def test_point_generator(x, task, stage): shuffle = False batch_size = 3 generator = PointGenerator(x, task, batch_size, stage, shuffle) assert len(generator) == 1 for x, y in generator: assert x['ids'].tolist() == [['qid0', 'did0'], ['qid1', 'did1'], ['qid1', 'did0']] assert x['text_left'].tolist() == [[1, 2], [2, 3], [2, 3]] assert x['text_right'].tolist() == [[2, 3, 4], [3, 4, 5], [2, 3, 4]] if stage == 'test': assert y is None elif stage == 'train' and task == tasks.Classification(num_classes=3): assert y.tolist() == [[1, 0, 0], [0, 1, 0], [0, 0, 1]] break
def __init__(self, inputs: datapack.DataPack, task: engine.BaseTask = tasks.Classification(2), batch_size: int = 32, stage: str = 'train', shuffle: bool = True): """Construct the point generator. :param inputs: the output generated by :class:`DataPack`. :param task: the task is a instance of :class:`engine.BaseTask`. :param batch_size: number of instances in a batch. :param shuffle: whether to shuffle the instances while generating a batch. """ self._relation = inputs.relation self._task = task self._left = inputs.left self._right = inputs.right super().__init__(batch_size, len(inputs.relation), stage, shuffle)
# with no kwargs: (models.DenseBaselineModel, None) # with kwargs: (models.DenseBaselineModel, {"num_dense_units": 512}) model_setups = [ (models.NaiveModel, None), (models.DenseBaselineModel, None), (models.DSSMModel, None) ] @pytest.fixture(scope='module', params=[1, 32]) def num_samples(request): return request.param @pytest.fixture(scope='module', params=[ tasks.Classification(num_classes=2), tasks.Classification(num_classes=16), tasks.Ranking() ]) def task(request): return request.param @pytest.fixture(params=model_setups) def raw_model(request): model_class, custom_kwargs = request.param model = model_class() if custom_kwargs: for key, val in custom_kwargs.items(): model.params[key] = val return model
def test_classification_num_classes(arg): task = tasks.Classification(num_classes=arg) assert task.num_classes == arg
def test_classification_instantiation_failure(arg): with pytest.raises(Exception): tasks.Classification(num_classes=arg)
def x(): relation = [['qid0', 'did0', 0], ['qid1', 'did1', 1], ['qid1', 'did0', 2]] left = [['qid0', [1, 2]], ['qid1', [2, 3]]] right = [['did0', [2, 3, 4]], ['did1', [3, 4, 5]]] ctx = {'vocab_size': 6, 'fill_word': 6} relation = pd.DataFrame(relation, columns=['id_left', 'id_right', 'label']) left = pd.DataFrame(left, columns=['id_left', 'text_left']) left.set_index('id_left', inplace=True) right = pd.DataFrame(right, columns=['id_right', 'text_right']) right.set_index('id_right', inplace=True) return DataPack(relation=relation, left=left, right=right, context=ctx) @pytest.fixture(scope='module', params=[ tasks.Classification(num_classes=3), tasks.Ranking(), ]) def task(request): return request.param @pytest.fixture(scope='module', params=['train', 'test']) def stage(request): return request.param def test_point_generator(x, task, stage): shuffle = False batch_size = 3 generator = PointGenerator(x, task, batch_size, stage, shuffle)
def _guess_task(train_pack): if np.issubdtype(train_pack.relation['label'].dtype, np.number): return tasks.Ranking() elif np.issubdtype(train_pack.relation['label'].dtype, list): num_classes = int(train_pack.relation['label'].apply(len).max()) return tasks.Classification(num_classes)