def test_round_trip(): # set up an op and Assign a value to it so we can read it out axes = ng.make_axes([ ng.make_axis(name='A', length=2), ng.make_axis(name='B', length=3), ]) x_op = ng.variable(axes) assign_op = ng.AssignOp(x_op, 1) with executor(assign_op) as assign_computation: t = assign_computation.transformer # Set initial value assign_computation() # Test value np.testing.assert_allclose(serde_weights.extract_op(t, x_op), 1) # write out values in x and graph f = BytesIO() # ## EXAMPLE OF HOW TO FULLY SERIALIZE A GRAPH ### serde_weights.serialize_weights(t, [x_op], f) graph_string = serde.serialize_graph([x_op]) # ## /EXAMPLE OF HOW TO FULLY SERIALIZE A GRAPH ### f.seek(0) # ## EXAMPLE OF HOW TO FULLY DESERIALIZE A GRAPH ### new_ops = serde.deserialize_graph(graph_string) serde_weights.deserialize_weights(t, new_ops, f) # ## /EXAMPLE OF HOW TO FULLY DESERIALIZE A GRAPH ### np.testing.assert_allclose(serde_weights.extract_op(t, new_ops[0]), 1)
def test_extract_op(): # set up an op and Assign a value to it so we can read it out axes = ng.make_axes([ ng.make_axis(name='A', length=2), ng.make_axis(name='B', length=3), ]) x_op = ng.variable(axes) assign_op = ng.AssignOp(x_op, 1) # extract values out of it and make sure they match expected results with executor(assign_op) as comp_assignment: t = comp_assignment.transformer comp_assignment() x_out = serde_weights.extract_op(t, x_op) assert (x_out == np.ones(axes.lengths)).all()
def test_extract_op(transformer_factory): # set up an op and Assign a value to it so we can read it out axes = ng.make_axes([ ng.make_axis(name='A', length=2), ng.make_axis(name='B', length=3), ]) x_op = ng.variable(axes) assign_op = ng.AssignOp(x_op, 1) # extract values out of it and make sure they match expected results with closing(ngt.make_transformer()) as t: comp_assignment = t.computation(assign_op) comp_assignment() x_out = serde_weights.extract_op(t, x_op) assert (x_out == np.ones(axes.lengths)).all()
def test_set_op_value(transformer_factory): """ set up a variable, then use serde_weights.set_op_value to inject a value into the graph. Then double check that the value was injected. """ axes = ng.make_axes([ ng.make_axis(name='A', length=2), ng.make_axis(name='B', length=3), ]) x_op = ng.variable(axes) with closing(ngt.make_transformer()) as t: value = np.ones(axes.lengths) serde_weights.set_op_value(t, x_op, value) # extract values out of it and make sure they match expected results x_out = serde_weights.extract_op(t, x_op) assert (x_out == value).all()