Ejemplo n.º 1
0
    def updateGradInput(self, input, y):
        v1 = input[0]
        v2 = input[1]

        gw1 = self.gradInput[0]
        gw2 = self.gradInput[1]
        gw1.resize_as_(v1).copy_(v2)
        gw2.resize_as_(v1).copy_(v1)

        torch.mul(self.w1, self.w22, out=self.buffer)
        gw1.addcmul_(-1, self.buffer.expand_as(v1), v1)
        gw1.mul_(self.w.expand_as(v1))

        torch.mul(self.w1, self.w32, out=self.buffer)
        gw2.addcmul_(-1, self.buffer.expand_as(v1), v2)
        gw2.mul_(self.w.expand_as(v1))

        # self._idx = self._outputs <= 0
        torch.le(self._outputs, 0, out=self._idx)
        self._idx = self._idx.view(-1, 1).expand(gw1.size())
        gw1[self._idx] = 0
        gw2[self._idx] = 0

        torch.eq(y, 1, out=self._idx)
        self._idx = self._idx.view(-1, 1).expand(gw2.size())
        gw1[self._idx] = gw1[self._idx].mul_(-1)
        gw2[self._idx] = gw2[self._idx].mul_(-1)

        if self.sizeAverage:
            gw1.div_(y.size(0))
            gw2.div_(y.size(0))

        return self.gradInput
    def __iter__(self):
        for batch in self.data:
            batch_size = len(batch)
            batch = list(zip(*batch))
            if self.eval:
                assert len(batch) == 7
            else:
                assert len(batch) == 9

            context_len = max(len(x) for x in batch[0])
            context_id = torch.LongTensor(batch_size, context_len).fill_(0)
            context_order = torch.LongTensor(batch_size,context_len).fill_(0)

            for i, doc in enumerate(batch[0]):
                context_id[i, :len(doc)] = torch.LongTensor(doc)
                context_order[i,:len(doc)] = torch.from_numpy(np.arange(1,len(doc)+1))
            feature_len = len(batch[1][0][0])
            context_feature = torch.Tensor(batch_size, context_len, feature_len).fill_(0)
            for i, doc in enumerate(batch[1]):
                for j, feature in enumerate(doc):
                    context_feature[i, j, :] = torch.Tensor(feature)

            context_tag = torch.LongTensor(batch_size, context_len).fill_(0)
            for i, doc in enumerate(batch[2]):
                context_tag[i, :len(doc)] = torch.LongTensor(doc)

            context_ent = torch.LongTensor(batch_size, context_len).fill_(0)
            for i, doc in enumerate(batch[3]):
                context_ent[i, :len(doc)] = torch.LongTensor(doc)
            question_len = max(len(x) for x in batch[4])
            question_id = torch.LongTensor(batch_size, question_len).fill_(0)
            question_order = torch.LongTensor(batch_size,question_len).fill_(0)
            for i, doc in enumerate(batch[4]):
                question_id[i, :len(doc)] = torch.LongTensor(doc)
                question_order[i,:len(doc)] = torch.from_numpy(np.arange(1,len(doc)+1))

            context_mask = torch.eq(context_id, 0)
            question_mask = torch.eq(question_id, 0)
            if not self.eval:
                y_s = torch.LongTensor(batch[5])
                y_e = torch.LongTensor(batch[6])
            text = list(batch[-2])
            span = list(batch[-1])
            if self.gpu:
                context_id = context_id.pin_memory()
                context_feature = context_feature.pin_memory()
                context_tag = context_tag.pin_memory()
                context_ent = context_ent.pin_memory()
                context_mask = context_mask.pin_memory()
                question_id = question_id.pin_memory()
                question_mask = question_mask.pin_memory()
                context_order = context_order.pin_memory()
                question_order = question_order.pin_memory()

            if self.eval:
                yield (context_id, context_feature, context_tag, context_ent, context_mask,
                       question_id, question_mask, context_order, question_order, text, span)
            else:
                yield (context_id, context_feature, context_tag, context_ent, context_mask,
                       question_id, question_mask, context_order, question_order, y_s, y_e, text, span)
Ejemplo n.º 3
0
def calc_precision(pred, label):
    t1 = torch.topk(pred, 1)[-1]
    t5 = torch.topk(pred, 5)[-1]
    mask_1 = torch.eq(t1, label.view(-1, 1))
    mask_5 = torch.eq(t5, label.view(-1, 1))
    t1_error = 1 - len(t1[mask_1]) / len(label)
    t5_error = 1 - len(t5[mask_5]) / len(label)
    return t1_error, t5_error
Ejemplo n.º 4
0
    def updateOutput(self, input, y):
        input1, input2 = input[0], input[1]

        # keep backward compatibility
        if self.buffer is None:
            self.buffer = input1.new()
            self.w1 = input1.new()
            self.w22 = input1.new()
            self.w = input1.new()
            self.w32 = input1.new()
            self._outputs = input1.new()

            # comparison operators behave differently from cuda/c implementations
            # TODO: verify name
            if input1.type() == 'torch.cuda.FloatTensor':
                self._idx = torch.cuda.ByteTensor()
            else:
                self._idx = torch.ByteTensor()

        torch.mul(input1, input2, out=self.buffer)
        torch.sum(self.buffer, 1, out=self.w1, keepdim=True)

        epsilon = 1e-12
        torch.mul(input1, input1, out=self.buffer)
        torch.sum(self.buffer, 1, out=self.w22, keepdim=True).add_(epsilon)
        # self._outputs is also used as a temporary buffer
        self._outputs.resize_as_(self.w22).fill_(1)
        torch.div(self._outputs, self.w22, out=self.w22)
        self.w.resize_as_(self.w22).copy_(self.w22)

        torch.mul(input2, input2, out=self.buffer)
        torch.sum(self.buffer, 1, out=self.w32, keepdim=True).add_(epsilon)
        torch.div(self._outputs, self.w32, out=self.w32)
        self.w.mul_(self.w32)
        self.w.sqrt_()

        torch.mul(self.w1, self.w, out=self._outputs)
        self._outputs = self._outputs.select(1, 0)

        torch.eq(y, -1, out=self._idx)
        self._outputs[self._idx] = self._outputs[self._idx].add_(-self.margin).clamp_(min=0)
        torch.eq(y, 1, out=self._idx)
        self._outputs[self._idx] = self._outputs[self._idx].mul_(-1).add_(1)

        self.output = self._outputs.sum().item()

        if self.sizeAverage:
            self.output = self.output / y.size(0)

        return self.output
Ejemplo n.º 5
0
def evaluate(attention_model,x_test,y_test):
    """
        cv results
 
        Args:
            attention_model : {object} model
            x_test          : {nplist} x_test
            y_test          : {nplist} y_test
       
        Returns:
            cv-accuracy
 
      
    """
   
    attention_model.batch_size = x_test.shape[0]
    attention_model.hidden_state = attention_model.init_hidden()
    x_test_var = Variable(torch.from_numpy(x_test).type(torch.LongTensor))
    y_test_pred,_ = attention_model(x_test_var)
    if bool(attention_model.type):
        y_preds = torch.max(y_test_pred,1)[1]
        y_test_var = Variable(torch.from_numpy(y_test).type(torch.LongTensor))
       
    else:
        y_preds = torch.round(y_test_pred.type(torch.DoubleTensor).squeeze(1))
        y_test_var = Variable(torch.from_numpy(y_test).type(torch.DoubleTensor))
       
    return torch.eq(y_preds,y_test_var).data.sum()/x_test_var.size(0)
Ejemplo n.º 6
0
    def forward(self, input, target):
        buffer = input.new()
        buffer.resize_as_(input).copy_(input)
        buffer[torch.eq(target, -1.)] = 0
        output = buffer.sum()

        buffer.fill_(self.margin).add_(-1, input)
        buffer.clamp_(min=0)
        buffer[torch.eq(target, 1.)] = 0
        output += buffer.sum()

        if self.size_average:
            output = output / input.nelement()

        self.save_for_backward(input, target)
        return input.new((output,))
Ejemplo n.º 7
0
    def forward(self, input, target):
        y_true = target.int().unsqueeze(-1)
        same_id = torch.eq(y_true, y_true.t()).type_as(input)

        pos_mask = same_id
        neg_mask = 1 - same_id

        def _mask_max(input_tensor, mask, axis=None, keepdims=False):
            input_tensor = input_tensor - 1e6 * (1 - mask)
            _max, _idx = torch.max(input_tensor, dim=axis, keepdim=keepdims)
            return _max, _idx

        def _mask_min(input_tensor, mask, axis=None, keepdims=False):
            input_tensor = input_tensor + 1e6 * (1 - mask)
            _min, _idx = torch.min(input_tensor, dim=axis, keepdim=keepdims)
            return _min, _idx

        # output[i, j] = || feature[i, :] - feature[j, :] ||_2
        dist_squared = torch.sum(input ** 2, dim=1, keepdim=True) + \
                       torch.sum(input.t() ** 2, dim=0, keepdim=True) - \
                       2.0 * torch.matmul(input, input.t())
        dist = dist_squared.clamp(min=1e-16).sqrt()

        pos_max, pos_idx = _mask_max(dist, pos_mask, axis=-1)
        neg_min, neg_idx = _mask_min(dist, neg_mask, axis=-1)

        # loss(x, y) = max(0, -y * (x1 - x2) + margin)
        y = torch.ones(same_id.size()[0]).to(DEVICE)
        return F.margin_ranking_loss(neg_min.float(),
                                     pos_max.float(),
                                     y,
                                     self.margin,
                                     self.size_average)
def test(net, testloader, config):
    total, correct = 0.0, 0.0
    for i, data in enumerate(testloader):
        # Get inputs
        X, S1, S2, labels = data
        if X.size()[0] != config.batch_size:
            continue  # Drop those data, if not enough for a batch
        # Send Tensors to GPU if available
        if use_GPU:
            X = X.cuda()
            S1 = S1.cuda()
            S2 = S2.cuda()
            labels = labels.cuda()
        # Wrap to autograd.Variable
        X, S1, S2 = Variable(X), Variable(S1), Variable(S2)
        # Forward pass
        outputs, predictions = net(X, S1, S2, config)
        # Select actions with max scores(logits)
        _, predicted = torch.max(outputs, dim=1, keepdim=True)
        # Unwrap autograd.Variable to Tensor
        predicted = predicted.data
        # Compute test accuracy
        correct += (torch.eq(torch.squeeze(predicted), labels)).sum()
        total += labels.size()[0]
    print('Test Accuracy: {:.2f}%'.format(100 * (correct / total)))
Ejemplo n.º 9
0
    def test_rescale_torch_tensor(self):
        rows, cols = 3, 5
        original_tensor = torch.randint(low=10, high=40, size=(rows, cols)).float()
        prev_max_tensor = torch.ones(1, 5) * 40.0
        prev_min_tensor = torch.ones(1, 5) * 10.0
        new_min_tensor = torch.ones(1, 5) * -1.0
        new_max_tensor = torch.ones(1, 5).float()

        print("Original tensor: ", original_tensor)
        rescaled_tensor = rescale_torch_tensor(
            original_tensor,
            new_min_tensor,
            new_max_tensor,
            prev_min_tensor,
            prev_max_tensor,
        )
        print("Rescaled tensor: ", rescaled_tensor)
        reconstructed_original_tensor = rescale_torch_tensor(
            rescaled_tensor,
            prev_min_tensor,
            prev_max_tensor,
            new_min_tensor,
            new_max_tensor,
        )
        print("Reconstructed Original tensor: ", reconstructed_original_tensor)

        comparison_tensor = torch.eq(original_tensor, reconstructed_original_tensor)
        self.assertTrue(torch.sum(comparison_tensor), rows * cols)
Ejemplo n.º 10
0
def train():
    epoch_num, loss_sum, cort_num_sum = 0, 0.0, 0
    for epoch in epoches :
            epoch_num += 1
            inputs = Variable(epoch[0])
            target = Variable(epoch[1])

            output = model(inputs)
            loss = criterion(output, target)
            # reset gradients
            optimizer.zero_grad()
            # backward pass
            loss.backward()
            # update parameters
            optimizer.step()

            # get training infomation
            loss_sum += loss.data[0]
            _, pred = torch.max(output.data, 1)


            num_correct = torch.eq(pred, epoch[1]).sum()
            cort_num_sum += num_correct

    loss_avg = loss_sum /float(epoch_num)
    cort_num_avg = cort_num_sum / float(epoch_num) /float( epoch_size)
    return loss_avg,cort_num_avg
Ejemplo n.º 11
0
    def test_serialization_built_vocab(self):
        self.write_test_ppid_dataset(data_format="tsv")
        question_field = data.Field(sequential=True)
        tsv_fields = [("id", None), ("q1", question_field),
                      ("q2", question_field), ("label", None)]
        tsv_dataset = data.TabularDataset(
            path=self.test_ppid_dataset_path, format="tsv",
            fields=tsv_fields)

        question_field.build_vocab(tsv_dataset)

        question_pickle_filename = "question.pl"
        question_pickle_path = os.path.join(self.test_dir, question_pickle_filename)
        torch.save(question_field, question_pickle_path)

        loaded_question_field = torch.load(question_pickle_path)

        assert loaded_question_field == question_field

        test_example_data = [["When", "do", "you", "use", "シ",
                              "instead", "of", "し?"],
                             ["What", "is", "2+2", "<pad>", "<pad>",
                              "<pad>", "<pad>", "<pad>"],
                             ["Here", "is", "a", "sentence", "with",
                              "some", "oovs", "<pad>"]]

        # Test results of numericalization
        original_numericalization = question_field.numericalize(test_example_data)
        pickled_numericalization = loaded_question_field.numericalize(test_example_data)

        assert torch.all(torch.eq(original_numericalization, pickled_numericalization))
    def forward(self, output, context):
        batch_size = output.size(0)
        hidden_size = output.size(2)
        input_size = context.size(1)

        # (batch, out_len, dim) * (batch, in_len, dim) -> (batch, out_len, in_len)
        attn = torch.bmm(output, context.transpose(1, 2))
        mask = torch.eq(attn, 0).data.byte()
        attn.data.masked_fill_(mask, -float('inf'))
        attn = F.softmax(attn.view(-1, input_size), dim=1).view(batch_size, -1, input_size)

        # (batch, out_len, in_len) * (batch, in_len, dim) -> (batch, out_len, dim)
        mix = torch.bmm(attn, context)

        # concat -> (batch, out_len, 2*dim)
        combined = torch.cat((mix, output), dim=2)

        # output -> (batch, out_len, dim)
        output = F.tanh(self.linear_out(combined.view(-1, 2 * hidden_size))).view(batch_size, -1, hidden_size)


        if not output.is_contiguous():
            output = output.contiguous()

        return output, attn
Ejemplo n.º 13
0
 def test_local_var_binary_methods(self):
     ''' Unit tests for methods mentioned on issue 1385
         https://github.com/OpenMined/PySyft/issues/1385'''
     x = torch.FloatTensor([1, 2, 3, 4])
     y = torch.FloatTensor([[1, 2, 3, 4]])
     z = torch.matmul(x, y.t())
     assert (torch.equal(z, torch.FloatTensor([30])))
     z = torch.add(x, y)
     assert (torch.equal(z, torch.FloatTensor([[2, 4, 6, 8]])))
     x = torch.FloatTensor([[1, 2, 3], [3, 4, 5], [5, 6, 7]])
     y = torch.FloatTensor([[1, 2, 3], [3, 4, 5], [5, 6, 7]])
     z = torch.cross(x, y, dim=1)
     assert (torch.equal(z, torch.FloatTensor([[0, 0, 0], [0, 0, 0], [0, 0, 0]])))
     x = torch.FloatTensor([[1, 2, 3], [3, 4, 5], [5, 6, 7]])
     y = torch.FloatTensor([[1, 2, 3], [3, 4, 5], [5, 6, 7]])
     z = torch.dist(x, y)
     t = torch.FloatTensor([z])
     assert (torch.equal(t, torch.FloatTensor([0.])))
     x = torch.FloatTensor([1, 2, 3])
     y = torch.FloatTensor([1, 2, 3])
     z = torch.dot(x, y)
     t = torch.FloatTensor([z])
     assert torch.equal(t, torch.FloatTensor([14]))
     z = torch.eq(x, y)
     assert (torch.equal(z, torch.ByteTensor([1, 1, 1])))
     z = torch.ge(x, y)
     assert (torch.equal(z, torch.ByteTensor([1, 1, 1])))
Ejemplo n.º 14
0
def rpn_bbox_loss(target_bbox, rpn_match, rpn_bbox, config):
    """Return the RPN bounding box loss graph.

    config: the model config object.
    target_bbox: [batch, max positive anchors, (dy, dx, log(dh), log(dw))].
        Uses 0 padding to fill in unsed bbox deltas.
    rpn_match: [batch, anchors, 1]. Anchor match type. 1=positive,
               -1=negative, 0=neutral anchor.
    rpn_bbox: [batch, anchors, (dy, dx, log(dh), log(dw))]
    """
    # Positive anchors contribute to the loss, but negative and
    # neutral anchors (match value of 0 or -1) don't.   
    indices = torch.eq(rpn_match, 1) 
    rpn_bbox = torch.masked_select(rpn_bbox, indices)
    batch_counts = torch.sum(indices.float(), dim=1)
        
    outputs = []
    for i in range(config.IMAGES_PER_GPU):
#        print(batch_counts[i].cpu().data.numpy()[0])
        outputs.append(target_bbox[i, torch.arange(int(batch_counts[i].cpu().data.numpy()[0])).type(torch.cuda.LongTensor)])
    
    target_bbox = torch.cat(outputs, dim=0)
    
    loss = F.smooth_l1_loss(rpn_bbox, target_bbox, size_average=True)
    return loss
Ejemplo n.º 15
0
    def test_train(self):
        self._metric.train()
        calls = [[torch.FloatTensor([0.0]), torch.LongTensor([0])],
                 [torch.FloatTensor([0.0, 0.1, 0.2, 0.3]), torch.LongTensor([0, 1, 2, 3])]]
        for i in range(len(self._states)):
            self._metric.process(self._states[i])
        self.assertEqual(2, len(self._metric_function.call_args_list))
        for i in range(len(self._metric_function.call_args_list)):
            self.assertTrue(torch.eq(self._metric_function.call_args_list[i][0][0], calls[i][0]).all)
            self.assertTrue(torch.lt(torch.abs(torch.add(self._metric_function.call_args_list[i][0][1], -calls[i][1])), 1e-12).all)
        self._metric_function.reset_mock()
        self._metric.process_final({})

        self._metric_function.assert_called_once()
        self.assertTrue(torch.eq(self._metric_function.call_args_list[0][0][1], torch.LongTensor([0, 1, 2, 3, 4])).all)
        self.assertTrue(torch.lt(torch.abs(torch.add(self._metric_function.call_args_list[0][0][0], -torch.FloatTensor([0.0, 0.1, 0.2, 0.3, 0.4]))), 1e-12).all)
Ejemplo n.º 16
0
def knn(Mxx, Mxy, Myy, k, sqrt):
    n0 = Mxx.size(0)
    n1 = Myy.size(0)
    label = torch.cat((torch.ones(n0),torch.zeros(n1)))
    M = torch.cat((torch.cat((Mxx,Mxy),1), torch.cat((Mxy.transpose(0,1),Myy), 1)), 0)
    if sqrt:
        M = M.abs().sqrt()
    INFINITY = float('inf')
    val, idx = (M+torch.diag(INFINITY*torch.ones(n0+n1))).topk(k, 0, False)

    count = torch.zeros(n0+n1)
    for i in range(0,k):
        count = count + label.index_select(0,idx[i])
    pred = torch.ge(count, (float(k)/2)*torch.ones(n0+n1)).float()

    s = Score_knn()
    s.tp = (pred*label).sum()
    s.fp = (pred*(1-label)).sum()
    s.fn = ((1-pred)*label).sum()
    s.tn = ((1-pred)*(1-label)).sum()
    s.precision = s.tp/(s.tp+s.fp)
    s.recall = s.tp/(s.tp+s.fn)
    s.acc_t = s.tp/(s.tp+s.fn)
    s.acc_f = s.tn/(s.tn+s.fp)
    s.acc = torch.eq(label, pred).float().mean()
    s.k = k 

    return s
Ejemplo n.º 17
0
    def updateOutput(self, input, y):
        if self.buffer is None:
            self.buffer = input.new()
        self.buffer.resize_as_(input).copy_(input)
        self.buffer[torch.eq(y, -1.)] = 0
        self.output = self.buffer.sum()

        self.buffer.fill_(self.margin).add_(-1, input)
        self.buffer.clamp_(min=0)
        self.buffer[torch.eq(y, 1.)] = 0
        self.output = self.output + self.buffer.sum()

        if self.sizeAverage:
            self.output = self.output / input.nelement()

        return self.output
Ejemplo n.º 18
0
    def test_remote_var_binary_methods(self):
        ''' Unit tests for methods mentioned on issue 1385
            https://github.com/OpenMined/PySyft/issues/1385'''
        hook = TorchHook(verbose=False)
        local = hook.local_worker
        remote = VirtualWorker(hook, 1)
        local.add_worker(remote)

        x = Var(torch.FloatTensor([1, 2, 3, 4])).send(remote)
        y = Var(torch.FloatTensor([[1, 2, 3, 4]])).send(remote)
        z = torch.matmul(x, y.t())
        assert (torch.equal(z.get(), Var(torch.FloatTensor([30]))))
        z = torch.add(x, y)
        assert (torch.equal(z.get(), Var(torch.FloatTensor([[2, 4, 6, 8]]))))
        x = Var(torch.FloatTensor([[1, 2, 3], [3, 4, 5], [5, 6, 7]])).send(remote)
        y = Var(torch.FloatTensor([[1, 2, 3], [3, 4, 5], [5, 6, 7]])).send(remote)
        z = torch.cross(x, y, dim=1)
        assert (torch.equal(z.get(), Var(torch.FloatTensor([[0, 0, 0], [0, 0, 0], [0, 0, 0]]))))
        x = Var(torch.FloatTensor([[1, 2, 3], [3, 4, 5], [5, 6, 7]])).send(remote)
        y = Var(torch.FloatTensor([[1, 2, 3], [3, 4, 5], [5, 6, 7]])).send(remote)
        z = torch.dist(x, y)
        assert (torch.equal(z.get(), Var(torch.FloatTensor([0.]))))
        x = Var(torch.FloatTensor([1, 2, 3])).send(remote)
        y = Var(torch.FloatTensor([1, 2, 3])).send(remote)
        z = torch.dot(x, y)
        print(torch.equal(z.get(), Var(torch.FloatTensor([14]))))
        z = torch.eq(x, y)
        assert (torch.equal(z.get(), Var(torch.ByteTensor([1, 1, 1]))))
        z = torch.ge(x, y)
        assert (torch.equal(z.get(), Var(torch.ByteTensor([1, 1, 1]))))
Ejemplo n.º 19
0
    def test_serialization(self):
        nesting_field = data.Field(batch_first=True)
        field = data.NestedField(nesting_field)
        ex1 = data.Example.fromlist(["john loves mary"], [("words", field)])
        ex2 = data.Example.fromlist(["mary cries"], [("words", field)])
        dataset = data.Dataset([ex1, ex2], [("words", field)])
        field.build_vocab(dataset)
        examples_data = [
            [
                ["<w>", "<s>", "</w>"] + ["<cpad>"] * 4,
                ["<w>"] + list("john") + ["</w>", "<cpad>"],
                ["<w>"] + list("loves") + ["</w>"],
                ["<w>"] + list("mary") + ["</w>", "<cpad>"],
                ["<w>", "</s>", "</w>"] + ["<cpad>"] * 4,
            ],
            [
                ["<w>", "<s>", "</w>"] + ["<cpad>"] * 4,
                ["<w>"] + list("mary") + ["</w>", "<cpad>"],
                ["<w>"] + list("cries") + ["</w>"],
                ["<w>", "</s>", "</w>"] + ["<cpad>"] * 4,
                ["<cpad>"] * 7,
            ]
        ]

        field_pickle_filename = "char_field.pl"
        field_pickle_path = os.path.join(self.test_dir, field_pickle_filename)
        torch.save(field, field_pickle_path)

        loaded_field = torch.load(field_pickle_path)
        assert loaded_field == field

        original_numericalization = field.numericalize(examples_data)
        pickled_numericalization = loaded_field.numericalize(examples_data)

        assert torch.all(torch.eq(original_numericalization, pickled_numericalization))
Ejemplo n.º 20
0
 def batch_log_pdf(self, x):
     """
     Ref: :py:meth:`pyro.distributions.distribution.Distribution.batch_log_pdf`
     """
     v = self.v
     v = v.expand(self.shape(x))
     batch_shape = self.batch_shape(x) + (1,)
     return torch.sum(torch.eq(x, v).float().log(), -1).contiguous().view(batch_shape)
Ejemplo n.º 21
0
    def updateGradInput(self, input, y):
        self.gradInput.resize_as_(input).copy_(y)
        self.gradInput[torch.mul(torch.eq(y, -1), torch.gt(input, self.margin))] = 0

        if self.sizeAverage:
            self.gradInput.mul_(1. / input.nelement())

        return self.gradInput
Ejemplo n.º 22
0
    def forward(ctx, input, target, margin, size_average):
        ctx.margin = margin
        ctx.size_average = size_average
        buffer = input.new()
        buffer.resize_as_(input).copy_(input)
        buffer[torch.eq(target, -1.)] = 0
        output = buffer.sum()

        buffer.fill_(ctx.margin).add_(-1, input)
        buffer.clamp_(min=0)
        buffer[torch.eq(target, 1.)] = 0
        output += buffer.sum()

        if ctx.size_average:
            output = output / input.nelement()

        ctx.save_for_backward(input, target)
        return input.new((output,))
Ejemplo n.º 23
0
    def landmark_loss(self,gt_label,gt_landmark,pred_landmark):
        mask = torch.eq(gt_label,-2)

        chose_index = torch.nonzero(mask.data)
        chose_index = torch.squeeze(chose_index)

        valid_gt_landmark = gt_landmark[chose_index, :]
        valid_pred_landmark = pred_landmark[chose_index, :]
        return self.loss_landmark(valid_pred_landmark, valid_gt_landmark)
Ejemplo n.º 24
0
    def test_validate(self):
        self._metric.eval()
        for i in range(len(self._states)):
            self._metric.process(self._states[i])
        self._metric_function.assert_not_called()
        self._metric.process_final_validate({})

        self._metric_function.assert_called_once()
        self.assertTrue(torch.eq(self._metric_function.call_args_list[0][0][1], torch.LongTensor([0, 1, 2, 3, 4])).all)
        self.assertTrue(torch.lt(torch.abs(torch.add(self._metric_function.call_args_list[0][0][0], -torch.FloatTensor([0.0, 0.1, 0.2, 0.3, 0.4]))), 1e-12).all)
Ejemplo n.º 25
0
    def forward(ctx, input1, input2, y, margin, size_average):
        ctx.margin = margin
        ctx.size_average = size_average
        ctx.w1 = input1.new()
        ctx.w22 = input1.new()
        ctx.w = input1.new()
        ctx.w32 = input1.new()
        ctx._outputs = input1.new()

        _idx = input1.new().byte()

        buffer = torch.mul(input1, input2)
        torch.sum(buffer, 1, out=ctx.w1, keepdim=True)

        epsilon = 1e-12
        torch.mul(input1, input1, out=buffer)
        torch.sum(buffer, 1, out=ctx.w22, keepdim=True).add_(epsilon)

        ctx._outputs.resize_as_(ctx.w22).fill_(1)
        torch.div(ctx._outputs, ctx.w22, out=ctx.w22)
        ctx.w.resize_as_(ctx.w22).copy_(ctx.w22)

        torch.mul(input2, input2, out=buffer)
        torch.sum(buffer, 1, out=ctx.w32, keepdim=True).add_(epsilon)
        torch.div(ctx._outputs, ctx.w32, out=ctx.w32)
        ctx.w.mul_(ctx.w32)
        ctx.w.sqrt_()

        torch.mul(ctx.w1, ctx.w, out=ctx._outputs)
        ctx._outputs = ctx._outputs.select(1, 0)

        torch.eq(y, -1, out=_idx)
        ctx._outputs[_idx] = ctx._outputs[_idx].add_(-ctx.margin).clamp_(min=0)
        torch.eq(y, 1, out=_idx)
        ctx._outputs[_idx] = ctx._outputs[_idx].mul_(-1).add_(1)

        output = ctx._outputs.sum()

        if ctx.size_average:
            output = output / y.size(0)

        ctx.save_for_backward(input1, input2, y)
        return input1.new((output,))
Ejemplo n.º 26
0
    def forward(self, input1, input2, y):
        self.w1 = input1.new()
        self.w22 = input1.new()
        self.w = input1.new()
        self.w32 = input1.new()
        self._outputs = input1.new()

        _idx = input1.new().byte()

        buffer = torch.mul(input1, input2)
        torch.sum(buffer, 1, out=self.w1, keepdim=True)

        epsilon = 1e-12
        torch.mul(input1, input1, out=buffer)
        torch.sum(buffer, 1, out=self.w22, keepdim=True).add_(epsilon)

        self._outputs.resize_as_(self.w22).fill_(1)
        torch.div(self._outputs, self.w22, out=self.w22)
        self.w.resize_as_(self.w22).copy_(self.w22)

        torch.mul(input2, input2, out=buffer)
        torch.sum(buffer, 1, out=self.w32, keepdim=True).add_(epsilon)
        torch.div(self._outputs, self.w32, out=self.w32)
        self.w.mul_(self.w32)
        self.w.sqrt_()

        torch.mul(self.w1, self.w, out=self._outputs)
        self._outputs = self._outputs.select(1, 0)

        torch.eq(y, -1, out=_idx)
        self._outputs[_idx] = self._outputs[_idx].add_(-self.margin).clamp_(min=0)
        torch.eq(y, 1, out=_idx)
        self._outputs[_idx] = self._outputs[_idx].mul_(-1).add_(1)

        output = self._outputs.sum()

        if self.size_average:
            output = output / y.size(0)

        self.save_for_backward(input1, input2, y)
        return input1.new((output,))
Ejemplo n.º 27
0
    def backward(self, grad_output):
        input, target = self.saved_tensors
        grad_input = input.new().resize_as_(input).copy_(target)
        grad_input[torch.mul(torch.eq(target, -1), torch.gt(input, self.margin))] = 0

        if self.size_average:
            grad_input.mul_(1. / input.nelement())

        if grad_output[0] != 1:
            grad_input.mul_(grad_output[0])

        return grad_input, None
Ejemplo n.º 28
0
    def compute_accuracy(self, prob_cls, gt_cls):
        #we only need the detection which >= 0
        prob_cls = torch.squeeze(prob_cls)
        mask = torch.ge(gt_cls, 0)
        #get valid element
        valid_gt_cls = torch.masked_select(gt_cls, mask)
        valid_prob_cls = torch.masked_select(prob_cls, mask)
        size = min(valid_gt_cls.size()[0], valid_prob_cls.size()[0])
        prob_ones = torch.ge(valid_prob_cls, 0.6).float()
        right_ones = torch.eq(prob_ones, valid_gt_cls.float()).float()

        return torch.div(torch.mul(torch.sum(right_ones), float(1.0)), float(size))
Ejemplo n.º 29
0
def compute_accuracy(y_pred, y_true, mask_index):
    y_pred, y_true = normalize_sizes(y_pred, y_true)

    _, y_pred_indices = y_pred.max(dim=1)

    correct_indices = torch.eq(y_pred_indices, y_true).float()
    valid_indices = torch.ne(y_true, mask_index).float()

    n_correct = (correct_indices * valid_indices).sum().item()
    n_valid = valid_indices.sum().item()

    return n_correct / n_valid * 100
Ejemplo n.º 30
0
def accuracy(cls_score,cls_labels):
    class_dim = cls_score.dim()-1
    argmax=torch.max(torch.nn.functional.softmax(cls_score,dim=class_dim),class_dim)[1]
    accuracy = torch.mean(torch.eq(argmax,cls_labels.long()).float())
    return accuracy

# class detector_loss(torch.nn.Module):
#     def __init__(self, do_loss_cls=True, do_loss_bbox=True, do_accuracy_cls=True):
#         super(detector_loss, self).__init__()
#         # Flags
#         self.do_loss_cls = do_loss_cls
#         self.do_loss_bbox = do_loss_bbox
#         self.do_accuracy_cls = do_accuracy_cls
#         # Dicts for losses 
#         # self.losses={}
#         # if do_loss_cls:
#         #     self.losses['loss_cls']=0
#         # if do_loss_bbox:
#         #     self.losses['loss_bbox']=0
#         # # Dicts for metrics       
#         # self.metrics={}
#         # if do_accuracy_cls:
#         #     self.metrics['accuracy_cls']=0

#     def forward(self,
#             cls_score,
#             cls_labels,
#             bbox_pred,
#             bbox_targets,
#             bbox_inside_weights,
#             bbox_outside_weights):

#         # compute losses
#         losses=[]
#         if self.do_loss_cls:
#             loss_cls = cross_entropy(cls_score,cls_labels.long())
#             losses.append(loss_cls)
#         if self.do_loss_bbox:
#             loss_bbox = smooth_L1(bbox_pred,bbox_targets,bbox_inside_weights,bbox_outside_weights)
#             losses.append(loss_bbox)

#         # # compute metrics
#         # if self.do_accuracy_cls:
#         #     self.metrics['accuracy_cls'] = accuracy(cls_score,cls_labels.long())

#         # sum total loss
#         #loss = torch.sum(torch.cat(tuple([v.unsqueeze(0) for v in losses]),0))        

#         # loss.register_hook(printmax)

#         return tuple(losses)
        
Ejemplo n.º 31
0
    def forward(self, classifications, regressions, anchors, annotations):
        alpha = 0.25
        gamma = 2.0
        batch_size = classifications.shape[0]
        classification_losses = []
        regression_losses = []

        anchor = anchors[0, :, :]

        anchor_widths = anchor[:, 2] - anchor[:, 0]
        anchor_heights = anchor[:, 3] - anchor[:, 1]
        anchor_ctr_x = anchor[:, 0] + 0.5 * anchor_widths
        anchor_ctr_y = anchor[:, 1] + 0.5 * anchor_heights

        for j in range(batch_size):

            classification = classifications[j, :, :]
            regression = regressions[j, :, :]

            bbox_annotation = annotations[j, :, :]
            bbox_annotation = bbox_annotation[bbox_annotation[:, 4] != -1]

            if bbox_annotation.shape[0] == 0:
                regression_losses.append(torch.tensor(0).float().to(DEVICE))
                classification_losses.append(
                    torch.tensor(0).float().to(DEVICE))

                continue

            classification = torch.clamp(classification, 1e-4, 1.0 - 1e-4)

            IoU = calc_iou(
                anchors[0, :, :],
                bbox_annotation[:, :4])  # num_anchors x num_annotations

            IoU_max, IoU_argmax = torch.max(IoU, dim=1)  # num_anchors x 1

            # import pdb
            # pdb.set_trace()

            # compute the loss for classification
            targets = torch.ones(classification.shape) * -1
            targets = targets.to(DEVICE)

            targets[torch.lt(IoU_max, 0.4), :] = 0

            positive_indices = torch.ge(IoU_max, 0.5)

            num_positive_anchors = positive_indices.sum()

            assigned_annotations = bbox_annotation[IoU_argmax, :]

            targets[positive_indices, :] = 0
            targets[positive_indices, assigned_annotations[positive_indices,
                                                           4].long()] = 1

            alpha_factor = torch.ones(targets.shape).to(DEVICE) * alpha

            alpha_factor = torch.where(torch.eq(targets, 1.), alpha_factor,
                                       1. - alpha_factor)
            focal_weight = torch.where(torch.eq(targets, 1.),
                                       1. - classification, classification)
            focal_weight = alpha_factor * torch.pow(focal_weight, gamma)

            bce = -(targets * torch.log(classification) +
                    (1.0 - targets) * torch.log(1.0 - classification))

            # cls_loss = focal_weight * torch.pow(bce, gamma)
            cls_loss = focal_weight * bce

            cls_loss = torch.where(torch.ne(targets, -1.0), cls_loss,
                                   torch.zeros(cls_loss.shape).to(DEVICE))

            classification_losses.append(
                cls_loss.sum() /
                torch.clamp(num_positive_anchors.float(), min=1.0))

            # compute the loss for regression

            if positive_indices.sum() > 0:
                assigned_annotations = assigned_annotations[
                    positive_indices, :]

                anchor_widths_pi = anchor_widths[positive_indices]
                anchor_heights_pi = anchor_heights[positive_indices]
                anchor_ctr_x_pi = anchor_ctr_x[positive_indices]
                anchor_ctr_y_pi = anchor_ctr_y[positive_indices]

                gt_widths = assigned_annotations[:,
                                                 2] - assigned_annotations[:,
                                                                           0]
                gt_heights = assigned_annotations[:,
                                                  3] - assigned_annotations[:,
                                                                            1]
                gt_ctr_x = assigned_annotations[:, 0] + 0.5 * gt_widths
                gt_ctr_y = assigned_annotations[:, 1] + 0.5 * gt_heights

                # clip widths to 1
                gt_widths = torch.clamp(gt_widths, min=1)
                gt_heights = torch.clamp(gt_heights, min=1)

                targets_dx = (gt_ctr_x - anchor_ctr_x_pi) / anchor_widths_pi
                targets_dy = (gt_ctr_y - anchor_ctr_y_pi) / anchor_heights_pi
                targets_dw = torch.log(gt_widths / anchor_widths_pi)
                targets_dh = torch.log(gt_heights / anchor_heights_pi)

                targets = torch.stack(
                    (targets_dx, targets_dy, targets_dw, targets_dh))
                targets = targets.t()

                targets = targets / torch.Tensor([[0.1, 0.1, 0.2, 0.2]
                                                  ]).to(DEVICE)

                negative_indices = 1 - positive_indices

                regression_diff = torch.abs(targets -
                                            regression[positive_indices, :])

                regression_loss = torch.where(
                    torch.le(regression_diff, 1.0 / 9.0),
                    0.5 * 9.0 * torch.pow(regression_diff, 2),
                    regression_diff - 0.5 / 9.0)
                regression_losses.append(regression_loss.mean())
            else:
                regression_losses.append(torch.tensor(0).float().to(DEVICE))

        return torch.stack(classification_losses).mean(
            dim=0,
            keepdim=True), torch.stack(regression_losses).mean(dim=0,
                                                               keepdim=True)
Ejemplo n.º 32
0
    def forward(self, classifications, regressions, anchors, annotations,
                **kwargs):
        alpha = 0.25
        gamma = 2.0
        batch_size = classifications.shape[0]
        classification_losses = []
        regression_losses = []

        anchor = anchors[
            0, :, :]  # assuming all image sizes are the same, which it is
        dtype = anchors.dtype

        anchor_widths = anchor[:, 2] - anchor[:, 0]
        anchor_heights = anchor[:, 3] - anchor[:, 1]
        anchor_ctr_x = anchor[:, 0] + 0.5 * anchor_widths
        anchor_ctr_y = anchor[:, 1] + 0.5 * anchor_heights

        for j in range(batch_size):

            classification = classifications[j, :, :]
            regression = regressions[j, :, :]

            bbox_annotation = annotations[j]
            bbox_annotation = bbox_annotation[bbox_annotation[:, 4] != -1]

            if bbox_annotation.shape[0] == 0:
                if torch.cuda.is_available():
                    regression_losses.append(torch.tensor(0).to(dtype).cuda())
                    classification_losses.append(
                        torch.tensor(0).to(dtype).cuda())
                else:
                    regression_losses.append(torch.tensor(0).to(dtype))
                    classification_losses.append(torch.tensor(0).to(dtype))

                continue

            classification = torch.clamp(classification, 1e-4, 1.0 - 1e-4)

            IoU = calc_iou(anchor[:, :], bbox_annotation[:, :4])

            IoU_max, IoU_argmax = torch.max(IoU, dim=1)

            # compute the loss for classification
            targets = torch.ones_like(classification) * -1
            if torch.cuda.is_available():
                targets = targets.cuda()

            targets[torch.lt(IoU_max, 0.4), :] = 0

            positive_indices = torch.ge(IoU_max, 0.5)

            num_positive_anchors = positive_indices.sum()

            assigned_annotations = bbox_annotation[IoU_argmax, :]

            targets[positive_indices, :] = 0
            targets[positive_indices, assigned_annotations[positive_indices,
                                                           4].long()] = 1

            alpha_factor = torch.ones_like(targets) * alpha
            if torch.cuda.is_available():
                alpha_factor = alpha_factor.cuda()

            alpha_factor = torch.where(torch.eq(targets, 1.), alpha_factor,
                                       1. - alpha_factor)
            focal_weight = torch.where(torch.eq(targets, 1.),
                                       1. - classification, classification)
            focal_weight = alpha_factor * torch.pow(focal_weight, gamma)

            bce = -(targets * torch.log(classification) +
                    (1.0 - targets) * torch.log(1.0 - classification))

            cls_loss = focal_weight * bce

            zeros = torch.zeros_like(cls_loss)
            if torch.cuda.is_available():
                zeros = zeros.cuda()
            cls_loss = torch.where(torch.ne(targets, -1.0), cls_loss, zeros)

            classification_losses.append(
                cls_loss.sum() /
                torch.clamp(num_positive_anchors.to(dtype), min=1.0))

            if positive_indices.sum() > 0:
                assigned_annotations = assigned_annotations[
                    positive_indices, :]

                anchor_widths_pi = anchor_widths[positive_indices]
                anchor_heights_pi = anchor_heights[positive_indices]
                anchor_ctr_x_pi = anchor_ctr_x[positive_indices]
                anchor_ctr_y_pi = anchor_ctr_y[positive_indices]

                gt_widths = assigned_annotations[:,
                                                 2] - assigned_annotations[:,
                                                                           0]
                gt_heights = assigned_annotations[:,
                                                  3] - assigned_annotations[:,
                                                                            1]
                gt_ctr_x = assigned_annotations[:, 0] + 0.5 * gt_widths
                gt_ctr_y = assigned_annotations[:, 1] + 0.5 * gt_heights

                # efficientdet style
                gt_widths = torch.clamp(gt_widths, min=1)
                gt_heights = torch.clamp(gt_heights, min=1)

                targets_dx = (gt_ctr_x - anchor_ctr_x_pi) / anchor_widths_pi
                targets_dy = (gt_ctr_y - anchor_ctr_y_pi) / anchor_heights_pi
                targets_dw = torch.log(gt_widths / anchor_widths_pi)
                targets_dh = torch.log(gt_heights / anchor_heights_pi)

                targets = torch.stack(
                    (targets_dy, targets_dx, targets_dh, targets_dw))
                targets = targets.t()

                regression_diff = torch.abs(targets -
                                            regression[positive_indices, :])

                regression_loss = torch.where(
                    torch.le(regression_diff, 1.0 / 9.0),
                    0.5 * 9.0 * torch.pow(regression_diff, 2),
                    regression_diff - 0.5 / 9.0)
                regression_losses.append(regression_loss.mean())
            else:
                if torch.cuda.is_available():
                    regression_losses.append(torch.tensor(0).to(dtype).cuda())
                else:
                    regression_losses.append(torch.tensor(0).to(dtype))

        # debug
        imgs = kwargs.get('imgs', None)
        if imgs is not None:
            regressBoxes = BBoxTransform()
            clipBoxes = ClipBoxes()
            obj_list = kwargs.get('obj_list', None)
            out = postprocess(
                imgs.detach(),
                torch.stack([anchors[0]] * imgs.shape[0], 0).detach(),
                regressions.detach(), classifications.detach(), regressBoxes,
                clipBoxes, 0.5, 0.5)
            imgs = imgs.permute(0, 2, 3, 1).cpu().numpy()
            imgs = ((imgs * [0.229, 0.224, 0.225] + [0.485, 0.456, 0.406]) *
                    255).astype(np.uint8)
            imgs = [cv2.cvtColor(img, cv2.COLOR_RGB2BGR) for img in imgs]
            display(out, imgs, obj_list, imshow=False, imwrite=True)

        return torch.stack(classification_losses).mean(dim=0, keepdim=True), \
               torch.stack(regression_losses).mean(dim=0, keepdim=True)
    def test_without_empty_list(self):
        for device in self.devices:
            s = '''
                0 1 0 0
                0 1 1 0
                1 2 -1 0
                2
            '''
            scores = torch.tensor([1, 2, 3],
                                  dtype=torch.float32,
                                  device=device,
                                  requires_grad=True)
            scores_copy = scores.detach().clone().requires_grad_(True)
            src = k2.Fsa.from_str(s).to(device)
            src.scores = scores
            # see https://git.io/Jufpl
            src.attr1 = "hello"
            src.attr2 = "k2"
            float_attr = torch.tensor([0.1, 0.2, 0.3],
                                      dtype=torch.float32,
                                      requires_grad=True,
                                      device=device)

            src.float_attr = float_attr.detach().clone().requires_grad_(True)
            src.int_attr = torch.tensor([1, 2, 3],
                                        dtype=torch.int32,
                                        device=device)
            src.ragged_attr = k2.RaggedTensor([[10, 20], [30, 40, 50],
                                               [60, 70]]).to(device)

            ragged_arc, arc_map = _k2.remove_epsilon(src.arcs, src.properties)
            # see https://git.io/Jufpe
            dest = k2.utils.fsa_from_unary_function_ragged(
                src, ragged_arc, arc_map)
            assert dest.attr1 == src.attr1
            assert dest.attr2 == src.attr2

            expected_arc_map = k2.RaggedTensor([[1], [0, 2], [2]]).to(device)
            assert arc_map == expected_arc_map

            expected_int_attr = k2.RaggedTensor([[2], [1, 3], [3]]).to(device)
            assert dest.int_attr == expected_int_attr

            expected_ragged_attr = k2.RaggedTensor([[30, 40, 50],
                                                    [10, 20, 60, 70],
                                                    [60, 70]]).to(device)
            assert dest.ragged_attr == expected_ragged_attr

            expected_float_attr = torch.empty_like(dest.float_attr)
            expected_float_attr[0] = float_attr[1]
            expected_float_attr[1] = float_attr[0] + float_attr[2]
            expected_float_attr[2] = float_attr[2]

            assert torch.all(torch.eq(dest.float_attr, expected_float_attr))

            expected_scores = torch.empty_like(dest.scores)
            expected_scores[0] = scores_copy[1]
            expected_scores[1] = scores_copy[0] + scores_copy[2]
            expected_scores[2] = scores_copy[2]

            assert torch.all(torch.eq(dest.scores, expected_scores))

            scale = torch.tensor([10, 20, 30]).to(float_attr)

            (dest.float_attr * scale).sum().backward()
            (expected_float_attr * scale).sum().backward()
            assert torch.all(torch.eq(src.float_attr.grad, float_attr.grad))

            (dest.scores * scale).sum().backward()
            (expected_scores * scale).sum().backward()
            assert torch.all(torch.eq(scores.grad, scores_copy.grad))
Ejemplo n.º 34
0
    def __call__(self, inputs, labels, train_iter, epoch):
        debug = False
        visualize_training = False
        tmp_show_epoches_list = []

        # if show_epochs_list is empty, all epoches should be plotted. Therefore, add current epoch to the list
        if not self.show_epochs_list:
            tmp_show_epoches_list.append(epoch)
        else:
            tmp_show_epoches_list = self.show_epochs_list

        # check if epoch should be visualized
        if epoch in tmp_show_epoches_list:
            # print the model's parameter dimensions etc in the first iter
            if (train_iter == 0 and epoch == 0):
                debug = True
            # visualize training on the last iteration in that epoch
            elif (train_iter == 1 and epoch == 0) or (train_iter
                                                      == self.max_train_iters):
                visualize_training = True

        # for nitorch models which have a 'debug' and 'visualize_training' switch in the
        # forward() method

        if (isinstance(self.model, nn.DataParallel)):
            self.model.module.set_debug(debug)
        else:
            self.model.set_debug(debug)

        outputs, encoder_out = self.model(inputs)

        if (visualize_training):
            # check if result should be plotted in PDF
            if self.plot_pdf_path != "":
                pp = PdfPages(
                    os.path.join(
                        self.plot_pdf_path,
                        "training_epoch_" + str(epoch) + "_visualization.pdf"))
            else:
                pp = None

            # show only the first image in the batch
            if pp is None:
                # input image
                show_brain(inputs[0].squeeze().cpu().detach().numpy(),
                           draw_cross=False,
                           cmap=self.cmap)
                plt.suptitle("Input image")
                plt.show()
                if (not torch.all(torch.eq(inputs[0], labels[0]))):
                    show_brain(labels[0].squeeze().cpu().detach().numpy(),
                               draw_cross=False,
                               cmap=self.cmap)
                    plt.suptitle("Expected reconstruction")
                    plt.show()
                # reconstructed image
                show_brain(outputs[0].squeeze().cpu().detach().numpy(),
                           draw_cross=False,
                           cmap=self.cmap)
                plt.suptitle("Reconstructed Image")
                plt.show()
                # statistics
                print(
                    "\nStatistics of expected reconstruction:\n(min, max)=({:.4f}, {:.4f})\nmean={:.4f}\nstd={:.4f}"
                    .format(labels[0].min(), labels[0].max(), labels[0].mean(),
                            labels[0].std()))
                print(
                    "\nStatistics of Reconstructed image:\n(min, max)=({:.4f}, {:.4f})\nmean={:.4f}\nstd={:.4f}"
                    .format(outputs[0].min(), outputs[0].max(),
                            outputs[0].mean(), outputs[0].std()))
                # feature maps
                visualize_feature_maps(encoder_out[0])
                plt.suptitle("Encoder output")
                plt.show()
            else:
                # input image
                fig = show_brain(inputs[0].squeeze().cpu().detach().numpy(),
                                 draw_cross=False,
                                 return_fig=True,
                                 cmap=self.cmap)
                plt.suptitle("Input image")
                pp.savefig(fig)
                plt.close(fig)
                if (not torch.all(torch.eq(inputs[0], labels[0]))):
                    fig = show_brain(
                        labels[0].squeeze().cpu().detach().numpy(),
                        draw_cross=False,
                        cmap=self.cmap)
                    plt.suptitle("Expected reconstruction")
                    pp.savefig(fig)
                    plt.close(fig)
                # reconstructed image
                fig = show_brain(outputs[0].squeeze().cpu().detach().numpy(),
                                 draw_cross=False,
                                 return_fig=True,
                                 cmap=self.cmap)
                plt.suptitle("Reconstructed Image")
                pp.savefig(fig)
                plt.close(fig)
                # feature maps
                if self.plotFeatures:
                    fig = visualize_feature_maps(encoder_out[0],
                                                 return_fig=True)
                    plt.suptitle("Encoder output")
                    pp.savefig(fig)
                    plt.close(fig)

            # close the PDF
            if pp is not None:
                pp.close()

        if (isinstance(self.model, nn.DataParallel)):
            self.model.module.set_debug(False)
        else:
            self.model.set_debug(False)

        return outputs
Ejemplo n.º 35
0
    def Actor_Critic(self, V_es, taken_actions, action_log_policies):
        cfg = self.cfg
        l = cfg.gamma
        n = cfg.n_step
        if n < 0:
            print "INFO: 1 <= n step !"
            exit()

        #Building gamma matrix to calculate return for each step.
        powers = np.arange(cfg.max_s_len)
        bases = np.full((1, cfg.max_s_len), l)
        rows = np.power(bases, powers)
        inverse_rows = 1.0 / rows
        inverse_cols = inverse_rows.reshape((cfg.max_s_len, 1))
        gammaM = np.tril(np.triu(np.multiply(inverse_cols, rows)), k=n - 1)
        gM_tensor = torch.from_numpy(gammaM.T).float()
        """
            for n = 3, gamma=0.9
            gM_tensor:

            array(
                    [[1.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0. ],
                    [0.9 , 1.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  ],
                    [0.81, 0.9 , 1.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  ],
                    [0.  , 0.81, 0.9 , 1.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  ],
                    [0.  , 0.  , 0.81, 0.9 , 1.  , 0.  , 0.  , 0.  , 0.  , 0.  ],
                    [0.  , 0.  , 0.  , 0.81, 0.9 , 1.  , 0.  , 0.  , 0.  , 0.  ],
                    [0.  , 0.  , 0.  , 0.  , 0.81, 0.9 , 1.  , 0.  , 0.  , 0.  ],
                    [0.  , 0.  , 0.  , 0.  , 0.  , 0.81, 0.9 , 1.  , 0.  , 0.  ],
                    [0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.81, 0.9 , 1.  , 0.  ],
                    [0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.81, 0.9 , 1.  ]
                    ])
        """

        if hasCuda:
            gM = Variable(gM_tensor.cuda(), requires_grad=False)
        else:
            gM = Variable(gM_tensor, requires_grad=False)

        tag = Variable(cfg.B['tag'].cuda()) if hasCuda else Variable(
            cfg.B['tag'])
        w_mask = Variable(cfg.B['w_mask'].cuda()) if hasCuda else Variable(
            cfg.B['w_mask'])

        is_true_tag = torch.eq(taken_actions, tag)
        #0/1 reward (hamming loss) for each prediction.
        rewards = is_true_tag.float() * w_mask
        V_es = V_es * w_mask
        Returns = torch.matmul(rewards, gM)
        for i in range(cfg.max_s_len - n):
            Returns[:,
                    i].data = Returns[:, i].data + (l**n) * V_es[:, i + n].data

        advantages = Returns - V_es
        pos_neq = torch.ge(advantages, 0.0).float()
        signs = torch.eq(pos_neq, rewards).float()

        #Do not back propagate through Returns and V_es!
        biased_advantages = signs * advantages
        if hasCuda:
            deltas = Variable(biased_advantages.data.cuda(),
                              requires_grad=False)
        else:
            deltas = Variable(biased_advantages.data, requires_grad=False)

        rlloss = -torch.mean(
            torch.mean(action_log_policies * deltas * w_mask, dim=1), dim=0)
        vloss = self.V_loss(Returns, V_es)
        return rlloss, vloss
Ejemplo n.º 36
0
    def test_unsqueeze(self):
        """ torch.unsqueeze の使い方を確認する.
        参考文献. https://github.com/zhouhaoyi/Informer2020/blob/main/models/attn.py
        """
        # バッチサイズ1, 系列長5, ヘッド数1, 特徴次元数4
        q = torch.tensor([[
            [[1., 2., 3., 4.]],
            [[5., 6., 7., 8.]],
            [[1., 2., 1., 2.]],
            [[3., 4., 3., 4.]],
            [[5., 6., 5., 6.]],
        ]])
        k = torch.tensor([[
            [[0.1, 0.2, 0.3, 0.4]],
            [[0.1, 0.0, 0.0, 0.0]],
            [[0.0, 0.1, 0.0, 0.0]],
            [[0.0, 0.0, 0.1, 0.0]],
            [[0.0, 0.0, 0.0, 0.1]],
        ]])
        assert k.shape == torch.Size([1, 5, 1, 4])

        # 系列長とヘッド数を転置
        # --> バッチ, ヘッド, 系列長, 特徴次元
        q = q.transpose(2, 1)
        k = k.transpose(2, 1)
        assert k.shape == torch.Size([1, 1, 5, 4])

        B, H, L_K, E = k.shape
        _, _, L_Q, _ = q.shape
        assert k.unsqueeze(-1).shape == torch.Size([1, 1, 5, 4, 1])
        assert k.unsqueeze(-2).shape == torch.Size([1, 1, 5, 1, 4])
        assert k.unsqueeze(-3).shape == torch.Size([1, 1, 1, 5, 4])

        # バッチ, ヘッド, 系列長*, 系列長, 特徴次元
        # になるように軸を割り込ませてコピー
        k_expand = k.unsqueeze(-3).expand(B, H, L_Q, L_K, E)
        assert k_expand.shape == torch.Size([1, 1, 5, 5, 4])
        assert torch.all(torch.eq(k[0][0], k_expand[0][0][0]))
        assert torch.all(torch.eq(k[0][0], k_expand[0][0][1]))
        assert torch.all(torch.eq(k[0][0], k_expand[0][0][2]))
        assert torch.all(torch.eq(k[0][0], k_expand[0][0][3]))
        assert torch.all(torch.eq(k[0][0], k_expand[0][0][4]))

        # 通常のセルフアテンションであれば 5×5 個の成分を計算しなければならない.
        # 計算量を節約するために k 側を 3 本に間引く.
        # q の各行によって間引き方を変えるので 3 本の選び方を 5 セット用意する.
        sample_k = 3
        torch.manual_seed(0)
        index_sample = torch.randint(L_K, (L_Q, sample_k))
        assert index_sample.shape == torch.Size([5, 3])  # 3 本の選び方が 5 セット
        k_sample = k_expand[:, :, torch.arange(L_Q).unsqueeze(1), index_sample, :]
        assert k_sample.shape == torch.Size([1, 1, 5, 3, 4])

        assert q.unsqueeze(-2).shape == torch.Size([1, 1, 5, 1, 4])  # q を行ごとにほぐす.
        assert k_sample.transpose(-2, -1).shape == torch.Size([1, 1, 5, 4, 3])  # 行ごとに異なる 4×3 にあてる.
        assert torch.matmul(q.unsqueeze(-2), k_sample.transpose(-2, -1)).shape == torch.Size([1, 1, 5, 1, 3])
        q_k_sample = torch.matmul(q.unsqueeze(-2), k_sample.transpose(-2, -1)).squeeze()
        assert q_k_sample.shape == torch.Size([5, 3])

        # 一様分布との交差エントロピーの見積もり値が大きい3行を残す.
        n_top = 3
        M = q_k_sample.max(-1)[0] - torch.div(q_k_sample.sum(-1), L_K)
        assert M.shape == torch.Size([5])
        M_top = M.topk(n_top, sorted=False)[1]  # 値とインデックスのタプルなので [1] でインデックスをとる.
        assert M_top.shape == torch.Size([3])
        assert torch.all(torch.eq(M_top, torch.tensor([4, 1, 0])))

        q_reduce = q[torch.arange(B)[:, None, None], torch.arange(H)[None, :, None], M_top, :]
        assert q_reduce.shape == torch.Size([1, 1, 3, 4])
        q_k = torch.matmul(q_reduce, k.transpose(-2, -1))
        assert q_k.shape == torch.Size([1, 1, 3, 5])

        # 未来へのセルフアテンションをマスクする.
        v = torch.ones([1, 5, 1, 4])
        v = v.transpose(2, 1)
        B, H, L_V, D = v.shape
        _mask = torch.ones(L_Q, q_k.shape[-1], dtype=torch.bool).triu(1)  # 対角成分より上が True の行列
        _mask_ex = _mask[None, None, :].expand(B, H, L_Q, q_k.shape[-1])
        assert torch.all(torch.eq(
            _mask_ex,
            torch.tensor([[
                [False, True, True, True, True],
                [False, False, True, True, True],
                [False, False, False, True, True],
                [False, False, False, False, True],
                [False, False, False, False, False],
            ]])
        ))
        indicator = _mask_ex[torch.arange(B)[:, None, None], torch.arange(H)[None, :, None], M_top, :]
        mask = indicator.view(q_k.shape)
        assert torch.all(torch.eq(
            mask,
            torch.tensor([[
                [False, False, False, False, False],
                [False, False, True, True, True],
                [False, True, True, True, True],
            ]])
        ))
        q_k.masked_fill_(mask, float('-inf'))
        attn = torch.softmax(q_k, dim=-1)

        # v にセルフアテンションを適用する.
        # 間引かれた行の v は単にその行までの cumsum になる.
        # セルフアテンションの間引かれた行には行にわたる一様分布がはめられるがこれを適用したわけではない.
        context = v.cumsum(dim=-2)
        context[torch.arange(B)[:, None, None], torch.arange(H)[None, :, None], M_top, :] \
            = torch.matmul(attn, v).type_as(context)
        attns = (torch.ones([B, H, L_V, L_V]) / L_V).type_as(attn)
        attns[torch.arange(B)[:, None, None], torch.arange(H)[None, :, None], M_top, :] = attn
Ejemplo n.º 37
0
def main():
    # Use a GPU if available, as it should be faster.
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print("Using device: " + str(device))

    # Load the training dataset, and create a data loader to generate a batch.
    textField = PreProcessing.text_field
    labelField = data.Field(sequential=False)

    train, dev = IMDB.splits(textField,
                             labelField,
                             train="train",
                             validation="dev")

    textField.build_vocab(train, dev, vectors=GloVe(name="6B", dim=50))
    labelField.build_vocab(train, dev)

    trainLoader, testLoader = data.BucketIterator.splits(
        (train, dev),
        shuffle=True,
        batch_size=64,
        sort_key=lambda x: len(x.text),
        sort_within_batch=True)

    net = Network().to(device)
    criterion = lossFunc()
    optimiser = topti.Adam(
        net.parameters(), lr=0.0005,
        weight_decay=1e-4)  # Minimise the loss using the Adam algorithm.

    for epoch in range(10):
        running_loss = 0
        total = 0
        num_correct = 0

        for i, batch in enumerate(trainLoader):
            # Get a batch and potentially send it to GPU memory.
            inputs, length, labels = textField.vocab.vectors[batch.text[0]].to(
                device), batch.text[1].to(device), batch.label.type(
                    torch.FloatTensor).to(device)

            labels -= 1

            # PyTorch calculates gradients by accumulating contributions to them (useful for
            # RNNs).  Hence we must manually set them to zero before calculating them.
            optimiser.zero_grad()

            # Forward pass through the network.
            output = net(inputs, length)

            predicted = torch.round(torch.sigmoid(output.detach())).view(-1)
            num_correct += torch.sum(torch.eq(labels, predicted)).item()

            total += labels.size(0)

            loss = criterion(output, labels)

            # Calculate gradients.
            loss.backward()

            # Minimise the loss according to the gradient.
            optimiser.step()

            running_loss += loss.item()

            if i % 32 == 31:
                print("Epoch: %2d, Batch: %4d, Loss: %.3f" %
                      (epoch + 1, i + 1, running_loss / 32))
                #print("Train acc: %2f" % (num_correct/total))
                running_loss = 0

    num_correct = 0

    # Save mode
    torch.save(net.state_dict(), "./model.pth")
    print("Saved model")

    # Evaluate network on the test dataset.  We aren't calculating gradients, so disable autograd to speed up
    # computations and reduce memory usage.
    with torch.no_grad():
        for batch in testLoader:
            # Get a batch and potentially send it to GPU memory.
            inputs, length, labels = textField.vocab.vectors[batch.text[0]].to(
                device), batch.text[1].to(device), batch.label.type(
                    torch.FloatTensor).to(device)

            labels -= 1

            # Get predictions
            outputs = torch.sigmoid(net(inputs, length)).view(-1)
            predicted = torch.round(outputs).view(-1)

            num_correct += torch.sum(torch.eq(labels, predicted)).item()

    accuracy = 100 * num_correct / len(dev)

    print(f"Classification accuracy: {accuracy}")
Ejemplo n.º 38
0
    for i, j in enumerate(all_end_pos_seq):
        if j in real_end_pos:
            label[i] = 1
    batch = seq.reshape(batch.shape)
    return batch, label


total = len(test_data)
results = []
with torch.no_grad():
    for i in tqdm(range(0, total, batch_size)):
        sample = test_data[i:i + batch_size]
        context_raw = [x[0] for x in sample]
        paras = [x[1] for x in sample]
        batch, label = get_train_data(context_raw, doc_max_length_size)
        batch = torch.LongTensor(batch)
        mask_idx = torch.eq(batch, vocab_size)
        answer_logits = model([batch.cuda(), None])
        end_num = mask_idx.sum(1).data.numpy().tolist()
        answer_logits = answer_logits.cpu().data.numpy().tolist()
        start = 0
        for one_sent_end_num, para in zip(end_num, paras):
            pred = answer_logits[start:start + one_sent_end_num]
            results.append([pred, para])
            start += one_sent_end_num
    threshold = 0.5
    precision, recall, f1, macro_f1, accuracy = evaluate_comqa(results, )
    print(
        'Test results:\nthreshold:{}\nprecision:{}\nrecall:{}\nf1:{}\nmacro_f1:{}\naccuracy:{}'
        .format(threshold, precision, recall, f1, macro_f1, accuracy))
Ejemplo n.º 39
0
    def finetunning(self, x_spt, y_spt, x_qry, y_qry):
        querysz = x_qry.size(0)

        corrects = [0 for _ in range(self.update_step_test + 1)]

        # in order to not ruin the state of running_mean/variance and bn_weight/bias
        # we finetunning on the copied model instead of self.net
        net = deepcopy(self.net)

        # 1. run the i-th task and compute loss for k=0
        logits = net(x_spt, dropout_training=True)
        loss = F.cross_entropy(logits, y_spt)
        grad = torch.autograd.grad(loss, net.parameters())
        fast_weights = list(
            map(lambda p: p[1] - self.update_lr * p[0],
                zip(grad, net.parameters())))

        # this is the loss and accuracy before first update
        with torch.no_grad():
            # [setsz, nway]
            logits_q = net(x_qry,
                           net.parameters(),
                           bn_training=True,
                           dropout_training=False)
            # [setsz]
            #                pred_q = F.softmax(logits_q, dim=1).argmax(dim=1)
            pred_q = logits_q.argmax(dim=1)
            # scalar
            correct = torch.eq(pred_q, y_qry).sum().item()
            corrects[0] = corrects[0] + correct

        # this is the loss and accuracy after the first update
        with torch.no_grad():
            # [setsz, nway]
            logits_q = net(x_qry,
                           fast_weights,
                           bn_training=True,
                           dropout_training=False)
            # [setsz]
            #                pred_q = F.softmax(logits_q, dim=1).argmax(dim=1)
            pred_q = logits_q.argmax(dim=1)
            # scalar
            correct = torch.eq(pred_q, y_qry).sum().item()
            corrects[1] = corrects[1] + correct

        for k in range(1, self.update_step_test):
            # 1. run the i-th task and compute loss for k=1~K-1
            logits = net(x_spt,
                         fast_weights,
                         bn_training=True,
                         dropout_training=True)
            loss = F.cross_entropy(logits, y_spt)
            # 2. compute grad on theta_pi
            grad = torch.autograd.grad(loss, fast_weights)
            # 3. theta_pi = theta_pi - train_lr * grad
            fast_weights = list(
                map(lambda p: p[1] - self.update_lr * p[0],
                    zip(grad, fast_weights)))

            logits_q = net(x_qry,
                           fast_weights,
                           bn_training=True,
                           dropout_training=False)
            # loss_q will be overwritten and just keep the loss_q on last update step.
            loss_q = F.cross_entropy(logits_q, y_qry)

            with torch.no_grad():
                #                pred_q = F.softmax(logits_q, dim=1).argmax(dim=1)
                pred_q = logits_q.argmax(dim=1)
                correct = torch.eq(pred_q,
                                   y_qry).sum().item()  # convert to numpy
                corrects[k + 1] = corrects[k + 1] + correct

        del net

        accs = np.array(corrects) / querysz

        return accs
Ejemplo n.º 40
0
            for i, data in enumerate(testloader):
                input_pc, input_sn, input_label, input_node, input_node_knn_I = data
                model.set_input(input_pc, input_sn, input_label, input_node,
                                input_node_knn_I)
                model.test_model()

                batch_amount += input_label.size()[0]

                # # accumulate loss
                model.test_loss += model.loss.detach() * input_label.size()[0]

                # # accumulate accuracy
                _, predicted_idx = torch.max(model.score.data,
                                             dim=1,
                                             keepdim=False)
                correct_mask = torch.eq(predicted_idx,
                                        model.input_label).float()
                test_accuracy = torch.mean(correct_mask).cpu()
                model.test_accuracy += test_accuracy * input_label.size()[0]

            model.test_loss /= batch_amount
            model.test_accuracy /= batch_amount
            if model.test_accuracy.item() > best_accuracy:
                best_accuracy = model.test_accuracy.item()
            print('Tested network. So far best: %f' % best_accuracy)

            # save network
            if opt.classes == 10:
                saving_acc_threshold = 0.930
            else:
                saving_acc_threshold = 0.918
            if model.test_accuracy.item() > saving_acc_threshold:
Ejemplo n.º 41
0
    def test_sparse_multihead_attention(self):
        attn_weights = torch.randn(1, 8, 8)
        bidirectional_sparse_mask = torch.tensor(
            [[0, 0, 0, 0, 0, float('-inf'),
              float('-inf'), 0],
             [0, 0, 0, 0, 0, float('-inf'),
              float('-inf'), 0],
             [0, 0, 0, 0, 0, float('-inf'),
              float('-inf'), 0],
             [0, 0, 0, 0, 0, float('-inf'),
              float('-inf'), 0],
             [float('-inf'),
              float('-inf'),
              float('-inf'), 0, 0, 0, 0, 0],
             [float('-inf'),
              float('-inf'),
              float('-inf'), 0, 0, 0, 0, 0],
             [float('-inf'),
              float('-inf'),
              float('-inf'), 0, 0, 0, 0, 0],
             [float('-inf'),
              float('-inf'),
              float('-inf'), 0, 0, 0, 0, 0]])

        bidirectional_attention = SparseMultiheadAttention(
            16, 1, stride=4, expressivity=1, is_bidirectional=True)
        bidirectional_attention_sparse_mask = bidirectional_attention.buffered_sparse_mask(
            attn_weights, 8, 8)
        torch.all(
            torch.eq(bidirectional_attention_sparse_mask,
                     bidirectional_sparse_mask))

        sparse_mask = torch.tensor([
            [
                0,
                float('-inf'),
                float('-inf'),
                float('-inf'),
                float('-inf'),
                float('-inf'),
                float('-inf'),
                float('-inf')
            ],
            [
                0, 0,
                float('-inf'),
                float('-inf'),
                float('-inf'),
                float('-inf'),
                float('-inf'),
                float('-inf')
            ],
            [
                0, 0, 0,
                float('-inf'),
                float('-inf'),
                float('-inf'),
                float('-inf'),
                float('-inf')
            ],
            [
                0, 0, 0, 0,
                float('-inf'),
                float('-inf'),
                float('-inf'),
                float('-inf')
            ],
            [0, 0, 0, 0, 0,
             float('-inf'),
             float('-inf'),
             float('-inf')],
            [
                float('-inf'),
                float('-inf'),
                float('-inf'), 0, 0, 0,
                float('-inf'),
                float('-inf')
            ],
            [
                float('-inf'),
                float('-inf'),
                float('-inf'), 0, 0, 0, 0,
                float('-inf')
            ],
            [float('-inf'),
             float('-inf'),
             float('-inf'), 0, 0, 0, 0, 0],
        ])

        attention = SparseMultiheadAttention(16,
                                             1,
                                             stride=4,
                                             expressivity=1,
                                             is_bidirectional=False)
        attention_sparse_mask = attention.buffered_sparse_mask(
            attn_weights, 8, 8)

        torch.all(torch.eq(attention_sparse_mask, sparse_mask))
Ejemplo n.º 42
0
 def loss_oriproto(self, sample_inputs, n_xs, n_xq, n_class, n_channles, n_size, protos, temperature):
         xs = Variable(sample_inputs[:,:n_xs,:]) # support; 
         xq = Variable(sample_inputs[:,n_xs:,:]) # query; 
         protos = Variable(protos)
         
         #print('protos')
         #print(protos.size())
         
         n_class = xs.size(0)
         n_support = xs.size(1)
         n_query = xq.size(1)
         n_proto = protos.size(0)
         n_prevtask = n_proto/n_class #we store one proto for each class for each previous task
         
         target_inds = torch.arange(0, n_class).view(n_class, 1, 1).expand(n_class, n_query, 1).long()
         target_inds = Variable(target_inds, requires_grad=False)
 
         if xq.is_cuda:
             target_inds = target_inds.cuda()
             protos = protos.cuda()
             
         xs = xs.reshape(n_class*n_support, n_channles, n_size, n_size)
         xq = xq.reshape(n_class * n_query, n_channles, n_size, n_size)        
         
         x = torch.cat((xs,xq), 0)
         z = self.encoder.forward(x)        
         z_dim = z.size(-1)
 
         zq = z[n_class*n_support:]
         z_proto = z[:n_class*n_support].view(n_class, n_support, z_dim).mean(1)
         
         #z_proto: n_class, z_dim
         #zq: n_class*n_query, z_dim
         dists = euclidean_dist(zq, z_proto,temperature)
 
         #dists: n_class*n_query, n_class
         log_p_y = F.log_softmax(-dists, dim=1).view(n_class, n_query, -1)
 
         #log_p_y: n_class, n_query, n_class (normalized from 0 to 1)
         #target_inds: n_class, n_query, 1
         loss_val = -log_p_y.gather(2, target_inds).squeeze().view(-1).mean()
         #print(loss_val)
         #pick the values of the ground truth index and calculate cross entropy loss
         _, y_hat = log_p_y.max(2)
         
         #y_hat: [n_class, n_query] ->  index of max value
         acc_val = torch.eq(y_hat, target_inds.squeeze()).float().mean()
         
         ## we calculate losses for each previous task
         z_proto_p = defaultdict(list)
         dists_p = defaultdict(list)
         log_p_y_p = defaultdict(list)
         y_hat_p = defaultdict(list)
         loss_val_p = defaultdict(list)
         acc_val_p = defaultdict(list)
         
         if n_prevtask > 0:
             protos_c = protos.view(-1,n_channles, n_size, n_size)
             z_protos = self.encoder.forward(protos_c)
             z_protos = z_protos.view(n_class*n_prevtask, -1, z_dim).mean(1)
         
         for t in range(n_prevtask): 
             
             z_proto_p[t] = z_protos[t*n_class:(t+1)*n_class,:]
             z_proto_p[t] = z_proto_p[t]
             #print(z_proto_p[t].size())
             dists_p[t] = euclidean_dist(zq, z_proto_p[t],temperature)
             #print(dists_p[t].size())
             log_p_y_p[t] = F.log_softmax(-dists_p[t], dim=1).view(n_class, n_query, -1)
             #print(log_p_y_p[t].size())
             loss_val_p[t] = - log_p_y_p[t].gather(2, target_inds).squeeze().view(-1).mean()
             _, y_hat_p[t] = log_p_y_p[t].max(2)
             acc_val_p[t] = torch.eq(y_hat_p[t], target_inds.squeeze()).float().mean()
             #print('loss_val_p and acc_val_p')
             #print(loss_val_p[t])
             #print(acc_val_p[t])   
         
         loss_total = loss_val
         acc_total = acc_val
         for t in range(n_prevtask):
             loss_total = loss_total + loss_val_p[t]
             acc_total = acc_total + acc_val_p[t]
             
             
         return loss_total, {
             'loss': loss_total.item()/(n_prevtask+1),
             'acc': acc_total.item()/(n_prevtask+1)
         }
def main(args):
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)
    np.random.seed(args.seed)
    random.seed(args.seed)

    my_experiment = experiment(args.name, args, "./results/", args.commit)
    writer = SummaryWriter(my_experiment.path + "tensorboard")

    logger = logging.getLogger('experiment')
    logger.setLevel(logging.INFO)
    total_clases = 10

    frozen_layers = []
    for temp in range(args.rln * 2):
        frozen_layers.append("vars." + str(temp))
    logger.info("Frozen layers = %s", " ".join(frozen_layers))

    final_results_all = []
    total_clases = args.schedule
    for tot_class in total_clases:
        lr_list = [
            0.03, 0.01, 0.003, 0.001, 0.0003, 0.0001, 0.00003, 0.00001,
            0.000003
        ]
        for aoo in range(0, args.runs):

            keep = np.random.choice(list(range(20)), tot_class, replace=False)
            #

            dataset = imgnet.MiniImagenet(args.dataset_path,
                                          mode='test',
                                          elem_per_class=30,
                                          classes=keep,
                                          seed=aoo)

            dataset_test = imgnet.MiniImagenet(args.dataset_path,
                                               mode='test',
                                               elem_per_class=30,
                                               test=args.test,
                                               classes=keep,
                                               seed=aoo)

            # Iterators used for evaluation

            iterator = torch.utils.data.DataLoader(dataset_test,
                                                   batch_size=128,
                                                   shuffle=True,
                                                   num_workers=1)
            iterator_sorted = torch.utils.data.DataLoader(dataset,
                                                          batch_size=1,
                                                          shuffle=args.iid,
                                                          num_workers=1)

            #
            print(args)

            if torch.cuda.is_available():
                device = torch.device('cuda')
            else:
                device = torch.device('cpu')

            results_mem_size = {}

            for mem_size in [args.memory]:
                max_acc = -10
                max_lr = -10
                for lr in lr_list:

                    print(lr)
                    # for lr in [0.001, 0.0003, 0.0001, 0.00003, 0.00001]:
                    maml = torch.load(args.model, map_location='cpu')

                    if args.scratch:
                        config = mf.ModelFactory.get_model("na", args.dataset)
                        maml = learner.Learner(config)
                        # maml = MetaLearingClassification(args, config).to(device).net

                    maml = maml.to(device)

                    for name, param in maml.named_parameters():
                        param.learn = True

                    for name, param in maml.named_parameters():
                        # logger.info(name)
                        if name in frozen_layers:
                            # logger.info("Freeezing name %s", str(name))
                            param.learn = False
                            # logger.info(str(param.requires_grad))
                        else:
                            if args.reset:
                                w = nn.Parameter(torch.ones_like(param))
                                # logger.info("W shape = %s", str(len(w.shape)))
                                if len(w.shape) > 1:
                                    torch.nn.init.kaiming_normal_(w)
                                else:
                                    w = nn.Parameter(torch.zeros_like(param))
                                param.data = w
                                param.learn = True

                    frozen_layers = []
                    for temp in range(args.rln * 2):
                        frozen_layers.append("vars." + str(temp))

                    torch.nn.init.kaiming_normal_(maml.parameters()[-2])
                    w = nn.Parameter(torch.zeros_like(maml.parameters()[-1]))
                    maml.parameters()[-1].data = w

                    for n, a in maml.named_parameters():
                        n = n.replace(".", "_")
                        # logger.info("Name = %s", n)
                        if n == "vars_14":
                            w = nn.Parameter(torch.ones_like(a))
                            # logger.info("W shape = %s", str(w.shape))
                            torch.nn.init.kaiming_normal_(w)
                            a.data = w
                        if n == "vars_15":
                            w = nn.Parameter(torch.zeros_like(a))
                            a.data = w

                    correct = 0

                    for img, target in iterator:
                        with torch.no_grad():
                            img = img.to(device)
                            target = target.to(device)
                            logits_q = maml(img,
                                            vars=None,
                                            bn_training=False,
                                            feature=False)
                            pred_q = (logits_q).argmax(dim=1)
                            correct += torch.eq(pred_q,
                                                target).sum().item() / len(img)

                    logger.info("Pre-epoch accuracy %s",
                                str(correct / len(iterator)))

                    filter_list = [
                        "vars.0", "vars.1", "vars.2", "vars.3", "vars.4",
                        "vars.5"
                    ]

                    logger.info("Filter list = %s", ",".join(filter_list))
                    list_of_names = list(
                        map(
                            lambda x: x[1],
                            list(
                                filter(lambda x: x[0] not in filter_list,
                                       maml.named_parameters()))))

                    list_of_params = list(
                        filter(lambda x: x.learn, maml.parameters()))
                    list_of_names = list(
                        filter(lambda x: x[1].learn, maml.named_parameters()))
                    if args.scratch or args.no_freeze:
                        print("Empty filter list")
                        list_of_params = maml.parameters()
                    #
                    for x in list_of_names:
                        logger.info("Unfrozen layer = %s", str(x[0]))
                    opt = torch.optim.Adam(list_of_params, lr=lr)
                    res_sampler = rep.ReservoirSampler(mem_size)
                    for _ in range(0, args.epoch):
                        for img, y in iterator_sorted:
                            if mem_size > 0:
                                res_sampler.update_buffer(zip(img, y))
                                res_sampler.update_observations(len(img))
                                img = img.to(device)
                                y = y.to(device)
                                img2, y2 = res_sampler.sample_buffer(8)
                                img2 = img2.to(device)
                                y2 = y2.to(device)

                                img = torch.cat([img, img2], dim=0)
                                y = torch.cat([y, y2], dim=0)
                            else:
                                img = img.to(device)
                                y = y.to(device)

                            pred = maml(img)
                            opt.zero_grad()
                            loss = F.cross_entropy(pred, y)
                            loss.backward()
                            opt.step()

                    logger.info("Result after one epoch for LR = %f", lr)
                    correct = 0
                    for img, target in iterator:
                        img = img.to(device)
                        target = target.to(device)
                        logits_q = maml(img,
                                        vars=None,
                                        bn_training=False,
                                        feature=False)

                        pred_q = (logits_q).argmax(dim=1)
                        # print("Pred=", pred_q)
                        # print("Target=", target)
                        correct += torch.eq(pred_q,
                                            target).sum().item() / len(img)

                    logger.info(str(correct / len(iterator)))
                    if (correct / len(iterator) > max_acc):
                        max_acc = correct / len(iterator)
                        max_lr = lr

                lr_list = [max_lr]
                results_mem_size[mem_size] = (max_acc, max_lr)
                logger.info("Final Max Result = %s", str(max_acc))
                writer.add_scalar('/finetune/best_' + str(aoo), max_acc,
                                  tot_class)
            final_results_all.append((tot_class, results_mem_size))
            print("A=  ", results_mem_size)
            logger.info("Final results = %s", str(results_mem_size))

            my_experiment.results["Final Results"] = final_results_all
            my_experiment.store_json()
            print("FINAL RESULTS = ", final_results_all)
    writer.close()
 def forward(self, embeddings, labels):
     d = self.pdist(embeddings, squared=False)
     pos = torch.eq(*[labels.unsqueeze(dim).expand_as(d) for dim in [0, 1]]).type_as(d)
     neg_i = torch.mul((self.params_dict['margin'] - d).exp(), 1 - pos).sum(1).expand_as(d).clamp(min=1e-12)
     return torch.sum(F.relu(pos.triu(1) * ((neg_i + neg_i.t()).log() + d)).pow(2)) / (pos.sum() - len(d) + 1e-8)
Ejemplo n.º 45
0
    def forward(self, x_spt, y_spt, x_qry, y_qry):
        task_num, setsz, w, h = x_spt.size()
        querysz = x_qry.size(1)

        losses_q = [0 for _ in range(self.update_step + 1)
                    ]  # losses_q[i] is the loss on step i
        corrects = [0 for _ in range(self.update_step + 1)]

        for i in range(task_num):

            # 1. run the i-th task and compute loss for k=0
            logits = self.net(x_spt[i],
                              self.net.parameters(),
                              bn_training=True,
                              dropout_training=True)
            loss = F.cross_entropy(logits, y_spt[i])
            grad = torch.autograd.grad(loss, self.net.parameters())
            fast_weights = list(
                map(lambda p: p[1] - self.update_lr * p[0],
                    zip(grad, self.net.parameters())))

            with torch.no_grad():
                # [setsz, nway]
                logits_q = self.net(x_qry[i],
                                    self.net.parameters(),
                                    bn_training=True,
                                    dropout_training=False)
                loss_q = F.cross_entropy(logits_q, y_qry[i])
                losses_q[0] += loss_q

                #                pred_q = F.softmax(logits_q, dim=1).argmax(dim=1)
                pred_q = logits_q.argmax(dim=1)
                correct = torch.eq(pred_q, y_qry[i]).sum().item()
                corrects[0] = corrects[0] + correct

            # this is the loss and accuracy after the first update
            with torch.no_grad():
                # [setsz, nway]
                logits_q = self.net(x_qry[i],
                                    fast_weights,
                                    bn_training=True,
                                    dropout_training=False)
                loss_q = F.cross_entropy(logits_q, y_qry[i])
                losses_q[1] += loss_q
                # [setsz]
                #                pred_q = F.softmax(logits_q, dim=1).argmax(dim=1)
                pred_q = logits_q.argmax(dim=1)
                correct = torch.eq(pred_q, y_qry[i]).sum().item()
                corrects[1] = corrects[1] + correct

            for k in range(1, self.update_step):
                # 1. run the i-th task and compute loss for k=1~K-1
                logits = self.net(x_spt[i],
                                  fast_weights,
                                  bn_training=True,
                                  dropout_training=True)
                loss = F.cross_entropy(logits, y_spt[i])
                # 2. compute grad on theta_pi
                grad = torch.autograd.grad(loss, fast_weights)
                # 3. theta_pi = theta_pi - train_lr * grad
                fast_weights = list(
                    map(lambda p: p[1] - self.update_lr * p[0],
                        zip(grad, fast_weights)))

                logits_q = self.net(x_qry[i],
                                    fast_weights,
                                    bn_training=True,
                                    dropout_training=True)
                # loss_q will be overwritten and just keep the loss_q on last update step.
                loss_q = F.cross_entropy(logits_q, y_qry[i])
                losses_q[k + 1] += loss_q

                with torch.no_grad():
                    #                pred_q = F.softmax(logits_q, dim=1).argmax(dim=1)
                    pred_q = logits_q.argmax(dim=1)
                    correct = torch.eq(
                        pred_q, y_qry[i]).sum().item()  # convert to numpy
                    corrects[k + 1] = corrects[k + 1] + correct

        # end of all tasks
        # sum over all losses on query set across all tasks
        loss_q = losses_q[-1] / task_num

        # optimize theta parameters
        self.meta_optim.zero_grad()
        loss_q.backward()
        # print('meta update')
        # for p in self.net.parameters()[:5]:
        # 	print(torch.norm(p).item())
        self.meta_optim.step()

        accs = np.array(corrects) / (querysz * task_num)

        return accs, losses_q
        inputs = torch.from_numpy(
            resample(inputs, int(LIBRISPEECH_SAMPLING_RATE * n_seconds / downsampling), axis=1)
        ).reshape((batchsize, 1, int(LIBRISPEECH_SAMPLING_RATE * n_seconds / downsampling)))

        # Zero the parameter gradients
        optimizer.zero_grad()

        # Forward + backward + optimize
        outputs = model(inputs)
        loss = criterion(outputs, labels.reshape((batchsize, 1)).cuda().double())
        loss.backward()
        optimizer.step()

        # Print statistics
        running_loss += loss.item()
        running_correct_samples += torch.eq((outputs[:, 0] > 0.5).cpu(), labels.byte()).numpy().sum()
        if i % evaluate_every_n_batches == evaluate_every_n_batches - 1:  # print every 'print_every' mini-batches
            val_acc = evaluate(model, testloader, preprocessor)

            # return model to training mode
            model.train()

            print('[%d, %5d, %.1f] loss: %.3f acc: %.3f val_acc: %.3f' %
                  (epoch + 1, i + 1, time.time() - t0,
                   running_loss / evaluate_every_n_batches,
                   running_correct_samples * 1. / (evaluate_every_n_batches * batchsize),
                   val_acc))
            running_loss = 0.0
            running_correct_samples = 0

            val_acc_values.append(val_acc)
Ejemplo n.º 47
0
def train(weight_decay,types,epochs = 20,pretrained_path=None):
        
    name = str(weight_decay)+'+'+types+'Stacking'
    
    model = Stack(4,3)
    if pretrained_path != None:
        model.load_state_dict(torch.load(pretrained_path))
    print("training")
#    optimizer = optim.SGD(model.parameters(), lr=1e-3)
    optimizer = optim.Adam(model.parameters(),weight_decay=weight_decay)
    model = model.to(device)
    
    F1Max = 0.0
    for epoch in range(epochs):
        running_loss,acc,precision,recall,F1,linkRecall = test(model)
        if(F1 > F1Max ):
            F1Max = F1
            torch.save(model.state_dict(),"model/"+name)
            with open("model/"+name+".txt","w") as file:
                file.write('[%d] loss: %.4f , acc: %.4f , precision: %.4f, recall: %.4f,F1: %.4f,LinkRecall: %.4f Testing'
                           %(epoch+1,running_loss,acc,precision,recall,F1,linkRecall))
        number = 0
        running_loss = 0.0
        acc = 0.0
        H = 0
        S = 0
        common = 0
        for i in range(len(sent_list)):
            pred_span=[]
            for k in range(len(sent_list[i])):
                optimizer.zero_grad()
                
                n = len(sent_list[i][k])
                number += n
                
                StackTensor = sent_list[i][k]
                labels = sent_labels[i][k]
                
                StackOutput = model(StackTensor)
                loss = F.nll_loss(StackOutput,labels)
                
                
                _,pred = torch.max(StackOutput,dim=1)
                acc += torch.sum(torch.eq(pred,labels).float()).data
                
                loss.backward()
                optimizer.step()
                for indexs in convert(pred.data):
                    pred_span.append([train_span_list[i][k][indexs[0]][0],train_span_list[i][k][indexs[1]][1]])
                running_loss += n*loss.data
            S += len(pred_span)
            H += len(train_labels_span[i])
            common += metrics(pred_span,train_labels_span[i])
        #        print(pred)
        print(S,H,common)
        if(S != 0):
            precision = common/S
        else:
            precision = 0.0
        recall = common/H
        if(common==0):
            F1 = 0.0
        else:
            F1 = 2*recall*precision/float(recall+precision)
        print('[%d] loss: %.4f , acc: %.4f , precision: %.4f, recall: %.4f,F1: %.4f'
              %(epoch+1,running_loss/number,acc/number,precision,recall,F1))
Ejemplo n.º 48
0
def get_n_valid(x: Tensor) -> Tensor:
    return torch.sum(torch.eq(x, x).float())
 def find_index(tensor_row, x):
     # for c in range(self.city_num):
     #     if tensor_row[c] == x:
     #         return c
     ind = torch.eq(tensor_row, x).nonzero().squeeze()
     return ind
Ejemplo n.º 50
0
def model_test(model, input_data, label_data, is_debug=False, line=0.1):
    total_num = len(input_data)

    correct_num = 0
    st2_num = 0
    st3_num = 0
    total_num = len(input_data)
    st1 = np.zeros(64)
    st2 = np.zeros(64)
    st3 = np.zeros(64)
    data_sca = np.zeros(64)

    zero_num = 0
    for step in range(len(input_data)):
        if is_debug:
            print(
                '\n<----------------------------------------------------------',
                step)
        x = input_data[step:(step + 1)]
        y = label_data[step:(step + 1)]
        max_y_index = 0
        # 求y的最大值下标
        for i in range(len(y[0])):
            if y[0][i] >= y[0][max_y_index]:
                max_y_index = i

        # if not (max_y_index >= 18 and max_y_index <= 22):  # 去掉低的情况
        #     total_num = total_num - 1
        #     continue

        data_sca[max_y_index] = data_sca[max_y_index] + 1
        x = np.array(x).reshape(1, SEQ_LEN, 1)
        y = np.array(y).reshape(1, SEQ_LEN, 1)
        x = torch.FloatTensor(x).cuda(0)
        y = torch.ByteTensor(y).cuda(0)
        prediction, _ = model(x, None)

        predict = torch.sigmoid(prediction) > line
        max_predict_index = 0
        mm = 0
        for i in range(1, len(predict.view(-1))):
            if predict.view(-1)[i] > mm:
                max_predict_index = i
                mm = predict.view(-1)[i]

        if max_predict_index == 0:
            zero_num = zero_num + 1

        t = label_data[step:(step + 1)]
        # t=
        t = np.array(t[0]).tolist()
        pd = predict.view(-1).data.cpu().numpy()
        pd = np.array(pd).tolist()
        result = torch.eq(y, predict)
        res = result.view(-1).data.cpu().numpy()
        res = np.array(res).tolist()
        if is_debug:
            print('target:    ', t)
            print('predict:   ', pd)
            print('difference:', res)

        # 在某个位置前后偏离两个位置
        if standard_define.is_satisfied_standard2(pd, max_y_index):
            st2_num = st2_num + 1
            st2[max_y_index] = st2[max_y_index] + 1

        if standard_define.is_satisfied_standard3(pd, max_y_index):
            st3_num = st3_num + 1
            st3[max_y_index] = st3[max_y_index] + 1

        accuracy = torch.sum(result) / torch.sum(torch.ones(y.shape))
        accuracy = accuracy.data.cpu().numpy()
        correct = torch.eq(torch.sum(~result), 0)
        if is_debug:
            print('accuracy: ', accuracy)
            print('correct: {}'.format(correct))
        if correct == 1:
            # right_index = int(max_y_index/10)
            st1[max_y_index] = st1[max_y_index] + 1
            correct_num = correct_num + 1  # 标准1,完全匹配
        if is_debug:
            print(
                '-------------------------------------------------------------->\n'
            )

    if is_debug:
        print('total:', (step + 1), ' | correct_num:', correct_num,
              '| complete_correct_rate:', correct_num / total_num,
              '| st2_num: ', st2_num, ' |st2_rate: ', st2_num / total_num,
              '| st3_num: ', st3_num, ' |st3_rate: ', st3_num / total_num)
        print('st1 : ', st1)
        print('data_sca : ', data_sca)

        right_distribute.distribute_cv(st1, data_sca, 36, 'st1完全正确预测分布')
        right_distribute.distribute_cv(st2, data_sca, 36, 'st2相对预测正确率')
        right_distribute.distribute_cv(st3, data_sca, 36, 'st3相对预测正确率')

    return correct_num / total_num, st2_num / total_num, st3_num / total_num
Ejemplo n.º 51
0
    def test_top_k_top_p_filtering(self):
        logits = torch.tensor(
            [
                [
                    8.2220991,  # 3rd highest value; idx. 0
                    -0.5620044,
                    5.23229752,
                    4.0386393,
                    -6.8798378,
                    -0.54785802,
                    -3.2012153,
                    2.92777176,
                    1.88171953,
                    7.35341276,  # 5th highest value; idx. 9
                    8.43207833,  # 2nd highest value; idx. 10
                    -9.85711836,
                    -5.96209236,
                    -1.13039161,
                    -7.1115294,
                    -0.8369633,
                    -5.3186408,
                    7.06427407,
                    0.81369344,
                    -0.82023817,
                    -5.9179796,
                    0.58813443,
                    -6.99778438,
                    4.71551189,
                    -0.18771637,
                    7.44020759,  # 4th highest value; idx. 25
                    9.38450987,  # 1st highest value; idx. 26
                    2.12662941,
                    -9.32562038,
                    2.35652522,
                ],  # cummulative prob of 5 highest values <= 0.6
                [
                    0.58425518,
                    4.53139238,
                    -5.57510464,
                    -6.28030699,
                    -7.19529503,
                    -4.02122551,
                    1.39337037,
                    -6.06707057,
                    1.59480517,
                    -9.643119,
                    0.03907799,
                    0.67231762,
                    -8.88206726,
                    6.27115922,  # 4th highest value; idx. 13
                    2.28520723,
                    4.82767506,
                    4.30421368,
                    8.8275313,  # 2nd highest value; idx. 17
                    5.44029958,  # 5th highest value; idx. 18
                    -4.4735794,
                    7.38579536,  # 3rd highest value; idx. 20
                    -2.91051663,
                    2.61946077,
                    -2.5674762,
                    -9.48959302,
                    -4.02922645,
                    -1.35416918,
                    9.67702323,  # 1st highest value; idx. 27
                    -5.89478553,
                    1.85370467,
                ],  # cummulative prob of 5 highest values <= 0.6
            ],
            dtype=torch.float,
            device=torch_device,
        )

        non_inf_expected_idx = torch.tensor(
            [[0, 0], [0, 9], [0, 10], [0, 25], [0, 26], [1, 13], [1, 17], [1, 18], [1, 20], [1, 27]],
            dtype=torch.long,
            device=torch_device,
        )  # expected non filtered idx as noted above

        non_inf_expected_output = torch.tensor(
            [
                8.2221,
                7.3534,
                8.4321,
                7.4402,
                9.3845,
                6.2712,
                8.8275,
                5.4403,
                7.3858,
                9.6770,
            ],  # expected non filtered values as noted above
            dtype=torch.float,
            device=torch_device,
        )

        output = top_k_top_p_filtering(logits, top_k=10, top_p=0.6, min_tokens_to_keep=4)
        non_inf_output = output[output != -float("inf")].to(device=torch_device)
        non_inf_idx = (output != -float("inf")).nonzero().to(device=torch_device)

        self.assertTrue(torch.allclose(non_inf_expected_output, non_inf_output, atol=1e-12))
        self.assertTrue(torch.all(torch.eq(non_inf_expected_idx, non_inf_idx)))
gruLayer = nn.GRU(
                input_size = inputSize,
                hidden_size = hiddenSize,
                num_layers = 1,
                batch_first= True
                )

X = torch.rand(nSample, seqLen, inputSize)
h_0 = torch.rand(1, nSample, hiddenSize)

Y, h_1 = gruLayer(X, h_0)        
# Y.size()          torch.Size([100, 15, 10])  which is (batch, seq_len, hidden_size)
# h_1.size()        torch.Size([1, 100, 10])   which is (num_layers, batch, hidden_size)

# we can verify last Y out is equal to h_1
print(torch.all(torch.eq(Y[:,-1,:].squeeze(), h_1.squeeze())))


## We can also reproduce GRU layer by GRU cell
## input of shape (batch, input_size): tensor containing input features
## hidden of shape (batch, hidden_size): tensor containing the initial hidden state for each element in the batch. Defaults to zero if not provided.
## output of shape (batch, hidden_size): tensor containing the next hidden state for each element in the batch
gruCell = nn.GRUCell(
                input_size = inputSize,
                hidden_size = hiddenSize
                )

# initialize rnnCell paramters require the function of torch.nn.Parameter
gruCell.weight_ih = torch.nn.Parameter(gruLayer.state_dict()["weight_ih_l0"])
gruCell.weight_hh = torch.nn.Parameter(gruLayer.state_dict()["weight_hh_l0"])
gruCell.bias_ih = torch.nn.Parameter(gruLayer.state_dict()["bias_ih_l0"])
Ejemplo n.º 53
0
 def __call__(self, output, label):
     prediction = (output > 0.5).float()
     hit_rate = torch.eq(prediction, label).float().mean()
     return hit_rate
Ejemplo n.º 54
0
    def train_nmt_full(self, output_prob_file, n_train_epochs):
        hparams = copy.deepcopy(self.hparams)

        hparams.train_nmt = True
        hparams.output_prob_file = output_prob_file
        hparams.n_train_epochs = n_train_epochs

        model = self.nmt_model
        optim = self.nmt_optim
        #optim = torch.optim.Adam(trainable_params)
        #step = 0
        #cur_attempt = 0
        #lr = hparams.lr

        trainable_params = [p for p in model.parameters() if p.requires_grad]
        num_params = count_params(trainable_params)
        print("Model has {0} params".format(num_params))

        print("-" * 80)
        print("start training...")
        start_time = log_start_time = time.time()
        target_words, total_loss, total_corrects = 0, 0, 0
        target_rules, target_total, target_eos = 0, 0, 0
        total_word_loss, total_rule_loss, total_eos_loss = 0, 0, 0
        model.train()
        #i = 0
        #epoch = 0
        update_batch_size = 0
        #for (x_train, x_mask, x_count, x_len, x_pos_emb_idxs, y_train, y_mask, y_count, y_len, y_pos_emb_idxs, batch_size, lan_id, eop) in data_util.next_nmt_train():
        for (x_train, x_mask, x_count, x_len, x_pos_emb_idxs, y_train, y_mask,
             y_count, y_len, y_pos_emb_idxs, batch_size, lan_id, eop, eob,
             save_grad) in self.data_loader.next_sample_nmt_train(
                 self.featurizer, self.actor):
            self.step += 1
            target_words += (y_count - batch_size)
            logits = model.forward(x_train,
                                   x_mask,
                                   x_len,
                                   x_pos_emb_idxs,
                                   y_train[:, :-1],
                                   y_mask[:, :-1],
                                   y_len,
                                   y_pos_emb_idxs, [], [],
                                   file_idx=[],
                                   step=self.step,
                                   x_rank=[])
            logits = logits.view(-1, hparams.trg_vocab_size)
            labels = y_train[:, 1:].contiguous().view(-1)

            cur_nmt_loss = torch.nn.functional.cross_entropy(
                logits,
                labels,
                ignore_index=self.hparams.pad_id,
                reduction="none")
            total_loss += cur_nmt_loss.sum().item()
            cur_nmt_loss = cur_nmt_loss.view(batch_size, -1).sum(-1).div_(
                batch_size * hparams.update_batch)

            if save_grad and not self.hparams.not_train_score:
                #save the gradients to nmt moving average
                for batch_id in range(batch_size):
                    batch_lan_id = lan_id[batch_id]
                    cur_nmt_loss[batch_id].backward(retain_graph=True)
                    optim.save_gradients(batch_lan_id)
            else:
                cur_nmt_loss = cur_nmt_loss.sum()
                cur_nmt_loss.backward()
            grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(),
                                                       hparams.clip_grad)

            mask = (labels == hparams.pad_id)
            _, preds = torch.max(logits, dim=1)
            cur_tr_acc = torch.eq(preds, labels).int().masked_fill_(mask,
                                                                    0).sum()

            total_corrects += cur_tr_acc.item()

            if self.step % hparams.update_batch == 0:
                optim.step()
                optim.zero_grad()
                optim.zero_prev_grad()
                update_batch_size = 0
                if self.hparams.cosine_schedule_max_step:
                    self.scheduler.step()
            # clean up GPU memory
            if self.step % hparams.clean_mem_every == 0:
                gc.collect()
            if eop:
                if (self.epoch +
                        1) % (self.hparams.agent_checkpoint_every) == 0:
                    agent_name = "actor_" + str(
                        (self.epoch + 1) //
                        self.hparams.agent_checkpoint_every) + ".pt"
                    agent_save_checkpoint(self.actor, hparams,
                                          hparams.output_dir, agent_name)
                self.epoch += 1
                if self.hparams.cosine_schedule_max_step and self.hparams.schedule_restart:
                    self.scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
                        self.nmt_optim, self.hparams.cosine_schedule_max_step)
            #  get_grad_cos_all(model, data, crit)
            if (self.step / hparams.update_batch) % hparams.log_every == 0:
                curr_time = time.time()
                since_start = (curr_time - start_time) / 60.0
                elapsed = (curr_time - log_start_time) / 60.0
                log_string = "ep={0:<3d}".format(self.epoch)
                log_string += " steps={0:<6.2f}".format(
                    (self.step / hparams.update_batch) / 1000)
                if self.hparams.cosine_schedule_max_step:
                    log_string += " lr={0:<9.7f}".format(
                        self.scheduler.get_lr()[0])
                else:
                    log_string += " lr={0:<9.7f}".format(self.lr)
                log_string += " loss={0:<7.2f}".format(
                    cur_nmt_loss.sum().item())
                log_string += " |g|={0:<5.2f}".format(grad_norm)

                log_string += " ppl={0:<8.2f}".format(
                    np.exp(total_loss / target_words))
                log_string += " acc={0:<5.4f}".format(total_corrects /
                                                      target_words)

                log_string += " wpm(k)={0:<5.2f}".format(target_words /
                                                         (1000 * elapsed))
                log_string += " time(min)={0:<5.2f}".format(since_start)
                print(log_string)
            if hparams.eval_end_epoch:
                if eop:
                    eval_now = True
                else:
                    eval_now = False
            elif (self.step / hparams.update_batch) % hparams.eval_every == 0:
                eval_now = True
            else:
                eval_now = False
            if eval_now:
                based_on_bleu = hparams.eval_bleu and self.best_val_ppl[
                    0] is not None and self.best_val_ppl[
                        0] <= hparams.ppl_thresh
                with torch.no_grad():
                    val_ppl, val_bleu, ppl_list, bleu_list = eval(
                        model,
                        self.data_loader,
                        self.step,
                        hparams,
                        hparams,
                        eval_bleu=based_on_bleu,
                        valid_batch_size=hparams.valid_batch_size,
                        tr_logits=logits)
                for i in range(len(ppl_list)):
                    save_bleu, save_ppl = False, False
                    if based_on_bleu:
                        if self.best_val_bleu[i] is None or self.best_val_bleu[
                                i] <= bleu_list[i]:
                            save_bleu = True
                            self.best_val_bleu[i] = bleu_list[i]
                            self.cur_attempt = 0
                        else:
                            save_bleu = False
                            self.cur_attempt += 1
                    if self.best_val_ppl[i] is None or self.best_val_ppl[
                            i] >= ppl_list[i]:
                        save_ppl = True
                        self.best_val_ppl[i] = ppl_list[i]
                        self.cur_attempt = 0
                    else:
                        save_ppl = False
                        self.cur_attempt += 1
                    if save_bleu or save_ppl:
                        if save_bleu:
                            if len(ppl_list) > 1:
                                nmt_save_checkpoint([
                                    self.step, self.best_val_ppl,
                                    self.best_val_bleu, self.cur_attempt,
                                    self.lr, self.epoch
                                ],
                                                    model,
                                                    optim,
                                                    hparams,
                                                    hparams.output_dir +
                                                    "dev{}".format(i),
                                                    self.actor,
                                                    self.actor_optim,
                                                    prefix="bleu_")
                            else:
                                nmt_save_checkpoint([
                                    self.step, self.best_val_ppl,
                                    self.best_val_bleu, self.cur_attempt,
                                    self.lr, self.epoch
                                ],
                                                    model,
                                                    optim,
                                                    hparams,
                                                    hparams.output_dir,
                                                    self.actor,
                                                    self.actor_optim,
                                                    prefix="bleu_")
                        if save_ppl:
                            if len(ppl_list) > 1:
                                nmt_save_checkpoint([
                                    self.step, self.best_val_ppl,
                                    self.best_val_bleu, self.cur_attempt,
                                    self.lr, self.epoch
                                ],
                                                    model,
                                                    optim,
                                                    hparams,
                                                    hparams.output_dir +
                                                    "dev{}".format(i),
                                                    self.actor,
                                                    self.actor_optim,
                                                    prefix="ppl_")
                            else:
                                nmt_save_checkpoint([
                                    self.step, self.best_val_ppl,
                                    self.best_val_bleu, self.cur_attempt,
                                    self.lr, self.epoch
                                ],
                                                    model,
                                                    optim,
                                                    hparams,
                                                    hparams.output_dir,
                                                    self.actor,
                                                    self.actor_optim,
                                                    prefix="ppl_")
                    elif not hparams.lr_schedule and self.step >= hparams.n_warm_ups:
                        self.lr = self.lr * hparams.lr_dec
                        set_lr(optim, self.lr)
                # reset counter after eval
                log_start_time = time.time()
                target_words = total_corrects = total_loss = 0
                target_rules = target_total = target_eos = 0
                total_word_loss = total_rule_loss = total_eos_loss = 0
            if hparams.patience >= 0:
                if self.cur_attempt > hparams.patience: break
            elif hparams.n_train_epochs > 0:
                if self.epoch >= hparams.n_train_epochs: break
            else:
                if self.step > hparams.n_train_steps: break
            if eob: break
Ejemplo n.º 55
0
 def evaluate_bach(self, batch):
     prediction = self.predict(batch.text)
     labels = batch.label.type('torch.LongTensor')
     correct = torch.sum(torch.eq(prediction, labels)).float()
     accuracy = float(correct / labels.shape[0])
     return accuracy, prediction, labels
Ejemplo n.º 56
0
    def __iter__(self):
        while self.offset < len(self):
            batch = self.data[self.offset]
            batch_size = len(batch)
            batch_dict = {}

            # doc
            doc_len = max(len(x['doc_tok']) for x in batch)  ## batch 内最长的文本长度
            doc_id = torch.LongTensor(batch_size, doc_len).fill_(0)
            doc_tag = torch.LongTensor(batch_size, doc_len).fill_(0)
            doc_ent = torch.LongTensor(batch_size, doc_len).fill_(0)
            feature_len = len(
                eval(batch[0]['doc_fea'])[0]) if len(batch[0].get(
                    'doc_fea', [])) > 0 else 0  ## 4 (词频比例、精确匹配、小写匹配、lemma匹配)
            doc_feature = torch.Tensor(batch_size, doc_len,
                                       feature_len).fill_(0)
            doc_cid = torch.LongTensor(
                batch_size, doc_len,
                ELMoCharacterMapper.max_word_length).fill_(0)  # elmo

            # query
            query_len = max(len(x['query_tok'])
                            for x in batch)  ## batch 内最长的query长度
            query_id = torch.LongTensor(batch_size, query_len).fill_(0)
            query_cid = torch.LongTensor(
                batch_size, query_len,
                ELMoCharacterMapper.max_word_length).fill_(0)  # elmo

            for i, sample in enumerate(batch):
                ## doc (的特征包括:token id、词性、实体、词频比例、精确匹配、小写匹配、lemma匹配)
                doc_tok = sample['doc_tok']
                doc_select_len = min(len(doc_tok), doc_len)
                # if self.is_train:
                #     doc_tok = self.__random_select__(doc_tok)         ## mask
                doc_id[i, :doc_select_len] = torch.LongTensor(
                    doc_tok[:doc_select_len])
                doc_tag[i, :doc_select_len] = torch.LongTensor(
                    sample['doc_pos'][:doc_select_len])
                doc_ent[i, :doc_select_len] = torch.LongTensor(
                    sample['doc_ner'][:doc_select_len])
                # 词频比例、精确匹配、小写匹配、lemma匹配
                for j, feature in enumerate(eval(sample['doc_fea'])):
                    if j >= doc_select_len:
                        break
                    doc_feature[i, j, :] = torch.Tensor(feature)
                # elmo
                doc_ctok = sample['doc_ctok']
                for j, w in enumerate(batch_to_ids(doc_ctok)[0].tolist()):
                    if j >= doc_select_len:
                        break
                    doc_cid[i, j, :len(w)] = torch.LongTensor(w)

                ## query(的特征包括:token id、elmo)
                query_tok = sample['query_tok']
                query_select_len = min(len(query_tok), query_len)
                # if self.is_train:
                #     query_tok = self.__random_select__(query_tok)
                query_id[i, :query_select_len] = torch.LongTensor(
                    query_tok[:query_select_len])
                # elmo
                query_ctok = sample['query_ctok']
                for j, w in enumerate(batch_to_ids(query_ctok)[0].tolist()):
                    if j >= query_select_len:
                        break
                    query_cid[i, j, :len(w)] = torch.LongTensor(w)

            batch_dict['uids'] = [sample['uid'] for sample in batch]
            batch_dict['doc_tok'] = doc_id
            batch_dict['doc_mask'] = torch.eq(doc_id, 1)
            batch_dict['doc_pos'] = doc_tag
            batch_dict['doc_ner'] = doc_ent
            batch_dict['doc_fea'] = doc_feature
            batch_dict['doc_ctok'] = doc_cid  # elmo
            batch_dict['query_tok'] = query_id
            batch_dict['query_mask'] = torch.eq(query_id, 1)
            batch_dict['query_ctok'] = query_cid  # elmo
            if self.is_train:
                batch_dict['start'] = torch.LongTensor(
                    [sample['start'] for sample in batch])
                batch_dict['end'] = torch.LongTensor(
                    [sample['end'] for sample in batch])
                batch_dict['label'] = torch.FloatTensor(
                    [sample['label'] for sample in batch])
            else:
                batch_dict['text'] = [sample['context'] for sample in batch]
                batch_dict['span'] = [sample['span'] for sample in batch]

            self.offset += 1
            yield batch_dict
Ejemplo n.º 57
0
def train_epoch(device,
                logger,
                epoch,
                trainer,
                train_ds,
                val_ds,
                train_batch_size,
                val_batch_size,
                num_workers,
                save_every,
                eval_every,
                save_imgs_every,
                train_eval_indices,
                val_eval_indices,
                tb_log_every=100,
                tb_log_enc_every=500,
                n_au_steps=1,
                dbg=False):
    """"""
    # log buffers
    au_loss_buffer = []
    au_loss_on_real_buffer = []
    au_loss_on_fake_buffer = []
    au_reg_buffer = []
    au_out_on_real_buffer = []
    au_out_on_fake_buffer = []
    au_pred_on_real_buffer = []
    au_pred_on_fake_buffer = []
    im_loss_buffer = []

    trainloader = DataLoader(train_ds,
                             batch_size=train_batch_size,
                             shuffle=True,
                             num_workers=num_workers,
                             drop_last=True)
    num_iters = 50 if dbg else len(trainloader)
    iter_bar = tqdm(itertools.islice(trainloader, num_iters),
                    total=num_iters,
                    desc='Training')
    for batch_idx, data_batch in enumerate(iter_bar):
        # step
        trainer.module.do_global_step()
        trainer.module.update_learning_rate()

        # data
        real_sample = data_batch["real_sample"].to(device)
        leaked_sample = data_batch["leaked_sample"].to(device)
        si_sample = data_batch["si_sample"].to(device)
        global_step = trainer.module.global_step

        # impersonator train step
        if (global_step + 1) % n_au_steps == 0:
            im_loss, fake_sample, _ = im_train_step(
                trainer=trainer,
                leaked_sample=leaked_sample,
                si_sample=si_sample)
        else:
            im_loss, fake_sample, _ = im_eval_step(trainer=trainer,
                                                   leaked_sample=leaked_sample,
                                                   si_sample=si_sample)
        im_loss_buffer.append(im_loss.view(1))

        # authenticator train step
        au_loss, au_loss_on_real, au_loss_on_fake, au_reg, au_out_on_real, au_out_on_fake, au_pred_on_real, au_pred_on_fake, fake_sample = au_train_step(
            trainer=trainer,
            real_sample=real_sample,
            fake_sample=fake_sample,
            si_sample=si_sample)

        # log stats
        au_loss_buffer.append(au_loss.view(1))
        au_loss_on_real_buffer.append(au_loss_on_real.view(1))
        au_loss_on_fake_buffer.append(au_loss_on_fake.view(1))
        au_reg_buffer.append(au_reg.view(1))
        au_out_on_real_buffer.append(au_out_on_real.view(1))
        au_out_on_fake_buffer.append(au_out_on_fake.view(1))
        au_pred_on_real_buffer.append(au_pred_on_real.view(-1))
        au_pred_on_fake_buffer.append(au_pred_on_fake.view(-1))

        if global_step % tb_log_every == 0:
            # lr
            logger.add_scalar(category='lr',
                              k='au',
                              v=trainer.module.au_lr,
                              global_step=global_step)
            logger.add_scalar(category='lr',
                              k='im',
                              v=trainer.module.im_lr,
                              global_step=global_step)
            logger.add_scalar(category='lr',
                              k='im_lm',
                              v=trainer.module.im_noise_mapping_lr,
                              global_step=global_step)

            # losses
            logger.add_scalar(category='train_losses',
                              k='dis_loss',
                              v=torch.cat(au_loss_buffer).mean().item(),
                              global_step=global_step)
            logger.add_scalar(
                category='train_losses',
                k='dis_loss_on_real',
                v=torch.cat(au_loss_on_real_buffer).mean().item(),
                global_step=global_step)
            logger.add_scalar(
                category='train_losses',
                k='dis_loss_on_fake',
                v=torch.cat(au_loss_on_fake_buffer).mean().item(),
                global_step=global_step)
            logger.add_scalar(category='train_losses',
                              k='dis_reg',
                              v=torch.cat(au_reg_buffer).mean().item(),
                              global_step=global_step)

            logger.add_scalar(category='train_au_out',
                              k='au_out_on_real',
                              v=torch.cat(au_out_on_real_buffer).mean().item(),
                              global_step=global_step)
            logger.add_scalar(category='train_au_out',
                              k='au_out_on_fake',
                              v=torch.cat(au_out_on_fake_buffer).mean().item(),
                              global_step=global_step)

            # acc
            au_acc_on_real = torch.cat(au_pred_on_real_buffer).to(
                torch.float).mean()
            au_acc_on_fake = torch.eq(torch.cat(au_pred_on_fake_buffer),
                                      0).to(torch.float).mean()
            au_acc = 0.5 * (au_acc_on_real + au_acc_on_fake)

            logger.add_scalar(category='train_accuracy',
                              k='dis_acc',
                              v=au_acc.item(),
                              global_step=global_step)
            logger.add_scalar(category='train_accuracy',
                              k='dis_acc_on_real',
                              v=au_acc_on_real.item(),
                              global_step=global_step)
            logger.add_scalar(category='train_accuracy',
                              k='dis_acc_on_fake',
                              v=au_acc_on_fake.item(),
                              global_step=global_step)

            # im
            if len(im_loss_buffer) > 0:
                logger.add_scalar(category='train losses',
                                  k='gen loss',
                                  v=torch.cat(im_loss_buffer).mean().item(),
                                  global_step=global_step)

            # clear buffers
            au_loss_buffer = []
            au_loss_on_real_buffer = []
            au_loss_on_fake_buffer = []
            au_reg_buffer = []
            au_out_on_real_buffer = []
            au_out_on_fake_buffer = []
            au_pred_on_real_buffer = []
            au_pred_on_fake_buffer = []
            im_loss_buffer = []

        # log encodings
        if global_step % tb_log_enc_every == 0:
            with torch.no_grad():
                # authenticator
                au_real_src = trainer.module.authenticator.src_encode_sample(
                    real_sample)
                au_si_src = trainer.module.authenticator.src_encode_sample(
                    si_sample)
                au_fake_src = trainer.module.authenticator.src_encode_sample(
                    fake_sample)

                au_real_env = trainer.module.authenticator.env_encode_sample(
                    real_sample)
                au_si_env = trainer.module.authenticator.env_encode_sample(
                    si_sample)
                au_fake_env = trainer.module.authenticator.env_encode_sample(
                    fake_sample)

                # mean
                logger.add_scalar(
                    category='train-au_src_mean',
                    k='abs[real-si]',
                    v=torch.abs(au_real_src.mean(1) -
                                au_si_src.mean(1)).mean().item(),
                    global_step=global_step)
                logger.add_scalar(
                    category='train-au_src_mean',
                    k='abs[fake-si]',
                    v=torch.abs(au_fake_src.mean(1) -
                                au_si_src.mean(1)).mean().item(),
                    global_step=global_step)

                logger.add_scalar(
                    category='train-au_env_mean',
                    k='abs[real-si]',
                    v=torch.abs(au_real_env.mean(1) -
                                au_si_env.mean(1)).mean().item(),
                    global_step=global_step)
                logger.add_scalar(
                    category='train-au_env_mean',
                    k='abs[fake-si]',
                    v=torch.abs(au_fake_env.mean(1) -
                                au_si_env.mean(1)).mean().item(),
                    global_step=global_step)

                # std
                au_real_src_std = mb.custom_std(au_real_src).mean().item()
                au_si_src_std = mb.custom_std(au_si_src).mean().item()
                au_fake_src_std = mb.custom_std(au_fake_src).mean().item()
                logger.add_scalar(category='train-au_src_std',
                                  k='real',
                                  v=au_real_src_std,
                                  global_step=global_step)
                logger.add_scalar(category='train-au_src_std',
                                  k='si',
                                  v=au_si_src_std,
                                  global_step=global_step)
                logger.add_scalar(category='train-au_src_std',
                                  k='fake',
                                  v=au_fake_src_std,
                                  global_step=global_step)

                au_real_env_std = mb.custom_std(au_real_env).mean().item()
                au_si_env_std = mb.custom_std(au_si_env).mean().item()
                au_fake_env_std = mb.custom_std(au_fake_env).mean().item()
                logger.add_scalar(category='train-au_env_std',
                                  k='real',
                                  v=au_real_env_std,
                                  global_step=global_step)
                logger.add_scalar(category='train-au_env_std',
                                  k='si',
                                  v=au_si_env_std,
                                  global_step=global_step)
                logger.add_scalar(category='train-au_env_std',
                                  k='fake',
                                  v=au_fake_env_std,
                                  global_step=global_step)

        if (global_step % save_every == 0):
            trainer.module.save(epoch=epoch)

        if global_step % save_imgs_every == 0:
            sample_and_save_imgs(device=device,
                                 logger=logger,
                                 trainer=trainer,
                                 ds=train_ds,
                                 ds_prefix='train',
                                 indices=train_eval_indices,
                                 dbg=dbg)
            sample_and_save_imgs(device=device,
                                 logger=logger,
                                 trainer=trainer,
                                 ds=val_ds,
                                 ds_prefix='val',
                                 indices=val_eval_indices,
                                 dbg=dbg)

        if global_step % eval_every == 0:
            eval_step(device=device,
                      trainer=trainer,
                      ds=val_ds,
                      logger=logger,
                      batch_size=val_batch_size)
Ejemplo n.º 58
0
def run_epoch(
        evaluation,
        hierarchical, sigma, alpha,                           # Model parameters
        model, dataloader, optimizer,                         # Training objects
        iteration, epoch_loss, epoch_accuracy, sub_accuracy,  # Intermediate results
        lambda_dict):
    """
    Runs through the entire dataset once, updates model only if evaluation=False
    Args:
        Input:
            evaluation : Boolean, if set to true, the model will not be updated
                         and the data will not be warped before the forward pass
            hierarchical : Boolean, is the model hierarchical or not?
            sigma, alpha : Parameters for image warping during training
            model : A PrototypeModel or HierarchyModel
            dataloader : A dataloader object, this function will go through all data
            optimizer : Optimizer object
            iteration, epoch_loss, epoch_accuracy, sub_accuracy : intermediate results
            lambda : all lambda's for calculating the loss function
        Output:
            The output consists of 4 scalars, representing:
            iteration : amount of data points seen
            epoch_loss, epoch_accuracy : accuracy over this epoch
            sub_accuracy : equal to 0 is hierarchical=False
    """

    for _, (images, labels) in enumerate(dataloader):
        # Up the iteration by 1
        iteration += 1

        # Transform images, then port to GPU
        if not evaluation:
            images = batch_elastic_transform(images, sigma, alpha, 28, 28)
        images = images.to(DEVICE)
        labels = labels.to(DEVICE)
        oh_labels = one_hot(labels)

        # Forward pass
        if hierarchical:
            _, decoding, (sub_c, sup_c, r_1, r_2, r_3, r_4) = model.forward(images)
        else:
            _, decoding, (r_1, r_2, c) = model.forward(images)

        # Calculate loss: Crossentropy + Reconstruction + R1 + R2
        # Crossentropy h(f(x)) and y
        ce = nn.CrossEntropyLoss()
        # reconstruction error g(f(x)) and x
        subtr = (decoding - images).view(-1, 28*28)
        re = torch.mean(torch.norm(subtr, dim=1))

        # Paper does 20 * ce and lambda_n = 1 for each regularization term
        # Calculate loss and get accuracy etc.

        if hierarchical:
            sup_ce = ce(sup_c, torch.argmax(oh_labels, dim=1))
            # Extra cross entropy for second linear layer
            sub_ce = ce(sub_c, torch.argmax(oh_labels, dim=1))

            # Actual loss
            loss = lambda_dict['lambda_class_sup'] * sup_ce + \
                lambda_dict['lambda_ae'] * re + \
                lambda_dict['lambda_class_sub'] * sub_ce + \
                lambda_dict['lambda_r1'] * r_1 + \
                lambda_dict['lambda_r2'] * r_2 + \
                lambda_dict['lambda_r3'] * r_3 + \
                lambda_dict['lambda_r4'] * r_4

            # For super prototype cross entropy term
            epoch_loss += loss.item()
            preds = torch.argmax(sup_c, dim=1)
            corr = torch.sum(torch.eq(preds, labels))
            size = labels.shape[0]
            epoch_accuracy += corr.item()/size

            # Also for sub prototype cross entropy term
            subpreds = torch.argmax(sub_c, dim=1)
            subcorr = torch.sum(torch.eq(subpreds, labels))
            sub_accuracy += subcorr.item()/size
        else:
            crossentropy_loss = ce(c, torch.argmax(oh_labels, dim=1))
            loss = lambda_dict['lambda_class'] * crossentropy_loss + \
            lambda_dict['lambda_ae'] * re + \
            lambda_dict['lambda_r1'] * r_1 +  \
            lambda_dict['lambda_r2'] * r_2

            # For prototype cross entropy term
            epoch_loss += loss.item()
            preds = torch.argmax(c, dim=1)
            corr = torch.sum(torch.eq(preds, labels))
            size = labels.shape[0]
            epoch_accuracy += corr.item()/size

        # Do backward pass and ADAM steps
        if not evaluation:
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()

    return iteration, epoch_loss, epoch_accuracy, sub_accuracy
Ejemplo n.º 59
0
def eval_step(device, trainer, ds, logger, batch_size):
    """"""
    # stats list
    au_loss_list = []
    au_loss_on_real_list = []
    au_loss_on_fake_list = []
    au_out_on_real_list = []
    au_out_on_fake_list = []
    au_acc_list = []
    au_acc_on_real_list = []
    au_acc_on_fake_list = []
    im_loss_list = []
    global_step = trainer.module.get_global_step()

    dataloader = DataLoader(ds,
                            batch_size=batch_size,
                            shuffle=False,
                            num_workers=0,
                            drop_last=True)
    num_iters = len(dataloader)
    iter_bar = tqdm(itertools.islice(dataloader, num_iters),
                    total=num_iters,
                    desc='Eval')
    for batch_idx, data_batch in enumerate(iter_bar):
        # data
        real_sample = data_batch["real_sample"].to(device)
        leaked_sample = data_batch["leaked_sample"].to(device)
        si_sample = data_batch["si_sample"].to(device)

        # impersonator train step
        im_loss, fake_sample, _ = im_eval_step(trainer=trainer,
                                               leaked_sample=leaked_sample,
                                               si_sample=si_sample)

        # authenticator train step
        au_loss, au_loss_on_real, au_loss_on_fake, au_reg, au_out_on_real, au_out_on_fake, au_pred_on_real, au_pred_on_fake, fake_sample = au_eval_step(
            trainer=trainer,
            real_sample=real_sample,
            fake_sample=fake_sample,
            si_sample=si_sample)
        # acc
        au_acc_on_real = au_pred_on_real.to(torch.float).mean()
        au_acc_on_fake = torch.eq(au_pred_on_fake, 0).to(torch.float).mean()
        au_acc = 0.5 * (au_acc_on_real + au_acc_on_fake)

        au_loss_list.append(au_loss.view(1))
        au_loss_on_real_list.append(au_loss_on_real.view(1))
        au_loss_on_fake_list.append(au_loss_on_fake.view(1))
        au_out_on_real_list.append(au_out_on_real.view(1))
        au_out_on_fake_list.append(au_out_on_fake.view(1))
        au_acc_list.append(au_acc.view(1))
        au_acc_on_real_list.append(au_acc_on_real.view(1))
        au_acc_on_fake_list.append(au_acc_on_fake.view(1))
        im_loss_list.append(im_loss.view(1))

    # log
    logger.add_scalar(category='eval losses',
                      k='dis loss',
                      v=torch.cat(au_loss_list).mean().item(),
                      global_step=global_step)
    logger.add_scalar(category='eval losses',
                      k='dis loss on real',
                      v=torch.cat(au_loss_on_real_list).mean().item(),
                      global_step=global_step)
    logger.add_scalar(category='eval losses',
                      k='dis loss on fake',
                      v=torch.cat(au_loss_on_fake_list).mean().item(),
                      global_step=global_step)
    logger.add_scalar(category='eval au out',
                      k='au out on real',
                      v=torch.cat(au_out_on_real_list).mean().item(),
                      global_step=global_step)
    logger.add_scalar(category='eval au out',
                      k='au out on fake',
                      v=torch.cat(au_out_on_fake_list).mean().item(),
                      global_step=global_step)
    logger.add_scalar(category='eval accuracy',
                      k='dis acc',
                      v=torch.cat(au_acc_list).mean().item(),
                      global_step=global_step)
    logger.add_scalar(category='eval accuracy',
                      k='dis acc on real',
                      v=torch.cat(au_acc_on_real_list).mean().item(),
                      global_step=global_step)
    logger.add_scalar(category='eval accuracy',
                      k='dis acc on fake',
                      v=torch.cat(au_acc_on_fake_list).mean().item(),
                      global_step=global_step)
    logger.add_scalar(category='eval losses',
                      k='gen loss',
                      v=torch.cat(im_loss_list).mean().item(),
                      global_step=global_step)
Ejemplo n.º 60
0
def cal_acc(outputs, labels):
    pred_labels = torch.max(outputs, 1)[1]
    equality = torch.eq(pred_labels, labels).float()
    accuracy = torch.mean(equality)
    return accuracy