コード例 #1
0
ファイル: syngen_tests.py プロジェクト: vanishinggrad/nvm
            def callback(layer_name, data):
                if layer_name == "gh":
                    state = vm.net.layers["gh"].coder.decode(data)

                    if state == "start":
                        test_state["start"] = True
                        test_state["t"] += 1
                        #print("\n%d" % test_state["t"])
                    else:
                        test_state["start"] = False

                    if state == "?":
                        if test_state["unk_count"] > 20:
                            print("Gate mechanism derailed!")
                            interrupt_engine()
                            test_state["failed"] = True
                        test_state["unk_count"] += 1
                    else:
                        test_state["unk_count"] = 0

                elif test_state["start"]:
                    state = vm.net.layers[layer_name].coder.decode(data)
                    #print(layer_name, state)

                    if test_state["t"] < len(trace):
                        trace_t = trace[test_state["t"]]

                        if layer_name in trace_t and state != trace_t[layer_name]:
                            print("Trace mismatch!")
                            print(test_state["t"], layer_name, state, trace_t[layer_name])
                            interrupt_engine()
                            test_state["failed"] = True
コード例 #2
0
ファイル: syngen_tests.py プロジェクト: vanishinggrad/nvm
 def consumer(output):
     if random() < 0.01:
         #print("Consumed %s" % output)
         if output != next(write_stream):
             print("Stream mismatch!")
             self.failed = True
             interrupt_engine()
         return True
     else:
         return False
コード例 #3
0
 def input_callback(layer_name, data):
     try:
         inp = next(input_iters[layer_name])
         if inp is not None:
             np.copyto(
                 data, self.acts[layer_name].g(
                     self.coders[layer_name].encode(inp)).flat)
     except StopIteration:
         interrupt_engine()
     except Exception as e:
         print(e)
         interrupt_engine()
コード例 #4
0
ファイル: syngen_nvm.py プロジェクト: vanishinggrad/nvm
    def checker_callback(ID, size, ptr):
        if ID == 0:
            nvmnet.tick()

        layer_name = layer_names[ID]

        coder = nvmnet.layers[layer_name].coder
        syn_v = FloatArray(size, ptr).to_np_array()
        syn_tok = coder.decode(syn_v)
        py_v = nvmnet.activity[layer_name]
        py_tok = coder.decode(py_v)

        if py_tok != syn_tok:
            residual = np.fabs(syn_v.reshape(py_v.shape) - py_v).max()

            print("Mismatch detected in nvm_checker!")
            print("%4s: %12s | %12s (res=%f)" %
                  (layer_name, syn_tok, py_tok, residual))

            interrupt_engine()
コード例 #5
0
        def output_callback(layer_name, data):
            try:
                tok = next(output_iters[layer_name])
                if tok is not None:
                    out = self.coders[layer_name].decode(data)

                    # Check output
                    if out == tok:
                        test_state[layer_name]["w_correct"] += 1
                        test_state[layer_name]["correct"] += 1
                    else:
                        test_state[layer_name]["w_correct"] += (np.sum(
                            np.sign(data) == np.sign(self.acts[layer_name].g(
                                self.coders[layer_name].encode(tok))).flat) /
                                                                data.size)
                    test_state[layer_name]["total"] += 1
            except StopIteration:
                interrupt_engine()
            except Exception as e:
                print(e)
                interrupt_engine()
コード例 #6
0
ファイル: syngen_matlab.py プロジェクト: vanishinggrad/nvm
def kill():
    global thread
    interrupt_engine()
    thread.join()
    thread = None
コード例 #7
0
ファイル: syngen_nvm.py プロジェクト: vanishinggrad/nvm
 def exit_callback(ID, size, ptr):
     if FloatArray(size, ptr)[0] > 0.0:
         interrupt_engine()