示例#1
0
    def test_import_object__numpy(self):
        """Test `import_object` with a numpy.array"""

        NumpyArray = import_object(object_importpath='numpy.array')
        array1 = NumpyArray([0, 1, 2, 3])
        array2 = np.array([0, 1, 2, 3])

        assert isinstance(array1, np.ndarray)
        assert np.array_equal(array1, array2)
def main():
    """Train AlexNet on ImageNet"""

    args = parse_args()
    with open(args.fpath_config) as f:
        training_config = yaml.load(f)

    TrainingJob = import_object(training_config['job_importpath'])
    training_job = TrainingJob(training_config)
    training_job.run()
示例#3
0
    def test_import_object__dev_env(self, monkeypatch):
        """Test `import_object` with utils.dev_env"""
        def mock_dev_env_get(group, key):
            """Mock `utils.dev_env.get` function"""
            return '/data/imagenet'

        monkeypatch.setattr('utils.dev_env.get', mock_dev_env_get)

        get = import_object(object_importpath='utils.dev_env.get')
        assert get('imagenet', 'dirpath_data') == '/data/imagenet'
        assert get('mpii', 'dirpath_data') == '/data/imagenet'
示例#4
0
    def _parse_transformations(self, transformations):
        """Parse the provided transformations into the expected format

        When passed into the dataset transformers (whose import paths are
        listed below), `transformations` is expected to be a list of two
        element tuples, where each tuple contains a transformation function to
        apply as the first element and function kwargs as the second element.
        When they are parsed from the config (and passed into this function),
        they are a list of dictionaries. This function mostly reformats them to
        the format expected by the following dataset transformer classes:
        - training.pytorch.dataset_transformer.PyTorchDataSetTransformer
        - training.tf.data_loader.TFDataLoader

        :param transformations: holds the transformations to apply to each
         batch of data, where each transformation is specified as a dictionary
         with the key equal to the importpath of the callable transformation
         and the value equal to a dictionary holding keyword arguments for the
         callable
        :type transformations: list[dict]
        :return: parsed transformations reformatted for the dataset transformer
         classes
        :type transformations: list[tuple]
        """

        processed_transformations = []
        for transformation in transformations:
            assert len(transformation) == 1
            transformation_fn_importpath = list(transformation.keys())[0]
            transformation_config = list(transformation.values())[0]

            transformation_fn = import_object(transformation_fn_importpath)
            processed_transformation_config = {}
            for param, arguments in transformation_config.items():
                value = arguments['value']
                if arguments.get('import'):
                    value = import_object(value)
                processed_transformation_config[param] = value
            processed_transformations.append(
                (transformation_fn, processed_transformation_config))

        return processed_transformations
示例#5
0
    def _instantiate_network(self):
        """Return the network object to train

        This relies on the `network` section of `self.config`. This section
        must contain the following keys:
        - str importpath: import path to the network class to use for training
        - dict init_params: parameters to pass directly into the `__init__` of
          the specified network as keyword arguments

        :return: network for training
        :rtype: object
        """

        network_spec = self.config['network']

        network_importpath = network_spec['importpath']
        Network = import_object(network_importpath)
        return Network(**network_spec['init_params'])
    def _instantiate_dataset(self, set_name):
        """Return a dataset object to be used as an iterator during training

        The dataset that is returned should be able to be directly passed into
        the `train` method of whatever trainer class is specified in
        `self.config`, as either the `train_dataset` or `validation_dataset`
        argument.

        :param set_name: set to return the dataset for, one of
         {'train', 'validation'}
        :type set_name: str
        :return: two element tuple holding an iterable over the dataset for
         `set_name`, as well as the number of batches in a single pass over the
         dataset
        :rtype: tuple
        """

        assert set_name in {'train', 'validation'}
        dataset_spec = self.config['dataset']

        fpath_df_obs_key = 'fpath_df_{}'.format(set_name)
        if fpath_df_obs_key not in dataset_spec:
            if set_name == 'train':
                raise RuntimeError
            return None, None
        fpath_df_obs = dataset_spec[fpath_df_obs_key]
        df_obs = pd.read_csv(fpath_df_obs)

        dataset_importpath = dataset_spec['importpath']
        DataSet = import_object(dataset_importpath)

        dataset = DataSet(df_obs=df_obs, **dataset_spec['init_params'])
        transformations_key = '{}_transformations'.format(set_name)
        transformations = dataset_spec[transformations_key]
        transformations = self._parse_transformations(transformations)

        loader = TFDataLoader(dataset, transformations)
        loading_params = dataset_spec['{}_loading_params'.format(set_name)]
        dataset_gen = loader.get_infinite_iter(**loading_params)
        n_batches = len(loader.numpy_dataset) // loading_params['batch_size']

        return dataset_gen, n_batches
示例#7
0
    def _instantiate_trainer(self):
        """Return the trainer object that runs training

        This relies on the `trainer` section of `self.config`. This section
        must contain the following keys:
        - str importpath: import path to the trainer class to use
        - dict init_params: parameters to pass directly into the `__init__` of
          the specified trainer as keyword arguments

        :return: trainer to run training
        :rtype: object
        """

        trainer_spec = self.config['trainer']

        trainer_importpath = trainer_spec['importpath']
        Trainer = import_object(trainer_importpath)
        trainer = Trainer(**trainer_spec['init_params'],
                          dirpath_save=self.dirpath_job)
        return trainer
示例#8
0
    def test_import_object__pandas(self):
        """Test `import_object` with a pandas.DataFrame"""

        rows = [{
            'a': 0,
            'b': 1,
            'c': 2
        }, {
            'a': 1,
            'b': 2,
            'c': 3
        }, {
            'a': 2,
            'b': 3,
            'c': 4
        }]

        PandasDataFrame = import_object(object_importpath='pandas.DataFrame')
        dataframe1 = PandasDataFrame(rows)
        dataframe2 = pd.DataFrame(rows)

        assert isinstance(dataframe1, pd.DataFrame)
        assert dataframe1.equals(dataframe2)