Ejemplo n.º 1
0
    def write(self, z, time, debug=False):
        # update usage indicator
        self.u = self.u + T.matmul(Variable(T.from_numpy(np.ones((1, Kr), dtype=np.float32))), self.W_predictor)

        # update writing weights
        prev_v_wr = self.v_wr
        v_wr = np.zeros((N_mem, 1), dtype=np.float32)
        if time < N_mem:
            v_wr[time][0] = 1
        else:
            waste_index = int(T.argmin(self.u).data)
            v_wr[waste_index][0] = 1
        self.v_wr = Variable(T.from_numpy(v_wr))

        # writing
        # z: (1, Z_DIM)
        if debug:
            print(self.M)
        if USE_RETROACTIVE:
            # update retroactive weights
            self.v_ret = GAMMA*self.v_ret + (1-GAMMA)*prev_v_wr
            z_wr = T.cat([z, Variable(T.from_numpy(np.zeros((1, Z_DIM), dtype=np.float32)))], 1)
            z_ret = T.cat([Variable(T.from_numpy(np.zeros((1, Z_DIM), dtype=np.float32))), z], 1)
            self.M = self.M + T.matmul(self.v_wr, z_wr) + T.matmul(self.v_ret, z_ret)
        else:
            self.M = self.M + T.matmul(self.v_wr, z)
        if debug:
            return self.M
Ejemplo n.º 2
0
    def forward(self,x):
        max_sample = x.size()[1]
        x = x.view(-1,self.feature_size)
        assignment = th.matmul(x,self.clusters)

        if self.add_batch_norm:
            assignment = self.batch_norm(assignment)

        assignment = F.softmax(assignment,dim=1)
        assignment = assignment.view(-1, max_sample, self.cluster_size)

        a_sum = th.sum(assignment,-2,keepdim=True)
        a = a_sum*self.clusters2

        assignment = assignment.transpose(1,2)

        x = x.view(-1, max_sample, self.feature_size)
        vlad = th.matmul(assignment, x)
        vlad = vlad.transpose(1,2)
        vlad = vlad - a

        # L2 intra norm
        vlad = F.normalize(vlad)
        
        # flattening + L2 norm
        vlad = vlad.view(-1, self.cluster_size*self.feature_size)
        vlad = F.normalize(vlad)

        return vlad
Ejemplo n.º 3
0
def train(ep):
    model.train()
    total_loss = 0
    count = 0
    train_idx_list = np.arange(len(X_train), dtype="int32")
    np.random.shuffle(train_idx_list)
    for idx in train_idx_list:
        data_line = X_train[idx]
        x, y = Variable(data_line[:-1]), Variable(data_line[1:])
        if args.cuda:
            x, y = x.cuda(), y.cuda()

        optimizer.zero_grad()
        output = model(x.unsqueeze(0)).squeeze(0)
        loss = -torch.trace(torch.matmul(y, torch.log(output).float().t()) +
                            torch.matmul((1 - y), torch.log(1 - output).float().t()))
        total_loss += loss.data[0]
        count += output.size(0)

        if args.clip > 0:
            torch.nn.utils.clip_grad_norm(model.parameters(), args.clip)
        loss.backward()
        optimizer.step()
        if idx > 0 and idx % args.log_interval == 0:
            cur_loss = total_loss / count
            print("Epoch {:2d} | lr {:.5f} | loss {:.5f}".format(ep, lr, cur_loss))
            total_loss = 0.0
            count = 0
Ejemplo n.º 4
0
    def forward(self,x):
        max_sample = x.size()[1]
        x = x.view(-1,self.feature_size)
        assignment = th.matmul(x,self.clusters)

        if self.add_batch_norm:
            assignment = self.batch_norm(assignment)

        assignment = F.softmax(assignment, dim=1)
        assignment = assignment.view(-1, max_sample, self.cluster_size)

        assignment = assignment.transpose(1,2)

        x = x.view(-1, max_sample, self.feature_size)
        rvlad = th.matmul(assignment, x)
        rvlad = rvlad.transpose(-1,1)

        # L2 intra norm
        rvlad = F.normalize(rvlad)
        
        # flattening + L2 norm
        rvlad = rvlad.view(-1, self.cluster_size*self.feature_size)
        rvlad = F.normalize(rvlad)

        return rvlad
Ejemplo n.º 5
0
 def forward(self, context, state, input_):
     output = (torch.matmul(context, self._v_c.unsqueeze(1))
               + torch.matmul(state, self._v_s.unsqueeze(1))
               + torch.matmul(input_, self._v_i.unsqueeze(1)))
     if self._b is not None:
         output = output + self._b.unsqueeze(0)
     return output
Ejemplo n.º 6
0
	def forward(self, sequence, graph):
		"""
		Apply self-attention to the sequence, ignores
		the graph
		"""
		sequence = sequence.squeeze(1)	
		
		#get the dimension
		n, d = sequence.size()
		
		#project the sequence into key, value, and query sequences
		keySeq = f.relu(self.keyProj(sequence))
		valueSeq = f.relu(self.valueProj(sequence))
		querySeq = f.relu(self.queryProj(sequence))
		
		#combine query with each key
		#a_ijh = softmax( (q_ih^T k_jh) / sqrt(d) )
		#the result is, row i is the importance of the sequence for key i
		importance = f.softmax(t.matmul(querySeq, keySeq.permute(1,0)) * math.sqrt(d),0).permute(1,0)

		#apply the importance weights to the value sequence
		attention = t.matmul(valueSeq.permute(1,0), importance).permute(1,0)
	
		#sum the sequence for a complete representation
		final = t.sum(attention, 0)
		
		return attention.unsqueeze(1), final
 def score(self, hidden, encoder_output):
     
     if self.method == 'dot':            
         # hidden is 1 by 256
         # encoder_output is 22 by 256
         encoder_output = torch.transpose(encoder_output, 0, 1)
         # encoder_output is 256 by 22
         energy = torch.matmul(hidden, encoder_output)
         return energy
     
     elif self.method == 'general':
         # hidden is 1 by 256
         # encoder_output is 256 by 22
         # encoder_output = torch.transpose(encoder_output, 0, 1)
         hidden = hidden.view(1, -1)
         a = self.attn(encoder_output)
         a = torch.transpose(a, 0, 1)
         energy = torch.matmul(hidden, a)
         return energy
     
     elif self.method == 'concat':
         len_encoder_output = encoder_output.size()[1]
         # hidden is 1 by 256
         # encoder_output is 256 by 22
         hidden = torch.transpose(hidden, 0, 1)
         # hidden is 256 by 1
         hidden = hidden.repeat(hidden_size, len_encoder_output)
         # hidden is 256 by 22
         concat = torch.cat((hidden, encoder_output), dim=0)
         # concat is 512 by 22
         # self.attn(concat) --> 256 by 22
         energy = torch.matmul(self.v, F.tanh(self.attn(concat)))
         return energy
Ejemplo n.º 8
0
def grad2():
    W = Variable(torch.rand(2, 2), requires_grad=True)
    W2 = Variable(torch.rand(2, 1), requires_grad=True)
    x1 = Variable(torch.rand(1, 2), requires_grad=True)
    x2 = Variable(torch.rand(1, 2), requires_grad=True)

    print("w: ")
    print(W)
    print("x1: ")
    print(x1)
    print("x2: ")
    print(x2)
    print("--------------------")

    y1 = torch.matmul(torch.matmul(x1, W), W2)
    print(torch.matmul(W, W2))
    # y = Variable(y, requires_grad=True)
    # print("y1:")
    # print(y1)

    y1.backward()
    # print(W.grad)
    print(x1.grad)

    # W.grad.data.zero_()
    # x1.grad.data.zero_()
    y2 = torch.matmul(torch.matmul(x2, W), W2)
    y2.backward()
    # print("y2: ")
    # print(y2)
    # print(W.grad)
    print(x2.grad)
    def forward(self, hidden_states, attention_mask):
        mixed_query_layer = self.query(hidden_states)
        mixed_key_layer = self.key(hidden_states)
        mixed_value_layer = self.value(hidden_states)

        query_layer = self.transpose_for_scores(mixed_query_layer)
        key_layer = self.transpose_for_scores(mixed_key_layer)
        value_layer = self.transpose_for_scores(mixed_value_layer)

        # Take the dot product between "query" and "key" to get the raw attention scores.
        attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
        attention_scores = attention_scores / math.sqrt(self.attention_head_size)
        # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
        attention_scores = attention_scores + attention_mask

        # Normalize the attention scores to probabilities.
        attention_probs = nn.Softmax(dim=-1)(attention_scores)

        # This is actually dropping out entire tokens to attend to, which might
        # seem a bit unusual, but is taken from the original Transformer paper.
        attention_probs = self.dropout(attention_probs)

        context_layer = torch.matmul(attention_probs, value_layer)
        context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
        new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
        context_layer = context_layer.view(*new_context_layer_shape)
        return context_layer
 def _attn(self, q, k, v):
     w = torch.matmul(q, k)
     if self.scale:
         w = w / math.sqrt(v.size(-1))
     w = w * self.b + -1e9 * (1 - self.b)  # TF implem method: mask_attn_weights
     w = nn.Softmax(dim=-1)(w)
     w = self.attn_dropout(w)
     return torch.matmul(w, v)
Ejemplo n.º 11
0
def test():
    x = torch.ones(1, 2)
    Sigma = torch.FloatTensor([[1, 0.8], [0.8, 1]])

    z = torch.ones(x.size())
    y = torch.matmul(x, Sigma)
    y = torch.matmul(y, x.t())
    print(y)
Ejemplo n.º 12
0
 def attention_score(attention, query, v, w):
     """ unnormalized attention score"""
     sum_ = attention.unsqueeze(1) + torch.matmul(
         query, w.unsqueeze(0)
     ).unsqueeze(2)  # [B, Nq, Ns, D]
     score = torch.matmul(
         F.tanh(sum_), v.unsqueeze(0).unsqueeze(1).unsqueeze(3)
     ).squeeze(3)  # [B, Nq, Ns]
     return score
 def attention(cls, query, key, value, mask=None, dropout=None):
     d_k = query.size(-1)
     scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k)
     if mask is not None:
         scores = scores.masked_fill(mask == 0, -1e9)
     p_attn = F.softmax(scores, dim=-1)
     if dropout is not None:
         p_attn = dropout(p_attn)
     return torch.matmul(p_attn, value), p_attn
Ejemplo n.º 14
0
    def forward(self, matrix_1: torch.Tensor, matrix_2: torch.Tensor) -> torch.Tensor:

        if self._use_input_biases:
            bias1 = matrix_1.new_ones(matrix_1.size()[:-1] + (1,))
            bias2 = matrix_2.new_ones(matrix_2.size()[:-1] + (1,))

            matrix_1 = torch.cat([matrix_1, bias1], -1)
            matrix_2 = torch.cat([matrix_2, bias2], -1)
        intermediate = torch.matmul(matrix_1.unsqueeze(1), self._weight_matrix.unsqueeze(0))
        final = torch.matmul(intermediate, matrix_2.unsqueeze(1).transpose(2, 3))
        return self._activation(final.squeeze(1) + self._bias)
Ejemplo n.º 15
0
 def _prepare(self, attn_mem):
     attn_feat = torch.matmul(attn_mem, self._attn_wm.unsqueeze(0))
     hop_feat = torch.matmul(attn_mem, self._hop_wm.unsqueeze(0))
     bs = attn_mem.size(0)
     n_l, d = self._init_h.size()
     size = (n_l, bs, d)
     lstm_states = (self._init_h.unsqueeze(1).expand(*size).contiguous(),
                    self._init_c.unsqueeze(1).expand(*size).contiguous())
     d = self._init_i.size(0)
     init_i = self._init_i.unsqueeze(0).unsqueeze(1).expand(bs, 1, d)
     return attn_feat, hop_feat, lstm_states, init_i
Ejemplo n.º 16
0
 def distance_calcu(self, query, gallery):
     """
     :param query:
     :param gallery:
     :return:
     """
     query = query.expand_as(gallery).contiguous()
     x = torch.cat([query, gallery, query-gallery], 1)
     W1, W2 = self.adpW(x)
     num = query.size(0)
     dist = torch.norm((torch.matmul(W2, gallery.view(num, -1, 1))+torch.matmul(W1, query.view(num, -1, 1)))  # projected gallery(combin with query)
                       - query.view(num, -1, 1), 2, 1)  # orig query
     return dist
Ejemplo n.º 17
0
def high_dimension_gaussain_energy(x):
    u = High_mu
    Sigma = High_Sigma
    Sigma = torch.inverse(Sigma)

    if isinstance(x, Variable):
        u = Variable(u, requires_grad=True)
        Sigma = Variable(Sigma, requires_grad=True)

    diff = x - u
    temp = 0.5 * torch.matmul(torch.matmul(diff, Sigma), diff.t())

    return temp
 def adpW(self,x):
     # x = F.normalize(x)
     x = self.adp_metric_embedding1(x)
     # x = self.adp_metric_embedding1_bn(x)
     x = F.prelu(x)
     x = self.adp_metric_embedding2(x)
     # x = self.adp_metric_embedding2_bn(x)
     diag_matrix = []
     for i in range(x.size(0)):
         diag_matrix.append(torch.diag(x[i,:]))
     x = torch.stack(diag_matrix)
     W = torch.matmul(self.transform_matrix,torch.matmul(x,self.transform_matrix))
     return W
 def distance_calcu(self, gallery, query):
     # x = torch.autograd.Variable(torch.cat([gallery,query]))
     gallery = gallery.expand_as(query).contiguous()
     x = torch.cat([gallery, query],1)
     W = self.adpW(x)
     num = query.size(0)
     dist1 = torch.norm(torch.matmul(W,gallery.view(num,-1,1))
                       -torch.matmul(W,query.view(num,-1,1)),2,1)
     x = torch.cat([query,gallery],1)
     W = self.adpW(x)
     dist2 = torch.norm(torch.matmul(W, gallery.view(num, -1, 1))
                       - torch.matmul(W, query.view(num, -1, 1)), 2, 1)
     dist = 0.5*(dist1+dist2)
     return dist
def attention(query: torch.Tensor,
              key: torch.Tensor,
              value: torch.Tensor,
              mask: torch.Tensor = None,
              dropout: Callable = None) -> Tuple[torch.Tensor, torch.Tensor]:
    """Compute 'Scaled Dot Product Attention'"""
    d_k = query.size(-1)
    scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k)
    if mask is not None:
        scores = scores.masked_fill(mask == 0, -1e9)
    p_attn = F.softmax(scores, dim=-1)
    if dropout is not None:
        p_attn = dropout(p_attn)
    return torch.matmul(p_attn, value), p_attn
Ejemplo n.º 21
0
def max_singular_value(W, u=None, Ip=1):
    """
    power iteration for weight parameter
    """
    #xp = W.data
    if u is None:
        u = torch.FloatTensor(1, W.size(0)).normal_(0, 1).cuda()
    _u = u
    for _ in range(Ip):
        #print(_u.size(), W.size())
        _v = _l2normalize(torch.matmul(_u, W.data), eps=1e-12)
        _u = _l2normalize(torch.matmul(_v, torch.transpose(W.data, 0, 1)), eps=1e-12)
    sigma = torch.matmul(torch.matmul(_v, torch.transpose(W.data, 0, 1)), torch.transpose(_u, 0, 1))
    return sigma, _v
Ejemplo n.º 22
0
def seqAttention(sequence, weights):
	"""
	Applies attention to the given sequence
	"""
	#compute the importance over the sequence
	importance = t.tanh(t.matmul(sequence, weights))

	#compute the attention
	attention = f.softmax(importance, 0)

	tSeq = sequence.permute(1,0)

	#compute and return the representation
	return t.matmul(tSeq, attention)
    def AdpsubM(self, x, batchsize):
        batchinputs = x.view(batchsize**2,-1,1)
        # x_constraint_branch = x.detach()
        W = self.adpW(x)
        feature_dim = x.size(1) / 2
        # W_branch = self.adpW(x_constraint_branch)  ###for weight constrain
        # I = torch.autograd.Variable(torch.eye(feature_dim)).cuda()
        # weight_constraint = torch.norm(torch.matmul(W_branch, W_branch.transpose(1,2)).sub(I), 2)

        batchoutputs1 = torch.matmul(W,batchinputs[:,:feature_dim,:]).view(batchsize**2,-1)
        batchoutputs2 = torch.matmul(W,batchinputs[:,feature_dim:,:]).view(batchsize**2,-1)
        dist = torch.norm(batchoutputs1-batchoutputs2,2,1).view(batchsize,batchsize)
        dist = dist.clamp(min=1e-6)
        return dist
Ejemplo n.º 24
0
def CORAL_loss(source, target):
    d = source.data.shape[1]

    # source covariance
    xm = torch.mean(source, 1, keepdim=True) - source
    xc = torch.matmul(torch.transpose(xm, 0, 1), xm)

    # target covariance
    xmt = torch.mean(target, 1, keepdim=True) - target
    xct = torch.matmul(torch.transpose(xmt, 0, 1), xmt)
    # frobenius norm between source and target
    loss = (xc - xct).pow(2).sum().sqrt()
    loss = loss/(4*d*d)
    return loss
Ejemplo n.º 25
0
    def compute_weight(self, module):
        weight = module._parameters[self.name + '_org']
        u = module._buffers[self.name + '_u']
        height = weight.size(0)
        weight_mat = weight.view(height, -1)
        for _ in range(self.n_power_iterations):
            # Spectral norm of weight equals to `u^T W v`, where `u` and `v`
            # are the first left and right singular vectors.
            # This power iteration produces approximations of `u` and `v`.
            v = normalize(torch.matmul(weight_mat.t(), u), dim=0, eps=self.eps)
            u = normalize(torch.matmul(weight_mat, v), dim=0, eps=self.eps)

        sigma = torch.dot(u, torch.matmul(weight_mat, v))
        weight.data /= sigma
        return weight, u
Ejemplo n.º 26
0
def relative_matmul(x, z, transpose):
    """Helper function for relative positions attention."""
    batch_size = x.shape[0]
    heads = x.shape[1]
    length = x.shape[2]
    x_t = x.permute(2, 0, 1, 3)
    x_t_r = x_t.reshape(length, heads * batch_size, -1)
    if transpose:
        z_t = z.transpose(1, 2)
        x_tz_matmul = torch.matmul(x_t_r, z_t)
    else:
        x_tz_matmul = torch.matmul(x_t_r, z)
    x_tz_matmul_r = x_tz_matmul.reshape(length, batch_size, heads, -1)
    x_tz_matmul_r_t = x_tz_matmul_r.permute(1, 2, 0, 3)
    return x_tz_matmul_r_t
Ejemplo n.º 27
0
    def forward(self, input, target):
        y_true = target.int().unsqueeze(-1)
        same_id = torch.eq(y_true, y_true.t()).type_as(input)

        pos_mask = same_id
        neg_mask = 1 - same_id

        def _mask_max(input_tensor, mask, axis=None, keepdims=False):
            input_tensor = input_tensor - 1e6 * (1 - mask)
            _max, _idx = torch.max(input_tensor, dim=axis, keepdim=keepdims)
            return _max, _idx

        def _mask_min(input_tensor, mask, axis=None, keepdims=False):
            input_tensor = input_tensor + 1e6 * (1 - mask)
            _min, _idx = torch.min(input_tensor, dim=axis, keepdim=keepdims)
            return _min, _idx

        # output[i, j] = || feature[i, :] - feature[j, :] ||_2
        dist_squared = torch.sum(input ** 2, dim=1, keepdim=True) + \
                       torch.sum(input.t() ** 2, dim=0, keepdim=True) - \
                       2.0 * torch.matmul(input, input.t())
        dist = dist_squared.clamp(min=1e-16).sqrt()

        pos_max, pos_idx = _mask_max(dist, pos_mask, axis=-1)
        neg_min, neg_idx = _mask_min(dist, neg_mask, axis=-1)

        # loss(x, y) = max(0, -y * (x1 - x2) + margin)
        y = torch.ones(same_id.size()[0]).to(DEVICE)
        return F.margin_ranking_loss(neg_min.float(),
                                     pos_max.float(),
                                     y,
                                     self.margin,
                                     self.size_average)
Ejemplo n.º 28
0
    def __call__(self, spec_f):

        spec_f, is_variable = _check_is_variable(spec_f)
        n_fft = spec_f.size(2)

        m_min = 0. if self.f_min == 0 else 2595 * np.log10(1. + (self.f_min / 700))
        m_max = 2595 * np.log10(1. + (self.f_max / 700))

        m_pts = torch.linspace(m_min, m_max, self.n_mels + 2)
        f_pts = (700 * (10**(m_pts / 2595) - 1))

        bins = torch.floor(((n_fft - 1) * 2) * f_pts / self.sr).long()

        fb = torch.zeros(n_fft, self.n_mels)
        for m in range(1, self.n_mels + 1):
            f_m_minus = bins[m - 1].item()
            f_m = bins[m].item()
            f_m_plus = bins[m + 1].item()

            if f_m_minus != f_m:
                fb[f_m_minus:f_m, m - 1] = (torch.arange(f_m_minus, f_m) - f_m_minus) / (f_m - f_m_minus)
            if f_m != f_m_plus:
                fb[f_m:f_m_plus, m - 1] = (f_m_plus - torch.arange(f_m, f_m_plus)) / (f_m_plus - f_m)

        fb = Variable(fb)
        spec_m = torch.matmul(spec_f, fb)  # (c, l, n_fft) dot (n_fft, n_mels) -> (c, l, n_mels)
        return spec_m if is_variable else spec_m.data
Ejemplo n.º 29
0
    def routing(self, x, b_IJ, W,batch_size,routing_iter):
        x1 = x.view(batch_size, 256, 1, 6, 6)
        x_tile = x1.repeat(1, 1, 10, 1, 1)
        x_view = x_tile.view(batch_size, 1152, 10, 8, 1)
        stride_i = W.repeat(batch_size, 1, 1, 1, 1)
        stride_j = stride_i.view(batch_size, 1152, 10, 16, 8)
        dot_op = torch.matmul(stride_j, x_view)
        dot_op_stopped = Variable(dot_op.data.clone(), requires_grad=False)

        for r_iter in range(routing_iter):
            id_capsule = F.softmax(b_IJ, dim=2)
            if r_iter == routing_iter - 1:
                route_I = torch.mul(id_capsule, dot_op)
                route_I_sum = torch.sum(route_I, dim=1, keepdim=True) + self.bias
                V_J = squash(route_I_sum,self.epsilon)
            if r_iter < routing_iter - 1:

                dot_op_stopped_tmp = dot_op_stopped.data.numpy()
                dot_op_stopped_tmp = np.reshape(dot_op_stopped_tmp, (batch_size, 1152, 10, 16, 1))
                id_capsule_tmp = id_capsule.data.numpy()
                route_I_tmp = id_capsule_tmp * dot_op_stopped_tmp
                route_I_tmp_sum = np.sum(route_I_tmp, axis=1, keepdims=True) + self.bias.data.numpy()
                V_J_tmp = squash(torch.Tensor(route_I_tmp_sum),self.epsilon)

                V_J_tmp_tiled = np.tile(V_J_tmp.numpy(), (1, 1152, 1, 1, 1))
                dot_op_stopped_tmp = np.reshape(dot_op_stopped_tmp, (batch_size, 1152, 10, 1, 16))

                u_produce_v = np.matmul(dot_op_stopped_tmp, V_J_tmp_tiled)

                b_IJ.data += torch.Tensor(u_produce_v)

        return V_J
Ejemplo n.º 30
0
    def compute_weight(self, module):
        weight = getattr(module, self.name + '_org')
        u = getattr(module, self.name + '_u')
        height = weight.size(0)
        weight_mat = weight.view(height, -1)
        with torch.no_grad():
            for _ in range(self.n_power_iterations):
                # Spectral norm of weight equals to `u^T W v`, where `u` and `v`
                # are the first left and right singular vectors.
                # This power iteration produces approximations of `u` and `v`.
                v = normalize(torch.matmul(weight_mat.t(), u), dim=0, eps=self.eps)
                u = normalize(torch.matmul(weight_mat, v), dim=0, eps=self.eps)

            sigma = torch.dot(u, torch.matmul(weight_mat, v))
        weight = weight / sigma
        return weight, u
 def prob_v_given_h(self, h):
     return (torch.matmul(h, self.weights.data, out=None).add_(
         self.visible_bias.data).sigmoid_().clamp_(min=0, max=1))
Ejemplo n.º 32
0
 def forward(self, input):
     x = self.fc(input)
     attn = torch.softmax(torch.matmul(x, x.transpose(1, 2)), 2)
     output = torch.matmul(attn, input)
     return output
Ejemplo n.º 33
0
    def matmul(self, other):
        """Matrix product of two tensors.

        See :func:`torch.matmul`."""
        return torch.matmul(self, other)
def angle_axis_to_rotation_matrix(angle_axis: torch.Tensor) -> torch.Tensor:
    r"""Convert 3d vector of axis-angle rotation to 3x3 rotation matrix
    Args:
        angle_axis (torch.Tensor): tensor of 3d vector of axis-angle rotations.
    Returns:
        torch.Tensor: tensor of 3x3 rotation matrices.
    Shape:
        - Input: :math:`(N, 3)`
        - Output: :math:`(N, 3, 3)`
    Example:
        >>> input = torch.rand(1, 3)  # Nx3
        >>> output = kornia.angle_axis_to_rotation_matrix(input)  # Nx3x3
    """
    if not isinstance(angle_axis, torch.Tensor):
        raise TypeError("Input type is not a torch.Tensor. Got {}".format(
            type(angle_axis)))

    if not angle_axis.shape[-1] == 3:
        raise ValueError(
            "Input size must be a (*, 3) tensor. Got {}".format(
                angle_axis.shape))

    def _compute_rotation_matrix(angle_axis, theta2, eps=1e-6):
        # We want to be careful to only evaluate the square root if the
        # norm of the angle_axis vector is greater than zero. Otherwise
        # we get a division by zero.
        k_one = 1.0
        theta = torch.sqrt(theta2)
        wxyz = angle_axis / (theta + eps)
        wx, wy, wz = torch.chunk(wxyz, 3, dim=1)
        cos_theta = torch.cos(theta)
        sin_theta = torch.sin(theta)

        r00 = cos_theta + wx * wx * (k_one - cos_theta)
        r10 = wz * sin_theta + wx * wy * (k_one - cos_theta)
        r20 = -wy * sin_theta + wx * wz * (k_one - cos_theta)
        r01 = wx * wy * (k_one - cos_theta) - wz * sin_theta
        r11 = cos_theta + wy * wy * (k_one - cos_theta)
        r21 = wx * sin_theta + wy * wz * (k_one - cos_theta)
        r02 = wy * sin_theta + wx * wz * (k_one - cos_theta)
        r12 = -wx * sin_theta + wy * wz * (k_one - cos_theta)
        r22 = cos_theta + wz * wz * (k_one - cos_theta)
        rotation_matrix = torch.cat(
            [r00, r01, r02, r10, r11, r12, r20, r21, r22], dim=1)
        return rotation_matrix.view(-1, 3, 3)

    def _compute_rotation_matrix_taylor(angle_axis):
        rx, ry, rz = torch.chunk(angle_axis, 3, dim=1)
        k_one = torch.ones_like(rx)
        rotation_matrix = torch.cat(
            [k_one, -rz, ry, rz, k_one, -rx, -ry, rx, k_one], dim=1)
        return rotation_matrix.view(-1, 3, 3)

    # stolen from ceres/rotation.h

    _angle_axis = torch.unsqueeze(angle_axis, dim=1)
    theta2 = torch.matmul(_angle_axis, _angle_axis.transpose(1, 2))
    theta2 = torch.squeeze(theta2, dim=1)

    # compute rotation matrices
    rotation_matrix_normal = _compute_rotation_matrix(angle_axis, theta2)
    rotation_matrix_taylor = _compute_rotation_matrix_taylor(angle_axis)

    # create mask to handle both cases
    eps = 1e-6
    mask = (theta2 > eps).view(-1, 1, 1).to(theta2.device)
    mask_pos = (mask).type_as(theta2)
    mask_neg = (mask == False).type_as(theta2)  # noqa

    # create output pose matrix
    batch_size = angle_axis.shape[0]
    rotation_matrix = torch.eye(3).to(angle_axis.device).type_as(angle_axis)
    rotation_matrix = rotation_matrix.view(1, 3, 3).repeat(batch_size, 1, 1)
    # fill output matrix with masked values
    rotation_matrix[..., :3, :3] = \
        mask_pos * rotation_matrix_normal + mask_neg * rotation_matrix_taylor
    return rotation_matrix  # Nx4x4
def apply_self_attention_rules(R_ss, R_sq, cam_ss):
    R_sq_addition = torch.matmul(cam_ss, R_sq)
    R_ss_addition = torch.matmul(cam_ss, R_ss)
    return R_ss_addition, R_sq_addition
Ejemplo n.º 36
0
    def forward(self, x_emb_var, x_len, col_inp_var, col_name_len,
            col_len, col_num, gt_where, gt_cond, reinforce):
        max_x_len = max(x_len)
        B = len(x_len)
        if reinforce:
            raise NotImplementedError('Our model doesn\'t have RL')

        # Predict the number of conditions
        # First use column embeddings to calculate the initial hidden unit
        # Then run the LSTM and predict condition number.
        e_num_col, col_num = col_name_encode(col_inp_var, col_name_len, col_len, self.cond_num_name_enc)
        num_col_att_val = self.cond_num_col_att(e_num_col).squeeze()
        for idx, num in enumerate(col_num):
            if num < max(col_num):
                num_col_att_val[idx, num:] = -100
        num_col_att = self.softmax(num_col_att_val)
        K_num_col = (e_num_col * num_col_att.unsqueeze(2)).sum(1)
        cond_num_h1 = self.cond_num_col2hid1(K_num_col).view(B, 4, self.N_h//2).transpose(0, 1).contiguous()
        cond_num_h2 = self.cond_num_col2hid2(K_num_col).view(B, 4, self.N_h//2).transpose(0, 1).contiguous()

        h_num_enc, _ = run_lstm(self.cond_num_lstm, x_emb_var, x_len,
                hidden=(cond_num_h1, cond_num_h2))

        num_att_val = self.cond_num_att(h_num_enc).squeeze()

        for idx, num in enumerate(x_len):
            if num < max_x_len:
                num_att_val[idx, num:] = -100
        num_att = self.softmax(num_att_val)

        K_cond_num = (h_num_enc * num_att.unsqueeze(2).expand_as(h_num_enc)).sum(1)
        cond_num_score = self.cond_num_out(K_cond_num)

        #Predict the columns of conditions
        e_cond_col, _ = col_name_encode(col_inp_var, col_name_len, col_len, self.cond_col_name_enc)
        h_col_enc, _ = run_lstm(self.cond_col_lstm, x_emb_var, x_len)
        #h_col_enc = x_emb_var
        if self.use_ca:
            col_att_val = torch.bmm(e_cond_col,
                    self.cond_col_att(h_col_enc).transpose(1, 2))
            for idx, num in enumerate(x_len):
                if num < max_x_len:
                    col_att_val[idx, :, num:] = -100
            col_att = self.softmax(col_att_val.view(
                (-1, max_x_len))).view(B, -1, max_x_len)
            K_cond_col = (h_col_enc.unsqueeze(1) * col_att.unsqueeze(3)).sum(2)
        else:
            col_att_val = self.cond_col_att(h_col_enc).squeeze()
            for idx, num in enumerate(x_len):
                if num < max_x_len:
                    col_att_val[idx, num:] = -100
            col_att = self.softmax(col_att_val)
            K_cond_col = (h_col_enc *
                    col_att_val.unsqueeze(2)).sum(1).unsqueeze(1)

        cond_col_score = self.cond_col_out(self.cond_col_out_K(K_cond_col) +
                self.cond_col_out_col(e_cond_col)).squeeze()
        max_col_num = max(col_num)
        for b, num in enumerate(col_num):
            if num < max_col_num:
                cond_col_score[b, num:] = -100


        #Predict the operator of conditions
        chosen_col_gt = []
        if gt_cond is None:
            cond_nums = np.argmax(cond_num_score.data.cpu().numpy(), axis=1)
            col_scores = cond_col_score.data.cpu().numpy()
            chosen_col_gt = [list(np.argsort(-col_scores[b])[:cond_nums[b]])
                    for b in range(len(cond_nums))]
        else:
            # print gt_cond
            chosen_col_gt = [[x[0] for x in one_gt_cond] for one_gt_cond in gt_cond]

        e_cond_col, _ = col_name_encode(col_inp_var, col_name_len,
                col_len, self.cond_op_name_enc)
        h_op_enc, _ = run_lstm(self.cond_op_lstm, x_emb_var, x_len)
        col_emb = []
        for b in range(B):
            cur_col_emb = torch.stack([e_cond_col[b, x]
                for x in chosen_col_gt[b]] + [e_cond_col[b, 0]] *
                (4 - len(chosen_col_gt[b])))  # Pad the columns to maximum (4)
            col_emb.append(cur_col_emb)
        col_emb = torch.stack(col_emb)

        if self.use_ca:
            op_att_val = torch.matmul(self.cond_op_att(h_op_enc).unsqueeze(1),
                    col_emb.unsqueeze(3)).squeeze()
            for idx, num in enumerate(x_len):
                if num < max_x_len:
                    op_att_val[idx, :, num:] = -100
            op_att = self.softmax(op_att_val.view(B*4, -1)).view(B, 4, -1)
            K_cond_op = (h_op_enc.unsqueeze(1) * op_att.unsqueeze(3)).sum(2)
        else:
            op_att_val = self.cond_op_att(h_op_enc).squeeze()
            for idx, num in enumerate(x_len):
                if num < max_x_len:
                    op_att_val[idx, num:] = -100
            op_att = self.softmax(op_att_val)
            K_cond_op = (h_op_enc * op_att.unsqueeze(2)).sum(1).unsqueeze(1)

        cond_op_score = self.cond_op_out(self.cond_op_out_K(K_cond_op) +
                self.cond_op_out_col(col_emb)).squeeze()

        #Predict the string of conditions
        h_str_enc, _ = run_lstm(self.cond_str_lstm, x_emb_var, x_len)
        e_cond_col, _ = col_name_encode(col_inp_var, col_name_len,
                col_len, self.cond_str_name_enc)
        col_emb = []
        for b in range(B):
            cur_col_emb = torch.stack([e_cond_col[b, x] for x in chosen_col_gt[b]] +
                                      [e_cond_col[b, 0]] * (4 - len(chosen_col_gt[b])))
            col_emb.append(cur_col_emb)
        col_emb = torch.stack(col_emb)

        if gt_where is not None:
            gt_tok_seq, gt_tok_len = self.gen_gt_batch(gt_where)
            g_str_s_flat, _ = self.cond_str_decoder(
                    gt_tok_seq.view(B*4, -1, self.max_tok_num))
            g_str_s = g_str_s_flat.contiguous().view(B, 4, -1, self.N_h)

            h_ext = h_str_enc.unsqueeze(1).unsqueeze(1)
            g_ext = g_str_s.unsqueeze(3)
            col_ext = col_emb.unsqueeze(2).unsqueeze(2)

            cond_str_score = self.cond_str_out(
                    self.cond_str_out_h(h_ext) + self.cond_str_out_g(g_ext) +
                    self.cond_str_out_col(col_ext)).squeeze()
            for b, num in enumerate(x_len):
                if num < max_x_len:
                    cond_str_score[b, :, :, num:] = -100
        else:
            h_ext = h_str_enc.unsqueeze(1).unsqueeze(1)
            col_ext = col_emb.unsqueeze(2).unsqueeze(2)
            scores = []

            t = 0
            init_inp = np.zeros((B*4, 1, self.max_tok_num), dtype=np.float32)
            init_inp[:,0,0] = 1  #Set the <BEG> token
            if self.gpu:
                cur_inp = Variable(torch.from_numpy(init_inp).cuda())
            else:
                cur_inp = Variable(torch.from_numpy(init_inp))
            cur_h = None
            while t < 50:
                if cur_h:
                    g_str_s_flat, cur_h = self.cond_str_decoder(cur_inp, cur_h)
                else:
                    g_str_s_flat, cur_h = self.cond_str_decoder(cur_inp)
                g_str_s = g_str_s_flat.view(B, 4, 1, self.N_h)
                g_ext = g_str_s.unsqueeze(3)

                cur_cond_str_score = self.cond_str_out(
                        self.cond_str_out_h(h_ext) + self.cond_str_out_g(g_ext)
                        + self.cond_str_out_col(col_ext)).squeeze()
                for b, num in enumerate(x_len):
                    if num < max_x_len:
                        cur_cond_str_score[b, :, num:] = -100
                scores.append(cur_cond_str_score)

                _, ans_tok_var = cur_cond_str_score.view(B*4, max_x_len).max(1)
                ans_tok = ans_tok_var.data.cpu()
                data = torch.zeros(B*4, self.max_tok_num).scatter_(
                        1, ans_tok.unsqueeze(1), 1)
                if self.gpu:  #To one-hot
                    cur_inp = Variable(data.cuda())
                else:
                    cur_inp = Variable(data)
                cur_inp = cur_inp.unsqueeze(1)

                t += 1

            cond_str_score = torch.stack(scores, 2)
            for b, num in enumerate(x_len):
                if num < max_x_len:
                    cond_str_score[b, :, :, num:] = -100  #[B, IDX, T, TOK_NUM]

        cond_score = (cond_num_score,
                cond_col_score, cond_op_score, cond_str_score)

        return cond_score
Ejemplo n.º 37
0
    def _get_svqb(
            self,
            U,  # Tensor
            drop,  # bool
            tau  # float
    ):
        # type: (Tensor, bool, float) -> Tensor
        """Return B-orthonormal U.

        .. note:: When `drop` is `False` then `svqb` is based on the
                  Algorithm 4 from [DuerschPhD2015] that is a slight
                  modification of the corresponding algorithm
                  introduced in [StathopolousWu2002].

        Arguments:

          U (Tensor) : initial approximation, size is (m, n)
          drop (bool) : when True, drop columns that
                     contribution to the `span([U])` is small.
          tau (float) : positive tolerance

        Returns:

          U (Tensor) : B-orthonormal columns (:math:`U^T B U = I`), size
                       is (m, n1), where `n1 = n` if `drop` is `False,
                       otherwise `n1 <= n`.

        """
        if torch.numel(U) == 0:
            return U
        UBU = _utils.qform(self.B, U)
        d = UBU.diagonal(0, -2, -1)

        # Detect and drop exact zero columns from U. While the test
        # `abs(d) == 0` is unlikely to be True for random data, it is
        # possible to construct input data to lobpcg where it will be
        # True leading to a failure (notice the `d ** -0.5` operation
        # in the original algorithm). To prevent the failure, we drop
        # the exact zero columns here and then continue with the
        # original algorithm below.
        nz = torch.where(abs(d) != 0.0)
        assert len(nz) == 1, nz
        if len(nz[0]) < len(d):
            U = U[:, nz[0]]
            if torch.numel(U) == 0:
                return U
            UBU = _utils.qform(self.B, U)
            d = UBU.diagonal(0, -2, -1)
            nz = torch.where(abs(d) != 0.0)
            assert len(nz[0]) == len(d)

        # The original algorithm 4 from [DuerschPhD2015].
        d_col = (d**-0.5).reshape(d.shape[0], 1)
        DUBUD = (UBU * d_col) * _utils.transpose(d_col)
        E, Z = _utils.symeig(DUBUD, eigenvectors=True)
        t = tau * abs(E).max()
        if drop:
            keep = torch.where(E > t)
            assert len(keep) == 1, keep
            E = E[keep[0]]
            Z = Z[:, keep[0]]
            d_col = d_col[keep[0]]
        else:
            E[(torch.where(E < t))[0]] = t

        return torch.matmul(U * _utils.transpose(d_col), Z * E**-0.5)
	def __call__(self, data_view1, data_view2):
		H1 = data_view1.view(data_view1.size(0)*data_view1.size(1),data_view1.size(2),data_view1.size(3))
		H2 = data_view2.view(data_view2.size(0)*data_view2.size(1),data_view2.size(2),data_view2.size(3))

		r1 = 1e-4
		r2 = 1e-4
		eps = 1e-12
		corr_sum = 0
		o1 = o2 = H1.size(1)

		m = H1.size(0)
		n = H1.size(1)
		
		H1bar = H1 - (1.0 / m) * H1
		H2bar = H2 - (1.0 / m) * H2
		Hat12 = torch.zeros(m,n,n).cuda()
		Hat11 = torch.zeros(m,n,n).cuda()
		Hat22 = torch.zeros(m,n,n).cuda()
		#Hat12 = torch.zeros(m,n,n)
		#Hat11 = torch.zeros(m,n,n)
		#Hat22 = torch.zeros(m,n,n)


		for i in range(m):
			Hat11[i] = torch.matmul(H1bar[i],H1bar.transpose(1,2)[i])
			Hat12[i] = torch.matmul(H1bar[i],H2bar.transpose(1,2)[i])
			Hat22[i] = torch.matmul(H2bar[i],H2bar.transpose(1,2)[i])

		SigmaHat12 = (1.0 / (m - 1)) * torch.mean(Hat12,dim=0)
		SigmaHat11 = (1.0 / (m - 1)) * torch.mean(Hat11,dim=0)+ r1 * torch.eye(o1, device=self.device)
		SigmaHat22 = (1.0 / (m - 1)) * torch.mean(Hat22,dim=0) + r2 * torch.eye(o2, device=self.device)

		# Calculating the root inverse of covariance matrices by using eigen decomposition
		[D1, V1] = torch.symeig(SigmaHat11, eigenvectors=True)
		[D2, V2] = torch.symeig(SigmaHat22, eigenvectors=True)

		# Added to increase stability
		posInd1 = torch.gt(D1, eps).nonzero()[:, 0]
		D1 = D1[posInd1]
		V1 = V1[:, posInd1]
		posInd2 = torch.gt(D2, eps).nonzero()[:, 0]
		D2 = D2[posInd2]
		V2 = V2[:, posInd2]
		SigmaHat11RootInv = torch.matmul(
			torch.matmul(V1, torch.diag(D1 ** -0.5)), V1.t())
		SigmaHat22RootInv = torch.matmul(
			torch.matmul(V2, torch.diag(D2 ** -0.5)), V2.t())

		Tval = torch.matmul(torch.matmul(SigmaHat11RootInv,
								   SigmaHat12), SigmaHat22RootInv)

		if self.use_all_singular_values:
			# all singular values are used to calculate the correlation
			corr = torch.sqrt(torch.trace(torch.matmul(Tval.t(), Tval)))
		else:
			# just the top self.outdim_size singular values are used
			U, V = torch.symeig(torch.matmul(Tval.t(), Tval), eigenvectors=True)
			U = U[torch.gt(U, eps).nonzero()[:, 0]]
			U = U.topk(self.outdim_size)[0]
			corr = torch.sum(torch.sqrt(U))
		return -corr
Ejemplo n.º 39
0
def _AXWb(A, X, W, b):
    X = th.matmul(X, W)
    Y = th.matmul(A, X.view(X.shape[0], -1)).view_as(X)
    return Y + b
Ejemplo n.º 40
0
def projection_transR_pytorch(original, proj_matrix):
    ent_embedding_size = original.shape[1]
    rel_embedding_size = proj_matrix.shape[1] // ent_embedding_size
    original = original.view(-1, ent_embedding_size, 1)
    proj_matrix = proj_matrix.view(-1, rel_embedding_size, ent_embedding_size)
    return torch.matmul(proj_matrix, original).view(-1, rel_embedding_size)
Ejemplo n.º 41
0
print()

# scalar multiplication
print('Scalar Multiplication in Numpy', 3*A)
print('Scalar Multiplication in PyTorch', 3*A_tensor)
print()

# elementwise multiplication
print('Elementwise Multiplication in Numpy', A*B)
print('Elementiwse Multiplication in PyTorch', A_tensor*B_tensor)
print()

# matrix multiplication
# this is slightly different from NumPy
print('Elementwise Multiplication in Numpy', A@B)
print('Elementiwse Multiplication in PyTorch', torch.matmul(A_tensor, B_tensor))
print()

# Elementwise comparison
print('NumPy: ', A == B)
print('PyTorch: ', A_tensor == B_tensor)
print()

# Generate a random matrix
C = np.array([[10, 9, 8], [6, 7, 5], [1, 2, 3]])
C_tensor = torch.Tensor(C)

print('C', C)
print()

# Sum along the row
    def forward(self, x):
        r"""
        The :func:`~gpytorch.variational.VariationalStrategy.forward` method determines how to marginalize out the
        inducing point function values. Specifically, forward defines how to transform a variational distribution
        over the inducing point values, :math:`q(u)`, in to a variational distribution over the function values at
        specified locations x, :math:`q(f|x)`, by integrating :math:`\int p(f|x, u)q(u)du`

        :param torch.Tensor x: Locations x to get the variational posterior of the function values at.
        :rtype: ~gpytorch.distributions.MultivariateNormal
        :return: The distribution :math:`q(f|x)`
        """
        variational_dist = self.variational_distribution
        inducing_points = self.inducing_points
        if inducing_points.dim() < x.dim():
            inducing_points = inducing_points.expand(*x.shape[:-2], *inducing_points.shape[-2:])
        if len(variational_dist.batch_shape) < x.dim() - 2:
            variational_dist = variational_dist.expand(x.shape[:-2])

        # If our points equal the inducing points, we're done
        if torch.equal(x, inducing_points):
            # De-whiten the prior covar
            prior_covar = self.prior_distribution.lazy_covariance_matrix
            if isinstance(variational_dist.lazy_covariance_matrix, RootLazyTensor):
                predictive_covar = RootLazyTensor(prior_covar @ variational_dist.lazy_covariance_matrix.root.evaluate())
            else:
                predictive_covar = MatmulLazyTensor(prior_covar @ variational_dist.covariance_matrix, prior_covar)

            # Cache some values for the KL divergence
            if self.training:
                self._mean_diff_inv_quad_memo, self._logdet_memo = prior_covar.inv_quad_logdet(
                    (variational_dist.mean - self.prior_distribution.mean), logdet=True
                )

            return MultivariateNormal(variational_dist.mean, predictive_covar)

        # Otherwise, we have to marginalize
        else:
            num_induc = inducing_points.size(-2)
            full_inputs = torch.cat([inducing_points, x], dim=-2)
            full_output = self.model.forward(full_inputs)
            full_mean, full_covar = full_output.mean, full_output.lazy_covariance_matrix

            # Mean terms
            test_mean = full_mean[..., num_induc:]
            induc_mean = full_mean[..., :num_induc]
            mean_diff = (variational_dist.mean - induc_mean).unsqueeze(-1)

            # Covariance terms
            induc_induc_covar = full_covar[..., :num_induc, :num_induc].add_jitter()
            induc_data_covar = full_covar[..., :num_induc, num_induc:].evaluate()
            data_data_covar = full_covar[..., num_induc:, num_induc:]

            # If we're less than a certain size, we'll compute the Cholesky decomposition of induc_induc_covar
            cholesky = False
            if settings.fast_computations.log_prob.off() or (num_induc <= settings.max_cholesky_size.value()):
                induc_induc_covar = CholLazyTensor(induc_induc_covar.cholesky())
                cholesky = True

            # Cache the CG results
            # Do not use preconditioning for whitened VI, as it does not seem to improve performance.
            with settings.max_preconditioner_size(0):
                with torch.no_grad():
                    eager_rhs = torch.cat([induc_data_covar, mean_diff], -1)
                    solve, probe_vecs, probe_vec_norms, probe_vec_solves, tmats = CachedCGLazyTensor.precompute_terms(
                        induc_induc_covar,
                        eager_rhs.detach(),
                        logdet_terms=(not cholesky),
                        include_tmats=(not settings.skip_logdet_forward.on() and not cholesky),
                    )
                    eager_rhss = [eager_rhs.detach()]
                    solves = [solve.detach()]
                    if settings.skip_logdet_forward.on() and self.training:
                        eager_rhss.append(torch.cat([probe_vecs, eager_rhs], -1))
                        solves.append(torch.cat([probe_vec_solves, solve[..., : eager_rhs.size(-1)]], -1))
                    elif not self.training:
                        eager_rhss.append(eager_rhs[..., :-1])
                        solves.append(solve[..., :-1])

                induc_induc_covar = CachedCGLazyTensor(
                    induc_induc_covar,
                    eager_rhss=eager_rhss,
                    solves=solves,
                    probe_vectors=probe_vecs,
                    probe_vector_norms=probe_vec_norms,
                    probe_vector_solves=probe_vec_solves,
                    probe_vector_tmats=tmats,
                )

            # Compute some terms that will be necessary for the predicitve covariance and KL divergence
            if self.training:
                interp_data_data_var_plus_mean_diff_inv_quad, logdet = induc_induc_covar.inv_quad_logdet(
                    torch.cat([induc_data_covar, mean_diff], -1), logdet=True, reduce_inv_quad=False
                )
                interp_data_data_var = interp_data_data_var_plus_mean_diff_inv_quad[..., :-1]
                mean_diff_inv_quad = interp_data_data_var_plus_mean_diff_inv_quad[..., -1]

            # Compute predictive mean
            predictive_mean = torch.add(
                test_mean,
                induc_induc_covar.inv_matmul(mean_diff, left_tensor=induc_data_covar.transpose(-1, -2)).squeeze(-1),
            )

            # Compute the predictive covariance
            is_root_lt = isinstance(variational_dist.lazy_covariance_matrix, RootLazyTensor)
            is_repeated_root_lt = isinstance(
                variational_dist.lazy_covariance_matrix, BatchRepeatLazyTensor
            ) and isinstance(variational_dist.lazy_covariance_matrix.base_lazy_tensor, RootLazyTensor)
            if is_root_lt:
                predictive_covar = RootLazyTensor(
                    induc_data_covar.transpose(-1, -2) @ variational_dist.lazy_covariance_matrix.root.evaluate()
                )
            elif is_repeated_root_lt:
                predictive_covar = RootLazyTensor(
                    induc_data_covar.transpose(-1, -2)
                    @ variational_dist.lazy_covariance_matrix.root_decomposition().root.evaluate()
                )
            else:
                predictive_covar = MatmulLazyTensor(
                    induc_data_covar.transpose(-1, -2), predictive_covar @ induc_data_covar
                )

            if self.training:
                data_covariance = DiagLazyTensor((data_data_covar.diag() - interp_data_data_var).clamp(0, math.inf))
            else:
                neg_induc_data_data_covar = torch.matmul(
                    induc_data_covar.transpose(-1, -2).mul(-1), induc_induc_covar.inv_matmul(induc_data_covar)
                )
                data_covariance = data_data_covar + neg_induc_data_data_covar
            predictive_covar = PsdSumLazyTensor(predictive_covar, data_covariance)

            # Save the logdet, mean_diff_inv_quad, prior distribution for the ELBO
            if self.training:
                self._memoize_cache["prior_distribution_memo"] = MultivariateNormal(induc_mean, induc_induc_covar)
                self._memoize_cache["logdet_memo"] = -logdet
                self._memoize_cache["mean_diff_inv_quad_memo"] = mean_diff_inv_quad

            return MultivariateNormal(predictive_mean, predictive_covar)
Ejemplo n.º 43
0
    def score(self, h, t, r):
        # view as matrix
        R = r.view(*r.shape[:-1], self.embedding_dim, self.embedding_dim)

        product = torch.matmul(h.unsqueeze(-2), R)
        return torch.matmul(product, t.unsqueeze(-1)).squeeze()
Ejemplo n.º 44
0
def mtimes(a: torch.Tensor, b: torch.Tensor, conj_a=False, conj_b=False):
    """Complex matrix multiplication of complex tensors.
    The dimensions (-3, -2) are matrix multiplied. -1 is the complex dimension."""

    if is_real(a):
        if a.dim() >= b.dim():
            raise ValueError('Incorrect dimensions.')
        return mtimes_real_complex(a, b, conj_b=conj_b)
    if is_real(b):
        if b.dim() >= a.dim():
            raise ValueError('Incorrect dimensions.')
        return mtimes_complex_real(a, b, conj_a=conj_a)

    if not conj_a and not conj_b:
        return complex(torch.matmul(a[..., 0], b[..., 0]) - torch.matmul(a[..., 1], b[..., 1]),
                       torch.matmul(a[..., 0], b[..., 1]) + torch.matmul(a[..., 1], b[..., 0]))
    if conj_a and not conj_b:
        return complex(torch.matmul(a[..., 0], b[..., 0]) + torch.matmul(a[..., 1], b[..., 1]),
                       torch.matmul(a[..., 0], b[..., 1]) - torch.matmul(a[..., 1], b[..., 0]))
    if not conj_a and conj_b:
        return complex(torch.matmul(a[..., 0], b[..., 0]) + torch.matmul(a[..., 1], b[..., 1]),
                       torch.matmul(a[..., 1], b[..., 0]) - torch.matmul(a[..., 0], b[..., 1]))
    if conj_a and conj_b:
        return complex(torch.matmul(a[..., 0], b[..., 0]) - torch.matmul(a[..., 1], b[..., 1]),
                       -torch.matmul(a[..., 0], b[..., 1]) - torch.matmul(a[..., 1], b[..., 0]))
    def generate_transformer_attr(self,
                                  input,
                                  index=None,
                                  method_name="transformer_attr"):
        kwargs = {"alpha": 1}
        output = self.model_usage.forward(input).question_answering_score
        model = self.model_usage.model

        # initialize relevancy matrices
        text_tokens = self.model_usage.text_len
        image_bboxes = self.model_usage.image_boxes_len

        # text self attention matrix
        self.R_t_t = torch.eye(text_tokens, text_tokens).to(model.device)
        # image self attention matrix
        self.R_i_i = torch.eye(image_bboxes, image_bboxes).to(model.device)
        # impact of images on text
        self.R_t_i = torch.zeros(text_tokens, image_bboxes).to(model.device)
        # impact of text on images
        self.R_i_t = torch.zeros(image_bboxes, text_tokens).to(model.device)

        if index == None:
            index = np.argmax(output.cpu().data.numpy(), axis=-1)

        one_hot = np.zeros((1, output.size()[-1]), dtype=np.float32)
        one_hot[0, index] = 1
        one_hot_vector = one_hot
        one_hot = torch.from_numpy(one_hot).requires_grad_(True)
        one_hot = torch.sum(one_hot.cuda() * output)

        model.zero_grad()
        one_hot.backward(retain_graph=True)
        model.relprop(torch.tensor(one_hot_vector).to(output.device), **kwargs)

        # language self attention
        blocks = model.lxmert.encoder.layer
        for blk in blocks:
            grad = blk.attention.self.get_attn_gradients().detach()
            cam = blk.attention.self.get_attn_cam().detach()
            cam = avg_heads(cam, grad)
            self.R_t_t += torch.matmul(cam, self.R_t_t)

        # image self attention
        blocks = model.lxmert.encoder.r_layers
        for blk in blocks:
            grad = blk.attention.self.get_attn_gradients().detach()
            cam = blk.attention.self.get_attn_cam().detach()
            cam = avg_heads(cam, grad)
            self.R_i_i += torch.matmul(cam, self.R_i_i)

        # cross attn layers
        blocks = model.lxmert.encoder.x_layers
        for i, blk in enumerate(blocks):
            # in the last cross attention module, only the text cross modal
            # attention has an impact on the CLS token, since it's the first
            # token in the language tokens
            if i == len(blocks) - 1:
                break

            # language self attention
            grad = blk.lang_self_att.self.get_attn_gradients().detach()
            cam = blk.lang_self_att.self.get_attn_cam().detach()
            cam = avg_heads(cam, grad)
            self.R_t_t += torch.matmul(cam, self.R_t_t)

            # image self attention
            grad = blk.visn_self_att.self.get_attn_gradients().detach()
            cam = blk.visn_self_att.self.get_attn_cam().detach()
            cam = avg_heads(cam, grad)
            self.R_i_i += torch.matmul(cam, self.R_i_i)

        # take care of last cross attention layer- only text
        blk = model.lxmert.encoder.x_layers[-1]
        # cross attn cam will be the one used for the R_t_i matrix
        cam_t_i = blk.visual_attention.att.get_attn_cam().detach()
        grad_t_i = blk.visual_attention.att.get_attn_gradients().detach()
        cam_t_i = avg_heads(cam_t_i, grad_t_i)
        # self.R_t_i = torch.matmul(self.R_t_t.t(), torch.matmul(cam_t_i, self.R_i_i))
        self.R_t_i = cam_t_i

        # language self attention
        grad = blk.lang_self_att.self.get_attn_gradients().detach()
        cam = blk.lang_self_att.self.get_attn_cam().detach()
        cam = avg_heads(cam, grad)
        self.R_t_t += torch.matmul(cam, self.R_t_t)

        self.R_t_t[0, 0] = 0
        return self.R_t_t, self.R_t_i
Ejemplo n.º 46
0
    def routing(self, x):	
        """
        Routing algorithm for capsule.

        :input: tensor x of shape [128, 8, 1152]

        :return: vector output of capsule j
        """
        batch_size = x.size(0)
        #print ('routing, batch_size:', batch_size)       
        #print (x.shape)
        
        x = x.transpose(1, 2) # dim 1 and dim 2 are swapped. out tensor shape: [128, 1152, 8]

        # Stacking and adding a dimension to a tensor.
        # stack ops output shape: [128, 1152, 10, 8]
        # unsqueeze ops output shape: [128, 1152, 10, 8, 1]
        x = torch.stack([x] * self.num_unit, dim=2).unsqueeze(4)

        # Convert single weight to batch weight.
        # [1 x 1152 x 10 x 16 x 8] to: [128, 1152, 10, 16, 8]
        batch_weight = torch.cat([self.weight] * batch_size, dim=0)

        # u_hat is "prediction vectors" from the capsules in the layer below.
        # Transform inputs by weight matrix.
        # Matrix product of 2 tensors with shape: [128, 1152, 10, 16, 8] x [128, 1152, 10, 8, 1]
        # u_hat shape: [128, 1152, 10, 16, 1]
        u_hat = torch.matmul(batch_weight, x)

        # All the routing logits (b_ij in the paper) are initialized to zero.
        # self.in_channel = primary_unit_size = 32 * 6 * 6 = 1152
        # self.num_unit = num_classes = 10
        # b_ij shape: [1, 1152, 10, 1]
        b_ij = Variable(torch.zeros(1, self.in_channel, self.num_unit, 1))
        if self.cuda_enabled:
            b_ij = b_ij.cuda()

        # From the paper in the "Capsules on MNIST" section,
        # the sample MNIST test reconstructions of a CapsNet with 3 routing iterations.
        num_iterations = self.num_routing

        for iteration in range(num_iterations):
            # Routing algorithm

            # Calculate routing or also known as coupling coefficients (c_ij).
            # c_ij shape: [1, 1152, 10, 1]
            c_ij = utils.softmax(b_ij, dim=2)  # Convert routing logits (b_ij) to softmax.
            # c_ij shape from: [128, 1152, 10, 1] to: [128, 1152, 10, 1, 1]
            c_ij = torch.cat([c_ij] * batch_size, dim=0).unsqueeze(4)

            # Implement equation 2 in the paper.
            # s_j is total input to a capsule, is a weigthed sum over all "prediction vectors".
            # u_hat is weighted inputs, prediction ˆuj|i made by capsule i.
            # c_ij * u_hat shape: [128, 1152, 10, 16, 1]
            # s_j output shape: [batch_size=128, 1, 10, 16, 1]
            # Sum of Primary Capsules outputs, 1152D becomes 1D.
            s_j = (c_ij * u_hat).sum(dim=1, keepdim=True)

            # Squash the vector output of capsule j.
            # v_j shape: [batch_size, weighted sum of PrimaryCaps output,
            #             num_classes, output_unit_size from u_hat, 1]
            # == [128, 1, 10, 16, 1]
            # So, the length of the output vector of a capsule is 16, which is in dim 3.
            v_j = utils.squash(s_j, dim=3)

            # in_channel is 1152.
            # v_j1 shape: [128, 1152, 10, 16, 1]
            v_j1 = torch.cat([v_j] * self.in_channel, dim=1)

            # The agreement.
            # Transpose u_hat with shape [128, 1152, 10, 16, 1] to [128, 1152, 10, 1, 16],
            # so we can do matrix product u_hat and v_j1.
            # u_vj1 shape: [1, 1152, 10, 1]
            u_vj1 = torch.matmul(u_hat.transpose(3, 4), v_j1).squeeze(4).mean(dim=0, keepdim=True)

            # Update routing (b_ij) by adding the agreement to the initial logit.
            b_ij = b_ij + u_vj1

        return v_j.squeeze(1) # shape: [128, 10, 16, 1]
Ejemplo n.º 47
0
def D(dat, Theta):
    c = torch.matmul(dat[:, 0:p], Theta.t())
    c = torch.clamp(c, -20, 20)
    out = -1.0 * (1 - dat[:, p].reshape(n, 1)) * c - torch.log(1.0 +
                                                               torch.exp(-c))
    return out
Ejemplo n.º 48
0
    def layer_A1(self):
        G = defaultdict(lambda: {})
        N = self.config.N
        M = self.config.M
        tau = self.config.tau
        mtsp_instance = self.mtsp_instance

        graph = mtsp_instance.graph
        depot = mtsp_instance.depot

        for i, v in enumerate(graph.V):
            for j, u in enumerate(graph.V):
                # append g_ij
                vu_dist = graph[v][u] / 1000
                xv_dist = graph[depot][v] / 1000
                xu_dist = graph[depot][u] / 1000

                a = torch.FloatTensor([vu_dist, xv_dist, xu_dist]).to(device)
                G[v][u] = torch.matmul(self.W1,
                                       F.relu(torch.matmul(self.W2, a)))
                # print(G[v][u])

        T = self.config.T

        # random mu initialization
        mu = {}
        next_mu = {}

        for v in mtsp_instance.remaining_cities:
            mu[v] = torch.rand(M, 1).to(device)
            next_mu[v] = None

        dist_from_robot = mtsp_instance.distance_from_robot()

        # main loop for struct2vec
        for t in range(T):
            for v in mtsp_instance.remaining_cities:

                Z = sum([
                    exp(G[u][v] / tau) for u in mtsp_instance.remaining_cities
                    if v != u
                ])
                p = [
                    exp(G[u][v] / tau) / Z
                    for u in mtsp_instance.remaining_cities if v != u
                ]  # softmax of G given tau

                l_not_weighted = [torch.cat(\
                    (graph[u][v] * F.relu(torch.matmul(self.W5, mu[v])), mu[u])) \
                        for u in mtsp_instance.remaining_cities if u != v]
                # print(l_not_weighted[0].shape, 123213)
                l = sum([p_uv * l_uv for p_uv, l_uv in zip(p, l_not_weighted)])

                a = torch.matmul(self.W3_A1, l)
                next_mu[v] = F.relu(torch.matmul(self.W3_A1, l) \
                    + torch.matmul(self.W4_A1,
                        torch.FloatTensor([[dist_from_robot[v]]]).to(device)))

            mu = next_mu

        return mu
Ejemplo n.º 49
0
def rpsm(cams, heatmaps, boxes, grid_center, limb_length, pairwise_constraint,
         config, **kwargs):
    """
    Args:
        cams : camera parameters for each view
        heatmaps: 2d pose heatmaps (n, k, h, w)
        boxes: on which the heatmaps are computed; n dictionaries
        grid_center: 3d location of the root
        limb_length: template limb length
        pairwise_constrain: pre-computed pairwise terms (iteration 0 psm only)
    Returns:
        pose3d: 3d pose
    """
    image_size = config.NETWORK.IMAGE_SIZE
    heatmap_size = config.NETWORK.HEATMAP_SIZE
    first_nbins = config.PICT_STRUCT.FIRST_NBINS
    recur_nbins = config.PICT_STRUCT.RECUR_NBINS
    recur_depth = config.PICT_STRUCT.RECUR_DEPTH
    grid_size = config.PICT_STRUCT.GRID_SIZE
    tolerance = config.PICT_STRUCT.LIMB_LENGTH_TOLERANCE

    # Iteration 1: discretizing 3d space
    # body = HumanBody()
    # current_device = torch.device('cuda:{}'.format(pairwise_constraint.values()[0].get_device()))
    current_device = kwargs['current_device']
    body = kwargs['human_body']
    grid = compute_grid(grid_size, grid_center, first_nbins)

    heatmaps = torch.as_tensor(heatmaps, dtype=torch.float32).to(
        current_device)  # todo: do this in dataloader
    extra_kwargs = kwargs

    # PSM
    do_bone_vectors = False
    if 'do_bone_vectors' in kwargs:
        if kwargs['do_bone_vectors']:
            do_bone_vectors = True
            bone_vectors = kwargs['bone_vectors']
    if do_bone_vectors:
        # merge limb length pairwise and bone orientation/vector pairwise term
        orient_pairwise = kwargs['orient_pairwise']
        new_pairwise_constrain = {}
        for node in body.skeleton:
            current = node['idx']
            children = node['children']
            bone_index = node['imubone']

            for idx_child, child in enumerate(children):
                constrain_array = pairwise_constraint[(current, child)]
                if bone_index[idx_child] >= 0:  # if certain bone has imu
                    expect_orient_vector = bone_vectors[bone_index[idx_child]]
                    expect_orient_vector = torch.as_tensor(
                        expect_orient_vector,
                        dtype=torch.float32).to(current_device)
                    norm_expect_orient_vector = expect_orient_vector / (
                        torch.norm(expect_orient_vector) + 1e-9)
                    norm_expect_orient_vector = norm_expect_orient_vector.view(
                        -1)  # (3,)
                    acutal_orient_vector = orient_pairwise  # (4096, 4096, 3)
                    cos_theta = torch.matmul(acutal_orient_vector,
                                             -norm_expect_orient_vector)
                    # todo we can add cos_theta activation func here
                    # acutal_orient_vector refer to 2 bin direction
                    # norm_expect_orient_vector refer to groundtruth direction
                    constrain_array = torch.mul(constrain_array, cos_theta)

                new_pairwise_constrain[(current, child)] = constrain_array
        pairwise_constraint = new_pairwise_constrain
    unary = compute_unary_term(heatmaps, [grid], boxes, cams, image_size)
    pose3d_as_cube_idx = infer(unary, pairwise_constraint, body, config,
                               **extra_kwargs)
    pose3d = get_loc_from_cube_idx([grid], pose3d_as_cube_idx)

    cur_grid_size = grid_size / first_nbins
    for i in range(recur_depth):
        pose3d = recursive_infer(pose3d, cams, heatmaps, boxes, image_size,
                                 heatmap_size, body, limb_length,
                                 cur_grid_size, recur_nbins, tolerance, config,
                                 **extra_kwargs)
        cur_grid_size = cur_grid_size / recur_nbins

    return pose3d
Ejemplo n.º 50
0
    def forward(self):

        A1 = self.layer_A1()
        A2 = self.layer_A2()

        layer_B_input = {}

        for v in self.mtsp_instance.remaining_cities:
            a = A1[v]
            b = A2[v]
            layer_B_input[v] = torch.cat((a, b))

        G = defaultdict(lambda: {})
        N = self.config.N
        M = self.config.M
        tau = self.config.tau
        mtsp_instance = self.mtsp_instance

        graph = mtsp_instance.graph
        depot = mtsp_instance.depot

        for i, v in enumerate(graph.V):
            for j, u in enumerate(graph.V):
                # append g_ij
                vu_dist = graph[v][u] / 1000
                xv_dist = graph[depot][v] / 1000
                xu_dist = graph[depot][u] / 1000

                a = torch.FloatTensor([vu_dist, xv_dist, xu_dist]).to(device)
                G[v][u] = torch.matmul(self.W1,
                                       F.relu(torch.matmul(self.W2, a)))
                # print(G[v][u])

        T = self.config.T

        # random mu initialization
        mu = {}
        next_mu = {}

        for v in mtsp_instance.remaining_cities:
            mu[v] = torch.rand(M, 1)
            next_mu[v] = None

        # main loop for struct2vec
        for t in range(T):
            for v in mtsp_instance.remaining_cities:

                Z = sum([
                    exp(G[u][v] / tau) for u in mtsp_instance.remaining_cities
                    if v != u
                ])
                p = [
                    exp(G[u][v] / tau) / Z
                    for u in mtsp_instance.remaining_cities if v != u
                ]  # softmax of G given tau

                l_not_weighted = [
                    mu[u] for u in mtsp_instance.remaining_cities if u != v
                ]

                l = sum([p_uv * l_uv for p_uv, l_uv in zip(p, l_not_weighted)])

                next_mu[v] = F.relu(\
                    torch.matmul(self.W3_B, l) \
                    + torch.matmul(self.W4_B, layer_B_input[v]))
                # print(layer_B_input[v].shape, 123)
            mu = next_mu

        return F.relu(
            torch.matmul(self.W7,
                         sum([mu[v] for v in mtsp_instance.remaining_cities])))
Ejemplo n.º 51
0
    def forward(
            self,
            wid_label: torch.
        LongTensor,  # shape B   string number -> int number
            wiki_title_meta: List[str],  # len B list
            mention_surface_meta: List[str],  # len B list
            mention_sent_lt: torch.
        FloatTensor,  # shape B x seq_len x 300 (word2vec)
            mention_sent_rt: torch.
        FloatTensor,  # shape B x seq_len x 300 (word2vec)
            type_labels: torch.
        LongTensor,  # shap B x seq_len --> dense one hot vector from reader [no padding]
            coherence_labels: List[List[int]],  #
            mention_sent_lt_len: List[int],
            mention_sent_rt_len: List[int],
            cross_wiki_candidates: List[List[int]],
            cross_wiki_priors: List[List[float]]):
        batch_size = wid_label.shape[0]
        mask_left = self.make_mask(batch_size, mention_sent_lt.shape[1],
                                   mention_sent_lt_len)
        sent_lt_encoded = self.left_seq2vec(
            mention_sent_lt,
            mask=mask_left)  # build mask manually for ArrayField

        mask_right = self.make_mask(batch_size, mention_sent_rt.shape[1],
                                    mention_sent_rt_len)
        sent_rt_encoded = self.right_seq2vec(mention_sent_rt, mask=mask_right)

        sent = torch.cat(
            (sent_lt_encoded, sent_rt_encoded),
            dim=1).to(self.device)  # dim0 is batch no adding batch

        v_local = self.ff_seq2vecs(sent)  # B x 200 --> B x 100

        if self.use_coherence:
            v_coh_batch = torch.zeros((batch_size, 100)).to(self.device)
            for i in range(batch_size):
                coherence_ind = torch.LongTensor(coherence_labels[i]).to(
                    self.device)
                v_coh1 = self.coherence_embedder(coherence_ind)
                v_coh2 = torch.sum(v_coh1, dim=0).view(1, -1).to(self.device)
                v_coh_batch[i, :] = v_coh2
            v_coh_batch = self.coherence_embedder_relu(
                v_coh_batch)  # todo check

            v_local = torch.cat((v_local, v_coh_batch), dim=1)

            v_local = self.ff_context(v_local)

        if self.use_type:
            v_type = self.ff_type(type_labels)

        loss = 0.0
        v_e_list = []
        for i in range(batch_size):
            entity_id = torch.LongTensor(cross_wiki_candidates[i]).to(
                self.device)  # 1 x C (C > 1)
            # this C, num classes changes per example pad zero/unk_wid?
            target = torch.LongTensor([0]).to(self.device)
            v_e = self.entity_embedder(entity_id)  # C x 200
            v_local_cur = v_local[i].view(-1, 1)  # 200 x 1
            score = torch.matmul(v_e, v_local_cur).view(1, -1)  # 1 x C
            temp = self.loss_context(score, target)
            loss += temp
            v_e_list.append(v_e)

        output_dict = {
            'loss': loss
        }  #, 'v_local': v_local, 'wid_label': wid_label }
        # 'candidates': cross_wiki_candidates, 'priors': cross_wiki_priors,
        # , 'v_e_list': v_e_list }

        # compute accuray metrics
        #with torch.no_grad():
        #    max_candidates = max( (len(t) for t in cross_wiki_candidates) )
        #    predictions_per_batch = []
        #    true_labels = wid_label
        #    for i in range(batch_size):   # i is for batch
        #        # list of prior probabilities len > 1    list 1 x C
        #        cur_prior = torch.FloatTensor(cross_wiki_priors[i]).view(1, -1).to(self.device)
        #        cur_ve = v_e_list[i].view(-1, 200)             # C x 200       tensor
        #        cur_vm = v_local[i].view(200, 1)            # 200 x 1    tensor

        #        prob_text = torch.exp(torch.matmul(cur_ve, cur_vm)).view(-1, 1)
        #        prob_text = prob_text / torch.sum(prob_text).to(self.device)     # C x 1

        #        temp = torch.zeros(max_candidates).to(self.device)
        #        prob_text = torch.squeeze(prob_text)
        #        cur_prior = torch.squeeze(cur_prior)
        #        temp[:len(cur_prior)] = cur_prior + prob_text - cur_prior * prob_text       # when this value can be nan?

        #        predictions_per_batch.append(temp)

        #    predictions = torch.stack(predictions_per_batch)
        #    for metric in self.metrics.values():
        #        metric(predictions=predictions,
        #               gold_labels=torch.zeros(len(true_labels)))  # true label is located at zero
        return output_dict
Ejemplo n.º 52
0
def knn(x, k):
    inner = -2 * torch.matmul(x.transpose(2, 1), x)
    xx = torch.sum(x**2, dim=1, keepdim=True)
    pairwise_distance = -xx - inner - xx.transpose(2, 1)
    idx = pairwise_distance.topk(k=k, dim=-1)[1]
    return idx
 def prob_h_given_v(self, v):
     return (torch.matmul(v, self.weights.data.t(), out=None).add_(
         self.hidden_bias.data).sigmoid_().clamp_(min=0, max=1))
Ejemplo n.º 54
0
 def gnn(self, xs, A, layer):
     for i in range(layer):
         hs = torch.relu(self.W_gnn[i](xs))
         xs = xs + torch.matmul(A, hs)
     # return torch.unsqueeze(torch.sum(xs, 0), 0)
     return torch.unsqueeze(torch.mean(xs, 0), 0)
Ejemplo n.º 55
0
 def forward(self, x, w):
     return torch.matmul(x, w)
Ejemplo n.º 56
0
    def sample_single_animal_x(self, z, x_pre, animal_idx, with_noise,
                               **kwargs):
        """

        :param z: a scalar
        :param x_pre: (1, 2)
        :param animal_idx
        :param with_noise:
        :return: (2, )
        """
        assert x_pre.shape == (1, 2), x_pre.shape

        A = kwargs.get("A_a", None) if animal_idx == 0 else kwargs.get(
            "A_b", None)
        Sigma = kwargs.get("Sigma_a", None) if animal_idx == 0 else kwargs.get(
            "Sigma_b", None)

        if A is None:
            #print("Not using cache. Calculating Sigma, A...")
            Sigma, A = self.get_gp_cache_condition_on_z(x_pre,
                                                        z,
                                                        animal_idx,
                                                        A_only=not with_noise,
                                                        **kwargs)

        assert A.shape == (1, 2, self.n_gps * 2)
        A = torch.squeeze(A, dim=0)

        u = self.us[z, :, 0:2] if animal_idx == 0 else self.us[z, :, 2:4]
        assert u.shape == (self.n_gps, 2), \
            "the correct shape is {}, instead we got {}".format((self.n_gps, 2), u.shape)
        u = torch.reshape(u, (-1, ))  # (n_gps*2, )

        # (2, n_gps*2) * (n_gps*2, 1) ->  (2, 1)
        mu = torch.matmul(A, u[..., None])
        assert mu.shape == (2, 1), mu.shape
        mu = torch.squeeze(mu, dim=-1)  # (2, )
        mu = mu + x_pre[0]
        assert mu.shape == (2, )

        if not with_noise:
            return mu

        assert Sigma.shape == (1, 2, 2)
        Sigma = torch.squeeze(Sigma, dim=0)

        # (2,)
        sigma = torch.exp(
            self.log_sigmas[z, 0:2]) if animal_idx == 0 else torch.exp(
                self.log_sigmas[z, 2:4])
        assert sigma.shape == (2, ), sigma.shape
        sigma = torch.diag(sigma)
        assert sigma.shape == (2, 2), sigma.shape

        cov = Sigma + sigma

        m = MultivariateNormal(mu, cov)
        sample = m.sample()
        assert sample.shape == (2, )

        return sample
Ejemplo n.º 57
0
    def forward(self, x):
        B, C, H, W = x.size()

        x_plus_pos = torch.cat([x, self.pos_emb.repeat(B, 1, 1, 1)], dim=1)
        K = self.K_nn(x_plus_pos).view(B, self.n_heads, self.dim_per_head, H,
                                       W)
        Q = self.Q_nn(x_plus_pos).view(B, self.n_heads, self.dim_per_head, H,
                                       W)
        V = self.V_nn(x_plus_pos).view(B, self.n_heads, self.dim_per_head, H,
                                       W)

        QK_W = torch.matmul(Q.permute(0, 1, 3, 4, 2), K.permute(0, 1, 3, 2, 4))
        # QK_W_old = torch.sum(
        #     Q.view(B, self.n_heads, self.dim_per_head, H, W, 1) *
        #     K.view(B, self.n_heads, self.dim_per_head, H, 1, W), dim=2
        # )
        # assert QK_W.size() == (B, self.n_heads, H, W, W)

        QK_H = torch.matmul(
            Q.permute(0, 1, 4, 3, 2),
            K.permute(0, 1, 4, 2, 3),
        ).permute(0, 1, 3, 2, 4)

        # QK_H_old = torch.sum(
        #     Q.view(B, self.n_heads, self.dim_per_head, H, 1, W) *
        #     K.view(B, self.n_heads, self.dim_per_head, 1, H, W), dim=2
        # ).permute(0, 1, 2, 4, 3)
        # print(torch.mean(torch.abs(QK_H - QK_H_old)))
        # assert QK_H.size() == (B, self.n_heads, H, W, H)

        # Shape: [B, heads, dims, H, W, W+H]
        QK = torch.cat([QK_W, QK_H], dim=-1)
        # assert QK.size() == (B, self.n_heads, H, W, W + H)

        QK = F.softmax(QK / np.sqrt(W + H), dim=-1)

        QK_W = QK[..., :W]
        # assert QK_W.size() == (B, self.n_heads, H, W, W)

        QK_H = QK[..., W:]
        # assert QK_H.size() == (B, self.n_heads, H, W, H)

        # V_W = V.view(B, self.n_heads, self.dim_per_head, H, 1, W)
        # V_H = V.permute(0, 1, 2, 4, 3).view(B, self.n_heads, self.dim_per_head, 1, W, H)
        # A_old = torch.sum(
        #     QK_W.view(B, self.n_heads, 1, H, W, W) * V_W, dim=-1) \
        #     + torch.sum(
        #     QK_H.view(B, self.n_heads, 1, H, W, H) * V_H, dim=-1
        # )

        A_W = torch.matmul(QK_W, V.permute(0, 1, 3, 4,
                                           2)).permute(0, 1, 4, 2, 3)
        QK_H = QK_H.permute(0, 1, 3, 2, 4)
        V_H = V.permute(0, 1, 4, 3, 2)
        A_H = torch.matmul(QK_H, V_H).permute(0, 1, 4, 3, 2)

        A = (A_W + A_H).contiguous()

        # assert A.size() == (B, self.n_heads, self.dim_per_head, H, W)

        A = A.view(B, self.n_heads * self.dim_per_head, H, W)

        return x + self.linear(A)
    def generate_rollout(self, input, method_name="rollout"):
        output = self.model_usage.forward(input).question_answering_score
        model = self.model_usage.model

        # initialize relevancy matrices
        text_tokens = self.model_usage.text_len
        image_bboxes = self.model_usage.image_boxes_len

        # text self attention matrix
        self.R_t_t = torch.eye(text_tokens, text_tokens).to(model.device)
        # image self attention matrix
        self.R_i_i = torch.eye(image_bboxes, image_bboxes).to(model.device)
        # impact of images on text
        self.R_t_i = torch.zeros(text_tokens, image_bboxes).to(model.device)
        # impact of text on images
        self.R_i_t = torch.zeros(image_bboxes, text_tokens).to(model.device)

        cams_text = []
        cams_image = []
        # language self attention
        blocks = model.lxmert.encoder.layer
        for blk in blocks:
            cam = blk.attention.self.get_attn().detach()
            cam = cam.reshape(-1, cam.shape[-2], cam.shape[-1]).mean(dim=0)
            cams_text.append(cam)

        # image self attention
        blocks = model.lxmert.encoder.r_layers
        for blk in blocks:
            cam = blk.attention.self.get_attn().detach()
            cam = cam.reshape(-1, cam.shape[-2], cam.shape[-1]).mean(dim=0)
            cams_image.append(cam)

        # cross attn layers
        blocks = model.lxmert.encoder.x_layers
        for i, blk in enumerate(blocks):
            # in the last cross attention module, only the text cross modal
            # attention has an impact on the CLS token, since it's the first
            # token in the language tokens
            if i == len(blocks) - 1:
                break

            # language self attention
            cam = blk.lang_self_att.self.get_attn().detach()
            cam = cam.reshape(-1, cam.shape[-2], cam.shape[-1]).mean(dim=0)
            cams_text.append(cam)

            # image self attention
            cam = blk.visn_self_att.self.get_attn().detach()
            cam = cam.reshape(-1, cam.shape[-2], cam.shape[-1]).mean(dim=0)
            cams_image.append(cam)

        # take care of last cross attention layer- only text
        blk = model.lxmert.encoder.x_layers[-1]
        # cross attn cam will be the one used for the R_t_i matrix
        cam_t_i = blk.visual_attention.att.get_attn().detach()
        cam_t_i = cam_t_i.reshape(-1, cam_t_i.shape[-2],
                                  cam_t_i.shape[-1]).mean(dim=0)
        self.R_t_t = compute_rollout_attention(copy.deepcopy(cams_text))
        self.R_i_i = compute_rollout_attention(cams_image)
        self.R_t_i = torch.matmul(self.R_t_t.t(),
                                  torch.matmul(cam_t_i, self.R_i_i))
        # language self attention
        cam = blk.lang_self_att.self.get_attn().detach()
        cam = cam.reshape(-1, cam.shape[-2], cam.shape[-1]).mean(dim=0)
        cams_text.append(cam)

        self.R_t_t = compute_rollout_attention(cams_text)

        # disregard the [CLS] token itself
        self.R_t_t[0, 0] = 0
        return self.R_t_t, self.R_t_i
Ejemplo n.º 59
0
    def forward(self, context_idxs, ques_idxs, context_char_idxs, ques_char_idxs, context_lens, start_mapping, end_mapping, all_mapping, return_yp=False):
        para_size, ques_size, char_size, bsz = context_idxs.size(1), ques_idxs.size(1), context_char_idxs.size(2), context_idxs.size(0)

        context_mask = (context_idxs > 0).float()
        ques_mask = (ques_idxs > 0).float()

        context_ch = self.char_emb(context_char_idxs.contiguous().view(-1, char_size)).view(bsz * para_size, char_size, -1)
        ques_ch = self.char_emb(ques_char_idxs.contiguous().view(-1, char_size)).view(bsz * ques_size, char_size, -1)

        context_ch = self.char_cnn(context_ch.permute(0, 2, 1).contiguous()).max(dim=-1)[0].view(bsz, para_size, -1)
        ques_ch = self.char_cnn(ques_ch.permute(0, 2, 1).contiguous()).max(dim=-1)[0].view(bsz, ques_size, -1)

        context_word = self.word_emb(context_idxs)
        ques_word = self.word_emb(ques_idxs)
        
        
        
        
#        logging("Start",self.config.save)
#        logging("context word "+str(context_word.shape),self.config.save)
#        logging("context ch"+str(context_ch.shape),self.config.save)
#        logging("question word"+str(ques_word.shape),self.config.save)
#        logging("question ch"+str(context_word.shape),self.config.save)
        
        context_output = torch.cat([context_word, context_ch], dim=2)
        ques_output = torch.cat([ques_word, ques_ch], dim=2)

#        logging("context output"+str(context_output.shape),self.config.save)
#        logging("ques output"+str(ques_output.shape),self.config.save)
#        
        context_output = self.rnn(context_output, context_lens)
        ques_output = self.rnn(ques_output)

#        logging("context output after RNN"+str(context_output.shape),self.config.save)
#        logging("ques output after RNN"+str(ques_output.shape),self.config.save)
#        
        output = self.qc_att(context_output, ques_output, ques_mask)
#        logging("attension output"+str(output.shape),self.config.save)
        output = self.linear_1(output)
#        logging("attension linear output"+str(output.shape),self.config.save)
        
        # Sentence branch
        sentence_embeddings , sentence_mask = self.mean_pooling_module(all_mapping ,output ,self.config.save)
#        logging("sentence embeddings "+str(sentence_embeddings.shape),self.config.save)
        
#        sentence_rnn_out = self.rnn_sentence(sentence_embeddings)
#        logging("sentence rnn output "+str(sentence_rnn_out.shape),self.config.save)
#        logging("sentece mask "+str(sentence_mask.shape),self.config.save)
        
        sentence_self_attension_out = self.self_att_sentences(sentence_embeddings,sentence_embeddings,sentence_mask)
#        logging("sentence self output "+str(sentence_self_attension_out.shape),self.config.save)
#        self.self_att_sentences
        
        sentence_embeddings = self.linear_sent_att(sentence_self_attension_out)
        
        sentence_embeddings = self.rnn_sentence(sentence_embeddings)
        
        # end
        output_t = self.rnn_2(output, context_lens)
#        logging("context lens "+str(context_lens.shape),self.config.save)
#        logging("output size " + str(output.shape),self.config.save)
#        logging("output 2nd RNN"+str(output_t.shape),self.config.save)
        output_t = self.self_att(output_t, output_t, context_mask)
#        logging("context_mask "+str(context_mask.shape),self.config.save)
#        logging("mask example 1 "+str(torch.sum(context_mask[0,:])),self.config.save)
#        logging("mask example 2 "+str(torch.sum(context_mask[1,:])),self.config.save)
#        logging("mask example 3 "+str(torch.sum(context_mask[2,:])),self.config.save)
#        logging("mask example 4 "+str(torch.sum(context_mask[3,:])),self.config.save)

#        logging("self attension output"+str(output_t.shape),self.config.save)
        output_t = self.linear_2(output_t)
#        logging("linear after self attension"+str(output_t.shape),self.config.save)
        
        output = output + output_t
#        logging("sum bi-attension and self-attension"+str(output.shape),self.config.save)
        sp_output = self.rnn_sp(output, context_lens)
#        logging("supporting fact RNN"+str(output.shape),self.config.save)
        
        start_output = torch.matmul(start_mapping.permute(0, 2, 1).contiguous(), sp_output[:,:,self.hidden:])
#        logging("start output sp"+str(start_output.shape),self.config.save)
        end_output = torch.matmul(end_mapping.permute(0, 2, 1).contiguous(), sp_output[:,:,:self.hidden])
#        logging("end output sp"+str(end_output.shape),self.config.save)
        sp_output = torch.cat([start_output, end_output], dim=-1)
#        logging("sp output "+str(sp_output.shape),self.config.save)
        
        sp_output = torch.cat([sp_output, sentence_embeddings], dim=-1)
#        logging("sp output sentence "+str(sp_output_sentence.shape),self.config.save)
        
        sp_output_t = self.linear_sp(sp_output)
#        logging("sp output after linear"+str(sp_output.shape),self.config.save)
        sp_output_aux = Variable(sp_output_t.data.new(sp_output_t.size(0), sp_output_t.size(1), 1).zero_())
#        logging("sp output aux"+str(sp_output_aux.shape),self.config.save)
       
        predict_support = torch.cat([sp_output_aux, sp_output_t], dim=-1).contiguous()
#        logging("predict_support"+str(predict_support.shape),self.config.save)
        
        return predict_support
Ejemplo n.º 60
0
 def complex_matmul(A_real, A_imag, B_real, B_imag):
     real = torch.matmul(A_real, B_real) - torch.matmul(A_imag, B_imag)
     imag = torch.matmul(A_real, B_imag) + torch.matmul(A_imag, B_real)
     return (real, imag)