示例#1
0
    def setup(self, stage: Optional[str] = None):
        if self.config.experiment.use_gfs_data:
            synop_inputs, all_gfs_input_data, gfs_target_data = self.prepare_dataset_for_gfs(
            )

            if self.gfs_train_params is not None:
                dataset = Sequence2SequenceWithGFSDataset(
                    self.config, self.synop_data, self.synop_data_indices,
                    gfs_target_data, all_gfs_input_data)
            else:
                dataset = Sequence2SequenceWithGFSDataset(
                    self.config, self.synop_data, self.synop_data_indices,
                    gfs_target_data)

        else:
            dataset = Sequence2SequenceDataset(self.config, self.synop_data,
                                               self.synop_data_indices)

        if len(dataset) == 0:
            raise RuntimeError(
                "There are no valid samples in the dataset! Please check your run configuration"
            )

        dataset.set_mean(self.synop_mean)
        dataset.set_std(self.synop_std)
        self.dataset_train, self.dataset_val = split_dataset(
            dataset,
            self.config.experiment.val_split,
            sequence_length=self.sequence_length
            if self.sequence_length > 1 else None)
        self.dataset_test = self.dataset_val
    def setup(self, stage: Optional[str] = None):
        if self.config.experiment.use_gfs_data:
            synop_inputs, all_gfs_input_data, gfs_target_data = self.prepare_dataset_for_gfs()

            self.cmax_IDs = [item for index, item in enumerate(self.cmax_IDs) if
                             index not in self.removed_dataset_indices]

            assert len(self.cmax_IDs) == len(synop_inputs)

            if self.gfs_train_params is not None:
                dataset = ConcatDatasets(Sequence2SequenceWithGFSDataset(self.config, self.synop_data,
                                                                         self.synop_data_indices, gfs_target_data,
                                                                         all_gfs_input_data),
                                         CMAXDataset(config=self.config, IDs=self.cmax_IDs, normalize=True))
            else:
                dataset = ConcatDatasets(
                    Sequence2SequenceWithGFSDataset(self.config, self.synop_data, self.synop_data_indices, gfs_target_data),
                    CMAXDataset(config=self.config, IDs=self.cmax_IDs, normalize=True))

        else:
            assert len(self.cmax_IDs) == len(self.synop_data_indices)

            dataset = ConcatDatasets(
                Sequence2SequenceDataset(self.config, self.synop_data, self.synop_data_indices),
                CMAXDataset(config=self.config, IDs=self.cmax_IDs, normalize=True))

        dataset.set_mean([self.synop_mean, 0])
        dataset.set_std([self.synop_std, 0])
        self.dataset_train, self.dataset_val = split_dataset(dataset, self.config.experiment.val_split,
                                                             sequence_length=self.sequence_length if self.sequence_length > 1 else None)
        self.dataset_test = self.dataset_val
 def setup(self, stage: Optional[str] = None):
     dataset = SequenceLimitedToGFSDatesDataset(config=self.config,
                                                synop_data=self.labels,
                                                dates=self.dates)
     self.dataset_train, self.dataset_val = split_dataset(
         dataset,
         self.config.experiment.val_split,
         sequence_length=self.sequence_length
         if self.sequence_length > 1 else None)
     self.dataset_test = self.dataset_val
 def setup(self, stage: Optional[str] = None):
     dataset = MultiChannelSpatialDataset(config=self.config,
                                          train_IDs=self.IDs,
                                          labels=self.labels)
     self.dataset_train, self.dataset_val = split_dataset(
         dataset,
         self.config.experiment.val_split,
         sequence_length=self.sequence_length
         if self.sequence_length > 1 else None)
     self.dataset_test = self.dataset_val
示例#5
0
 def setup(self, stage: Optional[str] = None):
     dataset = SingleGFSPointDataset(config=self.config)
     self.dataset_train, self.dataset_val = split_dataset(
         dataset, self.config.experiment.val_split)
     self.dataset_test = self.dataset_val
示例#6
0
 def setup(self, stage: Optional[str] = None):
     dataset = MultiChannelSpatialSubregionDataset(config=self.config, train_IDs=self.IDs, labels=self.labels, normalize=True)
     self.dataset_train, self.dataset_val = split_dataset(dataset, self.config.experiment.val_split)
     self.dataset_test = self.dataset_val