def common_check_2(self,
                       name,
                       fct,
                       weight_name=None,
                       verbose=0,
                       classification=False,
                       rnd=True,
                       second_name='Y_grad',
                       **kwargs):
        onx = function_onnx_graph(name,
                                  target_opset=get_max_opset(),
                                  dtype=numpy.float32,
                                  weight_name=weight_name,
                                  **kwargs)
        if verbose > 0:
            with open(name + ".onnx", "wb") as f:
                f.write(onx.SerializeToString())
        if classification:
            N = 10
            p = numpy.random.randn(N, 1).astype(numpy.float32)
            p[0, :] = 0
            p[1, :] = 100
            p[2, :] = -100
            p[3, :] = 1
            p[4, :] = -1
            y = (numpy.random.randn(N, 1).astype(numpy.float32) > 0).astype(
                numpy.int64)
            x2 = p
            x1 = y
        else:
            if rnd:
                x1 = numpy.random.randn(10, 1).astype(numpy.float32)
                x2 = numpy.random.randn(10, 1).astype(numpy.float32)
            else:
                x1 = numpy.zeros((10, 1), dtype=numpy.float32)
                x2 = numpy.zeros((10, 1), dtype=numpy.float32) + 1
        if rnd:
            w = numpy.random.rand(10).astype(numpy.float32)
        else:
            w = numpy.zeros(10, dtype=numpy.float32) + 0.2
        if weight_name is None:
            exp_loss, exp_grad = fct(x1, x2)
        else:
            exp_loss, exp_grad = fct(x1, x2, w.reshape((-1, 1)))

        oinf = OnnxInference(onx)
        run_params = dict(verbose=verbose, fLOG=print) if verbose > 0 else {}
        if verbose > 0:
            print(f"\n+++++ name(1)={name!r}")
        if weight_name is None:
            got = oinf.run({'X1': x1, 'X2': x2}, **run_params)
        else:
            got = oinf.run({'X1': x1, 'X2': x2, 'weight': w}, **run_params)
        self.assertEqual(len(exp_grad.shape), 2)
        self.assertEqual(exp_grad.shape[-1], 1)
        self.assertEqualArray(exp_grad, got[second_name], decimal=5)
        self.assertEqualArray(exp_loss, got['Y'], decimal=5)

        providers = device_to_providers('cpu')
        so = SessionOptions()
        so.log_severity_level = 0 if verbose > 0 else 4
        so.log_verbosity_level = 0 if verbose > 0 else 4
        sess = InferenceSession(onx.SerializeToString(),
                                so,
                                providers=providers)
        if verbose > 0:
            print("+++ run")
        if weight_name is None:
            got = sess.run(None, {'X1': x1, 'X2': x2})
        else:
            got = sess.run(None, {'X1': x1, 'X2': x2, 'weight': w})
        self.assertEqualArray(exp_loss, got[0], decimal=5)
        self.assertEqualArray(exp_grad, got[1], decimal=5)
        if weight_name is not None:
            if verbose > 0:
                print("+++ run*")
            got = sess.run(None, {'X1': x1, 'X2': x2})
            exp_loss2, exp_grad2 = fct(x1, x2, numpy.array([1],
                                                           dtype=x1.dtype))
            self.assertEqualArray(exp_loss2, got[0], decimal=5)
            self.assertEqualArray(exp_grad2, got[1], decimal=5)

        if 'grad' in name:
            rew = unreduced_onnx_loss(onx)
            if 'ReduceSum' in str(rew):
                raise AssertionError(f"Isse with:\n{rew!r}")
            if verbose > 0:
                with open(name + ".unreduced.onnx", "wb") as f:
                    f.write(rew.SerializeToString())

            if verbose > 0:
                print(f"\n+++++ name(2)={name!r}")
            oinf = OnnxInference(rew)
            if weight_name is None:
                got = oinf.run({'X1': x1, 'X2': x2}, **run_params)
            else:
                got = oinf.run({'X1': x1, 'X2': x2, 'weight': w}, **run_params)
            score = got['score']
            self.assertEqual(len(score.shape), 2)
            self.assertEqual(score.shape[0], 10)
            self.assertEqual(score.shape[1], 1)
            self.assertEqualFloat(exp_loss, score.sum())

            sess = InferenceSession(rew.SerializeToString(),
                                    so,
                                    providers=providers)
            if verbose > 0:
                print("+++ run")
            if weight_name is None:
                got = sess.run(None, {'X1': x1, 'X2': x2})
            else:
                got = sess.run(None, {'X1': x1, 'X2': x2, 'weight': w})
            score = got[0]
            self.assertEqual(len(score.shape), 2)
            self.assertEqual(score.shape[0], 10)
            self.assertEqual(score.shape[1], 1)
            self.assertEqualFloat(exp_loss, score.sum())
Exemple #2
0
    def _create_training_session(self,
                                 training_onnx,
                                 weights_to_train,
                                 loss_output_name='loss',
                                 training_optimizer_name='SGDOptimizer',
                                 device='cpu'):
        """
        Creates an instance of :epkg:`TrainingSession`.

        :param training_onnx: an ONNX graph with a loss function
        :param weights_to_train: list of initializer names to optimize
        :param loss_output_name: output name for the loss
        :param training_optimizer_name: optimizer name
        :param device: one :epkg:`C_OrtDevice` or a string
        :return: an instance of :epkg:`TrainingSession`
        """
        if training_optimizer_name != 'SGDOptimizer':
            raise NotImplementedError(
                "Only the SGDOptimizer is implemented not %r."
                "" % training_optimizer_name)
        ort_parameters = TrainingParameters()
        ort_parameters.loss_output_name = loss_output_name
        ort_parameters.use_mixed_precision = False
        # ort_parameters.world_rank = -1
        # ort_parameters.world_size = 1
        # ort_parameters.gradient_accumulation_steps = 1
        # ort_parameters.allreduce_post_accumulation = False
        # ort_parameters.deepspeed_zero_stage = 0
        # ort_parameters.enable_grad_norm_clip = False
        # ort_parameters.set_gradients_as_graph_outputs = False
        # ort_parameters.use_memory_efficient_gradient = False
        # ort_parameters.enable_adasum = False
        if self.saved_gradient is not None:
            name = self.saved_gradient
            name2 = name + ".training.onnx"
            ort_parameters.model_with_gradient_graph_path = name
            ort_parameters.model_with_training_graph_path = name2

        output_types = {}
        for output in training_onnx.graph.output:
            output_types[output.name] = output.type.tensor_type

        ort_parameters.weights_to_train = set(weights_to_train)
        ort_parameters.training_optimizer_name = training_optimizer_name
        # ort_parameters.lr_params_feed_name = lr_params_feed_name

        ort_parameters.optimizer_attributes_map = {
            name: {}
            for name in weights_to_train
        }
        ort_parameters.optimizer_int_attributes_map = {
            name: {}
            for name in weights_to_train
        }

        session_options = SessionOptions()
        session_options.log_severity_level = 4
        session_options.log_verbosity_level = 4
        # session_options.use_deterministic_compute = True

        providers = device_to_providers(self.device)
        session = TrainingSession(training_onnx.SerializeToString(),
                                  ort_parameters,
                                  session_options,
                                  providers=providers)

        return session