Exemplo n.º 1
0
    def conv_block(self, inputs, depthwise=False):
        ''' 2D conv layers. '''
        filters = self.netconf['filters']
        logging.info("filters : {}".format(filters))
        filters_size = self.netconf['filter_size']
        logging.info("filters_size : {}".format(filters_size))
        filters_strides = self.netconf['filter_stride']
        logging.info("filters_strides : {}".format(filters_strides))
        pools_size = self.netconf['pool_size']
        logging.info("pools_size : {}".format(pools_size))

        layer_num = len(filters)
        assert layer_num == len(filters_size)
        assert layer_num == len(filters_strides)
        assert layer_num == len(pools_size)

        channels = [self.input_channels] + filters
        logging.info("channels : {}".format(channels))

        downsample_input_len = self.input_len
        with tf.variable_scope('cnn'):
            x = tf.identity(inputs)
            for index, filt in enumerate(filters):
                unit_name = 'unit-' + str(index + 1)
                with tf.variable_scope(unit_name):
                    if depthwise:
                        x = tf.layers.separable_conv2d(
                            x,
                            filters=filt,
                            kernel_size=filters_size[index],
                            strides=filters_strides[index],
                            padding='same',
                            name=unit_name)
                    else:
                        cnn_name = 'cnn-' + str(index + 1)
                        x = common_layers.conv2d(x, cnn_name,
                                                 filters_size[index],
                                                 channels[index],
                                                 channels[index + 1],
                                                 filters_strides[index])
                    x = tf.nn.relu(x)
                    if self.netconf['use_bn']:
                        bn_name = 'bn' + str(index + 1)
                        x = tf.layers.batch_normalization(x,
                                                          axis=-1,
                                                          momentum=0.9,
                                                          training=self.train,
                                                          name=bn_name)
                    if self.netconf['use_dropout']:
                        x = tf.layers.dropout(x,
                                              self.netconf['dropout_rate'],
                                              training=self.train)
                    x = common_layers.max_pool(x, pools_size[index],
                                               pools_size[index])
                    downsample_input_len = downsample_input_len / pools_size[
                        index][0]

        return x, downsample_input_len
 def pooling_layer(self, x, time_len):
   ''' pooling layer'''
   with tf.variable_scope('time_pooling'):
     if self.attention:
       x, self.alphas = common_layers.attention(
           x, self.netconf['attention_size'], return_alphas=True)
       #alphas shape [batch, time, 1] -> [1, batch, time, 1]-> [1, time, batch, 1]
       tf.summary.image(
           'alignment',
           tf.transpose(tf.expand_dims(self.alphas, 0), [0, 2, 1, 3]))
     else:
       if self.netconf['use_lstm_layer']:
         x = tf.concat(x, 2)
       # [batch, seq_len, dim, 1]
       x = tf.expand_dims(x, axis=-1)
       seq_len = time_len
       x = common_layers.max_pool(x, ksize=[seq_len, 1], strides=[seq_len, 1])
       if self.netconf['use_lstm_layer']:
         x = tf.reshape(x, [-1, 2 * self.netconf['cell_num']])
       else:
         x = tf.reshape(x, [-1, self.netconf['linear_num']])
     return x