def forward(self, src, src_lens=None, hidden=None, use_gpu=True): """ Args: src: list of src word_ids [batch_size, seq_len, word_ids] """ # import pdb; pdb.set_trace() # src_lens=None device = check_device(use_gpu) # src mask mask_src = src.data.eq(PAD) batch_size = src.size(0) seq_len = src.size(1) # convert id to embedding emb_src = self.embedding_dropout(self.embedder_enc(src)) # run enc # bilstm: pack paded seq. for bilstm (rm impact of padding) if type(src_lens) != type(None): src_lens = torch.cat(src_lens) emb_src_pack = torch.nn.utils.rnn.pack_padded_sequence( emb_src, src_lens, batch_first=True, enforce_sorted=False) enc_outputs_pack, enc_hidden = self.enc(emb_src_pack, hidden) enc_outputs, _ = torch.nn.utils.rnn.pad_packed_sequence( enc_outputs_pack, batch_first=True) else: enc_outputs, enc_hidden = self.enc(emb_src, hidden) # unilstm enc_outputs = self.dropout(enc_outputs)\ .view(batch_size, seq_len, enc_outputs.size(-1)) if self.num_unilstm_enc != 0: if not self.residual: enc_hidden_uni_init = None enc_outputs, enc_hidden_uni = self.enc_uni( enc_outputs, enc_hidden_uni_init) enc_outputs = self.dropout(enc_outputs).view( batch_size, seq_len, enc_outputs.size(-1)) else: enc_hidden_uni_init = None enc_hidden_uni_lis = [] for i in range(self.num_unilstm_enc): enc_inputs = enc_outputs enc_func = getattr(self.enc_uni, 'l' + str(i)) enc_outputs, enc_hidden_uni = enc_func( enc_inputs, enc_hidden_uni_init) enc_hidden_uni_lis.append(enc_hidden_uni) if i < self.num_unilstm_enc - 1: # no residual for last layer enc_outputs = enc_outputs + enc_inputs enc_outputs = self.dropout(enc_outputs).view( batch_size, seq_len, enc_outputs.size(-1)) return enc_outputs
def __init__(self, expt_dir='experiment', load_dir=None, checkpoint_every=100, print_every=100, batch_size=256, use_gpu=False, learning_rate=0.00001, learning_rate_init=0.0005, lr_warmup_steps=16000, max_grad_norm=1.0, eval_with_mask=True, max_count_no_improve=2, max_count_num_rollback=2, keep_num=1, normalise_loss=True, minibatch_split=1): self.use_gpu = use_gpu self.device = check_device(self.use_gpu) self.optimizer = None self.checkpoint_every = checkpoint_every self.print_every = print_every self.learning_rate = learning_rate self.learning_rate_init = learning_rate_init self.lr_warmup_steps = lr_warmup_steps if self.lr_warmup_steps == 0: assert self.learning_rate == self.learning_rate_init self.max_grad_norm = max_grad_norm self.eval_with_mask = eval_with_mask self.max_count_no_improve = max_count_no_improve self.max_count_num_rollback = max_count_num_rollback self.keep_num = keep_num self.normalise_loss = normalise_loss if not os.path.isabs(expt_dir): expt_dir = os.path.join(os.getcwd(), expt_dir) self.expt_dir = expt_dir if not os.path.exists(self.expt_dir): os.makedirs(self.expt_dir) self.load_dir = load_dir self.logger = logging.getLogger(__name__) self.writer = torch.utils.tensorboard.writer.SummaryWriter( log_dir=self.expt_dir) self.minibatch_split = minibatch_split self.batch_size = batch_size self.minibatch_size = int(self.batch_size / self.minibatch_split) # to be changed if OOM
def main(): # load config parser = argparse.ArgumentParser(description='Seq2seq Evaluation') parser = load_arguments(parser) args = vars(parser.parse_args()) config = validate_config(args) # load src-tgt pair test_path_src = config['test_path_src'] test_path_tgt = config['test_path_tgt'] path_vocab_src = config['path_vocab_src'] path_vocab_tgt = config['path_vocab_tgt'] test_path_out = config['test_path_out'] load_dir = config['load'] max_seq_len = config['max_seq_len'] batch_size = config['batch_size'] beam_width = config['beam_width'] use_gpu = config['use_gpu'] seqrev = config['seqrev'] use_type = config['use_type'] if not os.path.exists(test_path_out): os.makedirs(test_path_out) config_save_dir = os.path.join(test_path_out, 'eval.cfg') save_config(config, config_save_dir) # set test mode: 1 = translate; 2 = plot MODE = config['eval_mode'] # check device: device = check_device(use_gpu) print('device: {}'.format(device)) # load test_set test_set = Dataset(test_path_src, test_path_tgt, path_vocab_src, path_vocab_tgt, seqrev=seqrev, max_seq_len=max_seq_len, batch_size=batch_size, use_gpu=use_gpu, use_type=use_type) print('Testset loaded') sys.stdout.flush() # run eval if MODE == 1: translate(test_set, load_dir, test_path_out, use_gpu, max_seq_len, beam_width, device, seqrev=seqrev)
def forward_train(self, src, tgt, debug_flag=False, use_gpu=True): """ train enc + dec note: all output useful up to the second last element i.e. b x (len-1) e.g. [b,:-1] for preds - src: w1 w2 w3 <EOS> <PAD> <PAD> <PAD> ref: BOS w1 w2 w3 <EOS> <PAD> <PAD> tgt: w1 w2 w3 <EOS> <PAD> <PAD> dummy ref start with BOS, the last elem does not have ref! """ # import pdb; pdb.set_trace() # note: adding .type(torch.uint8) to be compatible with pytorch 1.1! # check gpu global device device = check_device(use_gpu) # run transformer src_mask = _get_pad_mask(src).to(device=device).type( torch.uint8) # b x len tgt_mask = ((_get_pad_mask(tgt).to(device=device).type(torch.uint8) & _get_subsequent_mask(self.max_seq_len).type( torch.uint8).to(device=device))) # b x len x dim_model if self.enc_emb_proj_flag: emb_src = self.enc_emb_proj( self.embedding_dropout(self.enc_embedder(src))) else: emb_src = self.embedding_dropout(self.enc_embedder(src)) if self.dec_emb_proj_flag: emb_tgt = self.dec_emb_proj( self.embedding_dropout(self.dec_embedder(tgt))) else: emb_tgt = self.embedding_dropout(self.dec_embedder(tgt)) enc_outputs, *_ = self.enc(emb_src, src_mask=src_mask) # b x len x dim_model dec_outputs, *_ = self.dec(emb_tgt, enc_outputs, tgt_mask=tgt_mask, src_mask=src_mask) logits = self.out(dec_outputs) # b x len x vocab_size logps = torch.log_softmax(logits, dim=2) preds = logps.data.topk(1)[1] return preds, logps, dec_outputs
def __init__(self, expt_dir='experiment', load_dir=None, batch_size=64, minibatch_partition=20, checkpoint_every=100, print_every=100, learning_rate=0.001, eval_with_mask=True, scheduled_sampling=False, teacher_forcing_ratio=1.0, use_gpu=False, max_grad_norm=1.0, max_count_no_improve=3, max_count_num_rollback=3, keep_num=2, normalise_loss=True): self.use_gpu = use_gpu self.device = check_device(self.use_gpu) self.optimizer = None self.checkpoint_every = checkpoint_every self.print_every = print_every self.learning_rate = learning_rate self.max_grad_norm = max_grad_norm self.eval_with_mask = eval_with_mask self.scheduled_sampling = scheduled_sampling self.teacher_forcing_ratio = teacher_forcing_ratio self.max_count_no_improve = max_count_no_improve self.max_count_num_rollback = max_count_num_rollback self.keep_num = keep_num self.normalise_loss = normalise_loss if not os.path.isabs(expt_dir): expt_dir = os.path.join(os.getcwd(), expt_dir) self.expt_dir = expt_dir if not os.path.exists(self.expt_dir): os.makedirs(self.expt_dir) self.load_dir = load_dir self.logger = logging.getLogger(__name__) self.writer = torch.utils.tensorboard.writer.SummaryWriter( log_dir=self.expt_dir) self.batch_size = batch_size self.minibatch_partition = minibatch_partition self.minibatch_size = int(self.batch_size / self.minibatch_partition)
def forward(self, src, src_lens, hidden=None, use_gpu=True): """ Args: src: list of src word_ids [batch_size, seq_len, word_ids] """ # import pdb; pdb.set_trace() out_dict = {} device = check_device(use_gpu) # src mask mask_src = src.data.eq(PAD) batch_size = src.size(0) seq_len = src.size(1) # convert id to embedding emb_src = self.embedding_dropout(self.embedder(src)) # run lstm: packing + unpacking src_lens = torch.cat(src_lens) emb_src_pack = torch.nn.utils.rnn.pack_padded_sequence( emb_src, src_lens, batch_first=True, enforce_sorted=False) lstm_outputs_pack, lstm_hidden = self.lstm(emb_src_pack, hidden) lstm_outputs, _ = torch.nn.utils.rnn.pad_packed_sequence( lstm_outputs_pack, batch_first=True) lstm_outputs = self.dropout(lstm_outputs)\ .view(batch_size, seq_len, lstm_outputs.size(-1)) # generate predictions logits = self.out(lstm_outputs) logps = F.log_softmax(logits, dim=2) symbols = logps.topk(1)[1] out_dict['sequence'] = symbols return logps, out_dict
def main(): # load config parser = argparse.ArgumentParser(description='Seq2seq Evaluation') parser = load_arguments(parser) args = vars(parser.parse_args()) config = validate_config(args) # load src-tgt pair test_path_src = config['test_path_src'] test_path_tgt = config['test_path_tgt'] # dummy if type(test_path_tgt) == type(None): test_path_tgt = test_path_src test_path_out = config['test_path_out'] load_dir = config['load'] max_seq_len = config['max_seq_len'] batch_size = config['batch_size'] beam_width = config['beam_width'] use_gpu = config['use_gpu'] seqrev = config['seqrev'] use_type = config['use_type'] if not os.path.exists(test_path_out): os.makedirs(test_path_out) config_save_dir = os.path.join(test_path_out, 'eval.cfg') save_config(config, config_save_dir) # set test mode: 1 = translate; 3 = plot MODE = config['eval_mode'] if MODE == 3: max_seq_len = 32 batch_size = 1 beam_width = 1 use_gpu = False # check device: device = check_device(use_gpu) print('device: {}'.format(device)) # load model latest_checkpoint_path = load_dir resume_checkpoint = Checkpoint.load(latest_checkpoint_path) model = resume_checkpoint.model.to(device) vocab_src = resume_checkpoint.input_vocab vocab_tgt = resume_checkpoint.output_vocab print('Model dir: {}'.format(latest_checkpoint_path)) print('Model laoded') # load test_set test_set = Dataset(test_path_src, test_path_tgt, vocab_src_list=vocab_src, vocab_tgt_list=vocab_tgt, seqrev=seqrev, max_seq_len=max_seq_len, batch_size=batch_size, use_gpu=use_gpu, use_type=use_type) print('Test dir: {}'.format(test_path_src)) print('Testset loaded') sys.stdout.flush() # run eval if MODE == 1: # FR translate(test_set, model, test_path_out, use_gpu, max_seq_len, beam_width, device, seqrev=seqrev) if MODE == 2: translate_batch(test_set, model, test_path_out, use_gpu, max_seq_len, beam_width, device, seqrev=seqrev) elif MODE == 3: # plotting att_plot(test_set, model, test_path_out, use_gpu, max_seq_len, beam_width, device) elif MODE == 4: # TF translate_tf(test_set, model, test_path_out, use_gpu, max_seq_len, beam_width, device, seqrev=seqrev)
def main(): # import pdb; pdb.set_trace() # load config parser = argparse.ArgumentParser(description='LAS + NMT Training') parser = load_arguments(parser) args = vars(parser.parse_args()) config = validate_config(args) # set random seed if config['random_seed'] is not None: set_global_seeds(config['random_seed']) # record config if not os.path.isabs(config['save']): config_save_dir = os.path.join(os.getcwd(), config['save']) if not os.path.exists(config['save']): os.makedirs(config['save']) # resume or not if type(config['load']) != type(None) and config['load_mode'] == 'resume': config_save_dir = os.path.join(config['save'], 'model-cont.cfg') else: config_save_dir = os.path.join(config['save'], 'model.cfg') save_config(config, config_save_dir) loss_coeff = {} loss_coeff['nll_asr'] = config['loss_nll_asr_coeff'] loss_coeff['nll_mt'] = config['loss_nll_mt_coeff'] loss_coeff['nll_st'] = config['loss_nll_st_coeff'] # contruct trainer Trainer = globals()['Trainer_{}'.format(config['mode'])] t = Trainer(expt_dir=config['save'], load_dir=config['load'], load_mode=config['load_mode'], load_freeze=config['load_freeze'], batch_size=config['batch_size'], minibatch_partition=config['minibatch_partition'], checkpoint_every=config['checkpoint_every'], print_every=config['print_every'], learning_rate=config['learning_rate'], learning_rate_init=config['learning_rate_init'], lr_warmup_steps=config['lr_warmup_steps'], eval_with_mask=config['eval_with_mask'], use_gpu=config['use_gpu'], gpu_id=config['gpu_id'], max_grad_norm=config['max_grad_norm'], max_count_no_improve=config['max_count_no_improve'], max_count_num_rollback=config['max_count_num_rollback'], keep_num=config['keep_num'], normalise_loss=config['normalise_loss'], loss_coeff=loss_coeff) # vocab path_vocab_src = config['path_vocab_src'] path_vocab_tgt = config['path_vocab_tgt'] # ----- 3WAY ----- train_set = None dev_set = None mode = config['mode'] if 'ST' in mode: # load train set if config['st_train_path_src']: t.logger.info(' -- load ST train set -- ') train_path_src = config['st_train_path_src'] train_path_tgt = config['st_train_path_tgt'] train_acous_path = config['st_train_acous_path'] train_set = Dataset(path_src=train_path_src, path_tgt=train_path_tgt, path_vocab_src=path_vocab_src, path_vocab_tgt=path_vocab_tgt, use_type=config['use_type'], acous_path=train_acous_path, seqrev=config['seqrev'], acous_norm=config['las_acous_norm'], acous_norm_path=config['st_acous_norm_path'], acous_max_len=config['las_acous_max_len'], max_seq_len_src=config['max_seq_len_src'], max_seq_len_tgt=config['max_seq_len_tgt'], batch_size=config['batch_size'], data_ratio=config['st_data_ratio'], use_gpu=config['use_gpu'], mode='ST', logger=t.logger) vocab_size_enc = len(train_set.vocab_src) vocab_size_dec = len(train_set.vocab_tgt) src_word2id = train_set.src_word2id tgt_word2id = train_set.tgt_word2id src_id2word = train_set.src_id2word tgt_id2word = train_set.tgt_id2word # load dev set if config['st_dev_path_src']: t.logger.info(' -- load ST dev set -- ') dev_path_src = config['st_dev_path_src'] dev_path_tgt = config['st_dev_path_tgt'] dev_acous_path = config['st_dev_acous_path'] dev_set = Dataset(path_src=dev_path_src, path_tgt=dev_path_tgt, path_vocab_src=path_vocab_src, path_vocab_tgt=path_vocab_tgt, use_type=config['use_type'], acous_path=dev_acous_path, acous_norm_path=config['st_acous_norm_path'], acous_max_len=config['las_acous_max_len'], seqrev=config['seqrev'], acous_norm=config['las_acous_norm'], max_seq_len_src=config['max_seq_len_src'], max_seq_len_tgt=config['max_seq_len_tgt'], batch_size=config['batch_size'], use_gpu=config['use_gpu'], mode='ST', logger=t.logger) else: dev_set = None # ----- ASR ----- asr_train_set = None asr_dev_set = None if 'ASR' in mode: # load train set if config['asr_train_path_src']: t.logger.info(' -- load ASR train set -- ') asr_train_path_src = config['asr_train_path_src'] asr_train_acous_path = config['asr_train_acous_path'] asr_train_set = Dataset( path_src=asr_train_path_src, path_tgt=None, path_vocab_src=path_vocab_src, path_vocab_tgt=path_vocab_tgt, use_type=config['use_type'], acous_path=asr_train_acous_path, acous_norm_path=config['asr_train_acous_norm_path'], seqrev=config['seqrev'], acous_norm=config['las_acous_norm'], acous_max_len=config['las_acous_max_len'], max_seq_len_src=config['max_seq_len_src'], max_seq_len_tgt=config['max_seq_len_tgt'], batch_size=config['batch_size'], data_ratio=config['asr_data_ratio'], use_gpu=config['use_gpu'], mode='ASR', logger=t.logger) vocab_size_enc = len(asr_train_set.vocab_src) vocab_size_dec = len(asr_train_set.vocab_tgt) src_word2id = asr_train_set.src_word2id tgt_word2id = asr_train_set.tgt_word2id src_id2word = asr_train_set.src_id2word tgt_id2word = asr_train_set.tgt_id2word # load dev set if config['asr_dev_path_src']: t.logger.info(' -- load ASR dev set -- ') asr_dev_path_src = config['asr_dev_path_src'] asr_dev_acous_path = config['asr_dev_acous_path'] asr_dev_set = Dataset( path_src=asr_dev_path_src, path_tgt=None, path_vocab_src=path_vocab_src, path_vocab_tgt=path_vocab_tgt, use_type=config['use_type'], acous_path=asr_dev_acous_path, acous_norm_path=config['asr_dev_acous_norm_path'], acous_max_len=config['las_acous_max_len'], seqrev=config['seqrev'], acous_norm=config['las_acous_norm'], max_seq_len_src=config['max_seq_len_src'], max_seq_len_tgt=config['max_seq_len_tgt'], batch_size=config['batch_size'], use_gpu=config['use_gpu'], mode='ASR', logger=t.logger) else: asr_dev_set = None # ----- MT ----- mt_train_set = None mt_dev_set = None if 'MT' in mode: # load train set if config['mt_train_path_src']: t.logger.info(' -- load MT train set -- ') mt_train_path_src = config['mt_train_path_src'] mt_train_path_tgt = config['mt_train_path_tgt'] mt_train_set = Dataset(path_src=mt_train_path_src, path_tgt=mt_train_path_tgt, path_vocab_src=path_vocab_src, path_vocab_tgt=path_vocab_tgt, use_type=config['use_type'], acous_path=None, acous_norm_path=None, seqrev=config['seqrev'], acous_norm=config['las_acous_norm'], acous_max_len=config['las_acous_max_len'], max_seq_len_src=config['max_seq_len_src'], max_seq_len_tgt=config['max_seq_len_tgt'], batch_size=config['batch_size'], data_ratio=config['mt_data_ratio'], use_gpu=config['use_gpu'], mode='MT', logger=t.logger) vocab_size_enc = len(mt_train_set.vocab_src) vocab_size_dec = len(mt_train_set.vocab_tgt) src_word2id = mt_train_set.src_word2id tgt_word2id = mt_train_set.tgt_word2id src_id2word = mt_train_set.src_id2word tgt_id2word = mt_train_set.tgt_id2word # load dev set if config['mt_dev_path_src']: t.logger.info(' -- load MT dev set -- ') mt_dev_path_src = config['mt_dev_path_src'] mt_dev_path_tgt = config['mt_dev_path_tgt'] mt_dev_set = Dataset(path_src=mt_dev_path_src, path_tgt=mt_dev_path_tgt, path_vocab_src=path_vocab_src, path_vocab_tgt=path_vocab_tgt, use_type=config['use_type'], acous_path=None, acous_norm_path=None, acous_max_len=config['las_acous_max_len'], seqrev=config['seqrev'], acous_norm=config['las_acous_norm'], max_seq_len_src=config['max_seq_len_src'], max_seq_len_tgt=config['max_seq_len_tgt'], batch_size=config['batch_size'], use_gpu=config['use_gpu'], mode='MT', logger=t.logger) else: mt_dev_set = None # collect all datasets train_sets = {} dev_sets = {} train_sets['st'] = train_set train_sets['asr'] = asr_train_set train_sets['mt'] = mt_train_set dev_sets['st'] = dev_set dev_sets['asr'] = asr_dev_set dev_sets['mt'] = mt_dev_set # device device = check_device(config['use_gpu']) t.logger.info('device:{}'.format(device)) # construct nmt model seq2seq = Seq2seq( vocab_size_enc, vocab_size_dec, share_embedder=config['share_embedder'], enc_embedding_size=config['embedding_size_enc'], dec_embedding_size=config['embedding_size_dec'], load_embedding_src=config['load_embedding_src'], load_embedding_tgt=config['load_embedding_tgt'], num_heads=config['num_heads'], dim_model=config['dim_model'], dim_feedforward=config['dim_feedforward'], enc_layers=config['enc_layers'], dec_layers=config['dec_layers'], embedding_dropout=config['embedding_dropout'], dropout=config['dropout'], max_seq_len_src=config['max_seq_len_src'], max_seq_len_tgt=config['max_seq_len_tgt'], act=config['act'], enc_word2id=src_word2id, dec_word2id=tgt_word2id, enc_id2word=src_id2word, dec_id2word=tgt_id2word, transformer_type=config['transformer_type'], enc_emb_proj=config['enc_emb_proj'], dec_emb_proj=config['dec_emb_proj'], # acous_dim=config['las_acous_dim'], acous_hidden_size=config['las_acous_hidden_size'], # mode=config['mode'], load_mode=config['load_mode']) seq2seq = seq2seq.to(device=device) # run training seq2seq = t.train(train_sets, seq2seq, num_epochs=config['num_epochs'], dev_sets=dev_sets, grab_memory=config['grab_memory'])
def forward_translate_fast(self, src, beam_width=1, penalty_factor=1, use_gpu=True): """ require large memory - run on cpu """ # import pdb; pdb.set_trace() # check gpu global device device = check_device(use_gpu) # run dd src_mask = _get_pad_mask(src).type(torch.uint8).to( device=device) # b x 1 x len if self.enc_emb_proj_flag: emb_src = self.enc_emb_proj(self.enc_embedder(src)) else: emb_src = self.enc_embedder(src) enc_outputs, *_ = self.enc(emb_src, src_mask=src_mask) # b x len x dim_model batch = src.size(0) length_in = src.size(1) length_out = self.max_seq_len eos_mask = torch.BoolTensor([False]).repeat( batch * beam_width).to(device=device) len_map = torch.Tensor([1 ]).repeat(batch * beam_width).to(device=device) preds = torch.Tensor([BOS]).repeat(batch, 1).type( torch.LongTensor).to(device=device) # repeat for beam_width times # a b c d -> aaa bbb ccc ddd # b x 1 x len -> (b x beam_width) x 1 x len src_mask_expand = src_mask.repeat(1, beam_width, 1).view(-1, 1, length_in) # b x len x dim_model -> (b x beam_width) x len x dim_model enc_outputs_expand = enc_outputs.repeat(1, beam_width, 1).view( -1, length_in, self.dim_model) # (b x beam_width) x len preds_expand = preds.repeat(1, beam_width).view(-1, preds.size(-1)) # (b x beam_width) scores_expand = torch.Tensor([0]).repeat(batch * beam_width).type( torch.FloatTensor).to(device=device) # loop over sequence length for i in range(1, self.max_seq_len): # gen: 0-30; ref: 1-31 # import pdb; pdb.set_trace() # Get k candidates for each beam, k^2 candidates in total (k=beam_width) tgt_mask_expand = (( _get_pad_mask(preds_expand).type(torch.uint8).to(device=device) & _get_subsequent_mask(preds_expand.size(-1)).type( torch.uint8).to(device=device))) if self.dec_emb_proj_flag: emb_tgt_expand = self.dec_emb_proj( self.dec_embedder(preds_expand)) else: emb_tgt_expand = self.dec_embedder(preds_expand) if i == 1: cache_decslf = None cache_encdec = None dec_output_expand, *_, cache_decslf, cache_encdec = self.dec( emb_tgt_expand, enc_outputs_expand, tgt_mask=tgt_mask_expand, src_mask=src_mask_expand, decode_speedup=True, cache_decslf=cache_decslf, cache_encdec=cache_encdec) logit_expand = self.out(dec_output_expand) # (b x beam_width) x len x vocab_size logp_expand = torch.log_softmax(logit_expand, dim=2) # (b x beam_width) x len x beam_width score_expand, pred_expand = logp_expand.data.topk(beam_width) # select current slice dec_output = dec_output_expand[:, i - 1] # (b x beam_width) x dim_model - nouse logp = logp_expand[:, i - 1, :] # (b x beam_width) x vocab_size - nouse pred = pred_expand[:, i - 1] # (b x beam_width) x beam_width score = score_expand[:, i - 1] # (b x beam_width) x beam_width # select k candidates from k^2 candidates if i == 1: # inital state, keep first k candidates # b x (beam_width x beam_width) -> b x (beam_width) -> (b x beam_width) x 1 score_select = scores_expand + score.reshape(batch, -1)[:,:beam_width]\ .contiguous().view(-1) scores_expand = score_select pred_select = pred.reshape( batch, -1)[:, :beam_width].contiguous().view(-1) preds_expand = torch.cat( (preds_expand, pred_select.unsqueeze(-1)), dim=1) else: # keep only 1 candidate when hitting eos # (b x beam_width) x beam_width eos_mask_expand = eos_mask.reshape(-1, 1).repeat(1, beam_width) eos_mask_expand[:, 0] = False # (b x beam_width) x beam_width score_temp = scores_expand.reshape(-1, 1) + score.masked_fill( eos_mask.reshape(-1, 1), 0).masked_fill( eos_mask_expand, -1e9) # length penalty score_temp = score_temp / (len_map.reshape(-1, 1)** penalty_factor) # select top k from k^2 # (b x beam_width^2 -> b x beam_width) score_select, pos = score_temp.reshape(batch, -1).topk(beam_width) scores_expand = score_select.view(-1) * (len_map.reshape( -1, 1)**penalty_factor).view(-1) # select correct elements according to pos pos = (pos + torch.range(0, (batch - 1) * (beam_width**2), (beam_width**2)).to(device=device).reshape( batch, 1)).long() r_idxs, c_idxs = pos // beam_width, pos % beam_width # b x beam_width pred_select = pred[r_idxs, c_idxs].view( -1) # b x beam_width -> (b x beam_width) # Copy the corresponding previous tokens. preds_expand[:, :i] = preds_expand[r_idxs.view(-1), : i] # (b x beam_width) x i # Set the best tokens in this beam search step preds_expand = torch.cat( (preds_expand, pred_select.unsqueeze(-1)), dim=1) # locate the eos in the generated sequences # eos_mask = (pred_select == EOS) + eos_mask # >=pt1.3 eos_mask = ((pred_select == EOS).type(torch.uint8) + eos_mask.type(torch.uint8)).type(torch.bool).type( torch.uint8) # >=pt1.1 len_map = len_map + torch.Tensor([1]).repeat( batch * beam_width).to(device=device).masked_fill(eos_mask, 0) # early stop if sum(eos_mask.int()) == eos_mask.size(0): break # select the best candidate preds = preds_expand.reshape( batch, -1)[:, :self.max_seq_len].contiguous() # b x len scores = scores_expand.reshape(batch, -1)[:, 0].contiguous() # b # select the worst candidate # preds = preds_expand.reshape(batch, -1) # [:, (beam_width - 1)*length : (beam_width)*length].contiguous() # b x len # scores = scores_expand.reshape(batch, -1)[:, -1].contiguous() # b return preds
def __init__( self, # params vocab_size, embedding_size=200, acous_hidden_size=256, acous_att_mode='bahdanau', hidden_size_dec=200, hidden_size_shared=200, num_unilstm_dec=4, use_type='char', # embedding_dropout=0, dropout=0.0, residual=True, batch_first=True, max_seq_len=32, load_embedding=None, word2id=None, id2word=None, hard_att=False, use_gpu=False): super(Dec, self).__init__() device = check_device(use_gpu) # define model self.acous_hidden_size = acous_hidden_size self.acous_att_mode = acous_att_mode self.hidden_size_dec = hidden_size_dec self.hidden_size_shared = hidden_size_shared self.num_unilstm_dec = num_unilstm_dec # define var self.hard_att = hard_att self.residual = residual self.max_seq_len = max_seq_len self.use_type = use_type # use shared embedding + vocab self.vocab_size = vocab_size self.embedding_size = embedding_size self.load_embedding = load_embedding self.word2id = word2id self.id2word = id2word # define operations self.embedding_dropout = nn.Dropout(embedding_dropout) self.dropout = nn.Dropout(dropout) # ------- load embeddings -------- if self.load_embedding: embedding_matrix = np.random.rand(self.vocab_size, self.embedding_size) embedding_matrix = load_pretrained_embedding( self.word2id, embedding_matrix, self.load_embedding) embedding_matrix = torch.FloatTensor(embedding_matrix) self.embedder = nn.Embedding.from_pretrained(embedding_matrix, freeze=False, sparse=False, padding_idx=PAD) else: self.embedder = nn.Embedding(self.vocab_size, self.embedding_size, sparse=False, padding_idx=PAD) # ------ define acous att -------- dropout_acous_att = dropout self.acous_hidden_size_att = 0 # ignored with bilinear self.acous_key_size = self.acous_hidden_size * 2 # acous feats self.acous_value_size = self.acous_hidden_size * 2 # acous feats self.acous_query_size = self.hidden_size_dec # use dec(words) as query self.acous_att = AttentionLayer(self.acous_query_size, self.acous_key_size, value_size=self.acous_value_size, mode=self.acous_att_mode, dropout=dropout_acous_att, query_transform=False, output_transform=False, hidden_size=self.acous_hidden_size_att, use_gpu=use_gpu, hard_att=False) # ------ define acous out -------- self.acous_ffn = nn.Linear(self.acous_hidden_size * 2 + self.hidden_size_dec, self.hidden_size_shared, bias=False) self.acous_out = nn.Linear(self.hidden_size_shared, self.vocab_size, bias=True) # ------ define acous dec ------- # embedding_size_dec + self.hidden_size_shared [200+200]-> hidden_size_dec [200] if not self.residual: self.dec = torch.nn.LSTM(self.embedding_size + self.hidden_size_shared, self.hidden_size_dec, num_layers=self.num_unilstm_dec, batch_first=batch_first, bias=True, dropout=dropout, bidirectional=False) else: self.dec = nn.Module() self.dec.add_module( 'l0', torch.nn.LSTM(self.embedding_size + self.hidden_size_shared, self.hidden_size_dec, num_layers=1, batch_first=batch_first, bias=True, dropout=dropout, bidirectional=False)) for i in range(1, self.num_unilstm_dec): self.dec.add_module( 'l' + str(i), torch.nn.LSTM(self.hidden_size_dec, self.hidden_size_dec, num_layers=1, batch_first=batch_first, bias=True, dropout=dropout, bidirectional=False))
def forward_eval(self, src, debug_flag=False, use_gpu=True): """ eval enc + dec (beam_width = 1) all outputs following: tgt: <BOS> w1 w2 w3 <EOS> <PAD> gen: w1 w2 w3 <EOS> <PAD> <PAD> shift by 1, i.e. used input = <BOS> w1 <PAD> <PAD> gen output = dummy w2 dummy update prediction: assign w2(output[1]) to be input[2] """ # import pdb; pdb.set_trace() # check gpu global device device = check_device(use_gpu) batch = src.size(0) length_out = self.max_seq_len # run enc dec eos_mask = torch.BoolTensor([False]).repeat(batch).to(device=device) src_mask = _get_pad_mask(src).type(torch.uint8).to(device=device) if self.enc_emb_proj_flag: emb_src = self.enc_emb_proj( self.embedding_dropout(self.enc_embedder(src))) else: emb_src = self.embedding_dropout(self.enc_embedder(src)) enc_outputs, enc_var = self.enc(emb_src, src_mask=src_mask) # record logps = torch.Tensor([-1e-4]).repeat( batch, length_out, self.dec_vocab_size).type(torch.FloatTensor).to(device=device) dec_outputs = torch.Tensor([0]).repeat( batch, length_out, self.dim_model).type(torch.FloatTensor).to(device=device) preds_save = torch.Tensor([PAD]).repeat(batch, length_out).type( torch.LongTensor).to(device=device) # used to update pred history # start from length = 1 preds = torch.Tensor([BOS]).repeat(batch, 1).type( torch.LongTensor).to(device=device) preds_save[:, 0] = preds[:, 0] for i in range(1, self.max_seq_len): # gen: 0-30; ref: 1-31 # import pdb; pdb.set_trace() tgt_mask = (( _get_pad_mask(preds).type(torch.uint8).to(device=device) & _get_subsequent_mask(preds.size(-1)).type( torch.uint8).to(device=device))) if self.dec_emb_proj_flag: emb_tgt = self.dec_emb_proj(self.dec_embedder(preds)) else: emb_tgt = self.dec_embedder(preds) dec_output, dec_var, *_ = self.dec(emb_tgt, enc_outputs, tgt_mask=tgt_mask, src_mask=src_mask) logit = self.out(dec_output) logp = torch.log_softmax(logit, dim=2) pred = logp.data.topk(1)[1] # b x :i # eos_mask = (pred[:, i-1].squeeze(1) == EOS) + eos_mask # >=pt1.3 eos_mask = ((pred[:, i - 1].squeeze(1) == EOS).type(torch.uint8) + eos_mask.type(torch.uint8)).type(torch.bool).type( torch.uint8) # >=pt1.1 # b x len x dim_model - [:,0,:] is dummy 0's dec_outputs[:, i, :] = dec_output[:, i - 1] # b x len x vocab_size - [:,0,:] is dummy -1e-4's # individual logps logps[:, i, :] = logp[:, i - 1, :] # b x len - [:,0] is BOS preds_save[:, i] = pred[:, i - 1].view(-1) # append current pred, length+1 preds = torch.cat((preds, pred[:, i - 1]), dim=1) if sum(eos_mask.int()) == eos_mask.size(0): # import pdb; pdb.set_trace() if length_out != preds.size(1): dummy = torch.Tensor([PAD]).repeat( batch, length_out - preds.size(1)).type( torch.LongTensor).to(device=device) preds = torch.cat((preds, dummy), dim=1) # pad to max length break if not debug_flag: return preds, logps, dec_outputs else: return preds, logps, dec_outputs, enc_var, dec_var
def forward_eval_fast(self, src, debug_flag=False, use_gpu=True): """ require large memory - run on cpu """ # import pdb; pdb.set_trace() # check gpu global device device = check_device(use_gpu) batch = src.size(0) length_out = self.max_seq_len # run enc dec src_mask = _get_pad_mask(src).type(torch.uint8).to(device=device) if self.enc_emb_proj_flag: emb_src = self.enc_emb_proj( self.embedding_dropout(self.enc_embedder(src))) else: emb_src = self.embedding_dropout(self.enc_embedder(src)) enc_outputs, enc_var = self.enc(emb_src, src_mask=src_mask) # record logps = torch.Tensor([-1e-4]).repeat( batch, length_out, self.dec_vocab_size).type(torch.FloatTensor).to(device=device) dec_outputs = torch.Tensor([0]).repeat( batch, length_out, self.dim_model).type(torch.FloatTensor).to(device=device) preds_save = torch.Tensor([PAD]).repeat(batch, length_out).type( torch.LongTensor).to(device=device) # used to update pred history # start from length = 1 preds = torch.Tensor([BOS]).repeat(batch, 1).type( torch.LongTensor).to(device=device) preds_save[:, 0] = preds[:, 0] for i in range(1, self.max_seq_len): # gen: 0-30; ref: 1-31 # import pdb; pdb.set_trace() tgt_mask = (( _get_pad_mask(preds).type(torch.uint8).to(device=device) & _get_subsequent_mask(preds.size(-1)).type( torch.uint8).to(device=device))) if self.dec_emb_proj_flag: emb_tgt = self.dec_emb_proj(self.dec_embedder(preds)) else: emb_tgt = self.dec_embedder(preds) if i == 1: cache_decslf = None cache_encdec = None dec_output, dec_var, *_, cache_decslf, cache_encdec = self.dec( emb_tgt, enc_outputs, tgt_mask=tgt_mask, src_mask=src_mask, decode_speedup=True, cache_decslf=cache_decslf, cache_encdec=cache_encdec) logit = self.out(dec_output) logp = torch.log_softmax(logit, dim=2) pred = logp.data.topk(1)[1] # b x :i # b x len x dim_model - [:,0,:] is dummy 0's dec_outputs[:, i, :] = dec_output[:, i - 1] # b x len x vocab_size - [:,0,:] is dummy -1e-4's # individual logps logps[:, i, :] = logp[:, i - 1, :] # b x len - [:,0] is BOS preds_save[:, i] = pred[:, i - 1].view(-1) # append current pred, length+1 preds = torch.cat((preds, pred[:, i - 1]), dim=1) if not debug_flag: return preds, logps, dec_outputs else: return preds, logps, dec_outputs, enc_var, dec_var
def forward(self, acous_outputs, acous_lens=None, tgt=None, hidden=None, is_training=False, teacher_forcing_ratio=0.0, beam_width=1, use_gpu=False): """ Args: enc_outputs: [batch_size, acous_len / 8, self.acous_hidden_size * 2] tgt: list of word_ids [b x seq_len] hidden: initial hidden state is_training: whether in eval or train mode teacher_forcing_ratio: default at 1 - always teacher forcing Returns: decoder_outputs: list of step_output - log predicted_softmax [batch_size, 1, vocab_size_dec] * (T-1) """ # import pdb; pdb.set_trace() global device device = check_device(use_gpu) # 0. init var ret_dict = dict() ret_dict[KEY_ATTN_SCORE] = [] decoder_outputs = [] sequence_symbols = [] batch_size = acous_outputs.size(0) if type(tgt) == type(None): tgt = torch.Tensor([BOS]).repeat( batch_size, self.max_seq_len).type(torch.LongTensor).to(device=device) max_seq_len = tgt.size(1) lengths = np.array([max_seq_len] * batch_size) # 1. convert id to embedding emb_tgt = self.embedding_dropout(self.embedder(tgt)) # 2. att inputs: keys n values att_keys = acous_outputs att_vals = acous_outputs # generate acous mask: True for trailing 0's if type(acous_lens) != type(None): # reduce by 8 lens = torch.cat([elem + 8 - elem % 8 for elem in acous_lens]) / 8 max_acous_len = acous_outputs.size(1) # mask=True over trailing 0s mask = torch.arange(max_acous_len).to(device=device).expand( batch_size, max_acous_len) >= lens.unsqueeze(1).to(device=device) else: mask = None # 3. init hidden states dec_hidden = None # 4. run dec + att + shared + output """ teacher_forcing_ratio = 1.0 -> always teacher forcing E.g.: acous = [acous_len/8] tgt_chunk in = w1 w2 w3 </s> <pad> <pad> <pad> [max_seq_len] predicted = w1 w2 w3 </s> <pad> <pad> <pad> [max_seq_len] """ # LAS under teacher forcing use_teacher_forcing = True if random.random( ) < teacher_forcing_ratio else False # beam search decoding if not is_training and beam_width > 1: decoder_outputs, decoder_hidden, metadata = \ self.beam_search_decoding(att_keys, att_vals, dec_hidden, mask, beam_width=beam_width) return decoder_outputs, decoder_hidden, metadata # no beam search decoding tgt_chunk = self.embedder( torch.Tensor([BOS]).repeat(batch_size, 1).type( torch.LongTensor).to(device=device)) # BOS cell_value = torch.FloatTensor([0])\ .repeat(batch_size, 1, self.hidden_size_shared).to(device=device) prev_c = torch.FloatTensor([0]).repeat(batch_size, 1, max_seq_len).to(device=device) attn_outputs = [] for idx in range(max_seq_len): predicted_logsoftmax, dec_hidden, step_attn, c_out, cell_value, attn_output = \ self.forward_step(self.acous_att, self.acous_ffn, self.acous_out, att_keys, att_vals, tgt_chunk, cell_value, dec_hidden, mask, prev_c) predicted_logsoftmax = predicted_logsoftmax.squeeze( 1) # [b, vocab_size] step_output = predicted_logsoftmax symbols, decoder_outputs, sequence_symbols, lengths = \ self.decode(idx, step_output, decoder_outputs, sequence_symbols, lengths) prev_c = c_out if use_teacher_forcing: tgt_chunk = emb_tgt[:, idx].unsqueeze(1) else: tgt_chunk = self.embedder(symbols) ret_dict[KEY_ATTN_SCORE].append(step_attn) attn_outputs.append(attn_output) ret_dict[KEY_SEQUENCE] = sequence_symbols ret_dict[KEY_LENGTH] = lengths.tolist() ret_dict[KEY_ATTN_OUT] = attn_outputs # import pdb; pdb.set_trace() return decoder_outputs, dec_hidden, ret_dict
def calc_score(self, att_query, att_keys, prev_c=None, use_gpu=True): """ att_query: b x t_q x n_q (inference: t_q=1) att_keys: b x t_k x n_k return: b x t_q x t_k 'dot_prod': att = q * k^T 'bahdanau': att = W * tanh(Uq + Vk + b) 'loc_based': att = a * exp[ b(c-j)^2 ] j - key idx i - query idx prev_c - c_(i-1) a0,b0,c0 parameterised by q, k - (Uq_i + Vk_j + b) a = exp(a0), b=exp(b0) c = prev_c + exp(c0) """ device = check_device(use_gpu) b = att_query.size(0) t_q = att_query.size(1) # = 1 if in inference mode t_k = att_keys.size(1) n_q = att_query.size(2) n_k = att_keys.size(2) c_out = None # placeholder if self.mode == 'bahdanau': att_query = att_query.unsqueeze(2).expand(b, t_q, t_k, n_q) att_keys = att_keys.unsqueeze(1).expand(b, t_q, t_k, n_k) wq = self.linear_att_q(att_query).view(b, t_q, t_k, self.hidden_size) uk = self.linear_att_k(att_keys).view(b, t_q, t_k, self.hidden_size) sum_qk = wq + uk out = self.linear_att_o(F.tanh(sum_qk)).view(b, t_q, t_k) elif self.mode == 'hybrid': # start_time = time.time() att_query = att_query.unsqueeze(2).expand(b, t_q, t_k, n_q) att_keys = att_keys.unsqueeze(1).expand(b, t_q, t_k, n_k) if not hasattr(self, 'linear_att_ao'): # to word with old att setup self.hidden_size = 1 # fix a_wq = self.linear_att_aq(att_query).view( b, t_q, t_k, self.hidden_size) a_uk = self.linear_att_ak(att_keys).view( b, t_q, t_k, self.hidden_size) a_sum_qk = a_wq + a_uk a_out = torch.exp(torch.tanh(a_sum_qk)).view(b, t_q, t_k) b_wq = self.linear_att_bq(att_query).view( b, t_q, t_k, self.hidden_size) b_uk = self.linear_att_bk(att_keys).view( b, t_q, t_k, self.hidden_size) b_sum_qk = b_wq + b_uk b_out = torch.exp(torch.tanh(b_sum_qk)).view(b, t_q, t_k) c_wq = self.linear_att_cq(att_query).view( b, t_q, t_k, self.hidden_size) c_uk = self.linear_att_ck(att_keys).view( b, t_q, t_k, self.hidden_size) c_sum_qk = c_wq + c_uk c_out = torch.exp(torch.tanh(c_sum_qk)).view(b, t_q, t_k) else: # new setup by default a_wq = self.linear_att_aq(att_query).view( b, t_q, t_k, self.hidden_size) a_uk = self.linear_att_ak(att_keys).view( b, t_q, t_k, self.hidden_size) a_sum_qk = a_wq + a_uk a_out = torch.exp(self.linear_att_ao( torch.tanh(a_sum_qk))).view(b, t_q, t_k) b_wq = self.linear_att_bq(att_query).view( b, t_q, t_k, self.hidden_size) b_uk = self.linear_att_bk(att_keys).view( b, t_q, t_k, self.hidden_size) b_sum_qk = b_wq + b_uk b_out = torch.exp(self.linear_att_bo( torch.tanh(b_sum_qk))).view(b, t_q, t_k) c_wq = self.linear_att_cq(att_query).view( b, t_q, t_k, self.hidden_size) c_uk = self.linear_att_ck(att_keys).view( b, t_q, t_k, self.hidden_size) c_sum_qk = c_wq + c_uk c_out = torch.exp(self.linear_att_co( torch.tanh(c_sum_qk))).view(b, t_q, t_k) # print(time.time() - start_time) if t_q != 1: # teacher forcing mode - t_q != 1 key_indices = torch.arange(t_k).repeat(b, t_q).view(b, t_q, t_k)\ .type(torch.FloatTensor).to(device=device) c_curr = torch.FloatTensor([0]).repeat(b, t_q, t_k).to(device=device) for i in range(t_q): c_temp = torch.sum(c_out[:, :i + 1, :], dim=1) c_curr[:, i, :] = c_temp out = a_out * torch.exp(-b_out * torch.pow( (c_curr - key_indices), 2)) else: # infernece mode: t_q = 1 key_indices = torch.arange(t_k).repeat(b, 1).view(b, 1, t_k)\ .type(torch.FloatTensor).to(device=device) c_out = prev_c + c_out out = a_out * torch.exp(-b_out * torch.pow( (c_out - key_indices), 2)) elif self.mode == 'bilinear': wk = self.linear_att_w(att_keys).view(b, t_k, n_q) out = torch.bmm(att_query, wk.transpose(1, 2)) elif self.mode == 'dot_prod': assert n_q == n_k, 'Dot_prod attention - query, key size must agree!' out = torch.bmm(att_query, att_keys.transpose(1, 2)) return out, c_out
def forward_train(self, src, tgt=None, acous_feats=None, acous_lens=None, mode='ST', use_gpu=True, lm_mode='null', lm_model=None): """ mode: ASR acous -> src AE src -> src ST acous -> tgt MT src -> tgt """ # import pdb; pdb.set_trace() # note: adding .type(torch.uint8) to be compatible with pytorch 1.1! out_dict = {} # check gpu global device device = check_device(use_gpu) # check mode mode = mode.upper() assert type(src) != type(None) if 'ST' in mode or 'ASR' in mode: assert type(acous_feats) != type(None) if 'ST' in mode or 'MT' in mode: assert type(tgt) != type(None) if 'ASR' in mode: """ acous -> EN: RNN in : length reduced fbk features out: w1 w2 w3 <EOS> <PAD> <PAD> #=6 """ emb_src, logps_src, preds_src, lengths = self._encoder_acous( acous_feats, acous_lens, device, use_gpu, tgt=src, is_training=True, teacher_forcing_ratio=1.0, lm_mode=lm_mode, lm_model=lm_model) # output dict out_dict['emb_asr'] = emb_src # dynamic out_dict['preds_asr'] = preds_src out_dict['logps_asr'] = logps_src out_dict['lengths_asr'] = lengths if 'MT' in mode: """ EN -> DE: Transformer src: <BOS> w1 w2 w3 <EOS> <PAD> <PAD> #=7 mid: w1 w2 w3 <EOS> <PAD> <PAD> #=6 out: c1 c2 c3 <EOS> <PAD> <PAD> [dummy] #=7 note: add average dynamic embedding to static embedding """ # get tgt emb tgt_mask, emb_tgt = self._get_tgt_emb(tgt, device) # get src emb src_trim = self._pre_proc_src(src, device) emb_dyn_ave = self.EMB_DYN_AVE emb_dyn_ave_expand = emb_dyn_ave.repeat(src_trim.size(0), src_trim.size(1), 1).to(device=device) src_mask, emb_src, src_mask_input = self._get_src_emb( src_trim, emb_dyn_ave_expand, device) # encode decode enc_outputs = self._encoder_en( emb_src, src_mask=src_mask_input) # b x len x dim_model # decode dec_outputs_tgt, logits_tgt, logps_tgt, preds_tgt, _ = \ self._decoder_de(emb_tgt, enc_outputs, tgt_mask=tgt_mask, src_mask=src_mask_input) # output dict out_dict['emb_mt'] = emb_src # combined out_dict['preds_mt'] = preds_tgt out_dict['logps_mt'] = logps_tgt if 'ST' in mode: """ acous -> DE: Transformer in : length reduced fbk features mid: w1 w2 w3 <EOS> <PAD> <PAD> #=6 out: c1 c2 c3 <EOS> <PAD> <PAD> [dummy] #=7 """ # get tgt emb tgt_mask, emb_tgt = self._get_tgt_emb(tgt, device) # run ASR if 'ASR' in mode: emb_src_dyn = out_dict['emb_asr'] lengths = out_dict['lengths_asr'] # else: # use free running if no 'ASR' # emb_src_dyn, _, _, lengths = self._encoder_acous(acous_feats, acous_lens, # device, use_gpu, tgt=src, is_training=True, teacher_forcing_ratio=1.0) else: # use free running if no 'ASR' emb_src_dyn, _, _, lengths = self._encoder_acous( acous_feats, acous_lens, device, use_gpu, is_training=False, teacher_forcing_ratio=0.0, lm_mode=lm_mode, lm_model=lm_model) # get combined embedding src_trim = self._pre_proc_src(src, device) _, emb_src, _ = self._get_src_emb(src_trim, emb_src_dyn, device) # get mask max_len = emb_src.size(1) lengths = torch.LongTensor(lengths) src_mask_input = (torch.arange(max_len).expand( len(lengths), max_len) < lengths.unsqueeze(1)).unsqueeze(1).to( device=device) # encode enc_outputs = self._encoder_en( emb_src, src_mask=src_mask_input) # b x len x dim_model # decode dec_outputs_tgt, logits_tgt, logps_tgt, preds_tgt, _ = \ self._decoder_de(emb_tgt, enc_outputs, tgt_mask=tgt_mask, src_mask=src_mask_input) # output dict out_dict['emb_st'] = emb_src # combined out_dict['preds_st'] = preds_tgt out_dict['logps_st'] = logps_tgt return out_dict
def forward_translate_refen(self, acous_feats=None, acous_lens=None, src=None, beam_width=1, penalty_factor=1, use_gpu=True, max_seq_len=900, mode='ST', lm_mode='null', lm_model=None): """ run inference - with beam search (same output format as is in forward_eval) """ # import pdb; pdb.set_trace() # check gpu global device device = check_device(use_gpu) if mode == 'ASR': _, _, preds_src, _ = self._encoder_acous(acous_feats, acous_lens, device, use_gpu, tgt=src, is_training=False, teacher_forcing_ratio=1.0, lm_mode=lm_mode, lm_model=lm_model) preds = preds_src elif mode == 'MT': batch = src.size(0) # txt encoder src_trim = self._pre_proc_src(src, device) emb_dyn_ave = self.EMB_DYN_AVE emb_dyn_ave_expand = emb_dyn_ave.repeat(src_trim.size(0), src_trim.size(1), 1).to(device=device) src_mask, emb_src, src_mask_input = self._get_src_emb( src_trim, emb_dyn_ave_expand, device) enc_outputs = self._encoder_en(emb_src, src_mask=src_mask_input) length_in = enc_outputs.size(1) # prep eos_mask, len_map, preds, enc_outputs_expand, preds_expand, \ scores_expand, src_mask_input_expand = self._prep_translate( batch, beam_width, device, length_in, enc_outputs, src_mask_input) # loop over sequence length for i in range(1, max_seq_len): tgt_mask_expand, emb_tgt_expand = self._get_tgt_emb( preds_expand, device) dec_output_expand, logit_expand, logp_expand, pred_expand, score_expand = \ self._decoder_de(emb_tgt_expand, enc_outputs_expand, tgt_mask=tgt_mask_expand, src_mask=src_mask_input_expand, beam_width=beam_width) scores_expand, preds_expand, eos_mask, len_map, flag = \ self._step_translate(i, batch, beam_width, device, dec_output_expand, logp_expand, pred_expand, score_expand, preds_expand, scores_expand, eos_mask, len_map, penalty_factor) if flag == 1: break # select the best candidate preds = preds_expand.reshape( batch, -1)[:, :max_seq_len].contiguous() # b x len scores = scores_expand.reshape(batch, -1)[:, 0].contiguous() # b elif mode == 'ST': batch = acous_feats.size(0) # get embedding emb_src_dyn, _, preds_src, lengths = self._encoder_acous( acous_feats, acous_lens, device, use_gpu, tgt=src, is_training=False, teacher_forcing_ratio=1.0, lm_mode=lm_mode, lm_model=lm_model) src_trim = self._pre_proc_src(src, device) _, emb_src, _ = self._get_src_emb(src_trim, emb_src_dyn, device) # use ref # get mask max_len = emb_src.size(1) lengths = torch.LongTensor(lengths) src_mask_input = (torch.arange(max_len).expand( len(lengths), max_len) < lengths.unsqueeze(1)).unsqueeze(1).to( device=device) # encode enc_outputs = self._encoder_en( emb_src, src_mask=src_mask_input) # b x len x dim_model length_in = enc_outputs.size(1) # prep eos_mask, len_map, preds, enc_outputs_expand, preds_expand, \ scores_expand, src_mask_input_expand = self._prep_translate( batch, beam_width, device, length_in, enc_outputs, src_mask_input) # loop over sequence length for i in range(1, max_seq_len): # import pdb; pdb.set_trace() # Get k candidates for each beam, k^2 candidates in total (k=beam_width) tgt_mask_expand, emb_tgt_expand = self._get_tgt_emb( preds_expand, device) dec_output_expand, logit_expand, logp_expand, pred_expand, score_expand = \ self._decoder_de(emb_tgt_expand, enc_outputs_expand, tgt_mask=tgt_mask_expand, src_mask=src_mask_input_expand, beam_width=beam_width) scores_expand, preds_expand, eos_mask, len_map, flag = \ self._step_translate(i, batch, beam_width, device, dec_output_expand, logp_expand, pred_expand, score_expand, preds_expand, scores_expand, eos_mask, len_map, penalty_factor) if flag == 1: break # select the best candidate preds = preds_expand.reshape( batch, -1)[:, :max_seq_len].contiguous() # b x len scores = scores_expand.reshape(batch, -1)[:, 0].contiguous() # b return preds
def forward(self, query, keys, values=None, prev_c=None, use_gpu=True): """ query(out): b x t_q x n_q keys(in): b x t_k x n_k (usually: n_k >= n_v - keys are richer) vals(in): b x t_k x n_v context: b x t_q x output_size scores: b x t_q x t_k prev_c: for loc_based attention; None otherwise c_out: for loc_based attention; None otherwise in general n_q = embedding_dim n_k = size of key vectors n_v = size of value vectors """ device = check_device(use_gpu) if not self.batch_first: keys = keys.transpose(0, 1) if values is not None: values = values.transpose(0, 1) if query.dim() == 3: query = query.transpose(0, 1) if query.dim() == 2: single_query = True query = query.unsqueeze(1) else: single_query = False values = keys if values is None else values # b x t_k x n_v/n_k b = query.size(0) t_k = keys.size(1) t_q = query.size(1) if hasattr(self, 'linear_q'): att_query = self.linear_q(query) else: att_query = query # b x t_q x t_k scores, c_out = self.calc_score(att_query, keys, prev_c, use_gpu=use_gpu) if self.mask is not None: mask = self.mask.unsqueeze(1).expand(b, t_q, t_k) scores.masked_fill_(mask, -1e12) # Normalize the scores OR use hard attention if hasattr(self, 'hard_att'): if self.hard_att: top_idx = torch.argmax(scores, dim=2) scores_view = scores.view(-1, t_k) scores_hard = (scores_view == scores_view.max( dim=1, keepdim=True)[0]).view_as(scores) scores_hard = scores_hard.type(torch.FloatTensor) total_score = torch.sum(scores_hard, dim=2) total_score = total_score.view(b, t_q, 1).repeat(1, 1, t_k).view_as(scores) scores_normalized = (scores_hard / total_score).to(device=device) else: scores_normalized = F.softmax(scores, dim=2) else: scores_normalized = F.softmax(scores, dim=2) # print(torch.argmax(scores_normalized[0], dim=1)) # Context = the weighted average of the attention inputs # scores_normalized = self.dropout(scores_normalized) # b x t_q x t_k context = torch.bmm(scores_normalized, values) # b x t_q x n_v if hasattr(self, 'linear_out'): context = self.linear_out(torch.cat([query, context], 2)) if self.output_nonlinearity == 'tanh': context = F.tanh(context) elif self.output_nonlinearity == 'relu': context = F.relu(context, inplace=True) if single_query: context = context.squeeze(1) scores_normalized = scores_normalized.squeeze(1) elif not self.batch_first: context = context.transpose(0, 1) scores_normalized = scores_normalized.transpose(0, 1) return context, scores_normalized, c_out
def main(): # import pdb; pdb.set_trace() # load config parser = argparse.ArgumentParser(description='LAS Training') parser = load_arguments(parser) args = vars(parser.parse_args()) config = validate_config(args) # set random seed if config['random_seed'] is not None: set_global_seeds(config['random_seed']) # record config if not os.path.isabs(config['save']): config_save_dir = os.path.join(os.getcwd(), config['save']) if not os.path.exists(config['save']): os.makedirs(config['save']) # resume or not if type(config['load']) != type(None): config_save_dir = os.path.join(config['save'], 'model-cont.cfg') else: config_save_dir = os.path.join(config['save'], 'model.cfg') save_config(config, config_save_dir) # contruct trainer t = Trainer(expt_dir=config['save'], load_dir=config['load'], batch_size=config['batch_size'], minibatch_partition=config['minibatch_partition'], checkpoint_every=config['checkpoint_every'], print_every=config['print_every'], learning_rate=config['learning_rate'], eval_with_mask=config['eval_with_mask'], scheduled_sampling=config['scheduled_sampling'], teacher_forcing_ratio=config['teacher_forcing_ratio'], use_gpu=config['use_gpu'], max_grad_norm=config['max_grad_norm'], max_count_no_improve=config['max_count_no_improve'], max_count_num_rollback=config['max_count_num_rollback'], keep_num=config['keep_num'], normalise_loss=config['normalise_loss']) # vocab path_vocab_src = config['path_vocab_src'] # load train set train_path_src = config['train_path_src'] train_acous_path = config['train_acous_path'] train_set = Dataset(train_path_src, path_vocab_src=path_vocab_src, use_type=config['use_type'], acous_path=train_acous_path, seqrev=config['seqrev'], acous_norm=config['acous_norm'], acous_norm_path=config['acous_norm_path'], max_seq_len=config['max_seq_len'], batch_size=config['batch_size'], acous_max_len=config['acous_max_len'], use_gpu=config['use_gpu'], logger=t.logger) vocab_size = len(train_set.vocab_src) # load dev set if config['dev_path_src']: dev_path_src = config['dev_path_src'] dev_acous_path = config['dev_acous_path'] dev_set = Dataset(dev_path_src, path_vocab_src=path_vocab_src, use_type=config['use_type'], acous_path=dev_acous_path, acous_norm_path=config['acous_norm_path'], seqrev=config['seqrev'], acous_norm=config['acous_norm'], max_seq_len=config['max_seq_len'], batch_size=config['batch_size'], acous_max_len=config['acous_max_len'], use_gpu=config['use_gpu'], logger=t.logger) else: dev_set = None # construct model las_model = LAS(vocab_size, embedding_size=config['embedding_size'], acous_hidden_size=config['acous_hidden_size'], acous_att_mode=config['acous_att_mode'], hidden_size_dec=config['hidden_size_dec'], hidden_size_shared=config['hidden_size_shared'], num_unilstm_dec=config['num_unilstm_dec'], # acous_dim=config['acous_dim'], acous_norm=config['acous_norm'], spec_aug=config['spec_aug'], batch_norm=config['batch_norm'], enc_mode=config['enc_mode'], use_type=config['use_type'], # embedding_dropout=config['embedding_dropout'], dropout=config['dropout'], residual=config['residual'], batch_first=config['batch_first'], max_seq_len=config['max_seq_len'], load_embedding=config['load_embedding'], word2id=train_set.src_word2id, id2word=train_set.src_id2word, use_gpu=config['use_gpu']) device = check_device(config['use_gpu']) t.logger.info('device:{}'.format(device)) las_model = las_model.to(device=device) # run training las_model = t.train( train_set, las_model, num_epochs=config['num_epochs'], dev_set=dev_set)
def forward(self, enc_outputs, src, tgt=None, hidden=None, is_training=False, teacher_forcing_ratio=1.0, beam_width=1, use_gpu=True): """ Args: enc_outputs: [batch_size, max_seq_len, self.hidden_size_enc * 2] tgt: list of tgt word_ids hidden: initial hidden state is_training: whether in eval or train mode teacher_forcing_ratio: default at 1 - always teacher forcing Returns: decoder_outputs: list of step_output - log predicted_softmax [batch_size, 1, vocab_size_dec] * (T-1) ret_dict """ # import pdb; pdb.set_trace() global device device = check_device(use_gpu) # 0. init var ret_dict = dict() ret_dict[KEY_ATTN_SCORE] = [] decoder_outputs = [] sequence_symbols = [] batch_size = enc_outputs.size(0) if type(tgt) == type(None): tgt = torch.Tensor([BOS]).repeat( (batch_size, self.max_seq_len)).type(torch.LongTensor).to(device=device) max_seq_len = tgt.size(1) lengths = np.array([max_seq_len] * batch_size) # 1. convert id to embedding emb_tgt = self.embedding_dropout(self.embedder_dec(tgt)) # 2. att inputs: keys n values mask_src = src.data.eq(PAD) att_keys = enc_outputs att_vals = enc_outputs # 3. init hidden states dec_hidden = None # decoder def decode(step, step_output, step_attn): """ Greedy decoding Note: it should generate EOS, PAD as used in training tgt Args: step: step idx step_output: log predicted_softmax - [batch_size, 1, vocab_size_dec] step_attn: attention scores - (batch_size x tgt_len(query_len) x src_len(key_len) Returns: symbols: most probable symbol_id [batch_size, 1] """ ret_dict[KEY_ATTN_SCORE].append(step_attn) decoder_outputs.append(step_output) symbols = decoder_outputs[-1].topk(1)[1] sequence_symbols.append(symbols) eos_batches = torch.max(symbols.data.eq(EOS), symbols.data.eq(PAD)) # eos_batches = symbols.data.eq(PAD) if eos_batches.dim() > 0: eos_batches = eos_batches.cpu().view(-1).numpy() update_idx = ((lengths > step) & eos_batches) != 0 lengths[update_idx] = len(sequence_symbols) return symbols # 4. run dec + att + shared + output """ teacher_forcing_ratio = 1.0 -> always teacher forcing E.g.: (shift-by-1) emb_tgt = <s> w1 w2 w3 </s> <pad> <pad> <pad> [max_seq_len] tgt_chunk in = <s> w1 w2 w3 </s> <pad> <pad> [max_seq_len - 1] predicted = w1 w2 w3 </s> <pad> <pad> <pad> [max_seq_len - 1] """ # import pdb; pdb.set_trace() if not is_training: use_teacher_forcing = False elif random.random() < teacher_forcing_ratio: use_teacher_forcing = True else: use_teacher_forcing = False # beam search decoding if not is_training and beam_width > 1: decoder_outputs, decoder_hidden, metadata = \ self.beam_search_decoding(att_keys, att_vals, dec_hidden, mask_src, beam_width=beam_width, device=device) return decoder_outputs, decoder_hidden, metadata # greedy search decoding tgt_chunk = emb_tgt[:, 0].unsqueeze(1) # BOS cell_value = torch.FloatTensor([0]).repeat( batch_size, 1, self.hidden_size_shared).to(device=device) prev_c = torch.FloatTensor([0]).repeat( batch_size, 1, max_seq_len).to(device=device) for idx in range(max_seq_len - 1): predicted_logsoftmax, dec_hidden, step_attn, c_out, cell_value = \ self.forward_step(att_keys, att_vals, tgt_chunk, cell_value, dec_hidden, mask_src, prev_c) predicted_logsoftmax = predicted_logsoftmax.squeeze(1) # [b, vocab_size] step_output = predicted_logsoftmax symbols = decode(idx, step_output, step_attn) prev_c = c_out if use_teacher_forcing: tgt_chunk = emb_tgt[:, idx+1].unsqueeze(1) else: tgt_chunk = self.embedder_dec(symbols) ret_dict[KEY_SEQUENCE] = sequence_symbols ret_dict[KEY_LENGTH] = lengths.tolist() return decoder_outputs, dec_hidden, ret_dict
def __init__( self, # params acous_dim=26, acous_hidden_size=256, # acous_norm=False, spec_aug=False, batch_norm=False, enc_mode='pyramid', # dropout=0.0, batch_first=True, use_gpu=False, ): super(Enc, self).__init__() device = check_device(use_gpu) # define model self.acous_dim = acous_dim self.acous_hidden_size = acous_hidden_size # tuning self.acous_norm = acous_norm self.spec_aug = spec_aug self.batch_norm = batch_norm self.enc_mode = enc_mode # define operations self.dropout = nn.Dropout(dropout) # ------ define acous enc ------- if self.enc_mode == 'pyramid': self.acous_enc_l1 = torch.nn.LSTM(self.acous_dim, self.acous_hidden_size, num_layers=1, batch_first=batch_first, bias=True, dropout=dropout, bidirectional=True) self.acous_enc_l2 = torch.nn.LSTM(self.acous_hidden_size * 4, self.acous_hidden_size, num_layers=1, batch_first=batch_first, bias=True, dropout=dropout, bidirectional=True) self.acous_enc_l3 = torch.nn.LSTM(self.acous_hidden_size * 4, self.acous_hidden_size, num_layers=1, batch_first=batch_first, bias=True, dropout=dropout, bidirectional=True) self.acous_enc_l4 = torch.nn.LSTM(self.acous_hidden_size * 4, self.acous_hidden_size, num_layers=1, batch_first=batch_first, bias=True, dropout=dropout, bidirectional=True) if self.batch_norm: self.bn1 = nn.BatchNorm1d(self.acous_hidden_size * 2) self.bn2 = nn.BatchNorm1d(self.acous_hidden_size * 2) self.bn3 = nn.BatchNorm1d(self.acous_hidden_size * 2) self.bn4 = nn.BatchNorm1d(self.acous_hidden_size * 2) elif self.enc_mode == 'cnn': # todo pass
def forward(self, acous_feats, acous_lens=None, is_training=False, hidden=None, use_gpu=False): """ Args: acous_feats: list of acoustic features [b x acous_len x ?] """ # import pdb; pdb.set_trace() device = check_device(use_gpu) batch_size = acous_feats.size(0) acous_len = acous_feats.size(1) # pre-process acoustics if is_training: acous_feats = self.pre_process_acous(acous_feats) # run acous enc - pyramidal acous_hidden_init = None if self.enc_mode == 'pyramid': # layer1 # pack to rnn packed seq obj acous_lens_l1 = torch.cat( [elem + 8 - elem % 8 for elem in acous_lens]) acous_feats_pack = torch.nn.utils.rnn.pack_padded_sequence( acous_feats, acous_lens_l1, batch_first=True, enforce_sorted=False) # run lstm acous_outputs_l1_pack, acous_hidden_l1 = self.acous_enc_l1( acous_feats_pack, acous_hidden_init) # b x acous_len x 2dim # unpack acous_outputs_l1, _ = torch.nn.utils.rnn.pad_packed_sequence( acous_outputs_l1_pack, batch_first=True) # dropout acous_outputs_l1 = self.dropout(acous_outputs_l1)\ .reshape(batch_size, acous_len, acous_outputs_l1.size(-1)) # batch norm if self.batch_norm: acous_outputs_l1 = self.bn1(acous_outputs_l1.permute(0, 2, 1))\ .permute(0, 2, 1) # reduce length acous_inputs_l2 = acous_outputs_l1\ .reshape(batch_size, int(acous_len/2), 2*acous_outputs_l1.size(-1)) # b x acous_len/2 x 4dim # layer2 acous_lens_l2 = acous_lens_l1 / 2 acous_inputs_l2_pack = torch.nn.utils.rnn.pack_padded_sequence( acous_inputs_l2, acous_lens_l2, batch_first=True, enforce_sorted=False) acous_outputs_l2_pack, acous_hidden_l2 = self.acous_enc_l2( acous_inputs_l2_pack, acous_hidden_init) # b x acous_len/2 x 2dim acous_outputs_l2, _ = torch.nn.utils.rnn.pad_packed_sequence( acous_outputs_l2_pack, batch_first=True) acous_outputs_l2 = self.dropout(acous_outputs_l2)\ .reshape(batch_size, int(acous_len/2), acous_outputs_l2.size(-1)) if self.batch_norm: acous_outputs_l2 = self.bn2(acous_outputs_l2.permute(0, 2, 1))\ .permute(0, 2, 1) acous_inputs_l3 = acous_outputs_l2\ .reshape(batch_size, int(acous_len/4), 2*acous_outputs_l2.size(-1)) # b x acous_len/4 x 4dim # layer3 acous_lens_l3 = acous_lens_l2 / 2 acous_inputs_l3_pack = torch.nn.utils.rnn.pack_padded_sequence( acous_inputs_l3, acous_lens_l3, batch_first=True, enforce_sorted=False) acous_outputs_l3_pack, acous_hidden_l3 = self.acous_enc_l3( acous_inputs_l3_pack, acous_hidden_init) # b x acous_len/4 x 2dim acous_outputs_l3, _ = torch.nn.utils.rnn.pad_packed_sequence( acous_outputs_l3_pack, batch_first=True) acous_outputs_l3 = self.dropout(acous_outputs_l3)\ .reshape(batch_size, int(acous_len/4), acous_outputs_l3.size(-1)) if self.batch_norm: acous_outputs_l3 = self.bn3(acous_outputs_l3.permute(0, 2, 1))\ .permute(0, 2, 1) acous_inputs_l4 = acous_outputs_l3\ .reshape(batch_size, int(acous_len/8), 2*acous_outputs_l3.size(-1)) # b x acous_len/8 x 4dim # layer4 acous_lens_l4 = acous_lens_l3 / 2 acous_inputs_l4_pack = torch.nn.utils.rnn.pack_padded_sequence( acous_inputs_l4, acous_lens_l4, batch_first=True, enforce_sorted=False) acous_outputs_l4_pack, acous_hidden_l4 = self.acous_enc_l4( acous_inputs_l4_pack, acous_hidden_init) # b x acous_len/8 x 2dim acous_outputs_l4, _ = torch.nn.utils.rnn.pad_packed_sequence( acous_outputs_l4_pack, batch_first=True) acous_outputs_l4 = self.dropout(acous_outputs_l4)\ .reshape(batch_size, int(acous_len/8), acous_outputs_l4.size(-1)) if self.batch_norm: acous_outputs_l4 = self.bn4(acous_outputs_l4.permute(0, 2, 1))\ .permute(0, 2, 1) acous_outputs = acous_outputs_l4 elif self.enc_mode == 'cnn': pass #todo # import pdb; pdb.set_trace() return acous_outputs
def forward_train(self, src, tgt=None, acous_feats=None, acous_lens=None, mode='ST', use_gpu=True, lm_mode='null', lm_model=None): """ mode: ASR acous -> src AE src -> src ST acous -> tgt MT src -> tgt """ # for backwrad compatibility self.check_var('perturb_emb', False) # import pdb; pdb.set_trace() # note: adding .type(torch.uint8) to be compatible with pytorch 1.1! out_dict={} # check gpu global device device = check_device(use_gpu) # check mode mode = mode.upper() assert type(src) != type(None) if 'ST' in mode or 'ASR' in mode: assert type(acous_feats) != type(None) if 'ST' in mode or 'MT' in mode: assert type(tgt) != type(None) if 'ASR' in mode: """ acous -> EN: RNN in : length reduced fbk features out: w1 w2 w3 <EOS> <PAD> <PAD> #=6 """ emb_src, logps_src, preds_src, lengths = self._encoder_acous(acous_feats, acous_lens, device, use_gpu, tgt=src, is_training=True, teacher_forcing_ratio=1.0, lm_mode=lm_mode, lm_model=lm_model) # output dict out_dict['emb_asr'] = emb_src out_dict['preds_asr'] = preds_src out_dict['logps_asr'] = logps_src out_dict['lengths_asr'] = lengths if 'AE' in mode: """ EN -> EN: Embedder src: <BOS> w1 w2 w3 <EOS> <PAD> <PAD> #=7 mid: w1 w2 w3 <EOS> <PAD> <PAD> #=6 out: w1 w2 w3 <EOS> <PAD> <PAD> #=6 """ if 'ASR' in mode: src_trim = out_dict['preds_asr'].squeeze(2) else: # run asr: by default _, _, preds_src, _ = self._encoder_acous(acous_feats, acous_lens, device, use_gpu, tgt=src, is_training=True, teacher_forcing_ratio=1.0, lm_mode=lm_mode, lm_model=lm_model) src_trim = preds_src.squeeze(2) # from src: use with NLL loss # src_trim = self._pre_proc_src(src, device) src_mask, emb_src, src_mask_input = self._get_src_emb(src_trim, device) logits_src, logps_src, preds_src, _ = self._decoder_en(emb_src) # output dict out_dict['emb_ae'] = emb_src out_dict['refs_ae'] = src_trim out_dict['preds_ae'] = preds_src out_dict['logps_ae'] = logps_src if 'MT' in mode: """ EN -> DE: Transformer src: <BOS> w1 w2 w3 <EOS> <PAD> <PAD> #=7 mid: w1 w2 w3 <EOS> <PAD> <PAD> #=6 out: c1 c2 c3 <EOS> <PAD> <PAD> [dummy] #=7 optional: add perturbation to embeddings """ # get tgt emb tgt_mask, emb_tgt = self._get_tgt_emb(tgt, device) # get static src emb src_trim = self._pre_proc_src(src, device) src_mask, emb_src, src_mask_input = self._get_src_emb(src_trim, device) # perturb emb if self.perturb_emb: emb_src = self._perturb_emb(emb_src, device) # encode decode enc_outputs = self._encoder_en(emb_src, src_mask=src_mask_input) # b x len x dim_model # decode dec_outputs_tgt, logits_tgt, logps_tgt, preds_tgt, _ = \ self._decoder_de(emb_tgt, enc_outputs, tgt_mask=tgt_mask, src_mask=src_mask_input) # output dict out_dict['emb_mt'] = emb_src out_dict['preds_mt'] = preds_tgt out_dict['logps_mt'] = logps_tgt if 'ST' in mode: """ acous -> DE: Transformer in : length reduced fbk features mid: w1 w2 w3 <EOS> <PAD> <PAD> #=6 out: c1 c2 c3 <EOS> <PAD> <PAD> [dummy] #=7 """ # get tgt emb tgt_mask, emb_tgt = self._get_tgt_emb(tgt, device) # get dynamic src emb if 'ASR' in mode: emb_src = out_dict['emb_asr'] lengths = out_dict['lengths_asr'] # else: # emb_src, _, _, lengths = self._encoder_acous(acous_feats, acous_lens, # device, use_gpu, tgt=src, is_training=True, teacher_forcing_ratio=1.0) else: # use free running if no 'ASR' emb_src, _, _, lengths = self._encoder_acous(acous_feats, acous_lens, device, use_gpu, is_training=False, teacher_forcing_ratio=0.0, lm_mode=lm_mode, lm_model=lm_model) # get mask max_len = emb_src.size(1) lengths = torch.LongTensor(lengths) src_mask_input = (torch.arange(max_len).expand(len(lengths), max_len) < lengths.unsqueeze(1)).unsqueeze(1).to(device=device) # encode enc_outputs = self._encoder_en(emb_src, src_mask=src_mask_input) # b x len x dim_model # decode dec_outputs_tgt, logits_tgt, logps_tgt, preds_tgt, _ = \ self._decoder_de(emb_tgt, enc_outputs, tgt_mask=tgt_mask, src_mask=src_mask_input) # output dict out_dict['emb_st'] = emb_src out_dict['preds_st'] = preds_tgt out_dict['logps_st'] = logps_tgt return out_dict
def main(): # load config parser = argparse.ArgumentParser(description='Seq2seq Training') parser = load_arguments(parser) args = vars(parser.parse_args()) config = validate_config(args) # set random seed if config['random_seed'] is not None: set_global_seeds(config['random_seed']) # record config if not os.path.isabs(config['save']): config_save_dir = os.path.join(os.getcwd(), config['save']) if not os.path.exists(config['save']): os.makedirs(config['save']) # loading old models if config['load']: print('loading {} ...'.format(config['load'])) config_save_dir = os.path.join(config['save'], 'model-cont.cfg') else: config_save_dir = os.path.join(config['save'], 'model.cfg') save_config(config, config_save_dir) # contruct trainer t = Trainer(expt_dir=config['save'], load_dir=config['load'], load_mode=config['load_mode'], batch_size=config['batch_size'], checkpoint_every=config['checkpoint_every'], print_every=config['print_every'], eval_mode=config['eval_mode'], eval_metric=config['eval_metric'], learning_rate=config['learning_rate'], learning_rate_init=config['learning_rate_init'], lr_warmup_steps=config['lr_warmup_steps'], eval_with_mask=config['eval_with_mask'], use_gpu=config['use_gpu'], gpu_id=config['gpu_id'], max_grad_norm=config['max_grad_norm'], max_count_no_improve=config['max_count_no_improve'], max_count_num_rollback=config['max_count_num_rollback'], keep_num=config['keep_num'], normalise_loss=config['normalise_loss'], minibatch_split=config['minibatch_split'] ) # load train set train_path_src = config['train_path_src'] train_path_tgt = config['train_path_tgt'] path_vocab_src = config['path_vocab_src'] path_vocab_tgt = config['path_vocab_tgt'] train_set = Dataset(train_path_src, train_path_tgt, path_vocab_src=path_vocab_src, path_vocab_tgt=path_vocab_tgt, seqrev=config['seqrev'], max_seq_len=config['max_seq_len'], batch_size=config['batch_size'], data_ratio=config['data_ratio'], use_gpu=config['use_gpu'], logger=t.logger, use_type=config['use_type'], use_type_src=config['use_type_src']) vocab_size_enc = len(train_set.vocab_src) vocab_size_dec = len(train_set.vocab_tgt) # load dev set if config['dev_path_src'] and config['dev_path_tgt']: dev_path_src = config['dev_path_src'] dev_path_tgt = config['dev_path_tgt'] dev_set = Dataset(dev_path_src, dev_path_tgt, path_vocab_src=path_vocab_src, path_vocab_tgt=path_vocab_tgt, seqrev=config['seqrev'], max_seq_len=config['max_seq_len'], batch_size=config['batch_size'], use_gpu=config['use_gpu'], logger=t.logger, use_type=config['use_type'], use_type_src=config['use_type_src']) else: dev_set = None # construct model seq2seq = Seq2seq(vocab_size_enc, vocab_size_dec, share_embedder=config['share_embedder'], enc_embedding_size=config['embedding_size_enc'], dec_embedding_size=config['embedding_size_dec'], load_embedding_src=config['load_embedding_src'], load_embedding_tgt=config['load_embedding_tgt'], num_heads=config['num_heads'], dim_model=config['dim_model'], dim_feedforward=config['dim_feedforward'], enc_layers=config['enc_layers'], dec_layers=config['dec_layers'], embedding_dropout=config['embedding_dropout'], dropout=config['dropout'], max_seq_len=config['max_seq_len'], act=config['act'], enc_word2id=train_set.src_word2id, dec_word2id=train_set.tgt_word2id, enc_id2word=train_set.src_id2word, dec_id2word=train_set.tgt_id2word, transformer_type=config['transformer_type']) # import pdb; pdb.set_trace() t.logger.info("total #parameters:{}".format(sum(p.numel() for p in seq2seq.parameters() if p.requires_grad))) device = check_device(config['use_gpu']) t.logger.info('device: {}'.format(device)) seq2seq = seq2seq.to(device=device) # run training seq2seq = t.train(train_set, seq2seq, num_epochs=config['num_epochs'], dev_set=dev_set, grab_memory=config['grab_memory'])
def forward_eval(self, src=None, acous_feats=None, acous_lens=None, mode='ST', use_gpu=True, lm_mode='null', lm_model=None): """ beam_width = 1 note the output sequence different from training if using transformer model """ # import pdb; pdb.set_trace() out_dict = {} # check gpu global device device = check_device(use_gpu) # check mode mode = mode.upper() if 'ST' in mode or 'ASR' in mode: assert type(acous_feats) != type(None) batch = acous_feats.size(0) if 'MT' in mode or 'AE' in mode: assert type(src) != type(None) batch = src.size(0) length_out_src = self.max_seq_len_src length_out_tgt = self.max_seq_len_tgt if 'ASR' in mode: """ acous -> EN: RNN in : length reduced fbk features out: w1 w2 w3 <EOS> <PAD> <PAD> #=6 """ # run asr emb_src, logps_src, preds_src, lengths = self._encoder_acous( acous_feats, acous_lens, device, use_gpu, is_training=False, teacher_forcing_ratio=0.0, lm_mode=lm_mode, lm_model=lm_model) # output dict out_dict['emb_asr'] = emb_src out_dict['preds_asr'] = preds_src out_dict['logps_asr'] = logps_src out_dict['lengths_asr'] = lengths if 'MT' in mode: """ EN -> DE: Transformer in : <BOS> w1 w2 w3 <EOS> <PAD> <PAD> #=7 mid: w1 w2 w3 <EOS> <PAD> <PAD> <PAD> #=7 out: <BOS> c1 c2 c3 <EOS> <PAD> <PAD> #=7 """ # get src emb src_trim = self._pre_proc_src(src, device) emb_dyn_ave = self.EMB_DYN_AVE emb_dyn_ave_expand = emb_dyn_ave.repeat(src_trim.size(0), src_trim.size(1), 1).to(device=device) src_mask, emb_src, src_mask_input = self._get_src_emb( src_trim, emb_dyn_ave_expand, device) # encoder enc_outputs = self._encoder_en( emb_src, src_mask=src_mask_input) # b x len x dim_model # prep eos_mask_tgt, logps_tgt, dec_outputs_tgt, preds_save_tgt, preds_tgt, preds_save_tgt = \ self._prep_eval(batch, length_out_tgt, self.dec_vocab_size, device) for i in range(1, self.max_seq_len_tgt): tgt_mask, emb_tgt = self._get_tgt_emb(preds_tgt, device) dec_output_tgt, logit_tgt, logp_tgt, pred_tgt, _ = \ self._decoder_de(emb_tgt, enc_outputs, tgt_mask=tgt_mask, src_mask=src_mask_input) eos_mask_tgt, dec_outputs_tgt, logps_tgt, preds_save_tgt, preds_tgt, flag \ = self._step_eval(i, eos_mask_tgt, dec_output_tgt, logp_tgt, pred_tgt, dec_outputs_tgt, logps_tgt, preds_save_tgt, preds_tgt, batch, length_out_tgt) if flag == 1: break # output dict out_dict['emb_mt'] = emb_src out_dict['preds_mt'] = preds_tgt out_dict['logps_mt'] = logps_tgt if 'ST' in mode: """ acous -> DE: Transformer in : length reduced fbk features out: <BOS> c1 c2 c3 <EOS> <PAD> <PAD> #=7 """ # get embedding if 'ASR' in mode: preds_src = out_dict['preds_asr'] emb_src_dyn = out_dict['emb_asr'] lengths = out_dict['lengths_asr'] else: emb_src_dyn, _, preds_src, lengths = self._encoder_acous( acous_feats, acous_lens, device, use_gpu, is_training=False, teacher_forcing_ratio=0.0, lm_mode=lm_mode, lm_model=lm_model) _, emb_src, _ = self._get_src_emb(preds_src.squeeze(2), emb_src_dyn, device) # get mask max_len = emb_src.size(1) lengths = torch.LongTensor(lengths) src_mask_input = (torch.arange(max_len).expand( len(lengths), max_len) < lengths.unsqueeze(1)).unsqueeze(1).to( device=device) # encode enc_outputs = self._encoder_en( emb_src, src_mask=src_mask_input) # b x len x dim_model # prep eos_mask_tgt, logps_tgt, dec_outputs_tgt, preds_save_tgt, preds_tgt, preds_save_tgt = \ self._prep_eval(batch, length_out_tgt, self.dec_vocab_size, device) for i in range(1, self.max_seq_len_tgt): tgt_mask, emb_tgt = self._get_tgt_emb(preds_tgt, device) dec_output_tgt, logit_tgt, logp_tgt, pred_tgt, _ = \ self._decoder_de(emb_tgt, enc_outputs, tgt_mask=tgt_mask, src_mask=src_mask_input) eos_mask_tgt, dec_outputs_tgt, logps_tgt, preds_save_tgt, preds_tgt, flag \ = self._step_eval(i, eos_mask_tgt, dec_output_tgt, logp_tgt, pred_tgt, dec_outputs_tgt, logps_tgt, preds_save_tgt, preds_tgt, batch, length_out_tgt) if flag == 1: break # output dict out_dict['emb_st'] = emb_src out_dict['preds_st'] = preds_tgt out_dict['logps_st'] = logps_tgt return out_dict
def main(): # load config parser = argparse.ArgumentParser(description='Seq2seq Evaluation') parser = load_arguments(parser) args = vars(parser.parse_args()) config = validate_config(args) # load src-tgt pair test_path_src = config['test_path_src'] test_path_tgt = test_path_src test_path_out = config['test_path_out'] load_dir = config['load'] max_seq_len = config['max_seq_len'] batch_size = config['batch_size'] beam_width = config['beam_width'] use_gpu = config['use_gpu'] seqrev = config['seqrev'] use_type = config['use_type'] # set test mode: 1 = translate; 2 = plot; 3 = save comb ckpt MODE = config['eval_mode'] if MODE != 3: if not os.path.exists(test_path_out): os.makedirs(test_path_out) config_save_dir = os.path.join(test_path_out, 'eval.cfg') save_config(config, config_save_dir) # check device: device = check_device(use_gpu) print('device: {}'.format(device)) # load model latest_checkpoint_path = load_dir resume_checkpoint = Checkpoint.load(latest_checkpoint_path) model = resume_checkpoint.model.to(device) vocab_src = resume_checkpoint.input_vocab vocab_tgt = resume_checkpoint.output_vocab print('Model dir: {}'.format(latest_checkpoint_path)) print('Model laoded') # combine model if type(config['combine_path']) != type(None): model = combine_weights(config['combine_path']) # load test_set test_set = Dataset(test_path_src, test_path_tgt, vocab_src_list=vocab_src, vocab_tgt_list=vocab_tgt, seqrev=seqrev, max_seq_len=900, batch_size=batch_size, use_gpu=use_gpu, use_type=use_type) print('Test dir: {}'.format(test_path_src)) print('Testset loaded') sys.stdout.flush() # run eval if MODE == 1: translate(test_set, model, test_path_out, use_gpu, max_seq_len, beam_width, device, seqrev=seqrev) elif MODE == 2: # output posterior translate_logp(test_set, model, test_path_out, use_gpu, max_seq_len, device, seqrev=seqrev) elif MODE == 3: # save combined model ckpt = Checkpoint(model=model, optimizer=None, epoch=0, step=0, input_vocab=test_set.vocab_src, output_vocab=test_set.vocab_tgt) saved_path = ckpt.save_customise( os.path.join(config['combine_path'].strip('/')+'-combine','combine')) log_ckpts(config['combine_path'], config['combine_path'].strip('/')+'-combine') print('saving at {} ... '.format(saved_path))
def main(): # load config parser = argparse.ArgumentParser(description='Evaluation') parser = load_arguments(parser) args = vars(parser.parse_args()) config = validate_config(args) # load src-tgt pair test_path_src = config['test_path_src'] test_path_tgt = config['test_path_tgt'] if type(test_path_tgt) == type(None): test_path_tgt = test_path_src test_path_out = config['test_path_out'] test_acous_path = config['test_acous_path'] acous_norm_path = config['acous_norm_path'] load_dir = config['load'] max_seq_len = config['max_seq_len'] batch_size = config['batch_size'] beam_width = config['beam_width'] use_gpu = config['use_gpu'] seqrev = config['seqrev'] use_type = config['use_type'] # set test mode MODE = config['eval_mode'] if MODE != 2: if not os.path.exists(test_path_out): os.makedirs(test_path_out) config_save_dir = os.path.join(test_path_out, 'eval.cfg') save_config(config, config_save_dir) # check device: device = check_device(use_gpu) print('device: {}'.format(device)) # load model latest_checkpoint_path = load_dir resume_checkpoint = Checkpoint.load(latest_checkpoint_path) model = resume_checkpoint.model.to(device) vocab_src = resume_checkpoint.input_vocab vocab_tgt = resume_checkpoint.output_vocab print('Model dir: {}'.format(latest_checkpoint_path)) print('Model laoded') # combine model if type(config['combine_path']) != type(None): model = combine_weights(config['combine_path']) # import pdb; pdb.set_trace() # load test_set test_set = Dataset( path_src=test_path_src, path_tgt=test_path_tgt, vocab_src_list=vocab_src, vocab_tgt_list=vocab_tgt, use_type=use_type, acous_path=test_acous_path, seqrev=seqrev, acous_norm=config['acous_norm'], acous_norm_path=config['acous_norm_path'], acous_max_len=6000, # max 50k for mustc trainset max_seq_len_src=900, max_seq_len_tgt=900, # max 2.5k for mustc trainset batch_size=batch_size, mode='ST', use_gpu=use_gpu) print('Test dir: {}'.format(test_path_src)) print('Testset loaded') sys.stdout.flush() # '{AE|ASR|MT|ST}-{REF|HYP}' if len(config['gen_mode'].split('-')) == 2: gen_mode = config['gen_mode'].split('-')[0] history = config['gen_mode'].split('-')[1] elif len(config['gen_mode'].split('-')) == 1: gen_mode = config['gen_mode'] history = 'HYP' # add external language model lm_mode = config['lm_mode'] # run eval: if MODE == 1: translate(test_set, model, test_path_out, use_gpu, max_seq_len, beam_width, device, seqrev=seqrev, gen_mode=gen_mode, lm_mode=lm_mode, history=history) elif MODE == 2: # save combined model ckpt = Checkpoint(model=model, optimizer=None, epoch=0, step=0, input_vocab=test_set.vocab_src, output_vocab=test_set.vocab_tgt) saved_path = ckpt.save_customise( os.path.join(config['combine_path'].strip('/') + '-combine', 'combine')) log_ckpts(config['combine_path'], config['combine_path'].strip('/') + '-combine') print('saving at {} ... '.format(saved_path)) elif MODE == 3: plot_emb(test_set, model, test_path_out, use_gpu, max_seq_len, device) elif MODE == 4: gather_emb(test_set, model, test_path_out, use_gpu, max_seq_len, device) elif MODE == 5: compute_kl(test_set, model, test_path_out, use_gpu, max_seq_len, device)
def main(): # load config parser = argparse.ArgumentParser(description='Seq2seq Training') parser = load_arguments(parser) args = vars(parser.parse_args()) config = validate_config(args) # set random seed if config['random_seed'] is not None: set_global_seeds(config['random_seed']) # record config if not os.path.isabs(config['save']): config_save_dir = os.path.join(os.getcwd(), config['save']) if not os.path.exists(config['save']): os.makedirs(config['save']) # resume or not if config['load']: resume = True print('resuming {} ...'.format(config['load'])) config_save_dir = os.path.join(config['save'], 'model-cont.cfg') else: resume = False config_save_dir = os.path.join(config['save'], 'model.cfg') save_config(config, config_save_dir) # contruct trainer t = Trainer(expt_dir=config['save'], load_dir=config['load'], batch_size=config['batch_size'], checkpoint_every=config['checkpoint_every'], print_every=config['print_every'], learning_rate=config['learning_rate'], eval_with_mask=config['eval_with_mask'], scheduled_sampling=config['scheduled_sampling'], teacher_forcing_ratio=config['teacher_forcing_ratio'], use_gpu=config['use_gpu'], max_grad_norm=config['max_grad_norm'], max_count_no_improve=config['max_count_no_improve'], max_count_num_rollback=config['max_count_num_rollback'], keep_num=config['keep_num'], normalise_loss=config['normalise_loss'], minibatch_split=config['minibatch_split']) # load train set train_path_src = config['train_path_src'] train_path_tgt = config['train_path_tgt'] path_vocab_src = config['path_vocab_src'] path_vocab_tgt = config['path_vocab_tgt'] train_set = Dataset(train_path_src, train_path_tgt, path_vocab_src=path_vocab_src, path_vocab_tgt=path_vocab_tgt, seqrev=config['seqrev'], max_seq_len=config['max_seq_len'], batch_size=config['batch_size'], use_gpu=config['use_gpu'], logger=t.logger, use_type=config['use_type']) vocab_size_enc = len(train_set.vocab_src) vocab_size_dec = len(train_set.vocab_tgt) # load dev set if config['dev_path_src'] and config['dev_path_tgt']: dev_path_src = config['dev_path_src'] dev_path_tgt = config['dev_path_tgt'] dev_set = Dataset(dev_path_src, dev_path_tgt, path_vocab_src=path_vocab_src, path_vocab_tgt=path_vocab_tgt, seqrev=config['seqrev'], max_seq_len=config['max_seq_len'], batch_size=config['batch_size'], use_gpu=config['use_gpu'], logger=t.logger, use_type=config['use_type']) else: dev_set = None # construct model seq2seq = Seq2seq(vocab_size_enc, vocab_size_dec, share_embedder=config['share_embedder'], embedding_size_enc=config['embedding_size_enc'], embedding_size_dec=config['embedding_size_dec'], embedding_dropout=config['embedding_dropout'], hidden_size_enc=config['hidden_size_enc'], num_bilstm_enc=config['num_bilstm_enc'], num_unilstm_enc=config['num_unilstm_enc'], hidden_size_dec=config['hidden_size_dec'], num_unilstm_dec=config['num_unilstm_dec'], hidden_size_att=config['hidden_size_att'], hidden_size_shared=config['hidden_size_shared'], dropout=config['dropout'], residual=config['residual'], batch_first=config['batch_first'], max_seq_len=config['max_seq_len'], load_embedding_src=config['load_embedding_src'], load_embedding_tgt=config['load_embedding_tgt'], src_word2id=train_set.src_word2id, tgt_word2id=train_set.tgt_word2id, src_id2word=train_set.src_id2word, tgt_id2word=train_set.tgt_id2word, att_mode=config['att_mode']) device = check_device(config['use_gpu']) t.logger.info('device:{}'.format(device)) seq2seq = seq2seq.to(device=device) # run training seq2seq = t.train(train_set, seq2seq, num_epochs=config['num_epochs'], resume=resume, dev_set=dev_set)