def get_loss_original(self, image_a_pred, image_b_pred, matches_a,
                          matches_b, non_matches_a, non_matches_b,
                          M_margin=0.5, non_match_loss_weight=1.0):

        # this is pegged to it's implemenation at sha 87abdb63bb5b99d9632f5c4360b5f6f1cf54245f
        """
        Computes the loss function
        DCN = Dense Correspondence Network
        num_images = number of images in this batch
        num_matches = number of matches
        num_non_matches = number of non-matches
        W = image width
        H = image height
        D = descriptor dimension
        match_loss = 1/num_matches \sum_{num_matches} ||descriptor_a - descriptor_b||_2^2
        non_match_loss = 1/num_non_matches \sum_{num_non_matches} max(0, M_margin - ||descriptor_a - descriptor_b||_2^2 )
        loss = match_loss + non_match_loss
        :param image_a_pred: Output of DCN network on image A.
        :type image_a_pred: torch.Variable(torch.FloatTensor) shape [1, W * H, D]
        :param image_b_pred: same as image_a_pred
        :type image_b_pred:
        :param matches_a: torch.Variable(torch.LongTensor) has shape [num_matches,],  a (u,v) pair is mapped
        to (u,v) ---> image_width * v + u, this matches the shape of one dimension of image_a_pred
        :type matches_a: torch.Variable(torch.FloatTensor)
        :param matches_b: same as matches_b
        :type matches_b:
        :param non_matches_a: torch.Variable(torch.FloatTensor) has shape [num_non_matches,],  a (u,v) pair is mapped
        to (u,v) ---> image_width * v + u, this matches the shape of image_a_pred
        :type non_matches_a: torch.Variable(torch.FloatTensor)
        :param non_matches_b: same as non_matches_a
        :type non_matches_b:
        :return: loss, match_loss, non_match_loss
        :rtype: torch.Variable(torch.FloatTensor) each of shape torch.Size([1])
        """

        num_matches = matches_a.size()[0]
        num_non_matches = non_matches_a.size()[0]


        matches_a_descriptors = torch.index_select(image_a_pred, 1, matches_a)
        matches_b_descriptors = torch.index_select(image_b_pred, 1, matches_b)

        match_loss = 1.0/num_matches * (matches_a_descriptors - matches_b_descriptors).pow(2).sum()

        # add loss via non_matches
        non_matches_a_descriptors = torch.index_select(image_a_pred, 1, non_matches_a)
        non_matches_b_descriptors = torch.index_select(image_b_pred, 1, non_matches_b)
        pixel_wise_loss = (non_matches_a_descriptors - non_matches_b_descriptors).pow(2).sum(dim=2)
        pixel_wise_loss = torch.add(torch.neg(pixel_wise_loss), M_margin)
        zeros_vec = torch.zeros_like(pixel_wise_loss)
        non_match_loss = non_match_loss_weight * 1.0/num_non_matches * torch.max(zeros_vec, pixel_wise_loss).sum()

        loss = match_loss + non_match_loss

        return loss, match_loss, non_match_loss
    def forward(self, s1, s2, s1_hidden, s2_hidden):
        s1_embeded = self.embedding(s1)
        s2_embeded = self.embedding(s2)

        s1_cnn, s2_cnn = self.cnn(s1_embeded, s2_embeded)

        s1_output, s1_hidden = self.lstm1(s1_embeded, s1_hidden)
        s2_output, s2_hidden = self.lstm1(s2_embeded, s2_hidden)
        s1_output = self.layer_norm(s1_output)
        s2_output = self.layer_norm(s2_output)
        '''co-attention 权重计算方式'''
        s1_expanded = s1_output.unsqueeze(2).expand(s1_output.size()[0],
                                                    s1_output.size()[1],
                                                    s2_output.size()[1],
                                                    s1_output.size()[2])
        s2_expanded = s2_output.unsqueeze(1).expand(s2_output.size()[0],
                                                    s1_output.size()[1],
                                                    s2_output.size()[1],
                                                    s2_output.size()[2])

        cosine_sim = self.CosineSim(s1_expanded, s2_expanded)
        s1_attentive = torch.matmul(self.softmax(cosine_sim), s2_output)
        s2_attentive = torch.matmul(self.softmax(cosine_sim.transpose(1, 2)),
                                    s1_output)
        s1_attentive = self.layer_norm(s1_attentive)
        s2_attentive = self.layer_norm(s2_attentive)

        # s1_concat = torch.cat((s1_output, s1_attentive), 2)
        # s2_concat = torch.cat((s2_output, s2_attentive), 2)

        # s1_representation, _ = torch.max(s1_concat, 1)
        # s2_representation, _ = torch.max(s2_concat, 1)

        # s1_representation_mean = torch.mean(s1_concat, 1, keepdim=False)
        # s2_representation_mean = torch.mean(s2_concat, 1, keepdim=False)

        #s1_representation = torch.cat((s1_representation, s1_cnn), 1)
        #s2_representation = torch.cat((s2_representation, s2_cnn), 1)
        #s1_representation = torch.cat((s1_representation_max, s1_representation_mean), 1)
        #s2_representation = torch.cat((s2_representation_max, s2_representation_mean), 1)
        '''计算s1的每个hidden和s2的hidden_mean之间的注意力权重'''
        # s2_hidden_mean = torch.mean(s2_output, 1, keepdim=True)
        # s2_hidden_mean = torch.transpose(s2_hidden_mean, 1, 2).contiguous()
        # s1_attention_weight = torch.matmul(s1_output, s2_hidden_mean)
        # s1_attention_weight = self.softmax(s1_attention_weight)
        # s1_attention_weight = torch.transpose(s1_attention_weight, 1, 2).contiguous()
        # s1_representation = torch.matmul(s1_attention_weight, s1_output).squeeze()
        '''计算s2的每个hidden和s1的hidden_mean之间的注意力权重'''
        # s1_hidden_mean = torch.mean(s1_output, 1, keepdim=True)
        # s1_hidden_mean = torch.transpose(s1_hidden_mean, 1, 2).contiguous()
        # s2_attention_weight = torch.matmul(s2_output, s1_hidden_mean)
        # s2_attention_weight = self.softmax(s2_attention_weight)
        # s2_attention_weight = torch.transpose(s2_attention_weight, 1, 2).contiguous()
        # s2_representation = torch.matmul(s2_attention_weight, s2_output).squeeze()
        '''计算s1的每个hidden和s2的最后一个hidden之间的注意力权重'''
        # s2_last_hidden = s2_hidden[0]
        # s2_last_hidden = torch.transpose(s2_last_hidden, 0, 1).contiguous()  # 注意: transpose之后一定要进行contiguous()
        # s2_last_hidden_size = s2_last_hidden.size()
        # s2_last_hidden = s2_last_hidden.view(s2_last_hidden_size[0], -1, 1)
        # s1_attention_weight = torch.matmul(s1_output, s2_last_hidden)
        # s1_attention_weight = self.softmax(s1_attention_weight)
        # s1_attention_weight = torch.transpose(s1_attention_weight, 1, 2).contiguous()
        # s1_representation = torch.matmul(s1_attention_weight, s1_output).squeeze()
        '''计算s2的每个hidden和s1的最后一个hidden之间的注意力权重'''
        # s1_last_hidden = s1_hidden[0]
        # s1_last_hidden = torch.transpose(s1_last_hidden, 0, 1).contiguous()
        # s1_last_hidden_size = s1_last_hidden.size()
        # s1_last_hidden = s1_last_hidden.view(s1_last_hidden_size[0], -1, 1)
        # s2_attention_weight = torch.matmul(s2_output, s1_last_hidden)
        # s2_attention_weight = self.softmax(s2_attention_weight)
        # s2_attention_weight = torch.transpose(s2_attention_weight, 1, 2).contiguous()
        # s2_representation = torch.matmul(s2_attention_weight, s2_output).squeeze()
        '''计算s1的self-attention'''
        # s1_H = s1_output
        # s1_h1 = self.S1(s1_H)
        # s1_h1 = self.tanh(s1_h1)
        # s1_h2 = self.S2(s1_h1)
        # attention_weight = self.softmax(s1_h2)
        # attention_weight = torch.transpose(attention_weight, 1, 2)
        # s1_representation = torch.matmul(attention_weight, s1_H).squeeze()
        '''计算s2的self-attention'''
        # s2_H = s2_output
        # s2_h1 = self.S1(s2_H)
        # s2_h1 = self.tanh(s2_h1)
        # s2_h2 = self.S2(s2_h1)
        # attention_weight = self.softmax(s2_h2)
        # attention_weight = torch.transpose(attention_weight, 1, 2)
        # s2_representation = torch.matmul(attention_weight, s2_H).squeeze()
        '''merged = [ s1 + s2 ; pow((s1 - s2), 2) ]'''
        merge_add = torch.add(s1_representation, s2_representation)
        neg_s2 = torch.neg(s2_representation)
        merge_minus = torch.add(s1_representation, neg_s2)
        merge_minus = torch.pow(merge_minus, 2)
        merged = torch.cat((merge_add, merge_minus), 1)
        '''merged = [ s1 ; s2 ; s1 + s2 ; s1 - s2 ; |s1 - s2| ]'''
        # merge_add = torch.add(s1_representation, s2_representation)
        # neg_s2 = torch.neg(s2_representation)
        # merge_minus = torch.add(s1_representation, neg_s2)
        # merge_abs = torch.abs(merge_minus)
        # merged = torch.cat((s1_representation, s2_representation, merge_add, merge_minus, merge_abs), 1)

        if self.use_cuda:
            merged = merged.cuda()

        # merged = self.dropout(merged)
        output = self.relu(self.mlp1(merged))
        # output_mlp2 = self.relu(self.mlp2(output_mlp1))
        # output_mlp = self.dropout(output_mlp)

        output = self.output(output)
        output = self.sigmoid(output)
        return output
Exemple #3
0
 def forward(self, x):
     return torch.neg(x)
Exemple #4
0
 def pattern(x):
     return torch.neg(x) + torch.relu(x)
Exemple #5
0
 def pattern(x):
     return torch.neg(x)
Exemple #6
0
 def comparison(x):
     y = torch.sigmoid(x)
     return torch.neg(y) - y
Exemple #7
0
 def pattern(x):
     a = torch.neg(x)
     return torch.add(a, a)
Exemple #8
0
 def forward(self, x):
     return torch.neg(x) + torch.relu(x)
Exemple #9
0
def neg(*, input):
    return torch.neg(**locals())
Exemple #10
0
	def pairwise_word_interaction(self,out0,out1, target_A, target_B):
		extra_loss=0
		h_fw_0, h_bw_0 = self.unpack(out0.view(out0.size(0),out0.size(2)),half_dim=self.hidden_dim)
		h_fw_1, h_bw_1 = self.unpack(out1.view(out1.size(0),out1.size(2)),half_dim=self.hidden_dim)
		#print(h_fw_0)
		#print(h_bw_0)
		#print(h_fw_1)
		#print(h_bw_1)
		#sys.exit()
		h_bi_0 = out0.view(out0.size(0),out0.size(2))
		h_bi_1 = out1.view(out1.size(0),out1.size(2))
		h_sum_0 = h_fw_0 + h_bw_0
		h_sum_1 = h_fw_1 + h_bw_1
		len0 = h_fw_0.size(0)
		len1 = h_fw_1.size(0)
		i=0
		j=0
		#simCube1 = torch.mm(h_fw_0[i].view(1, -1), h_fw_1[j].view(-1, 1))
		#simCube2 = torch.mm(h_bw_0[i].view(1, -1), h_bw_1[j].view(-1, 1))
		#simCube3 = torch.mm(h_bi_0[i].view(1, -1), h_bi_1[j].view(-1, 1))
		#simCube4 = torch.mm(h_sum_0[i].view(1, -1), h_sum_1[j].view(-1, 1))
		#simCube5 = F.pairwise_distance(h_fw_0[i].view(1, -1), h_fw_1[j].view(1, -1))
		simCube5_0 = h_fw_0[i].view(1, -1)
		simCube5_1 = h_fw_1[j].view(1, -1)
		#simCube6 = F.pairwise_distance(h_bw_0[i].view(1, -1), h_bw_1[j].view(1, -1))
		simCube6_0 = h_bw_0[i].view(1, -1)
		simCube6_1 = h_bw_1[j].view(1, -1)
		#simCube7 = F.pairwise_distance(h_bi_0[i].view(1, -1), h_bi_1[j].view(1, -1))
		simCube7_0 = h_bi_0[i].view(1, -1)
		simCube7_1 = h_bi_1[j].view(1, -1)
		#simCube8 = F.pairwise_distance(h_sum_0[i].view(1, -1), h_sum_1[j].view(1, -1))
		simCube8_0 = h_sum_0[i].view(1,-1)
		simCube8_1 = h_sum_1[j].view(1,-1)
		#simCube9 = F.cosine_similarity(h_fw_0[i].view(1, -1), h_fw_1[j].view(1, -1))
		#simCube10 = F.cosine_similarity(h_bw_0[i].view(1, -1), h_bw_1[j].view(1, -1))
		#simCube11 = F.cosine_similarity(h_bi_0[i].view(1, -1), h_bi_1[j].view(1, -1))
		#simCube12 = F.cosine_similarity(h_sum_0[i].view(1, -1), h_sum_1[j].view(1, -1))
		for i in range(len0):
			for j in range(len1):
				if not(i==0 and j==0):
					simCube5_0 = torch.cat((simCube5_0, h_fw_0[i].view(1, -1)))
					simCube5_1 = torch.cat((simCube5_1, h_fw_1[j].view(1, -1)))
					simCube6_0 = torch.cat((simCube6_0, h_bw_0[i].view(1, -1)))
					simCube6_1 = torch.cat((simCube6_1, h_bw_1[j].view(1, -1)))
					simCube7_0 = torch.cat((simCube7_0, h_bi_0[i].view(1, -1)))
					simCube7_1 = torch.cat((simCube7_1, h_bi_1[j].view(1, -1)))
					simCube8_0 = torch.cat((simCube8_0, h_sum_0[i].view(1, -1)))
					simCube8_1 = torch.cat((simCube8_1, h_sum_1[j].view(1, -1)))
		simCube1 = torch.unsqueeze(torch.mm(h_fw_0, torch.transpose(h_fw_1, 0, 1)), 0)
		simCube2 = torch.unsqueeze(torch.mm(h_bw_0, torch.transpose(h_bw_1, 0, 1)), 0)
		simCube3 = torch.unsqueeze(torch.mm(h_bi_0, torch.transpose(h_bi_1, 0, 1)), 0)
		simCube4 = torch.unsqueeze(torch.mm(h_sum_0, torch.transpose(h_sum_1, 0, 1)), 0)
		simCube5 = torch.neg(F.pairwise_distance(simCube5_0, simCube5_1))
		simCube5 = torch.unsqueeze(simCube5.view(len0, len1), 0)
		simCube6 = torch.neg(F.pairwise_distance(simCube6_0, simCube6_1))
		simCube6 = torch.unsqueeze(simCube6.view(len0, len1), 0)
		simCube7 = torch.neg(F.pairwise_distance(simCube7_0, simCube7_1))
		simCube7 = torch.unsqueeze(simCube7.view(len0, len1), 0)
		simCube8 = torch.neg(F.pairwise_distance(simCube8_0,simCube8_1))
		simCube8 = torch.unsqueeze(simCube8.view(len0, len1), 0)

		simCube9 = F.cosine_similarity(simCube5_0,simCube5_1)
		simCube9 = torch.unsqueeze(simCube9.view(len0,len1), 0)
		simCube10 = F.cosine_similarity(simCube6_0, simCube6_1)
		simCube10 = torch.unsqueeze(simCube10.view(len0, len1), 0)
		simCube11 = F.cosine_similarity(simCube7_0, simCube7_1)
		simCube11 = torch.unsqueeze(simCube11.view(len0, len1), 0)
		simCube12 = F.cosine_similarity(simCube8_0, simCube8_1)
		simCube12= torch.unsqueeze(simCube12.view(len0, len1), 0)
		''''''
		if torch.cuda.is_available():
			simCube13 = torch.unsqueeze(Variable(torch.zeros(len0,len1)).cuda()+1,0)
		else:
			simCube13 = torch.unsqueeze(Variable(torch.zeros(len0,len1))+1,0)
		simCube=torch.cat((simCube9,simCube5,simCube1,simCube10,simCube6,simCube2,simCube12,simCube8,simCube4,simCube11,simCube7,simCube3,simCube13),0)
		#simCube=torch.unsqueeze(simCube,0)
		#simCube = F.pad(simCube, (0, self.limit - simCube.size(3), 0, self.limit - simCube.size(2)))[0]
		#print(simCube1)
		#print(simCube)
		#print(simCube8)
		#sys.exit()
		return simCube, extra_loss
Exemple #11
0
 def pattern(a):
     return torch.neg(a)
Exemple #12
0
 def forward(x):
     val = torch.neg(x)
     return torch.add(val, val)
Exemple #13
0
 def forward(self, output, target):
     loss = self.wts[0] * (
         target.float() * torch.log(output).float()) + self.wts[1] * (
             (1 - target).float() * torch.log(1 - output).float())
     return torch.neg(torch.mean(loss))
Exemple #14
0
def generate(
    model,
    inputs=None,
    task_ids=None,
    tokens_to_generate=0,
    all_probs=False,
    temperature=1.0,
    add_BOS=False,
    top_k=0,
    top_p=0.0,
    greedy=False,
    repetition_penalty=1.0,
    min_tokens_to_generate=0,
) -> OutputType:
    """
    Args:
        model (NLPModel): text generative model
        inputs (Union[tuple, List[str]]): if it is a tuple, it is assumed to be (context_tokens_tensor, context_length_tensor). Otherwise it it a list of prompt text strings
        task_ids (Tensor): used to specify that task when generating with p-tuned/prompt-tuned models (optional, default=None)
        tokens_to_generate (int): The maximum length of the tokens to be generated.
        all_probs (bool): Return the log prob for all the tokens
        temperature (float): sampling temperature
        add_BOS (bool): add the bos token at the begining of the prompt
        top_k (int): The number of highest probability vocabulary tokens to keep for top-k-filtering.
        top_p (float): If set to float < 1, only the most probable tokens with probabilities that add up to top_p or higher are kept for generation.
        greedy (bool):  Whether or not to use sampling ; use greedy decoding otherwise
        repetition_penalty (float): The parameter for repetition penalty. 1.0 means no penalty
        min_tokens_to_generate (int): The minimum length of the tokens to be generated
    Returns:
        OutputType: It generates the output in a dictionary type. It has the following keys:
            sentences: List[str], output sentences
            tokens: List[List[str]], output sentences borken into tokens
            logprob: List[Tensor], log prob of generated tokens
            full_logprob: List[Tensor], log prob of all the tokens in the vocab
            token_ids: List[Tensor], output sentence token ids
            offsets: List[List[int]]  # list of tokens start positions in text
    """
    model.eval()
    tokenizer = model.tokenizer
    if torch.distributed.get_rank() == 0:
        if isinstance(inputs, tuple):
            context_tokens_tensor, context_length_tensor = inputs
        else:
            context_tokens_tensor, context_length_tensor = tokenize_batch(
                tokenizer, inputs, tokens_to_generate, add_BOS)
        if task_ids is None:
            # Make a dummy tensor of -1s that won't be used during generation
            task_ids = torch.neg(
                torch.ones(context_tokens_tensor.size(0), dtype=torch.int64))
            task_ids = task_ids.to(device=context_tokens_tensor.get_device())

        send_generate_info(
            context_tokens_tensor,
            context_length_tensor,
            task_ids,
            tokens_to_generate,
            all_probs,
            temperature,
            top_k,
            top_p,
            greedy,
            repetition_penalty,
            min_tokens_to_generate,
        )
    else:
        (
            context_length_tensor,
            context_tokens_tensor,
            task_ids,
            tokens_to_generate,
            all_probs,
            temperature,
            top_k,
            top_p,
            greedy,
            repetition_penalty,
            min_tokens_to_generate,
        ) = receive_generate_info()

    output = synced_generate(
        model,
        context_tokens_tensor,
        context_length_tensor,
        task_ids,
        tokens_to_generate,
        all_probs,
        temperature,
        top_k=top_k,
        top_p=top_p,
        greedy=greedy,
        repetition_penalty=repetition_penalty,
        min_tokens_to_generate=min_tokens_to_generate,
    )
    if output is not None:
        decode_tokens, output_logits, full_logits = output
        resp_sentences = []
        resp_sentences_seg = []

        decode_tokens = decode_tokens.cpu().numpy().tolist()
        for decode_token in decode_tokens:
            sentence = tokenizer.ids_to_text(decode_token)
            resp_sentences.append(sentence)
            if not isinstance(tokenizer, TabularTokenizer):
                words = []
                for token in decode_token:
                    # Skip any soft prompt pseudo tokens
                    if token not in tokenizer.tokenizer.decoder:
                        continue
                    word = tokenizer.tokenizer.decoder[token]
                    word = bytearray([
                        tokenizer.tokenizer.byte_decoder[c] for c in word
                    ]).decode('utf-8', errors='replace')
                    words.append(word)
                resp_sentences_seg.append(words)
            else:
                words = tokenizer.text_to_tokens(sentence)
                resp_sentences_seg.append(words)
        # offsets calculation
        all_offsets = []
        for item in resp_sentences_seg:
            offsets = [0]
            for index, token in enumerate(item):
                if index != len(item) - 1:
                    offsets.append(len(token) + offsets[-1])
            all_offsets.append(offsets)

        output = {}
        output['sentences'] = resp_sentences
        output['tokens'] = resp_sentences_seg
        output['logprob'] = output_logits
        output['full_logprob'] = full_logits
        output['token_ids'] = decode_tokens
        output['offsets'] = all_offsets
        return output
 def forward(self, x):
     if bool(torch.sum(x) > 0):
         x = torch.neg(x)
     return x
Exemple #16
0
 def comparison(x, y):
     val = torch.neg(y) + torch.neg(x)
     return torch.add(val, val)
 def loss_ML_sam(output, target, sam):
     output = torch.clamp(output, 1e-5, 1 - 1e-5)
     ML = sam * (target * torch.log(output) +
                 (1 - target) * torch.log(1 - output))
     return torch.neg(torch.sum(ML))
Exemple #18
0
 def forward(self, x):
     a = torch.neg(x)
     return torch.add(a, a)
Exemple #19
0
     ex_indices = [i for i in range(0, len(train_ys))]
     random.shuffle(ex_indices)
     total_loss = 0.0
     for idx in ex_indices:
         x = form_input(train_xs[idx])
         y = train_ys[idx]
         # Build one-hot representation of y
         y_onehot = torch.zeros(num_classes)
         y_onehot.scatter_(0, torch.from_numpy(np.asarray(y,
                                                          dtype=np.long)),
                           1)
         # Zero out the gradients from the FFNN object. *THIS IS VERY IMPORTANT TO DO BEFORE CALLING BACKWARD()*
         ffnn.zero_grad()
         probs = ffnn.forward(x)
         # Can also use built-in NLLLoss as a shortcut here (takes log probabilities) but we're being explicit here
         loss = torch.neg(torch.log(probs)).dot(y_onehot)
         total_loss += loss
         loss.backward()
         optimizer.step()
     print("Loss on epoch %i: %f" % (epoch, total_loss))
 # Evaluate on the train set
 train_correct = 0
 for idx in range(0, len(train_xs)):
     # Note that we only feed in the x, not the y, since we're not training. We're also extracting different
     # quantities from the running of the computation graph, namely the probabilities, prediction, and z
     x = form_input(train_xs[idx])
     y = train_ys[idx]
     probs = ffnn.forward(x)
     prediction = torch.argmax(probs)
     if y == prediction:
         train_correct += 1
Exemple #20
0
 def forward(self, x):
     y = torch.relu(x)
     return torch.neg(y) - y
Exemple #21
0
        tanh_1 = torch.tanh(cat_1);  cat_1 = None
        neg_1 = torch.neg(tanh_1);  tanh_1 = None
        return neg_1

"""

# Create a graph independently of symbolic tracing
graph = Graph()

# Create raw Nodes
raw1 = graph.placeholder("x")
raw2 = graph.placeholder("y")

# Initialize Proxies using the raw Nodes
y = Proxy(raw1)
z = Proxy(raw2)

# Create other operations using the Proxies `y` and `z`
a = torch.cat([y, z])
b = torch.tanh(a)
c = torch.neg(b)

# Create a new output Node and add it to the Graph. By doing this, the
# Graph will contain all the Nodes we just created (since they're all
# linked to the output Node)
graph.output(c.node)

# Wrap our created Graph in a GraphModule to get a final, runnable
# `nn.Module` instance
mod = GraphModule(torch.nn.Module(), graph)
Exemple #22
0
 def comparison(x):
     val = torch.neg(x) + torch.relu(x)
     return torch.add(val, val)
Exemple #23
0
    def forward(self, input, mask):
        mask = cal_feat_mask(mask, 3, 1)
        # input[2]:256 32 32
        b, c, h, w = input[2].size()
        mask_1 = torch.add(torch.neg(mask.float()), 1)
        mask_1 = mask_1.expand(b, c, h, w)

        x_1 = self.activation(input[0])
        x_2 = self.activation(input[1])
        x_3 = self.activation(input[2])
        x_4 = self.activation(input[3])
        x_5 = self.activation(input[4])
        x_6 = self.activation(input[5])
        # Change the shape of each layer and intergrate low-level/high-level features
        x_1 = self.down_128(x_1)
        x_2 = self.down_64(x_2)
        x_3 = self.down_32(x_3)
        x_4 = self.up(x_4, (32, 32))
        x_5 = self.up(x_5, (32, 32))
        x_6 = self.up(x_6, (32, 32))

        # The first three layers are Texture/detail
        # The last three layers are Structure
        x_DE = torch.cat([x_1, x_2, x_3], 1)
        x_ST = torch.cat([x_4, x_5, x_6], 1)

        x_ST = self.down(x_ST)
        x_DE = self.down(x_DE)
        x_ST = [x_ST, mask_1]
        x_DE = [x_DE, mask_1]

        # Multi Scale PConv fill the Details
        x_DE_3 = self.cov_3(x_DE)
        x_DE_5 = self.cov_5(x_DE)
        x_DE_7 = self.cov_7(x_DE)
        x_DE_fuse = torch.cat([x_DE_3[0], x_DE_5[0], x_DE_7[0]], 1)
        x_DE_fi = self.down(x_DE_fuse)

        # Multi Scale PConv fill the Structure
        x_ST_3 = self.cov_3(x_ST)
        x_ST_5 = self.cov_5(x_ST)
        x_ST_7 = self.cov_7(x_ST)
        x_ST_fuse = torch.cat([x_ST_3[0], x_ST_5[0], x_ST_7[0]], 1)
        x_ST_fi = self.down(x_ST_fuse)

        x_cat = torch.cat([x_ST_fi, x_DE_fi], 1)
        x_cat_fuse = self.fuse(x_cat)

        # Feature equalizations
        x_final = self.base(x_cat_fuse)

        # Add back to the input
        x_ST = x_final
        x_DE = x_final
        x_1 = self.up_128(x_DE, (128, 128)) + input[0]
        x_2 = self.up_64(x_DE, (64, 64)) + input[1]
        x_3 = self.up_32(x_DE, (32, 32)) + input[2]
        x_4 = self.down_16(x_ST) + input[3]
        x_5 = self.down_8(x_ST) + input[4]
        x_6 = self.down_4(x_ST) + input[5]

        out = [x_1, x_2, x_3, x_4, x_5, x_6]
        loss = [x_ST_fi, x_DE_fi]
        out_final = [out, loss]
        return out_final
Exemple #24
0
 def forward(self, x):
     val = torch.neg(x) + torch.relu(x)
     return torch.add(val, val)
Exemple #25
0
 def test_neg(x, y):
     c = torch.neg(torch.add(x, y))
     return c
Exemple #26
0
            outputs = discriminator(torch.cat((generated_outputs, train_noisy), dim=1), ref_batch)
            noisy_loss = torch.mean(outputs ** 2)  # L2 loss - we want them all to be 0
            noisy_loss.backward()

            d_loss = clean_loss + noisy_loss
            d_optimizer.step()  # update parameters

            # TRAIN G so that D recognizes G(z) as real
            generator.zero_grad()
            generated_outputs = generator(train_noisy, z)
            gen_noise_pair = torch.cat((generated_outputs, train_noisy), dim=1)
            outputs = discriminator(gen_noise_pair, ref_batch)

            g_loss_ = 0.5 * torch.mean((outputs - 1.0) ** 2)
            # L1 loss between generated output and clean sample
            l1_dist = torch.abs(torch.add(generated_outputs, torch.neg(train_clean)))
            g_cond_loss = 100 * torch.mean(l1_dist)  # conditional loss
            g_loss = g_loss_ + g_cond_loss

            # backprop + optimize
            g_loss.backward()
            g_optimizer.step()
            train_bar.set_description(
                'Epoch {}: d_clean_loss {:.4f}, d_noisy_loss {:.4f}, g_loss {:.4f}, g_conditional_loss {:.4f}'
                    .format(epoch + 1, clean_loss.item(), noisy_loss.item(), g_loss.item(), g_cond_loss.item()))
        if not os.path.exists(opt.loss_path):
            with open(opt.loss_path, 'w') as f:
                f.write('epoch,g_loss_,g_cond_loss,total_g_loss,clean_loss,noisy_loss,total_d_loss\n')
        with open(opt.loss_path, 'a') as file:
            file.write('{},{},{},{},{},{},{}\n'.format(epoch+1, g_loss_, g_cond_loss, g_loss, clean_loss, noisy_loss, d_loss))
Exemple #27
0
# mul/multiply
torch.mul(torch.randn(3), 100)
torch.multiply(torch.randn(4, 1), torch.randn(1, 4))

# mvlgamma
torch.mvlgamma(torch.empty(2, 3).uniform_(1, 2), 2)

# nan_to_num
w = torch.tensor([float('nan'), float('inf'), -float('inf'), 3.14])
torch.nan_to_num(x)
torch.nan_to_num(x, nan=2.0)
torch.nan_to_num(x, nan=2.0, posinf=1.0)

# neg/negative
torch.neg(torch.randn(5))

# nextafter
eps = torch.finfo(torch.float32).eps
torch.nextafter(torch.tensor([1, 2]),
                torch.tensor([2, 1])) == torch.tensor([eps + 1, 2 - eps])

# polygamma
torch.polygamma(1, torch.tensor([1, 0.5]))
torch.polygamma(2, torch.tensor([1, 0.5]))
torch.polygamma(3, torch.tensor([1, 0.5]))
torch.polygamma(4, torch.tensor([1, 0.5]))

# pow
torch.pow(a, 2)
torch.pow(torch.arange(1., 5.), torch.arange(1., 5.))
Exemple #28
0
 def forward(self, x):
     return torch.neg(self.submod(x.relu() + self.attr))
 def test_non_promoting_ops(self, device):
     x = torch.ones(4, dtype=torch.double, device=device)
     self.assertRaises(RuntimeError,
                       lambda: torch.neg(torch.ones(4, dtype=torch.float, device=device), out=x))
     self.assertRaises(RuntimeError,
                       lambda: torch.lerp(x, torch.ones(4, dtype=torch.float, device=device), 1))
Exemple #30
0
 def replacement(x):
     return torch.neg(x)
 def backward(ctx, grad_output):
     yy, yy_pred= ctx.saved_tensors
     sum_cr = ctx.sum_cr
     eta = ctx.eta
     grad_input = torch.neg((sum_cr/eta) * (ctx.gbest - yy))
     return grad_input, grad_output, None, None, None