Пример #1
0
    def __getitem__(self, index):
        if self._first:
            self._first = False
        seq_i = self.db.sequence_index_from_sample_index(index)
        with self.timers.timed('state'):
            current_state = self.db.get_db_item('image_crops', index)
            current_state = current_state.transpose(
                (0, 3, 1, 2)).astype(np.float32) / 255.

        num_actions = 4
        num_objects = current_state.shape[0]
        gi = self.db.eos_index_from_sample_index(index)
        assert (index + 1 < gi and gi - index < 100)
        gi = int(np.random.randint(index + 2, gi + 1))
        action_sequence = self.db.get_db_item_list('actions', index, gi)
        seq_len = gi - index
        action_sequence_label = np.zeros((seq_len, num_objects * num_actions),
                                         dtype=np.float32)
        for i, a in enumerate(action_sequence):
            action_sequence_label[i][a[2]] = 1
        action_sequence_label = action_sequence_label.reshape(
            (seq_len, num_objects, num_actions))

        next_state = self.db.get_db_item('image_crops', gi - 1)
        next_state = next_state.transpose(
            (0, 3, 1, 2)).astype(np.float32) / 255.

        with self.timers.timed('sample'):
            sample = {
                'states': tu.to_tensor(current_state),
                'next_states': tu.to_tensor(next_state),
                'action_sequence': tu.to_tensor(action_sequence_label),
                'seq_idx': seq_i
            }
        return sample
Пример #2
0
    def __getitem__(self, index):
        if self._first:
            self._first = False
        seq_i = self.db.sequence_index_from_sample_index(index)
        with self.timers.timed('state'):
            current_state = self.db.get_db_item('image_crops', index)
            current_state = current_state.transpose(
                (0, 3, 1, 2)).astype(np.float32) / 255.
            next_state = self.db.get_db_item('image_crops', index + 1)
            next_state = next_state.transpose(
                (0, 3, 1, 2)).astype(np.float32) / 255.

        num_actions = 4
        num_objects = current_state.shape[0]
        action = self.db.get_db_item('actions', index)
        action_label = np.zeros(num_objects * num_actions, dtype=np.float32)
        action_label[action[2]] = 1
        action_label = action_label.reshape((num_objects, num_actions))
        with self.timers.timed('sample'):
            sample = {
                'states': tu.to_tensor(current_state),
                'next_states': tu.to_tensor(next_state),
                'action_labels': tu.to_tensor(action_label),
                'seq_idx': seq_i
            }
        return sample
Пример #3
0
def collate_torch_graphs(node_feat, edge_feat, node_index_list,
                         edge_index_list):
    """
    Collate a list of graphs and their features.

    :param node_feat: torch.Tensor of shape [N1 + N2 + ..., D1]
    :param edge_feat: torch.Tensor of shape [E1 + E2 + ..., D2]
    :param node_index_list: a list of node indices, in the form of numpy array
    :param edge_index_list: a list of edge indices, in the form of numpy array
    :return: a collated graph of type torch.geometric.data.Data
    """

    node_feat_list, edge_feat_list = split_graph_feature(
        node_feat, edge_feat, node_index_list, edge_index_list)

    graphs = []
    # TODO: vectorize this
    for nf, ef, n_idx, e_idx in zip(node_feat_list, edge_feat_list,
                                    node_index_list, edge_index_list):
        # add supernode to the graph
        supernode_clique = np.tile(n_idx[None, ...], (len(e_idx), 1))
        sn_n_idx, sn_e_idx = add_supernodes(n_idx, e_idx, supernode_clique)
        sn_feat = torch.cat([nf, ef], dim=0)
        torch_e_idx = to_tensor(sn_e_idx).long().t().contiguous().to(
            node_feat.device)
        graphs.append(Data(x=sn_feat, edge_index=torch_e_idx))

    batched_graphs = batch.Batch.from_data_list(graphs)

    num_node = [n.shape[0] for n in node_index_list]
    num_edge = [e.shape[0] for e in edge_index_list]
    assert (batched_graphs.x.shape[0] == (np.sum(num_node) + np.sum(num_edge)))

    return batched_graphs
Пример #4
0
    def __getitem__(self, index):
        seq_i = self.db.sequence_index_from_sample_index(index)
        with self.timers.timed('state'):
            action = self.db.get_db_item('actions', index)
            current_state = self.db.get_db_item('object_state_flat',
                                                index).astype(np.float32)

        plan_sample = self.get_plan_sample(index)
        with self.timers.timed('sample'):
            sample = {
                'states': tu.to_tensor(current_state),
                'action_labels': tu.to_tensor(np.array(action[0])),
                'num_entities': tu.to_tensor(np.array(current_state.shape[0])),
                'seq_idx': seq_i
            }
        sample.update(plan_sample)
        return sample
Пример #5
0
 def __getitem__(self, index):
     if self._first:
         self._first = False
     seq_i = self.db.sequence_index_from_sample_index(index)
     plan_sample = self.get_plan_sample(index)
     with self.timers.timed('state'):
         current_state = self.db.get_db_item('image_crops', index)
         current_state = current_state.transpose(
             (0, 3, 1, 2)).astype(np.float32) / 255.
     with self.timers.timed('sample'):
         sample = {'states': tu.to_tensor(current_state), 'seq_idx': seq_i}
     sample.update(plan_sample)
     return sample
Пример #6
0
def positional_encoding(batch_size, seq_len, enc_dim, device=None):
    """
    Positional encoding with wave function
    :param batch_size: batch size
    :param seq_len: sequence length to encode
    :param enc_dim: dimension of the encoding
    :return: [B, T, D]
    """
    pos_i = np.tile(np.arange(seq_len)[:, None], (1, enc_dim)).astype(np.float32)
    enc_i = np.tile(np.arange(enc_dim)[None, :], (seq_len, 1)).astype(np.float32)
    pos_enc = np.zeros((seq_len, enc_dim), dtype=np.float32)
    pos_enc[:, ::2] = np.sin(pos_i[:, ::2] / (np.power(10000, 2 * enc_i[:, ::2] / enc_dim)))
    pos_enc[:, 1::2] = np.cos(pos_i[:, 1::2] / (np.power(10000, 2 * (enc_i[:, 1::2] + 1) / enc_dim)))
    pos_enc = np.tile(pos_enc[None, ...], (batch_size, 1, 1))
    return to_tensor(pos_enc, device=device)
Пример #7
0
    def __getitem__(self, index):
        if self._first:
            self._first = False
        seq_i = self.db.sequence_index_from_sample_index(index)
        with self.timers.timed('state'):
            current_state = self.db.get_db_item('gt_state', index)
            num_objects = current_state.shape[0]
            full_edges = self.graph_edges(num_objects)
            types = self.db.get_db_item('object_type_indices', index)
            num_types = self.db.get_db_item('num_object_types', index)[0]
            sym_state = self.db.get_db_item('symbolic_state', index)
            current_state, entity_states = make_bullet_gt_input(
                current_state, sym_state, full_edges, types, num_types)

        plan_sample = self.get_plan_sample(index)
        with self.timers.timed('sample'):
            sample = {
                'states': tu.to_tensor(current_state),
                'entity_states': tu.to_tensor(entity_states),
                'num_entities': tu.to_tensor(np.array(entity_states.shape[0])),
                'seq_idx': seq_i
            }
        sample.update(plan_sample)
        return sample
Пример #8
0
    def _serialize_subgoals(self, entity_state, object_state, curr_goal):
        assert (len(curr_goal.shape) == 3)
        assert (len(entity_state.shape) == 3)
        assert (entity_state.shape[1] == curr_goal.shape[1])
        state_np = tu.to_numpy(entity_state)[0]
        num_predicate = curr_goal.shape[-1]
        curr_goal_np = tu.to_numpy(curr_goal[0])
        goal_object_index, goal_predicates_index = np.where(curr_goal_np != 2)

        goal_index = np.stack([goal_object_index,
                               goal_predicates_index]).transpose()
        goal_predicates_value = curr_goal_np[(goal_object_index,
                                              goal_predicates_index)]
        num_goal = goal_index.shape[0]
        # predict satisfaction
        sat_state_inputs = state_np[goal_object_index]
        sat_predicate_mask = npu.to_onehot(goal_predicates_index,
                                           num_predicate).astype(np.float32)
        sat_predicate = sat_predicate_mask.copy()
        sat_predicate[sat_predicate_mask.astype(
            np.bool)] = goal_predicates_value

        sat_state_inputs = tu.to_tensor(sat_state_inputs[None, ...],
                                        device=entity_state.device)
        sat_sym_inputs_np = np.concatenate((sat_predicate, sat_predicate_mask),
                                           axis=-1)[None, ...]
        sat_sym_inputs = tu.to_tensor(sat_sym_inputs_np,
                                      device=entity_state.device)
        sat_preds = tu.to_numpy(
            self.forward_sat(sat_state_inputs,
                             sat_sym_inputs).argmax(-1))[0]  # [ng]
        assert (sat_preds.shape[0] == num_goal)
        if self.verbose and self.env is not None:
            for sat_p, sat_m, sp, oi in zip(sat_predicate, sat_predicate_mask,
                                            sat_preds, goal_object_index):
                sat_pad = np.hstack([[oi], sat_p, sat_m, [sp]])
                print('[bp] sat: ',
                      self.env.deserialize_satisfied_entry(sat_pad))

        # Construct dependency graphs
        nodes, edges = construct_full_graph(num_goal)
        src_object_index = goal_object_index[
            edges[:,
                  0]]  # list of [object_idx, predicate_idx] for each edge source
        tgt_object_index = goal_object_index[edges[:, 1]]
        src_inputs = state_np[src_object_index]  # list of object states
        tgt_inputs = state_np[tgt_object_index]

        src_predicate_value = goal_predicates_value[
            edges[:, 0]]  # list of predicate values for each edge source
        src_predicate_index = goal_predicates_index[edges[:, 0]]
        tgt_predicate_value = goal_predicates_value[edges[:, 1]]
        tgt_predicate_index = goal_predicates_index[edges[:, 1]]
        src_predicate_mask = npu.to_onehot(src_predicate_index,
                                           num_predicate).astype(np.float32)
        src_predicate = np.zeros_like(src_predicate_mask)
        src_predicate[src_predicate_mask.astype(np.bool)] = src_predicate_value
        tgt_predicate_mask = npu.to_onehot(tgt_predicate_index,
                                           num_predicate).astype(np.float32)
        tgt_predicate = np.zeros_like(tgt_predicate_mask)
        tgt_predicate[tgt_predicate_mask.astype(np.bool)] = tgt_predicate_value
        # dependency_inputs_np = np.concatenate(
        #     (src_inputs, tgt_inputs, src_predicate, src_predicate_mask, tgt_predicate, tgt_predicate_mask), axis=-1)
        dependency_state_inputs_np = np.concatenate((src_inputs, tgt_inputs),
                                                    axis=-1)
        dependency_sym_inputs_np = np.concatenate(
            (src_predicate, src_predicate_mask, tgt_predicate,
             tgt_predicate_mask),
            axis=-1)

        if dependency_state_inputs_np.shape[0] > 0:
            dependency_state_inputs = tu.to_tensor(
                dependency_state_inputs_np,
                device=entity_state.device).unsqueeze(0)
            dependency_sym_inputs = tu.to_tensor(
                dependency_sym_inputs_np,
                device=entity_state.device).unsqueeze(0)
            deps_preds = tu.to_numpy(
                self.forward_dep(dependency_state_inputs,
                                 dependency_sym_inputs).argmax(-1))[0]
            dep_graph_edges = edges[deps_preds > 0]
        else:
            dep_graph_edges = np.array([])
        sorted_goal_groups = sort_goal_graph(dep_graph_edges, nodes)
        focus_group_idx = None
        for gg in reversed(sorted_goal_groups):
            if not np.any(sat_preds[gg]):  # if unsatisfied
                focus_group_idx = gg
                break

        if focus_group_idx is None:
            return None, 'NETWORK_ALL_SATISFIED'

        # focus_group_idx is a list of goal index
        focus_group_np = np.ones_like(curr_goal_np) * 2
        for fg_idx in focus_group_idx:
            fg_obj_i, fg_pred_i = goal_index[fg_idx]
            focus_group_np[fg_obj_i, fg_pred_i] = curr_goal_np[fg_obj_i,
                                                               fg_pred_i]
        focus_group = tu.to_tensor(focus_group_np,
                                   device=entity_state.device).unsqueeze(0)
        focus_group = tu.to_onehot(focus_group, 3)
        return focus_group, -1