コード例 #1
0
ファイル: main.py プロジェクト: wwxFromTju/gym-city
def main():
    import random
    import gym_micropolis
    import game_of_life

    args = get_args()
    args.log_dir = args.save_dir + '/logs'
    assert args.algo in ['a2c', 'ppo', 'acktr']
    if args.recurrent_policy:
        assert args.algo in ['a2c', 'ppo'], \
            'Recurrent policy is not implemented for ACKTR'

    num_updates = int(args.num_frames) // args.num_steps // args.num_processes

    torch.manual_seed(args.seed)
    if args.cuda:
        torch.cuda.manual_seed(args.seed)

    graph_name = args.save_dir.split('trained_models/')[1].replace('/', ' ')

    actor_critic = False
    agent = False
    past_steps = 0
    try:
        os.makedirs(args.log_dir)
    except OSError:
        files = glob.glob(os.path.join(args.log_dir, '*.monitor.csv'))
        for f in files:
            if args.overwrite:
                os.remove(f)
            else:
                pass
    torch.set_num_threads(1)
    device = torch.device("cuda:0" if args.cuda else "cpu")

    if args.vis:
        from visdom import Visdom
        viz = Visdom(port=args.port)
        win = None
        win_eval = None
    if 'GameOfLife' in args.env_name:
        print('env name: {}'.format(args.env_name))
        num_actions = 1
    envs = make_vec_envs(args.env_name, args.seed, args.num_processes,
                        args.gamma, args.log_dir, args.add_timestep, device, False, None,

                        args=args)

    if isinstance(envs.observation_space, gym.spaces.Discrete):
        num_inputs = envs.observation_space.n
    elif isinstance(envs.observation_space, gym.spaces.Box):
        if len(envs.observation_space.shape) == 3:
            in_w = envs.observation_space.shape[1]
            in_h = envs.observation_space.shape[2]
        else:
            in_w = 1
            in_h = 1
        num_inputs = envs.observation_space.shape[0]
    if isinstance(envs.action_space, gym.spaces.Discrete):
        out_w = 1
        out_h = 1
        if 'Micropolis' in args.env_name: #otherwise it's set
            if args.power_puzzle:
                num_actions = 1
            else:
                num_actions = 19 # TODO: have this already from env
        elif 'GameOfLife' in args.env_name:
            num_actions = 1
        else:
            num_actions = envs.action_space.n
    elif isinstance(envs.action_space, gym.spaces.Box):
        if len(envs.action_space.shape) == 3:
            out_w = envs.action_space.shape[1]
            out_h = envs.action_space.shape[2]
        elif len(envs.action_space.shape) == 1:
            out_w = 1
            out_h = 1
        num_actions = envs.action_space.shape[-1]
    print('num actions {}'.format(num_actions))

    if args.auto_expand:
        args.n_recs -= 1
    actor_critic = Policy(envs.observation_space.shape, envs.action_space,
        base_kwargs={'map_width': args.map_width, 'num_actions': num_actions,
            'recurrent': args.recurrent_policy,
            'in_w': in_w, 'in_h': in_h, 'num_inputs': num_inputs,
            'out_w': out_w, 'out_h': out_h},
                     curiosity=args.curiosity, algo=args.algo,
                     model=args.model, args=args)
    if args.auto_expand:
        args.n_recs += 1

    evaluator = None

    if not agent:
        agent = init_agent(actor_critic, args)

   #saved_model = os.path.join(args.save_dir, args.env_name + '.pt')
    if args.load_dir:
        saved_model = os.path.join(args.load_dir, args.env_name + '.tar')
    else:
        saved_model = os.path.join(args.save_dir, args.env_name + '.tar')
    vec_norm = get_vec_normalize(envs)
    if os.path.exists(saved_model) and not args.overwrite:
        checkpoint = torch.load(saved_model)
        saved_args = checkpoint['args']
        actor_critic.load_state_dict(checkpoint['model_state_dict'])
       #for o, l in zip(agent.optimizer.state_dict, checkpoint['optimizer_state_dict']):
       #    print(o, l)
       #print(agent.optimizer.state_dict()['param_groups'])
       #print('\n')
       #print(checkpoint['model_state_dict'])
        actor_critic.to(device)
        actor_critic.cuda()
       #agent = init_agent(actor_critic, saved_args)
        agent.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        if args.auto_expand:
            if not args.n_recs - saved_args.n_recs == 1:
                print('can expand by 1 rec only from saved model, not {}'.format(args.n_recs - saved_args.n_recs))
                raise Exception
            actor_critic.base.auto_expand()
            print('expanded net: \n{}'.format(actor_critic.base))
        past_steps = checkpoint['past_steps']
        ob_rms = checkpoint['ob_rms']

        past_steps = next(iter(agent.optimizer.state_dict()['state'].values()))['step']
        print('Resuming from step {}'.format(past_steps))

       #print(type(next(iter((torch.load(saved_model))))))
       #actor_critic, ob_rms = \
       #        torch.load(saved_model)
       #agent = \
       #    torch.load(os.path.join(args.save_dir, args.env_name + '_agent.pt'))
       #if not agent.optimizer.state_dict()['state'].values():
       #    past_steps = 0
       #else:

       #    raise Exception

        if vec_norm is not None:
            vec_norm.eval()
            vec_norm.ob_rms = ob_rms
        saved_args.num_frames = args.num_frames
        saved_args.vis_interval = args.vis_interval
        saved_args.eval_interval = args.eval_interval
        saved_args.overwrite = args.overwrite
        saved_args.n_recs = args.n_recs
        saved_args.intra_shr = args.intra_shr
        saved_args.inter_shr = args.inter_shr
        saved_args.map_width = args.map_width
        saved_args.render = args.render
        saved_args.print_map = args.print_map
        saved_args.load_dir = args.load_dir
        saved_args.experiment_name = args.experiment_name
        saved_args.log_dir = args.log_dir
        saved_args.save_dir = args.save_dir
        args = saved_args
    actor_critic.to(device)

    if 'LSTM' in args.model:
        recurrent_hidden_state_size = actor_critic.base.get_recurrent_state_size()
    else:
        recurrent_hidden_state_size = actor_critic.recurrent_hidden_state_size
    if args.curiosity:
        rollouts = CuriosityRolloutStorage(args.num_steps, args.num_processes,
                            envs.observation_space.shape, envs.action_space,
                            recurrent_hidden_state_size, actor_critic.base.feature_state_size(), args=args)
    else:
        rollouts = RolloutStorage(args.num_steps, args.num_processes,
                            envs.observation_space.shape, envs.action_space,
                            recurrent_hidden_state_size, args=args)

    obs = envs.reset()
    rollouts.obs[0].copy_(obs)
    rollouts.to(device)

    episode_rewards = deque(maxlen=10)

    start = time.time()
    model = actor_critic.base
    reset_eval = False
    plotter = None
    if args.model == 'FractalNet' or args.model == 'fractal':
        n_cols = model.n_cols
        if args.rule == 'wide1' and args.n_recs > 3:
            col_step = 3
        else:
            col_step = 1
    else:
        n_cols = 0
        col_step = 1
    for j in range(past_steps, num_updates):
        if reset_eval:
            print('post eval reset')
            obs = envs.reset()
            rollouts.obs[0].copy_(obs)
            rollouts.to(device)
            reset_eval = False
       #if np.random.rand(1) < 0.1:
       #    envs.venv.venv.remotes[1].send(('setRewardWeights', None))
        if args.model == 'FractalNet' and args.drop_path:
           #if args.intra_shr and args.inter_shr:
           #    n_recs = np.randint
           #    model.set_n_recs()
            model.set_drop_path()
        if args.model == 'fixed' and model.RAND:
            model.num_recursions = random.randint(1, model.map_width * 2)
        player_act = None
        for step in range(args.num_steps):
            # Sample actions
            with torch.no_grad():
                if args.render:
                    if args.num_processes == 1:
                        if not ('Micropolis' in args.env_name or 'GameOfLife' in args.env_name):
                            envs.venv.venv.render()
                        else:
                            pass
                    else:
                        if not ('Micropolis' in args.env_name or 'GameOfLife' in args.env_name):
                            envs.render()
                            envs.venv.venv.render()
                        else:
                            pass
                           #envs.venv.venv.remotes[0].send(('render', None))
                           #envs.venv.venv.remotes[0].recv()
                value, action, action_log_probs, recurrent_hidden_states = actor_critic.act(
                        rollouts.obs[step],
                        rollouts.recurrent_hidden_states[step],
                        rollouts.masks[step],
                        player_act=player_act,
                        icm_enabled=args.curiosity,
                        deterministic=False)

            # Observe reward and next obs
            obs, reward, done, infos = envs.step(action)

            player_act = None
            if args.render:
                if infos[0]:
                    if 'player_move' in infos[0].keys():
                        player_act = infos[0]['player_move']
            if args.curiosity:
                # run icm
                with torch.no_grad():


                    feature_state, feature_state_pred, action_dist_pred = actor_critic.icm_act(
                            (rollouts.obs[step], obs, action_bin)
                            )

                intrinsic_reward = args.eta * ((feature_state - feature_state_pred).pow(2)).sum() / 2.
                if args.no_reward:
                    reward = 0
                reward += intrinsic_reward.cpu()

            for info in infos:
                if 'episode' in info.keys():
                    episode_rewards.append(info['episode']['r'])

            # If done then clean the history of observations.
            masks = torch.FloatTensor([[0.0] if done_ else [1.0]
                                       for done_ in done])
            if args.curiosity:
                rollouts.insert(obs, recurrent_hidden_states, action, action_log_probs, value, reward, masks,
                                feature_state, feature_state_pred, action_bin, action_dist_pred)
            else:
                rollouts.insert(obs, recurrent_hidden_states, action, action_log_probs, value, reward, masks)

        with torch.no_grad():
            next_value = actor_critic.get_value(rollouts.obs[-1],
                                                rollouts.recurrent_hidden_states[-1],
                                                rollouts.masks[-1]).detach()

        rollouts.compute_returns(next_value, args.use_gae, args.gamma, args.tau)
        if args.curiosity:
            value_loss, action_loss, dist_entropy, fwd_loss, inv_loss = agent.update(rollouts)
        else:
            value_loss, action_loss, dist_entropy = agent.update(rollouts)

        rollouts.after_update()



        total_num_steps = (j + 1) * args.num_processes * args.num_steps

        if not dist_entropy:
            dist_entropy = 0
        if j % args.log_interval == 0 and len(episode_rewards) > 1:
            end = time.time()
            print("Updates {}, num timesteps {}, FPS {} \n Last {} training episodes: mean/median reward {:.1f}/{:.1f}, min/max reward {:.1f}/{:.1f}\n \
dist entropy {:.1f}, val/act loss {:.1f}/{:.1f},".
                format(j, total_num_steps,
                       int((total_num_steps - past_steps * args.num_processes * args.num_steps) / (end - start)),
                       len(episode_rewards),
                       np.mean(episode_rewards),
                       np.median(episode_rewards),
                       np.min(episode_rewards),
                       np.max(episode_rewards), dist_entropy,
                       value_loss, action_loss))
            if args.curiosity:
                print("fwd/inv icm loss {:.1f}/{:.1f}\n".
                format(
                       fwd_loss, inv_loss))

        if (args.eval_interval is not None and len(episode_rewards) > 1
                and j % args.eval_interval == 0):
            if evaluator is None:
                evaluator = Evaluator(args, actor_critic, device, envs=envs, vec_norm=vec_norm)


            model = evaluator.actor_critic.base

            col_idx = [-1, *range(0, n_cols, col_step)]
            for i in col_idx:
                evaluator.evaluate(column=i)
           #num_eval_frames = (args.num_frames // (args.num_steps * args.eval_interval * args.num_processes)) * args.num_processes *  args.max_step
           # making sure the evaluator plots the '-1'st column (the overall net)

            if args.vis: #and j % args.vis_interval == 0:
                try:
                    # Sometimes monitor doesn't properly flush the outputs
                    win_eval = evaluator.plotter.visdom_plot(viz, win_eval, evaluator.eval_log_dir, graph_name,
                                  args.algo, args.num_frames, n_graphs= col_idx)
                except IOError:
                    pass
           #elif args.model == 'fixed' and model.RAND:
           #    for i in model.eval_recs:
           #        evaluator.evaluate(num_recursions=i)
           #    win_eval = visdom_plot(viz, win_eval, evaluator.eval_log_dir, graph_name,
           #                           args.algo, args.num_frames, n_graphs=model.eval_recs)
           #else:
           #    evaluator.evaluate(column=-1)
           #    win_eval = visdom_plot(viz, win_eval, evaluator.eval_log_dir, graph_name,
           #                  args.algo, args.num_frames)
            reset_eval = True

        if j % args.save_interval == 0 and args.save_dir != "":
            save_path = os.path.join(args.save_dir)
            try:
                os.makedirs(save_path)
            except OSError:
                pass

            # A really ugly way to save a model to CPU
            save_model = actor_critic
            ob_rms = getattr(get_vec_normalize(envs), 'ob_rms', None)
            save_model = copy.deepcopy(actor_critic)
            save_agent = copy.deepcopy(agent)
            if args.cuda:
                save_model.cpu()
            optim_save = save_agent.optimizer.state_dict()

            # experimental:
            torch.save({
                'past_steps': next(iter(agent.optimizer.state_dict()['state'].values()))['step'],
                'model_state_dict': save_model.state_dict(),
                'optimizer_state_dict': optim_save,
                'ob_rms': ob_rms,
                'args': args
                }, os.path.join(save_path, args.env_name + ".tar"))

           #save_model = [save_model,
           #              getattr(get_vec_normalize(envs), 'ob_rms', None)]

           #torch.save(save_model, os.path.join(save_path, args.env_name + ".pt"))
           #save_agent = copy.deepcopy(agent)

           #torch.save(save_agent, os.path.join(save_path, args.env_name + '_agent.pt'))
           #torch.save(actor_critic.state_dict(), os.path.join(save_path, args.env_name + "_weights.pt"))

        if args.vis and j % args.vis_interval == 0:
            if plotter is None:
                plotter = Plotter(n_cols, args.log_dir, args.num_processes)
            try:
                # Sometimes monitor doesn't properly flush the outputs
                win = plotter.visdom_plot(viz, win, args.log_dir, graph_name,
                                  args.algo, args.num_frames)
            except IOError:
                pass
コード例 #2
0
ファイル: train.py プロジェクト: njustesen/gym-city
    def train(self):
        evaluator = self.evaluator
        episode_rewards = self.episode_rewards
        args = self.args
        actor_critic = self.actor_critic
        rollouts = self.rollouts
        agent = self.agent
        envs = self.envs
        plotter = self.plotter
        n_train = self.n_train
        start = self.start
        plotter = self.plotter
        n_cols = self.n_cols
        model = self.model
        device = self.device
        vec_norm = self.vec_norm
        n_frames = self.n_frames
        if self.reset_eval:
            obs = envs.reset()
            rollouts.obs[0].copy_(obs)
            rollouts.to(device)
            self.reset_eval = False
        if args.model == 'FractalNet' and args.drop_path:
            model.set_drop_path()
        if args.model == 'fixed' and model.RAND:
            model.num_recursions = random.randint(1, model.map_width * 2)
        self.player_act = None
        for self.n_step in range(args.num_steps):
            # Sample actions
            self.step()

        with torch.no_grad():
            next_value = actor_critic.get_value(
                rollouts.obs[-1], rollouts.recurrent_hidden_states[-1],
                rollouts.masks[-1]).detach()

        rollouts.compute_returns(next_value, args.use_gae, args.gamma,
                                 args.tau)
        if args.curiosity:
            value_loss, action_loss, dist_entropy, fwd_loss, inv_loss = agent.update(
                rollouts)
        else:
            value_loss, action_loss, dist_entropy = agent.update(rollouts)
        envs.dist_entropy = dist_entropy

        rollouts.after_update()

        total_num_steps = (n_train + 1) * args.num_processes * args.num_steps

        if not dist_entropy:
            dist_entropy = 0

    #print(episode_rewards)
    #if torch.max(rollouts.rewards) > 0:
    #    print(rollouts.rewards)
        if args.log and n_train % args.log_interval == 0 and len(
                episode_rewards) > 1:
            end = time.time()
            print(
                "Updates {}, num timesteps {}, FPS {} \n Last {} training episodes: mean/median reward {:.6f}/{:.6f}, min/max reward {:.6f}/{:.6f}\n \
dist entropy {:.6f}, val/act loss {:.6f}/{:.6f},".format(
                    n_train, total_num_steps,
                    int((self.n_frames - self.past_frames) / (end - start)),
                    len(episode_rewards), round(np.mean(episode_rewards), 6),
                    round(np.median(episode_rewards), 6),
                    round(np.min(episode_rewards), 6),
                    round(np.max(episode_rewards), 6), round(dist_entropy, 6),
                    round(value_loss, 6), round(action_loss, 6)))
            if args.curiosity:
                print("fwd/inv icm loss {:.1f}/{:.1f}\n".format(
                    fwd_loss, inv_loss))

        if (args.eval_interval is not None and len(episode_rewards) > 1
                and n_train % args.eval_interval == 0):
            if evaluator is None:
                evaluator = Evaluator(args,
                                      actor_critic,
                                      device,
                                      envs=envs,
                                      vec_norm=vec_norm,
                                      fieldnames=self.fieldnames)
                self.evaluator = evaluator

            col_idx = [-1, *[i for i in range(0, n_cols, self.col_step)]]
            for i in col_idx:
                evaluator.evaluate(column=i)
        #num_eval_frames = (args.num_frames // (args.num_steps * args.eval_interval * args.num_processes)) * args.num_processes *  args.max_step
        # making sure the evaluator plots the '-1'st column (the overall net)
            viz = self.viz
            win_eval = self.win_eval
            graph_name = self.graph_name
            if args.vis:  #and n_train % args.vis_interval == 0:
                try:
                    # Sometimes monitor doesn't properly flush the outputs
                    win_eval = evaluator.plotter.visdom_plot(
                        viz,
                        win_eval,
                        evaluator.eval_log_dir,
                        graph_name,
                        args.algo,
                        args.num_frames,
                        n_graphs=col_idx)
                except IOError:
                    pass
        #elif args.model == 'fixed' and model.RAND:
        #    for i in model.eval_recs:
        #        evaluator.evaluate(num_recursions=i)
        #    win_eval = visdom_plot(viz, win_eval, evaluator.eval_log_dir, graph_name,
        #                           args.algo, args.num_frames, n_graphs=model.eval_recs)
        #else:
        #    evaluator.evaluate(column=-1)
        #    win_eval = visdom_plot(viz, win_eval, evaluator.eval_log_dir, graph_name,
        #                  args.algo, args.num_frames)
            self.reset_eval = True

        if args.save and n_train % args.save_interval == 0 and args.save_dir != "":
            save_path = os.path.join(args.save_dir)
            try:
                os.makedirs(save_path)
            except OSError:
                pass

            # A really ugly way to save a model to CPU
            save_model = actor_critic
            ob_rms = getattr(get_vec_normalize(envs), 'ob_rms', None)
            save_model = copy.deepcopy(actor_critic)
            save_agent = copy.deepcopy(agent)
            if args.cuda:
                save_model.cpu()
            optim_save = save_agent.optimizer.state_dict()
            self.agent = agent
            self.save_model = save_model
            self.optim_save = optim_save
            self.args = args
            self.ob_rms = ob_rms
            torch.save(self.get_save_dict(),
                       os.path.join(save_path, args.env_name + ".tar"))

            #save_model = [save_model,
            #              getattr(get_vec_normalize(envs), 'ob_rms', None)]

            #torch.save(save_model, os.path.join(save_path, args.env_name + ".pt"))
            #save_agent = copy.deepcopy(agent)

            #torch.save(save_agent, os.path.join(save_path, args.env_name + '_agent.pt'))
            #torch.save(actor_critic.state_dict(), os.path.join(save_path, args.env_name + "_weights.pt"))

            print('model saved at {}'.format(save_path))

        if args.vis and n_train % args.vis_interval == 0:
            if plotter is None:
                plotter = Plotter(n_cols, args.log_dir, args.num_processes)
            try:
                # Sometimes monitor doesn't properly flush the outputs
                viz = self.viz
                win = self.win
                graph_name = self.graph_name
                win = plotter.visdom_plot(viz, win, args.log_dir, graph_name,
                                          args.algo, args.num_frames)
            except IOError:
                pass