Example #1
0
    def __init__(self, output_dist, latent_spec, batch_size, image_shape, network_type):
        """
        :type output_dist: Distribution
        :type latent_spec: list[(Distribution, bool)]
        :type batch_size: int
        :type network_type: string
        """
        self.output_dist = output_dist
        self.latent_spec = latent_spec
        self.latent_dist = Product([x for x, _ in latent_spec])
        self.reg_latent_dist = Product([x for x, reg in latent_spec if reg])
        self.nonreg_latent_dist = Product([x for x, reg in latent_spec if not reg])
        self.batch_size = batch_size
        self.network_type = network_type
        self.image_shape = image_shape
        assert all(isinstance(x, (Gaussian, Categorical, Bernoulli)) for x in self.reg_latent_dist.dists)

        self.reg_cont_latent_dist = Product([x for x in self.reg_latent_dist.dists if isinstance(x, Gaussian)])
        self.reg_disc_latent_dist = Product([x for x in self.reg_latent_dist.dists if isinstance(x, (Categorical, Bernoulli))])

        image_size = image_shape[0]
        if network_type == "mnist":
            with tf.variable_scope("d_net"):
                shared_template = \
                    (pt.template("input").
                     reshape([-1] + list(image_shape)).
                     custom_conv2d(64, k_h=4, k_w=4).
                     apply(leaky_rectify).
                     custom_conv2d(128, k_h=4, k_w=4).
                     conv_batch_norm().
                     apply(leaky_rectify).
                     custom_fully_connected(1024).
                     fc_batch_norm().
                     apply(leaky_rectify))
                self.discriminator_template = shared_template.custom_fully_connected(1)
                self.encoder_template = \
                    (shared_template.
                     custom_fully_connected(128).
                     fc_batch_norm().
                     apply(leaky_rectify).
                     custom_fully_connected(self.reg_latent_dist.dist_flat_dim))

            with tf.variable_scope("g_net"):
                self.generator_template = \
                    (pt.template("input").
                     custom_fully_connected(1024).
                     fc_batch_norm().
                     apply(tf.nn.relu).
                     custom_fully_connected(image_size // 4 * image_size // 4 * 128).
                     fc_batch_norm().
                     apply(tf.nn.relu).
                     reshape([-1, image_size // 4, image_size // 4, 128]).
                     custom_deconv2d([0, image_size // 2, image_size // 2, 64], k_h=4, k_w=4).
                     conv_batch_norm().
                     apply(tf.nn.relu).
                     custom_deconv2d([0] + list(image_shape), k_h=4, k_w=4).
                     flatten())
        else:
            raise NotImplementedError
Example #2
0
    def d_encode_image(self):
        node1_0 = \
            (pt.template("input").
             custom_conv2d(self.df_dim, k_h=4, k_w=4).
             apply(leaky_rectify, leakiness=0.2).
             custom_conv2d(self.df_dim * 2, k_h=4, k_w=4).
             conv_batch_norm().
             apply(leaky_rectify, leakiness=0.2).
             custom_conv2d(self.df_dim * 4, k_h=4, k_w=4).
             conv_batch_norm().
             custom_conv2d(self.df_dim * 8, k_h=4, k_w=4).
             conv_batch_norm())
        node1_1 = \
            (node1_0.
             custom_conv2d(self.df_dim * 2, k_h=1, k_w=1, d_h=1, d_w=1).
             conv_batch_norm().
             apply(leaky_rectify, leakiness=0.2).
             custom_conv2d(self.df_dim * 2, k_h=3, k_w=3, d_h=1, d_w=1).
             conv_batch_norm().
             apply(leaky_rectify, leakiness=0.2).
             custom_conv2d(self.df_dim * 8, k_h=3, k_w=3, d_h=1, d_w=1).
             conv_batch_norm())

        node1 = \
            (node1_0.
             apply(tf.add, node1_1).
             apply(leaky_rectify, leakiness=0.2))

        return node1
Example #3
0
    def discriminator(self):
        template = \
            (pt.template("input").  # 128*9*4*4
             custom_conv2d(self.df_dim * 8, k_h=1, k_w=1, d_h=1, d_w=1).  # 128*8*4*4
             conv_batch_norm().
             apply(leaky_rectify, leakiness=0.2).
             # custom_fully_connected(1))
             custom_conv2d(1, k_h=self.s16, k_w=self.s16, d_h=self.s16, d_w=self.s16))

        return template
  def testGraphMatchesImmediate(self):
    """Ensures that the vars line up between the two modes."""
    with tf.Graph().as_default():
      input_pt = prettytensor.wrap(self.input)
      self.BuildLargishGraph(input_pt)
      normal_names = sorted([v.name for v in tf.all_variables()])

    with tf.Graph().as_default():
      template = prettytensor.template('input')
      self.BuildLargishGraph(template).construct(
          input=prettytensor.wrap(self.input))
      template_names = sorted([v.name for v in tf.all_variables()])

    self.assertSequenceEqual(normal_names, template_names)
Example #5
0
 def d_encode_image_simple(self):
     template = \
         (pt.template("input").
          custom_conv2d(self.df_dim, k_h=4, k_w=4).
          apply(leaky_rectify, leakiness=0.2).
          custom_conv2d(self.df_dim * 2, k_h=4, k_w=4).
          conv_batch_norm().
          apply(leaky_rectify, leakiness=0.2).
          custom_conv2d(self.df_dim * 4, k_h=4, k_w=4).
          conv_batch_norm().
          apply(leaky_rectify, leakiness=0.2).
          custom_conv2d(self.df_dim * 8, k_h=4, k_w=4).
          conv_batch_norm().
          apply(leaky_rectify, leakiness=0.2))
     return template
  def testGraphMatchesImmediate(self):
    """Ensures that the vars line up between the two modes."""
    with tf.Graph().as_default():
      input_pt = prettytensor.wrap(
          tf.constant(self.input_data, dtype=tf.float32))
      self.BuildLargishGraph(input_pt)
      normal_names = sorted([v.name for v in tf.global_variables()])

    with tf.Graph().as_default():
      template = prettytensor.template('input')
      self.BuildLargishGraph(template).construct(input=prettytensor.wrap(
          tf.constant(self.input_data, dtype=tf.float32)))
      template_names = sorted([v.name for v in tf.global_variables()])

    self.assertSequenceEqual(normal_names, template_names)
 def shared_net(self):
     shared_template = \
         (pt.template("input").
          reshape([-1] + list(self.image_shape)).
          custom_conv2d(self.df_dim, name='d_h0_conv', k_h=self.k_h, k_w=self.k_w).
          apply(leaky_rectify).
          custom_conv2d(self.df_dim*2, name='d_h1_conv', k_h=self.k_h, k_w=self.k_w).
          conv_batch_norm().
          apply(leaky_rectify).
          custom_conv2d(self.df_dim*4, name='d_h2_conv', k_h=self.k_h, k_w=self.k_w).
          conv_batch_norm().
          apply(leaky_rectify).
          custom_conv2d(self.df_dim*8, name='d_h3_conv', k_h=self.k_h, k_w=self.k_w).
          conv_batch_norm().
          apply(leaky_rectify))
     return shared_template
Example #8
0
    def d_encode_image_simple(self):
        template = \
            (pt.template("input").
             custom_conv2d(self.df_dim, k_h=4, k_w=4).
             apply(leaky_rectify, leakiness=0.2).
             custom_conv2d(self.df_dim * 2, k_h=4, k_w=4).
             conv_batch_norm().
             apply(leaky_rectify, leakiness=0.2).
             custom_conv2d(self.df_dim * 4, k_h=4, k_w=4).
             conv_batch_norm().
             apply(leaky_rectify, leakiness=0.2).
             custom_conv2d(self.df_dim * 8, k_h=4, k_w=4).
             conv_batch_norm().
             apply(leaky_rectify, leakiness=0.2))

        return template
Example #9
0
 def infoGAN_mnist_net(self, image_shape):
     image_size = image_shape[0]
     generator_template = \
         (pt.template("input").
          custom_fully_connected(1024).
          fc_batch_norm().
          apply(tf.nn.relu).
          custom_fully_connected(image_size / 4 * image_size / 4 * 128).
          fc_batch_norm().
          apply(tf.nn.relu).
          reshape([-1, image_size / 4, image_size / 4, 128]).
          custom_deconv2d([0, image_size/2, image_size/2, 64], k_h=4, k_w=4).
          conv_batch_norm().
          apply(tf.nn.relu).
          custom_deconv2d([0] + list(image_shape), k_h=4, k_w=4).
          flatten())
     return generator_template
Example #10
0
    def hr_d_encode_image(self):
        node1_0 = \
            (pt.template("input").  # 4s * 4s * 3
             custom_conv2d(self.df_dim, k_h=4, k_w=4).  # 2s * 2s * df_dim
             apply(leaky_rectify, leakiness=0.2).
             custom_conv2d(self.df_dim * 2, k_h=4, k_w=4).  # s * s * df_dim*2
             conv_batch_norm().
             apply(leaky_rectify, leakiness=0.2).
             custom_conv2d(self.df_dim * 4, k_h=4, k_w=4).  # s2 * s2 * df_dim*4
             conv_batch_norm().
             apply(leaky_rectify, leakiness=0.2).
             custom_conv2d(self.df_dim * 8, k_h=4, k_w=4).  # s4 * s4 * df_dim*8
             conv_batch_norm().
             apply(leaky_rectify, leakiness=0.2).
             custom_conv2d(self.df_dim * 16, k_h=4, k_w=4).  # s8 * s8 * df_dim*16
             conv_batch_norm().
             apply(leaky_rectify, leakiness=0.2).
             custom_conv2d(self.df_dim * 32, k_h=4, k_w=4).  # s16 * s16 * df_dim*32
             conv_batch_norm().
             apply(leaky_rectify, leakiness=0.2).
             custom_conv2d(self.df_dim * 16, k_h=1, k_w=1, d_h=1, d_w=1).  # s16 * s16 * df_dim*16
             conv_batch_norm().
             apply(leaky_rectify, leakiness=0.2).
             custom_conv2d(self.df_dim * 8, k_h=1, k_w=1, d_h=1, d_w=1).  # s16 * s16 * df_dim*8
             conv_batch_norm())
        node1_1 = \
            (node1_0.
             custom_conv2d(self.df_dim * 2, k_h=1, k_w=1, d_h=1, d_w=1).
             conv_batch_norm().
             apply(leaky_rectify, leakiness=0.2).
             custom_conv2d(self.df_dim * 2, k_h=3, k_w=3, d_h=1, d_w=1).
             conv_batch_norm().
             apply(leaky_rectify, leakiness=0.2).
             custom_conv2d(self.df_dim * 8, k_h=3, k_w=3, d_h=1, d_w=1).
             conv_batch_norm())

        node1 = \
            (node1_0.
             apply(tf.add, node1_1).
             apply(leaky_rectify, leakiness=0.2))

        return node1
Example #11
0
def conv_ae(scope,
            filter_no,
            img_length=64,
            bottleneck=4,
            channel=3,
            act_fn=tf.nn.relu,
            last_act=tf.tanh):
    with tf.variable_scope(scope):
        with pt.defaults_scope(activation_fn=act_fn):
            layer = pt.template('batch').conv2d(4,
                                                filter_no,
                                                stride=2,
                                                name='conv1')
            img_length >>= 1
            i = 0
            while img_length > bottleneck:
                filter_no <<= 1
                img_length >>= 1
                layer = layer.conv2d(4,
                                     filter_no,
                                     stride=2,
                                     name='conv%d' % (i + 2))
                i += 1

            for j in range(i):
                filter_no >>= 1
                img_length <<= 1
                layer = layer.deconv2d(4,
                                       filter_no,
                                       [-1, img_length, img_length, filter_no],
                                       stride=2,
                                       name='deconv%d' % (j + 1))

            img_length <<= 1
            return layer.deconv2d(4,
                                  channel,
                                  [-1, img_length, img_length, channel],
                                  stride=2,
                                  name='deconv%d' % (i + 1),
                                  activation_fn=last_act)
    return cae_tpl
    def shared_net(self):
        shared_template = \
            (pt.template("input").
             reshape([-1] + list(self.image_shape)).
             custom_conv2d(self.df_dim, name='d_h0_conv', k_h=self.k_h, k_w=self.k_w).
             apply(leaky_rectify).
             custom_conv2d(self.df_dim*2, name='d_h1_conv', k_h=self.k_h, k_w=self.k_w).
             conv_batch_norm(addNoise=self.addNoise).
             apply(leaky_rectify).
             custom_conv2d(self.df_dim*4, name='d_h2_conv', k_h=self.k_h, k_w=self.k_w).
             conv_batch_norm(addNoise=self.addNoise).
             apply(leaky_rectify).
             custom_conv2d(self.df_dim*8, name='d_h3_conv', k_h=self.k_h, k_w=self.k_w).
             conv_batch_norm(addNoise=self.addNoise).
             apply(leaky_rectify,name="OutDiscriminator"))
        #.custom_fully_connected(512))
        self.intermLayer = (shared_template.as_layer())

        #shared_template=shared_template.apply(leaky_rectify,name="OutDiscriminator")

        return shared_template
Example #13
0
def conv_gen_bn(scope,
                filter_no,
                z_dim,
                img_size=64,
                bottleneck=4,
                channel=3,
                bn_arg=False,
                act_fn=tf.nn.relu,
                last_act=tf.tanh):
    if not bn_arg:
        bias = tf.zeros_initializer()
    else:
        bias = None
    with tf.variable_scope(scope):
        with pt.defaults_scope(activation_fn=act_fn, batch_normalize=bn_arg):
            layer = pt.template('batch').reshape((-1, 1, 1, z_dim)) \
                .deconv2d(bottleneck, filter_no, [-1, bottleneck, bottleneck, filter_no], stride=1,
                          edges=pt.pretty_tensor_class.PAD_VALID, name='deconv1', bias=bias)

            img_length = bottleneck
            i = 2
            while img_length < img_size / 2:
                filter_no >>= 1
                img_length <<= 1
                layer = layer.deconv2d(4,
                                       filter_no,
                                       [-1, img_length, img_length, filter_no],
                                       stride=2,
                                       name='deconv%d' % i,
                                       bias=bias)
                i += 1

            img_length <<= 1
            return layer.deconv2d(4,
                                  channel,
                                  [-1, img_length, img_length, channel],
                                  stride=2,
                                  activation_fn=last_act,
                                  name='deconv%d' % i,
                                  batch_normalize=False)
Example #14
0
    def gen_net(self, image_shape):
        sx = image_shape[0]
        sy = image_shape[1]
        sx2, sx4, sx8, sx16 = int(np.ceil(sx*1.0/2)), int(np.ceil(sx*1.0/4)), int(np.ceil(sx*1.0/8)), int(np.ceil(sx*1.0/16))
        sy2, sy4, sy8, sy16 = int(np.ceil(sy*1.0/2)), int(np.ceil(sy*1.0/4)), int(np.ceil(sy*1.0/8)), int(np.ceil(sy*1.0/16))

        if (sx == 96) and (sy == 96):
            self.k_h = self.k_w = 5
        elif (sx < 20) or (sy < 20):
            self.k_h = self.k_w = 2
        else:
            self.k_h = self.k_w = 3

        generator_template = \
            (pt.template("input").
             custom_fully_connected(self.gf_dim*8*sx16*sy16, scope='g_h0_lin').
             fc_batch_norm().
             apply(tf.nn.relu).
             reshape([-1, sx16, sy16, self.gf_dim * 8]).
             custom_deconv2d([self.batch_size, sx8, sy8, self.gf_dim*4],
                             name='g_h1', k_h=self.k_h, k_w=self.k_w,useResize=self.improved).
             conv_batch_norm().
             apply(tf.nn.relu).
             custom_deconv2d([self.batch_size, sx4, sy4, self.gf_dim*2],
                             name='g_h2', k_h=self.k_h, k_w=self.k_w,useResize=self.improved).
             conv_batch_norm().
             apply(tf.nn.relu).
             custom_deconv2d([self.batch_size, sx2, sy2, self.gf_dim*1],
                             name='g_h3', k_h=self.k_h,useResize=self.improved).
             conv_batch_norm().
             apply(tf.nn.relu).
             custom_deconv2d([self.batch_size, sx, sy, self.c_dim],
                             name='g_h4', k_h=self.k_h, k_w=self.k_w,useResize=self.improved).
             apply(tf.nn.tanh,name="OutGenerator"))

        return generator_template
Example #15
0
def discriminator_template():
    num_filters = FLAGS.discrim_filter_base
    with tf.variable_scope('discriminator'):
        tmp = pt.template('input')
        for i in xrange(input.NUM_LEVELS):
            if i > 0:
                tmp = tmp.dropout(FLAGS.keep_prob)
            tmp = tmp.conv2d(5, num_filters)
            if i > 0:
                tmp = tmp.batch_normalize()
            tmp = tmp.apply(discrim_activation_fn).max_pool(2, 2)
            num_filters *= 2
        tmp = tmp.flatten()
        features = tmp

        minibatch_discrim = features.minibatch_discrimination(100)

        for i in xrange(FLAGS.discrim_fc_layers - 1):
            tmp = tmp.fully_connected(
                FLAGS.discrim_fc_size).apply(discrim_activation_fn)
        tmp = tmp.concat(1, [minibatch_discrim]).fully_connected(1)
        output = tmp

    return output
Example #16
0
        tf.zeros([FLAGS.batch_size, FLAGS.rnn_size], tf.float32)), )

    sampled_tensors = []
    glimpse_tensors = []
    write_tensors = []
    params_tensors = []

    loss = 0.0
    with tf.variable_scope("model"):
        with pt.defaults_scope(activation_fn=tf.nn.elu,
                               batch_normalize=True,
                               learned_moments_update_rate=0.1,
                               variance_epsilon=0.001,
                               scale_after_normalization=True):
            # Encoder RNN (Eq. 5)
            encoder_template = (pt.template('input').gru_cell(
                num_units=FLAGS.rnn_size, state=pt.UnboundVariable('state')))

            # Projection of encoder RNN output (Eq. 1-2)
            encoder_proj_template = (pt.template('input').fully_connected(
                FLAGS.hidden_size * 2, activation_fn=None))

            # Params of read from decoder RNN output (Eq. 21)
            decoder_read_params_template = (
                pt.template('input').fully_connected(5, activation_fn=None))

            # Decoder RNN (Eq. 7)
            decoder_template = (pt.template('input').gru_cell(
                num_units=FLAGS.rnn_size, state=pt.UnboundVariable('state')))

            # Projection of decoder RNN output (Eq. 18)
            decoder_proj_template = (pt.template('input').fully_connected(
    sampled_state = (pt.wrap(tf.zeros([FLAGS.batch_size, FLAGS.rnn_size], tf.float32)),)

    sampled_tensors = []
    glimpse_tensors = []
    write_tensors = []
    params_tensors = []

    loss = 0.0
    with tf.variable_scope("model"):
        with pt.defaults_scope(activation_fn=tf.nn.elu,
                               batch_normalize=True,
                               learned_moments_update_rate=0.1,
                               variance_epsilon=0.001,
                               scale_after_normalization=True):
            # Encoder RNN (Eq. 5)
            encoder_template = (pt.template('input').
                                gru_cell(num_units=FLAGS.rnn_size, state=pt.UnboundVariable('state')))

            # Projection of encoder RNN output (Eq. 1-2)
            encoder_proj_template = (pt.template('input').
                                     fully_connected(FLAGS.hidden_size * 2, activation_fn=None))

            # Params of read from decoder RNN output (Eq. 21)
            decoder_read_params_template = (pt.template('input').
                                            fully_connected(5, activation_fn=None))

            # Decoder RNN (Eq. 7)
            decoder_template = (pt.template('input').
                                gru_cell(num_units=FLAGS.rnn_size, state=pt.UnboundVariable('state')))

            # Projection of decoder RNN output (Eq. 18)
            decoder_proj_template = (pt.template('input').
Example #18
0
    def build_network(self, scope_suffix=''):
        with tf.variable_scope("d_net{}".format(scope_suffix)):
            paddings = [[0, 0], [1, 1], [1, 1], [0, 0]]
            mode = "CONSTANT"
            self.discriminator_template = \
                (pt.template("input").
                 reshape([-1] + list(self.output_shape)).
                 custom_conv2d(64, k_h=4, k_w=4, d_h=2, d_w=2).
                 apply(leaky_rectify, 0.2).
                 custom_conv2d(128, k_h=4, k_w=4, d_h=2, d_w=2).
                 conv_instance_norm().
                 apply(leaky_rectify, 0.2).
                 custom_conv2d(256, k_h=4, k_w=4, d_h=2, d_w=2).
                 conv_instance_norm().
                 apply(leaky_rectify, 0.2).
                 apply(tf.pad, paddings, mode).
                 custom_conv2d(512, k_h=4, k_w=4, d_h=2, d_w=2, padding='VALID').
                 conv_instance_norm().
                 apply(leaky_rectify, 0.2).
                 apply(tf.pad, paddings, mode).
                 custom_conv2d(1, k_h=4, k_w=4, d_h=1, d_w=1, padding='VALID'))

        with tf.variable_scope("g_net{}".format(scope_suffix)):
            # TODO: Add reflection padding
            # with apply(tf.pad([[?,?], [?,?]], 'REFLECT'))
            # and padding='VALID' in custom_conv_2d
            paddings = [[0, 0], [1, 1], [1, 1], [0, 0]]
            mode = "REFLECT"
            self.generator_template = \
                (pt.template("input").
                 reshape([-1] + list(self.input_shape)).
                 apply(tf.pad, [[0, 0], [3, 3], [3, 3], [0, 0]], mode).
                 custom_conv2d(32, k_h=7, k_w=7, d_h=1, d_w=1, padding='VALID').
                 conv_instance_norm().
                 apply(tf.nn.relu).
                 apply(tf.pad, paddings, mode).
                 custom_conv2d(64, k_h=3, k_w=3, d_h=2, d_w=2, padding='VALID').
                 conv_instance_norm().
                 apply(tf.nn.relu).
                 apply(tf.pad, paddings, mode).
                 custom_conv2d(128, k_h=3, k_w=3, d_h=2, d_w=2, padding='VALID').
                 conv_instance_norm().
                 apply(tf.nn.relu).
                 custom_residual(paddings, mode, custom_conv2d, 128, k_h=3, k_w=3, d_h=1, d_w=1, padding='VALID').
                 custom_residual(paddings, mode, custom_conv2d, 128, k_h=3, k_w=3, d_h=1, d_w=1, padding='VALID').
                 custom_residual(paddings, mode, custom_conv2d, 128, k_h=3, k_w=3, d_h=1, d_w=1, padding='VALID').
                 custom_residual(paddings, mode, custom_conv2d, 128, k_h=3, k_w=3, d_h=1, d_w=1, padding='VALID').
                 custom_residual(paddings, mode, custom_conv2d, 128, k_h=3, k_w=3, d_h=1, d_w=1, padding='VALID').
                 custom_residual(paddings, mode, custom_conv2d, 128, k_h=3, k_w=3, d_h=1, d_w=1, padding='VALID').
                 custom_residual(paddings, mode, custom_conv2d, 128, k_h=3, k_w=3, d_h=1, d_w=1, padding='VALID').
                 custom_residual(paddings, mode, custom_conv2d, 128, k_h=3, k_w=3, d_h=1, d_w=1, padding='VALID').
                 custom_residual(paddings, mode, custom_conv2d, 128, k_h=3, k_w=3, d_h=1, d_w=1, padding='VALID').
                 custom_deconv2d([0, self.output_size / 2, self.output_size / 2, 64],
                                 k_h=3, k_w=3, d_h=2, d_w=2,  padding='SAME'). # 'VALID' or 'SAME' ?
                 conv_instance_norm().
                 apply(tf.nn.relu).
                 custom_deconv2d([0, self.output_size, self.output_size, 32],
                                 k_h=3, k_w=3, d_h=2, d_w=2,  padding='SAME'). # 'VALID' or 'SAME' ?
                 conv_instance_norm().
                 apply(tf.nn.relu).
                 apply(tf.pad, [[0, 0], [3, 3], [3, 3], [0, 0]], mode).
                 custom_conv2d(3, k_h=7, k_w=7, d_h=1, d_w=1, padding='VALID').
                 # conv_instance_norm().
                 # apply(tf.nn.relu).
                 apply(tf.nn.tanh).
                 flatten())
Example #19
0
 def context_embedding(self):
     template = (pt.template("input").
                 custom_fully_connected(self.ef_dim).
                 apply(leaky_rectify, leakiness=0.2))
     return template
 def Template(self, key):
   return prettytensor.template(key, self.bookkeeper)
Example #21
0
    def __init__(self, output_dist, latent_spec, batch_size, image_shape,
                 network_type):
        """
        :type output_dist: Distribution
        :type latent_spec: list[(Distribution, bool)]
        :type batch_size: int
        :type network_type: string
        """
        self.output_dist = output_dist
        self.latent_spec = latent_spec
        self.latent_dist = Product([x for x, _ in latent_spec])
        self.reg_latent_dist = Product([x for x, reg in latent_spec if reg])
        self.nonreg_latent_dist = Product(
            [x for x, reg in latent_spec if not reg])
        self.batch_size = batch_size
        self.network_type = network_type
        self.image_shape = image_shape
        assert all(
            isinstance(x, (Gaussian, Categorical, Bernoulli))
            for x in self.reg_latent_dist.dists)

        self.reg_cont_latent_dist = Product(
            [x for x in self.reg_latent_dist.dists if isinstance(x, Gaussian)])
        self.reg_disc_latent_dist = Product([
            x for x in self.reg_latent_dist.dists
            if isinstance(x, (Categorical, Bernoulli))
        ])

        image_size = image_shape[0]
        if network_type == "mnist":
            with tf.variable_scope("d_net"):
                shared_template = \
                    (pt.template("input").
                     reshape([-1] + list(image_shape)).
                     custom_conv2d(64, k_h=4, k_w=4).
                     apply(leaky_rectify).
                     custom_conv2d(128, k_h=4, k_w=4).
                     conv_batch_norm().
                     apply(leaky_rectify).
                     custom_fully_connected(1024).
                     fc_batch_norm().
                     apply(leaky_rectify))
                self.discriminator_template = shared_template.custom_fully_connected(
                    1)

                self.encoder_template = \
                    (shared_template.
                     custom_fully_connected(128).
                     fc_batch_norm().
                     apply(leaky_rectify).
                     custom_fully_connected(self.reg_latent_dist.dist_flat_dim))

            with tf.variable_scope("g_net"):
                #SOMEWHAT CONSISTENT. MIGHT CHANGE
                self.generator_template = \
                    (pt.template("input").
                     custom_fully_connected(1024).
                     fc_batch_norm().
                     apply(tf.nn.relu).
                     custom_fully_connected(image_size / 4 * image_size / 4 * 128).
                     fc_batch_norm().
                     apply(tf.nn.relu).
                     reshape([-1, image_size / 4, image_size / 4, 128]).
                     custom_deconv2d([0, image_size / 2, image_size / 2, 64], k_h=4, k_w=4).
                     conv_batch_norm().
                     apply(tf.nn.relu).
                     custom_deconv2d([0] + list(image_shape), k_h=4, k_w=4).
                     flatten())

        #HEART!!!
        elif network_type == 'heart':
            with tf.variable_scope("d_net"):
                shared_template = \
                    (pt.template("input").
                     reshape([-1] + list(image_shape)).
                     custom_conv2d(64, k_h=4, k_w=4).
                     apply(leaky_rectify).
                     custom_conv2d(128, k_h=4, k_w=4).
                     conv_batch_norm().
                     apply(leaky_rectify).
                     custom_fully_connected(1024).
                     fc_batch_norm().
                     apply(leaky_rectify))
                self.discriminator_template = shared_template.custom_fully_connected(
                    1)
                #THIS ENCODER DOESNT SEEM CONISTENT WITH FACES. THAT'S OKAY. WILL
                #TRY ANYWAY.
                self.encoder_template = \
                    (shared_template.
                     custom_fully_connected(128).
                     fc_batch_norm().
                     apply(leaky_rectify).
                     custom_fully_connected(self.reg_latent_dist.dist_flat_dim))

            with tf.variable_scope("g_net"):
                self.generator_template = \
                    (pt.template("input").
                     custom_fully_connected(1024).
                     fc_batch_norm().
                     apply(tf.nn.relu).
                     custom_fully_connected(image_size / 4 * image_size / 4 * 128).
                     fc_batch_norm().
                     apply(tf.nn.relu).
                     reshape([-1, image_size / 4, image_size / 4, 128]).
                     custom_deconv2d([0, image_size / 2, image_size / 2, 64], k_h=4, k_w=4).
                     conv_batch_norm().
                     apply(tf.nn.relu).
                     #THIS CONV APPEARS TO BE EXTRA. WILL KEEP ANYWAY
                     custom_deconv2d([0] + list(image_shape), k_h=4, k_w=4).
                     flatten())
        else:
            raise NotImplementedError
Example #22
0
 def context_embedding(self):
     template = (pt.template("input").custom_fully_connected(
         self.ef_dim).apply(leaky_rectify, leakiness=0.2))
     return template
Example #23
0
    def __init__(self, output_dist, latent_spec, batch_size, image_shape,
                 network_type):
        """
        :type output_dist: Distribution
        :type latent_spec: list[(Distribution, bool)]
        :type batch_size: int
        :type network_type: string
        """
        self.output_dist = output_dist
        self.latent_spec = latent_spec
        self.latent_dist = Product([x for x, _ in latent_spec])
        self.reg_latent_dist = Product([x for x, reg in latent_spec if reg])
        self.nonreg_latent_dist = Product(
            [x for x, reg in latent_spec if not reg])
        self.batch_size = batch_size
        self.network_type = network_type
        self.image_shape = image_shape
        assert all(
            isinstance(x, (Gaussian, Categorical, Bernoulli))
            for x in self.reg_latent_dist.dists)

        self.reg_cont_latent_dist = Product(
            [x for x in self.reg_latent_dist.dists if isinstance(x, Gaussian)])
        self.reg_disc_latent_dist = Product([
            x for x in self.reg_latent_dist.dists
            if isinstance(x, (Categorical, Bernoulli))
        ])

        image_size = image_shape[0]
        if network_type == "mnist":
            with tf.variable_scope("d_net"):
                shared_template = \
                    (pt.template("input").
                     reshape([-1] + list(image_shape)).
                     custom_conv2d(64, k_h=4, k_w=4).
                     apply(leaky_rectify).
                     custom_conv2d(128, k_h=4, k_w=4).
                     conv_batch_norm().
                     apply(leaky_rectify).
                     custom_fully_connected(1024).
                     fc_batch_norm().
                     apply(leaky_rectify))
                self.discriminator_template = shared_template.custom_fully_connected(
                    1)
                self.encoder_template = \
                    (shared_template.
                     custom_fully_connected(128).
                     fc_batch_norm().
                     apply(leaky_rectify).
                     custom_fully_connected(self.reg_latent_dist.dist_flat_dim))

            with tf.variable_scope("g_net"):
                self.generator_template = \
                    (pt.template("input").
                     custom_fully_connected(1024).
                     fc_batch_norm().
                     apply(tf.nn.relu).
                     custom_fully_connected(image_size / 4 * image_size / 4 * 128).
                     fc_batch_norm().
                     apply(tf.nn.relu).
                     reshape([-1, image_size / 4, image_size / 4, 128]).
                     custom_deconv2d([0, image_size / 2, image_size / 2, 64], k_h=4, k_w=4).
                     conv_batch_norm().
                     apply(tf.nn.relu).
                     custom_deconv2d([0] + list(image_shape), k_h=4, k_w=4).
                     flatten())

        elif network_type == "celebA":
            with tf.variable_scope("d_net"):
                shared_template = \
                    (pt.template("input").
                     reshape([-1] + list(image_shape)).
                     custom_conv2d(64, k_h=4, k_w=4).
                     apply(leaky_rectify).
                     custom_conv2d(128, k_h=4, k_w=4).
                     conv_batch_norm().
                     apply(leaky_rectify).
                     custom_conv2d(256, k_h=4, k_w=4).
                     conv_batch_norm().
                     apply(leaky_rectify))
                self.discriminator_template = shared_template.custom_fully_connected(
                    1)
                self.encoder_template = \
                    (shared_template.
                     custom_fully_connected(128).
                     fc_batch_norm().
                     apply(leaky_rectify).
                     custom_fully_connected(self.reg_latent_dist.dist_flat_dim))

            with tf.variable_scope("g_net"):
                self.generator_template = \
                    (pt.template("input").
                     custom_fully_connected(image_size / 16 * image_size / 16 * 448).
                     fc_batch_norm().
                     apply(tf.nn.relu).
                     reshape([-1, image_size / 16, image_size / 16, 448]).
                     # I am *pretty sure* each of these dimensions grows by 2x
                     # because the stride==2.
                     custom_deconv2d([0, image_size / 8, image_size / 8, 256], k_h=4, k_w=4).
                     conv_batch_norm().
                     apply(tf.nn.relu).
                     custom_deconv2d([0, image_size / 4, image_size / 4, 128], k_h=4, k_w=4).
                     apply(tf.nn.relu).
                     custom_deconv2d([0, image_size / 2, image_size / 2, 64], k_h=4, k_w=4).
                     apply(tf.nn.relu).
                     custom_deconv2d([0, image_size / 1, image_size / 1, 3], k_h=4, k_w=4).
                     apply(tf.nn.tanh).
                     flatten())

        elif network_type == "face":
            with tf.variable_scope("d_net"):
                shared_template = \
                    (pt.template("input").
                     reshape([-1] + list(image_shape)).
                     custom_conv2d(64, k_h=4, k_w=4).
                     apply(leaky_rectify).
                     custom_conv2d(128, k_h=4, k_w=4).
                     conv_batch_norm().
                     apply(leaky_rectify).
                     custom_fully_connected(1024).
                     fc_batch_norm().
                     apply(leaky_rectify))
                self.discriminator_template = shared_template.custom_fully_connected(
                    1)
                self.encoder_template = \
                    (shared_template.
                     custom_fully_connected(self.reg_latent_dist.dist_flat_dim))

            with tf.variable_scope("g_net"):
                self.generator_template = \
                    (pt.template("input").
                     custom_fully_connected(1024).
                     fc_batch_norm().
                     apply(tf.nn.relu).
                     custom_fully_connected(image_size / 4 * image_size / 4 * 128).
                     fc_batch_norm().
                     apply(tf.nn.relu).
                     reshape([-1, image_size / 4, image_size / 4, 128]).
                     custom_deconv2d([0, image_size / 2, image_size / 2, 64], k_h=4, k_w=4).
                     conv_batch_norm().
                     apply(tf.nn.relu).
                     custom_deconv2d([0] + list(image_shape), k_h=4, k_w=4).
                     apply(tf.nn.sigmoid).
                     flatten())
        else:
            raise NotImplementedError
Example #24
0
    def __init__(self, output_dist, latent_spec, batch_size, image_shape,
                 network_type):
        """
        :type output_dist: Distribution
        :type latent_spec: list[(Distribution, bool)]
        :type batch_size: int
        :type network_type: string
        """
        self.output_dist = output_dist
        pstr('output_dist', self.output_dist)
        self.latent_spec = latent_spec
        self.latent_dist = Product([x for x, _ in latent_spec])
        pstr('latent_dist', self.latent_dist)
        pstr('x in latent_spec', [x for x, _ in self.latent_spec])
        pstr('xreg in latent_spec', [xreg for _, xreg in self.latent_spec])
        #for x in enumerate(self.latent_spec):
        #   print '------------------------'
        #   for y in enumerate(x):
        #      pstrall('x----reg',y)

        self.reg_latent_dist = Product([x for x, reg in latent_spec if reg])
        self.nonreg_latent_dist = Product(
            [x for x, reg in latent_spec if not reg])
        self.batch_size = batch_size
        self.network_type = network_type
        self.image_shape = image_shape
        assert all(
            isinstance(x, (Gaussian, Categorical, Bernoulli))
            for x in self.reg_latent_dist.dists)
        #for x in self.reg_latent_dist.dists:
        #   pstr('x in reg_latent_dist.dists',x)

        self.reg_cont_latent_dist = Product(
            [x for x in self.reg_latent_dist.dists if isinstance(x, Gaussian)])
        self.reg_disc_latent_dist = Product([
            x for x in self.reg_latent_dist.dists
            if isinstance(x, (Categorical, Bernoulli))
        ])

        pstr('image_shape', image_shape)
        pstr('image_shape[0]', image_shape[0])
        image_size = image_shape[0]

        #self.image_shape = (178, 218, 1)

        if network_type == "mnist":
            with tf.variable_scope("d_net"):
                shared_template = \
                    (pt.template("input").
                     reshape([-1] + list(image_shape)).
                     custom_conv2d(64, k_h=4, k_w=4).
                     #conv_batch_norm().
                     apply(leaky_rectify).

                     custom_conv2d(128, k_h=4, k_w=4).
                     conv_batch_norm().
                     apply(leaky_rectify).

                     custom_conv2d(256, k_h=4, k_w=4).
                     conv_batch_norm().
                     apply(leaky_rectify).

                     #custom_fully_connected(1024).
                     #fc_batch_norm().
                     #apply(leaky_rectify).
                     custom_conv2d(512, k_h=4, k_w=4))
                #conv_batch_norm().
                #apply(leaky_rectify2))

                #linear

                #apply(tf.nn.sigmoid))
                self.discriminator_template = shared_template.custom_fully_connected(
                    1)
                self.encoder_template = \
                    (shared_template.
                     custom_fully_connected(128).
                     fc_batch_norm().
                     apply(leaky_rectify).
                     custom_fully_connected(self.reg_latent_dist.dist_flat_dim))

            with tf.variable_scope("g_net"):
                s = self.image_shape[0]
                s2, s4, s8, s16, s32 = int(s / 2), int(s / 4), int(s / 8), int(
                    s / 16), int(s / 32)
                self.generator_template = \
                    (pt.template("input").

                     custom_fully_connected(s16 * s16 * 512).
                     fc_batch_norm().
                     apply(tf.nn.relu).
                     reshape([-1, s16, s16,  512]).

                     #custom_fully_connected(s32 * s32 * 1024).
                     #fc_batch_norm().
                     #apply(tf.nn.relu).
                     #reshape([-1, s32, s32,  1024]).

                     #custom_deconv2d([0, s16, s16,  512], k_h=4, k_w=4).
                     #conv_batch_norm().
                     #apply(tf.nn.relu).

                     custom_deconv2d([0, s8, s8,  256], k_h=4, k_w=4).
                     conv_batch_norm().
                     apply(tf.nn.relu).

                     custom_deconv2d([0, s4, s4, 128], k_h=4, k_w=4).
                     #conv_batch_norm().
                     apply(tf.nn.relu).

                     custom_deconv2d([0, s2, s2, 64], k_h=4, k_w=4).
                     #conv_batch_norm().
                     apply(tf.nn.relu).
                     custom_deconv2d([0] + list(image_shape), k_h=4, k_w=4).
                     apply(tf.nn.tanh))

        else:
            raise NotImplementedError