Exemple #1
0
def test_and_sample(cfg, model, device, test_loader, height, width, losses,
                    params, epoch):
    test_loss = 0

    model.eval()
    with torch.no_grad():
        for images, labels in test_loader:
            images = images.to(device, non_blocking=True)
            labels = labels.to(device, non_blocking=True)

            normalized_images = images.float() / (cfg.color_levels - 1)
            outputs = model(normalized_images, labels)

            test_loss += F.cross_entropy(outputs, images, reduction='none')

    test_loss = test_loss.mean().cpu() / len(test_loader.dataset)

    wandb.log({"Test loss": test_loss})
    print("Average test loss: {}".format(test_loss))

    losses.append(test_loss)
    params.append(model.state_dict())

    samples = model.sample((3, height, width),
                           cfg.epoch_samples,
                           device=device)
    save_samples(samples, TRAIN_SAMPLES_DIR,
                 'epoch{}_samples.png'.format(epoch + 1))
Exemple #2
0
    def _save_samples(self):
        rows, columns = 5, 5
        noise = np.random.normal(0, 1, (rows * columns, self._latent_dim))
        generated_transactions = self._generator.predict(noise)

        filenames = [
            self._img_dir + ('/%07d.png' % self._epoch),
            self._img_dir + '/last.png'
        ]
        utils.save_samples(generated_transactions, rows, columns, filenames)
Exemple #3
0
def train(data_root, model, total_epoch, batch_size, lrate):

    X, Z, Lr = model.inputs()
    d_loss, g_loss = model.loss(X, Z)
    d_opt, g_opt = model.optimizer(d_loss, g_loss, Lr)
    g_sample = model.sample(Z)
    sample_size = batch_size
    test_noise = utils.get_noise(sample_size, n_noise)
    epoch_drop = 3

    iterator, image_count = ImageIterator(data_root, batch_size,
                                          model.image_size,
                                          model.image_channels).get_iterator()
    next_element = iterator.get_next()

    total_batch = int(image_count / batch_size)
    #learning_rate = lrate
    #G_var_list = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='generator')
    saver = tf.train.Saver()
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        sess.run(iterator.initializer)

        for epoch in range(total_epoch):
            learning_rate = lrate * \
                            math.pow(0.2, math.floor((epoch + 1) / epoch_drop))
            for step in range(total_batch):
                batch_x = sess.run(next_element)
                batch_z = utils.get_noise(batch_size, n_noise)

                _, loss_val_D = sess.run([d_opt, d_loss],
                                         feed_dict={
                                             X: batch_x,
                                             Z: batch_z,
                                             Lr: learning_rate
                                         })
                _, loss_val_G = sess.run([g_opt, g_loss],
                                         feed_dict={
                                             Z: batch_z,
                                             Lr: learning_rate
                                         })

                if step % 300 == 0:
                    #sample_size = 10
                    #noise = get_noise(sample_size, n_noise)
                    samples = sess.run(g_sample, feed_dict={Z: test_noise})
                    title = 'samples/%05d_%05d.png' % (epoch, step)
                    utils.save_samples(title, samples)

                    print('Epoch:', '%04d' % epoch,
                          '%05d/%05d' % (step, total_batch),
                          'D loss: {:.4}'.format(loss_val_D),
                          'G loss: {:.4}'.format(loss_val_G))
            saver.save(sess, './models/dcgan', global_step=epoch)
Exemple #4
0
def main():
    parser = argparse.ArgumentParser(description='PixelCNN')

    parser.add_argument('--causal-ksize', type=int, default=7,
                        help='Kernel size of causal convolution')
    parser.add_argument('--hidden-ksize', type=int, default=7,
                        help='Kernel size of hidden layers convolutions')

    parser.add_argument('--color-levels', type=int, default=2,
                        help='Number of levels to quantisize value of each channel of each pixel into')

    parser.add_argument('--hidden-fmaps', type=int, default=30,
                        help='Number of feature maps in hidden layer')
    parser.add_argument('--out-hidden-fmaps', type=int, default=10,
                        help='Number of feature maps in outer hidden layer')
    parser.add_argument('--hidden-layers', type=int, default=6,
                        help='Number of layers of gated convolutions with mask of type "B"')

    parser.add_argument('--cuda', type=str2bool, default=True,
                        help='Flag indicating whether CUDA should be used')
    parser.add_argument('--model-path', '-m',
                        help="Path to model's saved parameters")
    parser.add_argument('--output-fname', type=str, default='samples.png',
                        help='Name of output file (.png format)')

    parser.add_argument('--label', '--l', type=int, default=-1,
                        help='Label of sampled images. -1 indicates random labels.')

    parser.add_argument('--count', '-c', type=int, default=64,
                        help='Number of images to generate')
    parser.add_argument('--height', type=int, default=28, help='Output image height')
    parser.add_argument('--width', type=int, default=28, help='Output image width')

    cfg = parser.parse_args()
    OUTPUT_FILENAME = cfg.output_fname

    model = PixelCNN(cfg=cfg)
    model.eval()

    device = torch.device("cuda" if torch.cuda.is_available() and cfg.cuda else "cpu")
    model.to(device)

    model.load_state_dict(torch.load(cfg.model_path))

    label = None if cfg.label == -1 else cfg.label
    samples = model.sample((3, cfg.height, cfg.width), cfg.count, label=label, device=device)
    save_samples(samples, OUTPUT_DIRNAME, OUTPUT_FILENAME)
    def train(self):
        def preprocess_func(x):
            return tf.cast(x, tf.float32)/127.5 - 1.

        logfile_path = os.path.join(self.results_dir, 'training_base_logfile.txt')
        if not os.path.exists(logfile_path):
            logfile = open(logfile_path, "w")

        while True:
            for dataset_path in sorted(os.listdir(self.data_dir)):
                dataset = np.load(os.path.join(self.data_dir, dataset_path))
                dataset = tf.data.Dataset.from_tensor_slices((dataset)).shuffle(self.batch_size*5).batch(self.batch_size, drop_remainder=True).map(preprocess_func)
                dataset = self.strategy.experimental_distribute_dataset(dataset)
                
                for x in dataset:
                    self.train_step(x)
                    
                    i = self.i
                    
                    if i%25 == 0: #updates the moving average every 25 steps , as it takes a very long time.
                        self._update_moving_avg()
                    if i%1000==0:
                        loss, gradnorm = to_numpy(self.train_loss.result()), to_numpy(self.train_gradnorm.result())
                        result_str = "Iteration: %d, Time Elapsed: %0.1f,  Loss: %0.4f, gradnorm: %0.4f" % (i, time()-self.starttime, loss, gradnorm)
                        print(result_str)
                        logfile = open(logfile_path, "a")
                        logfile.write(result_str + "\n")
                        logfile.close()
                        self.train_loss.reset_states()
                        self.train_gradnorm.reset_states()
                    if i%10000==0:
                        self.ema_model.set_weights(self.moving_avg_weights)
                        self.save_objects(i//1000)
                        _, samples_ema = self.generate_samples(64)
                        savepath = self._getp('samples_ema_{}k.jpg'.format(i//1000), folder='results')
                        save_samples(samples_ema, savepath)

                    if i >= self.max_iterations:
                        print("Training complete. Recover your models from the {} folder and your training results from the {} folder".format(self.model_dir, self.results_dir))
                        print("Training was done on a maximum of {} iterations".format(i))
                        return True

                    self.i += 1
Exemple #6
0
    def test(self):

        init = tf.initialize_all_variables()

        with tf.Session() as sess:
            sess.run(init)

            self.saver.restore(sess, self.model_path)
            sample_z = np.random.uniform(1,
                                         -1,
                                         size=[self.batch_size, self.z_dim])
            sample_exec_time = sample_output(self.batch_size)
            sample_feature = sess.run(self.fake_feature_vector,
                                      feed_dict={
                                          self.z: sample_z,
                                          self.y: sample_exec_time
                                      })
            save_samples(sample_feature, sample_exec_time)
            print("Test finish!")
Exemple #7
0
def test(ckpt_root, model, batch_size):

    X, Z, Lr, Kt = model.inputs()
    g_sample = model.sample(Z, reuse=False)
    sample_size = batch_size
    test_noise = utils.get_uniform_noise(sample_size, n_noise)

    saver = tf.train.Saver()

    with tf.Session() as sess:
        #sess.run(tf.global_variables_initializer())
        saver.restore(sess, ckpt_root)

        samples = sess.run(g_sample, feed_dict={Z: test_noise})
        date = datetime.datetime.now()
        title = 'samples/%s.png' % date
        utils.save_samples(title, samples)
    np.savetxt('samples/test_noise%s.txt' % date,
               test_noise,
               fmt='%2.5f',
               delimiter=', ')
Exemple #8
0
def sample(model, config, samples_dir, texture_path,
           n_samples=20, n_z_samples=5, inverse=False):
    utils.makedirs(samples_dir)

    imgs = None
    if inverse:
        img_files = sorted(os.listdir(texture_path))[:n_samples]
        img_files = [texture_path + file for file in img_files]
        imgs = get_images(img_files)
        imgs = [get_random_patch(img, config.npx) for img in imgs]
        imgs = [np.reshape(img, (1,) + img.shape) for img in imgs]
    all_samples = []
    if inverse:
        X = np.concatenate(imgs, axis=0)
        global_noise = model.generate_z(X)
        global_noise = model.generate_z_det(X)
        z_samples = utils.sample_noise_tensor(config, n_z_samples, config.zx,
                                              global_noise=global_noise,
                                              per_each=True)
        gen_samples = model.generate(z_samples)
        gen_samples = model.generate_det(z_samples)
        all_samples = [[img, list(gen_samples[n_z_samples*i:n_z_samples*(i+1)])]
                       for i, img in enumerate(imgs)]
        all_samples = [np.concatenate(np.concatenate(samples, axis=0), axis=2)
                       for samples in all_samples]
        all_samples = [np.concatenate(all_samples, axis=1)]
        utils.save_samples(samples_dir, all_samples, ['inv_gens'])
        all_samples = []
    for i in range(n_samples):
        global_noise = np.random.uniform(-1., 1., (1, config.nz_global, 1, 1))
        z_samples = utils.sample_noise_tensor(config, n_z_samples, config.zx,
                                              global_noise=global_noise)
        gen_samples = model.generate_det(z_samples)
        gen_samples = np.concatenate(gen_samples, axis=2)
        all_samples.append(gen_samples)
    all_samples = [np.concatenate(all_samples, axis=1)]
    utils.save_samples(samples_dir, all_samples, ['gens'])
def test(ckpt_root, model, batch_size):

    X, Z, Lr = model.inputs()
    g_sample = model.sample(Z, reuse=False)
    sample_size = batch_size
    test_noise = utils.get_noise(sample_size, n_noise)

    # TEST DCGAN interpolation
    '''
    test_noise[0] = np.array([[-0.01403, 0.56896, 1.41881, 0.02516, -1.36731, -0.92614, 1.17105, -0.17130, -0.11242, 0.35453, 0.06243, 0.41190, -0.18923, -1.10846, -0.10500, 0.65989, -0.19307, -0.32606, -1.77017, -0.38637, -0.82117, 0.53288, -0.38393, 1.16999, 0.02266, 0.36757, 0.13555, -1.06630, 0.00951, -0.04134, -0.29982, -0.83991, -0.04059, -0.56064, 0.39640, 0.29686, 0.42023, -1.15875, 0.19443, -0.89730, 0.37836, -2.48704, -0.03874, 0.04086, -0.35425, -0.02359, 0.56843, -0.45289, 1.79295, 0.98343, -0.99543, 0.70134, -1.43882, -0.10630, -0.39800, -1.90689, -0.16606, 0.01075, 0.11386, 0.08757, 0.25799, 1.06645, 0.07529, 1.17719, 1.38717, -0.93715, 0.60258, 0.64817, -0.70972, 1.49177, -0.58564, -1.47612, -0.49625, 2.30098, -0.08210, -0.22495, -0.47805, -0.72601, 0.58665, -0.63158, 0.04414, -0.05951, -0.92667, -0.07905, -2.26017, -0.29677, 0.93230, -0.06546, -0.46701, 1.49024, 0.01060, -0.86621, -0.65857, 0.42297, -1.43760, 0.53813, -0.13808, 0.23095, -0.78151, 0.63207
]])
    test_noise[8] = np.array([[-0.30745, -1.80994, 0.84740, -0.01723, -0.25759, 1.62209, -0.01877, 1.31540, 1.80470, -1.76964, 2.06064, -0.62803, -0.94382, 0.85376, -0.26913, 0.69890, 1.52500, -0.62958, -0.97269, 1.81976, 1.46848, -0.10180, -0.14649, 0.82289, -0.21654, 0.63229, -0.61106, -0.84134, 0.95145, -0.84128, -0.02509, -0.14419, -0.46364, -0.00298, -0.23900, -1.37273, -1.16797, 1.10777, -1.56686, -0.60846, -0.18123, 1.95980, 0.58466, 0.64532, 1.01655, -1.00187, 0.07544, 0.31779, 1.55344, -0.41186, -0.14158, 0.07359, -1.02670, -0.14173, 0.10773, -0.64202, -1.56408, -1.96202, -0.13097, -0.05426, -2.26692, 0.04790, 0.03724, -0.55998, 0.11415, -1.97006, 0.56635, -1.29249, 0.32449, 0.37213, -0.77510, -0.09502, 2.44859, 0.68632, 0.48752, 0.18134, -1.14473, -0.09552, -0.62953, 0.28095, -1.04062, 0.39957, -1.39301, -0.29697, -0.99899, 1.91437, -1.94361, -0.38661, -0.04163, -0.09743, -0.87291, 1.00404, 0.51789, -0.78019, 1.43526, 0.16111, 1.26596, -0.12284, 0.74221, 1.53793
]])
    test_noise[1] = (7*test_noise[0]+1*test_noise[8])/8
    test_noise[2] = (6*test_noise[0]+2*test_noise[8])/8
    test_noise[3] = (5*test_noise[0]+3*test_noise[8])/8
    test_noise[4] = (4*test_noise[0]+4*test_noise[8])/8
    test_noise[5] = (3*test_noise[0]+5*test_noise[8])/8
    test_noise[6] = (2*test_noise[0]+6*test_noise[8])/8
    test_noise[7] = (1*test_noise[0]+7*test_noise[8])/8
    '''

    saver = tf.train.Saver()

    with tf.Session() as sess:
        #sess.run(tf.global_variables_initializer())
        saver.restore(sess, ckpt_root)

        samples = sess.run(g_sample, feed_dict={Z: test_noise})
        date = datetime.datetime.now()
        title = 'samples/%s.png' % date
        utils.save_samples(title, samples)
    np.savetxt('samples/test_noise%s.txt' % date,
               test_noise,
               fmt='%2.5f',
               delimiter=', ')
Exemple #10
0
    G_costs.append(G_cost_epoch_avg)

    LOGGER.info(
        "{} D_cost_train:{:.4f} | D_wass_train:{:.4f} | D_cost_valid:{:.4f} | D_wass_valid:{:.4f} | "
        "G_cost:{:.4f}".format(time_since(start), D_cost_train_epoch_avg,
                               D_wass_train_epoch_avg, D_cost_valid_epoch_avg,
                               D_wass_valid_epoch_avg, G_cost_epoch_avg))

    # Generate audio samples.
    if epoch % epochs_per_sample == 0:
        LOGGER.info("Generating samples...")
        sample_out = netG(sample_noise_Var)
        if cuda:
            sample_out = sample_out.cpu()
        sample_out = sample_out.data.numpy()
        save_samples(sample_out, epoch, output_dir)

    # TODO
    # Early stopping by Inception Score(IS)

LOGGER.info('>>>>>>>Training finished !<<<<<<<')

# Save model
LOGGER.info("Saving models...")
netD_path = os.path.join(output_dir, "discriminator.pkl")
netG_path = os.path.join(output_dir, "generator.pkl")
torch.save(netD.state_dict(),
           netD_path,
           pickle_protocol=pickle.HIGHEST_PROTOCOL)
torch.save(netG.state_dict(),
           netG_path,
def train():

    # Fix Seed for Reproducibility #
    torch.manual_seed(9)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(9)

    # Samples, Weights and Results Path #
    paths = [config.samples_path, config.weights_path, config.plots_path]
    paths = [make_dirs(path) for path in paths]

    # Prepare Data Loader #
    train_horse_loader, train_zebra_loader = get_horse2zebra_loader('train', config.batch_size)
    val_horse_loader, val_zebra_loader = get_horse2zebra_loader('test', config.batch_size)
    total_batch = min(len(train_horse_loader), len(train_zebra_loader))

    # Image Pool #
    masked_fake_A_pool = ImageMaskPool(config.pool_size)
    masked_fake_B_pool = ImageMaskPool(config.pool_size)

    # Prepare Networks #
    Attn_A = Attention()
    Attn_B = Attention()
    G_A2B = Generator()
    G_B2A = Generator()
    D_A = Discriminator()
    D_B = Discriminator()

    networks = [Attn_A, Attn_B, G_A2B, G_B2A, D_A, D_B]
    for network in networks:
        network.to(device)

    # Loss Function #
    criterion_Adversarial = nn.MSELoss()
    criterion_Cycle = nn.L1Loss()

    # Optimizers #
    D_optim = torch.optim.Adam(chain(D_A.parameters(), D_B.parameters()), lr=config.lr, betas=(0.5, 0.999))
    G_optim = torch.optim.Adam(chain(Attn_A.parameters(), Attn_B.parameters(), G_A2B.parameters(), G_B2A.parameters()), lr=config.lr, betas=(0.5, 0.999))

    D_optim_scheduler = get_lr_scheduler(D_optim)
    G_optim_scheduler = get_lr_scheduler(G_optim)

    # Lists #
    D_A_losses, D_B_losses = [], []
    G_A_losses, G_B_losses = [], []

    # Train #
    print("Training Unsupervised Attention-Guided GAN started with total epoch of {}.".format(config.num_epochs))

    for epoch in range(config.num_epochs):

        for i, (real_A, real_B) in enumerate(zip(train_horse_loader, train_zebra_loader)):

            # Data Preparation #
            real_A = real_A.to(device)
            real_B = real_B.to(device)

            # Initialize Optimizers #
            D_optim.zero_grad()
            G_optim.zero_grad()

            ###################
            # Train Generator #
            ###################

            set_requires_grad([D_A, D_B], requires_grad=False)

            # Adversarial Loss using real A #
            attn_A = Attn_A(real_A)
            fake_B = G_A2B(real_A)

            masked_fake_B = fake_B * attn_A + real_A * (1-attn_A)

            masked_fake_B *= attn_A
            prob_real_A = D_A(masked_fake_B)
            real_labels = torch.ones(prob_real_A.size()).to(device)

            G_loss_A = criterion_Adversarial(prob_real_A, real_labels)

            # Adversarial Loss using real B #
            attn_B = Attn_B(real_B)
            fake_A = G_B2A(real_B)

            masked_fake_A = fake_A * attn_B + real_B * (1-attn_B)

            masked_fake_A *= attn_B
            prob_real_B = D_B(masked_fake_A)
            real_labels = torch.ones(prob_real_B.size()).to(device)

            G_loss_B = criterion_Adversarial(prob_real_B, real_labels)

            # Cycle Consistency Loss using real A #
            attn_ABA = Attn_B(masked_fake_B)
            fake_ABA = G_B2A(masked_fake_B)
            masked_fake_ABA = fake_ABA * attn_ABA + masked_fake_B * (1 - attn_ABA)

            # Cycle Consistency Loss using real B #
            attn_BAB = Attn_A(masked_fake_A)
            fake_BAB = G_A2B(masked_fake_A)
            masked_fake_BAB = fake_BAB * attn_BAB + masked_fake_A * (1 - attn_BAB)

            # Cycle Consistency Loss #
            G_cycle_loss_A = config.lambda_cycle * criterion_Cycle(masked_fake_ABA, real_A)
            G_cycle_loss_B = config.lambda_cycle * criterion_Cycle(masked_fake_BAB, real_B)

            # Total Generator Loss #
            G_loss = G_loss_A + G_loss_B + G_cycle_loss_A + G_cycle_loss_B

            # Back Propagation and Update #
            G_loss.backward()
            G_optim.step()

            #######################
            # Train Discriminator #
            #######################

            set_requires_grad([D_A, D_B], requires_grad=True)

            # Train Discriminator A using real A #
            prob_real_A = D_A(real_B)
            real_labels = torch.ones(prob_real_A.size()).to(device)
            D_loss_real_A = criterion_Adversarial(prob_real_A, real_labels)

            # Add Pooling #
            masked_fake_B, attn_A = masked_fake_B_pool.query(masked_fake_B, attn_A)
            masked_fake_B *= attn_A

            # Train Discriminator A using fake B #
            prob_fake_B = D_A(masked_fake_B.detach())
            fake_labels = torch.zeros(prob_fake_B.size()).to(device)
            D_loss_fake_A = criterion_Adversarial(prob_fake_B, fake_labels)

            D_loss_A = (D_loss_real_A + D_loss_fake_A).mean()

            # Train Discriminator B using real B #
            prob_real_B = D_B(real_A)
            real_labels = torch.ones(prob_real_B.size()).to(device)
            D_loss_real_B = criterion_Adversarial(prob_real_B, real_labels)

            # Add Pooling #
            masked_fake_A, attn_B = masked_fake_A_pool.query(masked_fake_A, attn_B)
            masked_fake_A *= attn_B

            # Train Discriminator B using fake A #
            prob_fake_A = D_B(masked_fake_A.detach())
            fake_labels = torch.zeros(prob_fake_A.size()).to(device)
            D_loss_fake_B = criterion_Adversarial(prob_fake_A, fake_labels)

            D_loss_B = (D_loss_real_B + D_loss_fake_B).mean()

            # Calculate Total Discriminator Loss #
            D_loss = D_loss_A + D_loss_B

            # Back Propagation and Update #
            D_loss.backward()
            D_optim.step()

            # Add items to Lists #
            D_A_losses.append(D_loss_A.item())
            D_B_losses.append(D_loss_B.item())
            G_A_losses.append(G_loss_A.item())
            G_B_losses.append(G_loss_B.item())

            ####################
            # Print Statistics #
            ####################

            if (i+1) % config.print_every == 0:
                print("UAG-GAN | Epoch [{}/{}] | Iteration [{}/{}] | D A Losses {:.4f} | D B Losses {:.4f} | G A Losses {:.4f} | G B Losses {:.4f}".
                      format(epoch+1, config.num_epochs, i+1, total_batch, np.average(D_A_losses), np.average(D_B_losses), np.average(G_A_losses), np.average(G_B_losses)))

                # Save Sample Images #
                save_samples(val_horse_loader, val_zebra_loader, G_A2B, G_B2A, Attn_A, Attn_B, epoch, config.samples_path)

        # Adjust Learning Rate #
        D_optim_scheduler.step()
        G_optim_scheduler.step()

        # Save Model Weights #
        if (epoch + 1) % config.save_every == 0:
            torch.save(G_A2B.state_dict(), os.path.join(config.weights_path, 'UAG-GAN_Generator_A2B_Epoch_{}.pkl'.format(epoch+1)))
            torch.save(G_B2A.state_dict(), os.path.join(config.weights_path, 'UAG-GAN_Generator_B2A_Epoch_{}.pkl'.format(epoch+1)))
            torch.save(Attn_A.state_dict(), os.path.join(config.weights_path, 'UAG-GAN_Attention_A_Epoch_{}.pkl'.format(epoch+1)))
            torch.save(Attn_B.state_dict(), os.path.join(config.weights_path, 'UAG-GAN_Attention_B_Epoch_{}.pkl'.format(epoch+1)))

    # Make a GIF file #
    make_gifs_train("UAG-GAN", config.samples_path)

    # Plot Losses #
    plot_losses(D_A_losses, D_B_losses, G_A_losses, G_B_losses, config.num_epochs, config.plots_path)

    print("Training finished.")
def train():

    # Fix Seed for Reproducibility #
    torch.manual_seed(9)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(9)

    # Samples, Weights, and Plots Path #
    paths = [config.samples_path, config.weights_path, config.plots_path]
    paths = [make_dirs(path) for path in paths]

    # Prepare Data Loader #
    train_loader_selfie, train_loader_anime = get_selfie2anime_loader(
        'train', config.batch_size)
    total_batch = max(len(train_loader_selfie), len(train_loader_anime))

    test_loader_selfie, test_loader_anime = get_selfie2anime_loader(
        'test', config.val_batch_size)

    # Prepare Networks #
    D_A = Discriminator(num_layers=7)
    D_B = Discriminator(num_layers=7)
    L_A = Discriminator(num_layers=5)
    L_B = Discriminator(num_layers=5)
    G_A2B = Generator(image_size=config.crop_size,
                      num_blocks=config.num_blocks)
    G_B2A = Generator(image_size=config.crop_size,
                      num_blocks=config.num_blocks)

    networks = [D_A, D_B, L_A, L_B, G_A2B, G_B2A]

    for network in networks:
        network.to(device)

    # Loss Function #
    Adversarial_loss = nn.MSELoss()
    Cycle_loss = nn.L1Loss()
    BCE_loss = nn.BCEWithLogitsLoss()

    # Optimizers #
    D_optim = torch.optim.Adam(chain(D_A.parameters(), D_B.parameters(),
                                     L_A.parameters(), L_B.parameters()),
                               lr=config.lr,
                               betas=(0.5, 0.999),
                               weight_decay=0.0001)
    G_optim = torch.optim.Adam(chain(G_A2B.parameters(), G_B2A.parameters()),
                               lr=config.lr,
                               betas=(0.5, 0.999),
                               weight_decay=0.0001)

    D_optim_scheduler = get_lr_scheduler(D_optim)
    G_optim_scheduler = get_lr_scheduler(G_optim)

    # Rho Clipper to constraint the value of rho in AdaILN and ILN #
    Rho_Clipper = RhoClipper(0, 1)

    # Lists #
    D_losses = []
    G_losses = []

    # Train #
    print("Training U-GAT-IT started with total epoch of {}.".format(
        config.num_epochs))
    for epoch in range(config.num_epochs):

        for i, (selfie, anime) in enumerate(
                zip(train_loader_selfie, train_loader_anime)):

            # Data Preparation #
            real_A = selfie.to(device)
            real_B = anime.to(device)

            # Initialize Optimizers #
            D_optim.zero_grad()
            G_optim.zero_grad()

            #######################
            # Train Discriminator #
            #######################

            set_requires_grad([D_A, D_B, L_A, L_B], requires_grad=True)

            # Forward Data #
            fake_B, _, _ = G_A2B(real_A)
            fake_A, _, _ = G_B2A(real_B)

            G_real_A, G_real_A_cam, _ = D_A(real_A)
            L_real_A, L_real_A_cam, _ = L_A(real_A)
            G_real_B, G_real_B_cam, _ = D_B(real_B)
            L_real_B, L_real_B_cam, _ = L_B(real_B)

            G_fake_A, G_fake_A_cam, _ = D_A(fake_A)
            L_fake_A, L_fake_A_cam, _ = L_A(fake_A)
            G_fake_B, G_fake_B_cam, _ = D_B(fake_B)
            L_fake_B, L_fake_B_cam, _ = L_B(fake_B)

            # Adversarial Loss of Discriminator #
            real_labels = torch.ones(G_real_A.shape).to(device)
            D_ad_real_loss_GA = Adversarial_loss(G_real_A, real_labels)

            fake_labels = torch.zeros(G_fake_A.shape).to(device)
            D_ad_fake_loss_GA = Adversarial_loss(G_fake_A, fake_labels)

            D_ad_loss_GA = D_ad_real_loss_GA + D_ad_fake_loss_GA

            real_labels = torch.ones(G_real_A_cam.shape).to(device)
            D_ad_cam_real_loss_GA = Adversarial_loss(G_real_A_cam, real_labels)

            fake_labels = torch.zeros(G_fake_A_cam.shape).to(device)
            D_ad_cam_fake_loss_GA = Adversarial_loss(G_fake_A_cam, fake_labels)

            D_ad_cam_loss_GA = D_ad_cam_real_loss_GA + D_ad_cam_fake_loss_GA

            real_labels = torch.ones(G_real_B.shape).to(device)
            D_ad_real_loss_GB = Adversarial_loss(G_real_B, real_labels)

            fake_labels = torch.zeros(G_fake_B.shape).to(device)
            D_ad_fake_loss_GB = Adversarial_loss(G_fake_B, fake_labels)

            D_ad_loss_GB = D_ad_real_loss_GB + D_ad_fake_loss_GB

            real_labels = torch.ones(G_real_B_cam.shape).to(device)
            D_ad_cam_real_loss_GB = Adversarial_loss(G_real_B_cam, real_labels)

            fake_labels = torch.zeros(G_fake_B_cam.shape).to(device)
            D_ad_cam_fake_loss_GB = Adversarial_loss(G_fake_B_cam, fake_labels)

            D_ad_cam_loss_GB = D_ad_cam_real_loss_GB + D_ad_cam_fake_loss_GB

            # Adversarial Loss of L #
            real_labels = torch.ones(L_real_A.shape).to(device)
            D_ad_real_loss_LA = Adversarial_loss(L_real_A, real_labels)

            fake_labels = torch.zeros(L_fake_A.shape).to(device)
            D_ad_fake_loss_LA = Adversarial_loss(L_fake_A, fake_labels)

            D_ad_loss_LA = D_ad_real_loss_LA + D_ad_fake_loss_LA

            real_labels = torch.ones(L_real_A_cam.shape).to(device)
            D_ad_cam_real_loss_LA = Adversarial_loss(L_real_A_cam, real_labels)

            fake_labels = torch.zeros(L_fake_A_cam.shape).to(device)
            D_ad_cam_fake_loss_LA = Adversarial_loss(L_fake_A_cam, fake_labels)

            D_ad_cam_loss_LA = D_ad_cam_real_loss_LA + D_ad_cam_fake_loss_LA

            real_labels = torch.ones(L_real_B.shape).to(device)
            D_ad_real_loss_LB = Adversarial_loss(L_real_B, real_labels)

            fake_labels = torch.zeros(L_fake_B.shape).to(device)
            D_ad_fake_loss_LB = Adversarial_loss(L_fake_B, fake_labels)

            D_ad_loss_LB = D_ad_real_loss_LB + D_ad_fake_loss_LB

            real_labels = torch.ones(L_real_B_cam.shape).to(device)
            D_ad_cam_real_loss_LB = Adversarial_loss(L_real_B_cam, real_labels)

            fake_labels = torch.zeros(L_fake_B_cam.shape).to(device)
            D_ad_cam_fake_loss_LB = Adversarial_loss(L_fake_B_cam, fake_labels)

            D_ad_cam_loss_LB = D_ad_cam_real_loss_LB + D_ad_cam_fake_loss_LB

            # Calculate Each Discriminator Loss #
            D_loss_A = D_ad_loss_GA + D_ad_cam_loss_GA + D_ad_loss_LA + D_ad_cam_loss_LA
            D_loss_B = D_ad_loss_GB + D_ad_cam_loss_GB + D_ad_loss_LB + D_ad_cam_loss_LB

            # Calculate Total Discriminator Loss #
            D_loss = D_loss_A + D_loss_B

            # Back Propagation and Update #
            D_loss.backward()
            D_optim.step()

            ###################
            # Train Generator #
            ###################

            set_requires_grad([D_A, D_B, L_A, L_B], requires_grad=False)

            # Forward Data #
            fake_B, fake_B_cam, _ = G_A2B(real_A)
            fake_A, fake_A_cam, _ = G_B2A(real_B)

            fake_ABA, _, _ = G_B2A(fake_B)
            fake_BAB, _, _ = G_A2B(fake_A)

            fake_A2A, fake_A2A_cam, _ = G_A2B(real_A)
            fake_B2B, fake_B2B_cam, _ = G_B2A(real_B)

            G_fake_A, G_fake_A_cam, _ = D_A(fake_A)
            L_fake_A, L_fake_A_cam, _ = L_A(fake_A)
            G_fake_B, G_fake_B_cam, _ = D_B(fake_B)
            L_fake_B, L_fake_B_cam, _ = L_B(fake_B)

            # Adversarial Loss of Generator #
            real_labels = torch.ones(G_fake_A.shape).to(device)
            G_adv_fake_loss_A = Adversarial_loss(G_fake_A, real_labels)

            real_labels = torch.ones(G_fake_A_cam.shape).to(device)
            G_adv_cam_fake_loss_A = Adversarial_loss(G_fake_A_cam, real_labels)

            G_adv_loss_A = G_adv_fake_loss_A + G_adv_cam_fake_loss_A

            real_labels = torch.ones(G_fake_B.shape).to(device)
            G_adv_fake_loss_B = Adversarial_loss(G_fake_B, real_labels)

            real_labels = torch.ones(G_fake_B_cam.shape).to(device)
            G_adv_cam_fake_loss_B = Adversarial_loss(G_fake_B_cam, real_labels)

            G_adv_loss_B = G_adv_fake_loss_B + G_adv_cam_fake_loss_B

            # Adversarial Loss of L #
            real_labels = torch.ones(L_fake_A.shape).to(device)
            L_adv_fake_loss_A = Adversarial_loss(L_fake_A, real_labels)

            real_labels = torch.ones(L_fake_A_cam.shape).to(device)
            L_adv_cam_fake_loss_A = Adversarial_loss(L_fake_A_cam, real_labels)

            L_adv_loss_A = L_adv_fake_loss_A + L_adv_cam_fake_loss_A

            real_labels = torch.ones(L_fake_B.shape).to(device)
            L_adv_fake_loss_B = Adversarial_loss(L_fake_B, real_labels)

            real_labels = torch.ones(L_fake_B_cam.shape).to(device)
            L_adv_cam_fake_loss_B = Adversarial_loss(L_fake_B_cam, real_labels)

            L_adv_loss_B = L_adv_fake_loss_B + L_adv_cam_fake_loss_B

            # Cycle Consistency Loss #
            G_recon_loss_A = Cycle_loss(fake_ABA, real_A)
            G_recon_loss_B = Cycle_loss(fake_BAB, real_B)

            G_identity_loss_A = Cycle_loss(fake_A2A, real_A)
            G_identity_loss_B = Cycle_loss(fake_B2B, real_B)

            G_cycle_loss_A = G_recon_loss_A + G_identity_loss_A
            G_cycle_loss_B = G_recon_loss_B + G_identity_loss_B

            # CAM Loss #
            real_labels = torch.ones(fake_A_cam.shape).to(device)
            G_cam_real_loss_A = BCE_loss(fake_A_cam, real_labels)

            fake_labels = torch.zeros(fake_A2A_cam.shape).to(device)
            G_cam_fake_loss_A = BCE_loss(fake_A2A_cam, fake_labels)

            G_cam_loss_A = G_cam_real_loss_A + G_cam_fake_loss_A

            real_labels = torch.ones(fake_B_cam.shape).to(device)
            G_cam_real_loss_B = BCE_loss(fake_B_cam, real_labels)

            fake_labels = torch.zeros(fake_B2B_cam.shape).to(device)
            G_cam_fake_loss_B = BCE_loss(fake_B2B_cam, fake_labels)

            G_cam_loss_B = G_cam_real_loss_B + G_cam_fake_loss_B

            # Calculate Each Generator Loss #
            G_loss_A = G_adv_loss_A + L_adv_loss_A + config.lambda_cycle * G_cycle_loss_A + config.lambda_cam * G_cam_loss_A
            G_loss_B = G_adv_loss_B + L_adv_loss_B + config.lambda_cycle * G_cycle_loss_B + config.lambda_cam * G_cam_loss_B

            # Calculate Total Generator Loss #
            G_loss = G_loss_A + G_loss_B

            # Back Propagation and Update #
            G_loss.backward()
            G_optim.step()

            # Apply Rho Clipper to Generators #
            G_A2B.apply(Rho_Clipper)
            G_B2A.apply(Rho_Clipper)

            # Add items to Lists #
            D_losses.append(D_loss.item())
            G_losses.append(G_loss.item())

            ####################
            # Print Statistics #
            ####################

            if (i + 1) % config.print_every == 0:
                print(
                    "U-GAT-IT | Epochs [{}/{}] | Iterations [{}/{}] | D Loss {:.4f} | G Loss {:.4f}"
                    .format(epoch + 1, config.num_epochs, i + 1, total_batch,
                            np.average(D_losses), np.average(G_losses)))

                # Save Sample Images #
                save_samples(test_loader_selfie, G_A2B, epoch,
                             config.samples_path)

        # Adjust Learning Rate #
        D_optim_scheduler.step()
        G_optim_scheduler.step()

        # Save Model Weights #
        if (epoch + 1) % config.save_every == 0:
            torch.save(
                D_A.state_dict(),
                os.path.join(config.weights_path,
                             'U-GAT-IT_D_A_Epoch_{}.pkl'.format(epoch + 1)))
            torch.save(
                D_B.state_dict(),
                os.path.join(config.weights_path,
                             'U-GAT-IT_D_B_Epoch_{}.pkl'.format(epoch + 1)))
            torch.save(
                L_A.state_dict(),
                os.path.join(config.weights_path,
                             'U-GAT-IT_L_A_Epoch_{}.pkl'.format(epoch + 1)))
            torch.save(
                L_B.state_dict(),
                os.path.join(config.weights_path,
                             'U-GAT-IT_L_B_Epoch_{}.pkl'.format(epoch + 1)))
            torch.save(
                G_A2B.state_dict(),
                os.path.join(config.weights_path,
                             'U-GAT-IT_G_A2B_Epoch_{}.pkl'.format(epoch + 1)))
            torch.save(
                G_B2A.state_dict(),
                os.path.join(config.weights_path,
                             'U-GAT-IT_G_B2A_Epoch_{}.pkl'.format(epoch + 1)))

    # Plot Losses #
    plot_losses(D_losses, G_losses, config.num_epochs, config.plots_path)

    # Make a GIF file #
    make_gifs_train('U-GAT-IT', config.samples_path)

    print("Training finished.")
Exemple #13
0
            # raw_data.append(spikes)
            # Reshape to required format
            data = data.reshape((1, imageSize, imageSize))
            recv_data[rank - 1] = data
            comm.Send(recv_data, dest=0)
# print('done generating data, in sec: {}'.format(time.time() - t))

if rank == 0:
    for i in range(1, size):
        comm.Recv(recv_data, source=i)
        binned_data[i - 1] = recv_data[i - 1]

if comm.rank == 0:
    normed_data = np.empty((num_samples, 1, imageSize, imageSize),
                           dtype=np.float32)
    for i, bst in enumerate(binned_data):
        # Normalize data
        normed_data[i] = np.divide(bst, np.max(bst))
        # data = (data - data.mean()) / data.std()

    save_dict['binned_data'] = binned_data
    save_dict['normed_data'] = normed_data
    save_dict['num_samples'] = num_samples
    save_dict['imageSize'] = imageSize
    # save_dict['spikes'] = raw_data
    save_dict['data_type'] = data_type
    utils.save_samples(save_dict,
                       path='.',
                       filename='data_NS{}_IS{}_type-{}.npy'.format(
                           num_samples, imageSize, data_type))
                loss += value

            loss.backward()
            optimizer.step()

            print("[Epoch %d/%d] [Batch %d/%d] [loss: %f] " %
                  (epoch, args.epoch, idx, len(train_loader), loss))

        if epoch % args.val_epoch == 0:
            mse, ssim = utils.evaluate(test_loader, unet, args)
            print("mse = {}, ssim = {}".format(mse, ssim))
            score = 1 - mse / 100.0 + ssim
            if score > best_score:
                # if ssim > best_ssim:
                torch.save(
                    unet.state_dict(),
                    os.path.join(args.save_dir, 'unet_{}.pth'.format(epoch)))
                best_score = score
                # best_ssim = ssim
                utils.save_samples(test_loader, unet, save_samples_dir, epoch)

            print('ssim = {}, mse = {}, score = {}, best_score = {} '.format(
                ssim, mse, score, best_score))

    print('best_score = {} '.format(best_score))
'''
Reference:
https://github.com/ceshine/fast-neural-style/blob/master/notebooks/01-image-style-transfer.ipynb

'''
Exemple #15
0
    def test(self, training=False):
        # input images for the model to train on
        x = tf.placeholder(tf.float32,
                           shape=(None, self.height, self.width,
                                  self.channels),
                           name='inputs')

        # correct images for the model to check against when learning
        if (self.config == '--MNIST'):
            y = tf.placeholder(tf.float32,
                               shape=(None, self.height, self.width,
                                      self.channels),
                               name='correct_images')
        elif (self.config == '--CIFAR' or self.config == '--FREY'):
            y = tf.placeholder(tf.float32,
                               shape=(None, self.height, self.width,
                                      256 * self.channels),
                               name='correct_images')

        # model is in the training phase
        is_training = tf.placeholder(tf.bool, name='training')

        # build out the network architecture
        self.build_network(x, y, is_training, 'test')

        saver = tf.train.Saver()

        with tf.Session() as sess:
            saver.restore(sess, '/tmp/model.ckpt')

            loss = []
            NLL = []
            for i in range(self.test_inputs.shape[0]):
                if (self.config == '--MNIST'):
                    binaryImage = binarize(self.test_inputs[i])
                    binaryImage = np.reshape(binaryImage, (1, 28, 28, 1))
                    loss.append(
                        sess.run(self.loss,
                                 feed_dict={
                                     x: binaryImage,
                                     y: binaryImage,
                                     is_training: training
                                 }))
                    NLL.append(
                        sess.run(self.NLL,
                                 feed_dict={
                                     x: binaryImage,
                                     y: binaryImage,
                                     is_training: training
                                 }))
                elif (self.config == '--FREY'):
                    image = self.test_inputs[i]
                    label = sess.run(
                        tf.one_hot(np.reshape(image, (-1, 28, 20)),
                                   256,
                                   axis=-1))
                    image = 2. * (np.reshape(
                        image, (1, 28, 20, 1)).astype('float32') / 255.) - 1.
                    loss.append(
                        sess.run(self.loss,
                                 feed_dict={
                                     x: image,
                                     y: label,
                                     is_training: training
                                 }))
                    NLL.append(
                        sess.run(self.NLL,
                                 feed_dict={
                                     x: image,
                                     y: label,
                                     is_training: training
                                 }))

            loss = np.mean(loss)
            NLL = np.mean(NLL)
            print("Test loss: ", loss, "NLL: ", NLL)

            #save_samples(self.test_inputs[:22], self.height, self.width)

            # remove bottom half from images
            images = trim_images(self.test_inputs)

            if (self.config == '--MNIST'):
                images = images[:22]
                self.test_inputs = self.test_inputs[:22]
            elif (self.config == '--CIFAR' or self.config == '--FREY'):
                images = 2 * (images[:22].astype('float32') / 255.) - 1.
                self.test_inputs = sess.run(
                    tf.one_hot(np.reshape(self.test_inputs[:22], (-1, 28, 20)),
                               256,
                               axis=-1))

            #save_samples(images, self.height, self.width)

            # use model to generate bottom half of images
            for i in range(self.height // 2, self.height):
                for j in range(self.width):
                    for k in range(self.channels):
                        probs = sess.run(self.pred,
                                         feed_dict={
                                             x: images,
                                             y: self.test_inputs,
                                             is_training: training
                                         })

                        if (self.config == '--MNIST'):
                            sample = binarize(probs)
                            images[:, i, j, k] = sample[:, i, j, k]
                        elif (self.config == '--FREY'):
                            sample = sample_categorical(probs[:, i, j])
                            images[:, i, j,
                                   k] = 2 * (sample.astype('float32') /
                                             255.) - 1.

            save_samples(images, self.height, self.width)
Exemple #16
0
                              '%s/fake_samples_epoch_%03d.png' %
                              (opt.outf, epoch),
                              normalize=False)
            '''
            1-dim: list with all the data, listed according the epochs
            2-dim: list containing lists of integer and tuple,
                integer indicates the epoch, the tuple contains the step index
                and the output data,
                [int, tuple]
            3-dim: tuple of integer and data as torch.FloatTensor, the integer
                indicates the step index of the corresponding batch in the loop
                (int, FloatTensor)
            4-dim: FloatTensor of shape 64x1x64x64:
            5-dim:
                64 samples x
                Channel number (here always only 1) x
                64x64 normalized binned data
            '''
            save_dict['fake_data'][epoch].append((i, fake.data))

    # do checkpointing
    checkpoint_path = os.path.join(opt.outf, 'checkpoints')
    if not os.path.exists(checkpoint_path):
        os.mkdir(checkpoint_path)
    torch.save(netG.state_dict(),
               '%s/netG_epoch_%d.pth' % (checkpoint_path, epoch))
    torch.save(netD.state_dict(),
               '%s/netD_epoch_%d.pth' % (checkpoint_path, epoch))
utils.save_samples(save_dict, path=opt.outf, filename='results.npy')
writer.close()
    def train(self):
        def preprocess_func(x, y):
            x = tf.cast(x, tf.float32)
            y = tf.cast(y, tf.float32) / 127.5 - 1.
            return x, y

        logfile_path = os.path.join(self.results_dir,
                                    'training_stg2_logfile.txt')
        if not os.path.exists(logfile_path):
            logfile = open(logfile_path, "w")

        lsdir_x = [f for f in os.listdir(self.stg2_data_dir) if 'data_x' in f]
        lsdir_y = [f for f in os.listdir(self.stg2_data_dir) if 'data_y' in f]
        lsdir_x, lsdir_y = sorted(lsdir_x), sorted(lsdir_y)
        assert len(lsdir_x) == len(
            lsdir_y
        ), "You are missing an x or y dataset shard. You may have deleted it."
        while True:
            for xtr_path, ytr_path in zip(lsdir_x, lsdir_y):
                x_tr = np.load(os.path.join(self.stg2_data_dir, xtr_path))
                y_tr = np.load(os.path.join(self.stg2_data_dir, ytr_path))
                dataset = tf.data.Dataset.from_tensor_slices(
                    (x_tr, y_tr)).shuffle(self.batch_size * 5).batch(
                        self.batch_size,
                        drop_remainder=True).map(preprocess_func)
                dataset = self.strategy.experimental_distribute_dataset(
                    dataset)

                for x, y in dataset:
                    self.train_step(x, y)

                    i = self.i

                    if i % 25 == 0:  #updates the moving average every 25 steps , as it takes a very long time.
                        self._update_moving_avg()
                    if i % 1000 == 0:
                        loss, gradnorm = to_numpy(
                            self.train_loss.result()), to_numpy(
                                self.train_gradnorm.result())
                        result_str = "Iteration: %d, Time Elapsed: %0.1f,  Loss: %0.4f, gradnorm: %0.4f" % (
                            i, time() - self.starttime, loss, gradnorm)
                        print(result_str)
                        logfile = open(logfile_path, "a")
                        logfile.write(result_str + "\n")
                        logfile.close()
                        self.train_loss.reset_states()
                        self.train_gradnorm.reset_states()
                    if i % 10000 == 0:
                        self.ema_model.set_weights(self.moving_avg_weights)
                        self.save_objects(i // 1000, stg2=True)
                        _, samples_ema = self.generate_samples(64)
                        savepath = self._getp('samples_ema_{}k.jpg'.format(
                            i // 1000),
                                              folder='results')
                        save_samples(samples_ema, savepath)

                    if i >= self.max_iterations:
                        print(
                            "Training complete. Recover your models from the {} folder and your training results from the {} folder"
                            .format(self.model_dir, self.results_dir))
                        print(
                            "Training was done on a maximum of {} iterations".
                            format(i))
                        return True

                    self.i += 1
    writer.add_scalar("loss/penalty", penalty, iters)
    writer.add_scalar("loss/mi", loss_mi, iters)

    if iters % args.log_interval == 0:
        print(
            "Train Iter: {}/{} ({:.0f}%)\t"
            "D_costs: {} G_costs: {} Time: {:5.3f}".format(
                iters,
                args.iters,
                (args.log_interval * iters) / args.iters,
                np.asarray(e_costs).mean(0),
                np.asarray(g_costs).mean(0),
                (time.time() - start_time) / args.log_interval,
            )
        )
        img = save_samples(netG, args)
        writer.add_image("samples/generated", img, iters)

        e_costs = []
        g_costs = []
        start_time = time.time()

    if iters % args.save_interval == 0:
        netG.eval()
        print("-" * 100)
        n_modes, kld = mc_eval.count_modes(netG)
        print("-" * 100)
        netG.train()

        writer.add_scalar("metrics/mode_count", n_modes, iters)
        writer.add_scalar("metrics/kl_divergence", kld, iters)
Exemple #19
0
def train(model, config, logger, options, model_dir, samples_dir,
          inverse=False, save_step=10, n_samples=20):
    utils.makedirs(samples_dir)

    losses = defaultdict(list)
    for epoch in tqdm(range(options.n_epochs), file=sys.stdout):
        logger.info("Epoch {}".format(epoch))

        samples_generator = config.data_iter(
            options.data, options.b_size, inverse=inverse, n_samples=n_samples)

        entropy_epoch = []

        for it in tqdm(range(options.n_iters), file=sys.stdout):
            Z_global = None
            if inverse:
                Z_global = np.random.uniform(
                    -1., 1., (options.b_size, config.nz_global, 1, 1))
            Z_samples = utils.sample_noise_tensor(
                config, options.b_size, config.zx, global_noise=Z_global)

            X_samples = next(samples_generator)
            if it % (config.k + 1) != 0:
                if inverse == 0:
                    losses['G_iter'].append(model.train_g(Z_samples))
                elif inverse == 1:
                    losses['G_iter'].append(model.train_g(X_samples, Z_samples,
                                                          Z_global))
                elif inverse >= 2:
                    # losses['G_iter'].append(model.train_g(X_samples[0],
                    #                                       Z_samples))
                    G_loss, entropy = model.train_g(X_samples[0], Z_samples)
                    losses['G_iter'].append(G_loss)
                    entropy_epoch.append(entropy)

            else:
                if inverse == 0:
                    losses['D_iter'].append(model.train_d(X_samples, Z_samples))
                elif inverse == 1:
                    losses['D_iter'].append(model.train_d(X_samples, Z_samples,
                                                          Z_global))
                elif inverse >= 2:
                    losses['D_iter'].append(model.train_d(
                        X_samples[0], X_samples[1], Z_samples))
        msg = "Gloss = {}, Dloss = {}"
        msg += "\n e_min = {}, e_max = {}, e_mean = {}, e_med = {}"
        e_min, e_max, e_mean, e_med = [f(entropy_epoch) for f in
                                       [np.min, np.max, np.mean, np.median]]
        losses['G_epoch'].append(np.mean(losses['G_iter'][-options.n_iters:]))
        losses['D_epoch'].append(np.mean(losses['D_iter'][-options.n_iters:]))
        logger.info(msg.format(losses['G_epoch'][-1], losses['D_epoch'][-1],
                    e_min, e_max, e_mean, e_med))

        X = next(samples_generator)
        real_samples, gen_samples, large_sample = utils.sample_after_iteration(
            model, X, inverse, config, options.b_size)
        utils.save_samples(
            samples_dir, [real_samples, gen_samples, large_sample],
            ['real', 'gen', 'large'], epoch=epoch)
        utils.save_plots(samples_dir, losses, epoch, options.n_iters)

        if (epoch+1) % save_step == 0:
            model_file = 'epoch_{:04d}.model'.format(epoch)
            model.save(os.path.join(model_dir, model_file))
Exemple #20
0
def train(data_root, model, total_epoch, batch_size, lrate):

    X, Z, Lr, Kt = model.inputs()
    d_loss, g_loss, real_loss, fake_loss = model.loss(X, Z, Kt)
    d_opt, g_opt = model.optimizer(d_loss, g_loss, Lr)
    g_sample = model.sample(Z)
    sample_size = batch_size
    test_noise = utils.get_noise(sample_size, n_noise)
    epoch_drop = 3

    _lambda = 0.001
    _gamma = 0.5
    _kt = 0.0

    iterator, image_count = ImageIterator(data_root, batch_size,
                                          model.image_size,
                                          model.image_channels).get_iterator()
    next_element = iterator.get_next()

    measure = real_loss + tf.abs(_gamma * real_loss - fake_loss)
    tf.summary.scalar('measure', measure)

    merged = tf.summary.merge_all()

    total_batch = int(image_count / batch_size)
    #learning_rate = lrate
    #G_var_list = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='generator')
    saver = tf.train.Saver()
    with tf.Session() as sess:
        train_writer = tf.summary.FileWriter('./logs', sess.graph)
        sess.run(tf.global_variables_initializer())
        sess.run(iterator.initializer)

        for epoch in range(total_epoch):
            learning_rate = lrate * \
                            math.pow(0.2, math.floor((epoch + 1) / epoch_drop))
            for step in range(total_batch):
                batch_x = sess.run(next_element)
                batch_z = utils.get_uniform_noise(batch_size, n_noise)

                _, val_real_loss = sess.run([d_opt, real_loss],
                                            feed_dict={
                                                X: batch_x,
                                                Z: batch_z,
                                                Lr: learning_rate,
                                                Kt: _kt
                                            })
                _, val_fake_loss = sess.run([g_opt, fake_loss],
                                            feed_dict={
                                                Z: batch_z,
                                                Lr: learning_rate,
                                                Kt: _kt
                                            })

                _kt = _kt + _lambda * (_gamma * val_real_loss - val_fake_loss)

                if step % 300 == 0:
                    summary = sess.run(merged,
                                       feed_dict={
                                           X: batch_x,
                                           Z: batch_z,
                                           Lr: learning_rate,
                                           Kt: _kt
                                       })
                    train_writer.add_summary(summary,
                                             epoch * total_batch + step)

                    val_measure = val_real_loss + np.abs(
                        _gamma * val_real_loss - val_fake_loss)

                    print('Epoch:', '%04d' % epoch,
                          '%05d/%05d' % (step, total_batch),
                          'measure: {:.4}'.format(val_measure))

                    #sample_size = 10
                    #noise = get_noise(sample_size, n_noise)
                    samples = sess.run(g_sample, feed_dict={Z: test_noise})
                    title = 'samples/%05d_%05d.png' % (epoch, step)
                    utils.save_samples(title, samples)

            saver.save(sess, './models/began', global_step=epoch)
                    type=str,
                    default='training_results/0000-00-00_00-00-00',
                    help='model save path.')
parser.add_argument('--save_latent',
                    action='store_true',
                    default=False,
                    help='save latent variables')

options = parser.parse_args()
print(options)

if __name__ == "__main__":
    test_info, test_data, _ = load_samples(options.data_path)
    model = load_model(options.model_path)
    predictions = model.predict(test_data)
    save_samples(os.path.join(options.save_path, 'recons'), predictions,
                 test_info)

    #print(test_data.shape)
    save_recon_error(options.save_path, test_data, predictions, test_info)

    if options.save_latent:
        encoder = Model(inputs=model.get_layer('sequential_1').model.input,
                        outputs=model.get_layer('sequential_1').model.output)
        latents = encoder.predict(test_data)
        save_latent_variables(os.path.join(options.save_path, 'latents'),
                              latents, test_info)

# ()()
# ('') HAANJU.YOO
Exemple #22
0
def train_wgan(model_gen,
               model_dis,
               train_gen,
               valid_data,
               test_data,
               num_epochs,
               batches_per_epoch,
               batch_size,
               output_dir=None,
               lmbda=0.1,
               use_cuda=True,
               discriminator_updates=5,
               epochs_per_sample=10,
               sample_size=20,
               lr=1e-4,
               beta_1=0.5,
               beta_2=0.9,
               latent_dim=100):

    if use_cuda:
        model_gen = model_gen.cuda()
        model_dis = model_dis.cuda()

    # Initialize optimizers for each model
    optimizer_gen = optim.Adam(model_gen.parameters(),
                               lr=lr,
                               betas=(beta_1, beta_2))
    optimizer_dis = optim.Adam(model_dis.parameters(),
                               lr=lr,
                               betas=(beta_1, beta_2))

    # Sample noise used for seeing the evolution of generated output samples throughout training
    sample_noise = torch.Tensor(sample_size, latent_dim).uniform_(-1, 1)
    if use_cuda:
        sample_noise = sample_noise.cuda()
    sample_noise_v = autograd.Variable(sample_noise)

    samples = {}
    history = []

    train_iter = iter(train_gen)
    valid_data_v = np_to_input_var(valid_data['X'], use_cuda)
    test_data_v = np_to_input_var(test_data['X'], use_cuda)

    # Loop over the dataset multiple times
    for epoch in range(num_epochs):
        LOGGER.info("Epoch: {}/{}".format(epoch + 1, num_epochs))

        epoch_history = []

        for batch_idx in range(batches_per_epoch):

            # Set model parameters to require gradients to be computed and stored
            for p in model_dis.parameters():
                p.requires_grad = True

            # Initialize the metrics for this batch
            batch_history = {'discriminator': [], 'generator': {}}

            # Discriminator Training Phase:
            # -> Train discriminator k times
            for iter_d in range(discriminator_updates):
                # Get real examples
                real_data_v = np_to_input_var(next(train_iter)['X'], use_cuda)

                # Get noise
                noise = torch.Tensor(batch_size, latent_dim).uniform_(-1, 1)
                if use_cuda:
                    noise = noise.cuda()
                noise_v = autograd.Variable(
                    noise, volatile=True)  # totally freeze model_gen

                # Get new batch of real training data
                D_cost_train, D_wass_train = compute_discr_loss_terms(
                    model_dis,
                    model_gen,
                    real_data_v,
                    noise_v,
                    batch_size,
                    latent_dim,
                    lmbda,
                    use_cuda,
                    compute_grads=True)

                # Update the discriminator
                optimizer_dis.step()

                D_cost_valid, D_wass_valid = compute_discr_loss_terms(
                    model_dis,
                    model_gen,
                    valid_data_v,
                    noise_v,
                    batch_size,
                    latent_dim,
                    lmbda,
                    use_cuda,
                    compute_grads=False)

                if use_cuda:
                    D_cost_train = D_cost_train.cpu()
                    D_cost_valid = D_cost_valid.cpu()
                    D_wass_train = D_wass_train.cpu()
                    D_wass_valid = D_wass_valid.cpu()

                batch_history['discriminator'].append({
                    'cost':
                    D_cost_train.data.numpy()[0],
                    'wasserstein_cost':
                    D_wass_train.data.numpy()[0],
                    'cost_validation':
                    D_cost_valid.data.numpy()[0],
                    'wasserstein_cost_validation':
                    D_wass_valid.data.numpy()[0]
                })

            ############################
            # (2) Update G network
            ###########################

            # Prevent discriminator from computing gradients, since
            # we are only updating the generator
            for p in model_dis.parameters():
                p.requires_grad = False

            G_cost = compute_gener_loss_terms(model_dis,
                                              model_gen,
                                              batch_size,
                                              latent_dim,
                                              use_cuda,
                                              compute_grads=True)

            # Update generator
            optimizer_gen.step()

            if use_cuda:
                G_cost = G_cost.cpu()

            # Record generator loss
            batch_history['generator']['cost'] = G_cost.data.numpy()[0]

            # Record batch metrics
            epoch_history.append(batch_history)

        # Record epoch metrics
        history.append(epoch_history)

        LOGGER.info(pprint.pformat(epoch_history[-1]))

        if (epoch + 1) % epochs_per_sample == 0:
            # Generate outputs for fixed latent samples
            LOGGER.info('Generating samples...')
            samp_output = model_gen(sample_noise_v)
            if use_cuda:
                samp_output = samp_output.cpu()

            samples[epoch + 1] = samp_output.data.numpy()
            if output_dir:
                LOGGER.info('Saving samples...')
                save_samples(samples[epoch + 1], epoch + 1, output_dir)

    ## Get final discriminator loss
    # Get noise
    noise = torch.Tensor(batch_size, latent_dim).uniform_(-1, 1)
    if use_cuda:
        noise = noise.cuda()
    noise_v = autograd.Variable(noise,
                                volatile=True)  # totally freeze model_gen

    # Get new batch of real training data
    D_cost_test, D_wass_test = compute_discr_loss_terms(model_dis,
                                                        model_gen,
                                                        test_data_v,
                                                        noise_v,
                                                        batch_size,
                                                        latent_dim,
                                                        lmbda,
                                                        use_cuda,
                                                        compute_grads=False)

    D_cost_valid, D_wass_valid = compute_discr_loss_terms(model_dis,
                                                          model_gen,
                                                          valid_data_v,
                                                          noise_v,
                                                          batch_size,
                                                          latent_dim,
                                                          lmbda,
                                                          use_cuda,
                                                          compute_grads=False)

    if use_cuda:
        D_cost_test = D_cost_test.cpu()
        D_cost_valid = D_cost_valid.cpu()
        D_wass_test = D_wass_test.cpu()
        D_wass_valid = D_wass_valid.cpu()

    final_discr_metrics = {
        'cost_validation': D_cost_valid.data.numpy()[0],
        'wasserstein_cost_validation': D_wass_valid.data.numpy()[0],
        'cost_test': D_cost_test.data.numpy()[0],
        'wasserstein_cost_test': D_wass_test.data.numpy()[0],
    }

    return model_gen, model_dis, history, final_discr_metrics, samples
Exemple #23
0
def validate_decoder(target_path, generation_path):
    from numpy import linalg as LA
    import numpy as np

    target_info, targets, _ = load_samples(target_path)
    generate_info, generates, _ = load_samples(generation_path)

    if len(targets) != len(generates):
        return False

    for i in range(len(targets)):
        if LA.norm(np.subtract(targets[i], generates[i])) > 0.001:
            return False
    return True


if __name__ == "__main__":
    test_info, test_data = load_latent_vectors(options.data_path)
    model = load_model(options.model_path)

    decoder = Model(inputs=model.get_layer('sequential_2').model.input,
                    outputs=model.get_layer('sequential_2').model.output)
    generations = decoder.predict(test_data)
    save_samples(os.path.join(options.save_path, 'generates'), generations,
                 test_info)

    print(
        validate_decoder(os.path.join(options.save_path, 'recons'),
                         os.path.join(options.save_path, 'generates')))
    def train(self, dataset):

        try:
            temp = set(tf.global_variables())
        except:
            temp = set(tf.all_variables())

        print("Creating optimizer")
        d_optim = tf.train.AdamOptimizer(Options.lrate_d, beta1=Options.beta1_d) \
              .minimize(self.d_loss, var_list=self.d_vars)
        g_optim = tf.train.AdamOptimizer(Options.lrate_g, beta1=Options.beta1_g) \
              .minimize(self.g_loss, var_list=self.g_vars)
        print("Optimizer created")

        with tf.Session() as sess:
            self.sess = sess
            try:
                print("Restoring from checkpoint")
                _dir = os.path.join(Options.checkpoint_dir, self.prefix())
                a = glob.glob(os.path.join(_dir, '*/'))
                a = [int(i.split('/')[-2]) for i in a]
                a.sort()
                _dir = os.path.join(_dir, str(a[-1]), 'model.ckpt')
                self.saver.restore(self.sess, _dir)
                print("Model restored")
                try:
                    self.sess.run(
                        tf.variables_initializer(
                            set(tf.global_variables()) - temp))
                except:
                    self.sess.run(
                        tf.initialize_variables(
                            set(tf.all_variables()) - temp))
            except:
                print("Initializing")
                try:
                    tf.global_variables_initializer().run()
                except:
                    tf.initialize_all_variables().run()
                print("Initialized")

            print("Merging summaries")
            self.g_sum = md.merge_summary([
                self.z_sum, self.d__sum, self.d_loss_fake_sum, self.g_loss_sum
            ])
            self.d_sum = md.merge_summary([
                self.z_sum, self.d_sum, self.d_loss_real_sum, self.d_loss_sum
            ])
            self.writer = md.SummaryWriter("./logs", self.sess.graph)
            print("Summaries merged")

            sample_z = np.random.uniform(-1,
                                         1,
                                         size=(self.sample_size, self.z_dim))

            counter = 0
            terrD_fake = 0.0
            terrD_real = 0.0
            terrG = 0.0

            s_begin = time.time()
            c_begin = time.time()
            p_begin = time.time()

            print("Starting training epoch")
            for epoch in range(Options.train_epochs):
                for sub_data in dataset.train_iter():

                    batch_z = np.random.uniform(
                        -1, 1,
                        [Options.batch_size, self.z_dim]).astype(np.float32)

                    # Update D network
                    _, summary_str = self.sess.run([d_optim, self.d_sum],
                                                   feed_dict={
                                                       self.videos: sub_data,
                                                       self.z: batch_z
                                                   })
                    self.writer.add_summary(summary_str, counter)

                    # Update G network
                    _, summary_str = self.sess.run([g_optim, self.g_sum],
                                                   feed_dict={self.z: batch_z})
                    self.writer.add_summary(summary_str, counter)

                    errD_fake = self.d_loss_fake.eval({self.z: batch_z})
                    errD_real = self.d_loss_real.eval({self.videos: sub_data})
                    errG = self.g_loss.eval({self.z: batch_z})

                    terrD_real += errD_real
                    terrD_fake += errD_fake
                    terrG += errG

                    counter += 1

                    if time.time() - p_begin > Options.print_time:
                        print(
                            "Epoch: [%d], d_loss_fake: [%.6f]--[%.4f], d_loss_real: [%.6f]--[%.4f], g_loss: [%.6f]--[%.4f]"
                            % (epoch, terrD_fake / counter, errD_fake,
                               terrD_real / counter, errD_real,
                               terrG / counter, errG))
                        p_begin = time.time()

                    if time.time() - s_begin > Options.sampler_time:
                        samples, d_loss, g_loss = self.sess.run(
                            [self.sampler, self.d_loss, self.g_loss],
                            feed_dict={
                                self.z: sample_z,
                                self.videos: sub_data
                            })
                        utils.save_samples(samples, epoch, counter)
                        print("[Sample] d_loss: %.8f, g_loss: %.8f" %
                              (d_loss, g_loss))
                        s_begin = time.time()

                    if time.time() - c_begin > Options.checkpoint_time:
                        print("Checkpointing")
                        if not os.path.exists(Options.checkpoint_dir):
                            os.makedirs(Options.checkpoint_dir)
                        _dir = os.path.join(Options.checkpoint_dir,
                                            self.prefix())
                        self.save(_dir, epoch, self.sess)
                        c_begin = time.time()

                        counter = 0
                        terrD_fake = 0.0
                        terrD_real = 0.0
                        terrG = 0.0
Exemple #25
0
    def train(self):
        """
        Draws samples from the data distribution and the noise distribution,
        and alternates between optimizing the parameters of the discriminator
        and the generator.
        """
        # TODO add loop over epochs and batches instead single iterations
        with tf.Session() as session:
            tf.global_variables_initializer().run()

            if pre_train:
                # pretraining discriminator
                for step in range(self.pre_train_steps):
                    # d = (np.random.random(self.batch_size) - 0.5) * 10.0
                    d = self.data[self.pre_train_steps + step]
                    # labels = norm.pdf(d, loc=self.data.mu, scale=self.data.sigma)
                    # labels = self.data[step]
                    pretrain_loss, _ = session.run(
                        [self.pre_loss, self.pre_opt], {
                            self.pre_input: np.reshape(d, self.input_shape)
                        })
                self.weightsD = session.run(self.d_pre_params)

                # copy weights from pre-training over to new D network
                for i, v in enumerate(self.d_params):
                    session.run(tf.assign(v, self.weightsD[i]))
                    # session.run(v.assign(self.weightsD[i]))
            writer = tf.summary.FileWriter("./logs", session.graph)

            g_sum = tf.summary.merge(
                [self.z_sum, self.d2_sum, self.G_sum, self.loss_d_fake_sum,
                 self.loss_g_sum])
            d_sum = tf.summary.merge(
                [self.z_sum, self.d1_sum, self.loss_d_real_sum,
                 self.loss_d_sum])

            for step in range(self.num_steps):
                print('iteration at step {}'.format(step))
                # update discriminator
                # x = np.sort(self.data[self.num_steps + step])
                x = self.data[step]
                # TODO try different noise sources
                # z = self.gen.sample_int(self.input_shape[0])
                z = self.gen.binned_samples(self.input_shape)
                loss_d, summary, _ = session.run(
                    [self.loss_d, d_sum, self.opt_d],
                    {
                     self.x: x[np.newaxis, :, :, np.newaxis],
                     self.z: z,
                     self.is_training: True
                    })
                writer.add_summary(summary, step)

                # update generator
                # z = self.gen.sample_int(self.input_shape[0])
                z = self.gen.binned_samples(self.input_shape)
                loss_g, summary, _ = session.run(
                    [self.loss_g, g_sum, self.opt_g],
                    {
                     self.z: np.reshape(z, self.input_shape),
                     self.is_training: True
                    })
                writer.add_summary(summary, step)

                errD_fake = self.loss_d_fake.eval(
                    {self.z: z, self.is_training: False})
                errD_real = self.loss_d_real.eval(
                    {self.x: x[np.newaxis, :, :, np.newaxis], self.is_training: False})
                errG = self.loss_g.eval(
                    {self.z: z, self.is_training: False})
                print('loss_d {}, errD_real {}'.format(loss_d, errD_real + errD_fake))
                print('loss_g {}, errG {}'.format(loss_g, errG))

                self.loss_d_plot.append(loss_d)
                self.loss_g_plot.append(loss_g)
                if step % self.log_every == 0:
                    # print('{}: loss_d: {}\t loss_g: {}'.format(step, loss_d,
                    #                                            loss_g))
                    samples, d_loss, g_loss = session.run(
                        [self.G, self.loss_d, self.loss_g],
                        feed_dict={self.z: z,
                                   self.x: x[np.newaxis, :, :, np.newaxis],
                                   self.is_training: False}
                    )
                    now = datetime.datetime.now().isoformat()
                    save_samples(samples, self.samples_dir,
                              'train_{}_{}.npy'.format(step, now))

                if self.anim_path:
                    self.anim_frames.append(plots.samples(session, save=False,
                                                          lower_range=self.gen.lower_range,
                                                          upper_range=self.gen.upper_range,
                                                          batch_size=self.input_shape[0],
                                                          D1=self.D1,
                                                          G=self.G,
                                                          x=self.x,
                                                          data=self.data,
                                                          z=self.z))

                if (step % 500) == 0:
                    save(self.checkpoint_dir, step, self.saver, session,
                         'GAN.model')

            if self.anim_path:
                plots.save_animation(self.anim_path, self.anim_frames,
                                     lower_range=self.gen.lower_range,
                                     upper_range=self.gen.upper_range)
            else:
                # TODO find a proper plotting mechanism
                pass
                # plots.plot_distributions(session, save=False,
                #                          lower_range=self.gen.lower_range,
                #                          upper_range=self.gen.upper_range,
                #                          batch_size=self.input_shape[0],
                #                          D1=self.D1,
                #                          G=self.G,
                #                          x=self.x,
                #                          z=self.z,
                #                          data=self.data)
                # plots.plot_training_loss(self.loss_g_plot, self.loss_d_plot)

            # merged = tf.summary.merge_all()
            # writer = tf.summary.FileWriter('logs', session.graph)
            writer.close()
Exemple #26
0
def main():
    data_name = "augmented_1"
    data_path = os.path.join("./data", data_name)
    csv_name = data_name + ".csv"
    train_df = pd.read_csv(os.path.join(data_path, csv_name))

    keypoint_names = list(
        map(lambda x: x[:-2],
            train_df.columns.to_list()[1::2]))
    keypoint_flip_map = [
        ("left_eye", "right_eye"),
        ("left_ear", "right_ear"),
        ("left_shoulder", "right_shoulder"),
        ("left_elbow", "right_elbow"),
        ("left_wrist", "right_wrist"),
        ("left_hip", "right_hip"),
        ("left_knee", "right_knee"),
        ("left_ankle", "right_ankle"),
        ("left_palm", "right_palm"),
        ("left_instep", "right_instep"),
    ]

    image_list = train_df.iloc[:, 0].to_numpy()
    keypoints_list = train_df.iloc[:, 1:].to_numpy()
    train_imgs, valid_imgs, train_keypoints, valid_keypoints = train_val_split(
        image_list, keypoints_list, random_state=42)

    image_set = {"train": train_imgs, "valid": valid_imgs}
    keypoints_set = {"train": train_keypoints, "valid": valid_keypoints}

    hyper_params = {
        "augmented_ver": data_name,
        "learning_rate": 0.001,
        "num_epochs": 10000,
        "batch_size": 256,
        "description": "Final training"
    }

    for phase in ["train", "valid"]:
        DatasetCatalog.register(
            "keypoints_" + phase,
            lambda phase=phase: get_data_dicts(data_path, image_set[phase],
                                               keypoints_set[phase]))
        MetadataCatalog.get("keypoints_" + phase).set(thing_classes=["human"])
        MetadataCatalog.get("keypoints_" +
                            phase).set(keypoint_names=keypoint_names)
        MetadataCatalog.get("keypoints_" +
                            phase).set(keypoint_flip_map=keypoint_flip_map)
        MetadataCatalog.get("keypoints_" + phase).set(evaluator_type="coco")

    cfg = get_cfg()
    cfg.merge_from_file(
        model_zoo.get_config_file(
            "COCO-Keypoints/keypoint_rcnn_X_101_32x8d_FPN_3x.yaml"))
    cfg.DATASETS.TRAIN = ("keypoints_train", )
    cfg.DATASETS.TEST = ("keypoints_valid", )
    cfg.DATALOADER.NUM_WORKERS = 16  # On Windows environment, this value must be 0.
    cfg.SOLVER.IMS_PER_BATCH = 2  # mini batch size would be (SOLVER.IMS_PER_BATCH) * (ROI_HEADS.BATCH_SIZE_PER_IMAGE).
    cfg.SOLVER.BASE_LR = hyper_params["learning_rate"]  # Learning Rate.
    cfg.SOLVER.MAX_ITER = hyper_params["num_epochs"]  # Max iteration.
    cfg.SOLVER.GAMMA = 0.8
    cfg.SOLVER.STEPS = [
        3000, 4000, 5000, 6000, 7000, 8000
    ]  # The iteration number to decrease learning rate by GAMMA.
    cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url(
        "COCO-Keypoints/keypoint_rcnn_X_101_32x8d_FPN_3x.yaml")
    cfg.MODEL.ROI_HEADS.BATCH_SIZE_PER_IMAGE = hyper_params[
        "batch_size"]  # Use to calculate RPN loss.
    cfg.MODEL.ROI_HEADS.NUM_CLASSES = 1
    cfg.MODEL.ROI_KEYPOINT_HEAD.NUM_KEYPOINTS = 24
    cfg.TEST.KEYPOINT_OKS_SIGMAS = np.ones((24, 1), dtype=float).tolist()
    cfg.TEST.EVAL_PERIOD = 5000  # Evaluation would occur for every cfg.TEST.EVAL_PERIOD value.
    cfg.OUTPUT_DIR = os.path.join("./output", data_name)

    os.makedirs(cfg.OUTPUT_DIR, exist_ok=True)
    trainer = Trainer(cfg)
    trainer.resume_or_load(resume=False)
    trainer.train()

    # Inference should use the config with parameters that are used in training
    # cfg now already contains everything we've set previously. We changed it a little bit for inference:
    cfg.MODEL.WEIGHTS = os.path.join(
        cfg.OUTPUT_DIR, "model_final.pth")  # path to the model we just trained
    cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.7  # set a custom testing threshold
    predictor = DefaultPredictor(cfg)

    test_dir = os.path.join("data", "test_imgs")
    test_list = os.listdir(test_dir)
    test_list.sort()
    except_list = []

    files = []
    preds = []
    for file in tqdm(test_list):
        filepath = os.path.join(test_dir, file)
        im = cv2.imread(filepath)
        outputs = predictor(im)
        outputs = outputs["instances"].to("cpu").get("pred_keypoints").numpy()
        files.append(file)
        pred = []
        try:
            for out in outputs[0]:
                pred.extend([float(e) for e in out[:2]])
        except IndexError:
            pred.extend([0] * 48)
            except_list.append(filepath)
        preds.append(pred)

    df_sub = pd.read_csv("./data/sample_submission.csv")
    df = pd.DataFrame(columns=df_sub.columns)
    df["image"] = files
    df.iloc[:, 1:] = preds

    df.to_csv(os.path.join(cfg.OUTPUT_DIR, f"{data_name}_submission.csv"),
              index=False)
    if except_list:
        print(
            "The following images are not detected keypoints. The row corresponding that images names would be filled with 0 value."
        )
        print(*except_list)
    save_samples(cfg.OUTPUT_DIR,
                 test_dir,
                 os.path.join(cfg.OUTPUT_DIR, f"{data_name}_submission.csv"),
                 mode="random",
                 size=5)
Exemple #27
0
def train(opt, dataloader_X, dataloader_Y, test_dataloader_X,
          test_dataloader_Y):

    input_shape = (None, None, opt["channels"])

    norm_type = 'batchnorm'
    if opt["batch_size"] == 1:
        norm_type = 'instancenorm'

    # models
    G_YtoX = model.Generator(input_shape, opt, norm_type)
    G_XtoY = model.Generator(input_shape, opt, norm_type)
    D_X = model.Discriminator(input_shape, norm_type)
    D_Y = model.Discriminator(input_shape, norm_type)

    # summary
    G_YtoX.summary()
    D_X.summary()

    # optimizers
    G_YtoX_optimizer = tf.optimizers.Adam(opt["lr_g"], beta_1=opt["b1"])
    G_XtoY_optimizer = tf.optimizers.Adam(opt["lr_g"], beta_1=opt["b1"])

    D_X_optimizer = tf.optimizers.Adam(opt["lr_d"], beta_1=opt["b1"])
    D_Y_optimizer = tf.optimizers.Adam(opt["lr_d"], beta_1=opt["b1"])

    # set checkpoint
    ckpt, ckpt_manager = model.set_checkpoint(opt, G_YtoX, G_XtoY, D_X, D_Y,
                                              G_YtoX_optimizer,
                                              G_XtoY_optimizer, D_X_optimizer,
                                              D_Y_optimizer)

    test_X = next(iter(test_dataloader_X))
    test_Y = next(iter(test_dataloader_Y))

    iterations = 0
    total_disc_loss_list, total_gen_loss_list, test_loss_list = utils.load_losses_list(
        opt)
    mean_cycle_loss_test = 0
    count = 0

    for epoch in range(opt["epoch"], opt["n_epochs"] + 1):
        start = time.time()

        for image_x, image_y in tf.data.Dataset.zip(
            (dataloader_X, dataloader_Y)):

            total_disc_loss, total_gen_loss = train_step(
                image_x, image_y, G_YtoX, G_XtoY, D_X, D_Y, G_YtoX_optimizer,
                G_XtoY_optimizer, D_X_optimizer, D_Y_optimizer, opt)
            total_disc_loss_list.append(total_disc_loss)
            total_gen_loss_list.append(total_gen_loss)
            if iterations % 200 == 0:
                print('iteration:{}'.format(iterations))
                # Using a consistent image so that the progress of the model is visible.
                cycle_loss_test = utils.save_samples(epoch, test_Y, test_X,
                                                     G_YtoX, G_XtoY, opt)
                mean_cycle_loss_test += cycle_loss_test
                count += 1
                utils.plot_losses(total_disc_loss_list, total_gen_loss_list,
                                  opt, epoch, "Epoch:{}".format(epoch))

            iterations += 1

        clear_output(wait=True)
        # Using a consistent image so that the progress of the model is visible. Only one plot is keept per epoch
        cycle_loss_test = utils.save_samples(epoch, test_Y, test_X, G_YtoX,
                                             G_XtoY, opt)
        mean_cycle_loss_test += cycle_loss_test
        count += 1
        test_loss_list.append(mean_cycle_loss_test / count)
        mean_cycle_loss_test = 0
        count = 0

        utils.plot_losses(total_disc_loss_list, total_gen_loss_list, opt,
                          epoch, "Epoch:{}".format(epoch))
        utils.plot_test_loss(test_loss_list, opt, epoch)
        utils.save_losses_list(total_disc_loss_list, total_gen_loss_list,
                               test_loss_list, opt)

        print("Sample saved...")
        if (epoch + 1) % opt["checkpoint_interval"] == 0:
            ckpt_save_path = ckpt_manager.save()
            print('Saving checkpoint for epoch {} at {}'.format(
                epoch, ckpt_save_path))

        print('Time taken for epoch {} is {} sec\n'.format(
            epoch,
            time.time() - start))
    return G_YtoX, G_XtoY, D_X, D_Y
Exemple #28
0
        norm_data = np.empty((len(encoded_data), 1, imageSize, imageSize),
                             dtype=np.float32)
        norm_value = []
        for i, ed in enumerate(encoded_data):
            dat = np.array(ed).reshape((1, imageSize, imageSize))
            norm_data[i] = np.divide(dat, np.max(dat))
            norm_value.append(np.max(dat))
        save_dict['normed_data'] = norm_data
        save_dict['num_samples'] = num_samples
        save_dict['imageSize'] = imageSize
        save_dict['spikes'] = raw_data[:100]
        save_dict['encoded_data'] = encoded_data
        save_dict['normed_values'] = norm_value

    else:
        save_dict['binned_data'] = binned_data
        save_dict['normed_data'] = norm_data
        save_dict['num_samples'] = num_samples
        save_dict['imageSize'] = imageSize
        save_dict['spikes'] = raw_data
        save_dict['data_type'] = opt.data_type

    if opt.filename:
        fname = opt.filename
    else:
        fname = 'data_NS{}_IS{}_type-{}_encoded-{}_rate{}.npy'.format(
            num_samples, imageSize, opt.data_type, opt.encoding, ARRAY_ID)
    u = str(uuid.uuid4())
    path = os.path.join(opt.path, u)
    utils.save_samples(save_dict, path=path, filename=fname)
Exemple #29
0
    def train(self):

        opti_D = tf.train.AdamOptimizer(learning_rate=self.learn_rate,
                                        beta1=0.5).minimize(
                                            self.D_loss, var_list=self.d_var)
        opti_G = tf.train.AdamOptimizer(learning_rate=self.learn_rate,
                                        beta1=0.5).minimize(
                                            self.G_loss, var_list=self.g_var)
        init = tf.global_variables_initializer()

        config = tf.ConfigProto()
        config.gpu_options.allow_growth = True

        with tf.Session(config=config) as sess:

            sess.run(init)
            summary_writer = tf.summary.FileWriter(self.log_dir,
                                                   graph=sess.graph)

            step = 0
            while step <= 10000:

                realbatch_array, real_labels = self.data_ob.getNext_batch(step)

                # Get the z
                batch_z = np.random.uniform(-1,
                                            1,
                                            size=[self.batch_size, self.z_dim])

                _, summary_str = sess.run(
                    [opti_D, self.merged_summary_op_d],
                    feed_dict={
                        self.real_feature_vector: realbatch_array,
                        self.z: batch_z,
                        self.y: real_labels
                    })
                summary_writer.add_summary(summary_str, step)

                _, summary_str = sess.run([opti_G, self.merged_summary_op_g],
                                          feed_dict={
                                              self.z: batch_z,
                                              self.y: real_labels
                                          })
                summary_writer.add_summary(summary_str, step)

                if step % 50 == 0:

                    D_loss = sess.run(self.D_loss,
                                      feed_dict={
                                          self.real_feature_vector:
                                          realbatch_array,
                                          self.z: batch_z,
                                          self.y: real_labels
                                      })
                    fake_loss = sess.run(self.G_loss,
                                         feed_dict={
                                             self.z: batch_z,
                                             self.y: real_labels
                                         })
                    print("Step %d: D: loss = %.7f G: loss=%.7f " %
                          (step, D_loss, fake_loss))

                if np.mod(step, 50) == 1 and step != 0:
                    sample_exec_time = sample_output(
                        batch_size=self.batch_size,
                        min=self.y_min,
                        max=self.y_max)
                    sample_feature = sess.run(self.fake_feature_vector,
                                              feed_dict={
                                                  self.z: batch_z,
                                                  self.y: sample_exec_time
                                              })
                    save_samples(sample_feature, sample_exec_time)
                    self.saver.save(sess, self.model_path)

                step = step + 1

            save_path = self.saver.save(sess, self.model_path)
            print("Model saved in file: %s" % save_path)
    def train(self, ImageModel, netG, netD, audio_encoder, ImageGeneratorModel,
              ImageDiscriminatorModel, args):
        if self.use_cuda:
            netG.cuda()
            netD.cuda()
            ImageModel.cuda()
            audio_encoder.cuda()
            ImageGeneratorModel.cuda()
            ImageDiscriminatorModel.cuda()

        logger = Logger(self.log_folder)
        ##audio part
        optimizerG = optim.Adam(netG.parameters(),
                                lr=float(args['--learning_rate']),
                                betas=(float(args['--beta1']),
                                       float(args['--beta2'])))
        optimizerD = optim.Adam(netD.parameters(),
                                lr=float(args['--learning_rate']),
                                betas=(float(args['--beta1']),
                                       float(args['--beta2'])))
        opt_generator = optim.Adam(ImageModel.parameters(),
                                   lr=0.0002,
                                   betas=(0.5, 0.999),
                                   weight_decay=0.00001)

        optimizerG_Image = optim.Adam(ImageGeneratorModel.parameters(),
                                      lr=float(args['--learning_rate']),
                                      betas=(float(args['--beta1']),
                                             float(args['--beta2'])))
        optimizerD_image = optim.Adam(ImageDiscriminatorModel.parameters(),
                                      lr=float(args['--learning_rate']),
                                      betas=(float(args['--beta1']),
                                             float(args['--beta2'])))
        opt_encoder = optim.Adam(audio_encoder.parameters(),
                                 lr=float(args['--learning_rate']),
                                 betas=(float(args['--beta1']),
                                        float(args['--beta2'])))

        # opt_encoder = optim.Adam(audio_encoder.parameters(), lr=0.0002, betas=(0.5, 0.999), weight_decay=0.00001)
        # =============Train===============
        lmbda = float(args['--lmbda'])
        epochs_per_sample = int(args['--epochs_per_sample'])
        output_dir = 'outputs'
        history = []
        D_costs_train = []
        D_wasses_train = []
        D_costs_valid = []
        D_wasses_valid = []
        G_costs = []
        BATCH_NUM = 350
        epochs = 1800
        start = time.time()
        self.LOGGER_audio.info(
            'Starting training...EPOCHS={}, BATCH_SIZE={}, BATCH_NUM={}'.
            format(epochs, self.image_batch_size, BATCH_NUM))
        self.LOGGER_image.info(
            'Starting training...EPOCHS={}, BATCH_SIZE={}, BATCH_NUM={}'.
            format(epochs, self.audio_batch_size, BATCH_NUM))

        # training loop

        def sample_fake_image_batch(batch_size, real_img):
            return generator.sample_images(batch_size, real_img)

        def sample_fake_video_batch(batch_size, real_video):
            return generator.sample_videos(batch_size, real_video)

        def init_logs():
            return {'l_gen': 0, 'l_image_dis': 0, 'l_video_dis': 0}

        batch_num = 0

        logs = init_logs()

        start_time = time.time()
        epoch = 0
        gc.collect()
        while True:
            gc.collect()
            epoch = epoch + 1
            self.LOGGER_audio.info("{} Epoch: {}/{}".format(
                time_since(start), epoch, epochs))
            D_cost_train_epoch = []
            D_wass_train_epoch = []
            D_cost_valid_epoch = []
            D_wass_valid_epoch = []
            G_cost_epoch = []

            self.LOGGER_image.info("{} Epoch: {}/{}".format(
                time_since(start), epoch, epochs))

            D2_cost_train_epoch = []
            D2_wass_train_epoch = []
            D2_cost_valid_epoch = []
            D2_wass_valid_epoch = []
            G2_cost_epoch = []

            for i in range(1, BATCH_NUM + 1):
                # sample real data
                sample_real_batch = self.sample_real_image_batch()
                batch_real_audio = sample_real_batch['audio'].cpu()
                batch_image = Variable(sample_real_batch['images'],
                                       requires_grad=False)
                # print(batch_image.shape)
                batch_image = batch_image.cuda()

                for p in netD.parameters():
                    p.requires_grad = True

                one = torch.Tensor([1]).float()
                neg_one = one * -1
                if self.use_cuda:
                    one = one.cuda()
                    neg_one = neg_one.cuda()
                # (1) Train Discriminator

                for iter_dis in range(1):
                    netD.zero_grad()

                    z = nn.init.normal(torch.Tensor(self.image_batch_size,
                                                    512))
                    if self.use_cuda:
                        z = z.cuda()
                    z = Variable(z)

                    real_data_Var = numpy_to_var(batch_real_audio,
                                                 self.use_cuda)
                    # print(batch_image.shape)
                    # print(type(batch_image))
                    batch_image = batch_image.cuda()

                    real_img_Var = Variable(batch_image)

                    # a) compute loss contribution from real training data
                    D_real = netD(real_data_Var)
                    D_real = D_real.mean()
                    #print('D_real',D_real)  # avg loss
                    D_real.backward(neg_one)  # loss * -1

                    #print('real_img_Var',real_img_Var.shape)
                    D_real_img = ImageDiscriminatorModel(real_img_Var)
                    D_real_img = D_real_img.mean()  # avg loss
                    #print('D_real_img',D_real_img)
                    D_real_img = Variable(D_real_img.data,
                                          requires_grad=True)  #Added
                    D_real_img.backward(neg_one)  # loss * -1

                    # b) compute loss contribution from generated data, then backprop.
                    features = ImageModel(batch_image)
                    fk_audio = netG(z, features)
                    fk_audio = autograd.Variable(fk_audio.data)
                    #print(fk_audio.shape)(16, 1, 16384)

                    D_fake = netD(fk_audio)
                    D_fake = D_fake.mean()
                    D_fake.backward(one)

                    #print(batch_real_audio.shape)#16*16384
                    batch_real_audio = batch_real_audio.unsqueeze(1)
                    #print(batch_real_audio.shape)#16*1*16384
                    # print(fk_audio.shape)
                    # print(type(fk_audio))
                    # print(batch_real_audio.shape)
                    # print(type(batch_real_audio))
                    # print('audio_encoder',audio_encoder)
                    audio_features = audio_encoder(batch_real_audio)
                    #print('audio_features',audio_features.shape)
                    #audio_features=audio_encoder(fk_audio)
                    audio_features = audio_features.unsqueeze(2).unsqueeze(3)
                    fk_image = ImageGeneratorModel(audio_features)
                    fk_image = autograd.Variable(fk_image.data)
                    D_fake_image = ImageDiscriminatorModel(fk_image)
                    D_fake_image = D_fake_image.mean()
                    D_fake_image = Variable(D_fake_image.data,
                                            requires_grad=True)  #Added
                    D_fake_image.backward(one)
                    #print('real_img_Var',type(real_img_Var.data))#16x3x224x224

                    # c) compute gradient penalty and backprop
                    gradient_penalty = calc_gradient_penalty(
                        netD,
                        real_data_Var.data,
                        fk_audio.data,
                        self.image_batch_size,
                        lmbda,
                        use_cuda=self.use_cuda)
                    gradient_penalty.backward(one)

                    ################################# 16*3*224*224 to 16*3*64*64 or do nn.AvgPool2d
                    LRTrans = transforms.Compose([
                        transforms.Scale(64, Image.BICUBIC),
                        transforms.ToTensor(),
                        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                    ])
                    real_img_Var_64 = torch.zeros([16, 3, 64, 64])
                    for j in range(16):
                        #print('j',j)
                        real_img_Var_64[j] = torch.FloatTensor(
                            LRTrans(
                                Image.fromarray(real_img_Var.data.cpu().mul(
                                    0.5).add(0.5).mul(255).byte()[j].transpose(
                                        0, 2).transpose(0,
                                                        1).numpy())).numpy())
                    #print('testing',real_img_Var.data.cpu().mul(0.5).add(0.5).mul(255).byte()[j].shape)
                    #print("real_img_Var_64",real_img_Var_64.shape)
                    real_img_Var_64 = real_img_Var_64.cuda()
                    ######################################
                    gradient_penalty_2 = calc_gradient_penalty_2(
                        ImageDiscriminatorModel,
                        real_img_Var_64.data,
                        fk_image.data,
                        self.audio_batch_size,
                        lmbda,
                        use_cuda=self.use_cuda)
                    gradient_penalty_2.backward(one)

                    # Compute cost * Wassertein loss..
                    D_cost_train = D_fake - D_real + gradient_penalty
                    D_wass_train = D_real - D_fake

                    D2_cost_train = D_fake_image - D_real_img + gradient_penalty_2
                    D2_wass_train = D_real_img - D_fake_image

                    # Update gradient of discriminator.
                    optimizerD.step()

                    optimizerD_image.step()

                    #############################
                    # (2) Compute Valid data
                    #############################
                    netD.zero_grad()
                    ImageDiscriminatorModel.zero_grad()
                    batch_real_audio = batch_real_audio.squeeze(1)
                    valid_data_Var = numpy_to_var(batch_real_audio,
                                                  self.use_cuda)
                    D_real_valid = netD(valid_data_Var)
                    D_real_valid = D_real_valid.mean()  # avg loss

                    #valid_data_Var_2 = numpy_to_var(batch_image, self.use_cuda)# can substitute this with below two lines
                    batch_image = batch_image.cuda()
                    valid_data_Var_2 = Variable(batch_image)

                    D2_real_valid = ImageDiscriminatorModel(valid_data_Var_2)
                    D2_real_valid = D2_real_valid.mean()  # avg loss

                    # b) compute loss contribution from generated data, then backprop.
                    fake_valid = netG(z, features)
                    D_fake_valid = netD(fake_valid)
                    D_fake_valid = D_fake_valid.mean()

                    fake_valid_2 = ImageGeneratorModel(audio_features)
                    D2_fake_valid = ImageDiscriminatorModel(fake_valid_2)
                    D2_fake_valid = D2_fake_valid.mean()

                    # c) compute gradient penalty and backprop
                    gradient_penalty_valid = calc_gradient_penalty(
                        netD,
                        valid_data_Var.data,
                        fake_valid.data,
                        self.image_batch_size,
                        lmbda,
                        use_cuda=self.use_cuda)

                    ################################# 16*3*224*224 to 16*3*64*64 or do nn.AvgPool2d
                    LRTrans = transforms.Compose([
                        transforms.Scale(64, Image.BICUBIC),
                        transforms.ToTensor(),
                        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                    ])
                    valid_data_Var_2_64 = torch.zeros([16, 3, 64, 64])
                    for j in range(16):
                        valid_data_Var_2_64[j] = torch.FloatTensor(
                            LRTrans(
                                Image.fromarray(
                                    valid_data_Var_2.data.cpu().mul(0.5).add(
                                        0.5).mul(255).byte()[j].transpose(
                                            0,
                                            2).transpose(0,
                                                         1).numpy())).numpy())
                    #print('testing',real_img_Var.data.cpu().mul(0.5).add(0.5).mul(255).byte()[j].shape)
                    #print("valid_data_Var_2_64",valid_data_Var_2_64.shape)
                    valid_data_Var_2_64 = valid_data_Var_2_64.cuda()
                    ######################################

                    gradient_penalty_valid_2 = calc_gradient_penalty_2(
                        ImageDiscriminatorModel,
                        valid_data_Var_2_64.data,
                        fake_valid_2.data,
                        self.audio_batch_size,
                        lmbda,
                        use_cuda=self.use_cuda)

                    # Compute metrics and record in batch history.
                    D_cost_valid = D_fake_valid - D_real_valid + gradient_penalty_valid
                    D_wass_valid = D_real_valid - D_fake_valid

                    D2_cost_valid = D2_fake_valid - D2_real_valid + gradient_penalty_valid_2
                    D2_wass_valid = D2_real_valid - D2_fake_valid

                    if self.use_cuda:
                        D_cost_train = D_cost_train.cpu()
                        D_wass_train = D_wass_train.cpu()
                        D_cost_valid = D_cost_valid.cpu()
                        D_wass_valid = D_wass_valid.cpu()

                        D2_cost_train = D2_cost_train.cpu()
                        D2_wass_train = D2_wass_train.cpu()
                        D2_cost_valid = D2_cost_valid.cpu()
                        D2_wass_valid = D2_wass_valid.cpu()

                    # Record costs
                    D_cost_train_epoch.append(D_cost_train.data.numpy())
                    D_wass_train_epoch.append(D_wass_train.data.numpy())
                    D_cost_valid_epoch.append(D_cost_valid.data.numpy())
                    D_wass_valid_epoch.append(D_wass_valid.data.numpy())

                    D2_cost_train_epoch.append(D2_cost_train.data.numpy())
                    D2_wass_train_epoch.append(D2_wass_train.data.numpy())
                    D2_cost_valid_epoch.append(D2_cost_valid.data.numpy())
                    D2_wass_valid_epoch.append(D2_wass_valid.data.numpy())

                    #############################
                # (3) Train Generator
                #############################
                # Prevent discriminator update.
                for p in netD.parameters():
                    p.requires_grad = False

                for p in ImageDiscriminatorModel.parameters():
                    p.requires_grad = False

                # Reset generator gradients
                netG.zero_grad()
                fk_audio = netG(z, features)
                # fake = autograd.Variable(fk_img.data)
                # fake = netG(fk_img2)
                G = netD(fk_audio)
                G = G.mean()
                # print('audio_cond',audio_cond.shape)

                # Update gradients.
                G.backward(neg_one)
                G_cost = -G

                optimizerG.step()
                opt_generator.step()

                ImageGeneratorModel.zero_grad()
                fk_img = ImageGeneratorModel(audio_features)
                # fake = autograd.Variable(fk_img.data)
                # fake = netG(fk_img2)
                G2 = ImageDiscriminatorModel(fk_img)
                G2 = G2.mean()
                # print('audio_cond',audio_cond.shape)

                # Update gradients.
                G2.backward(neg_one)
                G2_cost = -G2

                optimizerG_Image.step()
                opt_encoder.step()

                # Record costs
                if self.use_cuda:
                    G_cost = G_cost.cpu()
                G_cost_epoch.append(G_cost.data.numpy())

                if i % (BATCH_NUM // 5) == 0:
                    self.LOGGER_audio.info(
                        "{} Epoch={} Batch: {}/{} D_c:{:.4f} | D_w:{:.4f} | G:{:.4f}"
                        .format(time_since(start), epoch, i, BATCH_NUM,
                                D_cost_train.data.numpy(),
                                D_wass_train.data.numpy(),
                                G_cost.data.numpy()))

                if self.use_cuda:
                    G2_cost = G2_cost.cpu()
                G2_cost_epoch.append(G2_cost.data.numpy())

                if i % (BATCH_NUM // 5) == 0:
                    self.LOGGER_image.info(
                        "{} Epoch={} Batch: {}/{} D2_c:{:.4f} | D2_w:{:.4f} | G2:{:.4f}"
                        .format(time_since(start), epoch, i, BATCH_NUM,
                                D2_cost_train.data.numpy(),
                                D2_wass_train.data.numpy(),
                                G2_cost.data.numpy()))

            # Save the average cost of batches in every epoch.
            D_cost_train_epoch_avg = sum(D_cost_train_epoch) / float(
                len(D_cost_train_epoch))
            D_wass_train_epoch_avg = sum(D_wass_train_epoch) / float(
                len(D_wass_train_epoch))
            D_cost_valid_epoch_avg = sum(D_cost_valid_epoch) / float(
                len(D_cost_valid_epoch))
            D_wass_valid_epoch_avg = sum(D_wass_valid_epoch) / float(
                len(D_wass_valid_epoch))
            G_cost_epoch_avg = sum(G_cost_epoch) / float(len(G_cost_epoch))

            D_costs_train.append(D_cost_train_epoch_avg)
            D_wasses_train.append(D_wass_train_epoch_avg)
            D_costs_valid.append(D_cost_valid_epoch_avg)
            D_wasses_valid.append(D_wass_valid_epoch_avg)
            G_costs.append(G_cost_epoch_avg)

            # Save the average cost of batches in every epoch.
            D2_cost_train_epoch_avg = sum(D2_cost_train_epoch) / float(
                len(D2_cost_train_epoch))
            D2_wass_train_epoch_avg = sum(D2_wass_train_epoch) / float(
                len(D2_wass_train_epoch))
            D2_cost_valid_epoch_avg = sum(D2_cost_valid_epoch) / float(
                len(D2_cost_valid_epoch))
            D2_wass_valid_epoch_avg = sum(D2_wass_valid_epoch) / float(
                len(D2_wass_valid_epoch))
            G2_cost_epoch_avg = sum(G2_cost_epoch) / float(len(G2_cost_epoch))

            D2_costs_train.append(D2_cost_train_epoch_avg)
            D2_wasses_train.append(D2_wass_train_epoch_avg)
            D2_costs_valid.append(D2_cost_valid_epoch_avg)
            D2_wasses_valid.append(D2_wass_valid_epoch_avg)
            G2_costs.append(G2_cost_epoch_avg)

            self.LOGGER_audio.info(
                "{} D_cost_train:{:.4f} | D_wass_train:{:.4f} | D_cost_valid:{:.4f} | D_wass_valid:{:.4f} | "
                "G_cost:{:.4f}".format(time_since(start),
                                       D_cost_train_epoch_avg,
                                       D_wass_train_epoch_avg,
                                       D_cost_valid_epoch_avg,
                                       D_wass_valid_epoch_avg,
                                       G_cost_epoch_avg))

            self.LOGGER_image.info(
                "{} D2_cost_train:{:.4f} | D2_wass_train:{:.4f} | D2_cost_valid:{:.4f} | D2_wass_valid:{:.4f} | "
                "G2_cost:{:.4f}".format(time_since(start),
                                        D2_cost_train_epoch_avg,
                                        D2_wass_train_epoch_avg,
                                        D2_cost_valid_epoch_avg,
                                        D2_wass_valid_epoch_avg,
                                        G2_cost_epoch_avg))

            # Generate audio samples.
            if epoch % epochs_per_sample == 0:
                self.LOGGER_audio.info("Generating samples...")

                self.LOGGER_image.info("Generating image samples...")
                # batch_real_image_val,image_enumerator=sample_real_image_batch(image_enumerator)
                # batch_real_audio_val=batch_real_image_val['audio'].cpu()
                # sample_test_Var = numpy_to_var(batch_real_audio_val, cuda)
                torch.save(
                    ImageModel,
                    os.path.join(self.log_folder,
                                 '%05d_ImageModel.pytorch' % epoch))
                torch.save(
                    netG,
                    os.path.join(self.log_folder, '%05d_netG.pytorch' % epoch))
                torch.save(
                    netD,
                    os.path.join(self.log_folder,
                                 'I%05d_netD.pytorch' % epoch))

                torch.save(
                    audio_encoder,
                    os.path.join(self.log_folder,
                                 '%05d_audio_encoder.pytorch' % epoch))
                torch.save(
                    ImageGeneratorModel,
                    os.path.join(self.log_folder,
                                 '%05d_ImageGeneratorModel.pytorch' % epoch))
                torch.save(
                    ImageDiscriminatorModel,
                    os.path.join(
                        self.log_folder,
                        'I%05d_ImageDiscriminatorModel.pytorch' % epoch))

                for iii in range(1, 2):
                    # sample real data
                    sample_test_batch = self.sample_test_image_batch()
                    batch_test_audio = sample_test_batch['audio'].cpu()
                    batch_image_test = Variable(sample_test_batch['images'],
                                                requires_grad=False)
                    # print(batch_image.shape)
                    batch_image_test = batch_image_test.cuda()
                    features_test = ImageModel(batch_image)
                    z = nn.init.normal(torch.Tensor(self.image_batch_size,
                                                    512))
                    if self.use_cuda:
                        z = z.cuda()
                    z = Variable(z)
                    sample_out = netG(z, features_test)
                    if self.use_cuda:
                        sample_out = sample_out.cpu()
                    sample_out = sample_out.data.numpy()
                    save_samples(sample_out, epoch, output_dir)
                    gc.collect()