コード例 #1
0
    def step(self, action, ref_act):
        #############################
        ## Maintenence on the process
        if self.needs_reset:
            self.reset()

        #######################
        ## Debug messages
        if self.debug:
            with open(os.path.join(self.run_path, 'mylog.txt'), 'a') as f:
                f.write(str(time.time()) + "\t-->Wanting to do op:!!!\t" + str(action) + '\t' + str(ref_act) + '\n')

            with open(os.path.join(self.refbot_path, 'mylog.txt'), 'a') as f:
                f.write(str(time.time()) + "\t-->Wanting to do op:!!!\t" + str(ref_act) + '\n')
        ep_info = {}
        #######################
        ## Writing actions
        x2, y2, build2 = ref_act
        write_prep_action(x2, y2, build2, path=self.refbot_path, debug=self.debug)

        x, y, build = action
        write_prep_action(x, y, build, path=self.run_path, debug=self.debug)

        #######################
        ## Signalling to jar wrappers to begin their running step
        with open(self.in_file, 'w') as f:
            # we want start of a new step
            if self.debug:
                print(">> pyenv {} >> writing 2 to file {}".format(self.name, self.in_file))
            f.write('2')

        with open(self.refenv.in_file, 'w') as f:
            # we want start of a new step
            if self.debug:
                print(">> pyenv {} >> writing 2 to file {}".format(self.refenv.name, self.refenv.in_file))
            f.write('2')

        #######################
        ## Checking if episode ended early
        if self.proc.poll() != None and self.done == True:
            # env ended last step, so reset:
            if self.debug:
                print(">> PYENV ", self.name ," >>  Ended early")
            cntrl_obj = ControlObject('EARLY')
            tp = np.concatenate([np.asarray([cntrl_obj,]), np.asarray([cntrl_obj,])], axis=-1)
            return tp, np.concatenate([np.asarray([cntrl_obj,]), np.asarray([cntrl_obj,])], axis=-1), ep_info

        #######################
        ## Taking step

        # Vars for Env
        obs = None
        should_load_obs = False
        reward = None
        # Vars for ref env
        ref_obs = None
        should_load_obs2 = False
        # Waiting for responses from the jar wrappers
        stopw = Stopwatch()
        failure = False
        while True:
            if should_load_obs == False:
                with open(self.in_file, 'r') as ff:
                    k = ff.read()
                    try:
                        k = int(k)
                    except ValueError:
                        continue
                    if k == 1:
                        #print("just wrote 0 to the ", self.out_file)
                        # a new turn has just been processed
                        should_load_obs = True

            if should_load_obs2 == False:
                with open(self.refenv.in_file, 'r') as ff:
                    k2 = ff.read()
                    try:
                        k2 = int(k2)
                    except ValueError:
                        continue
                    if k2 == 1:
                        #print("just wrote 0 to the ", self.out_file)
                        # a new turn has just been processed
                        should_load_obs2 = True

            if should_load_obs == True and should_load_obs2 == True:
                break
            
            if self.proc.poll() != None and self.done == False:
                #ep ended early.
                if self.debug:
                    print("PYENV: >> GAME ENDING EARLY FOR THE FIRST TIME")
                self.done = True

                valid, reason = is_valid_action(action, self.prev_obs)
                obs = self.load_state()
                self.prev_obs = obs
                
                ep_info['n_steps'] = self.step_num

                if valid == True:
                    ep_info['valid'] = True
                else:
                    ep_info['valid'] = False

                ref_obs = self.refenv.load_state()
                if obs['players'][0]['playerType'] == 'A':
                    a_hp = obs['players'][0]['health']
                    b_hp = obs['players'][1]['health']
                else:
                    a_hp = obs['players'][1]['health']
                    b_hp = obs['players'][0]['health']
                k = np.asarray([obs,])
                u = np.asarray([ref_obs,])
                return_obs = np.concatenate([k, u], axis=-1)
                if reward_mode == 'dense':
                    win_reward = dense_win_reward
                    lose_reward = -1 * dense_win_reward
                else:
                    win_reward = binary_win_reward
                    lose_reward = -1 * binary_win_reward
                if a_hp > b_hp:
                    # player a wins
                    ep_info['winner'] = 'A'
                    return return_obs, np.concatenate([np.asarray([win_reward,]), np.asarray([lose_reward,])], axis=-1), ep_info
                elif a_hp < b_hp:
                    ep_info['winner'] = 'B'
                    return return_obs, np.concatenate([np.asarray([lose_reward,]), np.asarray([win_reward,])], axis=-1), ep_info
                else:
                    ep_info['winner'] = 'TIE'
                    return return_obs, np.concatenate([np.asarray([0.0,]), np.asarray([0.0,])], axis=-1), ep_info

            if stopw.deltaT() > 3:
                # we have waited more than 3s, game clearly ended
                self.needs_reset = True
                failure = True
                print('pyenv: env ' + str(self.name) + ' with pid ' + str(self.pid) + ' encountered error. (', should_load_obs, ',',should_load_obs2, ')' , time.time())
                break

            time.sleep(0.01)
        # TODO: possibly pre-parse obs here and derive a reward from it?
        
        #########################
        ## Loading the obs if their jar's ended properly
        #ref_obs, _ = self.refenv.step(ref_act)
        if should_load_obs:
            valid, reason = is_valid_action(action, self.prev_obs)
            obs = self.load_state()
            self.prev_obs = obs

            if valid == True:
                ep_info['valid'] = True
            else:
                ep_info['valid'] = False
        if should_load_obs2:
            ref_obs = self.refenv.load_state()

        if obs is None and self.debug == True:
            print(">> PY_ENV >> MAIN OBS IS NONE (", self.name, ")")

        if ref_obs is None:
            print(">> PY_ENV >> REF OBS IS NONE. (", self.name, ")")

        if failure == True:
            cntrl_obj = ControlObject('FAILURE')
            tp = np.concatenate([np.asarray([cntrl_obj,]), np.asarray([cntrl_obj,])], axis=-1)
            return tp, np.concatenate([np.asarray([cntrl_obj,]), np.asarray([cntrl_obj,])], axis=-1), ep_info

        # print('-----A------------->', obs['players'][0]['health'])
        # print('-----B------------->', obs['players'][1]['health'])
        self.step_num += 1

        ########################
        ## Forming rewards and packaging the obs into a good numpy form
        if obs is not None:
            # Infer reward:
            #reward = float(obs['players'][0]['score']) - float(obs['players'][1]['score'])
            curS = float(obs['players'][0]['score']) * general_reward_scaling_factor
            self.score_delta = curS - self.score
            reward = self.score_delta + per_step_reward_penalty
            self.score = curS

        if ref_obs is not None:
            curS2 = float(ref_obs['players'][0]['score']) * general_reward_scaling_factor
            self.refenv.score_delta = curS2 - self.refenv.score
            ref_reward = self.refenv.score_delta + per_step_reward_penalty
            self.refenv.score = curS2

        k = np.asarray([obs,])
        u = np.asarray([ref_obs,])
        return_obs = np.concatenate([k, u], axis=-1)
        if reward_mode == 'dense':
            return return_obs, np.concatenate([np.asarray([reward,]), np.asarray([ref_reward,])], axis=-1), ep_info
        elif reward_mode == 'binary':
            return return_obs, np.concatenate([np.asarray([binary_step_penalty,]), np.asarray([binary_step_penalty,])], axis=-1), ep_info
コード例 #2
0
ファイル: storm.py プロジェクト: RF5/scudstorm
def train(env, n_envs, no_op_vec, resume_trianing):
    print(str('=' * 50) + '\n' + 'Initializing agents\n' + str('=' * 50))
    ##############################
    ## Summary buckets
    #failed_episodes = 0
    #early_episodes = 0
    refbot_back_ind = 1
    elite_overthrows = 0
    elite = None
    starting_gen = 0  # default startin generation number. Is overwritten if resuming
    ## Setting up logs
    writer = summary.create_file_writer(util.get_logdir('train12A'),
                                        flush_millis=10000)
    writer.set_as_default()
    global_step = tf.train.get_or_create_global_step()

    ## TODO: change agent layers to use xavier initializer
    agents = [Scud(name=str(i), debug=scud_debug) for i in range(n_population)]
    total_steps = 0

    elite_moving_average = metrics.MovingAverage(
        elite_score_moving_avg_periods)
    next_generation = [
        Scud(name=str(i) + 'next', debug=scud_debug)
        for i in range(n_population)
    ]

    refbot_queue = [
        Scud(name='refbot' + str(i), debug=scud_debug)
        for i in range(refbot_queue_length)
    ]
    for i, bot in enumerate(refbot_queue):
        bot.refbot_position = i
    refbot = refbot_queue[0]

    ## DOES NOT WORK WITH EAGER EXECUTION
    # with summary.always_record_summaries():
    #     summary.graph(agents[0].model.graph)
    total_s = Stopwatch()
    ########################################
    ## Restoring from last training session
    if resume_trianing:
        # loading up config from last train finish
        print("Restoring progress config from last run...")
        config_path = os.path.join(util.get_savedir(), 'progress.json')
        conf = json.load(open(config_path, 'r'))

        starting_gen = conf['gen_at_end'] + 1
        elite_overthrows = conf['elite_overthrows']
        total_steps = conf['total_steps']
        total_s.startime = conf['clock_start_time']
        global_step.assign(starting_gen)

        # Loading truncs, elite and refbot
        print(
            str('=' * 50) + '\n' + '>> STORM >> Resuming training.\n' +
            str('=' * 50))
        trunc_names = os.listdir(util.get_savedir('truncFinals'))
        trunc_names = sorted(trunc_names, reverse=True)

        for j in range(trunc_size):
            if j < len(trunc_names):
                agents[j + 1].load(util.get_savedir('truncFinals'),
                                   trunc_names[j])
            else:
                print("Skipping loading trunc agent for j = ", j)

        refbot_names = os.listdir(util.get_savedir('refbots'))
        refbot_names = sorted(refbot_names, reverse=False)
        refbot_q_names = refbot_names[-refbot_queue_length:]
        # sec = 0
        # for i in range(5, refbot_queue_length):
        #     refbot_queue[i].load(util.get_savedir('refbots'), refbot_q_names[sec])
        #     refbot_queue[i].refbot_position = i
        #     sec = sec + 1
        for i in range(refbot_queue_length):
            refbot_queue[i].load(util.get_savedir('refbots'),
                                 refbot_q_names[i])
            refbot_queue[i].refbot_position = i

        elite = agents[0]
        elite.load(util.get_savedir(), 'elite')

        print(">> STORM >> Successfully restored from last checkpoints")

    print(
        str('=' * 50) + '\n' + 'Beginning training (at gen ' +
        str(starting_gen) + ')\n' + str('=' * 50))
    s = Stopwatch()

    #partition_stopwatch = Stopwatch()
    for g in range(starting_gen, starting_gen + n_generations):
        #####################
        ## Hyperparameter annealing
        # gamma = gamma_func((g+1)/n_generations)

        #####################
        ## GA Algorithm
        for i in range(n_population):
            if g == 0:
                break
            else:
                kappa = random.sample(agents[0:trunc_size], 1)
                mutate(kappa[0], next_generation[i], g)
        #partition_stopwatch.lap('mutation')
        # swap agents and the next gen's agents. i.e set next gen agents to be current agents to evaluate
        tmp = agents
        agents = next_generation
        next_generation = tmp

        # evaluate fitness on each agent in population
        try:
            agents, additional_steps, rollout_info = evaluate_fitness(
                env, agents, refbot, debug=False)
        except KeyboardInterrupt as e:
            print(
                "Received keyboard interrupt {}. Saving and then closing env.".
                format(e))
            break
        total_steps += additional_steps

        # sort them based on final discounted reward
        agents = sorted(agents,
                        key=lambda agent: agent.fitness_score,
                        reverse=True)

        #partition_stopwatch.lap('fitness evaluation + sorting')

        ##################################
        ## Summary information
        with summary.always_record_summaries():
            sc_vec = [a.fitness_score for a in agents]
            summary.scalar('rewards/mean', np.mean(sc_vec))
            summary.scalar('rewards/max', agents[0].fitness_score)
            summary.scalar('rewards/min', agents[-1].fitness_score)
            summary.scalar('rewards/var', np.var(sc_vec))
            summary.scalar('rewards/truc_mean', np.mean(sc_vec[:trunc_size]))
            summary.scalar('hyperparameters/gamma', gamma)

            summary.scalar('main_rollout/agentWins', rollout_info['agentWins'])
            summary.scalar('main_rollout/refbotWins',
                           rollout_info['refbotWins'])
            summary.scalar('main_rollout/ties', rollout_info['ties'])
            summary.scalar('main_rollout/early_eps', rollout_info['early_eps'])
            summary.scalar('main_rollout/failed_eps',
                           rollout_info['failed_eps'])

            if len(rollout_info['ep_lengths']) > 0:
                mean_ep_lengg = np.mean(rollout_info['ep_lengths'])
                summary.histogram('main_rollout/ep_lengths',
                                  rollout_info['ep_lengths'])
                summary.scalar('main_rollout/mean_ep_length', mean_ep_lengg)
                print("Mean ep length: ", mean_ep_lengg)

            if len(rollout_info['agent_actions']) > 0:
                summary.histogram('main_rollout/agent_a0',
                                  rollout_info['agent_actions'])
                summary.histogram('main_rollout/agent_a0_first15steps',
                                  rollout_info['agent_early_actions'])

        print("Main stats: agent wins - {} | refbot wins - {} | Early - {}".
              format(rollout_info['agentWins'], rollout_info['refbotWins'],
                     rollout_info['early_eps']))
        for a in agents[:5]:
            print(a.name, " with fitness score: ", a.fitness_score)

        ############################################
        ## Evaluating elite candidates to find elite

        #partition_stopwatch.lap('summaries 1')
        # setup next generation parents / elite agents
        if g == 0:
            if resume_trianing == False:
                elite_candidates = set(agents[0:n_elite_in_royale])
            else:
                elite_candidates = set(agents[0:n_elite_in_royale - 1]) | set([
                    elite,
                ])
        else:
            elite_candidates = set(agents[0:n_elite_in_royale - 1]) | set([
                elite,
            ])
        # finding next elite by battling proposed elite candidates for some additional rounds
        #print("Evaluating elite agent...")
        inds = np.random.random_integers(0, refbot_queue_length - 1, 4)
        refbots_for_elite = [refbot_queue[lolno] for lolno in inds]
        elo_ags, additional_steps, rollout_info = evaluate_fitness(
            env,
            elite_candidates,
            refbots_for_elite,
            runs=elite_additional_episodes)
        total_steps += additional_steps
        elo_ags = sorted(elo_ags,
                         key=lambda agent: agent.fitness_score,
                         reverse=True)
        if elite != elo_ags[0]:
            elite_overthrows += 1
        elite = elo_ags[0]

        #partition_stopwatch.lap('elite battle royale')

        try:
            agents.remove(elite)
            agents = [
                elite,
            ] + agents
        except ValueError:
            agents = [
                elite,
            ] + agents[:len(agents) - 1]

        print("Elite stats: agent wins - {} | refbot wins - {} | Early - {}".
              format(rollout_info['agentWins'], rollout_info['refbotWins'],
                     rollout_info['early_eps']))
        for i, a in enumerate(elo_ags):
            print('Elite stats: pos', i, '; name: ', a.name,
                  " ; fitness score: ", a.fitness_score)

        ############################
        ## Summary information 2
        with summary.always_record_summaries():
            elite_moving_average.push(elite.fitness_score)
            summary.scalar('rewards/elite_moving_average',
                           elite_moving_average.value())
            summary.scalar('rewards/elite_score', elite.fitness_score)
            summary.scalar('rewards/stable_mean',
                           np.mean([a.fitness_score for a in elo_ags]))
            summary.scalar('time/wall_clock_time', total_s.deltaT())
            summary.scalar('time/single_gen_time', s.deltaT())
            summary.scalar('time/total_game_steps', total_steps)
            summary.scalar('time/elite_overthrows', elite_overthrows)

            summary.scalar('elite_rollout/agentWins',
                           rollout_info['agentWins'])
            summary.scalar('elite_rollout/refbotWins',
                           rollout_info['refbotWins'])
            summary.scalar('elite_rollout/ties', rollout_info['ties'])
            summary.scalar('elite_rollout/early_eps',
                           rollout_info['early_eps'])
            summary.scalar('elite_rollout/failed_eps',
                           rollout_info['failed_eps'])

            if len(rollout_info['ep_lengths']) > 0:
                mean_ep_lengE = np.mean(rollout_info['ep_lengths'])
                summary.histogram('elite_rollout/ep_lengths',
                                  rollout_info['ep_lengths'])
                summary.scalar('elite_rollout/mean_ep_length', mean_ep_lengE)
                print("Elite mean ep length: ", mean_ep_lengE)

            if len(rollout_info['agent_actions']) > 0:
                summary.histogram('elite_rollout/agent_a0',
                                  rollout_info['agent_actions'])
                summary.histogram('elite_rollout/agent_a0_first15steps',
                                  rollout_info['agent_early_actions'])

            summary.scalar('hyperparameters/refbot_back_ind', refbot_back_ind)

        #################################
        ## Replacing reference bot
        if g % replace_refbot_every == 0:
            toback = refbot
            del refbot_queue[0]

            refbot_back_ind = np.random.random_integers(
                0, refbot_queue_length - 1)
            print(
                str('=' * 50) + '\n' +
                '>> STORM >> Upgrading refbot (to pos ' +
                str(refbot_back_ind) + ') now.\n' + str('=' * 50))
            #good_params = agents[trunc_size-1].get_flat_weights()
            good_params = agents[np.random.random_integers(
                0, trunc_size - 1)].get_flat_weights()
            toback.set_flat_weights(good_params)

            refbot_queue.append(toback)
            #refbot = refbot_queue[0]
            ################
            ## Sampling refbot uniformly from past <refbot_queue_length> generation's agents
            refbot = refbot_queue[refbot_back_ind]

            for meme_review, inner_refbot in enumerate(refbot_queue):
                inner_refbot.refbot_position = meme_review

            #for bot in refbot_queue:
            #    print("Bot ", bot.name, ' now has refbot pos: ', bot.refbot_position)

        #################################
        ## Saving agents periodically
        if g % save_elite_every == 0 and g != 0:
            elite.save(util.get_savedir('checkpoints'),
                       'gen' + str(g) + 'elite')
            if refbot_queue_length < 5:
                for refAgent in refbot_queue:
                    refAgent.save(
                        util.get_savedir('refbots'),
                        'gen' + str(g) + 'pos' + str(refAgent.refbot_position))

            if trunc_size < 5:
                for i, truncAgent in enumerate(agents[:trunc_size]):
                    truncAgent.save(util.get_savedir('truncs'),
                                    'gen' + str(g) + 'agent' + str(i))

        global_step.assign_add(1)

        print(
            str('=' * 50) + '\n' + 'Generation ' + str(g) + '. Took  ' +
            s.delta + '(total: ' + total_s.delta + ')\n' + str('=' * 50))
        s.reset()
        #partition_stopwatch.lap('summaries 2 and updates/saves')

    ###############################
    ## Shutdown behavior

    #print("PARTITION STOPWATCH RESULTS:") # last i checked runtime is *dominated*
    #partition_stopwatch.print_results()
    elite.save(util.get_savedir(), elite_savename)
    summary.flush()
    for i, ag in enumerate(agents[:trunc_size]):
        ag.save(util.get_savedir('truncFinals'), 'finalTrunc' + str(i))

    print("End refbot queue: ", len(refbot_queue))
    for identity, refAgent in enumerate(refbot_queue):
        refAgent.save(util.get_savedir('refbots'),
                      'finalRefbot{:03d}'.format(identity))

    ##########################
    ## Saving progress.config
    conf = {}
    conf['gen_at_end'] = g
    conf['gamma_at_end'] = gamma
    conf['elite_overthrows'] = elite_overthrows
    conf['total_steps'] = total_steps
    conf['clock_start_time'] = total_s.startime
    path = os.path.join(util.get_savedir(), 'progress.json')
    with open(path, 'w') as config_file:
        config_file.write(json.dumps(conf))
    print(">> STORM >> Saved progress.config to: ", path)
コード例 #3
0
class Env():

    def __init__(self, name, debug):
        self.name = name
        self.debug = debug
        self.setup_directory()
        self.score = 0
        self.score_delta = 0
        # setting up jar runner
        self.needs_reset = True
        self.pid = None
        self.done = False
        self.prev_obs = get_initial_obs(1)[0][0]
        self.clock = Stopwatch()
        self.step_num = 0

    def setup_directory(self):
        # creates the dirs responsible for this env, 
        # and moves a copy of the runner and config to that location
        
        print("Setting up file directory for " + self.name + " with pid " + str(os.getpid()))
        basedir = os.path.dirname(os.path.abspath(__file__)) # now in scudstorm dir
        self.run_path = os.path.join(basedir, 'runs', self.name)
        if os.path.isdir(self.run_path):
            shutil.rmtree(self.run_path)
        self.wrapper_path = os.path.join(self.run_path, 'jar_wrapper.py')
        os.makedirs(self.run_path, exist_ok=True)
        jarpath = os.path.join(basedir, jar_name)
        copy2(jarpath, self.run_path)
        config_path = os.path.join(basedir, config_name)
        copy2(config_path, self.run_path)
        wrapper_path = os.path.join(basedir, 'common', 'jar_wrapper.py')
        copy2(wrapper_path, self.run_path)
        botdir = os.path.join(basedir, bot_file_name)
        copy2(botdir, self.run_path)
        copy2(os.path.join(basedir, game_config_name), self.run_path)

        # Copying over reference bot
        self.refbot_path = os.path.join(self.run_path, 'refbot')

        if os.path.isdir(self.refbot_path):
            shutil.rmtree(self.refbot_path)
        refbotdir = os.path.join(basedir, 'refbot')
        shutil.copytree(refbotdir, self.refbot_path)

        self.in_file = os.path.join(self.run_path, wrapper_out_filename)
        self.state_file = os.path.join(self.run_path, state_name)
        self.bot_file = os.path.join(self.run_path, bot_file_name)
        self.proc = None
        self.refenv = RefEnv(self, debug=self.debug)
        
        with open(self.in_file, 'w') as f:
            f.write('0')
        # run path should now have the jar, config and jar wrapper files

    def step(self, action, ref_act):
        #############################
        ## Maintenence on the process
        if self.needs_reset:
            self.reset()

        #######################
        ## Debug messages
        if self.debug:
            with open(os.path.join(self.run_path, 'mylog.txt'), 'a') as f:
                f.write(str(time.time()) + "\t-->Wanting to do op:!!!\t" + str(action) + '\t' + str(ref_act) + '\n')

            with open(os.path.join(self.refbot_path, 'mylog.txt'), 'a') as f:
                f.write(str(time.time()) + "\t-->Wanting to do op:!!!\t" + str(ref_act) + '\n')
        ep_info = {}
        #######################
        ## Writing actions
        x2, y2, build2 = ref_act
        write_prep_action(x2, y2, build2, path=self.refbot_path, debug=self.debug)

        x, y, build = action
        write_prep_action(x, y, build, path=self.run_path, debug=self.debug)

        #######################
        ## Signalling to jar wrappers to begin their running step
        with open(self.in_file, 'w') as f:
            # we want start of a new step
            if self.debug:
                print(">> pyenv {} >> writing 2 to file {}".format(self.name, self.in_file))
            f.write('2')

        with open(self.refenv.in_file, 'w') as f:
            # we want start of a new step
            if self.debug:
                print(">> pyenv {} >> writing 2 to file {}".format(self.refenv.name, self.refenv.in_file))
            f.write('2')

        #######################
        ## Checking if episode ended early
        if self.proc.poll() != None and self.done == True:
            # env ended last step, so reset:
            if self.debug:
                print(">> PYENV ", self.name ," >>  Ended early")
            cntrl_obj = ControlObject('EARLY')
            tp = np.concatenate([np.asarray([cntrl_obj,]), np.asarray([cntrl_obj,])], axis=-1)
            return tp, np.concatenate([np.asarray([cntrl_obj,]), np.asarray([cntrl_obj,])], axis=-1), ep_info

        #######################
        ## Taking step

        # Vars for Env
        obs = None
        should_load_obs = False
        reward = None
        # Vars for ref env
        ref_obs = None
        should_load_obs2 = False
        # Waiting for responses from the jar wrappers
        stopw = Stopwatch()
        failure = False
        while True:
            if should_load_obs == False:
                with open(self.in_file, 'r') as ff:
                    k = ff.read()
                    try:
                        k = int(k)
                    except ValueError:
                        continue
                    if k == 1:
                        #print("just wrote 0 to the ", self.out_file)
                        # a new turn has just been processed
                        should_load_obs = True

            if should_load_obs2 == False:
                with open(self.refenv.in_file, 'r') as ff:
                    k2 = ff.read()
                    try:
                        k2 = int(k2)
                    except ValueError:
                        continue
                    if k2 == 1:
                        #print("just wrote 0 to the ", self.out_file)
                        # a new turn has just been processed
                        should_load_obs2 = True

            if should_load_obs == True and should_load_obs2 == True:
                break
            
            if self.proc.poll() != None and self.done == False:
                #ep ended early.
                if self.debug:
                    print("PYENV: >> GAME ENDING EARLY FOR THE FIRST TIME")
                self.done = True

                valid, reason = is_valid_action(action, self.prev_obs)
                obs = self.load_state()
                self.prev_obs = obs
                
                ep_info['n_steps'] = self.step_num

                if valid == True:
                    ep_info['valid'] = True
                else:
                    ep_info['valid'] = False

                ref_obs = self.refenv.load_state()
                if obs['players'][0]['playerType'] == 'A':
                    a_hp = obs['players'][0]['health']
                    b_hp = obs['players'][1]['health']
                else:
                    a_hp = obs['players'][1]['health']
                    b_hp = obs['players'][0]['health']
                k = np.asarray([obs,])
                u = np.asarray([ref_obs,])
                return_obs = np.concatenate([k, u], axis=-1)
                if reward_mode == 'dense':
                    win_reward = dense_win_reward
                    lose_reward = -1 * dense_win_reward
                else:
                    win_reward = binary_win_reward
                    lose_reward = -1 * binary_win_reward
                if a_hp > b_hp:
                    # player a wins
                    ep_info['winner'] = 'A'
                    return return_obs, np.concatenate([np.asarray([win_reward,]), np.asarray([lose_reward,])], axis=-1), ep_info
                elif a_hp < b_hp:
                    ep_info['winner'] = 'B'
                    return return_obs, np.concatenate([np.asarray([lose_reward,]), np.asarray([win_reward,])], axis=-1), ep_info
                else:
                    ep_info['winner'] = 'TIE'
                    return return_obs, np.concatenate([np.asarray([0.0,]), np.asarray([0.0,])], axis=-1), ep_info

            if stopw.deltaT() > 3:
                # we have waited more than 3s, game clearly ended
                self.needs_reset = True
                failure = True
                print('pyenv: env ' + str(self.name) + ' with pid ' + str(self.pid) + ' encountered error. (', should_load_obs, ',',should_load_obs2, ')' , time.time())
                break

            time.sleep(0.01)
        # TODO: possibly pre-parse obs here and derive a reward from it?
        
        #########################
        ## Loading the obs if their jar's ended properly
        #ref_obs, _ = self.refenv.step(ref_act)
        if should_load_obs:
            valid, reason = is_valid_action(action, self.prev_obs)
            obs = self.load_state()
            self.prev_obs = obs

            if valid == True:
                ep_info['valid'] = True
            else:
                ep_info['valid'] = False
        if should_load_obs2:
            ref_obs = self.refenv.load_state()

        if obs is None and self.debug == True:
            print(">> PY_ENV >> MAIN OBS IS NONE (", self.name, ")")

        if ref_obs is None:
            print(">> PY_ENV >> REF OBS IS NONE. (", self.name, ")")

        if failure == True:
            cntrl_obj = ControlObject('FAILURE')
            tp = np.concatenate([np.asarray([cntrl_obj,]), np.asarray([cntrl_obj,])], axis=-1)
            return tp, np.concatenate([np.asarray([cntrl_obj,]), np.asarray([cntrl_obj,])], axis=-1), ep_info

        # print('-----A------------->', obs['players'][0]['health'])
        # print('-----B------------->', obs['players'][1]['health'])
        self.step_num += 1

        ########################
        ## Forming rewards and packaging the obs into a good numpy form
        if obs is not None:
            # Infer reward:
            #reward = float(obs['players'][0]['score']) - float(obs['players'][1]['score'])
            curS = float(obs['players'][0]['score']) * general_reward_scaling_factor
            self.score_delta = curS - self.score
            reward = self.score_delta + per_step_reward_penalty
            self.score = curS

        if ref_obs is not None:
            curS2 = float(ref_obs['players'][0]['score']) * general_reward_scaling_factor
            self.refenv.score_delta = curS2 - self.refenv.score
            ref_reward = self.refenv.score_delta + per_step_reward_penalty
            self.refenv.score = curS2

        k = np.asarray([obs,])
        u = np.asarray([ref_obs,])
        return_obs = np.concatenate([k, u], axis=-1)
        if reward_mode == 'dense':
            return return_obs, np.concatenate([np.asarray([reward,]), np.asarray([ref_reward,])], axis=-1), ep_info
        elif reward_mode == 'binary':
            return return_obs, np.concatenate([np.asarray([binary_step_penalty,]), np.asarray([binary_step_penalty,])], axis=-1), ep_info


    def load_state(self):
        '''
        Gets the current Game State json file.
        '''
        while os.path.isfile(self.state_file) == False:
            if self.debug:
               print(">> PYENV >> waiting for state file  ", self.state_file, ' to appear')
            time.sleep(0.01)

        flag = False
        while flag == False:
            try:
                k = json.load(open(self.state_file,'r'))
                flag = True
                break
            except json.decoder.JSONDecodeError as e:
                k = None
                if self.debug:
                    print(">> PYENV >> Failed to decode json state! Got error ", e)
                time.sleep(0.01)

        return k

    def get_obs(self):
        this_obs = self.load_state()
        refbot_obs = self.refenv.load_state()
        x = np.asarray([this_obs,])
        y = np.asarray([refbot_obs,])

        return np.concatenate([x, y], axis=-1)

    def reset(self):
        self.step_num = 0
        if self.debug:
            with open(os.path.join(self.run_path, 'mylog.txt'), 'a') as f:
                f.write(str(time.time()) + "\t-->RESETTING!!!\n")

            with open(os.path.join(self.refbot_path, 'mylog.txt'), 'a') as f:
                f.write(str(time.time()) + "\t-->RESETTING!!!\n")

        if self.proc is not None:
            self.proc.terminate()
            self.proc.wait()
        self.needs_reset = False
        self.done = False
        time.sleep(0.01)
        # trying to kill jar wrapper of this env
        pid_file = os.path.join(self.run_path, 'wrapper_pid.txt')
        if os.path.isfile(pid_file):
            flag = False
            while flag == False:
                with open(pid_file, 'r') as f:
                    try:
                        wrapper_pid = int(f.read())
                    except ValueError:
                        continue
                    if wrapper_pid == 0:
                        flag = True
                        return None
                    else:
                        flag = True
                        try:
                            os.kill(wrapper_pid, signal.SIGTERM)
                        except (PermissionError, ProcessLookupError) as e:
                            if self.debug:
                                print(">> PYENV ", self.name, " >> Attempted to close wrapper pid ", wrapper_pid, " but got ERROR ", e)
                        break
        else:
            if self.debug:
                print(">> PYENV >> Attempted to close wrapper pid but the wrapper pid file was not found ")
        ## Trying to prevent reset bugs from propping up
        # if os.path.isdir(self.refbot_path):
        #     shutil.rmtree(self.refbot_path)
        # refbotdir = os.path.join(basedir, 'refbot')
        # shutil.copytree(refbotdir, self.refbot_path)

        ## Trying to kill jar wrapper of ref env
        refpid_file = os.path.join(self.refbot_path, 'wrapper_pid.txt')
        if os.path.isfile(refpid_file):
            flag = False
            while flag == False:
                with open(refpid_file, 'r') as f:
                    try:
                        wrapper_pid2 = int(f.read())
                    except ValueError:
                        continue
                    if wrapper_pid2 == 0:
                        flag = True
                        return None
                    else:
                        flag = True
                        try:
                            os.kill(wrapper_pid2, signal.SIGTERM)
                        except (PermissionError, ProcessLookupError) as e:
                            if self.debug:
                                print(">> PYENV ", self.name, " >> Attempted to close refbot wrapper pid ", wrapper_pid2, " but got ERROR ", e)
        else:
            if self.debug:
                print(">> PYENV >> Attempted to close refbot wrapper pid but the wrapper pid file was not found ")
        time.sleep(0.01)
        
        #######################
        ## Flushing matchlogs folder if env alive for over 1h
        if self.clock.deltaT() >= 1800:
            print(">> PYENV {} >> Env alive for over half an hour, flushing (deleting) matchlogs folder".format(self.name))
            self.cleanup()
            self.clock.reset()
            print("Cleand.")

        command = 'java -jar ' + os.path.join(self.run_path, jar_name)

        if sys.platform == "win32":
            she = False
        else:
            she = True

        if self.debug:
            self.proc = subprocess.Popen(command, shell=she , stdout=subprocess.PIPE, cwd=self.run_path)
            print("Opened process: ", str(command), " with pid ", self.proc.pid)
        else:
            self.proc = subprocess.Popen(command, shell=she, stdout=subprocess.DEVNULL, cwd=self.run_path)
        
        self.pid = self.proc.pid
        time.sleep(0.01)

        return True

    def close(self):
        if self.debug:
            print("Closing env ", self.name)
        # clean up after itself
        
        if self.pid is not None:
            self.needs_reset = True
            self.proc.terminate()
            self.proc.wait()
        else:
            return None

        time.sleep(0.1)
        pid_file = os.path.join(self.run_path, 'wrapper_pid.txt')
        if os.path.isfile(pid_file):
            flag = False
            while flag == False:
                with open(pid_file, 'r') as f:
                    try:
                        wrapper_pid = int(f.read())
                    except ValueError:
                        continue
                    if wrapper_pid == 0:
                        flag = True
                        return None
                    else:
                        flag = True
                        try:
                            os.kill(wrapper_pid, signal.SIGTERM)
                        except (PermissionError, ProcessLookupError) as e:
                            if self.debug:
                                print(">> PYENV ", self.name, " >> Attempted to close wrapper pid ", wrapper_pid, " but got ERROR ", e)
                        break
        else:
            print(">> PYENV >> Attempted to close wrapper pid but the wrapper pid file was not found ")
        time.sleep(0.1)

        refpid_file = os.path.join(self.refbot_path, 'wrapper_pid.txt')
        if os.path.isfile(refpid_file):
            flag = False
            while flag == False:
                with open(refpid_file, 'r') as f:
                    try:
                        wrapper_pid2 = int(f.read())
                    except ValueError:
                        continue
                    if wrapper_pid2 == 0:
                        flag = True
                        return None
                    else:
                        flag = True
                        try:
                            os.kill(wrapper_pid2, signal.SIGTERM)
                        except (PermissionError, ProcessLookupError) as e:
                            if self.debug:
                                print(">> PYENV ", self.name, " >> Attempted to close refbot wrapper pid ", wrapper_pid2, " but got ERROR ", e)
        else:
            if self.debug:
                print(">> PYENV >> Attempted to close refbot wrapper pid but the wrapper pid file was not found ")
        time.sleep(0.1)

        self.pid = None
        return True
        
    def cleanup(self):
        log_path = os.path.join(self.run_path, 'matchlogs')

        if self.debug:
            print("Removing folder: ", log_path)
        try:
            if keep_log_folder_override == False:
                shutil.rmtree(log_path)
            else:
                print(">> PYENV >> OVERRIDE - Keeping log files.")
            time.sleep(0.1)
        except Exception:
            print(">> PYENV >> Exception occured while removing matchlogs folder")