Example #1
0
    def setUp(self):
        test_path = os.path.dirname(os.path.realpath(__file__))
        src = SourceField(batch_first=True)
        tgt = TargetField(batch_first=True)
        self.dataset = torchtext.data.TabularDataset(
            path=os.path.join(test_path, 'data/eng-fra.txt'),
            format='tsv',
            fields=[('src', src), ('tgt', tgt)],
        )
        src.build_vocab(self.dataset)
        tgt.build_vocab(self.dataset)

        self.data_iterator = torchtext.data.BucketIterator(
            dataset=self.dataset,
            batch_size=64,
            sort=False,
            sort_within_batch=True,
            sort_key=lambda x: len(x.src),
            repeat=False)

        encoder = EncoderRNN(len(src.vocab), 10, 10, 10, rnn_cell='lstm')
        decoder = DecoderRNN(len(tgt.vocab),
                             10,
                             10,
                             tgt.sos_id,
                             tgt.eos_id,
                             rnn_cell='lstm')
        self.seq2seq = Seq2seq(encoder, decoder)

        for param in self.seq2seq.parameters():
            param.data.uniform_(-0.08, 0.08)
Example #2
0
def get_train_valid_tests(train_path,
                          valid_path,
                          test_paths,
                          max_len,
                          src_vocab=50000,
                          tgt_vocab=50000):
    """Gets the formatted train, valid, and test data."""
    def len_filter(example):
        return len(example.src) <= max_len and len(example.tgt) <= max_len

    src = SourceField()
    tgt = TargetField(include_eos=True)
    fields = [('src', src), ('tgt', tgt)]
    train = torchtext.data.TabularDataset(path=train_path,
                                          format='tsv',
                                          fields=fields,
                                          filter_pred=len_filter)
    valid = torchtext.data.TabularDataset(path=valid_path,
                                          format='tsv',
                                          fields=fields,
                                          filter_pred=len_filter)

    tests = []
    for t in test_paths:
        tests.append(
            torchtext.data.TabularDataset(path=t,
                                          format='tsv',
                                          fields=fields,
                                          filter_pred=len_filter))

    src.build_vocab(train, max_size=src_vocab)
    tgt.build_vocab(train, max_size=tgt_vocab)

    return train, valid, tests, src, tgt
Example #3
0
 def setUp(self):
     test_path = os.path.dirname(os.path.realpath(__file__))
     src = SourceField()
     tgt = TargetField()
     self.dataset = torchtext.data.TabularDataset(
         path=os.path.join(test_path, 'data/eng-fra.txt'), format='tsv',
         fields=[('src', src), ('tgt', tgt)],
     )
     src.build_vocab(self.dataset)
     tgt.build_vocab(self.dataset)
Example #4
0
    def setUp(self):
        test_path = os.path.dirname(os.path.realpath(__file__))
        src = SourceField(batch_first=True)
        tgt = TargetField(batch_first=True)
        self.dataset = torchtext.data.TabularDataset(
            path=os.path.join(test_path, 'data/eng-fra.txt'), format='tsv',
            fields=[('src', src), ('tgt', tgt)],
        )
        src.build_vocab(self.dataset)
        tgt.build_vocab(self.dataset)

        self.data_iterator = torchtext.data.BucketIterator(
            dataset=self.dataset, batch_size=4,
            sort=False, sort_within_batch=True,
            sort_key=lambda x: len(x.src),
            repeat=False)
Example #5
0
    def setUp(self):
        test_path = os.path.dirname(os.path.realpath(__file__))
        src = SourceField()
        tgt = TargetField()
        self.dataset = torchtext.data.TabularDataset(
            path=os.path.join(test_path, 'data/eng-fra.txt'), format='tsv',
            fields=[('src', src), ('tgt', tgt)],
        )
        src.build_vocab(self.dataset)
        tgt.build_vocab(self.dataset)

        encoder = EncoderRNN(len(src.vocab), 10, 10, 10, rnn_cell='lstm')
        decoder = DecoderRNN(len(tgt.vocab), 10, 10, tgt.sos_id, tgt.eos_id, rnn_cell='lstm')
        self.seq2seq = Seq2seq(encoder, decoder)

        for param in self.seq2seq.parameters():
            param.data.uniform_(-0.08, 0.08)
def load_test_data(opt):
    src = SourceField()
    tgt = TargetField()
    tabular_data_fields = [('src', src), ('tgt', tgt)]

    max_len = opt.max_len

    def len_filter(example):
        return len(example.src) <= max_len and len(example.tgt) <= max_len

    # generate training and testing data
    test = torchtext.data.TabularDataset(path=opt.test,
                                         format='tsv',
                                         fields=tabular_data_fields,
                                         filter_pred=len_filter)

    return test, src, tgt
Example #7
0
def prepare_iters(parameters,
                  train_path,
                  test_paths,
                  valid_path,
                  batch_size,
                  eval_batch_size=512):
    src = SourceField(batch_first=True)
    tgt = TargetField(include_eos=False, batch_first=True)
    tabular_data_fields = [('src', src), ('tgt', tgt)]

    max_len = parameters['max_len']

    def len_filter(example):
        return len(example.src) <= max_len and len(example.tgt) <= max_len

    # generate training and testing data
    train = get_standard_iter(torchtext.data.TabularDataset(
        path=train_path,
        format='tsv',
        fields=tabular_data_fields,
        filter_pred=len_filter),
                              batch_size=batch_size)

    dev = get_standard_iter(torchtext.data.TabularDataset(
        path=valid_path,
        format='tsv',
        fields=tabular_data_fields,
        filter_pred=len_filter),
                            batch_size=eval_batch_size)

    monitor_data = OrderedDict()
    for dataset in test_paths:
        m = get_standard_iter(torchtext.data.TabularDataset(
            path=dataset,
            format='tsv',
            fields=tabular_data_fields,
            filter_pred=len_filter),
                              batch_size=eval_batch_size)
        monitor_data[dataset] = m

    return src, tgt, train, dev, monitor_data
Example #8
0
def prepare_iters(opt):

    use_output_eos = not opt.ignore_output_eos
    src = SourceField(batch_first=True)
    tgt = TargetField(include_eos=use_output_eos, batch_first=True)
    tabular_data_fields = [('src', src), ('tgt', tgt)]

    max_len = opt.max_len

    def len_filter(example):
        return len(example.src) <= max_len and len(example.tgt) <= max_len

    # generate training and testing data
    train = get_standard_iter(torchtext.data.TabularDataset(
        path=opt.train,
        format='tsv',
        fields=tabular_data_fields,
        filter_pred=len_filter),
                              batch_size=opt.batch_size)

    if opt.dev:
        dev = get_standard_iter(torchtext.data.TabularDataset(
            path=opt.dev,
            format='tsv',
            fields=tabular_data_fields,
            filter_pred=len_filter),
                                batch_size=opt.eval_batch_size)
    else:
        dev = None

    monitor_data = OrderedDict()
    for dataset in opt.monitor:
        m = get_standard_iter(torchtext.data.TabularDataset(
            path=dataset,
            format='tsv',
            fields=tabular_data_fields,
            filter_pred=len_filter),
                              batch_size=opt.eval_batch_size)
        monitor_data[dataset] = m

    return src, tgt, train, dev, monitor_data
Example #9
0
    def setUpClass(self):
        test_path = os.path.dirname(os.path.realpath(__file__))
        src = SourceField(batch_first=True)
        trg = TargetField(batch_first=True)
        dataset = torchtext.data.TabularDataset(
            path=os.path.join(test_path, 'data/eng-fra.txt'),
            format='tsv',
            fields=[('src', src), ('trg', trg)],
        )
        src.build_vocab(dataset)
        trg.build_vocab(dataset)

        encoder = EncoderRNN(len(src.vocab), 5, 10, 10, rnn_cell='lstm')
        decoder = DecoderRNN(len(trg.vocab),
                             10,
                             10,
                             trg.sos_id,
                             trg.eos_id,
                             rnn_cell='lstm')
        seq2seq = Seq2seq(encoder, decoder)
        self.predictor = Predictor(seq2seq, src.vocab, trg.vocab)
Example #10
0
def prepare_iters(opt):

    src = SourceField(batch_first=True)
    tgt = TargetField(batch_first=True, include_eos=True)
    tabular_data_fields = [('src', src), ('tgt', tgt)]

    max_len = opt.max_len

    def len_filter(example):
        return len(example.src) <= max_len and len(example.tgt) <= max_len

    ds = '100K'
    if opt.mini:
        ds = '10K'

    # generate training and testing data
    train = get_standard_iter(torchtext.data.TabularDataset(
        path='data/pcfg_set/{}/train.tsv'.format(ds),
        format='tsv',
        fields=tabular_data_fields,
        filter_pred=len_filter),
                              batch_size=opt.batch_size)

    dev = get_standard_iter(torchtext.data.TabularDataset(
        path='data/pcfg_set/{}/dev.tsv'.format(ds),
        format='tsv',
        fields=tabular_data_fields,
        filter_pred=len_filter),
                            batch_size=opt.eval_batch_size)

    monitor_data = OrderedDict()
    m = get_standard_iter(torchtext.data.TabularDataset(
        path='data/pcfg_set/{}/test.tsv'.format(ds),
        format='tsv',
        fields=tabular_data_fields,
        filter_pred=len_filter),
                          batch_size=opt.eval_batch_size)
    monitor_data['Test'] = m

    return src, tgt, train, dev, monitor_data
Example #11
0
    logging.info("Cuda device set to %i" % opt.cuda_device)
    torch.cuda.set_device(opt.cuda_device)

#################################################################################
# load model

logging.info("loading checkpoint from {}".format(
    os.path.join(opt.checkpoint_path)))
checkpoint = Checkpoint.load(opt.checkpoint_path)
seq2seq = checkpoint.model
input_vocab = checkpoint.input_vocab
output_vocab = checkpoint.output_vocab

############################################################################
# Prepare dataset and loss
src = SourceField()
tgt = TargetField(output_eos_used)

tabular_data_fields = [('src', src), ('tgt', tgt)]

if opt.use_attention_loss or opt.attention_method == 'hard':
    attn = AttentionField(use_vocab=False, ignore_index=IGNORE_INDEX)
    tabular_data_fields.append(('attn', attn))

src.vocab = input_vocab
tgt.vocab = output_vocab
tgt.eos_id = tgt.vocab.stoi[tgt.SYM_EOS]
tgt.sos_id = tgt.vocab.stoi[tgt.SYM_SOS]
max_len = opt.max_len

Example #12
0
 def test_sourcefield(self):
     field = SourceField()
     self.assertTrue(isinstance(field, torchtext.data.Field))
     self.assertTrue(field.batch_first)
     self.assertTrue(field.include_lengths)
Example #13
0
 def test_sourcefield_with_wrong_setting(self):
     field = SourceField(batch_first=False, include_lengths=False)
     self.assertTrue(isinstance(field, torchtext.data.Field))
     self.assertTrue(field.batch_first)
     self.assertTrue(field.include_lengths)
Example #14
0
    logging.info("Cuda device set to %i" % opt.cuda_device)
    torch.cuda.set_device(opt.cuda_device)

##########################################################################
# load model

logging.info("loading checkpoint from {}".format(
    os.path.join(opt.checkpoint_path)))
checkpoint = Checkpoint.load(opt.checkpoint_path)
seq2seq = checkpoint.model
input_vocab = checkpoint.input_vocab
output_vocab = checkpoint.output_vocab

############################################################################
# Prepare dataset and loss
src = SourceField(batch_first=True)
tgt = TargetField(output_eos_used, batch_first=True)

tabular_data_fields = [('src', src), ('tgt', tgt)]

src.vocab = input_vocab
tgt.vocab = output_vocab
tgt.eos_id = tgt.vocab.stoi[tgt.SYM_EOS]
tgt.sos_id = tgt.vocab.stoi[tgt.SYM_SOS]
max_len = opt.max_len


def len_filter(example):
    return len(example.src) <= max_len and len(example.tgt) <= max_len

if opt.attention:
    if not opt.attention_method:
        logging.info("No attention method provided. Using DOT method.")
        opt.attention_method = 'dot'

# Set random seed
if opt.random_seed:
    random.seed(opt.random_seed)
    np.random.seed(opt.random_seed)
    torch.manual_seed(opt.random_seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(opt.random_seed)

############################################################################
# Prepare dataset
src = SourceField(lower=opt.lower, preprocessing=lambda seq: seq + ['<eos>'])
tgt = TargetField(include_eos=use_output_eos, lower=opt.lower)

tabular_data_fields = [('src', src), ('tgt', tgt)]

max_len = opt.max_len


def len_filter(example):
    return len(example.src) <= max_len and len(example.tgt) <= max_len


# generate training and testing data
if opt.augmented_input:
    dataset_class = TabularAugmentedDataset
else:
Example #16
0
if opt.attention and not opt.attention_method:
    parser.error("Attention turned on, but no attention method provided")

if torch.cuda.is_available():
    logging.info("Cuda device set to %i" % opt.cuda_device)
    torch.cuda.set_device(opt.cuda_device)

if opt.attention:
    if not opt.attention_method:
        logging.info("No attention method provided. Using DOT method.")
        opt.attention_method = 'dot'

############################################################################
# Prepare dataset
src = SourceField(lower=opt.lower)
tgt = TargetField(include_eos=use_output_eos, lower=opt.lower)

tabular_data_fields = [('src', src), ('tgt', tgt)]

max_len = opt.max_len


def len_filter(example):
    return len(example.src) <= max_len and len(example.tgt) <= max_len


# generate training and testing data
train = torchtext.data.TabularDataset(path=opt.train,
                                      format='tsv',
                                      fields=tabular_data_fields,