Example #1
0
    def forward(self, x, edge_index, edge_attr, u, batch):
        # x: [N, F_x], where N is the number of nodes.
        # edge_index: [2, E] with max entry N - 1.
        # edge_attr: [E, F_e]
        # u: [B, F_u]
        # batch: [N] with max entry B - 1.
        row, col = edge_index
        out = torch.cat([x[col], edge_attr], dim=1)
        out = self.node_mlp_1(out)

        if self.aggregation == "mean":
            out = scatter_mean(out, row, dim=0, dim_size=x.size(0))
        elif self.aggregation == "min":
            out, _ = scatter_min(out, row, dim=0, dim_size=x.size(0))
        elif self.aggregation == "max":
            out, _ = scatter_max(out, row, dim=0, dim_size=x.size(0))
        elif self.aggregation == "minmax":
            out = torch.cat([
                scatter_min(out, row, dim=0, dim_size=x.size(0))[0],
                scatter_max(out, row, dim=0, dim_size=x.size(0))[0]
            ], dim=1)
        else:
            raise ValueError("Unknown aggregation type: {}".format(self.aggregation)) 

        out = torch.cat([x, out, u[batch]], dim=1)
        return self.node_mlp_2(out)
Example #2
0
    def assign_edge_labels(self):
        """
        Assigns self.graph_obj edge labels (tensor with shape (num_edges,)), with labels defined according to the
        network flow MOT formulation
        """

        ids = torch.as_tensor(self.graph_df.id.values,
                              device=self.graph_obj.edge_index.device)
        per_edge_ids = torch.stack([
            ids[self.graph_obj.edge_index[0]],
            ids[self.graph_obj.edge_index[1]]
        ])
        same_id = (per_edge_ids[0]
                   == per_edge_ids[1]) & (per_edge_ids[0] != -1)
        same_ids_ixs = torch.where(same_id)
        same_id_edges = self.graph_obj.edge_index.T[same_id].T

        time_dists = torch.abs(same_id_edges[0] - same_id_edges[1])

        # For every node, we get the index of the node in the future (resp. past) with the same id that is closest in time
        future_mask = same_id_edges[0] < same_id_edges[1]
        active_fut_edges = scatter_min(time_dists[future_mask],
                                       same_id_edges[0][future_mask],
                                       dim=0,
                                       dim_size=self.graph_obj.num_nodes)[1]
        original_node_ixs = torch.cat(
            (same_id_edges[1][future_mask],
             torch.as_tensor([-1], device=same_id.device)
             ))  # -1 at the end for nodes that were not present
        active_fut_edges = original_node_ixs[
            active_fut_edges]  # Recover the node id of the corresponding
        fut_edge_is_active = active_fut_edges[
            same_id_edges[0]] == same_id_edges[1]

        # Analogous for past edges
        past_mask = same_id_edges[0] > same_id_edges[1]
        active_past_edges = scatter_min(time_dists[past_mask],
                                        same_id_edges[0][past_mask],
                                        dim=0,
                                        dim_size=self.graph_obj.num_nodes)[1]
        original_node_ixs = torch.cat(
            (same_id_edges[1][past_mask],
             torch.as_tensor([-1], device=same_id.device)
             ))  # -1 at the end for nodes that were not present
        active_past_edges = original_node_ixs[active_past_edges]
        past_edge_is_active = active_past_edges[
            same_id_edges[0]] == same_id_edges[1]

        # Recover the ixs of active edges in the original edge_index tensor o
        active_edge_ixs = same_ids_ixs[0][past_edge_is_active
                                          | fut_edge_is_active]
        self.graph_obj.edge_labels = torch.zeros_like(same_id,
                                                      dtype=torch.float)
        self.graph_obj.edge_labels[active_edge_ixs] = 1
        self.graph_obj.tracking_id = ids
Example #3
0
    def scatter_distribution(src: torch.Tensor, index: torch.Tensor, dim: int = -1,
                            out: Optional[torch.Tensor] = None,
                            dim_size: Optional[int] = None,
                            unbiased: bool = True) -> torch.Tensor:

        if out is not None:
            dim_size = out.size(dim)

        if dim < 0:
            dim = src.dim() + dim

        count_dim = dim
        if index.dim() <= dim:
            count_dim = index.dim() - 1

        ones = torch.ones(index.size(), dtype=src.dtype, device=src.device)
        count = scatter_sum(ones, index, count_dim, dim_size=dim_size)

        index = broadcast(index, src, dim)
        tmp = scatter_sum(src, index, dim, dim_size=dim_size)
        summ = tmp.clone()
        count = broadcast(count, tmp, dim).clamp(1)
        mean = tmp.div(count)

        var = (src - mean.gather(dim, index))
        var = var * var
        var = scatter_sum(var, index, dim, out, dim_size)

        if unbiased:
            count = count.sub(1).clamp_(1)
        var = var.div(count)
        maximum = scatter_max(src, index, dim, out, dim_size)[0]
        minimum = scatter_min(src, index, dim, out, dim_size)[0]

        return torch.cat([summ,mean,var,maximum,minimum],dim=1)
    def forward(self, data):
        data.x = F.elu(self.conv1(data.x, data.edge_index, data.edge_attr))
        weight = normalized_cut_2d(data.edge_index, data.pos)
        cluster = graclus(data.edge_index, weight, data.x.size(0))
        data.edge_attr = None
        data = max_pool(cluster, data, transform=transform)

        data.x = F.elu(self.conv2(data.x, data.edge_index, data.edge_attr))
        weight = normalized_cut_2d(data.edge_index, data.pos)
        cluster = graclus(data.edge_index, weight, data.x.size(0))
        x, batch = max_pool_x(cluster, data.x, data.batch)

        #x = global_mean_pool(x, batch)
        x_min = torch_scatter.scatter_min(x, batch, dim=0)[0]
        gather_idxs = batch.expand(x.shape[1], -1).t()
        gather_mins = torch.gather(x_min, 0, gather_idxs)
        s = F.relu(-gather_mins)
        x = x + s
        x = self.aggregator(x, batch)
        s_out = self.aggregator(s, batch)
        x = x - s_out

        x = F.elu(self.fc1(x))
        x = F.dropout(x, training=self.training)
        return F.log_softmax(self.fc2(x), dim=1)
Example #5
0
def take_action_deterministic_batch_dqn(target_net, player, batch_instances):

    with torch.no_grad():
        # We compute the target values
        batch = batch_instances.G_torch.batch
        mask_values = batch_instances.J.eq(0)[:, 0]
        action_values = target_net(
            batch_instances.G_torch,
            batch_instances.n_nodes,
            batch_instances.Omegas,
            batch_instances.Phis,
            batch_instances.Lambdas,
            batch_instances.Omegas_norm,
            batch_instances.Phis_norm,
            batch_instances.Lambdas_norm,
            batch_instances.J,
        )
        action_values = action_values[mask_values]
        batch = batch[mask_values]
        # if it's the turn of the attacker
        if player == 1:
            # we take the argmin
            values, actions = scatter_min(action_values, batch, dim=0)
        else:
            # we take the argmax
            values, actions = scatter_max(action_values, batch, dim=0)

    return actions.view(-1).tolist()
Example #6
0
    def forward(self, x, edge_index, edge_attr, u, batch):
        # x: [N, F_x], where N is the number of nodes.
        # edge_index: [2, E] with max entry N - 1.
        # edge_attr: [E, F_e]
        # u: [B, F_u]
        # batch: [N] with max entry B - 1.
        row, col = edge_index

        # define interaction tensor; every pair contains features from input and
        # output node together with
        #out = torch.cat([x[row], x[col], edge_attr], dim=1)
        out = torch.cat([x[row], x[col]], dim=1)
        #print("node pre", x.shape, out.shape)

        # take interaction feature tensor and embedd it into another tensor
        #out = self.node_mlp_1(out)
        out = self.mlp(out)
        #print("node mlp", out.shape)

        # compute the mean,sum and max of each embed feature tensor for each node
        out1 = scatter_mean(out, col, dim=0, dim_size=x.size(0))
        out3 = scatter_max(out, col, dim=0, dim_size=x.size(0))[0]
        out4 = scatter_min(out, col, dim=0, dim_size=x.size(0))[0]

        # every node contains a feature tensor with the pooling of the messages from
        # neighbors, its own state, and a global feature
        out = torch.cat([x, out1, out3, out4, u[batch]], dim=1)
        #print("node post", out.shape)

        #return self.node_mlp_2(out)
        return out
Example #7
0
    def __call__(self, data, norm=True):
        row, col = data.edge_index
        N = data.num_nodes

        deg = degree(row, N, dtype=torch.float)
        if norm:
            deg = deg / deg.max()
        deg_col = deg[col]

        min_deg, _ = scatter_min(deg_col, row, dim_size=N)
        min_deg[min_deg > 10000] = 0
        max_deg, _ = scatter_max(deg_col, row, dim_size=N)
        max_deg[max_deg < -10000] = 0
        mean_deg = scatter_mean(deg_col, row, dim_size=N)
        std_deg = scatter_std(deg_col, row, dim_size=N)

        x = torch.stack([deg, min_deg, max_deg, mean_deg, std_deg], dim=1)

        if data.x is not None:
            data.x = data.x.view(-1, 1) if data.x.dim() == 1 else data.x
            data.x = torch.cat([data.x, x], dim=-1)
        else:
            data.x = x

        return data
Example #8
0
    def aggregate(self,
                  inputs: Tensor,
                  index: Tensor,
                  dim_size: Optional[int] = None) -> Tensor:

        if self.aggr == 'softmax':
            out = scatter_softmax(inputs * self.t, index, dim=self.node_dim)
            return scatter(inputs * out,
                           index,
                           dim=self.node_dim,
                           dim_size=dim_size,
                           reduce='sum')

        elif self.aggr == 'softmax_sg':
            out = scatter_softmax(inputs * self.t, index,
                                  dim=self.node_dim).detach()
            return scatter(inputs * out,
                           index,
                           dim=self.node_dim,
                           dim_size=dim_size,
                           reduce='sum')
        elif self.aggr == 'stat':
            _mean = scatter_mean(inputs,
                                 index,
                                 dim=self.node_dim,
                                 dim_size=dim_size)
            _std = scatter_std(inputs,
                               index,
                               dim=self.node_dim,
                               dim_size=dim_size).detach()
            _min = scatter_min(inputs,
                               index,
                               dim=self.node_dim,
                               dim_size=dim_size)[0]
            _max = scatter_max(inputs,
                               index,
                               dim=self.node_dim,
                               dim_size=dim_size)[0]

            _mean = _mean.unsqueeze(dim=-1)
            _std = _std.unsqueeze(dim=-1)
            _min = _min.unsqueeze(dim=-1)
            _max = _max.unsqueeze(dim=-1)

            stat = torch.cat([_mean, _std, _min, _max], dim=-1)
            stat = self.lin_stat(stat)
            stat = stat.squeeze(dim=-1)
            return stat

        else:
            min_value, max_value = 1e-7, 1e1
            torch.clamp_(inputs, min_value, max_value)
            out = scatter(torch.pow(inputs, self.p),
                          index,
                          dim=self.node_dim,
                          dim_size=dim_size,
                          reduce='mean')
            torch.clamp_(out, min_value, max_value)
            return torch.pow(out, 1 / self.p)
Example #9
0
def test_min_fill_value():
    src = torch.Tensor([[-2, 0, -1, -4, -3], [0, -2, -1, -3, -4]])
    index = torch.tensor([[4, 5, 4, 2, 3], [0, 0, 2, 2, 1]])

    out, _ = scatter_min(src, index)

    v = torch.finfo(torch.float).max
    assert out.tolist() == [[v, v, -4, -3, -2, 0], [-2, -4, -3, v, v, v]]
 def update_time_(node_time_dict, index, node_type, num_nodes):
     node_time_dict[node_type] = node_time_dict[node_type].clone()
     node_time, _ = scatter_min(edge_label_time,
                                index,
                                dim=0,
                                dim_size=num_nodes)
     # NOTE We assume that node_time is always less than edge_time.
     index_unique = index.unique()
     node_time_dict[node_type][index_unique] = node_time[index_unique]
Example #11
0
    def forward(self, batch_size, encode_coordinates, agent_encodings):
        channel = agent_encodings.shape[-1]
        pool_vector = agent_encodings.transpose(1, 0) # [C X D]

        init_map_ts = torch.zeros((channel, batch_size*self.pooling_size*self.pooling_size), device=self.device) # [C X B*H*W]
        out, _ = ts.scatter_min(src=pool_vector, index=encode_coordinates, out=init_map_ts) # [C X B*H*W]
        out, _ = ts.scatter_max(src=pool_vector, index=encode_coordinates, out=out) # [C X B*H*W]

        out = out.reshape((channel, batch_size, self.pooling_size, self.pooling_size)) # [C X B X H X W]
        out = out.permute((1, 0, 2, 3)) # [B X C X H X W]

        return out
Example #12
0
    def forward(self, data):
        # device = self.device
        # mode   = self.mode
        k = self.k
        device = self.device
        pos_idx = self.pos_idx
        x, edge_index, batch = data.x, data.edge_index, data.batch
        edge_index = knn_graph(x=x[:, pos_idx], k=k, batch=batch).to(device)
        x = self.GGconv1(x, edge_index)
        x = self.relu(x)

        x = self.nn1(x)
        x = self.relu(x)

        y = self.resblock1(x)
        x = x + y

        z = self.resblock2(x)
        x = x + z

        del y, z

        x = self.nn2(x)
        x = self.relu(x)

        x = self.GGconv2(x, edge_index)
        x = self.relu(x)

        p = self.resblock3(x)
        x = x + p

        o = self.resblock4(x)
        x = x + o
        del p, o

        x = self.nn3(x)
        x = self.relu(x)

        a, _ = scatter_max(x, batch, dim=0)
        b, _ = scatter_min(x, batch, dim=0)
        c = scatter_sum(x, batch, dim=0)
        d = scatter_mean(x, batch, dim=0)
        x = torch.cat((a, b, c, d), dim=1)
        # print ("cat size",x.size())
        del a, b, c, d

        x = self.nncat(x)
        x = self.relu(x)
        # if(torch.sum(torch.isnan(x)) != 0):
        # print('NAN ENCOUNTERED AT NN2')

        # print ("xsize %s batchsize %s a size %s b size %s y size %s end forward" %(x.size(),batch.size(),a.size(),b.size(),data.y[:,0].size()))
        return x
Example #13
0
 def forward(self, x, edge_index, edge_attr, u, batch):
     # x: [N, F_x], where N is the number of nodes.
     # edge_index: [2, E] with max entry N - 1.
     # edge_attr: [E, F_e]
     # u: [B, F_u]
     # batch: [N] with max entry B - 1.
     out1 = scatter_mean(x, batch, dim=0)
     out3 = scatter_max(x, batch, dim=0)[0]
     out4 = scatter_min(x, batch, dim=0)[0]
     out = torch.cat([u, out1, out3, out4], dim=1)
     #print("global pre",out.shape, x.shape, u.shape)
     out = self.global_mlp(out)
     #print("global post",out.shape)
     return out
Example #14
0
def correctness(dataset):
    group, name = dataset
    mat = loadmat(f'{name}.mat')['Problem'][0][0][2].tocsr()
    rowptr = torch.from_numpy(mat.indptr).to(args.device, torch.long)
    row = torch.from_numpy(mat.tocoo().row).to(args.device, torch.long)
    dim_size = rowptr.size(0) - 1

    for size in sizes:
        try:
            x = torch.randn((row.size(0), size), device=args.device)
            x = x.squeeze(-1) if size == 1 else x

            out1 = scatter_add(x, row, dim=0, dim_size=dim_size)
            out2 = segment_coo(x, row, dim_size=dim_size, reduce='add')
            out3 = segment_csr(x, rowptr, reduce='add')

            assert torch.allclose(out1, out2, atol=1e-4)
            assert torch.allclose(out1, out3, atol=1e-4)

            out1 = scatter_mean(x, row, dim=0, dim_size=dim_size)
            out2 = segment_coo(x, row, dim_size=dim_size, reduce='mean')
            out3 = segment_csr(x, rowptr, reduce='mean')

            assert torch.allclose(out1, out2, atol=1e-4)
            assert torch.allclose(out1, out3, atol=1e-4)

            x = x.abs_().mul_(-1)

            out1, _ = scatter_min(x, row, 0, torch.zeros_like(out1))
            out2, _ = segment_coo(x, row, reduce='min')
            out3, _ = segment_csr(x, rowptr, reduce='min')

            assert torch.allclose(out1, out2, atol=1e-4)
            assert torch.allclose(out1, out3, atol=1e-4)

            x = x.abs_()

            out1, _ = scatter_max(x, row, 0, torch.zeros_like(out1))
            out2, _ = segment_coo(x, row, reduce='max')
            out3, _ = segment_csr(x, rowptr, reduce='max')

            assert torch.allclose(out1, out2, atol=1e-4)
            assert torch.allclose(out1, out3, atol=1e-4)

        except RuntimeError as e:
            if 'out of memory' not in str(e):
                raise RuntimeError(e)
            torch.cuda.empty_cache()
Example #15
0
    def forward(self, data):
        k = self.k        
        device = self.device
        mode   = self.mode
        pos_idx = self.pos_idx
        #changing xtype to float, change back after saving graphs properly
        x, edge_index, batch = data.x, data.edge_index, data.batch
        
        edge_index = knn_graph(x=x[:,pos_idx],k=k,batch=batch).to(device)

        a = self.conv_add(x,edge_index)
        
        edge_index = knn_graph(x=a[:,pos_idx],k=k,batch=batch).to(device)
        "check if this recalculation of edge indices is correct, maybe you can do it over all of x"
        b = self.conv_add2(a,edge_index)

        edge_index = knn_graph(x=b[:,pos_idx],k=k,batch=batch).to(device)
        
        c = self.conv_add3(b,edge_index)

        edge_index = knn_graph(x=c[:,pos_idx],k=k,batch=batch).to(device)
        
        d = self.conv_add4(c,edge_index)

        x = torch.cat((x,a,b,c,d),dim = 1) 
        del a,b,c,d
        x = self.nn1(x)
        x = self.relu(x)
        x = self.nn2(x)
        
        a,_ = scatter_max(x, batch, dim = 0)
        b,_ = scatter_min(x, batch, dim = 0)
        c = scatter_sum(x,batch,dim = 0)
        d = scatter_mean(x,batch,dim= 0)
        x = torch.cat((a,b,c,d),dim = 1)
        
        x = self.relu(x)
        x = self.nn3(x)
        
        x = self.relu(x)
        x = self.nn4(x)
        
        if mode == 'angle':
            x[:,0] = self.tanh(x[:,0])
            x[:,1] = self.tanh(x[:,1])
        

        return x
Example #16
0
    def scatter_distribution(src: torch.Tensor,
                             index: torch.Tensor,
                             dim: int = -1,
                             out: Optional[torch.Tensor] = None,
                             dim_size: Optional[int] = None,
                             unbiased: bool = True) -> torch.Tensor:

        if out is not None:
            dim_size = out.size(dim)

        if dim < 0:
            dim = src.dim() + dim

        count_dim = dim
        if index.dim() <= dim:
            count_dim = index.dim() - 1

        ones = torch.ones(index.size(), dtype=src.dtype, device=src.device)
        count = scatter_sum(ones, index, count_dim, dim_size=dim_size)

        index = broadcast(index, src, dim)
        tmp = scatter_sum(src, index, dim, dim_size=dim_size)
        count = broadcast(count, tmp, dim).clamp(1)
        mean = tmp.div(count)

        src_minus_mean = (src - mean.gather(dim, index))
        var = src_minus_mean * src_minus_mean
        var = scatter_sum(var, index, dim, out, dim_size)

        if unbiased:
            count = count.sub(1).clamp_(1)
        var = var.div(count)

        skew = src_minus_mean * src_minus_mean * src_minus_mean / (
            var.gather(dim, index) + 1e-7)**(1.5)
        kurtosis = (src_minus_mean * src_minus_mean * src_minus_mean *
                    src_minus_mean) / (var * var + 1e-7).gather(dim, index)

        skew = scatter_sum(skew, index, dim, out, dim_size)
        kurtosis = scatter_sum(kurtosis, index, dim, out, dim_size)

        skew = skew.div(count)
        kurtosis = kurtosis.div(count)

        maximum = scatter_max(src, index, dim, out, dim_size)[0]
        minimum = scatter_min(src, index, dim, out, dim_size)[0]

        return torch.cat([mean, var, skew, kurtosis, maximum, minimum], dim=1)
Example #17
0
    def __call__(self, data: Data):
        pos = data.pos
        edges = data.edge_index

        if self.weighted_normals:
            normals = weighted_normals(data.face_normals, data.face_areas,
                                       data.faces, len(data.pos))
        else:
            normals = data.vertex_normals

        diffs = pos[edges[0]] - pos[edges[1]]
        projectors = make_projector(normals)

        projected = torch.einsum('pij, pj -> pi', projectors[edges[0]],
                                 pos[edges[1]])
        projected /= projected.norm(dim=1, keepdim=True)

        gauge, _ = scatter_min(edges[1], edges[0])
        e1 = projected[gauge]
        e2 = torch.cross(e1, normals)
        log_map = projected * diffs.norm(dim=1, keepdim=True)

        theta_x = torch.einsum('pi, pi -> p', e1[edges[0]], log_map)
        theta_y = torch.einsum('pi, pi -> p', e2[edges[0]], log_map)
        theta = torch.atan2(theta_y, theta_x)

        axis = torch.cross(normals[edges[1]], normals[edges[0]])
        alpha = torch.einsum('pi, pi -> p', normals[edges[0]],
                             normals[edges[1]]).clamp(-1, 1)
        alpha = torch.acos(alpha)
        rotvec = (alpha.unsqueeze(dim=-1) * axis).numpy()
        rotation = R.from_rotvec(rotvec)
        g_x = rotation.apply(e1[edges[1]].numpy())
        g_x = torch.einsum('pi, pi -> p', torch.FloatTensor(g_x), e1[edges[0]])
        g_y = rotation.apply(e2[edges[1]].numpy())
        g_y = torch.einsum('pi, pi -> p', torch.FloatTensor(g_y), e2[edges[0]])
        g = torch.atan2(g_y, g_x)

        del data.vertex_normals
        del data.faces
        del data.face_normals
        del data.face_areas
        if self.distance:
            data.distance = diffs.norm(dim=1)
        data.g = g
        data.theta = theta
        return data
Example #18
0
def take_action_deterministic_batch(target_net,
                                    player,
                                    next_player,
                                    rewards,
                                    next_afterstates,
                                    weights=None,
                                    id_graphs=None,
                                    **kwargs):
    """Take actions in batch"""

    if id_graphs is None:
        n_nodes = sum([len(afterstate) for afterstate in next_afterstates])
        id_graphs = torch.zeros(size=(n_nodes, ), dtype=torch.int64).to(device)
    # if the game is finished in the next turn
    # we know what is the best action to take
    # because we have the true rewards available
    if next_player == 3:
        # the targets are the true values
        targets = rewards
    # if it's not the end state,
    # we sample from the values
    else:
        with torch.no_grad():
            # Create a Batch of graphs
            G_torch = Batch.from_data_list(next_afterstates).to(device)
            # We compute the target values
            targets = target_net(G_torch, **kwargs)
    if weights is not None:
        weights_tensor = torch.tensor(weights, dtype=torch.float).view(
            targets.size()).to(device)
        target_decision = targets + weights_tensor
    else:
        target_decision = targets
    # if it's the turn of the attacker
    if player == 1:
        # we take the argmin
        _, actions = scatter_min(target_decision, id_graphs, dim=0)
    else:
        # we take the argmax
        _, actions = scatter_max(target_decision, id_graphs, dim=0)
    values = targets[actions[:, 0]]

    return actions.view(-1).tolist(), targets, values.view(-1).tolist()
Example #19
0
def get_depot_info(beam, graph):
    """
    Finds for each group (set of visited nodes) in the beam the lowest cost to return to the depot
    This is useful since any non-dominated (lowest cost) expansion via the depot must necessarily also
    arrive at the depot at lowest cost (since remaining demand is reset at depot, only look at cost)
    :param beam:
    :param graph:
    :return:
    """
    # Get total distance to depot for each entry in group, for first action current is undefined, don't add
    beam_cost_at_depot = beam.cost if beam.current is None else beam.cost + graph.cost_to_depot[
        beam.batch_ids, beam.current.long()]
    if beam.sort_by == 'group_idx':
        group_min_cost_at_depot, group_idx_min_cost_at_depot = segment_min_coo(beam_cost_at_depot, beam.group_idx)
    else:
        group_min_cost_at_depot, group_idx_min_cost_at_depot = scatter_min(beam_cost_at_depot, beam.group_idx)
    beam_min_cost_at_depot = group_min_cost_at_depot.gather(0, beam.group_idx)
    beam_idx_min_cost_at_depot = group_idx_min_cost_at_depot.gather(0, beam.group_idx)
    return group_min_cost_at_depot, group_idx_min_cost_at_depot, beam_min_cost_at_depot, beam_idx_min_cost_at_depot
Example #20
0
 def forward(self, batched_data):
     h_node = self.gnn_node(batched_data)
     if self.graph_pooling == 'laf' and isinstance(self.pool,
                                                   ScatterAggregationLayer):
         x_min = torch_scatter.scatter_min(h_node,
                                           batched_data.batch,
                                           dim=0)[0]
         gather_idxs = batched_data.batch.expand(h_node.shape[1], -1).t()
         gather_mins = torch.gather(x_min, 0, gather_idxs)
         s = F.relu(-gather_mins)
         h_node = h_node + s
         out = self.pool(h_node, batched_data.batch)
         s_out = self.pool(s, batched_data.batch)
         h_graph = out - s_out
     elif self.graph_pooling == 'laf' and isinstance(
             self.pool, ScatterExponentialLAF):
         h_graph = self.pool(h_node, batched_data.batch)
     else:
         h_graph = self.pool(h_node, batched_data.batch)
     return self.graph_pred_linear(h_graph)
    def __call__(self, data):
        row, col = data.edge_index
        N = data.num_nodes

        deg = degree(row, N, dtype=torch.float)
        deg_col = deg[col]

        value = 1e16
        min_deg, _ = scatter_min(deg_col, row, dim_size=N, fill_value=value)
        min_deg[min_deg == value] = 0
        max_deg, _ = scatter_max(deg_col, row, dim_size=N)
        mean_deg = scatter_mean(deg_col, row, dim_size=N)
        std_deg = scatter_std(deg_col, row, dim_size=N)

        x = torch.stack([deg, min_deg, max_deg, mean_deg, std_deg], dim=1)

        if data.x is not None:
            data.x = data.x.view(-1, 1) if data.x.dim() == 1 else data.x
            data.x = torch.cat([data.x, x], dim=-1)
        else:
            data.x = x

        return data
Example #22
0
    def predict(self, inputs, labels):
        K = inputs['intrinsics']
        extrinsics = inputs['extrinsics']
        depths = inputs['depth']
        depth_mask = inputs['depth_mask']
        target_T = inputs['target_T']
        segs = inputs['seg']
        if self.ind is not None:
            depths = depths[:, self.ind:self.ind + 1]
            depth_mask = depth_mask[:, self.ind:self.ind + 1]
            target_T = target_T[:, self.ind:self.ind + 1]
            segs = segs[:, self.ind:self.ind + 1]

        # Step 1: back project 2d points into 3D using formula:
        # pts3d = depths * K^-1 * pts2d
        b, inp_t, height, width = depths.shape
        vs, us = torch.meshgrid(
            torch.arange(height, dtype=torch.float, device=depths.device),
            torch.arange(width, dtype=torch.float, device=depths.device),
        )

        pts2d = torch.cat([
            us.reshape(-1, 1),
            vs.reshape(-1, 1),
            torch.ones(height * width, 1, dtype=torch.float, device=us.device)
        ],
                          dim=-1).unsqueeze(0).expand(b, -1, -1)
        K_inv = torch.inverse(K).reshape(b, 1, 3, 3)

        # [b, 1, 3, 3] x [b, hw, 3, 1]. After squeeze result is [b, hw, 3]
        pts3d_c = (K_inv @ pts2d.unsqueeze(-1)).squeeze(-1)
        pts3d_c = pts3d_c.unsqueeze(1) * depths.reshape(b, inp_t, -1, 1)
        pts3d_c = torch.cat([
            pts3d_c,
            torch.ones(b, inp_t, height * width, 1, device=K.device)
        ],
                            dim=-1)

        # Step 2: convert camera points (in RDF) to vehicle points (in FLU)
        # Here, pts3d is [b, inp_t, h*w, 4]
        pts3d_v = extrinsics.view(b, 1, 1, 4, 4) @ pts3d_c.unsqueeze(-1)

        # Step 3: transform points such that they lie in the final frame's
        # vehicle coordinate system
        # target_T shape: [b, inp_t, 4, 4]
        result_pts3d_v = target_T.unsqueeze(2) @ pts3d_v

        # Step 4: Project points to 2d (by first transforming to camera coordinates)
        result_pts3d_c = torch.inverse(extrinsics).reshape(b, 1, 1, 4,
                                                           4) @ result_pts3d_v
        result_pts3d_c = result_pts3d_c[:, :, :, :3] / result_pts3d_c[:, :, :,
                                                                      3:4]
        result_depths = result_pts3d_c[:, :, :, 2]
        result2d = K.view(b, 1, 1, 3, 3) @ result_pts3d_c
        result2d = result2d[:, :, :, :2] / result2d[:, :, :, 2:3]
        #result2d = result2d.squeeze(-1).round().long()

        result2d = result2d.squeeze(-1)
        # Valid points have the following properties:
        # - They correspond to valid input depth values
        # - the depth values are > 0 (i.e. they lie in front of the camera)
        # - The u/v coordinates lie within the image
        inbounds_mask = (result2d[:, :, :, 0] >= 0) & \
                        (result2d[:, :, :, 0] < width) & \
                        (result2d[:, :, :, 1] >= 0) & \
                        (result2d[:, :, :, 1] < height)
        result_mask = depth_mask.view(b, inp_t, height*width)* \
                    (result_depths.squeeze(-1) > 0) & \
                      inbounds_mask

        # We need to translate our 2d predictions (which currently take the form
        # [batch, num_predicted_points, 2] and represent the u/v coordinates for each
        # point in the final camera frame) to the actual image, only keeping points
        # with valid depths and moreover keeping the closest valid point.

        # We do this using scatter (a good overview of how this works can be seen at
        # https://pytorch-scatter.readthedocs.io/en/latest/functions/scatter.html#torch_scatter.scatter)

        # First: find the points with the smallest depth at each result location
        result_mask = result_mask.reshape(b, inp_t * height * width)
        result_depths = result_depths.reshape(b, inp_t * height * width)

        # Make sure we never select an invalid point for a location when a
        # valid point exists
        result_depths[~result_mask] = result_depths.max() + 1
        result2d = result2d.reshape(b, inp_t * height * width, 2)
        result2d_0 = torch.stack([
            result2d[:, :, 0].floor().long(), result2d[:, :, 1].floor().long()
        ],
                                 dim=-1)
        result2d_1 = torch.stack([
            result2d[:, :, 0].floor().long(), result2d[:, :, 1].ceil().long()
        ],
                                 dim=-1)
        result2d_2 = torch.stack([
            result2d[:, :, 0].ceil().long(), result2d[:, :, 1].floor().long()
        ],
                                 dim=-1)
        result2d_3 = torch.stack(
            [result2d[:, :, 0].ceil().long(), result2d[:, :, 1].ceil().long()],
            dim=-1)

        result2d = torch.cat([result2d_0, result2d_1, result2d_2, result2d_3],
                             dim=1)
        result2d[:, :, 0].clamp_(0, width - 1)
        result2d[:, :, 1].clamp_(0, height - 1)
        result_depths = result_depths.repeat(1, 4)

        scatter_inds = result2d[:, :, 1] * width + result2d[:, :, 0]
        _, argmin = torch_scatter.scatter_min(result_depths,
                                              scatter_inds,
                                              -1,
                                              dim_size=inp_t * height * width *
                                              4)
        tmp_mask = (argmin < inp_t * height * width * 4)
        ind0 = tmp_mask.nonzero()[:, 0]
        ind1 = argmin[tmp_mask]
        tgt_ind1 = tmp_mask.nonzero()[:, 1]

        if self.is_img:
            final_seg = torch.zeros(b,
                                    height * width,
                                    3,
                                    dtype=segs.dtype,
                                    device=K.device)
            segs = segs.reshape(b, inp_t * height * width, 3).repeat(1, 4, 1)
        else:
            final_seg = torch.zeros(b,
                                    height * width,
                                    dtype=segs.dtype,
                                    device=K.device)
            segs = segs.reshape(b, inp_t * height * width).repeat(1, 4)

        # The following ensures we don't copy a prediction from an "invalid" point
        segs[~result_mask.repeat(1, 4)] = 0
        final_seg[ind0, tgt_ind1] = segs[ind0, ind1]

        final_depths = torch.zeros(b,
                                   height * width,
                                   dtype=result_depths.dtype,
                                   device=K.device).fill_(-1)
        final_depths[ind0, tgt_ind1] = result_depths[ind0, ind1]
        if self.is_img:
            final_seg = final_seg.view(b, height, width, 3)
        else:
            final_seg = final_seg.view(b, height, width)

        result_dict = {
            'seg':
            final_seg,
            'result2d':
            result2d[:, :inp_t * height * width].reshape(
                b, inp_t, height, width, 2),
            'depth':
            final_depths.view(b, height, width),
        }
        return result_dict
Example #23
0
    def forward(self, data):
        """
        Provides a fractional solution to the data association problem.
        First, node and edge features are independently encoded by the encoder network. Then, they are iteratively
        'combined' for a fixed number of steps via the Message Passing Network (self.MPNet). Finally, they are
        classified independently by the classifiernetwork.
        Args:
            data: object containing attribues
              - x: node features matrix
              - edge_index: tensor with shape [2, M], with M being the number of edges, indicating nonzero entries in the
                graph adjacency (i.e. edges) (i.e. sparse adjacency)
              - edge_attr: edge features matrix (sorted by edge apperance in edge_index)

        Returns:
            classified_edges: list of unnormalized node probabilites after each MP step
        """
        x, edge_index, edge_attr = data.x, data.edge_index, data.edge_attr

        x_is_img = len(x.shape) == 4
        if self.node_cnn is not None and x_is_img:
            x = self.node_cnn(x)

            emb_dists = nn.functional.pairwise_distance(x[edge_index[0]], x[edge_index[1]]).view(-1, 1)
            edge_attr = torch.cat((edge_attr, emb_dists), dim = 1)

        # Encoding features step
        latent_edge_feats, latent_node_feats = self.encoder(edge_attr, x)
        initial_edge_feats = latent_edge_feats
        initial_node_feats = latent_node_feats

        # During training, the feature vectors that the MPNetwork outputs for the  last self.num_class_steps message
        # passing steps are classified in order to compute the loss.
        first_class_step = self.num_enc_steps - self.num_class_steps + 1
        first_attention_step = self.num_enc_steps - self.num_attention_steps + 1 
        if self.use_attention:
            if self.graph_pruning: 
                outputs_dict = {'classified_edges': [],'att_coefficients':[],'mask':[]}
            else: outputs_dict = {'classified_edges': [],'att_coefficients':[]}
        else: 
            if self.graph_pruning:  
                outputs_dict = {'classified_edges': [],'mask':[]}
            else: outputs_dict = {'classified_edges': []}
        

        mask = torch.full((edge_index.shape[1],), True, dtype=torch.bool)
        for step in range(1, self.num_enc_steps + 1):
            # Reattach the initially encoded embeddings before the update
            if self.reattach_initial_edges:
                latent_edge_feats = torch.cat((initial_edge_feats, latent_edge_feats), dim=1)              # [M,16]+[M,16] -> [M,32]
            if self.reattach_initial_nodes:
                latent_node_feats = torch.cat((initial_node_feats, latent_node_feats), dim=1)

            # Message Passing Step
            if self.use_attention:
                if self.graph_pruning:
                    a = torch.zeros(self.attention_head_num,edge_index.shape[1]).cuda()
                    edge_feats = torch.zeros(latent_edge_feats.shape[0],16).cuda()
                    latent_node_feats, edge_feats[mask],a_masked = self.MPNet(latent_node_feats, edge_index.T[mask].T, latent_edge_feats[mask])
                    latent_edge_feats = edge_feats
                    a.T[mask] = a_masked.T
                else:
                    latent_node_feats, latent_edge_feats,a = self.MPNet(latent_node_feats, edge_index, latent_edge_feats)
            else:
                if self.graph_pruning:
                    edge_feats = torch.zeros(latent_edge_feats.shape[0],16).cuda()
                    latent_node_feats, edge_feats[mask] = self.MPNet(latent_node_feats, edge_index.T[mask].T, latent_edge_feats[mask])
                    latent_edge_feats = edge_feats
                else: 
                    latent_node_feats, latent_edge_feats = self.MPNet(latent_node_feats, edge_index, latent_edge_feats)

            if step >= first_class_step:
                # Classification Step
                logits, _ = self.classifier(latent_edge_feats)
                pruning_this_step = self.graph_pruning and step >= self.first_prune_step and step < self.num_enc_steps
                
                if self.use_attention and step >= first_attention_step:
                    outputs_dict['att_coefficients'].append(a)

                if pruning_this_step:
                    if self.prune_mode == "classifier naive":
                        valid_pro = probabilities[mask]
                        topk_mask = torch.full((valid_pro.shape[0],), True,dtype=torch.bool)
                        _,indice = torch.topk(valid_pro,int(len(valid_pro)*self.prune_factor),largest=False)
                        topk_mask[indice]= False
                    elif self.prune_mode == "classifier node wise":
                        valid_pro = probabilities[mask]
                        valid_idx = edge_index[0][mask]

                        topk_mask = torch.ones(len(valid_pro), dtype=torch.bool)
                        valid_pro_copy = valid_pro.clone()

                        k = torch.ones(len(valid_idx)).cuda()
                        k = torch.max(scatter_add(k, valid_idx))
                        k = int(k * self.prune_factor)
                        for i in range(k):
                            _, argmin = torch_scatter.scatter_min(valid_pro_copy, valid_idx)
                            neighbor = scatter_add(topk_mask.long().cuda(), valid_idx)
                            argmin = argmin[neighbor > self.prune_min_edge]
                            topk_mask[argmin] = False
                            valid_pro_copy[argmin] = 2
                    mask[mask == True] = topk_mask
                    outputs_dict['mask'].append(mask.clone())

                probabilities = torch.zeros_like(logits.view(-1))
                probabilities[mask] = torch.sigmoid(logits.view(-1)[mask])
                outputs_dict['classified_edges'].append(probabilities)

        if self.num_enc_steps == 0:
            dec_edge_feats, _ = torch.sigmoid(self.classifier(latent_edge_feats))
            outputs_dict['classified_edges'].append(dec_edge_feats)
        return outputs_dict
    def train_step(model, optimizer, train_iterator, args, step, writer):
        optimizer.zero_grad()
        x_scores, x_relations, y_scores, y_relations, mask_relations, w_scores, w_relations, berts, edge_indices, softmax_edge_indices, n_program, max_y_score_len, mask_relations_class, question_indices, step_indices, noisy_mask_relations = train_iterator.next_supervised(
        )

        if args.cuda:
            x_scores = x_scores.cuda()
            x_relations = x_relations.cuda()
            y_scores = y_scores.cuda()
            y_relations = y_relations.cuda()
            mask_relations = mask_relations.cuda()
            w_scores = w_scores.cuda()
            w_relations = w_relations.cuda()
            berts = berts.cuda()
            edge_indices = edge_indices.cuda()
            softmax_edge_indices = softmax_edge_indices.cuda()
            mask_relations_class = mask_relations_class.cuda()
            question_indices = question_indices.cuda()
            step_indices = step_indices.cuda()
            noisy_mask_relations = noisy_mask_relations.cuda()

        scores, relations = model(x_scores, x_relations, berts, edge_indices,
                                  softmax_edge_indices, n_program,
                                  max_y_score_len)

        score_loss = torch.nn.CrossEntropyLoss(reduction='none')(scores,
                                                                 y_scores)

        if args.train_with_masking:
            relations = torch.where(
                mask_relations_class, relations,
                torch.tensor(-float('inf')).to(relations.device))
            relation_loss = torch.nn.CrossEntropyLoss(reduction='none')(
                relations, y_relations) * mask_relations
        else:
            relation_loss = torch.nn.CrossEntropyLoss(reduction='none')(
                relations, y_relations) * mask_relations

        relation_loss = relation_loss[noisy_mask_relations]
        all_loss = score_loss + args.relation_coeff * relation_loss
        all_loss = torch_scatter.scatter_add(
            all_loss,
            step_indices[1],
            dim=0,
            dim_size=torch.max(step_indices[1]) + 1)
        loss, _ = torch_scatter.scatter_min(
            all_loss,
            question_indices[1],
            dim=0,
            dim_size=torch.max(question_indices[1]) + 1)

        loss = torch.mean(loss)
        score_loss = torch.mean(
            torch_scatter.scatter_min(torch_scatter.scatter_add(
                score_loss,
                step_indices[1],
                dim=0,
                dim_size=torch.max(step_indices[1]) + 1),
                                      question_indices[1],
                                      dim=0,
                                      dim_size=torch.max(question_indices[1]) +
                                      1)[0])
        relation_loss = torch.mean(
            torch_scatter.scatter_min(torch_scatter.scatter_add(
                relation_loss,
                step_indices[1],
                dim=0,
                dim_size=torch.max(step_indices[1]) + 1),
                                      question_indices[1],
                                      dim=0,
                                      dim_size=torch.max(question_indices[1]) +
                                      1)[0])

        loss.backward()

        optimizer.step()
        log = {
            'supervised_loss': loss.item(),
            'supervised_score_loss': score_loss.item(),
            'supervised_relation_loss': relation_loss.item(),
        }

        for metric in log:
            writer.add_scalar(metric, log[metric], step)

        return log
Example #25
0
def train_dqn(batch_size,
              size_test_data,
              lr,
              betas,
              n_episode,
              update_target,
              n_time_instance_seen,
              eps_end,
              eps_decay,
              eps_start,
              dim_embedding,
              dim_values,
              dim_hidden,
              n_heads,
              n_att_layers,
              n_pool,
              alpha,
              p,
              n_free_min,
              n_free_max,
              d_edge_min,
              d_edge_max,
              Omega_max,
              Phi_max,
              Lambda_max,
              weighted,
              w_max=1,
              directed=False,
              num_workers=0,
              resume_training=False,
              path_train="",
              path_test_data=None,
              exact_protection=False,
              rate_display=200,
              batch_unroll=128):
    """Train a DQN to solve the MCN problem"""

    # Gather the hyperparameters
    dict_args = locals()
    # Gather the date as a string
    date_str = (datetime.now().strftime('%b') + str(datetime.now().day) + "_" +
                str(datetime.now().hour) + "-" + str(datetime.now().minute) +
                "-" + str(datetime.now().second))
    # Tensorboard init
    writer = SummaryWriter()
    # Init the counts
    count_steps = 0
    count_instances = 0
    # Compute n_max
    n_max = n_free_max + Omega_max + Phi_max + Lambda_max
    max_budget = Omega_max + Phi_max + Lambda_max
    list_players = [2] * Lambda_max + [1] * Phi_max + [0] * Omega_max
    # Compute the size of the memory
    size_memory = batch_size * n_time_instance_seen
    # Init the value net
    value_net = DQN(
        dim_input=5,
        dim_embedding=dim_embedding,
        dim_values=dim_values,
        dim_hidden=dim_hidden,
        n_heads=n_heads,
        n_att_layers=n_att_layers,
        n_pool=n_pool,
        K=n_max,
        alpha=alpha,
        p=p,
        weighted=weighted,
    ).to(device)
    # Initialize the optimizer
    optimizer = optim.Adam(value_net.parameters(), lr=lr, betas=betas)
    # Initialize the memory
    replay_memory_states = []
    replay_memory_actions = []
    replay_memory_afterstates = []
    replay_memory_rewards = []
    count_memory = 0
    # If resume training
    if resume_training:
        # load the state dicts of the optimizer and value_net
        value_net, optimizer = load_training_param(value_net, optimizer,
                                                   path_train)
    # Init the target net
    target_net = DQN(
        dim_input=5,
        dim_embedding=dim_embedding,
        dim_values=dim_values,
        dim_hidden=dim_hidden,
        n_heads=n_heads,
        n_att_layers=n_att_layers,
        n_pool=n_pool,
        K=n_max,
        alpha=alpha,
        p=p,
        weighted=weighted,
    ).to(device)
    target_net.load_state_dict(value_net.state_dict())
    target_net.eval()
    # in order to use the current value_net during training for an evaluation task,
    # we first create a second instance of ValueNet in which we will load the
    # state_dicts of the learning value_net before each use
    value_net_bis = DQN(
        dim_input=5,
        dim_embedding=dim_embedding,
        dim_values=dim_values,
        dim_hidden=dim_hidden,
        n_heads=n_heads,
        n_att_layers=n_att_layers,
        n_pool=n_pool,
        K=n_max,
        alpha=alpha,
        p=p,
        weighted=weighted,
    ).to(device)
    # generate the test set
    test_set_generators = load_create_test_set_dqn(
        n_free_min, n_free_max, d_edge_min, d_edge_max, Omega_max, Phi_max,
        Lambda_max, weighted, w_max, directed, size_test_data, path_test_data,
        batch_size, num_workers)
    losses_test = [0] * max_budget

    print("Number of parameters to train = %2d \n" % count_param_NN(value_net))

    for episode in tqdm(range(n_episode)):
        # Sample a random batch of instances from where to begin
        list_instances = generate_random_batch_instance(
            batch_unroll,
            n_free_min,
            n_free_max,
            d_edge_min,
            d_edge_max,
            Omega_max,
            Phi_max,
            Lambda_max,
            Budget_target=max_budget,
            weighted=weighted,
            w_max=w_max,
            directed=directed,
        )
        # Initialize the environment
        env = EnvironmentDQN(list_instances)
        # Init the list of instances for the episode
        current_states = None
        current_actions = None
        current_rewards = None
        cpt_budget = 0
        # Unroll the episode
        while env.Budget >= 1:
            last_states = current_states
            current_states = env.list_instance_torch
            action = sample_action_batch_dqn(value_net, env.player,
                                             env.batch_instance_torch, eps_end,
                                             eps_decay, eps_start, count_steps)
            env.step(action)
            last_actions = current_actions
            current_actions = action
            last_rewards = current_rewards
            current_rewards = env.rewards
            cpt_budget += 1

            # if we have the couples (state, afterstates) available
            if cpt_budget > 1:
                n_visited = 0
                for k in range(batch_unroll):
                    if len(replay_memory_states) < size_memory:
                        replay_memory_states.append(None)
                        replay_memory_afterstates.append(None)
                        replay_memory_actions.append(None)
                        replay_memory_rewards.append(None)
                    replay_memory_states[count_memory %
                                         size_memory] = last_states[k]
                    replay_memory_afterstates[count_memory %
                                              size_memory] = current_states[k]
                    replay_memory_rewards[count_memory %
                                          size_memory] = last_rewards[k]
                    n_free = int(torch.sum(last_states[k].J.eq(0)[:, 0]))
                    replay_memory_actions[
                        count_memory %
                        size_memory] = last_actions[k] - n_visited
                    n_visited += n_free
                    count_memory += 1
            # If we are in the last step, we push to memory the end rewards
            if env.Budget == 0 and cpt_budget > 1:
                n_visited = 0
                for k in range(batch_unroll):
                    if len(replay_memory_states) < size_memory:
                        replay_memory_states.append(None)
                        replay_memory_afterstates.append(None)
                        replay_memory_actions.append(None)
                        replay_memory_rewards.append(None)
                    replay_memory_states[count_memory %
                                         size_memory] = current_states[k]
                    # doesn't matter what we put in the afterstates here
                    replay_memory_afterstates[count_memory %
                                              size_memory] = current_states[k]
                    replay_memory_rewards[count_memory %
                                          size_memory] = current_rewards[k]
                    n_free = int(torch.sum(current_states[k].J.eq(0)[:, 0]))
                    replay_memory_actions[
                        count_memory %
                        size_memory] = current_actions[k] - n_visited
                    n_visited += n_free
                    count_memory += 1

            # if there is enough new instances in memory
            if count_memory > size_memory:
                # create a list of randomly shuffled indices to sample a batch from
                memory_size = len(replay_memory_states)
                id_batch = random.sample(range(memory_size), batch_size)
                # gather the states, afterstates, actions and rewards
                list_states = [replay_memory_states[k] for k in id_batch]
                list_afterstates = [
                    replay_memory_afterstates[k] for k in id_batch
                ]
                list_actions = [replay_memory_actions[k] for k in id_batch]
                list_rewards = [replay_memory_rewards[k] for k in id_batch]
                # recover the actions id in the batch
                n_visited = 0
                list_actions_new = []
                for k in range(len(list_actions)):
                    n_free = int(torch.sum(list_states[k].J.eq(0)[:, 0]))
                    list_actions_new.append(list_actions[k] + n_visited)
                    n_visited += n_free
                # create the tensors
                batch_states = collate_fn(list_states)
                batch_afterstates = collate_fn(list_afterstates)
                batch_actions = torch.tensor(list_actions_new,
                                             dtype=torch.long).view(
                                                 [len(list_actions),
                                                  1]).to(device)
                batch_rewards = torch.tensor(
                    list_rewards,
                    dtype=torch.float).view([len(list_rewards), 1]).to(device)
                # Compute the approximate values
                action_values = value_net(
                    batch_states.G_torch,
                    batch_states.n_nodes,
                    batch_states.Omegas,
                    batch_states.Phis,
                    batch_states.Lambdas,
                    batch_states.Omegas_norm,
                    batch_states.Phis_norm,
                    batch_states.Lambdas_norm,
                    batch_states.J,
                )
                # mask the attacked nodes
                mask_values = batch_states.J.eq(0)[:, 0]
                action_values = action_values[mask_values]
                # Gather the approximate values
                approx_values = action_values.gather(0, batch_actions)
                # compute the masks to apply to the target
                mask_attack = batch_states.next_player.eq(1)[:, 0]
                mask_exact = batch_states.next_player.eq(3)[:, 0]

                # Compute the approximate targets
                with torch.no_grad():
                    target_values = target_net(
                        batch_afterstates.G_torch,
                        batch_afterstates.n_nodes,
                        batch_afterstates.Omegas,
                        batch_afterstates.Phis,
                        batch_afterstates.Lambdas,
                        batch_afterstates.Omegas_norm,
                        batch_afterstates.Phis_norm,
                        batch_afterstates.Lambdas_norm,
                        batch_afterstates.J,
                    ).detach()

                    batch = batch_afterstates.G_torch.batch
                    mask_J = batch_afterstates.J.eq(0)[:, 0]
                    # mask the attacked nodes
                    batch = batch[mask_J]
                    target_values = target_values[mask_J]
                    # Compute the min and max
                    val_min, _ = scatter_min(target_values, batch, dim=0)
                    val_max, _ = scatter_max(target_values, batch, dim=0)
                    # create the target tensor
                    target = val_max
                    target[mask_attack] = val_min[mask_attack]
                    target[mask_exact] = batch_rewards[mask_exact]

                # Init the optimizer
                optimizer.zero_grad()
                # Compute the loss of the batch
                loss = torch.sqrt(torch.mean((approx_values - target)**2))
                # Update the parameters of the Value_net
                loss.backward()
                optimizer.step()
                # compute the loss on the test set using the value_net_bis
                value_net_bis.load_state_dict(value_net.state_dict())
                value_net_bis.eval()
                # Check the test losses every 20 steps
                if count_steps % 20 == 0:
                    losses_test = compute_loss_test_dqn(
                        test_set_generators,
                        list_players,
                        value_net=value_net_bis)
                for k in range(len(losses_test)):
                    name_loss = 'Loss test budget ' + str(k + 1)
                    writer.add_scalar(name_loss, float(losses_test[k]),
                                      count_steps)
                # Update the tensorboard
                writer.add_scalar("Loss", float(loss), count_steps)
                count_steps += 1

                # Update the target net
                if count_steps % update_target == 0:
                    target_net.load_state_dict(value_net.state_dict())
                    target_net.eval()

                # Saves model every rate_display steps
                if count_steps % rate_display == 0:
                    save_models(date_str, dict_args, value_net, optimizer,
                                count_steps)
                    print(
                        " \n Episode: %2d/%2d" %
                        (episode * batch_size, n_episode),
                        " \n Loss of the current value net: %f" % float(loss),
                        " \n Losses on test set : ",
                        losses_test,
                    )
Example #26
0
def compute_loss_test_dqn(test_set_generators,
                          list_players,
                          value_net=None,
                          list_experts=None,
                          id_to_test=None):
    """Compute the list of losses of the value_net or the list_of_experts
    over the list of exactly solved datasets that constitutes the test set"""

    list_losses = []
    with torch.no_grad():
        if id_to_test is None:
            iterator = range(len(test_set_generators))
        else:
            iterator = [id_to_test]
        for k in iterator:
            target = []
            val_approx = []
            player = list_players[k]
            if list_experts is not None:
                try:
                    target_net = list_experts[k]
                except IndexError:
                    target_net = None
            elif value_net is not None:
                target_net = value_net
            if target_net is None:
                list_losses.append(0)
            else:
                for i_batch, batch_instances in enumerate(
                        test_set_generators[k]):
                    batch = batch_instances.G_torch.batch
                    mask_values = batch_instances.J.eq(0)[:, 0]
                    action_values = target_net(
                        batch_instances.G_torch,
                        batch_instances.n_nodes,
                        batch_instances.Omegas,
                        batch_instances.Phis,
                        batch_instances.Lambdas,
                        batch_instances.Omegas_norm,
                        batch_instances.Phis_norm,
                        batch_instances.Lambdas_norm,
                        batch_instances.J,
                    )
                    action_values = action_values[mask_values]
                    batch = batch[mask_values]
                    # if it's the turn of the attacker
                    if player == 1:
                        # we take the argmin
                        values, actions = scatter_min(action_values,
                                                      batch,
                                                      dim=0)
                    else:
                        # we take the argmax
                        values, actions = scatter_max(action_values,
                                                      batch,
                                                      dim=0)
                    val_approx.append(values)
                    target.append(batch_instances.target)
                # Compute the loss
                target = torch.cat(target)
                val_approx = torch.cat(val_approx)
                loss_target_net = float(
                    torch.sqrt(torch.mean(
                        (val_approx[:, 0] - target[:, 0])**2)))
                list_losses.append(loss_target_net)

    return list_losses
Example #27
0
src = torch.Tensor([[2, 1, 1, 4, 2], [1, 2, 1, 2, 4]]).float()
index = torch.tensor([[4, 5, 4, 2, 3], [0, 0, 2, 2, 1]])
out = src.new_ones((2, 6))

out = scatter_div(src, index, out=out)

print(out)
# tensor([[1.0000, 1.0000, 0.2500, 0.5000, 0.5000, 1.0000],
# [0.5000, 0.2500, 0.5000, 1.0000, 1.0000, 1.0000]])

# 最大最小平均值
src = torch.Tensor([[2, 0, 1, 4, 3], [0, 2, 1, 3, 4]])
index = torch.tensor([[4, 5, 4, 2, 3], [0, 0, 2, 2, 1]])
out, argmax = scatter_max(src, index)
print(out, argmax)

out, argmin = scatter_min(src, index)
print(out, argmin)

out = scatter_mean(src, index)
print(out)

out = scatter_mul(src, index)
print(out)

out = scatter_std(src, index)
print(out)

out = scatter_sub(src, index)
print(out)
Example #28
0
    def forward(self, data, tvol = None):
        x = data.x
        edge_index = data.edge_index
        batch = data.batch 
        xinit= x.clone()
        row, col = edge_index
        mask = get_mask(x,edge_index,1).to(x.dtype).unsqueeze(-1)

        x = self.conv1(x, edge_index)
        xpostconv1 = x.detach() 
        x = x*mask
        for conv, bn in zip(self.convs, self.bns):
            if(x.dim()>1):
                x = x + conv(x, edge_index)
                mask = get_mask(mask,edge_index,1).to(x.dtype)
                x = x*mask
                x = bn(x)


        x = self.conv2(x, edge_index)
        mask = get_mask(mask,edge_index,1).to(x.dtype)
        x = x*mask
        xpostconvs = x.detach()
        #
        x = F.leaky_relu(self.lin1(x)) 
        x = x*mask
        x = self.bn2(x)

        xpostlin1 = x.detach()
        x = F.dropout(x, p=0.5, training=self.training)
        x = F.leaky_relu(self.lin2(x)) 
        x = x*mask
        

        xprethresh = x.detach()
        N_size = x.shape[0]    
        batch_max = scatter_max(x, batch, 0, dim_size= N_size)[0]
        batch_max = torch.index_select(batch_max, 0, batch)
        batch_min = scatter_min(x, batch, 0, dim_size= N_size)[0]
        batch_min = torch.index_select(batch_min, 0, batch)
        
        #min-max normalize       
        x = (x-batch_min)/(batch_max+1e-6-batch_min)
        x = x*mask + mask*1e-6
        

        #add dirac in the set
        x = x + xinit.unsqueeze(-1)
        
        #calculate
        x2 = x.detach()              
        r, c = edge_index
        tv = total_var(x, edge_index, batch)
        deg = degree(r).unsqueeze(-1) 
        conduct_1 = (tv)
        totalvol = scatter_add(deg.detach()*torch.ones_like(x, device=device), batch, 0)+1e-6
        totalcard = scatter_add(torch.ones_like(x, device=device), batch, 0)+1e-6
        
                
        #receptive field
        recvol_hard = scatter_add(deg*mask.float(), batch, 0, dim_size = batch.max().item()+1)+1e-6 
        reccard_hard = scatter_add(mask.float(), batch, 0, dim_size = batch.max().item()+1)+1e-6 
        
        assert recvol_hard.mean()/totalvol.mean() <=1, "Something went wrong! Receptive field is larger than total volume."
        target = torch.zeros_like(totalvol)
        
        #generate target vol
        if tvol is None:
            feasible_vols = data.recfield_vol/data.total_vol-0.0
            target = torch.rand_like(feasible_vols, device=device)*feasible_vols*0.85 + 0.1
            target = target.squeeze(-1)*totalvol.squeeze(-1)
        else:
            target = tvol*totalvol.squeeze(-1)
        a = torch.ones((batch.max().item()+1,1), device = device)
        xfilt = x
                
        
        ###############################################################################
        #iterative rescaling
        counter_no2 = 0
        for iteration in range(self.num_iterations):
            counter_no2 += 1
            keep = (((a[batch]*xfilt)<1).to(x.dtype))

            
            x_k, d_k, d_nk = xfilt*keep*mask, deg*keep*mask, deg*(1-keep)*mask
            
            
            diff = target.unsqueeze(-1) - scatter_add(d_nk, batch, 0)
            dot = scatter_add(x_k*d_k, batch, 0)
            a = diff/(dot+1e-5)
            volcur = (scatter_add(torch.clamp(a[batch]*xfilt,max = 1., min = 0.)*deg,batch,0))

            volcheck = (torch.abs(target - volcur.squeeze(-1))>0.1)
            checki = torch.abs(target.squeeze(-1)-volcur.squeeze(-1))>0.01

            targetcheck = torch.abs(volcur.squeeze(-1) - target)
            
            check = (targetcheck<= self.elasticity*target).to(x.dtype)

            if (tvol is not None):
                pass
            if(check.sum()>=batch.max().item()+1):
                break;
        
        probs = torch.clamp(a[batch]*x*mask, max = 1., min = 0.)
        ###############################################################################

            
            
        #collect useful numbers    
        x2 =  ((probs - torch.rand_like(x, device = device))>0).float()         
        vol_1 = scatter_add(probs*deg, batch, 0)+1e-6
        card_1 = scatter_add(probs, batch,0) 
        rec_field = scatter_add(mask, batch, 0)+1e-6
        cut_size = scatter_add(x2, batch, 0)
        tv_hard = total_var(x2, edge_index, batch)
        vol_hard = scatter_add(deg*x2, batch, 0, dim_size = batch.max().item()+1)+1e-6 
        conduct_hard = tv_hard/vol_hard         
        rec_field_ratio = cut_size/rec_field
        rec_field_volratio = vol_hard/recvol_hard
        total_vol_ratio = vol_hard/totalvol
        
        #calculate loss
        expected_cut = scatter_add(probs*deg, batch, 0) - scatter_add((probs[row]*probs[col]), batch[row], 0)   
        loss = expected_cut   


        #return dict 
        retdict = {}
        retdict["output"] = [probs.squeeze(-1),"hist"]   #output
        #retdict["|Expected_vol - Target|"]= [targetcheck, "sequence"] #absolute distance from targetvol
        retdict["Expected_volume"] = [vol_1.mean(),"sequence"] #volume
        retdict["Expected_cardinality"] = [card_1.mean(),"sequence"]
        retdict["volume_hard"] = [vol_hard.mean(),"sequence"] #volume2
        #retdict["cut1"] = [tv.mean(),"sequence"] #cut1
        retdict["cut_hard"] = [tv_hard.mean(),"sequence"] #cut1
        retdict["Average cardinality ratio of receptive field "] = [rec_field_ratio.mean(),"sequence"] 
        retdict["Recfield volume/Total volume"] = [recvol_hard.mean()/totalvol.mean(), "sequence"]
        retdict["Average ratio of receptive field volume"]= [rec_field_volratio.mean(),'sequence']
        retdict["Average ratio of total volume"]= [total_vol_ratio.mean(),'sequence']
        retdict["mask"] = [mask, "aux"] #mask
        retdict["xinit"] = [xinit,"hist"] #layer input diracs
        retdict["xpostlin1"] = [xpostlin1.mean(1),"hist"] #after first linear layer
        retdict["xprethresh"] = [xprethresh.mean(1),"hist"] #pre thresholding activations 195 x 1
        retdict["lossvol"] = [lossvol.mean(),"sequence"] #volume constraint
        retdict["losscard"] = [losscard.mean(),"sequence"] #cardinality constraint
        retdict["loss"] = [loss.mean().squeeze(),"sequence"] #final loss

        return retdict
Example #29
0
    def forward(self, data, edge_dropout = None, penalty_coefficient = 0.25):
        x = data.x
        edge_index = data.edge_index
        batch = data.batch
        num_graphs = batch.max().item() + 1
        row, col = edge_index     
        total_num_edges = edge_index.shape[1]
        N_size = x.shape[0]

        
        if edge_dropout is not None:
            edge_index = dropout_adj(edge_index, edge_attr = (torch.ones(edge_index.shape[1], device=device)).long(), p = edge_dropout, force_undirected=True)[0]
            edge_index = add_remaining_self_loops(edge_index, num_nodes = batch.shape[0])[0]
                
        reduced_num_edges = edge_index.shape[1]
        current_edge_percentage = (reduced_num_edges/total_num_edges)
        no_loop_index,_ = remove_self_loops(edge_index)  
        no_loop_row, no_loop_col = no_loop_index

        xinit= x.clone()
        x = x.unsqueeze(-1)
        mask = get_mask(x,edge_index,1).to(x.dtype)
        x = F.leaky_relu(self.conv1(x, edge_index))# +x
        x = x*mask
        x = self.gnorm(x)
        x = self.bn1(x)
        
            
        for conv, bn in zip(self.convs, self.bns):
            if(x.dim()>1):
                x =  x+F.leaky_relu(conv(x, edge_index))
                mask = get_mask(mask,edge_index,1).to(x.dtype)
                x = x*mask
                x = self.gnorm(x)
                x = bn(x)

        xpostconvs = x.detach()
        #
        x = F.leaky_relu(self.lin1(x)) 
        x = x*mask


        xpostlin1 = x.detach()
        x = F.leaky_relu(self.lin2(x)) 
        x = x*mask


        #calculate min and max
        batch_max = scatter_max(x, batch, 0, dim_size= N_size)[0]
        batch_max = torch.index_select(batch_max, 0, batch)        
        batch_min = scatter_min(x, batch, 0, dim_size= N_size)[0]
        batch_min = torch.index_select(batch_min, 0, batch)

        #min-max normalize
        x = (x-batch_min)/(batch_max+1e-6-batch_min)
        probs=x
           
        x2 = x.detach()              
        deg = degree(row).unsqueeze(-1) 
        totalvol = scatter_add(deg.detach()*torch.ones_like(x, device=device), batch, 0)+1e-6
        totalcard = scatter_add(torch.ones_like(x, device=device), batch, 0)+1e-6               
        x2 =  ((x2 - torch.rand_like(x, device = device))>0).float()    
        vol_1 = scatter_add(probs*deg, batch, 0)+1e-6
        card_1 = scatter_add(probs, batch,0)            
        set_size = scatter_add(x2, batch, 0)
        vol_hard = scatter_add(deg*x2, batch, 0, dim_size = batch.max().item()+1)+1e-6 
        total_vol_ratio = vol_hard/totalvol
        
        
        #calculating the terms for the expected distance between clique and graph
        pairwise_prodsums = torch.zeros(num_graphs, device = device)
        for graph in range(num_graphs):
            batch_graph = (batch==graph)
            pairwise_prodsums[graph] = (torch.conv1d(probs[batch_graph].unsqueeze(-1), probs[batch_graph].unsqueeze(-1))).sum()/2
        
        
        ###calculate loss terms
        self_sums = scatter_add((probs*probs), batch, 0, dim_size = num_graphs)
        expected_weight_G = scatter_add(probs[no_loop_row]*probs[no_loop_col], batch[no_loop_row], 0, dim_size = num_graphs)/2.
        expected_clique_weight = (pairwise_prodsums.unsqueeze(-1) - self_sums)/1.
        expected_distance = (expected_clique_weight - expected_weight_G)        
        
        
        ###useful numbers 
        max_set_weight = (scatter_add(torch.ones_like(x)[no_loop_row], batch[no_loop_row], 0, dim_size = num_graphs)/2).squeeze(-1)                
        set_weight = (scatter_add(x2[no_loop_row]*x2[no_loop_col], batch[no_loop_row], 0, dim_size = num_graphs)/2)+1e-6
        clique_edges_hard = (set_size*(set_size-1)/2) +1e-6
        clique_dist_hard = set_weight/clique_edges_hard
        clique_check = ((clique_edges_hard != clique_edges_hard))
        setedge_check  = ((set_weight != set_weight))      
        
        assert ((clique_dist_hard>=1.1).sum())<=1e-6, "Invalid set vol/clique vol ratio."

        ###calculate loss
        expected_loss = (penalty_coefficient)*expected_distance*0.5 - 0.5*expected_weight_G  
        

        loss = expected_loss


        retdict = {}
        
        retdict["output"] = [probs.squeeze(-1),"hist"]   #output
        retdict["Expected_cardinality"] = [card_1.mean(),"sequence"]
        retdict["Expected_cardinality_hist"] = [card_1,"hist"]
        retdict["losses histogram"] = [loss.squeeze(-1),"hist"]
        retdict["Set sizes"] = [set_size.squeeze(-1),"hist"]
        retdict["volume_hard"] = [vol_hard.mean(),"aux"] #volume2
        retdict["cardinality_hard"] = [set_size[0],"sequence"] #volumeq
        retdict["Expected weight(G)"]= [expected_weight_G.mean(), "sequence"]
        retdict["Expected maximum weight"] = [expected_clique_weight.mean(),"sequence"]
        retdict["Expected distance"]= [expected_distance.mean(), "sequence"]
        retdict["Currvol/Cliquevol"] = [clique_dist_hard.mean(),'sequence']
        retdict["Currvol/Cliquevol all graphs in batch"] = [clique_dist_hard.squeeze(-1),'hist']
        retdict["Average ratio of total volume"]= [total_vol_ratio.mean(),'sequence']
        retdict["cardinalities"] = [cardinalities.squeeze(-1),"hist"]
        retdict["Current edge percentage"] = [torch.tensor(current_edge_percentage),'sequence']
        retdict["loss"] = [loss.mean().squeeze(),"sequence"] #final loss

        return retdict
Example #30
0
    def forward(self, data, edge_dropout=None, penalty_coefficient=0.25):
        x = data.x
        edge_index = data.edge_index
        batch = data.batch
        num_graphs = batch.max().item() + 1
        row, col = edge_index
        total_num_edges = edge_index.shape[1]
        N_size = x.shape[0]

        if edge_dropout is not None:
            edge_index = dropout_adj(
                edge_index,
                edge_attr=(torch.ones(edge_index.shape[1],
                                      device=device)).long(),
                p=edge_dropout,
                force_undirected=True)[0]
            edge_index = add_remaining_self_loops(edge_index,
                                                  num_nodes=batch.shape[0])[0]

        reduced_num_edges = edge_index.shape[1]
        current_edge_percentage = (reduced_num_edges / total_num_edges)
        no_loop_index, _ = remove_self_loops(edge_index)
        no_loop_row, no_loop_col = no_loop_index

        xinit = x.clone()
        x = x.unsqueeze(-1)
        mask = get_mask(x, edge_index, 1).to(x.dtype)
        x = F.leaky_relu(self.conv1(x, edge_index))  # +x
        x = x * mask
        x = self.gnorm(x)
        x = self.bn1(x)

        for conv, bn in zip(self.convs, self.bns):
            if (x.dim() > 1):
                x = x + F.leaky_relu(conv(x, edge_index))
                mask = get_mask(mask, edge_index, 1).to(x.dtype)
                x = x * mask
                x = self.gnorm(x)
                x = bn(x)

        xpostconvs = x.detach()
        #
        x = F.leaky_relu(self.lin1(x))
        x = x * mask

        xpostlin1 = x.detach()
        x = F.leaky_relu(self.lin2(x))
        x = x * mask

        #calculate min and max
        batch_max = scatter_max(x, batch, 0, dim_size=N_size)[0]
        batch_max = torch.index_select(batch_max, 0, batch)
        batch_min = scatter_min(x, batch, 0, dim_size=N_size)[0]
        batch_min = torch.index_select(batch_min, 0, batch)

        #min-max normalize
        x = (x - batch_min) / (batch_max + 1e-6 - batch_min)
        probs = x

        #calculating the terms for the expected distance between clique and graph
        pairwise_prodsums = torch.zeros(num_graphs, device=device)
        for graph in range(num_graphs):
            batch_graph = (batch == graph)
            pairwise_prodsums[graph] = (torch.conv1d(
                probs[batch_graph].unsqueeze(-1),
                probs[batch_graph].unsqueeze(-1))).sum() / 2

        ###calculate loss terms
        self_sums = scatter_add((probs * probs), batch, 0, dim_size=num_graphs)
        expected_weight_G = scatter_add(
            probs[no_loop_row] * probs[no_loop_col],
            batch[no_loop_row],
            0,
            dim_size=num_graphs) / 2.
        expected_clique_weight = (pairwise_prodsums.unsqueeze(-1) -
                                  self_sums) / 1.
        expected_distance = (expected_clique_weight - expected_weight_G)

        ###calculate loss
        expected_loss = (penalty_coefficient
                         ) * expected_distance * 0.5 - 0.5 * expected_weight_G

        loss = expected_loss

        retdict = {}

        retdict["output"] = [probs.squeeze(-1), "hist"]  #output
        retdict["losses histogram"] = [loss.squeeze(-1), "hist"]
        retdict["Expected weight(G)"] = [expected_weight_G.mean(), "sequence"]
        retdict["Expected maximum weight"] = [
            expected_clique_weight.mean(), "sequence"
        ]
        retdict["Expected distance"] = [expected_distance.mean(), "sequence"]
        retdict["loss"] = [loss.mean().squeeze(), "sequence"]  #final loss

        return retdict