| |
|
| | """ SincNet model """ |
| | from functools import lru_cache |
| | import numpy as np |
| | import logging |
| |
|
| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| |
|
| | logger = logging.getLogger(__name__) |
| |
|
| | class SincNetFilterConvLayer(nn.Module): |
| | """SincNet fast convolution filter layer""" |
| |
|
| | def __init__(self, out_channels: int, kernel_size: int, sample_rate=16000, |
| | stride=1, padding=0, dilation=1, min_low_hz=50, min_band_hz=50, |
| | in_channels=1, requires_grad=False): |
| | """ |
| | Args: |
| | out_channels : `int` number of filters. |
| | kernel_size : `int` filter length. |
| | sample_rate : `int`, optional sample rate. Defaults to 16000. |
| | """ |
| | super(SincNetFilterConvLayer, self).__init__() |
| |
|
| | if in_channels != 1: |
| | raise ValueError(f"SincNetFilterConvLayer only support in_channels = 1, was in_channels = {in_channels}") |
| |
|
| | self._out_channels = out_channels |
| | self._kernel_size = kernel_size |
| |
|
| | if kernel_size % 2 == 0: |
| | self._kernel_size += 1 |
| | |
| | self._stride = stride |
| | self._padding = padding |
| | self._dilation = dilation |
| | self._sample_rate = sample_rate |
| | self._min_low_hz = min_low_hz |
| | self._min_band_hz = min_band_hz |
| |
|
| | |
| | low_hz = 30 |
| | high_hz = self._sample_rate / 2 - (self._min_low_hz + self._min_band_hz) |
| | mel = np.linspace( |
| | 2595 * np.log10(1 + low_hz / 700), |
| | 2595 * np.log10(1 + high_hz / 700), |
| | self._out_channels // 2 + 1 |
| | ) |
| | hz = 700 * (10 ** (mel / 2595) - 1) |
| | |
| | self._low_hz = nn.Parameter( |
| | torch.Tensor(hz[:-1]).view(-1, 1), |
| | requires_grad=requires_grad |
| | ) |
| | self._band_hz = nn.Parameter( |
| | torch.Tensor(np.diff(hz)).view(-1, 1), |
| | requires_grad=requires_grad |
| | ) |
| | self.register_buffer( |
| | "_window", |
| | torch.from_numpy(np.hamming(self._kernel_size)[: self._kernel_size // 2]).float() |
| | ) |
| | self.register_buffer( |
| | "_n", |
| | (2* np.pi * torch.arange(-(self._kernel_size // 2), 0.0).view(1, -1) / self._sample_rate) |
| | ) |
| |
|
| | @property |
| | @lru_cache(maxsize=1) |
| | def filters(self) -> torch.Tensor: |
| | low = self._min_low_hz + torch.abs(self._low_hz) |
| | high = torch.clamp(low + self._min_band_hz + torch.abs(self._band_hz), self._min_low_hz, self._sample_rate/2) |
| | band = (high-low)[:,0] |
| |
|
| | f_times_t_low = torch.matmul(low, self._n) |
| | f_times_t_high = torch.matmul(high, self._n) |
| |
|
| | band_pass_left = ((torch.sin(f_times_t_high)-torch.sin(f_times_t_low))/(self._n/2))*self._window |
| | band_pass_center = 2 * band.view(-1, 1) |
| | band_pass_right = torch.flip(band_pass_left, dims=[1]) |
| |
|
| | band_pass = torch.cat([band_pass_left,band_pass_center,band_pass_right],dim=1) |
| | band_pass = band_pass / (2*band[:,None]) |
| | return band_pass.view(self._out_channels, 1, self._kernel_size) |
| |
|
| | def forward(self, waveforms: torch.Tensor) -> torch.Tensor: |
| | """ |
| | Args: |
| | waveforms : (batch_size, 1, n_samples) batch of waveforms. |
| | |
| | Returns: |
| | features : (batch_size, out_channels, n_samples_out) batch of sinc filters activations. |
| | """ |
| | return F.conv1d(waveforms, self.filters, stride=self._stride, |
| | padding=self._padding, dilation=self._dilation, |
| | ).abs_() |
| |
|
| | class SincNet(nn.Module): |
| | """SincNet""" |
| |
|
| | def __init__( |
| | self, |
| | num_sinc_filters: int = 80, |
| | sinc_filter_length: int = 251, |
| | num_conv_filters: int = 60, |
| | conv_filter_length: int = 5, |
| | pool_kernel_size: int = 3, |
| | pool_stride: int = 3, |
| | sample_rate: int = 16000, |
| | sinc_filter_stride: int = 10, |
| | sinc_filter_padding: int = 0, |
| | sinc_filter_dilation: int = 1, |
| | min_low_hz: int = 50, |
| | min_band_hz: int = 50, |
| | sinc_filter_in_channels: int = 1, |
| | num_wavform_channels: int = 1, |
| | ): |
| | super().__init__() |
| |
|
| | if sample_rate != 16000: |
| | raise NotImplementedError(f"SincNet only supports 16kHz audio (sample_rate = 16000), was sample_rate = {sample_rate}") |
| |
|
| | self.wav_norm1d = nn.InstanceNorm1d(num_wavform_channels, affine=True) |
| |
|
| | self.conv1d = nn.ModuleList([ |
| | SincNetFilterConvLayer( |
| | num_sinc_filters, |
| | sinc_filter_length, |
| | sample_rate=sample_rate, |
| | stride=sinc_filter_stride, |
| | padding=sinc_filter_padding, |
| | dilation=sinc_filter_dilation, |
| | min_low_hz=min_low_hz, |
| | min_band_hz=min_band_hz, |
| | in_channels=sinc_filter_in_channels, |
| | ), |
| | nn.Conv1d(num_sinc_filters, num_conv_filters, conv_filter_length), |
| | nn.Conv1d(num_conv_filters, num_conv_filters, conv_filter_length), |
| | ]) |
| | self.pool1d = nn.ModuleList([ |
| | nn.MaxPool1d(pool_kernel_size, stride=pool_stride), |
| | nn.MaxPool1d(pool_kernel_size, stride=pool_stride), |
| | nn.MaxPool1d(pool_kernel_size, stride=pool_stride), |
| | ]) |
| | self.norm1d = nn.ModuleList([ |
| | nn.InstanceNorm1d(num_sinc_filters, affine=True), |
| | nn.InstanceNorm1d(num_conv_filters, affine=True), |
| | nn.InstanceNorm1d(num_conv_filters, affine=True), |
| | ]) |
| |
|
| | def forward(self, waveforms: torch.Tensor) -> torch.Tensor: |
| | """ |
| | Args: |
| | waveforms : (batch, channel, sample) |
| | """ |
| | outputs = self.wav_norm1d(waveforms) |
| |
|
| | for _, (conv1d, pool1d, norm1d) in enumerate( |
| | zip(self.conv1d, self.pool1d, self.norm1d) |
| | ): |
| | outputs = conv1d(outputs) |
| | outputs = F.leaky_relu(norm1d(pool1d(outputs))) |
| |
|
| | return outputs |
| |
|