def network(self, inputs, name=None): with tf.variable_scope(name): tf_utils.print_activations(inputs) # input of main reccurent layers output = tf_utils.conv2d_mask(inputs, 2 * self.hidden_dims, [7, 7], mask_type="A", name='inputConv1') # main recurrent layers if self.flags.model == 'pixelcnn': for idx in range(self.recurrent_length): output = tf_utils.conv2d_mask( output, self.hidden_dims, [3, 3], mask_type="B", name='mainConv{}'.format(idx + 2)) output = tf_utils.relu(output, name='mainRelu{}'.format(idx + 2)) elif self.flags.model == 'diagonal_bilstm': for idx in range(self.recurrent_length): output = self.diagonal_bilstm(output, name='BiLSTM{}'.format(idx + 2)) elif self.flags.model == 'row_lstm': raise NotImplementedError else: raise NotImplementedError # output recurrent layers for idx in range(self.out_recurrent_length): output = tf_utils.conv2d_mask(output, self.hidden_dims, [1, 1], mask_type="B", name='outputConv{}'.format(idx + 1)) output = tf_utils.relu(output, name='outputRelu{}'.format(idx + 1)) # TODO: for color images, implement a 256-way softmax for each RGB channel here output = tf_utils.conv2d_mask(output, self.img_size[2], [1, 1], mask_type="B", name='outputConv3') # output = tf_utils.sigmoid(output_logits, name='output_sigmoid') return tf_utils.sigmoid(output), output
def diagonal_bilstm(self, inputs, name='diagonal_bilstm'): with tf.variable_scope(name): output_state_fw = self.diagonal_lstm(inputs, name='output_state_fw') output_state_bw = tf_utils.reverse( self.diagonal_lstm(tf_utils.reverse(inputs), name='output_state_bw')) # Residual connection part residual_state_fw = tf_utils.conv2d_mask(output_state_fw, 2 * self.hidden_dims, [1, 1], mask_type="B", name='residual_fw') output_state_fw = residual_state_fw + inputs residual_state_bw = tf_utils.conv2d_mask(output_state_bw, 2 * self.hidden_dims, [1, 1], mask_type="B", name='residual_bw') output_state_bw = residual_state_bw + inputs batch, height, width, channel = output_state_bw.get_shape( ).as_list() output_state_bw_except_last = tf.slice(output_state_bw, [0, 0, 0, 0], [-1, height - 1, -1, -1]) output_state_bw_only_last = tf.slice(output_state_bw, [0, height - 1, 0, 0], [-1, -1, -1, -1]) dummy_zeros = tf.zeros_like(output_state_bw_only_last) output_state_bw_with_last_zeros = tf.concat( [dummy_zeros, output_state_bw_except_last], axis=1) return output_state_fw + output_state_bw_with_last_zeros
def diagonal_lstm(self, inputs, name='diagonal_lstm'): with tf.variable_scope(name): skewed_inputs = tf_utils.skew(inputs, name='skewed_i') # input-to-state (K_is * x_i): 1x1 convolution. generate 4h x n x n tensor input_to_state = tf_utils.conv2d_mask(skewed_inputs, 4 * self.hidden_dims, [1, 1], mask_type="B", name="i_to_s") # [batch, width, height, hidden_dims*4] column_wise_inputs = tf.transpose(input_to_state, [0, 2, 1, 3]) batch, width, height, channel = tf_utils.get_shape( column_wise_inputs) # [batch, max_time, height*hidden_dims*4] rnn_inputs = tf.reshape(column_wise_inputs, [-1, width, height * channel]) # rnn_input_list = [tf.squeeze(rnn_input, axis=[1]) for rnn_input in tf.split(rnn_inputs, width, axis=1)] cell = DiagonalLSTMCell(self.hidden_dims, height, channel) # [batch, width, height * hidden_dims] outputs, states = tf.nn.dynamic_rnn(cell=cell, inputs=rnn_inputs, dtype=tf.float32) packed_outputs = outputs # [batch, width, height, hidden_dims] width_first_output = tf.reshape( packed_outputs, [-1, width, height, self.hidden_dims]) skewed_outputs = tf.transpose( width_first_output, [0, 2, 1, 3]) # [batch, height, width, hidden_dims] outputs = tf_utils.unskew(skewed_outputs) return outputs