def _train(model, optim_inf, scheduler_inf, checkpointer, epochs, train_loader, test_loader, stat_tracker, log_dir, device, decoder_training=False): ''' Training loop for optimizing encoder ''' # If mixed precision is on, will add the necessary hooks into the model # and optimizer for half() conversions model, optim_inf = mixed_precision.initialize(model, optim_inf) optim_raw = mixed_precision.get_optimizer(optim_inf) test = test_decoder_model if model.decoder_training else test_model # Nawid - This chooses which method of testing to use # get target LR for LR warmup -- assume same LR for all param groups for pg in optim_raw.param_groups: lr_real = pg['lr'] # IDK, maybe this helps? torch.cuda.empty_cache() # prepare checkpoint and stats accumulator next_epoch, total_updates = checkpointer.get_current_position() fast_stats = AverageMeterSet() # run main training loop for epoch in range(next_epoch, epochs): epoch_stats = AverageMeterSet() epoch_updates = 0 time_start = time.time() for _, ((images1, images2), labels) in enumerate( train_loader): # Nawid - obtains the images and the labels # get data and info about this minibatch labels = torch.cat([labels, labels]).to(device) images1 = images1.to(device) images2 = images2.to(device) # run forward pass through model to get global and local features res_dict = model(x1=images1, x2=images2, class_only=False) # compute costs for all self-supervised tasks loss_g2l = ( res_dict['g2l_1t5'] + res_dict['g2l_1t7'] + res_dict['g2l_5t5'] ) # Nawid - loss for the global to local features predictions loss_inf = loss_g2l + res_dict['lgt_reg'] if model.decoder_training: image_reconstructions = res_dict['decoder_output'] target_images = torch.cat( [images1, images2] ) # Nawid - Concatenate both batches along the dimension of number of training examples auxiliary_loss = loss_MSE(image_reconstructions, target_images) epoch_stats.update_dict( {'loss_decoder': auxiliary_loss.item()}, n=1) else: # compute loss for online evaluation classifiers lgt_glb_mlp, lgt_glb_lin = res_dict['class'] auxiliary_loss = ( loss_xent(lgt_glb_mlp, labels) + # Nawid - Loss for the classifier terms loss_xent(lgt_glb_lin, labels)) epoch_stats.update_dict({'loss_cls': auxiliary_loss.item()}, n=1) update_train_accuracies(epoch_stats, labels, lgt_glb_mlp, lgt_glb_lin) # do hacky learning rate warmup -- we stop when LR hits lr_real if (total_updates < 500): lr_scale = min(1., float(total_updates + 1) / 500.) for pg in optim_raw.param_groups: pg['lr'] = lr_scale * lr_real # reset gradient accumlators and do backprop loss_opt = auxiliary_loss #+loss_inf # Nawid - Total loss is the loss from the global to local prediction as well as the loss from the classifier predictions optim_inf.zero_grad() mixed_precision.backward( loss_opt, optim_inf) # backwards with fp32/fp16 awareness optim_inf.step() # record loss and accuracy on minibatch epoch_stats.update_dict( { # Nawid - Changed the update so that the auxillary loss is calculated above 'loss_inf': loss_inf.item(), 'loss_g2l': loss_g2l.item(), 'lgt_reg': res_dict['lgt_reg'].item(), 'loss_g2l_1t5': res_dict['g2l_1t5'].item(), 'loss_g2l_1t7': res_dict['g2l_1t7'].item(), 'loss_g2l_5t5': res_dict['g2l_5t5'].item() }, n=1) # shortcut diagnostics to deal with long epochs total_updates += 1 epoch_updates += 1 if (total_updates % 100) == 0: # IDK, maybe this helps? torch.cuda.empty_cache() time_stop = time.time() spu = (time_stop - time_start) / 100. print( 'Epoch {0:d}, {1:d} updates -- {2:.4f} sec/update'.format( epoch, epoch_updates, spu)) time_start = time.time() if (total_updates % 500) == 0: # record diagnostics eval_start = time.time() fast_stats = AverageMeterSet( ) # Nawid - This is short term stats which are reset regularly test( model, test_loader, device, fast_stats, log_dir, max_evals=100000 ) # Nawd - test is chosen to be test_decoder_model or test_model at the start of the function based on whether decoder training is occuring or not stat_tracker.record_stats( fast_stats.averages(total_updates, prefix='fast/') ) # Nawid - This is used to record the data in tensorboard, where the average of the different values are placed in tensorboard,total_updates is the index which is used to place information in tensorbard i believe eval_time = time.time() - eval_start stat_str = fast_stats.pretty_string(ignore=model.tasks) stat_str = '-- {0:d} updates, eval_time {1:.2f}: {2:s}'.format( total_updates, eval_time, stat_str) print(stat_str) # update learning rate scheduler_inf.step(epoch) test(model, test_loader, device, epoch_stats, log_dir, max_evals=500000) epoch_str = epoch_stats.pretty_string(ignore=model.tasks) diag_str = '{0:d}: {1:s}'.format(epoch, epoch_str) print(diag_str) sys.stdout.flush() stat_tracker.record_stats( epoch_stats.averages(epoch, prefix='costs/') ) # Nawid - This is used to update long-term stats which are used for a long-period of time checkpointer.update(epoch + 1, total_updates)
def _train(model, optim_inf, scheduler_inf, checkpointer, epochs, train_loader, test_loader, stat_tracker, log_dir, device, args): ''' Training loop for optimizing encoder ''' # If mixed precision is on, will add the necessary hooks into the model # and optimizer for half() conversions model, optim_inf = mixed_precision.initialize(model, optim_inf) optim_raw = mixed_precision.get_optimizer(optim_inf) # get target LR for LR warmup -- assume same LR for all param groups for pg in optim_raw.param_groups: lr_real = pg['lr'] # IDK, maybe this helps? # but it makes the training slow # torch.cuda.empty_cache() # prepare checkpoint and stats accumulator next_epoch, total_updates = checkpointer.get_current_position() fast_stats = AverageMeterSet() # run main training loop for epoch in range(next_epoch, epochs): epoch_stats = AverageMeterSet() epoch_updates = 0 time_start = time.time() for _, ((images1, images2), labels) in enumerate(train_loader): # get data and info about this minibatch labels = torch.cat([labels, labels]).to(device) images1 = images1.to(device) images2 = images2.to(device) # run forward pass through model to get global and local features res_dict = model(args, x1=images1, x2=images2, class_only=False) lgt_glb_mlp, lgt_glb_lin = res_dict['class'] # compute costs for all self-supervised tasks loss_g2l = (res_dict['g2l_1t5'] + res_dict['g2l_1t7'] + res_dict['g2l_5t5']) loss_inf = loss_g2l + res_dict['lgt_reg'] # compute loss for online evaluation classifiers loss_cls = (loss_xent(lgt_glb_mlp, labels) + loss_xent(lgt_glb_lin, labels)) # do hacky learning rate warmup -- we stop when LR hits lr_real if (total_updates < 500): lr_scale = min(1., float(total_updates + 1) / 500.) for pg in optim_raw.param_groups: pg['lr'] = lr_scale * lr_real # reset gradient accumlators and do backprop loss_opt = loss_inf + loss_cls optim_inf.zero_grad() mixed_precision.backward( loss_opt, optim_inf) # backwards with fp32/fp16 awareness if args.grad_clip_value and (total_updates >= 500): torch.nn.utils.clip_grad_value_(model.parameters(), args.grad_clip_value) optim_inf.step() # record loss and accuracy on minibatch epoch_stats.update_dict( { 'loss_inf': loss_inf.detach().item(), 'loss_cls': loss_cls.detach().item(), 'loss_g2l': loss_g2l.detach().item(), 'lgt_reg': res_dict['lgt_reg'].detach().item(), 'loss_g2l_1t5': res_dict['g2l_1t5'].detach().item(), 'loss_g2l_1t7': res_dict['g2l_1t7'].detach().item(), 'loss_g2l_5t5': res_dict['g2l_5t5'].detach().item() }, n=1) update_train_accuracies(epoch_stats, labels, lgt_glb_mlp, lgt_glb_lin) # shortcut diagnostics to deal with long epochs total_updates += 1 epoch_updates += 1 # this command makes the training slow # torch.cuda.empty_cache() if (total_updates % 100) == 0: # IDK, maybe this helps? time_stop = time.time() spu = (time_stop - time_start) / 100. print( 'Epoch {0:d}, {1:d} updates -- {2:.4f} sec/update'.format( epoch, epoch_updates, spu)) time_start = time.time() if (total_updates % 500) == 0: # record diagnostics eval_start = time.time() fast_stats = AverageMeterSet() test_model(args, model, test_loader, device, fast_stats, max_evals=100000) stat_tracker.record_stats( fast_stats.averages(total_updates, prefix='fast/')) eval_time = time.time() - eval_start stat_str = fast_stats.pretty_string(ignore=model.tasks) stat_str = '-- {0:d} updates, eval_time {1:.2f}: {2:s}'.format( total_updates, eval_time, stat_str) print(stat_str) # update learning rate scheduler_inf.step(epoch) test_model(args, model, test_loader, device, epoch_stats, max_evals=500000) epoch_str = epoch_stats.pretty_string(ignore=model.tasks) diag_str = '{0:d}: {1:s}'.format(epoch, epoch_str) print(diag_str) sys.stdout.flush() stat_tracker.record_stats(epoch_stats.averages(epoch, prefix='costs/')) checkpointer.update(epoch + 1, total_updates)
mods_cls = [m for m in model.class_modules] mods_to_opt = mods_inf + mods_cls optim_inf = torch.optim.Adam([{ 'params': mod.parameters(), 'lr': args.learning_rate } for mod in mods_to_opt], betas=(0.8, 0.999), weight_decay=1e-5, eps=1e-8) scheduler = MultiStepLR(optim_inf, milestones=[10, 20], gamma=0.2) epochs = 50 model, optim_inf = mixed_precision.initialize(model, optim_inf) optim_raw = mixed_precision.get_optimizer(optim_inf) for p in optim_raw.param_groups: lr_real = p['lr'] # -- # Run _ = torch.cuda.empty_cache() writer = SummaryWriter( logdir=os.path.join(os.path.join(args.output_dir, args.run_name), 'logs')) t = time() step_counter = 0 for epoch_idx in range(epochs):