示例#1
0
    def step_end(self, run_context):
        """
        Save the checkpoint at the end of step.

        Args:
            run_context (RunContext): Context of the train running.
        """
        if _is_role_pserver():
            self._prefix = "PServer_" + str(
                _get_ps_mode_rank()) + "_" + self._prefix
        cb_params = run_context.original_args()
        _make_directory(self._directory)
        # save graph (only once)
        if not self._graph_saved:
            graph_file_name = os.path.join(self._directory,
                                           self._prefix + '-graph.meta')
            if os.path.isfile(graph_file_name) and context.get_context(
                    "mode") == context.GRAPH_MODE:
                os.remove(graph_file_name)
            _save_graph(cb_params.train_network, graph_file_name)
            self._graph_saved = True
        thread_list = threading.enumerate()
        for thread in thread_list:
            if thread.getName() == "asyn_save_ckpt":
                thread.join()
        self._save_ckpt(cb_params)
示例#2
0
    def step_end(self, run_context):
        """
        Save the checkpoint at the end of step.

        Args:
            run_context (RunContext): Context of the train running.
        """
        cb_params = run_context.original_args()
        # save graph (only once)
        if not self._graph_saved:
            graph_file_name = os.path.join(self._directory, self._prefix + '-graph.meta')
            _save_graph(cb_params.train_network, graph_file_name)
            self._graph_saved = True
        self._save_ckpt(cb_params)
示例#3
0
    def step_end(self, run_context):
        """
        Save the checkpoint at the end of step.

        Args:
            run_context (RunContext): Context of the train running.
        """
        if _is_role_pserver():
            self._prefix = "PServer_" + str(
                _get_ps_mode_rank()) + "_" + self._prefix
        cb_params = run_context.original_args()
        # save graph (only once)
        if not self._graph_saved:
            graph_file_name = os.path.join(self._directory,
                                           self._prefix + '-graph.meta')
            _save_graph(cb_params.train_network, graph_file_name)
            self._graph_saved = True
        self._save_ckpt(cb_params)
示例#4
0
def test_save_graph():
    """ test_exec_save_graph """
    class Net1(nn.Cell):
        def __init__(self):
            super(Net1, self).__init__()
            self.add = P.TensorAdd()

        def construct(self, x, y):
            z = self.add(x, y)
            return z

    net = Net1()
    net.set_train()
    out_me_list = []
    x = Tensor(np.random.rand(2, 1, 2, 3).astype(np.float32))
    y = Tensor(np.array([1.2]).astype(np.float32))
    out_put = net(x, y)
    _save_graph(network=net, file_name="net-graph.meta")
    out_me_list.append(out_put)
示例#5
0
def test_save_graph():
    """ test_exec_save_graph """
    class Net1(nn.Cell):
        def __init__(self):
            super(Net1, self).__init__()
            self.add = P.Add()

        def construct(self, x, y):
            z = self.add(x, y)
            return z

    net = Net1()
    net.set_train()
    out_me_list = []
    x = Tensor(np.random.rand(2, 1, 2, 3).astype(np.float32))
    y = Tensor(np.array([1.2]).astype(np.float32))
    out_put = net(x, y)
    output_file = "net-graph.meta"
    _save_graph(network=net, file_name=output_file)
    out_me_list.append(out_put)
    assert os.path.exists(output_file)
    os.chmod(output_file, stat.S_IWRITE)
    os.remove(output_file)