def generator(self, batch): grey_batch = rgb_to_grey(batch) with tf.variable_scope('g_') as vs: """ ----------------------------------------------------------------------------------- ENCODER ----------------------------------------------------------------------------------- """ self.en_h0 = conv2d(grey_batch, self.frame_size, 128, k_h=4, k_w=4, d_w=2, d_h=2, name="enc_conv1") self.en_h0 = tf.nn.relu(tf.contrib.layers.batch_norm(self.en_h0)) add_activation_summary(self.en_h0) print(self.en_h0.get_shape().as_list()) self.en_h1 = conv2d(self.en_h0, 128, 256, k_h=4, k_w=4, d_w=2, d_h=2, name="enc_conv2") self.en_h1 = tf.contrib.layers.batch_norm(self.en_h1, scope="enc_bn2") self.en_h1 = tf.nn.relu(self.en_h1) add_activation_summary(self.en_h1) print(self.en_h1.get_shape().as_list()) self.en_h2 = conv2d(self.en_h1, 256, 512, k_h=4, k_w=4, d_w=2, d_h=2, name="enc_conv3") self.en_h2 = tf.contrib.layers.batch_norm(self.en_h2, scope="enc_bn3") self.en_h2 = tf.nn.relu(self.en_h2) add_activation_summary(self.en_h2) print(self.en_h2.get_shape().as_list()) self.en_h3 = conv2d(self.en_h2, 512, 1024, k_h=4, k_w=4, d_w=2, d_h=2, name="enc_conv4") self.en_h3 = tf.contrib.layers.batch_norm(self.en_h3, scope="enc_bn4") self.en_h3 = tf.nn.relu(self.en_h3) add_activation_summary(self.en_h3) print(self.en_h3.get_shape().as_list()) """ ----------------------------------------------------------------------------------- GENERATOR ----------------------------------------------------------------------------------- """ self.fg_h0 = tf.reshape(self.en_h3, [-1, 2, 4, 4, 512]) print(self.fg_h0.get_shape().as_list()) self.fg_h1 = conv3d_transpose(self.fg_h0, 512, [self.batch_size, 4, 8, 8, 256], name='g_f_h1') self.fg_h1 = tf.nn.relu(tf.contrib.layers.batch_norm(self.fg_h1, scope='g_f_bn1'), name='g_f_relu1') add_activation_summary(self.fg_h1) print(self.fg_h1.get_shape().as_list()) self.fg_h2 = conv3d_transpose(self.fg_h1, 256, [self.batch_size, 8, 16, 16, 128], name='g_f_h2') self.fg_h2 = tf.nn.relu(tf.contrib.layers.batch_norm(self.fg_h2, scope='g_f_bn2'), name='g_f_relu2') add_activation_summary(self.fg_h2) print(self.fg_h2.get_shape().as_list()) self.fg_h3 = conv3d_transpose(self.fg_h2, 128, [self.batch_size, 16, 32, 32, 64], name='g_f_h3') self.fg_h3 = tf.nn.relu(tf.contrib.layers.batch_norm(self.fg_h3, scope='g_f_bn3'), name='g_f_relu3') add_activation_summary(self.fg_h3) print(self.fg_h3.get_shape().as_list()) self.fg_h4 = conv3d_transpose(self.fg_h3, 64, [self.batch_size, 32, 64, 64, 3], name='g_f_h4') self.fg_fg = tf.nn.tanh(self.fg_h4, name='g_f_actvcation') print(self.fg_fg.get_shape().as_list()) gen_reg = tf.reduce_mean(tf.square(grey_batch - rgb_to_grey(self.fg_fg))) variables = tf.contrib.framework.get_variables(vs) return self.fg_fg, gen_reg, variables
def generator(self, z): with tf.variable_scope('g_') as vs: """ LINEAR BLOCK """ self.z_, _, _ = linear(z, 512 * 4 * 4 * 2, 'g_f_h0_lin', with_w=True) self.fg_h0 = tf.reshape(self.z_, [-1, 2, 4, 4, 512]) self.fg_h0 = tf.nn.relu(tf.contrib.layers.batch_norm( self.fg_h0, scope='g_f_bn0'), name='g_f_relu0') add_activation_summary(self.fg_h0) """ CONV BLOCK 1 """ self.fg_h1 = conv3d_transpose(self.fg_h0, 512, [self.batch_size, 4, 8, 8, 256], name='g_f_h1') self.fg_h1 = tf.nn.relu(tf.contrib.layers.batch_norm( self.fg_h1, scope='g_f_bn1'), name='g_f_relu1') add_activation_summary(self.fg_h1) """ CONV BLOCK 2 """ self.fg_h2 = conv3d_transpose(self.fg_h1, 256, [self.batch_size, 8, 16, 16, 128], name='g_f_h2') self.fg_h2 = tf.nn.relu(tf.contrib.layers.batch_norm( self.fg_h2, scope='g_f_bn2'), name='g_f_relu2') add_activation_summary(self.fg_h2) """ CONV BLOCK 3 """ self.fg_h3 = conv3d_transpose(self.fg_h2, 128, [self.batch_size, 16, 32, 32, 64], name='g_f_h3') self.fg_h3 = tf.nn.relu(tf.contrib.layers.batch_norm( self.fg_h3, scope='g_f_bn3'), name='g_f_relu3') add_activation_summary(self.fg_h3) """ CONV BLOCK 5 """ self.fg_h4 = conv3d_transpose(self.fg_h3, 64, [self.batch_size, 32, 64, 64, 3], name='g_f_h4') self.fg_fg = tf.nn.tanh(self.fg_h4, name='g_f_actvcation') variables = tf.contrib.framework.get_variables(vs) return self.fg_fg, variables
def generatorVid(self, img_batch, reuse=False): with tf.variable_scope('gen_v', reuse=reuse) as vs: """ ----------------------------------------------------------------------------------- ENCODER ----------------------------------------------------------------------------------- """ # self.en_h0 = conv2d(img_batch, 3, 64, k_h=4, k_w=4, d_w=2, d_h=2, name="enc_conv1") # self.en_h0 = tf.nn.relu(tf.contrib.layers.batch_norm(self.en_h0)) # add_activation_summary(self.en_h0) # print(self.en_h0.get_shape().as_list()) self.en_h0 = conv2d(img_batch, 3, 32, k_h=4, k_w=4, d_w=2, d_h=2, name="enc_conv1") self.en_h0 = tf.nn.relu(tf.contrib.layers.batch_norm(self.en_h0)) add_activation_summary(self.en_h0) print(self.en_h0.get_shape().as_list()) self.en_h1 = conv2d(self.en_h0, 32, 64, k_h=4, k_w=4, d_w=2, d_h=2, name="enc_conv2") self.en_h1 = tf.contrib.layers.batch_norm(self.en_h1, scope="enc_bn2") self.en_h1 = tf.nn.relu(self.en_h1) add_activation_summary(self.en_h1) print(self.en_h1.get_shape().as_list()) #output = tf.transpose(self.en_h1, [0, 3, 1, 2]) #for i in xrange(3): # output = self.ResidualBlock('res1.16x16_{}'.format(i), 256, 256, 3, output, resample=None) #self.en_h1 = tf.transpose(output, [0, 2, 3, 1]) self.en_h2 = conv2d(self.en_h1, 64, 128, k_h=4, k_w=4, d_w=2, d_h=2, name="enc_conv3") self.en_h2 = tf.contrib.layers.batch_norm(self.en_h2, scope="enc_bn3") self.en_h2 = tf.nn.relu(self.en_h2) add_activation_summary(self.en_h2) print(self.en_h2.get_shape().as_list()) #output = tf.transpose(self.en_h2, [0, 3, 1, 2]) #for i in xrange(3): # output = self.ResidualBlock('res1.16x16_2_{}'.format(i), 512, 512, 3, output, resample=None) #self.en_h2 = tf.transpose(output, [0, 2, 3, 1]) self.en_h3 = conv2d(self.en_h2, 128, 256, k_h=4, k_w=4, d_w=2, d_h=2, name="enc_conv4") self.en_h3 = tf.contrib.layers.batch_norm(self.en_h3, scope="enc_bn4") self.en_h3 = tf.nn.relu(self.en_h3) add_activation_summary(self.en_h3) print(self.en_h3.get_shape().as_list()) self.en_h4 = conv2d(self.en_h3, 256, 256, k_h=4, k_w=4, d_w=2, d_h=2, name="enc_conv5") self.en_h4 = tf.contrib.layers.batch_norm(self.en_h4, scope="enc_bn5") self.en_h4 = tf.nn.relu(self.en_h4) add_activation_summary(self.en_h4) print(self.en_h4.get_shape().as_list()) #self.en_h4 is [32, 4, 4, 256] self.en_h5 = conv2d(self.en_h4, 256, 256, k_h=4, k_w=4, d_w=2, d_h=2, name="enc_conv6") self.en_h5 = tf.contrib.layers.batch_norm(self.en_h5, scope="enc_bn6") self.en_h5 = tf.nn.relu(self.en_h5) add_activation_summary(self.en_h5) print(self.en_h5.get_shape().as_list()) #self.en_h5 is [32,2,2,256] self.en_h6 = conv2d(self.en_h5, 256, 256, k_h=4, k_w=4, d_w=2, d_h=2, name="enc_conv7") self.en_h6 = tf.contrib.layers.batch_norm(self.en_h6, scope="enc_bn7") self.en_h6 = tf.nn.relu(self.en_h6) add_activation_summary(self.en_h6) print(self.en_h6.get_shape().as_list()) #self.en_h6 is [32,1,1,256] """ ----------------------------------------------------------------------------------- GENERATOR ----------------------------------------------------------------------------------- """ self.z_ = tf.reshape(self.en_h6, [self.batch_size, 1, 1, 1, 256]) print(self.z_.get_shape().as_list()) self.fg_h1 = conv3d_transpose(self.z_, 256, [self.batch_size, 1, 2, 2, 256], name='g_f_h1') self.fg_h1 = tf.nn.relu(tf.contrib.layers.batch_norm( self.fg_h1, scope='g_f_bn1'), name='g_f_relu1') add_activation_summary(self.fg_h1) print(self.fg_h1.get_shape().as_list()) #self.fg_h1 is [32, 1, 2, 2, 256] encv5 = tf.reshape(self.en_h5, [self.batch_size, 1, 2, 2, 256]) encov5 = tf.concat([self.fg_h1, encv5], axis=4) self.fg_h2 = conv3d_transpose(encov5, 512, [self.batch_size, 2, 4, 4, 256], name='g_f_h2') self.fg_h2 = tf.nn.relu(tf.contrib.layers.batch_norm( self.fg_h2, scope='g_f_bn2'), name='g_f_relu2') add_activation_summary(self.fg_h2) print(self.fg_h2.get_shape().as_list()) encv4 = tf.tile(tf.expand_dims(self.en_h4, axis=1), [1, 2, 1, 1, 1]) encvo4 = tf.reshape(encv4, [self.batch_size, 2, 4, 4, 256]) encodv4 = tf.concat([self.fg_h2, encvo4], axis=4) self.fg_h3 = conv3d_transpose(encodv4, 512, [self.batch_size, 4, 8, 8, 256], name='g_f_h3') self.fg_h3 = tf.nn.relu(tf.contrib.layers.batch_norm( self.fg_h3, scope='g_f_bn3'), name='g_f_relu3') add_activation_summary(self.fg_h3) print(self.fg_h3.get_shape().as_list()) encv3 = tf.tile(tf.expand_dims(self.en_h3, axis=1), [1, 4, 1, 1, 1]) encvo3 = tf.reshape(encv3, [self.batch_size, 4, 8, 8, 256]) encodv3 = tf.concat([self.fg_h3, encvo3], axis=4) self.fg_h4 = conv3d_transpose(encodv3, 512, [self.batch_size, 8, 16, 16, 128], name='g_f_h4') self.fg_h4 = tf.nn.relu(tf.contrib.layers.batch_norm( self.fg_h4, scope='g_f_bn4'), name='g_f_relu4') add_activation_summary(self.fg_h4) print(self.fg_h4.get_shape().as_list()) encv2 = tf.tile(tf.expand_dims(self.en_h2, axis=1), [1, 8, 1, 1, 1]) encvo2 = tf.reshape(encv2, [self.batch_size, 8, 16, 16, 128]) #encvo2 is [32,8,16,16,128] and self.fg_h4 is [32,8,16,16,128] ans concat is [32,8,16,16,256] encodv2 = tf.concat([self.fg_h4, encvo2], axis=4) self.fg_h5 = conv3d_transpose(encodv2, 256, [self.batch_size, 16, 32, 32, 64], name='g_f_h5') self.fg_h5 = tf.nn.relu(tf.contrib.layers.batch_norm( self.fg_h5, scope='g_f_bn5'), name='g_f_relu5') add_activation_summary(self.fg_h5) print(self.fg_h5.get_shape().as_list()) #self.fg_h5 is [32, 16, 32, 32, 64] encv1 = tf.tile(tf.expand_dims(self.en_h1, axis=1), [1, 16, 1, 1, 1]) encvo1 = tf.reshape(encv1, [self.batch_size, 16, 32, 32, 64]) encodv1 = tf.concat([self.fg_h5, encvo1], axis=4) self.fg_h6 = conv3d_transpose(encodv1, 128, [self.batch_size, 32, 64, 64, 3], name='g_f_h6') self.fg_vid = tf.nn.tanh(self.fg_h6, name='g_f_actvcation') print(self.fg_vid.get_shape().as_list()) # gen_reg = tf.reduce_mean(tf.square(img_batch - self.fg_fg[:, 0, :, :, :])) # variables = tf.contrib.framework.get_variables(vs) # # return self.fg_fg, gen_reg, variables # return self.fg_vid, variables return self.fg_vid
def generator(self, batch): self.unmasked_video = batch self.masked_video = self.mask_video(batch) with tf.variable_scope('g_') as vs: """ ----------------------------------------------------------------------------------- ENCODER ----------------------------------------------------------------------------------- """ self.en_h0 = conv3d(self.masked_video, 3, 64, name="enc_conv1") self.en_h0 = tf.nn.relu(tf.contrib.layers.batch_norm(self.en_h0)) add_activation_summary(self.en_h0) print(self.en_h0.get_shape().as_list()) self.en_h1 = conv3d(self.en_h0, 64, 128, name="enc_conv2") self.en_h1 = tf.contrib.layers.batch_norm(self.en_h1, scope="enc_bn2") self.en_h1 = tf.nn.relu(self.en_h1) add_activation_summary(self.en_h1) print(self.en_h1.get_shape().as_list()) self.en_h2 = conv3d(self.en_h1, 128, 256, name="enc_conv3") self.en_h2 = tf.contrib.layers.batch_norm(self.en_h2, scope="enc_bn3") self.en_h2 = tf.nn.relu(self.en_h2) add_activation_summary(self.en_h2) print(self.en_h2.get_shape().as_list()) self.en_h3 = conv3d(self.en_h2, 256, 512, name="enc_conv4") self.en_h3 = tf.contrib.layers.batch_norm(self.en_h3, scope="enc_bn4") self.en_h3 = tf.nn.relu(self.en_h3) add_activation_summary(self.en_h3) print(self.en_h3.get_shape().as_list()) """ ----------------------------------------------------------------------------------- DECODER ----------------------------------------------------------------------------------- """ self.fg_h0 = tf.reshape(self.en_h3, [-1, 2, 4, 4, 512]) print(self.fg_h0.get_shape().as_list()) self.fg_h1 = conv3d_transpose(self.fg_h0, 512, [self.batch_size, 4, 8, 8, 256], name='g_f_h1') self.fg_h1 = tf.nn.relu(tf.contrib.layers.batch_norm( self.fg_h1, scope='g_f_bn1'), name='g_f_relu1') add_activation_summary(self.fg_h1) print(self.fg_h1.get_shape().as_list()) self.fg_h2 = conv3d_transpose(self.fg_h1, 256, [self.batch_size, 8, 16, 16, 128], name='g_f_h2') self.fg_h2 = tf.nn.relu(tf.contrib.layers.batch_norm( self.fg_h2, scope='g_f_bn2'), name='g_f_relu2') add_activation_summary(self.fg_h2) print(self.fg_h2.get_shape().as_list()) self.fg_h3 = conv3d_transpose(self.fg_h2, 128, [self.batch_size, 16, 32, 32, 64], name='g_f_h3') self.fg_h3 = tf.nn.relu(tf.contrib.layers.batch_norm( self.fg_h3, scope='g_f_bn3'), name='g_f_relu3') add_activation_summary(self.fg_h3) print(self.fg_h3.get_shape().as_list()) self.fg_h4 = conv3d_transpose(self.fg_h3, 64, [self.batch_size, 32, 64, 64, 3], name='g_f_h4') self.fg_fg = tf.nn.tanh(self.fg_h4, name='g_f_actvcation') print(self.fg_fg.get_shape().as_list()) gen_reg = tf.reduce_mean( tf.square(self.unmasked_video - self.fg_fg)) variables = tf.contrib.framework.get_variables(vs) return self.fg_fg, gen_reg, variables
def generatorVid(self, img_batch, reuse=False): with tf.variable_scope('gen_v', reuse=reuse) as vs: """ ----------------------------------------------------------------------------------- ENCODER ----------------------------------------------------------------------------------- """ self.en_h0 = conv2d(img_batch, 3, 64, k_h=4, k_w=4, d_w=2, d_h=2, name="enc_conv1") self.en_h0 = tf.nn.relu(tf.contrib.layers.batch_norm(self.en_h0)) add_activation_summary(self.en_h0) print(self.en_h0.get_shape().as_list()) self.en_h1 = conv2d(self.en_h0, 64, 128, k_h=4, k_w=4, d_w=2, d_h=2, name="enc_conv2") self.en_h1 = tf.contrib.layers.batch_norm(self.en_h1, scope="enc_bn2") self.en_h1 = tf.nn.relu(self.en_h1) add_activation_summary(self.en_h1) print(self.en_h1.get_shape().as_list()) output = tf.transpose(self.en_h1, [0, 3, 1, 2]) for i in xrange(3): output = self.ResidualBlock('res1.16x16_{}'.format(i), 128, 128, 3, output, resample=None) self.en_h1 = tf.transpose(output, [0, 2, 3, 1]) self.en_h2 = conv2d(self.en_h1, 128, 256, k_h=4, k_w=4, d_w=2, d_h=2, name="enc_conv3") self.en_h2 = tf.contrib.layers.batch_norm(self.en_h2, scope="enc_bn3") self.en_h2 = tf.nn.relu(self.en_h2) add_activation_summary(self.en_h2) print(self.en_h2.get_shape().as_list()) output = tf.transpose(self.en_h2, [0, 3, 1, 2]) for i in xrange(3): output = self.ResidualBlock('res1.16x16_2_{}'.format(i), 256, 256, 3, output, resample=None) self.en_h2 = tf.transpose(output, [0, 2, 3, 1]) self.en_h3 = conv2d(self.en_h2, 256, 512, k_h=4, k_w=4, d_w=2, d_h=2, name="enc_conv4") self.en_h3 = tf.contrib.layers.batch_norm(self.en_h3, scope="enc_bn4") self.en_h3 = tf.nn.relu(self.en_h3) add_activation_summary(self.en_h3) print(self.en_h3.get_shape().as_list()) output = tf.transpose(self.en_h3, [0, 3, 1, 2]) for i in xrange(3): output = self.ResidualBlock('res1.16x16_3_{}'.format(i), 512, 512, 3, output, resample=None) self.en_h3 = tf.transpose(output, [0, 2, 3, 1]) self.en_h4 = conv2d(self.en_h3, 512, 1024, k_h=4, k_w=4, d_w=2, d_h=2, name="enc_conv5") self.en_h4 = tf.contrib.layers.batch_norm(self.en_h4, scope="enc_bn5") self.en_h4 = tf.nn.relu(self.en_h4) add_activation_summary(self.en_h4) print(self.en_h4.get_shape().as_list()) """ ----------------------------------------------------------------------------------- GENERATOR ----------------------------------------------------------------------------------- """ self.z_ = tf.reshape(self.en_h4, [self.batch_size, 2, 4, 4, 512]) print(self.z_.get_shape().as_list()) self.fg_h1 = conv3d_transpose(self.z_, 512, [self.batch_size, 4, 8, 8, 256], name='g_f_h1') self.fg_h1 = tf.nn.relu(tf.contrib.layers.batch_norm(self.fg_h1, scope='g_f_bn1'), name='g_f_relu1') add_activation_summary(self.fg_h1) print(self.fg_h1.get_shape().as_list()) self.fg_h2 = conv3d_transpose(self.fg_h1, 256, [self.batch_size, 8, 16, 16, 128], name='g_f_h2') self.fg_h2 = tf.nn.relu(tf.contrib.layers.batch_norm(self.fg_h2, scope='g_f_bn2'), name='g_f_relu2') add_activation_summary(self.fg_h2) print(self.fg_h2.get_shape().as_list()) self.fg_h3 = conv3d_transpose(self.fg_h2, 128, [self.batch_size, 16, 32, 32, 64], name='g_f_h3') self.fg_h3 = tf.nn.relu(tf.contrib.layers.batch_norm(self.fg_h3, scope='g_f_bn3'), name='g_f_relu3') add_activation_summary(self.fg_h3) print(self.fg_h3.get_shape().as_list()) self.fg_h4 = conv3d_transpose(self.fg_h3, 64, [self.batch_size, 32, 64, 64, 3], name='g_f_h4') self.fg_vid = tf.nn.tanh(self.fg_h4, name='g_f_actvcation') print(self.fg_vid.get_shape().as_list()) # gen_reg = tf.reduce_mean(tf.square(img_batch - self.fg_fg[:, 0, :, :, :])) # variables = tf.contrib.framework.get_variables(vs) # # return self.fg_fg, gen_reg, variables # return self.fg_vid, variables return self.fg_vid