コード例 #1
0
		def create_training_data_loader(self, dataset, **kwargs):
			instance_splitter = InstanceSplitter(
				target_field=FieldName.TARGET,
				is_pad_field=FieldName.IS_PAD,
				start_field=FieldName.START,
				forecast_start_field=FieldName.FORECAST_START,
				instance_sampler=ExpectedNumInstanceSampler(
					num_instances=1,
					min_future=self.prediction_length,
				),
				past_length=self.context_length + 1,
				future_length=self.prediction_length,
				time_series_fields=[
					FieldName.FEAT_DYNAMIC_REAL,
					FieldName.OBSERVED_VALUES,
				],
			)
			input_names = get_hybrid_forward_input_names(MyProbTrainRNN)
			return TrainDataLoader(
				dataset=dataset,
				transform=instance_splitter + SelectFields(input_names),
				batch_size=self.batch_size,
				stack_fn=functools.partial(batchify, ctx=self.trainer.ctx, dtype=self.dtype),
				decode_fn=functools.partial(as_in_context, ctx=self.trainer.ctx),
				**kwargs,
			)
コード例 #2
0
ファイル: _estimator.py プロジェクト: pablosteinmetz/gluon-ts
 def create_validation_data_loader(
     self,
     data: Dataset,
     **kwargs,
 ) -> DataLoader:
     input_names = get_hybrid_forward_input_names(DeepFactorTrainingNetwork)
     instance_splitter = self._create_instance_splitter("validation")
     return ValidationDataLoader(
         dataset=data,
         transform=instance_splitter + SelectFields(input_names),
         batch_size=self.batch_size,
         stack_fn=partial(batchify, ctx=self.trainer.ctx, dtype=self.dtype),
     )
コード例 #3
0
 def create_validation_data_loader(
     self,
     data: Dataset,
     **kwargs,
 ) -> DataLoader:
     input_names = get_hybrid_forward_input_names(CanonicalTrainingNetwork)
     with env._let(max_idle_transforms=maybe_len(data) or 0):
         instance_splitter = self._create_instance_splitter("validation")
     return ValidationDataLoader(
         dataset=data,
         transform=instance_splitter + SelectFields(input_names),
         batch_size=self.batch_size,
         stack_fn=partial(batchify, ctx=self.trainer.ctx, dtype=self.dtype),
     )
コード例 #4
0
 def create_training_data_loader(
     self,
     data: Dataset,
     **kwargs,
 ) -> DataLoader:
     input_names = get_hybrid_forward_input_names(LSTNetTrain)
     instance_splitter = self._create_instance_splitter("training")
     return TrainDataLoader(
         dataset=data,
         transform=instance_splitter + SelectFields(input_names),
         batch_size=self.batch_size,
         stack_fn=partial(batchify, ctx=self.trainer.ctx, dtype=self.dtype),
         decode_fn=partial(as_in_context, ctx=self.trainer.ctx),
         **kwargs,
     )
コード例 #5
0
ファイル: _estimator.py プロジェクト: pablo2909/gluon-ts
 def create_training_data_loader(
     self,
     data: Dataset,
     **kwargs,
 ) -> DataLoader:
     with env._let(max_idle_transforms=maybe_len(data) or 0):
         train_transform = (self._create_instance_splitter("training") +
                            self._create_post_split_transform() +
                            SelectFields(["past_target", "valid_length"]))
     return TrainDataLoader(
         train_transform.apply(Cyclic(data)),
         batch_size=self.batch_size,
         stack_fn=self._stack_fn(),
         decode_fn=partial(as_in_context, ctx=self.trainer.ctx),
     )
コード例 #6
0
 def create_validation_data_loader(
     self,
     data: Dataset,
     **kwargs,
 ) -> DataLoader:
     with env._let(max_idle_transforms=maybe_len(data) or 0):
         validation_transform = (
             self._create_instance_splitter("validation") +
             self._create_post_split_transform() +
             SelectFields(["past_target", "valid_length"]))
     return ValidationDataLoader(
         validation_transform.apply(data),
         batch_size=self.batch_size,
         stack_fn=self._stack_fn(),
     )
コード例 #7
0
ファイル: _estimator.py プロジェクト: vishalbelsare/gluon-ts
 def create_training_data_loader(
     self,
     data: Dataset,
     **kwargs,
 ) -> DataLoader:
     input_names = get_hybrid_forward_input_names(DeepStateTrainingNetwork)
     with env._let(max_idle_transforms=maybe_len(data) or 0):
         instance_splitter = self._create_instance_splitter("training")
     return TrainDataLoader(
         dataset=data,
         transform=instance_splitter + SelectFields(input_names),
         batch_size=self.batch_size,
         stack_fn=partial(batchify, ctx=self.trainer.ctx, dtype=self.dtype),
         decode_fn=partial(as_in_context, ctx=self.trainer.ctx),
         **kwargs,
     )
コード例 #8
0
ファイル: estimator.py プロジェクト: vishalbelsare/gluon-ts
    def create_validation_data_loader(
        self,
        data: Dataset,
        module: DeepARLightningModule,
        **kwargs,
    ) -> Iterable:
        transformation = self._create_instance_splitter(
            module, "validation") + SelectFields(TRAINING_INPUT_NAMES)

        validation_instances = transformation.apply(data)

        return DataLoader(
            IterableDataset(validation_instances),
            batch_size=self.batch_size,
            **kwargs,
        )
コード例 #9
0
ファイル: _estimator.py プロジェクト: kaleming/gluon-ts
 def create_validation_data_loader(
     self,
     data: Dataset,
     **kwargs,
 ) -> DataLoader:
     validation_transform = (self._create_instance_splitter("validation") +
                             self._create_post_split_transform() +
                             SelectFields(["past_target", "valid_length"]))
     return DataLoader(
         data_iterable=TransformedDataset(data,
                                          validation_transform,
                                          is_train=True),
         batch_size=self.batch_size,
         stack_fn=self._stack_fn(),
         decode_fn=partial(as_in_context, ctx=self.trainer.ctx),
     )
コード例 #10
0
ファイル: estimator.py プロジェクト: vishalbelsare/gluon-ts
    def create_training_data_loader(
        self,
        data: Dataset,
        module: DeepARLightningModule,
        shuffle_buffer_length: Optional[int] = None,
        **kwargs,
    ) -> Iterable:
        transformation = self._create_instance_splitter(
            module, "training") + SelectFields(TRAINING_INPUT_NAMES)

        training_instances = transformation.apply(
            Cyclic(data) if shuffle_buffer_length is None else PseudoShuffled(
                Cyclic(data), shuffle_buffer_length=shuffle_buffer_length))

        return IterableSlice(
            iter(
                DataLoader(
                    IterableDataset(training_instances),
                    batch_size=self.batch_size,
                    **kwargs,
                )),
            self.num_batches_per_epoch,
        )
コード例 #11
0
    def train_model(
        self,
        training_data: Dataset,
        validation_data: Optional[Dataset] = None,
        num_workers: int = 0,
        prefetch_factor: int = 2,
        shuffle_buffer_length: Optional[int] = None,
        cache_data: bool = False,
        **kwargs,
    ) -> TrainOutput:
        transformation = self.create_transformation()

        trained_net = self.create_training_network(self.trainer.device)

        input_names = get_module_forward_input_names(trained_net)

        with env._let(max_idle_transforms=maybe_len(training_data) or 0):
            training_instance_splitter = self.create_instance_splitter(
                "training")
        training_iter_dataset = TransformedIterableDataset(
            dataset=training_data,
            transform=transformation + training_instance_splitter +
            SelectFields(input_names),
            is_train=True,
            shuffle_buffer_length=shuffle_buffer_length,
            cache_data=cache_data,
        )

        training_data_loader = DataLoader(
            training_iter_dataset,
            batch_size=self.trainer.batch_size,
            num_workers=num_workers,
            prefetch_factor=prefetch_factor,
            pin_memory=True,
            worker_init_fn=self._worker_init_fn,
            **kwargs,
        )

        validation_data_loader = None
        if validation_data is not None:
            with env._let(max_idle_transforms=maybe_len(validation_data) or 0):
                validation_instance_splitter = self.create_instance_splitter(
                    "validation")
            validation_iter_dataset = TransformedIterableDataset(
                dataset=validation_data,
                transform=transformation + validation_instance_splitter +
                SelectFields(input_names),
                is_train=True,
                cache_data=cache_data,
            )
            validation_data_loader = DataLoader(
                validation_iter_dataset,
                batch_size=self.trainer.batch_size,
                num_workers=num_workers,
                prefetch_factor=prefetch_factor,
                pin_memory=True,
                worker_init_fn=self._worker_init_fn,
                **kwargs,
            )

        self.trainer(
            net=trained_net,
            train_iter=training_data_loader,
            validation_iter=validation_data_loader,
        )

        return TrainOutput(
            transformation=transformation,
            trained_net=trained_net,
            predictor=self.create_predictor(transformation, trained_net,
                                            self.trainer.device),
        )
コード例 #12
0
    def train_model(
        self,
        training_data: Dataset,
        validation_data: Optional[Dataset] = None,
        num_workers: Optional[int] = None,
        num_prefetch: Optional[int] = None,
        shuffle_buffer_length: Optional[int] = None,
        **kwargs,
    ) -> TrainOutput:
        transformation = self.create_transformation()

        # ensure that the training network is created within the same MXNet
        # context as the one that will be used during training
        with self.trainer.ctx:
            trained_net = self.create_training_network()

        input_names = get_hybrid_forward_input_names(trained_net)

        training_data_loader = TrainDataLoader(
            dataset=training_data,
            transform=transformation + SelectFields(input_names),
            batch_size=self.batch_size,
            stack_fn=partial(
                batchify,
                ctx=self.trainer.ctx,
                dtype=self.dtype,
            ),
            num_workers=num_workers,
            num_prefetch=num_prefetch,
            shuffle_buffer_length=shuffle_buffer_length,
            decode_fn=partial(as_in_context, ctx=self.trainer.ctx),
            **kwargs,
        )

        validation_data_loader = None
        if validation_data is not None:
            validation_data_loader = ValidationDataLoader(
                dataset=validation_data,
                transform=transformation + SelectFields(input_names),
                batch_size=self.batch_size,
                stack_fn=partial(
                    batchify,
                    ctx=self.trainer.ctx,
                    dtype=self.dtype,
                ),
                num_workers=num_workers,
                num_prefetch=num_prefetch,
                **kwargs,
            )

        self.trainer(
            net=trained_net,
            train_iter=training_data_loader,
            validation_iter=validation_data_loader,
        )

        with self.trainer.ctx:
            # ensure that the prediction network is created within the same MXNet
            # context as the one that was used during training
            return TrainOutput(
                transformation=transformation,
                trained_net=trained_net,
                predictor=self.create_predictor(transformation, trained_net),
            )
コード例 #13
0
    def train(
        self,
        training_data: Dataset,
        validation_data: Optional[Dataset] = None,
        num_workers: Optional[int] = None,
        num_prefetch: Optional[int] = None,
        shuffle_buffer_length: Optional[int] = None,
        **kwargs,
    ) -> Predictor:
        has_negative_data = any(np.any(d["target"] < 0) for d in training_data)
        low = -10.0 if has_negative_data else 0
        high = 10.0
        bin_centers = np.linspace(low, high, self.num_bins)
        bin_edges = np.concatenate(
            [[-1e20], (bin_centers[1:] + bin_centers[:-1]) / 2.0, [1e20]]
        )

        logging.info(
            f"using training windows of length = {self.train_window_length}"
        )

        transformation = self.create_transformation(
            bin_edges, pred_length=self.train_window_length
        )

        # ensure that the training network is created within the same MXNet
        # context as the one that will be used during training
        with self.trainer.ctx:
            params = self._get_wavenet_args(bin_centers)
            params.update(pred_length=self.train_window_length)
            trained_net = WaveNet(**params)

        input_names = get_hybrid_forward_input_names(trained_net)

        training_data_loader = TrainDataLoader(
            dataset=training_data,
            transform=transformation + SelectFields(input_names),
            batch_size=self.batch_size,
            stack_fn=partial(batchify, ctx=self.trainer.ctx, dtype=self.dtype),
            num_workers=num_workers,
            num_prefetch=num_prefetch,
            shuffle_buffer_length=shuffle_buffer_length,
            **kwargs,
        )

        validation_data_loader = None
        if validation_data is not None:
            validation_data_loader = ValidationDataLoader(
                dataset=validation_data,
                transform=transformation,
                batch_size=self.batch_size,
                stack_fn=partial(
                    batchify, ctx=self.trainer.ctx, dtype=self.dtype
                ),
                num_workers=num_workers,
                num_prefetch=num_prefetch,
                **kwargs,
            )

        self.trainer(
            net=trained_net,
            train_iter=training_data_loader,
            validation_iter=validation_data_loader,
        )

        # ensure that the prediction network is created within the same MXNet
        # context as the one that was used during training
        with self.trainer.ctx:
            return self.create_predictor(
                transformation, trained_net, bin_centers
            )