예제 #1
0
def test_double_subgraphs():
    set_multi_subgraphs()
    context.set_context(save_graphs=True)
    context.set_auto_parallel_context(device_num=8, global_rank=0)
    context.set_auto_parallel_context(parallel_mode="auto_parallel")
    net = TrainStepWarp(NetWithLoss(Net()))
    net.set_auto_parallel()

    x = Tensor(np.ones([8, 8, 8, 8]), dtype=ms.float32)
    reset_op_id()
    _executor.compile(net, x, phase='train')
    strategies = _executor._get_strategy(net)
    expected_strategies = {'Default/network-NetWithLoss/ReduceMean-op0': [[8, 1, 1, 1]],
                           'Default/network-NetWithLoss/net-Net/ReLU-op1': [[8, 1, 1, 1]],
                           'Default/network-NetWithLoss/net-Net/Mul-op2': [[8, 1, 1, 1], [8, 1, 1, 1]],
                           'Default/network-NetWithLoss/net-Net/Mul-op3': [[8, 1, 1, 1], [8, 1, 1, 1]],
                           'Default/network-NetWithLoss/ReduceSum-op4': [[8, 1, 1, 1]]}
    assert strategies == expected_strategies
예제 #2
0
                                 config=ckptconfig)
    context.set_auto_parallel_context(strategy_ckpt_save_file=config.stra_ckpt)
    callback_list = [
        TimeMonitor(ds_train.get_dataset_size()), eval_callback, callback
    ]
    if not host_device_mix:
        callback_list.append(ckpoint_cb)
    model.train(epochs,
                ds_train,
                callbacks=callback_list,
                dataset_sink_mode=(not host_device_mix))


if __name__ == "__main__":
    wide_deep_config = WideDeepConfig()
    wide_deep_config.argparse_init()
    context.set_context(mode=context.GRAPH_MODE,
                        device_target=wide_deep_config.device_target,
                        save_graphs=True)
    context.set_context(variable_memory_max_size="24GB")
    context.set_context(enable_sparse=True)
    set_multi_subgraphs()
    init()
    if wide_deep_config.host_device_mix == 1:
        context.set_auto_parallel_context(
            parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL, mirror_mean=True)
    else:
        context.set_auto_parallel_context(
            parallel_mode=ParallelMode.AUTO_PARALLEL, mirror_mean=True)
    train_and_eval(wide_deep_config)