Ejemplo n.º 1
0
    def __init__(self, args):
        self.train_loader, self.test_loader, self.valid_loader, self.input_dims = hp.get_data_loader(
            args, args.trainer.b_size, args.system.num_workers)
        args.Z_dim = self.input_dims

        self.args = args
        super(TrainerToy, self).__init__(args)
Ejemplo n.º 2
0
    def __init__(self, args):
        self.args = args
        self.train_loader, self.test_loader, self.valid_loader, self.input_dims = hp.get_data_loader(
            self.args, self.args.b_size, self.args.num_workers)
        args.Z_dim = int(self.input_dims)
        self.dataset_size = int(self.train_loader.dataset.X.shape[0])
        super(TrainerEBM, self).__init__(args)
        if self.args.combined_discriminator:
            self.discriminator = models.energy_model.CombinedDiscriminator(
                self.discriminator, self.generator)

        if self.args.criterion == 'cd':
            sampler = get_latent_sampler(self.args, self.discriminator,
                                         self.args.Z_dim, self.device)
            self.cd_sampler = samplers.ContrastiveDivergenceSampler(
                self.noise_gen, sampler)
Ejemplo n.º 3
0
    def build_model(self):
        self.train_loader, self.test_loader, self.valid_loader, self.input_dims = hp.get_data_loader(
            self.args, self.args.b_size, self.args.num_workers)

        self.generator = hp.get_base(self.args, self.input_dims, self.device)
        self.discriminator = hp.get_energy(self.args, self.input_dims,
                                           self.device)
        self.noise_gen = hp.get_latent_noise(self.args, self.args.Z_dim,
                                             self.device)
        self.fixed_latents = self.noise_gen.sample([64])
        self.eval_latents = torch.cat([
            self.noise_gen.sample([self.args.sample_b_size]).cpu()
            for b in range(
                int(self.args.fid_samples / self.args.sample_b_size) + 1)
        ],
                                      dim=0)
        self.eval_latents = self.eval_latents[:self.args.fid_samples]
        self.eval_velocity = torch.cat([
            torch.zeros([self.args.sample_b_size, self.eval_latents.shape[1]
                         ]).cpu()
            for b in range(
                int(self.args.fid_samples / self.args.sample_b_size) + 1)
        ],
                                       dim=0)
        self.eval_velocity = self.eval_velocity[:self.args.fid_samples]
        # load models if path exists, define log partition if using kale and add to discriminator
        self.d_params = list(
            filter(lambda p: p.requires_grad, self.discriminator.parameters()))
        if self.args.g_path is not None:
            self.load_generator()
            self.generator.eval()
        if self.args.d_path is not None:
            self.load_discriminator()
            self.discriminator.eval()

        else:
            if self.args.criterion == 'kale':
                self.log_partition = nn.Parameter(
                    torch.zeros(1).to(self.device))
                self.d_params.append(self.log_partition)
            else:
                self.log_partition = Variable(
                    torch.zeros(1, requires_grad=False)).to(self.device)

        if self.mode == 'train':
            # optimizers
            self.optim_d = hp.get_optimizer(self.args, 'discriminator',
                                            self.d_params)
            self.optim_g = hp.get_optimizer(self.args, 'generator',
                                            self.generator.parameters())
            self.optim_partition = hp.get_optimizer(self.args, 'discriminator',
                                                    [self.log_partition])
            # schedulers
            self.scheduler_d = hp.get_scheduler(self.args, self.optim_d)
            self.scheduler_g = hp.get_scheduler(self.args, self.optim_g)
            self.scheduler_partition = hp.get_scheduler(
                self.args, self.optim_partition)
            self.loss = hp.get_loss(self.args)

            self.counter = 0
            self.g_counter = 0
            self.g_loss = torch.tensor(0.)
            self.d_loss = torch.tensor(0.)

        if self.args.latent_sampler in ['imh', 'dot', 'spherelangevin']:
            self.latent_potential = samplers.Independent_Latent_potential(
                self.generator, self.discriminator, self.noise_gen)
        elif self.args.latent_sampler in ['zero_temperature_langevin']:
            self.latent_potential = samplers.Cold_Latent_potential(
                self.generator, self.discriminator)
        else:
            self.latent_potential = samplers.Latent_potential(
                self.generator, self.discriminator, self.noise_gen,
                self.args.temperature)

        self.latent_sampler = hp.get_latent_sampler(self.args,
                                                    self.latent_potential,
                                                    self.args.Z_dim,
                                                    self.device)
        if self.args.eval_fid:
            self.eval_fid = True
            print('==> Loading inception network...')
            block_idx = cp.InceptionV3.BLOCK_INDEX_BY_DIM[2048]
            self.fid_model = cp.InceptionV3([block_idx]).to(self.device)
            self.fid_scheduler = FIDScheduler(self.args)
            self.fid_scheduler = MMDScheduler(self.args, self.device)
            self.fid_scheduler.init_trainer(self)
            self.fid_train = -1.
        else:
            self.eval_fid = False

        dev_count = torch.cuda.device_count()
        if self.args.dataparallel and dev_count > 1:
            self.generator = torch.nn.DataParallel(self.generator,
                                                   device_ids=list(
                                                       range(dev_count)))
            self.discriminator = torch.nn.DataParallel(self.discriminator,
                                                       device_ids=list(
                                                           range(dev_count)))
        self.accum_loss_g = []
        self.accum_loss_d = []
        self.true_train_scores = None
        self.true_valid_scores = None
        self.true_train_mu = None
        self.true_train_sigma = None
        self.true_valid_mu = None
        self.true_valid_sigma = None
        self.kids = None
Ejemplo n.º 4
0
def main(batch_size, epochs, learning_rate, beta1, beta2, data_path,
         num_workers):
    # Create train and test dataloaders for images from the two domains X and Y
    # image_type = directory names for our data

    dataloader_X, test_dataloader_X = get_data_loader(image_type='summer',
                                                      image_dir=data_path,
                                                      batch_size=batch_size)
    dataloader_Y, test_dataloader_Y = get_data_loader(image_type='winter',
                                                      image_dir=data_path,
                                                      batch_size=batch_size)

    # call the function to get models
    G_XtoY, G_YtoX, D_X, D_Y = create_model()

    # print all of the models
    print_models(G_XtoY, G_YtoX, D_X, D_Y)

    g_params = list(G_XtoY.parameters()) + list(
        G_YtoX.parameters())  # Get generator parameters

    # Create optimizers for the generators and discriminators
    g_optimizer = optim.Adam(g_params, learning_rate, [beta1, beta2])
    d_x_optimizer = optim.Adam(D_X.parameters(), learning_rate, [beta1, beta2])
    d_y_optimizer = optim.Adam(D_Y.parameters(), learning_rate, [beta1, beta2])

    # train the network
    losses = training_loop(G_XtoY,
                           G_YtoX,
                           D_X,
                           D_Y,
                           g_optimizer,
                           d_x_optimizer,
                           d_y_optimizer,
                           dataloader_X,
                           dataloader_Y,
                           test_dataloader_X,
                           test_dataloader_Y,
                           epochs=epochs)

    fig, ax = plt.subplots(figsize=(12, 8))
    losses = np.array(losses)
    print(losses)
    plt.plot(losses.T[0], label='Discriminator, X', alpha=0.5)
    plt.plot(losses.T[1], label='Discriminator, Y', alpha=0.5)
    plt.plot(losses.T[2], label='Generators', alpha=0.5)
    plt.title("Training Losses")
    plt.legend()

    import matplotlib.image as mpimg

    # helper visualization code
    def view_samples(iteration, sample_dir='samples_cyclegan'):

        # samples are named by iteration
        path_XtoY = os.path.join(sample_dir,
                                 'sample-{:06d}-X-Y.png'.format(iteration))
        path_YtoX = os.path.join(sample_dir,
                                 'sample-{:06d}-Y-X.png'.format(iteration))

        # read in those samples
        try:
            x2y = mpimg.imread(path_XtoY)
            y2x = mpimg.imread(path_YtoX)
        except:
            print('Invalid number of iterations.')

        fig, (ax1, ax2) = plt.subplots(figsize=(18, 20),
                                       nrows=2,
                                       ncols=1,
                                       sharey=True,
                                       sharex=True)
        ax1.imshow(x2y)
        ax1.set_title('X to Y')
        ax2.imshow(y2x)
        ax2.set_title('Y to X')

    # view samples at iteration 4000
    view_samples(1, 'samples_cyclegan')