Exemplo n.º 1
0
def test_radius_graph(dtype, device):
    x = tensor([
        [-1, -1],
        [-1, +1],
        [+1, +1],
        [+1, -1],
    ], dtype, device)

    edge_index = radius_graph(x, r=2, flow='target_to_source')
    assert to_set(edge_index) == set([(0, 1), (0, 3), (1, 0), (1, 2), (2, 1),
                                      (2, 3), (3, 0), (3, 2)])

    edge_index = radius_graph(x, r=2, flow='source_to_target')
    assert to_set(edge_index) == set([(1, 0), (3, 0), (0, 1), (2, 1), (1, 2),
                                      (3, 2), (0, 3), (2, 3)])
Exemplo n.º 2
0
    def forward(self, data) -> torch.Tensor:
        num_neighbors = 3  # typical number of neighbors
        num_nodes = 4  # typical number of nodes
        num_z = self.num_z  # number of atom types

        # graph
        edge_src, edge_dst = radius_graph(data.pos, 10.0, data.batch)

        # spherical harmonics
        edge_vec = data.pos[edge_src] - data.pos[edge_dst]
        edge_sh = o3.spherical_harmonics(self.irreps_sh, edge_vec, normalize=False, normalization='component')

        # edge types
        edge_zz = num_z * data.z[edge_src] + data.z[edge_dst]  # from 0 to num_z^2 - 1
        edge_zz = torch.nn.functional.one_hot(edge_zz, num_z**2).mul(num_z)
        edge_zz = edge_zz.to(edge_sh.dtype)

        # edge attributes
        edge_attr = self.mul(edge_zz, edge_sh)

        # For each node, the initial features are the sum of the spherical harmonics of the neighbors
        node_features = scatter(edge_sh, edge_dst, dim=0).div(num_neighbors**0.5)

        # For each edge, tensor product the features on the source node with the spherical harmonics
        edge_features = self.tp1(node_features[edge_src], edge_attr)
        node_features = scatter(edge_features, edge_dst, dim=0).div(num_neighbors**0.5)

        edge_features = self.tp2(node_features[edge_src], edge_attr)
        node_features = scatter(edge_features, edge_dst, dim=0).div(num_neighbors**0.5)

        # For each graph, all the node's features are summed
        return scatter(node_features, data.batch, dim=0).div(num_nodes**0.5)
Exemplo n.º 3
0
    def forward(self, x, batch=None):
        spatial = self.lin_s(x)
        to_propagate = self.lin_flr(x)

        if self.neighbor_algo == "knn":
            edge_index = knn_graph(spatial,
                                   self.k,
                                   batch,
                                   loop=False,
                                   flow=self.flow,
                                   cosine=False)
        elif self.neighbor_algo == "radius":
            edge_index = radius_graph(spatial,
                                      self.radius,
                                      batch,
                                      loop=False,
                                      flow=self.flow,
                                      max_num_neighbors=self.k)
        else:
            raise Exception("Unknown neighbor algo {}".format(
                self.neighbor_algo))

        reference = spatial.index_select(0, edge_index[1])
        neighbors = spatial.index_select(0, edge_index[0])

        distancessq = torch.sum((reference - neighbors)**2, dim=-1)
        # Factor 10 gives a better initial spread
        distance_weight = torch.exp(-10. * distancessq)

        prop_feat = self.propagate(edge_index,
                                   x=to_propagate,
                                   edge_weight=distance_weight)

        return edge_index, self.lin_fout(torch.cat([prop_feat, x], dim=-1))
Exemplo n.º 4
0
    def forward(self, data) -> torch.Tensor:
        num_neighbors = 2  # typical number of neighbors
        num_nodes = 4  # typical number of nodes

        edge_src, edge_dst = radius_graph(
            x=data.pos, r=1.1,
            batch=data.batch)  # tensors of indices representing the graph
        edge_vec = data.pos[edge_src] - data.pos[edge_dst]
        edge_sh = o3.spherical_harmonics(
            l=self.irreps_sh,
            x=edge_vec,
            normalize=
            False,  # here we don't normalize otherwise it would not be a polynomial
            normalization='component')

        # For each node, the initial features are the sum of the spherical harmonics of the neighbors
        node_features = scatter(edge_sh, edge_dst,
                                dim=0).div(num_neighbors**0.5)

        # For each edge, tensor product the features on the source node with the spherical harmonics
        edge_features = self.tp1(node_features[edge_src], edge_sh)
        node_features = scatter(edge_features, edge_dst,
                                dim=0).div(num_neighbors**0.5)

        edge_features = self.tp2(node_features[edge_src], edge_sh)
        node_features = scatter(edge_features, edge_dst,
                                dim=0).div(num_neighbors**0.5)

        # For each graph, all the node's features are summed
        return scatter(node_features, data.batch, dim=0).div(num_nodes**0.5)
Exemplo n.º 5
0
def train(model, optimizer, scheduler, loss_fn, dataloader, epoch):
    model.train()
    loss_avg_arr = []
    loss_avg = utils.RunningAverage()
    with tqdm(total=len(dataloader)) as t:
        for data in dataloader:
            optimizer.zero_grad()
            data = data.to('cuda')
            x_cont = data.x[:,:8]
            x_cat = data.x[:,8:].long()
            phi = torch.atan2(data.x[:,1], data.x[:,0])
            etaphi = torch.cat([data.x[:,3][:,None], phi[:,None]], dim=1)        
            # NB: there is a problem right now for comparing hits at the +/- pi boundary
            edge_index = radius_graph(etaphi, r=deltaR, batch=data.batch, loop=True, max_num_neighbors=255)
            result = model(x_cont, x_cat, edge_index, data.batch)
            loss = loss_fn(result, data.x, data.y, data.batch)
            loss.backward()
            optimizer.step()
            # update the average loss
            loss_avg_arr.append(loss.item())
            loss_avg.update(loss.item())
            t.set_postfix(loss='{:05.3f}'.format(loss_avg()))
            t.update()
    scheduler.step(np.mean(loss_avg_arr))
    print('Training epoch: {:02d}, MSE: {:.4f}'.format(epoch, np.mean(loss_avg_arr)))
Exemplo n.º 6
0
    def forward(self, data):

        x, pos, batch, u = data.x, data.pos, data.batch, data.u

        # Get edges using positions by computing the kNNs or the neighbors within a radius
        #edge_index = knn_graph(pos, k=self.k_nn, batch=batch, loop=self.loop)
        edge_index = radius_graph(pos,
                                  r=self.k_nn,
                                  batch=batch,
                                  loop=self.loop)

        # Start message passing
        for layer in self.layers:
            if self.namemodel == "DeepSet":
                x = layer(x)
            elif self.namemodel == "PointNet":
                x = layer(x=x, pos=pos, edge_index=edge_index)
            elif self.namemodel == "MetaNet":
                x, dumb, u = layer(x, edge_index, None, u, batch)
            else:
                x = layer(x=x, edge_index=edge_index)
            self.h = x
            x = x.relu()

        # Mix different global pooling layers
        addpool = global_add_pool(x, batch)  # [num_examples, hidden_channels]
        meanpool = global_mean_pool(x, batch)
        maxpool = global_max_pool(x, batch)
        #self.pooled = torch.cat([addpool, meanpool, maxpool], dim=1)
        self.pooled = torch.cat([addpool, meanpool, maxpool, u], dim=1)

        # Final linear layer
        return self.lin(self.pooled)
Exemplo n.º 7
0
def radius(x, r=0.5, loop=False, dtype=None, device=None):
    N, D = x.shape
    batch = torch.zeros(N, dtype=torch.long)
    edge_index = radius_graph(x, r, batch=batch, loop=loop).to(device)
    edge_val = torch.ones(edge_index.shape[-1], dtype=dtype, device=device)
    return SparseTensor(
        row=edge_index[0], col=edge_index[1], value=edge_val, sparse_sizes=(N, N)
    )
Exemplo n.º 8
0
def test_radius_graph(dtype, device):
    x = tensor([
        [-1, -1],
        [-1, +1],
        [+1, +1],
        [+1, -1],
    ], dtype, device)

    row, col = radius_graph(x, r=2, flow='target_to_source')
    col = col.view(-1, 2).sort(dim=-1)[0].view(-1)
    assert row.tolist() == [0, 0, 1, 1, 2, 2, 3, 3]
    assert col.tolist() == [1, 3, 0, 2, 1, 3, 0, 2]

    row, col = radius_graph(x, r=2, flow='source_to_target')
    row = row.view(-1, 2).sort(dim=-1)[0].view(-1)
    assert row.tolist() == [1, 3, 0, 2, 1, 3, 0, 2]
    assert col.tolist() == [0, 0, 1, 1, 2, 2, 3, 3]
Exemplo n.º 9
0
    def forward(self, data: Union[Data, Dict[str,
                                             torch.Tensor]]) -> torch.Tensor:
        """evaluate the network

        Parameters
        ----------
        data : `torch_geometric.data.Data` or dict
            data object containing
            - ``pos`` the position of the nodes (atoms)
            - ``x`` the input features of the nodes, optional
            - ``z`` the attributes of the nodes, for instance the atom type, optional
            - ``batch`` the graph to which the node belong, optional
        """
        if 'batch' in data:
            batch = data['batch']
        else:
            batch = data['pos'].new_zeros(data['pos'].shape[0],
                                          dtype=torch.long)

        edge_index = radius_graph(data['pos'], self.max_radius, batch)
        edge_src = edge_index[0]
        edge_dst = edge_index[1]
        edge_vec = data['pos'][edge_src] - data['pos'][edge_dst]
        edge_sh = o3.spherical_harmonics(self.irreps_edge_attr,
                                         edge_vec,
                                         True,
                                         normalization='component')
        edge_length = edge_vec.norm(dim=1)
        edge_length_embedded = soft_one_hot_linspace(
            x=edge_length,
            start=0.0,
            end=self.max_radius,
            number=self.number_of_basis,
            basis='gaussian',
            cutoff=False).mul(self.number_of_basis**0.5)
        edge_attr = smooth_cutoff(
            edge_length / self.max_radius)[:, None] * edge_sh

        if self.input_has_node_in and 'x' in data:
            assert self.irreps_in is not None
            x = data['x']
        else:
            assert self.irreps_in is None
            x = data['pos'].new_ones((data['pos'].shape[0], 1))

        if self.input_has_node_attr and 'z' in data:
            z = data['z']
        else:
            assert self.irreps_node_attr == o3.Irreps("0e")
            z = data['pos'].new_ones((data['pos'].shape[0], 1))

        for lay in self.layers:
            x = lay(x, z, edge_src, edge_dst, edge_attr, edge_length_embedded)

        if self.reduce_output:
            return scatter(x, batch, dim=0).div(self.num_nodes**0.5)
        else:
            return x
Exemplo n.º 10
0
def test_radius_graph(dtype, device):
    x = tensor([
        [-1, -1],
        [-1, +1],
        [+1, +1],
        [+1, -1],
    ], dtype, device)

    out = radius_graph(x, r=2)
    assert coalesce(out).tolist() == [[0, 0, 1, 1, 2, 2, 3, 3],
                                      [1, 3, 0, 2, 1, 3, 0, 2]]
Exemplo n.º 11
0
    def forward(self, data) -> torch.Tensor:
        node_atom = data['z']
        node_pos = data['pos']
        batch = data['batch']

        # The graph
        edge_src, edge_dst = radius_graph(node_pos,
                                          r=self.max_radius,
                                          batch=batch,
                                          max_num_neighbors=1000)

        # Edge attributes
        edge_vec = node_pos[edge_src] - node_pos[edge_dst]
        edge_sh = o3.spherical_harmonics(l=range(self.sh_lmax + 1),
                                         x=edge_vec,
                                         normalize=True,
                                         normalization='component')

        # Edge length embedding
        edge_length = edge_vec.norm(dim=1)
        edge_length_embedding = soft_one_hot_linspace(
            edge_length,
            0.0,
            self.max_radius,
            self.num_basis,
            basis='smooth_finite',
            cutoff=True,
        ).mul(self.num_basis**0.5)

        node_input = node_pos.new_ones(node_pos.shape[0], 1)

        node_attr = node_atom.new_tensor([-1, 0, -1, -1, -1, -1, 1, 2, 3,
                                          4])[node_atom]
        node_attr = torch.nn.functional.one_hot(node_attr, 5).mul(5**0.5)

        node_outputs = self.mp(node_features=node_input,
                               node_attr=node_attr,
                               edge_src=edge_src,
                               edge_dst=edge_dst,
                               edge_attr=edge_sh,
                               edge_scalars=edge_length_embedding)

        node_outputs = node_outputs[:, 0] + node_outputs[:, 1].pow(2).mul(0.5)
        node_outputs = node_outputs.view(-1, 1)

        node_outputs = node_outputs.div(self.num_nodes**0.5)

        if self.atomref is not None:
            node_outputs = node_outputs + self.atomref[node_atom]
        # for target=7, MAE of 75eV

        outputs = scatter(node_outputs, batch, dim=0)

        return outputs
Exemplo n.º 12
0
def test_radius_graph(dtype, device):
    x = tensor([
        [-1, -1],
        [-1, +1],
        [+1, +1],
        [+1, -1],
    ], dtype, device)

    row, col = radius_graph(x, r=2)

    assert row.tolist() == [0, 0, 1, 1, 2, 2, 3, 3]
    assert col.tolist() == [1, 3, 0, 2, 1, 3, 0, 2]
Exemplo n.º 13
0
def test_radius_graph_large(dtype, device):
    x = torch.randn(1000, 3, dtype=dtype, device=device)

    edge_index = radius_graph(x,
                              r=0.5,
                              flow='target_to_source',
                              loop=True,
                              max_num_neighbors=2000)

    tree = scipy.spatial.cKDTree(x.cpu().numpy())
    col = tree.query_ball_point(x.cpu(), r=0.5)
    truth = set([(i, j) for i, ns in enumerate(col) for j in ns])

    assert to_set(edge_index.cpu()) == truth
Exemplo n.º 14
0
    def forward(self, inputs):
        """Apply forward pass of the model"""
        x = inputs.x
        #         print(x.shape)
        spatial = self.input_spatial_network(x)
        features = self.input_feature_network(x)
        spatial = self.emb_network(
            torch.cat([inputs.x, features, spatial], axis=-1))

        edge_index = radius_graph(spatial,
                                  r=self.r,
                                  batch=inputs.batch,
                                  loop=False,
                                  max_num_neighbors=30)

        # Loop over iterations of edge and node networks
        for i in range(self.n_graph_iters):
            features_inital = features

            # Apply edge network
            e = torch.sigmoid(self.edge_network(features, edge_index))

            # Apply node network
            features = self.node_network(features, e, edge_index)
            spatial = self.emb_network(
                torch.cat([inputs.x, features, spatial], axis=-1))

            edge_index = radius_graph(spatial,
                                      r=self.r,
                                      batch=inputs.batch,
                                      loop=False,
                                      max_num_neighbors=30)

            features = features_inital + features

        return self.edge_network(features, edge_index), spatial, edge_index
Exemplo n.º 15
0
def radius_graph(x,
                 r,
                 batch=None,
                 loop=False,
                 max_num_neighbors=32,
                 flow='source_to_target',
                 num_workers=1):
    r"""Computes graph edges to all points within a given distance.

    Args:
        x (Tensor): Node feature matrix
            :math:`\mathbf{X} \in \mathbb{R}^{N \times F}`.
        r (float): The radius.
        batch (LongTensor, optional): Batch vector
            :math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N`, which assigns each
            node to a specific example. (default: :obj:`None`)
        loop (bool, optional): If :obj:`True`, the graph will contain
            self-loops. (default: :obj:`False`)
        max_num_neighbors (int, optional): The maximum number of neighbors to
            return for each element in :obj:`y`. (default: :obj:`32`)
        flow (string, optional): The flow direction when using in combination
            with message passing (:obj:`"source_to_target"` or
            :obj:`"target_to_source"`). (default: :obj:`"source_to_target"`)
        num_workers (int): Number of workers to use for computation. Has no
            effect in case :obj:`batch` is not :obj:`None`, or the input lies
            on the GPU. (default: :obj:`1`)

    :rtype: :class:`LongTensor`

    .. code-block:: python

        import torch
        from torch_geometric.nn import radius_graph

        x = torch.Tensor([[-1, -1], [-1, 1], [1, -1], [1, 1]])
        batch = torch.tensor([0, 0, 0, 0])
        edge_index = radius_graph(x, r=1.5, batch=batch, loop=False)
    """
    if torch_cluster is None:
        raise ImportError('`radius_graph` requires `torch-cluster`.')

    return torch_cluster.radius_graph(x, r, batch, loop, max_num_neighbors,
                                      flow, num_workers)
Exemplo n.º 16
0
def test():
    from torch_cluster import radius_graph
    from e3nn.util.test import assert_equivariant, assert_auto_jitable

    mp = MessagePassing(
        irreps_node_input="0e",
        irreps_node_hidden="0e + 1e",
        irreps_node_output="1e",
        irreps_node_attr="0e + 1e",
        irreps_edge_attr="1e",
        layers=3,
        fc_neurons=[2, 100],
        num_neighbors=3.0,
    )

    num_nodes = 4
    node_pos = torch.randn(num_nodes, 3)
    edge_index = radius_graph(node_pos, 3.0)
    edge_src, edge_dst = edge_index
    num_edges = edge_index.shape[1]
    edge_attr = node_pos[edge_index[0]] - node_pos[edge_index[1]]

    node_features = torch.randn(num_nodes, 1)
    node_attr = torch.randn(num_nodes, 4)
    edge_scalars = torch.randn(num_edges, 2)

    assert mp(node_features, node_attr, edge_src, edge_dst, edge_attr,
              edge_scalars).shape == (num_nodes, 3)

    assert_equivariant(
        mp,
        irreps_in=[
            mp.irreps_node_input, mp.irreps_node_attr, None, None,
            mp.irreps_edge_attr, None
        ],
        args_in=[
            node_features, node_attr, edge_src, edge_dst, edge_attr,
            edge_scalars
        ],
        irreps_out=[mp.irreps_node_output],
    )

    assert_auto_jitable(mp.layers[0].first)
Exemplo n.º 17
0
    def forward(self, pos, batch):
        edge_index = radius_graph(
            pos,
            r=self.cutoff_upper,
            batch=batch,
            loop=self.loop,
            max_num_neighbors=self.max_num_neighbors + 1,
        )

        # make sure we didn't miss any neighbors due to max_num_neighbors
        assert not (torch.unique(
            edge_index[0], return_counts=True
        )[1] > self.max_num_neighbors).any(), (
            "The neighbor search missed some atoms due to max_num_neighbors being too low. "
            "Please increase this parameter to include the maximum number of atoms within the cutoff."
        )

        edge_vec = pos[edge_index[0]] - pos[edge_index[1]]

        mask: Optional[torch.Tensor] = None
        if self.loop:
            # mask out self loops when computing distances because
            # the norm of 0 produces NaN gradients
            # NOTE: might influence force predictions as self loop gradients are ignored
            mask = edge_index[0] != edge_index[1]
            edge_weight = torch.zeros(edge_vec.size(0), device=edge_vec.device)
            edge_weight[mask] = torch.norm(edge_vec[mask], dim=-1)
        else:
            edge_weight = torch.norm(edge_vec, dim=-1)

        lower_mask = edge_weight >= self.cutoff_lower
        if self.loop and mask is not None:
            # keep self loops even though they might be below the lower cutoff
            lower_mask = lower_mask | ~mask
        edge_index = edge_index[:, lower_mask]
        edge_weight = edge_weight[lower_mask]

        if self.return_vecs:
            edge_vec = edge_vec[lower_mask]
            return edge_index, edge_weight, edge_vec
        # TODO: return only `edge_index` and `edge_weight` once
        # Union typing works with TorchScript (https://github.com/pytorch/pytorch/pull/53180)
        return edge_index, edge_weight, None
Exemplo n.º 18
0
    def preprocess(self, data: Union[Data,
                                     Dict[str, torch.Tensor]]) -> torch.Tensor:
        if 'batch' in data:
            batch = data['batch']
        else:
            batch = data['pos'].new_zeros(data['pos'].shape[0],
                                          dtype=torch.long)

        # Create graph
        edge_index = radius_graph(data['pos'],
                                  self.max_radius,
                                  batch,
                                  max_num_neighbors=len(data['pos']) - 1)
        edge_src = edge_index[0]
        edge_dst = edge_index[1]

        # Edge attributes
        edge_vec = data['pos'][edge_src] - data['pos'][edge_dst]

        return batch, data['x'], edge_src, edge_dst, edge_vec
Exemplo n.º 19
0
def prepare_data():
    ### Load dataset
    dataset = GeometricShapes(root='data/GeometricShapes')
    print(dataset)

    # # visualize shapes
    # data = dataset[2]
    # print(data)
    # visualize_mesh(data.pos, data.face)

    # data = dataset[4]
    # print(data)
    # visualize_mesh(data.pos, data.face)

    ### Generate point cloud
    torch.manual_seed(42)

    dataset.transform = SamplePoints(num=256)

    # data = dataset[0]
    # print(data)
    # visualize_points(data.pos, data.edge_index)

    # data = dataset[4]
    # print(data)
    # visualize_points(data.pos)

    ### Grouping
    data = dataset[0]
    # # data.edge_index = knn_graph(data.pos, k=6)
    data.edge_index = radius_graph(data.pos, r=0.2)
    print(data.edge_index.shape)
    visualize_points(data.pos, edge_index=data.edge_index)

    data = dataset[4]
    data.edge_index = knn_graph(data.pos, k=6)
    print(data.edge_index.shape)
    visualize_points(data.pos, edge_index=data.edge_index)

    return dataset
Exemplo n.º 20
0
    def forward(self, data) -> torch.Tensor:
        num_nodes = 4  # typical number of nodes

        edge_src, edge_dst = radius_graph(x=data.pos, r=2.5, batch=data.batch)
        edge_vec = data.pos[edge_src] - data.pos[edge_dst]
        edge_attr = o3.spherical_harmonics(l=self.irreps_sh,
                                           x=edge_vec,
                                           normalize=True,
                                           normalization='component')
        edge_length_embedded = soft_one_hot_linspace(x=edge_vec.norm(dim=1),
                                                     start=0.5,
                                                     end=2.5,
                                                     number=3,
                                                     basis='smooth_finite',
                                                     cutoff=True) * 3**0.5

        x = scatter(edge_attr, edge_dst, dim=0).div(self.num_neighbors**0.5)

        x = self.conv(x, edge_src, edge_dst, edge_attr, edge_length_embedded)
        x = self.gate(x)
        x = self.final(x, edge_src, edge_dst, edge_attr, edge_length_embedded)

        return scatter(x, data.batch, dim=0).div(num_nodes**0.5)
Exemplo n.º 21
0
    def forward(self, x, batch: OptTensor = None):

        if batch is None:
            batch = torch.zeros(x.size()[0],
                                dtype=torch.int64,
                                device=x.device)
        '''Embedding1: Intermediate Latent space features (hiddenDim)'''
        x_emb = self.inputnet(x)
        '''KNN(k neighbors) over intermediate Latent space features'''
        for ec in self.edgeconvs:
            edge_index = knn_graph(x_emb,
                                   self.k,
                                   batch,
                                   loop=False,
                                   flow=ec.flow)
            x_emb = x_emb + ec(x_emb, edge_index)
        '''
        [1]
        Embedding2: Final Latent Space embedding coords from x,y,z to ncats_out
        '''
        out = self.output(x_emb)
        #plot = self.plotlayer(out)
        '''KNN(k neighbors) over Embedding2 features'''
        edge_index = radius_graph(out,
                                  r=0.5,
                                  batch=batch,
                                  max_num_neighbors=self.k,
                                  loop=False)
        ''' 
        use Embedding1 to build an edge classifier
        inputnet_cat is residual to inputnet
        '''
        x_cat = self.inputnet_cat(x)  #+ x_emb
        '''
        [2]
        Compute Edge Categories Convolution over Embedding1
        '''
        for ec in self.edgecatconvs:
            x_cat = x_cat + ec(torch.cat([x_cat, x_emb.detach(), x], dim=1),
                               edge_index)

        edge_scores = self.edge_classifier(
            torch.cat([x_cat[edge_index[0]], x_cat[edge_index[1]]],
                      dim=1)).squeeze()
        '''
        use the predicted graph to generate disjoint subgraphs
        these are our physics objects
        '''
        objects = UnionFind(x.size()[0])
        good_edges = edge_index[:, torch.argmax(edge_scores, dim=1) > 0]
        good_edges_cpu = good_edges.cpu().numpy()

        for edge in good_edges_cpu.T:
            objects.union(edge[0], edge[1])
        cluster_map = torch.from_numpy(
            np.array([objects.find(i) for i in range(x.shape[0])],
                     dtype=np.int64)).to(x.device)
        cluster_roots, inverse = torch.unique(cluster_map, return_inverse=True)
        # remap roots to [0, ..., nclusters-1]
        cluster_map = torch.arange(cluster_roots.size()[0],
                                   dtype=torch.int64,
                                   device=x.device)[inverse]
        ''' 
        [3]
        use Embedding1 to learn segmented cluster properties 
        inputnet_cat is residual to inputnet
        '''
        x_prop = self.inputnet_prop(x)  #+ x_emb
        # now we accumulate over all selected disjoint subgraphs
        # to define per-object properties
        for ec in self.propertyconvs:
            x_prop = x_prop + ec(torch.cat([x_prop, x_emb.detach(), x], dim=1),
                                 good_edges)
        props_pooled, cluster_batch = max_pool_x(cluster_map, x_prop, batch)
        cluster_props = self.property_predictor(props_pooled)

        return out, edge_scores, edge_index, cluster_map, cluster_props, cluster_batch
Exemplo n.º 22
0
def balanced_train(model, train_loader, optimizer, loss_fn, m_configs):
    edge_correct, edge_total_positive, edge_total_true, edge_true_positive, total = (
        0,
        1,
        0,
        0,
        0,
    )
    (
        cluster_correct,
        cluster_total_positive,
        cluster_total_true,
        cluster_total_true_positive,
        cluster_total,
    ) = (0, 1, 0, 0, 0)
    correct = 0
    total = 0
    total_loss = 0
    for i, batch in enumerate(train_loader):
        optimizer.zero_grad()
        data = batch.to(device)
        pred, spatial, e = model(data)

        # Get fake edge list
        candidates = radius_graph(
            spatial,
            r=m_configs["r_train"],
            batch=batch.batch,
            loop=False,
            max_num_neighbors=200,
        )
        fake_list = candidates[:, batch.pid[candidates[0]] != batch.
                               pid[candidates[1]]]

        #         print(batch.pid[fake_list[0]] == batch.pid[fake_list[1]])

        # Concatenate all candidates
        e_spatial = torch.cat(
            [fake_list, batch.true_edges.T.to(device)], axis=-1)

        reference = spatial.index_select(0, e_spatial[1])
        neighbors = spatial.index_select(0, e_spatial[0])

        d = torch.sum((reference - neighbors)**2, dim=-1)

        y_edge = batch.pid[e[0]] == batch.pid[e[1]]
        y_cluster = batch.pid[e_spatial[0]] == batch.pid[e_spatial[1]]

        hinge = y_cluster.float()
        hinge[batch.pid[e_spatial[0]] != batch.pid[e_spatial[1]]] = -1

        loss_1 = F.binary_cross_entropy_with_logits(pred.float(),
                                                    y_edge.float(),
                                                    pos_weight=torch.tensor(
                                                        m_configs["weight"]))
        loss_2 = torch.nn.functional.hinge_embedding_loss(
            d,
            hinge,
            margin=m_configs["margin"],
            reduction=m_configs["reduction"])
        #         print("Loss 1:", loss_1.item(), "Loss 2:", loss_2.item())
        loss = loss_fn([loss_1.to(device), loss_2.to(device)])
        #         print("Loss:", loss, "Noise params:", loss_fn.noise_params)
        total_loss += loss.item()
        loss.backward()
        optimizer.step()

        # Cluster performance
        batch_cpu = batch.pid.cpu()
        pids, counts = np.unique(batch_cpu, return_counts=True)

        cluster_true_positive = (y_cluster.float()).sum().item()
        cluster_total_true_positive += cluster_true_positive

        cluster_positive = len(e_spatial[0])
        cluster_total_positive += max(cluster_positive, 1)

        edge_correct += ((sig(pred) > 0.5) == (y_edge.float() >
                                               0.5)).sum().item()
        total += len(pred)

    edge_acc = edge_correct / total
    cluster_pur = cluster_total_true_positive / cluster_total_positive

    return edge_acc, cluster_pur, total_loss
Exemplo n.º 23
0
def evaluate(model, test_loader, loss_fn, m_configs):
    edge_correct, edge_total_positive, edge_total_true, edge_true_positive, total = (
        0,
        1,
        0,
        0,
        0,
    )
    (
        cluster_correct,
        cluster_total_positive,
        cluster_total_true,
        cluster_total_true_positive,
        cluster_total,
    ) = (0, 1, 0, 0, 0)
    total_loss = 0
    for batch in test_loader:
        data = batch.to(device)
        pred, spatial, e = model(data)

        e_spatial = radius_graph(
            spatial,
            r=m_configs["r_val"],
            batch=batch.batch,
            loop=False,
            max_num_neighbors=200,
        )
        #         e_spatial = knn_graph(spatial, k=m_configs["k"], batch=batch.batch, loop=False)

        reference = spatial.index_select(0, e_spatial[1])
        neighbors = spatial.index_select(0, e_spatial[0])

        d = torch.sum((reference - neighbors)**2, dim=-1)

        y_edge = (batch.pid[e[0]] == batch.pid[e[1]]) & (
            batch.layers[e[1]] - batch.layers[e[0]] == 1)
        y_cluster = (batch.pid[e_spatial[0]] == batch.pid[e_spatial[1]]) & (
            batch.layers[e_spatial[1]] - batch.layers[e_spatial[0]] == 1)

        hinge = y_cluster.float()
        hinge[hinge == 0] = -1

        loss_1 = F.binary_cross_entropy_with_logits(pred.float(),
                                                    y_edge.float(),
                                                    pos_weight=torch.tensor(
                                                        m_configs["weight"]))
        loss_2 = torch.nn.functional.hinge_embedding_loss(
            d,
            hinge,
            margin=m_configs["margin"],
            reduction=m_configs["reduction"])
        #         print("Loss 1:", loss_1.item(), "Loss 2:", loss_2.item())

        loss = loss_fn([loss_1, loss_2]).item()
        #         print("Combined loss:", loss, "Noise params:", loss_fn.noise_params)
        total_loss += loss

        # Cluster performance
        cluster_true = len(batch.true_edges)

        cluster_true_positive = (y_cluster.float()).sum().item()
        cluster_total_true_positive += cluster_true_positive

        cluster_positive = len(e_spatial[0])
        cluster_total_positive += max(cluster_positive, 1)

        cluster_total_true += cluster_true

        # Edge performance
        edge_true, edge_false = y_edge.float() > 0.5, y_edge.float() < 0.5
        edge_positive, edge_negative = sig(pred) > 0.5, sig(pred) < 0.5

        edge_correct += ((sig(pred) > 0.5) == (y_edge.float() >
                                               0.5)).sum().item()

        edge_true_positive += (edge_true & edge_positive).sum().item()
        edge_total_true += edge_true.sum().item()
        edge_total_positive += edge_positive.sum().item()

        #         print("EDGES:", "True positive:", (edge_true & edge_positive).sum().item(), "True:", edge_true.sum().item(), "Positive", edge_positive.sum().item())
        #         print("CLUSTER:", "True positive:", cluster_true_positive, "True:", cluster_true, "Positive:", cluster_positive)

        total += len(pred)

    edge_acc = edge_correct / total
    edge_eff = edge_true_positive / max(edge_total_true, 1)
    edge_pur = edge_true_positive / max(edge_total_positive, 1)

    cluster_eff = cluster_total_true_positive / max(cluster_total_true, 1)
    cluster_pur = cluster_total_true_positive / max(cluster_total_positive, 1)

    #     print('EDGE Accuracy: {:.4f}, Purity: {:.4f}, Efficiency: {:.4f}'.format(edge_acc, edge_pur, edge_eff))
    #     print('CLUSTER Purity: {:.4f}, Efficiency: {:.4f}'.format(cluster_pur, cluster_eff))

    return edge_acc, edge_pur, edge_eff, cluster_pur, cluster_eff, total_loss
Exemplo n.º 24
0
    def forward(self, node_atom, node_pos, batch) -> torch.Tensor:
        # The graph
        edge_src, edge_dst = radius_graph(
            node_pos,
            r=self.max_radius,
            batch=batch,
            max_num_neighbors=1000
        )

        # Edge attributes
        edge_vec = node_pos[edge_src] - node_pos[edge_dst]
        edge_sh = o3.spherical_harmonics(
            l=range(self.lmax + 1),
            x=edge_vec,
            normalize=True,
            normalization='component'
        )

        # Edge length embedding
        edge_length = edge_vec.norm(dim=1)
        edge_length_embedding = soft_one_hot_linspace(
            edge_length,
            0.0,
            self.max_radius,
            self.number_of_basis,
            basis='cosine',  # the cosine basis with cutoff = True goes to zero at max_radius
            cutoff=True,  # no need for an additional smooth cutoff
        ).mul(self.number_of_basis**0.5)

        node_input = node_pos.new_ones(node_pos.shape[0], 1)

        node_attr = node_atom.new_tensor([-1, 0, -1, -1, -1, -1, 1, 2, 3, 4])[node_atom]
        node_attr = torch.nn.functional.one_hot(node_attr, 5).mul(5**0.5)

        node_outputs = self.mp(
            node_features=node_input,
            node_attr=node_attr,
            edge_src=edge_src,
            edge_dst=edge_dst,
            edge_attr=edge_sh,
            edge_scalars=edge_length_embedding
        )

        node_outputs = node_outputs[:, 0] + node_outputs[:, 1].pow(2).mul(0.5)
        node_outputs = node_outputs.view(-1, 1)

        node_outputs = node_outputs.div(self.num_nodes**0.5)

        if self.mean is not None and self.std is not None:
            node_outputs = node_outputs * self.std + self.mean

        if self.atomref is not None:
            node_outputs = node_outputs + self.atomref[node_atom]
        # for target=7, MAE of 75eV

        outputs = scatter(node_outputs, batch, dim=0)

        if self.scale is not None:
            outputs = self.scale * outputs

        return outputs
Exemplo n.º 25
0
def plot_weight(model, loss_fn, dataloader, metrics, deltaR, model_dir,
                saveplot):
    """plot_weight
    Args:
        model: (torch.nn.Module) the neural network
        loss_fn: a function that takes batch_output and batch_labels and computes the loss for the batch
        dataloader: (DataLoader) a torch.utils.data.DataLoader object that fetches data
        metrics: (dict) a dictionary of functions that compute a metric using the output and labels of each batch
        num_steps: (int) number of batches to train on, each of size params.batch_size
    """
    #model.eval()
    # summary for current eval loop
    qT_arr = []
    weight_arr = [1, 11, 13, 22, 130, 211]
    label = {
        1: 'HF Candidate',
        11: 'Electron',
        13: 'Muon',
        22: 'Gamma',
        130: 'Neutral Hadron',
        211: 'Charged Hadron',
    }
    binedges_list = {
        'Pt':
        np.arange(-0.05, 25.05, 0.1),
        'eta':
        np.arange(-0.1, 5.1, 0.2),
        'Puppi': [
            -0.05, 0.05, 0.2, 0.25, 0.3, 0.35, 0.4, 0.45, 0.5, 0.55, 0.6, 0.65,
            0.7, 0.75, 0.8, 0.85, 0.9, 1.1
        ],
        'graph_weight':
        np.arange(-0.05, 1.15, 0.01),
        'qT1D':
        np.arange(0, 420, 20),
    }
    #print("bin list:", binedges_list)
    weight_pt_hist = {}
    weight_eta_hist = {}
    weight_puppi_hist = {}
    weight_pt_histN = {}
    weight_eta_histN = {}
    weight_puppi_histN = {}
    weight_CH_hist = {}
    weight_qT_hist = {}
    result = torch.empty(0, 6).to('cuda')
    for key in weight_arr:
        weight_pt_hist[label[key]] = []
        weight_pt_histN[label[key]] = []
        for i in range(1, len(binedges_list['Pt'])):
            weight_pt_hist[label[key]].append(0)
            weight_pt_histN[label[key]].append(0)
        weight_eta_hist[label[key]] = []
        weight_eta_histN[label[key]] = []
        for i in range(1, len(binedges_list['eta'])):
            weight_eta_hist[label[key]].append(0)
            weight_eta_histN[label[key]].append(0)
    for key in (1, 22, 130):
        weight_puppi_hist[label[key]] = []
        weight_puppi_histN[label[key]] = []
        for i in range(1, len(binedges_list['Puppi'])):
            weight_puppi_hist[label[key]].append(0)
            weight_puppi_histN[label[key]].append(0)
    weight_CH_hist['puppi0'] = []
    weight_CH_hist['puppi1'] = []
    for i in range(1, len(binedges_list['graph_weight'])):
        weight_CH_hist['puppi0'].append(0)
        weight_CH_hist['puppi1'].append(0)
    weight_qT_hist['TrueMET'] = []
    weight_qT_hist['GraphMET'] = []
    weight_qT_hist['PFMET'] = []
    weight_qT_hist['PUPPIMET'] = []
    weight_qT_hist['DeepMETResponse'] = []
    weight_qT_hist['DeepMETResolution'] = []
    for i in range(1, len(binedges_list['qT1D'])):
        weight_qT_hist['TrueMET'].append(0)
        weight_qT_hist['GraphMET'].append(0)
        weight_qT_hist['PFMET'].append(0)
        weight_qT_hist['PUPPIMET'].append(0)
        weight_qT_hist['DeepMETResponse'].append(0)
        weight_qT_hist['DeepMETResolution'].append(0)

    # compute metrics over the dataset
    for data in dataloader:
        data = data.to('cuda')
        x_cont = data.x[:, :8]
        x_cat = data.x[:, 8:].long()
        print("x_cont shape:", (x_cont.shape))
        print("x_cat shape:", (x_cat.shape))
        phi = torch.atan2(data.x[:, 1], data.x[:, 0])
        etaphi = torch.cat([data.x[:, 3][:, None], phi[:, None]], dim=1)
        # NB: there is a problem right now for comparing hits at the +/- pi boundary
        edge_index = radius_graph(etaphi,
                                  r=deltaR,
                                  batch=data.batch,
                                  loop=True,
                                  max_num_neighbors=255)
        # compute model output
        result = model(x_cont, x_cat, edge_index, data.batch)
        TrueqT = torch.sqrt(data.y[:, 0]**2 +
                            data.y[:, 1]**2).cpu().detach().numpy()
        pfqT = torch.sqrt(data.y[:, 2]**2 +
                          data.y[:, 3]**2).cpu().detach().numpy()
        puppiqT = torch.sqrt(data.y[:, 4]**2 +
                             data.y[:, 5]**2).cpu().detach().numpy()
        deepMETResponseqT = torch.sqrt(data.y[:, 6]**2 +
                                       data.y[:, 7]**2).cpu().detach().numpy()
        deepMETResolutionqT = torch.sqrt(data.y[:, 8]**2 +
                                         data.y[:,
                                                9]**2).cpu().detach().numpy()
        graphMETx = scatter_add(result * data.x[:, 0], data.batch)
        graphMETy = scatter_add(result * data.x[:, 1], data.batch)
        graphMETqT = torch.sqrt(graphMETx**2 +
                                graphMETy**2).cpu().detach().numpy()
        # qT 1D distribution
        for i in range(1, len(binedges_list['qT1D'])):
            binnedqT = TrueqT[np.where((TrueqT >= binedges_list['qT1D'][i - 1])
                                       & (TrueqT < binedges_list['qT1D'][i]))]
            weight_qT_hist['TrueMET'][i - 1] += len(binnedqT)
            binnedqT = graphMETqT[
                np.where((graphMETqT >= binedges_list['qT1D'][i - 1])
                         & (graphMETqT < binedges_list['qT1D'][i]))]
            weight_qT_hist['GraphMET'][i - 1] += len(binnedqT)
            binnedqT = pfqT[np.where((pfqT >= binedges_list['qT1D'][i - 1])
                                     & (pfqT < binedges_list['qT1D'][i]))]
            weight_qT_hist['PFMET'][i - 1] += len(binnedqT)
            binnedqT = puppiqT[
                np.where((puppiqT >= binedges_list['qT1D'][i - 1])
                         & (puppiqT < binedges_list['qT1D'][i]))]
            weight_qT_hist['PUPPIMET'][i - 1] += len(binnedqT)
            binnedqT = deepMETResponseqT[
                np.where((deepMETResponseqT >= binedges_list['qT1D'][i - 1])
                         & (deepMETResponseqT < binedges_list['qT1D'][i]))]
            weight_qT_hist['DeepMETResponse'][i - 1] += len(binnedqT)
            binnedqT = deepMETResolutionqT[
                np.where((deepMETResolutionqT >= binedges_list['qT1D'][i - 1])
                         & (deepMETResolutionqT < binedges_list['qT1D'][i]))]
            weight_qT_hist['DeepMETResolution'][i - 1] += len(binnedqT)

        #pX,pY,pT,eta,d0,dz,mass,puppiWeight,pdgId,charge,fromPV
        #pdg, pt, eta, puppi, weight
        ZQt = torch.gather(torch.sqrt(data.y[:, 0]**2 + data.y[:, 1]**2), 0,
                           data.batch)
        result = torch.stack(
            (torch.abs(x_cat[:, 0]), torch.abs(x_cont[:, 2]),
             torch.abs(x_cont[:, 3]), torch.abs(x_cont[:, 7]), result, ZQt),
            dim=1)
        #result = result[np.where(result[:,5].cpu()<30 )]
        # weight vs pt
        # weight vs eta
        for key in weight_arr:
            if (key == 1):
                W_arr = result[np.where((result[:, 0].cpu() == key) | (
                    result[:, 0].cpu() == 2))].cpu().detach().numpy()
            else:
                W_arr = result[np.where(
                    result[:, 0].cpu() == key)].cpu().detach().numpy()
            for i in range(1, len(binedges_list['Pt'])):
                W_i = W_arr[
                    np.where((W_arr[:, 1] >= binedges_list['Pt'][i - 1])
                             & (W_arr[:, 1] < binedges_list['Pt'][i]))][:, 4]
                weight_pt_hist[label[key]][i - 1] += np.sum(W_i)
                weight_pt_histN[label[key]][i - 1] += len(W_i)
            for i in range(1, len(binedges_list['eta'])):
                W_i = W_arr[
                    np.where((W_arr[:, 2] >= binedges_list['eta'][i - 1])
                             & (W_arr[:, 2] < binedges_list['eta'][i]))][:, 4]
                weight_eta_hist[label[key]][i - 1] += np.sum(W_i)
                weight_eta_histN[label[key]][i - 1] += len(W_i)
        # weight vs puppi
        for key in (1, 22, 130):
            if (key == 1):
                W_arr = result[np.where((result[:, 0].cpu() == key) | (
                    result[:, 0].cpu() == 2))].cpu().detach().numpy()
            else:
                W_arr = result[np.where(
                    result[:, 0].cpu() == key)].cpu().detach().numpy()
            for i in range(1, len(binedges_list['Puppi'])):
                W_i = W_arr[np.where(
                    (W_arr[:, 3] >= binedges_list['Puppi'][i - 1])
                    & (W_arr[:, 3] < binedges_list['Puppi'][i]))][:, 4]
                weight_puppi_hist[label[key]][i - 1] += np.sum(W_i)
                weight_puppi_histN[label[key]][i - 1] += len(W_i)
        # weight distribution
        W_arr = result[np.where(
            result[:, 0].cpu() == 211)].cpu().detach().numpy()
        for i in range(1, len(binedges_list['graph_weight'])):
            W_i = W_arr[np.where(
                (W_arr[:, 3] == 0)
                & (W_arr[:, 4] >= binedges_list['graph_weight'][i - 1])
                & (W_arr[:, 4] < binedges_list['graph_weight'][i]))][:, 4]
            weight_CH_hist['puppi0'][i - 1] += len(W_i)
            W_i = W_arr[np.where(
                (W_arr[:, 3] == 1)
                & (W_arr[:, 4] >= binedges_list['graph_weight'][i - 1])
                & (W_arr[:, 4] < binedges_list['graph_weight'][i]))][:, 4]
            weight_CH_hist['puppi1'][i - 1] += len(W_i)

    for key in weight_arr:
        for i in range(1, len(binedges_list['Pt'])):
            weight_pt_hist[label[key]][
                i - 1] /= 1.0 * weight_pt_histN[label[key]][i - 1]
            weight_pt_hist[label[key]] = np.nan_to_num(
                weight_pt_hist[label[key]])
        for i in range(1, len(binedges_list['eta'])):
            weight_eta_hist[label[key]][
                i - 1] /= 1.0 * weight_eta_histN[label[key]][i - 1]
            weight_eta_hist[label[key]] = np.nan_to_num(
                weight_eta_hist[label[key]])
    for key in (1, 22, 130):
        for i in range(1, len(binedges_list['Puppi'])):
            weight_puppi_hist[label[key]][
                i - 1] /= 1.0 * weight_puppi_histN[label[key]][i - 1]
            weight_puppi_hist[label[key]] = np.nan_to_num(
                weight_puppi_hist[label[key]])
    weights = {
        'bin_edges': binedges_list,
        'weight_pt_hist': weight_pt_hist,
        'weight_eta_hist': weight_eta_hist,
        'weight_puppi_hist': weight_puppi_hist,
        'weight_CH_hist': weight_CH_hist,
        'weight_qT_hist': weight_qT_hist,
    }
    utils.save(weights, 'weight.plt')
    return result
Exemplo n.º 26
0
def evaluate(model, device, loss_fn, dataloader, metrics, deltaR, deltaR_dz,
             model_dir):
    """Evaluate the model on `num_steps` batches.

    Args:
        model: (torch.nn.Module) the neural network
        loss_fn: a function that takes batch_output and batch_labels and computes the loss for the batch
        dataloader: (DataLoader) a torch.utils.data.DataLoader object that fetches data
        metrics: (dict) a dictionary of functions that compute a metric using the output and labels of each batch
        num_steps: (int) number of batches to train on, each of size params.batch_size
    """
    # set model to evaluation mode
    model.eval()

    # summary for current eval loop
    loss_avg_arr = []
    qT_arr = []
    has_deepmet = False
    resolutions_arr = {
        'MET': [[], [], []],
        'pfMET': [[], [], []],
        'puppiMET': [[], [], []],
    }

    colors = {
        'pfMET': 'black',
        'puppiMET': 'red',
        'deepMETResponse': 'blue',
        'deepMETResolution': 'green',
        'MET': 'magenta',
    }

    labels = {
        'pfMET': 'PF MET',
        'puppiMET': 'PUPPI MET',
        'deepMETResponse': 'DeepMETResponse',
        'deepMETResolution': 'DeepMETResolution',
        'MET': 'DeepMETv2'
    }

    # compute metrics over the dataset
    for data in dataloader:

        has_deepmet = (data.y.size()[1] > 6)

        if has_deepmet == True and 'deepMETResponse' not in resolutions_arr.keys(
        ):
            resolutions_arr.update({
                'deepMETResponse': [[], [], []],
                'deepMETResolution': [[], [], []]
            })

        data = data.to(device)
        #x_cont = data.x[:,:7] #remove puppi
        x_cont = data.x[:, :8]  #include puppi
        x_cat = data.x[:, 8:].long()
        phi = torch.atan2(data.x[:, 1], data.x[:, 0])
        etaphi = torch.cat([data.x[:, 3][:, None], phi[:, None]], dim=1)
        # NB: there is a problem right now for comparing hits at the +/- pi boundary
        edge_index = radius_graph(etaphi,
                                  r=deltaR,
                                  batch=data.batch,
                                  loop=True,
                                  max_num_neighbors=255)
        result = model(x_cont, x_cat, edge_index, data.batch)

        #add dz connection
        #tic = time.time()
        #tinf = (torch.ones(len(data.x[:,5]))*float("Inf")).to('cuda')
        #edge_index_dz = radius_graph(torch.where(data.x[:,7]!=0, data.x[:,5], tinf), r=deltaR_dz, batch=data.batch, loop=True, max_num_neighbors=127)
        #cat_edges = torch.cat([edge_index,edge_index_dz],dim=1)
        #result = model(x_cont, x_cat, cat_edges, data.batch)
        #toc = time.time()
        #print('Event processing speed', toc - tic)

        loss = loss_fn(result, data.x, data.y, data.batch)

        # compute all metrics on this batch
        resolutions, qT = metrics['resolution'](result, data.x, data.y,
                                                data.batch)
        for key in resolutions_arr:
            for i in range(len(resolutions_arr[key])):
                resolutions_arr[key][i] = np.concatenate(
                    (resolutions_arr[key][i], resolutions[key][i]))
        qT_arr = np.concatenate((qT_arr, qT))
        loss_avg_arr.append(loss.item())

    # compute mean of all metrics in summary
    max_x = 400  # max qT value
    x_n = 40  #number of bins

    bin_edges = np.arange(0, max_x, 10)
    inds = np.digitize(qT_arr, bin_edges)
    qT_hist = []
    for i in range(1, len(bin_edges)):
        qT_hist.append((bin_edges[i] + bin_edges[i - 1]) / 2.)

    resolution_hists = {}
    for key in resolutions_arr:

        R_arr = resolutions_arr[key][2]
        u_perp_arr = resolutions_arr[key][0]
        u_par_arr = resolutions_arr[key][1]

        u_perp_hist = []
        u_perp_scaled_hist = []
        u_par_hist = []
        u_par_scaled_hist = []
        R_hist = []

        for i in range(1, len(bin_edges)):
            R_i = R_arr[np.where(inds == i)[0]]
            R_hist.append(np.mean(R_i))
            u_perp_i = u_perp_arr[np.where(inds == i)[0]]
            u_perp_scaled_i = u_perp_i / np.mean(R_i)
            u_perp_hist.append(
                (np.quantile(u_perp_i, 0.84) - np.quantile(u_perp_i, 0.16)) /
                2.)
            u_perp_scaled_hist.append(
                (np.quantile(u_perp_scaled_i, 0.84) -
                 np.quantile(u_perp_scaled_i, 0.16)) / 2.)
            u_par_i = u_par_arr[np.where(inds == i)[0]]
            u_par_scaled_i = u_par_i / np.mean(R_i)
            u_par_hist.append(
                (np.quantile(u_par_i, 0.84) - np.quantile(u_par_i, 0.16)) / 2.)
            u_par_scaled_hist.append((np.quantile(u_par_scaled_i, 0.84) -
                                      np.quantile(u_par_scaled_i, 0.16)) / 2.)

        u_perp_resolution = np.histogram(qT_hist,
                                         bins=x_n,
                                         range=(0, max_x),
                                         weights=u_perp_hist)
        u_perp_scaled_resolution = np.histogram(qT_hist,
                                                bins=x_n,
                                                range=(0, max_x),
                                                weights=u_perp_scaled_hist)
        u_par_resolution = np.histogram(qT_hist,
                                        bins=x_n,
                                        range=(0, max_x),
                                        weights=u_par_hist)
        u_par_scaled_resolution = np.histogram(qT_hist,
                                               bins=x_n,
                                               range=(0, max_x),
                                               weights=u_par_scaled_hist)
        R = np.histogram(qT_hist, bins=x_n, range=(0, max_x), weights=R_hist)
        resolution_hists[key] = {
            'u_perp_resolution': u_perp_resolution,
            'u_perp_scaled_resolution': u_perp_scaled_resolution,
            'u_par_resolution': u_par_resolution,
            'u_par_scaled_resolution': u_par_scaled_resolution,
            'R': R
        }
    metrics_mean = {
        'loss': np.mean(loss_avg_arr),
        #'resolution': (np.quantile(resolution_arr,0.84)-np.quantile(resolution_arr,0.16))/2.
    }
    metrics_string = " ; ".join("{}: {:05.3f}".format(k, v)
                                for k, v in metrics_mean.items())
    print("- Eval metrics : " + metrics_string)
    return metrics_mean, resolution_hists
Exemplo n.º 27
0
def evaluate(model, loss_fn, dataloader, metrics, deltaR):
    """Evaluate the model on `num_steps` batches.

    Args:
        model: (torch.nn.Module) the neural network
        loss_fn: a function that takes batch_output and batch_labels and computes the loss for the batch
        dataloader: (DataLoader) a torch.utils.data.DataLoader object that fetches data
        metrics: (dict) a dictionary of functions that compute a metric using the output and labels of each batch
        num_steps: (int) number of batches to train on, each of size params.batch_size
    """
    # set model to evaluation mode
    model.eval()

    # summary for current eval loop
    loss_avg_arr = []
    qT_arr = []
    resolutions_arr = {
        'MET': [[], [], []],
        'pfMET': [[], [], []],
        'puppiMET': [[], [], []]
    }

    # compute metrics over the dataset
    for data in dataloader:

        data = data.to('cuda')
        x_cont = data.x[:, :8]
        x_cat = data.x[:, 8:].long()
        phi = torch.atan2(data.x[:, 1], data.x[:, 0])
        etaphi = torch.cat([data.x[:, 3][:, None], phi[:, None]], dim=1)
        # NB: there is a problem right now for comparing hits at the +/- pi boundary
        edge_index = radius_graph(etaphi,
                                  r=deltaR,
                                  batch=data.batch,
                                  loop=True,
                                  max_num_neighbors=255)
        # compute model output
        result = model(x_cont, x_cat, edge_index, data.batch)
        loss = loss_fn(result, data.x, data.y, data.batch)

        # compute all metrics on this batch
        resolutions, qT = metrics['resolution'](result, data.x, data.y,
                                                data.batch)
        for key in resolutions_arr:
            for i in range(len(resolutions_arr[key])):
                resolutions_arr[key][i] = np.concatenate(
                    (resolutions_arr[key][i], resolutions[key][i]))
        qT_arr = np.concatenate((qT_arr, qT))
        loss_avg_arr.append(loss.item())

    # compute mean of all metrics in summary
    bin_edges = np.arange(0, 500, 25)
    inds = np.digitize(qT_arr, bin_edges)
    qT_hist = []
    for i in range(1, len(bin_edges)):
        qT_hist.append((bin_edges[i] + bin_edges[i - 1]) / 2.)

    resolution_hists = {}
    for key in resolutions_arr:

        R_arr = resolutions_arr[key][2]
        u_perp_arr = resolutions_arr[key][0]
        u_par_arr = resolutions_arr[key][1]

        u_perp_hist = []
        u_perp_scaled_hist = []
        u_par_hist = []
        u_par_scaled_hist = []
        R_hist = []

        for i in range(1, len(bin_edges)):
            R_i = R_arr[np.where(inds == i)[0]]
            R_hist.append(np.mean(R_i))
            u_perp_i = u_perp_arr[np.where(inds == i)[0]]
            u_perp_scaled_i = u_perp_i / np.mean(R_i)
            u_perp_hist.append(
                (np.quantile(u_perp_i, 0.84) - np.quantile(u_perp_i, 0.16)) /
                2.)
            u_perp_scaled_hist.append(
                (np.quantile(u_perp_scaled_i, 0.84) -
                 np.quantile(u_perp_scaled_i, 0.16)) / 2.)
            u_par_i = u_par_arr[np.where(inds == i)[0]]
            u_par_scaled_i = u_par_i / np.mean(R_i)
            u_par_hist.append(
                (np.quantile(u_par_i, 0.84) - np.quantile(u_par_i, 0.16)) / 2.)
            u_par_scaled_hist.append((np.quantile(u_par_scaled_i, 0.84) -
                                      np.quantile(u_par_scaled_i, 0.16)) / 2.)

        u_perp_resolution = np.histogram(qT_hist,
                                         bins=20,
                                         range=(0, 500),
                                         weights=u_perp_hist)
        u_perp_scaled_resolution = np.histogram(qT_hist,
                                                bins=20,
                                                range=(0, 500),
                                                weights=u_perp_scaled_hist)
        u_par_resolution = np.histogram(qT_hist,
                                        bins=20,
                                        range=(0, 500),
                                        weights=u_par_hist)
        u_par_scaled_resolution = np.histogram(qT_hist,
                                               bins=20,
                                               range=(0, 500),
                                               weights=u_par_scaled_hist)
        R = np.histogram(qT_hist, bins=20, range=(0, 500), weights=R_hist)

        resolution_hists[key] = {
            'u_perp_resolution': u_perp_resolution,
            'u_perp_scaled_resolution': u_perp_scaled_resolution,
            'u_par_resolution': u_par_resolution,
            'u_par_scaled_resolution': u_par_scaled_resolution,
            'R': R
        }

    metrics_mean = {
        'loss': np.mean(loss_avg_arr),
        #'resolution': (np.quantile(resolution_arr,0.84)-np.quantile(resolution_arr,0.16))/2.
    }
    metrics_string = " ; ".join("{}: {:05.3f}".format(k, v)
                                for k, v in metrics_mean.items())
    print("- Eval metrics : " + metrics_string)
    return metrics_mean, resolution_hists