示例#1
0
文件: pytorch.py 项目: xiaowu0162/nni
    'max_experiment_duration':
    lambda value: f'{util.parse_time(value)}s' if value is not None else None,
    'experiment_working_directory':
    util.canonical_path
}

_validation_rules = {
    'trial_code_directory':
    lambda value:
    (Path(value).is_dir(), f'"{value}" does not exist or is not directory'),
    'trial_concurrency':
    lambda value: value > 0,
    'trial_gpu_number':
    lambda value: value >= 0,
    'max_experiment_duration':
    lambda value: util.parse_time(value) > 0,
    'max_trial_number':
    lambda value: value > 0,
    'log_level':
    lambda value: value in
    ["trace", "debug", "info", "warning", "error", "fatal"],
    'training_service':
    lambda value:
    (type(value) is not TrainingServiceConfig, 'cannot be abstract base class')
}


def preprocess_model(base_model, trainer, applied_mutators, full_ir=True):
    # TODO: this logic might need to be refactored into execution engine
    if full_ir:
        try:
示例#2
0
    @property
    def _validation_rules(self):
        return _validation_rules


_canonical_rules = {
    'trial_code_directory': util.canonical_path,
    'max_experiment_duration': lambda value: f'{util.parse_time(value)}s' if value is not None else None,
    'experiment_working_directory': util.canonical_path
}

_validation_rules = {
    'trial_code_directory': lambda value: (Path(value).is_dir(), f'"{value}" does not exist or is not directory'),
    'trial_concurrency': lambda value: value > 0,
    'trial_gpu_number': lambda value: value >= 0,
    'max_experiment_duration': lambda value: util.parse_time(value) > 0,
    'max_trial_number': lambda value: value > 0,
    'log_level': lambda value: value in ["trace", "debug", "info", "warning", "error", "fatal"],
    'training_service': lambda value: (type(value) is not TrainingServiceConfig, 'cannot be abstract base class')
}

def preprocess_model(base_model, trainer, applied_mutators):
        try:
            script_module = torch.jit.script(base_model)
        except Exception as e:
            _logger.error('Your base model cannot be parsed by torch.jit.script, please fix the following error:')
            raise e
        base_model_ir = convert_to_graph(script_module, base_model)
        base_model_ir.evaluator = trainer

        # handle inline mutations