Exemplo n.º 1
0
    def tdnn_block(self, inputs):
        ''' TDNN layers. '''
        if 'tdnn_method' in self.netconf:
            tdnn_method = self.netconf['tdnn_method']
        else:
            # Runs faster, support discrete context, for now.
            tdnn_method = 'splice_layer'
        tdnn_contexts = self.netconf['tdnn_contexts']
        logging.info("tdnn_contexts : {}".format(tdnn_contexts))
        tdnn_dims = self.netconf['tdnn_dims']
        logging.info("tdnn_dims : {}".format(tdnn_dims))

        layer_num = len(tdnn_contexts)
        assert layer_num == len(tdnn_dims)

        channels = [self.input_channels] + tdnn_dims
        logging.info("tdnn_channels : {}".format(channels))

        input_h_t = tf.shape(inputs)[1]
        input_w = inputs.shape[2]
        input_c = inputs.shape[3]
        if tdnn_method == 'conv1d':
            # NHWC -> NW'C, W' = H * W
            inputs = tf.reshape(inputs, [-1, input_h_t * input_w, input_c])
            last_w = channels[0]
        else:
            inputs = tf.reshape(inputs, [-1, input_h_t, input_w * input_c])
            last_w = input_w * input_c

        downsample_input_len = self.input_len
        with tf.variable_scope('tdnn'):
            x = tf.identity(inputs)
            for index in range(layer_num):
                unit_name = 'unit-' + str(index + 1)
                with tf.variable_scope(unit_name):
                    tdnn_name = 'tdnn-' + str(index + 1)
                    x = common_layers.tdnn(x,
                                           tdnn_name,
                                           last_w,
                                           tdnn_contexts[index],
                                           channels[index + 1],
                                           has_bias=True,
                                           method=tdnn_method)
                    last_w = channels[index + 1]
                    x = tf.nn.relu(x)
                    if self.netconf['use_bn']:
                        bn_name = 'bn' + str(index + 1)
                        x = tf.layers.batch_normalization(x,
                                                          axis=-1,
                                                          momentum=0.9,
                                                          training=self.train,
                                                          name=bn_name)
                    if self.netconf['use_dropout']:
                        x = tf.layers.dropout(x,
                                              self.netconf['dropout_rate'],
                                              training=self.train)
                    downsample_input_len = downsample_input_len

        return x, downsample_input_len
Exemplo n.º 2
0
    def tdnn_block(self, inputs):
        ''' TDNN layers. '''
        tdnn_contexts = self.netconf['tdnn_contexts']
        logging.info("tdnn_contexts : {}".format(tdnn_contexts))
        tdnn_dims = self.netconf['tdnn_dims']
        logging.info("tdnn_dims : {}".format(tdnn_dims))

        layer_num = len(tdnn_contexts)
        assert layer_num == len(tdnn_dims)

        channels = [self.input_channels] + tdnn_dims
        logging.info("tdnn_channels : {}".format(channels))

        # NHWC -> NW'C, W' = H * W
        input_n, input_h, input_w, input_c = inputs.shape.as_list()
        inputs = tf.reshape(inputs, [-1, input_h * input_w, input_c])

        downsample_input_len = self.input_len
        with tf.variable_scope('tdnn'):
            x = tf.identity(inputs)
            for index in range(layer_num):
                unit_name = 'unit-' + str(index + 1)
                with tf.variable_scope(unit_name):
                    tdnn_name = 'tdnn-' + str(index + 1)
                    use_bn = self.netconf['use_bn']
                    has_bias = not use_bn
                    x = common_layers.tdnn(x,
                                           tdnn_name,
                                           channels[index],
                                           tdnn_contexts[index],
                                           channels[index + 1],
                                           has_bias=has_bias)
                    x = tf.nn.relu(x)
                    if self.netconf['use_bn']:
                        bn_name = 'bn' + str(index + 1)
                        x = tf.layers.batch_normalization(x,
                                                          axis=-1,
                                                          momentum=0.9,
                                                          training=self.train,
                                                          name=bn_name)
                    if self.netconf['use_dropout']:
                        x = tf.layers.dropout(x,
                                              self.netconf['dropout_rate'],
                                              training=self.train)
                    downsample_input_len = downsample_input_len

        return x, downsample_input_len