Exemplo n.º 1
0
    def forward(self, src, src_lens=None, hidden=None, use_gpu=True):
        """
			Args:
				src: list of src word_ids [batch_size, seq_len, word_ids]
		"""

        # import pdb; pdb.set_trace()
        # src_lens=None

        device = check_device(use_gpu)

        # src mask
        mask_src = src.data.eq(PAD)
        batch_size = src.size(0)
        seq_len = src.size(1)

        # convert id to embedding
        emb_src = self.embedding_dropout(self.embedder_enc(src))

        # run enc
        # bilstm: pack paded seq. for bilstm (rm impact of padding)
        if type(src_lens) != type(None):
            src_lens = torch.cat(src_lens)
            emb_src_pack = torch.nn.utils.rnn.pack_padded_sequence(
                emb_src, src_lens, batch_first=True, enforce_sorted=False)
            enc_outputs_pack, enc_hidden = self.enc(emb_src_pack, hidden)
            enc_outputs, _ = torch.nn.utils.rnn.pad_packed_sequence(
                enc_outputs_pack, batch_first=True)
        else:
            enc_outputs, enc_hidden = self.enc(emb_src, hidden)
        # unilstm
        enc_outputs = self.dropout(enc_outputs)\
         .view(batch_size, seq_len, enc_outputs.size(-1))

        if self.num_unilstm_enc != 0:
            if not self.residual:
                enc_hidden_uni_init = None
                enc_outputs, enc_hidden_uni = self.enc_uni(
                    enc_outputs, enc_hidden_uni_init)
                enc_outputs = self.dropout(enc_outputs).view(
                    batch_size, seq_len, enc_outputs.size(-1))
            else:
                enc_hidden_uni_init = None
                enc_hidden_uni_lis = []
                for i in range(self.num_unilstm_enc):
                    enc_inputs = enc_outputs
                    enc_func = getattr(self.enc_uni, 'l' + str(i))
                    enc_outputs, enc_hidden_uni = enc_func(
                        enc_inputs, enc_hidden_uni_init)
                    enc_hidden_uni_lis.append(enc_hidden_uni)
                    if i < self.num_unilstm_enc - 1:  # no residual for last layer
                        enc_outputs = enc_outputs + enc_inputs
                    enc_outputs = self.dropout(enc_outputs).view(
                        batch_size, seq_len, enc_outputs.size(-1))

        return enc_outputs
Exemplo n.º 2
0
    def __init__(self,
                 expt_dir='experiment',
                 load_dir=None,
                 checkpoint_every=100,
                 print_every=100,
                 batch_size=256,
                 use_gpu=False,
                 learning_rate=0.00001,
                 learning_rate_init=0.0005,
                 lr_warmup_steps=16000,
                 max_grad_norm=1.0,
                 eval_with_mask=True,
                 max_count_no_improve=2,
                 max_count_num_rollback=2,
                 keep_num=1,
                 normalise_loss=True,
                 minibatch_split=1):

        self.use_gpu = use_gpu
        self.device = check_device(self.use_gpu)

        self.optimizer = None
        self.checkpoint_every = checkpoint_every
        self.print_every = print_every

        self.learning_rate = learning_rate
        self.learning_rate_init = learning_rate_init
        self.lr_warmup_steps = lr_warmup_steps
        if self.lr_warmup_steps == 0:
            assert self.learning_rate == self.learning_rate_init

        self.max_grad_norm = max_grad_norm
        self.eval_with_mask = eval_with_mask

        self.max_count_no_improve = max_count_no_improve
        self.max_count_num_rollback = max_count_num_rollback
        self.keep_num = keep_num
        self.normalise_loss = normalise_loss

        if not os.path.isabs(expt_dir):
            expt_dir = os.path.join(os.getcwd(), expt_dir)
        self.expt_dir = expt_dir
        if not os.path.exists(self.expt_dir):
            os.makedirs(self.expt_dir)
        self.load_dir = load_dir

        self.logger = logging.getLogger(__name__)
        self.writer = torch.utils.tensorboard.writer.SummaryWriter(
            log_dir=self.expt_dir)

        self.minibatch_split = minibatch_split
        self.batch_size = batch_size
        self.minibatch_size = int(self.batch_size /
                                  self.minibatch_split)  # to be changed if OOM
Exemplo n.º 3
0
def main():

	# load config
	parser = argparse.ArgumentParser(description='Seq2seq Evaluation')
	parser = load_arguments(parser)
	args = vars(parser.parse_args())
	config = validate_config(args)

	# load src-tgt pair
	test_path_src = config['test_path_src']
	test_path_tgt = config['test_path_tgt']
	path_vocab_src = config['path_vocab_src']
	path_vocab_tgt = config['path_vocab_tgt']
	test_path_out = config['test_path_out']
	load_dir = config['load']
	max_seq_len = config['max_seq_len']
	batch_size = config['batch_size']
	beam_width = config['beam_width']
	use_gpu = config['use_gpu']
	seqrev = config['seqrev']
	use_type = config['use_type']

	if not os.path.exists(test_path_out):
		os.makedirs(test_path_out)
	config_save_dir = os.path.join(test_path_out, 'eval.cfg')
	save_config(config, config_save_dir)

	# set test mode: 1 = translate; 2 = plot
	MODE = config['eval_mode']

	# check device:
	device = check_device(use_gpu)
	print('device: {}'.format(device))

	# load test_set
	test_set = Dataset(test_path_src, test_path_tgt,
						path_vocab_src, path_vocab_tgt,
						seqrev=seqrev,
						max_seq_len=max_seq_len,
						batch_size=batch_size,
						use_gpu=use_gpu,
						use_type=use_type)
	print('Testset loaded')
	sys.stdout.flush()

	# run eval
	if MODE == 1:
		translate(test_set, load_dir, test_path_out, use_gpu,
			max_seq_len, beam_width, device, seqrev=seqrev)
Exemplo n.º 4
0
    def forward_train(self, src, tgt, debug_flag=False, use_gpu=True):
        """
			train enc + dec
			note: all output useful up to the second last element i.e. b x (len-1)
					e.g. [b,:-1] for preds -
						src: 		w1 w2 w3 <EOS> <PAD> <PAD> <PAD>
						ref: 	BOS	w1 w2 w3 <EOS> <PAD> <PAD>
						tgt:		w1 w2 w3 <EOS> <PAD> <PAD> dummy
						ref start with BOS, the last elem does not have ref!
		"""

        # import pdb; pdb.set_trace()
        # note: adding .type(torch.uint8) to be compatible with pytorch 1.1!

        # check gpu
        global device
        device = check_device(use_gpu)

        # run transformer
        src_mask = _get_pad_mask(src).to(device=device).type(
            torch.uint8)  # b x len
        tgt_mask = ((_get_pad_mask(tgt).to(device=device).type(torch.uint8)
                     & _get_subsequent_mask(self.max_seq_len).type(
                         torch.uint8).to(device=device)))

        # b x len x dim_model
        if self.enc_emb_proj_flag:
            emb_src = self.enc_emb_proj(
                self.embedding_dropout(self.enc_embedder(src)))
        else:
            emb_src = self.embedding_dropout(self.enc_embedder(src))
        if self.dec_emb_proj_flag:
            emb_tgt = self.dec_emb_proj(
                self.embedding_dropout(self.dec_embedder(tgt)))
        else:
            emb_tgt = self.embedding_dropout(self.dec_embedder(tgt))

        enc_outputs, *_ = self.enc(emb_src,
                                   src_mask=src_mask)  # b x len x dim_model
        dec_outputs, *_ = self.dec(emb_tgt,
                                   enc_outputs,
                                   tgt_mask=tgt_mask,
                                   src_mask=src_mask)

        logits = self.out(dec_outputs)  # b x len x vocab_size
        logps = torch.log_softmax(logits, dim=2)
        preds = logps.data.topk(1)[1]

        return preds, logps, dec_outputs
Exemplo n.º 5
0
    def __init__(self,
                 expt_dir='experiment',
                 load_dir=None,
                 batch_size=64,
                 minibatch_partition=20,
                 checkpoint_every=100,
                 print_every=100,
                 learning_rate=0.001,
                 eval_with_mask=True,
                 scheduled_sampling=False,
                 teacher_forcing_ratio=1.0,
                 use_gpu=False,
                 max_grad_norm=1.0,
                 max_count_no_improve=3,
                 max_count_num_rollback=3,
                 keep_num=2,
                 normalise_loss=True):

        self.use_gpu = use_gpu
        self.device = check_device(self.use_gpu)

        self.optimizer = None
        self.checkpoint_every = checkpoint_every
        self.print_every = print_every
        self.learning_rate = learning_rate
        self.max_grad_norm = max_grad_norm
        self.eval_with_mask = eval_with_mask
        self.scheduled_sampling = scheduled_sampling
        self.teacher_forcing_ratio = teacher_forcing_ratio

        self.max_count_no_improve = max_count_no_improve
        self.max_count_num_rollback = max_count_num_rollback
        self.keep_num = keep_num
        self.normalise_loss = normalise_loss

        if not os.path.isabs(expt_dir):
            expt_dir = os.path.join(os.getcwd(), expt_dir)
        self.expt_dir = expt_dir
        if not os.path.exists(self.expt_dir):
            os.makedirs(self.expt_dir)
        self.load_dir = load_dir

        self.logger = logging.getLogger(__name__)
        self.writer = torch.utils.tensorboard.writer.SummaryWriter(
            log_dir=self.expt_dir)

        self.batch_size = batch_size
        self.minibatch_partition = minibatch_partition
        self.minibatch_size = int(self.batch_size / self.minibatch_partition)
Exemplo n.º 6
0
    def forward(self, src, src_lens, hidden=None, use_gpu=True):
        """
			Args:
				src: list of src word_ids [batch_size, seq_len, word_ids]
		"""

        # import pdb; pdb.set_trace()

        out_dict = {}
        device = check_device(use_gpu)

        # src mask
        mask_src = src.data.eq(PAD)
        batch_size = src.size(0)
        seq_len = src.size(1)

        # convert id to embedding
        emb_src = self.embedding_dropout(self.embedder(src))

        # run lstm: packing + unpacking
        src_lens = torch.cat(src_lens)
        emb_src_pack = torch.nn.utils.rnn.pack_padded_sequence(
            emb_src, src_lens, batch_first=True, enforce_sorted=False)
        lstm_outputs_pack, lstm_hidden = self.lstm(emb_src_pack, hidden)
        lstm_outputs, _ = torch.nn.utils.rnn.pad_packed_sequence(
            lstm_outputs_pack, batch_first=True)
        lstm_outputs = self.dropout(lstm_outputs)\
         .view(batch_size, seq_len, lstm_outputs.size(-1))

        # generate predictions
        logits = self.out(lstm_outputs)
        logps = F.log_softmax(logits, dim=2)
        symbols = logps.topk(1)[1]
        out_dict['sequence'] = symbols

        return logps, out_dict
Exemplo n.º 7
0
def main():

	# load config
	parser = argparse.ArgumentParser(description='Seq2seq Evaluation')
	parser = load_arguments(parser)
	args = vars(parser.parse_args())
	config = validate_config(args)

	# load src-tgt pair
	test_path_src = config['test_path_src']
	test_path_tgt = config['test_path_tgt'] # dummy
	if type(test_path_tgt) == type(None):
		test_path_tgt = test_path_src
	test_path_out = config['test_path_out']
	load_dir = config['load']
	max_seq_len = config['max_seq_len']
	batch_size = config['batch_size']
	beam_width = config['beam_width']
	use_gpu = config['use_gpu']
	seqrev = config['seqrev']
	use_type = config['use_type']

	if not os.path.exists(test_path_out):
		os.makedirs(test_path_out)
	config_save_dir = os.path.join(test_path_out, 'eval.cfg')
	save_config(config, config_save_dir)

	# set test mode: 1 = translate; 3 = plot
	MODE = config['eval_mode']
	if MODE == 3:
		max_seq_len = 32
		batch_size = 1
		beam_width = 1
		use_gpu = False

	# check device:
	device = check_device(use_gpu)
	print('device: {}'.format(device))

	# load model
	latest_checkpoint_path = load_dir
	resume_checkpoint = Checkpoint.load(latest_checkpoint_path)
	model = resume_checkpoint.model.to(device)
	vocab_src = resume_checkpoint.input_vocab
	vocab_tgt = resume_checkpoint.output_vocab
	print('Model dir: {}'.format(latest_checkpoint_path))
	print('Model laoded')

	# load test_set
	test_set = Dataset(test_path_src, test_path_tgt,
						vocab_src_list=vocab_src, vocab_tgt_list=vocab_tgt,
						seqrev=seqrev,
						max_seq_len=max_seq_len,
						batch_size=batch_size,
						use_gpu=use_gpu,
						use_type=use_type)
	print('Test dir: {}'.format(test_path_src))
	print('Testset loaded')
	sys.stdout.flush()

	# run eval
	if MODE == 1: # FR
		translate(test_set, model, test_path_out, use_gpu,
			max_seq_len, beam_width, device, seqrev=seqrev)

	if MODE == 2:
		translate_batch(test_set, model, test_path_out, use_gpu,
			max_seq_len, beam_width, device, seqrev=seqrev)

	elif MODE == 3:
		# plotting
		att_plot(test_set, model, test_path_out, use_gpu,
			max_seq_len, beam_width, device)

	elif MODE == 4: # TF
		translate_tf(test_set, model, test_path_out, use_gpu,
			max_seq_len, beam_width, device, seqrev=seqrev)
def main():

    # import pdb; pdb.set_trace()
    # load config
    parser = argparse.ArgumentParser(description='LAS + NMT Training')
    parser = load_arguments(parser)
    args = vars(parser.parse_args())
    config = validate_config(args)

    # set random seed
    if config['random_seed'] is not None:
        set_global_seeds(config['random_seed'])

    # record config
    if not os.path.isabs(config['save']):
        config_save_dir = os.path.join(os.getcwd(), config['save'])
    if not os.path.exists(config['save']):
        os.makedirs(config['save'])

    # resume or not
    if type(config['load']) != type(None) and config['load_mode'] == 'resume':
        config_save_dir = os.path.join(config['save'], 'model-cont.cfg')
    else:
        config_save_dir = os.path.join(config['save'], 'model.cfg')
    save_config(config, config_save_dir)

    loss_coeff = {}
    loss_coeff['nll_asr'] = config['loss_nll_asr_coeff']
    loss_coeff['nll_mt'] = config['loss_nll_mt_coeff']
    loss_coeff['nll_st'] = config['loss_nll_st_coeff']

    # contruct trainer
    Trainer = globals()['Trainer_{}'.format(config['mode'])]
    t = Trainer(expt_dir=config['save'],
                load_dir=config['load'],
                load_mode=config['load_mode'],
                load_freeze=config['load_freeze'],
                batch_size=config['batch_size'],
                minibatch_partition=config['minibatch_partition'],
                checkpoint_every=config['checkpoint_every'],
                print_every=config['print_every'],
                learning_rate=config['learning_rate'],
                learning_rate_init=config['learning_rate_init'],
                lr_warmup_steps=config['lr_warmup_steps'],
                eval_with_mask=config['eval_with_mask'],
                use_gpu=config['use_gpu'],
                gpu_id=config['gpu_id'],
                max_grad_norm=config['max_grad_norm'],
                max_count_no_improve=config['max_count_no_improve'],
                max_count_num_rollback=config['max_count_num_rollback'],
                keep_num=config['keep_num'],
                normalise_loss=config['normalise_loss'],
                loss_coeff=loss_coeff)

    # vocab
    path_vocab_src = config['path_vocab_src']
    path_vocab_tgt = config['path_vocab_tgt']

    # ----- 3WAY -----
    train_set = None
    dev_set = None
    mode = config['mode']
    if 'ST' in mode:
        # load train set
        if config['st_train_path_src']:
            t.logger.info(' -- load ST train set -- ')
            train_path_src = config['st_train_path_src']
            train_path_tgt = config['st_train_path_tgt']
            train_acous_path = config['st_train_acous_path']
            train_set = Dataset(path_src=train_path_src,
                                path_tgt=train_path_tgt,
                                path_vocab_src=path_vocab_src,
                                path_vocab_tgt=path_vocab_tgt,
                                use_type=config['use_type'],
                                acous_path=train_acous_path,
                                seqrev=config['seqrev'],
                                acous_norm=config['las_acous_norm'],
                                acous_norm_path=config['st_acous_norm_path'],
                                acous_max_len=config['las_acous_max_len'],
                                max_seq_len_src=config['max_seq_len_src'],
                                max_seq_len_tgt=config['max_seq_len_tgt'],
                                batch_size=config['batch_size'],
                                data_ratio=config['st_data_ratio'],
                                use_gpu=config['use_gpu'],
                                mode='ST',
                                logger=t.logger)

            vocab_size_enc = len(train_set.vocab_src)
            vocab_size_dec = len(train_set.vocab_tgt)
            src_word2id = train_set.src_word2id
            tgt_word2id = train_set.tgt_word2id
            src_id2word = train_set.src_id2word
            tgt_id2word = train_set.tgt_id2word

        # load dev set
        if config['st_dev_path_src']:
            t.logger.info(' -- load ST dev set -- ')
            dev_path_src = config['st_dev_path_src']
            dev_path_tgt = config['st_dev_path_tgt']
            dev_acous_path = config['st_dev_acous_path']
            dev_set = Dataset(path_src=dev_path_src,
                              path_tgt=dev_path_tgt,
                              path_vocab_src=path_vocab_src,
                              path_vocab_tgt=path_vocab_tgt,
                              use_type=config['use_type'],
                              acous_path=dev_acous_path,
                              acous_norm_path=config['st_acous_norm_path'],
                              acous_max_len=config['las_acous_max_len'],
                              seqrev=config['seqrev'],
                              acous_norm=config['las_acous_norm'],
                              max_seq_len_src=config['max_seq_len_src'],
                              max_seq_len_tgt=config['max_seq_len_tgt'],
                              batch_size=config['batch_size'],
                              use_gpu=config['use_gpu'],
                              mode='ST',
                              logger=t.logger)
        else:
            dev_set = None

    # ----- ASR -----
    asr_train_set = None
    asr_dev_set = None
    if 'ASR' in mode:
        # load train set
        if config['asr_train_path_src']:
            t.logger.info(' -- load ASR train set -- ')
            asr_train_path_src = config['asr_train_path_src']
            asr_train_acous_path = config['asr_train_acous_path']
            asr_train_set = Dataset(
                path_src=asr_train_path_src,
                path_tgt=None,
                path_vocab_src=path_vocab_src,
                path_vocab_tgt=path_vocab_tgt,
                use_type=config['use_type'],
                acous_path=asr_train_acous_path,
                acous_norm_path=config['asr_train_acous_norm_path'],
                seqrev=config['seqrev'],
                acous_norm=config['las_acous_norm'],
                acous_max_len=config['las_acous_max_len'],
                max_seq_len_src=config['max_seq_len_src'],
                max_seq_len_tgt=config['max_seq_len_tgt'],
                batch_size=config['batch_size'],
                data_ratio=config['asr_data_ratio'],
                use_gpu=config['use_gpu'],
                mode='ASR',
                logger=t.logger)

            vocab_size_enc = len(asr_train_set.vocab_src)
            vocab_size_dec = len(asr_train_set.vocab_tgt)
            src_word2id = asr_train_set.src_word2id
            tgt_word2id = asr_train_set.tgt_word2id
            src_id2word = asr_train_set.src_id2word
            tgt_id2word = asr_train_set.tgt_id2word

        # load dev set
        if config['asr_dev_path_src']:
            t.logger.info(' -- load ASR dev set -- ')
            asr_dev_path_src = config['asr_dev_path_src']
            asr_dev_acous_path = config['asr_dev_acous_path']
            asr_dev_set = Dataset(
                path_src=asr_dev_path_src,
                path_tgt=None,
                path_vocab_src=path_vocab_src,
                path_vocab_tgt=path_vocab_tgt,
                use_type=config['use_type'],
                acous_path=asr_dev_acous_path,
                acous_norm_path=config['asr_dev_acous_norm_path'],
                acous_max_len=config['las_acous_max_len'],
                seqrev=config['seqrev'],
                acous_norm=config['las_acous_norm'],
                max_seq_len_src=config['max_seq_len_src'],
                max_seq_len_tgt=config['max_seq_len_tgt'],
                batch_size=config['batch_size'],
                use_gpu=config['use_gpu'],
                mode='ASR',
                logger=t.logger)
        else:
            asr_dev_set = None

    # ----- MT -----
    mt_train_set = None
    mt_dev_set = None
    if 'MT' in mode:
        # load train set
        if config['mt_train_path_src']:
            t.logger.info(' -- load MT train set -- ')
            mt_train_path_src = config['mt_train_path_src']
            mt_train_path_tgt = config['mt_train_path_tgt']
            mt_train_set = Dataset(path_src=mt_train_path_src,
                                   path_tgt=mt_train_path_tgt,
                                   path_vocab_src=path_vocab_src,
                                   path_vocab_tgt=path_vocab_tgt,
                                   use_type=config['use_type'],
                                   acous_path=None,
                                   acous_norm_path=None,
                                   seqrev=config['seqrev'],
                                   acous_norm=config['las_acous_norm'],
                                   acous_max_len=config['las_acous_max_len'],
                                   max_seq_len_src=config['max_seq_len_src'],
                                   max_seq_len_tgt=config['max_seq_len_tgt'],
                                   batch_size=config['batch_size'],
                                   data_ratio=config['mt_data_ratio'],
                                   use_gpu=config['use_gpu'],
                                   mode='MT',
                                   logger=t.logger)

            vocab_size_enc = len(mt_train_set.vocab_src)
            vocab_size_dec = len(mt_train_set.vocab_tgt)
            src_word2id = mt_train_set.src_word2id
            tgt_word2id = mt_train_set.tgt_word2id
            src_id2word = mt_train_set.src_id2word
            tgt_id2word = mt_train_set.tgt_id2word

        # load dev set
        if config['mt_dev_path_src']:
            t.logger.info(' -- load MT dev set -- ')
            mt_dev_path_src = config['mt_dev_path_src']
            mt_dev_path_tgt = config['mt_dev_path_tgt']
            mt_dev_set = Dataset(path_src=mt_dev_path_src,
                                 path_tgt=mt_dev_path_tgt,
                                 path_vocab_src=path_vocab_src,
                                 path_vocab_tgt=path_vocab_tgt,
                                 use_type=config['use_type'],
                                 acous_path=None,
                                 acous_norm_path=None,
                                 acous_max_len=config['las_acous_max_len'],
                                 seqrev=config['seqrev'],
                                 acous_norm=config['las_acous_norm'],
                                 max_seq_len_src=config['max_seq_len_src'],
                                 max_seq_len_tgt=config['max_seq_len_tgt'],
                                 batch_size=config['batch_size'],
                                 use_gpu=config['use_gpu'],
                                 mode='MT',
                                 logger=t.logger)
        else:
            mt_dev_set = None

    # collect all datasets
    train_sets = {}
    dev_sets = {}
    train_sets['st'] = train_set
    train_sets['asr'] = asr_train_set
    train_sets['mt'] = mt_train_set
    dev_sets['st'] = dev_set
    dev_sets['asr'] = asr_dev_set
    dev_sets['mt'] = mt_dev_set

    # device
    device = check_device(config['use_gpu'])
    t.logger.info('device:{}'.format(device))

    # construct nmt model
    seq2seq = Seq2seq(
        vocab_size_enc,
        vocab_size_dec,
        share_embedder=config['share_embedder'],
        enc_embedding_size=config['embedding_size_enc'],
        dec_embedding_size=config['embedding_size_dec'],
        load_embedding_src=config['load_embedding_src'],
        load_embedding_tgt=config['load_embedding_tgt'],
        num_heads=config['num_heads'],
        dim_model=config['dim_model'],
        dim_feedforward=config['dim_feedforward'],
        enc_layers=config['enc_layers'],
        dec_layers=config['dec_layers'],
        embedding_dropout=config['embedding_dropout'],
        dropout=config['dropout'],
        max_seq_len_src=config['max_seq_len_src'],
        max_seq_len_tgt=config['max_seq_len_tgt'],
        act=config['act'],
        enc_word2id=src_word2id,
        dec_word2id=tgt_word2id,
        enc_id2word=src_id2word,
        dec_id2word=tgt_id2word,
        transformer_type=config['transformer_type'],
        enc_emb_proj=config['enc_emb_proj'],
        dec_emb_proj=config['dec_emb_proj'],
        #
        acous_dim=config['las_acous_dim'],
        acous_hidden_size=config['las_acous_hidden_size'],
        #
        mode=config['mode'],
        load_mode=config['load_mode'])
    seq2seq = seq2seq.to(device=device)

    # run training
    seq2seq = t.train(train_sets,
                      seq2seq,
                      num_epochs=config['num_epochs'],
                      dev_sets=dev_sets,
                      grab_memory=config['grab_memory'])
Exemplo n.º 9
0
    def forward_translate_fast(self,
                               src,
                               beam_width=1,
                               penalty_factor=1,
                               use_gpu=True):
        """
			require large memory - run on cpu
		"""

        # import pdb; pdb.set_trace()

        # check gpu
        global device
        device = check_device(use_gpu)

        # run dd
        src_mask = _get_pad_mask(src).type(torch.uint8).to(
            device=device)  # b x 1 x len
        if self.enc_emb_proj_flag:
            emb_src = self.enc_emb_proj(self.enc_embedder(src))
        else:
            emb_src = self.enc_embedder(src)
        enc_outputs, *_ = self.enc(emb_src,
                                   src_mask=src_mask)  # b x len x dim_model

        batch = src.size(0)
        length_in = src.size(1)
        length_out = self.max_seq_len

        eos_mask = torch.BoolTensor([False]).repeat(
            batch * beam_width).to(device=device)
        len_map = torch.Tensor([1
                                ]).repeat(batch * beam_width).to(device=device)
        preds = torch.Tensor([BOS]).repeat(batch, 1).type(
            torch.LongTensor).to(device=device)

        # repeat for beam_width times
        # a b c d -> aaa bbb ccc ddd

        # b x 1 x len -> (b x beam_width) x 1 x len
        src_mask_expand = src_mask.repeat(1, beam_width,
                                          1).view(-1, 1, length_in)
        # b x len x dim_model -> (b x beam_width) x len x dim_model
        enc_outputs_expand = enc_outputs.repeat(1, beam_width, 1).view(
            -1, length_in, self.dim_model)
        # (b x beam_width) x len
        preds_expand = preds.repeat(1, beam_width).view(-1, preds.size(-1))
        # (b x beam_width)
        scores_expand = torch.Tensor([0]).repeat(batch * beam_width).type(
            torch.FloatTensor).to(device=device)

        # loop over sequence length
        for i in range(1, self.max_seq_len):

            # gen: 0-30; ref: 1-31
            # import pdb; pdb.set_trace()

            # Get k candidates for each beam, k^2 candidates in total (k=beam_width)
            tgt_mask_expand = ((
                _get_pad_mask(preds_expand).type(torch.uint8).to(device=device)
                & _get_subsequent_mask(preds_expand.size(-1)).type(
                    torch.uint8).to(device=device)))
            if self.dec_emb_proj_flag:
                emb_tgt_expand = self.dec_emb_proj(
                    self.dec_embedder(preds_expand))
            else:
                emb_tgt_expand = self.dec_embedder(preds_expand)

            if i == 1:
                cache_decslf = None
                cache_encdec = None
            dec_output_expand, *_, cache_decslf, cache_encdec = self.dec(
                emb_tgt_expand,
                enc_outputs_expand,
                tgt_mask=tgt_mask_expand,
                src_mask=src_mask_expand,
                decode_speedup=True,
                cache_decslf=cache_decslf,
                cache_encdec=cache_encdec)

            logit_expand = self.out(dec_output_expand)
            # (b x beam_width) x len x vocab_size
            logp_expand = torch.log_softmax(logit_expand, dim=2)
            # (b x beam_width) x len x beam_width
            score_expand, pred_expand = logp_expand.data.topk(beam_width)

            # select current slice
            dec_output = dec_output_expand[:, i -
                                           1]  # (b x beam_width) x dim_model - nouse
            logp = logp_expand[:, i -
                               1, :]  # (b x beam_width) x vocab_size - nouse
            pred = pred_expand[:, i - 1]  # (b x beam_width) x beam_width
            score = score_expand[:, i - 1]  # (b x beam_width) x beam_width

            # select k candidates from k^2 candidates
            if i == 1:
                # inital state, keep first k candidates
                # b x (beam_width x beam_width) -> b x (beam_width) -> (b x beam_width) x 1
                score_select = scores_expand + score.reshape(batch, -1)[:,:beam_width]\
                 .contiguous().view(-1)
                scores_expand = score_select
                pred_select = pred.reshape(
                    batch, -1)[:, :beam_width].contiguous().view(-1)
                preds_expand = torch.cat(
                    (preds_expand, pred_select.unsqueeze(-1)), dim=1)

            else:
                # keep only 1 candidate when hitting eos
                # (b x beam_width) x beam_width
                eos_mask_expand = eos_mask.reshape(-1, 1).repeat(1, beam_width)
                eos_mask_expand[:, 0] = False
                # (b x beam_width) x beam_width
                score_temp = scores_expand.reshape(-1, 1) + score.masked_fill(
                    eos_mask.reshape(-1, 1), 0).masked_fill(
                        eos_mask_expand, -1e9)
                # length penalty
                score_temp = score_temp / (len_map.reshape(-1, 1)**
                                           penalty_factor)
                # select top k from k^2
                # (b x beam_width^2 -> b x beam_width)
                score_select, pos = score_temp.reshape(batch,
                                                       -1).topk(beam_width)
                scores_expand = score_select.view(-1) * (len_map.reshape(
                    -1, 1)**penalty_factor).view(-1)
                # select correct elements according to pos
                pos = (pos +
                       torch.range(0, (batch - 1) * (beam_width**2),
                                   (beam_width**2)).to(device=device).reshape(
                                       batch, 1)).long()
                r_idxs, c_idxs = pos // beam_width, pos % beam_width  # b x beam_width
                pred_select = pred[r_idxs, c_idxs].view(
                    -1)  # b x beam_width -> (b x beam_width)
                # Copy the corresponding previous tokens.
                preds_expand[:, :i] = preds_expand[r_idxs.view(-1), :
                                                   i]  # (b x beam_width) x i
                # Set the best tokens in this beam search step
                preds_expand = torch.cat(
                    (preds_expand, pred_select.unsqueeze(-1)), dim=1)

            # locate the eos in the generated sequences
            # eos_mask = (pred_select == EOS) + eos_mask # >=pt1.3
            eos_mask = ((pred_select == EOS).type(torch.uint8) +
                        eos_mask.type(torch.uint8)).type(torch.bool).type(
                            torch.uint8)  # >=pt1.1
            len_map = len_map + torch.Tensor([1]).repeat(
                batch * beam_width).to(device=device).masked_fill(eos_mask, 0)

            # early stop
            if sum(eos_mask.int()) == eos_mask.size(0):
                break

        # select the best candidate
        preds = preds_expand.reshape(
            batch, -1)[:, :self.max_seq_len].contiguous()  # b x len
        scores = scores_expand.reshape(batch, -1)[:, 0].contiguous()  # b

        # select the worst candidate
        # preds = preds_expand.reshape(batch, -1)
        # [:, (beam_width - 1)*length : (beam_width)*length].contiguous() # b x len
        # scores = scores_expand.reshape(batch, -1)[:, -1].contiguous() # b

        return preds
Exemplo n.º 10
0
Arquivo: Dec.py Projeto: EdieLu/LAS
    def __init__(
            self,
            # params
            vocab_size,
            embedding_size=200,
            acous_hidden_size=256,
            acous_att_mode='bahdanau',
            hidden_size_dec=200,
            hidden_size_shared=200,
            num_unilstm_dec=4,
            use_type='char',
            #
            embedding_dropout=0,
            dropout=0.0,
            residual=True,
            batch_first=True,
            max_seq_len=32,
            load_embedding=None,
            word2id=None,
            id2word=None,
            hard_att=False,
            use_gpu=False):

        super(Dec, self).__init__()
        device = check_device(use_gpu)

        # define model
        self.acous_hidden_size = acous_hidden_size
        self.acous_att_mode = acous_att_mode
        self.hidden_size_dec = hidden_size_dec
        self.hidden_size_shared = hidden_size_shared
        self.num_unilstm_dec = num_unilstm_dec

        # define var
        self.hard_att = hard_att
        self.residual = residual
        self.max_seq_len = max_seq_len
        self.use_type = use_type

        # use shared embedding + vocab
        self.vocab_size = vocab_size
        self.embedding_size = embedding_size
        self.load_embedding = load_embedding
        self.word2id = word2id
        self.id2word = id2word

        # define operations
        self.embedding_dropout = nn.Dropout(embedding_dropout)
        self.dropout = nn.Dropout(dropout)

        # ------- load embeddings --------
        if self.load_embedding:
            embedding_matrix = np.random.rand(self.vocab_size,
                                              self.embedding_size)
            embedding_matrix = load_pretrained_embedding(
                self.word2id, embedding_matrix, self.load_embedding)
            embedding_matrix = torch.FloatTensor(embedding_matrix)
            self.embedder = nn.Embedding.from_pretrained(embedding_matrix,
                                                         freeze=False,
                                                         sparse=False,
                                                         padding_idx=PAD)
        else:
            self.embedder = nn.Embedding(self.vocab_size,
                                         self.embedding_size,
                                         sparse=False,
                                         padding_idx=PAD)

        # ------ define acous att --------
        dropout_acous_att = dropout
        self.acous_hidden_size_att = 0  # ignored with bilinear

        self.acous_key_size = self.acous_hidden_size * 2  # acous feats
        self.acous_value_size = self.acous_hidden_size * 2  # acous feats
        self.acous_query_size = self.hidden_size_dec  # use dec(words) as query
        self.acous_att = AttentionLayer(self.acous_query_size,
                                        self.acous_key_size,
                                        value_size=self.acous_value_size,
                                        mode=self.acous_att_mode,
                                        dropout=dropout_acous_att,
                                        query_transform=False,
                                        output_transform=False,
                                        hidden_size=self.acous_hidden_size_att,
                                        use_gpu=use_gpu,
                                        hard_att=False)

        # ------ define acous out --------
        self.acous_ffn = nn.Linear(self.acous_hidden_size * 2 +
                                   self.hidden_size_dec,
                                   self.hidden_size_shared,
                                   bias=False)
        self.acous_out = nn.Linear(self.hidden_size_shared,
                                   self.vocab_size,
                                   bias=True)

        # ------ define acous dec -------
        # embedding_size_dec + self.hidden_size_shared [200+200]-> hidden_size_dec [200]
        if not self.residual:
            self.dec = torch.nn.LSTM(self.embedding_size +
                                     self.hidden_size_shared,
                                     self.hidden_size_dec,
                                     num_layers=self.num_unilstm_dec,
                                     batch_first=batch_first,
                                     bias=True,
                                     dropout=dropout,
                                     bidirectional=False)
        else:
            self.dec = nn.Module()
            self.dec.add_module(
                'l0',
                torch.nn.LSTM(self.embedding_size + self.hidden_size_shared,
                              self.hidden_size_dec,
                              num_layers=1,
                              batch_first=batch_first,
                              bias=True,
                              dropout=dropout,
                              bidirectional=False))

            for i in range(1, self.num_unilstm_dec):
                self.dec.add_module(
                    'l' + str(i),
                    torch.nn.LSTM(self.hidden_size_dec,
                                  self.hidden_size_dec,
                                  num_layers=1,
                                  batch_first=batch_first,
                                  bias=True,
                                  dropout=dropout,
                                  bidirectional=False))
Exemplo n.º 11
0
    def forward_eval(self, src, debug_flag=False, use_gpu=True):
        """
			eval enc + dec (beam_width = 1)
			all outputs following:
				tgt:	<BOS> w1 w2 w3 <EOS> <PAD>
				gen:		  w1 w2 w3 <EOS> <PAD> <PAD>
				shift by 1, i.e.
					used input = <BOS>	w1 		<PAD> 	<PAD>
					gen output = dummy	w2 		dummy
					update prediction: assign w2(output[1]) to be input[2]
		"""

        # import pdb; pdb.set_trace()

        # check gpu
        global device
        device = check_device(use_gpu)

        batch = src.size(0)
        length_out = self.max_seq_len

        # run enc dec
        eos_mask = torch.BoolTensor([False]).repeat(batch).to(device=device)
        src_mask = _get_pad_mask(src).type(torch.uint8).to(device=device)
        if self.enc_emb_proj_flag:
            emb_src = self.enc_emb_proj(
                self.embedding_dropout(self.enc_embedder(src)))
        else:
            emb_src = self.embedding_dropout(self.enc_embedder(src))
        enc_outputs, enc_var = self.enc(emb_src, src_mask=src_mask)

        # record
        logps = torch.Tensor([-1e-4]).repeat(
            batch, length_out,
            self.dec_vocab_size).type(torch.FloatTensor).to(device=device)
        dec_outputs = torch.Tensor([0]).repeat(
            batch, length_out,
            self.dim_model).type(torch.FloatTensor).to(device=device)
        preds_save = torch.Tensor([PAD]).repeat(batch, length_out).type(
            torch.LongTensor).to(device=device)  # used to update pred history

        # start from length = 1
        preds = torch.Tensor([BOS]).repeat(batch, 1).type(
            torch.LongTensor).to(device=device)
        preds_save[:, 0] = preds[:, 0]

        for i in range(1, self.max_seq_len):

            # gen: 0-30; ref: 1-31
            # import pdb; pdb.set_trace()

            tgt_mask = ((
                _get_pad_mask(preds).type(torch.uint8).to(device=device)
                & _get_subsequent_mask(preds.size(-1)).type(
                    torch.uint8).to(device=device)))
            if self.dec_emb_proj_flag:
                emb_tgt = self.dec_emb_proj(self.dec_embedder(preds))
            else:
                emb_tgt = self.dec_embedder(preds)
            dec_output, dec_var, *_ = self.dec(emb_tgt,
                                               enc_outputs,
                                               tgt_mask=tgt_mask,
                                               src_mask=src_mask)
            logit = self.out(dec_output)
            logp = torch.log_softmax(logit, dim=2)
            pred = logp.data.topk(1)[1]  # b x :i
            # eos_mask = (pred[:, i-1].squeeze(1) == EOS) + eos_mask # >=pt1.3
            eos_mask = ((pred[:, i - 1].squeeze(1) == EOS).type(torch.uint8) +
                        eos_mask.type(torch.uint8)).type(torch.bool).type(
                            torch.uint8)  # >=pt1.1

            # b x len x dim_model - [:,0,:] is dummy 0's
            dec_outputs[:, i, :] = dec_output[:, i - 1]
            # b x len x vocab_size - [:,0,:] is dummy -1e-4's # individual logps
            logps[:, i, :] = logp[:, i - 1, :]
            # b x len - [:,0] is BOS
            preds_save[:, i] = pred[:, i - 1].view(-1)

            # append current pred, length+1
            preds = torch.cat((preds, pred[:, i - 1]), dim=1)

            if sum(eos_mask.int()) == eos_mask.size(0):
                # import pdb; pdb.set_trace()
                if length_out != preds.size(1):
                    dummy = torch.Tensor([PAD]).repeat(
                        batch, length_out - preds.size(1)).type(
                            torch.LongTensor).to(device=device)
                    preds = torch.cat((preds, dummy),
                                      dim=1)  # pad to max length
                break

        if not debug_flag:
            return preds, logps, dec_outputs
        else:
            return preds, logps, dec_outputs, enc_var, dec_var
Exemplo n.º 12
0
    def forward_eval_fast(self, src, debug_flag=False, use_gpu=True):
        """
			require large memory - run on cpu
		"""

        # import pdb; pdb.set_trace()

        # check gpu
        global device
        device = check_device(use_gpu)

        batch = src.size(0)
        length_out = self.max_seq_len

        # run enc dec
        src_mask = _get_pad_mask(src).type(torch.uint8).to(device=device)
        if self.enc_emb_proj_flag:
            emb_src = self.enc_emb_proj(
                self.embedding_dropout(self.enc_embedder(src)))
        else:
            emb_src = self.embedding_dropout(self.enc_embedder(src))
        enc_outputs, enc_var = self.enc(emb_src, src_mask=src_mask)

        # record
        logps = torch.Tensor([-1e-4]).repeat(
            batch, length_out,
            self.dec_vocab_size).type(torch.FloatTensor).to(device=device)
        dec_outputs = torch.Tensor([0]).repeat(
            batch, length_out,
            self.dim_model).type(torch.FloatTensor).to(device=device)
        preds_save = torch.Tensor([PAD]).repeat(batch, length_out).type(
            torch.LongTensor).to(device=device)  # used to update pred history

        # start from length = 1
        preds = torch.Tensor([BOS]).repeat(batch, 1).type(
            torch.LongTensor).to(device=device)
        preds_save[:, 0] = preds[:, 0]

        for i in range(1, self.max_seq_len):

            # gen: 0-30; ref: 1-31
            # import pdb; pdb.set_trace()

            tgt_mask = ((
                _get_pad_mask(preds).type(torch.uint8).to(device=device)
                & _get_subsequent_mask(preds.size(-1)).type(
                    torch.uint8).to(device=device)))
            if self.dec_emb_proj_flag:
                emb_tgt = self.dec_emb_proj(self.dec_embedder(preds))
            else:
                emb_tgt = self.dec_embedder(preds)
            if i == 1:
                cache_decslf = None
                cache_encdec = None
            dec_output, dec_var, *_, cache_decslf, cache_encdec = self.dec(
                emb_tgt,
                enc_outputs,
                tgt_mask=tgt_mask,
                src_mask=src_mask,
                decode_speedup=True,
                cache_decslf=cache_decslf,
                cache_encdec=cache_encdec)
            logit = self.out(dec_output)
            logp = torch.log_softmax(logit, dim=2)
            pred = logp.data.topk(1)[1]  # b x :i

            # b x len x dim_model - [:,0,:] is dummy 0's
            dec_outputs[:, i, :] = dec_output[:, i - 1]
            # b x len x vocab_size - [:,0,:] is dummy -1e-4's # individual logps
            logps[:, i, :] = logp[:, i - 1, :]
            # b x len - [:,0] is BOS
            preds_save[:, i] = pred[:, i - 1].view(-1)

            # append current pred, length+1
            preds = torch.cat((preds, pred[:, i - 1]), dim=1)

        if not debug_flag:
            return preds, logps, dec_outputs
        else:
            return preds, logps, dec_outputs, enc_var, dec_var
Exemplo n.º 13
0
Arquivo: Dec.py Projeto: EdieLu/LAS
    def forward(self,
                acous_outputs,
                acous_lens=None,
                tgt=None,
                hidden=None,
                is_training=False,
                teacher_forcing_ratio=0.0,
                beam_width=1,
                use_gpu=False):
        """
			Args:
				enc_outputs: [batch_size, acous_len / 8, self.acous_hidden_size * 2]
				tgt: list of word_ids 			[b x seq_len]
				hidden: initial hidden state
				is_training: whether in eval or train mode
				teacher_forcing_ratio: default at 1 - always teacher forcing
			Returns:
				decoder_outputs: list of step_output -
					log predicted_softmax [batch_size, 1, vocab_size_dec] * (T-1)
		"""

        # import pdb; pdb.set_trace()

        global device
        device = check_device(use_gpu)

        # 0. init var
        ret_dict = dict()
        ret_dict[KEY_ATTN_SCORE] = []

        decoder_outputs = []
        sequence_symbols = []
        batch_size = acous_outputs.size(0)

        if type(tgt) == type(None):
            tgt = torch.Tensor([BOS]).repeat(
                batch_size,
                self.max_seq_len).type(torch.LongTensor).to(device=device)

        max_seq_len = tgt.size(1)
        lengths = np.array([max_seq_len] * batch_size)

        # 1. convert id to embedding
        emb_tgt = self.embedding_dropout(self.embedder(tgt))

        # 2. att inputs: keys n values
        att_keys = acous_outputs
        att_vals = acous_outputs

        # generate acous mask: True for trailing 0's
        if type(acous_lens) != type(None):
            # reduce by 8
            lens = torch.cat([elem + 8 - elem % 8 for elem in acous_lens]) / 8
            max_acous_len = acous_outputs.size(1)
            # mask=True over trailing 0s
            mask = torch.arange(max_acous_len).to(device=device).expand(
                batch_size,
                max_acous_len) >= lens.unsqueeze(1).to(device=device)
        else:
            mask = None

        # 3. init hidden states
        dec_hidden = None

        # 4. run dec + att + shared + output
        """
			teacher_forcing_ratio = 1.0 -> always teacher forcing
			E.g.:
				acous 	        = [acous_len/8]
				tgt_chunk in    = w1 w2 w3 </s> <pad> <pad> <pad>   [max_seq_len]
				predicted       = w1 w2 w3 </s> <pad> <pad> <pad>   [max_seq_len]
		"""

        # LAS under teacher forcing
        use_teacher_forcing = True if random.random(
        ) < teacher_forcing_ratio else False

        # beam search decoding
        if not is_training and beam_width > 1:
            decoder_outputs, decoder_hidden, metadata = \
             self.beam_search_decoding(att_keys, att_vals, dec_hidden, mask,
                   beam_width=beam_width)
            return decoder_outputs, decoder_hidden, metadata

        # no beam search decoding
        tgt_chunk = self.embedder(
            torch.Tensor([BOS]).repeat(batch_size, 1).type(
                torch.LongTensor).to(device=device))  # BOS
        cell_value = torch.FloatTensor([0])\
         .repeat(batch_size, 1, self.hidden_size_shared).to(device=device)
        prev_c = torch.FloatTensor([0]).repeat(batch_size, 1,
                                               max_seq_len).to(device=device)
        attn_outputs = []
        for idx in range(max_seq_len):
            predicted_logsoftmax, dec_hidden, step_attn, c_out, cell_value, attn_output = \
             self.forward_step(self.acous_att, self.acous_ffn, self.acous_out,
                 att_keys, att_vals, tgt_chunk, cell_value,
                 dec_hidden, mask, prev_c)
            predicted_logsoftmax = predicted_logsoftmax.squeeze(
                1)  # [b, vocab_size]
            step_output = predicted_logsoftmax
            symbols, decoder_outputs, sequence_symbols, lengths = \
             self.decode(idx, step_output, decoder_outputs, sequence_symbols, lengths)
            prev_c = c_out
            if use_teacher_forcing:
                tgt_chunk = emb_tgt[:, idx].unsqueeze(1)
            else:
                tgt_chunk = self.embedder(symbols)
            ret_dict[KEY_ATTN_SCORE].append(step_attn)
            attn_outputs.append(attn_output)

        ret_dict[KEY_SEQUENCE] = sequence_symbols
        ret_dict[KEY_LENGTH] = lengths.tolist()
        ret_dict[KEY_ATTN_OUT] = attn_outputs

        # import pdb; pdb.set_trace()

        return decoder_outputs, dec_hidden, ret_dict
Exemplo n.º 14
0
    def calc_score(self, att_query, att_keys, prev_c=None, use_gpu=True):
        """
			att_query: 	b x t_q x n_q (inference: t_q=1)
			att_keys:  	b x t_k x n_k
			return:		b x t_q x t_k

			'dot_prod': att = q * k^T
			'bahdanau':	att = W * tanh(Uq + Vk + b)
			'loc_based': att = a * exp[ b(c-j)^2 ]
							j - key idx
							i - query idx
							prev_c - c_(i-1)
							a0,b0,c0 parameterised by q, k - (Uq_i + Vk_j + b)
							a = exp(a0), b=exp(b0)
							c = prev_c + exp(c0)

		"""

        device = check_device(use_gpu)

        b = att_query.size(0)
        t_q = att_query.size(1)  # = 1 if in inference mode
        t_k = att_keys.size(1)
        n_q = att_query.size(2)
        n_k = att_keys.size(2)

        c_out = None  # placeholder

        if self.mode == 'bahdanau':
            att_query = att_query.unsqueeze(2).expand(b, t_q, t_k, n_q)
            att_keys = att_keys.unsqueeze(1).expand(b, t_q, t_k, n_k)
            wq = self.linear_att_q(att_query).view(b, t_q, t_k,
                                                   self.hidden_size)
            uk = self.linear_att_k(att_keys).view(b, t_q, t_k,
                                                  self.hidden_size)
            sum_qk = wq + uk
            out = self.linear_att_o(F.tanh(sum_qk)).view(b, t_q, t_k)

        elif self.mode == 'hybrid':

            # start_time = time.time()

            att_query = att_query.unsqueeze(2).expand(b, t_q, t_k, n_q)
            att_keys = att_keys.unsqueeze(1).expand(b, t_q, t_k, n_k)

            if not hasattr(self,
                           'linear_att_ao'):  # to word with old att setup
                self.hidden_size = 1  # fix

                a_wq = self.linear_att_aq(att_query).view(
                    b, t_q, t_k, self.hidden_size)
                a_uk = self.linear_att_ak(att_keys).view(
                    b, t_q, t_k, self.hidden_size)
                a_sum_qk = a_wq + a_uk
                a_out = torch.exp(torch.tanh(a_sum_qk)).view(b, t_q, t_k)

                b_wq = self.linear_att_bq(att_query).view(
                    b, t_q, t_k, self.hidden_size)
                b_uk = self.linear_att_bk(att_keys).view(
                    b, t_q, t_k, self.hidden_size)
                b_sum_qk = b_wq + b_uk
                b_out = torch.exp(torch.tanh(b_sum_qk)).view(b, t_q, t_k)

                c_wq = self.linear_att_cq(att_query).view(
                    b, t_q, t_k, self.hidden_size)
                c_uk = self.linear_att_ck(att_keys).view(
                    b, t_q, t_k, self.hidden_size)
                c_sum_qk = c_wq + c_uk
                c_out = torch.exp(torch.tanh(c_sum_qk)).view(b, t_q, t_k)

            else:  # new setup by default
                a_wq = self.linear_att_aq(att_query).view(
                    b, t_q, t_k, self.hidden_size)
                a_uk = self.linear_att_ak(att_keys).view(
                    b, t_q, t_k, self.hidden_size)
                a_sum_qk = a_wq + a_uk
                a_out = torch.exp(self.linear_att_ao(
                    torch.tanh(a_sum_qk))).view(b, t_q, t_k)

                b_wq = self.linear_att_bq(att_query).view(
                    b, t_q, t_k, self.hidden_size)
                b_uk = self.linear_att_bk(att_keys).view(
                    b, t_q, t_k, self.hidden_size)
                b_sum_qk = b_wq + b_uk
                b_out = torch.exp(self.linear_att_bo(
                    torch.tanh(b_sum_qk))).view(b, t_q, t_k)

                c_wq = self.linear_att_cq(att_query).view(
                    b, t_q, t_k, self.hidden_size)
                c_uk = self.linear_att_ck(att_keys).view(
                    b, t_q, t_k, self.hidden_size)
                c_sum_qk = c_wq + c_uk
                c_out = torch.exp(self.linear_att_co(
                    torch.tanh(c_sum_qk))).view(b, t_q, t_k)

            # print(time.time() - start_time)

            if t_q != 1:
                # teacher forcing mode - t_q != 1
                key_indices = torch.arange(t_k).repeat(b, t_q).view(b, t_q, t_k)\
                 .type(torch.FloatTensor).to(device=device)
                c_curr = torch.FloatTensor([0]).repeat(b, t_q,
                                                       t_k).to(device=device)

                for i in range(t_q):
                    c_temp = torch.sum(c_out[:, :i + 1, :], dim=1)
                    c_curr[:, i, :] = c_temp
                out = a_out * torch.exp(-b_out * torch.pow(
                    (c_curr - key_indices), 2))

            else:
                # infernece mode: t_q = 1
                key_indices = torch.arange(t_k).repeat(b, 1).view(b, 1, t_k)\
                 .type(torch.FloatTensor).to(device=device)

                c_out = prev_c + c_out
                out = a_out * torch.exp(-b_out * torch.pow(
                    (c_out - key_indices), 2))

        elif self.mode == 'bilinear':

            wk = self.linear_att_w(att_keys).view(b, t_k, n_q)
            out = torch.bmm(att_query, wk.transpose(1, 2))

        elif self.mode == 'dot_prod':

            assert n_q == n_k, 'Dot_prod attention - query, key size must agree!'
            out = torch.bmm(att_query, att_keys.transpose(1, 2))

        return out, c_out
    def forward_train(self,
                      src,
                      tgt=None,
                      acous_feats=None,
                      acous_lens=None,
                      mode='ST',
                      use_gpu=True,
                      lm_mode='null',
                      lm_model=None):
        """
			mode: 	ASR 		acous -> src
					AE 			src -> src
					ST			acous -> tgt
					MT 			src -> tgt
		"""

        # import pdb; pdb.set_trace()
        # note: adding .type(torch.uint8) to be compatible with pytorch 1.1!
        out_dict = {}

        # check gpu
        global device
        device = check_device(use_gpu)

        # check mode
        mode = mode.upper()
        assert type(src) != type(None)
        if 'ST' in mode or 'ASR' in mode:
            assert type(acous_feats) != type(None)
        if 'ST' in mode or 'MT' in mode:
            assert type(tgt) != type(None)

        if 'ASR' in mode:
            """
				acous -> EN: RNN
				in : length reduced fbk features
				out: w1 w2 w3 <EOS> <PAD> <PAD> #=6
			"""
            emb_src, logps_src, preds_src, lengths = self._encoder_acous(
                acous_feats,
                acous_lens,
                device,
                use_gpu,
                tgt=src,
                is_training=True,
                teacher_forcing_ratio=1.0,
                lm_mode=lm_mode,
                lm_model=lm_model)

            # output dict
            out_dict['emb_asr'] = emb_src  # dynamic
            out_dict['preds_asr'] = preds_src
            out_dict['logps_asr'] = logps_src
            out_dict['lengths_asr'] = lengths

        if 'MT' in mode:
            """
				EN -> DE: Transformer
				src: <BOS> w1 w2 w3 <EOS> <PAD> <PAD> #=7
				mid: w1 w2 w3 <EOS> <PAD> <PAD> #=6
				out: c1 c2 c3 <EOS> <PAD> <PAD> [dummy] #=7

				note: add average dynamic embedding to static embedding
			"""
            # get tgt emb
            tgt_mask, emb_tgt = self._get_tgt_emb(tgt, device)
            # get src emb
            src_trim = self._pre_proc_src(src, device)
            emb_dyn_ave = self.EMB_DYN_AVE
            emb_dyn_ave_expand = emb_dyn_ave.repeat(src_trim.size(0),
                                                    src_trim.size(1),
                                                    1).to(device=device)
            src_mask, emb_src, src_mask_input = self._get_src_emb(
                src_trim, emb_dyn_ave_expand, device)

            # encode decode
            enc_outputs = self._encoder_en(
                emb_src, src_mask=src_mask_input)  # b x len x dim_model
            # decode
            dec_outputs_tgt, logits_tgt, logps_tgt, preds_tgt, _ = \
             self._decoder_de(emb_tgt, enc_outputs, tgt_mask=tgt_mask, src_mask=src_mask_input)

            # output dict
            out_dict['emb_mt'] = emb_src  # combined
            out_dict['preds_mt'] = preds_tgt
            out_dict['logps_mt'] = logps_tgt

        if 'ST' in mode:
            """
				acous -> DE: Transformer
				in : length reduced fbk features
				mid: w1 w2 w3 <EOS> <PAD> <PAD> #=6
				out: c1 c2 c3 <EOS> <PAD> <PAD> [dummy] #=7
			"""
            # get tgt emb
            tgt_mask, emb_tgt = self._get_tgt_emb(tgt, device)
            # run ASR
            if 'ASR' in mode:
                emb_src_dyn = out_dict['emb_asr']
                lengths = out_dict['lengths_asr']
            # else: # use free running if no 'ASR'
            # 	emb_src_dyn, _, _, lengths = self._encoder_acous(acous_feats, acous_lens,
            # 		device, use_gpu, tgt=src, is_training=True, teacher_forcing_ratio=1.0)
            else:  # use free running if no 'ASR'
                emb_src_dyn, _, _, lengths = self._encoder_acous(
                    acous_feats,
                    acous_lens,
                    device,
                    use_gpu,
                    is_training=False,
                    teacher_forcing_ratio=0.0,
                    lm_mode=lm_mode,
                    lm_model=lm_model)

            # get combined embedding
            src_trim = self._pre_proc_src(src, device)
            _, emb_src, _ = self._get_src_emb(src_trim, emb_src_dyn, device)

            # get mask
            max_len = emb_src.size(1)
            lengths = torch.LongTensor(lengths)
            src_mask_input = (torch.arange(max_len).expand(
                len(lengths), max_len) < lengths.unsqueeze(1)).unsqueeze(1).to(
                    device=device)
            # encode
            enc_outputs = self._encoder_en(
                emb_src, src_mask=src_mask_input)  # b x len x dim_model
            # decode
            dec_outputs_tgt, logits_tgt, logps_tgt, preds_tgt, _ = \
             self._decoder_de(emb_tgt, enc_outputs, tgt_mask=tgt_mask, src_mask=src_mask_input)

            # output dict
            out_dict['emb_st'] = emb_src  # combined
            out_dict['preds_st'] = preds_tgt
            out_dict['logps_st'] = logps_tgt

        return out_dict
    def forward_translate_refen(self,
                                acous_feats=None,
                                acous_lens=None,
                                src=None,
                                beam_width=1,
                                penalty_factor=1,
                                use_gpu=True,
                                max_seq_len=900,
                                mode='ST',
                                lm_mode='null',
                                lm_model=None):
        """
			run inference - with beam search (same output format as is in forward_eval)
		"""

        # import pdb; pdb.set_trace()

        # check gpu
        global device
        device = check_device(use_gpu)

        if mode == 'ASR':
            _, _, preds_src, _ = self._encoder_acous(acous_feats,
                                                     acous_lens,
                                                     device,
                                                     use_gpu,
                                                     tgt=src,
                                                     is_training=False,
                                                     teacher_forcing_ratio=1.0,
                                                     lm_mode=lm_mode,
                                                     lm_model=lm_model)

            preds = preds_src

        elif mode == 'MT':
            batch = src.size(0)

            # txt encoder
            src_trim = self._pre_proc_src(src, device)
            emb_dyn_ave = self.EMB_DYN_AVE
            emb_dyn_ave_expand = emb_dyn_ave.repeat(src_trim.size(0),
                                                    src_trim.size(1),
                                                    1).to(device=device)
            src_mask, emb_src, src_mask_input = self._get_src_emb(
                src_trim, emb_dyn_ave_expand, device)
            enc_outputs = self._encoder_en(emb_src, src_mask=src_mask_input)
            length_in = enc_outputs.size(1)

            # prep
            eos_mask, len_map, preds, enc_outputs_expand, preds_expand, \
             scores_expand, src_mask_input_expand = self._prep_translate(
             batch, beam_width, device, length_in, enc_outputs, src_mask_input)

            # loop over sequence length
            for i in range(1, max_seq_len):

                tgt_mask_expand, emb_tgt_expand = self._get_tgt_emb(
                    preds_expand, device)
                dec_output_expand, logit_expand, logp_expand, pred_expand, score_expand = \
                 self._decoder_de(emb_tgt_expand, enc_outputs_expand,
                 tgt_mask=tgt_mask_expand, src_mask=src_mask_input_expand,
                 beam_width=beam_width)

                scores_expand, preds_expand, eos_mask, len_map, flag = \
                 self._step_translate(i, batch, beam_width, device,
                  dec_output_expand, logp_expand, pred_expand, score_expand,
                  preds_expand, scores_expand, eos_mask, len_map, penalty_factor)
                if flag == 1: break

            # select the best candidate
            preds = preds_expand.reshape(
                batch, -1)[:, :max_seq_len].contiguous()  # b x len
            scores = scores_expand.reshape(batch, -1)[:, 0].contiguous()  # b

        elif mode == 'ST':
            batch = acous_feats.size(0)

            # get embedding
            emb_src_dyn, _, preds_src, lengths = self._encoder_acous(
                acous_feats,
                acous_lens,
                device,
                use_gpu,
                tgt=src,
                is_training=False,
                teacher_forcing_ratio=1.0,
                lm_mode=lm_mode,
                lm_model=lm_model)
            src_trim = self._pre_proc_src(src, device)
            _, emb_src, _ = self._get_src_emb(src_trim, emb_src_dyn,
                                              device)  # use ref

            # get mask
            max_len = emb_src.size(1)
            lengths = torch.LongTensor(lengths)
            src_mask_input = (torch.arange(max_len).expand(
                len(lengths), max_len) < lengths.unsqueeze(1)).unsqueeze(1).to(
                    device=device)
            # encode
            enc_outputs = self._encoder_en(
                emb_src, src_mask=src_mask_input)  # b x len x dim_model
            length_in = enc_outputs.size(1)

            # prep
            eos_mask, len_map, preds, enc_outputs_expand, preds_expand, \
             scores_expand, src_mask_input_expand = self._prep_translate(
             batch, beam_width, device, length_in, enc_outputs, src_mask_input)

            # loop over sequence length
            for i in range(1, max_seq_len):

                # import pdb; pdb.set_trace()

                # Get k candidates for each beam, k^2 candidates in total (k=beam_width)
                tgt_mask_expand, emb_tgt_expand = self._get_tgt_emb(
                    preds_expand, device)
                dec_output_expand, logit_expand, logp_expand, pred_expand, score_expand = \
                 self._decoder_de(emb_tgt_expand, enc_outputs_expand,
                 tgt_mask=tgt_mask_expand, src_mask=src_mask_input_expand,
                 beam_width=beam_width)

                scores_expand, preds_expand, eos_mask, len_map, flag = \
                 self._step_translate(i, batch, beam_width, device,
                  dec_output_expand, logp_expand, pred_expand, score_expand,
                  preds_expand, scores_expand, eos_mask, len_map, penalty_factor)
                if flag == 1: break

            # select the best candidate
            preds = preds_expand.reshape(
                batch, -1)[:, :max_seq_len].contiguous()  # b x len
            scores = scores_expand.reshape(batch, -1)[:, 0].contiguous()  # b

        return preds
Exemplo n.º 17
0
    def forward(self, query, keys, values=None, prev_c=None, use_gpu=True):
        """
			query(out):	b x t_q x n_q
			keys(in): 	b x t_k x n_k (usually: n_k >= n_v - keys are richer)
			vals(in): 	b x t_k x n_v
			context:	b x t_q x output_size
			scores: 	b x t_q x t_k

			prev_c: for loc_based attention; None otherwise
			c_out: for loc_based attention; None otherwise

			in general
				n_q = embedding_dim
				n_k = size of key vectors
				n_v = size of value vectors
		"""

        device = check_device(use_gpu)

        if not self.batch_first:
            keys = keys.transpose(0, 1)
            if values is not None:
                values = values.transpose(0, 1)
            if query.dim() == 3:
                query = query.transpose(0, 1)

        if query.dim() == 2:
            single_query = True
            query = query.unsqueeze(1)
        else:
            single_query = False

        values = keys if values is None else values  # b x t_k x n_v/n_k

        b = query.size(0)
        t_k = keys.size(1)
        t_q = query.size(1)

        if hasattr(self, 'linear_q'):
            att_query = self.linear_q(query)
        else:
            att_query = query

        # b x t_q x t_k
        scores, c_out = self.calc_score(att_query,
                                        keys,
                                        prev_c,
                                        use_gpu=use_gpu)

        if self.mask is not None:
            mask = self.mask.unsqueeze(1).expand(b, t_q, t_k)
            scores.masked_fill_(mask, -1e12)

        # Normalize the scores OR use hard attention
        if hasattr(self, 'hard_att'):
            if self.hard_att:
                top_idx = torch.argmax(scores, dim=2)
                scores_view = scores.view(-1, t_k)
                scores_hard = (scores_view == scores_view.max(
                    dim=1, keepdim=True)[0]).view_as(scores)
                scores_hard = scores_hard.type(torch.FloatTensor)
                total_score = torch.sum(scores_hard, dim=2)
                total_score = total_score.view(b, t_q,
                                               1).repeat(1, 1,
                                                         t_k).view_as(scores)
                scores_normalized = (scores_hard /
                                     total_score).to(device=device)
            else:
                scores_normalized = F.softmax(scores, dim=2)
        else:
            scores_normalized = F.softmax(scores, dim=2)
        # print(torch.argmax(scores_normalized[0], dim=1))

        # Context = the weighted average of the attention inputs
        # scores_normalized = self.dropout(scores_normalized) # b x t_q x t_k
        context = torch.bmm(scores_normalized, values)  # b x t_q x n_v

        if hasattr(self, 'linear_out'):
            context = self.linear_out(torch.cat([query, context], 2))
            if self.output_nonlinearity == 'tanh':
                context = F.tanh(context)
            elif self.output_nonlinearity == 'relu':
                context = F.relu(context, inplace=True)

        if single_query:
            context = context.squeeze(1)
            scores_normalized = scores_normalized.squeeze(1)
        elif not self.batch_first:
            context = context.transpose(0, 1)
            scores_normalized = scores_normalized.transpose(0, 1)

        return context, scores_normalized, c_out
Exemplo n.º 18
0
Arquivo: train.py Projeto: EdieLu/LAS
def main():

	# import pdb; pdb.set_trace()
	# load config
	parser = argparse.ArgumentParser(description='LAS Training')
	parser = load_arguments(parser)
	args = vars(parser.parse_args())
	config = validate_config(args)

	# set random seed
	if config['random_seed'] is not None:
		set_global_seeds(config['random_seed'])

	# record config
	if not os.path.isabs(config['save']):
		config_save_dir = os.path.join(os.getcwd(), config['save'])
	if not os.path.exists(config['save']):
		os.makedirs(config['save'])

	# resume or not
	if type(config['load']) != type(None):
		config_save_dir = os.path.join(config['save'], 'model-cont.cfg')
	else:
		config_save_dir = os.path.join(config['save'], 'model.cfg')
	save_config(config, config_save_dir)

	# contruct trainer
	t = Trainer(expt_dir=config['save'],
					load_dir=config['load'],
					batch_size=config['batch_size'],
					minibatch_partition=config['minibatch_partition'],
					checkpoint_every=config['checkpoint_every'],
					print_every=config['print_every'],
					learning_rate=config['learning_rate'],
					eval_with_mask=config['eval_with_mask'],
					scheduled_sampling=config['scheduled_sampling'],
					teacher_forcing_ratio=config['teacher_forcing_ratio'],
					use_gpu=config['use_gpu'],
					max_grad_norm=config['max_grad_norm'],
					max_count_no_improve=config['max_count_no_improve'],
					max_count_num_rollback=config['max_count_num_rollback'],
					keep_num=config['keep_num'],
					normalise_loss=config['normalise_loss'])

	# vocab
	path_vocab_src = config['path_vocab_src']

	# load train set
	train_path_src = config['train_path_src']
	train_acous_path = config['train_acous_path']
	train_set = Dataset(train_path_src, path_vocab_src=path_vocab_src,
		use_type=config['use_type'],
		acous_path=train_acous_path,
		seqrev=config['seqrev'],
		acous_norm=config['acous_norm'],
		acous_norm_path=config['acous_norm_path'],
		max_seq_len=config['max_seq_len'],
		batch_size=config['batch_size'],
		acous_max_len=config['acous_max_len'],
		use_gpu=config['use_gpu'],
		logger=t.logger)

	vocab_size = len(train_set.vocab_src)

	# load dev set
	if config['dev_path_src']:
		dev_path_src = config['dev_path_src']
		dev_acous_path = config['dev_acous_path']
		dev_set = Dataset(dev_path_src, path_vocab_src=path_vocab_src,
			use_type=config['use_type'],
			acous_path=dev_acous_path,
			acous_norm_path=config['acous_norm_path'],
			seqrev=config['seqrev'],
			acous_norm=config['acous_norm'],
			max_seq_len=config['max_seq_len'],
			batch_size=config['batch_size'],
			acous_max_len=config['acous_max_len'],
			use_gpu=config['use_gpu'],
			logger=t.logger)
	else:
		dev_set = None

	# construct model
	las_model = LAS(vocab_size,
					embedding_size=config['embedding_size'],
					acous_hidden_size=config['acous_hidden_size'],
					acous_att_mode=config['acous_att_mode'],
					hidden_size_dec=config['hidden_size_dec'],
					hidden_size_shared=config['hidden_size_shared'],
					num_unilstm_dec=config['num_unilstm_dec'],
					#
					acous_dim=config['acous_dim'],
					acous_norm=config['acous_norm'],
					spec_aug=config['spec_aug'],
					batch_norm=config['batch_norm'],
					enc_mode=config['enc_mode'],
					use_type=config['use_type'],
					#
					embedding_dropout=config['embedding_dropout'],
					dropout=config['dropout'],
					residual=config['residual'],
					batch_first=config['batch_first'],
					max_seq_len=config['max_seq_len'],
					load_embedding=config['load_embedding'],
					word2id=train_set.src_word2id,
					id2word=train_set.src_id2word,
					use_gpu=config['use_gpu'])

	device = check_device(config['use_gpu'])
	t.logger.info('device:{}'.format(device))
	las_model = las_model.to(device=device)

	# run training
	las_model = t.train(
		train_set, las_model, num_epochs=config['num_epochs'], dev_set=dev_set)
Exemplo n.º 19
0
	def forward(self, enc_outputs, src, tgt=None,
		hidden=None, is_training=False, teacher_forcing_ratio=1.0,
		beam_width=1, use_gpu=True):

		"""
			Args:
				enc_outputs: [batch_size, max_seq_len, self.hidden_size_enc * 2]
				tgt: list of tgt word_ids
				hidden: initial hidden state
				is_training: whether in eval or train mode
				teacher_forcing_ratio: default at 1 - always teacher forcing
			Returns:
				decoder_outputs: list of step_output - log predicted_softmax
					[batch_size, 1, vocab_size_dec] * (T-1)
				ret_dict
		"""

		# import pdb; pdb.set_trace()

		global device
		device = check_device(use_gpu)

		# 0. init var
		ret_dict = dict()
		ret_dict[KEY_ATTN_SCORE] = []

		decoder_outputs = []
		sequence_symbols = []
		batch_size = enc_outputs.size(0)

		if type(tgt) == type(None): tgt = torch.Tensor([BOS]).repeat(
			(batch_size, self.max_seq_len)).type(torch.LongTensor).to(device=device)
		max_seq_len = tgt.size(1)
		lengths = np.array([max_seq_len] * batch_size)

		# 1. convert id to embedding
		emb_tgt = self.embedding_dropout(self.embedder_dec(tgt))

		# 2. att inputs: keys n values
		mask_src = src.data.eq(PAD)
		att_keys = enc_outputs
		att_vals = enc_outputs

		# 3. init hidden states
		dec_hidden = None

		# decoder
		def decode(step, step_output, step_attn):

			"""
				Greedy decoding
				Note:
					it should generate EOS, PAD as used in training tgt
				Args:
					step: step idx
					step_output: log predicted_softmax -
						[batch_size, 1, vocab_size_dec]
					step_attn: attention scores -
						(batch_size x tgt_len(query_len) x src_len(key_len)
				Returns:
					symbols: most probable symbol_id [batch_size, 1]
			"""

			ret_dict[KEY_ATTN_SCORE].append(step_attn)
			decoder_outputs.append(step_output)
			symbols = decoder_outputs[-1].topk(1)[1]
			sequence_symbols.append(symbols)

			eos_batches = torch.max(symbols.data.eq(EOS), symbols.data.eq(PAD))
			# eos_batches = symbols.data.eq(PAD)
			if eos_batches.dim() > 0:
				eos_batches = eos_batches.cpu().view(-1).numpy()
				update_idx = ((lengths > step) & eos_batches) != 0
				lengths[update_idx] = len(sequence_symbols)

			return symbols

		# 4. run dec + att + shared + output
		"""
			teacher_forcing_ratio = 1.0 -> always teacher forcing

			E.g.: (shift-by-1)
			emb_tgt      = <s> w1 w2 w3 </s> <pad> <pad> <pad>   [max_seq_len]
			tgt_chunk in = <s> w1 w2 w3 </s> <pad> <pad>         [max_seq_len - 1]
			predicted    =     w1 w2 w3 </s> <pad> <pad> <pad>   [max_seq_len - 1]

		"""
		# import pdb; pdb.set_trace()
		if not is_training:
			use_teacher_forcing = False
		elif random.random() < teacher_forcing_ratio:
			use_teacher_forcing = True
		else:
			use_teacher_forcing = False

		# beam search decoding
		if not is_training and beam_width > 1:
			decoder_outputs, decoder_hidden, metadata = \
				self.beam_search_decoding(att_keys, att_vals,
				dec_hidden, mask_src, beam_width=beam_width, device=device)
			return decoder_outputs, decoder_hidden, metadata

		# greedy search decoding
		tgt_chunk = emb_tgt[:, 0].unsqueeze(1) # BOS
		cell_value = torch.FloatTensor([0]).repeat(
			batch_size, 1, self.hidden_size_shared).to(device=device)
		prev_c = torch.FloatTensor([0]).repeat(
			batch_size, 1, max_seq_len).to(device=device)
		for idx in range(max_seq_len - 1):
			predicted_logsoftmax, dec_hidden, step_attn, c_out, cell_value = \
				self.forward_step(att_keys, att_vals, tgt_chunk,
					cell_value, dec_hidden, mask_src, prev_c)
			predicted_logsoftmax = predicted_logsoftmax.squeeze(1) # [b, vocab_size]
			step_output = predicted_logsoftmax
			symbols = decode(idx, step_output, step_attn)
			prev_c = c_out
			if use_teacher_forcing:
				tgt_chunk = emb_tgt[:, idx+1].unsqueeze(1)
			else:
				tgt_chunk = self.embedder_dec(symbols)

		ret_dict[KEY_SEQUENCE] = sequence_symbols
		ret_dict[KEY_LENGTH] = lengths.tolist()

		return decoder_outputs, dec_hidden, ret_dict
Exemplo n.º 20
0
Arquivo: Enc.py Projeto: EdieLu/LAS
    def __init__(
        self,
        # params
        acous_dim=26,
        acous_hidden_size=256,
        #
        acous_norm=False,
        spec_aug=False,
        batch_norm=False,
        enc_mode='pyramid',
        #
        dropout=0.0,
        batch_first=True,
        use_gpu=False,
    ):

        super(Enc, self).__init__()
        device = check_device(use_gpu)

        # define model
        self.acous_dim = acous_dim
        self.acous_hidden_size = acous_hidden_size

        # tuning
        self.acous_norm = acous_norm
        self.spec_aug = spec_aug
        self.batch_norm = batch_norm
        self.enc_mode = enc_mode

        # define operations
        self.dropout = nn.Dropout(dropout)

        # ------ define acous enc -------
        if self.enc_mode == 'pyramid':
            self.acous_enc_l1 = torch.nn.LSTM(self.acous_dim,
                                              self.acous_hidden_size,
                                              num_layers=1,
                                              batch_first=batch_first,
                                              bias=True,
                                              dropout=dropout,
                                              bidirectional=True)
            self.acous_enc_l2 = torch.nn.LSTM(self.acous_hidden_size * 4,
                                              self.acous_hidden_size,
                                              num_layers=1,
                                              batch_first=batch_first,
                                              bias=True,
                                              dropout=dropout,
                                              bidirectional=True)
            self.acous_enc_l3 = torch.nn.LSTM(self.acous_hidden_size * 4,
                                              self.acous_hidden_size,
                                              num_layers=1,
                                              batch_first=batch_first,
                                              bias=True,
                                              dropout=dropout,
                                              bidirectional=True)
            self.acous_enc_l4 = torch.nn.LSTM(self.acous_hidden_size * 4,
                                              self.acous_hidden_size,
                                              num_layers=1,
                                              batch_first=batch_first,
                                              bias=True,
                                              dropout=dropout,
                                              bidirectional=True)
            if self.batch_norm:
                self.bn1 = nn.BatchNorm1d(self.acous_hidden_size * 2)
                self.bn2 = nn.BatchNorm1d(self.acous_hidden_size * 2)
                self.bn3 = nn.BatchNorm1d(self.acous_hidden_size * 2)
                self.bn4 = nn.BatchNorm1d(self.acous_hidden_size * 2)

        elif self.enc_mode == 'cnn':
            # todo
            pass
Exemplo n.º 21
0
Arquivo: Enc.py Projeto: EdieLu/LAS
    def forward(self,
                acous_feats,
                acous_lens=None,
                is_training=False,
                hidden=None,
                use_gpu=False):
        """
			Args:
				acous_feats: list of acoustic features 	[b x acous_len x ?]
		"""

        # import pdb; pdb.set_trace()

        device = check_device(use_gpu)

        batch_size = acous_feats.size(0)
        acous_len = acous_feats.size(1)

        # pre-process acoustics
        if is_training: acous_feats = self.pre_process_acous(acous_feats)

        # run acous enc - pyramidal
        acous_hidden_init = None
        if self.enc_mode == 'pyramid':
            # layer1
            # pack to rnn packed seq obj
            acous_lens_l1 = torch.cat(
                [elem + 8 - elem % 8 for elem in acous_lens])
            acous_feats_pack = torch.nn.utils.rnn.pack_padded_sequence(
                acous_feats,
                acous_lens_l1,
                batch_first=True,
                enforce_sorted=False)
            # run lstm
            acous_outputs_l1_pack, acous_hidden_l1 = self.acous_enc_l1(
                acous_feats_pack, acous_hidden_init)  # b x acous_len x 2dim
            # unpack
            acous_outputs_l1, _ = torch.nn.utils.rnn.pad_packed_sequence(
                acous_outputs_l1_pack, batch_first=True)
            # dropout
            acous_outputs_l1 = self.dropout(acous_outputs_l1)\
             .reshape(batch_size, acous_len, acous_outputs_l1.size(-1))
            # batch norm
            if self.batch_norm:
                acous_outputs_l1 = self.bn1(acous_outputs_l1.permute(0, 2, 1))\
                 .permute(0, 2, 1)
            # reduce length
            acous_inputs_l2 = acous_outputs_l1\
             .reshape(batch_size, int(acous_len/2), 2*acous_outputs_l1.size(-1))
            # b x acous_len/2 x 4dim

            # layer2
            acous_lens_l2 = acous_lens_l1 / 2
            acous_inputs_l2_pack = torch.nn.utils.rnn.pack_padded_sequence(
                acous_inputs_l2,
                acous_lens_l2,
                batch_first=True,
                enforce_sorted=False)
            acous_outputs_l2_pack, acous_hidden_l2 = self.acous_enc_l2(
                acous_inputs_l2_pack,
                acous_hidden_init)  # b x acous_len/2 x 2dim
            acous_outputs_l2, _ = torch.nn.utils.rnn.pad_packed_sequence(
                acous_outputs_l2_pack, batch_first=True)
            acous_outputs_l2 = self.dropout(acous_outputs_l2)\
             .reshape(batch_size, int(acous_len/2), acous_outputs_l2.size(-1))
            if self.batch_norm:
                acous_outputs_l2 = self.bn2(acous_outputs_l2.permute(0, 2, 1))\
                 .permute(0, 2, 1)
            acous_inputs_l3 = acous_outputs_l2\
             .reshape(batch_size, int(acous_len/4), 2*acous_outputs_l2.size(-1))
            # b x acous_len/4 x 4dim

            # layer3
            acous_lens_l3 = acous_lens_l2 / 2
            acous_inputs_l3_pack = torch.nn.utils.rnn.pack_padded_sequence(
                acous_inputs_l3,
                acous_lens_l3,
                batch_first=True,
                enforce_sorted=False)
            acous_outputs_l3_pack, acous_hidden_l3 = self.acous_enc_l3(
                acous_inputs_l3_pack,
                acous_hidden_init)  # b x acous_len/4 x 2dim
            acous_outputs_l3, _ = torch.nn.utils.rnn.pad_packed_sequence(
                acous_outputs_l3_pack, batch_first=True)
            acous_outputs_l3 = self.dropout(acous_outputs_l3)\
             .reshape(batch_size, int(acous_len/4), acous_outputs_l3.size(-1))
            if self.batch_norm:
                acous_outputs_l3 = self.bn3(acous_outputs_l3.permute(0, 2, 1))\
                 .permute(0, 2, 1)
            acous_inputs_l4 = acous_outputs_l3\
             .reshape(batch_size, int(acous_len/8), 2*acous_outputs_l3.size(-1))
            # b x acous_len/8 x 4dim

            # layer4
            acous_lens_l4 = acous_lens_l3 / 2
            acous_inputs_l4_pack = torch.nn.utils.rnn.pack_padded_sequence(
                acous_inputs_l4,
                acous_lens_l4,
                batch_first=True,
                enforce_sorted=False)
            acous_outputs_l4_pack, acous_hidden_l4 = self.acous_enc_l4(
                acous_inputs_l4_pack,
                acous_hidden_init)  # b x acous_len/8 x 2dim
            acous_outputs_l4, _ = torch.nn.utils.rnn.pad_packed_sequence(
                acous_outputs_l4_pack, batch_first=True)
            acous_outputs_l4 = self.dropout(acous_outputs_l4)\
             .reshape(batch_size, int(acous_len/8), acous_outputs_l4.size(-1))
            if self.batch_norm:
                acous_outputs_l4 = self.bn4(acous_outputs_l4.permute(0, 2, 1))\
                 .permute(0, 2, 1)
            acous_outputs = acous_outputs_l4

        elif self.enc_mode == 'cnn':
            pass  #todo

        # import pdb; pdb.set_trace()
        return acous_outputs
	def forward_train(self, src, tgt=None, acous_feats=None, acous_lens=None,
		mode='ST', use_gpu=True, lm_mode='null', lm_model=None):

		"""
			mode: 	ASR 		acous -> src
					AE 			src -> src
					ST			acous -> tgt
					MT 			src -> tgt
		"""

		# for backwrad compatibility
		self.check_var('perturb_emb', False)

		# import pdb; pdb.set_trace()
		# note: adding .type(torch.uint8) to be compatible with pytorch 1.1!
		out_dict={}

		# check gpu
		global device
		device = check_device(use_gpu)

		# check mode
		mode = mode.upper()
		assert type(src) != type(None)
		if 'ST' in mode or 'ASR' in mode:
			assert type(acous_feats) != type(None)
		if 'ST' in mode or 'MT' in mode:
			assert type(tgt) != type(None)

		if 'ASR' in mode:
			"""
				acous -> EN: RNN
				in : length reduced fbk features
				out: w1 w2 w3 <EOS> <PAD> <PAD> #=6
			"""
			emb_src, logps_src, preds_src, lengths = self._encoder_acous(acous_feats, acous_lens,
				device, use_gpu, tgt=src, is_training=True, teacher_forcing_ratio=1.0,
				lm_mode=lm_mode, lm_model=lm_model)

			# output dict
			out_dict['emb_asr'] = emb_src
			out_dict['preds_asr'] = preds_src
			out_dict['logps_asr'] = logps_src
			out_dict['lengths_asr'] = lengths

		if 'AE' in mode:
			"""
				EN -> EN: Embedder
				src: <BOS> w1 w2 w3 <EOS> <PAD> <PAD> #=7
				mid: w1 w2 w3 <EOS> <PAD> <PAD> #=6
				out: w1 w2 w3 <EOS> <PAD> <PAD> #=6
			"""

			if 'ASR' in mode:
				src_trim = out_dict['preds_asr'].squeeze(2)
			else:
				# run asr: by default
				_, _, preds_src, _ = self._encoder_acous(acous_feats, acous_lens,
					device, use_gpu, tgt=src, is_training=True, teacher_forcing_ratio=1.0,
					lm_mode=lm_mode, lm_model=lm_model)
				src_trim = preds_src.squeeze(2)

			# from src: use with NLL loss
			# src_trim = self._pre_proc_src(src, device)

			src_mask, emb_src, src_mask_input = self._get_src_emb(src_trim, device)
			logits_src, logps_src, preds_src, _ = self._decoder_en(emb_src)

			# output dict
			out_dict['emb_ae'] = emb_src
			out_dict['refs_ae'] = src_trim
			out_dict['preds_ae'] = preds_src
			out_dict['logps_ae'] = logps_src

		if 'MT' in mode:
			"""
				EN -> DE: Transformer
				src: <BOS> w1 w2 w3 <EOS> <PAD> <PAD> #=7
				mid: w1 w2 w3 <EOS> <PAD> <PAD> #=6
				out: c1 c2 c3 <EOS> <PAD> <PAD> [dummy] #=7

				optional: add perturbation to embeddings
			"""
			# get tgt emb
			tgt_mask, emb_tgt = self._get_tgt_emb(tgt, device)
			# get static src emb
			src_trim = self._pre_proc_src(src, device)
			src_mask, emb_src, src_mask_input = self._get_src_emb(src_trim, device)
			# perturb emb
			if self.perturb_emb: emb_src = self._perturb_emb(emb_src, device)
			# encode decode
			enc_outputs = self._encoder_en(emb_src, src_mask=src_mask_input) # b x len x dim_model
			# decode
			dec_outputs_tgt, logits_tgt, logps_tgt, preds_tgt, _ = \
				self._decoder_de(emb_tgt, enc_outputs, tgt_mask=tgt_mask, src_mask=src_mask_input)

			# output dict
			out_dict['emb_mt'] = emb_src
			out_dict['preds_mt'] = preds_tgt
			out_dict['logps_mt'] = logps_tgt

		if 'ST' in mode:
			"""
				acous -> DE: Transformer
				in : length reduced fbk features
				mid: w1 w2 w3 <EOS> <PAD> <PAD> #=6
				out: c1 c2 c3 <EOS> <PAD> <PAD> [dummy] #=7
			"""
			# get tgt emb
			tgt_mask, emb_tgt = self._get_tgt_emb(tgt, device)
			# get dynamic src emb
			if 'ASR' in mode:
				emb_src = out_dict['emb_asr']
				lengths = out_dict['lengths_asr']
			# else:
			# 	emb_src, _, _, lengths = self._encoder_acous(acous_feats, acous_lens,
			# 		device, use_gpu, tgt=src, is_training=True, teacher_forcing_ratio=1.0)
			else: # use free running if no 'ASR'
				emb_src, _, _, lengths = self._encoder_acous(acous_feats, acous_lens,
					device, use_gpu, is_training=False, teacher_forcing_ratio=0.0,
					lm_mode=lm_mode, lm_model=lm_model)
			# get mask
			max_len = emb_src.size(1)
			lengths = torch.LongTensor(lengths)
			src_mask_input = (torch.arange(max_len).expand(len(lengths), max_len)
				< lengths.unsqueeze(1)).unsqueeze(1).to(device=device)
			# encode
			enc_outputs = self._encoder_en(emb_src, src_mask=src_mask_input) # b x len x dim_model
			# decode
			dec_outputs_tgt, logits_tgt, logps_tgt, preds_tgt, _ = \
				self._decoder_de(emb_tgt, enc_outputs, tgt_mask=tgt_mask, src_mask=src_mask_input)

			# output dict
			out_dict['emb_st'] = emb_src
			out_dict['preds_st'] = preds_tgt
			out_dict['logps_st'] = logps_tgt

		return out_dict
Exemplo n.º 23
0
def main():

	# load config
	parser = argparse.ArgumentParser(description='Seq2seq Training')
	parser = load_arguments(parser)
	args = vars(parser.parse_args())
	config = validate_config(args)

	# set random seed
	if config['random_seed'] is not None:
		set_global_seeds(config['random_seed'])

	# record config
	if not os.path.isabs(config['save']):
		config_save_dir = os.path.join(os.getcwd(), config['save'])
	if not os.path.exists(config['save']):
		os.makedirs(config['save'])

	# loading old models
	if config['load']:
		print('loading {} ...'.format(config['load']))
		config_save_dir = os.path.join(config['save'], 'model-cont.cfg')
	else:
		config_save_dir = os.path.join(config['save'], 'model.cfg')
	save_config(config, config_save_dir)

	# contruct trainer
	t = Trainer(expt_dir=config['save'],
					load_dir=config['load'],
					load_mode=config['load_mode'],
					batch_size=config['batch_size'],
					checkpoint_every=config['checkpoint_every'],
					print_every=config['print_every'],
					eval_mode=config['eval_mode'],
					eval_metric=config['eval_metric'],
					learning_rate=config['learning_rate'],
					learning_rate_init=config['learning_rate_init'],
					lr_warmup_steps=config['lr_warmup_steps'],
					eval_with_mask=config['eval_with_mask'],
					use_gpu=config['use_gpu'],
					gpu_id=config['gpu_id'],
					max_grad_norm=config['max_grad_norm'],
					max_count_no_improve=config['max_count_no_improve'],
					max_count_num_rollback=config['max_count_num_rollback'],
					keep_num=config['keep_num'],
					normalise_loss=config['normalise_loss'],
					minibatch_split=config['minibatch_split']
					)

	# load train set
	train_path_src = config['train_path_src']
	train_path_tgt = config['train_path_tgt']
	path_vocab_src = config['path_vocab_src']
	path_vocab_tgt = config['path_vocab_tgt']
	train_set = Dataset(train_path_src, train_path_tgt,
		path_vocab_src=path_vocab_src, path_vocab_tgt=path_vocab_tgt,
		seqrev=config['seqrev'],
		max_seq_len=config['max_seq_len'],
		batch_size=config['batch_size'],
		data_ratio=config['data_ratio'],
		use_gpu=config['use_gpu'],
		logger=t.logger,
		use_type=config['use_type'],
		use_type_src=config['use_type_src'])

	vocab_size_enc = len(train_set.vocab_src)
	vocab_size_dec = len(train_set.vocab_tgt)

	# load dev set
	if config['dev_path_src'] and config['dev_path_tgt']:
		dev_path_src = config['dev_path_src']
		dev_path_tgt = config['dev_path_tgt']
		dev_set = Dataset(dev_path_src, dev_path_tgt,
			path_vocab_src=path_vocab_src, path_vocab_tgt=path_vocab_tgt,
			seqrev=config['seqrev'],
			max_seq_len=config['max_seq_len'],
			batch_size=config['batch_size'],
			use_gpu=config['use_gpu'],
			logger=t.logger,
			use_type=config['use_type'],
			use_type_src=config['use_type_src'])
	else:
		dev_set = None

	# construct model
	seq2seq = Seq2seq(vocab_size_enc, vocab_size_dec,
					share_embedder=config['share_embedder'],
					enc_embedding_size=config['embedding_size_enc'],
					dec_embedding_size=config['embedding_size_dec'],
					load_embedding_src=config['load_embedding_src'],
					load_embedding_tgt=config['load_embedding_tgt'],
					num_heads=config['num_heads'],
					dim_model=config['dim_model'],
					dim_feedforward=config['dim_feedforward'],
					enc_layers=config['enc_layers'],
					dec_layers=config['dec_layers'],
					embedding_dropout=config['embedding_dropout'],
					dropout=config['dropout'],
					max_seq_len=config['max_seq_len'],
					act=config['act'],
					enc_word2id=train_set.src_word2id,
					dec_word2id=train_set.tgt_word2id,
					enc_id2word=train_set.src_id2word,
					dec_id2word=train_set.tgt_id2word,
					transformer_type=config['transformer_type'])

	# import pdb; pdb.set_trace()
	t.logger.info("total #parameters:{}".format(sum(p.numel() for p in
		seq2seq.parameters() if p.requires_grad)))

	device = check_device(config['use_gpu'])
	t.logger.info('device: {}'.format(device))
	seq2seq = seq2seq.to(device=device)

	# run training
	seq2seq = t.train(train_set, seq2seq, num_epochs=config['num_epochs'],
		dev_set=dev_set, grab_memory=config['grab_memory'])
    def forward_eval(self,
                     src=None,
                     acous_feats=None,
                     acous_lens=None,
                     mode='ST',
                     use_gpu=True,
                     lm_mode='null',
                     lm_model=None):
        """
			beam_width = 1
			note the output sequence different from training if using transformer model
		"""

        # import pdb; pdb.set_trace()
        out_dict = {}

        # check gpu
        global device
        device = check_device(use_gpu)

        # check mode
        mode = mode.upper()
        if 'ST' in mode or 'ASR' in mode:
            assert type(acous_feats) != type(None)
            batch = acous_feats.size(0)
        if 'MT' in mode or 'AE' in mode:
            assert type(src) != type(None)
            batch = src.size(0)

        length_out_src = self.max_seq_len_src
        length_out_tgt = self.max_seq_len_tgt

        if 'ASR' in mode:
            """
				acous -> EN: RNN
				in : length reduced fbk features
				out: w1 w2 w3 <EOS> <PAD> <PAD> #=6
			"""
            # run asr
            emb_src, logps_src, preds_src, lengths = self._encoder_acous(
                acous_feats,
                acous_lens,
                device,
                use_gpu,
                is_training=False,
                teacher_forcing_ratio=0.0,
                lm_mode=lm_mode,
                lm_model=lm_model)

            # output dict
            out_dict['emb_asr'] = emb_src
            out_dict['preds_asr'] = preds_src
            out_dict['logps_asr'] = logps_src
            out_dict['lengths_asr'] = lengths

        if 'MT' in mode:
            """
				EN -> DE: Transformer
				in : <BOS> w1 w2 w3 <EOS> <PAD> <PAD> #=7
				mid: w1 w2 w3 <EOS> <PAD> <PAD> <PAD> #=7
				out: <BOS> c1 c2 c3 <EOS> <PAD> <PAD> #=7
			"""
            # get src emb
            src_trim = self._pre_proc_src(src, device)
            emb_dyn_ave = self.EMB_DYN_AVE
            emb_dyn_ave_expand = emb_dyn_ave.repeat(src_trim.size(0),
                                                    src_trim.size(1),
                                                    1).to(device=device)
            src_mask, emb_src, src_mask_input = self._get_src_emb(
                src_trim, emb_dyn_ave_expand, device)
            # encoder
            enc_outputs = self._encoder_en(
                emb_src, src_mask=src_mask_input)  # b x len x dim_model

            # prep
            eos_mask_tgt, logps_tgt, dec_outputs_tgt, preds_save_tgt, preds_tgt, preds_save_tgt = \
             self._prep_eval(batch, length_out_tgt, self.dec_vocab_size, device)

            for i in range(1, self.max_seq_len_tgt):

                tgt_mask, emb_tgt = self._get_tgt_emb(preds_tgt, device)
                dec_output_tgt, logit_tgt, logp_tgt, pred_tgt, _ = \
                 self._decoder_de(emb_tgt, enc_outputs, tgt_mask=tgt_mask, src_mask=src_mask_input)

                eos_mask_tgt, dec_outputs_tgt, logps_tgt, preds_save_tgt, preds_tgt, flag \
                 = self._step_eval(i, eos_mask_tgt, dec_output_tgt, logp_tgt, pred_tgt,
                  dec_outputs_tgt, logps_tgt, preds_save_tgt, preds_tgt, batch, length_out_tgt)
                if flag == 1: break

            # output dict
            out_dict['emb_mt'] = emb_src
            out_dict['preds_mt'] = preds_tgt
            out_dict['logps_mt'] = logps_tgt

        if 'ST' in mode:
            """
				acous -> DE: Transformer
				in : length reduced fbk features
				out: <BOS> c1 c2 c3 <EOS> <PAD> <PAD> #=7
			"""
            # get embedding
            if 'ASR' in mode:
                preds_src = out_dict['preds_asr']
                emb_src_dyn = out_dict['emb_asr']
                lengths = out_dict['lengths_asr']
            else:
                emb_src_dyn, _, preds_src, lengths = self._encoder_acous(
                    acous_feats,
                    acous_lens,
                    device,
                    use_gpu,
                    is_training=False,
                    teacher_forcing_ratio=0.0,
                    lm_mode=lm_mode,
                    lm_model=lm_model)
            _, emb_src, _ = self._get_src_emb(preds_src.squeeze(2),
                                              emb_src_dyn, device)

            # get mask
            max_len = emb_src.size(1)
            lengths = torch.LongTensor(lengths)
            src_mask_input = (torch.arange(max_len).expand(
                len(lengths), max_len) < lengths.unsqueeze(1)).unsqueeze(1).to(
                    device=device)
            # encode
            enc_outputs = self._encoder_en(
                emb_src, src_mask=src_mask_input)  # b x len x dim_model

            # prep
            eos_mask_tgt, logps_tgt, dec_outputs_tgt, preds_save_tgt, preds_tgt, preds_save_tgt = \
             self._prep_eval(batch, length_out_tgt, self.dec_vocab_size, device)

            for i in range(1, self.max_seq_len_tgt):

                tgt_mask, emb_tgt = self._get_tgt_emb(preds_tgt, device)
                dec_output_tgt, logit_tgt, logp_tgt, pred_tgt, _ = \
                 self._decoder_de(emb_tgt, enc_outputs, tgt_mask=tgt_mask, src_mask=src_mask_input)

                eos_mask_tgt, dec_outputs_tgt, logps_tgt, preds_save_tgt, preds_tgt, flag \
                 = self._step_eval(i, eos_mask_tgt, dec_output_tgt, logp_tgt, pred_tgt,
                  dec_outputs_tgt, logps_tgt, preds_save_tgt, preds_tgt, batch, length_out_tgt)
                if flag == 1: break

            # output dict
            out_dict['emb_st'] = emb_src
            out_dict['preds_st'] = preds_tgt
            out_dict['logps_st'] = logps_tgt

        return out_dict
Exemplo n.º 25
0
def main():

	# load config
	parser = argparse.ArgumentParser(description='Seq2seq Evaluation')
	parser = load_arguments(parser)
	args = vars(parser.parse_args())
	config = validate_config(args)

	# load src-tgt pair
	test_path_src = config['test_path_src']
	test_path_tgt = test_path_src
	test_path_out = config['test_path_out']
	load_dir = config['load']
	max_seq_len = config['max_seq_len']
	batch_size = config['batch_size']
	beam_width = config['beam_width']
	use_gpu = config['use_gpu']
	seqrev = config['seqrev']
	use_type = config['use_type']

	# set test mode: 1 = translate; 2 = plot; 3 = save comb ckpt
	MODE = config['eval_mode']
	if MODE != 3:
		if not os.path.exists(test_path_out):
			os.makedirs(test_path_out)
		config_save_dir = os.path.join(test_path_out, 'eval.cfg')
		save_config(config, config_save_dir)

	# check device:
	device = check_device(use_gpu)
	print('device: {}'.format(device))

	# load model
	latest_checkpoint_path = load_dir
	resume_checkpoint = Checkpoint.load(latest_checkpoint_path)
	model = resume_checkpoint.model.to(device)
	vocab_src = resume_checkpoint.input_vocab
	vocab_tgt = resume_checkpoint.output_vocab
	print('Model dir: {}'.format(latest_checkpoint_path))
	print('Model laoded')

	# combine model
	if type(config['combine_path']) != type(None):
		model = combine_weights(config['combine_path'])

	# load test_set
	test_set = Dataset(test_path_src, test_path_tgt,
						vocab_src_list=vocab_src, vocab_tgt_list=vocab_tgt,
						seqrev=seqrev,
						max_seq_len=900,
						batch_size=batch_size,
						use_gpu=use_gpu,
						use_type=use_type)
	print('Test dir: {}'.format(test_path_src))
	print('Testset loaded')
	sys.stdout.flush()

	# run eval
	if MODE == 1:
		translate(test_set, model, test_path_out, use_gpu,
			max_seq_len, beam_width, device, seqrev=seqrev)

	elif MODE == 2: # output posterior
		translate_logp(test_set, model, test_path_out, use_gpu,
			max_seq_len, device, seqrev=seqrev)

	elif MODE == 3: # save combined model
		ckpt = Checkpoint(model=model,
				   optimizer=None, epoch=0, step=0,
				   input_vocab=test_set.vocab_src,
				   output_vocab=test_set.vocab_tgt)
		saved_path = ckpt.save_customise(
			os.path.join(config['combine_path'].strip('/')+'-combine','combine'))
		log_ckpts(config['combine_path'], config['combine_path'].strip('/')+'-combine')
		print('saving at {} ... '.format(saved_path))
Exemplo n.º 26
0
def main():

    # load config
    parser = argparse.ArgumentParser(description='Evaluation')
    parser = load_arguments(parser)
    args = vars(parser.parse_args())
    config = validate_config(args)

    # load src-tgt pair
    test_path_src = config['test_path_src']
    test_path_tgt = config['test_path_tgt']
    if type(test_path_tgt) == type(None):
        test_path_tgt = test_path_src

    test_path_out = config['test_path_out']
    test_acous_path = config['test_acous_path']
    acous_norm_path = config['acous_norm_path']

    load_dir = config['load']
    max_seq_len = config['max_seq_len']
    batch_size = config['batch_size']
    beam_width = config['beam_width']
    use_gpu = config['use_gpu']
    seqrev = config['seqrev']
    use_type = config['use_type']

    # set test mode
    MODE = config['eval_mode']
    if MODE != 2:
        if not os.path.exists(test_path_out):
            os.makedirs(test_path_out)
        config_save_dir = os.path.join(test_path_out, 'eval.cfg')
        save_config(config, config_save_dir)

    # check device:
    device = check_device(use_gpu)
    print('device: {}'.format(device))

    # load model
    latest_checkpoint_path = load_dir
    resume_checkpoint = Checkpoint.load(latest_checkpoint_path)
    model = resume_checkpoint.model.to(device)
    vocab_src = resume_checkpoint.input_vocab
    vocab_tgt = resume_checkpoint.output_vocab
    print('Model dir: {}'.format(latest_checkpoint_path))
    print('Model laoded')

    # combine model
    if type(config['combine_path']) != type(None):
        model = combine_weights(config['combine_path'])
    # import pdb; pdb.set_trace()

    # load test_set
    test_set = Dataset(
        path_src=test_path_src,
        path_tgt=test_path_tgt,
        vocab_src_list=vocab_src,
        vocab_tgt_list=vocab_tgt,
        use_type=use_type,
        acous_path=test_acous_path,
        seqrev=seqrev,
        acous_norm=config['acous_norm'],
        acous_norm_path=config['acous_norm_path'],
        acous_max_len=6000,  # max 50k for mustc trainset
        max_seq_len_src=900,
        max_seq_len_tgt=900,  # max 2.5k for mustc trainset
        batch_size=batch_size,
        mode='ST',
        use_gpu=use_gpu)

    print('Test dir: {}'.format(test_path_src))
    print('Testset loaded')
    sys.stdout.flush()

    # '{AE|ASR|MT|ST}-{REF|HYP}'
    if len(config['gen_mode'].split('-')) == 2:
        gen_mode = config['gen_mode'].split('-')[0]
        history = config['gen_mode'].split('-')[1]
    elif len(config['gen_mode'].split('-')) == 1:
        gen_mode = config['gen_mode']
        history = 'HYP'

    # add external language model
    lm_mode = config['lm_mode']

    # run eval:
    if MODE == 1:
        translate(test_set,
                  model,
                  test_path_out,
                  use_gpu,
                  max_seq_len,
                  beam_width,
                  device,
                  seqrev=seqrev,
                  gen_mode=gen_mode,
                  lm_mode=lm_mode,
                  history=history)

    elif MODE == 2:  # save combined model
        ckpt = Checkpoint(model=model,
                          optimizer=None,
                          epoch=0,
                          step=0,
                          input_vocab=test_set.vocab_src,
                          output_vocab=test_set.vocab_tgt)
        saved_path = ckpt.save_customise(
            os.path.join(config['combine_path'].strip('/') + '-combine',
                         'combine'))
        log_ckpts(config['combine_path'],
                  config['combine_path'].strip('/') + '-combine')
        print('saving at {} ... '.format(saved_path))

    elif MODE == 3:
        plot_emb(test_set, model, test_path_out, use_gpu, max_seq_len, device)

    elif MODE == 4:
        gather_emb(test_set, model, test_path_out, use_gpu, max_seq_len,
                   device)

    elif MODE == 5:
        compute_kl(test_set, model, test_path_out, use_gpu, max_seq_len,
                   device)
Exemplo n.º 27
0
def main():

	# load config
	parser = argparse.ArgumentParser(description='Seq2seq Training')
	parser = load_arguments(parser)
	args = vars(parser.parse_args())
	config = validate_config(args)

	# set random seed
	if config['random_seed'] is not None:
		set_global_seeds(config['random_seed'])

	# record config
	if not os.path.isabs(config['save']):
		config_save_dir = os.path.join(os.getcwd(), config['save'])
	if not os.path.exists(config['save']):
		os.makedirs(config['save'])

	# resume or not
	if config['load']:
		resume = True
		print('resuming {} ...'.format(config['load']))
		config_save_dir = os.path.join(config['save'], 'model-cont.cfg')
	else:
		resume = False
		config_save_dir = os.path.join(config['save'], 'model.cfg')
	save_config(config, config_save_dir)

	# contruct trainer
	t = Trainer(expt_dir=config['save'],
					load_dir=config['load'],
					batch_size=config['batch_size'],
					checkpoint_every=config['checkpoint_every'],
					print_every=config['print_every'],
					learning_rate=config['learning_rate'],
					eval_with_mask=config['eval_with_mask'],
					scheduled_sampling=config['scheduled_sampling'],
					teacher_forcing_ratio=config['teacher_forcing_ratio'],
					use_gpu=config['use_gpu'],
					max_grad_norm=config['max_grad_norm'],
					max_count_no_improve=config['max_count_no_improve'],
					max_count_num_rollback=config['max_count_num_rollback'],
					keep_num=config['keep_num'],
					normalise_loss=config['normalise_loss'],
					minibatch_split=config['minibatch_split'])

	# load train set
	train_path_src = config['train_path_src']
	train_path_tgt = config['train_path_tgt']
	path_vocab_src = config['path_vocab_src']
	path_vocab_tgt = config['path_vocab_tgt']
	train_set = Dataset(train_path_src, train_path_tgt,
		path_vocab_src=path_vocab_src, path_vocab_tgt=path_vocab_tgt,
		seqrev=config['seqrev'],
		max_seq_len=config['max_seq_len'],
		batch_size=config['batch_size'],
		use_gpu=config['use_gpu'],
		logger=t.logger,
		use_type=config['use_type'])

	vocab_size_enc = len(train_set.vocab_src)
	vocab_size_dec = len(train_set.vocab_tgt)

	# load dev set
	if config['dev_path_src'] and config['dev_path_tgt']:
		dev_path_src = config['dev_path_src']
		dev_path_tgt = config['dev_path_tgt']
		dev_set = Dataset(dev_path_src, dev_path_tgt,
			path_vocab_src=path_vocab_src, path_vocab_tgt=path_vocab_tgt,
			seqrev=config['seqrev'],
			max_seq_len=config['max_seq_len'],
			batch_size=config['batch_size'],
			use_gpu=config['use_gpu'],
			logger=t.logger,
			use_type=config['use_type'])
	else:
		dev_set = None

	# construct model
	seq2seq = Seq2seq(vocab_size_enc, vocab_size_dec,
					share_embedder=config['share_embedder'],
					embedding_size_enc=config['embedding_size_enc'],
					embedding_size_dec=config['embedding_size_dec'],
					embedding_dropout=config['embedding_dropout'],
					hidden_size_enc=config['hidden_size_enc'],
					num_bilstm_enc=config['num_bilstm_enc'],
					num_unilstm_enc=config['num_unilstm_enc'],
					hidden_size_dec=config['hidden_size_dec'],
					num_unilstm_dec=config['num_unilstm_dec'],
					hidden_size_att=config['hidden_size_att'],
					hidden_size_shared=config['hidden_size_shared'],
					dropout=config['dropout'],
					residual=config['residual'],
					batch_first=config['batch_first'],
					max_seq_len=config['max_seq_len'],
					load_embedding_src=config['load_embedding_src'],
					load_embedding_tgt=config['load_embedding_tgt'],
					src_word2id=train_set.src_word2id,
					tgt_word2id=train_set.tgt_word2id,
					src_id2word=train_set.src_id2word,
					tgt_id2word=train_set.tgt_id2word,
					att_mode=config['att_mode'])


	device = check_device(config['use_gpu'])
	t.logger.info('device:{}'.format(device))
	seq2seq = seq2seq.to(device=device)

	# run training
	seq2seq = t.train(train_set, seq2seq,
		num_epochs=config['num_epochs'], resume=resume, dev_set=dev_set)