コード例 #1
0
    def collect_from_bigtable():
        for epoch in range(args.train_epochs):
            #FETCH DATA
            global_i = cbt_global_iterator(cbt_table)

            rows = cbt_read_rows(cbt_table, args.prefix, args.train_steps,
                                 global_i)

            for row in tqdm(
                    rows,
                    "Trajectories {} - {}".format(global_i - args.train_steps,
                                                  global_i - 1)):
                #DESERIALIZE DATA
                bytes_traj = row.cells['trajectory']['traj'.encode()][0].value
                bytes_info = row.cells['trajectory']['info'.encode()][0].value
                traj, info = Trajectory(), Info()
                traj.ParseFromString(bytes_traj)
                info.ParseFromString(bytes_info)

                #print(info.num_steps)
                #print(info.vector_obs_spec)
                #FORMAT DATA
                traj_shape = np.append(info.num_steps, info.vector_obs_spec)
                obs = np.asarray(traj.vector_obs).reshape(traj_shape)
                actions = tf.convert_to_tensor(np.asarray(traj.actions))
                rewards = tf.convert_to_tensor(np.asarray(traj.rewards),
                                               dtype=tf.float32)
                next_obs = np.roll(obs, shift=-1, axis=0)
                next_mask = np.ones(info.num_steps)
                next_mask[-1] = 0

        dataset = tf.data.Dataset.from_tensor_slices(obs, actions, rewards,
                                                     next_obs, next_mask)
        return dataset
コード例 #2
0
ファイル: dqn_agent.py プロジェクト: pootitan/rab-bigtable-rl
    def train_input_fn(self):
        """
        Input function to be passed to estimator.train().

        Reads a single row from bigtable at a time until an experience buffer is filled to a specified buffer_size.

        """
        if self.log_time is True: self.time_logger.reset()

        global_i = cbt_global_iterator(self.cbt_table) - 1
        i = 0
        self.exp_buff.reset()
        while True:
            #FETCH ROW FROM CBT
            row_i = global_i - i
            row = cbt_read_row(self.cbt_table, self.prefix, row_i)

            if self.log_time is True: self.time_logger.log(0)

            #DESERIALIZE DATA
            bytes_traj = row.cells['trajectory']['traj'.encode()][0].value
            bytes_info = row.cells['trajectory']['info'.encode()][0].value
            traj, info = Trajectory(), Info()
            traj.ParseFromString(bytes_traj)
            info.ParseFromString(bytes_info)

            #FORMAT DATA
            obs_shape = np.append(info.num_steps,
                                  info.visual_obs_spec).astype(int)
            obs = np.asarray(traj.visual_obs).reshape(obs_shape)

            self.exp_buff.add_trajectory(obs, traj.actions, traj.rewards,
                                         info.num_steps)

            if self.log_time is True: self.time_logger.log(1)

            if self.exp_buff.size >= self.exp_buff.max_size: break
            i += 1
        print("-> Fetched trajectories {} - {}".format(global_i - i, global_i))

        dataset = tf.data.Dataset.from_tensor_slices(
            ((self.exp_buff.obs, self.exp_buff.next_obs),
             (self.exp_buff.actions, self.exp_buff.rewards,
              self.exp_buff.next_mask)))
        dataset = dataset.shuffle(self.exp_buff.max_size).repeat().batch(
            self.batch_size)

        if self.log_time is True: self.time_logger.log(2)

        return dataset
コード例 #3
0
                new_obs, reward, done, info = env.step(action)

                observations.append(obs)
                actions.append(action)
                rewards.append(reward)
                if args.log_time is True: time_logger.log(2)
                obs = np.asarray(new_obs).astype(bytes)
                if args.log_time is True: time_logger.log(3)
            visual_obs_keys = [
                '{}_visual_{}'.format(row_key, x)
                for x in range(len(observations))
            ]
            if args.log_time is True: time_logger.log(4)

            #BUILD PB2 OBJECTS
            traj, info, visual_obs = Trajectory(), Info(), []
            traj.visual_obs_key.extend(visual_obs_keys)
            traj.actions.extend(actions)
            traj.rewards.extend(rewards)
            info.vector_obs_spec.extend(observations[0].shape)
            info.num_steps = len(actions)
            index = 0
            if args.log_time is True: time_logger.log(5)
            for ob in observations:
                visual_ob = Visual_obs()
                visual_ob.data.extend(np.asarray(ob).flatten().astype(bytes))
                row = cbt_table_visual.row(visual_obs_keys[index])
                index += 1
                row.set_cell(column_family_id='trajectory',
                             column='data'.encode(),
                             value=visual_ob.SerializeToString())
コード例 #4
0
        print("Table doesn't exist.")
        exit()
    else:
        print("Table found.")

    #TRAINING LOOP
    for i in tqdm(range(5000), "Training"):
        #QUERY TABLE FOR PARTIAL ROWS
        regex_filter = '^cartpole_trajectory_{}$'.format(i)
        row_filter = row_filters.RowKeyRegexFilter(regex_filter)
        filtered_rows = table.read_rows(filter_=row_filter)

        for row in filtered_rows:
            bytes_traj = row.cells['trajectory']['traj'.encode()][0].value
            bytes_info = row.cells['trajectory']['info'.encode()][0].value
            traj, info = Trajectory(), Info()
            traj.ParseFromString(bytes_traj)
            info.ParseFromString(bytes_info)

            traj_shape = np.append(np.array(info.num_steps), np.array(info.vector_obs_spec))
            observations = np.array(traj.vector_obs).reshape(traj_shape)
            traj_obs = np.rollaxis(np.array([observations, np.roll(observations, 1)]), 0 , 2)
            traj_actions = np.rollaxis(np.array([traj.actions, np.roll(traj.actions, 1)]), 0 , 2)
            traj_rewards = np.rollaxis(np.array([traj.rewards, np.roll(traj.rewards, 1)]), 0 , 2)
            traj_discounts = np.ones((info.num_steps,2))

            traj_obs = tf.constant(traj_obs, dtype=tf.float32)
            traj_actions = tf.constant(traj_actions, dtype=tf.int32)
            policy_info = ()
            traj_rewards = tf.constant(traj_rewards, dtype=tf.float32)
            traj_discounts = tf.constant(traj_discounts, dtype=tf.float32)