def decoder(self, z, name='decoder', is_reuse=False): with tf.variable_scope(name) as scope: if is_reuse is True: scope.reuse_variables() tf_utils.print_activations(z) # 1st hidden layer h0_linear = tf_utils.linear(z, self.n_hidden, name='h0_linear') h0_tanh = tf_utils.tanh(h0_linear, name='h0_tanh') h0_drop = tf.nn.dropout(h0_tanh, keep_prob=self.keep_prob_tfph, name='h0_drop') tf_utils.print_activations(h0_drop) # 2nd hidden layer h1_linear = tf_utils.linear(h0_drop, self.n_hidden, name='h1_linear') h1_elu = tf_utils.elu(h1_linear, name='h1_elu') h1_drop = tf.nn.dropout(h1_elu, keep_prob=self.keep_prob_tfph, name='h1_drop') tf_utils.print_activations(h1_drop) # 3rd hidden layer h2_linear = tf_utils.linear(h1_drop, self.output_dim, name='h2_linear') h2_sigmoid = tf_utils.sigmoid(h2_linear, name='h2_sigmoid') tf_utils.print_activations(h2_sigmoid) output = tf.reshape(h2_sigmoid, [-1, *self.image_size]) tf_utils.print_activations(output) return output
def __call__(self, x): with tf.variable_scope(self.name, reuse=self.reuse): tf_utils.print_activations(x) # conv: (N, H, W, 3) -> (N, H/2, W/2, 64) output = tf_utils.conv2d(x, self.ndf, k_h=4, k_w=4, d_h=2, d_w=2, padding='SAME', name='conv0_conv2d') output = tf_utils.lrelu(output, name='conv0_lrelu', is_print=True) for idx, hidden_dim in enumerate(self.hidden_dims[1:]): # conv: (N, H/2, W/2, C) -> (N, H/4, W/4, C/2) output = tf_utils.conv2d(output, hidden_dim, k_h=4, k_w=4, d_h=2, d_w=2, padding='SAME', name='conv{}_conv2d'.format(idx + 1)) output = tf_utils.norm(output, _type=self.norm, _ops=self._ops, name='conv{}_norm'.format(idx + 1)) output = tf_utils.lrelu(output, name='conv{}_lrelu'.format(idx + 1), is_print=True) # conv: (N, H/16, W/16, 512) -> (N, H/16, W/16, 1) output = tf_utils.conv2d(output, 1, k_h=4, k_w=4, d_h=1, d_w=1, padding='SAME', name='conv4_conv2d') # set reuse=True for next call self.reuse = True self.variables = tf.get_collection( tf.GraphKeys.TRAINABLE_VARIABLES, scope=self.name) return tf_utils.sigmoid(output), output
def network(self, inputs, name=None): with tf.variable_scope(name): tf_utils.print_activations(inputs) # input of main reccurent layers output = tf_utils.conv2d_mask(inputs, 2 * self.hidden_dims, [7, 7], mask_type="A", name='inputConv1') # main recurrent layers if self.flags.model == 'pixelcnn': for idx in range(self.recurrent_length): output = tf_utils.conv2d_mask( output, self.hidden_dims, [3, 3], mask_type="B", name='mainConv{}'.format(idx + 2)) output = tf_utils.relu(output, name='mainRelu{}'.format(idx + 2)) elif self.flags.model == 'diagonal_bilstm': for idx in range(self.recurrent_length): output = self.diagonal_bilstm(output, name='BiLSTM{}'.format(idx + 2)) elif self.flags.model == 'row_lstm': raise NotImplementedError else: raise NotImplementedError # output recurrent layers for idx in range(self.out_recurrent_length): output = tf_utils.conv2d_mask(output, self.hidden_dims, [1, 1], mask_type="B", name='outputConv{}'.format(idx + 1)) output = tf_utils.relu(output, name='outputRelu{}'.format(idx + 1)) # TODO: for color images, implement a 256-way softmax for each RGB channel here output = tf_utils.conv2d_mask(output, self.img_size[2], [1, 1], mask_type="B", name='outputConv3') # output = tf_utils.sigmoid(output_logits, name='output_sigmoid') return tf_utils.sigmoid(output), output
def __call__(self, i_to_s, state, name='DiagonalBiLSTMCell'): c_prev = tf.slice(state, begin=[0, 0], size=[-1, self._num_units]) # [batch, height * hidden_dims] h_prev = tf.slice(state, begin=[0, self._num_units], size=[-1, self._num_units]) # i_to_s: [batch, 4 * height * hidden_dims] input_size = i_to_s.get_shape().with_rank(2)[1] if input_size.value is None: raise ValueError( "Could not infer input size from inputs.get_shape()[-1]") with tf.variable_scope(name): # input-to-state (K_ss * h_{i-1}) : 2x1 convolution. generate 4h x n x n ternsor. # [batch, height, 1, hidden_dims] conv1d_inputs = tf.reshape( h_prev, [-1, self._height, 1, self._hidden_dims], name='conv1d_inputs') # [batch, height, 1, hidden_dims * 4] conv_s_to_s = tf_utils.conv1d(conv1d_inputs, 4 * self._hidden_dims, kernel_size=2, name='s_to_s') # [batch, height * hidden_dims * 4] s_to_s = tf.reshape(conv_s_to_s, [-1, self._height * self._hidden_dims * 4]) lstm_matrix = tf_utils.sigmoid(s_to_s + i_to_s) # i=input_gate, g=new_input, f=forget_gate, o=output_gate o, f, i, g = tf.split(lstm_matrix, 4, axis=1) c = f * c_prev + i * g h = o * tf_utils.tanh(c) new_state = tf.concat([c, h], axis=1) return h, new_state
def __call__(self, x): with tf.variable_scope(self.name, reuse=self.reuse): tf_utils.print_activations(x) # (N, H, W, C) -> (N, H/2, W/2, 64) conv1 = tf_utils.conv2d(x, self.ndf, k_h=4, k_w=4, d_h=2, d_w=2, padding='SAME', name='conv1_conv') conv1 = tf_utils.lrelu(conv1, name='conv1_lrelu', is_print=True) # (N, H/2, W/2, 64) -> (N, H/4, W/4, 128) conv2 = tf_utils.conv2d(conv1, 2 * self.ndf, k_h=4, k_w=4, d_h=2, d_w=2, padding='SAME', name='conv2_conv') conv2 = tf_utils.norm(conv2, _type='instance', _ops=self._ops, name='conv2_norm') conv2 = tf_utils.lrelu(conv2, name='conv2_lrelu', is_print=True) # (N, H/4, W/4, 128) -> (N, H/8, W/8, 256) conv3 = tf_utils.conv2d(conv2, 4 * self.ndf, k_h=4, k_w=4, d_h=2, d_w=2, padding='SAME', name='conv3_conv') conv3 = tf_utils.norm(conv3, _type='instance', _ops=self._ops, name='conv3_norm') conv3 = tf_utils.lrelu(conv3, name='conv3_lrelu', is_print=True) # (N, H/8, W/8, 256) -> (N, H/16, W/16, 512) conv4 = tf_utils.conv2d(conv3, 8 * self.ndf, k_h=4, k_w=4, d_h=2, d_w=2, padding='SAME', name='conv4_conv') conv4 = tf_utils.norm(conv4, _type='instance', _ops=self._ops, name='conv4_norm') conv4 = tf_utils.lrelu(conv4, name='conv4_lrelu', is_print=True) # (N, H/16, W/16, 512) -> (N, H/16, W/16, 1) conv5 = tf_utils.conv2d(conv4, 1, k_h=4, k_w=4, d_h=1, d_w=1, padding='SAME', name='conv5_conv', is_print=True) if self.use_sigmoid: output = tf_utils.sigmoid(conv5, name='output_sigmoid', is_print=True) else: output = tf.identity(conv5, name='output_without_sigmoid') # set reuse=True for next call self.reuse = True self.variables = tf.get_collection( tf.GraphKeys.TRAINABLE_VARIABLES, scope=self.name) return output