Esempio n. 1
0
def val_test(args):
    writer = SummaryWriter('./logs/{0}'.format(args.output_folder))
    save_filename = './models/{0}'.format(args.output_folder)

    train_loader, valid_loader, test_loader = train_util.get_dataloaders(args)
    recons_input_img = train_util.log_input_img_grid(test_loader, writer)

    input_dim = 3
    model = VectorQuantizedVAE(input_dim, args.hidden_size, args.k,
                               args.enc_type, args.dec_type)
    # if torch.cuda.device_count() > 1 and args.device == "cuda":
    # 	model = torch.nn.DataParallel(model)

    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)

    discriminators = {}

    if args.recons_loss == "gan":
        recons_disc = Discriminator(input_dim, args.img_res,
                                    args.input_type).to(args.device)
        recons_disc_opt = torch.optim.Adam(recons_disc.parameters(),
                                           lr=args.disc_lr,
                                           amsgrad=True)
        discriminators["recons_disc"] = [recons_disc, recons_disc_opt]

    model.to(args.device)
    for disc in discriminators:
        discriminators[disc][0].to(args.device)

    if args.weights == "load":
        start_epoch = train_util.load_state(save_filename, model, optimizer,
                                            discriminators)
    else:
        start_epoch = 0

    stop_patience = args.stop_patience
    best_loss = torch.tensor(np.inf)
    for epoch in tqdm(range(start_epoch, 4), file=sys.stdout):
        val_loss_dict, z = train_util.test(get_losses, model, valid_loader,
                                           args, discriminators, True)
        # if args.weights == "init" and epoch==1:
        # 	epoch+=1
        # 	break

        train_util.log_recons_img_grid(recons_input_img, model, epoch + 1,
                                       args.device, writer)
        train_util.log_interp_img_grid(recons_input_img, model, epoch + 1,
                                       args.device, writer)

        train_util.log_losses("val", val_loss_dict, epoch + 1, writer)
        train_util.log_latent_metrics("val", z, epoch + 1, writer)
        train_util.save_state(model, optimizer, discriminators,
                              val_loss_dict["recons_loss"], best_loss,
                              args.recons_loss, epoch, save_filename)

    print(val_loss_dict)
Esempio n. 2
0
def main(args):
    writer = SummaryWriter('./logs/{0}'.format(args.output_folder))
    save_filename = './models/{0}'.format(args.output_folder)

    train_loader, valid_loader, test_loader = train_util.get_dataloaders(args)

    num_channels = 3
    model = VectorQuantizedVAE(num_channels, args.hidden_size, args.k,
                               args.enc_type, args.dec_type)
    model.to(args.device)

    # Fixed images for Tensorboard
    recons_input_img = train_util.log_input_img_grid(test_loader, writer)

    train_util.log_recons_img_grid(recons_input_img, model, 0, args.device,
                                   writer)

    discriminators = {}

    input_dim = 3
    if args.recons_loss != "mse":
        if args.recons_loss == "gan":
            recons_disc = Discriminator(input_dim, args.img_res,
                                        args.input_type).to(args.device)
        elif args.recons_loss == "comp":
            recons_disc = AnchorComparator(input_dim * 2, args.img_res,
                                           args.input_type).to(args.device)
        elif "comp_2" in args.recons_loss:
            recons_disc = ClubbedPermutationComparator(
                input_dim * 2, args.img_res, args.input_type).to(args.device)
        elif "comp_6" in args.recons_loss:
            recons_disc = FullPermutationComparator(
                input_dim * 2, args.img_res, args.input_type).to(args.device)

        recons_disc_opt = torch.optim.Adam(recons_disc.parameters(),
                                           lr=args.disc_lr,
                                           amsgrad=True)
        recons_disc_lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
            recons_disc_opt,
            "min",
            patience=args.lr_patience,
            factor=0.5,
            threshold=args.threshold,
            threshold_mode="abs",
            min_lr=1e-7)

        discriminators["recons_disc"] = [recons_disc, recons_disc_opt]

    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
    ae_lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer,
        "min",
        patience=args.lr_patience,
        factor=0.5,
        threshold=args.threshold,
        threshold_mode="abs",
        min_lr=1e-7)

    if torch.cuda.device_count() > 1:
        model = train_util.ae_data_parallel(model)
        for disc in discriminators:
            discriminators[disc][0] = torch.nn.DataParallel(
                discriminators[disc][0])

    model.to(args.device)
    for disc in discriminators:
        discriminators[disc][0].to(args.device)

    # Generate the samples first once
    recons_input_img = train_util.log_input_img_grid(test_loader, writer)
    train_util.log_recons_img_grid(recons_input_img, model, 0, args.device,
                                   writer)

    if args.weights == "load":
        start_epoch = train_util.load_state(save_filename, model, optimizer,
                                            discriminators)
    else:
        start_epoch = 0

    stop_patience = args.stop_patience
    best_loss = torch.tensor(np.inf)
    for epoch in tqdm(range(start_epoch, args.num_epochs), file=sys.stdout):

        try:
            train(epoch, train_loader, model, optimizer, args, writer,
                  discriminators)
        except RuntimeError as err:
            print("".join(
                traceback.TracebackException.from_exception(err).format()),
                  file=sys.stderr)
            print("*******")
            print(err, file=sys.stderr)
            print(f"batch_size:{args.batch_size}", file=sys.stderr)
            exit(0)

        val_loss_dict, z = train_util.test(get_losses, model, valid_loader,
                                           args, discriminators)

        train_util.log_recons_img_grid(recons_input_img, model, epoch + 1,
                                       args.device, writer)
        train_util.log_interp_img_grid(recons_input_img, model, epoch + 1,
                                       args.device, writer)

        train_util.log_losses("val", val_loss_dict, epoch + 1, writer)
        train_util.log_latent_metrics("val", z, epoch + 1, writer)
        train_util.save_state(model, optimizer, discriminators,
                              val_loss_dict["recons_loss"], best_loss,
                              args.recons_loss, epoch, save_filename)

        # early stop check
        # if val_loss_dict["recons_loss"] - best_loss < args.threshold:
        # 	stop_patience -= 1
        # else:
        # 	stop_patience = args.stop_patience

        # if stop_patience == 0:
        # 	print("training early stopped!")
        # 	break

        ae_lr_scheduler.step(val_loss_dict["recons_loss"])
        if args.recons_loss != "mse":
            recons_disc_lr_scheduler.step(val_loss_dict["recons_disc_loss"])