def setup_test_data(self, test_data_config: Optional[DictConfig] = None):
        if test_data_config is None:
            test_data_config = self._cfg.test_ds

        labels_file = os.path.join(self._cfg.dataset.data_dir, test_data_config.labels_file)
        get_label_ids(
            label_file=labels_file,
            is_training=False,
            pad_label=self._cfg.dataset.pad_label,
            label_ids_dict=self._cfg.label_ids,
            get_weights=False,
        )

        self._test_dl = self._setup_dataloader_from_config(cfg=test_data_config)
    def setup_training_data(self,
                            train_data_config: Optional[DictConfig] = None):
        if train_data_config is None:
            train_data_config = self._cfg.train_ds

        labels_file = os.path.join(self._cfg.dataset.data_dir,
                                   train_data_config.labels_file)

        # for older(pre - 1.0.0.b3) configs compatibility
        if not hasattr(self._cfg,
                       "class_labels") or self._cfg.class_labels is None:
            OmegaConf.set_struct(self._cfg, False)
            self._cfg.class_labels = {}
            self._cfg.class_labels = OmegaConf.create(
                {'class_labels_file': 'label_ids.csv'})
            OmegaConf.set_struct(self._cfg, True)

        label_ids, label_ids_filename, self.class_weights = get_label_ids(
            label_file=labels_file,
            is_training=True,
            pad_label=self._cfg.dataset.pad_label,
            label_ids_dict=self._cfg.label_ids,
            get_weights=True,
            class_labels_file_artifact=self._cfg.class_labels.
            class_labels_file,
        )
        # save label maps to the config
        self._cfg.label_ids = OmegaConf.create(label_ids)

        self.register_artifact('class_labels.class_labels_file',
                               label_ids_filename)
        self._train_dl = self._setup_dataloader_from_config(
            cfg=train_data_config)
Beispiel #3
0
    def setup_training_data(self,
                            train_data_config: Optional[DictConfig] = None):
        if train_data_config is None:
            train_data_config = self._cfg.train_ds

        labels_file = os.path.join(self._cfg.dataset.data_dir,
                                   train_data_config.labels_file)
        label_ids, label_ids_filename, self.class_weights = get_label_ids(
            label_file=labels_file,
            is_training=True,
            pad_label=self._cfg.dataset.pad_label,
            label_ids_dict=self._cfg.label_ids,
            get_weights=self._cfg.dataset.class_balancing == 'weighted_loss',
        )
        # save label maps to the config
        self._cfg.label_ids = OmegaConf.create(label_ids)
        self.register_artifact('label_ids.csv', label_ids_filename)
        self._train_dl = self._setup_dataloader_from_config(
            cfg=train_data_config)