예제 #1
0
    def test_indices(self):
        mock_data = Data()
        mock_data.pos = torch.zeros((batch_size, num_points, 3))
        mock_data.y = torch.zeros((batch_size, num_points, 1))
        mock_data.pred = torch.zeros((batch_size, num_points, 1))
        data = {"mock_date": mock_data}

        self.run_path = os.path.join(DIR, "test_viz")
        if not os.path.exists(self.run_path):
            os.makedirs(self.run_path)

        mock_num_batches = {"train": 9, "test": 3, "val": 0}
        config = OmegaConf.load(os.path.join(DIR, "test_config/viz/viz_config_indices.yaml"))
        visualizer = Visualizer(config.visualization, mock_num_batches, batch_size, self.run_path)

        for epoch in range(epochs):
            run(9, visualizer, epoch, "train", data)
            run(3, visualizer, epoch, "test", data)
            run(0, visualizer, epoch, "val", data)

        targets = {'train': set(["1_1.ply", "0_0.ply"]),
                   'test': set(["0_0.ply"])}
        for split in ["train", "test"]:
            for epoch in range(epochs):
                self.assertEqual(targets[split], set(os.listdir(os.path.join(self.run_path, "viz", str(epoch), split))))
        shutil.rmtree(self.run_path)
예제 #2
0
    def test_save_all(self):
        mock_data = Data()
        mock_data.pos = torch.zeros((num_points * batch_size, 3))
        mock_data.y = torch.zeros((num_points * batch_size, 1))
        mock_data.pred = torch.zeros((num_points * batch_size, 1))
        mock_data.batch = torch.zeros((num_points * batch_size))
        mock_data.batch[:num_points] = 1
        data = {"mock_date": mock_data}

        self.run_path = os.path.join(DIR, "test_viz")
        if not os.path.exists(self.run_path):
            os.makedirs(self.run_path)

        epochs = 2
        num_samples = 100
        mock_num_batches = {"train": num_samples}

        config = OmegaConf.load(os.path.join(DIR, "test_config/viz/viz_config_save_all.yaml"))
        visualizer = Visualizer(config.visualization, mock_num_batches, batch_size, self.run_path)

        for epoch in range(epochs):
            run(num_samples // batch_size, visualizer, epoch, "train", data)

        for split in ["train"]:
            for epoch in range(epochs):
                current = set(os.listdir(os.path.join(self.run_path, "viz", str(epoch), split)))
                self.assertGreaterEqual(len(current), num_samples)
        shutil.rmtree(self.run_path)
예제 #3
0
    def test_pyg_data(self):
        mock_data = Data()
        mock_data.pos = torch.zeros((num_points * batch_size, 3))
        mock_data.y = torch.zeros((num_points * batch_size, 1))
        mock_data.pred = torch.zeros((num_points * batch_size, 1))
        mock_data.batch = torch.zeros((num_points * batch_size))
        mock_data.batch[:num_points] = 1
        data = {"mock_date": mock_data}

        self.run_path = os.path.join(DIR, "test_viz")
        if not os.path.exists(self.run_path):
            os.makedirs(self.run_path)

        epochs = 10
        num_batches = 100
        mock_num_batches = {"train": num_batches}

        config = OmegaConf.load(os.path.join(DIR, "test_config/viz/viz_config_non_deterministic.yaml"))
        visualizer = Visualizer(config.visualization, mock_num_batches, batch_size, self.run_path)

        for epoch in range(epochs):
            run(num_batches, visualizer, epoch, "train", data)

        count = 0
        for split in ["train"]:
            target = set(os.listdir(os.path.join(self.run_path, "viz", "0", split)))
            for epoch in range(1, epochs):
                current = set(os.listdir(os.path.join(self.run_path, "viz", str(epoch), split)))
                count += 1 if len(target & current) == 0 else 0
        self.assertGreaterEqual(count, 4)
        shutil.rmtree(self.run_path)
예제 #4
0
    def test_dense_data(self):
        mock_data = Data()
        mock_data.pos = torch.zeros((batch_size, num_points, 3))
        mock_data.y = torch.zeros((batch_size, num_points, 1))
        mock_data.pred = torch.zeros((batch_size, num_points, 1))
        data = {"mock_date": mock_data}

        self.run_path = os.path.join(DIR, "test_viz")
        if not os.path.exists(self.run_path):
            os.makedirs(self.run_path)

        mock_num_batches = {"train": 9, "test": 3, "val": 0}
        config = OmegaConf.load(
            os.path.join(DIR, "test_config/viz/viz_config_deterministic.yaml"))
        visualizer = Visualizer(config.visualization, mock_num_batches,
                                batch_size, self.run_path, None)

        for epoch in range(epochs):
            run(9, visualizer, epoch, "train", data)
            run(3, visualizer, epoch, "test", data)
            run(0, visualizer, epoch, "val", data)

        for split in ["train", "test"]:
            for format in ["ply", "las"]:
                targets = os.listdir(
                    os.path.join(self.run_path, "viz", "0", split, format))
                for epoch in range(1, epochs):
                    current = os.listdir(
                        os.path.join(self.run_path, "viz", str(epoch), split,
                                     format))
                    self.assertEqual(len(targets), len(current))
        shutil.rmtree(self.run_path)
예제 #5
0
    def test_indices(self):
        mock_data = Data()
        mock_data.pos = torch.zeros((batch_size, num_points, 3))
        mock_data.y = torch.zeros((batch_size, num_points, 1))
        mock_data.pred = torch.zeros((batch_size, num_points, 1))
        data = {"mock_date": mock_data}

        self.run_path = os.path.join(DIR, "test_viz")
        if not os.path.exists(self.run_path):
            os.makedirs(self.run_path)

        mock_num_batches = {"train": 9, "test": 3, "val": 0}
        config = OmegaConf.load(
            os.path.join(DIR, "test_config/viz/viz_config_indices.yaml"))
        visualizer = Visualizer(config.visualization, mock_num_batches,
                                batch_size, self.run_path, None)

        for epoch in range(epochs):
            run(9, visualizer, epoch, "train", data)
            run(3, visualizer, epoch, "test", data)
            run(0, visualizer, epoch, "val", data)

        targets = {"train": set(["1_1", "0_0"]), "test": set(["0_0"])}
        for split in ["train", "test"]:
            for epoch in range(epochs):
                for format in ["ply", "las"]:
                    files = os.listdir(
                        os.path.join(self.run_path, "viz", str(epoch), split,
                                     format))
                    files = [
                        os.path.splitext(filename)[0] for filename in files
                    ]

                    target = targets[split]
                    target = ["%d_%s" % (epoch, f) for f in target
                              ]  # append current epoch to start of target
                    if format == "las":
                        target_gt = ["%s_gt" % (x)
                                     for x in target]  # add gt files for las
                        target += target_gt

                    self.assertEqual(set(target), set(files))
        shutil.rmtree(self.run_path)
예제 #6
0
    def test_empty(self):
        mock_data = Data()
        mock_data.pos = torch.zeros((batch_size, num_points, 3))
        mock_data.y = torch.zeros((batch_size, num_points, 1))
        mock_data.pred = torch.zeros((batch_size, num_points, 1))
        data = {}

        self.run_path = os.path.join(DIR, "test_viz")
        if not os.path.exists(self.run_path):
            os.makedirs(self.run_path)

        mock_num_batches = {"train": 9, "test": 3, "val": 0}
        config = OmegaConf.load(os.path.join(DIR, "test_config/viz/viz_config_indices.yaml"))
        visualizer = Visualizer(config.visualization, mock_num_batches, batch_size, self.run_path)

        for epoch in range(epochs):
            run(9, visualizer, epoch, "train", data)
            run(3, visualizer, epoch, "test", data)
            run(2, visualizer, epoch, "val", data)

        self.assertEqual(len(os.listdir(os.path.join(self.run_path, "viz"))), 0)
        shutil.rmtree(self.run_path)
예제 #7
0
    def _initialize_trainer(self):
        # Enable CUDNN BACKEND
        torch.backends.cudnn.enabled = self.enable_cudnn

        if not self.has_training:
            self._cfg.training = self._cfg
            resume = bool(self._cfg.checkpoint_dir)
        else:
            resume = bool(self._cfg.training.checkpoint_dir)

        # Get device
        if self._cfg.training.cuda > -1 and torch.cuda.is_available():
            device = "cuda"
            torch.cuda.set_device(self._cfg.training.cuda)
        else:
            device = "cpu"
        self._device = torch.device(device)
        log.info("DEVICE : {}".format(self._device))

        # Profiling
        if self.profiling:
            # Set the num_workers as torch.utils.bottleneck doesn't work well with it
            self._cfg.training.num_workers = 0

        # Start Wandb if public
        if self.wandb_log:
            Wandb.launch(self._cfg, self._cfg.wandb.public and self.wandb_log)

        # Checkpoint

        self._checkpoint: ModelCheckpoint = ModelCheckpoint(
            self._cfg.training.checkpoint_dir,
            self._cfg.model_name,
            self._cfg.training.weight_name,
            run_config=self._cfg,
            resume=resume,
        )

        # Create model and datasets
        if not self._checkpoint.is_empty:
            self._dataset: BaseDataset = instantiate_dataset(
                self._checkpoint.data_config)
            self._model: BaseModel = self._checkpoint.create_model(
                self._dataset, weight_name=self._cfg.training.weight_name)
        else:
            self._dataset: BaseDataset = instantiate_dataset(self._cfg.data)
            self._model: BaseModel = instantiate_model(
                copy.deepcopy(self._cfg), self._dataset)
            self._model.instantiate_optimizers(self._cfg, "cuda" in device)
            self._model.set_pretrained_weights()
            if not self._checkpoint.validate(self._dataset.used_properties):
                log.warning(
                    "The model will not be able to be used from pretrained weights without the corresponding dataset. Current properties are {}"
                    .format(self._dataset.used_properties))
        self._checkpoint.dataset_properties = self._dataset.used_properties

        log.info(self._model)

        self._model.log_optimizers()
        log.info(
            "Model size = %i",
            sum(param.numel() for param in self._model.parameters()
                if param.requires_grad))

        # Set dataloaders
        self._dataset.create_dataloaders(
            self._model,
            self._cfg.training.batch_size,
            self._cfg.training.shuffle,
            self._cfg.training.num_workers,
            self.precompute_multi_scale,
        )
        log.info(self._dataset)

        # Verify attributes in dataset
        self._model.verify_data(self._dataset.train_dataset[0])

        # Choose selection stage
        selection_stage = getattr(self._cfg, "selection_stage", "")
        self._checkpoint.selection_stage = self._dataset.resolve_saving_stage(
            selection_stage)
        self._tracker: BaseTracker = self._dataset.get_tracker(
            self.wandb_log, self.tensorboard_log)

        if self.wandb_log:
            Wandb.launch(self._cfg, not self._cfg.wandb.public
                         and self.wandb_log)

        # Run training / evaluation
        self._model = self._model.to(self._device)
        if self.has_visualization:
            self._visualizer = Visualizer(self._cfg.visualization,
                                          self._dataset.num_batches,
                                          self._dataset.batch_size,
                                          os.getcwd())
예제 #8
0
def main(cfg):
    OmegaConf.set_struct(
        cfg,
        False)  # This allows getattr and hasattr methods to function correctly
    if cfg.pretty_print:
        print(cfg.pretty())

    # Get device
    device = torch.device("cuda" if (
        torch.cuda.is_available() and cfg.training.cuda) else "cpu")
    log.info("DEVICE : {}".format(device))

    # Enable CUDNN BACKEND
    torch.backends.cudnn.enabled = cfg.training.enable_cudnn

    # Profiling
    profiling = getattr(cfg.debugging, "profiling", False)
    if profiling:
        # Set the num_workers as torch.utils.bottleneck doesn't work well with it
        cfg.training.num_workers = 0

    # Start Wandb if public
    launch_wandb(cfg, cfg.wandb.public and cfg.wandb.log)

    # Checkpoint
    checkpoint = ModelCheckpoint(
        cfg.training.checkpoint_dir,
        cfg.model_name,
        cfg.training.weight_name,
        run_config=cfg,
        resume=bool(cfg.training.checkpoint_dir),
    )

    # Create model and datasets
    if not checkpoint.is_empty:
        dataset = instantiate_dataset(checkpoint.data_config)
        model = checkpoint.create_model(dataset,
                                        weight_name=cfg.training.weight_name)
    else:
        dataset = instantiate_dataset(cfg.data)
        model = instantiate_model(cfg, dataset)
        model.instantiate_optimizers(cfg)
    log.info(model)
    model.log_optimizers()
    log.info(
        "Model size = %i",
        sum(param.numel() for param in model.parameters()
            if param.requires_grad))

    # Set dataloaders
    dataset.create_dataloaders(
        model,
        cfg.training.batch_size,
        cfg.training.shuffle,
        cfg.training.num_workers,
        cfg.training.precompute_multi_scale,
    )
    log.info(dataset)

    # Choose selection stage
    selection_stage = getattr(cfg, "selection_stage", "")
    checkpoint.selection_stage = dataset.resolve_saving_stage(selection_stage)
    tracker: BaseTracker = dataset.get_tracker(model, dataset, cfg.wandb.log,
                                               cfg.tensorboard.log)

    launch_wandb(cfg, not cfg.wandb.public and cfg.wandb.log)

    # Run training / evaluation
    model = model.to(device)
    visualizer = Visualizer(cfg.visualization, dataset.num_batches,
                            dataset.batch_size, os.getcwd())
    run(cfg, model, dataset, device, tracker, checkpoint, visualizer)

    # https://github.com/facebookresearch/hydra/issues/440
    hydra._internal.hydra.GlobalHydra.get_state().clear()
    return 0