コード例 #1
0
def uw3_trainer_params(with_validation=False,
                       with_split=False,
                       preload=True,
                       debug=False):
    p = CalamariTestScenario.default_trainer_params()
    p.scenario.debug_graph_construction = debug
    p.force_eager = debug

    train = FileDataParams(
        images=glob_all(
            [os.path.join(this_dir, "data", "uw3_50lines", "train", "*.png")]),
        preload=preload,
    )
    if with_split:
        p.gen = CalamariSplitTrainerPipelineParams(validation_split_ratio=0.2,
                                                   train=train)
    elif with_validation:
        p.gen.val.images = glob_all(
            [os.path.join(this_dir, "data", "uw3_50lines", "test", "*.png")])
        p.gen.val.preload = preload
        p.gen.train = train
        p.gen.__post_init__()
    else:
        p.gen = CalamariTrainOnlyPipelineParams(train=train)

    p.gen.setup.val.batch_size = 1
    p.gen.setup.val.num_processes = 1
    p.gen.setup.train.batch_size = 1
    p.gen.setup.train.num_processes = 1
    post_init(p)
    return p
コード例 #2
0
def update_model(params: dict, path: str):
    logger.info(f"Updating model at {path}")
    trainer_params = TrainerParams.from_dict(params)
    post_init(trainer_params)
    scenario_params = trainer_params.scenario
    scenario = CalamariScenario(scenario_params)
    scenario.setup()
    input_layers = scenario.data.create_input_layers()
    outputs = scenario.model.build(input_layers)
    pred_model = keras.models.Model(inputs=input_layers, outputs=outputs)
    with h5py.File(path + '.h5', 'r') as f:
        if 'layer_names' not in f.attrs and 'model_weights' in f:
            f = f['model_weights']
        graph: Graph = pred_model.layers[2]
        load_weights_from_hdf5_group(
            f, [l for l in graph.layer_instances if len(l.weights) > 0] +
            [graph.logits])

    logger.info(f"Writing converted model at {path}.tmp.h5")
    pred_model.save(path + '.tmp.h5', include_optimizer=False)
    logger.info(f"Attempting to load converted model at {path}.tmp.h5")
    keras.models.load_model(
        path + '.tmp.h5',
        custom_objects=CalamariScenario.model_cls().all_custom_objects())
    logger.info(f"Replacing old model at {path}.h5")
    os.remove(path + '.h5')
    os.rename(path + '.tmp.h5', path + '.h5')
    logger.info(f"New model successfully written")
    keras.backend.clear_session()
コード例 #3
0
ファイル: scenariobase.py プロジェクト: Planet-AI-GmbH/tfaip
 def create_lav(cls, lav_params: "LAVParams",
                scenario_params: TScenarioParams) -> "LAV":
     post_init(lav_params)
     post_init(scenario_params)
     return cls.lav_cls()(
         params=lav_params,
         scenario_params=scenario_params,
         **cls.static_lav_kwargs(scenario_params),
     )
コード例 #4
0
def setup_trainer_params( train_dataset, output_dir, debug=False):
    p = CalamariOmmr4allEnsembleScenario.default_trainer_params()
    p.force_eager = debug
    p.gen.train = train_dataset
    #p.scenario.model.    # p.gen.setup.val = val_dataset
    p.output_dir = output_dir
    p.best_model_prefix = "text"
    p.write_checkpoints = False ##
    post_init(p)
    return p
コード例 #5
0
    def __init__(self, params: TDP, **kwargs):
        assert len(
            kwargs) == 0, f"Not all kwargs processed by subclasses: {kwargs}"
        post_init(params)
        self._params = params
        self.resources = ResourceManager(params.resource_base_path)
        self.resources.register_all(params)
        post_init(params)

        self._pipelines: Dict[PipelineMode, DataPipeline] = {}
コード例 #6
0
def setup_trainer_params(preload=True, debug=False):
    p = CalamariTestEnsembleScenario.default_trainer_params()
    p.force_eager = debug

    p.gen.train = FileDataParams(
        images=glob_all([os.path.join(this_dir, "data", "uw3_50lines", "train", "*.png")]),
        preload=preload,
    )

    post_init(p)
    return p
コード例 #7
0
ファイル: predictorbase.py プロジェクト: Planet-AI-GmbH/tfaip
 def __init__(self, params: PredictorParams, data: "DataBase", **kwargs):
     assert len(
         kwargs) == 0, f"Not all kwargs processed by subclasses: {kwargs}"
     post_init(params)
     self._params = params
     self._params.pipeline.mode = PipelineMode.PREDICTION
     if params.include_targets:
         self._params.pipeline.mode = PipelineMode.EVALUATION
     self.device_config = DeviceConfig(self._params.device)
     self._data = data
     self.benchmark_results = BenchmarkResults()
     self._keras_model: Optional[keras.Model] = None
コード例 #8
0
    def __post_init__(self):
        self.scenario.default_serve_dir = f"{self.best_model_prefix}.ckpt.h5"
        self.scenario.trainer_params_filename = f"{self.best_model_prefix}.ckpt.json"
        self.early_stopping.best_model_name = ""

        self.gen.train_gen().n_folds = self.scenario.model.ensemble
        self.gen.train_gen().channels = self.scenario.data.input_channels
        if self.gen.val_gen() is not None:
            self.gen.val_gen().channels = self.scenario.data.input_channels
            self.gen.val_gen().n_folds = self.scenario.model.ensemble

        if self.network:
            self.scenario.model.layers = graph_params_from_definition_string(
                self.network)
            post_init(self.scenario.model)
コード例 #9
0
    def default_trainer_params(cls):

        p = super().default_trainer_params()
        p.gen.setup.val.batch_size = 1
        p.gen.setup.val.num_processes = 1
        p.gen.setup.train.batch_size = 1
        p.gen.setup.train.num_processes = 1
        p.epochs = 70
        #p.samples_per_epoch = 2
        p.scenario.data.pre_proc.run_parallel = False
        #p.model.ensemble = 1

        #p.scenario.data.pre_proc.
        post_init(p)
        return p
コード例 #10
0
 def test_concat_cnn_architecture(self):
     trainer_params = uw3_trainer_params()
     trainer_params.scenario.model.layers = [
         Conv2DLayerParams(filters=10),
         MaxPool2DLayerParams(),
         DilatedBlockLayerParams(filters=10),
         TransposedConv2DLayerParams(filters=10),
         ConcatLayerParams(
             concat_indices=[1, 4]
         ),  # corresponds to output of first and fourth layer
         Conv2DLayerParams(filters=10),
         BiLSTMLayerParams(hidden_nodes=10),
     ]
     post_init(trainer_params)
     cmd_line_trainer_params = parse_args([
         "--network",
         "conv=10,pool=2x2,db=10:2,tconv=10,concat=1:4,conv=10,lstm=10"
     ])
     self.assertDictEqual(trainer_params.scenario.model.to_dict(),
                          cmd_line_trainer_params.scenario.model.to_dict())
     cmd_line_trainer_params = parse_args([
         "--model.layers",
         "Conv",
         "Pool",
         "DilatedBlock",
         "TConv",
         "Concat",
         "Conv",
         "BiLSTM",
         "--model.layers.0.filters",
         "10",
         "--model.layers.2.filters",
         "10",
         "--model.layers.3.filters",
         "10",
         "--model.layers.4.concat_indices",
         "1",
         "4",
         "--model.layers.5.filters",
         "10",
         "--model.layers.6.hidden_nodes",
         "10",
     ])
     self.assertDictEqual(trainer_params.scenario.model.to_dict(),
                          cmd_line_trainer_params.scenario.model.to_dict())
     with tempfile.TemporaryDirectory() as d:
         trainer_params.output_dir = d
         main(trainer_params)
コード例 #11
0
 def default_trainer_params(cls):
     p = super().default_trainer_params()
     p.scenario.model.layers = [
         Conv2DLayerParams(filters=2),
         MaxPool2DLayerParams(pool_size=IntVec2D(4, 4)),
         BiLSTMLayerParams(hidden_nodes=2),
         DropoutLayerParams(rate=0.5),
     ]
     p.gen.setup.val.batch_size = 1
     p.gen.setup.val.num_processes = 1
     p.gen.setup.train.batch_size = 1
     p.gen.setup.train.num_processes = 1
     p.epochs = 1
     p.samples_per_epoch = 2
     p.scenario.data.pre_proc.run_parallel = False
     post_init(p)
     return p
コード例 #12
0
ファイル: nashi_client.py プロジェクト: andbue/nashi
def get_preproc_image():
    data_params = Data.default_params()
    data_params.skip_invalid_gt = False
    data_params.pre_proc.run_parallel = False
    data_params.pre_proc.processors = data_params.pre_proc.processors[:-1]
    for p in data_params.pre_proc.processors_of_type(FinalPreparationProcessorParams):
        p.pad = 0
    post_init(data_params)
    pl = Data(data_params).create_pipeline(DataPipelineParams, None)
    pl.mode = PipelineMode.PREDICTION
    preproc = data_params.pre_proc.create(pl)

    def pp(image):
        its = InputSample(
            image, None, SampleMeta("001", fold_id="01")
        ).to_input_target_sample()
        s = preproc.apply_on_sample(its)
        return s.inputs

    return pp
コード例 #13
0
 def test_pure_cnn_architecture(self):
     trainer_params = uw3_trainer_params()
     trainer_params.scenario.model.layers = [
         Conv2DLayerParams(filters=10),
         MaxPool2DLayerParams(),
         Conv2DLayerParams(filters=20,
                           strides=IntVec2D(2, 2),
                           kernel_size=IntVec2D(4, 4)),
         Conv2DLayerParams(filters=30),
     ]
     post_init(trainer_params)
     cmd_line_trainer_params = parse_args(
         ["--network", "conv=10,pool=2x2,conv=20:4x4:2x2,conv=30"])
     self.assertDictEqual(trainer_params.scenario.model.to_dict(),
                          cmd_line_trainer_params.scenario.model.to_dict())
     cmd_line_trainer_params = parse_args([
         "--model.layers",
         "Conv",
         "Pool",
         "Conv",
         "Conv",
         "--model.layers.0.filters",
         "10",
         "--model.layers.2.filters",
         "20",
         "--model.layers.2.kernel_size.x",
         "4",
         "--model.layers.2.kernel_size.y",
         "4",
         "--model.layers.2.strides.x",
         "2",
         "--model.layers.2.strides.y",
         "2",
         "--model.layers.3.filters",
         "30",
     ])
     self.assertDictEqual(trainer_params.scenario.model.to_dict(),
                          cmd_line_trainer_params.scenario.model.to_dict())
     with tempfile.TemporaryDirectory() as d:
         trainer_params.output_dir = d
         main(trainer_params)
コード例 #14
0
ファイル: nashi_client.py プロジェクト: andbue/nashi
def get_preproc_text(rtl=False):
    data_params = Data.default_params()
    data_params.skip_invalid_gt = False
    data_params.pre_proc.run_parallel = False

    if rtl:
        for p in data_params.pre_proc.processors_of_type(BidiTextProcessorParams):
            p.bidi_direction = BidiDirection.RTL
    post_init(data_params)

    pl = Data(data_params).create_pipeline(DataPipelineParams, None)
    pl.mode = PipelineMode.TARGETS
    preproc = data_params.pre_proc.create(pl)

    def pp(text):
        its = InputSample(
            None, text, SampleMeta("001", fold_id="01")
        ).to_input_target_sample()
        s = preproc.apply_on_sample(its)
        return s.targets

    return pp
コード例 #15
0
 def test_dilated_block_architecture(self):
     trainer_params = uw3_trainer_params()
     trainer_params.scenario.model.layers = [
         Conv2DLayerParams(filters=10),
         MaxPool2DLayerParams(strides=IntVec2D(2, 2)),
         DilatedBlockLayerParams(filters=10),
         DilatedBlockLayerParams(filters=10),
         Conv2DLayerParams(filters=10),
     ]
     post_init(trainer_params)
     cmd_line_trainer_params = parse_args(
         ["--network", "conv=10,pool=2x2:2x2,db=10:2,db=10:2,conv=10"])
     self.assertDictEqual(trainer_params.scenario.model.to_dict(),
                          cmd_line_trainer_params.scenario.model.to_dict())
     cmd_line_trainer_params = parse_args([
         "--model.layers",
         "Conv",
         "Pool",
         "DilatedBlock",
         "DilatedBlock",
         "Conv",
         "--model.layers.0.filters",
         "10",
         "--model.layers.1.strides",
         "2",
         "2",
         "--model.layers.2.filters",
         "10",
         "--model.layers.3.filters",
         "10",
         "--model.layers.4.filters",
         "10",
     ])
     self.assertDictEqual(trainer_params.scenario.model.to_dict(),
                          cmd_line_trainer_params.scenario.model.to_dict())
     with tempfile.TemporaryDirectory() as d:
         trainer_params.output_dir = d
         main(trainer_params)
コード例 #16
0
 def test_pure_lstm_architecture(self):
     trainer_params = uw3_trainer_params()
     trainer_params.scenario.model.layers = [
         BiLSTMLayerParams(hidden_nodes=10),
         BiLSTMLayerParams(hidden_nodes=20),
     ]
     post_init(trainer_params)
     cmd_line_trainer_params = parse_args(["--network", "lstm=10,lstm=20"])
     self.assertDictEqual(trainer_params.scenario.model.to_dict(),
                          cmd_line_trainer_params.scenario.model.to_dict())
     cmd_line_trainer_params = parse_args([
         "--model.layers",
         "BiLSTM",
         "BiLSTM",
         "--model.layers.0.hidden_nodes",
         "10",
         "--model.layers.1.hidden_nodes",
         "20",
     ])
     self.assertDictEqual(trainer_params.scenario.model.to_dict(),
                          cmd_line_trainer_params.scenario.model.to_dict())
     with tempfile.TemporaryDirectory() as d:
         trainer_params.output_dir = d
         main(trainer_params)
コード例 #17
0
ファイル: scenariobase.py プロジェクト: Planet-AI-GmbH/tfaip
    def create_predictor(cls, model: str,
                         params: "PredictorParams") -> "Predictor":
        post_init(params)
        scenario_params = cls.params_from_path(model)
        data_params = scenario_params.data
        post_init(data_params)
        predictor = cls.predictor_cls()(
            params,
            cls.data_cls()(data_params,
                           **cls.static_data_kwargs(scenario_params)),
            **cls.static_predictor_kwargs(scenario_params),
        )
        model_cls = cls.model_cls()
        run_eagerly = params.run_eagerly
        if isinstance(model, str):
            model = keras.models.load_model(
                os.path.join(model, "serve"),
                compile=False,
                custom_objects=model_cls.all_custom_objects()
                if run_eagerly else model_cls.base_custom_objects(),
            )

        predictor.set_model(model)
        return predictor
コード例 #18
0
ファイル: scenariobase.py プロジェクト: Planet-AI-GmbH/tfaip
 def create_multi_lav(
     cls,
     lav_params: "LAVParams",
     scenario_params: TScenarioParams,
     predictor_params: Optional["PredictorParams"] = None,
 ):
     post_init(lav_params)
     post_init(scenario_params)
     post_init(predictor_params)
     return MultiLAV(
         lav_params,
         scenario_params,
         predictor_params=predictor_params
         or cls.multi_predictor_cls().params_cls()(),
         **cls.static_lav_kwargs(scenario_params),
     )
コード例 #19
0
ファイル: evaluator.py プロジェクト: Planet-AI-GmbH/tfaip
 def __init__(self, params: TP, **kwargs):
     assert len(
         kwargs) == 0, f"Not all kwargs processed by subclasses: {kwargs}"
     post_init(params)
     self.params = params
コード例 #20
0
ファイル: scenariobase.py プロジェクト: Planet-AI-GmbH/tfaip
 def create_trainer(cls,
                    trainer_params: "TrainerParams",
                    restore=False) -> "Trainer":
     post_init(trainer_params)
     return cls.trainer_cls()(trainer_params, cls(trainer_params.scenario),
                              restore)
コード例 #21
0
ファイル: version3_4to5.py プロジェクト: jacektl/calamari
def migrate3to5(trainer_params: dict) -> dict:
    convert_processor_name = {
        "CenterNormalizer":
        "calamari_ocr.ocr.dataset.imageprocessors.center_normalizer:CenterNormalizerProcessorParams",
        "DataRangeNormalizer":
        "calamari_ocr.ocr.dataset.imageprocessors.data_range_normalizer:DataRangeProcessorParams",
        "FinalPreparation":
        "calamari_ocr.ocr.dataset.imageprocessors.final_preparation:FinalPreparationProcessorParams",
        "AugmentationProcessor":
        "calamari_ocr.ocr.dataset.imageprocessors.augmentation:AugmentationProcessorParams",
        "StripTextProcessor":
        "calamari_ocr.ocr.dataset.textprocessors.basic_text_processors:StripTextProcessorParams",
        "BidiTextProcessor":
        "calamari_ocr.ocr.dataset.textprocessors.basic_text_processors:BidiTextProcessorParams",
        "TextNormalizer":
        "calamari_ocr.ocr.dataset.textprocessors.text_normalizer:TextNormalizerProcessorParams",
        "TextRegularizer":
        "calamari_ocr.ocr.dataset.textprocessors.text_regularizer:TextRegularizerProcessorParams",
        "PrepareSampleProcessor":
        "calamari_ocr.ocr.dataset.imageprocessors.preparesample:PrepareSampleProcessorParams",
        "ReshapeOutputsProcessor":
        "calamari_ocr.ocr.dataset.postprocessors.reshape:ReshapeOutputsProcessorParams",
        "CTCDecoderProcessor":
        "calamari_ocr.ocr.dataset.postprocessors.ctcdecoder:CTCDecoderProcessorParams",
    }

    for name in ["scenario"]:
        rename(trainer_params, name + "_params", name)

    for name in ["model", "data"]:
        rename(trainer_params["scenario"], name + "_params", name)

    scenario = trainer_params["scenario"]
    scenario["data"]["__cls__"] = "calamari_ocr.ocr.dataset.params:DataParams"
    scenario["model"]["__cls__"] = "calamari_ocr.ocr.model.params:ModelParams"

    data = scenario["data"]
    rename(data, "line_height_", "line_height")
    rename(data, "skip_invalid_gt_", "skip_invalid_gt")
    rename(data, "resource_base_path_", "resource_base_path")
    rename(data, "pre_processors_", "pre_proc")
    rename(data, "post_processors_", "post_proc")
    rename(data["pre_proc"], "sample_processors", "processors")
    rename(data["post_proc"], "sample_processors", "processors")
    data["post_proc"][
        "__cls__"] = "tfaip.data.pipeline.processor.params:SequentialProcessorPipelineParams"
    data["pre_proc"][
        "__cls__"] = "tfaip.data.pipeline.processor.params:SequentialProcessorPipelineParams"
    for proc in data["pre_proc"]["processors"] + data["post_proc"][
            "processors"]:
        if "args" in proc:
            args = proc["args"]
            if args:
                for k, v in args.items():
                    proc[k] = v
            del proc["args"]
        name = proc["name"]
        del proc["name"]
        proc["__cls__"] = convert_processor_name[name]

    migrate_model_params(scenario["model"])

    params = CalamariScenario.params_from_dict(scenario)
    post_init(params)
    scenario = params.to_dict()

    return {"scenario": scenario}
コード例 #22
0
ファイル: scenariobase.py プロジェクト: Planet-AI-GmbH/tfaip
 def create_multi_predictor(
         cls, paths: List[str],
         params: "PredictorParams") -> "MultiModelPredictor":
     post_init(params)
     predictor_cls = cls.multi_predictor_cls()
     return predictor_cls.from_paths(paths, params, cls)
コード例 #23
0
ファイル: scenariobase.py プロジェクト: Planet-AI-GmbH/tfaip
 def create_evaluator(cls, params: EvaluatorParams,
                      **kwargs) -> EvaluatorBase:
     post_init(params)
     if cls.evaluator_cls() is None:
         raise NotImplementedError
     return cls.evaluator_cls()(params=params, **kwargs)
コード例 #24
0
def migrate3to4(trainer_params: dict) -> dict:
    convert_processor_name = {
        "CenterNormalizer":
        "calamari_ocr.ocr.dataset.imageprocessors.center_normalizer:CenterNormalizerProcessorParams",
        "DataRangeNormalizer":
        "calamari_ocr.ocr.dataset.imageprocessors.data_range_normalizer:DataRangeProcessorParams",
        "FinalPreparation":
        "calamari_ocr.ocr.dataset.imageprocessors.final_preparation:FinalPreparationProcessorParams",
        "AugmentationProcessor":
        "calamari_ocr.ocr.dataset.imageprocessors.augmentation:AugmentationProcessorParams",
        "StripTextProcessor":
        "calamari_ocr.ocr.dataset.textprocessors.basic_text_processors:StripTextProcessorParams",
        "TextNormalizer":
        "calamari_ocr.ocr.dataset.textprocessors.text_normalizer:TextNormalizerProcessorParams",
        "TextRegularizer":
        "calamari_ocr.ocr.dataset.textprocessors.text_regularizer:TextRegularizerProcessorParams",
        "PrepareSampleProcessor":
        "calamari_ocr.ocr.dataset.imageprocessors.preparesample:PrepareSampleProcessorParams",
        "ReshapeOutputsProcessor":
        "calamari_ocr.ocr.dataset.postprocessors.reshape:ReshapeOutputsProcessorParams",
        "CTCDecoderProcessor":
        "calamari_ocr.ocr.dataset.postprocessors.ctcdecoder:CTCDecoderProcessorParams",
    }

    for name in ['scenario']:
        rename(trainer_params, name + '_params', name)

    for name in ['model', 'data']:
        rename(trainer_params['scenario'], name + '_params', name)

    scenario = trainer_params['scenario']
    scenario['data']['__cls__'] = 'calamari_ocr.ocr.dataset.params:DataParams'
    scenario['model']["__cls__"] = "calamari_ocr.ocr.model.params:ModelParams"

    data = scenario['data']
    rename(data, 'line_height_', 'line_height')
    rename(data, 'skip_invalid_gt_', 'skip_invalid_gt')
    rename(data, 'resource_base_path_', 'resource_base_path')
    rename(data, 'pre_processors_', 'pre_proc')
    rename(data, 'post_processors_', 'post_proc')
    rename(data['pre_proc'], 'sample_processors', 'processors')
    rename(data['post_proc'], 'sample_processors', 'processors')
    data['post_proc'][
        '__cls__'] = "tfaip.data.pipeline.processor.params:SequentialProcessorPipelineParams"
    data['pre_proc'][
        '__cls__'] = "tfaip.data.pipeline.processor.params:SequentialProcessorPipelineParams"
    for proc in data['pre_proc']['processors'] + data['post_proc'][
            'processors']:
        if 'args' in proc:
            args = proc['args']
            if args:
                for k, v in args.items():
                    proc[k] = v
            del proc['args']
        name = proc['name']
        del proc['name']
        proc['__cls__'] = convert_processor_name[name]

    migrate_model_params(scenario['model'])

    params = CalamariScenario.params_from_dict(scenario)
    post_init(params)
    scenario = params.to_dict()

    return {'scenario': scenario}
コード例 #25
0
    def run(self):
        # temporary dir
        temporary_dir = self.params.temporary_dir
        if temporary_dir is None:
            temporary_dir = tempfile.mkdtemp(prefix="calamari")
        else:
            temporary_dir = os.path.abspath(temporary_dir)

        if not os.path.exists(temporary_dir):
            os.makedirs(temporary_dir)

        # Compute the files in the cross fold (create a CrossFold)
        fold_file = os.path.join(temporary_dir, "folds.json")
        cross_fold = CrossFold(
            n_folds=self.params.n_folds,
            data_generator_params=self.params.trainer.gen.train,
            output_dir=temporary_dir,
            progress_bar=self.params.trainer.progress_bar)
        cross_fold.write_folds_to_json(fold_file)

        # Create the json argument file for each individual training
        run_args = []
        seed = self.params.trainer.random_seed or -1
        folds_to_run = self.params.single_fold if len(
            self.params.single_fold) > 0 else range(len(cross_fold.folds))
        for fold in folds_to_run:
            train_files = cross_fold.train_files(fold)
            test_files = cross_fold.test_files(fold)
            path = os.path.join(temporary_dir, "fold_{}.json".format(fold))
            with open(path, 'w') as f:
                trainer_params = deepcopy(self.params.trainer)
                trainer_params.gen = CalamariDefaultTrainerPipelineParams(
                    train=trainer_params.gen.train,
                    val=deepcopy(trainer_params.gen.train),
                    setup=trainer_params.gen.setup,
                )
                if cross_fold.is_h5_dataset:
                    tp = trainer_params.gen.train.to_dict()
                    del tp['__cls__']
                    tp["files"] = train_files
                    trainer_params.gen.train = Hdf5.from_dict(tp)
                    vp = trainer_params.gen.val.to_dict()
                    del vp['__cls__']
                    vp['files'] = test_files
                    trainer_params.gen.val = Hdf5.from_dict(vp)
                else:
                    trainer_params.gen.train.images = train_files
                    trainer_params.gen.val.images = test_files
                    trainer_params.gen.val.gt_extension = trainer_params.gen.train.gt_extension

                trainer_params.scenario.id = fold
                trainer_params.progress_bar_mode = 2
                trainer_params.output_dir = os.path.join(
                    temporary_dir, "fold_{}".format(fold))
                trainer_params.early_stopping.best_model_output_dir = self.params.best_models_dir
                trainer_params.early_stopping.best_model_name = ''
                best_model_prefix = self.params.best_model_label.format(
                    id=fold)
                trainer_params.best_model_prefix = best_model_prefix

                if self.params.visible_gpus:
                    assert trainer_params.device.gpus is None, "Using visible_gpus with device.gpus is not supported"
                    trainer_params.device.gpus = [
                        self.params.visible_gpus[fold %
                                                 len(self.params.visible_gpus)]
                    ]

                if seed >= 0:
                    trainer_params.random_seed = seed + fold

                if len(self.params.weights) == 1:
                    trainer_params.warmstart.model = self.params.weights[0]
                elif len(self.params.weights) > 1:
                    trainer_params.warmstart.model = self.params.weights[fold]

                # start from scratch via None
                if trainer_params.warmstart.model:
                    if len(
                            trainer_params.warmstart.model.strip()
                    ) == 0 or trainer_params.warmstart.model.upper() == "NONE":
                        trainer_params.warmstart.model = None
                    else:
                        # access model once to upgrade the model if necessary
                        # (can not be performed in parallel if multiple folds use the same model)
                        SavedCalamariModel(trainer_params.warmstart.model)

                post_init(trainer_params)

                json.dump(
                    trainer_params.to_dict(),
                    f,
                    indent=4,
                )

            run_args.append({
                "json": path,
                "args": trainer_params,
                "id": fold,
                'train_script': self.train_script_path,
                'run': self.params.run,
                'verbose': True
            })

        # Launch the individual processes for each training
        with multiprocessing.pool.ThreadPool(
                processes=self.params.max_parallel_models) as pool:
            # workaround to forward keyboard interrupt
            pool.map_async(train_individual_model, run_args).get()

        if not self.params.keep_temporary_files:
            import shutil
            shutil.rmtree(temporary_dir)