Пример #1
0
def conv_layer(layer_name, input, in_dim, in_ch, out_dim, out_size, summary_conv=False):
    with tf.name_scope(layer_name):
        # Initialize weights and bias
        W_conv = weight_variable_with_decay([in_dim, in_dim, in_ch, out_dim], 0.004)
        b_conv = bias_variable([out_dim])

        # Log weights and bias
        tf.summary.histogram("weights", W_conv)
        tf.summary.histogram("biases", b_conv)

        # Draw weights in 8x8 grid for the first conv layer
        if summary_conv:
            kernel_grid = put_kernels_on_grid(W_conv, (8, 8))
            tf.summary.image("kernel", kernel_grid, max_outputs=1)

        # Draw conv activation in 8x8 grid
        activation = tf.nn.bias_add(conv2d(input, W_conv), b_conv)
        # Only draw the activation for the first image in a batch
        activation_sample = tf.slice(activation, [0, 0, 0, 0], [1, out_size, out_size, out_dim])
        activation_grid = put_kernels_on_grid(tf.transpose(activation_sample, [1, 2, 0, 3]), (8, 8))
        tf.summary.image("conv/activatins", activation_grid, max_outputs=1)

        # Draw relu activation in 8x8 grid
        activation = tf.nn.relu(activation)
        # Only draw the activation for the first image in a batch
        activation_sample = tf.slice(activation, [0, 0, 0, 0], [1, out_size, out_size, out_dim])
        activation_grid = put_kernels_on_grid(tf.transpose(activation_sample, [1, 2, 0, 3]), (8, 8))
        tf.summary.image("relu/activatins", activation_grid, max_outputs=1)

        # 2x2 max pooling
        pool = max_pool_2x2(activation)

        return tf.nn.lrn(pool, 4, bias=1.0, alpha=0.001 / 9.0, beta=0.75, name='norm')
def conv_layer(layer_name,
               input,
               in_dim,
               in_ch,
               out_dim,
               out_size,
               summary_conv=False):
    with tf.name_scope(layer_name):
        # Initialize weights and bias
        W_conv = weight_variable_with_decay([in_dim, in_dim, in_ch, out_dim],
                                            0.004)
        b_conv = bias_variable([out_dim])

        # Log weights and bias
        tf.summary.histogram("weights", W_conv)
        tf.summary.histogram("biases", b_conv)

        # Draw weights in 8x8 grid for the first conv layer
        if summary_conv:
            kernel_grid = put_kernels_on_grid(W_conv, (8, 8))
            tf.summary.image("kernel", kernel_grid, max_outputs=1)

        # Draw conv activation in 8x8 grid
        activation = tf.nn.bias_add(conv2d(input, W_conv), b_conv)
        # Only draw the activation for the first image in a batch
        activation_sample = tf.slice(activation, [0, 0, 0, 0],
                                     [1, out_size, out_size, out_dim])
        activation_grid = put_kernels_on_grid(
            tf.transpose(activation_sample, [1, 2, 0, 3]), (8, 8))
        tf.summary.image("conv/activatins", activation_grid, max_outputs=1)

        # Draw relu activation in 8x8 grid
        activation = tf.nn.relu(activation)
        # Only draw the activation for the first image in a batch
        activation_sample = tf.slice(activation, [0, 0, 0, 0],
                                     [1, out_size, out_size, out_dim])
        activation_grid = put_kernels_on_grid(
            tf.transpose(activation_sample, [1, 2, 0, 3]), (8, 8))
        tf.summary.image("relu/activatins", activation_grid, max_outputs=1)

        # 2x2 max pooling
        pool = max_pool_2x2(activation)

        return tf.nn.lrn(pool,
                         4,
                         bias=1.0,
                         alpha=0.001 / 9.0,
                         beta=0.75,
                         name='norm')
Пример #3
0
def main(_):
    cifar10 = cifar.Cifar()
    cifar10.ReadDataSets(one_hot=True, raw=True)

    # Create the model
    x = tf.placeholder(tf.float32, [None, 3072])

    x_image = tf.transpose(tf.reshape(x, [-1, 3, 32, 32]), [0, 2, 3, 1])
    image_flat = tf.reshape(x_image, [BATCH_SIZE, 32 * 32 * 3])

    image_grid = put_kernels_on_grid(
        tf.transpose(tf.reshape(x, [-1, 3, 32, 32]), [2, 3, 1, 0]), (8, 8))
    tf.summary.image("images", image_grid, max_outputs=1)

    z_mean, z_stddev = recognition(x_image)

    samples = tf.random_normal([BATCH_SIZE, LATENT_VAR_NUM],
                               0,
                               1,
                               dtype=tf.float32)
    guessed_z = tf.add(tf.multiply(samples, z_stddev), z_mean)

    generated_images = generation(guessed_z)

    generated_flat = tf.reshape(generated_images, [BATCH_SIZE, 32 * 32 * 3])
    tf.summary.histogram('generated_flat', generated_flat)

    generation_loss = -tf.reduce_sum(
        image_flat * tf.log(1e-8 + generated_flat) +
        (1 - image_flat) * tf.log(1e-8 + 1 - generated_flat), 1)
    tf.summary.histogram('generation_loss', generation_loss)

    latent_loss = 0.5 * tf.reduce_sum(
        tf.square(z_mean) + tf.square(z_stddev) - tf.log(tf.square(z_stddev)) -
        1, 1)

    cost = tf.reduce_mean(generation_loss + latent_loss)
    tf.summary.scalar('loss', cost)

    global_step = tf.Variable(0, trainable=False)
    lr = learning_rate(global_step)

    train_step = tf.train.AdamOptimizer(lr).minimize(cost)

    sess = tf.InteractiveSession()

    merged = tf.summary.merge_all()
    train_writer = tf.summary.FileWriter('train', sess.graph)

    sess.run(tf.global_variables_initializer())

    for i in range(EPOCH):
        batch = cifar10.train.next_batch(BATCH_SIZE)
        if i % 100 == 0:
            print("step %d" % i)
        summary, _ = sess.run([merged, train_step], feed_dict={x: batch[0]})
        train_writer.add_summary(summary, i)

    print("Done")
Пример #4
0
    def __init__(self, env, observations, timer, params):
        self.regularizer = tf.contrib.layers.l2_regularizer(scale=1e-10)

        img_model_name = params.image_model_name
        fc_layers = params.fc_layers
        fc_size = params.fc_size
        lowdim_model_name = params.lowdim_model_name
        past_frames = params.stack_past_frames

        image_obs = has_image_observations(env.observation_space.spaces['obs'])
        num_actions = env.action_space.n

        if image_obs:
            # convolutions
            if img_model_name == 'convnet_simple':
                conv_filters = self._convnet_simple(observations, [(32, 3, 2)] * 4)
            else:
                raise Exception('Unknown model name')

            encoded_input = tf.contrib.layers.flatten(conv_filters)
        else:
            # low-dimensional input
            if lowdim_model_name == 'simple_fc':
                frames = tf.split(observations, past_frames, axis=1)
                fc_encoder = tf.make_template('fc_encoder', self._fc_frame_encoder, create_scope_now_=True)
                encoded_frames = [fc_encoder(frame) for frame in frames]
                encoded_input = tf.concat(encoded_frames, axis=1)
            else:
                raise Exception('Unknown lowdim model name')

        if params.ignore_timer:
            timer = tf.multiply(timer, 0.0)

        encoded_input_with_timer = tf.concat([encoded_input, tf.expand_dims(timer, 1)], axis=1)

        fc = encoded_input_with_timer
        for _ in range(fc_layers - 1):
            fc = dense(fc, fc_size, self.regularizer)

        # fully-connected layers to generate actions
        actions_fc = dense(fc, fc_size // 2, self.regularizer)
        self.actions = tf.contrib.layers.fully_connected(actions_fc, num_actions, activation_fn=None)
        self.best_action_deterministic = tf.argmax(self.actions, axis=1)
        self.actions_prob_distribution = CategoricalProbabilityDistribution(self.actions)
        self.act = self.actions_prob_distribution.sample()

        value_fc = dense(fc, fc_size // 2, self.regularizer)
        self.value = tf.squeeze(tf.contrib.layers.fully_connected(value_fc, 1, activation_fn=None), axis=[1])

        if image_obs:
            # summaries
            with tf.variable_scope('conv1', reuse=True):
                weights = tf.get_variable('weights')
            with tf.name_scope('a2c_agent_summary_conv'):
                if weights.shape[2].value in [1, 3, 4]:
                    tf.summary.image('conv1/kernels', put_kernels_on_grid(weights), max_outputs=1)

        log.info('Total parameters in the model: %d', count_total_parameters())
Пример #5
0
def main(_):
    cifar10 = cifar.Cifar()
    cifar10.ReadDataSets(one_hot=True)

    keep_prob = tf.placeholder(tf.float32)

    # Create the model
    x = tf.placeholder(tf.float32, [None, 3, 32, 32])

    # Define loss and optimizer
    y_ = tf.placeholder(tf.float32, [None, 10])

    x_image = tf.transpose(x, [0, 2, 3, 1])

    tf.summary.image("images", x_image, max_outputs=1)

    h_pool1 = conv_layer("conv_layer1", x_image, 5, 3, 64, 32, summary_conv=True)
    h_pool2 = conv_layer("conv_layer2", h_pool1, 5, 64, 64, 16)

    h_conv3_flat = tf.reshape(h_pool2, [-1, 8 * 8 * 64])

    h_fc1 = fc_layer('fc_layer1', h_conv3_flat, 8 * 8 * 64, 384, activation=True)
    h_fc2 = fc_layer('fc_layer2', h_fc1, 384, 192, activation=True)
    y_conv = fc_layer('fc_layer3', h_fc2, 192, 10, activation=False)

    global_step = tf.Variable(0, trainable=False)
    lr = learning_rate(global_step)

    total_loss = loss(y_conv, y_)
    optimizer = tf.train.AdamOptimizer(lr)
    grads_and_vars = optimizer.compute_gradients(total_loss)
    with tf.name_scope("conv_layer1_grad"):
        kernel_grad_grid = put_kernels_on_grid(grads_and_vars[0][0], (8, 8))
        tf.summary.image("weight_grad", kernel_grad_grid, max_outputs=1)

    train_step = optimizer.apply_gradients(grads_and_vars, global_step=global_step)
    correct_prediction = tf.equal(tf.argmax(y_conv, 1), tf.argmax(y_, 1))
    accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

    sess = tf.InteractiveSession()

    merged = tf.summary.merge_all()
    train_writer = tf.summary.FileWriter('train', sess.graph)

    sess.run(tf.global_variables_initializer())

    for i in range(EPOCH):
        batch = cifar10.train.next_batch(BATCH_SIZE)
        if i % 100 == 0:
            test_accuracy = accuracy.eval(feed_dict={x: cifar10.test.images, y_: cifar10.test.labels})
            print("step %d, test accuracy %g" % (i, test_accuracy))
        summary, _ = sess.run([merged, train_step], feed_dict={x: batch[0], y_: batch[1]})
        train_writer.add_summary(summary, i)

    print("test accuracy %g" % accuracy.eval(feed_dict={
        x: cifar10.test.images, y_: cifar10.test.labels}))
Пример #6
0
def generation(z):
    with tf.variable_scope("generation"):
        z_develop = fc_layer('z_matrix', z, LATENT_VAR_NUM, 8 * 8 * 32)
        z_matrix = tf.nn.relu(tf.reshape(z_develop, [BATCH_SIZE, 8, 8, 32]))
        tf.summary.histogram('z_matrix', z_matrix)
        h1 = tf.nn.relu(
            conv_transpose("g_h1", z_matrix, 32, 16, [BATCH_SIZE, 16, 16, 16]))
        tf.summary.histogram('h1', h1)
        h2 = tf.nn.sigmoid(
            conv_transpose("g_h2", h1, 16, 3, [BATCH_SIZE, 32, 32, 3]))
        tf.summary.histogram('h2', h2)

        kernel_grad_grid = put_kernels_on_grid(tf.transpose(h2, [1, 2, 3, 0]),
                                               (8, 8))
        tf.summary.image("gen_images", kernel_grad_grid, max_outputs=1)

        return h2
def main(_):
    cifar10 = cifar.Cifar()
    cifar10.ReadDataSets(one_hot=True)

    keep_prob = tf.placeholder(tf.float32)

    # Create the model
    x = tf.placeholder(tf.float32, [None, 3, 32, 32])

    # Define loss and optimizer
    y_ = tf.placeholder(tf.float32, [None, 10])

    x_image = tf.transpose(x, [0, 2, 3, 1])

    tf.summary.image("images", x_image, max_outputs=1)

    h_pool1 = conv_layer("conv_layer1",
                         x_image,
                         5,
                         3,
                         64,
                         32,
                         summary_conv=True)
    h_pool2 = conv_layer("conv_layer2", h_pool1, 5, 64, 64, 16)

    h_conv3_flat = tf.reshape(h_pool2, [-1, 8 * 8 * 64])

    h_fc1 = fc_layer('fc_layer1',
                     h_conv3_flat,
                     8 * 8 * 64,
                     384,
                     activation=True)
    h_fc2 = fc_layer('fc_layer2', h_fc1, 384, 192, activation=True)
    y_conv = fc_layer('fc_layer3', h_fc2, 192, 10, activation=False)

    global_step = tf.Variable(0, trainable=False)
    lr = learning_rate(global_step)

    total_loss = loss(y_conv, y_)
    optimizer = tf.train.AdamOptimizer(lr)
    grads_and_vars = optimizer.compute_gradients(total_loss)
    with tf.name_scope("conv_layer1_grad"):
        kernel_grad_grid = put_kernels_on_grid(grads_and_vars[0][0], (8, 8))
        tf.summary.image("weight_grad", kernel_grad_grid, max_outputs=1)

    train_step = optimizer.apply_gradients(grads_and_vars,
                                           global_step=global_step)
    correct_prediction = tf.equal(tf.argmax(y_conv, 1), tf.argmax(y_, 1))
    accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

    sess = tf.InteractiveSession()

    merged = tf.summary.merge_all()
    train_writer = tf.summary.FileWriter('train', sess.graph)

    sess.run(tf.global_variables_initializer())

    for i in range(EPOCH):
        batch = cifar10.train.next_batch(BATCH_SIZE)
        if i % 100 == 0:
            test_accuracy = accuracy.eval(feed_dict={
                x: cifar10.test.images,
                y_: cifar10.test.labels
            })
            print("step %d, test accuracy %g" % (i, test_accuracy))
        summary, _ = sess.run([merged, train_step],
                              feed_dict={
                                  x: batch[0],
                                  y_: batch[1]
                              })
        train_writer.add_summary(summary, i)

    print("test accuracy %g" % accuracy.eval(feed_dict={
        x: cifar10.test.images,
        y_: cifar10.test.labels
    }))
Пример #8
0
def get_model(sess, image_shape=(32, 32, 3), gf_dim=64, df_dim=64, batch_size=64,
              name="autoencoder"):
    K.set_session(sess)
    with tf.variable_scope(name):
        # sizes
        ch = image_shape[2]
        rows = [4, 8, 16, 32]
        cols = [4, 8, 16, 32]

        # nets
        G = generator(batch_size, gf_dim, ch, rows, cols)
        G.compile("sgd", "mse")
        g_vars = G.trainable_weights
        print "G.shape: ", G.output_shape

        E = encoder(batch_size, df_dim, ch, rows, cols)
        E.compile("sgd", "mse")
        e_vars = E.trainable_weights
        print "E.shape: ", E.output_shape

        D = discriminator(batch_size, df_dim, ch, rows, cols)
        D.compile("sgd", "mse")
        d_vars = D.trainable_weights
        print "D.shape: ", D.output_shape

        Z2 = Input(batch_shape=(batch_size, z_dim), name='more_noise')
        Z = G.input
        Img = D.input
        image_grid = put_kernels_on_grid(tf.transpose(Img, [1, 2, 3, 0]), (8, 8))
        sum_img = tf.summary.image("Img", image_grid, max_outputs=1)
        G_train = G(Z)
        E_mean, E_logsigma = E(Img)
        G_dec = G(E_mean + Z2 * E_logsigma)
        D_fake, F_fake = D(G_train)
        D_dec_fake, F_dec_fake = D(G_dec)
        D_legit, F_legit = D(Img)

        # costs
        recon_vs_gan = 1e-6
        like_loss = tf.reduce_mean(tf.square(F_legit - F_dec_fake)) / 2.
        kl_loss = tf.reduce_mean(-E_logsigma + .5 * (-1 + tf.exp(2. * E_logsigma) + tf.square(E_mean)))

        d_loss_legit = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(D_legit, tf.ones_like(D_legit)))
        d_loss_fake1 = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(D_fake, tf.zeros_like(D_fake)))
        d_loss_fake2 = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(D_dec_fake, tf.zeros_like(D_dec_fake)))
        d_loss_fake = d_loss_fake1 + d_loss_fake2
        d_loss = d_loss_legit + d_loss_fake

        g_loss1 = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(D_fake, tf.ones_like(D_fake)))
        g_loss2 = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(D_dec_fake, tf.ones_like(D_dec_fake)))
        g_loss = g_loss1 + g_loss2 + recon_vs_gan * like_loss
        e_loss = kl_loss + like_loss

        # optimizers
        print "Generator variables:"
        for v in g_vars:
            print v.name
        print "Discriminator variables:"
        for v in d_vars:
            print v.name
        print "Encoder variables:"
        for v in e_vars:
            print v.name

        e_optim = tf.train.AdamOptimizer(learning_rate, beta1=beta1).minimize(e_loss, var_list=e_vars)
        d_optim = tf.train.AdamOptimizer(learning_rate, beta1=beta1).minimize(d_loss, var_list=d_vars)
        g_optim = tf.train.AdamOptimizer(learning_rate, beta1=beta1).minimize(g_loss, var_list=g_vars)
        sess.run(tf.global_variables_initializer())

    # summaries
    sum_d_loss_legit = tf.summary.scalar("d_loss_legit", d_loss_legit)
    sum_d_loss_fake = tf.summary.scalar("d_loss_fake", d_loss_fake)
    sum_d_loss = tf.summary.scalar("d_loss", d_loss)
    sum_g_loss = tf.summary.scalar("g_loss", g_loss)
    sum_e_loss = tf.summary.scalar("e_loss", e_loss)
    sum_e_mean = tf.summary.histogram("e_mean", E_mean)
    sum_e_sigma = tf.summary.histogram("e_sigma", tf.exp(E_logsigma))
    sum_Z = tf.summary.histogram("Z", Z)
    image_grid = put_kernels_on_grid(tf.transpose(G_train, [1, 2, 3, 0]), (8, 8))
    sum_gen = tf.summary.image("G", image_grid, max_outputs=1)
    image_grid = put_kernels_on_grid(tf.transpose(G_dec, [1, 2, 3, 0]), (8, 8))
    sum_dec = tf.summary.image("E", image_grid, max_outputs=1)

    g_sum = tf.summary.merge([sum_Z, sum_gen, sum_d_loss_fake, sum_g_loss, sum_img])
    e_sum = tf.summary.merge([sum_dec, sum_e_loss, sum_e_mean, sum_e_sigma])
    d_sum = tf.summary.merge([sum_d_loss_legit, sum_d_loss])
    writer = tf.summary.FileWriter("train", sess.graph)

    # functions
    def train_d(images, z, counter, sess=sess):
        z2 = np.random.normal(0., 1., z.shape)
        outputs = [d_loss, d_loss_fake, d_loss_legit, d_sum, d_optim]
        images = np.transpose(np.reshape(images, (-1, 3, 32, 32)), (0, 2, 3, 1))
        with tf.control_dependencies(outputs):
            updates = [tf.assign(p, new_p) for (p, new_p) in D.updates]
        outs = sess.run(outputs + updates, feed_dict={Img: images, Z: z, Z2: z2, K.learning_phase(): 1})
        dl, dlf, dll, sums = outs[:4]
        writer.add_summary(sums, counter)
        return dl, dlf, dll

    def train_g(images, z, counter, sess=sess):
        # generator
        z2 = np.random.normal(0., 1., z.shape)
        outputs = [g_loss, G_train, g_sum, g_optim]
        images = np.transpose(np.reshape(images, (-1, 3, 32, 32)), (0, 2, 3, 1))
        with tf.control_dependencies(outputs):
            updates = [tf.assign(p, new_p) for (p, new_p) in G.updates]
        outs = sess.run(outputs + updates, feed_dict={Img: images, Z: z, Z2: z2, K.learning_phase(): 1})
        gl, samples, sums = outs[:3]
        writer.add_summary(sums, counter)
        # encoder
        outputs = [e_loss, G_dec, e_sum, e_optim]
        with tf.control_dependencies(outputs):
            updates = [tf.assign(p, new_p) for (p, new_p) in E.updates]
        outs = sess.run(outputs + updates, feed_dict={Img: images, Z: z, Z2: z2, K.learning_phase(): 1})
        gl, samples, sums = outs[:3]
        writer.add_summary(sums, counter)

        return gl, samples, images

    def sampler(z, x):
        code = E.predict(x, batch_size=batch_size)[0]
        out = G.predict(code, batch_size=batch_size)
        return out, x

    return train_g, train_d, sampler, [G, D, E]
Пример #9
0
def get_model(sess,
              image_shape=(32, 32, 3),
              gf_dim=64,
              df_dim=64,
              batch_size=64,
              name="autoencoder"):
    K.set_session(sess)
    with tf.variable_scope(name):
        # sizes
        ch = image_shape[2]
        rows = [4, 8, 16, 32]
        cols = [4, 8, 16, 32]

        # nets
        G = generator(batch_size, gf_dim, ch, rows, cols)
        G.compile("sgd", "mse")
        g_vars = G.trainable_weights
        print "G.shape: ", G.output_shape

        E = encoder(batch_size, df_dim, ch, rows, cols)
        E.compile("sgd", "mse")
        e_vars = E.trainable_weights
        print "E.shape: ", E.output_shape

        D = discriminator(batch_size, df_dim, ch, rows, cols)
        D.compile("sgd", "mse")
        d_vars = D.trainable_weights
        print "D.shape: ", D.output_shape

        Z2 = Input(batch_shape=(batch_size, z_dim), name='more_noise')
        Z = G.input
        Img = D.input
        image_grid = put_kernels_on_grid(tf.transpose(Img, [1, 2, 3, 0]),
                                         (8, 8))
        sum_img = tf.summary.image("Img", image_grid, max_outputs=1)
        G_train = G(Z)
        E_mean, E_logsigma = E(Img)
        G_dec = G(E_mean + Z2 * E_logsigma)
        D_fake, F_fake = D(G_train)
        D_dec_fake, F_dec_fake = D(G_dec)
        D_legit, F_legit = D(Img)

        # costs
        recon_vs_gan = 1e-6
        like_loss = tf.reduce_mean(tf.square(F_legit - F_dec_fake)) / 2.
        kl_loss = tf.reduce_mean(
            -E_logsigma + .5 *
            (-1 + tf.exp(2. * E_logsigma) + tf.square(E_mean)))

        d_loss_legit = tf.reduce_mean(
            tf.nn.sigmoid_cross_entropy_with_logits(D_legit,
                                                    tf.ones_like(D_legit)))
        d_loss_fake1 = tf.reduce_mean(
            tf.nn.sigmoid_cross_entropy_with_logits(D_fake,
                                                    tf.zeros_like(D_fake)))
        d_loss_fake2 = tf.reduce_mean(
            tf.nn.sigmoid_cross_entropy_with_logits(D_dec_fake,
                                                    tf.zeros_like(D_dec_fake)))
        d_loss_fake = d_loss_fake1 + d_loss_fake2
        d_loss = d_loss_legit + d_loss_fake

        g_loss1 = tf.reduce_mean(
            tf.nn.sigmoid_cross_entropy_with_logits(D_fake,
                                                    tf.ones_like(D_fake)))
        g_loss2 = tf.reduce_mean(
            tf.nn.sigmoid_cross_entropy_with_logits(D_dec_fake,
                                                    tf.ones_like(D_dec_fake)))
        g_loss = g_loss1 + g_loss2 + recon_vs_gan * like_loss
        e_loss = kl_loss + like_loss

        # optimizers
        print "Generator variables:"
        for v in g_vars:
            print v.name
        print "Discriminator variables:"
        for v in d_vars:
            print v.name
        print "Encoder variables:"
        for v in e_vars:
            print v.name

        e_optim = tf.train.AdamOptimizer(learning_rate,
                                         beta1=beta1).minimize(e_loss,
                                                               var_list=e_vars)
        d_optim = tf.train.AdamOptimizer(learning_rate,
                                         beta1=beta1).minimize(d_loss,
                                                               var_list=d_vars)
        g_optim = tf.train.AdamOptimizer(learning_rate,
                                         beta1=beta1).minimize(g_loss,
                                                               var_list=g_vars)
        sess.run(tf.global_variables_initializer())

    # summaries
    sum_d_loss_legit = tf.summary.scalar("d_loss_legit", d_loss_legit)
    sum_d_loss_fake = tf.summary.scalar("d_loss_fake", d_loss_fake)
    sum_d_loss = tf.summary.scalar("d_loss", d_loss)
    sum_g_loss = tf.summary.scalar("g_loss", g_loss)
    sum_e_loss = tf.summary.scalar("e_loss", e_loss)
    sum_e_mean = tf.summary.histogram("e_mean", E_mean)
    sum_e_sigma = tf.summary.histogram("e_sigma", tf.exp(E_logsigma))
    sum_Z = tf.summary.histogram("Z", Z)
    image_grid = put_kernels_on_grid(tf.transpose(G_train, [1, 2, 3, 0]),
                                     (8, 8))
    sum_gen = tf.summary.image("G", image_grid, max_outputs=1)
    image_grid = put_kernels_on_grid(tf.transpose(G_dec, [1, 2, 3, 0]), (8, 8))
    sum_dec = tf.summary.image("E", image_grid, max_outputs=1)

    g_sum = tf.summary.merge(
        [sum_Z, sum_gen, sum_d_loss_fake, sum_g_loss, sum_img])
    e_sum = tf.summary.merge([sum_dec, sum_e_loss, sum_e_mean, sum_e_sigma])
    d_sum = tf.summary.merge([sum_d_loss_legit, sum_d_loss])
    writer = tf.summary.FileWriter("train", sess.graph)

    # functions
    def train_d(images, z, counter, sess=sess):
        z2 = np.random.normal(0., 1., z.shape)
        outputs = [d_loss, d_loss_fake, d_loss_legit, d_sum, d_optim]
        images = np.transpose(np.reshape(images, (-1, 3, 32, 32)),
                              (0, 2, 3, 1))
        with tf.control_dependencies(outputs):
            updates = [tf.assign(p, new_p) for (p, new_p) in D.updates]
        outs = sess.run(outputs + updates,
                        feed_dict={
                            Img: images,
                            Z: z,
                            Z2: z2,
                            K.learning_phase(): 1
                        })
        dl, dlf, dll, sums = outs[:4]
        writer.add_summary(sums, counter)
        return dl, dlf, dll

    def train_g(images, z, counter, sess=sess):
        # generator
        z2 = np.random.normal(0., 1., z.shape)
        outputs = [g_loss, G_train, g_sum, g_optim]
        images = np.transpose(np.reshape(images, (-1, 3, 32, 32)),
                              (0, 2, 3, 1))
        with tf.control_dependencies(outputs):
            updates = [tf.assign(p, new_p) for (p, new_p) in G.updates]
        outs = sess.run(outputs + updates,
                        feed_dict={
                            Img: images,
                            Z: z,
                            Z2: z2,
                            K.learning_phase(): 1
                        })
        gl, samples, sums = outs[:3]
        writer.add_summary(sums, counter)
        # encoder
        outputs = [e_loss, G_dec, e_sum, e_optim]
        with tf.control_dependencies(outputs):
            updates = [tf.assign(p, new_p) for (p, new_p) in E.updates]
        outs = sess.run(outputs + updates,
                        feed_dict={
                            Img: images,
                            Z: z,
                            Z2: z2,
                            K.learning_phase(): 1
                        })
        gl, samples, sums = outs[:3]
        writer.add_summary(sums, counter)

        return gl, samples, images

    def sampler(z, x):
        code = E.predict(x, batch_size=batch_size)[0]
        out = G.predict(code, batch_size=batch_size)
        return out, x

    return train_g, train_d, sampler, [G, D, E]