Example #1
0
    def build_tower_graph(id_):
        tower_x = x[id_ * tf.shape(x)[0] // FLAGS.num_gpus:(id_ + 1) *
                    tf.shape(x)[0] // FLAGS.num_gpus]
        tower_font_source = font_source[id_ * tf.shape(font_source)[0] //
                                        FLAGS.num_gpus:(id_ + 1) *
                                        tf.shape(font_source)[0] //
                                        FLAGS.num_gpus]
        tower_char_source = char_source[id_ * tf.shape(char_source)[0] //
                                        FLAGS.num_gpus:(id_ + 1) *
                                        tf.shape(char_source)[0] //
                                        FLAGS.num_gpus]
        n = tf.shape(tower_x)[0]
        x_obs = tf.tile(tf.expand_dims(tower_x, 0), [1, 1, 1])

        def log_joint(observed):
            decoder, _, = VLAE(observed, n, is_training)
            log_pz_char, log_pz_font, log_px_z = decoder.local_log_prob(
                ['z_char', 'z_font', 'x'])
            return log_pz_char + log_pz_font + log_px_z

        encoder, _, _ = q_net(None, tower_x, is_training)
        qz_samples_font, log_qz_font = encoder.query('z_font',
                                                     outputs=True,
                                                     local_log_prob=True)
        qz_samples_char, log_qz_char = encoder.query('z_char',
                                                     outputs=True,
                                                     local_log_prob=True)

        encoder, _, _ = q_net(None, tower_font_source, is_training)
        qz_samples_font_source, log_qz_font_source = encoder.query(
            'z_font', outputs=True, local_log_prob=True)
        encoder, _, _ = q_net(None, tower_char_source, is_training)
        qz_samples_char_source, log_qz_char_source = encoder.query(
            'z_char', outputs=True, local_log_prob=True)

        lower_bound = tf.reduce_mean(
            zs.iwae(log_joint, {'x': x_obs}, {
                'z_font': [qz_samples_font, log_qz_font],
                'z_char': [qz_samples_char, log_qz_char]
            },
                    axis=0))

        lower_bound_pairwise = pairwise_alpha * tf.reduce_mean(
            zs.iwae(log_joint, {'x': x_obs}, {
                'z_font': [qz_samples_font_source, log_qz_font_source],
                'z_char': [qz_samples_char_source, log_qz_char_source]
            },
                    axis=0))

        grads = optimizer.compute_gradients(-lower_bound -
                                            lower_bound_pairwise)
        return grads, [lower_bound, lower_bound_pairwise]
Example #2
0
	def build_tower_graph(id_):
		tower_x = x[id_ * tf.shape(x)[0] // FLAGS.num_gpus:
		(id_ + 1) * tf.shape(x)[0] // FLAGS.num_gpus]
		tower_font_source = font_source[id_ * tf.shape(font_source)[0] // FLAGS.num_gpus:
		(id_ + 1) * tf.shape(font_source)[0] // FLAGS.num_gpus]
		tower_char_source = char_source[id_ * tf.shape(char_source)[0] // FLAGS.num_gpus:
		(id_ + 1) * tf.shape(char_source)[0] // FLAGS.num_gpus]
		n = tf.shape(tower_x)[0]
		x_obs = tf.tile(tf.expand_dims(tower_x, 0), [1, 1, 1])
		char_obs = tf.tile(tf.expand_dims(tower_char_source, 0), [1, 1, 1])
		font_obs = tf.tile(tf.expand_dims(tower_font_source, 0), [1, 1, 1])

		def log_joint(observed):
			decoder, _, = VLAE(observed, n, is_training)
			if 'z_char' in observed.keys() and 'z_font' in observed.keys():
				log_pz_char, log_pz_font, log_px_z = decoder.local_log_prob(['z_char', 'z_font', 'x'])
				return log_pz_char + log_pz_font + log_px_z

			if 'z_char' in observed.keys() and 'z_font' not in observed.keys():
				log_pz_char, log_px_z = decoder.local_log_prob(['z_char', 'x'])
				return log_pz_char + log_px_z

			if 'z_font' in observed.keys() and 'z_char' not in observed.keys():
				log_pz_font, log_px_z = decoder.local_log_prob(['z_font', 'x'])
				return log_pz_font + log_px_z

		encoder, _, _ = q_net(None, tower_x, is_training)
		qz_samples_font, log_qz_font = encoder.query('z_font', outputs=True,
													 local_log_prob=True)
		encoder, _, _ = q_net(None, tower_x, is_training)
		qz_samples_char, log_qz_char = encoder.query('z_char', outputs=True,
													 local_log_prob=True)
		lower_bound = tf.reduce_mean(
			zs.iwae(log_joint, {'x': x_obs},
					{'z_font': [qz_samples_font, log_qz_font], 'z_char': [qz_samples_char, log_qz_char]}, axis=0))

		info_char = tf.reduce_mean(
			zs.iwae(log_joint, {'x': char_obs},
					{'z_char': [qz_samples_char, log_qz_char]}, axis=0))

		info_font = tf.reduce_mean(
			zs.iwae(log_joint, {'x': font_obs},
					{'z_char': [qz_samples_char, log_qz_char]}, axis=0))

		encoder_var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='encoder')

		grads = optimizer.compute_gradients(-lower_bound)
		encoder_grads = optimizer.compute_gradients(-(info_char + info_font), var_list=encoder_var_list)
		return grads + encoder_grads, [lower_bound, info_char, info_font]
Example #3
0
    def build_tower_graph(x, id_):
        tower_x_orig = x_orig[id_ * tf.shape(x_orig)[0] //
                              FLAGS.num_gpus:(id_ + 1) * tf.shape(x_orig)[0] //
                              FLAGS.num_gpus]
        tower_x = x[id_ * tf.shape(x)[0] // FLAGS.num_gpus:(id_ + 1) *
                    tf.shape(x)[0] // FLAGS.num_gpus]
        tower_code = code[id_ * tf.shape(code)[0] // FLAGS.num_gpus:(id_ + 1) *
                          tf.shape(code)[0] // FLAGS.num_gpus]
        n = tf.shape(tower_x)[0]
        x_obs = tf.tile(tf.expand_dims(tower_x, 0), [1, 1, 1])

        def log_joint(observed):
            decoder, _ = vae(observed, n, tower_code, is_training)
            log_pz, log_px_z = decoder.local_log_prob(['z', 'x'])
            return log_pz + log_px_z

        encoder, _ = q_net(None, tower_x, tower_code, is_training)
        qz_samples, log_qz = encoder.query('z',
                                           outputs=True,
                                           local_log_prob=True)

        lower_bound = tf.reduce_mean(
            zs.iwae(log_joint, {'x': x_obs}, {'z': [qz_samples, log_qz]},
                    axis=0))

        grads = optimizer.compute_gradients(-lower_bound)
        return grads, lower_bound
    def build_tower_graph(x, id_):
        tower_x_orig = x_orig[id_ * tf.shape(x_orig)[0] // FLAGS.num_gpus:
        (id_ + 1) * tf.shape(x_orig)[0] // FLAGS.num_gpus]
        tower_x_source = x_source[id_ * tf.shape(x_orig)[0] // FLAGS.num_gpus:
        (id_ + 1) * tf.shape(x_source)[0] // FLAGS.num_gpus]
        tower_x = x[id_ * tf.shape(x)[0] // FLAGS.num_gpus:
        (id_ + 1) * tf.shape(x)[0] // FLAGS.num_gpus]
        tower_code = code[id_ * tf.shape(code)[0] // FLAGS.num_gpus:
        (id_ + 1) * tf.shape(code)[0] // FLAGS.num_gpus]
        tower_code_source = code_source[id_ * tf.shape(code_source)[0] // FLAGS.num_gpus:
        (id_ + 1) * tf.shape(code_source)[0] // FLAGS.num_gpus]
        n = tf.shape(tower_x)[0]
        x_obs = tf.tile(tf.expand_dims(tower_x, 0), [1, 1, 1])

        def log_joint(observed):
            decoder, _ = vae(observed, n, tower_code, is_training)
            log_pz, log_px_z = decoder.local_log_prob(['z', 'x'])
            return log_pz + log_px_z

        encoder, _ = q_net(None, tower_x_source, tower_code_source, is_training)
        qz_samples, log_qz = encoder.query('z', outputs=True,
                                      local_log_prob=True)

        _ , train_x_recon = vae({'z':qz_samples} , n , tower_code , is_training)
        train_x_recon = tf.reshape(train_x_recon , [ -1 , n_xl , n_xl , 1])
        tv_loss = (tf.nn.l2_loss(train_x_recon[:, 1:, :, :] - train_x_recon[:, :(n_xl - 1), :, :]) / n_xl
                   + tf.nn.l2_loss(train_x_recon[:, :, 1:, :] - train_x_recon[:, :, :(n_xl - 1), :]) / n_xl) * args.tv_alpha

        lower_bound = tf.reduce_mean(
            zs.iwae(log_joint, {'x': x_obs}, {'z': [qz_samples, log_qz]}, axis=0))

        grads = optimizer.compute_gradients(-lower_bound + tv_loss)
        return grads, [lower_bound,tv_loss]
Example #5
0
                    tf.int32)
    x = tf.placeholder(tf.int32, shape=[None, n_x], name='x')
    x_obs = tf.tile(tf.expand_dims(x, 0), [n_particles, 1, 1])
    n = tf.shape(x)[0]

    def log_joint(observed):
        model = vae(observed, n, n_x, n_z, n_particles)
        log_pz, log_px_z = model.local_log_prob(['z', 'x'])
        return log_pz + log_px_z

    variational = q_net({}, x, n_z, n_particles)
    qz_samples, log_qz = variational.query('z',
                                           outputs=True,
                                           local_log_prob=True)
    lower_bound = tf.reduce_mean(
        zs.iwae(log_joint, {'x': x_obs}, {'z': [qz_samples, log_qz]}, axis=0))

    learning_rate_ph = tf.placeholder(tf.float32, shape=[], name='lr')
    optimizer = tf.train.AdamOptimizer(learning_rate_ph, epsilon=1e-4)
    grads = optimizer.compute_gradients(-lower_bound)
    infer = optimizer.apply_gradients(grads)

    params = tf.trainable_variables()
    for i in params:
        print(i.name, i.get_shape())

    saver = tf.train.Saver(max_to_keep=10)

    # Run the inference
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
Example #6
0
    def build_tower_graph(id_):
        tower_x = x[id_ * tf.shape(x)[0] // FLAGS.num_gpus:(id_ + 1) *
                    tf.shape(x)[0] // FLAGS.num_gpus]
        tower_font_source = font_source[id_ * tf.shape(font_source)[0] //
                                        FLAGS.num_gpus:(id_ + 1) *
                                        tf.shape(font_source)[0] //
                                        FLAGS.num_gpus]
        tower_char_source = char_source[id_ * tf.shape(char_source)[0] //
                                        FLAGS.num_gpus:(id_ + 1) *
                                        tf.shape(char_source)[0] //
                                        FLAGS.num_gpus]
        tower_code = code[id_ * tf.shape(code)[0] // FLAGS.num_gpus:(id_ + 1) *
                          tf.shape(code)[0] // FLAGS.num_gpus]
        n = tf.shape(tower_x)[0]
        x_obs = tf.tile(tf.expand_dims(tower_x, 0), [1, 1, 1])

        if args.mode == 'font':

            def log_joint(observed):
                decoder, _ = VAE_font(observed, n, tower_code, is_training)
                log_pz, log_px_z = decoder.local_log_prob(['z_font', 'x'])
                return log_pz + log_px_z

            encoder, _ = q_net_font(None, tower_font_source, is_training)
            qz_samples, log_qz = encoder.query('z_font',
                                               outputs=True,
                                               local_log_prob=True)
            _, _ = q_net_char(None, tower_char_source, is_training)
            _, _ = VAE(None, tf.shape(tower_char_source)[0], is_training)

            lower_bound = tf.reduce_mean(
                zs.iwae(log_joint, {'x': x_obs},
                        {'z_font': [qz_samples, log_qz]},
                        axis=0))
            gen_var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
                                                      scope='encoder_font') + \
                                    tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
                                                      scope='decoder_font')

            grads = optimizer.compute_gradients(-lower_bound,
                                                var_list=gen_var_list)

        elif args.mode == 'char':

            def log_joint(observed):
                decoder, _, = VAE(observed, n, is_training)
                log_pz_font, log_pz_char, log_px_z = decoder.local_log_prob(
                    ['z_font', 'z_char', 'x'])
                return log_pz_font + log_pz_char + log_px_z

            encoder_font, _ = q_net_font(None, tower_font_source, is_training)
            qz_samples_font, log_qz_font = encoder_font.query(
                'z_font', outputs=True, local_log_prob=True)
            encoder_char, _ = q_net_char(None, tower_char_source, is_training)
            qz_samples_char, log_qz_char = encoder_char.query(
                'z_char', outputs=True, local_log_prob=True)

            lower_bound = tf.reduce_mean(
                zs.iwae(log_joint, {'x': x_obs}, {
                    'z_font': [qz_samples_font, log_qz_font],
                    'z_char': [qz_samples_char, log_qz_char]
                },
                        axis=0))
            gen_var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
                                             scope='encoder_char') + \
                           tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
                                             scope='decoder_all')

            grads = optimizer.compute_gradients(-lower_bound,
                                                var_list=gen_var_list)

        else:

            def log_joint(observed):
                decoder, _, = VAE(observed, n, is_training)
                log_pz_font, log_pz_char, log_px_z = decoder.local_log_prob(
                    ['z_font', 'z_char', 'x'])
                return log_pz_font + log_pz_char + log_px_z

            encoder_font, _ = q_net_font(None, tower_x, is_training)
            qz_samples_font, log_qz_font = encoder_font.query(
                'z_font', outputs=True, local_log_prob=True)
            encoder_char, _ = q_net_char(None, tower_x, is_training)
            qz_samples_char, log_qz_char = encoder_char.query(
                'z_char', outputs=True, local_log_prob=True)

            lower_bound = tf.reduce_mean(
                zs.iwae(log_joint, {'x': x_obs}, {
                    'z_font': [qz_samples_font, log_qz_font],
                    'z_char': [qz_samples_char, log_qz_char]
                },
                        axis=0))
            gen_var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
                                             scope='encoder_font') + \
                           tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
                                             scope='encoder_char') + \
                           tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
                                             scope='decoder_all')

            grads = optimizer.compute_gradients(-lower_bound,
                                                var_list=gen_var_list)

        return grads, [lower_bound]