コード例 #1
0
    def test_targetfield(self):
        field = TargetField()
        self.assertTrue(isinstance(field, torchtext.data.Field))
        self.assertTrue(field.batch_first)

        processed = field.preprocessing([None])
        self.assertEqual(processed, ['<sos>', None, '<eos>'])
コード例 #2
0
    def setUp(self):
        test_path = os.path.dirname(os.path.realpath(__file__))
        src = SourceField(batch_first=True)
        tgt = TargetField(batch_first=True)
        self.dataset = torchtext.data.TabularDataset(
            path=os.path.join(test_path, 'data/eng-fra.txt'),
            format='tsv',
            fields=[('src', src), ('tgt', tgt)],
        )
        src.build_vocab(self.dataset)
        tgt.build_vocab(self.dataset)

        self.data_iterator = torchtext.data.BucketIterator(
            dataset=self.dataset,
            batch_size=64,
            sort=False,
            sort_within_batch=True,
            sort_key=lambda x: len(x.src),
            repeat=False)

        encoder = EncoderRNN(len(src.vocab), 10, 10, 10, rnn_cell='lstm')
        decoder = DecoderRNN(len(tgt.vocab),
                             10,
                             10,
                             tgt.sos_id,
                             tgt.eos_id,
                             rnn_cell='lstm')
        self.seq2seq = Seq2seq(encoder, decoder)

        for param in self.seq2seq.parameters():
            param.data.uniform_(-0.08, 0.08)
コード例 #3
0
    def test_targetfield_with_other_setting(self):
        field = TargetField(batch_first=False, preprocessing=lambda seq: seq + seq)
        self.assertTrue(isinstance(field, torchtext.data.Field))
        self.assertTrue(field.batch_first)

        processed = field.preprocessing([None])
        self.assertEqual(processed, ['<sos>', None, None, '<eos>'])
コード例 #4
0
def get_train_valid_tests(train_path,
                          valid_path,
                          test_paths,
                          max_len,
                          src_vocab=50000,
                          tgt_vocab=50000):
    """Gets the formatted train, valid, and test data."""
    def len_filter(example):
        return len(example.src) <= max_len and len(example.tgt) <= max_len

    src = SourceField()
    tgt = TargetField(include_eos=True)
    fields = [('src', src), ('tgt', tgt)]
    train = torchtext.data.TabularDataset(path=train_path,
                                          format='tsv',
                                          fields=fields,
                                          filter_pred=len_filter)
    valid = torchtext.data.TabularDataset(path=valid_path,
                                          format='tsv',
                                          fields=fields,
                                          filter_pred=len_filter)

    tests = []
    for t in test_paths:
        tests.append(
            torchtext.data.TabularDataset(path=t,
                                          format='tsv',
                                          fields=fields,
                                          filter_pred=len_filter))

    src.build_vocab(train, max_size=src_vocab)
    tgt.build_vocab(train, max_size=tgt_vocab)

    return train, valid, tests, src, tgt
コード例 #5
0
 def setUp(self):
     test_path = os.path.dirname(os.path.realpath(__file__))
     src = SourceField()
     tgt = TargetField()
     self.dataset = torchtext.data.TabularDataset(
         path=os.path.join(test_path, 'data/eng-fra.txt'), format='tsv',
         fields=[('src', src), ('tgt', tgt)],
     )
     src.build_vocab(self.dataset)
     tgt.build_vocab(self.dataset)
コード例 #6
0
 def test_targetfield_specials(self):
     test_path = os.path.dirname(os.path.realpath(__file__))
     data_path = os.path.join(test_path, 'data/eng-fra.txt')
     field = TargetField()
     train = torchtext.data.TabularDataset(
         path=data_path, format='tsv',
         fields=[('src', torchtext.data.Field()), ('trg', field)]
     )
     self.assertTrue(field.sos_id is None)
     self.assertTrue(field.eos_id is None)
     field.build_vocab(train)
     self.assertFalse(field.sos_id is None)
     self.assertFalse(field.eos_id is None)
コード例 #7
0
    def setUp(self):
        test_path = os.path.dirname(os.path.realpath(__file__))
        src = SourceField(batch_first=True)
        tgt = TargetField(batch_first=True)
        self.dataset = torchtext.data.TabularDataset(
            path=os.path.join(test_path, 'data/eng-fra.txt'), format='tsv',
            fields=[('src', src), ('tgt', tgt)],
        )
        src.build_vocab(self.dataset)
        tgt.build_vocab(self.dataset)

        self.data_iterator = torchtext.data.BucketIterator(
            dataset=self.dataset, batch_size=4,
            sort=False, sort_within_batch=True,
            sort_key=lambda x: len(x.src),
            repeat=False)
コード例 #8
0
    def setUp(self):
        test_path = os.path.dirname(os.path.realpath(__file__))
        src = SourceField()
        tgt = TargetField()
        self.dataset = torchtext.data.TabularDataset(
            path=os.path.join(test_path, 'data/eng-fra.txt'), format='tsv',
            fields=[('src', src), ('tgt', tgt)],
        )
        src.build_vocab(self.dataset)
        tgt.build_vocab(self.dataset)

        encoder = EncoderRNN(len(src.vocab), 10, 10, 10, rnn_cell='lstm')
        decoder = DecoderRNN(len(tgt.vocab), 10, 10, tgt.sos_id, tgt.eos_id, rnn_cell='lstm')
        self.seq2seq = Seq2seq(encoder, decoder)

        for param in self.seq2seq.parameters():
            param.data.uniform_(-0.08, 0.08)
コード例 #9
0
    def setUpClass(self):
        test_path = os.path.dirname(os.path.realpath(__file__))
        src = SourceField(batch_first=True)
        trg = TargetField(batch_first=True)
        dataset = torchtext.data.TabularDataset(
            path=os.path.join(test_path, 'data/eng-fra.txt'),
            format='tsv',
            fields=[('src', src), ('trg', trg)],
        )
        src.build_vocab(dataset)
        trg.build_vocab(dataset)

        encoder = EncoderRNN(len(src.vocab), 5, 10, 10, rnn_cell='lstm')
        decoder = DecoderRNN(len(trg.vocab),
                             10,
                             10,
                             trg.sos_id,
                             trg.eos_id,
                             rnn_cell='lstm')
        seq2seq = Seq2seq(encoder, decoder)
        self.predictor = Predictor(seq2seq, src.vocab, trg.vocab)
コード例 #10
0
def load_test_data(opt):
    src = SourceField()
    tgt = TargetField()
    tabular_data_fields = [('src', src), ('tgt', tgt)]

    max_len = opt.max_len

    def len_filter(example):
        return len(example.src) <= max_len and len(example.tgt) <= max_len

    # generate training and testing data
    test = torchtext.data.TabularDataset(path=opt.test,
                                         format='tsv',
                                         fields=tabular_data_fields,
                                         filter_pred=len_filter)

    return test, src, tgt
コード例 #11
0
ファイル: train_model.py プロジェクト: m0re4u/machine
def prepare_iters(opt):

    use_output_eos = not opt.ignore_output_eos
    src = SourceField(batch_first=True)
    tgt = TargetField(include_eos=use_output_eos, batch_first=True)
    tabular_data_fields = [('src', src), ('tgt', tgt)]

    max_len = opt.max_len

    def len_filter(example):
        return len(example.src) <= max_len and len(example.tgt) <= max_len

    # generate training and testing data
    train = get_standard_iter(torchtext.data.TabularDataset(
        path=opt.train,
        format='tsv',
        fields=tabular_data_fields,
        filter_pred=len_filter),
                              batch_size=opt.batch_size)

    if opt.dev:
        dev = get_standard_iter(torchtext.data.TabularDataset(
            path=opt.dev,
            format='tsv',
            fields=tabular_data_fields,
            filter_pred=len_filter),
                                batch_size=opt.eval_batch_size)
    else:
        dev = None

    monitor_data = OrderedDict()
    for dataset in opt.monitor:
        m = get_standard_iter(torchtext.data.TabularDataset(
            path=dataset,
            format='tsv',
            fields=tabular_data_fields,
            filter_pred=len_filter),
                              batch_size=opt.eval_batch_size)
        monitor_data[dataset] = m

    return src, tgt, train, dev, monitor_data
コード例 #12
0
ファイル: test_tasks.py プロジェクト: m0re4u/machine
def prepare_iters(parameters,
                  train_path,
                  test_paths,
                  valid_path,
                  batch_size,
                  eval_batch_size=512):
    src = SourceField(batch_first=True)
    tgt = TargetField(include_eos=False, batch_first=True)
    tabular_data_fields = [('src', src), ('tgt', tgt)]

    max_len = parameters['max_len']

    def len_filter(example):
        return len(example.src) <= max_len and len(example.tgt) <= max_len

    # generate training and testing data
    train = get_standard_iter(torchtext.data.TabularDataset(
        path=train_path,
        format='tsv',
        fields=tabular_data_fields,
        filter_pred=len_filter),
                              batch_size=batch_size)

    dev = get_standard_iter(torchtext.data.TabularDataset(
        path=valid_path,
        format='tsv',
        fields=tabular_data_fields,
        filter_pred=len_filter),
                            batch_size=eval_batch_size)

    monitor_data = OrderedDict()
    for dataset in test_paths:
        m = get_standard_iter(torchtext.data.TabularDataset(
            path=dataset,
            format='tsv',
            fields=tabular_data_fields,
            filter_pred=len_filter),
                              batch_size=eval_batch_size)
        monitor_data[dataset] = m

    return src, tgt, train, dev, monitor_data
コード例 #13
0
ファイル: main.py プロジェクト: gautierdag/pcfg-attention
def prepare_iters(opt):

    src = SourceField(batch_first=True)
    tgt = TargetField(batch_first=True, include_eos=True)
    tabular_data_fields = [('src', src), ('tgt', tgt)]

    max_len = opt.max_len

    def len_filter(example):
        return len(example.src) <= max_len and len(example.tgt) <= max_len

    ds = '100K'
    if opt.mini:
        ds = '10K'

    # generate training and testing data
    train = get_standard_iter(torchtext.data.TabularDataset(
        path='data/pcfg_set/{}/train.tsv'.format(ds),
        format='tsv',
        fields=tabular_data_fields,
        filter_pred=len_filter),
                              batch_size=opt.batch_size)

    dev = get_standard_iter(torchtext.data.TabularDataset(
        path='data/pcfg_set/{}/dev.tsv'.format(ds),
        format='tsv',
        fields=tabular_data_fields,
        filter_pred=len_filter),
                            batch_size=opt.eval_batch_size)

    monitor_data = OrderedDict()
    m = get_standard_iter(torchtext.data.TabularDataset(
        path='data/pcfg_set/{}/test.tsv'.format(ds),
        format='tsv',
        fields=tabular_data_fields,
        filter_pred=len_filter),
                          batch_size=opt.eval_batch_size)
    monitor_data['Test'] = m

    return src, tgt, train, dev, monitor_data
コード例 #14
0
    torch.cuda.set_device(opt.cuda_device)

#################################################################################
# load model

logging.info("loading checkpoint from {}".format(
    os.path.join(opt.checkpoint_path)))
checkpoint = Checkpoint.load(opt.checkpoint_path)
seq2seq = checkpoint.model
input_vocab = checkpoint.input_vocab
output_vocab = checkpoint.output_vocab

############################################################################
# Prepare dataset and loss
src = SourceField()
tgt = TargetField(output_eos_used)

tabular_data_fields = [('src', src), ('tgt', tgt)]

if opt.use_attention_loss or opt.attention_method == 'hard':
    attn = AttentionField(use_vocab=False, ignore_index=IGNORE_INDEX)
    tabular_data_fields.append(('attn', attn))

src.vocab = input_vocab
tgt.vocab = output_vocab
tgt.eos_id = tgt.vocab.stoi[tgt.SYM_EOS]
tgt.sos_id = tgt.vocab.stoi[tgt.SYM_SOS]
max_len = opt.max_len


def len_filter(example):
コード例 #15
0
        "Can't use attention loss in combination with non-differentiable hard attention method."
    )

if torch.cuda.is_available():
    logging.info("Cuda device set to %i" % opt.cuda_device)
    torch.cuda.set_device(opt.cuda_device)

if opt.attention:
    if not opt.attention_method:
        logging.info("No attention method provided. Using DOT method.")
        opt.attention_method = 'dot'

############################################################################
# Prepare dataset
src = SourceField()
tgt = TargetField(include_eos=use_output_eos)

tabular_data_fields = [('src', src), ('tgt', tgt)]

if opt.use_attention_loss or opt.attention_method == 'hard':
    attn = AttentionField(use_vocab=False, ignore_index=IGNORE_INDEX)
    tabular_data_fields.append(('attn', attn))

max_len = opt.max_len


def len_filter(example):
    return len(example.src) <= max_len and len(example.tgt) <= max_len


# generate training and testing data
コード例 #16
0
    torch.cuda.set_device(opt.cuda_device)

#################################################################################
# load model

logging.info("loading checkpoint from {}".format(
    os.path.join(opt.checkpoint_path)))
checkpoint = Checkpoint.load(opt.checkpoint_path)
seq2seq = checkpoint.model
input_vocab = checkpoint.input_vocab
output_vocab = checkpoint.output_vocab

############################################################################
# Prepare dataset and loss
src = SourceField()
tgt = TargetField()
src.vocab = input_vocab
tgt.vocab = output_vocab
max_len = opt.max_len


def len_filter(example):
    return len(example.src) <= max_len and len(example.tgt) <= max_len


# generate test set
test = torchtext.data.TabularDataset(path=opt.test_data,
                                     format='tsv',
                                     fields=[('src', src), ('tgt', tgt)],
                                     filter_pred=len_filter)
コード例 #17
0
ファイル: evaluate.py プロジェクト: m0re4u/machine
    torch.cuda.set_device(opt.cuda_device)

##########################################################################
# load model

logging.info("loading checkpoint from {}".format(
    os.path.join(opt.checkpoint_path)))
checkpoint = Checkpoint.load(opt.checkpoint_path)
seq2seq = checkpoint.model
input_vocab = checkpoint.input_vocab
output_vocab = checkpoint.output_vocab

############################################################################
# Prepare dataset and loss
src = SourceField(batch_first=True)
tgt = TargetField(output_eos_used, batch_first=True)

tabular_data_fields = [('src', src), ('tgt', tgt)]

src.vocab = input_vocab
tgt.vocab = output_vocab
tgt.eos_id = tgt.vocab.stoi[tgt.SYM_EOS]
tgt.sos_id = tgt.vocab.stoi[tgt.SYM_SOS]
max_len = opt.max_len


def len_filter(example):
    return len(example.src) <= max_len and len(example.tgt) <= max_len


def get_standard_batch_iterator(data, batch_size):
コード例 #18
0
    from machine.dataset import SourceField, TargetField

    # train
    combine_src_tgt_files('data/pcfg_set/100K/random_split/train.src',
                          'data/pcfg_set/100K/random_split/train.tgt',
                          'data/pcfg_set/100K/train.tsv')
    # dev
    combine_src_tgt_files('data/pcfg_set/100K/random_split/dev.src',
                          'data/pcfg_set/100K/random_split/dev.tgt',
                          'data/pcfg_set/100K/dev.tsv')
    # test
    combine_src_tgt_files('data/pcfg_set/100K/random_split/test.src',
                          'data/pcfg_set/100K/random_split/test.tgt',
                          'data/pcfg_set/100K/test.tsv')

    use_output_eos = False
    src = SourceField(batch_first=True)
    tgt = TargetField(include_eos=use_output_eos, batch_first=True)
    tabular_data_fields = [('src', src), ('tgt', tgt)]

    max_len = 50

    def len_filter(example):
        return len(example.src) <= max_len and len(example.tgt) <= max_len

    train = torchtext.data.TabularDataset(path='data/pcfg_set/10K/train.tsv',
                                          format='tsv',
                                          fields=tabular_data_fields,
                                          filter_pred=len_filter)