def decorator(loop_body): thread_rounds = n_loops / n_threads remainder = n_loops % n_threads for t in thread_mem_req: if t != regint: raise CompilerError('Not implemented for other than regint') args = Matrix(n_threads, 2 + thread_mem_req.get(regint, 0), 'ci') state = tuple(initializer()) def f(inc): if thread_mem_req: thread_mem = Array(thread_mem_req[regint], regint, \ args[get_arg()].address + 2) mem_state = Array(len(state), type(state[0]) \ if state else cint, args[get_arg()][1]) base = args[get_arg()][0] @map_reduce_single(n_parallel, thread_rounds + inc, \ initializer, reducer, mem_state) def f(i): if thread_mem_req: return loop_body(base + i, thread_mem) else: return loop_body(base + i) prog = get_program() threads = [] if thread_rounds: tape = prog.new_tape(f, (0, ), 'multithread') for i in range(n_threads - remainder): mem_state = make_array(initializer()) args[remainder + i][0] = i * thread_rounds if len(mem_state): args[remainder + i][1] = mem_state.address threads.append(prog.run_tape(tape, remainder + i)) if remainder: tape1 = prog.new_tape(f, (1, ), 'multithread1') for i in range(remainder): mem_state = make_array(initializer()) args[i][0] = (n_threads - remainder + i) * thread_rounds + i if len(mem_state): args[i][1] = mem_state.address threads.append(prog.run_tape(tape1, i)) for thread in threads: prog.join_tape(thread) if state: if thread_rounds: for i in range(n_threads - remainder): state = reducer(Array(len(state), type(state[0]), \ args[remainder + i][1]), state) if remainder: for i in range(remainder): state = reducer(Array(len(state), type(state[0]).reg_type, \ args[i][1]), state) def returner(): return untuplify(state) return returner
def test_argmax(): sec_mat = input_as_mat([[1, 1, 0], [9, 4, 1], [2, 1, 2], [1, 2, 3]]) actual = argmax_over_fracs(sec_mat) runtime_assert_arr_equals([9, 4, 1], actual, default_test_name()) sec_mat = Matrix(64, 3, sint) for r in range(64): sec_mat[r][0] = 1 sec_mat[r][1] = (r + 2) sec_mat[r][2] = r sec_mat[30][0] = 2 sec_mat[30][1] = 3 sec_mat[30][2] = 3 actual = argmax_over_fracs(sec_mat) runtime_assert_arr_equals([2, 3, 3], actual, default_test_name()) sec_mat = input_as_mat([[1, 9, 0], [0, 1, 1]]) actual = argmax_over_fracs(sec_mat) runtime_assert_arr_equals([1, 9, 0], actual, default_test_name())