Пример #1
0
def load_data(train_src, train_tgt, val_src, val_tgt, batch_size=64, save_path="checkpoint"):
    # prepare dataset
    print("Reading data...")
    train = Seq2SeqDataset.from_file(train_src, train_tgt)

    print("Building vocab...")
    train.build_vocab(max_size=300)

    val = Seq2SeqDataset.from_file(val_src, val_tgt, share_fields_from=train)

    src_vocab = train.src_field.vocab
    tgt_vocab = train.tgt_field.vocab

    # save vocab
    with open(os.path.join(save_path, "vocab.src"), "wb") as f:
        dill.dump(src_vocab, f)
    with open(os.path.join(save_path, "vocab.tgt"), "wb") as f:
        dill.dump(tgt_vocab, f)

    print("Source vocab size:", len(src_vocab))
    print("Target vocab size:", len(tgt_vocab))

    # data iterator
    # keep sort=False and shuffle=False to speed up training and reduce memory usage
    train_iterator = BucketIterator(dataset=train, batch_size=batch_size,
                                    sort=False, sort_within_batch=True,
                                    sort_key=lambda x: len(x.src),
                                    shuffle=False, device=device)
    val_iterator = BucketIterator(dataset=val, batch_size=batch_size, train=False,
                                  sort=False, sort_within_batch=True,
                                  sort_key=lambda x: len(x.src),
                                  shuffle=False, device=device)

    return src_vocab, tgt_vocab, train_iterator, val_iterator
Пример #2
0
    def train_in_parts(self, train_parts, val, val_iterator, batch_size, start_epoch=0, print_every=100):
        for epoch in range(start_epoch, self.n_epochs):
            # shuffle data each epoch
            random.shuffle(train_parts)

            for train_src_, train_tgt_ in train_parts:
                # create train dataset
                print("Training part [{}] with target [{}]...".format(train_src_, train_tgt_))
                train_ = Seq2SeqDataset.from_file(train_src_, train_tgt_, share_fields_from=val)

                # create iterator
                train_iterator_ = BucketIterator(dataset=train_, batch_size=batch_size,
                                                 sort=False, sort_within_batch=True,
                                                 sort_key=lambda x: len(x.src),
                                                 shuffle=True, device=device)
                # train
                self._train_epoch(epoch, train_iterator_, train=True, print_every=print_every)

                # clean
                del train_
                del train_iterator_
                gc.collect()

            # save
            self.save(epoch)

            # evaluate on validation set after each epoch
            with torch.no_grad():
                self._train_epoch(epoch, val_iterator, train=False, print_every=print_every)