def __call__(self, model: nn.Module, epoch_idx, output_dir, eval_rtn: dict, test_rtn: dict, logger: logging.Logger, writer: SummaryWriter): # save model acc = eval_rtn.get('err_spk', 0) - eval_rtn.get('err_sph', 1) is_best = acc > self.best_accu self.best_accu = acc if is_best else self.best_accu model_filename = "epoch_{}.pth".format(epoch_idx) save_checkpoint(model, os.path.join(output_dir, model_filename), meta={'epoch': epoch_idx}) os.system("ln -sf {} {}".format( os.path.abspath(os.path.join(output_dir, model_filename)), os.path.join(output_dir, "latest.pth"))) if is_best: os.system("ln -sf {} {}".format( os.path.abspath(os.path.join(output_dir, model_filename)), os.path.join(output_dir, "best.pth"))) if logger is not None: logger.info("EvalHook: best accu: {:.3f}, is_best: {}".format( self.best_accu, is_best))
def main(): args = arguments() num_templates = 25 # aka the number of clusters normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) img_transforms = transforms.Compose([transforms.ToTensor(), normalize]) train_loader, _ = get_dataloader(args.traindata, args, num_templates, img_transforms=img_transforms) model = DetectionModel(num_objects=1, num_templates=num_templates) loss_fn = DetectionCriterion(num_templates) # directory where we'll store model weights weights_dir = "weights" if not osp.exists(weights_dir): os.mkdir(weights_dir) # check for CUDA if torch.cuda.is_available(): device = torch.device('cuda:0') else: device = torch.device('cpu') optimizer = optim.SGD(model.learnable_parameters(args.lr), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) # optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) if args.resume: checkpoint = torch.load(args.resume) model.load_state_dict(checkpoint['model']) optimizer.load_state_dict(checkpoint['optimizer']) # Set the start epoch if it has not been if not args.start_epoch: args.start_epoch = checkpoint['epoch'] scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=20, last_epoch=args.start_epoch - 1) # train and evalute for `epochs` for epoch in range(args.start_epoch, args.epochs): scheduler.step() trainer.train(model, loss_fn, optimizer, train_loader, epoch, device=device) if (epoch + 1) % args.save_every == 0: trainer.save_checkpoint( { 'epoch': epoch + 1, 'batch_size': train_loader.batch_size, 'model': model.state_dict(), 'optimizer': optimizer.state_dict() }, filename="checkpoint_{0}.pth".format(epoch + 1), save_path=weights_dir)
def train(train_loader, model, optimizer, train_vars, control_vars, verbose=True): curr_epoch_iter = 1 for batch_idx, (data, target) in enumerate(train_loader): control_vars['batch_idx'] = batch_idx if batch_idx < control_vars['iter_size']: print_verbose("\rPerforming first iteration; current mini-batch: " + str(batch_idx+1) + "/" + str(control_vars['iter_size']), verbose, n_tabs=0, erase_line=True) # check if arrived at iter to start if control_vars['curr_epoch_iter'] < control_vars['start_iter_mod']: if batch_idx % control_vars['iter_size'] == 0: print_verbose("\rGoing through iterations to arrive at last one saved... " + str(int(control_vars['curr_epoch_iter']*100.0/control_vars['start_iter_mod'])) + "% of " + str(control_vars['start_iter_mod']) + " iterations (" + str(control_vars['curr_epoch_iter']) + "/" + str(control_vars['start_iter_mod']) + ")", verbose, n_tabs=0, erase_line=True) control_vars['curr_epoch_iter'] += 1 control_vars['curr_iter'] += 1 curr_epoch_iter += 1 continue # save checkpoint after final iteration if control_vars['curr_iter'] == control_vars['num_iter']: print_verbose("\nReached final number of iterations: " + str(control_vars['num_iter']), verbose) print_verbose("\tSaving final model checkpoint...", verbose) final_model_dict = { 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'control_vars': control_vars, 'train_vars': train_vars, } trainer.save_checkpoint(final_model_dict, filename=train_vars['checkpoint_filenamebase'] + 'final' + str(control_vars['num_iter']) + '.pth.tar') control_vars['done_training'] = True break # start time counter start = time.time() # get data and targetas cuda variables target_heatmaps, target_joints, _, target_prior = target data, target_heatmaps, target_prior = Variable(data), Variable(target_heatmaps), Variable(target_prior) if train_vars['use_cuda']: data = data.cuda() target_heatmaps = target_heatmaps.cuda() target_prior = target_prior.cuda() # visualize if debugging # get model output output = model(data) # accumulate loss for sub-mini-batch if train_vars['cross_entropy']: loss_func = my_losses.cross_entropy_loss_p_logq else: loss_func = my_losses.euclidean_loss loss, loss_prior = my_losses.calculate_loss_HALNet_prior(loss_func, output, target_heatmaps, target_prior, model.joint_ixs, model.WEIGHT_LOSS_INTERMED1, model.WEIGHT_LOSS_INTERMED2, model.WEIGHT_LOSS_INTERMED3, model.WEIGHT_LOSS_MAIN, control_vars['iter_size']) loss.backward() train_vars['total_loss'] += loss train_vars['total_loss_prior'] += loss_prior # accumulate pixel dist loss for sub-mini-batch train_vars['total_pixel_loss'] = my_losses.accumulate_pixel_dist_loss_multiple( train_vars['total_pixel_loss'], output[3], target_heatmaps, control_vars['batch_size']) if train_vars['cross_entropy']: train_vars['total_pixel_loss_sample'] = my_losses.accumulate_pixel_dist_loss_from_sample_multiple( train_vars['total_pixel_loss_sample'], output[3], target_heatmaps, control_vars['batch_size']) else: train_vars['total_pixel_loss_sample'] = [-1] * len(model.joint_ixs) # get boolean variable stating whether a mini-batch has been completed minibatch_completed = (batch_idx+1) % control_vars['iter_size'] == 0 if minibatch_completed: # optimise for mini-batch optimizer.step() # clear optimiser optimizer.zero_grad() # append total loss train_vars['losses'].append(train_vars['total_loss'].data[0]) # erase total loss total_loss = train_vars['total_loss'].data[0] train_vars['total_loss'] = 0 # append total loss prior train_vars['losses_prior'].append(train_vars['total_loss_prior'].data[0]) # erase total loss total_loss_prior = train_vars['total_loss_prior'].data[0] train_vars['total_loss_prior'] = 0 # append dist loss train_vars['pixel_losses'].append(train_vars['total_pixel_loss']) # erase pixel dist loss train_vars['total_pixel_loss'] = [0] * len(model.joint_ixs) # append dist loss of sample from output train_vars['pixel_losses_sample'].append(train_vars['total_pixel_loss_sample']) # erase dist loss of sample from output train_vars['total_pixel_loss_sample'] = [0] * len(model.joint_ixs) # check if loss is better if train_vars['losses'][-1] < train_vars['best_loss']: train_vars['best_loss'] = train_vars['losses'][-1] print_verbose(" This is a best loss found so far: " + str(train_vars['losses'][-1]), verbose) train_vars['best_model_dict'] = { 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'control_vars': control_vars, 'train_vars': train_vars, } if train_vars['losses_prior'][-1] < train_vars['best_loss_prior']: train_vars['best_loss_prior'] = train_vars['losses_prior'][-1] # log checkpoint if control_vars['curr_iter'] % control_vars['log_interval'] == 0: trainer.print_log_info(model, optimizer, epoch, total_loss, train_vars, control_vars) msg = '' msg += print_verbose( "-------------------------------------------------------------------------------------------", verbose) + "\n" msg += print_verbose("Current loss (prior): " + str(total_loss_prior), verbose) + "\n" msg += print_verbose("Best loss (prior): " + str(train_vars['best_loss_prior']), verbose) + "\n" msg += print_verbose("Mean total loss (prior): " + str(np.mean(train_vars['losses_prior'])), verbose) + "\n" msg += print_verbose("Mean loss (prior) for last " + str(control_vars['log_interval']) + " iterations (average total loss): " + str( np.mean(train_vars['losses_prior'][-control_vars['log_interval']:])), verbose) + "\n" msg += print_verbose( "-------------------------------------------------------------------------------------------", verbose) + "\n" if not control_vars['output_filepath'] == '': with open(control_vars['output_filepath'], 'a') as f: f.write(msg + '\n') if control_vars['curr_iter'] % control_vars['log_interval_valid'] == 0: print_verbose("\nSaving model and checkpoint model for validation", verbose) checkpoint_model_dict = { 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'control_vars': control_vars, 'train_vars': train_vars, } trainer.save_checkpoint(checkpoint_model_dict, filename=train_vars['checkpoint_filenamebase'] + 'for_valid_' + str(control_vars['curr_iter']) + '.pth.tar') # print time lapse prefix = 'Training (Epoch #' + str(epoch) + ' ' + str(control_vars['curr_epoch_iter']) + '/' +\ str(control_vars['tot_iter']) + ')' + ', (Batch ' + str(control_vars['batch_idx']+1) +\ '(' + str(control_vars['iter_size']) + ')' + '/' +\ str(control_vars['num_batches']) + ')' + ', (Iter #' + str(control_vars['curr_iter']) +\ '(' + str(control_vars['batch_size']) + ')' +\ ' - log every ' + str(control_vars['log_interval']) + ' iter): ' control_vars['tot_toc'] = display_est_time_loop(control_vars['tot_toc'] + time.time() - start, control_vars['curr_iter'], control_vars['num_iter'], prefix=prefix) control_vars['curr_iter'] += 1 control_vars['start_iter'] = control_vars['curr_iter'] + 1 control_vars['curr_epoch_iter'] += 1 return train_vars, control_vars
def main(): args = arguments() segmentation = False # check for CUDA if torch.cuda.is_available(): device = torch.device('cuda:0') else: device = torch.device('cpu') train_loader = get_dataloader(args.traindata, args, device=device) model = CoattentionNet() loss_fn = SiameseCriterion(device=device) pretrained_dict = torch.load( "../crowd-counting-revise/weight/checkpoint_104.pth")["model"] model_dict = model.state_dict() pretrained_dict = { k: v for k, v in pretrained_dict.items() if "frontend" not in k and "backend2" not in k and "main_classifier" not in k } model_dict.update(pretrained_dict) model.load_state_dict(model_dict) # directory where we'll store model weights weights_dir = "weight_all" if not osp.exists(weights_dir): os.mkdir(weights_dir) optimizer = optim.Adam(model.learnable_parameters(args.lr), lr=args.lr, weight_decay=args.weight_decay) #optimizer = optim.Adam(model.learnable_parameters(args.lr), lr=args.lr) if args.resume: checkpoint = torch.load(args.resume) model = model.to(device) model.load_state_dict(checkpoint['model']) optimizer.load_state_dict(checkpoint['optimizer']) # Set the start epoch if it has not been if not args.start_epoch: args.start_epoch = checkpoint['epoch'] scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=20, last_epoch=args.start_epoch - 1) print("Start training!") # train and evalute for `epochs` best_mae = sys.maxsize for epoch in range(args.start_epoch, args.epochs): if epoch % 4 == 0 and epoch != 0: val_loader = val_dataloader(args, device=device) with torch.no_grad(): if not segmentation: mae, mse = evaluate_model(model, val_loader, device=device, training=True, debug=args.debug, segmentation=segmentation) if mae < best_mae: best_mae = mae best_mse = mse best_model = "checkpoint_{0}.pth".format(epoch) log_text = 'epoch: %4d, mae: %4.2f, mse: %4.2f' % ( epoch - 1, mae, mse) log_print(log_text, color='green', attrs=['bold']) log_text = 'BEST MAE: %0.1f, BEST MSE: %0.1f, BEST MODEL: %s' % ( best_mae, best_mse, best_model) log_print(log_text, color='green', attrs=['bold']) else: _, _ = evaluate_model(model, val_loader, device=device, training=True, debug=args.debug, segmentation=segmentation) scheduler.step() trainer.train(model, loss_fn, optimizer, train_loader, epoch, device=device) if (epoch + 1) % args.save_every == 0: trainer.save_checkpoint( { 'epoch': epoch + 1, 'batch_size': train_loader.batch_size, 'model': model.state_dict(), 'optimizer': optimizer.state_dict() }, filename="checkpoint_{0}.pth".format(epoch + 1), save_path=weights_dir)
def validate(valid_loader, model, optimizer, valid_vars, control_vars, verbose=True): curr_epoch_iter = 1 for batch_idx, (data, target) in enumerate(valid_loader): control_vars['batch_idx'] = batch_idx if batch_idx < control_vars['iter_size']: print_verbose("\rPerforming first iteration; current mini-batch: " + str(batch_idx + 1) + "/" + str(control_vars['iter_size']), verbose, n_tabs=0, erase_line=True) # start time counter start = time.time() # get data and targetas cuda variables target_heatmaps, target_joints, target_joints_z = target data, target_heatmaps = Variable(data), Variable(target_heatmaps) if valid_vars['use_cuda']: data = data.cuda() target_heatmaps = target_heatmaps.cuda() # visualize if debugging # get model output output = model(data) # accumulate loss for sub-mini-batch if valid_vars['cross_entropy']: loss_func = my_losses.cross_entropy_loss_p_logq else: loss_func = my_losses.euclidean_loss loss = my_losses.calculate_loss_HALNet(loss_func, output, target_heatmaps, model.joint_ixs, model.WEIGHT_LOSS_INTERMED1, model.WEIGHT_LOSS_INTERMED2, model.WEIGHT_LOSS_INTERMED3, model.WEIGHT_LOSS_MAIN, control_vars['iter_size']) if DEBUG_VISUALLY: for i in range(control_vars['max_mem_batch']): filenamebase_idx = (batch_idx * control_vars['max_mem_batch']) + i filenamebase = valid_loader.dataset.get_filenamebase(filenamebase_idx) fig = visualize.create_fig() #visualize.plot_joints_from_heatmaps(output[3][i].data.numpy(), fig=fig, # title=filenamebase, data=data[i].data.numpy()) #visualize.plot_image_and_heatmap(output[3][i][8].data.numpy(), # data=data[i].data.numpy(), # title=filenamebase) #visualize.savefig('/home/paulo/' + filenamebase.replace('/', '_') + '_heatmap') labels_colorspace = conv.heatmaps_to_joints_colorspace(output[3][i].data.numpy()) data_crop, crop_coords, labels_heatmaps, labels_colorspace = \ converter.crop_image_get_labels(data[i].data.numpy(), labels_colorspace, range(21)) visualize.plot_image(data_crop, title=filenamebase, fig=fig) visualize.plot_joints_from_colorspace(labels_colorspace, title=filenamebase, fig=fig, data=data_crop) #visualize.savefig('/home/paulo/' + filenamebase.replace('/', '_') + '_crop') visualize.show() #loss.backward() valid_vars['total_loss'] += loss # accumulate pixel dist loss for sub-mini-batch valid_vars['total_pixel_loss'] = my_losses.accumulate_pixel_dist_loss_multiple( valid_vars['total_pixel_loss'], output[3], target_heatmaps, control_vars['batch_size']) if valid_vars['cross_entropy']: valid_vars['total_pixel_loss_sample'] = my_losses.accumulate_pixel_dist_loss_from_sample_multiple( valid_vars['total_pixel_loss_sample'], output[3], target_heatmaps, control_vars['batch_size']) else: valid_vars['total_pixel_loss_sample'] = [-1] * len(model.joint_ixs) # get boolean variable stating whether a mini-batch has been completed minibatch_completed = (batch_idx+1) % control_vars['iter_size'] == 0 if minibatch_completed: # append total loss valid_vars['losses'].append(valid_vars['total_loss'].item()) # erase total loss total_loss = valid_vars['total_loss'].item() valid_vars['total_loss'] = 0 # append dist loss valid_vars['pixel_losses'].append(valid_vars['total_pixel_loss']) # erase pixel dist loss valid_vars['total_pixel_loss'] = [0] * len(model.joint_ixs) # append dist loss of sample from output valid_vars['pixel_losses_sample'].append(valid_vars['total_pixel_loss_sample']) # erase dist loss of sample from output valid_vars['total_pixel_loss_sample'] = [0] * len(model.joint_ixs) # check if loss is better if valid_vars['losses'][-1] < valid_vars['best_loss']: valid_vars['best_loss'] = valid_vars['losses'][-1] #print_verbose(" This is a best loss found so far: " + str(valid_vars['losses'][-1]), verbose) # log checkpoint if control_vars['curr_iter'] % control_vars['log_interval'] == 0: trainer.print_log_info(model, optimizer, 1, total_loss, valid_vars, control_vars) model_dict = { 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'control_vars': control_vars, 'train_vars': valid_vars, } trainer.save_checkpoint(model_dict, filename=valid_vars['checkpoint_filenamebase'] + str(control_vars['num_iter']) + '.pth.tar') # print time lapse prefix = 'Validating (Epoch #' + str(1) + ' ' + str(control_vars['curr_epoch_iter']) + '/' +\ str(control_vars['tot_iter']) + ')' + ', (Batch ' + str(control_vars['batch_idx']+1) +\ '(' + str(control_vars['iter_size']) + ')' + '/' +\ str(control_vars['num_batches']) + ')' + ', (Iter #' + str(control_vars['curr_iter']) +\ '(' + str(control_vars['batch_size']) + ')' +\ ' - log every ' + str(control_vars['log_interval']) + ' iter): ' control_vars['tot_toc'] = display_est_time_loop(control_vars['tot_toc'] + time.time() - start, control_vars['curr_iter'], control_vars['num_iter'], prefix=prefix) control_vars['curr_iter'] += 1 control_vars['start_iter'] = control_vars['curr_iter'] + 1 control_vars['curr_epoch_iter'] += 1 return valid_vars, control_vars
def train(): parser = argparse.ArgumentParser( description='Parameters for training Model') # configuration fiule parser.add_argument('--opt', type=str, help='Path to option YAML file.') args = parser.parse_args() opt = option.parse(args.opt) set_logger.setup_logger(opt['logger']['name'], opt['logger']['path'], screen=opt['logger']['screen'], tofile=opt['logger']['tofile']) logger = logging.getLogger(opt['logger']['name']) day_time = datetime.date.today().strftime('%y%m%d') # build model model = opt['model']['MODEL'] logger.info("Building the model of {}".format(model)) # Extraction and Suppression model if opt['model']['MODEL'] == 'DPRNN_Speaker_Extraction' or opt['model'][ 'MODEL'] == 'DPRNN_Speaker_Suppression': net = model_function.Extractin_Suppression_Model( **opt['Dual_Path_Aux_Speaker']) # Separation model if opt['model']['MODEL'] == 'DPRNN_Speech_Separation': net = model_function.Speech_Serapation_Model( **opt['Dual_Path_Aux_Speaker']) if opt['train']['gpuid']: if len(opt['train']['gpuid']) > 1: logger.info('We use GPUs : {}'.format(opt['train']['gpuid'])) else: logger.info('We use GPUs : {}'.format(opt['train']['gpuid'])) device = torch.device('cuda:{}'.format(opt['train']['gpuid'][0])) gpuids = opt['train']['gpuid'] if len(gpuids) > 1: net = torch.nn.DataParallel(net, device_ids=gpuids) net = net.to(device) logger.info('Loading {} parameters: {:.3f} Mb'.format( model, check_parameters(net))) # build optimizer logger.info("Building the optimizer of {}".format(model)) Optimizer = make_optimizer(net.parameters(), opt) Scheduler = ReduceLROnPlateau(Optimizer, mode='min', factor=opt['scheduler']['factor'], patience=opt['scheduler']['patience'], verbose=True, min_lr=opt['scheduler']['min_lr']) # build dataloader logger.info('Building the dataloader of {}'.format(model)) train_dataloader, val_dataloader = make_dataloader(opt) logger.info('Train Datasets Length: {}, Val Datasets Length: {}'.format( len(train_dataloader), len(val_dataloader))) # build trainer logger.info('............. Training ................') total_epoch = opt['train']['epoch'] num_spks = opt['num_spks'] print_freq = opt['logger']['print_freq'] checkpoint_path = opt['train']['path'] early_stop = opt['train']['early_stop'] max_norm = opt['optim']['clip_norm'] best_loss = np.inf no_improve = 0 ce_loss = torch.nn.CrossEntropyLoss() weight = 0.1 epoch = 0 # Resume training settings if opt['resume']['state']: opt['resume']['path'] = opt['resume'][ 'path'] + '/' + '200722_epoch{}.pth.tar'.format( opt['resume']['epoch']) ckp = torch.load(opt['resume']['path'], map_location='cpu') epoch = ckp['epoch'] logger.info("Resume from checkpoint {}: epoch {:.3f}".format( opt['resume']['path'], epoch)) net.load_state_dict(ckp['model_state_dict']) net.to(device) Optimizer.load_state_dict(ckp['optim_state_dict']) while epoch < total_epoch: epoch += 1 logger.info('Start training from epoch: {:d}, iter: {:d}'.format( epoch, 0)) num_steps = len(train_dataloader) # trainning process total_SNRloss = 0.0 total_CEloss = 0.0 num_index = 1 start_time = time.time() for inputs, targets in train_dataloader: # Separation train if opt['model']['MODEL'] == 'DPRNN_Speech_Separation': mix = inputs ref = targets net.train() mix = mix.to(device) ref = [ref[i].to(device) for i in range(num_spks)] net.zero_grad() train_out = net(mix) SNR_loss = Loss(train_out, ref) loss = SNR_loss # Extraction train if opt['model']['MODEL'] == 'DPRNN_Speaker_Extraction': mix, aux = inputs ref, aux_len, sp_label = targets net.train() mix = mix.to(device) aux = aux.to(device) ref = ref.to(device) aux_len = aux_len.to(device) sp_label = sp_label.to(device) net.zero_grad() train_out = net([mix, aux, aux_len]) SNR_loss = Loss_SI_SDR(train_out[0], ref) CE_loss = torch.mean(ce_loss(train_out[1], sp_label)) loss = SNR_loss + weight * CE_loss total_CEloss += CE_loss.item() # Suppression train if opt['model']['MODEL'] == 'DPRNN_Speaker_Suppression': mix, aux = inputs ref, aux_len = targets net.train() mix = mix.to(device) aux = aux.to(device) ref = ref.to(device) aux_len = aux_len.to(device) net.zero_grad() train_out = net([mix, aux, aux_len]) SNR_loss = Loss_SI_SDR(train_out[0], ref) loss = SNR_loss # BP processs loss.backward() torch.nn.utils.clip_grad_norm_(net.parameters(), max_norm) Optimizer.step() total_SNRloss += SNR_loss.item() if num_index % print_freq == 0: message = '<Training epoch:{:d} / {:d} , iter:{:d} / {:d}, lr:{:.3e}, SI-SNR_loss:{:.3f}, CE loss:{:.3f}>'.format( epoch, total_epoch, num_index, num_steps, Optimizer.param_groups[0]['lr'], total_SNRloss / num_index, total_CEloss / num_index) logger.info(message) num_index += 1 end_time = time.time() mean_SNRLoss = total_SNRloss / num_index mean_CELoss = total_CEloss / num_index message = 'Finished Training *** <epoch:{:d} / {:d}, iter:{:d}, lr:{:.3e}, ' \ 'SNR loss:{:.3f}, CE loss:{:.3f}, Total time:{:.3f} min> '.format( epoch, total_epoch, num_index, Optimizer.param_groups[0]['lr'], mean_SNRLoss, mean_CELoss, (end_time - start_time) / 60) logger.info(message) # development processs val_num_index = 1 val_total_loss = 0.0 val_CE_loss = 0.0 val_acc_total = 0.0 val_acc = 0.0 val_start_time = time.time() val_num_steps = len(val_dataloader) for inputs, targets in val_dataloader: net.eval() with torch.no_grad(): # Separation development if opt['model']['MODEL'] == 'DPRNN_Speech_Separation': mix = inputs ref = targets mix = mix.to(device) ref = [ref[i].to(device) for i in range(num_spks)] Optimizer.zero_grad() val_out = net(mix) val_loss = Loss(val_out, ref) val_total_loss += val_loss.item() # Extraction development if opt['model']['MODEL'] == 'DPRNN_Speaker_Extraction': mix, aux = inputs ref, aux_len, label = targets mix = mix.to(device) aux = aux.to(device) ref = ref.to(device) aux_len = aux_len.to(device) label = label.to(device) Optimizer.zero_grad() val_out = net([mix, aux, aux_len]) val_loss = Loss_SI_SDR(val_out[0], ref) val_ce = torch.mean(ce_loss(val_out[1], label)) val_acc = accuracy_speaker(val_out[1], label) val_acc_total += val_acc val_total_loss += val_loss.item() val_CE_loss += val_ce.item() # suppression development if opt['model']['MODEL'] == 'DPRNN_Speaker_Suppression': mix, aux = inputs ref, aux_len = targets mix = mix.to(device) aux = aux.to(device) ref = ref.to(device) aux_len = aux_len.to(device) Optimizer.zero_grad() val_out = net([mix, aux, aux_len]) val_loss = Loss_SI_SDR(val_out[0], ref) val_total_loss += val_loss.item() if val_num_index % print_freq == 0: message = '<Valid-Epoch:{:d} / {:d}, iter:{:d} / {:d}, lr:{:.3e}, ' \ 'val_SISNR_loss:{:.3f}, val_CE_loss:{:.3f}, val_acc :{:.3f}>' .format( epoch, total_epoch, val_num_index, val_num_steps, Optimizer.param_groups[0]['lr'], val_total_loss / val_num_index, val_CE_loss / val_num_index, val_acc_total / val_num_index) logger.info(message) val_num_index += 1 val_end_time = time.time() mean_val_total_loss = val_total_loss / val_num_index mean_val_CE_loss = val_CE_loss / val_num_index mean_acc = val_acc_total / val_num_index message = 'Finished *** <epoch:{:d}, iter:{:d}, lr:{:.3e}, val SI-SNR loss:{:.3f}, val_CE_loss:{:.3f}, val_acc:{:.3f}' \ ' Total time:{:.3f} min> '.format(epoch, val_num_index, Optimizer.param_groups[0]['lr'], mean_val_total_loss, mean_val_CE_loss, mean_acc, (val_end_time - val_start_time) / 60) logger.info(message) Scheduler.step(mean_val_total_loss) if mean_val_total_loss >= best_loss: no_improve += 1 logger.info( 'No improvement, Best SI-SNR Loss: {:.4f}'.format(best_loss)) if mean_val_total_loss < best_loss: best_loss = mean_val_total_loss no_improve = 0 save_checkpoint(epoch, checkpoint_path, net, Optimizer, day_time) logger.info( 'Epoch: {:d}, Now Best SI-SNR Loss Change: {:.4f}'.format( epoch, best_loss)) if no_improve == early_stop: save_checkpoint(epoch, checkpoint_path, net, Optimizer, day_time) logger.info("Stop training cause no impr for {:d} epochs".format( no_improve)) break
# Hyperparams that have been found to work well tconf = trainer.TrainerConfig(max_epochs=650, batch_size=128, learning_rate=6e-3, lr_decay=True, warmup_tokens=512 * 20, final_tokens=200 * len(pretrain_dataset) * block_size, num_workers=4, ckpt_path=args.writing_params_path) # Initiate trainer, train, then save params of model trainer = trainer.Trainer(model, pretrain_dataset, None, tconf) trainer.train() trainer.save_checkpoint() elif args.function == 'finetune': assert args.writing_params_path is not None assert args.finetune_corpus_path is not None # - Given: # 1. A finetuning corpus specified in args.finetune_corpus_path # 2. A path args.reading_params_path containing pretrained model # parameters, or None if finetuning without a pretrained model # 3. An output path args.writing_params_path for the model parameters # - Goals: # 1. If args.reading_params_path is specified, load these parameters # into the model # 2. Finetune the model on this corpus # 3. Save the resulting model in args.writing_params_path
def train(train_loader, model, optimizer, train_vars): verbose = train_vars['verbose'] for batch_idx, (data, target) in enumerate(train_loader): train_vars['batch_idx'] = batch_idx # print info about performing first iter if batch_idx < train_vars['iter_size']: print_verbose( "\rPerforming first iteration; current mini-batch: " + str(batch_idx + 1) + "/" + str(train_vars['iter_size']), verbose, n_tabs=0, erase_line=True) # check if arrived at iter to start arrived_curr_iter, train_vars = run_until_curr_iter( batch_idx, train_vars) if not arrived_curr_iter: continue # save checkpoint after final iteration if train_vars['curr_iter'] - 1 == train_vars['num_iter']: train_vars = trainer.save_final_checkpoint(train_vars, model, optimizer) break # start time counter start = time.time() # get data and target as torch Variables _, target_joints, target_heatmaps, target_joints_z = target # make target joints be relative target_joints = target_joints[:, 3:] data, target_heatmaps = Variable(data), Variable(target_heatmaps) if train_vars['use_cuda']: data = data.cuda() target_heatmaps = target_heatmaps.cuda() target_joints = target_joints.cuda() target_joints_z = target_joints_z.cuda() # get model output output = model(data) # accumulate loss for sub-mini-batch if train_vars['cross_entropy']: loss_func = my_losses.cross_entropy_loss_p_logq else: loss_func = my_losses.euclidean_loss weights_heatmaps_loss, weights_joints_loss = get_loss_weights( train_vars['curr_iter']) loss, loss_heatmaps, loss_joints = my_losses.calculate_loss_JORNet( loss_func, output, target_heatmaps, target_joints, train_vars['joint_ixs'], weights_heatmaps_loss, weights_joints_loss, train_vars['iter_size']) loss.backward() train_vars['total_loss'] += loss.item() train_vars['total_joints_loss'] += loss_joints.item() train_vars['total_heatmaps_loss'] += loss_heatmaps.item() # accumulate pixel dist loss for sub-mini-batch train_vars[ 'total_pixel_loss'] = my_losses.accumulate_pixel_dist_loss_multiple( train_vars['total_pixel_loss'], output[3], target_heatmaps, train_vars['batch_size']) if train_vars['cross_entropy']: train_vars[ 'total_pixel_loss_sample'] = my_losses.accumulate_pixel_dist_loss_from_sample_multiple( train_vars['total_pixel_loss_sample'], output[3], target_heatmaps, train_vars['batch_size']) else: train_vars['total_pixel_loss_sample'] = [-1] * len(model.joint_ixs) ''' For debugging training for i in range(train_vars['max_mem_batch']): filenamebase_idx = (batch_idx * train_vars['max_mem_batch']) + i filenamebase = train_loader.dataset.get_filenamebase(filenamebase_idx) visualize.plot_joints_from_heatmaps(target_heatmaps[i].data.cpu().numpy(), title='GT joints: ' + filenamebase, data=data[i].data.cpu().numpy()) visualize.plot_joints_from_heatmaps(output[3][i].data.cpu().numpy(), title='Pred joints: ' + filenamebase, data=data[i].data.cpu().numpy()) visualize.plot_image_and_heatmap(output[3][i][4].data.numpy(), data=data[i].data.numpy(), title='Thumb tib heatmap: ' + filenamebase) visualize.show() ''' # get boolean variable stating whether a mini-batch has been completed minibatch_completed = (batch_idx + 1) % train_vars['iter_size'] == 0 if minibatch_completed: # visualize # ax, fig = visualize.plot_3D_joints(target_joints[0]) # visualize.plot_3D_joints(target_joints[1], ax=ax, fig=fig) if train_vars['curr_iter'] % train_vars['log_interval'] == 0: fig, ax = visualize.plot_3D_joints(target_joints[0]) visualize.savefig('joints_GT_' + str(train_vars['curr_iter']) + '.png') #visualize.plot_3D_joints(target_joints[1], fig=fig, ax=ax, color_root='C7') #visualize.plot_3D_joints(output[7].data.cpu().numpy()[0], fig=fig, ax=ax, color_root='C7') visualize.plot_3D_joints(output[7].data.cpu().numpy()[0]) visualize.savefig('joints_model_' + str(train_vars['curr_iter']) + '.png') #visualize.show() #visualize.savefig('joints_' + str(train_vars['curr_iter']) + '.png') # change learning rate to 0.01 after 45000 iterations optimizer = change_learning_rate(optimizer, 0.01, train_vars['curr_iter']) # optimise for mini-batch optimizer.step() # clear optimiser optimizer.zero_grad() # append total loss train_vars['losses'].append(train_vars['total_loss']) # erase total loss total_loss = train_vars['total_loss'] train_vars['total_loss'] = 0 # append total joints loss train_vars['losses_joints'].append(train_vars['total_joints_loss']) # erase total joints loss train_vars['total_joints_loss'] = 0 # append total joints loss train_vars['losses_heatmaps'].append( train_vars['total_heatmaps_loss']) # erase total joints loss train_vars['total_heatmaps_loss'] = 0 # append dist loss train_vars['pixel_losses'].append(train_vars['total_pixel_loss']) # erase pixel dist loss train_vars['total_pixel_loss'] = [0] * len(model.joint_ixs) # append dist loss of sample from output train_vars['pixel_losses_sample'].append( train_vars['total_pixel_loss_sample']) # erase dist loss of sample from output train_vars['total_pixel_loss_sample'] = [0] * len(model.joint_ixs) # check if loss is better if train_vars['losses'][-1] < train_vars['best_loss']: train_vars['best_loss'] = train_vars['losses'][-1] print_verbose( " This is a best loss found so far: " + str(train_vars['losses'][-1]), verbose) train_vars['best_model_dict'] = { 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'train_vars': train_vars } # log checkpoint if train_vars['curr_iter'] % train_vars['log_interval'] == 0: trainer.print_log_info(model, optimizer, epoch, total_loss, train_vars, train_vars) aa1 = target_joints[0].data.cpu().numpy() aa2 = output[7][0].data.cpu().numpy() output_joint_loss = np.sum(np.abs(aa1 - aa2)) / 63 msg = '' msg += print_verbose( "-------------------------------------------------------------------------------------------", verbose) + "\n" msg += print_verbose( '\tJoint Coord Avg Loss for first image of current mini-batch: ' + str(output_joint_loss) + '\n', train_vars['verbose']) msg += print_verbose( "-------------------------------------------------------------------------------------------", verbose) + "\n" if not train_vars['output_filepath'] == '': with open(train_vars['output_filepath'], 'a') as f: f.write(msg + '\n') if train_vars['curr_iter'] % train_vars['log_interval_valid'] == 0: print_verbose( "\nSaving model and checkpoint model for validation", verbose) checkpoint_model_dict = { 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'train_vars': train_vars, } trainer.save_checkpoint( checkpoint_model_dict, filename=train_vars['checkpoint_filenamebase'] + 'for_valid_' + str(train_vars['curr_iter']) + '.pth.tar') # print time lapse prefix = 'Training (Epoch #' + str(epoch) + ' ' + str(train_vars['curr_epoch_iter']) + '/' +\ str(train_vars['tot_iter']) + ')' + ', (Batch ' + str(train_vars['batch_idx']+1) +\ '(' + str(train_vars['iter_size']) + ')' + '/' +\ str(train_vars['num_batches']) + ')' + ', (Iter #' + str(train_vars['curr_iter']) +\ '(' + str(train_vars['batch_size']) + ')' +\ ' - log every ' + str(train_vars['log_interval']) + ' iter): ' train_vars['tot_toc'] = display_est_time_loop( train_vars['tot_toc'] + time.time() - start, train_vars['curr_iter'], train_vars['num_iter'], prefix=prefix) train_vars['curr_iter'] += 1 train_vars['start_iter'] = train_vars['curr_iter'] + 1 train_vars['curr_epoch_iter'] += 1 return train_vars
def validate(valid_loader, model, optimizer, valid_vars, control_vars, verbose=True): curr_epoch_iter = 1 for batch_idx, (data, target) in enumerate(valid_loader): control_vars['batch_idx'] = batch_idx if batch_idx < control_vars['iter_size']: print_verbose( "\rPerforming first iteration; current mini-batch: " + str(batch_idx + 1) + "/" + str(control_vars['iter_size']), verbose, n_tabs=0, erase_line=True) # start time counter start = time.time() # get data and targetas cuda variables target_heatmaps, target_joints, target_handroot = target # make target joints be relative target_joints = target_joints[:, 3:] data, target_heatmaps = Variable(data), Variable(target_heatmaps) if valid_vars['use_cuda']: data = data.cuda() target_joints = target_joints.cuda() target_heatmaps = target_heatmaps.cuda() target_handroot = target_handroot.cuda() # visualize if debugging # get model output output = model(data) # accumulate loss for sub-mini-batch if model.cross_entropy: loss_func = my_losses.cross_entropy_loss_p_logq else: loss_func = my_losses.euclidean_loss weights_heatmaps_loss, weights_joints_loss = get_loss_weights( control_vars['curr_iter']) loss, loss_heatmaps, loss_joints = my_losses.calculate_loss_JORNet( loss_func, output, target_heatmaps, target_joints, valid_vars['joint_ixs'], weights_heatmaps_loss, weights_joints_loss, control_vars['iter_size']) valid_vars['total_loss'] += loss valid_vars['total_joints_loss'] += loss_joints valid_vars['total_heatmaps_loss'] += loss_heatmaps # accumulate pixel dist loss for sub-mini-batch valid_vars[ 'total_pixel_loss'] = my_losses.accumulate_pixel_dist_loss_multiple( valid_vars['total_pixel_loss'], output[3], target_heatmaps, control_vars['batch_size']) valid_vars[ 'total_pixel_loss_sample'] = my_losses.accumulate_pixel_dist_loss_from_sample_multiple( valid_vars['total_pixel_loss_sample'], output[3], target_heatmaps, control_vars['batch_size']) # get boolean variable stating whether a mini-batch has been completed for i in range(control_vars['max_mem_batch']): filenamebase_idx = (batch_idx * control_vars['max_mem_batch']) + i filenamebase = valid_loader.dataset.get_filenamebase( filenamebase_idx) print('') print(filenamebase) visualize.plot_image(data[i].data.numpy()) visualize.show() output_batch_numpy = output[7][i].data.cpu().numpy() print('\n-------------------------------') reshaped_out = output_batch_numpy.reshape((20, 3)) for j in range(20): print('[{}, {}, {}],'.format(reshaped_out[j, 0], reshaped_out[j, 1], reshaped_out[j, 2])) print('-------------------------------') fig, ax = visualize.plot_3D_joints(target_joints[i]) visualize.plot_3D_joints(output_batch_numpy, fig=fig, ax=ax, color='C6') visualize.title(filenamebase) visualize.show() temp = np.zeros((21, 3)) output_batch_numpy_abs = output_batch_numpy.reshape((20, 3)) temp[1:, :] = output_batch_numpy_abs output_batch_numpy_abs = temp output_joints_colorspace = camera.joints_depth2color( output_batch_numpy_abs, depth_intr_matrix=synthhands_handler.DEPTH_INTR_MTX, handroot=target_handroot[i].data.cpu().numpy()) visualize.plot_3D_joints(output_joints_colorspace) visualize.show() aa1 = target_joints[i].data.cpu().numpy().reshape((20, 3)) aa2 = output[7][i].data.cpu().numpy().reshape((20, 3)) print('\n----------------------------------') print(np.sum(np.abs(aa1 - aa2)) / 60) print('----------------------------------') #loss.backward() valid_vars['total_loss'] += loss valid_vars['total_joints_loss'] += loss_joints valid_vars['total_heatmaps_loss'] += loss_heatmaps # accumulate pixel dist loss for sub-mini-batch valid_vars[ 'total_pixel_loss'] = my_losses.accumulate_pixel_dist_loss_multiple( valid_vars['total_pixel_loss'], output[3], target_heatmaps, control_vars['batch_size']) valid_vars[ 'total_pixel_loss_sample'] = my_losses.accumulate_pixel_dist_loss_from_sample_multiple( valid_vars['total_pixel_loss_sample'], output[3], target_heatmaps, control_vars['batch_size']) # get boolean variable stating whether a mini-batch has been completed minibatch_completed = (batch_idx + 1) % control_vars['iter_size'] == 0 if minibatch_completed: # append total loss valid_vars['losses'].append(valid_vars['total_loss'].data[0]) # erase total loss total_loss = valid_vars['total_loss'].data[0] valid_vars['total_loss'] = 0 # append total joints loss valid_vars['losses_joints'].append( valid_vars['total_joints_loss'].data[0]) # erase total joints loss valid_vars['total_joints_loss'] = 0 # append total joints loss valid_vars['losses_heatmaps'].append( valid_vars['total_heatmaps_loss'].data[0]) # erase total joints loss valid_vars['total_heatmaps_loss'] = 0 # append dist loss valid_vars['pixel_losses'].append(valid_vars['total_pixel_loss']) # erase pixel dist loss valid_vars['total_pixel_loss'] = [0] * len(model.joint_ixs) # append dist loss of sample from output valid_vars['pixel_losses_sample'].append( valid_vars['total_pixel_loss_sample']) # erase dist loss of sample from output valid_vars['total_pixel_loss_sample'] = [0] * len(model.joint_ixs) # check if loss is better #if valid_vars['losses'][-1] < valid_vars['best_loss']: # valid_vars['best_loss'] = valid_vars['losses'][-1] # print_verbose(" This is a best loss found so far: " + str(valid_vars['losses'][-1]), verbose) # log checkpoint if control_vars['curr_iter'] % control_vars['log_interval'] == 0: trainer.print_log_info(model, optimizer, 1, total_loss, valid_vars, control_vars) model_dict = { 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'control_vars': control_vars, 'train_vars': valid_vars, } trainer.save_checkpoint( model_dict, filename=valid_vars['checkpoint_filenamebase'] + str(control_vars['num_iter']) + '.pth.tar') # print time lapse prefix = 'Validating (Epoch #' + str(1) + ' ' + str(control_vars['curr_epoch_iter']) + '/' +\ str(control_vars['tot_iter']) + ')' + ', (Batch ' + str(control_vars['batch_idx']+1) +\ '(' + str(control_vars['iter_size']) + ')' + '/' +\ str(control_vars['num_batches']) + ')' + ', (Iter #' + str(control_vars['curr_iter']) +\ '(' + str(control_vars['batch_size']) + ')' +\ ' - log every ' + str(control_vars['log_interval']) + ' iter): ' control_vars['tot_toc'] = display_est_time_loop( control_vars['tot_toc'] + time.time() - start, control_vars['curr_iter'], control_vars['num_iter'], prefix=prefix) control_vars['curr_iter'] += 1 control_vars['start_iter'] = control_vars['curr_iter'] + 1 control_vars['curr_epoch_iter'] += 1 return valid_vars, control_vars
def train(train_loader, model, optimizer, train_vars): verbose = train_vars['verbose'] for batch_idx, (data, target) in enumerate(train_loader): train_vars['batch_idx'] = batch_idx # print info about performing first iter if batch_idx < train_vars['iter_size']: print_verbose("\rPerforming first iteration; current mini-batch: " + str(batch_idx+1) + "/" + str(train_vars['iter_size']), verbose, n_tabs=0, erase_line=True) # check if arrived at iter to start arrived_curr_iter, train_vars = run_until_curr_iter(batch_idx, train_vars) if not arrived_curr_iter: continue # save checkpoint after final iteration if train_vars['curr_iter'] - 1 == train_vars['num_iter']: train_vars = save_final_checkpoint(train_vars, model, optimizer) break # start time counter start = time.time() # get data and target as torch Variables _, target_joints, target_heatmaps, target_joints_z = target data, target_heatmaps = Variable(data), Variable(target_heatmaps) if train_vars['use_cuda']: data = data.cuda() target_heatmaps = target_heatmaps.cuda() # get model output output = model(data) # accumulate loss for sub-mini-batch if model.cross_entropy: loss_func = my_losses.cross_entropy_loss_p_logq else: loss_func = my_losses.euclidean_loss loss = my_losses.calculate_loss_HALNet(loss_func, output, target_heatmaps, model.joint_ixs, model.WEIGHT_LOSS_INTERMED1, model.WEIGHT_LOSS_INTERMED2, model.WEIGHT_LOSS_INTERMED3, model.WEIGHT_LOSS_MAIN, train_vars['iter_size']) loss.backward() train_vars['total_loss'] += loss # accumulate pixel dist loss for sub-mini-batch train_vars['total_pixel_loss'] = my_losses.accumulate_pixel_dist_loss_multiple( train_vars['total_pixel_loss'], output[3], target_heatmaps, train_vars['batch_size']) if train_vars['cross_entropy']: train_vars['total_pixel_loss_sample'] = my_losses.accumulate_pixel_dist_loss_from_sample_multiple( train_vars['total_pixel_loss_sample'], output[3], target_heatmaps, train_vars['batch_size']) else: train_vars['total_pixel_loss_sample'] = [-1] * len(model.joint_ixs) # get boolean variable stating whether a mini-batch has been completed minibatch_completed = (batch_idx+1) % train_vars['iter_size'] == 0 if minibatch_completed: # optimise for mini-batch optimizer.step() # clear optimiser optimizer.zero_grad() # append total loss train_vars['losses'].append(train_vars['total_loss'].item()) # erase total loss total_loss = train_vars['total_loss'].item() train_vars['total_loss'] = 0 # append dist loss train_vars['pixel_losses'].append(train_vars['total_pixel_loss']) # erase pixel dist loss train_vars['total_pixel_loss'] = [0] * len(model.joint_ixs) # append dist loss of sample from output train_vars['pixel_losses_sample'].append(train_vars['total_pixel_loss_sample']) # erase dist loss of sample from output train_vars['total_pixel_loss_sample'] = [0] * len(model.joint_ixs) # check if loss is better if train_vars['losses'][-1] < train_vars['best_loss']: train_vars['best_loss'] = train_vars['losses'][-1] print_verbose(" This is a best loss found so far: " + str(train_vars['losses'][-1]), verbose) train_vars['best_model_dict'] = { 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'train_vars': train_vars } # log checkpoint if train_vars['curr_iter'] % train_vars['log_interval'] == 0: trainer.print_log_info(model, optimizer, epoch, total_loss, train_vars, train_vars) if train_vars['curr_iter'] % train_vars['log_interval_valid'] == 0: print_verbose("\nSaving model and checkpoint model for validation", verbose) checkpoint_model_dict = { 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'train_vars': train_vars, } trainer.save_checkpoint(checkpoint_model_dict, filename=train_vars['checkpoint_filenamebase'] + 'for_valid_' + str(train_vars['curr_iter']) + '.pth.tar') # print time lapse prefix = 'Training (Epoch #' + str(epoch) + ' ' + str(train_vars['curr_epoch_iter']) + '/' +\ str(train_vars['tot_iter']) + ')' + ', (Batch ' + str(train_vars['batch_idx']+1) +\ '(' + str(train_vars['iter_size']) + ')' + '/' +\ str(train_vars['num_batches']) + ')' + ', (Iter #' + str(train_vars['curr_iter']) +\ '(' + str(train_vars['batch_size']) + ')' +\ ' - log every ' + str(train_vars['log_interval']) + ' iter): ' train_vars['tot_toc'] = display_est_time_loop(train_vars['tot_toc'] + time.time() - start, train_vars['curr_iter'], train_vars['num_iter'], prefix=prefix) train_vars['curr_iter'] += 1 train_vars['start_iter'] = train_vars['curr_iter'] + 1 train_vars['curr_epoch_iter'] += 1 return train_vars
def main_func(args): cdf = mc.ConfidenceDepthFrameworkFactory() val_loader, _ = df.create_data_loaders(args.data_path , loader_type='val' , data_type= args.data_type , modality= args.data_modality , num_samples= args.num_samples , depth_divisor= args.divider , max_depth= args.max_depth , max_gt_depth= args.max_gt_depth , workers= args.workers , batch_size=1) if not args.evaluate: train_loader, _ = df.create_data_loaders(args.data_path , loader_type='train' , data_type=args.data_type , modality=args.data_modality , num_samples=args.num_samples , depth_divisor=args.divider , max_depth=args.max_depth , max_gt_depth=args.max_gt_depth , workers=args.workers , batch_size=args.batch_size) # evaluation mode if args.evaluate: cdfmodel,loss, epoch = trainer.resume(args.evaluate,cdf,True) output_directory = create_eval_output_folder(args) os.makedirs(output_directory) print(output_directory) save_arguments(args,output_directory) trainer.validate(val_loader, cdfmodel, loss, epoch,print_frequency=args.print_freq,num_image_samples=args.val_images, output_folder=output_directory, conf_recall=args.pr,conf_threshold= args.thrs) return output_directory = create_output_folder(args) os.makedirs(output_directory) print(output_directory) save_arguments(args, output_directory) # optionally resume from a checkpoint if args.resume: cdfmodel, loss, loss_def, best_result_error, optimizer, scheduler = trainer.resume(args.resume,cdf,False) # create new model else: cdfmodel = cdf.create_model(args.dcnet_modality, args.training_mode, args.dcnet_arch, args.dcnet_pretrained, args.confnet_arch, args.confnet_pretrained, args.lossnet_arch, args.lossnet_pretrained) cdfmodel, opt_parameters = cdf.to_device(cdfmodel) optimizer, scheduler = trainer.create_optimizer(args.optimizer, opt_parameters, args.momentum, args.weight_decay, args.lr, args.lrs, args.lrm) loss, loss_definition = cdf.create_loss(args.criterion, ('ln' in args.training_mode), (0.5 if 'dc1' in args.training_mode else 1.0)) best_result_error = math.inf for epoch in range(0, args.epochs): trainer.train(train_loader, cdfmodel, loss, optimizer, output_directory, epoch) epoch_result = trainer.validate(val_loader, cdfmodel, loss, epoch=epoch,print_frequency=args.print_freq,num_image_samples=args.val_images, output_folder=output_directory) scheduler.step(epoch) is_best = epoch_result.rmse < best_result_error if is_best: best_result_error = epoch_result.rmse trainer.report_top_result(os.path.join(output_directory, 'best_result.txt'), epoch, epoch_result) # if img_merge is not None: # img_filename = output_directory + '/comparison_best.png' # utils.save_image(img_merge, img_filename) trainer.save_checkpoint(cdf, cdfmodel, loss_definition, optimizer, scheduler,best_result_error, is_best, epoch, output_directory)
def train(train_loader, model, optimizer, train_vars, control_vars, verbose=True): curr_epoch_iter = 1 for batch_idx, (data, target) in enumerate(train_loader): control_vars['batch_idx'] = batch_idx if batch_idx < control_vars['iter_size']: print_verbose("\rPerforming first iteration; current mini-batch: " + str(batch_idx+1) + "/" + str(control_vars['iter_size']), verbose, n_tabs=0, erase_line=True) # check if arrived at iter to start if control_vars['curr_epoch_iter'] < control_vars['start_iter_mod']: control_vars['curr_epoch_iter'] = control_vars['start_iter_mod'] msg = '' if batch_idx % control_vars['iter_size'] == 0: msg += print_verbose("\rGoing through iterations to arrive at last one saved... " + str(int(control_vars['curr_epoch_iter']*100.0/control_vars['start_iter_mod'])) + "% of " + str(control_vars['start_iter_mod']) + " iterations (" + str(control_vars['curr_epoch_iter']) + "/" + str(control_vars['start_iter_mod']) + ")", verbose, n_tabs=0, erase_line=True) control_vars['curr_epoch_iter'] += 1 control_vars['curr_iter'] += 1 curr_epoch_iter += 1 if not control_vars['output_filepath'] == '': with open(control_vars['output_filepath'], 'a') as f: f.write(msg + '\n') continue # save checkpoint after final iteration if control_vars['curr_iter'] == control_vars['num_iter']: print_verbose("\nReached final number of iterations: " + str(control_vars['num_iter']), verbose) print_verbose("\tSaving final model checkpoint...", verbose) final_model_dict = { 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'control_vars': control_vars, 'train_vars': train_vars, } trainer.save_checkpoint(final_model_dict, filename=train_vars['checkpoint_filenamebase'] + 'final' + str(control_vars['num_iter']) + '.pth.tar') control_vars['done_training'] = True break # start time counter start = time.time() # get data and targetas cuda variables target_heatmaps, target_joints, target_roothand = target data, target_heatmaps, target_joints, target_roothand = Variable(data), Variable(target_heatmaps),\ Variable(target_joints), Variable(target_roothand) if train_vars['use_cuda']: data = data.cuda() target_heatmaps = target_heatmaps.cuda() target_joints = target_joints.cuda() # get model output output = model(data) ''' visualize.plot_joints_from_heatmaps(target_heatmaps[0, :, :, :].cpu().data.numpy(), title='', data=data[0].cpu().data.numpy()) visualize.show() visualize.plot_image_and_heatmap(target_heatmaps[0][4].cpu().data.numpy(), data=data[0].cpu().data.numpy(), title='') visualize.show() visualize.plot_image_and_heatmap(output[3][0][4].cpu().data.numpy(), data=data[0].cpu().data.numpy(), title='') visualize.show() ''' # accumulate loss for sub-mini-batch if train_vars['cross_entropy']: loss_func = my_losses.cross_entropy_loss_p_logq else: loss_func = my_losses.euclidean_loss weights_heatmaps_loss, weights_joints_loss = get_loss_weights(control_vars['curr_iter']) loss, loss_heatmaps, loss_joints = my_losses.calculate_loss_JORNet( loss_func, output, target_heatmaps, target_joints, train_vars['joint_ixs'], weights_heatmaps_loss, weights_joints_loss, control_vars['iter_size']) loss.backward() train_vars['total_loss'] += loss.data[0] train_vars['total_joints_loss'] += loss_joints.data[0] train_vars['total_heatmaps_loss'] += loss_heatmaps.data[0] # accumulate pixel dist loss for sub-mini-batch train_vars['total_pixel_loss'] = my_losses.accumulate_pixel_dist_loss_multiple( train_vars['total_pixel_loss'], output[3], target_heatmaps, control_vars['batch_size']) train_vars['total_pixel_loss_sample'] = my_losses.accumulate_pixel_dist_loss_from_sample_multiple( train_vars['total_pixel_loss_sample'], output[3], target_heatmaps, control_vars['batch_size']) # get boolean variable stating whether a mini-batch has been completed minibatch_completed = (batch_idx+1) % control_vars['iter_size'] == 0 if minibatch_completed: # optimise for mini-batch optimizer.step() # clear optimiser optimizer.zero_grad() # append total loss train_vars['losses'].append(train_vars['total_loss']) # erase total loss total_loss = train_vars['total_loss'] train_vars['total_loss'] = 0 # append total joints loss train_vars['losses_joints'].append(train_vars['total_joints_loss']) # erase total joints loss train_vars['total_joints_loss'] = 0 # append total joints loss train_vars['losses_heatmaps'].append(train_vars['total_heatmaps_loss']) # erase total joints loss train_vars['total_heatmaps_loss'] = 0 # append dist loss train_vars['pixel_losses'].append(train_vars['total_pixel_loss']) # erase pixel dist loss train_vars['total_pixel_loss'] = [0] * len(model.joint_ixs) # append dist loss of sample from output train_vars['pixel_losses_sample'].append(train_vars['total_pixel_loss_sample']) # erase dist loss of sample from output train_vars['total_pixel_loss_sample'] = [0] * len(model.joint_ixs) # check if loss is better if train_vars['losses'][-1] < train_vars['best_loss']: train_vars['best_loss'] = train_vars['losses'][-1] print_verbose(" This is a best loss found so far: " + str(train_vars['losses'][-1]), verbose) train_vars['best_model_dict'] = { 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'control_vars': control_vars, 'train_vars': train_vars, } if train_vars['losses_joints'][-1] < train_vars['best_loss_joints']: train_vars['best_loss_joints'] = train_vars['losses_joints'][-1] if train_vars['losses_heatmaps'][-1] < train_vars['best_loss_heatmaps']: train_vars['best_loss_heatmaps'] = train_vars['losses_heatmaps'][-1] # log checkpoint if control_vars['curr_iter'] % control_vars['log_interval'] == 0: trainer.print_log_info(model, optimizer, epoch, total_loss, train_vars, control_vars) aa1 = target_joints[0].data.cpu().numpy() aa2 = output[7][0].data.cpu().numpy() output_joint_loss = np.sum(np.abs(aa1 - aa2)) / 63 msg = '' msg += print_verbose( "-------------------------------------------------------------------------------------------", verbose) + "\n" msg += print_verbose('\tJoint Coord Avg Loss for first image of current mini-batch: ' + str(output_joint_loss) + '\n', control_vars['verbose']) msg += print_verbose( "-------------------------------------------------------------------------------------------", verbose) + "\n" if not control_vars['output_filepath'] == '': with open(control_vars['output_filepath'], 'a') as f: f.write(msg + '\n') if control_vars['curr_iter'] % control_vars['log_interval_valid'] == 0: print_verbose("\nSaving model and checkpoint model for validation", verbose) checkpoint_model_dict = { 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'control_vars': control_vars, 'train_vars': train_vars, } trainer.save_checkpoint(checkpoint_model_dict, filename=train_vars['checkpoint_filenamebase'] + 'for_valid_' + str(control_vars['curr_iter']) + '.pth.tar') # print time lapse prefix = 'Training (Epoch #' + str(epoch) + ' ' + str(control_vars['curr_epoch_iter']) + '/' +\ str(control_vars['tot_iter']) + ')' + ', (Batch ' + str(control_vars['batch_idx']+1) +\ '(' + str(control_vars['iter_size']) + ')' + '/' +\ str(control_vars['num_batches']) + ')' + ', (Iter #' + str(control_vars['curr_iter']) +\ '(' + str(control_vars['batch_size']) + ')' +\ ' - log every ' + str(control_vars['log_interval']) + ' iter): ' control_vars['tot_toc'] = display_est_time_loop(control_vars['tot_toc'] + time.time() - start, control_vars['curr_iter'], control_vars['num_iter'], prefix=prefix) control_vars['curr_iter'] += 1 control_vars['start_iter'] = control_vars['curr_iter'] + 1 control_vars['curr_epoch_iter'] += 1 return train_vars, control_vars
def main(): args = arguments.parse() checkpoint = args.checkpoint if args.checkpoint else None model, params = get_network(args.arch, args.n_attrs, checkpoint=checkpoint, base_frozen=args.freeze_base) criterion = get_criterion(loss_type=args.loss, args=args) optimizer = get_optimizer(params, fc_lr=float(args.lr), opt_type=args.optimizer_type, momentum=args.momentum, weight_decay=args.weight_decay) scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1, last_epoch=args.start_epoch - 1) if checkpoint: state = torch.load(checkpoint) model.load_state_dict(state["state_dict"]) scheduler.load_state_dict(state['scheduler']) # Dataloader code normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) train_transforms = transforms.Compose([ transforms.Resize((224, 224)), transforms.RandomHorizontalFlip(), transforms.ToTensor(), normalize, ]) val_transforms = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), normalize, ]) logger.info("Setting up training data") train_loader = data.DataLoader(COCOAttributes( args.attributes, args.train_ann, train=True, split='train2014', transforms=train_transforms, dataset_root=args.dataset_root), batch_size=args.batch_size, shuffle=True, num_workers=args.workers, pin_memory=True) logger.info("Setting up validation data") val_loader = data.DataLoader(COCOAttributes( args.attributes, args.val_ann, train=False, split='val2014', transforms=val_transforms, dataset_root=args.dataset_root), batch_size=args.batch_size, shuffle=False, num_workers=args.workers, pin_memory=True) best_prec1 = 0 if not os.path.exists(args.save_dir): os.makedirs(args.save_dir) logger.info("Beginning training...") for epoch in range(args.start_epoch, args.epochs): scheduler.step() # train for one epoch trainer.train(train_loader, model, criterion, optimizer, epoch, args.print_freq) # evaluate on validation set # trainer.validate(val_loader, model, criterion, epoch, args.print_freq) prec1 = 0 # remember best prec@1 and save checkpoint best_prec1 = max(prec1, best_prec1) trainer.save_checkpoint( { 'epoch': epoch + 1, 'arch': args.arch, 'loss': args.loss, 'optimizer': args.optimizer_type, 'state_dict': model.state_dict(), 'scheduler': scheduler.state_dict(), 'batch_size': args.batch_size, 'best_prec1': best_prec1, }, args.save_dir, '{0}_{1}_checkpoint.pth.tar'.format(args.arch, args.loss).lower()) logger.info('Finished Training') logger.info('Running evaluation') evaluator = evaluation.Evaluator(model, val_loader, batch_size=args.batch_size, name="{0}_{1}".format( args.arch, args.loss)) with torch.no_grad(): evaluator.evaluate()