Spaces:
Paused
Paused
| from groq import Groq | |
| from pydantic import BaseModel | |
| import json | |
| import gradio as gr | |
| import pandas as pd | |
| class ValidationStatus(BaseModel): | |
| # Indicates whether the generated SQL query is syntactically valid | |
| is_valid: bool | |
| # A list of SQL syntax error messages (empty if no errors are found) | |
| syntax_errors: list[str] | |
| class SQLQueryGeneration(BaseModel): | |
| query: str ### SELECT product_id, name, price FROM products WHERE price < 50 ORDER BY price ASC | |
| # The type of SQL query (e.g., SELECT, INSERT, UPDATE, DELETE) | |
| query_type: str ### "SELECT", | |
| # A list of table names referenced in the SQL query | |
| tables_used: list[str] ### products | |
| # Estimated complexity of the query (e.g., LOW, MEDIUM, HIGH) | |
| estimated_complexity: str ### low | |
| # Notes describing how the query executes or any assumptions made | |
| execution_notes: list[str] | |
| # "Simple SELECT query on products table", "Filter products with price less than $50", "Order results by price ascending" | |
| # Validation results for the generated SQL query | |
| validation_status: ValidationStatus | |
| # "is_valid": true/false, "syntax_errors": [] | |
| # SQL CREATE TABLE statement describing the table schema | |
| table_schema: str ### CREATE Table query (We create the table which is products here) | |
| # Sample data used to populate the table (INSERT statements or table view) | |
| sample_data: str #### INSERT DATA INTO _______ | |
| # Results of executing the SQL query, formatted as a pipe-delimited table | |
| execution_results: str #### EXECUTION | |
| # Suggestions for optimizing the SQL query (indexes, joins, filters, etc.) | |
| optimization_notes: list[str] ### INSTRUCTIONS | |
| def parse_execution_results_to_dataframe(execution_results): | |
| """Convert text-based table results to pandas DataFrame""" | |
| try: | |
| # Remove leading/trailing whitespace and split the text into lines | |
| lines = execution_results.strip().split('\n') | |
| # If there are fewer than 3 lines, it's not a valid table (header, separator, data) | |
| if len(lines) < 3: | |
| return None | |
| # -------------------- | |
| # Extract header row | |
| # -------------------- | |
| # The first line contains the column headers | |
| header_line = lines[0] | |
| # Split the header by '|' and strip whitespace from each column name | |
| headers = [col.strip() for col in header_line.split('|')] | |
| # -------------------- | |
| # Extract data rows | |
| # -------------------- | |
| # Initialize a list to store parsed data rows | |
| data_rows = [] | |
| # Skip the second line (usually a separator like ----|----) | |
| for line in lines[2:]: | |
| # Ignore empty lines and separator-only lines | |
| if line.strip() and not line.strip().startswith('-'): | |
| # Split the row by '|' and clean up whitespace | |
| row = [cell.strip() for cell in line.split('|')] | |
| # Only keep rows that match the number of headers | |
| if len(row) == len(headers): | |
| data_rows.append(row) | |
| # -------------------- | |
| # Create DataFrame | |
| # -------------------- | |
| # If we successfully collected data rows | |
| if data_rows: | |
| # Create a pandas DataFrame using headers as column names | |
| df = pd.DataFrame(data_rows, columns=headers) ### column with heading names from my original text based table and data rows collected from there.... | |
| return df | |
| # Return None if no valid data rows were found | |
| return None | |
| except Exception as e: | |
| # Catch and print any parsing errors | |
| print(f"Error parsing results: {e}") | |
| return None | |
| def generate_sql_query(api_key, user_query): | |
| """Generate SQL query from natural language using GROQ API""" | |
| try: | |
| # -------------------- | |
| # Input validation | |
| # -------------------- | |
| # Check if API key is missing | |
| if not api_key: | |
| # Return error message and placeholders for expected return values | |
| return "Error: Please enter your GROQ API key", "", "", "", None, "" | |
| # Check if user query is missing | |
| if not user_query: | |
| # Return error message and placeholders for expected return values | |
| return "Error: Please enter a query description", "", "", "", None, "" | |
| # -------------------- | |
| # Initialize GROQ client | |
| # -------------------- | |
| # Create a GROQ client using the provided API key | |
| client = Groq(api_key=api_key) | |
| # -------------------- | |
| # Call GROQ Chat Completion API | |
| # -------------------- | |
| # Send a request to the GROQ chat completion endpoint | |
| response = client.chat.completions.create( | |
| # Specify the LLM model to use | |
| model="moonshotai/kimi-k2-instruct-0905", | |
| # Provide system and user messages | |
| messages=[ | |
| { | |
| # System prompt defines the assistant's role and output format | |
| "role": "system", | |
| "content": """You are a SQL expert. Generate structured SQL queries from natural language descriptions with proper syntax validation and metadata. | |
| After generating the SQL query, you must: | |
| 1. Create a sample SQL table schema based on the natural language description, including all necessary columns with appropriate data types | |
| 2. Populate the table with realistic sample data that demonstrates the query's functionality | |
| 3. Execute the generated SQL query against the sample table | |
| 4. Display the SQL table structure and data clearly | |
| 5. Show the query execution results in a pipe-delimited table format | |
| IMPORTANT: The execution_results field must contain a properly formatted table with: | |
| - Header row with column names separated by pipes (|) | |
| - A separator row with dashes | |
| - Data rows with values separated by pipes (|) | |
| Example format: | |
| column1 | column2 | column3 | |
| --------|---------|-------- | |
| value1 | value2 | value3 | |
| value4 | value5 | value6 | |
| Always present your response in this order: | |
| - Generated SQL query with syntax explanation | |
| - Table schema (CREATE TABLE statement) | |
| - Sample data (INSERT statements or table visualization) | |
| - Query execution results (in pipe-delimited table format) | |
| - Any relevant notes about assumptions made or query optimization suggestions""", | |
| }, | |
| { | |
| # User-provided natural language query | |
| "role": "user", | |
| "content": user_query ### NLQ | |
| }, | |
| ], | |
| # Enforce structured JSON output using a predefined schema | |
| response_format={ | |
| "type": "json_schema", | |
| "json_schema": { | |
| "name": "sql_query_generation", | |
| # Convert Pydantic model into JSON schema | |
| "schema": SQLQueryGeneration.model_json_schema() | |
| } | |
| } | |
| ) | |
| # -------------------- | |
| # Parse and validate model output | |
| # -------------------- | |
| # Convert the JSON string returned by the model into a Python object | |
| sql_query_generation = SQLQueryGeneration.model_validate( | |
| json.loads(response.choices[0].message.content) | |
| ) | |
| # -------------------- | |
| # Format validation results | |
| # -------------------- | |
| # Start validation summary with overall validity flag | |
| validation_text = f"Valid: {sql_query_generation.validation_status.is_valid}\n" ## true or false | |
| # If syntax errors exist, list them | |
| if sql_query_generation.validation_status.syntax_errors: ## if any syntax error is there | |
| validation_text += "Errors:\n" + "\n".join( | |
| f"- {error}" for error in sql_query_generation.validation_status.syntax_errors | |
| ) | |
| else: | |
| # No syntax issues found | |
| validation_text += "No syntax errors found" | |
| # Build a metadata summary string describing the query | |
| metadata = f"""Query Type: {sql_query_generation.query_type} | |
| Tables Used: {', '.join(sql_query_generation.tables_used)} | |
| Complexity: {sql_query_generation.estimated_complexity} | |
| Execution Notes: | |
| {chr(10).join(f"- {note}" for note in sql_query_generation.execution_notes)} | |
| Optimization Notes: | |
| {chr(10).join(f"- {note}" for note in sql_query_generation.optimization_notes)}""" | |
| # -------------------- | |
| # Parse execution results into DataFrame | |
| # -------------------- | |
| # Convert the pipe-delimited execution results into a pandas DataFrame | |
| results_df = parse_execution_results_to_dataframe( | |
| sql_query_generation.execution_results | |
| ) | |
| # -------------------- | |
| # Return all outputs | |
| # -------------------- | |
| return ( | |
| # Generated SQL query | |
| sql_query_generation.query, | |
| # Metadata summary | |
| metadata, | |
| # SQL CREATE TABLE schema | |
| sql_query_generation.table_schema, | |
| # Sample INSERT data or table visualization | |
| sql_query_generation.sample_data, | |
| # Pandas DataFrame of execution results | |
| results_df, | |
| # SQL validation summary | |
| validation_text | |
| ) | |
| except Exception as e: | |
| # Catch unexpected errors and return an error message | |
| error_msg = f"Error: {str(e)}" | |
| return error_msg, "", "", "", None, "" | |
| # Create Gradio interface | |
| with gr.Blocks(title="SQL Query Generator", theme=gr.themes.Ocean()) as demo: | |
| gr.Markdown( | |
| """ | |
| # ๐๏ธ Natural Language to SQL Query Generator | |
| Convert your natural language descriptions into structured SQL queries with validation and execution results. | |
| """ | |
| ) | |
| with gr.Row(): | |
| with gr.Column(): | |
| api_key_input = gr.Textbox( | |
| label="GROQ API Key", | |
| type="password", | |
| placeholder="Enter your GROQ API key here...", | |
| info="Your API key is not stored and only used for this session" | |
| ) | |
| query_input = gr.Textbox( | |
| label="Natural Language Query", | |
| placeholder="e.g., Find all the students who scored more than 90 out of 100", | |
| lines=3, | |
| value="Find all the students who scored more than 90 out of 100" | |
| ) | |
| generate_btn = gr.Button("Generate SQL Query", variant="primary", size="lg") | |
| gr.Examples( | |
| examples=[ | |
| ["Find all the students who scored more than 90 out of 100"], | |
| ["Get the top 5 customers by total purchase amount"], | |
| ["List all employees hired in the last 6 months"], | |
| ["Find products with price between $50 and $100"], | |
| ["Show average salary by department"] | |
| ], | |
| inputs=query_input, | |
| label="Example Queries" | |
| ) | |
| with gr.Row(): | |
| with gr.Column(): | |
| sql_output = gr.Code( | |
| label="Generated SQL Query", | |
| language="sql", | |
| lines=5 | |
| ) | |
| metadata_output = gr.Textbox( | |
| label="Query Metadata", | |
| lines=8 | |
| ) | |
| validation_output = gr.Textbox( | |
| label="Validation Status", | |
| lines=3 | |
| ) | |
| with gr.Row(): | |
| with gr.Column(): | |
| schema_output = gr.Code( | |
| label="Table Schema", | |
| language="sql", | |
| lines=8 | |
| ) | |
| with gr.Column(): | |
| sample_data_output = gr.Code( | |
| label="Sample Data", | |
| language="sql", | |
| lines=8 | |
| ) | |
| with gr.Row(): | |
| execution_output = gr.Dataframe( | |
| label="๐ Execution Results", | |
| headers=None, | |
| datatype="str", | |
| row_count=10, | |
| col_count=None, | |
| wrap=True, | |
| interactive=False | |
| ) | |
| generate_btn.click( | |
| fn=generate_sql_query, | |
| inputs=[api_key_input, query_input], | |
| outputs=[ | |
| sql_output, | |
| metadata_output, | |
| schema_output, | |
| sample_data_output, | |
| execution_output, | |
| validation_output | |
| ] | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch(share=True) | |