def train_model(train, dev, loss_func='categorical_crossentropy'):
    # metaparams
    metaparams = {
        "optimizer": "Adam",  # SGD, Nadam
        "learning_rate_max": 0.001,  # up to 0.0001
        "learning_rate_min": 0.00001,  # up to 0.000001
        "lr_scheduller": "reduceLRonPlateau",  # "reduceLRonPlateau"
        "annealing_period": 10,
        "epochs": 100,
        "batch_size": 4,
        "architecture": "LSTM_no_attention",
        "dataset": "NoXi",
        'type_of_labels': 'sequence_to_one',
        "num_classes": 5,
        'num_embeddings': 256,
        'num_layers': 3,
        'num_neurons': 256,
        'window_length': 40,
        'window_shift': 10,
        'window_stride': 1,
    }

    # initialization of Weights and Biases

    # Metaparams initialization
    metrics = ['accuracy']
    if metaparams['lr_scheduller'] == 'Cyclic':
        lr_scheduller = get_annealing_LRreduce_callback(
            highest_lr=metaparams['learning_rate_max'],
            lowest_lr=metaparams['learning_rate_min'],
            annealing_period=metaparams['annealing_period'])
    elif metaparams['lr_scheduller'] == 'reduceLRonPlateau':
        lr_scheduller = get_reduceLRonPlateau_callback(
            monitoring_loss='val_loss',
            reduce_factor=0.1,
            num_patient_epochs=5,
            min_lr=metaparams['learning_rate_min'])
    else:
        raise Exception("You passed wrong lr_scheduller.")

    if metaparams['optimizer'] == 'Adam':
        optimizer = tf.keras.optimizers.Adam(metaparams['learning_rate_max'])
    elif metaparams['optimizer'] == 'Nadam':
        optimizer = tf.keras.optimizers.Nadam(metaparams['learning_rate_max'])
    elif metaparams['optimizer'] == 'SGD':
        optimizer = tf.keras.optimizers.SGD(metaparams['learning_rate_max'])
    else:
        raise Exception("You passed wrong optimizer name.")

    # class weights, as it is computed in sklearn
    class_weights = pd.concat(
        train.values(),
        axis=0).iloc[:, -metaparams['num_classes']:].values.sum(axis=0)
    class_weights = class_weights.sum() / (metaparams['num_classes'] +
                                           class_weights)
    class_weights = class_weights / class_weights.sum()

    # loss function
    if loss_func == 'categorical_crossentropy':
        loss = tf.keras.losses.categorical_crossentropy
        train_class_weights = {
            i: class_weights[i]
            for i in range(metaparams['num_classes'])
        }
    elif loss_func == 'focal_loss':
        focal_loss_gamma = 2
        loss = categorical_focal_loss(alpha=class_weights,
                                      gamma=focal_loss_gamma)
        train_class_weights = None
    else:
        raise AttributeError(
            'Passed name of loss function is not acceptable. Possible variants are categorical_crossentropy or focal_loss.'
        )
    # model initialization
    model = create_sequence_model(
        num_classes=metaparams['num_classes'],
        neurons_on_layer=tuple(metaparams['num_neurons']
                               for i in range(metaparams['num_layers'])),
        input_shape=(metaparams['window_length'],
                     metaparams['num_embeddings']))
    # freezing layers?
    for i, layer in enumerate(model.layers):
        print("%i:%s" % (i, layer.name))

    # model compilation
    model.compile(loss=loss, optimizer=optimizer, metrics=metrics)
    model.summary()

    # create DataLoaders (DataGenerator)
    train_data_loader = create_generator_from_pd_dictionary(
        embeddings_dict=train,
        num_classes=metaparams['num_classes'],
        type_of_labels=metaparams['type_of_labels'],
        window_length=metaparams['window_length'],
        window_shift=metaparams['window_shift'],
        window_stride=metaparams['window_stride'],
        batch_size=metaparams['batch_size'],
        shuffle=True,
        preprocessing_function=None,
        clip_values=None,
        cache_loaded_seq=None)

    # transform labels in dev data to one-hot encodings
    dev_data_loader = create_generator_from_pd_dictionary(
        embeddings_dict=dev,
        num_classes=metaparams['num_classes'],
        type_of_labels=metaparams['type_of_labels'],
        window_length=metaparams['window_length'],
        window_shift=metaparams['window_length'],
        window_stride=metaparams['window_stride'],
        batch_size=metaparams['batch_size'],
        shuffle=False,
        preprocessing_function=None,
        clip_values=None,
        cache_loaded_seq=None)

    # create Keras Callbacks for monitoring learning rate and metrics on val_set
    lr_monitor_callback = WandB_LR_log_callback()
    val_metrics = {
        'val_recall': partial(recall_score, average='macro'),
        'val_precision': partial(precision_score, average='macro'),
        'val_f1_score:': partial(f1_score, average='macro')
    }
    val_metrics_callback = WandB_val_metrics_callback(
        dev_data_loader, val_metrics, metric_to_monitor='val_recall')
    early_stopping_callback = EarlyStopping(monitor='val_loss',
                                            patience=10,
                                            verbose=1)

    # train process
    print("Loss used:%s" % (loss))
    print("SEQUENCE TO ONE MODEL")
    print(metaparams['batch_size'])
    print("--------------------")
    model.fit(train_data_loader,
              epochs=metaparams['epochs'],
              class_weight=train_class_weights,
              validation_data=dev_data_loader,
              callbacks=[
                  lr_scheduller, early_stopping_callback, lr_monitor_callback,
                  val_metrics_callback
              ])
    # clear RAM
    del train_data_loader, dev_data_loader
    del model
    gc.collect()
    tf.keras.backend.clear_session()
def train_model(train, dev, loss_func='categorical_crossentropy'):
    """ Creates and trains on the NoXi dataset the Keras Tensorflow model.
            Here, the model is MobileNetv3.
            During the training all metaparams will be logged using the Weights and Biases library.
            Also, different augmentation methods will be applied (see down to the function).
            Overall, the function is designed only for the usage with Weights and Biases library.

        :param train: pd.DataFrame
                    Pandas DataFrame with the following columns: [filename, class] or [filename, class_0, class_1, ...].
                    Train dataset.
        :param dev: pd.DataFrame
                    Pandas DataFrame with the following columns: [filename, class] or [filename, class_0, class_1, ...].
                    Development dataset
        :param loss_func: str
                    Type of the loss function to be applied. Either "categorical_crossentropy" or "focal_loss".
        :return: None
        """
    # metaparams
    metaparams = {
        "optimizer": "Adam",  # SGD, Nadam
        "learning_rate_max": 0.001,  # up to 0.0001
        "learning_rate_min": 0.00001,  # up to 0.000001
        "lr_scheduller": "Cyclic",  # "reduceLRonPlateau"
        "annealing_period": 5,
        "epochs": 30,
        "batch_size": 256,
        "augmentation_rate": 0.1,  # 0.2, 0.3
        "architecture": "MobileNetv3_256_Dense",
        "dataset": "NoXi",
        "num_classes": 5
    }

    augmentation_methods = [
        partial(random_rotate90_image,
                probability=metaparams["augmentation_rate"]),
        partial(random_flip_vertical_image,
                probability=metaparams["augmentation_rate"]),
        partial(random_flip_horizontal_image,
                probability=metaparams["augmentation_rate"]),
        partial(random_crop_image,
                probability=metaparams["augmentation_rate"]),
        partial(random_change_brightness_image,
                probability=metaparams["augmentation_rate"],
                min_max_delta=0.35),
        partial(random_change_contrast_image,
                probability=metaparams["augmentation_rate"],
                min_factor=0.5,
                max_factor=1.5),
        partial(random_change_saturation_image,
                probability=metaparams["augmentation_rate"],
                min_factor=0.5,
                max_factor=1.5),
        partial(random_worse_quality_image,
                probability=metaparams["augmentation_rate"],
                min_factor=25,
                max_factor=99),
        partial(random_convert_to_grayscale_image,
                probability=metaparams["augmentation_rate"])
    ]

    # initialization of Weights and Biases
    wandb.init(project="VGGFace2_FtF_training", config=metaparams)
    config = wandb.config

    # Metaparams initialization
    metrics = ['accuracy']
    if config.lr_scheduller == 'Cyclic':
        lr_scheduller = get_annealing_LRreduce_callback(
            highest_lr=config.learning_rate_max,
            lowest_lr=config.learning_rate_min,
            annealing_period=config.annealing_period)
    elif config.lr_scheduller == 'reduceLRonPlateau':
        lr_scheduller = get_reduceLRonPlateau_callback(
            monitoring_loss='val_loss',
            reduce_factor=0.1,
            num_patient_epochs=4,
            min_lr=config.learning_rate_min)
    else:
        raise Exception("You passed wrong lr_scheduller.")

    if config.optimizer == 'Adam':
        optimizer = tf.keras.optimizers.Adam(config.learning_rate_max)
    elif config.optimizer == 'Nadam':
        optimizer = tf.keras.optimizers.Nadam(config.learning_rate_max)
    elif config.optimizer == 'SGD':
        optimizer = tf.keras.optimizers.SGD(config.learning_rate_max)
    else:
        raise Exception("You passed wrong optimizer name.")

    # class weights
    class_weights = compute_class_weight(class_weight='balanced',
                                         classes=np.unique(
                                             np.argmax(train.iloc[:,
                                                                  1:].values,
                                                       axis=1,
                                                       keepdims=True)),
                                         y=np.argmax(train.iloc[:, 1:].values,
                                                     axis=1,
                                                     keepdims=True).flatten())

    # loss function
    if loss_func == 'categorical_crossentropy':
        loss = tf.keras.losses.categorical_crossentropy
        train_class_weights = {
            i: class_weights[i]
            for i in range(config.num_classes)
        }
    elif loss_func == 'focal_loss':
        focal_loss_gamma = 2
        loss = categorical_focal_loss(alpha=class_weights,
                                      gamma=focal_loss_gamma)
        train_class_weights = None
    else:
        raise AttributeError(
            'Passed name of loss function is not acceptable. Possible variants are categorical_crossentropy or focal_loss.'
        )
    wandb.config.update({'loss': loss})
    # model initialization
    model = create_MobileNetv3_model(num_classes=config.num_classes)
    # freezing layers?

    for i, layer in enumerate(model.layers):
        print("%i:%s" % (i, layer.name))

    # for i in range(75): # up to block 8
    #    model.layers[i].trainable = False

    # model compilation
    model.compile(loss=loss, optimizer=optimizer, metrics=metrics)
    model.summary()

    # create DataLoaders (DataGenerator)
    train_data_loader = get_tensorflow_generator(
        paths_and_labels=train,
        batch_size=metaparams["batch_size"],
        augmentation=True,
        augmentation_methods=augmentation_methods,
        preprocessing_function=preprocess_data_MobileNetv3,
        clip_values=None,
        cache_loaded_images=False)
    # transform labels in dev data to one-hot encodings
    dev = dev.__deepcopy__()
    dev = pd.concat([dev, pd.get_dummies(dev['class'], dtype="float32")],
                    axis=1).drop(columns=['class'])

    dev_data_loader = get_tensorflow_generator(
        paths_and_labels=dev,
        batch_size=metaparams["batch_size"],
        augmentation=False,
        augmentation_methods=None,
        preprocessing_function=preprocess_data_MobileNetv3,
        clip_values=None,
        cache_loaded_images=False)

    # create Keras Callbacks for monitoring learning rate and metrics on val_set
    lr_monitor_callback = WandB_LR_log_callback()
    val_metrics = {
        'val_recall': partial(recall_score, average='macro'),
        'val_precision': partial(precision_score, average='macro'),
        'val_f1_score:': partial(f1_score, average='macro')
    }
    val_metrics_callback = WandB_val_metrics_callback(
        dev_data_loader, val_metrics, metric_to_monitor='val_recall')
    early_stopping_callback = EarlyStopping(monitor='val_loss',
                                            patience=7,
                                            verbose=1)

    # train process
    print("Loss used:%s" % (loss))
    print("MobileNetv3, LAYERS ARE NOT FROZEN")
    print(config.batch_size)
    print("--------------------")
    model.fit(train_data_loader,
              epochs=config.epochs,
              class_weight=train_class_weights,
              validation_data=dev_data_loader,
              callbacks=[
                  WandbCallback(), lr_scheduller, early_stopping_callback,
                  lr_monitor_callback, val_metrics_callback
              ])
    # clear RAM
    del train_data_loader, dev_data_loader
    del model
    gc.collect()
    tf.keras.backend.clear_session()
Example #3
0
    logger.write('Additional info:%s\n' %
                 'VGGFace2 model with 1024-512-128-4 dense layers. Engagement recognition task with 4 classes.')

    # create callbacks
    callbacks=[validation_with_generator_callback_multilabel(test_gen, metrics=(partial(f1_score, average='macro'),
                                                                        accuracy_score,
                                                                        partial(recall_score, average='macro')),
                                                                        num_label_types=4,
                                                                        num_metric_to_set_weights=2,
                                                                        logger=logger)]

    # create metrics
    metrics=[tf.keras.metrics.CategoricalAccuracy(),tf.keras.metrics.Recall()]

    # define focal loss
    losses = {'dense_3':categorical_focal_loss(alpha=class_weights, gamma=focal_loss_gamma),
    }
    """loss_weights={
        'dense_2': 1.0,
        'dense_4': 0.33,
        'dense_6': 0.33,
        'dense_8': 0.33
    }"""
    #losses=tf.keras.losses.categorical_crossentropy
    tf.keras.utils.plot_model(model, 'model.png')
    model=train_model(train_gen, model, optimizer, losses, epochs,
                      dev_gen, metrics, callbacks, path_to_save_results='results')
    model.save_weights(os.path.join(path_to_save_model_and_results, "model_weights.h5"))
    logger.close()

def train_model(*, path_to_save_model_and_results: str, epochs: int,
                highest_lr: float, lowest_lr: float, num_frames_in_seq: int,
                focal_loss_gamma: float, class_weights: List[float], train_gen,
                dev_gen, test_gen):
    # create output path
    if not os.path.exists(path_to_save_model_and_results):
        os.makedirs(path_to_save_model_and_results)

    # create logger
    logger_dev = open(os.path.join(path_to_save_model_and_results,
                                   'val_logs.txt'),
                      mode='w')
    logger_dev.close()
    logger_dev = open(os.path.join(path_to_save_model_and_results,
                                   'val_logs.txt'),
                      mode='a')
    # write training params and all important information:
    logger_dev.write('# Train params:\n')
    logger_dev.write('Database:%s\n' % "DAiSEE")
    logger_dev.write('Epochs:%i\n' % epochs)
    logger_dev.write('Highest_lr:%f\n' % highest_lr)
    logger_dev.write('Lowest_lr:%f\n' % lowest_lr)
    logger_dev.write('num_frames_in_seq:%i\n' % num_frames_in_seq)
    logger_dev.write('Loss:%s\n' % 'focal loss (gamma=2)')
    logger_dev.write('Class_weights:%s\n' % class_weights)
    logger_dev.write(
        'Additional info:%s\n' %
        'AttVGGFace2 and EMOVGGFace2 embeddings + FAU, then - 3 LSTMs with self-attention. 4 engagement classes with focal loss'
    )

    # create logger
    logger_test = open(os.path.join(path_to_save_model_and_results,
                                    'test_logs.txt'),
                       mode='w')
    logger_test.close()
    logger_test = open(os.path.join(path_to_save_model_and_results,
                                    'test_logs.txt'),
                       mode='a')
    # write training params and all important information:
    logger_test.write('# Train params:\n')
    logger_test.write('Database:%s\n' % "DAiSEE")
    logger_test.write('Epochs:%i\n' % epochs)
    logger_test.write('Highest_lr:%f\n' % highest_lr)
    logger_test.write('Lowest_lr:%f\n' % lowest_lr)
    logger_test.write('num_frames_in_seq:%i\n' % num_frames_in_seq)
    logger_test.write('Loss:%s\n' % 'focal loss (gamma=2)')
    logger_test.write('Class_weights:%s\n' % class_weights)
    logger_test.write(
        'Additional info:%s\n' %
        'AttVGGFace2 and EMOVGGFace2 embeddings + FAU, then - 2 LSTMs with attention. 4 engagement classes with focal loss'
    )

    # create callbacks
    callbacks = [
        validation_with_generator_callback_multilabel(
            dev_gen,
            metrics=(partial(f1_score, average='macro'), accuracy_score,
                     partial(recall_score, average='macro')),
            num_label_types=1,
            num_metric_to_set_weights=2,
            logger=logger_dev),
        validation_with_generator_callback_multilabel(
            test_gen,
            metrics=(partial(f1_score, average='macro'), accuracy_score,
                     partial(recall_score, average='macro')),
            num_label_types=1,
            num_metric_to_set_weights=None,
            logger=logger_test),
        get_annealing_LRreduce_callback(highest_lr, lowest_lr, epochs)
    ]

    # create metrics
    metrics = [
        tf.keras.metrics.CategoricalAccuracy(),
        tf.keras.metrics.Recall()
    ]

    # loss
    # define focal loss
    losses = {
        'dense_1':
        categorical_focal_loss(alpha=class_weights, gamma=focal_loss_gamma),
    }
    # optimizer
    optimizer = tf.keras.optimizers.Adam(highest_lr, clipnorm=1.)

    model = tf.keras.Sequential()
    model.add(
        tf.keras.layers.LSTM(
            512,
            input_shape=(num_frames_in_seq, 1571),
            return_sequences=True,
            kernel_regularizer=tf.keras.regularizers.l2(0.0001)))
    model.add(tf.keras.layers.Dropout(0.5))
    model.add(
        _Self_attention_non_local_block_without_shortcut_connection(512,
                                                                    mode='1D'))
    model.add(
        tf.keras.layers.LSTM(
            256,
            return_sequences=True,
            kernel_regularizer=tf.keras.regularizers.l2(0.0001)))
    model.add(tf.keras.layers.Dropout(0.5))
    model.add(
        _Self_attention_non_local_block_without_shortcut_connection(256,
                                                                    mode='1D'))
    model.add(
        tf.keras.layers.LSTM(
            128,
            return_sequences=False,
            kernel_regularizer=tf.keras.regularizers.l2(0.0001)))
    model.add(tf.keras.layers.Dropout(0.5))
    model.add(
        tf.keras.layers.Dense(
            128,
            activation='relu',
            kernel_regularizer=tf.keras.regularizers.l2(0.0001)))
    model.add(tf.keras.layers.Dense(4, activation='softmax'))
    model.compile(optimizer=optimizer, loss=losses, metrics=metrics)
    model.summary()
    model.fit(train_gen, epochs=epochs, callbacks=callbacks)
    model.save_weights(
        os.path.join(path_to_save_model_and_results, "model_weights.h5"))
Example #5
0
def train_model(train, dev, loss_func='focal_loss'):
    # metaparams
    metaparams = {
        "optimizer": "Adam",  # SGD, Nadam
        "learning_rate_max": 0.001,  # up to 0.0001
        "learning_rate_min": 0.00001,  # up to 0.000001
        "lr_scheduller": "Cyclic",  # "reduceLRonPlateau"
        "annealing_period": 5,
        "epochs": 30,
        "batch_size": 256,
        "augmentation_rate": 0.1,  # 0.2, 0.3
        "architecture": "MobileNetv3_256_Dense",
        "dataset": "NoXi",
        "num_classes": 5
    }

    augmentation_methods = [
        partial(random_rotate90_image,
                probability=metaparams["augmentation_rate"]),
        partial(random_flip_vertical_image,
                probability=metaparams["augmentation_rate"]),
        partial(random_flip_horizontal_image,
                probability=metaparams["augmentation_rate"]),
        partial(random_crop_image,
                probability=metaparams["augmentation_rate"]),
        partial(random_change_brightness_image,
                probability=metaparams["augmentation_rate"],
                min_max_delta=0.35),
        partial(random_change_contrast_image,
                probability=metaparams["augmentation_rate"],
                min_factor=0.5,
                max_factor=1.5),
        partial(random_change_saturation_image,
                probability=metaparams["augmentation_rate"],
                min_factor=0.5,
                max_factor=1.5),
        partial(random_worse_quality_image,
                probability=metaparams["augmentation_rate"],
                min_factor=25,
                max_factor=99),
        partial(random_convert_to_grayscale_image,
                probability=metaparams["augmentation_rate"])
    ]

    # initialization of Weights and Biases
    # wandb.init(project="VGGFace2_FtF_training", config=metaparams)
    # config = wandb.config

    # Metaparams initialization
    metrics = ['accuracy']
    if metaparams['lr_scheduller'] == 'Cyclic':
        lr_scheduller = get_annealing_LRreduce_callback(
            highest_lr=metaparams['learning_rate_max'],
            lowest_lr=metaparams['learning_rate_min'],
            annealing_period=metaparams['annealing_period'])
    elif metaparams['lr_scheduller'] == 'reduceLRonPlateau':
        lr_scheduller = get_reduceLRonPlateau_callback(
            monitoring_loss='val_loss',
            reduce_factor=0.1,
            num_patient_epochs=4,
            min_lr=metaparams['learning_rate_min'])
    else:
        raise Exception("You passed wrong lr_scheduller.")

    if metaparams['optimizer'] == 'Adam':
        optimizer = tf.keras.optimizers.Adam(metaparams['learning_rate_max'])
    elif metaparams['optimizer'] == 'Nadam':
        optimizer = tf.keras.optimizers.Nadam(metaparams['learning_rate_max'])
    elif metaparams['optimizer'] == 'SGD':
        optimizer = tf.keras.optimizers.SGD(metaparams['learning_rate_max'])
    else:
        raise Exception("You passed wrong optimizer name.")

    # class weights
    class_weights = compute_class_weight(class_weight='balanced',
                                         classes=np.unique(
                                             np.argmax(train.iloc[:,
                                                                  1:].values,
                                                       axis=1,
                                                       keepdims=True)),
                                         y=np.argmax(train.iloc[:, 1:].values,
                                                     axis=1,
                                                     keepdims=True).flatten())

    # loss function
    if loss_func == 'categorical_crossentropy':
        loss = tf.keras.losses.categorical_crossentropy
        train_class_weights = {
            i: class_weights[i]
            for i in range(metaparams['num_classes'])
        }
    elif loss_func == 'focal_loss':
        focal_loss_gamma = 2
        loss = categorical_focal_loss(alpha=class_weights,
                                      gamma=focal_loss_gamma)
        train_class_weights = None
    else:
        raise AttributeError(
            'Passed name of loss function is not acceptable. Possible variants are categorical_crossentropy or focal_loss.'
        )
    # wandb.config.update({'loss': loss})
    # model initialization
    model = create_MobileNetv3_model(num_classes=metaparams['num_classes'])
    # freezing layers?

    for i, layer in enumerate(model.layers):
        print("%i:%s" % (i, layer.name))

    # for i in range(75): # up to block 8
    #    model.layers[i].trainable = False

    # model compilation
    model.compile(loss=loss, optimizer=optimizer, metrics=metrics)
    model.summary()

    # create DataLoaders (DataGenerator)
    train_data_loader = get_tensorflow_generator(
        paths_and_labels=train,
        batch_size=metaparams["batch_size"],
        augmentation=True,
        augmentation_methods=augmentation_methods,
        preprocessing_function=preprocess_data_MobileNetv3,
        clip_values=None,
        cache_loaded_images=False)

    dev_for_keras_validation_set = dev.__deepcopy__()
    dev_for_keras_validation_set = pd.concat([
        dev_for_keras_validation_set,
        pd.get_dummies(dev_for_keras_validation_set['class'], dtype="float32")
    ],
                                             axis=1).drop(columns=['class'])

    dev_data_loader_for_keras_validation = get_tensorflow_generator(
        paths_and_labels=dev_for_keras_validation_set,
        batch_size=metaparams["batch_size"],
        augmentation=False,
        augmentation_methods=None,
        preprocessing_function=preprocess_data_MobileNetv3,
        clip_values=None,
        cache_loaded_images=False)

    # create Keras Callbacks for monitoring learning rate and metrics on val_set
    lr_monitor_callback = WandB_LR_log_callback()
    val_metrics = {
        'val_recall': partial(recall_score, average='macro'),
        'val_precision': partial(precision_score, average='macro'),
        'val_f1_score:': partial(f1_score, average='macro')
    }
    val_metrics_callback = WandB_val_metrics_callback(
        dev_data_loader_for_keras_validation,
        val_metrics,
        metric_to_monitor='val_recall')
    early_stopping_callback = EarlyStopping(monitor='val_loss',
                                            patience=7,
                                            verbose=1)

    # train process
    print("Loss used:%s" % (loss))
    print("MobileNetv3, LAYERS ARE NOT FROZEN")
    print(metaparams['batch_size'])
    print("--------------------")
    model.fit(
        train_data_loader,
        epochs=metaparams['epochs'],
        class_weight=train_class_weights,
        validation_data=dev_data_loader_for_keras_validation,
        callbacks=[  # WandbCallback(),
            lr_scheduller, early_stopping_callback, lr_monitor_callback,
            val_metrics_callback
        ])
    # clear RAM
    del train_data_loader, dev_data_loader_for_keras_validation
    del model
    gc.collect()
    tf.keras.backend.clear_session()