def test_learn_to_add_three(): logging.basicConfig(level=logging.DEBUG) batch_size = 10 interpreter = si.SimpleInterpreter(10, 15, 15, batch_size, parallel_branches=True) for batch in range(0, batch_size): interpreter.load_code( ": 2+ 1+ 1+ ; { choose 1+ 2+ } { choose 1+ 2+ } ", batch) trace = interpreter.execute(10) loss = L2Loss(trace[-1], interpreter) l2_loss = loss.l2_loss # / batch_size opt = tf.train.AdamOptimizer(learning_rate=0.1) opt_op = opt.minimize(l2_loss) sess = tf.Session() sess.run(tf.initialize_all_variables()) # presampled = [np.random.randint(0, 10) for i in range(0, 10)] current_loss = 0 for epoch in range(0, 100): for i in range(0, batch_size): sampled = np.random.randint(0, 10) interpreter.load_stack([sampled], i) loss.load_target_stack([sampled + 3], i) current_loss, _ = sess.run([l2_loss, opt_op], loss.current_feed_dict()) current_loss /= batch_size print(epoch, current_loss) assert current_loss < 0.47
def test_learn_to_add_digits(): batch_size = 1 interpreter = si.SimpleInterpreter(3, 20, 15, batch_size, parallel_branches=True, do_normalise_pointer=True) for batch in range(0, batch_size): interpreter.load_code("{ observe D0 D-1 -> sigmoid -> linear 10 -> manipulate D0 } SWAP DROP", batch) trace = interpreter.execute(5) state = trace[-1] loss = L2Loss(state, interpreter) l2_loss = loss.l2_loss / batch_size opt = tf.train.AdamOptimizer(learning_rate=0.1) opt_op = opt.minimize(l2_loss) sess = tf.Session() sess.run(tf.initialize_all_variables()) # presampled = [np.random.randint(0, 10) for i in range(0, 10)] xs = [i for i in range(0, 10)] ys = [i for i in range(0, 10)] for epoch in range(0, 100): # for i in range(0, batch_size): # x = xs[i] # y = ys[i] # random.shuffle(xs) for i in range(0, 10): # print("-" * 50) # print(i) # random.shuffle(ys) for j in range(0, batch_size): # x = xs[i] # y = ys[i] x = xs[i] y = ys[j] interpreter.load_stack([x, y], j) loss.load_target_stack([(x + y)], j) current_loss, data_stack, data_stack_pointer, return_stack, return_stack_pointer, _ = sess.run( [l2_loss, edsm.pretty_print_buffer(state.data_stack), edsm.pretty_print_value(state.data_stack_pointer), edsm.pretty_print_buffer(state.return_stack), edsm.pretty_print_value(state.return_stack_pointer), opt_op], loss.current_feed_dict()) print("I") print(i) print("Diff") print(sess.run(edsm.pretty_print_buffer(loss.data_stack_diff), loss.current_feed_dict())) print(sess.run(edsm.pretty_print_value(loss.data_stack_pointer_diff), loss.current_feed_dict())) print("D") print(data_stack) print("DP") print(data_stack_pointer) print("R") print(return_stack) print("DP") print(return_stack_pointer) print("Loss") print(current_loss)
def test_halt(): bubble = """ : BUBBLE DUP IF >R OVER OVER < IF SWAP THEN R> SWAP >R 1- BUBBLE R> ELSE DROP THEN ; : SORT 1- DUP 0 DO >R R@ BUBBLE R> LOOP DROP ; SORT """ max_length = 5 interpreter = si.SimpleInterpreter(stack_size=2 * max_length + 10, value_size=2 * max_length + 1, min_return_width=25, batch_size=1, parallel_branches=True, collapse_forth=True, test_time_stack_size=2 * max_length + 10) interpreter.load_code(bubble, 0) sess = tf.Session() print("max steps", 10 * max_length * (max_length + 15) + 5) input_seq = rand.randint(2, 11, max_length) input_seq = np.append(input_seq, len(input_seq)) interpreter.test_time_load_stack(input_seq, 0) test_trace, (next_state, step) = \ interpreter.execute_test_time(sess, 10 * max_length * (max_length + 15) + 5, use_argmax_pointers=False, use_argmax_stacks=False, test_halt=True) print(step) assert step < 200
def test_test_time_execution(): sess = tf.Session() interpreter = si.SimpleInterpreter(5, 20, 5, 1, collapse_forth=True, parallel_branches=True, merge_pipelines_=True) # interpreter.load_code("1.0 v-") # interpreter.load_code("8 2 - 3 /") interpreter.load_code("8 2 -") # interpreter.test_time_load_stack([], last_float=False) trace, (_, steps) = interpreter.execute_test_time(sess, 10, use_argmax_pointers=True, test_halt=True) state = trace[-1] edsm.print_dsm_state_np(data_stack=state[interpreter.test_time_data_stack], data_stack_pointer=state[interpreter.test_time_data_stack_pointer], pc=state[interpreter.test_time_pc], interpreter=interpreter) print(len(trace))
def evaluate_specs(stack_size, value_size, min_return_width, steps, specs, debug=False, parallel_branches=True): interpreter = si.SimpleInterpreter(stack_size, value_size, min_return_width, len(specs), parallel_branches=parallel_branches) for batch, (_, _, code) in enumerate(specs): interpreter.load_code(code, batch) trace = interpreter.execute(steps) loss = CrossEntropyLoss(trace[-1], interpreter) for batch, (input_, output_, _) in enumerate(specs): interpreter.load_stack(input_, batch) loss.load_target_stack(output_, batch) sess = tf.Session() sess.run(tf.initialize_all_variables()) result = sess.run(loss.loss, loss.current_feed_dict()) if debug: for i in range(0, steps): print("-" * 50) print("Step {}".format(i)) print("Data Stack") state = trace[i] print( sess.run(edsm.pretty_print_buffer(state.data_stack), interpreter.current_feed_dict())) print("Data Stack Pointer") print( sess.run(edsm.pretty_print_value(state.data_stack_pointer), interpreter.current_feed_dict())) print("Return Stack") print( sess.run(edsm.pretty_print_buffer(state.return_stack), interpreter.current_feed_dict())) print("Return Stack Pointer") print( sess.run(edsm.pretty_print_value(state.return_stack_pointer), interpreter.current_feed_dict())) # print("Heap") # print(sess.run(edsm.pretty_print_buffer(state.heap), interpreter.current_feed_dict())) print("PC") pc = sess.run(edsm.pretty_print_value(state.pc), interpreter.current_feed_dict()) print(pc) # print("Current Word:") # for batch in range(0, interpreter.batch_size): # print("Batch {}".format(batch)) # for word_index in range(0, state.code_size): # score = pc[batch, word_index] # if score > 0.5: # word = interpreter.words[interpreter.final_code[batch][word_index]] # print("{} {}".format(score, word)) print("-" * 50) print("Loss: {}".format(result)) print("Target") print( sess.run( edsm.pretty_print_buffer(loss.data_stack_target_placeholder), loss.current_feed_dict())) print( sess.run( edsm.pretty_print_value( loss.data_stack_pointer_target_placeholder), loss.current_feed_dict())) assert abs(result) < eps