Ejemplo n.º 1
0
 def compute_inception_score(n):
     all_samples = []
     for i in range(int(n / 100)):
         all_samples.append(session.run(samples_100))
     all_samples = np.concatenate(all_samples, axis=0)
     all_samples = all_samples.reshape((-1, 3, 32, 32))
     all_samples = scale_value(all_samples, [-1.0, 1.0])
     print(all_samples.shape)
     return get_inception_score(all_samples)
Ejemplo n.º 2
0
def compute_metric(generator_model):
    global best_icp
    sample_size = 20000
    noise = np.random.normal(size=(sample_size, 100))
    art_images = generator_model.predict(noise)
    art_images = scale_value(art_images, [-1.0, 1.0])
    art_images = np.transpose(art_images, (0, 3, 1, 2))
    (icp_mean, icp_std) = get_inception_score(art_images)
    if icp_mean > best_icp: best_icp = icp_mean
    print('Inception score: ', icp_mean)
Ejemplo n.º 3
0
def test(target_vars, saver, sess, logger, dataloader):
    X_NOISE = target_vars['X_NOISE']
    X = target_vars['X']
    Y = target_vars['Y']
    LABEL = target_vars['LABEL']
    energy_start = target_vars['energy_start']
    x_mod = target_vars['x_mod']
    x_mod = target_vars['test_x_mod']
    energy_neg = target_vars['energy_neg']

    np.random.seed(1)
    random.seed(1)

    output = [x_mod, energy_start, energy_neg]

    dataloader_iterator = iter(dataloader)
    data_corrupt, data, label = next(dataloader_iterator)
    data_corrupt, data, label = data_corrupt.numpy(), data.numpy(
    ), label.numpy()

    orig_im = try_im = data_corrupt

    if FLAGS.cclass:
        try_im, energy_orig, energy = sess.run(output, {
            X_NOISE: orig_im,
            Y: label[0:1],
            LABEL: label
        })
    else:
        try_im, energy_orig, energy = sess.run(output, {
            X_NOISE: orig_im,
            Y: label[0:1]
        })

    orig_im = rescale_im(orig_im)
    try_im = rescale_im(try_im)
    actual_im = rescale_im(data)

    for i, (im, energy_i, t_im, energy, label_i, actual_im_i) in enumerate(
            zip(orig_im, energy_orig, try_im, energy, label, actual_im)):
        label_i = np.array(label_i)

        shape = im.shape[1:]
        new_im = np.zeros((shape[0], shape[1] * 3, *shape[2:]))
        size = shape[1]
        new_im[:, :size] = im
        new_im[:, size:2 * size] = t_im

        if FLAGS.cclass:
            label_i = np.where(label_i == 1)[0][0]
            if FLAGS.dataset == 'cifar10':
                log_image(new_im,
                          logger,
                          '{}_{:.4f}_now_{:.4f}_{}'.format(
                              i, energy_i[0], energy[0], cifar10_map[label_i]),
                          step=i)
            else:
                log_image(new_im,
                          logger,
                          '{}_{:.4f}_now_{:.4f}_{}'.format(
                              i, energy_i[0], energy[0], label_i),
                          step=i)
        else:
            log_image(new_im,
                      logger,
                      '{}_{:.4f}_now_{:.4f}'.format(i, energy_i[0], energy[0]),
                      step=i)

    test_ims = list(try_im)
    real_ims = list(actual_im)

    for i in tqdm(range(50000 // FLAGS.batch_size + 1)):
        try:
            data_corrupt, data, label = dataloader_iterator.next()
        except BaseException:
            dataloader_iterator = iter(dataloader)
            data_corrupt, data, label = dataloader_iterator.next()

        data_corrupt, data, label = data_corrupt.numpy(), data.numpy(
        ), label.numpy()

        if FLAGS.cclass:
            try_im, energy_orig, energy = sess.run(output, {
                X_NOISE: data_corrupt,
                Y: label[0:1],
                LABEL: label
            })
        else:
            try_im, energy_orig, energy = sess.run(output, {
                X_NOISE: data_corrupt,
                Y: label[0:1]
            })

        try_im = rescale_im(try_im)
        real_im = rescale_im(data)

        test_ims.extend(list(try_im))
        real_ims.extend(list(real_im))

    score, std = get_inception_score(test_ims)
    print("!!!Inception score of {} with std of {}".format(score, std))
Ejemplo n.º 4
0
def train(target_vars, saver, sess, logger, dataloader, resume_iter, logdir):
    X = target_vars['X']
    Y = target_vars['Y']
    X_NOISE = target_vars['X_NOISE']
    train_op = target_vars['train_op']
    energy_pos = target_vars['energy_pos']
    energy_neg = target_vars['energy_neg']
    loss_energy = target_vars['loss_energy']
    loss_ml = target_vars['loss_ml']
    loss_total = target_vars['total_loss']
    gvs = target_vars['gvs']
    x_grad = target_vars['x_grad']
    x_grad_first = target_vars['x_grad_first']
    x_off = target_vars['x_off']
    temp = target_vars['temp']
    x_mod = target_vars['x_mod']
    LABEL = target_vars['LABEL']
    LABEL_POS = target_vars['LABEL_POS']
    weights = target_vars['weights']
    test_x_mod = target_vars['test_x_mod']
    eps = target_vars['eps_begin']
    label_ent = target_vars['label_ent']

    if FLAGS.use_attention:
        gamma = weights[0]['atten']['gamma']
    else:
        gamma = tf.zeros(1)

    val_output = [test_x_mod]

    gvs_dict = dict(gvs)

    log_output = [
        train_op, energy_pos, energy_neg, eps, loss_energy, loss_ml,
        loss_total, x_grad, x_off, x_mod, gamma, x_grad_first, label_ent,
        *gvs_dict.keys()
    ]
    output = [train_op, x_mod]

    replay_buffer = ReplayBuffer(10000)
    itr = resume_iter
    x_mod = None
    gd_steps = 1

    dataloader_iterator = iter(dataloader)
    best_inception = 0.0

    for epoch in range(FLAGS.epoch_num):
        print("Training epoch:%d" % epoch)
        for data_corrupt, data, label in dataloader:
            data_corrupt = data_corrupt_init = data_corrupt.numpy()
            data_corrupt_init = data_corrupt.copy()

            data = data.numpy()
            label = label.numpy()

            label_init = label.copy()

            if FLAGS.mixup:
                idx = np.random.permutation(data.shape[0])
                lam = np.random.beta(1, 1, size=(data.shape[0], 1, 1, 1))
                data = data * lam + data[idx] * (1 - lam)

            if FLAGS.replay_batch and (x_mod is not None):
                replay_buffer.add(compress_x_mod(x_mod))

                if len(replay_buffer) > FLAGS.batch_size:
                    replay_batch = replay_buffer.sample(FLAGS.batch_size)
                    replay_batch = decompress_x_mod(replay_batch)
                    replay_mask = (np.random.uniform(0, FLAGS.rescale,
                                                     FLAGS.batch_size) > 0.05)
                    data_corrupt[replay_mask] = replay_batch[replay_mask]

            if FLAGS.pcd:
                if x_mod is not None:
                    data_corrupt = x_mod

            feed_dict = {X_NOISE: data_corrupt, X: data, Y: label}

            if FLAGS.cclass:
                feed_dict[LABEL] = label
                feed_dict[LABEL_POS] = label_init

            if itr % FLAGS.log_interval == 0:
                _, e_pos, e_neg, eps, loss_e, loss_ml, loss_total, x_grad, x_off, x_mod, gamma, x_grad_first, label_ent, * \
                    grads = sess.run(log_output, feed_dict)

                kvs = {}
                kvs['e_pos'] = e_pos.mean()
                kvs['e_pos_std'] = e_pos.std()
                kvs['e_neg'] = e_neg.mean()
                kvs['e_diff'] = kvs['e_pos'] - kvs['e_neg']
                kvs['e_neg_std'] = e_neg.std()
                kvs['temp'] = temp
                kvs['loss_e'] = loss_e.mean()
                kvs['eps'] = eps.mean()
                kvs['label_ent'] = label_ent
                kvs['loss_ml'] = loss_ml.mean()
                kvs['loss_total'] = loss_total.mean()
                kvs['x_grad'] = np.abs(x_grad).mean()
                kvs['x_grad_first'] = np.abs(x_grad_first).mean()
                kvs['x_off'] = x_off.mean()
                kvs['iter'] = itr
                kvs['gamma'] = gamma

                for v, k in zip(grads, [v.name for v in gvs_dict.values()]):
                    kvs[k] = np.abs(v).max()

                string = "Obtained a total of "
                for key, value in kvs.items():
                    string += "{}: {}, ".format(key, value)

                if hvd.rank() == 0:
                    print(string)
                    logger.writekvs(kvs)
            else:
                _, x_mod = sess.run(output, feed_dict)

            if itr % FLAGS.save_interval == 0 and hvd.rank() == 0:
                saver.save(
                    sess,
                    osp.join(FLAGS.logdir, FLAGS.exp, 'model_{}'.format(itr)))

            if itr % FLAGS.test_interval == 0 and hvd.rank(
            ) == 0 and FLAGS.dataset != '2d':
                try_im = x_mod
                orig_im = data_corrupt.squeeze()
                actual_im = rescale_im(data)

                orig_im = rescale_im(orig_im)
                try_im = rescale_im(try_im).squeeze()

                for i, (im, t_im, actual_im_i) in enumerate(
                        zip(orig_im[:20], try_im[:20], actual_im)):
                    shape = orig_im.shape[1:]
                    new_im = np.zeros((shape[0], shape[1] * 3, *shape[2:]))
                    size = shape[1]
                    new_im[:, :size] = im
                    new_im[:, size:2 * size] = t_im
                    new_im[:, 2 * size:] = actual_im_i

                    log_image(new_im,
                              logger,
                              'train_gen_{}'.format(itr),
                              step=i)

                test_im = x_mod

                try:
                    data_corrupt, data, label = next(dataloader_iterator)
                except BaseException:
                    dataloader_iterator = iter(dataloader)
                    data_corrupt, data, label = next(dataloader_iterator)

                data_corrupt = data_corrupt.numpy()

                if FLAGS.replay_batch and (
                        x_mod is not None) and len(replay_buffer) > 0:
                    replay_batch = replay_buffer.sample(FLAGS.batch_size)
                    replay_batch = decompress_x_mod(replay_batch)
                    replay_mask = (np.random.uniform(0, 1, (FLAGS.batch_size))
                                   > 0.05)
                    data_corrupt[replay_mask] = replay_batch[replay_mask]

                if FLAGS.dataset == 'cifar10' or FLAGS.dataset == 'imagenet' or FLAGS.dataset == 'imagenetfull':
                    n = 128

                    if FLAGS.dataset == "imagenetfull":
                        n = 32

                    if len(replay_buffer) > n:
                        data_corrupt = decompress_x_mod(
                            replay_buffer.sample(n))
                    elif FLAGS.dataset == 'imagenetfull':
                        data_corrupt = np.random.uniform(
                            0, FLAGS.rescale, (n, 128, 128, 3))
                    else:
                        data_corrupt = np.random.uniform(
                            0, FLAGS.rescale, (n, 32, 32, 3))

                    if FLAGS.dataset == 'cifar10':
                        label = np.eye(10)[np.random.randint(0, 10, (n))]
                    else:
                        label = np.eye(1000)[np.random.randint(0, 1000, (n))]

                feed_dict[X_NOISE] = data_corrupt

                feed_dict[X] = data

                if FLAGS.cclass:
                    feed_dict[LABEL] = label

                test_x_mod = sess.run(val_output, feed_dict)

                try_im = test_x_mod
                orig_im = data_corrupt.squeeze()
                actual_im = rescale_im(data.numpy())

                orig_im = rescale_im(orig_im)
                try_im = rescale_im(try_im).squeeze()

                for i, (im, t_im, actual_im_i) in enumerate(
                        zip(orig_im[:20], try_im[:20], actual_im)):

                    shape = orig_im.shape[1:]
                    new_im = np.zeros((shape[0], shape[1] * 3, *shape[2:]))
                    size = shape[1]
                    new_im[:, :size] = im
                    new_im[:, size:2 * size] = t_im
                    new_im[:, 2 * size:] = actual_im_i
                    log_image(new_im, logger, 'val_gen_{}'.format(itr), step=i)

                score, std = get_inception_score(list(try_im), splits=1)
                print("///Inception score of {} with std of {}".format(
                    score, std))
                kvs = {}
                kvs['inception_score'] = score
                kvs['inception_score_std'] = std
                logger.writekvs(kvs)

                if score > best_inception:
                    best_inception = score
                    saver.save(sess,
                               osp.join(FLAGS.logdir, FLAGS.exp, 'model_best'))

            if itr > 60000 and FLAGS.dataset == "mnist":
                assert False
            itr += 1
            print("Training iteration:%d" % itr)

    saver.save(sess, osp.join(FLAGS.logdir, FLAGS.exp, 'model_{}'.format(itr)))
Ejemplo n.º 5
0
def compute_inception(sess, target_vars):
    X_START = target_vars['X_START']
    Y_GT = target_vars['Y_GT']
    X_finals = target_vars['X_finals']
    NOISE_SCALE = target_vars['NOISE_SCALE']
    energy_noise = target_vars['energy_noise']

    size = FLAGS.im_number
    num_steps = size // 1000

    images = []
    test_ims = []
    test_images = []


    if FLAGS.dataset == "cifar10":
        test_dataset = Cifar10(full=True, noise=False)
    elif FLAGS.dataset == "celeba":
        dataset = CelebA()
    elif FLAGS.dataset == "imagenet" or FLAGS.dataset == "imagenetfull":
        test_dataset = Imagenet(train=False)

    if FLAGS.dataset != "imagenetfull":
        test_dataloader = DataLoader(test_dataset, batch_size=FLAGS.batch_size, num_workers=4, shuffle=True, drop_last=False)
    else:
        test_dataloader = TFImagenetLoader('test', FLAGS.batch_size, 0, 1)

    for data_corrupt, data, label_gt in tqdm(test_dataloader):
        data = data.numpy()
        test_ims.extend(list(rescale_im(data)))

        if FLAGS.dataset == "imagenetfull" and len(test_ims) > 60000:
            test_ims = test_ims[:60000]
            break


    # n = min(len(images), len(test_ims))
    print(len(test_ims))
    # fid = get_fid_score(test_ims[:30000], test_ims[-30000:])
    # print("Base FID of score {}".format(fid))

    if FLAGS.dataset == "cifar10":
        classes = 10
    else:
        classes = 1000

    if FLAGS.dataset == "imagenetfull":
        n = 128
    else:
        n = 32

    for j in range(num_steps):
        itr = int(1000 / 500 * FLAGS.repeat_scale)
        data_buffer = InceptionReplayBuffer(1000)
        curr_index = 0

        identity = np.eye(classes)

        test_steps = range(300, itr, 20)

        for i in tqdm(range(itr)):
            model_index = curr_index % len(X_finals)
            x_final = X_finals[model_index]

            noise_scale = [1]
            if len(data_buffer) < 1000:
                x_init = np.random.uniform(0, 1, (FLAGS.batch_size, n, n, 3))
                label = np.random.randint(0, classes, (FLAGS.batch_size))
                label = identity[label]
                x_new = sess.run([x_final], {X_START:x_init, Y_GT:label, NOISE_SCALE: noise_scale})[0]
                data_buffer.add(x_new, label)
            else:
                (x_init, label), idx = data_buffer.sample(FLAGS.batch_size)
                keep_mask = (np.random.uniform(0, 1, (FLAGS.batch_size)) > 0.99)
                label_keep_mask = (np.random.uniform(0, 1, (FLAGS.batch_size)) > 0.9)
                label_corrupt = np.random.randint(0, classes, (FLAGS.batch_size))
                label_corrupt = identity[label_corrupt]
                x_init_corrupt = np.random.uniform(0, 1, (FLAGS.batch_size, n, n, 3))

                if i < itr - FLAGS.nomix:
                    x_init[keep_mask] = x_init_corrupt[keep_mask]
                    label[label_keep_mask] = label_corrupt[label_keep_mask]
                # else:
                #     noise_scale = [0.7]

                x_new, e_noise = sess.run([x_final, energy_noise], {X_START:x_init, Y_GT:label, NOISE_SCALE: noise_scale})
                data_buffer.set_elms(idx, x_new, label)

            curr_index += 1

        ims = np.array(data_buffer._storage[:1000])
        ims = rescale_im(ims)
        test_images.extend(list(ims))

    saveim = osp.join(FLAGS.logdir, FLAGS.exp, "test{}.png".format(FLAGS.resume_iter))
    row = 15
    col = 20
    ims = ims[:row * col]
    if FLAGS.dataset != "imagenetfull":
        im_panel = ims.reshape((row, col, 32, 32, 3)).transpose((0, 2, 1, 3, 4)).reshape((32*row, 32*col, 3))
    else:
        im_panel = ims.reshape((row, col, 128, 128, 3)).transpose((0, 2, 1, 3, 4)).reshape((128*row, 128*col, 3))
    imsave(saveim, im_panel)

    splits = max(1, len(test_images) // 5000)
    score, std = get_inception_score(test_images, splits=splits)
    print("Inception score of {} with std of {}".format(score, std))

    # FID score
    # n = min(len(images), len(test_ims))
    fid = get_fid_score(test_images, test_ims)
    print("FID of score {}".format(fid))
Ejemplo n.º 6
0
def run(config):
    # Prepare state dict, which holds things like epoch # and itr #
    state_dict = {
        'itr': 0,
        'epoch': 0,
        'save_num': 0,
        'save_best_num': 0,
        'best_IS': 0,
        'best_FID': 999999,
        'config': config
    }

    # Optionally, get the configuration from the state dict. This allows for
    # recovery of the config provided only a state dict and experiment name,
    # and can be convenient for writing less verbose sample shell scripts.
    if config['config_from_name']:
        utils.load_weights(None,
                           None,
                           state_dict,
                           config['weights_root'],
                           config['experiment_name'],
                           config['load_weights'],
                           None,
                           strict=False,
                           load_optim=False)
        # Ignore items which we might want to overwrite from the command line
        for item in state_dict['config']:
            if item not in [
                    'z_var', 'base_root', 'batch_size', 'G_batch_size',
                    'use_ema', 'G_eval_mode'
            ]:
                config[item] = state_dict['config'][item]

    # update config (see train.py for explanation)
    config['resolution'] = utils.imsize_dict[config['dataset']]
    config['n_classes'] = utils.nclass_dict[config['dataset']]
    config['n_channels'] = utils.nchannels_dict[config['dataset']]
    config['G_activation'] = utils.activation_dict[config['G_nl']]
    config['D_activation'] = utils.activation_dict[config['D_nl']]
    config = utils.update_config_roots(config)
    config['skip_init'] = True
    config['no_optim'] = True
    device = 'cuda'

    # Seed RNG
    # utils.seed_rng(config['seed'])

    # Setup cudnn.benchmark for free speed
    torch.backends.cudnn.benchmark = True

    # Import the model--this line allows us to dynamically select different files.
    model = __import__(config['model'])
    experiment_name = (config['experiment_name'] if config['experiment_name']
                       else utils.name_from_config(config))
    print('Experiment name is %s' % experiment_name)

    G = model.Generator(**config).cuda()
    utils.count_parameters(G)

    # In some cases we need to load D
    if True or config['get_test_error'] or config['get_train_error'] or config[
            'get_self_error'] or config['get_generator_error']:
        disc_config = config.copy()
        if config['mh_csc_loss'] or config['mh_loss']:
            disc_config['output_dim'] = disc_config['n_classes'] + 1
        D = model.Discriminator(**disc_config).to(device)

        def get_n_correct_from_D(x, y):
            """Gets the "classifications" from D.
      
      y: the correct labels
      
      In the case of projection discrimination we have to pass in all the labels
      as conditionings to get the class specific affinity.
      """
            x = x.to(device)
            if config['model'] == 'BigGAN':  # projection discrimination case
                if not config['get_self_error']:
                    y = y.to(device)
                yhat = D(x, y)
                for i in range(1, config['n_classes']):
                    yhat_ = D(x, ((y + i) % config['n_classes']))
                    yhat = torch.cat([yhat, yhat_], 1)
                preds_ = yhat.data.max(1)[1].cpu()
                return preds_.eq(0).cpu().sum()
            else:  # the mh gan case
                if not config['get_self_error']:
                    y = y.to(device)
                yhat = D(x)
                preds_ = yhat[:, :config['n_classes']].data.max(1)[1]
                return preds_.eq(y.data).cpu().sum()

    # Load weights
    print('Loading weights...')
    # Here is where we deal with the ema--load ema weights or load normal weights
    utils.load_weights(G if not (config['use_ema']) else None,
                       D,
                       state_dict,
                       config['weights_root'],
                       experiment_name,
                       config['load_weights'],
                       G if config['ema'] and config['use_ema'] else None,
                       strict=False,
                       load_optim=False)
    # Update batch size setting used for G
    G_batch_size = max(config['G_batch_size'], config['batch_size'])
    z_, y_ = utils.prepare_z_y(G_batch_size,
                               G.dim_z,
                               config['n_classes'],
                               device=device,
                               fp16=config['G_fp16'],
                               z_var=config['z_var'])

    if config['G_eval_mode']:
        print('Putting G in eval mode..')
        G.eval()
    else:
        print('G is in %s mode...' % ('training' if G.training else 'eval'))

    sample = functools.partial(utils.sample, G=G, z_=z_, y_=y_, config=config)
    brief_expt_name = config['experiment_name'][-30:]

    # load results dict always
    HIST_FNAME = 'scoring_hist.npy'

    def load_or_make_hist(d):
        """make/load history files in each
    """
        if not os.path.isdir(d):
            raise Exception('%s is not a valid directory' % d)
        f = os.path.join(d, HIST_FNAME)
        if os.path.isfile(f):
            return np.load(f, allow_pickle=True).item()
        else:
            return defaultdict(dict)

    hist_dir = os.path.join(config['weights_root'], config['experiment_name'])
    hist = load_or_make_hist(hist_dir)

    if config['get_test_error'] or config['get_train_error']:
        loaders = utils.get_data_loaders(
            **{
                **config, 'batch_size': config['batch_size'],
                'start_itr': state_dict['itr'],
                'use_test_set': config['get_test_error']
            })
        acc_type = 'Test' if config['get_test_error'] else 'Train'

        pbar = tqdm(loaders[0])
        loader_total = len(loaders[0]) * config['batch_size']
        sample_todo = min(config['sample_num_error'], loader_total)
        print('Getting %s error accross %i examples' % (acc_type, sample_todo))
        correct = 0
        total = 0

        with torch.no_grad():
            for i, (x, y) in enumerate(pbar):
                correct += get_n_correct_from_D(x, y)
                total += config['batch_size']
                if loader_total > total and total >= config['sample_num_error']:
                    print('Quitting early...')
                    break

        accuracy = float(correct) / float(total)
        hist = load_or_make_hist(hist_dir)
        hist[state_dict['itr']][acc_type] = accuracy
        np.save(os.path.join(hist_dir, HIST_FNAME), hist)

        print('[%s][%06d] %s accuracy: %f.' %
              (brief_expt_name, state_dict['itr'], acc_type, accuracy * 100))

    if config['get_self_error']:
        n_used_imgs = config['sample_num_error']
        correct = 0
        imageSize = config['resolution']
        x = np.empty((n_used_imgs, imageSize, imageSize, 3), dtype=np.uint8)
        for l in tqdm(range(n_used_imgs // G_batch_size),
                      desc='Generating [%s][%06d]' %
                      (brief_expt_name, state_dict['itr'])):
            with torch.no_grad():
                images, y = sample()
                correct += get_n_correct_from_D(images, y)

        accuracy = float(correct) / float(n_used_imgs)
        print('[%s][%06d] %s accuracy: %f.' %
              (brief_expt_name, state_dict['itr'], 'Self', accuracy * 100))
        hist = load_or_make_hist(hist_dir)
        hist[state_dict['itr']]['Self'] = accuracy
        np.save(os.path.join(hist_dir, HIST_FNAME), hist)

    if config['get_generator_error']:

        if config['dataset'] == 'C10':
            from classification.models.densenet import DenseNet121
            from torchvision import transforms
            compnet = DenseNet121()
            compnet = torch.nn.DataParallel(compnet)
            #checkpoint = torch.load(os.path.join('/scratch0/ilya/locDoc/classifiers/densenet121','ckpt_47.t7'))
            checkpoint = torch.load(
                os.path.join(
                    '/fs/vulcan-scratch/ilyak/locDoc/experiments/classifiers/cifar/densenet121',
                    'ckpt_47.t7'))
            compnet.load_state_dict(checkpoint['net'])
            compnet = compnet.to(device)
            compnet.eval()
            minimal_trans = transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize((0.4914, 0.4822, 0.4465),
                                     (0.2023, 0.1994, 0.2010)),
            ])
        elif config['dataset'] == 'C100':
            from classification.models.densenet import DenseNet121
            from torchvision import transforms
            compnet = DenseNet121(num_classes=100)
            compnet = torch.nn.DataParallel(compnet)
            checkpoint = torch.load(
                os.path.join(
                    '/scratch0/ilya/locDoc/classifiers/cifar100/densenet121',
                    'ckpt.copy.t7'))
            #checkpoint = torch.load(os.path.join('/fs/vulcan-scratch/ilyak/locDoc/experiments/classifiers/cifar100/densenet121','ckpt.copy.t7'))
            compnet.load_state_dict(checkpoint['net'])
            compnet = compnet.to(device)
            compnet.eval()
            minimal_trans = transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize((0.507, 0.487, 0.441),
                                     (0.267, 0.256, 0.276)),
            ])
        elif config['dataset'] == 'STL48':
            from classification.models.wideresnet import WideResNet48
            from torchvision import transforms
            checkpoint = torch.load(
                os.path.join(
                    '/fs/vulcan-scratch/ilyak/locDoc/experiments/classifiers/stl/mixmatch_48',
                    'model_best.pth.tar'))
            compnet = WideResNet48(num_classes=10)
            compnet = compnet.to(device)
            for param in compnet.parameters():
                param.detach_()
            compnet.load_state_dict(checkpoint['ema_state_dict'])
            compnet.eval()
            minimal_trans = transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize((0.4914, 0.4822, 0.4465),
                                     (0.2023, 0.1994, 0.2010)),
            ])
        else:
            raise ValueError('Dataset %s has no comparison network.' %
                             config['dataset'])

        n_used_imgs = 10000
        correct = 0
        mean_label = np.zeros(config['n_classes'])
        imageSize = config['resolution']
        x = np.empty((n_used_imgs, imageSize, imageSize, 3), dtype=np.uint8)
        for l in tqdm(range(n_used_imgs // G_batch_size),
                      desc='Generating [%s][%06d]' %
                      (brief_expt_name, state_dict['itr'])):
            with torch.no_grad():
                images, y = sample()
                fake = images.data.cpu().numpy()
                fake = np.floor((fake + 1) * 255 / 2.0).astype(np.uint8)
                fake_input = np.zeros(fake.shape)
                for bi in range(fake.shape[0]):
                    fake_input[bi] = minimal_trans(np.moveaxis(
                        fake[bi], 0, -1))
                images.data.copy_(torch.from_numpy(fake_input))
                lab = compnet(images).max(1)[1]
                mean_label += np.bincount(lab.data.cpu(),
                                          minlength=config['n_classes'])
                correct += int((lab == y).sum().cpu())

        accuracy = float(correct) / float(n_used_imgs)
        mean_label_normalized = mean_label / float(n_used_imgs)

        print(
            '[%s][%06d] %s accuracy: %f.' %
            (brief_expt_name, state_dict['itr'], 'Generator', accuracy * 100))
        hist = load_or_make_hist(hist_dir)
        hist[state_dict['itr']]['Generator'] = accuracy
        hist[state_dict['itr']]['Mean_Label'] = mean_label_normalized
        np.save(os.path.join(hist_dir, HIST_FNAME), hist)

    if config['accumulate_stats']:
        print('Accumulating standing stats across %d accumulations...' %
              config['num_standing_accumulations'])
        utils.accumulate_standing_stats(G, z_, y_, config['n_classes'],
                                        config['num_standing_accumulations'])

    # Sample a number of images and save them to an NPZ, for use with TF-Inception
    if config['sample_npz']:
        # Lists to hold images and labels for images
        x, y = [], []
        print('Sampling %d images and saving them to npz...' %
              config['sample_num_npz'])
        for i in trange(
                int(np.ceil(config['sample_num_npz'] / float(G_batch_size)))):
            with torch.no_grad():
                images, labels = sample()
            x += [np.uint8(255 * (images.cpu().numpy() + 1) / 2.)]
            y += [labels.cpu().numpy()]
        x = np.concatenate(x, 0)[:config['sample_num_npz']]
        y = np.concatenate(y, 0)[:config['sample_num_npz']]
        print('Images shape: %s, Labels shape: %s' % (x.shape, y.shape))
        npz_filename = '%s/%s/samples.npz' % (config['samples_root'],
                                              experiment_name)
        print('Saving npz to %s...' % npz_filename)
        np.savez(npz_filename, **{'x': x, 'y': y})

    if config['official_FID']:
        f = np.load(config['dataset_is_fid'])
        # this is for using the downloaded one from
        # https://github.com/bioinf-jku/TTUR
        #mdata, sdata = f['mu'][:], f['sigma'][:]

        # this one is for my format files
        mdata, sdata = f['mfid'], f['sfid']

    # Sample a number of images and stick them in memory, for use with TF-Inception official_IS and official_FID
    data_gen_necessary = False
    if config['sample_np_mem']:
        is_saved = int('IS' in hist[state_dict['itr']])
        is_todo = int(config['official_IS'])
        fid_saved = int('FID' in hist[state_dict['itr']])
        fid_todo = int(config['official_FID'])
        data_gen_necessary = config['overwrite'] or (is_todo > is_saved) or (
            fid_todo > fid_saved)
    if config['sample_np_mem'] and data_gen_necessary:
        n_used_imgs = 50000
        imageSize = config['resolution']
        x = np.empty((n_used_imgs, imageSize, imageSize, 3), dtype=np.uint8)
        for l in tqdm(range(n_used_imgs // G_batch_size),
                      desc='Generating [%s][%06d]' %
                      (brief_expt_name, state_dict['itr'])):
            start = l * G_batch_size
            end = start + G_batch_size

            with torch.no_grad():
                images, labels = sample()
            fake = np.uint8(255 * (images.cpu().numpy() + 1) / 2.)
            x[start:end] = np.moveaxis(fake, 1, -1)
            #y += [labels.cpu().numpy()]

    if config['official_IS']:
        if (not ('IS' in hist[state_dict['itr']])) or config['overwrite']:
            mis, sis = iscore.get_inception_score(x)
            print('[%s][%06d] IS mu: %f. IS sigma: %f.' %
                  (brief_expt_name, state_dict['itr'], mis, sis))
            hist = load_or_make_hist(hist_dir)
            hist[state_dict['itr']]['IS'] = [mis, sis]
            np.save(os.path.join(hist_dir, HIST_FNAME), hist)
        else:
            mis, sis = hist[state_dict['itr']]['IS']
            print(
                '[%s][%06d] Already done (skipping...): IS mu: %f. IS sigma: %f.'
                % (brief_expt_name, state_dict['itr'], mis, sis))

    if config['official_FID']:
        import tensorflow as tf

        def fid_ms_for_imgs(images, mem_fraction=0.5):
            gpu_options = tf.GPUOptions(
                per_process_gpu_memory_fraction=mem_fraction)
            inception_path = fid.check_or_download_inception(None)
            fid.create_inception_graph(
                inception_path)  # load the graph into the current TF graph
            with tf.Session(config=tf.ConfigProto(
                    gpu_options=gpu_options)) as sess:
                sess.run(tf.global_variables_initializer())
                mu_gen, sigma_gen = fid.calculate_activation_statistics(
                    images, sess, batch_size=100)
            return mu_gen, sigma_gen

        if (not ('FID' in hist[state_dict['itr']])) or config['overwrite']:
            m1, s1 = fid_ms_for_imgs(x)
            fid_value = fid.calculate_frechet_distance(m1, s1, mdata, sdata)
            print('[%s][%06d] FID: %f' %
                  (brief_expt_name, state_dict['itr'], fid_value))
            hist = load_or_make_hist(hist_dir)
            hist[state_dict['itr']]['FID'] = fid_value
            np.save(os.path.join(hist_dir, HIST_FNAME), hist)
        else:
            fid_value = hist[state_dict['itr']]['FID']
            print('[%s][%06d] Already done (skipping...): FID: %f' %
                  (brief_expt_name, state_dict['itr'], fid_value))

    # Prepare sample sheets
    if config['sample_sheets']:
        print('Preparing conditional sample sheets...')
        folder_number = config['sample_sheet_folder_num']
        if folder_number == -1:
            folder_number = config['load_weights']
        utils.sample_sheet(
            G,
            classes_per_sheet=utils.classes_per_sheet_dict[config['dataset']],
            num_classes=config['n_classes'],
            samples_per_class=10,
            parallel=config['parallel'],
            samples_root=config['samples_root'],
            experiment_name=experiment_name,
            folder_number=folder_number,
            z_=z_,
        )
    # Sample interp sheets
    if config['sample_interps']:
        print('Preparing interp sheets...')
        folder_number = config['sample_sheet_folder_num']
        if folder_number == -1:
            folder_number = config['load_weights']
        for fix_z, fix_y in zip([False, False, True], [False, True, False]):
            utils.interp_sheet(G,
                               num_per_sheet=16,
                               num_midpoints=8,
                               num_classes=config['n_classes'],
                               parallel=config['parallel'],
                               samples_root=config['samples_root'],
                               experiment_name=experiment_name,
                               folder_number=int(folder_number),
                               sheet_number=0,
                               fix_z=fix_z,
                               fix_y=fix_y,
                               device='cuda')
    # Sample random sheet
    if config['sample_random']:
        print('Preparing random sample sheet...')
        images, labels = sample()
        torchvision.utils.save_image(
            images.float(),
            '%s/%s/%s.jpg' %
            (config['samples_root'], experiment_name, config['load_weights']),
            nrow=int(G_batch_size**0.5),
            normalize=True)

    # Prepare a simple function get metrics that we use for trunc curves
    def get_metrics():
        # Get Inception Score and FID
        get_inception_metrics = inception_utils.prepare_inception_metrics(
            config['dataset'], config['parallel'], config['no_fid'])
        sample = functools.partial(utils.sample,
                                   G=G,
                                   z_=z_,
                                   y_=y_,
                                   config=config)
        IS_mean, IS_std, FID = get_inception_metrics(
            sample,
            config['num_inception_images'],
            num_splits=10,
            prints=False)
        # Prepare output string
        outstring = 'Using %s weights ' % ('ema'
                                           if config['use_ema'] else 'non-ema')
        outstring += 'in %s mode, ' % ('eval' if config['G_eval_mode'] else
                                       'training')
        outstring += 'with noise variance %3.3f, ' % z_.var
        outstring += 'over %d images, ' % config['num_inception_images']
        if config['accumulate_stats'] or not config['G_eval_mode']:
            outstring += 'with batch size %d, ' % G_batch_size
        if config['accumulate_stats']:
            outstring += 'using %d standing stat accumulations, ' % config[
                'num_standing_accumulations']
        outstring += 'Itr %d: PYTORCH UNOFFICIAL Inception Score is %3.3f +/- %3.3f, PYTORCH UNOFFICIAL FID is %5.4f' % (
            state_dict['itr'], IS_mean, IS_std, FID)
        print(outstring)

    if config['sample_inception_metrics']:
        print('Calculating Inception metrics...')
        get_metrics()

    # Sample truncation curve stuff. This is basically the same as the inception metrics code
    if config['sample_trunc_curves']:
        start, step, end = [
            float(item) for item in config['sample_trunc_curves'].split('_')
        ]
        print(
            'Getting truncation values for variance in range (%3.3f:%3.3f:%3.3f)...'
            % (start, step, end))
        for var in np.arange(start, end + step, step):
            z_.var = var
            # Optionally comment this out if you want to run with standing stats
            # accumulated at one z variance setting
            if config['accumulate_stats']:
                utils.accumulate_standing_stats(
                    G, z_, y_, config['n_classes'],
                    config['num_standing_accumulations'])
            get_metrics()
Ejemplo n.º 7
0
def train(target_vars, saver, sess, logger, dataloaders, test_dataloaders,
          resume_iter, logdir):
    X = target_vars['X']
    Y = target_vars['Y']
    X_NOISE = target_vars['X_NOISE']
    train_op = target_vars['train_op']
    energy_pos = target_vars['energy_pos']
    energy_neg = target_vars['energy_neg']
    loss_energy = target_vars['loss_energy']
    loss_ml = target_vars['loss_ml']
    loss_total = target_vars['total_loss']
    gvs = target_vars['gvs']
    x_grad = target_vars['x_grad']
    x_grad_first = target_vars['x_grad_first']
    x_off = target_vars['x_off']
    temp = target_vars['temp']
    x_mod = target_vars['x_mod']
    LABEL = target_vars['LABEL']
    LABEL_POS = target_vars['LABEL_POS']
    weights = target_vars['weights']
    test_x_mod = target_vars['test_x_mod']
    eps = target_vars['eps_begin']
    label_ent = target_vars['label_ent']

    set_seed(0)
    np.random.seed(0)
    random.seed(0)

    if FLAGS.use_attention:
        gamma = weights[0]['atten']['gamma']
    else:
        gamma = tf.zeros(1)

    val_output = [test_x_mod]

    gvs_dict = dict(gvs)

    log_output = [
        train_op, energy_pos, energy_neg, eps, loss_energy, loss_ml,
        loss_total, x_grad, x_off, x_mod, gamma, x_grad_first, label_ent,
        *gvs_dict.keys()
    ]
    output = [train_op, x_mod]

    replay_buffer = ReplayBuffer(10000)
    itr = resume_iter
    x_mod = None
    gd_steps = 1

    err_message = 'Total number of epochs should be divisible by the number of CL tasks.'
    assert FLAGS.epoch_num % FLAGS.num_tasks == 0, err_message
    epochs_per_task = FLAGS.epoch_num // FLAGS.num_tasks // FLAGS.num_cycles

    for task_index, dataloader in enumerate(dataloaders):
        dataloader_iterator = iter(dataloader)
        best_inception = 0.0

        for epoch in range(1, epochs_per_task + 1):
            for data_corrupt, data, label in dataloader:
                print('Iter: {}; Epoch: {}/{}; Task: {}/{}'.format(
                    itr, epoch + (task_index * epochs_per_task),
                    FLAGS.epoch_num, task_index + 1, FLAGS.num_tasks))
                data_corrupt = data_corrupt_init = data_corrupt.numpy()
                data_corrupt_init = data_corrupt.copy()

                data = data.numpy()
                label = label.numpy()

                label_init = label.copy()

                if FLAGS.mixup:
                    idx = np.random.permutation(data.shape[0])
                    lam = np.random.beta(1, 1, size=(data.shape[0], 1, 1, 1))
                    data = data * lam + data[idx] * (1 - lam)

                if FLAGS.replay_batch and (x_mod is not None):
                    replay_buffer.add(compress_x_mod(x_mod))

                    if len(replay_buffer) > FLAGS.batch_size:
                        replay_batch = replay_buffer.sample(FLAGS.batch_size)
                        replay_batch = decompress_x_mod(replay_batch)
                        replay_mask = (np.random.uniform(
                            0, FLAGS.rescale, FLAGS.batch_size) > 0.05)
                        data_corrupt[replay_mask] = replay_batch[replay_mask]

                if FLAGS.pcd:
                    if x_mod is not None:
                        data_corrupt = x_mod

                feed_dict = {X_NOISE: data_corrupt, X: data, Y: label}

                if FLAGS.cclass:
                    feed_dict[LABEL] = label
                    feed_dict[LABEL_POS] = label_init

                if itr % FLAGS.log_interval == 0:
                    _, e_pos, e_neg, eps, loss_e, loss_ml, loss_total, x_grad, x_off, x_mod, gamma, x_grad_first, label_ent, * \
                        grads = sess.run(log_output, feed_dict)

                    kvs = {}
                    kvs['e_pos'] = e_pos.mean()
                    kvs['e_pos_std'] = e_pos.std()
                    kvs['e_neg'] = e_neg.mean()
                    kvs['e_diff'] = kvs['e_pos'] - kvs['e_neg']
                    kvs['e_neg_std'] = e_neg.std()
                    kvs['temp'] = temp
                    kvs['loss_e'] = loss_e.mean()
                    kvs['eps'] = eps.mean()
                    kvs['label_ent'] = label_ent
                    kvs['loss_ml'] = loss_ml.mean()
                    kvs['loss_total'] = loss_total.mean()
                    kvs['x_grad'] = np.abs(x_grad).mean()
                    kvs['x_grad_first'] = np.abs(x_grad_first).mean()
                    kvs['x_off'] = x_off.mean()
                    kvs['iter'] = itr
                    kvs['gamma'] = gamma

                    for v, k in zip(grads,
                                    [v.name for v in gvs_dict.values()]):
                        kvs[k] = np.abs(v).max()

                    string = "Obtained a total of "
                    for key, value in kvs.items():
                        string += "{}: {}, ".format(key, value)

                    if hvd.rank() == 0:
                        print(string)
                        logger.writekvs(kvs)
                        for key, value in kvs.items():
                            neptune.log_metric(key, x=itr, y=value)

                else:
                    _, x_mod = sess.run(output, feed_dict)

                if itr % FLAGS.save_interval == 0 and hvd.rank() == 0:
                    saver.save(
                        sess,
                        osp.join(FLAGS.logdir, FLAGS.exp,
                                 'model_{}'.format(itr)))

                if itr % FLAGS.test_interval == 0 and hvd.rank(
                ) == 0 and FLAGS.dataset != '2d':
                    if FLAGS.dataset == 'cifar10':
                        cifar10_map = {
                            0: 'airplane',
                            1: 'automobile',
                            2: 'bird',
                            3: 'cat',
                            4: 'deer',
                            5: 'dog',
                            6: 'frog',
                            7: 'horse',
                            8: 'ship',
                            9: 'truck'
                        }

                        imgs = data
                        labels = np.argmax(label, axis=1)
                        for idx, img in enumerate(imgs[:20, :, :, :]):
                            neptune.log_image(
                                'input_images',
                                rescale_im(imgs[idx]),
                                description=str(int(labels[idx])) + ': ' +
                                cifar10_map[int(labels[idx])])

                    if FLAGS.evaluate:
                        print('Test.')
                        train_acc = test_accuracy(target_vars, saver, sess,
                                                  logger, test_dataloaders[0])
                        test_acc = test_accuracy(target_vars, saver, sess,
                                                 logger, test_dataloaders[1])
                        neptune.log_metric('train_accuracy',
                                           x=itr,
                                           y=train_acc)
                        neptune.log_metric('test_accuracy', x=itr, y=test_acc)

                    try_im = x_mod
                    orig_im = data_corrupt.squeeze()
                    actual_im = rescale_im(data)

                    orig_im = rescale_im(orig_im)
                    try_im = rescale_im(try_im).squeeze()

                    for i, (im, t_im, actual_im_i) in enumerate(
                            zip(orig_im[:20], try_im[:20], actual_im)):
                        shape = orig_im.shape[1:]
                        new_im = np.zeros((shape[0], shape[1] * 3, *shape[2:]))
                        size = shape[1]
                        new_im[:, :size] = im
                        new_im[:, size:2 * size] = t_im
                        new_im[:, 2 * size:] = actual_im_i

                        log_image(new_im,
                                  logger,
                                  'train_gen_{}'.format(itr),
                                  step=i)
                        neptune.log_image(
                            'train_gen',
                            x=new_im,
                            description='train_gen_iter:{}_idx:{}'.format(
                                itr, i))
                    test_im = x_mod

                    try:
                        data_corrupt, data, label = next(dataloader_iterator)
                    except BaseException:
                        dataloader_iterator = iter(dataloader)
                        data_corrupt, data, label = next(dataloader_iterator)

                    data_corrupt = data_corrupt.numpy()

                    if FLAGS.replay_batch and (
                            x_mod is not None) and len(replay_buffer) > 0:
                        replay_batch = replay_buffer.sample(FLAGS.batch_size)
                        replay_batch = decompress_x_mod(replay_batch)
                        replay_mask = (np.random.uniform(
                            0, 1, (FLAGS.batch_size)) > 0.05)
                        data_corrupt[replay_mask] = replay_batch[replay_mask]

                    if FLAGS.dataset == 'cifar10' or FLAGS.dataset == 'imagenet' or FLAGS.dataset == 'imagenetfull':
                        n = 128

                        if FLAGS.dataset == "imagenetfull":
                            n = 32

                        if len(replay_buffer) > n:
                            data_corrupt = decompress_x_mod(
                                replay_buffer.sample(n))
                        elif FLAGS.dataset == 'imagenetfull':
                            data_corrupt = np.random.uniform(
                                0, FLAGS.rescale, (n, 128, 128, 3))
                        else:
                            data_corrupt = np.random.uniform(
                                0, FLAGS.rescale, (n, 32, 32, 3))

                        if FLAGS.dataset == 'cifar10':
                            label = np.eye(10)[np.random.randint(0, 10, (n))]
                        else:
                            label = np.eye(1000)[np.random.randint(
                                0, 1000, (n))]

                    feed_dict[X_NOISE] = data_corrupt

                    feed_dict[X] = data

                    if FLAGS.cclass:
                        feed_dict[LABEL] = label

                    test_x_mod = sess.run(val_output, feed_dict)

                    try_im = test_x_mod
                    orig_im = data_corrupt.squeeze()
                    actual_im = rescale_im(data.numpy())

                    orig_im = rescale_im(orig_im)
                    try_im = rescale_im(try_im).squeeze()

                    for i, (im, t_im, actual_im_i) in enumerate(
                            zip(orig_im[:20], try_im[:20], actual_im)):

                        shape = orig_im.shape[1:]
                        new_im = np.zeros((shape[0], shape[1] * 3, *shape[2:]))
                        size = shape[1]
                        new_im[:, :size] = im
                        new_im[:, size:2 * size] = t_im
                        new_im[:, 2 * size:] = actual_im_i
                        log_image(new_im,
                                  logger,
                                  'val_gen_{}'.format(itr),
                                  step=i)
                        neptune.log_image(
                            'val_gen',
                            new_im,
                            description='val_gen_iter:{}_idx:{}'.format(
                                itr, i))

                    score, std = get_inception_score(list(try_im), splits=1)
                    print("Inception score of {} with std of {}".format(
                        score, std))
                    kvs = {}
                    kvs['inception_score'] = score
                    kvs['inception_score_std'] = std
                    logger.writekvs(kvs)
                    for key, value in kvs.items():
                        neptune.log_metric(key, x=itr, y=value)

                    if score > best_inception:
                        best_inception = score
                        saver.save(
                            sess,
                            osp.join(FLAGS.logdir, FLAGS.exp, 'model_best'))

                if itr > 600000 and FLAGS.dataset == "mnist":
                    assert False
                itr += 1

        saver.save(sess,
                   osp.join(FLAGS.logdir, FLAGS.exp, 'model_{}'.format(itr)))
Ejemplo n.º 8
0
def train(target_vars, saver, sess, logger, dataloader, resume_iter, logdir):
    X = target_vars['X']
    Y = target_vars['Y']
    X_NOISE = target_vars['X_NOISE']
    train_op_model = target_vars['train_op_model']
    train_op_dis = target_vars['train_op_dis']
    energy_pos = target_vars['energy_pos']
    energy_neg = target_vars['energy_neg']
    score_pos = target_vars['score_pos']
    score_neg = target_vars['score_neg']
    loss_energy = target_vars['loss_energy']
    loss_total = target_vars['total_loss']
    gvs = target_vars['gvs']
    x_grad = target_vars['x_grad']
    x_grad_first = target_vars['x_grad_first']
    x_off = target_vars['x_off']
    temp = target_vars['temp']
    x_mod = target_vars['x_mod']
    LABEL = target_vars['LABEL']
    LABEL_POS = target_vars['LABEL_POS']
    weights = target_vars['weights']
    test_x_mod = target_vars['test_x_mod']
    eps = target_vars['eps_begin']
    label_ent = target_vars['label_ent']
    train_op_model_l2 = target_vars['train_op_model_l2']
    train_op_dis_l2 = target_vars['train_op_dis_l2']

    output = [train_op_model, x_mod]

    if FLAGS.use_attention:
        gamma = weights[0]['atten']['gamma']
    else:
        gamma = tf.zeros(1)

    val_output = [test_x_mod]

    gvs_dict = dict(gvs)

    # log_output = [
    #     train_op,
    #     energy_pos,
    #     energy_neg,
    #     eps,
    #     loss_energy,
    #     loss_total,
    #     x_grad,
    #     x_off,
    #     x_mod,
    #     gamma,
    #     x_grad_first,
    #     label_ent,
    #     *gvs_dict.keys()]

    replay_buffer = ReplayBuffer(10000)
    itr = resume_iter
    x_mod = None
    gd_steps = 1

    dataloader_iterator = iter(dataloader)
    best_inception = 0.0
    save_interval = FLAGS.save_interval

    for epoch in range(FLAGS.epoch_num):
        for data_corrupt, data, label in dataloader:
            data_corrupt = data_corrupt_init = data_corrupt.numpy()
            data_corrupt_init = data_corrupt.copy()

            data = data.numpy()
            label = label.numpy()

            label_init = label.copy()

            if FLAGS.mixup:
                idx = np.random.permutation(data.shape[0])
                lam = np.random.beta(1, 1, size=(data.shape[0], 1, 1, 1))
                data = data * lam + data[idx] * (1 - lam)

            if FLAGS.replay_batch and (x_mod is not None):
                replay_buffer.add(compress_x_mod(x_mod))

                if len(replay_buffer) > FLAGS.batch_size:
                    replay_batch = replay_buffer.sample(FLAGS.batch_size)
                    replay_batch = decompress_x_mod(replay_batch)
                    replay_mask = (
                        np.random.uniform(
                            0,
                            FLAGS.rescale,
                            FLAGS.batch_size) > 0.05)
                    data_corrupt[replay_mask] = replay_batch[replay_mask]

            if FLAGS.pcd:
                if x_mod is not None:
                    data_corrupt = x_mod

            feed_dict = {X_NOISE: data_corrupt, X: data, Y: label}

            if FLAGS.cclass:
                feed_dict[LABEL] = label
                feed_dict[LABEL_POS] = label_init

            if itr > 10:
                # Train discriminator
                _ = sess.run(train_op_dis, feed_dict)

                # Train model
                _, x_mod = sess.run(output, feed_dict)
            else:
                _, _ = sess.run([train_op_dis_l2, train_op_model_l2], feed_dict)
                energy_neg_, energy_pos_, score_neg_, score_pos_ = sess.run([energy_neg, energy_pos, score_neg, score_pos], feed_dict)
                print(np.mean(energy_neg_), np.mean(energy_pos_), np.mean(score_neg_), np.mean(score_pos_))

            if itr > 30000:
                save_interval = 100

            # if itr % save_interval == 0 and hvd.rank() == 0:
            #     saver.save(sess, osp.join(FLAGS.logdir, FLAGS.exp, 'model_{}'.format(itr)))

            if itr and itr % FLAGS.test_interval == 0 and hvd.rank() == 0 and FLAGS.dataset != '2d':
                try_im = x_mod
                orig_im = data_corrupt.squeeze()
                actual_im = rescale_im(data)

                orig_im = rescale_im(orig_im)
                try_im = rescale_im(try_im).squeeze()

                for i, (im, t_im, actual_im_i) in enumerate(
                        zip(orig_im[:20], try_im[:20], actual_im)):
                    shape = orig_im.shape[1:]
                    new_im = np.zeros((shape[0], shape[1] * 3, *shape[2:]))
                    size = shape[1]
                    new_im[:, :size] = im
                    new_im[:, size:2 * size] = t_im
                    new_im[:, 2 * size:] = actual_im_i

                    log_image(
                        new_im, logger, 'train_gen_{}'.format(itr), step=i)

                test_im = x_mod

                try:
                    data_corrupt, data, label = next(dataloader_iterator)
                except BaseException:
                    dataloader_iterator = iter(dataloader)
                    data_corrupt, data, label = next(dataloader_iterator)

                data_corrupt = data_corrupt.numpy()

                if FLAGS.replay_batch and (
                        x_mod is not None) and len(replay_buffer) > 0:
                    replay_batch = replay_buffer.sample(FLAGS.batch_size)
                    replay_batch = decompress_x_mod(replay_batch)
                    replay_mask = (
                        np.random.uniform(
                            0, 1, (FLAGS.batch_size)) > 0.05)
                    data_corrupt[replay_mask] = replay_batch[replay_mask]

                if FLAGS.dataset == 'cifar10' or FLAGS.dataset == 'imagenet' or FLAGS.dataset == 'imagenetfull' or FLAGS.dataset == 'celeba':
                    n = 128

                    if FLAGS.dataset == "imagenetfull":
                        n = 32

                    if len(replay_buffer) > n:
                        data_corrupt = decompress_x_mod(replay_buffer.sample(n))
                    elif FLAGS.dataset == 'imagenetfull':
                        data_corrupt = np.random.uniform(
                            0, FLAGS.rescale, (n, 128, 128, 3))
                    else:
                        data_corrupt = np.random.uniform(
                            0, FLAGS.rescale, (n, 32, 32, 3))

                    if FLAGS.dataset == 'cifar10':
                        label = np.eye(10)[np.random.randint(0, 10, (n))]
                    elif FLAGS.dataset == 'celeba':
                        label = np.array([1] * n). reshape((n, 1))
                    else:
                        label = np.eye(1000)[
                            np.random.randint(
                                0, 1000, (n))]

                feed_dict[X_NOISE] = data_corrupt

                feed_dict[X] = data

                if FLAGS.cclass:
                    feed_dict[LABEL] = label

                test_x_mod = sess.run(val_output, feed_dict)

                try_im = test_x_mod
                orig_im = data_corrupt.squeeze()
                actual_im = rescale_im(data.numpy())

                orig_im = rescale_im(orig_im)
                try_im = rescale_im(try_im).squeeze()

                for i, (im, t_im, actual_im_i) in enumerate(
                        zip(orig_im[:20], try_im[:20], actual_im)):

                    shape = orig_im.shape[1:]
                    new_im = np.zeros((shape[0], shape[1] * 3, *shape[2:]))
                    size = shape[1]
                    new_im[:, :size] = im
                    new_im[:, size:2 * size] = t_im
                    new_im[:, 2 * size:] = actual_im_i
                    log_image(
                        new_im, logger, 'val_gen_{}'.format(itr), step=i)

                score, std = get_inception_score(list(try_im), splits=1)
                print("Iteration {}: Inception score of {} with std of {}".format(itr, score, std))
                kvs = {}
                kvs['inception_score'] = score
                kvs['inception_score_std'] = std
                logger.writekvs(kvs)

                if score > best_inception:
                    best_inception = score
                    saver.save(sess, osp.join(FLAGS.logdir, FLAGS.exp, 'model_best'))
                    saver.save(sess, osp.join(FLAGS.logdir, FLAGS.exp, 'model_{}'.format(itr)))

            if itr > 60000 and FLAGS.dataset == "mnist":
                assert False
            itr += 1

    saver.save(sess, osp.join(FLAGS.logdir, FLAGS.exp, 'model_{}'.format(itr)))
Ejemplo n.º 9
0
def train(models, models_ema, optimizer, logger, dataloader, resume_iter, logdir, FLAGS, rank_idx, best_inception):

    torch.cuda.set_device(rank_idx)

    if FLAGS.replay_batch:
        if FLAGS.reservoir:
            replay_buffer = ReservoirBuffer(FLAGS.buffer_size, FLAGS.transform, FLAGS.dataset)
        else:
            replay_buffer = ReplayBuffer(FLAGS.buffer_size, FLAGS.transform, FLAGS.dataset)

    if rank_idx == 0:
        from inception import get_inception_score

    itr = resume_iter
    im_neg = None
    gd_steps = 1

    optimizer.zero_grad()

    num_steps = FLAGS.num_steps

    if FLAGS.cuda:
        dev = torch.device("cuda:{}".format(rank_idx))
    else:
        dev = torch.device("cpu")

    for epoch in range(FLAGS.epoch_num):
        tock = time.time()
        for data_corrupt, data, label in dataloader:
            label = label.float().cuda(rank_idx)
            data = data.permute(0, 3, 1, 2).float().contiguous()

            # Generate samples to evaluate inception score
            if itr % FLAGS.save_interval == 0:
                if FLAGS.dataset == "cifar10":
                    data_corrupt = torch.Tensor(np.random.uniform(0.0, 1.0, (128, 32, 32, 3)))
                    repeat = 128 // FLAGS.batch_size + 1
                    label = torch.cat([label] * repeat, axis=0)
                    label = label[:128]
                elif FLAGS.dataset == "celeba":
                    data_corrupt = torch.Tensor(np.random.uniform(0.0, 1.0, (data.shape[0], 128, 128, 3)))
                    label = label[:data.shape[0]]
                    data_corrupt = data_corrupt[:label.shape[0]]
                elif FLAGS.dataset == "stl":
                    data_corrupt = torch.Tensor(np.random.uniform(0.0, 1.0, (32, 48, 48, 3)))
                    label = label[:32]
                    data_corrupt = data_corrupt[:label.shape[0]]
                elif FLAGS.dataset == "lsun":
                    data_corrupt = torch.Tensor(np.random.uniform(0.0, 1.0, (32, 128, 128, 3)))
                    label = label[:32]
                    data_corrupt = data_corrupt[:label.shape[0]]
                elif FLAGS.dataset == "imagenet":
                    data_corrupt = torch.Tensor(np.random.uniform(0.0, 1.0, (32, 128, 128, 3)))
                    label = label[:32]
                    data_corrupt = data_corrupt[:label.shape[0]]
                elif FLAGS.dataset == "object":
                    data_corrupt = torch.Tensor(np.random.uniform(0.0, 1.0, (32, 128, 128, 3)))
                    label = label[:32]
                    data_corrupt = data_corrupt[:label.shape[0]]
                elif FLAGS.dataset == "mnist":
                    data_corrupt = torch.Tensor(np.random.uniform(0.0, 1.0, (32, 28, 28, 1)))
                    label = label[:32]
                    data_corrupt = data_corrupt[:label.shape[0]]
                else:
                    assert False

            data_corrupt = torch.Tensor(data_corrupt.float()).permute(0, 3, 1, 2).float().contiguous()
            data = data.cuda(rank_idx)
            data_corrupt = data_corrupt.cuda(rank_idx)

            if FLAGS.replay_batch and len(replay_buffer) >= FLAGS.batch_size:
                replay_batch, idxs = replay_buffer.sample(data_corrupt.size(0))
                replay_batch = decompress_x_mod(replay_batch)
                replay_mask = (
                    np.random.uniform(
                        0,
                        1,
                        data_corrupt.size(0)) > 0.05)
                data_corrupt[replay_mask] = torch.Tensor(replay_batch[replay_mask]).cuda(rank_idx)
            else:
                idxs = None

            ix = random.randint(0, len(models) - 1)
            model = models[ix]

            if FLAGS.hmc:
                if itr % FLAGS.save_interval == 0:
                    im_neg, im_samples, x_grad, v = gen_hmc_image(label, FLAGS, model, data_corrupt, num_steps, sample=True)
                else:
                    im_neg, x_grad, v = gen_hmc_image(label, FLAGS, model, data_corrupt, num_steps)
            else:
                if itr % FLAGS.save_interval == 0:
                    im_neg, im_neg_kl, im_samples, x_grad = gen_image(label, FLAGS, model, data_corrupt, num_steps, sample=True)
                else:
                    im_neg, im_neg_kl, x_grad = gen_image(label, FLAGS, model, data_corrupt, num_steps)

            energy_pos = model.forward(data, label[:data.size(0)])
            energy_neg = model.forward(im_neg.clone(), label)

            if FLAGS.replay_batch and (im_neg is not None):
                replay_buffer.add(compress_x_mod(im_neg.detach().cpu().numpy()))

            loss = energy_pos.mean() - energy_neg.mean() #
            loss = loss  + (torch.pow(energy_pos, 2).mean() + torch.pow(energy_neg, 2).mean())

            if FLAGS.kl:
                model.requires_grad_(False)
                loss_kl = model.forward(im_neg_kl, label)
                model.requires_grad_(True)
                loss = loss + FLAGS.kl_coeff * loss_kl.mean()

                if FLAGS.repel_im:
                    start = timeit.timeit()
                    bs = im_neg_kl.size(0)

                    if FLAGS.dataset in ["celeba", "imagenet", "object", "lsun", "stl"]:
                        im_neg_kl = im_neg_kl[:, :, :, :].contiguous()

                    im_flat = torch.clamp(im_neg_kl.view(bs, -1), 0, 1)

                    if FLAGS.dataset == "cifar10":
                        if len(replay_buffer) > 1000:
                            compare_batch, idxs = replay_buffer.sample(100, no_transform=False)
                            compare_batch = decompress_x_mod(compare_batch)
                            compare_batch = torch.Tensor(compare_batch).cuda(rank_idx)
                            compare_flat = compare_batch.view(100, -1)

                            dist_matrix = torch.norm(im_flat[:, None, :] - compare_flat[None, :, :], p=2, dim=-1)
                            loss_repel = torch.log(dist_matrix.min(dim=1)[0]).mean()
                            loss = loss - 0.3 * loss_repel
                        else:
                            loss_repel = torch.zeros(1)
                    else:
                        if len(replay_buffer) > 1000:
                            compare_batch, idxs = replay_buffer.sample(100, no_transform=False, downsample=True)
                            compare_batch = decompress_x_mod(compare_batch)
                            compare_batch = torch.Tensor(compare_batch).cuda(rank_idx)
                            compare_flat = compare_batch.view(100, -1)
                            dist_matrix = torch.norm(im_flat[:, None, :] - compare_flat[None, :, :], p=2, dim=-1)
                            loss_repel = torch.log(dist_matrix.min(dim=1)[0]).mean()
                        else:
                            loss_repel = torch.zeros(1).cuda(rank_idx)

                        loss = loss - 0.3 * loss_repel

                    end = timeit.timeit()
                else:
                    loss_repel = torch.zeros(1)

            else:
                loss_kl = torch.zeros(1)
                loss_repel = torch.zeros(1)

            if FLAGS.hmc:
                v_flat = v.view(v.size(0), -1)
                im_grad_flat = x_grad.view(x_grad.size(0), -1)
                dot_product = F.normalize(v_flat, dim=1) * F.normalize(im_grad_flat, dim=1)
                hmc_loss = torch.abs(dot_product.sum(dim=1)).mean()
                loss = loss + 0.01 * hmc_loss
            else:
                hmc_loss = torch.zeros(1)

            if FLAGS.log_grad and len(replay_buffer) > 1000:
                loss_kl = loss_kl - 0.1 * loss_repel
                loss_kl = loss_kl.mean()
                loss_ml = energy_pos.mean() - energy_neg.mean()

                loss_ml.backward(retain_graph=True)
                ele = []

                for param in model.parameters():
                    if param.grad is not None:
                        ele.append(torch.norm(param.grad.data))

                ele = torch.stack(ele, dim=0)
                ml_grad = torch.mean(ele)
                model.zero_grad()

                loss_kl.backward(retain_graph=True)
                ele = []

                for param in model.parameters():
                    if param.grad is not None:
                        ele.append(torch.norm(param.grad.data))

                ele = torch.stack(ele, dim=0)
                kl_grad = torch.mean(ele)
                model.zero_grad()

            else:
                ml_grad = None
                kl_grad = None

            loss.backward()

            if FLAGS.gpus > 1:
                average_gradients(models)

            [clip_grad_norm(model.parameters(), 0.5) for model in models]

            optimizer.step()
            optimizer.zero_grad()

            ema_model(models, models_ema)

            if torch.isnan(energy_pos.mean()):
                assert False

            if torch.abs(energy_pos.mean()) > 10.0:
                assert False

            if itr % FLAGS.log_interval == 0 and rank_idx==0:
                tick = time.time()
                kvs = {}
                kvs['e_pos'] = energy_pos.mean().item()
                kvs['e_pos_std'] = energy_pos.std().item()
                kvs['e_neg'] = energy_neg.mean().item()
                kvs['kl_mean'] = loss_kl.mean().item()
                kvs['loss_repel'] = loss_repel.mean().item()
                kvs['e_neg_std'] = energy_neg.std().item()
                kvs['e_diff'] = kvs['e_pos'] - kvs['e_neg']
                kvs['x_grad'] = np.abs(x_grad.detach().cpu().numpy()).mean()
                kvs['iter'] = itr
                kvs['hmc_loss'] = hmc_loss.item()
                kvs['num_steps'] = num_steps
                kvs['t_diff'] = tick - tock

                if FLAGS.replay_batch:
                    kvs['length'] = len(replay_buffer)

                if (ml_grad is not None):
                    kvs['kl_grad'] = kl_grad
                    kvs['ml_grad'] = ml_grad

                string = "Obtained a total of "
                for key, value in kvs.items():
                    string += "{}: {}, ".format(key, value)

                print(string)
                logger.writekvs(kvs)
                tock = tick

            if itr % FLAGS.save_interval == 0 and rank_idx == 0 and (FLAGS.save_interval != 0):
                model_path = osp.join(logdir, "model_{}.pth".format(itr))
                ckpt = {'optimizer_state_dict': optimizer.state_dict(),
                            'FLAGS': FLAGS, 'best_inception': best_inception}

                for i in range(FLAGS.ensembles):
                    ckpt['model_state_dict_{}'.format(i)] = models[i].state_dict()
                    ckpt['ema_model_state_dict_{}'.format(i)] = models_ema[i].state_dict()

                torch.save(ckpt, model_path)

            if itr % FLAGS.save_interval == 0 and rank_idx == 0:
                im_samples = im_samples[::10]
                im_samples_total = torch.stack(im_samples, dim=1).detach().cpu().permute(0, 1, 3, 4, 2).numpy()
                try_im = im_neg
                orig_im = data_corrupt
                actual_im = rescale_im(data.detach().permute(0, 2, 3, 1).cpu().numpy())

                orig_im = rescale_im(orig_im.detach().permute(0, 2, 3, 1).cpu().numpy())
                try_im = rescale_im(try_im.detach().permute(0, 2, 3, 1).cpu().numpy()).squeeze()
                im_samples_total = rescale_im(im_samples_total)

                for i, (im, sample_im, actual_im_i) in enumerate(
                        zip(orig_im[:20], im_samples_total[:20], actual_im)):
                    shape = orig_im.shape[1:]
                    new_im = np.zeros((shape[0], shape[1] * (2 + sample_im.shape[0]), *shape[2:]))
                    size = shape[1]
                    new_im[:, :size] = im

                    for i, sample_i in enumerate(sample_im):
                        new_im[:, (i+1) * size:(i+2) * size] = sample_i

                    new_im[:, -size:] = actual_im_i

                    log_image(
                        new_im, logger, 'train_gen_{}'.format(itr), step=i)


                if rank_idx == 0:
                    score, std = get_inception_score(list(try_im), splits=1)
                    print("Inception score of {} with std of {}".format(
                            score, std))
                    kvs = {}
                    kvs['inception_score'] = score
                    kvs['inception_score_std'] = std
                    logger.writekvs(kvs)

                    if score > best_inception:
                        model_path = osp.join(logdir, "model_best.pth")
                        torch.save(ckpt, model_path)
                        best_inception = score

            itr += 1
def conceptcombineeval(model_list, select_idx):
    dataset = ImageNet()
    dataloader = torch.utils.data.DataLoader(dataset,
                                             batch_size=64,
                                             shuffle=True,
                                             num_workers=4)

    n = 64
    labels = []

    for six in select_idx:
        six = np.random.permutation(1000)[:n]
        print(six)
        label_batch = np.eye(1000)[six]
        # label_ix = np.eye(2)[six]
        # label_batch = np.tile(label_ix[None, :], (n, 1))
        label = torch.Tensor(label_batch).cuda()
        labels.append(label)

    def get_color_distortion(s=1.0):
        # s is the strength of color distortion.
        color_jitter = transforms.ColorJitter(0.8 * s, 0.8 * s, 0.8 * s,
                                              0.4 * s)
        rnd_color_jitter = transforms.RandomApply([color_jitter], p=0.8)
        rnd_gray = transforms.RandomGrayscale(p=0.2)
        color_distort = transforms.Compose([rnd_color_jitter, rnd_gray])
        return color_distort

    color_transform = get_color_distortion(0.5)

    im_size = 128
    transform = transforms.Compose([
        transforms.RandomResizedCrop(im_size, scale=(0.3, 1.0)),
        transforms.RandomHorizontalFlip(), color_transform,
        transforms.ToTensor()
    ])

    gt_ims = []
    fake_ims = []

    label_embed = torch.eye(1000).cuda()
    im = None

    for _, data, label in tqdm(dataloader):
        print(label)
        gt_ims.extend(list((data.numpy() * 255).astype(np.uint8)))

        if im is None:
            im = torch.rand(n, 3, 128, 128).cuda()

        im_noise = torch.randn_like(im).detach()
        # First get good initializations for sampling
        for i in range(5):
            for i in range(60):
                label = torch.randperm(1000).to(im.device)[:n]
                label = label_embed[label]
                im_noise.normal_()
                im = im + 0.001 * im_noise
                # im.requires_grad = True
                im.requires_grad_(requires_grad=True)
                energy = 0

                for model, label in zip(model_list, labels):
                    energy = model.forward(im, label) + energy

                # print("step: ", i, energy.mean())
                im_grad = torch.autograd.grad([energy.sum()], [im])[0]

                im = im - FLAGS.step_lr * im_grad
                im = im.detach()

                im = torch.clamp(im, 0, 1)

            im = im.detach().cpu().numpy().transpose((0, 2, 3, 1))
            im = (im * 255).astype(np.uint8)

            ims = []
            for i in range(im.shape[0]):
                im_i = np.array(transform(Image.fromarray(np.array(im[i]))))
                ims.append(im_i)

            im = torch.Tensor(np.array(ims)).cuda()

        # Then refine the images

        for i in range(FLAGS.num_steps):
            im_noise.normal_()
            im = im + 0.001 * im_noise
            # im.requires_grad = True
            im.requires_grad_(requires_grad=True)
            energy = 0

            label = torch.randperm(1000).to(im.device)[:n]
            label = label_embed[label]

            for model, label in zip(model_list, labels):
                energy = model.forward(im, label) + energy

            print("step: ", i, energy.mean())
            im_grad = torch.autograd.grad([energy.sum()], [im])[0]

            im = im - FLAGS.step_lr * im_grad
            im = im.detach()

            im = torch.clamp(im, 0, 1)

        im_cpu = im.detach().cpu()
        ims = list((im_cpu.numpy().transpose(
            (0, 2, 3, 1)) * 255).astype(np.uint8))

        fake_ims.extend(ims)
        if len(gt_ims) > 50000:
            break

    splits = max(1, len(fake_ims) // 5000)
    score, std = get_inception_score(fake_ims, splits=splits)
    print("inception score {}, with std {} ".format(score, std))
    get_fid_score(gt_ims, fake_ims)
    import pdb
    pdb.set_trace()
    print("here")
Ejemplo n.º 11
0
def ComputeInception(images):
    images = ((images + 1) / 2.0) * 255.0
    images = images.astype(np.uint8)
    IS = inception.get_inception_score(images)
    return IS
Ejemplo n.º 12
0
def compute_inception(model):
    size = FLAGS.im_number
    num_steps = size // 1000

    images = []
    test_ims = []

    if FLAGS.dataset == "cifar10":
        test_dataset = Cifar10(FLAGS)
    elif FLAGS.dataset == "celeba":
        test_dataset = CelebAHQ()
    elif FLAGS.dataset == "mnist":
        test_dataset = Mnist(train=True)

    test_dataloader = DataLoader(test_dataset,
                                 batch_size=FLAGS.batch_size,
                                 num_workers=4,
                                 shuffle=True,
                                 drop_last=False)

    if FLAGS.dataset == "cifar10":
        for data_corrupt, data, label_gt in tqdm(test_dataloader):
            data = data.numpy()
            test_ims.extend(list(rescale_im(data)))

            if len(test_ims) > 10000:
                break
    elif FLAGS.dataset == "mnist":
        for data_corrupt, data, label_gt in tqdm(test_dataloader):
            data = data.numpy()
            test_ims.extend(list(np.tile(rescale_im(data), (1, 1, 3))))

            if len(test_ims) > 10000:
                break

    test_ims = test_ims[:10000]

    classes = 10

    print(FLAGS.batch_size)
    data_buffer = None

    for j in range(num_steps):
        itr = int(1000 / 500 * FLAGS.repeat_scale)

        if data_buffer is None:
            data_buffer = InceptionReplayBuffer(1000)

        curr_index = 0

        identity = np.eye(classes)

        if FLAGS.dataset == "celeba":
            n = 128
            c = 3
        elif FLAGS.dataset == "mnist":
            n = 28
            c = 1
        else:
            n = 32
            c = 3

        for i in tqdm(range(itr)):
            noise_scale = [1]
            if len(data_buffer) < 1000:
                x_init = np.random.uniform(0, 1, (FLAGS.batch_size, c, n, n))
                label = np.random.randint(0, classes, (FLAGS.batch_size))

                x_init = torch.Tensor(x_init).cuda()
                label = identity[label]
                label = torch.Tensor(label).cuda()

                x_new, _ = gen_image(label, FLAGS, model, x_init,
                                     FLAGS.num_steps)
                x_new = x_new.detach().cpu().numpy()
                label = label.detach().cpu().numpy()
                data_buffer.add(x_new, label)
            else:
                if i < itr - FLAGS.nomix:
                    (x_init, label), idx = data_buffer.sample(
                        FLAGS.batch_size, transform=FLAGS.transform)
                else:
                    if FLAGS.dataset == "celeba":
                        n = 20
                    else:
                        n = 2

                    ix = i % n
                    # for i in range(n):
                    start_idx = (1000 // n) * ix
                    end_idx = (1000 // n) * (ix + 1)
                    (x_init, label) = data_buffer._encode_sample(
                        list(range(start_idx, end_idx)), transform=False)
                    idx = list(range(start_idx, end_idx))

                x_init = torch.Tensor(x_init).cuda()
                label = torch.Tensor(label).cuda()
                x_new, energy = gen_image(label, FLAGS, model, x_init,
                                          FLAGS.num_steps)
                energy = energy.cpu().detach().numpy()
                x_new = x_new.cpu().detach().numpy()
                label = label.cpu().detach().numpy()
                data_buffer.set_elms(idx, x_new, label)

                if FLAGS.im_number != 50000:
                    print(np.mean(energy), np.std(energy))

            curr_index += 1

        ims = np.array(data_buffer._storage[:1000])
        ims = rescale_im(ims).transpose((0, 2, 3, 1))

        if FLAGS.dataset == "mnist":
            ims = np.tile(ims, (1, 1, 1, 3))

        images.extend(list(ims))

    random.shuffle(images)
    saveim = osp.join('sandbox_cachedir', FLAGS.exp,
                      "test{}.png".format(FLAGS.idx))

    if FLAGS.dataset == "cifar10":
        rix = np.random.permutation(1000)[:100]
        ims = ims[rix]
        im_panel = ims.reshape((10, 10, 32, 32, 3)).transpose(
            (0, 2, 1, 3, 4)).reshape((320, 320, 3))
        imsave(saveim, im_panel)

        print("Saved image!!!!")
        splits = max(1, len(images) // 5000)
        score, std = get_inception_score(images, splits=splits)
        print("Inception score of {} with std of {}".format(score, std))

        # FID score
        n = min(len(images), len(test_ims))
        fid = get_fid_score(images, test_ims)
        print("FID of score {}".format(fid))

    elif FLAGS.dataset == "mnist":
        # ims = ims[:100]
        # im_panel = ims.reshape((10, 10, 32, 32, 3)).transpose((0, 2, 1, 3, 4)).reshape((320, 320, 3))
        # imsave(saveim, im_panel)

        ims = ims[:100]
        im_panel = ims.reshape((10, 10, 28, 28, 3)).transpose(
            (0, 2, 1, 3, 4)).reshape((280, 280, 3))
        imsave(saveim, im_panel)

        print("Saved image!!!!")
        splits = max(1, len(images) // 5000)
        # score, std = get_inception_score(images, splits=splits)
        # print("Inception score of {} with std of {}".format(score, std))

        # FID score
        n = min(len(images), len(test_ims))
        fid = get_fid_score(images, test_ims)
        print("FID of score {}".format(fid))

    elif FLAGS.dataset == "celeba":

        ims = ims[:25]
        im_panel = ims.reshape((5, 5, 128, 128, 3)).transpose(
            (0, 2, 1, 3, 4)).reshape((5 * 128, 5 * 128, 3))
        imsave(saveim, im_panel)