| """Command-line interface entry points for BitTransformerLM.""" |
|
|
| import sys |
| import logging |
| from pathlib import Path |
| from typing import Optional |
|
|
| import torch |
|
|
| from .cli_standards import create_training_parser, create_inference_parser, BitTransformerCLI |
| from .config import ( |
| ExperimentConfig, |
| ModelConfig, |
| TrainingConfig, |
| SafetyConfig, |
| DataConfig, |
| get_small_config, |
| get_medium_config, |
| get_large_config, |
| ) |
| from .model import BitTransformerLM, diffusion_inference |
| from .training import train_loop |
| from .bit_io import text_to_bits, bits_to_text, infer_text |
| from .utils import save_model, load_model |
| from .dashboard_app import run_dashboard |
|
|
|
|
| def setup_logging(level: str = "INFO") -> None: |
| """Setup logging configuration.""" |
| logging.basicConfig( |
| level=getattr(logging, level.upper()), |
| format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", |
| handlers=[ |
| logging.StreamHandler(sys.stdout), |
| ], |
| ) |
|
|
|
|
| def train_cli() -> None: |
| """CLI entry point for training BitTransformerLM models.""" |
| parser = create_training_parser() |
| args = parser.parse_args() |
|
|
| setup_logging(args.log_level) |
| logger = logging.getLogger(__name__) |
|
|
| |
| if args.model_size == "small": |
| config = get_small_config() |
| elif args.model_size == "medium": |
| config = get_medium_config() |
| elif args.model_size == "large": |
| config = get_large_config() |
| else: |
| config = ExperimentConfig() |
|
|
| |
| config.model.d_model = args.d_model |
| config.model.nhead = args.num_heads |
| config.model.num_layers = args.num_layers |
| config.model.max_seq_len = args.max_seq_len |
|
|
| config.training.epochs = args.epochs |
| config.training.batch_size = args.batch_size |
| config.training.learning_rate = args.learning_rate |
| config.training.weight_decay = args.weight_decay |
| config.training.gradient_clip_val = args.grad_clip |
| config.training.warmup_steps = args.warmup_steps |
| config.training.amp = args.use_amp |
| config.training.compile_model = args.compile_model |
|
|
| config.safety.k_threshold = args.min_negentropy |
| config.safety.c_threshold = args.max_complexity |
| config.safety.s_threshold = args.min_symbiosis |
| config.safety.enable_safety = args.enable_safety_gates |
|
|
| config.data.dataset_path = Path(args.input_path) if args.input_path else None |
| config.data.max_sequence_length = args.seq_length |
| config.data.num_workers = args.num_workers |
|
|
| config.output_dir = Path(args.output_path) |
| config.seed = args.seed |
|
|
| |
| if torch.cuda.is_available(): |
| config.device = "cuda" |
| else: |
| config.device = "cpu" |
|
|
| logger.info(f"Starting training with config: {config.experiment_name}") |
| logger.info(f"Model: {config.model.d_model}d, {config.model.num_layers}L, {config.model.nhead}H") |
| logger.info(f"Device: {config.device}") |
|
|
| |
| model = BitTransformerLM(**config.model.to_dict()) |
| model = model.to(config.device) |
|
|
| |
| logger.info("Creating synthetic training data...") |
| torch.manual_seed(config.seed) |
| data = torch.randint(0, 2, (args.dataset_size, config.data.max_sequence_length)) |
|
|
| |
| logger.info("Starting training...") |
| try: |
| train_loop( |
| model, |
| data, |
| epochs=config.training.epochs, |
| batch_size=config.training.batch_size, |
| amp=config.training.amp, |
| compile_model=config.training.compile_model, |
| log=True, |
| ) |
|
|
| |
| save_path = config.output_dir / "model_final.pt" |
| save_model(model, save_path) |
| logger.info(f"Model saved to {save_path}") |
|
|
| except Exception as e: |
| logger.error(f"Training failed: {e}") |
| sys.exit(1) |
|
|
|
|
| def infer_cli() -> None: |
| """CLI entry point for BitTransformerLM inference.""" |
| parser = create_inference_parser() |
| parser.add_argument("--prompt", type=str, required=True, help="Text prompt for generation") |
| parser.add_argument("--max-tokens", type=int, default=50, help="Maximum tokens to generate") |
| parser.add_argument("--temperature", type=float, default=1.0, help="Sampling temperature") |
| parser.add_argument("--use-diffusion", action="store_true", help="Use diffusion mode") |
| args = parser.parse_args() |
|
|
| setup_logging(args.log_level) |
| logger = logging.getLogger(__name__) |
|
|
| |
| if not Path(args.weights_path).exists(): |
| logger.error(f"Model weights not found at {args.weights_path}") |
| sys.exit(1) |
|
|
| logger.info(f"Loading model from {args.weights_path}") |
| model = load_model(args.weights_path) |
| model.eval() |
|
|
| |
| device = "cuda" if torch.cuda.is_available() else "cpu" |
| model = model.to(device) |
|
|
| logger.info(f"Model loaded on {device}") |
| logger.info(f"Prompt: {args.prompt}") |
|
|
| try: |
| if args.use_diffusion: |
| |
| logger.info("Using diffusion inference mode") |
| prompt_bits = text_to_bits(args.prompt) |
| length = len(prompt_bits) + args.max_tokens * 9 |
|
|
| generated_bits = diffusion_inference( |
| model, |
| length=length, |
| steps=args.diffusion_steps, |
| schedule=args.noise_schedule, |
| ) |
|
|
| result = bits_to_text(generated_bits[0].tolist()) |
|
|
| else: |
| |
| if args.enable_safety_gates: |
| result = infer_text( |
| model, |
| args.prompt, |
| c_floor=args.max_complexity, |
| s_floor=args.min_symbiosis, |
| ) |
| else: |
| |
| from .bit_io import sample_text |
| result = sample_text( |
| model, |
| args.prompt, |
| max_new_tokens=args.max_tokens, |
| temperature=args.temperature, |
| ) |
|
|
| print(f"\nGenerated text:\n{result}") |
|
|
| except Exception as e: |
| logger.error(f"Inference failed: {e}") |
| sys.exit(1) |
|
|
|
|
| def dashboard_cli() -> None: |
| """CLI entry point for BitTransformerLM dashboard.""" |
| parser = BitTransformerCLI.create_standard_parser( |
| "BitTransformerLM Dashboard", |
| ["io"] |
| ) |
| parser.add_argument("--host", type=str, default="127.0.0.1", help="Dashboard host") |
| parser.add_argument("--port", type=int, default=7860, help="Dashboard port") |
| parser.add_argument("--share", action="store_true", help="Create public link") |
| args = parser.parse_args() |
|
|
| setup_logging(args.log_level) |
| logger = logging.getLogger(__name__) |
|
|
| logger.info(f"Starting BitTransformerLM dashboard on {args.host}:{args.port}") |
|
|
| try: |
| run_dashboard( |
| host=args.host, |
| port=args.port, |
| share=args.share, |
| ) |
| except Exception as e: |
| logger.error(f"Dashboard failed to start: {e}") |
| sys.exit(1) |
|
|
|
|
| if __name__ == "__main__": |
| |
| import os |
| script_name = os.path.basename(sys.argv[0]) |
|
|
| if "train" in script_name: |
| train_cli() |
| elif "infer" in script_name: |
| infer_cli() |
| elif "dashboard" in script_name: |
| dashboard_cli() |
| else: |
| print("Available commands:") |
| print(" bit-transformer-train - Train a BitTransformerLM model") |
| print(" bit-transformer-infer - Run inference with a trained model") |
| print(" bit-transformer-dashboard - Launch interactive dashboard") |
| sys.exit(1) |