def test_combine_dicts_with_disjoint_key_sets():
    combined_dict = py_datastructures.combine_dicts([{
        "d1_k1": "d1_v1"
    }, {
        "d2_k1": "d2_v1"
    }])
    expected_dict = {"d1_k1": "d1_v1", "d2_k1": "d2_v1"}
    combined_sorted: str = json.dumps(combined_dict, sort_keys=True)
    expected_sorted: str = json.dumps(expected_dict, sort_keys=True)
    assert combined_sorted == expected_sorted
Exemple #2
0
    def run_from_parser_json_prepend(cls, parser, cl_args):
        parser.add_argument("--ZZsrc", type=str, action="append")
        parser.add_argument("--ZZoverrides", type=str, nargs="+")
        pre_args, _ = parser.parse_known_args(cl_args)
        if cl_args is None:
            cl_args = sys.argv[1:]
        if pre_args.ZZsrc is not None:
            # Import configs from ZZsrc JSONs
            imported_dict_ls = [read_json(path) for path in pre_args.ZZsrc]
            combined_imported_dict = combine_dicts(imported_dict_ls,
                                                   strict=True)

            # Record which args are going to be overridden

            if pre_args.ZZoverrides is not None:
                raw_overrides = pre_args.ZZoverrides
                overrides = [f"--{k}" for k in raw_overrides]
            else:
                raw_overrides = overrides = []

            attr_dict = cls.get_attr_dict()
            added_args = []
            for k, v in combined_imported_dict.items():
                formatted_k = f"--{k}"
                # Ensure that args from imported, which are not specified to be overridden,
                #   aren't explicitly specified
                if formatted_k in cl_args and formatted_k not in overrides:
                    raise RuntimeError(f"Attempting to override {formatted_k}")

                # Special handling for store_true args
                if cls._is_store_true_arg(attr_dict[k]):
                    if v and k not in raw_overrides:
                        added_args.append(formatted_k)
                else:
                    added_args.append(formatted_k)
                    added_args.append(str(v))
            submitted_args = added_args + cl_args
        else:
            assert pre_args.ZZoverrides is None
            submitted_args = cl_args
        update_parser(
            parser=parser,
            class_with_attributes=cls,
        )
        result, _ = read_parser(
            parser=parser,
            class_with_attributes=cls,
            skip_non_class_attributes=["ZZsrc", "ZZoverrides"],
            args=submitted_args,
        )
        assert isinstance(result, cls)
        return result
Exemple #3
0
 def collate_fn(cls, batch):
     # cls.collate_fn
     elem = batch[0]
     if isinstance(elem, Mapping):  # dict
         assert set(elem.keys()) == {"data_row", "metadata"}
         data_rows = [x["data_row"] for x in batch]
         metadata = [x["metadata"] for x in batch]
         collated_data_rows = {
             key: flat_collate_fn([getattr(d, key) for d in data_rows])
             for key in data_rows[0].to_dict()
         }
         collated_metadata = metadata_collate_fn(metadata)
         combined = combine_dicts([collated_data_rows, collated_metadata])
         batch_dict = {}
         for field, field_type in cls.Batch.get_annotations().items():
             batch_dict[field] = combined.pop(field)
             if field_type == torch.FloatTensor:
                 # Ensure that floats stay as float32
                 batch_dict[field] = batch_dict[field].float()
         out_batch = cls.Batch(**batch_dict)
         remainder = combined
         return out_batch, remainder
     else:
         raise TypeError(f"Unknown type for collate_fn {type(elem)}")
def test_combine_dicts_with_overlapping_key_sets():
    with pytest.raises(RuntimeError):
        py_datastructures.combine_dicts([{"k1": "v1"}, {"k1": "v1"}])