def load_model(data_gen: AudioGenerator, model_builder: ModelBuilder):
    model = model_builder.model(input_shape=(None, data_gen.input_dim), output_dim=29)
    model.load_weights('results/' + ("Spec " if data_gen.spectrogram else "MFCC ") + model.name + '.h5')
    return model
def train_model(audio_gen: AudioGenerator,
                model_builder: ModelBuilder,
                # pickle_path,
                # save_model_path,
                # train_json='train_corpus.json',
                # valid_json='valid_corpus.json',
                optimizer=SGD(lr=0.02, decay=1e-6, momentum=0.9, nesterov=True, clipnorm=5),
                # optimizer=Adam(lr=1e-01),
                epochs=30,
                verbose=0,
                # sort_by_duration=False,
                loss_limit=400):
    # create a class instance for obtaining batches of data
    input_dim = audio_gen.input_dim
    if audio_gen.max_length is None:
        model = model_builder.model(input_shape=(None, input_dim), output_dim=29)
    else:
        model = model_builder.model(input_shape=(audio_gen.max_length, input_dim), output_dim=29)
    model_name = ("Spec" if audio_gen.spectrogram else "MFCC") + " " + model.name
    model.name = model_name
    save_model_path = model.name + ".h5"

    # add the training data to the generator
    # audio_gen.load_train_data(train_json)
    # audio_gen.load_validation_data(valid_json)
    # calculate steps_per_epoch
    num_train_examples = len(audio_gen.train_audio_paths)
    steps_per_epoch = num_train_examples // audio_gen.minibatch_size
    # calculate validation_steps
    num_valid_samples = len(audio_gen.valid_audio_paths)
    validation_steps = num_valid_samples // audio_gen.minibatch_size

    # add CTC loss to the NN specified in input_to_softmax
    pre_model = model
    model = add_ctc_loss(model)

    # CTC loss is implemented elsewhere, so use a dummy lambda function for the loss
    model.compile(loss={'ctc': lambda y_true, y_pred: y_pred}, optimizer=optimizer)

    # make results/ directory, if necessary
    if not os.path.exists('results'):
        os.makedirs('results')

    # add model_checkpoint
    model_checkpoint = ModelCheckpoint(filepath='results/' + save_model_path, verbose=0, save_best_only=True)
    terminate_on_na_n = TerminateOnNaN()
    if verbose > 0:
        callbacks = [model_checkpoint, terminate_on_na_n]
    else:
        metrics_logger = MetricsLogger(model_name=model_name, n_epochs=epochs, loss_limit=loss_limit)
        callbacks = [model_checkpoint, metrics_logger]
        # callbacks = [model_checkpoint, metrics_logger, terminate_on_na_n]

    try:
        # hist = \
        model.fit_generator(generator=audio_gen.next_train(), steps_per_epoch=steps_per_epoch,
                            epochs=epochs, validation_data=audio_gen.next_valid(),
                            validation_steps=validation_steps,
                            callbacks=callbacks, verbose=verbose)
        # hist.history["name"] = model_name
        # save model loss
        # pickle_file_name = 'results/' + pickle_path
        # print("Writing hist.history[\"name\"] = ", model_name, "to ", pickle_file_name)
        # with open(pickle_file_name, 'wb') as f:
        #     pickle.dump(hist.history, f)
    except KeyboardInterrupt:
        display.clear_output(wait=True)
        # print("Training interrupted")
    except Exception:
        try:
            exc_info = sys.exc_info()
        finally:
            # Display the *original* exception
            traceback.print_exception(*exc_info)
            del exc_info
    finally:
        pre_model.summary()
        del pre_model
        del model
    return model_name