Example #1
0
def from_data_list_token(data_list, follow_batch=[]):
    """ This is pretty a copy paste of the from data list of pytorch geometric
    batch object with the difference that indexes that are negative are not incremented
    """

    keys = [set(data.keys) for data in data_list]
    keys = list(set.union(*keys))
    assert "batch" not in keys

    batch = Batch()
    batch.__data_class__ = data_list[0].__class__
    batch.__slices__ = {key: [0] for key in keys}

    for key in keys:
        batch[key] = []

    for key in follow_batch:
        batch["{}_batch".format(key)] = []

    cumsum = {key: 0 for key in keys}
    batch.batch = []
    for i, data in enumerate(data_list):
        for key in data.keys:
            item = data[key]
            if torch.is_tensor(item) and item.dtype != torch.bool:
                mask = item >= 0
                item[mask] = item[mask] + cumsum[key]
            if torch.is_tensor(item):
                size = item.size(data.__cat_dim__(key, data[key]))
            else:
                size = 1
            batch.__slices__[key].append(size + batch.__slices__[key][-1])
            cumsum[key] += data.__inc__(key, item)
            batch[key].append(item)

            if key in follow_batch:
                item = torch.full((size,), i, dtype=torch.long)
                batch["{}_batch".format(key)].append(item)

        num_nodes = data.num_nodes
        if num_nodes is not None:
            item = torch.full((num_nodes,), i, dtype=torch.long)
            batch.batch.append(item)

    if num_nodes is None:
        batch.batch = None

    for key in batch.keys:
        item = batch[key][0]
        if torch.is_tensor(item):
            batch[key] = torch.cat(batch[key], dim=data_list[0].__cat_dim__(key, item))
        elif isinstance(item, int) or isinstance(item, float):
            batch[key] = torch.tensor(batch[key])
        else:
            raise ValueError("Unsupported attribute type {} : {}".format(type(item), item))

    if torch_geometric.is_debug_enabled():
        batch.debug()

    return batch.contiguous()
    def forward(self, inputs: Batch) -> Batch:
        # Input
        node_features = self.input_node_model(
            linear_features=inputs.object_linear_features,
            conv_features=inputs.object_conv_features,
        )
        edge_features = self.input_edge_model(
            linear_features=inputs.relation_linear_features)

        # Message passing
        edge_features = self.edge_model(
            nodes=node_features,
            edges=edge_features,
            edge_indices=inputs.relation_indexes,
        )

        # Readout
        global_features = self.output_global_model(
            edges=edge_features,
            edge_indices=inputs.relation_indexes,
            node_to_graph_idx=inputs.batch,
            num_graphs=inputs.num_graphs,
        )

        # Build output batch so that it can be split back into graphs using Batch.to_data_list()
        keys_to_copy = (
            "n_edges",
            "n_nodes",
            "object_boxes",
            "object_classes",
            "relation_indexes",
            "object_image_size",
        )
        outputs = Batch(num_nodes=inputs.num_nodes,
                        batch=inputs.batch,
                        predicate_scores=global_features,
                        **{k: inputs[k]
                           for k in keys_to_copy})
        outputs.__slices__ = {
            "predicate_scores": inputs.__slices__["relation_indexes"],
            **{k: inputs.__slices__[k]
               for k in keys_to_copy},
        }

        return outputs
Example #3
0
    def from_data_list(data_list, follow_batch=[]):
        r"""Constructs a batch object from a python list holding
        :class:`torch_geometric.data.Data` objects.
        The assignment vector :obj:`batch` is created on the fly.
        Additionally, creates assignment batch vectors for each key in
        :obj:`follow_batch`."""

        keys = [set(data.keys) for data in data_list]
        keys = set.union(*keys)
        keys.remove('depth_count')
        keys = list(keys)
        depth = max(data.depth_count.shape[0] for data in data_list)
        assert 'batch' not in keys

        batch = Batch()
        batch.__data_class__ = data_list[0].__class__
        batch.__slices__ = {key: [0] for key in keys}

        for key in keys:
            batch[key] = []

        for key in follow_batch:
            batch['{}_batch'.format(key)] = []

        cumsum = {i: 0 for i in range(depth)}
        depth_count = th.zeros((depth, ), dtype=th.long)
        batch.batch = []
        for i, data in enumerate(data_list):
            edges = data['edge_index']
            for d in range(1, depth):
                mask = data.depth_mask == d
                edges[mask] += cumsum[d - 1]
                cumsum[d - 1] += data.depth_count[d - 1].item()
            batch['edge_index'].append(edges)
            depth_count += data['depth_count']

            for key in data.keys:
                if key == 'edge_index' or key == 'depth_count':
                    continue
                item = data[key]
                batch[key].append(item)

            num_nodes = data.num_nodes
            if num_nodes is not None:
                item = torch.full((num_nodes, ), i, dtype=torch.long)
                batch.batch.append(item)

        if num_nodes is None:
            batch.batch = None

        for key in batch.keys:
            item = batch[key][0]
            if torch.is_tensor(item):
                batch[key] = torch.cat(batch[key],
                                       dim=data_list[0].__cat_dim__(key, item))
            elif isinstance(item, int) or isinstance(item, float):
                batch[key] = torch.tensor(batch[key])

        batch.depth_count = depth_count
        if torch_geometric.is_debug_enabled():
            batch.debug()

        return batch.contiguous()
Example #4
0
    def from_data_list(data_list, follow_batch=[]):
        r"""Constructs a batch object from a python list holding
        :class:`torch_geometric.data.Data` objects.
        The assignment vector :obj:`batch` is created on the fly.
        Additionally, creates assignment batch vectors for each key in
        :obj:`follow_batch`."""
        keys = [set(data.keys) for data in data_list]
        keys = list(set.union(*keys))
        assert 'batch' not in keys
        batch = Batch()
        batch.__data_class__ = data_list[0].__class__
        batch.__slices__ = {key: [0] for key in keys}
        for key in keys:
            batch[key] = []
        for key in follow_batch:
            batch['{}_batch'.format(key)] = []
        cumsum = {key: 0 for key in keys}
        batch.batch = []
        for i, data in enumerate(data_list):
            for key in data.keys:
                # logger.info(f"key={key}")
                item = data[key]
                if torch.is_tensor(item) and item.dtype != torch.bool:
                    item = item + cumsum[key]
                if torch.is_tensor(item):
                    size = item.size(data.__cat_dim__(key, data[key]))
                else:
                    size = 1
                batch.__slices__[key].append(size + batch.__slices__[key][-1])
                cumsum[key] = cumsum[key] + data.__inc__(key, item)
                batch[key].append(item)

                if key in follow_batch:
                    item = torch.full((size,), i, dtype=torch.long)
                    batch['{}_batch'.format(key)].append(item)

            num_nodes = data.num_nodes
            if num_nodes is not None:
                item = torch.full((num_nodes,), i, dtype=torch.long)
                batch.batch.append(item)

        if num_nodes is None:
            batch.batch = None

        for key in batch.keys:
            item = batch[key][0]
            logger.debug(f"key = {key}")
            if torch.is_tensor(item):
                logger.debug(f"batch[{key}]")
                logger.debug(f"item.shape = {item.shape}")
                elem = data_list[0]     # type(elem) = Data or ClevrData
                dim_ = elem.__cat_dim__(key, item)      # basically, which dim we want to concat
                batch[key] = torch.cat(batch[key], dim=dim_)
                # batch[key] = torch.cat(batch[key],
                #                        dim=data_list[0].__cat_dim__(key, item))
            elif isinstance(item, int) or isinstance(item, float):
                batch[key] = torch.tensor(batch[key])
        if torch_geometric.is_debug_enabled():
            batch.debug()

        return batch.contiguous()
Example #5
0
    def _keep_top_x_relations(
        self,
        B,
        edge_to_graph_assignment,
        inputs,
        predicate_classes_sorted,
        predicate_scores_sorted,
        relation_scores,
    ):
        # For each graph, retain the TOP_X_RELATIONS relations
        (
            relation_scores_sorted,  # [B x TOP_X_RELATIONS]
            (relation_indexes_index_sorted, _),  # [B * TOP_X_RELATIONS]
            (_, predicate_scores_index_sorted),  # [TOP_X_RELATIONS]
        ) = scatter_topk_2d_flat(
            relation_scores,
            edge_to_graph_assignment,
            self.top_x_relations,
            dim_size=B,
            fill_value=float("-inf"),
        )

        # Final number of relations per graph, could be less than TOP_X_RELATIONS if the
        # graph had less than (TOP_X_RELATIONS // TOP_K_PREDICATES) from the start
        n_relations = (relation_indexes_index_sorted != -1).int().sum(dim=1)

        # Skip locations where relation_indexes_sorted = -1, i.e. there were fewer than TOP_X_RELATIONS to rank.
        # [n_relations.sum()]
        relation_scores_sorted = relation_scores_sorted.flatten()[
            relation_indexes_index_sorted.flatten() != -1]

        # Index into relation_indexes to retrieve subj and obj for the top x scoring relations per graph.
        # Skip locations where relation_indexes_sorted = -1, i.e. there were fewer than TOP_X_RELATIONS to rank.
        # [2, n_relations.sum()]
        relation_indexes_index_sorted = relation_indexes_index_sorted.flatten()
        relation_indexes_index_sorted = relation_indexes_index_sorted[
            relation_indexes_index_sorted != -1]
        relation_indexes_sorted = inputs.relation_indexes[:,
                                                          relation_indexes_index_sorted]

        # Index into predicate_scores_sorted and predicate_classes_sorted
        # to retrieve the top x scoring relations per graph.
        # Skip locations where predicate_scores_index_sorted = -1, i.e. there were fewer than TOP_X_RELATIONS to rank.
        # When applying gather turn -1 into 0, otherwise cuda complains, but then remove the gathered values.
        # [n_relations.sum()]
        predicate_scores_sorted = predicate_scores_sorted.gather(
            dim=1, index=predicate_scores_index_sorted.clamp(min=0))
        predicate_scores_sorted = predicate_scores_sorted.flatten()[
            predicate_scores_index_sorted.flatten() != -1]
        predicate_classes_sorted = predicate_classes_sorted.gather(
            dim=1, index=predicate_scores_index_sorted.clamp(min=0))
        predicate_classes_sorted = predicate_classes_sorted.flatten()[
            predicate_scores_index_sorted.flatten() != -1]

        relations = Batch(
            num_nodes=inputs.num_nodes,
            n_edges=n_relations,
            relation_scores=relation_scores_sorted,
            relation_indexes=relation_indexes_sorted,
            predicate_scores=predicate_scores_sorted,
            predicate_classes=predicate_classes_sorted,
            **{
                k: inputs[k]
                for k in (
                    "n_nodes",
                    "batch",
                    "object_boxes",
                    "object_scores",
                    "object_classes",
                    "object_image_size",
                )
            },
        )

        # Build relations batch so that it can be split back into graphs using Batch.to_data_list()
        edge_slice = [0] + n_relations.cumsum(dim=0).tolist()
        relations.__slices__ = {
            # Global attributes
            "n_nodes": inputs.__slices__["n_nodes"],
            "n_edges": inputs.__slices__["n_edges"],
            # Node attributes
            "object_boxes": inputs.__slices__["object_boxes"],
            "object_scores": inputs.__slices__["object_scores"],
            "object_classes": inputs.__slices__["object_classes"],
            "object_image_size": inputs.__slices__["object_image_size"],
            # Edge attributes
            "relation_indexes": edge_slice,
            "relation_scores": edge_slice,
            "predicate_classes": edge_slice,
            "predicate_scores": edge_slice,
        }

        return relations