示例#1
0
def test_round_trip():
    # set up an op and Assign a value to it so we can read it out
    axes = ng.make_axes([
        ng.make_axis(name='A', length=2),
        ng.make_axis(name='B', length=3),
    ])
    x_op = ng.variable(axes)

    assign_op = ng.AssignOp(x_op, 1)

    with executor(assign_op) as assign_computation:
        t = assign_computation.transformer

        # Set initial value
        assign_computation()

        # Test value
        np.testing.assert_allclose(serde_weights.extract_op(t, x_op), 1)

        # write out values in x and graph
        f = BytesIO()

        # ## EXAMPLE OF HOW TO FULLY SERIALIZE A GRAPH ###
        serde_weights.serialize_weights(t, [x_op], f)
        graph_string = serde.serialize_graph([x_op])
        # ## /EXAMPLE OF HOW TO FULLY SERIALIZE A GRAPH ###

        f.seek(0)

        # ## EXAMPLE OF HOW TO FULLY DESERIALIZE A GRAPH ###
        new_ops = serde.deserialize_graph(graph_string)
        serde_weights.deserialize_weights(t, new_ops, f)
        # ## /EXAMPLE OF HOW TO FULLY DESERIALIZE A GRAPH ###

        np.testing.assert_allclose(serde_weights.extract_op(t, new_ops[0]), 1)
示例#2
0
    def setup_restore(self, transformer, computation, filename):
        """
        prepare restore function for loading weight from file to
        weight variables in computation

        Arguments:
            transformer : transformer where the weights will be restored
            computation (ComputationOp or dict of Ops):
                          A ComputationOp or dictionary of output Ops of interest.
            filename: name of file with saved weights
        """
        def match_ops(tensors, values):
            """
            Match weights with tensor values loaded from file
            """
            nodes = dict()
            frontier = set(values)
            visited = set()

            def match_op(op_to_add):
                """
                Match weight with loaded tensor value
                """
                if isinstance(op_to_add, ng.TensorValueOp):
                    tensor = op_to_add.tensor
                    if isinstance(tensor, ng.AssignableTensorOp):
                        if tensor.is_persistent:
                            if tensor.is_constant:
                                pass
                            elif tensor.is_placeholder:
                                pass
                            else:
                                try:
                                    nodes[tensor] = tensors[tensor.name]
                                except KeyError:
                                    print(
                                        "Warning: Missing weight in save file: "
                                        + tensor.name)

            while len(frontier) > 0:
                op_to_visit = frontier.pop()
                match_op(op_to_visit)
                visited.add(op_to_visit)
                for arg in op_to_visit.args:
                    if arg not in visited:
                        frontier.add(arg)
                for arg in op_to_visit.all_deps:
                    if arg not in visited:
                        frontier.add(arg)
            return nodes

        # load weight from file to tensors
        savefile = SaverFile(filename)
        tensors = savefile.read_values()
        nodes = match_ops(tensors, get_root_ops(computation))
        restore_ops = []
        for op_to_save, op_value in nodes.items():
            restore_ops.append(ng.AssignOp(op_to_save, op_value))
        self.setter = transformer.computation(restore_ops)
示例#3
0
def test_extract_op():
    # set up an op and Assign a value to it so we can read it out
    axes = ng.make_axes([
        ng.make_axis(name='A', length=2),
        ng.make_axis(name='B', length=3),
    ])
    x_op = ng.variable(axes)
    assign_op = ng.AssignOp(x_op, 1)

    # extract values out of it and make sure they match expected results
    with executor(assign_op) as comp_assignment:
        t = comp_assignment.transformer
        comp_assignment()
        x_out = serde_weights.extract_op(t, x_op)
    assert (x_out == np.ones(axes.lengths)).all()
示例#4
0
def test_variable():
    input_axes = ng.make_axes([
        ng.make_axis(10),
        ng.make_axis(3)
    ])
    var = ng.variable(axes=input_axes)
    assign_val = np.random.rand(10, 3)
    var_assign = ng.AssignOp(tensor=var, val=assign_val)
    var_seq = ng.sequential([var_assign, var])
    var_comp = ng.computation(var_seq, "all")
    results = dict()
    weight_saver = Saver()
    with closing(ngt.make_transformer()) as transformer:
        var_func = transformer.add_computation(var_comp)
        weight_saver.setup_save(transformer=transformer, computation=var_comp)
        results['saved'] = var_func().copy()
        weight_saver.save(filename="test_variable")

    reassign_val = np.random.rand(10, 3)
    var_reassign = ng.AssignOp(tensor=var, val=reassign_val)

    var_recomp = ng.computation(var_reassign, "all")
    var_read = ng.computation(var, "all")
    with closing(ngt.make_transformer()) as restore_transformer:
        var_recompfunc = restore_transformer.add_computation(var_recomp)
        weight_saver.setup_restore(transformer=restore_transformer, computation=var_recomp,
                                   filename="test_variable")
        var_readfunc = restore_transformer.add_computation(var_read)
        var_recompfunc()
        results['reassigned'] = var_readfunc().copy()
        weight_saver.restore()
        results['restored'] = var_readfunc().copy()
    os.remove("test_variable.npz")
    assert np.allclose(results['saved'], assign_val, atol=0)
    assert np.allclose(results['reassigned'], reassign_val, atol=0)
    assert np.allclose(results['saved'], results['restored'], atol=0)
示例#5
0
def test_extract_op(transformer_factory):
    # set up an op and Assign a value to it so we can read it out
    axes = ng.make_axes([
        ng.make_axis(name='A', length=2),
        ng.make_axis(name='B', length=3),
    ])
    x_op = ng.variable(axes)
    assign_op = ng.AssignOp(x_op, 1)

    # extract values out of it and make sure they match expected results
    with closing(ngt.make_transformer()) as t:
        comp_assignment = t.computation(assign_op)
        comp_assignment()
        x_out = serde_weights.extract_op(t, x_op)
    assert (x_out == np.ones(axes.lengths)).all()
示例#6
0
def set_op_value(transformer, op, value):
    """
    Given an op and a numpy array, set the op's value to the numpy array
    """
    transformer.computation(ng.AssignOp(op, value))()
示例#7
0
def assign_ops(ops, values):
    assign_ops = [ng.AssignOp(op, value) for op, value in zip(ops, values)]
    return ng.sequential(assign_ops)