def main():
    pyplot.style.use("bmh")
    base_path = Path.home() / "/Data" / "PennFudanPed"

    save_model_path = PROJECT_APP_PATH.user_data / 'models' / "penn_fudan_ped_seg.model"
    train_model = False
    eval_model = not train_model
    SEED = 87539842
    batch_size = 8
    num_workers = 1  # os.cpu_count()
    learning_rate = 0.01
    torch_seed(SEED)

    train_set = PennFudanDataset(base_path, Split.Training)
    train_loader = DataLoader(train_set,
                              batch_size=batch_size,
                              shuffle=True,
                              num_workers=num_workers)
    valid_loader = DataLoader(
        PennFudanDataset(base_path, Split.Validation),
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
    )

    model = SkipHourglassFission(
        input_channels=train_set.predictor_shape[-1],
        output_heads=(train_set.response_shape[-1], ),
        encoding_depth=1,
    )
    model.to(global_torch_device())

    if train_model:
        if save_model_path.exists():
            model.load_state_dict(torch.load(str(save_model_path)))
            print("loading saved model")

        with TorchTrainSession(model):
            criterion = BCEDiceLoss(eps=1.0)
            optimiser = torch.optim.SGD(model.parameters(), lr=learning_rate)
            scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
                optimiser, T_max=7, eta_min=learning_rate / 100, last_epoch=-1)

            model = train_person_segmenter(
                model,
                train_loader,
                valid_loader,
                criterion,
                optimiser,
                scheduler,
                save_model_path,
            )

    if eval_model:
        if save_model_path.exists():
            model.load_state_dict(torch.load(str(save_model_path)))
            print("loading saved model")

        with TorchDeviceSession(global_torch_device(cuda_if_available=False),
                                model):
            with torch.no_grad():
                with TorchCacheSession():
                    with TorchEvalSession(model):
                        valid_masks = []
                        a = (350, 525)
                        tr = min(len(valid_loader.dataset) * 4, 2000)
                        probabilities = numpy.zeros((tr, *a),
                                                    dtype=numpy.float32)
                        for sample_i, (data, target) in enumerate(
                                tqdm(valid_loader)):
                            data = data.to(global_torch_device())
                            target = target.cpu().detach().numpy()
                            outpu, *_ = model(data)
                            outpu = torch.sigmoid(outpu).cpu().detach().numpy()
                            for p in range(data.shape[0]):
                                output, mask = outpu[p], target[p]
                                """
for m in mask:
  valid_masks.append(cv2_resize(m, a))
for probability in output:
  probabilities[sample_i, :, :] = cv2_resize(probability, a)
  sample_i += 1
"""
                                if sample_i >= tr - 1:
                                    break
                            if sample_i >= tr - 1:
                                break

                        f, ax = pyplot.subplots(3, 3, figsize=(24, 12))

                        for i in range(3):
                            ax[0, i].imshow(valid_masks[i], vmin=0, vmax=1)
                            ax[0, i].set_title("Original", fontsize=14)

                            ax[1, i].imshow(valid_masks[i], vmin=0, vmax=1)
                            ax[1, i].set_title("Target", fontsize=14)

                            ax[2, i].imshow(probabilities[i], vmin=0, vmax=1)
                            ax[2, i].set_title("Prediction", fontsize=14)

                        pyplot.show()
Esempio n. 2
0
def main():
    args = argparse.ArgumentParser()
    args.add_argument("-i", action="store_false")
    options = args.parse_args()

    seed = 42
    batch_size = 8  # 12
    depth = 4  # 5
    segmentation_channels = 3
    tqdm.monitor_interval = 0
    learning_rate = 3e-3
    lr_sch_step_size = int(1000 // batch_size) + 4
    lr_sch_gamma = 0.1
    model_start_channels = 16

    home_path = Path.home() / "Models" / "Vision"
    base_path = home_path / str(time.time())
    best_model_path = "INTERRUPTED_BEST.pth"
    interrupted_path = str(base_path / best_model_path)

    writer = TensorBoardPytorchWriter(str(base_path))
    env = CameraObservationWrapper()

    torch.manual_seed(seed)
    env.seed(seed)

    device = global_torch_device()

    aeu_model = SkipHourglassFission(
        segmentation_channels,
        (segmentation_channels,),
        encoding_depth=depth,
        start_channels=model_start_channels,
    )
    aeu_model = aeu_model.to(global_torch_device())

    optimizer_ft = optim.Adam(aeu_model.parameters(), lr=learning_rate)

    exp_lr_scheduler = lr_scheduler.StepLR(
        optimizer_ft, step_size=lr_sch_step_size, gamma=lr_sch_gamma
    )

    data_iter = iter(neodroid_camera_data_iterator(env, device, batch_size))

    if options.i:
        trained_aeu_model = train_model(
            aeu_model,
            data_iter,
            optimizer_ft,
            exp_lr_scheduler,
            writer,
            interrupted_path,
        )
        test_model(trained_aeu_model, data_iter)
    else:
        _list_of_files = list(home_path.glob("*"))
        latest_model_path = (
            str(max(_list_of_files, key=os.path.getctime)) + f"/{best_model_path}"
        )
        print("loading previous model: " + latest_model_path)
        test_model(aeu_model, data_iter, load_path=latest_model_path)

    torch.cuda.empty_cache()
    env.close()
    writer.close()
Esempio n. 3
0
def train_mnist(load_earlier=False, train=True, denoise: bool = True):
    """

    :param load_earlier:
    :type load_earlier:
    :param train:
    :type train:"""
    seed = 251645
    batch_size = 32

    tqdm.monitor_interval = 0
    learning_rate = 3e-3
    lr_sch_step_size = int(10e4 // batch_size)
    lr_sch_gamma = 0.1
    unet_depth = 3
    unet_start_channels = 16
    input_channels = 1
    output_channels = (input_channels,)

    home_path = PROJECT_APP_PATH
    model_file_ending = ".model"
    model_base_path = ensure_existence(PROJECT_APP_PATH.user_data / "unet_mnist")
    interrupted_name = "INTERRUPTED_BEST"
    interrupted_path = model_base_path / f"{interrupted_name}{model_file_ending}"

    torch.manual_seed(seed)

    device = global_torch_device()

    img_transform = transforms.Compose(
        [
            transforms.ToTensor(),
            MinMaxNorm(),
            transforms.Lambda(lambda tensor: torch.round(tensor)),
            # transforms.RandomErasing()
        ]
    )
    dataset = MNIST(
        PROJECT_APP_PATH.user_data / "mnist", transform=img_transform, download=True
    )
    data_iter = iter(
        cycle(DataLoader(dataset, batch_size=batch_size, shuffle=True, pin_memory=True))
    )
    data_iter = to_device_iterator(data_iter, device)

    model = SkipHourglassFission(
        input_channels=input_channels,
        output_heads=output_channels,
        encoding_depth=unet_depth,
        start_channels=unet_start_channels,
    ).to(global_torch_device())

    optimiser_ft = optim.Adam(model.parameters(), lr=learning_rate)

    exp_lr_scheduler = optim.lr_scheduler.StepLR(
        optimiser_ft, step_size=lr_sch_step_size, gamma=lr_sch_gamma
    )

    if load_earlier:
        _list_of_files = list(
            model_base_path.rglob(f"{interrupted_name}{model_file_ending}")
        )
        if not len(_list_of_files):
            print(
                f"found no trained models under {model_base_path}{os.path.sep}**{os.path.sep}{interrupted_name}{model_file_ending}"
            )
            exit(1)
        latest_model_path = str(max(_list_of_files, key=os.path.getctime))
        print(f"loading previous model: {latest_model_path}")
        if latest_model_path is not None:
            model.load_state_dict(torch.load(latest_model_path))

    if train:
        with TensorBoardPytorchWriter(home_path.user_log / str(time.time())) as writer:
            model = training(
                model,
                data_iter,
                optimiser_ft,
                exp_lr_scheduler,
                writer,
                interrupted_path,
                denoise=denoise,
            )
            torch.save(
                model.state_dict(),
                model_base_path / f"unet_mnist_final{model_file_ending}",
            )
    else:
        inference(model, data_iter, denoise=denoise)

    torch.cuda.empty_cache()
Esempio n. 4
0
def main(
    base_path: Path = Path.home() / "Data" / "Datasets" / "PennFudanPed",
    train_model: bool = True,
    load_prev_model: bool = True,
    writer: Writer = TensorBoardPytorchWriter(PROJECT_APP_PATH.user_log /
                                              "instanced_person_segmentation" /
                                              f"{time.time()}"),
):
    """ """

    # base_path = Path("/") / "encrypted_disk" / "heider" / "Data" / "PennFudanPed"
    base_path: Path = Path.home() / "Data3" / "PennFudanPed"
    # base_path = Path('/media/heider/OS/Users/Christian/Data/Datasets/')  / "PennFudanPed"
    pyplot.style.use("bmh")

    save_model_path = (
        ensure_existence(PROJECT_APP_PATH.user_data / "models") /
        "instanced_penn_fudan_ped_seg.model")

    eval_model = not train_model
    SEED = 9221
    batch_size = 32
    num_workers = 0
    encoding_depth = 2
    learning_rate = 6e-6  # sequence 6e-2 6e-3 6e-4 6e-5

    seed_stack(SEED)

    train_set = PennFudanDataset(
        base_path,
        SplitEnum.training,
        return_variant=PennFudanDataset.PennFudanReturnVariantEnum.instanced,
    )

    train_loader = DataLoader(train_set,
                              batch_size=batch_size,
                              shuffle=True,
                              num_workers=num_workers)
    valid_loader = DataLoader(
        PennFudanDataset(
            base_path,
            SplitEnum.validation,
            return_variant=PennFudanDataset.PennFudanReturnVariantEnum.
            instanced,
        ),
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
    )

    model = SkipHourglassFission(
        input_channels=train_set.predictor_shape[-1],
        output_heads=(train_set.response_shape[-1], ),
        encoding_depth=encoding_depth,
    )
    model.to(global_torch_device())

    if load_prev_model and save_model_path.exists():
        model.load_state_dict(torch.load(str(save_model_path)))
        print("loading saved model")

    if train_model:
        with TorchTrainSession(model):
            criterion = BCEDiceLoss()
            # optimiser = torch.optim.SGD(model.parameters(), lr=learning_rate)
            optimiser = torch.optim.Adam(model.parameters(), lr=learning_rate)
            # scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(                optimiser, T_max=7, eta_min=learning_rate / 100, last_epoch=-1            )

            model = train_person_segmentor(
                model,
                train_loader,
                valid_loader,
                criterion,
                optimiser,
                save_model_path=save_model_path,
                learning_rate=learning_rate,
                writer=writer,
            )

    if eval_model:
        validate_model(model, valid_loader)