示例#1
0
def main(args):

    writer = SummaryWriter('./logs/{0}'.format(args.output_folder))
    save_filename = './models/{0}'.format(args.output_folder)

    train_loader, val_loader, test_loader = train_util.get_dataloaders(args)
    recons_input_img = train_util.log_input_img_grid(test_loader, writer)

    input_dim = 3
    model = ACAI(args.img_res, input_dim, args.hidden_size, args.enc_type,
                 args.dec_type).to(args.device)
    disc = Discriminator(input_dim, args.img_res,
                         args.input_type).to(args.device)
    disc_opt = torch.optim.Adam(disc.parameters(),
                                lr=args.disc_lr,
                                amsgrad=True)
    # if torch.cuda.device_count() > 1 and args.device == "cuda":
    # 	model = torch.nn.DataParallel(model)

    opt = torch.optim.Adam(model.parameters(), lr=args.lr)

    # ae_lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, "min", patience=args.lr_patience, factor=0.5,
    # 	threshold=args.threshold, threshold_mode="abs", min_lr=1e-7)

    # interp_disc_lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(disc_opt, "min", patience=args.lr_patience, factor=0.5,
    # 	threshold=args.threshold, threshold_mode="abs", min_lr=1e-7)

    discriminators = {"interp_disc": [disc, disc_opt]}
    if args.recons_loss != "mse":
        if args.recons_loss == "gan":
            recons_disc = Discriminator(input_dim, args.img_res,
                                        args.input_type).to(args.device)
        elif args.recons_loss == "comp":
            recons_disc = AnchorComparator(input_dim * 2, args.img_res,
                                           args.input_type).to(args.device)
        elif "comp_2" in args.recons_loss:
            recons_disc = ClubbedPermutationComparator(
                input_dim * 2, args.img_res, args.input_type).to(args.device)
        elif "comp_6" in args.recons_loss:
            recons_disc = FullPermutationComparator(
                input_dim * 2, args.img_res, args.input_type).to(args.device)

        recons_disc_opt = torch.optim.Adam(recons_disc.parameters(),
                                           lr=args.disc_lr,
                                           amsgrad=True)

        discriminators["recons_disc"] = [recons_disc, recons_disc_opt]

    # Generate the samples first once
    train_util.save_recons_img_grid("test", recons_input_img, model, 0, args)

    if args.weights == "load":
        start_epoch = train_util.load_state(save_filename, model, opt,
                                            discriminators)
    else:
        start_epoch = 0

    best_loss = torch.tensor(np.inf)
    for epoch in range(args.num_epochs):
        print("Epoch {}:".format(epoch))
        train(model, opt, train_loader, args, discriminators, writer)

        # curr_loss = val(model, val_loader)
        # print(f"epoch val loss:{curr_loss}")

        val_loss_dict, z = train_util.test(get_losses, model, val_loader, args,
                                           discriminators)

        train_util.log_losses("val", val_loss_dict, epoch + 1, writer)
        train_util.log_latent_metrics("val", z, epoch + 1, writer)

        # train_util.log_recons_img_grid(recons_input_img, model, epoch+1, args.device, writer)
        # train_util.log_interp_img_grid(recons_input_img, model, epoch+1, args.device, writer)

        train_util.save_recons_img_grid("val", recons_input_img, model,
                                        epoch + 1, args)
        train_util.save_interp_img_grid("val", recons_input_img, model,
                                        epoch + 1, args)

        train_util.save_state(model, opt, discriminators,
                              val_loss_dict["recons_loss"], best_loss,
                              args.recons_loss, epoch, save_filename)
示例#2
0
def model_summary(model_type,
                  img_res,
                  hidden_size,
                  enc_type,
                  dec_type,
                  loss,
                  batch_size,
                  device=torch.device("cuda:1"),
                  verbose=True):
    pattern = re.compile(r"Params size \(MB\):(.*)\n")
    pattern2 = re.compile(r"Forward/backward pass size \(MB\):(.*)\n")
    input_dim = 3
    enc_input_size = (input_dim, img_res, img_res)
    dec_input_size = (hidden_size, img_res // 4, img_res // 4)
    pdb.set_trace()
    if verbose:
        print(f"model:{model_type}")
        print(f"depth:{enc_type}_{dec_type}")

    if model_type == "acai":
        model = ACAI(img_res, input_dim, hidden_size, enc_type,
                     dec_type).to(device)
    elif model_type == "vqvae":
        model = VectorQuantizedVAE(input_dim,
                                   hidden_size,
                                   enc_type=enc_type,
                                   dec_type=dec_type).to(device)
    elif model_type == "vae":
        model = VAE(input_dim,
                    hidden_size,
                    enc_type=enc_type,
                    dec_type=dec_type).to(device)

    encoder_summary, _ = torchsummary.summary_string(model.encoder,
                                                     enc_input_size,
                                                     device=device,
                                                     batch_size=batch_size)
    decoder_summary, _ = torchsummary.summary_string(model.decoder,
                                                     dec_input_size,
                                                     device=device,
                                                     batch_size=batch_size)
    if verbose:
        print(encoder_summary)
        print(decoder_summary)

    discriminators = {}

    if model_type == "acai":
        disc = Discriminator(input_dim, img_res, "image").to(device)

        disc_summary, _ = torchsummary.summary_string(disc,
                                                      enc_input_size,
                                                      device=device,
                                                      batch_size=batch_size)
        disc_param_size = float(re.search(pattern, disc_summary).group(1))
        disc_forward_size = float(re.search(pattern2, disc_summary).group(1))
        discriminators["interp_disc"] = (disc_param_size, disc_forward_size)
    if loss == "gan":
        disc = Discriminator(input_dim, img_res, "image").to(device)

        disc_summary, _ = torchsummary.summary_string(disc,
                                                      enc_input_size,
                                                      device=device,
                                                      batch_size=batch_size)
        disc_param_size = float(re.search(pattern, disc_summary).group(1))
        disc_forward_size = float(re.search(pattern2, disc_summary).group(1))
        discriminators["recons_disc"] = (disc_param_size,
                                         2 * disc_forward_size)
    elif loss == "comp":
        disc = AnchorComparator(input_dim * 2, img_res, "image").to(device)

        disc_summary, _ = torchsummary.summary_string(disc,
                                                      enc_input_size,
                                                      device=device,
                                                      batch_size=batch_size)
        disc_param_size = float(re.search(pattern, disc_summary).group(1))
        disc_forward_size = float(re.search(pattern2, disc_summary).group(1))
        discriminators["recons_disc"] = (disc_param_size,
                                         2 * disc_forward_size)
    elif "comp_2" in loss:
        disc = ClubbedPermutationComparator(input_dim * 2, img_res,
                                            "image").to(device)

        disc_summary, _ = torchsummary.summary_string(disc,
                                                      enc_input_size,
                                                      device=device,
                                                      batch_size=batch_size)
        disc_param_size = float(re.search(pattern, disc_summary).group(1))
        disc_forward_size = float(re.search(pattern2, disc_summary).group(1))
        discriminators["recons_disc"] = (disc_param_size,
                                         2 * disc_forward_size)
    elif "comp_6" in loss:
        disc = FullPermutationComparator(input_dim * 2, img_res,
                                         "image").to(device)

        disc_summary, _ = torchsummary.summary_string(disc,
                                                      enc_input_size,
                                                      device=device,
                                                      batch_size=batch_size)
        disc_param_size = float(re.search(pattern, disc_summary).group(1))
        disc_forward_size = float(re.search(pattern2, disc_summary).group(1))
        discriminators["recons_disc"] = (disc_param_size,
                                         2 * disc_forward_size)

    encoder_param_size = float(re.search(pattern, encoder_summary).group(1))
    encoder_forward_size = float(re.search(pattern2, encoder_summary).group(1))
    decoder_param_size = float(re.search(pattern, decoder_summary).group(1))
    decoder_forward_size = float(re.search(pattern2, decoder_summary).group(1))

    if verbose:
        if "ACAI" in str(type(model)):
            print(
                f"discriminator:\n\tparams:{disc_param_size}\n\tforward:{disc_forward_size}"
            )

        if loss == "gan":
            print(
                f"reconstruction discriminator:\n\tparams:{disc_param_size}\n\tforward:{disc_forward_size}"
            )

        print(
            f"encoder:\n\tparams:{encoder_param_size}\n\tforward:{encoder_forward_size}"
        )
        print(
            f"decoder:\n\tparams:{decoder_param_size}\n\tforward:{decoder_forward_size}"
        )

    encoder = {"params": encoder_param_size, "forward": encoder_forward_size}
    decoder = {"params": decoder_param_size, "forward": decoder_forward_size}

    return encoder, decoder, discriminators
示例#3
0
def main(args):

	train_loader, val_loader, test_loader = train_util.get_dataloaders(args)
	input_dim = 3
	model = VAE(input_dim, args.hidden_size, args.enc_type, args.dec_type)
	opt = torch.optim.Adam(model.parameters(), lr=LR, amsgrad=True)


	discriminators = {}

	if args.recons_loss != "mse":
		if args.recons_loss == "gan":
			recons_disc = Discriminator(input_dim, args.img_res, args.input_type).to(args.device)
		elif args.recons_loss == "comp":
			recons_disc = AnchorComparator(input_dim*2, args.img_res, args.input_type).to(args.device)
		elif "comp_2" in args.recons_loss:
			recons_disc = ClubbedPermutationComparator(input_dim*2, args.img_res, args.input_type).to(args.device)
		elif "comp_6" in args.recons_loss:
			recons_disc = FullPermutationComparator(input_dim*2, args.img_res, args.input_type).to(args.device)

		recons_disc_opt = torch.optim.Adam(recons_disc.parameters(), lr=args.disc_lr, amsgrad=True)
		
		discriminators["recons_disc"] = [recons_disc, recons_disc_opt]

	if torch.cuda.device_count() > 1:
		model = train_util.ae_data_parallel(model)
		for disc in discriminators:
			discriminators[disc][0] = torch.nn.DataParallel(discriminators[disc][0])

	model.to(args.device)

	model_name = f"vae_{args.recons_loss}"
	if args.output_folder is None:
		args.output_folder = os.path.join(model_name, args.dataset, f"depth_{args.enc_type}_{args.dec_type}_hs_{args.img_res}_{args.hidden_size}")

	log_save_path = os.path.join("./logs", args.output_folder)
	model_save_path = os.path.join("./models", args.output_folder)

	if not os.path.exists(log_save_path):
		os.makedirs(log_save_path)
		print(f"log:{log_save_path}", file=sys.stderr)
		sys.stderr.flush()
	if not os.path.exists(model_save_path):
		os.makedirs(model_save_path)


	writer = SummaryWriter(log_save_path)

	print(f"train loader length:{len(train_loader)}", file=sys.stderr)
	best_loss = torch.tensor(np.inf)
	
	if args.weights == "load":
		start_epoch = train_util.load_state(model_save_path, model, opt, discriminators)
	else:
		start_epoch = 0

	recons_input_img = train_util.log_input_img_grid(test_loader, writer)

	train_util.save_recons_img_grid("val", recons_input_img, model, 0, args)


	for epoch in range(1, args.num_epochs):
		print("Epoch {}:".format(epoch))
		train(model, opt, train_loader)
		curr_loss = val(model, val_loader)
		# val_loss_dict, z = train_util.test(get_losses, model, val_loader, args, discriminators)

		print(f"epoch val loss:{curr_loss}", file=sys.stderr)
		sys.stderr.flush()
		train_util.save_recons_img_grid("val", recons_input_img, model, epoch+1, args)
		train_util.save_interp_img_grid("val", recons_input_img, model, epoch+1, args)
示例#4
0
def main(args):
    writer = SummaryWriter('./logs/{0}'.format(args.output_folder))
    save_filename = './models/{0}'.format(args.output_folder)

    train_loader, valid_loader, test_loader = train_util.get_dataloaders(args)

    num_channels = 3
    model = VectorQuantizedVAE(num_channels, args.hidden_size, args.k,
                               args.enc_type, args.dec_type)
    model.to(args.device)

    # Fixed images for Tensorboard
    recons_input_img = train_util.log_input_img_grid(test_loader, writer)

    train_util.log_recons_img_grid(recons_input_img, model, 0, args.device,
                                   writer)

    discriminators = {}

    input_dim = 3
    if args.recons_loss != "mse":
        if args.recons_loss == "gan":
            recons_disc = Discriminator(input_dim, args.img_res,
                                        args.input_type).to(args.device)
        elif args.recons_loss == "comp":
            recons_disc = AnchorComparator(input_dim * 2, args.img_res,
                                           args.input_type).to(args.device)
        elif "comp_2" in args.recons_loss:
            recons_disc = ClubbedPermutationComparator(
                input_dim * 2, args.img_res, args.input_type).to(args.device)
        elif "comp_6" in args.recons_loss:
            recons_disc = FullPermutationComparator(
                input_dim * 2, args.img_res, args.input_type).to(args.device)

        recons_disc_opt = torch.optim.Adam(recons_disc.parameters(),
                                           lr=args.disc_lr,
                                           amsgrad=True)
        recons_disc_lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
            recons_disc_opt,
            "min",
            patience=args.lr_patience,
            factor=0.5,
            threshold=args.threshold,
            threshold_mode="abs",
            min_lr=1e-7)

        discriminators["recons_disc"] = [recons_disc, recons_disc_opt]

    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
    ae_lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer,
        "min",
        patience=args.lr_patience,
        factor=0.5,
        threshold=args.threshold,
        threshold_mode="abs",
        min_lr=1e-7)

    if torch.cuda.device_count() > 1:
        model = train_util.ae_data_parallel(model)
        for disc in discriminators:
            discriminators[disc][0] = torch.nn.DataParallel(
                discriminators[disc][0])

    model.to(args.device)
    for disc in discriminators:
        discriminators[disc][0].to(args.device)

    # Generate the samples first once
    recons_input_img = train_util.log_input_img_grid(test_loader, writer)
    train_util.log_recons_img_grid(recons_input_img, model, 0, args.device,
                                   writer)

    if args.weights == "load":
        start_epoch = train_util.load_state(save_filename, model, optimizer,
                                            discriminators)
    else:
        start_epoch = 0

    stop_patience = args.stop_patience
    best_loss = torch.tensor(np.inf)
    for epoch in tqdm(range(start_epoch, args.num_epochs), file=sys.stdout):

        try:
            train(epoch, train_loader, model, optimizer, args, writer,
                  discriminators)
        except RuntimeError as err:
            print("".join(
                traceback.TracebackException.from_exception(err).format()),
                  file=sys.stderr)
            print("*******")
            print(err, file=sys.stderr)
            print(f"batch_size:{args.batch_size}", file=sys.stderr)
            exit(0)

        val_loss_dict, z = train_util.test(get_losses, model, valid_loader,
                                           args, discriminators)

        train_util.log_recons_img_grid(recons_input_img, model, epoch + 1,
                                       args.device, writer)
        train_util.log_interp_img_grid(recons_input_img, model, epoch + 1,
                                       args.device, writer)

        train_util.log_losses("val", val_loss_dict, epoch + 1, writer)
        train_util.log_latent_metrics("val", z, epoch + 1, writer)
        train_util.save_state(model, optimizer, discriminators,
                              val_loss_dict["recons_loss"], best_loss,
                              args.recons_loss, epoch, save_filename)

        # early stop check
        # if val_loss_dict["recons_loss"] - best_loss < args.threshold:
        # 	stop_patience -= 1
        # else:
        # 	stop_patience = args.stop_patience

        # if stop_patience == 0:
        # 	print("training early stopped!")
        # 	break

        ae_lr_scheduler.step(val_loss_dict["recons_loss"])
        if args.recons_loss != "mse":
            recons_disc_lr_scheduler.step(val_loss_dict["recons_disc_loss"])
示例#5
0
文件: ae.py 项目: sidwa/ae_thesis
def val_test(args):
    writer = SummaryWriter('./logs/{0}'.format(args.output_folder))
    save_filename = './models/{0}'.format(args.output_folder)

    d_args = vars(args)

    num_perturb_types = 0
    perturb_types = []
    for perturb_type in train_util.perturb_dict:
        if d_args[perturb_type]:
            num_perturb_types += 1
            perturb_types.append(perturb_type)

    train_loader, valid_loader, test_loader = train_util.get_dataloaders(
        args, perturb_types)

    recons_input_img = train_util.log_input_img_grid(test_loader, writer)

    input_dim = 3
    model = PCIE(args.img_res, input_dim, args.hidden_size, num_perturb_types,
                 args.enc_type, args.dec_type)
    interp_disc = Discriminator(input_dim, args.img_res,
                                args.input_type).to(args.device)
    interp_disc_opt = torch.optim.Adam(interp_disc.parameters(),
                                       lr=args.disc_lr,
                                       amsgrad=True)
    # if torch.cuda.device_count() > 1 and args.device == "cuda":
    # 	model = torch.nn.DataParallel(model)

    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)

    # ae_lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, "min", patience=args.lr_patience, factor=0.5,
    # 	threshold=args.threshold, threshold_mode="abs", min_lr=1e-7)

    # interp_disc_lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(interp_disc_opt, "min", patience=args.lr_patience, factor=0.5,
    # 	threshold=args.threshold, threshold_mode="abs", min_lr=1e-7)

    discriminators = {"interp_disc": [interp_disc, interp_disc_opt]}

    if args.recons_loss != "mse":
        if args.recons_loss == "gan":
            recons_disc = Discriminator(input_dim, args.img_res,
                                        args.input_type).to(args.device)
        elif args.recons_loss == "comp":
            recons_disc = AnchorComparator(input_dim * 2, args.img_res,
                                           args.input_type).to(args.device)
        elif "comp_2" in args.recons_loss:
            recons_disc = ClubbedPermutationComparator(
                input_dim * 2, args.img_res, args.input_type).to(args.device)
        elif "comp_6" in args.recons_loss:
            recons_disc = FullPermutationComparator(
                input_dim * 2, args.img_res, args.input_type).to(args.device)

        recons_disc_opt = torch.optim.Adam(recons_disc.parameters(),
                                           lr=args.disc_lr,
                                           amsgrad=True)
        recons_disc_lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
            recons_disc_opt,
            "min",
            patience=args.lr_patience,
            factor=0.5,
            threshold=args.threshold,
            threshold_mode="abs",
            min_lr=1e-7)

        discriminators["recons_disc"] = [recons_disc, recons_disc_opt]

    if args.prior_loss == "gan":
        prior_disc = Latent2ClassDiscriminator(
            args.hidden_size, args.img_res // args.scale_factor)
        prior_disc_opt = torch.optim.Adam(prior_disc.parameters(),
                                          lr=args.disc_lr,
                                          amsgrad=True)
        # prior_disc_lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(prior_disc_opt, "min", patience=args.lr_patience, factor=0.5,
        # threshold=args.threshold, threshold_mode="abs", min_lr=1e-7)

        discriminators["prior_disc"] = [prior_disc, prior_disc_opt]

    print("pertrub gans")

    if args.perturb_feat_gan:
        for perturb_type in train_util.perturb_dict:
            if d_args[perturb_type]:
                num_classes = train_util.perturb_dict[perturb_type].num_class
                pdb.set_trace()
                if num_classes == 2:

                    print(f"perturb:{d_args[perturb_type]}\ttype: two ")
                    pert_disc = Latent2ClassDiscriminator(
                        args.hidden_size, args.img_res // args.scale_factor)
                    pert_disc_opt = torch.optim.Adam(pert_disc.parameters(),
                                                     lr=args.disc_lr,
                                                     amsgrad=True)
                else:
                    print(f"perturb:{d_args[perturb_type]}\ttype: multi ")
                    pert_disc = LatentMultiClassDiscriminator(
                        args.hidden_size, args.img_res // args.scale_factor,
                        num_classes)
                    pert_disc_opt = torch.optim.Adam(pert_disc.parameters(),
                                                     lr=args.disc_lr,
                                                     amsgrad=True)

                discriminators[f"{perturb_type}_disc"] = (pert_disc,
                                                          pert_disc_opt)

    print("perrturb gans set")

    model.to(args.device)
    for disc in discriminators:
        discriminators[disc][0].to(args.device)

    # Generate the samples first once
    # train_util.log_recons_img_grid(recons_input_img, model, 0, args.device, writer)

    if args.weights == "load":
        start_epoch = train_util.load_state(save_filename, model, optimizer,
                                            discriminators)
    else:
        start_epoch = 0

    # stop_patience = args.stop_patience
    best_loss = torch.tensor(np.inf)
    for epoch in tqdm(range(start_epoch, 4), file=sys.stdout):
        val_loss_dict = train_util.test(get_losses, model, valid_loader, args,
                                        discriminators, True)
        if args.weights == "init" and epoch == 1:
            epoch += 1
            break

    train_util.log_losses("val", val_loss_dict, epoch + 1, writer)
    train_util.log_recons_img_grid(recons_input_img, model, epoch, args.device,
                                   writer)
    train_util.log_interp_img_grid(recons_input_img, model, epoch + 1,
                                   args.device, writer)

    train_util.save_state(model, optimizer, discriminators,
                          val_loss_dict["recons_loss"], best_loss,
                          args.recons_loss, epoch, save_filename)

    print(val_loss_dict)
示例#6
0
文件: ae.py 项目: sidwa/ae_thesis
def main(args):
    writer = SummaryWriter('./logs/{0}'.format(args.output_folder))
    save_filename = './models/{0}'.format(args.output_folder)

    d_args = vars(args)
    pert_types = train_util.get_perturb_types(args)
    train_loader, valid_loader, test_loader = train_util.get_dataloaders(
        args, pert_types)
    recons_input_img = train_util.log_input_img_grid(test_loader, writer)

    # print(f"nn num:{len(pert_types)}")
    num_perturb_types = len(pert_types)

    input_dim = 3
    model = PCIE(args.img_res, input_dim, args.hidden_size, num_perturb_types,
                 args.enc_type, args.dec_type)
    interp_disc = Discriminator(input_dim, args.img_res,
                                args.input_type).to(args.device)
    interp_disc_opt = torch.optim.Adam(interp_disc.parameters(),
                                       lr=args.disc_lr,
                                       amsgrad=True)
    # if torch.cuda.device_count() > 1 and args.device == "cuda":
    # 	model = torch.nn.DataParallel(model)

    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)

    # ae_lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, "min", patience=args.lr_patience, factor=0.5,
    # 	threshold=args.threshold, threshold_mode="abs", min_lr=1e-7)

    # interp_disc_lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(interp_disc_opt, "min", patience=args.lr_patience, factor=0.5,
    # 	threshold=args.threshold, threshold_mode="abs", min_lr=1e-7)

    discriminators = {"interp_disc": [interp_disc, interp_disc_opt]}

    if args.recons_loss != "mse":
        if args.recons_loss == "gan":
            recons_disc = Discriminator(input_dim, args.img_res,
                                        args.input_type).to(args.device)
        elif args.recons_loss == "comp":
            recons_disc = AnchorComparator(input_dim * 2, args.img_res,
                                           args.input_type).to(args.device)
        elif "comp_2" in args.recons_loss:
            recons_disc = ClubbedPermutationComparator(
                input_dim * 2, args.img_res, args.input_type).to(args.device)
        elif "comp_6" in args.recons_loss:
            if "color" in args.recons_loss:
                recons_disc = FullPermutationColorComparator(
                    input_dim * 2, args.img_res,
                    args.input_type).to(args.device)
            else:
                recons_disc = FullPermutationComparator(
                    input_dim * 2, args.img_res,
                    args.input_type).to(args.device)

        recons_disc_opt = torch.optim.Adam(recons_disc.parameters(),
                                           lr=args.disc_lr,
                                           amsgrad=True)
        recons_disc_lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
            recons_disc_opt,
            "min",
            patience=args.lr_patience,
            factor=0.5,
            threshold=args.threshold,
            threshold_mode="abs",
            min_lr=1e-7)

        discriminators["recons_disc"] = [recons_disc, recons_disc_opt]

    if args.prior_loss == "gan":
        prior_disc = Latent2ClassDiscriminator(
            args.hidden_size, args.img_res // args.scale_factor)
        prior_disc_opt = torch.optim.Adam(prior_disc.parameters(),
                                          lr=args.disc_lr,
                                          amsgrad=True)
        # prior_disc_lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(prior_disc_opt, "min", patience=args.lr_patience, factor=0.5,
        # threshold=args.threshold, threshold_mode="abs", min_lr=1e-7)

        discriminators["prior_disc"] = [prior_disc, prior_disc_opt]

    if args.perturb_feat_gan:
        for perturb_type in train_util.perturb_dict:
            if d_args[perturb_type]:
                num_class = train_util.perturb_dict[perturb_type].num_class
                if num_class == 2:
                    pert_disc = Latent2ClassDiscriminator(
                        args.hidden_size, args.img_res // args.scale_factor)
                    pert_disc_opt = torch.optim.Adam(pert_disc.parameters(),
                                                     lr=args.disc_lr,
                                                     amsgrad=True)
                else:
                    pert_disc = LatentMultiClassDiscriminator(
                        args.hidden_size, args.img_res // args.scale_factor,
                        num_class)
                    pert_disc_opt = torch.optim.Adam(pert_disc.parameters(),
                                                     lr=args.disc_lr,
                                                     amsgrad=True)

                discriminators[f"{perturb_type}_disc"] = (pert_disc,
                                                          pert_disc_opt)

    model.to(args.device)
    for disc in discriminators:
        discriminators[disc][0].to(args.device)

    # Generate the samples first once
    # train_util.log_recons_img_grid(recons_input_img, model, 0, args.device, writer)

    if args.weights == "load":
        start_epoch = train_util.load_state(save_filename, model, optimizer,
                                            discriminators)
    else:
        start_epoch = 0

    # stop_patience = args.stop_patience
    best_loss = torch.tensor(np.inf)
    for epoch in tqdm(range(start_epoch, args.num_epochs), file=sys.stdout):

        # for CUDA OOM error, prevents running dependency job on slurm which is meant to run on timeout
        try:
            train(epoch, train_loader, model, optimizer, args, writer,
                  discriminators)
            # pass
        except RuntimeError as err:
            print("".join(
                traceback.TracebackException.from_exception(err).format()),
                  file=sys.stderr)
            print("*******", file=sys.stderr)
            print(err, file=sys.stderr)
            exit(0)

        print("out of train")
        # comp = subprocess.run("nvidia-smi --query-gpu=memory.free --format=csv,noheader,nounits", text=True, stdout=subprocess.PIPE)
        # print(comp.stdout, file=sys.stderr)
        val_loss_dict, _ = ae_test(get_losses, model, valid_loader, args,
                                   discriminators)
        train_util.log_losses("val", val_loss_dict, epoch + 1, writer)
        # print("logg loss")
        # train_util.log_latent_metrics("val", z, epoch+1, writer)
        # print("log metric")
        train_util.save_recons_img_grid("test", recons_input_img, model,
                                        epoch + 1, args)
        # print("log recons")
        train_util.save_interp_img_grid("test", recons_input_img, model,
                                        epoch + 1, args)
        # print("log interp")
        train_util.save_state(model, optimizer, discriminators,
                              val_loss_dict["recons_loss"], best_loss,
                              args.recons_loss, epoch, save_filename)
示例#7
0
文件: vae.py 项目: sidwa/ae_thesis
def main(args):

    input_dim = 3
    model = VAE(input_dim, args.hidden_size, args.enc_type, args.dec_type)

    opt = torch.optim.Adam(model.parameters(), lr=args.lr, amsgrad=True)
    # scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(opt, "min", patience=args.lr_patience, factor=0.5,
    # 	threshold=args.threshold, threshold_mode="abs", min_lr=1e-6)

    # ae_lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(opt, "min", patience=args.lr_patience, factor=0.5,
    # 	threshold=args.threshold, threshold_mode="abs", min_lr=1e-7)

    discriminators = {}

    if args.recons_loss != "mse":
        if args.recons_loss == "gan":
            recons_disc = Discriminator(input_dim, args.img_res,
                                        args.input_type).to(args.device)
        elif args.recons_loss == "comp":
            recons_disc = AnchorComparator(input_dim * 2, args.img_res,
                                           args.input_type).to(args.device)
        elif "comp_2" in args.recons_loss:
            recons_disc = ClubbedPermutationComparator(
                input_dim * 2, args.img_res, args.input_type).to(args.device)
        elif "comp_6" in args.recons_loss:
            recons_disc = FullPermutationComparator(
                input_dim * 2, args.img_res, args.input_type).to(args.device)

        recons_disc_opt = torch.optim.Adam(recons_disc.parameters(),
                                           lr=args.disc_lr,
                                           amsgrad=True)

        discriminators["recons_disc"] = [recons_disc, recons_disc_opt]

    if torch.cuda.device_count() > 1:
        model = train_util.ae_data_parallel(model)
        for disc in discriminators:
            discriminators[disc][0] = torch.nn.DataParallel(
                discriminators[disc][0])

    model.to(args.device)
    for disc in discriminators:
        discriminators[disc][0].to(args.device)

    print("model built", file=sys.stderr)
    #print("model created")
    train_loader, val_loader, test_loader = train_util.get_dataloaders(args)
    print("loaders acquired", file=sys.stderr)
    #print("loaders acquired")

    model_name = f"vae_{args.recons_loss}"
    if args.output_folder is None:
        args.output_folder = os.path.join(
            model_name, args.dataset,
            f"depth_{args.enc_type}_{args.dec_type}_hs_{args.img_res}_{args.hidden_size}"
        )

    log_save_path = os.path.join("./logs", args.output_folder)
    model_save_path = os.path.join("./models", args.output_folder)

    if not os.path.exists(log_save_path):
        os.makedirs(log_save_path)
        print(f"log:{log_save_path}", file=sys.stderr)
        sys.stderr.flush()
    if not os.path.exists(model_save_path):
        os.makedirs(model_save_path)

    writer = SummaryWriter(log_save_path)

    print(f"train loader length:{len(train_loader)}", file=sys.stderr)
    best_loss = torch.tensor(np.inf)

    if args.weights == "load":
        start_epoch = train_util.load_state(model_save_path, model, opt,
                                            discriminators)
    else:
        start_epoch = 0

    recons_input_img = train_util.log_input_img_grid(test_loader, writer)

    train_util.log_recons_img_grid(recons_input_img, model, 0, args.device,
                                   writer)

    stop_patience = args.stop_patience
    for epoch in range(start_epoch, args.num_epochs):

        try:
            train(model, train_loader, opt, epoch, writer, args,
                  discriminators)
        except RuntimeError as err:
            print("".join(
                traceback.TracebackException.from_exception(err).format()),
                  file=sys.stderr)
            print("*******", file=sys.stderr)
            print(err, file=sys.stderr)
            exit(0)

        val_loss_dict, z = train_util.test(get_losses, model, val_loader, args,
                                           discriminators)
        print(f"epoch loss:{val_loss_dict['recons_loss'].item()}")

        train_util.save_recons_img_grid("test", recons_input_img, model,
                                        epoch + 1, args)
        train_util.save_interp_img_grid("test", recons_input_img, model,
                                        epoch + 1, args)

        train_util.log_losses("val", val_loss_dict, epoch + 1, writer)
        train_util.log_latent_metrics("val", z, epoch + 1, writer)
        train_util.save_state(model, opt, discriminators,
                              val_loss_dict["recons_loss"], best_loss,
                              args.recons_loss, epoch, model_save_path)