Пример #1
0
def combine_weights(path):

	"""
	 	reference - qd212
		average ckpt weights under the given path
	"""

	ckpt_path_list = [os.path.join(path, ep) for ep in os.listdir(path)]
	ckpt_state_dict_list = [Checkpoint.load(ckpt_path).model.state_dict()
		for ckpt_path in ckpt_path_list]

	model = Checkpoint.load(ckpt_path_list[0]).model
	mean_state_dict = model.state_dict()
	for key in mean_state_dict.keys():
		mean_state_dict[key] = 1. * (sum(d[key] for d in ckpt_state_dict_list)
			/ len(ckpt_state_dict_list))

	model.load_state_dict(mean_state_dict)

	return model
Пример #2
0
def main():

	# load config
	parser = argparse.ArgumentParser(description='Seq2seq Evaluation')
	parser = load_arguments(parser)
	args = vars(parser.parse_args())
	config = validate_config(args)

	# load src-tgt pair
	test_path_src = config['test_path_src']
	test_path_tgt = config['test_path_tgt'] # dummy
	if type(test_path_tgt) == type(None):
		test_path_tgt = test_path_src
	test_path_out = config['test_path_out']
	load_dir = config['load']
	max_seq_len = config['max_seq_len']
	batch_size = config['batch_size']
	beam_width = config['beam_width']
	use_gpu = config['use_gpu']
	seqrev = config['seqrev']
	use_type = config['use_type']

	if not os.path.exists(test_path_out):
		os.makedirs(test_path_out)
	config_save_dir = os.path.join(test_path_out, 'eval.cfg')
	save_config(config, config_save_dir)

	# set test mode: 1 = translate; 3 = plot
	MODE = config['eval_mode']
	if MODE == 3:
		max_seq_len = 32
		batch_size = 1
		beam_width = 1
		use_gpu = False

	# check device:
	device = check_device(use_gpu)
	print('device: {}'.format(device))

	# load model
	latest_checkpoint_path = load_dir
	resume_checkpoint = Checkpoint.load(latest_checkpoint_path)
	model = resume_checkpoint.model.to(device)
	vocab_src = resume_checkpoint.input_vocab
	vocab_tgt = resume_checkpoint.output_vocab
	print('Model dir: {}'.format(latest_checkpoint_path))
	print('Model laoded')

	# load test_set
	test_set = Dataset(test_path_src, test_path_tgt,
						vocab_src_list=vocab_src, vocab_tgt_list=vocab_tgt,
						seqrev=seqrev,
						max_seq_len=max_seq_len,
						batch_size=batch_size,
						use_gpu=use_gpu,
						use_type=use_type)
	print('Test dir: {}'.format(test_path_src))
	print('Testset loaded')
	sys.stdout.flush()

	# run eval
	if MODE == 1: # FR
		translate(test_set, model, test_path_out, use_gpu,
			max_seq_len, beam_width, device, seqrev=seqrev)

	if MODE == 2:
		translate_batch(test_set, model, test_path_out, use_gpu,
			max_seq_len, beam_width, device, seqrev=seqrev)

	elif MODE == 3:
		# plotting
		att_plot(test_set, model, test_path_out, use_gpu,
			max_seq_len, beam_width, device)

	elif MODE == 4: # TF
		translate_tf(test_set, model, test_path_out, use_gpu,
			max_seq_len, beam_width, device, seqrev=seqrev)
Пример #3
0
    def train(self,
              train_set,
              model,
              num_epochs=5,
              optimizer=None,
              dev_set=None):
        """
			Run training for a given model.
			Args:
				train_set: dataset
				dev_set: dataset, optional
				model: model to run training on, if `resume=True`, it would be
				   overwritten by the model loaded from the latest checkpoint.
				num_epochs (int, optional): number of epochs to run (default 5)
				optimizer (self.optim.Optimizer, optional): optimizer for training
				   (default: Optimizer(pytorch.optim.Adam, max_grad_norm=5))

			Returns:
				model (self.models): trained model.
		"""

        torch.cuda.empty_cache()
        if type(self.load_dir) != type(None):
            latest_checkpoint_path = self.load_dir
            self.logger.info('resuming {} ...'.format(latest_checkpoint_path))
            resume_checkpoint = Checkpoint.load(latest_checkpoint_path)
            model = resume_checkpoint.model
            self.logger.info(model)
            self.optimizer = resume_checkpoint.optimizer

            # A walk around to set optimizing parameters properly
            resume_optim = self.optimizer.optimizer
            defaults = resume_optim.param_groups[0]
            defaults.pop('params', None)
            defaults.pop('initial_lr', None)
            self.optimizer.optimizer = resume_optim.__class__(
                model.parameters(), **defaults)

            model.set_idmap(train_set.src_word2id, train_set.src_id2word)
            for name, param in model.named_parameters():
                log = self.logger.info('{}:{}'.format(name, param.size()))

            # start from prev
            start_epoch = resume_checkpoint.epoch
            step = resume_checkpoint.step

            # just for the sake of finetuning
            # start_epoch = 1
            # step = 0
        else:
            start_epoch = 1
            step = 0
            self.logger.info(model)

            for name, param in model.named_parameters():
                log = self.logger.info('{}:{}'.format(name, param.size()))

            if optimizer is None:
                optimizer = Optimizer(torch.optim.Adam(model.parameters(),
                                                       lr=self.learning_rate),
                                      max_grad_norm=self.max_grad_norm)
            self.optimizer = optimizer

        self.logger.info("Optimizer: %s, Scheduler: %s" %
                         (self.optimizer.optimizer, self.optimizer.scheduler))

        self._train_epoches(train_set,
                            model,
                            num_epochs,
                            start_epoch,
                            step,
                            dev_set=dev_set)

        return model
Пример #4
0
    def _train_epoches(self,
                       train_set,
                       model,
                       n_epochs,
                       start_epoch,
                       start_step,
                       dev_set=None):

        log = self.logger

        las_print_loss_total = 0  # Reset every print_every
        step = start_step
        step_elapsed = 0
        prev_acc = 0.0
        count_no_improve = 0
        count_num_rollback = 0
        ckpt = None

        # ******************** [loop over epochs] ********************
        for epoch in range(start_epoch, n_epochs + 1):

            for param_group in self.optimizer.optimizer.param_groups:
                log.info('epoch:{} lr: {}'.format(epoch, param_group['lr']))
                lr_curr = param_group['lr']

            # ----------construct batches-----------
            log.info('--- construct train set ---')
            train_set.construct_batches(is_train=True)
            if dev_set is not None:
                log.info('--- construct dev set ---')
                dev_set.construct_batches(is_train=True)

            # --------print info for each epoch----------
            steps_per_epoch = len(train_set.iter_loader)
            total_steps = steps_per_epoch * n_epochs
            log.info("steps_per_epoch {}".format(steps_per_epoch))
            log.info("total_steps {}".format(total_steps))

            log.debug(" --------- Epoch: %d, Step: %d ---------" %
                      (epoch, step))
            mem_kb, mem_mb, mem_gb = get_memory_alloc()
            mem_mb = round(mem_mb, 2)
            log.info('Memory used: {0:.2f} MB'.format(mem_mb))
            self.writer.add_scalar('Memory_MB', mem_mb, global_step=step)
            sys.stdout.flush()

            # ******************** [loop over batches] ********************
            model.train(True)
            trainiter = iter(train_set.iter_loader)
            for idx in range(steps_per_epoch):

                # load batch items
                batch_items = trainiter.next()

                # update macro count
                step += 1
                step_elapsed += 1

                # Get loss
                losses = self._train_batch(model, batch_items, train_set, step,
                                           total_steps)

                las_loss = losses['las_loss']
                las_print_loss_total += las_loss

                if step % self.print_every == 0 and step_elapsed > self.print_every:
                    las_print_loss_avg = las_print_loss_total / self.print_every
                    las_print_loss_total = 0

                    log_msg = 'Progress: %d%%, Train las: %.4f'\
                     % (step / total_steps * 100, las_print_loss_avg)

                    log.info(log_msg)
                    self.writer.add_scalar('train_las_loss',
                                           las_print_loss_avg,
                                           global_step=step)

                # Checkpoint
                if step % self.checkpoint_every == 0 or step == total_steps:

                    # save criteria
                    if dev_set is not None:
                        dev_accs, dev_losses = self._evaluate_batches(
                            model, dev_set)
                        las_loss = dev_losses['las_loss']
                        las_acc = dev_accs['las_acc']
                        log_msg = 'Progress: %d%%, Dev las loss: %.4f, accuracy: %.4f'\
                         % (step / total_steps * 100, las_loss, las_acc)
                        log.info(log_msg)
                        self.writer.add_scalar('dev_las_loss',
                                               las_loss,
                                               global_step=step)
                        self.writer.add_scalar('dev_las_acc',
                                               las_acc,
                                               global_step=step)

                        accuracy = las_acc
                        # save
                        if prev_acc < accuracy:
                            # save best model
                            ckpt = Checkpoint(model=model,
                                              optimizer=self.optimizer,
                                              epoch=epoch,
                                              step=step,
                                              input_vocab=train_set.vocab_src,
                                              output_vocab=train_set.vocab_src)

                            saved_path = ckpt.save(self.expt_dir)
                            log.info('saving at {} ... '.format(saved_path))
                            # reset
                            prev_acc = accuracy
                            count_no_improve = 0
                            count_num_rollback = 0
                        else:
                            count_no_improve += 1

                        # roll back
                        if count_no_improve > self.max_count_no_improve:
                            # resuming
                            latest_checkpoint_path = Checkpoint.get_latest_checkpoint(
                                self.expt_dir)
                            if type(latest_checkpoint_path) != type(None):
                                resume_checkpoint = Checkpoint.load(
                                    latest_checkpoint_path)
                                log.info(
                                    'epoch:{} step: {} - rolling back {} ...'.
                                    format(epoch, step,
                                           latest_checkpoint_path))
                                model = resume_checkpoint.model
                                self.optimizer = resume_checkpoint.optimizer
                                # A walk around to set optimizing parameters properly
                                resume_optim = self.optimizer.optimizer
                                defaults = resume_optim.param_groups[0]
                                defaults.pop('params', None)
                                defaults.pop('initial_lr', None)
                                self.optimizer.optimizer = resume_optim.__class__(
                                    model.parameters(), **defaults)

                            # reset
                            count_no_improve = 0
                            count_num_rollback += 1

                        # update learning rate
                        if count_num_rollback > self.max_count_num_rollback:

                            # roll back
                            latest_checkpoint_path = Checkpoint.get_latest_checkpoint(
                                self.expt_dir)
                            if type(latest_checkpoint_path) != type(None):
                                resume_checkpoint = Checkpoint.load(
                                    latest_checkpoint_path)
                                log.info(
                                    'epoch:{} step: {} - rolling back {} ...'.
                                    format(epoch, step,
                                           latest_checkpoint_path))
                                model = resume_checkpoint.model
                                self.optimizer = resume_checkpoint.optimizer
                                # A walk around to set optimizing parameters properly
                                resume_optim = self.optimizer.optimizer
                                defaults = resume_optim.param_groups[0]
                                defaults.pop('params', None)
                                defaults.pop('initial_lr', None)
                                self.optimizer.optimizer = resume_optim.__class__(
                                    model.parameters(), **defaults)

                            # decrease lr
                            for param_group in self.optimizer.optimizer.param_groups:
                                param_group['lr'] *= 0.5
                                lr_curr = param_group['lr']
                                log.info('reducing lr ...')
                                log.info('step:{} - lr: {}'.format(
                                    step, param_group['lr']))

                            # check early stop
                            if lr_curr < 0.125 * self.learning_rate:
                                log.info('early stop ...')
                                break

                            # reset
                            count_no_improve = 0
                            count_num_rollback = 0

                        model.train(mode=True)
                        if ckpt is None:
                            ckpt = Checkpoint(model=model,
                                              optimizer=self.optimizer,
                                              epoch=epoch,
                                              step=step,
                                              input_vocab=train_set.vocab_src,
                                              output_vocab=train_set.vocab_tgt)
                        ckpt.rm_old(self.expt_dir, keep_num=self.keep_num)
                        log.info('n_no_improve {}, num_rollback {}'.format(
                            count_no_improve, count_num_rollback))

                    sys.stdout.flush()

            else:
                if dev_set is None:
                    # save every epoch if no dev_set
                    ckpt = Checkpoint(model=model,
                                      optimizer=self.optimizer,
                                      epoch=epoch,
                                      step=step,
                                      input_vocab=train_set.vocab_src,
                                      output_vocab=train_set.vocab_src)
                    saved_path = ckpt.save_epoch(self.expt_dir, epoch)
                    log.info('saving at {} ... '.format(saved_path))
                    continue

                else:
                    continue

            # break nested for loop
            break
Пример #5
0
    def train(self,
              train_set,
              model,
              num_epochs=5,
              resume=False,
              optimizer=None,
              dev_set=None):
        """
			Run training for a given model.
			Args:
				train_set: dataset
				dev_set: dataset, optional
				model: model to run training on, if `resume=True`, it would be
				   overwritten by the model loaded from the latest checkpoint.
				num_epochs (int, optional): number of epochs to run (default 5)
				resume(bool, optional): resume training with the latest checkpoint, (default False)
				optimizer (seq2seq.optim.Optimizer, optional): optimizer for training
				   (default: Optimizer(pytorch.optim.Adam, max_grad_norm=5))

			Returns:
				model (seq2seq.models): trained model.
		"""

        log = self.logger.info(
            'MAX_COUNT_NO_IMPROVE: {}'.format(MAX_COUNT_NO_IMPROVE))
        log = self.logger.info(
            'MAX_COUNT_NUM_ROLLBACK: {}'.format(MAX_COUNT_NUM_ROLLBACK))

        torch.cuda.empty_cache()
        if resume:
            # latest_checkpoint_path = Checkpoint.get_latest_checkpoint(self.load_dir)
            latest_checkpoint_path = self.load_dir
            print('resuming {} ...'.format(latest_checkpoint_path))
            resume_checkpoint = Checkpoint.load(latest_checkpoint_path)
            model = resume_checkpoint.model
            print(model)
            self.optimizer = resume_checkpoint.optimizer

            # A walk around to set optimizing parameters properly
            resume_optim = self.optimizer.optimizer
            defaults = resume_optim.param_groups[0]
            defaults.pop('params', None)
            defaults.pop('initial_lr', None)
            self.optimizer.optimizer = resume_optim.__class__(
                model.parameters(), **defaults)

            # start_epoch = resume_checkpoint.epoch
            # step = resume_checkpoint.step
            model.set_idmap(train_set.src_word2id, train_set.src_id2word)
            model.reset_batch_size(train_set.batch_size)

            for name, param in model.named_parameters():
                log = self.logger.info('{}:{}'.format(name, param.size()))

            # just for the sake of finetuning
            start_epoch = 1
            step = 0

        else:
            start_epoch = 1
            step = 0
            print(model)

            for name, param in model.named_parameters():
                log = self.logger.info('{}:{}'.format(name, param.size()))

            if optimizer is None:
                optimizer = Optimizer(
                    torch.optim.Adam(model.parameters(),
                                     lr=self.learning_rate),
                    max_grad_norm=self.max_grad_norm)  # 5 -> 1
            self.optimizer = optimizer

        self.logger.info("Optimizer: %s, Scheduler: %s" %
                         (self.optimizer.optimizer, self.optimizer.scheduler))

        self._train_epoches(train_set,
                            model,
                            num_epochs,
                            start_epoch,
                            step,
                            dev_set=dev_set)

        return model
def att_plot(test_set, load_dir, plot_path, use_gpu, max_seq_len, beam_width):

	"""
		generate attention alignment plots
		Args:
			test_set: test dataset
			load_dir: model dir
			use_gpu: on gpu/cpu
			max_seq_len
		Returns:

	"""

	# check devide
	print('cuda available: {}'.format(torch.cuda.is_available()))
	use_gpu = use_gpu and torch.cuda.is_available()

	# load model
	# latest_checkpoint_path = Checkpoint.get_latest_checkpoint(load_dir)
	# latest_checkpoint_path = Checkpoint.get_thirdlast_checkpoint(load_dir)
	latest_checkpoint_path = load_dir
	resume_checkpoint = Checkpoint.load(latest_checkpoint_path)

	model = resume_checkpoint.model.to(device)
	print('Model dir: {}'.format(latest_checkpoint_path))
	print('Model laoded')

	# reset batch_size:
	model.reset_max_seq_len(max_seq_len)
	model.reset_use_gpu(use_gpu)
	model.reset_batch_size(test_set.batch_size)

	# in plotting mode always turn off beam search
	model.set_beam_width(beam_width=0)
	model.check_var('ptr_net')
	print('max seq len {}'.format(model.max_seq_len))
	print('ptr_net {}'.format(model.ptr_net))

	# load test
	if type(test_set.attkey_path) == type(None):
		test_batches, vocab_size = test_set.construct_batches(is_train=False)
	else:
		test_batches, vocab_size = test_set.construct_batches_with_ddfd_prob(is_train=False)

	# start eval
	model.eval()
	match = 0
	total = 0
	count = 0
	with torch.no_grad():
		for batch in test_batches:

			src_ids = batch['src_word_ids']
			src_lengths = batch['src_sentence_lengths']
			tgt_ids = batch['tgt_word_ids']
			tgt_lengths = batch['tgt_sentence_lengths']
			src_probs = None
			if 'src_ddfd_probs' in batch:
				src_probs =  batch['src_ddfd_probs']
				src_probs = _convert_to_tensor(src_probs, use_gpu).unsqueeze(2)

			src_ids = _convert_to_tensor(src_ids, use_gpu)
			tgt_ids = _convert_to_tensor(tgt_ids, use_gpu)

			decoder_outputs, decoder_hidden, other = model(src_ids, tgt_ids,
															is_training=False,
															att_key_feats=src_probs,
															beam_width=0)
			# Evaluation
			# default batch_size = 1
			# attention: 31 * [1 x 1 x 32] ( tgt_len(query_len) * [ batch_size x 1 x src_len(key_len)] )
			attention = other['attention_score']
			seqlist = other['sequence'] # traverse over time not batch
			bsize = test_set.batch_size
			max_seq = test_set.max_seq_len
			vocab_size = len(test_set.tgt_word2id)
			for idx in range(len(decoder_outputs)): # loop over max_seq
				step = idx
				step_output = decoder_outputs[idx] # 64 x vocab_size
				# count correct
				target = tgt_ids[:, step+1]
				non_padding = target.ne(PAD)
				correct = seqlist[step].view(-1).eq(target).masked_select(non_padding).sum().item()
				match += correct
				total += non_padding.sum().item()

			# Print sentence by sentence
			srcwords = _convert_to_words_batchfirst(src_ids, test_set.src_id2word)
			refwords = _convert_to_words_batchfirst(tgt_ids[:,1:], test_set.tgt_id2word)
			seqwords = _convert_to_words(seqlist, test_set.tgt_id2word)
			# print(type(attention))
			# print(len(attention))
			# print(type(attention[0]))
			# print(attention[0].size())
			# input('...')
			n_q = len(attention)
			n_k = attention[0].size(2)
			b_size =  attention[0].size(0)
			att_score = torch.empty(n_q, n_k, dtype=torch.float)
			# att_score = np.empty([n_q, n_k])

			for i in range(len(seqwords)): # loop over sentences
				outline_src = ' '.join(srcwords[i])
				outline_ref = ' '.join(refwords[i])
				outline_gen = ' '.join(seqwords[i])
				print('SRC: {}'.format(outline_src))
				print('REF: {}'.format(outline_ref))
				print('GEN: {}'.format(outline_gen))
				for j in range(len(attention)):
					# i: idx of batch
					# j: idx of query
					gen = seqwords[i][j]
					ref = refwords[i][j]
					att = attention[j][i]
					# record att scores
					att_score[j] = att

					# print('REF:GEN - {}:{}'.format(ref,gen))
					# print('{}th ATT size: {}'.format(j, attention[j][i].size()))
					# print(att)
					# print(torch.argmax(att))
					# print(sum(sum(att)))
					# input('Press enter to continue ...')

				# plotting
				# print(att_score)
				loc_eos_k = srcwords[i].index('</s>') + 1
				loc_eos_q = seqwords[i].index('</s>') + 1
				loc_eos_ref = refwords[i].index('</s>') + 1
				print('eos_k: {}, eos_q: {}'.format(loc_eos_k, loc_eos_q))
				att_score_trim = att_score[:loc_eos_q, :loc_eos_k] # each row (each query) sum up to 1
				print(att_score_trim)
				print('\n')
				# import pdb; pdb.set_trace()

				choice = input('Plot or not ? - y/n\n')
				if choice:
					if choice.lower()[0] == 'y':
						print('plotting ...')
						plot_dir = os.path.join(plot_path, '{}.png'.format(count))
						src = srcwords[i][:loc_eos_k]
						hyp = seqwords[i][:loc_eos_q]
						ref = refwords[i][:loc_eos_ref]
						# x-axis: src; y-axis: hyp
						# plot_alignment(att_score_trim.numpy(), plot_dir, src=src, hyp=hyp, ref=ref)
						plot_alignment(att_score_trim.numpy(), plot_dir, src=src, hyp=hyp, ref=None) # no ref
						count += 1
						input('Press enter to continue ...')


	if total == 0:
		accuracy = float('nan')
	else:
		accuracy = match / total
	print(saccuracy)
Пример #7
0
def translate(test_set,
              model,
              test_path_out,
              use_gpu,
              max_seq_len,
              beam_width,
              device,
              seqrev=False,
              gen_mode='ASR',
              lm_mode='null',
              history='HYP'):
    """
		no reference tgt given - Run translation.
		Args:
			test_set: test dataset
				src, tgt using the same dir
			test_path_out: output dir
			load_dir: model dir
			use_gpu: on gpu/cpu
	"""

    modes = '_'.join([model.mode, gen_mode])
    # reset max_len
    if 'ASR' in modes or 'ST' in modes:
        model.las.decoder.max_seq_len = 150
    if 'MT' in modes:
        model.enc_src.expand_time(150)
    if 'ST' in modes or 'MT' in modes:
        model.dec_tgt.expand_time(max_seq_len)

    print('max seq len {}'.format(max_seq_len))
    sys.stdout.flush()

    # load lm
    mode = lm_mode.split('_')[0]
    if mode == 'null':
        lm_model = None
        LM_PATH = None
    elif mode == 's-4g':
        corpus = lm_mode.split('_')[1]
        LM_BASE = '/home/alta/BLTSpeaking/exp-ytl28/projects/lib/lms/pkl/'
        if corpus == 'ted':
            LM_PATH = os.path.join(LM_BASE, 'idlm-ted-train.pkl')
        elif corpus == 'mustc':
            LM_PATH = os.path.join(LM_BASE, 'idlm-mustc-train.pkl')
        with open(LM_PATH, 'rb') as fin:
            lm_model = pickle.load(fin)
    elif mode == 's-rnn':
        corpus = lm_mode.split('_')[1]
        LM_BASE = '/home/alta/BLTSpeaking/exp-ytl28/projects/rnnlm'
        if corpus == 'ted':
            LM_PATH = os.path.join(
                LM_BASE, 'models/ted-v001/checkpoints-combine/combine')
        elif corpus == 'mustc':
            LM_PATH = os.path.join(
                LM_BASE, 'models/mustc-v001/checkpoints-combine/combine')
        ckpt = Checkpoint.load(LM_PATH)
        lm_model = ckpt.model.to(device)
    print('LM {} - {} loaded'.format(lm_mode, LM_PATH))

    # load test
    test_set.construct_batches(is_train=False)
    evaliter = iter(test_set.iter_loader)

    print('num batches: {}'.format(len(evaliter)))
    with open(os.path.join(test_path_out, 'translate.txt'),
              'w',
              encoding="utf8") as f:
        model.eval()
        with torch.no_grad():
            for idx in range(len(evaliter)):

                print(idx + 1, len(evaliter))
                batch_items = evaliter.next()

                # load data
                src_ids = batch_items['srcid'][0]
                src_lengths = batch_items['srclen']
                tgt_ids = batch_items['tgtid'][0]
                tgt_lengths = batch_items['tgtlen']
                acous_feats = batch_items['acous_feat'][0]
                acous_lengths = batch_items['acouslen']

                src_len = max(src_lengths)
                tgt_len = max(tgt_lengths)
                acous_len = max(acous_lengths)
                src_ids = src_ids[:, :src_len].to(device=device)
                tgt_ids = tgt_ids.to(device=device)
                acous_feats = acous_feats.to(device=device)

                n_minibatch = int(tgt_len / 100 + tgt_len % 100 > 0)
                minibatch_size = int(src_ids.size(0) / n_minibatch)
                n_minibatch = int(src_ids.size(0) / minibatch_size) + \
                 (src_ids.size(0) % minibatch_size > 0)

                for j in range(n_minibatch):

                    st = j * minibatch_size
                    ed = min((j + 1) * minibatch_size, src_ids.size(0))
                    src_ids_sub = src_ids[st:ed, :]
                    tgt_ids_sub = tgt_ids[st:ed, :]
                    acous_feats_sub = acous_feats[st:ed, :]
                    acous_lengths_sub = acous_lengths[st:ed]
                    print('minibatch: ', st, ed, src_ids.size(0))

                    time1 = time.time()
                    if history == 'HYP':
                        preds = model.forward_translate(
                            acous_feats=acous_feats_sub,
                            acous_lens=acous_lengths_sub,
                            src=src_ids_sub,
                            beam_width=beam_width,
                            use_gpu=use_gpu,
                            max_seq_len=max_seq_len,
                            mode=gen_mode,
                            lm_mode=lm_mode,
                            lm_model=lm_model)
                    elif history == 'REF':
                        preds = model.forward_translate_refen(
                            acous_feats=acous_feats_sub,
                            acous_lens=acous_lengths_sub,
                            src=src_ids_sub,
                            beam_width=beam_width,
                            use_gpu=use_gpu,
                            max_seq_len=max_seq_len,
                            mode=gen_mode,
                            lm_mode=lm_mode,
                            lm_model=lm_model)
                    time2 = time.time()
                    print('comp time: ', time2 - time1)

                    # ------ debug ------
                    # import pdb; pdb.set_trace()
                    # out_dict = model.forward_eval(acous_feats=acous_feats_sub,
                    # 	acous_lens=acous_lengths_sub, src=src_ids_sub,
                    # 	use_gpu=use_gpu, mode=gen_mode)
                    # -------------------

                    # write to file
                    if gen_mode == 'MT' or gen_mode == 'ST':
                        seqlist = preds[:, 1:]
                        seqwords = _convert_to_words_batchfirst(
                            seqlist, test_set.tgt_id2word)
                        use_type = 'char'
                    elif gen_mode == 'AE' or gen_mode == 'ASR':
                        seqlist = preds
                        seqwords = _convert_to_words_batchfirst(
                            seqlist, test_set.src_id2word)
                        use_type = 'word'

                    for i in range(len(seqwords)):
                        words = []
                        for word in seqwords[i]:
                            if word == '<pad>':
                                continue
                            elif word == '<spc>':
                                words.append(' ')
                            elif word == '</s>':
                                break
                            else:
                                words.append(word)
                        if len(words) == 0:
                            outline = ''
                        else:
                            if seqrev:
                                words = words[::-1]
                            if use_type == 'word':
                                outline = ' '.join(words)
                            elif use_type == 'char':
                                outline = ''.join(words)
                        f.write('{}\n'.format(outline))

                        # import pdb; pdb.set_trace()
                    sys.stdout.flush()
Пример #8
0
    def _train_epochs(self,
                      train_set,
                      model,
                      n_epochs,
                      start_epoch,
                      start_step,
                      dev_set=None):

        log = self.logger

        print_loss_total = 0  # Reset every print_every
        step = start_step
        step_elapsed = 0
        prev_acc = 0.0
        prev_bleu = 0.0
        count_no_improve = 0
        count_num_rollback = 0
        ckpt = None

        # loop over epochs
        for epoch in range(start_epoch, n_epochs + 1):

            # update lr
            if self.lr_warmup_steps != 0:
                self.optimizer.optimizer = self.lr_scheduler(
                    self.optimizer.optimizer,
                    step,
                    init_lr=self.learning_rate_init,
                    peak_lr=self.learning_rate,
                    warmup_steps=self.lr_warmup_steps)

            # print lr
            for param_group in self.optimizer.optimizer.param_groups:
                log.info('epoch:{} lr: {}'.format(epoch, param_group['lr']))
                lr_curr = param_group['lr']

            # construct batches - allow re-shuffling of data
            log.info('--- construct train set ---')
            train_set.construct_batches(is_train=True)
            if dev_set is not None:
                log.info('--- construct dev set ---')
                dev_set.construct_batches(is_train=False)

            # print info
            steps_per_epoch = len(train_set.iter_loader)
            total_steps = steps_per_epoch * n_epochs
            log.info("steps_per_epoch {}".format(steps_per_epoch))
            log.info("total_steps {}".format(total_steps))

            log.info(" ---------- Epoch: %d, Step: %d ----------" %
                     (epoch, step))
            mem_kb, mem_mb, mem_gb = get_memory_alloc()
            mem_mb = round(mem_mb, 2)
            log.info('Memory used: {0:.2f} MB'.format(mem_mb))
            self.writer.add_scalar('Memory_MB', mem_mb, global_step=step)
            sys.stdout.flush()

            # loop over batches
            model.train(True)
            trainiter = iter(train_set.iter_loader)
            for idx in range(steps_per_epoch):

                # load batch items
                batch_items = trainiter.next()

                # update macro count
                step += 1
                step_elapsed += 1

                if self.lr_warmup_steps != 0:
                    self.optimizer.optimizer = self.lr_scheduler(
                        self.optimizer.optimizer,
                        step,
                        init_lr=self.learning_rate_init,
                        peak_lr=self.learning_rate,
                        warmup_steps=self.lr_warmup_steps)

                # Get loss
                loss = self._train_batch(model, batch_items, train_set, step,
                                         total_steps)
                print_loss_total += loss

                if step % self.print_every == 0 and step_elapsed > self.print_every:
                    print_loss_avg = print_loss_total / self.print_every
                    print_loss_total = 0

                    log_msg = 'Progress: %d%%, Train nlll: %.4f' % (
                        step / total_steps * 100, print_loss_avg)

                    log.info(log_msg)
                    self.writer.add_scalar('train_loss',
                                           print_loss_avg,
                                           global_step=step)

                # Checkpoint
                if step % self.checkpoint_every == 0 or step == total_steps:

                    # save criteria
                    if dev_set is not None:
                        losses, metrics = self._evaluate_batches(
                            model, dev_set)

                        loss = losses['nll_loss']
                        accuracy = metrics['accuracy']
                        bleu = metrics['bleu']
                        log_msg = 'Progress: %d%%, Dev loss: %.4f, accuracy: %.4f, bleu: %.4f' % (
                            step / total_steps * 100, loss, accuracy, bleu)
                        log.info(log_msg)
                        self.writer.add_scalar('dev_loss',
                                               loss,
                                               global_step=step)
                        self.writer.add_scalar('dev_acc',
                                               accuracy,
                                               global_step=step)
                        self.writer.add_scalar('dev_bleu',
                                               bleu,
                                               global_step=step)

                        # save condition
                        cond_acc = (prev_acc <= accuracy)
                        cond_bleu = (((prev_acc <= accuracy) and (bleu < 0.1))
                                     or prev_bleu <= bleu)

                        # save
                        if self.eval_metric == 'tokacc':
                            save_cond = cond_acc
                        elif self.eval_metric == 'bleu':
                            save_cond = cond_bleu
                        if save_cond:
                            # save best model
                            ckpt = Checkpoint(model=model,
                                              optimizer=self.optimizer,
                                              epoch=epoch,
                                              step=step,
                                              input_vocab=train_set.vocab_src,
                                              output_vocab=train_set.vocab_tgt)

                            saved_path = ckpt.save(self.expt_dir)
                            log.info('saving at {} ... '.format(saved_path))
                            # reset
                            prev_acc = accuracy
                            prev_bleu = bleu
                            count_no_improve = 0
                            count_num_rollback = 0
                        else:
                            count_no_improve += 1

                        # roll back
                        if count_no_improve > self.max_count_no_improve:
                            # no roll back - break after self.max_count_no_improve epochs
                            if self.max_count_num_rollback == 0:
                                break
                            # resuming
                            latest_checkpoint_path = Checkpoint.get_latest_checkpoint(
                                self.expt_dir)
                            if type(latest_checkpoint_path) != type(None):
                                resume_checkpoint = Checkpoint.load(
                                    latest_checkpoint_path)
                                log.info(
                                    'epoch:{} step: {} - rolling back {} ...'.
                                    format(epoch, step,
                                           latest_checkpoint_path))
                                model = resume_checkpoint.model
                                self.optimizer = resume_checkpoint.optimizer
                                # A walk around to set optimizing parameters properly
                                resume_optim = self.optimizer.optimizer
                                defaults = resume_optim.param_groups[0]
                                defaults.pop('params', None)
                                defaults.pop('initial_lr', None)
                                self.optimizer.optimizer = resume_optim.__class__(
                                    model.parameters(), **defaults)

                            # reset
                            count_no_improve = 0
                            count_num_rollback += 1

                        # update learning rate
                        if count_num_rollback > self.max_count_num_rollback:

                            # roll back
                            latest_checkpoint_path = Checkpoint.get_latest_checkpoint(
                                self.expt_dir)
                            if type(latest_checkpoint_path) != type(None):
                                resume_checkpoint = Checkpoint.load(
                                    latest_checkpoint_path)
                                log.info(
                                    'epoch:{} step: {} - rolling back {} ...'.
                                    format(epoch, step,
                                           latest_checkpoint_path))
                                model = resume_checkpoint.model
                                self.optimizer = resume_checkpoint.optimizer
                                # A walk around to set optimizing parameters properly
                                resume_optim = self.optimizer.optimizer
                                defaults = resume_optim.param_groups[0]
                                defaults.pop('params', None)
                                defaults.pop('initial_lr', None)
                                self.optimizer.optimizer = resume_optim.__class__(
                                    model.parameters(), **defaults)

                            # decrease lr
                            for param_group in self.optimizer.optimizer.param_groups:
                                param_group['lr'] *= 0.5
                                lr_curr = param_group['lr']
                                log.info('reducing lr ...')
                                log.info('step:{} - lr: {}'.format(
                                    step, param_group['lr']))

                            # check early stop
                            if lr_curr <= 0.125 * self.learning_rate:
                                log.info('early stop ...')
                                break

                            # reset
                            count_no_improve = 0
                            count_num_rollback = 0

                        model.train(mode=True)
                        if ckpt is None:
                            ckpt = Checkpoint(model=model,
                                              optimizer=self.optimizer,
                                              epoch=epoch,
                                              step=step,
                                              input_vocab=train_set.vocab_src,
                                              output_vocab=train_set.vocab_tgt)
                        ckpt.rm_old(self.expt_dir, keep_num=self.keep_num)
                        log.info('n_no_improve {}, num_rollback {}'.format(
                            count_no_improve, count_num_rollback))

                    sys.stdout.flush()

            else:
                if dev_set is None:
                    # save every epoch if no dev_set
                    ckpt = Checkpoint(model=model,
                                      optimizer=self.optimizer,
                                      epoch=epoch,
                                      step=step,
                                      input_vocab=train_set.vocab_src,
                                      output_vocab=train_set.vocab_tgt)
                    saved_path = ckpt.save_epoch(self.expt_dir, epoch)
                    log.info('saving at {} ... '.format(saved_path))
                    continue

                else:
                    continue

            # break nested for loop
            break
Пример #9
0
    def train(self,
              train_set,
              model,
              num_epochs=5,
              optimizer=None,
              dev_set=None,
              grab_memory=True):
        """
			Run training for a given model.
			Args:
				train_set: dataset
				dev_set: dataset, optional
				model: model to run training on, if `resume=True`, it would be
				   overwritten by the model loaded from the latest checkpoint.
				num_epochs (int, optional): number of epochs to run
				resume(bool, optional): resume training with the latest checkpoint
				optimizer (seq2seq.optim.Optimizer, optional): optimizer for training

			Returns:
				model (seq2seq.models): trained model.
		"""

        if 'resume' in self.load_mode or 'restart' in self.load_mode:

            assert type(self.load_dir) != type(None)

            # resume training
            latest_checkpoint_path = self.load_dir
            self.logger.info('{} {} ...'.format(self.load_mode,
                                                latest_checkpoint_path))
            resume_checkpoint = Checkpoint.load(latest_checkpoint_path)
            model = resume_checkpoint.model
            self.logger.info(model)
            self.optimizer = resume_checkpoint.optimizer
            if self.optimizer is None:
                self.optimizer = Optimizer(torch.optim.Adam(
                    model.parameters(), lr=self.learning_rate_init),
                                           max_grad_norm=self.max_grad_norm)

            # A walk around to set optimizing parameters properly
            resume_optim = self.optimizer.optimizer
            defaults = resume_optim.param_groups[0]
            defaults.pop('params', None)
            defaults.pop('initial_lr', None)
            self.optimizer.optimizer = resume_optim.__class__(
                model.parameters(), **defaults)

            for name, param in model.named_parameters():
                log = self.logger.info('{}:{}'.format(name, param.size()))

            # set step/epoch
            if 'resume' in self.load_mode:
                # start from prev
                start_epoch = resume_checkpoint.epoch  # start from the saved epoch!
                step = resume_checkpoint.step  # start from the saved step!
            elif 'restart' in self.load_mode:
                # just for the sake of finetuning
                start_epoch = 1
                step = 0

        else:
            start_epoch = 1
            step = 0
            self.logger.info(model)

            for name, param in model.named_parameters():
                log = self.logger.info('{}:{}'.format(name, param.size()))

            if optimizer is None:
                optimizer = Optimizer(torch.optim.Adam(
                    model.parameters(), lr=self.learning_rate_init),
                                      max_grad_norm=self.max_grad_norm)
            self.optimizer = optimizer

        self.logger.info("Optimizer: %s, Scheduler: %s" %
                         (self.optimizer.optimizer, self.optimizer.scheduler))

        # reserve memory
        # import pdb; pdb.set_trace()
        if self.device == torch.device('cuda') and grab_memory:
            reserve_memory(device_id=self.gpu_id)

        self._train_epochs(train_set,
                           model,
                           num_epochs,
                           start_epoch,
                           step,
                           dev_set=dev_set)

        return model
Пример #10
0
def debug_beam_search(test_set, load_dir, use_gpu, max_seq_len, beam_width):
    """
		with reference tgt given - debug beam search.
		Args:
			test_set: test dataset
			load_dir: model dir
			use_gpu: on gpu/cpu
		Returns:
			accuracy (excluding PAD tokens)
	"""

    # load model
    # latest_checkpoint_path = Checkpoint.get_latest_checkpoint(load_dir)
    # latest_checkpoint_path = Checkpoint.get_thirdlast_checkpoint(load_dir)
    latest_checkpoint_path = load_dir
    resume_checkpoint = Checkpoint.load(latest_checkpoint_path)

    model = resume_checkpoint.model
    print('Model dir: {}'.format(latest_checkpoint_path))
    print('Model laoded')

    # reset batch_size:
    model.reset_max_seq_len(max_seq_len)
    model.reset_use_gpu(use_gpu)
    model.reset_batch_size(test_set.batch_size)
    print('max seq len {}'.format(model.max_seq_len))
    sys.stdout.flush()

    # load test
    if type(test_set.attkey_path) == type(None):
        test_batches, vocab_size = test_set.construct_batches(is_train=False)
    else:
        test_batches, vocab_size = test_set.construct_batches_with_ddfd_prob(
            is_train=False)

    model.eval()
    match = 0
    total = 0
    with torch.no_grad():
        for batch in test_batches:

            src_ids = batch['src_word_ids']
            src_lengths = batch['src_sentence_lengths']
            tgt_ids = batch['tgt_word_ids']
            tgt_lengths = batch['tgt_sentence_lengths']
            src_probs = None
            if 'src_ddfd_probs' in batch:
                src_probs = batch['src_ddfd_probs']
                src_probs = _convert_to_tensor(src_probs, use_gpu).unsqueeze(2)

            src_ids = _convert_to_tensor(src_ids, use_gpu)
            tgt_ids = _convert_to_tensor(tgt_ids, use_gpu)

            decoder_outputs, decoder_hidden, other = model(
                src_ids,
                tgt_ids,
                is_training=False,
                att_key_feats=src_probs,
                beam_width=beam_width)

            # Evaluation
            seqlist = other['sequence']  # traverse over time not batch
            if beam_width > 1:
                # print('dict:sequence')
                # print(len(seqlist))
                # print(seqlist[0].size())

                full_seqlist = other['topk_sequence']
                # print('dict:topk_sequence')
                # print(len(full_seqlist))
                # print((full_seqlist[0]).size())
                # input('...')
                seqlists = []
                for i in range(beam_width):
                    seqlists.append([seq[:, i] for seq in full_seqlist])

                # print(decoder_outputs[0].size())
                # print('tgt id size {}'.format(tgt_ids.size()))
                # input('...')

                decoder_outputs = decoder_outputs[:-1]
                # print(len(decoder_outputs))

            for step, step_output in enumerate(
                    decoder_outputs):  # loop over time steps
                target = tgt_ids[:, step + 1]
                non_padding = target.ne(PAD)
                # print('step', step)
                # print('target', target)
                # print('hyp', seqlist[step])
                # if beam_width > 1:
                # 	print('full_seqlist', full_seqlist[step])
                # input('...')
                correct = seqlist[step].view(-1).eq(target).masked_select(
                    non_padding).sum().item()
                match += correct
                total += non_padding.sum().item()

            # write to file
            refwords = _convert_to_words_batchfirst(tgt_ids[:, 1:],
                                                    test_set.tgt_id2word)
            seqwords = _convert_to_words(seqlist, test_set.tgt_id2word)
            seqwords_list = []
            for i in range(beam_width):
                seqwords_list.append(
                    _convert_to_words(seqlists[i], test_set.tgt_id2word))

            for i in range(len(seqwords)):
                outline_ref = ' '.join(refwords[i])
                print('REF', outline_ref)
                outline_hyp = ' '.join(seqwords[i])
                # print(outline_hyp)
                outline_hyps = []
                for j in range(beam_width):
                    outline_hyps.append(' '.join(seqwords_list[j][i]))
                    print('{}th'.format(j), outline_hyps[-1])

                # skip padding sentences in batch (num_sent % batch_size != 0)
                # if src_lengths[i] == 0:
                # 	continue
                # words = []
                # for word in seqwords[i]:
                # 	if word == '<pad>':
                # 		continue
                # 	elif word == '</s>':
                # 		break
                # 	else:
                # 		words.append(word)
                # if len(words) == 0:
                # 	outline = ''
                # else:
                # 	outline = ' '.join(words)

                input('...')

            sys.stdout.flush()

        if total == 0:
            accuracy = float('nan')
        else:
            accuracy = match / total

    return accuracy
Пример #11
0
def acous_att_plot(test_set, load_dir, plot_path, use_gpu, max_seq_len,
                   beam_width):
    """
		generate attention alignment plots
		Args:
			test_set: test dataset
			load_dir: model dir
			use_gpu: on gpu/cpu
			max_seq_len
		Returns:

	"""

    # import pdb; pdb.set_trace()
    use_gpu = False

    # load model
    latest_checkpoint_path = load_dir
    resume_checkpoint = Checkpoint.load(latest_checkpoint_path)

    model = resume_checkpoint.model
    print('Model dir: {}'.format(latest_checkpoint_path))
    print('Model laoded')

    # reset batch_size:
    model.cpu()
    model.reset_max_seq_len(max_seq_len)
    print('max seq len {}'.format(model.max_seq_len))
    sys.stdout.flush()

    # load test
    test_set.construct_batches(is_train=False)
    evaliter = iter(test_set.iter_loader)
    total = 0
    print('total #batches: {}'.format(len(evaliter)))

    # start eval
    count = 0
    model.eval()
    with torch.no_grad():
        for idx in range(len(evaliter)):

            batch_items = evaliter.next()
            src_ids = batch_items[0][0].to(device=device)
            src_lengths = batch_items[1]
            acous_feats = batch_items[2][0].to(device=device)
            acous_lengths = batch_items[3]
            labs = batch_items[4][0].to(device=device)
            acous_times = batch_items[5]

            batch_size = src_ids.size(0)
            seq_len = int(max(src_lengths))
            acous_len = int(max(acous_lengths))

            decoder_outputs, decoder_hidden, ret_dict = model(
                src_ids,
                acous_feats=acous_feats,
                acous_times=acous_times,
                is_training=False,
                use_gpu=use_gpu,
                beam_width=beam_width)
            # attention: [32 x ?] (batch_size x src_len x acous_len(key_len))
            # default batch_size = 1
            i = 0
            bsize = test_set.batch_size
            max_seq = test_set.max_seq_len
            vocab_size = len(test_set.src_word2id)

            if not model.add_times:
                # Print sentence by sentence
                # import pdb; pdb.set_trace()
                attention = torch.cat(ret_dict['attention_score'], dim=1)[i]

                seqlist = ret_dict['sequence']
                seqwords = _convert_to_words(seqlist, test_set.src_id2word)
                outline_gen = ' '.join(seqwords[i])
                srcwords = _convert_to_words_batchfirst(
                    src_ids, test_set.src_id2word)
                outline_src = ' '.join(srcwords[i])
                print('SRC: {}'.format(outline_src))
                print('GEN: {}'.format(outline_gen))

                # plotting
                # import pdb; pdb.set_trace()
                loc_eos_k = srcwords[i].index('</s>') + 1
                print('eos_k: {}'.format(loc_eos_k))
                # loc_eos_m = seqwords[i].index('</s>') + 1
                loc_eos_m = len(seqwords[i])
                print('eos_m: {}'.format(loc_eos_m))

                att_score_trim = attention[:
                                           loc_eos_m, :]  #each row (each query) sum up to 1
                print('att size: {}'.format(att_score_trim.size()))
                # print('\n')

                choice = input('Plot or not ? - y/n\n')
                if choice:
                    if choice.lower()[0] == 'y':
                        print('plotting ...')
                        plot_dir = os.path.join(plot_path,
                                                '{}.png'.format(count))
                        src = srcwords[i][:loc_eos_m]
                        gen = seqwords[i][:loc_eos_m]

                        # x-axis: acous; y-axis: src
                        plot_attention(att_score_trim.numpy(),
                                       plot_dir,
                                       gen,
                                       words_right=src)  # no ref
                        count += 1
                        input('Press enter to continue ...')

            else:
                # import pdb; pdb.set_trace()
                attention = torch.cat(ret_dict['attention_score'], dim=0)

                srcwords = _convert_to_words_batchfirst(
                    src_ids, test_set.src_id2word)
                outline_src = ' '.join(srcwords[i])
                print('SRC: {}'.format(outline_src))
                loc_eos_k = srcwords[i].index('</s>') + 1
                print('eos_k: {}'.format(loc_eos_k))
                att_score_trim = attention[:
                                           loc_eos_k, :]  #each row (each query) sum up to 1
                print('att size: {}'.format(att_score_trim.size()))

                choice = input('Plot or not ? - y/n\n')
                if choice:
                    if choice.lower()[0] == 'y':
                        print('plotting ...')
                        plot_dir = os.path.join(plot_path,
                                                '{}.png'.format(count))
                        src = srcwords[i][:loc_eos_k]

                        # x-axis: acous; y-axis: src
                        plot_attention(att_score_trim.numpy(),
                                       plot_dir,
                                       src,
                                       words_right=src)  # no ref
                        count += 1
                        input('Press enter to continue ...')
Пример #12
0
def translate_acous(test_set,
                    load_dir,
                    test_path_out,
                    use_gpu,
                    max_seq_len,
                    beam_width,
                    seqrev=False):
    """
		no reference tgt given - Run translation.
		Args:
			test_set: test dataset
				src, tgt using the same dir
			test_path_out: output dir
			load_dir: model dir
			use_gpu: on gpu/cpu
	"""

    # load model
    latest_checkpoint_path = load_dir
    resume_checkpoint = Checkpoint.load(latest_checkpoint_path)

    model = resume_checkpoint.model.to(device)
    print('Model dir: {}'.format(latest_checkpoint_path))
    print('Model laoded')

    # reset batch_size:
    model.reset_max_seq_len(max_seq_len)
    print('max seq len {}'.format(model.max_seq_len))
    sys.stdout.flush()

    # load test
    test_set.construct_batches(is_train=False)
    evaliter = iter(test_set.iter_loader)
    total = 0
    print('total #batches: {}'.format(len(evaliter)))

    f2 = open(os.path.join(test_path_out, 'translate.tsv'), 'w')
    with open(os.path.join(test_path_out, 'translate.txt'),
              'w',
              encoding="utf8") as f:
        model.eval()
        with torch.no_grad():
            for idx in range(len(evaliter)):
                batch_items = evaliter.next()
                src_ids = batch_items[0][0].to(device=device)
                src_lengths = batch_items[1]
                acous_feats = batch_items[2][0].to(device=device)
                acous_lengths = batch_items[3]
                labs = batch_items[4][0].to(device=device)
                acous_times = batch_items[5]

                batch_size = src_ids.size(0)
                seq_len = int(max(src_lengths))
                acous_len = int(max(acous_lengths))

                decoder_outputs, decoder_hidden, other = model(
                    src_ids,
                    acous_feats=acous_feats,
                    acous_times=acous_times,
                    is_training=False,
                    use_gpu=use_gpu,
                    beam_width=beam_width)

                # memory usage
                mem_kb, mem_mb, mem_gb = get_memory_alloc()
                mem_mb = round(mem_mb, 2)
                print('Memory used: {0:.2f} MB'.format(mem_mb))
                batch_size = src_ids.size(0)

                model.check_var('add_times', False)
                if not model.add_times:

                    # write to file
                    # import pdb; pdb.set_trace()
                    seqlist = other['sequence']
                    seqwords = _convert_to_words(seqlist, test_set.src_id2word)

                    # print las output
                    model.check_var('use_type', 'char')
                    total += len(seqwords)
                    if model.use_type == 'char':
                        for i in range(len(seqwords)):
                            # skip padding sentences in batch (num_sent % batch_size != 0)
                            if src_lengths[i] == 0:
                                continue
                            words = []
                            for word in seqwords[i]:
                                if word == '<pad>':
                                    continue
                                elif word == '</s>':
                                    break
                                elif word == '<spc>':
                                    words.append(' ')
                                else:
                                    words.append(word)
                            if len(words) == 0:
                                outline = ''
                            else:
                                if seqrev:
                                    words = words[::-1]
                                outline = ''.join(words)
                            f.write('{}\n'.format(outline))

                    elif model.use_type == 'word' or model.use_type == 'bpe':
                        for i in range(len(seqwords)):
                            # skip padding sentences in batch (num_sent % batch_size != 0)
                            if src_lengths[i] == 0:
                                continue
                            words = []
                            for word in seqwords[i]:
                                if word == '<pad>':
                                    continue
                                elif word == '</s>':
                                    break
                                else:
                                    words.append(word)
                            if len(words) == 0:
                                outline = ''
                            else:
                                if seqrev:
                                    words = words[::-1]
                                outline = ' '.join(words)
                            f.write('{}\n'.format(outline))

                srcwords = _convert_to_words_batchfirst(
                    src_ids, test_set.src_id2word)
                dd_ps = other['classify_prob']

                # print dd output
                for i in range(len(srcwords)):
                    for j in range(len(srcwords[i])):
                        word = srcwords[i][j]
                        prob = dd_ps[i][j].data[0]
                        if word == '<pad>':
                            break
                        elif word == '</s>':
                            break
                        else:
                            f2.write('{}\t{}\n'.format(word, prob))
                    f2.write('\n')

                sys.stdout.flush()

    print('total #sent: {}'.format(total))
    f2.close()
Пример #13
0
def translate(test_set,
              load_dir,
              test_path_out,
              use_gpu,
              max_seq_len,
              beam_width,
              seqrev=False):
    """
		no reference tgt given - Run translation.
		Args:
			test_set: test dataset
				src, tgt using the same dir
			test_path_out: output dir
			load_dir: model dir
			use_gpu: on gpu/cpu
	"""

    # load model
    latest_checkpoint_path = load_dir
    resume_checkpoint = Checkpoint.load(latest_checkpoint_path)

    model = resume_checkpoint.model.to(device)
    print('Model dir: {}'.format(latest_checkpoint_path))
    print('Model laoded')

    # reset batch_size:
    model.reset_max_seq_len(max_seq_len)
    print('max seq len {}'.format(model.max_seq_len))
    sys.stdout.flush()

    # load test
    test_set.construct_batches(is_train=False)
    evaliter = iter(test_set.iter_loader)
    total = 0

    with open(os.path.join(test_path_out, 'translate.tsv'),
              'w',
              encoding="utf8") as f:
        model.eval()
        with torch.no_grad():
            for idx in range(len(evaliter)):
                batch_items = evaliter.next()
                src_ids = batch_items[0][0].to(device=device)
                src_lengths = batch_items[1]
                labs = batch_items[4][0].to(device=device)

                batch_size = src_ids.size(0)
                seq_len = int(max(src_lengths))

                decoder_outputs, decoder_hidden, other = model(
                    src_ids,
                    is_training=False,
                    use_gpu=use_gpu,
                    beam_width=beam_width)

                # memory usage
                mem_kb, mem_mb, mem_gb = get_memory_alloc()
                mem_mb = round(mem_mb, 2)
                print('Memory used: {0:.2f} MB'.format(mem_mb))
                batch_size = src_ids.size(0)

                # write to file
                # import pdb; pdb.set_trace()
                srcwords = _convert_to_words_batchfirst(
                    src_ids, test_set.src_id2word)
                dd_ps = other['classify_prob']

                # print dd output
                total += len(srcwords)
                for i in range(len(srcwords)):
                    for j in range(len(srcwords[i])):
                        word = srcwords[i][j]
                        prob = dd_ps[i][j].data[0]
                        if word == '<pad>':
                            break
                        elif word == '</s>':
                            break
                        else:
                            f.write('{}\t{}\n'.format(word, prob))
                    f.write('\n')
                sys.stdout.flush()

    print('total #sent: {}'.format(total))
    def _train_epoches(self,
                       train_sets,
                       model,
                       n_epochs,
                       start_epoch,
                       start_step,
                       dev_sets=None):

        # load datasets
        train_set_asr = train_sets['asr']
        dev_set_asr = dev_sets['asr']
        train_set_mt = train_sets['mt']
        dev_set_mt = dev_sets['mt']

        log = self.logger

        print_loss_ae_total = 0  # Reset every print_every
        print_loss_asr_total = 0
        print_loss_mt_total = 0
        print_loss_kl_total = 0
        print_loss_l2_total = 0

        step = start_step
        step_elapsed = 0
        prev_acc = 0.0
        prev_bleu = 0.0
        count_no_improve = 0
        count_num_rollback = 0
        ckpt = None

        # loop over epochs
        for epoch in range(start_epoch, n_epochs + 1):

            # update lr
            if self.lr_warmup_steps != 0:
                self.optimizer.optimizer = self.lr_scheduler(
                    self.optimizer.optimizer,
                    step,
                    init_lr=self.learning_rate_init,
                    peak_lr=self.learning_rate,
                    warmup_steps=self.lr_warmup_steps)
            # print lr
            for param_group in self.optimizer.optimizer.param_groups:
                log.info('epoch:{} lr: {}'.format(epoch, param_group['lr']))
                lr_curr = param_group['lr']

            # construct batches - allow re-shuffling of data
            log.info('--- construct train set ---')
            train_set_asr.construct_batches(is_train=True)
            train_set_mt.construct_batches(is_train=True)
            if dev_set_asr is not None:
                log.info('--- construct dev set ---')
                dev_set_asr.construct_batches(is_train=False)
                dev_set_mt.construct_batches(is_train=False)

            # print info
            steps_per_epoch_asr = len(train_set_asr.iter_loader)
            steps_per_epoch_mt = len(train_set_mt.iter_loader)
            steps_per_epoch = min(steps_per_epoch_asr, steps_per_epoch_mt)
            total_steps = steps_per_epoch * n_epochs
            log.info("steps_per_epoch {}".format(steps_per_epoch))
            log.info("total_steps {}".format(total_steps))

            log.info(" ---------- Epoch: %d, Step: %d ----------" %
                     (epoch, step))
            mem_kb, mem_mb, mem_gb = get_memory_alloc()
            mem_mb = round(mem_mb, 2)
            log.info('Memory used: {0:.2f} MB'.format(mem_mb))
            self.writer.add_scalar('Memory_MB', mem_mb, global_step=step)
            sys.stdout.flush()

            # loop over batches
            model.train(True)
            trainiter_asr = iter(train_set_asr.iter_loader)
            trainiter_mt = iter(train_set_mt.iter_loader)
            for idx in range(steps_per_epoch):

                # load batch items
                batch_items_asr = trainiter_asr.next()
                batch_items_mt = trainiter_mt.next()

                # update macro count
                step += 1
                step_elapsed += 1

                if self.lr_warmup_steps != 0:
                    self.optimizer.optimizer = self.lr_scheduler(
                        self.optimizer.optimizer,
                        step,
                        init_lr=self.learning_rate_init,
                        peak_lr=self.learning_rate,
                        warmup_steps=self.lr_warmup_steps)

                # Get loss
                losses = self._train_batch(model, batch_items_asr,
                                           batch_items_mt, step, total_steps)
                loss_ae = losses['nll_loss_ae']
                loss_asr = losses['nll_loss_asr']
                loss_mt = losses['nll_loss_mt']
                loss_kl = losses['kl_loss']
                loss_l2 = losses['l2_loss']

                print_loss_ae_total += loss_ae
                print_loss_asr_total += loss_asr
                print_loss_mt_total += loss_mt
                print_loss_kl_total += loss_kl
                print_loss_l2_total += loss_l2

                if step % self.print_every == 0 and step_elapsed > self.print_every:
                    print_loss_ae_avg = print_loss_ae_total / self.print_every
                    print_loss_ae_total = 0
                    print_loss_asr_avg = print_loss_asr_total / self.print_every
                    print_loss_asr_total = 0
                    print_loss_mt_avg = print_loss_mt_total / self.print_every
                    print_loss_mt_total = 0
                    print_loss_kl_avg = print_loss_kl_total / self.print_every
                    print_loss_kl_total = 0
                    print_loss_l2_avg = print_loss_l2_total / self.print_every
                    print_loss_l2_total = 0

                    log_msg = 'Progress: %d%%, Train nlll_ae: %.4f, nlll_asr: %.4f, ' % (
                        step / total_steps * 100, print_loss_ae_avg,
                        print_loss_asr_avg)
                    log_msg += 'Train nlll_mt: %.4f, l2: %.4f, kl_en: %.4f' % (
                        print_loss_mt_avg, print_loss_l2_avg,
                        print_loss_kl_avg)
                    log.info(log_msg)

                    self.writer.add_scalar('train_loss_ae',
                                           print_loss_ae_avg,
                                           global_step=step)
                    self.writer.add_scalar('train_loss_asr',
                                           print_loss_asr_avg,
                                           global_step=step)
                    self.writer.add_scalar('train_loss_mt',
                                           print_loss_mt_avg,
                                           global_step=step)
                    self.writer.add_scalar('train_loss_kl',
                                           print_loss_kl_avg,
                                           global_step=step)
                    self.writer.add_scalar('train_loss_l2',
                                           print_loss_l2_avg,
                                           global_step=step)

                # Checkpoint
                if step % self.checkpoint_every == 0 or step == total_steps:

                    # save criteria
                    if dev_set_asr is not None:
                        losses, metrics = self._evaluate_batches(
                            model, dev_set_asr, dev_set_mt)

                        loss_kl = losses['kl_loss']
                        loss_l2 = losses['l2_loss']
                        loss_ae = losses['nll_loss_ae']
                        accuracy_ae = metrics['accuracy_ae']
                        bleu_ae = metrics['bleu_ae']
                        loss_asr = losses['nll_loss_asr']
                        accuracy_asr = metrics['accuracy_asr']
                        bleu_asr = metrics['bleu_asr']
                        loss_mt = losses['nll_loss_mt']
                        accuracy_mt = metrics['accuracy_mt']
                        bleu_mt = metrics['bleu_mt']

                        log_msg = 'Progress: %d%%, Dev AE loss: %.4f, accuracy: %.4f, bleu: %.4f' % (
                            step / total_steps * 100, loss_ae, accuracy_ae,
                            bleu_ae)
                        log.info(log_msg)
                        log_msg = 'Progress: %d%%, Dev ASR loss: %.4f, accuracy: %.4f, bleu: %.4f' % (
                            step / total_steps * 100, loss_asr, accuracy_asr,
                            bleu_asr)
                        log.info(log_msg)
                        log_msg = 'Progress: %d%%, Dev MT loss: %.4f, accuracy: %.4f, bleu: %.4f' % (
                            step / total_steps * 100, loss_mt, accuracy_mt,
                            bleu_mt)
                        log.info(log_msg)
                        log_msg = 'Progress: %d%%, Dev En KL loss: %.4f, L2 loss: %.4f' % (
                            step / total_steps * 100, loss_kl, loss_l2)
                        log.info(log_msg)

                        self.writer.add_scalar('dev_loss_l2',
                                               loss_l2,
                                               global_step=step)
                        self.writer.add_scalar('dev_loss_kl',
                                               loss_kl,
                                               global_step=step)
                        self.writer.add_scalar('dev_loss_ae',
                                               loss_ae,
                                               global_step=step)
                        self.writer.add_scalar('dev_acc_ae',
                                               accuracy_ae,
                                               global_step=step)
                        self.writer.add_scalar('dev_bleu_ae',
                                               bleu_ae,
                                               global_step=step)
                        self.writer.add_scalar('dev_loss_asr',
                                               loss_asr,
                                               global_step=step)
                        self.writer.add_scalar('dev_acc_asr',
                                               accuracy_asr,
                                               global_step=step)
                        self.writer.add_scalar('dev_bleu_asr',
                                               bleu_asr,
                                               global_step=step)
                        self.writer.add_scalar('dev_loss_mt',
                                               loss_mt,
                                               global_step=step)
                        self.writer.add_scalar('dev_acc_mt',
                                               accuracy_mt,
                                               global_step=step)
                        self.writer.add_scalar('dev_bleu_mt',
                                               bleu_mt,
                                               global_step=step)

                        # save - use ASR res
                        accuracy_ave = (accuracy_asr / 4.0 + accuracy_mt) / 2.0
                        bleu_ave = (bleu_asr / 4.0 + bleu_mt) / 2.0
                        if ((prev_acc < accuracy_ave) and
                            (bleu_ave < 0.1)) or prev_bleu < bleu_ave:

                            # save best model - using bleu as metric
                            ckpt = Checkpoint(
                                model=model,
                                optimizer=self.optimizer,
                                epoch=epoch,
                                step=step,
                                input_vocab=train_set_asr.vocab_src,
                                output_vocab=train_set_asr.vocab_tgt)

                            saved_path = ckpt.save(self.expt_dir)
                            log.info('saving at {} ... '.format(saved_path))
                            # reset
                            prev_acc = accuracy_ave
                            prev_bleu = bleu_ave
                            count_no_improve = 0
                            count_num_rollback = 0
                        else:
                            count_no_improve += 1

                        # roll back
                        if count_no_improve > self.max_count_no_improve:
                            # break after self.max_count_no_improve epochs
                            if self.max_count_num_rollback == 0:
                                break
                            # resuming
                            latest_checkpoint_path = Checkpoint.get_latest_checkpoint(
                                self.expt_dir)
                            if type(latest_checkpoint_path) != type(None):
                                resume_checkpoint = Checkpoint.load(
                                    latest_checkpoint_path)
                                log.info(
                                    'epoch:{} step: {} - rolling back {} ...'.
                                    format(epoch, step,
                                           latest_checkpoint_path))
                                model = resume_checkpoint.model
                                self.optimizer = resume_checkpoint.optimizer
                                # A walk around to set optimizing parameters properly
                                resume_optim = self.optimizer.optimizer
                                defaults = resume_optim.param_groups[0]
                                defaults.pop('params', None)
                                defaults.pop('initial_lr', None)
                                self.optimizer.optimizer = resume_optim.__class__(
                                    model.parameters(), **defaults)

                            # reset
                            count_no_improve = 0
                            count_num_rollback += 1

                        # update learning rate
                        if count_num_rollback > self.max_count_num_rollback:

                            # roll back
                            latest_checkpoint_path = Checkpoint.get_latest_checkpoint(
                                self.expt_dir)
                            if type(latest_checkpoint_path) != type(None):
                                resume_checkpoint = Checkpoint.load(
                                    latest_checkpoint_path)
                                log.info(
                                    'epoch:{} step: {} - rolling back {} ...'.
                                    format(epoch, step,
                                           latest_checkpoint_path))
                                model = resume_checkpoint.model
                                self.optimizer = resume_checkpoint.optimizer
                                # A walk around to set optimizing parameters properly
                                resume_optim = self.optimizer.optimizer
                                defaults = resume_optim.param_groups[0]
                                defaults.pop('params', None)
                                defaults.pop('initial_lr', None)
                                self.optimizer.optimizer = resume_optim.__class__(
                                    model.parameters(), **defaults)

                            # decrease lr
                            for param_group in self.optimizer.optimizer.param_groups:
                                param_group['lr'] *= 0.5
                                lr_curr = param_group['lr']
                                log.info('reducing lr ...')
                                log.info('step:{} - lr: {}'.format(
                                    step, param_group['lr']))

                            # check early stop
                            if lr_curr <= 0.125 * self.learning_rate:
                                log.info('early stop ...')
                                break

                            # reset
                            count_no_improve = 0
                            count_num_rollback = 0

                        model.train(mode=True)
                        if ckpt is not None:
                            ckpt.rm_old(self.expt_dir, keep_num=self.keep_num)
                        log.info('n_no_improve {}, num_rollback {}'.format(
                            count_no_improve, count_num_rollback))

                    sys.stdout.flush()

            else:
                if dev_set_asr is None:
                    # save every epoch if no dev_set
                    ckpt = Checkpoint(model=model,
                                      optimizer=self.optimizer,
                                      epoch=epoch,
                                      step=step,
                                      input_vocab=train_set_asr.vocab_src,
                                      output_vocab=train_set_asr.vocab_tgt)
                    saved_path = ckpt.save_epoch(self.expt_dir, epoch)
                    log.info('saving at {} ... '.format(saved_path))
                    continue

                else:
                    continue

            # break nested for loop
            break
Пример #15
0
def translate(test_set, load_dir, test_path_out, use_gpu,
	max_seq_len, beam_width, device, seqrev=False):

	"""
		no reference tgt given - Run translation.
		Args:
			test_set: test dataset
				src, tgt using the same dir
			test_path_out: output dir
			load_dir: model dir
			use_gpu: on gpu/cpu
	"""
	# import pdb; pdb.set_trace()

	# load model
	latest_checkpoint_path = load_dir
	resume_checkpoint = Checkpoint.load(latest_checkpoint_path)

	model = resume_checkpoint.model.to(device)
	print('Model dir: {}'.format(latest_checkpoint_path))
	print('Model laoded')

	# reset batch_size:
	model.max_seq_len = max_seq_len
	model.enc.expand_time(max_seq_len)
	model.dec.expand_time(max_seq_len)
	print('max seq len {}'.format(model.max_seq_len))
	sys.stdout.flush()

	# load test
	test_set.construct_batches(is_train=False)
	evaliter = iter(test_set.iter_loader)
	print('num batches: {}'.format(len(evaliter)))

	with open(os.path.join(test_path_out, 'translate.txt'), 'w', encoding="utf8") as f:
		model.eval()
		with torch.no_grad():
			for idx in range(len(evaliter)):

				batch_items = evaliter.next()

				# load data
				src_ids = batch_items['srcid'][0]
				src_lengths = batch_items['srclen']
				tgt_ids = batch_items['tgtid'][0]
				tgt_lengths = batch_items['tgtlen']
				src_len = max(src_lengths)
				tgt_len = max(tgt_lengths)
				src_ids = src_ids[:,:src_len].to(device=device)
				tgt_ids = tgt_ids.to(device=device)

				# import pdb; pdb.set_trace()
				# split minibatch to avoid OOM
				# if idx < 12: continue

				n_minibatch = int(tgt_len / 100 + (tgt_len % 100 > 0))
				minibatch_size = int(src_ids.size(0) / n_minibatch)
				n_minibatch = int(src_ids.size(0) / minibatch_size +
					(src_ids.size(0) % minibatch_size > 0))

				for j in range(n_minibatch):

					print(idx+1, len(evaliter), '-', j+1, n_minibatch)

					st = j * minibatch_size
					ed = min((j+1) * minibatch_size, src_ids.size(0))
					src_ids_sub = src_ids[st:ed,:]

					time1 = time.time()
					if next(model.parameters()).is_cuda:
						preds = model.forward_translate(src=src_ids_sub,
								beam_width=beam_width, use_gpu=use_gpu)
					else:
						preds = model.forward_translate_fast(src=src_ids_sub,
								beam_width=beam_width, use_gpu=use_gpu)
					time2 = time.time()
					print(time2-time1)

					# write to file
					seqlist = preds[:,1:]
					seqwords = _convert_to_words_batchfirst(seqlist, test_set.tgt_id2word)

					# import pdb; pdb.set_trace()

					for i in range(len(seqwords)):
						if src_lengths[i] == 0:
							continue
						words = []
						for word in seqwords[i]:
							if word == '<pad>':
								continue
							elif word == '<spc>':
								words.append(' ')
							elif word == '</s>':
								break
							else:
								words.append(word)
						if len(words) == 0:
							outline = ''
						else:
							if seqrev:
								words = words[::-1]
							if test_set.use_type == 'word':
								outline = ' '.join(words)
							elif test_set.use_type == 'char':
								outline = ''.join(words)
						f.write('{}\n'.format(outline))

					sys.stdout.flush()
Пример #16
0
def translate(test_set, load_dir, test_path_out,
	use_gpu, max_seq_len, beam_width, mode='gec', seqrev=False):

	"""
		no reference tgt given - Run translation.
		Args:
			test_set: test dataset
				src, tgt using the same dir
			test_path_out: output dir
			load_dir: model dir
			use_gpu: on gpu/cpu
	"""

	# load model
	# latest_checkpoint_path = Checkpoint.get_latest_checkpoint(load_dir)
	# latest_checkpoint_path = Checkpoint.get_thirdlast_checkpoint(load_dir)
	latest_checkpoint_path = load_dir
	resume_checkpoint = Checkpoint.load(latest_checkpoint_path)

	model = resume_checkpoint.model
	print('Model dir: {}'.format(latest_checkpoint_path))
	print('Model laoded')

	# reset batch_size:
	model.reset_max_seq_len(max_seq_len)
	model.reset_use_gpu(use_gpu)
	model.reset_batch_size(test_set.batch_size)
	# fix compatibility
	model.check_classvar('shared_embed')
	model.check_classvar('additional_key_size')
	model.check_classvar('gec_num_bilstm_dec')
	model.check_classvar('num_unilstm_enc')
	model.check_classvar('residual')
	model.check_classvar('add_discriminator')
	if model.ptr_net == 'none': model.ptr_net = 'null'
	model.to(device)

	print('max seq len {}'.format(model.max_seq_len))
	sys.stdout.flush()

	# load test
	print('--- constrcut gec test set ---')
	if type(test_set.tsv_path) == type(None):
		test_batches, vocab_size = test_set.construct_batches(is_train=False)
	else:
		test_batches, vocab_size = test_set.construct_batches_with_ddfd_prob(is_train=False)

	with open(os.path.join(test_path_out, 'translate.txt'), 'w', encoding="utf8") as f:
		model.eval()
		match = 0
		total = 0
		with torch.no_grad():
			for batch in test_batches:

				src_ids = batch['src_word_ids']
				src_lengths = batch['src_sentence_lengths']
				src_probs = None
				if 'src_ddfd_probs' in batch and model.dd_additional_key_size > 0:
					src_probs =  batch['src_ddfd_probs']
					src_probs = _convert_to_tensor(src_probs, use_gpu).unsqueeze(2)

				src_ids = _convert_to_tensor(src_ids, use_gpu)
				tgt_ids = None

				if mode.lower() == 'gec' and beam_width <= 1:
					gec_dd_decoder_outputs, gec_dd_dec_hidden, gec_dd_ret_dict, \
						decoder_outputs, dec_hidden, ret_dict = \
							model.gec_eval(src_ids, tgt_ids, is_training=False,
									gec_dd_att_key_feats=src_probs, beam_width=beam_width)
				elif mode.lower() == 'gec' and beam_width > 1:
						decoder_outputs, dec_hidden, ret_dict = \
							model.gec_eval(src_ids, tgt_ids, is_training=False,
									gec_dd_att_key_feats=src_probs, beam_width=beam_width)
				elif mode.lower() == 'dd':
					decoder_outputs, dec_hidden, ret_dict = \
							model.dd_eval(src_ids, tgt_ids, is_training=False,
										dd_att_key_feats=src_probs, beam_width=beam_width)
				else:
					assert False, 'Unrecognised eval mode - choose from gec/dd'

				# memory usage
				mem_kb, mem_mb, mem_gb = get_memory_alloc()
				mem_mb = round(mem_mb, 2)
				print('Memory used: {0:.2f} MB'.format(mem_mb))

				# gec output write to file
				# import pdb; pdb.set_trace()
				seqlist = ret_dict['sequence']
				seqwords = _convert_to_words(seqlist, test_set.src_id2word)
				for i in range(len(seqwords)):
					# skip padding sentences in batch (num_sent % batch_size != 0)
					if src_lengths[i] == 0:
						continue
					words = []
					for word in seqwords[i]:
						if word == '<pad>':
							continue
						elif word == '</s>':
							break
						else:
							words.append(word)
					if len(words) == 0:
						outline = ''
					else:
						if seqrev:
							words = words[::-1]
						outline = ' '.join(words)
					f.write('{}\n'.format(outline))
					# if i == 0:
					# 	print(outline)
				sys.stdout.flush()
def evaluate(test_set, load_dir, test_path_out, use_gpu, max_seq_len, beam_width, seqrev=False):

	"""
		with reference tgt given - Run translation.
		Args:
			test_set: test dataset
			test_path_out: output dir
			load_dir: model dir
			use_gpu: on gpu/cpu
		Returns:
			accuracy (excluding PAD tokens)
	"""

	# load model
	# latest_checkpoint_path = Checkpoint.get_latest_checkpoint(load_dir)
	# latest_checkpoint_path = Checkpoint.get_thirdlast_checkpoint(load_dir)
	latest_checkpoint_path = load_dir
	resume_checkpoint = Checkpoint.load(latest_checkpoint_path)

	model = resume_checkpoint.model.to(device)
	print('Model dir: {}'.format(latest_checkpoint_path))
	print('Model laoded')

	# reset batch_size:
	model.reset_max_seq_len(max_seq_len)
	model.reset_use_gpu(use_gpu)
	model.reset_batch_size(test_set.batch_size)
	model.set_beam_width(beam_width)
	model.check_var('ptr_net')
	print('max seq len {}'.format(model.max_seq_len))
	sys.stdout.flush()

	# load test
	if type(test_set.attkey_path) == type(None):
		test_batches, vocab_size = test_set.construct_batches(is_train=False)
	else:
		test_batches, vocab_size = test_set.construct_batches_with_ddfd_prob(is_train=False)


	# f = open(os.path.join(test_path_out, 'test.txt'), 'w')
	with open(os.path.join(test_path_out, 'translate.txt'), 'w', encoding="utf8") as f:
		model.eval()
		match = 0
		total = 0
		with torch.no_grad():
			for batch in test_batches:

				src_ids = batch['src_word_ids']
				src_lengths = batch['src_sentence_lengths']
				tgt_ids = batch['tgt_word_ids']
				tgt_lengths = batch['tgt_sentence_lengths']
				src_probs = None
				if 'src_ddfd_probs' in batch:
					src_probs =  batch['src_ddfd_probs']
					src_probs = _convert_to_tensor(src_probs, use_gpu).unsqueeze(2)

				src_ids = _convert_to_tensor(src_ids, use_gpu)
				tgt_ids = _convert_to_tensor(tgt_ids, use_gpu)

				decoder_outputs, decoder_hidden, other = model(src_ids, tgt_ids,
																is_training=False,
																att_key_feats=src_probs,
																beam_width=beam_width)

				# Evaluation
				seqlist = other['sequence'] # traverse over time not batch
				if beam_width > 1:
					full_seqlist = other['topk_sequence']
					decoder_outputs = decoder_outputs[:-1]
				for step, step_output in enumerate(decoder_outputs):
					target = tgt_ids[:, step+1]
					non_padding = target.ne(PAD)
					correct = seqlist[step].view(-1).eq(target)
						.masked_select(non_padding).sum().item()
					match += correct
					total += non_padding.sum().item()

				# write to file
				seqwords = _convert_to_words(seqlist, test_set.tgt_id2word)
				for i in range(len(seqwords)):
					# skip padding sentences in batch (num_sent % batch_size != 0)
					if src_lengths[i] == 0:
						continue
					words = []
					for word in seqwords[i]:
						if word == '<pad>':
							continue
						elif word == '</s>':
							break
						else:
							words.append(word)
					if len(words) == 0:
						outline = ''
					else:
						if seqrev:
							words = words[::-1]
						outline = ' '.join(words)
					f.write('{}\n'.format(outline))
					if i == 0:
						print(outline)
				sys.stdout.flush()

		if total == 0:
			accuracy = float('nan')
		else:
			accuracy = match / total
Пример #18
0
def main():

	# load config
	parser = argparse.ArgumentParser(description='Seq2seq Evaluation')
	parser = load_arguments(parser)
	args = vars(parser.parse_args())
	config = validate_config(args)

	# load src-tgt pair
	test_path_src = config['test_path_src']
	test_path_tgt = test_path_src
	test_path_out = config['test_path_out']
	load_dir = config['load']
	max_seq_len = config['max_seq_len']
	batch_size = config['batch_size']
	beam_width = config['beam_width']
	use_gpu = config['use_gpu']
	seqrev = config['seqrev']
	use_type = config['use_type']

	# set test mode: 1 = translate; 2 = plot; 3 = save comb ckpt
	MODE = config['eval_mode']
	if MODE != 3:
		if not os.path.exists(test_path_out):
			os.makedirs(test_path_out)
		config_save_dir = os.path.join(test_path_out, 'eval.cfg')
		save_config(config, config_save_dir)

	# check device:
	device = check_device(use_gpu)
	print('device: {}'.format(device))

	# load model
	latest_checkpoint_path = load_dir
	resume_checkpoint = Checkpoint.load(latest_checkpoint_path)
	model = resume_checkpoint.model.to(device)
	vocab_src = resume_checkpoint.input_vocab
	vocab_tgt = resume_checkpoint.output_vocab
	print('Model dir: {}'.format(latest_checkpoint_path))
	print('Model laoded')

	# combine model
	if type(config['combine_path']) != type(None):
		model = combine_weights(config['combine_path'])

	# load test_set
	test_set = Dataset(test_path_src, test_path_tgt,
						vocab_src_list=vocab_src, vocab_tgt_list=vocab_tgt,
						seqrev=seqrev,
						max_seq_len=900,
						batch_size=batch_size,
						use_gpu=use_gpu,
						use_type=use_type)
	print('Test dir: {}'.format(test_path_src))
	print('Testset loaded')
	sys.stdout.flush()

	# run eval
	if MODE == 1:
		translate(test_set, model, test_path_out, use_gpu,
			max_seq_len, beam_width, device, seqrev=seqrev)

	elif MODE == 2: # output posterior
		translate_logp(test_set, model, test_path_out, use_gpu,
			max_seq_len, device, seqrev=seqrev)

	elif MODE == 3: # save combined model
		ckpt = Checkpoint(model=model,
				   optimizer=None, epoch=0, step=0,
				   input_vocab=test_set.vocab_src,
				   output_vocab=test_set.vocab_tgt)
		saved_path = ckpt.save_customise(
			os.path.join(config['combine_path'].strip('/')+'-combine','combine'))
		log_ckpts(config['combine_path'], config['combine_path'].strip('/')+'-combine')
		print('saving at {} ... '.format(saved_path))
def translate(test_set, load_dir, test_path_out, use_gpu, max_seq_len, beam_width, seqrev=False):

	"""
		no reference tgt given - Run translation.
		Args:
			test_set: test dataset
				src, tgt using the same dir
			test_path_out: output dir
			load_dir: model dir
			use_gpu: on gpu/cpu
	"""

	# load model
	# latest_checkpoint_path = Checkpoint.get_latest_checkpoint(load_dir)
	# latest_checkpoint_path = Checkpoint.get_thirdlast_checkpoint(load_dir)
	latest_checkpoint_path = load_dir
	resume_checkpoint = Checkpoint.load(latest_checkpoint_path)

	model = resume_checkpoint.model.to(device)
	print('Model dir: {}'.format(latest_checkpoint_path))
	print('Model laoded')

	# reset batch_size:
	model.reset_max_seq_len(max_seq_len)
	model.reset_use_gpu(use_gpu)
	model.reset_batch_size(test_set.batch_size)
	model.check_var('ptr_net')
	print('max seq len {}'.format(model.max_seq_len))
	sys.stdout.flush()

	# load test
	if type(test_set.attkey_path) == type(None):
		test_batches, vocab_size = test_set.construct_batches(is_train=False)
	else:
		test_batches, vocab_size = test_set.construct_batches_with_ddfd_prob(is_train=False)

	# f = open(os.path.join(test_path_out, 'translate.txt'), 'w') -> use proper encoding
	with open(os.path.join(test_path_out, 'translate.txt'), 'w', encoding="utf8") as f:
		model.eval()
		match = 0
		total = 0
		with torch.no_grad():
			for batch in test_batches:

				src_ids = batch['src_word_ids']
				src_lengths = batch['src_sentence_lengths']
				src_probs = None
				if 'src_ddfd_probs' in batch:
					src_probs =  batch['src_ddfd_probs']
					src_probs = _convert_to_tensor(src_probs, use_gpu).unsqueeze(2)

				src_ids = _convert_to_tensor(src_ids, use_gpu)
				decoder_outputs, decoder_hidden, other = model(src=src_ids,
																is_training=False,
																att_key_feats=src_probs,
																beam_width=beam_width)
				# memory usage
				mem_kb, mem_mb, mem_gb = get_memory_alloc()
				mem_mb = round(mem_mb, 2)
				print('Memory used: {0:.2f} MB'.format(mem_mb))

				# write to file
				seqlist = other['sequence']
				seqwords = _convert_to_words(seqlist, test_set.src_id2word)
				for i in range(len(seqwords)):
					# skip padding sentences in batch (num_sent % batch_size != 0)
					if src_lengths[i] == 0:
						continue
					words = []
					for word in seqwords[i]:
						if word == '<pad>':
							continue
						elif word == '</s>':
							break
						else:
							words.append(word)
					if len(words) == 0:
						outline = ''
					else:
						if seqrev:
							words = words[::-1]
						outline = ' '.join(words)
					f.write('{}\n'.format(outline))
					# if i == 0:
					# 	print(outline)
				sys.stdout.flush()
    def train(self,
              train_sets,
              model,
              num_epochs=5,
              optimizer=None,
              dev_sets=None,
              grab_memory=True):
        """
			Run training for a given model.
			Args:
				train_set: dataset
				dev_set: dataset, optional
				model: model to run training on
				optimizer (seq2seq.optim.Optimizer, optional): optimizer for training
			Returns:
				model (seq2seq.models): trained model.
		"""

        if 'resume' in self.load_mode or 'restart' in self.load_mode:

            # resume training
            latest_checkpoint_path = self.load_dir
            self.logger.info('resuming {} ...'.format(latest_checkpoint_path))
            resume_checkpoint = Checkpoint.load(latest_checkpoint_path)
            model = resume_checkpoint.model
            self.logger.info(model)
            self.optimizer = resume_checkpoint.optimizer

            # A walk around to set optimizing parameters properly
            resume_optim = self.optimizer.optimizer
            defaults = resume_optim.param_groups[0]
            defaults.pop('params', None)
            defaults.pop('initial_lr', None)
            self.optimizer.optimizer = resume_optim.__class__(
                model.parameters(), **defaults)

            # set freeze param
            for name, param in model.named_parameters():
                log = self.logger.info('{}:{}'.format(name, param.size()))

                # various mode
                if self.load_mode == 'ASR-resume' and self.load_freeze:
                    # freeze LAS
                    if 'las' in name:
                        log = self.logger.info('freezed')
                        param.requires_grad = False

                elif self.load_mode == 'AE-ASR-resume' and self.load_freeze:
                    # freeze LAS and EN embedder
                    if 'las' in name or 'enc_embedder' in name:
                        log = self.logger.info('freezed')
                        param.requires_grad = False

            # set step/epoch
            if 'resume' in self.load_mode:
                # start from prev
                start_epoch = resume_checkpoint.epoch  # start from the saved epoch!
                step = resume_checkpoint.step  # start from the saved step!
            elif 'restart' in self.load_mode:
                # just for the sake of finetuning
                start_epoch = 1
                step = 0

        else:

            # all are init from start
            if self.load_mode == 'LAS':
                """
					load LAS pyramidal LSTM from old dir
					freeze: only the pyramidal LSTMs in AcousEnc
				"""

                las_checkpoint_path = self.load_dir
                self.logger.info('loading Pyramidal lstm {} ...'.format(
                    las_checkpoint_path))
                las_checkpoint = Checkpoint.load(las_checkpoint_path)
                las_model = las_checkpoint.model
                # assign param
                for name, param in model.named_parameters():
                    loaded = False
                    log = self.logger.info('{}:{}'.format(name, param.size()))
                    for las_name, las_param in las_model.named_parameters():
                        # las_name = encoder.acous_enc_l1.weight_ih_l0
                        # name = las.enc.acous_enc_l1.weight_ih_l0
                        name_init = '.'.join(name.split('.')[0:2])
                        name_rest = '.'.join(name.split('.')[2:])
                        las_name_rest = '.'.join(las_name.split('.')[1:])
                        if name_init == 'las.encoder' and name_rest == las_name_rest:
                            assert param.data.size() == las_param.data.size(), \
                             'las_name {} {} : name {} {}'.format(las_name,
                              las_param.data.size(), name, param.data.size())
                            param.data = las_param.data
                            self.logger.info('loading {}'.format(las_name))
                            loaded = True
                            if self.load_freeze:
                                self.logger.info('freezed')
                                param.requires_grad = False
                            else:
                                self.logger.info('not freezed')
                    if not loaded:
                        self.logger.info('not preloaded - {}'.format(name))
                    # import pdb; pdb.set_trace()

            elif self.load_mode == 'ASR':
                """
					load ASR model: compatible only with asr-v001 model
					freeze: AcousEnc+EnDec (Entire LAS model)
				"""

                asr_checkpoint_path = self.load_dir
                self.logger.info(
                    'loading ASR {} ...'.format(asr_checkpoint_path))
                asr_checkpoint = Checkpoint.load(asr_checkpoint_path)
                asr_model = asr_checkpoint.model
                # assign param
                for name, param in model.named_parameters():
                    loaded = False
                    log = self.logger.info('{}:{}'.format(name, param.size()))
                    # name = las.encoder.acous_enc_l1.weight_ih_l0
                    name_init = '.'.join(name.split('.')[0:1])
                    name_rest = '.'.join(name.split('.')[1:])

                    for asr_name, asr_param in asr_model.named_parameters():
                        if name_init == 'las' and name == asr_name:
                            assert param.data.size() == asr_param.data.size()
                            param.data = asr_param.data
                            loaded = True
                            self.logger.info('loading {}'.format(asr_name))
                            if self.load_freeze:  # freezing embedder too
                                self.logger.info('freezed')
                                param.requires_grad = False
                            else:
                                self.logger.info('not freezed')
                    if not loaded:
                        # make exception for las dec embedder
                        if name == 'las.decoder.embedder.weight':
                            model.las.decoder.embedder.weight.data = \
                             asr_model.enc_embedder.weight.data
                            self.logger.info('assigning {} with {}'.format(
                                'las.decoder.embedder.weight',
                                'enc_embedder.weight'))
                            if self.load_freeze:
                                self.logger.info('freezed')
                                param.requires_grad = False
                            else:
                                self.logger.info('not freezed')
                        else:
                            self.logger.info('not preloaded - {}'.format(name))
                    # import pdb; pdb.set_trace()

            elif self.load_mode == 'ASR-PARTIAL':
                """
					load ASR model: compatible only with (ted-)asr-v001 model
					freeze: AcousEnc (LAS model excluding las.decoder.acous_out)
				"""

                asr_checkpoint_path = self.load_dir
                self.logger.info(
                    'loading ASR {} ...'.format(asr_checkpoint_path))
                asr_checkpoint = Checkpoint.load(asr_checkpoint_path)
                asr_model = asr_checkpoint.model
                # assign param
                for name, param in model.named_parameters():
                    loaded = False
                    log = self.logger.info('{}:{}'.format(name, param.size()))
                    # name = las.encoder.acous_enc_l1.weight_ih_l0
                    name_init = '.'.join(name.split('.')[0:1])
                    name_rest = '.'.join(name.split('.')[1:])

                    for asr_name, asr_param in asr_model.named_parameters():
                        if name_init == 'las' and name == asr_name:
                            assert param.data.size() == asr_param.data.size()
                            param.data = asr_param.data
                            loaded = True
                            self.logger.info('loading {}'.format(asr_name))
                            if self.load_freeze and ('las.decoder.acous_out'
                                                     not in name):
                                self.logger.info('freezed')
                                param.requires_grad = False
                            else:
                                self.logger.info('not freezed')
                    if not loaded:
                        # make exception for las dec embedder
                        if name == 'las.decoder.embedder.weight':
                            model.las.decoder.embedder.weight.data = \
                             asr_model.enc_embedder.weight.data
                            self.logger.info('assigning {} with {}'.format(
                                'las.decoder.embedder.weight',
                                'enc_embedder.weight'))
                            if self.load_freeze:
                                self.logger.info('freezed')
                                param.requires_grad = False
                            else:
                                self.logger.info('not freezed')
                        else:
                            self.logger.info('not preloaded - {}'.format(name))
                    # import pdb; pdb.set_trace()

            elif self.load_mode == 'AE-ASR':
                """
					load AE-ASR model
					freeze: entire AcousEnc+EnDec+EnEnc
				"""

                aeasr_checkpoint_path = self.load_dir
                self.logger.info('loading AE-ASR model {} ...'.format(
                    aeasr_checkpoint_path))
                aeasr_checkpoint = Checkpoint.load(aeasr_checkpoint_path)
                aeasr_model = aeasr_checkpoint.model
                # assign param
                for name, param in model.named_parameters():
                    loaded = False
                    log = self.logger.info('{}:{}'.format(name, param.size()))
                    name_init = '.'.join(name.split('.')[0:1])
                    for aeasr_name, aeasr_param in aeasr_model.named_parameters(
                    ):
                        if name == aeasr_name and (name_init == 'las'
                                                   or name_init
                                                   == 'enc_embedder'):
                            assert param.data.size() == aeasr_param.data.size()
                            param.data = aeasr_param.data
                            self.logger.info('loading {}'.format(aeasr_name))
                            loaded = True
                            if self.load_freeze:  # freezing embedder too
                                self.logger.info('freezed')
                                param.requires_grad = False
                            else:
                                self.logger.info('not freezed')
                    if not loaded:
                        self.logger.info('not preloaded - {}'.format(name))

            elif self.load_mode == 'AE-ASR-MT':
                """
					load general models
					freeze: entire AcousEnc+EnDec+EnEnc
				"""

                checkpoint_path = self.load_dir
                self.logger.info(
                    'loading model {} ...'.format(checkpoint_path))
                checkpoint = Checkpoint.load(checkpoint_path)
                load_model = checkpoint.model
                # assign param
                for name, param in model.named_parameters():
                    loaded = False
                    log = self.logger.info('{}:{}'.format(name, param.size()))
                    name_init = '.'.join(name.split('.')[:1])
                    for load_name, load_param in load_model.named_parameters():
                        if name == load_name:
                            assert param.data.size() == load_param.data.size()
                            param.data = load_param.data
                            self.logger.info('loading {}'.format(load_name))
                            loaded = True
                            if self.load_freeze and (name_init == 'las'
                                                     or name_init
                                                     == 'enc_embedder'):
                                self.logger.info('freezed')
                                param.requires_grad = False
                            else:
                                self.logger.info('not freezed')
                    if not loaded:
                        self.logger.info('not preloaded - {}'.format(name))

            elif self.load_mode == 'AE-ASR-MT-PARTIAL':
                """
					load general models
					freeze: AcousEnc(only LAS-LSTM, exclude LAS-DEC)
					won't train: EnDec+EnEnc
					train: DeDec
				"""

                checkpoint_path = self.load_dir
                self.logger.info(
                    'loading model {} ...'.format(checkpoint_path))
                checkpoint = Checkpoint.load(checkpoint_path)
                load_model = checkpoint.model
                # assign param
                for name, param in model.named_parameters():
                    loaded = False
                    log = self.logger.info('{}:{}'.format(name, param.size()))
                    name_init = '.'.join(name.split('.')[:2])
                    for load_name, load_param in load_model.named_parameters():
                        if name == load_name:
                            assert param.data.size() == load_param.data.size()
                            param.data = load_param.data
                            self.logger.info('loading {}'.format(load_name))
                            loaded = True
                            if self.load_freeze and (name_init
                                                     == 'las.encoder'):
                                self.logger.info('freezed')
                                param.requires_grad = False
                            else:
                                self.logger.info('not freezed')
                    if not loaded:
                        self.logger.info('not preloaded - {}'.format(name))

            elif type(self.load_dir) != type(None):
                """ load general models """

                checkpoint_path = self.load_dir
                self.logger.info(
                    'loading model {} ...'.format(checkpoint_path))
                checkpoint = Checkpoint.load(checkpoint_path)
                load_model = checkpoint.model
                # assign param
                for name, param in model.named_parameters():
                    loaded = False
                    log = self.logger.info('{}:{}'.format(name, param.size()))
                    for load_name, load_param in load_model.named_parameters():
                        if name == load_name:
                            assert param.data.size() == load_param.data.size()
                            param.data = load_param.data
                            self.logger.info('loading {}'.format(load_name))
                            loaded = True
                            if self.load_freeze:
                                self.logger.info('freezed')
                                param.requires_grad = False
                            else:
                                self.logger.info('not freezed')
                    if not loaded:
                        self.logger.info('not preloaded - {}'.format(name))

            else:
                # not loading pre-trained model
                for name, param in model.named_parameters():
                    log = self.logger.info('{}:{}'.format(name, param.size()))

            # init opt
            if optimizer is None:
                optimizer = Optimizer(torch.optim.Adam(
                    model.parameters(), lr=self.learning_rate_init),
                                      max_grad_norm=self.max_grad_norm)
            self.optimizer = optimizer
            start_epoch = 1
            step = 0

        # train epochs
        self.logger.info("Optimizer: %s, Scheduler: %s" %
                         (self.optimizer.optimizer, self.optimizer.scheduler))

        # reserve memory
        # import pdb; pdb.set_trace()
        if self.device == torch.device('cuda') and grab_memory:
            reserve_memory(device_id=self.gpu_id)

        # training
        self._train_epoches(train_sets,
                            model,
                            num_epochs,
                            start_epoch,
                            step,
                            dev_sets=dev_sets)

        return model
Пример #21
0
def main():

    # load config
    parser = argparse.ArgumentParser(description='Evaluation')
    parser = load_arguments(parser)
    args = vars(parser.parse_args())
    config = validate_config(args)

    # load src-tgt pair
    test_path_src = config['test_path_src']
    test_path_tgt = config['test_path_tgt']
    if type(test_path_tgt) == type(None):
        test_path_tgt = test_path_src

    test_path_out = config['test_path_out']
    test_acous_path = config['test_acous_path']
    acous_norm_path = config['acous_norm_path']

    load_dir = config['load']
    max_seq_len = config['max_seq_len']
    batch_size = config['batch_size']
    beam_width = config['beam_width']
    use_gpu = config['use_gpu']
    seqrev = config['seqrev']
    use_type = config['use_type']

    # set test mode
    MODE = config['eval_mode']
    if MODE != 2:
        if not os.path.exists(test_path_out):
            os.makedirs(test_path_out)
        config_save_dir = os.path.join(test_path_out, 'eval.cfg')
        save_config(config, config_save_dir)

    # check device:
    device = check_device(use_gpu)
    print('device: {}'.format(device))

    # load model
    latest_checkpoint_path = load_dir
    resume_checkpoint = Checkpoint.load(latest_checkpoint_path)
    model = resume_checkpoint.model.to(device)
    vocab_src = resume_checkpoint.input_vocab
    vocab_tgt = resume_checkpoint.output_vocab
    print('Model dir: {}'.format(latest_checkpoint_path))
    print('Model laoded')

    # combine model
    if type(config['combine_path']) != type(None):
        model = combine_weights(config['combine_path'])
    # import pdb; pdb.set_trace()

    # load test_set
    test_set = Dataset(
        path_src=test_path_src,
        path_tgt=test_path_tgt,
        vocab_src_list=vocab_src,
        vocab_tgt_list=vocab_tgt,
        use_type=use_type,
        acous_path=test_acous_path,
        seqrev=seqrev,
        acous_norm=config['acous_norm'],
        acous_norm_path=config['acous_norm_path'],
        acous_max_len=6000,  # max 50k for mustc trainset
        max_seq_len_src=900,
        max_seq_len_tgt=900,  # max 2.5k for mustc trainset
        batch_size=batch_size,
        mode='ST',
        use_gpu=use_gpu)

    print('Test dir: {}'.format(test_path_src))
    print('Testset loaded')
    sys.stdout.flush()

    # '{AE|ASR|MT|ST}-{REF|HYP}'
    if len(config['gen_mode'].split('-')) == 2:
        gen_mode = config['gen_mode'].split('-')[0]
        history = config['gen_mode'].split('-')[1]
    elif len(config['gen_mode'].split('-')) == 1:
        gen_mode = config['gen_mode']
        history = 'HYP'

    # add external language model
    lm_mode = config['lm_mode']

    # run eval:
    if MODE == 1:
        translate(test_set,
                  model,
                  test_path_out,
                  use_gpu,
                  max_seq_len,
                  beam_width,
                  device,
                  seqrev=seqrev,
                  gen_mode=gen_mode,
                  lm_mode=lm_mode,
                  history=history)

    elif MODE == 2:  # save combined model
        ckpt = Checkpoint(model=model,
                          optimizer=None,
                          epoch=0,
                          step=0,
                          input_vocab=test_set.vocab_src,
                          output_vocab=test_set.vocab_tgt)
        saved_path = ckpt.save_customise(
            os.path.join(config['combine_path'].strip('/') + '-combine',
                         'combine'))
        log_ckpts(config['combine_path'],
                  config['combine_path'].strip('/') + '-combine')
        print('saving at {} ... '.format(saved_path))

    elif MODE == 3:
        plot_emb(test_set, model, test_path_out, use_gpu, max_seq_len, device)

    elif MODE == 4:
        gather_emb(test_set, model, test_path_out, use_gpu, max_seq_len,
                   device)

    elif MODE == 5:
        compute_kl(test_set, model, test_path_out, use_gpu, max_seq_len,
                   device)
Пример #22
0
    def _train_epoches(self,
                       train_set,
                       model,
                       n_epochs,
                       start_epoch,
                       start_step,
                       dev_set=None):

        log = self.logger

        print_loss_total = 0  # Reset every print_every
        epoch_loss_total = 0  # Reset every epoch
        att_print_loss_total = 0  # Reset every print_every
        att_epoch_loss_total = 0  # Reset every epoch
        attcls_print_loss_total = 0  # Reset every print_every
        attcls_epoch_loss_total = 0  # Reset every epoch

        step = start_step
        step_elapsed = 0
        prev_acc = 0.0
        count_no_improve = 0
        count_num_rollback = 0
        ckpt = None

        # ******************** [loop over epochs] ********************
        for epoch in range(start_epoch, n_epochs + 1):

            for param_group in self.optimizer.optimizer.param_groups:
                print('epoch:{} lr: {}'.format(epoch, param_group['lr']))
                lr_curr = param_group['lr']

            # ----------construct batches-----------
            # allow re-shuffling of data
            if type(train_set.attkey_path) == type(None):
                print('--- construct train set ---')
                train_batches, vocab_size = train_set.construct_batches(
                    is_train=True)
                if dev_set is not None:
                    print('--- construct dev set ---')
                    dev_batches, vocab_size = dev_set.construct_batches(
                        is_train=False)
            else:
                print('--- construct train set ---')
                train_batches, vocab_size = train_set.construct_batches_with_ddfd_prob(
                    is_train=True)
                if dev_set is not None:
                    print('--- construct dev set ---')
                    assert type(dev_set.attkey_path) != type(
                        None), 'Dev set missing ddfd probabilities'
                    dev_batches, vocab_size = dev_set.construct_batches_with_ddfd_prob(
                        is_train=False)

            # --------print info for each epoch----------
            steps_per_epoch = len(train_batches)
            total_steps = steps_per_epoch * n_epochs
            log.info("steps_per_epoch {}".format(steps_per_epoch))
            log.info("total_steps {}".format(total_steps))

            log.debug(
                " ----------------- Epoch: %d, Step: %d -----------------" %
                (epoch, step))
            mem_kb, mem_mb, mem_gb = get_memory_alloc()
            mem_mb = round(mem_mb, 2)
            print('Memory used: {0:.2f} MB'.format(mem_mb))
            self.writer.add_scalar('Memory_MB', mem_mb, global_step=step)
            sys.stdout.flush()

            # ******************** [loop over batches] ********************
            model.train(True)
            for batch in train_batches:

                # update macro count
                step += 1
                step_elapsed += 1

                # load data
                src_ids = batch['src_word_ids']
                src_lengths = batch['src_sentence_lengths']
                tgt_ids = batch['tgt_word_ids']
                tgt_lengths = batch['tgt_sentence_lengths']

                src_probs = None
                src_labs = None
                if 'src_ddfd_probs' in batch and model.additional_key_size > 0:
                    src_probs = batch['src_ddfd_probs']
                    src_probs = _convert_to_tensor(src_probs,
                                                   self.use_gpu).unsqueeze(2)
                if 'src_ddfd_labs' in batch:
                    src_labs = batch['src_ddfd_labs']
                    src_labs = _convert_to_tensor(src_labs,
                                                  self.use_gpu).unsqueeze(2)

                # sanity check src-tgt pair
                if step == 1:
                    print('--- Check src tgt pair ---')
                    log_msgs = check_srctgt(src_ids, tgt_ids,
                                            train_set.src_id2word,
                                            train_set.tgt_id2word)
                    for log_msg in log_msgs:
                        sys.stdout.buffer.write(log_msg)

                # convert variable to tensor
                src_ids = _convert_to_tensor(src_ids, self.use_gpu)
                tgt_ids = _convert_to_tensor(tgt_ids, self.use_gpu)

                # Get loss
                loss, att_loss, attcls_loss = self._train_batch(
                    src_ids,
                    tgt_ids,
                    model,
                    step,
                    total_steps,
                    src_probs=src_probs,
                    src_labs=src_labs)

                print_loss_total += loss
                epoch_loss_total += loss
                att_print_loss_total += att_loss
                att_epoch_loss_total += att_loss
                attcls_print_loss_total += attcls_loss
                attcls_epoch_loss_total += attcls_loss

                if step % self.print_every == 0 and step_elapsed > self.print_every:
                    print_loss_avg = print_loss_total / self.print_every
                    att_print_loss_avg = att_print_loss_total / self.print_every
                    attcls_print_loss_avg = attcls_print_loss_total / self.print_every
                    print_loss_total = 0
                    att_print_loss_total = 0
                    attcls_print_loss_total = 0

                    log_msg = 'Progress: %d%%, Train nlll: %.4f, att: %.4f, attcls: %.4f' % (
                        step / total_steps * 100, print_loss_avg,
                        att_print_loss_avg, attcls_print_loss_avg)
                    # print(log_msg)
                    log.info(log_msg)
                    self.writer.add_scalar('train_loss',
                                           print_loss_avg,
                                           global_step=step)
                    self.writer.add_scalar('att_train_loss',
                                           att_print_loss_avg,
                                           global_step=step)
                    self.writer.add_scalar('attcls_train_loss',
                                           attcls_print_loss_avg,
                                           global_step=step)

                # Checkpoint
                if step % self.checkpoint_every == 0 or step == total_steps:

                    # save criteria
                    if dev_set is not None:
                        dev_loss, accuracy, dev_attlosses = \
                         self._evaluate_batches(model, dev_batches, dev_set)
                        dev_attloss = dev_attlosses['att_loss']
                        dev_attclsloss = dev_attlosses['attcls_loss']
                        log_msg = 'Progress: %d%%, Dev loss: %.4f, accuracy: %.4f, att: %.4f, attcls: %.4f' % (
                            step / total_steps * 100, dev_loss, accuracy,
                            dev_attloss, dev_attclsloss)
                        log.info(log_msg)
                        self.writer.add_scalar('dev_loss',
                                               dev_loss,
                                               global_step=step)
                        self.writer.add_scalar('dev_acc',
                                               accuracy,
                                               global_step=step)
                        self.writer.add_scalar('att_dev_loss',
                                               dev_attloss,
                                               global_step=step)
                        self.writer.add_scalar('attcls_dev_loss',
                                               dev_attclsloss,
                                               global_step=step)

                        # save
                        if prev_acc < accuracy:
                            # save best model
                            ckpt = Checkpoint(model=model,
                                              optimizer=self.optimizer,
                                              epoch=epoch,
                                              step=step,
                                              input_vocab=train_set.vocab_src,
                                              output_vocab=train_set.vocab_tgt)

                            saved_path = ckpt.save(self.expt_dir)
                            print('saving at {} ... '.format(saved_path))
                            # reset
                            prev_acc = accuracy
                            count_no_improve = 0
                            count_num_rollback = 0
                        else:
                            count_no_improve += 1

                        # roll back
                        if count_no_improve > MAX_COUNT_NO_IMPROVE:
                            # resuming
                            latest_checkpoint_path = Checkpoint.get_latest_checkpoint(
                                self.expt_dir)
                            if type(latest_checkpoint_path) != type(None):
                                resume_checkpoint = Checkpoint.load(
                                    latest_checkpoint_path)
                                print(
                                    'epoch:{} step: {} - rolling back {} ...'.
                                    format(epoch, step,
                                           latest_checkpoint_path))
                                model = resume_checkpoint.model
                                self.optimizer = resume_checkpoint.optimizer
                                # A walk around to set optimizing parameters properly
                                resume_optim = self.optimizer.optimizer
                                defaults = resume_optim.param_groups[0]
                                defaults.pop('params', None)
                                defaults.pop('initial_lr', None)
                                self.optimizer.optimizer = resume_optim\
                                 .__class__(model.parameters(), **defaults)
                                # start_epoch = resume_checkpoint.epoch
                                # step = resume_checkpoint.step

                            # reset
                            count_no_improve = 0
                            count_num_rollback += 1

                        # update learning rate
                        if count_num_rollback > MAX_COUNT_NUM_ROLLBACK:

                            # roll back
                            latest_checkpoint_path = Checkpoint.get_latest_checkpoint(
                                self.expt_dir)
                            if type(latest_checkpoint_path) != type(None):
                                resume_checkpoint = Checkpoint.load(
                                    latest_checkpoint_path)
                                print(
                                    'epoch:{} step: {} - rolling back {} ...'.
                                    format(epoch, step,
                                           latest_checkpoint_path))
                                model = resume_checkpoint.model
                                self.optimizer = resume_checkpoint.optimizer
                                # A walk around to set optimizing parameters properly
                                resume_optim = self.optimizer.optimizer
                                defaults = resume_optim.param_groups[0]
                                defaults.pop('params', None)
                                defaults.pop('initial_lr', None)
                                self.optimizer.optimizer = resume_optim\
                                 .__class__(model.parameters(), **defaults)
                                start_epoch = resume_checkpoint.epoch
                                step = resume_checkpoint.step

                            # decrease lr
                            for param_group in self.optimizer.optimizer.param_groups:
                                param_group['lr'] *= 0.5
                                lr_curr = param_group['lr']
                                print('reducing lr ...')
                                print('step:{} - lr: {}'.format(
                                    step, param_group['lr']))

                            # check early stop
                            if lr_curr < 0.000125:
                                print('early stop ...')
                                break

                            # reset
                            count_no_improve = 0
                            count_num_rollback = 0

                        model.train(mode=True)
                        if ckpt is None:
                            ckpt = Checkpoint(model=model,
                                              optimizer=self.optimizer,
                                              epoch=epoch,
                                              step=step,
                                              input_vocab=train_set.vocab_src,
                                              output_vocab=train_set.vocab_tgt)
                            saved_path = ckpt.save(self.expt_dir)
                        ckpt.rm_old(self.expt_dir, keep_num=KEEP_NUM)
                        print('n_no_improve {}, num_rollback {}'.format(
                            count_no_improve, count_num_rollback))
                    sys.stdout.flush()

            else:
                if dev_set is None:
                    # save every epoch if no dev_set
                    ckpt = Checkpoint(model=model,
                                      optimizer=self.optimizer,
                                      epoch=epoch,
                                      step=step,
                                      input_vocab=train_set.vocab_src,
                                      output_vocab=train_set.vocab_tgt)
                    # saved_path = ckpt.save(self.expt_dir)
                    saved_path = ckpt.save_epoch(self.expt_dir, epoch)
                    print('saving at {} ... '.format(saved_path))
                    continue

                else:
                    continue
            # break nested for loop
            break

            if step_elapsed == 0: continue
            epoch_loss_avg = epoch_loss_total / min(steps_per_epoch,
                                                    step - start_step)
            epoch_loss_total = 0
            log_msg = "Finished epoch %d: Train %s: %.4f" % (
                epoch, self.loss.name, epoch_loss_avg)

            log.info('\n')
            log.info(log_msg)