コード例 #1
0
    def train(self,
              feature_folder,
              semi_feature_folder=None,
              model_name=None,
              input_model_path=None,
              vocal_settings=None):
        """Model training.

        Train a new model or continue to train on a previously trained model.

        Parameters
        ----------
        feature_folder: Path
            Path to the folder containing generated feature.
        semi_feature_folder: Path
            If specified, semi-supervise learning will be leveraged, and the feature
            files contained in this folder will be used as unsupervised data.
        model_name: str
            The name for storing the trained model. If not given, will default to the
            current timesamp.
        input_model_path: Path
            Continue to train on the pre-trained model by specifying the path.
        vocal_settings: VocalSettings
            The configuration instance that holds all relative settings for
            the life-cycle of building a model.
        """
        settings = self._validate_and_get_settings(vocal_settings)

        if input_model_path is not None:
            logger.info("Continue to train on model: %s", input_model_path)
            model, prev_set = self._load_model(input_model_path)
            settings.model.save_path = prev_set.model.save_path

        logger.info("Constructing dataset instance")
        split = settings.training.steps / (settings.training.steps +
                                           settings.training.val_steps)
        train_feat_files, val_feat_files = get_train_val_feat_file_list(
            feature_folder, split=split)

        output_types = (tf.float32, tf.float32)
        output_shapes = ((settings.training.context_length * 2 + 1, 174, 9),
                         (19, 6))  # noqa: E226
        train_dataset = VocalDatasetLoader(
                ctx_len=settings.training.context_length,
                feature_files=train_feat_files,
                num_samples=settings.training.epoch * settings.training.batch_size * settings.training.steps
            ) \
            .get_dataset(settings.training.batch_size, output_types=output_types, output_shapes=output_shapes)
        val_dataset = VocalDatasetLoader(
                ctx_len=settings.training.context_length,
                feature_files=val_feat_files,
                num_samples=settings.training.epoch * settings.training.val_batch_size * settings.training.val_steps
            ) \
            .get_dataset(settings.training.val_batch_size, output_types=output_types, output_shapes=output_shapes)
        if semi_feature_folder is not None:
            # Semi-supervise learning dataset.
            feat_files = glob.glob(f"{semi_feature_folder}/*.hdf")
            semi_dataset = VocalDatasetLoader(
                    ctx_len=settings.training.context_length,
                    feature_files=feat_files,
                    num_samples=settings.training.epoch * settings.training.batch_size * settings.training.steps
                ) \
                .get_dataset(settings.training.batch_size, output_types=output_types, output_shapes=output_shapes)
            train_dataset = tf.data.Dataset.zip((train_dataset, semi_dataset))

        if input_model_path is None:
            logger.info("Constructing new model")
            model = self.get_model(settings)

        # Notice: the original implementation uses AdamW as the optimizer, which is also viable through
        # tensorflow_addons.optimizers.AdamW. However we found that by using AdamW, the model would fail
        # to converge, and instead the training loss will get higher and higher.
        optimizer = tf.keras.optimizers.Adam(
            learning_rate=settings.training.init_learning_rate)
        model.compile(optimizer=optimizer,
                      loss='bce',
                      metrics=['accuracy', 'binary_accuracy'])

        logger.info("Resolving model output path")
        if model_name is None:
            model_name = str(datetime.now()).replace(" ", "_")
        if not model_name.startswith(settings.model.save_prefix):
            model_name = settings.model.save_prefix + "_" + model_name
        model_save_path = jpath(settings.model.save_path, model_name)
        ensure_path_exists(model_save_path)
        write_yaml(settings.to_json(),
                   jpath(model_save_path, "configurations.yaml"))
        logger.info("Model output to: %s", model_save_path)

        logger.info("Constructing callbacks")
        callbacks = [
            tf.keras.callbacks.EarlyStopping(
                patience=settings.training.early_stop, monitor="val_loss"),
            tf.keras.callbacks.ModelCheckpoint(jpath(model_save_path,
                                                     "weights"),
                                               save_weights_only=True,
                                               monitor="val_loss")
        ]
        logger.info("Callback list: %s", callbacks)

        logger.info("Start training")
        history = model.fit(train_dataset,
                            validation_data=val_dataset,
                            epochs=settings.training.epoch,
                            steps_per_epoch=settings.training.steps,
                            validation_steps=settings.training.val_steps,
                            callbacks=callbacks,
                            use_multiprocessing=True,
                            workers=8)
        return model_save_path, history
コード例 #2
0
    def train(self,
              feature_folder,
              model_name=None,
              input_model_path=None,
              music_settings=None):
        """Model training.

        Train the model from scratch or continue training given a model checkpoint.

        Parameters
        ----------
        feature_folder: Path
            Path to the generated feature.
        model_name: str
            The name of the trained model. If not given, will default to the
            current timestamp.
        input_model_path: Path
            Specify the path to the model checkpoint in order to fine-tune
            the model.
        music_settings: MusicSettings
            The configuration that holds all relative settings for
            the life-cycle of model building.
        """
        settings = self._validate_and_get_settings(music_settings)

        if input_model_path is not None:
            logger.info("Continue to train on model: %s", input_model_path)
            model, prev_set = self._load_model(
                input_model_path, custom_objects=self.custom_objects)
            settings.training.timesteps = prev_set.training.timesteps
            settings.training.label_type = prev_set.training.label_type
            settings.training.channels = prev_set.training.channels
            settings.model.save_path = prev_set.model.save_path
            settings.transcription_mode = prev_set.transcription_mode

        logger.info("Using label type: %s", settings.training.label_type)
        l_type = LabelType(settings.training.label_type)
        settings.transcription_mode = self.label_trans_mode_mapping[
            settings.training.label_type]

        logger.info("Constructing dataset instance")
        split = settings.training.steps / (settings.training.steps +
                                           settings.training.val_steps)
        train_feat_files, val_feat_files = get_train_val_feat_file_list(
            feature_folder, split=split)

        output_types = (tf.float32, tf.float32)
        train_dataset = MusicDatasetLoader(
                l_type.get_conversion_func(),
                feature_files=train_feat_files,
                num_samples=settings.training.batch_size * settings.training.steps,
                timesteps=settings.training.timesteps,
                channels=[FEATURE_NAME_TO_NUMBER[ch_name] for ch_name in settings.training.channels],
                feature_num=settings.training.feature_num
            ) \
            .get_dataset(settings.training.batch_size, output_types=output_types)
        val_dataset = MusicDatasetLoader(
                l_type.get_conversion_func(),
                feature_files=val_feat_files,
                num_samples=settings.training.val_batch_size * settings.training.val_steps,
                timesteps=settings.training.timesteps,
                channels=[FEATURE_NAME_TO_NUMBER[ch_name] for ch_name in settings.training.channels],
                feature_num=settings.training.feature_num
            ) \
            .get_dataset(settings.training.val_batch_size, output_types=output_types)

        if input_model_path is None:
            logger.info("Creating new model with type: %s",
                        settings.model.model_type)
            model_func = {
                "aspp": semantic_segmentation,
                "attn": semantic_segmentation_attn
            }[settings.model.model_type]
            model = model_func(timesteps=settings.training.timesteps,
                               out_class=l_type.get_out_classes(),
                               ch_num=len(settings.training.channels))

        logger.info("Compiling model with loss function type: %s",
                    settings.training.loss_function)
        loss_func = {
            "smooth":
            lambda y, x: smooth_loss(y, x, total_chs=l_type.get_out_classes()),
            "focal":
            focal_loss,
            "bce":
            tf.keras.losses.BinaryCrossentropy()
        }[settings.training.loss_function]
        model.compile(optimizer="adam", loss=loss_func, metrics=['accuracy'])

        logger.info("Resolving model output path")
        if model_name is None:
            model_name = str(datetime.now()).replace(" ", "_")
        if not model_name.startswith(settings.model.save_prefix):
            model_name = settings.model.save_prefix + "_" + model_name
        model_save_path = jpath(settings.model.save_path, model_name)
        ensure_path_exists(model_save_path)
        write_yaml(settings.to_json(),
                   jpath(model_save_path, "configurations.yaml"))
        write_yaml(model.to_yaml(),
                   jpath(model_save_path, "arch.yaml"),
                   dump=False)
        logger.info("Model output to: %s", model_save_path)

        logger.info("Constructing callbacks")
        callbacks = [
            EarlyStopping(patience=settings.training.early_stop),
            ModelCheckpoint(model_save_path, save_weights_only=True)
        ]
        logger.info("Callback list: %s", callbacks)

        logger.info("Start training")
        history = train_epochs(model,
                               train_dataset,
                               validate_dataset=val_dataset,
                               epochs=settings.training.epoch,
                               steps=settings.training.steps,
                               val_steps=settings.training.val_steps,
                               callbacks=callbacks)
        return model_save_path, history
コード例 #3
0
    def train(self, feature_folder, model_name=None, input_model_path=None, beat_settings=None):
        """Model training.

        Train the model from scratch or continue training given a model checkpoint.

        Parameters
        ----------
        feature_folder: Path
            Path to the generated feature.
        model_name: str
            The name of the trained model. If not given, will default to the
            current timestamp.
        input_model_path: Path
            Specify the path to the model checkpoint in order to fine-tune
            the model.
        beat_settings: BeatSettings
            The configuration that holds all relative settings for
            the life-cycle of model building.
        """
        settings = self._validate_and_get_settings(beat_settings)

        if input_model_path is not None:
            logger.info("Continue to train on model: %s", input_model_path)
            model, prev_set = self._load_model(input_model_path)
            settings.model.from_json(prev_set.model.to_json())
            settings.feature.time_unit = prev_set.feature.time_unit

        logger.info("Constructing dataset instance")
        split = settings.training.steps / (settings.training.steps + settings.training.val_steps)
        train_feat_files, val_feat_files = get_train_val_feat_file_list(feature_folder, split=split)

        output_types = (tf.float32, tf.float32)
        output_shapes = ((settings.model.timesteps, 178), (settings.model.timesteps, 2))
        train_dataset = BeatDatasetLoader(
                feature_files=train_feat_files,
                num_samples=settings.training.epoch * settings.training.batch_size * settings.training.steps,
                slice_hop=settings.model.timesteps // 2
            ) \
            .get_dataset(settings.training.batch_size, output_types=output_types, output_shapes=output_shapes)
        val_dataset = BeatDatasetLoader(
                feature_files=val_feat_files,
                num_samples=settings.training.epoch * settings.training.val_batch_size * settings.training.val_steps,
                slice_hop=settings.model.timesteps // 2
            ) \
            .get_dataset(settings.training.val_batch_size, output_types=output_types, output_shapes=output_shapes)

        if input_model_path is None:
            logger.info("Constructing new %s model for training.", settings.model.model_type)
            model_func = {
                "blstm": self._construct_blstm_model,
                "blstm_attn": self._construct_blstm_attn_model
            }[settings.model.model_type]
            model = model_func(settings)

        logger.info("Compiling model")
        optimizer = tf.keras.optimizers.Adam(learning_rate=settings.training.init_learning_rate)
        loss = lambda y, x: weighted_binary_crossentropy(y, x, down_beat_weight=settings.training.down_beat_weight)
        model.compile(optimizer=optimizer, loss=loss, metrics=["accuracy"])

        logger.info("Resolving model output path")
        if model_name is None:
            model_name = str(datetime.now()).replace(" ", "_")
        if not model_name.startswith(settings.model.save_prefix):
            model_name = settings.model.save_prefix + "_" + model_name
        model_save_path = jpath(settings.model.save_path, model_name)
        ensure_path_exists(model_save_path)
        write_yaml(settings.to_json(), jpath(model_save_path, "configurations.yaml"))
        write_yaml(model.to_yaml(), jpath(model_save_path, "arch.yaml"), dump=False)
        logger.info("Model output to: %s", model_save_path)

        logger.info("Constructing callbacks")
        callbacks = [
            tf.keras.callbacks.EarlyStopping(
                patience=settings.training.early_stop, monitor="val_loss", restore_best_weights=False
            ),
            tf.keras.callbacks.ModelCheckpoint(
                jpath(model_save_path, "weights.h5"), save_weights_only=True, monitor="val_loss"
            )
        ]
        logger.info("Callback list: %s", callbacks)

        logger.info("Start training")
        history = model.fit(
            train_dataset,
            validation_data=val_dataset,
            epochs=settings.training.epoch,
            steps_per_epoch=settings.training.steps,
            validation_steps=settings.training.val_steps,
            callbacks=callbacks,
            use_multiprocessing=True,
            workers=8
        )
        return model_save_path, history
コード例 #4
0
    def train(self,
              feature_folder,
              model_name=None,
              input_model_path=None,
              drum_settings=None):
        """Model training.

        Train a new model or continue to train on a previously trained model.

        Parameters
        ----------
        feature_folder: Path
            Path to the folder containing generated feature.
        model_name: str
            The name for storing the trained model. If not given, will default to the
            current timesamp.
        input_model_path: Path
            Continue to train on the pre-trained model by specifying the path.
        drum_settings: DrumSettings
            The configuration instance that holds all relative settings for
            the life-cycle of building a model.
        """
        settings = self._validate_and_get_settings(drum_settings)

        if input_model_path is not None:
            logger.info("Continue to train on model: %s", input_model_path)
            model, prev_set = self._load_model(
                input_model_path, custom_objects=self.custom_objects)
            settings.model.save_path = prev_set.model.save_path
            settings.training.init_learninig_rate = prev_set.training.init_learning_rate
            settings.training.res_block_num = prev_set.training.res_block_num

        logger.info("Constructing dataset instance")
        split = settings.training.steps / (settings.training.steps +
                                           settings.training.val_steps)
        train_feat_files, val_feat_files = get_train_val_feat_file_list(
            feature_folder, split=split)

        output_types = (tf.float32, tf.float32)
        output_shapes = ([120, 120, 4], [4, 13])
        train_dataset = PopDatasetLoader(
                feature_files=train_feat_files,
                num_samples=settings.training.epoch * settings.training.batch_size * settings.training.steps
            ) \
            .get_dataset(settings.training.batch_size, output_types=output_types, output_shapes=output_shapes)
        val_dataset = PopDatasetLoader(
                feature_files=val_feat_files,
                num_samples=settings.training.epoch * settings.training.val_batch_size * settings.training.val_steps
            ) \
            .get_dataset(settings.training.val_batch_size, output_types=output_types, output_shapes=output_shapes)

        if input_model_path is None:
            logger.info("Constructing new model")
            model = drum_model(
                out_classes=13,
                mini_beat_per_seg=settings.feature.mini_beat_per_segment,
                res_block_num=settings.training.res_block_num)

        optimizer = tf.keras.optimizers.Adam(
            learning_rate=settings.training.init_learning_rate)
        model.compile(optimizer=optimizer,
                      loss=loss_func,
                      metrics=["accuracy"])

        logger.info("Resolving model output path")
        if model_name is None:
            model_name = str(datetime.now()).replace(" ", "_")
        if not model_name.startswith(settings.model.save_prefix):
            model_name = settings.model.save_prefix + "_" + model_name
        model_save_path = jpath(settings.model.save_path, model_name)
        ensure_path_exists(model_save_path)
        write_yaml(settings.to_json(),
                   jpath(model_save_path, "configurations.yaml"))
        write_yaml(model.to_yaml(),
                   jpath(model_save_path, "arch.yaml"),
                   dump=False)
        logger.info("Model output to: %s", model_save_path)

        logger.info("Constructing callbacks")
        callbacks = [
            tf.keras.callbacks.EarlyStopping(
                patience=settings.training.early_stop, monitor="val_loss"),
            tf.keras.callbacks.ModelCheckpoint(jpath(model_save_path,
                                                     "weights.h5"),
                                               save_weights_only=True)
        ]
        logger.info("Callback list: %s", callbacks)

        logger.info("Start training")
        history = model.fit(train_dataset,
                            validation_data=val_dataset,
                            epochs=settings.training.epoch,
                            steps_per_epoch=settings.training.steps,
                            validation_steps=settings.training.val_steps,
                            callbacks=callbacks,
                            use_multiprocessing=True,
                            workers=8)
        return model_save_path, history
コード例 #5
0
    def train(self,
              feature_folder,
              model_name=None,
              input_model_path=None,
              chord_settings=None):
        """Model training.

        Train a new music model or continue to train on a pre-trained model.

        Parameters
        ----------
        feature_folder: Path
            Path to the generated feature.
        model_name: str
            The name of the trained model. If not given, will default to the
            current timestamp.
        input_model_path: Path
            Specify the path to the pre-trained model if you want to continue
            to fine-tune on the model.
        chord_settings: ChordSettings
            The configuration instance that holds all relative settings for
            the life-cycle of building a model.
        """
        settings = self._validate_and_get_settings(chord_settings)

        if input_model_path is not None:
            logger.info("Continue to train one model: %s", input_model_path)
            model, _ = self._load_model(input_model_path)

        split = settings.training.steps / (settings.training.steps +
                                           settings.training.val_steps)
        train_feat_files, val_feat_files = get_train_val_feat_file_list(
            feature_folder, split=split)

        output_types = (tf.float32, (tf.int32, tf.int32))
        output_shapes = ([
            settings.feature.num_steps, settings.feature.segment_width * 24
        ], ([settings.feature.num_steps], [settings.feature.num_steps]))
        train_dataset = McGillDatasetLoader(
                feature_files=train_feat_files,
                num_samples=settings.training.epoch * settings.training.batch_size * settings.training.steps
            ) \
            .get_dataset(settings.training.batch_size, output_types=output_types, output_shapes=output_shapes)
        val_dataset = McGillDatasetLoader(
                feature_files=val_feat_files,
                num_samples=settings.training.epoch * settings.training.val_batch_size * settings.training.val_steps
            ) \
            .get_dataset(settings.training.batch_size, output_types=output_types, output_shapes=output_shapes)

        if input_model_path is None:
            logger.info("Constructing new model")
            model = self.get_model(settings)

        learninig_rate = tf.keras.optimizers.schedules.ExponentialDecay(
            settings.training.init_learning_rate,
            decay_steps=settings.training.steps,
            decay_rate=settings.training.learning_rate_decay,
            staircase=True)
        optimizer = tf.keras.optimizers.Adam(learning_rate=learninig_rate,
                                             clipvalue=1)
        model.compile(optimizer=optimizer,
                      loss=chord_loss_func,
                      metrics=["accuracy"])

        logger.info("Resolving model output path")
        if model_name is None:
            model_name = str(datetime.now()).replace(" ", "_")
        if not model_name.startswith(settings.model.save_prefix):
            model_name = settings.model.save_prefix + "_" + model_name
        model_save_path = jpath(settings.model.save_path, model_name)
        ensure_path_exists(model_save_path)
        write_yaml(settings.to_json(),
                   jpath(model_save_path, "configurations.yaml"))
        logger.info("Model output to: %s", model_save_path)

        callbacks = [
            tf.keras.callbacks.EarlyStopping(
                patience=settings.training.early_stop, monitor="val_loss"),
            tf.keras.callbacks.ModelCheckpoint(jpath(model_save_path,
                                                     "weights"),
                                               save_weights_only=True,
                                               monitor="val_loss"),
            ReduceSlope()
        ]

        history = model.fit(train_dataset,
                            validation_data=val_dataset,
                            epochs=settings.training.epoch,
                            steps_per_epoch=settings.training.steps,
                            validation_steps=settings.training.val_steps,
                            callbacks=callbacks)
        return history
コード例 #6
0
ファイル: app.py プロジェクト: ykhorzon/omnizart
    def train(self,
              feature_folder,
              model_name=None,
              input_model_path=None,
              patch_cnn_settings=None):
        """Model training.

        Train the model from scratch or continue training given a model checkpoint.

        Parameters
        ----------
        feature_folder: Path
            Path to the generated feature.
        model_name: str
            The name of the trained model. If not given, will default to the
            current timestamp.
        input_model_path: Path
            Specify the path to the model checkpoint in order to fine-tune
            the model.
        patch_cnn_settings: VocalContourSettings
            The configuration that holds all relative settings for
            the life-cycle of model building.
        """
        settings = self._validate_and_get_settings(patch_cnn_settings)

        if input_model_path is not None:
            logger.info("Continue to train on model: %s", input_model_path)
            model, prev_set = self._load_model(
                input_model_path, custom_objects=self.custom_objects)
            settings.feature.patch_size = prev_set.feature.patch_size

        logger.info("Constructing dataset instance")
        split = settings.training.steps / (settings.training.steps +
                                           settings.training.val_steps)
        train_feat_files, val_feat_files = get_train_val_feat_file_list(
            feature_folder, split=split)

        output_types = (tf.float32, tf.float32)
        output_shapes = ((settings.feature.patch_size,
                          settings.feature.patch_size, 1), (2))
        train_dataset = PatchCNNDatasetLoader(
                feature_files=train_feat_files,
                num_samples=settings.training.epoch * settings.training.batch_size * settings.training.steps
            ) \
            .get_dataset(settings.training.batch_size, output_types=output_types, output_shapes=output_shapes)
        val_dataset = PatchCNNDatasetLoader(
                feature_files=val_feat_files,
                num_samples=settings.training.epoch * settings.training.val_batch_size * settings.training.val_steps
            ) \
            .get_dataset(settings.training.val_batch_size, output_types=output_types, output_shapes=output_shapes)

        if input_model_path is None:
            logger.info("Constructing new model")
            model = patch_cnn_model(patch_size=settings.feature.patch_size)

        logger.info("Compiling model")
        optimizer = tf.keras.optimizers.Adam(
            learning_rate=settings.training.init_learning_rate)
        model.compile(optimizer=optimizer,
                      loss="categorical_crossentropy",
                      metrics=["accuracy"])

        logger.info("Resolving model output path")
        if model_name is None:
            model_name = str(datetime.now()).replace(" ", "_")
        if not model_name.startswith(settings.model.save_prefix):
            model_name = settings.model.save_prefix + "_" + model_name
        model_save_path = jpath(settings.model.save_path, model_name)
        ensure_path_exists(model_save_path)
        write_yaml(settings.to_json(),
                   jpath(model_save_path, "configurations.yaml"))
        write_yaml(model.to_yaml(),
                   jpath(model_save_path, "arch.yaml"),
                   dump=False)
        logger.info("Model output to: %s", model_save_path)

        logger.info("Constrcuting callbacks")
        callbacks = [
            tf.keras.callbacks.EarlyStopping(
                patience=settings.training.early_stop),
            tf.keras.callbacks.ModelCheckpoint(jpath(model_save_path,
                                                     "weights.h5"),
                                               save_weights_only=True)
        ]
        logger.info("Callback list: %s", callbacks)

        logger.info("Start training")
        history = model.fit(train_dataset,
                            validation_data=val_dataset,
                            epochs=settings.training.epoch,
                            steps_per_epoch=settings.training.steps,
                            validation_steps=settings.training.val_steps,
                            callbacks=callbacks,
                            use_multiprocessing=True,
                            workers=8)
        return model_save_path, history
コード例 #7
0
ファイル: app.py プロジェクト: ykhorzon/omnizart
    def train(self,
              feature_folder,
              model_name=None,
              input_model_path=None,
              vocalcontour_settings=None):
        """Model training.

        Train the model from scratch or continue training given a model checkpoint.

        Parameters
        ----------
        feature_folder: Path
            Path to the generated feature.
        model_name: str
            The name of the trained model. If not given, will default to the
            current timestamp.
        input_model_path: Path
            Specify the path to the model checkpoint in order to fine-tune
            the model.
        vocalcontour_settings: VocalContourSettings
            The configuration that holds all relative settings for
            the life-cycle of model building.
        """
        settings = self._validate_and_get_settings(vocalcontour_settings)

        if input_model_path is not None:
            logger.info("Continue to train one model: %s", input_model_path)
            model, prev_set = self._load_model(input_model_path)
            settings.training.timesteps = prev_set.training.timesteps
            settings.model.save_path = prev_set.model.save_path

        logger.info("Constructing dataset instance")
        split = settings.training.steps / (settings.training.steps +
                                           settings.training.val_steps)
        train_feat_files, val_feat_files = get_train_val_feat_file_list(
            feature_folder, split=split)

        output_types = (tf.float32, tf.float32)
        train_dataset = VocalContourDatasetLoader(
            feature_files=train_feat_files,
            num_samples=settings.training.batch_size * settings.training.steps,
            timesteps=settings.training.timesteps).get_dataset(
                settings.training.batch_size, output_types=output_types)

        val_dataset = VocalContourDatasetLoader(
            feature_files=val_feat_files,
            num_samples=settings.training.val_batch_size *
            settings.training.val_steps,
            timesteps=settings.training.timesteps).get_dataset(
                settings.training.val_batch_size, output_types=output_types)

        if input_model_path is None:
            logger.info("Constructing new model")
            # NOTE: The default value of dropout rate for ConvBlock is different
            # in VocalSeg which is 0.2.
            model = semantic_segmentation(
                multi_grid_layer_n=1,
                feature_num=384,
                ch_num=1,
                timesteps=settings.training.timesteps)
        model.compile(optimizer="adam", loss=focal_loss, metrics=['accuracy'])

        logger.info("Resolving model output path")
        if model_name is None:
            model_name = str(datetime.now()).replace(" ", "_")
        if not model_name.startswith(settings.model.save_prefix):
            model_name = settings.model.save_prefix + "_" + model_name

        model_save_path = jpath(settings.model.save_path, model_name)
        ensure_path_exists(model_save_path)
        write_yaml(settings.to_json(),
                   jpath(model_save_path, "configurations.yaml"))
        write_yaml(model.to_yaml(),
                   jpath(model_save_path, "arch.yaml"),
                   dump=False)
        logger.info("Model output to: %s", model_save_path)

        logger.info("Constructing callbacks")
        callbacks = [
            EarlyStopping(patience=settings.training.early_stop),
            ModelCheckpoint(model_save_path, save_weights_only=True)
        ]
        logger.info("Callback list: %s", callbacks)

        logger.info("Start training")
        history = train_epochs(model,
                               train_dataset,
                               validate_dataset=val_dataset,
                               epochs=settings.training.epoch,
                               steps=settings.training.steps,
                               val_steps=settings.training.val_steps,
                               callbacks=callbacks)

        return model_save_path, history