def test_hierarchical_dataset_invalid_json_fail(hierarchical_dataset_fields): with pytest.raises(JSONDecodeError): HierarchicalDataset.from_json( INVALID_JSON, hierarchical_dataset_fields, HierarchicalDataset.get_default_dict_parser("children"), )
def test_hierarchical_dataset_json_root_element_not_list_fail( hierarchical_dataset_fields, ): with pytest.raises(ValueError): HierarchicalDataset.from_json( JSON_ROOT_NOT_LIST, hierarchical_dataset_fields, HierarchicalDataset.get_default_dict_parser("children"), )
def hierarchical_dataset_2(hierarchical_dataset_fields, hierarchical_dataset_parser): dataset = HierarchicalDataset.from_json( HIERARCHIAL_DATASET_JSON_EXAMPLE_2, hierarchical_dataset_fields, hierarchical_dataset_parser, ) return dataset
def hierarchical_dataset(hierarchical_dataset_fields, hierarchical_dataset_parser): return HierarchicalDataset.from_json( dataset=HIERARCHIAL_DATASET_JSON_EXAMPLE, fields=hierarchical_dataset_fields, parser=hierarchical_dataset_parser, )
def _get_node_context(self, node): """ Generates a list of examples that make up the context of the provided node, truncated to adhere to 'context_max_depth' and 'context_max_length' limitations. Parameters ---------- node : Node The Hierarchical dataset node the context should be retrieved for. Returns ------- list(Example) A list of examples that make up the context of the provided node, truncated to adhere to 'context_max_depth' and 'context_max_length' limitations. """ context_iterator = HierarchicalDataset._get_node_context( node, self._context_max_depth) context = list(context_iterator) if self._context_max_length is not None: # if context max size is defined, truncate it context = context[-self._context_max_length:] # add the example to the end of its own context context.append(node.example) return context
def hierarchical_dataset(hierarchical_dataset_fields, hierarchical_dataset_parser): dataset = HierarchicalDataset.from_json( HIERARCHIAL_DATASET_JSON_EXAMPLE, hierarchical_dataset_fields, hierarchical_dataset_parser, ) dataset.finalize_fields() return dataset
def test_hierarchical_dataset_finalize_fields(hierarchical_dataset_parser): name_vocab = Vocab() number_vocab = Vocab() name_field = Field("name", keep_raw=True, tokenizer=None, numericalizer=name_vocab) number_field = Field("number", keep_raw=True, tokenizer=None, numericalizer=number_vocab) fields = {"name": name_field, "number": number_field} dataset = HierarchicalDataset.from_json( dataset=HIERARCHIAL_DATASET_JSON_EXAMPLE, fields=fields, parser=hierarchical_dataset_parser, ) dataset.finalize_fields() assert name_vocab.is_finalized assert number_vocab.is_finalized
def hierarchical_dataset_parser(): return HierarchicalDataset.get_default_dict_parser("children")