예제 #1
0
def get_batch(
    flags,
    free_queue: mp.SimpleQueue,
    full_queue: mp.SimpleQueue,
    buffers: Buffers,
    initial_agent_state_buffers,
    timings,
    lock=threading.Lock(),
):
    with lock:
        timings.time("lock")
        indices = [full_queue.get() for _ in range(flags.batch_size)]
        timings.time("dequeue")
    batch = {
        key: torch.stack([buffers[key][m] for m in indices], dim=1) for key in buffers
    }
    initial_agent_state = [torch.stack([initial_agent_state_buffers[m][i][0] for m in indices], axis=0)
                      for i in range(2)]
    #print("initial_agent_state[0].shape: ", initial_agent_state[0].shape)
    timings.time("batch")
    for m in indices:
        free_queue.put(m)
    timings.time("enqueue")
    batch = {k: t.to(device=flags.device, non_blocking=True) for k, t in batch.items()}
    initial_agent_state = [t.to(device=flags.device, non_blocking=True) for t in initial_agent_state]
    timings.time("device")
    return batch, initial_agent_state
예제 #2
0
파일: agent.py 프로젝트: yilu1021/nle
def get_batch(
        flags,
        free_queue: mp.SimpleQueue,
        full_queue: mp.SimpleQueue,
        buffers,
        initial_agent_state_buffers,
        lock=threading.Lock(),
):
    with lock:
        indices = [full_queue.get() for _ in range(flags.batch_size)]
    batch = {
        key: torch.stack([buffers[key][m] for m in indices], dim=1)
        for key in buffers
    }
    initial_agent_state = (torch.cat(ts, dim=1) for ts in zip(
        *[initial_agent_state_buffers[m] for m in indices]))
    for m in indices:
        free_queue.put(m)
    batch = {
        k: t.to(device=flags.device, non_blocking=True)
        for k, t in batch.items()
    }
    initial_agent_state = tuple(
        t.to(device=flags.device, non_blocking=True)
        for t in initial_agent_state)
    return batch, initial_agent_state
예제 #3
0
def get_batch(
    free_queue: mp.SimpleQueue,
    full_queue: mp.SimpleQueue,
    buffers: Buffers,
    flags,
    timings,
    lock=threading.Lock()) -> typing.Dict[str, torch.Tensor]:
    with lock:
        timings.time('lock')
        indices = [full_queue.get() for _ in range(flags.batch_size)]
        timings.time('dequeue')
    batch = {
        key: torch.stack([buffers[key][m] for m in indices], dim=1)
        for key in buffers
    }
    timings.time('batch')
    for m in indices:
        free_queue.put(m)
    timings.time('enqueue')
    batch = {
        k: t.to(device=flags.device, non_blocking=True)
        for k, t in batch.items()
    }
    timings.time('device')
    return batch
def get_batch(
    flags,
    free_queue: mp.SimpleQueue,
    full_queue: mp.SimpleQueue,
    buffers: Buffers,
    initial_agent_state_buffers,
    timings,
    lock=threading.Lock(),
):
    with lock:
        timings.time("lock")
        indices = [full_queue.get() for _ in range(flags.batch_size)]
        timings.time("dequeue")
    batch = {
        key: torch.stack([buffers[key][m] for m in indices], dim=1) for key in buffers
    }
    # NOTE: AttentionNet is batch first.
    initial_agent_state = tuple(
        torch.cat(ts, dim=0)
        for ts in zip(*[initial_agent_state_buffers[m] for m in indices])
    )
    timings.time("batch")
    for m in indices:
        free_queue.put(m)
    timings.time("enqueue")
    batch = {k: t.to(device=flags.device, non_blocking=True) for k, t in batch.items()}
    initial_agent_state = tuple(
        t.to(device=flags.device, non_blocking=True) for t in initial_agent_state
    )
    timings.time("device")
    return batch, initial_agent_state
예제 #5
0
def get_batch(
        flags,
        free_queue: mp.SimpleQueue,
        full_queue: mp.SimpleQueue,
        buffers: Buffers,
        timings,
        lock=threading.Lock(),
):
    with lock:
        timings.time("lock")
        indices = [full_queue.get() for _ in range(flags.batch_size)]
        timings.time("dequeue")
    batch = {
        key: torch.stack([buffers[key][m] for m in indices], dim=1)
        for key in buffers
    }
    timings.time("batch")
    for m in indices:
        free_queue.put(m)
    timings.time("enqueue")
    batch = {
        k: t.to(device=flags.device, non_blocking=True)
        for k, t in batch.items()
    }
    timings.time("device")
    return batch
def read_img(path_queue: multiprocessing.JoinableQueue,
             data_queue: multiprocessing.SimpleQueue):
    torch.set_num_threads(1)
    while True:
        img_path = path_queue.get()
        img = Image.open(img_path)
        data_queue.put(T(img))
        path_queue.task_done()
def act(flags, actor_index: int, free_queue: mp.SimpleQueue,
        full_queue: mp.SimpleQueue, model: torch.nn.Module, buffers: Buffers,
        initial_agent_state_buffers, level_name):
    try:
        logging.info("Actor %i started.", actor_index)
        timings = prof.Timings()  # Keep track of how fast things are.
        seed = actor_index ^ int.from_bytes(os.urandom(4), byteorder="little")
        ######changed next line
        gym_env = create_env(flags, level_name, seed)
        env = environment.Environment(gym_env)
        env_output = env.initial()
        agent_state = model.initial_state(batch_size=1)
        agent_output, unused_state = model(env_output, agent_state)
        while True:
            index = free_queue.get()
            if index is None:
                break

            # Write old rollout end.
            for key in env_output:
                buffers[key][index][0, ...] = env_output[key]
            for key in agent_output:
                buffers[key][index][0, ...] = agent_output[key]
            for i, tensor in enumerate(agent_state):
                initial_agent_state_buffers[index][i][...] = tensor

            # Do new rollout.
            for t in range(flags.unroll_length):
                timings.reset()

                with torch.no_grad():
                    agent_output, agent_state = model(env_output, agent_state)

                timings.time("model")

                env_output = env.step(agent_output["action"])

                timings.time("step")

                for key in env_output:
                    buffers[key][index][t + 1, ...] = env_output[key]
                for key in agent_output:
                    buffers[key][index][t + 1, ...] = agent_output[key]

                timings.time("write")
            full_queue.put(index)

        if actor_index == 0:
            logging.info("Actor %i: %s", actor_index, timings.summary())

    except KeyboardInterrupt:
        pass  # Return silently.
    except Exception as e:
        logging.error("Exception in worker process %i", actor_index)
        traceback.print_exc()
        print()
        raise e
예제 #8
0
파일: agent.py 프로젝트: zeta1999/nle
def act(
    flags,
    actor_index: int,
    free_queue: mp.SimpleQueue,
    full_queue: mp.SimpleQueue,
    model: torch.nn.Module,
    buffers,
    initial_agent_state_buffers,
):
    try:
        logging.info("Actor %i started.", actor_index)

        gym_env = create_env(
            flags.env,
            savedir=flags.rundir,
            archivefile="nethack.%i.%%(pid)i.%%(time)s.zip" % actor_index,
        )
        env = ResettingEnvironment(gym_env)
        env_output = env.initial()
        agent_state = model.initial_state(batch_size=1)
        agent_output, unused_state = model(env_output, agent_state)
        while True:
            index = free_queue.get()
            if index is None:
                break

            # Write old rollout end.
            for key in env_output:
                buffers[key][index][0, ...] = env_output[key]
            for key in agent_output:
                buffers[key][index][0, ...] = agent_output[key]
            for i, tensor in enumerate(agent_state):
                initial_agent_state_buffers[index][i][...] = tensor

            # Do new rollout.
            for t in range(flags.unroll_length):
                with torch.no_grad():
                    agent_output, agent_state = model(env_output, agent_state)

                env_output = env.step(agent_output["action"])

                for key in env_output:
                    buffers[key][index][t + 1, ...] = env_output[key]
                for key in agent_output:
                    buffers[key][index][t + 1, ...] = agent_output[key]

            full_queue.put(index)

    except KeyboardInterrupt:
        pass  # Return silently.
    except Exception:
        logging.error("Exception in worker process %i", actor_index)
        traceback.print_exc()
        print()
        raise
예제 #9
0
파일: run_exp.py 프로젝트: vzhong/RTFM-1
def act(i: int, free_queue: mp.SimpleQueue, full_queue: mp.SimpleQueue,
        model: torch.nn.Module, buffers: Buffers, flags):
    try:
        logging.info('Actor %i started.', i)
        timings = prof.Timings()  # Keep track of how fast things are.

        gym_env = Net.create_env(flags)
        seed = i ^ int.from_bytes(os.urandom(4), byteorder='little')
        gym_env.seed(seed)
        env = environment.Environment(gym_env)
        env_output = env.initial()
        agent_output = model(env_output)
        while True:
            index = free_queue.get()
            if index is None:
                break

            # Write old rollout end.
            for key in env_output:
                buffers[key][index][0, ...] = env_output[key]
            for key in agent_output:
                buffers[key][index][0, ...] = agent_output[key]

            # Do new rollout
            for t in range(flags.unroll_length):
                timings.reset()

                with torch.no_grad():
                    agent_output = model(env_output)

                timings.time('model')

                env_output = env.step(agent_output['action'])

                timings.time('step')

                for key in env_output:
                    buffers[key][index][t + 1, ...] = env_output[key]
                for key in agent_output:
                    buffers[key][index][t + 1, ...] = agent_output[key]

                timings.time('write')
            full_queue.put(index)

        if i == 0:
            logging.info('Actor %i: %s', i, timings.summary())

    except KeyboardInterrupt:
        pass  # Return silently.
    except Exception as e:
        logging.error('Exception in worker process %i', i)
        traceback.print_exc()
        print()
        raise e
예제 #10
0
def get_batch(
        flags,
        free_queue: mp.SimpleQueue,
        full_queue: mp.SimpleQueue,
        buffers: Buffers,
        initial_agent_state_buffers,
        timings,
        lock=threading.Lock(),
):
    # need to make sure that we wait until batch_size trajectories/rollouts have been put into the queue
    with lock:
        timings.time("lock")
        # get the indices of actors "offering" trajectories/rollouts to be processed by the learner
        indices = [full_queue.get() for _ in range(flags.batch_size)]
        timings.time("dequeue")

    # create the batch as a dictionary for all the data in the buffers (see act() function for list of
    # keys), where each entry is a tensor of these values stacked across actors along the first dimension,
    # which I believe should be the "batch dimension" (see _format_frame())
    batch = {
        key: torch.stack([buffers[key][m] for m in indices], dim=1)
        for key in buffers
    }

    # similar thing for the initial agent states, where I think the tuples are concatenated to become torch tensors
    initial_agent_state = (torch.cat(ts, dim=1) for ts in zip(
        *[initial_agent_state_buffers[m] for m in indices]))
    timings.time("batch")

    # once the data has been "transferred" into batch and initial_agent_state,
    # signal that the data has been processed to the actors
    for m in indices:
        free_queue.put(m)
    timings.time("enqueue")

    # move the data to the right device (e.g. GPU)
    batch = {
        k: t.to(device=flags.device, non_blocking=True)
        for k, t in batch.items()
    }
    initial_agent_state = tuple(
        t.to(device=flags.device, non_blocking=True)
        for t in initial_agent_state)
    timings.time("device")

    return batch, initial_agent_state
def get_batch(
        flags,
        free_queue: mp.SimpleQueue,
        full_queue: mp.SimpleQueue,
        buffers: Buffers,
        initial_agent_state_buffers,
        timings,
        lock=threading.Lock(),
):
    with lock:
        timings.time("lock")
        indices = [full_queue.get() for _ in range(flags.batch_size)]

        # TODO: Check if emptying full_queue and then readding to it takes very long,
        #       seems like the only way to ensure a batch of similar length elements
        # One problem with doing this is that if get a really short trajectory, may never end up
        # using it. DONT CHANGE THIS FOR NOW.

        timings.time("dequeue")
    batch = {
        key: torch.stack([buffers[key][m] for m in indices], dim=1)
        for key in buffers
    }
    initial_agent_state = (torch.cat(ts, dim=1) for ts in zip(
        *[initial_agent_state_buffers[m] for m in indices]))
    timings.time("batch")
    for m in indices:
        free_queue.put(m)
    timings.time("enqueue")
    batch = {
        k: t.to(device=flags.device, non_blocking=True)
        for k, t in batch.items()
    }
    initial_agent_state = tuple(
        t.to(device=flags.device, non_blocking=True)
        for t in initial_agent_state)
    timings.time("device")
    return batch, initial_agent_state
예제 #12
0
def act(i: int, free_queue: mp.SimpleQueue, full_queue: mp.SimpleQueue,
        model: torch.nn.Module, buffers: Buffers, initial_agent_state_buffers, flags):
    try:
        log.info('Actor %i started.', i)
        timings = prof.Timings() 
        
        gym_env = create_env(flags)
        seed = i ^ int.from_bytes(os.urandom(4), byteorder='little')
        gym_env.seed(seed)
        
        if flags.num_input_frames > 1:
            gym_env = FrameStack(gym_env, flags.num_input_frames)  

        env = Environment(gym_env, fix_seed=flags.fix_seed, env_seed=flags.env_seed)

        env_output = env.initial()
        agent_state = model.initial_state(batch_size=1)
        agent_output, unused_state = model(env_output, agent_state)
        

        while True:
            index = free_queue.get()
            if index is None:
                break
            
            for key in env_output:
                buffers[key][index][0, ...] = env_output[key]
            for key in agent_output:
                buffers[key][index][0, ...] = agent_output[key]
            for i, tensor in enumerate(agent_state):
                initial_agent_state_buffers[index][i][...] = tensor

            for t in range(flags.unroll_length):
                timings.reset()

                with torch.no_grad():
                    agent_output, agent_state = model(env_output, agent_state)

                timings.time('model')

                env_output = env.step(agent_output['action'])

                timings.time('step')

                for key in env_output:
                    buffers[key][index][t + 1, ...] = env_output[key]
                for key in agent_output:
                    buffers[key][index][t + 1, ...] = agent_output[key]
                
                timings.time('write')
            full_queue.put(index)

        if i == 0:
            log.info('Actor %i: %s', i, timings.summary())

    except KeyboardInterrupt:
        pass  
    except Exception as e:
        logging.error('Exception in worker process %i', i)
        traceback.print_exc()
        print()
        raise e
예제 #13
0
def train(config):
    task_queue = SimpleQueue()
    result_queue = SimpleQueue()
    stop = mp.Value('i', False)
    stats = SharedStats(config.state_dim)
    normalizers = [StaticNormalizer(config.state_dim) for _ in range(config.num_workers)]
    for normalizer in normalizers:
        normalizer.offline_stats.load(stats)

    workers = [Worker(id, normalizers[id], task_queue, result_queue, stop, config) for id in range(config.num_workers)]
    for w in workers: w.start()

    opt = cma.CMAOptions()
    opt['tolfun'] = -config.target
    opt['popsize'] = config.pop_size
    opt['verb_disp'] = 0
    opt['verb_log'] = 0
    opt['maxiter'] = sys.maxsize
    es = cma.CMAEvolutionStrategy(config.initial_weight, config.sigma, opt)

    total_steps = 0
    initial_time = time.time()
    training_rewards = []
    training_steps = []
    training_timestamps = []
    test_mean, test_ste = test(config, config.initial_weight, stats)
    logger.info('total steps %d, %f(%f)' % (total_steps, test_mean, test_ste))
    training_rewards.append(test_mean)
    training_steps.append(0)
    training_timestamps.append(0)
    while True:
        solutions = es.ask()
        for id, solution in enumerate(solutions):
            task_queue.put((id, solution))
        while not task_queue.empty():
            continue
        result = []
        while len(result) < len(solutions):
            if result_queue.empty():
                continue
            result.append(result_queue.get())
        result = sorted(result, key=lambda x: x[0])
        total_steps += np.sum([r[2] for r in result])
        cost = [r[1] for r in result]
        best_solution = solutions[np.argmin(cost)]
        elapsed_time = time.time() - initial_time
        test_mean, test_ste = test(config, best_solution, stats)
        logger.info('total steps %d, test %f(%f), best %f, elapased time %f' %
            (total_steps, test_mean, test_ste, -np.min(cost), elapsed_time))
        training_rewards.append(test_mean)
        training_steps.append(total_steps)
        training_timestamps.append(elapsed_time)
        # with open('data/%s-best_solution_%s.bin' % (TAG, config.task), 'wb') as f:
        #     pickle.dump(solutions[np.argmin(result)], f)
        if config.max_steps and total_steps > config.max_steps:
            stop.value = True
            break

        cost = fitness_shift(cost)
        es.tell(solutions, cost)
        # es.disp()
        for normalizer in normalizers:
            stats.merge(normalizer.online_stats)
            normalizer.online_stats.zero()
        for normalizer in normalizers:
            normalizer.offline_stats.load(stats)

    stop.value = True
    for w in workers: w.join()
    return [training_rewards, training_steps, training_timestamps]
예제 #14
0
def act(flags, gym_env, actor_index: int, free_queue: mp.SimpleQueue,
        full_queue: mp.SimpleQueue, buffers: Buffers, actor_buffers: Buffers,
        actor_model_queues: List[mp.SimpleQueue],
        actor_env_queues: List[mp.SimpleQueue]):
    try:
        logging.info("Actor %i started.", actor_index)
        timings = prof.Timings()  # Keep track of how fast things are.

        gym_env = gym_env
        #seed = actor_index ^ int.from_bytes(os.urandom(4), byteorder="little")
        #gym_env.seed(seed)
        if flags.agent in ["CNN"]:
            env = environment.Environment(gym_env, "image")
        elif flags.agent in ["NLM", "KBMLP", "GCN"]:
            if flags.state in ["relative", "integer", "block"]:
                env = environment.Environment(gym_env, "VKB")
            elif flags.state == "absolute":
                env = environment.Environment(gym_env, "absVKB")
        env_output = env.initial()
        for key in env_output:
            actor_buffers[key][actor_index][0] = env_output[key]
        while True:
            index = free_queue.get()
            if index is None:
                break

            # Write old rollout end.
            for key in actor_buffers:
                buffers[key][index][0] = actor_buffers[key][actor_index][0]

            # Do new rollout.
            for t in range(flags.unroll_length):
                timings.reset()

                actor_model_queues[actor_index].put(actor_index)
                env_info = actor_env_queues[actor_index].get()
                if env_info == "exit":
                    return

                timings.time("model")

                env_output = env.step(actor_buffers["action"][actor_index][0])

                timings.time("step")

                for key in actor_buffers:
                    buffers[key][index][t +
                                        1] = actor_buffers[key][actor_index][0]
                for key in env_output:
                    buffers[key][index][t + 1, ...] = env_output[key]
                for key in env_output:
                    actor_buffers[key][actor_index][0] = env_output[key]

                timings.time("write")

            full_queue.put(index)

        if actor_index == 0:
            logging.info("Actor %i: %s", actor_index, timings.summary())

    except KeyboardInterrupt:
        pass  # Return silently.
    except Exception as e:
        logging.error("Exception in worker process %i", actor_index)
        traceback.print_exc()
        print()
        raise e
예제 #15
0
def act(
    flags,
    game_params,
    actor_index: int,
    free_queue: mp.SimpleQueue,
    full_queue: mp.SimpleQueue,
    model: torch.nn.Module,
    buffers: Buffers,
):
    try:
        logging.info("Actor %i started.", actor_index)
        timings = prof.Timings()  # Keep track of how fast things are.

        seed = actor_index ^ int.from_bytes(os.urandom(4), byteorder="little")
        sc_env = init_game(game_params['env'],
                           flags.map_name,
                           random_seed=seed)
        obs_processer = IMPALA_ObsProcesser(action_table=model.action_table,
                                            **game_params['obs_processer'])
        env = environment.Environment(sc_env, obs_processer, seed)
        # initial rollout starts here
        env_output = env.initial()
        with torch.no_grad():
            agent_output = model.actor_step(env_output)

        while True:
            index = free_queue.get()
            if index is None:
                break

            # Write old rollout end.
            for key in env_output:
                buffers[key][index][0, ...] = env_output[key]
            for key in agent_output:
                if key not in ['sc_env_action'
                               ]:  # no need to save this key on buffers
                    buffers[key][index][0, ...] = agent_output[key]

            # Do new rollout.
            for t in range(flags.unroll_length):
                timings.reset()

                env_output = env.step(agent_output["sc_env_action"])

                timings.time("step")

                with torch.no_grad():
                    agent_output = model.actor_step(env_output)

                timings.time("model")

                #env_output = env.step(agent_output["sc_env_action"])

                #timings.time("step")

                for key in env_output:
                    buffers[key][index][t + 1, ...] = env_output[key]
                for key in agent_output:
                    if key not in ['sc_env_action'
                                   ]:  # no need to save this key on buffers
                        buffers[key][index][t + 1, ...] = agent_output[key]
                # env_output will be like
                # s_{0}, ..., s_{T}
                # act_mask_{0}, ..., act_mask_{T}
                # discount_{0}, ..., discount_{T}
                # r_{-1}, ..., r_{T-1}
                # agent_output will be like
                # a_0, ..., a_T with a_t ~ pi(.|s_t)
                # log_pi(a_0|s_0), ..., log_pi(a_T|s_T)
                # so the learner can use (s_i, act_mask_i) to predict log_pi_i
                timings.time("write")
            full_queue.put(index)

        if actor_index == 0:
            logging.info("Actor %i: %s", actor_index, timings.summary())

    except KeyboardInterrupt:
        pass  # Return silently.
    except Exception as e:
        logging.error("Exception in worker process %i", actor_index)
        traceback.print_exc()
        print()
        raise e
예제 #16
0
def train(config):
    task_queue = SimpleQueue()
    result_queue = SimpleQueue()
    stop = mp.Value('i', False)
    stats = SharedStats(config.state_dim)
    param = torch.FloatTensor(torch.from_numpy(config.initial_weight))
    param.share_memory_()
    n_params = len(param.numpy().flatten())
    if config.args.noise_type == 'lss':
        noise_sizes = [
            cofig.state_dim * config.hidden_size,
            config.hidden_size * config.hidden_size,
            config.hidden_size * config.action_dim
        ]
    else:
        noise_sizes = None
    noise_generator = NoiseGenerator(n_params,
                                     config.pop_size,
                                     config.args.noise,
                                     noise_sizes=noise_sizes)
    normalizers = [
        StaticNormalizer(config.state_dim) for _ in range(config.num_workers)
    ]
    for normalizer in normalizers:
        normalizer.offline_stats.load(stats)
    workers = [
        Worker(id, param, normalizers[id], task_queue, result_queue, stop,
               noise_generator, config) for id in range(config.num_workers)
    ]
    for w in workers:
        w.start()

    training_rewards = []
    training_steps = []
    training_timestamps = []
    initial_time = time.time()
    total_steps = 0
    iteration = 0
    while not stop.value:
        test_mean, test_ste = test(config, param.numpy(), stats)
        elapsed_time = time.time() - initial_time
        training_rewards.append(test_mean)
        training_steps.append(total_steps)
        training_timestamps.append(elapsed_time)
        logger.info('Test: total steps %d, %f(%f), elapsed time %d' %
                    (total_steps, test_mean, test_ste, elapsed_time))

        for i in range(config.pop_size):
            task_queue.put(i)
        rewards = []
        epsilons = []
        steps = []
        while len(rewards) < config.pop_size:
            if result_queue.empty():
                continue
            epsilon, fitness, step = result_queue.get()
            epsilons.append(epsilon)
            rewards.append(fitness)
            steps.append(step)

        total_steps += np.sum(steps)
        r_mean = np.mean(rewards)
        r_std = np.std(rewards)
        # rewards = (rewards - r_mean) / r_std
        logger.info('Train: iteration %d, %f(%f)' %
                    (iteration, r_mean, r_std / np.sqrt(config.pop_size)))
        iteration += 1
        # if r_mean > config.target:
        if config.max_steps and total_steps > config.max_steps:
            stop.value = True
            break
        for normalizer in normalizers:
            stats.merge(normalizer.online_stats)
            normalizer.online_stats.zero()
        for normalizer in normalizers:
            normalizer.offline_stats.load(stats)
        if config.args.reward_type == 'rank':
            rewards = fitness_shift(rewards)
        gradient = np.asarray(epsilons) * np.asarray(rewards).reshape((-1, 1))
        gradient = np.mean(gradient, 0) / config.sigma
        gradient -= config.weight_decay * gradient
        if config.args.opt == 'adam':
            gradient = config.opt.update(gradient)
        gradient = torch.FloatTensor(gradient)
        param.add_(config.learning_rate * gradient)

    for w in workers:
        w.join()
    return [training_rewards, training_steps, training_timestamps]
예제 #17
0
def act(
    flags,
    env: str,
    task: int,
    full_action_space: bool,
    actor_index: int,
    free_queue: mp.SimpleQueue,
    full_queue: mp.SimpleQueue,
    model: torch.nn.Module,
    buffers: Buffers,
    initial_agent_state_buffers,
):
    try:
        logging.info("Actor %i started.", actor_index)
        timings = prof.Timings()  # Keep track of how fast things are.

        # create the environment from command line parameters
        # => could also create a special one which operates on a list of games (which we need)
        gym_env = create_env(
            env,
            frame_height=flags.frame_height,
            frame_width=flags.frame_width,
            gray_scale=(flags.aaa_input_format == "gray_stack"),
            full_action_space=full_action_space,
            task=task)

        # generate a seed for the environment (NO HUMAN STARTS HERE!), could just
        # use this for all games wrapped by the environment for our application
        seed = actor_index ^ int.from_bytes(os.urandom(4), byteorder="little")
        gym_env.seed(seed)

        # wrap the environment, this is actually probably the point where we could
        # use multiple games, because the other environment is still one from Gym
        env = environment.Environment(gym_env)

        # get the initial frame, reward, done, return, step, last_action
        env_output = env.initial()

        # perform the first step
        agent_state = model.initial_state(batch_size=1)
        agent_output, unused_state = model(env_output, agent_state)
        while True:
            # get a buffer index from the queue for free buffers (?)
            index = free_queue.get()
            # termination signal (?) for breaking out of this loop
            if index is None:
                break

            # Write old rollout end.
            # the keys here are (frame, reward, done, episode_return, episode_step, last_action)
            for key in env_output:
                buffers[key][index][0, ...] = env_output[key]
            # here the keys are (policy_logits, baseline, action)
            for key in agent_output:
                buffers[key][index][0, ...] = agent_output[key]
            # I think the agent_state is just the RNN/LSTM state (which will be the "initial" state for the next step)
            # not sure why it's needed though because it really just seems to be the initial state before starting to
            # act; however, it might be randomly initialised, which is why we might want it...
            for i, tensor in enumerate(agent_state):
                initial_agent_state_buffers[index][i][...] = tensor

            # Do new rollout
            for t in range(flags.unroll_length):
                timings.reset()

                # forward pass without keeping track of gradients to get the agent action
                with torch.no_grad():
                    agent_output, agent_state = model(env_output, agent_state)

                timings.time("model")

                # agent acting in the environment
                env_output = env.step(agent_output["action"])

                timings.time("step")

                # writing the respective outputs of the current step (see above for the list of keys)
                for key in env_output:
                    buffers[key][index][t + 1, ...] = env_output[key]
                for key in agent_output:
                    buffers[key][index][t + 1, ...] = agent_output[key]

                timings.time("write")

            # after finishing a trajectory put the index in the "full queue",
            # presumably so that the data can be processed/sent to the learner
            full_queue.put(index)

        if actor_index == 0:
            logging.info("Actor %i: %s", actor_index, timings.summary())

    except KeyboardInterrupt:
        pass  # Return silently.
    except Exception as e:
        logging.error("Exception in worker process %i", actor_index)
        traceback.print_exc()
        print()
        raise e
예제 #18
0
def act(
    flags,
    game_params,
    actor_index: int,
    free_queue: mp.SimpleQueue,
    full_queue: mp.SimpleQueue,
    model: torch.nn.Module,
    buffers: Buffers,
    initial_agent_state_buffers,
):
    try:
        logging.info("Actor %i started.", actor_index)
        timings = prof.Timings()  # Keep track of how fast things are.

        seed = actor_index ^ int.from_bytes(os.urandom(4), byteorder="little")
        sc_env = init_game(game_params['env'], flags.map_name, random_seed=seed)
        obs_processer = IMPALA_ObsProcesser_v2(env=sc_env, action_table=model.action_table, **game_params['obs_processer'])
        env = environment.Environment_v2(sc_env, obs_processer, seed)
        # initial rollout starts here
        env_output = env.initial() 
        new_res = model.spatial_processing_block.new_res
        agent_state = model.spatial_processing_block.conv_lstm._init_hidden(batch_size=1, 
                                                                            image_size=(new_res,new_res)
                                                                           )
        
        with torch.no_grad():
            agent_output, new_agent_state = model.actor_step(env_output, *agent_state[0]) 

        agent_state = agent_state[0] # _init_hidden yields [(h,c)], whereas actor step only (h,c)
        while True:
            index = free_queue.get()
            if index is None:
                break

            # Write old rollout end. 
            for key in env_output:
                buffers[key][index][0, ...] = env_output[key]
            for key in agent_output:
                if key not in ['sc_env_action']: # no need to save this key on buffers
                    buffers[key][index][0, ...] = agent_output[key]
            
            # lstm state in syncro with the environment / input to the agent 
            # that's why agent_state = new_agent_state gets executed afterwards
            initial_agent_state_buffers[index][0][...] = agent_state[0]
            initial_agent_state_buffers[index][1][...] = agent_state[1]
            
            
            # Do new rollout.
            for t in range(flags.unroll_length):
                timings.reset()

                env_output = env.step(agent_output["sc_env_action"])
                
                timings.time("step")
                
                # update state
                agent_state = new_agent_state 
            
                with torch.no_grad():
                    agent_output, new_agent_state = model.actor_step(env_output, *agent_state)
                
                timings.time("model")
                
                #env_output = env.step(agent_output["sc_env_action"])

                #timings.time("step")

                for key in env_output:
                    buffers[key][index][t+1, ...] = env_output[key] 
                for key in agent_output:
                    if key not in ['sc_env_action']: # no need to save this key on buffers
                        buffers[key][index][t+1, ...] = agent_output[key] 
                # env_output will be like
                # s_{0}, ..., s_{T}
                # act_mask_{0}, ..., act_mask_{T}
                # discount_{0}, ..., discount_{T}
                # r_{-1}, ..., r_{T-1}
                # agent_output will be like
                # a_0, ..., a_T with a_t ~ pi(.|s_t)
                # log_pi(a_0|s_0), ..., log_pi(a_T|s_T)
                # so the learner can use (s_i, act_mask_i) to predict log_pi_i
                timings.time("write")
            full_queue.put(index)

        if actor_index == 0:
            logging.info("Actor %i: %s", actor_index, timings.summary())

    except KeyboardInterrupt:
        pass  # Return silently.
    except Exception as e:
        logging.error("Exception in worker process %i", actor_index)
        traceback.print_exc()
        print()
        raise e
예제 #19
0
def act(
    flags,
    actor_index: int,
    free_queue: mp.SimpleQueue,
    full_queue: mp.SimpleQueue,
    model: torch.nn.Module,
    buffers: Buffers,
    initial_agent_state_buffers,
):
    try:
        logging.info("Actor %i started.", actor_index)
        timings = prof.Timings()  # Keep track of how fast things are.

        gym_env = create_env(flags)
        seed = actor_index ^ int.from_bytes(os.urandom(4), byteorder="little")
        gym_env.seed(seed)
        env = environment.Environment(gym_env)
        env_output = env.initial()

        agent_state = model.initial_state(batch_size=1)
        mems, mem_padding = None, None
        agent_output, unused_state, mems, mem_padding, _ = model(
            env_output, agent_state, mems, mem_padding)
        while True:
            index = free_queue.get()
            if index is None:
                break

            # explicitly make done False to allow the loop to run
            # Don't need to set 'done' to true since now take step out of done state
            # when do arrive at 'done'
            # env_output['done'] = torch.tensor([0], dtype=torch.uint8)

            # Write old rollout end.
            for key in env_output:
                buffers[key][index][0, ...] = env_output[key]
            for key in agent_output:
                buffers[key][index][0, ...] = agent_output[key]
            for i, tensor in enumerate(agent_state):
                initial_agent_state_buffers[index][i][...] = tensor

            # Do one new rollout, untill flags.unroll_length
            t = 0
            while t < flags.unroll_length and not env_output['done'].item():
                # for t in range(flags.unroll_length):
                timings.reset()

                # REmoved since never this will never be true (MOVED TO AFTER FOR LOOP)
                # if env_output['done'].item():
                #    mems = None

                with torch.no_grad():
                    agent_output, agent_state, mems, mem_padding, _ = model(
                        env_output, agent_state, mems, mem_padding)

                timings.time("model")

                # TODO: Shakti add action repeat?
                env_output = env.step(agent_output["action"])

                timings.time("step")

                for key in env_output:
                    buffers[key][index][t + 1, ...] = env_output[key]
                for key in agent_output:
                    buffers[key][index][t + 1, ...] = agent_output[key]

                timings.time("write")
                t += 1

            if env_output['done'].item():
                mems = None
                # Take arbitrary step to reset environment
                env_output = env.step(torch.tensor([2]))

            if t != flags.unroll_length:
                # TODO I checked and seems good but Shakti can you check as well?
                buffers['done'][index][t + 1:] = torch.tensor(
                    [True]).repeat(flags.unroll_length - t)

            full_queue.put(index)

        if actor_index == 0:
            logging.info("Actor %i: %s", actor_index, timings.summary())

    except KeyboardInterrupt:
        pass  # Return silently.
    except Exception as e:
        logging.error("Exception in worker process %i", actor_index)
        traceback.print_exc()
        # print()
        raise e
예제 #20
0
파일: impala.py 프로젝트: johnlime/cleanrl
def act(
    args,
    actor_index: int,
    free_queue: mp.SimpleQueue,
    full_queue: mp.SimpleQueue,
    model: torch.nn.Module,
    buffers: Buffers,
    initial_agent_state_buffers,
):
    try:
        logging.info("Actor %i started.", actor_index)
        timings = Timings()  # Keep track of how fast things are.

        gym_env = create_env(args)
        seed = actor_index ^ int.from_bytes(os.urandom(4), byteorder="little")
        gym_env.seed(seed)
        env = Environment(gym_env)
        def make_env(args):
            def thunk():
                env = create_env(args)
                return env
            return thunk
        envs = DummyVecEnv([make_env(args) for i in range(1)])
        
        env_output = env.initial()
        envs.reset()
        agent_state = model.initial_state(batch_size=1)
        agent_output, unused_state = model(env_output, agent_state)
        while True:
            index = free_queue.get()
            if index is None:
                break

            # Write old rollout end.
            for key in env_output:
                buffers[key][index][0, ...] = env_output[key]
            for key in agent_output:
                buffers[key][index][0, ...] = agent_output[key]
            for i, tensor in enumerate(agent_state):
                initial_agent_state_buffers[index][i][...] = tensor

            # Do new rollout.
            for t in range(args.unroll_length):
                timings.reset()

                with torch.no_grad():
                    agent_output, agent_state = model(env_output, agent_state)

                # timings.time("model")

                env_output = env.step(agent_output["action"])
                # env_output = env.step(agent_output["action"])
                # envs.step((torch.randint(0, envs.action_space.n, (envs.num_envs,))).numpy())
                assert agent_output["action"] == env_output["last_action"]
                timings.time("step")

                for key in env_output:
                    buffers[key][index][t + 1, ...] = env_output[key]
                for key in agent_output:
                    buffers[key][index][t + 1, ...] = agent_output[key]

                timings.time("write")
            full_queue.put(index)

        if actor_index == 0:
            logging.info("Actor %i: %s", actor_index, timings.summary())

    except KeyboardInterrupt:
        pass  # Return silently.
    except Exception as e:
        logging.error("Exception in worker process %i", actor_index)
        traceback.print_exc()
        print()
        raise e
class ColorizationModel:
    """Wrapper class implementing network training and prediction.

    This is a wrapper class that composes a PyTorch network with a loss
    function and an optimizer and implements training, prediction,
    checkpointing and logging functionality. Its implementation is kept
    independent of the concrete underlying network.

    """

    CHECKPOINT_PREFIX = 'checkpoint'
    CHECKPOINT_POSTFIX = 'tar'
    CHECKPOINT_ID_FMT = '{:010}'

    def __init__(self,
                 network,
                 loss=None,
                 optimizer=None,
                 lr_scheduler=None,
                 log_config=None,
                 logger=None):
        """Compose a model.

        Note:
            The model is not trainable unless both `loss` and `optimizer` are
            not `None`. A non-trainable model can still be initialized from
            pretrained weights for evaluation or prediction purposes.

            Likewise, logging will only be enabled if both `log_config` and
            `logger` are not `None`.

        Args:
            network (colorization.ColorizationNetwork):
                Network instance.
            loss (torch.nn.Module, optional):
                Training loss function, if this is set to `None`, the model is
                not trainable.
            optimizer (functools.partial, optional):
                Partially applied training optimizer, parameter argument is
                supplied by this constructor, if this is set to `None`, the
                model is not trainable.
            lr_scheduler (torch.optim.lr_scheduler._LRScheduler, optional):
                Learning rate scheduler.
            log_config (dict, optional):
                Python `logging` configuration dictionary, if this is set to
                `None`, logging will be disabled.
            logger (str, optional):
                Name of the logger to be utilized (this logger has to be
                configured by `log_config`, pre-existing loggers will not
                function correctly since logging is performed in a separate
                thread).

        """

        self.network = network

        self.loss = loss

        if optimizer is None:
            self.optimizer = None
        else:
            self.optimizer = optimizer(network.parameters())

        if lr_scheduler is None:
            self.lr_scheduler = None
        else:
            self.lr_scheduler = lr_scheduler(self.optimizer)

        self.device = network.device

        self._log_enabled = \
            _mp_spawn and log_config is not None and logger is not None

        if self._log_enabled:
            self._log_config = log_config
            self._logger = logger

    def train(self,
              dataloader,
              iterations,
              iterations_till_checkpoint=None,
              checkpoint_init=None,
              checkpoint_dir=None):
        """Train the models network.

        Args:
            dataloader (torch.utils.data.DataLoder):
                Data loader, should return Lab images of arbitrary shape and
                datatype float32. For efficient training, the data loader
                should be constructed with `pin_memory` set to `True`.
            iterations (int):
                Number of iterations (batches) to run the training for, note
                that training can be picked up at a previous checkpoint later
                on, also note that while training time is specified in
                iterations, checkpoint frequency is specified in epochs to
                avoid issues with data loader shuffling etc.
            iterations_till_checkpoint (str, optional):
                Number of iterations between checkpoints, only meaningful in
                combination with `checkpoint_dir`.
            checkpoint_init (str, optional):
                Previous checkpoint from which to pick up training.
            checkpoint_dir (str, optional):
                Directory in which to save checkpoints, must exist and be
                empty, is this is set to `None`, no checkpoints will be saved.

        """

        # restore from checkpoint
        if checkpoint_init is not None:
            self.load_checkpoint(checkpoint_init, load_optimizer=True)
            iteration_init = self._checkpoint_iteration(checkpoint_init)

            if iteration_init == 'final':
                raise ValueError(
                    "cannot continue training from final checkpoint")

        # validate checkpoint directory
        if checkpoint_dir is not None:
            self._validate_checkpoint_dir(checkpoint_dir,
                                          resuming=(checkpoint_init
                                                    is not None))

        # check whether dataloader has pin_memory set and set image size
        if not dataloader.pin_memory:
            warn("'pin_memory' not set, this will slow down training")

        # switch to training mode (essential for batch normalization)
        self.network.train()

        # create logging thread
        if self._log_enabled:
            self._log_queue = SimpleQueue()

            self._log = Process(target=_log_progress,
                                args=(self._log_config, self._logger,
                                      self._log_queue))

            self._log.start()

        # optimization loop
        if checkpoint_init is None:
            i = 1
        else:
            i = iteration_init + 1

        if self.lr_scheduler is not None:
            self.lr_scheduler.max_epochs = iterations

        done = False
        while not done:
            for img in dataloader:
                # move data to device
                img = img.to(self.device, non_blocking=dataloader.pin_memory)

                # perform parameter update
                self.optimizer.zero_grad()

                q_pred, q_actual = self.network(img)
                loss = self.loss(q_pred, q_actual)
                loss.backward()

                self.optimizer.step()

                # update learning rate
                if self.lr_scheduler is not None:
                    self.lr_scheduler.step()

                # display progress
                if self._log_enabled:
                    self._log_queue.put(
                        _LossLogData(i, iterations, loss.item()))

                # periodically save checkpoint
                if checkpoint_dir is not None:
                    if i % iterations_till_checkpoint == 0:
                        self._checkpoint_training(checkpoint_dir, i)

                # increment iteration counter
                i += 1

                if i > iterations:
                    done = True
                    break

        # save final checkpoint
        if checkpoint_dir is not None:
            self._checkpoint_training(checkpoint_dir, 'final')

        # stop logging thread
        if self._log_enabled:
            self._log_queue.put('done')
            self._log.join()

    def predict(self, img):
        """Perform single batch prediction using the current network.

        Args:
            img (torch.Tensor):
                A tensor of shape `(n, 1, h, w)` where `n` is the size of the
                batch to be predicted and `h` and `w` are image dimensions. The
                images should be Lab lightness channels.

        Returns:
            A tensor of shape `(n, 1, h, w)` containing the predicted ab
            channels.

        """

        # switch to evaluation mode
        self.network.eval()

        # move data to device
        img = img.to(self.device)

        # run prediction
        with torch.no_grad():
            img_pred = self.network(img)

        return img_pred

    def save_checkpoint(self, path, save_optimizer=False):
        """Save weights to checkpoint.

        Args:
            path (str):
                Path to the checkpoint.
            save_optimizer (bool):
                If `True`, save optimizer state as well.

        """

        state = {
            'network': self.network.base_network.state_dict(),
        }

        if save_optimizer:
            state['optimizer'] = self.optimizer.state_dict()

        torch.save(state, path)

    def load_checkpoint(self, path, load_optimizer=False):
        """Initialize weights from checkpoint.

        Args:
            path (str):
                Path to the checkpoint.
            load_optimizer (bool):
                If `True`, load optimizer state as well.

        """

        state = torch.load(path, map_location=(lambda storage, _: storage))

        self.network.base_network.load_state_dict(state['network'])

        if load_optimizer:
            self.optimizer.load_state_dict(state['optimizer'])

    @classmethod
    def find_latest_checkpoint(cls, checkpoint_dir):
        """Find the most up to date checkpoint file in a checkpoint directory.

        Args:
            checkpoint_dir (str):
                Directory in which to search for checkpoints, must exist.

        Returns:
            The file path to the latest checkpoint.

        """

        # create list of all checkpoints
        checkpoint_template = '{}_*.{}'.format(cls.CHECKPOINT_PREFIX,
                                               cls.CHECKPOINT_POSTFIX)

        checkpoint_template = os.path.join(checkpoint_dir, checkpoint_template)

        all_checkpoints = sorted(glob(checkpoint_template))

        # find lastest checkpoint
        if not all_checkpoints:
            err = "failed to resume training: no previous checkpoints"
            raise ValueError(err)

        return all_checkpoints[-1]

    def _checkpoint_training(self, checkpoint_dir, checkpoint_iterations):
        path = self._checkpoint_path(checkpoint_dir, checkpoint_iterations)

        self.save_checkpoint(path, save_optimizer=True)

        if self._log_enabled:
            fmt = "saved checkpoint '{}'"
            self._log_queue.put(fmt.format(os.path.basename(path)))

    @staticmethod
    def _validate_checkpoint_dir(checkpoint_dir, resuming=False):
        # check existance
        if not os.path.isdir(checkpoint_dir):
            raise ValueError("checkpoint directory must exist")

        # refuse to overwrite checkpoints unless resuming
        if not resuming:
            checkpoint_files = os.listdir(checkpoint_dir)

            if len([f for f in checkpoint_files if not f.startswith('.')]) > 0:
                raise ValueError("checkpoint directory must be empty")

    @classmethod
    def _checkpoint_path(cls, checkpoint_dir, checkpoint_iteration):
        if checkpoint_iteration == 'final':
            checkpoint_id = 'final'
        else:
            checkpoint_id = cls.CHECKPOINT_ID_FMT.format(checkpoint_iteration)

        checkpoint = '{}_{}.{}'.format(cls.CHECKPOINT_PREFIX, checkpoint_id,
                                       cls.CHECKPOINT_POSTFIX)

        return os.path.join(checkpoint_dir, checkpoint)

    @classmethod
    def _checkpoint_iteration(cls, checkpoint_path):
        if checkpoint_path.find('final') != -1:
            return 'final'

        checkpoint_regex = '{}_(\\d+).{}'.format(cls.CHECKPOINT_PREFIX,
                                                 cls.CHECKPOINT_POSTFIX)

        m = re.search(checkpoint_regex, checkpoint_path)
        if m is None:
            err = "invalid training checkpoint naming scheme"
            raise ValueError(err)

        checkpoint_iteration = int(m.group(1))

        return checkpoint_iteration
예제 #22
0
def act(i: int, free_queue: mp.SimpleQueue, full_queue: mp.SimpleQueue,
        model: torch.nn.Module, buffers: Buffers,
        episode_state_count_dict: dict, train_state_count_dict: dict,
        initial_agent_state_buffers, flags):
    try:
        log.info('Actor %i started.', i)
        timings = prof.Timings()

        gym_env = create_env(flags)
        seed = i ^ int.from_bytes(os.urandom(4), byteorder='little')
        gym_env.seed(seed)

        if flags.num_input_frames > 1:
            gym_env = FrameStack(gym_env, flags.num_input_frames)

        env = Environment(gym_env,
                          fix_seed=flags.fix_seed,
                          env_seed=flags.env_seed)

        env_output = env.initial()
        agent_state = model.initial_state(batch_size=1)
        agent_output, unused_state = model(env_output, agent_state)

        while True:
            index = free_queue.get()
            if index is None:
                break

            # Write old rollout end.
            for key in env_output:
                buffers[key][index][0, ...] = env_output[key]
            for key in agent_output:
                buffers[key][index][0, ...] = agent_output[key]
            for i, tensor in enumerate(agent_state):
                initial_agent_state_buffers[index][i][...] = tensor

            # Update the episodic state counts
            episode_state_key = tuple(env_output['frame'].view(-1).tolist())
            if episode_state_key in episode_state_count_dict:
                episode_state_count_dict[episode_state_key] += 1
            else:
                episode_state_count_dict.update({episode_state_key: 1})
            buffers['episode_state_count'][index][0, ...] = \
                torch.tensor(1 / np.sqrt(episode_state_count_dict.get(episode_state_key)))

            # Reset the episode state counts when the episode is over
            if env_output['done'][0][0]:
                for episode_state_key in episode_state_count_dict:
                    episode_state_count_dict = dict()

            # Update the training state counts if you're doing count-based exploration
            if flags.model == 'count':
                train_state_key = tuple(env_output['frame'].view(-1).tolist())
                if train_state_key in train_state_count_dict:
                    train_state_count_dict[train_state_key] += 1
                else:
                    train_state_count_dict.update({train_state_key: 1})
                buffers['train_state_count'][index][0, ...] = \
                    torch.tensor(1 / np.sqrt(train_state_count_dict.get(train_state_key)))

            # Do new rollout
            for t in range(flags.unroll_length):
                timings.reset()

                with torch.no_grad():
                    agent_output, agent_state = model(env_output, agent_state)

                timings.time('model')

                env_output = env.step(agent_output['action'])

                timings.time('step')

                for key in env_output:
                    buffers[key][index][t + 1, ...] = env_output[key]

                for key in agent_output:
                    buffers[key][index][t + 1, ...] = agent_output[key]

                # Update the episodic state counts
                episode_state_key = tuple(
                    env_output['frame'].view(-1).tolist())
                if episode_state_key in episode_state_count_dict:
                    episode_state_count_dict[episode_state_key] += 1
                else:
                    episode_state_count_dict.update({episode_state_key: 1})
                buffers['episode_state_count'][index][t + 1, ...] = \
                    torch.tensor(1 / np.sqrt(episode_state_count_dict.get(episode_state_key)))

                # Reset the episode state counts when the episode is over
                if env_output['done'][0][0]:
                    episode_state_count_dict = dict()

                # Update the training state counts if you're doing count-based exploration
                if flags.model == 'count':
                    train_state_key = tuple(
                        env_output['frame'].view(-1).tolist())
                    if train_state_key in train_state_count_dict:
                        train_state_count_dict[train_state_key] += 1
                    else:
                        train_state_count_dict.update({train_state_key: 1})
                    buffers['train_state_count'][index][t + 1, ...] = \
                        torch.tensor(1 / np.sqrt(train_state_count_dict.get(train_state_key)))

                timings.time('write')
            full_queue.put(index)

        if i == 0:
            log.info('Actor %i: %s', i, timings.summary())

    except KeyboardInterrupt:
        pass
    except Exception as e:
        logging.error('Exception in worker process %i', i)
        traceback.print_exc()
        print()
        raise e