Пример #1
0
    def from_params(cls: Type[T], params: Params, **extras) -> T:
        from dh_segment_torch.config.registrable import Registrable

        params = normalize_params(params)

        if is_base_registrable(cls):
            if Registrable._register.get(cls) is None:
                raise ConfigurationError(
                    "Tried to construct an abstract registrable class that has nothing registered."
                )
            class_as_registrable = cast(Type[Registrable], cls)
            choices = class_as_registrable.get_available()

            subclass_type = params.pop("type", choices[0])

            subclass, constructor = class_as_registrable.get(subclass_type)
            if has_from_params(subclass):
                kwargs = create_kwargs(constructor, subclass, params, **extras)
            else:
                extras = get_extras(subclass, extras)
                kwargs = {**params, **extras}
            return class_as_registrable.get_constructor(subclass_type)(
                **kwargs)
        else:
            if cls.__init__ == object.__init__:
                kwargs: Dict[str, Any] = {}
                params.assert_empty(cls.__name__)
            else:
                kwargs = create_kwargs(cls.__init__, cls, params, **extras)
            return cls(**kwargs)
Пример #2
0
    def test_basic_load_dataset(self):
        params = {
            "type": "image_csv",
            "csv_filename":
            self.FIXTURES_ROOT / "dataset" / "multiclass" / "train.csv",
            "base_dir": self.FIXTURES_ROOT / "dataset" / "multiclass",
        }

        dataset = Dataset.from_params(Params(params))
        assert isinstance(dataset, Dataset)
        assert isinstance(dataset, ImageDataset)
        assert len(dataset) == 10
        assert dataset.num_images == 10
        sample = dataset[0]
        assert "image" in sample
        assert "label" in sample
        assert "shape" in sample

        patches_size = 32
        params = {
            "type": "patches_folder",
            "folder": self.FIXTURES_ROOT / "dataset" / "multilabel",
            "patches_size": patches_size,
        }

        dataset = Dataset.from_params(Params(params))
        assert isinstance(dataset, IterableDataset)
        assert isinstance(dataset, PatchesDataset)
        assert dataset.num_images == 15
        for sample in dataset:
            break
        assert "image" in sample
        assert "label" in sample
        assert "shape" in sample
        assert sample["image"].shape[1] == patches_size
Пример #3
0
def normalize_params(params: Union[Params, str, Dict]):
    if isinstance(params, str):
        params = Params({"type": params})

    if isinstance(params, dict):
        params = Params(params)

    return params
Пример #4
0
    def test_patches_transform(self):
        patches_size = 32
        params = {
            "type": "patches_csv",
            "csv_filename":
            self.FIXTURES_ROOT / "dataset" / "multiclass" / "train.csv",
            "base_dir": self.FIXTURES_ROOT / "dataset" / "multiclass",
            "patches_size": patches_size,
            "pre_patches_compose": {
                "transforms": [
                    {
                        "type": "fixed_size_resize",
                        "output_size": 1e5
                    },
                ]
            },
            "post_patches_compose": {
                "transforms": [
                    "gaussian_blur",
                    "flip",
                ]
            },
            "assign_transform": {
                "type": "assign_label",
                "colors_array": [[0, 0, 0], [255, 0, 0], [0, 0, 255]],
            },
        }

        dataset = Dataset.from_params(Params(params))
        for sample in dataset:
            break
        assert "image" in sample
        assert "label" in sample
        assert "shape" in sample
Пример #5
0
    def test_transform_dataset(self):
        first_image = cv2.imread(
            str(self.FIXTURES_ROOT / "dataset" / "multiclass" / "images" /
                "image_001.png"))

        params = {
            "type": "image_csv",
            "csv_filename":
            self.FIXTURES_ROOT / "dataset" / "multiclass" / "train.csv",
            "base_dir": self.FIXTURES_ROOT / "dataset" / "multiclass",
            "compose": {
                "transforms": [
                    {
                        "type": "fixed_size_resize",
                        "output_size": 1e5
                    },
                    "gaussian_blur",
                    "flip",
                ]
            },
            "assign_transform": {
                "type": "assign_label",
                "colors_array": [[0, 0, 0], [255, 0, 0], [0, 0, 255]],
            },
        }

        dataset = Dataset.from_params(Params(params))
        sample = dataset[0]

        assert "image" in sample
        assert "label" in sample
        assert "shape" in sample

        assert (len(
            set(sample["label"].unique().numpy().tolist()).difference(
                [0, 1, 2])) == 0)

        with pytest.raises(TypeError):
            params = {
                "type": "image_csv",
                "csv_filename":
                self.FIXTURES_ROOT / "dataset" / "multiclass" / "train.csv",
                "base_dir": self.FIXTURES_ROOT / "dataset" / "multiclass",
                "compose": {"blur"},
            }

            Dataset.from_params(Params(params))
Пример #6
0
def create_kwargs(constructor: Callable[..., T], cls: Type[T], params: Params,
                  **extras) -> Dict[str, Any]:
    kwargs: Dict[str, any] = {}

    parameters = infer_params(constructor, cls)

    for param_name, param in parameters.items():
        if param_name == "self" or param.kind == param.VAR_KEYWORD:
            continue

        constructed_param = pop_construct_param(param_name, param, params,
                                                **extras)

        if constructed_param is not param.default:
            kwargs[param_name] = constructed_param
    params.assert_empty(cls.__name__)
    return kwargs
Пример #7
0
def pop_construct_param(param_name: str, param: inspect.Parameter,
                        params: Params, **extras):
    if param_name in extras:
        if param_name not in params:
            return extras[param_name]
        else:
            logger.warning(
                f"Parameter {param_name} was found in extras, which is not required,"
                "but was also found in params. Using params value, but this can"
                "lead to unexpected results")
    optional = param.default != param.empty
    # params = normalize_params(params)
    popped_params = (params.pop(param_name, param.default)
                     if optional else params.pop(param_name))

    # TODO find cornercase when needed
    if popped_params is None:
        param_type = infer_type(param)
        origin, _ = get_origin_args(param_type)
        # if origin == Lazy:
        #     return Lazy(lambda **kwargs: None)
        return None

    return construct_param(param_name, param, popped_params, **extras)
Пример #8
0
    def test_assign_multilabel(self):
        params = {
            "type": "image_csv",
            "csv_filename":
            self.FIXTURES_ROOT / "dataset" / "multilabel" / "train.csv",
            "base_dir": self.FIXTURES_ROOT / "dataset" / "multilabel",
            "compose": {
                "transforms": [
                    {
                        "type": "fixed_size_resize",
                        "output_size": 1e5
                    },
                    {
                        "type": "gaussian_blur"
                    },
                ]
            },
            "assign_transform": {
                "type":
                "assign_multilabel",
                "colors_array": [
                    [0, 0, 0],
                    [255, 0, 0],
                    [0, 0, 255],
                    [128, 0, 128],
                ],
                "onehot_label_array": [[0, 0], [1, 0], [0, 1], [1, 1]],
            },
        }

        dataset = Dataset.from_params(Params(params))
        sample = dataset[0]

        assert "image" in sample
        assert "label" in sample
        assert "shape" in sample
Пример #9
0
    type=str,
    nargs="?",
    default=None,
    help="trainer checkpoint to resume from",
)
parser.add_argument(
    "--model-checkpoint",
    type=str,
    nargs="?",
    default=None,
    help="model checkpoint to resume from",
)

if __name__ == "__main__":
    args = parser.parse_args()
    params = Params.from_file(args.config)

    model_out_dir = params.get("model_out_dir", "./model")
    os.makedirs(model_out_dir, exist_ok=True)
    params.to_file(os.path.join(model_out_dir, "config.json"))

    exp_name = params.pop("experiment_name", "dhSegment_experiment")
    config = params.as_dict()

    ex = Experiment(exp_name)
    ex.add_config(config)

    trainer = Trainer.from_params(params,
                                  exp_name=exp_name,
                                  config=deepcopy(config))
Пример #10
0
def construct_param(param_name: str, param: inspect.Parameter, params: Params,
                    **extras):
    """

    :param param_name:
    :param param:
    :param params:
    :param extras:
    :return:
    """

    param_type = infer_type(param)

    origin, args = get_origin_args(param.annotation)

    if is_optional(param) and params is None:
        return params

    if has_from_params(param_type):
        if params is param.default:
            return param.default
        elif params is not None:
            params = normalize_params(params)
            param_type_as_from_params = cast(Type[FromParams], param_type)
            sub_extras = get_extras(param_type_as_from_params, extras)
            return param_type_as_from_params.from_params(params, **sub_extras)
    elif origin == Lazy:
        if params is param.default:
            return Lazy(lambda **kwargs: param.default)
        sub_param_type_as_from_params = cast(Type[FromParams], args[0])
        sub_extras = get_extras(sub_param_type_as_from_params, extras)
        return Lazy(lambda **kwargs: sub_param_type_as_from_params.from_params(
            params=deepcopy(params), **{
                **sub_extras,
                **kwargs
            }))

    elif param_type in {int, bool}:
        if type(params) in {int, bool}:
            return param_type(params)
        else:
            raise TypeError(
                f"Expected {param_name} to be a {param_type.__name__}.")
    elif param_type == float:
        if type(params) in {int, float}:
            return params
        else:
            raise TypeError(
                f"Expected {param_name} to be either a float or a int.")
    elif param_type == str:
        if type(params) == str or isinstance(params, Path):
            return str(params)
        else:
            raise TypeError(f"Expected {params} to be a string.")
    elif (origin in {collections.abc.Mapping, Mapping, Dict, dict}
          and len(args) == 2 and can_construct(args[-1])):
        if not isinstance(params, collections.abc.Mapping):
            raise TypeError(f"Expected {param_name} to be a mapping.")
        value_class = args[-1]
        value_class_as_param = inspect.Parameter(
            "dummy",
            kind=inspect.Parameter.VAR_KEYWORD,
            annotation=value_class)

        new_dict = {}
        for key_param, value_params in params.items():
            new_dict[key_param] = construct_param(
                f"{param_name}.{key_param}",
                value_class_as_param,
                value_params,
                **extras,
            )
        return new_dict
    elif origin in {Tuple, tuple} and all(can_construct(arg) for arg in args):
        if not isinstance(params, collections.abc.Sequence):
            raise TypeError(f"Expected {param_name} to be a sequence.")
        new_tuple = []
        prev_value_class = None
        for i, (value_class, value_params) in enumerate(
                zip(args, iterate_not_string(params))):
            if value_class == Ellipsis:
                value_class = prev_value_class
            value_class_as_param = inspect.Parameter(
                "dummy",
                kind=inspect.Parameter.VAR_KEYWORD,
                annotation=value_class)
            new_tuple.append(
                construct_param(f"{param_name}.{i}", value_class_as_param,
                                value_params, **extras))
            prev_value_class = value_class
        return tuple(new_tuple)
    elif origin in {Set, set} and len(args) == 1 and can_construct(args[0]):
        if not isinstance(params, collections.abc.Set) and not isinstance(
                params, collections.abc.Sequence):
            raise TypeError(
                f"Expected {param_name} to be a set or a sequence.")
        value_class = args[0]
        value_class_as_param = inspect.Parameter(
            "dummy",
            kind=inspect.Parameter.VAR_KEYWORD,
            annotation=value_class)

        new_set = set()

        for i, value_params in enumerate(iterate_not_string(params)):
            new_set.add(
                construct_param(f"{param_name}.{i}", value_class_as_param,
                                value_params, **extras))
        return new_set
    elif (origin in {collections.abc.Iterable, Iterable, List, list}
          and len(args) == 1 and can_construct(args[0])):
        if not isinstance(params, collections.abc.Sequence):
            raise TypeError(f"Expected {param_name} to be a sequence.")
        value_class = args[0]
        value_class_as_param = inspect.Parameter(
            "dummy",
            kind=inspect.Parameter.VAR_KEYWORD,
            annotation=value_class)

        new_list = []

        for i, value_params in enumerate(iterate_not_string(params)):
            new_list.append(
                construct_param(f"{param_name}.{i}", value_class_as_param,
                                value_params, **extras))
        return new_list
    elif origin == Union:
        backup_params = deepcopy(params)
        all_errors = []
        for value_class in args:
            value_class_as_param = inspect.Parameter(
                "dummy",
                kind=inspect.Parameter.VAR_KEYWORD,
                annotation=value_class)
            try:
                return construct_param(param_name, value_class_as_param,
                                       params, **extras)
            except (
                    AttributeError,
                    ValueError,
                    TypeError,
                    ConfigurationError,
                    RegistrableError,
            ) as e:
                params = deepcopy(backup_params)
                all_errors.append(e)
        raise ConfigurationError(
            f"Failed to construct {param_name} with type {param_type} and {params}, got errors: {all_errors}"
        )
    else:
        logger.warning(
            f"The params {str(params)} were not matched, returning them as is."
        )
        if isinstance(params, Params):
            return params.as_dict()
        else:
            return params