def get_learner_config(self, pipeline):
        data = SemanticSegmentationDataConfig()
        data.uri = pipeline.chip_uri
        data.class_names = pipeline.dataset.class_config.names
        data.class_colors = pipeline.dataset.class_config.colors
        data.img_sz = pipeline.train_chip_sz
        data.augmentors = self.augmentors

        learner = SemanticSegmentationLearnerConfig(
            data=data,
            model=self.model,
            solver=self.solver,
            test_mode=self.test_mode,
            output_uri=pipeline.train_uri,
            log_tensorboard=self.log_tensorboard,
            run_tensorboard=self.run_tensorboard)
        learner.update()
        return learner
    def get_learner_config(self, pipeline):
        data = SemanticSegmentationDataConfig(
            uri=pipeline.chip_uri,
            class_names=pipeline.dataset.class_config.names,
            class_colors=pipeline.dataset.class_config.colors,
            img_sz=pipeline.train_chip_sz,
            img_channels=pipeline.dataset.img_channels,
            img_format=pipeline.img_format,
            label_format=pipeline.label_format,
            num_workers=self.num_workers,
            augmentors=self.augmentors,
            base_transform=self.base_transform,
            aug_transform=self.aug_transform,
            plot_options=self.plot_options,
            channel_display_groups=pipeline.channel_display_groups)

        learner = SemanticSegmentationLearnerConfig(
            data=data,
            model=self.model,
            solver=self.solver,
            test_mode=self.test_mode,
            output_uri=pipeline.train_uri,
            log_tensorboard=self.log_tensorboard,
            run_tensorboard=self.run_tensorboard)
        learner.update()
        learner.validate_config()
        return learner
Example #3
0
 def get_learner_config(self, pipeline):
     learner = SemanticSegmentationLearnerConfig(
         data=self.data,
         model=self.model,
         solver=self.solver,
         test_mode=self.test_mode,
         output_uri=pipeline.train_uri,
         log_tensorboard=self.log_tensorboard,
         run_tensorboard=self.run_tensorboard)
     learner.update()
     learner.validate_config()
     return learner