def test_operators_variable_builtins(self, mxf_operator, mxnet_operator, inputs, case): m = Model() v1 = Variable() v2 = Variable() variables = [v1, v2] if len(inputs) > 1 else [v1] m.r = mxf_operator(*variables) vs = [v for v in m.r.factor.inputs] variables_rt = {v[1].uuid: inputs[i] for i, v in enumerate(vs)} r_eval = m.r.factor.eval(mx.nd, variables=variables_rt) m2 = Model() v12 = Variable() v22 = Variable() variables2 = [v12, v22] if len(inputs) > 1 else [v12] if case == "add": m2.r = v12 + v22 elif case == "sub": m2.r = v12 - v22 elif case == "mul": m2.r = v12 * v22 elif case == "div": m2.r = v12 / v22 elif case == "pow": m2.r = v12**v22 elif case == "transpose": m2.r = transpose(v12) vs2 = [v for v in m2.r.factor.inputs] variables_rt2 = {v[1].uuid: inputs[i] for i, v in enumerate(vs2)} p_eval = m2.r.factor.eval(mx.nd, variables=variables_rt2) assert np.allclose(r_eval.asnumpy(), p_eval.asnumpy()), (r_eval, p_eval)
def _test_operator(self, operator, inputs, properties=None): """ inputs are mx.nd.array properties are just the operator properties needed at model def time. """ properties = properties if properties is not None else {} m = Model() variables = [Variable() for _ in inputs] m.r = operator(*variables, **properties) vs = [v for v in m.r.factor.inputs] variables = {v[1].uuid: inputs[i] for i, v in enumerate(vs)} evaluation = m.r.factor.eval(mx.nd, variables=variables) return evaluation
def test_operator_replicate(self, mxf_operator, mxnet_operator, inputs, properties): properties = properties if properties is not None else {} m = Model() variables = [Variable() for _ in inputs] m.r = mxf_operator(*variables, **properties) vs = [v for v in m.r.factor.inputs] variables = {v[1].uuid: inputs[i] for i, v in enumerate(vs)} evaluation = m.r.factor.eval(mx.nd, variables=variables) r_clone = m.extract_distribution_of(m.r) vs = [v for v in r_clone.factor.inputs] variables = {v[1].uuid: inputs[i] for i, v in enumerate(vs)} evaluation2 = r_clone.factor.eval(mx.nd, variables=variables) assert np.allclose(evaluation.asnumpy(), evaluation2.asnumpy()), (evaluation, evaluation2)