def main():
    """ """
    args = argparse.ArgumentParser()
    args.add_argument("--inference", "-i", action="store_true")
    args.add_argument("--continue_training", "-c", action="store_true")
    args.add_argument("--no_cuda", "-k", action="store_false")
    args.add_argument("--export", "-e", action="store_true")
    options = args.parse_args()

    timeas = str(time.time())
    this_model_path = PROJECT_APP_PATH.user_data / timeas
    this_log = PROJECT_APP_PATH.user_log / timeas
    ensure_directory_exist(this_model_path)
    ensure_directory_exist(this_log)

    best_model_name = "best_validation_model.model"
    interrupted_path = str(this_model_path / best_model_name)

    torch.manual_seed(seed)

    if not options.no_cuda:
        global_torch_device("cpu")

    with MixedObservationWrapper() as env:
        env.seed(seed)
        train_iter = batch_generator(iter(env), batch_size)
        num_categories = env.sensor("Class").space.discrete_steps
        test_iter = train_iter

        model, params_to_update = squeezenet_retrain(num_categories)
        print(params_to_update)

        model = model.to(global_torch_device())

        if options.continue_training:
            _list_of_files = list(PROJECT_APP_PATH.user_data.rglob("*.model"))
            latest_model_path = str(max(_list_of_files, key=os.path.getctime))
            print(f"loading previous model: {latest_model_path}")
            if latest_model_path is not None:
                model.load_state_dict(torch.load(latest_model_path))

        criterion = torch.nn.CrossEntropyLoss().to(global_torch_device())

        optimiser_ft = optim.SGD(model.parameters(),
                                 lr=learning_rate,
                                 momentum=momentum,
                                 weight_decay=wd)
        exp_lr_scheduler = torch.optim.lr_scheduler.StepLR(optimiser_ft,
                                                           step_size=7,
                                                           gamma=0.1)

        with TensorBoardPytorchWriter(this_log) as writer:
            if True:
                model = pred_target_train_model(
                    model,
                    train_iter,
                    criterion,
                    optimiser_ft,
                    exp_lr_scheduler,
                    writer,
                    interrupted_path,
                    test_data_iterator=test_iter,
                    num_updates=num_updates,
                )

            inputs, true_label = zip(*next(train_iter))
            rgb_imgs = torch_vision_normalize_batch_nchw(
                uint_nhwc_to_nchw_float_batch(
                    rgb_drop_alpha_batch_nhwc(to_tensor(inputs))))

            predicted = torch.argmax(model(rgb_imgs), -1)
            true_label = to_tensor(true_label, dtype=torch.long)
            print(predicted, true_label)
            horizontal_imshow(
                inputs,
                [
                    f"p:{int(p)},t:{int(t)}"
                    for p, t in zip(predicted, true_label)
                ],
            )
            pyplot.show()

    torch.cuda.empty_cache()

    if options.export:
        with TorchEvalSession(model):
            example = torch.rand(1, 3, 256, 256)
            traced_script_module = torch.jit.trace(model.to("cpu"), example)
            traced_script_module.save("resnet18_v.model")
Esempio n. 2
0
def main():
    args = argparse.ArgumentParser()
    args.add_argument("--inference", "-i", action="store_true")
    args.add_argument("--continue_training", "-c", action="store_true")
    args.add_argument("--real_data", "-r", action="store_true")
    args.add_argument("--no_cuda", "-k", action="store_false")
    args.add_argument("--export", "-e", action="store_true")
    options = args.parse_args()

    train_model = True
    timeas = str(time.time())
    this_model_path = PROJECT_APP_PATH.user_data / timeas
    this_log = PROJECT_APP_PATH.user_log / timeas
    ensure_directory_exist(this_model_path)
    ensure_directory_exist(this_log)

    best_model_name = "best_validation_model.model"
    interrupted_path = str(this_model_path / best_model_name)

    torch.manual_seed(seed)

    if not options.no_cuda:
        global_torch_device("cpu")

    dataset = MNISTDataset2(PROJECT_APP_PATH.user_cache / "mnist", split=Split.Training)
    train_iter = iter(
        recycle(
            DataLoader(dataset, batch_size=batch_size, shuffle=True, pin_memory=True)
        )
    )

    val_iter = iter(
        recycle(
            DataLoader(
                MNISTDataset2(
                    PROJECT_APP_PATH.user_cache / "mnist", split=Split.Validation
                ),
                batch_size=batch_size,
                shuffle=True,
                pin_memory=True,
            )
        )
    )

    model, params_to_update = squeezenet_retrain(len(dataset.categories))
    print(params_to_update)
    model = model.to(global_torch_device())

    if options.continue_training:
        _list_of_files = list(PROJECT_APP_PATH.user_data.rglob("*.model"))
        latest_model_path = str(max(_list_of_files, key=os.path.getctime))
        print(f"loading previous model: {latest_model_path}")
        if latest_model_path is not None:
            model.load_state_dict(torch.load(latest_model_path))

    criterion = torch.nn.CrossEntropyLoss().to(global_torch_device())

    optimizer_ft = optim.SGD(
        model.parameters(), lr=learning_rate, momentum=momentum, weight_decay=wd
    )
    exp_lr_scheduler = torch.optim.lr_scheduler.StepLR(
        optimizer_ft, step_size=7, gamma=0.1
    )

    writer = TensorBoardPytorchWriter(this_log)

    if train_model:
        model = predictor_response_train_model(
            model,
            train_iterator=train_iter,
            criterion=criterion,
            optimizer=optimizer_ft,
            scheduler=exp_lr_scheduler,
            writer=writer,
            interrupted_path=interrupted_path,
            val_data_iterator=val_iter,
            num_updates=NUM_UPDATES,
        )

    inputs, true_label = next(train_iter)
    inputs = to_tensor(inputs, dtype=torch.float, device=global_torch_device()).repeat(
        1, 3, 1, 1
    )
    true_label = to_tensor(true_label, dtype=torch.long, device=global_torch_device())

    pred = model(inputs)
    predicted = torch.argmax(pred, -1)
    true_label = to_tensor(true_label, dtype=torch.long)
    print(predicted, true_label)
    horizontal_imshow(
        inputs, [f"p:{int(p)},t:{int(t)}" for p, t in zip(predicted, true_label)]
    )
    pyplot.show()

    writer.close()
    torch_clean_up()
Esempio n. 3
0
def main(train_model=True, continue_training=True, no_cuda=False):
    """ """

    timeas = str(time.time())
    this_model_path = PROJECT_APP_PATH.user_data / timeas
    this_log = PROJECT_APP_PATH.user_log / timeas
    ensure_directory_exist(this_model_path)
    ensure_directory_exist(this_log)

    best_model_name = "best_validation_model.model"
    interrupted_path = str(this_model_path / best_model_name)

    torch.manual_seed(seed)

    if no_cuda:
        global_torch_device("cpu")

    dataset = MNISTDataset2(PROJECT_APP_PATH.user_cache / "mnist",
                            split=SplitEnum.training)
    train_iter = iter(
        recycle(
            DataLoader(dataset,
                       batch_size=batch_size,
                       shuffle=True,
                       pin_memory=True)))

    val_iter = iter(
        recycle(
            DataLoader(
                MNISTDataset2(PROJECT_APP_PATH.user_cache / "mnist",
                              split=SplitEnum.validation),
                batch_size=batch_size,
                shuffle=True,
                pin_memory=global_pin_memory(0),
            )))

    model, params_to_update = squeezenet_retrain(len(dataset.categories),
                                                 train_only_last_layer=True)
    # print(params_to_update)
    model = model.to(global_torch_device())

    if continue_training:
        _list_of_files = list(PROJECT_APP_PATH.user_data.rglob("*.model"))
        latest_model_path = str(max(_list_of_files, key=os.path.getctime))
        print(f"loading previous model: {latest_model_path}")
        if latest_model_path is not None:
            model.load_state_dict(torch.load(latest_model_path))

    criterion = torch.nn.CrossEntropyLoss().to(global_torch_device())

    optimiser_ft = optim.SGD(model.parameters(),
                             lr=learning_rate,
                             momentum=momentum,
                             weight_decay=wd)
    exp_lr_scheduler = torch.optim.lr_scheduler.StepLR(optimiser_ft,
                                                       step_size=7,
                                                       gamma=0.1)

    with TensorBoardPytorchWriter(this_log) as writer:
        if train_model:
            model = predictor_response_train_model(
                model,
                train_iterator=train_iter,
                criterion=criterion,
                optimiser=optimiser_ft,
                scheduler=exp_lr_scheduler,
                writer=writer,
                interrupted_path=interrupted_path,
                val_data_iterator=val_iter,
                num_updates=NUM_UPDATES,
            )

        inputs, true_label = next(val_iter)
        inputs = to_tensor(inputs,
                           dtype=torch.float,
                           device=global_torch_device()).repeat(1, 3, 1, 1)
        true_label = to_tensor(true_label,
                               dtype=torch.long,
                               device=global_torch_device())

        pred = model(inputs)
        predicted = torch.argmax(pred, -1)
        true_label = to_tensor(true_label, dtype=torch.long)
        # print(predicted, true_label)
        horizontal_imshow(
            [to_pil_image(i) for i in inputs],
            [f"p:{int(p)},t:{int(t)}" for p, t in zip(predicted, true_label)],
            num_columns=64 // 8,
        )
        pyplot.show()

    torch_clean_up()