예제 #1
0
    def __init__(self, config, dirpath_save, device=None):
        """Init

        config must contain the following keys:
        - str optimizer: optimizer to use when training the network
        - str loss: loss function to use when training the network
        - int batch_size: batch size to use during training
        - int n_epochs: number of epochs to train for

        :param config: specifies the configuration of the trainer
        :type config: dict
        :param dirpath_save: directory path to save the model to during
         training
        :type dirpath_save: str
        :param device: specifies a GPU to use for training
        :type device: str or torch.device
        """

        validate_config(config, self.required_config_keys)

        self.optimizer = config['optimizer']
        self.loss = config['loss']
        self.batch_size = config['batch_size']
        self.n_epochs = config['n_epochs']
        self.device = device

        self.dirpath_save = dirpath_save
    def __init__(self, df_obs, config):
        """Init

        `config` must contain the following keys:
        - int height: height to reshape the images to
        - int width: width to reshape the images to

        :param config: specifies the configuration of the dataset
        :type config: dict
        :param df_obs: holds the filepath to the input image ('fpath_image')
         and the target label for the image ('label')
        :type df_obs: pandas.DataFrame
        """

        validate_config(config, self.required_config_keys)

        if set(df_obs.columns) < {'fpath_image', 'label'}:
            msg = ('df_obs must have an \'fpath_image\' and \'label\' '
                   'column, and only {} columns were given.').format(
                       df_obs.columns)
            raise KeyError(msg)

        self.df_obs = df_obs
        self.config = config
        self.df_obs['label'] = (self.df_obs['label'].astype(
            self.sample_types['label']))
예제 #3
0
    def test_validate_config__bad(self):
        """Test `validate_config` when required keys are missing"""

        config = {0: 1, 'test': 'testy_test', (0, 1): 4}
        required_keys = list(config.keys())

        for key in config:
            config_copy = config.copy()
            del config_copy[key]

            with pytest.raises(KeyError):
                validate_config(config_copy, required_keys)
예제 #4
0
    def test_validate_config__good(self):
        """Test `validate_config` when all required keys are present"""

        possible_keys = [(1, 2), 'required1', 85, 'required2', 4, (2, 3, 4)]

        for _ in range(3):
            required_keys = random.sample(possible_keys, 3)
            config = {
                required_key: 'sentinel_value'
                for required_key in required_keys
            }

            validate_config(config, required_keys)
예제 #5
0
    def __init__(self, config):
        """Init

        `config` must contain the following keys:
        - int height: height of the input to the network
        - int width: width of the input to the network
        - int n_channels: number of channels of the input
        - int n_classes: number of classes in the output layer

        :param config: specifies the configuration for the network
        :type config: dict
        """

        validate_config(config, self.required_config_keys)
        self.config = config
예제 #6
0
    def __init__(self, config):
        """Init

        `config` must contain the following keys:
        - int n_channels: number of channels of the input
        - int n_classes: number of classes in the output layer

        :param config: specifies the configuration for the network
        :type config: dict
        """

        super().__init__()

        validate_config(config, self.required_config_keys)
        self.config = config
        self._set_layers()
예제 #7
0
    def __init__(self, config):
        """Init

        The `config` must contain the following keys:
        - dict network: specifies the network class to train as well as how to
          build it; see the `_instantiate_network` method for details
        - dict trainer: specifies the trainer class to train with; see the
          `_instantiate_trainer` method for details
        - dict dataset: specifies the training and validation dataset classes
          to train with, as well as how to load the data from the datasets; see
          the `_instantiate_dataset` method in child classes for details

        It can contain the following additional keys:
        - str 'job_name': optional name given to the job; the timestamp of when
          the job started will be appended to the job_name to uniquely identify
          the directory name the job will be saved to
        - str 'dirpath_jobs': optional directory path to save job directory in,
          resulting in the job being saved to 'dirpath_jobs/dirname_job';
          defaults to `os.environ['HOME']/training_jobs`
        - int gpu_id: GPU to run the job on; defaults to None, which means the
          job runs on the CPU

        See the `_parse_dirpath_job` method for details on where the results of
        the training job will be stored.

        :param config: config file specifying a training job to run
        :type config: dict
        """

        validate_config(config, self.required_config_keys)
        self.config = config
        self.dirpath_job = self._parse_dirpath_job()

        self.gpu_id = self.config.get('gpu_id', None)
        if self.gpu_id is not None:
            os.environ['CUDA_VISIBLE_DEVICES'] = str(self.gpu_id)

        fpath_config = os.path.join(self.dirpath_job, 'config.yml')
        with open(fpath_config, 'w') as f:
            yaml.dump(self.config, f, default_flow_style=False)