def as_symbol_block_predictor( self, batch: Optional[DataBatch] = None, dataset: Optional[Dataset] = None) -> SymbolBlockPredictor: if batch is None: data_loader = InferenceDataLoader(dataset, transform=self.input_transform, batch_size=self.batch_size, stack_fn=partial( batchify, ctx=self.ctx, dtype=self.dtype)) batch = next(iter(data_loader)) with self.ctx: symbol_block_net = hybrid_block_to_symbol_block( hb=self.prediction_net, data_batch=[batch[k] for k in self.input_names], ) return SymbolBlockPredictor( input_names=self.input_names, prediction_net=symbol_block_net, batch_size=self.batch_size, prediction_length=self.prediction_length, freq=self.freq, ctx=self.ctx, input_transform=self.input_transform, lead_time=self.lead_time, forecast_generator=self.forecast_generator, output_transform=self.output_transform, dtype=self.dtype, )
def test_symb_block_export_import_nested_array(block_type, hybridize) -> None: x1 = mx.nd.array([1, 2, 3]) x2 = [mx.nd.array([1, 5, 5]), mx.nd.array([2, 3, 3])] my_block = block_type() my_block.collect_params().initialize() if hybridize: my_block.hybridize() my_block(x1, x2) sb = util.hybrid_block_to_symbol_block(my_block, [x1, x2]) assert np.allclose(sb(x1, x2).asnumpy(), my_block(x1, x2).asnumpy())
def as_symbol_block_predictor(self, batch: DataBatch) -> SymbolBlockPredictor: symbol_block_net = hybrid_block_to_symbol_block( hb=self.prediction_net, data_batch=[batch[k] for k in self.input_names], ) return SymbolBlockPredictor( input_names=self.input_names, prediction_net=symbol_block_net, batch_size=self.batch_size, prediction_length=self.prediction_length, freq=self.freq, ctx=self.ctx, input_transform=self.input_transform, forecast_generator=self.forecast_generator, output_transform=self.output_transform, float_type=self.float_type, )