Exemplo n.º 1
0
    def visualize(self):
        # TODO: Solve bug with the generator which generates unmatched images.
        sample_z = np.random.uniform(-1,
                                     1,
                                     size=(self.model.sample_num,
                                           self.model.z_dim))
        _, sample_embed, _, captions = self.dataset.train.next_batch_test(
            self.model.sample_num, randint(0, self.dataset.test.num_examples),
            1)
        sample_embed = np.squeeze(sample_embed, axis=0)

        samples = self.sess.run(self.model.sampler,
                                feed_dict={
                                    self.model.z_sample: sample_z,
                                    self.model.phi_sample: sample_embed,
                                })

        fake_img = samples[0]
        closest_img = closest_image(fake_img, self.dataset)
        closest_pair = np.array([fake_img, closest_img])

        save_images(
            closest_pair, image_manifold_size(closest_pair.shape[0]),
            './{}/{}/{}/test5.png'.format(self.config.test_dir,
                                          self.model.name, self.dataset.name))
Exemplo n.º 2
0
    def visualize_results(self, epoch, fix=True):
        self.G.eval()
        output_path = '/'.join(
            [self.result_dir, self.dataset, self.model_name])
        if not os.path.exists(output_path):
            os.makedirs(output_path)

        tot_num_samples = min(self.sample_num, self.batch_size)
        image_frame_dim = int(np.floor(np.sqrt(tot_num_samples)))

        if fix:
            """ fixed noise """
            samples = self.G(self.sample_z_)
        else:
            """ random noise """
            if self.gpu_mode:
                sample_z_ = Variable(
                    torch.rand((self.batch_size, self.z_dim)).cuda())
            else:
                sample_z_ = Variable(torch.rand((self.batch_size, self.z_dim)))

            samples = self.G(sample_z_)

        if self.gpu_mode:
            samples = samples.cpu().data.numpy().transpose(0, 2, 3, 1)
        else:
            samples = samples.data.numpy().transpose(0, 2, 3, 1)

        path_images = '/'.join(
            [self.result_dir, self.dataset, self.model_name])
        utils.save_images(
            samples[:image_frame_dim * image_frame_dim, :, :, :],
            [image_frame_dim, image_frame_dim], path_images + '/' +
            self.model_name + '_epoch%03d' % epoch + '.png')
Exemplo n.º 3
0
def main(args):
    test_x, test_y = load_image(args.image_path)

    test_inp = to_tensor(test_x.astype(np.float32))
    test_target = to_tensor(test_y.astype(np.float32))

    generator = Generator().to("cuda")

    start_t = time.time()
    pretrain_model = flow.load(args.model_path)
    generator.load_state_dict(pretrain_model)
    end_t = time.time()
    print("load params time : {}".format(end_t - start_t))

    start_t = time.time()
    generator.eval()
    with flow.no_grad():
        gout = to_numpy(generator(test_inp), False)
    end_t = time.time()
    print("infer time : {}".format(end_t - start_t))

    # save images
    save_images(
        gout,
        test_inp.numpy(),
        test_target.numpy(),
        path=os.path.join("./testimage.png"),
        plot_size=1,
    )
Exemplo n.º 4
0
 def save_images(self, data):
     images = data['img']
     paths = os.path.join(self.opt.save_dir, self.opt.object)
     paths = os.path.join(paths, "result")
     anomaly_img = utils.compare_images(images,
                                        self.generated_imgs,
                                        threshold=self.opt.threshold)
     utils.save_images(anomaly_img, paths, data)
    def train(self, mode='train'):
        print('Beginning Training: ')
        # with self.sess as sess:
        sess = self.sess
        tf.global_variables_initializer().run(session=self.sess)

        if mode == 'test' or mode == 'validation':
            print("loading model from checkpoint")
            checkpoint = tf.train.latest_checkpoint(self.model_dir)
            # print(checkpoint)
            self.saver.restore(sess, checkpoint)
        else:
            # checkpoint = tf.train.latest_checkpoint(self.model_dir)
            # if checkpoint:
            #     self.saver.restore(sess, checkpoint)
            #     print("Restored from checkpoint")
            counter = 0
            ep = 0
            could_load, checkpoint_counter, checkpoint_epoch = self.load(self.model_dir)
            if could_load:
                ep = checkpoint_epoch
                counter = checkpoint_counter
                print("Successfully loaded checkpoint")
            else:
                print("Failed to load checkpoint")
            start_time = time.time()
            labels = open("samples/labels.txt", "a")
            for epoch in tqdm(range(ep, self.epochs)):
                for step in tqdm(range(counter, self.max_steps)):
                    if step - counter in range(2101, 2116):
                        x, _ = sess.run([self.x, self.real_labels])
                        print(x.shape)
                        continue
                    for _ in range(5):
                        # self.x, self.real_labels = self.iter.get_next()
                        _, disc_loss, x = self.sess.run([self.disc_step, self.d_loss, self.x])
                        if step - counter > 2110:
                            print(x.shape)
                        # _ = self.sess.run([self.disc_gp_step])

                    # self.x, self.real_labels = self.iter.get_next()
                    _, gen_loss = self.sess.run([self.gen_step, self.g_loss])

                    if step % 100 == 0:
                        print("Time: {:.4f}, Epoch: {}, Step: {}, Generator Loss: {:.4f}, Discriminator Loss: {:.4f}"
                              .format(time.time() - start_time, epoch, step, gen_loss, disc_loss))
                        fake_im, real_im, fake_l, real_l = sess.run([self.fake_image, self.x, self.fake_labels, self.real_labels])
                        save_images(fake_im, image_manifold_size(fake_im.shape[0]),
                                    './samples/train_{:02d}_{:06d}.png'.format(epoch, step))
                        save_images(real_im, image_manifold_size(real_im.shape[0]),
                                    './samples/train_{:02d}_{:06d}_real.png'.format(epoch, step))
                        labels.write("{:02d}_{:06d}:\nReal Labels -\n{}\nFake Labels -\n{}\n".format(epoch, step, str(real_l), str(fake_l)))
                        print('Translated images and saved..!')

                    if step % 200 == 0:
                        self.save(self.model_dir, step, epoch)
                        print("Checkpoint saved")
                counter = 0
Exemplo n.º 6
0
 def _eval_generator_and_save_images(self, epoch_idx):
     results = self._eval_generator()
     save_images(
         results,
         to_numpy(self.fixed_inp, False),
         to_numpy(self.fixed_target, False),
         path=os.path.join(self.test_images_path,
                           "testimage_{:02d}.png".format(epoch_idx + 1)),
     )
Exemplo n.º 7
0
    def test(self, args):
        if args.which_direction == 'AtoB':
            sample_files = glob('./datasets/{}/*.*'.format(self.dataset_dir +
                                                           '/testA_b'))
        elif args.which_direction == 'BtoA':
            sample_files = glob('./datasets/{}/*.*'.format(self.dataset_dir +
                                                           '/testB'))
        else:
            raise Exception('--which_direction must be AtoB or BtoA')

        # write html for visual comparison
        index_path = os.path.join(
            args.test_dir, '{0}_index.html'.format(args.which_direction))
        index = open(index_path, "w")
        index.write("<html><body><table><tr>")
        index.write("<th>name</th><th>input</th><th>output</th></tr>")

        out_var, in_var = (
            self.testB,
            self.test_A) if args.which_direction == 'AtoB' else (self.testA,
                                                                 self.test_B)

        for sample_file in sample_files:
            print('Processing image: ' + sample_file)
            sample_image = [load_test_data(sample_file, args.fine_size)]
            sample_image = np.array(sample_image).astype(np.float32)
            new_shape = list(sample_image.shape) + [1]
            sample_image = np.reshape(sample_image, newshape=new_shape)
            sample_image = sample_image[:, :, :, :self.input_c_dim]
            test_path = os.path.join(args.test_dir, args.dataset_dir)
            if not os.path.exists(test_path):
                os.makedirs(test_path)
            image_path = os.path.join(
                args.test_dir, args.dataset_dir,
                '{0}_{1}'.format(args.which_direction,
                                 os.path.basename(sample_file)))
            fake_img = self.sess.run(out_var, feed_dict={in_var: sample_image})
            save_images(fake_img, [1, 1], image_path)
            index.write("<td>%s</td>" % os.path.basename(image_path))
            index.write("<td><img src='%s'></td>" %
                        (sample_file if os.path.isabs(sample_file) else
                         ('..' + os.path.sep + sample_file)))
            index.write("<td><img src='%s'></td>" %
                        (image_path if os.path.isabs(image_path) else
                         ('..' + os.path.sep + image_path)))
            index.write("</tr>")
        index.close()
Exemplo n.º 8
0
    def infer(self, image, save_path, bright_diff=0, is_grayscale=True):
        # read image
        if isinstance(image, str):
            img = imread(image, is_grayscale=is_grayscale)
        else:
            img = image
        img = cv2.resize(img, dsize=(256, 256))
        img = np.reshape(img, newshape=(img.shape[0], img.shape[1], 1))
        gen_avatar = self.sess.run(self.testB, feed_dict={self.test_A: [img]})
        if save_path is not None:
            save_images(gen_avatar + bright_diff,
                        size=[1, 1],
                        image_path=save_path)
        gen_avatar = np.reshape(gen_avatar,
                                newshape=list(gen_avatar.shape[1:-1]))
        gen_avatar = inverse_transform(gen_avatar)

        return gen_avatar
Exemplo n.º 9
0
 def sample_model(self, sample_dir, epoch, idx):
     data_a = glob('./datasets/{}/*.*'.format(self.dataset_dir + '/testA'))
     data_b = glob('./datasets/{}/*.*'.format(self.dataset_dir + '/testB'))
     np.random.shuffle(data_a)
     np.random.shuffle(data_b)
     batch_files = list(
         zip(data_a[:self.batch_size], data_b[:self.batch_size]))
     sample_images = [
         load_train_data(batch_file, is_testing=True)
         for batch_file in batch_files
     ]
     sample_images = np.array(sample_images).astype(np.float32)
     fake_a, fake_b = self.sess.run(
         [self.fake_A, self.fake_B],
         feed_dict={self.real_data: sample_images})
     save_images(fake_a, [self.batch_size, 1],
                 './{}/A_{:02d}_{:04d}.jpg'.format(sample_dir, epoch, idx))
     save_images(fake_b, [self.batch_size, 1],
                 './{}/B_{:02d}_{:04d}.jpg'.format(sample_dir, epoch, idx))
Exemplo n.º 10
0
    def visualize_data(self, epoch, sample_max=5):
        if self.dataset_name in ['mnist', 'fashion-mnist', 'celeba']:
            tot_num_samples = min(self.sample_num, self.batch_size)
            image_frame_dim = int(np.floor(np.sqrt(tot_num_samples)))

            samples = self.sess.run(self.ds.denorm_img(self.inputs))

            save_images(
                samples[:image_frame_dim * image_frame_dim, :, :, :],
                [image_frame_dim, image_frame_dim],
                check_folder(self.result_dir + '/' + self.model_dir) + '/' +
                self.model_name + '_epoch%03d' % epoch + '_training_data.png')

            if self.bot is not None:
                self.bot.send_file(
                    os.path.join(
                        self.result_dir, self.model_dir, self.model_name +
                        '_epoch%03d' % epoch + '_training_data.png'))
        else:
            raise NotImplementedError
Exemplo n.º 11
0
    def save_test_sample(self, epoch, batch_number):
        samples = \
            self.sess.run(self.fake_images)

        tot_num_samples = min(self.sample_num, self.batch_size)
        manifold_h = int(np.floor(np.sqrt(tot_num_samples)))
        manifold_w = int(np.floor(np.sqrt(tot_num_samples)))

        save_images(
            samples[:manifold_h * manifold_w, :, :, :],
            [manifold_h, manifold_w],
            os.path.join(
                check_folder(
                    os.path.join(os.getcwd(), self.result_dir,
                                 self.model_dir)), self.model_name +
                '_train_{:04d}_{:04d}.png'.format(epoch, batch_number)))

        if self.bot is not None:
            self.bot.send_file(
                os.path.join(
                    os.getcwd(), self.result_dir, self.model_dir,
                    self.model_name +
                    '_train_{:04d}_{:04d}.png'.format(epoch, batch_number)))
Exemplo n.º 12
0
    def visualize_results(self, epoch):
        if self.dataset_name in ['mnist', 'fashion-mnist', 'celeba']:
            tot_num_samples = min(self.sample_num, self.batch_size)
            image_frame_dim = int(np.floor(np.sqrt(tot_num_samples)))
            """ random condition, random noise """
            samples = self.sess.run(self.fake_images)

            save_images(
                samples[:image_frame_dim * image_frame_dim, :, :, :],
                [image_frame_dim, image_frame_dim],
                os.path.join(
                    check_folder(
                        os.path.join(os.getcwd(), self.result_dir,
                                     self.model_dir)), self.model_name +
                    '_epoch%03d' % epoch + '_test_all_classes.png'))

            if self.bot is not None:
                self.bot.send_file(
                    os.path.join(
                        os.getcwd(), self.result_dir, self.model_dir,
                        self.model_name + '_epoch%03d' % epoch +
                        '_test_all_classes.png'))
        else:
            raise NotImplementedError
    def test(self):

        opt = self.opt

        gpu_ids = range(torch.cuda.device_count())
        print('Number of GPUs in use {}'.format(gpu_ids))

        iteration = 0

        if torch.cuda.device_count() > 1:
            vae = nn.DataParallel(VAE(hallucination=self.useHallucination,
                                      opt=opt,
                                      refine=self.refine,
                                      bg=128,
                                      fg=896),
                                  device_ids=gpu_ids).cuda()
        else:
            vae = VAE(hallucination=self.useHallucination, opt=opt).cuda()

        print(self.jobname)

        if self.load:
            model_name = '../pretrained_models/cityscapes/refine_genmask_w_mask_two_path_096000.pth.tar'
            # model_name = '../' + self.jobname + '/{:06d}_model.pth.tar'.format(self.iter_to_load)

            print("loading model from {}".format(model_name))

            state_dict = torch.load(model_name)
            if torch.cuda.device_count() > 1:
                vae.module.load_state_dict(state_dict['vae'])
            else:
                vae.load_state_dict(state_dict['vae'])

        z_noise = torch.ones(1, 1024).normal_()

        for data, bg_mask, fg_mask, paths in tqdm(iter(self.testloader)):
            # Set to evaluation mode (randomly sample z from the whole distribution)
            vae.eval()

            # If test on generated images
            # data = data.unsqueeze(1)
            # data = data.repeat(1, opt.num_frames, 1, 1, 1)

            frame1 = data[:, 0, :, :, :]
            noise_bg = torch.randn(frame1.size())
            z_m = Vb(z_noise.repeat(frame1.size()[0] * 8, 1))
            #

            y_pred_before_refine, y_pred, mu, logvar, flow, flowback, mask_fw, mask_bw = vae(
                frame1, data, bg_mask, fg_mask, noise_bg, z_m)

            utils.save_samples(data,
                               y_pred_before_refine,
                               y_pred,
                               flow,
                               mask_fw,
                               mask_bw,
                               iteration,
                               self.sampledir,
                               opt,
                               eval=True,
                               useMask=True,
                               grid=[4, 4])
            '''save images'''
            utils.save_images(self.output_image_dir, data, y_pred, paths, opt)
            utils.save_images(self.output_image_dir_before, data,
                              y_pred_before_refine, paths, opt)

            data = data.cpu().data.transpose(2, 3).transpose(3, 4).numpy()
            utils.save_gif(
                data * 255, opt.num_frames, [4, 4],
                self.sampledir + '/{:06d}_real.gif'.format(iteration))
            '''save flows'''
            utils.save_flows(self.output_fw_flow_dir, flow, paths)
            utils.save_flows(self.output_bw_flow_dir, flowback, paths)
            '''save occlusion maps'''
            utils.save_occ_map(self.output_fw_mask_dir, mask_fw, paths)
            utils.save_occ_map(self.output_bw_mask_dir, mask_bw, paths)

            iteration += 1
Exemplo n.º 14
0
    def train(self):
        self.define_losses()
        self.define_summaries()

        sample_z = np.random.normal(0, 1,
                                    (self.model.sample_num, self.model.z_dim))
        _, sample_embed, _, captions = self.dataset.test.next_batch_test(
            self.model.sample_num, randint(0, self.dataset.test.num_examples),
            1)
        sample_embed = np.squeeze(sample_embed, axis=0)
        print(sample_embed.shape)

        # Display the captions of the sampled images
        print('\nCaptions of the sampled images:')
        for caption_idx, caption_batch in enumerate(captions):
            print('{}: {}'.format(caption_idx + 1, caption_batch[0]))
        print()

        counter = 1
        start_time = time.time()

        # Try to load the parameters of the stage II networks
        tf.global_variables_initializer().run()
        could_load, checkpoint_counter = load(self.stageii_saver, self.sess,
                                              self.cfg.CHECKPOINT_DIR)
        if could_load:
            counter = checkpoint_counter
            print(" [*] Load SUCCESS: Stage II networks are loaded.")
        else:
            print(" [!] Load failed for stage II networks...")

        could_load, checkpoint_counter = load(self.stagei_g_saver, self.sess,
                                              self.cfg_stage_i.CHECKPOINT_DIR)
        if could_load:
            counter = checkpoint_counter
            print(" [*] Load SUCCESS: Stage I generator is loaded")
        else:
            print(
                " [!] WARNING!!! Failed to load the parameters for stage I generator..."
            )

        for epoch in range(self.cfg.TRAIN.EPOCH):
            # Updates per epoch are given by the training data size / batch size
            updates_per_epoch = self.dataset.train.num_examples // self.model.batch_size

            for idx in range(0, updates_per_epoch):
                images, wrong_images, embed, _, _ = self.dataset.train.next_batch(
                    self.model.batch_size, 4)
                batch_z = np.random.normal(
                    0, 1, (self.model.batch_size, self.model.z_dim))

                # Update D network
                _, err_d_real_match, err_d_real_mismatch, err_d_fake, err_d, summary_str = self.sess.run(
                    [
                        self.D_optim, self.D_real_match_loss,
                        self.D_real_mismatch_loss, self.D_synthetic_loss,
                        self.D_loss, self.D_merged_summ
                    ],
                    feed_dict={
                        self.model.inputs: images,
                        self.model.wrong_inputs: wrong_images,
                        self.model.embed_inputs: embed,
                        self.model.z: batch_z
                    })
                self.writer.add_summary(summary_str, counter)

                # Update G network
                _, err_g, summary_str = self.sess.run(
                    [self.G_optim, self.G_loss, self.G_merged_summ],
                    feed_dict={
                        self.model.z: batch_z,
                        self.model.embed_inputs: embed
                    })
                self.writer.add_summary(summary_str, counter)

                counter += 1
                print(
                    "Epoch: [%2d] [%4d/%4d] time: %4.4f, d_loss: %.8f, g_loss: %.8f"
                    % (epoch, idx, updates_per_epoch, time.time() - start_time,
                       err_d, err_g))

                if np.mod(counter, 100) == 0:
                    try:
                        samples = self.sess.run(self.model.sampler,
                                                feed_dict={
                                                    self.model.z_sample:
                                                    sample_z,
                                                    self.model.embed_sample:
                                                    sample_embed,
                                                })
                        save_images(
                            samples, image_manifold_size(samples.shape[0]),
                            '{}train_{:02d}_{:04d}.png'.format(
                                self.cfg.SAMPLE_DIR, epoch, idx))
                        print("[Sample] d_loss: %.8f, g_loss: %.8f" %
                              (err_d, err_g))

                        # Display the captions of the sampled images
                        print('\nCaptions of the sampled images:')
                        for caption_idx, caption_batch in enumerate(captions):
                            print('{}: {}'.format(caption_idx + 1,
                                                  caption_batch[0]))
                        print()
                    except Exception as e:
                        print("Failed to generate sample image")
                        print(type(e))
                        print(e.args)
                        print(e)

                if np.mod(counter, 500) == 2:
                    save(self.stageii_saver, self.sess,
                         self.cfg.CHECKPOINT_DIR, counter)
Exemplo n.º 15
0
    def train(self):
        self.define_summaries()

        self.saver = tf.train.Saver(max_to_keep=self.cfg.TRAIN.CHECKPOINTS_TO_KEEP)

        sample_z = np.random.normal(0, 1, (self.model.sample_num, self.model.z_dim))
        _, sample_cond, _, captions = self.dataset.test.next_batch_test(self.model.sample_num, 0, 1)
        # _, sample_cond, _, captions = self.dataset.test.next_batch_test(self.model.sample_num, 1, 1)
        sample_cond = np.squeeze(sample_cond, axis=0)
        print('Conditionals sampler shape: {}'.format(sample_cond.shape))

        save_captions(self.cfg.SAMPLE_DIR, captions)

        start_time = time.time()
        tf.global_variables_initializer().run()

        could_load, checkpoint_counter = load(self.saver, self.sess, self.cfg.CHECKPOINT_DIR)
        if could_load:
            start_point = checkpoint_counter
            print(" [*] Load SUCCESS")
        else:
            start_point = 0
            print(" [!] Load failed...")
        sys.stdout.flush()

        for idx in range(start_point + 1, self.cfg.TRAIN.MAX_STEPS):
            epoch_size = self.dataset.train.num_examples // self.model.batch_size
            epoch = idx // epoch_size

            images, wrong_images, embed, _, _ = self.dataset.train.next_batch(self.model.batch_size, 1, embeddings=True,
                                                                              wrong_img=True)
            batch_z = np.random.normal(0, 1, (self.model.batch_size, self.model.z_dim))
            eps = np.random.uniform(0., 1., size=(self.model.batch_size, 1, 1, 1))
            n_critic = self.cfg.TRAIN.N_CRITIC
            kiter = (idx // n_critic) // 10000

            feed_dict = {
                self.model.learning_rate_d: self.lr_d * (0.95**kiter),
                self.model.learning_rate_g: self.lr_g * (0.95**kiter),
                self.model.x: images,
                self.model.x_mismatch: wrong_images,
                self.model.cond: embed,
                self.model.z: batch_z,
                self.model.epsilon: eps,
                self.model.z_sample: sample_z,
                self.model.cond_sample: sample_cond,
                self.model.iter: idx,
            }

            _, _, err_d = self.sess.run([self.model.D_optim, self.model.kt_optim, self.model.D_loss],
                                         feed_dict=feed_dict)

            if idx % n_critic == 0:
                _, err_g = self.sess.run([self.model.G_optim, self.model.G_loss],
                                         feed_dict=feed_dict)

            summary_period = self.cfg.TRAIN.SUMMARY_PERIOD
            if np.mod(idx, summary_period) == 0:
                summary_str = self.sess.run(self.summary_op, feed_dict=feed_dict)
                self.writer.add_summary(summary_str, idx)

            if np.mod(idx, self.cfg.TRAIN.SAMPLE_PERIOD) == 0:
                try:
                    samples = self.sess.run(self.model.sampler,
                                            feed_dict={
                                                self.model.z_sample: sample_z,
                                                self.model.cond_sample: sample_cond,
                                            })
                    save_images(samples, get_balanced_factorization(samples.shape[0]),
                                '{}train_{:02d}_{:04d}.png'.format(self.cfg.SAMPLE_DIR, epoch, idx))

                except Exception as e:
                    print("Failed to generate sample image")
                    print(type(e))
                    print(e.args)
                    print(e)

            if np.mod(idx, 500) == 2:
                save(self.saver, self.sess, self.cfg.CHECKPOINT_DIR, idx)
            sys.stdout.flush()
Exemplo n.º 16
0
def train(cfg):
    '''
    This is the main loop for training
    Loads the dataset, model, and other things
    '''
    print json.dumps(cfg, sort_keys=True, indent=4)

    use_cuda = cfg['use-cuda']

    _, _, train_dl, val_dl = utils.get_data_loaders(cfg)

    model = utils.get_model(cfg)
    if use_cuda:
        model = model.cuda()
    model = utils.init_weights(model, cfg)

    # Get pretrained models, optimizers and loss functions
    optim = utils.get_optimizers(model, cfg)
    model, optim, metadata = utils.load_ckpt(model, optim, cfg)
    loss_fn = utils.get_losses(cfg)

    # Set up random seeds
    seed = np.random.randint(2**32)
    ckpt = 0
    if metadata is not None:
        seed = metadata['seed']
        ckpt = metadata['ckpt']

    # Get schedulers after getting checkpoints
    scheduler = utils.get_schedulers(optim, cfg, ckpt)
    # Print optimizer state
    print optim

    # Get loss file handle to dump logs to
    if not os.path.exists(cfg['save-path']):
        os.makedirs(cfg['save-path'])
    lossesfile = open(os.path.join(cfg['save-path'], 'losses.txt'), 'a+')

    # Random seed according to what the saved model is
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

    # Run training loop
    num_epochs = cfg['train']['num-epochs']
    for epoch in range(num_epochs):
        # Run the main training loop
        model.train()
        for data in train_dl:
            # zero out the grads
            optim.zero_grad()

            # Change to required device
            for key, value in data.items():
                data[key] = Variable(value)
                if use_cuda:
                    data[key] = data[key].cuda()

            # Get all outputs
            outputs = model(data)
            loss_val = loss_fn(outputs, data, cfg)

            # print it
            print('Epoch: {}, step: {}, loss: {}'.format(
                epoch, ckpt,
                loss_val.data.cpu().numpy()))

            # Log into the file after some epochs
            if ckpt % cfg['train']['step-log'] == 0:
                lossesfile.write('Epoch: {}, step: {}, loss: {}\n'.format(
                    epoch, ckpt,
                    loss_val.data.cpu().numpy()))

            # Backward
            loss_val.backward()
            optim.step()

            # Update schedulers
            scheduler.step()

            # Peek into the validation set
            ckpt += 1
            if ckpt % cfg['peek-validation'] == 0:
                model.eval()
                with torch.no_grad():
                    for val_data in val_dl:
                        # Change to required device
                        for key, value in val_data.items():
                            val_data[key] = Variable(value)
                            if use_cuda:
                                val_data[key] = val_data[key].cuda()

                        # Get all outputs
                        outputs = model(val_data)
                        loss_val = loss_fn(outputs, val_data, cfg)

                        print 'Validation loss: {}'.format(
                            loss_val.data.cpu().numpy())

                        lossesfile.write('Validation loss: {}\n'.format(\
                            loss_val.data.cpu().numpy()))
                        utils.save_images(val_data, outputs, cfg, ckpt)
                        break
                model.train()
            # Save checkpoint
            utils.save_ckpt((model, optim), cfg, ckpt, seed)

    lossesfile.close()
Exemplo n.º 17
0
    opt.serial_batches = True
    dataset = dataset.DatasetLoader(opt)
    dataset_size = len(dataset)
    print('The number of training images = %d' % dataset_size)

    # model setup
    model = ddpm.DDPModel(opt)
    model.setup(opt)               # regular setup: load and print networks; create schedulers

    # save options in the checkpoints directory
    cfg.save()

    # save directory setup
    save_dir = os.path.join(opt.results_dir, opt.name)  # save all the images to save_dir
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)
        print('Directory created: %s' % save_dir)

    # sampling
    total_iters = 0                # the total number of training iterations

    for i, data in enumerate(dataset):
        if i>= opt.num_test:
            break
        model.set_input(data)         # unpack data from dataset and apply preprocessing
        outputs = model.interpolate()
        utils.save_images(outputs, [os.path.join(save_dir, model.get_interpolate_filename())])
        total_iters += opt.batch_size
        if total_iters % opt.print_freq == 0:    # print training losses and save logging information to the disk
            print('interpolate %d images in %s' % (total_iters, opt.results_dir))
Exemplo n.º 18
0
def test(args,
         datasets,
         loaders,
         model,
         writer,
         batch_idx=-1,
         prefix="test",
         save_on_batch=None):
    start = time.time()
    model = model.eval()
    save_on_batch = random.randint(
        0, 10) if save_on_batch is None else save_on_batch
    with torch.no_grad():
        for s, loader in loaders.items():
            not_dumped_images = True
            losses = defaultdict(lambda: [])
            for bidx, (x, svec, tvec, cvec) in enumerate(loader):
                batch_size = x.shape[0]
                r_labels = torch.ones((batch_size),
                                      dtype=torch.float,
                                      device=args.device)
                f_labels = torch.zeros((batch_size),
                                       dtype=torch.float,
                                       device=args.device)
                if args.cuda:
                    x = x.cuda()
                    svec = svec.cuda()
                    tvec = tvec.cuda()
                    cvec = cvec.cuda()
                # Generator
                z = torch.randn(batch_size,
                                model.LATENT_SIZE).float().to(args.device)
                gz = model.generator(z, svec, tvec, cvec)
                xs = [
                    ('real-', x),
                    ('fake-', gz),
                ]
                for tag, xin in xs:
                    rf, spred, tpred, cpred = model.discriminator(
                        xin, predict_s=r_labels, svec=svec, tvec=tvec)
                    losses[tag + "dreal"].append(
                        torch.log(1e-8 + rf).mean().item())
                    losses[tag + "dfake"].append(
                        torch.log(1e-8 + 1 - rf).mean().item())
                    losses[tag + "study acc"].append(
                        (torch.argmax(spred.detach(),
                                      dim=1) == svec).float().mean().item())
                    losses[tag + "study ce"].append(
                        F.cross_entropy(spred, svec).item())
                    losses[tag + "task acc"].append(
                        (torch.argmax(tpred.detach(),
                                      dim=1) == tvec).float().mean().item())
                    losses[tag + "task ce"].append(
                        F.cross_entropy(tpred, tvec).item())
                    losses[tag + "contrast acc"].append(
                        (torch.argmax(cpred.detach(),
                                      dim=1) == cvec).float().mean().item())
                    losses[tag + "contrast ce"].append(
                        F.cross_entropy(cpred, cvec).item())

                losses["disc loss"] = losses["real-dreal"][-1] + losses[
                    "fake-dfake"][-1]
                losses["gen loss"] = -losses["fake-dreal"][-1]

                # Save gz
                if not_dumped_images and (bidx % 10 == save_on_batch):
                    not_dumped_images = False
                    batch_dir = os.path.join(args.image_dir,
                                             "batch{}".format(batch_idx))
                    os.makedirs(batch_dir, exist_ok=True)
                    for i in range(0, gz.shape[0],
                                   10):  # For each image in the batch
                        utils.save_images(
                            [gz[i].view(*(gz[i].shape))],
                            ["generated"],
                            os.path.join(
                                batch_dir, "gen_{}_{}_{}_{}.png".format(
                                    args.meta['i2s'][svec[i].item()],
                                    args.meta['i2t'][tvec[i].item()],
                                    args.meta['i2c'][cvec[i].item()], i)),
                            indexes=[1],
                            nrows=1,
                            # mu_=datasets[s].mu,
                            # std_=datasets[s].std,
                        )
                # Log stuff
            writer.add_scalars((prefix + "-" if prefix != "" else "") + s,
                               {k: np.mean(v)
                                for k, v in losses.items()}, batch_idx)
    model = model.train()
    print("{}testing took {}s".format(prefix + " " if prefix != "" else "",
                                      time.time() - start))
Exemplo n.º 19
0
    def train(self):
        config = tf.ConfigProto()
        config.gpu_options.allow_growth = True

        with tf.Session(config=config) as sess:

            summary_writer = tf.summary.FileWriter(self.log_dir, sess.graph)
            start_point = 0

            if self.stage != 1:
                if self.trans:
                    could_load, _ = load(self.restore, sess, self.check_dir_read)
                    if not could_load:
                        raise RuntimeError('Could not load previous stage during transition')
                else:
                    could_load, _ = load(self.saver, sess, self.check_dir_read)
                    if not could_load:
                        raise RuntimeError('Could not load current stage')

            # variables to init
            vars_to_init = initialize_uninitialized(sess)
            sess.run(tf.variables_initializer(vars_to_init))

            sample_z = np.random.normal(0, 1, (self.sample_num, self.z_dim))
            _, sample_cond, _, captions = self.dataset.test.next_batch_test(self.sample_num, 0, 1)
            sample_cond = np.squeeze(sample_cond, axis=0)
            print('Conditionals sampler shape: {}'.format(sample_cond.shape))

            save_captions(self.sample_path, captions)
            start_time = time.time()

            for idx in range(start_point + 1, self.steps):
                if self.trans:
                    # Reduce the learning rate during the transition period and slowly increase it
                    p = idx / self.steps
                    self.lr_inp = self.lr  # * np.exp(-2 * np.square(1 - p))

                epoch_size = self.dataset.train.num_examples // self.batch_size
                epoch = idx // epoch_size

                images, wrong_images, embed, _, _ = self.dataset.train.next_batch(self.batch_size, 4,
                                                                                  wrong_img=True,
                                                                                  embeddings=True)
                batch_z = np.random.normal(0, 1, (self.batch_size, self.z_dim))
                eps = np.random.uniform(0., 1., size=(self.batch_size, 1, 1, 1))

                feed_dict = {
                    self.x: images,
                    self.learning_rate: self.lr_inp,
                    self.x_mismatch: wrong_images,
                    self.cond: embed,
                    self.z: batch_z,
                    self.epsilon: eps,
                    self.z_sample: sample_z,
                    self.cond_sample: sample_cond,
                    self.iter: idx,
                }

                _, err_d = sess.run([self.D_optim, self.D_loss], feed_dict=feed_dict)
                _, err_g = sess.run([self.G_optim, self.G_loss], feed_dict=feed_dict)

                if np.mod(idx, 20) == 0:
                    summary_str = sess.run(self.summary_op, feed_dict=feed_dict)
                    summary_writer.add_summary(summary_str, idx)

                    print("Epoch: [%2d] [%4d] time: %4.4f, d_loss: %.8f, g_loss: %.8f"
                          % (epoch, idx, time.time() - start_time, err_d, err_g))

                if np.mod(idx, 2000) == 0:
                    try:
                        samples = sess.run(self.sampler, feed_dict={
                                                    self.z_sample: sample_z,
                                                    self.cond_sample: sample_cond})
                        samples = np.clip(samples, -1., 1.)
                        if self.out_size > 256:
                            samples = samples[:4]

                        save_images(samples, get_balanced_factorization(samples.shape[0]),
                                    '{}train_{:02d}_{:04d}.png'.format(self.sample_path, epoch, idx))

                    except Exception as e:
                        print("Failed to generate sample image")
                        print(type(e))
                        print(e.args)
                        print(e)

                if np.mod(idx, 2000) == 0 or idx == self.steps - 1:
                    save(self.saver, sess, self.check_dir_write, idx)
                sys.stdout.flush()

        tf.reset_default_graph()

if __name__ == '__main__':
    # option setup
    cfg = config.Config()
    modify_commandline_options(cfg.parser)
    opt = cfg.parse()

    # model setup
    model = ddpm.DDPModel(opt)
    model.setup(opt)               # regular setup: load and print networks; create schedulers

    # save options in the checkpoints directory
    cfg.save()

    # save directory setup
    save_dir = os.path.join(opt.results_dir, opt.name)  # save all the images to save_dir
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)
        print('Directory created: %s' % save_dir)

    # sampling
    total_iters = 0                # the total number of training iterations

    for test in range(opt.num_test):
        outputs = model.sample()
        utils.save_images(outputs, [os.path.join(save_dir, 'sample_%d.png' % (total_iters+n)) for n in range(opt.batch_size)])
        total_iters += opt.batch_size
        if total_iters % opt.print_freq == 0:    # print training losses and save logging information to the disk
            print('sampled %d images in %s' % (total_iters, opt.results_dir))
    def test(self):

        opt = self.opt

        gpu_ids = range(torch.cuda.device_count())
        print('Number of GPUs in use {}'.format(gpu_ids))

        iteration = 0

        if torch.cuda.device_count() > 1:
            vae = nn.DataParallel(VAE(hallucination=self.useHallucination,
                                      opt=opt,
                                      refine=self.refine,
                                      bg=128,
                                      fg=896),
                                  device_ids=gpu_ids).cuda()
        else:
            vae = VAE(hallucination=self.useHallucination, opt=opt).cuda()

        print(self.jobname)

        if self.load:
            # model_name = '../' + self.jobname + '/{:06d}_model.pth.tar'.format(self.iter_to_load)
            model_name = '../pretrained_models/cityscapes/refine_genmask_w_mask_two_path_096000.pth.tar'

            print("loading model from {}".format(model_name))

            state_dict = torch.load(model_name)
            if torch.cuda.device_count() > 1:
                vae.module.load_state_dict(state_dict['vae'])
            else:
                vae.load_state_dict(state_dict['vae'])

        z_noise = torch.ones(1, 1024).normal_()

        for data, bg_mask, fg_mask, paths in tqdm(iter(self.testloader)):
            # Set to evaluation mode (randomly sample z from the whole distribution)
            vae.eval()

            # If test on generated images
            # data = data.unsqueeze(1)
            # data = data.repeat(1, opt.num_frames, 1, 1, 1)

            frame1 = data[:, 0, :, :, :]
            noise_bg = torch.randn(frame1.size())
            z_m = Vb(z_noise.repeat(frame1.size()[0] * 8, 1))

            y_pred_before_refine, y_pred, flow, flowback, mask_fw, mask_bw, warped_mask_bg, warped_mask_fg = vae(
                frame1, data, bg_mask, fg_mask, noise_bg, z_m)
            '''iterative generation'''

            for i in range(5):
                noise_bg = torch.randn(frame1.size())

                y_pred_before_refine_1, y_pred_1, flow_1, flowback_1, mask_fw_1, mask_bw_1, warped_mask_bg, warped_mask_fg = vae(
                    y_pred[:, -1, ...], y_pred, warped_mask_bg, warped_mask_fg,
                    noise_bg, z_m)

                y_pred_before_refine = torch.cat(
                    [y_pred_before_refine, y_pred_before_refine_1], 1)
                y_pred = torch.cat([y_pred, y_pred_1], 1)
                flow = torch.cat([flow, flow_1], 2)
                flowback = torch.cat([flowback, flowback_1], 2)
                mask_fw = torch.cat([mask_fw, mask_fw_1], 1)
                mask_bw = torch.cat([mask_bw, mask_bw_1], 1)

            print(y_pred_before_refine.size())

            utils.save_samples(data,
                               y_pred_before_refine,
                               y_pred,
                               flow,
                               mask_fw,
                               mask_bw,
                               iteration,
                               self.sampledir,
                               opt,
                               eval=True,
                               useMask=True,
                               grid=[4, 4])

            # '''save images'''
            utils.save_images(self.output_image_dir, data, y_pred, paths, opt)
            utils.save_images(self.output_image_dir_before, data,
                              y_pred_before_refine, paths, opt)

            iteration += 1
Exemplo n.º 22
0
    def train(self):
        # clip
        #tmp1 = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='discriminatorYdomain')
        #tmp2 = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='discriminatorXdomain')
        #tmp_vars = tmp1+tmp2
        #self.clip_D = [var.assign(tf.clip_by_value(var, -0.01, 0.01)) for var in tmp_vars]

        last_G_loss = 0.0
        last_D_loss = 0.0
        self.saver = tf.train.Saver()
        gpu_options = tf.GPUOptions(allow_growth=True)
        self.sess = tf.Session(config=tf.ConfigProto(gpu_options=gpu_options))
        fig_count = 0
        self.sess.run(tf.global_variables_initializer())
        X_domain_image_val, Y_domain_image_val = self.data.validation_data()

        for epoch in range(1, self.options.epoch + 1):
            # update D
            for _ in tqdm(
                    range(
                        int(self.data._num_examples /
                            self.options.batch_size))):
                #self.sess.run(self.clip_D)
                X_domain_image, Y_domain_image = self.data.next_batch(
                    self.options.batch_size)

                #_ , img =self.sess.run(
                #   [self.D_solver,self.domain_X_decode_Y],
                #   feed_dict={self.X: X_domain_image, self.Y: Y_domain_image})
                _, img = self.sess.run([self.D_solver, self.fake_B],
                                       feed_dict={
                                           self.X: X_domain_image,
                                           self.Y: Y_domain_image
                                       })
                # update G
                # for _ in range(int(self.data._num_examples/self.options.batch_size)):
                #     X_domain_image, Y_domain_image = self.data.next_batch(self.options.batch_size)
                self.sess.run(self.G_solver,
                              feed_dict={
                                  self.X: X_domain_image,
                                  self.Y: Y_domain_image
                              })
            '''
                # fake img label to train G
                self.sess.run(
                    self.C_fake_solver,
                    feed_dict={self.y: y_b, self.z: sample_z(batch_size, self.z_dim)})
            '''
            # save img, model. print loss
            if epoch % 5 == 0 or epoch < 100:
                D_loss_curr, G_loss_curr = 0.0, 0.0
                for i in tqdm(range(X_domain_image_val.shape[0])):
                    if i % 100 == 0:
                        #Y_domain_image_ = self.sess.run(self.domain_X_decode_Y,
                        #                               feed_dict={self.X: X_domain_image_val[i][np.newaxis,:,:,:],
                        #                                         self.Y: Y_domain_image_val[i][np.newaxis,:,:,:]
                        #                                        }
                        Y_domain_image_ = self.sess.run(
                            self.fake_B,
                            feed_dict={
                                self.X:
                                X_domain_image_val[i][np.newaxis, :, :, :],
                                self.Y:
                                Y_domain_image_val[i][np.newaxis, :, :, :]
                            })
                        save_images(
                            Y_domain_image_, [1, 1],
                            './{}/A_{:02d}_{:04d}.jpg'.format(
                                'train_output', epoch, i))
                    D_loss_curr += self.sess.run(
                        self.D_loss,
                        feed_dict={
                            self.X: X_domain_image_val[i][np.newaxis, :, :, :],
                            self.Y: Y_domain_image_val[i][np.newaxis, :, :, :]
                        })
                    G_loss_curr += self.sess.run(
                        self.G_loss,
                        feed_dict={
                            self.X: X_domain_image_val[i][np.newaxis, :, :, :],
                            self.Y: Y_domain_image_val[i][np.newaxis, :, :, :]
                        })
                print('Iter: %d; D loss: %10.3f; G_loss: %10.3f' %
                      (epoch, D_loss_curr / X_domain_image_val.shape[0],
                       G_loss_curr / X_domain_image_val.shape[0]))

                # if epoch % 500 == 0:
                #     y_s = sample_y(16, self.y_dim, fig_count % 10)
                #     samples = self.sess.run(self.G_sample, feed_dict={self.y: y_s, self.z: sample_z(16, self.z_dim)})
                #
                #     fig = self.data.data2fig(samples)
                #     plt.savefig('{}/{}_{}.png'.format(sample_dir, str(fig_count).zfill(3), str(fig_count % 10)),
                #                 bbox_inches='tight')
                #     fig_count += 1
                #     plt.close(fig)

            #if  epoch%(self.options.save_freq) == 0:
            if epoch in [2, 5, 10, 20, 50, 100, 150, 200]:
                self.saver.save(
                    self.sess,
                    os.path.join(
                        os.getcwd(), 'model_save', 'xgan%s_d%s_g%s.ckpt' %
                        (epoch, D_loss_curr, G_loss_curr)))
                print('model save at %s' %
                      os.path.join(os.getcwd(), 'model_save', 'xgan.ckpt'))
Exemplo n.º 23
0
def test_(test_iter, net, experiment_dir_final, loss_function, loss_type,
          void_labels, save_test_images, n_classes):
    ckt_names = ['best_jaccard.t7']

    for ckt_name in ckt_names:
        print('Testing checkpoint ' + ckt_name)
        checkpoint = torch.load(
            os.path.join(experiment_dir_final, 'checkpoint', ckt_name))
        print('Checkpoint loaded for testing...')
        net.load_state_dict(checkpoint['net'])

        net.eval()
        test_loss = 0
        total = 0
        # Create the confusion matrix
        cm = np.zeros((n_classes, n_classes))
        nTest = test_iter.nbatches
        for batch_idx in range(nTest):
            all_data = test_iter.next()
            data_ = all_data[0]
            target_ = all_data[1]

            data, target = data_.transpose((0, 3, 1, 2)), target_.transpose(
                (0, 3, 1, 2))
            data, target = torch.from_numpy(data), torch.from_numpy(target)
            data, target = data.cuda(), target.cuda()
            data, target = Variable(data), Variable(target)
            output = net(data)

            target = target.type(torch.LongTensor).cuda()
            _, target_indices = torch.max(target, 1)
            _, output_indices = torch.max(output, 1)
            flattened_output = output_indices.view(-1)
            flattened_target = target_indices.view(-1)

            loss = loss_function(output, target_indices)

            cm = confusion_matrix(cm,
                                  flattened_output.data.cpu().numpy(),
                                  flattened_target.data.cpu().numpy(),
                                  n_classes)

            test_loss += loss.data[0]
            _, predicted = torch.max(output.data, 1)
            total += target.size(0)

            progress_bar(batch_idx, test_iter.nbatches,
                         'Test loss: %.3f' % (test_loss / (batch_idx + 1)))

            if save_test_images:
                save_images(data_, target_, output, experiment_dir_final,
                            batch_idx, void_labels)

            del (output)
            del (loss)
            del (flattened_output)
            del (output_indices)

        jaccard_per_class, jaccard, accuracy = compute_metrics(cm)
        metrics_string = print_metrics(test_loss, nTest, n_classes,
                                       jaccard_per_class, jaccard, accuracy)
        print(metrics_string)
Exemplo n.º 24
0
    def train(self):
        self.define_losses()
        self.define_summaries()

        sample_z = np.random.normal(0, 1,
                                    (self.model.sample_num, self.model.z_dim))
        _, sample_embed, _, captions = self.dataset.test.next_batch_test(
            self.model.sample_num, 0, 1)
        sample_embed = np.squeeze(sample_embed, axis=0)
        print(sample_embed.shape)

        save_captions(self.cfg.SAMPLE_DIR, captions)

        counter = 1
        start_time = time.time()

        could_load, checkpoint_counter = load(self.saver, self.sess,
                                              self.cfg.CHECKPOINT_DIR)
        if could_load:
            counter = checkpoint_counter
            print(" [*] Load SUCCESS")
        else:
            print(" [!] Load failed...")

        initialize_uninitialized(self.sess)

        # Updates per epoch are given by the training data size / batch size
        updates_per_epoch = self.dataset.train.num_examples // self.model.batch_size
        epoch_start = counter // updates_per_epoch

        for epoch in range(epoch_start, self.cfg.TRAIN.EPOCH):
            cen_epoch = epoch // 100

            for idx in range(0, updates_per_epoch):
                images, wrong_images, embed, _, _ = self.dataset.train.next_batch(
                    self.model.batch_size, 4, embeddings=True, wrong_img=True)
                batch_z = np.random.normal(
                    0, 1, (self.model.batch_size, self.model.z_dim))

                feed_dict = {
                    self.learning_rate: self.lr * (0.5**cen_epoch),
                    self.model.inputs: images,
                    self.model.wrong_inputs: wrong_images,
                    self.model.embed_inputs: embed,
                    self.model.z: batch_z,
                }

                # Update D network
                _, err_d, summary_str = self.sess.run(
                    [self.D_optim, self.D_loss, self.D_merged_summ],
                    feed_dict=feed_dict)
                self.writer.add_summary(summary_str, counter)

                # Update G network
                _, err_g, summary_str = self.sess.run(
                    [self.G_optim, self.G_loss, self.G_merged_summ],
                    feed_dict=feed_dict)
                self.writer.add_summary(summary_str, counter)

                counter += 1
                print(
                    "Epoch: [%2d] [%4d/%4d] time: %4.4f, d_loss: %.8f, g_loss: %.8f"
                    % (epoch, idx, updates_per_epoch, time.time() - start_time,
                       err_d, err_g))

                if np.mod(counter, 500) == 0:
                    try:
                        samples = self.sess.run(self.model.sampler,
                                                feed_dict={
                                                    self.model.z_sample:
                                                    sample_z,
                                                    self.model.embed_sample:
                                                    sample_embed,
                                                })
                        save_images(
                            samples,
                            get_balanced_factorization(samples.shape[0]),
                            '{}train_{:02d}_{:04d}.png'.format(
                                self.cfg.SAMPLE_DIR, epoch, idx))
                    except Exception as e:
                        print("Failed to generate sample image")
                        print(type(e))
                        print(e.args)
                        print(e)

                if np.mod(counter, 500) == 0:
                    save(self.saver, self.sess, self.cfg.CHECKPOINT_DIR,
                         counter)
Exemplo n.º 25
0
	def train(self):
		self.define_losses()
		self.define_summaries()

		sample_z = np.random.normal(0, 1, (self.model.sample_num, self.model.z_dim))
		_, sample_embed, _, captions = self.dataset.test.next_batch_test(self.model.sample_num, 0, 1)
		im_feats_test, sent_feats_test, labels_test = self.test_data_loader.get_batch(0,self.cfg.RETRIEVAL.SAMPLE_NUM,\
														image_aug = self.cfg.RETRIEVAL.IMAGE_AUG, phase='test')        
		sample_embed = np.squeeze(sample_embed, axis=0)
		print(sample_embed.shape)

		save_captions(self.cfg.SAMPLE_DIR, captions)

		counter = 1
		start_time = time.time()

		could_load, checkpoint_counter = load(self.stageii_saver, self.sess, self.cfg.CHECKPOINT_DIR)
		if could_load:
			counter = checkpoint_counter
			print(" [*] Load SUCCESS: Stage II networks are loaded.")
		else:
			print(" [!] Load failed for stage II networks...")

		could_load, checkpoint_counter = load(self.stagei_g_saver, self.sess, self.cfg_stage_i.CHECKPOINT_DIR)
		if could_load:
			print(" [*] Load SUCCESS: Stage I generator is loaded")
		else:
			print(" [!] WARNING!!! Failed to load the parameters for stage I generator...")

		initialize_uninitialized(self.sess)

		# Updates per epoch are given by the training data size / batch size
		updates_per_epoch = self.dataset.train.num_examples // self.model.batch_size
		epoch_start = counter // updates_per_epoch

		for epoch in range(epoch_start, self.cfg.TRAIN.EPOCH):
			cen_epoch = epoch // 100

			for idx in range(0, updates_per_epoch):
				images, wrong_images, embed, _, _ = self.dataset.train.next_batch(self.model.batch_size, 1,
																				  embeddings=True,
																				  wrong_img=True)
				batch_z = np.random.normal(0, 1, (self.model.batch_size, self.model.z_dim))

				# Retrieval data loader
				if idx % updates_per_epoch == 0:
					self.R_loader.shuffle_inds()
				
				im_feats, sent_feats, labels = self.R_loader.get_batch(idx % updates_per_epoch,\
								self.cfg.RETRIEVAL.BATCH_SIZE, image_aug = self.cfg.RETRIEVAL.IMAGE_AUG)                

				feed_dict = {
					self.learning_rate: self.lr * (0.5**cen_epoch),
					self.model.inputs: images,
					self.model.wrong_inputs: wrong_images,
					# self.model.embed_inputs: embed,
					# self.model.embed_inputs: self.txt_emb,
					self.model.z: batch_z,
					self.Retrieval.image_placeholder : im_feats, 
					self.Retrieval.sent_placeholder : sent_feats,
					self.Retrieval.label_placeholder : labels
				}

				# Update D network
				_, err_d, summary_str = self.sess.run([self.D_optim, self.D_loss, self.D_merged_summ],
													  feed_dict=feed_dict)
				self.writer.add_summary(summary_str, counter)

				# Update G network
				_, err_g, summary_str = self.sess.run([self.G_optim, self.G_loss, self.G_merged_summ],
													  feed_dict=feed_dict)
				self.writer.add_summary(summary_str, counter)
				
				# Update R network
				_, err_r, summary_str = self.sess.run([self.R_optim, self.R_loss, self.R_loss_summ],
													  feed_dict=feed_dict)
				self.writer.add_summary(summary_str, counter)                 

				counter += 1
				print("Epoch: [%2d] [%4d/%4d] time: %4.4f, d_loss: %.8f, g_loss: %.8f, r_loss: %.8f"
					  % (epoch, idx, updates_per_epoch,
						 time.time() - start_time, err_d, err_g, err_r))

				if np.mod(counter, 1000) == 0:
					try:
						# pdb.set_trace()
						self.Retrieval.eval()
						sent_emb = self.sess.run(self.Retrieval.sent_embed_tensor,
												feed_dict={
															self.Retrieval.image_placeholder_test: im_feats_test,
															self.Retrieval.sent_placeholder_test: sent_feats_test,
														  })
						self.model.eval(sent_emb)								  
						samples = self.sess.run(self.model.sampler,
												feed_dict={
															self.model.z_sample: sample_z,
															# self.model.embed_sample: sample_embed,
															self.model.embed_sample: sent_emb,
														  })
						save_images(samples, get_balanced_factorization(samples.shape[0]),
									'{}train_{:02d}_{:04d}.png'.format(self.cfg.SAMPLE_DIR, epoch, idx))
					except Exception as e:
						print("Failed to generate sample image")
						print(type(e))
						print(e.args)
						print(e)

				if np.mod(counter, 500) == 2:
					save(self.stageii_saver, self.sess, self.cfg.CHECKPOINT_DIR, counter)

			if np.mod(epoch, 50) == 0 and epoch!=0:
				self.ret_eval(epoch)