def model(sess, image):
    global _inception_initialized
    network_fn = _get_model(reuse=_inception_initialized)
    size = network_fn.default_image_size
    preprocessed = _preprocess(image, size, size)
    logits, _ = network_fn(preprocessed)
    logits = logits[:,1:] # ignore background class
    predictions = tf.argmax(logits, 1)

    if not _inception_initialized:
        optimistic_restore(sess, INCEPTION_CHECKPOINT_PATH)
        _inception_initialized = True

    return logits, predictions
Exemple #2
0
def model(sess, image):
    global _inception_initialized
    network_fn = _get_model(reuse=_inception_initialized)
    size = network_fn.default_image_size
    preprocessed = _preprocess(image, size, size)
    logits, _ = network_fn(preprocessed)
    logits = logits[:, 1:]  # ignore background class
    predictions = tf.argmax(logits, 1)

    if not _inception_initialized:
        optimistic_restore(sess, INCEPTION_CHECKPOINT_PATH)
        _inception_initialized = True

    return logits, predictions
def _init_model(sess, checkpoint_name=None):
    global _model_func
    global _obs_shape
    global _model_opt

    if checkpoint_name is None:
        checkpoint_name = _PIXELCNN_CHECKPOINT_NAME
    checkpoint_path = os.path.join(DATA_DIR, checkpoint_name)

    x_init = tf.placeholder(tf.float32, (1,) + _obs_shape)
    model = _model_func(x_init, init=True, dropout_p=0.5, **_model_opt)
    # XXX need to add a scope argument to optimistic_restore and filter for
    # things that start with "{scope}/", so we can filter for "model/", because
    # the pixelcnn checkpoint has some random unscoped stuff like 'Variable'
    optimistic_restore(sess, checkpoint_path)
Exemple #4
0
def main():

    logdir = osp.join(FLAGS.logdir, FLAGS.exp)
    logger = TensorBoardOutputFormat(logdir)

    config = tf.ConfigProto()

    sess = tf.Session(config=config)
    LABEL = None
    print("Loading data...")
    if FLAGS.dataset == 'cubes':
        dataset = Cubes(cond_idx=FLAGS.cond_idx)
        test_dataset = dataset

        if FLAGS.cond_idx == 0:
            label_size = 2
        elif FLAGS.cond_idx == 1:
            label_size = 1
        elif FLAGS.cond_idx == 2:
            label_size = 3
        elif FLAGS.cond_idx == 3:
            label_size = 20

        LABEL = tf.placeholder(shape=(None, label_size), dtype=tf.float32)
        LABEL_POS = tf.placeholder(shape=(None, label_size), dtype=tf.float32)
    elif FLAGS.dataset == 'color':
        dataset = CubesColor()
        test_dataset = dataset
        LABEL = tf.placeholder(shape=(None, 301), dtype=tf.float32)
        LABEL_POS = tf.placeholder(shape=(None, 301), dtype=tf.float32)
        label_size = 301
    elif FLAGS.dataset == 'pos':
        dataset = CubesPos()
        test_dataset = dataset
        LABEL = tf.placeholder(shape=(None, 2), dtype=tf.float32)
        LABEL_POS = tf.placeholder(shape=(None, 2), dtype=tf.float32)
        label_size = 2
    elif FLAGS.dataset == "pairs":
        dataset = Pairs(cond_idx=0)
        test_dataset = dataset
        LABEL = tf.placeholder(shape=(None, 6), dtype=tf.float32)
        LABEL_POS = tf.placeholder(shape=(None, 6), dtype=tf.float32)
        label_size = 6
    elif FLAGS.dataset == "continual":
        dataset = CubesContinual()
        test_dataset = dataset

        if FLAGS.prelearn_model_shape:
            LABEL = tf.placeholder(shape=(None, 20), dtype=tf.float32)
            LABEL_POS = tf.placeholder(shape=(None, 20), dtype=tf.float32)
            label_size = 20
        else:
            LABEL = tf.placeholder(shape=(None, 2), dtype=tf.float32)
            LABEL_POS = tf.placeholder(shape=(None, 2), dtype=tf.float32)
            label_size = 2

    elif FLAGS.dataset == "cross":
        dataset = CubesCrossProduct(FLAGS.ratio, cond_size=FLAGS.cond_size, cond_pos=FLAGS.cond_pos, joint_baseline=FLAGS.joint_baseline)
        test_dataset = dataset

        if FLAGS.cond_size:
            LABEL = tf.placeholder(shape=(None, 1), dtype=tf.float32)
            LABEL_POS = tf.placeholder(shape=(None, 1), dtype=tf.float32)
            label_size = 1
        elif FLAGS.cond_pos:
            LABEL = tf.placeholder(shape=(None, 2), dtype=tf.float32)
            LABEL_POS = tf.placeholder(shape=(None, 2), dtype=tf.float32)
            label_size = 2

        if FLAGS.joint_baseline:
            LABEL = tf.placeholder(shape=(None, 3), dtype=tf.float32)
            LABEL_POS = tf.placeholder(shape=(None, 3), dtype=tf.float32)
            label_size = 3

    elif FLAGS.dataset == 'celeba':
        dataset = CelebA(cond_idx=FLAGS.celeba_cond_idx)
        test_dataset = dataset
        channel_num = 3
        X_NOISE = tf.placeholder(shape=(None, 128, 128, 3), dtype=tf.float32)
        X = tf.placeholder(shape=(None, 128, 128, 3), dtype=tf.float32)
        LABEL = tf.placeholder(shape=(None, 2), dtype=tf.float32)
        LABEL_POS = tf.placeholder(shape=(None, 2), dtype=tf.float32)

        model = ResNet128(
            num_channels=channel_num,
            num_filters=64,
            classes=2)

    if FLAGS.joint_baseline:
        # Other stuff for joint model
        optimizer = AdamOptimizer(FLAGS.lr, beta1=0.99, beta2=0.999)

        X = tf.placeholder(shape=(None, 64, 64, 3), dtype=tf.float32)
        X_NOISE = tf.placeholder(shape=(None, 64, 64, 3), dtype=tf.float32)
        ATTENTION_MASK = tf.placeholder(shape=(None, 64, 64, FLAGS.cond_func), dtype=tf.float32)
        NOISE = tf.placeholder(shape=(None, 128), dtype=tf.float32)
        HIER_LABEL = tf.placeholder(shape=(None, 2), dtype=tf.float32)

        channel_num = 3

        model = CubesNetGen(num_channels=channel_num, label_size=label_size)
        weights = model.construct_weights('context_0')
        output = model.forward(NOISE, weights, reuse=False, label=LABEL)
        print(output.get_shape())
        mse_loss = tf.reduce_mean(tf.square(output - X))
        gvs = optimizer.compute_gradients(mse_loss)
        train_op = optimizer.apply_gradients(gvs)
        gvs = [(k, v) for (k, v) in gvs if k is not None]

        target_vars = {}
        target_vars['train_op'] = train_op
        target_vars['X'] = X
        target_vars['X_NOISE'] = X_NOISE
        target_vars['ATTENTION_MASK'] = ATTENTION_MASK
        target_vars['eps_begin'] = tf.zeros(1)
        target_vars['gvs'] = gvs
        target_vars['energy_pos'] = tf.zeros(1)
        target_vars['energy_neg'] = tf.zeros(1)
        target_vars['loss_energy'] = tf.zeros(1)
        target_vars['loss_ml'] = tf.zeros(1)
        target_vars['total_loss'] = mse_loss
        target_vars['attention_mask'] = tf.zeros(1)
        target_vars['attention_grad'] = tf.zeros(1)
        target_vars['x_off'] = tf.reduce_mean(tf.abs(output - X))
        target_vars['x_mod'] = tf.zeros(1)
        target_vars['x_grad'] = tf.zeros(1)
        target_vars['NOISE'] = NOISE
        target_vars['LABEL'] = LABEL
        target_vars['LABEL_POS'] = LABEL_POS
        target_vars['HIER_LABEL'] = HIER_LABEL

        data_loader = DataLoader(
            dataset,
            batch_size=FLAGS.batch_size,
            num_workers=FLAGS.data_workers,
            drop_last=True,
            shuffle=True)
    else:
        print("label size here ", label_size)
        channel_num = 3
        X_NOISE = tf.placeholder(shape=(None, 64, 64, 3), dtype=tf.float32)
        X = tf.placeholder(shape=(None, 64, 64, 3), dtype=tf.float32)
        HEIR_LABEL = tf.placeholder(shape=(None, 2), dtype=tf.float32)
        ATTENTION_MASK = tf.placeholder(shape=(None, 64, 64, FLAGS.cond_func), dtype=tf.float32)

        if FLAGS.dataset != "celeba":
            model = CubesNet(num_channels=channel_num, label_size=label_size)

        heir_model = HeirNet(num_channels=FLAGS.cond_func)

        models_pretrain = []
        if FLAGS.prelearn_model:
            model_prelearn = CubesNet(num_channels=channel_num, label_size=FLAGS.prelearn_label)
            weights = model_prelearn.construct_weights('context_1')
            LABEL_PRELEARN = tf.placeholder(shape=(None, FLAGS.prelearn_label), dtype=tf.float32)
            models_pretrain.append((model_prelearn, weights, LABEL_PRELEARN))

            cubes_logdir = osp.join(FLAGS.logdir, FLAGS.prelearn_exp)
            if (FLAGS.prelearn_iter != -1 or not FLAGS.train):
                model_file = osp.join(cubes_logdir, 'model_{}'.format(FLAGS.prelearn_iter))
                resume_itr = FLAGS.resume_iter
                # saver.restore(sess, model_file)

                v_list = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='context_{}'.format(1))
                v_map = {(v.name.replace('context_{}'.format(1), 'context_0')[:-2]): v for v in v_list}
                saver = tf.train.Saver(v_map)
                saver.restore(sess, model_file)

        if FLAGS.prelearn_model_shape:
            model_prelearn = CubesNet(num_channels=channel_num, label_size=FLAGS.prelearn_label_shape)
            weights = model_prelearn.construct_weights('context_2')
            LABEL_PRELEARN = tf.placeholder(shape=(None, FLAGS.prelearn_label_shape), dtype=tf.float32)
            models_pretrain.append((model_prelearn, weights, LABEL_PRELEARN))

            cubes_logdir = osp.join(FLAGS.logdir, FLAGS.prelearn_exp_shape)
            if (FLAGS.prelearn_iter_shape != -1 or not FLAGS.train):
                model_file = osp.join(cubes_logdir, 'model_{}'.format(FLAGS.prelearn_iter_shape))
                resume_itr = FLAGS.resume_iter
                # saver.restore(sess, model_file)

                v_list = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='context_{}'.format(2))
                v_map = {(v.name.replace('context_{}'.format(2), 'context_0')[:-2]): v for v in v_list}
                saver = tf.train.Saver(v_map)
                saver.restore(sess, model_file)

        print("Done loading...")

        data_loader = DataLoader(
            dataset,
            batch_size=FLAGS.batch_size,
            num_workers=FLAGS.data_workers,
            drop_last=True,
            shuffle=True)

        batch_size = FLAGS.batch_size

        weights = model.construct_weights('context_0')

        if FLAGS.heir_mask:
            weights = heir_model.construct_weights('heir_0', weights=weights)

        Y = tf.placeholder(shape=(None), dtype=tf.int32)

        # Varibles to run in training

        X_SPLIT = tf.split(X, FLAGS.num_gpus)
        X_NOISE_SPLIT = tf.split(X_NOISE, FLAGS.num_gpus)
        LABEL_SPLIT = tf.split(LABEL, FLAGS.num_gpus)
        LABEL_POS_SPLIT = tf.split(LABEL_POS, FLAGS.num_gpus)
        LABEL_SPLIT_INIT = list(LABEL_SPLIT)
        attention_mask = ATTENTION_MASK
        tower_grads = []
        tower_gen_grads = []
        x_mod_list = []

        optimizer = AdamOptimizer(FLAGS.lr, beta1=0.0, beta2=0.99)

        for j in range(FLAGS.num_gpus):

            x_mod = X_SPLIT[j]
            if FLAGS.comb_mask:
                steps = tf.constant(0)
                c = lambda i, x: tf.less(i, FLAGS.num_steps)

                def langevin_attention_step(counter, attention_mask):
                    attention_mask = attention_mask + tf.random_normal(tf.shape(attention_mask), mean=0.0, stddev=0.01)
                    energy_noise = energy_start = model.forward(
                                x_mod,
                                weights,
                                attention_mask,
                                label=LABEL_SPLIT[j],
                                reuse=True,
                                stop_at_grad=False,
                                stop_batch=True)

                    if FLAGS.heir_mask:
                        energy_heir = 1.00 * heir_model.forward(attention_mask, weights, label=HEIR_LABEL)
                        energy_noise = energy_noise + energy_heir

                    attention_grad = tf.gradients(
                        FLAGS.temperature * energy_noise, [attention_mask])[0]
                    energy_noise_old = energy_noise

                    # Clip gradient norm for now
                    attention_mask = attention_mask - (FLAGS.attention_lr) * attention_grad
                    attention_mask = tf.layers.average_pooling2d(attention_mask, (3, 3), 1, padding='SAME')
                    attention_mask = tf.stop_gradient(attention_mask)

                    counter = counter + 1

                    return counter, attention_mask

                steps, attention_mask = tf.while_loop(c, langevin_attention_step, (steps, attention_mask))

                # attention_mask = tf.Print(attention_mask, [attention_mask])

                energy_pos = model.forward(
                        X_SPLIT[j],
                        weights,
                        tf.stop_gradient(attention_mask),
                        label=LABEL_POS_SPLIT[j],
                        stop_at_grad=False)

                if FLAGS.heir_mask:
                    energy_heir = 1.00 * heir_model.forward(attention_mask, weights, label=HEIR_LABEL)
                    energy_pos = energy_heir + energy_pos

            else:
                energy_pos = model.forward(
                        X_SPLIT[j],
                        weights,
                        attention_mask,
                        label=LABEL_POS_SPLIT[j],
                        stop_at_grad=False)

                if FLAGS.heir_mask:
                    energy_heir = 1.00 * heir_model.forward(attention_mask, weights, label=HEIR_LABEL)
                    energy_pos = energy_heir + energy_pos

            print("Building graph...")
            x_mod = x_orig = X_NOISE_SPLIT[j]

            x_grads = []

            loss_energys = []

            eps_begin = tf.zeros(1)

            steps = tf.constant(0)
            c_cond = lambda i, x, y: tf.less(i, FLAGS.num_steps)

            def langevin_step(counter, x_mod, attention_mask):

                lr = FLAGS.step_lr

                x_mod = x_mod + tf.random_normal(tf.shape(x_mod),
                                                 mean=0.0,
                                                 stddev=0.001 * FLAGS.rescale * FLAGS.noise_scale)
                attention_mask = attention_mask + tf.random_normal(tf.shape(attention_mask), mean=0.0, stddev=0.01)

                energy_noise = model.forward(
                            x_mod,
                            weights,
                            attention_mask,
                            label=LABEL_SPLIT[j],
                            reuse=True,
                            stop_at_grad=False,
                            stop_batch=True)

                if FLAGS.prelearn_model:
                    for m_i, w_i, l_i in models_pretrain:
                        energy_noise = energy_noise + m_i.forward(
                                    x_mod,
                                    w_i,
                                    attention_mask,
                                    label=l_i,
                                    reuse=True,
                                    stop_at_grad=False,
                                    stop_batch=True)


                if FLAGS.heir_mask:
                    energy_heir = 1.00 * heir_model.forward(attention_mask, weights, label=HEIR_LABEL)
                    energy_noise = energy_heir + energy_noise

                x_grad, attention_grad = tf.gradients(
                    FLAGS.temperature * energy_noise, [x_mod, attention_mask])

                if not FLAGS.comb_mask:
                    attention_grad = tf.zeros(1)
                energy_noise_old = energy_noise

                if FLAGS.proj_norm != 0.0:
                    if FLAGS.proj_norm_type == 'l2':
                        x_grad = tf.clip_by_norm(x_grad, FLAGS.proj_norm)
                    elif FLAGS.proj_norm_type == 'li':
                        x_grad = tf.clip_by_value(
                            x_grad, -FLAGS.proj_norm, FLAGS.proj_norm)
                    else:
                        print("Other types of projection are not supported!!!")
                        assert False

                # Clip gradient norm for now
                x_last = x_mod - (lr) * x_grad

                if FLAGS.comb_mask:
                    attention_mask = attention_mask - FLAGS.attention_lr * attention_grad
                    attention_mask = tf.layers.average_pooling2d(attention_mask, (3, 3), 1, padding='SAME')
                    attention_mask = tf.stop_gradient(attention_mask)

                x_mod = x_last
                x_mod = tf.clip_by_value(x_mod, 0, FLAGS.rescale)

                counter = counter + 1

                return counter, x_mod, attention_mask


            steps, x_mod, attention_mask = tf.while_loop(c_cond, langevin_step, (steps, x_mod, attention_mask))

            attention_mask = tf.stop_gradient(attention_mask)
            # attention_mask = tf.Print(attention_mask, [attention_mask])

            energy_eval = model.forward(x_mod, weights, attention_mask, label=LABEL_SPLIT[j],
                                        stop_at_grad=False, reuse=True)
            x_grad, attention_grad = tf.gradients(FLAGS.temperature * energy_eval, [x_mod, attention_mask])
            x_grads.append(x_grad)

            energy_neg = model.forward(
                    tf.stop_gradient(x_mod),
                    weights,
                    tf.stop_gradient(attention_mask),
                    label=LABEL_SPLIT[j],
                    stop_at_grad=False,
                    reuse=True)

            if FLAGS.heir_mask:
                energy_heir = 1.00 * heir_model.forward(attention_mask, weights, label=HEIR_LABEL)
                energy_neg = energy_heir + energy_neg


            temp = FLAGS.temperature

            x_off = tf.reduce_mean(
                tf.abs(x_mod[:tf.shape(X_SPLIT[j])[0]] - X_SPLIT[j]))

            loss_energy = model.forward(
                x_mod,
                weights,
                attention_mask,
                reuse=True,
                label=LABEL,
                stop_grad=True)

            print("Finished processing loop construction ...")

            target_vars = {}

            if FLAGS.antialias:
                antialias = tf.tile(stride_3, (1, 1, tf.shape(x_mod)[3], tf.shape(x_mod)[3]))
                inp = tf.nn.conv2d(x_mod, antialias, [1, 2, 2, 1], padding='SAME')

            test_x_mod = x_mod

            if FLAGS.cclass or FLAGS.model_cclass:
                label_sum = tf.reduce_sum(LABEL_SPLIT[0], axis=0)
                label_prob = label_sum / tf.reduce_sum(label_sum)
                label_ent = -tf.reduce_sum(label_prob *
                                           tf.math.log(label_prob + 1e-7))
            else:
                label_ent = tf.zeros(1)

            target_vars['label_ent'] = label_ent

            if FLAGS.train:
                if FLAGS.objective == 'logsumexp':
                    pos_term = temp * energy_pos
                    energy_neg_reduced = (energy_neg - tf.reduce_min(energy_neg))
                    coeff = tf.stop_gradient(tf.exp(-temp * energy_neg_reduced))
                    norm_constant = tf.stop_gradient(tf.reduce_sum(coeff)) + 1e-4
                    pos_loss = tf.reduce_mean(temp * energy_pos)
                    neg_loss = coeff * (-1 * temp * energy_neg) / norm_constant
                    loss_ml = FLAGS.ml_coeff * (pos_loss + tf.reduce_sum(neg_loss))
                elif FLAGS.objective == 'cd':
                    pos_loss = tf.reduce_mean(temp * energy_pos)
                    neg_loss = -tf.reduce_mean(temp * energy_neg)
                    loss_ml = FLAGS.ml_coeff * (pos_loss + tf.reduce_sum(neg_loss))
                elif FLAGS.objective == 'softplus':
                    loss_ml = FLAGS.ml_coeff * \
                        tf.nn.softplus(temp * (energy_pos - energy_neg))

                loss_total = tf.reduce_mean(loss_ml)

                if not FLAGS.zero_kl:
                    loss_total = loss_total + tf.reduce_mean(loss_energy)

                loss_total = loss_total + \
                    FLAGS.l2_coeff * (tf.reduce_mean(tf.square(energy_pos)) + tf.reduce_mean(tf.square((energy_neg))))

                print("Started gradient computation...")
                gvs = optimizer.compute_gradients(loss_total)
                gvs = [(k, v) for (k, v) in gvs if k is not None]

                print("Applying gradients...")

                tower_grads.append(gvs)

                print("Finished applying gradients.")

                target_vars['loss_ml'] = loss_ml
                target_vars['total_loss'] = loss_total
                target_vars['loss_energy'] = loss_energy
                target_vars['weights'] = weights
                target_vars['gvs'] = gvs

            target_vars['X'] = X
            target_vars['Y'] = Y
            target_vars['LABEL'] = LABEL
            target_vars['HIER_LABEL'] = HEIR_LABEL
            target_vars['LABEL_POS'] = LABEL_POS
            target_vars['X_NOISE'] = X_NOISE
            target_vars['energy_pos'] = energy_pos
            target_vars['attention_grad'] = attention_grad

            if len(x_grads) >= 1:
                target_vars['x_grad'] = x_grads[-1]
                target_vars['x_grad_first'] = x_grads[0]
            else:
                target_vars['x_grad'] = tf.zeros(1)
                target_vars['x_grad_first'] = tf.zeros(1)

            target_vars['x_mod'] = x_mod
            target_vars['x_off'] = x_off
            target_vars['temp'] = temp
            target_vars['energy_neg'] = energy_neg
            target_vars['test_x_mod'] = test_x_mod
            target_vars['eps_begin'] = eps_begin
            target_vars['ATTENTION_MASK'] = ATTENTION_MASK
            target_vars['models_pretrain'] = models_pretrain
            if FLAGS.comb_mask:
                target_vars['attention_mask'] = tf.nn.softmax(attention_mask)
            else:
                target_vars['attention_mask'] = tf.zeros(1)

        if FLAGS.train:
            grads = average_gradients(tower_grads)
            train_op = optimizer.apply_gradients(grads)
            target_vars['train_op'] = train_op

    # sess = tf.Session(config=config)

    saver = loader = tf.train.Saver(
        max_to_keep=30, keep_checkpoint_every_n_hours=6)

    total_parameters = 0
    for variable in tf.trainable_variables():
        # shape is an array of tf.Dimension
        shape = variable.get_shape()
        variable_parameters = 1
        for dim in shape:
            variable_parameters *= dim.value
        total_parameters += variable_parameters
    print("Model has a total of {} parameters".format(total_parameters))

    sess.run(tf.global_variables_initializer())

    resume_itr = 0

    if (FLAGS.resume_iter != -1 or not FLAGS.train):
        model_file = osp.join(logdir, 'model_{}'.format(FLAGS.resume_iter))
        resume_itr = FLAGS.resume_iter
        # saver.restore(sess, model_file)
        optimistic_restore(sess, model_file)

    print("Initializing variables...")

    print("Start broadcast")
    print("End broadcast")

    if FLAGS.train:
        train(target_vars, saver, sess,
              logger, data_loader, resume_itr,
              logdir)

    test(target_vars, saver, sess, logger, data_loader)
Exemple #5
0
def main():
    print("Local rank: ", hvd.local_rank(), hvd.size())

    logdir = osp.join(FLAGS.logdir, FLAGS.exp)
    if hvd.rank() == 0:
        if not osp.exists(logdir):
            os.makedirs(logdir)
        logger = TensorBoardOutputFormat(logdir)
    else:
        logger = None

    LABEL = None
    print("Loading data...")
    if FLAGS.dataset == 'cifar10':
        dataset = Cifar10(augment=FLAGS.augment, rescale=FLAGS.rescale)
        test_dataset = Cifar10(train=False, rescale=FLAGS.rescale)
        channel_num = 3

        X_NOISE = tf.placeholder(shape=(None, 32, 32, 3), dtype=tf.float32)
        X = tf.placeholder(shape=(None, 32, 32, 3), dtype=tf.float32)
        LABEL = tf.placeholder(shape=(None, 10), dtype=tf.float32)
        LABEL_POS = tf.placeholder(shape=(None, 10), dtype=tf.float32)

        if FLAGS.large_model:
            model = ResNet32Large(num_channels=channel_num,
                                  num_filters=128,
                                  train=True)
        elif FLAGS.larger_model:
            model = ResNet32Larger(num_channels=channel_num, num_filters=128)
        elif FLAGS.wider_model:
            model = ResNet32Wider(num_channels=channel_num, num_filters=192)
        else:
            model = ResNet32(num_channels=channel_num, num_filters=128)

    elif FLAGS.dataset == 'imagenet':
        dataset = Imagenet(train=True)
        test_dataset = Imagenet(train=False)
        channel_num = 3
        X_NOISE = tf.placeholder(shape=(None, 32, 32, 3), dtype=tf.float32)
        X = tf.placeholder(shape=(None, 32, 32, 3), dtype=tf.float32)
        LABEL = tf.placeholder(shape=(None, 1000), dtype=tf.float32)
        LABEL_POS = tf.placeholder(shape=(None, 1000), dtype=tf.float32)

        model = ResNet32Wider(num_channels=channel_num, num_filters=256)

    elif FLAGS.dataset == 'imagenetfull':
        channel_num = 3
        X_NOISE = tf.placeholder(shape=(None, 128, 128, 3), dtype=tf.float32)
        X = tf.placeholder(shape=(None, 128, 128, 3), dtype=tf.float32)
        LABEL = tf.placeholder(shape=(None, 1000), dtype=tf.float32)
        LABEL_POS = tf.placeholder(shape=(None, 1000), dtype=tf.float32)

        model = ResNet128(num_channels=channel_num, num_filters=64)

    elif FLAGS.dataset == 'mnist':
        dataset = Mnist(rescale=FLAGS.rescale)
        test_dataset = dataset
        channel_num = 1
        X_NOISE = tf.placeholder(shape=(None, 28, 28), dtype=tf.float32)
        X = tf.placeholder(shape=(None, 28, 28), dtype=tf.float32)
        LABEL = tf.placeholder(shape=(None, 10), dtype=tf.float32)
        LABEL_POS = tf.placeholder(shape=(None, 10), dtype=tf.float32)

        model = MnistNet(num_channels=channel_num,
                         num_filters=FLAGS.num_filters)

    elif FLAGS.dataset == 'dsprites':
        dataset = DSprites(cond_shape=FLAGS.cond_shape,
                           cond_size=FLAGS.cond_size,
                           cond_pos=FLAGS.cond_pos,
                           cond_rot=FLAGS.cond_rot)
        test_dataset = dataset
        channel_num = 1

        X_NOISE = tf.placeholder(shape=(None, 64, 64), dtype=tf.float32)
        X = tf.placeholder(shape=(None, 64, 64), dtype=tf.float32)

        if FLAGS.dpos_only:
            LABEL = tf.placeholder(shape=(None, 2), dtype=tf.float32)
            LABEL_POS = tf.placeholder(shape=(None, 2), dtype=tf.float32)
        elif FLAGS.dsize_only:
            LABEL = tf.placeholder(shape=(None, 1), dtype=tf.float32)
            LABEL_POS = tf.placeholder(shape=(None, 1), dtype=tf.float32)
        elif FLAGS.drot_only:
            LABEL = tf.placeholder(shape=(None, 2), dtype=tf.float32)
            LABEL_POS = tf.placeholder(shape=(None, 2), dtype=tf.float32)
        elif FLAGS.cond_size:
            LABEL = tf.placeholder(shape=(None, 1), dtype=tf.float32)
            LABEL_POS = tf.placeholder(shape=(None, 1), dtype=tf.float32)
        elif FLAGS.cond_shape:
            LABEL = tf.placeholder(shape=(None, 3), dtype=tf.float32)
            LABEL_POS = tf.placeholder(shape=(None, 3), dtype=tf.float32)
        elif FLAGS.cond_pos:
            LABEL = tf.placeholder(shape=(None, 2), dtype=tf.float32)
            LABEL_POS = tf.placeholder(shape=(None, 2), dtype=tf.float32)
        elif FLAGS.cond_rot:
            LABEL = tf.placeholder(shape=(None, 2), dtype=tf.float32)
            LABEL_POS = tf.placeholder(shape=(None, 2), dtype=tf.float32)
        else:
            LABEL = tf.placeholder(shape=(None, 3), dtype=tf.float32)
            LABEL_POS = tf.placeholder(shape=(None, 3), dtype=tf.float32)

        model = DspritesNet(num_channels=channel_num,
                            num_filters=FLAGS.num_filters,
                            cond_size=FLAGS.cond_size,
                            cond_shape=FLAGS.cond_shape,
                            cond_pos=FLAGS.cond_pos,
                            cond_rot=FLAGS.cond_rot)

    print("Done loading...")

    if FLAGS.dataset == "imagenetfull":
        # In the case of full imagenet, use custom_tensorflow dataloader
        data_loader = TFImagenetLoader('train',
                                       FLAGS.batch_size,
                                       hvd.rank(),
                                       hvd.size(),
                                       rescale=FLAGS.rescale)
    else:
        data_loader = DataLoader(dataset,
                                 batch_size=FLAGS.batch_size,
                                 num_workers=FLAGS.data_workers,
                                 drop_last=True,
                                 shuffle=True)

    batch_size = FLAGS.batch_size

    weights = [model.construct_weights('context_0')]

    Y = tf.placeholder(shape=(None), dtype=tf.int32)

    # Varibles to run in training
    X_SPLIT = tf.split(X, FLAGS.num_gpus)
    X_NOISE_SPLIT = tf.split(X_NOISE, FLAGS.num_gpus)
    LABEL_SPLIT = tf.split(LABEL, FLAGS.num_gpus)
    LABEL_POS_SPLIT = tf.split(LABEL_POS, FLAGS.num_gpus)
    LABEL_SPLIT_INIT = list(LABEL_SPLIT)
    tower_grads = []
    tower_gen_grads = []
    x_mod_list = []

    optimizer = AdamOptimizer(FLAGS.lr, beta1=0.0, beta2=0.999)
    optimizer = hvd.DistributedOptimizer(optimizer)

    for j in range(FLAGS.num_gpus):

        if FLAGS.model_cclass:
            ind_batch_size = FLAGS.batch_size // FLAGS.num_gpus
            label_tensor = tf.Variable(tf.convert_to_tensor(np.reshape(
                np.tile(np.eye(10), (FLAGS.batch_size, 1, 1)),
                (FLAGS.batch_size * 10, 10)),
                                                            dtype=tf.float32),
                                       trainable=False,
                                       dtype=tf.float32)
            x_split = tf.tile(
                tf.reshape(X_SPLIT[j], (ind_batch_size, 1, 32, 32, 3)),
                (1, 10, 1, 1, 1))
            x_split = tf.reshape(x_split, (ind_batch_size * 10, 32, 32, 3))
            energy_pos = model.forward(x_split,
                                       weights[0],
                                       label=label_tensor,
                                       stop_at_grad=False)

            energy_pos_full = tf.reshape(energy_pos, (ind_batch_size, 10))
            energy_partition_est = tf.reduce_logsumexp(energy_pos_full,
                                                       axis=1,
                                                       keepdims=True)
            uniform = tf.random_uniform(tf.shape(energy_pos_full))
            label_tensor = tf.argmax(-energy_pos_full -
                                     tf.log(-tf.log(uniform)) -
                                     energy_partition_est,
                                     axis=1)
            label = tf.one_hot(label_tensor, 10, dtype=tf.float32)
            label = tf.Print(label, [label_tensor, energy_pos_full])
            LABEL_SPLIT[j] = label
            energy_pos = tf.concat(energy_pos, axis=0)
        else:
            energy_pos = [
                model.forward(X_SPLIT[j],
                              weights[0],
                              label=LABEL_POS_SPLIT[j],
                              stop_at_grad=False)
            ]
            energy_pos = tf.concat(energy_pos, axis=0)

        print("Building graph...")
        x_mod = x_orig = X_NOISE_SPLIT[j]

        x_grads = []

        energy_negs = []
        loss_energys = []

        energy_negs.extend([
            model.forward(tf.stop_gradient(x_mod),
                          weights[0],
                          label=LABEL_SPLIT[j],
                          stop_at_grad=False,
                          reuse=True)
        ])
        eps_begin = tf.zeros(1)

        steps = tf.constant(0)
        c = lambda i, x: tf.less(i, FLAGS.num_steps)

        def langevin_step(counter, x_mod):
            x_mod = x_mod + tf.random_normal(
                tf.shape(x_mod),
                mean=0.0,
                stddev=0.005 * FLAGS.rescale * FLAGS.noise_scale)

            energy_noise = energy_start = tf.concat([
                model.forward(x_mod,
                              weights[0],
                              label=LABEL_SPLIT[j],
                              reuse=True,
                              stop_at_grad=False,
                              stop_batch=True)
            ],
                                                    axis=0)

            x_grad, label_grad = tf.gradients(FLAGS.temperature * energy_noise,
                                              [x_mod, LABEL_SPLIT[j]])
            energy_noise_old = energy_noise

            lr = FLAGS.step_lr

            if FLAGS.proj_norm != 0.0:
                if FLAGS.proj_norm_type == 'l2':
                    x_grad = tf.clip_by_norm(x_grad, FLAGS.proj_norm)
                elif FLAGS.proj_norm_type == 'li':
                    x_grad = tf.clip_by_value(x_grad, -FLAGS.proj_norm,
                                              FLAGS.proj_norm)
                else:
                    print("Other types of projection are not supported!!!")
                    assert False

            # Clip gradient norm for now
            if FLAGS.hmc:
                # Step size should be tuned to get around 65% acceptance
                def energy(x):
                    return FLAGS.temperature * \
                        model.forward(x, weights[0], label=LABEL_SPLIT[j], reuse=True)

                x_last = hmc(x_mod, 15., 10, energy)
            else:
                x_last = x_mod - (lr) * x_grad

            x_mod = x_last
            x_mod = tf.clip_by_value(x_mod, 0, FLAGS.rescale)

            counter = counter + 1

            return counter, x_mod

        steps, x_mod = tf.while_loop(c, langevin_step, (steps, x_mod))

        energy_eval = model.forward(x_mod,
                                    weights[0],
                                    label=LABEL_SPLIT[j],
                                    stop_at_grad=False,
                                    reuse=True)
        x_grad = tf.gradients(FLAGS.temperature * energy_eval, [x_mod])[0]
        x_grads.append(x_grad)

        energy_negs.append(
            model.forward(tf.stop_gradient(x_mod),
                          weights[0],
                          label=LABEL_SPLIT[j],
                          stop_at_grad=False,
                          reuse=True))

        test_x_mod = x_mod

        temp = FLAGS.temperature

        energy_neg = energy_negs[-1]
        x_off = tf.reduce_mean(
            tf.abs(x_mod[:tf.shape(X_SPLIT[j])[0]] - X_SPLIT[j]))

        loss_energy = model.forward(x_mod,
                                    weights[0],
                                    reuse=True,
                                    label=LABEL,
                                    stop_grad=True)

        print("Finished processing loop construction ...")

        target_vars = {}

        if FLAGS.cclass or FLAGS.model_cclass:
            label_sum = tf.reduce_sum(LABEL_SPLIT[0], axis=0)
            label_prob = label_sum / tf.reduce_sum(label_sum)
            label_ent = -tf.reduce_sum(
                label_prob * tf.math.log(label_prob + 1e-7))
        else:
            label_ent = tf.zeros(1)

        target_vars['label_ent'] = label_ent

        if FLAGS.train:

            if FLAGS.objective == 'logsumexp':
                pos_term = temp * energy_pos
                energy_neg_reduced = (energy_neg - tf.reduce_min(energy_neg))
                coeff = tf.stop_gradient(tf.exp(-temp * energy_neg_reduced))
                norm_constant = tf.stop_gradient(tf.reduce_sum(coeff)) + 1e-4
                pos_loss = tf.reduce_mean(temp * energy_pos)
                neg_loss = coeff * (-1 * temp * energy_neg) / norm_constant
                loss_ml = FLAGS.ml_coeff * (pos_loss + tf.reduce_sum(neg_loss))
            elif FLAGS.objective == 'cd':
                pos_loss = tf.reduce_mean(temp * energy_pos)
                neg_loss = -tf.reduce_mean(temp * energy_neg)
                loss_ml = FLAGS.ml_coeff * (pos_loss + tf.reduce_sum(neg_loss))
            elif FLAGS.objective == 'softplus':
                loss_ml = FLAGS.ml_coeff * \
                    tf.nn.softplus(temp * (energy_pos - energy_neg))

            loss_total = tf.reduce_mean(loss_ml)

            if not FLAGS.zero_kl:
                loss_total = loss_total + tf.reduce_mean(loss_energy)

            loss_total = loss_total + \
                FLAGS.l2_coeff * (tf.reduce_mean(tf.square(energy_pos)) + tf.reduce_mean(tf.square((energy_neg))))

            print("Started gradient computation...")
            gvs = optimizer.compute_gradients(loss_total)
            gvs = [(k, v) for (k, v) in gvs if k is not None]

            print("Applying gradients...")

            tower_grads.append(gvs)

            print("Finished applying gradients.")

            target_vars['loss_ml'] = loss_ml
            target_vars['total_loss'] = loss_total
            target_vars['loss_energy'] = loss_energy
            target_vars['weights'] = weights
            target_vars['gvs'] = gvs

        target_vars['X'] = X
        target_vars['Y'] = Y
        target_vars['LABEL'] = LABEL
        target_vars['LABEL_POS'] = LABEL_POS
        target_vars['X_NOISE'] = X_NOISE
        target_vars['energy_pos'] = energy_pos
        target_vars['energy_start'] = energy_negs[0]

        if len(x_grads) >= 1:
            target_vars['x_grad'] = x_grads[-1]
            target_vars['x_grad_first'] = x_grads[0]
        else:
            target_vars['x_grad'] = tf.zeros(1)
            target_vars['x_grad_first'] = tf.zeros(1)

        target_vars['x_mod'] = x_mod
        target_vars['x_off'] = x_off
        target_vars['temp'] = temp
        target_vars['energy_neg'] = energy_neg
        target_vars['test_x_mod'] = test_x_mod
        target_vars['eps_begin'] = eps_begin

    if FLAGS.train:
        grads = average_gradients(tower_grads)
        train_op = optimizer.apply_gradients(grads)
        target_vars['train_op'] = train_op

    config = tf.ConfigProto()

    if hvd.size() > 1:
        config.gpu_options.visible_device_list = str(hvd.local_rank())

    sess = tf.Session(config=config)

    saver = loader = tf.train.Saver(max_to_keep=30,
                                    keep_checkpoint_every_n_hours=6)

    total_parameters = 0
    for variable in tf.trainable_variables():
        # shape is an array of tf.Dimension
        shape = variable.get_shape()
        variable_parameters = 1
        for dim in shape:
            variable_parameters *= dim.value
        total_parameters += variable_parameters
    print("Model has a total of {} parameters".format(total_parameters))

    sess.run(tf.global_variables_initializer())

    resume_itr = 0

    if (FLAGS.resume_iter != -1 or not FLAGS.train) and hvd.rank() == 0:
        model_file = osp.join(logdir, 'model_{}'.format(FLAGS.resume_iter))
        resume_itr = FLAGS.resume_iter
        # saver.restore(sess, model_file)
        optimistic_restore(sess, model_file)

    sess.run(hvd.broadcast_global_variables(0))
    print("Initializing variables...")

    print("Start broadcast")
    print("End broadcast")

    if FLAGS.train:
        print("Training phase")
        train(target_vars, saver, sess, logger, data_loader, resume_itr,
              logdir)
    print("Testing phase")
    test(target_vars, saver, sess, logger, data_loader)
def main():

    # Initialize dataset
    if FLAGS.dataset == 'cifar10':
        dataset = Cifar10(train=False, rescale=FLAGS.rescale)
        channel_num = 3
        dim_input = 32 * 32 * 3
    elif FLAGS.dataset == 'imagenet':
        dataset = ImagenetClass()
        channel_num = 3
        dim_input = 64 * 64 * 3
    elif FLAGS.dataset == 'mnist':
        dataset = Mnist(train=False, rescale=FLAGS.rescale)
        channel_num = 1
        dim_input = 28 * 28 * 1
    elif FLAGS.dataset == 'dsprites':
        dataset = DSprites()
        channel_num = 1
        dim_input = 64 * 64 * 1
    elif FLAGS.dataset == '2d' or FLAGS.dataset == 'gauss':
        dataset = Box2D()

    dim_output = 1
    data_loader = DataLoader(dataset,
                             batch_size=FLAGS.batch_size,
                             num_workers=FLAGS.data_workers,
                             drop_last=False,
                             shuffle=True)

    if FLAGS.dataset == 'mnist':
        model = MnistNet(num_channels=channel_num)
    elif FLAGS.dataset == 'cifar10':
        if FLAGS.large_model:
            model = ResNet32Large(num_filters=128)
        elif FLAGS.wider_model:
            model = ResNet32Wider(num_filters=192)
        else:
            model = ResNet32(num_channels=channel_num, num_filters=128)
    elif FLAGS.dataset == 'dsprites':
        model = DspritesNet(num_channels=channel_num,
                            num_filters=FLAGS.num_filters)

    weights = model.construct_weights('context_{}'.format(0))

    config = tf.ConfigProto()
    sess = tf.Session(config=config)
    saver = loader = tf.train.Saver(max_to_keep=10)

    sess.run(tf.global_variables_initializer())
    logdir = osp.join(FLAGS.logdir, FLAGS.exp)

    model_file = osp.join(logdir, 'model_{}'.format(FLAGS.resume_iter))
    resume_itr = FLAGS.resume_iter

    if FLAGS.resume_iter != "-1":
        optimistic_restore(sess, model_file)
    else:
        print("WARNING, YOU ARE NOT LOADING A SAVE FILE")
    # saver.restore(sess, model_file)

    chain_weights, a_prev, a_new, x, x_init, approx_lr = ancestral_sample(
        model, weights, FLAGS.batch_size, temp=FLAGS.temperature)
    print("Finished constructing ancestral sample ...................")

    if FLAGS.dataset != "gauss":
        comb_weights_cum = []
        batch_size = tf.shape(x_init)[0]
        label_tiled = tf.tile(label_default, (batch_size, 1))
        e_compute = -FLAGS.temperature * model.forward(
            x_init, weights, label=label_tiled)
        e_pos_list = []

        for data_corrupt, data, label_gt in tqdm(data_loader):
            e_pos = sess.run([e_compute], {x_init: data})[0]
            e_pos_list.extend(list(e_pos))

        print(len(e_pos_list))
        print("Positive sample probability ", np.mean(e_pos_list),
              np.std(e_pos_list))

    if FLAGS.dataset == "2d":
        alr = 0.0045
    elif FLAGS.dataset == "gauss":
        alr = 0.0085
    elif FLAGS.dataset == "mnist":
        alr = 0.0065
        #90 alr = 0.0035
    else:
        # alr = 0.0125
        if FLAGS.rescale == 8:
            alr = 0.0085
        else:
            alr = 0.0045


#
    for i in range(1):
        tot_weight = 0
        for j in tqdm(range(1, FLAGS.pdist + 1)):
            if j == 1:
                if FLAGS.dataset == "cifar10":
                    x_curr = np.random.uniform(0,
                                               FLAGS.rescale,
                                               size=(FLAGS.batch_size, 32, 32,
                                                     3))
                elif FLAGS.dataset == "gauss":
                    x_curr = np.random.uniform(0,
                                               FLAGS.rescale,
                                               size=(FLAGS.batch_size,
                                                     FLAGS.gauss_dim))
                elif FLAGS.dataset == "mnist":
                    x_curr = np.random.uniform(0,
                                               FLAGS.rescale,
                                               size=(FLAGS.batch_size, 28, 28))
                else:
                    x_curr = np.random.uniform(0,
                                               FLAGS.rescale,
                                               size=(FLAGS.batch_size, 2))

            alpha_prev = (j - 1) / FLAGS.pdist
            alpha_new = j / FLAGS.pdist
            cweight, x_curr = sess.run(
                [chain_weights, x], {
                    a_prev: alpha_prev,
                    a_new: alpha_new,
                    x_init: x_curr,
                    approx_lr: alr * (5**(2.5 * -alpha_prev))
                })
            tot_weight = tot_weight + cweight

        print("Total values of lower value based off forward sampling",
              np.mean(tot_weight), np.std(tot_weight))

        tot_weight = 0

        for j in tqdm(range(FLAGS.pdist, 0, -1)):
            alpha_new = (j - 1) / FLAGS.pdist
            alpha_prev = j / FLAGS.pdist
            cweight, x_curr = sess.run(
                [chain_weights, x], {
                    a_prev: alpha_prev,
                    a_new: alpha_new,
                    x_init: x_curr,
                    approx_lr: alr * (5**(2.5 * -alpha_prev))
                })
            tot_weight = tot_weight - cweight

        print("Total values of upper value based off backward sampling",
              np.mean(tot_weight), np.std(tot_weight))
def main():

    if FLAGS.dataset == "cifar10":
        dataset = Cifar10(train=True, noise=False)
        test_dataset = Cifar10(train=False, noise=False)
    else:
        dataset = Imagenet(train=True)
        test_dataset = Imagenet(train=False)

    if FLAGS.svhn:
        dataset = Svhn(train=True)
        test_dataset = Svhn(train=False)

    if FLAGS.task == 'latent':
        dataset = DSprites()
        test_dataset = dataset

    dataloader = DataLoader(dataset,
                            batch_size=FLAGS.batch_size,
                            num_workers=FLAGS.data_workers,
                            shuffle=True,
                            drop_last=True)
    test_dataloader = DataLoader(test_dataset,
                                 batch_size=FLAGS.batch_size,
                                 num_workers=FLAGS.data_workers,
                                 shuffle=True,
                                 drop_last=True)

    hidden_dim = 128

    if FLAGS.large_model:
        model = ResNet32Large(num_filters=hidden_dim)
    elif FLAGS.larger_model:
        model = ResNet32Larger(num_filters=hidden_dim)
    elif FLAGS.wider_model:
        if FLAGS.dataset == 'imagenet':
            model = ResNet32Wider(num_filters=196, train=False)
        else:
            model = ResNet32Wider(num_filters=256, train=False)
    else:
        model = ResNet32(num_filters=hidden_dim)

    if FLAGS.task == 'latent':
        model = DspritesNet()

    weights = model.construct_weights('context_{}'.format(0))

    total_parameters = 0
    for variable in tf.compat.v1.trainable_variables():
        # shape is an array of tf.Dimension
        shape = variable.get_shape()
        variable_parameters = 1
        for dim in shape:
            variable_parameters *= dim.value
        total_parameters += variable_parameters
    print("Model has a total of {} parameters".format(total_parameters))

    config = tf.compat.v1.ConfigProto()
    sess = tf.compat.v1.InteractiveSession()

    if FLAGS.task == 'latent':
        X = tf.compat.v1.placeholder(shape=(None, 64, 64), dtype=tf.float32)
    else:
        X = tf.compat.v1.placeholder(shape=(None, 32, 32, 3), dtype=tf.float32)

    if FLAGS.dataset == "cifar10":
        Y = tf.compat.v1.placeholder(shape=(None, 10), dtype=tf.float32)
        Y_GT = tf.compat.v1.placeholder(shape=(None, 10), dtype=tf.float32)
    elif FLAGS.dataset == "imagenet":
        Y = tf.compat.v1.placeholder(shape=(None, 1000), dtype=tf.float32)
        Y_GT = tf.compat.v1.placeholder(shape=(None, 1000), dtype=tf.float32)

    target_vars = {'X': X, 'Y': Y, 'Y_GT': Y_GT}

    if FLAGS.task == 'label':
        construct_label(weights, X, Y, Y_GT, model, target_vars)
    elif FLAGS.task == 'labelfinetune':
        construct_finetune_label(
            weights,
            X,
            Y,
            Y_GT,
            model,
            target_vars,
        )
    elif FLAGS.task == 'energyeval' or FLAGS.task == 'mixenergy':
        construct_energy(weights, X, Y, Y_GT, model, target_vars)
    elif FLAGS.task == 'anticorrupt' or FLAGS.task == 'boxcorrupt' or FLAGS.task == 'crossclass' or FLAGS.task == 'cycleclass' or FLAGS.task == 'democlass' or FLAGS.task == 'nearestneighbor':
        construct_steps(weights, X, Y_GT, model, target_vars)
    elif FLAGS.task == 'latent':
        construct_latent(weights, X, Y_GT, model, target_vars)

    sess.run(tf.compat.v1.global_variables_initializer())
    saver = loader = tf.compat.v1.train.Saver(max_to_keep=10)
    savedir = osp.join('cachedir', FLAGS.exp)
    logdir = osp.join(FLAGS.logdir, FLAGS.exp)
    if not osp.exists(logdir):
        os.makedirs(logdir)

    initialize()
    if FLAGS.resume_iter != -1:
        model_file = osp.join(savedir, 'model_{}'.format(FLAGS.resume_iter))
        resume_itr = FLAGS.resume_iter

        if FLAGS.task == 'label' or FLAGS.task == 'boxcorrupt' or FLAGS.task == 'labelfinetune' or FLAGS.task == "energyeval" or FLAGS.task == "crossclass" or FLAGS.task == "mixenergy":
            optimistic_restore(sess, model_file)
            # saver.restore(sess, model_file)
        else:
            # optimistic_restore(sess, model_file)
            saver.restore(sess, model_file)

    if FLAGS.task == 'label':
        if FLAGS.labelgrid:
            vals = []
            if FLAGS.lnorm == -1:
                for i in range(31):
                    accuracies = label(dataloader,
                                       test_dataloader,
                                       target_vars,
                                       sess,
                                       l1val=i)
                    vals.append(accuracies)
            elif FLAGS.lnorm == 2:
                for i in range(0, 100, 5):
                    accuracies = label(dataloader,
                                       test_dataloader,
                                       target_vars,
                                       sess,
                                       l2val=i)
                    vals.append(accuracies)

            np.save("result_{}_{}.npy".format(FLAGS.lnorm, FLAGS.exp), vals)
        else:
            label(dataloader, test_dataloader, target_vars, sess)
    elif FLAGS.task == 'labelfinetune':
        labelfinetune(dataloader,
                      test_dataloader,
                      target_vars,
                      sess,
                      savedir,
                      saver,
                      l1val=FLAGS.lival,
                      l2val=FLAGS.l2val)
    elif FLAGS.task == 'energyeval':
        energyeval(dataloader, test_dataloader, target_vars, sess)
    elif FLAGS.task == 'mixenergy':
        energyevalmix(dataloader, test_dataloader, target_vars, sess)
    elif FLAGS.task == 'anticorrupt':
        anticorrupt(test_dataloader, weights, model, target_vars, logdir, sess)
    elif FLAGS.task == 'boxcorrupt':
        # boxcorrupt(test_dataloader, weights, model, target_vars, logdir, sess)
        boxcorrupt(test_dataloader, dataloader, weights, model, target_vars,
                   logdir, sess)
    elif FLAGS.task == 'crossclass':
        crossclass(test_dataloader, weights, model, target_vars, logdir, sess)
    elif FLAGS.task == 'cycleclass':
        cycleclass(test_dataloader, weights, model, target_vars, logdir, sess)
    elif FLAGS.task == 'democlass':
        democlass(test_dataloader, weights, model, target_vars, logdir, sess)
    elif FLAGS.task == 'nearestneighbor':
        # print(dir(dataset))
        # print(type(dataset))
        nearest_neighbor(dataset.data.train_data / 255, sess, target_vars,
                         logdir)
    elif FLAGS.task == 'latent':
        latent(test_dataloader, weights, model, target_vars, sess)
Exemple #8
0
def main(argv=()):
    del argv

    batch_size = FLAGS.batch_size_per_gpu * FLAGS.num_gpus

    data_stream_init = utils.setup_data_stream_genome(
        "train",
        batch_size=FLAGS.init_batch_size,
        image_res=FLAGS.image_res,
    )
    (image_init_batch, class_init_batch, box_init_batch) = data_stream_init

    data_stream_train = utils.setup_data_stream_genome(
        "train", batch_size=batch_size, image_res=FLAGS.image_res)
    (image_train_batch, class_train_batch, box_train_batch) = data_stream_train

    data_stream_val = utils.setup_data_stream_genome("val",
                                                     batch_size=batch_size,
                                                     image_res=FLAGS.image_res)
    (image_val_batch, class_val_batch, box_val_batch) = data_stream_val

    def model_template(images, labels, boxes, stage):
        return models.model_detection(images, labels, boxes, stage)

    model_factory = tf.make_template("detection", model_template)

    tf.GLOBAL = {}

    # Init
    tf.GLOBAL["init"] = True
    tf.GLOBAL["dropout"] = 0.0

    with tf.device("/cpu:0"):
        _ = model_factory(image_init_batch, [class_init_batch], box_init_batch,
                          0)
    ## Train
    tf.GLOBAL["init"] = False
    tf.GLOBAL["dropout"] = 0.5

    imgs_train = tf.split(image_train_batch, FLAGS.num_gpus, 0)
    class_train = tf.split(class_train_batch, FLAGS.num_gpus, 0)
    boxes_train = tf.split(box_train_batch, FLAGS.num_gpus, 0)

    min_stage = tf.placeholder(shape=[], dtype=tf.int32)
    stage_train = tf.random_uniform([], min_stage, 5, dtype=tf.int32)

    loss_train = 0.0
    for i in range(FLAGS.num_gpus):
        with tf.device("gpu:%i" % i if FLAGS.mode == "gpu" else "/cpu:0"):
            _, loss = model_factory(imgs_train[i], [class_train[i]],
                                    boxes_train[i], stage_train)
            loss_train = loss_train + loss

    loss_train /= FLAGS.num_gpus

    # Optimization
    learning_rate = tf.Variable(0.0001)
    update_lr = learning_rate.assign(FLAGS.decay * learning_rate)
    optimizer = tf.train.AdamOptimizer(learning_rate, 0.95, 0.9995)
    train_step = optimizer.minimize(loss_train,
                                    colocate_gradients_with_ops=True)

    train_bpd_ph = tf.placeholder(shape=[], dtype=tf.float32)
    summary_train = {
        i: tf.summary.scalar("train_bpd_stage%i" % i, train_bpd_ph)
        for i in range(5)
    }

    ## Val
    tf.GLOBAL["init"] = False
    tf.GLOBAL["dropout"] = 0.0

    imgs_val = tf.split(image_val_batch, FLAGS.num_gpus, 0)
    class_val = tf.split(class_val_batch, FLAGS.num_gpus, 0)
    boxes_val = tf.split(box_val_batch, FLAGS.num_gpus, 0)
    stage_val = tf.random_uniform([], 0, 5, dtype=tf.int32)

    loss_val = 0.0
    label_p_val, point_p_val = [], []
    for i in range(FLAGS.num_gpus):
        with tf.device("gpu:%i" % i if FLAGS.mode == "gpu" else "/cpu:0"):
            [label_p_v,
             point_p_v], loss = model_factory(imgs_val[i], [class_val[i]],
                                              boxes_val[i], stage_val)
            loss_val = loss_val + loss
            label_p_val.append(label_p_v)
            point_p_val.append(point_p_v)

    loss_val /= FLAGS.num_gpus
    label_p_val = [tf.concat(l, axis=0) for l in zip(*label_p_val)]
    point_p_val = [tf.concat(l, axis=0) for l in zip(*point_p_val)]

    val_bpd_ph = tf.placeholder(shape=[], dtype=tf.float32)
    summary_val = {
        i: tf.summary.scalar("val_bpd_stage%i" % i, val_bpd_ph)
        for i in range(5)
    }

    # Counters
    global_step, val_step = tf.Variable(1), tf.Variable(1)
    update_global_step = global_step.assign_add(1)
    update_val_step = val_step.assign_add(1)

    ## Inits
    var_init_1 = [
        v for v in tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
        if v.name.find("image_parser") >= 0
    ]
    var_init_2 = [
        v for v in tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
        if v.name.find("detector") >= 0
    ]
    var_rest = list(
        set(tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)) -
        set(var_init_1 + var_init_2))

    init_ops = [
        tf.initialize_variables(v_l)
        for v_l in [var_init_1, var_init_2, var_rest]
    ]

    ####
    image_summary_placeholder = tf.placeholder(dtype=tf.float32)
    image_summary_sample_val = tf.summary.image("validation_samples",
                                                image_summary_placeholder,
                                                max_outputs=256)
    saver = tf.train.Saver()

    # tf.get_default_graph().finalize()
    with tf.Session() as sess:
        with queues.QueueRunners(sess):

            default_model_meta = os.path.join(FLAGS.tb_log_dir, "main",
                                              "model.ckpt.meta")
            default_model_file = os.path.join(FLAGS.tb_log_dir, "main",
                                              "model.ckpt")
            rerun = False
            if tf.gfile.Exists(default_model_meta):
                print("Model is loading...")
                saver.restore(sess, default_model_file)
                rerun = True
            else:
                # Initialization (Due to the bug in tensorflow it is split
                #                 into multiple steps)
                _ = [sess.run(init_op) for init_op in init_ops]

                if FLAGS.use_pretrained:
                    utils.optimistic_restore(sess, "")
                    sess.run(global_step.assign(1))
                    sess.run(val_step.assign(1))
                    sess.run(learning_rate.assign(0.0001))

            # Summary writers
            summary_writer_main = tf.summary.FileWriter(
                "%s/%s" % (FLAGS.tb_log_dir, "main"), sess.graph)
            # Visalize validation GT

            if not rerun:
                (imgs_sample, box_cls_sample, boxes_sample) = sess.run(
                    [image_val_batch, class_val_batch, box_val_batch])

                boxes_sample = np.concatenate([
                    boxes_sample[..., :2][:, None], boxes_sample[..., 2:][:,
                                                                          None]
                ], 1)
                imgs_with_box = utils.visualize(imgs_sample, box_cls_sample,
                                                boxes_sample, utils.LABEL_MAP)

                s = sess.run(
                    image_summary_sample_val,
                    {image_summary_placeholder: np.array(imgs_with_box)})
                summary_writer_main.add_summary(s, 0)

        # Run training

            n_iter_train = (utils.SST_COUNTS["train"] // batch_size
                            if FLAGS.iter_cap <= 0 else FLAGS.iter_cap)
            n_iter_val = (utils.SST_COUNTS["val"] // batch_size
                          if FLAGS.iter_cap <= 0 else FLAGS.iter_cap)

            max_iter = FLAGS.num_epochs * n_iter_train

            buf_loss = defaultdict(list)
            val_i = 0
            while True and (not FLAGS.run_test):

                # Training step
                (_, loss_v, stage_v, train_i, val_i) = sess.run([
                    train_step, loss_train, stage_train, global_step, val_step
                ], {min_stage: 0})

                buf_loss[stage_v].append(loss_v)

                # Update global counter and learning rate
                sess.run([update_global_step, update_lr])

                # Log training error
                if train_i % FLAGS.log_training_loss == 0:

                    for i in range(5):
                        s = sess.run(summary_train[i],
                                     {train_bpd_ph: np.mean(buf_loss[i])})
                        summary_writer_main.add_summary(s, train_i)
                    buf_loss = defaultdict(list)

                # Log val error and visualize samples
                if train_i % FLAGS.log_val_loss == 0:

                    buf_loss = defaultdict(list)
                    for i in range(n_iter_val):
                        loss_v, stage_v = sess.run([loss_val, stage_val])
                        buf_loss[stage_v].append(loss_v)

                    for i in range(5):
                        s = sess.run(summary_val[i],
                                     {val_bpd_ph: np.mean(buf_loss[i])})
                        summary_writer_main.add_summary(s, val_i)
                    buf_loss = defaultdict(list)

                    # Sample detections

                    label_np = np.zeros((batch_size, 41))
                    boxes_np = np.zeros((batch_size, 56, 56, 4))

                    # stage 0
                    l = sess.run(
                        label_p_val, {
                            image_val_batch: imgs_sample,
                            class_val_batch: label_np,
                            box_val_batch: boxes_np,
                            stage_val: 0
                        })[0]
                    l = np.argmax(l, axis=1)
                    label_np[range(batch_size), l] = 1

                    # stage 1
                    for ii in range(4):
                        l = sess.run(
                            point_p_val, {
                                image_val_batch: imgs_sample,
                                class_val_batch: label_np,
                                box_val_batch: boxes_np,
                                stage_val: ii + 1
                            })[ii]
                        l = (l == np.amax(l, axis=(1, 2),
                                          keepdims=True)).astype("int32")
                        boxes_np[:, :, :, ii:ii + 1] = l

                    # vis
                    boxes_np = np.concatenate([
                        boxes_np[..., :2][:, None], boxes_np[..., 2:][:, None]
                    ], 1)
                    imgs_with_box = utils.visualize(imgs_sample, label_np,
                                                    boxes_np, utils.LABEL_MAP)

                    image_summary_det = tf.summary.image(
                        "detection_samples%i" % val_i,
                        image_summary_placeholder,
                        max_outputs=256)

                    s = sess.run(
                        image_summary_det,
                        {image_summary_placeholder: np.array(imgs_with_box)})
                    summary_writer_main.add_summary(s, 0)

                    # Save model
                    saver.save(
                        sess,
                        os.path.join(FLAGS.tb_log_dir, "main", "model.ckpt"))
                    saver.save(
                        sess,
                        os.path.join(FLAGS.tb_log_dir, "main",
                                     "model%i.ckpt" % val_i))

                    sess.run([update_val_step])

                # Terminate
                if train_i > max_iter:
                    break

            if FLAGS.run_test:
                pass
Exemple #9
0
def train(args):
    global num_gpu
    global batch_size

    num_gpu = 1
    batch_size = per_gpu_batch_size * num_gpu

    lr_init = config.TRAIN.lr_init
    pwc_lr_init = config.TRAIN.pwc_lr_init

    record_reader = RecordReader(config.TRAIN.tf_records_path)

    with tf.variable_scope('learning_rate'):
        lr_v = tf.Variable(lr_init, trainable=False)
        opt = tf.train.AdamOptimizer(lr_v, beta1=beta1, beta2=beta2)
        pwc_lr_v = tf.Variable(pwc_lr_init, trainable=False)
        pwcnet_opt = tf.train.AdamOptimizer(pwc_lr_v, beta1=beta1, beta2=beta2)

    vgg_data_dict = np.load(config.TRAIN.vgg19_npy_path,
                            encoding='latin1').item()

    first_img_t, mid_img_t, end_img_t, s_img_t = record_reader.read_and_decode(
    )
    first_img_t_batch, mid_img_t_batch, end_img_t_batch, s_img_t_batch = tf.train.shuffle_batch(
        [first_img_t, mid_img_t, end_img_t, s_img_t],
        batch_size=batch_size,
        capacity=12000,
        min_after_dequeue=160,
        num_threads=4)

    reuse_all = False

    tower_grads, tower_pwc_grads = [], []
    tower_loss = []

    for d in range(0, num_gpu):
        print("dealing {}th gpu".format(d))
        with tf.device('/gpu:%s' % d):
            with tf.name_scope('%s_%s' % ('tower', d)):
                print("build model!!!")
                tot_loss_gpu, summary \
                    = build_model(first_img_t_batch, mid_img_t_batch, end_img_t_batch, s_img_t_batch, vgg_data_dict, reuse_all=reuse_all)

                if not reuse_all:
                    vars_trainable = get_variables_with_name(
                        name='stabnet', exclude_name='pwcnet', train_only=True)
                    grads = opt.compute_gradients(tot_loss_gpu,
                                                  var_list=vars_trainable)
                    pwc_vars_trainable = get_variables_with_name(
                        name='pwcnet', exclude_name='stabnet', train_only=True)
                    pwc_grads = opt.compute_gradients(
                        tot_loss_gpu, var_list=pwc_vars_trainable)

                for i, (g, v) in enumerate(grads):
                    if g is not None:
                        grads[i] = (tf.clip_by_norm(g, 5), v)
                for i, (g, v) in enumerate(pwc_grads):
                    if g is not None:
                        pwc_grads[i] = (tf.clip_by_norm(g, 5), v)

                tower_grads.append(grads)
                tower_pwc_grads.append(pwc_grads)
                tower_loss.append(tot_loss_gpu)

                reuse_all = True
        if num_gpu == 1:
            with tf.device('/gpu:0'):
                mse_loss = tf.reduce_mean(tf.stack(tower_loss, 0), 0)
                mean_grads = average_gradients(tower_grads)
                mean_pwc_grads = average_gradients(tower_pwc_grads)
                update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS,
                                               scope='.*?stabnet')
                update_pwc_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS,
                                                   scope='.*?pwcnet')
                with tf.control_dependencies(update_ops):
                    minimize_op = opt.apply_gradients(mean_grads)
                with tf.control_dependencies(update_pwc_ops):
                    minimize_pwc_op = pwcnet_opt.apply_gradients(
                        mean_pwc_grads)

        else:
            mse_loss = tf.reduce_mean(tf.stack(tower_loss, 0), 0)
            mean_grads = average_gradients(tower_grads)
            mean_pwc_grads = average_gradients(tower_pwc_grads)
            update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS,
                                           scope='.*?stabnet')
            update_pwc_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS,
                                               scope='.*?pwcnet')
            with tf.control_dependencies(update_ops):
                minimize_op = opt.apply_gradients(mean_grads)
            with tf.control_dependencies(update_pwc_ops):
                minimize_pwc_op = pwcnet_opt.apply_gradients(mean_pwc_grads)

        print('trainable variables:')
        print(vars_trainable)
        print('pwc trainable variables:')
        print(pwc_vars_trainable)

    sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True,
                                            log_device_placement=False))
    if debug:
        sess = tf_debug.LocalCLIDebugWrapperSession(sess)

    # sess.run(tf.global_variables_initializer())

    saver = tf.train.Saver(max_to_keep=200)
    lr_str = timestamp + ' ' + get_config(config) + ',gn:{}'.format(num_gpu)
    if not os.path.exists(checkpoint_path + lr_str):
        os.makedirs(checkpoint_path + lr_str)

    if args.pretrained:
        print('restore path from : ',
              checkpoint_path + args.lr_str + '/stab.ckpt-' + str(args.modeli))
        saver.restore(
            sess,
            checkpoint_path + args.lr_str + '/stab.ckpt-' + str(args.modeli))

    summary_ops = tf.summary.merge(summary)
    summary_writer = tf.summary.FileWriter(
        checkpoint_path + lr_str + '/summary', sess.graph)

    len_train = config.TRAIN.len_train
    n_epoch = 250

    init_op = tf.group(tf.global_variables_initializer(),
                       tf.local_variables_initializer())
    sess.run(init_op)
    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(sess=sess, coord=coord)
    optimistic_restore(sess, pwc_opt.ckpt_path)

    for epoch in range(0, n_epoch):
        if epoch < pwc_freeze_epoch:
            pwc_lr_init = 0.0
            sess.run(tf.assign(pwc_lr_v,
                               pwc_lr_init))  # freeze the optical flow net
            log = ' ** pwc net new learning rate: %f ' % (pwc_lr_init)
            print(log)

        if epoch >= pwc_freeze_epoch:
            pwc_lr_init = config.TRAIN.pwc_lr_init
            cur_lr = pwc_lr_init
            sess.run(tf.assign(pwc_lr_v, cur_lr))
            log = ' ** pwc net new learning rate: %f ' % (cur_lr)
            print(log)

        if epoch >= pwc_lr_stable_epoch:
            pwc_lr_init = config.TRAIN.pwc_lr_init
            cur_lr = linear_lr(pwc_lr_init, pwc_decay_ratio,
                               epoch - pwc_lr_stable_epoch)
            sess.run(tf.assign(pwc_lr_v, cur_lr))
            log = ' ** pwc net new learning rate: %f ' % (cur_lr)
            print(log)

        if epoch >= lr_stable_epoch:
            lr_init = config.TRAIN.lr_init
            cur_lr = linear_lr(lr_init, decay_ratio, epoch - lr_stable_epoch)
            sess.run(tf.assign(lr_v, cur_lr))
            log = ' ** stab net new learning rate: %f' % (cur_lr)
            print(log)

        sys.stdout.flush()

        epoch_time = time.time()
        for it in range(int(len_train / batch_size)):
            errM, _, _, summary = sess.run(
                [mse_loss, minimize_op, minimize_pwc_op, summary_ops])
            if (it + int(len_train / batch_size) * epoch) % 10 == 0:
                summary_writer.add_summary(
                    summary, it + int(len_train / batch_size) * epoch)

            print("Epoch [%2d/%2d] %4d time: %4.4fs, loss:  %5.5f" %
                  (epoch, n_epoch, it, time.time() - epoch_time, errM))
            sys.stdout.flush()

            epoch_time = time.time()

            if (it + int(len_train / batch_size) * epoch) % 1000 == 0:
                saver.save(sess,
                           checkpoint_path + lr_str + '/stab.ckpt',
                           global_step=(it +
                                        int(len_train / batch_size) * epoch))

    coord.request_stop()
    coord.join(threads)
    sess.close()
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('name', nargs='*')
    parser.add_argument('--eval', dest='eval_only', action='store_true')
    parser.add_argument('--test', action='store_true')
    parser.add_argument('--resume', nargs='*')
    args = parser.parse_args()

    if args.test:
        args.eval_only = True
    src = open('model.py').read()
    if args.name:
        name = ' '.join(args.name)
    else:
        from datetime import datetime
        name = datetime.now().strftime("%Y-%m-%d_%H:%M:%S")
    target_name = os.path.join('logs', '{}'.format(name))
    writer.add_text('Log Name:', name)
    if not args.test:
        # target_name won't be used in test mode
        print('will save to {}'.format(target_name))
    if args.resume:
        logs = torch.load(' '.join(args.resume))
        # hacky way to tell the VQA classes that they should use the vocab without passing more params around
        #data.preloaded_vocab = logs['vocab']

    cudnn.benchmark = True

    if not args.eval_only:
        train_loader = data.get_loader(train=True)
    if not args.test:
        val_loader = data.get_loader(val=True)
    else:
        val_loader = data.get_loader(test=True)

    net = model.Net(val_loader.dataset.num_tokens).cuda()
    # restore transfer learning
    # 'data/vgrel-29.tar' for 36
    # 'data/vgrel-19.tar' for 10-100
    if config.output_size == 36:
        print("load data/vgrel-29(transfer36).tar")
        ckpt = torch.load('data/vgrel-29(transfer36).tar')
    else:
        print("load data/vgrel-19(transfer110).tar")
        ckpt = torch.load('data/vgrel-19(transfer110).tar')
    
    utils.optimistic_restore(net.tree_lstm.gen_tree_net, ckpt['state_dict'])
    
    if config.use_rl:
        for p in net.parameters():
            p.requires_grad = False
        for p in net.tree_lstm.gen_tree_net.parameters():
            p.requires_grad = True
    
    optimizer = optim.Adam([p for p in net.parameters() if p.requires_grad], lr=config.initial_lr)
    scheduler = lr_scheduler.ExponentialLR(optimizer, 0.5**(1 / config.lr_halflife))
    start_epoch = 0
    if args.resume:
        net.load_state_dict(logs['weights'])
        #optimizer.load_state_dict(logs['optimizer'])
        start_epoch = int(logs['epoch']) + 1

    tracker = utils.Tracker()
    config_as_dict = {k: v for k, v in vars(config).items() if not k.startswith('__')}
    print(config_as_dict)
    best_accuracy = -1

    for i in range(start_epoch, config.epochs):
        if not args.eval_only:
            run(net, train_loader, optimizer, scheduler, tracker, train=True, prefix='train', epoch=i)
        if i % 1 != 0 or (i > 0 and i <20):
            r = [[-1], [-1], [-1]]
        else:
            r = run(net, val_loader, optimizer, scheduler, tracker, train=False, prefix='val', epoch=i, has_answers=not args.test)

        if not args.test:
            results = {
                'name': name,
                'tracker': tracker.to_dict(),
                'config': config_as_dict,
                'weights': net.state_dict(),
                'optimizer': optimizer.state_dict(),
                'epoch': i,
                'eval': {
                    'answers': r[0],
                    'accuracies': r[1],
                    'idx': r[2],
                },
                'vocab': val_loader.dataset.vocab,
                'src': src,
                'setting': exp_setting,
            }
            current_ac = sum(r[1]) / len(r[1])
            if current_ac >  best_accuracy:
                best_accuracy = current_ac
                print('update best model, current: ', current_ac)
                torch.save(results, target_name + '_best.pth')
            if i % 1 == 0:
                torch.save(results, target_name + '_' + str(i) + '.pth')

        else:
            # in test mode, save a results file in the format accepted by the submission server
            answer_index_to_string = {a:  s for s, a in val_loader.dataset.answer_to_index.items()}
            results = []
            for answer, index in zip(r[0], r[2]):
                answer = answer_index_to_string[answer.item()]
                qid = val_loader.dataset.question_ids[index]
                entry = {
                    'question_id': qid,
                    'answer': answer,
                }
                results.append(entry)
            with open('results.json', 'w') as fd:
                json.dump(results, fd)

        if args.eval_only:
            break
Exemple #11
0
    img_mid = np.expand_dims(img_mid, 0)
    img_end = np.expand_dims(img_end, 0)

    # warped = test_flow_warp(img_first, img_s)
    img_int, img_out, [warped_first, warped_end,
                       warped_mid] = training_stab_model(img_first,
                                                         img_s,
                                                         img_end,
                                                         img_mid,
                                                         training=False)
    sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True,
                                            log_device_placement=False))

    sess.run(tf.global_variables_initializer())

    optimistic_restore(sess, pwc_opt.ckpt_path)  #必须放在下面,否则会被覆盖

    [warped_first, warped_end] = sess.run([warped_first, warped_end])
    # import pdb; pdb.set_trace();
    warped_first = warped_first[0][:, :, ::-1]
    warped_end = warped_end[0][:, :, ::-1]

    # warped_np = np.clip(warped_np, 0, 1.)
    cv2.imwrite('warped_first.png',
                np.array(warped_first * 255).astype(np.uint8))
    cv2.imwrite('warped_end.png', np.array(warped_end * 255).astype(np.uint8))
    # test_training_model()
    # test_testing_model()
    # input_tensor_batch = tf.random_uniform(shape=[2, 16, 16, 3])
    # # out = resnet(input_tensor_batch, 5)
    # out = make_unet(input_tensor_batch)