Exemple #1
0
def test_step_end_save_graph():
    """Test save checkpoint."""
    train_config = CheckpointConfig(save_checkpoint_steps=16,
                                    save_checkpoint_seconds=0,
                                    keep_checkpoint_max=5,
                                    keep_checkpoint_per_n_minutes=0)
    cb_params = _InternalCallbackParam()
    net = LossNet()
    input_data = Tensor(
        np.random.randint(0, 255, [1, 3, 224, 224]).astype(np.float32))
    input_label = Tensor(np.random.randint(0, 3, [1, 3]).astype(np.float32))
    net(input_data, input_label)
    cb_params.train_network = net
    cb_params.epoch_num = 10
    cb_params.cur_epoch_num = 5
    cb_params.cur_step_num = 0
    cb_params.batch_num = 32
    ckpoint_cb = ModelCheckpoint(prefix="test",
                                 directory='./test_files',
                                 config=train_config)
    run_context = RunContext(cb_params)
    ckpoint_cb.begin(run_context)
    # import pdb;pdb.set_trace()
    ckpoint_cb.step_end(run_context)
    assert os.path.exists('./test_files/test-graph.meta')
    if os.path.exists('./test_files/test-graph.meta'):
        os.chmod('./test_files/test-graph.meta', stat.S_IWRITE)
        os.remove('./test_files/test-graph.meta')
    ckpoint_cb.step_end(run_context)
    assert not os.path.exists('./test_files/test-graph.meta')
Exemple #2
0
def test_save_checkpoint():
    """Test save checkpoint."""
    train_config = CheckpointConfig(save_checkpoint_steps=16,
                                    save_checkpoint_seconds=0,
                                    keep_checkpoint_max=5,
                                    keep_checkpoint_per_n_minutes=0)
    cb_params = _InternalCallbackParam()
    net = Net()
    loss = nn.SoftmaxCrossEntropyWithLogits()
    optim = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9)
    network_ = WithLossCell(net, loss)
    _train_network = TrainOneStepCell(network_, optim)
    cb_params.train_network = _train_network
    cb_params.epoch_num = 10
    cb_params.cur_epoch_num = 5
    cb_params.cur_step_num = 0
    cb_params.batch_num = 32
    ckpoint_cb = ModelCheckpoint(prefix="test_ckpt",
                                 directory='./test_files',
                                 config=train_config)
    run_context = RunContext(cb_params)
    ckpoint_cb.begin(run_context)
    ckpoint_cb.step_end(run_context)
    if os.path.exists('./test_files/test_ckpt-model.pkl'):
        os.chmod('./test_files/test_ckpt-model.pkl', stat.S_IWRITE)
        os.remove('./test_files/test_ckpt-model.pkl')
Exemple #3
0
def test_checkpoint_save_ckpt_seconds():
    """Test checkpoint save ckpt seconds."""
    train_config = CheckpointConfig(save_checkpoint_steps=16,
                                    save_checkpoint_seconds=100,
                                    keep_checkpoint_max=0,
                                    keep_checkpoint_per_n_minutes=1)
    ckpt_cb = ModelCheckpoint(config=train_config)
    cb_params = _InternalCallbackParam()
    net = Net()
    loss = nn.SoftmaxCrossEntropyWithLogits()
    optim = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9)
    network_ = WithLossCell(net, loss)
    _train_network = TrainOneStepCell(network_, optim)
    cb_params.train_network = _train_network
    cb_params.epoch_num = 10
    cb_params.cur_epoch_num = 4
    cb_params.cur_step_num = 128
    cb_params.batch_num = 32
    run_context = RunContext(cb_params)
    ckpt_cb.begin(run_context)
    ckpt_cb.step_end(run_context)
    ckpt_cb2 = ModelCheckpoint(config=train_config)
    cb_params.cur_epoch_num = 1
    cb_params.cur_step_num = 16
    ckpt_cb2.begin(run_context)
    ckpt_cb2.step_end(run_context)
Exemple #4
0
def test_internal_callback_param():
    """Test Internal CallbackParam."""
    cb_params = _InternalCallbackParam()
    cb_params.member1 = 1
    cb_params.member2 = "abc"
    assert cb_params.member1 == 1
    assert cb_params.member2 == "abc"
Exemple #5
0
    def _train(self,
               epoch,
               train_dataset,
               callbacks=None,
               dataset_sink_mode=True):
        """
        Training.

        Args:
            epoch (int): Total number of iterations on the data.
            train_dataset (Dataset): A training dataset iterator. If there is no
                                     loss_fn, a tuple with multiply data (data1, data2, data3, ...) will be
                                     returned and passed to the network. Otherwise, a tuple (data, label) will
                                     be returned, and the data and label are passed to the network and loss
                                     function respectively.
            callbacks (list): List of callback object. Callbacks which should be executed while training. Default: None.
            dataset_sink_mode (bool): Determines whether to pass the data through dataset channel. Default: True.
                                      Configure pynative mode, the training process will be performed with
                                      dataset not sink.
        """
        epoch = check_int_positive(epoch)
        self._train_network.set_train()

        if self._parameter_broadcast:
            self._train_network.set_broadcast_flag()

        # build callback list
        cb_params = _InternalCallbackParam()
        cb_params.train_network = self._train_network
        cb_params.epoch_num = epoch
        cb_params.batch_num = train_dataset.get_dataset_size()
        cb_params.mode = "train"
        cb_params.loss_fn = self._loss_fn
        cb_params.optimizer = self._optimizer
        cb_params.parallel_mode = self._parallel_mode
        cb_params.device_number = self._device_number
        cb_params.train_dataset = train_dataset
        cb_params.list_callback = callbacks

        with _CallbackManager(callbacks) as list_callback:
            if not dataset_sink_mode:
                self._train_process(epoch, train_dataset, list_callback,
                                    cb_params)
            elif context.get_context("mode") == context.PYNATIVE_MODE:
                logger.warning(
                    "The pynative mode cannot support dataset sink mode currently."
                    "So the training process will be performed with dataset not sink."
                )
                self._train_process(epoch, train_dataset, list_callback,
                                    cb_params)
            else:
                self._train_dataset_sink_process(epoch, train_dataset,
                                                 list_callback, cb_params)
Exemple #6
0
    def eval(self, valid_dataset, callbacks=None, dataset_sink_mode=True):
        """
        Evaluation API where the iteration is controlled by python front-end.

        Configure to pynative mode, the evaluation will be performed with dataset non-sink mode.

        Note:
            CPU is not supported when dataset_sink_mode is true.
            If dataset_sink_mode is True, data will be sent to device. If device is Ascend, features
            of data will be transferred one by one. The limitation of data transmission per time is 256M.

        Args:
            valid_dataset (Dataset): Dataset to evaluate the model.
            callbacks (list): List of callback object. Callbacks which should be excuted
                              while training. Default: None.
            dataset_sink_mode (bool): Determines whether to pass the data through dataset channel. Default: True.

        Returns:
            Dict, returns the loss value & metrics values for the model in test mode.

        Examples:
            >>> dataset = get_dataset()
            >>> net = Net()
            >>> loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True)
            >>> model = Model(net, loss_fn=loss, optimizer=None, metrics={'acc'})
            >>> model.eval(dataset)
        """
        check_bool(dataset_sink_mode)
        _device_number_check(self._parallel_mode, self._device_number)
        if not self._metric_fns:
            raise ValueError("metric fn can not be None or empty.")

        cb_params = _InternalCallbackParam()
        cb_params.eval_network = self._eval_network
        cb_params.valid_dataset = valid_dataset
        cb_params.batch_num = valid_dataset.get_dataset_size()
        cb_params.mode = "eval"
        cb_params.cur_step_num = 0

        self._eval_network.set_train(mode=False)
        self._eval_network.phase = 'eval'

        self._clear_metrics()

        with _CallbackManager(callbacks) as list_callback:
            if dataset_sink_mode:
                return self._eval_dataset_sink_process(valid_dataset,
                                                       list_callback,
                                                       cb_params)
            return self._eval_process(valid_dataset, list_callback, cb_params)
Exemple #7
0
def test_loss_monitor_normal_mode():
    """Test loss monitor normal(non-sink) mode."""
    cb_params = _InternalCallbackParam()
    run_context = RunContext(cb_params)
    loss_cb = LossMonitor(1)
    cb_params.cur_epoch_num = 4
    cb_params.cur_step_num = 1
    cb_params.batch_num = 1
    cb_params.net_outputs = Tensor(2.0)
    loss_cb.begin(run_context)
    loss_cb.epoch_begin(run_context)
    loss_cb.step_begin(run_context)
    loss_cb.step_end(run_context)
    loss_cb.epoch_end(run_context)
    loss_cb.end(run_context)
Exemple #8
0
def test_loss_monitor_sink_mode():
    """Test loss monitor sink mode."""
    cb_params = _InternalCallbackParam()
    cb_params.cur_epoch_num = 4
    cb_params.cur_step_num = 2
    cb_params.batch_num = 2
    cb_params.net_outputs = Tensor(2.0)
    run_context = RunContext(cb_params)
    loss_cb = LossMonitor(1)
    callbacks = [loss_cb]
    with _CallbackManager(callbacks) as callbacklist:
        callbacklist.begin(run_context)
        callbacklist.epoch_begin(run_context)
        callbacklist.step_begin(run_context)
        callbacklist.step_end(run_context)
        callbacklist.epoch_end(run_context)
        callbacklist.end(run_context)
Exemple #9
0
def test_RunContext():
    """Test RunContext."""
    context_err = 666
    with pytest.raises(TypeError):
        RunContext(context_err)

    cb_params = _InternalCallbackParam()
    cb_params.member1 = 1
    cb_params.member2 = "abc"

    run_context = RunContext(cb_params)
    run_context.original_args()
    assert cb_params.member1 == 1
    assert cb_params.member2 == "abc"

    run_context.request_stop()
    should_stop = run_context.get_stop_requested()
    assert should_stop