def network(self, inputs, name=None):
        with tf.variable_scope(name):
            tf_utils.print_activations(inputs)

            # input of main reccurent layers
            output = tf_utils.conv2d_mask(inputs,
                                          2 * self.hidden_dims, [7, 7],
                                          mask_type="A",
                                          name='inputConv1')

            # main recurrent layers
            if self.flags.model == 'pixelcnn':
                for idx in range(self.recurrent_length):
                    output = tf_utils.conv2d_mask(
                        output,
                        self.hidden_dims, [3, 3],
                        mask_type="B",
                        name='mainConv{}'.format(idx + 2))
                    output = tf_utils.relu(output,
                                           name='mainRelu{}'.format(idx + 2))
            elif self.flags.model == 'diagonal_bilstm':
                for idx in range(self.recurrent_length):
                    output = self.diagonal_bilstm(output,
                                                  name='BiLSTM{}'.format(idx +
                                                                         2))
            elif self.flags.model == 'row_lstm':
                raise NotImplementedError
            else:
                raise NotImplementedError

            # output recurrent layers
            for idx in range(self.out_recurrent_length):
                output = tf_utils.conv2d_mask(output,
                                              self.hidden_dims, [1, 1],
                                              mask_type="B",
                                              name='outputConv{}'.format(idx +
                                                                         1))
                output = tf_utils.relu(output,
                                       name='outputRelu{}'.format(idx + 1))

            # TODO: for color images, implement a 256-way softmax for each RGB channel here
            output = tf_utils.conv2d_mask(output,
                                          self.img_size[2], [1, 1],
                                          mask_type="B",
                                          name='outputConv3')
            # output = tf_utils.sigmoid(output_logits, name='output_sigmoid')

            return tf_utils.sigmoid(output), output
    def diagonal_bilstm(self, inputs, name='diagonal_bilstm'):
        with tf.variable_scope(name):
            output_state_fw = self.diagonal_lstm(inputs,
                                                 name='output_state_fw')
            output_state_bw = tf_utils.reverse(
                self.diagonal_lstm(tf_utils.reverse(inputs),
                                   name='output_state_bw'))

            # Residual connection part
            residual_state_fw = tf_utils.conv2d_mask(output_state_fw,
                                                     2 * self.hidden_dims,
                                                     [1, 1],
                                                     mask_type="B",
                                                     name='residual_fw')
            output_state_fw = residual_state_fw + inputs

            residual_state_bw = tf_utils.conv2d_mask(output_state_bw,
                                                     2 * self.hidden_dims,
                                                     [1, 1],
                                                     mask_type="B",
                                                     name='residual_bw')
            output_state_bw = residual_state_bw + inputs

            batch, height, width, channel = output_state_bw.get_shape(
            ).as_list()
            output_state_bw_except_last = tf.slice(output_state_bw,
                                                   [0, 0, 0, 0],
                                                   [-1, height - 1, -1, -1])
            output_state_bw_only_last = tf.slice(output_state_bw,
                                                 [0, height - 1, 0, 0],
                                                 [-1, -1, -1, -1])
            dummy_zeros = tf.zeros_like(output_state_bw_only_last)

            output_state_bw_with_last_zeros = tf.concat(
                [dummy_zeros, output_state_bw_except_last], axis=1)

            return output_state_fw + output_state_bw_with_last_zeros
    def diagonal_lstm(self, inputs, name='diagonal_lstm'):
        with tf.variable_scope(name):
            skewed_inputs = tf_utils.skew(inputs, name='skewed_i')

            # input-to-state (K_is * x_i): 1x1 convolution. generate 4h x n x n tensor
            input_to_state = tf_utils.conv2d_mask(skewed_inputs,
                                                  4 * self.hidden_dims, [1, 1],
                                                  mask_type="B",
                                                  name="i_to_s")
            # [batch, width, height, hidden_dims*4]
            column_wise_inputs = tf.transpose(input_to_state, [0, 2, 1, 3])

            batch, width, height, channel = tf_utils.get_shape(
                column_wise_inputs)
            # [batch, max_time, height*hidden_dims*4]
            rnn_inputs = tf.reshape(column_wise_inputs,
                                    [-1, width, height * channel])
            # rnn_input_list = [tf.squeeze(rnn_input, axis=[1]) for rnn_input in tf.split(rnn_inputs, width, axis=1)]
            cell = DiagonalLSTMCell(self.hidden_dims, height, channel)

            # [batch, width, height * hidden_dims]
            outputs, states = tf.nn.dynamic_rnn(cell=cell,
                                                inputs=rnn_inputs,
                                                dtype=tf.float32)
            packed_outputs = outputs

            # [batch, width, height, hidden_dims]
            width_first_output = tf.reshape(
                packed_outputs, [-1, width, height, self.hidden_dims])

            skewed_outputs = tf.transpose(
                width_first_output,
                [0, 2, 1, 3])  # [batch, height, width, hidden_dims]
            outputs = tf_utils.unskew(skewed_outputs)

        return outputs