Ejemplo n.º 1
0
def translate_batch(test_set, model, test_path_out, use_gpu,
	max_seq_len, beam_width, device, seqrev=False):

	"""
		no reference tgt given - Run translation.
		Args:
			test_set: test dataset
				src, tgt using the same dir
			test_path_out: output dir
			load_dir: model dir
			use_gpu: on gpu/cpu
	"""

	# reset batch_size:
	model.max_seq_len = max_seq_len
	print('max seq len {}'.format(model.max_seq_len))
	sys.stdout.flush()

	# load test
	test_set.construct_batches(is_train=False)
	evaliter = iter(test_set.iter_loader)
	print('num batches: {}'.format(len(evaliter)))
	print('batch_size: {}'.format(test_set.batch_size))

	model.eval()
	with torch.no_grad():

		# select batch
		n_total = len(evaliter)
		iter_idx = 0
		per_iter = 500 # 1892809 lines; 100/batch; 38 iterations
		st = iter_idx * per_iter
		ed = min((iter_idx + 1) * per_iter, n_total)
		f = open(os.path.join(test_path_out, '{:04d}.txt'.format(iter_idx)), 'w', encoding="utf8")

		for idx in range(len(evaliter)):
			batch_items = evaliter.next()
			if idx < st:
				continue
			elif idx >= ed:
				break
			print(idx, ed)

			# load data
			src_ids = batch_items['srcid'][0]
			src_lengths = batch_items['srclen']
			tgt_ids = batch_items['tgtid'][0]
			tgt_lengths = batch_items['tgtlen']
			src_len = max(src_lengths)
			tgt_len = max(tgt_lengths)
			src_ids = src_ids[:,:src_len].to(device=device)
			tgt_ids = tgt_ids[:,:tgt_len].to(device=device)

			decoder_outputs, decoder_hidden, other = model(src=src_ids, src_lens=src_lengths,
				is_training=False, beam_width=beam_width, use_gpu=use_gpu)

			# memory usage
			mem_kb, mem_mb, mem_gb = get_memory_alloc()
			mem_mb = round(mem_mb, 2)
			print('Memory used: {0:.2f} MB'.format(mem_mb))

			# write to file
			seqlist = other['sequence']
			seqwords = _convert_to_words(seqlist, test_set.tgt_id2word)
			for i in range(len(seqwords)):
				if src_lengths[i] == 0:
					continue
				words = []
				for word in seqwords[i]:
					if word == '<pad>':
						continue
					elif word == '<spc>':
						words.append(' ')
					elif word == '</s>':
						break
					else:
						words.append(word)
				if len(words) == 0:
					outline = ''
				else:
					if seqrev:
						words = words[::-1]
					if test_set.use_type == 'word':
						outline = ' '.join(words)
					elif test_set.use_type == 'char':
						outline = ''.join(words)
				f.write('{}\n'.format(outline))

			sys.stdout.flush()
def translate(test_set, load_dir, test_path_out, use_gpu, max_seq_len, beam_width, seqrev=False):

	"""
		no reference tgt given - Run translation.
		Args:
			test_set: test dataset
				src, tgt using the same dir
			test_path_out: output dir
			load_dir: model dir
			use_gpu: on gpu/cpu
	"""

	# load model
	# latest_checkpoint_path = Checkpoint.get_latest_checkpoint(load_dir)
	# latest_checkpoint_path = Checkpoint.get_thirdlast_checkpoint(load_dir)
	latest_checkpoint_path = load_dir
	resume_checkpoint = Checkpoint.load(latest_checkpoint_path)

	model = resume_checkpoint.model.to(device)
	print('Model dir: {}'.format(latest_checkpoint_path))
	print('Model laoded')

	# reset batch_size:
	model.reset_max_seq_len(max_seq_len)
	model.reset_use_gpu(use_gpu)
	model.reset_batch_size(test_set.batch_size)
	model.check_var('ptr_net')
	print('max seq len {}'.format(model.max_seq_len))
	sys.stdout.flush()

	# load test
	if type(test_set.attkey_path) == type(None):
		test_batches, vocab_size = test_set.construct_batches(is_train=False)
	else:
		test_batches, vocab_size = test_set.construct_batches_with_ddfd_prob(is_train=False)

	# f = open(os.path.join(test_path_out, 'translate.txt'), 'w') -> use proper encoding
	with open(os.path.join(test_path_out, 'translate.txt'), 'w', encoding="utf8") as f:
		model.eval()
		match = 0
		total = 0
		with torch.no_grad():
			for batch in test_batches:

				src_ids = batch['src_word_ids']
				src_lengths = batch['src_sentence_lengths']
				src_probs = None
				if 'src_ddfd_probs' in batch:
					src_probs =  batch['src_ddfd_probs']
					src_probs = _convert_to_tensor(src_probs, use_gpu).unsqueeze(2)

				src_ids = _convert_to_tensor(src_ids, use_gpu)
				decoder_outputs, decoder_hidden, other = model(src=src_ids,
																is_training=False,
																att_key_feats=src_probs,
																beam_width=beam_width)
				# memory usage
				mem_kb, mem_mb, mem_gb = get_memory_alloc()
				mem_mb = round(mem_mb, 2)
				print('Memory used: {0:.2f} MB'.format(mem_mb))

				# write to file
				seqlist = other['sequence']
				seqwords = _convert_to_words(seqlist, test_set.src_id2word)
				for i in range(len(seqwords)):
					# skip padding sentences in batch (num_sent % batch_size != 0)
					if src_lengths[i] == 0:
						continue
					words = []
					for word in seqwords[i]:
						if word == '<pad>':
							continue
						elif word == '</s>':
							break
						else:
							words.append(word)
					if len(words) == 0:
						outline = ''
					else:
						if seqrev:
							words = words[::-1]
						outline = ' '.join(words)
					f.write('{}\n'.format(outline))
					# if i == 0:
					# 	print(outline)
				sys.stdout.flush()
Ejemplo n.º 3
0
def translate(test_set, load_dir, test_path_out,
	use_gpu, max_seq_len, beam_width, mode='gec', seqrev=False):

	"""
		no reference tgt given - Run translation.
		Args:
			test_set: test dataset
				src, tgt using the same dir
			test_path_out: output dir
			load_dir: model dir
			use_gpu: on gpu/cpu
	"""

	# load model
	# latest_checkpoint_path = Checkpoint.get_latest_checkpoint(load_dir)
	# latest_checkpoint_path = Checkpoint.get_thirdlast_checkpoint(load_dir)
	latest_checkpoint_path = load_dir
	resume_checkpoint = Checkpoint.load(latest_checkpoint_path)

	model = resume_checkpoint.model
	print('Model dir: {}'.format(latest_checkpoint_path))
	print('Model laoded')

	# reset batch_size:
	model.reset_max_seq_len(max_seq_len)
	model.reset_use_gpu(use_gpu)
	model.reset_batch_size(test_set.batch_size)
	# fix compatibility
	model.check_classvar('shared_embed')
	model.check_classvar('additional_key_size')
	model.check_classvar('gec_num_bilstm_dec')
	model.check_classvar('num_unilstm_enc')
	model.check_classvar('residual')
	model.check_classvar('add_discriminator')
	if model.ptr_net == 'none': model.ptr_net = 'null'
	model.to(device)

	print('max seq len {}'.format(model.max_seq_len))
	sys.stdout.flush()

	# load test
	print('--- constrcut gec test set ---')
	if type(test_set.tsv_path) == type(None):
		test_batches, vocab_size = test_set.construct_batches(is_train=False)
	else:
		test_batches, vocab_size = test_set.construct_batches_with_ddfd_prob(is_train=False)

	with open(os.path.join(test_path_out, 'translate.txt'), 'w', encoding="utf8") as f:
		model.eval()
		match = 0
		total = 0
		with torch.no_grad():
			for batch in test_batches:

				src_ids = batch['src_word_ids']
				src_lengths = batch['src_sentence_lengths']
				src_probs = None
				if 'src_ddfd_probs' in batch and model.dd_additional_key_size > 0:
					src_probs =  batch['src_ddfd_probs']
					src_probs = _convert_to_tensor(src_probs, use_gpu).unsqueeze(2)

				src_ids = _convert_to_tensor(src_ids, use_gpu)
				tgt_ids = None

				if mode.lower() == 'gec' and beam_width <= 1:
					gec_dd_decoder_outputs, gec_dd_dec_hidden, gec_dd_ret_dict, \
						decoder_outputs, dec_hidden, ret_dict = \
							model.gec_eval(src_ids, tgt_ids, is_training=False,
									gec_dd_att_key_feats=src_probs, beam_width=beam_width)
				elif mode.lower() == 'gec' and beam_width > 1:
						decoder_outputs, dec_hidden, ret_dict = \
							model.gec_eval(src_ids, tgt_ids, is_training=False,
									gec_dd_att_key_feats=src_probs, beam_width=beam_width)
				elif mode.lower() == 'dd':
					decoder_outputs, dec_hidden, ret_dict = \
							model.dd_eval(src_ids, tgt_ids, is_training=False,
										dd_att_key_feats=src_probs, beam_width=beam_width)
				else:
					assert False, 'Unrecognised eval mode - choose from gec/dd'

				# memory usage
				mem_kb, mem_mb, mem_gb = get_memory_alloc()
				mem_mb = round(mem_mb, 2)
				print('Memory used: {0:.2f} MB'.format(mem_mb))

				# gec output write to file
				# import pdb; pdb.set_trace()
				seqlist = ret_dict['sequence']
				seqwords = _convert_to_words(seqlist, test_set.src_id2word)
				for i in range(len(seqwords)):
					# skip padding sentences in batch (num_sent % batch_size != 0)
					if src_lengths[i] == 0:
						continue
					words = []
					for word in seqwords[i]:
						if word == '<pad>':
							continue
						elif word == '</s>':
							break
						else:
							words.append(word)
					if len(words) == 0:
						outline = ''
					else:
						if seqrev:
							words = words[::-1]
						outline = ' '.join(words)
					f.write('{}\n'.format(outline))
					# if i == 0:
					# 	print(outline)
				sys.stdout.flush()
Ejemplo n.º 4
0
    def _train_epoches(self,
                       train_set,
                       model,
                       n_epochs,
                       start_epoch,
                       start_step,
                       dev_set=None):

        log = self.logger

        las_print_loss_total = 0  # Reset every print_every
        step = start_step
        step_elapsed = 0
        prev_acc = 0.0
        count_no_improve = 0
        count_num_rollback = 0
        ckpt = None

        # ******************** [loop over epochs] ********************
        for epoch in range(start_epoch, n_epochs + 1):

            for param_group in self.optimizer.optimizer.param_groups:
                log.info('epoch:{} lr: {}'.format(epoch, param_group['lr']))
                lr_curr = param_group['lr']

            # ----------construct batches-----------
            log.info('--- construct train set ---')
            train_set.construct_batches(is_train=True)
            if dev_set is not None:
                log.info('--- construct dev set ---')
                dev_set.construct_batches(is_train=True)

            # --------print info for each epoch----------
            steps_per_epoch = len(train_set.iter_loader)
            total_steps = steps_per_epoch * n_epochs
            log.info("steps_per_epoch {}".format(steps_per_epoch))
            log.info("total_steps {}".format(total_steps))

            log.debug(" --------- Epoch: %d, Step: %d ---------" %
                      (epoch, step))
            mem_kb, mem_mb, mem_gb = get_memory_alloc()
            mem_mb = round(mem_mb, 2)
            log.info('Memory used: {0:.2f} MB'.format(mem_mb))
            self.writer.add_scalar('Memory_MB', mem_mb, global_step=step)
            sys.stdout.flush()

            # ******************** [loop over batches] ********************
            model.train(True)
            trainiter = iter(train_set.iter_loader)
            for idx in range(steps_per_epoch):

                # load batch items
                batch_items = trainiter.next()

                # update macro count
                step += 1
                step_elapsed += 1

                # Get loss
                losses = self._train_batch(model, batch_items, train_set, step,
                                           total_steps)

                las_loss = losses['las_loss']
                las_print_loss_total += las_loss

                if step % self.print_every == 0 and step_elapsed > self.print_every:
                    las_print_loss_avg = las_print_loss_total / self.print_every
                    las_print_loss_total = 0

                    log_msg = 'Progress: %d%%, Train las: %.4f'\
                     % (step / total_steps * 100, las_print_loss_avg)

                    log.info(log_msg)
                    self.writer.add_scalar('train_las_loss',
                                           las_print_loss_avg,
                                           global_step=step)

                # Checkpoint
                if step % self.checkpoint_every == 0 or step == total_steps:

                    # save criteria
                    if dev_set is not None:
                        dev_accs, dev_losses = self._evaluate_batches(
                            model, dev_set)
                        las_loss = dev_losses['las_loss']
                        las_acc = dev_accs['las_acc']
                        log_msg = 'Progress: %d%%, Dev las loss: %.4f, accuracy: %.4f'\
                         % (step / total_steps * 100, las_loss, las_acc)
                        log.info(log_msg)
                        self.writer.add_scalar('dev_las_loss',
                                               las_loss,
                                               global_step=step)
                        self.writer.add_scalar('dev_las_acc',
                                               las_acc,
                                               global_step=step)

                        accuracy = las_acc
                        # save
                        if prev_acc < accuracy:
                            # save best model
                            ckpt = Checkpoint(model=model,
                                              optimizer=self.optimizer,
                                              epoch=epoch,
                                              step=step,
                                              input_vocab=train_set.vocab_src,
                                              output_vocab=train_set.vocab_src)

                            saved_path = ckpt.save(self.expt_dir)
                            log.info('saving at {} ... '.format(saved_path))
                            # reset
                            prev_acc = accuracy
                            count_no_improve = 0
                            count_num_rollback = 0
                        else:
                            count_no_improve += 1

                        # roll back
                        if count_no_improve > self.max_count_no_improve:
                            # resuming
                            latest_checkpoint_path = Checkpoint.get_latest_checkpoint(
                                self.expt_dir)
                            if type(latest_checkpoint_path) != type(None):
                                resume_checkpoint = Checkpoint.load(
                                    latest_checkpoint_path)
                                log.info(
                                    'epoch:{} step: {} - rolling back {} ...'.
                                    format(epoch, step,
                                           latest_checkpoint_path))
                                model = resume_checkpoint.model
                                self.optimizer = resume_checkpoint.optimizer
                                # A walk around to set optimizing parameters properly
                                resume_optim = self.optimizer.optimizer
                                defaults = resume_optim.param_groups[0]
                                defaults.pop('params', None)
                                defaults.pop('initial_lr', None)
                                self.optimizer.optimizer = resume_optim.__class__(
                                    model.parameters(), **defaults)

                            # reset
                            count_no_improve = 0
                            count_num_rollback += 1

                        # update learning rate
                        if count_num_rollback > self.max_count_num_rollback:

                            # roll back
                            latest_checkpoint_path = Checkpoint.get_latest_checkpoint(
                                self.expt_dir)
                            if type(latest_checkpoint_path) != type(None):
                                resume_checkpoint = Checkpoint.load(
                                    latest_checkpoint_path)
                                log.info(
                                    'epoch:{} step: {} - rolling back {} ...'.
                                    format(epoch, step,
                                           latest_checkpoint_path))
                                model = resume_checkpoint.model
                                self.optimizer = resume_checkpoint.optimizer
                                # A walk around to set optimizing parameters properly
                                resume_optim = self.optimizer.optimizer
                                defaults = resume_optim.param_groups[0]
                                defaults.pop('params', None)
                                defaults.pop('initial_lr', None)
                                self.optimizer.optimizer = resume_optim.__class__(
                                    model.parameters(), **defaults)

                            # decrease lr
                            for param_group in self.optimizer.optimizer.param_groups:
                                param_group['lr'] *= 0.5
                                lr_curr = param_group['lr']
                                log.info('reducing lr ...')
                                log.info('step:{} - lr: {}'.format(
                                    step, param_group['lr']))

                            # check early stop
                            if lr_curr < 0.125 * self.learning_rate:
                                log.info('early stop ...')
                                break

                            # reset
                            count_no_improve = 0
                            count_num_rollback = 0

                        model.train(mode=True)
                        if ckpt is None:
                            ckpt = Checkpoint(model=model,
                                              optimizer=self.optimizer,
                                              epoch=epoch,
                                              step=step,
                                              input_vocab=train_set.vocab_src,
                                              output_vocab=train_set.vocab_tgt)
                        ckpt.rm_old(self.expt_dir, keep_num=self.keep_num)
                        log.info('n_no_improve {}, num_rollback {}'.format(
                            count_no_improve, count_num_rollback))

                    sys.stdout.flush()

            else:
                if dev_set is None:
                    # save every epoch if no dev_set
                    ckpt = Checkpoint(model=model,
                                      optimizer=self.optimizer,
                                      epoch=epoch,
                                      step=step,
                                      input_vocab=train_set.vocab_src,
                                      output_vocab=train_set.vocab_src)
                    saved_path = ckpt.save_epoch(self.expt_dir, epoch)
                    log.info('saving at {} ... '.format(saved_path))
                    continue

                else:
                    continue

            # break nested for loop
            break
Ejemplo n.º 5
0
    def _train_epochs(self,
                      train_set,
                      model,
                      n_epochs,
                      start_epoch,
                      start_step,
                      dev_set=None):

        log = self.logger

        print_loss_total = 0  # Reset every print_every
        step = start_step
        step_elapsed = 0
        prev_acc = 0.0
        prev_bleu = 0.0
        count_no_improve = 0
        count_num_rollback = 0
        ckpt = None

        # loop over epochs
        for epoch in range(start_epoch, n_epochs + 1):

            # update lr
            if self.lr_warmup_steps != 0:
                self.optimizer.optimizer = self.lr_scheduler(
                    self.optimizer.optimizer,
                    step,
                    init_lr=self.learning_rate_init,
                    peak_lr=self.learning_rate,
                    warmup_steps=self.lr_warmup_steps)

            # print lr
            for param_group in self.optimizer.optimizer.param_groups:
                log.info('epoch:{} lr: {}'.format(epoch, param_group['lr']))
                lr_curr = param_group['lr']

            # construct batches - allow re-shuffling of data
            log.info('--- construct train set ---')
            train_set.construct_batches(is_train=True)
            if dev_set is not None:
                log.info('--- construct dev set ---')
                dev_set.construct_batches(is_train=False)

            # print info
            steps_per_epoch = len(train_set.iter_loader)
            total_steps = steps_per_epoch * n_epochs
            log.info("steps_per_epoch {}".format(steps_per_epoch))
            log.info("total_steps {}".format(total_steps))

            log.info(" ---------- Epoch: %d, Step: %d ----------" %
                     (epoch, step))
            mem_kb, mem_mb, mem_gb = get_memory_alloc()
            mem_mb = round(mem_mb, 2)
            log.info('Memory used: {0:.2f} MB'.format(mem_mb))
            self.writer.add_scalar('Memory_MB', mem_mb, global_step=step)
            sys.stdout.flush()

            # loop over batches
            model.train(True)
            trainiter = iter(train_set.iter_loader)
            for idx in range(steps_per_epoch):

                # load batch items
                batch_items = trainiter.next()

                # update macro count
                step += 1
                step_elapsed += 1

                if self.lr_warmup_steps != 0:
                    self.optimizer.optimizer = self.lr_scheduler(
                        self.optimizer.optimizer,
                        step,
                        init_lr=self.learning_rate_init,
                        peak_lr=self.learning_rate,
                        warmup_steps=self.lr_warmup_steps)

                # Get loss
                loss = self._train_batch(model, batch_items, train_set, step,
                                         total_steps)
                print_loss_total += loss

                if step % self.print_every == 0 and step_elapsed > self.print_every:
                    print_loss_avg = print_loss_total / self.print_every
                    print_loss_total = 0

                    log_msg = 'Progress: %d%%, Train nlll: %.4f' % (
                        step / total_steps * 100, print_loss_avg)

                    log.info(log_msg)
                    self.writer.add_scalar('train_loss',
                                           print_loss_avg,
                                           global_step=step)

                # Checkpoint
                if step % self.checkpoint_every == 0 or step == total_steps:

                    # save criteria
                    if dev_set is not None:
                        losses, metrics = self._evaluate_batches(
                            model, dev_set)

                        loss = losses['nll_loss']
                        accuracy = metrics['accuracy']
                        bleu = metrics['bleu']
                        log_msg = 'Progress: %d%%, Dev loss: %.4f, accuracy: %.4f, bleu: %.4f' % (
                            step / total_steps * 100, loss, accuracy, bleu)
                        log.info(log_msg)
                        self.writer.add_scalar('dev_loss',
                                               loss,
                                               global_step=step)
                        self.writer.add_scalar('dev_acc',
                                               accuracy,
                                               global_step=step)
                        self.writer.add_scalar('dev_bleu',
                                               bleu,
                                               global_step=step)

                        # save condition
                        cond_acc = (prev_acc <= accuracy)
                        cond_bleu = (((prev_acc <= accuracy) and (bleu < 0.1))
                                     or prev_bleu <= bleu)

                        # save
                        if self.eval_metric == 'tokacc':
                            save_cond = cond_acc
                        elif self.eval_metric == 'bleu':
                            save_cond = cond_bleu
                        if save_cond:
                            # save best model
                            ckpt = Checkpoint(model=model,
                                              optimizer=self.optimizer,
                                              epoch=epoch,
                                              step=step,
                                              input_vocab=train_set.vocab_src,
                                              output_vocab=train_set.vocab_tgt)

                            saved_path = ckpt.save(self.expt_dir)
                            log.info('saving at {} ... '.format(saved_path))
                            # reset
                            prev_acc = accuracy
                            prev_bleu = bleu
                            count_no_improve = 0
                            count_num_rollback = 0
                        else:
                            count_no_improve += 1

                        # roll back
                        if count_no_improve > self.max_count_no_improve:
                            # no roll back - break after self.max_count_no_improve epochs
                            if self.max_count_num_rollback == 0:
                                break
                            # resuming
                            latest_checkpoint_path = Checkpoint.get_latest_checkpoint(
                                self.expt_dir)
                            if type(latest_checkpoint_path) != type(None):
                                resume_checkpoint = Checkpoint.load(
                                    latest_checkpoint_path)
                                log.info(
                                    'epoch:{} step: {} - rolling back {} ...'.
                                    format(epoch, step,
                                           latest_checkpoint_path))
                                model = resume_checkpoint.model
                                self.optimizer = resume_checkpoint.optimizer
                                # A walk around to set optimizing parameters properly
                                resume_optim = self.optimizer.optimizer
                                defaults = resume_optim.param_groups[0]
                                defaults.pop('params', None)
                                defaults.pop('initial_lr', None)
                                self.optimizer.optimizer = resume_optim.__class__(
                                    model.parameters(), **defaults)

                            # reset
                            count_no_improve = 0
                            count_num_rollback += 1

                        # update learning rate
                        if count_num_rollback > self.max_count_num_rollback:

                            # roll back
                            latest_checkpoint_path = Checkpoint.get_latest_checkpoint(
                                self.expt_dir)
                            if type(latest_checkpoint_path) != type(None):
                                resume_checkpoint = Checkpoint.load(
                                    latest_checkpoint_path)
                                log.info(
                                    'epoch:{} step: {} - rolling back {} ...'.
                                    format(epoch, step,
                                           latest_checkpoint_path))
                                model = resume_checkpoint.model
                                self.optimizer = resume_checkpoint.optimizer
                                # A walk around to set optimizing parameters properly
                                resume_optim = self.optimizer.optimizer
                                defaults = resume_optim.param_groups[0]
                                defaults.pop('params', None)
                                defaults.pop('initial_lr', None)
                                self.optimizer.optimizer = resume_optim.__class__(
                                    model.parameters(), **defaults)

                            # decrease lr
                            for param_group in self.optimizer.optimizer.param_groups:
                                param_group['lr'] *= 0.5
                                lr_curr = param_group['lr']
                                log.info('reducing lr ...')
                                log.info('step:{} - lr: {}'.format(
                                    step, param_group['lr']))

                            # check early stop
                            if lr_curr <= 0.125 * self.learning_rate:
                                log.info('early stop ...')
                                break

                            # reset
                            count_no_improve = 0
                            count_num_rollback = 0

                        model.train(mode=True)
                        if ckpt is None:
                            ckpt = Checkpoint(model=model,
                                              optimizer=self.optimizer,
                                              epoch=epoch,
                                              step=step,
                                              input_vocab=train_set.vocab_src,
                                              output_vocab=train_set.vocab_tgt)
                        ckpt.rm_old(self.expt_dir, keep_num=self.keep_num)
                        log.info('n_no_improve {}, num_rollback {}'.format(
                            count_no_improve, count_num_rollback))

                    sys.stdout.flush()

            else:
                if dev_set is None:
                    # save every epoch if no dev_set
                    ckpt = Checkpoint(model=model,
                                      optimizer=self.optimizer,
                                      epoch=epoch,
                                      step=step,
                                      input_vocab=train_set.vocab_src,
                                      output_vocab=train_set.vocab_tgt)
                    saved_path = ckpt.save_epoch(self.expt_dir, epoch)
                    log.info('saving at {} ... '.format(saved_path))
                    continue

                else:
                    continue

            # break nested for loop
            break
Ejemplo n.º 6
0
def translate_acous(test_set,
                    load_dir,
                    test_path_out,
                    use_gpu,
                    max_seq_len,
                    beam_width,
                    seqrev=False):
    """
		no reference tgt given - Run translation.
		Args:
			test_set: test dataset
				src, tgt using the same dir
			test_path_out: output dir
			load_dir: model dir
			use_gpu: on gpu/cpu
	"""

    # load model
    latest_checkpoint_path = load_dir
    resume_checkpoint = Checkpoint.load(latest_checkpoint_path)

    model = resume_checkpoint.model.to(device)
    print('Model dir: {}'.format(latest_checkpoint_path))
    print('Model laoded')

    # reset batch_size:
    model.reset_max_seq_len(max_seq_len)
    print('max seq len {}'.format(model.max_seq_len))
    sys.stdout.flush()

    # load test
    test_set.construct_batches(is_train=False)
    evaliter = iter(test_set.iter_loader)
    total = 0
    print('total #batches: {}'.format(len(evaliter)))

    f2 = open(os.path.join(test_path_out, 'translate.tsv'), 'w')
    with open(os.path.join(test_path_out, 'translate.txt'),
              'w',
              encoding="utf8") as f:
        model.eval()
        with torch.no_grad():
            for idx in range(len(evaliter)):
                batch_items = evaliter.next()
                src_ids = batch_items[0][0].to(device=device)
                src_lengths = batch_items[1]
                acous_feats = batch_items[2][0].to(device=device)
                acous_lengths = batch_items[3]
                labs = batch_items[4][0].to(device=device)
                acous_times = batch_items[5]

                batch_size = src_ids.size(0)
                seq_len = int(max(src_lengths))
                acous_len = int(max(acous_lengths))

                decoder_outputs, decoder_hidden, other = model(
                    src_ids,
                    acous_feats=acous_feats,
                    acous_times=acous_times,
                    is_training=False,
                    use_gpu=use_gpu,
                    beam_width=beam_width)

                # memory usage
                mem_kb, mem_mb, mem_gb = get_memory_alloc()
                mem_mb = round(mem_mb, 2)
                print('Memory used: {0:.2f} MB'.format(mem_mb))
                batch_size = src_ids.size(0)

                model.check_var('add_times', False)
                if not model.add_times:

                    # write to file
                    # import pdb; pdb.set_trace()
                    seqlist = other['sequence']
                    seqwords = _convert_to_words(seqlist, test_set.src_id2word)

                    # print las output
                    model.check_var('use_type', 'char')
                    total += len(seqwords)
                    if model.use_type == 'char':
                        for i in range(len(seqwords)):
                            # skip padding sentences in batch (num_sent % batch_size != 0)
                            if src_lengths[i] == 0:
                                continue
                            words = []
                            for word in seqwords[i]:
                                if word == '<pad>':
                                    continue
                                elif word == '</s>':
                                    break
                                elif word == '<spc>':
                                    words.append(' ')
                                else:
                                    words.append(word)
                            if len(words) == 0:
                                outline = ''
                            else:
                                if seqrev:
                                    words = words[::-1]
                                outline = ''.join(words)
                            f.write('{}\n'.format(outline))

                    elif model.use_type == 'word' or model.use_type == 'bpe':
                        for i in range(len(seqwords)):
                            # skip padding sentences in batch (num_sent % batch_size != 0)
                            if src_lengths[i] == 0:
                                continue
                            words = []
                            for word in seqwords[i]:
                                if word == '<pad>':
                                    continue
                                elif word == '</s>':
                                    break
                                else:
                                    words.append(word)
                            if len(words) == 0:
                                outline = ''
                            else:
                                if seqrev:
                                    words = words[::-1]
                                outline = ' '.join(words)
                            f.write('{}\n'.format(outline))

                srcwords = _convert_to_words_batchfirst(
                    src_ids, test_set.src_id2word)
                dd_ps = other['classify_prob']

                # print dd output
                for i in range(len(srcwords)):
                    for j in range(len(srcwords[i])):
                        word = srcwords[i][j]
                        prob = dd_ps[i][j].data[0]
                        if word == '<pad>':
                            break
                        elif word == '</s>':
                            break
                        else:
                            f2.write('{}\t{}\n'.format(word, prob))
                    f2.write('\n')

                sys.stdout.flush()

    print('total #sent: {}'.format(total))
    f2.close()
Ejemplo n.º 7
0
def translate(test_set,
              load_dir,
              test_path_out,
              use_gpu,
              max_seq_len,
              beam_width,
              seqrev=False):
    """
		no reference tgt given - Run translation.
		Args:
			test_set: test dataset
				src, tgt using the same dir
			test_path_out: output dir
			load_dir: model dir
			use_gpu: on gpu/cpu
	"""

    # load model
    latest_checkpoint_path = load_dir
    resume_checkpoint = Checkpoint.load(latest_checkpoint_path)

    model = resume_checkpoint.model.to(device)
    print('Model dir: {}'.format(latest_checkpoint_path))
    print('Model laoded')

    # reset batch_size:
    model.reset_max_seq_len(max_seq_len)
    print('max seq len {}'.format(model.max_seq_len))
    sys.stdout.flush()

    # load test
    test_set.construct_batches(is_train=False)
    evaliter = iter(test_set.iter_loader)
    total = 0

    with open(os.path.join(test_path_out, 'translate.tsv'),
              'w',
              encoding="utf8") as f:
        model.eval()
        with torch.no_grad():
            for idx in range(len(evaliter)):
                batch_items = evaliter.next()
                src_ids = batch_items[0][0].to(device=device)
                src_lengths = batch_items[1]
                labs = batch_items[4][0].to(device=device)

                batch_size = src_ids.size(0)
                seq_len = int(max(src_lengths))

                decoder_outputs, decoder_hidden, other = model(
                    src_ids,
                    is_training=False,
                    use_gpu=use_gpu,
                    beam_width=beam_width)

                # memory usage
                mem_kb, mem_mb, mem_gb = get_memory_alloc()
                mem_mb = round(mem_mb, 2)
                print('Memory used: {0:.2f} MB'.format(mem_mb))
                batch_size = src_ids.size(0)

                # write to file
                # import pdb; pdb.set_trace()
                srcwords = _convert_to_words_batchfirst(
                    src_ids, test_set.src_id2word)
                dd_ps = other['classify_prob']

                # print dd output
                total += len(srcwords)
                for i in range(len(srcwords)):
                    for j in range(len(srcwords[i])):
                        word = srcwords[i][j]
                        prob = dd_ps[i][j].data[0]
                        if word == '<pad>':
                            break
                        elif word == '</s>':
                            break
                        else:
                            f.write('{}\t{}\n'.format(word, prob))
                    f.write('\n')
                sys.stdout.flush()

    print('total #sent: {}'.format(total))
    def _train_epoches(self,
                       train_sets,
                       model,
                       n_epochs,
                       start_epoch,
                       start_step,
                       dev_sets=None):

        # load datasets
        train_set_asr = train_sets['asr']
        dev_set_asr = dev_sets['asr']
        train_set_mt = train_sets['mt']
        dev_set_mt = dev_sets['mt']

        log = self.logger

        print_loss_ae_total = 0  # Reset every print_every
        print_loss_asr_total = 0
        print_loss_mt_total = 0
        print_loss_kl_total = 0
        print_loss_l2_total = 0

        step = start_step
        step_elapsed = 0
        prev_acc = 0.0
        prev_bleu = 0.0
        count_no_improve = 0
        count_num_rollback = 0
        ckpt = None

        # loop over epochs
        for epoch in range(start_epoch, n_epochs + 1):

            # update lr
            if self.lr_warmup_steps != 0:
                self.optimizer.optimizer = self.lr_scheduler(
                    self.optimizer.optimizer,
                    step,
                    init_lr=self.learning_rate_init,
                    peak_lr=self.learning_rate,
                    warmup_steps=self.lr_warmup_steps)
            # print lr
            for param_group in self.optimizer.optimizer.param_groups:
                log.info('epoch:{} lr: {}'.format(epoch, param_group['lr']))
                lr_curr = param_group['lr']

            # construct batches - allow re-shuffling of data
            log.info('--- construct train set ---')
            train_set_asr.construct_batches(is_train=True)
            train_set_mt.construct_batches(is_train=True)
            if dev_set_asr is not None:
                log.info('--- construct dev set ---')
                dev_set_asr.construct_batches(is_train=False)
                dev_set_mt.construct_batches(is_train=False)

            # print info
            steps_per_epoch_asr = len(train_set_asr.iter_loader)
            steps_per_epoch_mt = len(train_set_mt.iter_loader)
            steps_per_epoch = min(steps_per_epoch_asr, steps_per_epoch_mt)
            total_steps = steps_per_epoch * n_epochs
            log.info("steps_per_epoch {}".format(steps_per_epoch))
            log.info("total_steps {}".format(total_steps))

            log.info(" ---------- Epoch: %d, Step: %d ----------" %
                     (epoch, step))
            mem_kb, mem_mb, mem_gb = get_memory_alloc()
            mem_mb = round(mem_mb, 2)
            log.info('Memory used: {0:.2f} MB'.format(mem_mb))
            self.writer.add_scalar('Memory_MB', mem_mb, global_step=step)
            sys.stdout.flush()

            # loop over batches
            model.train(True)
            trainiter_asr = iter(train_set_asr.iter_loader)
            trainiter_mt = iter(train_set_mt.iter_loader)
            for idx in range(steps_per_epoch):

                # load batch items
                batch_items_asr = trainiter_asr.next()
                batch_items_mt = trainiter_mt.next()

                # update macro count
                step += 1
                step_elapsed += 1

                if self.lr_warmup_steps != 0:
                    self.optimizer.optimizer = self.lr_scheduler(
                        self.optimizer.optimizer,
                        step,
                        init_lr=self.learning_rate_init,
                        peak_lr=self.learning_rate,
                        warmup_steps=self.lr_warmup_steps)

                # Get loss
                losses = self._train_batch(model, batch_items_asr,
                                           batch_items_mt, step, total_steps)
                loss_ae = losses['nll_loss_ae']
                loss_asr = losses['nll_loss_asr']
                loss_mt = losses['nll_loss_mt']
                loss_kl = losses['kl_loss']
                loss_l2 = losses['l2_loss']

                print_loss_ae_total += loss_ae
                print_loss_asr_total += loss_asr
                print_loss_mt_total += loss_mt
                print_loss_kl_total += loss_kl
                print_loss_l2_total += loss_l2

                if step % self.print_every == 0 and step_elapsed > self.print_every:
                    print_loss_ae_avg = print_loss_ae_total / self.print_every
                    print_loss_ae_total = 0
                    print_loss_asr_avg = print_loss_asr_total / self.print_every
                    print_loss_asr_total = 0
                    print_loss_mt_avg = print_loss_mt_total / self.print_every
                    print_loss_mt_total = 0
                    print_loss_kl_avg = print_loss_kl_total / self.print_every
                    print_loss_kl_total = 0
                    print_loss_l2_avg = print_loss_l2_total / self.print_every
                    print_loss_l2_total = 0

                    log_msg = 'Progress: %d%%, Train nlll_ae: %.4f, nlll_asr: %.4f, ' % (
                        step / total_steps * 100, print_loss_ae_avg,
                        print_loss_asr_avg)
                    log_msg += 'Train nlll_mt: %.4f, l2: %.4f, kl_en: %.4f' % (
                        print_loss_mt_avg, print_loss_l2_avg,
                        print_loss_kl_avg)
                    log.info(log_msg)

                    self.writer.add_scalar('train_loss_ae',
                                           print_loss_ae_avg,
                                           global_step=step)
                    self.writer.add_scalar('train_loss_asr',
                                           print_loss_asr_avg,
                                           global_step=step)
                    self.writer.add_scalar('train_loss_mt',
                                           print_loss_mt_avg,
                                           global_step=step)
                    self.writer.add_scalar('train_loss_kl',
                                           print_loss_kl_avg,
                                           global_step=step)
                    self.writer.add_scalar('train_loss_l2',
                                           print_loss_l2_avg,
                                           global_step=step)

                # Checkpoint
                if step % self.checkpoint_every == 0 or step == total_steps:

                    # save criteria
                    if dev_set_asr is not None:
                        losses, metrics = self._evaluate_batches(
                            model, dev_set_asr, dev_set_mt)

                        loss_kl = losses['kl_loss']
                        loss_l2 = losses['l2_loss']
                        loss_ae = losses['nll_loss_ae']
                        accuracy_ae = metrics['accuracy_ae']
                        bleu_ae = metrics['bleu_ae']
                        loss_asr = losses['nll_loss_asr']
                        accuracy_asr = metrics['accuracy_asr']
                        bleu_asr = metrics['bleu_asr']
                        loss_mt = losses['nll_loss_mt']
                        accuracy_mt = metrics['accuracy_mt']
                        bleu_mt = metrics['bleu_mt']

                        log_msg = 'Progress: %d%%, Dev AE loss: %.4f, accuracy: %.4f, bleu: %.4f' % (
                            step / total_steps * 100, loss_ae, accuracy_ae,
                            bleu_ae)
                        log.info(log_msg)
                        log_msg = 'Progress: %d%%, Dev ASR loss: %.4f, accuracy: %.4f, bleu: %.4f' % (
                            step / total_steps * 100, loss_asr, accuracy_asr,
                            bleu_asr)
                        log.info(log_msg)
                        log_msg = 'Progress: %d%%, Dev MT loss: %.4f, accuracy: %.4f, bleu: %.4f' % (
                            step / total_steps * 100, loss_mt, accuracy_mt,
                            bleu_mt)
                        log.info(log_msg)
                        log_msg = 'Progress: %d%%, Dev En KL loss: %.4f, L2 loss: %.4f' % (
                            step / total_steps * 100, loss_kl, loss_l2)
                        log.info(log_msg)

                        self.writer.add_scalar('dev_loss_l2',
                                               loss_l2,
                                               global_step=step)
                        self.writer.add_scalar('dev_loss_kl',
                                               loss_kl,
                                               global_step=step)
                        self.writer.add_scalar('dev_loss_ae',
                                               loss_ae,
                                               global_step=step)
                        self.writer.add_scalar('dev_acc_ae',
                                               accuracy_ae,
                                               global_step=step)
                        self.writer.add_scalar('dev_bleu_ae',
                                               bleu_ae,
                                               global_step=step)
                        self.writer.add_scalar('dev_loss_asr',
                                               loss_asr,
                                               global_step=step)
                        self.writer.add_scalar('dev_acc_asr',
                                               accuracy_asr,
                                               global_step=step)
                        self.writer.add_scalar('dev_bleu_asr',
                                               bleu_asr,
                                               global_step=step)
                        self.writer.add_scalar('dev_loss_mt',
                                               loss_mt,
                                               global_step=step)
                        self.writer.add_scalar('dev_acc_mt',
                                               accuracy_mt,
                                               global_step=step)
                        self.writer.add_scalar('dev_bleu_mt',
                                               bleu_mt,
                                               global_step=step)

                        # save - use ASR res
                        accuracy_ave = (accuracy_asr / 4.0 + accuracy_mt) / 2.0
                        bleu_ave = (bleu_asr / 4.0 + bleu_mt) / 2.0
                        if ((prev_acc < accuracy_ave) and
                            (bleu_ave < 0.1)) or prev_bleu < bleu_ave:

                            # save best model - using bleu as metric
                            ckpt = Checkpoint(
                                model=model,
                                optimizer=self.optimizer,
                                epoch=epoch,
                                step=step,
                                input_vocab=train_set_asr.vocab_src,
                                output_vocab=train_set_asr.vocab_tgt)

                            saved_path = ckpt.save(self.expt_dir)
                            log.info('saving at {} ... '.format(saved_path))
                            # reset
                            prev_acc = accuracy_ave
                            prev_bleu = bleu_ave
                            count_no_improve = 0
                            count_num_rollback = 0
                        else:
                            count_no_improve += 1

                        # roll back
                        if count_no_improve > self.max_count_no_improve:
                            # break after self.max_count_no_improve epochs
                            if self.max_count_num_rollback == 0:
                                break
                            # resuming
                            latest_checkpoint_path = Checkpoint.get_latest_checkpoint(
                                self.expt_dir)
                            if type(latest_checkpoint_path) != type(None):
                                resume_checkpoint = Checkpoint.load(
                                    latest_checkpoint_path)
                                log.info(
                                    'epoch:{} step: {} - rolling back {} ...'.
                                    format(epoch, step,
                                           latest_checkpoint_path))
                                model = resume_checkpoint.model
                                self.optimizer = resume_checkpoint.optimizer
                                # A walk around to set optimizing parameters properly
                                resume_optim = self.optimizer.optimizer
                                defaults = resume_optim.param_groups[0]
                                defaults.pop('params', None)
                                defaults.pop('initial_lr', None)
                                self.optimizer.optimizer = resume_optim.__class__(
                                    model.parameters(), **defaults)

                            # reset
                            count_no_improve = 0
                            count_num_rollback += 1

                        # update learning rate
                        if count_num_rollback > self.max_count_num_rollback:

                            # roll back
                            latest_checkpoint_path = Checkpoint.get_latest_checkpoint(
                                self.expt_dir)
                            if type(latest_checkpoint_path) != type(None):
                                resume_checkpoint = Checkpoint.load(
                                    latest_checkpoint_path)
                                log.info(
                                    'epoch:{} step: {} - rolling back {} ...'.
                                    format(epoch, step,
                                           latest_checkpoint_path))
                                model = resume_checkpoint.model
                                self.optimizer = resume_checkpoint.optimizer
                                # A walk around to set optimizing parameters properly
                                resume_optim = self.optimizer.optimizer
                                defaults = resume_optim.param_groups[0]
                                defaults.pop('params', None)
                                defaults.pop('initial_lr', None)
                                self.optimizer.optimizer = resume_optim.__class__(
                                    model.parameters(), **defaults)

                            # decrease lr
                            for param_group in self.optimizer.optimizer.param_groups:
                                param_group['lr'] *= 0.5
                                lr_curr = param_group['lr']
                                log.info('reducing lr ...')
                                log.info('step:{} - lr: {}'.format(
                                    step, param_group['lr']))

                            # check early stop
                            if lr_curr <= 0.125 * self.learning_rate:
                                log.info('early stop ...')
                                break

                            # reset
                            count_no_improve = 0
                            count_num_rollback = 0

                        model.train(mode=True)
                        if ckpt is not None:
                            ckpt.rm_old(self.expt_dir, keep_num=self.keep_num)
                        log.info('n_no_improve {}, num_rollback {}'.format(
                            count_no_improve, count_num_rollback))

                    sys.stdout.flush()

            else:
                if dev_set_asr is None:
                    # save every epoch if no dev_set
                    ckpt = Checkpoint(model=model,
                                      optimizer=self.optimizer,
                                      epoch=epoch,
                                      step=step,
                                      input_vocab=train_set_asr.vocab_src,
                                      output_vocab=train_set_asr.vocab_tgt)
                    saved_path = ckpt.save_epoch(self.expt_dir, epoch)
                    log.info('saving at {} ... '.format(saved_path))
                    continue

                else:
                    continue

            # break nested for loop
            break
Ejemplo n.º 9
0
    def _train_epoches(self,
                       train_set,
                       model,
                       n_epochs,
                       start_epoch,
                       start_step,
                       dev_set=None):

        log = self.logger

        print_loss_total = 0  # Reset every print_every
        epoch_loss_total = 0  # Reset every epoch
        att_print_loss_total = 0  # Reset every print_every
        att_epoch_loss_total = 0  # Reset every epoch
        attcls_print_loss_total = 0  # Reset every print_every
        attcls_epoch_loss_total = 0  # Reset every epoch

        step = start_step
        step_elapsed = 0
        prev_acc = 0.0
        count_no_improve = 0
        count_num_rollback = 0
        ckpt = None

        # ******************** [loop over epochs] ********************
        for epoch in range(start_epoch, n_epochs + 1):

            for param_group in self.optimizer.optimizer.param_groups:
                print('epoch:{} lr: {}'.format(epoch, param_group['lr']))
                lr_curr = param_group['lr']

            # ----------construct batches-----------
            # allow re-shuffling of data
            if type(train_set.attkey_path) == type(None):
                print('--- construct train set ---')
                train_batches, vocab_size = train_set.construct_batches(
                    is_train=True)
                if dev_set is not None:
                    print('--- construct dev set ---')
                    dev_batches, vocab_size = dev_set.construct_batches(
                        is_train=False)
            else:
                print('--- construct train set ---')
                train_batches, vocab_size = train_set.construct_batches_with_ddfd_prob(
                    is_train=True)
                if dev_set is not None:
                    print('--- construct dev set ---')
                    assert type(dev_set.attkey_path) != type(
                        None), 'Dev set missing ddfd probabilities'
                    dev_batches, vocab_size = dev_set.construct_batches_with_ddfd_prob(
                        is_train=False)

            # --------print info for each epoch----------
            steps_per_epoch = len(train_batches)
            total_steps = steps_per_epoch * n_epochs
            log.info("steps_per_epoch {}".format(steps_per_epoch))
            log.info("total_steps {}".format(total_steps))

            log.debug(
                " ----------------- Epoch: %d, Step: %d -----------------" %
                (epoch, step))
            mem_kb, mem_mb, mem_gb = get_memory_alloc()
            mem_mb = round(mem_mb, 2)
            print('Memory used: {0:.2f} MB'.format(mem_mb))
            self.writer.add_scalar('Memory_MB', mem_mb, global_step=step)
            sys.stdout.flush()

            # ******************** [loop over batches] ********************
            model.train(True)
            for batch in train_batches:

                # update macro count
                step += 1
                step_elapsed += 1

                # load data
                src_ids = batch['src_word_ids']
                src_lengths = batch['src_sentence_lengths']
                tgt_ids = batch['tgt_word_ids']
                tgt_lengths = batch['tgt_sentence_lengths']

                src_probs = None
                src_labs = None
                if 'src_ddfd_probs' in batch and model.additional_key_size > 0:
                    src_probs = batch['src_ddfd_probs']
                    src_probs = _convert_to_tensor(src_probs,
                                                   self.use_gpu).unsqueeze(2)
                if 'src_ddfd_labs' in batch:
                    src_labs = batch['src_ddfd_labs']
                    src_labs = _convert_to_tensor(src_labs,
                                                  self.use_gpu).unsqueeze(2)

                # sanity check src-tgt pair
                if step == 1:
                    print('--- Check src tgt pair ---')
                    log_msgs = check_srctgt(src_ids, tgt_ids,
                                            train_set.src_id2word,
                                            train_set.tgt_id2word)
                    for log_msg in log_msgs:
                        sys.stdout.buffer.write(log_msg)

                # convert variable to tensor
                src_ids = _convert_to_tensor(src_ids, self.use_gpu)
                tgt_ids = _convert_to_tensor(tgt_ids, self.use_gpu)

                # Get loss
                loss, att_loss, attcls_loss = self._train_batch(
                    src_ids,
                    tgt_ids,
                    model,
                    step,
                    total_steps,
                    src_probs=src_probs,
                    src_labs=src_labs)

                print_loss_total += loss
                epoch_loss_total += loss
                att_print_loss_total += att_loss
                att_epoch_loss_total += att_loss
                attcls_print_loss_total += attcls_loss
                attcls_epoch_loss_total += attcls_loss

                if step % self.print_every == 0 and step_elapsed > self.print_every:
                    print_loss_avg = print_loss_total / self.print_every
                    att_print_loss_avg = att_print_loss_total / self.print_every
                    attcls_print_loss_avg = attcls_print_loss_total / self.print_every
                    print_loss_total = 0
                    att_print_loss_total = 0
                    attcls_print_loss_total = 0

                    log_msg = 'Progress: %d%%, Train nlll: %.4f, att: %.4f, attcls: %.4f' % (
                        step / total_steps * 100, print_loss_avg,
                        att_print_loss_avg, attcls_print_loss_avg)
                    # print(log_msg)
                    log.info(log_msg)
                    self.writer.add_scalar('train_loss',
                                           print_loss_avg,
                                           global_step=step)
                    self.writer.add_scalar('att_train_loss',
                                           att_print_loss_avg,
                                           global_step=step)
                    self.writer.add_scalar('attcls_train_loss',
                                           attcls_print_loss_avg,
                                           global_step=step)

                # Checkpoint
                if step % self.checkpoint_every == 0 or step == total_steps:

                    # save criteria
                    if dev_set is not None:
                        dev_loss, accuracy, dev_attlosses = \
                         self._evaluate_batches(model, dev_batches, dev_set)
                        dev_attloss = dev_attlosses['att_loss']
                        dev_attclsloss = dev_attlosses['attcls_loss']
                        log_msg = 'Progress: %d%%, Dev loss: %.4f, accuracy: %.4f, att: %.4f, attcls: %.4f' % (
                            step / total_steps * 100, dev_loss, accuracy,
                            dev_attloss, dev_attclsloss)
                        log.info(log_msg)
                        self.writer.add_scalar('dev_loss',
                                               dev_loss,
                                               global_step=step)
                        self.writer.add_scalar('dev_acc',
                                               accuracy,
                                               global_step=step)
                        self.writer.add_scalar('att_dev_loss',
                                               dev_attloss,
                                               global_step=step)
                        self.writer.add_scalar('attcls_dev_loss',
                                               dev_attclsloss,
                                               global_step=step)

                        # save
                        if prev_acc < accuracy:
                            # save best model
                            ckpt = Checkpoint(model=model,
                                              optimizer=self.optimizer,
                                              epoch=epoch,
                                              step=step,
                                              input_vocab=train_set.vocab_src,
                                              output_vocab=train_set.vocab_tgt)

                            saved_path = ckpt.save(self.expt_dir)
                            print('saving at {} ... '.format(saved_path))
                            # reset
                            prev_acc = accuracy
                            count_no_improve = 0
                            count_num_rollback = 0
                        else:
                            count_no_improve += 1

                        # roll back
                        if count_no_improve > MAX_COUNT_NO_IMPROVE:
                            # resuming
                            latest_checkpoint_path = Checkpoint.get_latest_checkpoint(
                                self.expt_dir)
                            if type(latest_checkpoint_path) != type(None):
                                resume_checkpoint = Checkpoint.load(
                                    latest_checkpoint_path)
                                print(
                                    'epoch:{} step: {} - rolling back {} ...'.
                                    format(epoch, step,
                                           latest_checkpoint_path))
                                model = resume_checkpoint.model
                                self.optimizer = resume_checkpoint.optimizer
                                # A walk around to set optimizing parameters properly
                                resume_optim = self.optimizer.optimizer
                                defaults = resume_optim.param_groups[0]
                                defaults.pop('params', None)
                                defaults.pop('initial_lr', None)
                                self.optimizer.optimizer = resume_optim\
                                 .__class__(model.parameters(), **defaults)
                                # start_epoch = resume_checkpoint.epoch
                                # step = resume_checkpoint.step

                            # reset
                            count_no_improve = 0
                            count_num_rollback += 1

                        # update learning rate
                        if count_num_rollback > MAX_COUNT_NUM_ROLLBACK:

                            # roll back
                            latest_checkpoint_path = Checkpoint.get_latest_checkpoint(
                                self.expt_dir)
                            if type(latest_checkpoint_path) != type(None):
                                resume_checkpoint = Checkpoint.load(
                                    latest_checkpoint_path)
                                print(
                                    'epoch:{} step: {} - rolling back {} ...'.
                                    format(epoch, step,
                                           latest_checkpoint_path))
                                model = resume_checkpoint.model
                                self.optimizer = resume_checkpoint.optimizer
                                # A walk around to set optimizing parameters properly
                                resume_optim = self.optimizer.optimizer
                                defaults = resume_optim.param_groups[0]
                                defaults.pop('params', None)
                                defaults.pop('initial_lr', None)
                                self.optimizer.optimizer = resume_optim\
                                 .__class__(model.parameters(), **defaults)
                                start_epoch = resume_checkpoint.epoch
                                step = resume_checkpoint.step

                            # decrease lr
                            for param_group in self.optimizer.optimizer.param_groups:
                                param_group['lr'] *= 0.5
                                lr_curr = param_group['lr']
                                print('reducing lr ...')
                                print('step:{} - lr: {}'.format(
                                    step, param_group['lr']))

                            # check early stop
                            if lr_curr < 0.000125:
                                print('early stop ...')
                                break

                            # reset
                            count_no_improve = 0
                            count_num_rollback = 0

                        model.train(mode=True)
                        if ckpt is None:
                            ckpt = Checkpoint(model=model,
                                              optimizer=self.optimizer,
                                              epoch=epoch,
                                              step=step,
                                              input_vocab=train_set.vocab_src,
                                              output_vocab=train_set.vocab_tgt)
                            saved_path = ckpt.save(self.expt_dir)
                        ckpt.rm_old(self.expt_dir, keep_num=KEEP_NUM)
                        print('n_no_improve {}, num_rollback {}'.format(
                            count_no_improve, count_num_rollback))
                    sys.stdout.flush()

            else:
                if dev_set is None:
                    # save every epoch if no dev_set
                    ckpt = Checkpoint(model=model,
                                      optimizer=self.optimizer,
                                      epoch=epoch,
                                      step=step,
                                      input_vocab=train_set.vocab_src,
                                      output_vocab=train_set.vocab_tgt)
                    # saved_path = ckpt.save(self.expt_dir)
                    saved_path = ckpt.save_epoch(self.expt_dir, epoch)
                    print('saving at {} ... '.format(saved_path))
                    continue

                else:
                    continue
            # break nested for loop
            break

            if step_elapsed == 0: continue
            epoch_loss_avg = epoch_loss_total / min(steps_per_epoch,
                                                    step - start_step)
            epoch_loss_total = 0
            log_msg = "Finished epoch %d: Train %s: %.4f" % (
                epoch, self.loss.name, epoch_loss_avg)

            log.info('\n')
            log.info(log_msg)