def get_WideDeep_net(config): WideDeep_net = WideDeepModel(config) loss_net = NetWithLossClass(WideDeep_net, config) loss_net = VirtualDatasetCellTriple(loss_net) train_net = TrainStepWrap(loss_net) eval_net = PredictWithSigmoid(WideDeep_net) eval_net = VirtualDatasetCellTriple(eval_net) return train_net, eval_net
def get_WideDeep_net(config): """ Get network of wide&deep model. """ WideDeep_net = WideDeepModel(config) loss_net = NetWithLossClass(WideDeep_net, config) loss_net = VirtualDatasetCellTriple(loss_net) train_net = TrainStepWrap(loss_net, host_device_mix=bool(config.host_device_mix)) eval_net = PredictWithSigmoid(WideDeep_net) eval_net = VirtualDatasetCellTriple(eval_net) return train_net, eval_net
def get_wide_deep_net(config): """ Get network of wide&deep model. """ wide_deep_net = WideDeepModel(config) loss_net = NetWithLossClass(wide_deep_net, config) if cache_enable: loss_net = VirtualDatasetCellTriple(loss_net) train_net = TrainStepWrap(loss_net, parameter_server=bool(config.parameter_server), cache_enable=(config.vocab_cache_size > 0)) eval_net = PredictWithSigmoid(wide_deep_net) if cache_enable: eval_net = VirtualDatasetCellTriple(eval_net) return train_net, eval_net
def test_virtualdataset_cell_3_inputs(): class Net(nn.Cell): def __init__(self, strategy0, strategy1, strategy2, strategy3): super().__init__() self.matmul1 = P.MatMul().set_strategy(strategy1) self.matmul2 = P.MatMul().set_strategy(strategy2) self.gelu = P.Gelu().set_strategy(strategy3) def construct(self, x, y, b): out = self.gelu(self.matmul1(x, y)) out = self.matmul2(out, b) return out net = GradWrap(VirtualDatasetCellTriple(NetWithLoss(Net(None, None, None, None)))) context.set_context(save_graphs=True) context.set_auto_parallel_context(parallel_mode="auto_parallel") context.set_auto_parallel_context(device_num=8, global_rank=0) x = Tensor(np.ones([128, 32]), dtype=ms.float32) y = Tensor(np.ones([32, 64]), dtype=ms.float32) b = Tensor(np.ones([64, 2048]), dtype=ms.float32) net.set_auto_parallel() _executor.compile(net, x, y, b)