def test_round_trip(): # set up an op and Assign a value to it so we can read it out axes = ng.make_axes([ ng.make_axis(name='A', length=2), ng.make_axis(name='B', length=3), ]) x_op = ng.variable(axes) assign_op = ng.AssignOp(x_op, 1) with executor(assign_op) as assign_computation: t = assign_computation.transformer # Set initial value assign_computation() # Test value np.testing.assert_allclose(serde_weights.extract_op(t, x_op), 1) # write out values in x and graph f = BytesIO() # ## EXAMPLE OF HOW TO FULLY SERIALIZE A GRAPH ### serde_weights.serialize_weights(t, [x_op], f) graph_string = serde.serialize_graph([x_op]) # ## /EXAMPLE OF HOW TO FULLY SERIALIZE A GRAPH ### f.seek(0) # ## EXAMPLE OF HOW TO FULLY DESERIALIZE A GRAPH ### new_ops = serde.deserialize_graph(graph_string) serde_weights.deserialize_weights(t, new_ops, f) # ## /EXAMPLE OF HOW TO FULLY DESERIALIZE A GRAPH ### np.testing.assert_allclose(serde_weights.extract_op(t, new_ops[0]), 1)
def setup_restore(self, transformer, computation, filename): """ prepare restore function for loading weight from file to weight variables in computation Arguments: transformer : transformer where the weights will be restored computation (ComputationOp or dict of Ops): A ComputationOp or dictionary of output Ops of interest. filename: name of file with saved weights """ def match_ops(tensors, values): """ Match weights with tensor values loaded from file """ nodes = dict() frontier = set(values) visited = set() def match_op(op_to_add): """ Match weight with loaded tensor value """ if isinstance(op_to_add, ng.TensorValueOp): tensor = op_to_add.tensor if isinstance(tensor, ng.AssignableTensorOp): if tensor.is_persistent: if tensor.is_constant: pass elif tensor.is_placeholder: pass else: try: nodes[tensor] = tensors[tensor.name] except KeyError: print( "Warning: Missing weight in save file: " + tensor.name) while len(frontier) > 0: op_to_visit = frontier.pop() match_op(op_to_visit) visited.add(op_to_visit) for arg in op_to_visit.args: if arg not in visited: frontier.add(arg) for arg in op_to_visit.all_deps: if arg not in visited: frontier.add(arg) return nodes # load weight from file to tensors savefile = SaverFile(filename) tensors = savefile.read_values() nodes = match_ops(tensors, get_root_ops(computation)) restore_ops = [] for op_to_save, op_value in nodes.items(): restore_ops.append(ng.AssignOp(op_to_save, op_value)) self.setter = transformer.computation(restore_ops)
def test_extract_op(): # set up an op and Assign a value to it so we can read it out axes = ng.make_axes([ ng.make_axis(name='A', length=2), ng.make_axis(name='B', length=3), ]) x_op = ng.variable(axes) assign_op = ng.AssignOp(x_op, 1) # extract values out of it and make sure they match expected results with executor(assign_op) as comp_assignment: t = comp_assignment.transformer comp_assignment() x_out = serde_weights.extract_op(t, x_op) assert (x_out == np.ones(axes.lengths)).all()
def test_variable(): input_axes = ng.make_axes([ ng.make_axis(10), ng.make_axis(3) ]) var = ng.variable(axes=input_axes) assign_val = np.random.rand(10, 3) var_assign = ng.AssignOp(tensor=var, val=assign_val) var_seq = ng.sequential([var_assign, var]) var_comp = ng.computation(var_seq, "all") results = dict() weight_saver = Saver() with closing(ngt.make_transformer()) as transformer: var_func = transformer.add_computation(var_comp) weight_saver.setup_save(transformer=transformer, computation=var_comp) results['saved'] = var_func().copy() weight_saver.save(filename="test_variable") reassign_val = np.random.rand(10, 3) var_reassign = ng.AssignOp(tensor=var, val=reassign_val) var_recomp = ng.computation(var_reassign, "all") var_read = ng.computation(var, "all") with closing(ngt.make_transformer()) as restore_transformer: var_recompfunc = restore_transformer.add_computation(var_recomp) weight_saver.setup_restore(transformer=restore_transformer, computation=var_recomp, filename="test_variable") var_readfunc = restore_transformer.add_computation(var_read) var_recompfunc() results['reassigned'] = var_readfunc().copy() weight_saver.restore() results['restored'] = var_readfunc().copy() os.remove("test_variable.npz") assert np.allclose(results['saved'], assign_val, atol=0) assert np.allclose(results['reassigned'], reassign_val, atol=0) assert np.allclose(results['saved'], results['restored'], atol=0)
def test_extract_op(transformer_factory): # set up an op and Assign a value to it so we can read it out axes = ng.make_axes([ ng.make_axis(name='A', length=2), ng.make_axis(name='B', length=3), ]) x_op = ng.variable(axes) assign_op = ng.AssignOp(x_op, 1) # extract values out of it and make sure they match expected results with closing(ngt.make_transformer()) as t: comp_assignment = t.computation(assign_op) comp_assignment() x_out = serde_weights.extract_op(t, x_op) assert (x_out == np.ones(axes.lengths)).all()
def set_op_value(transformer, op, value): """ Given an op and a numpy array, set the op's value to the numpy array """ transformer.computation(ng.AssignOp(op, value))()
def assign_ops(ops, values): assign_ops = [ng.AssignOp(op, value) for op, value in zip(ops, values)] return ng.sequential(assign_ops)