def run(model: BaseModel, dataset: BaseDataset, device, output_path, cfg):
    # Set dataloaders
    num_fragment = dataset.num_fragment
    if cfg.data.is_patch:
        for i in range(num_fragment):
            dataset.set_patches(i)
            dataset.create_dataloaders(
                model,
                cfg.batch_size,
                False,
                cfg.num_workers,
                False,
            )
            loader = dataset.test_dataloaders()[0]
            features = []
            scene_name, pc_name = dataset.get_name(i)

            with Ctq(loader) as tq_test_loader:
                for data in tq_test_loader:
                    # pcd = open3d.geometry.PointCloud()
                    # pcd.points = open3d.utility.Vector3dVector(data.pos[0].numpy())
                    # open3d.visualization.draw_geometries([pcd])
                    with torch.no_grad():
                        model.set_input(data, device)
                        model.forward()
                        features.append(model.get_output().cpu())
            features = torch.cat(features, 0).numpy()
            log.info("save {} from {} in  {}".format(pc_name, scene_name,
                                                     output_path))
            save(output_path, scene_name, pc_name,
                 dataset.base_dataset[i].to("cpu"), features)
    else:
        dataset.create_dataloaders(
            model,
            1,
            False,
            cfg.num_workers,
            False,
        )
        loader = dataset.test_dataloaders()[0]
        with Ctq(loader) as tq_test_loader:
            for i, data in enumerate(tq_test_loader):
                with torch.no_grad():
                    model.set_input(data, device)
                    model.forward()
                    features = model.get_output()[0]  # batch of 1
                    save(output_path, scene_name, pc_name, data.to("cpu"),
                         features)
Example #2
0
def train_epoch(epoch, model: BaseModel, dataset, device: str,
                tracker: BaseTracker, checkpoint: ModelCheckpoint, log):
    model.train()
    tracker.reset("train")
    train_loader = dataset.train_dataloader()

    iter_data_time = time.time()
    with Ctq(train_loader) as tq_train_loader:
        for i, data in enumerate(tq_train_loader):
            data = data.to(device)  # This takes time

            model.set_input(data)
            t_data = time.time() - iter_data_time

            iter_start_time = time.time()
            model.optimize_parameters(dataset.batch_size)

            if i % 10 == 0:
                tracker.track(model)

            tq_train_loader.set_postfix(**tracker.get_metrics(),
                                        data_loading=float(t_data),
                                        iteration=float(time.time() -
                                                        iter_start_time),
                                        color=COLORS.TRAIN_COLOR)
            iter_data_time = time.time()

    metrics = tracker.publish()
    checkpoint.save_best_models_under_current_metrics(model, metrics)
    log.info("Learning rate = %f" % model.learning_rate)
Example #3
0
def eval_epoch(
    epoch: int,
    model: BaseModel,
    dataset,
    device,
    tracker: BaseTracker,
    checkpoint: ModelCheckpoint,
    visualizer: Visualizer,
    early_break: bool,
):
    model.eval()
    tracker.reset("val")
    visualizer.reset(epoch, "val")
    loader = dataset.val_dataloader()
    with Ctq(loader) as tq_val_loader:
        for data in tq_val_loader:
            data = data.to(device)
            with torch.no_grad():
                model.set_input(data)
                model.forward()

            tracker.track(model)
            tq_val_loader.set_postfix(**tracker.get_metrics(),
                                      color=COLORS.VAL_COLOR)

            if visualizer.is_active:
                visualizer.save_visuals(model.get_current_visuals())

            if early_break:
                break

    metrics = tracker.publish(epoch)
    tracker.print_summary()
    checkpoint.save_best_models_under_current_metrics(model, metrics)
Example #4
0
def run_epoch(model: BaseModel, loader, device: str, num_batches: int):
    model.eval()
    with Ctq(loader) as tq_loader:
        for batch_idx, data in enumerate(tq_loader):
            if batch_idx < num_batches:
                process(model, data, device)
            else:
                break
Example #5
0
def eval_epoch(model: BaseModel, dataset, device, tracker: BaseTracker, checkpoint: ModelCheckpoint):
    tracker.reset("val")
    loader = dataset.val_dataloader
    with Ctq(loader) as tq_val_loader:
        for data in tq_val_loader:
            with torch.no_grad():
                model.set_input(data, device)
                model.forward()

            tracker.track(model)
            tq_val_loader.set_postfix(**tracker.get_metrics(), color=COLORS.VAL_COLOR)

    tracker.print_summary()
Example #6
0
def test_epoch(model: BaseModel, dataset, device, tracker: BaseTracker, checkpoint: ModelCheckpoint):
    tracker.reset("test")
    loader = dataset.test_dataloader()
    with Ctq(loader) as tq_test_loader:
        for data in tq_test_loader:
            data = data.to(device)
            with torch.no_grad():
                model.set_input(data)
                model.forward()

            tracker.track(model)
            tq_test_loader.set_postfix(**tracker.get_metrics(), color=COLORS.TEST_COLOR)

    tracker.print_summary()
Example #7
0
def run(model: BaseModel, dataset: BaseDataset, device, output_path):
    loaders = dataset.test_dataloaders
    predicted = {}
    for loader in loaders:
        loader.dataset.name
        with Ctq(loader) as tq_test_loader:
            for data in tq_test_loader:
                with torch.no_grad():
                    model.set_input(data, device)
                    model.forward()
                predicted = {
                    **predicted,
                    **dataset.predict_original_samples(data, model.conv_type,
                                                       model.get_output())
                }

    save(output_path, predicted)
Example #8
0
def test_epoch(model: BaseModel, dataset, device, tracker: BaseTracker, checkpoint: ModelCheckpoint):

    loaders = dataset.test_dataloaders

    for loader in loaders:
        stage_name = loader.dataset.name
        tracker.reset(stage_name)
        with Ctq(loader) as tq_test_loader:
            for data in tq_test_loader:
                with torch.no_grad():
                    model.set_input(data, device)
                    model.forward()

                tracker.track(model)
                tq_test_loader.set_postfix(**tracker.get_metrics(), color=COLORS.TEST_COLOR)

            tracker.print_summary()
def run(model: BaseModel, dataset: BaseDataset, device, output_path):
    loaders = dataset.test_dataloaders()
    predicted = {}
    for idx, loader in enumerate(loaders):
        dataset.get_test_dataset_name(idx)
        with Ctq(loader) as tq_test_loader:
            for data in tq_test_loader:
                data = data.to(device)
                with torch.no_grad():
                    model.set_input(data)
                    model.forward()
                predicted = {
                    **predicted,
                    **dataset.predict_original_samples(data, model.conv_type,
                                                       model.get_output())
                }

    save(output_path, predicted)
Example #10
0
def eval_epoch(model: BaseModel, dataset, device, tracker: BaseTracker,
               checkpoint: ModelCheckpoint, log):
    model.eval()
    tracker.reset("val")
    loader = dataset.val_dataloader()
    with Ctq(loader) as tq_val_loader:
        for data in tq_val_loader:
            data = data.to(device)
            with torch.no_grad():
                model.set_input(data)
                model.forward()

            tracker.track(model)
            tq_val_loader.set_postfix(**tracker.get_metrics(),
                                      color=COLORS.VAL_COLOR)

    metrics = tracker.publish()
    tracker.print_summary()
    checkpoint.save_best_models_under_current_metrics(model, metrics)
def train_epoch(
    epoch: int,
    model: BaseModel,
    dataset,
    device: str,
    tracker: BaseTracker,
    checkpoint: ModelCheckpoint,
    visualizer: Visualizer,
    early_break: bool,
):
    model.train()
    tracker.reset("train")
    visualizer.reset(epoch, "train")
    train_loader = dataset.train_dataloader

    iter_data_time = time.time()
    with Ctq(train_loader) as tq_train_loader:
        for i, data in enumerate(tq_train_loader):
            model.set_input(data, device)
            t_data = time.time() - iter_data_time

            iter_start_time = time.time()
            model.optimize_parameters(epoch, dataset.batch_size)
            if i % 10 == 0:
                tracker.track(model)

            tq_train_loader.set_postfix(**tracker.get_metrics(),
                                        data_loading=float(t_data),
                                        iteration=float(time.time() -
                                                        iter_start_time),
                                        color=COLORS.TRAIN_COLOR)

            if visualizer.is_active:
                visualizer.save_visuals(model.get_current_visuals())

            iter_data_time = time.time()

            if early_break:
                break

    metrics = tracker.publish(epoch)
    checkpoint.save_best_models_under_current_metrics(model, metrics)
    log.info("Learning rate = %f" % model.learning_rate)
Example #12
0
def test_epoch(
    epoch: int,
    model: BaseModel,
    dataset,
    device,
    tracker: BaseTracker,
    checkpoint: ModelCheckpoint,
    visualizer: Visualizer,
    early_break: bool,
):
    model.eval()

    loaders = dataset.test_dataloaders()

    for idx, loader in enumerate(loaders):
        stage_name = dataset.get_test_dataset_name(idx)
        tracker.reset(stage_name)
        visualizer.reset(epoch, stage_name)
        with Ctq(loader) as tq_test_loader:
            for data in tq_test_loader:
                data = data.to(device)
                with torch.no_grad():
                    model.set_input(data)
                    model.forward()

                tracker.track(model)
                tq_test_loader.set_postfix(**tracker.get_metrics(),
                                           color=COLORS.TEST_COLOR)

                if visualizer.is_active:
                    visualizer.save_visuals(model.get_current_visuals())

                if early_break:
                    break

        metrics = tracker.publish(epoch)
        tracker.print_summary()
        checkpoint.save_best_models_under_current_metrics(model, metrics)