"targets_path": which_targets_path,
        "fold": fold,
        "heatmap_size": heatmap_size,
        "heatmap_smoothing_sigma": heatmap_smoothing_sigma,
        "phase": "test",
        "batch_size": BATCH_SIZE,
        "shuffle": False,
        "num_workers": NUM_WORKERS,
        "multi_target": True,
    }

    test_loader = get_trajectory_tensor_dataset(**test_dataset_args)

    # ########## SET UP MODEL ########## #
    encoder = CNN_2D_1D_Encoder(**encoder_args).to(DEVICE)
    decoder = FullyConnectedTrajectoryTensorClassifier(**decoder_args).to(DEVICE)
    encoder.load_state_dict(torch.load(os.path.join(MODEL_LOAD_PATH, "encoder_fold_" + str(fold) + ".weights")))
    decoder.load_state_dict(torch.load(os.path.join(MODEL_LOAD_PATH, "decoder_fold_" + str(fold) + ".weights")))

    params = list(encoder.parameters()) + list(decoder.parameters())

    loss_function = nn.BCELoss()

    # ########## TRAIN AND EVALUATE ########## #
    best_ap = 0

    test_args = {
        "encoder": encoder,
        "decoder": decoder,
        "device": DEVICE,
        "test_loader": test_loader,
Beispiel #2
0
            "fold": fold,
            "heatmap_size": heatmap_size,
            "heatmap_smoothing_sigma": heatmap_smoothing_sigma,
            "phase": "test",
            "batch_size": BATCH_SIZE,
            "shuffle": False,
            "num_workers": NUM_WORKERS,
        }

        train_loader = get_trajectory_tensor_dataset(**train_dataset_args)
        val_loader = get_trajectory_tensor_dataset(**val_dataset_args)
        test_loader = get_trajectory_tensor_dataset(**test_dataset_args)

        # ########## SET UP MODEL ########## #
        encoder = CNN_3D_Encoder(**encoder_args).to(DEVICE)
        decoder = FullyConnectedTrajectoryTensorClassifier(
            **decoder_args).to(DEVICE)
        params = list(encoder.parameters()) + list(decoder.parameters())

        optimizer = optim.Adam(params,
                               lr=LEARNING_RATE,
                               weight_decay=WEIGHT_DECAY)
        loss_function = nn.BCELoss()

        # ########## TRAIN AND EVALUATE ########## #
        best_ap = 0

        for epoch in range(NUM_EPOCHS):
            print("----------- EPOCH " + str(epoch) + " -----------")

            trainer_args = {
                "encoder": encoder,