def _create_ff_block(self, subnet_unit: ReturnnNetwork, source, prefix): prefix = '{}_ff'.format(prefix) ln = subnet_unit.add_layer_norm_layer('{}_laynorm'.format(prefix), source) conv1 = subnet_unit.add_linear_layer( '{}_conv1'.format(prefix), ln, with_bias=True, activation='relu', forward_weights_init=self.forward_weights_init, n_out=self.ff_dim) conv2 = subnet_unit.add_linear_layer( '{}_conv2'.format(prefix), conv1, with_bias=True, activation=None, forward_weights_init=self.forward_weights_init, n_out=self.out_dim) drop = subnet_unit.add_dropout_layer('{}_drop'.format(prefix), conv2, dropout=self.dropout) out = subnet_unit.add_combine_layer('{}_out'.format(prefix), [drop, source], kind='add', n_out=self.out_dim) return out
def _create_masked_mhsa(self, subnet_unit: ReturnnNetwork, source, prefix): prefix = '{}_self_att'.format(prefix) ln = subnet_unit.add_layer_norm_layer('{}_laynorm'.format(prefix), source) att = subnet_unit.add_self_att_layer( '{}_att'.format(prefix), ln, forward_weights_init=self.forward_weights_init, att_dropout=self.att_dropout, attention_left_only=True, n_out=self.v_dim, num_heads=self.att_num_heads, total_key_dim=self.qk_dim) lin = subnet_unit.add_linear_layer( '{}_lin'.format(prefix), att, n_out=self.out_dim, with_bias=False, forward_weights_init=self.forward_weights_init) drop = subnet_unit.add_dropout_layer('{}_drop'.format(prefix), lin, dropout=self.dropout) out = subnet_unit.add_combine_layer('{}_out'.format(prefix), [drop, source], kind='add', n_out=self.out_dim) return out
def _create_prior_net(self, subnet_unit: ReturnnNetwork): prior_att_input = self._add_prior_input(subnet_unit) # for the first frame in decoding, don't use average but zero always is_first_frame = subnet_unit.add_compare_layer('is_first_frame', source=':i', kind='equal', value=0) zero_att = subnet_unit.add_eval_layer('zero_att', 'att', eval='tf.zeros_like(source(0))') prev_att = subnet_unit.add_switch_layer( 'prev_att', condition=is_first_frame, true_from=zero_att, false_from=prior_att_input) key_names = ['s', 'readout_in', 'readout', 'output_prob'] for key_name in key_names: d = copy.deepcopy(subnet_unit[key_name]) # update attention input new_sources = [] from_list = d['from'] if isinstance(from_list, str): from_list = [from_list] assert isinstance(from_list, list) for src in from_list: if 'att' in src: if src.split(':')[0] == 'prev': assert prev_att not in new_sources new_sources += [prev_att] # switched based on decoder index else: new_sources += [prior_att_input] elif src in key_names: new_sources += [('prev:' if 'prev' in src else '') + 'prior_{}'.format(src.split(':')[-1])] else: new_sources += [src] d['from'] = new_sources subnet_unit['prior_{}'.format(key_name)] = d return 'prior_output_prob'
def create_network(self): subnet_unit = ReturnnNetwork() target_embed_raw = subnet_unit.add_linear_layer( '{}target_embed_raw'.format(self.prefix_name), self.source, forward_weights_init=self.forward_weights_init, n_out=self.embed_dim, with_bias=False, param_device='CPU' if self.emb_cpu_lookup else None) target_embed_with_pos = subnet_unit.add_pos_encoding_layer( '{}target_embed_with_pos'.format(self.prefix_name), target_embed_raw) target_embed = subnet_unit.add_dropout_layer( '{}target_embed'.format(self.prefix_name), target_embed_with_pos, dropout=self.embed_dropout) target_embed_lin = subnet_unit.add_linear_layer( '{}target_embed_lin'.format(self.prefix_name), target_embed, with_bias=False, forward_weights_init=self.forward_weights_init, n_out=self.out_dim) x = target_embed_lin for i in range(self.num_layers): x = self._create_decoder_block(subnet_unit, x, i) # final LN decoder = subnet_unit.add_layer_norm_layer( '{}decoder'.format(self.prefix_name), x) subnet_unit.add_softmax_layer( '{}output'.format(self.prefix_name), decoder, forward_weights_init=self.forward_weights_init, loss='ce', target=self.target, with_bias=True, dropout=self.dropout) if self.use_as_ext_lm: self.network = copy.deepcopy(subnet_unit) else: self.network.add_subnet_rec_layer('output', unit=subnet_unit.get_net(), target=self.target, source=self.source) return 'output'
def __init__(self, source='data:delayed', target='data', num_layers=6, ff_dim=4096, att_num_heads=8, out_dim=1024, qk_dim=1024, v_dim=1024, dropout=0.0, att_dropout=0.0, embed_dropout=0.0, embed_dim=128, emb_cpu_lookup=True, forward_weights_init=None, prefix_name=None, use_as_ext_lm=False, vocab_size=None): self.source = source self.target = target self.num_layers = num_layers self.ff_dim = ff_dim self.att_num_heads = att_num_heads self.out_dim = out_dim self.qk_dim = qk_dim self.v_dim = v_dim self.dropout = dropout self.embed_dropout = embed_dropout self.att_dropout = att_dropout self.embed_dim = embed_dim self.emb_cpu_lookup = emb_cpu_lookup # use this as default for now if forward_weights_init is None: forward_weights_init = "variance_scaling_initializer(mode='fan_in', distribution='uniform', scale=1.0)" self.forward_weights_init = forward_weights_init self.use_as_ext_lm = use_as_ext_lm self.vocab_size = vocab_size if not prefix_name: prefix_name = '' self.prefix_name = prefix_name self.network = ReturnnNetwork()
def create(self): out_net = ReturnnNetwork() pad_left = out_net.add_pad_layer('feedback_pad_left', 'prev:att_weights', axes='s:0', padding=((self.filter_size - 1) // 2, 0), value=0) pad_right = out_net.add_pad_layer('feedback_pad_right', pad_left, axes='s:0', padding=(0, (self.filter_size - 1) // 2), value=0) loc_att_conv = out_net.add_conv_layer('loc_att_conv', pad_right, activation=None, with_bias=False, filter_size=(self.filter_size, ), padding='valid', n_out=self.num_channels, l2=self.l2) self.name = out_net.add_linear_layer('weight_feedback', loc_att_conv, activation=None, with_bias=False, n_out=self.enc_key_dim) return out_net.get_net()
def _add_prior_input(self, subnet_unit: ReturnnNetwork): prior_type = self.prior_lm_opts.get('type', None) assert prior_type is not None, 'prior_type not defined' if prior_type == 'mini_lstm': # add mini lstm layers subnet_unit.add_rec_layer( 'mini_att_lstm', 'prev:' + self.prior_lm_opts.get('target_embed_name', 'target_embed'), n_out=self.prior_lm_opts.get('mini_lstm_dim', 50), l2=self.prior_lm_opts.get('l2', 0.0)) prior_att_input = subnet_unit.add_linear_layer( 'mini_att', 'mini_att_lstm', activation=None, n_out=512, l2=0.0001) elif prior_type == 'zero': prior_att_input = subnet_unit.add_eval_layer( 'zero_att', 'transformer_decoder_01_att', eval='tf.zeros_like(source(0))') else: raise ValueError() return prior_att_input
def create(self): out_net = ReturnnNetwork() out_net.add_eval_layer( 'accum_att_weights', ["prev:accum_att_weights", "att_weights", "base:inv_fertility"], eval='source(0) + source(1) * source(2) * 0.5', out_type={ "dim": self.att_num_heads, "shape": (None, self.att_num_heads) }) self.name = out_net.add_linear_layer('weight_feedback', 'prev:accum_att_weights', n_out=self.enc_key_dim, with_bias=False) return out_net.get_net()
def _add_prior_input(self, subnet_unit: ReturnnNetwork): prior_type = self.prior_lm_opts['type'] assert prior_type == 'mini_lstm' num_layers = self.prior_lm_opts['dec_layers'] assert num_layers > 0 variant = self.prior_lm_opts['mini_lstm_variant'] assert variant in ['single', 'many'] if variant == 'single': subnet_unit.add_rec_layer( 'mini_att_lstm', 'prev:' + self.prior_lm_opts.get('target_embed_name', 'target_embed'), n_out=self.prior_lm_opts.get('mini_lstm_dim', 50), l2=self.prior_lm_opts.get('l2', 0.0)) else: for i in range(1, num_layers + 1): subnet_unit.add_rec_layer( 'mini_att_lstm_%02i' % i, 'prev:' + self.prior_lm_opts.get('target_embed_name', 'target_embed'), n_out=self.prior_lm_opts.get('mini_lstm_dim', 50), l2=self.prior_lm_opts.get('l2', 0.0)) for i in range(1, num_layers + 1): subnet_unit.add_linear_layer( 'mini_att_%02i' % i, 'mini_att_lstm_%02i' % i if variant == 'many' else 'mini_att_lstm', activation=None, n_out=512, l2=0.0001)
class TransformerLM: def __init__(self, source='data:delayed', target='data', num_layers=6, ff_dim=4096, att_num_heads=8, out_dim=1024, qk_dim=1024, v_dim=1024, dropout=0.0, att_dropout=0.0, embed_dropout=0.0, embed_dim=128, emb_cpu_lookup=True, forward_weights_init=None, prefix_name=None, use_as_ext_lm=False, vocab_size=None): self.source = source self.target = target self.num_layers = num_layers self.ff_dim = ff_dim self.att_num_heads = att_num_heads self.out_dim = out_dim self.qk_dim = qk_dim self.v_dim = v_dim self.dropout = dropout self.embed_dropout = embed_dropout self.att_dropout = att_dropout self.embed_dim = embed_dim self.emb_cpu_lookup = emb_cpu_lookup # use this as default for now if forward_weights_init is None: forward_weights_init = "variance_scaling_initializer(mode='fan_in', distribution='uniform', scale=1.0)" self.forward_weights_init = forward_weights_init self.use_as_ext_lm = use_as_ext_lm self.vocab_size = vocab_size if not prefix_name: prefix_name = '' self.prefix_name = prefix_name self.network = ReturnnNetwork() def _create_ff_block(self, subnet_unit: ReturnnNetwork, source, prefix): prefix = '{}_ff'.format(prefix) ln = subnet_unit.add_layer_norm_layer('{}_laynorm'.format(prefix), source) conv1 = subnet_unit.add_linear_layer( '{}_conv1'.format(prefix), ln, with_bias=True, activation='relu', forward_weights_init=self.forward_weights_init, n_out=self.ff_dim) conv2 = subnet_unit.add_linear_layer( '{}_conv2'.format(prefix), conv1, with_bias=True, activation=None, forward_weights_init=self.forward_weights_init, n_out=self.out_dim) drop = subnet_unit.add_dropout_layer('{}_drop'.format(prefix), conv2, dropout=self.dropout) out = subnet_unit.add_combine_layer('{}_out'.format(prefix), [drop, source], kind='add', n_out=self.out_dim) return out def _create_masked_mhsa(self, subnet_unit: ReturnnNetwork, source, prefix): prefix = '{}_self_att'.format(prefix) ln = subnet_unit.add_layer_norm_layer('{}_laynorm'.format(prefix), source) att = subnet_unit.add_self_att_layer( '{}_att'.format(prefix), ln, forward_weights_init=self.forward_weights_init, att_dropout=self.att_dropout, attention_left_only=True, n_out=self.v_dim, num_heads=self.att_num_heads, total_key_dim=self.qk_dim) lin = subnet_unit.add_linear_layer( '{}_lin'.format(prefix), att, n_out=self.out_dim, with_bias=False, forward_weights_init=self.forward_weights_init) drop = subnet_unit.add_dropout_layer('{}_drop'.format(prefix), lin, dropout=self.dropout) out = subnet_unit.add_combine_layer('{}_out'.format(prefix), [drop, source], kind='add', n_out=self.out_dim) return out def _create_decoder_block(self, subnet_unit: ReturnnNetwork, source, i): prefix = self.prefix_name + ('dec_%i' % i) masked_mhsa = self._create_masked_mhsa(subnet_unit, source, prefix) ff = self._create_ff_block(subnet_unit, masked_mhsa, prefix) out = subnet_unit.add_copy_layer(prefix, ff) return out def create_network(self): subnet_unit = ReturnnNetwork() target_embed_raw = subnet_unit.add_linear_layer( '{}target_embed_raw'.format(self.prefix_name), self.source, forward_weights_init=self.forward_weights_init, n_out=self.embed_dim, with_bias=False, param_device='CPU' if self.emb_cpu_lookup else None) target_embed_with_pos = subnet_unit.add_pos_encoding_layer( '{}target_embed_with_pos'.format(self.prefix_name), target_embed_raw) target_embed = subnet_unit.add_dropout_layer( '{}target_embed'.format(self.prefix_name), target_embed_with_pos, dropout=self.embed_dropout) target_embed_lin = subnet_unit.add_linear_layer( '{}target_embed_lin'.format(self.prefix_name), target_embed, with_bias=False, forward_weights_init=self.forward_weights_init, n_out=self.out_dim) x = target_embed_lin for i in range(self.num_layers): x = self._create_decoder_block(subnet_unit, x, i) # final LN decoder = subnet_unit.add_layer_norm_layer( '{}decoder'.format(self.prefix_name), x) subnet_unit.add_softmax_layer( '{}output'.format(self.prefix_name), decoder, forward_weights_init=self.forward_weights_init, loss='ce', target=self.target, with_bias=True, dropout=self.dropout) if self.use_as_ext_lm: self.network = copy.deepcopy(subnet_unit) else: self.network.add_subnet_rec_layer('output', unit=subnet_unit.get_net(), target=self.target, source=self.source) return 'output'
def _create_decoder_block(self, subnet_unit: ReturnnNetwork, source, i): prefix = self.prefix_name + ('dec_%i' % i) masked_mhsa = self._create_masked_mhsa(subnet_unit, source, prefix) ff = self._create_ff_block(subnet_unit, masked_mhsa, prefix) out = subnet_unit.add_copy_layer(prefix, ff) return out
def create(self): out_net = ReturnnNetwork() out_net.add_linear_layer('s_transformed', 's', n_out=self.enc_key_dim, with_bias=False, l2=self.l2) # project query if self.loc_num_channels is not None: assert self.loc_filter_size is not None weight_feedback = ConvLocAwareness( enc_key_dim=self.enc_key_dim, filter_size=self.loc_filter_size, num_channels=self.loc_num_channels, l2=self.l2) else: # additive weight_feedback = AdditiveLocAwareness( enc_key_dim=self.enc_key_dim, att_num_heads=self.att_num_heads) out_net.update(weight_feedback.create()) # add att weight feedback out_net.add_combine_layer( 'energy_in', ['base:enc_ctx', weight_feedback.name, 's_transformed'], kind='add', n_out=self.enc_key_dim) # compute energies out_net.add_activation_layer('energy_tanh', 'energy_in', activation='tanh') energy = out_net.add_linear_layer('energy', 'energy_tanh', n_out=self.att_num_heads, with_bias=False, l2=self.l2) if self.att_dropout: att_weights0 = out_net.add_softmax_over_spatial_layer( 'att_weights0', energy) att_weights = out_net.add_dropout_layer( 'att_weights', att_weights0, dropout=self.att_dropout, dropout_noise_shape={'*': None}) else: att_weights = out_net.add_softmax_over_spatial_layer( 'att_weights', energy) att0 = out_net.add_generic_att_layer('att0', weights=att_weights, base='base:enc_value') self.name = out_net.add_merge_dims_layer('att', att0, axes='except_batch') return out_net.get_net()
def _create_external_lm_net(self) -> dict: lm_net_out = ReturnnNetwork() ext_lm_subnet = self.ext_lm_opts['lm_subnet'] ext_lm_scale = self.ext_lm_opts['lm_scale'] assert isinstance(ext_lm_subnet, dict) is_recurrent = self.ext_lm_opts.get('is_recurrent', False) if is_recurrent: lm_output_prob = self.ext_lm_opts['lm_output_prob_name'] ext_lm_subnet[lm_output_prob]['target'] = self.target lm_net_out.update(ext_lm_subnet) # just append else: ext_lm_model = self.ext_lm_opts['lm_model'] lm_net_out.add_subnetwork( 'lm_output', 'prev:output', subnetwork_net=ext_lm_subnet, load_on_init=ext_lm_model) lm_output_prob = lm_net_out.add_activation_layer( 'lm_output_prob', 'lm_output', activation='softmax', target=self.target) fusion_str = 'safe_log(source(0)) + {} * safe_log(source(1))'.format(ext_lm_scale) # shallow fusion fusion_source = [self.am_output_prob, lm_output_prob] if self.prior_lm_opts: if self.dec_type == 'lstm': ilm_decoder = LSTMILMDecoder(self.asr_decoder, self.prior_lm_opts) elif self.dec_type == 'transformer': ilm_decoder = TransformerMiniLSTMDecoder(self.asr_decoder, self.prior_lm_opts) else: raise ValueError('dec type: {} is not valid'.format(self.dec_type)) ilm_decoder.create_network() # add ILM fusion_str += ' - {} * safe_log(source(2))'.format(self.prior_lm_opts['scale']) fusion_source += [ilm_decoder.output_prob_name] lm_net_out.add_eval_layer('combo_output_prob', source=fusion_source, eval=fusion_str) lm_net_out.add_choice_layer( 'output', 'combo_output_prob', target=self.target, beam_size=self.beam_size, initial_output=0, input_type='log_prob') return lm_net_out.get_net()