Ejemplo n.º 1
0
    def build_model(cls, n_features):
        n_labels = len(cls.label_names)

        input_shape = (None, n_features)
        model_input = Input(input_shape, name='input')
        layer = model_input
        layer = LSTM(units=128,
                     dropout=0.05,
                     recurrent_dropout=0.35,
                     return_sequences=True,
                     input_shape=input_shape,
                     name='lstm_1')(layer)
        layer = LSTM(units=32,
                     dropout=0.05,
                     recurrent_dropout=0.35,
                     return_sequences=False,
                     name='lstm_2')(layer)
        layer = Dense(units=n_labels, activation='softmax')(layer)

        model_output = layer
        model = Model(model_input, model_output)
        model.compile(loss='categorical_crossentropy',
                      optimizer=Adam(),
                      metrics=[
                          'accuracy',
                          TopKCategoricalAccuracy(3, name='top3-accuracy')
                      ])
        return model
Ejemplo n.º 2
0
    def build_model(cls, n_features):
        n_labels = len(cls.label_names)

        print('Building model...')
        input_shape = (None, n_features)
        model_input = Input(input_shape, name='input')
        layer = model_input
        for i in range(N_LAYERS):
            layer = Convolution1D(filters=CONV_FILTER_COUNT,
                                  kernel_size=FILTER_LENGTH,
                                  name='convolution_' + str(i + 1))(layer)
            layer = BatchNormalization(momentum=0.9)(layer)
            layer = Activation('relu')(layer)
            layer = MaxPooling1D(2)(layer)
            layer = Dropout(0.5)(layer)

        layer = TimeDistributed(Dense(n_labels))(layer)
        time_distributed_merge_layer = Lambda(
            function=lambda x: K.mean(x, axis=1),
            output_shape=lambda shape: (shape[0], ) + shape[2:],
            name='output_merged')
        layer = time_distributed_merge_layer(layer)
        layer = Activation('softmax', name='output_realtime')(layer)
        model_output = layer
        model = Model(model_input, model_output)
        model.compile(loss='categorical_crossentropy',
                      optimizer=Adam(learning_rate=0.001),
                      metrics=[
                          'accuracy',
                          TopKCategoricalAccuracy(3, name='top3-accuracy')
                      ])
        return model
Ejemplo n.º 3
0
def perform_grid_search(opt, lr, bs, model):
    # Just cross tunning every parameter with each other with 3 nested for loops.
    for batch in bs:
        train_batches = prepare(training_data, batch_size=batch)
        test_batches = prepare(test_data, batch_size=batch)
        for optimizer in opt:
            for learning_rate in lr:
                model.compile(optimizer = optimizer(learning_rata = learning_rate), loss = 'categorical_crossentropy', metrics = ['accuracy', TopKCategoricalAccuracy(k=3)])
                model_history = model.fit(train_batches, validation_data=test_batches,  epochs = 50, callbacks = [early_stopping])
                plot_train_and_validation_results(model_history.history['loss'], model_history.history['val_loss'], 'Training and Validation Loss', range(0, 51, 10) )
                plot_train_and_validation_results(model_history.history['accuracy'], model_history.history['val_accuracy'], 'Training and Validation Accuracy', range(0, 51, 10) )
                plot_train_and_validation_results(model_history.history['top_k_categorical_accuracy'], model_history.history['val_top_k_categorical_accuracy'], 'Top 3 Training and Validation Accuracy',  range(0, 51, 10))
Ejemplo n.º 4
0
def run(args, use_gpu=True):

    # saving
    save_path = os.path.join(os.getcwd(), 'models')
    if not os.path.isdir(save_path):
        os.mkdir(save_path)

    model = lipnext(inputDim=256,
                    hiddenDim=512,
                    nClasses=args.nClasses,
                    frameLen=29,
                    alpha=args.alpha)
    #model = tf.keras.Sequential([
    #    tf.keras.layers.Flatten(),
    #    tf.keras.layers.Dense(args.nClasses)
    # ])

    if args.train == True:
        mode = "train"
    else:
        mode = "test"

#    train_list = glob.glob("./test_tfrecord_ACTUALLY_color/*.tfrecords")
#    val_list = glob.glob("./test_tfrecord_ACTUALLY_color/*.tfrecords")
#    test_list = glob.glob("./test_tfrecord_ACTUALLY_color/*.tfrecords")
#train_list = glob.glob("/mnt/disks/data/dataset/lipread_tfrecords/*/train/*.tfrecords")
#val_list = glob.glob("/mnt/disks/data/dataset/lipread_tfrecords/*/val/*.tfrecords")
#test_list = glob.glob("/mnt/disks/data/dataset/lipread_tfrecords/*/test/*.tfrecords")

    with open(args.labels) as f:
        labels = f.read().splitlines()
    train_list = []
    val_list = []
    test_list = []
    for word in labels:
        # print(word)
        train_list.extend(glob.glob(args.dataset + word +
                                    '/train/*.tfrecords'))
        val_list.extend(glob.glob(args.dataset + word + '/val/*.tfrecords'))
        test_list.extend(glob.glob(args.dataset + word + '/test/*.tfrecords'))
    # randomly shuffle *_list
    np.random.shuffle(train_list)
    np.random.shuffle(val_list)
    np.random.shuffle(test_list)

    if mode == "train":
        dataset = tf.data.TFRecordDataset(train_list)
        val_dataset = tf.data.TFRecordDataset(val_list)
    else:
        dataset = tf.data.TFRecordDataset(test_list)  #test_list)
    # print("raw_dataset: ", dataset)

    if mode == "train":
        dataset = dataset.map(_parse_function)
        val_dataset = val_dataset.map(_parse_function)
        #check_dataset(dataset, "train_test_images", 'RGB')

        dataset = dataset.map(_train_preprocess_function)
        val_dataset = val_dataset.map(_test_preprocess_function)

        dataset = dataset.map(
            lambda x, y: _normalize_function(x, y, args.nClasses))
        val_dataset = val_dataset.map(
            lambda x, y: _normalize_function(x, y, args.nClasses))

        dataset = dataset.batch(args.batch_size, drop_remainder=True)
        val_dataset = val_dataset.batch(args.batch_size, drop_remainder=True)
        # check_dataset(dataset)

        dataset = dataset.map(lambda x, y: (x[:, :, ::2, ::2, :], y))
        val_dataset = val_dataset.map(lambda x, y: (x[:, :, ::2, ::2, :], y))
        # check_dataset(dataset)


#        dataset = dataset.map(lambda x, y: ( tf.reshape(x, [-1, x.shape[2], x.shape[3], x.shape[4]]), y))
#        val_dataset = val_dataset.map(lambda x, y: ( tf.reshape(x, [-1, x.shape[2], x.shape[3], x.shape[4]]), y))

    else:
        dataset = dataset.map(_parse_function)
        #check_dataset(dataset, "test_test_images", 'RGB')

        dataset = dataset.map(_test_preprocess_function)

        dataset = dataset.map(
            lambda x, y: _normalize_function(x, y, args.nClasses))

        dataset = dataset.batch(1, drop_remainder=True)

        dataset = dataset.map(lambda x, y: (x[:, :, ::2, ::2, :], y))

    #check_dataset(dataset, "test_test_images_processed", 'L')
    model.compile(optimizer=Adam(learning_rate=args.lr),
                  loss=CategoricalCrossentropy(from_logits=True),
                  metrics=[
                      'accuracy',
                      TopKCategoricalAccuracy(3),
                      keras.metrics.CategoricalAccuracy()
                  ])

    run_dir = args.save_path + datetime.now().strftime("%Y%m%d-%H%M%S")
    print(run_dir, "_-----------------------------------------------")
    callbacks = [
        # Interrupt training if `val_loss` stops improving for over 2 epochs
        # tf.keras.callbacks.EarlyStopping(patience=2, monitor='val_loss'),
        # Learning rate scheduler
        tf.keras.callbacks.LearningRateScheduler(
            lambda e: lr_scheduler(e, args.lr, args.lr_sleep_epochs)),
        # Save checkpoints
        tf.keras.callbacks.ModelCheckpoint(filepath=run_dir +
                                           '/runs/{epoch}/checkpoint',
                                           save_weights_only=True,
                                           save_best_only=True,
                                           monitor='val_categorical_accuracy',
                                           mode='max'),
        # Write TensorBoard logs to `./logs` directory
        tf.keras.callbacks.TensorBoard(log_dir=run_dir + '/logs',
                                       profile_batch=0)
    ]
    file_writer = tf.summary.create_file_writer(run_dir + '/logs/metrics')
    file_writer.set_as_default()

    if args.checkpoint:
        print("Loading model from: ", args.checkpoint)
        #model.load_weights(args.checkpoint)
        status = model.load_weights(args.checkpoint).expect_partial()
        print(f'STATUS: {status.assert_existing_objects_matched()}')

        #model = tf.keras.models.load_model(args.checkpoint)
    else:
        print("Model training from scratch -")

    def rep_dataset():
        for i in range(10):
            image = next(iter(dataset))
            yield [image[0]]
        #return [dataset.__iter__().next()[0]]

    print(f' REP DATASET: {rep_dataset()}')

    if mode == "train":
        #model.evaluate(val_dataset)
        model.fit(dataset,
                  epochs=args.epochs,
                  callbacks=callbacks,
                  validation_data=val_dataset)
        #assert False
        '''model.fit(dataset, epochs=1, callbacks=callbacks, steps_per_epoch=1)
        #model.save('../saved_model')
        #tf.keras.models.save_model(model, '../saved_model')
        #model.save_weights(args.save_path +'/final_weights/Conv3D_model')
        # Convert the model.
        converter = tf.lite.TFLiteConverter.from_keras_model(model)
        converter.optimizations = [tf.lite.Optimize.DEFAULT]
        # This ensures that if any ops can't be quantized, the converter throws an error
        converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
        # These set the input and output tensors to uint8
        converter.inference_input_type = tf.float32 # uint8
        converter.inference_output_type = tf.uint8 # uint8
        # And this sets the representative dataset so we can quantize the activations
        converter.representative_dataset = rep_dataset
        print('\n\nstarting conversion\n\n')
        tflite_model = converter.convert()

        print('\n\nmodel converted\n\n')

        # Save the TF Lite model.
        with tf.io.gfile.GFile('model_quantized_float.tflite', 'wb') as f:
              f.write(tflite_model)'''
    else:
        model.evaluate(dataset)
Ejemplo n.º 5
0
    cnn4.add(Flatten())

    cnn4.add(Dense(512, activation='relu'))
    cnn4.add(BatchNormalization())
    cnn4.add(Dropout(0.5))

    cnn4.add(Dense(128, activation='relu'))
    cnn4.add(BatchNormalization())
    cnn4.add(Dropout(0.5))

    cnn4.add(Dense(14, activation='softmax'))

    cnn4.compile(optimizer='adam',
                 loss='categorical_crossentropy',
                 metrics=['accuracy', TopKCategoricalAccuracy(k=3)])

    train_datagen = ImageDataGenerator(rotation_range=30.,
                                       shear_range=0.2,
                                       zoom_range=0.2,
                                       width_shift_range=0.2,
                                       height_shift_range=0.2,
                                       horizontal_flip=True,
                                       rescale=1. / 255)
    test_datagen = ImageDataGenerator(rescale=1. / 255)

    train_iterator = train_datagen.flow_from_directory(
        base_path + 'train/',
        class_mode='categorical',
        target_size=(200, 200),
        batch_size=16)
Ejemplo n.º 6
0
    x = model_resnet.output

    x = Dense(512, activation='elu', kernel_regularizer=l2(0.001))(x)
    y = Dense(14, activation='softmax', name='img')(x)

    final_model = Model(inputs=model_resnet.input, outputs=y)

    print(final_model.summary())

    opt = SGD(lr=0.0001, momentum=0.9, nesterov=True)

    final_model.compile(
        optimizer=opt,
        loss={'img': 'categorical_crossentropy'},
        metrics={'img': ['accuracy', TopKCategoricalAccuracy(k=3)]})

    train_datagen = ImageDataGenerator(rotation_range=30.,
                                       shear_range=0.2,
                                       zoom_range=0.2,
                                       width_shift_range=0.2,
                                       height_shift_range=0.2,
                                       horizontal_flip=True,
                                       rescale=1. / 255,
                                       validation_split=0.2)

    train_iterator = train_datagen.flow_from_directory(
        base_path,
        class_mode='categorical',
        target_size=(224, 224),
        subset='training')
Ejemplo n.º 7
0
    #input_shape, output_shape)
    for layer in model.layers[:-2]:
        layer.trainable = False

    for layer in model.layers[-2:]:
        layer.trainable = True

    model_checkpoint = ModelCheckpoint(filepath='weights_{}.hdf5'.format(
        args.experiment),
                                       monitor='val_loss')
    early_stopping = EarlyStopping(monitor='val_loss', patience=6)
    reduce_lr = ReduceLROnPlateau(monitor='val_loss', factor=0.1, patience=3)
    callbacks = [model_checkpoint, early_stopping, reduce_lr]

    top3_metric = TopKCategoricalAccuracy(k=3, name='top_3_cat_acc')
    precision_metric = Precision()
    recall_metric = Recall()
    metrics = ['accuracy', top3_metric, precision_metric, recall_metric]

    model.compile(optimizer=Adam(0.0001, 0.9, 0.99),
                  loss='categorical_crossentropy',
                  metrics=metrics)
    history = model.fit_generator(train_flow,
                                  steps_per_epoch=params['batches_per_epoch'],
                                  epochs=args.max_epochs,
                                  validation_data=valid_flow,
                                  workers=4,
                                  callbacks=callbacks)
    #history = model.fit_generator(train_flow, steps_per_epoch=params['batches_per_epoch'], epochs=4,
    #                              validation_data=valid_flow, workers=4, callbacks=callbacks)
Ejemplo n.º 8
0
Archivo: train.py Proyecto: abc-3/utils
def train_model(dataset, pretrained=False):
    """ Train Tensorflow-Keras model """

    LOGGER.info('\n---------------- Training Starting -----------------')
    LOGGER.info("Model Name: {}".format(Config['model_name']))
    LOGGER.info('\nData, Configuration, & Model Parameters')
    LOGGER.info('-----------------------------------------')
    LOGGER.info("Seed: {}".format(Config['seed']))
    LOGGER.info("Train/Test split: {}".format(Config['train_eval_split']))
    LOGGER.info("Max sequence: {}".format(Config['max_sequence_size']))
    LOGGER.info("Max labels: {}".format(Config['max_label_size']))
    LOGGER.info("Embedding dim: {}".format(Config['embedding_dim']))
    LOGGER.info("Binary threshold: {}".format(Config['binary_threshold']))
    LOGGER.info("Batch size: {}".format(Config['batch_size']))
    LOGGER.info("N epochs: {}".format(Config['n_epochs']))
    LOGGER.info("Eval @ k: {}".format(Config['evaluation@k']))
    LOGGER.info("Patience: {}".format(Config['patience']))
    LOGGER.info("Encoer type: {}".format(Config['model_encoder']))
    if Config['model_encoder'] == 'cnns':
        LOGGER.info("Num hidden units: {}".format(
            Config['cnns']['hidden_units_size']))
        LOGGER.info("Learning rate: {}".format(
            Config['cnns']['learning_rate']))
        LOGGER.info("Dropout: {}".format(Config['cnns']['spatial_dropout']))
    if Config['model_encoder'] == 'grus':
        LOGGER.info("Num hidden units: {}".format(
            Config['grus']['hidden_units_size']))
        LOGGER.info("Num hidden layers: {}".format(
            Config['grus']['n_hidden_layers']))
        LOGGER.info("Learning rate: {}".format(
            Config['grus']['learning_rate']))
        LOGGER.info("Dropout: {}".format(Config['grus']['spatial_dropout']))
    if pretrained:
        LOGGER.info('Loading Pretrained Model: {}'.format(
            Config['pretrained_model']))

    # Retrieve training / validation data
    train_X = dataset['training_dataset']['train_X']
    train_y = dataset['training_dataset']['train_y']
    dev_X = dataset['evaluation_dataset']['eval_X']
    dev_y = dataset['evaluation_dataset']['eval_y']
    # For training final model
    train_X = np.concatenate((train_X, dev_X))
    train_y = np.concatenate((train_y, dev_y))

    train_generator = BatchGenerator(train_X,
                                     train_y,
                                     batch_size=Config['batch_size'])
    val_generator = BatchGenerator(dev_X,
                                   dev_y,
                                   batch_size=Config['batch_size'])

    # Initialize the model
    if pretrained:
        LOGGER.info('\nLoading Model: {}'.format(BASE + Config['model_loc']))
        model = tf.saved_model.load(BASE + Config['model_loc'])
    else:
        word_vectors = np.array([np.array(x) for x in dataset['word_vectors']])
        # Model class from model.py
        model = LWAN(n_classes=train_y.shape[1], emb_weights=word_vectors)
    optimizer = Adam(learning_rate=Config['cnns']['learning_rate'])
    loss_fn = BinaryCrossentropy()
    #metric = BinaryAccuracy()
    metric = TopKCategoricalAccuracy()
    model.compile(optimizer, loss_fn, metric)

    # Start Training
    LOGGER.info('\n ------ Fitting model ------ ')
    ckpt_path = os.path.join(BASE + Config['model_location'],
                             Config['model_name'], Config['timestamp'],
                             'checkpoints', 'model-{epoch:02d}-{val_loss:.2f}')
    early_stopping = EarlyStopping(monitor='val_loss',
                                   patience=Config['patience'],
                                   restore_best_weights=True)
    model_checkpoint = ModelCheckpoint(filepath=ckpt_path,
                                       monitor='val_loss',
                                       mode='auto',
                                       save_freq='epoch',
                                       verbose=1,
                                       save_best_only=True,
                                       save_weights_only=False)
    start_time = time.time()  # start training
    fit_history = model.fit(train_generator,
                            validation_data=val_generator,
                            workers=os.cpu_count(),
                            epochs=Config['n_epochs'],
                            callbacks=[early_stopping, model_checkpoint])
    total_time = time.time() - start_time
    LOGGER.info('\nTotal Training Time: {} secs'.format(total_time))

    # Record final loss
    LOGGER.info('\n ----- Final Loss & Best Epoch ----- ')
    best_epoch = np.argmin(fit_history.history['val_loss']) + 1
    n_epochs = len(fit_history.history['val_loss'])
    val_loss_per_epoch = '- ' + ' '.join(
        '-' if fit_history.history['val_loss'][i] < np.
        min(fit_history.history['val_loss'][:i]) else '+'
        for i in range(1, len(fit_history.history['val_loss'])))
    LOGGER.info('Val loss per epoch: {}\n'.format(val_loss_per_epoch))
    LOGGER.info('\nBest epoch: {}/{}'.format(best_epoch, n_epochs))

    # Save model
    LOGGER.info('\n ----- Saving Final Model ------ ')
    best_val_loss = fit_history.history['val_loss'][best_epoch - 1]
    model_path = os.path.join(
        BASE + Config['model_location'], Config['model_name'],
        Config['timestamp'],
        '{}_FINAL_epoch_{:02d}_loss_{:.5f}'.format(Config['model_name'],
                                                   best_epoch, best_val_loss))
    LOGGER.info('Model Path: {}'.format(model_path))
    tf.saved_model.save(model, model_path)
Ejemplo n.º 9
0

q_aware_model = tf.keras.models.clone_model(
    model,
    clone_function=apply_quantization_to_dense,
)
adam = optimizers.Adam(learning_rate=lr,
                       beta_1=0.9,
                       beta_2=0.99,
                       epsilon=None,
                       decay=1e-5,
                       amsgrad=False)
q_aware_model.compile(optimizer=adam,
                      loss='categorical_crossentropy',
                      metrics=[
                          TopKCategoricalAccuracy(k=1, name='accuracy'),
                          TopKCategoricalAccuracy(k=3, name='top3_accuracy'),
                          TopKCategoricalAccuracy(k=5, name='top5_accuracy')
                      ])

# Define callbacks
model_folder = os.path.dirname(model_path)
model_filename = os.path.basename(model_path)
output_filename = model_filename[:model_filename.index('.hdf5'
                                                       )] + '_quant.hdf5'
checkpoint = ModelCheckpoint(filepath=os.path.join(model_folder,
                                                   output_filename),
                             save_best_only=True,
                             monitor='val_accuracy',
                             save_weights_only=False,
                             verbose=0)
Ejemplo n.º 10
0
def run(args, use_gpu=True):

    # saving
    save_path = os.path.join(os.getcwd(), 'models')
    if not os.path.isdir(save_path):
        os.mkdir(save_path)

    model = lipnext(inputDim=256,
                    hiddenDim=512,
                    nClasses=args.nClasses,
                    frameLen=29,
                    alpha=args.alpha)
    #model = tf.keras.Sequential([
    #    tf.keras.layers.Flatten(),
    #    tf.keras.layers.Dense(1)
    #])

    if args.train == True:
        mode = "train"
    else:
        mode = "test"


#    train_list = glob.glob("./test_tfrecord_ACTUALLY_color/*.tfrecords")
#    val_list = glob.glob("./test_tfrecord_ACTUALLY_color/*.tfrecords")
#    test_list = glob.glob("./test_tfrecord_ACTUALLY_color/*.tfrecords")
#train_list = glob.glob("/mnt/disks/data/dataset/lipread_tfrecords/*/train/*.tfrecords")
#val_list = glob.glob("/mnt/disks/data/dataset/lipread_tfrecords/*/val/*.tfrecords")
#test_list = glob.glob("/mnt/disks/data/dataset/lipread_tfrecords/*/test/*.tfrecords")

    with open(args.labels) as f:
        labels = f.read().splitlines()
    train_list = []
    val_list = []
    test_list = []
    for word in labels:
        print(word)
        train_list.extend(glob.glob(args.dataset + word +
                                    '/train/*.tfrecords'))
        val_list.extend(glob.glob(args.dataset + word + '/val/*.tfrecords'))
        test_list.extend(glob.glob(args.dataset + word + '/test/*.tfrecords'))
    # randomly shuffle *_list
    np.random.shuffle(train_list)
    np.random.shuffle(val_list)
    np.random.shuffle(test_list)

    if mode == "train":
        dataset = tf.data.TFRecordDataset(train_list)
        val_dataset = tf.data.TFRecordDataset(val_list)
    else:
        dataset = tf.data.TFRecordDataset(test_list)
    print("raw_dataset: ", dataset)

    if mode == "train":
        dataset = dataset.map(_parse_function)
        val_dataset = val_dataset.map(_parse_function)
        #check_dataset(dataset, "train_test_images", 'RGB')

        dataset = dataset.map(_train_preprocess_function)
        val_dataset = val_dataset.map(_test_preprocess_function)
        #check_dataset(dataset, "train_test_images_processed", 'L')

        dataset = dataset.map(
            lambda x, y: _normalize_function(x, y, args.nClasses))
        val_dataset = val_dataset.map(
            lambda x, y: _normalize_function(x, y, args.nClasses))

        dataset = dataset.shuffle(500).batch(args.batch_size,
                                             drop_remainder=True)
        val_dataset = val_dataset.batch(args.batch_size, drop_remainder=True)

    else:
        dataset = dataset.map(_parse_function)
        #check_dataset(dataset, "test_test_images", 'RGB')

        dataset = dataset.map(_test_preprocess_function)

        dataset = dataset.map(
            lambda x, y: _normalize_function(x, y, args.nClasses))

        dataset = dataset.batch(args.batch_size)

    #check_dataset(dataset, "test_test_images_processed", 'L')
    model.compile(optimizer=Adam(learning_rate=args.lr),
                  loss=CategoricalCrossentropy(from_logits=True),
                  metrics=[
                      'accuracy',
                      TopKCategoricalAccuracy(3),
                      keras.metrics.CategoricalAccuracy()
                  ])

    run_dir = args.save_path + datetime.now().strftime("%Y%m%d-%H%M%S")
    callbacks = [
        # Interrupt training if `val_loss` stops improving for over 2 epochs
        # tf.keras.callbacks.EarlyStopping(patience=2, monitor='val_loss'),
        # Learning rate scheduler
        tf.keras.callbacks.LearningRateScheduler(
            lambda e: lr_scheduler(e, args.lr)),
        # Save checkpoints
        tf.keras.callbacks.ModelCheckpoint(
            filepath=run_dir + '/checkpoints/Conv3D_model_{epoch}',
            #save_weights_only=True,
            save_best_only=True,
            monitor='val_loss'),
        # Write TensorBoard logs to `./logs` directory
        tf.keras.callbacks.TensorBoard(log_dir=run_dir + '/logs')
    ]
    file_writer = tf.summary.create_file_writer(run_dir + '/logs/metrics')
    file_writer.set_as_default()

    if args.checkpoint:
        print("Loading model from: ", args.checkpoint)
        #model.load_weights(args.checkpoint)
        model = tf.keras.models.load_model(args.checkpoint)
    else:
        print("Model training from scratch")

    if mode == "train":
        model.fit(dataset,
                  epochs=args.epochs,
                  callbacks=callbacks,
                  validation_data=val_dataset)
        #model.save_weights(args.save_path +'/final_weights/Conv3D_model')
    else:
        model.evaluate(dataset)