Esempio n. 1
0
    def test_extend_vocab_1(self):
        vectors_cache_dir = '.cache'
        if os.path.exists(vectors_cache_dir):
            shutil.rmtree(vectors_cache_dir)

        mf = MatchingField()
        lf = MatchingField(id=True, sequential=False)
        fields = [('id', lf), ('left_a', mf), ('right_a', mf), ('label', lf)]
        col_naming = {
            'id': 'id',
            'label': 'label',
            'left': 'left_',
            'right': 'right_'
        }

        pathdir = os.path.abspath(os.path.join(test_dir_path, 'test_datasets'))
        filename = 'fasttext_sample.vec'
        file = os.path.join(pathdir, filename)
        url_base = urljoin('file:', pathname2url(file))
        vecs = Vectors(name=filename, cache=vectors_cache_dir, url=url_base)

        data_path = os.path.join(test_dir_path, 'test_datasets',
                                 'sample_table_small.csv')
        md = MatchingDataset(fields, col_naming, path=data_path)

        mf.build_vocab()
        mf.vocab.vectors = torch.Tensor(len(mf.vocab.itos), 300)
        mf.extend_vocab(md, vectors=vecs)
        self.assertEqual(len(mf.vocab.itos), 6)
        self.assertEqual(mf.vocab.vectors.size(), torch.Size([6, 300]))
Esempio n. 2
0
def test_class_matching_dataset():
    fields = [("left_a", MatchingField()), ("right_a", MatchingField())]
    col_naming = {
        "id": "id",
        "label": "label",
        "left": "left",
        "right": "right"
    }
    path = os.path.join(test_dir_path, "test_datasets",
                        "sample_table_small.csv")
    md = MatchingDataset(fields, col_naming, path=path)
    assert md.id_field == "id"
    assert md.label_field == "label"
    assert md.all_left_fields == ["left_a"]
    assert md.all_right_fields == ["right_a"]
    assert md.all_text_fields == ["left_a", "right_a"]
    assert md.canonical_text_fields == ["_a"]
Esempio n. 3
0
def process_unlabeled(path, trained_model, ignore_columns=None):
    """Creates a dataset object for an unlabeled dataset.

    Args:
        path (str): The full path to the unlabeled data file (not just the directory).
        trained_model (:class:`~deepmatcher.MatchingModel`): The trained model.
            The model is aware of the configuration of the training
            data on which it was trained, and so this method reuses the same
            configuration for the unlabeled data.
        ignore_columns (list): A list of columns to ignore in the unlabeled CSV file.

    """
    with io.open(path, encoding="utf8") as f:
        header = next(unicode_csv_reader(f))

    train_info = trained_model.meta
    if ignore_columns is None:
        ignore_columns = train_info.ignore_columns
    column_naming = dict(train_info.column_naming)
    column_naming["label"] = None

    fields = _make_fields(
        header,
        column_naming["id"],
        column_naming["label"],
        ignore_columns,
        train_info.lowercase,
        train_info.tokenize,
        train_info.include_lengths,
    )

    begin = timer()
    dataset_args = {"fields": fields, "column_naming": column_naming}
    dataset = MatchingDataset(path=path, **dataset_args)

    # Make sure we have the same attributes.
    assert set(dataset.all_text_fields) == set(train_info.all_text_fields)

    after_load = timer()
    logger.info("Data load time: {}s".format(after_load - begin))

    reverse_fields_dict = {pair[1]: pair[0] for pair in fields}
    for field, name in reverse_fields_dict.items():
        if field is not None and field.use_vocab:
            # Copy over vocab from original train data.
            field.vocab = copy.deepcopy(train_info.vocabs[name])
            # Then extend the vocab.
            field.extend_vocab(
                dataset,
                vectors=train_info.embeddings,
                cache=train_info.embeddings_cache,
            )

    dataset.vocabs = {
        name: dataset.fields[name].vocab
        for name in train_info.all_text_fields
    }

    after_vocab = timer()
    logger.info("Vocab update time: {}s".format(after_vocab - after_load))

    return dataset