def test_combine_dicts_with_disjoint_key_sets(): combined_dict = py_datastructures.combine_dicts([{ "d1_k1": "d1_v1" }, { "d2_k1": "d2_v1" }]) expected_dict = {"d1_k1": "d1_v1", "d2_k1": "d2_v1"} combined_sorted: str = json.dumps(combined_dict, sort_keys=True) expected_sorted: str = json.dumps(expected_dict, sort_keys=True) assert combined_sorted == expected_sorted
def run_from_parser_json_prepend(cls, parser, cl_args): parser.add_argument("--ZZsrc", type=str, action="append") parser.add_argument("--ZZoverrides", type=str, nargs="+") pre_args, _ = parser.parse_known_args(cl_args) if cl_args is None: cl_args = sys.argv[1:] if pre_args.ZZsrc is not None: # Import configs from ZZsrc JSONs imported_dict_ls = [read_json(path) for path in pre_args.ZZsrc] combined_imported_dict = combine_dicts(imported_dict_ls, strict=True) # Record which args are going to be overridden if pre_args.ZZoverrides is not None: raw_overrides = pre_args.ZZoverrides overrides = [f"--{k}" for k in raw_overrides] else: raw_overrides = overrides = [] attr_dict = cls.get_attr_dict() added_args = [] for k, v in combined_imported_dict.items(): formatted_k = f"--{k}" # Ensure that args from imported, which are not specified to be overridden, # aren't explicitly specified if formatted_k in cl_args and formatted_k not in overrides: raise RuntimeError(f"Attempting to override {formatted_k}") # Special handling for store_true args if cls._is_store_true_arg(attr_dict[k]): if v and k not in raw_overrides: added_args.append(formatted_k) else: added_args.append(formatted_k) added_args.append(str(v)) submitted_args = added_args + cl_args else: assert pre_args.ZZoverrides is None submitted_args = cl_args update_parser( parser=parser, class_with_attributes=cls, ) result, _ = read_parser( parser=parser, class_with_attributes=cls, skip_non_class_attributes=["ZZsrc", "ZZoverrides"], args=submitted_args, ) assert isinstance(result, cls) return result
def collate_fn(cls, batch): # cls.collate_fn elem = batch[0] if isinstance(elem, Mapping): # dict assert set(elem.keys()) == {"data_row", "metadata"} data_rows = [x["data_row"] for x in batch] metadata = [x["metadata"] for x in batch] collated_data_rows = { key: flat_collate_fn([getattr(d, key) for d in data_rows]) for key in data_rows[0].to_dict() } collated_metadata = metadata_collate_fn(metadata) combined = combine_dicts([collated_data_rows, collated_metadata]) batch_dict = {} for field, field_type in cls.Batch.get_annotations().items(): batch_dict[field] = combined.pop(field) if field_type == torch.FloatTensor: # Ensure that floats stay as float32 batch_dict[field] = batch_dict[field].float() out_batch = cls.Batch(**batch_dict) remainder = combined return out_batch, remainder else: raise TypeError(f"Unknown type for collate_fn {type(elem)}")
def test_combine_dicts_with_overlapping_key_sets(): with pytest.raises(RuntimeError): py_datastructures.combine_dicts([{"k1": "v1"}, {"k1": "v1"}])