Exemplo n.º 1
0
def check_model(model_name, weights, fold):
    """
    Helper test method to do a simple sanity check of model predictions

    :param model_name: name of the model
    :param weights: model weights file path
    :param fold: fold to check
    """
    model = MODELS[model_name].factory(lock_base_model=True)
    model.load_weights(str(weights), by_name=False)

    dataset = SingleFrameCNNDataset(
        preprocess_input_func=MODELS[model_name].preprocess_input,
        batch_size=1,
        validation_batch_size=1,
        fold=fold)
    batch_id = 0
    for X, y in dataset.generate_test():
        pred = model.predict_on_batch(X)
        print()
        for i, cls in enumerate(CLASSES):
            print(f'gt: {y[0, i]}  pred: {pred[0, i]:.03f}  {cls}')
        batch_id += 1
        for batch_frame in range(dataset.batch_size):
            plt.imshow(utils.preprocessed_input_to_img_resnet(X[batch_frame]))
            # plt.imshow(X[batch_frame]/2.0+0.5)
            plt.show()
Exemplo n.º 2
0
def check_generator(use_test):
    """
    Helper function used to visually check X and y vectors generated by SingleFrameCNNDataset
    :param use_test: true if validation samples are checked, false for training samples.
    """
    dataset = SingleFrameCNNDataset(
        preprocess_input_func=preprocess_input_resnet50,
        batch_size=2,
        validation_batch_size=2,
        fold=1)
    batch_id = 0
    startTime = time.time()

    if use_test:
        gen = dataset.generate_test()
    else:
        gen = dataset.generate()

    for X, y in gen:
        batch_id += 1
        elapsedTime = time.time() - startTime
        startTime = time.time()
        print(f'{batch_id} {elapsedTime:.3}')
        for batch_frame in range(dataset.batch_size):
            print(y[batch_frame])
            plt.imshow(utils.preprocessed_input_to_img_resnet(X[batch_frame]))
            plt.show()