def add_binary_tree_6(depth, parent_nodes, num_leafs):
    parent_node = parent_nodes[0]
    # process while_loop_1
    while_2_parent_nodes = []
    for i in range(depth):
        #print("creating depth {} of loop1 of mcts tree".format(i))
        parent_node, while_node_0 = add_binary_tree_block_6(parent_node)
        while_2_parent_nodes.append(while_node_0)

    # add terminating while_loop_1 nodes
    terminate_while_node_0 = MctsTreeNode(parent=parent_node,
                                          children=[],
                                          val=0)
    parent_node.children = [terminate_while_node_0]
    while_2_parent_nodes.append(terminate_while_node_0)

    # process while_loop_2
    for parent_node in while_2_parent_nodes:
        for i in range(depth):
            #print("creating depth {} of loop2 of mcts tree".format(i))
            parent_node, _ = add_binary_tree_block_6(parent_node)

        # add terminating while_loop_2 nodes
        terminate_while_node_0 = MctsTreeNode(parent=parent_node,
                                              children=[],
                                              val=0)
        parent_node.children = [terminate_while_node_0]
        # add terminating while_loop_2 nodes as leaf nodes
        num_leafs[0] += 1
def add_binary_tree_7(depth, parent_nodes, num_leafs):
    while_nodes_0 = []
    for i in range(depth):
        #print("creating depth {} of mcts tree".format(i))
        next_parent_nodes = []
        for parent_node in parent_nodes:
            next_parent_nodes_new, while_node_0_new = add_binary_tree_block_7(
                parent_node)
            next_parent_nodes += next_parent_nodes_new
            while_nodes_0 += while_node_0_new
        parent_nodes = next_parent_nodes

    # add terminating while nodes
    terminate_while_nodes_0 = []
    for parent_node in parent_nodes:
        terminate_while_node_0 = MctsTreeNode(parent=parent_node,
                                              children=[],
                                              val=0)
        parent_node.children = [terminate_while_node_0]
        terminate_while_nodes_0.append(terminate_while_node_0)

    # append last if node to while_nodes_0 and terminate_while_node_0
    for node in (while_nodes_0 + terminate_while_nodes_0):
        if_node_1 = MctsTreeNode(parent=node, children=[], val=1)
        if_node_0 = MctsTreeNode(parent=node, children=[], val=0)
        node.children = [if_node_1, if_node_0]
        # add if_node_1 and if_node_0 as leaf nodes
        num_leafs[0] += 2
def add_binary_tree_block_6(parent_node):
    # add while nodes
    while_node_1 = MctsTreeNode(parent=parent_node, children=[], val=1)
    while_node_0 = MctsTreeNode(parent=parent_node, children=[], val=0)
    parent_node.children = [while_node_1, while_node_0]

    next_parent_node = while_node_1

    return next_parent_node, while_node_0
def add_binary_tree_block_5(parent_node):
    # add if nodes
    if_node_1 = MctsTreeNode(parent=parent_node, children=[], val=1)
    if_node_0 = MctsTreeNode(parent=parent_node, children=[], val=0)
    parent_node.children = [if_node_1, if_node_0]

    next_parent_nodes = [if_node_1, if_node_0]

    return next_parent_nodes
def add_binary_tree_block_2(parent_node, num_leafs):
    # add while nodes
    while_node_1 = MctsTreeNode(parent=parent_node, children=[], val=1)
    while_node_0 = MctsTreeNode(parent=parent_node, children=[], val=0)
    parent_node.children = [while_node_1, while_node_0]
    # add while_node_0 as leaf node
    num_leafs[0] += 1

    next_parent_node = while_node_1

    return next_parent_node
def grow_type5_loc(node, max_depth):
    # add if nodes
    if_node_1 = MctsTreeNode(node_type="if1",
                             parent=node,
                             children=[],
                             val=1,
                             depth=(node.depth + 1))
    if_node_0 = MctsTreeNode(node_type="if0",
                             parent=node,
                             children=[],
                             val=0,
                             depth=(node.depth + 1))
    node.children = [if_node_1, if_node_0]
def grow_type3_loc(node, max_depth):
    # add while nodes
    while_node_1 = MctsTreeNode(node_type="while1",
                                parent=node,
                                children=[],
                                val=1,
                                depth=(node.depth + 1))
    while_node_0 = MctsTreeNode(node_type="while0",
                                parent=node,
                                children=[],
                                val=0,
                                depth=(node.depth + 1))
    node.children = [while_node_1, while_node_0]
def generate_bin_tree(tree_type, depth, use_subtree_flag):
    num_leafs = [0]
    # create root node
    root = MctsTreeNode(parent=None, children=[], val="root")

    # add 20 location children nodes or 1 loc node is agentloc and agentdir is fixed
    loc_nodes = get_loc_nodes(root, tree_type, use_subtree_flag)
    root.children = loc_nodes
    # add binary tree for each loc node
    for loc_node in root.children:
        get_tree_gen_func(tree_type)(depth=depth,
                                     parent_nodes=[loc_node],
                                     num_leafs=num_leafs)
        print("created tree for {}".format(loc_node.val))

    return root, num_leafs[0]
def add_binary_tree_block_4(parent_node, num_leafs):
    # add while nodes
    while_node_1 = MctsTreeNode(parent=parent_node, children=[], val=1)
    while_node_0 = MctsTreeNode(parent=parent_node, children=[], val=0)
    parent_node.children = [while_node_1, while_node_0]
    # add while_node_0 as leaf node
    num_leafs[0] += 1
    # add a nodes
    a_node_1 = MctsTreeNode(parent=while_node_1, children=[], val=1)
    a_node_0 = MctsTreeNode(parent=while_node_1, children=[], val=0)
    while_node_1.children = [a_node_1, a_node_0]
    # add bc nodes
    bc_node_1 = MctsTreeNode(parent=a_node_0, children=[], val=1)
    bc_node_0 = MctsTreeNode(parent=a_node_0, children=[], val=0)
    a_node_0.children = [bc_node_1, bc_node_0]

    next_parent_nodes = [a_node_1, bc_node_1, bc_node_0]

    return next_parent_nodes
def get_loc_nodes(parent_node, tree_type, use_subtree_flag):
    if (use_subtree_flag):
        direction_choices = ["fixed_loc_dir"]
    else:
        direction_choices = get_direction_choices(tree_type)
    loc_nodes = []
    for dir in direction_choices:
        loc_node = MctsTreeNode(parent=parent_node, children=[], val=dir)
        loc_nodes.append(loc_node)

    return loc_nodes
def grow_type2_while1(node, max_depth):
    if (node.depth < max_depth):
        # add while nodes
        while_node_1 = MctsTreeNode(node_type="while1",
                                    parent=node,
                                    children=[],
                                    val=1,
                                    depth=(node.depth + 1))
        while_node_0 = MctsTreeNode(node_type="while0",
                                    parent=node,
                                    children=[],
                                    val=0,
                                    depth=(node.depth + 1))
        node.children = [while_node_1, while_node_0]
    else:
        # add terminating while0 node
        while_node_0 = MctsTreeNode(node_type="while0",
                                    parent=node,
                                    children=[],
                                    val=0,
                                    depth=(node.depth + 1))
        node.children = [while_node_0]
def add_binary_tree_2(depth, parent_nodes, num_leafs):
    parent_node = parent_nodes[0]
    for i in range(depth):
        #print("creating depth {} of mcts tree".format(i))
        parent_node = add_binary_tree_block_2(parent_node, num_leafs)

    # add terminating while nodes
    terminate_while_node_0 = MctsTreeNode(parent=parent_node,
                                          children=[],
                                          val=0)
    parent_node.children = [terminate_while_node_0]
    # add terminate_while_node_0 as leaf node
    num_leafs[0] += 1
def get_loc_nodes_onthefly(parent_node, use_subtree_flag):
    if (use_subtree_flag):
        direction_choices = ["fixed_loc_dir"]
    else:
        direction_choices = get_direction_choices()
    loc_nodes = []
    for dir in direction_choices:
        loc_node = MctsTreeNode(node_type="loc",
                                parent=parent_node,
                                children=[],
                                val=dir,
                                depth=(parent_node.depth + 1))
        loc_nodes.append(loc_node)

    return loc_nodes
def add_binary_tree_3(depth, parent_nodes, num_leafs):
    for i in range(depth):
        #print("creating depth {} of mcts tree".format(i))
        next_parent_nodes = []
        for parent_node in parent_nodes:
            next_parent_nodes += add_binary_tree_block_3(
                parent_node, num_leafs)
        parent_nodes = next_parent_nodes
    # add terminating while nodes
    for parent_node in parent_nodes:
        terminate_while_node_0 = MctsTreeNode(parent=parent_node,
                                              children=[],
                                              val=0)
        parent_node.children = [terminate_while_node_0]
        # add terminate_while_node_0 as leaf node
        num_leafs[0] += 1
def add_binary_tree_block_7(parent_node):
    # add while nodes
    while_node_1 = MctsTreeNode(parent=parent_node, children=[], val=1)
    while_node_0 = MctsTreeNode(parent=parent_node, children=[], val=0)
    parent_node.children = [while_node_1, while_node_0]

    # add if nodes
    if_node_1 = MctsTreeNode(parent=while_node_1, children=[], val=1)
    if_node_0 = MctsTreeNode(parent=while_node_1, children=[], val=0)
    while_node_1.children = [if_node_1, if_node_0]

    next_parent_nodes = [if_node_1, if_node_0]

    return next_parent_nodes, [while_node_0]
def add_binary_tree_block_3(parent_node, num_leafs):
    # add while nodes
    while_node_1 = MctsTreeNode(parent=parent_node, children=[], val=1)
    while_node_0 = MctsTreeNode(parent=parent_node, children=[], val=0)
    parent_node.children = [while_node_1, while_node_0]
    # add while_node_0 as leaf node
    num_leafs[0] += 1
    # add if nodes
    if_node_1 = MctsTreeNode(parent=while_node_1, children=[], val=1)
    if_node_0 = MctsTreeNode(parent=while_node_1, children=[], val=0)
    while_node_1.children = [if_node_1, if_node_0]

    next_parent_nodes = [if_node_1, if_node_0]

    return next_parent_nodes
def train_mcts_tree(tree_type, num_train_iterations, input_code,
                    input_task_data, env_type, env_type_initial, Z, expl_const,
                    program, depth):
    # create root node of mcts tree
    root = MctsTreeNode(node_type="root",
                        parent=None,
                        children=[],
                        val="root",
                        depth=0)
    # snapshot_t stores at which t to take a snapshot of the mcts tree
    #snapshot_t = [16, 32, 64, 128, 1000, 2500, 5000, 7500, 10000]
    #snapshot_roots = {}

    flag_no_valid_children = False
    # no of times scores were reused from cached score table
    score_table_hits = 0
    unique_traces = []
    all_traces = []
    cnt_invalid_traces = 0
    cnt_node_expansions = [0]
    cnt_node_grown = [0]
    t = 1
    # store a list of expanded nodes to be reset for diversity
    #expanded_nodes = []
    #grown_nodes = []

    # expand root node once for all t training iterations
    #expand_node(root, cnt_node_expansions, expanded_nodes)
    grow_node(root, tree_type, cnt_node_grown, depth)

    # initialize tqdm progress bar
    pbar = tqdm(total=num_train_iterations + 1)

    # for each training iteration do
    while (t <= num_train_iterations):
        # increment num_iterations of root
        # needed since parent's node num_iterations is used in node selection
        root.num_iterations += 1

        # initialize
        curr_node = root
        mcts_trace = []
        # index in trace to start backpropogation update from (inclusive)
        # this index is the node from which simulation begins in mcts terminology
        backprop_idx = None

        # selection - compute a mcts trace from root to a leaf
        while (curr_node.children != []):
            # before selection check each child node is in expanded state
            #for child in curr_node.children:
            #    expand_node(child, cnt_node_expansions, expanded_nodes)

            # simulation stage
            if (backprop_idx != None):
                # select node at random
                selected_node = select_node(curr_node.children,
                                            mcts_trace,
                                            input_code,
                                            tree_type,
                                            env_type,
                                            env_type_initial,
                                            expl_const,
                                            program,
                                            random=True)
                if (selected_node == None):
                    curr_node.valid = False
                    flag_no_valid_children = True
                    break
            else:
                # selection stage
                selected_node = select_node(curr_node.children, mcts_trace,
                                            input_code, tree_type, env_type,
                                            env_type_initial, expl_const,
                                            program)
                if (selected_node == None):
                    curr_node.valid = False
                    flag_no_valid_children = True
                    break
                # expansion to include the first num_iterations=0 node
                if ((selected_node.num_iterations == 0)
                        and (backprop_idx == None)):
                    # since backprop_idx is the len it will include the new node to be updated added to mcts_trace later
                    backprop_idx = len(mcts_trace)
            mcts_trace.append(selected_node)
            curr_node = selected_node

            # grow curr_node for next selection / simulation step
            grow_node(curr_node, tree_type, cnt_node_grown, depth)

        if (flag_no_valid_children):
            flag_no_valid_children = False
            cnt_invalid_traces += 1
            # continue without increment mcts iteration counter t
            continue

        else:
            # compute score of mcts trace
            # reuse cached leaf scores if previously visited
            if (mcts_trace[-1].num_picked == 0):
                score_dict = score_mcts_trace(mcts_trace, tree_type,
                                              input_code, input_task_data,
                                              env_type, env_type_initial, Z,
                                              program)
                mcts_trace[-1].score_dict = score_dict
                unique_traces.append(mcts_trace)
                score = score_dict["score_total"]
                #all_traces.append(mcts_trace)
            else:
                score_table_hits += 1
                score = mcts_trace[-1].score_dict["score_total"]
                #all_traces.append(mcts_trace)

            mcts_trace[-1].num_picked += 1

            # backpropagation
            # update average scores of selected nodes
            update_node_scores(mcts_trace, score, backprop_idx)
            # update iteration counts of selected nodes
            update_node_num_iterations(mcts_trace, backprop_idx)

            # take snapshots of evolving tree at different t
            #if( t+1 in snapshot_t ):
            #   snapshot_roots[t] = copy.deepcopy(root)

            # increment mcts iteration counter t
            t += 1
            # update tqdm progress bar
            pbar.update(1)

    # close tqdm progress bar
    pbar.close()

    # remove individual variables above, access keys in dictionary directly
    mcts_info = {
        "mcts_run_totaliters": num_train_iterations,
        "mcts_run_num_unique_traces": len(unique_traces),
        "mcts_run_num_repeat_traces": score_table_hits,
        "mcts_run_num_invalid_traces": cnt_invalid_traces
    }

    #return root, score_table_hits, unique_traces, mcts_info, cnt_node_expansions, expanded_nodes, all_traces
    return score_table_hits, unique_traces, mcts_info, cnt_node_grown, all_traces