def get_agent(x, reuse=False):
    """
    Generate the CNN agent
    :param x: tensor, Input frames concatenated along axis 3
    :param reuse: bool, True -> Reuse weight variables
                        False -> Create new ones
    :return: Tensor, logits for each valid action
    """
    if reuse:
        tf.get_variable_scope().reuse_variables()

    x = tf.divide(x, 255.0, name='Normalize')
    conv_1 = tf.nn.relu(
        ops.cnn_2d(x,
                   weight_shape=mc.conv_1,
                   strides=mc.stride_1,
                   name='conv_1'))
    conv_2 = tf.nn.relu(
        ops.cnn_2d(conv_1,
                   weight_shape=mc.conv_2,
                   strides=mc.stride_2,
                   name='conv_2'))
    conv_3 = tf.nn.relu(
        ops.cnn_2d(conv_2,
                   weight_shape=mc.conv_3,
                   strides=mc.stride_3,
                   name='conv_3'))
    conv_3_r = tf.reshape(conv_3, [-1, 7 * 7 * 64], name='reshape')
    dense_1 = tf.nn.relu(
        ops.dense(conv_3_r, 7 * 7 * 64, mc.dense_1, name='dense_1'))
    output = ops.dense(dense_1, mc.dense_1, mc.dense_2, name='dense_2')
    return output
    def model(self, x, action, reuse=False):
        if reuse:
            tf.get_variable_scope().reuse_variables()

        # TODO: Use a better network for video frame prediction
        # Encoder
        x = tf.divide(x, 255.0)
        conv_1 = tf.nn.relu(
            ops.cnn_2d(x,
                       weight_shape=[6, 6, 4, 64],
                       strides=[1, 2, 2, 1],
                       name='conv_1'))
        conv_2 = tf.nn.relu(
            ops.cnn_2d(conv_1,
                       weight_shape=[6, 6, 64, 64],
                       strides=[1, 2, 2, 1],
                       name='conv_2',
                       padding="SAME"))
        conv_3 = tf.nn.relu(
            ops.cnn_2d(conv_2,
                       weight_shape=[6, 6, 64, 64],
                       strides=[1, 2, 2, 1],
                       name='conv_3',
                       padding="SAME"))
        conv_3_flatten = tf.reshape(conv_3, shape=[-1, 6400], name='reshape_1')
        dense_1 = tf.nn.relu(
            ops.dense(conv_3_flatten, 6400, 1024, name='dense_1'))
        dense_2 = ops.dense(dense_1, 1024, 2048, name='dense_2')
        action_dense_1 = ops.dense(action, 4, 2048, name='action_dense_1')
        dense_2_action = tf.multiply(dense_2,
                                     action_dense_1,
                                     name='dense_2_action')

        # Decoder
        dense_3 = ops.dense(dense_2_action, 2048, 1024, name='dense_3')
        dense_4 = tf.nn.relu(
            ops.dense(dense_3, 1024, 11 * 11 * 64, name='dense_4'))
        dense_4_reshaped = tf.reshape(dense_4,
                                      shape=[self.batch_size, 11, 11, 64],
                                      name='dense_4_reshaped')
        conv_t_1 = tf.nn.relu(
            ops.cnn_2d_trans(dense_4_reshaped,
                             weight_shape=[6, 6, 64, 64],
                             strides=[1, 2, 2, 1],
                             output_shape=[self.batch_size, 21, 21, 64],
                             name='conv_t_1'))
        conv_t_2 = tf.nn.relu(
            ops.cnn_2d_trans(conv_t_1,
                             weight_shape=[6, 6, 64, 64],
                             strides=[1, 2, 2, 1],
                             output_shape=[self.batch_size, 42, 42, 64],
                             name='conv_t_2'))
        output = ops.cnn_2d_trans(conv_t_2,
                                  weight_shape=[6, 6, 1, 64],
                                  strides=[1, 2, 2, 1],
                                  output_shape=[self.batch_size, 84, 84, 1],
                                  name='output_image')
        return output
def discriminator(x, reuse=False):
    if reuse:
        tf.get_variable_scope().reuse_variables()

    conv_1 = ops.lrelu(
        ops.batch_norm(ops.cnn_2d(x,
                                  weight_shape=[4, 4, 6, 64],
                                  strides=[1, 2, 2, 1],
                                  name='dis_conv_1'),
                       center=True,
                       scale=True,
                       is_training=True,
                       scope='dis_batch_Norm_1'))
    conv_2 = ops.lrelu(
        ops.batch_norm(ops.cnn_2d(conv_1,
                                  weight_shape=[4, 4, 64, 128],
                                  strides=[1, 2, 2, 1],
                                  name='dis_conv_2'),
                       center=True,
                       scale=True,
                       is_training=True,
                       scope='dis_batch_Norm_2'))
    conv_3 = ops.lrelu(
        ops.batch_norm(ops.cnn_2d(conv_2,
                                  weight_shape=[4, 4, 128, 256],
                                  strides=[1, 2, 2, 1],
                                  name='dis_conv_3'),
                       center=True,
                       scale=True,
                       is_training=True,
                       scope='dis_batch_Norm_3'))
    conv_4 = ops.lrelu(
        ops.batch_norm(ops.cnn_2d(conv_3,
                                  weight_shape=[4, 4, 256, 512],
                                  strides=[1, 2, 2, 1],
                                  name='dis_conv_4'),
                       center=True,
                       scale=True,
                       is_training=True,
                       scope='dis_batch_Norm_4'))
    conv_5 = ops.lrelu(
        ops.batch_norm(ops.cnn_2d(conv_4,
                                  weight_shape=[4, 4, 512, 512],
                                  strides=[1, 2, 2, 1],
                                  name='dis_conv_5'),
                       center=True,
                       scale=True,
                       is_training=True,
                       scope='dis_batch_Norm_5'))
    conv_6 = ops.lrelu(
        ops.batch_norm(ops.cnn_2d(conv_5,
                                  weight_shape=[4, 4, 512, 512],
                                  strides=[1, 2, 2, 1],
                                  name='dis_conv_6'),
                       center=True,
                       scale=True,
                       is_training=True,
                       scope='dis_batch_Norm_6'))
    output = ops.dense(conv_6, 5 * 6, 1, name='dis_output')
    return output
def generator(x, reuse=False):
    if reuse:
        tf.get_variable_scope().reuse_variables()

    # Encoder
    conv_1 = ops.lrelu(
        ops.cnn_2d(x,
                   weight_shape=[4, 4, 3, 64],
                   strides=[1, 2, 2, 1],
                   name='g_e_conv_1'))
    conv_2 = ops.lrelu(
        ops.batch_norm(ops.cnn_2d(conv_1,
                                  weight_shape=[4, 4, 64, 128],
                                  strides=[1, 2, 2, 1],
                                  name='g_e_conv_2'),
                       center=True,
                       scale=True,
                       is_training=True,
                       scope='g_e_batch_Norm_2'))
    conv_3 = ops.lrelu(
        ops.batch_norm(ops.cnn_2d(conv_2,
                                  weight_shape=[4, 4, 128, 256],
                                  strides=[1, 2, 2, 1],
                                  name='g_e_conv_3'),
                       center=True,
                       scale=True,
                       is_training=True,
                       scope='g_e_batch_Norm_3'))
    conv_4 = ops.lrelu(
        ops.batch_norm(ops.cnn_2d(conv_3,
                                  weight_shape=[4, 4, 256, 512],
                                  strides=[1, 2, 2, 1],
                                  name='g_e_conv_4'),
                       center=True,
                       scale=True,
                       is_training=True,
                       scope='g_e_batch_Norm_4'))
    conv_5 = ops.lrelu(
        ops.batch_norm(ops.cnn_2d(conv_4,
                                  weight_shape=[4, 4, 512, 512],
                                  strides=[1, 2, 2, 1],
                                  name='g_e_conv_5'),
                       center=True,
                       scale=True,
                       is_training=True,
                       scope='g_e_batch_Norm_5'))
    conv_6 = ops.lrelu(
        ops.batch_norm(ops.cnn_2d(conv_5,
                                  weight_shape=[4, 4, 512, 512],
                                  strides=[1, 2, 2, 1],
                                  name='g_e_conv_6'),
                       center=True,
                       scale=True,
                       is_training=True,
                       scope='g_e_batch_Norm_6'))
    conv_7 = ops.lrelu(
        ops.batch_norm(ops.cnn_2d(conv_6,
                                  weight_shape=[4, 4, 512, 512],
                                  strides=[1, 2, 2, 1],
                                  name='g_e_conv_7'),
                       center=True,
                       scale=True,
                       is_training=True,
                       scope='g_e_batch_Norm_7'))
    conv_8 = ops.lrelu(
        ops.batch_norm(ops.cnn_2d(conv_7,
                                  weight_shape=[4, 4, 512, 512],
                                  strides=[1, 2, 2, 1],
                                  name='g_e_conv_8'),
                       center=True,
                       scale=True,
                       is_training=True,
                       scope='g_e_batch_Norm_8'))

    # Decoder
    dconv_1 = ops.lrelu(
        tf.nn.dropout(ops.batch_norm(ops.cnn_2d_trans(
            conv_8,
            weight_shape=[2, 2, 512, 512],
            strides=[1, 2, 2, 1],
            output_shape=[
                mc.batch_size,
                conv_8.get_shape()[1].value + 1,
                conv_8.get_shape()[2].value + 1, 512
            ],
            name='g_d_dconv_1'),
                                     center=True,
                                     scale=True,
                                     is_training=True,
                                     scope='g_d_batch_Norm_1'),
                      keep_prob=0.5))
    dconv_1 = tf.concat([dconv_1, conv_7], axis=3)
    dconv_2 = ops.lrelu(
        tf.nn.dropout(ops.batch_norm(ops.cnn_2d_trans(
            dconv_1,
            weight_shape=[4, 4, 512, 1024],
            strides=[1, 2, 2, 1],
            output_shape=[
                mc.batch_size,
                dconv_1.get_shape()[1].value * 2 - 1,
                dconv_1.get_shape()[2].value * 2, 512
            ],
            name='g_d_dconv_2'),
                                     center=True,
                                     scale=True,
                                     is_training=True,
                                     scope='g_d_batch_Norm_2'),
                      keep_prob=0.5))
    dconv_2 = tf.concat([dconv_2, conv_6], axis=3)
    dconv_3 = ops.lrelu(
        tf.nn.dropout(ops.batch_norm(ops.cnn_2d_trans(
            dconv_2,
            weight_shape=[4, 4, 512, 1024],
            strides=[1, 2, 2, 1],
            output_shape=[
                mc.batch_size,
                dconv_2.get_shape()[1].value * 2 - 1,
                dconv_2.get_shape()[2].value * 2 - 1, 512
            ],
            name='g_d_dconv_3'),
                                     center=True,
                                     scale=True,
                                     is_training=True,
                                     scope='g_d_batch_Norm_3'),
                      keep_prob=0.5))
    dconv_3 = tf.concat([dconv_3, conv_5], axis=3)
    dconv_4 = ops.lrelu(
        ops.batch_norm(ops.cnn_2d_trans(dconv_3,
                                        weight_shape=[4, 4, 512, 1024],
                                        strides=[1, 2, 2, 1],
                                        output_shape=[
                                            mc.batch_size,
                                            dconv_3.get_shape()[1].value * 2,
                                            dconv_3.get_shape()[2].value * 2,
                                            512
                                        ],
                                        name='g_d_dconv_4'),
                       center=True,
                       scale=True,
                       is_training=True,
                       scope='g_d_batch_Norm_4'))
    dconv_4 = tf.concat([dconv_4, conv_4], axis=3)
    dconv_5 = ops.lrelu(
        ops.batch_norm(ops.cnn_2d_trans(dconv_4,
                                        weight_shape=[4, 4, 256, 1024],
                                        strides=[1, 2, 2, 1],
                                        output_shape=[
                                            mc.batch_size,
                                            dconv_4.get_shape()[1].value * 2,
                                            dconv_4.get_shape()[2].value * 2,
                                            256
                                        ],
                                        name='g_d_dconv_5'),
                       center=True,
                       scale=True,
                       is_training=True,
                       scope='g_d_batch_Norm_5'))
    dconv_5 = tf.concat([dconv_5, conv_3], axis=3)
    dconv_6 = ops.lrelu(
        ops.batch_norm(ops.cnn_2d_trans(dconv_5,
                                        weight_shape=[4, 4, 128, 512],
                                        strides=[1, 2, 2, 1],
                                        output_shape=[
                                            mc.batch_size,
                                            dconv_5.get_shape()[1].value * 2,
                                            dconv_5.get_shape()[2].value * 2,
                                            128
                                        ],
                                        name='g_d_dconv_6'),
                       center=True,
                       scale=True,
                       is_training=True,
                       scope='g_d_batch_Norm_6'))
    dconv_6 = tf.concat([dconv_6, conv_2], axis=3)
    dconv_7 = ops.lrelu(
        ops.batch_norm(ops.cnn_2d_trans(dconv_6,
                                        weight_shape=[4, 4, 64, 256],
                                        strides=[1, 2, 2, 1],
                                        output_shape=[
                                            mc.batch_size,
                                            dconv_6.get_shape()[1].value * 2,
                                            dconv_6.get_shape()[2].value * 2,
                                            64
                                        ],
                                        name='g_d_dconv_7'),
                       center=True,
                       scale=True,
                       is_training=True,
                       scope='g_d_batch_Norm_7'))
    dconv_7 = tf.concat([dconv_7, conv_1], axis=3)
    dconv_8 = tf.nn.tanh(
        ops.cnn_2d_trans(dconv_7,
                         weight_shape=[4, 4, 3, 128],
                         strides=[1, 2, 2, 1],
                         output_shape=[
                             mc.batch_size,
                             dconv_7.get_shape()[1].value * 2,
                             dconv_7.get_shape()[2].value * 2, 3
                         ],
                         name='g_d_dconv_8'))
    return dconv_8
Exemple #5
0
    def generator(self, x, action, reuse=False):
        if reuse:
            tf.get_variable_scope().reuse_variables()

        # TODO: Use a better network for video frame prediction

        x = tf.divide(x, 255.0)
        # Encoder
        conv_1 = ops.lrelu(
            ops.cnn_2d(x,
                       weight_shape=[4, 4, 4, 64],
                       strides=[1, 2, 2, 1],
                       padding="SAME",
                       name='g_e_conv_1'))
        conv_2 = ops.lrelu(
            ops.batch_norm(ops.cnn_2d(conv_1,
                                      weight_shape=[4, 4, 64, 128],
                                      strides=[1, 2, 2, 1],
                                      padding="SAME",
                                      name='g_e_conv_2'),
                           center=True,
                           scale=True,
                           is_training=True,
                           scope='g_e_batch_Norm_2'))
        conv_3 = ops.lrelu(
            ops.batch_norm(ops.cnn_2d(conv_2,
                                      weight_shape=[4, 4, 128, 256],
                                      strides=[1, 2, 2, 1],
                                      padding="SAME",
                                      name='g_e_conv_3'),
                           center=True,
                           scale=True,
                           is_training=True,
                           scope='g_e_batch_Norm_3'))
        conv_4 = ops.lrelu(
            ops.batch_norm(ops.cnn_2d(conv_3,
                                      weight_shape=[4, 4, 256, 512],
                                      strides=[1, 2, 2, 1],
                                      padding="SAME",
                                      name='g_e_conv_4'),
                           center=True,
                           scale=True,
                           is_training=True,
                           scope='g_e_batch_Norm_4'))
        conv_5 = ops.lrelu(
            ops.batch_norm(ops.cnn_2d(conv_4,
                                      weight_shape=[4, 4, 512, 512],
                                      strides=[1, 2, 2, 1],
                                      padding="SAME",
                                      name='g_e_conv_5'),
                           center=True,
                           scale=True,
                           is_training=True,
                           scope='g_e_batch_Norm_5'))
        conv_6 = ops.lrelu(
            ops.batch_norm(ops.cnn_2d(conv_5,
                                      weight_shape=[4, 4, 512, 512],
                                      strides=[1, 2, 2, 1],
                                      padding="SAME",
                                      name='g_e_conv_6'),
                           center=True,
                           scale=True,
                           is_training=True,
                           scope='g_e_batch_Norm_6'))
        conv_6_reshaped = tf.reshape(conv_6, [-1, 2 * 2 * 512],
                                     name='g_conv_6_reshape')

        action_dense_1 = ops.dense(action, 4, 2048, name='g_action_dense_1')
        action_dense_2 = tf.multiply(conv_6_reshaped,
                                     action_dense_1,
                                     name='g_action_dense_2')

        action_dense_2_reshaped = tf.reshape(action_dense_2, [-1, 2, 2, 512])

        # Decoder
        dconv_1 = ops.lrelu(
            ops.batch_norm(ops.cnn_2d_trans(
                action_dense_2_reshaped,
                weight_shape=[2, 2, 512, 512],
                strides=[1, 2, 2, 1],
                output_shape=[
                    self.batch_size,
                    action_dense_2_reshaped.get_shape()[1].value * 2 - 1,
                    action_dense_2_reshaped.get_shape()[2].value * 2 - 1, 512
                ],
                name='g_d_dconv_1'),
                           center=True,
                           scale=True,
                           is_training=True,
                           scope='g_d_batch_Norm_1'))
        dconv_1 = tf.concat([dconv_1, conv_5], axis=3)
        dconv_2 = ops.lrelu(
            ops.batch_norm(ops.cnn_2d_trans(
                dconv_1,
                weight_shape=[4, 4, 512, 1024],
                strides=[1, 2, 2, 1],
                output_shape=[
                    self.batch_size,
                    dconv_1.get_shape()[1].value * 2,
                    dconv_1.get_shape()[2].value * 2, 512
                ],
                name='g_d_dconv_2'),
                           center=True,
                           scale=True,
                           is_training=True,
                           scope='g_d_batch_Norm_2'))
        dconv_2 = tf.concat([dconv_2, conv_4], axis=3)
        dconv_3 = ops.lrelu(
            ops.batch_norm(ops.cnn_2d_trans(
                dconv_2,
                weight_shape=[4, 4, 256, 1024],
                strides=[1, 2, 2, 1],
                output_shape=[
                    self.batch_size,
                    dconv_2.get_shape()[1].value * 2 - 1,
                    dconv_2.get_shape()[2].value * 2 - 1, 256
                ],
                name='g_d_dconv_3'),
                           center=True,
                           scale=True,
                           is_training=True,
                           scope='g_d_batch_Norm_3'))
        dconv_3 = tf.concat([dconv_3, conv_3], axis=3)
        dconv_4 = ops.lrelu(
            ops.batch_norm(ops.cnn_2d_trans(
                dconv_3,
                weight_shape=[4, 4, 128, 512],
                strides=[1, 2, 2, 1],
                output_shape=[
                    self.batch_size,
                    dconv_3.get_shape()[1].value * 2 - 1,
                    dconv_3.get_shape()[2].value * 2 - 1, 128
                ],
                name='g_d_dconv_4'),
                           center=True,
                           scale=True,
                           is_training=True,
                           scope='g_d_batch_Norm_4'))
        dconv_4 = tf.concat([dconv_4, conv_2], axis=3)
        dconv_5 = ops.lrelu(
            ops.batch_norm(ops.cnn_2d_trans(
                dconv_4,
                weight_shape=[4, 4, 64, 256],
                strides=[1, 2, 2, 1],
                output_shape=[
                    self.batch_size,
                    dconv_4.get_shape()[1].value * 2,
                    dconv_4.get_shape()[2].value * 2, 64
                ],
                name='g_d_dconv_5'),
                           center=True,
                           scale=True,
                           is_training=True,
                           scope='g_d_batch_Norm_5'))
        dconv_5 = tf.concat([dconv_5, conv_1], axis=3)
        output = tf.nn.tanh(
            ops.cnn_2d_trans(dconv_5,
                             weight_shape=[4, 4, 1, 128],
                             strides=[1, 2, 2, 1],
                             output_shape=[
                                 self.batch_size,
                                 dconv_5.get_shape()[1].value * 2,
                                 dconv_5.get_shape()[2].value * 2, 1
                             ],
                             name='g_output'))
        return output