def train(): parser = argparse.ArgumentParser( description='Parameters for training Deep Clustering') parser.add_argument('--opt', type=str, help='Path to option YAML file.', default='./config/train.yml') args = parser.parse_args() opt = option.parse(args.opt) set_logger.setup_logger(opt['logger']['name'], opt['logger']['path'], screen=opt['logger']['screen'], tofile=opt['logger']['tofile']) logger = logging.getLogger(opt['logger']['name']) logger.info("Building the model of Deep Clustering") dpcl = model.DPCL(**opt['DPCL']) logger.info("Building the optimizer of Deep Clustering") optimizer = make_optimizer(dpcl.parameters(), opt) logger.info('Building the dataloader of Deep Clustering') train_dataloader, val_dataloader = make_dataloader(opt) logger.info('Train Datasets Length: {}, Val Datasets Length: {}'.format( len(train_dataloader), len(val_dataloader))) logger.info('Building the Trainer of Deep Clustering') trainer = Trainer(train_dataloader, val_dataloader, dpcl, optimizer, opt) trainer.run()
def __init__(self, mix_path, yaml_path, model, gpuid): super(Separation, self).__init__() self.mix = read_wav(mix_path) opt = parse(yaml_path, is_tain=False) net = ConvTasNet(**opt['Conv_Tasnet']) dicts = torch.load(model, map_location='cpu') net.load_state_dict(dicts["model_state_dict"]) self.logger = get_logger(__name__) self.logger.info('Load checkpoint from {}, epoch {: d}'.format( model, dicts["epoch"])) self.net = net.cuda() self.device = torch.device( 'cuda:{}'.format(gpuid[0]) if len(gpuid) > 0 else 'cpu') self.gpuid = tuple(gpuid)
def __init__(self, mix_json, aux_json, ref_json, yaml_path, model, gpuid): super(Separation, self).__init__() opt = parse(yaml_path) self.data_loader = Datasets_Test( mix_json, aux_json, ref_json, sample_rate=opt['datasets']['audio_setting']['sample_rate'], model=opt['model']['MODEL']) self.mix_infos = self.data_loader.mix self.total_wavs = len(self.mix_infos) self.model = opt['model']['MODEL'] # Extraction and Suppression model if self.model == 'DPRNN_Speaker_Extraction' or self.model == 'DPRNN_Speaker_Suppression': net = model_function.Extractin_Suppression_Model( **opt['Dual_Path_Aux_Speaker']) # Separation model if self.model == 'DPRNN_Speech_Separation': net = model_function.Speech_Serapation_Model( **opt['Dual_Path_Aux_Speaker']) net = torch.nn.DataParallel(net) dicts = torch.load(model, map_location='cpu') net.load_state_dict(dicts["model_state_dict"]) setup_logger(opt['logger']['name'], opt['logger']['path'], screen=opt['logger']['screen'], tofile=opt['logger']['tofile']) self.logger = logging.getLogger(opt['logger']['name']) self.logger.info('Load checkpoint from {}, epoch {: d}'.format( model, dicts["epoch"])) self.net = net.cuda() self.device = torch.device( 'cuda:{}'.format(gpuid[0]) if len(gpuid) > 0 else 'cpu') self.gpuid = tuple(gpuid) self.Output_wav = False
def train(): parser = argparse.ArgumentParser( description='Parameters for training Conv-TasNet') parser.add_argument('--opt', type=str, help='Path to option YAML file.') args = parser.parse_args() opt = option.parse(args.opt) set_logger.setup_logger(opt['logger']['name'], opt['logger']['path'], screen=opt['logger']['screen'], tofile=opt['logger']['tofile']) logger = logging.getLogger(opt['logger']['name']) # build model logger.info("Building the model of Conv-Tasnet") Conv_Tasnet = model.Conv_TasNet(**opt['Conv_Tasnet']) # build optimizer logger.info("Building the optimizer of Conv-Tasnet") optimizer = make_optimizer(Conv_Tasnet.parameters(), opt) # build dataloader logger.info('Building the dataloader of Conv-Tasnet') train_dataloader, val_dataloader = make_dataloader(opt) logger.info('Train Datasets Length: {}, Val Datasets Length: {}'.format( len(train_dataloader), len(val_dataloader))) # build scheduler scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=opt['scheduler']['factor'], patience=opt['scheduler']['patience'], verbose=True, min_lr=opt['scheduler']['min_lr']) # build trainer logger.info('Building the Trainer of Conv-Tasnet') trainer = trainer_Tasnet.Trainer(train_dataloader, val_dataloader, Conv_Tasnet, optimizer, scheduler, opt) trainer.run()
def train(): parser = argparse.ArgumentParser( description='Parameters for training Model') # configuration fiule parser.add_argument('--opt', type=str, help='Path to option YAML file.') args = parser.parse_args() opt = option.parse(args.opt) set_logger.setup_logger(opt['logger']['name'], opt['logger']['path'], screen=opt['logger']['screen'], tofile=opt['logger']['tofile']) logger = logging.getLogger(opt['logger']['name']) day_time = datetime.date.today().strftime('%y%m%d') # build model model = opt['model']['MODEL'] logger.info("Building the model of {}".format(model)) # Extraction and Suppression model if opt['model']['MODEL'] == 'DPRNN_Speaker_Extraction' or opt['model'][ 'MODEL'] == 'DPRNN_Speaker_Suppression': net = model_function.Extractin_Suppression_Model( **opt['Dual_Path_Aux_Speaker']) # Separation model if opt['model']['MODEL'] == 'DPRNN_Speech_Separation': net = model_function.Speech_Serapation_Model( **opt['Dual_Path_Aux_Speaker']) if opt['train']['gpuid']: if len(opt['train']['gpuid']) > 1: logger.info('We use GPUs : {}'.format(opt['train']['gpuid'])) else: logger.info('We use GPUs : {}'.format(opt['train']['gpuid'])) device = torch.device('cuda:{}'.format(opt['train']['gpuid'][0])) gpuids = opt['train']['gpuid'] if len(gpuids) > 1: net = torch.nn.DataParallel(net, device_ids=gpuids) net = net.to(device) logger.info('Loading {} parameters: {:.3f} Mb'.format( model, check_parameters(net))) # build optimizer logger.info("Building the optimizer of {}".format(model)) Optimizer = make_optimizer(net.parameters(), opt) Scheduler = ReduceLROnPlateau(Optimizer, mode='min', factor=opt['scheduler']['factor'], patience=opt['scheduler']['patience'], verbose=True, min_lr=opt['scheduler']['min_lr']) # build dataloader logger.info('Building the dataloader of {}'.format(model)) train_dataloader, val_dataloader = make_dataloader(opt) logger.info('Train Datasets Length: {}, Val Datasets Length: {}'.format( len(train_dataloader), len(val_dataloader))) # build trainer logger.info('............. Training ................') total_epoch = opt['train']['epoch'] num_spks = opt['num_spks'] print_freq = opt['logger']['print_freq'] checkpoint_path = opt['train']['path'] early_stop = opt['train']['early_stop'] max_norm = opt['optim']['clip_norm'] best_loss = np.inf no_improve = 0 ce_loss = torch.nn.CrossEntropyLoss() weight = 0.1 epoch = 0 # Resume training settings if opt['resume']['state']: opt['resume']['path'] = opt['resume'][ 'path'] + '/' + '200722_epoch{}.pth.tar'.format( opt['resume']['epoch']) ckp = torch.load(opt['resume']['path'], map_location='cpu') epoch = ckp['epoch'] logger.info("Resume from checkpoint {}: epoch {:.3f}".format( opt['resume']['path'], epoch)) net.load_state_dict(ckp['model_state_dict']) net.to(device) Optimizer.load_state_dict(ckp['optim_state_dict']) while epoch < total_epoch: epoch += 1 logger.info('Start training from epoch: {:d}, iter: {:d}'.format( epoch, 0)) num_steps = len(train_dataloader) # trainning process total_SNRloss = 0.0 total_CEloss = 0.0 num_index = 1 start_time = time.time() for inputs, targets in train_dataloader: # Separation train if opt['model']['MODEL'] == 'DPRNN_Speech_Separation': mix = inputs ref = targets net.train() mix = mix.to(device) ref = [ref[i].to(device) for i in range(num_spks)] net.zero_grad() train_out = net(mix) SNR_loss = Loss(train_out, ref) loss = SNR_loss # Extraction train if opt['model']['MODEL'] == 'DPRNN_Speaker_Extraction': mix, aux = inputs ref, aux_len, sp_label = targets net.train() mix = mix.to(device) aux = aux.to(device) ref = ref.to(device) aux_len = aux_len.to(device) sp_label = sp_label.to(device) net.zero_grad() train_out = net([mix, aux, aux_len]) SNR_loss = Loss_SI_SDR(train_out[0], ref) CE_loss = torch.mean(ce_loss(train_out[1], sp_label)) loss = SNR_loss + weight * CE_loss total_CEloss += CE_loss.item() # Suppression train if opt['model']['MODEL'] == 'DPRNN_Speaker_Suppression': mix, aux = inputs ref, aux_len = targets net.train() mix = mix.to(device) aux = aux.to(device) ref = ref.to(device) aux_len = aux_len.to(device) net.zero_grad() train_out = net([mix, aux, aux_len]) SNR_loss = Loss_SI_SDR(train_out[0], ref) loss = SNR_loss # BP processs loss.backward() torch.nn.utils.clip_grad_norm_(net.parameters(), max_norm) Optimizer.step() total_SNRloss += SNR_loss.item() if num_index % print_freq == 0: message = '<Training epoch:{:d} / {:d} , iter:{:d} / {:d}, lr:{:.3e}, SI-SNR_loss:{:.3f}, CE loss:{:.3f}>'.format( epoch, total_epoch, num_index, num_steps, Optimizer.param_groups[0]['lr'], total_SNRloss / num_index, total_CEloss / num_index) logger.info(message) num_index += 1 end_time = time.time() mean_SNRLoss = total_SNRloss / num_index mean_CELoss = total_CEloss / num_index message = 'Finished Training *** <epoch:{:d} / {:d}, iter:{:d}, lr:{:.3e}, ' \ 'SNR loss:{:.3f}, CE loss:{:.3f}, Total time:{:.3f} min> '.format( epoch, total_epoch, num_index, Optimizer.param_groups[0]['lr'], mean_SNRLoss, mean_CELoss, (end_time - start_time) / 60) logger.info(message) # development processs val_num_index = 1 val_total_loss = 0.0 val_CE_loss = 0.0 val_acc_total = 0.0 val_acc = 0.0 val_start_time = time.time() val_num_steps = len(val_dataloader) for inputs, targets in val_dataloader: net.eval() with torch.no_grad(): # Separation development if opt['model']['MODEL'] == 'DPRNN_Speech_Separation': mix = inputs ref = targets mix = mix.to(device) ref = [ref[i].to(device) for i in range(num_spks)] Optimizer.zero_grad() val_out = net(mix) val_loss = Loss(val_out, ref) val_total_loss += val_loss.item() # Extraction development if opt['model']['MODEL'] == 'DPRNN_Speaker_Extraction': mix, aux = inputs ref, aux_len, label = targets mix = mix.to(device) aux = aux.to(device) ref = ref.to(device) aux_len = aux_len.to(device) label = label.to(device) Optimizer.zero_grad() val_out = net([mix, aux, aux_len]) val_loss = Loss_SI_SDR(val_out[0], ref) val_ce = torch.mean(ce_loss(val_out[1], label)) val_acc = accuracy_speaker(val_out[1], label) val_acc_total += val_acc val_total_loss += val_loss.item() val_CE_loss += val_ce.item() # suppression development if opt['model']['MODEL'] == 'DPRNN_Speaker_Suppression': mix, aux = inputs ref, aux_len = targets mix = mix.to(device) aux = aux.to(device) ref = ref.to(device) aux_len = aux_len.to(device) Optimizer.zero_grad() val_out = net([mix, aux, aux_len]) val_loss = Loss_SI_SDR(val_out[0], ref) val_total_loss += val_loss.item() if val_num_index % print_freq == 0: message = '<Valid-Epoch:{:d} / {:d}, iter:{:d} / {:d}, lr:{:.3e}, ' \ 'val_SISNR_loss:{:.3f}, val_CE_loss:{:.3f}, val_acc :{:.3f}>' .format( epoch, total_epoch, val_num_index, val_num_steps, Optimizer.param_groups[0]['lr'], val_total_loss / val_num_index, val_CE_loss / val_num_index, val_acc_total / val_num_index) logger.info(message) val_num_index += 1 val_end_time = time.time() mean_val_total_loss = val_total_loss / val_num_index mean_val_CE_loss = val_CE_loss / val_num_index mean_acc = val_acc_total / val_num_index message = 'Finished *** <epoch:{:d}, iter:{:d}, lr:{:.3e}, val SI-SNR loss:{:.3f}, val_CE_loss:{:.3f}, val_acc:{:.3f}' \ ' Total time:{:.3f} min> '.format(epoch, val_num_index, Optimizer.param_groups[0]['lr'], mean_val_total_loss, mean_val_CE_loss, mean_acc, (val_end_time - val_start_time) / 60) logger.info(message) Scheduler.step(mean_val_total_loss) if mean_val_total_loss >= best_loss: no_improve += 1 logger.info( 'No improvement, Best SI-SNR Loss: {:.4f}'.format(best_loss)) if mean_val_total_loss < best_loss: best_loss = mean_val_total_loss no_improve = 0 save_checkpoint(epoch, checkpoint_path, net, Optimizer, day_time) logger.info( 'Epoch: {:d}, Now Best SI-SNR Loss Change: {:.4f}'.format( epoch, best_loss)) if no_improve == early_stop: save_checkpoint(epoch, checkpoint_path, net, Optimizer, day_time) logger.info("Stop training cause no impr for {:d} epochs".format( no_improve)) break
# B x TF x nspk y = ibm * non_silent.expand_as(ibm) # attractors are the weighted average of the embeddings # calculated by V and Y # B x K x nspk v_y = torch.bmm(torch.transpose(v, 1, 2), y) # B x K x nspk sum_y = torch.sum(y, 1, keepdim=True).expand_as(v_y) # B x K x nspk attractor = v_y / (sum_y + self.eps) # calculate the distance bewteen embeddings and attractors # and generate the masks # B x TF x nspk dist = v.bmm(attractor) # B x TF x nspk mask = Fun.softmax(dist, dim=2) return mask, hidden def init_hidden(self, batch_size): return self.rnn.init_hidden(batch_size) if __name__ == "__main__": opt = option.parse('../config/train.yml') net = DANet(**opt['DANet']) input = torch.randn(5, 10, 129) ibm = torch.randint(2, (5, 1290, 2)) non_silent = torch.randn(5, 1290, 1) out = net(input, ibm, non_silent)