예제 #1
0
    def test_search_modules(self):
        with pytest.raises(ValueError):
            instantiate_from_dict({
                'class_type': 'datetime',
                'year': 2018,
                'month': 9,
                'day': 8,
            })

        with pytest.raises(ObjectNotFoundError):
            instantiate_from_dict(
                {
                    'class_type': 'datetime',
                    'year': 2018,
                    'month': 9,
                    'day': 8,
                },
                search_modules=['os'])

        dt = instantiate_from_dict(
            {
                'class_type': 'datetime',
                'year': 2018,
                'month': 9,
                'day': 8,
            },
            search_modules=['datetime'])

        assert dt == datetime.datetime(2018, 9, 8)
예제 #2
0
    def _construct_features_map(
            cls,
            features_map: Dict[str, Union[Dict, FeatureTypeBase]]) -> Dict[str, FeatureTypeBase]:
        """
        Create an example feature based on a dictionary that describes the type and parameters.

        The reconstruction process is similar to reconstruction followed by variant definition for
        Model and Feeders

        :param features_map: Example features or dictionary with their definitions mapped in unique key.
        :return: The example feature objects mapped in their key, preserving the same order as in `features_map`
        """

        features_map_objects = {}

        if 'module_path' not in features_map:
            module_path = 'tensorflow'
        else:
            module_path = features_map['module_path']
            del (features_map['module_path'])

        for fname, fmap in features_map.items():

            if not isinstance(fmap, dict):
                features_map_objects[fname] = fmap
                continue

            features_map_objects[fname] = instantiate_from_dict(
                fmap,
                search_modules=[module_path])

        result = features_map_objects
        return result
예제 #3
0
    def create_exporters(self, exporters_config_path: str) -> List[Exporter]:
        """
        Create a fresh Feeder object instance based on the definition of the variant.

        :param exporters_config_path: A dot-separated path in the variant definition
            that holds the exporter configuration. Usually a value of `train.exporters`

        :return: The instantiated Exporter object
        """
        exporters_dict = self.get(exporters_config_path)
        exporters = []

        for k in exporters_dict.keys():
            e = exporters_dict[k]

            if 'serving_input_receiver_fn' in e.keys():
                feeder = self.create_feeder(
                    f'train.exporters.{k}.serving_feeder')
                serving_input_receiver_fn = getattr(
                    feeder, e['serving_input_receiver_fn'])
                e['serving_input_receiver_fn'] = serving_input_receiver_fn

            exporters.append(
                instantiate_from_dict(
                    e,
                    search_modules=[e['module_path']],
                    exclude_keys=['module_path', 'serving_feeder']))

        return exporters
예제 #4
0
    def create_feeder(self, feeder_config_path: str) -> FeederBase:
        """
        Create a fresh Feeder object instance based on the definition of the variant.

        :param feeder_config_path: A dot-separated path in the variant definition
            that holds the feeder configuration. Usually a value of `train.train_feeder` or
            `train.eval_feeder`

        :return: The instantiated Feeder object
        """
        if self.has('train.train_feeder'):
            search_modules = [self.get('train.train_feeder.module_path')]
        elif self.has('train.eval_feeder'):
            search_modules = [self.get('train.eval_feeder.module_path')]
        elif self.has('predict.predict_feeder'):
            search_modules = [self.get('predict.predict_feeder.module_path')]
        else:
            raise ValueError(
                'Variant must have one of the following: \n ' +
                'train.train_feeder, train.eval_feeder, predict.predict_feeder'
            )

        feeder = instantiate_from_dict(
            self.get(feeder_config_path),
            search_modules=search_modules,
            exclude_keys=['model_dir', 'module_path', 'class_type'])

        if not isinstance(feeder, FeederBase):
            logger.warn(
                f"Class '{feeder.__class__.__name__}' that is used as feeder "
                f"type is not subclass of FeederBase.")
        return feeder
예제 #5
0
    def test_class_name_key(self):
        with pytest.raises(InstantiationError):
            instantiate_from_dict({
                'wrong_type': 'datetime:datetime',
                'year': 2018,
                'month': 9,
                'day': 8,
            })

        dt = instantiate_from_dict(
            {
                '_class_': 'datetime:datetime',
                'year': 2018,
                'month': 9,
                'day': 8,
            },
            class_name_key='_class_')

        assert dt == datetime.datetime(2018, 9, 8)
예제 #6
0
    def test_excluded_params(self):
        with pytest.raises(InstantiationError):
            instantiate_from_dict({
                'class_type': 'datetime:datetime',
                'year': 2018,
                'month': 9,
                'day': 8,
                'wrong': 15
            })

        instantiate_from_dict(
            {
                'class_type': 'datetime:datetime',
                'year': 2018,
                'month': 9,
                'day': 8,
                'wrong': 15
            },
            exclude_keys=['wrong'])
예제 #7
0
 def test_simple_scenario(self):
     dt = instantiate_from_dict({
         'class_type': 'datetime:datetime',
         'year': 2018,
         'month': 9,
         'day': 8,
         'hour': 7,
         'minute': 6,
         'second': 5
     })
     assert datetime.datetime(2018, 9, 8, 7, 6, 5) == dt
예제 #8
0
    def create_model(self) -> ModelBase:
        """
        Create a fresh model instance based on the definition of the variant

        :return: The instantiated Model object
        """
        model = instantiate_from_dict(
            self.get('model'), search_modules=[self.get('model.module_path')])
        if not isinstance(model, ModelBase):
            logger.warn(
                f"Class '{model.__class__.__name__}' that is used as model type is not subclass of ModelBase."
            )
        return model