| | import os |
| | import errno |
| | import numpy as np |
| |
|
| | from copy import deepcopy |
| | from miscc.config import cfg |
| | from scipy.io.wavfile import write |
| | from torch.nn import init |
| | import torch |
| | import torch.nn as nn |
| | import torchvision.utils as vutils |
| | from wavefile import WaveWriter, Format |
| | import RT60 |
| | from multiprocessing import Pool |
| |
|
| |
|
| | |
| | def KL_loss(mu, logvar): |
| | |
| | KLD_element = mu.pow(2).add_(logvar.exp()).mul_(-1).add_(1).add_(logvar) |
| | KLD = torch.mean(KLD_element).mul_(-0.5) |
| | return KLD |
| |
|
| |
|
| | def compute_discriminator_loss(netD, real_RIRs, fake_RIRs, |
| | real_labels, fake_labels, |
| | conditions, gpus): |
| | criterion = nn.BCELoss() |
| | batch_size = real_RIRs.size(0) |
| | cond = conditions.detach() |
| | fake = fake_RIRs.detach() |
| | real_features = nn.parallel.data_parallel(netD, (real_RIRs), gpus) |
| | fake_features = nn.parallel.data_parallel(netD, (fake), gpus) |
| | |
| | |
| | inputs = (real_features, cond) |
| | real_logits = nn.parallel.data_parallel(netD.get_cond_logits, inputs, gpus) |
| | errD_real = criterion(real_logits, real_labels) |
| | |
| | inputs = (real_features[:(batch_size-1)], cond[1:]) |
| | wrong_logits = \ |
| | nn.parallel.data_parallel(netD.get_cond_logits, inputs, gpus) |
| | errD_wrong = criterion(wrong_logits, fake_labels[1:]) |
| | |
| | inputs = (fake_features, cond) |
| | fake_logits = nn.parallel.data_parallel(netD.get_cond_logits, inputs, gpus) |
| | errD_fake = criterion(fake_logits, fake_labels) |
| |
|
| | if netD.get_uncond_logits is not None: |
| | real_logits = \ |
| | nn.parallel.data_parallel(netD.get_uncond_logits, |
| | (real_features), gpus) |
| | fake_logits = \ |
| | nn.parallel.data_parallel(netD.get_uncond_logits, |
| | (fake_features), gpus) |
| | uncond_errD_real = criterion(real_logits, real_labels) |
| | uncond_errD_fake = criterion(fake_logits, fake_labels) |
| | |
| | errD = ((errD_real + uncond_errD_real) / 2. + |
| | (errD_fake + errD_wrong + uncond_errD_fake) / 3.) |
| | errD_real = (errD_real + uncond_errD_real) / 2. |
| | errD_fake = (errD_fake + uncond_errD_fake) / 2. |
| | else: |
| | errD = errD_real + (errD_fake + errD_wrong) * 0.5 |
| | return errD, errD_real.data, errD_wrong.data, errD_fake.data |
| | |
| |
|
| |
|
| |
|
| | def compute_generator_loss(epoch,netD,real_RIRs, fake_RIRs, real_labels, conditions, gpus): |
| | criterion = nn.BCELoss() |
| | loss = nn.L1Loss() |
| | loss1 = nn.MSELoss() |
| | RT_error = 0 |
| | |
| | |
| | |
| |
|
| | cond = conditions.detach() |
| | fake_features = nn.parallel.data_parallel(netD, (fake_RIRs), gpus) |
| | |
| | inputs = (fake_features, cond) |
| | fake_logits = nn.parallel.data_parallel(netD.get_cond_logits, inputs, gpus) |
| | MSE_error = loss(real_RIRs,fake_RIRs) |
| | MSE_error1 = loss1(real_RIRs,fake_RIRs) |
| | sample_size = real_RIRs.size()[0] |
| | channel = 12 |
| | fs = 16000 |
| | rn = np.random.randint(sample_size-(channel*2)) |
| | real_wave = np.array(real_RIRs[rn:rn+channel].to("cpu").detach()) |
| | real_wave = real_wave.reshape(channel,4096) |
| | fake_wave = np.array(fake_RIRs[rn:rn+channel].to("cpu").detach()) |
| | fake_wave = fake_wave.reshape(channel,4096) |
| |
|
| | pool = Pool(processes=12) |
| | |
| | results =[] |
| | for n in range(channel): |
| | results.append(pool.apply_async(RT60.t60_parallel, args=(n,real_wave,fake_wave,fs,))) |
| | |
| | T60_error =0 |
| | for result in results: |
| | T60_error = T60_error + result.get() |
| |
|
| | RT_error = T60_error/channel |
| | |
| | pool.close() |
| | pool.join() |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| |
|
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| |
|
| | |
| | |
| | |
| | |
| | errD_fake = criterion(fake_logits, real_labels) + 5* 4096 * MSE_error1 + 40 * RT_error |
| | if netD.get_uncond_logits is not None: |
| | fake_logits = \ |
| | nn.parallel.data_parallel(netD.get_uncond_logits, |
| | (fake_features), gpus) |
| | uncond_errD_fake = criterion(fake_logits, real_labels) |
| | errD_fake += uncond_errD_fake |
| | return errD_fake, MSE_error,RT_error |
| |
|
| |
|
| | |
| | def weights_init(m): |
| | classname = m.__class__.__name__ |
| | if classname.find('Conv') != -1: |
| | m.weight.data.normal_(0.0, 0.02) |
| | elif classname.find('BatchNorm') != -1: |
| | m.weight.data.normal_(1.0, 0.02) |
| | m.bias.data.fill_(0) |
| | elif classname.find('Linear') != -1: |
| | m.weight.data.normal_(0.0, 0.02) |
| | if m.bias is not None: |
| | m.bias.data.fill_(0.0) |
| |
|
| |
|
| | |
| | def save_RIR_results(data_RIR, fake, epoch, RIR_dir): |
| | num = cfg.VIS_COUNT |
| | fake = fake[0:num] |
| | |
| | if data_RIR is not None: |
| | data_RIR = data_RIR[0:num] |
| | for i in range(num): |
| | |
| | real_RIR_path = RIR_dir+"/real_sample"+str(i)+".wav" |
| | fake_RIR_path = RIR_dir+"/fake_sample"+str(i)+"_epoch_"+str(epoch)+".wav" |
| | fs =16000 |
| |
|
| | real_IR = np.array(data_RIR[i].to("cpu").detach()) |
| | fake_IR = np.array(fake[i].to("cpu").detach()) |
| | |
| | |
| | |
| | |
| | r = WaveWriter(real_RIR_path, channels=1, samplerate=fs) |
| | r.write(np.array(real_IR)) |
| | f = WaveWriter(fake_RIR_path, channels=1, samplerate=fs) |
| | f.write(np.array(fake_IR)) |
| |
|
| |
|
| | |
| | |
| |
|
| |
|
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | else: |
| | for i in range(num): |
| | |
| | fake_RIR_path = RIR_dir+"/small_fake_sample"+str(i)+"_epoch_"+str(epoch)+".wav" |
| | fs =16000 |
| | fake_IR = np.array(fake[i].to("cpu").detach()) |
| | f = WaveWriter(fake_RIR_path, channels=1, samplerate=fs) |
| | f.write(np.array(fake_IR)) |
| | |
| | |
| |
|
| | |
| | |
| | |
| |
|
| |
|
| | def save_model(netG, netD, epoch, model_dir): |
| | torch.save( |
| | netG.state_dict(), |
| | '%s/netG_epoch_%d.pth' % (model_dir, epoch)) |
| | torch.save( |
| | netD.state_dict(), |
| | '%s/netD_epoch_last.pth' % (model_dir)) |
| | |
| |
|
| |
|
| | def mkdir_p(path): |
| | try: |
| | os.makedirs(path) |
| | except OSError as exc: |
| | if exc.errno == errno.EEXIST and os.path.isdir(path): |
| | pass |
| | else: |
| | raise |
| |
|