def from_params(cls: Type[T], params: Params, **extras) -> T: from dh_segment_torch.config.registrable import Registrable params = normalize_params(params) if is_base_registrable(cls): if Registrable._register.get(cls) is None: raise ConfigurationError( "Tried to construct an abstract registrable class that has nothing registered." ) class_as_registrable = cast(Type[Registrable], cls) choices = class_as_registrable.get_available() subclass_type = params.pop("type", choices[0]) subclass, constructor = class_as_registrable.get(subclass_type) if has_from_params(subclass): kwargs = create_kwargs(constructor, subclass, params, **extras) else: extras = get_extras(subclass, extras) kwargs = {**params, **extras} return class_as_registrable.get_constructor(subclass_type)( **kwargs) else: if cls.__init__ == object.__init__: kwargs: Dict[str, Any] = {} params.assert_empty(cls.__name__) else: kwargs = create_kwargs(cls.__init__, cls, params, **extras) return cls(**kwargs)
def test_basic_load_dataset(self): params = { "type": "image_csv", "csv_filename": self.FIXTURES_ROOT / "dataset" / "multiclass" / "train.csv", "base_dir": self.FIXTURES_ROOT / "dataset" / "multiclass", } dataset = Dataset.from_params(Params(params)) assert isinstance(dataset, Dataset) assert isinstance(dataset, ImageDataset) assert len(dataset) == 10 assert dataset.num_images == 10 sample = dataset[0] assert "image" in sample assert "label" in sample assert "shape" in sample patches_size = 32 params = { "type": "patches_folder", "folder": self.FIXTURES_ROOT / "dataset" / "multilabel", "patches_size": patches_size, } dataset = Dataset.from_params(Params(params)) assert isinstance(dataset, IterableDataset) assert isinstance(dataset, PatchesDataset) assert dataset.num_images == 15 for sample in dataset: break assert "image" in sample assert "label" in sample assert "shape" in sample assert sample["image"].shape[1] == patches_size
def normalize_params(params: Union[Params, str, Dict]): if isinstance(params, str): params = Params({"type": params}) if isinstance(params, dict): params = Params(params) return params
def test_patches_transform(self): patches_size = 32 params = { "type": "patches_csv", "csv_filename": self.FIXTURES_ROOT / "dataset" / "multiclass" / "train.csv", "base_dir": self.FIXTURES_ROOT / "dataset" / "multiclass", "patches_size": patches_size, "pre_patches_compose": { "transforms": [ { "type": "fixed_size_resize", "output_size": 1e5 }, ] }, "post_patches_compose": { "transforms": [ "gaussian_blur", "flip", ] }, "assign_transform": { "type": "assign_label", "colors_array": [[0, 0, 0], [255, 0, 0], [0, 0, 255]], }, } dataset = Dataset.from_params(Params(params)) for sample in dataset: break assert "image" in sample assert "label" in sample assert "shape" in sample
def test_transform_dataset(self): first_image = cv2.imread( str(self.FIXTURES_ROOT / "dataset" / "multiclass" / "images" / "image_001.png")) params = { "type": "image_csv", "csv_filename": self.FIXTURES_ROOT / "dataset" / "multiclass" / "train.csv", "base_dir": self.FIXTURES_ROOT / "dataset" / "multiclass", "compose": { "transforms": [ { "type": "fixed_size_resize", "output_size": 1e5 }, "gaussian_blur", "flip", ] }, "assign_transform": { "type": "assign_label", "colors_array": [[0, 0, 0], [255, 0, 0], [0, 0, 255]], }, } dataset = Dataset.from_params(Params(params)) sample = dataset[0] assert "image" in sample assert "label" in sample assert "shape" in sample assert (len( set(sample["label"].unique().numpy().tolist()).difference( [0, 1, 2])) == 0) with pytest.raises(TypeError): params = { "type": "image_csv", "csv_filename": self.FIXTURES_ROOT / "dataset" / "multiclass" / "train.csv", "base_dir": self.FIXTURES_ROOT / "dataset" / "multiclass", "compose": {"blur"}, } Dataset.from_params(Params(params))
def create_kwargs(constructor: Callable[..., T], cls: Type[T], params: Params, **extras) -> Dict[str, Any]: kwargs: Dict[str, any] = {} parameters = infer_params(constructor, cls) for param_name, param in parameters.items(): if param_name == "self" or param.kind == param.VAR_KEYWORD: continue constructed_param = pop_construct_param(param_name, param, params, **extras) if constructed_param is not param.default: kwargs[param_name] = constructed_param params.assert_empty(cls.__name__) return kwargs
def pop_construct_param(param_name: str, param: inspect.Parameter, params: Params, **extras): if param_name in extras: if param_name not in params: return extras[param_name] else: logger.warning( f"Parameter {param_name} was found in extras, which is not required," "but was also found in params. Using params value, but this can" "lead to unexpected results") optional = param.default != param.empty # params = normalize_params(params) popped_params = (params.pop(param_name, param.default) if optional else params.pop(param_name)) # TODO find cornercase when needed if popped_params is None: param_type = infer_type(param) origin, _ = get_origin_args(param_type) # if origin == Lazy: # return Lazy(lambda **kwargs: None) return None return construct_param(param_name, param, popped_params, **extras)
def test_assign_multilabel(self): params = { "type": "image_csv", "csv_filename": self.FIXTURES_ROOT / "dataset" / "multilabel" / "train.csv", "base_dir": self.FIXTURES_ROOT / "dataset" / "multilabel", "compose": { "transforms": [ { "type": "fixed_size_resize", "output_size": 1e5 }, { "type": "gaussian_blur" }, ] }, "assign_transform": { "type": "assign_multilabel", "colors_array": [ [0, 0, 0], [255, 0, 0], [0, 0, 255], [128, 0, 128], ], "onehot_label_array": [[0, 0], [1, 0], [0, 1], [1, 1]], }, } dataset = Dataset.from_params(Params(params)) sample = dataset[0] assert "image" in sample assert "label" in sample assert "shape" in sample
type=str, nargs="?", default=None, help="trainer checkpoint to resume from", ) parser.add_argument( "--model-checkpoint", type=str, nargs="?", default=None, help="model checkpoint to resume from", ) if __name__ == "__main__": args = parser.parse_args() params = Params.from_file(args.config) model_out_dir = params.get("model_out_dir", "./model") os.makedirs(model_out_dir, exist_ok=True) params.to_file(os.path.join(model_out_dir, "config.json")) exp_name = params.pop("experiment_name", "dhSegment_experiment") config = params.as_dict() ex = Experiment(exp_name) ex.add_config(config) trainer = Trainer.from_params(params, exp_name=exp_name, config=deepcopy(config))
def construct_param(param_name: str, param: inspect.Parameter, params: Params, **extras): """ :param param_name: :param param: :param params: :param extras: :return: """ param_type = infer_type(param) origin, args = get_origin_args(param.annotation) if is_optional(param) and params is None: return params if has_from_params(param_type): if params is param.default: return param.default elif params is not None: params = normalize_params(params) param_type_as_from_params = cast(Type[FromParams], param_type) sub_extras = get_extras(param_type_as_from_params, extras) return param_type_as_from_params.from_params(params, **sub_extras) elif origin == Lazy: if params is param.default: return Lazy(lambda **kwargs: param.default) sub_param_type_as_from_params = cast(Type[FromParams], args[0]) sub_extras = get_extras(sub_param_type_as_from_params, extras) return Lazy(lambda **kwargs: sub_param_type_as_from_params.from_params( params=deepcopy(params), **{ **sub_extras, **kwargs })) elif param_type in {int, bool}: if type(params) in {int, bool}: return param_type(params) else: raise TypeError( f"Expected {param_name} to be a {param_type.__name__}.") elif param_type == float: if type(params) in {int, float}: return params else: raise TypeError( f"Expected {param_name} to be either a float or a int.") elif param_type == str: if type(params) == str or isinstance(params, Path): return str(params) else: raise TypeError(f"Expected {params} to be a string.") elif (origin in {collections.abc.Mapping, Mapping, Dict, dict} and len(args) == 2 and can_construct(args[-1])): if not isinstance(params, collections.abc.Mapping): raise TypeError(f"Expected {param_name} to be a mapping.") value_class = args[-1] value_class_as_param = inspect.Parameter( "dummy", kind=inspect.Parameter.VAR_KEYWORD, annotation=value_class) new_dict = {} for key_param, value_params in params.items(): new_dict[key_param] = construct_param( f"{param_name}.{key_param}", value_class_as_param, value_params, **extras, ) return new_dict elif origin in {Tuple, tuple} and all(can_construct(arg) for arg in args): if not isinstance(params, collections.abc.Sequence): raise TypeError(f"Expected {param_name} to be a sequence.") new_tuple = [] prev_value_class = None for i, (value_class, value_params) in enumerate( zip(args, iterate_not_string(params))): if value_class == Ellipsis: value_class = prev_value_class value_class_as_param = inspect.Parameter( "dummy", kind=inspect.Parameter.VAR_KEYWORD, annotation=value_class) new_tuple.append( construct_param(f"{param_name}.{i}", value_class_as_param, value_params, **extras)) prev_value_class = value_class return tuple(new_tuple) elif origin in {Set, set} and len(args) == 1 and can_construct(args[0]): if not isinstance(params, collections.abc.Set) and not isinstance( params, collections.abc.Sequence): raise TypeError( f"Expected {param_name} to be a set or a sequence.") value_class = args[0] value_class_as_param = inspect.Parameter( "dummy", kind=inspect.Parameter.VAR_KEYWORD, annotation=value_class) new_set = set() for i, value_params in enumerate(iterate_not_string(params)): new_set.add( construct_param(f"{param_name}.{i}", value_class_as_param, value_params, **extras)) return new_set elif (origin in {collections.abc.Iterable, Iterable, List, list} and len(args) == 1 and can_construct(args[0])): if not isinstance(params, collections.abc.Sequence): raise TypeError(f"Expected {param_name} to be a sequence.") value_class = args[0] value_class_as_param = inspect.Parameter( "dummy", kind=inspect.Parameter.VAR_KEYWORD, annotation=value_class) new_list = [] for i, value_params in enumerate(iterate_not_string(params)): new_list.append( construct_param(f"{param_name}.{i}", value_class_as_param, value_params, **extras)) return new_list elif origin == Union: backup_params = deepcopy(params) all_errors = [] for value_class in args: value_class_as_param = inspect.Parameter( "dummy", kind=inspect.Parameter.VAR_KEYWORD, annotation=value_class) try: return construct_param(param_name, value_class_as_param, params, **extras) except ( AttributeError, ValueError, TypeError, ConfigurationError, RegistrableError, ) as e: params = deepcopy(backup_params) all_errors.append(e) raise ConfigurationError( f"Failed to construct {param_name} with type {param_type} and {params}, got errors: {all_errors}" ) else: logger.warning( f"The params {str(params)} were not matched, returning them as is." ) if isinstance(params, Params): return params.as_dict() else: return params