def forward(self, x, F):
     feature_size = F.size(-1)
     x = nn.AdaptiveAvgPool2d(feature_size)(x)
     x = x.view(self.batch_size, self.mc, -1, feature_size, feature_size)
     F = F.view(self.batch_size, self.mc, -1, feature_size, feature_size)
     F_y = torch.cat([x, F], 2)
     # F_y = torch.transpose(F_y, 1, 2).contiguous()
     # F_y = torch.add(x, F)
     F_y = F_y.view(self.batch_size, -1, feature_size, feature_size)
     return F_y
    def write(self, z, F):
        # z: [batch_size, hidden_size]
        # F: [batch_size, s_size, r_size, t_size]

        write_vars = self.W_write(z)
        # write_vars: [batch_size, n_writes * (s_size + r_size + t_size + b_size)]

        write_sizes = [self.s_size, self.r_size, self.t_size, self.b_size]
        write_list = torch.split(write_vars, sum(write_sizes), dim=1)
        # write_list: list of [batch_size, s_size + r_size + t_size + b_size]

        # multiple writes at once
        # scale = 1./self.t_size
        scale = 1. / (3 * self.t_size)
        for write_idx, write_el in enumerate(write_list):
            s, r, t, b = torch.split(write_el, write_sizes, dim=1)
            s = torch.tanh(s)
            r = torch.tanh(r)
            t = torch.tanh(t)
            # *: [batch_size, *_size]

            b = torch.sigmoid(b + 1)
            # b: [batch_size, 1]

            # sr = torch.einsum("bs,br->bsr", s, r)
            sr = s.unsqueeze(2) * r.unsqueeze(1)
            sr = sr.view(sr.shape[0], -1)
            # sr: [batch_size, s_size, r_size]
            # v = torch.einsum("bsr,bsrv->bv", sr, F)
            v = torch.matmul(sr.unsqueeze(1), F).squeeze(1)

            # v: [batch_size, t_size]
            new_v = b.view(-1, 1) * (t - v)

            # new_v: [batch_size, t_size]
            # F = F + torch.einsum("bsr,bv->bsrv", sr, new_v * scale)
            delta = sr.unsqueeze(2) * new_v.unsqueeze(1)
            F = F + delta

        # scale F down if norm is > 1
        F_norm = F.view(F.shape[0], -1).norm(dim=-1)
        F_norm = torch.relu(F_norm - 1) + 1
        F = F / F_norm.unsqueeze(1).unsqueeze(1)

        return F
Exemple #3
0
    def forward(self, F, pred, seed):
        b, c, h, w = pred.size()

        F = self.bn(self.conv(F))
        F = nn.functional.adaptive_max_pool2d(F, (h, w))
        F = F.view(b, -1, h * w)
        W = torch.bmm(F.transpose(1, 2), F)
        P = self.softmax(W)

        if self.clamp:
            self.alpha.data = torch.clamp(self.alpha.data, 0, 1)
            self.beta.data = torch.clamp(self.beta.data, 0, 1)

        pred_vec = pred.view(b, c, -1)
        out_vec = torch.bmm(P, pred_vec.transpose(1, 2)).transpose(1, 2).contiguous()
        out = (1 / (1 + torch.exp(self.beta))) * ((1 / (1 + torch.exp(self.alpha))) * out_vec.view(b, c, h, w) + (torch.exp(self.alpha) / (1 + torch.exp(self.alpha))) * seed) + (
                    torch.exp(self.beta) / (1 + torch.exp(self.beta))) * pred
        return out, P
Exemple #4
0
def VF_adjacency_matrix(V, F):
    """
    Input:
    V: N x 3
    F: F x 3
    Outputs:
    C: V x F adjacency matrix
    """
    #tensor type and device
    device = V.device
    dtype = V.dtype

    VF_adj = torch.zeros((V.shape[0], F.shape[0]), dtype=dtype, device=device)
    v_idx = F.view(-1)
    f_idx = torch.arange(F.shape[0]).repeat(3).reshape(
        3, F.shape[0]).transpose(1, 0).contiguous().view(-1)  # [000111...FFF]

    VF_adj[v_idx, f_idx] = 1
    return VF_adj
Exemple #5
0
    def wireframe_rendering(self, x_in):
        def k(x):
            return torch.relu(1 - torch.abs(x))

        w = 28
        h = 28

        p = x_in[:, :, 0]
        theta = x_in[:, :, 1:]

        batch_size, num_elements, geo_size = theta.shape

        theta[:, :, 0] *= w
        theta[:, :, 1] *= h

        assert (p.shape[0] == batch_size and p.shape[1] == num_elements)

        x = np.repeat(np.arange(w), h).reshape(w, h)
        y = np.transpose(x)

        x_tensor = torch.from_numpy(x)
        y_tensor = torch.from_numpy(y)
        x_tensor = x_tensor.view(1, w, h)
        y_tensor = y_tensor.view(1, w, h)

        base_tensor = torch.cat([x_tensor, y_tensor
                                 ]).type(torch.FloatTensor).to(self.device)
        base_tensor = base_tensor.repeat(batch_size * num_elements, 1, 1, 1)
        theta = theta.view(batch_size * num_elements, geo_size, 1, 1)
        p = p.view(batch_size, num_elements, 1, 1)

        F = k(base_tensor[:, 0, :, :] -
              theta[:, 0]) * k(base_tensor[:, 1, :, :] - theta[:, 1])
        F = F.view(batch_size, num_elements, w, h)

        p_times_F = p * F

        I = torch.max(p_times_F, dim=1)[0]

        I = I.view(batch_size, 1, w, h)

        return I
def expected_gradient_length(X_pool, model, nb_ech):

    x_pool = torch.tensor(X_pool, dtype=torch.float).cuda()

    F, activations, preactivations = model.forward_gradients(x_pool)

    F = F.view(x_pool.shape[0], 1, 1).expand(x_pool.shape[0], 1, nb_ech)

    model.train()

    Fhat = torch.stack([model.forward(x_pool) for _ in range(nb_ech)], dim=-1)

    Cost = (1. / nb_ech) * torch.pow(F - Fhat, 2).sum().sum().sum()

    preactivations_grads = torch.autograd.grad(Cost, preactivations)

    gradients = goodfellow_backprop(activations, preactivations_grads)
    squared_L2 = torch.zeros(x_pool.shape[0], requires_grad=False).cuda()

    for gradient in gradients:
        Sij = torch.pow(gradient, 2).sum(dim=tuple(range(1, gradient.dim())))
        squared_L2 += Sij

    return squared_L2
Exemple #7
0
def batched_pooling(blocks, verts_pos, img_info):
    # convert vertex positions to x,y coordinates in the image, scaled to fractions of image dimension

    cam_mat, cam_pos = batch_camera_info(img_info)

    A = ((verts_pos * .57) - cam_pos.unsqueeze(1))
    B = (cam_mat.permute(0, 2, 1))

    pt_trans = torch.matmul(A, B)
    X = pt_trans[:, :, 0]
    Y = pt_trans[:, :, 1]
    Z = pt_trans[:, :, 2]
    F = 248

    h = (-Y) / (-Z) * F + 224 / 2.0
    w = X / (-Z) * F + 224 / 2.0
    xs = h / 223.
    ys = w / 223.

    full_features = None
    batch_size = verts_pos.shape[0]
    for block in blocks:
        # scale coordinated to block dimensions/resolution
        dim = block.shape[-1]

        cur_xs = torch.clamp(xs * dim, 0, dim - 1)
        cur_ys = torch.clamp(ys * dim, 0, dim - 1)

        # this is basically bilinear interpolation of the 4 closest feature vectors to where the vertex lands in the block
        # https://en.wikipedia.org/wiki/Bilinear_interpolation
        x1s, y1s, x2s, y2s = torch.floor(cur_xs), torch.floor(
            cur_ys), torch.ceil(cur_xs), torch.ceil(cur_ys)
        A = x2s - cur_xs
        B = cur_xs - x1s
        G = y2s - cur_ys
        H = cur_ys - y1s

        x1s = x1s.type(torch.cuda.LongTensor)
        y1s = y1s.type(torch.cuda.LongTensor)
        x2s = x2s.type(torch.cuda.LongTensor)
        y2s = y2s.type(torch.cuda.LongTensor)

        flat_block = block.permute(1, 0, 2,
                                   3).contiguous().view(block.shape[1], -1)
        block_upper = torch.arange(
            0, verts_pos.shape[0]).cuda().unsqueeze(-1).expand(
                batch_size, verts_pos.shape[1])

        selection = ((block_upper * dim * dim) + (x1s * dim) + y1s).view(-1)
        C = torch.index_select(flat_block, 1, selection)
        C = C.view(-1, batch_size, verts_pos.shape[1]).permute(1, 0, 2)
        selection = ((block_upper * dim * dim) + (x1s * dim) + y2s).view(-1)
        D = torch.index_select(flat_block, 1, selection)
        D = D.view(-1, batch_size, verts_pos.shape[1]).permute(1, 0, 2)
        selection = ((block_upper * dim * dim) + (x2s * dim) + y1s).view(-1)
        E = torch.index_select(flat_block, 1, selection)
        E = E.view(-1, batch_size, verts_pos.shape[1]).permute(1, 0, 2)
        selection = ((block_upper * dim * dim) + (x2s * dim) + y2s).view(-1)
        F = torch.index_select(flat_block, 1, selection)
        F = F.view(-1, batch_size, verts_pos.shape[1]).permute(1, 0, 2)

        section1 = A.unsqueeze(1) * C * G.unsqueeze(1)
        section2 = H.unsqueeze(1) * D * A.unsqueeze(1)
        section3 = G.unsqueeze(1) * E * B.unsqueeze(1)
        section4 = B.unsqueeze(1) * F * H.unsqueeze(1)

        features = (section1 + section2 + section3 + section4)
        features = features.permute(0, 2, 1)

        if full_features is None: full_features = features
        else: full_features = torch.cat((full_features, features), dim=2)

    return full_features
    def routine(self,
                α,
                x,
                β,
                y,
                class_xy=None,
                class_yx=None,
                class_xx=None,
                class_yy=None,
                mask_diagonal=False,
                **kwargs):

        N, D = x.shape
        M, _ = y.shape

        if self.debias:
            C_xx = self.calculate_cost(x, x.detach(), class_xx)
            C_yy = self.calculate_cost(y, y.detach(), class_yy)

        else:
            C_xx, C_yy = Non

        C_xy = self.calculate_cost(x, y.detach(), class_xy)
        C_yx = self.calculate_cost(y, x.detach(), class_yx)

        softmin = partial(self.softmin_online,
                          log_conv=self.logconv(C_xy, dtype=str(x.dtype)[6:]))

        diameter, ε, ε_s, ρ = scaling_parameters(x, y, self.p, self.blur,
                                                 self.reach, self.diameter,
                                                 self.scaling)

        a_x, b_y, a_y, b_x = sinkhorn_loop(softmin,
                                           log_weights(α),
                                           log_weights(β),
                                           C_xx,
                                           C_yy,
                                           C_xy,
                                           C_yx,
                                           ε_s,
                                           ρ,
                                           debias=self.debias)

        F, G = sinkhorn_cost(ε,
                             ρ,
                             α,
                             β,
                             a_x,
                             b_y,
                             a_y,
                             b_x,
                             debias=self.debias,
                             potentials=True)

        a_i = α.view(-1, 1)
        b_j = β.view(1, -1)
        F_i, G_j = F.view(-1, 1), G.view(1, -1)

        cost = (F_i + G_j).mean()

        # coupling calculation
        F_i = LazyTensor(F_i.view(-1, 1, 1))
        G_j = LazyTensor(G_j.view(1, -1, 1))
        a_i = LazyTensor(a_i.view(-1, 1, 1))
        b_j = LazyTensor(b_j.view(1, -1, 1))

        if len(C_xy) == 2:
            x, y = C_xy
            C_ij = self.cost(x, y)
        else:
            x, y, Z, L = C_xy
            C_ij = self.cost(x, y, class_values=[Z, L])

        coupling = ((F_i + G_j - C_ij) / self.eps).exp() * (a_i * b_j)

        return cost, coupling, C_ij, [F_i, G_j, a_i, b_j]
    def routine(self,
                α,
                x,
                β,
                y,
                class_xy=None,
                class_yx=None,
                class_xx=None,
                class_yy=None,
                mask_diagonal=False,
                **kwargs):

        if x.ndim == 2:
            x = x.unsqueeze(0)

        if y.ndim == 2:
            y = y.unsqueeze(0)

        B, N, D = x.shape
        _, M, _ = y.shape

        if self.debias:
            C_xx = self.calculate_cost(x,
                                       x.detach(),
                                       class_xx,
                                       mask_diagonal=mask_diagonal)
            C_yy = self.calculate_cost(y,
                                       y.detach(),
                                       class_yy,
                                       mask_diagonal=mask_diagonal)

        else:
            C_xx, C_yy = None, None

        C_xy = self.calculate_cost(x,
                                   y.detach(),
                                   class_xy,
                                   mask_diagonal=mask_diagonal)  # (B, N, M)
        C_yx = self.calculate_cost(y,
                                   x.detach(),
                                   class_yx,
                                   mask_diagonal=mask_diagonal)  # (B, M, N)

        diameter, ε, ε_s, ρ = scaling_parameters(x, y, self.p, self.blur,
                                                 self.reach, self.diameter,
                                                 self.scaling)

        a_x, b_y, a_y, b_x = sinkhorn_loop(softmin_tensorized,
                                           log_weights(α),
                                           log_weights(β),
                                           C_xx,
                                           C_yy,
                                           C_xy,
                                           C_yx,
                                           ε_s,
                                           ρ,
                                           debias=self.debias)

        F, G = sinkhorn_cost(ε,
                             ρ,
                             α,
                             β,
                             a_x,
                             b_y,
                             a_y,
                             b_x,
                             batch=True,
                             debias=self.debias,
                             potentials=True)

        a_i = α.view(-1, 1)
        b_j = β.view(1, -1)
        F_i, G_j = F.view(-1, 1), G.view(1, -1)
        cost = (F_i + G_j).mean()
        coupling = ((F_i + G_j - C_xy) / self.eps).exp() * (a_i * b_j)

        return cost, coupling.squeeze(0), C_xy.squeeze(0), [F_i, G_j, a_i, b_j]
    def pooling(self, blocks, verts_pos, debug=False):
        # convert vertex positions to x,y coordinates in the image, scaled to fractions of image dimension
        ext_verts_pos = torch.cat(
            (verts_pos,
             torch.FloatTensor(
                 np.ones([verts_pos.shape[0], verts_pos.shape[1], 1])).cuda()),
            dim=-1)
        ext_verts_pos = torch.matmul(ext_verts_pos, self.matrix.permute(1, 0))
        xs = ext_verts_pos[:, :, 1] / ext_verts_pos[:, :, 2] / 256.
        ys = ext_verts_pos[:, :, 0] / ext_verts_pos[:, :, 2] / 256.

        full_features = None
        batch_size = verts_pos.shape[0]

        # check camera project covers the image
        if debug:
            dim = 256
            xs = (torch.clamp(xs * dim, 0,
                              dim - 1).data.cpu().numpy()).astype(np.uint8)
            ys = (torch.clamp(ys * dim, 0,
                              dim - 1).data.cpu().numpy()).astype(np.uint8)
            for ex in range(blocks.shape[0]):
                img = blocks[ex].permute(1, 2, 0).data.cpu().numpy()[:, :, :3]
                for x, y in zip(xs[ex], ys[ex]):
                    img[x, y, 0] = 1
                    img[x, y, 1] = 0
                    img[x, y, 2] = 0

                from PIL import Image
                Image.fromarray(
                    (img * 255).astype(np.uint8)).save('results/temp.png')
                print('saved')
                input()

        for block in blocks:
            # scale projected vertex points to dimension of current feature map
            dim = block.shape[-1]
            cur_xs = torch.clamp(xs * dim, 0, dim - 1)
            cur_ys = torch.clamp(ys * dim, 0, dim - 1)

            # https://en.wikipedia.org/wiki/Bilinear_interpolation
            x1s, y1s, x2s, y2s = torch.floor(cur_xs), torch.floor(
                cur_ys), torch.ceil(cur_xs), torch.ceil(cur_ys)
            A = x2s - cur_xs
            B = cur_xs - x1s
            G = y2s - cur_ys
            H = cur_ys - y1s

            x1s = x1s.type(torch.cuda.LongTensor)
            y1s = y1s.type(torch.cuda.LongTensor)
            x2s = x2s.type(torch.cuda.LongTensor)
            y2s = y2s.type(torch.cuda.LongTensor)

            # flatten batch of feature maps to make vectorization easier
            flat_block = block.permute(1, 0, 2, 3).contiguous().view(
                block.shape[1], -1)
            block_idx = torch.arange(
                0, verts_pos.shape[0]).cuda().unsqueeze(-1).expand(
                    batch_size, verts_pos.shape[1])
            block_idx = block_idx * dim * dim

            selection = (block_idx + (x1s * dim) + y1s).view(-1)
            C = torch.index_select(flat_block, 1, selection)
            C = C.view(-1, batch_size, verts_pos.shape[1]).permute(1, 0, 2)
            selection = (block_idx + (x1s * dim) + y2s).view(-1)
            D = torch.index_select(flat_block, 1, selection)
            D = D.view(-1, batch_size, verts_pos.shape[1]).permute(1, 0, 2)
            selection = (block_idx + (x2s * dim) + y1s).view(-1)
            E = torch.index_select(flat_block, 1, selection)
            E = E.view(-1, batch_size, verts_pos.shape[1]).permute(1, 0, 2)
            selection = (block_idx + (x2s * dim) + y2s).view(-1)
            F = torch.index_select(flat_block, 1, selection)
            F = F.view(-1, batch_size, verts_pos.shape[1]).permute(1, 0, 2)

            section1 = A.unsqueeze(1) * C * G.unsqueeze(1)
            section2 = H.unsqueeze(1) * D * A.unsqueeze(1)
            section3 = G.unsqueeze(1) * E * B.unsqueeze(1)
            section4 = B.unsqueeze(1) * F * H.unsqueeze(1)

            features = (section1 + section2 + section3 + section4)
            features = features.permute(0, 2, 1)

            if full_features is None:
                full_features = features
            else:
                full_features = torch.cat((full_features, features), dim=2)

        return full_features