def val_test(args): writer = SummaryWriter('./logs/{0}'.format(args.output_folder)) save_filename = './models/{0}'.format(args.output_folder) train_loader, valid_loader, test_loader = train_util.get_dataloaders(args) recons_input_img = train_util.log_input_img_grid(test_loader, writer) input_dim = 3 model = VectorQuantizedVAE(input_dim, args.hidden_size, args.k, args.enc_type, args.dec_type) # if torch.cuda.device_count() > 1 and args.device == "cuda": # model = torch.nn.DataParallel(model) optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) discriminators = {} if args.recons_loss == "gan": recons_disc = Discriminator(input_dim, args.img_res, args.input_type).to(args.device) recons_disc_opt = torch.optim.Adam(recons_disc.parameters(), lr=args.disc_lr, amsgrad=True) discriminators["recons_disc"] = [recons_disc, recons_disc_opt] model.to(args.device) for disc in discriminators: discriminators[disc][0].to(args.device) if args.weights == "load": start_epoch = train_util.load_state(save_filename, model, optimizer, discriminators) else: start_epoch = 0 stop_patience = args.stop_patience best_loss = torch.tensor(np.inf) for epoch in tqdm(range(start_epoch, 4), file=sys.stdout): val_loss_dict, z = train_util.test(get_losses, model, valid_loader, args, discriminators, True) # if args.weights == "init" and epoch==1: # epoch+=1 # break train_util.log_recons_img_grid(recons_input_img, model, epoch + 1, args.device, writer) train_util.log_interp_img_grid(recons_input_img, model, epoch + 1, args.device, writer) train_util.log_losses("val", val_loss_dict, epoch + 1, writer) train_util.log_latent_metrics("val", z, epoch + 1, writer) train_util.save_state(model, optimizer, discriminators, val_loss_dict["recons_loss"], best_loss, args.recons_loss, epoch, save_filename) print(val_loss_dict)
def main(args): writer = SummaryWriter('./logs/{0}'.format(args.output_folder)) save_filename = './models/{0}'.format(args.output_folder) train_loader, val_loader, test_loader = train_util.get_dataloaders(args) recons_input_img = train_util.log_input_img_grid(test_loader, writer) input_dim = 3 model = ACAI(args.img_res, input_dim, args.hidden_size, args.enc_type, args.dec_type).to(args.device) disc = Discriminator(input_dim, args.img_res, args.input_type).to(args.device) disc_opt = torch.optim.Adam(disc.parameters(), lr=args.disc_lr, amsgrad=True) # if torch.cuda.device_count() > 1 and args.device == "cuda": # model = torch.nn.DataParallel(model) opt = torch.optim.Adam(model.parameters(), lr=args.lr) # ae_lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, "min", patience=args.lr_patience, factor=0.5, # threshold=args.threshold, threshold_mode="abs", min_lr=1e-7) # interp_disc_lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(disc_opt, "min", patience=args.lr_patience, factor=0.5, # threshold=args.threshold, threshold_mode="abs", min_lr=1e-7) discriminators = {"interp_disc": [disc, disc_opt]} if args.recons_loss != "mse": if args.recons_loss == "gan": recons_disc = Discriminator(input_dim, args.img_res, args.input_type).to(args.device) elif args.recons_loss == "comp": recons_disc = AnchorComparator(input_dim * 2, args.img_res, args.input_type).to(args.device) elif "comp_2" in args.recons_loss: recons_disc = ClubbedPermutationComparator( input_dim * 2, args.img_res, args.input_type).to(args.device) elif "comp_6" in args.recons_loss: recons_disc = FullPermutationComparator( input_dim * 2, args.img_res, args.input_type).to(args.device) recons_disc_opt = torch.optim.Adam(recons_disc.parameters(), lr=args.disc_lr, amsgrad=True) discriminators["recons_disc"] = [recons_disc, recons_disc_opt] # Generate the samples first once train_util.save_recons_img_grid("test", recons_input_img, model, 0, args) if args.weights == "load": start_epoch = train_util.load_state(save_filename, model, opt, discriminators) else: start_epoch = 0 best_loss = torch.tensor(np.inf) for epoch in range(args.num_epochs): print("Epoch {}:".format(epoch)) train(model, opt, train_loader, args, discriminators, writer) # curr_loss = val(model, val_loader) # print(f"epoch val loss:{curr_loss}") val_loss_dict, z = train_util.test(get_losses, model, val_loader, args, discriminators) train_util.log_losses("val", val_loss_dict, epoch + 1, writer) train_util.log_latent_metrics("val", z, epoch + 1, writer) # train_util.log_recons_img_grid(recons_input_img, model, epoch+1, args.device, writer) # train_util.log_interp_img_grid(recons_input_img, model, epoch+1, args.device, writer) train_util.save_recons_img_grid("val", recons_input_img, model, epoch + 1, args) train_util.save_interp_img_grid("val", recons_input_img, model, epoch + 1, args) train_util.save_state(model, opt, discriminators, val_loss_dict["recons_loss"], best_loss, args.recons_loss, epoch, save_filename)
def main(args): train_loader, val_loader, test_loader = train_util.get_dataloaders(args) input_dim = 3 model = VAE(input_dim, args.hidden_size, args.enc_type, args.dec_type) opt = torch.optim.Adam(model.parameters(), lr=LR, amsgrad=True) discriminators = {} if args.recons_loss != "mse": if args.recons_loss == "gan": recons_disc = Discriminator(input_dim, args.img_res, args.input_type).to(args.device) elif args.recons_loss == "comp": recons_disc = AnchorComparator(input_dim*2, args.img_res, args.input_type).to(args.device) elif "comp_2" in args.recons_loss: recons_disc = ClubbedPermutationComparator(input_dim*2, args.img_res, args.input_type).to(args.device) elif "comp_6" in args.recons_loss: recons_disc = FullPermutationComparator(input_dim*2, args.img_res, args.input_type).to(args.device) recons_disc_opt = torch.optim.Adam(recons_disc.parameters(), lr=args.disc_lr, amsgrad=True) discriminators["recons_disc"] = [recons_disc, recons_disc_opt] if torch.cuda.device_count() > 1: model = train_util.ae_data_parallel(model) for disc in discriminators: discriminators[disc][0] = torch.nn.DataParallel(discriminators[disc][0]) model.to(args.device) model_name = f"vae_{args.recons_loss}" if args.output_folder is None: args.output_folder = os.path.join(model_name, args.dataset, f"depth_{args.enc_type}_{args.dec_type}_hs_{args.img_res}_{args.hidden_size}") log_save_path = os.path.join("./logs", args.output_folder) model_save_path = os.path.join("./models", args.output_folder) if not os.path.exists(log_save_path): os.makedirs(log_save_path) print(f"log:{log_save_path}", file=sys.stderr) sys.stderr.flush() if not os.path.exists(model_save_path): os.makedirs(model_save_path) writer = SummaryWriter(log_save_path) print(f"train loader length:{len(train_loader)}", file=sys.stderr) best_loss = torch.tensor(np.inf) if args.weights == "load": start_epoch = train_util.load_state(model_save_path, model, opt, discriminators) else: start_epoch = 0 recons_input_img = train_util.log_input_img_grid(test_loader, writer) train_util.save_recons_img_grid("val", recons_input_img, model, 0, args) for epoch in range(1, args.num_epochs): print("Epoch {}:".format(epoch)) train(model, opt, train_loader) curr_loss = val(model, val_loader) # val_loss_dict, z = train_util.test(get_losses, model, val_loader, args, discriminators) print(f"epoch val loss:{curr_loss}", file=sys.stderr) sys.stderr.flush() train_util.save_recons_img_grid("val", recons_input_img, model, epoch+1, args) train_util.save_interp_img_grid("val", recons_input_img, model, epoch+1, args)
def main(args): writer = SummaryWriter('./logs/{0}'.format(args.output_folder)) save_filename = './models/{0}'.format(args.output_folder) train_loader, valid_loader, test_loader = train_util.get_dataloaders(args) num_channels = 3 model = VectorQuantizedVAE(num_channels, args.hidden_size, args.k, args.enc_type, args.dec_type) model.to(args.device) # Fixed images for Tensorboard recons_input_img = train_util.log_input_img_grid(test_loader, writer) train_util.log_recons_img_grid(recons_input_img, model, 0, args.device, writer) discriminators = {} input_dim = 3 if args.recons_loss != "mse": if args.recons_loss == "gan": recons_disc = Discriminator(input_dim, args.img_res, args.input_type).to(args.device) elif args.recons_loss == "comp": recons_disc = AnchorComparator(input_dim * 2, args.img_res, args.input_type).to(args.device) elif "comp_2" in args.recons_loss: recons_disc = ClubbedPermutationComparator( input_dim * 2, args.img_res, args.input_type).to(args.device) elif "comp_6" in args.recons_loss: recons_disc = FullPermutationComparator( input_dim * 2, args.img_res, args.input_type).to(args.device) recons_disc_opt = torch.optim.Adam(recons_disc.parameters(), lr=args.disc_lr, amsgrad=True) recons_disc_lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( recons_disc_opt, "min", patience=args.lr_patience, factor=0.5, threshold=args.threshold, threshold_mode="abs", min_lr=1e-7) discriminators["recons_disc"] = [recons_disc, recons_disc_opt] optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) ae_lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( optimizer, "min", patience=args.lr_patience, factor=0.5, threshold=args.threshold, threshold_mode="abs", min_lr=1e-7) if torch.cuda.device_count() > 1: model = train_util.ae_data_parallel(model) for disc in discriminators: discriminators[disc][0] = torch.nn.DataParallel( discriminators[disc][0]) model.to(args.device) for disc in discriminators: discriminators[disc][0].to(args.device) # Generate the samples first once recons_input_img = train_util.log_input_img_grid(test_loader, writer) train_util.log_recons_img_grid(recons_input_img, model, 0, args.device, writer) if args.weights == "load": start_epoch = train_util.load_state(save_filename, model, optimizer, discriminators) else: start_epoch = 0 stop_patience = args.stop_patience best_loss = torch.tensor(np.inf) for epoch in tqdm(range(start_epoch, args.num_epochs), file=sys.stdout): try: train(epoch, train_loader, model, optimizer, args, writer, discriminators) except RuntimeError as err: print("".join( traceback.TracebackException.from_exception(err).format()), file=sys.stderr) print("*******") print(err, file=sys.stderr) print(f"batch_size:{args.batch_size}", file=sys.stderr) exit(0) val_loss_dict, z = train_util.test(get_losses, model, valid_loader, args, discriminators) train_util.log_recons_img_grid(recons_input_img, model, epoch + 1, args.device, writer) train_util.log_interp_img_grid(recons_input_img, model, epoch + 1, args.device, writer) train_util.log_losses("val", val_loss_dict, epoch + 1, writer) train_util.log_latent_metrics("val", z, epoch + 1, writer) train_util.save_state(model, optimizer, discriminators, val_loss_dict["recons_loss"], best_loss, args.recons_loss, epoch, save_filename) # early stop check # if val_loss_dict["recons_loss"] - best_loss < args.threshold: # stop_patience -= 1 # else: # stop_patience = args.stop_patience # if stop_patience == 0: # print("training early stopped!") # break ae_lr_scheduler.step(val_loss_dict["recons_loss"]) if args.recons_loss != "mse": recons_disc_lr_scheduler.step(val_loss_dict["recons_disc_loss"])
def val_test(args): writer = SummaryWriter('./logs/{0}'.format(args.output_folder)) save_filename = './models/{0}'.format(args.output_folder) d_args = vars(args) num_perturb_types = 0 perturb_types = [] for perturb_type in train_util.perturb_dict: if d_args[perturb_type]: num_perturb_types += 1 perturb_types.append(perturb_type) train_loader, valid_loader, test_loader = train_util.get_dataloaders( args, perturb_types) recons_input_img = train_util.log_input_img_grid(test_loader, writer) input_dim = 3 model = PCIE(args.img_res, input_dim, args.hidden_size, num_perturb_types, args.enc_type, args.dec_type) interp_disc = Discriminator(input_dim, args.img_res, args.input_type).to(args.device) interp_disc_opt = torch.optim.Adam(interp_disc.parameters(), lr=args.disc_lr, amsgrad=True) # if torch.cuda.device_count() > 1 and args.device == "cuda": # model = torch.nn.DataParallel(model) optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) # ae_lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, "min", patience=args.lr_patience, factor=0.5, # threshold=args.threshold, threshold_mode="abs", min_lr=1e-7) # interp_disc_lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(interp_disc_opt, "min", patience=args.lr_patience, factor=0.5, # threshold=args.threshold, threshold_mode="abs", min_lr=1e-7) discriminators = {"interp_disc": [interp_disc, interp_disc_opt]} if args.recons_loss != "mse": if args.recons_loss == "gan": recons_disc = Discriminator(input_dim, args.img_res, args.input_type).to(args.device) elif args.recons_loss == "comp": recons_disc = AnchorComparator(input_dim * 2, args.img_res, args.input_type).to(args.device) elif "comp_2" in args.recons_loss: recons_disc = ClubbedPermutationComparator( input_dim * 2, args.img_res, args.input_type).to(args.device) elif "comp_6" in args.recons_loss: recons_disc = FullPermutationComparator( input_dim * 2, args.img_res, args.input_type).to(args.device) recons_disc_opt = torch.optim.Adam(recons_disc.parameters(), lr=args.disc_lr, amsgrad=True) recons_disc_lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( recons_disc_opt, "min", patience=args.lr_patience, factor=0.5, threshold=args.threshold, threshold_mode="abs", min_lr=1e-7) discriminators["recons_disc"] = [recons_disc, recons_disc_opt] if args.prior_loss == "gan": prior_disc = Latent2ClassDiscriminator( args.hidden_size, args.img_res // args.scale_factor) prior_disc_opt = torch.optim.Adam(prior_disc.parameters(), lr=args.disc_lr, amsgrad=True) # prior_disc_lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(prior_disc_opt, "min", patience=args.lr_patience, factor=0.5, # threshold=args.threshold, threshold_mode="abs", min_lr=1e-7) discriminators["prior_disc"] = [prior_disc, prior_disc_opt] print("pertrub gans") if args.perturb_feat_gan: for perturb_type in train_util.perturb_dict: if d_args[perturb_type]: num_classes = train_util.perturb_dict[perturb_type].num_class pdb.set_trace() if num_classes == 2: print(f"perturb:{d_args[perturb_type]}\ttype: two ") pert_disc = Latent2ClassDiscriminator( args.hidden_size, args.img_res // args.scale_factor) pert_disc_opt = torch.optim.Adam(pert_disc.parameters(), lr=args.disc_lr, amsgrad=True) else: print(f"perturb:{d_args[perturb_type]}\ttype: multi ") pert_disc = LatentMultiClassDiscriminator( args.hidden_size, args.img_res // args.scale_factor, num_classes) pert_disc_opt = torch.optim.Adam(pert_disc.parameters(), lr=args.disc_lr, amsgrad=True) discriminators[f"{perturb_type}_disc"] = (pert_disc, pert_disc_opt) print("perrturb gans set") model.to(args.device) for disc in discriminators: discriminators[disc][0].to(args.device) # Generate the samples first once # train_util.log_recons_img_grid(recons_input_img, model, 0, args.device, writer) if args.weights == "load": start_epoch = train_util.load_state(save_filename, model, optimizer, discriminators) else: start_epoch = 0 # stop_patience = args.stop_patience best_loss = torch.tensor(np.inf) for epoch in tqdm(range(start_epoch, 4), file=sys.stdout): val_loss_dict = train_util.test(get_losses, model, valid_loader, args, discriminators, True) if args.weights == "init" and epoch == 1: epoch += 1 break train_util.log_losses("val", val_loss_dict, epoch + 1, writer) train_util.log_recons_img_grid(recons_input_img, model, epoch, args.device, writer) train_util.log_interp_img_grid(recons_input_img, model, epoch + 1, args.device, writer) train_util.save_state(model, optimizer, discriminators, val_loss_dict["recons_loss"], best_loss, args.recons_loss, epoch, save_filename) print(val_loss_dict)
def main(args): writer = SummaryWriter('./logs/{0}'.format(args.output_folder)) save_filename = './models/{0}'.format(args.output_folder) d_args = vars(args) pert_types = train_util.get_perturb_types(args) train_loader, valid_loader, test_loader = train_util.get_dataloaders( args, pert_types) recons_input_img = train_util.log_input_img_grid(test_loader, writer) # print(f"nn num:{len(pert_types)}") num_perturb_types = len(pert_types) input_dim = 3 model = PCIE(args.img_res, input_dim, args.hidden_size, num_perturb_types, args.enc_type, args.dec_type) interp_disc = Discriminator(input_dim, args.img_res, args.input_type).to(args.device) interp_disc_opt = torch.optim.Adam(interp_disc.parameters(), lr=args.disc_lr, amsgrad=True) # if torch.cuda.device_count() > 1 and args.device == "cuda": # model = torch.nn.DataParallel(model) optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) # ae_lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, "min", patience=args.lr_patience, factor=0.5, # threshold=args.threshold, threshold_mode="abs", min_lr=1e-7) # interp_disc_lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(interp_disc_opt, "min", patience=args.lr_patience, factor=0.5, # threshold=args.threshold, threshold_mode="abs", min_lr=1e-7) discriminators = {"interp_disc": [interp_disc, interp_disc_opt]} if args.recons_loss != "mse": if args.recons_loss == "gan": recons_disc = Discriminator(input_dim, args.img_res, args.input_type).to(args.device) elif args.recons_loss == "comp": recons_disc = AnchorComparator(input_dim * 2, args.img_res, args.input_type).to(args.device) elif "comp_2" in args.recons_loss: recons_disc = ClubbedPermutationComparator( input_dim * 2, args.img_res, args.input_type).to(args.device) elif "comp_6" in args.recons_loss: if "color" in args.recons_loss: recons_disc = FullPermutationColorComparator( input_dim * 2, args.img_res, args.input_type).to(args.device) else: recons_disc = FullPermutationComparator( input_dim * 2, args.img_res, args.input_type).to(args.device) recons_disc_opt = torch.optim.Adam(recons_disc.parameters(), lr=args.disc_lr, amsgrad=True) recons_disc_lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( recons_disc_opt, "min", patience=args.lr_patience, factor=0.5, threshold=args.threshold, threshold_mode="abs", min_lr=1e-7) discriminators["recons_disc"] = [recons_disc, recons_disc_opt] if args.prior_loss == "gan": prior_disc = Latent2ClassDiscriminator( args.hidden_size, args.img_res // args.scale_factor) prior_disc_opt = torch.optim.Adam(prior_disc.parameters(), lr=args.disc_lr, amsgrad=True) # prior_disc_lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(prior_disc_opt, "min", patience=args.lr_patience, factor=0.5, # threshold=args.threshold, threshold_mode="abs", min_lr=1e-7) discriminators["prior_disc"] = [prior_disc, prior_disc_opt] if args.perturb_feat_gan: for perturb_type in train_util.perturb_dict: if d_args[perturb_type]: num_class = train_util.perturb_dict[perturb_type].num_class if num_class == 2: pert_disc = Latent2ClassDiscriminator( args.hidden_size, args.img_res // args.scale_factor) pert_disc_opt = torch.optim.Adam(pert_disc.parameters(), lr=args.disc_lr, amsgrad=True) else: pert_disc = LatentMultiClassDiscriminator( args.hidden_size, args.img_res // args.scale_factor, num_class) pert_disc_opt = torch.optim.Adam(pert_disc.parameters(), lr=args.disc_lr, amsgrad=True) discriminators[f"{perturb_type}_disc"] = (pert_disc, pert_disc_opt) model.to(args.device) for disc in discriminators: discriminators[disc][0].to(args.device) # Generate the samples first once # train_util.log_recons_img_grid(recons_input_img, model, 0, args.device, writer) if args.weights == "load": start_epoch = train_util.load_state(save_filename, model, optimizer, discriminators) else: start_epoch = 0 # stop_patience = args.stop_patience best_loss = torch.tensor(np.inf) for epoch in tqdm(range(start_epoch, args.num_epochs), file=sys.stdout): # for CUDA OOM error, prevents running dependency job on slurm which is meant to run on timeout try: train(epoch, train_loader, model, optimizer, args, writer, discriminators) # pass except RuntimeError as err: print("".join( traceback.TracebackException.from_exception(err).format()), file=sys.stderr) print("*******", file=sys.stderr) print(err, file=sys.stderr) exit(0) print("out of train") # comp = subprocess.run("nvidia-smi --query-gpu=memory.free --format=csv,noheader,nounits", text=True, stdout=subprocess.PIPE) # print(comp.stdout, file=sys.stderr) val_loss_dict, _ = ae_test(get_losses, model, valid_loader, args, discriminators) train_util.log_losses("val", val_loss_dict, epoch + 1, writer) # print("logg loss") # train_util.log_latent_metrics("val", z, epoch+1, writer) # print("log metric") train_util.save_recons_img_grid("test", recons_input_img, model, epoch + 1, args) # print("log recons") train_util.save_interp_img_grid("test", recons_input_img, model, epoch + 1, args) # print("log interp") train_util.save_state(model, optimizer, discriminators, val_loss_dict["recons_loss"], best_loss, args.recons_loss, epoch, save_filename)
def main(args): save_filename = './models/{0}'.format(args.output_folder) d_args = vars(args) pert_types = train_util.get_perturb_types(args) _, _, test_loader = train_util.get_dataloaders(args, pert_types) recons_input_img = next(iter(test_loader))[0].to( args.device)[:args.num_img] # print(f"nn num:{len(pert_types)}") num_perturb_types = len(pert_types) input_dim = 3 model = PCIE(args.img_res, input_dim, args.hidden_size, num_perturb_types, args.enc_type, args.dec_type) # ae_lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, "min", patience=args.lr_patience, factor=0.5, # threshold=args.threshold, threshold_mode="abs", min_lr=1e-7) # interp_disc_lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(interp_disc_opt, "min", patience=args.lr_patience, factor=0.5, # threshold=args.threshold, threshold_mode="abs", min_lr=1e-7) recons_img = [] recons_img.append(recons_input_img) model.to(args.device) perturb_args = train_util.get_perturb_types(args) #print(args.output_folder) if args.compare == "recons_loss": losses = [ "mse", "comp", "comp_2_adv", "comp_2_dc", "comp_6_adv", "comp_6_dc" ] prior_loss = args.prior_loss perturb_feat_gan = args.perturb_feat_gan for recons_loss in losses: save_dir = os.path.join( "./models", compute_save_dir(recons_loss, perturb_feat_gan, perturb_args, prior_loss, args.enc_type, args.dec_type, args.img_res, args.hidden_size)) train_util.load_model(save_dir, model) recons, _ = model(recons_input_img) recons_img.append(recons) recons_img = torch.cat(recons_img, dim=0) fname = '-'.join(losses) elif args.compare == "prior_loss": recons_loss = args.recons_loss perturb_feat_gan = args.perturb_feat_gan prior_losses = ["kl_div", "gan"] for prior_loss in prior_losses: save_dir = os.path.join( "./models", compute_save_dir(recons_loss, perturb_feat_gan, perturb_args, prior_loss, args.enc_type, args.dec_type, args.img_res, args.hidden_size)) train_util.load_model(save_dir, model) recons, _ = model(recons_input_img) recons_img.append(recons) recons_img = torch.cat(recons_img, dim=0) fname = '-'.join(prior_losses) elif args.compare == "perturb_gan": recons_loss = args.recons_loss prior_loss = args.prior_loss perturb_feat_gans = [True, False] for perturb_feat_gan in perturb_feat_gans: save_dir = os.path.join( "./models", compute_save_dir(recons_loss, perturb_feat_gan, perturb_args, prior_loss, args.enc_type, args.dec_type, args.img_res, args.hidden_size)) train_util.load_model(save_dir, model) recons, _ = model(recons_input_img) recons_img.append(recons) recons_img = torch.cat(recons_img, dim=0) perturb_feat_gans = [str(val) for val in perturb_feat_gans] fname = '-'.join(perturb_feat_gans) # pdb.set_trace() result_dir = "./results" result_path = os.path.join("./results", f"cmp_{args.compare}_{ fname }.png") if not os.path.exists(result_dir): os.makedirs(result_dir) torchvision.utils.save_image(recons_img, result_path, nrow=args.num_img)
def main(args): input_dim = 3 model = VAE(input_dim, args.hidden_size, args.enc_type, args.dec_type) opt = torch.optim.Adam(model.parameters(), lr=args.lr, amsgrad=True) # scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(opt, "min", patience=args.lr_patience, factor=0.5, # threshold=args.threshold, threshold_mode="abs", min_lr=1e-6) # ae_lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(opt, "min", patience=args.lr_patience, factor=0.5, # threshold=args.threshold, threshold_mode="abs", min_lr=1e-7) discriminators = {} if args.recons_loss != "mse": if args.recons_loss == "gan": recons_disc = Discriminator(input_dim, args.img_res, args.input_type).to(args.device) elif args.recons_loss == "comp": recons_disc = AnchorComparator(input_dim * 2, args.img_res, args.input_type).to(args.device) elif "comp_2" in args.recons_loss: recons_disc = ClubbedPermutationComparator( input_dim * 2, args.img_res, args.input_type).to(args.device) elif "comp_6" in args.recons_loss: recons_disc = FullPermutationComparator( input_dim * 2, args.img_res, args.input_type).to(args.device) recons_disc_opt = torch.optim.Adam(recons_disc.parameters(), lr=args.disc_lr, amsgrad=True) discriminators["recons_disc"] = [recons_disc, recons_disc_opt] if torch.cuda.device_count() > 1: model = train_util.ae_data_parallel(model) for disc in discriminators: discriminators[disc][0] = torch.nn.DataParallel( discriminators[disc][0]) model.to(args.device) for disc in discriminators: discriminators[disc][0].to(args.device) print("model built", file=sys.stderr) #print("model created") train_loader, val_loader, test_loader = train_util.get_dataloaders(args) print("loaders acquired", file=sys.stderr) #print("loaders acquired") model_name = f"vae_{args.recons_loss}" if args.output_folder is None: args.output_folder = os.path.join( model_name, args.dataset, f"depth_{args.enc_type}_{args.dec_type}_hs_{args.img_res}_{args.hidden_size}" ) log_save_path = os.path.join("./logs", args.output_folder) model_save_path = os.path.join("./models", args.output_folder) if not os.path.exists(log_save_path): os.makedirs(log_save_path) print(f"log:{log_save_path}", file=sys.stderr) sys.stderr.flush() if not os.path.exists(model_save_path): os.makedirs(model_save_path) writer = SummaryWriter(log_save_path) print(f"train loader length:{len(train_loader)}", file=sys.stderr) best_loss = torch.tensor(np.inf) if args.weights == "load": start_epoch = train_util.load_state(model_save_path, model, opt, discriminators) else: start_epoch = 0 recons_input_img = train_util.log_input_img_grid(test_loader, writer) train_util.log_recons_img_grid(recons_input_img, model, 0, args.device, writer) stop_patience = args.stop_patience for epoch in range(start_epoch, args.num_epochs): try: train(model, train_loader, opt, epoch, writer, args, discriminators) except RuntimeError as err: print("".join( traceback.TracebackException.from_exception(err).format()), file=sys.stderr) print("*******", file=sys.stderr) print(err, file=sys.stderr) exit(0) val_loss_dict, z = train_util.test(get_losses, model, val_loader, args, discriminators) print(f"epoch loss:{val_loss_dict['recons_loss'].item()}") train_util.save_recons_img_grid("test", recons_input_img, model, epoch + 1, args) train_util.save_interp_img_grid("test", recons_input_img, model, epoch + 1, args) train_util.log_losses("val", val_loss_dict, epoch + 1, writer) train_util.log_latent_metrics("val", z, epoch + 1, writer) train_util.save_state(model, opt, discriminators, val_loss_dict["recons_loss"], best_loss, args.recons_loss, epoch, model_save_path)