Ejemplo n.º 1
0
    def _get_single_double_config(self, nocc, nvirt):
        """Get the confs of the single + double

        Args:
            nelec (int): number of electrons in the active space
            norb (int) : number of orbitals in the active space
        """

        _gs_up = list(range(self.nup))
        _gs_down = list(range(self.ndown))
        cup, cdown = self._get_single_config(nocc, nvirt)
        cup = cup.tolist()
        cdown = cdown.tolist()

        idx_occ_up = list(
            range(self.nup - 1, self.nup - 1 - nocc, -1))
        idx_vrt_up = list(range(self.nup, self.nup + nvirt, 1))

        idx_occ_down = list(range(
            self.ndown - 1, self.ndown - 1 - nocc, -1))
        idx_vrt_down = list(range(self.ndown, self.ndown + nvirt, 1))

        # ground, single and double with 1 elec excited per spin
        for iocc_up in idx_occ_up:
            for ivirt_up in idx_vrt_up:

                for iocc_down in idx_occ_down:
                    for ivirt_down in idx_vrt_down:

                        _xt_up = self._create_excitation(
                            _gs_up.copy(), iocc_up, ivirt_up)
                        _xt_down = self._create_excitation(
                            _gs_down.copy(), iocc_down, ivirt_down)
                        cup, cdown = self._append_excitations(
                            cup, cdown, _xt_up, _xt_down)

        # double with 2elec excited on spin up
        for occ1, occ2 in torch.combinations(torch.as_tensor(idx_occ_up), r=2):
            for vrt1, vrt2 in torch.combinations(torch.as_tensor(idx_vrt_up), r=2):
                _xt_up = self._create_excitation(
                    _gs_up.copy(), occ1, vrt2)
                _xt_up = self._create_excitation(_xt_up, occ2, vrt1)
                cup, cdown = self._append_excitations(
                    cup, cdown, _xt_up, _gs_down)

        # double with 2elec excited per spin
        for occ1, occ2 in torch.combinations(torch.as_tensor(idx_occ_down), r=2):
            for vrt1, vrt2 in torch.combinations(torch.as_tensor(idx_vrt_down), r=2):

                _xt_down = self._create_excitation(
                    _gs_down.copy(), occ1, vrt2)
                _xt_down = self._create_excitation(
                    _xt_down, occ2, vrt1)
                cup, cdown = self._append_excitations(
                    cup, cdown, _gs_up, _xt_down)

        return (torch.LongTensor(cup), torch.LongTensor(cdown))
    def similarity_computing_inner(self, doc_output, segment_idx):
        similarities = []
        cos = nn.CosineSimilarity(dim=1, eps=1e-6)
        seg_outputs = []
        index = 0
        doc_output = F.softmax(doc_output)
        for i, idx in enumerate(segment_idx):
            if i == 0:
                seg_output = doc_output[0:segment_idx[i] + 1, :]
            elif i == len(segment_idx) - 1:
                seg_output = doc_output[segment_idx[i - 1] + 1:, :]
            else:
                seg_output = doc_output[segment_idx[i - 1] + 1:segment_idx[i] +
                                        1, :]
            seg_outputs.append(seg_output)

        for i in range(len(seg_outputs)):
            sent_idx = maybe_cuda(
                torch.LongTensor([k for k in range(seg_outputs[i].size()[0])]))
            if seg_outputs[i].size()[0] > 1:
                pairs = torch.combinations(sent_idx)
                pair_sims = []
                for p in pairs:
                    pair_sims.append(
                        cos(seg_outputs[i][p[0], :].unsqueeze(0),
                            seg_outputs[i][p[1], :].unsqueeze(0)))
                similarities.append(sum(pair_sims) / len(pair_sims))
            else:
                continue

        return Variable(maybe_cuda(torch.Tensor(similarities)))
Ejemplo n.º 3
0
    def __init__(self, hparams):
        super(HeteroConv, self).__init__()

        self.hparams = hparams

        concatenation_factor = (
            3 if (self.hparams["aggregation"] in ["sum_max", "mean_max", "mean_sum"]) else 2
        )

        # Make module list
        self.node_encoders = nn.ModuleList([
            make_mlp(
                concatenation_factor * hparams["hidden"],
                [hparams["hidden"]] * hparams["nb_node_layer"],
                output_activation=hparams["output_activation"],
                hidden_activation=hparams["hidden_activation"],
                layer_norm=hparams["layernorm"],
                batch_norm=hparams["batchnorm"],
            ) for _ in hparams["model_ids"]
        ])

        # Make edge encoder combos (this is an N-choose-2 with replacement situation)
        self.all_combos = torch.combinations(torch.arange(len(self.hparams["model_ids"])), r=2, with_replacement=True)
    
        self.edge_encoders = nn.ModuleList([
            make_mlp(
                3 * hparams["hidden"],
                [hparams["hidden"]] * hparams["nb_edge_layer"],
                layer_norm=hparams["layernorm"],
                batch_norm=hparams["batchnorm"],
                output_activation=hparams["output_activation"],
                hidden_activation=hparams["hidden_activation"],
            ) for _ in self.all_combos
        ])
Ejemplo n.º 4
0
def multiview_covariance_matrix(dims,
                                constructor,
                                options=dict(),
                                symmetric=True,
                                **kwargs):
    num_views = dims.size(0)
    indices = torch.arange(num_views, dtype=torch.long)
    diag_indices = indices.repeat(2, 1).t_()
    upper_indices = torch.combinations(indices, r=2)
    diag_and_upper_indices = torch.cat([diag_indices, upper_indices])
    sum_dimensions = dims.sum()

    options.update(kwargs)
    out = torch.empty(sum_dimensions, sum_dimensions, **options)
    for j, r in diag_and_upper_indices:
        j_indices = slice(dims[:j].sum(), dims[:j + 1].sum())
        r_indices = slice(dims[:r].sum(), dims[:r + 1].sum())
        jr_cross_covariance = constructor(j, r)
        if not torch.is_tensor(
                jr_cross_covariance) or jr_cross_covariance.nelement() == 1:
            jr_cross_covariance = torch.empty(
                dims[j], dims[r], **options).fill_(jr_cross_covariance)
        elif jr_cross_covariance.nelement() == 0:
            jr_cross_covariance = torch.zeros(dims[j], dims[r], **options)
        out[j_indices, r_indices] = jr_cross_covariance
        if j != r:
            out[r_indices, j_indices] = jr_cross_covariance.t(
            ) if symmetric else constructor(r, j)
    return out
Ejemplo n.º 5
0
    def __init__(self, hparams):
        super(HeteroEncoder, self).__init__()

        self.hparams = hparams

        # Make module list
        self.node_encoders = nn.ModuleList([
            make_mlp(
                model["num_features"],
                [hparams["hidden"]] * hparams["nb_node_layer"],
                output_activation=hparams["output_activation"],
                hidden_activation=hparams["hidden_activation"],
                layer_norm=hparams["layernorm"],
                batch_norm=hparams["batchnorm"],
            ) for model in hparams["model_ids"]
        ])

        # Make edge encoder combos (this is an N-choose-2 with replacement situation)
        self.all_combos = torch.combinations(torch.arange(
            len(self.hparams["model_ids"])),
                                             r=2,
                                             with_replacement=True)

        self.edge_encoders = nn.ModuleList([
            make_mlp(
                2 * hparams["hidden"],
                [hparams["hidden"]] * hparams["nb_edge_layer"],
                layer_norm=hparams["layernorm"],
                batch_norm=hparams["batchnorm"],
                output_activation=hparams["output_activation"],
                hidden_activation=hparams["hidden_activation"],
            ) for _ in self.all_combos
        ])
Ejemplo n.º 6
0
    def get_triplets(self, embeddings, labels):

        # calculate distances between embedded vectors
        distance_matrix = pdist(embeddings)
        # logging.debug(distance_matrix)

        # triplets..
        triplets = []

        # get unique labels in this batch
        unique_labels = labels.unique()
        for anchor_label in unique_labels:

            # anchor indices
            anchor_mask = labels == anchor_label
            anchor_indices = (anchor_mask == 1).nonzero()[:, 0]

            # skip the following procedure if we have one or zero positive case
            if anchor_indices.size(0) < 2:
                continue

            # negative indices
            negative_indices = (anchor_mask == 0).nonzero()[:, 0]

            # generate anchor-positive pairs using combination
            anchor_positive_pairs = torch.combinations(anchor_indices)

            # for each pair
            for ap_pair in anchor_positive_pairs:
                anchor_idx = ap_pair[0]
                positive_idx = ap_pair[1]

                # distance between anchor-positive
                ap_distance = distance_matrix[anchor_idx, positive_idx]

                # distanceS between anchor-negativeS
                an_distances = distance_matrix[anchor_idx][negative_indices]

                # loss_valueS = ap - anS + margin
                loss_values = ap_distance - an_distances + self.margin

                # for each fn in func_list (ordered by high priority)
                negative_idx = None
                for fn in self.func_list:
                    negative_idx = fn(loss_values, self.margin)
                    if negative_idx is not None:
                        break

                # if no negative_idx is found, just pick a random one
                if negative_idx is None:
                    logging.debug("No negative idx found. Picking random..")
                    choice = torch.randint(0, negative_indices.size(0), (1,))
                    negative_idx = negative_indices[choice].view(-1).squeeze(0)

                triplet = torch.stack([anchor_idx, positive_idx, negative_idx], 0)
                triplets.append(triplet)

        triplet_tensor = torch.stack(triplets)
        logging.debug(triplet_tensor.size())
        return triplet_tensor
Ejemplo n.º 7
0
    def _create_pixel_pairs(self, input):
        batch_size = input.shape[0]

        # Divide image into nxn blocks
        n = 16
        unfold = torch.nn.Unfold(kernel_size=(n, n), stride=n)
        input_blocks = unfold(input)  # [B, 256, 196]
        num_patches = input_blocks.shape[-1]

        flattened_blocks = input_blocks.view(-1)
        total_patches = batch_size * input_blocks.shape[-1]

        # Randomly sample 1 pixel from each block
        random_pixel_idxs = torch.Tensor([
            256 * i + torch.randint(0, 256, size=(1, ))
            for i in range(total_patches)
        ])
        pixel_samples = flattened_blocks[random_pixel_idxs.type(torch.long)]
        pixel_samples_batched = pixel_samples.view(batch_size, num_patches)

        # Create combinations of each pixel pair index
        pixel_pairs = torch.stack(([
            torch.combinations(pixel_samples_batched[i])
            for i in range(batch_size)
        ]))
        pixel_pairs = pixel_pairs.view(-1, 2)

        return pixel_pairs[..., 0], pixel_pairs[..., 1]
def erdos_renyi_graph(num_nodes, edge_prob, directed=False):
    r"""Returns the :obj:`edge_index` of a random Erdos-Renyi graph.

    Args:
        num_nodes (int): The number of nodes.
        edge_prob (float): Probability of an edge.
        directed (bool, optional): If set to :obj:`True`, will return a
            directed graph. (default: :obj:`False`)
    """

    if directed:
        idx = torch.arange((num_nodes - 1) * num_nodes)
        idx = idx.view(num_nodes - 1, num_nodes)
        idx = idx + torch.arange(1, num_nodes).view(-1, 1)
        idx = idx.view(-1)
    else:
        idx = torch.combinations(torch.arange(num_nodes))

    # Filter edges.
    mask = torch.rand(idx.size(0)) < edge_prob
    idx = idx[mask]

    if directed:
        row = idx // num_nodes
        col = idx % num_nodes
        edge_index = torch.stack([row, col], dim=0)
    else:
        edge_index = to_undirected(idx.t(), num_nodes)

    return edge_index
    def __init__(self, hparams):
        super().__init__(hparams)
        """
        Initialise the Lightning Module that can scan over different filter training regimes
        """

        self.all_combos = torch.combinations(torch.arange(
            len(self.hparams["model_ids"])),
                                             r=2,
                                             with_replacement=True)

        # Still need this??
        self.vol_matrix = get_vol_matrix(
            self.all_combos,
            [model_id["volume_ids"] for model_id in self.hparams["model_ids"]])

        self.edge_encoders = nn.ModuleList([
            make_mlp(
                hparams["model_ids"][combo[0]]["num_features"] +
                hparams["model_ids"][combo[1]]["num_features"] +
                2 * hparams["cell_channels"],
                [
                    hparams["hidden"] // (2**i)
                    for i in range(hparams["nb_layer"])
                ] + [1],
                layer_norm=hparams["layernorm"],
                batch_norm=hparams["batchnorm"],
                output_activation=None,
                hidden_activation=hparams["hidden_activation"],
            ) for combo in self.all_combos
        ])
Ejemplo n.º 10
0
def IP_loss(phase, mask):
    osci_num = phase.shape[1]

    phase_sin = torch.sin(phase)
    phase_cos = torch.cos(phase)

    masked_sin = phase_sin.unsqueeze(1) * mask
    masked_cos = phase_cos.unsqueeze(1) * mask
    # mask.shape=(batch, groups, N)

    product1 = torch.matmul(masked_sin.unsqueeze(3), masked_sin.unsqueeze(2)) +\
               torch.matmul(masked_cos.unsqueeze(3), masked_cos.unsqueeze(2))

    sync_loss = (product1.sum(3).sum(2) - osci_num) / (osci_num**2 - osci_num)

    product2 = torch.matmul(masked_sin.unsqueeze(2).unsqueeze(4), masked_sin.unsqueeze(1).unsqueeze(3)) +\
               torch.matmul(masked_cos.unsqueeze(2).unsqueeze(4), masked_cos.unsqueeze(1).unsqueeze(3))

    product2 = product2.sum(2, keepdim=True) - product1.unsqueeze(2)
    desync_loss = product2.squeeze().sum(3).sum(2).sum(
        1) / 2 / torch.combinations(torch.arange(osci_num)).shape[0]

    sync_loss_mean = torch.exp(-sync_loss.mean())
    desync_loss_mean = torch.exp(desync_loss.mean())

    tot_loss_mean = 0.1 * sync_loss_mean + desync_loss_mean
    return tot_loss_mean, sync_loss_mean, desync_loss_mean
Ejemplo n.º 11
0
 def _combinations(self, tensor, dim=0):
     n = tensor.shape[dim]
     if n == 0:
         return tensor, tensor
     r = torch.arange(n, dtype=torch.long, device=tensor.device)
     index1, index2 = torch.combinations(r).unbind(-1)
     return tensor.index_select(dim, index1), \
         tensor.index_select(dim, index2)
Ejemplo n.º 12
0
def triu_index(num_species):
    species = torch.arange(num_species)
    species1, species2 = torch.combinations(species, r=2, with_replacement=True).unbind(-1)
    pair_index = torch.arange(species1.shape[0])
    ret = torch.zeros(num_species, num_species, dtype=torch.long)
    ret[species1, species2] = pair_index
    ret[species2, species1] = pair_index
    return ret
Ejemplo n.º 13
0
    def __init__(self,
                 points_per_patch=256,
                 patch_radius=[0.05],
                 dim_pts=3,
                 num_gpts=128,
                 dim_gpts=1,
                 hyps=64,
                 inlier_params=[0.01, 0.5],
                 use_mask=False,
                 sym_op='max',
                 ith=0,
                 use_point_stn=True,
                 use_feat_stn=True,
                 decoder='PointPredNet',
                 device=0,
                 normal_loss='ms_euclidean',
                 seed=3627474):
        '''
        Constructor.

        hyps -- number of planes hypotheses sampled for each patch
        inlier_thresh -- threshold used in the soft inlier count, 
        inlier_beta -- scaling factor within the sigmoid of the soft inlier count
        inlier_alpha -- scaling factor for the soft inlier scores (controls the peakiness of the hypothesis distribution)

        '''
        super(WDSAC, self).__init__()
        self.hyps = hyps
        self.num_gpts = num_gpts

        if len(inlier_params) == 2:
            self.scorer = Scorer_dist(inlier_params[0])
        else:
            self.scorer = Scorer_inlier_count(inlier_params[0],
                                              inlier_params[1])
        self.inlier_alpha = inlier_params[-1]

        self.normal_loss = normal_loss
        self.use_point_stn = use_point_stn
        self.device = device

        torch.manual_seed(seed)
        gpts_idx = torch.tensor([i for i in range(num_gpts)])
        self.idx_combi = torch.combinations(gpts_idx, 3).to(device)
        #self.plane_weight = torch.ones(self.idx_combi.size(0), dtype=torch.float).to(device)

        self.wpcp = PCPNet(num_pts=points_per_patch,
                           dim_pts=dim_pts,
                           num_gpts=points_per_patch,
                           dim_gpts=1,
                           use_point_stn=use_point_stn,
                           use_feat_stn=use_feat_stn,
                           device=device,
                           b_pred=True,
                           use_mask=False,
                           sym_op=sym_op,
                           ith=0)
Ejemplo n.º 14
0
def main(args):
    res = []
       
    combos = torch.combinations(torch.arange(0, args.n_modules), args.n_active).numpy().tolist()
    for i, subset in enumerate(combos):
        if args.head_id in subset:
            res.append(str(i))

    print("{}".format(" ".join(res)))
Ejemplo n.º 15
0
	def take_all_pwrs(self, vec, pwr):
		#todo: vectorize (kinda)
		combins=torch.combinations(vec, r=pwr, with_replacement=True)
		out=torch.ones(combins.size()[0]).to(device).to(torch.float)
		for i in torch.t(combins).to(device).to(torch.float):
			out *= i
		if pwr == 1:
			return out
		else:
			return torch.cat((out,self.take_all_pwrs(vec, pwr-1)))
Ejemplo n.º 16
0
def neighbor_pairs(padding_mask, coordinates, cell, shifts, cutoff):
    """Compute pairs of atoms that are neighbors
    Copyright 2018- Xiang Gao and other ANI developers
    (https://github.com/aiqm/torchani/blob/master/torchani/aev.py)

    Arguments:
        padding_mask (:class:`torch.Tensor`): boolean tensor of shape
            (molecules, atoms) for padding mask. 1 == is padding.
        coordinates (:class:`torch.Tensor`): tensor of shape
            (molecules, atoms, 3) for atom coordinates.
        cell (:class:`torch.Tensor`): tensor of shape (3, 3) of the three vectors
            defining unit cell: tensor([[x1, y1, z1], [x2, y2, z2], [x3, y3, z3]])
        cutoff (float): the cutoff inside which atoms are considered pairs
        shifts (:class:`torch.Tensor`): tensor of shape (?, 3) storing shifts
    """
    # type: (Tensor, Tensor, Tensor, Tensor, float) -> Tuple[Tensor, Tensor, Tensor, Tensor]

    coordinates = coordinates.detach()
    cell = cell.detach()
    num_atoms = padding_mask.shape[0]
    all_atoms = torch.arange(num_atoms, device=cell.device)

    # Step 2: center cell
    p1_center, p2_center = torch.combinations(all_atoms).unbind(-1)
    shifts_center = shifts.new_zeros(p1_center.shape[0], 3)

    # Step 3: cells with shifts
    # shape convention (shift index, molecule index, atom index, 3)
    num_shifts = shifts.shape[0]
    all_shifts = torch.arange(num_shifts, device=cell.device)
    shift_index, p1, p2 = torch.cartesian_prod(all_shifts, all_atoms, all_atoms).unbind(
        -1
    )
    shifts_outside = shifts.index_select(0, shift_index)

    # Step 4: combine results for all cells
    shifts_all = torch.cat([shifts_center, shifts_outside])
    p1_all = torch.cat([p1_center, p1])
    p2_all = torch.cat([p2_center, p2])

    shift_values = torch.mm(shifts_all.to(cell.dtype), cell)

    # step 5, compute distances, and find all pairs within cutoff
    distances = (coordinates[p1_all] - coordinates[p2_all] + shift_values).norm(2, -1)

    padding_mask = (padding_mask[p1_all]) | (padding_mask[p2_all])
    distances.masked_fill_(padding_mask, math.inf)
    in_cutoff = torch.nonzero(distances < cutoff, as_tuple=False)
    pair_index = in_cutoff.squeeze()
    atom_index1 = p1_all[pair_index]
    atom_index2 = p2_all[pair_index]
    shifts = shifts_all.index_select(0, pair_index)
    return atom_index1, atom_index2, shifts
Ejemplo n.º 17
0
def get_iterator(N, get_all=True):
    if get_all:
        # every possible pair
        total = (N * (N - 1) // 2)
        iters = torch.combinations(torch.from_numpy(np.arange(N)),
                                   2).to(DEVICE).long()
    else:
        total = 10000
        iters = torch.from_numpy(np.random.randint(
            N, size=(total, 2))).to(DEVICE).long()

    return total, iters
Ejemplo n.º 18
0
def test_sample_hyp(tries_in=9999):
    gpu_idx = -3
    device = torch.device("cpu" if gpu_idx < 0 else "cuda:%d" % gpu_idx)
    batchsize = 2

    rng = np.random.RandomState(3627474)
    num_gpts = 256
    pts = torch.rand(batchsize, num_gpts, 3)
    hyps = 32
    gpts_idx = torch.tensor([i for i in range(num_gpts)])
    combi = torch.combinations(gpts_idx, 3).to(device)

    t = time.process_time()
    tries = tries_in
    while tries:
        tmp = torch.randint(0, combi.size(0), (hyps * batchsize, ))
        index = combi[tmp, :].view(batchsize, hyps * 3, 1)
        if tries:
            tries -= 1
        else:
            break
    print(time.process_time() - t)

    plane_weight = torch.ones(combi.size(0), dtype=torch.float).to(device)
    t = time.process_time()
    tries = tries_in
    while tries:
        tmp = torch.multinomial(plane_weight,
                                hyps * batchsize,
                                replacement=False)
        index = combi[tmp, :].view(batchsize, hyps * 3, 1)
        if tries:
            tries -= 1
        else:
            break
    print(time.process_time() - t)

    t = time.process_time()
    tries = tries_in
    while tries:
        index = torch.stack([
            torch.from_numpy(
                np.stack(
                    rng.choice(pts.size(1), 3, replace=False)
                    for _ in range(hyps)).reshape(-1))
            for _ in range(batchsize)
        ])
        index = index.view(batchsize, hyps * 3, 1)
        if tries:
            tries -= 1
        else:
            break
    print(time.process_time() - t)
Ejemplo n.º 19
0
def neighbor_pairs(padding_mask: Tensor, coordinates: Tensor, cell: Tensor,
                   shifts: Tensor,
                   cutoff: float) -> Tuple[Tensor, Tensor, Tensor]:
    """Compute pairs of atoms that are neighbors

    Arguments:
        padding_mask (:class:`torch.Tensor`): boolean tensor of shape
            (molecules, atoms) for padding mask. 1 == is padding.
        coordinates (:class:`torch.Tensor`): tensor of shape
            (molecules, atoms, 3) for atom coordinates.
        cell (:class:`torch.Tensor`): tensor of shape (3, 3) of the three vectors
            defining unit cell: tensor([[x1, y1, z1], [x2, y2, z2], [x3, y3, z3]])
        cutoff (float): the cutoff inside which atoms are considered pairs
        shifts (:class:`torch.Tensor`): tensor of shape (?, 3) storing shifts
    """
    coordinates = coordinates.detach()
    cell = cell.detach()
    num_atoms = padding_mask.shape[1]
    all_atoms = torch.arange(num_atoms, device=cell.device)

    # Step 2: center cell
    p1_center, p2_center = torch.combinations(all_atoms).unbind(-1)
    shifts_center = shifts.new_zeros((p1_center.shape[0], 3))

    # Step 3: cells with shifts
    # shape convention (shift index, molecule index, atom index, 3)
    num_shifts = shifts.shape[0]
    all_shifts = torch.arange(num_shifts, device=cell.device)
    shift_index, p1, p2 = torch.cartesian_prod(all_shifts, all_atoms,
                                               all_atoms).unbind(-1)
    shifts_outide = shifts.index_select(0, shift_index)

    # Step 4: combine results for all cells
    shifts_all = torch.cat([shifts_center, shifts_outide])
    p1_all = torch.cat([p1_center, p1])
    p2_all = torch.cat([p2_center, p2])
    shift_values = shifts_all.to(cell.dtype) @ cell

    # step 5, compute distances, and find all pairs within cutoff
    distances = (coordinates.index_select(1, p1_all) -
                 coordinates.index_select(1, p2_all) + shift_values).norm(
                     2, -1)
    padding_mask = (padding_mask.index_select(
        1, p1_all)) | (padding_mask.index_select(1, p2_all))
    distances.masked_fill_(padding_mask, math.inf)
    in_cutoff = (distances <= cutoff).nonzero()
    molecule_index, pair_index = in_cutoff.unbind(1)
    molecule_index *= num_atoms
    atom_index1 = p1_all[pair_index]
    atom_index2 = p2_all[pair_index]
    shifts = shifts_all.index_select(0, pair_index)
    return molecule_index + atom_index1, molecule_index + atom_index2, shifts
Ejemplo n.º 20
0
    def forward(self, vertex_pred, vertex_gt, c_pred, points, target, model_points, model_kp, idx, target_r, target_t):
        vertex_loss = smooth_l1_loss(vertex_pred.view(1, self.num_pt_mesh, -1), vertex_gt.view(1, self.num_pt_mesh, -1))

        kp_set = vertex_pred + points.repeat(1, 1, 9).view(1, points.shape[1], 9, 3)
        confidence = c_pred / (0.00001 + torch.sum(c_pred, 1))
        points_pred = torch.sum(confidence * kp_set, 1)

        all_index = torch.combinations(torch.arange(9), 3)
        all_r, all_t = batch_least_square(model_kp.squeeze()[all_index, :], points_pred.squeeze()[all_index, :], torch.ones([all_index.shape[0], 3]).cuda())
        all_e = calculate_error(all_r, all_t, model_points, points)
        e = all_e.unsqueeze(0).unsqueeze(2)
        w = torch.softmax(1 / e, 1).squeeze().unsqueeze(1)
        all_qua = tgm.rotation_matrix_to_quaternion(torch.cat((all_r, torch.tensor([0., 0., 1.]).cuda().unsqueeze(1).repeat(all_index.shape[0], 1, 1)), dim=2))
        pred_qua = torch.sum(w * all_qua, 0)
        pred_r = pred_qua.view(1, 1, -1)
        bs, num_p, _ = pred_r.size()
        pred_r = pred_r / (torch.norm(pred_r, dim=2).view(1, 1, 1))
        pred_r = torch.cat(((1.0 - 2.0 * (pred_r[:, :, 2] ** 2 + pred_r[:, :, 3] ** 2)).view(bs, num_p, 1), \
                            (2.0 * pred_r[:, :, 1] * pred_r[:, :, 2] - 2.0 * pred_r[:, :, 0] * pred_r[:, :, 3]).view(bs, num_p,1), \
                            (2.0 * pred_r[:, :, 0] * pred_r[:, :, 2] + 2.0 * pred_r[:, :, 1] * pred_r[:, :, 3]).view(bs, num_p, 1), \
                            (2.0 * pred_r[:, :, 1] * pred_r[:, :, 2] + 2.0 * pred_r[:, :, 3] * pred_r[:, :, 0]).view(bs, num_p, 1), \
                            (1.0 - 2.0 * (pred_r[:, :, 1] ** 2 + pred_r[:, :, 3] ** 2)).view(bs, num_p, 1), \
                            (-2.0 * pred_r[:, :, 0] * pred_r[:, :, 1] + 2.0 * pred_r[:, :, 2] * pred_r[:, :, 3]).view(bs, num_p, 1), \
                            (-2.0 * pred_r[:, :, 0] * pred_r[:, :, 2] + 2.0 * pred_r[:, :, 1] * pred_r[:, :, 3]).view(bs, num_p, 1), \
                            (2.0 * pred_r[:, :, 0] * pred_r[:, :, 1] + 2.0 * pred_r[:, :, 2] * pred_r[:, :, 3]).view(bs, num_p, 1), \
                            (1.0 - 2.0 * (pred_r[:, :, 1] ** 2 + pred_r[:, :, 2] ** 2)).view(bs, num_p, 1)), dim=2).contiguous().view(bs * num_p, 3, 3)
        pred_r = pred_r.squeeze()
        pred_t = torch.sum(w * all_t, 0)

        target_r = target_r.squeeze()
        target_t = target_t.squeeze()
        pose_loss = torch.norm(pred_t - target_t) + 0.01 * torch.norm(torch.mm(pred_r, torch.transpose(target_r, 1, 0)) - torch.eye(3).cuda())

        pred = torch.mm(model_points[0], torch.transpose(pred_r, 1, 0)) + pred_t
        knn = KNN(k=1, transpose_mode=True)
        if idx in self.sym_list:
            dist, inds = knn(pred.unsqueeze(0), target.unsqueeze(0))
            dis = torch.mean(dist.squeeze())
        else:
            dis = torch.mean(torch.norm(pred - target[0], dim=1), dim=0)

        ori_r = torch.unsqueeze(pred_r, 0).cuda()
        ori_t = torch.unsqueeze(pred_t, 0).cuda()
        ori_t = ori_t.repeat(self.num_pt_mesh, 1).contiguous().view(1, self.num_pt_mesh, 3)
        new_points = torch.bmm((points - ori_t), ori_r).contiguous()

        new_target = torch.bmm((target - ori_t), ori_r).contiguous()

        del knn

        return vertex_loss, pose_loss, dis, new_points.detach(), new_target.detach()
Ejemplo n.º 21
0
def fast_pcmvda_loss(Ys, y, y_unique=None, beta=1, q=1):
    if len(set(X.size(1) for X in Ys)) > 1:
        raise ValueError(
            f"pc-MvDA only works on projected data with same dimensions, "
            f"got dimensions of {tuple(X.size(1) for X in Ys)}.")
    y_unique_present = torch.unique(y)
    if y_unique is None:
        y_unique = y_unique_present
        y_present_mask = torch.zeros_like(y_unique, dtype=torch.bool)
        y_present_mask[y_unique_present.tolist()] = 1
    else:
        y_present_mask = torch.ones_like(y_unique, dtype=torch.bool)
    y_present_mask = torch.where(y_present_mask)[0]

    options = dict(dtype=Ys[0].dtype, device=Ys[0].device)
    num_views = len(Ys)
    num_components = y.size(0)
    num_classes = y_unique.size(0)
    ecs = class_vectors(y, y_unique).to(dtype=options['dtype'])
    ucs = torch.stack(class_means(Ys, ecs))
    us = ucs.mean(0)
    y_unique_counts = ecs.sum(1)
    out_dimension = Ys[0].size(1)

    pairs = torch.combinations(y_unique_present, r=2)
    class_Sw = torch.zeros(num_classes, out_dimension, out_dimension,
                           **options)

    for ci in y_unique_present:
        for vj in range(num_views):
            for k in torch.where(ecs[ci])[0]:
                d_ijk = Ys[vj][k] - us[ci]
                class_Sw[ci] += d_ijk.unsqueeze(1) @ d_ijk.unsqueeze(0)
    Sw = class_Sw[y_present_mask].sum(0)

    out = torch.tensor(0, **options)
    for ca, cb in pairs:
        Sw_ab = beta * (y_unique_counts[ca] * class_Sw[ca] +
                        y_unique_counts[cb] * class_Sw[cb])
        Sw_ab.div_(y_unique_counts[ca] + y_unique_counts[cb]).add_(
            (1 - beta) * Sw)

        du_ab = sum(uc[ca] for uc in ucs).sub(sum(
            uc[cb] for uc in ucs)).div_(num_views).unsqueeze_(0)
        Sb_ab = du_ab.t().mm(du_ab)
        out += y_unique_counts[ca] * y_unique_counts[cb] * (
            torch.trace(Sb_ab) / torch.trace(Sw_ab)).pow_(-q)
    out /= num_components * num_components
    return out
Ejemplo n.º 22
0
def _gen_pairs(input, dim=-2, reducer=lambda a, b: ((a - b)**2).sum(dim=-1)):
    """Generates all pairs of different rows and then applies the reducer
    Args:
        input: a tensor
        dim: a dimension to generate pairs across
        reducer: a function of generated pair of rows to apply (beyond just concat)
    Returns:
        for default args, for A x B x C input, will output A x (B choose 2)
    """
    n = input.size()[dim]
    range = torch.arange(n)
    idx = torch.combinations(range).to(input).long()
    left = input.index_select(dim, idx[:, 0])
    right = input.index_select(dim, idx[:, 1])
    return reducer(left, right)
    def forward(self, xs, y_true):
        _, x = xs

        bs = x.shape[0]

        idx = torch.combinations(torch.range(0, bs - 1), 2).long().cuda()

        x1 = x[idx[:, 0]]
        x2 = x[idx[:, 1]]

        y1 = y_true[idx[:, 0]]
        y2 = y_true[idx[:, 1]]

        y = 2 * ((y1 == y2).float()) - 1

        loss = self.cos_dis(x1, x2, y)
        return loss
def stochastic_blockmodel_graph(block_sizes, edge_probs, directed=False):
    r"""Returns the :obj:`edge_index` of a stochastic blockmodel graph.

    Args:
        block_sizes ([int] or LongTensor): The sizes of blocks.
        edge_probs ([[float]] or FloatTensor): The density of edges going
        from each block to each other block. Must be symmetric if the graph is
            undirected.
        directed (bool, optional): If set to :obj:`True`, will return a
            directed graph. (default: :obj:`False`)
    """

    size, prob = block_sizes, edge_probs

    if not torch.is_tensor(size):
        size = torch.tensor(size, dtype=torch.long)
    if not torch.is_tensor(prob):
        prob = torch.tensor(prob, dtype=torch.float)

    assert size.dim() == 1
    assert prob.dim() == 2 and prob.size(0) == prob.size(1)
    assert size.size(0) == prob.size(0)
    if not directed:
        assert torch.allclose(prob, prob.t())

    node_idx = torch.cat([size.new_full((b, ), i) for i, b in enumerate(size)])
    num_nodes = node_idx.size(0)

    if directed:
        idx = torch.arange((num_nodes - 1) * num_nodes)
        idx = idx.view(num_nodes - 1, num_nodes)
        idx = idx + torch.arange(1, num_nodes).view(-1, 1)
        idx = idx.view(-1)
        row = idx // num_nodes
        col = idx % num_nodes
    else:
        row, col = torch.combinations(torch.arange(num_nodes)).t()

    mask = torch.bernoulli(prob[node_idx[row], node_idx[col]]).to(torch.bool)
    edge_index = torch.stack([row[mask], col[mask]], dim=0)

    if not directed:
        edge_index = to_undirected(edge_index, num_nodes)

    return edge_index
Ejemplo n.º 25
0
    def update(self, y_pred: torch.Tensor, s_c: torch.Tensor):
        assert y_pred.shape == s_c.shape
        if self.k is None:
            order = torch.argsort(input=y_pred, descending=True)
        else:
            sequence_length = y_pred.shape[0]
            if sequence_length < self.k:
                k = sequence_length
            else:
                k = self.k
            _, order = torch.topk(input=y_pred, k=k, largest=True)

        senti_score = torch.take(s_c, order)
        senti_score = torch.combinations(senti_score)
        senti_score = torch.abs(torch.diff(senti_score, dim=-1)) / 2
        senti_score = torch.sum(senti_score) / senti_score.size(0)
        self.ils_senti += senti_score
        self.count += 1.0
Ejemplo n.º 26
0
    def __init__(self,
                 args,
                 npoints=8,
                 input=64,
                 hidden=64,
                 hidden2=256,
                 sz=28,
                 sigma2=1e-2,
                 allow_points=True,
                 hard=False):
        super().__init__(args)

        self.hard = hard

        # build the coordinate grid:
        r = torch.linspace(-1, 1, sz)
        c = torch.linspace(-1, 1, sz)
        grid = torch.meshgrid(r, c)
        grid = torch.stack(grid, dim=2)
        self.register_buffer("grid", grid)

        # if we allow points, we compute the upper-triangular part of the symmetric connection
        # matrix including the diagonal. If points are not allowed, we don't need the diagonal values
        # as they would be implictly zero
        if allow_points:
            nlines = int((npoints**2 + npoints) / 2)
        else:
            nlines = int(npoints * (npoints - 1) / 2)
        self.coordpairs = torch.combinations(torch.arange(0,
                                                          npoints,
                                                          dtype=torch.long),
                                             r=2,
                                             with_replacement=allow_points)

        # shared part of the encoder
        self.enc1 = nn.Sequential(nn.Linear(input, hidden), nn.ReLU(),
                                  nn.Linear(hidden, hidden2), nn.ReLU())
        # second part for computing npoints 2d coordinates (using tanh because we use a -1..1 grid)
        self.enc_pts = nn.Sequential(nn.Linear(hidden2, npoints * 2),
                                     nn.Tanh())
        # second part for computing upper triangular part of the connection matrix
        self.enc_con = nn.Sequential(nn.Linear(hidden2, nlines), nn.Sigmoid())

        self.sigma2 = sigma2
Ejemplo n.º 27
0
def get_pair_indices(inputs: Tensor, ordered_pair: bool = False) -> Tensor:
    """
    Get pair indices between each element in input tensor

    Args:
        inputs: input tensor
        ordered_pair: if True, will return ordered pairs. (e.g. both inputs[i,j] and inputs[j,i] are included)

    Returns: a tensor of shape (K, 2) where K = choose(len(inputs),2) if ordered_pair is False.
        Else K = 2 * choose(len(inputs),2). Each row corresponds to two indices in inputs.

    """
    indices = torch.combinations(torch.tensor(range(len(inputs))), r=2)

    if ordered_pair:
        # make pairs ordered (e.g. both (0,1) and (1,0) are included)
        indices = torch.cat((indices, indices[:, [1, 0]]), dim=0)

    return indices
Ejemplo n.º 28
0
def compute_action(obs,
                   user_choice_model: UserChoiceModel,
                   q_model: QModel,
                   slate_size: int = 3):
    user = pack_state_user(obs)
    doc = pack_state_doc(obs)
    user = torch.tensor(user).unsqueeze(0)
    doc = torch.tensor(doc).unsqueeze(0)

    with torch.no_grad():
        scores = user_choice_model(user, doc).squeeze(0)  # shape=[num_docs+1]
        scores = torch.exp(scores - torch.max(scores, dim=-1)[0])
        scores_doc = scores[:-1]  # shape=[num_docs]
        score_no_click = scores[-1]  # shape=[]
        q_values = q_model(user, doc).squeeze(0)  # shape=[num_docs+1]
        q_values_doc = q_values[:-1]  # shape=[num_docs]
        q_values_no_click = q_values[-1]  # shape=[]

    num_docs = len(obs["doc"])
    indices = torch.tensor(np.arange(num_docs)).long()
    slates = torch.combinations(indices,
                                r=slate_size)  # shape=[num_slates, num_docs]
    num_slates, _ = slates.shape

    slate_decomp_q_values = torch.gather(
        q_values_doc.unsqueeze(0).expand(num_slates, num_docs), 1,
        slates)  # shape=[num_slates, slate_size]
    slate_scores = torch.gather(
        scores_doc.unsqueeze(0).expand(num_slates, num_docs), 1,
        slates)  # shape=[num_slates, slate_size]
    slate_q_values = (slate_decomp_q_values * slate_scores +
                      q_values_no_click * score_no_click).sum(dim=1) / (
                          slate_scores.sum(dim=1) + score_no_click
                      )  # shape=[num_slates]

    idx = np.argmax(slate_q_values.detach().numpy())
    selected = slates[idx].detach().numpy().tolist()
    action = tuple(selected)
    print("compute_action",
          q_values.detach().numpy(),
          scores.detach().numpy(), action)
    return action
Ejemplo n.º 29
0
def main(args):
    subsets = {}
    modules = {}

    reindex_fn = lambda x: x
    if args.n_modules is not None and args.n_active is not None:
        print("n-modules and n-active set... using module reindexing...", file=sys.stderr)

        # We use torch to guarantee that the subset indexing will be same as with our model
        combos = torch.combinations(torch.arange(0, args.n_modules), args.n_active).numpy().tolist()
        reindex_fn = lambda x: combos[x]

    n_lines = 0
    with open(args.input_file, "r") as fh:
        for line in fh:
            n_lines += 1

            line = line.strip().split("\t")
            curr_modules = [int(x) for x in line[-1].split(",")]
            curr_modules = [y for x in curr_modules for y in reindex_fn(x)]
            curr_modules = [str(x) for x in curr_modules]

            s = ",".join(curr_modules)
            if s in subsets:
                subsets[s] += 1
            else:
                subsets[s] = 1

            for m in curr_modules:
                if m in modules:
                    modules[m] += 1
                else:
                    modules[m] = 1
    print("N-subsets\t{}".format(len(subsets.keys())))
    print("N-modules\t{}".format(len(modules.keys())))
    print("N-ratio\t{}".format(len(subsets.keys()) / float(len(combos))))

    for k, v in sorted(subsets.items(), key=lambda item: item[1], reverse=True):
        print("S-{}\t{}\t{}".format(k, v, (v / float(n_lines))))

    for k, v in sorted(modules.items(), key=lambda item: item[1], reverse=True):
        print("M-{}\t{}\t{}".format(k, v, (v / float(n_lines))))
Ejemplo n.º 30
0
    def forward(self, node_features):
        """
        FORWARD ROUTINE
            - computes Adjacency Matrix
        """
        n_nodes = node_features.shape[-1]
        index = torch.arange(n_nodes)
        index_combinations = torch.combinations(index, r=2)
        adjacency_matrix = torch.zeros(node_features.shape[0], n_nodes,
                                       n_nodes)

        for i, j in index_combinations:

            adjacency_matrix[:, j,
                             i] = adjacency_matrix[:, i, j] = torch.sigmoid(
                                 self.fc1(
                                     torch.cat((node_features[:, :, i],
                                                node_features[:, :, j]),
                                               dim=1))).squeeze()
        return adjacency_matrix