| --- |
| language: en |
| license: mit |
| tags: |
| - audio |
| - audio-classification |
| - musical-instruments |
| - wav2vec2 |
| - transformers |
| - pytorch |
| datasets: |
| - custom |
| metrics: |
| - accuracy |
| - roc_auc |
| model-index: |
| - name: epoch_musical_instruments_identification_2 |
| results: |
| - task: |
| type: audio-classification |
| name: Musical Instrument Classification |
| metrics: |
| - type: accuracy |
| value: 0.9333 |
| name: Accuracy |
| - type: roc_auc |
| value: 0.9859 |
| name: ROC AUC (Macro) |
| - type: loss |
| value: 1.0639 |
| name: Validation Loss |
| base_model: |
| - facebook/wav2vec2-base-960h |
| --- |
| |
| # Musical Instrument Classification Model |
|
|
| This model is a fine-tuned version of [facebook/wav2vec2-base-960h](https://huggingface.co/facebook/wav2vec2-base-960h) for musical instrument classification. It can identify 9 different musical instruments from audio recordings with high accuracy. |
|
|
| ## Model Description |
|
|
| - **Model type:** Audio Classification |
| - **Base model:** facebook/wav2vec2-base-960h |
| - **Language:** Audio (no specific language) |
| - **License:** MIT |
| - **Fine-tuned on:** Custom musical instrument dataset (200 samples for each class) |
|
|
| ## Performance |
|
|
| The model achieves excellent performance on the evaluation set after 5 epochs of training: |
|
|
| - **Final Accuracy:** 93.33% |
| - **Final ROC AUC (Macro):** 98.59% |
| - **Final Validation Loss:** 1.064 |
| - **Evaluation Runtime:** 14.18 seconds |
| - **Evaluation Speed:** 25.39 samples/second |
|
|
| ### Training Progress |
|
|
| | Epoch | Training Loss | Validation Loss | ROC AUC | Accuracy | |
| |-------|---------------|-----------------|---------|----------| |
| | 1 | 1.9872 | 1.8875 | 0.9248 | 0.6639 | |
| | 2 | 1.8652 | 1.4793 | 0.9799 | 0.8000 | |
| | 3 | 1.3868 | 1.2311 | 0.9861 | 0.8194 | |
| | 4 | 1.3242 | 1.1121 | 0.9827 | 0.9250 | |
| | 5 | 1.1869 | 1.0639 | 0.9859 | 0.9333 | |
|
|
| ## Supported Instruments |
|
|
| The model can classify the following 9 musical instruments: |
|
|
| 1. **Acoustic Guitar** |
| 2. **Bass Guitar** |
| 3. **Drum Set** |
| 4. **Electric Guitar** |
| 5. **Flute** |
| 6. **Hi-Hats** |
| 7. **Keyboard** |
| 8. **Trumpet** |
| 9. **Violin** |
|
|
| ## Usage |
|
|
| ### Quick Start with Pipeline |
|
|
| ```python |
| from transformers import pipeline |
| import torchaudio |
| |
| # Load the classification pipeline |
| classifier = pipeline("audio-classification", model="Bhaveen/epoch_musical_instruments_identification_2") |
| |
| # Load and preprocess audio |
| audio, rate = torchaudio.load("your_audio_file.wav") |
| transform = torchaudio.transforms.Resample(rate, 16000) |
| audio = transform(audio).numpy().reshape(-1)[:48000] |
| |
| # Classify the audio |
| result = classifier(audio) |
| print(result) |
| ``` |
|
|
| ### Using Transformers Directly |
|
|
| ```python |
| from transformers import AutoFeatureExtractor, AutoModelForAudioClassification |
| import torchaudio |
| import torch |
| |
| # Load model and feature extractor |
| model_name = "Bhaveen/epoch_musical_instruments_identification_2" |
| feature_extractor = AutoFeatureExtractor.from_pretrained(model_name) |
| model = AutoModelForAudioClassification.from_pretrained(model_name) |
| |
| # Load and preprocess audio |
| audio, rate = torchaudio.load("your_audio_file.wav") |
| transform = torchaudio.transforms.Resample(rate, 16000) |
| audio = transform(audio).numpy().reshape(-1)[:48000] |
| |
| # Extract features and make prediction |
| inputs = feature_extractor(audio, sampling_rate=16000, return_tensors="pt") |
| with torch.no_grad(): |
| outputs = model(**inputs) |
| predictions = torch.nn.functional.softmax(outputs.logits, dim=-1) |
| predicted_class = torch.argmax(predictions, dim=-1) |
| |
| print(f"Predicted instrument: {model.config.id2label[predicted_class.item()]}") |
| ``` |
|
|
| ## Training Details |
|
|
| ### Dataset and Preprocessing |
|
|
| - **Custom dataset** with audio recordings of 9 musical instruments |
| - **Train/Test Split:** 80/20 using file numbering (files < 160 for training) |
| - **Data Balancing:** Random oversampling applied to minority classes |
| - **Audio Preprocessing:** |
| - Resampling to 16,000 Hz |
| - Fixed length of 48,000 samples (3 seconds) |
| - Truncation of longer audio files |
|
|
| ### Training Configuration |
|
|
| ```python |
| # Training hyperparameters |
| batch_size = 1 |
| gradient_accumulation_steps = 4 |
| learning_rate = 5e-6 |
| num_train_epochs = 5 |
| warmup_steps = 50 |
| weight_decay = 0.02 |
| ``` |
|
|
| ### Model Architecture |
|
|
| - **Base Model:** facebook/wav2vec2-base-960h |
| - **Classification Head:** Added for 9-class classification |
| - **Parameters:** ~95M trainable parameters |
| - **Features:** Wav2Vec2 audio representations with fine-tuned classification layer |
|
|
| ## Technical Specifications |
|
|
| - **Audio Format:** WAV files |
| - **Sample Rate:** 16,000 Hz |
| - **Input Length:** 3 seconds (48,000 samples) |
| - **Model Framework:** PyTorch + Transformers |
| - **Inference Device:** GPU recommended (CUDA) |
|
|
| ## Evaluation Metrics |
|
|
| The model uses the following evaluation metrics: |
|
|
| - **Accuracy:** Standard classification accuracy |
| - **ROC AUC:** Macro-averaged ROC AUC with one-vs-rest approach |
| - **Multi-class Classification:** Softmax probabilities for all 9 instrument classes |
|
|
|
|
|
|
| ## Limitations and Considerations |
|
|
| 1. **Audio Duration:** Model expects exactly 3-second audio clips (truncates longer, may not work well with shorter) |
| 2. **Single Instrument Focus:** Optimized for single instrument classification, mixed instruments may produce uncertain results |
| 3. **Audio Quality:** Performance depends on audio quality and recording conditions |
| 4. **Sample Rate:** Input must be resampled to 16kHz for optimal performance |
| 5. **Domain Specificity:** Trained on specific instrument recordings, may not generalize to all variants or playing styles |
|
|
| ## Training Environment |
|
|
| - **Platform:** Google Colab |
| - **GPU:** CUDA-enabled device |
| - **Libraries:** |
| - transformers==4.28.1 |
| - torchaudio==0.12 |
| - datasets |
| - evaluate |
| - imblearn |
|
|
| ## Model Files |
|
|
| The repository contains: |
| - Model weights and configuration |
| - Feature extractor configuration |
| - Training logs and metrics |
| - Label mappings (id2label, label2id) |
|
|
| --- |
|
|
| *Model trained as part of a hackathon project* |