Esempio n. 1
0
def test_point_generator(x, task, stage):
    shuffle = False
    batch_size = 3
    generator = PointGenerator(x, task, batch_size, stage, shuffle)
    assert len(generator) == 1
    for x, y in generator:
        assert x['ids'].tolist() == [['qid0', 'did0'], ['qid1', 'did1'],
                                     ['qid1', 'did0']]
        assert x['text_left'].tolist() == [[1, 2], [2, 3], [2, 3]]
        assert x['text_right'].tolist() == [[2, 3, 4], [3, 4, 5], [2, 3, 4]]
        if stage == 'test':
            assert y is None
        elif stage == 'train' and task == tasks.Classification(num_classes=3):
            assert y.tolist() == [[1, 0, 0], [0, 1, 0], [0, 0, 1]]
        break
Esempio n. 2
0
    def __init__(self,
                 inputs: datapack.DataPack,
                 task: engine.BaseTask = tasks.Classification(2),
                 batch_size: int = 32,
                 stage: str = 'train',
                 shuffle: bool = True):
        """Construct the point generator.

        :param inputs: the output generated by :class:`DataPack`.
        :param task: the task is a instance of :class:`engine.BaseTask`.
        :param batch_size: number of instances in a batch.
        :param shuffle: whether to shuffle the instances while generating a
            batch.
        """
        self._relation = inputs.relation
        self._task = task
        self._left = inputs.left
        self._right = inputs.right
        super().__init__(batch_size, len(inputs.relation), stage, shuffle)
Esempio n. 3
0
# with no kwargs: (models.DenseBaselineModel, None)
# with kwargs: (models.DenseBaselineModel, {"num_dense_units": 512})
model_setups = [
    (models.NaiveModel, None),
    (models.DenseBaselineModel, None),
    (models.DSSMModel, None)
]


@pytest.fixture(scope='module', params=[1, 32])
def num_samples(request):
    return request.param


@pytest.fixture(scope='module', params=[
    tasks.Classification(num_classes=2),
    tasks.Classification(num_classes=16),
    tasks.Ranking()
])
def task(request):
    return request.param


@pytest.fixture(params=model_setups)
def raw_model(request):
    model_class, custom_kwargs = request.param
    model = model_class()
    if custom_kwargs:
        for key, val in custom_kwargs.items():
            model.params[key] = val
    return model
Esempio n. 4
0
def test_classification_num_classes(arg):
    task = tasks.Classification(num_classes=arg)
    assert task.num_classes == arg
Esempio n. 5
0
def test_classification_instantiation_failure(arg):
    with pytest.raises(Exception):
        tasks.Classification(num_classes=arg)
Esempio n. 6
0
def x():
    relation = [['qid0', 'did0', 0], ['qid1', 'did1', 1], ['qid1', 'did0', 2]]
    left = [['qid0', [1, 2]], ['qid1', [2, 3]]]
    right = [['did0', [2, 3, 4]], ['did1', [3, 4, 5]]]
    ctx = {'vocab_size': 6, 'fill_word': 6}
    relation = pd.DataFrame(relation, columns=['id_left', 'id_right', 'label'])
    left = pd.DataFrame(left, columns=['id_left', 'text_left'])
    left.set_index('id_left', inplace=True)
    right = pd.DataFrame(right, columns=['id_right', 'text_right'])
    right.set_index('id_right', inplace=True)
    return DataPack(relation=relation, left=left, right=right, context=ctx)


@pytest.fixture(scope='module',
                params=[
                    tasks.Classification(num_classes=3),
                    tasks.Ranking(),
                ])
def task(request):
    return request.param


@pytest.fixture(scope='module', params=['train', 'test'])
def stage(request):
    return request.param


def test_point_generator(x, task, stage):
    shuffle = False
    batch_size = 3
    generator = PointGenerator(x, task, batch_size, stage, shuffle)
Esempio n. 7
0
def _guess_task(train_pack):
    if np.issubdtype(train_pack.relation['label'].dtype, np.number):
        return tasks.Ranking()
    elif np.issubdtype(train_pack.relation['label'].dtype, list):
        num_classes = int(train_pack.relation['label'].apply(len).max())
        return tasks.Classification(num_classes)