예제 #1
0
    def test_load_optimizer(self, mock_toml_load):
        def toml_load(file):
            return file

        mock_toml_load.side_effect = toml_load

        file_cntn = {
            "model": {
                "blocks": None
            },
            "compile": {
                "optimizer": "keras:Adam",
                "losses": None,
                "lr": 1.0
            }
        }
        builder = ModelBuilder(file_cntn)
        opti = builder._get_optimizer()

        target = {
            'name': 'Adam',
            'learning_rate': 1.0,
            'beta_1': 0.8999999761581421,
            'beta_2': 0.9990000128746033,
            'decay': 0.0,
            'epsilon': 1e-07,
            'amsgrad': False
        }

        for k, v in opti.get_config().items():
            self.assertAlmostEqual(v, target[k])
예제 #2
0
    def test_model_setup_CNN_model(self):
        orga = self.orga

        builder = ModelBuilder(self.model_file)
        model = builder.build(orga)

        self.assertEqual(model.input_shape[1:], self.input_shapes["input_A"])
        self.assertEqual(model.output_shape[1:], (2, ))
        self.assertEqual(len(model.layers), 14)
        self.assertEqual(model.optimizer.epsilon, 0.2)
예제 #3
0
    def test_custom_blocks(self, mock_toml_load):
        def toml_load(file):
            return file

        mock_toml_load.side_effect = toml_load

        file_cntn = {
            "model": {
                "blocks": [
                    {
                        "type": "my_custom_block",
                        "units": 10
                    },
                ]
            },
        }
        builder = ModelBuilder(file_cntn, my_custom_block=layers.Dense)
        model = builder.build_with_input({"a": (10, 1)}, compile_model=False)

        self.assertIsInstance(model.layers[-1], layers.Dense)
        self.assertEqual(model.layers[-1].get_config()["units"], 10)
예제 #4
0
    def test_integration_multi_input_model(self):
        """
        Run whole script on dummy data to see if it throws an error.

        Build a model with the ModelBuilder.
        Train for 2 epochs.
        Reset organizer.
        Resume for 1 epoch with different lr and sample modifier.
        Use orga.train and validate once each.
        Predict.

        """
        orga = self.make_orga()

        model_file = os.path.join(self.data_folder, "model_test.toml")
        builder = ModelBuilder(model_file)
        initial_model = builder.build(orga)

        orga.train_and_validate(initial_model, epochs=2)

        def test_learning_rate(epoch, fileno):
            lr = 0.001 * (epoch + 0.1 * fileno)
            return lr

        def test_modifier(info_blob):
            xs = info_blob["x_values"]
            xs = {key: xs[key] for key in xs}
            return xs

        orga = self.make_orga()
        orga.cfg.learning_rate = test_learning_rate
        orga.cfg.sample_modifier = test_modifier
        orga.train_and_validate(epochs=1)
        orga.train()
        orga.validate()
        orga.predict()
예제 #5
0
def train(directory,
          list_file=None,
          config_file=None,
          model_file=None,
          to_epoch=None):
    from orcanet.core import Organizer
    from orcanet.model_builder import ModelBuilder
    from orcanet.misc import find_file

    orga = Organizer(directory, list_file, config_file, tf_log_level=1)

    if orga.io.get_latest_epoch() is None:
        # Start of training
        print("Building new model")
        if model_file is None:
            model_file = find_file(directory, "model.toml")
        model = ModelBuilder(model_file).build(orga, verbose=False)
    else:
        model = None

    return orga.train_and_validate(model=model, to_epoch=to_epoch)
예제 #6
0
def orca_train(output_folder,
               list_file,
               config_file,
               model_file,
               recompile_model=False):
    """
    Run orga.train with predefined ModelBuilder networks using a parser.

    Parameters
    ----------
    output_folder : str
        Path to the folder where everything gets saved to, e.g. the summary
        log file, the plots, the trained models, etc.
    list_file : str
        Path to a list file which contains pathes to all the h5 files that
        should be used for training and validation.
    config_file : str
        Path to a .toml file which overwrites some of the default settings
        for training and validating a model.
    model_file : str
        Path to a file with parameters to build a model of a predefined
        architecture with OrcaNet.
    recompile_model : bool
        If the model should be recompiled or not. Necessary, if e.g. the
        loss_weights are changed during the training.

    """
    # Set up the Organizer with the input data
    orga = Organizer(output_folder, list_file, config_file, tf_log_level=1)

    # Load in the orga sample-, label-, and dataset-modifiers, as well as
    # the custom objects
    update_objects(orga, model_file)

    # If this is the start of the training, a compiled model needs to be
    # handed to the orga.train function
    if orga.io.get_latest_epoch() is None:
        # The ModelBuilder class allows to construct models from a toml file,
        # adapted to the datasets in the orga instance. Its modifiers will
        # be taken into account for this
        builder = ModelBuilder(model_file)
        model = builder.build(orga, log_comp_opts=True)

    elif recompile_model is True:
        builder = ModelBuilder(model_file)

        path_of_model = orga.io.get_model_path(-1, -1)
        model = ks.models.load_model(
            path_of_model, custom_objects=orga.cfg.get_custom_objects())
        print("Recompiling the saved model")
        model = builder.compile_model(
            model, custom_objects=orga.cfg.get_custom_objects())
        builder.log_model_properties(orga)

    else:
        model = None

    try:
        # Use a custom LR schedule
        user_lr = orga.cfg.learning_rate
        lr = orca_learning_rates(user_lr, orga.io.get_no_of_files("train"))
        orga.cfg.learning_rate = lr
    except NameError:
        pass

    # start the training
    orga.train_and_validate(model=model)
예제 #7
0
 def test_model_setup_CNN_model_custom_callback(self):
     builder = ModelBuilder(self.model_file)
     builder.optimizer = ks.optimizers.SGD()
     model = builder.build(self.orga)
     self.assertIsInstance(model.optimizer, ks.optimizers.SGD)