Exemplo n.º 1
0
    def test_train_model_runs_successfully_for_simplified_case(self):
        # Note: This test is simply a mock test to ensure that the pipeline
        # runs successfully, and is not a test of the quality of the model
        # itself.

        train_model(
            self.trainingPath,
            self.modelPath,
            self.config_network,
            debug_mode=True
            )

        expectedFiles = [
            "checkpoint",
            "config_network.json",
            "evolution_stats.pkl",
            "evolution.pkl",
            "model.ckpt.data-00000-of-00001",
            "model.ckpt.index",
            "model.ckpt.meta",
            "report.txt"
            ]

        existingFiles = os.listdir(self.modelPath)

        for fileName in expectedFiles:
            assert fileName in existingFiles
Exemplo n.º 2
0
def compute_training(configfile, path_trainingset, path_model, path_model_init = None, gpu_per = 1.0):

    os.chdir(sys.path[0]) # Necessary to fix the directory we are working in
    with open(os.path.join(path_model, configfile), 'r') as fd:
        config_network = json.loads(fd.read())

    #if not os.path.exists(path_model): # Already created before. But can be useful if we want to create models with date of launch.
    #    os.makedirs(path_model)

    #with open(path_model + configname, 'w') as f:
    #    json.dump(config, f, indent=2)

    #with open(path_model + filename, 'r') as fd:
    #    config_network = json.loads(fd.read())

    # Training
    from AxonDeepSeg.train_network import train_model
    train_model(path_trainingset, path_model, config_network, gpu_per=gpu_per)
Exemplo n.º 3
0
    def test_train_model_runs_successfully_for_simplified_case(self):
        # Note: This test is simply a mock test to ensure that the pipeline
        # runs successfully, and is not a test of the quality of the model
        # itself.

        train_model(str(self.trainingPath),
                    str(self.modelPath),
                    self.config_network,
                    debug_mode=True)

        expectedFiles = [
            "checkpoint", "config_network.json",
            "model.ckpt.data-00000-of-00001", "model.ckpt.index",
            "model.ckpt.meta"
        ]

        existingFiles = [f.name for f in self.modelPath.iterdir()]

        for fileName in expectedFiles:
            assert fileName in existingFiles
Exemplo n.º 4
0
    # Weighted cost parameters:
    "weighted_cost-activate": True,
    "weighted_cost-balanced_activate": True,
    "weighted_cost-balanced_weights": [1.1, 1, 1.3],
    "weighted_cost-boundaries_sigma": 2,
    "weighted_cost-boundaries_activate": False,

    # Data augmentation parameters:
    "da-type": "all",
    "da-2-random_rotation-activate": False,
    "da-5-noise_addition-activate": False,
    "da-3-elastic-activate": True,
    "da-0-shifting-activate": True,
    "da-4-flipping-activate": True,
    "da-1-rescaling-activate": False
}

if os.path.exists(path_model + filename):
    with open(path_model + filename, 'r') as fd:
        config_network = json.loads(fd.read())
else:  # There is no config file for the moment
    with open(path_model + filename, 'w') as f:
        json.dump(config, f, indent=2)
    with open(path_model + filename, 'r') as fd:
        config_network = json.loads(fd.read())

tf.reset_default_graph()

train_model(path_training, path_model, config)