def __init__(self, output_dist, latent_spec, batch_size, image_shape, network_type): """ :type output_dist: Distribution :type latent_spec: list[(Distribution, bool)] :type batch_size: int :type network_type: string """ self.output_dist = output_dist self.latent_spec = latent_spec self.latent_dist = Product([x for x, _ in latent_spec]) self.reg_latent_dist = Product([x for x, reg in latent_spec if reg]) self.nonreg_latent_dist = Product([x for x, reg in latent_spec if not reg]) self.batch_size = batch_size self.network_type = network_type self.image_shape = image_shape assert all(isinstance(x, (Gaussian, Categorical, Bernoulli)) for x in self.reg_latent_dist.dists) self.reg_cont_latent_dist = Product([x for x in self.reg_latent_dist.dists if isinstance(x, Gaussian)]) self.reg_disc_latent_dist = Product([x for x in self.reg_latent_dist.dists if isinstance(x, (Categorical, Bernoulli))]) image_size = image_shape[0] if network_type == "mnist": with tf.variable_scope("d_net"): shared_template = \ (pt.template("input"). reshape([-1] + list(image_shape)). custom_conv2d(64, k_h=4, k_w=4). apply(leaky_rectify). custom_conv2d(128, k_h=4, k_w=4). conv_batch_norm(). apply(leaky_rectify). custom_fully_connected(1024). fc_batch_norm(). apply(leaky_rectify)) self.discriminator_template = shared_template.custom_fully_connected(1) self.encoder_template = \ (shared_template. custom_fully_connected(128). fc_batch_norm(). apply(leaky_rectify). custom_fully_connected(self.reg_latent_dist.dist_flat_dim)) with tf.variable_scope("g_net"): self.generator_template = \ (pt.template("input"). custom_fully_connected(1024). fc_batch_norm(). apply(tf.nn.relu). custom_fully_connected(image_size // 4 * image_size // 4 * 128). fc_batch_norm(). apply(tf.nn.relu). reshape([-1, image_size // 4, image_size // 4, 128]). custom_deconv2d([0, image_size // 2, image_size // 2, 64], k_h=4, k_w=4). conv_batch_norm(). apply(tf.nn.relu). custom_deconv2d([0] + list(image_shape), k_h=4, k_w=4). flatten()) else: raise NotImplementedError
def d_encode_image(self): node1_0 = \ (pt.template("input"). custom_conv2d(self.df_dim, k_h=4, k_w=4). apply(leaky_rectify, leakiness=0.2). custom_conv2d(self.df_dim * 2, k_h=4, k_w=4). conv_batch_norm(). apply(leaky_rectify, leakiness=0.2). custom_conv2d(self.df_dim * 4, k_h=4, k_w=4). conv_batch_norm(). custom_conv2d(self.df_dim * 8, k_h=4, k_w=4). conv_batch_norm()) node1_1 = \ (node1_0. custom_conv2d(self.df_dim * 2, k_h=1, k_w=1, d_h=1, d_w=1). conv_batch_norm(). apply(leaky_rectify, leakiness=0.2). custom_conv2d(self.df_dim * 2, k_h=3, k_w=3, d_h=1, d_w=1). conv_batch_norm(). apply(leaky_rectify, leakiness=0.2). custom_conv2d(self.df_dim * 8, k_h=3, k_w=3, d_h=1, d_w=1). conv_batch_norm()) node1 = \ (node1_0. apply(tf.add, node1_1). apply(leaky_rectify, leakiness=0.2)) return node1
def discriminator(self): template = \ (pt.template("input"). # 128*9*4*4 custom_conv2d(self.df_dim * 8, k_h=1, k_w=1, d_h=1, d_w=1). # 128*8*4*4 conv_batch_norm(). apply(leaky_rectify, leakiness=0.2). # custom_fully_connected(1)) custom_conv2d(1, k_h=self.s16, k_w=self.s16, d_h=self.s16, d_w=self.s16)) return template
def testGraphMatchesImmediate(self): """Ensures that the vars line up between the two modes.""" with tf.Graph().as_default(): input_pt = prettytensor.wrap(self.input) self.BuildLargishGraph(input_pt) normal_names = sorted([v.name for v in tf.all_variables()]) with tf.Graph().as_default(): template = prettytensor.template('input') self.BuildLargishGraph(template).construct( input=prettytensor.wrap(self.input)) template_names = sorted([v.name for v in tf.all_variables()]) self.assertSequenceEqual(normal_names, template_names)
def d_encode_image_simple(self): template = \ (pt.template("input"). custom_conv2d(self.df_dim, k_h=4, k_w=4). apply(leaky_rectify, leakiness=0.2). custom_conv2d(self.df_dim * 2, k_h=4, k_w=4). conv_batch_norm(). apply(leaky_rectify, leakiness=0.2). custom_conv2d(self.df_dim * 4, k_h=4, k_w=4). conv_batch_norm(). apply(leaky_rectify, leakiness=0.2). custom_conv2d(self.df_dim * 8, k_h=4, k_w=4). conv_batch_norm(). apply(leaky_rectify, leakiness=0.2)) return template
def testGraphMatchesImmediate(self): """Ensures that the vars line up between the two modes.""" with tf.Graph().as_default(): input_pt = prettytensor.wrap( tf.constant(self.input_data, dtype=tf.float32)) self.BuildLargishGraph(input_pt) normal_names = sorted([v.name for v in tf.global_variables()]) with tf.Graph().as_default(): template = prettytensor.template('input') self.BuildLargishGraph(template).construct(input=prettytensor.wrap( tf.constant(self.input_data, dtype=tf.float32))) template_names = sorted([v.name for v in tf.global_variables()]) self.assertSequenceEqual(normal_names, template_names)
def shared_net(self): shared_template = \ (pt.template("input"). reshape([-1] + list(self.image_shape)). custom_conv2d(self.df_dim, name='d_h0_conv', k_h=self.k_h, k_w=self.k_w). apply(leaky_rectify). custom_conv2d(self.df_dim*2, name='d_h1_conv', k_h=self.k_h, k_w=self.k_w). conv_batch_norm(). apply(leaky_rectify). custom_conv2d(self.df_dim*4, name='d_h2_conv', k_h=self.k_h, k_w=self.k_w). conv_batch_norm(). apply(leaky_rectify). custom_conv2d(self.df_dim*8, name='d_h3_conv', k_h=self.k_h, k_w=self.k_w). conv_batch_norm(). apply(leaky_rectify)) return shared_template
def infoGAN_mnist_net(self, image_shape): image_size = image_shape[0] generator_template = \ (pt.template("input"). custom_fully_connected(1024). fc_batch_norm(). apply(tf.nn.relu). custom_fully_connected(image_size / 4 * image_size / 4 * 128). fc_batch_norm(). apply(tf.nn.relu). reshape([-1, image_size / 4, image_size / 4, 128]). custom_deconv2d([0, image_size/2, image_size/2, 64], k_h=4, k_w=4). conv_batch_norm(). apply(tf.nn.relu). custom_deconv2d([0] + list(image_shape), k_h=4, k_w=4). flatten()) return generator_template
def hr_d_encode_image(self): node1_0 = \ (pt.template("input"). # 4s * 4s * 3 custom_conv2d(self.df_dim, k_h=4, k_w=4). # 2s * 2s * df_dim apply(leaky_rectify, leakiness=0.2). custom_conv2d(self.df_dim * 2, k_h=4, k_w=4). # s * s * df_dim*2 conv_batch_norm(). apply(leaky_rectify, leakiness=0.2). custom_conv2d(self.df_dim * 4, k_h=4, k_w=4). # s2 * s2 * df_dim*4 conv_batch_norm(). apply(leaky_rectify, leakiness=0.2). custom_conv2d(self.df_dim * 8, k_h=4, k_w=4). # s4 * s4 * df_dim*8 conv_batch_norm(). apply(leaky_rectify, leakiness=0.2). custom_conv2d(self.df_dim * 16, k_h=4, k_w=4). # s8 * s8 * df_dim*16 conv_batch_norm(). apply(leaky_rectify, leakiness=0.2). custom_conv2d(self.df_dim * 32, k_h=4, k_w=4). # s16 * s16 * df_dim*32 conv_batch_norm(). apply(leaky_rectify, leakiness=0.2). custom_conv2d(self.df_dim * 16, k_h=1, k_w=1, d_h=1, d_w=1). # s16 * s16 * df_dim*16 conv_batch_norm(). apply(leaky_rectify, leakiness=0.2). custom_conv2d(self.df_dim * 8, k_h=1, k_w=1, d_h=1, d_w=1). # s16 * s16 * df_dim*8 conv_batch_norm()) node1_1 = \ (node1_0. custom_conv2d(self.df_dim * 2, k_h=1, k_w=1, d_h=1, d_w=1). conv_batch_norm(). apply(leaky_rectify, leakiness=0.2). custom_conv2d(self.df_dim * 2, k_h=3, k_w=3, d_h=1, d_w=1). conv_batch_norm(). apply(leaky_rectify, leakiness=0.2). custom_conv2d(self.df_dim * 8, k_h=3, k_w=3, d_h=1, d_w=1). conv_batch_norm()) node1 = \ (node1_0. apply(tf.add, node1_1). apply(leaky_rectify, leakiness=0.2)) return node1
def conv_ae(scope, filter_no, img_length=64, bottleneck=4, channel=3, act_fn=tf.nn.relu, last_act=tf.tanh): with tf.variable_scope(scope): with pt.defaults_scope(activation_fn=act_fn): layer = pt.template('batch').conv2d(4, filter_no, stride=2, name='conv1') img_length >>= 1 i = 0 while img_length > bottleneck: filter_no <<= 1 img_length >>= 1 layer = layer.conv2d(4, filter_no, stride=2, name='conv%d' % (i + 2)) i += 1 for j in range(i): filter_no >>= 1 img_length <<= 1 layer = layer.deconv2d(4, filter_no, [-1, img_length, img_length, filter_no], stride=2, name='deconv%d' % (j + 1)) img_length <<= 1 return layer.deconv2d(4, channel, [-1, img_length, img_length, channel], stride=2, name='deconv%d' % (i + 1), activation_fn=last_act) return cae_tpl
def shared_net(self): shared_template = \ (pt.template("input"). reshape([-1] + list(self.image_shape)). custom_conv2d(self.df_dim, name='d_h0_conv', k_h=self.k_h, k_w=self.k_w). apply(leaky_rectify). custom_conv2d(self.df_dim*2, name='d_h1_conv', k_h=self.k_h, k_w=self.k_w). conv_batch_norm(addNoise=self.addNoise). apply(leaky_rectify). custom_conv2d(self.df_dim*4, name='d_h2_conv', k_h=self.k_h, k_w=self.k_w). conv_batch_norm(addNoise=self.addNoise). apply(leaky_rectify). custom_conv2d(self.df_dim*8, name='d_h3_conv', k_h=self.k_h, k_w=self.k_w). conv_batch_norm(addNoise=self.addNoise). apply(leaky_rectify,name="OutDiscriminator")) #.custom_fully_connected(512)) self.intermLayer = (shared_template.as_layer()) #shared_template=shared_template.apply(leaky_rectify,name="OutDiscriminator") return shared_template
def conv_gen_bn(scope, filter_no, z_dim, img_size=64, bottleneck=4, channel=3, bn_arg=False, act_fn=tf.nn.relu, last_act=tf.tanh): if not bn_arg: bias = tf.zeros_initializer() else: bias = None with tf.variable_scope(scope): with pt.defaults_scope(activation_fn=act_fn, batch_normalize=bn_arg): layer = pt.template('batch').reshape((-1, 1, 1, z_dim)) \ .deconv2d(bottleneck, filter_no, [-1, bottleneck, bottleneck, filter_no], stride=1, edges=pt.pretty_tensor_class.PAD_VALID, name='deconv1', bias=bias) img_length = bottleneck i = 2 while img_length < img_size / 2: filter_no >>= 1 img_length <<= 1 layer = layer.deconv2d(4, filter_no, [-1, img_length, img_length, filter_no], stride=2, name='deconv%d' % i, bias=bias) i += 1 img_length <<= 1 return layer.deconv2d(4, channel, [-1, img_length, img_length, channel], stride=2, activation_fn=last_act, name='deconv%d' % i, batch_normalize=False)
def gen_net(self, image_shape): sx = image_shape[0] sy = image_shape[1] sx2, sx4, sx8, sx16 = int(np.ceil(sx*1.0/2)), int(np.ceil(sx*1.0/4)), int(np.ceil(sx*1.0/8)), int(np.ceil(sx*1.0/16)) sy2, sy4, sy8, sy16 = int(np.ceil(sy*1.0/2)), int(np.ceil(sy*1.0/4)), int(np.ceil(sy*1.0/8)), int(np.ceil(sy*1.0/16)) if (sx == 96) and (sy == 96): self.k_h = self.k_w = 5 elif (sx < 20) or (sy < 20): self.k_h = self.k_w = 2 else: self.k_h = self.k_w = 3 generator_template = \ (pt.template("input"). custom_fully_connected(self.gf_dim*8*sx16*sy16, scope='g_h0_lin'). fc_batch_norm(). apply(tf.nn.relu). reshape([-1, sx16, sy16, self.gf_dim * 8]). custom_deconv2d([self.batch_size, sx8, sy8, self.gf_dim*4], name='g_h1', k_h=self.k_h, k_w=self.k_w,useResize=self.improved). conv_batch_norm(). apply(tf.nn.relu). custom_deconv2d([self.batch_size, sx4, sy4, self.gf_dim*2], name='g_h2', k_h=self.k_h, k_w=self.k_w,useResize=self.improved). conv_batch_norm(). apply(tf.nn.relu). custom_deconv2d([self.batch_size, sx2, sy2, self.gf_dim*1], name='g_h3', k_h=self.k_h,useResize=self.improved). conv_batch_norm(). apply(tf.nn.relu). custom_deconv2d([self.batch_size, sx, sy, self.c_dim], name='g_h4', k_h=self.k_h, k_w=self.k_w,useResize=self.improved). apply(tf.nn.tanh,name="OutGenerator")) return generator_template
def discriminator_template(): num_filters = FLAGS.discrim_filter_base with tf.variable_scope('discriminator'): tmp = pt.template('input') for i in xrange(input.NUM_LEVELS): if i > 0: tmp = tmp.dropout(FLAGS.keep_prob) tmp = tmp.conv2d(5, num_filters) if i > 0: tmp = tmp.batch_normalize() tmp = tmp.apply(discrim_activation_fn).max_pool(2, 2) num_filters *= 2 tmp = tmp.flatten() features = tmp minibatch_discrim = features.minibatch_discrimination(100) for i in xrange(FLAGS.discrim_fc_layers - 1): tmp = tmp.fully_connected( FLAGS.discrim_fc_size).apply(discrim_activation_fn) tmp = tmp.concat(1, [minibatch_discrim]).fully_connected(1) output = tmp return output
tf.zeros([FLAGS.batch_size, FLAGS.rnn_size], tf.float32)), ) sampled_tensors = [] glimpse_tensors = [] write_tensors = [] params_tensors = [] loss = 0.0 with tf.variable_scope("model"): with pt.defaults_scope(activation_fn=tf.nn.elu, batch_normalize=True, learned_moments_update_rate=0.1, variance_epsilon=0.001, scale_after_normalization=True): # Encoder RNN (Eq. 5) encoder_template = (pt.template('input').gru_cell( num_units=FLAGS.rnn_size, state=pt.UnboundVariable('state'))) # Projection of encoder RNN output (Eq. 1-2) encoder_proj_template = (pt.template('input').fully_connected( FLAGS.hidden_size * 2, activation_fn=None)) # Params of read from decoder RNN output (Eq. 21) decoder_read_params_template = ( pt.template('input').fully_connected(5, activation_fn=None)) # Decoder RNN (Eq. 7) decoder_template = (pt.template('input').gru_cell( num_units=FLAGS.rnn_size, state=pt.UnboundVariable('state'))) # Projection of decoder RNN output (Eq. 18) decoder_proj_template = (pt.template('input').fully_connected(
sampled_state = (pt.wrap(tf.zeros([FLAGS.batch_size, FLAGS.rnn_size], tf.float32)),) sampled_tensors = [] glimpse_tensors = [] write_tensors = [] params_tensors = [] loss = 0.0 with tf.variable_scope("model"): with pt.defaults_scope(activation_fn=tf.nn.elu, batch_normalize=True, learned_moments_update_rate=0.1, variance_epsilon=0.001, scale_after_normalization=True): # Encoder RNN (Eq. 5) encoder_template = (pt.template('input'). gru_cell(num_units=FLAGS.rnn_size, state=pt.UnboundVariable('state'))) # Projection of encoder RNN output (Eq. 1-2) encoder_proj_template = (pt.template('input'). fully_connected(FLAGS.hidden_size * 2, activation_fn=None)) # Params of read from decoder RNN output (Eq. 21) decoder_read_params_template = (pt.template('input'). fully_connected(5, activation_fn=None)) # Decoder RNN (Eq. 7) decoder_template = (pt.template('input'). gru_cell(num_units=FLAGS.rnn_size, state=pt.UnboundVariable('state'))) # Projection of decoder RNN output (Eq. 18) decoder_proj_template = (pt.template('input').
def build_network(self, scope_suffix=''): with tf.variable_scope("d_net{}".format(scope_suffix)): paddings = [[0, 0], [1, 1], [1, 1], [0, 0]] mode = "CONSTANT" self.discriminator_template = \ (pt.template("input"). reshape([-1] + list(self.output_shape)). custom_conv2d(64, k_h=4, k_w=4, d_h=2, d_w=2). apply(leaky_rectify, 0.2). custom_conv2d(128, k_h=4, k_w=4, d_h=2, d_w=2). conv_instance_norm(). apply(leaky_rectify, 0.2). custom_conv2d(256, k_h=4, k_w=4, d_h=2, d_w=2). conv_instance_norm(). apply(leaky_rectify, 0.2). apply(tf.pad, paddings, mode). custom_conv2d(512, k_h=4, k_w=4, d_h=2, d_w=2, padding='VALID'). conv_instance_norm(). apply(leaky_rectify, 0.2). apply(tf.pad, paddings, mode). custom_conv2d(1, k_h=4, k_w=4, d_h=1, d_w=1, padding='VALID')) with tf.variable_scope("g_net{}".format(scope_suffix)): # TODO: Add reflection padding # with apply(tf.pad([[?,?], [?,?]], 'REFLECT')) # and padding='VALID' in custom_conv_2d paddings = [[0, 0], [1, 1], [1, 1], [0, 0]] mode = "REFLECT" self.generator_template = \ (pt.template("input"). reshape([-1] + list(self.input_shape)). apply(tf.pad, [[0, 0], [3, 3], [3, 3], [0, 0]], mode). custom_conv2d(32, k_h=7, k_w=7, d_h=1, d_w=1, padding='VALID'). conv_instance_norm(). apply(tf.nn.relu). apply(tf.pad, paddings, mode). custom_conv2d(64, k_h=3, k_w=3, d_h=2, d_w=2, padding='VALID'). conv_instance_norm(). apply(tf.nn.relu). apply(tf.pad, paddings, mode). custom_conv2d(128, k_h=3, k_w=3, d_h=2, d_w=2, padding='VALID'). conv_instance_norm(). apply(tf.nn.relu). custom_residual(paddings, mode, custom_conv2d, 128, k_h=3, k_w=3, d_h=1, d_w=1, padding='VALID'). custom_residual(paddings, mode, custom_conv2d, 128, k_h=3, k_w=3, d_h=1, d_w=1, padding='VALID'). custom_residual(paddings, mode, custom_conv2d, 128, k_h=3, k_w=3, d_h=1, d_w=1, padding='VALID'). custom_residual(paddings, mode, custom_conv2d, 128, k_h=3, k_w=3, d_h=1, d_w=1, padding='VALID'). custom_residual(paddings, mode, custom_conv2d, 128, k_h=3, k_w=3, d_h=1, d_w=1, padding='VALID'). custom_residual(paddings, mode, custom_conv2d, 128, k_h=3, k_w=3, d_h=1, d_w=1, padding='VALID'). custom_residual(paddings, mode, custom_conv2d, 128, k_h=3, k_w=3, d_h=1, d_w=1, padding='VALID'). custom_residual(paddings, mode, custom_conv2d, 128, k_h=3, k_w=3, d_h=1, d_w=1, padding='VALID'). custom_residual(paddings, mode, custom_conv2d, 128, k_h=3, k_w=3, d_h=1, d_w=1, padding='VALID'). custom_deconv2d([0, self.output_size / 2, self.output_size / 2, 64], k_h=3, k_w=3, d_h=2, d_w=2, padding='SAME'). # 'VALID' or 'SAME' ? conv_instance_norm(). apply(tf.nn.relu). custom_deconv2d([0, self.output_size, self.output_size, 32], k_h=3, k_w=3, d_h=2, d_w=2, padding='SAME'). # 'VALID' or 'SAME' ? conv_instance_norm(). apply(tf.nn.relu). apply(tf.pad, [[0, 0], [3, 3], [3, 3], [0, 0]], mode). custom_conv2d(3, k_h=7, k_w=7, d_h=1, d_w=1, padding='VALID'). # conv_instance_norm(). # apply(tf.nn.relu). apply(tf.nn.tanh). flatten())
def context_embedding(self): template = (pt.template("input"). custom_fully_connected(self.ef_dim). apply(leaky_rectify, leakiness=0.2)) return template
def Template(self, key): return prettytensor.template(key, self.bookkeeper)
def __init__(self, output_dist, latent_spec, batch_size, image_shape, network_type): """ :type output_dist: Distribution :type latent_spec: list[(Distribution, bool)] :type batch_size: int :type network_type: string """ self.output_dist = output_dist self.latent_spec = latent_spec self.latent_dist = Product([x for x, _ in latent_spec]) self.reg_latent_dist = Product([x for x, reg in latent_spec if reg]) self.nonreg_latent_dist = Product( [x for x, reg in latent_spec if not reg]) self.batch_size = batch_size self.network_type = network_type self.image_shape = image_shape assert all( isinstance(x, (Gaussian, Categorical, Bernoulli)) for x in self.reg_latent_dist.dists) self.reg_cont_latent_dist = Product( [x for x in self.reg_latent_dist.dists if isinstance(x, Gaussian)]) self.reg_disc_latent_dist = Product([ x for x in self.reg_latent_dist.dists if isinstance(x, (Categorical, Bernoulli)) ]) image_size = image_shape[0] if network_type == "mnist": with tf.variable_scope("d_net"): shared_template = \ (pt.template("input"). reshape([-1] + list(image_shape)). custom_conv2d(64, k_h=4, k_w=4). apply(leaky_rectify). custom_conv2d(128, k_h=4, k_w=4). conv_batch_norm(). apply(leaky_rectify). custom_fully_connected(1024). fc_batch_norm(). apply(leaky_rectify)) self.discriminator_template = shared_template.custom_fully_connected( 1) self.encoder_template = \ (shared_template. custom_fully_connected(128). fc_batch_norm(). apply(leaky_rectify). custom_fully_connected(self.reg_latent_dist.dist_flat_dim)) with tf.variable_scope("g_net"): #SOMEWHAT CONSISTENT. MIGHT CHANGE self.generator_template = \ (pt.template("input"). custom_fully_connected(1024). fc_batch_norm(). apply(tf.nn.relu). custom_fully_connected(image_size / 4 * image_size / 4 * 128). fc_batch_norm(). apply(tf.nn.relu). reshape([-1, image_size / 4, image_size / 4, 128]). custom_deconv2d([0, image_size / 2, image_size / 2, 64], k_h=4, k_w=4). conv_batch_norm(). apply(tf.nn.relu). custom_deconv2d([0] + list(image_shape), k_h=4, k_w=4). flatten()) #HEART!!! elif network_type == 'heart': with tf.variable_scope("d_net"): shared_template = \ (pt.template("input"). reshape([-1] + list(image_shape)). custom_conv2d(64, k_h=4, k_w=4). apply(leaky_rectify). custom_conv2d(128, k_h=4, k_w=4). conv_batch_norm(). apply(leaky_rectify). custom_fully_connected(1024). fc_batch_norm(). apply(leaky_rectify)) self.discriminator_template = shared_template.custom_fully_connected( 1) #THIS ENCODER DOESNT SEEM CONISTENT WITH FACES. THAT'S OKAY. WILL #TRY ANYWAY. self.encoder_template = \ (shared_template. custom_fully_connected(128). fc_batch_norm(). apply(leaky_rectify). custom_fully_connected(self.reg_latent_dist.dist_flat_dim)) with tf.variable_scope("g_net"): self.generator_template = \ (pt.template("input"). custom_fully_connected(1024). fc_batch_norm(). apply(tf.nn.relu). custom_fully_connected(image_size / 4 * image_size / 4 * 128). fc_batch_norm(). apply(tf.nn.relu). reshape([-1, image_size / 4, image_size / 4, 128]). custom_deconv2d([0, image_size / 2, image_size / 2, 64], k_h=4, k_w=4). conv_batch_norm(). apply(tf.nn.relu). #THIS CONV APPEARS TO BE EXTRA. WILL KEEP ANYWAY custom_deconv2d([0] + list(image_shape), k_h=4, k_w=4). flatten()) else: raise NotImplementedError
def context_embedding(self): template = (pt.template("input").custom_fully_connected( self.ef_dim).apply(leaky_rectify, leakiness=0.2)) return template
def __init__(self, output_dist, latent_spec, batch_size, image_shape, network_type): """ :type output_dist: Distribution :type latent_spec: list[(Distribution, bool)] :type batch_size: int :type network_type: string """ self.output_dist = output_dist self.latent_spec = latent_spec self.latent_dist = Product([x for x, _ in latent_spec]) self.reg_latent_dist = Product([x for x, reg in latent_spec if reg]) self.nonreg_latent_dist = Product( [x for x, reg in latent_spec if not reg]) self.batch_size = batch_size self.network_type = network_type self.image_shape = image_shape assert all( isinstance(x, (Gaussian, Categorical, Bernoulli)) for x in self.reg_latent_dist.dists) self.reg_cont_latent_dist = Product( [x for x in self.reg_latent_dist.dists if isinstance(x, Gaussian)]) self.reg_disc_latent_dist = Product([ x for x in self.reg_latent_dist.dists if isinstance(x, (Categorical, Bernoulli)) ]) image_size = image_shape[0] if network_type == "mnist": with tf.variable_scope("d_net"): shared_template = \ (pt.template("input"). reshape([-1] + list(image_shape)). custom_conv2d(64, k_h=4, k_w=4). apply(leaky_rectify). custom_conv2d(128, k_h=4, k_w=4). conv_batch_norm(). apply(leaky_rectify). custom_fully_connected(1024). fc_batch_norm(). apply(leaky_rectify)) self.discriminator_template = shared_template.custom_fully_connected( 1) self.encoder_template = \ (shared_template. custom_fully_connected(128). fc_batch_norm(). apply(leaky_rectify). custom_fully_connected(self.reg_latent_dist.dist_flat_dim)) with tf.variable_scope("g_net"): self.generator_template = \ (pt.template("input"). custom_fully_connected(1024). fc_batch_norm(). apply(tf.nn.relu). custom_fully_connected(image_size / 4 * image_size / 4 * 128). fc_batch_norm(). apply(tf.nn.relu). reshape([-1, image_size / 4, image_size / 4, 128]). custom_deconv2d([0, image_size / 2, image_size / 2, 64], k_h=4, k_w=4). conv_batch_norm(). apply(tf.nn.relu). custom_deconv2d([0] + list(image_shape), k_h=4, k_w=4). flatten()) elif network_type == "celebA": with tf.variable_scope("d_net"): shared_template = \ (pt.template("input"). reshape([-1] + list(image_shape)). custom_conv2d(64, k_h=4, k_w=4). apply(leaky_rectify). custom_conv2d(128, k_h=4, k_w=4). conv_batch_norm(). apply(leaky_rectify). custom_conv2d(256, k_h=4, k_w=4). conv_batch_norm(). apply(leaky_rectify)) self.discriminator_template = shared_template.custom_fully_connected( 1) self.encoder_template = \ (shared_template. custom_fully_connected(128). fc_batch_norm(). apply(leaky_rectify). custom_fully_connected(self.reg_latent_dist.dist_flat_dim)) with tf.variable_scope("g_net"): self.generator_template = \ (pt.template("input"). custom_fully_connected(image_size / 16 * image_size / 16 * 448). fc_batch_norm(). apply(tf.nn.relu). reshape([-1, image_size / 16, image_size / 16, 448]). # I am *pretty sure* each of these dimensions grows by 2x # because the stride==2. custom_deconv2d([0, image_size / 8, image_size / 8, 256], k_h=4, k_w=4). conv_batch_norm(). apply(tf.nn.relu). custom_deconv2d([0, image_size / 4, image_size / 4, 128], k_h=4, k_w=4). apply(tf.nn.relu). custom_deconv2d([0, image_size / 2, image_size / 2, 64], k_h=4, k_w=4). apply(tf.nn.relu). custom_deconv2d([0, image_size / 1, image_size / 1, 3], k_h=4, k_w=4). apply(tf.nn.tanh). flatten()) elif network_type == "face": with tf.variable_scope("d_net"): shared_template = \ (pt.template("input"). reshape([-1] + list(image_shape)). custom_conv2d(64, k_h=4, k_w=4). apply(leaky_rectify). custom_conv2d(128, k_h=4, k_w=4). conv_batch_norm(). apply(leaky_rectify). custom_fully_connected(1024). fc_batch_norm(). apply(leaky_rectify)) self.discriminator_template = shared_template.custom_fully_connected( 1) self.encoder_template = \ (shared_template. custom_fully_connected(self.reg_latent_dist.dist_flat_dim)) with tf.variable_scope("g_net"): self.generator_template = \ (pt.template("input"). custom_fully_connected(1024). fc_batch_norm(). apply(tf.nn.relu). custom_fully_connected(image_size / 4 * image_size / 4 * 128). fc_batch_norm(). apply(tf.nn.relu). reshape([-1, image_size / 4, image_size / 4, 128]). custom_deconv2d([0, image_size / 2, image_size / 2, 64], k_h=4, k_w=4). conv_batch_norm(). apply(tf.nn.relu). custom_deconv2d([0] + list(image_shape), k_h=4, k_w=4). apply(tf.nn.sigmoid). flatten()) else: raise NotImplementedError
def __init__(self, output_dist, latent_spec, batch_size, image_shape, network_type): """ :type output_dist: Distribution :type latent_spec: list[(Distribution, bool)] :type batch_size: int :type network_type: string """ self.output_dist = output_dist pstr('output_dist', self.output_dist) self.latent_spec = latent_spec self.latent_dist = Product([x for x, _ in latent_spec]) pstr('latent_dist', self.latent_dist) pstr('x in latent_spec', [x for x, _ in self.latent_spec]) pstr('xreg in latent_spec', [xreg for _, xreg in self.latent_spec]) #for x in enumerate(self.latent_spec): # print '------------------------' # for y in enumerate(x): # pstrall('x----reg',y) self.reg_latent_dist = Product([x for x, reg in latent_spec if reg]) self.nonreg_latent_dist = Product( [x for x, reg in latent_spec if not reg]) self.batch_size = batch_size self.network_type = network_type self.image_shape = image_shape assert all( isinstance(x, (Gaussian, Categorical, Bernoulli)) for x in self.reg_latent_dist.dists) #for x in self.reg_latent_dist.dists: # pstr('x in reg_latent_dist.dists',x) self.reg_cont_latent_dist = Product( [x for x in self.reg_latent_dist.dists if isinstance(x, Gaussian)]) self.reg_disc_latent_dist = Product([ x for x in self.reg_latent_dist.dists if isinstance(x, (Categorical, Bernoulli)) ]) pstr('image_shape', image_shape) pstr('image_shape[0]', image_shape[0]) image_size = image_shape[0] #self.image_shape = (178, 218, 1) if network_type == "mnist": with tf.variable_scope("d_net"): shared_template = \ (pt.template("input"). reshape([-1] + list(image_shape)). custom_conv2d(64, k_h=4, k_w=4). #conv_batch_norm(). apply(leaky_rectify). custom_conv2d(128, k_h=4, k_w=4). conv_batch_norm(). apply(leaky_rectify). custom_conv2d(256, k_h=4, k_w=4). conv_batch_norm(). apply(leaky_rectify). #custom_fully_connected(1024). #fc_batch_norm(). #apply(leaky_rectify). custom_conv2d(512, k_h=4, k_w=4)) #conv_batch_norm(). #apply(leaky_rectify2)) #linear #apply(tf.nn.sigmoid)) self.discriminator_template = shared_template.custom_fully_connected( 1) self.encoder_template = \ (shared_template. custom_fully_connected(128). fc_batch_norm(). apply(leaky_rectify). custom_fully_connected(self.reg_latent_dist.dist_flat_dim)) with tf.variable_scope("g_net"): s = self.image_shape[0] s2, s4, s8, s16, s32 = int(s / 2), int(s / 4), int(s / 8), int( s / 16), int(s / 32) self.generator_template = \ (pt.template("input"). custom_fully_connected(s16 * s16 * 512). fc_batch_norm(). apply(tf.nn.relu). reshape([-1, s16, s16, 512]). #custom_fully_connected(s32 * s32 * 1024). #fc_batch_norm(). #apply(tf.nn.relu). #reshape([-1, s32, s32, 1024]). #custom_deconv2d([0, s16, s16, 512], k_h=4, k_w=4). #conv_batch_norm(). #apply(tf.nn.relu). custom_deconv2d([0, s8, s8, 256], k_h=4, k_w=4). conv_batch_norm(). apply(tf.nn.relu). custom_deconv2d([0, s4, s4, 128], k_h=4, k_w=4). #conv_batch_norm(). apply(tf.nn.relu). custom_deconv2d([0, s2, s2, 64], k_h=4, k_w=4). #conv_batch_norm(). apply(tf.nn.relu). custom_deconv2d([0] + list(image_shape), k_h=4, k_w=4). apply(tf.nn.tanh)) else: raise NotImplementedError