def get_WideDeep_net(config):
    WideDeep_net = WideDeepModel(config)
    loss_net = NetWithLossClass(WideDeep_net, config)
    loss_net = VirtualDatasetCellTriple(loss_net)
    train_net = TrainStepWrap(loss_net)
    eval_net = PredictWithSigmoid(WideDeep_net)
    eval_net = VirtualDatasetCellTriple(eval_net)
    return train_net, eval_net
def get_WideDeep_net(config):
    """
    Get network of wide&deep model.
    """
    WideDeep_net = WideDeepModel(config)
    loss_net = NetWithLossClass(WideDeep_net, config)
    loss_net = VirtualDatasetCellTriple(loss_net)
    train_net = TrainStepWrap(loss_net, host_device_mix=bool(config.host_device_mix))
    eval_net = PredictWithSigmoid(WideDeep_net)
    eval_net = VirtualDatasetCellTriple(eval_net)
    return train_net, eval_net
Пример #3
0
def get_wide_deep_net(config):
    """
    Get network of wide&deep model.
    """
    wide_deep_net = WideDeepModel(config)
    loss_net = NetWithLossClass(wide_deep_net, config)
    if cache_enable:
        loss_net = VirtualDatasetCellTriple(loss_net)
    train_net = TrainStepWrap(loss_net,
                              parameter_server=bool(config.parameter_server),
                              cache_enable=(config.vocab_cache_size > 0))
    eval_net = PredictWithSigmoid(wide_deep_net)
    if cache_enable:
        eval_net = VirtualDatasetCellTriple(eval_net)
    return train_net, eval_net
Пример #4
0
def test_virtualdataset_cell_3_inputs():
    class Net(nn.Cell):
        def __init__(self, strategy0, strategy1, strategy2, strategy3):
            super().__init__()
            self.matmul1 = P.MatMul().set_strategy(strategy1)
            self.matmul2 = P.MatMul().set_strategy(strategy2)
            self.gelu = P.Gelu().set_strategy(strategy3)

        def construct(self, x, y, b):
            out = self.gelu(self.matmul1(x, y))
            out = self.matmul2(out, b)
            return out

    net = GradWrap(VirtualDatasetCellTriple(NetWithLoss(Net(None, None, None, None))))
    context.set_context(save_graphs=True)
    context.set_auto_parallel_context(parallel_mode="auto_parallel")
    context.set_auto_parallel_context(device_num=8, global_rank=0)
    x = Tensor(np.ones([128, 32]), dtype=ms.float32)
    y = Tensor(np.ones([32, 64]), dtype=ms.float32)
    b = Tensor(np.ones([64, 2048]), dtype=ms.float32)
    net.set_auto_parallel()
    _executor.compile(net, x, y, b)