Ejemplo n.º 1
0
    def drop_checkpoint(self, opt, epoch, fields, valid_stats):
        """ Save a resumable checkpoint.

        Args:
            opt (dict): option object
            epoch (int): epoch number
            fields (dict): fields and vocabulary
            valid_stats : statistics of last validation run
        """
        real_model = (self.model.module if isinstance(
            self.model, nn.DataParallel) else self.model)
        real_generator = (real_model.generator.module if isinstance(
            real_model.generator, nn.DataParallel) else real_model.generator)

        model_state_dict = real_model.state_dict()
        model_state_dict = {
            k: v
            for k, v in model_state_dict.items() if 'generator' not in k
        }
        generator_state_dict = real_generator.state_dict()
        checkpoint = {
            'model': model_state_dict,
            'generator': generator_state_dict,
            'vocab': inputters.save_fields_to_vocab(fields),
            'opt': opt,
            'epoch': epoch,
            'optim': self.optim,
        }
        torch.save(
            checkpoint, '%s_acc_%.2f_ppl_%.2f_e%d.pt' %
            (opt.save_model, valid_stats.accuracy(), valid_stats.ppl(), epoch))
Ejemplo n.º 2
0
def build_save_vocab(train_dataset, fields, opt):
    fields = inputters.build_vocab(
        train_dataset, fields, opt.data_type, opt.share_vocab,
        opt.src_vocab, opt.src_vocab_size, opt.src_words_min_frequency,
        opt.tgt_vocab, opt.tgt_vocab_size, opt.tgt_words_min_frequency
    )

    vocab_path = opt.save_data + '.vocab.pt'
    torch.save(inputters.save_fields_to_vocab(fields), vocab_path)
Ejemplo n.º 3
0
def build_save_vocab(train_dataset, fields, opt):
    """ Building and saving the vocab """
    fields = inputters.build_vocab(
        train_dataset, fields, opt.data_type, opt.share_vocab, opt.src_vocab,
        opt.src_vocab_size, opt.src_words_min_frequency, opt.tgt_vocab,
        opt.tgt_vocab_size, opt.tgt_words_min_frequency, opt.extra_vocab)

    # Can't save fields, so remove/reconstruct at training time.
    vocab_file = opt.save_data + '.vocab.pt'
    torch.save(inputters.save_fields_to_vocab(fields), vocab_file)
Ejemplo n.º 4
0
def build_save_vocab(train_dataset, data_type, fields, opt):
    """ Building and saving the vocab """
    fields = inputters.build_vocab(train_dataset, data_type, fields,
                                   opt.share_vocab, opt.src_vocab_size,
                                   opt.src_words_min_frequency,
                                   opt.tgt_vocab_size,
                                   opt.tgt_words_min_frequency)

    # Can't save fields, so remove/reconstruct at training time.
    vocab_file = opt.save_data + '.vocab.pt'
    logger.info("saving vocabulary {}".format(vocab_file))
    torch.save(inputters.save_fields_to_vocab(fields), vocab_file)
Ejemplo n.º 5
0
def build_save_vocab(train_dataset, fields, opt):
    """ Building and saving the vocab """

    fields = inputters.build_vocab(train_dataset, fields, "text",
                                   opt.share_vocab, opt.src_vocab,
                                   opt.src_vocab_size,
                                   opt.src_words_min_frequency, opt.tgt_vocab,
                                   opt.tgt_vocab_size,
                                   opt.tgt_words_min_frequency)
    # Can't save fields, so remove/reconstruct at training time.
    vocab_file = opt.save_data + '.vocab.pt'
    # torch.save(inputters.save_fields_to_vocab(fields), vocab_file)
    with open(vocab_file, 'wb') as f:
        pickle.dump(inputters.save_fields_to_vocab(fields), f)
Ejemplo n.º 6
0
def build_save_vocab(train_dataset, fields, savepath, opt):
    """ Building and saving the vocab """

    fields = inputters.build_vocab(train_dataset,
                                   fields,
                                   data_type='text',
                                   share_vocab=True,
                                   src_vocab_path='',
                                   src_vocab_size=100,
                                   src_words_min_frequency=1,
                                   tgt_vocab_path='',
                                   tgt_vocab_size=100,
                                   tgt_words_min_frequency=1)
    # Can't save fields, so remove/reconstruct at training time.
    vocab_file = savepath + '/vocab.pt'
    with open(vocab_file, 'wb') as f:
        pickle.dump(inputters.save_fields_to_vocab(fields), f)
Ejemplo n.º 7
0
def build_save_vocab(fields, opt):
    """ Building and saving the vocab """
    fields = inputters.build_vocab_vec(fields)
    # Can't save fields, so remove/reconstruct at training time.
    vocab_file = opt.save_data + '.vocab.pt'
    torch.save(inputters.save_fields_to_vocab(fields), vocab_file)