def __init__(self, model, src_file=None, trg_file=None, dev_every=0, batcher=bare(SrcBatcher, batch_size=32), loss_calculator=None, run_for_epochs=None, lr_decay=1.0, lr_decay_times=3, patience=1, initial_patience=None, dev_tasks=None, restart_trainer=False, reload_command=None, name=None, sample_train_sents: Optional[int] = None, max_num_train_sents=None, max_src_len=None, max_trg_len=None): self.src_file = src_file self.trg_file = trg_file self.dev_tasks = dev_tasks if lr_decay > 1.0 or lr_decay <= 0.0: raise RuntimeError("illegal lr_decay, must satisfy: 0.0 < lr_decay <= 1.0") self.lr_decay = lr_decay self.patience = patience self.initial_patience = initial_patience self.lr_decay_times = lr_decay_times self.restart_trainer = restart_trainer self.run_for_epochs = run_for_epochs self.early_stopping_reached = False # training state self.training_state = TrainingState() self.reload_command = reload_command self.model = model self.loss_calculator = loss_calculator or LossCalculator(MLELoss()) self.sample_train_sents = sample_train_sents self.max_num_train_sents = max_num_train_sents self.max_src_len = max_src_len self.max_trg_len = max_trg_len self.batcher = batcher self.logger = BatchLossTracker(self, dev_every, name)
def test_single(self): dy.renew_cg() train_loss = self.model.calc_loss(src=self.src_data[0], trg=self.trg_data[0], loss_calculator=LossCalculator()).value() dy.renew_cg() self.model.initialize_generator(beam=1) outputs = self.model.generate_output(self.src_data[0], 0, forced_trg_ids=self.trg_data[0]) self.assertAlmostEqual(-outputs[0].score, train_loss, places=4)
def assert_single_loss_equals_batch_loss(self, model, pad_src_to_multiple=1): """ Tests whether single loss equals batch loss. Here we don't truncate the target side and use masking. """ batch_size = 5 src_sents = self.src_data[:batch_size] src_min = min([len(x) for x in src_sents]) src_sents_trunc = [s[:src_min] for s in src_sents] for single_sent in src_sents_trunc: single_sent[src_min - 1] = Vocab.ES while len(single_sent) % pad_src_to_multiple != 0: single_sent.append(Vocab.ES) trg_sents = self.trg_data[:batch_size] trg_max = max([len(x) for x in trg_sents]) trg_masks = Mask(np.zeros([batch_size, trg_max])) for i in range(batch_size): for j in range(len(trg_sents[i]), trg_max): trg_masks.np_arr[i, j] = 1.0 trg_sents_padded = [[w for w in s] + [Vocab.ES] * (trg_max - len(s)) for s in trg_sents] single_loss = 0.0 for sent_id in range(batch_size): dy.renew_cg() train_loss = model.calc_loss( src=src_sents_trunc[sent_id], trg=trg_sents[sent_id], loss_calculator=LossCalculator()).value() single_loss += train_loss dy.renew_cg() batched_loss = model.calc_loss( src=mark_as_batch(src_sents_trunc), trg=mark_as_batch(trg_sents_padded, trg_masks), loss_calculator=LossCalculator()).value() self.assertAlmostEqual(single_loss, sum(batched_loss), places=4)
def __init__(self, src_file, ref_file, model=Ref("model"), batcher=Ref("train.batcher", default=None), loss_calculator=None, max_src_len=None, max_trg_len=None, desc=None): self.model = model self.loss_calculator = loss_calculator or LossCalculator(MLELoss()) self.src_file = src_file self.ref_file = ref_file self.batcher = batcher self.src_data = None self.max_src_len = max_src_len self.max_trg_len = max_trg_len self.desc=desc
def test_overfitting(self): layer_dim = 16 batcher = SrcBatcher(batch_size=10, break_ties_randomly=False) train_args = {} train_args['src_file'] = "examples/data/head.ja" train_args['trg_file'] = "examples/data/head.en" train_args['loss_calculator'] = LossCalculator() train_args['model'] = DefaultTranslator( src_reader=PlainTextReader(), trg_reader=PlainTextReader(), src_embedder=SimpleWordEmbedder(vocab_size=100, emb_dim=layer_dim), encoder=BiLSTMSeqTransducer(input_dim=layer_dim, hidden_dim=layer_dim), attender=MlpAttender(input_dim=layer_dim, state_dim=layer_dim, hidden_dim=layer_dim), trg_embedder=SimpleWordEmbedder(vocab_size=100, emb_dim=layer_dim), decoder=MlpSoftmaxDecoder(input_dim=layer_dim, trg_embed_dim=layer_dim, rnn_layer=UniLSTMSeqTransducer( input_dim=layer_dim, hidden_dim=layer_dim, decoder_input_dim=layer_dim, yaml_path="model.decoder.rnn_layer"), mlp_layer=MLP( input_dim=layer_dim, hidden_dim=layer_dim, decoder_rnn_dim=layer_dim, vocab_size=100, yaml_path="model.decoder.rnn_layer"), bridge=CopyBridge(dec_dim=layer_dim, dec_layers=1)), ) train_args['dev_tasks'] = [ LossEvalTask(model=train_args['model'], src_file="examples/data/head.ja", ref_file="examples/data/head.en", batcher=batcher) ] train_args['run_for_epochs'] = 1 train_args['trainer'] = AdamTrainer(alpha=0.1) train_args['batcher'] = batcher training_regimen = xnmt.training_regimen.SimpleTrainingRegimen( **train_args) for _ in range(50): training_regimen.run_training(save_fct=lambda: None, update_weights=True) self.assertAlmostEqual(0.0, training_regimen.logger.epoch_loss.sum() / training_regimen.logger.epoch_words, places=2)
def assert_single_loss_equals_batch_loss(self, model, pad_src_to_multiple=1): """ Tests whether single loss equals batch loss. Truncating src / trg sents to same length so no masking is necessary """ batch_size = 5 src_sents = self.src_data[:batch_size] src_min = min([len(x) for x in src_sents]) src_sents_trunc = [s[:src_min] for s in src_sents] for single_sent in src_sents_trunc: single_sent[src_min - 1] = Vocab.ES while len(single_sent) % pad_src_to_multiple != 0: single_sent.append(Vocab.ES) trg_sents = self.trg_data[:batch_size] trg_min = min([len(x) for x in trg_sents]) trg_sents_trunc = [s[:trg_min] for s in trg_sents] for single_sent in trg_sents_trunc: single_sent[trg_min - 1] = Vocab.ES single_loss = 0.0 for sent_id in range(batch_size): dy.renew_cg() train_loss = model.calc_loss( src=src_sents_trunc[sent_id], trg=trg_sents_trunc[sent_id], loss_calculator=LossCalculator()).value() single_loss += train_loss dy.renew_cg() batched_loss = model.calc_loss( src=mark_as_batch(src_sents_trunc), trg=mark_as_batch(trg_sents_trunc), loss_calculator=LossCalculator()).value() self.assertAlmostEqual(single_loss, sum(batched_loss), places=4)
def test_single(self): dy.renew_cg() self.model.initialize_generator() outputs = self.model.generate_output( self.training_corpus.train_src_data[0], 0, forced_trg_ids=self.training_corpus.train_trg_data[0]) output_score = outputs[0].score dy.renew_cg() train_loss = self.model.calc_loss( src=self.training_corpus.train_src_data[0], trg=outputs[0].actions, loss_calculator=LossCalculator()).value() self.assertAlmostEqual(-output_score, train_loss, places=5)
def test_train_dev_loss_equal(self): layer_dim = 512 batcher = SrcBatcher(batch_size=5, break_ties_randomly=False) train_args = {} train_args['src_file'] = "examples/data/head.ja" train_args['trg_file'] = "examples/data/head.en" train_args['loss_calculator'] = LossCalculator() train_args['model'] = DefaultTranslator( src_reader=PlainTextReader(), trg_reader=PlainTextReader(), src_embedder=SimpleWordEmbedder(emb_dim=layer_dim, vocab_size=100), encoder=BiLSTMSeqTransducer(input_dim=layer_dim, hidden_dim=layer_dim), attender=MlpAttender(input_dim=layer_dim, state_dim=layer_dim, hidden_dim=layer_dim), trg_embedder=SimpleWordEmbedder(emb_dim=layer_dim, vocab_size=100), decoder=MlpSoftmaxDecoder(input_dim=layer_dim, lstm_dim=layer_dim, mlp_hidden_dim=layer_dim, trg_embed_dim=layer_dim, vocab_size=100, bridge=CopyBridge(dec_layers=1, dec_dim=layer_dim)), ) train_args['dev_tasks'] = [ LossEvalTask(model=train_args['model'], src_file="examples/data/head.ja", ref_file="examples/data/head.en", batcher=batcher) ] train_args['trainer'] = None train_args['batcher'] = batcher train_args['run_for_epochs'] = 1 training_regimen = xnmt.training_regimen.SimpleTrainingRegimen( **train_args) training_regimen.run_training(save_fct=lambda: None, update_weights=False) self.assertAlmostEqual(training_regimen.logger.epoch_loss.sum() / training_regimen.logger.epoch_words, training_regimen.logger.dev_score.loss, places=5)
def test_overfitting(self): self.exp_global = ExpGlobal( dynet_param_collection=NonPersistentParamCollection(), dropout=0.0) self.exp_global.default_layer_dim = 16 batcher = SrcBatcher(batch_size=10, break_ties_randomly=False) train_args = {} train_args['src_file'] = "examples/data/head.ja" train_args['trg_file'] = "examples/data/head.en" train_args['loss_calculator'] = LossCalculator() train_args['model'] = DefaultTranslator( src_reader=PlainTextReader(), trg_reader=PlainTextReader(), src_embedder=SimpleWordEmbedder(self.exp_global, vocab_size=100), encoder=BiLSTMSeqTransducer(self.exp_global), attender=MlpAttender(self.exp_global), trg_embedder=SimpleWordEmbedder(self.exp_global, vocab_size=100), decoder=MlpSoftmaxDecoder(self.exp_global, vocab_size=100, bridge=CopyBridge( exp_global=self.exp_global, dec_layers=1)), ) train_args['dev_tasks'] = [ LossEvalTask(model=train_args['model'], src_file="examples/data/head.ja", ref_file="examples/data/head.en", batcher=batcher) ] train_args['run_for_epochs'] = 1 train_args['trainer'] = AdamTrainer(self.exp_global, alpha=0.1) train_args['batcher'] = batcher training_regimen = xnmt.training_regimen.SimpleTrainingRegimen( exp_global=self.exp_global, **train_args) training_regimen.exp_global = self.exp_global for _ in range(50): training_regimen.run_training(save_fct=lambda: None, update_weights=True) self.assertAlmostEqual(0.0, training_regimen.logger.epoch_loss.sum() / training_regimen.logger.epoch_words, places=2)
def test_overfitting(self): self.model_context = ModelContext() self.model_context.dynet_param_collection = NonPersistentParamCollection( ) self.model_context.default_layer_dim = 16 train_args = {} training_corpus = BilingualTrainingCorpus( train_src="examples/data/head.ja", train_trg="examples/data/head.en", dev_src="examples/data/head.ja", dev_trg="examples/data/head.en") train_args['corpus_parser'] = BilingualCorpusParser( training_corpus=training_corpus, src_reader=PlainTextReader(), trg_reader=PlainTextReader()) train_args['loss_calculator'] = LossCalculator() train_args['model'] = DefaultTranslator( src_embedder=SimpleWordEmbedder(self.model_context, vocab_size=100), encoder=BiLSTMSeqTransducer(self.model_context), attender=MlpAttender(self.model_context), trg_embedder=SimpleWordEmbedder(self.model_context, vocab_size=100), decoder=MlpSoftmaxDecoder(self.model_context, vocab_size=100), ) train_args['run_for_epochs'] = 1 train_args['trainer'] = AdamTrainer(self.model_context, alpha=0.1) train_args['batcher'] = SrcBatcher(batch_size=10, break_ties_randomly=False) training_regimen = xnmt.training_regimen.SimpleTrainingRegimen( yaml_context=self.model_context, **train_args) training_regimen.model_context = self.model_context for _ in range(50): training_regimen.run_training(update_weights=True) self.assertAlmostEqual( 0.0, training_regimen.logger.epoch_loss.loss_values['loss'] / training_regimen.logger.epoch_words, places=2)
def __call__(self, corpus_parser, generator, batcher, src_file=None, trg_file=None, candidate_id_file=None): """ :param src_file: path of input src file to be translated :param trg_file: path of file where trg translatons will be written :param batcher: :param candidate_id_file: if we are doing something like retrieval where we select from fixed candidates, sometimes we want to limit our candidates to a certain subset of the full set. this setting allows us to do this. :param model_elements: If None, the model will be loaded from model_file. If set, should equal (corpus_parser, generator). """ args = dict(model_file=self.model_file, src_file=src_file or self.src_file, trg_file=trg_file or self.trg_file, ref_file=self.ref_file, max_src_len=self.max_src_len, input_format=self.input_format, post_process=self.post_process, candidate_id_file=candidate_id_file, report_path=self.report_path, report_type=self.report_type, beam=self.beam, max_len=self.max_len, len_norm_type=self.len_norm_type, mode=self.mode) is_reporting = issubclass( generator.__class__, Reportable) and args["report_path"] is not None # Corpus src_corpus = list(corpus_parser.src_reader.read_sents( args["src_file"])) # Get reference if it exists and is necessary if args["mode"] == "forced" or args["mode"] == "forceddebug": if args["ref_file"] == None: raise RuntimeError( "When performing {} decoding, must specify reference file". format(args["mode"])) ref_corpus = list( corpus_parser.trg_reader.read_sents(args["ref_file"])) else: ref_corpus = None # Vocab src_vocab = corpus_parser.src_reader.vocab if hasattr( corpus_parser.src_reader, "vocab") else None trg_vocab = corpus_parser.trg_reader.vocab if hasattr( corpus_parser.trg_reader, "vocab") else None # Perform initialization generator.set_train(False) generator.initialize_generator(**args) # TODO: Structure it better. not only Translator can have post processes if issubclass(generator.__class__, Translator): generator.set_post_processor(self.get_output_processor()) generator.set_trg_vocab(trg_vocab) generator.set_reporting_src_vocab(src_vocab) if is_reporting: generator.set_report_resource("src_vocab", src_vocab) generator.set_report_resource("trg_vocab", trg_vocab) # If we're debugging, calculate the loss for each target sentence ref_scores = None if args["mode"] == 'forceddebug': some_batcher = xnmt.batcher.InOrderBatcher(32) # Arbitrary batched_src, batched_ref = some_batcher.pack( src_corpus, ref_corpus) ref_scores = [] for src, ref in zip(batched_src, batched_ref): dy.renew_cg() loss_expr = generator.calc_loss( src, ref, loss_calculator=LossCalculator()) ref_scores.extend(loss_expr.value()) ref_scores = [-x for x in ref_scores] # Perform generation of output with io.open(args["trg_file"], 'wt', encoding='utf-8' ) as fp: # Saving the translated output to a trg file src_ret = [] for i, src in enumerate(src_corpus): batcher.add_single_batch(src_curr=[src], trg_curr=None, src_ret=src_ret, trg_ret=None) src = src_ret.pop()[0] # Do the decoding if args["max_src_len"] is not None and len( src) > args["max_src_len"]: output_txt = NO_DECODING_ATTEMPTED else: dy.renew_cg() ref_ids = ref_corpus[i] if ref_corpus != None else None output = generator.generate_output(src, i, forced_trg_ids=ref_ids) # If debugging forced decoding, make sure it matches if ref_scores != None and ( abs(output[0].score - ref_scores[i]) / abs(ref_scores[i])) > 1e-5: print( 'Forced decoding score {} and loss {} do not match at sentence {}' .format(output[0].score, ref_scores[i], i)) output_txt = output[0].plaintext # Printing to trg file fp.write(u"{}\n".format(output_txt))
def __call__(self, generator, src_file=None, trg_file=None, candidate_id_file=None): """ Args: generator (GeneratorModel): the model to be used src_file (str): path of input src file to be translated trg_file (str): path of file where trg translatons will be written candidate_id_file (str): if we are doing something like retrieval where we select from fixed candidates, sometimes we want to limit our candidates to a certain subset of the full set. this setting allows us to do this. """ args = dict(src_file=src_file or self.src_file, trg_file=trg_file or self.trg_file, ref_file=self.ref_file, max_src_len=self.max_src_len, post_process=self.post_process, candidate_id_file=candidate_id_file, report_path=self.report_path, report_type=self.report_type, beam=self.beam, max_len=self.max_len, len_norm_type=self.len_norm_type, mode=self.mode) is_reporting = issubclass( generator.__class__, Reportable) and args["report_path"] is not None # Corpus src_corpus = list(generator.src_reader.read_sents(args["src_file"])) # Get reference if it exists and is necessary if args["mode"] == "forced" or args["mode"] == "forceddebug" or args[ "mode"] == "score": if args["ref_file"] == None: raise RuntimeError( "When performing {} decoding, must specify reference file". format(args["mode"])) score_src_corpus = [] ref_corpus = [] with open(args["ref_file"], "r", encoding="utf-8") as fp: for line in fp: if args["mode"] == "score": nbest = line.split("|||") assert len( nbest ) > 1, "When performing scoring, ref_file must have nbest format 'index ||| hypothesis'" src_index = int(nbest[0].strip()) assert src_index < len( src_corpus ), "The src_file has only {} instances, nbest file has invalid src_index {}".format( len(src_corpus), src_index) score_src_corpus.append(src_corpus[src_index]) trg_input = generator.trg_reader.read_sent( nbest[1].strip()) else: trg_input = generator.trg_reader.read_sent(line) ref_corpus.append(trg_input) if args["mode"] == "score": src_corpus = score_src_corpus else: if self.max_len and any( len(s) > self.max_len for s in ref_corpus): logger.warning( "Forced decoding with some targets being longer than max_len. Increase max_len to avoid unexpected behavior." ) else: ref_corpus = None # Vocab src_vocab = generator.src_reader.vocab if hasattr( generator.src_reader, "vocab") else None trg_vocab = generator.trg_reader.vocab if hasattr( generator.trg_reader, "vocab") else None # Perform initialization generator.set_train(False) generator.initialize_generator(**args) if hasattr(generator, "set_post_processor"): generator.set_post_processor(self.get_output_processor()) if hasattr(generator, "set_trg_vocab"): generator.set_trg_vocab(trg_vocab) if hasattr(generator, "set_reporting_src_vocab"): generator.set_reporting_src_vocab(src_vocab) if is_reporting: generator.set_report_resource("src_vocab", src_vocab) generator.set_report_resource("trg_vocab", trg_vocab) # If we're debugging, calculate the loss for each target sentence ref_scores = None if args["mode"] == 'forceddebug' or args["mode"] == 'score': some_batcher = xnmt.batcher.InOrderBatcher(32) # Arbitrary if not isinstance(some_batcher, xnmt.batcher.InOrderBatcher): raise ValueError( f"forceddebug requires InOrderBatcher, got: {some_batcher}" ) batched_src, batched_ref = some_batcher.pack( src_corpus, ref_corpus) ref_scores = [] for src, ref in zip(batched_src, batched_ref): dy.renew_cg(immediate_compute=settings.IMMEDIATE_COMPUTE, check_validity=settings.CHECK_VALIDITY) loss_expr = generator.calc_loss( src, ref, loss_calculator=LossCalculator()) if isinstance(loss_expr.value(), Iterable): ref_scores.extend(loss_expr.value()) else: ref_scores.append(loss_expr.value()) ref_scores = [-x for x in ref_scores] # Make the parent directory if necessary make_parent_dir(args["trg_file"]) # Perform generation of output if args["mode"] != 'score': with open(args["trg_file"], 'wt', encoding='utf-8' ) as fp: # Saving the translated output to a trg file src_ret = [] for i, src in enumerate(src_corpus): # This is necessary when the batcher does some sort of pre-processing, e.g. # when the batcher pads to a particular number of dimensions if self.batcher: self.batcher.add_single_batch(src_curr=[src], trg_curr=None, src_ret=src_ret, trg_ret=None) src = src_ret.pop()[0] # Do the decoding if args["max_src_len"] is not None and len( src) > args["max_src_len"]: output_txt = NO_DECODING_ATTEMPTED else: dy.renew_cg( immediate_compute=settings.IMMEDIATE_COMPUTE, check_validity=settings.CHECK_VALIDITY) ref_ids = ref_corpus[i] if ref_corpus != None else None output = generator.generate_output( src, i, forced_trg_ids=ref_ids) # If debugging forced decoding, make sure it matches if ref_scores != None and ( abs(output[0].score - ref_scores[i]) / abs(ref_scores[i])) > 1e-5: logger.error( f'Forced decoding score {output[0].score} and loss {ref_scores[i]} do not match at sentence {i}' ) output_txt = output[0].plaintext # Printing to trg file fp.write(f"{output_txt}\n") else: with open(args["trg_file"], 'wt', encoding='utf-8') as fp: with open(args["ref_file"], "r", encoding="utf-8") as nbest_fp: for nbest, score in zip(nbest_fp, ref_scores): fp.write("{} ||| score={}\n".format( nbest.strip(), score))
def __init__(self, yaml_context, corpus_parser, model, glob={}, dev_every=0, batcher=None, loss_calculator=None, pretrained_model_file="", src_format="text", run_for_epochs=None, lr_decay=1.0, lr_decay_times=3, patience=1, initial_patience=None, dev_metrics="", schedule_metric="loss", restart_trainer=False, reload_command=None, name=None, inference=None): """ :param yaml_context: :param corpus_parser: an input.InputReader object :param model: a generator.GeneratorModel object :param dev_every (int): dev checkpoints every n sentences (0 for only after epoch) :param batcher: Type of batcher. Defaults to SrcBatcher of batch size 32. :param loss_calculator: :param pretrained_model_file: Path of pre-trained model file :param src_format: Format of input data: text/contvec :param lr_decay (float): :param lr_decay_times (int): Early stopping after decaying learning rate a certain number of times :param patience (int): apply LR decay after dev scores haven't improved over this many checkpoints :param initial_patience (int): if given, allows adjusting patience for the first LR decay :param dev_metrics: Comma-separated list of evaluation metrics (bleu/wer/cer) :param schedule_metric: determine learning schedule based on this dev_metric (loss/bleu/wer/cer) :param restart_trainer: Restart trainer (useful for Adam) and revert weights to best dev checkpoint when applying LR decay (https://arxiv.org/pdf/1706.09733.pdf) :param reload_command: Command to change the input data after each epoch. --epoch EPOCH_NUM will be appended to the command. To just reload the data after each epoch set the command to 'true'. :param name: will be prepended to log outputs if given :param inference: used for inference during dev checkpoints if dev_metrics are specified """ assert yaml_context is not None self.yaml_context = yaml_context self.model_file = self.yaml_context.dynet_param_collection.model_file self.yaml_serializer = YamlSerializer() if lr_decay > 1.0 or lr_decay <= 0.0: raise RuntimeError( "illegal lr_decay, must satisfy: 0.0 < lr_decay <= 1.0") self.lr_decay = lr_decay self.patience = patience self.initial_patience = initial_patience self.lr_decay_times = lr_decay_times self.restart_trainer = restart_trainer self.run_for_epochs = run_for_epochs self.early_stopping_reached = False # training state self.training_state = TrainingState() self.evaluators = [ s.lower() for s in dev_metrics.split(",") if s.strip() != "" ] if schedule_metric.lower() not in self.evaluators: self.evaluators.append(schedule_metric.lower()) if "loss" not in self.evaluators: self.evaluators.append("loss") if dev_metrics: self.inference = inference or SimpleInference() self.reload_command = reload_command if reload_command is not None: self._augmentation_handle = None self._augment_data_initial() self.model = model self.corpus_parser = corpus_parser self.loss_calculator = loss_calculator or LossCalculator(MLELoss()) self.pretrained_model_file = pretrained_model_file if self.pretrained_model_file: self.yaml_context.dynet_param_collection.load_from_data_file( self.pretrained_model_file + '.data') self.batcher = batcher or SrcBatcher(32) if src_format == "contvec": self.batcher.pad_token = np.zeros(self.model.src_embedder.emb_dim) self.pack_batches() self.logger = BatchLossTracker(self, dev_every, name) self.schedule_metric = schedule_metric.lower()
def __init__(self, model, src_file=None, trg_file=None, dev_every=0, batcher=bare(SrcBatcher, batch_size=32), loss_calculator=None, run_for_epochs=None, lr_decay=1.0, lr_decay_times=3, patience=1, initial_patience=None, dev_tasks=None, restart_trainer=False, reload_command=None, name=None, sample_train_sents=None, max_num_train_sents=None, max_src_len=None, max_trg_len=None, exp_global=Ref(Path("exp_global"))): """ Args: exp_global: model: a generator.GeneratorModel object src_file: The file for the source data. trg_file: The file for the target data. dev_every (int): dev checkpoints every n sentences (0 for only after epoch) batcher: Type of batcher loss_calculator: lr_decay (float): lr_decay_times (int): Early stopping after decaying learning rate a certain number of times patience (int): apply LR decay after dev scores haven't improved over this many checkpoints initial_patience (int): if given, allows adjusting patience for the first LR decay dev_tasks: A list of tasks to run on the development set restart_trainer: Restart trainer (useful for Adam) and revert weights to best dev checkpoint when applying LR decay (https://arxiv.org/pdf/1706.09733.pdf) reload_command: Command to change the input data after each epoch. --epoch EPOCH_NUM will be appended to the command. To just reload the data after each epoch set the command to 'true'. sample_train_sents: max_num_train_sents: max_src_len: max_trg_len: name: will be prepended to log outputs if given """ self.exp_global = exp_global self.model_file = self.exp_global.dynet_param_collection.model_file self.src_file = src_file self.trg_file = trg_file self.dev_tasks = dev_tasks if lr_decay > 1.0 or lr_decay <= 0.0: raise RuntimeError("illegal lr_decay, must satisfy: 0.0 < lr_decay <= 1.0") self.lr_decay = lr_decay self.patience = patience self.initial_patience = initial_patience self.lr_decay_times = lr_decay_times self.restart_trainer = restart_trainer self.run_for_epochs = run_for_epochs self.early_stopping_reached = False # training state self.training_state = TrainingState() self.reload_command = reload_command self.model = model self.loss_calculator = loss_calculator or LossCalculator(MLELoss()) self.sample_train_sents = sample_train_sents self.max_num_train_sents = max_num_train_sents self.max_src_len = max_src_len self.max_trg_len = max_trg_len self.batcher = batcher self.logger = BatchLossTracker(self, dev_every, name)