예제 #1
0
  def __init__(self, model, src_file=None, trg_file=None, dev_every=0,
               batcher=bare(SrcBatcher, batch_size=32), loss_calculator=None,
               run_for_epochs=None, lr_decay=1.0, lr_decay_times=3, patience=1,
               initial_patience=None, dev_tasks=None, restart_trainer=False,
               reload_command=None, name=None, sample_train_sents: Optional[int] = None,
               max_num_train_sents=None, max_src_len=None, max_trg_len=None):
    self.src_file = src_file
    self.trg_file = trg_file
    self.dev_tasks = dev_tasks

    if lr_decay > 1.0 or lr_decay <= 0.0:
      raise RuntimeError("illegal lr_decay, must satisfy: 0.0 < lr_decay <= 1.0")
    self.lr_decay = lr_decay
    self.patience = patience
    self.initial_patience = initial_patience
    self.lr_decay_times = lr_decay_times
    self.restart_trainer = restart_trainer
    self.run_for_epochs = run_for_epochs

    self.early_stopping_reached = False
    # training state
    self.training_state = TrainingState()

    self.reload_command = reload_command

    self.model = model
    self.loss_calculator = loss_calculator or LossCalculator(MLELoss())

    self.sample_train_sents = sample_train_sents
    self.max_num_train_sents = max_num_train_sents
    self.max_src_len = max_src_len
    self.max_trg_len = max_trg_len

    self.batcher = batcher
    self.logger = BatchLossTracker(self, dev_every, name)
예제 #2
0
 def test_single(self):
   dy.renew_cg()
   train_loss = self.model.calc_loss(src=self.src_data[0],
                                     trg=self.trg_data[0],
                                     loss_calculator=LossCalculator()).value()
   dy.renew_cg()
   self.model.initialize_generator(beam=1)
   outputs = self.model.generate_output(self.src_data[0], 0,
                                        forced_trg_ids=self.trg_data[0])
   self.assertAlmostEqual(-outputs[0].score, train_loss, places=4)
예제 #3
0
    def assert_single_loss_equals_batch_loss(self,
                                             model,
                                             pad_src_to_multiple=1):
        """
    Tests whether single loss equals batch loss.
    Here we don't truncate the target side and use masking.
    """
        batch_size = 5
        src_sents = self.src_data[:batch_size]
        src_min = min([len(x) for x in src_sents])
        src_sents_trunc = [s[:src_min] for s in src_sents]
        for single_sent in src_sents_trunc:
            single_sent[src_min - 1] = Vocab.ES
            while len(single_sent) % pad_src_to_multiple != 0:
                single_sent.append(Vocab.ES)
        trg_sents = self.trg_data[:batch_size]
        trg_max = max([len(x) for x in trg_sents])
        trg_masks = Mask(np.zeros([batch_size, trg_max]))
        for i in range(batch_size):
            for j in range(len(trg_sents[i]), trg_max):
                trg_masks.np_arr[i, j] = 1.0
        trg_sents_padded = [[w for w in s] + [Vocab.ES] * (trg_max - len(s))
                            for s in trg_sents]

        single_loss = 0.0
        for sent_id in range(batch_size):
            dy.renew_cg()
            train_loss = model.calc_loss(
                src=src_sents_trunc[sent_id],
                trg=trg_sents[sent_id],
                loss_calculator=LossCalculator()).value()
            single_loss += train_loss

        dy.renew_cg()

        batched_loss = model.calc_loss(
            src=mark_as_batch(src_sents_trunc),
            trg=mark_as_batch(trg_sents_padded, trg_masks),
            loss_calculator=LossCalculator()).value()
        self.assertAlmostEqual(single_loss, sum(batched_loss), places=4)
예제 #4
0
파일: eval_task.py 프로젝트: bastings/xnmt
 def __init__(self, src_file, ref_file, model=Ref("model"),
               batcher=Ref("train.batcher", default=None),
               loss_calculator=None, max_src_len=None, max_trg_len=None,
               desc=None):
   self.model = model
   self.loss_calculator = loss_calculator or LossCalculator(MLELoss())
   self.src_file = src_file
   self.ref_file = ref_file
   self.batcher = batcher
   self.src_data = None
   self.max_src_len = max_src_len
   self.max_trg_len = max_trg_len
   self.desc=desc
예제 #5
0
 def test_overfitting(self):
     layer_dim = 16
     batcher = SrcBatcher(batch_size=10, break_ties_randomly=False)
     train_args = {}
     train_args['src_file'] = "examples/data/head.ja"
     train_args['trg_file'] = "examples/data/head.en"
     train_args['loss_calculator'] = LossCalculator()
     train_args['model'] = DefaultTranslator(
         src_reader=PlainTextReader(),
         trg_reader=PlainTextReader(),
         src_embedder=SimpleWordEmbedder(vocab_size=100, emb_dim=layer_dim),
         encoder=BiLSTMSeqTransducer(input_dim=layer_dim,
                                     hidden_dim=layer_dim),
         attender=MlpAttender(input_dim=layer_dim,
                              state_dim=layer_dim,
                              hidden_dim=layer_dim),
         trg_embedder=SimpleWordEmbedder(vocab_size=100, emb_dim=layer_dim),
         decoder=MlpSoftmaxDecoder(input_dim=layer_dim,
                                   trg_embed_dim=layer_dim,
                                   rnn_layer=UniLSTMSeqTransducer(
                                       input_dim=layer_dim,
                                       hidden_dim=layer_dim,
                                       decoder_input_dim=layer_dim,
                                       yaml_path="model.decoder.rnn_layer"),
                                   mlp_layer=MLP(
                                       input_dim=layer_dim,
                                       hidden_dim=layer_dim,
                                       decoder_rnn_dim=layer_dim,
                                       vocab_size=100,
                                       yaml_path="model.decoder.rnn_layer"),
                                   bridge=CopyBridge(dec_dim=layer_dim,
                                                     dec_layers=1)),
     )
     train_args['dev_tasks'] = [
         LossEvalTask(model=train_args['model'],
                      src_file="examples/data/head.ja",
                      ref_file="examples/data/head.en",
                      batcher=batcher)
     ]
     train_args['run_for_epochs'] = 1
     train_args['trainer'] = AdamTrainer(alpha=0.1)
     train_args['batcher'] = batcher
     training_regimen = xnmt.training_regimen.SimpleTrainingRegimen(
         **train_args)
     for _ in range(50):
         training_regimen.run_training(save_fct=lambda: None,
                                       update_weights=True)
     self.assertAlmostEqual(0.0,
                            training_regimen.logger.epoch_loss.sum() /
                            training_regimen.logger.epoch_words,
                            places=2)
예제 #6
0
    def assert_single_loss_equals_batch_loss(self,
                                             model,
                                             pad_src_to_multiple=1):
        """
    Tests whether single loss equals batch loss.
    Truncating src / trg sents to same length so no masking is necessary
    """
        batch_size = 5
        src_sents = self.src_data[:batch_size]
        src_min = min([len(x) for x in src_sents])
        src_sents_trunc = [s[:src_min] for s in src_sents]
        for single_sent in src_sents_trunc:
            single_sent[src_min - 1] = Vocab.ES
            while len(single_sent) % pad_src_to_multiple != 0:
                single_sent.append(Vocab.ES)
        trg_sents = self.trg_data[:batch_size]
        trg_min = min([len(x) for x in trg_sents])
        trg_sents_trunc = [s[:trg_min] for s in trg_sents]
        for single_sent in trg_sents_trunc:
            single_sent[trg_min - 1] = Vocab.ES

        single_loss = 0.0
        for sent_id in range(batch_size):
            dy.renew_cg()
            train_loss = model.calc_loss(
                src=src_sents_trunc[sent_id],
                trg=trg_sents_trunc[sent_id],
                loss_calculator=LossCalculator()).value()
            single_loss += train_loss

        dy.renew_cg()

        batched_loss = model.calc_loss(
            src=mark_as_batch(src_sents_trunc),
            trg=mark_as_batch(trg_sents_trunc),
            loss_calculator=LossCalculator()).value()
        self.assertAlmostEqual(single_loss, sum(batched_loss), places=4)
예제 #7
0
    def test_single(self):
        dy.renew_cg()
        self.model.initialize_generator()
        outputs = self.model.generate_output(
            self.training_corpus.train_src_data[0],
            0,
            forced_trg_ids=self.training_corpus.train_trg_data[0])
        output_score = outputs[0].score

        dy.renew_cg()
        train_loss = self.model.calc_loss(
            src=self.training_corpus.train_src_data[0],
            trg=outputs[0].actions,
            loss_calculator=LossCalculator()).value()

        self.assertAlmostEqual(-output_score, train_loss, places=5)
예제 #8
0
 def test_train_dev_loss_equal(self):
     layer_dim = 512
     batcher = SrcBatcher(batch_size=5, break_ties_randomly=False)
     train_args = {}
     train_args['src_file'] = "examples/data/head.ja"
     train_args['trg_file'] = "examples/data/head.en"
     train_args['loss_calculator'] = LossCalculator()
     train_args['model'] = DefaultTranslator(
         src_reader=PlainTextReader(),
         trg_reader=PlainTextReader(),
         src_embedder=SimpleWordEmbedder(emb_dim=layer_dim, vocab_size=100),
         encoder=BiLSTMSeqTransducer(input_dim=layer_dim,
                                     hidden_dim=layer_dim),
         attender=MlpAttender(input_dim=layer_dim,
                              state_dim=layer_dim,
                              hidden_dim=layer_dim),
         trg_embedder=SimpleWordEmbedder(emb_dim=layer_dim, vocab_size=100),
         decoder=MlpSoftmaxDecoder(input_dim=layer_dim,
                                   lstm_dim=layer_dim,
                                   mlp_hidden_dim=layer_dim,
                                   trg_embed_dim=layer_dim,
                                   vocab_size=100,
                                   bridge=CopyBridge(dec_layers=1,
                                                     dec_dim=layer_dim)),
     )
     train_args['dev_tasks'] = [
         LossEvalTask(model=train_args['model'],
                      src_file="examples/data/head.ja",
                      ref_file="examples/data/head.en",
                      batcher=batcher)
     ]
     train_args['trainer'] = None
     train_args['batcher'] = batcher
     train_args['run_for_epochs'] = 1
     training_regimen = xnmt.training_regimen.SimpleTrainingRegimen(
         **train_args)
     training_regimen.run_training(save_fct=lambda: None,
                                   update_weights=False)
     self.assertAlmostEqual(training_regimen.logger.epoch_loss.sum() /
                            training_regimen.logger.epoch_words,
                            training_regimen.logger.dev_score.loss,
                            places=5)
예제 #9
0
 def test_overfitting(self):
     self.exp_global = ExpGlobal(
         dynet_param_collection=NonPersistentParamCollection(), dropout=0.0)
     self.exp_global.default_layer_dim = 16
     batcher = SrcBatcher(batch_size=10, break_ties_randomly=False)
     train_args = {}
     train_args['src_file'] = "examples/data/head.ja"
     train_args['trg_file'] = "examples/data/head.en"
     train_args['loss_calculator'] = LossCalculator()
     train_args['model'] = DefaultTranslator(
         src_reader=PlainTextReader(),
         trg_reader=PlainTextReader(),
         src_embedder=SimpleWordEmbedder(self.exp_global, vocab_size=100),
         encoder=BiLSTMSeqTransducer(self.exp_global),
         attender=MlpAttender(self.exp_global),
         trg_embedder=SimpleWordEmbedder(self.exp_global, vocab_size=100),
         decoder=MlpSoftmaxDecoder(self.exp_global,
                                   vocab_size=100,
                                   bridge=CopyBridge(
                                       exp_global=self.exp_global,
                                       dec_layers=1)),
     )
     train_args['dev_tasks'] = [
         LossEvalTask(model=train_args['model'],
                      src_file="examples/data/head.ja",
                      ref_file="examples/data/head.en",
                      batcher=batcher)
     ]
     train_args['run_for_epochs'] = 1
     train_args['trainer'] = AdamTrainer(self.exp_global, alpha=0.1)
     train_args['batcher'] = batcher
     training_regimen = xnmt.training_regimen.SimpleTrainingRegimen(
         exp_global=self.exp_global, **train_args)
     training_regimen.exp_global = self.exp_global
     for _ in range(50):
         training_regimen.run_training(save_fct=lambda: None,
                                       update_weights=True)
     self.assertAlmostEqual(0.0,
                            training_regimen.logger.epoch_loss.sum() /
                            training_regimen.logger.epoch_words,
                            places=2)
예제 #10
0
 def test_overfitting(self):
     self.model_context = ModelContext()
     self.model_context.dynet_param_collection = NonPersistentParamCollection(
     )
     self.model_context.default_layer_dim = 16
     train_args = {}
     training_corpus = BilingualTrainingCorpus(
         train_src="examples/data/head.ja",
         train_trg="examples/data/head.en",
         dev_src="examples/data/head.ja",
         dev_trg="examples/data/head.en")
     train_args['corpus_parser'] = BilingualCorpusParser(
         training_corpus=training_corpus,
         src_reader=PlainTextReader(),
         trg_reader=PlainTextReader())
     train_args['loss_calculator'] = LossCalculator()
     train_args['model'] = DefaultTranslator(
         src_embedder=SimpleWordEmbedder(self.model_context,
                                         vocab_size=100),
         encoder=BiLSTMSeqTransducer(self.model_context),
         attender=MlpAttender(self.model_context),
         trg_embedder=SimpleWordEmbedder(self.model_context,
                                         vocab_size=100),
         decoder=MlpSoftmaxDecoder(self.model_context, vocab_size=100),
     )
     train_args['run_for_epochs'] = 1
     train_args['trainer'] = AdamTrainer(self.model_context, alpha=0.1)
     train_args['batcher'] = SrcBatcher(batch_size=10,
                                        break_ties_randomly=False)
     training_regimen = xnmt.training_regimen.SimpleTrainingRegimen(
         yaml_context=self.model_context, **train_args)
     training_regimen.model_context = self.model_context
     for _ in range(50):
         training_regimen.run_training(update_weights=True)
     self.assertAlmostEqual(
         0.0,
         training_regimen.logger.epoch_loss.loss_values['loss'] /
         training_regimen.logger.epoch_words,
         places=2)
예제 #11
0
파일: inference.py 프로젝트: nvog/xnmt
    def __call__(self,
                 corpus_parser,
                 generator,
                 batcher,
                 src_file=None,
                 trg_file=None,
                 candidate_id_file=None):
        """
    :param src_file: path of input src file to be translated
    :param trg_file: path of file where trg translatons will be written
    :param batcher:
    :param candidate_id_file: if we are doing something like retrieval where we select from fixed candidates, sometimes we want to limit our candidates to a certain subset of the full set. this setting allows us to do this.
    :param model_elements: If None, the model will be loaded from model_file. If set, should equal (corpus_parser, generator).
    """
        args = dict(model_file=self.model_file,
                    src_file=src_file or self.src_file,
                    trg_file=trg_file or self.trg_file,
                    ref_file=self.ref_file,
                    max_src_len=self.max_src_len,
                    input_format=self.input_format,
                    post_process=self.post_process,
                    candidate_id_file=candidate_id_file,
                    report_path=self.report_path,
                    report_type=self.report_type,
                    beam=self.beam,
                    max_len=self.max_len,
                    len_norm_type=self.len_norm_type,
                    mode=self.mode)

        is_reporting = issubclass(
            generator.__class__,
            Reportable) and args["report_path"] is not None
        # Corpus
        src_corpus = list(corpus_parser.src_reader.read_sents(
            args["src_file"]))
        # Get reference if it exists and is necessary
        if args["mode"] == "forced" or args["mode"] == "forceddebug":
            if args["ref_file"] == None:
                raise RuntimeError(
                    "When performing {} decoding, must specify reference file".
                    format(args["mode"]))
            ref_corpus = list(
                corpus_parser.trg_reader.read_sents(args["ref_file"]))
        else:
            ref_corpus = None
        # Vocab
        src_vocab = corpus_parser.src_reader.vocab if hasattr(
            corpus_parser.src_reader, "vocab") else None
        trg_vocab = corpus_parser.trg_reader.vocab if hasattr(
            corpus_parser.trg_reader, "vocab") else None
        # Perform initialization
        generator.set_train(False)
        generator.initialize_generator(**args)

        # TODO: Structure it better. not only Translator can have post processes
        if issubclass(generator.__class__, Translator):
            generator.set_post_processor(self.get_output_processor())
            generator.set_trg_vocab(trg_vocab)
            generator.set_reporting_src_vocab(src_vocab)

        if is_reporting:
            generator.set_report_resource("src_vocab", src_vocab)
            generator.set_report_resource("trg_vocab", trg_vocab)

        # If we're debugging, calculate the loss for each target sentence
        ref_scores = None
        if args["mode"] == 'forceddebug':
            some_batcher = xnmt.batcher.InOrderBatcher(32)  # Arbitrary
            batched_src, batched_ref = some_batcher.pack(
                src_corpus, ref_corpus)
            ref_scores = []
            for src, ref in zip(batched_src, batched_ref):
                dy.renew_cg()
                loss_expr = generator.calc_loss(
                    src, ref, loss_calculator=LossCalculator())
                ref_scores.extend(loss_expr.value())
            ref_scores = [-x for x in ref_scores]

        # Perform generation of output
        with io.open(args["trg_file"], 'wt', encoding='utf-8'
                     ) as fp:  # Saving the translated output to a trg file
            src_ret = []
            for i, src in enumerate(src_corpus):
                batcher.add_single_batch(src_curr=[src],
                                         trg_curr=None,
                                         src_ret=src_ret,
                                         trg_ret=None)
                src = src_ret.pop()[0]
                # Do the decoding
                if args["max_src_len"] is not None and len(
                        src) > args["max_src_len"]:
                    output_txt = NO_DECODING_ATTEMPTED
                else:
                    dy.renew_cg()
                    ref_ids = ref_corpus[i] if ref_corpus != None else None
                    output = generator.generate_output(src,
                                                       i,
                                                       forced_trg_ids=ref_ids)
                    # If debugging forced decoding, make sure it matches
                    if ref_scores != None and (
                            abs(output[0].score - ref_scores[i]) /
                            abs(ref_scores[i])) > 1e-5:
                        print(
                            'Forced decoding score {} and loss {} do not match at sentence {}'
                            .format(output[0].score, ref_scores[i], i))
                    output_txt = output[0].plaintext
                # Printing to trg file
                fp.write(u"{}\n".format(output_txt))
예제 #12
0
    def __call__(self,
                 generator,
                 src_file=None,
                 trg_file=None,
                 candidate_id_file=None):
        """
    Args:
      generator (GeneratorModel): the model to be used
      src_file (str): path of input src file to be translated
      trg_file (str): path of file where trg translatons will be written
      candidate_id_file (str): if we are doing something like retrieval where we select from fixed candidates, sometimes we want to limit our candidates to a certain subset of the full set. this setting allows us to do this.
    """
        args = dict(src_file=src_file or self.src_file,
                    trg_file=trg_file or self.trg_file,
                    ref_file=self.ref_file,
                    max_src_len=self.max_src_len,
                    post_process=self.post_process,
                    candidate_id_file=candidate_id_file,
                    report_path=self.report_path,
                    report_type=self.report_type,
                    beam=self.beam,
                    max_len=self.max_len,
                    len_norm_type=self.len_norm_type,
                    mode=self.mode)

        is_reporting = issubclass(
            generator.__class__,
            Reportable) and args["report_path"] is not None
        # Corpus
        src_corpus = list(generator.src_reader.read_sents(args["src_file"]))
        # Get reference if it exists and is necessary
        if args["mode"] == "forced" or args["mode"] == "forceddebug" or args[
                "mode"] == "score":
            if args["ref_file"] == None:
                raise RuntimeError(
                    "When performing {} decoding, must specify reference file".
                    format(args["mode"]))
            score_src_corpus = []
            ref_corpus = []
            with open(args["ref_file"], "r", encoding="utf-8") as fp:
                for line in fp:
                    if args["mode"] == "score":
                        nbest = line.split("|||")
                        assert len(
                            nbest
                        ) > 1, "When performing scoring, ref_file must have nbest format 'index ||| hypothesis'"
                        src_index = int(nbest[0].strip())
                        assert src_index < len(
                            src_corpus
                        ), "The src_file has only {} instances, nbest file has invalid src_index {}".format(
                            len(src_corpus), src_index)
                        score_src_corpus.append(src_corpus[src_index])
                        trg_input = generator.trg_reader.read_sent(
                            nbest[1].strip())
                    else:
                        trg_input = generator.trg_reader.read_sent(line)
                    ref_corpus.append(trg_input)
            if args["mode"] == "score":
                src_corpus = score_src_corpus
            else:
                if self.max_len and any(
                        len(s) > self.max_len for s in ref_corpus):
                    logger.warning(
                        "Forced decoding with some targets being longer than max_len. Increase max_len to avoid unexpected behavior."
                    )
        else:
            ref_corpus = None
        # Vocab
        src_vocab = generator.src_reader.vocab if hasattr(
            generator.src_reader, "vocab") else None
        trg_vocab = generator.trg_reader.vocab if hasattr(
            generator.trg_reader, "vocab") else None
        # Perform initialization
        generator.set_train(False)
        generator.initialize_generator(**args)

        if hasattr(generator, "set_post_processor"):
            generator.set_post_processor(self.get_output_processor())
        if hasattr(generator, "set_trg_vocab"):
            generator.set_trg_vocab(trg_vocab)
        if hasattr(generator, "set_reporting_src_vocab"):
            generator.set_reporting_src_vocab(src_vocab)

        if is_reporting:
            generator.set_report_resource("src_vocab", src_vocab)
            generator.set_report_resource("trg_vocab", trg_vocab)

        # If we're debugging, calculate the loss for each target sentence
        ref_scores = None
        if args["mode"] == 'forceddebug' or args["mode"] == 'score':
            some_batcher = xnmt.batcher.InOrderBatcher(32)  # Arbitrary
            if not isinstance(some_batcher, xnmt.batcher.InOrderBatcher):
                raise ValueError(
                    f"forceddebug requires InOrderBatcher, got: {some_batcher}"
                )
            batched_src, batched_ref = some_batcher.pack(
                src_corpus, ref_corpus)
            ref_scores = []
            for src, ref in zip(batched_src, batched_ref):
                dy.renew_cg(immediate_compute=settings.IMMEDIATE_COMPUTE,
                            check_validity=settings.CHECK_VALIDITY)
                loss_expr = generator.calc_loss(
                    src, ref, loss_calculator=LossCalculator())
                if isinstance(loss_expr.value(), Iterable):
                    ref_scores.extend(loss_expr.value())
                else:
                    ref_scores.append(loss_expr.value())
            ref_scores = [-x for x in ref_scores]

        # Make the parent directory if necessary
        make_parent_dir(args["trg_file"])

        # Perform generation of output
        if args["mode"] != 'score':
            with open(args["trg_file"], 'wt', encoding='utf-8'
                      ) as fp:  # Saving the translated output to a trg file
                src_ret = []
                for i, src in enumerate(src_corpus):
                    # This is necessary when the batcher does some sort of pre-processing, e.g.
                    # when the batcher pads to a particular number of dimensions
                    if self.batcher:
                        self.batcher.add_single_batch(src_curr=[src],
                                                      trg_curr=None,
                                                      src_ret=src_ret,
                                                      trg_ret=None)
                        src = src_ret.pop()[0]
                    # Do the decoding
                    if args["max_src_len"] is not None and len(
                            src) > args["max_src_len"]:
                        output_txt = NO_DECODING_ATTEMPTED
                    else:
                        dy.renew_cg(
                            immediate_compute=settings.IMMEDIATE_COMPUTE,
                            check_validity=settings.CHECK_VALIDITY)
                        ref_ids = ref_corpus[i] if ref_corpus != None else None
                        output = generator.generate_output(
                            src, i, forced_trg_ids=ref_ids)
                        # If debugging forced decoding, make sure it matches
                        if ref_scores != None and (
                                abs(output[0].score - ref_scores[i]) /
                                abs(ref_scores[i])) > 1e-5:
                            logger.error(
                                f'Forced decoding score {output[0].score} and loss {ref_scores[i]} do not match at sentence {i}'
                            )
                        output_txt = output[0].plaintext
                    # Printing to trg file
                    fp.write(f"{output_txt}\n")
        else:
            with open(args["trg_file"], 'wt', encoding='utf-8') as fp:
                with open(args["ref_file"], "r", encoding="utf-8") as nbest_fp:
                    for nbest, score in zip(nbest_fp, ref_scores):
                        fp.write("{} ||| score={}\n".format(
                            nbest.strip(), score))
예제 #13
0
파일: training_task.py 프로젝트: nvog/xnmt
    def __init__(self,
                 yaml_context,
                 corpus_parser,
                 model,
                 glob={},
                 dev_every=0,
                 batcher=None,
                 loss_calculator=None,
                 pretrained_model_file="",
                 src_format="text",
                 run_for_epochs=None,
                 lr_decay=1.0,
                 lr_decay_times=3,
                 patience=1,
                 initial_patience=None,
                 dev_metrics="",
                 schedule_metric="loss",
                 restart_trainer=False,
                 reload_command=None,
                 name=None,
                 inference=None):
        """
    :param yaml_context:
    :param corpus_parser: an input.InputReader object
    :param model: a generator.GeneratorModel object
    :param dev_every (int): dev checkpoints every n sentences (0 for only after epoch)
    :param batcher: Type of batcher. Defaults to SrcBatcher of batch size 32.
    :param loss_calculator:
    :param pretrained_model_file: Path of pre-trained model file
    :param src_format: Format of input data: text/contvec
    :param lr_decay (float):
    :param lr_decay_times (int):  Early stopping after decaying learning rate a certain number of times
    :param patience (int): apply LR decay after dev scores haven't improved over this many checkpoints
    :param initial_patience (int): if given, allows adjusting patience for the first LR decay
    :param dev_metrics: Comma-separated list of evaluation metrics (bleu/wer/cer)
    :param schedule_metric: determine learning schedule based on this dev_metric (loss/bleu/wer/cer)
    :param restart_trainer: Restart trainer (useful for Adam) and revert weights to best dev checkpoint when applying LR decay (https://arxiv.org/pdf/1706.09733.pdf)
    :param reload_command: Command to change the input data after each epoch.
                           --epoch EPOCH_NUM will be appended to the command.
                           To just reload the data after each epoch set the command to 'true'.
    :param name: will be prepended to log outputs if given
    :param inference: used for inference during dev checkpoints if dev_metrics are specified
    """
        assert yaml_context is not None
        self.yaml_context = yaml_context
        self.model_file = self.yaml_context.dynet_param_collection.model_file
        self.yaml_serializer = YamlSerializer()

        if lr_decay > 1.0 or lr_decay <= 0.0:
            raise RuntimeError(
                "illegal lr_decay, must satisfy: 0.0 < lr_decay <= 1.0")
        self.lr_decay = lr_decay
        self.patience = patience
        self.initial_patience = initial_patience
        self.lr_decay_times = lr_decay_times
        self.restart_trainer = restart_trainer
        self.run_for_epochs = run_for_epochs

        self.early_stopping_reached = False
        # training state
        self.training_state = TrainingState()

        self.evaluators = [
            s.lower() for s in dev_metrics.split(",") if s.strip() != ""
        ]
        if schedule_metric.lower() not in self.evaluators:
            self.evaluators.append(schedule_metric.lower())
        if "loss" not in self.evaluators: self.evaluators.append("loss")
        if dev_metrics:
            self.inference = inference or SimpleInference()

        self.reload_command = reload_command
        if reload_command is not None:
            self._augmentation_handle = None
            self._augment_data_initial()

        self.model = model
        self.corpus_parser = corpus_parser
        self.loss_calculator = loss_calculator or LossCalculator(MLELoss())
        self.pretrained_model_file = pretrained_model_file
        if self.pretrained_model_file:
            self.yaml_context.dynet_param_collection.load_from_data_file(
                self.pretrained_model_file + '.data')

        self.batcher = batcher or SrcBatcher(32)
        if src_format == "contvec":
            self.batcher.pad_token = np.zeros(self.model.src_embedder.emb_dim)
        self.pack_batches()
        self.logger = BatchLossTracker(self, dev_every, name)

        self.schedule_metric = schedule_metric.lower()
예제 #14
0
  def __init__(self, model, src_file=None, trg_file=None, dev_every=0,
               batcher=bare(SrcBatcher, batch_size=32), loss_calculator=None,
               run_for_epochs=None, lr_decay=1.0, lr_decay_times=3, patience=1,
               initial_patience=None, dev_tasks=None, restart_trainer=False,
               reload_command=None, name=None, sample_train_sents=None,
               max_num_train_sents=None, max_src_len=None, max_trg_len=None,
               exp_global=Ref(Path("exp_global"))):
    """
    Args:
      exp_global:
      model: a generator.GeneratorModel object
      src_file: The file for the source data.
      trg_file: The file for the target data.
      dev_every (int): dev checkpoints every n sentences (0 for only after epoch)
      batcher: Type of batcher
      loss_calculator:
      lr_decay (float):
      lr_decay_times (int):  Early stopping after decaying learning rate a certain number of times
      patience (int): apply LR decay after dev scores haven't improved over this many checkpoints
      initial_patience (int): if given, allows adjusting patience for the first LR decay
      dev_tasks: A list of tasks to run on the development set
      restart_trainer: Restart trainer (useful for Adam) and revert weights to best dev checkpoint when applying LR decay (https://arxiv.org/pdf/1706.09733.pdf)
      reload_command: Command to change the input data after each epoch.
                           --epoch EPOCH_NUM will be appended to the command.
                           To just reload the data after each epoch set the command to 'true'.
      sample_train_sents:
      max_num_train_sents:
      max_src_len:
      max_trg_len:
      name: will be prepended to log outputs if given
    """
    self.exp_global = exp_global
    self.model_file = self.exp_global.dynet_param_collection.model_file
    self.src_file = src_file
    self.trg_file = trg_file
    self.dev_tasks = dev_tasks

    if lr_decay > 1.0 or lr_decay <= 0.0:
      raise RuntimeError("illegal lr_decay, must satisfy: 0.0 < lr_decay <= 1.0")
    self.lr_decay = lr_decay
    self.patience = patience
    self.initial_patience = initial_patience
    self.lr_decay_times = lr_decay_times
    self.restart_trainer = restart_trainer
    self.run_for_epochs = run_for_epochs

    self.early_stopping_reached = False
    # training state
    self.training_state = TrainingState()

    self.reload_command = reload_command

    self.model = model
    self.loss_calculator = loss_calculator or LossCalculator(MLELoss())

    self.sample_train_sents = sample_train_sents
    self.max_num_train_sents = max_num_train_sents
    self.max_src_len = max_src_len
    self.max_trg_len = max_trg_len

    self.batcher = batcher
    self.logger = BatchLossTracker(self, dev_every, name)