def as_symbol_block_predictor( self, batch: Optional[DataBatch] = None, dataset: Optional[Dataset] = None, ) -> SymbolBlockPredictor: if batch is None: data_loader = InferenceDataLoader( dataset, transform=self.input_transform, batch_size=self.batch_size, stack_fn=partial(batchify, ctx=self.ctx, dtype=self.dtype), ) batch = next(iter(data_loader)) with self.ctx: symbol_block_net = hybrid_block_to_symbol_block( hb=self.prediction_net, data_batch=[batch[k] for k in self.input_names], ) return SymbolBlockPredictor( input_names=self.input_names, prediction_net=symbol_block_net, batch_size=self.batch_size, prediction_length=self.prediction_length, freq=self.freq, ctx=self.ctx, input_transform=self.input_transform, lead_time=self.lead_time, forecast_generator=self.forecast_generator, output_transform=self.output_transform, dtype=self.dtype, )
def test_symb_block_export_import_nested_array(block_type, hybridize) -> None: x1 = mx.nd.array([1, 2, 3]) x2 = [mx.nd.array([1, 5, 5]), mx.nd.array([2, 3, 3])] my_block = block_type() my_block.collect_params().initialize() if hybridize: my_block.hybridize() my_block(x1, x2) sb = hybrid_block_to_symbol_block(my_block, [x1, x2]) assert np.allclose(sb(x1, x2).asnumpy(), my_block(x1, x2).asnumpy())