def train(model, args, device, writer): print('preparing data...') dataloader = make_dataloader( args.tr_clean_list, args.tr_noise_list, args.tr_rir_list, batch_size=args.batch_size, repeate=1, segement_length=8, sample_rate=args.sample_rate, num_workers=args.num_threads, ) print_freq = 100 num_batch = len(dataloader) params = model.get_params(args.weight_decay) optimizer = optim.Adam(params, lr=args.learn_rate) scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', factor=0.5, patience=1, verbose=True) if args.retrain: start_epoch, step = reload_model(model, optimizer, args.exp_dir, args.use_cuda) else: start_epoch, step = 0, 0 print('---------PRERUN-----------') lr = get_learning_rate(optimizer) print('(Initialization)') val_loss, val_sisnr = validation(model, args, lr, -1, device) writer.add_scalar('Loss/Train', val_loss, step) writer.add_scalar('Loss/Cross-Validation', val_loss, step) writer.add_scalar('SISNR/Train', -val_sisnr, step) writer.add_scalar('SISNR/Cross-Validation', -val_sisnr, step) warmup_epoch = 6 warmup_lr = args.learn_rate / (4 * warmup_epoch) for epoch in range(start_epoch, args.max_epoch): torch.manual_seed(args.seed + epoch) if args.use_cuda: torch.cuda.manual_seed(args.seed + epoch) model.train() loss_total = 0.0 loss_print = 0.0 sisnr_total = 0.0 sisnr_print = 0.0 ''' if epoch == 0 and warmup_epoch > 0: print('Use warmup stragery, and the lr is set to {:.5f}'.format(warmup_lr)) setup_lr(optimizer, warmup_lr) warmup_lr *= 4*(epoch+1) elif epoch == warmup_epoch: print('The warmup was end, and the lr is set to {:.5f}'.format(args.learn_rate)) setup_lr(optimizer, args.learn_rate) ''' stime = time.time() lr = get_learning_rate(optimizer) for idx, data in enumerate(dataloader): torch.cuda.empty_cache() inputs, labels = data inputs = inputs.to(device) labels = labels.to(device) model.zero_grad() est_spec, est_wav = data_parallel(model, (inputs, )) ''' if epoch > 8: gth_spec, gth_wav = data_parallel(model, (labels,)) else: gth_spec = data_parallel(model.stft, (labels))[0] ''' #gth_spec = data_parallel(model.stft, (labels)) #loss = model.loss(est_spec, gth_spec, loss_mode='MSE') #loss.backward() sisnr = model.loss(est_wav, labels, loss_mode='SI-SNR') sisnr.backward() nn.utils.clip_grad_norm_(model.parameters(), args.clip_grad_norm) optimizer.step() step += 1 #loss_total += loss.data.cpu() #loss_print += loss.data.cpu() sisnr_total += sisnr.data.cpu() sisnr_print += sisnr.data.cpu() loss_total = sisnr_total loss_print = sisnr_print del est_wav, est_spec if (idx + 1) % 3000 == 0: save_checkpoint(model, optimizer, -1, step, args.exp_dir) if (idx + 1) % print_freq == 0: eplashed = time.time() - stime speed_avg = eplashed / (idx + 1) loss_print_avg = loss_print / print_freq sisnr_print_avg = sisnr_print / print_freq print('Epoch {:3d}/{:3d} | batches {:5d}/{:5d} | lr {:1.4e} |' '{:2.3f}s/batches | loss {:2.6f} |' 'SI-SNR {:2.4f} '.format( epoch, args.max_epoch, idx + 1, num_batch, lr, speed_avg, loss_print_avg, -sisnr_print_avg, )) sys.stdout.flush() writer.add_scalar('Loss/Train', loss_print_avg, step) writer.add_scalar('SISNR/Train', -sisnr_print_avg, step) loss_print = 0.0 sisnr_print = 0.0 eplashed = time.time() - stime loss_total_avg = loss_total / num_batch sisnr_total_avg = sisnr_total / num_batch print('Training AVG.LOSS |' ' Epoch {:3d}/{:3d} | lr {:1.4e} |' ' {:2.3f}s/batch | time {:3.2f}mins |' ' loss {:2.6f} |' ' SISNR {:2.4f}|'.format(epoch + 1, args.max_epoch, lr, eplashed / num_batch, eplashed / 60.0, loss_total_avg.item(), -sisnr_total_avg.item())) val_loss, val_sisnr = validation(model, args, lr, epoch, device) writer.add_scalar('Loss/Cross-Validation', val_loss, step) writer.add_scalar('SISNR/Cross-Validation', -val_sisnr, step) writer.add_scalar('learn_rate', lr, step) if val_loss > scheduler.best: print('Rejected !!! The best is {:2.6f}'.format(scheduler.best)) else: save_checkpoint(model, optimizer, epoch + 1, step, args.exp_dir, mode='best_model') scheduler.step(val_loss) sys.stdout.flush() stime = time.time()
def main(args): if args.name == 'sbr': experiment_dir = Path( os.path.join(BASE_DIR, 'save/sbr/' + get_run_name())) else: experiment_dir = Path( os.path.join(BASE_DIR, 'save/double_{}/'.format(args.name) + get_run_name())) experiment_dir.mkdir(parents=True, exist_ok=True) checkpoints_dir = experiment_dir.joinpath('checkpoints/') checkpoints_dir.mkdir(exist_ok=True) log_dir = experiment_dir.joinpath('logs/') log_dir.mkdir(exist_ok=True) # writer = SummaryWriter(log_dir) config_f = open(os.path.join(log_dir, 'config.json'), 'w') json.dump(vars(args), config_f) config_f.close() logger = logging.getLogger("Model") logger.setLevel(logging.INFO) formatter = logging.Formatter( '%(asctime)s - %(name)s - %(levelname)s - %(message)s') file_handler = logging.FileHandler('%s/%s.txt' % (log_dir, args.model)) file_handler.setLevel(logging.INFO) file_handler.setFormatter(formatter) logger.addHandler(file_handler) num_class = DATASETS[args.dataset]['n_classes'] if args.name == 'sbr': net_p = ngvnn.Net_Prev(pretraining=args.pretrain, num_views=args.num_views) net_whole = ngvnn.Net_Whole(pretraining=args.pretrain) elif args.name == 'pointnet': import importlib model = importlib.import_module(args.model) net_p = model.get_encoder_model(num_class, normal_channel=False) net_whole = model.get_encoder_model(num_class, normal_channel=False) elif args.name == 'ngvnn': net_p = ngvnn.Net_Prev(pretraining=args.pretrain, num_views=args.num_views) net_whole = ngvnn.Net_Prev(pretraining=args.pretrain, num_views=args.num_views) else: NotImplementedError net_cls = ngvnn.Net_Classifier(nclasses=num_class) crt_cls = nn.CrossEntropyLoss().cuda() if args.triplet_type == 'tcl': center_embed = 512 crt_tpl = custom_loss.TripletCenterLoss(margin=args.margin, center_embed=center_embed, num_classes=num_class).cuda() optim_centers = torch.optim.SGD(crt_tpl.parameters(), lr=0.1) else: from dataset.TripletSampler import HardestNegativeTripletSelector anchor_index = args.n_classes * args.n_samples crt_tpl = custom_loss.OnlineTripletLoss( args.margin, HardestNegativeTripletSelector(args.margin, False, anchor_index)) criterion = [crt_cls, crt_tpl, args.w1, args.w2] # Load from checkpoint start_epoch = best_top1 = 0 if args.resume: checkpoint = torch.load(args.resume) net_p.load_state_dict(checkpoint['net_p']) net_whole.load_state_dict(checkpoint['net_whole']) net_cls.load_state_dict(checkpoint['net_cls']) start_epoch = checkpoint['epoch'] best_top1 = checkpoint['best_prec'] net_whole = nn.DataParallel(net_whole).cuda() net_cls = nn.DataParallel(net_cls).cuda() net_p = nn.DataParallel(net_p).cuda() optim_shape = optim.SGD([{ 'params': net_p.parameters() }, { 'params': net_cls.parameters() }], lr=0.001, momentum=0.9, weight_decay=args.weight_decay) if args.name in ['sbr', 'ngvnn', 'pointnet']: base_param_ids = set(map(id, net_whole.module.features.parameters())) new_params = [ p for p in net_whole.parameters() if id(p) not in base_param_ids ] param_groups = [{ 'params': net_whole.module.features.parameters(), 'lr_mult': 0.1 }, { 'params': new_params, 'lr_mult': 1.0 }] optim_sketch = optim.SGD(param_groups, lr=0.001, momentum=0.9, weight_decay=args.weight_decay) else: optim_sketch = optim.SGD([{ 'params': net_p.parameters() }, { 'params': net_cls.parameters() }], lr=0.001, momentum=0.9, weight_decay=args.weight_decay) if args.triplet_type == 'tcl': optimizer = (optim_sketch, optim_shape, optim_centers) else: optimizer = (optim_sketch, optim_shape) model = (net_whole, net_p, net_cls) # Schedule learning rate def adjust_lr(epoch, optimizer): step_size = 800 if args.pk_flag else 80 # 40 lr = args.lr * (0.1**(epoch // step_size)) for g in optimizer.param_groups: g['lr'] = lr * g.get('lr_mult', 1) train_shape_loader, train_sketch_loader, test_shape_loader, test_sketch_loader = get_dataloader( args) # Start training top1 = 0.0 best_epoch = -1 best_metric = None '''TRANING''' logger.info('Start training...') for epoch in range(start_epoch, args.epoch): # cls acc top1 train_top1 = train(train_sketch_loader, train_shape_loader, model, criterion, optimizer, epoch, args, logger) if train_top1 > 0.1: print("Test:") cur_metric = validate(test_sketch_loader, test_shape_loader, model, criterion, args, logger) top1 = cur_metric[3] # mAP_feat_norm is_best = top1 > best_top1 if is_best: best_epoch = epoch + 1 best_metric = cur_metric best_top1 = max(top1, best_top1) # path_checkpoint = '{0}/model_latest.pth'.format(checkpoints_dir) # misc.save_checkpoint(checkpoint, path_checkpoint) if is_best: # save checkpoint logger.info('Save model...') savepath = str(checkpoints_dir) + '/best_model.pth' checkpoint = {} checkpoint['epoch'] = epoch + 1 checkpoint['current_prec'] = top1 checkpoint['best_prec'] = best_top1 checkpoint['net_p'] = net_p.module.state_dict() checkpoint['net_whole'] = net_whole.module.state_dict() checkpoint['net_cls'] = net_cls.module.state_dict() # torch.save(checkpoint, savepath) # path_checkpoint = '{0}/best_model.pth'.format(checkpoints_dir) misc.save_checkpoint(checkpoint, savepath) #path_checkpoint) log_string( '\n * Finished epoch {:3d} top1: {:5.3%} best: {:5.3%}{} @epoch {}\n' .format(epoch, top1, best_top1, ' *' if is_best else '', best_epoch), logger) logger.info('End of training...') log_string('Best metric {}'.format(best_metric), logger) return experiment_dir
def train(model, args, device, writer): print('preparing data...') dataloader, dataset = make_loader( args.tr_list, args.batch_size, num_workers=args.num_threads, ) print_freq = 100 num_batch = len(dataloader) params = model.get_params(args.weight_decay) optimizer = optim.Adam(params, lr=args.learn_rate) scheduler = optim.lr_scheduler.ReduceLROnPlateau( optimizer, 'min', factor=0.5, patience=1, verbose=True) if args.retrain: start_epoch, step = reload_model(model, optimizer, args.exp_dir, args.use_cuda) else: start_epoch, step = 0, 0 print('---------PRERUN-----------') lr = get_learning_rate(optimizer) print('(Initialization)') val_loss, val_sisnr = validation(model, args, lr, -1, device) writer.add_scalar('Loss/Train', val_loss, step) writer.add_scalar('Loss/Cross-Validation', val_loss, step) writer.add_scalar('SiSNR/Train', -val_sisnr, step) writer.add_scalar('SiSNR/Cross-Validation', -val_sisnr, step) for epoch in range(start_epoch, args.max_epoch): torch.manual_seed(args.seed + epoch) if args.use_cuda: torch.cuda.manual_seed(args.seed + epoch) model.train() sisnr_total = 0.0 sisnr_print = 0.0 mix_loss_total = 0.0 mix_loss_print = 0.0 amp_loss_total = 0.0 amp_loss_print = 0.0 phase_loss_total = 0.0 phase_loss_print = 0.0 stime = time.time() lr = get_learning_rate(optimizer) for idx, data in enumerate(dataloader): inputs, labels = data inputs = inputs.to(device) labels = labels.to(device) model.zero_grad() outputs, wav = data_parallel(model, (inputs,)) loss = model.loss(outputs, labels, mode='Mix') loss[0].backward() nn.utils.clip_grad_norm_(model.parameters(), args.clip_grad_norm) optimizer.step() step += 1 sisnr = model.loss(wav, labels, mode='SiSNR') mix_loss_total += loss[0].data.cpu() mix_loss_print += loss[0].data.cpu() amp_loss_total += loss[1].data.cpu() amp_loss_print += loss[1].data.cpu() phase_loss_total += loss[2].data.cpu() phase_loss_print += loss[2].data.cpu() sisnr_print += sisnr.data.cpu() sisnr_total += sisnr.data.cpu() del outputs, labels, inputs, loss, wav if (idx+1) % 1000 == 0: save_checkpoint(model, optimizer, -1, step, args.exp_dir) if (idx + 1) % print_freq == 0: eplashed = time.time() - stime speed_avg = eplashed / (idx+1) mix_loss_print_avg = mix_loss_print / print_freq amp_loss_print_avg = amp_loss_print / print_freq phase_loss_print_avg = phase_loss_print / print_freq sisnr_print_avg = sisnr_print / print_freq print('Epoch {:3d}/{:3d} | batches {:5d}/{:5d} | lr {:1.4e} |' '{:2.3f}s/batches ' '| Mixloss {:2.4f}' '| AMPloss {:2.4f}' '| Phaseloss {:2.4f}' '| SiSNR {:2.4f}' .format( epoch, args.max_epoch, idx + 1, num_batch, lr, speed_avg, mix_loss_print_avg, amp_loss_print_avg, phase_loss_print_avg, -sisnr_print_avg )) sys.stdout.flush() writer.add_scalar('SiSNR/Train', -sisnr_print_avg, step) writer.add_scalar('Loss/Train', mix_loss_print_avg, step) mix_loss_print = 0. amp_loss_print = 0. phase_loss_print = 0. sisnr_print = 0. eplashed = time.time() - stime mix_loss_total_avg = mix_loss_total / num_batch sisnr_total_avg = sisnr_total / num_batch print( 'Training AVG.LOSS |' ' Epoch {:3d}/{:3d} | lr {:1.4e} |' ' {:2.3f}s/batch | time {:3.2f}mins |' ' Mixloss {:2.4f}' ' SiSNR {:2.4f}' .format( epoch + 1, args.max_epoch, lr, eplashed/num_batch, eplashed/60.0, mix_loss_total_avg, -sisnr_total_avg )) val_loss, val_sisnr = validation(model, args, lr, epoch, device) writer.add_scalar('Loss/Cross-Validation', val_loss, step) writer.add_scalar('SiSNR/Cross-Validation', -val_sisnr, step) writer.add_scalar('learn_rate', lr, step) if val_loss > scheduler.best: print('Rejected !!! The best is {:2.6f}'.format(scheduler.best)) else: save_checkpoint(model, optimizer, epoch + 1, step, args.exp_dir, mode='best_model') scheduler.step(val_loss) sys.stdout.flush() stime = time.time()
def train(model, args, device, writer): print('preparing data...') dataloader, dataset = make_loader( args.tr_list, args.batch_size, num_workers=args.num_threads, processer=Processer( win_len=args.win_len, win_inc=args.win_inc, left_context=args.left_context, right_context=args.right_context, fft_len=args.fft_len, window_type=args.win_type)) print_freq = 100 num_batch = len(dataloader) params = model.get_params(args.weight_decay) optimizer = optim.Adam(params, lr=args.learn_rate) scheduler = optim.lr_scheduler.ReduceLROnPlateau( optimizer, 'min', factor=0.5, patience=1, verbose=True) if args.retrain: start_epoch, step = reload_model(model, optimizer, args.exp_dir, args.use_cuda) else: start_epoch, step = 0, 0 print('---------PRERUN-----------') lr = get_learning_rate(optimizer) print('(Initialization)') val_loss = validation(model, args, lr, -1, device) writer.add_scalar('Loss/Train', val_loss, step) writer.add_scalar('Loss/Cross-Validation', val_loss, step) for epoch in range(start_epoch, args.max_epoch): torch.manual_seed(args.seed + epoch) if args.use_cuda: torch.cuda.manual_seed(args.seed + epoch) model.train() loss_total = 0.0 loss_print = 0.0 stime = time.time() lr = get_learning_rate(optimizer) for idx, data in enumerate(dataloader): inputs, labels, lengths = data inputs = inputs.to(device) labels = labels.to(device) lengths = lengths model.zero_grad() outputs, _ = data_parallel(model, (inputs, lengths)) loss = model.loss(outputs, labels, lengths) loss.backward() nn.utils.clip_grad_norm_(model.parameters(), args.clip_grad_norm) optimizer.step() step += 1 loss_total += loss.data.cpu() loss_print += loss.data.cpu() del lengths, outputs, labels, inputs, loss, _ if (idx+1) % 3000 == 0: save_checkpoint(model, optimizer, epoch + 1, step, args.exp_dir) if (idx + 1) % print_freq == 0: eplashed = time.time() - stime speed_avg = eplashed / (idx+1) loss_print_avg = loss_print / print_freq print('Epoch {:3d}/{:3d} | batches {:5d}/{:5d} | lr {:1.4e} |' '{:2.3f}s/batches | loss {:2.6f}'.format( epoch, args.max_epoch, idx + 1, num_batch, lr, speed_avg, loss_print_avg)) sys.stdout.flush() writer.add_scalar('Loss/Train', loss_print_avg, step) loss_print = 0.0 eplashed = time.time() - stime loss_total_avg = loss_total / num_batch print( 'Training AVG.LOSS |' ' Epoch {:3d}/{:3d} | lr {:1.4e} |' ' {:2.3f}s/batch | time {:3.2f}mins |' ' loss {:2.6f}'.format( epoch + 1, args.max_epoch, lr, eplashed/num_batch, eplashed/60.0, loss_total_avg.item())) val_loss = validation(model, args, lr, epoch, device) writer.add_scalar('Loss/Cross-Validation', val_loss, step) writer.add_scalar('learn_rate', lr, step) if val_loss > scheduler.best: print('Rejected !!! The best is {:2.6f}'.format(scheduler.best)) else: save_checkpoint(model, optimizer, epoch + 1, step, args.exp_dir) scheduler.step(val_loss) sys.stdout.flush() stime = time.time()
def train(model, args, device, writer): print('preparing data...') dataloader, dataset = make_loader(args.tr_list, args.batch_size, num_workers=args.num_threads, segment_length=1.5, sample_rate=args.sample_rate, processer=Processer( sample_rate=args.sample_rate, )) print_freq = 100 num_batch = len(dataloader) params = model.get_params(args.weight_decay) optimizer = optim.Adam(params, lr=args.learn_rate) scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', factor=0.5, patience=1, verbose=True) if args.retrain: start_epoch, step = reload_model(model, optimizer, args.exp_dir, args.use_cuda) else: start_epoch, step = 0, 0 print('---------PRERUN-----------') lr = get_learning_rate(optimizer) print('(Initialization)') val_loss, val_sisnr = 30, 30. #validation(model, args, lr, -1, device) writer.add_scalar('Loss/Train', val_loss, step) writer.add_scalar('Loss/Cross-Validation', val_loss, step) writer.add_scalar('SISNR/Train', -val_sisnr, step) writer.add_scalar('SISNR/Cross-Validation', -val_sisnr, step) for epoch in range(start_epoch, args.max_epoch): torch.manual_seed(args.seed + epoch) if args.use_cuda: torch.cuda.manual_seed(args.seed + epoch) model.train() all_moniter = Moniter('All', num_batch, print_freq) sdr_moniter = Moniter('SDR', num_batch, print_freq) speaker_moniter = Moniter('Speaker', num_batch, print_freq) reg_moniter = Moniter('Reg', num_batch, print_freq) stime = time.time() lr = get_learning_rate(optimizer) for idx, data in enumerate(dataloader): inputs, labels, spkid = data inputs = inputs.to(device) labels = labels.to(device) spkid = spkid.to(device) model.zero_grad() est_wav, speaker_loss, reg_loss = data_parallel( model, (inputs, spkid)) speaker_loss = torch.mean(speaker_loss) reg_loss = torch.mean(reg_loss) sdr = model.loss(est_wav, labels, loss_mode='SDR') all = sdr + 2 * speaker_loss + 0.3 * reg_loss all.backward() nn.utils.clip_grad_norm_(model.parameters(), args.clip_grad_norm) optimizer.step() all_moniter(all) sdr_moniter(sdr) speaker_moniter(speaker_loss) reg_moniter(reg_loss) step += 1 if (idx + 1) % 3000 == 0: save_checkpoint(model, optimizer, -1, step, args.exp_dir) #val_loss, val_sisnr= validation(model, args, lr, epoch, device) #scheduler.step(val_loss) #lr = get_learning_rate(optimizer) if (idx + 1) % print_freq == 0: eplashed = time.time() - stime speed_avg = eplashed / (idx + 1) log_str = 'Epoch {:3d}/{:3d} | batches {:5d}/{:5d} | lr {:1.4e} |'\ '{:2.3f}s/batches |' \ ' {:s} |'\ ' {:s} |'\ ' {:s} |'\ ' {:s} |'\ ''.format( epoch, args.max_epoch, idx + 1, num_batch, lr, speed_avg, all_moniter.recent(), sdr_moniter.recent(), speaker_moniter.recent(), reg_moniter.recent() ) writer.add_scalar('Loss/Train', all_moniter.rec_float(), step) writer.add_scalar('SISNR/Train', -sdr_moniter.rec_float(), step) print(log_str) all_moniter.reset() sdr_moniter.reset() speaker_moniter.reset() reg_moniter.reset() sys.stdout.flush() eplashed = time.time() - stime log_str = 'Training AVG.LOSS |' \ ' Epoch {:3d}/{:3d} | lr {:1.4e} |' \ ' {:2.3f}s/batch | time {:3.2f}mins |' \ ' {:s} |'\ ' {:s} |'\ ' {:s} |'\ ' {:s} |'\ ''.format( epoch + 1, args.max_epoch, lr, eplashed/num_batch, eplashed/60.0, all_moniter.average(), sdr_moniter.average(), speaker_moniter.average(), reg_moniter.average() ) print(log_str) val_loss, val_sisnr = validation(model, args, lr, epoch, device) writer.add_scalar('Loss/Cross-Validation', val_loss, step) writer.add_scalar('SISNR/Cross-Validation', -val_sisnr, step) writer.add_scalar('learn_rate', lr, step) if val_loss > scheduler.best: print('Rejected !!! The best is {:2.6f}'.format(scheduler.best)) else: save_checkpoint(model, optimizer, epoch + 1, step, args.exp_dir, mode='best_model') scheduler.step(val_loss) sys.stdout.flush() stime = time.time()