예제 #1
0
def construct_ff_model(model, weights, X_NOISE, X, ACTION_LABEL,
                       ACTION_NOISE_LABEL, optimizer):
    target_vars = {}
    x_mods = []

    x_pred = model.forward(X[:, 0, 0], weights, action_label=ACTION_LABEL)
    loss_total = tf.reduce_mean(tf.square(x_pred - X[:, 1, 0]))

    dyn_model = TrajInverseDynamics(dim_input=FLAGS.latent_dim,
                                    dim_output=FLAGS.action_dim)
    weights = dyn_model.construct_weights(scope="inverse_dynamics",
                                          weights=weights)
    output_action = dyn_model.forward(X, weights)
    dyn_loss = tf.reduce_mean(tf.square(output_action - ACTION_LABEL))
    dyn_dist = tf.reduce_mean(tf.abs(output_action - ACTION_LABEL))
    target_vars['dyn_loss'] = dyn_loss
    target_vars['dyn_dist'] = dyn_dist

    dyn_optimizer = AdamOptimizer(1e-3)
    gvs = dyn_optimizer.compute_gradients(dyn_loss)
    dyn_train_op = dyn_optimizer.apply_gradients(gvs)

    gvs = optimizer.compute_gradients(loss_total)
    gvs = [(k, v) for (k, v) in gvs if k is not None]
    print("Applying gradients...")
    grads, vs = zip(*gvs)

    def filter_grad(g, v):
        return tf.clip_by_value(g, -1e5, 1e5)

    capped_gvs = [(filter_grad(grad, var), var) for grad, var in gvs]
    gvs = capped_gvs
    train_op = optimizer.apply_gradients(gvs)

    if not FLAGS.gt_inverse_dynamics:
        train_op = tf.group(train_op, dyn_train_op)

    target_vars['train_op'] = train_op

    target_vars['loss_ml'] = tf.zeros(1)
    target_vars['loss_total'] = loss_total
    target_vars['gvs'] = gvs
    target_vars['loss_energy'] = tf.zeros(1)

    target_vars['weights'] = weights
    target_vars['X'] = X
    target_vars['X_NOISE'] = X_NOISE
    target_vars['energy_pos'] = tf.zeros(1)
    target_vars['energy_neg'] = tf.zeros(1)
    target_vars['x_grad'] = tf.zeros(1)
    target_vars['action_grad'] = tf.zeros(1)
    target_vars['x_mod'] = tf.zeros(1)
    target_vars['x_off'] = tf.zeros(1)
    target_vars['temp'] = FLAGS.temperature
    target_vars['ACTION_LABEL'] = ACTION_LABEL
    target_vars['ACTION_NOISE_LABEL'] = ACTION_NOISE_LABEL

    return target_vars
예제 #2
0
def main():

    logdir = osp.join(FLAGS.logdir, FLAGS.exp)
    logger = TensorBoardOutputFormat(logdir)

    config = tf.ConfigProto()

    sess = tf.Session(config=config)
    LABEL = None
    print("Loading data...")
    if FLAGS.dataset == 'cubes':
        dataset = Cubes(cond_idx=FLAGS.cond_idx)
        test_dataset = dataset

        if FLAGS.cond_idx == 0:
            label_size = 2
        elif FLAGS.cond_idx == 1:
            label_size = 1
        elif FLAGS.cond_idx == 2:
            label_size = 3
        elif FLAGS.cond_idx == 3:
            label_size = 20

        LABEL = tf.placeholder(shape=(None, label_size), dtype=tf.float32)
        LABEL_POS = tf.placeholder(shape=(None, label_size), dtype=tf.float32)
    elif FLAGS.dataset == 'color':
        dataset = CubesColor()
        test_dataset = dataset
        LABEL = tf.placeholder(shape=(None, 301), dtype=tf.float32)
        LABEL_POS = tf.placeholder(shape=(None, 301), dtype=tf.float32)
        label_size = 301
    elif FLAGS.dataset == 'pos':
        dataset = CubesPos()
        test_dataset = dataset
        LABEL = tf.placeholder(shape=(None, 2), dtype=tf.float32)
        LABEL_POS = tf.placeholder(shape=(None, 2), dtype=tf.float32)
        label_size = 2
    elif FLAGS.dataset == "pairs":
        dataset = Pairs(cond_idx=0)
        test_dataset = dataset
        LABEL = tf.placeholder(shape=(None, 6), dtype=tf.float32)
        LABEL_POS = tf.placeholder(shape=(None, 6), dtype=tf.float32)
        label_size = 6
    elif FLAGS.dataset == "continual":
        dataset = CubesContinual()
        test_dataset = dataset

        if FLAGS.prelearn_model_shape:
            LABEL = tf.placeholder(shape=(None, 20), dtype=tf.float32)
            LABEL_POS = tf.placeholder(shape=(None, 20), dtype=tf.float32)
            label_size = 20
        else:
            LABEL = tf.placeholder(shape=(None, 2), dtype=tf.float32)
            LABEL_POS = tf.placeholder(shape=(None, 2), dtype=tf.float32)
            label_size = 2

    elif FLAGS.dataset == "cross":
        dataset = CubesCrossProduct(FLAGS.ratio, cond_size=FLAGS.cond_size, cond_pos=FLAGS.cond_pos, joint_baseline=FLAGS.joint_baseline)
        test_dataset = dataset

        if FLAGS.cond_size:
            LABEL = tf.placeholder(shape=(None, 1), dtype=tf.float32)
            LABEL_POS = tf.placeholder(shape=(None, 1), dtype=tf.float32)
            label_size = 1
        elif FLAGS.cond_pos:
            LABEL = tf.placeholder(shape=(None, 2), dtype=tf.float32)
            LABEL_POS = tf.placeholder(shape=(None, 2), dtype=tf.float32)
            label_size = 2

        if FLAGS.joint_baseline:
            LABEL = tf.placeholder(shape=(None, 3), dtype=tf.float32)
            LABEL_POS = tf.placeholder(shape=(None, 3), dtype=tf.float32)
            label_size = 3

    elif FLAGS.dataset == 'celeba':
        dataset = CelebA(cond_idx=FLAGS.celeba_cond_idx)
        test_dataset = dataset
        channel_num = 3
        X_NOISE = tf.placeholder(shape=(None, 128, 128, 3), dtype=tf.float32)
        X = tf.placeholder(shape=(None, 128, 128, 3), dtype=tf.float32)
        LABEL = tf.placeholder(shape=(None, 2), dtype=tf.float32)
        LABEL_POS = tf.placeholder(shape=(None, 2), dtype=tf.float32)

        model = ResNet128(
            num_channels=channel_num,
            num_filters=64,
            classes=2)

    if FLAGS.joint_baseline:
        # Other stuff for joint model
        optimizer = AdamOptimizer(FLAGS.lr, beta1=0.99, beta2=0.999)

        X = tf.placeholder(shape=(None, 64, 64, 3), dtype=tf.float32)
        X_NOISE = tf.placeholder(shape=(None, 64, 64, 3), dtype=tf.float32)
        ATTENTION_MASK = tf.placeholder(shape=(None, 64, 64, FLAGS.cond_func), dtype=tf.float32)
        NOISE = tf.placeholder(shape=(None, 128), dtype=tf.float32)
        HIER_LABEL = tf.placeholder(shape=(None, 2), dtype=tf.float32)

        channel_num = 3

        model = CubesNetGen(num_channels=channel_num, label_size=label_size)
        weights = model.construct_weights('context_0')
        output = model.forward(NOISE, weights, reuse=False, label=LABEL)
        print(output.get_shape())
        mse_loss = tf.reduce_mean(tf.square(output - X))
        gvs = optimizer.compute_gradients(mse_loss)
        train_op = optimizer.apply_gradients(gvs)
        gvs = [(k, v) for (k, v) in gvs if k is not None]

        target_vars = {}
        target_vars['train_op'] = train_op
        target_vars['X'] = X
        target_vars['X_NOISE'] = X_NOISE
        target_vars['ATTENTION_MASK'] = ATTENTION_MASK
        target_vars['eps_begin'] = tf.zeros(1)
        target_vars['gvs'] = gvs
        target_vars['energy_pos'] = tf.zeros(1)
        target_vars['energy_neg'] = tf.zeros(1)
        target_vars['loss_energy'] = tf.zeros(1)
        target_vars['loss_ml'] = tf.zeros(1)
        target_vars['total_loss'] = mse_loss
        target_vars['attention_mask'] = tf.zeros(1)
        target_vars['attention_grad'] = tf.zeros(1)
        target_vars['x_off'] = tf.reduce_mean(tf.abs(output - X))
        target_vars['x_mod'] = tf.zeros(1)
        target_vars['x_grad'] = tf.zeros(1)
        target_vars['NOISE'] = NOISE
        target_vars['LABEL'] = LABEL
        target_vars['LABEL_POS'] = LABEL_POS
        target_vars['HIER_LABEL'] = HIER_LABEL

        data_loader = DataLoader(
            dataset,
            batch_size=FLAGS.batch_size,
            num_workers=FLAGS.data_workers,
            drop_last=True,
            shuffle=True)
    else:
        print("label size here ", label_size)
        channel_num = 3
        X_NOISE = tf.placeholder(shape=(None, 64, 64, 3), dtype=tf.float32)
        X = tf.placeholder(shape=(None, 64, 64, 3), dtype=tf.float32)
        HEIR_LABEL = tf.placeholder(shape=(None, 2), dtype=tf.float32)
        ATTENTION_MASK = tf.placeholder(shape=(None, 64, 64, FLAGS.cond_func), dtype=tf.float32)

        if FLAGS.dataset != "celeba":
            model = CubesNet(num_channels=channel_num, label_size=label_size)

        heir_model = HeirNet(num_channels=FLAGS.cond_func)

        models_pretrain = []
        if FLAGS.prelearn_model:
            model_prelearn = CubesNet(num_channels=channel_num, label_size=FLAGS.prelearn_label)
            weights = model_prelearn.construct_weights('context_1')
            LABEL_PRELEARN = tf.placeholder(shape=(None, FLAGS.prelearn_label), dtype=tf.float32)
            models_pretrain.append((model_prelearn, weights, LABEL_PRELEARN))

            cubes_logdir = osp.join(FLAGS.logdir, FLAGS.prelearn_exp)
            if (FLAGS.prelearn_iter != -1 or not FLAGS.train):
                model_file = osp.join(cubes_logdir, 'model_{}'.format(FLAGS.prelearn_iter))
                resume_itr = FLAGS.resume_iter
                # saver.restore(sess, model_file)

                v_list = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='context_{}'.format(1))
                v_map = {(v.name.replace('context_{}'.format(1), 'context_0')[:-2]): v for v in v_list}
                saver = tf.train.Saver(v_map)
                saver.restore(sess, model_file)

        if FLAGS.prelearn_model_shape:
            model_prelearn = CubesNet(num_channels=channel_num, label_size=FLAGS.prelearn_label_shape)
            weights = model_prelearn.construct_weights('context_2')
            LABEL_PRELEARN = tf.placeholder(shape=(None, FLAGS.prelearn_label_shape), dtype=tf.float32)
            models_pretrain.append((model_prelearn, weights, LABEL_PRELEARN))

            cubes_logdir = osp.join(FLAGS.logdir, FLAGS.prelearn_exp_shape)
            if (FLAGS.prelearn_iter_shape != -1 or not FLAGS.train):
                model_file = osp.join(cubes_logdir, 'model_{}'.format(FLAGS.prelearn_iter_shape))
                resume_itr = FLAGS.resume_iter
                # saver.restore(sess, model_file)

                v_list = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='context_{}'.format(2))
                v_map = {(v.name.replace('context_{}'.format(2), 'context_0')[:-2]): v for v in v_list}
                saver = tf.train.Saver(v_map)
                saver.restore(sess, model_file)

        print("Done loading...")

        data_loader = DataLoader(
            dataset,
            batch_size=FLAGS.batch_size,
            num_workers=FLAGS.data_workers,
            drop_last=True,
            shuffle=True)

        batch_size = FLAGS.batch_size

        weights = model.construct_weights('context_0')

        if FLAGS.heir_mask:
            weights = heir_model.construct_weights('heir_0', weights=weights)

        Y = tf.placeholder(shape=(None), dtype=tf.int32)

        # Varibles to run in training

        X_SPLIT = tf.split(X, FLAGS.num_gpus)
        X_NOISE_SPLIT = tf.split(X_NOISE, FLAGS.num_gpus)
        LABEL_SPLIT = tf.split(LABEL, FLAGS.num_gpus)
        LABEL_POS_SPLIT = tf.split(LABEL_POS, FLAGS.num_gpus)
        LABEL_SPLIT_INIT = list(LABEL_SPLIT)
        attention_mask = ATTENTION_MASK
        tower_grads = []
        tower_gen_grads = []
        x_mod_list = []

        optimizer = AdamOptimizer(FLAGS.lr, beta1=0.0, beta2=0.99)

        for j in range(FLAGS.num_gpus):

            x_mod = X_SPLIT[j]
            if FLAGS.comb_mask:
                steps = tf.constant(0)
                c = lambda i, x: tf.less(i, FLAGS.num_steps)

                def langevin_attention_step(counter, attention_mask):
                    attention_mask = attention_mask + tf.random_normal(tf.shape(attention_mask), mean=0.0, stddev=0.01)
                    energy_noise = energy_start = model.forward(
                                x_mod,
                                weights,
                                attention_mask,
                                label=LABEL_SPLIT[j],
                                reuse=True,
                                stop_at_grad=False,
                                stop_batch=True)

                    if FLAGS.heir_mask:
                        energy_heir = 1.00 * heir_model.forward(attention_mask, weights, label=HEIR_LABEL)
                        energy_noise = energy_noise + energy_heir

                    attention_grad = tf.gradients(
                        FLAGS.temperature * energy_noise, [attention_mask])[0]
                    energy_noise_old = energy_noise

                    # Clip gradient norm for now
                    attention_mask = attention_mask - (FLAGS.attention_lr) * attention_grad
                    attention_mask = tf.layers.average_pooling2d(attention_mask, (3, 3), 1, padding='SAME')
                    attention_mask = tf.stop_gradient(attention_mask)

                    counter = counter + 1

                    return counter, attention_mask

                steps, attention_mask = tf.while_loop(c, langevin_attention_step, (steps, attention_mask))

                # attention_mask = tf.Print(attention_mask, [attention_mask])

                energy_pos = model.forward(
                        X_SPLIT[j],
                        weights,
                        tf.stop_gradient(attention_mask),
                        label=LABEL_POS_SPLIT[j],
                        stop_at_grad=False)

                if FLAGS.heir_mask:
                    energy_heir = 1.00 * heir_model.forward(attention_mask, weights, label=HEIR_LABEL)
                    energy_pos = energy_heir + energy_pos

            else:
                energy_pos = model.forward(
                        X_SPLIT[j],
                        weights,
                        attention_mask,
                        label=LABEL_POS_SPLIT[j],
                        stop_at_grad=False)

                if FLAGS.heir_mask:
                    energy_heir = 1.00 * heir_model.forward(attention_mask, weights, label=HEIR_LABEL)
                    energy_pos = energy_heir + energy_pos

            print("Building graph...")
            x_mod = x_orig = X_NOISE_SPLIT[j]

            x_grads = []

            loss_energys = []

            eps_begin = tf.zeros(1)

            steps = tf.constant(0)
            c_cond = lambda i, x, y: tf.less(i, FLAGS.num_steps)

            def langevin_step(counter, x_mod, attention_mask):

                lr = FLAGS.step_lr

                x_mod = x_mod + tf.random_normal(tf.shape(x_mod),
                                                 mean=0.0,
                                                 stddev=0.001 * FLAGS.rescale * FLAGS.noise_scale)
                attention_mask = attention_mask + tf.random_normal(tf.shape(attention_mask), mean=0.0, stddev=0.01)

                energy_noise = model.forward(
                            x_mod,
                            weights,
                            attention_mask,
                            label=LABEL_SPLIT[j],
                            reuse=True,
                            stop_at_grad=False,
                            stop_batch=True)

                if FLAGS.prelearn_model:
                    for m_i, w_i, l_i in models_pretrain:
                        energy_noise = energy_noise + m_i.forward(
                                    x_mod,
                                    w_i,
                                    attention_mask,
                                    label=l_i,
                                    reuse=True,
                                    stop_at_grad=False,
                                    stop_batch=True)


                if FLAGS.heir_mask:
                    energy_heir = 1.00 * heir_model.forward(attention_mask, weights, label=HEIR_LABEL)
                    energy_noise = energy_heir + energy_noise

                x_grad, attention_grad = tf.gradients(
                    FLAGS.temperature * energy_noise, [x_mod, attention_mask])

                if not FLAGS.comb_mask:
                    attention_grad = tf.zeros(1)
                energy_noise_old = energy_noise

                if FLAGS.proj_norm != 0.0:
                    if FLAGS.proj_norm_type == 'l2':
                        x_grad = tf.clip_by_norm(x_grad, FLAGS.proj_norm)
                    elif FLAGS.proj_norm_type == 'li':
                        x_grad = tf.clip_by_value(
                            x_grad, -FLAGS.proj_norm, FLAGS.proj_norm)
                    else:
                        print("Other types of projection are not supported!!!")
                        assert False

                # Clip gradient norm for now
                x_last = x_mod - (lr) * x_grad

                if FLAGS.comb_mask:
                    attention_mask = attention_mask - FLAGS.attention_lr * attention_grad
                    attention_mask = tf.layers.average_pooling2d(attention_mask, (3, 3), 1, padding='SAME')
                    attention_mask = tf.stop_gradient(attention_mask)

                x_mod = x_last
                x_mod = tf.clip_by_value(x_mod, 0, FLAGS.rescale)

                counter = counter + 1

                return counter, x_mod, attention_mask


            steps, x_mod, attention_mask = tf.while_loop(c_cond, langevin_step, (steps, x_mod, attention_mask))

            attention_mask = tf.stop_gradient(attention_mask)
            # attention_mask = tf.Print(attention_mask, [attention_mask])

            energy_eval = model.forward(x_mod, weights, attention_mask, label=LABEL_SPLIT[j],
                                        stop_at_grad=False, reuse=True)
            x_grad, attention_grad = tf.gradients(FLAGS.temperature * energy_eval, [x_mod, attention_mask])
            x_grads.append(x_grad)

            energy_neg = model.forward(
                    tf.stop_gradient(x_mod),
                    weights,
                    tf.stop_gradient(attention_mask),
                    label=LABEL_SPLIT[j],
                    stop_at_grad=False,
                    reuse=True)

            if FLAGS.heir_mask:
                energy_heir = 1.00 * heir_model.forward(attention_mask, weights, label=HEIR_LABEL)
                energy_neg = energy_heir + energy_neg


            temp = FLAGS.temperature

            x_off = tf.reduce_mean(
                tf.abs(x_mod[:tf.shape(X_SPLIT[j])[0]] - X_SPLIT[j]))

            loss_energy = model.forward(
                x_mod,
                weights,
                attention_mask,
                reuse=True,
                label=LABEL,
                stop_grad=True)

            print("Finished processing loop construction ...")

            target_vars = {}

            if FLAGS.antialias:
                antialias = tf.tile(stride_3, (1, 1, tf.shape(x_mod)[3], tf.shape(x_mod)[3]))
                inp = tf.nn.conv2d(x_mod, antialias, [1, 2, 2, 1], padding='SAME')

            test_x_mod = x_mod

            if FLAGS.cclass or FLAGS.model_cclass:
                label_sum = tf.reduce_sum(LABEL_SPLIT[0], axis=0)
                label_prob = label_sum / tf.reduce_sum(label_sum)
                label_ent = -tf.reduce_sum(label_prob *
                                           tf.math.log(label_prob + 1e-7))
            else:
                label_ent = tf.zeros(1)

            target_vars['label_ent'] = label_ent

            if FLAGS.train:
                if FLAGS.objective == 'logsumexp':
                    pos_term = temp * energy_pos
                    energy_neg_reduced = (energy_neg - tf.reduce_min(energy_neg))
                    coeff = tf.stop_gradient(tf.exp(-temp * energy_neg_reduced))
                    norm_constant = tf.stop_gradient(tf.reduce_sum(coeff)) + 1e-4
                    pos_loss = tf.reduce_mean(temp * energy_pos)
                    neg_loss = coeff * (-1 * temp * energy_neg) / norm_constant
                    loss_ml = FLAGS.ml_coeff * (pos_loss + tf.reduce_sum(neg_loss))
                elif FLAGS.objective == 'cd':
                    pos_loss = tf.reduce_mean(temp * energy_pos)
                    neg_loss = -tf.reduce_mean(temp * energy_neg)
                    loss_ml = FLAGS.ml_coeff * (pos_loss + tf.reduce_sum(neg_loss))
                elif FLAGS.objective == 'softplus':
                    loss_ml = FLAGS.ml_coeff * \
                        tf.nn.softplus(temp * (energy_pos - energy_neg))

                loss_total = tf.reduce_mean(loss_ml)

                if not FLAGS.zero_kl:
                    loss_total = loss_total + tf.reduce_mean(loss_energy)

                loss_total = loss_total + \
                    FLAGS.l2_coeff * (tf.reduce_mean(tf.square(energy_pos)) + tf.reduce_mean(tf.square((energy_neg))))

                print("Started gradient computation...")
                gvs = optimizer.compute_gradients(loss_total)
                gvs = [(k, v) for (k, v) in gvs if k is not None]

                print("Applying gradients...")

                tower_grads.append(gvs)

                print("Finished applying gradients.")

                target_vars['loss_ml'] = loss_ml
                target_vars['total_loss'] = loss_total
                target_vars['loss_energy'] = loss_energy
                target_vars['weights'] = weights
                target_vars['gvs'] = gvs

            target_vars['X'] = X
            target_vars['Y'] = Y
            target_vars['LABEL'] = LABEL
            target_vars['HIER_LABEL'] = HEIR_LABEL
            target_vars['LABEL_POS'] = LABEL_POS
            target_vars['X_NOISE'] = X_NOISE
            target_vars['energy_pos'] = energy_pos
            target_vars['attention_grad'] = attention_grad

            if len(x_grads) >= 1:
                target_vars['x_grad'] = x_grads[-1]
                target_vars['x_grad_first'] = x_grads[0]
            else:
                target_vars['x_grad'] = tf.zeros(1)
                target_vars['x_grad_first'] = tf.zeros(1)

            target_vars['x_mod'] = x_mod
            target_vars['x_off'] = x_off
            target_vars['temp'] = temp
            target_vars['energy_neg'] = energy_neg
            target_vars['test_x_mod'] = test_x_mod
            target_vars['eps_begin'] = eps_begin
            target_vars['ATTENTION_MASK'] = ATTENTION_MASK
            target_vars['models_pretrain'] = models_pretrain
            if FLAGS.comb_mask:
                target_vars['attention_mask'] = tf.nn.softmax(attention_mask)
            else:
                target_vars['attention_mask'] = tf.zeros(1)

        if FLAGS.train:
            grads = average_gradients(tower_grads)
            train_op = optimizer.apply_gradients(grads)
            target_vars['train_op'] = train_op

    # sess = tf.Session(config=config)

    saver = loader = tf.train.Saver(
        max_to_keep=30, keep_checkpoint_every_n_hours=6)

    total_parameters = 0
    for variable in tf.trainable_variables():
        # shape is an array of tf.Dimension
        shape = variable.get_shape()
        variable_parameters = 1
        for dim in shape:
            variable_parameters *= dim.value
        total_parameters += variable_parameters
    print("Model has a total of {} parameters".format(total_parameters))

    sess.run(tf.global_variables_initializer())

    resume_itr = 0

    if (FLAGS.resume_iter != -1 or not FLAGS.train):
        model_file = osp.join(logdir, 'model_{}'.format(FLAGS.resume_iter))
        resume_itr = FLAGS.resume_iter
        # saver.restore(sess, model_file)
        optimistic_restore(sess, model_file)

    print("Initializing variables...")

    print("Start broadcast")
    print("End broadcast")

    if FLAGS.train:
        train(target_vars, saver, sess,
              logger, data_loader, resume_itr,
              logdir)

    test(target_vars, saver, sess, logger, data_loader)
예제 #3
0
def main():
    print("Local rank: ", hvd.local_rank(), hvd.size())

    logdir = osp.join(FLAGS.logdir, FLAGS.exp)
    if hvd.rank() == 0:
        if not osp.exists(logdir):
            os.makedirs(logdir)
        logger = TensorBoardOutputFormat(logdir)
    else:
        logger = None

    LABEL = None
    print("Loading data...")
    if FLAGS.dataset == 'cifar10':
        dataset = Cifar10(augment=FLAGS.augment, rescale=FLAGS.rescale)
        test_dataset = Cifar10(train=False, rescale=FLAGS.rescale)
        channel_num = 3

        X_NOISE = tf.placeholder(shape=(None, 32, 32, 3), dtype=tf.float32)
        X = tf.placeholder(shape=(None, 32, 32, 3), dtype=tf.float32)
        LABEL = tf.placeholder(shape=(None, 10), dtype=tf.float32)
        LABEL_POS = tf.placeholder(shape=(None, 10), dtype=tf.float32)

        if FLAGS.large_model:
            model = ResNet32Large(num_channels=channel_num,
                                  num_filters=128,
                                  train=True)
        elif FLAGS.larger_model:
            model = ResNet32Larger(num_channels=channel_num, num_filters=128)
        elif FLAGS.wider_model:
            model = ResNet32Wider(num_channels=channel_num, num_filters=192)
        else:
            model = ResNet32(num_channels=channel_num, num_filters=128)

    elif FLAGS.dataset == 'imagenet':
        dataset = Imagenet(train=True)
        test_dataset = Imagenet(train=False)
        channel_num = 3
        X_NOISE = tf.placeholder(shape=(None, 32, 32, 3), dtype=tf.float32)
        X = tf.placeholder(shape=(None, 32, 32, 3), dtype=tf.float32)
        LABEL = tf.placeholder(shape=(None, 1000), dtype=tf.float32)
        LABEL_POS = tf.placeholder(shape=(None, 1000), dtype=tf.float32)

        model = ResNet32Wider(num_channels=channel_num, num_filters=256)

    elif FLAGS.dataset == 'imagenetfull':
        channel_num = 3
        X_NOISE = tf.placeholder(shape=(None, 128, 128, 3), dtype=tf.float32)
        X = tf.placeholder(shape=(None, 128, 128, 3), dtype=tf.float32)
        LABEL = tf.placeholder(shape=(None, 1000), dtype=tf.float32)
        LABEL_POS = tf.placeholder(shape=(None, 1000), dtype=tf.float32)

        model = ResNet128(num_channels=channel_num, num_filters=64)

    elif FLAGS.dataset == 'mnist':
        dataset = Mnist(rescale=FLAGS.rescale)
        test_dataset = dataset
        channel_num = 1
        X_NOISE = tf.placeholder(shape=(None, 28, 28), dtype=tf.float32)
        X = tf.placeholder(shape=(None, 28, 28), dtype=tf.float32)
        LABEL = tf.placeholder(shape=(None, 10), dtype=tf.float32)
        LABEL_POS = tf.placeholder(shape=(None, 10), dtype=tf.float32)

        model = MnistNet(num_channels=channel_num,
                         num_filters=FLAGS.num_filters)

    elif FLAGS.dataset == 'dsprites':
        dataset = DSprites(cond_shape=FLAGS.cond_shape,
                           cond_size=FLAGS.cond_size,
                           cond_pos=FLAGS.cond_pos,
                           cond_rot=FLAGS.cond_rot)
        test_dataset = dataset
        channel_num = 1

        X_NOISE = tf.placeholder(shape=(None, 64, 64), dtype=tf.float32)
        X = tf.placeholder(shape=(None, 64, 64), dtype=tf.float32)

        if FLAGS.dpos_only:
            LABEL = tf.placeholder(shape=(None, 2), dtype=tf.float32)
            LABEL_POS = tf.placeholder(shape=(None, 2), dtype=tf.float32)
        elif FLAGS.dsize_only:
            LABEL = tf.placeholder(shape=(None, 1), dtype=tf.float32)
            LABEL_POS = tf.placeholder(shape=(None, 1), dtype=tf.float32)
        elif FLAGS.drot_only:
            LABEL = tf.placeholder(shape=(None, 2), dtype=tf.float32)
            LABEL_POS = tf.placeholder(shape=(None, 2), dtype=tf.float32)
        elif FLAGS.cond_size:
            LABEL = tf.placeholder(shape=(None, 1), dtype=tf.float32)
            LABEL_POS = tf.placeholder(shape=(None, 1), dtype=tf.float32)
        elif FLAGS.cond_shape:
            LABEL = tf.placeholder(shape=(None, 3), dtype=tf.float32)
            LABEL_POS = tf.placeholder(shape=(None, 3), dtype=tf.float32)
        elif FLAGS.cond_pos:
            LABEL = tf.placeholder(shape=(None, 2), dtype=tf.float32)
            LABEL_POS = tf.placeholder(shape=(None, 2), dtype=tf.float32)
        elif FLAGS.cond_rot:
            LABEL = tf.placeholder(shape=(None, 2), dtype=tf.float32)
            LABEL_POS = tf.placeholder(shape=(None, 2), dtype=tf.float32)
        else:
            LABEL = tf.placeholder(shape=(None, 3), dtype=tf.float32)
            LABEL_POS = tf.placeholder(shape=(None, 3), dtype=tf.float32)

        model = DspritesNet(num_channels=channel_num,
                            num_filters=FLAGS.num_filters,
                            cond_size=FLAGS.cond_size,
                            cond_shape=FLAGS.cond_shape,
                            cond_pos=FLAGS.cond_pos,
                            cond_rot=FLAGS.cond_rot)

    print("Done loading...")

    if FLAGS.dataset == "imagenetfull":
        # In the case of full imagenet, use custom_tensorflow dataloader
        data_loader = TFImagenetLoader('train',
                                       FLAGS.batch_size,
                                       hvd.rank(),
                                       hvd.size(),
                                       rescale=FLAGS.rescale)
    else:
        data_loader = DataLoader(dataset,
                                 batch_size=FLAGS.batch_size,
                                 num_workers=FLAGS.data_workers,
                                 drop_last=True,
                                 shuffle=True)

    batch_size = FLAGS.batch_size

    weights = [model.construct_weights('context_0')]

    Y = tf.placeholder(shape=(None), dtype=tf.int32)

    # Varibles to run in training
    X_SPLIT = tf.split(X, FLAGS.num_gpus)
    X_NOISE_SPLIT = tf.split(X_NOISE, FLAGS.num_gpus)
    LABEL_SPLIT = tf.split(LABEL, FLAGS.num_gpus)
    LABEL_POS_SPLIT = tf.split(LABEL_POS, FLAGS.num_gpus)
    LABEL_SPLIT_INIT = list(LABEL_SPLIT)
    tower_grads = []
    tower_gen_grads = []
    x_mod_list = []

    optimizer = AdamOptimizer(FLAGS.lr, beta1=0.0, beta2=0.999)
    optimizer = hvd.DistributedOptimizer(optimizer)

    for j in range(FLAGS.num_gpus):

        if FLAGS.model_cclass:
            ind_batch_size = FLAGS.batch_size // FLAGS.num_gpus
            label_tensor = tf.Variable(tf.convert_to_tensor(np.reshape(
                np.tile(np.eye(10), (FLAGS.batch_size, 1, 1)),
                (FLAGS.batch_size * 10, 10)),
                                                            dtype=tf.float32),
                                       trainable=False,
                                       dtype=tf.float32)
            x_split = tf.tile(
                tf.reshape(X_SPLIT[j], (ind_batch_size, 1, 32, 32, 3)),
                (1, 10, 1, 1, 1))
            x_split = tf.reshape(x_split, (ind_batch_size * 10, 32, 32, 3))
            energy_pos = model.forward(x_split,
                                       weights[0],
                                       label=label_tensor,
                                       stop_at_grad=False)

            energy_pos_full = tf.reshape(energy_pos, (ind_batch_size, 10))
            energy_partition_est = tf.reduce_logsumexp(energy_pos_full,
                                                       axis=1,
                                                       keepdims=True)
            uniform = tf.random_uniform(tf.shape(energy_pos_full))
            label_tensor = tf.argmax(-energy_pos_full -
                                     tf.log(-tf.log(uniform)) -
                                     energy_partition_est,
                                     axis=1)
            label = tf.one_hot(label_tensor, 10, dtype=tf.float32)
            label = tf.Print(label, [label_tensor, energy_pos_full])
            LABEL_SPLIT[j] = label
            energy_pos = tf.concat(energy_pos, axis=0)
        else:
            energy_pos = [
                model.forward(X_SPLIT[j],
                              weights[0],
                              label=LABEL_POS_SPLIT[j],
                              stop_at_grad=False)
            ]
            energy_pos = tf.concat(energy_pos, axis=0)

        print("Building graph...")
        x_mod = x_orig = X_NOISE_SPLIT[j]

        x_grads = []

        energy_negs = []
        loss_energys = []

        energy_negs.extend([
            model.forward(tf.stop_gradient(x_mod),
                          weights[0],
                          label=LABEL_SPLIT[j],
                          stop_at_grad=False,
                          reuse=True)
        ])
        eps_begin = tf.zeros(1)

        steps = tf.constant(0)
        c = lambda i, x: tf.less(i, FLAGS.num_steps)

        def langevin_step(counter, x_mod):
            x_mod = x_mod + tf.random_normal(
                tf.shape(x_mod),
                mean=0.0,
                stddev=0.005 * FLAGS.rescale * FLAGS.noise_scale)

            energy_noise = energy_start = tf.concat([
                model.forward(x_mod,
                              weights[0],
                              label=LABEL_SPLIT[j],
                              reuse=True,
                              stop_at_grad=False,
                              stop_batch=True)
            ],
                                                    axis=0)

            x_grad, label_grad = tf.gradients(FLAGS.temperature * energy_noise,
                                              [x_mod, LABEL_SPLIT[j]])
            energy_noise_old = energy_noise

            lr = FLAGS.step_lr

            if FLAGS.proj_norm != 0.0:
                if FLAGS.proj_norm_type == 'l2':
                    x_grad = tf.clip_by_norm(x_grad, FLAGS.proj_norm)
                elif FLAGS.proj_norm_type == 'li':
                    x_grad = tf.clip_by_value(x_grad, -FLAGS.proj_norm,
                                              FLAGS.proj_norm)
                else:
                    print("Other types of projection are not supported!!!")
                    assert False

            # Clip gradient norm for now
            if FLAGS.hmc:
                # Step size should be tuned to get around 65% acceptance
                def energy(x):
                    return FLAGS.temperature * \
                        model.forward(x, weights[0], label=LABEL_SPLIT[j], reuse=True)

                x_last = hmc(x_mod, 15., 10, energy)
            else:
                x_last = x_mod - (lr) * x_grad

            x_mod = x_last
            x_mod = tf.clip_by_value(x_mod, 0, FLAGS.rescale)

            counter = counter + 1

            return counter, x_mod

        steps, x_mod = tf.while_loop(c, langevin_step, (steps, x_mod))

        energy_eval = model.forward(x_mod,
                                    weights[0],
                                    label=LABEL_SPLIT[j],
                                    stop_at_grad=False,
                                    reuse=True)
        x_grad = tf.gradients(FLAGS.temperature * energy_eval, [x_mod])[0]
        x_grads.append(x_grad)

        energy_negs.append(
            model.forward(tf.stop_gradient(x_mod),
                          weights[0],
                          label=LABEL_SPLIT[j],
                          stop_at_grad=False,
                          reuse=True))

        test_x_mod = x_mod

        temp = FLAGS.temperature

        energy_neg = energy_negs[-1]
        x_off = tf.reduce_mean(
            tf.abs(x_mod[:tf.shape(X_SPLIT[j])[0]] - X_SPLIT[j]))

        loss_energy = model.forward(x_mod,
                                    weights[0],
                                    reuse=True,
                                    label=LABEL,
                                    stop_grad=True)

        print("Finished processing loop construction ...")

        target_vars = {}

        if FLAGS.cclass or FLAGS.model_cclass:
            label_sum = tf.reduce_sum(LABEL_SPLIT[0], axis=0)
            label_prob = label_sum / tf.reduce_sum(label_sum)
            label_ent = -tf.reduce_sum(
                label_prob * tf.math.log(label_prob + 1e-7))
        else:
            label_ent = tf.zeros(1)

        target_vars['label_ent'] = label_ent

        if FLAGS.train:

            if FLAGS.objective == 'logsumexp':
                pos_term = temp * energy_pos
                energy_neg_reduced = (energy_neg - tf.reduce_min(energy_neg))
                coeff = tf.stop_gradient(tf.exp(-temp * energy_neg_reduced))
                norm_constant = tf.stop_gradient(tf.reduce_sum(coeff)) + 1e-4
                pos_loss = tf.reduce_mean(temp * energy_pos)
                neg_loss = coeff * (-1 * temp * energy_neg) / norm_constant
                loss_ml = FLAGS.ml_coeff * (pos_loss + tf.reduce_sum(neg_loss))
            elif FLAGS.objective == 'cd':
                pos_loss = tf.reduce_mean(temp * energy_pos)
                neg_loss = -tf.reduce_mean(temp * energy_neg)
                loss_ml = FLAGS.ml_coeff * (pos_loss + tf.reduce_sum(neg_loss))
            elif FLAGS.objective == 'softplus':
                loss_ml = FLAGS.ml_coeff * \
                    tf.nn.softplus(temp * (energy_pos - energy_neg))

            loss_total = tf.reduce_mean(loss_ml)

            if not FLAGS.zero_kl:
                loss_total = loss_total + tf.reduce_mean(loss_energy)

            loss_total = loss_total + \
                FLAGS.l2_coeff * (tf.reduce_mean(tf.square(energy_pos)) + tf.reduce_mean(tf.square((energy_neg))))

            print("Started gradient computation...")
            gvs = optimizer.compute_gradients(loss_total)
            gvs = [(k, v) for (k, v) in gvs if k is not None]

            print("Applying gradients...")

            tower_grads.append(gvs)

            print("Finished applying gradients.")

            target_vars['loss_ml'] = loss_ml
            target_vars['total_loss'] = loss_total
            target_vars['loss_energy'] = loss_energy
            target_vars['weights'] = weights
            target_vars['gvs'] = gvs

        target_vars['X'] = X
        target_vars['Y'] = Y
        target_vars['LABEL'] = LABEL
        target_vars['LABEL_POS'] = LABEL_POS
        target_vars['X_NOISE'] = X_NOISE
        target_vars['energy_pos'] = energy_pos
        target_vars['energy_start'] = energy_negs[0]

        if len(x_grads) >= 1:
            target_vars['x_grad'] = x_grads[-1]
            target_vars['x_grad_first'] = x_grads[0]
        else:
            target_vars['x_grad'] = tf.zeros(1)
            target_vars['x_grad_first'] = tf.zeros(1)

        target_vars['x_mod'] = x_mod
        target_vars['x_off'] = x_off
        target_vars['temp'] = temp
        target_vars['energy_neg'] = energy_neg
        target_vars['test_x_mod'] = test_x_mod
        target_vars['eps_begin'] = eps_begin

    if FLAGS.train:
        grads = average_gradients(tower_grads)
        train_op = optimizer.apply_gradients(grads)
        target_vars['train_op'] = train_op

    config = tf.ConfigProto()

    if hvd.size() > 1:
        config.gpu_options.visible_device_list = str(hvd.local_rank())

    sess = tf.Session(config=config)

    saver = loader = tf.train.Saver(max_to_keep=30,
                                    keep_checkpoint_every_n_hours=6)

    total_parameters = 0
    for variable in tf.trainable_variables():
        # shape is an array of tf.Dimension
        shape = variable.get_shape()
        variable_parameters = 1
        for dim in shape:
            variable_parameters *= dim.value
        total_parameters += variable_parameters
    print("Model has a total of {} parameters".format(total_parameters))

    sess.run(tf.global_variables_initializer())

    resume_itr = 0

    if (FLAGS.resume_iter != -1 or not FLAGS.train) and hvd.rank() == 0:
        model_file = osp.join(logdir, 'model_{}'.format(FLAGS.resume_iter))
        resume_itr = FLAGS.resume_iter
        # saver.restore(sess, model_file)
        optimistic_restore(sess, model_file)

    sess.run(hvd.broadcast_global_variables(0))
    print("Initializing variables...")

    print("Start broadcast")
    print("End broadcast")

    if FLAGS.train:
        print("Training phase")
        train(target_vars, saver, sess, logger, data_loader, resume_itr,
              logdir)
    print("Testing phase")
    test(target_vars, saver, sess, logger, data_loader)
예제 #4
0
def main():
    logdir = osp.join(FLAGS.logdir, FLAGS.exp)
    if not osp.exists(logdir):
        os.makedirs(logdir)
    logger = TensorBoardOutputFormat(logdir)

    datasource = FLAGS.datasource

    def make_env(rank):
        def _thunk():
            # Make the environments non stoppable for now
            if datasource == "maze":
                env = Maze(end=[0.7, -0.8],
                           start=[-0.85, -0.85],
                           random_starts=False)
            elif datasource == "point":
                env = Point(end=[0.5, 0.5],
                            start=[0.0, 0.0],
                            random_starts=True)
            elif datasource == "reacher":
                env = Reacher(end=[0.7, 0.5], eps=0.01)
            env.seed(rank)
            env = Monitor(env,
                          os.path.join("/tmp", str(rank)),
                          allow_early_resets=True)
            return env

        return _thunk

    env = SubprocVecEnv(
        [make_env(i + FLAGS.seed) for i in range(FLAGS.num_env)])

    if FLAGS.datasource == 'point' or FLAGS.datasource == 'maze' or FLAGS.datasource == 'reacher':
        if FLAGS.ff_model:
            model = TrajFFDynamics(dim_input=FLAGS.latent_dim,
                                   dim_output=FLAGS.latent_dim)
        else:
            model = TrajNetLatentFC(dim_input=FLAGS.latent_dim)

        X_NOISE = tf.placeholder(shape=(None, FLAGS.total_frame,
                                        FLAGS.input_objects, FLAGS.latent_dim),
                                 dtype=tf.float32)
        X = tf.placeholder(shape=(None, FLAGS.total_frame, FLAGS.input_objects,
                                  FLAGS.latent_dim),
                           dtype=tf.float32)

        if FLAGS.cond:
            ACTION_LABEL = tf.placeholder(shape=(None, 2), dtype=tf.float32)
        else:
            ACTION_LABEL = None

        ACTION_NOISE_LABEL = tf.placeholder(shape=(None, 2), dtype=tf.float32)
        ACTION_PLAN = tf.placeholder(shape=(None, FLAGS.plan_steps + 1, 2),
                                     dtype=tf.float32)

        X_START = tf.placeholder(shape=(None, 1, FLAGS.input_objects,
                                        FLAGS.latent_dim),
                                 dtype=tf.float32)
        X_PLAN = tf.placeholder(shape=(None, FLAGS.plan_steps,
                                       FLAGS.input_objects, FLAGS.latent_dim),
                                dtype=tf.float32)

        if FLAGS.datasource == 'reacher':
            X_END = tf.placeholder(shape=(None, 1, FLAGS.input_objects, 2),
                                   dtype=tf.float32)
        else:
            X_END = tf.placeholder(shape=(None, 1, FLAGS.input_objects,
                                          FLAGS.latent_dim),
                                   dtype=tf.float32)
    else:
        raise AssertionError("Unsupported data source")

    weights = model.construct_weights(action_size=FLAGS.action_dim)
    optimizer = AdamOptimizer(1e-2, beta1=0.0, beta2=0.999)

    if FLAGS.ff_model:
        target_vars = construct_ff_model(model, weights, X_NOISE, X,
                                         ACTION_LABEL, ACTION_NOISE_LABEL,
                                         optimizer)
        target_vars = construct_ff_plan_model(model,
                                              weights,
                                              X_PLAN,
                                              X_START,
                                              X_END,
                                              ACTION_PLAN,
                                              target_vars=target_vars)
    else:
        target_vars = construct_model(model, weights, X_NOISE, X, ACTION_LABEL,
                                      ACTION_NOISE_LABEL, optimizer)
        target_vars = construct_plan_model(model,
                                           weights,
                                           X_PLAN,
                                           X_START,
                                           X_END,
                                           ACTION_PLAN,
                                           target_vars=target_vars)

    sess = tf.InteractiveSession()
    saver = loader = tf.train.Saver(max_to_keep=10,
                                    keep_checkpoint_every_n_hours=2)

    tf.global_variables_initializer().run()
    print("Initializing variables...")

    if FLAGS.resume_iter != -1 or not FLAGS.train:
        model_file = osp.join(logdir, 'model_{}'.format(FLAGS.resume_iter))
        resume_itr = FLAGS.resume_iter
        saver.restore(sess, model_file)

    train(target_vars, saver, sess, logger, FLAGS.resume_iter, env)
예제 #5
0
def construct_model(model, weights, X_NOISE, X, ACTION_LABEL,
                    ACTION_NOISE_LABEL, optimizer):
    target_vars = {}
    x_mods = []

    energy_pos = model.forward(X, weights, action_label=ACTION_LABEL)
    energy_noise = energy_start = model.forward(X_NOISE,
                                                weights,
                                                reuse=True,
                                                stop_at_grad=True,
                                                action_label=ACTION_LABEL)

    x_mod = X_NOISE

    x_grads = []
    x_ees = []

    if not FLAGS.gt_inverse_dynamics:
        dyn_model = TrajInverseDynamics(dim_input=FLAGS.latent_dim,
                                        dim_output=FLAGS.action_dim)
        weights = dyn_model.construct_weights(scope="inverse_dynamics",
                                              weights=weights)

    steps = tf.constant(0)
    c = lambda i, x, y: tf.less(i, FLAGS.num_steps)

    def mcmc_step(counter, x_mod, action_label):
        x_mod = x_mod + tf.random_normal(
            tf.shape(x_mod), mean=0.0, stddev=0.01)
        action_label = action_label + tf.random_normal(
            tf.shape(action_label), mean=0.0, stddev=0.01)
        energy_noise = model.forward(x_mod,
                                     weights,
                                     action_label=action_label,
                                     reuse=True)
        lr = FLAGS.step_lr

        x_grad = tf.gradients(FLAGS.temperature * energy_noise, [x_mod])[0]

        x_mod = x_mod - lr * x_grad

        if FLAGS.cond:
            x_grad, action_grad = tf.gradients(
                FLAGS.temperature * energy_noise, [x_mod, action_label])
        else:
            x_grad, action_grad = tf.gradients(
                FLAGS.temperature * energy_noise, [x_mod])[0], tf.zeros(1)

        action_label = action_label - FLAGS.step_lr * action_grad

        x_mod = tf.clip_by_value(x_mod, -1.2, 1.2)
        action_label = tf.clip_by_value(action_label, -1.2, 1.2)

        counter = counter + 1

        return counter, x_mod, action_label

    steps, x_mod, action_label = tf.while_loop(
        c, mcmc_step, (steps, x_mod, ACTION_NOISE_LABEL))

    target_vars['x_mod'] = x_mod
    temp = FLAGS.temperature

    loss_energy = temp * model.forward(
        x_mod, weights, reuse=True, action_label=action_label, stop_grad=True)
    x_mod = tf.stop_gradient(x_mod)
    action_label = tf.stop_gradient(action_label)

    energy_neg = model.forward(x_mod,
                               weights,
                               action_label=action_label,
                               reuse=True)

    if FLAGS.cond:
        x_grad, action_grad = tf.gradients(FLAGS.temperature * energy_neg,
                                           [x_mod, action_label])
    else:
        x_grad, action_grad = tf.gradients(FLAGS.temperature * energy_neg,
                                           [x_mod])[0], tf.zeros(1)

    x_off = tf.reduce_mean(tf.square(x_mod - X))

    if FLAGS.train:
        pos_loss = tf.reduce_mean(temp * energy_pos)
        neg_loss = -tf.reduce_mean(temp * energy_neg)
        loss_ml = (pos_loss + tf.reduce_sum(neg_loss))
        loss_total = tf.reduce_mean(loss_ml)
        loss_total = loss_total + \
                     (tf.reduce_mean(tf.square(energy_pos)) + tf.reduce_mean(tf.square((energy_neg))))

    if not FLAGS.gt_inverse_dynamics:
        output_action = dyn_model.forward(X, weights)
        dyn_loss = tf.reduce_mean(tf.square(output_action - ACTION_LABEL))
        dyn_dist = tf.reduce_mean(tf.abs(output_action - ACTION_LABEL))
        target_vars['dyn_loss'] = dyn_loss
        target_vars['dyn_dist'] = dyn_dist

        dyn_optimizer = AdamOptimizer(1e-3)
        gvs = dyn_optimizer.compute_gradients(dyn_loss)
        dyn_train_op = dyn_optimizer.apply_gradients(gvs)
    else:
        target_vars['dyn_loss'] = tf.zeros(1)
        target_vars['dyn_dist'] = tf.zeros(1)

    if FLAGS.train:
        print("Started gradient computation...")
        gvs = optimizer.compute_gradients(loss_total)
        gvs = [(k, v) for (k, v) in gvs if k is not None]
        print("Applying gradients...")
        grads, vs = zip(*gvs)

        def filter_grad(g, v):
            return tf.clip_by_value(g, -1e5, 1e5)

        capped_gvs = [(filter_grad(grad, var), var) for grad, var in gvs]
        gvs = capped_gvs
        train_op = optimizer.apply_gradients(gvs)

        if not FLAGS.gt_inverse_dynamics:
            train_op = tf.group(train_op, dyn_train_op)

        target_vars['train_op'] = train_op

        print("Finished applying gradients.")
        target_vars['loss_ml'] = loss_ml
        target_vars['loss_total'] = loss_total
        target_vars['gvs'] = gvs
        target_vars['loss_energy'] = loss_energy

    target_vars['weights'] = weights
    target_vars['X'] = X
    target_vars['X_NOISE'] = X_NOISE
    target_vars['energy_pos'] = energy_pos
    target_vars['energy_neg'] = energy_neg
    target_vars['x_grad'] = x_grad
    target_vars['action_grad'] = action_grad
    target_vars['x_mod'] = x_mod
    target_vars['x_off'] = x_off
    target_vars['temp'] = temp
    target_vars['ACTION_LABEL'] = ACTION_LABEL
    target_vars['ACTION_NOISE_LABEL'] = ACTION_NOISE_LABEL
    target_vars['idyn_model'] = dyn_model

    return target_vars
def gentest(sess, kvs, data, latents, save_exp_dir):
    X_NOISE = kvs['X_NOISE']
    LABEL_SIZE = kvs['LABEL_SIZE']
    LABEL_SHAPE = kvs['LABEL_SHAPE']
    LABEL_POS = kvs['LABEL_POS']
    LABEL_ROT = kvs['LABEL_ROT']
    model_size = kvs['model_size']
    model_shape = kvs['model_shape']
    model_pos = kvs['model_pos']
    model_rot = kvs['model_rot']
    weight_size = kvs['weight_size']
    weight_shape = kvs['weight_shape']
    weight_pos = kvs['weight_pos']
    weight_rot = kvs['weight_rot']
    X = tf.placeholder(shape=(None, 64, 64), dtype=tf.float32)

    datafull = data
    # Test combination of generalization where we use slices of both training
    x_final = X_NOISE
    x_mod_size = X_NOISE
    x_mod_pos = X_NOISE

    for i in range(FLAGS.num_steps):

        # use cond_pos

        energies = []
        x_mod_pos = x_mod_pos + tf.random_normal(tf.shape(x_mod_pos), mean=0.0, stddev=0.005)
        e_noise = model_pos.forward(x_final, weight_pos, label=LABEL_POS)

        # energies.append(e_noise)
        x_grad = tf.gradients(e_noise, [x_final])[0]
        x_mod_pos = x_mod_pos + tf.random_normal(tf.shape(x_mod_pos), mean=0.0, stddev=0.005)
        x_mod_pos = x_mod_pos - FLAGS.step_lr * x_grad
        x_mod_pos = tf.clip_by_value(x_mod_pos, 0, 1)

        if FLAGS.joint_shape:
            # use cond_shape
            e_noise = model_shape.forward(x_mod_pos, weight_shape, label=LABEL_SHAPE)
        elif FLAGS.joint_rot:
            e_noise = model_rot.forward(x_mod_pos, weight_rot, label=LABEL_ROT)
        else:
            # use cond_size
            e_noise = model_size.forward(x_mod_pos, weight_size, label=LABEL_SIZE)

        # energies.append(e_noise)
        # energy_stack = tf.concat(energies, axis=1)
        # energy_stack = tf.reduce_logsumexp(-1*energy_stack, axis=1)
        # energy_stack = tf.reduce_sum(energy_stack, axis=1)

        x_grad = tf.gradients(e_noise, [x_mod_pos])[0]
        x_mod_pos = x_mod_pos - FLAGS.step_lr * x_grad
        x_mod_pos = tf.clip_by_value(x_mod_pos, 0, 1)

        # for x_mod_size
        # use cond_size
        # e_noise = model_size.forward(x_mod_size, weight_size, label=LABEL_SIZE)
        # x_grad = tf.gradients(e_noise, [x_mod_size])[0]
        # x_mod_size = x_mod_size + tf.random_normal(tf.shape(x_mod_size), mean=0.0, stddev=0.005)
        # x_mod_size = x_mod_size - FLAGS.step_lr * x_grad
        # x_mod_size = tf.clip_by_value(x_mod_size, 0, 1)

        # # use cond_pos
        # e_noise = model_pos.forward(x_mod_size, weight_pos, label=LABEL_POS)
        # x_grad = tf.gradients(e_noise, [x_mod_size])[0]
        # x_mod_size = x_mod_size + tf.random_normal(tf.shape(x_mod_size), mean=0.0, stddev=0.005)
        # x_mod_size = x_mod_size - FLAGS.step_lr * tf.stop_gradient(x_grad)
        # x_mod_size = tf.clip_by_value(x_mod_size, 0, 1)

    x_mod = x_mod_pos
    x_final = x_mod


    if FLAGS.joint_shape:
        loss_kl = model_shape.forward(x_final, weight_shape, reuse=True, label=LABEL_SHAPE, stop_grad=True) + \
                  model_pos.forward(x_final, weight_pos, reuse=True, label=LABEL_POS, stop_grad=True)

        energy_pos = model_shape.forward(X, weight_shape, reuse=True, label=LABEL_SHAPE) + \
                      model_pos.forward(X, weight_pos, reuse=True, label=LABEL_POS)

        energy_neg = model_shape.forward(tf.stop_gradient(x_mod), weight_shape, reuse=True, label=LABEL_SHAPE) + \
                      model_pos.forward(tf.stop_gradient(x_mod), weight_pos, reuse=True, label=LABEL_POS)
    elif FLAGS.joint_rot:
        loss_kl = model_rot.forward(x_final, weight_rot, reuse=True, label=LABEL_ROT, stop_grad=True) + \
                  model_pos.forward(x_final, weight_pos, reuse=True, label=LABEL_POS, stop_grad=True)

        energy_pos = model_rot.forward(X, weight_rot, reuse=True, label=LABEL_ROT) + \
                      model_pos.forward(X, weight_pos, reuse=True, label=LABEL_POS)

        energy_neg = model_rot.forward(tf.stop_gradient(x_mod), weight_rot, reuse=True, label=LABEL_ROT) + \
                      model_pos.forward(tf.stop_gradient(x_mod), weight_pos, reuse=True, label=LABEL_POS)
    else:
        loss_kl = model_size.forward(x_final, weight_size, reuse=True, label=LABEL_SIZE, stop_grad=True) + \
                    model_pos.forward(x_final, weight_pos, reuse=True, label=LABEL_POS, stop_grad=True)

        energy_pos = model_size.forward(X, weight_size, reuse=True, label=LABEL_SIZE) + \
                      model_pos.forward(X, weight_pos, reuse=True, label=LABEL_POS)

        energy_neg = model_size.forward(tf.stop_gradient(x_mod), weight_size, reuse=True, label=LABEL_SIZE) + \
                      model_pos.forward(tf.stop_gradient(x_mod), weight_pos, reuse=True, label=LABEL_POS)

    energy_neg_reduced = (energy_neg - tf.reduce_min(energy_neg))
    coeff = tf.stop_gradient(tf.exp(-energy_neg_reduced))
    norm_constant = tf.stop_gradient(tf.reduce_sum(coeff)) + 1e-4
    neg_loss = coeff * (-1*energy_neg) / norm_constant

    loss_ml = tf.reduce_mean(energy_pos) - tf.reduce_mean(energy_neg)
    loss_total = loss_ml + tf.reduce_mean(loss_kl) + 1 * (tf.reduce_mean(tf.square(energy_pos)) + tf.reduce_mean(tf.square(energy_neg)))

    optimizer = AdamOptimizer(1e-3, beta1=0.0, beta2=0.999)
    gvs = optimizer.compute_gradients(loss_total)
    gvs = [(k, v) for (k, v) in gvs if k is not None]
    train_op = optimizer.apply_gradients(gvs)

    vs = optimizer.variables()
    sess.run(tf.variables_initializer(vs))

    dataloader = DataLoader(DSpritesGen(data, latents), batch_size=FLAGS.batch_size, num_workers=6, drop_last=True, shuffle=True)

    x_off = tf.reduce_mean(tf.square(x_mod - X))

    itr = 0
    saver = tf.train.Saver()
    x_mod = None


    if FLAGS.train:
        replay_buffer = ReplayBuffer(10000)
        for _ in range(1):


            for data_corrupt, data, label_size, label_pos in tqdm(dataloader):
                data_corrupt = data_corrupt.numpy()[:, :, :]
                data = data.numpy()[:, :, :]

                if x_mod is not None:
                    replay_buffer.add(x_mod)
                    replay_batch = replay_buffer.sample(FLAGS.batch_size)
                    replay_mask = (np.random.uniform(0, 1, (FLAGS.batch_size)) > 0.95)
                    data_corrupt[replay_mask] = replay_batch[replay_mask]

                if FLAGS.joint_shape:
                    feed_dict = {X_NOISE: data_corrupt, X: data, LABEL_SHAPE: label_size, LABEL_POS: label_pos}
                elif FLAGS.joint_rot:
                    feed_dict = {X_NOISE: data_corrupt, X: data, LABEL_ROT: label_size, LABEL_POS: label_pos}
                else:
                    feed_dict = {X_NOISE: data_corrupt, X: data, LABEL_SIZE: label_size, LABEL_POS: label_pos}

                _, off_value, e_pos, e_neg, x_mod = sess.run([train_op, x_off, energy_pos, energy_neg, x_final], feed_dict=feed_dict)
                itr += 1

                if itr % 10 == 0:
                    print("x_off of {}, e_pos of {}, e_neg of {} itr of {}".format(off_value, e_pos.mean(), e_neg.mean(), itr))

                if itr == FLAGS.break_steps:
                    break


        saver.save(sess, osp.join(save_exp_dir, 'model_gentest'))

    saver.restore(sess, osp.join(save_exp_dir, 'model_gentest'))

    l = latents

    if FLAGS.joint_shape:
        mask_gen = (l[:, 3] == 30 * np.pi / 39) * (l[:, 2] == 0.5)
    elif FLAGS.joint_rot:
        mask_gen = (l[:, 1] == 1) * (l[:, 2] == 0.5)
    else:
        mask_gen = (l[:, 3] == 30 * np.pi / 39) * (l[:, 1] == 1) & (~((l[:, 2] == 0.5) | ((l[:, 4] == 16/31) & (l[:, 5] == 16/31))))

    data_gen = datafull[mask_gen]
    latents_gen = latents[mask_gen]

    losses = []

    for dat, latent in zip(np.array_split(data_gen, 120), np.array_split(latents_gen, 120)):
        x = 0.5 + np.random.randn(*dat.shape)

        if FLAGS.joint_shape:
            feed_dict = {LABEL_SHAPE: np.eye(3)[latent[:, 1].astype(np.int32) - 1], LABEL_POS: latent[:, 4:], X_NOISE: x, X: dat}
        elif FLAGS.joint_rot:
            feed_dict = {LABEL_ROT: np.concatenate([np.cos(latent[:, 3:4]), np.sin(latent[:, 3:4])], axis=1), LABEL_POS: latent[:, 4:], X_NOISE: x, X: dat}
        else:
            feed_dict = {LABEL_SIZE: latent[:, 2:3], LABEL_POS: latent[:, 4:], X_NOISE: x, X: dat}

        for i in range(2):
            x = sess.run([x_final], feed_dict=feed_dict)[0]
            feed_dict[X_NOISE] = x

        loss = sess.run([x_off], feed_dict=feed_dict)[0]
        losses.append(loss)

    print("Mean MSE loss of {} ".format(np.mean(losses)))

    data_try = data_gen[:10]
    data_init = 0.5 + 0.5 * np.random.randn(10, 64, 64)
    latent_scale = latents_gen[:10, 2:3]
    latent_pos = latents_gen[:10, 4:]

    if FLAGS.joint_shape:
        feed_dict = {X_NOISE: data_init, LABEL_SHAPE: np.eye(3)[latent[:10, 1].astype(np.int32)-1], LABEL_POS: latent_pos}
    elif FLAGS.joint_rot:
        feed_dict = {LABEL_ROT: np.concatenate([np.cos(latent[:10, 3:4]), np.sin(latent[:10, 3:4])], axis=1), LABEL_POS: latent[:10, 4:], X_NOISE: data_init}
    else:
        feed_dict = {X_NOISE: data_init, LABEL_SIZE: latent_scale, LABEL_POS: latent_pos}

    x_output = sess.run([x_final], feed_dict=feed_dict)[0]

    if FLAGS.joint_shape:
        im_name = "size_shape_combine_gentest.png"
    else:
        im_name = "size_scale_combine_gentest.png"

    x_output_wrap = np.ones((10, 66, 66))
    data_try_wrap = np.ones((10, 66, 66))

    x_output_wrap[:, 1:-1, 1:-1] = x_output
    data_try_wrap[:, 1:-1, 1:-1] = data_try

    im_output = np.concatenate([x_output_wrap, data_try_wrap], axis=2).reshape(-1, 66*2)
    impath = osp.join(save_exp_dir, im_name)
    imsave(impath, im_output)
    print("Successfully saved images at {}".format(impath))
def genbaseline(sess, kvs, data, latents, save_exp_dir, frac=0.0):
    # tf.reset_default_graph()

    if FLAGS.joint_shape:
        model_baseline = DspritesNetGen(num_filters=FLAGS.num_filters, label_size=5)
        LABEL = tf.placeholder(shape=(None, 5), dtype=tf.float32)
    else:
        model_baseline = DspritesNetGen(num_filters=FLAGS.num_filters, label_size=3)
        LABEL = tf.placeholder(shape=(None, 3), dtype=tf.float32)

    weights_baseline = model_baseline.construct_weights('context_baseline_{}'.format(frac))

    X_feed = tf.placeholder(shape=(None, 2*FLAGS.num_filters), dtype=tf.float32)
    X_label = tf.placeholder(shape=(None, 64, 64), dtype=tf.float32)

    X_out = model_baseline.forward(X_feed, weights_baseline, label=LABEL)
    loss_sq = tf.reduce_mean(tf.square(X_out - X_label))

    optimizer = AdamOptimizer(1e-3)
    gvs = optimizer.compute_gradients(loss_sq)
    gvs = [(k, v) for (k, v) in gvs if k is not None]
    train_op = optimizer.apply_gradients(gvs)

    dataloader = DataLoader(DSpritesGen(data, latents, frac=frac), batch_size=FLAGS.batch_size, num_workers=6, drop_last=True, shuffle=True)

    datafull = data

    itr = 0
    saver = tf.train.Saver()

    vs = optimizer.variables()
    sess.run(tf.global_variables_initializer())

    if FLAGS.train:
        for _ in range(5):
            for data_corrupt, data, label_size, label_pos in tqdm(dataloader):

                data_corrupt = data_corrupt.numpy()
                label_size, label_pos = label_size.numpy(), label_pos.numpy()

                data_corrupt = np.random.randn(data_corrupt.shape[0], 2*FLAGS.num_filters)
                label_comb = np.concatenate([label_size, label_pos], axis=1)

                feed_dict = {X_feed: data_corrupt, X_label: data, LABEL: label_comb}

                output = [loss_sq, train_op]

                loss, _ = sess.run(output, feed_dict=feed_dict)

                itr += 1

        saver.save(sess, osp.join(save_exp_dir, 'model_genbaseline'))

    saver.restore(sess, osp.join(save_exp_dir, 'model_genbaseline'))

    l = latents

    if FLAGS.joint_shape:
        mask_gen = (l[:, 3] == 30 * np.pi / 39) * (l[:, 2] == 0.5)
    else:
        mask_gen = (l[:, 3] == 30 * np.pi / 39) * (l[:, 1] == 1) & (~((l[:, 2] == 0.5) | ((l[:, 4] == 16/31) & (l[:, 5] == 16/31))))

    data_gen = datafull[mask_gen]
    latents_gen = latents[mask_gen]
    losses = []

    for dat, latent in zip(np.array_split(data_gen, 10), np.array_split(latents_gen, 10)):
        data_init = np.random.randn(dat.shape[0], 2*FLAGS.num_filters)

        if FLAGS.joint_shape:
            latent_size = np.eye(3)[latent[:, 1].astype(np.int32) - 1]
            latent_pos = latent[:, 4:6]
            latent = np.concatenate([latent_size, latent_pos], axis=1)
            feed_dict = {X_feed: data_init, LABEL: latent, X_label: dat}
        else:
            feed_dict = {X_feed: data_init, LABEL: latent[:, [2,4,5]], X_label: dat}
        loss = sess.run([loss_sq], feed_dict=feed_dict)[0]
        # print(loss)
        losses.append(loss)

    print("Overall MSE for generalization of {} for fraction of {}".format(np.mean(losses), frac))


    data_try = data_gen[:10]
    data_init = np.random.randn(10, 2*FLAGS.num_filters)

    if FLAGS.joint_shape:
        latent_scale = np.eye(3)[latent[:10, 1].astype(np.int32) - 1]
        latent_pos = latents_gen[:10, 4:]
    else:
        latent_scale = latents_gen[:10, 2:3]
        latent_pos = latents_gen[:10, 4:]

    latent_tot = np.concatenate([latent_scale, latent_pos], axis=1)

    feed_dict = {X_feed: data_init, LABEL: latent_tot}
    x_output = sess.run([X_out], feed_dict=feed_dict)[0]
    x_output = np.clip(x_output, 0, 1)

    im_name = "size_scale_combine_genbaseline.png"

    x_output_wrap = np.ones((10, 66, 66))
    data_try_wrap = np.ones((10, 66, 66))

    x_output_wrap[:, 1:-1, 1:-1] = x_output
    data_try_wrap[:, 1:-1, 1:-1] = data_try

    im_output = np.concatenate([x_output_wrap, data_try_wrap], axis=2).reshape(-1, 66*2)
    impath = osp.join(save_exp_dir, im_name)
    imsave(impath, im_output)
    print("Successfully saved images at {}".format(impath))

    return np.mean(losses)
예제 #8
0
def main():
    print("Local rank: ", hvd.local_rank(), hvd.size())
    FLAGS.exp = FLAGS.exp + '_' + FLAGS.divergence

    logdir = osp.join(FLAGS.logdir, FLAGS.exp)
    if hvd.rank() == 0:
        if not osp.exists(logdir):
            os.makedirs(logdir)
        logger = TensorBoardOutputFormat(logdir)
    else:
        logger = None

    print("Loading data...")
    dataset = Cifar10(augment=FLAGS.augment, rescale=FLAGS.rescale)
    test_dataset = Cifar10(train=False, rescale=FLAGS.rescale)
    channel_num = 3

    X_NOISE = tf.placeholder(shape=(None, 32, 32, 3), dtype=tf.float32)
    X = tf.placeholder(shape=(None, 32, 32, 3), dtype=tf.float32)
    LABEL = tf.placeholder(shape=(None, 10), dtype=tf.float32)
    LABEL_POS = tf.placeholder(shape=(None, 10), dtype=tf.float32)

    if FLAGS.large_model:
        model = ResNet32Large(
            num_channels=channel_num,
            num_filters=128,
            train=True)
        model_dis = ResNet32Large(
            num_channels=channel_num,
            num_filters=128,
            train=True)
    elif FLAGS.larger_model:
        model = ResNet32Larger(
            num_channels=channel_num,
            num_filters=128)
        model_dis = ResNet32Larger(
            num_channels=channel_num,
            num_filters=128)
    elif FLAGS.wider_model:
        model = ResNet32Wider(
            num_channels=channel_num,
            num_filters=256)
        model_dis = ResNet32Wider(
            num_channels=channel_num,
            num_filters=256)
    else:
        model = ResNet32(
            num_channels=channel_num,
            num_filters=128)
        model_dis = ResNet32(
            num_channels=channel_num,
            num_filters=128)

    print("Done loading...")

    grad_exp, conjugate_grad_exp = get_divergence_funcs(FLAGS.divergence)

    data_loader = DataLoader(
        dataset,
        batch_size=FLAGS.batch_size,
        num_workers=FLAGS.data_workers,
        drop_last=True,
        shuffle=True)

    weights = [model.construct_weights('context_energy'), model_dis.construct_weights('context_dis')]

    Y = tf.placeholder(shape=(None), dtype=tf.int32)

    # Varibles to run in training
    X_SPLIT = tf.split(X, FLAGS.num_gpus)
    X_NOISE_SPLIT = tf.split(X_NOISE, FLAGS.num_gpus)
    LABEL_SPLIT = tf.split(LABEL, FLAGS.num_gpus)
    LABEL_POS_SPLIT = tf.split(LABEL_POS, FLAGS.num_gpus)
    LABEL_SPLIT_INIT = list(LABEL_SPLIT)
    tower_grads = []
    tower_grads_dis = []
    tower_grads_l2 = []
    tower_grads_dis_l2 = []

    optimizer = AdamOptimizer(FLAGS.lr, beta1=0.0, beta2=0.999)
    optimizer = hvd.DistributedOptimizer(optimizer)

    optimizer_dis = AdamOptimizer(FLAGS.lr, beta1=0.0, beta2=0.999)
    optimizer_dis = hvd.DistributedOptimizer(optimizer_dis)

    for j in range(FLAGS.num_gpus):

        energy_pos = [
            model.forward(
                X_SPLIT[j],
                weights[0],
                label=LABEL_POS_SPLIT[j],
                stop_at_grad=False)]
        energy_pos = tf.concat(energy_pos, axis=0)

        score_pos = [
            model_dis.forward(
                X_SPLIT[j],
                weights[1],
                label=LABEL_POS_SPLIT[j],
                stop_at_grad=False)]
        score_pos = tf.concat(score_pos, axis=0)

        print("Building graph...")
        x_mod = x_orig = X_NOISE_SPLIT[j]

        x_grads = []

        energy_negs = []
        loss_energys = []

        energy_negs.extend([model.forward(tf.stop_gradient(
            x_mod), weights[0], label=LABEL_SPLIT[j], stop_at_grad=False, reuse=True)])
        eps_begin = tf.zeros(1)

        steps = tf.constant(0)
        c = lambda i, x: tf.less(i, FLAGS.num_steps)

        def langevin_step(counter, x_mod):
            x_mod = x_mod + tf.random_normal(tf.shape(x_mod),
                                             mean=0.0,
                                             stddev=0.005 * FLAGS.rescale * FLAGS.noise_scale)

            energy_noise = energy_start = tf.concat(
                [model.forward(
                        x_mod,
                        weights[0],
                        label=LABEL_SPLIT[j],
                        reuse=True,
                        stop_at_grad=False,
                        stop_batch=True)],
                axis=0)

            x_grad, label_grad = tf.gradients(energy_noise, [x_mod, LABEL_SPLIT[j]])
            energy_noise_old = energy_noise

            lr = FLAGS.step_lr

            if FLAGS.proj_norm != 0.0:
                if FLAGS.proj_norm_type == 'l2':
                    x_grad = tf.clip_by_norm(x_grad, FLAGS.proj_norm)
                elif FLAGS.proj_norm_type == 'li':
                    x_grad = tf.clip_by_value(
                        x_grad, -FLAGS.proj_norm, FLAGS.proj_norm)
                else:
                    print("Other types of projection are not supported!!!")
                    assert False

            # Clip gradient norm for now
            if FLAGS.hmc:
                # Step size should be tuned to get around 65% acceptance
                def energy(x):
                    return FLAGS.temperature * \
                        model.forward(x, weights[0], label=LABEL_SPLIT[j], reuse=True)

                x_last = hmc(x_mod, 15., 10, energy)
            else:
                x_last = x_mod - (lr) * x_grad

            x_mod = x_last
            x_mod = tf.clip_by_value(x_mod, 0, FLAGS.rescale)

            counter = counter + 1

            return counter, x_mod

        steps, x_mod = tf.while_loop(c, langevin_step, (steps, x_mod))

        energy_eval = model.forward(x_mod, weights[0], label=LABEL_SPLIT[j],
                                    stop_at_grad=False, reuse=True)
        x_grad = tf.gradients(energy_eval, [x_mod])[0]
        x_grads.append(x_grad)

        energy_negs.append(
            model.forward(
                tf.stop_gradient(x_mod),
                weights[0],
                label=LABEL_SPLIT[j],
                stop_at_grad=False,
                reuse=True))

        score_neg = model_dis.forward(
                tf.stop_gradient(x_mod),
                weights[1],
                label=LABEL_SPLIT[j],
                stop_at_grad=False,
                reuse=True)

        test_x_mod = x_mod

        temp = FLAGS.temperature

        energy_neg = energy_negs[-1]
        x_off = tf.reduce_mean(
            tf.abs(x_mod[:tf.shape(X_SPLIT[j])[0]] - X_SPLIT[j]))

        loss_energy = model.forward(
            x_mod,
            weights[0],
            reuse=True,
            label=LABEL,
            stop_grad=True)

        print("Finished processing loop construction ...")

        target_vars = {}

        if FLAGS.cclass or FLAGS.model_cclass:
            label_sum = tf.reduce_sum(LABEL_SPLIT[0], axis=0)
            label_prob = label_sum / tf.reduce_sum(label_sum)
            label_ent = -tf.reduce_sum(label_prob *
                                       tf.math.log(label_prob + 1e-7))
        else:
            label_ent = tf.zeros(1)

        target_vars['label_ent'] = label_ent

        if FLAGS.train:

            loss_dis = - (tf.reduce_mean(grad_exp(score_pos + energy_pos)) - tf.reduce_mean(conjugate_grad_exp(score_neg + energy_neg)))
            loss_dis = loss_dis + FLAGS.l2_coeff * (tf.reduce_mean(tf.square(score_pos)) + tf.reduce_mean(tf.square(score_neg)))
            l2_dis = FLAGS.l2_coeff * (tf.reduce_mean(tf.square(score_pos)) + tf.reduce_mean(tf.square(score_neg)))

            loss_model = tf.reduce_mean(grad_exp(score_pos + energy_pos)) + \
                         tf.reduce_mean(energy_neg * tf.stop_gradient(conjugate_grad_exp(score_neg + energy_neg))) - \
                         tf.reduce_mean(energy_neg) * tf.stop_gradient(tf.reduce_mean(conjugate_grad_exp(score_neg + energy_neg))) - \
                         tf.reduce_mean(conjugate_grad_exp(score_neg + energy_neg))
            loss_model = loss_model + FLAGS.l2_coeff * (tf.reduce_mean(tf.square(energy_pos)) + tf.reduce_mean(tf.square(energy_neg)))
            l2_model = FLAGS.l2_coeff * (tf.reduce_mean(tf.square(energy_pos)) + tf.reduce_mean(tf.square(energy_neg)))

            print("Started gradient computation...")
            model_vars = [var for var in tf.trainable_variables() if 'context_energy' in var.name]
            print("model var number", len(model_vars))
            dis_vars = [var for var in tf.trainable_variables() if 'context_dis' in var.name]
            print("discriminator var number", len(dis_vars))

            gvs = optimizer.compute_gradients(loss_model, model_vars)
            gvs = [(k, v) for (k, v) in gvs if k is not None]
            tower_grads.append(gvs)

            gvs = optimizer.compute_gradients(l2_model, model_vars)
            gvs = [(k, v) for (k, v) in gvs if k is not None]
            tower_grads_l2.append(gvs)

            gvs_dis = optimizer_dis.compute_gradients(loss_dis, dis_vars)
            gvs_dis = [(k, v) for (k, v) in gvs_dis if k is not None]
            tower_grads_dis.append(gvs_dis)

            gvs_dis = optimizer_dis.compute_gradients(l2_dis, dis_vars)
            gvs_dis = [(k, v) for (k, v) in gvs_dis if k is not None]
            tower_grads_dis_l2.append(gvs_dis)

            print("Finished applying gradients.")

            target_vars['total_loss'] = loss_model
            target_vars['loss_energy'] = loss_energy
            target_vars['weights'] = weights
            target_vars['gvs'] = gvs

        target_vars['X'] = X
        target_vars['Y'] = Y
        target_vars['LABEL'] = LABEL
        target_vars['LABEL_POS'] = LABEL_POS
        target_vars['X_NOISE'] = X_NOISE
        target_vars['energy_pos'] = energy_pos
        target_vars['energy_start'] = energy_negs[0]

        if len(x_grads) >= 1:
            target_vars['x_grad'] = x_grads[-1]
            target_vars['x_grad_first'] = x_grads[0]
        else:
            target_vars['x_grad'] = tf.zeros(1)
            target_vars['x_grad_first'] = tf.zeros(1)

        target_vars['x_mod'] = x_mod
        target_vars['x_off'] = x_off
        target_vars['temp'] = temp
        target_vars['energy_neg'] = energy_neg
        target_vars['test_x_mod'] = test_x_mod
        target_vars['eps_begin'] = eps_begin
        target_vars['score_neg'] = score_neg
        target_vars['score_pos'] = score_pos

    if FLAGS.train:
        grads_model = average_gradients(tower_grads)
        train_op_model = optimizer.apply_gradients(grads_model)
        target_vars['train_op_model'] = train_op_model

        grads_model_l2 = average_gradients(tower_grads_l2)
        train_op_model_l2 = optimizer.apply_gradients(grads_model_l2)
        target_vars['train_op_model_l2'] = train_op_model_l2

        grads_model_dis = average_gradients(tower_grads_dis)
        train_op_dis = optimizer_dis.apply_gradients(grads_model_dis)
        target_vars['train_op_dis'] = train_op_dis

        grads_model_dis_l2 = average_gradients(tower_grads_dis_l2)
        train_op_dis_l2 = optimizer_dis.apply_gradients(grads_model_dis_l2)
        target_vars['train_op_dis_l2'] = train_op_dis_l2

    config = tf.ConfigProto()

    if hvd.size() > 1:
        config.gpu_options.visible_device_list = str(hvd.local_rank())

    sess = tf.Session(config=config)

    saver = loader = tf.train.Saver(max_to_keep=500)

    total_parameters = 0
    for variable in tf.trainable_variables():
        # shape is an array of tf.Dimension
        shape = variable.get_shape()
        variable_parameters = 1
        for dim in shape:
            variable_parameters *= dim.value
        total_parameters += variable_parameters
    print("Model has a total of {} parameters".format(total_parameters))

    sess.run(tf.global_variables_initializer())

    resume_itr = 0

    if (FLAGS.resume_iter != -1 or not FLAGS.train) and hvd.rank() == 0:
        model_file = osp.join(logdir, 'model_{}'.format(FLAGS.resume_iter))
        resume_itr = FLAGS.resume_iter
        saver.restore(sess, model_file)
        # optimistic_restore(sess, model_file)

    sess.run(hvd.broadcast_global_variables(0))
    print("Initializing variables...")

    print("Start broadcast")
    print("End broadcast")

    if FLAGS.train:
        train(target_vars, saver, sess,
              logger, data_loader, resume_itr,
              logdir)

    test(target_vars, saver, sess, logger, data_loader)
예제 #9
0
def construct_steps_dual(weights_pos, weights_color, model_pos, model_color, target_vars):
    steps = tf.constant(0)
    STEPS = target_vars['STEPS']
    c = lambda i, x: tf.less(i, STEPS)
    X, Y_first, Y_second = target_vars['X'], target_vars['Y_first'], target_vars['Y_second']
    X_feed = target_vars['X_feed']
    attention_mask = tf.zeros(1)

    def langevin_step(counter, x_mod):
        x_mod = x_mod + tf.random_normal(tf.shape(x_mod),
                                         mean=0.0,
                                         stddev=0.001)
        # latent = latent + tf.random_normal(tf.shape(latent), mean=0.0, stddev=0.01)
        energy_noise = model_pos.forward(
                    x_mod,
                    weights_pos,
                    label=Y_first,
                    reuse=True,
                    stop_at_grad=False,
                    stop_batch=True,
                    attention_mask=attention_mask)
        # energy_noise = tf.Print(energy_noise, [energy_noise, x_mod, latent])
        x_grad = tf.gradients(energy_noise, [x_mod])[0]

        x_mod = x_mod - FLAGS.step_lr * x_grad
        x_mod = tf.clip_by_value(x_mod, 0, 1)

        x_mod = x_mod + tf.random_normal(tf.shape(x_mod), mean=0.0, stddev=0.001)

        energy_noise = 1.0 * model_color.forward(
                    x_mod,
                    weights_color,
                    label=Y_second,
                    reuse=True,
                    stop_at_grad=False,
                    stop_batch=True,
                    attention_mask=attention_mask)

        x_grad = tf.gradients(energy_noise, [x_mod])[0]

        x_mod = x_mod - FLAGS.step_lr * x_grad
        x_mod = tf.clip_by_value(x_mod, 0, 1)

        counter = counter + 1
        # counter = tf.Print(counter, [counter], message="step")

        return counter, x_mod

    def langevin_merge_step(counter, x_mod):
        x_mod = x_mod + tf.random_normal(tf.shape(x_mod),
                                         mean=0.0,
                                         stddev=0.001)
        # latent = latent + tf.random_normal(tf.shape(latent), mean=0.0, stddev=0.01)
        energy_noise = 1 * model_pos.forward(
                    x_mod,
                    weights_pos,
                    label=Y_first,
                    reuse=True,
                    stop_at_grad=False,
                    stop_batch=True,
                    attention_mask=attention_mask) + \
                    1.0 * model_color.forward(
                    x_mod,
                    weights_color,
                    label=Y_second,
                    reuse=True,
                    stop_at_grad=False,
                    stop_batch=True,
                    attention_mask=attention_mask)

        # energy_noise = tf.Print(energy_noise, [energy_noise, x_mod, latent])
        x_grad = tf.gradients(energy_noise, [x_mod])[0]

        x_mod = x_mod - FLAGS.step_lr * x_grad
        x_mod = tf.clip_by_value(x_mod, 0, 1)

        counter = counter + 1

        return counter, x_mod

    steps, x_mod = tf.while_loop(c, langevin_merge_step, (steps, X))

    energy_pos = model_pos.forward(
                tf.stop_gradient(x_mod),
                weights_pos,
                label=Y_first,
                reuse=True,
                stop_at_grad=False,
                stop_batch=True,
                attention_mask=attention_mask)

    energy_color = model_color.forward(
                tf.stop_gradient(x_mod),
                weights_color,
                label=Y_second,
                reuse=True,
                stop_at_grad=False,
                stop_batch=True,
                attention_mask=attention_mask)

    energy_plus_pos = model_pos.forward(
                X_feed,
                weights_pos,
                label=Y_first,
                reuse=True,
                stop_at_grad=False,
                stop_batch=True,
                attention_mask=attention_mask)

    energy_plus_color = model_color.forward(
                X_feed,
                weights_color,
                label=Y_second,
                reuse=True,
                stop_at_grad=False,
                stop_batch=True,
                attention_mask=attention_mask)

    energy_neg = -tf.reduce_mean(tf.reduce_mean(energy_pos) + tf.reduce_mean(energy_color))
    energy_plus = tf.reduce_mean(tf.reduce_mean(energy_plus_pos) + tf.reduce_mean(energy_plus_color))
    loss_l2 = tf.reduce_mean(tf.square(energy_pos)) + tf.reduce_mean(tf.square(energy_color)) + tf.reduce_mean(tf.square(energy_plus_pos)) + tf.reduce_mean(tf.square(energy_plus_color))

    loss_total = energy_plus + energy_neg + loss_l2
    optimizer = AdamOptimizer(1e-4, beta1=0.0, beta2=0.99)
    gvs = optimizer.compute_gradients(loss_total)
    train_op = optimizer.apply_gradients(gvs)

    x_off = tf.reduce_mean(tf.abs(X_feed - x_mod))

    target_vars['X_final'] = x_mod
    target_vars['x_off'] = x_off
    target_vars['train_op'] = train_op
    target_vars['energy_neg'] = -energy_neg
    target_vars['energy_pos'] = energy_plus