Spaces:
Sleeping
Sleeping
Upload folder using huggingface_hub
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- Dockerfile +37 -0
- README.md +85 -11
- backend/.pytest_cache/.gitignore +2 -0
- backend/.pytest_cache/CACHEDIR.TAG +4 -0
- backend/.pytest_cache/README.md +8 -0
- backend/.pytest_cache/v/cache/lastfailed +8 -0
- backend/.pytest_cache/v/cache/nodeids +74 -0
- backend/.ruff_cache/.gitignore +2 -0
- backend/.ruff_cache/0.14.6/18015614173546374012 +0 -0
- backend/.ruff_cache/CACHEDIR.TAG +1 -0
- backend/app/__init__.py +4 -0
- backend/app/api/__init__.py +1 -0
- backend/app/api/cache.py +45 -0
- backend/app/automata/__init__.py +16 -0
- backend/app/automata/ast_fixer.py +196 -0
- backend/app/automata/base.py +78 -0
- backend/app/automata/formatter.py +86 -0
- backend/app/automata/linter.py +139 -0
- backend/app/automata/runtime_fixer.py +297 -0
- backend/app/automata/test_generator.py +161 -0
- backend/app/automata/trace_parser.py +177 -0
- backend/app/config.py +91 -0
- backend/app/core/__init__.py +1 -0
- backend/app/core/automata_manager.py +73 -0
- backend/app/core/distillation.py +92 -0
- backend/app/core/lifecycle.py +97 -0
- backend/app/core/model_cache.py +240 -0
- backend/app/core/orchestrator.py +695 -0
- backend/app/core/orchestrator_decomposition.py +193 -0
- backend/app/core/pipeline.py +42 -0
- backend/app/core/rag.py +124 -0
- backend/app/core/router.py +100 -0
- backend/app/core/router_v2.py +174 -0
- backend/app/core/slm_registry.py +120 -0
- backend/app/core/task_decomposer.py +309 -0
- backend/app/engines/__init__.py +10 -0
- backend/app/engines/base.py +279 -0
- backend/app/engines/codet5.py +180 -0
- backend/app/engines/groq_engine.py +228 -0
- backend/app/engines/micro_slm.py +135 -0
- backend/app/engines/phi2.py +191 -0
- backend/app/engines/starcoder.py +212 -0
- backend/app/locales/en.json +124 -0
- backend/app/locales/fr.json +124 -0
- backend/app/main.py +265 -0
- backend/app/models/__init__.py +1 -0
- backend/app/models/schemas.py +154 -0
- backend/app/rag/__init__.py +10 -0
- backend/app/rag/embedder.py +95 -0
- backend/app/rag/retriever.py +215 -0
Dockerfile
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
FROM python:3.11-slim
|
| 2 |
+
|
| 3 |
+
WORKDIR /app
|
| 4 |
+
|
| 5 |
+
# Installer les dépendances système
|
| 6 |
+
RUN apt-get update && apt-get install -y \
|
| 7 |
+
git \
|
| 8 |
+
curl \
|
| 9 |
+
&& rm -rf /var/lib/apt/lists/*
|
| 10 |
+
|
| 11 |
+
# Copier requirements
|
| 12 |
+
COPY backend/requirements.txt .
|
| 13 |
+
|
| 14 |
+
# Installer les dépendances Python
|
| 15 |
+
RUN pip install --no-cache-dir -r requirements.txt
|
| 16 |
+
|
| 17 |
+
# Copier le code
|
| 18 |
+
COPY backend/ ./backend/
|
| 19 |
+
COPY data/ ./data/
|
| 20 |
+
|
| 21 |
+
# Créer les répertoires nécessaires
|
| 22 |
+
RUN mkdir -p logs
|
| 23 |
+
|
| 24 |
+
# Exposer le port (Hugging Face Spaces utilise 7860)
|
| 25 |
+
EXPOSE 7860
|
| 26 |
+
|
| 27 |
+
# Variables d'environnement
|
| 28 |
+
ENV PYTHONUNBUFFERED=1
|
| 29 |
+
ENV HOST=0.0.0.0
|
| 30 |
+
ENV PORT=7860
|
| 31 |
+
|
| 32 |
+
# Healthcheck
|
| 33 |
+
HEALTHCHECK --interval=30s --timeout=10s --start-period=60s --retries=3 \
|
| 34 |
+
CMD curl -f http://localhost:7860/health || exit 1
|
| 35 |
+
|
| 36 |
+
# Lancer le serveur
|
| 37 |
+
CMD ["uvicorn", "backend.app.main:app", "--host", "0.0.0.0", "--port", "7860", "--workers", "1"]
|
README.md
CHANGED
|
@@ -1,11 +1,85 @@
|
|
| 1 |
-
---
|
| 2 |
-
title:
|
| 3 |
-
emoji:
|
| 4 |
-
colorFrom: blue
|
| 5 |
-
colorTo:
|
| 6 |
-
sdk: docker
|
| 7 |
-
pinned: false
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
title: SLM Code Engine
|
| 3 |
+
emoji: 🤖
|
| 4 |
+
colorFrom: blue
|
| 5 |
+
colorTo: purple
|
| 6 |
+
sdk: docker
|
| 7 |
+
pinned: false
|
| 8 |
+
---
|
| 9 |
+
|
| 10 |
+
# 🤖 SLM Code Engine
|
| 11 |
+
|
| 12 |
+
Moteur de code intelligent avec Micro-SLMs spécialisés pour la génération de code.
|
| 13 |
+
|
| 14 |
+
## 🚀 Utilisation
|
| 15 |
+
|
| 16 |
+
### API Endpoint
|
| 17 |
+
|
| 18 |
+
```
|
| 19 |
+
POST /api/v1/query
|
| 20 |
+
```
|
| 21 |
+
|
| 22 |
+
### Exemple de requête
|
| 23 |
+
|
| 24 |
+
```bash
|
| 25 |
+
curl -X POST https://YOUR-USERNAME-slm-code-engine.hf.space/api/v1/query \
|
| 26 |
+
-H "Content-Type: application/json" \
|
| 27 |
+
-d '{
|
| 28 |
+
"task": "boilerplate",
|
| 29 |
+
"code": "",
|
| 30 |
+
"language": "python",
|
| 31 |
+
"context": "Génère une fonction pour calculer la moyenne"
|
| 32 |
+
}'
|
| 33 |
+
```
|
| 34 |
+
|
| 35 |
+
### Réponse
|
| 36 |
+
|
| 37 |
+
```json
|
| 38 |
+
{
|
| 39 |
+
"success": true,
|
| 40 |
+
"result": "def calculer_moyenne(nombres):\n return sum(nombres) / len(nombres)",
|
| 41 |
+
"explanation": "Fonction pour calculer la moyenne d'une liste",
|
| 42 |
+
"used_slm": true,
|
| 43 |
+
"total_duration_ms": 2500
|
| 44 |
+
}
|
| 45 |
+
```
|
| 46 |
+
|
| 47 |
+
## 🧠 Modèles disponibles
|
| 48 |
+
|
| 49 |
+
- **boilerplate_slm** : Génération de code boilerplate Python (Phi-2 fine-tuné)
|
| 50 |
+
- **Groq API** : Fallback pour tâches complexes (Llama 3.3 70B)
|
| 51 |
+
|
| 52 |
+
## 📊 Endpoints
|
| 53 |
+
|
| 54 |
+
| Endpoint | Méthode | Description |
|
| 55 |
+
|----------|---------|-------------|
|
| 56 |
+
| `/health` | GET | Vérifier le statut du serveur |
|
| 57 |
+
| `/api/v1/query` | POST | Générer du code |
|
| 58 |
+
| `/cache/stats` | GET | Statistiques du cache de modèles |
|
| 59 |
+
|
| 60 |
+
## 🔧 Configuration
|
| 61 |
+
|
| 62 |
+
Le système utilise :
|
| 63 |
+
- **Routeur intelligent** : Sélectionne automatiquement le meilleur modèle
|
| 64 |
+
- **Cache LRU** : Garde les modèles en mémoire pour des réponses rapides
|
| 65 |
+
- **Automates** : Formatage et linting automatiques
|
| 66 |
+
|
| 67 |
+
## 📈 Performance
|
| 68 |
+
|
| 69 |
+
- **Micro-SLM** : ~2-5s par requête
|
| 70 |
+
- **Groq API** : ~1-3s par requête
|
| 71 |
+
- **Cache hit** : ~0.1s par requête
|
| 72 |
+
|
| 73 |
+
## 🛠️ Technologies
|
| 74 |
+
|
| 75 |
+
- **Backend** : FastAPI + Uvicorn
|
| 76 |
+
- **Modèles** : Phi-2 (2.7B), Llama 3.3 (70B via Groq)
|
| 77 |
+
- **Framework** : Transformers, PEFT, PyTorch
|
| 78 |
+
|
| 79 |
+
## 📝 License
|
| 80 |
+
|
| 81 |
+
Apache 2.0
|
| 82 |
+
|
| 83 |
+
---
|
| 84 |
+
|
| 85 |
+
Développé avec ❤️ pour la communauté des développeurs
|
backend/.pytest_cache/.gitignore
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Created by pytest automatically.
|
| 2 |
+
*
|
backend/.pytest_cache/CACHEDIR.TAG
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Signature: 8a477f597d28d172789f06886806bc55
|
| 2 |
+
# This file is a cache directory tag created by pytest.
|
| 3 |
+
# For information about cache directory tags, see:
|
| 4 |
+
# https://bford.info/cachedir/spec.html
|
backend/.pytest_cache/README.md
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# pytest cache directory #
|
| 2 |
+
|
| 3 |
+
This directory contains data from the pytest's cache plugin,
|
| 4 |
+
which provides the `--lf` and `--ff` options, as well as the `cache` fixture.
|
| 5 |
+
|
| 6 |
+
**Do not** commit this to version control.
|
| 7 |
+
|
| 8 |
+
See [the docs](https://docs.pytest.org/en/stable/how-to/cache.html) for more information.
|
backend/.pytest_cache/v/cache/lastfailed
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"tests/test_automata.py::TestPythonLinter::test_can_handle_lint_task": true,
|
| 3 |
+
"tests/test_automata_unit.py::TestPythonFormatter::test_format_already_formatted": true,
|
| 4 |
+
"tests/test_automata_unit.py::TestTestTemplateGenerator::test_can_handle_test_task": true,
|
| 5 |
+
"tests/test_automata_unit.py::TestTestTemplateGenerator::test_generate_template": true,
|
| 6 |
+
"tests/test_orchestrator.py::test_orchestrator_valid_code_no_changes": true,
|
| 7 |
+
"tests/test_code_validators.py::TestExecutionValidator::test_timeout": true
|
| 8 |
+
}
|
backend/.pytest_cache/v/cache/nodeids
ADDED
|
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[
|
| 2 |
+
"tests/test_api.py::test_health_endpoint",
|
| 3 |
+
"tests/test_api.py::test_query_fix_endpoint",
|
| 4 |
+
"tests/test_api.py::test_query_format_endpoint",
|
| 5 |
+
"tests/test_api.py::test_query_invalid_task",
|
| 6 |
+
"tests/test_api.py::test_query_missing_code",
|
| 7 |
+
"tests/test_api.py::test_query_with_context",
|
| 8 |
+
"tests/test_api.py::test_query_with_trace",
|
| 9 |
+
"tests/test_api.py::test_stats_endpoint",
|
| 10 |
+
"tests/test_automata.py::TestASTFixer::test_can_handle_fix_task",
|
| 11 |
+
"tests/test_automata.py::TestASTFixer::test_fix_missing_colon",
|
| 12 |
+
"tests/test_automata.py::TestASTFixer::test_fix_multiple_errors",
|
| 13 |
+
"tests/test_automata.py::TestASTFixer::test_valid_code_unchanged",
|
| 14 |
+
"tests/test_automata.py::TestPythonFormatter::test_can_handle_python_format",
|
| 15 |
+
"tests/test_automata.py::TestPythonFormatter::test_cannot_handle_other_language",
|
| 16 |
+
"tests/test_automata.py::TestPythonFormatter::test_format_execution",
|
| 17 |
+
"tests/test_automata.py::TestPythonFormatter::test_format_invalid_syntax",
|
| 18 |
+
"tests/test_automata.py::TestPythonLinter::test_can_handle_format_task",
|
| 19 |
+
"tests/test_automata.py::TestPythonLinter::test_can_handle_lint_task",
|
| 20 |
+
"tests/test_automata.py::TestPythonLinter::test_lint_clean_code",
|
| 21 |
+
"tests/test_automata.py::TestTestTemplateGenerator::test_can_handle_test_task",
|
| 22 |
+
"tests/test_automata.py::TestTestTemplateGenerator::test_generate_class_tests",
|
| 23 |
+
"tests/test_automata.py::TestTestTemplateGenerator::test_generate_function_tests",
|
| 24 |
+
"tests/test_automata_unit.py::TestASTFixer::test_can_handle_python_fix",
|
| 25 |
+
"tests/test_automata_unit.py::TestASTFixer::test_fix_missing_colon",
|
| 26 |
+
"tests/test_automata_unit.py::TestASTFixer::test_fix_missing_colon_if",
|
| 27 |
+
"tests/test_automata_unit.py::TestASTFixer::test_no_changes_needed",
|
| 28 |
+
"tests/test_automata_unit.py::TestPythonFormatter::test_can_handle_python",
|
| 29 |
+
"tests/test_automata_unit.py::TestPythonFormatter::test_cannot_handle_other_languages",
|
| 30 |
+
"tests/test_automata_unit.py::TestPythonFormatter::test_cannot_handle_other_tasks",
|
| 31 |
+
"tests/test_automata_unit.py::TestPythonFormatter::test_format_already_formatted",
|
| 32 |
+
"tests/test_automata_unit.py::TestPythonFormatter::test_format_messy_code",
|
| 33 |
+
"tests/test_automata_unit.py::TestPythonLinter::test_can_handle_python",
|
| 34 |
+
"tests/test_automata_unit.py::TestPythonLinter::test_lint_code",
|
| 35 |
+
"tests/test_automata_unit.py::TestRuntimeFixer::test_fix_index_error",
|
| 36 |
+
"tests/test_automata_unit.py::TestRuntimeFixer::test_fix_zero_division",
|
| 37 |
+
"tests/test_automata_unit.py::TestTestTemplateGenerator::test_can_handle_test_task",
|
| 38 |
+
"tests/test_automata_unit.py::TestTestTemplateGenerator::test_generate_template",
|
| 39 |
+
"tests/test_automata_unit.py::TestTraceParser::test_can_handle_explain_with_trace",
|
| 40 |
+
"tests/test_automata_unit.py::TestTraceParser::test_parse_python_traceback",
|
| 41 |
+
"tests/test_automata_unit.py::TestTraceParser::test_parse_syntax_error",
|
| 42 |
+
"tests/test_code_validators.py::TestCompositeValidator::test_all_validators",
|
| 43 |
+
"tests/test_code_validators.py::TestCompositeValidator::test_overall_score",
|
| 44 |
+
"tests/test_code_validators.py::TestCompositeValidator::test_syntax_failure_stops_execution",
|
| 45 |
+
"tests/test_code_validators.py::TestExecutionValidator::test_runtime_error",
|
| 46 |
+
"tests/test_code_validators.py::TestExecutionValidator::test_successful_execution",
|
| 47 |
+
"tests/test_code_validators.py::TestExecutionValidator::test_timeout",
|
| 48 |
+
"tests/test_code_validators.py::TestQualityValidator::test_high_quality_code",
|
| 49 |
+
"tests/test_code_validators.py::TestQualityValidator::test_low_quality_code",
|
| 50 |
+
"tests/test_code_validators.py::TestSyntaxValidator::test_indentation_error",
|
| 51 |
+
"tests/test_code_validators.py::TestSyntaxValidator::test_invalid_syntax",
|
| 52 |
+
"tests/test_code_validators.py::TestSyntaxValidator::test_valid_syntax",
|
| 53 |
+
"tests/test_code_validators.py::TestTestValidator::test_failing_tests",
|
| 54 |
+
"tests/test_code_validators.py::TestTestValidator::test_passing_tests",
|
| 55 |
+
"tests/test_orchestrator.py::test_orchestrator_fix_via_automata",
|
| 56 |
+
"tests/test_orchestrator.py::test_orchestrator_format_via_automata",
|
| 57 |
+
"tests/test_orchestrator.py::test_orchestrator_performance",
|
| 58 |
+
"tests/test_orchestrator.py::test_orchestrator_pipeline_tracking",
|
| 59 |
+
"tests/test_orchestrator.py::test_orchestrator_valid_code_no_changes",
|
| 60 |
+
"tests/test_orchestrator_unit.py::TestOrchestratorInit::test_automata_loaded",
|
| 61 |
+
"tests/test_orchestrator_unit.py::TestOrchestratorInit::test_initialization",
|
| 62 |
+
"tests/test_orchestrator_unit.py::TestOrchestratorPipeline::test_duration_tracking",
|
| 63 |
+
"tests/test_orchestrator_unit.py::TestOrchestratorPipeline::test_pipeline_records_steps",
|
| 64 |
+
"tests/test_orchestrator_unit.py::TestOrchestratorRouting::test_fix_tries_automata_first",
|
| 65 |
+
"tests/test_orchestrator_unit.py::TestOrchestratorRouting::test_format_uses_automata",
|
| 66 |
+
"tests/test_orchestrator_unit.py::TestOrchestratorShutdown::test_shutdown",
|
| 67 |
+
"tests/test_orchestrator_unit.py::TestOrchestratorStatus::test_get_status",
|
| 68 |
+
"tests/test_orchestrator_unit.py::TestOrchestratorStatus::test_status_before_init",
|
| 69 |
+
"tests/test_router.py::test_router_boilerplate_task",
|
| 70 |
+
"tests/test_router.py::test_router_explain_task",
|
| 71 |
+
"tests/test_router.py::test_router_fix_task",
|
| 72 |
+
"tests/test_router.py::test_router_format_task",
|
| 73 |
+
"tests/test_router.py::test_router_test_task"
|
| 74 |
+
]
|
backend/.ruff_cache/.gitignore
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Automatically created by ruff.
|
| 2 |
+
*
|
backend/.ruff_cache/0.14.6/18015614173546374012
ADDED
|
Binary file (387 Bytes). View file
|
|
|
backend/.ruff_cache/CACHEDIR.TAG
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
Signature: 8a477f597d28d172789f06886806bc55
|
backend/app/__init__.py
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
SLM Code Engine - Main application package
|
| 3 |
+
"""
|
| 4 |
+
__version__ = "0.1.0"
|
backend/app/api/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
"""API package"""
|
backend/app/api/cache.py
ADDED
|
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Cache Statistics Endpoint
|
| 3 |
+
|
| 4 |
+
Provides real-time statistics about the model cache.
|
| 5 |
+
"""
|
| 6 |
+
from fastapi import APIRouter
|
| 7 |
+
from app.core.model_cache import model_cache
|
| 8 |
+
|
| 9 |
+
router = APIRouter(prefix="/cache", tags=["cache"])
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
@router.get("/stats")
|
| 13 |
+
async def get_cache_stats():
|
| 14 |
+
"""Get model cache statistics"""
|
| 15 |
+
return model_cache.get_stats()
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
@router.post("/clear")
|
| 19 |
+
async def clear_cache():
|
| 20 |
+
"""Clear all cached models"""
|
| 21 |
+
await model_cache.clear()
|
| 22 |
+
return {"message": "Cache cleared successfully"}
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
@router.post("/preload/{model_name}")
|
| 26 |
+
async def preload_model(model_name: str):
|
| 27 |
+
"""Preload a model into cache"""
|
| 28 |
+
from app.core.slm_registry import slm_registry
|
| 29 |
+
from app.engines.micro_slm import MicroSLMEngine
|
| 30 |
+
|
| 31 |
+
micro_slm_info = slm_registry.get_model(model_name)
|
| 32 |
+
if not micro_slm_info:
|
| 33 |
+
return {"error": f"Model {model_name} not found in registry"}
|
| 34 |
+
|
| 35 |
+
async def load_micro_slm():
|
| 36 |
+
engine = MicroSLMEngine(
|
| 37 |
+
name=model_name,
|
| 38 |
+
model_path=micro_slm_info.model_path
|
| 39 |
+
)
|
| 40 |
+
await engine.initialize()
|
| 41 |
+
return engine
|
| 42 |
+
|
| 43 |
+
await model_cache.preload(model_name, load_micro_slm)
|
| 44 |
+
|
| 45 |
+
return {"message": f"Model {model_name} preloaded successfully"}
|
backend/app/automata/__init__.py
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Automata package"""
|
| 2 |
+
from app.automata.base import BaseAutomaton
|
| 3 |
+
from app.automata.formatter import PythonFormatter
|
| 4 |
+
from app.automata.linter import PythonLinter
|
| 5 |
+
from app.automata.trace_parser import TraceParser
|
| 6 |
+
from app.automata.ast_fixer import ASTFixer
|
| 7 |
+
from app.automata.test_generator import TestTemplateGenerator
|
| 8 |
+
|
| 9 |
+
__all__ = [
|
| 10 |
+
"BaseAutomaton",
|
| 11 |
+
"PythonFormatter",
|
| 12 |
+
"PythonLinter",
|
| 13 |
+
"TraceParser",
|
| 14 |
+
"ASTFixer",
|
| 15 |
+
"TestTemplateGenerator"
|
| 16 |
+
]
|
backend/app/automata/ast_fixer.py
ADDED
|
@@ -0,0 +1,196 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
AST-based code fixer for simple syntax errors
|
| 3 |
+
|
| 4 |
+
Uses Python's AST module to detect and fix common issues:
|
| 5 |
+
- Indentation errors
|
| 6 |
+
- Missing colons
|
| 7 |
+
- Simple syntax errors
|
| 8 |
+
"""
|
| 9 |
+
import ast
|
| 10 |
+
import logging
|
| 11 |
+
from typing import Dict, Any, Optional
|
| 12 |
+
|
| 13 |
+
from app.automata.base import BaseAutomaton
|
| 14 |
+
from app.models.schemas import TaskType, Language
|
| 15 |
+
from app.utils.localization import get_string
|
| 16 |
+
|
| 17 |
+
logger = logging.getLogger(__name__)
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class ASTFixer(BaseAutomaton):
|
| 21 |
+
"""Fixes simple Python syntax errors using AST analysis"""
|
| 22 |
+
|
| 23 |
+
def __init__(self):
|
| 24 |
+
super().__init__("ast_fixer")
|
| 25 |
+
|
| 26 |
+
def can_handle(
|
| 27 |
+
self,
|
| 28 |
+
code: str,
|
| 29 |
+
language: Language,
|
| 30 |
+
task: TaskType
|
| 31 |
+
) -> bool:
|
| 32 |
+
"""Check if can fix this code"""
|
| 33 |
+
return (
|
| 34 |
+
language == Language.PYTHON
|
| 35 |
+
and task == TaskType.FIX
|
| 36 |
+
)
|
| 37 |
+
|
| 38 |
+
async def execute(
|
| 39 |
+
self,
|
| 40 |
+
code: str,
|
| 41 |
+
**kwargs
|
| 42 |
+
) -> Dict[str, Any]:
|
| 43 |
+
"""Try to fix simple syntax errors"""
|
| 44 |
+
try:
|
| 45 |
+
# First, try to parse as-is
|
| 46 |
+
ast.parse(code)
|
| 47 |
+
|
| 48 |
+
# No syntax errors
|
| 49 |
+
return self._format_result(
|
| 50 |
+
success=True,
|
| 51 |
+
result=code,
|
| 52 |
+
explanation=get_string("ast_fixer_no_errors"),
|
| 53 |
+
suggestions=[]
|
| 54 |
+
)
|
| 55 |
+
|
| 56 |
+
except SyntaxError as e:
|
| 57 |
+
# Try multiple passes to fix errors (up to 5)
|
| 58 |
+
current_code = code
|
| 59 |
+
fixes_applied = []
|
| 60 |
+
max_attempts = 5
|
| 61 |
+
|
| 62 |
+
for attempt in range(max_attempts):
|
| 63 |
+
try:
|
| 64 |
+
# Try to parse current code
|
| 65 |
+
ast.parse(current_code)
|
| 66 |
+
# Success! All errors fixed
|
| 67 |
+
return self._format_result(
|
| 68 |
+
success=True,
|
| 69 |
+
result=current_code,
|
| 70 |
+
explanation=get_string(
|
| 71 |
+
"ast_fixer_fixed_issues",
|
| 72 |
+
issue_count=len(fixes_applied),
|
| 73 |
+
issues=', '.join(fixes_applied)
|
| 74 |
+
),
|
| 75 |
+
suggestions=[get_string("ast_fixer_suggestion_linter")]
|
| 76 |
+
)
|
| 77 |
+
except SyntaxError as error:
|
| 78 |
+
# Try to fix this error
|
| 79 |
+
fixed_code, explanation = self._attempt_fix(current_code, error)
|
| 80 |
+
|
| 81 |
+
if fixed_code and fixed_code != current_code:
|
| 82 |
+
current_code = fixed_code
|
| 83 |
+
fixes_applied.append(explanation)
|
| 84 |
+
else:
|
| 85 |
+
# Can't fix this error
|
| 86 |
+
break
|
| 87 |
+
|
| 88 |
+
# Check if we made any progress
|
| 89 |
+
if fixes_applied:
|
| 90 |
+
# Some fixes worked - return False to trigger SLM fallback
|
| 91 |
+
return self._format_result(
|
| 92 |
+
success=False,
|
| 93 |
+
result=current_code,
|
| 94 |
+
explanation=get_string("ast_fixer_failed_autofix"),
|
| 95 |
+
suggestions=[get_string("ast_fixer_suggestion_slm")]
|
| 96 |
+
)
|
| 97 |
+
else:
|
| 98 |
+
# No fixes worked - trigger SLM fallback
|
| 99 |
+
return self._format_result(
|
| 100 |
+
success=False,
|
| 101 |
+
explanation=get_string("ast_fixer_syntax_error", error=str(e)),
|
| 102 |
+
suggestions=[get_string("ast_fixer_suggestion_slm")]
|
| 103 |
+
)
|
| 104 |
+
|
| 105 |
+
except Exception as e:
|
| 106 |
+
logger.error(f"AST analysis failed: {e}")
|
| 107 |
+
return self._format_result(
|
| 108 |
+
success=False,
|
| 109 |
+
explanation=get_string("ast_fixer_analysis_error", error=str(e))
|
| 110 |
+
)
|
| 111 |
+
|
| 112 |
+
def _attempt_fix(self, code: str, error: SyntaxError) -> tuple[Optional[str], Optional[str]]:
|
| 113 |
+
"""Attempt to fix common syntax errors"""
|
| 114 |
+
lines = code.split('\n')
|
| 115 |
+
error_line = error.lineno - 1 if error.lineno else 0
|
| 116 |
+
|
| 117 |
+
# Common fixes
|
| 118 |
+
fixes = [
|
| 119 |
+
self._fix_missing_colon,
|
| 120 |
+
self._fix_indentation,
|
| 121 |
+
self._fix_parentheses,
|
| 122 |
+
]
|
| 123 |
+
|
| 124 |
+
for fix_func in fixes:
|
| 125 |
+
try:
|
| 126 |
+
fixed_code, explanation = fix_func(lines, error_line, error)
|
| 127 |
+
if fixed_code:
|
| 128 |
+
return fixed_code, explanation
|
| 129 |
+
except Exception as e:
|
| 130 |
+
logger.debug(f"Fix attempt failed: {e}")
|
| 131 |
+
continue
|
| 132 |
+
|
| 133 |
+
return None, None
|
| 134 |
+
|
| 135 |
+
def _fix_missing_colon(self, lines: list, error_line: int, error: SyntaxError) -> tuple[Optional[str], Optional[str]]:
|
| 136 |
+
"""Fix missing colon in function/class definitions"""
|
| 137 |
+
if error_line >= len(lines):
|
| 138 |
+
return None, None
|
| 139 |
+
|
| 140 |
+
line = lines[error_line].rstrip()
|
| 141 |
+
|
| 142 |
+
# Check if it's a definition without colon
|
| 143 |
+
keywords = ['def ', 'class ', 'if ', 'elif ', 'else', 'for ', 'while ', 'try', 'except', 'finally', 'with ']
|
| 144 |
+
|
| 145 |
+
for keyword in keywords:
|
| 146 |
+
if line.strip().startswith(keyword) and not line.endswith(':'):
|
| 147 |
+
# Add missing colon
|
| 148 |
+
lines[error_line] = line + ':'
|
| 149 |
+
fixed_code = '\n'.join(lines)
|
| 150 |
+
return fixed_code, get_string(
|
| 151 |
+
"ast_fixer_added_colon",
|
| 152 |
+
keyword=keyword.strip(),
|
| 153 |
+
line_number=error_line + 1
|
| 154 |
+
)
|
| 155 |
+
|
| 156 |
+
return None, None
|
| 157 |
+
|
| 158 |
+
def _fix_indentation(self, lines: list, error_line: int, error: SyntaxError) -> tuple[Optional[str], Optional[str]]:
|
| 159 |
+
"""Fix simple indentation errors"""
|
| 160 |
+
if error_line >= len(lines) or error_line == 0:
|
| 161 |
+
return None, None
|
| 162 |
+
|
| 163 |
+
current_line = lines[error_line]
|
| 164 |
+
prev_line = lines[error_line - 1].rstrip()
|
| 165 |
+
|
| 166 |
+
# If previous line ends with colon, current should be indented
|
| 167 |
+
if prev_line.endswith(':'):
|
| 168 |
+
if not current_line.startswith(' ') and current_line.strip():
|
| 169 |
+
lines[error_line] = ' ' + current_line.lstrip()
|
| 170 |
+
fixed_code = '\n'.join(lines)
|
| 171 |
+
return fixed_code, get_string("ast_fixer_fixed_indentation", line_number=error_line + 1)
|
| 172 |
+
|
| 173 |
+
return None, None
|
| 174 |
+
|
| 175 |
+
def _fix_parentheses(self, lines: list, error_line: int, error: SyntaxError) -> tuple[Optional[str], Optional[str]]:
|
| 176 |
+
"""Fix unmatched parentheses"""
|
| 177 |
+
if error_line >= len(lines):
|
| 178 |
+
return None, None
|
| 179 |
+
|
| 180 |
+
line = lines[error_line]
|
| 181 |
+
|
| 182 |
+
# Count parentheses
|
| 183 |
+
open_count = line.count('(')
|
| 184 |
+
close_count = line.count(')')
|
| 185 |
+
|
| 186 |
+
if open_count > close_count:
|
| 187 |
+
# Missing closing parenthesis
|
| 188 |
+
lines[error_line] = line.rstrip() + ')' * (open_count - close_count)
|
| 189 |
+
fixed_code = '\n'.join(lines)
|
| 190 |
+
return fixed_code, get_string("ast_fixer_added_paren", line_number=error_line + 1)
|
| 191 |
+
|
| 192 |
+
elif close_count > open_count:
|
| 193 |
+
# Extra closing parenthesis - harder to fix automatically
|
| 194 |
+
pass
|
| 195 |
+
|
| 196 |
+
return None, None
|
backend/app/automata/base.py
ADDED
|
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Base class for all automata
|
| 3 |
+
|
| 4 |
+
Automata are deterministic, rule-based components that handle
|
| 5 |
+
specific tasks without requiring LLM inference.
|
| 6 |
+
"""
|
| 7 |
+
from abc import ABC, abstractmethod
|
| 8 |
+
from typing import Dict, Any, Optional
|
| 9 |
+
import logging
|
| 10 |
+
|
| 11 |
+
from app.models.schemas import TaskType, Language
|
| 12 |
+
|
| 13 |
+
logger = logging.getLogger(__name__)
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class BaseAutomaton(ABC):
|
| 17 |
+
"""Base class for all automata"""
|
| 18 |
+
|
| 19 |
+
def __init__(self, name: str):
|
| 20 |
+
self.name = name
|
| 21 |
+
logger.info(f"Initializing automaton: {name}")
|
| 22 |
+
|
| 23 |
+
@abstractmethod
|
| 24 |
+
def can_handle(
|
| 25 |
+
self,
|
| 26 |
+
code: str,
|
| 27 |
+
language: Language,
|
| 28 |
+
task: TaskType
|
| 29 |
+
) -> bool:
|
| 30 |
+
"""
|
| 31 |
+
Determine if this automaton can handle the task
|
| 32 |
+
|
| 33 |
+
Args:
|
| 34 |
+
code: Source code
|
| 35 |
+
language: Programming language
|
| 36 |
+
task: Task type
|
| 37 |
+
|
| 38 |
+
Returns:
|
| 39 |
+
True if automaton can handle this task
|
| 40 |
+
"""
|
| 41 |
+
pass
|
| 42 |
+
|
| 43 |
+
@abstractmethod
|
| 44 |
+
async def execute(
|
| 45 |
+
self,
|
| 46 |
+
code: str,
|
| 47 |
+
**kwargs
|
| 48 |
+
) -> Dict[str, Any]:
|
| 49 |
+
"""
|
| 50 |
+
Execute the automaton
|
| 51 |
+
|
| 52 |
+
Args:
|
| 53 |
+
code: Source code to process
|
| 54 |
+
**kwargs: Additional parameters
|
| 55 |
+
|
| 56 |
+
Returns:
|
| 57 |
+
Dict with:
|
| 58 |
+
- success: bool
|
| 59 |
+
- result: str (processed code or output)
|
| 60 |
+
- explanation: Optional[str]
|
| 61 |
+
- suggestions: Optional[List[str]]
|
| 62 |
+
"""
|
| 63 |
+
pass
|
| 64 |
+
|
| 65 |
+
def _format_result(
|
| 66 |
+
self,
|
| 67 |
+
success: bool,
|
| 68 |
+
result: Optional[str] = None,
|
| 69 |
+
explanation: Optional[str] = None,
|
| 70 |
+
suggestions: Optional[list] = None
|
| 71 |
+
) -> Dict[str, Any]:
|
| 72 |
+
"""Helper to format results consistently"""
|
| 73 |
+
return {
|
| 74 |
+
"success": success,
|
| 75 |
+
"result": result,
|
| 76 |
+
"explanation": explanation,
|
| 77 |
+
"suggestions": suggestions or []
|
| 78 |
+
}
|
backend/app/automata/formatter.py
ADDED
|
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Python code formatter using Black
|
| 3 |
+
"""
|
| 4 |
+
import logging
|
| 5 |
+
from typing import Dict, Any
|
| 6 |
+
|
| 7 |
+
from app.automata.base import BaseAutomaton
|
| 8 |
+
from app.models.schemas import TaskType, Language
|
| 9 |
+
|
| 10 |
+
logger = logging.getLogger(__name__)
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class PythonFormatter(BaseAutomaton):
|
| 14 |
+
"""Formats Python code using Black"""
|
| 15 |
+
|
| 16 |
+
def __init__(self):
|
| 17 |
+
super().__init__("python_formatter")
|
| 18 |
+
self._black_available = False
|
| 19 |
+
|
| 20 |
+
try:
|
| 21 |
+
import black
|
| 22 |
+
self._black = black
|
| 23 |
+
self._black_available = True
|
| 24 |
+
logger.info("Black formatter loaded successfully")
|
| 25 |
+
except ImportError:
|
| 26 |
+
logger.warning("Black not available, formatter will be limited")
|
| 27 |
+
|
| 28 |
+
def can_handle(
|
| 29 |
+
self,
|
| 30 |
+
code: str,
|
| 31 |
+
language: Language,
|
| 32 |
+
task: TaskType
|
| 33 |
+
) -> bool:
|
| 34 |
+
"""Check if can format this code"""
|
| 35 |
+
return (
|
| 36 |
+
self._black_available
|
| 37 |
+
and language == Language.PYTHON
|
| 38 |
+
and task == TaskType.FORMAT
|
| 39 |
+
)
|
| 40 |
+
|
| 41 |
+
async def execute(
|
| 42 |
+
self,
|
| 43 |
+
code: str,
|
| 44 |
+
**kwargs
|
| 45 |
+
) -> Dict[str, Any]:
|
| 46 |
+
"""Format Python code with Black"""
|
| 47 |
+
if not self._black_available:
|
| 48 |
+
return self._format_result(
|
| 49 |
+
success=False,
|
| 50 |
+
explanation="Black formatter not available"
|
| 51 |
+
)
|
| 52 |
+
|
| 53 |
+
try:
|
| 54 |
+
# Format with Black
|
| 55 |
+
mode = self._black.Mode(
|
| 56 |
+
line_length=88,
|
| 57 |
+
string_normalization=True,
|
| 58 |
+
magic_trailing_comma=True
|
| 59 |
+
)
|
| 60 |
+
|
| 61 |
+
formatted_code = self._black.format_str(code, mode=mode)
|
| 62 |
+
|
| 63 |
+
if formatted_code == code:
|
| 64 |
+
return self._format_result(
|
| 65 |
+
success=True,
|
| 66 |
+
result=code,
|
| 67 |
+
explanation="Code is already properly formatted",
|
| 68 |
+
suggestions=[]
|
| 69 |
+
)
|
| 70 |
+
|
| 71 |
+
return self._format_result(
|
| 72 |
+
success=True,
|
| 73 |
+
result=formatted_code,
|
| 74 |
+
explanation="Code formatted with Black (PEP 8 style)",
|
| 75 |
+
suggestions=[
|
| 76 |
+
"Consider using Black in your pre-commit hooks",
|
| 77 |
+
"Configure Black in pyproject.toml for project-wide consistency"
|
| 78 |
+
]
|
| 79 |
+
)
|
| 80 |
+
|
| 81 |
+
except Exception as e:
|
| 82 |
+
logger.error(f"Black formatting failed: {e}")
|
| 83 |
+
return self._format_result(
|
| 84 |
+
success=False,
|
| 85 |
+
explanation=f"Formatting error: {str(e)}"
|
| 86 |
+
)
|
backend/app/automata/linter.py
ADDED
|
@@ -0,0 +1,139 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Python code linter using Ruff
|
| 3 |
+
"""
|
| 4 |
+
import logging
|
| 5 |
+
import subprocess
|
| 6 |
+
import sys
|
| 7 |
+
import tempfile
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
from typing import Dict, Any
|
| 10 |
+
|
| 11 |
+
from app.automata.base import BaseAutomaton
|
| 12 |
+
from app.models.schemas import TaskType, Language
|
| 13 |
+
|
| 14 |
+
logger = logging.getLogger(__name__)
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class PythonLinter(BaseAutomaton):
|
| 18 |
+
"""Lints and auto-fixes Python code using Ruff"""
|
| 19 |
+
|
| 20 |
+
def __init__(self):
|
| 21 |
+
super().__init__("python_linter")
|
| 22 |
+
self._ruff_available = self._check_ruff()
|
| 23 |
+
|
| 24 |
+
def _check_ruff(self) -> bool:
|
| 25 |
+
"""Check if Ruff is available"""
|
| 26 |
+
try:
|
| 27 |
+
# Use python -m ruff for better cross-platform compatibility
|
| 28 |
+
result = subprocess.run(
|
| 29 |
+
[sys.executable, "-m", "ruff", "--version"],
|
| 30 |
+
capture_output=True,
|
| 31 |
+
text=True,
|
| 32 |
+
timeout=5
|
| 33 |
+
)
|
| 34 |
+
if result.returncode == 0:
|
| 35 |
+
logger.info(f"Ruff available: {result.stdout.strip()}")
|
| 36 |
+
return True
|
| 37 |
+
except (FileNotFoundError, subprocess.TimeoutExpired) as e:
|
| 38 |
+
logger.warning(f"Ruff not available: {e}")
|
| 39 |
+
|
| 40 |
+
return False
|
| 41 |
+
|
| 42 |
+
def can_handle(
|
| 43 |
+
self,
|
| 44 |
+
code: str,
|
| 45 |
+
language: Language,
|
| 46 |
+
task: TaskType
|
| 47 |
+
) -> bool:
|
| 48 |
+
"""Check if can lint this code"""
|
| 49 |
+
return (
|
| 50 |
+
self._ruff_available
|
| 51 |
+
and language == Language.PYTHON
|
| 52 |
+
and task in [TaskType.FIX, TaskType.FORMAT]
|
| 53 |
+
)
|
| 54 |
+
|
| 55 |
+
async def execute(
|
| 56 |
+
self,
|
| 57 |
+
code: str,
|
| 58 |
+
**kwargs
|
| 59 |
+
) -> Dict[str, Any]:
|
| 60 |
+
"""Lint and auto-fix Python code with Ruff"""
|
| 61 |
+
if not self._ruff_available:
|
| 62 |
+
return self._format_result(
|
| 63 |
+
success=False,
|
| 64 |
+
explanation="Ruff linter not available"
|
| 65 |
+
)
|
| 66 |
+
|
| 67 |
+
try:
|
| 68 |
+
# Create temporary file
|
| 69 |
+
with tempfile.NamedTemporaryFile(
|
| 70 |
+
mode='w',
|
| 71 |
+
suffix='.py',
|
| 72 |
+
delete=False,
|
| 73 |
+
encoding='utf-8'
|
| 74 |
+
) as tmp:
|
| 75 |
+
tmp.write(code)
|
| 76 |
+
tmp_path = tmp.name
|
| 77 |
+
|
| 78 |
+
try:
|
| 79 |
+
# Run Ruff check using python -m
|
| 80 |
+
check_result = subprocess.run(
|
| 81 |
+
[sys.executable, "-m", "ruff", "check", tmp_path, "--output-format=json"],
|
| 82 |
+
capture_output=True,
|
| 83 |
+
text=True,
|
| 84 |
+
timeout=10
|
| 85 |
+
)
|
| 86 |
+
|
| 87 |
+
# Run Ruff fix using python -m
|
| 88 |
+
fix_result = subprocess.run(
|
| 89 |
+
[sys.executable, "-m", "ruff", "check", tmp_path, "--fix"],
|
| 90 |
+
capture_output=True,
|
| 91 |
+
text=True,
|
| 92 |
+
timeout=10
|
| 93 |
+
)
|
| 94 |
+
|
| 95 |
+
# Read fixed code
|
| 96 |
+
fixed_code = Path(tmp_path).read_text(encoding='utf-8')
|
| 97 |
+
|
| 98 |
+
# Count issues
|
| 99 |
+
import json
|
| 100 |
+
try:
|
| 101 |
+
issues = json.loads(check_result.stdout) if check_result.stdout else []
|
| 102 |
+
issue_count = len(issues)
|
| 103 |
+
except json.JSONDecodeError:
|
| 104 |
+
issue_count = 0
|
| 105 |
+
|
| 106 |
+
if fixed_code == code:
|
| 107 |
+
return self._format_result(
|
| 108 |
+
success=True,
|
| 109 |
+
result=code,
|
| 110 |
+
explanation="No linting issues found" if issue_count == 0 else f"Found {issue_count} issues but couldn't auto-fix",
|
| 111 |
+
suggestions=["Code follows Python best practices"] if issue_count == 0 else []
|
| 112 |
+
)
|
| 113 |
+
|
| 114 |
+
return self._format_result(
|
| 115 |
+
success=True,
|
| 116 |
+
result=fixed_code,
|
| 117 |
+
explanation=f"Auto-fixed {issue_count} linting issues",
|
| 118 |
+
suggestions=[
|
| 119 |
+
"Configure Ruff in pyproject.toml",
|
| 120 |
+
"Add Ruff to your CI/CD pipeline"
|
| 121 |
+
]
|
| 122 |
+
)
|
| 123 |
+
|
| 124 |
+
finally:
|
| 125 |
+
# Cleanup
|
| 126 |
+
Path(tmp_path).unlink(missing_ok=True)
|
| 127 |
+
|
| 128 |
+
except subprocess.TimeoutExpired:
|
| 129 |
+
logger.error("Ruff execution timed out")
|
| 130 |
+
return self._format_result(
|
| 131 |
+
success=False,
|
| 132 |
+
explanation="Linting timed out"
|
| 133 |
+
)
|
| 134 |
+
except Exception as e:
|
| 135 |
+
logger.error(f"Ruff linting failed: {e}")
|
| 136 |
+
return self._format_result(
|
| 137 |
+
success=False,
|
| 138 |
+
explanation=f"Linting error: {str(e)}"
|
| 139 |
+
)
|
backend/app/automata/runtime_fixer.py
ADDED
|
@@ -0,0 +1,297 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Runtime error fixer for common Python errors
|
| 3 |
+
|
| 4 |
+
Fixes simple runtime errors based on trace analysis:
|
| 5 |
+
- ZeroDivisionError: Add checks before division
|
| 6 |
+
- NameError: Detect typos in variable names
|
| 7 |
+
- IndexError: Add boundary checks
|
| 8 |
+
- SyntaxError with = vs ==: Fix comparison operators
|
| 9 |
+
"""
|
| 10 |
+
import ast
|
| 11 |
+
import re
|
| 12 |
+
import logging
|
| 13 |
+
from typing import Dict, Any, Optional, List, Tuple
|
| 14 |
+
from difflib import get_close_matches
|
| 15 |
+
|
| 16 |
+
from app.automata.base import BaseAutomaton
|
| 17 |
+
from app.models.schemas import TaskType, Language
|
| 18 |
+
|
| 19 |
+
logger = logging.getLogger(__name__)
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class RuntimeFixer(BaseAutomaton):
|
| 23 |
+
"""Fixes common runtime errors using trace analysis"""
|
| 24 |
+
|
| 25 |
+
def __init__(self):
|
| 26 |
+
super().__init__("runtime_fixer")
|
| 27 |
+
|
| 28 |
+
def can_handle(
|
| 29 |
+
self,
|
| 30 |
+
code: str,
|
| 31 |
+
language: Language,
|
| 32 |
+
task: TaskType,
|
| 33 |
+
trace: Optional[str] = None
|
| 34 |
+
) -> bool:
|
| 35 |
+
"""Check if can fix this code"""
|
| 36 |
+
return (
|
| 37 |
+
language == Language.PYTHON
|
| 38 |
+
and task == TaskType.FIX
|
| 39 |
+
and trace is not None # Need trace to know what to fix
|
| 40 |
+
)
|
| 41 |
+
|
| 42 |
+
async def execute(
|
| 43 |
+
self,
|
| 44 |
+
code: str,
|
| 45 |
+
trace: Optional[str] = None,
|
| 46 |
+
**kwargs
|
| 47 |
+
) -> Dict[str, Any]:
|
| 48 |
+
"""Try to fix runtime errors based on trace"""
|
| 49 |
+
if not trace:
|
| 50 |
+
return self._format_result(
|
| 51 |
+
success=False,
|
| 52 |
+
explanation="No trace provided"
|
| 53 |
+
)
|
| 54 |
+
|
| 55 |
+
# Identify error type from trace
|
| 56 |
+
error_type = self._identify_error_type(trace)
|
| 57 |
+
|
| 58 |
+
if not error_type:
|
| 59 |
+
return self._format_result(
|
| 60 |
+
success=False,
|
| 61 |
+
explanation="Unknown error type"
|
| 62 |
+
)
|
| 63 |
+
|
| 64 |
+
# Try to fix based on error type
|
| 65 |
+
fixers = {
|
| 66 |
+
"ZeroDivisionError": self._fix_zero_division,
|
| 67 |
+
"NameError": self._fix_name_error,
|
| 68 |
+
"IndexError": self._fix_index_error,
|
| 69 |
+
"SyntaxError": self._fix_syntax_error,
|
| 70 |
+
}
|
| 71 |
+
|
| 72 |
+
fixer_func = fixers.get(error_type)
|
| 73 |
+
if not fixer_func:
|
| 74 |
+
return self._format_result(
|
| 75 |
+
success=False,
|
| 76 |
+
explanation=f"No fixer available for {error_type}"
|
| 77 |
+
)
|
| 78 |
+
|
| 79 |
+
# Apply the fix
|
| 80 |
+
try:
|
| 81 |
+
fixed_code, explanation = fixer_func(code, trace)
|
| 82 |
+
if fixed_code:
|
| 83 |
+
return self._format_result(
|
| 84 |
+
success=True,
|
| 85 |
+
result=fixed_code,
|
| 86 |
+
explanation=explanation,
|
| 87 |
+
suggestions=["Test the fixed code with various inputs"]
|
| 88 |
+
)
|
| 89 |
+
else:
|
| 90 |
+
return self._format_result(
|
| 91 |
+
success=False,
|
| 92 |
+
explanation="Could not automatically fix this error"
|
| 93 |
+
)
|
| 94 |
+
except Exception as e:
|
| 95 |
+
logger.error(f"Runtime fixer failed: {e}")
|
| 96 |
+
return self._format_result(
|
| 97 |
+
success=False,
|
| 98 |
+
explanation=f"Fix attempt failed: {str(e)}"
|
| 99 |
+
)
|
| 100 |
+
|
| 101 |
+
def _identify_error_type(self, trace: str) -> Optional[str]:
|
| 102 |
+
"""Identify the type of error from trace"""
|
| 103 |
+
error_patterns = [
|
| 104 |
+
(r"ZeroDivisionError", "ZeroDivisionError"),
|
| 105 |
+
(r"NameError: name '(\w+)' is not defined", "NameError"),
|
| 106 |
+
(r"IndexError", "IndexError"),
|
| 107 |
+
(r"SyntaxError.*'=' .*'=='", "SyntaxError"),
|
| 108 |
+
]
|
| 109 |
+
|
| 110 |
+
for pattern, error_type in error_patterns:
|
| 111 |
+
if re.search(pattern, trace):
|
| 112 |
+
return error_type
|
| 113 |
+
|
| 114 |
+
return None
|
| 115 |
+
|
| 116 |
+
def _fix_zero_division(self, code: str, trace: str) -> Tuple[Optional[str], Optional[str]]:
|
| 117 |
+
"""Fix division by zero errors"""
|
| 118 |
+
try:
|
| 119 |
+
tree = ast.parse(code)
|
| 120 |
+
except:
|
| 121 |
+
return None, None
|
| 122 |
+
|
| 123 |
+
# Find all divisions
|
| 124 |
+
class DivisionFixer(ast.NodeTransformer):
|
| 125 |
+
def __init__(self):
|
| 126 |
+
self.fixed = False
|
| 127 |
+
|
| 128 |
+
def visit_BinOp(self, node):
|
| 129 |
+
# Check if it's a division
|
| 130 |
+
if isinstance(node.op, (ast.Div, ast.FloorDiv)):
|
| 131 |
+
# Check if denominator could be zero
|
| 132 |
+
denom = node.right
|
| 133 |
+
|
| 134 |
+
# If denominator is len() or similar, add check
|
| 135 |
+
if isinstance(denom, ast.Call):
|
| 136 |
+
if isinstance(denom.func, ast.Name) and denom.func.id == 'len':
|
| 137 |
+
# This is a len() call - needs check
|
| 138 |
+
self.fixed = True
|
| 139 |
+
# We'll handle this at statement level
|
| 140 |
+
elif isinstance(denom, ast.Name):
|
| 141 |
+
# Variable - might need check
|
| 142 |
+
self.fixed = True
|
| 143 |
+
|
| 144 |
+
return self.generic_visit(node)
|
| 145 |
+
|
| 146 |
+
fixer = DivisionFixer()
|
| 147 |
+
fixer.visit(tree)
|
| 148 |
+
|
| 149 |
+
if not fixer.fixed:
|
| 150 |
+
return None, None
|
| 151 |
+
|
| 152 |
+
# Add protective check
|
| 153 |
+
# For now, simple pattern matching approach
|
| 154 |
+
lines = code.split('\n')
|
| 155 |
+
fixed_lines = []
|
| 156 |
+
|
| 157 |
+
for line in lines:
|
| 158 |
+
# Look for division with len()
|
| 159 |
+
if '/ len(' in line or '// len(' in line:
|
| 160 |
+
# Extract the variable being divided
|
| 161 |
+
indent = len(line) - len(line.lstrip())
|
| 162 |
+
spacing = ' ' * indent
|
| 163 |
+
|
| 164 |
+
# Add check before division
|
| 165 |
+
# Extract the len() argument
|
| 166 |
+
match = re.search(r'len\((\w+)\)', line)
|
| 167 |
+
if match:
|
| 168 |
+
var_name = match.group(1)
|
| 169 |
+
fixed_lines.append(f"{spacing}if not {var_name}:")
|
| 170 |
+
fixed_lines.append(f"{spacing} return 0")
|
| 171 |
+
fixed_lines.append(line)
|
| 172 |
+
else:
|
| 173 |
+
fixed_lines.append(line)
|
| 174 |
+
|
| 175 |
+
fixed_code = '\n'.join(fixed_lines)
|
| 176 |
+
|
| 177 |
+
# Verify it parses
|
| 178 |
+
try:
|
| 179 |
+
ast.parse(fixed_code)
|
| 180 |
+
return fixed_code, "Added zero-division check"
|
| 181 |
+
except:
|
| 182 |
+
return None, None
|
| 183 |
+
|
| 184 |
+
def _fix_name_error(self, code: str, trace: str) -> Tuple[Optional[str], Optional[str]]:
|
| 185 |
+
"""Fix undefined variable names (typos)"""
|
| 186 |
+
# Extract the undefined variable name
|
| 187 |
+
match = re.search(r"name '(\w+)' is not defined", trace)
|
| 188 |
+
if not match:
|
| 189 |
+
return None, None
|
| 190 |
+
|
| 191 |
+
undefined_var = match.group(1)
|
| 192 |
+
|
| 193 |
+
# Parse code to find all defined variables
|
| 194 |
+
try:
|
| 195 |
+
tree = ast.parse(code)
|
| 196 |
+
except:
|
| 197 |
+
return None, None
|
| 198 |
+
|
| 199 |
+
# Collect all defined names
|
| 200 |
+
defined_names = set()
|
| 201 |
+
|
| 202 |
+
class NameCollector(ast.NodeVisitor):
|
| 203 |
+
def visit_Name(self, node):
|
| 204 |
+
if isinstance(node.ctx, ast.Store):
|
| 205 |
+
defined_names.add(node.id)
|
| 206 |
+
|
| 207 |
+
def visit_FunctionDef(self, node):
|
| 208 |
+
# Add function parameters
|
| 209 |
+
for arg in node.args.args:
|
| 210 |
+
defined_names.add(arg.arg)
|
| 211 |
+
self.generic_visit(node)
|
| 212 |
+
|
| 213 |
+
NameCollector().visit(tree)
|
| 214 |
+
|
| 215 |
+
# Find closest match
|
| 216 |
+
matches = get_close_matches(undefined_var, defined_names, n=1, cutoff=0.6)
|
| 217 |
+
|
| 218 |
+
if matches:
|
| 219 |
+
correct_name = matches[0]
|
| 220 |
+
# Replace typo with correct name
|
| 221 |
+
fixed_code = re.sub(r'\b' + undefined_var + r'\b', correct_name, code)
|
| 222 |
+
|
| 223 |
+
# Verify it parses
|
| 224 |
+
try:
|
| 225 |
+
ast.parse(fixed_code)
|
| 226 |
+
return fixed_code, f"Fixed typo: '{undefined_var}' → '{correct_name}'"
|
| 227 |
+
except:
|
| 228 |
+
return None, None
|
| 229 |
+
|
| 230 |
+
return None, None
|
| 231 |
+
|
| 232 |
+
def _fix_index_error(self, code: str, trace: str) -> Tuple[Optional[str], Optional[str]]:
|
| 233 |
+
"""Fix index out of range errors"""
|
| 234 |
+
lines = code.split('\n')
|
| 235 |
+
fixed_lines = []
|
| 236 |
+
|
| 237 |
+
for line in lines:
|
| 238 |
+
# Look for array indexing [0], [1], etc.
|
| 239 |
+
if re.search(r'\w+\[\d+\]', line):
|
| 240 |
+
indent = len(line) - len(line.lstrip())
|
| 241 |
+
spacing = ' ' * indent
|
| 242 |
+
|
| 243 |
+
# Extract the variable being indexed
|
| 244 |
+
match = re.search(r'(\w+)\[(\d+)\]', line)
|
| 245 |
+
if match:
|
| 246 |
+
var_name = match.group(1)
|
| 247 |
+
index = match.group(2)
|
| 248 |
+
|
| 249 |
+
# Add check
|
| 250 |
+
fixed_lines.append(f"{spacing}if not {var_name}:")
|
| 251 |
+
fixed_lines.append(f"{spacing} return None")
|
| 252 |
+
fixed_lines.append(line)
|
| 253 |
+
else:
|
| 254 |
+
fixed_lines.append(line)
|
| 255 |
+
|
| 256 |
+
fixed_code = '\n'.join(fixed_lines)
|
| 257 |
+
|
| 258 |
+
# Verify it parses
|
| 259 |
+
try:
|
| 260 |
+
ast.parse(fixed_code)
|
| 261 |
+
if fixed_code != code:
|
| 262 |
+
return fixed_code, "Added index bounds check"
|
| 263 |
+
except:
|
| 264 |
+
pass
|
| 265 |
+
|
| 266 |
+
return None, None
|
| 267 |
+
|
| 268 |
+
def _fix_syntax_error(self, code: str, trace: str) -> Tuple[Optional[str], Optional[str]]:
|
| 269 |
+
"""Fix = vs == in conditions"""
|
| 270 |
+
if "'=' " in trace and "'=='" in trace:
|
| 271 |
+
# This is the = vs == error
|
| 272 |
+
# Fix by replacing = with == in if statements
|
| 273 |
+
lines = code.split('\n')
|
| 274 |
+
fixed_lines = []
|
| 275 |
+
fixed = False
|
| 276 |
+
|
| 277 |
+
for line in lines:
|
| 278 |
+
# Look for if x = value pattern
|
| 279 |
+
if 'if ' in line and ' = ' in line and not line.strip().endswith('='):
|
| 280 |
+
# Replace = with ==
|
| 281 |
+
parts = line.split(' = ', 1)
|
| 282 |
+
if len(parts) == 2:
|
| 283 |
+
fixed_line = parts[0] + ' == ' + parts[1]
|
| 284 |
+
fixed_lines.append(fixed_line)
|
| 285 |
+
fixed = True
|
| 286 |
+
continue
|
| 287 |
+
fixed_lines.append(line)
|
| 288 |
+
|
| 289 |
+
if fixed:
|
| 290 |
+
fixed_code = '\n'.join(fixed_lines)
|
| 291 |
+
try:
|
| 292 |
+
ast.parse(fixed_code)
|
| 293 |
+
return fixed_code, "Fixed comparison: '=' → '=='"
|
| 294 |
+
except:
|
| 295 |
+
pass
|
| 296 |
+
|
| 297 |
+
return None, None
|
backend/app/automata/test_generator.py
ADDED
|
@@ -0,0 +1,161 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Template-based test generator
|
| 3 |
+
|
| 4 |
+
Generates basic test structure using templates.
|
| 5 |
+
SLM will fill in the specific test cases.
|
| 6 |
+
"""
|
| 7 |
+
import re
|
| 8 |
+
import logging
|
| 9 |
+
from typing import Dict, Any, Optional, List
|
| 10 |
+
|
| 11 |
+
from app.automata.base import BaseAutomaton
|
| 12 |
+
from app.models.schemas import TaskType, Language
|
| 13 |
+
|
| 14 |
+
logger = logging.getLogger(__name__)
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class TestTemplateGenerator(BaseAutomaton):
|
| 18 |
+
"""Generates test templates for code"""
|
| 19 |
+
|
| 20 |
+
def __init__(self):
|
| 21 |
+
super().__init__("test_template")
|
| 22 |
+
|
| 23 |
+
self.templates = {
|
| 24 |
+
Language.PYTHON: self._python_template,
|
| 25 |
+
Language.JAVASCRIPT: self._javascript_template,
|
| 26 |
+
}
|
| 27 |
+
|
| 28 |
+
def can_handle(
|
| 29 |
+
self,
|
| 30 |
+
code: str,
|
| 31 |
+
language: Language,
|
| 32 |
+
task: TaskType
|
| 33 |
+
) -> bool:
|
| 34 |
+
"""Check if can generate test template"""
|
| 35 |
+
# Only generates templates, not full tests
|
| 36 |
+
# Returns partial result for SLM to complete
|
| 37 |
+
return False # Let SLM handle full test generation
|
| 38 |
+
|
| 39 |
+
async def execute(
|
| 40 |
+
self,
|
| 41 |
+
code: str,
|
| 42 |
+
language: Language = Language.PYTHON,
|
| 43 |
+
**kwargs
|
| 44 |
+
) -> Dict[str, Any]:
|
| 45 |
+
"""Generate test template"""
|
| 46 |
+
template_func = self.templates.get(language)
|
| 47 |
+
|
| 48 |
+
if not template_func:
|
| 49 |
+
return self._format_result(
|
| 50 |
+
success=False,
|
| 51 |
+
explanation=f"No template for {language}"
|
| 52 |
+
)
|
| 53 |
+
|
| 54 |
+
try:
|
| 55 |
+
# Extract function/class names
|
| 56 |
+
entities = self._extract_entities(code, language)
|
| 57 |
+
|
| 58 |
+
# Generate template
|
| 59 |
+
template = template_func(entities)
|
| 60 |
+
|
| 61 |
+
return self._format_result(
|
| 62 |
+
success=True,
|
| 63 |
+
result=template,
|
| 64 |
+
explanation=f"Generated test template for {len(entities)} entities",
|
| 65 |
+
suggestions=[
|
| 66 |
+
"Fill in test cases with specific scenarios",
|
| 67 |
+
"Add edge cases and error handling tests",
|
| 68 |
+
"Consider using parametrized tests"
|
| 69 |
+
]
|
| 70 |
+
)
|
| 71 |
+
|
| 72 |
+
except Exception as e:
|
| 73 |
+
logger.error(f"Template generation failed: {e}")
|
| 74 |
+
return self._format_result(
|
| 75 |
+
success=False,
|
| 76 |
+
explanation=f"Generation error: {str(e)}"
|
| 77 |
+
)
|
| 78 |
+
|
| 79 |
+
def _extract_entities(self, code: str, language: Language) -> List[Dict[str, str]]:
|
| 80 |
+
"""Extract functions and classes from code"""
|
| 81 |
+
entities = []
|
| 82 |
+
|
| 83 |
+
if language == Language.PYTHON:
|
| 84 |
+
# Extract Python functions and classes
|
| 85 |
+
func_pattern = r'^def\s+(\w+)\s*\('
|
| 86 |
+
class_pattern = r'^class\s+(\w+)'
|
| 87 |
+
|
| 88 |
+
for match in re.finditer(func_pattern, code, re.MULTILINE):
|
| 89 |
+
entities.append({"type": "function", "name": match.group(1)})
|
| 90 |
+
|
| 91 |
+
for match in re.finditer(class_pattern, code, re.MULTILINE):
|
| 92 |
+
entities.append({"type": "class", "name": match.group(1)})
|
| 93 |
+
|
| 94 |
+
elif language == Language.JAVASCRIPT:
|
| 95 |
+
# Extract JavaScript functions
|
| 96 |
+
func_pattern = r'function\s+(\w+)\s*\('
|
| 97 |
+
arrow_pattern = r'const\s+(\w+)\s*=\s*\('
|
| 98 |
+
|
| 99 |
+
for match in re.finditer(func_pattern, code, re.MULTILINE):
|
| 100 |
+
entities.append({"type": "function", "name": match.group(1)})
|
| 101 |
+
|
| 102 |
+
for match in re.finditer(arrow_pattern, code, re.MULTILINE):
|
| 103 |
+
entities.append({"type": "function", "name": match.group(1)})
|
| 104 |
+
|
| 105 |
+
return entities
|
| 106 |
+
|
| 107 |
+
def _python_template(self, entities: List[Dict[str, str]]) -> str:
|
| 108 |
+
"""Generate Python test template"""
|
| 109 |
+
template = '''"""
|
| 110 |
+
Unit tests for generated code
|
| 111 |
+
"""
|
| 112 |
+
import pytest
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
'''
|
| 116 |
+
for entity in entities:
|
| 117 |
+
if entity["type"] == "function":
|
| 118 |
+
template += f'''def test_{entity["name"]}():
|
| 119 |
+
"""Test {entity["name"]} function"""
|
| 120 |
+
# TODO: Add test cases
|
| 121 |
+
pass
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
'''
|
| 125 |
+
elif entity["type"] == "class":
|
| 126 |
+
template += f'''class Test{entity["name"]}:
|
| 127 |
+
"""Test {entity["name"]} class"""
|
| 128 |
+
|
| 129 |
+
def test_init(self):
|
| 130 |
+
"""Test initialization"""
|
| 131 |
+
# TODO: Add test
|
| 132 |
+
pass
|
| 133 |
+
|
| 134 |
+
def test_methods(self):
|
| 135 |
+
"""Test methods"""
|
| 136 |
+
# TODO: Add test
|
| 137 |
+
pass
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
'''
|
| 141 |
+
|
| 142 |
+
return template
|
| 143 |
+
|
| 144 |
+
def _javascript_template(self, entities: List[Dict[str, str]]) -> str:
|
| 145 |
+
"""Generate JavaScript test template"""
|
| 146 |
+
template = '''/**
|
| 147 |
+
* Unit tests for generated code
|
| 148 |
+
*/
|
| 149 |
+
|
| 150 |
+
'''
|
| 151 |
+
for entity in entities:
|
| 152 |
+
template += f'''describe('{entity["name"]}', () => {{
|
| 153 |
+
test('should work correctly', () => {{
|
| 154 |
+
// TODO: Add test cases
|
| 155 |
+
expect(true).toBe(true);
|
| 156 |
+
}});
|
| 157 |
+
}});
|
| 158 |
+
|
| 159 |
+
'''
|
| 160 |
+
|
| 161 |
+
return template
|
backend/app/automata/trace_parser.py
ADDED
|
@@ -0,0 +1,177 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Error trace parser and explainer
|
| 3 |
+
|
| 4 |
+
Uses regex and pattern matching to extract key information
|
| 5 |
+
from error traces before passing to LLM.
|
| 6 |
+
"""
|
| 7 |
+
import re
|
| 8 |
+
import logging
|
| 9 |
+
from typing import Dict, Any, Optional, List
|
| 10 |
+
|
| 11 |
+
from app.automata.base import BaseAutomaton
|
| 12 |
+
from app.models.schemas import TaskType, Language
|
| 13 |
+
|
| 14 |
+
logger = logging.getLogger(__name__)
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class TraceParser(BaseAutomaton):
|
| 18 |
+
"""Parses and extracts information from error traces"""
|
| 19 |
+
|
| 20 |
+
def __init__(self):
|
| 21 |
+
super().__init__("trace_parser")
|
| 22 |
+
|
| 23 |
+
# Common error patterns
|
| 24 |
+
self.patterns = {
|
| 25 |
+
"python": [
|
| 26 |
+
(r"(\w+Error): (.+)", "error_type"),
|
| 27 |
+
(r'File "([^"]+)", line (\d+)', "file_location"),
|
| 28 |
+
(r"NameError: name '(\w+)' is not defined", "undefined_variable"),
|
| 29 |
+
(r"TypeError: (.+)", "type_error"),
|
| 30 |
+
(r"AttributeError: (.+) has no attribute '(\w+)'", "attribute_error"),
|
| 31 |
+
(r"IndexError: (.+)", "index_error"),
|
| 32 |
+
(r"KeyError: (.+)", "key_error"),
|
| 33 |
+
],
|
| 34 |
+
"javascript": [
|
| 35 |
+
(r"(\w+Error): (.+)", "error_type"),
|
| 36 |
+
(r"at (.+) \((.+):(\d+):(\d+)\)", "stack_location"),
|
| 37 |
+
(r"ReferenceError: (\w+) is not defined", "undefined_variable"),
|
| 38 |
+
(r"TypeError: (.+)", "type_error"),
|
| 39 |
+
]
|
| 40 |
+
}
|
| 41 |
+
|
| 42 |
+
def can_handle(
|
| 43 |
+
self,
|
| 44 |
+
code: str,
|
| 45 |
+
language: Language,
|
| 46 |
+
task: TaskType
|
| 47 |
+
) -> bool:
|
| 48 |
+
"""Check if can parse this trace"""
|
| 49 |
+
# Only handle explain tasks
|
| 50 |
+
return task == TaskType.EXPLAIN
|
| 51 |
+
|
| 52 |
+
async def execute(
|
| 53 |
+
self,
|
| 54 |
+
code: str,
|
| 55 |
+
trace: Optional[str] = None,
|
| 56 |
+
**kwargs
|
| 57 |
+
) -> Dict[str, Any]:
|
| 58 |
+
"""Parse error trace"""
|
| 59 |
+
if not trace:
|
| 60 |
+
return self._format_result(
|
| 61 |
+
success=False,
|
| 62 |
+
explanation="No trace provided"
|
| 63 |
+
)
|
| 64 |
+
|
| 65 |
+
try:
|
| 66 |
+
# Detect language from trace
|
| 67 |
+
language = self._detect_language(trace)
|
| 68 |
+
|
| 69 |
+
# Extract structured information
|
| 70 |
+
info = self._extract_info(trace, language)
|
| 71 |
+
|
| 72 |
+
if not info:
|
| 73 |
+
# Couldn't parse, let SLM handle it
|
| 74 |
+
return self._format_result(
|
| 75 |
+
success=False,
|
| 76 |
+
explanation="Trace format not recognized, needs SLM analysis"
|
| 77 |
+
)
|
| 78 |
+
|
| 79 |
+
# Build explanation
|
| 80 |
+
explanation = self._build_explanation(info)
|
| 81 |
+
suggestions = self._get_suggestions(info)
|
| 82 |
+
|
| 83 |
+
return self._format_result(
|
| 84 |
+
success=True,
|
| 85 |
+
result=trace, # Return original trace
|
| 86 |
+
explanation=explanation,
|
| 87 |
+
suggestions=suggestions
|
| 88 |
+
)
|
| 89 |
+
|
| 90 |
+
except Exception as e:
|
| 91 |
+
logger.error(f"Trace parsing failed: {e}")
|
| 92 |
+
return self._format_result(
|
| 93 |
+
success=False,
|
| 94 |
+
explanation=f"Parse error: {str(e)}"
|
| 95 |
+
)
|
| 96 |
+
|
| 97 |
+
def _detect_language(self, trace: str) -> str:
|
| 98 |
+
"""Detect programming language from trace"""
|
| 99 |
+
if "Traceback (most recent call last)" in trace or "Error:" in trace:
|
| 100 |
+
return "python"
|
| 101 |
+
elif "at " in trace and "Error:" in trace:
|
| 102 |
+
return "javascript"
|
| 103 |
+
return "unknown"
|
| 104 |
+
|
| 105 |
+
def _extract_info(self, trace: str, language: str) -> Dict[str, Any]:
|
| 106 |
+
"""Extract structured information from trace"""
|
| 107 |
+
info = {
|
| 108 |
+
"language": language,
|
| 109 |
+
"error_type": None,
|
| 110 |
+
"error_message": None,
|
| 111 |
+
"file": None,
|
| 112 |
+
"line": None,
|
| 113 |
+
"details": {}
|
| 114 |
+
}
|
| 115 |
+
|
| 116 |
+
patterns = self.patterns.get(language, [])
|
| 117 |
+
|
| 118 |
+
for pattern, name in patterns:
|
| 119 |
+
match = re.search(pattern, trace)
|
| 120 |
+
if match:
|
| 121 |
+
if name == "error_type":
|
| 122 |
+
info["error_type"] = match.group(1)
|
| 123 |
+
info["error_message"] = match.group(2)
|
| 124 |
+
elif name == "file_location":
|
| 125 |
+
info["file"] = match.group(1)
|
| 126 |
+
info["line"] = match.group(2)
|
| 127 |
+
elif name == "undefined_variable":
|
| 128 |
+
info["details"]["undefined_var"] = match.group(1)
|
| 129 |
+
elif name in ["type_error", "attribute_error", "index_error", "key_error"]:
|
| 130 |
+
info["details"][name] = match.groups()
|
| 131 |
+
|
| 132 |
+
return info if info["error_type"] else {}
|
| 133 |
+
|
| 134 |
+
def _build_explanation(self, info: Dict[str, Any]) -> str:
|
| 135 |
+
"""Build human-readable explanation"""
|
| 136 |
+
error_type = info.get("error_type", "Unknown")
|
| 137 |
+
error_msg = info.get("error_message", "")
|
| 138 |
+
file_info = ""
|
| 139 |
+
|
| 140 |
+
if info.get("file") and info.get("line"):
|
| 141 |
+
file_info = f" in {info['file']} at line {info['line']}"
|
| 142 |
+
|
| 143 |
+
explanation = f"**{error_type}**{file_info}\n\n{error_msg}"
|
| 144 |
+
|
| 145 |
+
# Add specific guidance
|
| 146 |
+
if "undefined_var" in info.get("details", {}):
|
| 147 |
+
var = info["details"]["undefined_var"]
|
| 148 |
+
explanation += f"\n\nThe variable '{var}' is used but not defined. Check for typos or ensure it's declared before use."
|
| 149 |
+
|
| 150 |
+
return explanation
|
| 151 |
+
|
| 152 |
+
def _get_suggestions(self, info: Dict[str, Any]) -> List[str]:
|
| 153 |
+
"""Get suggestions based on error type"""
|
| 154 |
+
error_type = info.get("error_type", "")
|
| 155 |
+
suggestions = []
|
| 156 |
+
|
| 157 |
+
if error_type == "NameError":
|
| 158 |
+
suggestions.append("Check for typos in variable names")
|
| 159 |
+
suggestions.append("Ensure variables are defined before use")
|
| 160 |
+
suggestions.append("Check import statements")
|
| 161 |
+
|
| 162 |
+
elif error_type == "TypeError":
|
| 163 |
+
suggestions.append("Verify function arguments match expected types")
|
| 164 |
+
suggestions.append("Check None values before operations")
|
| 165 |
+
suggestions.append("Add type hints for better clarity")
|
| 166 |
+
|
| 167 |
+
elif error_type == "AttributeError":
|
| 168 |
+
suggestions.append("Verify the object has the expected attribute")
|
| 169 |
+
suggestions.append("Check for None values")
|
| 170 |
+
suggestions.append("Review object initialization")
|
| 171 |
+
|
| 172 |
+
elif error_type == "IndexError":
|
| 173 |
+
suggestions.append("Check list/array bounds")
|
| 174 |
+
suggestions.append("Verify index values are within range")
|
| 175 |
+
suggestions.append("Use len() to validate indices")
|
| 176 |
+
|
| 177 |
+
return suggestions
|
backend/app/config.py
ADDED
|
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Configuration management for SLM Code Engine
|
| 3 |
+
"""
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
from typing import Optional
|
| 6 |
+
from pydantic_settings import BaseSettings
|
| 7 |
+
from pydantic import Field
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class Settings(BaseSettings):
|
| 11 |
+
"""Application settings with environment variable support"""
|
| 12 |
+
|
| 13 |
+
# API Configuration
|
| 14 |
+
api_host: str = Field(default="0.0.0.0", env="API_HOST")
|
| 15 |
+
api_port: int = Field(default=8000, env="API_PORT")
|
| 16 |
+
api_workers: int = Field(default=1, env="API_WORKERS")
|
| 17 |
+
debug: bool = Field(default=True, env="DEBUG")
|
| 18 |
+
|
| 19 |
+
# Groq Configuration
|
| 20 |
+
groq_api_key: Optional[str] = Field(default=None, env="GROQ_API_KEY")
|
| 21 |
+
|
| 22 |
+
# Localization
|
| 23 |
+
language: str = Field(
|
| 24 |
+
default="en",
|
| 25 |
+
env="LANGUAGE",
|
| 26 |
+
description="Language for responses (e.g., 'en', 'fr')",
|
| 27 |
+
)
|
| 28 |
+
|
| 29 |
+
# Project paths
|
| 30 |
+
project_root: Path = Path(__file__).parent.parent.parent
|
| 31 |
+
models_dir: Path = Field(default_factory=lambda: Path(__file__).parent.parent.parent / "models")
|
| 32 |
+
data_dir: Path = Field(default_factory=lambda: Path(__file__).parent.parent.parent / "data")
|
| 33 |
+
cache_dir: Path = Field(default_factory=lambda: Path(__file__).parent.parent.parent / "data" / "cache")
|
| 34 |
+
|
| 35 |
+
# Models Configuration
|
| 36 |
+
starcoder_model: str = Field(default="phi-2.Q4_K_M.gguf", env="STARCODER_MODEL")
|
| 37 |
+
codet5_model: str = Field(default="codet5-small", env="CODET5_MODEL")
|
| 38 |
+
embedding_model: str = Field(default="all-MiniLM-L6-v2", env="EMBEDDING_MODEL")
|
| 39 |
+
|
| 40 |
+
# Model inference settings
|
| 41 |
+
max_tokens: int = Field(default=2048, env="MAX_TOKENS")
|
| 42 |
+
temperature: float = Field(default=0.2, env="TEMPERATURE")
|
| 43 |
+
n_ctx: int = Field(default=4096, env="N_CTX") # Context window
|
| 44 |
+
n_threads: Optional[int] = Field(default=None, env="N_THREADS") # CPU threads
|
| 45 |
+
|
| 46 |
+
# Database
|
| 47 |
+
db_path: Path = Field(default_factory=lambda: Path(__file__).parent.parent.parent / "data" / "usage.db")
|
| 48 |
+
|
| 49 |
+
# Sandbox Configuration
|
| 50 |
+
sandbox_enabled: bool = Field(default=True, env="SANDBOX_ENABLED")
|
| 51 |
+
sandbox_timeout: int = Field(default=30, env="SANDBOX_TIMEOUT") # seconds
|
| 52 |
+
sandbox_memory_limit: str = Field(default="512m", env="SANDBOX_MEMORY_LIMIT")
|
| 53 |
+
|
| 54 |
+
# Orchestrator Configuration
|
| 55 |
+
router_threshold: float = Field(default=0.7, env="ROUTER_THRESHOLD") # Confidence threshold
|
| 56 |
+
enable_automata_first: bool = Field(default=True, env="ENABLE_AUTOMATA_FIRST")
|
| 57 |
+
enable_automata_first: bool = Field(default=True, env="ENABLE_AUTOMATA_FIRST")
|
| 58 |
+
enable_rag: bool = Field(default=True, env="ENABLE_RAG") # Enable RAG context enrichment
|
| 59 |
+
enable_distillation: bool = Field(default=True, env="ENABLE_DISTILLATION") # Enable data collection
|
| 60 |
+
|
| 61 |
+
# Logging
|
| 62 |
+
log_level: str = Field(default="INFO", env="LOG_LEVEL")
|
| 63 |
+
log_file: Optional[Path] = Field(default=None, env="LOG_FILE")
|
| 64 |
+
|
| 65 |
+
class Config:
|
| 66 |
+
# Look for .env in project root
|
| 67 |
+
env_file = str(Path(__file__).parent.parent.parent / ".env")
|
| 68 |
+
env_file_encoding = "utf-8"
|
| 69 |
+
case_sensitive = False
|
| 70 |
+
extra = "ignore" # Ignore extra fields in .env
|
| 71 |
+
|
| 72 |
+
def __init__(self, **kwargs):
|
| 73 |
+
super().__init__(**kwargs)
|
| 74 |
+
# Ensure directories exist
|
| 75 |
+
self.models_dir.mkdir(parents=True, exist_ok=True)
|
| 76 |
+
self.data_dir.mkdir(parents=True, exist_ok=True)
|
| 77 |
+
self.cache_dir.mkdir(parents=True, exist_ok=True)
|
| 78 |
+
|
| 79 |
+
@property
|
| 80 |
+
def starcoder_path(self) -> Path:
|
| 81 |
+
"""Get full path to StarCoder model (currently using Phi-2)"""
|
| 82 |
+
return self.models_dir / "phi-2" / self.starcoder_model
|
| 83 |
+
|
| 84 |
+
@property
|
| 85 |
+
def codet5_path(self) -> Path:
|
| 86 |
+
"""Get full path to CodeT5 model"""
|
| 87 |
+
return self.models_dir / "codet5-small"
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
# Global settings instance
|
| 91 |
+
settings = Settings()
|
backend/app/core/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
"""Core package"""
|
backend/app/core/automata_manager.py
ADDED
|
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Automata Manager
|
| 3 |
+
|
| 4 |
+
Manages the lifecycle and execution of deterministic automata.
|
| 5 |
+
Automata are fast, rule-based code processors that handle simple tasks.
|
| 6 |
+
"""
|
| 7 |
+
import logging
|
| 8 |
+
from typing import Dict, Optional, Any
|
| 9 |
+
from app.models.schemas import TaskType, Language
|
| 10 |
+
|
| 11 |
+
logger = logging.getLogger(__name__)
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class AutomataManager:
|
| 15 |
+
"""Manager for deterministic automata"""
|
| 16 |
+
|
| 17 |
+
def __init__(self):
|
| 18 |
+
self.automata: Dict[str, Any] = {}
|
| 19 |
+
logger.info("AutomataManager initialized")
|
| 20 |
+
|
| 21 |
+
def register_automaton(self, name: str, automaton: Any):
|
| 22 |
+
"""Register an automaton"""
|
| 23 |
+
self.automata[name] = automaton
|
| 24 |
+
logger.info(f"Registered automaton: {name}")
|
| 25 |
+
|
| 26 |
+
def get_automaton(self, name: str) -> Optional[Any]:
|
| 27 |
+
"""Get an automaton by name"""
|
| 28 |
+
return self.automata.get(name)
|
| 29 |
+
|
| 30 |
+
def list_automata(self) -> list:
|
| 31 |
+
"""List all registered automata"""
|
| 32 |
+
return list(self.automata.keys())
|
| 33 |
+
|
| 34 |
+
async def execute(
|
| 35 |
+
self,
|
| 36 |
+
automaton_name: str,
|
| 37 |
+
code: str,
|
| 38 |
+
language: Language,
|
| 39 |
+
task: TaskType,
|
| 40 |
+
**kwargs
|
| 41 |
+
) -> Dict[str, Any]:
|
| 42 |
+
"""Execute an automaton"""
|
| 43 |
+
automaton = self.get_automaton(automaton_name)
|
| 44 |
+
|
| 45 |
+
if not automaton:
|
| 46 |
+
return {
|
| 47 |
+
"success": False,
|
| 48 |
+
"error": f"Automaton '{automaton_name}' not found"
|
| 49 |
+
}
|
| 50 |
+
|
| 51 |
+
try:
|
| 52 |
+
if hasattr(automaton, 'can_handle'):
|
| 53 |
+
if not automaton.can_handle(code, language, task):
|
| 54 |
+
return {
|
| 55 |
+
"success": False,
|
| 56 |
+
"error": "Automaton cannot handle this task"
|
| 57 |
+
}
|
| 58 |
+
|
| 59 |
+
if hasattr(automaton, 'execute'):
|
| 60 |
+
result = await automaton.execute(code, **kwargs)
|
| 61 |
+
return result
|
| 62 |
+
else:
|
| 63 |
+
return {
|
| 64 |
+
"success": False,
|
| 65 |
+
"error": "Automaton does not implement execute method"
|
| 66 |
+
}
|
| 67 |
+
|
| 68 |
+
except Exception as e:
|
| 69 |
+
logger.error(f"Error executing automaton {automaton_name}: {e}")
|
| 70 |
+
return {
|
| 71 |
+
"success": False,
|
| 72 |
+
"error": str(e)
|
| 73 |
+
}
|
backend/app/core/distillation.py
ADDED
|
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Distillation Logger
|
| 3 |
+
|
| 4 |
+
Captures high-quality interactions (Teacher -> Student) for Knowledge Distillation.
|
| 5 |
+
Logs Prompt/Response pairs to a dataset file for future fine-tuning of local SLMs.
|
| 6 |
+
"""
|
| 7 |
+
import json
|
| 8 |
+
import logging
|
| 9 |
+
import time
|
| 10 |
+
from pathlib import Path
|
| 11 |
+
from typing import Dict, Any, Optional
|
| 12 |
+
|
| 13 |
+
from app.config import settings
|
| 14 |
+
from app.models.schemas import TaskType, Language
|
| 15 |
+
|
| 16 |
+
logger = logging.getLogger(__name__)
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class DistillationLogger:
|
| 20 |
+
"""Logs interactions for knowledge distillation"""
|
| 21 |
+
|
| 22 |
+
def __init__(self):
|
| 23 |
+
self.enabled = settings.enable_distillation
|
| 24 |
+
self.dataset_dir = settings.data_dir / "datasets"
|
| 25 |
+
self.dataset_file = self.dataset_dir / "distillation_v1.jsonl"
|
| 26 |
+
|
| 27 |
+
if self.enabled:
|
| 28 |
+
self._ensure_setup()
|
| 29 |
+
|
| 30 |
+
def _ensure_setup(self):
|
| 31 |
+
"""Ensure dataset directory exists"""
|
| 32 |
+
try:
|
| 33 |
+
self.dataset_dir.mkdir(parents=True, exist_ok=True)
|
| 34 |
+
logger.info(f"Distillation logger initialized. Dataset: {self.dataset_file}")
|
| 35 |
+
except Exception as e:
|
| 36 |
+
logger.error(f"Failed to setup distillation logger: {e}")
|
| 37 |
+
self.enabled = False
|
| 38 |
+
|
| 39 |
+
async def log_interaction(
|
| 40 |
+
self,
|
| 41 |
+
task: TaskType,
|
| 42 |
+
language: Language,
|
| 43 |
+
code_input: str,
|
| 44 |
+
context: Optional[str],
|
| 45 |
+
output: str,
|
| 46 |
+
model: str,
|
| 47 |
+
score: float = 1.0
|
| 48 |
+
):
|
| 49 |
+
"""
|
| 50 |
+
Log a successful interaction
|
| 51 |
+
|
| 52 |
+
Format follows Alpaca/Instruction tuning standards:
|
| 53 |
+
{
|
| 54 |
+
"instruction": "The task description",
|
| 55 |
+
"input": "The code context (optional)",
|
| 56 |
+
"output": "The model response",
|
| 57 |
+
...metadata
|
| 58 |
+
}
|
| 59 |
+
"""
|
| 60 |
+
if not self.enabled:
|
| 61 |
+
return
|
| 62 |
+
|
| 63 |
+
try:
|
| 64 |
+
# Construct instruction based on task
|
| 65 |
+
instruction = f"Perform task: {task} for language: {language}"
|
| 66 |
+
if context:
|
| 67 |
+
instruction += f". Context: {context}"
|
| 68 |
+
|
| 69 |
+
entry = {
|
| 70 |
+
"instruction": instruction,
|
| 71 |
+
"input": code_input,
|
| 72 |
+
"output": output,
|
| 73 |
+
"metadata": {
|
| 74 |
+
"task": task,
|
| 75 |
+
"language": language,
|
| 76 |
+
"teacher_model": model,
|
| 77 |
+
"timestamp": time.time(),
|
| 78 |
+
"score": score
|
| 79 |
+
}
|
| 80 |
+
}
|
| 81 |
+
|
| 82 |
+
# Append to JSONL file
|
| 83 |
+
with open(self.dataset_file, "a", encoding="utf-8") as f:
|
| 84 |
+
f.write(json.dumps(entry, ensure_ascii=False) + "\n")
|
| 85 |
+
|
| 86 |
+
logger.debug("Logged distillation example")
|
| 87 |
+
|
| 88 |
+
except Exception as e:
|
| 89 |
+
logger.error(f"Failed to log distillation example: {e}")
|
| 90 |
+
|
| 91 |
+
# Global instance
|
| 92 |
+
distillation_logger = DistillationLogger()
|
backend/app/core/lifecycle.py
ADDED
|
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Lifecycle Manager
|
| 3 |
+
|
| 4 |
+
Manages the lifecycle of engines and components:
|
| 5 |
+
- Initialization
|
| 6 |
+
- Health checks
|
| 7 |
+
- Graceful shutdown
|
| 8 |
+
- Resource cleanup
|
| 9 |
+
"""
|
| 10 |
+
import logging
|
| 11 |
+
import asyncio
|
| 12 |
+
from typing import Dict, Any, List
|
| 13 |
+
from contextlib import asynccontextmanager
|
| 14 |
+
|
| 15 |
+
logger = logging.getLogger(__name__)
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class LifecycleManager:
|
| 19 |
+
"""Manages component lifecycle"""
|
| 20 |
+
|
| 21 |
+
def __init__(self):
|
| 22 |
+
self.components: Dict[str, Any] = {}
|
| 23 |
+
self.initialized = False
|
| 24 |
+
logger.info("LifecycleManager created")
|
| 25 |
+
|
| 26 |
+
def register_component(self, name: str, component: Any):
|
| 27 |
+
"""Register a component for lifecycle management"""
|
| 28 |
+
self.components[name] = component
|
| 29 |
+
logger.info(f"Registered component: {name}")
|
| 30 |
+
|
| 31 |
+
async def initialize_all(self):
|
| 32 |
+
"""Initialize all registered components"""
|
| 33 |
+
logger.info("Initializing all components...")
|
| 34 |
+
|
| 35 |
+
for name, component in self.components.items():
|
| 36 |
+
try:
|
| 37 |
+
if hasattr(component, 'initialize'):
|
| 38 |
+
logger.info(f"Initializing {name}...")
|
| 39 |
+
await component.initialize()
|
| 40 |
+
logger.info(f"✓ {name} initialized")
|
| 41 |
+
except Exception as e:
|
| 42 |
+
logger.error(f"Failed to initialize {name}: {e}")
|
| 43 |
+
raise
|
| 44 |
+
|
| 45 |
+
self.initialized = True
|
| 46 |
+
logger.info("All components initialized successfully")
|
| 47 |
+
|
| 48 |
+
async def shutdown_all(self):
|
| 49 |
+
"""Shutdown all components gracefully"""
|
| 50 |
+
logger.info("Shutting down all components...")
|
| 51 |
+
|
| 52 |
+
for name, component in reversed(list(self.components.items())):
|
| 53 |
+
try:
|
| 54 |
+
if hasattr(component, 'shutdown'):
|
| 55 |
+
logger.info(f"Shutting down {name}...")
|
| 56 |
+
await component.shutdown()
|
| 57 |
+
logger.info(f"✓ {name} shut down")
|
| 58 |
+
except Exception as e:
|
| 59 |
+
logger.error(f"Error shutting down {name}: {e}")
|
| 60 |
+
|
| 61 |
+
self.initialized = False
|
| 62 |
+
logger.info("All components shut down")
|
| 63 |
+
|
| 64 |
+
async def health_check(self) -> Dict[str, Any]:
|
| 65 |
+
"""Check health of all components"""
|
| 66 |
+
health_status = {
|
| 67 |
+
"status": "healthy",
|
| 68 |
+
"components": {}
|
| 69 |
+
}
|
| 70 |
+
|
| 71 |
+
for name, component in self.components.items():
|
| 72 |
+
try:
|
| 73 |
+
if hasattr(component, 'health_check'):
|
| 74 |
+
component_health = await component.health_check()
|
| 75 |
+
health_status["components"][name] = component_health
|
| 76 |
+
else:
|
| 77 |
+
health_status["components"][name] = {
|
| 78 |
+
"status": "unknown",
|
| 79 |
+
"message": "No health check implemented"
|
| 80 |
+
}
|
| 81 |
+
except Exception as e:
|
| 82 |
+
health_status["components"][name] = {
|
| 83 |
+
"status": "unhealthy",
|
| 84 |
+
"error": str(e)
|
| 85 |
+
}
|
| 86 |
+
health_status["status"] = "degraded"
|
| 87 |
+
|
| 88 |
+
return health_status
|
| 89 |
+
|
| 90 |
+
@asynccontextmanager
|
| 91 |
+
async def lifespan(self):
|
| 92 |
+
"""Context manager for application lifespan"""
|
| 93 |
+
try:
|
| 94 |
+
await self.initialize_all()
|
| 95 |
+
yield
|
| 96 |
+
finally:
|
| 97 |
+
await self.shutdown_all()
|
backend/app/core/model_cache.py
ADDED
|
@@ -0,0 +1,240 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Model Cache with LRU Eviction
|
| 3 |
+
|
| 4 |
+
Intelligent caching system for Micro-SLMs to minimize loading time.
|
| 5 |
+
Keeps the most recently used models in memory.
|
| 6 |
+
"""
|
| 7 |
+
import logging
|
| 8 |
+
import asyncio
|
| 9 |
+
from collections import OrderedDict
|
| 10 |
+
from typing import Optional, Callable, Any, Dict
|
| 11 |
+
from datetime import datetime
|
| 12 |
+
|
| 13 |
+
try:
|
| 14 |
+
import psutil
|
| 15 |
+
HAS_PSUTIL = True
|
| 16 |
+
except ImportError:
|
| 17 |
+
HAS_PSUTIL = False
|
| 18 |
+
|
| 19 |
+
logger = logging.getLogger(__name__)
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class ModelCache:
|
| 23 |
+
"""
|
| 24 |
+
LRU (Least Recently Used) Cache for Micro-SLM models.
|
| 25 |
+
|
| 26 |
+
Features:
|
| 27 |
+
- Automatic eviction of least recently used models
|
| 28 |
+
- Memory usage tracking
|
| 29 |
+
- Async model loading
|
| 30 |
+
- Thread-safe operations
|
| 31 |
+
"""
|
| 32 |
+
|
| 33 |
+
def __init__(
|
| 34 |
+
self,
|
| 35 |
+
max_models: int = 3,
|
| 36 |
+
max_memory_mb: int = 2000,
|
| 37 |
+
enable_stats: bool = True
|
| 38 |
+
):
|
| 39 |
+
"""
|
| 40 |
+
Initialize the model cache.
|
| 41 |
+
|
| 42 |
+
Args:
|
| 43 |
+
max_models: Maximum number of models to keep in cache
|
| 44 |
+
max_memory_mb: Maximum memory usage in MB (soft limit)
|
| 45 |
+
enable_stats: Enable statistics tracking
|
| 46 |
+
"""
|
| 47 |
+
self.cache: OrderedDict[str, Any] = OrderedDict()
|
| 48 |
+
self.max_models = max_models
|
| 49 |
+
self.max_memory_mb = max_memory_mb
|
| 50 |
+
self.enable_stats = enable_stats
|
| 51 |
+
|
| 52 |
+
# Statistics
|
| 53 |
+
self.stats = {
|
| 54 |
+
"hits": 0,
|
| 55 |
+
"misses": 0,
|
| 56 |
+
"evictions": 0,
|
| 57 |
+
"loads": 0
|
| 58 |
+
}
|
| 59 |
+
|
| 60 |
+
# Locks for thread safety
|
| 61 |
+
self._lock = asyncio.Lock()
|
| 62 |
+
|
| 63 |
+
logger.info(f"ModelCache initialized: max_models={max_models}, max_memory={max_memory_mb}MB")
|
| 64 |
+
|
| 65 |
+
async def get_or_load(
|
| 66 |
+
self,
|
| 67 |
+
model_name: str,
|
| 68 |
+
loader_func: Callable,
|
| 69 |
+
*args,
|
| 70 |
+
**kwargs
|
| 71 |
+
) -> Any:
|
| 72 |
+
"""
|
| 73 |
+
Get model from cache or load it if not present.
|
| 74 |
+
|
| 75 |
+
Args:
|
| 76 |
+
model_name: Unique identifier for the model
|
| 77 |
+
loader_func: Async function to load the model
|
| 78 |
+
*args, **kwargs: Arguments to pass to loader_func
|
| 79 |
+
|
| 80 |
+
Returns:
|
| 81 |
+
The loaded model instance
|
| 82 |
+
"""
|
| 83 |
+
async with self._lock:
|
| 84 |
+
# Check cache
|
| 85 |
+
if model_name in self.cache:
|
| 86 |
+
# Cache hit
|
| 87 |
+
self.cache.move_to_end(model_name)
|
| 88 |
+
self.stats["hits"] += 1
|
| 89 |
+
|
| 90 |
+
logger.info(
|
| 91 |
+
f"✅ Cache HIT: {model_name} "
|
| 92 |
+
f"(hit rate: {self.get_hit_rate():.1%})"
|
| 93 |
+
)
|
| 94 |
+
|
| 95 |
+
return self.cache[model_name]
|
| 96 |
+
|
| 97 |
+
# Cache miss - need to load
|
| 98 |
+
self.stats["misses"] += 1
|
| 99 |
+
logger.info(f"❌ Cache MISS: {model_name}")
|
| 100 |
+
|
| 101 |
+
# Check if we need to evict
|
| 102 |
+
await self._evict_if_needed()
|
| 103 |
+
|
| 104 |
+
# Load the model
|
| 105 |
+
logger.info(f"📥 Loading model: {model_name}...")
|
| 106 |
+
load_start = datetime.now()
|
| 107 |
+
|
| 108 |
+
try:
|
| 109 |
+
model = await loader_func(*args, **kwargs)
|
| 110 |
+
load_duration = (datetime.now() - load_start).total_seconds()
|
| 111 |
+
|
| 112 |
+
logger.info(f"✓ Loaded {model_name} in {load_duration:.2f}s")
|
| 113 |
+
|
| 114 |
+
# Add to cache
|
| 115 |
+
self.cache[model_name] = model
|
| 116 |
+
self.cache.move_to_end(model_name)
|
| 117 |
+
self.stats["loads"] += 1
|
| 118 |
+
|
| 119 |
+
return model
|
| 120 |
+
|
| 121 |
+
except Exception as e:
|
| 122 |
+
logger.error(f"Failed to load {model_name}: {e}")
|
| 123 |
+
raise
|
| 124 |
+
|
| 125 |
+
async def _evict_if_needed(self):
|
| 126 |
+
"""Evict least recently used model if cache is full"""
|
| 127 |
+
|
| 128 |
+
# Check model count limit
|
| 129 |
+
if len(self.cache) >= self.max_models:
|
| 130 |
+
await self._evict_oldest()
|
| 131 |
+
return
|
| 132 |
+
|
| 133 |
+
# Check memory limit
|
| 134 |
+
memory_usage = self._get_memory_usage_mb()
|
| 135 |
+
if memory_usage > self.max_memory_mb:
|
| 136 |
+
logger.warning(
|
| 137 |
+
f"Memory usage ({memory_usage:.0f}MB) exceeds limit "
|
| 138 |
+
f"({self.max_memory_mb}MB)"
|
| 139 |
+
)
|
| 140 |
+
await self._evict_oldest()
|
| 141 |
+
|
| 142 |
+
async def _evict_oldest(self):
|
| 143 |
+
"""Evict the least recently used model"""
|
| 144 |
+
|
| 145 |
+
if not self.cache:
|
| 146 |
+
return
|
| 147 |
+
|
| 148 |
+
# Get oldest (first) item
|
| 149 |
+
oldest_name = next(iter(self.cache))
|
| 150 |
+
oldest_model = self.cache.pop(oldest_name)
|
| 151 |
+
|
| 152 |
+
self.stats["evictions"] += 1
|
| 153 |
+
|
| 154 |
+
logger.info(f"🗑️ Evicting: {oldest_name}")
|
| 155 |
+
|
| 156 |
+
# Cleanup model resources
|
| 157 |
+
try:
|
| 158 |
+
if hasattr(oldest_model, 'shutdown'):
|
| 159 |
+
await oldest_model.shutdown()
|
| 160 |
+
elif hasattr(oldest_model, 'cleanup'):
|
| 161 |
+
await oldest_model.cleanup()
|
| 162 |
+
except Exception as e:
|
| 163 |
+
logger.warning(f"Error during model cleanup: {e}")
|
| 164 |
+
|
| 165 |
+
def _get_memory_usage_mb(self) -> float:
|
| 166 |
+
"""Get current process memory usage in MB"""
|
| 167 |
+
if not HAS_PSUTIL:
|
| 168 |
+
return 0.0
|
| 169 |
+
|
| 170 |
+
try:
|
| 171 |
+
process = psutil.Process()
|
| 172 |
+
return process.memory_info().rss / (1024 * 1024)
|
| 173 |
+
except Exception:
|
| 174 |
+
return 0.0
|
| 175 |
+
|
| 176 |
+
def get_hit_rate(self) -> float:
|
| 177 |
+
"""Calculate cache hit rate"""
|
| 178 |
+
total = self.stats["hits"] + self.stats["misses"]
|
| 179 |
+
if total == 0:
|
| 180 |
+
return 0.0
|
| 181 |
+
return self.stats["hits"] / total
|
| 182 |
+
|
| 183 |
+
def get_stats(self) -> Dict[str, Any]:
|
| 184 |
+
"""Get cache statistics"""
|
| 185 |
+
return {
|
| 186 |
+
**self.stats,
|
| 187 |
+
"cached_models": len(self.cache),
|
| 188 |
+
"model_names": list(self.cache.keys()),
|
| 189 |
+
"hit_rate": self.get_hit_rate(),
|
| 190 |
+
"memory_usage_mb": self._get_memory_usage_mb()
|
| 191 |
+
}
|
| 192 |
+
|
| 193 |
+
async def clear(self):
|
| 194 |
+
"""Clear all cached models"""
|
| 195 |
+
async with self._lock:
|
| 196 |
+
logger.info("Clearing model cache...")
|
| 197 |
+
|
| 198 |
+
for name, model in self.cache.items():
|
| 199 |
+
try:
|
| 200 |
+
if hasattr(model, 'shutdown'):
|
| 201 |
+
await model.shutdown()
|
| 202 |
+
except Exception as e:
|
| 203 |
+
logger.warning(f"Error shutting down {name}: {e}")
|
| 204 |
+
|
| 205 |
+
self.cache.clear()
|
| 206 |
+
logger.info("Cache cleared")
|
| 207 |
+
|
| 208 |
+
async def preload(self, model_name: str, loader_func: Callable, *args, **kwargs):
|
| 209 |
+
"""
|
| 210 |
+
Preload a model into cache (prefetching).
|
| 211 |
+
|
| 212 |
+
Useful for anticipating which model will be needed next.
|
| 213 |
+
"""
|
| 214 |
+
logger.info(f"🔮 Prefetching: {model_name}")
|
| 215 |
+
await self.get_or_load(model_name, loader_func, *args, **kwargs)
|
| 216 |
+
|
| 217 |
+
def contains(self, model_name: str) -> bool:
|
| 218 |
+
"""Check if model is in cache"""
|
| 219 |
+
return model_name in self.cache
|
| 220 |
+
|
| 221 |
+
async def remove(self, model_name: str):
|
| 222 |
+
"""Manually remove a model from cache"""
|
| 223 |
+
async with self._lock:
|
| 224 |
+
if model_name in self.cache:
|
| 225 |
+
model = self.cache.pop(model_name)
|
| 226 |
+
logger.info(f"Removed {model_name} from cache")
|
| 227 |
+
|
| 228 |
+
try:
|
| 229 |
+
if hasattr(model, 'shutdown'):
|
| 230 |
+
await model.shutdown()
|
| 231 |
+
except Exception as e:
|
| 232 |
+
logger.warning(f"Error shutting down {model_name}: {e}")
|
| 233 |
+
|
| 234 |
+
|
| 235 |
+
# Global cache instance
|
| 236 |
+
model_cache = ModelCache(
|
| 237 |
+
max_models=3, # Keep 3 models in memory
|
| 238 |
+
max_memory_mb=2000, # 2GB soft limit
|
| 239 |
+
enable_stats=True
|
| 240 |
+
)
|
backend/app/core/orchestrator.py
ADDED
|
@@ -0,0 +1,695 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Core orchestrator for SLM Code Engine
|
| 3 |
+
|
| 4 |
+
Responsible for:
|
| 5 |
+
- Routing tasks to appropriate engines (automata vs SLM)
|
| 6 |
+
- Building execution pipelines
|
| 7 |
+
- Coordinating between micro-SLMs and automata
|
| 8 |
+
- Collecting metrics and logging
|
| 9 |
+
"""
|
| 10 |
+
import time
|
| 11 |
+
import logging
|
| 12 |
+
from typing import Dict, Any, List, Optional
|
| 13 |
+
from pathlib import Path
|
| 14 |
+
|
| 15 |
+
from app.config import settings
|
| 16 |
+
from app.models.schemas import TaskType, Language
|
| 17 |
+
from app.core.router import Router
|
| 18 |
+
from app.core.router_v2 import router_v2
|
| 19 |
+
from app.core.task_decomposer import task_decomposer
|
| 20 |
+
from app.core.slm_registry import slm_registry
|
| 21 |
+
from app.core.automata_manager import AutomataManager
|
| 22 |
+
from app.core.rag import RAGRetriever
|
| 23 |
+
from app.core.lifecycle import LifecycleManager
|
| 24 |
+
from app.core.pipeline import Pipeline
|
| 25 |
+
from app.core.model_cache import model_cache
|
| 26 |
+
from app.engines.base import BaseEngine
|
| 27 |
+
from app.automata.base import BaseAutomaton
|
| 28 |
+
from app.rag.retriever import CodeRetriever
|
| 29 |
+
|
| 30 |
+
logger = logging.getLogger(__name__)
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
class Orchestrator:
|
| 34 |
+
"""Main orchestrator coordinating all components"""
|
| 35 |
+
|
| 36 |
+
def __init__(self):
|
| 37 |
+
self.router = Router()
|
| 38 |
+
self.router_v2 = router_v2 # New micro-SLM router
|
| 39 |
+
self.task_decomposer = task_decomposer
|
| 40 |
+
self.slm_registry = slm_registry
|
| 41 |
+
self.automata_manager = AutomataManager()
|
| 42 |
+
self.pipeline: Optional[Pipeline] = None
|
| 43 |
+
self.engines: Dict[str, BaseEngine] = {}
|
| 44 |
+
self.automata: Dict[str, BaseAutomaton] = {}
|
| 45 |
+
self.retriever = None
|
| 46 |
+
self.enable_decomposition = True # Enable task decomposition
|
| 47 |
+
self.initialized = False
|
| 48 |
+
self._metrics = []
|
| 49 |
+
|
| 50 |
+
async def initialize(self):
|
| 51 |
+
"""Initialize all components"""
|
| 52 |
+
logger.info("Initializing orchestrator...")
|
| 53 |
+
|
| 54 |
+
try:
|
| 55 |
+
# Initialize router
|
| 56 |
+
self.router = Router()
|
| 57 |
+
await self.router.initialize()
|
| 58 |
+
|
| 59 |
+
# Initialize pipeline builder
|
| 60 |
+
self.pipeline = Pipeline()
|
| 61 |
+
|
| 62 |
+
# Load automata (lightweight, always loaded)
|
| 63 |
+
await self._load_automata()
|
| 64 |
+
|
| 65 |
+
# Load SLM engines (lazy loading for performance)
|
| 66 |
+
await self._load_engines()
|
| 67 |
+
|
| 68 |
+
# Initialize RAG retriever (lazy loading)
|
| 69 |
+
index_path = settings.data_dir / "rag_index.faiss"
|
| 70 |
+
self.retriever = CodeRetriever(index_path=str(index_path))
|
| 71 |
+
logger.info("RAG retriever configured (lazy loading)")
|
| 72 |
+
|
| 73 |
+
self.initialized = True
|
| 74 |
+
logger.info("Orchestrator initialized successfully")
|
| 75 |
+
|
| 76 |
+
except Exception as e:
|
| 77 |
+
logger.error(f"Failed to initialize orchestrator: {e}")
|
| 78 |
+
raise
|
| 79 |
+
|
| 80 |
+
async def _load_automata(self):
|
| 81 |
+
"""Load all available automata"""
|
| 82 |
+
logger.info("Loading automata...")
|
| 83 |
+
|
| 84 |
+
try:
|
| 85 |
+
# Import automata (we'll create these files next)
|
| 86 |
+
from app.automata.formatter import PythonFormatter
|
| 87 |
+
from app.automata.linter import PythonLinter
|
| 88 |
+
from app.automata.trace_parser import TraceParser
|
| 89 |
+
from app.automata.ast_fixer import ASTFixer
|
| 90 |
+
from app.automata.test_generator import TestTemplateGenerator
|
| 91 |
+
from app.automata.runtime_fixer import RuntimeFixer
|
| 92 |
+
|
| 93 |
+
# Register automata
|
| 94 |
+
self.automata["python_formatter"] = PythonFormatter()
|
| 95 |
+
self.automata["python_linter"] = PythonLinter()
|
| 96 |
+
self.automata["trace_parser"] = TraceParser()
|
| 97 |
+
self.automata["ast_fixer"] = ASTFixer()
|
| 98 |
+
self.automata["runtime_fixer"] = RuntimeFixer() # NEW: Fix runtime errors
|
| 99 |
+
self.automata["test_template"] = TestTemplateGenerator()
|
| 100 |
+
|
| 101 |
+
logger.info(f"Loaded {len(self.automata)} automata: {list(self.automata.keys())}")
|
| 102 |
+
|
| 103 |
+
except Exception as e:
|
| 104 |
+
logger.warning(f"Failed to load some automata: {e}")
|
| 105 |
+
# Non-critical, continue with available automata
|
| 106 |
+
|
| 107 |
+
async def _load_engines(self):
|
| 108 |
+
"""Load SLM engines (lazy loading)"""
|
| 109 |
+
logger.info("Loading SLM engines...")
|
| 110 |
+
|
| 111 |
+
# For V1, we'll implement lazy loading
|
| 112 |
+
# Engines are loaded on first use to save memory
|
| 113 |
+
logger.info("SLM engines configured for lazy loading")
|
| 114 |
+
|
| 115 |
+
async def _get_engine(self, engine_name: str) -> BaseEngine:
|
| 116 |
+
"""Get or load an engine on demand"""
|
| 117 |
+
if engine_name not in self.engines:
|
| 118 |
+
logger.info(f"Loading engine: {engine_name}")
|
| 119 |
+
|
| 120 |
+
try:
|
| 121 |
+
if engine_name == "groq":
|
| 122 |
+
from app.engines.groq_engine import GroqEngine
|
| 123 |
+
self.engines[engine_name] = GroqEngine()
|
| 124 |
+
await self.engines[engine_name].initialize()
|
| 125 |
+
|
| 126 |
+
elif engine_name == "phi2":
|
| 127 |
+
from app.engines.phi2 import Phi2Engine
|
| 128 |
+
self.engines[engine_name] = Phi2Engine()
|
| 129 |
+
await self.engines[engine_name].initialize()
|
| 130 |
+
|
| 131 |
+
elif engine_name == "starcoder":
|
| 132 |
+
from app.engines.starcoder import StarCoderEngine
|
| 133 |
+
self.engines[engine_name] = StarCoderEngine()
|
| 134 |
+
await self.engines[engine_name].initialize()
|
| 135 |
+
|
| 136 |
+
elif engine_name == "codet5":
|
| 137 |
+
from app.engines.codet5 import CodeT5Engine
|
| 138 |
+
self.engines[engine_name] = CodeT5Engine()
|
| 139 |
+
await self.engines[engine_name].initialize()
|
| 140 |
+
|
| 141 |
+
else:
|
| 142 |
+
# Check if it's a registered Micro-SLM
|
| 143 |
+
micro_slm_info = self.slm_registry.get_model(engine_name)
|
| 144 |
+
if micro_slm_info:
|
| 145 |
+
from app.engines.micro_slm import MicroSLMEngine
|
| 146 |
+
|
| 147 |
+
# Use cache for Micro-SLMs
|
| 148 |
+
async def load_micro_slm():
|
| 149 |
+
engine = MicroSLMEngine(
|
| 150 |
+
name=engine_name,
|
| 151 |
+
model_path=micro_slm_info.model_path
|
| 152 |
+
)
|
| 153 |
+
await engine.initialize()
|
| 154 |
+
return engine
|
| 155 |
+
|
| 156 |
+
self.engines[engine_name] = await model_cache.get_or_load(
|
| 157 |
+
model_name=engine_name,
|
| 158 |
+
loader_func=load_micro_slm
|
| 159 |
+
)
|
| 160 |
+
else:
|
| 161 |
+
raise ValueError(f"Unknown engine: {engine_name}")
|
| 162 |
+
|
| 163 |
+
logger.info(f"Engine {engine_name} loaded successfully")
|
| 164 |
+
|
| 165 |
+
except Exception as e:
|
| 166 |
+
logger.error(f"Failed to load engine {engine_name}: {e}")
|
| 167 |
+
raise
|
| 168 |
+
|
| 169 |
+
return self.engines[engine_name]
|
| 170 |
+
|
| 171 |
+
async def process(
|
| 172 |
+
self,
|
| 173 |
+
task: TaskType,
|
| 174 |
+
code: str,
|
| 175 |
+
language: Language,
|
| 176 |
+
context: Optional[str] = None,
|
| 177 |
+
trace: Optional[str] = None,
|
| 178 |
+
history: Optional[List[Dict[str, str]]] = None
|
| 179 |
+
) -> Dict[str, Any]:
|
| 180 |
+
"""
|
| 181 |
+
Main processing method
|
| 182 |
+
|
| 183 |
+
Args:
|
| 184 |
+
task: Type of task to perform
|
| 185 |
+
code: Source code to process
|
| 186 |
+
language: Programming language
|
| 187 |
+
context: Additional context
|
| 188 |
+
trace: Error trace (if applicable)
|
| 189 |
+
history: Conversation history for context
|
| 190 |
+
|
| 191 |
+
Returns:
|
| 192 |
+
Dict with results and metadata
|
| 193 |
+
"""
|
| 194 |
+
start_time = time.time()
|
| 195 |
+
pipeline_steps = []
|
| 196 |
+
|
| 197 |
+
try:
|
| 198 |
+
# Step 0: Check if we should decompose this task
|
| 199 |
+
if self.enable_decomposition and self._should_decompose(task):
|
| 200 |
+
return await self._process_with_decomposition(
|
| 201 |
+
task=task,
|
| 202 |
+
code=code,
|
| 203 |
+
language=language,
|
| 204 |
+
context=context,
|
| 205 |
+
trace=trace,
|
| 206 |
+
history=history,
|
| 207 |
+
start_time=start_time,
|
| 208 |
+
pipeline_steps=pipeline_steps
|
| 209 |
+
)
|
| 210 |
+
|
| 211 |
+
# Step 1: Route the task using RouterV2 (supports Micro-SLMs)
|
| 212 |
+
routing_decision = await self.router_v2.route_task(
|
| 213 |
+
task=task,
|
| 214 |
+
code=code,
|
| 215 |
+
language=language,
|
| 216 |
+
context=context
|
| 217 |
+
)
|
| 218 |
+
|
| 219 |
+
logger.info(f"Routing decision: {routing_decision}")
|
| 220 |
+
|
| 221 |
+
# Step 2: Try automata first (if enabled and applicable)
|
| 222 |
+
if settings.enable_automata_first and routing_decision.get("try_automata"):
|
| 223 |
+
automata_result = await self._try_automata(
|
| 224 |
+
task=task,
|
| 225 |
+
code=code,
|
| 226 |
+
language=language,
|
| 227 |
+
trace=trace,
|
| 228 |
+
pipeline_steps=pipeline_steps
|
| 229 |
+
)
|
| 230 |
+
|
| 231 |
+
if automata_result.get("success"):
|
| 232 |
+
# Check if code was actually modified
|
| 233 |
+
code_changed = automata_result.get("result") != code
|
| 234 |
+
|
| 235 |
+
# For FIX tasks, if code unchanged, try SLM for deeper analysis
|
| 236 |
+
if task == TaskType.FIX and not code_changed:
|
| 237 |
+
logger.info("Automata found no issues, trying SLM for deeper analysis")
|
| 238 |
+
# Continue to SLM fallback
|
| 239 |
+
else:
|
| 240 |
+
# Automata succeeded with changes, return result
|
| 241 |
+
duration_ms = (time.time() - start_time) * 1000
|
| 242 |
+
return {
|
| 243 |
+
"success": True,
|
| 244 |
+
"task": task,
|
| 245 |
+
"result": automata_result["result"],
|
| 246 |
+
"explanation": automata_result.get("explanation"),
|
| 247 |
+
"suggestions": automata_result.get("suggestions", []),
|
| 248 |
+
"used_automata": True,
|
| 249 |
+
"used_slm": False,
|
| 250 |
+
"pipeline": pipeline_steps,
|
| 251 |
+
"total_duration_ms": duration_ms
|
| 252 |
+
}
|
| 253 |
+
|
| 254 |
+
# Step 3: Use SLM
|
| 255 |
+
slm_result = await self._use_slm(
|
| 256 |
+
task=task,
|
| 257 |
+
code=code,
|
| 258 |
+
language=language,
|
| 259 |
+
context=context,
|
| 260 |
+
trace=trace,
|
| 261 |
+
routing_decision=routing_decision,
|
| 262 |
+
pipeline_steps=pipeline_steps,
|
| 263 |
+
history=history
|
| 264 |
+
)
|
| 265 |
+
|
| 266 |
+
duration_ms = (time.time() - start_time) * 1000
|
| 267 |
+
|
| 268 |
+
return {
|
| 269 |
+
"success": slm_result.get("success", True),
|
| 270 |
+
"task": task,
|
| 271 |
+
"result": slm_result.get("result"),
|
| 272 |
+
"explanation": slm_result.get("explanation"),
|
| 273 |
+
"suggestions": slm_result.get("suggestions", []),
|
| 274 |
+
"used_automata": len([s for s in pipeline_steps if s["step_type"] == "automata"]) > 0,
|
| 275 |
+
"used_slm": True,
|
| 276 |
+
"pipeline": pipeline_steps,
|
| 277 |
+
"total_duration_ms": duration_ms
|
| 278 |
+
}
|
| 279 |
+
|
| 280 |
+
except Exception as e:
|
| 281 |
+
logger.error(f"Error in orchestrator.process: {e}", exc_info=True)
|
| 282 |
+
duration_ms = (time.time() - start_time) * 1000
|
| 283 |
+
|
| 284 |
+
from app.utils.localization import get_string
|
| 285 |
+
error_message = get_string("backend_error_generic", error=str(e))
|
| 286 |
+
|
| 287 |
+
return {
|
| 288 |
+
"success": False,
|
| 289 |
+
"task": task,
|
| 290 |
+
"error": error_message,
|
| 291 |
+
"used_automata": False,
|
| 292 |
+
"used_slm": False,
|
| 293 |
+
"pipeline": pipeline_steps,
|
| 294 |
+
"total_duration_ms": duration_ms
|
| 295 |
+
}
|
| 296 |
+
|
| 297 |
+
async def _try_automata(
|
| 298 |
+
self,
|
| 299 |
+
task: TaskType,
|
| 300 |
+
code: str,
|
| 301 |
+
language: Language,
|
| 302 |
+
trace: Optional[str],
|
| 303 |
+
pipeline_steps: List[Dict]
|
| 304 |
+
) -> Dict[str, Any]:
|
| 305 |
+
"""Try to handle task with automata only"""
|
| 306 |
+
|
| 307 |
+
# Map tasks to automata
|
| 308 |
+
automata_map = {
|
| 309 |
+
TaskType.FORMAT: ["python_formatter"] if language == Language.PYTHON else [],
|
| 310 |
+
TaskType.EXPLAIN: ["trace_parser"] if trace else [],
|
| 311 |
+
TaskType.FIX: ["runtime_fixer", "ast_fixer", "python_linter"] if language == Language.PYTHON else [],
|
| 312 |
+
}
|
| 313 |
+
|
| 314 |
+
automata_to_try = automata_map.get(task, [])
|
| 315 |
+
|
| 316 |
+
for automaton_name in automata_to_try:
|
| 317 |
+
if automaton_name not in self.automata:
|
| 318 |
+
continue
|
| 319 |
+
|
| 320 |
+
automaton = self.automata[automaton_name]
|
| 321 |
+
step_start = time.time()
|
| 322 |
+
|
| 323 |
+
try:
|
| 324 |
+
if automaton.can_handle(code, language, task):
|
| 325 |
+
result = await automaton.execute(code, trace=trace)
|
| 326 |
+
|
| 327 |
+
step_duration = (time.time() - step_start) * 1000
|
| 328 |
+
pipeline_steps.append({
|
| 329 |
+
"step_type": "automata",
|
| 330 |
+
"component": automaton_name,
|
| 331 |
+
"duration_ms": step_duration,
|
| 332 |
+
"success": True,
|
| 333 |
+
"details": {"automaton": automaton_name}
|
| 334 |
+
})
|
| 335 |
+
|
| 336 |
+
if result.get("success"):
|
| 337 |
+
return result
|
| 338 |
+
|
| 339 |
+
except Exception as e:
|
| 340 |
+
logger.warning(f"Automaton {automaton_name} failed: {e}")
|
| 341 |
+
step_duration = (time.time() - step_start) * 1000
|
| 342 |
+
pipeline_steps.append({
|
| 343 |
+
"step_type": "automata",
|
| 344 |
+
"component": automaton_name,
|
| 345 |
+
"duration_ms": step_duration,
|
| 346 |
+
"success": False,
|
| 347 |
+
"details": {"error": str(e)}
|
| 348 |
+
})
|
| 349 |
+
|
| 350 |
+
return {"success": False}
|
| 351 |
+
|
| 352 |
+
async def _use_slm(
|
| 353 |
+
self,
|
| 354 |
+
task: TaskType,
|
| 355 |
+
code: str,
|
| 356 |
+
language: Language,
|
| 357 |
+
context: Optional[str],
|
| 358 |
+
trace: Optional[str],
|
| 359 |
+
routing_decision: Dict,
|
| 360 |
+
pipeline_steps: List[Dict],
|
| 361 |
+
history: Optional[List[Dict[str, str]]] = None
|
| 362 |
+
) -> Dict[str, Any]:
|
| 363 |
+
"""Use SLM engine to process task"""
|
| 364 |
+
|
| 365 |
+
# Determine which engine to use
|
| 366 |
+
# If routed to micro_slm, use the handler_name as the engine name
|
| 367 |
+
if routing_decision.get("handler_type") == "micro_slm":
|
| 368 |
+
engine_name = routing_decision.get("handler_name")
|
| 369 |
+
else:
|
| 370 |
+
engine_name = routing_decision.get("engine", "groq")
|
| 371 |
+
|
| 372 |
+
# Enrich context with RAG examples
|
| 373 |
+
enriched_context = context or ""
|
| 374 |
+
if self.retriever and settings.enable_rag:
|
| 375 |
+
try:
|
| 376 |
+
rag_context = self.retriever.build_context(
|
| 377 |
+
query_code=code,
|
| 378 |
+
language=language,
|
| 379 |
+
task=task,
|
| 380 |
+
k=3 # Retrieve top 3 similar examples
|
| 381 |
+
)
|
| 382 |
+
if rag_context:
|
| 383 |
+
enriched_context = f"{enriched_context}\n\n{rag_context}" if enriched_context else rag_context
|
| 384 |
+
logger.debug("Enriched context with RAG examples")
|
| 385 |
+
except Exception as e:
|
| 386 |
+
logger.warning(f"Failed to enrich context with RAG: {e}")
|
| 387 |
+
|
| 388 |
+
step_start = time.time()
|
| 389 |
+
|
| 390 |
+
try:
|
| 391 |
+
engine = await self._get_engine(engine_name)
|
| 392 |
+
|
| 393 |
+
result = await engine.process(
|
| 394 |
+
task=task,
|
| 395 |
+
code=code,
|
| 396 |
+
language=language,
|
| 397 |
+
context=enriched_context,
|
| 398 |
+
trace=trace,
|
| 399 |
+
history=history
|
| 400 |
+
)
|
| 401 |
+
|
| 402 |
+
step_duration = (time.time() - step_start) * 1000
|
| 403 |
+
pipeline_steps.append({
|
| 404 |
+
"step_type": "slm",
|
| 405 |
+
"component": engine_name,
|
| 406 |
+
"duration_ms": step_duration,
|
| 407 |
+
"success": True,
|
| 408 |
+
"details": {"engine": engine_name}
|
| 409 |
+
})
|
| 410 |
+
|
| 411 |
+
# Log for distillation if using Teacher model (Groq)
|
| 412 |
+
if engine_name == "groq" and result.get("success"):
|
| 413 |
+
try:
|
| 414 |
+
from app.core.distillation import distillation_logger
|
| 415 |
+
await distillation_logger.log_interaction(
|
| 416 |
+
task=task,
|
| 417 |
+
language=language,
|
| 418 |
+
code_input=code,
|
| 419 |
+
context=enriched_context,
|
| 420 |
+
output=result.get("result") or result.get("explanation", ""),
|
| 421 |
+
model="groq-llama-3"
|
| 422 |
+
)
|
| 423 |
+
except Exception as e:
|
| 424 |
+
logger.warning(f"Failed to log distillation data: {e}")
|
| 425 |
+
|
| 426 |
+
return result
|
| 427 |
+
|
| 428 |
+
except Exception as e:
|
| 429 |
+
logger.error(f"SLM {engine_name} failed: {e}", exc_info=True)
|
| 430 |
+
step_duration = (time.time() - step_start) * 1000
|
| 431 |
+
pipeline_steps.append({
|
| 432 |
+
"step_type": "slm",
|
| 433 |
+
"component": engine_name,
|
| 434 |
+
"duration_ms": step_duration,
|
| 435 |
+
"success": False,
|
| 436 |
+
"details": {"error": str(e)}
|
| 437 |
+
})
|
| 438 |
+
raise
|
| 439 |
+
|
| 440 |
+
async def translate(
|
| 441 |
+
self,
|
| 442 |
+
code: str,
|
| 443 |
+
source_lang: Language,
|
| 444 |
+
target_lang: Language,
|
| 445 |
+
preserve_comments: bool = True
|
| 446 |
+
) -> Dict[str, Any]:
|
| 447 |
+
"""Translate code between languages"""
|
| 448 |
+
context = f"Translate from {source_lang} to {target_lang}"
|
| 449 |
+
if preserve_comments:
|
| 450 |
+
context += ". Preserve all comments."
|
| 451 |
+
|
| 452 |
+
return await self.process(
|
| 453 |
+
task=TaskType.TRANSLATE,
|
| 454 |
+
code=code,
|
| 455 |
+
language=source_lang,
|
| 456 |
+
context=context
|
| 457 |
+
)
|
| 458 |
+
|
| 459 |
+
async def generate_boilerplate(
|
| 460 |
+
self,
|
| 461 |
+
template_type: str,
|
| 462 |
+
language: Language,
|
| 463 |
+
name: str,
|
| 464 |
+
options: Optional[Dict[str, Any]] = None
|
| 465 |
+
) -> Dict[str, Any]:
|
| 466 |
+
"""Generate boilerplate code"""
|
| 467 |
+
context = f"Generate {template_type} boilerplate for {name}"
|
| 468 |
+
if options:
|
| 469 |
+
context += f" with options: {options}"
|
| 470 |
+
|
| 471 |
+
return await self.process(
|
| 472 |
+
task=TaskType.BOILERPLATE,
|
| 473 |
+
code="", # Empty code for generation
|
| 474 |
+
language=language,
|
| 475 |
+
context=context
|
| 476 |
+
)
|
| 477 |
+
|
| 478 |
+
async def get_status(self) -> Dict[str, Any]:
|
| 479 |
+
"""Get orchestrator status"""
|
| 480 |
+
return {
|
| 481 |
+
"ready": self.initialized,
|
| 482 |
+
"models_loaded": {
|
| 483 |
+
engine: True for engine in self.engines.keys()
|
| 484 |
+
},
|
| 485 |
+
"automata_available": list(self.automata.keys())
|
| 486 |
+
}
|
| 487 |
+
|
| 488 |
+
def _should_decompose(self, task: TaskType) -> bool:
|
| 489 |
+
"""Determine if a task should be decomposed into subtasks"""
|
| 490 |
+
# Decompose complex tasks
|
| 491 |
+
decomposable_tasks = [
|
| 492 |
+
TaskType.FIX,
|
| 493 |
+
TaskType.REFACTOR,
|
| 494 |
+
TaskType.FORMAT
|
| 495 |
+
]
|
| 496 |
+
return task in decomposable_tasks
|
| 497 |
+
|
| 498 |
+
async def _process_with_decomposition(
|
| 499 |
+
self,
|
| 500 |
+
task: TaskType,
|
| 501 |
+
code: str,
|
| 502 |
+
language: Language,
|
| 503 |
+
context: Optional[str],
|
| 504 |
+
trace: Optional[str],
|
| 505 |
+
history: Optional[List[Dict[str, str]]],
|
| 506 |
+
start_time: float,
|
| 507 |
+
pipeline_steps: List[Dict]
|
| 508 |
+
) -> Dict[str, Any]:
|
| 509 |
+
"""Process task using decomposition and micro-SLM routing"""
|
| 510 |
+
|
| 511 |
+
logger.info("Using decomposition-based processing")
|
| 512 |
+
|
| 513 |
+
# Step 1: Decompose task into subtasks
|
| 514 |
+
subtasks = await self.task_decomposer.decompose(
|
| 515 |
+
task=task,
|
| 516 |
+
code=code,
|
| 517 |
+
language=language,
|
| 518 |
+
context=context,
|
| 519 |
+
trace=trace
|
| 520 |
+
)
|
| 521 |
+
|
| 522 |
+
logger.info(f"Decomposed into {len(subtasks)} subtasks")
|
| 523 |
+
|
| 524 |
+
# Step 2: Process each subtask
|
| 525 |
+
results = []
|
| 526 |
+
current_code = code
|
| 527 |
+
|
| 528 |
+
for i, subtask in enumerate(subtasks):
|
| 529 |
+
logger.info(f"Processing subtask {i+1}/{len(subtasks)}: {subtask.subtask_type}")
|
| 530 |
+
|
| 531 |
+
# Route subtask to best handler
|
| 532 |
+
routing = await self.router_v2.route_subtask(
|
| 533 |
+
subtask_type=subtask.subtask_type,
|
| 534 |
+
code=current_code,
|
| 535 |
+
language=language,
|
| 536 |
+
context=subtask.context
|
| 537 |
+
)
|
| 538 |
+
|
| 539 |
+
logger.info(f"Routed to: {routing['handler_type']} ({routing['handler_name']})")
|
| 540 |
+
|
| 541 |
+
# Execute based on handler type
|
| 542 |
+
step_start = time.time()
|
| 543 |
+
|
| 544 |
+
if routing['handler_type'] == 'automata':
|
| 545 |
+
# Use automaton
|
| 546 |
+
result = await self._execute_automaton(
|
| 547 |
+
automaton_name=routing['handler_name'],
|
| 548 |
+
code=current_code,
|
| 549 |
+
trace=trace,
|
| 550 |
+
pipeline_steps=pipeline_steps
|
| 551 |
+
)
|
| 552 |
+
|
| 553 |
+
elif routing['handler_type'] == 'micro_slm':
|
| 554 |
+
# Use micro-SLM
|
| 555 |
+
logger.warning(f"Micro-SLM execution not yet fully implemented, falling back to Groq")
|
| 556 |
+
result = await self._execute_groq(
|
| 557 |
+
task=task,
|
| 558 |
+
code=current_code,
|
| 559 |
+
language=language,
|
| 560 |
+
context=subtask.context,
|
| 561 |
+
trace=trace,
|
| 562 |
+
history=history,
|
| 563 |
+
pipeline_steps=pipeline_steps
|
| 564 |
+
)
|
| 565 |
+
|
| 566 |
+
else: # groq
|
| 567 |
+
# Use Groq
|
| 568 |
+
result = await self._execute_groq(
|
| 569 |
+
task=task,
|
| 570 |
+
code=current_code,
|
| 571 |
+
language=language,
|
| 572 |
+
context=subtask.context,
|
| 573 |
+
trace=trace,
|
| 574 |
+
history=history,
|
| 575 |
+
pipeline_steps=pipeline_steps
|
| 576 |
+
)
|
| 577 |
+
|
| 578 |
+
step_duration = (time.time() - step_start) * 1000
|
| 579 |
+
|
| 580 |
+
# Record step
|
| 581 |
+
pipeline_steps.append({
|
| 582 |
+
"step_type": routing['handler_type'],
|
| 583 |
+
"component": routing['handler_name'],
|
| 584 |
+
"subtask": subtask.subtask_type,
|
| 585 |
+
"duration_ms": step_duration,
|
| 586 |
+
"success": result.get("success", False)
|
| 587 |
+
})
|
| 588 |
+
|
| 589 |
+
# Update current_code for next subtask
|
| 590 |
+
if result.get("success") and result.get("result"):
|
| 591 |
+
current_code = result["result"]
|
| 592 |
+
|
| 593 |
+
results.append(result)
|
| 594 |
+
|
| 595 |
+
# Step 3: Combine results
|
| 596 |
+
duration_ms = (time.time() - start_time) * 1000
|
| 597 |
+
|
| 598 |
+
# Get final result (last successful result)
|
| 599 |
+
final_result = current_code
|
| 600 |
+
|
| 601 |
+
# Combine explanations
|
| 602 |
+
explanations = [r.get("explanation", "") for r in results if r.get("explanation")]
|
| 603 |
+
combined_explanation = "\n\n".join(explanations) if explanations else "Processed via decomposition"
|
| 604 |
+
|
| 605 |
+
return {
|
| 606 |
+
"success": True,
|
| 607 |
+
"task": task,
|
| 608 |
+
"result": final_result,
|
| 609 |
+
"explanation": combined_explanation,
|
| 610 |
+
"suggestions": ["Code processed through micro-SLM mesh"],
|
| 611 |
+
"used_automata": any(s["step_type"] == "automata" for s in pipeline_steps),
|
| 612 |
+
"used_slm": any(s["step_type"] in ["micro_slm", "groq"] for s in pipeline_steps),
|
| 613 |
+
"pipeline": pipeline_steps,
|
| 614 |
+
"total_duration_ms": duration_ms,
|
| 615 |
+
"subtasks_processed": len(subtasks)
|
| 616 |
+
}
|
| 617 |
+
|
| 618 |
+
async def _execute_automaton(
|
| 619 |
+
self,
|
| 620 |
+
automaton_name: str,
|
| 621 |
+
code: str,
|
| 622 |
+
trace: Optional[str],
|
| 623 |
+
pipeline_steps: List[Dict]
|
| 624 |
+
) -> Dict[str, Any]:
|
| 625 |
+
"""Execute an automaton"""
|
| 626 |
+
try:
|
| 627 |
+
automaton = self.automata.get(automaton_name)
|
| 628 |
+
if not automaton:
|
| 629 |
+
return {"success": False, "error": f"Automaton {automaton_name} not found"}
|
| 630 |
+
|
| 631 |
+
result = await automaton.execute(code, trace=trace)
|
| 632 |
+
return result
|
| 633 |
+
|
| 634 |
+
except Exception as e:
|
| 635 |
+
logger.error(f"Automaton execution failed: {e}")
|
| 636 |
+
return {"success": False, "error": str(e)}
|
| 637 |
+
|
| 638 |
+
async def _execute_groq(
|
| 639 |
+
self,
|
| 640 |
+
task: TaskType,
|
| 641 |
+
code: str,
|
| 642 |
+
language: Language,
|
| 643 |
+
context: Optional[str],
|
| 644 |
+
trace: Optional[str],
|
| 645 |
+
history: Optional[List[Dict[str, str]]],
|
| 646 |
+
pipeline_steps: List[Dict]
|
| 647 |
+
) -> Dict[str, Any]:
|
| 648 |
+
"""Execute using Groq engine"""
|
| 649 |
+
try:
|
| 650 |
+
engine = await self._get_engine("groq")
|
| 651 |
+
|
| 652 |
+
# Enrich context with RAG if available
|
| 653 |
+
enriched_context = context or ""
|
| 654 |
+
if self.retriever and settings.enable_rag:
|
| 655 |
+
try:
|
| 656 |
+
rag_context = self.retriever.build_context(
|
| 657 |
+
query_code=code,
|
| 658 |
+
language=language,
|
| 659 |
+
task=task,
|
| 660 |
+
k=3
|
| 661 |
+
)
|
| 662 |
+
if rag_context:
|
| 663 |
+
enriched_context = f"{enriched_context}\n\n{rag_context}" if enriched_context else rag_context
|
| 664 |
+
except Exception as e:
|
| 665 |
+
logger.warning(f"RAG enrichment failed: {e}")
|
| 666 |
+
|
| 667 |
+
result = await engine.process(
|
| 668 |
+
task=task,
|
| 669 |
+
code=code,
|
| 670 |
+
language=language,
|
| 671 |
+
context=enriched_context,
|
| 672 |
+
trace=trace,
|
| 673 |
+
history=history
|
| 674 |
+
)
|
| 675 |
+
|
| 676 |
+
return result
|
| 677 |
+
|
| 678 |
+
except Exception as e:
|
| 679 |
+
logger.error(f"Groq execution failed: {e}")
|
| 680 |
+
return {"success": False, "error": str(e)}
|
| 681 |
+
|
| 682 |
+
|
| 683 |
+
async def shutdown(self):
|
| 684 |
+
"""Cleanup resources"""
|
| 685 |
+
logger.info("Shutting down orchestrator...")
|
| 686 |
+
|
| 687 |
+
for engine_name, engine in self.engines.items():
|
| 688 |
+
try:
|
| 689 |
+
await engine.shutdown()
|
| 690 |
+
logger.info(f"Engine {engine_name} shutdown complete")
|
| 691 |
+
except Exception as e:
|
| 692 |
+
logger.error(f"Error shutting down engine {engine_name}: {e}")
|
| 693 |
+
|
| 694 |
+
self.initialized = False
|
| 695 |
+
logger.info("Orchestrator shutdown complete")
|
backend/app/core/orchestrator_decomposition.py
ADDED
|
@@ -0,0 +1,193 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
def _should_decompose(self, task: TaskType) -> bool:
|
| 2 |
+
"""Determine if a task should be decomposed into subtasks"""
|
| 3 |
+
# Decompose complex tasks
|
| 4 |
+
decomposable_tasks = [
|
| 5 |
+
TaskType.FIX,
|
| 6 |
+
TaskType.REFACTOR,
|
| 7 |
+
TaskType.FORMAT
|
| 8 |
+
]
|
| 9 |
+
return task in decomposable_tasks
|
| 10 |
+
|
| 11 |
+
async def _process_with_decomposition(
|
| 12 |
+
self,
|
| 13 |
+
task: TaskType,
|
| 14 |
+
code: str,
|
| 15 |
+
language: Language,
|
| 16 |
+
context: Optional[str],
|
| 17 |
+
trace: Optional[str],
|
| 18 |
+
history: Optional[List[Dict[str, str]]],
|
| 19 |
+
start_time: float,
|
| 20 |
+
pipeline_steps: List[Dict]
|
| 21 |
+
) -> Dict[str, Any]:
|
| 22 |
+
"""Process task using decomposition and micro-SLM routing"""
|
| 23 |
+
|
| 24 |
+
logger.info("Using decomposition-based processing")
|
| 25 |
+
|
| 26 |
+
# Step 1: Decompose task into subtasks
|
| 27 |
+
subtasks = await self.task_decomposer.decompose(
|
| 28 |
+
task=task,
|
| 29 |
+
code=code,
|
| 30 |
+
language=language,
|
| 31 |
+
context=context,
|
| 32 |
+
trace=trace
|
| 33 |
+
)
|
| 34 |
+
|
| 35 |
+
logger.info(f"Decomposed into {len(subtasks)} subtasks")
|
| 36 |
+
|
| 37 |
+
# Step 2: Process each subtask
|
| 38 |
+
results = []
|
| 39 |
+
current_code = code
|
| 40 |
+
|
| 41 |
+
for i, subtask in enumerate(subtasks):
|
| 42 |
+
logger.info(f"Processing subtask {i+1}/{len(subtasks)}: {subtask.subtask_type}")
|
| 43 |
+
|
| 44 |
+
# Route subtask to best handler
|
| 45 |
+
routing = await self.router_v2.route_subtask(
|
| 46 |
+
subtask_type=subtask.subtask_type,
|
| 47 |
+
code=current_code,
|
| 48 |
+
language=language,
|
| 49 |
+
context=subtask.context
|
| 50 |
+
)
|
| 51 |
+
|
| 52 |
+
logger.info(f"Routed to: {routing['handler_type']} ({routing['handler_name']})")
|
| 53 |
+
|
| 54 |
+
# Execute based on handler type
|
| 55 |
+
step_start = time.time()
|
| 56 |
+
|
| 57 |
+
if routing['handler_type'] == 'automata':
|
| 58 |
+
# Use automaton
|
| 59 |
+
result = await self._execute_automaton(
|
| 60 |
+
automaton_name=routing['handler_name'],
|
| 61 |
+
code=current_code,
|
| 62 |
+
trace=trace,
|
| 63 |
+
pipeline_steps=pipeline_steps
|
| 64 |
+
)
|
| 65 |
+
|
| 66 |
+
elif routing['handler_type'] == 'micro_slm':
|
| 67 |
+
# Use micro-SLM (placeholder for now)
|
| 68 |
+
logger.warning(f"Micro-SLM execution not yet implemented, falling back to Groq")
|
| 69 |
+
result = await self._execute_groq(
|
| 70 |
+
task=task,
|
| 71 |
+
code=current_code,
|
| 72 |
+
language=language,
|
| 73 |
+
context=subtask.context,
|
| 74 |
+
trace=trace,
|
| 75 |
+
history=history,
|
| 76 |
+
pipeline_steps=pipeline_steps
|
| 77 |
+
)
|
| 78 |
+
|
| 79 |
+
else: # groq
|
| 80 |
+
# Use Groq
|
| 81 |
+
result = await self._execute_groq(
|
| 82 |
+
task=task,
|
| 83 |
+
code=current_code,
|
| 84 |
+
language=language,
|
| 85 |
+
context=subtask.context,
|
| 86 |
+
trace=trace,
|
| 87 |
+
history=history,
|
| 88 |
+
pipeline_steps=pipeline_steps
|
| 89 |
+
)
|
| 90 |
+
|
| 91 |
+
step_duration = (time.time() - step_start) * 1000
|
| 92 |
+
|
| 93 |
+
# Record step
|
| 94 |
+
pipeline_steps.append({
|
| 95 |
+
"step_type": routing['handler_type'],
|
| 96 |
+
"component": routing['handler_name'],
|
| 97 |
+
"subtask": subtask.subtask_type,
|
| 98 |
+
"duration_ms": step_duration,
|
| 99 |
+
"success": result.get("success", False)
|
| 100 |
+
})
|
| 101 |
+
|
| 102 |
+
# Update current_code for next subtask
|
| 103 |
+
if result.get("success") and result.get("result"):
|
| 104 |
+
current_code = result["result"]
|
| 105 |
+
|
| 106 |
+
results.append(result)
|
| 107 |
+
|
| 108 |
+
# Step 3: Combine results
|
| 109 |
+
duration_ms = (time.time() - start_time) * 1000
|
| 110 |
+
|
| 111 |
+
# Get final result (last successful result)
|
| 112 |
+
final_result = current_code
|
| 113 |
+
|
| 114 |
+
# Combine explanations
|
| 115 |
+
explanations = [r.get("explanation", "") for r in results if r.get("explanation")]
|
| 116 |
+
combined_explanation = "\n\n".join(explanations) if explanations else "Processed via decomposition"
|
| 117 |
+
|
| 118 |
+
return {
|
| 119 |
+
"success": True,
|
| 120 |
+
"task": task,
|
| 121 |
+
"result": final_result,
|
| 122 |
+
"explanation": combined_explanation,
|
| 123 |
+
"suggestions": ["Code processed through micro-SLM mesh"],
|
| 124 |
+
"used_automata": any(s["step_type"] == "automata" for s in pipeline_steps),
|
| 125 |
+
"used_slm": any(s["step_type"] in ["micro_slm", "groq"] for s in pipeline_steps),
|
| 126 |
+
"pipeline": pipeline_steps,
|
| 127 |
+
"total_duration_ms": duration_ms,
|
| 128 |
+
"subtasks_processed": len(subtasks)
|
| 129 |
+
}
|
| 130 |
+
|
| 131 |
+
async def _execute_automaton(
|
| 132 |
+
self,
|
| 133 |
+
automaton_name: str,
|
| 134 |
+
code: str,
|
| 135 |
+
trace: Optional[str],
|
| 136 |
+
pipeline_steps: List[Dict]
|
| 137 |
+
) -> Dict[str, Any]:
|
| 138 |
+
"""Execute an automaton"""
|
| 139 |
+
try:
|
| 140 |
+
automaton = self.automata.get(automaton_name)
|
| 141 |
+
if not automaton:
|
| 142 |
+
return {"success": False, "error": f"Automaton {automaton_name} not found"}
|
| 143 |
+
|
| 144 |
+
result = await automaton.execute(code, trace=trace)
|
| 145 |
+
return result
|
| 146 |
+
|
| 147 |
+
except Exception as e:
|
| 148 |
+
logger.error(f"Automaton execution failed: {e}")
|
| 149 |
+
return {"success": False, "error": str(e)}
|
| 150 |
+
|
| 151 |
+
async def _execute_groq(
|
| 152 |
+
self,
|
| 153 |
+
task: TaskType,
|
| 154 |
+
code: str,
|
| 155 |
+
language: Language,
|
| 156 |
+
context: Optional[str],
|
| 157 |
+
trace: Optional[str],
|
| 158 |
+
history: Optional[List[Dict[str, str]]],
|
| 159 |
+
pipeline_steps: List[Dict]
|
| 160 |
+
) -> Dict[str, Any]:
|
| 161 |
+
"""Execute using Groq engine"""
|
| 162 |
+
try:
|
| 163 |
+
engine = await self._get_engine("groq")
|
| 164 |
+
|
| 165 |
+
# Enrich context with RAG if available
|
| 166 |
+
enriched_context = context or ""
|
| 167 |
+
if self.retriever and settings.enable_rag:
|
| 168 |
+
try:
|
| 169 |
+
rag_context = self.retriever.build_context(
|
| 170 |
+
query_code=code,
|
| 171 |
+
language=language,
|
| 172 |
+
task=task,
|
| 173 |
+
k=3
|
| 174 |
+
)
|
| 175 |
+
if rag_context:
|
| 176 |
+
enriched_context = f"{enriched_context}\n\n{rag_context}" if enriched_context else rag_context
|
| 177 |
+
except Exception as e:
|
| 178 |
+
logger.warning(f"RAG enrichment failed: {e}")
|
| 179 |
+
|
| 180 |
+
result = await engine.process(
|
| 181 |
+
task=task,
|
| 182 |
+
code=code,
|
| 183 |
+
language=language,
|
| 184 |
+
context=enriched_context,
|
| 185 |
+
trace=trace,
|
| 186 |
+
history=history
|
| 187 |
+
)
|
| 188 |
+
|
| 189 |
+
return result
|
| 190 |
+
|
| 191 |
+
except Exception as e:
|
| 192 |
+
logger.error(f"Groq execution failed: {e}")
|
| 193 |
+
return {"success": False, "error": str(e)}
|
backend/app/core/pipeline.py
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Pipeline builder for multi-step task execution
|
| 3 |
+
|
| 4 |
+
Future feature for V1.5: Chain multiple automata/SLMs
|
| 5 |
+
V1: Simple pass-through
|
| 6 |
+
"""
|
| 7 |
+
import logging
|
| 8 |
+
from typing import List, Dict, Any
|
| 9 |
+
|
| 10 |
+
logger = logging.getLogger(__name__)
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class Pipeline:
|
| 14 |
+
"""Builds and executes multi-step pipelines"""
|
| 15 |
+
|
| 16 |
+
def __init__(self):
|
| 17 |
+
logger.info("Pipeline builder initialized (simple mode for V1)")
|
| 18 |
+
|
| 19 |
+
async def build(
|
| 20 |
+
self,
|
| 21 |
+
task: str,
|
| 22 |
+
steps: List[str]
|
| 23 |
+
) -> List[Dict[str, Any]]:
|
| 24 |
+
"""
|
| 25 |
+
Build execution pipeline
|
| 26 |
+
|
| 27 |
+
V1: Returns simple single-step pipeline
|
| 28 |
+
V1.5+: Will support multi-step chaining
|
| 29 |
+
"""
|
| 30 |
+
# For V1, we don't chain steps yet
|
| 31 |
+
# This is a placeholder for future functionality
|
| 32 |
+
return []
|
| 33 |
+
|
| 34 |
+
async def execute(
|
| 35 |
+
self,
|
| 36 |
+
pipeline: List[Dict[str, Any]],
|
| 37 |
+
input_data: str
|
| 38 |
+
) -> Dict[str, Any]:
|
| 39 |
+
"""Execute a pipeline"""
|
| 40 |
+
# V1: Not used yet
|
| 41 |
+
# V1.5+: Will execute multi-step pipelines
|
| 42 |
+
pass
|
backend/app/core/rag.py
ADDED
|
@@ -0,0 +1,124 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
RAG Retriever
|
| 3 |
+
|
| 4 |
+
Retrieval-Augmented Generation for code context.
|
| 5 |
+
Retrieves relevant code examples and documentation to enhance SLM responses.
|
| 6 |
+
"""
|
| 7 |
+
import logging
|
| 8 |
+
from typing import List, Dict, Any, Optional
|
| 9 |
+
from pathlib import Path
|
| 10 |
+
import json
|
| 11 |
+
|
| 12 |
+
logger = logging.getLogger(__name__)
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class RAGRetriever:
|
| 16 |
+
"""Retrieval-Augmented Generation retriever for code context"""
|
| 17 |
+
|
| 18 |
+
def __init__(self, examples_dir: Optional[Path] = None):
|
| 19 |
+
self.examples_dir = examples_dir or Path("data/examples")
|
| 20 |
+
self.examples_cache: Dict[str, List[Dict]] = {}
|
| 21 |
+
logger.info(f"RAGRetriever initialized with examples_dir: {self.examples_dir}")
|
| 22 |
+
|
| 23 |
+
def load_examples(self, language: str = "python") -> List[Dict[str, Any]]:
|
| 24 |
+
"""Load code examples for a specific language"""
|
| 25 |
+
cache_key = language
|
| 26 |
+
|
| 27 |
+
if cache_key in self.examples_cache:
|
| 28 |
+
return self.examples_cache[cache_key]
|
| 29 |
+
|
| 30 |
+
examples = []
|
| 31 |
+
|
| 32 |
+
# Look for example files
|
| 33 |
+
if self.examples_dir.exists():
|
| 34 |
+
for example_file in self.examples_dir.glob(f"{language}*.json"):
|
| 35 |
+
try:
|
| 36 |
+
with open(example_file, 'r', encoding='utf-8') as f:
|
| 37 |
+
file_examples = json.load(f)
|
| 38 |
+
if isinstance(file_examples, list):
|
| 39 |
+
examples.extend(file_examples)
|
| 40 |
+
else:
|
| 41 |
+
examples.append(file_examples)
|
| 42 |
+
except Exception as e:
|
| 43 |
+
logger.warning(f"Failed to load examples from {example_file}: {e}")
|
| 44 |
+
|
| 45 |
+
self.examples_cache[cache_key] = examples
|
| 46 |
+
logger.info(f"Loaded {len(examples)} examples for {language}")
|
| 47 |
+
|
| 48 |
+
return examples
|
| 49 |
+
|
| 50 |
+
async def retrieve_context(
|
| 51 |
+
self,
|
| 52 |
+
query: str,
|
| 53 |
+
language: str = "python",
|
| 54 |
+
task_type: Optional[str] = None,
|
| 55 |
+
max_results: int = 3
|
| 56 |
+
) -> List[Dict[str, Any]]:
|
| 57 |
+
"""
|
| 58 |
+
Retrieve relevant code examples based on query
|
| 59 |
+
|
| 60 |
+
Args:
|
| 61 |
+
query: Search query (e.g., task description)
|
| 62 |
+
language: Programming language
|
| 63 |
+
task_type: Optional task type filter
|
| 64 |
+
max_results: Maximum number of results to return
|
| 65 |
+
|
| 66 |
+
Returns:
|
| 67 |
+
List of relevant code examples
|
| 68 |
+
"""
|
| 69 |
+
examples = self.load_examples(language)
|
| 70 |
+
|
| 71 |
+
if not examples:
|
| 72 |
+
return []
|
| 73 |
+
|
| 74 |
+
# Simple keyword-based retrieval
|
| 75 |
+
# In production, you'd use embeddings and vector search
|
| 76 |
+
query_lower = query.lower()
|
| 77 |
+
|
| 78 |
+
scored_examples = []
|
| 79 |
+
for example in examples:
|
| 80 |
+
score = 0
|
| 81 |
+
|
| 82 |
+
# Check task type match
|
| 83 |
+
if task_type and example.get('task') == task_type:
|
| 84 |
+
score += 10
|
| 85 |
+
|
| 86 |
+
# Check description match
|
| 87 |
+
description = example.get('description', '').lower()
|
| 88 |
+
if any(word in description for word in query_lower.split()):
|
| 89 |
+
score += 5
|
| 90 |
+
|
| 91 |
+
# Check code content match
|
| 92 |
+
code = example.get('code', '').lower()
|
| 93 |
+
if any(word in code for word in query_lower.split()):
|
| 94 |
+
score += 2
|
| 95 |
+
|
| 96 |
+
if score > 0:
|
| 97 |
+
scored_examples.append((score, example))
|
| 98 |
+
|
| 99 |
+
# Sort by score and return top results
|
| 100 |
+
scored_examples.sort(reverse=True, key=lambda x: x[0])
|
| 101 |
+
|
| 102 |
+
results = [example for score, example in scored_examples[:max_results]]
|
| 103 |
+
|
| 104 |
+
logger.debug(f"Retrieved {len(results)} examples for query: {query}")
|
| 105 |
+
|
| 106 |
+
return results
|
| 107 |
+
|
| 108 |
+
def format_context(self, examples: List[Dict[str, Any]]) -> str:
|
| 109 |
+
"""Format retrieved examples as context string"""
|
| 110 |
+
if not examples:
|
| 111 |
+
return ""
|
| 112 |
+
|
| 113 |
+
context_parts = ["Here are some relevant code examples:\n"]
|
| 114 |
+
|
| 115 |
+
for i, example in enumerate(examples, 1):
|
| 116 |
+
context_parts.append(f"\nExample {i}:")
|
| 117 |
+
if 'description' in example:
|
| 118 |
+
context_parts.append(f"Description: {example['description']}")
|
| 119 |
+
if 'code' in example:
|
| 120 |
+
context_parts.append(f"```{example.get('language', 'python')}")
|
| 121 |
+
context_parts.append(example['code'])
|
| 122 |
+
context_parts.append("```")
|
| 123 |
+
|
| 124 |
+
return "\n".join(context_parts)
|
backend/app/core/router.py
ADDED
|
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Router for determining how to handle tasks
|
| 3 |
+
|
| 4 |
+
Uses a combination of:
|
| 5 |
+
- Rule-based routing (simple cases)
|
| 6 |
+
- Embedding-based classification (complex cases)
|
| 7 |
+
"""
|
| 8 |
+
import logging
|
| 9 |
+
from typing import Dict, Any, Optional
|
| 10 |
+
|
| 11 |
+
from app.models.schemas import TaskType, Language
|
| 12 |
+
|
| 13 |
+
logger = logging.getLogger(__name__)
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class Router:
|
| 17 |
+
"""Routes tasks to appropriate engines or automata"""
|
| 18 |
+
|
| 19 |
+
def __init__(self):
|
| 20 |
+
self.initialized = False
|
| 21 |
+
self.embedding_model = None
|
| 22 |
+
|
| 23 |
+
async def initialize(self):
|
| 24 |
+
"""Initialize router components"""
|
| 25 |
+
logger.info("Initializing router...")
|
| 26 |
+
|
| 27 |
+
# For V1, we use simple rule-based routing
|
| 28 |
+
# V1.5 will add embedding-based classification
|
| 29 |
+
self.initialized = True
|
| 30 |
+
logger.info("Router initialized (rule-based mode)")
|
| 31 |
+
|
| 32 |
+
async def route(
|
| 33 |
+
self,
|
| 34 |
+
task: TaskType,
|
| 35 |
+
code: str,
|
| 36 |
+
language: Language,
|
| 37 |
+
context: Optional[str] = None
|
| 38 |
+
) -> Dict[str, Any]:
|
| 39 |
+
"""
|
| 40 |
+
Determine how to handle the task
|
| 41 |
+
|
| 42 |
+
Returns:
|
| 43 |
+
Dict with routing decision:
|
| 44 |
+
- try_automata: bool
|
| 45 |
+
- engine: str (which SLM to use)
|
| 46 |
+
- confidence: float
|
| 47 |
+
"""
|
| 48 |
+
|
| 49 |
+
# Check if Groq is configured
|
| 50 |
+
from app.config import settings
|
| 51 |
+
use_groq = bool(settings.groq_api_key)
|
| 52 |
+
default_engine = "groq" if use_groq else "phi2"
|
| 53 |
+
|
| 54 |
+
# Rule-based routing
|
| 55 |
+
routing = {
|
| 56 |
+
"try_automata": False,
|
| 57 |
+
"engine": default_engine,
|
| 58 |
+
"confidence": 0.8
|
| 59 |
+
}
|
| 60 |
+
|
| 61 |
+
# Simple rules
|
| 62 |
+
if task == TaskType.FORMAT:
|
| 63 |
+
# Formatting is best handled by automata
|
| 64 |
+
routing["try_automata"] = True
|
| 65 |
+
routing["confidence"] = 0.95
|
| 66 |
+
|
| 67 |
+
elif task == TaskType.FIX and language == Language.PYTHON:
|
| 68 |
+
# Try automata first for Python fixes
|
| 69 |
+
routing["try_automata"] = True
|
| 70 |
+
routing["engine"] = default_engine
|
| 71 |
+
routing["confidence"] = 0.7
|
| 72 |
+
|
| 73 |
+
elif task == TaskType.EXPLAIN:
|
| 74 |
+
# Explanation needs SLM
|
| 75 |
+
routing["try_automata"] = True
|
| 76 |
+
routing["engine"] = default_engine
|
| 77 |
+
routing["confidence"] = 0.85
|
| 78 |
+
|
| 79 |
+
elif task == TaskType.TRANSLATE:
|
| 80 |
+
# Translation needs SLM
|
| 81 |
+
routing["engine"] = default_engine
|
| 82 |
+
routing["confidence"] = 0.9
|
| 83 |
+
|
| 84 |
+
elif task in [TaskType.TEST, TaskType.BOILERPLATE]:
|
| 85 |
+
# Generation tasks need SLM
|
| 86 |
+
routing["engine"] = default_engine
|
| 87 |
+
routing["confidence"] = 0.8
|
| 88 |
+
|
| 89 |
+
elif task == TaskType.REFACTOR:
|
| 90 |
+
# Refactoring needs SLM
|
| 91 |
+
routing["engine"] = default_engine
|
| 92 |
+
routing["confidence"] = 0.75
|
| 93 |
+
|
| 94 |
+
logger.debug(f"Routing decision for {task}: {routing}")
|
| 95 |
+
return routing
|
| 96 |
+
|
| 97 |
+
def _calculate_confidence(self, task: TaskType, code: str) -> float:
|
| 98 |
+
"""Calculate confidence in routing decision"""
|
| 99 |
+
# Placeholder for future embedding-based classification
|
| 100 |
+
return 0.8
|
backend/app/core/router_v2.py
ADDED
|
@@ -0,0 +1,174 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Router v2 - Intelligent Capability-Based Routing
|
| 3 |
+
|
| 4 |
+
Routes subtasks to the most appropriate handler:
|
| 5 |
+
1. Automata (fastest, deterministic)
|
| 6 |
+
2. Micro-SLMs (fast, specialized)
|
| 7 |
+
3. Groq (slow, general-purpose, teacher)
|
| 8 |
+
"""
|
| 9 |
+
import logging
|
| 10 |
+
from typing import Dict, Any, Optional
|
| 11 |
+
from app.models.schemas import TaskType, Language
|
| 12 |
+
from app.core.slm_registry import slm_registry
|
| 13 |
+
from app.config import settings
|
| 14 |
+
|
| 15 |
+
logger = logging.getLogger(__name__)
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class RouterV2:
|
| 19 |
+
"""Intelligent router for micro-SLM mesh"""
|
| 20 |
+
|
| 21 |
+
def __init__(self):
|
| 22 |
+
self.min_micro_slm_accuracy = 0.85 # Minimum accuracy to use micro-SLM
|
| 23 |
+
self.automata_capabilities = {
|
| 24 |
+
# Map subtask types to automata names
|
| 25 |
+
"fix_syntax": "ast_fixer",
|
| 26 |
+
"format_code": "black",
|
| 27 |
+
"format_imports": "isort",
|
| 28 |
+
}
|
| 29 |
+
|
| 30 |
+
async def route_subtask(
|
| 31 |
+
self,
|
| 32 |
+
subtask_type: str,
|
| 33 |
+
code: str,
|
| 34 |
+
language: Language,
|
| 35 |
+
context: Optional[str] = None
|
| 36 |
+
) -> Dict[str, Any]:
|
| 37 |
+
"""
|
| 38 |
+
Route a subtask to the best handler
|
| 39 |
+
|
| 40 |
+
Returns:
|
| 41 |
+
{
|
| 42 |
+
"handler_type": "automata" | "micro_slm" | "groq",
|
| 43 |
+
"handler_name": str,
|
| 44 |
+
"confidence": float
|
| 45 |
+
}
|
| 46 |
+
"""
|
| 47 |
+
logger.debug(f"Routing subtask: {subtask_type}")
|
| 48 |
+
|
| 49 |
+
# Step 1: Check if automata can handle it
|
| 50 |
+
automata_result = await self._try_automata(subtask_type, code, language)
|
| 51 |
+
if automata_result:
|
| 52 |
+
return automata_result
|
| 53 |
+
|
| 54 |
+
# Step 2: Check if micro-SLM can handle it
|
| 55 |
+
micro_slm_result = await self._try_micro_slm(subtask_type, code, language)
|
| 56 |
+
if micro_slm_result:
|
| 57 |
+
return micro_slm_result
|
| 58 |
+
|
| 59 |
+
# Step 3: Fall back to Groq
|
| 60 |
+
return {
|
| 61 |
+
"handler_type": "groq",
|
| 62 |
+
"handler_name": "groq",
|
| 63 |
+
"confidence": 1.0, # Groq is always confident
|
| 64 |
+
"reason": "No specialized handler available"
|
| 65 |
+
}
|
| 66 |
+
|
| 67 |
+
async def _try_automata(
|
| 68 |
+
self,
|
| 69 |
+
subtask_type: str,
|
| 70 |
+
code: str,
|
| 71 |
+
language: Language
|
| 72 |
+
) -> Optional[Dict[str, Any]]:
|
| 73 |
+
"""Check if an automaton can handle this subtask"""
|
| 74 |
+
|
| 75 |
+
automaton_name = self.automata_capabilities.get(subtask_type)
|
| 76 |
+
|
| 77 |
+
if not automaton_name:
|
| 78 |
+
return None
|
| 79 |
+
|
| 80 |
+
logger.debug(f"Automaton '{automaton_name}' can handle '{subtask_type}'")
|
| 81 |
+
|
| 82 |
+
return {
|
| 83 |
+
"handler_type": "automata",
|
| 84 |
+
"handler_name": automaton_name,
|
| 85 |
+
"confidence": 1.0, # Automata are deterministic
|
| 86 |
+
"reason": f"Automaton '{automaton_name}' handles this pattern"
|
| 87 |
+
}
|
| 88 |
+
|
| 89 |
+
async def _try_micro_slm(
|
| 90 |
+
self,
|
| 91 |
+
subtask_type: str,
|
| 92 |
+
code: str,
|
| 93 |
+
language: Language
|
| 94 |
+
) -> Optional[Dict[str, Any]]:
|
| 95 |
+
"""Check if a micro-SLM can handle this subtask"""
|
| 96 |
+
|
| 97 |
+
# Query registry for micro-SLMs with this capability
|
| 98 |
+
best_micro_slm = slm_registry.get_best_for_capability(
|
| 99 |
+
capability=subtask_type,
|
| 100 |
+
min_accuracy=self.min_micro_slm_accuracy
|
| 101 |
+
)
|
| 102 |
+
|
| 103 |
+
if not best_micro_slm:
|
| 104 |
+
logger.debug(f"No micro-SLM available for '{subtask_type}'")
|
| 105 |
+
return None
|
| 106 |
+
|
| 107 |
+
logger.info(
|
| 108 |
+
f"Micro-SLM '{best_micro_slm.name}' selected for '{subtask_type}' "
|
| 109 |
+
f"(accuracy: {best_micro_slm.accuracy:.2f})"
|
| 110 |
+
)
|
| 111 |
+
|
| 112 |
+
return {
|
| 113 |
+
"handler_type": "micro_slm",
|
| 114 |
+
"handler_name": best_micro_slm.name,
|
| 115 |
+
"confidence": best_micro_slm.accuracy,
|
| 116 |
+
"reason": f"Specialized micro-SLM (accuracy: {best_micro_slm.accuracy:.2f})"
|
| 117 |
+
}
|
| 118 |
+
|
| 119 |
+
async def route_task(
|
| 120 |
+
self,
|
| 121 |
+
task: TaskType,
|
| 122 |
+
code: str,
|
| 123 |
+
language: Language,
|
| 124 |
+
context: Optional[str] = None
|
| 125 |
+
) -> Dict[str, Any]:
|
| 126 |
+
"""
|
| 127 |
+
Route a full task (backward compatibility with old Router)
|
| 128 |
+
|
| 129 |
+
This is used when task is NOT decomposed
|
| 130 |
+
"""
|
| 131 |
+
logger.info(f"RouterV2 checking for capability: {task.value}")
|
| 132 |
+
|
| 133 |
+
# Step 1: Check for automata (fastest)
|
| 134 |
+
if task == TaskType.FORMAT and language == Language.PYTHON:
|
| 135 |
+
return {
|
| 136 |
+
"handler_type": "automata",
|
| 137 |
+
"handler_name": "black",
|
| 138 |
+
"confidence": 1.0,
|
| 139 |
+
"try_automata": True,
|
| 140 |
+
"engine": "black"
|
| 141 |
+
}
|
| 142 |
+
|
| 143 |
+
# Step 2: Check for Micro-SLMs
|
| 144 |
+
micro_slm_info = slm_registry.get_best_for_capability(
|
| 145 |
+
capability=task.value, # Use task.value to get string like "boilerplate"
|
| 146 |
+
min_accuracy=self.min_micro_slm_accuracy
|
| 147 |
+
)
|
| 148 |
+
|
| 149 |
+
if micro_slm_info:
|
| 150 |
+
logger.info(
|
| 151 |
+
f"Micro-SLM '{micro_slm_info.name}' selected for '{task.value}' "
|
| 152 |
+
f"(accuracy: {micro_slm_info.accuracy:.2f})"
|
| 153 |
+
)
|
| 154 |
+
return {
|
| 155 |
+
"handler_type": "micro_slm",
|
| 156 |
+
"handler_name": micro_slm_info.name,
|
| 157 |
+
"confidence": micro_slm_info.accuracy,
|
| 158 |
+
"try_automata": False,
|
| 159 |
+
"engine": micro_slm_info.name # Use Micro-SLM as engine
|
| 160 |
+
}
|
| 161 |
+
|
| 162 |
+
# Step 3: Fall back to Groq
|
| 163 |
+
logger.info(f"No Micro-SLM found for '{task.value}', using Groq")
|
| 164 |
+
return {
|
| 165 |
+
"handler_type": "groq",
|
| 166 |
+
"handler_name": "groq",
|
| 167 |
+
"confidence": 1.0,
|
| 168 |
+
"try_automata": settings.enable_automata_first,
|
| 169 |
+
"engine": "groq"
|
| 170 |
+
}
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
# Global instance
|
| 174 |
+
router_v2 = RouterV2()
|
backend/app/core/slm_registry.py
ADDED
|
@@ -0,0 +1,120 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
SLM Registry
|
| 3 |
+
|
| 4 |
+
Manages the registration and discovery of specialized micro-SLMs.
|
| 5 |
+
Acts as the central catalog for the "Mesh" architecture.
|
| 6 |
+
"""
|
| 7 |
+
import json
|
| 8 |
+
import logging
|
| 9 |
+
from pathlib import Path
|
| 10 |
+
from typing import Dict, List, Optional, Any
|
| 11 |
+
from dataclasses import dataclass, asdict
|
| 12 |
+
from datetime import datetime
|
| 13 |
+
|
| 14 |
+
from app.config import settings
|
| 15 |
+
|
| 16 |
+
logger = logging.getLogger(__name__)
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
@dataclass
|
| 20 |
+
class MicroSLMInfo:
|
| 21 |
+
"""Metadata for a registered micro-SLM"""
|
| 22 |
+
name: str
|
| 23 |
+
model_path: str
|
| 24 |
+
base_model: str
|
| 25 |
+
capabilities: List[str]
|
| 26 |
+
accuracy: float
|
| 27 |
+
avg_latency_ms: float
|
| 28 |
+
size_mb: float
|
| 29 |
+
training_samples: int
|
| 30 |
+
last_updated: str
|
| 31 |
+
metadata: Dict[str, Any]
|
| 32 |
+
|
| 33 |
+
def to_dict(self) -> Dict[str, Any]:
|
| 34 |
+
return asdict(self)
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
class SLMRegistry:
|
| 38 |
+
"""
|
| 39 |
+
Registry for managing micro-SLMs.
|
| 40 |
+
Persists data to a JSON file.
|
| 41 |
+
"""
|
| 42 |
+
|
| 43 |
+
def __init__(self):
|
| 44 |
+
self.registry_file = settings.data_dir / "slm_registry.json"
|
| 45 |
+
self.models: Dict[str, MicroSLMInfo] = {}
|
| 46 |
+
self._load_registry()
|
| 47 |
+
|
| 48 |
+
def _load_registry(self):
|
| 49 |
+
"""Load registry from disk"""
|
| 50 |
+
if self.registry_file.exists():
|
| 51 |
+
try:
|
| 52 |
+
with open(self.registry_file, 'r', encoding='utf-8') as f:
|
| 53 |
+
data = json.load(f)
|
| 54 |
+
for name, info in data.items():
|
| 55 |
+
self.models[name] = MicroSLMInfo(**info)
|
| 56 |
+
logger.info(f"Loaded {len(self.models)} micro-SLMs from registry")
|
| 57 |
+
except Exception as e:
|
| 58 |
+
logger.error(f"Failed to load SLM registry: {e}")
|
| 59 |
+
self.models = {}
|
| 60 |
+
|
| 61 |
+
def _save_registry(self):
|
| 62 |
+
"""Save registry to disk"""
|
| 63 |
+
try:
|
| 64 |
+
data = {name: model.to_dict() for name, model in self.models.items()}
|
| 65 |
+
with open(self.registry_file, 'w', encoding='utf-8') as f:
|
| 66 |
+
json.dump(data, f, indent=2, ensure_ascii=False)
|
| 67 |
+
logger.info("Saved SLM registry to disk")
|
| 68 |
+
except Exception as e:
|
| 69 |
+
logger.error(f"Failed to save SLM registry: {e}")
|
| 70 |
+
|
| 71 |
+
def register(self, info: MicroSLMInfo):
|
| 72 |
+
"""Register or update a micro-SLM"""
|
| 73 |
+
self.models[info.name] = info
|
| 74 |
+
self._save_registry()
|
| 75 |
+
logger.info(f"Registered micro-SLM: {info.name}")
|
| 76 |
+
|
| 77 |
+
def get_model(self, name: str) -> Optional[MicroSLMInfo]:
|
| 78 |
+
"""Get model info by name"""
|
| 79 |
+
return self.models.get(name)
|
| 80 |
+
|
| 81 |
+
def get_best_for_capability(self, capability: str, min_accuracy: float = 0.0) -> Optional[MicroSLMInfo]:
|
| 82 |
+
"""
|
| 83 |
+
Find the best model for a specific capability (subtask).
|
| 84 |
+
Returns the model with the highest accuracy that meets the minimum requirement.
|
| 85 |
+
"""
|
| 86 |
+
candidates = [
|
| 87 |
+
m for m in self.models.values()
|
| 88 |
+
if capability in m.capabilities and m.accuracy >= min_accuracy
|
| 89 |
+
]
|
| 90 |
+
|
| 91 |
+
if not candidates:
|
| 92 |
+
return None
|
| 93 |
+
|
| 94 |
+
# Sort by accuracy (descending)
|
| 95 |
+
candidates.sort(key=lambda x: x.accuracy, reverse=True)
|
| 96 |
+
return candidates[0]
|
| 97 |
+
|
| 98 |
+
def get_all_models(self) -> List[MicroSLMInfo]:
|
| 99 |
+
"""Get all registered models"""
|
| 100 |
+
return list(self.models.values())
|
| 101 |
+
|
| 102 |
+
def get_stats(self) -> Dict[str, Any]:
|
| 103 |
+
"""Get registry statistics"""
|
| 104 |
+
capabilities = set()
|
| 105 |
+
total_size = 0.0
|
| 106 |
+
|
| 107 |
+
for m in self.models.values():
|
| 108 |
+
capabilities.update(m.capabilities)
|
| 109 |
+
total_size += m.size_mb
|
| 110 |
+
|
| 111 |
+
return {
|
| 112 |
+
"total_micro_slms": len(self.models),
|
| 113 |
+
"total_size_mb": round(total_size, 2),
|
| 114 |
+
"capabilities_covered": list(capabilities),
|
| 115 |
+
"last_updated": datetime.now().isoformat()
|
| 116 |
+
}
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
# Global instance
|
| 120 |
+
slm_registry = SLMRegistry()
|
backend/app/core/task_decomposer.py
ADDED
|
@@ -0,0 +1,309 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Task Decomposer
|
| 3 |
+
|
| 4 |
+
Breaks down complex code tasks into atomic subtasks that can be handled
|
| 5 |
+
by specialized micro-SLMs or automata.
|
| 6 |
+
"""
|
| 7 |
+
import logging
|
| 8 |
+
from typing import List, Dict, Any, Optional
|
| 9 |
+
from app.models.schemas import TaskType, Language
|
| 10 |
+
|
| 11 |
+
logger = logging.getLogger(__name__)
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class Subtask:
|
| 15 |
+
"""Represents an atomic subtask"""
|
| 16 |
+
|
| 17 |
+
def __init__(
|
| 18 |
+
self,
|
| 19 |
+
subtask_type: str,
|
| 20 |
+
code: str,
|
| 21 |
+
priority: int = 1,
|
| 22 |
+
context: Optional[str] = None,
|
| 23 |
+
metadata: Optional[Dict[str, Any]] = None
|
| 24 |
+
):
|
| 25 |
+
self.subtask_type = subtask_type
|
| 26 |
+
self.code = code
|
| 27 |
+
self.priority = priority
|
| 28 |
+
self.context = context
|
| 29 |
+
self.metadata = metadata or {}
|
| 30 |
+
|
| 31 |
+
def to_dict(self) -> Dict[str, Any]:
|
| 32 |
+
return {
|
| 33 |
+
"subtask_type": self.subtask_type,
|
| 34 |
+
"code": self.code,
|
| 35 |
+
"priority": self.priority,
|
| 36 |
+
"context": self.context,
|
| 37 |
+
"metadata": self.metadata
|
| 38 |
+
}
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
class TaskDecomposer:
|
| 42 |
+
"""Decomposes complex tasks into atomic subtasks"""
|
| 43 |
+
|
| 44 |
+
def __init__(self):
|
| 45 |
+
# Define decomposition rules
|
| 46 |
+
self.decomposition_rules = {
|
| 47 |
+
TaskType.FIX: self._decompose_fix,
|
| 48 |
+
TaskType.REFACTOR: self._decompose_refactor,
|
| 49 |
+
TaskType.TEST: self._decompose_test,
|
| 50 |
+
TaskType.BOILERPLATE: self._decompose_boilerplate,
|
| 51 |
+
TaskType.EXPLAIN: self._decompose_explain,
|
| 52 |
+
TaskType.FORMAT: self._decompose_format,
|
| 53 |
+
TaskType.TRANSLATE: self._decompose_translate,
|
| 54 |
+
}
|
| 55 |
+
|
| 56 |
+
async def decompose(
|
| 57 |
+
self,
|
| 58 |
+
task: TaskType,
|
| 59 |
+
code: str,
|
| 60 |
+
language: Language,
|
| 61 |
+
context: Optional[str] = None,
|
| 62 |
+
trace: Optional[str] = None
|
| 63 |
+
) -> List[Subtask]:
|
| 64 |
+
"""
|
| 65 |
+
Decompose a task into subtasks
|
| 66 |
+
|
| 67 |
+
Returns:
|
| 68 |
+
List of Subtask objects, ordered by priority
|
| 69 |
+
"""
|
| 70 |
+
logger.info(f"Decomposing task: {task}")
|
| 71 |
+
|
| 72 |
+
# Get decomposition function for this task type
|
| 73 |
+
decompose_fn = self.decomposition_rules.get(task)
|
| 74 |
+
|
| 75 |
+
if not decompose_fn:
|
| 76 |
+
# No decomposition needed, return single subtask
|
| 77 |
+
return [Subtask(
|
| 78 |
+
subtask_type=task.value,
|
| 79 |
+
code=code,
|
| 80 |
+
priority=1,
|
| 81 |
+
context=context
|
| 82 |
+
)]
|
| 83 |
+
|
| 84 |
+
# Decompose
|
| 85 |
+
subtasks = await decompose_fn(code, language, context, trace)
|
| 86 |
+
|
| 87 |
+
# Sort by priority
|
| 88 |
+
subtasks.sort(key=lambda x: x.priority)
|
| 89 |
+
|
| 90 |
+
logger.info(f"Decomposed into {len(subtasks)} subtasks")
|
| 91 |
+
return subtasks
|
| 92 |
+
|
| 93 |
+
async def _decompose_fix(
|
| 94 |
+
self,
|
| 95 |
+
code: str,
|
| 96 |
+
language: Language,
|
| 97 |
+
context: Optional[str],
|
| 98 |
+
trace: Optional[str]
|
| 99 |
+
) -> List[Subtask]:
|
| 100 |
+
"""Decompose fix task"""
|
| 101 |
+
subtasks = []
|
| 102 |
+
|
| 103 |
+
# Priority 1: Syntax errors (must fix first)
|
| 104 |
+
if self._has_syntax_errors(code, language):
|
| 105 |
+
subtasks.append(Subtask(
|
| 106 |
+
subtask_type="fix_syntax",
|
| 107 |
+
code=code,
|
| 108 |
+
priority=1,
|
| 109 |
+
context=trace
|
| 110 |
+
))
|
| 111 |
+
|
| 112 |
+
# Priority 2: Import errors
|
| 113 |
+
if self._has_import_errors(code, trace):
|
| 114 |
+
subtasks.append(Subtask(
|
| 115 |
+
subtask_type="fix_imports",
|
| 116 |
+
code=code,
|
| 117 |
+
priority=2,
|
| 118 |
+
context=trace
|
| 119 |
+
))
|
| 120 |
+
|
| 121 |
+
# Priority 3: Runtime errors
|
| 122 |
+
if trace and "Error" in trace:
|
| 123 |
+
subtasks.append(Subtask(
|
| 124 |
+
subtask_type="fix_runtime_error",
|
| 125 |
+
code=code,
|
| 126 |
+
priority=3,
|
| 127 |
+
context=trace
|
| 128 |
+
))
|
| 129 |
+
|
| 130 |
+
# If no specific errors detected, general fix
|
| 131 |
+
if not subtasks:
|
| 132 |
+
subtasks.append(Subtask(
|
| 133 |
+
subtask_type="fix_general",
|
| 134 |
+
code=code,
|
| 135 |
+
priority=1,
|
| 136 |
+
context=context
|
| 137 |
+
))
|
| 138 |
+
|
| 139 |
+
return subtasks
|
| 140 |
+
|
| 141 |
+
async def _decompose_refactor(
|
| 142 |
+
self,
|
| 143 |
+
code: str,
|
| 144 |
+
language: Language,
|
| 145 |
+
context: Optional[str],
|
| 146 |
+
trace: Optional[str]
|
| 147 |
+
) -> List[Subtask]:
|
| 148 |
+
"""Decompose refactor task"""
|
| 149 |
+
subtasks = []
|
| 150 |
+
|
| 151 |
+
# Check what kind of refactoring is needed
|
| 152 |
+
if context:
|
| 153 |
+
context_lower = context.lower()
|
| 154 |
+
|
| 155 |
+
if "performance" in context_lower or "optimize" in context_lower:
|
| 156 |
+
subtasks.append(Subtask(
|
| 157 |
+
subtask_type="optimize_performance",
|
| 158 |
+
code=code,
|
| 159 |
+
priority=1,
|
| 160 |
+
context=context
|
| 161 |
+
))
|
| 162 |
+
|
| 163 |
+
if "readability" in context_lower or "clean" in context_lower:
|
| 164 |
+
subtasks.append(Subtask(
|
| 165 |
+
subtask_type="improve_readability",
|
| 166 |
+
code=code,
|
| 167 |
+
priority=2,
|
| 168 |
+
context=context
|
| 169 |
+
))
|
| 170 |
+
|
| 171 |
+
if "type" in context_lower or "hint" in context_lower:
|
| 172 |
+
subtasks.append(Subtask(
|
| 173 |
+
subtask_type="add_type_hints",
|
| 174 |
+
code=code,
|
| 175 |
+
priority=3,
|
| 176 |
+
context=context
|
| 177 |
+
))
|
| 178 |
+
|
| 179 |
+
# Default: general refactoring
|
| 180 |
+
if not subtasks:
|
| 181 |
+
subtasks.append(Subtask(
|
| 182 |
+
subtask_type="refactor_general",
|
| 183 |
+
code=code,
|
| 184 |
+
priority=1,
|
| 185 |
+
context=context
|
| 186 |
+
))
|
| 187 |
+
|
| 188 |
+
return subtasks
|
| 189 |
+
|
| 190 |
+
async def _decompose_test(
|
| 191 |
+
self,
|
| 192 |
+
code: str,
|
| 193 |
+
language: Language,
|
| 194 |
+
context: Optional[str],
|
| 195 |
+
trace: Optional[str]
|
| 196 |
+
) -> List[Subtask]:
|
| 197 |
+
"""Decompose test generation task"""
|
| 198 |
+
return [Subtask(
|
| 199 |
+
subtask_type="generate_tests",
|
| 200 |
+
code=code,
|
| 201 |
+
priority=1,
|
| 202 |
+
context=context
|
| 203 |
+
)]
|
| 204 |
+
|
| 205 |
+
async def _decompose_boilerplate(
|
| 206 |
+
self,
|
| 207 |
+
code: str,
|
| 208 |
+
language: Language,
|
| 209 |
+
context: Optional[str],
|
| 210 |
+
trace: Optional[str]
|
| 211 |
+
) -> List[Subtask]:
|
| 212 |
+
"""Decompose boilerplate generation"""
|
| 213 |
+
return [Subtask(
|
| 214 |
+
subtask_type="generate_boilerplate",
|
| 215 |
+
code=code,
|
| 216 |
+
priority=1,
|
| 217 |
+
context=context
|
| 218 |
+
)]
|
| 219 |
+
|
| 220 |
+
async def _decompose_explain(
|
| 221 |
+
self,
|
| 222 |
+
code: str,
|
| 223 |
+
language: Language,
|
| 224 |
+
context: Optional[str],
|
| 225 |
+
trace: Optional[str]
|
| 226 |
+
) -> List[Subtask]:
|
| 227 |
+
"""Decompose explain task"""
|
| 228 |
+
return [Subtask(
|
| 229 |
+
subtask_type="explain_code",
|
| 230 |
+
code=code,
|
| 231 |
+
priority=1,
|
| 232 |
+
context=context
|
| 233 |
+
)]
|
| 234 |
+
|
| 235 |
+
async def _decompose_format(
|
| 236 |
+
self,
|
| 237 |
+
code: str,
|
| 238 |
+
language: Language,
|
| 239 |
+
context: Optional[str],
|
| 240 |
+
trace: Optional[str]
|
| 241 |
+
) -> List[Subtask]:
|
| 242 |
+
"""Decompose format task"""
|
| 243 |
+
subtasks = []
|
| 244 |
+
|
| 245 |
+
if language == Language.PYTHON:
|
| 246 |
+
# Python-specific formatting pipeline
|
| 247 |
+
subtasks.append(Subtask(
|
| 248 |
+
subtask_type="format_imports",
|
| 249 |
+
code=code,
|
| 250 |
+
priority=1
|
| 251 |
+
))
|
| 252 |
+
subtasks.append(Subtask(
|
| 253 |
+
subtask_type="format_code",
|
| 254 |
+
code=code,
|
| 255 |
+
priority=2
|
| 256 |
+
))
|
| 257 |
+
else:
|
| 258 |
+
subtasks.append(Subtask(
|
| 259 |
+
subtask_type="format_general",
|
| 260 |
+
code=code,
|
| 261 |
+
priority=1
|
| 262 |
+
))
|
| 263 |
+
|
| 264 |
+
return subtasks
|
| 265 |
+
|
| 266 |
+
async def _decompose_translate(
|
| 267 |
+
self,
|
| 268 |
+
code: str,
|
| 269 |
+
language: Language,
|
| 270 |
+
context: Optional[str],
|
| 271 |
+
trace: Optional[str]
|
| 272 |
+
) -> List[Subtask]:
|
| 273 |
+
"""Decompose translation task"""
|
| 274 |
+
return [Subtask(
|
| 275 |
+
subtask_type="translate_code",
|
| 276 |
+
code=code,
|
| 277 |
+
priority=1,
|
| 278 |
+
context=context
|
| 279 |
+
)]
|
| 280 |
+
|
| 281 |
+
# Helper methods for error detection
|
| 282 |
+
|
| 283 |
+
def _has_syntax_errors(self, code: str, language: Language) -> bool:
|
| 284 |
+
"""Check if code has syntax errors"""
|
| 285 |
+
if language == Language.PYTHON:
|
| 286 |
+
try:
|
| 287 |
+
compile(code, '<string>', 'exec')
|
| 288 |
+
return False
|
| 289 |
+
except SyntaxError:
|
| 290 |
+
return True
|
| 291 |
+
return False
|
| 292 |
+
|
| 293 |
+
def _has_import_errors(self, code: str, trace: Optional[str]) -> bool:
|
| 294 |
+
"""Check if there are import-related errors"""
|
| 295 |
+
if not trace:
|
| 296 |
+
return False
|
| 297 |
+
|
| 298 |
+
import_error_indicators = [
|
| 299 |
+
"ImportError",
|
| 300 |
+
"ModuleNotFoundError",
|
| 301 |
+
"cannot import",
|
| 302 |
+
"No module named"
|
| 303 |
+
]
|
| 304 |
+
|
| 305 |
+
return any(indicator in trace for indicator in import_error_indicators)
|
| 306 |
+
|
| 307 |
+
|
| 308 |
+
# Global instance
|
| 309 |
+
task_decomposer = TaskDecomposer()
|
backend/app/engines/__init__.py
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""SLM Engines package"""
|
| 2 |
+
from app.engines.base import BaseEngine
|
| 3 |
+
from app.engines.starcoder import StarCoderEngine
|
| 4 |
+
from app.engines.codet5 import CodeT5Engine
|
| 5 |
+
|
| 6 |
+
__all__ = [
|
| 7 |
+
"BaseEngine",
|
| 8 |
+
"StarCoderEngine",
|
| 9 |
+
"CodeT5Engine"
|
| 10 |
+
]
|
backend/app/engines/base.py
ADDED
|
@@ -0,0 +1,279 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Base class for all SLM engines
|
| 3 |
+
|
| 4 |
+
Engines are neural network-based components that use
|
| 5 |
+
Small Language Models for code understanding and generation.
|
| 6 |
+
"""
|
| 7 |
+
from abc import ABC, abstractmethod
|
| 8 |
+
from typing import Dict, Any, Optional
|
| 9 |
+
import logging
|
| 10 |
+
|
| 11 |
+
from app.models.schemas import TaskType, Language
|
| 12 |
+
|
| 13 |
+
logger = logging.getLogger(__name__)
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class BaseEngine(ABC):
|
| 17 |
+
"""Base class for all SLM engines"""
|
| 18 |
+
|
| 19 |
+
def __init__(self, name: str, model_path: Optional[str] = None):
|
| 20 |
+
self.name = name
|
| 21 |
+
self.model_path = model_path
|
| 22 |
+
self.model = None
|
| 23 |
+
self.initialized = False
|
| 24 |
+
logger.info(f"Creating engine: {name}")
|
| 25 |
+
|
| 26 |
+
@abstractmethod
|
| 27 |
+
async def initialize(self):
|
| 28 |
+
"""
|
| 29 |
+
Initialize the engine and load model
|
| 30 |
+
|
| 31 |
+
Should set self.initialized = True when done
|
| 32 |
+
"""
|
| 33 |
+
pass
|
| 34 |
+
|
| 35 |
+
@abstractmethod
|
| 36 |
+
async def process(
|
| 37 |
+
self,
|
| 38 |
+
task: TaskType,
|
| 39 |
+
code: str,
|
| 40 |
+
language: Language,
|
| 41 |
+
context: Optional[str] = None,
|
| 42 |
+
trace: Optional[str] = None,
|
| 43 |
+
history: Optional[list] = None,
|
| 44 |
+
**kwargs
|
| 45 |
+
) -> Dict[str, Any]:
|
| 46 |
+
"""
|
| 47 |
+
Process a task using the SLM
|
| 48 |
+
|
| 49 |
+
Args:
|
| 50 |
+
task: Type of task
|
| 51 |
+
code: Source code
|
| 52 |
+
language: Programming language
|
| 53 |
+
context: Additional context
|
| 54 |
+
trace: Error trace
|
| 55 |
+
history: Conversation history for context
|
| 56 |
+
**kwargs: Additional parameters
|
| 57 |
+
|
| 58 |
+
Returns:
|
| 59 |
+
Dict with:
|
| 60 |
+
- success: bool
|
| 61 |
+
- result: str (generated/fixed code or explanation)
|
| 62 |
+
- explanation: Optional[str]
|
| 63 |
+
- suggestions: Optional[List[str]]
|
| 64 |
+
"""
|
| 65 |
+
pass
|
| 66 |
+
|
| 67 |
+
@abstractmethod
|
| 68 |
+
async def shutdown(self):
|
| 69 |
+
"""Cleanup and free resources"""
|
| 70 |
+
pass
|
| 71 |
+
|
| 72 |
+
def build_prompt(self, task: TaskType, code: str, context: Optional[str] = None) -> str:
|
| 73 |
+
"""
|
| 74 |
+
Build prompt for SLM based on task type
|
| 75 |
+
|
| 76 |
+
Args:
|
| 77 |
+
task: Type of task
|
| 78 |
+
code: Source code
|
| 79 |
+
context: Additional context
|
| 80 |
+
|
| 81 |
+
Returns:
|
| 82 |
+
Formatted prompt string
|
| 83 |
+
"""
|
| 84 |
+
if task == TaskType.FIX:
|
| 85 |
+
return self._build_fix_prompt(code, context)
|
| 86 |
+
elif task == TaskType.EXPLAIN:
|
| 87 |
+
return self._build_explain_prompt(code, context)
|
| 88 |
+
elif task == TaskType.REFACTOR:
|
| 89 |
+
return self._build_refactor_prompt(code, context)
|
| 90 |
+
elif task == TaskType.TEST:
|
| 91 |
+
return self._build_test_prompt(code, context)
|
| 92 |
+
elif task == TaskType.BOILERPLATE:
|
| 93 |
+
return self._build_boilerplate_prompt(context)
|
| 94 |
+
elif task == TaskType.TRANSLATE:
|
| 95 |
+
return self._build_translate_prompt(code, context)
|
| 96 |
+
else:
|
| 97 |
+
return f"Code:\n{code}\n\nTask: {task}"
|
| 98 |
+
|
| 99 |
+
def get_stop_tokens(self, task: TaskType) -> list:
|
| 100 |
+
"""
|
| 101 |
+
Get stop tokens to prevent over-generation
|
| 102 |
+
|
| 103 |
+
Args:
|
| 104 |
+
task: Type of task
|
| 105 |
+
|
| 106 |
+
Returns:
|
| 107 |
+
List of stop tokens
|
| 108 |
+
"""
|
| 109 |
+
# Common stop tokens
|
| 110 |
+
common_stops = ["```", "\n\n\n", "# Example", "# Test"]
|
| 111 |
+
|
| 112 |
+
# Task-specific stops
|
| 113 |
+
task_stops = {
|
| 114 |
+
TaskType.FIX: ["# Fixed code:", "# Original:"],
|
| 115 |
+
TaskType.BOILERPLATE: ["# Usage:", "# Example usage:"],
|
| 116 |
+
TaskType.TEST: ["# Run tests:"],
|
| 117 |
+
TaskType.EXPLAIN: ["# Code:", "# Summary:"],
|
| 118 |
+
}
|
| 119 |
+
|
| 120 |
+
return common_stops + task_stops.get(task, [])
|
| 121 |
+
|
| 122 |
+
def _build_fix_prompt(self, code: str, context: Optional[str]) -> str:
|
| 123 |
+
"""Build prompt for code fixing"""
|
| 124 |
+
prompt = """Fix the following Python code. Return ONLY the corrected code without explanation.
|
| 125 |
+
|
| 126 |
+
Rules:
|
| 127 |
+
- Fix syntax errors
|
| 128 |
+
- Fix logic errors
|
| 129 |
+
- Maintain original functionality
|
| 130 |
+
- Keep the same structure
|
| 131 |
+
- Do not add comments unless necessary
|
| 132 |
+
|
| 133 |
+
"""
|
| 134 |
+
if context:
|
| 135 |
+
prompt += f"Additional context: {context}\n\n"
|
| 136 |
+
|
| 137 |
+
prompt += f"Code to fix:\n```python\n{code}\n```\n\n"
|
| 138 |
+
prompt += "Fixed code:\n```python\n"
|
| 139 |
+
return prompt
|
| 140 |
+
|
| 141 |
+
def _build_explain_prompt(self, code: str, context: Optional[str]) -> str:
|
| 142 |
+
"""Build prompt for explaining code"""
|
| 143 |
+
prompt = """Explain the following Python code in detail.
|
| 144 |
+
|
| 145 |
+
Rules:
|
| 146 |
+
- Provide a high-level summary.
|
| 147 |
+
- Break down the code into logical sections and explain each.
|
| 148 |
+
- Highlight key concepts and potential improvements.
|
| 149 |
+
- Use clear and concise language.
|
| 150 |
+
|
| 151 |
+
"""
|
| 152 |
+
if context:
|
| 153 |
+
prompt += f"Focus on: {context}\n\n"
|
| 154 |
+
|
| 155 |
+
prompt += f"Code to explain:\n```python\n{code}\n```\n\n"
|
| 156 |
+
prompt += "Explanation:\n"
|
| 157 |
+
return prompt
|
| 158 |
+
|
| 159 |
+
def _build_refactor_prompt(self, code: str, context: Optional[str]) -> str:
|
| 160 |
+
"""Build prompt for refactoring code"""
|
| 161 |
+
prompt = """Refactor the following Python code to improve readability, maintainability, and performance. Return ONLY the refactored code without explanation.
|
| 162 |
+
|
| 163 |
+
Rules:
|
| 164 |
+
- Maintain original functionality.
|
| 165 |
+
- Apply Python best practices and idioms.
|
| 166 |
+
- Improve variable names, function structure, and overall design.
|
| 167 |
+
- Do not add comments unless necessary.
|
| 168 |
+
|
| 169 |
+
"""
|
| 170 |
+
if context:
|
| 171 |
+
prompt += f"Refactoring requirements: {context}\n\n"
|
| 172 |
+
|
| 173 |
+
prompt += f"Original code:\n```python\n{code}\n```\n\n"
|
| 174 |
+
prompt += "Refactored code:\n```python\n"
|
| 175 |
+
return prompt
|
| 176 |
+
|
| 177 |
+
def _build_test_prompt(self, code: str, context: Optional[str]) -> str:
|
| 178 |
+
"""Build prompt for generating unit tests"""
|
| 179 |
+
prompt = """Generate comprehensive unit tests for the following Python code using `pytest`. Return ONLY the test code without explanation.
|
| 180 |
+
|
| 181 |
+
Rules:
|
| 182 |
+
- Cover normal cases, edge cases, and error cases.
|
| 183 |
+
- Use descriptive test function names.
|
| 184 |
+
- Include assertions for expected behavior.
|
| 185 |
+
- Use `pytest.fixture` for setup if needed.
|
| 186 |
+
|
| 187 |
+
"""
|
| 188 |
+
if context:
|
| 189 |
+
prompt += f"Test requirements: {context}\n\n"
|
| 190 |
+
|
| 191 |
+
prompt += f"Code to test:\n```python\n{code}\n```\n\n"
|
| 192 |
+
prompt += "Test code:\n```python\n"
|
| 193 |
+
return prompt
|
| 194 |
+
|
| 195 |
+
def _build_translate_prompt(self, code: str, context: Optional[str]) -> str:
|
| 196 |
+
"""Build prompt for translating code"""
|
| 197 |
+
prompt = """Translate the following code. Return ONLY the translated code without explanation.
|
| 198 |
+
|
| 199 |
+
"""
|
| 200 |
+
if context:
|
| 201 |
+
prompt += f"{context}\n\n"
|
| 202 |
+
|
| 203 |
+
prompt += f"Code to translate:\n```\n{code}\n```\n\n"
|
| 204 |
+
prompt += "Translated code:\n```\n"
|
| 205 |
+
return prompt
|
| 206 |
+
|
| 207 |
+
def _build_boilerplate_prompt(self, context: Optional[str]) -> str:
|
| 208 |
+
"""Build prompt for boilerplate generation"""
|
| 209 |
+
prompt = """Generate clean, well-structured Python code based on the description below.
|
| 210 |
+
|
| 211 |
+
Requirements:
|
| 212 |
+
- Follow PEP 8 style guide
|
| 213 |
+
- Include docstrings
|
| 214 |
+
- Handle edge cases
|
| 215 |
+
- Use type hints where appropriate
|
| 216 |
+
- Keep it simple and readable
|
| 217 |
+
|
| 218 |
+
"""
|
| 219 |
+
if context:
|
| 220 |
+
prompt += f"Description: {context}\n\n"
|
| 221 |
+
else:
|
| 222 |
+
prompt += "Description: Create a basic implementation\n\n"
|
| 223 |
+
|
| 224 |
+
prompt += "Code:\n```python\n"
|
| 225 |
+
return prompt
|
| 226 |
+
|
| 227 |
+
def _default_prompt(
|
| 228 |
+
self,
|
| 229 |
+
code: str,
|
| 230 |
+
language: Language,
|
| 231 |
+
context: Optional[str],
|
| 232 |
+
trace: Optional[str]
|
| 233 |
+
) -> str:
|
| 234 |
+
"""Default prompt"""
|
| 235 |
+
return f"Process the following {language} code:\n\n```{language}\n{code}\n```"
|
| 236 |
+
|
| 237 |
+
def _format_result(
|
| 238 |
+
self,
|
| 239 |
+
success: bool,
|
| 240 |
+
result: Optional[str] = None,
|
| 241 |
+
explanation: Optional[str] = None,
|
| 242 |
+
suggestions: Optional[list] = None
|
| 243 |
+
) -> Dict[str, Any]:
|
| 244 |
+
"""Helper to format results consistently"""
|
| 245 |
+
return {
|
| 246 |
+
"success": success,
|
| 247 |
+
"result": result,
|
| 248 |
+
"explanation": explanation,
|
| 249 |
+
"suggestions": suggestions or []
|
| 250 |
+
}
|
| 251 |
+
|
| 252 |
+
def _extract_code_from_response(self, response: str) -> str:
|
| 253 |
+
"""Extract code from model response (handles markdown code blocks)"""
|
| 254 |
+
import re
|
| 255 |
+
|
| 256 |
+
# Look for ```language\ncode\n``` pattern
|
| 257 |
+
pattern = r'```(?:\w+)?\n(.*?)\n?```'
|
| 258 |
+
matches = re.findall(pattern, response, re.DOTALL)
|
| 259 |
+
|
| 260 |
+
if matches:
|
| 261 |
+
# Return first code block
|
| 262 |
+
return matches[0].strip()
|
| 263 |
+
|
| 264 |
+
# If no code blocks, check if response looks like code already
|
| 265 |
+
# (happens when prompt ends with ``` and model just generates code)
|
| 266 |
+
lines = response.strip().split('\n')
|
| 267 |
+
|
| 268 |
+
# Skip empty lines at start/end
|
| 269 |
+
while lines and not lines[0].strip():
|
| 270 |
+
lines.pop(0)
|
| 271 |
+
while lines and not lines[-1].strip():
|
| 272 |
+
lines.pop()
|
| 273 |
+
|
| 274 |
+
# If we have content, return it
|
| 275 |
+
if lines:
|
| 276 |
+
return '\n'.join(lines)
|
| 277 |
+
|
| 278 |
+
# Empty response
|
| 279 |
+
return ""
|
backend/app/engines/codet5.py
ADDED
|
@@ -0,0 +1,180 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
CodeT5 engine implementation
|
| 3 |
+
|
| 4 |
+
Uses CodeT5-small for code explanation and translation.
|
| 5 |
+
Loaded via HuggingFace transformers.
|
| 6 |
+
"""
|
| 7 |
+
import logging
|
| 8 |
+
from typing import Dict, Any, Optional
|
| 9 |
+
from pathlib import Path
|
| 10 |
+
|
| 11 |
+
from app.engines.base import BaseEngine
|
| 12 |
+
from app.models.schemas import TaskType, Language
|
| 13 |
+
from app.config import settings
|
| 14 |
+
from app.utils.localization import get_string
|
| 15 |
+
|
| 16 |
+
logger = logging.getLogger(__name__)
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class CodeT5Engine(BaseEngine):
|
| 20 |
+
"""CodeT5-small engine for explanations and translation"""
|
| 21 |
+
|
| 22 |
+
def __init__(self):
|
| 23 |
+
super().__init__(
|
| 24 |
+
name="codet5",
|
| 25 |
+
model_path=str(settings.codet5_path)
|
| 26 |
+
)
|
| 27 |
+
self.tokenizer = None
|
| 28 |
+
self.model_instance = None
|
| 29 |
+
self.torch = None
|
| 30 |
+
|
| 31 |
+
async def initialize(self):
|
| 32 |
+
"""Load CodeT5 model"""
|
| 33 |
+
if self.initialized:
|
| 34 |
+
logger.info("CodeT5 already initialized")
|
| 35 |
+
return
|
| 36 |
+
|
| 37 |
+
logger.info(f"Loading CodeT5 from {self.model_path}")
|
| 38 |
+
|
| 39 |
+
try:
|
| 40 |
+
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
|
| 41 |
+
import torch
|
| 42 |
+
self.torch = torch
|
| 43 |
+
|
| 44 |
+
model_name = "Salesforce/codet5-small"
|
| 45 |
+
if Path(self.model_path).exists():
|
| 46 |
+
model_name = str(self.model_path)
|
| 47 |
+
else:
|
| 48 |
+
logger.warning(f"Local model not found at {self.model_path}, using default: {model_name}")
|
| 49 |
+
|
| 50 |
+
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
|
| 51 |
+
self.model_instance = AutoModelForSeq2SeqLM.from_pretrained(model_name)
|
| 52 |
+
self.model_instance.eval()
|
| 53 |
+
|
| 54 |
+
if self.torch.cuda.is_available():
|
| 55 |
+
self.model_instance = self.model_instance.cuda()
|
| 56 |
+
logger.info("CodeT5 loaded on GPU")
|
| 57 |
+
else:
|
| 58 |
+
logger.info("CodeT5 loaded on CPU")
|
| 59 |
+
|
| 60 |
+
self.initialized = True
|
| 61 |
+
logger.info("CodeT5 loaded successfully")
|
| 62 |
+
|
| 63 |
+
except Exception as e:
|
| 64 |
+
logger.error(f"Failed to load CodeT5: {e}")
|
| 65 |
+
raise
|
| 66 |
+
|
| 67 |
+
async def process(
|
| 68 |
+
self,
|
| 69 |
+
task: TaskType,
|
| 70 |
+
code: str,
|
| 71 |
+
language: Language,
|
| 72 |
+
context: Optional[str] = None,
|
| 73 |
+
trace: Optional[str] = None,
|
| 74 |
+
**kwargs
|
| 75 |
+
) -> Dict[str, Any]:
|
| 76 |
+
"""Process task with CodeT5"""
|
| 77 |
+
if not self.initialized:
|
| 78 |
+
await self.initialize()
|
| 79 |
+
|
| 80 |
+
try:
|
| 81 |
+
prompt = self._build_codet5_prompt(task, code, language, context, trace)
|
| 82 |
+
logger.info(f"CodeT5 processing {task} for {language}")
|
| 83 |
+
logger.debug(f"Prompt: {prompt[:200]}...")
|
| 84 |
+
|
| 85 |
+
inputs = self.tokenizer(prompt, return_tensors="pt", max_length=512, truncation=True)
|
| 86 |
+
if self.torch.cuda.is_available():
|
| 87 |
+
inputs = {k: v.cuda() for k, v in inputs.items()}
|
| 88 |
+
|
| 89 |
+
with self.torch.no_grad():
|
| 90 |
+
outputs = self.model_instance.generate(
|
| 91 |
+
**inputs,
|
| 92 |
+
max_length=settings.max_tokens,
|
| 93 |
+
temperature=settings.temperature,
|
| 94 |
+
num_beams=2,
|
| 95 |
+
early_stopping=True
|
| 96 |
+
)
|
| 97 |
+
generated_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
|
| 98 |
+
|
| 99 |
+
if task == TaskType.EXPLAIN:
|
| 100 |
+
return self._format_result(
|
| 101 |
+
success=True,
|
| 102 |
+
explanation=generated_text.strip(),
|
| 103 |
+
suggestions=self._get_explanation_suggestions()
|
| 104 |
+
)
|
| 105 |
+
elif task == TaskType.TRANSLATE:
|
| 106 |
+
return self._format_result(
|
| 107 |
+
success=True,
|
| 108 |
+
result=generated_text.strip(),
|
| 109 |
+
explanation=get_string("codet5_translate_explanation"),
|
| 110 |
+
suggestions=[get_string("codet5_translate_suggestion")]
|
| 111 |
+
)
|
| 112 |
+
else:
|
| 113 |
+
return self._format_result(success=True, result=generated_text.strip())
|
| 114 |
+
|
| 115 |
+
except Exception as e:
|
| 116 |
+
logger.error(f"CodeT5 processing failed: {e}", exc_info=True)
|
| 117 |
+
return self._format_result(
|
| 118 |
+
success=False,
|
| 119 |
+
explanation=get_string("codet5_error", error=str(e))
|
| 120 |
+
)
|
| 121 |
+
|
| 122 |
+
def _build_codet5_prompt(
|
| 123 |
+
self,
|
| 124 |
+
task: TaskType,
|
| 125 |
+
code: str,
|
| 126 |
+
language: Language,
|
| 127 |
+
context: Optional[str],
|
| 128 |
+
trace: Optional[str]
|
| 129 |
+
) -> str:
|
| 130 |
+
"""Build an improved, task-specific prompt for CodeT5."""
|
| 131 |
+
|
| 132 |
+
base_instruction = f"As an expert programmer, please perform the following task in {settings.language}."
|
| 133 |
+
|
| 134 |
+
if task == TaskType.EXPLAIN:
|
| 135 |
+
if trace:
|
| 136 |
+
instruction = (
|
| 137 |
+
f"{base_instruction} Explain the root cause of the following error trace "
|
| 138 |
+
f"in the context of the provided {language.value} code."
|
| 139 |
+
)
|
| 140 |
+
return f"{instruction}\n\nError Trace:\n{trace}\n\nCode:\n{code}"
|
| 141 |
+
else:
|
| 142 |
+
instruction = (
|
| 143 |
+
f"{base_instruction} Provide a concise summary of the following "
|
| 144 |
+
f"{language.value} code. Describe its purpose and functionality."
|
| 145 |
+
)
|
| 146 |
+
return f"{instruction}\n\nCode:\n{code}"
|
| 147 |
+
|
| 148 |
+
elif task == TaskType.TRANSLATE:
|
| 149 |
+
target_language = "the target language"
|
| 150 |
+
if context:
|
| 151 |
+
match = re.search(r"to (\w+)", context, re.IGNORECASE)
|
| 152 |
+
if match:
|
| 153 |
+
target_language = match.group(1)
|
| 154 |
+
|
| 155 |
+
instruction = f"Translate the following {language.value} code to {target_language}."
|
| 156 |
+
return f"{instruction}\n\n{code}"
|
| 157 |
+
|
| 158 |
+
else:
|
| 159 |
+
return f"Process the following {language.value} code:\n{code}"
|
| 160 |
+
|
| 161 |
+
def _get_explanation_suggestions(self) -> list:
|
| 162 |
+
"""Get suggestions for explanation tasks using localized strings."""
|
| 163 |
+
return [
|
| 164 |
+
get_string("codet5_explanation_suggestion_1"),
|
| 165 |
+
get_string("codet5_explanation_suggestion_2")
|
| 166 |
+
]
|
| 167 |
+
|
| 168 |
+
async def shutdown(self):
|
| 169 |
+
"""Cleanup CodeT5"""
|
| 170 |
+
logger.info("Shutting down CodeT5 engine")
|
| 171 |
+
if self.model_instance:
|
| 172 |
+
del self.model_instance
|
| 173 |
+
self.model_instance = None
|
| 174 |
+
if self.tokenizer:
|
| 175 |
+
del self.tokenizer
|
| 176 |
+
self.tokenizer = None
|
| 177 |
+
if self.torch:
|
| 178 |
+
del self.torch
|
| 179 |
+
self.torch = None
|
| 180 |
+
self.initialized = False
|
backend/app/engines/groq_engine.py
ADDED
|
@@ -0,0 +1,228 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Groq API engine implementation
|
| 3 |
+
|
| 4 |
+
Uses Groq's ultra-fast inference API with models like:
|
| 5 |
+
- llama-3.1-70b-versatile (best quality)
|
| 6 |
+
- llama-3.1-8b-instant (fastest)
|
| 7 |
+
- mixtral-8x7b-32768 (good balance)
|
| 8 |
+
"""
|
| 9 |
+
import logging
|
| 10 |
+
import os
|
| 11 |
+
from typing import Dict, Any, Optional
|
| 12 |
+
|
| 13 |
+
from app.engines.base import BaseEngine
|
| 14 |
+
from app.models.schemas import TaskType, Language
|
| 15 |
+
|
| 16 |
+
logger = logging.getLogger(__name__)
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class GroqEngine(BaseEngine):
|
| 20 |
+
"""Groq API engine for code tasks"""
|
| 21 |
+
|
| 22 |
+
def __init__(self):
|
| 23 |
+
super().__init__(
|
| 24 |
+
name="groq",
|
| 25 |
+
model_path=None # API-based, no local model
|
| 26 |
+
)
|
| 27 |
+
self.client = None
|
| 28 |
+
self.api_key = None
|
| 29 |
+
# Use environment variable for model or default to current stable model
|
| 30 |
+
self.model_name = os.getenv("GROQ_MODEL", "llama-3.3-70b-versatile")
|
| 31 |
+
|
| 32 |
+
async def initialize(self):
|
| 33 |
+
"""Initialize Groq client"""
|
| 34 |
+
if self.initialized:
|
| 35 |
+
logger.info("Groq already initialized")
|
| 36 |
+
return
|
| 37 |
+
|
| 38 |
+
logger.info("Initializing Groq API client")
|
| 39 |
+
|
| 40 |
+
try:
|
| 41 |
+
# Get API key from settings
|
| 42 |
+
from app.config import settings
|
| 43 |
+
self.api_key = settings.groq_api_key
|
| 44 |
+
|
| 45 |
+
if not self.api_key:
|
| 46 |
+
raise ValueError(
|
| 47 |
+
"GROQ_API_KEY not found in configuration.\\n"
|
| 48 |
+
"Please add it to your .env file:\\n"
|
| 49 |
+
"GROQ_API_KEY=gsk_your_key_here"
|
| 50 |
+
)
|
| 51 |
+
|
| 52 |
+
# Import Groq client
|
| 53 |
+
try:
|
| 54 |
+
from groq import Groq
|
| 55 |
+
except ImportError:
|
| 56 |
+
raise ImportError(
|
| 57 |
+
"Groq package not installed.\\n"
|
| 58 |
+
"Install with: pip install groq"
|
| 59 |
+
)
|
| 60 |
+
|
| 61 |
+
self.client = Groq(api_key=self.api_key)
|
| 62 |
+
|
| 63 |
+
self.initialized = True
|
| 64 |
+
logger.info(f"Groq initialized with model: {self.model_name}")
|
| 65 |
+
|
| 66 |
+
except Exception as e:
|
| 67 |
+
logger.error(f"Failed to initialize Groq: {e}")
|
| 68 |
+
raise
|
| 69 |
+
|
| 70 |
+
async def process(
|
| 71 |
+
self,
|
| 72 |
+
task: TaskType,
|
| 73 |
+
code: str,
|
| 74 |
+
language: Language,
|
| 75 |
+
context: Optional[str] = None,
|
| 76 |
+
trace: Optional[str] = None,
|
| 77 |
+
history: Optional[list] = None,
|
| 78 |
+
**kwargs
|
| 79 |
+
) -> Dict[str, Any]:
|
| 80 |
+
"""Process task with Groq"""
|
| 81 |
+
if not self.initialized:
|
| 82 |
+
await self.initialize()
|
| 83 |
+
|
| 84 |
+
try:
|
| 85 |
+
# Build prompt
|
| 86 |
+
prompt = self._build_groq_prompt(task, code, language, context, trace)
|
| 87 |
+
|
| 88 |
+
logger.info(f"Groq processing {task} for {language}")
|
| 89 |
+
logger.debug(f"Using model: {self.model_name}")
|
| 90 |
+
|
| 91 |
+
# Get language from settings
|
| 92 |
+
from app.config import settings
|
| 93 |
+
lang_instruction = ""
|
| 94 |
+
if settings.language == "fr":
|
| 95 |
+
lang_instruction = " Répondez toujours en français. Expliquez le code en français."
|
| 96 |
+
|
| 97 |
+
# Add instruction for file creation
|
| 98 |
+
file_instruction = " If the user asks to create a file, specify the filename in your explanation using the format: [FILE: filename.ext]"
|
| 99 |
+
|
| 100 |
+
# Build message chain
|
| 101 |
+
messages = [
|
| 102 |
+
{
|
| 103 |
+
"role": "system",
|
| 104 |
+
"content": f"You are an expert programmer. Provide clear, concise, and correct code solutions.{lang_instruction}{file_instruction}"
|
| 105 |
+
}
|
| 106 |
+
]
|
| 107 |
+
|
| 108 |
+
# Add conversation history if provided
|
| 109 |
+
if history:
|
| 110 |
+
for msg in history:
|
| 111 |
+
messages.append({
|
| 112 |
+
"role": msg.get("role", "user"),
|
| 113 |
+
"content": msg.get("content", "")
|
| 114 |
+
})
|
| 115 |
+
|
| 116 |
+
# Add current prompt as the last user message
|
| 117 |
+
messages.append({
|
| 118 |
+
"role": "user",
|
| 119 |
+
"content": prompt
|
| 120 |
+
})
|
| 121 |
+
|
| 122 |
+
# Call Groq API
|
| 123 |
+
response = self.client.chat.completions.create(
|
| 124 |
+
model=self.model_name,
|
| 125 |
+
messages=messages,
|
| 126 |
+
temperature=0.3, # Low for code accuracy
|
| 127 |
+
max_tokens=2048,
|
| 128 |
+
top_p=0.95
|
| 129 |
+
)
|
| 130 |
+
|
| 131 |
+
generated_text = response.choices[0].message.content.strip()
|
| 132 |
+
|
| 133 |
+
logger.debug(f"Generated {len(generated_text)} chars")
|
| 134 |
+
|
| 135 |
+
# Extract code from response
|
| 136 |
+
if task in [TaskType.FIX, TaskType.REFACTOR, TaskType.TRANSLATE, TaskType.BOILERPLATE, TaskType.TEST]:
|
| 137 |
+
result_code = self._extract_code_from_response(generated_text)
|
| 138 |
+
|
| 139 |
+
# If no code block found, use whole response
|
| 140 |
+
if not result_code:
|
| 141 |
+
result_code = generated_text
|
| 142 |
+
|
| 143 |
+
return self._format_result(
|
| 144 |
+
success=True,
|
| 145 |
+
result=result_code,
|
| 146 |
+
explanation=f"Generated using Groq ({self.model_name})",
|
| 147 |
+
suggestions=[
|
| 148 |
+
"Review the generated code",
|
| 149 |
+
"Test thoroughly before production use"
|
| 150 |
+
]
|
| 151 |
+
)
|
| 152 |
+
|
| 153 |
+
elif task == TaskType.EXPLAIN:
|
| 154 |
+
return self._format_result(
|
| 155 |
+
success=True,
|
| 156 |
+
explanation=generated_text,
|
| 157 |
+
suggestions=[
|
| 158 |
+
"Review the explanation for accuracy"
|
| 159 |
+
]
|
| 160 |
+
)
|
| 161 |
+
|
| 162 |
+
else:
|
| 163 |
+
return self._format_result(
|
| 164 |
+
success=True,
|
| 165 |
+
result=generated_text
|
| 166 |
+
)
|
| 167 |
+
|
| 168 |
+
except Exception as e:
|
| 169 |
+
logger.error(f"Groq processing failed: {e}", exc_info=True)
|
| 170 |
+
return self._format_result(
|
| 171 |
+
success=False,
|
| 172 |
+
explanation=f"Groq API error: {str(e)}"
|
| 173 |
+
)
|
| 174 |
+
|
| 175 |
+
def _build_groq_prompt(
|
| 176 |
+
self,
|
| 177 |
+
task: TaskType,
|
| 178 |
+
code: str,
|
| 179 |
+
language: Language,
|
| 180 |
+
context: Optional[str],
|
| 181 |
+
trace: Optional[str]
|
| 182 |
+
) -> str:
|
| 183 |
+
"""Build prompt for Groq"""
|
| 184 |
+
|
| 185 |
+
if task == TaskType.BOILERPLATE:
|
| 186 |
+
prompt = f"Write {language.value} code that {context}.\\n\\n"
|
| 187 |
+
prompt += "Provide ONLY the code, no explanations.\\n\\n"
|
| 188 |
+
prompt += f"```{language.value}\\n"
|
| 189 |
+
return prompt
|
| 190 |
+
|
| 191 |
+
elif task == TaskType.FIX:
|
| 192 |
+
prompt = f"Fix this {language.value} code:\\n\\n```{language.value}\\n{code}\\n```\\n\\n"
|
| 193 |
+
if trace:
|
| 194 |
+
prompt += f"Error:\\n```\\n{trace}\\n```\\n\\n"
|
| 195 |
+
prompt += "Provide the corrected code only.\\n\\n"
|
| 196 |
+
prompt += f"```{language.value}\\n"
|
| 197 |
+
return prompt
|
| 198 |
+
|
| 199 |
+
elif task == TaskType.EXPLAIN:
|
| 200 |
+
prompt = f"Explain this {language.value} code:\\n\\n```{language.value}\\n{code}\\n```\\n\\n"
|
| 201 |
+
if context:
|
| 202 |
+
prompt += f"Focus on: {context}\\n\\n"
|
| 203 |
+
prompt += "Provide a clear explanation."
|
| 204 |
+
return prompt
|
| 205 |
+
|
| 206 |
+
elif task == TaskType.REFACTOR:
|
| 207 |
+
prompt = f"Refactor this {language.value} code to improve it:\\n\\n```{language.value}\\n{code}\\n```\\n\\n"
|
| 208 |
+
if context:
|
| 209 |
+
prompt += f"Requirements: {context}\\n\\n"
|
| 210 |
+
prompt += "Provide the refactored code only.\\n\\n"
|
| 211 |
+
prompt += f"```{language.value}\\n"
|
| 212 |
+
return prompt
|
| 213 |
+
|
| 214 |
+
elif task == TaskType.TEST:
|
| 215 |
+
prompt = f"Write comprehensive tests for this {language.value} code:\\n\\n```{language.value}\\n{code}\\n```\\n\\n"
|
| 216 |
+
prompt += f"Use pytest for Python or appropriate framework for {language.value}.\\n\\n"
|
| 217 |
+
prompt += f"```{language.value}\\n"
|
| 218 |
+
return prompt
|
| 219 |
+
|
| 220 |
+
else:
|
| 221 |
+
prompt = f"Process this {language.value} code:\\n\\n```{language.value}\\n{code}\\n```"
|
| 222 |
+
return prompt
|
| 223 |
+
|
| 224 |
+
async def shutdown(self):
|
| 225 |
+
"""Cleanup Groq"""
|
| 226 |
+
logger.info("Shutting down Groq engine")
|
| 227 |
+
self.client = None
|
| 228 |
+
self.initialized = False
|
backend/app/engines/micro_slm.py
ADDED
|
@@ -0,0 +1,135 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Micro-SLM Engine
|
| 3 |
+
|
| 4 |
+
Generic engine for running specialized micro-SLMs (usually based on Phi-2 or similar).
|
| 5 |
+
Loads models dynamically from the registry.
|
| 6 |
+
"""
|
| 7 |
+
import logging
|
| 8 |
+
import torch
|
| 9 |
+
from typing import Dict, Any, Optional, List
|
| 10 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 11 |
+
|
| 12 |
+
from app.engines.base import BaseEngine
|
| 13 |
+
from app.models.schemas import TaskType, Language
|
| 14 |
+
|
| 15 |
+
logger = logging.getLogger(__name__)
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class MicroSLMEngine(BaseEngine):
|
| 19 |
+
"""
|
| 20 |
+
Generic engine for Micro-SLMs.
|
| 21 |
+
Can load any HuggingFace model compatible with AutoModelForCausalLM.
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
def __init__(self, name: str, model_path: str):
|
| 25 |
+
super().__init__(name=name, model_path=model_path)
|
| 26 |
+
self.tokenizer = None
|
| 27 |
+
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 28 |
+
|
| 29 |
+
# Support Hugging Face Hub models with 'hf://' prefix
|
| 30 |
+
# Example: hf://vienoux/boilerplate-slm
|
| 31 |
+
if model_path.startswith("hf://"):
|
| 32 |
+
self.hf_model_id = model_path[5:] # Remove 'hf://' prefix
|
| 33 |
+
self.is_hf_model = True
|
| 34 |
+
else:
|
| 35 |
+
self.hf_model_id = model_path # Use as-is for local paths
|
| 36 |
+
self.is_hf_model = False
|
| 37 |
+
|
| 38 |
+
async def initialize(self):
|
| 39 |
+
"""Load the model and tokenizer"""
|
| 40 |
+
if self.initialized:
|
| 41 |
+
return
|
| 42 |
+
|
| 43 |
+
if self.is_hf_model:
|
| 44 |
+
logger.info(f"Loading Micro-SLM {self.name} from Hugging Face Hub: {self.hf_model_id} on {self.device}")
|
| 45 |
+
else:
|
| 46 |
+
logger.info(f"Loading Micro-SLM {self.name} from local path: {self.model_path} on {self.device}")
|
| 47 |
+
|
| 48 |
+
try:
|
| 49 |
+
# Load from Hugging Face Hub or local path
|
| 50 |
+
self.tokenizer = AutoTokenizer.from_pretrained(
|
| 51 |
+
self.hf_model_id,
|
| 52 |
+
trust_remote_code=True
|
| 53 |
+
)
|
| 54 |
+
self.model = AutoModelForCausalLM.from_pretrained(
|
| 55 |
+
self.hf_model_id,
|
| 56 |
+
torch_dtype=torch.float16 if self.device == "cuda" else torch.float32,
|
| 57 |
+
device_map="auto" if self.device == "cuda" else None,
|
| 58 |
+
trust_remote_code=True
|
| 59 |
+
)
|
| 60 |
+
|
| 61 |
+
if self.device == "cpu":
|
| 62 |
+
self.model.to("cpu")
|
| 63 |
+
|
| 64 |
+
self.initialized = True
|
| 65 |
+
logger.info(f"Micro-SLM {self.name} initialized successfully")
|
| 66 |
+
|
| 67 |
+
except Exception as e:
|
| 68 |
+
logger.error(f"Failed to initialize Micro-SLM {self.name}: {e}")
|
| 69 |
+
raise
|
| 70 |
+
|
| 71 |
+
async def process(
|
| 72 |
+
self,
|
| 73 |
+
task: TaskType,
|
| 74 |
+
code: str,
|
| 75 |
+
language: Language,
|
| 76 |
+
context: Optional[str] = None,
|
| 77 |
+
trace: Optional[str] = None,
|
| 78 |
+
history: Optional[list] = None,
|
| 79 |
+
**kwargs
|
| 80 |
+
) -> Dict[str, Any]:
|
| 81 |
+
"""Process task using the micro-SLM"""
|
| 82 |
+
|
| 83 |
+
if not self.initialized:
|
| 84 |
+
await self.initialize()
|
| 85 |
+
|
| 86 |
+
prompt = self.build_prompt(task, code, context)
|
| 87 |
+
|
| 88 |
+
try:
|
| 89 |
+
inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device)
|
| 90 |
+
|
| 91 |
+
# Generate
|
| 92 |
+
with torch.no_grad():
|
| 93 |
+
outputs = self.model.generate(
|
| 94 |
+
**inputs,
|
| 95 |
+
max_new_tokens=512,
|
| 96 |
+
temperature=0.2,
|
| 97 |
+
do_sample=True,
|
| 98 |
+
pad_token_id=self.tokenizer.eos_token_id
|
| 99 |
+
)
|
| 100 |
+
|
| 101 |
+
response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
|
| 102 |
+
|
| 103 |
+
# Extract just the new part (remove prompt)
|
| 104 |
+
generated_text = response[len(prompt):].strip()
|
| 105 |
+
|
| 106 |
+
# Extract code block if present
|
| 107 |
+
result_code = self._extract_code_from_response(generated_text)
|
| 108 |
+
if not result_code:
|
| 109 |
+
result_code = generated_text # Fallback to full text if no block found
|
| 110 |
+
|
| 111 |
+
return self._format_result(
|
| 112 |
+
success=True,
|
| 113 |
+
result=result_code,
|
| 114 |
+
explanation=f"Generated by {self.name}"
|
| 115 |
+
)
|
| 116 |
+
|
| 117 |
+
except Exception as e:
|
| 118 |
+
logger.error(f"Error in Micro-SLM {self.name}: {e}")
|
| 119 |
+
return self._format_result(
|
| 120 |
+
success=False,
|
| 121 |
+
explanation=f"Error: {str(e)}"
|
| 122 |
+
)
|
| 123 |
+
|
| 124 |
+
async def shutdown(self):
|
| 125 |
+
"""Unload model to free memory"""
|
| 126 |
+
if self.model:
|
| 127 |
+
del self.model
|
| 128 |
+
if self.tokenizer:
|
| 129 |
+
del self.tokenizer
|
| 130 |
+
|
| 131 |
+
if self.device == "cuda":
|
| 132 |
+
torch.cuda.empty_cache()
|
| 133 |
+
|
| 134 |
+
self.initialized = False
|
| 135 |
+
logger.info(f"Micro-SLM {self.name} unloaded")
|
backend/app/engines/phi2.py
ADDED
|
@@ -0,0 +1,191 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Phi-2 engine implementation
|
| 3 |
+
|
| 4 |
+
Uses Microsoft Phi-2 (quantized GGUF) for code generation, fixing, and refactoring.
|
| 5 |
+
Loaded via llama-cpp-python.
|
| 6 |
+
"""
|
| 7 |
+
import logging
|
| 8 |
+
import re
|
| 9 |
+
from typing import Dict, Any, Optional
|
| 10 |
+
from pathlib import Path
|
| 11 |
+
|
| 12 |
+
from app.engines.base import BaseEngine
|
| 13 |
+
from app.models.schemas import TaskType, Language
|
| 14 |
+
from app.config import settings
|
| 15 |
+
|
| 16 |
+
logger = logging.getLogger(__name__)
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class Phi2Engine(BaseEngine):
|
| 20 |
+
"""Phi-2 engine for code tasks"""
|
| 21 |
+
|
| 22 |
+
def __init__(self):
|
| 23 |
+
# Use Phi-2 model path from settings
|
| 24 |
+
model_path = settings.models_dir / "phi-2-Q4_K_M.gguf"
|
| 25 |
+
super().__init__(
|
| 26 |
+
name="phi2",
|
| 27 |
+
model_path=str(model_path)
|
| 28 |
+
)
|
| 29 |
+
self.llm = None
|
| 30 |
+
|
| 31 |
+
async def initialize(self):
|
| 32 |
+
"""Load Phi-2 model"""
|
| 33 |
+
if self.initialized:
|
| 34 |
+
logger.info("Phi-2 already initialized")
|
| 35 |
+
return
|
| 36 |
+
|
| 37 |
+
logger.info(f"Loading Phi-2 from {self.model_path}")
|
| 38 |
+
|
| 39 |
+
try:
|
| 40 |
+
if not Path(self.model_path).exists():
|
| 41 |
+
raise FileNotFoundError(
|
| 42 |
+
f"Model file not found: {self.model_path}\n"
|
| 43 |
+
f"Please run: python scripts/download_phi2.py"
|
| 44 |
+
)
|
| 45 |
+
|
| 46 |
+
from llama_cpp import Llama
|
| 47 |
+
|
| 48 |
+
self.llm = Llama(
|
| 49 |
+
model_path=self.model_path,
|
| 50 |
+
n_ctx=2048, # Context window
|
| 51 |
+
n_threads=4, # CPU threads
|
| 52 |
+
n_batch=512,
|
| 53 |
+
verbose=False
|
| 54 |
+
)
|
| 55 |
+
|
| 56 |
+
self.initialized = True
|
| 57 |
+
logger.info("Phi-2 loaded successfully")
|
| 58 |
+
|
| 59 |
+
except Exception as e:
|
| 60 |
+
logger.error(f"Failed to load Phi-2: {e}")
|
| 61 |
+
raise
|
| 62 |
+
|
| 63 |
+
async def process(
|
| 64 |
+
self,
|
| 65 |
+
task: TaskType,
|
| 66 |
+
code: str,
|
| 67 |
+
language: Language,
|
| 68 |
+
context: Optional[str] = None,
|
| 69 |
+
trace: Optional[str] = None,
|
| 70 |
+
**kwargs
|
| 71 |
+
) -> Dict[str, Any]:
|
| 72 |
+
"""Process task with Phi-2"""
|
| 73 |
+
if not self.initialized:
|
| 74 |
+
await self.initialize()
|
| 75 |
+
|
| 76 |
+
try:
|
| 77 |
+
# Build Phi-2-specific prompt (simple completion format)
|
| 78 |
+
prompt = self._build_phi2_prompt(task, code, context)
|
| 79 |
+
|
| 80 |
+
logger.info(f"Phi-2 processing {task} for {language}")
|
| 81 |
+
logger.debug(f"Prompt: {prompt[:300]}...")
|
| 82 |
+
|
| 83 |
+
# Task-specific max tokens (increased for better output)
|
| 84 |
+
task_max_tokens = {
|
| 85 |
+
TaskType.FIX: 512,
|
| 86 |
+
TaskType.EXPLAIN: 512,
|
| 87 |
+
TaskType.REFACTOR: 1024,
|
| 88 |
+
TaskType.TEST: 1024,
|
| 89 |
+
TaskType.TRANSLATE: 1024,
|
| 90 |
+
TaskType.BOILERPLATE: 512 # Enough for simple functions
|
| 91 |
+
}
|
| 92 |
+
max_tokens = task_max_tokens.get(task, 512)
|
| 93 |
+
|
| 94 |
+
# Get stop tokens
|
| 95 |
+
stop_tokens = ["\n\n\n", "###", "Example:", "Note:"]
|
| 96 |
+
|
| 97 |
+
# Generate with Phi-2
|
| 98 |
+
response = self.llm(
|
| 99 |
+
prompt,
|
| 100 |
+
max_tokens=max_tokens,
|
| 101 |
+
temperature=0.7, # Higher for more creative code
|
| 102 |
+
top_p=0.95,
|
| 103 |
+
top_k=50,
|
| 104 |
+
repeat_penalty=1.15,
|
| 105 |
+
stop=stop_tokens,
|
| 106 |
+
echo=False
|
| 107 |
+
)
|
| 108 |
+
|
| 109 |
+
generated_text = response["choices"][0]["text"].strip()
|
| 110 |
+
|
| 111 |
+
logger.debug(f"Generated: {generated_text[:200]}...")
|
| 112 |
+
|
| 113 |
+
# Extract code from response
|
| 114 |
+
if task in [TaskType.FIX, TaskType.REFACTOR, TaskType.TRANSLATE, TaskType.BOILERPLATE, TaskType.TEST]:
|
| 115 |
+
result_code = self._extract_code_from_response(generated_text)
|
| 116 |
+
|
| 117 |
+
# If extraction fails, use the whole response
|
| 118 |
+
if not result_code:
|
| 119 |
+
result_code = generated_text
|
| 120 |
+
|
| 121 |
+
return self._format_result(
|
| 122 |
+
success=True,
|
| 123 |
+
result=result_code,
|
| 124 |
+
explanation=f"Generated using Phi-2",
|
| 125 |
+
suggestions=[
|
| 126 |
+
"Review the generated code for accuracy",
|
| 127 |
+
"Test the code before using in production"
|
| 128 |
+
]
|
| 129 |
+
)
|
| 130 |
+
|
| 131 |
+
elif task == TaskType.EXPLAIN:
|
| 132 |
+
return self._format_result(
|
| 133 |
+
success=True,
|
| 134 |
+
explanation=generated_text,
|
| 135 |
+
suggestions=[
|
| 136 |
+
"Review the explanation for accuracy",
|
| 137 |
+
"Consider adding inline comments to your code for clarity"
|
| 138 |
+
]
|
| 139 |
+
)
|
| 140 |
+
|
| 141 |
+
else:
|
| 142 |
+
return self._format_result(
|
| 143 |
+
success=True,
|
| 144 |
+
result=generated_text
|
| 145 |
+
)
|
| 146 |
+
|
| 147 |
+
except Exception as e:
|
| 148 |
+
logger.error(f"Phi-2 processing failed: {e}", exc_info=True)
|
| 149 |
+
return self._format_result(
|
| 150 |
+
success=False,
|
| 151 |
+
explanation=f"Error: {str(e)}"
|
| 152 |
+
)
|
| 153 |
+
|
| 154 |
+
def _build_phi2_prompt(self, task: TaskType, code: str, context: Optional[str]) -> str:
|
| 155 |
+
"""Build Phi-2-specific prompt (simple completion format)"""
|
| 156 |
+
|
| 157 |
+
if task == TaskType.BOILERPLATE:
|
| 158 |
+
# For code generation
|
| 159 |
+
prompt = f"Write a Python function that {context}.\n\n"
|
| 160 |
+
prompt += "def "
|
| 161 |
+
return prompt
|
| 162 |
+
|
| 163 |
+
elif task == TaskType.FIX:
|
| 164 |
+
prompt = "Fix this Python code:\n\n"
|
| 165 |
+
prompt += f"{code}\n\n"
|
| 166 |
+
prompt += "Fixed code:\n\n"
|
| 167 |
+
return prompt
|
| 168 |
+
|
| 169 |
+
elif task == TaskType.EXPLAIN:
|
| 170 |
+
prompt = f"Explain this Python code:\n\n{code}\n\nExplanation: "
|
| 171 |
+
return prompt
|
| 172 |
+
|
| 173 |
+
elif task == TaskType.REFACTOR:
|
| 174 |
+
prompt = f"Refactor this Python code to make it better:\n\n{code}\n\nRefactored code:\n\n"
|
| 175 |
+
return prompt
|
| 176 |
+
|
| 177 |
+
elif task == TaskType.TEST:
|
| 178 |
+
prompt = f"Write pytest tests for this Python code:\n\n{code}\n\nimport pytest\n\n"
|
| 179 |
+
return prompt
|
| 180 |
+
|
| 181 |
+
else:
|
| 182 |
+
prompt = f"Process this code:\n\n{code}\n\nResult:\n\n"
|
| 183 |
+
return prompt
|
| 184 |
+
|
| 185 |
+
async def shutdown(self):
|
| 186 |
+
"""Cleanup Phi-2"""
|
| 187 |
+
logger.info("Shutting down Phi-2 engine")
|
| 188 |
+
if self.llm:
|
| 189 |
+
del self.llm
|
| 190 |
+
self.llm = None
|
| 191 |
+
self.initialized = False
|
backend/app/engines/starcoder.py
ADDED
|
@@ -0,0 +1,212 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
StarCoder engine implementation
|
| 3 |
+
|
| 4 |
+
Uses StarCoder2-3B (quantized) for code generation, fixing, and refactoring.
|
| 5 |
+
Loaded via llama-cpp-python (GGUF format).
|
| 6 |
+
"""
|
| 7 |
+
import logging
|
| 8 |
+
import re
|
| 9 |
+
from typing import Dict, Any, Optional
|
| 10 |
+
from pathlib import Path
|
| 11 |
+
|
| 12 |
+
from app.engines.base import BaseEngine
|
| 13 |
+
from app.models.schemas import TaskType, Language
|
| 14 |
+
from app.config import settings
|
| 15 |
+
from app.utils.localization import get_string
|
| 16 |
+
|
| 17 |
+
logger = logging.getLogger(__name__)
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class StarCoderEngine(BaseEngine):
|
| 21 |
+
"""StarCoder2-3B engine for code tasks"""
|
| 22 |
+
|
| 23 |
+
def __init__(self):
|
| 24 |
+
super().__init__(
|
| 25 |
+
name="starcoder",
|
| 26 |
+
model_path=str(settings.starcoder_path)
|
| 27 |
+
)
|
| 28 |
+
self.llm = None
|
| 29 |
+
|
| 30 |
+
async def initialize(self):
|
| 31 |
+
"""Load StarCoder model"""
|
| 32 |
+
if self.initialized:
|
| 33 |
+
logger.info("StarCoder already initialized")
|
| 34 |
+
return
|
| 35 |
+
|
| 36 |
+
logger.info(f"Loading StarCoder from {self.model_path}")
|
| 37 |
+
|
| 38 |
+
try:
|
| 39 |
+
if not Path(self.model_path).exists():
|
| 40 |
+
raise FileNotFoundError(
|
| 41 |
+
f"Model file not found: {self.model_path}\n"
|
| 42 |
+
f"Please run: python scripts/download_models.py"
|
| 43 |
+
)
|
| 44 |
+
from llama_cpp import Llama
|
| 45 |
+
self.llm = Llama(
|
| 46 |
+
model_path=self.model_path,
|
| 47 |
+
n_ctx=settings.n_ctx,
|
| 48 |
+
n_threads=settings.n_threads,
|
| 49 |
+
n_batch=512,
|
| 50 |
+
verbose=False
|
| 51 |
+
)
|
| 52 |
+
self.initialized = True
|
| 53 |
+
logger.info("StarCoder loaded successfully")
|
| 54 |
+
|
| 55 |
+
except Exception as e:
|
| 56 |
+
logger.error(f"Failed to load StarCoder: {e}")
|
| 57 |
+
raise
|
| 58 |
+
|
| 59 |
+
async def process(
|
| 60 |
+
self,
|
| 61 |
+
task: TaskType,
|
| 62 |
+
code: str,
|
| 63 |
+
language: Language,
|
| 64 |
+
context: Optional[str] = None,
|
| 65 |
+
trace: Optional[str] = None,
|
| 66 |
+
**kwargs
|
| 67 |
+
) -> Dict[str, Any]:
|
| 68 |
+
"""Process task with StarCoder"""
|
| 69 |
+
if not self.initialized:
|
| 70 |
+
await self.initialize()
|
| 71 |
+
|
| 72 |
+
try:
|
| 73 |
+
prompt = self._build_prompt(task, code, language, context, trace)
|
| 74 |
+
logger.info(f"StarCoder processing {task} for {language}")
|
| 75 |
+
logger.debug(f"Prompt: {prompt[:300]}...")
|
| 76 |
+
|
| 77 |
+
task_max_tokens = {
|
| 78 |
+
TaskType.FIX: 512,
|
| 79 |
+
TaskType.EXPLAIN: 512,
|
| 80 |
+
TaskType.REFACTOR: 1024,
|
| 81 |
+
TaskType.TEST: 1024,
|
| 82 |
+
TaskType.TRANSLATE: 1024,
|
| 83 |
+
TaskType.BOILERPLATE: 2048
|
| 84 |
+
}
|
| 85 |
+
max_tokens = task_max_tokens.get(task, 512)
|
| 86 |
+
|
| 87 |
+
response = self.llm(
|
| 88 |
+
prompt,
|
| 89 |
+
max_tokens=max_tokens,
|
| 90 |
+
temperature=0.1,
|
| 91 |
+
top_p=0.9,
|
| 92 |
+
top_k=20,
|
| 93 |
+
repeat_penalty=1.2,
|
| 94 |
+
stop=["\n```\n", "```\n\n", "\nExample", "\nNow fix", "\nBuggy code", "Exercise"],
|
| 95 |
+
echo=False
|
| 96 |
+
)
|
| 97 |
+
|
| 98 |
+
generated_text = response["choices"][0]["text"]
|
| 99 |
+
|
| 100 |
+
if task in [TaskType.FIX, TaskType.REFACTOR, TaskType.TRANSLATE, TaskType.BOILERPLATE]:
|
| 101 |
+
result_code = self._extract_code_from_response(generated_text, language)
|
| 102 |
+
explanation = self._extract_explanation(generated_text)
|
| 103 |
+
return self._format_result(
|
| 104 |
+
success=True,
|
| 105 |
+
result=result_code or code, # Return original code if extraction fails
|
| 106 |
+
explanation=explanation,
|
| 107 |
+
suggestions=self._generate_suggestions(task)
|
| 108 |
+
)
|
| 109 |
+
elif task == TaskType.EXPLAIN:
|
| 110 |
+
return self._format_result(
|
| 111 |
+
success=True,
|
| 112 |
+
result=None,
|
| 113 |
+
explanation=generated_text.strip(),
|
| 114 |
+
suggestions=[]
|
| 115 |
+
)
|
| 116 |
+
elif task == TaskType.TEST:
|
| 117 |
+
test_code = self._extract_code_from_response(generated_text, language)
|
| 118 |
+
return self._format_result(
|
| 119 |
+
success=True,
|
| 120 |
+
result=test_code,
|
| 121 |
+
explanation=get_string("starcoder_test_explanation"),
|
| 122 |
+
suggestions=self._generate_suggestions(task)
|
| 123 |
+
)
|
| 124 |
+
else:
|
| 125 |
+
return self._format_result(
|
| 126 |
+
success=True,
|
| 127 |
+
result=generated_text.strip()
|
| 128 |
+
)
|
| 129 |
+
except Exception as e:
|
| 130 |
+
logger.error(f"StarCoder processing failed: {e}", exc_info=True)
|
| 131 |
+
return self._format_result(
|
| 132 |
+
success=False,
|
| 133 |
+
explanation=get_string("starcoder_error", error=str(e))
|
| 134 |
+
)
|
| 135 |
+
|
| 136 |
+
def _build_prompt(self, task: TaskType, code: str, language: Language, context: Optional[str], trace: Optional[str]) -> str:
|
| 137 |
+
"""Builds a task-specific, improved prompt."""
|
| 138 |
+
|
| 139 |
+
system_prompt_content = (
|
| 140 |
+
"You are an expert programmer and a helpful coding assistant. "
|
| 141 |
+
"Provide a clear and concise response. "
|
| 142 |
+
f"The user's preferred language for explanations is {settings.language}."
|
| 143 |
+
)
|
| 144 |
+
|
| 145 |
+
system_block = f"<|system|>\n{system_prompt_content}\n<|end|>".replace('\n', '\n')
|
| 146 |
+
|
| 147 |
+
task_instructions = {
|
| 148 |
+
TaskType.FIX: (
|
| 149 |
+
"The following code has an error. Analyze the code and the error trace, then provide a corrected version. "
|
| 150 |
+
"Explain the fix in a comment or before the code block."
|
| 151 |
+
),
|
| 152 |
+
TaskType.EXPLAIN: "Explain the following code. Describe its purpose, how it works, and any key algorithms or patterns used.",
|
| 153 |
+
TaskType.REFACTOR: "Refactor the following code to improve its readability, performance, or maintainability. Explain the changes made.",
|
| 154 |
+
TaskType.TEST: (f"Generate a comprehensive suite of unit tests for the following {language.value} code "
|
| 155 |
+
"using a standard testing framework (e.g., pytest for Python, Jest for JavaScript)."),
|
| 156 |
+
TaskType.TRANSLATE: f"Translate the following code snippet from its current language to {language.value}. Preserve logic and comments.",
|
| 157 |
+
TaskType.BOILERPLATE: f"Generate boilerplate code for a {context} in {language.value}."
|
| 158 |
+
}
|
| 159 |
+
|
| 160 |
+
instruction = task_instructions.get(task, "Process the following code:")
|
| 161 |
+
|
| 162 |
+
prompt_parts = []
|
| 163 |
+
prompt_parts.append(system_block)
|
| 164 |
+
prompt_parts.append(f"<|user|>\n{instruction}")
|
| 165 |
+
|
| 166 |
+
if context:
|
| 167 |
+
prompt_parts.append(f"\nHere is some additional context and examples:\n```\n{context}\n```")
|
| 168 |
+
|
| 169 |
+
if trace:
|
| 170 |
+
prompt_parts.append(f"\nHere is the error trace:\n```\n{trace}\n```")
|
| 171 |
+
|
| 172 |
+
prompt_parts.append(f"\nHere is the code:\n```{{language.value}}\n{code}\n```")
|
| 173 |
+
prompt_parts.append(f"\n<|assistant|>\n")
|
| 174 |
+
|
| 175 |
+
return "\n".join(prompt_parts)
|
| 176 |
+
|
| 177 |
+
def _extract_code_from_response(self, text: str, language: Language) -> Optional[str]:
|
| 178 |
+
"""Extracts the first code block from the model's response."""
|
| 179 |
+
pattern = re.compile(r"```(?:" + re.escape(language.value) + r")?\s*\n(.*?)\n```", re.DOTALL)
|
| 180 |
+
match = pattern.search(text)
|
| 181 |
+
if match:
|
| 182 |
+
return match.group(1).strip()
|
| 183 |
+
|
| 184 |
+
if text.strip() and "```" not in text:
|
| 185 |
+
return text.strip()
|
| 186 |
+
|
| 187 |
+
return None
|
| 188 |
+
|
| 189 |
+
def _extract_explanation(self, text: str) -> Optional[str]:
|
| 190 |
+
"""Extract explanation from response (text before the first code block)."""
|
| 191 |
+
parts = re.split(r"```.*", text, 1)
|
| 192 |
+
if parts and parts[0].strip():
|
| 193 |
+
return parts[0].strip()
|
| 194 |
+
return None
|
| 195 |
+
|
| 196 |
+
def _generate_suggestions(self, task: TaskType) -> list:
|
| 197 |
+
"""Generate task-specific suggestions using localized strings."""
|
| 198 |
+
suggestion_keys = {
|
| 199 |
+
TaskType.FIX: ["starcoder_suggestion_fix_1", "starcoder_suggestion_fix_2"],
|
| 200 |
+
TaskType.REFACTOR: ["starcoder_suggestion_refactor_1", "starcoder_suggestion_refactor_2"],
|
| 201 |
+
TaskType.TRANSLATE: ["starcoder_suggestion_translate_1", "starcoder_suggestion_translate_2"],
|
| 202 |
+
TaskType.BOILERPLATE: ["starcoder_suggestion_boilerplate_1", "starcoder_suggestion_boilerplate_2"],
|
| 203 |
+
TaskType.TEST: ["starcoder_suggestion_test_1", "starcoder_suggestion_test_2"]
|
| 204 |
+
}
|
| 205 |
+
return [get_string(key) for key in suggestion_keys.get(task, [])]
|
| 206 |
+
|
| 207 |
+
async def shutdown(self):
|
| 208 |
+
"""Cleanup StarCoder"""
|
| 209 |
+
logger.info("Shutting down StarCoder engine")
|
| 210 |
+
if self.llm:
|
| 211 |
+
self.llm = None
|
| 212 |
+
self.initialized = False
|
backend/app/locales/en.json
ADDED
|
@@ -0,0 +1,124 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cli_ready": "Ready! Start chatting or type /help for commands",
|
| 3 |
+
"cli_help_title": "SLM Code Engine - Interactive CLI",
|
| 4 |
+
"cli_help_commands_title": "Available Commands:",
|
| 5 |
+
"cli_help_command_help": "/help - Show this help message",
|
| 6 |
+
"cli_help_command_status": "/status - Show the status of the backend and loaded models",
|
| 7 |
+
"cli_help_command_history": "/history - Show the command history",
|
| 8 |
+
"cli_help_command_save": "/save - (After a good response) Save the last interaction to improve the model",
|
| 9 |
+
"cli_help_command_exit": "/exit - Exit the interactive CLI",
|
| 10 |
+
"cli_help_examples_title": "Examples:",
|
| 11 |
+
"cli_help_example_fix": "fix this code:\n<code>",
|
| 12 |
+
"cli_help_example_explain": "explain this error:\n<traceback>",
|
| 13 |
+
"backend_error_generic": "An unexpected error occurred: {error}",
|
| 14 |
+
"automaton_applied_fix": "Automaton '{automaton_name}' applied a fix.",
|
| 15 |
+
"slm_applied_fix": "SLM '{engine_name}' applied a fix.",
|
| 16 |
+
"explanation_generated": "Explanation generated by '{component_name}'.",
|
| 17 |
+
"ast_fixer_no_errors": "No syntax errors detected by AST scan.",
|
| 18 |
+
"ast_fixer_fixed_issues": "Fixed {issue_count} issue(s): {issues}",
|
| 19 |
+
"ast_fixer_suggestion_linter": "Consider using a linter to prevent future errors.",
|
| 20 |
+
"ast_fixer_failed_autofix": "Found syntax issues but could not auto-fix them.",
|
| 21 |
+
"ast_fixer_suggestion_slm": "Complex errors require SLM analysis.",
|
| 22 |
+
"ast_fixer_syntax_error": "Syntax error detected but could not auto-fix: {error}",
|
| 23 |
+
"ast_fixer_analysis_error": "An error occurred during AST analysis: {error}",
|
| 24 |
+
"ast_fixer_added_colon": "Added missing colon after {keyword} statement on line {line_number}",
|
| 25 |
+
"ast_fixer_fixed_indentation": "Fixed indentation on line {line_number}",
|
| 26 |
+
"ast_fixer_added_paren": "Added missing closing parenthesis on line {line_number}",
|
| 27 |
+
"cmd_help_title": "SLM Code Engine - Available Commands",
|
| 28 |
+
"cmd_help_col_command": "Command",
|
| 29 |
+
"cmd_help_col_description": "Description",
|
| 30 |
+
"cmd_help_desc_help": "Show this help message",
|
| 31 |
+
"cmd_help_desc_exit": "Exit the SLM Code Engine",
|
| 32 |
+
"cmd_help_desc_clear": "Clear conversation history",
|
| 33 |
+
"cmd_help_desc_history": "Show last N messages (default: all)",
|
| 34 |
+
"cmd_help_desc_status": "Check backend status and loaded models",
|
| 35 |
+
"cmd_help_desc_file": "Set current working file",
|
| 36 |
+
"cmd_help_desc_lang": "Set current language (python, javascript, etc.)",
|
| 37 |
+
"cmd_help_desc_save": "Save current session",
|
| 38 |
+
"cmd_help_desc_load": "Load a previous session",
|
| 39 |
+
"cmd_help_desc_read": "Read and display a file",
|
| 40 |
+
"cmd_help_desc_write": "Write content to a file",
|
| 41 |
+
"cmd_read_usage": "Usage: /read <path>",
|
| 42 |
+
"cmd_read_not_found": "File not found: {path}",
|
| 43 |
+
"cmd_read_error": "Error reading file: {error}",
|
| 44 |
+
"cmd_write_usage": "Usage: /write <path> [content]",
|
| 45 |
+
"cmd_write_success": "✓ File written: {path}",
|
| 46 |
+
"cmd_write_error": "Error writing file: {error}",
|
| 47 |
+
"cmd_write_no_content": "No content provided and no previous result to save.",
|
| 48 |
+
"cmd_help_tips_title": "💡 Usage Tips:",
|
| 49 |
+
"cmd_help_tip_1": "• Type naturally: 'fix this code', 'explain this error', etc.",
|
| 50 |
+
"cmd_help_tip_2": "• Paste code directly - the assistant will understand the context",
|
| 51 |
+
"cmd_help_tip_3": "• Use /file to set a working file for context",
|
| 52 |
+
"cmd_help_tip_4": "• Conversation history is maintained automatically",
|
| 53 |
+
"cmd_unknown": "Unknown command: /{cmd}",
|
| 54 |
+
"cmd_unknown_suggestion": "Type /help for available commands",
|
| 55 |
+
"cmd_exit_message": "Goodbye! 👋",
|
| 56 |
+
"cmd_clear_success": "✓ Conversation history cleared",
|
| 57 |
+
"cmd_history_empty": "No conversation history yet",
|
| 58 |
+
"cmd_history_title": "📜 Conversation History ({count} messages)",
|
| 59 |
+
"cmd_status_error": "Error checking status: {error}",
|
| 60 |
+
"cmd_status_title": "🚀 Backend Status",
|
| 61 |
+
"cmd_file_current": "Current file: {file}",
|
| 62 |
+
"cmd_file_none": "No file set",
|
| 63 |
+
"cmd_file_usage": "Usage: /file <path>",
|
| 64 |
+
"cmd_file_not_found": "File not found: {path}",
|
| 65 |
+
"cmd_file_success": "✓ Current file set to: {path}",
|
| 66 |
+
"cmd_lang_current": "Current language: {lang}",
|
| 67 |
+
"cmd_lang_usage": "Usage: /lang <python|javascript|typescript|bash|rust|go|auto>",
|
| 68 |
+
"cmd_lang_invalid": "Invalid language: {lang}",
|
| 69 |
+
"cmd_lang_valid": "Valid options: {options}",
|
| 70 |
+
"cmd_lang_success": "✓ Language set to: {lang}",
|
| 71 |
+
"cmd_save_success": "✓ Session saved: {path}",
|
| 72 |
+
"cmd_save_error": "Error saving session: {error}",
|
| 73 |
+
"cmd_load_usage": "Usage: /load <session_file>",
|
| 74 |
+
"cmd_load_success": "✓ Session loaded: {path}",
|
| 75 |
+
"cmd_load_success_details": "Messages: {count}",
|
| 76 |
+
"cmd_load_error": "Error loading session: {error}",
|
| 77 |
+
"repl_banner_title": "SLM Code Engine - Interactive CLI",
|
| 78 |
+
"repl_banner_subtitle": "Local AI-powered code assistant (100% local)",
|
| 79 |
+
"repl_banner_help_hint": "Type /help for available commands or just chat naturally",
|
| 80 |
+
"repl_backend_check": "Checking backend connection...",
|
| 81 |
+
"repl_backend_conn_error_title": "Connection Error",
|
| 82 |
+
"repl_backend_conn_error_message": "❌ Cannot connect to SLM backend",
|
| 83 |
+
"repl_backend_conn_error_expected": "Expected backend at: {url}",
|
| 84 |
+
"repl_backend_conn_error_start_prompt": "Please start the backend:",
|
| 85 |
+
"repl_backend_conn_success": "✓ Connected to backend (v{version})",
|
| 86 |
+
"repl_backend_models_loaded": "Models loaded: {models}",
|
| 87 |
+
"repl_error_panel_title": "❌ Error",
|
| 88 |
+
"repl_result_panel_title": "✅ {task} Result",
|
| 89 |
+
"repl_explanation_panel_title": "💡 Explanation",
|
| 90 |
+
"repl_suggestions_title": "💡 Suggestions:",
|
| 91 |
+
"repl_performance_info": "⚡ {duration:.2f}s using {used_info}",
|
| 92 |
+
"repl_processing": "🤔 Processing...",
|
| 93 |
+
"repl_connection_lost": "❌ Lost connection to backend",
|
| 94 |
+
"repl_api_error": "❌ API error: {status_code}",
|
| 95 |
+
"repl_generic_error": "❌ Error: {error}",
|
| 96 |
+
"repl_ready": "Ready! Start chatting or type /help for commands",
|
| 97 |
+
"repl_prompt": "You",
|
| 98 |
+
"repl_interrupt_exit_hint": "Use /exit to quit",
|
| 99 |
+
"repl_interrupt_goodbye": "Interrupted. Goodbye! 👋",
|
| 100 |
+
"repl_session_saved": "Session saved",
|
| 101 |
+
"repl_autowrite_confirm": "🤖 Assistant wants to create file: {file}",
|
| 102 |
+
"repl_autowrite_prompt": "Do you want to create this file?",
|
| 103 |
+
"cmd_feedback_saved": "✓ Feedback saved. Thank you for helping the assistant improve!",
|
| 104 |
+
"cmd_feedback_no_last_interaction": "There is no previous interaction to save.",
|
| 105 |
+
"cmd_feedback_error": "Error saving feedback: {error}",
|
| 106 |
+
"cmd_help_desc_session_save": "Save the current chat session to a file",
|
| 107 |
+
"starcoder_test_explanation": "Generated unit tests.",
|
| 108 |
+
"starcoder_suggestion_test_1": "Review and adjust test cases as needed.",
|
| 109 |
+
"starcoder_suggestion_test_2": "Add more edge cases if necessary.",
|
| 110 |
+
"starcoder_suggestion_fix_1": "Review the fix to ensure it addresses the root cause.",
|
| 111 |
+
"starcoder_suggestion_fix_2": "Add tests to prevent regression.",
|
| 112 |
+
"starcoder_suggestion_refactor_1": "Consider adding documentation.",
|
| 113 |
+
"starcoder_suggestion_refactor_2": "Review performance implications.",
|
| 114 |
+
"starcoder_suggestion_translate_1": "Verify behavior matches original code.",
|
| 115 |
+
"starcoder_suggestion_translate_2": "Check for language-specific idioms.",
|
| 116 |
+
"starcoder_suggestion_boilerplate_1": "Customize the generated code for your needs.",
|
| 117 |
+
"starcoder_suggestion_boilerplate_2": "Add error handling as appropriate.",
|
| 118 |
+
"starcoder_error": "Processing error: {error}",
|
| 119 |
+
"codet5_explanation_suggestion_1": "Review the explanation for accuracy.",
|
| 120 |
+
"codet5_explanation_suggestion_2": "Consider adding inline comments to your code for clarity.",
|
| 121 |
+
"codet5_translate_explanation": "Translated code to target language.",
|
| 122 |
+
"codet5_translate_suggestion": "Verify the translation maintains original behavior and syntax.",
|
| 123 |
+
"codet5_error": "Processing error: {error}"
|
| 124 |
+
}
|
backend/app/locales/fr.json
ADDED
|
@@ -0,0 +1,124 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cli_ready": "Prêt ! Discutez ou tapez /help pour voir les commandes",
|
| 3 |
+
"cli_help_title": "SLM Code Engine - CLI Interactif",
|
| 4 |
+
"cli_help_commands_title": "Commandes Disponibles :",
|
| 5 |
+
"cli_help_command_help": "/help - Affiche ce message d'aide",
|
| 6 |
+
"cli_help_command_status": "/status - Affiche le statut du backend et des modèles chargés",
|
| 7 |
+
"cli_help_command_history": "/history - Affiche l'historique des commandes",
|
| 8 |
+
"cli_help_command_save": "/save - (Après une bonne réponse) Sauvegarde la dernière interaction pour améliorer le modèle",
|
| 9 |
+
"cli_help_command_exit": "/exit - Quitte le CLI interactif",
|
| 10 |
+
"cli_help_examples_title": "Exemples :",
|
| 11 |
+
"cli_help_example_fix": "corrige ce code :\n<code>",
|
| 12 |
+
"cli_help_example_explain": "explique cette erreur :\n<traceback>",
|
| 13 |
+
"backend_error_generic": "Une erreur inattendue est survenue : {error}",
|
| 14 |
+
"automaton_applied_fix": "L'automate '{automaton_name}' a appliqué une correction.",
|
| 15 |
+
"slm_applied_fix": "Le SLM '{engine_name}' a appliqué une correction.",
|
| 16 |
+
"explanation_generated": "Explication générée par '{component_name}'.",
|
| 17 |
+
"ast_fixer_no_errors": "Aucune erreur de syntaxe détectée par l'analyse AST.",
|
| 18 |
+
"ast_fixer_fixed_issues": "{issue_count} problème(s) corrigé(s) : {issues}",
|
| 19 |
+
"ast_fixer_suggestion_linter": "Envisagez d'utiliser un linter pour prévenir de futures erreurs.",
|
| 20 |
+
"ast_fixer_failed_autofix": "Des problèmes de syntaxe ont été trouvés mais n'ont pas pu être corrigés automatiquement.",
|
| 21 |
+
"ast_fixer_suggestion_slm": "Les erreurs complexes nécessitent une analyse par le SLM.",
|
| 22 |
+
"ast_fixer_syntax_error": "Erreur de syntaxe détectée mais non corrigible automatiquement : {error}",
|
| 23 |
+
"ast_fixer_analysis_error": "Une erreur est survenue lors de l'analyse AST : {error}",
|
| 24 |
+
"ast_fixer_added_colon": "Deux-points manquant ajouté après l'instruction {keyword} à la ligne {line_number}",
|
| 25 |
+
"ast_fixer_fixed_indentation": "Indentation corrigée à la ligne {line_number}",
|
| 26 |
+
"ast_fixer_added_paren": "Parenthèse fermante manquante ajoutée à la ligne {line_number}",
|
| 27 |
+
"cmd_help_title": "SLM Code Engine - Commandes Disponibles",
|
| 28 |
+
"cmd_help_col_command": "Commande",
|
| 29 |
+
"cmd_help_col_description": "Description",
|
| 30 |
+
"cmd_help_desc_help": "Affiche ce message d'aide",
|
| 31 |
+
"cmd_help_desc_exit": "Quitte le SLM Code Engine",
|
| 32 |
+
"cmd_help_desc_clear": "Efface l'historique de la conversation",
|
| 33 |
+
"cmd_help_desc_history": "Affiche les N derniers messages (défaut : tous)",
|
| 34 |
+
"cmd_help_desc_status": "Vérifie le statut du backend et les modèles chargés",
|
| 35 |
+
"cmd_help_desc_file": "Définit le fichier de travail actuel",
|
| 36 |
+
"cmd_help_desc_lang": "Définit la langue actuelle (python, javascript, etc.)",
|
| 37 |
+
"cmd_help_desc_save": "Sauvegarde la session actuelle",
|
| 38 |
+
"cmd_help_desc_load": "Charge une session précédente",
|
| 39 |
+
"cmd_help_desc_read": "Lit et affiche un fichier",
|
| 40 |
+
"cmd_help_desc_write": "Écrit du contenu dans un fichier",
|
| 41 |
+
"cmd_read_usage": "Usage : /read <chemin>",
|
| 42 |
+
"cmd_read_not_found": "Fichier non trouvé : {path}",
|
| 43 |
+
"cmd_read_error": "Erreur de lecture du fichier : {error}",
|
| 44 |
+
"cmd_write_usage": "Usage : /write <chemin> [contenu]",
|
| 45 |
+
"cmd_write_success": "✓ Fichier écrit : {path}",
|
| 46 |
+
"cmd_write_error": "Erreur d'écriture du fichier : {error}",
|
| 47 |
+
"cmd_write_no_content": "Aucun contenu fourni et aucun résultat précédent à sauvegarder.",
|
| 48 |
+
"cmd_help_tips_title": "💡 Astuces d'Utilisation :",
|
| 49 |
+
"cmd_help_tip_1": "• Tapez naturellement : 'corrige ce code', 'explique cette erreur', etc.",
|
| 50 |
+
"cmd_help_tip_2": "• Collez du code directement - l'assistant comprendra le contexte",
|
| 51 |
+
"cmd_help_tip_3": "• Utilisez /file pour définir un fichier de travail pour le contexte",
|
| 52 |
+
"cmd_help_tip_4": "• L'historique de la conversation est conservé automatiquement",
|
| 53 |
+
"cmd_unknown": "Commande inconnue : /{cmd}",
|
| 54 |
+
"cmd_unknown_suggestion": "Tapez /help pour voir les commandes disponibles",
|
| 55 |
+
"cmd_exit_message": "Au revoir ! 👋",
|
| 56 |
+
"cmd_clear_success": "✓ Historique de la conversation effacé",
|
| 57 |
+
"cmd_history_empty": "Pas encore d'historique de conversation",
|
| 58 |
+
"cmd_history_title": "📜 Historique de la Conversation ({count} messages)",
|
| 59 |
+
"cmd_status_error": "Erreur lors de la vérification du statut : {error}",
|
| 60 |
+
"cmd_status_title": "🚀 Statut du Backend",
|
| 61 |
+
"cmd_file_current": "Fichier actuel : {file}",
|
| 62 |
+
"cmd_file_none": "Aucun fichier défini",
|
| 63 |
+
"cmd_file_usage": "Usage : /file <chemin>",
|
| 64 |
+
"cmd_file_not_found": "Fichier non trouvé : {path}",
|
| 65 |
+
"cmd_file_success": "✓ Fichier de travail défini : {path}",
|
| 66 |
+
"cmd_lang_current": "Langue actuelle : {lang}",
|
| 67 |
+
"cmd_lang_usage": "Usage : /lang <python|javascript|typescript|bash|rust|go|auto>",
|
| 68 |
+
"cmd_lang_invalid": "Langue invalide : {lang}",
|
| 69 |
+
"cmd_lang_valid": "Options valides : {options}",
|
| 70 |
+
"cmd_lang_success": "✓ Langue définie : {lang}",
|
| 71 |
+
"cmd_save_success": "✓ Session sauvegardée : {path}",
|
| 72 |
+
"cmd_save_error": "Erreur lors de la sauvegarde de la session : {error}",
|
| 73 |
+
"cmd_load_usage": "Usage : /load <fichier_session>",
|
| 74 |
+
"cmd_load_success": "✓ Session chargée : {path}",
|
| 75 |
+
"cmd_load_success_details": "Messages : {count}",
|
| 76 |
+
"cmd_load_error": "Erreur lors du chargement de la session : {error}",
|
| 77 |
+
"repl_banner_title": "SLM Code Engine - CLI Interactif",
|
| 78 |
+
"repl_banner_subtitle": "Assistant de code IA (100% local)",
|
| 79 |
+
"repl_banner_help_hint": "Tapez /help pour les commandes ou discutez normalement",
|
| 80 |
+
"repl_backend_check": "Vérification de la connexion au backend...",
|
| 81 |
+
"repl_backend_conn_error_title": "Erreur de Connexion",
|
| 82 |
+
"repl_backend_conn_error_message": "❌ Connexion impossible au backend SLM",
|
| 83 |
+
"repl_backend_conn_error_expected": "Backend attendu à : {url}",
|
| 84 |
+
"repl_backend_conn_error_start_prompt": "Veuillez démarrer le backend :",
|
| 85 |
+
"repl_backend_conn_success": "✓ Connecté au backend (v{version})",
|
| 86 |
+
"repl_backend_models_loaded": "Modèles chargés : {models}",
|
| 87 |
+
"repl_error_panel_title": "❌ Erreur",
|
| 88 |
+
"repl_result_panel_title": "✅ Résultat de {task}",
|
| 89 |
+
"repl_explanation_panel_title": "💡 Explication",
|
| 90 |
+
"repl_suggestions_title": "💡 Suggestions :",
|
| 91 |
+
"repl_performance_info": "⚡ {duration:.2f}s en utilisant {used_info}",
|
| 92 |
+
"repl_processing": "🤔 Traitement en cours...",
|
| 93 |
+
"repl_connection_lost": "❌ Connexion au backend perdue",
|
| 94 |
+
"repl_api_error": "❌ Erreur API : {status_code}",
|
| 95 |
+
"repl_generic_error": "❌ Erreur : {error}",
|
| 96 |
+
"repl_ready": "Prêt ! Discutez ou tapez /help pour voir les commandes",
|
| 97 |
+
"repl_prompt": "Vous",
|
| 98 |
+
"repl_interrupt_exit_hint": "Utilisez /exit pour quitter",
|
| 99 |
+
"repl_interrupt_goodbye": "Interrompu. Au revoir ! 👋",
|
| 100 |
+
"repl_session_saved": "Session sauvegardée",
|
| 101 |
+
"repl_autowrite_confirm": "🤖 L'assistant veut créer le fichier : {file}",
|
| 102 |
+
"repl_autowrite_prompt": "Voulez-vous créer ce fichier ?",
|
| 103 |
+
"cmd_feedback_saved": "✓ Feedback sauvegardé. Merci d'aider l'assistant à s'améliorer !",
|
| 104 |
+
"cmd_feedback_no_last_interaction": "Il n'y a pas d'interaction précédente à sauvegarder.",
|
| 105 |
+
"cmd_feedback_error": "Erreur lors de la sauvegarde du feedback : {error}",
|
| 106 |
+
"cmd_help_desc_session_save": "Sauvegarde la session de chat actuelle dans un fichier",
|
| 107 |
+
"starcoder_test_explanation": "Tests unitaires générés.",
|
| 108 |
+
"starcoder_suggestion_test_1": "Révisez et ajustez les cas de test si nécessaire.",
|
| 109 |
+
"starcoder_suggestion_test_2": "Ajoutez plus de cas limites si nécessaire.",
|
| 110 |
+
"starcoder_suggestion_fix_1": "Vérifiez que la correction résout la cause première du problème.",
|
| 111 |
+
"starcoder_suggestion_fix_2": "Ajoutez des tests pour prévenir les régressions.",
|
| 112 |
+
"starcoder_suggestion_refactor_1": "Envisagez d'ajouter de la documentation.",
|
| 113 |
+
"starcoder_suggestion_refactor_2": "Examinez les implications sur la performance.",
|
| 114 |
+
"starcoder_suggestion_translate_1": "Vérifiez que le comportement correspond au code original.",
|
| 115 |
+
"starcoder_suggestion_translate_2": "Vérifiez les idiomes spécifiques au langage.",
|
| 116 |
+
"starcoder_suggestion_boilerplate_1": "Personnalisez le code généré pour vos besoins.",
|
| 117 |
+
"starcoder_suggestion_boilerplate_2": "Ajoutez la gestion des erreurs de manière appropriée.",
|
| 118 |
+
"starcoder_error": "Erreur de traitement : {error}",
|
| 119 |
+
"codet5_explanation_suggestion_1": "Vérifiez l'exactitude de l'explication.",
|
| 120 |
+
"codet5_explanation_suggestion_2": "Envisagez d'ajouter des commentaires en ligne à votre code pour plus de clarté.",
|
| 121 |
+
"codet5_translate_explanation": "Code traduit dans la langue cible.",
|
| 122 |
+
"codet5_translate_suggestion": "Vérifiez que la traduction conserve le comportement et la syntaxe d'origine.",
|
| 123 |
+
"codet5_error": "Erreur de traitement : {error}"
|
| 124 |
+
}
|
backend/app/main.py
ADDED
|
@@ -0,0 +1,265 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
FastAPI main application for SLM Code Engine
|
| 3 |
+
"""
|
| 4 |
+
import logging
|
| 5 |
+
from contextlib import asynccontextmanager
|
| 6 |
+
from typing import Dict
|
| 7 |
+
|
| 8 |
+
from fastapi import FastAPI, HTTPException
|
| 9 |
+
from fastapi.middleware.cors import CORSMiddleware
|
| 10 |
+
from fastapi.responses import JSONResponse
|
| 11 |
+
|
| 12 |
+
from app.config import settings
|
| 13 |
+
from app.models.schemas import (
|
| 14 |
+
QueryRequest,
|
| 15 |
+
QueryResponse,
|
| 16 |
+
HealthResponse,
|
| 17 |
+
TranslateRequest,
|
| 18 |
+
BoilerplateRequest,
|
| 19 |
+
FeedbackRequest,
|
| 20 |
+
FeedbackResponse,
|
| 21 |
+
)
|
| 22 |
+
from app.core.orchestrator import Orchestrator
|
| 23 |
+
from app import __version__
|
| 24 |
+
|
| 25 |
+
# Configure logging
|
| 26 |
+
logging.basicConfig(
|
| 27 |
+
level=getattr(logging, settings.log_level),
|
| 28 |
+
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
| 29 |
+
)
|
| 30 |
+
logger = logging.getLogger(__name__)
|
| 31 |
+
|
| 32 |
+
# Global orchestrator instance
|
| 33 |
+
orchestrator: Orchestrator = None
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
@asynccontextmanager
|
| 37 |
+
async def lifespan(app: FastAPI):
|
| 38 |
+
"""Lifecycle manager for the application"""
|
| 39 |
+
global orchestrator
|
| 40 |
+
|
| 41 |
+
logger.info("Starting SLM Code Engine...")
|
| 42 |
+
|
| 43 |
+
try:
|
| 44 |
+
# Initialize orchestrator
|
| 45 |
+
orchestrator = Orchestrator()
|
| 46 |
+
await orchestrator.initialize()
|
| 47 |
+
logger.info("Orchestrator initialized successfully")
|
| 48 |
+
|
| 49 |
+
yield
|
| 50 |
+
|
| 51 |
+
except Exception as e:
|
| 52 |
+
logger.error(f"Failed to initialize: {e}")
|
| 53 |
+
raise
|
| 54 |
+
|
| 55 |
+
finally:
|
| 56 |
+
# Cleanup
|
| 57 |
+
logger.info("Shutting down SLM Code Engine...")
|
| 58 |
+
if orchestrator:
|
| 59 |
+
await orchestrator.shutdown()
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
# Create FastAPI app
|
| 63 |
+
app = FastAPI(
|
| 64 |
+
title="SLM Code Engine",
|
| 65 |
+
description="Local AI-powered code assistant using Small Language Models",
|
| 66 |
+
version=__version__,
|
| 67 |
+
lifespan=lifespan
|
| 68 |
+
)
|
| 69 |
+
|
| 70 |
+
# CORS middleware
|
| 71 |
+
app.add_middleware(
|
| 72 |
+
CORSMiddleware,
|
| 73 |
+
allow_origins=["*"], # Configure appropriately for production
|
| 74 |
+
allow_credentials=True,
|
| 75 |
+
allow_methods=["*"],
|
| 76 |
+
allow_headers=["*"],
|
| 77 |
+
)
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
@app.get("/", response_model=Dict[str, str])
|
| 81 |
+
async def root():
|
| 82 |
+
"""Root endpoint"""
|
| 83 |
+
return {
|
| 84 |
+
"name": "SLM Code Engine",
|
| 85 |
+
"version": __version__,
|
| 86 |
+
"status": "running",
|
| 87 |
+
"docs": "/docs"
|
| 88 |
+
}
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
@app.get("/health", response_model=HealthResponse)
|
| 92 |
+
async def health_check():
|
| 93 |
+
"""Health check endpoint"""
|
| 94 |
+
if not orchestrator:
|
| 95 |
+
raise HTTPException(status_code=503, detail="Orchestrator not initialized")
|
| 96 |
+
|
| 97 |
+
status = await orchestrator.get_status()
|
| 98 |
+
|
| 99 |
+
return HealthResponse(
|
| 100 |
+
status="healthy" if status["ready"] else "initializing",
|
| 101 |
+
version=__version__,
|
| 102 |
+
models_loaded=status.get("models_loaded", {}),
|
| 103 |
+
automata_available=status.get("automata_available", [])
|
| 104 |
+
)
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
@app.post("/api/v1/query", response_model=QueryResponse)
|
| 108 |
+
async def process_query(request: QueryRequest):
|
| 109 |
+
"""
|
| 110 |
+
Main endpoint for code processing
|
| 111 |
+
|
| 112 |
+
Supports:
|
| 113 |
+
- fix: Fix code errors
|
| 114 |
+
- explain: Explain code or errors
|
| 115 |
+
- refactor: Refactor code
|
| 116 |
+
- test: Generate unit tests
|
| 117 |
+
- translate: Translate code between languages
|
| 118 |
+
- format: Format code
|
| 119 |
+
- boilerplate: Generate boilerplate code
|
| 120 |
+
"""
|
| 121 |
+
if not orchestrator:
|
| 122 |
+
raise HTTPException(status_code=503, detail="Orchestrator not initialized")
|
| 123 |
+
|
| 124 |
+
try:
|
| 125 |
+
logger.info(f"Processing {request.task} request for {request.language}")
|
| 126 |
+
|
| 127 |
+
result = await orchestrator.process(
|
| 128 |
+
task=request.task,
|
| 129 |
+
code=request.code,
|
| 130 |
+
language=request.language,
|
| 131 |
+
context=request.context,
|
| 132 |
+
trace=request.trace
|
| 133 |
+
)
|
| 134 |
+
|
| 135 |
+
return QueryResponse(**result)
|
| 136 |
+
|
| 137 |
+
except Exception as e:
|
| 138 |
+
logger.error(f"Error processing query: {e}", exc_info=True)
|
| 139 |
+
return QueryResponse(
|
| 140 |
+
success=False,
|
| 141 |
+
task=request.task,
|
| 142 |
+
error=str(e),
|
| 143 |
+
used_automata=False,
|
| 144 |
+
used_slm=False,
|
| 145 |
+
pipeline=[],
|
| 146 |
+
total_duration_ms=0
|
| 147 |
+
)
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
@app.post("/api/v1/translate", response_model=QueryResponse)
|
| 151 |
+
async def translate_code(request: TranslateRequest):
|
| 152 |
+
"""Translate code between programming languages"""
|
| 153 |
+
if not orchestrator:
|
| 154 |
+
raise HTTPException(status_code=503, detail="Orchestrator not initialized")
|
| 155 |
+
|
| 156 |
+
try:
|
| 157 |
+
result = await orchestrator.translate(
|
| 158 |
+
code=request.code,
|
| 159 |
+
source_lang=request.source_language,
|
| 160 |
+
target_lang=request.target_language,
|
| 161 |
+
preserve_comments=request.preserve_comments
|
| 162 |
+
)
|
| 163 |
+
|
| 164 |
+
return QueryResponse(**result)
|
| 165 |
+
|
| 166 |
+
except Exception as e:
|
| 167 |
+
logger.error(f"Error translating code: {e}", exc_info=True)
|
| 168 |
+
return QueryResponse(
|
| 169 |
+
success=False,
|
| 170 |
+
task="translate",
|
| 171 |
+
error=str(e),
|
| 172 |
+
used_automata=False,
|
| 173 |
+
used_slm=False,
|
| 174 |
+
pipeline=[],
|
| 175 |
+
total_duration_ms=0
|
| 176 |
+
)
|
| 177 |
+
|
| 178 |
+
|
| 179 |
+
@app.post("/api/v1/boilerplate", response_model=QueryResponse)
|
| 180 |
+
async def generate_boilerplate(request: BoilerplateRequest):
|
| 181 |
+
"""Generate boilerplate code"""
|
| 182 |
+
if not orchestrator:
|
| 183 |
+
raise HTTPException(status_code=503, detail="Orchestrator not initialized")
|
| 184 |
+
|
| 185 |
+
try:
|
| 186 |
+
result = await orchestrator.generate_boilerplate(
|
| 187 |
+
template_type=request.template_type,
|
| 188 |
+
language=request.language,
|
| 189 |
+
name=request.name,
|
| 190 |
+
options=request.options
|
| 191 |
+
)
|
| 192 |
+
|
| 193 |
+
return QueryResponse(**result)
|
| 194 |
+
|
| 195 |
+
except Exception as e:
|
| 196 |
+
logger.error(f"Error generating boilerplate: {e}", exc_info=True)
|
| 197 |
+
return QueryResponse(
|
| 198 |
+
success=False,
|
| 199 |
+
task="boilerplate",
|
| 200 |
+
error=str(e),
|
| 201 |
+
used_automata=False,
|
| 202 |
+
used_slm=False,
|
| 203 |
+
pipeline=[],
|
| 204 |
+
total_duration_ms=0
|
| 205 |
+
)
|
| 206 |
+
|
| 207 |
+
|
| 208 |
+
from app.storage.feedback import FeedbackLogger
|
| 209 |
+
|
| 210 |
+
|
| 211 |
+
@app.post("/api/v1/feedback", response_model=FeedbackResponse)
|
| 212 |
+
async def log_feedback(request: FeedbackRequest):
|
| 213 |
+
"""
|
| 214 |
+
Endpoint to log positive user feedback on an interaction.
|
| 215 |
+
This feedback is used to improve the model over time.
|
| 216 |
+
"""
|
| 217 |
+
try:
|
| 218 |
+
feedback_logger = FeedbackLogger()
|
| 219 |
+
entry_created = feedback_logger.log_feedback(
|
| 220 |
+
task=request.task.value,
|
| 221 |
+
language=request.language.value,
|
| 222 |
+
request_code=request.request_code,
|
| 223 |
+
response_code=request.response_code,
|
| 224 |
+
response_explanation=request.response_explanation
|
| 225 |
+
)
|
| 226 |
+
|
| 227 |
+
if entry_created:
|
| 228 |
+
message = "Feedback logged successfully. Thank you!"
|
| 229 |
+
else:
|
| 230 |
+
message = "This feedback was already recorded."
|
| 231 |
+
|
| 232 |
+
return FeedbackResponse(
|
| 233 |
+
success=True,
|
| 234 |
+
message=message,
|
| 235 |
+
entry_created=entry_created
|
| 236 |
+
)
|
| 237 |
+
|
| 238 |
+
except Exception as e:
|
| 239 |
+
logger.error(f"Error logging feedback: {e}", exc_info=True)
|
| 240 |
+
raise HTTPException(status_code=500, detail=f"Failed to log feedback: {str(e)}")
|
| 241 |
+
|
| 242 |
+
|
| 243 |
+
@app.exception_handler(Exception)
|
| 244 |
+
async def global_exception_handler(request, exc):
|
| 245 |
+
"""Global exception handler"""
|
| 246 |
+
logger.error(f"Unhandled exception: {exc}", exc_info=True)
|
| 247 |
+
return JSONResponse(
|
| 248 |
+
status_code=500,
|
| 249 |
+
content={
|
| 250 |
+
"error": "Internal server error",
|
| 251 |
+
"detail": str(exc) if settings.debug else "An error occurred"
|
| 252 |
+
}
|
| 253 |
+
)
|
| 254 |
+
|
| 255 |
+
|
| 256 |
+
if __name__ == "__main__":
|
| 257 |
+
import uvicorn
|
| 258 |
+
|
| 259 |
+
uvicorn.run(
|
| 260 |
+
"app.main:app",
|
| 261 |
+
host=settings.api_host,
|
| 262 |
+
port=settings.api_port,
|
| 263 |
+
reload=settings.debug,
|
| 264 |
+
workers=settings.api_workers
|
| 265 |
+
)
|
backend/app/models/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
"""Models package"""
|
backend/app/models/schemas.py
ADDED
|
@@ -0,0 +1,154 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Pydantic models for API requests and responses
|
| 3 |
+
"""
|
| 4 |
+
from enum import Enum
|
| 5 |
+
from typing import Optional, Dict, Any, List
|
| 6 |
+
from pydantic import BaseModel, Field
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class TaskType(str, Enum):
|
| 10 |
+
"""Supported task types"""
|
| 11 |
+
FIX = "fix"
|
| 12 |
+
EXPLAIN = "explain"
|
| 13 |
+
REFACTOR = "refactor"
|
| 14 |
+
TEST = "test"
|
| 15 |
+
TRANSLATE = "translate"
|
| 16 |
+
FORMAT = "format"
|
| 17 |
+
BOILERPLATE = "boilerplate"
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class Language(str, Enum):
|
| 21 |
+
"""Supported programming languages"""
|
| 22 |
+
PYTHON = "python"
|
| 23 |
+
JAVASCRIPT = "javascript"
|
| 24 |
+
TYPESCRIPT = "typescript"
|
| 25 |
+
BASH = "bash"
|
| 26 |
+
RUST = "rust"
|
| 27 |
+
GO = "go"
|
| 28 |
+
AUTO = "auto" # Auto-detect
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
class QueryRequest(BaseModel):
|
| 32 |
+
"""Request for code processing"""
|
| 33 |
+
task: TaskType = Field(..., description="Type of task to perform")
|
| 34 |
+
code: str = Field(..., description="Source code to process")
|
| 35 |
+
language: Language = Field(default=Language.AUTO, description="Programming language")
|
| 36 |
+
context: Optional[str] = Field(default=None, description="Additional context or instructions")
|
| 37 |
+
trace: Optional[str] = Field(default=None, description="Error trace (for fix/explain tasks)")
|
| 38 |
+
history: Optional[List[Dict[str, str]]] = Field(default=None, description="Conversation history for context")
|
| 39 |
+
|
| 40 |
+
class Config:
|
| 41 |
+
json_schema_extra = {
|
| 42 |
+
"example": {
|
| 43 |
+
"task": "fix",
|
| 44 |
+
"code": "def hello)\n print('hello')",
|
| 45 |
+
"language": "python",
|
| 46 |
+
"context": "Fix syntax errors"
|
| 47 |
+
}
|
| 48 |
+
}
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
class ExecutionStep(BaseModel):
|
| 52 |
+
"""Single step in the execution pipeline"""
|
| 53 |
+
step_type: str = Field(..., description="Type of step (automata/slm)")
|
| 54 |
+
component: str = Field(..., description="Component used (e.g., 'black', 'starcoder')")
|
| 55 |
+
duration_ms: float = Field(..., description="Execution duration in milliseconds")
|
| 56 |
+
success: bool = Field(..., description="Whether step succeeded")
|
| 57 |
+
details: Optional[Dict[str, Any]] = Field(default=None, description="Additional details")
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
class QueryResponse(BaseModel):
|
| 61 |
+
"""Response from code processing"""
|
| 62 |
+
success: bool = Field(..., description="Whether request succeeded")
|
| 63 |
+
task: TaskType = Field(..., description="Task type executed")
|
| 64 |
+
|
| 65 |
+
# Results
|
| 66 |
+
result: Optional[str] = Field(default=None, description="Processed code or explanation")
|
| 67 |
+
explanation: Optional[str] = Field(default=None, description="Human-readable explanation")
|
| 68 |
+
suggestions: Optional[List[str]] = Field(default=None, description="Additional suggestions")
|
| 69 |
+
|
| 70 |
+
# Metadata
|
| 71 |
+
used_automata: bool = Field(..., description="Whether automata were used")
|
| 72 |
+
used_slm: bool = Field(..., description="Whether SLM was used")
|
| 73 |
+
pipeline: List[ExecutionStep] = Field(default_factory=list, description="Execution pipeline steps")
|
| 74 |
+
|
| 75 |
+
# Performance
|
| 76 |
+
total_duration_ms: float = Field(..., description="Total execution time")
|
| 77 |
+
|
| 78 |
+
# Error handling
|
| 79 |
+
error: Optional[str] = Field(default=None, description="Error message if failed")
|
| 80 |
+
|
| 81 |
+
class Config:
|
| 82 |
+
json_schema_extra = {
|
| 83 |
+
"example": {
|
| 84 |
+
"success": True,
|
| 85 |
+
"task": "fix",
|
| 86 |
+
"result": "def hello():\n print('hello')",
|
| 87 |
+
"explanation": "Fixed: Missing ':' after function definition and incorrect indentation",
|
| 88 |
+
"suggestions": ["Consider adding type hints", "Add docstring"],
|
| 89 |
+
"used_automata": False,
|
| 90 |
+
"used_slm": True,
|
| 91 |
+
"pipeline": [
|
| 92 |
+
{
|
| 93 |
+
"step_type": "slm",
|
| 94 |
+
"component": "starcoder",
|
| 95 |
+
"duration_ms": 1234.5,
|
| 96 |
+
"success": True
|
| 97 |
+
}
|
| 98 |
+
],
|
| 99 |
+
"total_duration_ms": 1250.0,
|
| 100 |
+
"error": None
|
| 101 |
+
}
|
| 102 |
+
}
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
class HealthResponse(BaseModel):
|
| 106 |
+
"""Health check response"""
|
| 107 |
+
status: str = Field(..., description="Service status")
|
| 108 |
+
version: str = Field(..., description="API version")
|
| 109 |
+
models_loaded: Dict[str, bool] = Field(..., description="Model loading status")
|
| 110 |
+
automata_available: List[str] = Field(..., description="Available automata")
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
class TranslateRequest(BaseModel):
|
| 114 |
+
"""Request for code translation"""
|
| 115 |
+
code: str = Field(..., description="Source code to translate")
|
| 116 |
+
source_language: Language = Field(..., description="Source language")
|
| 117 |
+
target_language: Language = Field(..., description="Target language")
|
| 118 |
+
preserve_comments: bool = Field(default=True, description="Preserve code comments")
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
class BoilerplateRequest(BaseModel):
|
| 122 |
+
"""Request for boilerplate generation"""
|
| 123 |
+
template_type: str = Field(..., description="Type of boilerplate (cli, api, class, etc.)")
|
| 124 |
+
language: Language = Field(..., description="Programming language")
|
| 125 |
+
name: str = Field(..., description="Name for the component")
|
| 126 |
+
options: Optional[Dict[str, Any]] = Field(default=None, description="Additional options")
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
class FeedbackRequest(BaseModel):
|
| 130 |
+
"""Request to log a successful interaction for feedback"""
|
| 131 |
+
task: TaskType = Field(..., description="The task that was performed")
|
| 132 |
+
language: Language = Field(..., description="The programming language")
|
| 133 |
+
request_code: str = Field(..., description="The original user code or query")
|
| 134 |
+
response_code: Optional[str] = Field(default=None, description="The successful code response from the AI")
|
| 135 |
+
response_explanation: Optional[str] = Field(default=None, description="The successful explanation from the AI")
|
| 136 |
+
session_id: Optional[str] = Field(default=None, description="The session ID for context")
|
| 137 |
+
|
| 138 |
+
class Config:
|
| 139 |
+
json_schema_extra = {
|
| 140 |
+
"example": {
|
| 141 |
+
"task": "fix",
|
| 142 |
+
"language": "python",
|
| 143 |
+
"request_code": "def hello)",
|
| 144 |
+
"response_code": "def hello():",
|
| 145 |
+
"response_explanation": "Added missing parentheses."
|
| 146 |
+
}
|
| 147 |
+
}
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
class FeedbackResponse(BaseModel):
|
| 151 |
+
"""Response from logging feedback"""
|
| 152 |
+
success: bool = Field(..., description="Whether the feedback was logged successfully")
|
| 153 |
+
message: str = Field(..., description="A confirmation message")
|
| 154 |
+
entry_created: bool = Field(..., description="Whether a new feedback entry was created (vs. being a duplicate)")
|
backend/app/rag/__init__.py
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
RAG (Retrieval Augmented Generation) module
|
| 3 |
+
|
| 4 |
+
Provides code example retrieval using FAISS vector similarity search
|
| 5 |
+
"""
|
| 6 |
+
from .embedder import CodeEmbedder
|
| 7 |
+
from .vector_store import VectorStore
|
| 8 |
+
from .retriever import CodeRetriever
|
| 9 |
+
|
| 10 |
+
__all__ = ["CodeEmbedder", "VectorStore", "CodeRetriever"]
|
backend/app/rag/embedder.py
ADDED
|
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Code embedder using sentence-transformers
|
| 3 |
+
|
| 4 |
+
Converts code snippets into vector embeddings for similarity search
|
| 5 |
+
"""
|
| 6 |
+
import logging
|
| 7 |
+
from typing import List, Optional
|
| 8 |
+
import numpy as np
|
| 9 |
+
|
| 10 |
+
logger = logging.getLogger(__name__)
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class CodeEmbedder:
|
| 14 |
+
"""Generates embeddings for code using sentence-transformers"""
|
| 15 |
+
|
| 16 |
+
def __init__(self, model_name: str = "microsoft/codebert-base"):
|
| 17 |
+
"""
|
| 18 |
+
Initialize the code embedder
|
| 19 |
+
|
| 20 |
+
Args:
|
| 21 |
+
model_name: HuggingFace model for code embeddings
|
| 22 |
+
Default: microsoft/codebert-base (125M params, fast)
|
| 23 |
+
"""
|
| 24 |
+
self.model_name = model_name
|
| 25 |
+
self.model: Optional[object] = None
|
| 26 |
+
|
| 27 |
+
def initialize(self):
|
| 28 |
+
"""Load the embedding model (lazy loading)"""
|
| 29 |
+
if self.model is not None:
|
| 30 |
+
return
|
| 31 |
+
|
| 32 |
+
try:
|
| 33 |
+
from sentence_transformers import SentenceTransformer
|
| 34 |
+
|
| 35 |
+
logger.info(f"Loading embedding model: {self.model_name}")
|
| 36 |
+
self.model = SentenceTransformer(self.model_name)
|
| 37 |
+
logger.info("Embedding model loaded successfully")
|
| 38 |
+
|
| 39 |
+
except Exception as e:
|
| 40 |
+
logger.error(f"Failed to load embedding model: {e}")
|
| 41 |
+
raise
|
| 42 |
+
|
| 43 |
+
def embed(self, code: str) -> np.ndarray:
|
| 44 |
+
"""
|
| 45 |
+
Generate embedding for a single code snippet
|
| 46 |
+
|
| 47 |
+
Args:
|
| 48 |
+
code: Source code string
|
| 49 |
+
|
| 50 |
+
Returns:
|
| 51 |
+
Embedding vector as numpy array
|
| 52 |
+
"""
|
| 53 |
+
if self.model is None:
|
| 54 |
+
self.initialize()
|
| 55 |
+
|
| 56 |
+
try:
|
| 57 |
+
# Truncate very long code (max 512 tokens for CodeBERT)
|
| 58 |
+
if len(code) > 2000:
|
| 59 |
+
code = code[:2000]
|
| 60 |
+
|
| 61 |
+
embedding = self.model.encode(code, convert_to_numpy=True)
|
| 62 |
+
return embedding
|
| 63 |
+
|
| 64 |
+
except Exception as e:
|
| 65 |
+
logger.error(f"Failed to generate embedding: {e}")
|
| 66 |
+
raise
|
| 67 |
+
|
| 68 |
+
def embed_batch(self, codes: List[str]) -> np.ndarray:
|
| 69 |
+
"""
|
| 70 |
+
Generate embeddings for multiple code snippets
|
| 71 |
+
|
| 72 |
+
Args:
|
| 73 |
+
codes: List of source code strings
|
| 74 |
+
|
| 75 |
+
Returns:
|
| 76 |
+
Matrix of embeddings (n_samples x embedding_dim)
|
| 77 |
+
"""
|
| 78 |
+
if self.model is None:
|
| 79 |
+
self.initialize()
|
| 80 |
+
|
| 81 |
+
try:
|
| 82 |
+
# Truncate long codes
|
| 83 |
+
truncated_codes = [c[:2000] if len(c) > 2000 else c for c in codes]
|
| 84 |
+
|
| 85 |
+
embeddings = self.model.encode(
|
| 86 |
+
truncated_codes,
|
| 87 |
+
convert_to_numpy=True,
|
| 88 |
+
show_progress_bar=True
|
| 89 |
+
)
|
| 90 |
+
|
| 91 |
+
return embeddings
|
| 92 |
+
|
| 93 |
+
except Exception as e:
|
| 94 |
+
logger.error(f"Failed to generate batch embeddings: {e}")
|
| 95 |
+
raise
|
backend/app/rag/retriever.py
ADDED
|
@@ -0,0 +1,215 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Code retriever - High-level interface for RAG
|
| 3 |
+
|
| 4 |
+
Combines embedding and vector search to retrieve similar code examples
|
| 5 |
+
"""
|
| 6 |
+
import logging
|
| 7 |
+
from typing import List, Dict, Any, Optional
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
|
| 10 |
+
from .embedder import CodeEmbedder
|
| 11 |
+
from .vector_store import VectorStore
|
| 12 |
+
from app.models.schemas import Language, TaskType
|
| 13 |
+
|
| 14 |
+
logger = logging.getLogger(__name__)
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class CodeRetriever:
|
| 18 |
+
"""High-level interface for code example retrieval"""
|
| 19 |
+
|
| 20 |
+
def __init__(
|
| 21 |
+
self,
|
| 22 |
+
embedder: Optional[CodeEmbedder] = None,
|
| 23 |
+
vector_store: Optional[VectorStore] = None,
|
| 24 |
+
index_path: Optional[str] = None
|
| 25 |
+
):
|
| 26 |
+
"""
|
| 27 |
+
Initialize code retriever
|
| 28 |
+
|
| 29 |
+
Args:
|
| 30 |
+
embedder: CodeEmbedder instance (creates default if None)
|
| 31 |
+
vector_store: VectorStore instance (creates default if None)
|
| 32 |
+
index_path: Path to FAISS index file
|
| 33 |
+
"""
|
| 34 |
+
self.embedder = embedder or CodeEmbedder()
|
| 35 |
+
self.vector_store = vector_store or VectorStore(
|
| 36 |
+
embedding_dim=768, # CodeBERT dimension
|
| 37 |
+
index_path=index_path
|
| 38 |
+
)
|
| 39 |
+
self.initialized = False
|
| 40 |
+
|
| 41 |
+
def initialize(self):
|
| 42 |
+
"""Initialize embedder and vector store"""
|
| 43 |
+
if self.initialized:
|
| 44 |
+
return
|
| 45 |
+
|
| 46 |
+
try:
|
| 47 |
+
logger.info("Initializing CodeRetriever...")
|
| 48 |
+
|
| 49 |
+
# Initialize embedder
|
| 50 |
+
self.embedder.initialize()
|
| 51 |
+
|
| 52 |
+
# Initialize vector store
|
| 53 |
+
self.vector_store.initialize()
|
| 54 |
+
|
| 55 |
+
self.initialized = True
|
| 56 |
+
logger.info("CodeRetriever initialized successfully")
|
| 57 |
+
|
| 58 |
+
except Exception as e:
|
| 59 |
+
logger.error(f"Failed to initialize CodeRetriever: {e}")
|
| 60 |
+
raise
|
| 61 |
+
|
| 62 |
+
def add_examples(
|
| 63 |
+
self,
|
| 64 |
+
codes: List[str],
|
| 65 |
+
languages: List[Language],
|
| 66 |
+
tasks: List[TaskType],
|
| 67 |
+
descriptions: Optional[List[str]] = None
|
| 68 |
+
):
|
| 69 |
+
"""
|
| 70 |
+
Add code examples to the index
|
| 71 |
+
|
| 72 |
+
Args:
|
| 73 |
+
codes: List of code snippets
|
| 74 |
+
languages: List of programming languages
|
| 75 |
+
tasks: List of task types
|
| 76 |
+
descriptions: Optional list of descriptions
|
| 77 |
+
"""
|
| 78 |
+
if not self.initialized:
|
| 79 |
+
self.initialize()
|
| 80 |
+
|
| 81 |
+
try:
|
| 82 |
+
logger.info(f"Adding {len(codes)} code examples to index")
|
| 83 |
+
|
| 84 |
+
# Generate embeddings
|
| 85 |
+
embeddings = self.embedder.embed_batch(codes)
|
| 86 |
+
|
| 87 |
+
# Prepare metadata
|
| 88 |
+
metadata = []
|
| 89 |
+
for i, (code, lang, task) in enumerate(zip(codes, languages, tasks)):
|
| 90 |
+
meta = {
|
| 91 |
+
"code": code,
|
| 92 |
+
"language": lang.value if hasattr(lang, 'value') else str(lang),
|
| 93 |
+
"task": task.value if hasattr(task, 'value') else str(task),
|
| 94 |
+
"description": descriptions[i] if descriptions and i < len(descriptions) else None
|
| 95 |
+
}
|
| 96 |
+
metadata.append(meta)
|
| 97 |
+
|
| 98 |
+
# Add to vector store
|
| 99 |
+
self.vector_store.add(embeddings, metadata)
|
| 100 |
+
|
| 101 |
+
logger.info(f"Successfully added {len(codes)} examples")
|
| 102 |
+
|
| 103 |
+
except Exception as e:
|
| 104 |
+
logger.error(f"Failed to add examples: {e}")
|
| 105 |
+
raise
|
| 106 |
+
|
| 107 |
+
def retrieve(
|
| 108 |
+
self,
|
| 109 |
+
query_code: str,
|
| 110 |
+
language: Optional[Language] = None,
|
| 111 |
+
task: Optional[TaskType] = None,
|
| 112 |
+
k: int = 3
|
| 113 |
+
) -> List[Dict[str, Any]]:
|
| 114 |
+
"""
|
| 115 |
+
Retrieve similar code examples
|
| 116 |
+
|
| 117 |
+
Args:
|
| 118 |
+
query_code: Code snippet to find similar examples for
|
| 119 |
+
language: Filter by programming language (optional)
|
| 120 |
+
task: Filter by task type (optional)
|
| 121 |
+
k: Number of examples to retrieve
|
| 122 |
+
|
| 123 |
+
Returns:
|
| 124 |
+
List of similar code examples with metadata
|
| 125 |
+
"""
|
| 126 |
+
if not self.initialized:
|
| 127 |
+
self.initialize()
|
| 128 |
+
|
| 129 |
+
try:
|
| 130 |
+
logger.debug(f"Retrieving {k} similar examples for query")
|
| 131 |
+
|
| 132 |
+
# Generate query embedding
|
| 133 |
+
query_embedding = self.embedder.embed(query_code)
|
| 134 |
+
|
| 135 |
+
# Search vector store (get more results for filtering)
|
| 136 |
+
search_k = k * 3 if (language or task) else k
|
| 137 |
+
results = self.vector_store.search(query_embedding, k=search_k)
|
| 138 |
+
|
| 139 |
+
# Filter by language/task if specified
|
| 140 |
+
filtered_results = []
|
| 141 |
+
for distance, metadata in results:
|
| 142 |
+
# Apply filters
|
| 143 |
+
if language and metadata.get("language") != (
|
| 144 |
+
language.value if hasattr(language, 'value') else str(language)
|
| 145 |
+
):
|
| 146 |
+
continue
|
| 147 |
+
|
| 148 |
+
if task and metadata.get("task") != (
|
| 149 |
+
task.value if hasattr(task, 'value') else str(task)
|
| 150 |
+
):
|
| 151 |
+
continue
|
| 152 |
+
|
| 153 |
+
filtered_results.append({
|
| 154 |
+
"code": metadata.get("code"),
|
| 155 |
+
"language": metadata.get("language"),
|
| 156 |
+
"task": metadata.get("task"),
|
| 157 |
+
"description": metadata.get("description"),
|
| 158 |
+
"similarity_score": 1.0 / (1.0 + distance) # Convert distance to similarity
|
| 159 |
+
})
|
| 160 |
+
|
| 161 |
+
if len(filtered_results) >= k:
|
| 162 |
+
break
|
| 163 |
+
|
| 164 |
+
logger.info(f"Retrieved {len(filtered_results)} similar examples")
|
| 165 |
+
return filtered_results
|
| 166 |
+
|
| 167 |
+
except Exception as e:
|
| 168 |
+
logger.error(f"Failed to retrieve examples: {e}")
|
| 169 |
+
return []
|
| 170 |
+
|
| 171 |
+
def save(self):
|
| 172 |
+
"""Save the vector store index"""
|
| 173 |
+
if self.initialized:
|
| 174 |
+
self.vector_store.save()
|
| 175 |
+
|
| 176 |
+
def clear(self):
|
| 177 |
+
"""Clear all indexed examples"""
|
| 178 |
+
if self.initialized:
|
| 179 |
+
self.vector_store.clear()
|
| 180 |
+
|
| 181 |
+
def build_context(
|
| 182 |
+
self,
|
| 183 |
+
query_code: str,
|
| 184 |
+
language: Optional[Language] = None,
|
| 185 |
+
task: Optional[TaskType] = None,
|
| 186 |
+
k: int = 3
|
| 187 |
+
) -> str:
|
| 188 |
+
"""
|
| 189 |
+
Build context string from retrieved examples
|
| 190 |
+
|
| 191 |
+
Args:
|
| 192 |
+
query_code: Code snippet to find similar examples for
|
| 193 |
+
language: Filter by programming language
|
| 194 |
+
task: Filter by task type
|
| 195 |
+
k: Number of examples to include
|
| 196 |
+
|
| 197 |
+
Returns:
|
| 198 |
+
Formatted context string for LLM prompts
|
| 199 |
+
"""
|
| 200 |
+
examples = self.retrieve(query_code, language, task, k)
|
| 201 |
+
|
| 202 |
+
if not examples:
|
| 203 |
+
return ""
|
| 204 |
+
|
| 205 |
+
context_parts = ["Here are similar code examples:\n"]
|
| 206 |
+
|
| 207 |
+
for i, example in enumerate(examples, 1):
|
| 208 |
+
context_parts.append(f"\nExample {i}:")
|
| 209 |
+
if example.get("description"):
|
| 210 |
+
context_parts.append(f"Description: {example['description']}")
|
| 211 |
+
context_parts.append(f"```{example.get('language', 'python')}")
|
| 212 |
+
context_parts.append(example.get("code", ""))
|
| 213 |
+
context_parts.append("```")
|
| 214 |
+
|
| 215 |
+
return "\n".join(context_parts)
|