def _make_gluon_function_evaluation_rand_param(self, dtype, broadcastable):
        class Dot(HybridBlock):
            def __init__(self, params=None, prefix=None):
                super(Dot, self).__init__(params=params, prefix=prefix)
                with self.name_scope():
                    self.const = self.params.get('const',
                                                 shape=(1, ),
                                                 dtype=dtype,
                                                 init=Zero())

            def hybrid_forward(self, F, a, b, const):
                return F.broadcast_add(F.linalg.gemm2(a, b), const)

        dot = Dot(prefix='dot_')
        dot.initialize()
        dot.hybridize()
        print(dot.collect_params())
        func_wrapper = MXFusionGluonFunction(dot,
                                             1,
                                             dtype=dtype,
                                             broadcastable=broadcastable)
        from mxfusion.components.distributions.normal import Normal
        rand_var = Normal.define_variable(shape=(1, ))
        out = func_wrapper(Variable(shape=(3, 4)),
                           Variable(shape=(4, 5)),
                           dot_const=rand_var)

        return out.factor
Exemplo n.º 2
0
    def _make_gluon_function_evaluation(self, dtype, broadcastable):
        class Dot(HybridBlock):
            def hybrid_forward(self, F, a, b):
                return F.linalg.gemm2(a, b)

        dot = Dot(prefix='dot')
        dot.initialize()
        dot.hybridize()
        func_wrapper = MXFusionGluonFunction(dot, 1, dtype=dtype,
                                             broadcastable=broadcastable)
        out = func_wrapper(Variable(shape=(3, 4)), Variable(shape=(4, 5)))

        return out.factor
Exemplo n.º 3
0
    def test_gluon_parameters(self):
        self.setUp()

        m = Model()
        m.x = Variable(shape=(1, 1))
        m.f = MXFusionGluonFunction(self.net, num_outputs=1)
        m.y = m.f(m.x)

        infr = Inference(ForwardSamplingAlgorithm(m, observed=[m.x]))
        infr.run(x=mx.nd.ones((1, 1)))
        assert all([
            v.uuid in infr.params.param_dict for v in m.f.parameters.values()
        ])
Exemplo n.º 4
0
 def test_success(self):
     self.setUp()
     f = MXFusionGluonFunction(self.net, num_outputs=1)
     x = Variable()
     y = f(x)