for loss in train_obj: writer.add_scalar(f'train_{loss}', train_obj[loss], epoch) # Compute validation objective. val_obj = validate(gen_val, model, losses, report_freq=20) report_loss('Validation', val_obj['nll'], epoch) for loss in val_obj: writer.add_scalar(f'val_{loss}', val_obj[loss], epoch) if gen_plot is not None: plot_model_task(model, gen_plot, epoch, wd) update_learning_rate(opt, decay_rate=0.999, lowest=args.learning_rate / 10) # Update the best objective value and checkpoint the model. is_best = False if val_obj['nll'] < best_obj: best_obj = val_obj['nll'] is_best = True save_checkpoint(wd, { 'epoch': epoch + 1, 'state_dict': model.state_dict(), 'best_acc_top1': best_obj, 'optimizer': opt.state_dict() }, is_best=is_best)
log_path = "logs/" + file_name + "_" + str(experimentID) + ".log" if not os.path.exists("logs/"): utils.makedirs("logs/") logger = utils.get_logger(logpath=log_path, filepath=os.path.abspath(__file__)) logger.info(input_command) optimizer = optim.Adamax(model.parameters(), lr=args.lr) num_batches = data_obj["n_train_batches"] for itr in range(1, num_batches * (args.niters + 1)): optimizer.zero_grad() utils.update_learning_rate(optimizer, decay_rate=0.999, lowest=args.lr / 10) wait_until_kl_inc = 10 if itr // num_batches < wait_until_kl_inc: kl_coef = 0. else: kl_coef = (1 - 0.99**(itr // num_batches - wait_until_kl_inc)) batch_dict = utils.get_next_batch(data_obj["train_dataloader"]) train_res = model.compute_all_losses( batch_dict, n_traj_samples=3, kl_coef=kl_coef ) # for each elem in a batch, sample 3 trajectories from the same encoded posterior q(z|x) train_res["loss"].backward() optimizer.step()
def train_it( Model, Data_obj, args, file_name, ExperimentID, #Trainwriter, Validationwriter, input_command, Devices): """ parameters: Model, #List of Models Data_obj, #List of Data_objects which live on different devices args, file_name, ExperimentID, #List of IDs trainwriter, #List of TFwriters validationwriter, #List of TFwriters input_command, Devices #List of devices """ Ckpt_path = [] Top_ckpt_path = [] Best_test_acc = [] Best_test_acc_step = [] Logger = [] Optimizer = [] otherOptimizer = [] ODEOptimizer = [] for i, device in enumerate(Devices): Ckpt_path.append( os.path.join(args.save, "experiment_" + str(ExperimentID[i]) + '.ckpt')) Top_ckpt_path.append( os.path.join( args.save, "experiment_" + str(ExperimentID[i]) + '_topscore.ckpt')) Best_test_acc.append(0) Best_test_acc_step.append(0) log_path = "logs/" + file_name + "_" + str(ExperimentID[i]) + ".log" if not os.path.exists("logs/"): utils.makedirs("logs/") Logger.append( utils.get_logger(logpath=log_path, filepath=os.path.abspath(__file__))) Logger[i].info(input_command) Optimizer.append( get_optimizer(args.optimizer, args.lr, Model[i].parameters())) num_batches = Data_obj[0]["n_train_batches"] labels = Data_obj[0]["dataset_obj"].label_list #create empty lists for results and similar num_gpus = len(Devices) train_res = [None] * num_gpus batch_dict = [None] * num_gpus test_res = [None] * num_gpus label_dict = [None] * num_gpus # empty result placeholder somedict = {} test_res = [somedict] test_res[0]["accuracy"] = float(0) if args.v == 1 or args.v == 2: pbar = tqdm(range(1, num_batches * (args.niters) + 1), position=0, leave=True, ncols=160) else: pbar = range(1, num_batches * (args.niters) + 1) for itr in pbar: for i, device in enumerate(Devices): Optimizer[i].zero_grad() for i, device in enumerate(Devices): # default decay_rate = 0.999, lowest= args.lr/10 # original # decay_rate = 0.9995, lowest = args.lr / 50 # new utils.update_learning_rate(Optimizer[i], decay_rate=args.lrdecay, lowest=args.lr / 1000) wait_until_kl_inc = 10 if itr // num_batches < wait_until_kl_inc: kl_coef = 0.01 else: kl_coef = (1 - 0.99**(itr // num_batches - wait_until_kl_inc)) for i, device in enumerate(Devices): batch_dict[i] = utils.get_next_batch( Data_obj[i]["train_dataloader"]) for i, device in enumerate(Devices): train_res[i] = Model[i].compute_all_losses(batch_dict[i], n_traj_samples=3, kl_coef=kl_coef) for i, device in enumerate(Devices): train_res[i]["loss"].backward() for i, device in enumerate(Devices): Optimizer[i].step() n_iters_to_viz = 0.333 if args.dataset == "swisscrop": n_iters_to_viz /= 20 if (itr != 0) and (itr % args.val_freq) == 0: with torch.no_grad(): # Calculate labels and loss on test data for i, device in enumerate(Devices): test_res[i], label_dict[i] = compute_loss_all_batches( Model[i], Data_obj[i]["test_dataloader"], args, n_batches=Data_obj[i]["n_test_batches"], experimentID=ExperimentID[i], device=Devices[i], n_traj_samples=3, kl_coef=kl_coef) for i, device in enumerate(Devices): #make confusion matrix cm, conf_fig = plot_confusion_matrix( label_dict[0]["correct_labels"], label_dict[0]["predict_labels"], Data_obj[0]["dataset_obj"].label_list, tensor_name='dev/cm') Validationwriter[i].add_figure( "Validation_Confusionmatrix", conf_fig, itr * args.batch_size) # prepare GT labels and predictions y_ref_train = torch.argmax( train_res[0]['label_predictions'], dim=2).squeeze().cpu() y_pred_train = torch.argmax(batch_dict[0]['labels'], dim=1).cpu() y_ref = label_dict[0]["correct_labels"].cpu() y_pred = label_dict[0]["predict_labels"] # prepare GT labels and predictions y_ref_train = torch.argmax( train_res[0]['label_predictions'], dim=2).squeeze().cpu() y_pred_train = torch.argmax(batch_dict[0]['labels'], dim=1).cpu() y_ref = label_dict[0]["correct_labels"].cpu() y_pred = label_dict[0]["predict_labels"] #Make checkpoint torch.save( { 'args': args, 'state_dict': Model[i].state_dict(), }, Ckpt_path[i]) if test_res[i]["accuracy"] > Best_test_acc[i]: Best_test_acc[i] = test_res[i]["accuracy"] Best_test_acc_step[i] = itr * args.batch_size torch.save( { 'args': args, 'state_dict': Model[i].state_dict(), 'cm': cm }, Top_ckpt_path[i]) #utils.plot_confusion_matrix2(y_ref, y_pred, Data_obj[0]["dataset_obj"].label_list, ExperimentID[i]) # Save trajectory here #if not test_res[i]["PCA_traj"] is None: # with open( os.path.join('vis', 'traj_dict' + str(ExperimentID[i]) + '.pickle' ), 'wb') as handle: # pickle.dump(test_res[i]["PCA_traj"], handle, protocol=pickle.HIGHEST_PROTOCOL) # make PCA visualization if "PCA_traj" in test_res[0]: #PCA_fig = get_pca_fig(test_res[0]["PCA_traj"]["PCA_trajs1"]) PCA_fig = None else: PCA_fig = None logdict = { 'Classification_accuracy/train': train_res[i]["accuracy"], 'Classification_accuracy/validation': test_res[i]["accuracy"], 'Classification_accuracy/validation_peak': Best_test_acc[i], 'Classification_accuracy/validation_peak_step': Best_test_acc_step[i], 'loss/train': train_res[i]["loss"].detach(), 'loss/validation': test_res[i]["loss"].detach(), 'Other_metrics/train_cm': sklearn_cm(y_ref_train, y_pred_train), 'Other_metrics/train_precision': precision_score(y_ref_train, y_pred_train, average='macro'), 'Other_metrics/train_recall': recall_score(y_ref_train, y_pred_train, average='macro'), 'Other_metrics/train_f1': f1_score(y_ref_train, y_pred_train, average='macro'), 'Other_metrics/train_kappa': cohen_kappa_score(y_ref_train, y_pred_train), 'Other_metrics/validation_cm': sklearn_cm(y_ref, y_pred), 'Other_metrics/validation_precision': precision_score(y_ref, y_pred, average='macro'), 'Other_metrics/validation_recall': recall_score(y_ref, y_pred, average='macro'), 'Other_metrics/validation_f1': f1_score(y_ref, y_pred, average='macro'), 'Other_metrics/validation_kappa': cohen_kappa_score(y_ref, y_pred), } if "PCA_traj" in test_res[0]: pass #logdict['Visualization/latent_trajectory'] = wandb.Image( get_pca_fig(test_res[0]["PCA_traj"]) ) wandb.log(logdict, step=itr * args.batch_size) # wandb.sklearn.plot_confusion_matrix(y_ref, y_pred, labels) # Write training loss and accuracy after every batch (Only recommanded for debugging) fine_train_writer = False if fine_train_writer: if "loss" in train_res[i]: Validationwriter[i].add_scalar('loss/train', train_res[i]["loss"].detach(), itr * args.batch_size) if "accuracy" in train_res[i]: Validationwriter[i].add_scalar('Classification_accuracy/train', train_res[i]["accuracy"], itr * args.batch_size) #update progressbar if args.v == 2: pbar.set_description( "Train Ac: {:.3f} % | Test Ac: {:.3f} %, Peak Test Ac.: {:.3f} % (at {} batches) |" .format(train_res[0]["accuracy"] * 100, test_res[0]["accuracy"] * 100, Best_test_acc[i] * 100, Best_test_acc_step[0] // args.batch_size)) #empty all training variables #train_res = [None] * num_gpus batch_dict = [None] * num_gpus #test_res = [None] * num_gpus label_dict = [None] * num_gpus print(Best_test_acc[0], " at step ", Best_test_acc_step[0]) return train_res, test_res, Best_test_acc[0], Best_test_acc_step[0]