def main():

    if FLAGS.dataset == "cifar10":
        dataset = Cifar10(train=True, noise=False)
        test_dataset = Cifar10(train=False, noise=False)
    else:
        dataset = Imagenet(train=True)
        test_dataset = Imagenet(train=False)

    if FLAGS.svhn:
        dataset = Svhn(train=True)
        test_dataset = Svhn(train=False)

    if FLAGS.task == 'latent':
        dataset = DSprites()
        test_dataset = dataset

    dataloader = DataLoader(dataset,
                            batch_size=FLAGS.batch_size,
                            num_workers=FLAGS.data_workers,
                            shuffle=True,
                            drop_last=True)
    test_dataloader = DataLoader(test_dataset,
                                 batch_size=FLAGS.batch_size,
                                 num_workers=FLAGS.data_workers,
                                 shuffle=True,
                                 drop_last=True)

    hidden_dim = 128

    if FLAGS.large_model:
        model = ResNet32Large(num_filters=hidden_dim)
    elif FLAGS.larger_model:
        model = ResNet32Larger(num_filters=hidden_dim)
    elif FLAGS.wider_model:
        if FLAGS.dataset == 'imagenet':
            model = ResNet32Wider(num_filters=196, train=False)
        else:
            model = ResNet32Wider(num_filters=256, train=False)
    else:
        model = ResNet32(num_filters=hidden_dim)

    if FLAGS.task == 'latent':
        model = DspritesNet()

    weights = model.construct_weights('context_{}'.format(0))

    total_parameters = 0
    for variable in tf.compat.v1.trainable_variables():
        # shape is an array of tf.Dimension
        shape = variable.get_shape()
        variable_parameters = 1
        for dim in shape:
            variable_parameters *= dim.value
        total_parameters += variable_parameters
    print("Model has a total of {} parameters".format(total_parameters))

    config = tf.compat.v1.ConfigProto()
    sess = tf.compat.v1.InteractiveSession()

    if FLAGS.task == 'latent':
        X = tf.compat.v1.placeholder(shape=(None, 64, 64), dtype=tf.float32)
    else:
        X = tf.compat.v1.placeholder(shape=(None, 32, 32, 3), dtype=tf.float32)

    if FLAGS.dataset == "cifar10":
        Y = tf.compat.v1.placeholder(shape=(None, 10), dtype=tf.float32)
        Y_GT = tf.compat.v1.placeholder(shape=(None, 10), dtype=tf.float32)
    elif FLAGS.dataset == "imagenet":
        Y = tf.compat.v1.placeholder(shape=(None, 1000), dtype=tf.float32)
        Y_GT = tf.compat.v1.placeholder(shape=(None, 1000), dtype=tf.float32)

    target_vars = {'X': X, 'Y': Y, 'Y_GT': Y_GT}

    if FLAGS.task == 'label':
        construct_label(weights, X, Y, Y_GT, model, target_vars)
    elif FLAGS.task == 'labelfinetune':
        construct_finetune_label(
            weights,
            X,
            Y,
            Y_GT,
            model,
            target_vars,
        )
    elif FLAGS.task == 'energyeval' or FLAGS.task == 'mixenergy':
        construct_energy(weights, X, Y, Y_GT, model, target_vars)
    elif FLAGS.task == 'anticorrupt' or FLAGS.task == 'boxcorrupt' or FLAGS.task == 'crossclass' or FLAGS.task == 'cycleclass' or FLAGS.task == 'democlass' or FLAGS.task == 'nearestneighbor':
        construct_steps(weights, X, Y_GT, model, target_vars)
    elif FLAGS.task == 'latent':
        construct_latent(weights, X, Y_GT, model, target_vars)

    sess.run(tf.compat.v1.global_variables_initializer())
    saver = loader = tf.compat.v1.train.Saver(max_to_keep=10)
    savedir = osp.join('cachedir', FLAGS.exp)
    logdir = osp.join(FLAGS.logdir, FLAGS.exp)
    if not osp.exists(logdir):
        os.makedirs(logdir)

    initialize()
    if FLAGS.resume_iter != -1:
        model_file = osp.join(savedir, 'model_{}'.format(FLAGS.resume_iter))
        resume_itr = FLAGS.resume_iter

        if FLAGS.task == 'label' or FLAGS.task == 'boxcorrupt' or FLAGS.task == 'labelfinetune' or FLAGS.task == "energyeval" or FLAGS.task == "crossclass" or FLAGS.task == "mixenergy":
            optimistic_restore(sess, model_file)
            # saver.restore(sess, model_file)
        else:
            # optimistic_restore(sess, model_file)
            saver.restore(sess, model_file)

    if FLAGS.task == 'label':
        if FLAGS.labelgrid:
            vals = []
            if FLAGS.lnorm == -1:
                for i in range(31):
                    accuracies = label(dataloader,
                                       test_dataloader,
                                       target_vars,
                                       sess,
                                       l1val=i)
                    vals.append(accuracies)
            elif FLAGS.lnorm == 2:
                for i in range(0, 100, 5):
                    accuracies = label(dataloader,
                                       test_dataloader,
                                       target_vars,
                                       sess,
                                       l2val=i)
                    vals.append(accuracies)

            np.save("result_{}_{}.npy".format(FLAGS.lnorm, FLAGS.exp), vals)
        else:
            label(dataloader, test_dataloader, target_vars, sess)
    elif FLAGS.task == 'labelfinetune':
        labelfinetune(dataloader,
                      test_dataloader,
                      target_vars,
                      sess,
                      savedir,
                      saver,
                      l1val=FLAGS.lival,
                      l2val=FLAGS.l2val)
    elif FLAGS.task == 'energyeval':
        energyeval(dataloader, test_dataloader, target_vars, sess)
    elif FLAGS.task == 'mixenergy':
        energyevalmix(dataloader, test_dataloader, target_vars, sess)
    elif FLAGS.task == 'anticorrupt':
        anticorrupt(test_dataloader, weights, model, target_vars, logdir, sess)
    elif FLAGS.task == 'boxcorrupt':
        # boxcorrupt(test_dataloader, weights, model, target_vars, logdir, sess)
        boxcorrupt(test_dataloader, dataloader, weights, model, target_vars,
                   logdir, sess)
    elif FLAGS.task == 'crossclass':
        crossclass(test_dataloader, weights, model, target_vars, logdir, sess)
    elif FLAGS.task == 'cycleclass':
        cycleclass(test_dataloader, weights, model, target_vars, logdir, sess)
    elif FLAGS.task == 'democlass':
        democlass(test_dataloader, weights, model, target_vars, logdir, sess)
    elif FLAGS.task == 'nearestneighbor':
        # print(dir(dataset))
        # print(type(dataset))
        nearest_neighbor(dataset.data.train_data / 255, sess, target_vars,
                         logdir)
    elif FLAGS.task == 'latent':
        latent(test_dataloader, weights, model, target_vars, sess)
def main():
    data = np.load(FLAGS.dsprites_path)['imgs']
    l = latents = np.load(FLAGS.dsprites_path)['latents_values']

    np.random.seed(1)
    idx = np.random.permutation(data.shape[0])

    data = data[idx]
    latents = latents[idx]

    config = tf.ConfigProto()
    sess = tf.Session(config=config)

    # Model 1 will be conditioned on size
    model_size = DspritesNet(num_filters=FLAGS.num_filters, cond_size=True)
    weight_size = model_size.construct_weights('context_0')

    # Model 2 will be conditioned on shape
    model_shape = DspritesNet(num_filters=FLAGS.num_filters, cond_shape=True)
    weight_shape = model_shape.construct_weights('context_1')

    # Model 3 will be conditioned on position
    model_pos = DspritesNet(num_filters=FLAGS.num_filters, cond_pos=True)
    weight_pos = model_pos.construct_weights('context_2')

    # Model 4 will be conditioned on rotation
    model_rot = DspritesNet(num_filters=FLAGS.num_filters, cond_rot=True)
    weight_rot = model_rot.construct_weights('context_3')

    sess.run(tf.global_variables_initializer())
    save_path_size = osp.join(FLAGS.logdir, FLAGS.exp_size, 'model_{}'.format(FLAGS.resume_size))

    v_list = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='context_{}'.format(0))
    v_map = {(v.name.replace('context_{}'.format(0), 'context_0')[:-2]): v for v in v_list}

    if FLAGS.cond_scale:
        saver = tf.train.Saver(v_map)
        saver.restore(sess, save_path_size)

    save_path_shape = osp.join(FLAGS.logdir, FLAGS.exp_shape, 'model_{}'.format(FLAGS.resume_shape))

    v_list = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='context_{}'.format(1))
    v_map = {(v.name.replace('context_{}'.format(1), 'context_0')[:-2]): v for v in v_list}

    if FLAGS.cond_shape:
        saver = tf.train.Saver(v_map)
        saver.restore(sess, save_path_shape)


    save_path_pos = osp.join(FLAGS.logdir, FLAGS.exp_pos, 'model_{}'.format(FLAGS.resume_pos))
    v_list = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='context_{}'.format(2))
    v_map = {(v.name.replace('context_{}'.format(2), 'context_0')[:-2]): v for v in v_list}
    saver = tf.train.Saver(v_map)

    if FLAGS.cond_pos:
        saver.restore(sess, save_path_pos)


    save_path_rot = osp.join(FLAGS.logdir, FLAGS.exp_rot, 'model_{}'.format(FLAGS.resume_rot))
    v_list = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='context_{}'.format(3))
    v_map = {(v.name.replace('context_{}'.format(3), 'context_0')[:-2]): v for v in v_list}
    saver = tf.train.Saver(v_map)

    if FLAGS.cond_rot:
        saver.restore(sess, save_path_rot)

    X_NOISE = tf.placeholder(shape=(None, 64, 64), dtype=tf.float32)
    LABEL_SIZE = tf.placeholder(shape=(None, 1), dtype=tf.float32)
    LABEL_SHAPE = tf.placeholder(shape=(None, 3), dtype=tf.float32)
    LABEL_POS = tf.placeholder(shape=(None, 2), dtype=tf.float32)
    LABEL_ROT = tf.placeholder(shape=(None, 2), dtype=tf.float32)

    x_mod = X_NOISE

    kvs = {}
    kvs['X_NOISE'] = X_NOISE
    kvs['LABEL_SIZE'] = LABEL_SIZE
    kvs['LABEL_SHAPE'] = LABEL_SHAPE
    kvs['LABEL_POS'] = LABEL_POS
    kvs['LABEL_ROT'] = LABEL_ROT
    kvs['model_size'] = model_size
    kvs['model_shape'] = model_shape
    kvs['model_pos'] = model_pos
    kvs['model_rot'] = model_rot
    kvs['weight_size'] = weight_size
    kvs['weight_shape'] = weight_shape
    kvs['weight_pos'] = weight_pos
    kvs['weight_rot'] = weight_rot

    save_exp_dir = osp.join(FLAGS.savedir, '{}_{}_joint'.format(FLAGS.exp_size, FLAGS.exp_shape))
    if not osp.exists(save_exp_dir):
        os.makedirs(save_exp_dir)


    if FLAGS.task == 'conceptcombine':
        conceptcombine(sess, kvs, data, latents, save_exp_dir)
    elif FLAGS.task == 'labeldiscover':
        labeldiscover(sess, kvs, data, latents, save_exp_dir)
    elif FLAGS.task == 'gentest':
        save_exp_dir = osp.join(FLAGS.savedir, '{}_{}_gen'.format(FLAGS.exp_size, FLAGS.exp_pos))
        if not osp.exists(save_exp_dir):
            os.makedirs(save_exp_dir)

        gentest(sess, kvs, data, latents, save_exp_dir)
    elif FLAGS.task == 'genbaseline':
        save_exp_dir = osp.join(FLAGS.savedir, '{}_{}_gen_baseline'.format(FLAGS.exp_size, FLAGS.exp_pos))
        if not osp.exists(save_exp_dir):
            os.makedirs(save_exp_dir)

        if FLAGS.plot_curve:
            mse_losses = []
            for frac in [i/10 for i in range(11)]:
                mse_loss = genbaseline(sess, kvs, data, latents, save_exp_dir, frac=frac)
                mse_losses.append(mse_loss)
            np.save("mse_baseline_comb.npy", mse_losses)
        else:
            genbaseline(sess, kvs, data, latents, save_exp_dir)