示例#1
0
def prepare_iters(parameters,
                  train_path,
                  test_paths,
                  valid_path,
                  batch_size,
                  eval_batch_size=512):
    src = SourceField(batch_first=True)
    tgt = TargetField(include_eos=False, batch_first=True)
    tabular_data_fields = [('src', src), ('tgt', tgt)]

    max_len = parameters['max_len']

    def len_filter(example):
        return len(example.src) <= max_len and len(example.tgt) <= max_len

    # generate training and testing data
    train = get_standard_iter(torchtext.data.TabularDataset(
        path=train_path,
        format='tsv',
        fields=tabular_data_fields,
        filter_pred=len_filter),
                              batch_size=batch_size)

    dev = get_standard_iter(torchtext.data.TabularDataset(
        path=valid_path,
        format='tsv',
        fields=tabular_data_fields,
        filter_pred=len_filter),
                            batch_size=eval_batch_size)

    monitor_data = OrderedDict()
    for dataset in test_paths:
        m = get_standard_iter(torchtext.data.TabularDataset(
            path=dataset,
            format='tsv',
            fields=tabular_data_fields,
            filter_pred=len_filter),
                              batch_size=eval_batch_size)
        monitor_data[dataset] = m

    return src, tgt, train, dev, monitor_data
示例#2
0
def prepare_iters(opt):

    use_output_eos = not opt.ignore_output_eos
    src = SourceField(batch_first=True)
    tgt = TargetField(include_eos=use_output_eos, batch_first=True)
    tabular_data_fields = [('src', src), ('tgt', tgt)]

    max_len = opt.max_len

    def len_filter(example):
        return len(example.src) <= max_len and len(example.tgt) <= max_len

    # generate training and testing data
    train = get_standard_iter(torchtext.data.TabularDataset(
        path=opt.train,
        format='tsv',
        fields=tabular_data_fields,
        filter_pred=len_filter),
                              batch_size=opt.batch_size)

    if opt.dev:
        dev = get_standard_iter(torchtext.data.TabularDataset(
            path=opt.dev,
            format='tsv',
            fields=tabular_data_fields,
            filter_pred=len_filter),
                                batch_size=opt.eval_batch_size)
    else:
        dev = None

    monitor_data = OrderedDict()
    for dataset in opt.monitor:
        m = get_standard_iter(torchtext.data.TabularDataset(
            path=dataset,
            format='tsv',
            fields=tabular_data_fields,
            filter_pred=len_filter),
                              batch_size=opt.eval_batch_size)
        monitor_data[dataset] = m

    return src, tgt, train, dev, monitor_data
示例#3
0
def prepare_iters(opt):

    src = SourceField(batch_first=True)
    tgt = TargetField(batch_first=True, include_eos=True)
    tabular_data_fields = [('src', src), ('tgt', tgt)]

    max_len = opt.max_len

    def len_filter(example):
        return len(example.src) <= max_len and len(example.tgt) <= max_len

    ds = '100K'
    if opt.mini:
        ds = '10K'

    # generate training and testing data
    train = get_standard_iter(torchtext.data.TabularDataset(
        path='data/pcfg_set/{}/train.tsv'.format(ds),
        format='tsv',
        fields=tabular_data_fields,
        filter_pred=len_filter),
                              batch_size=opt.batch_size)

    dev = get_standard_iter(torchtext.data.TabularDataset(
        path='data/pcfg_set/{}/dev.tsv'.format(ds),
        format='tsv',
        fields=tabular_data_fields,
        filter_pred=len_filter),
                            batch_size=opt.eval_batch_size)

    monitor_data = OrderedDict()
    m = get_standard_iter(torchtext.data.TabularDataset(
        path='data/pcfg_set/{}/test.tsv'.format(ds),
        format='tsv',
        fields=tabular_data_fields,
        filter_pred=len_filter),
                          batch_size=opt.eval_batch_size)
    monitor_data['Test'] = m

    return src, tgt, train, dev, monitor_data