Exemplo n.º 1
0
    def search(self, n_mcts, Env, mcts_env, H=30):
        ''' Perform the MCTS search from the root '''
        if self.root is None:
            # initialize new root
            self.root = ThompsonSamplingState(
                self.root_index,
                r=0.0,
                terminal=False,
                parent_action=None,
                na=self.na,
                model=self)  #, signature=mcts_env.get_signature()
        else:
            self.root.parent_action = None  # continue from current root
        if self.root.terminal:
            raise (ValueError("Can't do tree search from a terminal state"))

        is_atari = is_atari_game(Env)
        if is_atari:
            snapshot = copy_atari_state(
                Env)  # for Atari: snapshot the root at the beginning

        for i in range(n_mcts):
            state = self.root  # reset to root for new trace
            if not is_atari:
                mcts_env = copy.deepcopy(
                    Env)  # copy original Env to rollout from
            else:
                restore_atari_state(mcts_env, snapshot)

            depth = 0
            mcts_env.seed()
            self.search_iteration(mcts_env, state, depth, H)
    def search(self, n_mcts, c, env, mcts_env):
        """
        Perform the MCTS search from the root
        """
        if self.root is None:
            self.root = State(
                self.root_index,
                r=0.0,
                terminal=False,
                parent_action=None,
                na=self.na,
                bootstrap_last_state_value=self.bootstrap_last_state_value,
                model=self.model)  # initialize new root
        else:
            self.root.parent_action = None  # continue from current root
        if self.root.terminal:
            raise (ValueError("Can't do tree search from a terminal state"))

        is_atari = is_atari_game(env)
        if is_atari:
            snapshot = copy_atari_state(
                env)  # for Atari: snapshot the root at the beginning

        for i in range(n_mcts):
            state = self.root  # reset to root for new trace
            if not is_atari:
                mcts_env = copy.deepcopy(
                    env)  # copy original Env to rollout from
            else:
                restore_atari_state(mcts_env, snapshot)

            while not state.terminal:
                action = state.select(c=c)
                s1, r, t, _ = mcts_env.step(action.index)
                if hasattr(action, 'child_state'):
                    state = action.child_state  # select
                    continue
                else:
                    state = action.add_child_state(s1, r, t,
                                                   self.model)  # expand
                    break

            # Back-up
            r = state.V
            while state.parent_action is not None:  # loop back-up until root is reached
                r = state.r + self.gamma * r
                action = state.parent_action
                action.update(r)
                state = action.parent_state
                state.update()
Exemplo n.º 3
0
    def search(self, n_mcts, c, Env, mcts_env, budget, max_depth=200):
        """ Perform the MCTS search from the root """
        if self.root is None:
            # initialize new root
            self.root = State(self.root_index, r=0.0, terminal=False, parent_action=None, na=self.na, env=mcts_env, budget=budget)
        else:
            self.root.parent_action = None  # continue from current root
        if self.root.terminal:
            raise (ValueError("Can't do tree search from a terminal state"))

        is_atari = is_atari_game(Env)
        if is_atari:
            snapshot = copy_atari_state(Env)  # for Atari: snapshot the root at the beginning

        while budget > 0:
            state = self.root  # reset to root for new trace
            if not is_atari:
                mcts_env = copy.deepcopy(Env)  # copy original Env to rollout from
            else:
                restore_atari_state(mcts_env, snapshot)
            st = 0
            while not state.terminal:
                bias = c * self.gamma ** st / (1 - self.gamma) if self.depth_based_bias else c
                action = state.select(c=bias)
                st += 1
                s1, r, t, _ = mcts_env.step(action.index)
                if hasattr(action, 'child_state'):
                    state = action.child_state  # select
                    if state.terminal:
                        budget -= 1
                    continue
                else:
                    state, budget = action.add_child_state(s1, r, t, budget, env=mcts_env, max_depth=max_depth-st)  # expand
                    break

            # Back-up
            R = state.V
            state.update()
            while state.parent_action is not None:  # loop back-up until root is reached
                if not state.terminal:
                    R = state.r + self.gamma * R
                else:
                    R = state.r
                action = state.parent_action
                action.update(R)
                state = action.parent_state
                state.update()
def train(game, n_ep, n_mcts, max_ep_len, lr, c, gamma, data_size, batch_size,
          temp, n_hidden_layers, n_hidden_units):
    """ Outer training loop """

    episode_returns = []  # storage
    timepoints = []
    # Environments
    if game == "teaching":
        env = TeachingEnv()
        bootstrap_last_state_value = False
    else:
        env = make_game(game)
        bootstrap_last_state_value = True
    is_atari = is_atari_game(env)
    mcts_env = make_game(game) if is_atari else None

    replay_buffer = ReplayBuffer(max_size=data_size, batch_size=batch_size)
    model = Model(env=env,
                  lr=lr,
                  n_hidden_layers=n_hidden_layers,
                  n_hidden_units=n_hidden_units)
    t_total = 0  # total steps
    R_best = -np.Inf

    start_training = time.time()

    for ep in range(n_ep):
        start = time.time()
        s = env.reset()
        R = 0.0  # Total return counter
        a_store = []
        seed = np.random.randint(1e7)  # draw some Env seed
        env.seed(seed)
        if is_atari:
            mcts_env.reset()
            mcts_env.seed(seed)

        # the object responsible for MCTS searches
        mcts = MCTS(root_index=s,
                    root=None,
                    model=model,
                    na=model.action_dim,
                    gamma=gamma,
                    bootstrap_last_state_value=bootstrap_last_state_value)

        if game == "teaching":
            iterator = tqdm(range(env.t_max))
        else:
            iterator = range(max_ep_len)
        for _ in iterator:
            # MCTS step
            mcts.search(n_mcts=n_mcts, c=c, env=env,
                        mcts_env=mcts_env)  # perform a forward search
            state, pi, v = mcts.return_results(temp)  # extract the root output
            replay_buffer.store((state, v, pi))

            # Make the true step
            a = np.random.choice(len(pi), p=pi)
            a_store.append(a)
            s1, r, terminal, _ = env.step(a)
            R += r
            t_total += n_mcts  # total number of environment steps (counts the mcts steps)

            if terminal:
                break
            else:
                mcts.forward(a, s1)

        # Finished episode
        episode_returns.append(R)  # store the total episode return
        timepoints.append(
            t_total)  # store the timestep count of the episode return
        store_safely({'R': episode_returns, 't': timepoints})

        if R > R_best:
            a_best = a_store
            seed_best = seed
            R_best = R
        print(f'Finished episode {ep}, total return: {np.round(R,2)}, '
              f'time episode: {time.time()-start:.1f} sec, '
              f'total time since: {time.time()-start_training:.1f} sec')
        # Train
        replay_buffer.shuffle()
        for sb, vb, pib in replay_buffer:
            model.train_on_example(sb=sb, vb=vb, pib=pib)
    # Return results
    return episode_returns, timepoints, a_best, seed_best, R_best
    def search(self, n_mcts, c, Env, mcts_env, budget, max_depth=200):
        ''' Perform the MCTS search from the root '''
        is_atari = is_atari_game(Env)
        if is_atari:
            snapshot = copy_atari_state(Env)  # for Atari: snapshot the root at the beginning
        else:
            mcts_env = copy.deepcopy(Env)  # copy original Env to rollout from
        # else:
        #     restore_atari_state(mcts_env, snapshot)

        # Check that the environment has been copied correctly
        try:
            sig1 = mcts_env.get_signature()
            sig2 = Env.get_signature()
            if sig1.keys() != sig2.keys():
                raise AssertionError
            if not all(np.array_equal(sig1[key], sig2[key]) for key in sig1):
                raise AssertionError
        except AssertionError:
            print("Something wrong while copying the environment")
            sig1 = mcts_env.get_signature()
            sig2 = Env.get_signature()
            print(sig1.keys(), sig2.keys())
            exit()

        if self.root is None:
            # initialize new root
            self.root = StochasticState(self.root_index, r=0.0, terminal=False, parent_action=None, na=self.na,
                                        signature=Env.get_signature(), env=mcts_env, budget=budget)
        else:
            self.root.parent_action = None  # continue from current root
        if self.root.terminal:
            raise (ValueError("Can't do tree search from a terminal state"))

        while budget > 0:
            state = self.root  # reset to root for new trace
            if not is_atari:
                mcts_env = copy.deepcopy(Env)  # copy original Env to rollout from
            else:
                restore_atari_state(mcts_env, snapshot)
            mcts_env.seed()
            st = 0
            while not state.terminal:
                bias = c * self.gamma ** st / (1 - self.gamma) if self.depth_based_bias else c
                action = state.select(c=bias)
                st += 1
                k = np.ceil(self.beta * action.n ** self.alpha)
                if k >= action.n_children:
                    s1, r, t, _ = mcts_env.step(action.index)
                    # if action.index == 0 and not np.array_equal(s1.flatten(), action.parent_state.index.flatten()):
                    #     print("WTF")
                    budget -= 1
                    if action.get_state_ind(s1) != -1:
                        state = action.child_states[action.get_state_ind(s1)]  # select
                        state.r = r
                    else:
                        state, budget = action.add_child_state(s1, r, t, mcts_env.get_signature(), budget, env=mcts_env,
                                                               max_depth=max_depth - st)  # expand
                        break
                else:
                    state = action.sample_state()
                    mcts_env.set_signature(state.signature)
                    if state.terminal:
                        budget -= 1

            # Back-up
            R = state.V
            state.update()
            while state.parent_action is not None:  # loop back-up until root is reached
                if not state.terminal:
                    R = state.r + self.gamma * R
                else:
                    R = state.r
                action = state.parent_action
                action.update(R)
                state = action.parent_state
                state.update()
Exemplo n.º 6
0
def agent(game,
          n_ep,
          n_mcts,
          max_ep_len,
          lr,
          c,
          gamma,
          data_size,
          batch_size,
          temp,
          n_hidden_layers,
          n_hidden_units,
          stochastic=False,
          eval_freq=-1,
          eval_episodes=100,
          alpha=0.6,
          out_dir='../',
          pre_process=None,
          visualize=False):
    ''' Outer training loop '''
    if pre_process is not None:
        pre_process()

    # tf.reset_default_graph()

    if not os.path.exists(out_dir):
        os.makedirs(out_dir)
    episode_returns = []  # storage
    timepoints = []
    # Environments
    Env = make_game(game)
    is_atari = is_atari_game(Env)
    mcts_env = make_game(game) if is_atari else None
    online_scores = []
    offline_scores = []
    mcts_params = dict(gamma=gamma)
    if stochastic:
        mcts_params['alpha'] = alpha
        mcts_maker = MCTSStochastic
    else:
        mcts_maker = MCTS

    D = Database(max_size=data_size, batch_size=batch_size)
    model = Model(Env=Env,
                  lr=lr,
                  n_hidden_layers=n_hidden_layers,
                  n_hidden_units=n_hidden_units)
    t_total = 0  # total steps
    R_best = -np.Inf

    with tf.Session() as sess:
        model.sess = sess
        sess.run(tf.global_variables_initializer())

        for ep in range(n_ep):
            if eval_freq > 0 and ep % eval_freq == 0:  #and ep > 0
                print(
                    'Evaluating policy for {} episodes!'.format(eval_episodes))
                seed = np.random.randint(1e7)  # draw some Env seed
                Env.seed(seed)
                s = Env.reset()
                mcts = mcts_maker(root_index=s,
                                  root=None,
                                  model=model,
                                  na=model.action_dim,
                                  **mcts_params)
                env_wrapper = EnvEvalWrapper()
                env_wrapper.mcts = mcts
                starting_states = []

                def reset_env():
                    s = Env.reset()
                    env_wrapper.mcts = mcts_maker(root_index=s,
                                                  root=None,
                                                  model=model,
                                                  na=model.action_dim,
                                                  **mcts_params)
                    starting_states.append(s)
                    if env_wrapper.curr_probs is not None:
                        env_wrapper.episode_probabilities.append(
                            env_wrapper.curr_probs)
                    env_wrapper.curr_probs = []
                    return s

                def forward(a, s, r):
                    env_wrapper.mcts.forward(a, s, r)
                    #pass

                env_wrapper.reset = reset_env
                env_wrapper.step = lambda x: Env.step(x)
                env_wrapper.forward = forward
                env_wrapper.episode_probabilities = []
                env_wrapper.curr_probs = None

                def pi_wrapper(ob):
                    if not is_atari:
                        mcts_env = None
                    env_wrapper.mcts.search(n_mcts=n_mcts,
                                            c=c,
                                            Env=Env,
                                            mcts_env=mcts_env)
                    state, pi, V = env_wrapper.mcts.return_results(temp=0)
                    #pi = model.predict_pi(s).flatten()
                    env_wrapper.curr_probs.append(pi)
                    a = np.argmax(pi)
                    return a

                rews, lens = eval_policy(pi_wrapper,
                                         env_wrapper,
                                         n_episodes=eval_episodes,
                                         verbose=True)
                offline_scores.append([
                    np.min(rews),
                    np.max(rews),
                    np.mean(rews),
                    np.std(rews),
                    len(rews),
                    np.mean(lens)
                ])
                # if len(rews) < eval_episodes or len(rews) == 0:
                #     print("WTF")
                # if np.std(rews) == 0.:
                #     print("WTF 2")
                np.save(out_dir + '/offline_scores.npy', offline_scores)
            start = time.time()
            s = Env.reset()
            R = 0.0  # Total return counter
            a_store = []
            seed = np.random.randint(1e7)  # draw some Env seed
            Env.seed(seed)
            if is_atari:
                mcts_env.reset()
                mcts_env.seed(seed)
            if ep % eval_freq == 0:
                print("Collecting %d episodes" % eval_freq)
            mcts = mcts_maker(
                root_index=s,
                root=None,
                model=model,
                na=model.action_dim,
                **mcts_params)  # the object responsible for MCTS searches
            for t in range(max_ep_len):
                # MCTS step
                if not is_atari:
                    mcts_env = None
                mcts.search(n_mcts=n_mcts, c=c, Env=Env,
                            mcts_env=mcts_env)  # perform a forward search
                if visualize:
                    mcts.visualize()
                state, pi, V = mcts.return_results(
                    temp)  # extract the root output
                D.store((state, V, pi))

                # Make the true step
                a = np.random.choice(len(pi), p=pi)
                a_store.append(a)
                s1, r, terminal, _ = Env.step(a)
                R += r
                t_total += n_mcts  # total number of environment steps (counts the mcts steps)

                if terminal:
                    break
                else:
                    mcts.forward(a, s1, r)

            # Finished episode
            episode_returns.append(R)  # store the total episode return
            online_scores.append(R)
            timepoints.append(
                t_total)  # store the timestep count of the episode return
            store_safely(out_dir, 'result', {
                'R': episode_returns,
                't': timepoints
            })
            np.save(out_dir + '/online_scores.npy', online_scores)
            # print('Finished episode {}, total return: {}, total time: {} sec'.format(ep, np.round(R, 2),
            #                                                                          np.round((time.time() - start),
            #                                                                                   1)))

            if R > R_best:
                a_best = a_store
                seed_best = seed
                R_best = R

            # Train
            D.reshuffle()
            try:
                for epoch in range(1):
                    for sb, Vb, pib in D:
                        model.train(sb, Vb, pib)
            except Exception as e:
                print("ASD")
            model.save(out_dir + 'model')
    # Return results
    return episode_returns, timepoints, a_best, seed_best, R_best, offline_scores
Exemplo n.º 7
0
def agent(game, n_ep, n_mcts, max_ep_len, lr, c, gamma, data_size, batch_size,
          temp, n_hidden_layers, n_hidden_units):
    ''' Outer training loop '''
    # tf.reset_default_graph()
    episode_returns = []  # storage
    timepoints = []
    # Environments
    Env = make_game(game)
    is_atari = is_atari_game(Env)
    mcts_env = make_game(game) if is_atari else None

    D = Database(max_size=data_size, batch_size=batch_size)
    model = Model(Env=Env,
                  lr=lr,
                  n_hidden_layers=n_hidden_layers,
                  n_hidden_units=n_hidden_units)
    t_total = 0  # total steps
    R_best = -np.Inf

    with tf.Session() as sess:
        model.sess = sess
        sess.run(tf.global_variables_initializer())
        for ep in range(n_ep):
            start = time.time()
            s = Env.reset()
            R = 0.0  # Total return counter
            a_store = []
            seed = np.random.randint(1e7)  # draw some Env seed
            Env.seed(seed)
            if is_atari:
                mcts_env.reset()
                mcts_env.seed(seed)

            mcts = MCTS(
                root_index=s,
                root=None,
                model=model,
                na=model.action_dim,
                gamma=gamma)  # the object responsible for MCTS searches
            for t in range(max_ep_len):
                # MCTS step
                mcts.search(n_mcts=n_mcts, c=c, Env=Env,
                            mcts_env=mcts_env)  # perform a forward search
                state, pi, V = mcts.return_results(
                    temp)  # extract the root output
                D.store((state, V, pi))

                # Make the true step
                a = np.random.choice(len(pi), p=pi)
                a_store.append(a)
                s1, r, terminal, _ = Env.step(a)
                R += r
                t_total += n_mcts  # total number of environment steps (counts the mcts steps)

                if terminal:
                    break
                else:
                    mcts.forward(a, s1)

            # Finished episode
            episode_returns.append(R)  # store the total episode return
            timepoints.append(
                t_total)  # store the timestep count of the episode return
            store_safely(os.getcwd(), 'result', {
                'R': episode_returns,
                't': timepoints
            })

            if R > R_best:
                a_best = a_store
                seed_best = seed
                R_best = R
            print('Finished episode {}, total return: {}, total time: {} sec'.
                  format(ep, np.round(R, 2), np.round((time.time() - start),
                                                      1)))
            # Train
            D.reshuffle()
            for epoch in range(1):
                for sb, Vb, pib in D:
                    model.train(sb, Vb, pib)
    # Return results
    return episode_returns, timepoints, a_best, seed_best, R_best
Exemplo n.º 8
0
def agent(game,
          n_ep,
          n_mcts,
          max_ep_len,
          lr,
          c,
          gamma,
          data_size,
          batch_size,
          temp,
          n_hidden_layers,
          n_hidden_units,
          stochastic=False,
          eval_freq=-1,
          eval_episodes=100,
          alpha=0.6,
          n_epochs=100,
          c_dpw=1,
          numpy_dump_dir='../',
          pre_process=None,
          visualize=False,
          game_params={},
          parallelize_evaluation=False,
          mcts_only=False,
          particles=0,
          show_plots=False,
          n_workers=1,
          use_sampler=False,
          budget=np.inf,
          unbiased=False,
          biased=False,
          max_workers=100,
          variance=False,
          depth_based_bias=False,
          scheduler_params=None,
          out_dir=None,
          render=False,
          second_version=False,
          third_version=False):
    visualizer = None

    # if particles:
    #     parallelize_evaluation = False  # Cannot run parallelized evaluation with particle filtering

    if not mcts_only:
        from mcts import MCTS
        from mcts_dpw import MCTSStochastic
    elif particles:
        if unbiased:
            from particle_filtering.ol_uct import OL_MCTS
        elif biased:
            if second_version:
                from particle_filtering.pf_uct_2 import PFMCTS2 as PFMCTS
            elif third_version:
                from particle_filtering.pf_uct_3 import PFMCTS3 as PFMCTS
            else:
                from particle_filtering.pf_uct import PFMCTS
        else:
            from particle_filtering.pf_mcts_edo import PFMCTS
    else:
        from pure_mcts.mcts import MCTS
        from pure_mcts.mcts_dpw import MCTSStochastic

    if parallelize_evaluation:
        print("The evaluation will be parallel")

    parameter_list = {
        "game": game,
        "n_ep": n_ep,
        "n_mcts": n_mcts,
        "max_ep_len": max_ep_len,
        "lr": lr,
        "c": c,
        "gamma": gamma,
        "data_size": data_size,
        "batch_size": batch_size,
        "temp": temp,
        "n_hidden_layers": n_hidden_layers,
        "n_hidden_units": n_hidden_units,
        "stochastic": stochastic,
        "eval_freq": eval_freq,
        "eval_episodes": eval_episodes,
        "alpha": alpha,
        "n_epochs": n_epochs,
        "out_dir": numpy_dump_dir,
        "pre_process": pre_process,
        "visualize": visualize,
        "game_params": game_params,
        "n_workers": n_workers,
        "use_sampler": use_sampler,
        "variance": variance,
        "depth_based_bias": depth_based_bias,
        "unbiased": unbiased,
        "second_version": second_version,
        'third_version': third_version
    }
    if out_dir is not None:
        if not os.path.exists(out_dir):
            os.makedirs(out_dir)
        with open(os.path.join(out_dir, "parameters.txt"), 'w') as d:
            d.write(json.dumps(parameter_list))
    #logger = Logger(parameter_list, game, show=show_plots)

    if DEBUG_TAXI:
        from utils.visualization.taxi import TaxiVisualizer
        with open(game_params["grid"]) as f:
            m = f.readlines()
            matrix = []
            for r in m:
                row = []
                for ch in r.strip('\n'):
                    row.append(ch)
                matrix.append(row)
            visualizer = TaxiVisualizer(matrix)
            f.close()
            exit()
    ''' Outer training loop '''
    if pre_process is not None:
        pre_process()

    # numpy_dump_dir = logger.numpy_dumps_dir
    #
    # if not os.path.exists(numpy_dump_dir):
    #     os.makedirs(numpy_dump_dir)

    episode_returns = []  # storage
    timepoints = []

    # Environments
    if game == 'Trading-v0':
        game_params['save_dir'] = out_dir  #logger.save_dir
    Env = make_game(game, game_params)
    num_actions = Env.action_space.n
    sampler = None
    if use_sampler and not (unbiased or biased):

        def make_pi(action_space):
            def pi(s):
                return np.random.randint(low=0, high=action_space.n)

            return pi

        def make_env():
            return make_game(game, game_params)

        sampler = ParallelSampler(make_pi=make_pi,
                                  make_env=make_env,
                                  n_particles=particles,
                                  n_workers=n_workers,
                                  seed=10)

    is_atari = is_atari_game(Env)
    mcts_env = make_game(game, game_params) if is_atari else None
    online_scores = []
    offline_scores = []

    # Setup the parameters for generating the search environments

    if game == "RaceStrategy-v1":
        mcts_maker, mcts_params, c_dpw = load_race_agents_config(
            'envs/configs/race_strategy_full.json', gamma)

    else:
        mcts_params = dict(gamma=gamma)
        if particles:
            if not (biased or unbiased):
                mcts_params['particles'] = particles
                mcts_params['sampler'] = sampler
            elif biased:
                mcts_params['alpha'] = alpha
                mcts_maker = PFMCTS

            mcts_params['depth_based_bias'] = depth_based_bias
            if unbiased:
                mcts_params['variance'] = variance
                mcts_maker = OL_MCTS

        elif stochastic:
            mcts_params['alpha'] = alpha
            mcts_params['depth_based_bias'] = depth_based_bias
            mcts_maker = MCTSStochastic
        else:
            mcts_maker = MCTS

    # Prepare the database for storing training data to be sampled
    db = Database(max_size=data_size, batch_size=batch_size)

    # TODO extract dimensions to avoid allocating model
    # Setup the model
    model_params = {
        "Env": Env,
        "lr": lr,
        "n_hidden_layers": n_hidden_layers,
        "n_hidden_units": n_hidden_units,
        "joint_networks": True
    }

    model_wrapper = ModelWrapper(**model_params)

    t_total = 0  # total steps
    R_best = -np.Inf
    a_best = None
    seed_best = None

    # Variables for storing values to be plotted
    avgs = []
    stds = []

    # Run the episodes
    for ep in range(n_ep):

        if DEBUG_TAXI:
            visualizer.reset()

        ##### Policy evaluation step #####
        if eval_freq > 0 and ep % eval_freq == 0:  # and ep > 0
            print(
                '--------------------------------\nEvaluating policy for {} episodes!'
                .format(eval_episodes))
            seed = np.random.randint(1e7)  # draw some Env seed
            Env.seed(seed)
            s = Env.reset()

            if parallelize_evaluation:
                penv = None
                pgame = {
                    "game_maker": make_game,
                    "game": game,
                    "game_params": game_params
                }
            else:
                penv = Env
                pgame = None

            model_file = os.path.join(out_dir, "model.h5")

            # model_wrapper.save(model_file)

            if game == "RaceStrategy-v1":
                env_wrapper = RaceWrapper(s,
                                          mcts_maker,
                                          model_file,
                                          model_params,
                                          mcts_params,
                                          is_atari,
                                          n_mcts,
                                          budget,
                                          mcts_env,
                                          c_dpw,
                                          temp,
                                          env=penv,
                                          game_maker=pgame,
                                          mcts_only=mcts_only,
                                          scheduler_params=scheduler_params)
            else:
                env_wrapper = Wrapper(s,
                                      mcts_maker,
                                      model_file,
                                      model_params,
                                      mcts_params,
                                      is_atari,
                                      n_mcts,
                                      budget,
                                      mcts_env,
                                      c_dpw,
                                      temp,
                                      env=penv,
                                      game_maker=pgame,
                                      mcts_only=mcts_only,
                                      scheduler_params=scheduler_params)

            # Run the evaluation
            if parallelize_evaluation:
                total_reward, reward_per_timestep, lens, action_counts = \
                    parallelize_eval_policy(env_wrapper, n_episodes=eval_episodes, verbose=False, max_len=max_ep_len,
                                            max_workers=max_workers, out_dir=out_dir)
            else:
                total_reward, reward_per_timestep, lens, action_counts = \
                    eval_policy(env_wrapper, n_episodes=eval_episodes, verbose=False, max_len=max_ep_len,
                                visualize=visualize, out_dir=out_dir, render=render)

            # offline_scores.append([np.min(rews), np.max(rews), np.mean(rews), np.std(rews),
            #                        len(rews), np.mean(lens)])

            offline_scores.append(
                [total_reward, reward_per_timestep, lens, action_counts])

            #np.save(numpy_dump_dir + '/offline_scores.npy', offline_scores)

            # Store and plot data
            avgs.append(np.mean(total_reward))
            stds.append(np.std(total_reward))

            #logger.plot_evaluation_mean_and_variance(avgs, stds)

        ##### Policy improvement step #####

        if not mcts_only:

            start = time.time()
            s = start_s = Env.reset()
            R = 0.0  # Total return counter
            a_store = []
            seed = np.random.randint(1e7)  # draw some Env seed
            Env.seed(seed)
            if is_atari:
                mcts_env.reset()
                mcts_env.seed(seed)

            if eval_freq > 0 and ep % eval_freq == 0:
                print("\nCollecting %d episodes" % eval_freq)
            mcts = mcts_maker(
                root_index=s,
                root=None,
                model=model_wrapper,
                na=model_wrapper.action_dim,
                **mcts_params)  # the object responsible for MCTS searches

            print("\nPerforming MCTS steps\n")

            ep_steps = 0
            start_targets = []

            for st in range(max_ep_len):

                print_step = max_ep_len // 10
                if st % print_step == 0:
                    print('Step ' + str(st + 1) + ' of ' + str(max_ep_len))

                # MCTS step
                if not is_atari:
                    mcts_env = None
                mcts.search(n_mcts=n_mcts, c=c, Env=Env,
                            mcts_env=mcts_env)  # perform a forward search

                if visualize:
                    mcts.visualize()

                state, pi, V = mcts.return_results(
                    temp)  # extract the root output

                # Save targets for starting state to debug
                if np.array_equal(start_s, state):
                    if DEBUG:
                        print("Pi target for starting state:", pi)
                    start_targets.append((V, pi))
                db.store((state, V, pi))

                # Make the true step
                a = np.random.choice(len(pi), p=pi)
                a_store.append(a)

                s1, r, terminal, _ = Env.step(a)

                # Perform command line visualization if necessary
                if DEBUG_TAXI:
                    olds, olda = copy.deepcopy(s1), copy.deepcopy(a)
                    visualizer.visualize_taxi(olds, olda)
                    print("Reward:", r)

                R += r
                t_total += n_mcts  # total number of environment steps (counts the mcts steps)
                ep_steps = st + 1

                if terminal:
                    break  # Stop the episode if we encounter a terminal state
                else:
                    mcts.forward(a, s1, r)  # Otherwise proceed

            # Finished episode
            if DEBUG:
                print("Train episode return:", R)
                print("Train episode actions:", a_store)
            episode_returns.append(R)  # store the total episode return
            online_scores.append(R)
            timepoints.append(
                t_total)  # store the timestep count of the episode return
            #store_safely(numpy_dump_dir, '/result', {'R': episode_returns, 't': timepoints})
            #np.save(numpy_dump_dir + '/online_scores.npy', online_scores)

            if DEBUG or True:
                print(
                    'Finished episode {} in {} steps, total return: {}, total time: {} sec'
                    .format(ep, ep_steps, np.round(R, 2),
                            np.round((time.time() - start), 1)))
            # Plot the online return over training episodes

            #logger.plot_online_return(online_scores)

            if R > R_best:
                a_best = a_store
                seed_best = seed
                R_best = R

            print()

            # Train only if the model has to be used
            if not mcts_only:
                # Train
                try:
                    print("\nTraining network")
                    ep_V_loss = []
                    ep_pi_loss = []

                    for _ in range(n_epochs):
                        # Reshuffle the dataset at each epoch
                        db.reshuffle()

                        batch_V_loss = []
                        batch_pi_loss = []

                        # Batch training
                        for sb, Vb, pib in db:

                            if DEBUG:
                                print("sb:", sb)
                                print("Vb:", Vb)
                                print("pib:", pib)

                            loss = model_wrapper.train(sb, Vb, pib)

                            batch_V_loss.append(loss[1])
                            batch_pi_loss.append(loss[2])

                        ep_V_loss.append(mean(batch_V_loss))
                        ep_pi_loss.append(mean(batch_pi_loss))

                    # Plot the loss over training epochs

                    #logger.plot_loss(ep, ep_V_loss, ep_pi_loss)

                except Exception as e:
                    print("Something wrong while training:", e)

                # model.save(out_dir + 'model')

                # Plot the loss over different episodes
                #logger.plot_training_loss_over_time()

                pi_start = model_wrapper.predict_pi(start_s)
                V_start = model_wrapper.predict_V(start_s)

                print("\nStart policy: ", pi_start)
                print("Start value:", V_start)

                #logger.log_start(ep, pi_start, V_start, start_targets)

    # Return results
    if use_sampler:
        sampler.close()
    return episode_returns, timepoints, a_best, seed_best, R_best, offline_scores
    def search(self, n_mcts, c, Env, mcts_env, max_depth=200):
        ''' Perform the MCTS search from the root '''
        is_atari = is_atari_game(Env)
        if is_atari:
            snapshot = copy_atari_state(
                Env)  # for Atari: snapshot the root at the beginning
        else:
            mcts_env = copy.deepcopy(Env)  # copy original Env to rollout from
        # else:
        #     restore_atari_state(mcts_env, snapshot)
        if mcts_env._state != Env._state:
            print("Copying went wrong")
        if self.root is None:
            # initialize new root
            self.root = StochasticState(self.root_index,
                                        r=0.0,
                                        terminal=False,
                                        parent_action=None,
                                        na=self.na,
                                        model=self.model,
                                        signature=Env.get_signature(),
                                        max_depth=max_depth)
        else:
            self.root.parent_action = None  # continue from current root
        if self.root.terminal:
            raise (ValueError("Can't do tree search from a terminal state"))

        for i in range(n_mcts):
            state = self.root  # reset to root for new trace
            if not is_atari:
                mcts_env = copy.deepcopy(
                    Env)  # copy original Env to rollout from
            else:
                restore_atari_state(mcts_env, snapshot)
            # obs1 = mcts_env._get_obs().flatten()
            # obs2 = Env._get_obs().flatten()
            # if not np.array_equal(obs1, obs2):
            #     print("HOLDUP")
            mcts_env.seed()
            while not state.terminal:
                # obs = mcts_env._get_obs().flatten()
                # flattened_State = state.index.flatten()
                # if not np.array_equal(flattened_State, obs):
                #     print("WHATTTTTT")
                action = state.select(c=c)
                k = np.ceil(c * action.n**self.alpha)
                if k >= action.n_children:
                    s1, r, t, _ = mcts_env.step(action.index)
                    # if action.index == 0 and not np.array_equal(s1.flatten(), action.parent_state.index.flatten()):
                    #     print("WTF")
                    if action.get_state_ind(s1) != -1:
                        state = action.child_states[action.get_state_ind(
                            s1)]  # select
                        state.r = r
                    else:
                        # if action.index == 0 and len(action.child_states) > 0:
                        #     print("Error")
                        state = action.add_child_state(
                            s1, r, t, self.model,
                            mcts_env.get_signature())  # expand
                        break
                else:
                    state = action.sample_state()
                    mcts_env.set_signature(state.signature)
                    # obs = mcts_env._get_obs().flatten()
                    # flattened_State = state.index.flatten()
                    # if not np.array_equal(flattened_State, obs):
                    #     print("WHATTTTTT")

            # Back-up
            R = state.V
            state.update()
            while state.parent_action is not None:  # loop back-up until root is reached
                if not state.terminal:
                    R = state.r + self.gamma * R
                else:
                    R = state.r
                action = state.parent_action
                action.update(R)
                state = action.parent_state
                state.update()
Exemplo n.º 10
0
def agent(game, n_ep, n_mcts, max_ep_len, lr, c, gamma, data_size, batch_size,
          temp, n_hidden_layers, n_hidden_units):
    """ Outer training loop """
    seed_best = None
    a_best = None
    episode_returns = []  # storage
    timepoints = []
    # environments
    env = make_game(game)
    is_atari = is_atari_game(env)
    mcts_env = make_game(game) if is_atari else None

    database = Database(max_size=data_size, batch_size=batch_size)
    model = Model(env=env,
                  lr=lr,
                  n_hidden_layers=n_hidden_layers,
                  n_hidden_units=n_hidden_units)
    t_total = 0  # total steps
    r_best = -np.Inf

    for ep in range(n_ep):
        start = time.time()
        s = env.reset()
        r2 = 0.0  # Total return counter
        a_store = []
        seed = np.random.randint(1e7)  # draw some env seed
        env.seed(seed)
        if is_atari:
            mcts_env.reset()
            mcts_env.seed(seed)

        mcts = MCTS(root_index=s,
                    model=model,
                    na=model.action_dim,
                    gamma=gamma)  # the object responsible for MCTS searches
        for t in range(max_ep_len):
            # MCTS step
            mcts.search(n_mcts=n_mcts, c=c, env=env,
                        mcts_env=mcts_env)  # perform a forward search
            state, pi, v = mcts.return_results(temp)  # extract the root output
            database.store((state, v, pi))

            # Make the true step
            a = np.random.choice(len(pi), p=pi)
            a_store.append(a)
            s1, r, terminal, _ = env.step(a)
            r2 += r
            # total number of environment steps (counts the mcts steps)
            t_total += n_mcts

            if terminal:
                break
            else:
                mcts.forward(a, s1)

        # Finished episode
        episode_returns.append(r2)  # store the total episode return
        timepoints.append(
            t_total)  # store the timestep count of the episode return
        store_safely(os.getcwd(), 'result', {
            'r': episode_returns,
            't': timepoints
        })

        if r2 > r_best:
            a_best = a_store
            seed_best = seed
            r_best = r2
        print(
            'Finished episode {}, total return: {}, total time: {} sec'.format(
                ep, np.round(r2, 2), np.round((time.time() - start), 1)))
        # Train
        database.reshuffle()
        for epoch in range(1):
            for sb, v_batch, pi_batch in database:
                model.train(sb, v_batch, pi_batch)
    # return results
    return episode_returns, timepoints, a_best, seed_best, r_best