Spaces:
Build error
Build error
| import numpy as np | |
| import pandas as pd | |
| from sklearn.datasets import make_classification | |
| from sklearn.ensemble import IsolationForest | |
| from sklearn.metrics import roc_curve, auc | |
| import shap | |
| import matplotlib.pyplot as plt | |
| import gradio as gr | |
| # Generate synthetic data with 20 features | |
| np.random.seed(42) | |
| X, _ = make_classification( | |
| n_samples=500, | |
| n_features=20, | |
| n_informative=10, | |
| n_redundant=5, | |
| n_clusters_per_class=1, | |
| random_state=42 | |
| ) | |
| outliers = np.random.uniform(low=-6, high=6, size=(50, 20)) # Add outliers | |
| X = np.vstack([X, outliers]) | |
| # Convert to DataFrame | |
| columns = [f"Feature{i+1}" for i in range(20)] | |
| df = pd.DataFrame(X, columns=columns) | |
| # Fit Isolation Forest | |
| iso_forest = IsolationForest( | |
| n_estimators=100, | |
| max_samples=256, | |
| contamination=0.1, | |
| random_state=42 | |
| ) | |
| iso_forest.fit(df) | |
| # Predict anomaly scores | |
| anomaly_scores = iso_forest.decision_function(df) # Negative values indicate anomalies | |
| anomaly_labels = iso_forest.predict(df) # -1 for anomaly, 1 for normal | |
| # Add results to DataFrame | |
| df["Anomaly_Score"] = anomaly_scores | |
| df["Anomaly_Label"] = np.where(anomaly_labels == -1, "Anomaly", "Normal") | |
| # Generate true labels (1 for anomaly, 0 for normal) for ROC curve | |
| true_labels = np.where(df["Anomaly_Label"] == "Anomaly", 1, 0) | |
| # SHAP Explainability | |
| explainer = shap.Explainer(iso_forest, df[columns]) | |
| shap_values = explainer(df[columns]) | |
| # Define functions for Gradio | |
| def get_shap_summary(): | |
| """Generates SHAP summary plot.""" | |
| plt.figure() | |
| shap.summary_plot(shap_values, df[columns], feature_names=columns, show=False) | |
| plt.savefig("shap_summary.png") | |
| return "shap_summary.png" | |
| def get_shap_waterfall(index): | |
| """Generates SHAP waterfall plot for a specific data point.""" | |
| specific_index = int(index) | |
| plt.figure() | |
| shap.waterfall_plot( | |
| shap.Explanation( | |
| values=shap_values.values[specific_index], | |
| base_values=shap_values.base_values[specific_index], | |
| data=df.iloc[specific_index], | |
| feature_names=columns | |
| ) | |
| ) | |
| plt.savefig("shap_waterfall.png") | |
| return "shap_waterfall.png" | |
| def get_scatter_plot(feature1, feature2): | |
| """Generates scatter plot for two features.""" | |
| plt.figure(figsize=(8, 6)) | |
| plt.scatter( | |
| df[feature1], | |
| df[feature2], | |
| c=(df["Anomaly_Label"] == "Anomaly"), | |
| cmap="coolwarm", | |
| edgecolor="k", | |
| alpha=0.7 | |
| ) | |
| plt.title(f"Isolation Forest - {feature1} vs {feature2}") | |
| plt.xlabel(feature1) | |
| plt.ylabel(feature2) | |
| plt.savefig("scatter_plot.png") | |
| return "scatter_plot.png" | |
| def get_roc_curve(): | |
| """Generates the ROC curve plot.""" | |
| fpr, tpr, _ = roc_curve(true_labels, -df["Anomaly_Score"]) # Use -scores as higher scores mean normal | |
| roc_auc = auc(fpr, tpr) | |
| plt.figure(figsize=(8, 6)) | |
| plt.plot(fpr, tpr, label=f"ROC Curve (AUC = {roc_auc:.2f})") | |
| plt.plot([0, 1], [0, 1], "k--", label="Random Guess") | |
| plt.xlabel("False Positive Rate") | |
| plt.ylabel("True Positive Rate") | |
| plt.title("Receiver Operating Characteristic (ROC) Curve") | |
| plt.legend(loc="lower right") | |
| plt.grid() | |
| plt.savefig("roc_curve.png") | |
| return "roc_curve.png" | |
| def get_anomaly_samples(): | |
| """Returns formatted top, middle, and bottom 10 records based on anomaly score.""" | |
| sorted_df = df.sort_values("Anomaly_Score", ascending=False) | |
| # Top 10 anomalies | |
| top_10 = sorted_df[sorted_df["Anomaly_Label"] == "Anomaly"].head(10) | |
| # Middle 10 (mix of anomalies and normal) | |
| mid_start = len(sorted_df) // 2 - 50 # Get a broader middle slice | |
| middle_section = sorted_df.iloc[mid_start: mid_start + 100] # Consider a larger middle slice | |
| middle_anomalies = middle_section[middle_section["Anomaly_Label"] == "Anomaly"].sample(n=5, random_state=42) | |
| middle_normals = middle_section[middle_section["Anomaly_Label"] == "Normal"].sample(n=5, random_state=42) | |
| middle_10 = pd.concat([middle_anomalies, middle_normals]).sort_values("Anomaly_Score", ascending=False) | |
| # Bottom 10 normal records | |
| bottom_10 = sorted_df[sorted_df["Anomaly_Label"] == "Normal"].tail(10) | |
| return top_10, middle_10, bottom_10 | |
| # Create Gradio interface | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# Isolation Forest Anomaly Detection") | |
| with gr.Tab("SHAP Summary"): | |
| gr.Markdown("### Global Explainability: SHAP Summary Plot") | |
| shap_button = gr.Button("Generate SHAP Summary Plot") | |
| shap_image = gr.Image() | |
| shap_button.click(get_shap_summary, outputs=shap_image) | |
| with gr.Tab("SHAP Waterfall"): | |
| gr.Markdown("### Local Explainability: SHAP Waterfall Plot") | |
| index_input = gr.Number(label="Data Point Index", value=0) | |
| shap_waterfall_button = gr.Button("Generate SHAP Waterfall Plot") | |
| shap_waterfall_image = gr.Image() | |
| shap_waterfall_button.click(get_shap_waterfall, inputs=index_input, outputs=shap_waterfall_image) | |
| with gr.Tab("Feature Scatter Plot"): | |
| gr.Markdown("### Feature Interaction: Scatter Plot") | |
| feature1_dropdown = gr.Dropdown(choices=columns, label="Feature 1") | |
| feature2_dropdown = gr.Dropdown(choices=columns, label="Feature 2") | |
| scatter_button = gr.Button("Generate Scatter Plot") | |
| scatter_image = gr.Image() | |
| scatter_button.click(get_scatter_plot, inputs=[feature1_dropdown, feature2_dropdown], outputs=scatter_image) | |
| with gr.Tab("Anomaly Samples"): | |
| gr.HTML("<h3 style='text-align: center; font-size: 18px; font-weight: bold;'>Top 10 Records (Anomalies)</h3>") | |
| top_table = gr.Dataframe(label="Top 10 Records") | |
| gr.HTML("<h3 style='text-align: center; font-size: 18px; font-weight: bold;'>Middle 10 Records (Mixed)</h3>") | |
| middle_table = gr.Dataframe(label="Middle 10 Records") | |
| gr.HTML("<h3 style='text-align: center; font-size: 18px; font-weight: bold;'>Bottom 10 Records (Normal)</h3>") | |
| bottom_table = gr.Dataframe(label="Bottom 10 Records") | |
| anomaly_samples_button = gr.Button("Show Anomaly Samples") | |
| anomaly_samples_button.click( | |
| get_anomaly_samples, | |
| outputs=[top_table, middle_table, bottom_table] | |
| ) | |
| with gr.Tab("ROC Curve"): | |
| gr.Markdown("### ROC Curve for Isolation Forest") | |
| roc_button = gr.Button("Generate ROC Curve") | |
| roc_image = gr.Image() | |
| roc_button.click(get_roc_curve, outputs=roc_image) | |
| # Launch the Gradio app | |
| demo.launch() | |