Exemple #1
0
def val_test(args):
    writer = SummaryWriter('./logs/{0}'.format(args.output_folder))
    save_filename = './models/{0}'.format(args.output_folder)

    train_loader, valid_loader, test_loader = train_util.get_dataloaders(args)
    recons_input_img = train_util.log_input_img_grid(test_loader, writer)

    input_dim = 3
    model = VectorQuantizedVAE(input_dim, args.hidden_size, args.k,
                               args.enc_type, args.dec_type)
    # if torch.cuda.device_count() > 1 and args.device == "cuda":
    # 	model = torch.nn.DataParallel(model)

    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)

    discriminators = {}

    if args.recons_loss == "gan":
        recons_disc = Discriminator(input_dim, args.img_res,
                                    args.input_type).to(args.device)
        recons_disc_opt = torch.optim.Adam(recons_disc.parameters(),
                                           lr=args.disc_lr,
                                           amsgrad=True)
        discriminators["recons_disc"] = [recons_disc, recons_disc_opt]

    model.to(args.device)
    for disc in discriminators:
        discriminators[disc][0].to(args.device)

    if args.weights == "load":
        start_epoch = train_util.load_state(save_filename, model, optimizer,
                                            discriminators)
    else:
        start_epoch = 0

    stop_patience = args.stop_patience
    best_loss = torch.tensor(np.inf)
    for epoch in tqdm(range(start_epoch, 4), file=sys.stdout):
        val_loss_dict, z = train_util.test(get_losses, model, valid_loader,
                                           args, discriminators, True)
        # if args.weights == "init" and epoch==1:
        # 	epoch+=1
        # 	break

        train_util.log_recons_img_grid(recons_input_img, model, epoch + 1,
                                       args.device, writer)
        train_util.log_interp_img_grid(recons_input_img, model, epoch + 1,
                                       args.device, writer)

        train_util.log_losses("val", val_loss_dict, epoch + 1, writer)
        train_util.log_latent_metrics("val", z, epoch + 1, writer)
        train_util.save_state(model, optimizer, discriminators,
                              val_loss_dict["recons_loss"], best_loss,
                              args.recons_loss, epoch, save_filename)

    print(val_loss_dict)
Exemple #2
0
def train(dataset_dir, output_dir):
    """Train discriminator and generator"""
    if torch.cuda.is_available():
        torch.set_default_tensor_type('torch.cuda.FloatTensor')
        device = torch.device('cuda')
    else:
        device = torch.device('cpu')

    discriminator = Discriminator()
    discriminator.to(device)
    discriminator.train()
    generator = Generator()
    generator.to(device)
    generator.train()

    optimizer_discriminator = torch.optim.Adam(discriminator.parameters(), lr=0.0001)
    optimizer_generator = torch.optim.Adam(generator.parameters(), lr=0.0001)
    loss_func = nn.BCELoss()

    loader = LFWLoader(dataset_dir)
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    for epoch in range(10):
        for image in tqdm(loader):
            # train discriminator on true
            outputs = discriminator(torch.tensor(image, dtype=torch.float32) / 255)
            loss = loss_func(outputs, torch.tensor([1.0]))
            optimizer_discriminator.zero_grad()
            loss.backward()
            optimizer_discriminator.step()

            # train discriminator on false
            outputs = discriminator(generator(generate_random_seed(100)).detach())
            loss = loss_func(outputs, torch.tensor([0.0]))
            optimizer_discriminator.zero_grad()
            loss.backward()
            optimizer_discriminator.step()

            # train generator
            outputs = discriminator(generator(generate_random_seed(100)))
            loss = loss_func(outputs, torch.tensor([1.0]))
            optimizer_generator.zero_grad()
            loss.backward()
            optimizer_generator.step()
        torch.save(generator, os.path.join(output_dir, 'G_%d.pt' % epoch))
        torch.save(discriminator, os.path.join(output_dir, 'D_%d.pt' % epoch))
Exemple #3
0
def main(args):

	train_loader, val_loader, test_loader = train_util.get_dataloaders(args)
	input_dim = 3
	model = VAE(input_dim, args.hidden_size, args.enc_type, args.dec_type)
	opt = torch.optim.Adam(model.parameters(), lr=LR, amsgrad=True)


	discriminators = {}

	if args.recons_loss != "mse":
		if args.recons_loss == "gan":
			recons_disc = Discriminator(input_dim, args.img_res, args.input_type).to(args.device)
		elif args.recons_loss == "comp":
			recons_disc = AnchorComparator(input_dim*2, args.img_res, args.input_type).to(args.device)
		elif "comp_2" in args.recons_loss:
			recons_disc = ClubbedPermutationComparator(input_dim*2, args.img_res, args.input_type).to(args.device)
		elif "comp_6" in args.recons_loss:
			recons_disc = FullPermutationComparator(input_dim*2, args.img_res, args.input_type).to(args.device)

		recons_disc_opt = torch.optim.Adam(recons_disc.parameters(), lr=args.disc_lr, amsgrad=True)
		
		discriminators["recons_disc"] = [recons_disc, recons_disc_opt]

	if torch.cuda.device_count() > 1:
		model = train_util.ae_data_parallel(model)
		for disc in discriminators:
			discriminators[disc][0] = torch.nn.DataParallel(discriminators[disc][0])

	model.to(args.device)

	model_name = f"vae_{args.recons_loss}"
	if args.output_folder is None:
		args.output_folder = os.path.join(model_name, args.dataset, f"depth_{args.enc_type}_{args.dec_type}_hs_{args.img_res}_{args.hidden_size}")

	log_save_path = os.path.join("./logs", args.output_folder)
	model_save_path = os.path.join("./models", args.output_folder)

	if not os.path.exists(log_save_path):
		os.makedirs(log_save_path)
		print(f"log:{log_save_path}", file=sys.stderr)
		sys.stderr.flush()
	if not os.path.exists(model_save_path):
		os.makedirs(model_save_path)


	writer = SummaryWriter(log_save_path)

	print(f"train loader length:{len(train_loader)}", file=sys.stderr)
	best_loss = torch.tensor(np.inf)
	
	if args.weights == "load":
		start_epoch = train_util.load_state(model_save_path, model, opt, discriminators)
	else:
		start_epoch = 0

	recons_input_img = train_util.log_input_img_grid(test_loader, writer)

	train_util.save_recons_img_grid("val", recons_input_img, model, 0, args)


	for epoch in range(1, args.num_epochs):
		print("Epoch {}:".format(epoch))
		train(model, opt, train_loader)
		curr_loss = val(model, val_loader)
		# val_loss_dict, z = train_util.test(get_losses, model, val_loader, args, discriminators)

		print(f"epoch val loss:{curr_loss}", file=sys.stderr)
		sys.stderr.flush()
		train_util.save_recons_img_grid("val", recons_input_img, model, epoch+1, args)
		train_util.save_interp_img_grid("val", recons_input_img, model, epoch+1, args)
Exemple #4
0
def main(args):
    writer = SummaryWriter('./logs/{0}'.format(args.output_folder))
    save_filename = './models/{0}'.format(args.output_folder)

    train_loader, valid_loader, test_loader = train_util.get_dataloaders(args)

    num_channels = 3
    model = VectorQuantizedVAE(num_channels, args.hidden_size, args.k,
                               args.enc_type, args.dec_type)
    model.to(args.device)

    # Fixed images for Tensorboard
    recons_input_img = train_util.log_input_img_grid(test_loader, writer)

    train_util.log_recons_img_grid(recons_input_img, model, 0, args.device,
                                   writer)

    discriminators = {}

    input_dim = 3
    if args.recons_loss != "mse":
        if args.recons_loss == "gan":
            recons_disc = Discriminator(input_dim, args.img_res,
                                        args.input_type).to(args.device)
        elif args.recons_loss == "comp":
            recons_disc = AnchorComparator(input_dim * 2, args.img_res,
                                           args.input_type).to(args.device)
        elif "comp_2" in args.recons_loss:
            recons_disc = ClubbedPermutationComparator(
                input_dim * 2, args.img_res, args.input_type).to(args.device)
        elif "comp_6" in args.recons_loss:
            recons_disc = FullPermutationComparator(
                input_dim * 2, args.img_res, args.input_type).to(args.device)

        recons_disc_opt = torch.optim.Adam(recons_disc.parameters(),
                                           lr=args.disc_lr,
                                           amsgrad=True)
        recons_disc_lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
            recons_disc_opt,
            "min",
            patience=args.lr_patience,
            factor=0.5,
            threshold=args.threshold,
            threshold_mode="abs",
            min_lr=1e-7)

        discriminators["recons_disc"] = [recons_disc, recons_disc_opt]

    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
    ae_lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer,
        "min",
        patience=args.lr_patience,
        factor=0.5,
        threshold=args.threshold,
        threshold_mode="abs",
        min_lr=1e-7)

    if torch.cuda.device_count() > 1:
        model = train_util.ae_data_parallel(model)
        for disc in discriminators:
            discriminators[disc][0] = torch.nn.DataParallel(
                discriminators[disc][0])

    model.to(args.device)
    for disc in discriminators:
        discriminators[disc][0].to(args.device)

    # Generate the samples first once
    recons_input_img = train_util.log_input_img_grid(test_loader, writer)
    train_util.log_recons_img_grid(recons_input_img, model, 0, args.device,
                                   writer)

    if args.weights == "load":
        start_epoch = train_util.load_state(save_filename, model, optimizer,
                                            discriminators)
    else:
        start_epoch = 0

    stop_patience = args.stop_patience
    best_loss = torch.tensor(np.inf)
    for epoch in tqdm(range(start_epoch, args.num_epochs), file=sys.stdout):

        try:
            train(epoch, train_loader, model, optimizer, args, writer,
                  discriminators)
        except RuntimeError as err:
            print("".join(
                traceback.TracebackException.from_exception(err).format()),
                  file=sys.stderr)
            print("*******")
            print(err, file=sys.stderr)
            print(f"batch_size:{args.batch_size}", file=sys.stderr)
            exit(0)

        val_loss_dict, z = train_util.test(get_losses, model, valid_loader,
                                           args, discriminators)

        train_util.log_recons_img_grid(recons_input_img, model, epoch + 1,
                                       args.device, writer)
        train_util.log_interp_img_grid(recons_input_img, model, epoch + 1,
                                       args.device, writer)

        train_util.log_losses("val", val_loss_dict, epoch + 1, writer)
        train_util.log_latent_metrics("val", z, epoch + 1, writer)
        train_util.save_state(model, optimizer, discriminators,
                              val_loss_dict["recons_loss"], best_loss,
                              args.recons_loss, epoch, save_filename)

        # early stop check
        # if val_loss_dict["recons_loss"] - best_loss < args.threshold:
        # 	stop_patience -= 1
        # else:
        # 	stop_patience = args.stop_patience

        # if stop_patience == 0:
        # 	print("training early stopped!")
        # 	break

        ae_lr_scheduler.step(val_loss_dict["recons_loss"])
        if args.recons_loss != "mse":
            recons_disc_lr_scheduler.step(val_loss_dict["recons_disc_loss"])
    def test_gan(self):
        # models settings
        lambda_gp = 0.1
        Tensor = torch.FloatTensor
        generator = Generator(2, 2)
        discriminator = Discriminator(2)
        optimizer_G = torch.optim.Adam(generator.parameters(),
                                       lr=0.0001,
                                       betas=(0.5, 0.999))
        optimizer_D = torch.optim.Adam(discriminator.parameters(),
                                       lr=0.0001,
                                       betas=(0.5, 0.999))
        criterion = torch.nn.BCEWithLogitsLoss()
        optimizer_D.zero_grad()

        # run models for test batch
        real_data = Tensor(np.random.normal(0, 1, (64, 2)))
        z = Tensor(np.random.normal(0, 1, (64, 2)))
        real_target = Tensor(real_data.size(0), 1).fill_(1.0)
        fake_target = Tensor(real_data.size(0), 1).fill_(0.0)

        g_before = deepcopy(generator)
        d_before = deepcopy(discriminator)

        # Generate a batch of images
        fake_data = generator(z)

        # Real images
        real_validity = discriminator(real_data)
        # Fake images
        fake_validity = discriminator(fake_data)

        # Gradient penalty
        gradient_penalty = compute_gradient_penalty(discriminator,
                                                    real_data.data,
                                                    fake_data.data, Tensor)

        # Discriminator loss
        d_loss = criterion(real_validity, real_target) \
                 + criterion(fake_validity, fake_target) \
                 + lambda_gp * gradient_penalty
        d_loss.backward()
        optimizer_D.step()

        # Assert that D changed and G not changed
        g_changed = [
            torch.equal(after, before) for after, before in zip(
                generator.parameters(), g_before.parameters())
        ]
        self.assertTrue(all(g_changed))

        d_changed = [
            torch.equal(after, before) for after, before in zip(
                discriminator.parameters(), d_before.parameters())
        ]
        self.assertFalse(all(d_changed))
        optimizer_G.zero_grad()

        # Train on fake samples

        g_before = deepcopy(generator)
        d_before = deepcopy(discriminator)

        fake_data = generator(z)
        fake_validity = discriminator(fake_data)
        g_loss = criterion(fake_validity, real_target)
        g_loss.backward()
        optimizer_G.step()

        # Assert that G changed and D not changed
        g_changed = [
            torch.equal(after, before) for after, before in zip(
                generator.parameters(), g_before.parameters())
        ]
        self.assertFalse(all(g_changed))

        d_changed = [
            torch.equal(after, before) for after, before in zip(
                discriminator.parameters(), d_before.parameters())
        ]
        self.assertTrue(all(d_changed))
Exemple #6
0
def train(rank: int, cfg: DictConfig):
    print(OmegaConf.to_yaml(cfg))

    if cfg.train.n_gpu > 1:
        init_process_group(backend=cfg.train.dist_config['dist_backend'],
                           init_method=cfg.train.dist_config['dist_url'],
                           world_size=cfg.train.dist_config['world_size'] *
                           cfg.train.n_gpu,
                           rank=rank)

    device = torch.device(
        'cuda:{:d}'.format(rank) if torch.cuda.is_available() else 'cpu')

    generator = Generator(sum(cfg.model.feature_dims), *cfg.model.cond_dims,
                          **cfg.model.generator).to(device)
    discriminator = Discriminator(**cfg.model.discriminator).to(device)

    if rank == 0:
        print(generator)
        os.makedirs(cfg.train.ckpt_dir, exist_ok=True)
        print("checkpoints directory : ", cfg.train.ckpt_dir)

    if os.path.isdir(cfg.train.ckpt_dir):
        cp_g = scan_checkpoint(cfg.train.ckpt_dir, 'g_')
        cp_do = scan_checkpoint(cfg.train.ckpt_dir, 'd_')

    steps = 1
    if cp_g is None or cp_do is None:
        state_dict_do = None
        last_epoch = -1
    else:
        state_dict_g = load_checkpoint(cp_g, device)
        state_dict_do = load_checkpoint(cp_do, device)
        generator.load_state_dict(state_dict_g['generator'])
        discriminator.load_state_dict(state_dict_do['discriminator'])
        steps = state_dict_do['steps'] + 1
        last_epoch = state_dict_do['epoch']

    if cfg.train.n_gpu > 1:
        generator = DistributedDataParallel(generator,
                                            device_ids=[rank]).to(device)
        discriminator = DistributedDataParallel(discriminator,
                                                device_ids=[rank]).to(device)

    optim_g = RAdam(generator.parameters(), cfg.opt.lr, betas=cfg.opt.betas)
    optim_d = RAdam(discriminator.parameters(),
                    cfg.opt.lr,
                    betas=cfg.opt.betas)

    if state_dict_do is not None:
        optim_g.load_state_dict(state_dict_do['optim_g'])
        optim_d.load_state_dict(state_dict_do['optim_d'])

    scheduler_g = torch.optim.lr_scheduler.ExponentialLR(
        optim_g, gamma=cfg.opt.lr_decay, last_epoch=last_epoch)
    scheduler_d = torch.optim.lr_scheduler.ExponentialLR(
        optim_d, gamma=cfg.opt.lr_decay, last_epoch=last_epoch)

    train_filelist = load_dataset_filelist(cfg.dataset.train_list)
    trainset = FeatureDataset(cfg.dataset, train_filelist, cfg.data)
    train_sampler = DistributedSampler(
        trainset) if cfg.train.n_gpu > 1 else None
    train_loader = DataLoader(trainset,
                              batch_size=cfg.train.batch_size,
                              num_workers=cfg.train.num_workers,
                              shuffle=True,
                              sampler=train_sampler,
                              pin_memory=True,
                              drop_last=True)

    if rank == 0:
        val_filelist = load_dataset_filelist(cfg.dataset.test_list)
        valset = FeatureDataset(cfg.dataset,
                                val_filelist,
                                cfg.data,
                                segmented=False)
        val_loader = DataLoader(valset,
                                batch_size=1,
                                num_workers=cfg.train.num_workers,
                                shuffle=False,
                                sampler=train_sampler,
                                pin_memory=True)

        sw = SummaryWriter(os.path.join(cfg.train.ckpt_dir, 'logs'))

    generator.train()
    discriminator.train()
    for epoch in range(max(0, last_epoch), cfg.train.epochs):
        if rank == 0:
            start = time.time()
            print("Epoch: {}".format(epoch + 1))

        if cfg.train.n_gpu > 1:
            train_sampler.set_epoch(epoch)

        for y, x_noised_features, x_noised_cond in train_loader:
            if rank == 0:
                start_b = time.time()

            y = y.to(device, non_blocking=True)
            x_noised_features = x_noised_features.transpose(1, 2).to(
                device, non_blocking=True)
            x_noised_cond = x_noised_cond.to(device, non_blocking=True)
            z1 = torch.randn(cfg.train.batch_size,
                             cfg.model.cond_dims[1],
                             device=device)
            z2 = torch.randn(cfg.train.batch_size,
                             cfg.model.cond_dims[1],
                             device=device)

            y_hat1 = generator(x_noised_features, x_noised_cond, z=z1)
            y_hat2 = generator(x_noised_features, x_noised_cond, z=z2)

            # Discriminator
            real_scores, fake_scores = discriminator(y), discriminator(
                y_hat1.detach())
            d_loss = discriminator_loss(real_scores, fake_scores)

            optim_d.zero_grad()
            d_loss.backward(retain_graph=True)
            optim_d.step()

            # Generator
            g_stft_loss = criterion(y, y_hat1) + criterion(
                y, y_hat2) - criterion(y_hat1, y_hat2)
            g_adv_loss = adversarial_loss(fake_scores)
            g_loss = g_adv_loss + g_stft_loss

            optim_g.zero_grad()
            g_loss.backward()
            optim_g.step()

            if rank == 0:
                # STDOUT logging
                if steps % cfg.train.stdout_interval == 0:
                    with torch.no_grad():
                        print(
                            'Steps : {:d}, Gen Loss Total : {:4.3f}, STFT Error : {:4.3f}, s/b : {:4.3f}'
                            .format(steps, g_loss, g_stft_loss,
                                    time.time() - start_b))

                # checkpointing
                if steps % cfg.train.checkpoint_interval == 0:
                    ckpt_dir = "{}/g_{:08d}".format(cfg.train.ckpt_dir, steps)
                    save_checkpoint(
                        ckpt_dir, {
                            'generator':
                            (generator.module if cfg.train.n_gpu > 1 else
                             generator).state_dict()
                        })
                    ckpt_dir = "{}/do_{:08d}".format(cfg.train.ckpt_dir, steps)
                    save_checkpoint(
                        ckpt_dir, {
                            'discriminator':
                            (discriminator.module if cfg.train.n_gpu > 1 else
                             discriminator).state_dict(),
                            'optim_g':
                            optim_g.state_dict(),
                            'optim_d':
                            optim_d.state_dict(),
                            'steps':
                            steps,
                            'epoch':
                            epoch
                        })

                # Tensorboard summary logging
                if steps % cfg.train.summary_interval == 0:
                    sw.add_scalar("training/gen_loss_total", g_loss, steps)
                    sw.add_scalar("training/gen_stft_error", g_stft_loss,
                                  steps)

                # Validation
                if steps % cfg.train.validation_interval == 0:
                    generator.eval()
                    torch.cuda.empty_cache()
                    val_err_tot = 0
                    with torch.no_grad():
                        for j, (y, x_noised_features,
                                x_noised_cond) in enumerate(val_loader):
                            y_hat = generator(
                                x_noised_features.transpose(1, 2).to(device),
                                x_noised_cond.to(device))
                            val_err_tot += criterion(y, y_hat).item()

                            if j <= 4:
                                # sw.add_audio('noised/y_noised_{}'.format(j), y_noised[0], steps, cfg.data.target_sample_rate)
                                sw.add_audio('generated/y_hat_{}'.format(j),
                                             y_hat[0], steps,
                                             cfg.data.sample_rate)
                                sw.add_audio('gt/y_{}'.format(j), y[0], steps,
                                             cfg.data.sample_rate)

                        val_err = val_err_tot / (j + 1)
                        sw.add_scalar("validation/stft_error", val_err, steps)

                    generator.train()

            steps += 1

        scheduler_g.step()
        scheduler_d.step()

        if rank == 0:
            print('Time taken for epoch {} is {} sec\n'.format(
                epoch + 1, int(time.time() - start)))
Exemple #7
0
loss_fn = nn.BCELoss()
loss_rec = nn.MSELoss()
controller = Segmentation(args.state_dim, args.policy_hidden_1,
                          args.option_num)
encoder = MLPEncoder(args.option_dim, 2 * args.option_dim, 1, 0, True)
decoder = MLPDecoder(args.option_dim, 1, 2 * args.option_dim,
                     2 * args.option_dim, 2 * args.option_dim, 0, False)
actor = Actor_with_option(args.state_dim, args.option_dim, args.action_dim)
optimizer = optim.Adam(list(encoder.parameters()) +
                       list(decoder.parameters()) + list(actor.parameters()) +
                       list(controller.parameters()),
                       lr=args.lr,
                       betas=(0.5, 0.999))

discriminator = Discriminator(args.state_dim, args.action_dim)
optimizer_discriminator = torch.optim.Adam(discriminator.parameters(),
                                           lr=1e-4,
                                           betas=(0.5, 0.999))
warm_start_optimizer = optim.Adam(list(controller.parameters()),
                                  lr=1e-2,
                                  betas=(0.5, 0.999))

if args.CUDA:
    policy.cuda()
    encoder.cuda()
    decoder.cuda()
    controller.cuda()

meta_train_tasks = [0, 1, 2, 4]
meta_test_tasks = [5, 6]
Exemple #8
0
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

t = transforms.Compose([transforms.Resize(64),
                        transforms.CenterCrop(64),
                        transforms.ToTensor(),
                        lambda x: x * 2 - 1])  # Scaling to -1, 1

dataset = torchvision.datasets.CelebA('G:/Datasets', download=DOWNLOAD, transform=t)
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, drop_last=True)

G = Generator(Z_DIM).to(device)
D = Discriminator().to(device)
d_loss = D_loss().to(device)
g_loss = G_loss().to(device)

optim_D = torch.optim.Adam(D.parameters(), lr=1e-5, betas=(0.5, 0.999))
optim_G = torch.optim.Adam(G.parameters(), lr=1e-4, betas=(0.5, 0.999))

d_count = 0
for e in range(EPOCHS):
    for x, _ in dataloader:
        for p in D.parameters():
            p.requires_grad = True

        x = x.to(device)
        D.zero_grad()
        d_count += 1

        Dx = D(x)
        z = torch.randn((BATCH_SIZE, Z_DIM)).to(device)
        Gz = G(z)
Exemple #9
0
class Trainer(object):
    def __init__(self, data_loader, config):
        self.dataloader = data_loader
        self.imsize = config.imsize
        self.batch_size = config.batch_size
        self.g_lr = config.g_lr
        self.d_lr = config.d_lr
        self.g_dim = config.g_dim
        self.d_dim = config.d_dim
        self.beta1 = config.beta1
        self.beta2 = config.beta2
        self.lambda_gp = config.lambda_gp

        self.z_dim = config.z_dim
        self.num_iters = config.total_step
        self.num_iters_decay = config.iter_start_decay
        self.log_step = config.log_step
        self.sample_step = config.sample_step
        self.lr_update_step = config.lr_iter_decay
        self.lr_decay = config.lr_decay
        self.model_save_step = config.model_save_step
        self.resume_iters = config.resume_iter
        self.version = config.version

        self.device = torch.device('cuda:0')

        self.sample_path = os.path.join(config.sample_path, self.version)
        self.model_save_dir = os.path.join(config.model_save_path,
                                           self.version)
        self.build_model()

    def build_model(self):
        self.G = Generator(image_size=self.imsize,
                           z_dim=self.z_dim,
                           conv_dim=self.g_dim)
        self.D = Discriminator(conv_dim=self.d_dim)

        self.g_optimizer = torch.optim.Adam(self.G.parameters(), self.g_lr,
                                            [self.beta1, self.beta2])
        self.d_optimizer = torch.optim.Adam(self.D.parameters(), self.d_lr,
                                            [self.beta1, self.beta2])
        self.print_network(self.G, 'G')
        self.print_network(self.D, 'D')

        self.G.to(self.device)
        self.D.to(self.device)

    def reset_grad(self):
        self.g_optimizer.zero_grad()
        self.d_optimizer.zero_grad()

    def update_lr(self, g_lr, d_lr):
        for param_group in self.g_optimizer.param_groups:
            param_group['lr'] = g_lr
        for param_group in self.d_optimizer.param_groups:
            param_group['lr'] = d_lr

    def print_network(self, model, name):
        num_params = 0
        for p in model.parameters():
            num_params += p.numel()
        print(model)
        print(name)
        print("The number of parameters: {}".format(num_params))

    def restore_model(self, resume_iters):
        print(
            'Loading the trained models from step {}...'.format(resume_iters))
        G_path = os.path.join(self.model_save_dir,
                              '{}-G.ckpt'.format(resume_iters))
        D_path = os.path.join(self.model_save_dir,
                              '{}-D.ckpt'.format(resume_iters))
        self.G.load_state_dict(
            torch.load(G_path, map_location=lambda storage, loc: storage))
        self.D.load_state_dict(
            torch.load(D_path, map_location=lambda storage, loc: storage))

    def gradient_penalty(self, y, x):
        weight = torch.ones(y.size()).to(self.device)
        dydx = torch.autograd.grad(outputs=y,
                                   inputs=x,
                                   grad_outputs=weight,
                                   retain_graph=True,
                                   create_graph=True,
                                   only_inputs=True)[0]

        dydx = dydx.view(dydx.size(0), -1)
        dydx_l2norm = torch.sqrt(torch.sum(dydx**2, dim=1))
        return torch.mean((dydx_l2norm - 1)**2)

    def train(self):
        loss = {}
        vis = visdom.Visdom()

        data_iter = iter(self.dataloader)
        g_lr = self.g_lr
        d_lr = self.d_lr
        fixed_z = torch.randn(self.batch_size, self.z_dim).cuda()

        start_iters = 0
        if self.resume_iters:
            start_iters = self.resume_iters
            self.restore_model(self.resume_iters)

        print('start training')
        start_time = time.time()
        for i in range(start_iters, self.num_iters):
            try:
                x_mb, _ = next(data_iter)
            except:
                data_iter = iter(self.dataloader)
                x_mb, _ = next(data_iter)
            x_mb = x_mb.cuda()
            z = torch.randn(x_mb.size(0), self.z_dim).cuda()
            # train the discriminator
            x_fake = self.G(z)
            d_real = self.D(x_mb)
            d_fake = self.D(x_fake)
            d_loss_real = -torch.mean(d_real)
            d_loss_fake = torch.mean(d_fake)
            alpha = torch.rand(x_mb.size(0), 1, 1, 1).to(self.device)
            # interpolate between real data and fake data
            x_hat = (alpha * x_mb.data +
                     (1 - alpha) * x_fake.data).requires_grad_(True)
            out_src = self.D(x_hat)

            d_loss_gp = self.gradient_penalty(out_src, x_hat)
            d_loss = d_loss_real + d_loss_fake + self.lambda_gp * d_loss_gp

            d_loss.backward()
            self.d_optimizer.step()
            self.reset_grad()

            loss['D/loss_real'] = d_loss_real.item()
            loss['D/loss_fake'] = d_loss_fake.item()
            loss['D/loss_gp'] = d_loss_gp.item()
            # train generator
            d_fake = self.D(self.G(z))
            g_loss = -torch.mean(d_fake)

            g_loss.backward()
            self.g_optimizer.step()
            self.reset_grad()

            loss['G/loss'] = g_loss.item()
            if (i + 1) % self.log_step == 0:
                # visualize real and fake imgs
                vis.images((x_fake + 1) / 2, win='fake_imgs')
                vis.images((x_mb + 1) / 2, win='real_imgs')
                # print and visualize losses
                et = time.time() - start_time
                et = str(datetime.timedelta(seconds=et))[:-7]
                log = "Elapsed [{}], Iteration [{}/{}]".format(
                    et, i + 1, self.num_iters)
                for tag, value in loss.items():
                    log += ", {}: {:.4f}".format(tag, value)
                opts = dict(title='Losses',
                            width=13,
                            height=10,
                            legend=list(loss.keys()))
                vis.line(Y=[list(loss.values())], X=[np.ones(len(loss))*(i+1)], win='Losses', \
                         update='append', opts=opts)
                print(log)

            if (i + 1) % self.lr_update_step == 0 and (
                    i + 1) > self.num_iters_decay:
                g_lr = self.g_lr * self.lr_decay
                d_lr = self.d_lr * self.lr_decay
                self.update_lr(g_lr, d_lr)
                print('Decayed learning rates, g_lr: {}, d_lr: {}.'.format(
                    g_lr, d_lr))

            # Sample images
            if (i + 1) % self.sample_step == 0:
                fake_images = self.G(fixed_z)
                save_image(
                    denorm(fake_images.data),
                    os.path.join(self.sample_path,
                                 '{}_fake.png'.format(i + 1)))

            if (i + 1) % self.model_save_step == 0:
                G_path = os.path.join(self.model_save_dir,
                                      '{}-G.ckpt'.format(i + 1))
                D_path = os.path.join(self.model_save_dir,
                                      '{}-D.ckpt'.format(i + 1))
                torch.save(self.G.state_dict(), G_path)
                torch.save(self.D.state_dict(), D_path)
                print('Saved model checkpoints into {}...'.format(
                    self.model_save_dir))
Exemple #10
0
class Solver(object):
    def __init__(self, config, dataloader):
        self.dataloader = dataloader
        self.data_size = config.data_size
        # self.iters = config.iters
        self.loss_type = config.loss_type
        self.G_lr = config.G_lr
        self.D_lr = config.D_lr
        self.beta1 = config.momentum
        self.batch_size = config.batch_size
        self.max_epoch = config.max_epoch
        self.z_dim = config.z_dim
        self.lr_update_step = config.lr_update_step
        self.lr_decay_after = config.lr_decay_after
        self.lr_decay = config.lr_decay
        # path
        self.sample_path = os.path.join(config.main_path, 'samples')
        self.ckpt_path = os.path.join(config.main_path, 'checkpoints')
        # misc
        self.log_step = config.log_step
        self.eval_step = config.eval_step
        self.save_step = config.save_step

        self.device = torch.device(
            'cuda' if torch.cuda.is_available() else 'cpu')
        self.build_model()

    def build_model(self):
        self.G = Generator()
        self.D = Discriminator()

        self.G_optim = optim.Adam(self.G.parameters(), self.G_lr,
                                  (self.beta1, 0.999))
        self.D_optim = optim.Adam(self.D.parameters(), self.D_lr,
                                  (self.beta1, 0.999))

        if self.loss_type == 'BCEwL':
            self.criterion = nn.BCEWithLogitsLoss()
        elif self.loss_type == 'WGAN':
            pass
        elif self.loss_type == 'WGAN+':
            pass

        self.fixed_sample = None
        self.fixed_noise = None

        # self.true = torch.ones([self.batch_size, 1, 1, 1], requires_grad=False).to(self.device)
        # self.false = torch.zeros([self.batch_size, 1, 1, 1], requires_grad=False).to(self.device)

        # Change to GPU mode
        print('Change CPU mode to GPU mode...')
        self.G.to(self.device)
        self.D.to(self.device)
        print('Creating models are success...')

    def restore_model(self, resume_iters):
        print('Load the trained models from step {}...'.format(resume_iters))
        G_path = os.path.join(self.ckpt_path, '{}-G.ckpt'.format(resume_iters))
        D_path = os.path.join(self.ckpt_path, '{}-D.ckpt'.format(resume_iters))
        self.G.load_state_dict(torch.load(G_path))
        self.D.load_state_dict(torch.load(D_path))

    def train(self):
        iters = self.max_epoch * len(self.dataloader)
        data_iter = iter(self.dataloader)
        self.fixed_sample = next(data_iter)
        self.fixed_noise = torch.randn(self.batch_size,
                                       self.z_dim).to(self.device)
        num_data = 0
        start_time = time.time()
        print('Start training...')
        for i in range(iters):
            # try:
            #     sample = next(data_iter)
            # except:
            #     print('error occur')
            #     data_iter = iter(self.dataloader)
            #     sample = next(data_iter)
            sample = next(data_iter)
            if i % len(self.dataloader) == 0:
                data_iter = iter(self.dataloader)
            # Load data.
            right_embd = sample['right_embd'].to(self.device)
            wrong_embd = sample['wrong_embd'].to(self.device)
            z_noise = torch.randn(right_embd.size(0),
                                  self.z_dim).to(self.device)
            real_img = sample['real_img'].to(self.device)
            fake_img = self.G(right_embd, z_noise)
            # print('right_embd size: {}'.format(right_embd.size()))
            # print('wrong_embd size: {}'.format(wrong_embd.size()))
            # print('real_img size: {}'.format(real_img.size()))
            num_data += right_embd.size(0)
            T = torch.ones([right_embd.size(0), 1, 1, 1],
                           requires_grad=False).to(self.device)
            F = torch.zeros([right_embd.size(0), 1, 1, 1],
                            requires_grad=False).to(self.device)
            ## Train Discriminator.
            sr = self.D(real_img, right_embd)  # {real image, right text}
            rr_loss = self.criterion(sr, T)
            sw = self.D(real_img, wrong_embd)  # {real image, wrong text}
            rw_loss = self.criterion(sw, F)
            sf = self.D(fake_img.detach(),
                        right_embd)  # {fake image, right text}
            fr_loss = self.criterion(sf, F)
            d_loss = rr_loss + rw_loss + fr_loss
            ## Backward and optimize for D.
            self.D_optim.zero_grad()
            d_loss.backward()
            self.D_optim.step()
            # For logs
            loss = {}
            loss['D/rr_loss'] = rr_loss.item()
            loss['D/rw_loss'] = rw_loss.item()
            loss['D/fr_loss'] = fr_loss.item()
            loss['D/d_loss'] = d_loss.item()

            ## Train Generator.
            sf = self.D(fake_img, right_embd)
            g_loss = self.criterion(sf, T)
            ## Backward and optimize for G.
            self.G_optim.zero_grad()
            g_loss.backward()
            self.G_optim.step()
            loss['G/g_loss'] = g_loss.item()

            ## Print training information.
            if (i + 1) % self.log_step == 0:
                et = time.time() - start_time
                et = str(datetime.timedelta(seconds=et))[:-7]
                logs = "Elapsed [{}], Iter [{}/{}], Epoch [{}/{}]".format(
                    et, i + 1, iters, (i + 1) / len(self.dataloader),
                    self.max_epoch)
                logs += ", Dataset [{}/{}]".format(num_data % self.data_size,
                                                   self.data_size)
                for tag, value in loss.items():
                    logs += ', {} [{:.4f}]'.format(tag, value)
                print(logs)
            ## Debug sample images.
            if (i + 1) % self.eval_step == 0:  #will be modified.
                with torch.no_grad():
                    image_path = os.path.join(self.sample_path,
                                              '{}.jpg'.format(i + 1))
                    fake_img = self.G(self.fixed_sample['right_embd'].to(
                        self.device), self.fixed_noise)  #size: [B, 3, 64, 64]
                    real_img = self.fixed_sample['real_img']
                    img_list = []
                    for row in range(int(self.batch_size /
                                         8)):  #print multiple of 8 samples
                        img_list += [
                            real_img[row * 8 + col] for col in range(8)
                        ]
                        img_list += [
                            fake_img[row * 8 + col].to('cpu')
                            for col in range(8)
                        ]
                    sample_name = os.path.join(self.sample_path,
                                               '{}iter.jpg'.format(i + 1))
                    save_image(make_grid(img_list), sample_name)
                print('Save generated sample results {}iter.jpg into {}...'.
                      format(i + 1, self.sample_path))
            ## Save model checkpoints.
            if (i + 1) % self.save_step == 0:
                G_path = os.path.join(self.ckpt_path,
                                      '{}-G.ckpt'.format(i + 1))
                D_path = os.path.join(self.ckpt_path,
                                      '{}-D.ckpt'.format(i + 1))
                torch.save(self.G.state_dict(), G_path)
                torch.save(self.D.state_dict(), D_path)
                print('Save model checkpoints into {}...'.format(
                    self.ckpt_path))
            ## Decay learning rates.
            if (i + 1) % self.lr_update_step == 0:
                if (i + 1) >= self.lr_decay_after:
                    self.G_lr = self.G_lr * self.lr_decay
                    self.D_lr = self.D_lr * self.lr_decay
                    for param_group in self.G_optim.param_groups:
                        param_group['lr'] = self.G_lr
                    for param_group in self.D_optim.param_groups:
                        param_group['lr'] = self.D_lr
                print('Decay learning rates, g_lr: {}, d_lr: {}...'.format(
                    self.G_lr, self.D_lr))

    def test(self):
        pass
Exemple #11
0
def val_test(args):
    writer = SummaryWriter('./logs/{0}'.format(args.output_folder))
    save_filename = './models/{0}'.format(args.output_folder)

    d_args = vars(args)

    num_perturb_types = 0
    perturb_types = []
    for perturb_type in train_util.perturb_dict:
        if d_args[perturb_type]:
            num_perturb_types += 1
            perturb_types.append(perturb_type)

    train_loader, valid_loader, test_loader = train_util.get_dataloaders(
        args, perturb_types)

    recons_input_img = train_util.log_input_img_grid(test_loader, writer)

    input_dim = 3
    model = PCIE(args.img_res, input_dim, args.hidden_size, num_perturb_types,
                 args.enc_type, args.dec_type)
    interp_disc = Discriminator(input_dim, args.img_res,
                                args.input_type).to(args.device)
    interp_disc_opt = torch.optim.Adam(interp_disc.parameters(),
                                       lr=args.disc_lr,
                                       amsgrad=True)
    # if torch.cuda.device_count() > 1 and args.device == "cuda":
    # 	model = torch.nn.DataParallel(model)

    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)

    # ae_lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, "min", patience=args.lr_patience, factor=0.5,
    # 	threshold=args.threshold, threshold_mode="abs", min_lr=1e-7)

    # interp_disc_lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(interp_disc_opt, "min", patience=args.lr_patience, factor=0.5,
    # 	threshold=args.threshold, threshold_mode="abs", min_lr=1e-7)

    discriminators = {"interp_disc": [interp_disc, interp_disc_opt]}

    if args.recons_loss != "mse":
        if args.recons_loss == "gan":
            recons_disc = Discriminator(input_dim, args.img_res,
                                        args.input_type).to(args.device)
        elif args.recons_loss == "comp":
            recons_disc = AnchorComparator(input_dim * 2, args.img_res,
                                           args.input_type).to(args.device)
        elif "comp_2" in args.recons_loss:
            recons_disc = ClubbedPermutationComparator(
                input_dim * 2, args.img_res, args.input_type).to(args.device)
        elif "comp_6" in args.recons_loss:
            recons_disc = FullPermutationComparator(
                input_dim * 2, args.img_res, args.input_type).to(args.device)

        recons_disc_opt = torch.optim.Adam(recons_disc.parameters(),
                                           lr=args.disc_lr,
                                           amsgrad=True)
        recons_disc_lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
            recons_disc_opt,
            "min",
            patience=args.lr_patience,
            factor=0.5,
            threshold=args.threshold,
            threshold_mode="abs",
            min_lr=1e-7)

        discriminators["recons_disc"] = [recons_disc, recons_disc_opt]

    if args.prior_loss == "gan":
        prior_disc = Latent2ClassDiscriminator(
            args.hidden_size, args.img_res // args.scale_factor)
        prior_disc_opt = torch.optim.Adam(prior_disc.parameters(),
                                          lr=args.disc_lr,
                                          amsgrad=True)
        # prior_disc_lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(prior_disc_opt, "min", patience=args.lr_patience, factor=0.5,
        # threshold=args.threshold, threshold_mode="abs", min_lr=1e-7)

        discriminators["prior_disc"] = [prior_disc, prior_disc_opt]

    print("pertrub gans")

    if args.perturb_feat_gan:
        for perturb_type in train_util.perturb_dict:
            if d_args[perturb_type]:
                num_classes = train_util.perturb_dict[perturb_type].num_class
                pdb.set_trace()
                if num_classes == 2:

                    print(f"perturb:{d_args[perturb_type]}\ttype: two ")
                    pert_disc = Latent2ClassDiscriminator(
                        args.hidden_size, args.img_res // args.scale_factor)
                    pert_disc_opt = torch.optim.Adam(pert_disc.parameters(),
                                                     lr=args.disc_lr,
                                                     amsgrad=True)
                else:
                    print(f"perturb:{d_args[perturb_type]}\ttype: multi ")
                    pert_disc = LatentMultiClassDiscriminator(
                        args.hidden_size, args.img_res // args.scale_factor,
                        num_classes)
                    pert_disc_opt = torch.optim.Adam(pert_disc.parameters(),
                                                     lr=args.disc_lr,
                                                     amsgrad=True)

                discriminators[f"{perturb_type}_disc"] = (pert_disc,
                                                          pert_disc_opt)

    print("perrturb gans set")

    model.to(args.device)
    for disc in discriminators:
        discriminators[disc][0].to(args.device)

    # Generate the samples first once
    # train_util.log_recons_img_grid(recons_input_img, model, 0, args.device, writer)

    if args.weights == "load":
        start_epoch = train_util.load_state(save_filename, model, optimizer,
                                            discriminators)
    else:
        start_epoch = 0

    # stop_patience = args.stop_patience
    best_loss = torch.tensor(np.inf)
    for epoch in tqdm(range(start_epoch, 4), file=sys.stdout):
        val_loss_dict = train_util.test(get_losses, model, valid_loader, args,
                                        discriminators, True)
        if args.weights == "init" and epoch == 1:
            epoch += 1
            break

    train_util.log_losses("val", val_loss_dict, epoch + 1, writer)
    train_util.log_recons_img_grid(recons_input_img, model, epoch, args.device,
                                   writer)
    train_util.log_interp_img_grid(recons_input_img, model, epoch + 1,
                                   args.device, writer)

    train_util.save_state(model, optimizer, discriminators,
                          val_loss_dict["recons_loss"], best_loss,
                          args.recons_loss, epoch, save_filename)

    print(val_loss_dict)
Exemple #12
0
def main(args):
    writer = SummaryWriter('./logs/{0}'.format(args.output_folder))
    save_filename = './models/{0}'.format(args.output_folder)

    d_args = vars(args)
    pert_types = train_util.get_perturb_types(args)
    train_loader, valid_loader, test_loader = train_util.get_dataloaders(
        args, pert_types)
    recons_input_img = train_util.log_input_img_grid(test_loader, writer)

    # print(f"nn num:{len(pert_types)}")
    num_perturb_types = len(pert_types)

    input_dim = 3
    model = PCIE(args.img_res, input_dim, args.hidden_size, num_perturb_types,
                 args.enc_type, args.dec_type)
    interp_disc = Discriminator(input_dim, args.img_res,
                                args.input_type).to(args.device)
    interp_disc_opt = torch.optim.Adam(interp_disc.parameters(),
                                       lr=args.disc_lr,
                                       amsgrad=True)
    # if torch.cuda.device_count() > 1 and args.device == "cuda":
    # 	model = torch.nn.DataParallel(model)

    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)

    # ae_lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, "min", patience=args.lr_patience, factor=0.5,
    # 	threshold=args.threshold, threshold_mode="abs", min_lr=1e-7)

    # interp_disc_lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(interp_disc_opt, "min", patience=args.lr_patience, factor=0.5,
    # 	threshold=args.threshold, threshold_mode="abs", min_lr=1e-7)

    discriminators = {"interp_disc": [interp_disc, interp_disc_opt]}

    if args.recons_loss != "mse":
        if args.recons_loss == "gan":
            recons_disc = Discriminator(input_dim, args.img_res,
                                        args.input_type).to(args.device)
        elif args.recons_loss == "comp":
            recons_disc = AnchorComparator(input_dim * 2, args.img_res,
                                           args.input_type).to(args.device)
        elif "comp_2" in args.recons_loss:
            recons_disc = ClubbedPermutationComparator(
                input_dim * 2, args.img_res, args.input_type).to(args.device)
        elif "comp_6" in args.recons_loss:
            if "color" in args.recons_loss:
                recons_disc = FullPermutationColorComparator(
                    input_dim * 2, args.img_res,
                    args.input_type).to(args.device)
            else:
                recons_disc = FullPermutationComparator(
                    input_dim * 2, args.img_res,
                    args.input_type).to(args.device)

        recons_disc_opt = torch.optim.Adam(recons_disc.parameters(),
                                           lr=args.disc_lr,
                                           amsgrad=True)
        recons_disc_lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
            recons_disc_opt,
            "min",
            patience=args.lr_patience,
            factor=0.5,
            threshold=args.threshold,
            threshold_mode="abs",
            min_lr=1e-7)

        discriminators["recons_disc"] = [recons_disc, recons_disc_opt]

    if args.prior_loss == "gan":
        prior_disc = Latent2ClassDiscriminator(
            args.hidden_size, args.img_res // args.scale_factor)
        prior_disc_opt = torch.optim.Adam(prior_disc.parameters(),
                                          lr=args.disc_lr,
                                          amsgrad=True)
        # prior_disc_lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(prior_disc_opt, "min", patience=args.lr_patience, factor=0.5,
        # threshold=args.threshold, threshold_mode="abs", min_lr=1e-7)

        discriminators["prior_disc"] = [prior_disc, prior_disc_opt]

    if args.perturb_feat_gan:
        for perturb_type in train_util.perturb_dict:
            if d_args[perturb_type]:
                num_class = train_util.perturb_dict[perturb_type].num_class
                if num_class == 2:
                    pert_disc = Latent2ClassDiscriminator(
                        args.hidden_size, args.img_res // args.scale_factor)
                    pert_disc_opt = torch.optim.Adam(pert_disc.parameters(),
                                                     lr=args.disc_lr,
                                                     amsgrad=True)
                else:
                    pert_disc = LatentMultiClassDiscriminator(
                        args.hidden_size, args.img_res // args.scale_factor,
                        num_class)
                    pert_disc_opt = torch.optim.Adam(pert_disc.parameters(),
                                                     lr=args.disc_lr,
                                                     amsgrad=True)

                discriminators[f"{perturb_type}_disc"] = (pert_disc,
                                                          pert_disc_opt)

    model.to(args.device)
    for disc in discriminators:
        discriminators[disc][0].to(args.device)

    # Generate the samples first once
    # train_util.log_recons_img_grid(recons_input_img, model, 0, args.device, writer)

    if args.weights == "load":
        start_epoch = train_util.load_state(save_filename, model, optimizer,
                                            discriminators)
    else:
        start_epoch = 0

    # stop_patience = args.stop_patience
    best_loss = torch.tensor(np.inf)
    for epoch in tqdm(range(start_epoch, args.num_epochs), file=sys.stdout):

        # for CUDA OOM error, prevents running dependency job on slurm which is meant to run on timeout
        try:
            train(epoch, train_loader, model, optimizer, args, writer,
                  discriminators)
            # pass
        except RuntimeError as err:
            print("".join(
                traceback.TracebackException.from_exception(err).format()),
                  file=sys.stderr)
            print("*******", file=sys.stderr)
            print(err, file=sys.stderr)
            exit(0)

        print("out of train")
        # comp = subprocess.run("nvidia-smi --query-gpu=memory.free --format=csv,noheader,nounits", text=True, stdout=subprocess.PIPE)
        # print(comp.stdout, file=sys.stderr)
        val_loss_dict, _ = ae_test(get_losses, model, valid_loader, args,
                                   discriminators)
        train_util.log_losses("val", val_loss_dict, epoch + 1, writer)
        # print("logg loss")
        # train_util.log_latent_metrics("val", z, epoch+1, writer)
        # print("log metric")
        train_util.save_recons_img_grid("test", recons_input_img, model,
                                        epoch + 1, args)
        # print("log recons")
        train_util.save_interp_img_grid("test", recons_input_img, model,
                                        epoch + 1, args)
        # print("log interp")
        train_util.save_state(model, optimizer, discriminators,
                              val_loss_dict["recons_loss"], best_loss,
                              args.recons_loss, epoch, save_filename)
Exemple #13
0
    discriminator.cuda()

prior_batch_size = opt.batch_size * opt.prior_factor if opt.learning_type == 'smallgan' else opt.batch_size

# Valid dataset
valid_samples = np.random.uniform(-1, 1, (opt.valid_size, opt.latent_dim))
valid_dataloader = torch.utils.data.DataLoader(
    valid_samples,
    batch_size=opt.valid_batch_size,
    shuffle=False,
    num_workers=opt.n_cpu,
)

# Optimizers
optimizer_G = torch.optim.Adam(generator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))

criterion = nn.BCEWithLogitsLoss()

# Train
for global_step in range(opt.train_steps):
    # Configure input and apply GreedyCoreset
    real_data = Tensor(next(train_data))
    if opt.learning_type == 'smallgan':
        real_data = greedy_core_set(real_data, opt.batch_size)
    optimizer_D.zero_grad()

    # Sample noise as generator input and apply GreedyCoreset
    z = Tensor(np.random.uniform(-1, 1, (prior_batch_size, opt.latent_dim)))
    if opt.learning_type == 'smallgan':
        z = greedy_core_set(z, opt.batch_size)
Exemple #14
0
                                           batch_size=64,
                                           shuffle=False,
                                           **kwargs)

test_loader = torch.utils.data.DataLoader(datasets.MNIST(
    '../data/fashion_mnist/', train=False, transform=transforms.ToTensor()),
                                          batch_size=32,
                                          shuffle=False,
                                          **kwargs)
test_data = list(test_loader)

netG = Generator().cuda()
netD = Discriminator().cuda()

optimizerG = torch.optim.Adam(netG.parameters(), lr=1e-4, betas=(0.5, 0.9))
optimizerD = torch.optim.Adam(netD.parameters(), lr=1e-4, betas=(0.5, 0.9))

one = torch.cuda.FloatTensor([1])
mone = one * -1


def train(epoch):
    train_loss = []
    for batch_idx, (data, _) in enumerate(train_loader):
        start_time = time.time()

        if data.size(0) != 64:
            continue
        x_real = Variable(data, requires_grad=False).cuda()

        netD.zero_grad()
Exemple #15
0
def main(args):

    input_dim = 3
    model = VAE(input_dim, args.hidden_size, args.enc_type, args.dec_type)

    opt = torch.optim.Adam(model.parameters(), lr=args.lr, amsgrad=True)
    # scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(opt, "min", patience=args.lr_patience, factor=0.5,
    # 	threshold=args.threshold, threshold_mode="abs", min_lr=1e-6)

    # ae_lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(opt, "min", patience=args.lr_patience, factor=0.5,
    # 	threshold=args.threshold, threshold_mode="abs", min_lr=1e-7)

    discriminators = {}

    if args.recons_loss != "mse":
        if args.recons_loss == "gan":
            recons_disc = Discriminator(input_dim, args.img_res,
                                        args.input_type).to(args.device)
        elif args.recons_loss == "comp":
            recons_disc = AnchorComparator(input_dim * 2, args.img_res,
                                           args.input_type).to(args.device)
        elif "comp_2" in args.recons_loss:
            recons_disc = ClubbedPermutationComparator(
                input_dim * 2, args.img_res, args.input_type).to(args.device)
        elif "comp_6" in args.recons_loss:
            recons_disc = FullPermutationComparator(
                input_dim * 2, args.img_res, args.input_type).to(args.device)

        recons_disc_opt = torch.optim.Adam(recons_disc.parameters(),
                                           lr=args.disc_lr,
                                           amsgrad=True)

        discriminators["recons_disc"] = [recons_disc, recons_disc_opt]

    if torch.cuda.device_count() > 1:
        model = train_util.ae_data_parallel(model)
        for disc in discriminators:
            discriminators[disc][0] = torch.nn.DataParallel(
                discriminators[disc][0])

    model.to(args.device)
    for disc in discriminators:
        discriminators[disc][0].to(args.device)

    print("model built", file=sys.stderr)
    #print("model created")
    train_loader, val_loader, test_loader = train_util.get_dataloaders(args)
    print("loaders acquired", file=sys.stderr)
    #print("loaders acquired")

    model_name = f"vae_{args.recons_loss}"
    if args.output_folder is None:
        args.output_folder = os.path.join(
            model_name, args.dataset,
            f"depth_{args.enc_type}_{args.dec_type}_hs_{args.img_res}_{args.hidden_size}"
        )

    log_save_path = os.path.join("./logs", args.output_folder)
    model_save_path = os.path.join("./models", args.output_folder)

    if not os.path.exists(log_save_path):
        os.makedirs(log_save_path)
        print(f"log:{log_save_path}", file=sys.stderr)
        sys.stderr.flush()
    if not os.path.exists(model_save_path):
        os.makedirs(model_save_path)

    writer = SummaryWriter(log_save_path)

    print(f"train loader length:{len(train_loader)}", file=sys.stderr)
    best_loss = torch.tensor(np.inf)

    if args.weights == "load":
        start_epoch = train_util.load_state(model_save_path, model, opt,
                                            discriminators)
    else:
        start_epoch = 0

    recons_input_img = train_util.log_input_img_grid(test_loader, writer)

    train_util.log_recons_img_grid(recons_input_img, model, 0, args.device,
                                   writer)

    stop_patience = args.stop_patience
    for epoch in range(start_epoch, args.num_epochs):

        try:
            train(model, train_loader, opt, epoch, writer, args,
                  discriminators)
        except RuntimeError as err:
            print("".join(
                traceback.TracebackException.from_exception(err).format()),
                  file=sys.stderr)
            print("*******", file=sys.stderr)
            print(err, file=sys.stderr)
            exit(0)

        val_loss_dict, z = train_util.test(get_losses, model, val_loader, args,
                                           discriminators)
        print(f"epoch loss:{val_loss_dict['recons_loss'].item()}")

        train_util.save_recons_img_grid("test", recons_input_img, model,
                                        epoch + 1, args)
        train_util.save_interp_img_grid("test", recons_input_img, model,
                                        epoch + 1, args)

        train_util.log_losses("val", val_loss_dict, epoch + 1, writer)
        train_util.log_latent_metrics("val", z, epoch + 1, writer)
        train_util.save_state(model, opt, discriminators,
                              val_loss_dict["recons_loss"], best_loss,
                              args.recons_loss, epoch, model_save_path)
Exemple #16
0
def main(args):

    writer = SummaryWriter('./logs/{0}'.format(args.output_folder))
    save_filename = './models/{0}'.format(args.output_folder)

    train_loader, val_loader, test_loader = train_util.get_dataloaders(args)
    recons_input_img = train_util.log_input_img_grid(test_loader, writer)

    input_dim = 3
    model = ACAI(args.img_res, input_dim, args.hidden_size, args.enc_type,
                 args.dec_type).to(args.device)
    disc = Discriminator(input_dim, args.img_res,
                         args.input_type).to(args.device)
    disc_opt = torch.optim.Adam(disc.parameters(),
                                lr=args.disc_lr,
                                amsgrad=True)
    # if torch.cuda.device_count() > 1 and args.device == "cuda":
    # 	model = torch.nn.DataParallel(model)

    opt = torch.optim.Adam(model.parameters(), lr=args.lr)

    # ae_lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, "min", patience=args.lr_patience, factor=0.5,
    # 	threshold=args.threshold, threshold_mode="abs", min_lr=1e-7)

    # interp_disc_lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(disc_opt, "min", patience=args.lr_patience, factor=0.5,
    # 	threshold=args.threshold, threshold_mode="abs", min_lr=1e-7)

    discriminators = {"interp_disc": [disc, disc_opt]}
    if args.recons_loss != "mse":
        if args.recons_loss == "gan":
            recons_disc = Discriminator(input_dim, args.img_res,
                                        args.input_type).to(args.device)
        elif args.recons_loss == "comp":
            recons_disc = AnchorComparator(input_dim * 2, args.img_res,
                                           args.input_type).to(args.device)
        elif "comp_2" in args.recons_loss:
            recons_disc = ClubbedPermutationComparator(
                input_dim * 2, args.img_res, args.input_type).to(args.device)
        elif "comp_6" in args.recons_loss:
            recons_disc = FullPermutationComparator(
                input_dim * 2, args.img_res, args.input_type).to(args.device)

        recons_disc_opt = torch.optim.Adam(recons_disc.parameters(),
                                           lr=args.disc_lr,
                                           amsgrad=True)

        discriminators["recons_disc"] = [recons_disc, recons_disc_opt]

    # Generate the samples first once
    train_util.save_recons_img_grid("test", recons_input_img, model, 0, args)

    if args.weights == "load":
        start_epoch = train_util.load_state(save_filename, model, opt,
                                            discriminators)
    else:
        start_epoch = 0

    best_loss = torch.tensor(np.inf)
    for epoch in range(args.num_epochs):
        print("Epoch {}:".format(epoch))
        train(model, opt, train_loader, args, discriminators, writer)

        # curr_loss = val(model, val_loader)
        # print(f"epoch val loss:{curr_loss}")

        val_loss_dict, z = train_util.test(get_losses, model, val_loader, args,
                                           discriminators)

        train_util.log_losses("val", val_loss_dict, epoch + 1, writer)
        train_util.log_latent_metrics("val", z, epoch + 1, writer)

        # train_util.log_recons_img_grid(recons_input_img, model, epoch+1, args.device, writer)
        # train_util.log_interp_img_grid(recons_input_img, model, epoch+1, args.device, writer)

        train_util.save_recons_img_grid("val", recons_input_img, model,
                                        epoch + 1, args)
        train_util.save_interp_img_grid("val", recons_input_img, model,
                                        epoch + 1, args)

        train_util.save_state(model, opt, discriminators,
                              val_loss_dict["recons_loss"], best_loss,
                              args.recons_loss, epoch, save_filename)
Exemple #17
0
class Trainer(object):
    def __init__(self, celeba_loader, config):
        # miscellaneous
        self.use_tensorboard = config.use_tensorboard
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

        # data loader
        self.dataload = celeba_loader

        # model configurations
        self.c64 = config.c64
        self.c256 = config.c256
        self.c2048 = config.c2048
        self.rb6 = config.rb6
        self.attr_dim = config.attr_dim
        self.hair_dim = config.hair_dim

        # training configurations
        self.selected_attrs = config.selected_attrs
        self.train_iters = config.train_iters
        self.num_iters_decay = config.num_iters_decay
        self.n_critic = config.n_critic
        self.d_lr = config.d_lr
        self.r_lr = config.r_lr
        self.t_lr = config.t_lr
        self.e_lr = config.e_lr
        self.decay_rate = config.decay_rate
        self.beta1 = config.beta1
        self.beta2 = config.beta2
        self.lambda_cls = config.lambda_cls
        self.lambda_cyc = config.lambda_cyc
        self.lambda_gp = config.lambda_gp

        # test configurations
        self.test_iters = config.test_iters

        # directories
        self.sample_dir = config.sample_dir
        self.model_save_dir = config.model_save_dir
        self.result_dir = config.result_dir
        self.log_dir = config.log_dir

        # step size
        self.log_step = config.log_step
        self.sample_step = config.sample_step
        self.model_save_step = config.model_save_step
        self.lr_update_step = config.lr_update_step

        # initial models
        self.build_models()
        if self.use_tensorboard:
            self.build_tensorboard()

    def build_models(self):
        self.E = Encoder(self.c64, self.rb6)
        self.T_Hair = Transformer(self.hair_dim, self.c256, self.rb6)
        self.T_Gender = Transformer(self.attr_dim, self.c256, self.rb6)
        self.T_Smailing = Transformer(self.attr_dim, self.c256, self.rb6)
        self.R = Reconstructor(self.c256)
        self.D_Hair = Discriminator(self.hair_dim, self.c64)
        self.D_Gender = Discriminator(self.attr_dim, self.c64)
        self.D_Smailing = Discriminator(self.attr_dim, self.c64)

        self.e_optim = torch.optim.Adam(self.E.parameters(), self.e_lr, [self.beta1, self.beta2])
        self.th_optim = torch.optim.Adam(self.T_Hair.parameters(), self.t_lr, [self.beta1, self.beta2])
        self.tg_optim = torch.optim.Adam(self.T_Gender.parameters(), self.t_lr, [self.beta1, self.beta2])
        self.ts_optim = torch.optim.Adam(self.T_Smailing.parameters(), self.t_lr, [self.beta1, self.beta2])
        self.r_optim = torch.optim.Adam(self.R.parameters(), self.r_lr, [self.beta1, self.beta2])
        self.dh_optim = torch.optim.Adam(self.D_Hair.parameters(), self.d_lr, [self.beta1, self.beta2])
        self.dg_optim = torch.optim.Adam(self.D_Gender.parameters(), self.d_lr, [self.beta1, self.beta2])
        self.ds_optim = torch.optim.Adam(self.D_Smailing.parameters(), self.d_lr, [self.beta1, self.beta2])

        self.print_network(self.E, 'Encoder')
        self.print_network(self.T_Hair, 'Transformer for Hair Color')
        self.print_network(self.T_Gender, 'Transformer for Gender')
        self.print_network(self.T_Smailing, 'Transformer for Smailing')
        self.print_network(self.R, 'Reconstructor')
        self.print_network(self.D_Hair, 'D for Hair Color')
        self.print_network(self.D_Gender, 'D for Gender')
        self.print_network(self.D_Smailing, 'D for Smailing')

        self.E.to(self.device)
        self.T_Hair.to(self.device)
        self.T_Gender.to(self.device)
        self.T_Smailing.to(self.device)
        self.R.to(self.device)
        self.D_Gender.to(self.device)
        self.D_Smailing.to(self.device)
        self.D_Hair.to(self.device)

    def print_network(self, model, name):
        """Print out the network information."""
        num_params = 0
        for p in model.parameters():
            num_params += p.numel()

        print(name)
        print("The number of parameters: {}".format(num_params))
        print(model)
        
    def build_tensorboard(self):
        """Build a tensorboard logger."""
        from logger import Logger
        self.logger = Logger(self.log_dir)

    def gradient_penalty(self, y, x):
        """Compute gradient penalty: (L2_norm(dy/dx) - 1)**2."""
        weight = torch.ones(y.size()).to(self.device)
        dydx = torch.autograd.grad(outputs=y,
                                   inputs=x,
                                   grad_outputs=weight,
                                   retain_graph=True,
                                   create_graph=True,
                                   only_inputs=True)[0]

        dydx = dydx.view(dydx.size(0), -1)
        dydx_l2norm = torch.sqrt(torch.sum(dydx**2, dim=1))
        return torch.mean((dydx_l2norm-1)**2)

    def reset_grad(self):
        self.e_optim.zero_grad()
        self.th_optim.zero_grad()
        self.tg_optim.zero_grad()
        self.ts_optim.zero_grad()
        self.r_optim.zero_grad()
        self.dh_optim.zero_grad()
        self.dg_optim.zero_grad()
        self.ds_optim.zero_grad()

    def update_lr(self, e_lr, d_lr, r_lr, t_lr):
        """Decay learning rates of the generator and discriminator."""
        for param_group in self.e_optim.param_groups:
            param_group['lr'] = e_lr
        for param_group in self.dh_optim.param_groups:
            param_group['lr'] = d_lr
        for param_group in self.dg_optim.param_groups:
            param_group['lr'] = d_lr
        for param_group in self.ds_optim.param_groups:
            param_group['lr'] = d_lr
        for param_group in self.r_optim.param_groups:
            param_group['lr'] = r_lr
        for param_group in self.th_optim.param_groups:
            param_group['lr'] = t_lr
        for param_group in self.tg_optim.param_groups:
            param_group['lr'] = t_lr
        for param_group in self.ts_optim.param_groups:
            param_group['lr'] = t_lr

    def create_labels(self, c_org, c_dim=5, selected_attrs=None):
        """Generate target domain labels for debugging and testing."""
        # Get hair color indices.
        hair_color_indices = []
        for i, attr_name in enumerate(selected_attrs):
            if attr_name in ['Black_Hair', 'Blond_Hair', 'Brown_Hair']:
                hair_color_indices.append(i)

        c_trg_list = []
        for i in range(c_dim):
            c_trg = c_org.clone()
            if i in hair_color_indices:  # Set one hair color to 1 and the rest to 0.
                c_trg[:, i] = 1
                for j in hair_color_indices:
                    if j != i:
                        c_trg[:, j] = 0
            else:
                c_trg[:, i] = (c_trg[:, i] == 0)  # Reverse attribute value.

            c_trg_list.append(c_trg.to(self.device))
        return c_trg_list

    def denorm(self, x):
        """Convert the range from [-1, 1] to [0, 1]."""
        out = (x + 1) / 2
        return out.clamp_(0, 1)

    def train(self):
        data_loader = self.dataload

        # Fetch fixed inputs for debugging.
        data_iter = iter(data_loader)
        x_fixed, c_org = next(data_iter)
        x_fixed = x_fixed.to(self.device)
        c_fixed_list = self.create_labels(c_org, 5, self.selected_attrs)

        d_lr = self.d_lr
        r_lr = self.r_lr
        t_lr = self.t_lr
        e_lr = self.e_lr

        # Start training
        print('Starting point==============================')
        start_time = time.time()

        for i in range(0, self.train_iters):
            # =================================================================================== #
            #                             1. Preprocess input data                                #
            # =================================================================================== #

            # Fetch real images and labels
            try:
                x_real, label_real = next(data_iter)
            except:
                data_iter = iter(data_loader)
                x_real, label_real = next(data_iter)

            rand_idx = torch.randperm(label_real.size(0))
            label_feak = label_real[rand_idx]

            x_real = x_real.to(self.device)
            # labels for hair color
            label_h_real = label_real[:, 0:3]
            label_h_feak = label_feak[:, 0:3]
            # labels for gender
            label_g_real = label_real[:, 3:4]
            label_g_feak = label_feak[:, 3:4]
            # labels for smailing
            label_s_real = label_real[:, 4:]
            label_s_feak = label_feak[:, 4:]

            label_h_real = label_h_real.to(self.device)
            label_h_feak = label_h_feak.to(self.device)
            label_g_real = label_g_real.to(self.device)
            label_g_feak = label_g_feak.to(self.device)
            label_s_real = label_s_real.to(self.device)
            label_s_feak = label_s_feak.to(self.device)

            # =================================================================================== #
            #                             2. Train the discriminator                              #
            # =================================================================================== #

            # Computer loss with real images
            h_src, h_cls = self.D_Hair(x_real)
            d_h_loss_real = -torch.mean(h_src)
            d_h_loss_cls = F.binary_cross_entropy_with_logits(h_cls, label_h_real, reduction='sum') / h_cls.size(0)

            g_src, g_cls = self.D_Gender(x_real)
            d_g_loss_real = -torch.mean(g_src)
            d_g_loss_cls = F.binary_cross_entropy_with_logits(g_cls, label_g_real, reduction='sum') / g_cls.size(0)

            s_src, s_cls = self.D_Smailing(x_real)
            d_s_loss_real = -torch.mean(s_src)
            d_s_loss_cls = F.binary_cross_entropy_with_logits(s_cls, label_s_real, reduction='sum') / s_cls.size(0)

            # Generate fake images and computer loss
            # Retrieve features of real image
            features = self.E(x_real)
            # Transform attributes from one value to an other
            t_h_features = self.T_Hair(features.detach(), label_h_feak)
            t_g_features = self.T_Gender(features.detach(), label_g_feak)
            t_s_features = self.T_Smailing(features.detach(), label_s_feak)
            # Reconstruct images from transformed attributes
            x_h_feak = self.R(t_h_features.detach())
            x_g_feak = self.R(t_g_features.detach())
            x_s_feak = self.R(t_s_features.detach())

            # Computer loss with fake images
            h_src, h_cls = self.D_Hair(x_h_feak.detach())
            d_h_loss_fake = torch.mean(h_src)

            g_src, g_cls = self.D_Gender(x_g_feak.detach())
            d_g_loss_fake = torch.mean(g_src)

            s_src, s_cls = self.D_Smailing(x_s_feak.detach())
            d_s_loss_fake = torch.mean(s_src)

            # Compute loss for gradient penalty
            alpha = torch.rand(x_real.size(0), 1, 1, 1).to(self.device)
            x_h_hat = (alpha * x_real.data + (1 - alpha) * x_h_feak.data).requires_grad_(True)
            #x_h_hat = (alpha * x_real.data + (1-alpha) * x_h_feak.data).requires_grad_(True).to(torch.float16)
            x_g_hat = (alpha * x_real.data + (1 - alpha) * x_g_feak.data).requires_grad_(True)
            #x_g_hat = (alpha * x_real.data + (1-alpha) * x_g_feak.data).requires_grad_(True).to(torch.float16)
            x_s_hat = (alpha * x_real.data + (1 - alpha) * x_s_feak.data).requires_grad_(True)
            #x_s_hat = (alpha * x_real.data + (1-alpha) * x_s_feak.data).requires_grad_(True).to(torch.float16)

            out_src, _ = self.D_Hair(x_h_hat)
            d_h_loss_gp = self.gradient_penalty(out_src, x_h_hat)
            out_src, _ = self.D_Gender(x_g_hat)
            d_g_loss_gp = self.gradient_penalty(out_src, x_g_hat)
            out_src, _ = self.D_Smailing(x_s_hat)
            d_s_loss_gp = self.gradient_penalty(out_src, x_s_hat)

            # Backward and optimize
            d_loss = d_h_loss_real + d_g_loss_real + d_s_loss_real + \
                     d_h_loss_fake + d_g_loss_fake + d_s_loss_fake + \
                     self.lambda_gp * (d_h_loss_gp + d_g_loss_gp + d_s_loss_gp) + \
                     self.lambda_cls * (d_h_loss_cls + d_g_loss_cls + d_s_loss_cls)
            #d_loss = d_h_loss_real + d_h_loss_fake + self.lambda_gp * d_h_loss_gp + self.lambda_cls * d_h_loss_cls


            self.reset_grad()
            d_loss.backward()
            self.dh_optim.step()
            self.dg_optim.step()
            self.ds_optim.step()

            # Logging
            loss = {}
            loss['D/h_loss_real'] = d_h_loss_real.item()
            loss['D/g_loss_real'] = d_g_loss_real.item()
            loss['D/s_loss_real'] = d_s_loss_real.item()
            loss['D/h_loss_fake'] = d_h_loss_fake.item()
            loss['D/g_loss_fake'] = d_g_loss_fake.item()
            loss['D/s_loss_fake'] = d_s_loss_fake.item()
            loss['D/h_loss_cls'] = d_h_loss_cls.item()
            loss['D/g_loss_cls'] = d_g_loss_cls.item()
            loss['D/s_loss_cls'] = d_s_loss_cls.item()
            loss['D/h_loss_gp'] = d_h_loss_gp.item()
            loss['D/g_loss_gp'] = d_g_loss_gp.item()
            loss['D/s_loss_gp'] = d_s_loss_gp.item()

            # =================================================================================== #
            #                  3. Train the encoder, transformer and reconstructor                #
            # =================================================================================== #

            if(i+1) % self.n_critic == 0:
                # Generate fake images and compute loss
                # Retrieve features of real image
                features = self.E(x_real)
                # Transform attributes from one value to an other
                t_h_features = self.T_Hair(features, label_h_feak)
                t_g_features = self.T_Gender(features, label_g_feak)
                t_s_features = self.T_Smailing(features, label_s_feak)
                # Reconstruct images from transformed attributes
                x_h_feak = self.R(t_h_features)
                x_g_feak = self.R(t_g_features)
                x_s_feak = self.R(t_s_features)

                # Computer loss with fake images
                h_src, h_cls = self.D_Hair(x_h_feak)
                etr_h_loss_fake = -torch.mean(h_src)
                etr_h_loss_cls = F.binary_cross_entropy_with_logits(h_cls, label_h_feak, reduction='sum') / h_cls.size(0)

                g_src, g_cls = self.D_Gender(x_g_feak)
                etr_g_loss_fake = -torch.mean(g_src)
                etr_g_loss_cls = F.binary_cross_entropy_with_logits(g_cls, label_g_feak, reduction='sum') / g_cls.size(0)

                s_src, s_cls = self.D_Smailing(x_s_feak)
                etr_s_loss_fake = -torch.mean(s_src)
                etr_s_loss_cls = F.binary_cross_entropy_with_logits(s_cls, label_s_feak, reduction='sum') / s_cls.size(0)

                # Real - Encoder - Reconstructor - Real loss
                x_re = self.R(features)
                er_loss_cyc = torch.mean(torch.abs(x_re - x_real))

                # Real - Encoder - Transform, Real - Encoder - Transform - Reconstructor - Encoder loss
                h_fake_features = self.E(x_h_feak)
                g_fake_features = self.E(x_g_feak)
                s_fake_features = self.E(x_s_feak)

                etr_h_loss_cyc = torch.mean(torch.abs(t_h_features - h_fake_features))
                etr_g_loss_cyc = torch.mean(torch.abs(t_g_features - g_fake_features))
                etr_s_loss_cyc = torch.mean(torch.abs(t_s_features - s_fake_features))

                # Backward and optimize
                etr_loss = etr_h_loss_fake + etr_g_loss_fake + etr_s_loss_fake + \
                           self.lambda_cls * (etr_h_loss_cls + etr_g_loss_cls + etr_s_loss_cls) + \
                           self.lambda_cyc * (er_loss_cyc + etr_h_loss_cyc + etr_g_loss_cyc + etr_s_loss_cyc)
                #etr_loss = etr_h_loss_fake + self.lambda_cls * etr_h_loss_cls + self.lambda_cyc * (er_loss_cyc + etr_h_loss_cyc)



                self.reset_grad()
                etr_loss.backward()
                self.e_optim.step()
                self.th_optim.step()
                self.tg_optim.step()
                self.ts_optim.step()
                self.r_optim.step()

                # Logging.
                loss['ETR/h_loss_fake'] = etr_h_loss_fake.item()
                loss['ETR/g_loss_fake'] = etr_g_loss_fake.item()
                loss['ETR/s_loss_fake'] = etr_s_loss_fake.item()
                loss['ETR/h_loss_cls'] = etr_h_loss_cls.item()
                loss['ETR/g_loss_cls'] = etr_g_loss_cls.item()
                loss['ETR/s_loss_cls'] = etr_s_loss_cls.item()
                loss['ER/er_loss_cyc'] = er_loss_cyc.item()
                loss['ETR/h_loss_cyc'] = etr_h_loss_cyc.item()
                loss['ETR/g_loss_cyc'] = etr_g_loss_cyc.item()
                loss['ETR/s_loss_cyc'] = etr_s_loss_cyc.item()

            # =================================================================================== #
            #                                 4. Miscellaneous                                    #
            # =================================================================================== #

            # Translate fixed images for debugging.
            if (i + 1) % self.sample_step == 0:
                with torch.no_grad():
                    x_fake_list = [x_fixed]
                    for c_fixed in c_fixed_list:
                        xf = self.E(x_fixed)
                        xth = self.T_Hair(xf, c_fixed[:, 0:3])
                        xtg = self.T_Gender(xth, c_fixed[:, 3:4])
                        xts = self.T_Smailing(xtg, c_fixed[:, 4:5])
                        x_fake_list.append(self.R(xts))
                    x_concat = torch.cat(x_fake_list, dim=3)
                    sample_path = os.path.join(self.sample_dir, '{}-images.jpg'.format(i + 1))
                    save_image(self.denorm(x_concat.data.cpu()), sample_path, nrow=1, padding=0)
                    print('Saved real and fake images into {}...'.format(sample_path))

            # Print out training information.
            if (i + 1) % self.log_step == 0:
                et = time.time() - start_time
                et = str(datetime.timedelta(seconds=et))[:-7]
                log = "Elapsed [{}], Iteration [{}/{}]".format(et, i + 1, self.train_iters)
                for tag, value in loss.items():
                    log += ", {}: {:.4f}".format(tag, value)
                print(log)
                
                if self.use_tensorboard:
                    for tag, value in loss.items():
                        self.logger.scalar_summary(tag, value, i+1)

            # save model checkpoints
            if (i+1) % self.model_save_step == 0:
                E_path = os.path.join(self.model_save_dir, '{}-E.ckpt'.format(i+1))
                D_h_path = os.path.join(self.model_save_dir, '{}-D_h.ckpt'.format(i+1))
                D_g_path = os.path.join(self.model_save_dir, '{}-D_g.ckpt'.format(i+1))
                D_s_path = os.path.join(self.model_save_dir, '{}-D_s.ckpt'.format(i+1))
                R_path = os.path.join(self.model_save_dir, '{}-R.ckpt'.format(i+1))
                T_h_path = os.path.join(self.model_save_dir, '{}-T_h.ckpt'.format(i+1))
                T_g_path = os.path.join(self.model_save_dir, '{}-T_g.ckpt'.format(i+1))
                T_s_path = os.path.join(self.model_save_dir, '{}-T_s.ckpt'.format(i+1))
                torch.save(self.E.state_dict(), E_path)
                torch.save(self.D_Hair.state_dict(), D_h_path)
                torch.save(self.D_Gender.state_dict(), D_g_path)
                torch.save(self.D_Smailing.state_dict(), D_s_path)
                torch.save(self.R.state_dict(), R_path)
                torch.save(self.T_Hair.state_dict(), T_h_path)
                torch.save(self.T_Gender.state_dict(), T_g_path)
                torch.save(self.T_Smailing.state_dict(), T_s_path)
                print('Saved model checkpoints into {}...'.format(self.model_save_dir))

            # decay learning rates
            if (i+1) % self.lr_update_step == 0 and (i+1) > self.num_iters_decay:
                e_lr -= (self.e_lr / float(self.decay_rate))
                d_lr -= (self.d_lr / float(self.decay_rate))
                r_lr -= (self.r_lr / float(self.decay_rate))
                t_lr -= (self.t_lr / float(self.decay_rate))
                self.update_lr(e_lr, d_lr, r_lr, t_lr)
                print ('Decayed learning rates, e_lr: {}, d_lr: {}, r_lr: {}, t_lr: {}.'.format(e_lr, d_lr, r_lr, t_lr))
Exemple #18
0
class BaseGAN(torch.nn.Module):
    def __init__(
        self,
        generator_and_opt: [Generator, torch.optim] = None,
        discriminator_and_opt: [Discriminator, torch.optim] = None,
        input_size: int = None,
        hidden_channel: int = 128,
        latent_dim: int = 100,
        learning_rate: float = 1e-4,
    ):
        self.generator = None
        self.discriminator = None
        super().__init__()
        # Generator
        if generator_and_opt is None:
            assert input_size is None, "generator_and_opt or input_size should be given."
            if self.generator is not None:
                self.generator = Generator(input_size=input_size,
                                           latent_dim=latent_dim,
                                           hidden_channel=hidden_channel)
                self.generator_opt = torch.optim.Adam(
                    self.generator.parameters(), learning_rate)
        else:
            self.generator, self.generator_opt = generator_and_opt
        # Discriminator
        if discriminator_and_opt is None:
            assert input_size is None, "discriminator_and_opt or input_size should be given."
            if self.discriminator is not None:
                self.discriminator = Discriminator(
                    input_size=input_size, hidden_channel=hidden_channel)
                self.discriminator_opt = torch.optim.Adam(
                    self.discriminator.parameters(), learning_rate)
        else:
            self.discriminator, self.discriminator_opt = discriminator_and_opt

    def discriminator_loss(self,
                           x: torch.Tensor) -> (torch.Tensor, torch.Tensor):
        return None, None

    def generator_loss(self, x: torch.Tensor) -> torch.Tensor:
        return None

    def fit_batch(self, engine: Engine, batch: Optional[Union[tuple,
                                                              list]]) -> dict:
        return self.fit(batch)

    def fit(self, batch: Optional[Union[tuple, list]]) -> dict:
        self.generator.train()
        self.discriminator.train()
        self.generator_opt.zero_grad()
        self.discriminator_opt.zero_grad()

        device = next(self.generator.parameters()).device
        x, _ = batch
        x = x.to(device)
        #
        # 1. get discriminator loss and update discriminator
        #
        real_D_loss, fake_D_loss = self.discriminator_loss(x)
        D_loss = real_D_loss + fake_D_loss
        D_loss.backward()
        self.discriminator_opt.step()
        #
        # 2. get generator loss and update generator
        #
        G_loss = self.generator_loss(x)
        G_loss.backward()
        self.generator_opt.step()
        return {
            "D_loss": float(D_loss),
            "G_loss": float(G_loss),
        }
Exemple #19
0
class Solver(object):
    def __init__(self, config, data_loader):
        self.generator = None
        self.discriminator = None
        self.g_optimizer = None
        self.d_optimizer = None
        self.g_conv_dim = config.g_conv_dim
        self.d_conv_dim = config.d_conv_dim
        self.z_dim = config.z_dim
        self.beta1 = config.beta1
        self.beta2 = config.beta2
        self.image_size = config.image_size
        self.data_loader = data_loader
        self.num_epochs = config.num_epochs
        self.batch_size = config.batch_size
        self.sample_size = config.sample_size
        self.lr = config.lr
        self.log_step = config.log_step
        self.sample_step = config.sample_step
        self.sample_path = config.sample_path
        self.model_path = config.model_path
        self.epoch = config.epoch
        self.build_model()

        self.plotter = Plotter()
        
    def build_model(self):
        """Build generator and discriminator."""
        self.generator = Generator(z_dim=self.z_dim)
        print(count_parameters(self.generator))
        self.discriminator = Discriminator()
        print(count_parameters(self.discriminator))
        self.g_optimizer = optim.Adam(self.generator.parameters(),
                                      self.lr, (self.beta1, self.beta2))
        self.d_optimizer = optim.Adam(self.discriminator.parameters(),
                                      self.lr*1, (self.beta1, self.beta2))

        if self.epoch:
            g_path = os.path.join(self.model_path, 'generator-%d.pkl' % self.epoch)
            d_path = os.path.join(self.model_path, 'discriminator-%d.pkl' % self.epoch)
            g_optim_path = os.path.join(self.model_path, 'gen-optim-%d.pkl' % self.epoch)
            d_optim_path = os.path.join(self.model_path, 'dis-optim-%d.pkl' % self.epoch)
            self.generator.load_state_dict(torch.load(g_path))
            self.discriminator.load_state_dict(torch.load(d_path))
            self.g_optimizer.load_state_dict(torch.load(g_optim_path))
            self.d_optimizer.load_state_dict(torch.load(d_optim_path))

        if torch.cuda.is_available():
            self.generator.cuda()
            self.discriminator.cuda()


        
    def to_variable(self, x):
        """Convert tensor to variable."""
        if torch.cuda.is_available():
            x = x.cuda()
        return Variable(x)
    
    def to_data(self, x):
        """Convert variable to tensor."""
        if torch.cuda.is_available():
            x = x.cpu()
        return x.data
    
    def reset_grad(self):
        """Zero the gradient buffers."""
        self.discriminator.zero_grad()
        self.generator.zero_grad()
    
    def denorm(self, x):
        """Convert range (-1, 1) to (0, 1)"""
        out = (x + 1) / 2
        return out.clamp(0, 1)

    def train(self):
        """Train generator and discriminator."""
        fixed_noise = self.to_variable(torch.randn(self.batch_size, self.z_dim))
        total_step = len(self.data_loader)
        for epoch in range(self.epoch, self.epoch + self.num_epochs) if self.epoch else range(self.num_epochs):
            for i, images in enumerate(self.data_loader):
                if len(images) != self.batch_size:
                    continue

                # self.plotter.draw_kernels(self.discriminator)
                for p in self.discriminator.parameters():
                    p.requires_grad = True
                #===================== Train D =====================#
                images = self.to_variable(images)
                images.retain_grad()
                batch_size = images.size(0)
                noise = self.to_variable(torch.randn(batch_size, self.z_dim))
                
                # Train D to recognize real images as real.
                outputs = self.discriminator(images)
                real_loss = torch.mean((outputs - 1) ** 2)      # L2 loss instead of Binary cross entropy loss (this is optional for stable training)
                # real_loss = torch.mean(outputs - 1)
                # Train D to recognize fake images as fake.
                fake_images = self.generator(noise)
                fake_images.retain_grad()
                outputs = self.discriminator(fake_images)
                fake_loss = torch.mean(outputs ** 2)
                # fake_loss = torch.mean(outputs)

                # gradient penalty
                gp_loss = calc_gradient_penalty(self.discriminator, images, fake_images)

                # Backprop + optimize
                d_loss = fake_loss + real_loss + gp_loss
                self.reset_grad()
                d_loss.backward()
                self.d_optimizer.step()
                if i % 10 == 0:
                    self.plotter.draw_activations(fake_images.grad[0], original=fake_images[0])

                g_losses = []
                for p in self.discriminator.parameters():
                    p.requires_grad = False
                #===================== Train G =====================#
                for g_batch in range(5):
                    noise = self.to_variable(torch.randn(batch_size, self.z_dim))

                    # Train G so that D recognizes G(z) as real.
                    fake_images = self.generator(noise)
                    outputs = self.discriminator(fake_images)
                    g_loss = torch.mean((outputs - 1) ** 2)
                    # g_loss = -torch.mean(outputs)
                    # Backprop + optimize
                    self.reset_grad()
                    g_loss.backward()
                    # if g_loss.item() < 0.5 * d_loss.item():
                    #     break
                    self.g_optimizer.step()

                    g_losses.append("%.3f"%g_loss.clone().item())
                # print the log info
                if (i+1) % self.log_step == 0:
                    print('Epoch [%d/%d], Step[%d/%d], d_real_loss: %.4f, ' 
                          'd_fake_loss: %.4f, gp_loss: %s, g_loss: %s'
                          %(epoch+1, self.num_epochs, i+1, total_step, 
                            real_loss.item(), fake_loss.item(), gp_loss.item(), ", ".join(g_losses)))

                # save the sampled images
                # print((i+1)%self.sample_step)
                if (i) % self.sample_step == 0:
                    print("saving samples")
                    fake_images = self.generator(fixed_noise)
                    if not os.path.exists(self.sample_path):
                        os.makedirs(self.sample_path)
                    torchvision.utils.save_image(self.denorm(fake_images.data), 
                        os.path.join(self.sample_path,
                                     'fake_samples-%d-%d.png' %(epoch+1, i+1)))
            
            # save the model parameters for each epoch
            if epoch % 5 == 0:
                if not os.path.exists(self.model_path):
                    os.mkdir(self.model_path)
                g_path = os.path.join(self.model_path, 'generator-%d.pkl' %(epoch+1))
                d_path = os.path.join(self.model_path, 'discriminator-%d.pkl' %(epoch+1))
                g_optim_path = os.path.join(self.model_path, 'gen-optim-%d.pkl' % (epoch + 1))
                d_optim_path = os.path.join(self.model_path, 'dis-optim-%d.pkl' % (epoch + 1))
                torch.save(self.generator.state_dict(), g_path)
                torch.save(self.discriminator.state_dict(), d_path)
                torch.save(self.g_optimizer.state_dict(), g_optim_path)
                torch.save(self.d_optimizer.state_dict(), d_optim_path)
            
    def sample(self):
        
        # Load trained parameters 
        g_path = os.path.join(self.model_path, 'generator-%d.pkl' % self.num_epochs)
        d_path = os.path.join(self.model_path, 'discriminator-%d.pkl' % self.num_epochs)
        self.generator.load_state_dict(torch.load(g_path))
        self.discriminator.load_state_dict(torch.load(d_path))
        self.generator.eval()
        self.discriminator.eval()
        
        # Sample the images
        noise = self.to_variable(torch.randn(self.sample_size, self.z_dim))
        fake_images = self.generator(noise)
        sample_path = os.path.join(self.sample_path, 'fake_samples-final.png')
        torchvision.utils.save_image(self.denorm(fake_images.data), sample_path, nrow=12)
        
        print("Saved sampled images to '%s'" %sample_path)
Exemple #20
0
                                  batch_size=batch_size,
                                  shuffle=True)
test_dataset_loader = DataLoader(test_dataset,
                                 batch_size=batch_size,
                                 shuffle=False)

generator_A2B = Generator(24).to("cuda")
generator_B2A = Generator(24).to("cuda")
discriminator_A = Discriminator(1).to("cuda")
discriminator_B = Discriminator(1).to("cuda")

generator_params = [generator_A2B.parameters(), generator_B2A.parameters()]
generator_optimizer = torch.optim.Adam(itertools.chain(*generator_params),
                                       lr=generator_lr)
discriminator_params = [
    discriminator_A.parameters(),
    discriminator_B.parameters()
]
discriminator_optimizer = torch.optim.Adam(
    itertools.chain(*discriminator_params), lr=discriminator_lr)

for epoch in range(num_epochs):
    print("Epoch ", epoch)
    for i, sample in enumerate(train_dataset_loader):

        # Learning rate adjustment snippet
        # TODO: len(train_dataset_loader) or len(train_dataset)
        num_iterations = (len(train_dataset_loader) // batch_size) * epoch + i
        if num_iterations > 10000:
            lambda_identity = 0
        if num_iterations > start_decay:
Exemple #21
0
class Model():
    def __init__(self, args):
        self.args = args

        self.pretrained = False
        self.epoch = 0
        self.G = Generator()
        self.D = Discriminator()
        self.g_optimizer = optim.Adam(self.G.parameters(), lr=1E-4)
        self.d_optimizer = optim.Adam(self.D.parameters(), lr=1E-4)
        self.g_scheduler = optim.lr_scheduler.StepLR(self.g_optimizer,
                                                     step_size=40)
        self.d_scheduler = optim.lr_scheduler.StepLR(self.d_optimizer,
                                                     step_size=40)
        self.train_losses = []
        self.val_losses = []

        if args.load_model:
            self._load_state(args.load_model)

        # extract all layers prior to the last softmax of VGG-19
        vgg19_layers = list(models.vgg19(pretrained=True).features)[:36]
        self.vgg19 = nn.Sequential(*vgg19_layers).eval()
        for param in self.vgg19.parameters():
            param.requires_grad = False

        self.mse_loss = torch.nn.MSELoss()
        self.bce_loss = torch.nn.BCELoss()

    def train(self, train_dataloader, val_dataloader=None):
        self.D.to(device)
        self.G.to(device)
        self.vgg19.to(device)
        """ Pretrain Generator """
        if not self.pretrained:
            log_message("Starting pretraining")
            self._pretrain(train_dataloader)
            self._save_state()

            if val_dataloader:
                val_g_loss, _ = self.evaluate(val_dataloader)
                log_message("Pretrain G loss: {:.4f}".format(val_g_loss))
        """ Real Training """
        log_message("Starting training")
        while self.epoch < self.args.epochs:
            # Train one epoch
            self.D.train()
            self.G.train()

            g_loss, d_loss = self._run_epoch(train_dataloader, train=True)

            self.train_losses.append([g_loss, d_loss])
            self.g_scheduler.step()
            self.d_scheduler.step()
            self.epoch += 1
            log_message("Epoch: {}/{}".format(self.epoch, self.args.epochs))

            # Print evaluation
            train_string = "Train G loss: {:.4f} | Train D loss: {:.4f}".format(
                g_loss, d_loss)
            if self.epoch % self.args.eval_epochs == 0:
                if val_dataloader:
                    val_g_loss, val_d_loss = self.evaluate(val_dataloader)
                    self.val_losses.append([val_g_loss, val_d_loss])
                    train_string += " | Val G loss: {:.4f} | Val D loss: {:.4f}".format(
                        val_g_loss, val_d_loss)
            log_message(train_string)

            # Save the model
            if self.epoch % self.args.save_epochs == 0:
                self._save_state()

        log_message("Finished training")
        self._save_state()

    def evaluate(self, dataloader):
        self.D.eval()
        self.G.eval()

        with torch.no_grad():
            return self._run_epoch(dataloader, train=False)

    def generate(self, dataloader):
        def to_image(tensor):
            array = tensor.data.cpu().numpy()
            array = array.transpose((1, 2, 0))
            array = np.clip(255.0 * (array + 1) / 2, 0, 255)
            array = np.uint8(array)
            return Image.fromarray(array)

        self.D.eval()
        self.G.eval()

        if not os.path.exists(self.args.generate_dir):
            os.mkdir(self.args.generate_dir)

        with torch.no_grad():
            for batch in dataloader:
                low_res = batch['low_res'].to(device)
                hi_res = batch['high_res']
                generated = self.G(low_res)

                for i in range(len(generated)):
                    naive = np.clip(
                        255.0 * low_res[i].data.cpu().numpy().transpose(
                            (1, 2, 0)), 0, 255)
                    naive = Image.fromarray(np.uint8(naive))
                    naive = naive.resize((96, 96), Image.BICUBIC)

                    fake_im = to_image(generated[i])
                    real_im = to_image(hi_res[i])

                    naive.save(
                        os.path.join(self.args.generate_dir,
                                     "{}_naive.png".format(i)))
                    fake_im.save(
                        os.path.join(self.args.generate_dir,
                                     "{}_fake.png".format(i)))
                    real_im.save(
                        os.path.join(self.args.generate_dir,
                                     "{}_real.png".format(i)))

                    if i > 10:
                        return

    def _load_state(self, fname):
        if torch.cuda.is_available():
            map_location = lambda storage, loc: storage.cuda()
        else:
            map_location = 'cpu'
        state = torch.load(fname, map_location=map_location)

        self.pretrained = state["pretrained"]
        self.epoch = state["epoch"]
        self.train_losses = state["train_losses"]
        self.val_losses = state["val_losses"]
        self.G.load_state_dict(state["G"])
        self.D.load_state_dict(state["D"])
        self.g_optimizer.load_state_dict(state["g_optimizer"])
        self.d_optimizer.load_state_dict(state["d_optimizer"])
        self.g_scheduler.load_state_dict(state["g_scheduler"])
        self.d_scheduler.load_state_dict(state["d_scheduler"])

        for state in self.d_optimizer.state.values():
            for k, v in state.items():
                if torch.is_tensor(v):
                    state[k] = v.to(device)
        for state in self.g_optimizer.state.values():
            for k, v in state.items():
                if torch.is_tensor(v):
                    state[k] = v.to(device)

    def _save_state(self):
        if not os.path.exists(self.args.save_dir):
            os.mkdir(self.args.save_dir)

        fname = "%s/save_%d.pkl" % (self.args.save_dir, self.epoch)
        state = {
            "pretrained": self.pretrained,
            "epoch": self.epoch,
            "G": self.G.state_dict(),
            "D": self.D.state_dict(),
            "g_optimizer": self.g_optimizer.state_dict(),
            "d_optimizer": self.d_optimizer.state_dict(),
            "g_scheduler": self.g_scheduler.state_dict(),
            "d_scheduler": self.d_scheduler.state_dict(),
            "train_losses": self.train_losses,
            "val_losses": self.val_losses
        }
        torch.save(state, fname)

    def _pretrain(self, dataloader):
        self.G.train()
        for i in range(self.args.pretrain_epochs):
            log_message("Pretrain Epoch: {}/{}".format(
                i, self.args.pretrain_epochs))
            for batch in dataloader:
                low_res = batch['low_res'].to(device)
                high_res = batch['high_res'].to(device)

                self.g_optimizer.zero_grad()

                generated = self.G(low_res)

                # Optimize pixel loss
                g_loss = self.mse_loss(generated, high_res)
                g_loss.backward()
                self.g_optimizer.step()

        self.pretrained = True

    def _run_epoch(self, dataloader, train):
        g_losses, d_losses = [], []

        for batch in dataloader:
            low_res = batch['low_res'].to(device)
            high_res = batch['high_res'].to(device)

            batch_size = high_res.size(0)
            real = torch.ones((batch_size, 1), requires_grad=False).to(device)
            fake = torch.zeros((batch_size, 1), requires_grad=False).to(device)
            """ Discriminator """
            generated = self.G(low_res)
            self.d_optimizer.zero_grad()

            real_loss = self.bce_loss(self.D(high_res), real)
            fake_loss = self.bce_loss(self.D(generated), fake)
            d_loss = real_loss + fake_loss
            d_losses.append(d_loss.item())

            if train:
                d_loss.backward()
                self.d_optimizer.step()
            """ Generator """
            generated = self.G(low_res)
            self.g_optimizer.zero_grad()

            # take a [B, C, W, H] batch of [-1, 1] images, normalize, then run through vgg19
            def vgg_features(image):
                mean = torch.tensor(
                    [0.485, 0.456,
                     0.406]).unsqueeze(0).unsqueeze(2).unsqueeze(3).to(device)
                std = torch.tensor(
                    [0.229, 0.224,
                     0.225]).unsqueeze(0).unsqueeze(2).unsqueeze(3).to(device)
                image = (image + 1) / 2
                image = (image - mean) / std
                return self.vgg19(image)

            pixel_loss = self.mse_loss(high_res, generated)
            content_loss = self.mse_loss(vgg_features(high_res),
                                         vgg_features(generated))
            adversarial_loss = self.bce_loss(self.D(generated), real)
            g_loss = pixel_loss + 0.006 * content_loss + 1E-3 * adversarial_loss
            g_losses.append(g_loss.item())

            if train:
                g_loss.backward()
                self.g_optimizer.step()

        return np.mean(g_losses), np.mean(d_losses)
Exemple #22
0
class BiGAN(nn.Module):
    def __init__(self,config):
        super(BiGAN,self).__init__()

        self._work_type = config.work_type
        self._epochs = config.epochs
        self._batch_size = config.batch_size

        self._encoder_lr = config.encoder_lr
        self._generator_lr = config.generator_lr
        self._discriminator_lr = config.discriminator_lr
        self._latent_dim = config.latent_dim
        self._weight_decay = config.weight_decay

        self._img_shape = (config.input_size,config.input_size)
        self._img_save_path = config.image_save_path
        self._model_save_path = config.model_save_path
        self._device = config.device

        if self._work_type == 'train':
            # Loss function
            self._adversarial_criterion = torch.nn.MSELoss()

            # Initialize generator, encoder and discriminator
            self._G = Generator(self._latent_dim,self._img_shape).to(self._device)
            self._E = Encoder(self._latent_dim,self._img_shape).to(self._device)
            self._D = Discriminator(self._latent_dim,self._img_shape).to(self._device)

            self._G.apply(self.weights_init)
            self._E.apply(self.weights_init)
            self._D.apply(self.discriminator_weights_init)

            self._G_optimizer = torch.optim.Adam([{'params' : self._G.parameters()},{'params' : self._E.parameters()}],
                                                lr=self._generator_lr,betas=(0.5,0.999),weight_decay=self._weight_decay)
            self._D_optimizer = torch.optim.Adam(self._D.parameters(),lr=self._discriminator_lr,betas=(0.5,0.999))
            
            self._G_scheduler = lr_scheduler.ExponentialLR(self._G_optimizer, gamma= 0.99) 
            self._D_scheduler = lr_scheduler.ExponentialLR(self._D_optimizer, gamma= 0.99) 

    def train(self,train_loader):
        Tensor = torch.cuda.FloatTensor if self._device == 'cuda' else torch.FloatTensor
        n_total_steps = len(train_loader)
        for epoch in range(self._epochs):
            self._G_scheduler.step()
            self._D_scheduler.step()

            for i, (images, _) in enumerate(train_loader):
                # Adversarial ground truths
                valid = Variable(Tensor(images.size(0), 1).fill_(1), requires_grad=False)
                fake = Variable(Tensor(images.size(0), 1).fill_(0), requires_grad=False)

                
                # ---------------------
                # Train Encoder
                # ---------------------
                
                # Configure input
                images = images.reshape(-1,np.prod(self._img_shape)).to(self._device)

                # z_ is encoded latent vector
                (original_img,z_)= self._E(images)
                predict_encoder = self._D(original_img,z_)
  

                # ---------------------
                # Train Generator
                # ---------------------
                
                # Sample noise as generator input
                z = Variable(Tensor(np.random.normal(0, 1, (images.shape[0],self._latent_dim))))
                (gen_img,z)=self._G(z)
                predict_generator = self._D(gen_img,z)
                                                                                                               
                G_loss = (self._adversarial_criterion(predict_generator,valid)+self._adversarial_criterion(predict_encoder,fake)) *0.5   

                self._G_optimizer.zero_grad()
                G_loss.backward()
                self._G_optimizer.step()         

                # ---------------------
                # Train Discriminator
                # ---------------------

                z = Variable(Tensor(np.random.normal(0, 1, (images.shape[0],self._latent_dim))))
                (gen_img,z)=self._G(z)
                (original_img,z_)= self._E(images)
                predict_encoder = self._D(original_img,z_)
                predict_generator = self._D(gen_img,z)

                D_loss = (self._adversarial_criterion(predict_encoder,valid)+self._adversarial_criterion(predict_generator,fake)) *0.5                
                
                self._D_optimizer.zero_grad()
                D_loss.backward()
                self._D_optimizer.step()

                

                
                if i % 100 == 0:
                    print (f'Epoch [{epoch+1}/{self._epochs}], Step [{i+1}/{n_total_steps}]')
                    print (f'Generator Loss: {G_loss.item():.4f} Discriminator Loss: {D_loss.item():.4f}')
 
                if i % 400 ==0:
                    vutils.save_image(gen_img.unsqueeze(1).cpu().data[:64, ], f'{self._img_save_path}/E{epoch}_Iteration{i}_fake.png')
                    vutils.save_image(original_img.unsqueeze(1).cpu().data[:64, ], f'{self._img_save_path}/E{epoch}_Iteration{i}_real.png')
                    print('image saved')
                    print('')
            if epoch % 100==0:
                torch.save(self._G.state_dict(), f'{self._model_save_path}/netG_{epoch}epoch.pth')
                torch.save(self._E.state_dict(), f'{self._model_save_path}/netE_{epoch}epoch.pth')
                torch.save(self._D.state_dict(), f'{self._model_save_path}/netD_{epoch}epoch.pth')





    def weights_init(self,m):
        classname = m.__class__.__name__
        if classname.find('BatchNorm') != -1:
            m.weight.data.normal_(1.0, 0.02)
            m.bias.data.fill_(0)
        elif classname.find('Linear') != -1:
            m.bias.data.fill_(0)

    def discriminator_weights_init(self,m):
        classname = m.__class__.__name__
        if classname.find('BatchNorm') != -1:
            m.weight.data.normal_(1.0, 0.5)
            m.bias.data.fill_(0)
        elif classname.find('Linear') != -1:
            m.bias.data.fill_(0)