Exemple #1
0
        print_single_error(epoch_idx, loss_fn.show(), metric_fn.show())
        lr_scheduler.step()

        if config.TEST:
            test(model, config_test)
        if config.SAVE_MODEL:
            save_checkpoint(epoch_idx, model, output_dir)

    if not config.TEST:
        test(model, config_test)
    if not config.SAVE_MODEL:
        save_checkpoint(epoch_idx, model, output_dir)


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='train model')
    parser.add_argument('--train',
                        type=str,
                        default="train.yaml",
                        help='train config file')
    parser.add_argument('--test',
                        type=str,
                        default="test.yaml",
                        help='test config file')
    opt = parser.parse_args()

    config_train = parse_train_config(read_yaml_config(opt.train))
    config_test = parse_test_config(read_yaml_config(opt.test))

    train(config_train, config_test)
Exemple #2
0
        A.Normalize(),
        M.MyToTensorV2(),
    ],
                          additional_targets={
                              'right_img': 'image',
                          })

    dataset = LoadImages(config.JSON, transform=transform)

    if not model:
        model = Model()
        model = model.to(DEVICE)
        _, model = load_checkpoint(model, config.CHECKPOINT_FILE, DEVICE)

    model.eval()
    for img, predictions, path in generatePredictions(model, dataset):
        plot_predictions([img], predictions, [path])
        save_predictions([img], predictions, [path])


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='run inference on model')
    parser.add_argument('--detect',
                        type=str,
                        default="detect.yaml",
                        help='detect config file')
    opt = parser.parse_args()

    config_detect = parse_detect_config(read_yaml_config(opt.detect))

    detect(config=config_detect)
                                      batch_size=config.BATCH_SIZE,
                                      transform=transform,
                                      workers=config.WORKERS,
                                      pin_memory=config.PIN_MEMORY,
                                      shuffle=config.SHUFFLE)

    if not model:
        model = Model()
        model = model.to(DEVICE)
        epoch, model = load_checkpoint(model, config.CHECKPOINT_FILE, DEVICE)

    loss_fn = LossFunction()
    metric_fn = MetricFunction(config.BATCH_SIZE)

    model.eval()
    run_test(model, dataloader, loss_fn, metric_fn)
    print_single_error(epoch, loss_fn.show(), metric_fn.show())


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='test model')
    parser.add_argument('--test',
                        type=str,
                        default="test.yaml",
                        help='test config file')
    opt = parser.parse_args()

    config_test = parse_test_config(read_yaml_config(opt.test))

    test(config=config_test)