示例#1
0
    def __post_init__(self) -> None:
        # make sure all properties are populated
        common_util.check_properties_are_not_none(self, ignore=["mask_channel"])

        if not self.image_channels:
            raise ValueError("image_channels cannot be empty")
        if not self.ground_truth_channels:
            raise ValueError("ground_truth_channels cannot be empty")
 def __post_init__(self) -> None:
     check_properties_are_not_none(self)
     # check to make sure paths are no long paths are provided
     long_paths = list(filter(is_long_path, self.checkpoint_paths))
     if long_paths:
         raise ValueError(
             f"Following paths: {long_paths} are greater than {MAX_PATH_LENGTH}"
         )
示例#3
0
    def __post_init__(self) -> None:
        # make sure all properties are populated
        common_util.check_properties_are_not_none(self)

        # ensure the center crops for the labels and mask are compatible with each other
        ml_util.check_size_matches(arg1=self.mask_center_crop,
                                   arg2=self.labels_center_crop,
                                   matching_dimensions=self._get_matching_dimensions())
示例#4
0
    def __post_init__(self) -> None:
        # make sure all properties are populated
        common_util.check_properties_are_not_none(self)

        ml_util.check_size_matches(arg1=self.image, arg2=self.mask,
                                   matching_dimensions=self._get_matching_dimensions())

        ml_util.check_size_matches(arg1=self.image, arg2=self.labels,
                                   matching_dimensions=self._get_matching_dimensions())
    def __post_init__(self) -> None:
        common_util.check_properties_are_not_none(self)

        if len(self.train_results_per_epoch) != len(self.val_results_per_epoch) != len(self.learning_rates_per_epoch):
            raise Exception("train_results_per_epoch must be the same length as val_results_per_epoch found "
                            "and learning_rates_per_epoch, found: train_metrics_per_epoch={}, "
                            "val_metrics_per_epoch={}, learning_rates_per_epoch={}"
                            .format(len(self.train_results_per_epoch), len(self.val_results_per_epoch),
                                    len(self.learning_rates_per_epoch)))
    def __post_init__(self) -> None:
        common_util.check_properties_are_not_none(self)

        ml_util.check_size_matches(arg1=self.image, arg2=self.prediction,
                                   dim1=3, dim2=3,
                                   matching_dimensions=[])

        ml_util.check_size_matches(arg1=self.image, arg2=self.labels,
                                   dim1=3, dim2=4,
                                   matching_dimensions=[-1, -2, -3])
示例#7
0
    def __post_init__(self) -> None:
        common_util.check_properties_are_not_none(self)
        # perform dataset split validity assertions
        unique_train, unique_test, unique_val = self.unique_subjects()
        intersection = set.intersection(set(unique_train), set(unique_test), set(unique_val))

        if len(intersection) != 0:
            raise ValueError("Train, Test, and Val splits must have no intersection, found: {}".format(intersection))

        if (not self.allow_empty) and any([len(x) == 0 for x in [unique_train, unique_val]]):
            raise ValueError("train_ids({}), val_ids({}) must have at least one value"
                             .format(len(unique_train), len(unique_val)))
    def __post_init__(self) -> None:
        # make sure all properties are populated
        common_util.check_properties_are_not_none(self,
                                                  ignore=["mask_channel"])

        if not self.image_channels:
            raise ValueError("image_channels cannot be empty")

        if not self.ground_truth_channels:
            raise ValueError("ground_truth_channels cannot be empty")

        if self.ground_truth_channels.count(
                None) > 0 and not self.allow_incomplete_labels:
            raise ValueError("all ground_truth_channels must be provided")
示例#9
0
    def __post_init__(self) -> None:
        check_properties_are_not_none(self, ignore=["subject_ids"])

        if len(self.model_outputs.data) != len(self.labels.data):
            raise ValueError(
                "model_outputs and labels must have the same length, "
                f"found {len(self.model_outputs.data)} and {len(self.labels.data)}"
            )

        if not torch.equal(self.model_outputs.batch_sizes,
                           self.labels.batch_sizes):
            raise ValueError(
                "batch_sizes for model_outputs and labels must be equal, "
                f"found {self.model_outputs.batch_sizes} and {self.labels.batch_sizes}"
            )

        if not torch.equal(self.model_outputs.sorted_indices,
                           self.labels.sorted_indices):
            raise ValueError(
                "sorted_indices for model_outputs and labels must be equal, "
                f"found {self.model_outputs.sorted_indices} and {self.labels.sorted_indices}"
            )

        if not torch.equal(self.model_outputs.unsorted_indices,
                           self.labels.unsorted_indices):
            raise ValueError(
                "unsorted_indices for model_outputs and labels must be equal, "
                f"found {self.model_outputs.unsorted_indices} and {self.labels.unsorted_indices}"
            )

        _expected_subjects = self.labels.batch_sizes.max().item()
        if self.subject_ids is not None and len(
                self.subject_ids) != _expected_subjects:
            raise ValueError(
                f"expected {_expected_subjects} subject_ids but found {len(self.subject_ids)}"
            )
示例#10
0
    def __post_init__(self) -> None:
        common_util.check_properties_are_not_none(self)

        def pairwise_intersection(*collections: Iterable) -> Set:
            """Returns any element that appears in more than one collection."""
            intersection = set()
            for col1, col2 in combinations(map(set, collections), 2):
                intersection |= col1 & col2
            return intersection

        # perform dataset split validity assertions
        unique_train, unique_test, unique_val = self.unique_subjects()
        intersection = pairwise_intersection(unique_train, unique_test,
                                             unique_val)

        if len(intersection) != 0:
            raise ValueError(
                "Train, Test, and Val splits must have no intersection, found: {}"
                .format(intersection))

        if self.group_column is not None:
            groups_train = self.train[self.group_column].unique()
            groups_test = self.test[self.group_column].unique()
            groups_val = self.val[self.group_column].unique()
            group_intersection = pairwise_intersection(groups_train,
                                                       groups_test, groups_val)
            if len(group_intersection) != 0:
                raise ValueError(
                    "Train, Test, and Val splits must have no intersecting groups, found: {}"
                    .format(group_intersection))

        if (not self.allow_empty) and any(
            [len(x) == 0 for x in [unique_train, unique_val]]):
            raise ValueError(
                "train_ids({}), val_ids({}) must have at least one value".
                format(len(unique_train), len(unique_val)))
 def _validate(self) -> None:
     check_properties_are_not_none(self)
     if len(self.checkpoints_roots) == 0:
         raise ValueError("checkpoints_roots must not be empty")
示例#12
0
 def __post_init__(self) -> None:
     common_util.check_properties_are_not_none(
         self, ignore=["torch_cuda_random_state"])
 def __post_init__(self) -> None:
     common_util.check_properties_are_not_none(self)
 def __post_init__(self) -> None:
     check_properties_are_not_none(self)
示例#15
0
 def __post_init__(self) -> None:
     common_util.check_properties_are_not_none(self, ignore=["dataset_csv_file", "run_recovery_id"])
 def __post_init__(self) -> None:
     common_util.check_properties_are_not_none(self, ignore=["metrics"])