def unsup_finetune(model, FLAGS_model):
    train_dataset = Cifar10(FLAGS, train=False)
    train_dataloader = DataLoader(train_dataset, batch_size=FLAGS.batch_size, num_workers=FLAGS.data_workers, shuffle=True, drop_last=False)

    classifier = ModelLinear(FLAGS_model).cuda()
    optimizer = optim.Adam(classifier.parameters(),
                          lr=1e-3)
    criterion = nn.CrossEntropyLoss()
    running_accuracy = []

    count = 0
    for i in range(100):
        for data_corrupt, data, label_gt in tqdm(train_dataloader):
            data = data.permute(0, 3, 1, 2).float().cuda()
            target = label_gt.long().cuda()

            with torch.no_grad():
                model_feat = model.compute_feat(data, None)
                model_feat = F.normalize(model_feat, dim=-1, p=2)

            output = classifier(model_feat)

            loss = criterion(output, target)
            optimizer.zero_grad()

            loss.backward()
            optimizer.step()
            acc1, acc5 = accuracy(output, target, topk=(1, 5))
            running_accuracy.append(acc1.item())
            running_accuracy = running_accuracy[-10:]
            print(count, loss, np.mean(running_accuracy))

            count += 1
def energyevalmix(model):
    # dataset = Cifar100(FLAGS, train=False)
    # dataset = Svhn(train=False)
    # dataset = Textures(train=True)
    # dataset = Cifar10(FLAGS, train=False)
    dataset = CelebaSmall()
    test_dataloader = DataLoader(dataset, batch_size=FLAGS.batch_size, num_workers=FLAGS.data_workers, shuffle=True, drop_last=False)

    train_dataset = Cifar10(FLAGS, train=False)
    train_dataloader = DataLoader(train_dataset, batch_size=FLAGS.batch_size, num_workers=FLAGS.data_workers, shuffle=True, drop_last=False)

    test_iter = iter(test_dataloader)

    probs = []
    labels = []
    negs = []
    pos = []
    for data_corrupt, data, label_gt in tqdm(train_dataloader):
        _, data_mix, _ = test_iter.next()
        # data_mix = data_mix[:data.shape[0]]
        # data_mix_permute = torch.cat([data_mix[1:data.shape[0]], data_mix[:1]], dim=0)
        # data_mix = (data_mix + data_mix_permute) / 2.

        data = data.permute(0, 3, 1, 2).float().cuda()
        data_mix = data_mix.permute(0, 3, 1, 2).float().cuda()

        pos_energy = model.forward(data, None).detach().cpu().numpy().mean(axis=-1)
        neg_energy = model.forward(data_mix, None).detach().cpu().numpy().mean(axis=-1)
        print("pos_energy", pos_energy.mean())
        print("neg_energy", neg_energy.mean())

        probs.extend(list(-1*pos_energy))
        probs.extend(list(-1*neg_energy))
        pos.extend(list(-1*pos_energy))
        negs.extend(list(-1*neg_energy))
        labels.extend([1]*pos_energy.shape[0])
        labels.extend([0]*neg_energy.shape[0])

    pos, negs = np.array(pos), np.array(negs)
    np.save("pos.npy", pos)
    np.save("neg.npy", negs)
    auroc = sk.roc_auc_score(labels, probs)
    print("Roc score of {}".format(auroc))
Example #3
0
def dataset_iterator(args):
    if args.dataset == 'mnist':
        train_gen, dev_gen, test_gen = Mnist.load(args.batch_size,
                                                  args.batch_size)
    if args.dataset == 'cifar10':
        data_dir = '../../../images/cifar-10-batches-py/'
        train_gen, dev_gen = Cifar10.load(args.batch_size, data_dir)
        test_gen = None
    if args.dataset == 'imagenet':
        data_dir = '../../../images/imagenet12/imagenet_val_png/'
        train_gen, dev_gen = Imagenet.load(args.batch_size, data_dir)
        test_gen = None
    if args.dataset == 'raise':
        data_dir = '../../../images/raise/'
        train_gen, dev_gen = Raise.load(args.batch_size, data_dir)
        test_gen = None
    else:
        raise ValueError

    return (train_gen, dev_gen, test_gen)
Example #4
0
def main():
    print("Local rank: ", hvd.local_rank(), hvd.size())

    logdir = osp.join(FLAGS.logdir, FLAGS.exp)
    if hvd.rank() == 0:
        if not osp.exists(logdir):
            os.makedirs(logdir)
        logger = TensorBoardOutputFormat(logdir)
    else:
        logger = None

    LABEL = None
    print("Loading data...")
    if FLAGS.dataset == 'cifar10':
        dataset = Cifar10(augment=FLAGS.augment, rescale=FLAGS.rescale)
        test_dataset = Cifar10(train=False, rescale=FLAGS.rescale)
        channel_num = 3

        X_NOISE = tf.placeholder(shape=(None, 32, 32, 3), dtype=tf.float32)
        X = tf.placeholder(shape=(None, 32, 32, 3), dtype=tf.float32)
        LABEL = tf.placeholder(shape=(None, 10), dtype=tf.float32)
        LABEL_POS = tf.placeholder(shape=(None, 10), dtype=tf.float32)

        if FLAGS.large_model:
            model = ResNet32Large(num_channels=channel_num,
                                  num_filters=128,
                                  train=True)
        elif FLAGS.larger_model:
            model = ResNet32Larger(num_channels=channel_num, num_filters=128)
        elif FLAGS.wider_model:
            model = ResNet32Wider(num_channels=channel_num, num_filters=192)
        else:
            model = ResNet32(num_channels=channel_num, num_filters=128)

    elif FLAGS.dataset == 'imagenet':
        dataset = Imagenet(train=True)
        test_dataset = Imagenet(train=False)
        channel_num = 3
        X_NOISE = tf.placeholder(shape=(None, 32, 32, 3), dtype=tf.float32)
        X = tf.placeholder(shape=(None, 32, 32, 3), dtype=tf.float32)
        LABEL = tf.placeholder(shape=(None, 1000), dtype=tf.float32)
        LABEL_POS = tf.placeholder(shape=(None, 1000), dtype=tf.float32)

        model = ResNet32Wider(num_channels=channel_num, num_filters=256)

    elif FLAGS.dataset == 'imagenetfull':
        channel_num = 3
        X_NOISE = tf.placeholder(shape=(None, 128, 128, 3), dtype=tf.float32)
        X = tf.placeholder(shape=(None, 128, 128, 3), dtype=tf.float32)
        LABEL = tf.placeholder(shape=(None, 1000), dtype=tf.float32)
        LABEL_POS = tf.placeholder(shape=(None, 1000), dtype=tf.float32)

        model = ResNet128(num_channels=channel_num, num_filters=64)

    elif FLAGS.dataset == 'mnist':
        dataset = Mnist(rescale=FLAGS.rescale)
        test_dataset = dataset
        channel_num = 1
        X_NOISE = tf.placeholder(shape=(None, 28, 28), dtype=tf.float32)
        X = tf.placeholder(shape=(None, 28, 28), dtype=tf.float32)
        LABEL = tf.placeholder(shape=(None, 10), dtype=tf.float32)
        LABEL_POS = tf.placeholder(shape=(None, 10), dtype=tf.float32)

        model = MnistNet(num_channels=channel_num,
                         num_filters=FLAGS.num_filters)

    elif FLAGS.dataset == 'dsprites':
        dataset = DSprites(cond_shape=FLAGS.cond_shape,
                           cond_size=FLAGS.cond_size,
                           cond_pos=FLAGS.cond_pos,
                           cond_rot=FLAGS.cond_rot)
        test_dataset = dataset
        channel_num = 1

        X_NOISE = tf.placeholder(shape=(None, 64, 64), dtype=tf.float32)
        X = tf.placeholder(shape=(None, 64, 64), dtype=tf.float32)

        if FLAGS.dpos_only:
            LABEL = tf.placeholder(shape=(None, 2), dtype=tf.float32)
            LABEL_POS = tf.placeholder(shape=(None, 2), dtype=tf.float32)
        elif FLAGS.dsize_only:
            LABEL = tf.placeholder(shape=(None, 1), dtype=tf.float32)
            LABEL_POS = tf.placeholder(shape=(None, 1), dtype=tf.float32)
        elif FLAGS.drot_only:
            LABEL = tf.placeholder(shape=(None, 2), dtype=tf.float32)
            LABEL_POS = tf.placeholder(shape=(None, 2), dtype=tf.float32)
        elif FLAGS.cond_size:
            LABEL = tf.placeholder(shape=(None, 1), dtype=tf.float32)
            LABEL_POS = tf.placeholder(shape=(None, 1), dtype=tf.float32)
        elif FLAGS.cond_shape:
            LABEL = tf.placeholder(shape=(None, 3), dtype=tf.float32)
            LABEL_POS = tf.placeholder(shape=(None, 3), dtype=tf.float32)
        elif FLAGS.cond_pos:
            LABEL = tf.placeholder(shape=(None, 2), dtype=tf.float32)
            LABEL_POS = tf.placeholder(shape=(None, 2), dtype=tf.float32)
        elif FLAGS.cond_rot:
            LABEL = tf.placeholder(shape=(None, 2), dtype=tf.float32)
            LABEL_POS = tf.placeholder(shape=(None, 2), dtype=tf.float32)
        else:
            LABEL = tf.placeholder(shape=(None, 3), dtype=tf.float32)
            LABEL_POS = tf.placeholder(shape=(None, 3), dtype=tf.float32)

        model = DspritesNet(num_channels=channel_num,
                            num_filters=FLAGS.num_filters,
                            cond_size=FLAGS.cond_size,
                            cond_shape=FLAGS.cond_shape,
                            cond_pos=FLAGS.cond_pos,
                            cond_rot=FLAGS.cond_rot)

    print("Done loading...")

    if FLAGS.dataset == "imagenetfull":
        # In the case of full imagenet, use custom_tensorflow dataloader
        data_loader = TFImagenetLoader('train',
                                       FLAGS.batch_size,
                                       hvd.rank(),
                                       hvd.size(),
                                       rescale=FLAGS.rescale)
    else:
        data_loader = DataLoader(dataset,
                                 batch_size=FLAGS.batch_size,
                                 num_workers=FLAGS.data_workers,
                                 drop_last=True,
                                 shuffle=True)

    batch_size = FLAGS.batch_size

    weights = [model.construct_weights('context_0')]

    Y = tf.placeholder(shape=(None), dtype=tf.int32)

    # Varibles to run in training
    X_SPLIT = tf.split(X, FLAGS.num_gpus)
    X_NOISE_SPLIT = tf.split(X_NOISE, FLAGS.num_gpus)
    LABEL_SPLIT = tf.split(LABEL, FLAGS.num_gpus)
    LABEL_POS_SPLIT = tf.split(LABEL_POS, FLAGS.num_gpus)
    LABEL_SPLIT_INIT = list(LABEL_SPLIT)
    tower_grads = []
    tower_gen_grads = []
    x_mod_list = []

    optimizer = AdamOptimizer(FLAGS.lr, beta1=0.0, beta2=0.999)
    optimizer = hvd.DistributedOptimizer(optimizer)

    for j in range(FLAGS.num_gpus):

        if FLAGS.model_cclass:
            ind_batch_size = FLAGS.batch_size // FLAGS.num_gpus
            label_tensor = tf.Variable(tf.convert_to_tensor(np.reshape(
                np.tile(np.eye(10), (FLAGS.batch_size, 1, 1)),
                (FLAGS.batch_size * 10, 10)),
                                                            dtype=tf.float32),
                                       trainable=False,
                                       dtype=tf.float32)
            x_split = tf.tile(
                tf.reshape(X_SPLIT[j], (ind_batch_size, 1, 32, 32, 3)),
                (1, 10, 1, 1, 1))
            x_split = tf.reshape(x_split, (ind_batch_size * 10, 32, 32, 3))
            energy_pos = model.forward(x_split,
                                       weights[0],
                                       label=label_tensor,
                                       stop_at_grad=False)

            energy_pos_full = tf.reshape(energy_pos, (ind_batch_size, 10))
            energy_partition_est = tf.reduce_logsumexp(energy_pos_full,
                                                       axis=1,
                                                       keepdims=True)
            uniform = tf.random_uniform(tf.shape(energy_pos_full))
            label_tensor = tf.argmax(-energy_pos_full -
                                     tf.log(-tf.log(uniform)) -
                                     energy_partition_est,
                                     axis=1)
            label = tf.one_hot(label_tensor, 10, dtype=tf.float32)
            label = tf.Print(label, [label_tensor, energy_pos_full])
            LABEL_SPLIT[j] = label
            energy_pos = tf.concat(energy_pos, axis=0)
        else:
            energy_pos = [
                model.forward(X_SPLIT[j],
                              weights[0],
                              label=LABEL_POS_SPLIT[j],
                              stop_at_grad=False)
            ]
            energy_pos = tf.concat(energy_pos, axis=0)

        print("Building graph...")
        x_mod = x_orig = X_NOISE_SPLIT[j]

        x_grads = []

        energy_negs = []
        loss_energys = []

        energy_negs.extend([
            model.forward(tf.stop_gradient(x_mod),
                          weights[0],
                          label=LABEL_SPLIT[j],
                          stop_at_grad=False,
                          reuse=True)
        ])
        eps_begin = tf.zeros(1)

        steps = tf.constant(0)
        c = lambda i, x: tf.less(i, FLAGS.num_steps)

        def langevin_step(counter, x_mod):
            x_mod = x_mod + tf.random_normal(
                tf.shape(x_mod),
                mean=0.0,
                stddev=0.005 * FLAGS.rescale * FLAGS.noise_scale)

            energy_noise = energy_start = tf.concat([
                model.forward(x_mod,
                              weights[0],
                              label=LABEL_SPLIT[j],
                              reuse=True,
                              stop_at_grad=False,
                              stop_batch=True)
            ],
                                                    axis=0)

            x_grad, label_grad = tf.gradients(FLAGS.temperature * energy_noise,
                                              [x_mod, LABEL_SPLIT[j]])
            energy_noise_old = energy_noise

            lr = FLAGS.step_lr

            if FLAGS.proj_norm != 0.0:
                if FLAGS.proj_norm_type == 'l2':
                    x_grad = tf.clip_by_norm(x_grad, FLAGS.proj_norm)
                elif FLAGS.proj_norm_type == 'li':
                    x_grad = tf.clip_by_value(x_grad, -FLAGS.proj_norm,
                                              FLAGS.proj_norm)
                else:
                    print("Other types of projection are not supported!!!")
                    assert False

            # Clip gradient norm for now
            if FLAGS.hmc:
                # Step size should be tuned to get around 65% acceptance
                def energy(x):
                    return FLAGS.temperature * \
                        model.forward(x, weights[0], label=LABEL_SPLIT[j], reuse=True)

                x_last = hmc(x_mod, 15., 10, energy)
            else:
                x_last = x_mod - (lr) * x_grad

            x_mod = x_last
            x_mod = tf.clip_by_value(x_mod, 0, FLAGS.rescale)

            counter = counter + 1

            return counter, x_mod

        steps, x_mod = tf.while_loop(c, langevin_step, (steps, x_mod))

        energy_eval = model.forward(x_mod,
                                    weights[0],
                                    label=LABEL_SPLIT[j],
                                    stop_at_grad=False,
                                    reuse=True)
        x_grad = tf.gradients(FLAGS.temperature * energy_eval, [x_mod])[0]
        x_grads.append(x_grad)

        energy_negs.append(
            model.forward(tf.stop_gradient(x_mod),
                          weights[0],
                          label=LABEL_SPLIT[j],
                          stop_at_grad=False,
                          reuse=True))

        test_x_mod = x_mod

        temp = FLAGS.temperature

        energy_neg = energy_negs[-1]
        x_off = tf.reduce_mean(
            tf.abs(x_mod[:tf.shape(X_SPLIT[j])[0]] - X_SPLIT[j]))

        loss_energy = model.forward(x_mod,
                                    weights[0],
                                    reuse=True,
                                    label=LABEL,
                                    stop_grad=True)

        print("Finished processing loop construction ...")

        target_vars = {}

        if FLAGS.cclass or FLAGS.model_cclass:
            label_sum = tf.reduce_sum(LABEL_SPLIT[0], axis=0)
            label_prob = label_sum / tf.reduce_sum(label_sum)
            label_ent = -tf.reduce_sum(
                label_prob * tf.math.log(label_prob + 1e-7))
        else:
            label_ent = tf.zeros(1)

        target_vars['label_ent'] = label_ent

        if FLAGS.train:

            if FLAGS.objective == 'logsumexp':
                pos_term = temp * energy_pos
                energy_neg_reduced = (energy_neg - tf.reduce_min(energy_neg))
                coeff = tf.stop_gradient(tf.exp(-temp * energy_neg_reduced))
                norm_constant = tf.stop_gradient(tf.reduce_sum(coeff)) + 1e-4
                pos_loss = tf.reduce_mean(temp * energy_pos)
                neg_loss = coeff * (-1 * temp * energy_neg) / norm_constant
                loss_ml = FLAGS.ml_coeff * (pos_loss + tf.reduce_sum(neg_loss))
            elif FLAGS.objective == 'cd':
                pos_loss = tf.reduce_mean(temp * energy_pos)
                neg_loss = -tf.reduce_mean(temp * energy_neg)
                loss_ml = FLAGS.ml_coeff * (pos_loss + tf.reduce_sum(neg_loss))
            elif FLAGS.objective == 'softplus':
                loss_ml = FLAGS.ml_coeff * \
                    tf.nn.softplus(temp * (energy_pos - energy_neg))

            loss_total = tf.reduce_mean(loss_ml)

            if not FLAGS.zero_kl:
                loss_total = loss_total + tf.reduce_mean(loss_energy)

            loss_total = loss_total + \
                FLAGS.l2_coeff * (tf.reduce_mean(tf.square(energy_pos)) + tf.reduce_mean(tf.square((energy_neg))))

            print("Started gradient computation...")
            gvs = optimizer.compute_gradients(loss_total)
            gvs = [(k, v) for (k, v) in gvs if k is not None]

            print("Applying gradients...")

            tower_grads.append(gvs)

            print("Finished applying gradients.")

            target_vars['loss_ml'] = loss_ml
            target_vars['total_loss'] = loss_total
            target_vars['loss_energy'] = loss_energy
            target_vars['weights'] = weights
            target_vars['gvs'] = gvs

        target_vars['X'] = X
        target_vars['Y'] = Y
        target_vars['LABEL'] = LABEL
        target_vars['LABEL_POS'] = LABEL_POS
        target_vars['X_NOISE'] = X_NOISE
        target_vars['energy_pos'] = energy_pos
        target_vars['energy_start'] = energy_negs[0]

        if len(x_grads) >= 1:
            target_vars['x_grad'] = x_grads[-1]
            target_vars['x_grad_first'] = x_grads[0]
        else:
            target_vars['x_grad'] = tf.zeros(1)
            target_vars['x_grad_first'] = tf.zeros(1)

        target_vars['x_mod'] = x_mod
        target_vars['x_off'] = x_off
        target_vars['temp'] = temp
        target_vars['energy_neg'] = energy_neg
        target_vars['test_x_mod'] = test_x_mod
        target_vars['eps_begin'] = eps_begin

    if FLAGS.train:
        grads = average_gradients(tower_grads)
        train_op = optimizer.apply_gradients(grads)
        target_vars['train_op'] = train_op

    config = tf.ConfigProto()

    if hvd.size() > 1:
        config.gpu_options.visible_device_list = str(hvd.local_rank())

    sess = tf.Session(config=config)

    saver = loader = tf.train.Saver(max_to_keep=30,
                                    keep_checkpoint_every_n_hours=6)

    total_parameters = 0
    for variable in tf.trainable_variables():
        # shape is an array of tf.Dimension
        shape = variable.get_shape()
        variable_parameters = 1
        for dim in shape:
            variable_parameters *= dim.value
        total_parameters += variable_parameters
    print("Model has a total of {} parameters".format(total_parameters))

    sess.run(tf.global_variables_initializer())

    resume_itr = 0

    if (FLAGS.resume_iter != -1 or not FLAGS.train) and hvd.rank() == 0:
        model_file = osp.join(logdir, 'model_{}'.format(FLAGS.resume_iter))
        resume_itr = FLAGS.resume_iter
        # saver.restore(sess, model_file)
        optimistic_restore(sess, model_file)

    sess.run(hvd.broadcast_global_variables(0))
    print("Initializing variables...")

    print("Start broadcast")
    print("End broadcast")

    if FLAGS.train:
        print("Training phase")
        train(target_vars, saver, sess, logger, data_loader, resume_itr,
              logdir)
    print("Testing phase")
    test(target_vars, saver, sess, logger, data_loader)
Example #5
0
def compute_inception(sess, target_vars):
    X_START = target_vars['X_START']
    Y_GT = target_vars['Y_GT']
    X_finals = target_vars['X_finals']
    NOISE_SCALE = target_vars['NOISE_SCALE']
    energy_noise = target_vars['energy_noise']

    size = FLAGS.im_number
    num_steps = size // 1000

    images = []
    test_ims = []
    test_images = []


    if FLAGS.dataset == "cifar10":
        test_dataset = Cifar10(full=True, noise=False)
    elif FLAGS.dataset == "celeba":
        dataset = CelebA()
    elif FLAGS.dataset == "imagenet" or FLAGS.dataset == "imagenetfull":
        test_dataset = Imagenet(train=False)

    if FLAGS.dataset != "imagenetfull":
        test_dataloader = DataLoader(test_dataset, batch_size=FLAGS.batch_size, num_workers=4, shuffle=True, drop_last=False)
    else:
        test_dataloader = TFImagenetLoader('test', FLAGS.batch_size, 0, 1)

    for data_corrupt, data, label_gt in tqdm(test_dataloader):
        data = data.numpy()
        test_ims.extend(list(rescale_im(data)))

        if FLAGS.dataset == "imagenetfull" and len(test_ims) > 60000:
            test_ims = test_ims[:60000]
            break


    # n = min(len(images), len(test_ims))
    print(len(test_ims))
    # fid = get_fid_score(test_ims[:30000], test_ims[-30000:])
    # print("Base FID of score {}".format(fid))

    if FLAGS.dataset == "cifar10":
        classes = 10
    else:
        classes = 1000

    if FLAGS.dataset == "imagenetfull":
        n = 128
    else:
        n = 32

    for j in range(num_steps):
        itr = int(1000 / 500 * FLAGS.repeat_scale)
        data_buffer = InceptionReplayBuffer(1000)
        curr_index = 0

        identity = np.eye(classes)

        test_steps = range(300, itr, 20)

        for i in tqdm(range(itr)):
            model_index = curr_index % len(X_finals)
            x_final = X_finals[model_index]

            noise_scale = [1]
            if len(data_buffer) < 1000:
                x_init = np.random.uniform(0, 1, (FLAGS.batch_size, n, n, 3))
                label = np.random.randint(0, classes, (FLAGS.batch_size))
                label = identity[label]
                x_new = sess.run([x_final], {X_START:x_init, Y_GT:label, NOISE_SCALE: noise_scale})[0]
                data_buffer.add(x_new, label)
            else:
                (x_init, label), idx = data_buffer.sample(FLAGS.batch_size)
                keep_mask = (np.random.uniform(0, 1, (FLAGS.batch_size)) > 0.99)
                label_keep_mask = (np.random.uniform(0, 1, (FLAGS.batch_size)) > 0.9)
                label_corrupt = np.random.randint(0, classes, (FLAGS.batch_size))
                label_corrupt = identity[label_corrupt]
                x_init_corrupt = np.random.uniform(0, 1, (FLAGS.batch_size, n, n, 3))

                if i < itr - FLAGS.nomix:
                    x_init[keep_mask] = x_init_corrupt[keep_mask]
                    label[label_keep_mask] = label_corrupt[label_keep_mask]
                # else:
                #     noise_scale = [0.7]

                x_new, e_noise = sess.run([x_final, energy_noise], {X_START:x_init, Y_GT:label, NOISE_SCALE: noise_scale})
                data_buffer.set_elms(idx, x_new, label)

            curr_index += 1

        ims = np.array(data_buffer._storage[:1000])
        ims = rescale_im(ims)
        test_images.extend(list(ims))

    saveim = osp.join(FLAGS.logdir, FLAGS.exp, "test{}.png".format(FLAGS.resume_iter))
    row = 15
    col = 20
    ims = ims[:row * col]
    if FLAGS.dataset != "imagenetfull":
        im_panel = ims.reshape((row, col, 32, 32, 3)).transpose((0, 2, 1, 3, 4)).reshape((32*row, 32*col, 3))
    else:
        im_panel = ims.reshape((row, col, 128, 128, 3)).transpose((0, 2, 1, 3, 4)).reshape((128*row, 128*col, 3))
    imsave(saveim, im_panel)

    splits = max(1, len(test_images) // 5000)
    score, std = get_inception_score(test_images, splits=splits)
    print("Inception score of {} with std of {}".format(score, std))

    # FID score
    # n = min(len(images), len(test_ims))
    fid = get_fid_score(test_images, test_ims)
    print("FID of score {}".format(fid))
Example #6
0
def main():

    # Initialize dataset
    if FLAGS.dataset == 'cifar10':
        dataset = Cifar10(train=False, rescale=FLAGS.rescale)
        channel_num = 3
        dim_input = 32 * 32 * 3
    elif FLAGS.dataset == 'imagenet':
        dataset = ImagenetClass()
        channel_num = 3
        dim_input = 64 * 64 * 3
    elif FLAGS.dataset == 'mnist':
        dataset = Mnist(train=False, rescale=FLAGS.rescale)
        channel_num = 1
        dim_input = 28 * 28 * 1
    elif FLAGS.dataset == 'dsprites':
        dataset = DSprites()
        channel_num = 1
        dim_input = 64 * 64 * 1
    elif FLAGS.dataset == '2d' or FLAGS.dataset == 'gauss':
        dataset = Box2D()

    dim_output = 1
    data_loader = DataLoader(dataset,
                             batch_size=FLAGS.batch_size,
                             num_workers=FLAGS.data_workers,
                             drop_last=False,
                             shuffle=True)

    if FLAGS.dataset == 'mnist':
        model = MnistNet(num_channels=channel_num)
    elif FLAGS.dataset == 'cifar10':
        if FLAGS.large_model:
            model = ResNet32Large(num_filters=128)
        elif FLAGS.wider_model:
            model = ResNet32Wider(num_filters=192)
        else:
            model = ResNet32(num_channels=channel_num, num_filters=128)
    elif FLAGS.dataset == 'dsprites':
        model = DspritesNet(num_channels=channel_num,
                            num_filters=FLAGS.num_filters)

    weights = model.construct_weights('context_{}'.format(0))

    config = tf.ConfigProto()
    sess = tf.Session(config=config)
    saver = loader = tf.train.Saver(max_to_keep=10)

    sess.run(tf.global_variables_initializer())
    logdir = osp.join(FLAGS.logdir, FLAGS.exp)

    model_file = osp.join(logdir, 'model_{}'.format(FLAGS.resume_iter))
    resume_itr = FLAGS.resume_iter

    if FLAGS.resume_iter != "-1":
        optimistic_restore(sess, model_file)
    else:
        print("WARNING, YOU ARE NOT LOADING A SAVE FILE")
    # saver.restore(sess, model_file)

    chain_weights, a_prev, a_new, x, x_init, approx_lr = ancestral_sample(
        model, weights, FLAGS.batch_size, temp=FLAGS.temperature)
    print("Finished constructing ancestral sample ...................")

    if FLAGS.dataset != "gauss":
        comb_weights_cum = []
        batch_size = tf.shape(x_init)[0]
        label_tiled = tf.tile(label_default, (batch_size, 1))
        e_compute = -FLAGS.temperature * model.forward(
            x_init, weights, label=label_tiled)
        e_pos_list = []

        for data_corrupt, data, label_gt in tqdm(data_loader):
            e_pos = sess.run([e_compute], {x_init: data})[0]
            e_pos_list.extend(list(e_pos))

        print(len(e_pos_list))
        print("Positive sample probability ", np.mean(e_pos_list),
              np.std(e_pos_list))

    if FLAGS.dataset == "2d":
        alr = 0.0045
    elif FLAGS.dataset == "gauss":
        alr = 0.0085
    elif FLAGS.dataset == "mnist":
        alr = 0.0065
        #90 alr = 0.0035
    else:
        # alr = 0.0125
        if FLAGS.rescale == 8:
            alr = 0.0085
        else:
            alr = 0.0045


#
    for i in range(1):
        tot_weight = 0
        for j in tqdm(range(1, FLAGS.pdist + 1)):
            if j == 1:
                if FLAGS.dataset == "cifar10":
                    x_curr = np.random.uniform(0,
                                               FLAGS.rescale,
                                               size=(FLAGS.batch_size, 32, 32,
                                                     3))
                elif FLAGS.dataset == "gauss":
                    x_curr = np.random.uniform(0,
                                               FLAGS.rescale,
                                               size=(FLAGS.batch_size,
                                                     FLAGS.gauss_dim))
                elif FLAGS.dataset == "mnist":
                    x_curr = np.random.uniform(0,
                                               FLAGS.rescale,
                                               size=(FLAGS.batch_size, 28, 28))
                else:
                    x_curr = np.random.uniform(0,
                                               FLAGS.rescale,
                                               size=(FLAGS.batch_size, 2))

            alpha_prev = (j - 1) / FLAGS.pdist
            alpha_new = j / FLAGS.pdist
            cweight, x_curr = sess.run(
                [chain_weights, x], {
                    a_prev: alpha_prev,
                    a_new: alpha_new,
                    x_init: x_curr,
                    approx_lr: alr * (5**(2.5 * -alpha_prev))
                })
            tot_weight = tot_weight + cweight

        print("Total values of lower value based off forward sampling",
              np.mean(tot_weight), np.std(tot_weight))

        tot_weight = 0

        for j in tqdm(range(FLAGS.pdist, 0, -1)):
            alpha_new = (j - 1) / FLAGS.pdist
            alpha_prev = j / FLAGS.pdist
            cweight, x_curr = sess.run(
                [chain_weights, x], {
                    a_prev: alpha_prev,
                    a_new: alpha_new,
                    x_init: x_curr,
                    approx_lr: alr * (5**(2.5 * -alpha_prev))
                })
            tot_weight = tot_weight - cweight

        print("Total values of upper value based off backward sampling",
              np.mean(tot_weight), np.std(tot_weight))
        return inception_score

    def test_unconditional(self, data_loader):
        print("Testing phase")
        inception_score = test(self.target_vars,
                               self.saver,
                               self.sess,
                               self.logger,
                               data_loader)
        return inception_score

if __name__ == "__main__":

    print("Loading data...")
    path = "/home/abhi/Documents/courses/UofT/CSC2506/project/data/cifar10"
    train_dataset = Cifar10(train=True, augment=FLAGS.augment, rescale=FLAGS.rescale, path=path)
    train_dataset_1 = torch.utils.data.Subset(train_dataset, list(range(0, 1000, 2)))
    print("Length of train_dataset:%d"%len(train_dataset_1))
    data_loader = DataLoader(train_dataset_1, batch_size=FLAGS.batch_size, num_workers=FLAGS.data_workers, drop_last=True, shuffle=True)   
    print("Done loading...")

    base_functions=[elu,gelu,linear,relu,selu,sigmoid,softplus,swish,tanh,atan,cos,erf,sin,sqrt]
    base_operations=[maximum,minimum,add,subtract]

    evo=EvolutionaryAlgorithm(base_functions,base_operations,min_depth=1,max_depth=3,pop_size=10)
    initial_gen = evo.create_generation()
    custom_act = compile(initial_gen[0], pset=evo.pset)

    ebm_prob = EBMProbML(tf.nn.leaky_relu)
    # ebm_prob = EBMProbML(custom_act)
    train_inc_score = ebm_prob.train_unconditional(data_loader)
def main():

    if FLAGS.dataset == "cifar10":
        dataset = Cifar10(train=True, noise=False)
        test_dataset = Cifar10(train=False, noise=False)
    else:
        dataset = Imagenet(train=True)
        test_dataset = Imagenet(train=False)

    if FLAGS.svhn:
        dataset = Svhn(train=True)
        test_dataset = Svhn(train=False)

    if FLAGS.task == 'latent':
        dataset = DSprites()
        test_dataset = dataset

    dataloader = DataLoader(dataset,
                            batch_size=FLAGS.batch_size,
                            num_workers=FLAGS.data_workers,
                            shuffle=True,
                            drop_last=True)
    test_dataloader = DataLoader(test_dataset,
                                 batch_size=FLAGS.batch_size,
                                 num_workers=FLAGS.data_workers,
                                 shuffle=True,
                                 drop_last=True)

    hidden_dim = 128

    if FLAGS.large_model:
        model = ResNet32Large(num_filters=hidden_dim)
    elif FLAGS.larger_model:
        model = ResNet32Larger(num_filters=hidden_dim)
    elif FLAGS.wider_model:
        if FLAGS.dataset == 'imagenet':
            model = ResNet32Wider(num_filters=196, train=False)
        else:
            model = ResNet32Wider(num_filters=256, train=False)
    else:
        model = ResNet32(num_filters=hidden_dim)

    if FLAGS.task == 'latent':
        model = DspritesNet()

    weights = model.construct_weights('context_{}'.format(0))

    total_parameters = 0
    for variable in tf.compat.v1.trainable_variables():
        # shape is an array of tf.Dimension
        shape = variable.get_shape()
        variable_parameters = 1
        for dim in shape:
            variable_parameters *= dim.value
        total_parameters += variable_parameters
    print("Model has a total of {} parameters".format(total_parameters))

    config = tf.compat.v1.ConfigProto()
    sess = tf.compat.v1.InteractiveSession()

    if FLAGS.task == 'latent':
        X = tf.compat.v1.placeholder(shape=(None, 64, 64), dtype=tf.float32)
    else:
        X = tf.compat.v1.placeholder(shape=(None, 32, 32, 3), dtype=tf.float32)

    if FLAGS.dataset == "cifar10":
        Y = tf.compat.v1.placeholder(shape=(None, 10), dtype=tf.float32)
        Y_GT = tf.compat.v1.placeholder(shape=(None, 10), dtype=tf.float32)
    elif FLAGS.dataset == "imagenet":
        Y = tf.compat.v1.placeholder(shape=(None, 1000), dtype=tf.float32)
        Y_GT = tf.compat.v1.placeholder(shape=(None, 1000), dtype=tf.float32)

    target_vars = {'X': X, 'Y': Y, 'Y_GT': Y_GT}

    if FLAGS.task == 'label':
        construct_label(weights, X, Y, Y_GT, model, target_vars)
    elif FLAGS.task == 'labelfinetune':
        construct_finetune_label(
            weights,
            X,
            Y,
            Y_GT,
            model,
            target_vars,
        )
    elif FLAGS.task == 'energyeval' or FLAGS.task == 'mixenergy':
        construct_energy(weights, X, Y, Y_GT, model, target_vars)
    elif FLAGS.task == 'anticorrupt' or FLAGS.task == 'boxcorrupt' or FLAGS.task == 'crossclass' or FLAGS.task == 'cycleclass' or FLAGS.task == 'democlass' or FLAGS.task == 'nearestneighbor':
        construct_steps(weights, X, Y_GT, model, target_vars)
    elif FLAGS.task == 'latent':
        construct_latent(weights, X, Y_GT, model, target_vars)

    sess.run(tf.compat.v1.global_variables_initializer())
    saver = loader = tf.compat.v1.train.Saver(max_to_keep=10)
    savedir = osp.join('cachedir', FLAGS.exp)
    logdir = osp.join(FLAGS.logdir, FLAGS.exp)
    if not osp.exists(logdir):
        os.makedirs(logdir)

    initialize()
    if FLAGS.resume_iter != -1:
        model_file = osp.join(savedir, 'model_{}'.format(FLAGS.resume_iter))
        resume_itr = FLAGS.resume_iter

        if FLAGS.task == 'label' or FLAGS.task == 'boxcorrupt' or FLAGS.task == 'labelfinetune' or FLAGS.task == "energyeval" or FLAGS.task == "crossclass" or FLAGS.task == "mixenergy":
            optimistic_restore(sess, model_file)
            # saver.restore(sess, model_file)
        else:
            # optimistic_restore(sess, model_file)
            saver.restore(sess, model_file)

    if FLAGS.task == 'label':
        if FLAGS.labelgrid:
            vals = []
            if FLAGS.lnorm == -1:
                for i in range(31):
                    accuracies = label(dataloader,
                                       test_dataloader,
                                       target_vars,
                                       sess,
                                       l1val=i)
                    vals.append(accuracies)
            elif FLAGS.lnorm == 2:
                for i in range(0, 100, 5):
                    accuracies = label(dataloader,
                                       test_dataloader,
                                       target_vars,
                                       sess,
                                       l2val=i)
                    vals.append(accuracies)

            np.save("result_{}_{}.npy".format(FLAGS.lnorm, FLAGS.exp), vals)
        else:
            label(dataloader, test_dataloader, target_vars, sess)
    elif FLAGS.task == 'labelfinetune':
        labelfinetune(dataloader,
                      test_dataloader,
                      target_vars,
                      sess,
                      savedir,
                      saver,
                      l1val=FLAGS.lival,
                      l2val=FLAGS.l2val)
    elif FLAGS.task == 'energyeval':
        energyeval(dataloader, test_dataloader, target_vars, sess)
    elif FLAGS.task == 'mixenergy':
        energyevalmix(dataloader, test_dataloader, target_vars, sess)
    elif FLAGS.task == 'anticorrupt':
        anticorrupt(test_dataloader, weights, model, target_vars, logdir, sess)
    elif FLAGS.task == 'boxcorrupt':
        # boxcorrupt(test_dataloader, weights, model, target_vars, logdir, sess)
        boxcorrupt(test_dataloader, dataloader, weights, model, target_vars,
                   logdir, sess)
    elif FLAGS.task == 'crossclass':
        crossclass(test_dataloader, weights, model, target_vars, logdir, sess)
    elif FLAGS.task == 'cycleclass':
        cycleclass(test_dataloader, weights, model, target_vars, logdir, sess)
    elif FLAGS.task == 'democlass':
        democlass(test_dataloader, weights, model, target_vars, logdir, sess)
    elif FLAGS.task == 'nearestneighbor':
        # print(dir(dataset))
        # print(type(dataset))
        nearest_neighbor(dataset.data.train_data / 255, sess, target_vars,
                         logdir)
    elif FLAGS.task == 'latent':
        latent(test_dataloader, weights, model, target_vars, sess)
Example #9
0
from datetime import datetime
from torch import nn, optim
from torch.utils import data
from data import Cifar10, FashionMNIST
from model import MLP, ConvNet, OneLayer
from utils import train, test
from config import Config as config

print(f"{datetime.now().ctime()} - Start Loading Dataset...")
if config.dataset == "cifar10":
    train_dataset = Cifar10(config.root, train=True)
    test_dataset = Cifar10(config.root, train=False)
elif config.dataset == "fashionmnist":
    train_dataset = FashionMNIST(config.root, train=True)
    test_dataset = FashionMNIST(config.root, train=False)
train_dataloader = data.DataLoader(train_dataset,
                                   config.batch_size,
                                   shuffle=True,
                                   num_workers=2)
test_dataloader = data.DataLoader(test_dataset,
                                  config.batch_size,
                                  shuffle=False,
                                  num_workers=2)
print(f"{datetime.now().ctime()} - Finish Loading Dataset")

print(
    f"{datetime.now().ctime()} - Start Creating Net, Criterion, Optimizer and Scheduler..."
)
if config.model == "mlp":
    net = MLP(config.cifar10_input_size, config.num_classes)
elif config.model == "convnet":
Example #10
0
from config import DefaultConfig
from data import Cifar10
import models
from Trainer import Trainer
import csv
import matplotlib.pyplot as plt
import os
from torchvision import datasets, transforms

opt = DefaultConfig()

if not os.path.exists(opt.train_data_root):
    from precifar import *
    Precifar()

train_data = Cifar10(opt.train_data_root, train=True)
val_data = Cifar10(opt.train_data_root, train=False)
test_data = Cifar10(opt.test_data_root, test=True)

Model = getattr(models, opt.model)()
if opt.load_model_path:
    Model.load(opt.load_model_path)
Cifar_Trainer = Trainer(Model, opt)

Cifar_Trainer.train(train_data, val_data)
Model.save(opt.save_model_path)

results, confusion_matrix, accuracy = Cifar_Trainer.test(test_data)

with open(opt.result_file, 'wt', newline='', encoding='utf-8') as csvfile:
    try:
Example #11
0
def main():
    print("Local rank: ", hvd.local_rank(), hvd.size())
    FLAGS.exp = FLAGS.exp + '_' + FLAGS.divergence

    logdir = osp.join(FLAGS.logdir, FLAGS.exp)
    if hvd.rank() == 0:
        if not osp.exists(logdir):
            os.makedirs(logdir)
        logger = TensorBoardOutputFormat(logdir)
    else:
        logger = None

    print("Loading data...")
    dataset = Cifar10(augment=FLAGS.augment, rescale=FLAGS.rescale)
    test_dataset = Cifar10(train=False, rescale=FLAGS.rescale)
    channel_num = 3

    X_NOISE = tf.placeholder(shape=(None, 32, 32, 3), dtype=tf.float32)
    X = tf.placeholder(shape=(None, 32, 32, 3), dtype=tf.float32)
    LABEL = tf.placeholder(shape=(None, 10), dtype=tf.float32)
    LABEL_POS = tf.placeholder(shape=(None, 10), dtype=tf.float32)

    if FLAGS.large_model:
        model = ResNet32Large(
            num_channels=channel_num,
            num_filters=128,
            train=True)
        model_dis = ResNet32Large(
            num_channels=channel_num,
            num_filters=128,
            train=True)
    elif FLAGS.larger_model:
        model = ResNet32Larger(
            num_channels=channel_num,
            num_filters=128)
        model_dis = ResNet32Larger(
            num_channels=channel_num,
            num_filters=128)
    elif FLAGS.wider_model:
        model = ResNet32Wider(
            num_channels=channel_num,
            num_filters=256)
        model_dis = ResNet32Wider(
            num_channels=channel_num,
            num_filters=256)
    else:
        model = ResNet32(
            num_channels=channel_num,
            num_filters=128)
        model_dis = ResNet32(
            num_channels=channel_num,
            num_filters=128)

    print("Done loading...")

    grad_exp, conjugate_grad_exp = get_divergence_funcs(FLAGS.divergence)

    data_loader = DataLoader(
        dataset,
        batch_size=FLAGS.batch_size,
        num_workers=FLAGS.data_workers,
        drop_last=True,
        shuffle=True)

    weights = [model.construct_weights('context_energy'), model_dis.construct_weights('context_dis')]

    Y = tf.placeholder(shape=(None), dtype=tf.int32)

    # Varibles to run in training
    X_SPLIT = tf.split(X, FLAGS.num_gpus)
    X_NOISE_SPLIT = tf.split(X_NOISE, FLAGS.num_gpus)
    LABEL_SPLIT = tf.split(LABEL, FLAGS.num_gpus)
    LABEL_POS_SPLIT = tf.split(LABEL_POS, FLAGS.num_gpus)
    LABEL_SPLIT_INIT = list(LABEL_SPLIT)
    tower_grads = []
    tower_grads_dis = []
    tower_grads_l2 = []
    tower_grads_dis_l2 = []

    optimizer = AdamOptimizer(FLAGS.lr, beta1=0.0, beta2=0.999)
    optimizer = hvd.DistributedOptimizer(optimizer)

    optimizer_dis = AdamOptimizer(FLAGS.lr, beta1=0.0, beta2=0.999)
    optimizer_dis = hvd.DistributedOptimizer(optimizer_dis)

    for j in range(FLAGS.num_gpus):

        energy_pos = [
            model.forward(
                X_SPLIT[j],
                weights[0],
                label=LABEL_POS_SPLIT[j],
                stop_at_grad=False)]
        energy_pos = tf.concat(energy_pos, axis=0)

        score_pos = [
            model_dis.forward(
                X_SPLIT[j],
                weights[1],
                label=LABEL_POS_SPLIT[j],
                stop_at_grad=False)]
        score_pos = tf.concat(score_pos, axis=0)

        print("Building graph...")
        x_mod = x_orig = X_NOISE_SPLIT[j]

        x_grads = []

        energy_negs = []
        loss_energys = []

        energy_negs.extend([model.forward(tf.stop_gradient(
            x_mod), weights[0], label=LABEL_SPLIT[j], stop_at_grad=False, reuse=True)])
        eps_begin = tf.zeros(1)

        steps = tf.constant(0)
        c = lambda i, x: tf.less(i, FLAGS.num_steps)

        def langevin_step(counter, x_mod):
            x_mod = x_mod + tf.random_normal(tf.shape(x_mod),
                                             mean=0.0,
                                             stddev=0.005 * FLAGS.rescale * FLAGS.noise_scale)

            energy_noise = energy_start = tf.concat(
                [model.forward(
                        x_mod,
                        weights[0],
                        label=LABEL_SPLIT[j],
                        reuse=True,
                        stop_at_grad=False,
                        stop_batch=True)],
                axis=0)

            x_grad, label_grad = tf.gradients(energy_noise, [x_mod, LABEL_SPLIT[j]])
            energy_noise_old = energy_noise

            lr = FLAGS.step_lr

            if FLAGS.proj_norm != 0.0:
                if FLAGS.proj_norm_type == 'l2':
                    x_grad = tf.clip_by_norm(x_grad, FLAGS.proj_norm)
                elif FLAGS.proj_norm_type == 'li':
                    x_grad = tf.clip_by_value(
                        x_grad, -FLAGS.proj_norm, FLAGS.proj_norm)
                else:
                    print("Other types of projection are not supported!!!")
                    assert False

            # Clip gradient norm for now
            if FLAGS.hmc:
                # Step size should be tuned to get around 65% acceptance
                def energy(x):
                    return FLAGS.temperature * \
                        model.forward(x, weights[0], label=LABEL_SPLIT[j], reuse=True)

                x_last = hmc(x_mod, 15., 10, energy)
            else:
                x_last = x_mod - (lr) * x_grad

            x_mod = x_last
            x_mod = tf.clip_by_value(x_mod, 0, FLAGS.rescale)

            counter = counter + 1

            return counter, x_mod

        steps, x_mod = tf.while_loop(c, langevin_step, (steps, x_mod))

        energy_eval = model.forward(x_mod, weights[0], label=LABEL_SPLIT[j],
                                    stop_at_grad=False, reuse=True)
        x_grad = tf.gradients(energy_eval, [x_mod])[0]
        x_grads.append(x_grad)

        energy_negs.append(
            model.forward(
                tf.stop_gradient(x_mod),
                weights[0],
                label=LABEL_SPLIT[j],
                stop_at_grad=False,
                reuse=True))

        score_neg = model_dis.forward(
                tf.stop_gradient(x_mod),
                weights[1],
                label=LABEL_SPLIT[j],
                stop_at_grad=False,
                reuse=True)

        test_x_mod = x_mod

        temp = FLAGS.temperature

        energy_neg = energy_negs[-1]
        x_off = tf.reduce_mean(
            tf.abs(x_mod[:tf.shape(X_SPLIT[j])[0]] - X_SPLIT[j]))

        loss_energy = model.forward(
            x_mod,
            weights[0],
            reuse=True,
            label=LABEL,
            stop_grad=True)

        print("Finished processing loop construction ...")

        target_vars = {}

        if FLAGS.cclass or FLAGS.model_cclass:
            label_sum = tf.reduce_sum(LABEL_SPLIT[0], axis=0)
            label_prob = label_sum / tf.reduce_sum(label_sum)
            label_ent = -tf.reduce_sum(label_prob *
                                       tf.math.log(label_prob + 1e-7))
        else:
            label_ent = tf.zeros(1)

        target_vars['label_ent'] = label_ent

        if FLAGS.train:

            loss_dis = - (tf.reduce_mean(grad_exp(score_pos + energy_pos)) - tf.reduce_mean(conjugate_grad_exp(score_neg + energy_neg)))
            loss_dis = loss_dis + FLAGS.l2_coeff * (tf.reduce_mean(tf.square(score_pos)) + tf.reduce_mean(tf.square(score_neg)))
            l2_dis = FLAGS.l2_coeff * (tf.reduce_mean(tf.square(score_pos)) + tf.reduce_mean(tf.square(score_neg)))

            loss_model = tf.reduce_mean(grad_exp(score_pos + energy_pos)) + \
                         tf.reduce_mean(energy_neg * tf.stop_gradient(conjugate_grad_exp(score_neg + energy_neg))) - \
                         tf.reduce_mean(energy_neg) * tf.stop_gradient(tf.reduce_mean(conjugate_grad_exp(score_neg + energy_neg))) - \
                         tf.reduce_mean(conjugate_grad_exp(score_neg + energy_neg))
            loss_model = loss_model + FLAGS.l2_coeff * (tf.reduce_mean(tf.square(energy_pos)) + tf.reduce_mean(tf.square(energy_neg)))
            l2_model = FLAGS.l2_coeff * (tf.reduce_mean(tf.square(energy_pos)) + tf.reduce_mean(tf.square(energy_neg)))

            print("Started gradient computation...")
            model_vars = [var for var in tf.trainable_variables() if 'context_energy' in var.name]
            print("model var number", len(model_vars))
            dis_vars = [var for var in tf.trainable_variables() if 'context_dis' in var.name]
            print("discriminator var number", len(dis_vars))

            gvs = optimizer.compute_gradients(loss_model, model_vars)
            gvs = [(k, v) for (k, v) in gvs if k is not None]
            tower_grads.append(gvs)

            gvs = optimizer.compute_gradients(l2_model, model_vars)
            gvs = [(k, v) for (k, v) in gvs if k is not None]
            tower_grads_l2.append(gvs)

            gvs_dis = optimizer_dis.compute_gradients(loss_dis, dis_vars)
            gvs_dis = [(k, v) for (k, v) in gvs_dis if k is not None]
            tower_grads_dis.append(gvs_dis)

            gvs_dis = optimizer_dis.compute_gradients(l2_dis, dis_vars)
            gvs_dis = [(k, v) for (k, v) in gvs_dis if k is not None]
            tower_grads_dis_l2.append(gvs_dis)

            print("Finished applying gradients.")

            target_vars['total_loss'] = loss_model
            target_vars['loss_energy'] = loss_energy
            target_vars['weights'] = weights
            target_vars['gvs'] = gvs

        target_vars['X'] = X
        target_vars['Y'] = Y
        target_vars['LABEL'] = LABEL
        target_vars['LABEL_POS'] = LABEL_POS
        target_vars['X_NOISE'] = X_NOISE
        target_vars['energy_pos'] = energy_pos
        target_vars['energy_start'] = energy_negs[0]

        if len(x_grads) >= 1:
            target_vars['x_grad'] = x_grads[-1]
            target_vars['x_grad_first'] = x_grads[0]
        else:
            target_vars['x_grad'] = tf.zeros(1)
            target_vars['x_grad_first'] = tf.zeros(1)

        target_vars['x_mod'] = x_mod
        target_vars['x_off'] = x_off
        target_vars['temp'] = temp
        target_vars['energy_neg'] = energy_neg
        target_vars['test_x_mod'] = test_x_mod
        target_vars['eps_begin'] = eps_begin
        target_vars['score_neg'] = score_neg
        target_vars['score_pos'] = score_pos

    if FLAGS.train:
        grads_model = average_gradients(tower_grads)
        train_op_model = optimizer.apply_gradients(grads_model)
        target_vars['train_op_model'] = train_op_model

        grads_model_l2 = average_gradients(tower_grads_l2)
        train_op_model_l2 = optimizer.apply_gradients(grads_model_l2)
        target_vars['train_op_model_l2'] = train_op_model_l2

        grads_model_dis = average_gradients(tower_grads_dis)
        train_op_dis = optimizer_dis.apply_gradients(grads_model_dis)
        target_vars['train_op_dis'] = train_op_dis

        grads_model_dis_l2 = average_gradients(tower_grads_dis_l2)
        train_op_dis_l2 = optimizer_dis.apply_gradients(grads_model_dis_l2)
        target_vars['train_op_dis_l2'] = train_op_dis_l2

    config = tf.ConfigProto()

    if hvd.size() > 1:
        config.gpu_options.visible_device_list = str(hvd.local_rank())

    sess = tf.Session(config=config)

    saver = loader = tf.train.Saver(max_to_keep=500)

    total_parameters = 0
    for variable in tf.trainable_variables():
        # shape is an array of tf.Dimension
        shape = variable.get_shape()
        variable_parameters = 1
        for dim in shape:
            variable_parameters *= dim.value
        total_parameters += variable_parameters
    print("Model has a total of {} parameters".format(total_parameters))

    sess.run(tf.global_variables_initializer())

    resume_itr = 0

    if (FLAGS.resume_iter != -1 or not FLAGS.train) and hvd.rank() == 0:
        model_file = osp.join(logdir, 'model_{}'.format(FLAGS.resume_iter))
        resume_itr = FLAGS.resume_iter
        saver.restore(sess, model_file)
        # optimistic_restore(sess, model_file)

    sess.run(hvd.broadcast_global_variables(0))
    print("Initializing variables...")

    print("Start broadcast")
    print("End broadcast")

    if FLAGS.train:
        train(target_vars, saver, sess,
              logger, data_loader, resume_itr,
              logdir)

    test(target_vars, saver, sess, logger, data_loader)
def main():

    # Initialize dataset
    dataset = Cifar10(FLAGS, train=False, rescale=FLAGS.rescale)
    data_loader = DataLoader(dataset,
                             batch_size=FLAGS.batch_size,
                             num_workers=4,
                             drop_last=False,
                             shuffle=True)

    logdir = osp.join(FLAGS.logdir, FLAGS.exp)
    model_path = osp.join(logdir, "model_best.pth")
    checkpoint = torch.load(model_path)
    FLAGS_model = checkpoint['FLAGS']
    model = ResNetModel(FLAGS_model).eval().cuda()

    if FLAGS.ema:
        model.load_state_dict(checkpoint['ema_model_state_dict_0'])
    else:
        model.load_state_dict(checkpoint['model_state_dict_0'])

    print("Finished constructing ancestral sample ...................")
    e_pos_list = []

    for data_corrupt, data, label_gt in tqdm(data_loader):
        data = data.permute(0, 3, 1, 2).contiguous().cuda()
        energy = model.forward(data, None)
        energy = -FLAGS.temperature * energy.squeeze().detach().cpu().numpy()
        e_pos_list.extend(list(energy))

    print(len(e_pos_list))
    print("Positive sample probability ", np.mean(e_pos_list),
          np.std(e_pos_list))

    # alr = 0.0065
    alr = 10.0
    #
    for i in range(1):
        tot_weight = 0
        for j in tqdm(range(1, FLAGS.pdist + 1)):

            if j == 1:
                x_curr = torch.rand(FLAGS.batch_size, 3, 32, 32).cuda()

            alpha_prev = (j - 1) / FLAGS.pdist
            alpha_new = j / FLAGS.pdist
            cweight, x_curr = ancestral_sample(model,
                                               x_curr,
                                               alpha_prev,
                                               alpha_new,
                                               FLAGS,
                                               FLAGS.batch_size,
                                               FLAGS.pdist,
                                               temp=FLAGS.temperature,
                                               approx_lr=alr)
            tot_weight = tot_weight + cweight.detach()
            x_curr = x_curr.detach()

        tot_weight = tot_weight.detach().cpu().float().numpy()
        print("Total values of lower value based off forward sampling",
              np.mean(tot_weight), np.std(tot_weight))

        tot_weight = 0
        x_curr = x_curr.detach()

        for j in tqdm(range(FLAGS.pdist, 0, -1)):
            alpha_new = (j - 1) / FLAGS.pdist
            alpha_prev = j / FLAGS.pdist
            cweight, x_curr = ancestral_sample(model,
                                               x_curr,
                                               alpha_prev,
                                               alpha_new,
                                               FLAGS,
                                               FLAGS.batch_size,
                                               FLAGS.pdist,
                                               temp=FLAGS.temperature,
                                               approx_lr=alr)
            tot_weight = tot_weight - cweight.detach()
            x_curr = x_curr.detach()

        tot_weight = tot_weight.detach().cpu().float().numpy()
        print("Total values of upper value based off backward sampling",
              np.mean(tot_weight), np.std(tot_weight))
Example #13
0
def main_single(gpu, FLAGS):
    if FLAGS.slurm:
        init_distributed_mode(FLAGS)

    os.environ['MASTER_ADDR'] = FLAGS.master_addr
    os.environ['MASTER_PORT'] = FLAGS.port

    rank_idx = FLAGS.node_rank * FLAGS.gpus + gpu
    world_size = FLAGS.nodes * FLAGS.gpus
    print("Values of args: ", FLAGS)

    if world_size > 1:
        if FLAGS.slurm:
            dist.init_process_group(backend='nccl', init_method='env://', world_size=world_size, rank=rank_idx)
        else:
            dist.init_process_group(backend='nccl', init_method='tcp://localhost:1700', world_size=world_size, rank=rank_idx)

    if FLAGS.dataset == "cifar10":
        train_dataset = Cifar10(FLAGS)
        valid_dataset = Cifar10(FLAGS, train=False, augment=False)
        test_dataset = Cifar10(FLAGS, train=False, augment=False)
    elif FLAGS.dataset == "stl":
        train_dataset = STLDataset(FLAGS)
        valid_dataset = STLDataset(FLAGS, train=False)
        test_dataset = STLDataset(FLAGS, train=False)
    elif FLAGS.dataset == "object":
        train_dataset = ObjectDataset(FLAGS.cond_idx)
        valid_dataset = ObjectDataset(FLAGS.cond_idx)
        test_dataset = ObjectDataset(FLAGS.cond_idx)
    elif FLAGS.dataset == "imagenet":
        train_dataset = ImageNet()
        valid_dataset = ImageNet()
        test_dataset = ImageNet()
    elif FLAGS.dataset == "mnist":
        train_dataset = Mnist(train=True)
        valid_dataset = Mnist(train=False)
        test_dataset = Mnist(train=False)
    elif FLAGS.dataset == "celeba":
        train_dataset = CelebAHQ(cond_idx=FLAGS.cond_idx)
        valid_dataset = CelebAHQ(cond_idx=FLAGS.cond_idx)
        test_dataset = CelebAHQ(cond_idx=FLAGS.cond_idx)
    elif FLAGS.dataset == "lsun":
        train_dataset = LSUNBed(cond_idx=FLAGS.cond_idx)
        valid_dataset = LSUNBed(cond_idx=FLAGS.cond_idx)
        test_dataset = LSUNBed(cond_idx=FLAGS.cond_idx)
    else:
        assert False

    train_dataloader = DataLoader(train_dataset, num_workers=FLAGS.data_workers, batch_size=FLAGS.batch_size, shuffle=True, drop_last=True)
    valid_dataloader = DataLoader(valid_dataset, num_workers=FLAGS.data_workers, batch_size=FLAGS.batch_size, shuffle=True, drop_last=True)
    test_dataloader = DataLoader(test_dataset, num_workers=FLAGS.data_workers, batch_size=FLAGS.batch_size, shuffle=True, drop_last=True)

    FLAGS_OLD = FLAGS

    logdir = osp.join(FLAGS.logdir, FLAGS.exp)
    best_inception = 0.0

    if FLAGS.resume_iter != 0:
        model_path = osp.join(logdir, "model_{}.pth".format(FLAGS.resume_iter))
        checkpoint = torch.load(model_path)
        best_inception = checkpoint['best_inception']
        FLAGS = checkpoint['FLAGS']

        FLAGS.resume_iter = FLAGS_OLD.resume_iter
        FLAGS.nodes = FLAGS_OLD.nodes
        FLAGS.gpus = FLAGS_OLD.gpus
        FLAGS.node_rank = FLAGS_OLD.node_rank
        FLAGS.master_addr = FLAGS_OLD.master_addr
        FLAGS.train = FLAGS_OLD.train
        FLAGS.num_steps = FLAGS_OLD.num_steps
        FLAGS.step_lr = FLAGS_OLD.step_lr
        FLAGS.batch_size = FLAGS_OLD.batch_size
        FLAGS.ensembles = FLAGS_OLD.ensembles
        FLAGS.kl_coeff = FLAGS_OLD.kl_coeff
        FLAGS.repel_im = FLAGS_OLD.repel_im
        FLAGS.save_interval = FLAGS_OLD.save_interval

        for key in dir(FLAGS):
            if "__" not in key:
                FLAGS_OLD[key] = getattr(FLAGS, key)

        FLAGS = FLAGS_OLD

    if FLAGS.dataset == "cifar10":
        model_fn = ResNetModel
    elif FLAGS.dataset == "stl":
        model_fn = ResNetModel
    elif FLAGS.dataset == "object":
        model_fn = CelebAModel
    elif FLAGS.dataset == "mnist":
        model_fn = MNISTModel
    elif FLAGS.dataset == "celeba":
        model_fn = CelebAModel
    elif FLAGS.dataset == "lsun":
        model_fn = CelebAModel
    elif FLAGS.dataset == "imagenet":
        model_fn = ImagenetModel
    else:
        assert False

    models = [model_fn(FLAGS).train() for i in range(FLAGS.ensembles)]
    models_ema = [model_fn(FLAGS).train() for i in range(FLAGS.ensembles)]

    torch.cuda.set_device(gpu)
    if FLAGS.cuda:
        models = [model.cuda(gpu) for model in models]
        model_ema = [model_ema.cuda(gpu) for model_ema in models_ema]

    if FLAGS.gpus > 1:
        sync_model(models)

    parameters = []
    for model in models:
        parameters.extend(list(model.parameters()))

    optimizer = Adam(parameters, lr=FLAGS.lr, betas=(0.0, 0.9), eps=1e-8)

    ema_model(models, models_ema, mu=0.0)

    logger = TensorBoardOutputFormat(logdir)

    it = FLAGS.resume_iter

    if not osp.exists(logdir):
        os.makedirs(logdir)

    checkpoint = None
    if FLAGS.resume_iter != 0:
        model_path = osp.join(logdir, "model_{}.pth".format(FLAGS.resume_iter))
        checkpoint = torch.load(model_path)
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

        for i, (model, model_ema) in enumerate(zip(models, models_ema)):
            model.load_state_dict(checkpoint['model_state_dict_{}'.format(i)])
            model_ema.load_state_dict(checkpoint['ema_model_state_dict_{}'.format(i)])


    print("New Values of args: ", FLAGS)

    pytorch_total_params = sum([p.numel() for p in model.parameters() if p.requires_grad])
    print("Number of parameters for models", pytorch_total_params)

    train(models, models_ema, optimizer, logger, train_dataloader, FLAGS.resume_iter, logdir, FLAGS, gpu, best_inception)
Example #14
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--policy-lr", type=float, default=5e-4, help="learning rate")
    parser.add_argument("--num-gpu-core", type=int, default=1, help="Number of GPU cores to use")
    parser.add_argument("--num-cpu-core", type=int, default=4, help="Number of CPU cores to use")
    parser.add_argument("--inter-op-parallelism-threads", type=int, default=4,
                        help="0 means the system picks an appropriate number.")
    parser.add_argument("--intra-op-parallelism-threads", type=int, default=4,
                        help="0 means the system picks an appropriate number.")
    parser.add_argument("--use-fp16", default=False, action='store_true', help="whether use float16 as default")
    args = parser.parse_args()

    print("Available devices:")
    print(device_lib.list_local_devices())

    config_dict = load_cfg_from_yaml('config/CIFAR10/R_50.yaml')
    config_dict = merge_cfg_from_args(config_dict, args)
    lr = config_dict["policy_lr"]
    intra_op_threads = config_dict['intra_op_parallelism_threads']
    inter_op_threads = config_dict['inter_op_parallelism_threads']
    use_fp16 = config_dict['use_fp16']
    cpu_num = config_dict["num_cpu_core"]
    gpu_num = config_dict["num_gpu_core"]
    MAX_EPOCH = config_dict['TRAIN']['MAX_EPOCH']
    device_id = -1

    config = tf.ConfigProto(
        # device_count limits the number of CPUs being used, not the number of cores or threads.
        device_count={'CPU': cpu_num, 'GPU': gpu_num},
        inter_op_parallelism_threads=cpu_num,  # parallel without each operation, i.e., reduce_sum
        intra_op_parallelism_threads=cpu_num,  # parallel between multiple operations
        log_device_placement=True,  # log the GPU or CPU device that is assigned to an operation
        allow_soft_placement=True,  # use soft constraints for the device placement
    )
    config.gpu_options.allow_growth = True
    config.gpu_options.per_process_gpu_memory_fraction = 0.9

    with tf.Session(config=config) as sess:
        # input for resnet
        tf_x = tf.placeholder(dtype=data_type(use_fp16), shape=[None, 32, 32, 3], name='tf_x')
        tf_y = tf.placeholder(dtype=data_type(use_fp16), shape=[None, 10], name='tf_y')
        # output of resnet
        with slim.arg_scope(resnet_v2.resnet_arg_scope()):
            with tf.device(next_device(device_id, config_dict, use_cpu=False)):
                resnet_out, end_points = resnet_v2.resnet_v2_50(tf_x, num_classes=10, is_training=False)
        # loss
        cross_entropy_loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(logits=resnet_out, labels=tf_y))
        # tf_loss_summary = tf.summary.scalar('loss', cross_entropy_loss)
        # optimizer
        tf_lr = tf.placeholder(dtype=data_type(use_fp16), shape=None, name='learning_rate')  # more flexible learning rate
        optimizer = tf.train.AdamOptimizer(tf_lr)
        grads_and_vars = optimizer.compute_gradients(cross_entropy_loss)
        train_step = optimizer.minimize(cross_entropy_loss)

        correct_prediction = tf.equal(tf.argmax(resnet_out, 1), tf.argmax(tf_y, 1))
        accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"))
        # tf_accuracy_summary = tf.summary.scalar('accuracy', accuracy)

        # initialize variables
        sess.run(tf.global_variables_initializer())

        # Name scope allows you to group various summaries together
        # Summaries having the same name_scope will be displayed on the same row
        with tf.name_scope('performance'):
            # Summaries need to be displayed
            # Whenever you need to record the loss, feed the mean loss to this placeholder
            tf_loss_ph = tf.placeholder(dtype=data_type(use_fp16), shape=None, name='loss_summary')
            # Create a scalar summary object for the loss so it can be displayed
            tf_loss_summary = tf.summary.scalar('loss', tf_loss_ph)

            # Whenever you need to record the loss, feed the mean test accuracy to this placeholder
            tf_accuracy_ph = tf.placeholder(dtype=data_type(use_fp16), shape=None, name='accuracy_summary')
            # Create a scalar summary object for the accuracy so it can be displayed
            tf_accuracy_summary = tf.summary.scalar('accuracy', tf_accuracy_ph)

        # Gradient norm summary
        # tf_gradnorm_summary: this calculates the l2 norm of the gradients of the last layer
        # of your neural network. Gradient norm is a good indicator of whether the weights of
        # the neural network are being properly updated. A too small gradient norm can indicate
        # vanishing gradient or a too large gradient can imply exploding gradient phenomenon.
        t_layer = len(grads_and_vars) - 2  # index of the last layer
        for i_layer, (g, v) in enumerate(grads_and_vars):
            if i_layer == t_layer:
                with tf.name_scope('gradients_norm'):
                    tf_last_grad_norm = tf.sqrt(tf.reduce_mean(g ** 2))
                    tf_gradnorm_summary = tf.summary.scalar('grad_norm', tf_last_grad_norm)
                    break

        # A summary for each weight in each layer of ResNet50
        all_summaries = []
        for weight_name in end_points:
            try:
                name_scope = weight_name.split("bottleneck_v2/")[0] + weight_name.split("bottleneck_v2/")[1]
            except:
                name_scope = weight_name.split("bottleneck_v2/")[0]
            with tf.name_scope(name_scope):
                weight = end_points[weight_name]
                # Create a scalar summary object for the loss so it can be displayed
                tf_w_hist = tf.summary.histogram('weights_hist', tf.reshape(weight, [-1]))
                all_summaries.append([tf_w_hist])
        # Merge all parameter histogram summaries together
        tf_weight_summaries = tf.summary.merge(all_summaries)

        # Merge all summaries together
        # the following two statements are equal
        # [1] merge_op = tf.summary.merge_all()
        # [2] merge_op = tf.summary.merge([tf_loss_summary, tf_accuracy_summary, tf_gradnorm_summary, ...])
        merge_op = tf.summary.merge([tf_loss_summary, tf_accuracy_summary])
        # separate tensorboard, i.e., one log file, one folder.
        train_writer = tf.summary.FileWriter('logs/train', sess.graph)
        test_writer = tf.summary.FileWriter('logs/test', sess.graph)

        # get cifar10 dataset
        cifar10_data = Cifar10('dataset/CIFAR10/', config=config_dict)
        test_images, test_labels = cifar10_data.test_data()

        # training
        start_time = time.time()
        batch_counter = 0
        while cifar10_data.epochs_completed < MAX_EPOCH:
            batch_xs, batch_ys = cifar10_data.next_train_batch()
            batch_counter += 1
            sess.run(train_step, feed_dict={tf_x: batch_xs, tf_y: batch_ys, tf_lr: lr})

            # calculate the train_accuracy for one batch
            if batch_counter % 100 == 0:
                train_accuracy, loss = sess.run(
                    [accuracy, cross_entropy_loss], feed_dict={tf_x: batch_xs, tf_y: batch_ys, tf_lr: lr})
                train_summary = sess.run(
                    merge_op, feed_dict={tf_loss_ph: loss, tf_accuracy_ph: train_accuracy})
                train_writer.add_summary(train_summary, batch_counter)
                print("----- epoch {} batch {} training accuracy {} loss {}".
                      format(cifar10_data.epochs_completed, batch_counter, train_accuracy, loss))

                end_time = time.time()
                print("time: {}".format(end_time - start_time))
                start_time = end_time

            # calculate the gradient norm summary and weight histogram
            if batch_counter % 100 == 0:
                gn_summary, wb_summary = sess.run(
                    [tf_gradnorm_summary, tf_weight_summaries], feed_dict={tf_x: batch_xs, tf_y: batch_ys, tf_lr: lr})
                train_writer.add_summary(gn_summary, batch_counter)
                train_writer.add_summary(wb_summary, batch_counter)

            # calculate the test_accuracy
            if batch_counter % 1000 == 0:
                # Test_accuracy
                test_accuracy, test_loss = sess.run(
                    [accuracy, cross_entropy_loss], feed_dict={tf_x: test_images, tf_y: test_labels, tf_lr: lr})
                test_summary = sess.run(
                    merge_op, feed_dict={tf_loss_ph: test_loss, tf_accuracy_ph: test_accuracy})
                test_writer.add_summary(test_summary, batch_counter)
                print("----- test accuracy {} test loss {}".format(test_accuracy, test_loss))

        # Overall test accuracy
        overall_accuracy = accuracy.eval(feed_dict={tf_x: test_images, tf_y: test_labels, tf_lr: lr})
        print("\nOverall test accuracy {}".format(overall_accuracy))

        save_cfg_to_yaml(config_dict, 'logs/current_config.yaml')
Example #15
0
def compute_inception(model):
    size = FLAGS.im_number
    num_steps = size // 1000

    images = []
    test_ims = []

    if FLAGS.dataset == "cifar10":
        test_dataset = Cifar10(FLAGS)
    elif FLAGS.dataset == "celeba":
        test_dataset = CelebAHQ()
    elif FLAGS.dataset == "mnist":
        test_dataset = Mnist(train=True)

    test_dataloader = DataLoader(test_dataset,
                                 batch_size=FLAGS.batch_size,
                                 num_workers=4,
                                 shuffle=True,
                                 drop_last=False)

    if FLAGS.dataset == "cifar10":
        for data_corrupt, data, label_gt in tqdm(test_dataloader):
            data = data.numpy()
            test_ims.extend(list(rescale_im(data)))

            if len(test_ims) > 10000:
                break
    elif FLAGS.dataset == "mnist":
        for data_corrupt, data, label_gt in tqdm(test_dataloader):
            data = data.numpy()
            test_ims.extend(list(np.tile(rescale_im(data), (1, 1, 3))))

            if len(test_ims) > 10000:
                break

    test_ims = test_ims[:10000]

    classes = 10

    print(FLAGS.batch_size)
    data_buffer = None

    for j in range(num_steps):
        itr = int(1000 / 500 * FLAGS.repeat_scale)

        if data_buffer is None:
            data_buffer = InceptionReplayBuffer(1000)

        curr_index = 0

        identity = np.eye(classes)

        if FLAGS.dataset == "celeba":
            n = 128
            c = 3
        elif FLAGS.dataset == "mnist":
            n = 28
            c = 1
        else:
            n = 32
            c = 3

        for i in tqdm(range(itr)):
            noise_scale = [1]
            if len(data_buffer) < 1000:
                x_init = np.random.uniform(0, 1, (FLAGS.batch_size, c, n, n))
                label = np.random.randint(0, classes, (FLAGS.batch_size))

                x_init = torch.Tensor(x_init).cuda()
                label = identity[label]
                label = torch.Tensor(label).cuda()

                x_new, _ = gen_image(label, FLAGS, model, x_init,
                                     FLAGS.num_steps)
                x_new = x_new.detach().cpu().numpy()
                label = label.detach().cpu().numpy()
                data_buffer.add(x_new, label)
            else:
                if i < itr - FLAGS.nomix:
                    (x_init, label), idx = data_buffer.sample(
                        FLAGS.batch_size, transform=FLAGS.transform)
                else:
                    if FLAGS.dataset == "celeba":
                        n = 20
                    else:
                        n = 2

                    ix = i % n
                    # for i in range(n):
                    start_idx = (1000 // n) * ix
                    end_idx = (1000 // n) * (ix + 1)
                    (x_init, label) = data_buffer._encode_sample(
                        list(range(start_idx, end_idx)), transform=False)
                    idx = list(range(start_idx, end_idx))

                x_init = torch.Tensor(x_init).cuda()
                label = torch.Tensor(label).cuda()
                x_new, energy = gen_image(label, FLAGS, model, x_init,
                                          FLAGS.num_steps)
                energy = energy.cpu().detach().numpy()
                x_new = x_new.cpu().detach().numpy()
                label = label.cpu().detach().numpy()
                data_buffer.set_elms(idx, x_new, label)

                if FLAGS.im_number != 50000:
                    print(np.mean(energy), np.std(energy))

            curr_index += 1

        ims = np.array(data_buffer._storage[:1000])
        ims = rescale_im(ims).transpose((0, 2, 3, 1))

        if FLAGS.dataset == "mnist":
            ims = np.tile(ims, (1, 1, 1, 3))

        images.extend(list(ims))

    random.shuffle(images)
    saveim = osp.join('sandbox_cachedir', FLAGS.exp,
                      "test{}.png".format(FLAGS.idx))

    if FLAGS.dataset == "cifar10":
        rix = np.random.permutation(1000)[:100]
        ims = ims[rix]
        im_panel = ims.reshape((10, 10, 32, 32, 3)).transpose(
            (0, 2, 1, 3, 4)).reshape((320, 320, 3))
        imsave(saveim, im_panel)

        print("Saved image!!!!")
        splits = max(1, len(images) // 5000)
        score, std = get_inception_score(images, splits=splits)
        print("Inception score of {} with std of {}".format(score, std))

        # FID score
        n = min(len(images), len(test_ims))
        fid = get_fid_score(images, test_ims)
        print("FID of score {}".format(fid))

    elif FLAGS.dataset == "mnist":
        # ims = ims[:100]
        # im_panel = ims.reshape((10, 10, 32, 32, 3)).transpose((0, 2, 1, 3, 4)).reshape((320, 320, 3))
        # imsave(saveim, im_panel)

        ims = ims[:100]
        im_panel = ims.reshape((10, 10, 28, 28, 3)).transpose(
            (0, 2, 1, 3, 4)).reshape((280, 280, 3))
        imsave(saveim, im_panel)

        print("Saved image!!!!")
        splits = max(1, len(images) // 5000)
        # score, std = get_inception_score(images, splits=splits)
        # print("Inception score of {} with std of {}".format(score, std))

        # FID score
        n = min(len(images), len(test_ims))
        fid = get_fid_score(images, test_ims)
        print("FID of score {}".format(fid))

    elif FLAGS.dataset == "celeba":

        ims = ims[:25]
        im_panel = ims.reshape((5, 5, 128, 128, 3)).transpose(
            (0, 2, 1, 3, 4)).reshape((5 * 128, 5 * 128, 3))
        imsave(saveim, im_panel)