taylorliu2000 commited on
Commit
7af68b6
ยท
verified ยท
1 Parent(s): 6b72783

Upload train_qwen_capybara_sft.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. train_qwen_capybara_sft.py +100 -0
train_qwen_capybara_sft.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # /// script
3
+ # requires-python = ">=3.10"
4
+ # dependencies = [
5
+ # "trl>=0.12.0",
6
+ # "peft>=0.7.0",
7
+ # "transformers>=4.36.0",
8
+ # "accelerate>=0.24.0",
9
+ # "trackio",
10
+ # ]
11
+ # ///
12
+
13
+ """
14
+ SFT fine-tuning Qwen2.5-0.5B on Capybara dataset with LoRA.
15
+ """
16
+
17
+ import trackio
18
+ from datasets import load_dataset
19
+ from peft import LoraConfig
20
+ from trl import SFTTrainer, SFTConfig
21
+
22
+
23
+ # Load dataset
24
+ print("๐Ÿ“ฆ Loading dataset...")
25
+ dataset = load_dataset("trl-lib/Capybara", split="train")
26
+ print(f"โœ… Dataset loaded: {len(dataset)} examples")
27
+
28
+ # Create train/eval split
29
+ print("๐Ÿ”€ Creating train/eval split...")
30
+ dataset_split = dataset.train_test_split(test_size=0.1, seed=42)
31
+ train_dataset = dataset_split["train"]
32
+ eval_dataset = dataset_split["test"]
33
+ print(f" Train: {len(train_dataset)} examples")
34
+ print(f" Eval: {len(eval_dataset)} examples")
35
+
36
+ # Training configuration
37
+ config = SFTConfig(
38
+ # Hub settings
39
+ output_dir="qwen25-05b-capybara-sft",
40
+ push_to_hub=True,
41
+ hub_model_id="taylorliu2000/qwen25-05b-capybara-sft",
42
+ hub_strategy="every_save",
43
+
44
+ # Training parameters
45
+ num_train_epochs=3,
46
+ per_device_train_batch_size=4,
47
+ gradient_accumulation_steps=4,
48
+ learning_rate=2e-5,
49
+
50
+ # Logging & checkpointing
51
+ logging_steps=10,
52
+ save_strategy="steps",
53
+ save_steps=100,
54
+ save_total_limit=2,
55
+
56
+ # Evaluation
57
+ eval_strategy="steps",
58
+ eval_steps=100,
59
+
60
+ # Optimization
61
+ warmup_ratio=0.1,
62
+ lr_scheduler_type="cosine",
63
+
64
+ # Monitoring
65
+ report_to="trackio",
66
+ project="qwen25-capybara-sft",
67
+ run_name="qwen25-05b-capybara-lora-r16",
68
+ )
69
+
70
+ # LoRA configuration
71
+ peft_config = LoraConfig(
72
+ r=16,
73
+ lora_alpha=32,
74
+ lora_dropout=0.05,
75
+ bias="none",
76
+ task_type="CAUSAL_LM",
77
+ target_modules=["q_proj", "v_proj"],
78
+ )
79
+
80
+ # Initialize and train
81
+ print("๐ŸŽฏ Initializing trainer...")
82
+ trainer = SFTTrainer(
83
+ model="Qwen/Qwen2.5-0.5B",
84
+ train_dataset=train_dataset,
85
+ eval_dataset=eval_dataset,
86
+ args=config,
87
+ peft_config=peft_config,
88
+ )
89
+
90
+ print("๐Ÿš€ Starting training...")
91
+ trainer.train()
92
+
93
+ print("๐Ÿ’พ Pushing to Hub...")
94
+ trainer.push_to_hub()
95
+
96
+ # Finish Trackio tracking
97
+ trackio.finish()
98
+
99
+ print("โœ… Complete! Model at: https://huggingface.co/taylorliu2000/qwen25-05b-capybara-sft")
100
+ print("๐Ÿ“Š View metrics at: https://huggingface.co/spaces/taylorliu2000/trackio")