def evaluate_inception(self):
        incep_batch_size = self.cfg.EVAL.INCEP_BATCH_SIZE
        logits, _ = load_inception_inference(
            self.sess, self.classes, incep_batch_size,
            self.cfg.EVAL.INCEP_CHECKPOINT_DIR)
        pred_op = tf.nn.softmax(logits)

        z = tf.placeholder(tf.float32, [self.bs, self.model.z_dim], name='z')
        cond = tf.placeholder(tf.float32, [self.bs] + [self.model.embed_dim],
                              name='cond')
        eval_gen, _, _ = self.model.generator(z,
                                              cond,
                                              reuse=False,
                                              is_training=False)

        saver = tf.train.Saver(tf.global_variables('g_net'))
        could_load, _ = load(saver, self.sess, self.cfg.CHECKPOINT_DIR)
        if could_load:
            print(" [*] Load SUCCESS")
        else:
            print(" [!] Load failed...")
            raise RuntimeError(
                'Could not load the checkpoints of the generator')

        print('Generating x...')

        size = self.cfg.EVAL.SIZE
        n_batches = size // self.bs

        w, h, c = self.model.image_dims[0], self.model.image_dims[
            1], self.model.image_dims[2]
        samples = np.zeros((n_batches * self.bs, w, h, c))
        for i in range(n_batches):
            print("\rGenerating batch %d/%d" % (i + 1, n_batches),
                  end="",
                  flush=True)

            sample_z = np.random.normal(0, 1, size=(self.bs, self.model.z_dim))
            _, _, embed, _, _ = self.dataset.test.next_batch(self.bs,
                                                             4,
                                                             embeddings=True)
            start = i * self.bs
            end = start + self.bs

            gen_batch = self.sess.run(eval_gen,
                                      feed_dict={
                                          z: sample_z,
                                          cond: embed
                                      })
            samples[start:end] = denormalize_images(gen_batch)

        print('\nComputing inception score...')
        mean, std = inception_score.get_inception_score(samples,
                                                        self.sess,
                                                        incep_batch_size,
                                                        10,
                                                        pred_op,
                                                        verbose=True)
        print('Inception Score | mean:', "%.2f" % mean, 'std:', "%.2f" % std)
示例#2
0
    def evaluate_fid(self):
        incep_batch_size = self.cfg.EVAL.INCEP_BATCH_SIZE
        _, layers = load_inception_inference(self.sess, 20, incep_batch_size,
                                             self.cfg.EVAL.INCEP_CHECKPOINT_DIR)
        pool3 = layers['PreLogits']
        act_op = tf.reshape(pool3, shape=[incep_batch_size, -1])

        if not os.path.exists(self.cfg.EVAL.ACT_STAT_PATH):
            print('Computing activation statistics for real x')
            fid.compute_and_save_activation_statistics(self.cfg.EVAL.R_IMG_PATH, self.sess, incep_batch_size, act_op,
                                                       self.cfg.EVAL.ACT_STAT_PATH, verbose=True)

        print('Loading activation statistics for the real x')
        stats = np.load(self.cfg.EVAL.ACT_STAT_PATH)
        mu_real = stats['mu']
        sigma_real = stats['sigma']

        z = tf.placeholder(tf.float32, [self.bs, self.model.z_dim], name='real_images')
        cond = tf.placeholder(tf.float32, [self.bs] + [self.model.embed_dim], name='cond')
        eval_gen, _, _ = self.model.generator(z, cond, reuse=False)

        saver = tf.train.Saver(tf.global_variables('g_net'))
        could_load, _ = load(saver, self.sess, self.cfg.CHECKPOINT_DIR)
        if could_load:
            print(" [*] Load SUCCESS")
        else:
            print(" [!] Load failed...")
            raise RuntimeError('Could not load the checkpoints of the generator')

        print('Generating x...')

        fid_size = self.cfg.EVAL.SIZE
        n_batches = fid_size // self.bs

        w, h, c = self.model.image_dims[0], self.model.image_dims[1], self.model.image_dims[2]
        samples = np.zeros((n_batches * self.bs, w, h, c))
        for i in range(n_batches):
            start = i * self.bs
            end = start + self.bs

            sample_z = np.random.normal(0, 1, size=(self.bs, self.model.z_dim))
            images, _, embed, _, _ = self.dataset.test.next_batch(self.bs, 4, embeddings=True)

            samples[start: end] = denormalize_images(self.sess.run(eval_gen, feed_dict={z: sample_z, cond: embed}))

        print('Computing activation statistics for generated x...')
        mu_gen, sigma_gen = fid.calculate_activation_statistics(samples, self.sess, incep_batch_size, act_op,
                                                                verbose=True)
        print("calculate FID:", end=" ", flush=True)
        try:
            FID = fid.calculate_frechet_distance(mu_gen, sigma_gen, mu_real, sigma_real)
        except Exception as e:
            print(e)
            FID = 500

        print(FID)
    def __init__(self, sess: tf.Session, model: GanCls, dataset: TextDataset,
                 config):
        self.sess = sess
        self.model = model
        self.sampler = model.sampler
        self.dataset = dataset
        self.config = config
        self.saver = tf.train.Saver()

        could_load, _ = load(self.saver, self.sess, self.config.checkpoint_dir)
        if could_load:
            print(" [*] Load SUCCESS")
        else:
            print(" [!] Load failed...")
            raise LookupError('Could not load any checkpoints')
示例#4
0
def check_save():
    try:
        save_path = config.SAVE_PATH
        res = saver.load(save_path)
        if res == False:
            # 追加提示
            print('未发现存档文件')
            return False
        else:
            if res['slot1'] == 'null' and res['slot2'] == 'null' and res[
                    'slot3'] == 'null':
                return False
            return True
    except Exception as err:
        return False
示例#5
0
def load_inception_inference(sess, num_classes, batch_size, checkpoint_dir):
    """Loads the inception network with the parameters from checkpoint_dir"""
    # Build a Graph that computes the logits predictions from the inference model.
    inputs = tf.placeholder(tf.float32, [batch_size, 299, 299, 3],
                            name='inputs')
    logits, layers = inception_net(inputs, num_classes)

    inception_vars = tf.global_variables('InceptionV3')

    saver = tf.train.Saver(inception_vars)
    print('Restoring Inception model from %s' % checkpoint_dir)

    could_load, _ = load(saver, sess, checkpoint_dir)
    if could_load:
        print(" [*] Load SUCCESS")
    else:
        print(" [!] Load failed...")
    return logits, layers
示例#6
0
def dojob(x, y, is_mouse_down, keys):

    #显示底版
    loader.screen.blit(loader.SAVE_BOARD_TITLE, (0, 0))
    a0 = (x >= 20 and x < 100 and y >= 20 and y < 60)
    a1 = (x >= 0 and x <= 860 and y >= 80 and y < 280)
    a2 = (x >= 0 and x <= 860 and y >= 280 and y < 480)
    a3 = (x >= 0 and x <= 860 and y >= 480 and y < 680)

    if a0:
        loader.screen.blit(loader.SAVE_BOARD_TITLE_BACK2, (20, 20))
        if is_mouse_down == True:
            player_runtime.INFO['loading'] = False
    else:
        loader.screen.blit(loader.SAVE_BOARD_TITLE_BACK1, (20, 20))

    if player_runtime.INFO['checksover'] == True:
        loader.screen.blit(loader.SAVE_BOARD_LOT, (0, 80))
        loader.screen.blit(loader.SAVE_BOARD_LOT, (0, 280))
        loader.screen.blit(loader.SAVE_BOARD_LOT, (0, 480))

        s = 90
        for sd in player_runtime.SDATA:
            try:
                img_url = sd['img_path']
                if not img_url == '':
                    img_res = sd['img_res']
                    cr = sd['round']
                    cd = sd['date']
                    loader.screen.blit(img_res, (15, s))
                    crt = loader.BAIKE_FONT.render(cr, True, color_rgb.WHITE,
                                                   None)
                    cdt = loader.BAIKE_FONT.render(cd, True, color_rgb.WHITE,
                                                   None)

                    loader.screen.blit(crt, (450, s + 35))
                    loader.screen.blit(cdt, (450, s + 90))
            except Exception as err:
                print('save read err')
                print(err)
            s = s + 200

        #执行读档逻辑
        #先提示
        loader.screen.blit(loader.CHECK_LOAD, (230, 200))
        # 确认按钮
        loader.screen.blit(loader.SURE, (300, 370))
        # 取消按钮
        loader.screen.blit(loader.SURENO, (460, 370))

        #点击确认
        if x >= 300 and x <= 380 and y >= 370 and y <= 410:
            loader.screen.blit(loader.SELECT_SAVE, (300, 370))
            if is_mouse_down == True:
                try:
                    cslot = player_runtime.INFO['cslot'] + 1
                    cslot_path = 'save/slot' + str(cslot) + '.pkl'
                    res = saver.load(cslot_path)
                    if res == False:
                        print('存档异常')
                    else:
                        player_runtime.INFO = res
                        player_runtime.INFO['checksover'] = False
                        player_runtime.INFO['saving'] = False
                        loader.curs_code = game_page.SCODE
                except Exception as err:
                    print('load error')
                    print(err)

        elif x >= 460 and x <= 540 and y >= 370 and y <= 410:
            loader.screen.blit(loader.SELECT_SAVE, (460, 370))
            if is_mouse_down == True:
                #取消的话就返回到界面浏览状态就行了
                player_runtime.INFO['checksover'] = False

    else:
        #执行选择存档逻辑
        if a1:
            loader.screen.blit(loader.SAVE_BOARD_LOTA, (0, 80))
            if is_mouse_down == True and not player_runtime.SDATA[0][
                    'img_path'] == '':
                player_runtime.INFO['checksover'] = True
                player_runtime.INFO['cslot'] = 0
        else:
            loader.screen.blit(loader.SAVE_BOARD_LOT, (0, 80))

        if a2:
            loader.screen.blit(loader.SAVE_BOARD_LOTA, (0, 280))
            if is_mouse_down == True and not player_runtime.SDATA[1][
                    'img_path'] == '':
                player_runtime.INFO['checksover'] = True
                player_runtime.INFO['cslot'] = 1
        else:
            loader.screen.blit(loader.SAVE_BOARD_LOT, (0, 280))

        if a3:
            loader.screen.blit(loader.SAVE_BOARD_LOTA, (0, 480))
            if is_mouse_down == True and not player_runtime.SDATA[2][
                    'img_path'] == '':
                player_runtime.INFO['checksover'] = True
                player_runtime.INFO['cslot'] = 2
        else:
            loader.screen.blit(loader.SAVE_BOARD_LOT, (0, 480))

        #绘制存档截图
        s = 90
        for sd in player_runtime.SDATA:
            try:
                img_url = sd['img_path']
                if not img_url == '':
                    img_res = sd['img_res']
                    cr = sd['round']
                    cd = sd['date']
                    loader.screen.blit(img_res, (15, s))
                    crt = loader.BAIKE_FONT.render(cr, True, color_rgb.WHITE,
                                                   None)
                    cdt = loader.BAIKE_FONT.render(cd, True, color_rgb.WHITE,
                                                   None)

                    loader.screen.blit(crt, (450, s + 35))
                    loader.screen.blit(cdt, (450, s + 90))
            except Exception as err:
                print('save read err')
                print(err)
            s = s + 200
示例#7
0
    def visualize(self):
        z = tf.placeholder(tf.float32, [self.model.batch_size, self.model.z_dim], name='z')
        cond = tf.placeholder(tf.float32, [self.model.batch_size] + [self.model.embed_dim], name='cond')
        gen_stagei, _, _ = self.model.stagei.generator(z, cond, is_training=False)
        gen, _, _ = self.model.generator(gen_stagei, cond, is_training=False)
        gen_no_noise, _, _ = self.model.generator(gen_stagei, cond, is_training=False, reuse=True, cond_noise=False)

        saver = tf.train.Saver(tf.global_variables('g_net'))
        could_load, _ = load(saver, self.sess, self.model.stagei.cfg.CHECKPOINT_DIR)
        if could_load:
            print(" [*] Load SUCCESS")
        else:
            print(" [!] Load failed...")
            raise LookupError('Could not load any checkpoints for stage I')

        saver = tf.train.Saver(tf.global_variables('stageII_g_net'))
        could_load, _ = load(saver, self.sess, self.config.CHECKPOINT_DIR)
        if could_load:
            print(" [*] Load SUCCESS")
        else:
            print(" [!] Load failed...")
            raise LookupError('Could not load any checkpoints for stage II')

        dataset_pos = np.random.randint(0, self.dataset.test.num_examples)
        for idx in range(0):
            dataset_pos = np.random.randint(0, self.dataset.test.num_examples)
            dataset_pos2 = np.random.randint(0, self.dataset.test.num_examples)

            # Interpolation in z space:
            # ---------------------------------------------------------------------------------------------------------
            _, cond, _, captions = self.dataset.test.next_batch_test(1, dataset_pos, 1)
            cond = np.squeeze(cond, axis=0)
            caption = captions[0][0]

            samples = gen_noise_interp_img(self.sess, gen_no_noise, cond, self.model.z_dim, self.model.batch_size)
            save_cap_batch(samples, caption, '{}/{}_visual/z_interp/z_interp{}.png'.format(self.samples_dir,
                                                                                           self.dataset.name,
                                                                                           idx))
            # Interpolation in embedding space:
            # ---------------------------------------------------------------------------------------------------------

            _, cond1, _, caps1 = self.dataset.test.next_batch_test(1, dataset_pos, 1)
            _, cond2, _, caps2 = self.dataset.test.next_batch_test(1, dataset_pos2, 1)

            cond1 = np.squeeze(cond1, axis=0)
            cond2 = np.squeeze(cond2, axis=0)
            cap1, cap2 = caps1[0][0], caps2[0][0]

            samples = gen_cond_interp_img(self.sess, gen_no_noise, cond1, cond2, self.model.z_dim, self.model.batch_size)
            save_interp_cap_batch(samples, cap1, cap2,
                                  '{}/{}_visual/cond_interp/cond_interp{}.png'.format(self.samples_dir,
                                                                                      self.dataset.name,
                                                                                      idx))
            # make_gif(samples, '{}/{}_visual/cond_interp/gifs/cond_interp{}.gif'.format(self.samples_dir,
            #                                                                            self.dataset.name,
            #                                                                            idx), duration=4)

            # Generate captioned image
            # ---------------------------------------------------------------------------------------------------------
            _, conditions, _, captions = self.dataset.test.next_batch_test(1, dataset_pos, 1)
            conditions = np.squeeze(conditions, axis=0)
            caption = captions[0][0]
            samples = gen_captioned_img(self.sess, gen, conditions, self.model.z_dim, self.model.batch_size)

            save_cap_batch(samples, caption, '{}/{}_visual/cap/cap{}.png'.format(self.samples_dir,
                                                                                 self.dataset.name, idx))

            # Generate Stage I and Stage II images
            # ---------------------------------------------------------------------------------------------------------
            _, cond, _, captions = self.dataset.test.next_batch_test(self.model.batch_size, dataset_pos, 1)
            cond = np.squeeze(cond, axis=0)
            samples = gen_multiple_stage_img(self.sess, [gen_stagei, gen], cond, self.model.z_dim,
                                             self.model.batch_size, size=128)
            text = "Stage I and Stage II"
            save_cap_batch(samples, text, '{}/{}_visual/stages/stage{}.png'.format(self.samples_dir,
                                                                                   self.dataset.name, idx))

        special_flowers = [1126, 908, 398]
        special_birds = [12, 908, 1005]
        for idx, special_pos in enumerate(special_birds):
            print(special_pos)
            # Generate specific image
            # ---------------------------------------------------------------------------------------------------------
            _, conditions, _, captions = self.dataset.test.next_batch_test(1, special_pos, 1)
            conditions = np.squeeze(conditions, axis=0)
            caption = captions[0][0]
            samples = gen_captioned_img(self.sess, gen, conditions, self.model.z_dim, self.model.batch_size)

            save_cap_batch(samples, caption, '{}/{}_visual/special_cap/cap{}.png'.format(self.samples_dir,
                                                                                         self.dataset.name, idx))
示例#8
0
    def train(self):
        self.define_model()
        self.define_summaries()

        start_time = time.time()
        self.saver = tf.train.Saver(
            max_to_keep=self.cfg.TRAIN.CHECKPOINTS_TO_KEEP)

        if self.cfg.TRAIN.RESTORE_PRETRAIN:
            pretrain_saver = tf.train.Saver(self.pretrained_to_restore)

            # Load the pre-trained layer
            pretrain_saver.restore(self.sess,
                                   self.cfg.TRAIN.PRETRAINED_CHECKPOINT_DIR)

            # Initialise the not restored layers and the optimizer variables
            self.sess.run(
                tf.variables_initializer(self.not_to_restore + self.opt_vars))
            start_point = 0
        else:
            could_load, checkpoint_counter = load(self.saver, self.sess,
                                                  self.cfg.CHECKPOINT_DIR)
            if could_load:
                start_point = checkpoint_counter
                print(" [*] Load SUCCESS")
            else:
                print(" [!] Load failed...")
                raise RuntimeError(
                    'Failed to restore the complete Inception model')
        sys.stdout.flush()

        batch_size = self.cfg.TRAIN.BATCH_SIZE
        for idx in range(start_point + 1, self.cfg.TRAIN.MAX_STEPS):
            epoch_size = self.dataset.test.num_examples // batch_size
            epoch = idx // epoch_size

            images, _, _, _, labels = self.dataset.test.next_batch(batch_size,
                                                                   labels=True)

            # Bring the labels in a continuous range: [0, num_classes)
            new_labels = []
            for label in labels:
                new_labels.append(self.class_to_idx[label])

            assert (np.min(images) >= -1.)
            assert (np.max(images) <= 1.)
            assert (np.min(new_labels) >= 0)
            assert (np.max(new_labels) < 50)  # 20 for flowers, 50 for birds

            feed_dict = {
                self.x: images,
                self.labels: new_labels,
            }

            _, err = self.sess.run([self.opt_step, self.loss],
                                   feed_dict=feed_dict)

            summary_period = self.cfg.TRAIN.SUMMARY_PERIOD
            if np.mod(idx, summary_period) == 0:
                summary_str, pred = self.sess.run([self.summary_op, self.pred],
                                                  feed_dict=feed_dict)
                self.writer.add_summary(summary_str, idx)

                print("Epoch: [%2d] [%4d] time: %4.4f, loss: %.8f" %
                      (epoch, idx, time.time() - start_time, err))

            if np.mod(idx, 200) == 0:
                save(self.saver, self.sess, self.cfg.CHECKPOINT_DIR, idx)
            sys.stdout.flush()
示例#9
0
    def train(self):
        self.define_losses()
        self.define_summaries()

        sample_z = np.random.normal(0, 1,
                                    (self.model.sample_num, self.model.z_dim))
        _, sample_embed, _, captions = self.dataset.test.next_batch_test(
            self.model.sample_num, randint(0, self.dataset.test.num_examples),
            1)
        sample_embed = np.squeeze(sample_embed, axis=0)
        print(sample_embed.shape)

        # Display the captions of the sampled images
        print('\nCaptions of the sampled images:')
        for caption_idx, caption_batch in enumerate(captions):
            print('{}: {}'.format(caption_idx + 1, caption_batch[0]))
        print()

        counter = 1
        start_time = time.time()

        # Try to load the parameters of the stage II networks
        tf.global_variables_initializer().run()
        could_load, checkpoint_counter = load(self.stageii_saver, self.sess,
                                              self.cfg.CHECKPOINT_DIR)
        if could_load:
            counter = checkpoint_counter
            print(" [*] Load SUCCESS: Stage II networks are loaded.")
        else:
            print(" [!] Load failed for stage II networks...")

        could_load, checkpoint_counter = load(self.stagei_g_saver, self.sess,
                                              self.cfg_stage_i.CHECKPOINT_DIR)
        if could_load:
            counter = checkpoint_counter
            print(" [*] Load SUCCESS: Stage I generator is loaded")
        else:
            print(
                " [!] WARNING!!! Failed to load the parameters for stage I generator..."
            )

        for epoch in range(self.cfg.TRAIN.EPOCH):
            # Updates per epoch are given by the training data size / batch size
            updates_per_epoch = self.dataset.train.num_examples // self.model.batch_size

            for idx in range(0, updates_per_epoch):
                images, wrong_images, embed, _, _ = self.dataset.train.next_batch(
                    self.model.batch_size, 4)
                batch_z = np.random.normal(
                    0, 1, (self.model.batch_size, self.model.z_dim))

                # Update D network
                _, err_d_real_match, err_d_real_mismatch, err_d_fake, err_d, summary_str = self.sess.run(
                    [
                        self.D_optim, self.D_real_match_loss,
                        self.D_real_mismatch_loss, self.D_synthetic_loss,
                        self.D_loss, self.D_merged_summ
                    ],
                    feed_dict={
                        self.model.inputs: images,
                        self.model.wrong_inputs: wrong_images,
                        self.model.embed_inputs: embed,
                        self.model.z: batch_z
                    })
                self.writer.add_summary(summary_str, counter)

                # Update G network
                _, err_g, summary_str = self.sess.run(
                    [self.G_optim, self.G_loss, self.G_merged_summ],
                    feed_dict={
                        self.model.z: batch_z,
                        self.model.embed_inputs: embed
                    })
                self.writer.add_summary(summary_str, counter)

                counter += 1
                print(
                    "Epoch: [%2d] [%4d/%4d] time: %4.4f, d_loss: %.8f, g_loss: %.8f"
                    % (epoch, idx, updates_per_epoch, time.time() - start_time,
                       err_d, err_g))

                if np.mod(counter, 100) == 0:
                    try:
                        samples = self.sess.run(self.model.sampler,
                                                feed_dict={
                                                    self.model.z_sample:
                                                    sample_z,
                                                    self.model.embed_sample:
                                                    sample_embed,
                                                })
                        save_images(
                            samples, image_manifold_size(samples.shape[0]),
                            '{}train_{:02d}_{:04d}.png'.format(
                                self.cfg.SAMPLE_DIR, epoch, idx))
                        print("[Sample] d_loss: %.8f, g_loss: %.8f" %
                              (err_d, err_g))

                        # Display the captions of the sampled images
                        print('\nCaptions of the sampled images:')
                        for caption_idx, caption_batch in enumerate(captions):
                            print('{}: {}'.format(caption_idx + 1,
                                                  caption_batch[0]))
                        print()
                    except Exception as e:
                        print("Failed to generate sample image")
                        print(type(e))
                        print(e.args)
                        print(e)

                if np.mod(counter, 500) == 2:
                    save(self.stageii_saver, self.sess,
                         self.cfg.CHECKPOINT_DIR, counter)
示例#10
0
    def train(self):
        config = tf.ConfigProto()
        config.gpu_options.allow_growth = True

        with tf.Session(config=config) as sess:

            summary_writer = tf.summary.FileWriter(self.log_dir, sess.graph)
            start_point = 0

            if self.stage != 1:
                if self.trans:
                    could_load, _ = load(self.restore, sess, self.check_dir_read)
                    if not could_load:
                        raise RuntimeError('Could not load previous stage during transition')
                else:
                    could_load, _ = load(self.saver, sess, self.check_dir_read)
                    if not could_load:
                        raise RuntimeError('Could not load current stage')

            # variables to init
            vars_to_init = initialize_uninitialized(sess)
            sess.run(tf.variables_initializer(vars_to_init))

            sample_z = np.random.normal(0, 1, (self.sample_num, self.z_dim))
            _, sample_cond, _, captions = self.dataset.test.next_batch_test(self.sample_num, 0, 1)
            sample_cond = np.squeeze(sample_cond, axis=0)
            print('Conditionals sampler shape: {}'.format(sample_cond.shape))

            save_captions(self.sample_path, captions)
            start_time = time.time()

            for idx in range(start_point + 1, self.steps):
                if self.trans:
                    # Reduce the learning rate during the transition period and slowly increase it
                    p = idx / self.steps
                    self.lr_inp = self.lr  # * np.exp(-2 * np.square(1 - p))

                epoch_size = self.dataset.train.num_examples // self.batch_size
                epoch = idx // epoch_size

                images, wrong_images, embed, _, _ = self.dataset.train.next_batch(self.batch_size, 4,
                                                                                  wrong_img=True,
                                                                                  embeddings=True)
                batch_z = np.random.normal(0, 1, (self.batch_size, self.z_dim))
                eps = np.random.uniform(0., 1., size=(self.batch_size, 1, 1, 1))

                feed_dict = {
                    self.x: images,
                    self.learning_rate: self.lr_inp,
                    self.x_mismatch: wrong_images,
                    self.cond: embed,
                    self.z: batch_z,
                    self.epsilon: eps,
                    self.z_sample: sample_z,
                    self.cond_sample: sample_cond,
                    self.iter: idx,
                }

                _, err_d = sess.run([self.D_optim, self.D_loss], feed_dict=feed_dict)
                _, err_g = sess.run([self.G_optim, self.G_loss], feed_dict=feed_dict)

                if np.mod(idx, 20) == 0:
                    summary_str = sess.run(self.summary_op, feed_dict=feed_dict)
                    summary_writer.add_summary(summary_str, idx)

                    print("Epoch: [%2d] [%4d] time: %4.4f, d_loss: %.8f, g_loss: %.8f"
                          % (epoch, idx, time.time() - start_time, err_d, err_g))

                if np.mod(idx, 2000) == 0:
                    try:
                        samples = sess.run(self.sampler, feed_dict={
                                                    self.z_sample: sample_z,
                                                    self.cond_sample: sample_cond})
                        samples = np.clip(samples, -1., 1.)
                        if self.out_size > 256:
                            samples = samples[:4]

                        save_images(samples, get_balanced_factorization(samples.shape[0]),
                                    '{}train_{:02d}_{:04d}.png'.format(self.sample_path, epoch, idx))

                    except Exception as e:
                        print("Failed to generate sample image")
                        print(type(e))
                        print(e.args)
                        print(e)

                if np.mod(idx, 2000) == 0 or idx == self.steps - 1:
                    save(self.saver, sess, self.check_dir_write, idx)
                sys.stdout.flush()

        tf.reset_default_graph()
示例#11
0
    def train(self):
        self.define_summaries()

        self.saver = tf.train.Saver(max_to_keep=self.cfg.TRAIN.CHECKPOINTS_TO_KEEP)

        sample_z = np.random.normal(0, 1, (self.model.sample_num, self.model.z_dim))
        _, sample_cond, _, captions = self.dataset.test.next_batch_test(self.model.sample_num, 0, 1)
        # _, sample_cond, _, captions = self.dataset.test.next_batch_test(self.model.sample_num, 1, 1)
        sample_cond = np.squeeze(sample_cond, axis=0)
        print('Conditionals sampler shape: {}'.format(sample_cond.shape))

        save_captions(self.cfg.SAMPLE_DIR, captions)

        start_time = time.time()
        tf.global_variables_initializer().run()

        could_load, checkpoint_counter = load(self.saver, self.sess, self.cfg.CHECKPOINT_DIR)
        if could_load:
            start_point = checkpoint_counter
            print(" [*] Load SUCCESS")
        else:
            start_point = 0
            print(" [!] Load failed...")
        sys.stdout.flush()

        for idx in range(start_point + 1, self.cfg.TRAIN.MAX_STEPS):
            epoch_size = self.dataset.train.num_examples // self.model.batch_size
            epoch = idx // epoch_size

            images, wrong_images, embed, _, _ = self.dataset.train.next_batch(self.model.batch_size, 1, embeddings=True,
                                                                              wrong_img=True)
            batch_z = np.random.normal(0, 1, (self.model.batch_size, self.model.z_dim))
            eps = np.random.uniform(0., 1., size=(self.model.batch_size, 1, 1, 1))
            n_critic = self.cfg.TRAIN.N_CRITIC
            kiter = (idx // n_critic) // 10000

            feed_dict = {
                self.model.learning_rate_d: self.lr_d * (0.95**kiter),
                self.model.learning_rate_g: self.lr_g * (0.95**kiter),
                self.model.x: images,
                self.model.x_mismatch: wrong_images,
                self.model.cond: embed,
                self.model.z: batch_z,
                self.model.epsilon: eps,
                self.model.z_sample: sample_z,
                self.model.cond_sample: sample_cond,
                self.model.iter: idx,
            }

            _, _, err_d = self.sess.run([self.model.D_optim, self.model.kt_optim, self.model.D_loss],
                                         feed_dict=feed_dict)

            if idx % n_critic == 0:
                _, err_g = self.sess.run([self.model.G_optim, self.model.G_loss],
                                         feed_dict=feed_dict)

            summary_period = self.cfg.TRAIN.SUMMARY_PERIOD
            if np.mod(idx, summary_period) == 0:
                summary_str = self.sess.run(self.summary_op, feed_dict=feed_dict)
                self.writer.add_summary(summary_str, idx)

            if np.mod(idx, self.cfg.TRAIN.SAMPLE_PERIOD) == 0:
                try:
                    samples = self.sess.run(self.model.sampler,
                                            feed_dict={
                                                self.model.z_sample: sample_z,
                                                self.model.cond_sample: sample_cond,
                                            })
                    save_images(samples, get_balanced_factorization(samples.shape[0]),
                                '{}train_{:02d}_{:04d}.png'.format(self.cfg.SAMPLE_DIR, epoch, idx))

                except Exception as e:
                    print("Failed to generate sample image")
                    print(type(e))
                    print(e.args)
                    print(e)

            if np.mod(idx, 500) == 2:
                save(self.saver, self.sess, self.cfg.CHECKPOINT_DIR, idx)
            sys.stdout.flush()
示例#12
0
                      sample_path=None,
                      log_dir=None,
                      stage=stage[i],
                      trans=False,
                      build_model=False)

        cond = tf.placeholder(tf.float32, [None, 1024], name='cond')
        z = tf.placeholder(tf.float32, [None, z_dim], name='z')
        gen_op, _, _ = pggan.generator(z, cond, stages=stage[i], t=False)

        config = tf.ConfigProto()
        config.gpu_options.allow_growth = True

        with tf.Session(config=config) as sess:
            saver = tf.train.Saver(tf.global_variables('g_net'))
            could_load = load(saver, sess, pggan_checkpoint_dir_read)
            if not could_load:
                raise RuntimeError('Could not load stage %d' % stage[i])

            samples = sess.run(gen_op,
                               feed_dict={
                                   'z:0': z_sample,
                                   'cond:0': conditions
                               })
            samples = np.clip(samples, -1., 1.)
            all_samples.append(samples)

        tf.reset_default_graph()

    all_samples = gen_pggan_sample(all_samples, 128, interp='nearest')
示例#13
0
	def evaluate_inception(self):
		incep_batch_size = self.cfg.EVAL.INCEP_BATCH_SIZE
		logits, _ = load_inception_inference(self.sess, self.cfg.EVAL.NUM_CLASSES, incep_batch_size,
											 self.cfg.EVAL.INCEP_CHECKPOINT_DIR)
		pred_op = tf.nn.softmax(logits)
		
		z = tf.placeholder(tf.float32, [self.bs, self.model.stagei.z_dim], name='z')
		cond = tf.placeholder(tf.float32, [self.bs] + [self.model.stagei.embed_dim], name='cond')
		stagei_gen, _, _ = self.model.stagei.generator(z, cond, reuse=False, is_training=False)
		eval_gen, _, _ = self.model.generator(stagei_gen, cond, reuse=False, is_training=False)
		self.Retrieval.eval(self.bs)
		saver = tf.train.Saver(tf.global_variables('g_net')+tf.global_variables('vf_')+tf.global_variables('sf_')+
										tf.global_variables('att')) 
		could_load, _ = load(saver, self.sess, self.model.stagei.cfg.CHECKPOINT_DIR)
		
		if could_load:
			print(" [*] Load SUCCESS")
		else:
			print(" [!] Load failed...")
			raise RuntimeError('Could not load the checkpoints of stage I')

		saver = tf.train.Saver(tf.global_variables('stageII_g_net'))
		could_load, _ = load(saver, self.sess, self.cfg.CHECKPOINT_DIR)
		if could_load:
			print(" [*] Load SUCCESS")
		else:
			print(" [!] Load failed...")
			raise RuntimeError('Could not load the checkpoints of stage II')

		print('Generating batches...')

		size = self.cfg.EVAL.SIZE
		n_batches = size // self.bs

		all_preds = []
		for i in range(n_batches):
			print("\rGenerating batch %d/%d" % (i + 1, n_batches), end="", flush=True)

			sample_z = np.random.normal(0, 1, size=(self.bs, self.model.z_dim))
			# _, _, embed, _, _ = self.dataset.test.next_batch(self.bs, 4, embeddings=True)
			_, _, embed, _, _ = self.dataset.test.next_batch(self.bs, 1, embeddings=True)
			im_feats, sent_feats, labels = self.test_data_loader.get_batch(i, self.bs, phase = 'incep')

			# Generate a batch and scale it up for inception
			
			sent_emb = self.sess.run(self.Retrieval.sent_embed_tensor,
									feed_dict={
												self.Retrieval.image_placeholder_test: im_feats,
												self.Retrieval.sent_placeholder_test: sent_feats,
											  })			
			gen_batch = self.sess.run(eval_gen, feed_dict={z: sample_z, cond: sent_emb})

			

			samples = denormalize_images(gen_batch)
			incep_samples = np.empty((self.bs, 299, 299, 3))
			for sample_idx in range(self.bs):
				incep_samples[sample_idx] = prep_incep_img(samples[sample_idx])

			# Run prediction for current batch
			pred = self.sess.run(pred_op, feed_dict={'inputs:0': incep_samples})
			all_preds.append(pred)

		# Get rid of the first dimension
		all_preds = np.concatenate(all_preds, 0)

		print('\nComputing inception score...')
		mean, std = inception_score.get_inception_from_predictions(all_preds, 10)
		print('Inception Score | mean:', "%.2f" % mean, 'std:', "%.2f" % std)
示例#14
0
def generate_sequences(fuzzing_requests, checkers, fuzzing_jobs=1):
    """ Implements core restler algorithm.

    @param fuzzing_requests: The collection of requests that will be fuzzed
    @type  fuzzing_requests: FuzzingRequestCollection
    @param checkers: The list of checkers to apply
    @type  checkers: list[Checker]
    @param fuzzing_jobs: Optional number of fuzzing jobs for parallel fuzzing.
                            Default value passed is one (sequential fuzzing).
    @type  fuzzing_jobs: Int

    @return: None
    @rtype : None

    """
    if not fuzzing_requests.size:
        return

    logger.create_network_log(logger.LOG_TYPE_TESTING)

    fuzzing_mode = Settings().fuzzing_mode
    max_len = Settings().max_sequence_length
    if fuzzing_mode == 'directed-smoke-test':
        return generate_sequences_directed_smoketest(fuzzing_requests,
                                                     checkers)

    if fuzzing_jobs > 1:
        render = render_parallel
        global_lock = multiprocessing.Lock()
        fuzzing_pool = ThreadPool(fuzzing_jobs)
    else:
        global_lock = None
        fuzzing_pool = None
        render = render_sequential

    should_stop = False
    timeout_reached = False
    seq_collection_exhausted = False
    num_total_sequences = 0
    while not should_stop:

        seq_collection = [sequences.Sequence()]
        # Only for bfs: If any checkpoint file is available, load state of
        # latest generation. Note that it only makes sense to use checkpoints
        # for the bfs exploration method, since it is the only systemic and
        # exhaustive method.
        min_len = 0
        if fuzzing_mode == 'bfs':
            req_collection = GrammarRequestCollection()
            monitor = Monitor()
            req_collection, seq_collection, fuzzing_requests, monitor, min_len =\
                saver.load(req_collection, seq_collection, fuzzing_requests, monitor)
            requests.GlobalRequestCollection.Instance(
            )._req_collection = req_collection
            fuzzing_monitor.FuzzingMonitor.__instance = monitor
        # Repeat external loop only for random walk
        if fuzzing_mode != 'random-walk':
            should_stop = True

        # Initialize fuzzing schedule
        fuzzing_schedule = {}
        logger.write_to_main(f"Setting fuzzing schemes: {fuzzing_mode}")
        for length in range(min_len, max_len):
            fuzzing_schedule[length] = fuzzing_mode
            # print(" - {}: {}".format(length + 1, fuzzing_schedule[length]))

        # print general request-related stats
        logger.print_req_collection_stats(
            fuzzing_requests,
            GrammarRequestCollection().candidate_values_pool)

        generation = 0
        for length in range(min_len, max_len):
            # we can set this without locking, since noone else writes (main
            # driver is single-threaded) and every potential worker will just
            # read-access this value.
            generation = length + 1
            fuzzing_mode = fuzzing_schedule[length]

            # extend sequences with new request templates
            seq_collection = extend(seq_collection, fuzzing_requests,
                                    global_lock)
            print(f"{formatting.timestamp()}: Generation: {generation} ")

            logger.write_to_main(
                f"{formatting.timestamp()}: Generation: {generation} / "
                f"Sequences Collection Size: {len(seq_collection)} "
                f"(After {fuzzing_schedule[length]} Extend)")

            # render templates
            try:
                seq_collection_exhausted = False
                seq_collection = render(seq_collection, fuzzing_pool, checkers,
                                        generation, global_lock)

            except TimeOutException:
                logger.write_to_main("Timed out...")
                timeout_reached = True
                seq_collection_exhausted = True
                # Increase fuzzing generation after timeout because the code
                # that does it would have never been reached. This is done so
                # the previous generation's test summary is logged correctly.
                Monitor().current_fuzzing_generation += 1

            except ExhaustSeqCollectionException:
                logger.write_to_main("Exhausted collection...")
                seq_collection = []
                seq_collection_exhausted = True

            logger.write_to_main(
                f"{formatting.timestamp()}: Generation: {generation} / "
                f"Sequences Collection Size: {len(seq_collection)} "
                f"(After {fuzzing_schedule[length]} Render)")

            # saving latest state
            saver.save(GrammarRequestCollection(), seq_collection,
                       fuzzing_requests, Monitor(), generation)

            # Print stats for iteration of the current generation
            logger.print_generation_stats(GrammarRequestCollection(),
                                          Monitor(), global_lock)

            num_total_sequences += len(seq_collection)

            logger.print_request_rendering_stats(
                GrammarRequestCollection().candidate_values_pool,
                fuzzing_requests, Monitor(),
                Monitor().num_fully_rendered_requests(
                    fuzzing_requests.all_requests), generation, global_lock)

            if timeout_reached or seq_collection_exhausted:
                if timeout_reached:
                    should_stop = True
                break
        logger.write_to_main("--\n")

    if fuzzing_pool is not None:
        fuzzing_pool.close()
        fuzzing_pool.join()

    return num_total_sequences
    def visualize(self):
        z = tf.compat.v1.placeholder(tf.float32, [None, self.model.z_dim],
                                     name='z')
        cond = tf.compat.v1.placeholder(tf.float32,
                                        [None] + [self.model.embed_dim],
                                        name='cond')
        gen = self.model.generator(z, cond, is_training=False)

        saver = tf.compat.v1.train.Saver(
            tf.compat.v1.global_variables('g_net'))
        could_load, _ = load(saver, self.sess, self.config.CHECKPOINT_DIR)
        if could_load:
            print(" [*] Load SUCCESS")
        else:
            print(" [!] Load failed...")
            raise LookupError('Could not load any checkpoints')

        dataset_pos = np.random.randint(0, self.dataset.test.num_examples)
        for idx in range(0):
            dataset_pos = np.random.randint(0, self.dataset.test.num_examples)

            # Interpolation in z space:
            # ---------------------------------------------------------------------------------------------------------
            _, cond, _, captions = self.dataset.test.next_batch_test(
                1, dataset_pos, 1)
            cond = np.squeeze(cond, axis=0)
            caption = captions[0][0]

            samples = gen_noise_interp_img(self.sess, gen, cond,
                                           self.model.z_dim,
                                           self.model.batch_size)
            save_cap_batch(
                samples, caption,
                '{}/{}_visual/z_interp/z_interp{}.png'.format(
                    self.samples_dir, self.dataset.name, idx))
            # Interpolation in embedding space:
            # ---------------------------------------------------------------------------------------------------------

            _, cond, _, caps = self.dataset.test.next_batch_test(
                2, dataset_pos, 1)
            cond = np.squeeze(cond, axis=0)
            cond1, cond2 = cond[0], cond[1]
            cap1, cap2 = caps[0][0], caps[1][0]

            samples = gen_cond_interp_img(self.sess, gen, cond1, cond2,
                                          self.model.z_dim,
                                          self.model.batch_size)
            save_interp_cap_batch(
                samples, cap1, cap2,
                '{}/{}_visual/cond_interp/cond_interp{}.png'.format(
                    self.samples_dir, self.dataset.name, idx))
            # make_gif(samples, '{}/{}_visual/cond_interp/gifs/cond_interp{}.gif'.format(self.samples_dir,
            #                                                                            self.dataset.name,
            #                                                                            idx), duration=4)

            # Generate captioned image
            # ---------------------------------------------------------------------------------------------------------
            _, conditions, _, captions = self.dataset.test.next_batch_test(
                1, dataset_pos, 1)
            conditions = np.squeeze(conditions, axis=0)
            caption = captions[0][0]
            samples = gen_captioned_img(self.sess, gen, conditions,
                                        self.model.z_dim,
                                        self.model.batch_size)

            save_cap_batch(
                samples, caption,
                '{}/{}_visual/cap/cap{}.png'.format(self.samples_dir,
                                                    self.dataset.name, idx))

        for idx, special_pos in enumerate([1112, 900, 398]):
            print(special_pos)
            # Generate specific image
            # ---------------------------------------------------------------------------------------------------------
            _, conditions, _, captions = self.dataset.test.next_batch_test(
                1, special_pos, 1)
            conditions = np.squeeze(conditions, axis=0)
            caption = captions[0][0]
            samples = gen_captioned_img(self.sess, gen, conditions,
                                        self.model.z_dim,
                                        self.model.batch_size)

            save_cap_batch(
                samples, caption, '{}/{}_visual/special_cap/cap{}.png'.format(
                    self.samples_dir, self.dataset.name, idx))

        # Generate some images and their closest neighbours
        # ---------------------------------------------------------------------------------------------------------
        _, conditions, _, _ = self.dataset.test.next_batch_test(
            self.model.batch_size, dataset_pos, 1)
        conditions = np.squeeze(conditions)
        samples, neighbours = gen_closest_neighbour_img(
            self.sess, gen, conditions, self.model.z_dim,
            self.model.batch_size, self.dataset)
        batch = np.concatenate([samples, neighbours])
        text = 'Generated images (first row) and their closest neighbours (second row)'
        save_cap_batch(
            batch, text,
            '{}/{}_visual/neighb/neighb.png'.format(self.samples_dir,
                                                    self.dataset.name))
示例#16
0
    def train(self):
        self.define_losses()
        self.define_summaries()

        sample_z = np.random.normal(0, 1,
                                    (self.model.sample_num, self.model.z_dim))
        _, sample_embed, _, captions = self.dataset.test.next_batch_test(
            self.model.sample_num, 0, 1)
        sample_embed = np.squeeze(sample_embed, axis=0)
        print(sample_embed.shape)

        save_captions(self.cfg.SAMPLE_DIR, captions)

        counter = 1
        start_time = time.time()

        could_load, checkpoint_counter = load(self.saver, self.sess,
                                              self.cfg.CHECKPOINT_DIR)
        if could_load:
            counter = checkpoint_counter
            print(" [*] Load SUCCESS")
        else:
            print(" [!] Load failed...")

        initialize_uninitialized(self.sess)

        # Updates per epoch are given by the training data size / batch size
        updates_per_epoch = self.dataset.train.num_examples // self.model.batch_size
        epoch_start = counter // updates_per_epoch

        for epoch in range(epoch_start, self.cfg.TRAIN.EPOCH):
            cen_epoch = epoch // 100

            for idx in range(0, updates_per_epoch):
                images, wrong_images, embed, _, _ = self.dataset.train.next_batch(
                    self.model.batch_size, 4, embeddings=True, wrong_img=True)
                batch_z = np.random.normal(
                    0, 1, (self.model.batch_size, self.model.z_dim))

                feed_dict = {
                    self.learning_rate: self.lr * (0.5**cen_epoch),
                    self.model.inputs: images,
                    self.model.wrong_inputs: wrong_images,
                    self.model.embed_inputs: embed,
                    self.model.z: batch_z,
                }

                # Update D network
                _, err_d, summary_str = self.sess.run(
                    [self.D_optim, self.D_loss, self.D_merged_summ],
                    feed_dict=feed_dict)
                self.writer.add_summary(summary_str, counter)

                # Update G network
                _, err_g, summary_str = self.sess.run(
                    [self.G_optim, self.G_loss, self.G_merged_summ],
                    feed_dict=feed_dict)
                self.writer.add_summary(summary_str, counter)

                counter += 1
                print(
                    "Epoch: [%2d] [%4d/%4d] time: %4.4f, d_loss: %.8f, g_loss: %.8f"
                    % (epoch, idx, updates_per_epoch, time.time() - start_time,
                       err_d, err_g))

                if np.mod(counter, 500) == 0:
                    try:
                        samples = self.sess.run(self.model.sampler,
                                                feed_dict={
                                                    self.model.z_sample:
                                                    sample_z,
                                                    self.model.embed_sample:
                                                    sample_embed,
                                                })
                        save_images(
                            samples,
                            get_balanced_factorization(samples.shape[0]),
                            '{}train_{:02d}_{:04d}.png'.format(
                                self.cfg.SAMPLE_DIR, epoch, idx))
                    except Exception as e:
                        print("Failed to generate sample image")
                        print(type(e))
                        print(e.args)
                        print(e)

                if np.mod(counter, 500) == 0:
                    save(self.saver, self.sess, self.cfg.CHECKPOINT_DIR,
                         counter)
示例#17
0
	def train(self):
		self.define_losses()
		self.define_summaries()

		sample_z = np.random.normal(0, 1, (self.model.sample_num, self.model.z_dim))
		_, sample_embed, _, captions = self.dataset.test.next_batch_test(self.model.sample_num, 0, 1)
		im_feats_test, sent_feats_test, labels_test = self.test_data_loader.get_batch(0,self.cfg.RETRIEVAL.SAMPLE_NUM,\
														image_aug = self.cfg.RETRIEVAL.IMAGE_AUG, phase='test')        
		sample_embed = np.squeeze(sample_embed, axis=0)
		print(sample_embed.shape)

		save_captions(self.cfg.SAMPLE_DIR, captions)

		counter = 1
		start_time = time.time()

		could_load, checkpoint_counter = load(self.stageii_saver, self.sess, self.cfg.CHECKPOINT_DIR)
		if could_load:
			counter = checkpoint_counter
			print(" [*] Load SUCCESS: Stage II networks are loaded.")
		else:
			print(" [!] Load failed for stage II networks...")

		could_load, checkpoint_counter = load(self.stagei_g_saver, self.sess, self.cfg_stage_i.CHECKPOINT_DIR)
		if could_load:
			print(" [*] Load SUCCESS: Stage I generator is loaded")
		else:
			print(" [!] WARNING!!! Failed to load the parameters for stage I generator...")

		initialize_uninitialized(self.sess)

		# Updates per epoch are given by the training data size / batch size
		updates_per_epoch = self.dataset.train.num_examples // self.model.batch_size
		epoch_start = counter // updates_per_epoch

		for epoch in range(epoch_start, self.cfg.TRAIN.EPOCH):
			cen_epoch = epoch // 100

			for idx in range(0, updates_per_epoch):
				images, wrong_images, embed, _, _ = self.dataset.train.next_batch(self.model.batch_size, 1,
																				  embeddings=True,
																				  wrong_img=True)
				batch_z = np.random.normal(0, 1, (self.model.batch_size, self.model.z_dim))

				# Retrieval data loader
				if idx % updates_per_epoch == 0:
					self.R_loader.shuffle_inds()
				
				im_feats, sent_feats, labels = self.R_loader.get_batch(idx % updates_per_epoch,\
								self.cfg.RETRIEVAL.BATCH_SIZE, image_aug = self.cfg.RETRIEVAL.IMAGE_AUG)                

				feed_dict = {
					self.learning_rate: self.lr * (0.5**cen_epoch),
					self.model.inputs: images,
					self.model.wrong_inputs: wrong_images,
					# self.model.embed_inputs: embed,
					# self.model.embed_inputs: self.txt_emb,
					self.model.z: batch_z,
					self.Retrieval.image_placeholder : im_feats, 
					self.Retrieval.sent_placeholder : sent_feats,
					self.Retrieval.label_placeholder : labels
				}

				# Update D network
				_, err_d, summary_str = self.sess.run([self.D_optim, self.D_loss, self.D_merged_summ],
													  feed_dict=feed_dict)
				self.writer.add_summary(summary_str, counter)

				# Update G network
				_, err_g, summary_str = self.sess.run([self.G_optim, self.G_loss, self.G_merged_summ],
													  feed_dict=feed_dict)
				self.writer.add_summary(summary_str, counter)
				
				# Update R network
				_, err_r, summary_str = self.sess.run([self.R_optim, self.R_loss, self.R_loss_summ],
													  feed_dict=feed_dict)
				self.writer.add_summary(summary_str, counter)                 

				counter += 1
				print("Epoch: [%2d] [%4d/%4d] time: %4.4f, d_loss: %.8f, g_loss: %.8f, r_loss: %.8f"
					  % (epoch, idx, updates_per_epoch,
						 time.time() - start_time, err_d, err_g, err_r))

				if np.mod(counter, 1000) == 0:
					try:
						# pdb.set_trace()
						self.Retrieval.eval()
						sent_emb = self.sess.run(self.Retrieval.sent_embed_tensor,
												feed_dict={
															self.Retrieval.image_placeholder_test: im_feats_test,
															self.Retrieval.sent_placeholder_test: sent_feats_test,
														  })
						self.model.eval(sent_emb)								  
						samples = self.sess.run(self.model.sampler,
												feed_dict={
															self.model.z_sample: sample_z,
															# self.model.embed_sample: sample_embed,
															self.model.embed_sample: sent_emb,
														  })
						save_images(samples, get_balanced_factorization(samples.shape[0]),
									'{}train_{:02d}_{:04d}.png'.format(self.cfg.SAMPLE_DIR, epoch, idx))
					except Exception as e:
						print("Failed to generate sample image")
						print(type(e))
						print(e.args)
						print(e)

				if np.mod(counter, 500) == 2:
					save(self.stageii_saver, self.sess, self.cfg.CHECKPOINT_DIR, counter)

			if np.mod(epoch, 50) == 0 and epoch!=0:
				self.ret_eval(epoch)