Exemplo n.º 1
0
    # ########## SET UP DATASET ########## #
    test_dataset_args = {
        "inputs_path": INPUTS_PATH,
        "departure_cameras_path": DEPARTURE_CAMERA_PATH,
        "targets_path": TARGETS_PATH,
        "fold": fold,
        "phase": "test",
        "batch_size": BATCH_SIZE,
        "shuffle": False,
        "num_workers": NUM_WORKERS,
        "flatten_inputs": False,
        "multi_target": True,
    }

    test_loader = get_coordinate_trajectory_dataset(**test_dataset_args)

    # ########## SET UP MODEL ########## #
    encoder = RecurrentEncoder(**encoder_args).to(DEVICE)
    decoder = RecurrentDecoder(**decoder_args).to(DEVICE)
    encoder.load_state_dict(
        torch.load(
            os.path.join(MODEL_LOAD_PATH,
                         "encoder_fold_" + str(fold) + ".weights")))
    decoder.load_state_dict(
        torch.load(
            os.path.join(MODEL_LOAD_PATH,
                         "decoder_fold_" + str(fold) + ".weights")))

    loss_function = nn.BCELoss()
Exemplo n.º 2
0
    }

    test_dataset_args = {
        "inputs_path": CROSS_VALIDATION_COORDINATE_TRAJECTORIES_PATH,
        "departure_cameras_path": CROSS_VALIDATION_DEPARTURE_CAMERAS_PATH,
        "targets_path": CROSS_VALIDATION_WHEN_TARGETS_PATH,
        "fold": fold,
        "phase": "test",
        "batch_size": BATCH_SIZE,
        "shuffle": False,
        "num_workers": NUM_WORKERS,
        "flatten_inputs": False,
        "flatten_targets": False,
    }

    train_loader = get_coordinate_trajectory_dataset(**train_dataset_args)
    val_loader = get_coordinate_trajectory_dataset(**val_dataset_args)
    test_loader = get_coordinate_trajectory_dataset(**test_dataset_args)

    # ########## SET UP MODEL ########## #
    encoder = RecurrentEncoder(**encoder_args).to(DEVICE)
    decoder = RecurrentDecoder(**decoder_args).to(DEVICE)
    params = list(encoder.parameters()) + list(decoder.parameters())

    optimizer = optim.Adam(params, lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)
    loss_function = nn.BCELoss()

    # ########## TRAIN AND EVALUATE ########## #
    best_ap = 0

    for epoch in range(NUM_EPOCHS):