def get_training_annotation(training_filepath, output_filepath, verbose=False): guess_num_lines = 1e6 read_interval = 10000000 num_joints = 21 print_verbose("Training input file path: " + training_filepath, verbose) print_verbose("Testing if program can write to output: " + output_filepath, verbose) with open(output_filepath, 'wb') as f: pickle.dump([], f) joints = [] image_names = [] with open(training_filepath, 'r') as f: line = f.readline() curr_line_ix = 1 tot_toc = 0 while line: start = time.time() image_names.append(training_file_line_to_image_name(line)) joints.append(training_file_line_to_numpy_array(line, num_joints)) if curr_line_ix % read_interval == 0: with open(output_filepath + '.pkl', 'wb') as pf: pickle.dump([image_names, joints], pf) line = f.readline() curr_line_ix += 1 tot_toc = display_est_time_loop(tot_toc + (time.time() - start), curr_line_ix, guess_num_lines, prefix='Line: ' + str(curr_line_ix) + ' ') with open(output_filepath, 'wb') as pf: pickle.dump([image_names, joints], pf)
def save_final_checkpoint(train_vars, model, optimizer): msg = '' msg += print_verbose("\nReached final number of iterations: " + str(train_vars['num_iter']), train_vars['verbose']) msg += print_verbose("\tSaving final model checkpoint...", train_vars['verbose']) if not train_vars['output_filepath'] == '': with open(train_vars['output_filepath'], 'a') as f: f.write(msg + '\n') final_model_dict = { 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'train_vars': train_vars, } save_checkpoint(final_model_dict, filename=train_vars['checkpoint_filenamebase'] + 'final' + str(train_vars['num_iter']) + '.pth.tar') train_vars['done_training'] = True return train_vars
def run_until_curr_iter(batch_idx, train_vars): if train_vars['curr_epoch_iter'] < train_vars['start_iter_mod']: msg = '' if batch_idx % train_vars['iter_size'] == 0: msg += print_verbose("\rGoing through iterations to arrive at last one saved... " + str(int(train_vars['curr_epoch_iter'] * 100.0 / train_vars[ 'start_iter_mod'])) + "% of " + str(train_vars['start_iter_mod']) + " iterations (" + str(train_vars['curr_epoch_iter']) + "/" + str(train_vars['start_iter_mod']) + ")", train_vars['verbose'], n_tabs=0, erase_line=True) train_vars['curr_epoch_iter'] += 1 train_vars['curr_iter'] += 1 train_vars['curr_epoch_iter'] += 1 if not train_vars['output_filepath'] == '': with open(train_vars['output_filepath'], 'a') as f: f.write(msg + '\n') return False, train_vars return True, train_vars
def print_header_info(model, dataset_loader, train_vars): msg = '' msg += print_verbose("-----------------------------------------------------------", train_vars['verbose']) + "\n" msg += print_verbose("Output filenamebase: " + train_vars['output_filepath'], train_vars['verbose']) + "\n" msg += print_verbose("Model info", train_vars['verbose']) + "\n" try: heatmap_ixs = model.heatmap_ixs msg += print_verbose("Joints indexes: " + str(heatmap_ixs), train_vars['verbose']) + "\n" msg += print_verbose("Number of joints: " + str(len(heatmap_ixs)), train_vars['verbose']) + "\n" except: msg += print_verbose("Joints indexes: " + str(model.num_heatmaps), train_vars['verbose']) + "\n" msg += print_verbose("-----------------------------------------------------------", train_vars['verbose']) + "\n" msg += print_verbose("Max memory batch size: " + str(train_vars['max_mem_batch']), train_vars['verbose']) + "\n" msg += print_verbose("Length of dataset (in max mem batch size): " + str(len(dataset_loader)), train_vars['verbose']) + "\n" msg += print_verbose("Training batch size: " + str(train_vars['batch_size']), train_vars['verbose']) + "\n" msg += print_verbose("Starting epoch: " + str(train_vars['start_epoch']), train_vars['verbose']) + "\n" msg += print_verbose("Starting epoch iteration: " + str(train_vars['start_iter_mod']), train_vars['verbose']) + "\n" msg += print_verbose("Starting overall iteration: " + str(train_vars['start_iter']), train_vars['verbose']) + "\n" msg += print_verbose("-----------------------------------------------------------", train_vars['verbose']) + "\n" msg += print_verbose("Number of iterations per epoch: " + str(train_vars['n_iter_per_epoch']), train_vars['verbose']) + "\n" msg += print_verbose("Number of iterations to train: " + str(train_vars['num_iter']), train_vars['verbose']) + "\n" msg += print_verbose("Approximate number of epochs to train: " + str(round(train_vars['num_iter'] / train_vars['n_iter_per_epoch'], 1)), train_vars['verbose']) + "\n" msg += print_verbose("-----------------------------------------------------------", train_vars['verbose']) + "\n" if not train_vars['output_filepath'] == '': with open(train_vars['output_filepath'], 'w+') as f: f.write(msg + '\n')
def print_log_info(model, optimizer, epoch, total_loss, vars, train_vars, save_best=True, save_a_checkpoint=True): model_class_name = type(model).__name__ verbose = train_vars['verbose'] print_verbose("", verbose) print_verbose("-------------------------------------------------------------------------------------------", verbose) if save_a_checkpoint: print_verbose("Saving checkpoints:", verbose) print_verbose("-------------------------------------------------------------------------------------------", verbose) checkpoint_model_dict = { 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'train_vars': train_vars, } save_checkpoint(checkpoint_model_dict, filename=vars['checkpoint_filenamebase'] + '.pth.tar') if save_best: save_checkpoint(vars['best_model_dict'], filename=vars['checkpoint_filenamebase'] + 'best.pth.tar') msg = '' msg += print_verbose("-------------------------------------------------------------------------------------------", verbose) + "\n" now = datetime.datetime.now() msg += print_verbose('Time: ' + now.strftime("%Y-%m-%d %H:%M"), verbose) + "\n" msg += print_verbose("-------------------------------------------------------------------------------------------", verbose) + "\n" msg += print_verbose("-------------------------------------------------------------------------------------------", verbose) + "\n" msg += print_verbose('Training (Epoch #' + str(epoch) + ' ' + str(train_vars['curr_epoch_iter']) + '/' + \ str(train_vars['tot_iter']) + ')' + ', (Batch ' + str(train_vars['batch_idx'] + 1) + \ '(' + str(train_vars['iter_size']) + ')' + '/' + \ str(train_vars['num_batches']) + ')' + ', (Iter #' + str(train_vars['curr_iter']) + \ '(' + str(train_vars['batch_size']) + ')' + \ ' - log every ' + str(train_vars['log_interval']) + ' iter): ', verbose) + '\n' msg += print_verbose("-------------------------------------------------------------------------------------------", verbose) + "\n" msg += print_verbose("Current loss: " + str(total_loss), verbose) + "\n" msg += print_verbose("Best loss: " + str(vars['best_loss']), verbose) + "\n" msg += print_verbose("Mean total loss: " + str(np.mean(vars['losses'])), verbose) + "\n" msg += print_verbose("Mean loss for last " + str(train_vars['log_interval']) + " iterations (average total loss): " + str( np.mean(vars['losses'][-train_vars['log_interval']:])), verbose) + "\n" if model_class_name == 'JORNet': msg += print_verbose("-------------------------------------------------------------------------------------------", verbose) + "\n" msg += print_verbose("Current joints loss: " + str(vars['losses_joints'][-1]), verbose) + "\n" msg += print_verbose("Best joints loss: " + str(vars['best_loss_joints']), verbose) + "\n" msg += print_verbose("Mean total joints loss: " + str(np.mean(vars['losses_joints'])), verbose) + "\n" msg += print_verbose("Mean joints loss for last " + str(train_vars['log_interval']) + " iterations (average total joints loss): " + str( np.mean(vars['losses_joints'][-train_vars['log_interval']:])), verbose) + "\n" msg += print_verbose("-------------------------------------------------------------------------------------------", verbose) + "\n" msg += print_verbose("Current heatmaps loss: " + str(vars['losses_heatmaps'][-1]), verbose) + "\n" msg += print_verbose("Best heatmaps loss: " + str(vars['best_loss_heatmaps']), verbose) + "\n" msg += print_verbose("Mean total heatmaps loss: " + str(np.mean(vars['losses_heatmaps'])), verbose) + "\n" msg += print_verbose("Mean heatmaps loss for last " + str(train_vars['log_interval']) + " iterations (average total heatmaps loss): " + str( np.mean(vars['losses_heatmaps'][-train_vars['log_interval']:])), verbose) + "\n" msg += print_verbose("-------------------------------------------------------------------------------------------", verbose) + "\n" msg += print_verbose("Joint pixel losses:", verbose) + "\n" msg += print_verbose("-------------------------------------------------------------------------------------------", verbose) + "\n" joint_loss_avg = 0 aa = np.mean(np.array(vars['pixel_losses'])) msg += print_verbose("\tTotal mean pixel loss: " + str(aa), verbose) + '\n' msg += print_verbose("-------------------------------------------------------------------------------------------", verbose) + "\n" tot_joint_loss_avg = 0 for heatmap_ix in range(model.num_heatmaps): msg += print_verbose("\tJoint index: " + str(heatmap_ix), verbose) + "\n" mean_joint_pixel_loss = np.mean( np.array(vars['pixel_losses']) [-train_vars['log_interval']:, heatmap_ix]) joint_loss_avg += mean_joint_pixel_loss tot_mean_joint_pixel_loss = np.mean(np.array(vars['pixel_losses'])[:, heatmap_ix]) tot_joint_loss_avg += tot_mean_joint_pixel_loss msg += print_verbose("\tTraining set mean error for last " + str(train_vars['log_interval']) + " iterations (average pixel loss): " + str(mean_joint_pixel_loss), verbose) + "\n" msg += print_verbose("\tTraining set stddev error for last " + str(train_vars['log_interval']) + " iterations (average pixel loss): " + str(np.std( np.array(vars['pixel_losses'])[-train_vars['log_interval']:, heatmap_ix])), verbose) + "\n" msg += print_verbose("\tThis is the last pixel dist loss: " + str(vars['pixel_losses'][-1][heatmap_ix]), verbose) + "\n" msg += print_verbose("\tTraining set mean error for last " + str(train_vars['log_interval']) + " iterations (average pixel loss of sample): " + str(np.mean(np.array(vars['pixel_losses_sample'])[-train_vars['log_interval']:, heatmap_ix])), verbose) + "\n" msg += print_verbose("\tTraining set stddev error for last " + str(train_vars['log_interval']) + " iterations (average pixel loss of sample): " + str(np.std(np.array(vars['pixel_losses_sample'])[-train_vars['log_interval']:, heatmap_ix])), verbose) + "\n" msg += print_verbose( "\tThis is the last pixel dist loss of sample: " + str(vars['pixel_losses_sample'][-1][heatmap_ix]), verbose) + "\n" msg += print_verbose( "\t-------------------------------------------------------------------------------------------", verbose) + "\n" msg += print_verbose( "-------------------------------------------------------------------------------------------", verbose) + "\n" joint_loss_avg /= len(model.heatmap_ixs) tot_joint_loss_avg /= len(model.heatmap_ixs) msg += print_verbose("-------------------------------------------------------------------------------------------", verbose) + "\n" msg += print_verbose("\tCurrent mean pixel loss: " + str(joint_loss_avg), verbose) + '\n' msg += print_verbose("-------------------------------------------------------------------------------------------", verbose) + "\n" msg += print_verbose("-------------------------------------------------------------------------------------------", verbose) + "\n" msg += print_verbose("\tTotal mean pixel loss: " + str(tot_joint_loss_avg), verbose) + '\n' msg += print_verbose("-------------------------------------------------------------------------------------------", verbose) + "\n" if not train_vars['output_filepath'] == '': with open(train_vars['output_filepath'], 'a') as f: f.write(msg + '\n') return tot_joint_loss_avg
def train(train_loader, model, optimizer, train_vars): verbose = train_vars['verbose'] for batch_idx, (data, target) in enumerate(train_loader): train_vars['batch_idx'] = batch_idx # print info about performing first iter if batch_idx < train_vars['iter_size']: print_verbose("\rPerforming first iteration; current mini-batch: " + str(batch_idx+1) + "/" + str(train_vars['iter_size']), verbose, n_tabs=0, erase_line=True) # check if arrived at iter to start arrived_curr_iter, train_vars = run_until_curr_iter(batch_idx, train_vars) if not arrived_curr_iter: continue # save checkpoint after final iteration if train_vars['curr_iter'] - 1 == train_vars['num_iter']: train_vars = save_final_checkpoint(train_vars, model, optimizer) break # start time counter start = time.time() # get data and target as torch Variables _, target_joints, target_heatmaps, target_joints_z = target data, target_heatmaps = Variable(data), Variable(target_heatmaps) if train_vars['use_cuda']: data = data.cuda() target_heatmaps = target_heatmaps.cuda() # get model output output = model(data) # accumulate loss for sub-mini-batch if model.cross_entropy: loss_func = my_losses.cross_entropy_loss_p_logq else: loss_func = my_losses.euclidean_loss loss = my_losses.calculate_loss_HALNet(loss_func, output, target_heatmaps, model.joint_ixs, model.WEIGHT_LOSS_INTERMED1, model.WEIGHT_LOSS_INTERMED2, model.WEIGHT_LOSS_INTERMED3, model.WEIGHT_LOSS_MAIN, train_vars['iter_size']) loss.backward() train_vars['total_loss'] += loss # accumulate pixel dist loss for sub-mini-batch train_vars['total_pixel_loss'] = my_losses.accumulate_pixel_dist_loss_multiple( train_vars['total_pixel_loss'], output[3], target_heatmaps, train_vars['batch_size']) if train_vars['cross_entropy']: train_vars['total_pixel_loss_sample'] = my_losses.accumulate_pixel_dist_loss_from_sample_multiple( train_vars['total_pixel_loss_sample'], output[3], target_heatmaps, train_vars['batch_size']) else: train_vars['total_pixel_loss_sample'] = [-1] * len(model.joint_ixs) # get boolean variable stating whether a mini-batch has been completed minibatch_completed = (batch_idx+1) % train_vars['iter_size'] == 0 if minibatch_completed: # optimise for mini-batch optimizer.step() # clear optimiser optimizer.zero_grad() # append total loss train_vars['losses'].append(train_vars['total_loss'].item()) # erase total loss total_loss = train_vars['total_loss'].item() train_vars['total_loss'] = 0 # append dist loss train_vars['pixel_losses'].append(train_vars['total_pixel_loss']) # erase pixel dist loss train_vars['total_pixel_loss'] = [0] * len(model.joint_ixs) # append dist loss of sample from output train_vars['pixel_losses_sample'].append(train_vars['total_pixel_loss_sample']) # erase dist loss of sample from output train_vars['total_pixel_loss_sample'] = [0] * len(model.joint_ixs) # check if loss is better if train_vars['losses'][-1] < train_vars['best_loss']: train_vars['best_loss'] = train_vars['losses'][-1] print_verbose(" This is a best loss found so far: " + str(train_vars['losses'][-1]), verbose) train_vars['best_model_dict'] = { 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'train_vars': train_vars } # log checkpoint if train_vars['curr_iter'] % train_vars['log_interval'] == 0: trainer.print_log_info(model, optimizer, epoch, total_loss, train_vars, train_vars) if train_vars['curr_iter'] % train_vars['log_interval_valid'] == 0: print_verbose("\nSaving model and checkpoint model for validation", verbose) checkpoint_model_dict = { 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'train_vars': train_vars, } trainer.save_checkpoint(checkpoint_model_dict, filename=train_vars['checkpoint_filenamebase'] + 'for_valid_' + str(train_vars['curr_iter']) + '.pth.tar') # print time lapse prefix = 'Training (Epoch #' + str(epoch) + ' ' + str(train_vars['curr_epoch_iter']) + '/' +\ str(train_vars['tot_iter']) + ')' + ', (Batch ' + str(train_vars['batch_idx']+1) +\ '(' + str(train_vars['iter_size']) + ')' + '/' +\ str(train_vars['num_batches']) + ')' + ', (Iter #' + str(train_vars['curr_iter']) +\ '(' + str(train_vars['batch_size']) + ')' +\ ' - log every ' + str(train_vars['log_interval']) + ' iter): ' train_vars['tot_toc'] = display_est_time_loop(train_vars['tot_toc'] + time.time() - start, train_vars['curr_iter'], train_vars['num_iter'], prefix=prefix) train_vars['curr_iter'] += 1 train_vars['start_iter'] = train_vars['curr_iter'] + 1 train_vars['curr_epoch_iter'] += 1 return train_vars
def load_resnet_weights_into_HALNet(halnet, verbose, n_tabs=1): print_verbose("Loading RESNet50...", verbose, n_tabs) resnet50 = resnet.resnet50(pretrained=True) print_verbose("Done loading RESNet50", verbose, n_tabs) # initialize HALNet with RESNet50 print_verbose("Initializaing network with RESNet50...", verbose, n_tabs) # initialize level 1 # initialize conv1 resnet_weight = resnet50.conv1.weight.data.cpu() float_tensor = np.random.normal(np.mean(resnet_weight.numpy()), np.std(resnet_weight.numpy()), (resnet_weight.shape[0], 1, resnet_weight.shape[2], resnet_weight.shape[2])) resnet_weight_numpy = resnet_weight.numpy() resnet_weight = np.concatenate((resnet_weight_numpy, float_tensor), axis=1) resnet_weight = torch.FloatTensor(resnet_weight) halnet.conv1[0]._parameters['weight'].data.copy_(resnet_weight) # initialize level 2 # initialize res2a resnet_weight = resnet50.layer1[0].conv1.weight.data halnet.res2a.right_res[0][0]._parameters['weight'].data.copy_(resnet_weight) resnet_weight = resnet50.layer1[0].conv2.weight.data halnet.res2a.right_res[2][0]._parameters['weight'].data.copy_(resnet_weight) resnet_weight = resnet50.layer1[0].conv3.weight.data halnet.res2a.right_res[4][0]._parameters['weight'].data.copy_(resnet_weight) resnet_weight = resnet50.layer1[0].downsample[0].weight.data halnet.res2a.left_res[0]._parameters['weight'].data.copy_(resnet_weight) # initialize res2b resnet_weight = resnet50.layer1[1].conv1.weight.data halnet.res2b.right_res[0][0]._parameters['weight'].data.copy_(resnet_weight) resnet_weight = resnet50.layer1[1].conv2.weight.data halnet.res2b.right_res[2][0]._parameters['weight'].data.copy_(resnet_weight) resnet_weight = resnet50.layer1[1].conv3.weight.data halnet.res2b.right_res[4][0]._parameters['weight'].data.copy_(resnet_weight) # initialize res2c resnet_weight = resnet50.layer1[2].conv1.weight.data halnet.res2c.right_res[0][0]._parameters['weight'].data.copy_(resnet_weight) resnet_weight = resnet50.layer1[2].conv2.weight.data halnet.res2c.right_res[2][0]._parameters['weight'].data.copy_(resnet_weight) resnet_weight = resnet50.layer1[2].conv3.weight.data halnet.res2c.right_res[4][0]._parameters['weight'].data.copy_(resnet_weight) # initialize res3a resnet_weight = resnet50.layer2[0].conv1.weight.data halnet.res3a.right_res[0][0]._parameters['weight'].data.copy_(resnet_weight) resnet_weight = resnet50.layer2[0].conv2.weight.data halnet.res3a.right_res[2][0]._parameters['weight'].data.copy_(resnet_weight) resnet_weight = resnet50.layer2[0].conv3.weight.data halnet.res3a.right_res[4][0]._parameters['weight'].data.copy_(resnet_weight) resnet_weight = resnet50.layer2[0].downsample[0].weight.data halnet.res3a.left_res[0]._parameters['weight'].data.copy_(resnet_weight) # initialize res3b resnet_weight = resnet50.layer2[1].conv1.weight.data halnet.res3b.right_res[0][0]._parameters['weight'].data.copy_(resnet_weight) resnet_weight = resnet50.layer2[1].conv2.weight.data halnet.res3b.right_res[2][0]._parameters['weight'].data.copy_(resnet_weight) resnet_weight = resnet50.layer2[1].conv3.weight.data halnet.res3b.right_res[4][0]._parameters['weight'].data.copy_(resnet_weight) # initialize res3c resnet_weight = resnet50.layer2[2].conv1.weight.data halnet.res3c.right_res[0][0]._parameters['weight'].data.copy_(resnet_weight) resnet_weight = resnet50.layer2[2].conv2.weight.data halnet.res3c.right_res[2][0]._parameters['weight'].data.copy_(resnet_weight) resnet_weight = resnet50.layer2[2].conv3.weight.data halnet.res3c.right_res[4][0]._parameters['weight'].data.copy_(resnet_weight) print_verbose("Done initializaing network with RESNet50", verbose, n_tabs) print_verbose("Deleting resnet from memory", verbose, n_tabs) del resnet50 return halnet
def train(train_loader, model, optimizer, train_vars): verbose = train_vars['verbose'] for batch_idx, (data, target) in enumerate(train_loader): train_vars['batch_idx'] = batch_idx # print info about performing first iter if batch_idx < train_vars['iter_size']: print_verbose( "\rPerforming first iteration; current mini-batch: " + str(batch_idx + 1) + "/" + str(train_vars['iter_size']), verbose, n_tabs=0, erase_line=True) # check if arrived at iter to start arrived_curr_iter, train_vars = run_until_curr_iter( batch_idx, train_vars) if not arrived_curr_iter: continue # save checkpoint after final iteration if train_vars['curr_iter'] - 1 == train_vars['num_iter']: train_vars = trainer.save_final_checkpoint(train_vars, model, optimizer) break # start time counter start = time.time() # get data and target as torch Variables _, target_joints, target_heatmaps, target_joints_z = target # make target joints be relative target_joints = target_joints[:, 3:] data, target_heatmaps = Variable(data), Variable(target_heatmaps) if train_vars['use_cuda']: data = data.cuda() target_heatmaps = target_heatmaps.cuda() target_joints = target_joints.cuda() target_joints_z = target_joints_z.cuda() # get model output output = model(data) # accumulate loss for sub-mini-batch if train_vars['cross_entropy']: loss_func = my_losses.cross_entropy_loss_p_logq else: loss_func = my_losses.euclidean_loss weights_heatmaps_loss, weights_joints_loss = get_loss_weights( train_vars['curr_iter']) loss, loss_heatmaps, loss_joints = my_losses.calculate_loss_JORNet( loss_func, output, target_heatmaps, target_joints, train_vars['joint_ixs'], weights_heatmaps_loss, weights_joints_loss, train_vars['iter_size']) loss.backward() train_vars['total_loss'] += loss.item() train_vars['total_joints_loss'] += loss_joints.item() train_vars['total_heatmaps_loss'] += loss_heatmaps.item() # accumulate pixel dist loss for sub-mini-batch train_vars[ 'total_pixel_loss'] = my_losses.accumulate_pixel_dist_loss_multiple( train_vars['total_pixel_loss'], output[3], target_heatmaps, train_vars['batch_size']) if train_vars['cross_entropy']: train_vars[ 'total_pixel_loss_sample'] = my_losses.accumulate_pixel_dist_loss_from_sample_multiple( train_vars['total_pixel_loss_sample'], output[3], target_heatmaps, train_vars['batch_size']) else: train_vars['total_pixel_loss_sample'] = [-1] * len(model.joint_ixs) ''' For debugging training for i in range(train_vars['max_mem_batch']): filenamebase_idx = (batch_idx * train_vars['max_mem_batch']) + i filenamebase = train_loader.dataset.get_filenamebase(filenamebase_idx) visualize.plot_joints_from_heatmaps(target_heatmaps[i].data.cpu().numpy(), title='GT joints: ' + filenamebase, data=data[i].data.cpu().numpy()) visualize.plot_joints_from_heatmaps(output[3][i].data.cpu().numpy(), title='Pred joints: ' + filenamebase, data=data[i].data.cpu().numpy()) visualize.plot_image_and_heatmap(output[3][i][4].data.numpy(), data=data[i].data.numpy(), title='Thumb tib heatmap: ' + filenamebase) visualize.show() ''' # get boolean variable stating whether a mini-batch has been completed minibatch_completed = (batch_idx + 1) % train_vars['iter_size'] == 0 if minibatch_completed: # visualize # ax, fig = visualize.plot_3D_joints(target_joints[0]) # visualize.plot_3D_joints(target_joints[1], ax=ax, fig=fig) if train_vars['curr_iter'] % train_vars['log_interval'] == 0: fig, ax = visualize.plot_3D_joints(target_joints[0]) visualize.savefig('joints_GT_' + str(train_vars['curr_iter']) + '.png') #visualize.plot_3D_joints(target_joints[1], fig=fig, ax=ax, color_root='C7') #visualize.plot_3D_joints(output[7].data.cpu().numpy()[0], fig=fig, ax=ax, color_root='C7') visualize.plot_3D_joints(output[7].data.cpu().numpy()[0]) visualize.savefig('joints_model_' + str(train_vars['curr_iter']) + '.png') #visualize.show() #visualize.savefig('joints_' + str(train_vars['curr_iter']) + '.png') # change learning rate to 0.01 after 45000 iterations optimizer = change_learning_rate(optimizer, 0.01, train_vars['curr_iter']) # optimise for mini-batch optimizer.step() # clear optimiser optimizer.zero_grad() # append total loss train_vars['losses'].append(train_vars['total_loss']) # erase total loss total_loss = train_vars['total_loss'] train_vars['total_loss'] = 0 # append total joints loss train_vars['losses_joints'].append(train_vars['total_joints_loss']) # erase total joints loss train_vars['total_joints_loss'] = 0 # append total joints loss train_vars['losses_heatmaps'].append( train_vars['total_heatmaps_loss']) # erase total joints loss train_vars['total_heatmaps_loss'] = 0 # append dist loss train_vars['pixel_losses'].append(train_vars['total_pixel_loss']) # erase pixel dist loss train_vars['total_pixel_loss'] = [0] * len(model.joint_ixs) # append dist loss of sample from output train_vars['pixel_losses_sample'].append( train_vars['total_pixel_loss_sample']) # erase dist loss of sample from output train_vars['total_pixel_loss_sample'] = [0] * len(model.joint_ixs) # check if loss is better if train_vars['losses'][-1] < train_vars['best_loss']: train_vars['best_loss'] = train_vars['losses'][-1] print_verbose( " This is a best loss found so far: " + str(train_vars['losses'][-1]), verbose) train_vars['best_model_dict'] = { 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'train_vars': train_vars } # log checkpoint if train_vars['curr_iter'] % train_vars['log_interval'] == 0: trainer.print_log_info(model, optimizer, epoch, total_loss, train_vars, train_vars) aa1 = target_joints[0].data.cpu().numpy() aa2 = output[7][0].data.cpu().numpy() output_joint_loss = np.sum(np.abs(aa1 - aa2)) / 63 msg = '' msg += print_verbose( "-------------------------------------------------------------------------------------------", verbose) + "\n" msg += print_verbose( '\tJoint Coord Avg Loss for first image of current mini-batch: ' + str(output_joint_loss) + '\n', train_vars['verbose']) msg += print_verbose( "-------------------------------------------------------------------------------------------", verbose) + "\n" if not train_vars['output_filepath'] == '': with open(train_vars['output_filepath'], 'a') as f: f.write(msg + '\n') if train_vars['curr_iter'] % train_vars['log_interval_valid'] == 0: print_verbose( "\nSaving model and checkpoint model for validation", verbose) checkpoint_model_dict = { 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'train_vars': train_vars, } trainer.save_checkpoint( checkpoint_model_dict, filename=train_vars['checkpoint_filenamebase'] + 'for_valid_' + str(train_vars['curr_iter']) + '.pth.tar') # print time lapse prefix = 'Training (Epoch #' + str(epoch) + ' ' + str(train_vars['curr_epoch_iter']) + '/' +\ str(train_vars['tot_iter']) + ')' + ', (Batch ' + str(train_vars['batch_idx']+1) +\ '(' + str(train_vars['iter_size']) + ')' + '/' +\ str(train_vars['num_batches']) + ')' + ', (Iter #' + str(train_vars['curr_iter']) +\ '(' + str(train_vars['batch_size']) + ')' +\ ' - log every ' + str(train_vars['log_interval']) + ' iter): ' train_vars['tot_toc'] = display_est_time_loop( train_vars['tot_toc'] + time.time() - start, train_vars['curr_iter'], train_vars['num_iter'], prefix=prefix) train_vars['curr_iter'] += 1 train_vars['start_iter'] = train_vars['curr_iter'] + 1 train_vars['curr_epoch_iter'] += 1 return train_vars
def validate(valid_loader, model, optimizer, valid_vars, control_vars, verbose=True): curr_epoch_iter = 1 for batch_idx, (data, target) in enumerate(valid_loader): control_vars['batch_idx'] = batch_idx if batch_idx < control_vars['iter_size']: print_verbose("\rPerforming first iteration; current mini-batch: " + str(batch_idx + 1) + "/" + str(control_vars['iter_size']), verbose, n_tabs=0, erase_line=True) # start time counter start = time.time() # get data and targetas cuda variables target_heatmaps, target_joints, target_joints_z = target data, target_heatmaps = Variable(data), Variable(target_heatmaps) if valid_vars['use_cuda']: data = data.cuda() target_heatmaps = target_heatmaps.cuda() # visualize if debugging # get model output output = model(data) # accumulate loss for sub-mini-batch if valid_vars['cross_entropy']: loss_func = my_losses.cross_entropy_loss_p_logq else: loss_func = my_losses.euclidean_loss loss = my_losses.calculate_loss_HALNet(loss_func, output, target_heatmaps, model.joint_ixs, model.WEIGHT_LOSS_INTERMED1, model.WEIGHT_LOSS_INTERMED2, model.WEIGHT_LOSS_INTERMED3, model.WEIGHT_LOSS_MAIN, control_vars['iter_size']) if DEBUG_VISUALLY: for i in range(control_vars['max_mem_batch']): filenamebase_idx = (batch_idx * control_vars['max_mem_batch']) + i filenamebase = valid_loader.dataset.get_filenamebase(filenamebase_idx) fig = visualize.create_fig() #visualize.plot_joints_from_heatmaps(output[3][i].data.numpy(), fig=fig, # title=filenamebase, data=data[i].data.numpy()) #visualize.plot_image_and_heatmap(output[3][i][8].data.numpy(), # data=data[i].data.numpy(), # title=filenamebase) #visualize.savefig('/home/paulo/' + filenamebase.replace('/', '_') + '_heatmap') labels_colorspace = conv.heatmaps_to_joints_colorspace(output[3][i].data.numpy()) data_crop, crop_coords, labels_heatmaps, labels_colorspace = \ converter.crop_image_get_labels(data[i].data.numpy(), labels_colorspace, range(21)) visualize.plot_image(data_crop, title=filenamebase, fig=fig) visualize.plot_joints_from_colorspace(labels_colorspace, title=filenamebase, fig=fig, data=data_crop) #visualize.savefig('/home/paulo/' + filenamebase.replace('/', '_') + '_crop') visualize.show() #loss.backward() valid_vars['total_loss'] += loss # accumulate pixel dist loss for sub-mini-batch valid_vars['total_pixel_loss'] = my_losses.accumulate_pixel_dist_loss_multiple( valid_vars['total_pixel_loss'], output[3], target_heatmaps, control_vars['batch_size']) if valid_vars['cross_entropy']: valid_vars['total_pixel_loss_sample'] = my_losses.accumulate_pixel_dist_loss_from_sample_multiple( valid_vars['total_pixel_loss_sample'], output[3], target_heatmaps, control_vars['batch_size']) else: valid_vars['total_pixel_loss_sample'] = [-1] * len(model.joint_ixs) # get boolean variable stating whether a mini-batch has been completed minibatch_completed = (batch_idx+1) % control_vars['iter_size'] == 0 if minibatch_completed: # append total loss valid_vars['losses'].append(valid_vars['total_loss'].item()) # erase total loss total_loss = valid_vars['total_loss'].item() valid_vars['total_loss'] = 0 # append dist loss valid_vars['pixel_losses'].append(valid_vars['total_pixel_loss']) # erase pixel dist loss valid_vars['total_pixel_loss'] = [0] * len(model.joint_ixs) # append dist loss of sample from output valid_vars['pixel_losses_sample'].append(valid_vars['total_pixel_loss_sample']) # erase dist loss of sample from output valid_vars['total_pixel_loss_sample'] = [0] * len(model.joint_ixs) # check if loss is better if valid_vars['losses'][-1] < valid_vars['best_loss']: valid_vars['best_loss'] = valid_vars['losses'][-1] #print_verbose(" This is a best loss found so far: " + str(valid_vars['losses'][-1]), verbose) # log checkpoint if control_vars['curr_iter'] % control_vars['log_interval'] == 0: trainer.print_log_info(model, optimizer, 1, total_loss, valid_vars, control_vars) model_dict = { 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'control_vars': control_vars, 'train_vars': valid_vars, } trainer.save_checkpoint(model_dict, filename=valid_vars['checkpoint_filenamebase'] + str(control_vars['num_iter']) + '.pth.tar') # print time lapse prefix = 'Validating (Epoch #' + str(1) + ' ' + str(control_vars['curr_epoch_iter']) + '/' +\ str(control_vars['tot_iter']) + ')' + ', (Batch ' + str(control_vars['batch_idx']+1) +\ '(' + str(control_vars['iter_size']) + ')' + '/' +\ str(control_vars['num_batches']) + ')' + ', (Iter #' + str(control_vars['curr_iter']) +\ '(' + str(control_vars['batch_size']) + ')' +\ ' - log every ' + str(control_vars['log_interval']) + ' iter): ' control_vars['tot_toc'] = display_est_time_loop(control_vars['tot_toc'] + time.time() - start, control_vars['curr_iter'], control_vars['num_iter'], prefix=prefix) control_vars['curr_iter'] += 1 control_vars['start_iter'] = control_vars['curr_iter'] + 1 control_vars['curr_epoch_iter'] += 1 return valid_vars, control_vars
def validate(valid_loader, model, optimizer, valid_vars, control_vars, verbose=True): curr_epoch_iter = 1 for batch_idx, (data, target) in enumerate(valid_loader): control_vars['batch_idx'] = batch_idx if batch_idx < control_vars['iter_size']: print_verbose( "\rPerforming first iteration; current mini-batch: " + str(batch_idx + 1) + "/" + str(control_vars['iter_size']), verbose, n_tabs=0, erase_line=True) # start time counter start = time.time() # get data and targetas cuda variables target_heatmaps, target_joints, target_handroot = target # make target joints be relative target_joints = target_joints[:, 3:] data, target_heatmaps = Variable(data), Variable(target_heatmaps) if valid_vars['use_cuda']: data = data.cuda() target_joints = target_joints.cuda() target_heatmaps = target_heatmaps.cuda() target_handroot = target_handroot.cuda() # visualize if debugging # get model output output = model(data) # accumulate loss for sub-mini-batch if model.cross_entropy: loss_func = my_losses.cross_entropy_loss_p_logq else: loss_func = my_losses.euclidean_loss weights_heatmaps_loss, weights_joints_loss = get_loss_weights( control_vars['curr_iter']) loss, loss_heatmaps, loss_joints = my_losses.calculate_loss_JORNet( loss_func, output, target_heatmaps, target_joints, valid_vars['joint_ixs'], weights_heatmaps_loss, weights_joints_loss, control_vars['iter_size']) valid_vars['total_loss'] += loss valid_vars['total_joints_loss'] += loss_joints valid_vars['total_heatmaps_loss'] += loss_heatmaps # accumulate pixel dist loss for sub-mini-batch valid_vars[ 'total_pixel_loss'] = my_losses.accumulate_pixel_dist_loss_multiple( valid_vars['total_pixel_loss'], output[3], target_heatmaps, control_vars['batch_size']) valid_vars[ 'total_pixel_loss_sample'] = my_losses.accumulate_pixel_dist_loss_from_sample_multiple( valid_vars['total_pixel_loss_sample'], output[3], target_heatmaps, control_vars['batch_size']) # get boolean variable stating whether a mini-batch has been completed for i in range(control_vars['max_mem_batch']): filenamebase_idx = (batch_idx * control_vars['max_mem_batch']) + i filenamebase = valid_loader.dataset.get_filenamebase( filenamebase_idx) print('') print(filenamebase) visualize.plot_image(data[i].data.numpy()) visualize.show() output_batch_numpy = output[7][i].data.cpu().numpy() print('\n-------------------------------') reshaped_out = output_batch_numpy.reshape((20, 3)) for j in range(20): print('[{}, {}, {}],'.format(reshaped_out[j, 0], reshaped_out[j, 1], reshaped_out[j, 2])) print('-------------------------------') fig, ax = visualize.plot_3D_joints(target_joints[i]) visualize.plot_3D_joints(output_batch_numpy, fig=fig, ax=ax, color='C6') visualize.title(filenamebase) visualize.show() temp = np.zeros((21, 3)) output_batch_numpy_abs = output_batch_numpy.reshape((20, 3)) temp[1:, :] = output_batch_numpy_abs output_batch_numpy_abs = temp output_joints_colorspace = camera.joints_depth2color( output_batch_numpy_abs, depth_intr_matrix=synthhands_handler.DEPTH_INTR_MTX, handroot=target_handroot[i].data.cpu().numpy()) visualize.plot_3D_joints(output_joints_colorspace) visualize.show() aa1 = target_joints[i].data.cpu().numpy().reshape((20, 3)) aa2 = output[7][i].data.cpu().numpy().reshape((20, 3)) print('\n----------------------------------') print(np.sum(np.abs(aa1 - aa2)) / 60) print('----------------------------------') #loss.backward() valid_vars['total_loss'] += loss valid_vars['total_joints_loss'] += loss_joints valid_vars['total_heatmaps_loss'] += loss_heatmaps # accumulate pixel dist loss for sub-mini-batch valid_vars[ 'total_pixel_loss'] = my_losses.accumulate_pixel_dist_loss_multiple( valid_vars['total_pixel_loss'], output[3], target_heatmaps, control_vars['batch_size']) valid_vars[ 'total_pixel_loss_sample'] = my_losses.accumulate_pixel_dist_loss_from_sample_multiple( valid_vars['total_pixel_loss_sample'], output[3], target_heatmaps, control_vars['batch_size']) # get boolean variable stating whether a mini-batch has been completed minibatch_completed = (batch_idx + 1) % control_vars['iter_size'] == 0 if minibatch_completed: # append total loss valid_vars['losses'].append(valid_vars['total_loss'].data[0]) # erase total loss total_loss = valid_vars['total_loss'].data[0] valid_vars['total_loss'] = 0 # append total joints loss valid_vars['losses_joints'].append( valid_vars['total_joints_loss'].data[0]) # erase total joints loss valid_vars['total_joints_loss'] = 0 # append total joints loss valid_vars['losses_heatmaps'].append( valid_vars['total_heatmaps_loss'].data[0]) # erase total joints loss valid_vars['total_heatmaps_loss'] = 0 # append dist loss valid_vars['pixel_losses'].append(valid_vars['total_pixel_loss']) # erase pixel dist loss valid_vars['total_pixel_loss'] = [0] * len(model.joint_ixs) # append dist loss of sample from output valid_vars['pixel_losses_sample'].append( valid_vars['total_pixel_loss_sample']) # erase dist loss of sample from output valid_vars['total_pixel_loss_sample'] = [0] * len(model.joint_ixs) # check if loss is better #if valid_vars['losses'][-1] < valid_vars['best_loss']: # valid_vars['best_loss'] = valid_vars['losses'][-1] # print_verbose(" This is a best loss found so far: " + str(valid_vars['losses'][-1]), verbose) # log checkpoint if control_vars['curr_iter'] % control_vars['log_interval'] == 0: trainer.print_log_info(model, optimizer, 1, total_loss, valid_vars, control_vars) model_dict = { 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'control_vars': control_vars, 'train_vars': valid_vars, } trainer.save_checkpoint( model_dict, filename=valid_vars['checkpoint_filenamebase'] + str(control_vars['num_iter']) + '.pth.tar') # print time lapse prefix = 'Validating (Epoch #' + str(1) + ' ' + str(control_vars['curr_epoch_iter']) + '/' +\ str(control_vars['tot_iter']) + ')' + ', (Batch ' + str(control_vars['batch_idx']+1) +\ '(' + str(control_vars['iter_size']) + ')' + '/' +\ str(control_vars['num_batches']) + ')' + ', (Iter #' + str(control_vars['curr_iter']) +\ '(' + str(control_vars['batch_size']) + ')' +\ ' - log every ' + str(control_vars['log_interval']) + ' iter): ' control_vars['tot_toc'] = display_est_time_loop( control_vars['tot_toc'] + time.time() - start, control_vars['curr_iter'], control_vars['num_iter'], prefix=prefix) control_vars['curr_iter'] += 1 control_vars['start_iter'] = control_vars['curr_iter'] + 1 control_vars['curr_epoch_iter'] += 1 return valid_vars, control_vars
'start_iter_mod'] = train_vars['start_iter'] % train_vars['tot_iter'] train_vars['start_epoch'] = int(train_vars['start_iter'] / train_vars['n_iter_per_epoch']) trainer.print_header_info(model, train_loader, train_vars) model.train() train_vars['curr_iter'] = 1 msg = '' for epoch in range(train_vars['num_epochs']): train_vars['curr_epoch_iter'] = 1 if epoch + 1 < train_vars['start_epoch']: msg += print_verbose("\nAdvancing through epochs: " + str(epoch + 1), train_vars['verbose'], erase_line=True) train_vars['curr_iter'] += train_vars['n_iter_per_epoch'] if not train_vars['output_filepath'] == '': with open(train_vars['output_filepath'], 'a') as f: f.write(msg + '\n') continue else: msg = '' train_vars['total_loss'] = 0 train_vars['total_pixel_loss'] = [0] * len(model.joint_ixs) train_vars['total_pixel_loss_sample'] = [0] * len(model.joint_ixs) optimizer.zero_grad() # train model train_vars = train(train_loader, model, optimizer, train_vars) if train_vars['done_training']:
def validate(valid_loader, model, optimizer, valid_vars, control_vars, verbose=True): losses_main = [] for batch_idx, (data, target) in enumerate(valid_loader): control_vars['batch_idx'] = batch_idx if batch_idx < control_vars['iter_size']: print_verbose( "\rPerforming first iteration; current mini-batch: " + str(batch_idx + 1) + "/" + str(control_vars['iter_size']), verbose, n_tabs=0, erase_line=True) # start time counter start = time.time() # get data and targetas cuda variables target_heatmaps, target_joints, target_handroot = target # make target joints be relative target_joints = target_joints[:, 3:] data, target_heatmaps = Variable(data), Variable(target_heatmaps) if valid_vars['use_cuda']: data = data.cuda() target_joints = target_joints.cuda() target_heatmaps = target_heatmaps.cuda() target_handroot = target_handroot.cuda() # visualize if debugging # get model output output = model(data) # accumulate loss for sub-mini-batch if model.cross_entropy: loss_func = my_losses.cross_entropy_loss_p_logq else: loss_func = my_losses.euclidean_loss weights_heatmaps_loss, weights_joints_loss = get_loss_weights( control_vars['curr_iter']) loss, loss_heatmaps, loss_joints, loss_main = my_losses.calculate_loss_JORNet_for_valid( loss_func, output, target_heatmaps, target_joints, valid_vars['joint_ixs'], weights_heatmaps_loss, weights_joints_loss, control_vars['iter_size']) losses_main.append(loss_main.item() / 63.0) valid_vars['total_loss'] += loss valid_vars['total_joints_loss'] += loss_joints valid_vars['total_heatmaps_loss'] += loss_heatmaps # accumulate pixel dist loss for sub-mini-batch valid_vars[ 'total_pixel_loss'] = my_losses.accumulate_pixel_dist_loss_multiple( valid_vars['total_pixel_loss'], output[3], target_heatmaps, control_vars['batch_size']) valid_vars[ 'total_pixel_loss_sample'] = my_losses.accumulate_pixel_dist_loss_from_sample_multiple( valid_vars['total_pixel_loss_sample'], output[3], target_heatmaps, control_vars['batch_size']) valid_vars['total_loss'] += loss valid_vars['total_joints_loss'] += loss_joints valid_vars['total_heatmaps_loss'] += loss_heatmaps # accumulate pixel dist loss for sub-mini-batch valid_vars[ 'total_pixel_loss'] = my_losses.accumulate_pixel_dist_loss_multiple( valid_vars['total_pixel_loss'], output[3], target_heatmaps, control_vars['batch_size']) valid_vars[ 'total_pixel_loss_sample'] = my_losses.accumulate_pixel_dist_loss_from_sample_multiple( valid_vars['total_pixel_loss_sample'], output[3], target_heatmaps, control_vars['batch_size']) # get boolean variable stating whether a mini-batch has been completed minibatch_completed = (batch_idx + 1) % control_vars['iter_size'] == 0 if minibatch_completed: # append total loss valid_vars['losses'].append(valid_vars['total_loss'].item()) # erase total loss total_loss = valid_vars['total_loss'].item() valid_vars['total_loss'] = 0 # append total joints loss valid_vars['losses_joints'].append( valid_vars['total_joints_loss'].item()) # erase total joints loss valid_vars['total_joints_loss'] = 0 # append total joints loss valid_vars['losses_heatmaps'].append( valid_vars['total_heatmaps_loss'].item()) # erase total joints loss valid_vars['total_heatmaps_loss'] = 0 # append dist loss valid_vars['pixel_losses'].append(valid_vars['total_pixel_loss']) # erase pixel dist loss valid_vars['total_pixel_loss'] = [0] * len(model.joint_ixs) # append dist loss of sample from output valid_vars['pixel_losses_sample'].append( valid_vars['total_pixel_loss_sample']) # erase dist loss of sample from output valid_vars['total_pixel_loss_sample'] = [0] * len(model.joint_ixs) # check if loss is better #if valid_vars['losses'][-1] < valid_vars['best_loss']: # valid_vars['best_loss'] = valid_vars['losses'][-1] # print_verbose(" This is a best loss found so far: " + str(valid_vars['losses'][-1]), verbose) # log checkpoint if control_vars['curr_iter'] % control_vars['log_interval'] == 0: trainer.print_log_info(model, optimizer, 1, total_loss, valid_vars, control_vars, save_best=False, save_a_checkpoint=False) model_dict = { 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'control_vars': control_vars, 'train_vars': valid_vars, } #trainer.save_checkpoint(model_dict, # filename=valid_vars['checkpoint_filenamebase'] + # str(control_vars['num_iter']) + '.pth.tar') # print time lapse prefix = 'Validating (Epoch #' + str(1) + ' ' + str(control_vars['curr_epoch_iter']) + '/' +\ str(control_vars['tot_iter']) + ')' + ', (Batch ' + str(control_vars['batch_idx']+1) +\ '(' + str(control_vars['iter_size']) + ')' + '/' +\ str(control_vars['num_batches']) + ')' + ', (Iter #' + str(control_vars['curr_iter']) +\ '(' + str(control_vars['batch_size']) + ')' +\ ' - log every ' + str(control_vars['log_interval']) + ' iter): ' control_vars['tot_toc'] = display_est_time_loop( control_vars['tot_toc'] + time.time() - start, control_vars['curr_iter'], control_vars['num_iter'], prefix=prefix) control_vars['curr_iter'] += 1 control_vars['start_iter'] = control_vars['curr_iter'] + 1 control_vars['curr_epoch_iter'] += 1 total_avg_loss = np.mean(losses_main) return valid_vars, control_vars, total_avg_loss
def print_log_info(model, optimizer, epoch, train_vars, save_best=True, save_a_checkpoint=True): vars = train_vars total_loss = train_vars['total_loss'] model_class_name = type(model).__name__ verbose = train_vars['verbose'] print_verbose("", verbose) print_verbose( "-------------------------------------------------------------------------------------------", verbose) if save_a_checkpoint: print_verbose("Saving checkpoints:", verbose) print_verbose( "-------------------------------------------------------------------------------------------", verbose) if optimizer is None: checkpoint_model_dict = { 'model_state_dict': model.state_dict(), 'optimizer_state_dict': None, 'train_vars': train_vars, } else: checkpoint_model_dict = { 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'train_vars': train_vars, } save_checkpoint(checkpoint_model_dict, filename=vars['checkpoint_filenamebase'] + '.pth.tar') msg = '' msg += print_verbose( "-------------------------------------------------------------------------------------------", verbose) + "\n" now = datetime.datetime.now() msg += print_verbose('Time: ' + now.strftime("%Y-%m-%d %H:%M"), verbose) + "\n" msg += print_verbose( "-------------------------------------------------------------------------------------------", verbose) + "\n" msg += print_verbose( "-------------------------------------------------------------------------------------------", verbose) + "\n" msg += print_verbose('Training (Epoch #' + str(epoch) + ' ' + str(train_vars['curr_epoch_iter']) + '/' + \ str(train_vars['tot_epoch']) + ')' + ', (Batch ' + str(train_vars['batch_idx'] + 1) + \ '(' + str(train_vars['iter_size']) + ')' + '/' + \ str(train_vars['num_batches']) + ')' + ', (Iter #' + str(train_vars['curr_iter']) + \ '(' + str(train_vars['batch_size']) + ')' + \ ' - log every ' + str(train_vars['log_interval']) + ' iter): ', verbose) + '\n' msg += print_verbose( "-------------------------------------------------------------------------------------------", verbose) + "\n" msg += print_verbose("Current loss: " + str(total_loss), verbose) + "\n" msg += print_verbose("Best loss: " + str(vars['best_loss']), verbose) + "\n" msg += print_verbose("Mean total loss: " + str(np.mean(vars['losses'])), verbose) + "\n" msg += print_verbose("Stddev total loss: " + str(np.std(vars['losses'])), verbose) + "\n" msg += print_verbose( "Mean loss for last " + str(train_vars['log_interval']) + " iterations (average total loss): " + str(np.mean(vars['losses'][-train_vars['log_interval']:])), verbose) + "\n" msg += print_verbose( "Stddev loss for last " + str(train_vars['log_interval']) + " iterations (average total loss): " + str(np.std(vars['losses'][-train_vars['log_interval']:])), verbose) + "\n" msg += print_verbose( "-------------------------------------------------------------------------------------------", verbose) + "\n" if not train_vars['output_filepath'] == '': with open(train_vars['output_filepath'], 'a') as f: f.write(msg + '\n') return 1
def parse_args(model_class): parser = argparse.ArgumentParser( description='Train a hand-tracking deep neural network') parser.add_argument('--num_iter', dest='num_iter', type=int, help='Total number of iterations to train') parser.add_argument('-c', dest='checkpoint_filepath', default='', required=True, help='Checkpoint file from which to begin training') parser.add_argument('--log_interval', type=int, dest='log_interval', default=10, help='Number of iterations interval on which to log' ' a model checkpoint (default 10)') parser.add_argument('-v', '--verbose', dest='verbose', action='store_true', default=True, help='Verbose mode') parser.add_argument('--max_mem_batch', type=int, dest='max_mem_batch', default=8, help='Max size of batch given GPU memory (default 8)') parser.add_argument( '--batch_size', type=int, dest='batch_size', default=16, help= 'Batch size for training (if larger than max memory batch, training will take ' 'the required amount of iterations to complete a batch') parser.add_argument('-r', dest='root_folder', default='', required=True, help='Root folder for dataset') parser.add_argument('--visual', dest='visual_debugging', action='store_true', default=False, help='Whether to visually inspect results') parser.add_argument('--cuda', dest='use_cuda', action='store_true', default=False, help='Whether to use cuda for training') parser.add_argument('--split_filename', default='', required=False, help='Split filename for the file with dataset splits') args = parser.parse_args() control_vars, valid_vars = initialize_vars(args) control_vars['visual_debugging'] = args.visual_debugging print_verbose( "Loading model and optimizer from file: " + args.checkpoint_filepath, args.verbose) model, optimizer, valid_vars, train_control_vars = \ trainer.load_checkpoint(filename=args.checkpoint_filepath, model_class=model_class, use_cuda=args.use_cuda) valid_vars['root_folder'] = args.root_folder valid_vars['use_cuda'] = args.use_cuda control_vars['log_interval'] = args.log_interval random_int_str = args.checkpoint_filepath.split('_')[-2] valid_vars['checkpoint_filenamebase'] = 'valid_halnet_log_' + str( random_int_str) + '_' control_vars[ 'output_filepath'] = 'validated_halnet_log_' + random_int_str + '.txt' msg = print_verbose( "Printing also to output filepath: " + control_vars['output_filepath'], args.verbose) with open(control_vars['output_filepath'], 'w+') as f: f.write(msg + '\n') if valid_vars['use_cuda']: print_verbose("Using CUDA", args.verbose) else: print_verbose("Not using CUDA", args.verbose) control_vars['num_epochs'] = 100 control_vars['verbose'] = True if valid_vars['cross_entropy']: print_verbose("Using cross entropy loss", args.verbose) control_vars['num_iter'] = 0 valid_vars['split_filename'] = args.split_filename return model, optimizer, control_vars, valid_vars, train_control_vars
verbose=control_vars['verbose'], dataset_type='prior') control_vars['num_batches'] = len(train_loader) control_vars['n_iter_per_epoch'] = int(len(train_loader) / control_vars['iter_size']) control_vars['tot_iter'] = int(len(train_loader) / control_vars['iter_size']) control_vars['start_iter_mod'] = control_vars['start_iter'] % control_vars['tot_iter'] trainer.print_header_info(model, train_loader, control_vars) model.train() control_vars['curr_iter'] = 1 train_vars['best_loss_prior'] = 1e10 train_vars['losses_prior'] = [] train_vars['total_loss_prior'] = 0 for epoch in range(control_vars['num_epochs']): control_vars['curr_epoch_iter'] = 1 if epoch + 1 < control_vars['start_epoch']: print_verbose("Advancing through epochs: " + str(epoch + 1), control_vars['verbose'], erase_line=True) control_vars['curr_iter'] += control_vars['n_iter_per_epoch'] continue train_vars['total_loss'] = 0 train_vars['total_pixel_loss'] = [0] * len(model.joint_ixs) train_vars['total_pixel_loss_sample'] = [0] * len(model.joint_ixs) optimizer.zero_grad() # train model train_vars, control_vars = train(train_loader, model, optimizer, train_vars, control_vars, control_vars['verbose']) if control_vars['done_training']: print_verbose("Done training.", control_vars['verbose']) break
def parse_args(model_class, random_id=-1): parser = argparse.ArgumentParser(description='Train a hand-tracking deep neural network') parser.add_argument('--num_iter', dest='num_iter', type=int, required=True, help='Total number of iterations to train') parser.add_argument('-c', dest='checkpoint_filepath', default='', help='Checkpoint file from which to begin training') parser.add_argument('--log_interval', type=int, dest='log_interval', default=10, help='Number of iterations interval on which to log' ' a model checkpoint (default 10)') parser.add_argument('--log_interval_valid', type=int, dest='log_interval_valid', default=1000, help='Number of iterations interval on which to log' ' a model checkpoint for validation (default 1000)') parser.add_argument('--num_epochs', dest='num_epochs', default=100, help='Total number of epochs to train') parser.add_argument('--cuda', dest='use_cuda', action='store_true', default=False, help='Whether to use cuda for training') parser.add_argument('--override', dest='override', action='store_true', default=False, help='Whether to override checkpoint args for command line ones') parser.add_argument('-o', dest='output_filepath', default='', help='Output file for logging') parser.add_argument('-v', '--verbose', dest='verbose', action='store_true', default=True, help='Verbose mode') parser.add_argument('-j', '--heatmap_ixs', dest='heatmap_ixs', nargs='+', help='', default=list(range(21))) parser.add_argument('--resnet', dest='load_resnet', action='store_true', default=False, help='Whether to load RESNet weights onto the network when creating it') parser.add_argument('--max_mem_batch', type=int, dest='max_mem_batch', default=8, help='Max size of batch given GPU memory (default 8)') parser.add_argument('--batch_size', type=int, dest='batch_size', default=16, help='Batch size for training (if larger than max memory batch, training will take ' 'the required amount of iterations to complete a batch') parser.add_argument('--cross_entropy', dest='cross_entropy', action='store_true', default=False, help='Whether to use cross entropy loss on HALNet') parser.add_argument('-r', dest='root_folder', default='', required=True, help='Root folder for dataset') args = parser.parse_args() args.heatmap_ixs = list(map(int, args.heatmap_ixs)) train_vars = initialize_train_vars(args) train_vars['checkpoint_filenamebase'] = 'trained_' + str(model_class.__name__) + '_' if random_id >= 0: train_vars['checkpoint_filenamebase'] += str(random_id) + '_' if train_vars['output_filepath'] == '': print_verbose("No output filepath specified", args.verbose) else: print_verbose("Printing also to output filepath: " + train_vars['output_filepath'], args.verbose) if args.checkpoint_filepath == '': print_verbose("Creating network from scratch", args.verbose) print_verbose("Building network...", args.verbose) train_vars['use_cuda'] = args.use_cuda train_vars['cross_entropy'] = args.cross_entropy params_dict = {} params_dict['heatmap_ixs'] = args.heatmap_ixs params_dict['use_cuda'] = args.use_cuda params_dict['cross_entropy'] = args.cross_entropy model = model_class(params_dict) if args.load_resnet: model = load_resnet_weights_into_HALNet(model, args.verbose) print_verbose("Done building network", args.verbose) optimizer = my_optimizers.get_adadelta_halnet(model) else: print_verbose("Loading model and optimizer from file: " + args.checkpoint_filepath, args.verbose) model, optimizer, train_vars, train_vars = \ load_checkpoint(filename=args.checkpoint_filepath, model_class=model_class) if train_vars['use_cuda']: print_verbose("Using CUDA", args.verbose) else: print_verbose("Not using CUDA", args.verbose) if args.override or args.checkpoint_filepath == '': train_vars['root_folder'] = args.root_folder train_vars['use_cuda'] = args.use_cuda train_vars['log_interval'] = args.log_interval train_vars['max_mem_batch'] = args.max_mem_batch train_vars['batch_size'] = args.batch_size train_vars['num_epochs'] = 100 train_vars['verbose'] = True if train_vars['cross_entropy']: print_verbose("Using cross entropy loss", args.verbose) return model, optimizer, train_vars, train_vars
def train(train_loader, model, optimizer, train_vars, control_vars, verbose=True): curr_epoch_iter = 1 for batch_idx, (data, target) in enumerate(train_loader): control_vars['batch_idx'] = batch_idx if batch_idx < control_vars['iter_size']: print_verbose("\rPerforming first iteration; current mini-batch: " + str(batch_idx+1) + "/" + str(control_vars['iter_size']), verbose, n_tabs=0, erase_line=True) # check if arrived at iter to start if control_vars['curr_epoch_iter'] < control_vars['start_iter_mod']: if batch_idx % control_vars['iter_size'] == 0: print_verbose("\rGoing through iterations to arrive at last one saved... " + str(int(control_vars['curr_epoch_iter']*100.0/control_vars['start_iter_mod'])) + "% of " + str(control_vars['start_iter_mod']) + " iterations (" + str(control_vars['curr_epoch_iter']) + "/" + str(control_vars['start_iter_mod']) + ")", verbose, n_tabs=0, erase_line=True) control_vars['curr_epoch_iter'] += 1 control_vars['curr_iter'] += 1 curr_epoch_iter += 1 continue # save checkpoint after final iteration if control_vars['curr_iter'] == control_vars['num_iter']: print_verbose("\nReached final number of iterations: " + str(control_vars['num_iter']), verbose) print_verbose("\tSaving final model checkpoint...", verbose) final_model_dict = { 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'control_vars': control_vars, 'train_vars': train_vars, } trainer.save_checkpoint(final_model_dict, filename=train_vars['checkpoint_filenamebase'] + 'final' + str(control_vars['num_iter']) + '.pth.tar') control_vars['done_training'] = True break # start time counter start = time.time() # get data and targetas cuda variables target_heatmaps, target_joints, _, target_prior = target data, target_heatmaps, target_prior = Variable(data), Variable(target_heatmaps), Variable(target_prior) if train_vars['use_cuda']: data = data.cuda() target_heatmaps = target_heatmaps.cuda() target_prior = target_prior.cuda() # visualize if debugging # get model output output = model(data) # accumulate loss for sub-mini-batch if train_vars['cross_entropy']: loss_func = my_losses.cross_entropy_loss_p_logq else: loss_func = my_losses.euclidean_loss loss, loss_prior = my_losses.calculate_loss_HALNet_prior(loss_func, output, target_heatmaps, target_prior, model.joint_ixs, model.WEIGHT_LOSS_INTERMED1, model.WEIGHT_LOSS_INTERMED2, model.WEIGHT_LOSS_INTERMED3, model.WEIGHT_LOSS_MAIN, control_vars['iter_size']) loss.backward() train_vars['total_loss'] += loss train_vars['total_loss_prior'] += loss_prior # accumulate pixel dist loss for sub-mini-batch train_vars['total_pixel_loss'] = my_losses.accumulate_pixel_dist_loss_multiple( train_vars['total_pixel_loss'], output[3], target_heatmaps, control_vars['batch_size']) if train_vars['cross_entropy']: train_vars['total_pixel_loss_sample'] = my_losses.accumulate_pixel_dist_loss_from_sample_multiple( train_vars['total_pixel_loss_sample'], output[3], target_heatmaps, control_vars['batch_size']) else: train_vars['total_pixel_loss_sample'] = [-1] * len(model.joint_ixs) # get boolean variable stating whether a mini-batch has been completed minibatch_completed = (batch_idx+1) % control_vars['iter_size'] == 0 if minibatch_completed: # optimise for mini-batch optimizer.step() # clear optimiser optimizer.zero_grad() # append total loss train_vars['losses'].append(train_vars['total_loss'].data[0]) # erase total loss total_loss = train_vars['total_loss'].data[0] train_vars['total_loss'] = 0 # append total loss prior train_vars['losses_prior'].append(train_vars['total_loss_prior'].data[0]) # erase total loss total_loss_prior = train_vars['total_loss_prior'].data[0] train_vars['total_loss_prior'] = 0 # append dist loss train_vars['pixel_losses'].append(train_vars['total_pixel_loss']) # erase pixel dist loss train_vars['total_pixel_loss'] = [0] * len(model.joint_ixs) # append dist loss of sample from output train_vars['pixel_losses_sample'].append(train_vars['total_pixel_loss_sample']) # erase dist loss of sample from output train_vars['total_pixel_loss_sample'] = [0] * len(model.joint_ixs) # check if loss is better if train_vars['losses'][-1] < train_vars['best_loss']: train_vars['best_loss'] = train_vars['losses'][-1] print_verbose(" This is a best loss found so far: " + str(train_vars['losses'][-1]), verbose) train_vars['best_model_dict'] = { 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'control_vars': control_vars, 'train_vars': train_vars, } if train_vars['losses_prior'][-1] < train_vars['best_loss_prior']: train_vars['best_loss_prior'] = train_vars['losses_prior'][-1] # log checkpoint if control_vars['curr_iter'] % control_vars['log_interval'] == 0: trainer.print_log_info(model, optimizer, epoch, total_loss, train_vars, control_vars) msg = '' msg += print_verbose( "-------------------------------------------------------------------------------------------", verbose) + "\n" msg += print_verbose("Current loss (prior): " + str(total_loss_prior), verbose) + "\n" msg += print_verbose("Best loss (prior): " + str(train_vars['best_loss_prior']), verbose) + "\n" msg += print_verbose("Mean total loss (prior): " + str(np.mean(train_vars['losses_prior'])), verbose) + "\n" msg += print_verbose("Mean loss (prior) for last " + str(control_vars['log_interval']) + " iterations (average total loss): " + str( np.mean(train_vars['losses_prior'][-control_vars['log_interval']:])), verbose) + "\n" msg += print_verbose( "-------------------------------------------------------------------------------------------", verbose) + "\n" if not control_vars['output_filepath'] == '': with open(control_vars['output_filepath'], 'a') as f: f.write(msg + '\n') if control_vars['curr_iter'] % control_vars['log_interval_valid'] == 0: print_verbose("\nSaving model and checkpoint model for validation", verbose) checkpoint_model_dict = { 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'control_vars': control_vars, 'train_vars': train_vars, } trainer.save_checkpoint(checkpoint_model_dict, filename=train_vars['checkpoint_filenamebase'] + 'for_valid_' + str(control_vars['curr_iter']) + '.pth.tar') # print time lapse prefix = 'Training (Epoch #' + str(epoch) + ' ' + str(control_vars['curr_epoch_iter']) + '/' +\ str(control_vars['tot_iter']) + ')' + ', (Batch ' + str(control_vars['batch_idx']+1) +\ '(' + str(control_vars['iter_size']) + ')' + '/' +\ str(control_vars['num_batches']) + ')' + ', (Iter #' + str(control_vars['curr_iter']) +\ '(' + str(control_vars['batch_size']) + ')' +\ ' - log every ' + str(control_vars['log_interval']) + ' iter): ' control_vars['tot_toc'] = display_est_time_loop(control_vars['tot_toc'] + time.time() - start, control_vars['curr_iter'], control_vars['num_iter'], prefix=prefix) control_vars['curr_iter'] += 1 control_vars['start_iter'] = control_vars['curr_iter'] + 1 control_vars['curr_epoch_iter'] += 1 return train_vars, control_vars
def train(train_loader, model, optimizer, train_vars, control_vars, verbose=True): curr_epoch_iter = 1 for batch_idx, (data, target) in enumerate(train_loader): control_vars['batch_idx'] = batch_idx if batch_idx < control_vars['iter_size']: print_verbose("\rPerforming first iteration; current mini-batch: " + str(batch_idx+1) + "/" + str(control_vars['iter_size']), verbose, n_tabs=0, erase_line=True) # check if arrived at iter to start if control_vars['curr_epoch_iter'] < control_vars['start_iter_mod']: control_vars['curr_epoch_iter'] = control_vars['start_iter_mod'] msg = '' if batch_idx % control_vars['iter_size'] == 0: msg += print_verbose("\rGoing through iterations to arrive at last one saved... " + str(int(control_vars['curr_epoch_iter']*100.0/control_vars['start_iter_mod'])) + "% of " + str(control_vars['start_iter_mod']) + " iterations (" + str(control_vars['curr_epoch_iter']) + "/" + str(control_vars['start_iter_mod']) + ")", verbose, n_tabs=0, erase_line=True) control_vars['curr_epoch_iter'] += 1 control_vars['curr_iter'] += 1 curr_epoch_iter += 1 if not control_vars['output_filepath'] == '': with open(control_vars['output_filepath'], 'a') as f: f.write(msg + '\n') continue # save checkpoint after final iteration if control_vars['curr_iter'] == control_vars['num_iter']: print_verbose("\nReached final number of iterations: " + str(control_vars['num_iter']), verbose) print_verbose("\tSaving final model checkpoint...", verbose) final_model_dict = { 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'control_vars': control_vars, 'train_vars': train_vars, } trainer.save_checkpoint(final_model_dict, filename=train_vars['checkpoint_filenamebase'] + 'final' + str(control_vars['num_iter']) + '.pth.tar') control_vars['done_training'] = True break # start time counter start = time.time() # get data and targetas cuda variables target_heatmaps, target_joints, target_roothand = target data, target_heatmaps, target_joints, target_roothand = Variable(data), Variable(target_heatmaps),\ Variable(target_joints), Variable(target_roothand) if train_vars['use_cuda']: data = data.cuda() target_heatmaps = target_heatmaps.cuda() target_joints = target_joints.cuda() # get model output output = model(data) ''' visualize.plot_joints_from_heatmaps(target_heatmaps[0, :, :, :].cpu().data.numpy(), title='', data=data[0].cpu().data.numpy()) visualize.show() visualize.plot_image_and_heatmap(target_heatmaps[0][4].cpu().data.numpy(), data=data[0].cpu().data.numpy(), title='') visualize.show() visualize.plot_image_and_heatmap(output[3][0][4].cpu().data.numpy(), data=data[0].cpu().data.numpy(), title='') visualize.show() ''' # accumulate loss for sub-mini-batch if train_vars['cross_entropy']: loss_func = my_losses.cross_entropy_loss_p_logq else: loss_func = my_losses.euclidean_loss weights_heatmaps_loss, weights_joints_loss = get_loss_weights(control_vars['curr_iter']) loss, loss_heatmaps, loss_joints = my_losses.calculate_loss_JORNet( loss_func, output, target_heatmaps, target_joints, train_vars['joint_ixs'], weights_heatmaps_loss, weights_joints_loss, control_vars['iter_size']) loss.backward() train_vars['total_loss'] += loss.data[0] train_vars['total_joints_loss'] += loss_joints.data[0] train_vars['total_heatmaps_loss'] += loss_heatmaps.data[0] # accumulate pixel dist loss for sub-mini-batch train_vars['total_pixel_loss'] = my_losses.accumulate_pixel_dist_loss_multiple( train_vars['total_pixel_loss'], output[3], target_heatmaps, control_vars['batch_size']) train_vars['total_pixel_loss_sample'] = my_losses.accumulate_pixel_dist_loss_from_sample_multiple( train_vars['total_pixel_loss_sample'], output[3], target_heatmaps, control_vars['batch_size']) # get boolean variable stating whether a mini-batch has been completed minibatch_completed = (batch_idx+1) % control_vars['iter_size'] == 0 if minibatch_completed: # optimise for mini-batch optimizer.step() # clear optimiser optimizer.zero_grad() # append total loss train_vars['losses'].append(train_vars['total_loss']) # erase total loss total_loss = train_vars['total_loss'] train_vars['total_loss'] = 0 # append total joints loss train_vars['losses_joints'].append(train_vars['total_joints_loss']) # erase total joints loss train_vars['total_joints_loss'] = 0 # append total joints loss train_vars['losses_heatmaps'].append(train_vars['total_heatmaps_loss']) # erase total joints loss train_vars['total_heatmaps_loss'] = 0 # append dist loss train_vars['pixel_losses'].append(train_vars['total_pixel_loss']) # erase pixel dist loss train_vars['total_pixel_loss'] = [0] * len(model.joint_ixs) # append dist loss of sample from output train_vars['pixel_losses_sample'].append(train_vars['total_pixel_loss_sample']) # erase dist loss of sample from output train_vars['total_pixel_loss_sample'] = [0] * len(model.joint_ixs) # check if loss is better if train_vars['losses'][-1] < train_vars['best_loss']: train_vars['best_loss'] = train_vars['losses'][-1] print_verbose(" This is a best loss found so far: " + str(train_vars['losses'][-1]), verbose) train_vars['best_model_dict'] = { 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'control_vars': control_vars, 'train_vars': train_vars, } if train_vars['losses_joints'][-1] < train_vars['best_loss_joints']: train_vars['best_loss_joints'] = train_vars['losses_joints'][-1] if train_vars['losses_heatmaps'][-1] < train_vars['best_loss_heatmaps']: train_vars['best_loss_heatmaps'] = train_vars['losses_heatmaps'][-1] # log checkpoint if control_vars['curr_iter'] % control_vars['log_interval'] == 0: trainer.print_log_info(model, optimizer, epoch, total_loss, train_vars, control_vars) aa1 = target_joints[0].data.cpu().numpy() aa2 = output[7][0].data.cpu().numpy() output_joint_loss = np.sum(np.abs(aa1 - aa2)) / 63 msg = '' msg += print_verbose( "-------------------------------------------------------------------------------------------", verbose) + "\n" msg += print_verbose('\tJoint Coord Avg Loss for first image of current mini-batch: ' + str(output_joint_loss) + '\n', control_vars['verbose']) msg += print_verbose( "-------------------------------------------------------------------------------------------", verbose) + "\n" if not control_vars['output_filepath'] == '': with open(control_vars['output_filepath'], 'a') as f: f.write(msg + '\n') if control_vars['curr_iter'] % control_vars['log_interval_valid'] == 0: print_verbose("\nSaving model and checkpoint model for validation", verbose) checkpoint_model_dict = { 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'control_vars': control_vars, 'train_vars': train_vars, } trainer.save_checkpoint(checkpoint_model_dict, filename=train_vars['checkpoint_filenamebase'] + 'for_valid_' + str(control_vars['curr_iter']) + '.pth.tar') # print time lapse prefix = 'Training (Epoch #' + str(epoch) + ' ' + str(control_vars['curr_epoch_iter']) + '/' +\ str(control_vars['tot_iter']) + ')' + ', (Batch ' + str(control_vars['batch_idx']+1) +\ '(' + str(control_vars['iter_size']) + ')' + '/' +\ str(control_vars['num_batches']) + ')' + ', (Iter #' + str(control_vars['curr_iter']) +\ '(' + str(control_vars['batch_size']) + ')' +\ ' - log every ' + str(control_vars['log_interval']) + ' iter): ' control_vars['tot_toc'] = display_est_time_loop(control_vars['tot_toc'] + time.time() - start, control_vars['curr_iter'], control_vars['num_iter'], prefix=prefix) control_vars['curr_iter'] += 1 control_vars['start_iter'] = control_vars['curr_iter'] + 1 control_vars['curr_epoch_iter'] += 1 return train_vars, control_vars