예제 #1
0
def burn_in(senv, replay_mem):
    tic = time.time()
    for i_idx in range(FLAGS.burn_in_length):
        if i_idx % 10 == 0 and i_idx != 0:
            toc = time.time()
            log_string('Burning in {}/{} sequences, time taken: {}s'.format(
                i_idx, FLAGS.burn_in_length, toc - tic))
            tic = time.time()
        state, model_id = senv.reset(True)
        actions = []
        RGB_temp_list = np.zeros(
            (FLAGS.max_episode_length, FLAGS.resolution, FLAGS.resolution, 3),
            dtype=np.float32)
        R_list = np.zeros((FLAGS.max_episode_length, 3, 4), dtype=np.float32)

        RGB_temp_list[0, ...], _ = replay_mem.read_png_to_uint8(
            state[0][0], state[1][0], model_id)
        R_list[0, ...] = replay_mem.get_R(state[0][0], state[1][0])

        for e_idx in range(FLAGS.max_episode_length - 1):
            actions.append(np.random.randint(FLAGS.action_num))
            state, next_state, done, model_id = senv.step(actions[-1])
            RGB_temp_list[e_idx + 1, ...], _ = replay_mem.read_png_to_uint8(
                next_state[0], next_state[1], model_id)
            R_list[e_idx + 1, ...] = replay_mem.get_R(next_state[0],
                                                      next_state[1])
            if done:
                traj_state = state
                traj_state[0] += [next_state[0]]
                traj_state[1] += [next_state[1]]
                temp_traj = trajectData(traj_state, actions, model_id)
                replay_mem.append(temp_traj)
                break
예제 #2
0
def burn_in(senv, replay_mem):
    K_single = np.asarray([[420.0, 0.0, 112.0], [0.0, 420.0, 112.0],
                           [0.0, 0.0, 1]])
    K_list = np.tile(K_single[None, None, ...],
                     (1, FLAGS.max_episode_length, 1, 1))
    tic = time.time()
    for i_idx in xrange(FLAGS.burn_in_length):
        if i_idx % 10 == 0 and i_idx != 0:
            toc = time.time()
            log_string('Burning in {}/{} sequences, time taken: {}s'.format(
                i_idx, FLAGS.burn_in_length, toc - tic))
            tic = time.time()
        state, model_id = senv.reset(True)
        actions = []
        RGB_temp_list = np.zeros(
            (FLAGS.max_episode_length, FLAGS.resolution, FLAGS.resolution, 3),
            dtype=np.float32)
        R_list = np.zeros((FLAGS.max_episode_length, 3, 4), dtype=np.float32)
        #vox_temp = np.zeros((FLAGS.voxel_resolution, FLAGS.voxel_resolution, FLAGS.voxel_resolution),
        #    dtype=np.float32)

        RGB_temp_list[0, ...], _ = replay_mem.read_png_to_uint8(
            state[0][0], state[1][0], model_id)
        R_list[0, ...] = replay_mem.get_R(state[0][0], state[1][0])
        #vox_temp_list = replay_mem.get_vox_pred(RGB_temp_list, R_list, K_list, 0)
        #vox_temp = np.squeeze(vox_temp_list[0, ...])
        ## run simulations and get memories
        for e_idx in xrange(FLAGS.max_episode_length - 1):
            actions.append(np.random.randint(FLAGS.action_num))
            state, next_state, done, model_id = senv.step(actions[-1])
            RGB_temp_list[e_idx + 1, ...], _ = replay_mem.read_png_to_uint8(
                next_state[0], next_state[1], model_id)
            R_list[e_idx + 1, ...] = replay_mem.get_R(next_state[0],
                                                      next_state[1])
            ## TODO: update vox_temp
            #vox_temp_list = replay_mem.get_vox_pred(RGB_temp_list, R_list, K_list, e_idx+1)
            #vox_temp = np.squeeze(vox_temp_list[e_idx+1, ...])
            if done:
                traj_state = state
                traj_state[0] += [next_state[0]]
                traj_state[1] += [next_state[1]]
                #rewards = replay_mem.get_seq_rewards(RGB_temp_list, R_list, K_list, model_id)
                #print 'rewards: {}'.format(rewards)
                temp_traj = trajectData(traj_state, actions, model_id)
                replay_mem.append(temp_traj)
                break
예제 #3
0
    def go(self,
           i_idx,
           verbose=True,
           add_to_mem=True,
           mode='active',
           is_train=True):
        ''' does 1 rollout, returns mvnet_input'''

        state, model_id = self.env.reset(is_train, i_idx)
        actions = []
        mvnet_input = MVInputs(self.FLAGS, batch_size=1)

        mvnet_input.put(self.single_input_for_state(state), episode_idx=0)

        for e_idx in range(1, self.FLAGS.max_episode_length):

            tic = time.time()
            if mode == 'active':
                #if np.random.uniform(0, 1) < self.FLAGS.epsilon:
                #    probs = [1.0/8]*8
                #    agent_action = np.random.choice(self.env.action_space_n, p=probs)
                #else:
                agent_action = self.agent.select_action(mvnet_input,
                                                        e_idx - 1,
                                                        is_training=is_train)
            elif mode == 'random':
                probs = [1.0 / 8] * 8
                agent_action = np.random.choice(self.env.action_space_n,
                                                p=probs)
            elif mode == 'nolimit':
                agent_action = 0
            elif mode == 'oneway':
                if len(actions) == 0:
                    probs = [1.0 / 8] * 8
                    agent_action = np.random.choice(self.env.action_space_n,
                                                    p=probs)
                    agent_action = np.random.choice([0, 1, 4, 7])
                else:
                    agent_action = actions[0]

            actions.append(agent_action)
            if mode is not 'nolimit':
                state, next_state, done, model_id = self.env.step(actions[-1])
            else:
                state, next_state, done, model_id = self.env.step(actions[-1],
                                                                  nolimit=True)

            mvnet_input.put(self.single_input_for_state(next_state),
                            episode_idx=e_idx)

            if verbose:
                log_string(
                    'Iter: {}, e_idx: {}, azim: {}, elev: {}, model_id: {}, time: {}s'
                    .format(i_idx, e_idx, next_state[0], next_state[1],
                            model_id,
                            time.time() - tic))

            if done:
                traj_state = state
                traj_state[0] += [next_state[0]]
                traj_state[1] += [next_state[1]]

                if add_to_mem:
                    temp_traj = trajectData(traj_state, actions, model_id)
                    self.mem.append(temp_traj)

                self.last_trajectory = traj_state
                break

        return mvnet_input, actions