コード例 #1
0
ファイル: Model_Lung.py プロジェクト: Mullans/TF_Toolkit
 def load_weights(self, version=None, path=None):
     if path is None:
         path = self.model_dir / version / 'model_weights.h5'
     path = gouda.GoudaPath(path)
     if path.exists():
         self.model.load_weights(path)
     else:
         raise ValueError("No weights found at: {}".format(path.abspath))
コード例 #2
0
 def load_weights(self, model_version=None, weights_path=None):
     """Load model weights from either a version of the current group/model or from a file path"""
     if weights_path is None:
         if model_version is None:
             raise ValueError("model_version must be specified if weights_path is not used")
         weights_path = project_path('results',
                                     self.model_args['model_group'],
                                     self.model_args['model_name'],
                                     model_version,
                                     'model_weights.h5')
         if not weights_path.exists():
             raise ValueError("No file found at {}".format(weights_path.abspath))
     weights_path = gouda.GoudaPath(weights_path)
     self.model.load_weights(weights_path.abspath)
コード例 #3
0
ファイル: Model_Lung.py プロジェクト: Mullans/TF_Toolkit
    def __init__(self,
                 model_name='default',
                 model_group='default',
                 model_type='template',
                 filter_scale=0,
                 num_outputs=2,
                 input_shape=[512, 512, 1],
                 load_args=False,
                 **kwargs):
        """Initialize a model for the network

        Parameters
        ----------
        model_name : str
            The name of the model to use - should define the model level parameters
        model_group : str
            The group of models to use - should define the model structure or data paradigm
        filter_scale : int
            The scaling factor to use for the model layers (scales by powers of 2)
        input_shape : tuple of ints
            The shape of data to be passed to the model (not including batch size)
        load_args : bool
            Whether to use pre-existing arguments for the given model group+name
        """
        if input_shape is None and load_args is False:
            raise ValueError("Input shape cannot be None for model object")
        self.model_dir = gouda.GoudaPath(
            gouda.ensure_dir(RESULTS_DIR, model_group, model_name))
        if load_args:
            if self.model_dir('model_args.json').exists():
                self.model_args = gouda.load_json(
                    self.model_dir('model_args.json'))
            else:
                raise ValueError(
                    "Cannot find model args for model {}/{}".format(
                        model_group, model_name))
        else:
            self.model_args = {
                'model_name': model_name,
                'model_group': model_group,
                'model_type': model_type,
                'filter_scale': filter_scale,
                'input_shape': input_shape,
            }
            for key in kwargs:
                self.model_args[key] = kwargs[key]
            gouda.save_json(self.model_args, self.model_dir('model_args.json'))
        K.clear_session()
        self.model = lookup_model(model_type)(**self.model_args)
コード例 #4
0
ファイル: Model_Neuron.py プロジェクト: Mullans/TF_Toolkit
    def __init__(self,
                 model_name='default',
                 model_group='default',
                 model_type='multires',
                 filter_scale=0,
                 out_layers=1,
                 out_classes=2,
                 input_shape=[1024, 1360, 1],
                 patch_out=False,
                 load_args=False,
                 **kwargs):
        K.clear_session()
        self.loaded = False

        self.model_dir = gouda.GoudaPath(
            gouda.ensure_dir(RESULTS_DIR, model_group, model_name))
        args_path = self.model_dir / 'model_args.json'
        if load_args:
            if not args_path.exists():
                raise ValueError("No model arguments found at path: {}".format(
                    args_path.abspath))
            self.model_args = gouda.load_json(args_path)
        else:
            self.model_args = {
                'model_name': model_name,
                'model_group': model_group,
                'model_type': model_type,
                'filter_scale': filter_scale,
                'out_layers': out_layers,
                'out_classes': out_classes,
                'input_shape': input_shape,
                'patch_out': patch_out
            }
            for key in kwargs:
                self.model_args[key] = kwargs[key]
            gouda.save_json(self.model_args, args_path)

        model_func = get_model_func(self.model_args['model_type'])
        self.model = model_func(**self.model_args)
コード例 #5
0
    def __init__(
            self,
            model_group='default',
            model_name='default',
            project_dir=None,
            load_args=False,
            overwrite_args=False,
            # distributed=False,
            **kwargs):
        if project_dir is None:
            # Generally the location the code is called from
            project_dir = os.getcwd()
        K.clear_session()
        self.model_args = {
            'model_name': model_name,
            'model_group': model_group,
            'train_step': 'default',
            'val_step': 'default',
            'lr_type': 'base',
            'model_func': None
        }
        for key in kwargs:
            self.model_args[key] = kwargs[key]

        # self.is_distributed = distributed
        self.model = None

        self.results_dir = gouda.GoudaPath(os.path.join(
            project_dir, 'Results'))
        gouda.ensure_dir(self.results_dir)
        group_dir = self.results_dir / model_group
        gouda.ensure_dir(group_dir)
        self.model_dir = group_dir / model_name
        gouda.ensure_dir(self.model_dir)
        args_path = self.model_dir('model_args.json')
        if load_args:
            self.load_args(args_path)
        if overwrite_args or not args_path.exists():
            self.save_args()