Ejemplo n.º 1
0
def train_image_classification_model(tp: MisoParameters):
    tf_version = int(tf.__version__[0])

    # Hack to make RTX cards work
    if tf_version == 2:
        physical_devices = tf.config.list_physical_devices('GPU')
        for device in physical_devices:
            tf.config.experimental.set_memory_growth(device, True)
    else:
        config = tf.ConfigProto()
        config.gpu_options.allow_growth = True
        session = tf.Session(config=config)

    K.clear_session()

    # Clean the training parameters
    tp.sanitise()

    print(
        "+---------------------------------------------------------------------------+"
    )
    print(
        "| MISO Particle Classification Library                                      |"
    )
    print(
        "+---------------------------------------------------------------------------+"
    )
    print(
        "| Stable version:                                                           |"
    )
    print(
        "| pip install miso2                                                         |"
    )
    print(
        "| Development version:                                                      |"
    )
    print(
        "| pip install git+http://www.github.com/microfossil/particle-classification |"
    )
    print(
        "+---------------------------------------------------------------------------+"
    )
    print("Tensorflow version: {}".format(tf.__version__))
    print("-" * 80)
    print("Train information:")
    print("- name: {}".format(tp.name))
    print("- description: {}".format(tp.description))
    print("- CNN type: {}".format(tp.cnn.id))
    print("- image type: {}".format(tp.cnn.img_type))
    print("- image shape: {}".format(tp.cnn.img_shape))
    print()

    # Load data
    ds = TrainingDataset(tp.dataset.source, tp.cnn.img_shape, tp.cnn.img_type,
                         tp.dataset.min_count, tp.dataset.map_others,
                         tp.dataset.val_split, tp.dataset.random_seed,
                         tp.dataset.memmap_directory)
    ds.load()
    tp.dataset.num_classes = ds.num_classes

    # Create save lodations
    now = datetime.datetime.now()
    save_dir = os.path.join(tp.output.save_dir,
                            "{0}_{1:%Y%m%d-%H%M%S}".format(tp.name, now))
    os.makedirs(save_dir, exist_ok=True)

    # ------------------------------------------------------------------------------
    # Transfer learning
    # ------------------------------------------------------------------------------
    if tp.cnn.id.endswith('tl'):
        print('-' * 80)
        print("Transfer learning network training")
        start = time.time()

        # Generate head model and predict vectors
        model_head = generate_tl_head(tp.cnn.id, tp.cnn.img_shape)

        # Calculate vectors
        print("- calculating vectors")
        t = time.time()
        gen = ds.images.create_generator(tp.training.batch_size,
                                         shuffle=False,
                                         one_shot=True)
        if tf_version == 2:
            vectors = model_head.predict(gen.create())
        else:
            vectors = predict_in_batches(model_head, gen.create())
        print("! {}s elapsed, ({}/{} vectors)".format(time.time() - t,
                                                      len(vectors),
                                                      len(ds.images.data)))

        # Clear session
        # K.clear_session()

        # Generate tail model and compile
        model_tail = generate_tl_tail(tp.dataset.num_classes, [
            vectors.shape[-1],
        ])
        model_tail.compile(optimizer='adam',
                           loss='categorical_crossentropy',
                           metrics=['accuracy'])

        # Learning rate scheduler
        alr_cb = AdaptiveLearningRateScheduler(
            nb_epochs=tp.training.alr_epochs,
            nb_drops=tp.training.alr_drops,
            verbose=1)
        print('-' * 80)
        print("Training")

        # Training generator
        train_gen = TFGenerator(
            vectors,
            ds.cls_onehot,
            ds.train_idx,
            tp.training.batch_size,
            shuffle=True,
            one_shot=False,
            undersample=tp.training.use_class_undersampling)

        # Validation generator
        if tf_version == 2:
            val_one_shot = True
        else:
            # One repeat for validation for TF1 otherwise we get end of dataset errors
            val_one_shot = False
        if tp.dataset.val_split > 0:
            val_gen = TFGenerator(vectors,
                                  ds.cls_onehot,
                                  ds.test_idx,
                                  tp.training.batch_size,
                                  shuffle=False,
                                  one_shot=val_one_shot)
            val_data = val_gen.create()
            val_steps = len(val_gen)
        else:
            val_gen = None
            val_data = None
            val_steps = None

        # Class weights (only if over sampling is not used)
        if tp.training.use_class_weights is True and tp.training.use_class_undersampling is False:
            class_weights = ds.class_weights
            print("- class weights: {}".format(class_weights))
            if tf_version == 2:
                class_weights = dict(enumerate(class_weights))
        else:
            class_weights = None
        if tp.training.use_class_undersampling:
            print("- class balancing using random under sampling")

        # v = model_tail.predict(vectors[0:1])
        # print(v[0, :10])

        # model_head.summary()

        # model = Model(inputs=model_head.input, outputs=model_tail(model_head.output))
        # vector_model = Model(model.inputs, model.get_layer(index=-2).get_output_at(0))
        # v = vector_model.predict(ds.images.data[0:1] / 255)
        # print(v[0, :10])
        # v = vector_model.predict(ds.images.data[0:1])
        # print(v[0, :10])

        # Train
        history = model_tail.fit_generator(train_gen.create(),
                                           steps_per_epoch=len(train_gen),
                                           validation_data=val_data,
                                           validation_steps=val_steps,
                                           epochs=tp.training.max_epochs,
                                           verbose=0,
                                           shuffle=False,
                                           max_queue_size=1,
                                           class_weight=class_weights,
                                           callbacks=[alr_cb])
        # Elapsed time
        end = time.time()
        training_time = end - start
        print("- training time: {}s".format(training_time))
        time.sleep(3)

        # Now we join the trained dense layers to the resnet model to create a model that accepts images as input
        # model_head = generate_tl_head(tp.cnn.id, tp.cnn.img_shape)
        model = combine_tl(model_head, model_tail)
        model.summary()

        # print(model.layers[-1])
        # print(model.layers[-2])
        #
        # vector_model = Model(inputs=model.inputs, outputs=model.layers[-2].get_output_at(1))
        # print(vector_model.layers[-1])
        # print(vector_model.layers[-2])
        # v = vector_model.predict(ds.images.data[0:1] / 255)
        # print(v[0, :10])
        # v = vector_model.predict(ds.images.data[0:1])
        # print(v[0, :10])

        # model = Model(inputs=model_head.input, outputs=model_tail(model_head.layers[-1].layers[-1].output))
        # model.summary()

        # print(model_tail.get_layer(index=-2).get_weights())
        #
        # vector_tensor = model_tail.get_layer(index=-2).get_output_at(0)
        # vector_model = Model(model_tail.inputs, vector_tensor)
        vector_model = generate_vector(model, tp.cnn.id)
        # v = vector_model.predict(vectors[0:1])
        # print(v[0, :10])
        #
        # vectors = model_head.predict(next(iter(gen.create())))
        # v = vector_model.predict(vectors[0:1])
        # print(v[0, :10])

    # ------------------------------------------------------------------------------
    # Full network train
    # ------------------------------------------------------------------------------
    else:
        print('-' * 80)
        print("Full network training")
        start = time.time()

        # Generate model
        model = generate(tp)
        model.summary()

        # Augmentation
        if tp.augmentation.rotation is True:
            tp.augmentation.rotation = [0, 360]
        elif tp.augmentation.rotation is False:
            tp.augmentation.rotation = None
        if tp.training.use_augmentation is True:
            print("- using augmentation")
            augment_fn = aug_all_fn(
                rotation=tp.augmentation.rotation,
                gain=tp.augmentation.gain,
                gamma=tp.augmentation.gamma,
                zoom=tp.augmentation.zoom,
                gaussian_noise=tp.augmentation.gaussian_noise,
                bias=tp.augmentation.bias,
                random_crop=tp.augmentation.random_crop,
                divide=255)
        else:
            print("- NOT using augmentation")
            augment_fn = TFGenerator.map_fn_divide_255

        # Learning rate scheduler
        alr_cb = AdaptiveLearningRateScheduler(
            nb_epochs=tp.training.alr_epochs,
            nb_drops=tp.training.alr_drops,
            verbose=1)

        # Training generator
        train_gen = ds.train_generator(
            batch_size=tp.training.batch_size,
            map_fn=augment_fn,
            undersample=tp.training.use_class_undersampling)

        # Save example of training data
        print(" - saving example training batch")
        training_examples_dir = os.path.join(save_dir, "examples", "training")
        os.makedirs(training_examples_dir)
        images, labels = next(iter(train_gen.create()))
        for t_idx, im in enumerate(images):
            im = (im * 255)
            im[im > 255] = 255
            skimage.io.imsave(
                os.path.join(training_examples_dir,
                             "{:03d}.jpg".format(t_idx)), im.astype(np.uint8))

        # Validation generator
        if tf_version == 2:
            val_one_shot = True
        else:
            # One repeat for validation for TF1 otherwise we get end of dataset errors
            val_one_shot = False
        if tp.dataset.val_split > 0:
            # Maximum 8 in batch otherwise validation results jump around a bit because
            val_gen = ds.test_generator(min(tp.training.batch_size, 16),
                                        shuffle=False,
                                        one_shot=val_one_shot)
            # val_gen = ds.test_generator(tp.training.batch_size, shuffle=False, one_shot=val_one_shot)
            val_data = val_gen.create()
            val_steps = len(val_gen)
        else:
            val_gen = None
            val_data = None
            val_steps = None

        # Class weights
        if tp.training.use_class_weights is True and tp.training.use_class_undersampling is False:
            class_weights = ds.class_weights
            print("- class weights: {}".format(class_weights))
            if tf_version == 2:
                class_weights = dict(enumerate(class_weights))
        else:
            class_weights = None
        if tp.training.use_class_undersampling:
            print("- class balancing using random under sampling")

        # Train the model
        history = model.fit_generator(train_gen.create(),
                                      steps_per_epoch=len(train_gen),
                                      validation_data=val_data,
                                      validation_steps=val_steps,
                                      epochs=tp.training.max_epochs,
                                      verbose=0,
                                      shuffle=False,
                                      max_queue_size=1,
                                      class_weight=class_weights,
                                      callbacks=[alr_cb])

        # Elapsed time
        end = time.time()
        training_time = end - start
        print()
        print("Total training time: {}s".format(training_time))
        time.sleep(3)

        # Vector model
        vector_model = generate_vector(model, tp.cnn.id)

    # ------------------------------------------------------------------------------
    # Results
    # ------------------------------------------------------------------------------
    print('-' * 80)
    print("Evaluating model")
    # Accuracy
    if tp.dataset.val_split > 0:
        y_true = ds.cls[ds.test_idx]
        gen = ds.test_generator(tp.training.batch_size,
                                shuffle=False,
                                one_shot=True)
        if tf_version == 2:
            y_prob = model.predict(gen.create())
        else:
            y_prob = predict_in_batches(model, gen.create())
        y_pred = y_prob.argmax(axis=1)
    else:
        y_true = np.asarray([])
        y_prob = np.asarray([])
        y_pred = np.asarray([])
    # Inference time
    print("- calculating inference time:", end='')
    max_count = np.min([128, len(ds.images.data)])
    inf_times = []
    for i in range(3):
        gen = ds.images.create_generator(tp.training.batch_size,
                                         idxs=np.arange(max_count),
                                         shuffle=False,
                                         one_shot=True)
        start = time.time()
        if tf_version == 2:
            model.predict(gen.create())
        else:
            predict_in_batches(model, gen.create())
        end = time.time()
        diff = (end - start) / max_count * 1000
        inf_times.append(diff)
        print(" {:.3f}ms".format(diff), end='')
    inference_time = np.median(inf_times)
    print(", median: {}".format(inference_time))
    # Store results
    # - fix to make key same for tensorflow 1 and 2
    if 'accuracy' in history.history:
        history.history['acc'] = history.history.pop('accuracy')
        history.history['val_acc'] = history.history.pop('val_accuracy')
    result = TrainingResult(tp, history, y_true, y_pred, y_prob, ds.cls_labels,
                            training_time, inference_time)
    print("- accuracy {:.2f}".format(result.accuracy * 100))
    print("- mean precision {:.2f}".format(result.mean_precision * 100))
    print("- mean recall {:.2f}".format(result.mean_recall * 100))
    # ------------------------------------------------------------------------------
    # Save results
    # ------------------------------------------------------------------------------
    if tp.description is None:
        tp.description = "{}: {} model trained on data from {} ({} images in {} classes).\n" \
                         "Accuracy: {:.1f} (P: {:.1f}, R: {:.1f}, F1 {:.1f})".format(
            tp.name,
            tp.cnn.id,
            tp.dataset.source,
            len(ds.filenames.filenames),
            len(ds.cls_labels),
            result.accuracy * 100,
            result.mean_precision * 100,
            result.mean_recall * 100,
            result.mean_f1_score * 100)

    # Create model info with all the parameters
    inputs = OrderedDict()
    inputs["image"] = model.inputs[0]
    outputs = OrderedDict()
    outputs["pred"] = model.outputs[0]
    outputs["vector"] = vector_model.outputs[0]
    info = ModelInfo(tp.name, tp.description, tp.cnn.id, now,
                     "frozen_model.pb", tp, inputs, outputs, tp.dataset.source,
                     ds.cls_labels, ds.filenames.cls_counts, "rescale",
                     [255, 0, 1], result.accuracy, result.precision,
                     result.recall, result.f1_score, result.support,
                     result.epochs[-1], training_time, tp.dataset.val_split,
                     inference_time)
    # ------------------------------------------------------------------------------
    # Plots
    # ------------------------------------------------------------------------------
    # Plot the graphs
    # plot_model(model, to_file=os.path.join(save_dir, "model_plot.pdf"), show_shapes=True)
    print("-" * 80)
    print("Plotting")
    if tp.dataset.val_split > 0:
        print("- loss")
        plot_loss_vs_epochs(history)
        plt.savefig(os.path.join(save_dir, "loss_vs_epoch.pdf"))
        print("- accuracy")
        plot_accuracy_vs_epochs(history)
        plt.savefig(os.path.join(save_dir, "accuracy_vs_epoch.pdf"))
        print("- confusion matrix")
        plot_confusion_accuracy_matrix(y_true, y_pred, ds.cls_labels)
        plt.savefig(os.path.join(save_dir, "confusion_matrix.pdf"))
        plt.close('all')

    if tp.output.save_mislabeled is True:
        print("- mislabeled")
        print("- calculating vectors... ", end='')
        gen = ds.images.create_generator(tp.training.batch_size,
                                         shuffle=False,
                                         one_shot=True)
        if tf_version == 2:
            vectors = vector_model.predict(gen.create())
        else:
            vectors = predict_in_batches(vector_model, gen.create())
        print("{} total".format(len(vectors)))
        find_and_save_mislabelled(
            ds.images.data, vectors, ds.cls, ds.cls_labels,
            [os.path.basename(f)
             for f in ds.filenames.filenames], save_dir, 11)

    # t-SNE
    print("- t-SNE (1024 vectors max)")
    print("- calculating vectors... ", end='')
    idxs = np.random.choice(np.arange(len(ds.images.data)),
                            np.min((1024, len(ds.images.data))),
                            replace=False)
    gen = ds.images.create_generator(tp.training.batch_size,
                                     idxs=idxs,
                                     shuffle=False,
                                     one_shot=True)
    if tf_version == 2:
        vec_subset = vector_model.predict(gen.create())
    else:
        vec_subset = predict_in_batches(vector_model, gen.create())
    X = TSNE(n_components=2).fit_transform(vec_subset)
    plot_embedding(X, ds.cls[idxs], ds.num_classes)
    plt.savefig(os.path.join(save_dir, "tsne.pdf"))
    cls_info = pd.DataFrame({
        "index": range(ds.num_classes),
        "label": ds.cls_labels
    })
    cls_info.to_csv(os.path.join(save_dir, "legend.csv"), sep=';')

    # ------------------------------------------------------------------------------
    # Save model (has to be last thing it seems)
    # ------------------------------------------------------------------------------
    print('-' * 80)
    print("Saving model")
    # Convert if necessary to fix TF batch normalisation issues

    # Freeze and save graph
    if tp.output.save_model is not None:
        if tf_version == 2:
            inference_model = convert_to_inference_mode_tf2(
                model, lambda: generate(tp))
            tf.saved_model.save(
                inference_model,
                os.path.join(os.path.join(save_dir, "model_keras")))
            frozen_func = save_frozen_model_tf2(
                inference_model, os.path.join(save_dir, "model"),
                "frozen_model.pb")
            info.inputs["image"] = frozen_func.inputs[0]
            info.outputs["pred"] = frozen_func.outputs[0]
        else:
            inference_model = convert_to_inference_mode(
                model, lambda: generate(tp))
            tf.saved_model.save(
                inference_model,
                os.path.join(os.path.join(save_dir, "model_keras")))
            freeze(inference_model, os.path.join(save_dir, "model"))

    # Save model info
    info.save(os.path.join(save_dir, "model", "network_info.xml"))

    # ------------------------------------------------------------------------------
    # Confirm model save
    # ------------------------------------------------------------------------------
    if tp.output.save_model is not None and tp.dataset.val_split > 0:
        print("-" * 80)
        print("Validate saved model")
        y_pred_old = y_pred
        y_true = ds.cls[ds.test_idx]
        gen = ds.test_generator(32, shuffle=False, one_shot=True)
        y_prob = []
        if tf_version == 2:
            model, img_size, cls_labels = load_from_xml(
                os.path.join(save_dir, "model", "network_info.xml"))
            for b in iter(gen.to_tfdataset()):
                y_prob.append(model(b[0]).numpy())
        else:
            session, input, output, img_size, cls_labels = load_from_xml(
                os.path.join(save_dir, "model", "network_info.xml"))
            iterator = iter(gen.tf1_compat_generator())
            for bi in range(len(gen)):
                b = next(iterator)
                y_p = session.run(output, feed_dict={input: b[0]})
                y_prob.append(y_p)
        y_prob = np.concatenate(y_prob, axis=0)
        y_pred = y_prob.argmax(axis=1)
        acc = accuracy_score(y_true, y_pred)
        p, r, f1, _ = precision_recall_fscore_support(y_true, y_pred)
        print(
            "Saved model on test set: acc {:.2f}, prec {:.2f}, rec {:.2f}, f1 {:.2f}"
            .format(acc, np.mean(p), np.mean(r), np.mean(f1)))
        acc = accuracy_score(y_pred_old, y_pred)
        if acc == 1.0:
            print("Overlap: {:.2f}% - PASSED".format(acc * 100))
        else:
            print("Overlap: {:.2f}% - FAILED".format(acc * 100))

    # ------------------------------------------------------------------------------
    # Clean up
    # ------------------------------------------------------------------------------
    print("- cleaning up")
    ds.release()
    print("- complete")
    print('-' * 80)
    print()
    return model, vector_model, ds, result
def train_image_classification_model(params: dict,
                                     data_source: DataSource = None):
    K.clear_session()

    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    session = tf.Session(config=config)

    intro()

    # Params -----------------------------------------------------------------------------------------------------------
    name = params.get('name')
    description = params.get('description')

    # Network
    cnn_type = params.get('type')

    # Input
    img_size = params.get('img_size')
    if img_size is not None:
        [img_height, img_width, img_channels] = params.get('img_size')
    else:
        img_height = params.get('img_height')
        img_width = params.get('img_width')
        img_channels = params.get('img_channels')

    # Training
    batch_size = params.get('batch_size', 64)
    max_epochs = params.get('max_epochs', 1000)
    alr_epochs = params.get('alr_epochs', 10)
    alr_drops = params.get('alr_drops', 4)

    # Input data
    input_dir = params.get('input_source', None)
    data_min_count = params.get('data_min_count', 40)
    data_split = params.get('data_split', 0.2)
    data_split_offset = params.get('data_split_offset', 0)
    seed = params.get('seed', None)

    # Output
    output_dir = params.get('save_dir')

    # Type
    # - rgb
    # - greyscale
    # - greyscale3
    # - rgbd
    # - greyscaled
    img_type = params.get('img_type', None)
    if img_type is None:
        if img_channels == 3:
            img_type = 'rgb'
        elif img_channels == 1:
            if cnn_type.endswith('tl'):
                img_type = 'greyscale3'
                params['img_channels'] = 3
            else:
                img_type = 'greyscale'
        else:
            raise ValueError("Number of channels must be 1 or 3")
    elif img_type == 'rgbd':
        params['img_channels'] = 4
    elif img_type == 'greyscaled':
        params['img_channels'] = 2
    elif img_type == 'greyscaledm':
        params['img_channels'] = 3

    print('@ Image type: {}'.format(img_type))

    # Data -------------------------------------------------------------------------------------------------------------
    if data_source is None:
        data_source = DataSource()
        data_source.use_mmap = params['use_mmap']
        data_source.set_source(input_dir,
                               data_min_count,
                               mapping=params['class_mapping'],
                               min_count_to_others=params['data_map_others'],
                               mmap_directory=params['mmap_directory'])
        data_source.load_dataset(img_size=(img_height, img_width),
                                 img_type=img_type)
    data_source.split(data_split, seed)

    if params['use_class_weights'] is True:
        params['class_weights'] = data_source.get_class_weights()
        print("@ Class weights are {}".format(params['class_weights']))
    else:
        params['class_weights'] = None
    params['num_classes'] = data_source.num_classes

    if cnn_type.endswith('tl'):
        start = time.time()

        # Generate vectors
        model_head = generate_tl_head(params)
        print("@ Calculating train vectors")
        t = time.time()
        train_vector = model_head.predict(data_source.train_images)
        print("! {}s elapsed".format(time.time() - t))
        print("@ Calculating test vectors")
        t = time.time()
        test_vector = model_head.predict(data_source.test_images)
        print("! {}s elapsed".format(time.time() - t))
        # Clear
        K.clear_session()

        config = tf.ConfigProto()
        config.gpu_options.allow_growth = True
        session = tf.Session(config=config)

        data_source.train_vectors = train_vector
        data_source.test_vectors = test_vector

        # Augmentation -------------------------------------------------------------------------------------------------
        # No augmentation as we pre-calculate vectors

        # Generator ----------------------------------------------------------------------------------------------------
        # No generator needed

        # Model --------------------------------------------------------------------------------------------------------
        print("@ Generating tail")
        # Get  tail
        model_tail = generate_tl_tail(params, [
            train_vector.shape[1],
        ])
        model_tail.compile(optimizer='adam',
                           loss='categorical_crossentropy',
                           metrics=['accuracy'])

        # Generator ----------------------------------------------------------------------------------------------------
        train_gen = tf_vector_generator(train_vector,
                                        data_source.train_onehots, batch_size)
        test_gen = tf_vector_generator(test_vector, data_source.test_onehots,
                                       batch_size)

        # Training -----------------------------------------------------------------------------------------------------
        alr_cb = AdaptiveLearningRateScheduler(nb_epochs=alr_epochs,
                                               nb_drops=alr_drops)
        print("@ Training")
        if data_split > 0:
            validation_data = test_gen
        else:
            validation_data = None
        # log_dir = "C:\\logs\\profile\\" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
        # tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=log_dir, histogram_freq=1, profile_batch=3)
        history = model_tail.fit_generator(
            train_gen,
            steps_per_epoch=math.ceil(len(train_vector) // batch_size),
            validation_data=validation_data,
            validation_steps=math.ceil(len(test_vector) // batch_size),
            epochs=max_epochs,
            verbose=0,
            shuffle=False,
            max_queue_size=1,
            class_weight=params['class_weights'],
            callbacks=[alr_cb])
        end = time.time()
        training_time = end - start
        print("@ Training time: {}s".format(training_time))
        time.sleep(3)

        # Generator ----------------------------------------------------------------------------------------------------
        # Now we be tricky and join the trained dense layers to the resnet model to create a model that accepts images
        # as input
        model_head = generate_tl_head(params)
        outputs = model_tail(model_head.output)
        model = Model(model_head.input, outputs)
        model.summary()

        # Vector -------------------------------------------------------------------------------------------------------
        vector_model = generate_vector(model, params)

    else:
        # Model --------------------------------------------------------------------------------------------------------
        print("@ Generating model")
        start = time.time()
        model = generate(params)

        # Augmentation -------------------------------------------------------------------------------------------------
        if params['aug_rotation'] is True:
            rotation_range = [0, 360]
        else:
            rotation_range = None

        def augment(x):
            return augmentation_complete(
                x,
                rotation=rotation_range,
                gain=params['aug_gain'],
                gamma=params['aug_gamma'],
                zoom=params['aug_zoom'],
                gaussian_noise=params['aug_gaussian_noise'],
                bias=params['aug_bias'])

        if params['use_augmentation'] is True:
            augment_fn = augment
        else:
            augment_fn = None

        # Generator ----------------------------------------------------------------------------------------------------
        train_gen = tf_augmented_image_generator(data_source.train_images,
                                                 data_source.train_onehots,
                                                 batch_size, augment_fn)
        test_gen = image_generator(data_source.test_images,
                                   data_source.test_onehots, batch_size)

        # Training -----------------------------------------------------------------------------------------------------

        alr_cb = AdaptiveLearningRateScheduler(nb_epochs=alr_epochs,
                                               nb_drops=alr_drops)
        print("@ Training")
        if data_split > 0:
            validation_data = test_gen
        else:
            validation_data = None
        history = model.fit_generator(
            train_gen,
            steps_per_epoch=math.ceil(
                len(data_source.train_images) // batch_size),
            validation_data=validation_data,
            validation_steps=math.ceil(
                len(data_source.test_images) // batch_size),
            epochs=max_epochs,
            verbose=0,
            shuffle=False,
            max_queue_size=1,
            class_weight=params['class_weights'],
            callbacks=[alr_cb])
        end = time.time()
        training_time = end - start
        print("@ Training time: {}s".format(training_time))
        time.sleep(3)

        # Vector -------------------------------------------------------------------------------------------------------
        vector_model = generate_vector(model, params)

    # Graphs -----------------------------------------------------------------------------------------------------------
    print("@ Generating results")
    if data_split > 0:
        # Calculate test set scores
        y_true = data_source.test_cls
        y_prob = model.predict(data_source.test_images)
        y_pred = y_prob.argmax(axis=1)
    else:
        y_true = np.asarray([])
        y_prob = np.asarray([])
        y_pred = np.asarray([])

    # Inference time
    max_count = np.min([1000, len(data_source.images)])
    to_predict = np.copy(data_source.images[0:max_count])

    inf_times = []
    for i in range(3):
        start = time.time()
        model.predict(to_predict)
        end = time.time()
        diff = (end - start) / max_count * 1000
        inf_times.append(diff)
        print("@ Calculating inference time {}/10: {:.3f}ms".format(
            i + 1, diff))
    inference_time = np.median(inf_times)

    # Store results
    result = TrainingResult(params, history, y_true, y_pred, y_prob,
                            data_source.cls_labels, training_time,
                            inference_time)

    # Save the results
    now = datetime.datetime.now()
    save_dir = os.path.join(output_dir,
                            "{0}_{1:%Y%m%d-%H%M%S}".format(name, now))
    os.makedirs(save_dir, exist_ok=True)

    # Plot the graphs
    # plot_model(model, to_file=os.path.join(save_dir, "model_plot.pdf"), show_shapes=True)
    if data_split > 0:
        plot_loss_vs_epochs(history)
        plt.savefig(os.path.join(save_dir, "loss_vs_epoch.pdf"))
        plot_accuracy_vs_epochs(history)
        plt.savefig(os.path.join(save_dir, "accuracy_vs_epoch.pdf"))
        plot_confusion_accuracy_matrix(y_true, y_pred, data_source.cls_labels)
        plt.savefig(os.path.join(save_dir, "confusion_matrix.pdf"))
        plt.close('all')

    if params['save_mislabeled'] is True:
        print("@ Estimating mislabeled")
        vectors = vector_model.predict(data_source.images)
        find_and_save_mislabelled(data_source.images, vectors, data_source.cls,
                                  data_source.cls_labels,
                                  data_source.get_short_filenames(), save_dir,
                                  11)

    # Save model -------------------------------------------------------------------------------------------------------
    print("@ Saving model")
    # Convert if necessary to fix TF batch normalisation issues
    model = convert_to_inference_mode(model, lambda: generate(params))
    vector_model = generate_vector(model, params)

    # Generate description
    if description is None:
        description = "{}: {} model trained on data from {} ({} images in {} classes).\n" \
                      "Accuracy: {:.1f} (P: {:.1f}, R: {:.1f}, F1 {:.1f})".format(
            name,
            cnn_type,
            input_dir,
            len(data_source.data_df),
            len(data_source.cls_labels),
            result.accuracy * 100,
            result.mean_precision() * 100,
            result.mean_recall() * 100,
            result.mean_f1_score() * 100)

    # Create model info with all the parameters
    inputs = OrderedDict()
    inputs["image"] = model.inputs[0]
    outputs = OrderedDict()
    outputs["pred"] = model.outputs[0]
    outputs["vector"] = vector_model.outputs[0]
    info = ModelInfo(name, description, cnn_type, now, "frozen_model.pb",
                     params, inputs, outputs, input_dir,
                     data_source.cls_labels, data_source.cls_counts, "rescale",
                     [255, 0, 1], result.accuracy, result.precision,
                     result.recall, result.f1_score, result.support,
                     result.epochs[-1], training_time, params['data_split'],
                     inference_time)

    # Freeze and save graph
    if params['save_model'] is not None:
        freeze(model, os.path.join(save_dir, "model"), info)

    # Save info
    info.save(os.path.join(save_dir, "model", "network_info.xml"))

    print("@ Deleting temporary files")
    data_source.delete_memmap_files(del_split=True,
                                    del_source=params['delete_mmap_files'])

    wave()

    print("@ Complete")
    return model, vector_model, data_source, result
def train_image_classification_model(params: dict,
                                     data_source: DataSource = None):
    # Make both backends use the same session
    K.clear_session()

    intro()

    # Params -----------------------------------------------------------------------------------------------------------
    name = params.get('name')
    description = params.get('description')

    # Network
    cnn_type = params.get('type')

    # Input
    img_height = params.get('img_height')
    img_width = params.get('img_width')
    img_channels = params.get('img_channels')

    # Training
    batch_size = params.get('batch_size')
    max_epochs = params.get('max_epochs')
    alr_epochs = params.get('alr_epochs')
    alr_drops = params.get('alr_drops')

    # Input data
    input_dir = params.get('input_source')
    data_min_count = params.get('data_min_count')
    data_split = params.get('data_split')
    data_split_offset = params.get('data_split_offset')
    seed = params.get('seed')

    # Output
    output_dir = params.get('save_dir')

    # Data -------------------------------------------------------------------------------------------------------------
    # print("@Loading images...")
    if img_channels == 3:
        color_mode = 'rgb'
    else:
        if cnn_type.endswith('tl'):
            color_mode = 'greyscale3'
            params['img_channels'] = 3
        else:
            color_mode = 'greyscale'
    print('Color mode: {}'.format(color_mode))

    if data_source is None:
        data_source = DataSource()
        data_source.use_mmap = params['use_mmap']
        data_source.set_source(input_dir,
                               data_min_count,
                               mapping=params['class_mapping'],
                               min_count_to_others=params['data_map_others'],
                               mmap_directory=params['mmap_directory'])
        data_source.load_dataset(img_size=(img_height, img_width),
                                 prepro_type=None,
                                 prepro_params=(255, 0, 1),
                                 img_type=color_mode,
                                 print_status=True)
    data_source.split(data_split, data_split_offset, seed)

    if params['use_class_weights'] is True:
        params['class_weights'] = data_source.get_class_weights()
        print("@Class weights are {}".format(params['class_weights']))
    else:
        params['class_weights'] = None
    params['num_classes'] = data_source.num_classes

    if cnn_type.endswith('tl'):

        # mnist = tf.keras.datasets.mnist
        # (x_train, y_train), (x_test, y_test) = mnist.load_data()
        # x_train, x_test = x_train / 255.0, x_test / 255.0
        # model = tf.keras.models.Sequential([
        #     tf.keras.layers.Flatten(input_shape=(28, 28)),
        #     tf.keras.layers.Dense(128, activation='relu'),
        #     tf.keras.layers.Dropout(0.2),
        #     tf.keras.layers.Dense(10)
        # ])
        # loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
        # model.compile(optimizer='adam',
        #               loss=loss_fn,
        #               metrics=['accuracy'])
        # model.fit(x_train, y_train, epochs=5)
        # model.evaluate(x_test, y_test, verbose=2)
        # probability_model = tf.keras.Sequential([
        #     model,
        #     tf.keras.layers.Softmax()
        # ])

        start = time.time()
        # Create Vectors -----------------------------------------------------------------------------------------------
        # Note that the images are scaled internally in the network to match the expected preprocessing
        # print(time.time())
        # model_head.predict(data_source.train_images[0:1024])
        # print(time.time())
        # model_head.predict(data_source.train_images[0:1024])
        # print(time.time())
        # test = data_source.train_images[0:1024].copy()
        # model_head.predict(test)
        # print(time.time())
        # model_head.predict(data_source.train_images[0:1024])
        # print(time.time())
        # test = data_source.train_images[0:1024].copy()
        # model_head.predict(test)
        # print(time.time())

        # Generate vectors
        model_head = generate_tl_head(params)
        print("@Calculating train vectors")
        t = time.time()
        train_vector = model_head.predict(data_source.train_images)
        print("!{}s elapsed".format(time.time() - t))
        print("@Calculating test vectors")
        t = time.time()
        test_vector = model_head.predict(data_source.test_images)
        print("!{}s elapsed".format(time.time() - t))
        # Clear
        K.clear_session()

        # print(train_vector.dtype)
        # print(test_vector.dtype)

        # Generate vectors (random!)
        # train_vector = np.random.random(size=[data_source.train_images.shape[0], 2048])
        # test_vector = np.random.random(size=[data_source.test_images.shape[0], 2048])

        # train_vector = []
        # test_vector = []
        # step = 64
        #
        # for i in range(0, len(data_source.train_images), step):
        #     train_vector.append(model_head.predict(data_source.train_images[i:i+step]))
        #     print("@Calculating train vectors - {} of {}".format(i, len(data_source.train_images)))
        # train_vector = np.concatenate(train_vector, axis=0)
        #
        # for i in range(0, len(data_source.test_images), step):
        #     test_vector.append(model_head.predict(data_source.test_images[i:i + step]))
        #     print("@Calculating test vectors - {} of {}".format(i, len(data_source.test_images)))
        # test_vector = np.concatenate(test_vector, axis=0)

        data_source.train_vectors = train_vector
        data_source.test_vectors = test_vector

        # Augmentation -------------------------------------------------------------------------------------------------
        # No augmentation - there is a bug in the batch normalisation layer for tensorflow v1.xx where the mean and variance
        # are still calculated even when the layer is set to not trainable. This means the vectors produced are not the
        # vary according to the batch. For augmentation we need to include the ResNet network (with its batch normalisation
        # layers) in the graph, and because of this bug, the training performance is poor.

        # Generator ----------------------------------------------------------------------------------------------------
        # No generator needed

        # Model --------------------------------------------------------------------------------------------------------
        print("@Generating tail")
        # Get  tail
        model_tail = generate_tl_tail(params, [
            train_vector.shape[1],
        ])
        model_tail.compile(optimizer='adam',
                           loss='categorical_crossentropy',
                           metrics=['accuracy'])

        # Training -----------------------------------------------------------------------------------------------------
        # alr_cb = AdaptiveLearningRateScheduler(nb_epochs=alr_epochs,
        #                                        nb_drops=alr_drops)
        # print("@Training")
        # if data_split > 0:
        #     validation_data = (test_vector, data_source.test_onehots)
        # else:
        #     validation_data = None
        # history = model_tail.fit(train_vector,
        #                          data_source.train_onehots,
        #                          validation_data=validation_data,
        #                          epochs=max_epochs,
        #                          batch_size=batch_size,
        #                          shuffle=True,
        #                          verbose=0,
        #                          class_weight=params['class_weights'],
        #                          callbacks=[alr_cb])
        # end = time.time()
        # training_time = end - start
        # print("@Training time: {}s".format(training_time))
        # time.sleep(3)

        # Generator ----------------------------------------------------------------------------------------------------
        train_gen = tf_vector_generator(train_vector,
                                        data_source.train_onehots, batch_size)
        test_gen = tf_vector_generator(test_vector, data_source.test_onehots,
                                       batch_size)

        # Training -----------------------------------------------------------------------------------------------------
        alr_cb = AdaptiveLearningRateScheduler(nb_epochs=alr_epochs,
                                               nb_drops=alr_drops)
        print("@Training")
        if data_split > 0:
            validation_data = test_gen
        else:
            validation_data = None
        # log_dir = "C:\\logs\\profile\\" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
        # tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=log_dir, histogram_freq=1, profile_batch=3)
        history = model_tail.fit_generator(
            train_gen,
            steps_per_epoch=math.ceil(len(train_vector) // batch_size),
            validation_data=validation_data,
            validation_steps=math.ceil(len(test_vector) // batch_size),
            epochs=max_epochs,
            verbose=0,
            shuffle=False,
            max_queue_size=1,
            class_weight=params['class_weights'],
            callbacks=[alr_cb])
        end = time.time()
        training_time = end - start
        print("@Training time: {}s".format(training_time))
        time.sleep(3)

        # Generator ----------------------------------------------------------------------------------------------------
        # Now we be tricky and join the trained dense layers to the resnet model to create a model that accepts images
        # as input
        model_head = generate_tl_head(params)
        outputs = model_tail(model_head.output)
        model = Model(model_head.input, outputs)
        model.summary()

        # Vector -------------------------------------------------------------------------------------------------------
        vector_model = generate_vector(model, params)

    else:
        # Model --------------------------------------------------------------------------------------------------------
        print("@Generating model")
        start = time.time()
        model = generate(params)

        # Augmentation -------------------------------------------------------------------------------------------------
        if params['aug_rotation'] is True:
            rotation_range = [0, 360]
        else:
            rotation_range = None

        def augment(x):
            return augmentation_complete(
                x,
                rotation=rotation_range,
                gain=params['aug_gain'],
                gamma=params['aug_gamma'],
                zoom=params['aug_zoom'],
                gaussian_noise=params['aug_gaussian_noise'],
                bias=params['aug_bias'])

        if params['use_augmentation'] is True:
            augment_fn = augment
        else:
            augment_fn = None

        # Generator ----------------------------------------------------------------------------------------------------
        train_gen = tf_augmented_image_generator(data_source.train_images,
                                                 data_source.train_onehots,
                                                 batch_size, augment_fn)
        test_gen = image_generator(data_source.test_images,
                                   data_source.test_onehots, batch_size)

        # Training -----------------------------------------------------------------------------------------------------

        alr_cb = AdaptiveLearningRateScheduler(nb_epochs=alr_epochs,
                                               nb_drops=alr_drops)
        print("@Training")
        if data_split > 0:
            validation_data = test_gen
        else:
            validation_data = None
        history = model.fit_generator(
            train_gen,
            steps_per_epoch=math.ceil(
                len(data_source.train_images) // batch_size),
            validation_data=validation_data,
            validation_steps=math.ceil(
                len(data_source.test_images) // batch_size),
            epochs=max_epochs,
            verbose=0,
            shuffle=False,
            max_queue_size=1,
            class_weight=params['class_weights'],
            callbacks=[alr_cb])
        end = time.time()
        training_time = end - start
        print("@Training time: {}s".format(training_time))
        time.sleep(3)

        # Vector -------------------------------------------------------------------------------------------------------
        vector_model = generate_vector(model, params)

    # Graphs -----------------------------------------------------------------------------------------------------------
    print("@Generating results")
    if data_split > 0:
        # Calculate test set scores
        y_true = data_source.test_cls
        y_prob = model.predict(data_source.test_images)
        y_pred = y_prob.argmax(axis=1)
    else:
        y_true = np.asarray([])
        y_prob = np.asarray([])
        y_pred = np.asarray([])

    # Inference time
    max_count = np.min([1000, len(data_source.images)])
    to_predict = np.copy(data_source.images[0:max_count])

    inf_times = []
    for i in range(3):
        start = time.time()
        model.predict(to_predict)
        end = time.time()
        diff = (end - start) / max_count * 1000
        inf_times.append(diff)
        print("@Calculating inference time {}/10: {:.3f}ms".format(
            i + 1, diff))
    inference_time = np.median(inf_times)

    # Store results
    result = TrainingResult(params, history, y_true, y_pred, y_prob,
                            data_source.cls_labels, training_time,
                            inference_time)

    # Save the results
    now = datetime.datetime.now()
    save_dir = os.path.join(output_dir,
                            "{0}_{1:%Y%m%d-%H%M%S}".format(name, now))
    os.makedirs(save_dir, exist_ok=True)

    # Plot the graphs
    # plot_model(model, to_file=os.path.join(save_dir, "model_plot.pdf"), show_shapes=True)
    if data_split > 0:
        plot_loss_vs_epochs(history)
        plt.savefig(os.path.join(save_dir, "loss_vs_epoch.pdf"))
        plot_accuracy_vs_epochs(history)
        plt.savefig(os.path.join(save_dir, "accuracy_vs_epoch.pdf"))
        plot_confusion_accuracy_matrix(y_true, y_pred, data_source.cls_labels)
        plt.savefig(os.path.join(save_dir, "confusion_matrix.pdf"))
        plt.close('all')

    if params['save_mislabeled'] is True:
        print("@Estimating mislabeled")
        vectors = vector_model.predict(data_source.images)
        find_and_save_mislabelled(data_source.images, vectors, data_source.cls,
                                  data_source.cls_labels,
                                  data_source.get_short_filenames(), save_dir,
                                  11)

    # Save model -------------------------------------------------------------------------------------------------------
    print("@Saving model")
    # Convert if necessary to fix TF batch normalisation issues
    model = convert_to_inference_mode(model, lambda: generate(params))
    vector_model = generate_vector(model, params)

    # Generate description
    if description is None:
        description = "{}: {} model trained on data from {} ({} images in {} classes).\n" \
                      "Accuracy: {:.1f} (P: {:.1f}, R: {:.1f}, F1 {:.1f})".format(
            name,
            cnn_type,
            input_dir,
            len(data_source.data_df),
            len(data_source.cls_labels),
            result.accuracy * 100,
            result.mean_precision() * 100,
            result.mean_recall() * 100,
            result.mean_f1_score() * 100)

    # Create model info with all the parameters
    inputs = OrderedDict()
    inputs["image"] = model.inputs[0]
    outputs = OrderedDict()
    outputs["pred"] = model.outputs[0]
    outputs["vector"] = vector_model.outputs[0]
    info = ModelInfo(name, description, cnn_type, now, "frozen_model.pb",
                     params, inputs, outputs, input_dir,
                     data_source.cls_labels, data_source.cls_counts, "rescale",
                     [255, 0, 1], result.accuracy, result.precision,
                     result.recall, result.f1_score, result.support,
                     result.epochs[-1], training_time, params['data_split'],
                     inference_time)

    # Freeze and save graph
    if params['save_model'] is not None:
        freeze(model, os.path.join(save_dir, "model"), info)

    # Save info
    info.save(os.path.join(save_dir, "model", "network_info.xml"))

    print("@Deleting temporary files")
    data_source.delete_memmap_files(del_split=True,
                                    del_source=params['delete_mmap_files'])

    wave()

    print("@Complete")
    return model, vector_model, data_source, result