def energyevalmix(dataloader, test_dataloader, target_vars, sess):
    X = target_vars['X']
    Y_GT = target_vars['Y_GT']
    energy = target_vars['energy']

    if FLAGS.svhnmix:
        dataset = Svhn(train=False)
        test_dataloader_val = DataLoader(dataset,
                                         batch_size=FLAGS.batch_size,
                                         num_workers=FLAGS.data_workers,
                                         shuffle=True,
                                         drop_last=False)
        test_iter = iter(test_dataloader_val)
    elif FLAGS.cifar100mix:
        dataset = Cifar100(train=False)
        test_dataloader_val = DataLoader(dataset,
                                         batch_size=FLAGS.batch_size,
                                         num_workers=FLAGS.data_workers,
                                         shuffle=True,
                                         drop_last=False)
        test_iter = iter(test_dataloader_val)
    elif FLAGS.texturemix:
        dataset = Textures()
        test_dataloader_val = DataLoader(dataset,
                                         batch_size=FLAGS.batch_size,
                                         num_workers=FLAGS.data_workers,
                                         shuffle=True,
                                         drop_last=False)
        test_iter = iter(test_dataloader_val)

    probs = []
    labels = []
    negs = []
    pos = []
    for data_corrupt, data, label_gt in tqdm(test_dataloader):
        data = data.numpy()
        data_corrupt = data_corrupt.numpy()
        if FLAGS.svhnmix:
            _, data_mix, _ = test_iter.next()
        elif FLAGS.cifar100mix:
            _, data_mix, _ = test_iter.next()
        elif FLAGS.texturemix:
            _, data_mix, _ = test_iter.next()
        elif FLAGS.randommix:
            data_mix = np.random.randn(FLAGS.batch_size, 32, 32, 3) * 0.5 + 0.5
        else:
            data_idx = np.concatenate([np.arange(1, data.shape[0]), [0]])
            data_other = data[data_idx]
            data_mix = (data + data_other) / 2

        data_mix = data_mix[:data.shape[0]]

        if FLAGS.cclass:
            # It's unfair to take a random class
            label_gt = np.tile(np.eye(10), (data.shape[0], 1, 1))
            label_gt = label_gt.reshape(data.shape[0] * 10, 10)
            data_mix = np.tile(data_mix[:, None, :, :, :], (1, 10, 1, 1, 1))
            data = np.tile(data[:, None, :, :, :], (1, 10, 1, 1, 1))

            data_mix = data_mix.reshape(-1, 32, 32, 3)
            data = data.reshape(-1, 32, 32, 3)

        feed_dict = {X: data, Y_GT: label_gt}
        feed_dict_neg = {X: data_mix, Y_GT: label_gt}

        pos_energy = sess.run([energy], feed_dict)[0]
        neg_energy = sess.run([energy], feed_dict_neg)[0]

        if FLAGS.cclass:
            pos_energy = pos_energy.reshape(-1, 10).min(axis=1)
            neg_energy = neg_energy.reshape(-1, 10).min(axis=1)

        probs.extend(list(-1 * pos_energy))
        probs.extend(list(-1 * neg_energy))
        pos.extend(list(-1 * pos_energy))
        negs.extend(list(-1 * neg_energy))
        labels.extend([1] * pos_energy.shape[0])
        labels.extend([0] * neg_energy.shape[0])

    pos, negs = np.array(pos), np.array(negs)
    np.save("pos.npy", pos)
    np.save("neg.npy", negs)
    auroc = sk.roc_auc_score(labels, probs)
    print("Roc score of {}".format(auroc))
def main():

    if FLAGS.dataset == "cifar10":
        dataset = Cifar10(train=True, noise=False)
        test_dataset = Cifar10(train=False, noise=False)
    else:
        dataset = Imagenet(train=True)
        test_dataset = Imagenet(train=False)

    if FLAGS.svhn:
        dataset = Svhn(train=True)
        test_dataset = Svhn(train=False)

    if FLAGS.task == 'latent':
        dataset = DSprites()
        test_dataset = dataset

    dataloader = DataLoader(dataset,
                            batch_size=FLAGS.batch_size,
                            num_workers=FLAGS.data_workers,
                            shuffle=True,
                            drop_last=True)
    test_dataloader = DataLoader(test_dataset,
                                 batch_size=FLAGS.batch_size,
                                 num_workers=FLAGS.data_workers,
                                 shuffle=True,
                                 drop_last=True)

    hidden_dim = 128

    if FLAGS.large_model:
        model = ResNet32Large(num_filters=hidden_dim)
    elif FLAGS.larger_model:
        model = ResNet32Larger(num_filters=hidden_dim)
    elif FLAGS.wider_model:
        if FLAGS.dataset == 'imagenet':
            model = ResNet32Wider(num_filters=196, train=False)
        else:
            model = ResNet32Wider(num_filters=256, train=False)
    else:
        model = ResNet32(num_filters=hidden_dim)

    if FLAGS.task == 'latent':
        model = DspritesNet()

    weights = model.construct_weights('context_{}'.format(0))

    total_parameters = 0
    for variable in tf.compat.v1.trainable_variables():
        # shape is an array of tf.Dimension
        shape = variable.get_shape()
        variable_parameters = 1
        for dim in shape:
            variable_parameters *= dim.value
        total_parameters += variable_parameters
    print("Model has a total of {} parameters".format(total_parameters))

    config = tf.compat.v1.ConfigProto()
    sess = tf.compat.v1.InteractiveSession()

    if FLAGS.task == 'latent':
        X = tf.compat.v1.placeholder(shape=(None, 64, 64), dtype=tf.float32)
    else:
        X = tf.compat.v1.placeholder(shape=(None, 32, 32, 3), dtype=tf.float32)

    if FLAGS.dataset == "cifar10":
        Y = tf.compat.v1.placeholder(shape=(None, 10), dtype=tf.float32)
        Y_GT = tf.compat.v1.placeholder(shape=(None, 10), dtype=tf.float32)
    elif FLAGS.dataset == "imagenet":
        Y = tf.compat.v1.placeholder(shape=(None, 1000), dtype=tf.float32)
        Y_GT = tf.compat.v1.placeholder(shape=(None, 1000), dtype=tf.float32)

    target_vars = {'X': X, 'Y': Y, 'Y_GT': Y_GT}

    if FLAGS.task == 'label':
        construct_label(weights, X, Y, Y_GT, model, target_vars)
    elif FLAGS.task == 'labelfinetune':
        construct_finetune_label(
            weights,
            X,
            Y,
            Y_GT,
            model,
            target_vars,
        )
    elif FLAGS.task == 'energyeval' or FLAGS.task == 'mixenergy':
        construct_energy(weights, X, Y, Y_GT, model, target_vars)
    elif FLAGS.task == 'anticorrupt' or FLAGS.task == 'boxcorrupt' or FLAGS.task == 'crossclass' or FLAGS.task == 'cycleclass' or FLAGS.task == 'democlass' or FLAGS.task == 'nearestneighbor':
        construct_steps(weights, X, Y_GT, model, target_vars)
    elif FLAGS.task == 'latent':
        construct_latent(weights, X, Y_GT, model, target_vars)

    sess.run(tf.compat.v1.global_variables_initializer())
    saver = loader = tf.compat.v1.train.Saver(max_to_keep=10)
    savedir = osp.join('cachedir', FLAGS.exp)
    logdir = osp.join(FLAGS.logdir, FLAGS.exp)
    if not osp.exists(logdir):
        os.makedirs(logdir)

    initialize()
    if FLAGS.resume_iter != -1:
        model_file = osp.join(savedir, 'model_{}'.format(FLAGS.resume_iter))
        resume_itr = FLAGS.resume_iter

        if FLAGS.task == 'label' or FLAGS.task == 'boxcorrupt' or FLAGS.task == 'labelfinetune' or FLAGS.task == "energyeval" or FLAGS.task == "crossclass" or FLAGS.task == "mixenergy":
            optimistic_restore(sess, model_file)
            # saver.restore(sess, model_file)
        else:
            # optimistic_restore(sess, model_file)
            saver.restore(sess, model_file)

    if FLAGS.task == 'label':
        if FLAGS.labelgrid:
            vals = []
            if FLAGS.lnorm == -1:
                for i in range(31):
                    accuracies = label(dataloader,
                                       test_dataloader,
                                       target_vars,
                                       sess,
                                       l1val=i)
                    vals.append(accuracies)
            elif FLAGS.lnorm == 2:
                for i in range(0, 100, 5):
                    accuracies = label(dataloader,
                                       test_dataloader,
                                       target_vars,
                                       sess,
                                       l2val=i)
                    vals.append(accuracies)

            np.save("result_{}_{}.npy".format(FLAGS.lnorm, FLAGS.exp), vals)
        else:
            label(dataloader, test_dataloader, target_vars, sess)
    elif FLAGS.task == 'labelfinetune':
        labelfinetune(dataloader,
                      test_dataloader,
                      target_vars,
                      sess,
                      savedir,
                      saver,
                      l1val=FLAGS.lival,
                      l2val=FLAGS.l2val)
    elif FLAGS.task == 'energyeval':
        energyeval(dataloader, test_dataloader, target_vars, sess)
    elif FLAGS.task == 'mixenergy':
        energyevalmix(dataloader, test_dataloader, target_vars, sess)
    elif FLAGS.task == 'anticorrupt':
        anticorrupt(test_dataloader, weights, model, target_vars, logdir, sess)
    elif FLAGS.task == 'boxcorrupt':
        # boxcorrupt(test_dataloader, weights, model, target_vars, logdir, sess)
        boxcorrupt(test_dataloader, dataloader, weights, model, target_vars,
                   logdir, sess)
    elif FLAGS.task == 'crossclass':
        crossclass(test_dataloader, weights, model, target_vars, logdir, sess)
    elif FLAGS.task == 'cycleclass':
        cycleclass(test_dataloader, weights, model, target_vars, logdir, sess)
    elif FLAGS.task == 'democlass':
        democlass(test_dataloader, weights, model, target_vars, logdir, sess)
    elif FLAGS.task == 'nearestneighbor':
        # print(dir(dataset))
        # print(type(dataset))
        nearest_neighbor(dataset.data.train_data / 255, sess, target_vars,
                         logdir)
    elif FLAGS.task == 'latent':
        latent(test_dataloader, weights, model, target_vars, sess)
))

exec "from {0:s} import {0:s}".format(args.model)
exec "T = {:s}(T)".format(args.model)
T.sess.run(tf.global_variables_initializer())

if args.model != 'classifier':
    path = tf.train.latest_checkpoint('save')
    restorer = tf.train.Saver(tf.get_collection('trainable_variables', 'enc'))
    restorer.restore(T.sess, path)

#############
# Load data #
#############
mnist = Mnist(size=32)
svhn = Svhn(size=32)

#########
# Train #
#########
bs = 100
iterep = 600
n_epoch = 5000 if args.model != 'classifier' else 17
epoch = 0
feed_dict = {T.phase: 1}
saver = tf.train.Saver()

print "Batch size:", bs
print "Iterep:", iterep
print "Total iterations:", n_epoch * iterep
print "Log directory:", get_log_dir()