Пример #1
0
    def _build_eval_network(self, metrics, eval_network, eval_indexes):
        """Build the network for evaluation."""
        self._metric_fns = get_metrics(metrics)
        if not self._metric_fns:
            return

        if eval_network is not None:
            if eval_indexes is not None and not (isinstance(
                    eval_indexes, list) and len(eval_indexes) == 3):
                raise ValueError(
                    "Eval_indexes must be a list or None. If eval_indexes is a list, length of it \
                                 must be three. But got {}".format(
                        eval_indexes))

            self._eval_network = eval_network
            self._eval_indexes = eval_indexes
        else:
            if self._loss_fn is None:
                raise ValueError("loss_fn can not be None.")
            self._eval_network = nn.WithEvalCell(self._network, self._loss_fn,
                                                 self._amp_level == "O2")
            self._eval_indexes = [0, 1, 2]

        if self._parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL,
                                   ParallelMode.AUTO_PARALLEL):
            if self._optimizer:
                self._eval_network = _VirtualDatasetCell(self._eval_network)
            self._eval_network.set_auto_parallel()
Пример #2
0
def test_scalar_output_auto():
    context.set_auto_parallel_context(device_num=8,
                                      global_rank=0,
                                      parallel_mode="auto_parallel",
                                      full_batch=False)
    net = ParallelMulNet()
    loss_fn = nn.SoftmaxCrossEntropyWithLogits(reduction='mean')
    eval_net = nn.WithEvalCell(net, loss_fn)
    x = Tensor(np.ones([4096, 1, 2, 1024]).astype(np.float32) * 0.01)
    label = Tensor(np.ones([4096, 250]).astype(np.float32) * 0.01)
    strategies = compile_graph_two_input(x, label, eval_net)
    count = 0
    for (k, v) in strategies.items():
        if re.search('VirtualOutput-op', k) is not None:
            assert v[0][0] == 8
            count += 1
    assert count == 1
Пример #3
0
    def _build_eval_network(self, metrics, eval_network, eval_indexes):
        """Build the network for evaluation."""
        self._metric_fns = get_metrics(metrics)
        if not self._metric_fns:
            return

        if eval_network is not None:
            if eval_indexes is not None and not (isinstance(
                    eval_indexes, list) and len(eval_indexes) == 3):
                raise ValueError(
                    "Eval_indexes must be a list or None. If eval_indexes is a list, length of it \
                                 must be three. But got {}".format(
                        eval_indexes))

            self._eval_network = eval_network
            self._eval_indexes = eval_indexes
        else:
            if self._loss_fn is None:
                raise ValueError("loss_fn can not be None.")
            self._eval_network = nn.WithEvalCell(self._network, self._loss_fn)
            self._eval_indexes = [0, 1, 2]
Пример #4
0
    lr_max = 0.1
    lr = get_lr(lr_init=lr_init, lr_end=lr_end, lr_max=lr_max,
                warmup_epochs=warmup_epochs, total_epochs=epoch_size, steps_per_epoch=step_size,
                lr_decay_mode='poly')
    lr = Tensor(lr)

    # define opt
    loss_scale = 1024
    momentum = 0.9
    weight_decay = 1e-4

    # define loss, model
    loss = SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
    opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), lr, momentum)
    
    eval_net = nn.WithEvalCell(net, loss, AMP_LEVEL in ["O2", "O3"])
    model = Model(net, loss_fn=loss, optimizer=opt, metrics={'acc'},amp_level=AMP_LEVEL, eval_network=eval_net, 
              eval_indexes=[0, 1, 2], keep_batchnorm_fp32=False)

    # define callbacks
    time_cb = TimeMonitor(data_size=step_size)
    loss_cb = LossMonitor()

    cb = [time_cb, loss_cb]
    save_checkpoint = 5
    if save_checkpoint:
        save_checkpoint_epochs = 5
        keep_checkpoint_max = 10
        config_ck = CheckpointConfig(save_checkpoint_steps=save_checkpoint_epochs * step_size,
                                     keep_checkpoint_max=keep_checkpoint_max)
        ckpt_cb = ModelCheckpoint(prefix="resnet", directory=ckpt_save_dir, config=config_ck)