def val_test(args): writer = SummaryWriter('./logs/{0}'.format(args.output_folder)) save_filename = './models/{0}'.format(args.output_folder) train_loader, valid_loader, test_loader = train_util.get_dataloaders(args) recons_input_img = train_util.log_input_img_grid(test_loader, writer) input_dim = 3 model = VectorQuantizedVAE(input_dim, args.hidden_size, args.k, args.enc_type, args.dec_type) # if torch.cuda.device_count() > 1 and args.device == "cuda": # model = torch.nn.DataParallel(model) optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) discriminators = {} if args.recons_loss == "gan": recons_disc = Discriminator(input_dim, args.img_res, args.input_type).to(args.device) recons_disc_opt = torch.optim.Adam(recons_disc.parameters(), lr=args.disc_lr, amsgrad=True) discriminators["recons_disc"] = [recons_disc, recons_disc_opt] model.to(args.device) for disc in discriminators: discriminators[disc][0].to(args.device) if args.weights == "load": start_epoch = train_util.load_state(save_filename, model, optimizer, discriminators) else: start_epoch = 0 stop_patience = args.stop_patience best_loss = torch.tensor(np.inf) for epoch in tqdm(range(start_epoch, 4), file=sys.stdout): val_loss_dict, z = train_util.test(get_losses, model, valid_loader, args, discriminators, True) # if args.weights == "init" and epoch==1: # epoch+=1 # break train_util.log_recons_img_grid(recons_input_img, model, epoch + 1, args.device, writer) train_util.log_interp_img_grid(recons_input_img, model, epoch + 1, args.device, writer) train_util.log_losses("val", val_loss_dict, epoch + 1, writer) train_util.log_latent_metrics("val", z, epoch + 1, writer) train_util.save_state(model, optimizer, discriminators, val_loss_dict["recons_loss"], best_loss, args.recons_loss, epoch, save_filename) print(val_loss_dict)
def main(args): writer = SummaryWriter('./logs/{0}'.format(args.output_folder)) save_filename = './models/{0}'.format(args.output_folder) train_loader, valid_loader, test_loader = train_util.get_dataloaders(args) num_channels = 3 model = VectorQuantizedVAE(num_channels, args.hidden_size, args.k, args.enc_type, args.dec_type) model.to(args.device) # Fixed images for Tensorboard recons_input_img = train_util.log_input_img_grid(test_loader, writer) train_util.log_recons_img_grid(recons_input_img, model, 0, args.device, writer) discriminators = {} input_dim = 3 if args.recons_loss != "mse": if args.recons_loss == "gan": recons_disc = Discriminator(input_dim, args.img_res, args.input_type).to(args.device) elif args.recons_loss == "comp": recons_disc = AnchorComparator(input_dim * 2, args.img_res, args.input_type).to(args.device) elif "comp_2" in args.recons_loss: recons_disc = ClubbedPermutationComparator( input_dim * 2, args.img_res, args.input_type).to(args.device) elif "comp_6" in args.recons_loss: recons_disc = FullPermutationComparator( input_dim * 2, args.img_res, args.input_type).to(args.device) recons_disc_opt = torch.optim.Adam(recons_disc.parameters(), lr=args.disc_lr, amsgrad=True) recons_disc_lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( recons_disc_opt, "min", patience=args.lr_patience, factor=0.5, threshold=args.threshold, threshold_mode="abs", min_lr=1e-7) discriminators["recons_disc"] = [recons_disc, recons_disc_opt] optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) ae_lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( optimizer, "min", patience=args.lr_patience, factor=0.5, threshold=args.threshold, threshold_mode="abs", min_lr=1e-7) if torch.cuda.device_count() > 1: model = train_util.ae_data_parallel(model) for disc in discriminators: discriminators[disc][0] = torch.nn.DataParallel( discriminators[disc][0]) model.to(args.device) for disc in discriminators: discriminators[disc][0].to(args.device) # Generate the samples first once recons_input_img = train_util.log_input_img_grid(test_loader, writer) train_util.log_recons_img_grid(recons_input_img, model, 0, args.device, writer) if args.weights == "load": start_epoch = train_util.load_state(save_filename, model, optimizer, discriminators) else: start_epoch = 0 stop_patience = args.stop_patience best_loss = torch.tensor(np.inf) for epoch in tqdm(range(start_epoch, args.num_epochs), file=sys.stdout): try: train(epoch, train_loader, model, optimizer, args, writer, discriminators) except RuntimeError as err: print("".join( traceback.TracebackException.from_exception(err).format()), file=sys.stderr) print("*******") print(err, file=sys.stderr) print(f"batch_size:{args.batch_size}", file=sys.stderr) exit(0) val_loss_dict, z = train_util.test(get_losses, model, valid_loader, args, discriminators) train_util.log_recons_img_grid(recons_input_img, model, epoch + 1, args.device, writer) train_util.log_interp_img_grid(recons_input_img, model, epoch + 1, args.device, writer) train_util.log_losses("val", val_loss_dict, epoch + 1, writer) train_util.log_latent_metrics("val", z, epoch + 1, writer) train_util.save_state(model, optimizer, discriminators, val_loss_dict["recons_loss"], best_loss, args.recons_loss, epoch, save_filename) # early stop check # if val_loss_dict["recons_loss"] - best_loss < args.threshold: # stop_patience -= 1 # else: # stop_patience = args.stop_patience # if stop_patience == 0: # print("training early stopped!") # break ae_lr_scheduler.step(val_loss_dict["recons_loss"]) if args.recons_loss != "mse": recons_disc_lr_scheduler.step(val_loss_dict["recons_disc_loss"])