Exemplo n.º 1
0
def test_alpgmm(env,
                nb_episodes,
                gif=True,
                nb_dims=2,
                score_step=1000,
                verbose=True,
                params={}):
    # Init teacher
    task_generator = ALPGMM([0] * nb_dims, [1] * nb_dims, params=params)

    # Init book keeping
    rewards = []
    scores = []
    bk = {
        'weights': [],
        'covariances': [],
        'means': [],
        'tasks_lps': [],
        'episodes': [],
        'comp_grids': [],
        'comp_xs': [],
        'comp_ys': []
    }

    # Launch run
    for i in range(nb_episodes + 1):
        if (i % score_step) == 0:
            scores.append(env.get_score())
            if nb_dims == 2:
                if verbose:
                    print(env.cube_competence)
            else:
                if verbose:
                    print("it:{}, score:{}".format(i, scores[-1]))

        # Book keeping if ALP-GMM updated its GMM
        if i > 100 and (i % task_generator.fit_rate) == 0 and (gif is True):
            bk['weights'].append(task_generator.gmm.weights_.copy())
            bk['covariances'].append(task_generator.gmm.covariances_.copy())
            bk['means'].append(task_generator.gmm.means_.copy())
            bk['tasks_lps'] = task_generator.tasks_alps
            bk['episodes'].append(i)
            if nb_dims == 2:
                bk['comp_grids'].append(env.cube_competence.copy())
                bk['comp_xs'].append(env.bnds[0].copy())
                bk['comp_ys'].append(env.bnds[1].copy())

        task = task_generator.sample_task()
        reward = env.episode(task)
        task_generator.update(np.array(task), reward)
        rewards.append(reward)

    if gif and nb_dims == 2:
        print('Creating gif...')
        gmm_plot_gif(bk,
                     gifname='alpgmm_' + str(time.time()),
                     gifdir='toy_env_gifs/')
        print('Done (see graphics/toy_env_gifs/ folder)')
    return scores
Exemplo n.º 2
0
    def __init__(self):
        # have to do above before call to parent to inirialize Evaluator correctly
        super(Teacher, self).__init__()
        # dictionary of param names to target histories as set by alp_gmm
        self.param_hist = {}
        envs = self.envs
        args = self.args
        env_param_bounds = envs.get_param_bounds()
        # in case we want to change this dynamically in the future (e.g., we may
        # not know how much traffic the agent can possibly produce in Micropolis)
        envs.set_param_bounds(env_param_bounds) # start with default bounds
        env_param_bounds = env_param_bounds
        num_env_params = 4
        env_param_ranges = []
        env_param_lw_bounds = []
        env_param_hi_bounds = []
        i = 0
        for k, v in env_param_bounds.items():
            if i < num_env_params:
                env_param_ranges += [abs(v[1] - v[0])]
                env_param_lw_bounds += [v[0]]
                env_param_hi_bounds += [v[1]]
                i += 1
            else:
                break
        alp_gmm = None
        if self.checkpoint:
            alp_gmm = self.checkpoint['alp_gmm']
        if alp_gmm is None:
            alp_gmm = ALPGMM(env_param_lw_bounds, env_param_hi_bounds)
        params_vec = alp_gmm.sample_task()
        self.alp_gmm = alp_gmm

        params = OrderedDict()
        print('\n env_param_bounds', env_param_bounds)
        print(params_vec)
        trial_remaining = args.max_step
        trial_reward = 0

        self.env_param_bounds = env_param_bounds
        self.num_env_params = num_env_params
        self.env_param_ranges = env_param_ranges
        self.params_vec = params_vec
        self.params = params
        self.trial_remaining = args.max_step
        self.trial_reward = trial_reward
Exemplo n.º 3
0
class ALPGMMTeacher(gym.Wrapper):
    def __init__(self, env, **kwargs):

        from teachDRL.teachers.algos.alp_gmm import ALPGMM

        super(ALPGMMTeacher, self).__init__(env)
        self.cond_bounds = self.env.unwrapped.cond_bounds
        self.midep_trgs = False
        env_param_lw_bounds = [self.cond_bounds[k][0] for k in self.usable_metrics]
        env_param_hi_bounds = [self.cond_bounds[k][1] for k in self.usable_metrics]
        self.alp_gmm = ALPGMM(env_param_lw_bounds, env_param_hi_bounds)
        self.trg_vec = None
        self.trial_reward = 0
        self.n_trial_steps = 0

    def reset(self):
        if self.trg_vec is not None:
            if self.n_trial_steps == 0:
                # This is some whack shit that happens when we reset manually from the inference script.
                rew = 0
            else:
                rew = self.trial_reward / self.n_trial_steps
            self.alp_gmm.update(self.trg_vec, rew)
        trg_vec = self.alp_gmm.sample_task()
        self.trg_vec = trg_vec
        trgs = {k: trg_vec[i] for (i, k) in enumerate(self.usable_metrics)}
        #       print(trgs)
        self.set_trgs(trgs)
        self.trial_reward = 0
        self.n_trial_steps = 0

        return self.env.reset()

    def step(self, action):
        obs, rew, done, info = self.env.step(action)
        self.trial_reward += rew
        self.n_trial_steps += 1

        return obs, rew, done, info
Exemplo n.º 4
0
def main():
    import random
    import gym_micropolis
    import game_of_life

    args = get_args()
    args.log_dir = args.save_dir + '/logs'
    assert args.algo in ['a2c', 'ppo', 'acktr']
    if args.recurrent_policy:
        assert args.algo in ['a2c', 'ppo'], \
            'Recurrent policy is not implemented for ACKTR'
    args.poet = True  # hacky

    num_updates = int(args.num_frames) // args.num_steps // args.num_processes

    torch.manual_seed(args.seed)
    if args.cuda:
        torch.cuda.manual_seed(args.seed)

    graph_name = args.save_dir.split('trained_models/')[1].replace('/', ' ')

    actor_critic = False
    agent = False
    past_steps = 0
    try:
        os.makedirs(args.log_dir)
    except OSError:
        files = glob.glob(os.path.join(args.log_dir, '*.monitor.csv'))
        for f in files:
            if args.overwrite:
                os.remove(f)
            else:
                pass
    torch.set_num_threads(1)
    device = torch.device("cuda:0" if args.cuda else "cpu")

    if args.vis:
        from visdom import Visdom
        viz = Visdom(port=args.port)
        win = None
        win_eval = None
    if 'GameOfLife' in args.env_name:
        print('env name: {}'.format(args.env_name))
        num_actions = 1
    envs = make_vec_envs(args.env_name,
                         args.seed,
                         args.num_processes,
                         args.gamma,
                         args.log_dir,
                         args.add_timestep,
                         device,
                         False,
                         None,
                         args=args)

    if isinstance(envs.observation_space, gym.spaces.Discrete):
        num_inputs = envs.observation_space.n
    elif isinstance(envs.observation_space, gym.spaces.Box):
        if len(envs.observation_space.shape) == 3:
            in_w = envs.observation_space.shape[1]
            in_h = envs.observation_space.shape[2]
        else:
            in_w = 1
            in_h = 1
        num_inputs = envs.observation_space.shape[0]
    if isinstance(envs.action_space, gym.spaces.Discrete):
        out_w = 1
        out_h = 1
        if 'Micropolis' in args.env_name:  #otherwise it's set
            if args.power_puzzle:
                num_actions = 1
            else:
                num_actions = 19  # TODO: have this already from env
        elif 'GameOfLife' in args.env_name:
            num_actions = 1
        else:
            num_actions = envs.action_space.n
    elif isinstance(envs.action_space, gym.spaces.Box):
        if len(envs.action_space.shape) == 3:
            out_w = envs.action_space.shape[1]
            out_h = envs.action_space.shape[2]
        elif len(envs.action_space.shape) == 1:
            out_w = 1
            out_h = 1
        num_actions = envs.action_space.shape[-1]
    print('num actions {}'.format(num_actions))

    if args.auto_expand:
        args.n_recs -= 1
    actor_critic = Policy(envs.observation_space.shape,
                          envs.action_space,
                          base_kwargs={
                              'map_width': args.map_width,
                              'num_actions': num_actions,
                              'recurrent': args.recurrent_policy,
                              'prebuild': args.prebuild,
                              'in_w': in_w,
                              'in_h': in_h,
                              'num_inputs': num_inputs,
                              'out_w': out_w,
                              'out_h': out_h
                          },
                          curiosity=args.curiosity,
                          algo=args.algo,
                          model=args.model,
                          args=args)
    if args.auto_expand:
        args.n_recs += 1

    evaluator = None

    if not agent:
        agent = init_agent(actor_critic, args)

#saved_model = os.path.join(args.save_dir, args.env_name + '.pt')
    if args.load_dir:
        saved_model = os.path.join(args.load_dir, args.env_name + '.tar')
    else:
        saved_model = os.path.join(args.save_dir, args.env_name + '.tar')
    vec_norm = get_vec_normalize(envs)
    alp_gmm = None
    if os.path.exists(saved_model) and not args.overwrite:
        checkpoint = torch.load(saved_model)
        saved_args = checkpoint['args']
        actor_critic.load_state_dict(checkpoint['model_state_dict'])
        actor_critic.to(device)
        actor_critic.cuda()
        #agent = init_agent(actor_critic, saved_args)
        agent.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        if args.auto_expand:
            if not args.n_recs - saved_args.n_recs == 1:
                print(
                    'can expand by 1 rec only from saved model, not {}'.format(
                        args.n_recs - saved_args.n_recs))
                raise Exception
            actor_critic.base.auto_expand()
            print('expanded net: \n{}'.format(actor_critic.base))
        past_steps = checkpoint['past_steps']
        ob_rms = checkpoint['ob_rms']

        past_steps = next(iter(
            agent.optimizer.state_dict()['state'].values()))['step']
        print('Resuming from step {}'.format(past_steps))

        #print(type(next(iter((torch.load(saved_model))))))
        #actor_critic, ob_rms = \
        #        torch.load(saved_model)
        #agent = \
        #    torch.load(os.path.join(args.save_dir, args.env_name + '_agent.pt'))
        #if not agent.optimizer.state_dict()['state'].values():
        #    past_steps = 0
        #else:

        #    raise Exception
        alp_gmm = checkpoint['alp_gmm']

        if vec_norm is not None:
            vec_norm.eval()
            vec_norm.ob_rms = ob_rms
        saved_args.num_frames = args.num_frames
        saved_args.vis_interval = args.vis_interval
        saved_args.eval_interval = args.eval_interval
        saved_args.overwrite = args.overwrite
        saved_args.n_recs = args.n_recs
        saved_args.intra_shr = args.intra_shr
        saved_args.inter_shr = args.inter_shr
        saved_args.map_width = args.map_width
        saved_args.render = args.render
        saved_args.print_map = args.print_map
        saved_args.load_dir = args.load_dir
        saved_args.experiment_name = args.experiment_name
        saved_args.log_dir = args.log_dir
        saved_args.save_dir = args.save_dir
        saved_args.num_processes = args.num_processes
        saved_args.n_chan = args.n_chan
        saved_args.prebuild = args.prebuild
        args = saved_args
    actor_critic.to(device)

    if 'LSTM' in args.model:
        recurrent_hidden_state_size = actor_critic.base.get_recurrent_state_size(
        )
    else:
        recurrent_hidden_state_size = actor_critic.recurrent_hidden_state_size
    if args.curiosity:
        rollouts = CuriosityRolloutStorage(
            args.num_steps,
            args.num_processes,
            envs.observation_space.shape,
            envs.action_space,
            recurrent_hidden_state_size,
            actor_critic.base.feature_state_size(),
            args=args)
    else:
        rollouts = RolloutStorage(args.num_steps,
                                  args.num_processes,
                                  envs.observation_space.shape,
                                  envs.action_space,
                                  recurrent_hidden_state_size,
                                  args=args)

    obs = envs.reset()
    rollouts.obs[0].copy_(obs)
    rollouts.to(device)

    episode_rewards = deque(maxlen=10)

    start = time.time()
    model = actor_critic.base
    reset_eval = False
    plotter = None
    if args.model == 'FractalNet' or args.model == 'fractal':
        n_cols = model.n_cols
        if args.rule == 'wide1' and args.n_recs > 3:
            col_step = 3
        else:
            col_step = 1
    else:
        n_cols = 0
        col_step = 1
    env_param_bounds = envs.venv.venv.get_param_bounds()
    envs.venv.venv.set_param_ranges(env_param_bounds)
    num_env_params = len(env_param_bounds)
    env_param_ranges = [abs(v[1] - v[0]) for k, v in env_param_bounds.items()]
    env_param_lw_bounds = [v[0] for k, v in env_param_bounds.items()]
    env_param_hi_bounds = [v[1] for k, v in env_param_bounds.items()]
    if alp_gmm is None:
        alp_gmm = ALPGMM(env_param_lw_bounds, env_param_hi_bounds)
    params_vec = alp_gmm.sample_task()
    params = OrderedDict()
    print('\n env_param_bounds', env_param_bounds)
    print(params_vec)
    trial_remaining = args.max_step
    trial_reward = 0
    for j in range(past_steps, num_updates):
        if trial_remaining == 0:
            trial_reward = trial_reward / args.num_processes
            alp_gmm.update(params_vec, trial_reward)
            trial_reward = 0
            trial_remaining = args.max_step
            # sample random environment parameters
            params_vec = alp_gmm.sample_task()
            prm_i = 0
            for k, v in env_param_bounds.items():
                params[k] = params_vec[prm_i]
                prm_i += 1
            envs.venv.venv.set_params(params)
        trial_remaining -= args.num_steps
        if reset_eval:
            print('post eval reset')
            obs = envs.reset()
            rollouts.obs[0].copy_(obs)
            rollouts.to(device)
            reset_eval = False
    #if np.random.rand(1) < 0.1:
    #    envs.venv.venv.remotes[1].send(('setRewardWeights', None))
        if args.model == 'FractalNet' and args.drop_path:
            #if args.intra_shr and args.inter_shr:
            #    n_recs = np.randint
            #    model.set_n_recs()
            model.set_drop_path()
        if args.model == 'fixed' and model.RAND:
            model.num_recursions = random.randint(1, model.map_width * 2)
        player_act = None
        for step in range(args.num_steps):
            # Sample actions
            with torch.no_grad():
                if args.render:
                    if args.num_processes == 1:
                        if not ('Micropolis' in args.env_name
                                or 'GameOfLife' in args.env_name):
                            envs.venv.venv.render()
                        else:
                            pass
                    else:
                        if not ('Micropolis' in args.env_name
                                or 'GameOfLife' in args.env_name):
                            envs.render()
                            envs.venv.venv.render()
                        else:
                            pass
                        #envs.venv.venv.remotes[0].send(('render', None))
                        #envs.venv.venv.remotes[0].recv()
                value, action, action_log_probs, recurrent_hidden_states = actor_critic.act(
                    rollouts.obs[step],
                    rollouts.recurrent_hidden_states[step],
                    rollouts.masks[step],
                    player_act=player_act,
                    icm_enabled=args.curiosity,
                    deterministic=False)

            # Observe reward and next obs
            obs, reward, done, infos = envs.step(action)

            player_act = None
            if args.render:
                if infos[0]:
                    if 'player_move' in infos[0].keys():
                        player_act = infos[0]['player_move']
            if args.curiosity:
                # run icm
                with torch.no_grad():

                    feature_state, feature_state_pred, action_dist_pred = actor_critic.icm_act(
                        (rollouts.obs[step], obs, action_bin))

                intrinsic_reward = args.eta * (
                    (feature_state - feature_state_pred).pow(2)).sum() / 2.
                if args.no_reward:
                    reward = 0
                reward += intrinsic_reward.cpu()

            for info in infos:
                if 'episode' in info.keys():
                    epi_reward = info['episode']['r']
                    episode_rewards.append(epi_reward)
                    trial_reward += epi_reward

            # If done then clean the history of observations.
            masks = torch.FloatTensor([[0.0] if done_ else [1.0]
                                       for done_ in done])

            if args.curiosity:
                rollouts.insert(obs, recurrent_hidden_states, action,
                                action_log_probs, value, reward, masks,
                                feature_state, feature_state_pred, action_bin,
                                action_dist_pred)
            else:
                rollouts.insert(obs, recurrent_hidden_states, action,
                                action_log_probs, value, reward, masks)

        with torch.no_grad():
            next_value = actor_critic.get_value(
                rollouts.obs[-1], rollouts.recurrent_hidden_states[-1],
                rollouts.masks[-1]).detach()

        rollouts.compute_returns(next_value, args.use_gae, args.gamma,
                                 args.tau)
        if args.curiosity:
            value_loss, action_loss, dist_entropy, fwd_loss, inv_loss = agent.update(
                rollouts)
        else:
            value_loss, action_loss, dist_entropy = agent.update(rollouts)

        rollouts.after_update()

        total_num_steps = (j + 1) * args.num_processes * args.num_steps

        if not dist_entropy:
            dist_entropy = 0
        if j % args.log_interval == 0 and len(episode_rewards) > 1:
            end = time.time()
            print(
                "Updates {}, num timesteps {}, FPS {} \n Last {} training episodes: mean/median reward {:.1f}/{:.1f}, min/max reward {:.1f}/{:.1f}\n \
dist entropy {:.1f}, val/act loss {:.1f}/{:.1f},".format(
                    j, total_num_steps,
                    int((total_num_steps -
                         past_steps * args.num_processes * args.num_steps) /
                        (end - start)), len(episode_rewards),
                    np.mean(episode_rewards), np.median(episode_rewards),
                    np.min(episode_rewards), np.max(episode_rewards),
                    dist_entropy, value_loss, action_loss))
            if args.curiosity:
                print("fwd/inv icm loss {:.1f}/{:.1f}\n".format(
                    fwd_loss, inv_loss))

        if (args.eval_interval is not None and len(episode_rewards) > 1
                and j % args.eval_interval == 0):
            if evaluator is None:
                evaluator = Evaluator(args,
                                      actor_critic,
                                      device,
                                      envs=envs,
                                      vec_norm=vec_norm)

            model = evaluator.actor_critic.base

            col_idx = [-1, *range(0, n_cols, col_step)]
            for i in col_idx:
                evaluator.evaluate(column=i)
        #num_eval_frames = (args.num_frames // (args.num_steps * args.eval_interval * args.num_processes)) * args.num_processes *  args.max_step
        # making sure the evaluator plots the '-1'st column (the overall net)

            if args.vis:  #and j % args.vis_interval == 0:
                try:
                    # Sometimes monitor doesn't properly flush the outputs
                    win_eval = evaluator.plotter.visdom_plot(
                        viz,
                        win_eval,
                        evaluator.eval_log_dir,
                        graph_name,
                        args.algo,
                        args.num_frames,
                        n_graphs=col_idx)
                except IOError:
                    pass
        #elif args.model == 'fixed' and model.RAND:
        #    for i in model.eval_recs:
        #        evaluator.evaluate(num_recursions=i)
        #    win_eval = visdom_plot(viz, win_eval, evaluator.eval_log_dir, graph_name,
        #                           args.algo, args.num_frames, n_graphs=model.eval_recs)
        #else:
        #    evaluator.evaluate(column=-1)
        #    win_eval = visdom_plot(viz, win_eval, evaluator.eval_log_dir, graph_name,
        #                  args.algo, args.num_frames)
            reset_eval = True

        if j % args.save_interval == 0 and args.save_dir != "":
            save_path = os.path.join(args.save_dir)
            try:
                os.makedirs(save_path)
            except OSError:
                pass

            # A really ugly way to save a model to CPU
            save_model = actor_critic
            ob_rms = getattr(get_vec_normalize(envs), 'ob_rms', None)
            save_model = copy.deepcopy(actor_critic)
            save_agent = copy.deepcopy(agent)
            if args.cuda:
                save_model.cpu()
            optim_save = save_agent.optimizer.state_dict()

            # experimental:
            torch.save(
                {
                    'past_steps':
                    next(iter(agent.optimizer.state_dict()['state'].values()))
                    ['step'],
                    'model_state_dict':
                    save_model.state_dict(),
                    'optimizer_state_dict':
                    optim_save,
                    'ob_rms':
                    ob_rms,
                    'args':
                    args,
                    'alp_gmm':
                    alp_gmm
                }, os.path.join(save_path, args.env_name + ".tar"))

        #save_model = [save_model,
        #              getattr(get_vec_normalize(envs), 'ob_rms', None)]

        #torch.save(save_model, os.path.join(save_path, args.env_name + ".pt"))
        #save_agent = copy.deepcopy(agent)

        #torch.save(save_agent, os.path.join(save_path, args.env_name + '_agent.pt'))
        #torch.save(actor_critic.state_dict(), os.path.join(save_path, args.env_name + "_weights.pt"))

        if args.vis and j % args.vis_interval == 0:
            if plotter is None:
                plotter = Plotter(n_cols, args.log_dir, args.num_processes)
            try:
                # Sometimes monitor doesn't properly flush the outputs
                win = plotter.visdom_plot(viz, win, args.log_dir, graph_name,
                                          args.algo, args.num_frames)
            except IOError:
                pass