Exemplo n.º 1
0
def act(i: int, free_queue: mp.SimpleQueue, full_queue: mp.SimpleQueue,
        model: torch.nn.Module, buffers: Buffers,
        episode_state_count_dict: dict, train_state_count_dict: dict,
        initial_agent_state_buffers, flags):
    try:
        log.info('Actor %i started.', i)
        timings = prof.Timings()

        gym_env = create_env(flags)
        seed = i ^ int.from_bytes(os.urandom(4), byteorder='little')
        gym_env.seed(seed)

        if flags.num_input_frames > 1:
            gym_env = FrameStack(gym_env, flags.num_input_frames)

        env = Environment(gym_env,
                          fix_seed=flags.fix_seed,
                          env_seed=flags.env_seed)

        env_output = env.initial()
        agent_state = model.initial_state(batch_size=1)
        agent_output, unused_state = model(env_output, agent_state)

        while True:
            index = free_queue.get()
            if index is None:
                break

            # Write old rollout end.
            for key in env_output:
                buffers[key][index][0, ...] = env_output[key]
            for key in agent_output:
                buffers[key][index][0, ...] = agent_output[key]
            for i, tensor in enumerate(agent_state):
                initial_agent_state_buffers[index][i][...] = tensor

            # Update the episodic state counts
            episode_state_key = tuple(env_output['frame'].view(-1).tolist())
            if episode_state_key in episode_state_count_dict:
                episode_state_count_dict[episode_state_key] += 1
            else:
                episode_state_count_dict.update({episode_state_key: 1})
            buffers['episode_state_count'][index][0, ...] = \
                torch.tensor(1 / np.sqrt(episode_state_count_dict.get(episode_state_key)))

            # Reset the episode state counts when the episode is over
            if env_output['done'][0][0]:
                for episode_state_key in episode_state_count_dict:
                    episode_state_count_dict = dict()

            # Update the training state counts if you're doing count-based exploration
            if flags.model == 'count':
                train_state_key = tuple(env_output['frame'].view(-1).tolist())
                if train_state_key in train_state_count_dict:
                    train_state_count_dict[train_state_key] += 1
                else:
                    train_state_count_dict.update({train_state_key: 1})
                buffers['train_state_count'][index][0, ...] = \
                    torch.tensor(1 / np.sqrt(train_state_count_dict.get(train_state_key)))

            # Do new rollout
            for t in range(flags.unroll_length):
                timings.reset()

                with torch.no_grad():
                    agent_output, agent_state = model(env_output, agent_state)

                timings.time('model')

                env_output = env.step(agent_output['action'])

                timings.time('step')

                for key in env_output:
                    buffers[key][index][t + 1, ...] = env_output[key]

                for key in agent_output:
                    buffers[key][index][t + 1, ...] = agent_output[key]

                # Update the episodic state counts
                episode_state_key = tuple(
                    env_output['frame'].view(-1).tolist())
                if episode_state_key in episode_state_count_dict:
                    episode_state_count_dict[episode_state_key] += 1
                else:
                    episode_state_count_dict.update({episode_state_key: 1})
                buffers['episode_state_count'][index][t + 1, ...] = \
                    torch.tensor(1 / np.sqrt(episode_state_count_dict.get(episode_state_key)))

                # Reset the episode state counts when the episode is over
                if env_output['done'][0][0]:
                    episode_state_count_dict = dict()

                # Update the training state counts if you're doing count-based exploration
                if flags.model == 'count':
                    train_state_key = tuple(
                        env_output['frame'].view(-1).tolist())
                    if train_state_key in train_state_count_dict:
                        train_state_count_dict[train_state_key] += 1
                    else:
                        train_state_count_dict.update({train_state_key: 1})
                    buffers['train_state_count'][index][t + 1, ...] = \
                        torch.tensor(1 / np.sqrt(train_state_count_dict.get(train_state_key)))

                timings.time('write')
            full_queue.put(index)

        if i == 0:
            log.info('Actor %i: %s', i, timings.summary())

    except KeyboardInterrupt:
        pass
    except Exception as e:
        logging.error('Exception in worker process %i', i)
        traceback.print_exc()
        print()
        raise e
Exemplo n.º 2
0
def train(flags):
    if flags.xpid is None:
        flags.xpid = 'torchbeast-%s' % time.strftime('%Y%m%d-%H%M%S')
    plogger = file_writer.FileWriter(
        xpid=flags.xpid,
        xp_args=flags.__dict__,
        rootdir=flags.savedir,
    )
    checkpointpath = os.path.expandvars(
        os.path.expanduser('%s/%s/%s' %
                           (flags.savedir, flags.xpid, 'model.tar')))

    T = flags.unroll_length
    B = flags.batch_size

    flags.device = None
    if not flags.disable_cuda and torch.cuda.is_available():
        log.info('Using CUDA.')
        flags.device = torch.device('cuda')
    else:
        log.info('Not using CUDA.')
        flags.device = torch.device('cpu')

    env = create_env(flags)
    if flags.num_input_frames > 1:
        env = FrameStack(env, flags.num_input_frames)

    if 'MiniGrid' in flags.env:
        if flags.use_fullobs_policy:
            model = FullObsMinigridPolicyNet(env.observation_space.shape,
                                             env.action_space.n)
        else:
            model = MinigridPolicyNet(env.observation_space.shape,
                                      env.action_space.n)
    else:
        model = MarioDoomPolicyNet(env.observation_space.shape,
                                   env.action_space.n)

    buffers = create_buffers(env.observation_space.shape, model.num_actions,
                             flags)
    if flags.histogram_length:
        action_hist = WindowedHisto(env.action_space.n, flags.histogram_length)
    else:
        action_hist = FlatHisto(env.action_space.n)

    model.share_memory()

    initial_agent_state_buffers = []
    for _ in range(flags.num_buffers):
        state = model.initial_state(batch_size=1)
        for t in state:
            t.share_memory_()
        initial_agent_state_buffers.append(state)

    actor_processes = []
    ctx = mp.get_context('fork')
    free_queue = ctx.SimpleQueue()
    full_queue = ctx.SimpleQueue()

    episode_state_count_dict = dict()
    train_state_count_dict = dict()
    position_count = dict()

    for i in range(flags.num_actors):
        actor = ctx.Process(target=act,
                            args=(i, free_queue, full_queue, model, buffers,
                                  episode_state_count_dict,
                                  train_state_count_dict,
                                  initial_agent_state_buffers, flags),
                            daemon=True)
        actor.start()
        actor_processes.append(actor)

    if 'MiniGrid' in flags.env:
        if flags.use_fullobs_policy:
            learner_model = FullObsMinigridPolicyNet(env.observation_space.shape, env.action_space.n)\
                .to(device=flags.device)
        else:
            learner_model = MinigridPolicyNet(env.observation_space.shape, env.action_space.n)\
                .to(device=flags.device)
    else:
        learner_model = MarioDoomPolicyNet(env.observation_space.shape, env.action_space.n)\
            .to(device=flags.device)

    optimizer = torch.optim.RMSprop(learner_model.parameters(),
                                    lr=flags.learning_rate,
                                    momentum=flags.momentum,
                                    eps=flags.epsilon,
                                    alpha=flags.alpha)

    def lr_lambda(epoch):
        return 1 - min(epoch * T * B, flags.total_frames) / flags.total_frames

    scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)

    logger = logging.getLogger('logfile')
    stat_keys = [
        'total_loss',
        'mean_episode_return',
        'pg_loss',
        'baseline_loss',
        'entropy_loss',
        'mean_rewards',
        'mean_intrinsic_rewards',
        'mean_total_rewards',
    ]

    logger.info('# Step\t%s', '\t'.join(stat_keys))
    frames, stats = 0, {}

    def batch_and_learn(i, lock=threading.Lock()):
        """Thread target for the learning process."""
        nonlocal frames, stats
        timings = prof.Timings()
        while frames < flags.total_frames:
            timings.reset()
            batch, agent_state = get_batch(free_queue, full_queue, buffers,
                                           initial_agent_state_buffers, flags,
                                           timings)
            stats = learn(model,
                          learner_model,
                          batch,
                          agent_state,
                          optimizer,
                          scheduler,
                          flags,
                          position_count=position_count,
                          action_hist=action_hist)

            timings.time('learn')
            with lock:
                to_log = dict(frames=frames)
                to_log.update({k: stats[k] for k in stat_keys})
                plogger.log(to_log)
                frames += T * B

        if i == 0:
            log.info('Batch and learn: %s', timings.summary())

    for m in range(flags.num_buffers):
        free_queue.put(m)

    threads = []
    for i in range(flags.num_threads):
        thread = threading.Thread(target=batch_and_learn,
                                  name='batch-and-learn-%d' % i,
                                  args=(i, ))
        thread.start()
        threads.append(thread)

    def checkpoint(frames):
        if flags.disable_checkpoint:
            return
        checkpointpath = os.path.expandvars(
            os.path.expanduser('%s/%s/%s' %
                               (flags.savedir, flags.xpid, 'model.tar')))
        log.info('Saving checkpoint to %s', checkpointpath)
        torch.save(
            {
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'scheduler_state_dict': scheduler.state_dict(),
                'position_count': position_count,
                'flags': vars(flags),
            }, checkpointpath)

        # Action histogram logger
        action_hist_path = os.path.expandvars(
            os.path.expanduser('%s/%s/%s' %
                               (flags.savedir, flags.xpid, 'action_hist.tar')))
        try:
            action_hist_list = torch.load(action_hist_path)
        except FileNotFoundError:
            action_hist_list = []

        action_hist_list.append(action_hist.return_full_hist())
        torch.save(action_hist_list, action_hist_path)

    timer = timeit.default_timer
    try:
        last_checkpoint_time = timer()
        while frames < flags.total_frames:

            add_actor = []
            for actor_num, actor in enumerate(actor_processes):
                if not actor.is_alive():
                    actor.terminate()
                    new_actor = ctx.Process(
                        target=act,
                        args=(actor_num, free_queue, full_queue, model,
                              buffers, episode_state_count_dict,
                              train_state_count_dict,
                              initial_agent_state_buffers, flags),
                        daemon=True)
                    new_actor.start()
                    add_actor.append([actor_num, new_actor])

            for num, actor in add_actor:
                actor_processes[num] = actor

            start_frames = frames
            start_time = timer()
            time.sleep(5)

            if timer() - last_checkpoint_time > flags.save_interval * 60:
                checkpoint(frames)
                last_checkpoint_time = timer()

            fps = (frames - start_frames) / (timer() - start_time)
            if stats.get('episode_returns', None):
                mean_return = 'Return per episode: %.1f. ' % stats[
                    'mean_episode_return']
            else:
                mean_return = ''
            total_loss = stats.get('total_loss', float('inf'))
            log.info('After %i frames: loss %f @ %.1f fps. %sStats:\n%s',
                     frames, total_loss, fps, mean_return,
                     pprint.pformat(stats))

    except KeyboardInterrupt:
        return
    else:
        for thread in threads:
            thread.join(timeout=1)
        log.info('Learning finished after %d frames.', frames)

    finally:
        for _ in range(flags.num_actors):
            free_queue.put(None)
        for actor in actor_processes:
            actor.join(timeout=1)
    checkpoint(frames)
    plogger.close()
    quit()
def act(i: int, free_queue: mp.SimpleQueue, full_queue: mp.SimpleQueue,
        model: torch.nn.Module, buffers: Buffers, initial_agent_state_buffers, flags):
    try:
        log.info('Actor %i started.', i)
        timings = prof.Timings() 
        
        gym_env = create_env(flags)
        seed = i ^ int.from_bytes(os.urandom(4), byteorder='little')
        gym_env.seed(seed)
        
        if flags.num_input_frames > 1:
            gym_env = FrameStack(gym_env, flags.num_input_frames)  

        env = Environment(gym_env, fix_seed=flags.fix_seed, env_seed=flags.env_seed)

        env_output = env.initial()
        agent_state = model.initial_state(batch_size=1)
        agent_output, unused_state = model(env_output, agent_state)
        

        while True:
            index = free_queue.get()
            if index is None:
                break
            
            for key in env_output:
                buffers[key][index][0, ...] = env_output[key]
            for key in agent_output:
                buffers[key][index][0, ...] = agent_output[key]
            for i, tensor in enumerate(agent_state):
                initial_agent_state_buffers[index][i][...] = tensor

            for t in range(flags.unroll_length):
                timings.reset()

                with torch.no_grad():
                    agent_output, agent_state = model(env_output, agent_state)

                timings.time('model')

                env_output = env.step(agent_output['action'])

                timings.time('step')

                for key in env_output:
                    buffers[key][index][t + 1, ...] = env_output[key]
                for key in agent_output:
                    buffers[key][index][t + 1, ...] = agent_output[key]
                
                timings.time('write')
            full_queue.put(index)

        if i == 0:
            log.info('Actor %i: %s', i, timings.summary())

    except KeyboardInterrupt:
        pass  
    except Exception as e:
        logging.error('Exception in worker process %i', i)
        traceback.print_exc()
        print()
        raise e