def main(): # enable mixed-precision computation if desired if args.amp: mixed_precision.enable_mixed_precision() torch.manual_seed(args.seed) torch.cuda.manual_seed(args.seed) # get the dataset dataset = get_dataset(args.dataset) _, test_loader, _ = build_dataset( dataset=dataset, batch_size=args.batch_size, input_dir=args.input_dir ) torch_device = torch.device("cuda") checkpointer = Checkpointer() model = checkpointer.restore_model_from_checkpoint(args.checkpoint_path) model = model.to(torch_device) model, _ = mixed_precision.initialize(model, None) test_stats = AverageMeterSet() test(model, test_loader, torch_device, test_stats) stat_str = test_stats.pretty_string(ignore=model.tasks) print(stat_str)
def _train(model, optimizer, scheduler, checkpointer, epochs, train_loader, test_loader, stat_tracker, log_dir, device): ''' Training loop to train classifiers on top of an encoder with fixed weights. -- e.g., use this for eval or running on new data ''' # If mixed precision is on, will add the necessary hooks into the model and # optimizer for half precision conversions model, optimizer = mixed_precision.initialize(model, optimizer) # ... time_start = time.time() total_updates = 0 next_epoch, total_updates = checkpointer.get_current_position( classifier=True) for epoch in range(next_epoch, epochs): epoch_updates = 0 epoch_stats = AverageMeterSet() for _, ((images1, images2), labels) in enumerate(train_loader): # get data and info about this minibatch images1 = images1.to(device) images2 = images2.to(device) labels = labels.to(device) # run forward pass through model and collect activations res_dict = model(x1=images1, x2=images2, class_only=True) lgt_glb_mlp, lgt_glb_lin = res_dict['class'] # compute total loss for optimization loss = (loss_xent(lgt_glb_mlp, labels) + loss_xent(lgt_glb_lin, labels)) # do optimizer step for encoder optimizer.zero_grad() mixed_precision.backward( loss, optimizer) # special mixed precision stuff optimizer.step() # record loss and accuracy on minibatch epoch_stats.update('loss', loss.item(), n=1) update_train_accuracies(epoch_stats, labels, lgt_glb_mlp, lgt_glb_lin) # shortcut diagnostics to deal with long epochs total_updates += 1 epoch_updates += 1 if (total_updates % 100) == 0: time_stop = time.time() spu = (time_stop - time_start) / 100. print( 'Epoch {0:d}, {1:d} updates -- {2:.4f} sec/update'.format( epoch, epoch_updates, spu)) time_start = time.time() # step learning rate scheduler scheduler.step(epoch) # record diagnostics test_model(model, test_loader, device, epoch_stats, max_evals=500000) epoch_str = epoch_stats.pretty_string(ignore=model.tasks) diag_str = '{0:d}: {1:s}'.format(epoch, epoch_str) print(diag_str) sys.stdout.flush() stat_tracker.record_stats(epoch_stats.averages(epoch, prefix='eval/')) checkpointer.update(epoch + 1, total_updates, classifier=True)
def _train(model, optim_inf, scheduler_inf, checkpointer, epochs, train_loader, test_loader, stat_tracker, log_dir, device, args): ''' Training loop for optimizing encoder ''' # If mixed precision is on, will add the necessary hooks into the model # and optimizer for half() conversions model, optim_inf = mixed_precision.initialize(model, optim_inf) optim_raw = mixed_precision.get_optimizer(optim_inf) # get target LR for LR warmup -- assume same LR for all param groups for pg in optim_raw.param_groups: lr_real = pg['lr'] # IDK, maybe this helps? # but it makes the training slow # torch.cuda.empty_cache() # prepare checkpoint and stats accumulator next_epoch, total_updates = checkpointer.get_current_position() fast_stats = AverageMeterSet() # run main training loop for epoch in range(next_epoch, epochs): epoch_stats = AverageMeterSet() epoch_updates = 0 time_start = time.time() for _, ((images1, images2), labels) in enumerate(train_loader): # get data and info about this minibatch labels = torch.cat([labels, labels]).to(device) images1 = images1.to(device) images2 = images2.to(device) # run forward pass through model to get global and local features res_dict = model(args, x1=images1, x2=images2, class_only=False) lgt_glb_mlp, lgt_glb_lin = res_dict['class'] # compute costs for all self-supervised tasks loss_g2l = (res_dict['g2l_1t5'] + res_dict['g2l_1t7'] + res_dict['g2l_5t5']) loss_inf = loss_g2l + res_dict['lgt_reg'] # compute loss for online evaluation classifiers loss_cls = (loss_xent(lgt_glb_mlp, labels) + loss_xent(lgt_glb_lin, labels)) # do hacky learning rate warmup -- we stop when LR hits lr_real if (total_updates < 500): lr_scale = min(1., float(total_updates + 1) / 500.) for pg in optim_raw.param_groups: pg['lr'] = lr_scale * lr_real # reset gradient accumlators and do backprop loss_opt = loss_inf + loss_cls optim_inf.zero_grad() mixed_precision.backward( loss_opt, optim_inf) # backwards with fp32/fp16 awareness if args.grad_clip_value and (total_updates >= 500): torch.nn.utils.clip_grad_value_(model.parameters(), args.grad_clip_value) optim_inf.step() # record loss and accuracy on minibatch epoch_stats.update_dict( { 'loss_inf': loss_inf.detach().item(), 'loss_cls': loss_cls.detach().item(), 'loss_g2l': loss_g2l.detach().item(), 'lgt_reg': res_dict['lgt_reg'].detach().item(), 'loss_g2l_1t5': res_dict['g2l_1t5'].detach().item(), 'loss_g2l_1t7': res_dict['g2l_1t7'].detach().item(), 'loss_g2l_5t5': res_dict['g2l_5t5'].detach().item() }, n=1) update_train_accuracies(epoch_stats, labels, lgt_glb_mlp, lgt_glb_lin) # shortcut diagnostics to deal with long epochs total_updates += 1 epoch_updates += 1 # this command makes the training slow # torch.cuda.empty_cache() if (total_updates % 100) == 0: # IDK, maybe this helps? time_stop = time.time() spu = (time_stop - time_start) / 100. print( 'Epoch {0:d}, {1:d} updates -- {2:.4f} sec/update'.format( epoch, epoch_updates, spu)) time_start = time.time() if (total_updates % 500) == 0: # record diagnostics eval_start = time.time() fast_stats = AverageMeterSet() test_model(args, model, test_loader, device, fast_stats, max_evals=100000) stat_tracker.record_stats( fast_stats.averages(total_updates, prefix='fast/')) eval_time = time.time() - eval_start stat_str = fast_stats.pretty_string(ignore=model.tasks) stat_str = '-- {0:d} updates, eval_time {1:.2f}: {2:s}'.format( total_updates, eval_time, stat_str) print(stat_str) # update learning rate scheduler_inf.step(epoch) test_model(args, model, test_loader, device, epoch_stats, max_evals=500000) epoch_str = epoch_stats.pretty_string(ignore=model.tasks) diag_str = '{0:d}: {1:s}'.format(epoch, epoch_str) print(diag_str) sys.stdout.flush() stat_tracker.record_stats(epoch_stats.averages(epoch, prefix='costs/')) checkpointer.update(epoch + 1, total_updates)
def _train(model, optim_inf, scheduler_inf, checkpointer, epochs, train_loader, test_loader, stat_tracker, log_dir, device, decoder_training=False): ''' Training loop for optimizing encoder ''' # If mixed precision is on, will add the necessary hooks into the model # and optimizer for half() conversions model, optim_inf = mixed_precision.initialize(model, optim_inf) optim_raw = mixed_precision.get_optimizer(optim_inf) test = test_decoder_model if model.decoder_training else test_model # Nawid - This chooses which method of testing to use # get target LR for LR warmup -- assume same LR for all param groups for pg in optim_raw.param_groups: lr_real = pg['lr'] # IDK, maybe this helps? torch.cuda.empty_cache() # prepare checkpoint and stats accumulator next_epoch, total_updates = checkpointer.get_current_position() fast_stats = AverageMeterSet() # run main training loop for epoch in range(next_epoch, epochs): epoch_stats = AverageMeterSet() epoch_updates = 0 time_start = time.time() for _, ((images1, images2), labels) in enumerate( train_loader): # Nawid - obtains the images and the labels # get data and info about this minibatch labels = torch.cat([labels, labels]).to(device) images1 = images1.to(device) images2 = images2.to(device) # run forward pass through model to get global and local features res_dict = model(x1=images1, x2=images2, class_only=False) # compute costs for all self-supervised tasks loss_g2l = ( res_dict['g2l_1t5'] + res_dict['g2l_1t7'] + res_dict['g2l_5t5'] ) # Nawid - loss for the global to local features predictions loss_inf = loss_g2l + res_dict['lgt_reg'] if model.decoder_training: image_reconstructions = res_dict['decoder_output'] target_images = torch.cat( [images1, images2] ) # Nawid - Concatenate both batches along the dimension of number of training examples auxiliary_loss = loss_MSE(image_reconstructions, target_images) epoch_stats.update_dict( {'loss_decoder': auxiliary_loss.item()}, n=1) else: # compute loss for online evaluation classifiers lgt_glb_mlp, lgt_glb_lin = res_dict['class'] auxiliary_loss = ( loss_xent(lgt_glb_mlp, labels) + # Nawid - Loss for the classifier terms loss_xent(lgt_glb_lin, labels)) epoch_stats.update_dict({'loss_cls': auxiliary_loss.item()}, n=1) update_train_accuracies(epoch_stats, labels, lgt_glb_mlp, lgt_glb_lin) # do hacky learning rate warmup -- we stop when LR hits lr_real if (total_updates < 500): lr_scale = min(1., float(total_updates + 1) / 500.) for pg in optim_raw.param_groups: pg['lr'] = lr_scale * lr_real # reset gradient accumlators and do backprop loss_opt = auxiliary_loss #+loss_inf # Nawid - Total loss is the loss from the global to local prediction as well as the loss from the classifier predictions optim_inf.zero_grad() mixed_precision.backward( loss_opt, optim_inf) # backwards with fp32/fp16 awareness optim_inf.step() # record loss and accuracy on minibatch epoch_stats.update_dict( { # Nawid - Changed the update so that the auxillary loss is calculated above 'loss_inf': loss_inf.item(), 'loss_g2l': loss_g2l.item(), 'lgt_reg': res_dict['lgt_reg'].item(), 'loss_g2l_1t5': res_dict['g2l_1t5'].item(), 'loss_g2l_1t7': res_dict['g2l_1t7'].item(), 'loss_g2l_5t5': res_dict['g2l_5t5'].item() }, n=1) # shortcut diagnostics to deal with long epochs total_updates += 1 epoch_updates += 1 if (total_updates % 100) == 0: # IDK, maybe this helps? torch.cuda.empty_cache() time_stop = time.time() spu = (time_stop - time_start) / 100. print( 'Epoch {0:d}, {1:d} updates -- {2:.4f} sec/update'.format( epoch, epoch_updates, spu)) time_start = time.time() if (total_updates % 500) == 0: # record diagnostics eval_start = time.time() fast_stats = AverageMeterSet( ) # Nawid - This is short term stats which are reset regularly test( model, test_loader, device, fast_stats, log_dir, max_evals=100000 ) # Nawd - test is chosen to be test_decoder_model or test_model at the start of the function based on whether decoder training is occuring or not stat_tracker.record_stats( fast_stats.averages(total_updates, prefix='fast/') ) # Nawid - This is used to record the data in tensorboard, where the average of the different values are placed in tensorboard,total_updates is the index which is used to place information in tensorbard i believe eval_time = time.time() - eval_start stat_str = fast_stats.pretty_string(ignore=model.tasks) stat_str = '-- {0:d} updates, eval_time {1:.2f}: {2:s}'.format( total_updates, eval_time, stat_str) print(stat_str) # update learning rate scheduler_inf.step(epoch) test(model, test_loader, device, epoch_stats, log_dir, max_evals=500000) epoch_str = epoch_stats.pretty_string(ignore=model.tasks) diag_str = '{0:d}: {1:s}'.format(epoch, epoch_str) print(diag_str) sys.stdout.flush() stat_tracker.record_stats( epoch_stats.averages(epoch, prefix='costs/') ) # Nawid - This is used to update long-term stats which are used for a long-period of time checkpointer.update(epoch + 1, total_updates)
def _train(model, optimizer, scheduler, checkpointer, epochs, train_loader, test_loader, stat_tracker, log_dir, device): ''' Training loop to train classifiers on top of an encoder with fixed weights. -- e.g., use this for eval or running on new data ''' # If mixed precision is on, will add the necessary hooks into the model and # optimizer for half precision conversions model, optimizer = mixed_precision.initialize(model, optimizer) # ... time_start = time.time() total_updates = 0 next_epoch, total_updates = checkpointer.get_current_position( decoder=True ) # Nawid - I think if I am continuing training, then it finds the current epoch it is on for epoch in range(next_epoch, epochs): epoch_updates = 0 epoch_stats = AverageMeterSet() #for _, ((images1, images2), labels) in enumerate(test_loader): # Nawid - loads the two different images for _, ((images1, images2), labels) in enumerate( train_loader): # Nawid - loads the two different images # get data and info about this minibatch images1 = images1.to(device) images2 = images2.to(device) labels = labels.to(device) # run forward pass through model and collect activations res_dict = model( x1=images1, x2=images2, decoder_only=True ) # Nawid - Only requires the first input, produces the dictionary of outputs which should be the encoder value and a decoded output image_reconstructions = res_dict[ 'decoder_output'] # Nawid- Obtains the logits from the mlp and the linear # compute total loss for optimization loss = loss_MSE( images1, image_reconstructions ) # Nawid- Compute the loss using the mlp and the linear layer - there is no loss term related to the encoder and the input to the loss term is the logits which # do optimizer step for encoder optimizer.zero_grad() mixed_precision.backward( loss, optimizer) # special mixed precision stuff optimizer.step() # Nawid - Gradient step # record loss and accuracy on minibatch epoch_stats.update('loss', loss.item(), n=1) # - NEED TO CHANGE THIS FOR THE DECODER update_train_accuracies(epoch_stats, labels, lgt_glb_mlp, lgt_glb_lin) # Nawid - updates the accuracies # shortcut diagnostics to deal with long epochs total_updates += 1 epoch_updates += 1 if (total_updates % 100) == 0: save_reconstructions(images1, image_reconstructions) time_stop = time.time() spu = (time_stop - time_start) / 100. print( 'Epoch {0:d}, {1:d} updates -- {2:.4f} sec/update'.format( epoch, epoch_updates, spu)) time_start = time.time() # step learning rate scheduler scheduler.step(epoch) # record diagnostics test_decoder_model( model, test_loader, device, epoch_stats, max_evals=500000) # Nawid - NEED TO CHANGE FOR DECODER I BELIEVE epoch_str = epoch_stats.pretty_string(ignore=model.tasks) diag_str = '{0:d}: {1:s}'.format(epoch, epoch_str) print(diag_str) sys.stdout.flush() stat_tracker.record_stats( epoch_stats.averages( epoch, prefix='decoder/')) # NAWID - NEED TO CHANGE FOR DECODER checkpointer.update(epoch + 1, total_updates, decoder=True) # Nawid - Updates the decoder