Beispiel #1
0
def test(flags, num_episodes: int = 10):
    if flags.xpid is None:
        checkpointpath = "./latest/model.tar"
    else:
        checkpointpath = os.path.expandvars(
            os.path.expanduser("%s/%s/%s" % (flags.savedir, flags.xpid, "model.tar"))
        )

    gym_env = create_env(flags)
    env = environment.Environment(gym_env)
    model = AtariNet(gym_env.observation_space.shape, gym_env.action_space.n, flags.use_lstm, flags.use_resnet)
    model.eval()
    checkpoint = torch.load(checkpointpath, map_location="cpu")
    model.load_state_dict(checkpoint["model_state_dict"])

    observation = env.initial()
    returns = []

    while len(returns) < num_episodes:
        if flags.mode == "test_render":
            env.gym_env.render()
        agent_outputs = model(observation)
        policy_outputs, _ = agent_outputs
        observation = env.step(policy_outputs["action"])
        if observation["done"].item():
            returns.append(observation["episode_return"].item())
            logging.info(
                "Episode ended after %d steps. Return: %.1f",
                observation["episode_step"].item(),
                observation["episode_return"].item(),
            )
    env.close()
    logging.info(
        "Average returns over %i steps: %.1f", num_episodes, sum(returns) / len(returns)
    )
Beispiel #2
0
def task_messenger(task):
    slzr = serializer.StandardSerializer()
    scheduler = SingleTaskScheduler(task)
    env = environment.Environment(slzr,
                                  scheduler,
                                  max_reward_per_task=float("inf"),
                                  byte_mode=True)
    return EnvironmentByteMessenger(env, slzr)
Beispiel #3
0
 def init_env(self, task, success_threshold=2):
     slzr = serializer.StandardSerializer()
     scheduler = ConsecutiveTaskScheduler([task], success_threshold)
     env = environment.Environment(slzr,
                                   scheduler,
                                   max_reward_per_task=float("inf"),
                                   byte_mode=True)
     messenger = EnvironmentByteMessenger(env, slzr)
     return (scheduler, messenger)
Beispiel #4
0
def act(i: int, free_queue: mp.SimpleQueue, full_queue: mp.SimpleQueue,
        model: torch.nn.Module, buffers: Buffers, flags):
    try:
        logging.info('Actor %i started.', i)
        timings = prof.Timings()  # Keep track of how fast things are.

        gym_env = Net.create_env(flags)
        seed = i ^ int.from_bytes(os.urandom(4), byteorder='little')
        gym_env.seed(seed)
        env = environment.Environment(gym_env)
        env_output = env.initial()
        agent_output = model(env_output)
        while True:
            index = free_queue.get()
            if index is None:
                break

            # Write old rollout end.
            for key in env_output:
                buffers[key][index][0, ...] = env_output[key]
            for key in agent_output:
                buffers[key][index][0, ...] = agent_output[key]

            # Do new rollout
            for t in range(flags.unroll_length):
                timings.reset()

                with torch.no_grad():
                    agent_output = model(env_output)

                timings.time('model')

                env_output = env.step(agent_output['action'])

                timings.time('step')

                for key in env_output:
                    buffers[key][index][t + 1, ...] = env_output[key]
                for key in agent_output:
                    buffers[key][index][t + 1, ...] = agent_output[key]

                timings.time('write')
            full_queue.put(index)

        if i == 0:
            logging.info('Actor %i: %s', i, timings.summary())

    except KeyboardInterrupt:
        pass  # Return silently.
    except Exception as e:
        logging.error('Exception in worker process %i', i)
        traceback.print_exc()
        print()
        raise e
Beispiel #5
0
    def testDynRegistering(self):
        class TestTask(task.Task):
            def __init__(self, *args, **kwargs):
                super(TestTask, self).__init__(*args, **kwargs)
                self.start_handled = False
                self.end_handled = False

            @task.on_start()
            def start_handler(self, event):
                try:
                    self.add_handler(task.on_ended()(self.end_handler.im_func))
                except AttributeError:  # Python 3
                    self.add_handler(task.on_ended()(
                        self.end_handler.__func__))
                self.start_handled = True

            def end_handler(self, event):
                self.end_handled = True

        tt = TestTask(max_time=10)
        env = environment.Environment(SerializerMock(),
                                      SingleTaskScheduler(tt))
        tt.start(env)
        env._register_task_triggers(tt)
        # End cannot be handled
        self.assertFalse(env.raise_event(task.Ended()))
        self.assertFalse(tt.end_handled)
        # Start should be handled
        self.assertTrue(env.raise_event(task.Start()))
        # The start handler should have been executed
        self.assertTrue(tt.start_handled)
        # Now the End should be handled
        self.assertTrue(env.raise_event(task.Ended()))
        # The end handler should have been executed
        self.assertTrue(tt.end_handled)
        env._deregister_task_triggers(tt)
        # Start should not be handled anymore
        self.assertFalse(env.raise_event(task.Start()))
        tt.end_handled = False
        # End should not be handled anymore
        self.assertFalse(env.raise_event(task.Ended()))
        self.assertFalse(tt.end_handled)
        # Register them again! mwahaha (evil laugh) -- lol
        env._register_task_triggers(tt)
        # End should not be handled anymore
        self.assertFalse(env.raise_event(task.Ended()))
        self.assertFalse(tt.end_handled)
        # Deregister them again! mwahahahaha (more evil laugh)
        env._deregister_task_triggers(tt)
        self.assertFalse(env.raise_event(task.Ended()))
        self.assertFalse(tt.end_handled)
Beispiel #6
0
    def testAllInputs(self):
        env = environment.Environment(serializer.StandardSerializer(),
                                      SingleTaskScheduler(NullTask()),
                                      byte_mode=True)
        learner = TryAllInputsLearner()
        s = session.Session(env, learner)

        def on_time_updated(t):
            if t >= 600:
                s.stop()

        s.total_time_updated.register(on_time_updated)

        s.run()
Beispiel #7
0
    def testLimitReward(self):
        env = environment.Environment(serializer.StandardSerializer(),
                                      SingleTaskScheduler(NullTask()))
        learner = LearnerMock()
        s = session.Session(env, learner)

        def on_time_updated(t):
            if t >= 20:
                s.stop()

        s.total_time_updated.register(on_time_updated)

        s.run()
        self.assertLessEqual(s._total_reward, 10)
Beispiel #8
0
 def perform_setup(self, success_threshold=2):
     slzr = serializer.StandardSerializer()
     self.tasks = [
         micro.Micro1Task(),
         micro.Micro2Task(),
         micro.Micro3Task(),
         micro.Micro4Task(),
         micro.Micro5Sub1Task()
     ]
     self.scheduler = ConsecutiveTaskScheduler(self.tasks,
                                               success_threshold)
     self.env = environment.Environment(slzr,
                                        self.scheduler,
                                        max_reward_per_task=float("inf"),
                                        byte_mode=True)
     self.messenger = EnvironmentByteMessenger(self.env, slzr)
Beispiel #9
0
def task_messenger(task_funct, world_funct=None):
    '''
    Returns an EnvironmentMessenger to interact with the created task.
    Args:
        task_func (functor): takes an environment (optionally a world) and
            returns a task object.
        world_func (functor): takes an environment and returns a world
            object.
    '''
    slzr = serializer.StandardSerializer()
    if world_funct:
        world = world_funct()
        task = task_funct(world)
    else:
        task = task_funct()
    scheduler = SingleTaskScheduler(task)
    env = environment.Environment(slzr, scheduler)
    m = EnvironmentMessenger(env, slzr)
    yield m
Beispiel #10
0
    def __init__(self, flags, actor_index):
        self.flags = flags
        self.actor_index = actor_index
        self.logger = logging.getLogger(__name__)
        self.logger.error("Actor %i started.", self.actor_index)

        self.timings = prof.Timings()  # Keep track of how fast things are.

        self.gym_env = create_env(flags)
        seed = actor_index ^ int.from_bytes(os.urandom(4), byteorder="little")
        self.gym_env.seed(seed)
        self.env = environment.Environment(self.gym_env)
        self.env_output = self.env.initial()

        self.net = AtariNet(self.gym_env.observation_space.shape[0],
                            self.gym_env.action_space.n, self.flags.use_lstm,
                            self.flags.use_resnet)
        self.agent_state = self.net.initial_state(batch_size=1)
        self.agent_output, _ = self.net(self.env_output, self.agent_state)
        self.params_idx = -1
Beispiel #11
0
    def testRegistering(self):
        class TestTask(task.Task):
            def __init__(self, *args, **kwargs):
                super(TestTask, self).__init__(*args, **kwargs)
                self.handled = False

            @task.on_start()
            def start_handler(self, event):
                self.handled = True

        tt = TestTask(max_time=10)
        env = environment.Environment(SerializerMock(),
                                      SingleTaskScheduler(tt))
        tt.start(env)
        env._register_task_triggers(tt)
        # Start should be handled
        self.assertTrue(env.raise_event(task.Start()))
        # The start handler should have been executed
        self.assertTrue(tt.handled)
        env._deregister_task_triggers(tt)
        # Start should not be handled anymore
        self.assertFalse(env.raise_event(task.Start()))
Beispiel #12
0
def act(
    flags,
    actor_index: int,
    free_queue: mp.SimpleQueue,
    full_queue: mp.SimpleQueue,
    model: torch.nn.Module,
    buffers: Buffers,
    initial_agent_state_buffers,
):
    try:
        logging.info("Actor %i started.", actor_index)
        timings = prof.Timings()  # Keep track of how fast things are.

        gym_env = create_env(flags)
        seed = actor_index ^ int.from_bytes(os.urandom(4), byteorder="little")
        gym_env.seed(seed)
        env = environment.Environment(gym_env)
        env_output = env.initial()

        #agent_ model.initialize(env_output)
        agent_state = model.initialize(env_output, batch_size=1)
        agent_output, unused_state = model.act(env_output, agent_state)
        while True:
            index = free_queue.get()
            if index is None:
                break

            # Write old rollout end.
            for key in env_output:
                buffers[key][index][0, ...] = env_output[key]
            for key in agent_output:
                buffers[key][index][0, ...] = agent_output[key]
            for i, tensor in enumerate(agent_state["core_state"]):
                initial_agent_state_buffers[index][i][...] = tensor

            # Do new rollout.
            for t in range(flags.unroll_length):
                timings.reset()

                with torch.no_grad():
                    agent_output, agent_state = model.act(
                        env_output, agent_state)

                timings.time("model")

                env_output = env.step(agent_output["action"])

                timings.time("step")

                for key in env_output:
                    buffers[key][index][t + 1, ...] = env_output[key]
                for key in agent_output:
                    buffers[key][index][t + 1, ...] = agent_output[key]

                timings.time("write")
            full_queue.put(index)

        if actor_index == 0:
            logging.info("Actor %i: %s", actor_index, timings.summary())

    except KeyboardInterrupt:
        pass  # Return silently.
    except Exception as e:
        logging.error("Exception in worker process %i", actor_index)
        traceback.print_exc()
        print()
        raise e
Beispiel #13
0
def test(flags, num_eps: int = 1000):
    from rtfm import featurizer as X
    gym_env = Net.create_env(flags)
    if flags.mode == 'test_render':
        gym_env.featurizer = X.Concat([gym_env.featurizer, X.Terminal()])
    env = environment.Environment(gym_env)

    if not flags.random_agent:
        model = Net.make(flags, gym_env)
        model.eval()
        if flags.xpid is None:
            checkpointpath = './results_latest/model.tar'
        else:
            checkpointpath = os.path.expandvars(
                os.path.expanduser('%s/%s/%s' % (flags.savedir, flags.xpid,
                                                 'model.tar')))
        checkpoint = torch.load(checkpointpath, map_location='cpu')
        model.load_state_dict(checkpoint['model_state_dict'])

    observation = env.initial()
    returns = []
    won = []
    entropy = []
    ep_len = []
    while len(won) < num_eps:
        done = False
        steps = 0
        while not done:
            if flags.random_agent:
                action = torch.zeros(1, 1, dtype=torch.int32)
                action[0][0] = random.randint(0, gym_env.action_space.n - 1)
                observation = env.step(action)
            else:
                agent_outputs = model(observation)
                observation = env.step(agent_outputs['action'])
                policy = F.softmax(agent_outputs['policy_logits'], dim=-1)
                log_policy = F.log_softmax(agent_outputs['policy_logits'], dim=-1)
                e = -torch.sum(policy * log_policy, dim=-1)
                entropy.append(e.mean(0).item())

            steps += 1
            done = observation['done'].item()
            if observation['done'].item():
                returns.append(observation['episode_return'].item())
                won.append(observation['reward'][0][0].item() > 0.5)
                ep_len.append(steps)
                # logging.info('Episode ended after %d steps. Return: %.1f',
                #              observation['episode_step'].item(),
                #              observation['episode_return'].item())
            if flags.mode == 'test_render':
                sleep_seconds = os.environ.get('DELAY', '0.3')
                time.sleep(float(sleep_seconds))

                if observation['done'].item():
                    print('Done: {}'.format('You won!!' if won[-1] else 'You lost!!'))
                    print('Episode steps: {}'.format(observation['episode_step']))
                    print('Episode return: {}'.format(observation['episode_return']))
                    done_seconds = os.environ.get('DONE', None)
                    if done_seconds is None:
                        print('Press Enter to continue')
                        input()
                    else:
                        time.sleep(float(done_seconds))

    env.close()
    logging.info('Average returns over %i episodes: %.2f. Win rate: %.2f. Entropy: %.2f. Len: %.2f', num_eps, sum(returns)/len(returns), sum(won)/len(returns), sum(entropy)/max(1, len(entropy)), sum(ep_len)/len(ep_len))
Beispiel #14
0
def train(flags):  # pylint: disable=too-many-branches, too-many-statements
    if flags.xpid is None:
        flags.xpid = "torchbeast-%s" % time.strftime("%Y%m%d-%H%M%S")
    plogger = file_writer.FileWriter(xpid=flags.xpid,
                                     xp_args=flags.__dict__,
                                     rootdir=flags.savedir)
    checkpointpath = os.path.expandvars(
        os.path.expanduser("%s/%s/%s" %
                           (flags.savedir, flags.xpid, "model.tar")))

    if flags.num_buffers is None:  # Set sensible default for num_buffers.
        flags.num_buffers = max(2 * flags.num_actors, flags.batch_size)
    if flags.num_actors >= flags.num_buffers:
        raise ValueError("num_buffers should be larger than num_actors")
    if flags.num_buffers < flags.batch_size:
        raise ValueError("num_buffers should be larger than batch_size")

    T = flags.unroll_length
    B = flags.batch_size

    flags.device = None
    if not flags.disable_cuda and torch.cuda.is_available():
        logging.info("Using CUDA.")
        flags.device = torch.device("cuda")
    else:
        logging.info("Not using CUDA.")
        flags.device = torch.device("cpu")

    gym_env = environment.create_gym_env(flags, seed=flags.seed)

    env = environment.Environment(flags, gym_env)
    env.initial()

    kg_model = load_kg_model_for_env(flags, gym_env)
    model = model_for_env(flags, gym_env, kg_model)

    buffers = create_buffers(flags, gym_env.observation_space,
                             model.num_actions)
    cumulative_steps = torch.zeros(1, dtype=int).share_memory_()

    model.share_memory()

    ctx = mp.get_context("fork")

    tester_processes = []
    if flags.test_interval > 0:
        splits = ['test', 'train']
        for split in splits:
            tester = ctx.Process(
                target=test,
                args=(flags, env, cumulative_steps, model, split, plogger),
            )
            tester_processes.append(tester)
            tester.start()

    # Add initial RNN state.
    initial_agent_state_buffers = []
    for _ in range(flags.num_buffers):
        state = model.initial_state(batch_size=1)
        for t in state:
            t.share_memory_()
        initial_agent_state_buffers.append(state)

    actor_processes = []
    free_queue = ctx.SimpleQueue()
    full_queue = ctx.SimpleQueue()

    for i in range(flags.num_actors):
        actor = ctx.Process(
            target=act,
            args=(
                flags,
                env,
                i,
                free_queue,
                full_queue,
                model,
                buffers,
                initial_agent_state_buffers,
            ),
        )
        actor.start()
        actor_processes.append(actor)

    learner_model = model_for_env(flags, gym_env,
                                  kg_model).to(device=flags.device)

    optimizer = torch.optim.RMSprop(
        learner_model.parameters(),
        lr=flags.learning_rate,
        momentum=flags.momentum,
        eps=flags.epsilon,
        alpha=flags.alpha,
    )

    def lr_lambda(epoch):
        return 1 - min(epoch * T * B, flags.total_steps) / flags.total_steps

    scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)

    logger = logging.getLogger("logfile")
    stat_keys = [
        "total_loss",
        "mean_episode_return",
        "pg_loss",
        "baseline_loss",
        "entropy_loss",
    ]
    logger.info("# Step\t%s", "\t".join(stat_keys))

    step, stats = 0, {}

    def batch_and_learn(i, lock=threading.Lock()):
        """Thread target for the learning process."""
        nonlocal step, stats, cumulative_steps
        timings = prof.Timings()
        while step < flags.total_steps:
            timings.reset()
            batch, agent_state = get_batch(
                flags,
                free_queue,
                full_queue,
                buffers,
                initial_agent_state_buffers,
                timings,
            )
            stats = learn(flags, model, learner_model, batch, agent_state,
                          optimizer, scheduler)
            timings.time("learn")
            with lock:
                to_log = dict(step=step)
                to_log.update({k: stats[k] for k in stat_keys})
                plogger.log(to_log)
                step += T * B
                cumulative_steps[0] = step

        if i == 0:
            logging.info("Batch and learn: %s", timings.summary())

    for m in range(flags.num_buffers):
        free_queue.put(m)

    threads = []
    for i in range(flags.num_learner_threads):
        thread = threading.Thread(target=batch_and_learn,
                                  name="batch-and-learn-%d" % i,
                                  args=(i, ))
        thread.start()
        threads.append(thread)

    def checkpoint():
        if flags.disable_checkpoint:
            return
        logging.info("Saving checkpoint to %s", checkpointpath)
        torch.save(
            {
                "model_state_dict": model.state_dict(),
                "optimizer_state_dict": optimizer.state_dict(),
                "scheduler_state_dict": scheduler.state_dict(),
                "flags": vars(flags),
            },
            checkpointpath,
        )

    timer = timeit.default_timer
    try:
        last_checkpoint_time = timer()
        while step < flags.total_steps:
            start_step = step
            start_time = timer()
            time.sleep(5)

            if timer() - last_checkpoint_time > 10 * 60:  # Save every 10 min.
                checkpoint()
                last_checkpoint_time = timer()

            sps = (step - start_step) / (timer() - start_time)
            if stats.get("episode_returns", None):
                mean_return = ("Return per episode: %.1f. " %
                               stats["mean_episode_return"])
            else:
                mean_return = ""
            total_loss = stats.get("total_loss", float("inf"))
            logging.info(
                "Steps %i @ %.1f SPS. Loss %f. %sStats:\n%s",
                step,
                sps,
                total_loss,
                mean_return,
                pprint.pformat(stats),
            )
    except KeyboardInterrupt:
        return  # Try joining actors then quit.
    else:
        for thread in threads:
            thread.join()
        logging.info("Learning finished after %d steps.", step)
    finally:
        for _ in range(flags.num_actors):
            free_queue.put(None)
        for actor in actor_processes:
            actor.join(timeout=1)
        for tester in tester_processes:
            tester.join(timeout=1)

    checkpoint()
    plogger.close()
Beispiel #15
0
def act(flags, actor_index: str, channel):
    try:
        logging.info("Actor %i started.", actor_index)
        timings = prof.Timings()  # Keep track of how fast things are.

        gym_env = create_env(flags)
        seed = actor_index ^ int.from_bytes(os.urandom(4), byteorder="little")
        gym_env.seed(seed)

        flags.num_actions = gym_env.action_space.n

        flags.device = torch.device("cpu")

        model = Net(gym_env.observation_space.shape,
                    num_actions=flags.num_actions).to(device=flags.device)
        #model.eval()

        buffers = create_buffers(flags, gym_env.observation_space.shape,
                                 model.num_actions)

        env = environment.Environment(gym_env)
        env_output = env.initial()
        agent_state = model.initial_state(batch_size=1)
        agent_output, unused_state = model(env_output, agent_state)

        # pull index for model update

        pull_index = 0

        while True:
            # Write old rollout end.
            for key in env_output:
                buffers[key][0][0, ...] = env_output[key]
            for key in agent_output:
                buffers[key][0][0, ...] = agent_output[key]

            # Do new rollout.
            for t in range(flags.unroll_length):
                timings.reset()

                with torch.no_grad():
                    if (flags.cut_layer < model.total_cut_layers - 1):
                        inter_tensors, inter_T, inter_B = model(
                            env_output, agent_state, cut_layer=flags.cut_layer)
                        agent_output, agent_state = rpcenv.inference_send(
                            inter_tensors, agent_state, flags.cut_layer,
                            inter_T, inter_B, env_output["reward"], channel)
                    else:
                        agent_output, agent_state = model(
                            env_output, agent_state)

                timings.time("model")

                env_output = env.step(agent_output["action"])

                timings.time("step")

                for key in env_output:
                    buffers[key][0][t + 1, ...] = env_output[key]
                for key in agent_output:
                    buffers[key][0][t + 1, ...] = agent_output[key]

                timings.time("write")

            # rpc send buffers to learner
            rpcenv.upload_trajectory(actor_index, buffers, channel)

            pull_index = pull_index + 1

            if (pull_index == flags.batch_size):
                parameters = rpcenv.pull_model(actor_index, channel)
                logging.info("update model !!")
                model.load_state_dict(parameters)
                logging.info("update model from learner in %i steps",
                             env_output["episode_step"])
                logging.info("model return in %f",
                             env_output["episode_return"])

                # pull index for model update

                pull_index = 0

    except KeyboardInterrupt:
        pass  # Return silently.
    except Exception as e:
        logging.error("Exception in worker process %i", actor_index)
        traceback.print_exc()
        print()
        raise e