def _make_gluon_function_evaluation_rand_param(self, dtype, broadcastable): class Dot(HybridBlock): def __init__(self, params=None, prefix=None): super(Dot, self).__init__(params=params, prefix=prefix) with self.name_scope(): self.const = self.params.get('const', shape=(1, ), dtype=dtype, init=Zero()) def hybrid_forward(self, F, a, b, const): return F.broadcast_add(F.linalg.gemm2(a, b), const) dot = Dot(prefix='dot_') dot.initialize() dot.hybridize() print(dot.collect_params()) func_wrapper = MXFusionGluonFunction(dot, 1, dtype=dtype, broadcastable=broadcastable) from mxfusion.components.distributions.normal import Normal rand_var = Normal.define_variable(shape=(1, )) out = func_wrapper(Variable(shape=(3, 4)), Variable(shape=(4, 5)), dot_const=rand_var) return out.factor
def _make_gluon_function_evaluation(self, dtype, broadcastable): class Dot(HybridBlock): def hybrid_forward(self, F, a, b): return F.linalg.gemm2(a, b) dot = Dot(prefix='dot') dot.initialize() dot.hybridize() func_wrapper = MXFusionGluonFunction(dot, 1, dtype=dtype, broadcastable=broadcastable) out = func_wrapper(Variable(shape=(3, 4)), Variable(shape=(4, 5))) return out.factor
def test_gluon_parameters(self): self.setUp() m = Model() m.x = Variable(shape=(1, 1)) m.f = MXFusionGluonFunction(self.net, num_outputs=1) m.y = m.f(m.x) infr = Inference(ForwardSamplingAlgorithm(m, observed=[m.x])) infr.run(x=mx.nd.ones((1, 1))) assert all([ v.uuid in infr.params.param_dict for v in m.f.parameters.values() ])
def test_success(self): self.setUp() f = MXFusionGluonFunction(self.net, num_outputs=1) x = Variable() y = f(x)