def from_data_list_token(data_list, follow_batch=[]): """ This is pretty a copy paste of the from data list of pytorch geometric batch object with the difference that indexes that are negative are not incremented """ keys = [set(data.keys) for data in data_list] keys = list(set.union(*keys)) assert "batch" not in keys batch = Batch() batch.__data_class__ = data_list[0].__class__ batch.__slices__ = {key: [0] for key in keys} for key in keys: batch[key] = [] for key in follow_batch: batch["{}_batch".format(key)] = [] cumsum = {key: 0 for key in keys} batch.batch = [] for i, data in enumerate(data_list): for key in data.keys: item = data[key] if torch.is_tensor(item) and item.dtype != torch.bool: mask = item >= 0 item[mask] = item[mask] + cumsum[key] if torch.is_tensor(item): size = item.size(data.__cat_dim__(key, data[key])) else: size = 1 batch.__slices__[key].append(size + batch.__slices__[key][-1]) cumsum[key] += data.__inc__(key, item) batch[key].append(item) if key in follow_batch: item = torch.full((size,), i, dtype=torch.long) batch["{}_batch".format(key)].append(item) num_nodes = data.num_nodes if num_nodes is not None: item = torch.full((num_nodes,), i, dtype=torch.long) batch.batch.append(item) if num_nodes is None: batch.batch = None for key in batch.keys: item = batch[key][0] if torch.is_tensor(item): batch[key] = torch.cat(batch[key], dim=data_list[0].__cat_dim__(key, item)) elif isinstance(item, int) or isinstance(item, float): batch[key] = torch.tensor(batch[key]) else: raise ValueError("Unsupported attribute type {} : {}".format(type(item), item)) if torch_geometric.is_debug_enabled(): batch.debug() return batch.contiguous()
def forward(self, inputs: Batch) -> Batch: # Input node_features = self.input_node_model( linear_features=inputs.object_linear_features, conv_features=inputs.object_conv_features, ) edge_features = self.input_edge_model( linear_features=inputs.relation_linear_features) # Message passing edge_features = self.edge_model( nodes=node_features, edges=edge_features, edge_indices=inputs.relation_indexes, ) # Readout global_features = self.output_global_model( edges=edge_features, edge_indices=inputs.relation_indexes, node_to_graph_idx=inputs.batch, num_graphs=inputs.num_graphs, ) # Build output batch so that it can be split back into graphs using Batch.to_data_list() keys_to_copy = ( "n_edges", "n_nodes", "object_boxes", "object_classes", "relation_indexes", "object_image_size", ) outputs = Batch(num_nodes=inputs.num_nodes, batch=inputs.batch, predicate_scores=global_features, **{k: inputs[k] for k in keys_to_copy}) outputs.__slices__ = { "predicate_scores": inputs.__slices__["relation_indexes"], **{k: inputs.__slices__[k] for k in keys_to_copy}, } return outputs
def from_data_list(data_list, follow_batch=[]): r"""Constructs a batch object from a python list holding :class:`torch_geometric.data.Data` objects. The assignment vector :obj:`batch` is created on the fly. Additionally, creates assignment batch vectors for each key in :obj:`follow_batch`.""" keys = [set(data.keys) for data in data_list] keys = set.union(*keys) keys.remove('depth_count') keys = list(keys) depth = max(data.depth_count.shape[0] for data in data_list) assert 'batch' not in keys batch = Batch() batch.__data_class__ = data_list[0].__class__ batch.__slices__ = {key: [0] for key in keys} for key in keys: batch[key] = [] for key in follow_batch: batch['{}_batch'.format(key)] = [] cumsum = {i: 0 for i in range(depth)} depth_count = th.zeros((depth, ), dtype=th.long) batch.batch = [] for i, data in enumerate(data_list): edges = data['edge_index'] for d in range(1, depth): mask = data.depth_mask == d edges[mask] += cumsum[d - 1] cumsum[d - 1] += data.depth_count[d - 1].item() batch['edge_index'].append(edges) depth_count += data['depth_count'] for key in data.keys: if key == 'edge_index' or key == 'depth_count': continue item = data[key] batch[key].append(item) num_nodes = data.num_nodes if num_nodes is not None: item = torch.full((num_nodes, ), i, dtype=torch.long) batch.batch.append(item) if num_nodes is None: batch.batch = None for key in batch.keys: item = batch[key][0] if torch.is_tensor(item): batch[key] = torch.cat(batch[key], dim=data_list[0].__cat_dim__(key, item)) elif isinstance(item, int) or isinstance(item, float): batch[key] = torch.tensor(batch[key]) batch.depth_count = depth_count if torch_geometric.is_debug_enabled(): batch.debug() return batch.contiguous()
def from_data_list(data_list, follow_batch=[]): r"""Constructs a batch object from a python list holding :class:`torch_geometric.data.Data` objects. The assignment vector :obj:`batch` is created on the fly. Additionally, creates assignment batch vectors for each key in :obj:`follow_batch`.""" keys = [set(data.keys) for data in data_list] keys = list(set.union(*keys)) assert 'batch' not in keys batch = Batch() batch.__data_class__ = data_list[0].__class__ batch.__slices__ = {key: [0] for key in keys} for key in keys: batch[key] = [] for key in follow_batch: batch['{}_batch'.format(key)] = [] cumsum = {key: 0 for key in keys} batch.batch = [] for i, data in enumerate(data_list): for key in data.keys: # logger.info(f"key={key}") item = data[key] if torch.is_tensor(item) and item.dtype != torch.bool: item = item + cumsum[key] if torch.is_tensor(item): size = item.size(data.__cat_dim__(key, data[key])) else: size = 1 batch.__slices__[key].append(size + batch.__slices__[key][-1]) cumsum[key] = cumsum[key] + data.__inc__(key, item) batch[key].append(item) if key in follow_batch: item = torch.full((size,), i, dtype=torch.long) batch['{}_batch'.format(key)].append(item) num_nodes = data.num_nodes if num_nodes is not None: item = torch.full((num_nodes,), i, dtype=torch.long) batch.batch.append(item) if num_nodes is None: batch.batch = None for key in batch.keys: item = batch[key][0] logger.debug(f"key = {key}") if torch.is_tensor(item): logger.debug(f"batch[{key}]") logger.debug(f"item.shape = {item.shape}") elem = data_list[0] # type(elem) = Data or ClevrData dim_ = elem.__cat_dim__(key, item) # basically, which dim we want to concat batch[key] = torch.cat(batch[key], dim=dim_) # batch[key] = torch.cat(batch[key], # dim=data_list[0].__cat_dim__(key, item)) elif isinstance(item, int) or isinstance(item, float): batch[key] = torch.tensor(batch[key]) if torch_geometric.is_debug_enabled(): batch.debug() return batch.contiguous()
def _keep_top_x_relations( self, B, edge_to_graph_assignment, inputs, predicate_classes_sorted, predicate_scores_sorted, relation_scores, ): # For each graph, retain the TOP_X_RELATIONS relations ( relation_scores_sorted, # [B x TOP_X_RELATIONS] (relation_indexes_index_sorted, _), # [B * TOP_X_RELATIONS] (_, predicate_scores_index_sorted), # [TOP_X_RELATIONS] ) = scatter_topk_2d_flat( relation_scores, edge_to_graph_assignment, self.top_x_relations, dim_size=B, fill_value=float("-inf"), ) # Final number of relations per graph, could be less than TOP_X_RELATIONS if the # graph had less than (TOP_X_RELATIONS // TOP_K_PREDICATES) from the start n_relations = (relation_indexes_index_sorted != -1).int().sum(dim=1) # Skip locations where relation_indexes_sorted = -1, i.e. there were fewer than TOP_X_RELATIONS to rank. # [n_relations.sum()] relation_scores_sorted = relation_scores_sorted.flatten()[ relation_indexes_index_sorted.flatten() != -1] # Index into relation_indexes to retrieve subj and obj for the top x scoring relations per graph. # Skip locations where relation_indexes_sorted = -1, i.e. there were fewer than TOP_X_RELATIONS to rank. # [2, n_relations.sum()] relation_indexes_index_sorted = relation_indexes_index_sorted.flatten() relation_indexes_index_sorted = relation_indexes_index_sorted[ relation_indexes_index_sorted != -1] relation_indexes_sorted = inputs.relation_indexes[:, relation_indexes_index_sorted] # Index into predicate_scores_sorted and predicate_classes_sorted # to retrieve the top x scoring relations per graph. # Skip locations where predicate_scores_index_sorted = -1, i.e. there were fewer than TOP_X_RELATIONS to rank. # When applying gather turn -1 into 0, otherwise cuda complains, but then remove the gathered values. # [n_relations.sum()] predicate_scores_sorted = predicate_scores_sorted.gather( dim=1, index=predicate_scores_index_sorted.clamp(min=0)) predicate_scores_sorted = predicate_scores_sorted.flatten()[ predicate_scores_index_sorted.flatten() != -1] predicate_classes_sorted = predicate_classes_sorted.gather( dim=1, index=predicate_scores_index_sorted.clamp(min=0)) predicate_classes_sorted = predicate_classes_sorted.flatten()[ predicate_scores_index_sorted.flatten() != -1] relations = Batch( num_nodes=inputs.num_nodes, n_edges=n_relations, relation_scores=relation_scores_sorted, relation_indexes=relation_indexes_sorted, predicate_scores=predicate_scores_sorted, predicate_classes=predicate_classes_sorted, **{ k: inputs[k] for k in ( "n_nodes", "batch", "object_boxes", "object_scores", "object_classes", "object_image_size", ) }, ) # Build relations batch so that it can be split back into graphs using Batch.to_data_list() edge_slice = [0] + n_relations.cumsum(dim=0).tolist() relations.__slices__ = { # Global attributes "n_nodes": inputs.__slices__["n_nodes"], "n_edges": inputs.__slices__["n_edges"], # Node attributes "object_boxes": inputs.__slices__["object_boxes"], "object_scores": inputs.__slices__["object_scores"], "object_classes": inputs.__slices__["object_classes"], "object_image_size": inputs.__slices__["object_image_size"], # Edge attributes "relation_indexes": edge_slice, "relation_scores": edge_slice, "predicate_classes": edge_slice, "predicate_scores": edge_slice, } return relations