コード例 #1
0
    def __init__(
        self,
        *args,
        **kwargs,
    ):

        self._threshold = kwargs.get("threshold", None)
        self.__instantiate_transform(kwargs)
        BaseDatasetSamplerMixin.__init__(self, *args, **kwargs)
        BaseTasksMixin.__init__(self, *args, **kwargs)
        self.clean_kwargs(kwargs)
        LightningDataModule.__init__(self, *args, **kwargs)

        self.dataset_train = None
        self.dataset_val = None
        self.dataset_test = None

        self._seed = 42
        self._num_workers = 2
        self._shuffle = True
        self._drop_last = False
        self._pin_memory = True
        self._follow_batch = []

        self._hyper_parameters = {}
コード例 #2
0
 def __init__(
         self,
         train_sequences: List[str],
         validation_sequences: List[str],
         alphabet: AlphabetDataLoader,
         masking_ratio: float,
         masking_prob: float,
         random_token_prob: float,
         num_workers: int,
         toks_per_batch: int,
         crop_sizes: Tuple[int, int] = (512, 1024),
 ):
     LightningDataModule.__init__(self)
     self._train_sequences = train_sequences
     self._validation_sequences = validation_sequences
     self._alphabet = alphabet
     self._masking_ratio = masking_ratio
     self._masking_prob = masking_prob
     self._random_token_prob = random_token_prob
     self._num_workers = num_workers
     self._toks_per_batch = toks_per_batch
     self._crop_sizes = crop_sizes
コード例 #3
0
    def __post_init__(self,
                      observation_space: gym.Space = None,
                      action_space: gym.Space = None,
                      reward_space: gym.Space = None):
        """ Initializes the fields of the setting that weren't set from the
        command-line.
        """
        logger.debug(f"__post_init__ of Setting")
        if len(self.train_transforms) == 1 and isinstance(self.train_transforms[0], list):
            self.train_transforms = self.train_transforms[0]
        if len(self.val_transforms) == 1 and isinstance(self.val_transforms[0], list):
            self.val_transforms = self.val_transforms[0]
        if len(self.test_transforms) == 1 and isinstance(self.test_transforms[0], list):
            self.test_transforms = self.test_transforms[0]

        # Actually compose the list of Transforms or callables into a single transform.
        self.train_transforms: Compose = Compose(self.train_transforms)
        self.val_transforms: Compose = Compose(self.val_transforms)
        self.test_transforms: Compose = Compose(self.test_transforms)

        LightningDataModule.__init__(self,
            train_transforms=self.train_transforms,
            val_transforms=self.val_transforms,
            test_transforms=self.test_transforms,
        )
        
        self._observation_space = observation_space
        self._action_space = action_space
        self._reward_space = reward_space

        # TODO: It's a bit confusing to also have a `config` attribute on the
        # Setting. Might want to change this a bit.
        self.config: Config = None

        self.train_env: Environment = None  # type: ignore
        self.val_env: Environment = None  # type: ignore
        self.test_env: Environment = None  # type: ignore
コード例 #4
0
ファイル: setting.py プロジェクト: gopeshh/Sequoia
    def __post_init__(
        self,
        observation_space: gym.Space = None,
        action_space: gym.Space = None,
        reward_space: gym.Space = None,
    ):
        """ Initializes the fields of the setting that weren't set from the
        command-line.
        """
        logger.debug("__post_init__ of Setting")
        # BUG: simple-parsing sometimes parses a list with a single item, itself the
        # list of transforms. Not sure if this still happens.

        def is_list_of_list(v: Any) -> bool:
            return isinstance(v, list) and len(v) == 1 and isinstance(v[0], list)

        if is_list_of_list(self.train_transforms):
            self.train_transforms = self.train_transforms[0]
        if is_list_of_list(self.val_transforms):
            self.val_transforms = self.val_transforms[0]
        if is_list_of_list(self.test_transforms):
            self.test_transforms = self.test_transforms[0]

        if all(
            t is None
            for t in [
                self.transforms,
                self.train_transforms,
                self.val_transforms,
                self.test_transforms,
            ]
        ):
            # Use these two transforms by default if no transforms are passed at all.
            # TODO: Remove this after the competition perhaps.
            self.transforms = Compose([Transforms.to_tensor, Transforms.three_channels])

        # If the constructor is called with just the `transforms` argument, like this:
        # <SomeSetting>(dataset="bob", transforms=foo_transform)
        # Then we use this value as the default for the train, val and test transforms.
        if self.transforms and not any(
            [self.train_transforms, self.val_transforms, self.test_transforms]
        ):
            if not isinstance(self.transforms, list):
                self.transforms = Compose([self.transforms])
            self.train_transforms = self.transforms.copy()
            self.val_transforms = self.transforms.copy()
            self.test_transforms = self.transforms.copy()

        if self.train_transforms is not None and not isinstance(
            self.train_transforms, list
        ):
            self.train_transforms = [self.train_transforms]

        if self.val_transforms is not None and not isinstance(
            self.val_transforms, list
        ):
            self.val_transforms = [self.val_transforms]

        if self.test_transforms is not None and not isinstance(
            self.test_transforms, list
        ):
            self.test_transforms = [self.test_transforms]

        # Actually compose the list of Transforms or callables into a single transform.
        self.train_transforms: Compose = Compose(self.train_transforms or [])
        self.val_transforms: Compose = Compose(self.val_transforms or [])
        self.test_transforms: Compose = Compose(self.test_transforms or [])

        LightningDataModule.__init__(
            self,
            train_transforms=self.train_transforms,
            val_transforms=self.val_transforms,
            test_transforms=self.test_transforms,
        )

        self._observation_space = observation_space
        self._action_space = action_space
        self._reward_space = reward_space

        # TODO: It's a bit confusing to also have a `config` attribute on the
        # Setting. Might want to change this a bit.
        self.config: Config = None

        self.train_env: Environment = None  # type: ignore
        self.val_env: Environment = None  # type: ignore
        self.test_env: Environment = None  # type: ignore