Example #1
0
    def forward(self, output, target):
        """
        output (FloatTensor): batch_size x n_classes
        target (LongTensor): batch_size
        """
        model_prob = self.one_hot.repeat(target.size(0), 1)
        model_prob.scatter_(1, target.unsqueeze(1), self.confidence)
        model_prob.masked_fill_((target == self.ignore_index).unsqueeze(1), 0)

        return F.kl_div(output, model_prob, reduction='sum')
Example #2
0
def softmax_kl_loss(input_logits, target_logits):
    """Takes softmax on both sides and returns KL divergence

    Note:
    - Returns the sum over all examples. Divide by the batch size afterwards
      if you want the mean.
    - Sends gradients to inputs but not the targets.
    """
    assert input_logits.size() == target_logits.size()
    input_log_softmax = F.log_softmax(input_logits, dim=1)
    target_softmax = F.softmax(target_logits, dim=1)
    return F.kl_div(input_log_softmax, target_softmax, size_average=False)
Example #3
0
    def forward(self, output, target):
        """
        output (FloatTensor): batch_size x tgt_vocab_size
        target (LongTensor): batch_size
        """
        # (batch_size, tgt_vocab_size)
        true_dist = self.one_hot.repeat(target.size(0), 1)

        # fill in gold-standard word position with confidence value
        true_dist.scatter_(1, target.unsqueeze(-1), self.confidence)

        # fill padded entries with zeros
        true_dist.masked_fill_((target == self.padding_idx).unsqueeze(-1), 0.)

        loss = -F.kl_div(output, true_dist, reduction='none').sum(-1)

        return loss
Example #4
0
    def train(self, epoch):
        
        self.model.train()
        print("Epochs %d"%epoch)
        T=2
        
        tasknum = self.incremental_loader.t
        end = self.incremental_loader.end
        mid = end-self.args.step_size
        start = 0
        
        exemplar_dataset_loaders = trainer.ExemplarLoader(self.incremental_loader)
        exemplar_iterator = torch.utils.data.DataLoader(exemplar_dataset_loaders,
                                                        batch_size=self.args.replay_batch_size, 
                                                        shuffle=True, drop_last=True, **self.kwargs)
        
        if tasknum > 0:
            iterator = zip(self.train_iterator, exemplar_iterator)
        else:
            iterator = self.train_iterator
        
        
        for samples in tqdm(iterator):
            if tasknum > 0:
                curr, prev = samples
                
                data, target = curr
                if self.args.ablation == 'None':
                    target = target%(end-mid)
                batch_size = data.shape[0]
                data_r, target_r = prev
                replay_size = data_r.shape[0]
                data, data_r = data.cuda(), data_r.cuda()
                data = torch.cat((data,data_r))
                target, target_r = target.cuda(), target_r.cuda()
                
            else:
                data, target = samples
                data = data.cuda()
                target = target.cuda()
                    
                batch_size = data.shape[0]
            
            output = self.model(data)
            loss_KD = 0
            
            if self.args.ablation == 'naive':
                target = torch.cat((target, target_r))
                loss_CE = self.loss(output[:,:end],target) / (batch_size + replay_size)
            
            else:
            
                loss_CE_curr = 0
                loss_CE_prev = 0

                curr = output[:batch_size,mid:end]
                loss_CE_curr = self.loss(curr, target)

                if tasknum > 0:
                    prev = output[batch_size:batch_size+replay_size,start:mid]
                    loss_CE_prev = self.loss(prev, target_r)
                    loss_CE = (loss_CE_curr + loss_CE_prev) / (batch_size + replay_size)
                    
                    # loss_KD
                    score = self.model_fixed(data)[:,:mid].data
                    if self.distill == 'global':
                        soft_target = F.softmax(score / T, dim=1)
                        output_log = F.log_softmax(output[:,:mid] / T, dim=1)
                        loss_KD = F.kl_div(output_log, soft_target, reduction='batchmean') * (T ** 2)
                        
                    elif self.distill == 'local':
                        loss_KD = torch.zeros(tasknum).cuda()
                        for t in range(tasknum):

                            # local distillation
                            start_KD = (t) * self.args.step_size
                            end_KD = (t+1) * self.args.step_size

                            soft_target = F.softmax(score[:,start_KD:end_KD] / T, dim=1)
                            output_log = F.log_softmax(output[:,start_KD:end_KD] / T, dim=1)
                            loss_KD[t] = F.kl_div(output_log, soft_target, reduction='batchmean') * (T**2)
                        loss_KD = loss_KD.sum()

                else:
                    loss_CE = loss_CE_curr / batch_size
                
            self.optimizer.zero_grad()
            (loss_KD + loss_CE).backward()
            self.optimizer.step()
Example #5
0
    def forward(self, data):

        KL_loss = None
        x, A, mask, _, params_dict = data[:5]

        mask_float = mask.float()
        N_nodes_float = params_dict['N_nodes'].float()
        B, N, C = x.shape
        alpha_gt = None
        if 'node_attn' in params_dict:
            if not isinstance(params_dict['node_attn'], list):
                params_dict['node_attn'] = [params_dict['node_attn']]
            alpha_gt = params_dict['node_attn'][-1].view(B, N)
        if 'node_attn_eval' in params_dict:
            if not isinstance(params_dict['node_attn_eval'], list):
                params_dict['node_attn_eval'] = [params_dict['node_attn_eval']]

        if (self.pool_type[1] == 'gt' or
            (self.pool_type[1] == 'sup'
             and self.training)) and alpha_gt is None:
            raise ValueError(
                'ground truth node attention values node_attn required for %s'
                % self.pool_type)

        if self.pool_type[1] in ['unsup', 'sup']:
            attn_input = data[-1] if self.pool_arch[1] == 'prev' else x.clone()
            if self.pool_arch[0] == 'fc':
                alpha_pre = self.proj(attn_input).view(B, N)
            else:
                alpha_pre = self.proj([attn_input, *data[1:]])[0].view(B, N)
            # softmax with masking out dummy nodes
            alpha_pre = torch.clamp(alpha_pre, -self.clamp_value,
                                    self.clamp_value)
            alpha = normalize_batch(
                self.mask_out(torch.exp(alpha_pre), mask_float).view(B, N))
            if self.pool_type[1] == 'sup' and self.training:
                KL_loss_per_node = self.mask_out(
                    F.kl_div(torch.log(alpha + 1e-14),
                             alpha_gt,
                             reduction='none'), mask_float.view(B, N))
                KL_loss = self.kl_weight * torch.mean(
                    KL_loss_per_node.sum(dim=1) /
                    (N_nodes_float +
                     1e-7))  # mean over nodes, then mean over batches
        else:
            alpha = alpha_gt

        x = x * alpha.view(B, N, 1)
        if self.large_graph:
            # For large graphs during training, all alpha values can be very small hindering training
            x = x * N_nodes_float.view(B, 1, 1)
        if self.is_topk:
            N_remove = torch.round(N_nodes_float * (1 - self.topk_ratio)).long(
            )  # number of nodes to be removed for each graph
            idx = torch.sort(
                alpha, dim=1,
                descending=False)[1]  # indices of alpha in ascending order
            mask = mask.clone().view(B, N)
            for b in range(B):
                idx_b = idx[b, mask[b, idx[
                    b]]]  # take indices of non-dummy nodes for current data example
                mask[b, idx_b[:N_remove[b]]] = 0
        else:
            mask = (mask & (alpha.view_as(mask) > self.threshold)).view(B, N)

        if self.drop_nodes:
            x, A, mask, N_nodes_pooled, idx = self.drop_nodes_edges(x, A, mask)
            if idx is not None and 'node_attn' in params_dict:
                # update ground truth (or weakly labeled) attention for a reduced graph
                params_dict['node_attn'].append(
                    normalize_batch(
                        self.mask_out(torch.gather(alpha_gt, dim=1, index=idx),
                                      mask.float())))
            if idx is not None and 'node_attn_eval' in params_dict:
                # update ground truth (or weakly labeled) attention for a reduced graph
                params_dict['node_attn_eval'].append(
                    normalize_batch(
                        self.mask_out(
                            torch.gather(params_dict['node_attn_eval'][-1],
                                         dim=1,
                                         index=idx), mask.float())))
        else:
            N_nodes_pooled = torch.sum(mask, dim=1).long()  # B
            if 'node_attn' in params_dict:
                params_dict['node_attn'].append(
                    (self.mask_out(params_dict['node_attn'][-1],
                                   mask.float())))
            if 'node_attn_eval' in params_dict:
                params_dict['node_attn_eval'].append(
                    (self.mask_out(params_dict['node_attn_eval'][-1],
                                   mask.float())))

        params_dict['N_nodes'] = N_nodes_pooled

        mask_matrix = mask.unsqueeze(2) & mask.unsqueeze(1)
        A = A * mask_matrix.float()  # or A[~mask_matrix] = 0

        # Add additional losses regularizing the model
        if KL_loss is not None:
            if 'reg' not in params_dict:
                params_dict['reg'] = []
            params_dict['reg'].append(KL_loss)

        # Keep attention coefficients for evaluation
        for key, value in zip(['alpha', 'mask'], [alpha, mask]):
            if key not in params_dict:
                params_dict[key] = []
            params_dict[key].append(value.detach())

        if self.debug and alpha_gt is not None:
            idx_correct_pool = (alpha_gt > 0)
            idx_correct_drop = (alpha_gt == 0)
            alpha_correct_pool = alpha[idx_correct_pool].sum(
            ) / N_nodes_float.sum()
            alpha_correct_drop = alpha[idx_correct_drop].sum(
            ) / N_nodes_float.sum()
            ratio_avg = (N_nodes_pooled.float() / N_nodes_float).mean()

            for key, values in zip([
                    'alpha_correct_pool_debug', 'alpha_correct_drop_debug',
                    'ratio_avg_debug'
            ], [alpha_correct_pool, alpha_correct_drop, ratio_avg]):
                if key not in params_dict:
                    params_dict[key] = []
                params_dict[key].append(values.detach())

        return [x, A, mask, *data[3:]]
Example #6
0
def distillation(y, teacher_scores, labels, T, alpha):
    kl_div = F.kl_div(F.log_softmax(y/T, dim=1), F.softmax(teacher_scores/T, dim=1))
    return kl_div * (T*T * 2. * alpha) + F.cross_entropy(y, labels, ignore_index=255) * (1. - alpha)
Example #7
0
    def update(self):
        # discount reward
        # R_i = r_i + GAMMA * R_{i+1}
        r_list = []
        R = 0
        eps = np.finfo(np.float32).eps.item()
        norm = lambda a: (a - a.mean()) / (a.std() + eps)
        for r in self.rewards[::-1]:
            R = r + self.gamma * R
            r_list.append(R)
        r_list = r_list[::-1]
        r_list = torch.tensor(r_list)
        r_list = norm(r_list).to(self.device)

        states = torch.cat(self.states)
        actions = torch.cat(self.actions).to(self.device)
        saved_log_probs = torch.cat(self.saved_log_probs).to(self.device)
        v_list = torch.cat(self.value_list).squeeze(-1)
        advantage = r_list - v_list
        advantage = norm(advantage)

        # compute PG loss
        # loss = sum(-R_i * log(action_prob))
        states = states.detach()
        actions = actions.detach()
        log_prob_actions = saved_log_probs.detach()
        advantage = advantage.detach()
        r_list = r_list.detach()
        
        loss = 0
        for _ in range(self.ppo_steps):
            #get new log prob of actions for all input states
            action_pred, value_pred = self.model(states)
            value_pred = value_pred.squeeze(-1)
            action_prob = F.softmax(action_pred, dim = -1)
            dist = Categorical(action_prob)

            #new log prob using old actions
            new_log_prob_actions = dist.log_prob(actions)

            policy_ratio = (new_log_prob_actions - log_prob_actions).exp()
            policy_loss_1 = policy_ratio * advantage
            # import pdb
            # pdb.set_trace()

            kl=F.kl_div(log_prob_actions,new_log_prob_actions.exp(),reduction='mean')
            if kl < self.kl_target/1.5:
            	self.beta/=2
            elif kl >self.kl_target*1.5:
            	self.beta*=2
            #kl=



            #policy_loss_2 = torch.clamp(policy_ratio, min = 1.0 - self.ppo_clip, max = 1.0 + self.ppo_clip) * advantage

            #policy_loss = - torch.min(policy_loss_1, policy_loss_2).mean()
            policy_loss=-(policy_loss_1-beta*kl).mean()
            value_loss = F.smooth_l1_loss(r_list, value_pred).mean()
            loss += policy_loss.item() + value_loss.item()

            self.optimizer.zero_grad()
            
            policy_loss.backward()
            value_loss.backward()

            self.optimizer.step()
        loss /= self.ppo_steps
        return loss
Example #8
0
def klloss(pred, true):
    return F.kl_div(pred, true)
def train(vocab, train_sets, dev_sets, test_sets, unlabeled_sets):
    """
    train_sets, dev_sets, test_sets: dict[domain] -> AmazonDataset
    For unlabeled domains, no train_sets are available
    """
    # dataset loaders
    train_loaders, unlabeled_loaders = {}, {}
    train_iters, unlabeled_iters = {}, {}
    dev_loaders, test_loaders = {}, {}
    my_collate = utils.sorted_collate if opt.model == 'lstm' else utils.unsorted_collate
    for domain in opt.domains:
        train_loaders[domain] = DataLoader(train_sets[domain],
                                           opt.batch_size,
                                           shuffle=True,
                                           collate_fn=my_collate)
        train_iters[domain] = iter(train_loaders[domain])
    for domain in opt.dev_domains:
        dev_loaders[domain] = DataLoader(dev_sets[domain],
                                         opt.batch_size,
                                         shuffle=False,
                                         collate_fn=my_collate)
        test_loaders[domain] = DataLoader(test_sets[domain],
                                          opt.batch_size,
                                          shuffle=False,
                                          collate_fn=my_collate)
    for domain in opt.all_domains:
        if domain in opt.unlabeled_domains:
            uset = unlabeled_sets[domain]
        else:
            # for labeled domains, consider which data to use as unlabeled set
            if opt.unlabeled_data == 'both':
                uset = ConcatDataset(
                    [train_sets[domain], unlabeled_sets[domain]])
            elif opt.unlabeled_data == 'unlabeled':
                uset = unlabeled_sets[domain]
            elif opt.unlabeled_data == 'train':
                uset = train_sets[domain]
            else:
                raise Exception(
                    f'Unknown options for the unlabeled data usage: {opt.unlabeled_data}'
                )
        unlabeled_loaders[domain] = DataLoader(uset,
                                               opt.batch_size,
                                               shuffle=True,
                                               collate_fn=my_collate)
        unlabeled_iters[domain] = iter(unlabeled_loaders[domain])

    # model
    F_s = None
    F_d = {}
    C, D = None, None
    if opt.model.lower() == 'dan':
        F_s = DanFeatureExtractor(vocab, opt.F_layers, opt.shared_hidden_size,
                                  opt.sum_pooling, opt.dropout, opt.F_bn)
        for domain in opt.domains:
            F_d[domain] = DanFeatureExtractor(vocab, opt.F_layers,
                                              opt.domain_hidden_size,
                                              opt.sum_pooling, opt.dropout,
                                              opt.F_bn)
    elif opt.model.lower() == 'lstm':
        F_s = LSTMFeatureExtractor(vocab, opt.F_layers, opt.shared_hidden_size,
                                   opt.dropout, opt.bdrnn, opt.attn)
        for domain in opt.domains:
            F_d[domain] = LSTMFeatureExtractor(vocab, opt.F_layers,
                                               opt.domain_hidden_size,
                                               opt.dropout, opt.bdrnn,
                                               opt.attn)
    elif opt.model.lower() == 'cnn':
        F_s = CNNFeatureExtractor(vocab, opt.F_layers, opt.shared_hidden_size,
                                  opt.kernel_num, opt.kernel_sizes,
                                  opt.dropout)
        for domain in opt.domains:
            F_d[domain] = CNNFeatureExtractor(vocab, opt.F_layers,
                                              opt.domain_hidden_size,
                                              opt.kernel_num, opt.kernel_sizes,
                                              opt.dropout)
    else:
        raise Exception(f'Unknown model architecture {opt.model}')

    C = SentimentClassifier(opt.C_layers,
                            opt.shared_hidden_size + opt.domain_hidden_size,
                            opt.shared_hidden_size + opt.domain_hidden_size,
                            opt.num_labels, opt.dropout, opt.C_bn)
    D = DomainClassifier(opt.D_layers,
                         opt.shared_hidden_size, opt.shared_hidden_size,
                         len(opt.all_domains), opt.loss, opt.dropout, opt.D_bn)

    F_s, C, D = F_s.to(opt.device), C.to(opt.device), D.to(opt.device)
    for f_d in F_d.values():
        f_d = f_d.to(opt.device)
    # optimizers
    optimizer = optim.Adam(itertools.chain(
        *map(list, [F_s.parameters() if F_s else [],
                    C.parameters()] + [f.parameters() for f in F_d.values()])),
                           lr=opt.learning_rate)
    optimizerD = optim.Adam(D.parameters(), lr=opt.D_learning_rate)

    # testing
    if opt.test_only:
        log.info(f'Loading model from {opt.exp3_model_save_file}...')
        if F_s:
            F_s.load_state_dict(
                torch.load(
                    os.path.join(opt.exp3_model_save_file, f'netF_s.pth')))
        for domain in opt.all_domains:
            if domain in F_d:
                F_d[domain].load_state_dict(
                    torch.load(
                        os.path.join(opt.exp3_model_save_file,
                                     f'net_F_d_{domain}.pth')))
        C.load_state_dict(
            torch.load(os.path.join(opt.exp3_model_save_file, f'netC.pth')))
        D.load_state_dict(
            torch.load(os.path.join(opt.exp3_model_save_file, f'netD.pth')))

        log.info('Evaluating validation sets:')
        acc = {}
        for domain in opt.all_domains:
            acc[domain] = evaluate(domain, dev_loaders[domain], F_s,
                                   F_d[domain] if domain in F_d else None, C)
        avg_acc = sum([acc[d] for d in opt.dev_domains]) / len(opt.dev_domains)
        log.info(f'Average validation accuracy: {avg_acc}')
        log.info('Evaluating test sets:')
        test_acc = {}
        for domain in opt.all_domains:
            test_acc[domain] = evaluate(domain, test_loaders[domain], F_s,
                                        F_d[domain] if domain in F_d else None,
                                        C)
        avg_test_acc = sum([test_acc[d]
                            for d in opt.dev_domains]) / len(opt.dev_domains)
        log.info(f'Average test accuracy: {avg_test_acc}')
        return {'valid': acc, 'test': test_acc}

    # training
    best_acc, best_avg_acc = defaultdict(float), 0.0
    for epoch in range(opt.max_epoch):
        F_s.train()
        C.train()
        D.train()
        for f in F_d.values():
            f.train()

        # training accuracy
        correct, total = defaultdict(int), defaultdict(int)
        # D accuracy
        d_correct, d_total = 0, 0
        # conceptually view 1 epoch as 1 epoch of the first domain
        num_iter = len(train_loaders[opt.domains[0]])
        for i in tqdm(range(num_iter)):
            # D iterations
            utils.freeze_net(F_s)
            map(utils.freeze_net, F_d.values())
            utils.freeze_net(C)
            utils.unfreeze_net(D)
            # WGAN n_critic trick since D trains slower
            n_critic = opt.n_critic
            if opt.wgan_trick:
                if opt.n_critic > 0 and ((epoch == 0 and i < 25)
                                         or i % 500 == 0):
                    n_critic = 100

            for _ in range(n_critic):
                D.zero_grad()
                loss_d = {}
                # train on both labeled and unlabeled domains
                for domain in opt.all_domains:
                    # targets not used
                    d_inputs, _ = utils.endless_get_next_batch(
                        unlabeled_loaders, unlabeled_iters, domain)
                    d_targets = utils.get_domain_label(opt.loss, domain,
                                                       len(d_inputs[1]))
                    shared_feat = F_s(d_inputs)
                    d_outputs = D(shared_feat)
                    # D accuracy
                    _, pred = torch.max(d_outputs, 1)
                    d_total += len(d_inputs[1])
                    if opt.loss.lower() == 'l2':
                        _, tgt_indices = torch.max(d_targets, 1)
                        d_correct += (pred == tgt_indices).sum().item()
                        l_d = functional.mse_loss(d_outputs, d_targets)
                        l_d.backward()
                    else:
                        d_correct += (pred == d_targets).sum().item()
                        l_d = functional.nll_loss(d_outputs, d_targets)
                        l_d.backward()
                    loss_d[domain] = l_d.item()
                optimizerD.step()

            # F&C iteration
            utils.unfreeze_net(F_s)
            map(utils.unfreeze_net, F_d.values())
            utils.unfreeze_net(C)
            utils.freeze_net(D)
            if opt.fix_emb:
                utils.freeze_net(F_s.word_emb)
                for f_d in F_d.values():
                    utils.freeze_net(f_d.word_emb)
            F_s.zero_grad()
            for f_d in F_d.values():
                f_d.zero_grad()
            C.zero_grad()
            for domain in opt.domains:
                inputs, targets = utils.endless_get_next_batch(
                    train_loaders, train_iters, domain)
                targets = targets.to(opt.device)
                shared_feat = F_s(inputs)
                domain_feat = F_d[domain](inputs)
                features = torch.cat((shared_feat, domain_feat), dim=1)
                c_outputs = C(features)
                l_c = functional.nll_loss(c_outputs, targets)
                l_c.backward(retain_graph=True)
                _, pred = torch.max(c_outputs, 1)
                total[domain] += targets.size(0)
                correct[domain] += (pred == targets).sum().item()
            # update F with D gradients on all domains
            for domain in opt.all_domains:
                d_inputs, _ = utils.endless_get_next_batch(
                    unlabeled_loaders, unlabeled_iters, domain)
                shared_feat = F_s(d_inputs)
                d_outputs = D(shared_feat)
                if opt.loss.lower() == 'gr':
                    d_targets = utils.get_domain_label(opt.loss, domain,
                                                       len(d_inputs[1]))
                    l_d = functional.nll_loss(d_outputs, d_targets)
                    if opt.lambd > 0:
                        l_d *= -opt.lambd
                elif opt.loss.lower() == 'bs':
                    d_targets = utils.get_random_domain_label(
                        opt.loss, len(d_inputs[1]))
                    l_d = functional.kl_div(d_outputs,
                                            d_targets,
                                            size_average=False)
                    if opt.lambd > 0:
                        l_d *= opt.lambd
                elif opt.loss.lower() == 'l2':
                    d_targets = utils.get_random_domain_label(
                        opt.loss, len(d_inputs[1]))
                    l_d = functional.mse_loss(d_outputs, d_targets)
                    if opt.lambd > 0:
                        l_d *= opt.lambd
                l_d.backward()

            optimizer.step()

        # end of epoch
        log.info('Ending epoch {}'.format(epoch + 1))
        if d_total > 0:
            log.info('D Training Accuracy: {}%'.format(100.0 * d_correct /
                                                       d_total))
        log.info('Training accuracy:')
        log.info('\t'.join(opt.domains))
        log.info('\t'.join(
            [str(100.0 * correct[d] / total[d]) for d in opt.domains]))
        log.info('Evaluating validation sets:')
        acc = {}
        for domain in opt.dev_domains:
            acc[domain] = evaluate(domain, dev_loaders[domain], F_s,
                                   F_d[domain] if domain in F_d else None, C)
        avg_acc = sum([acc[d] for d in opt.dev_domains]) / len(opt.dev_domains)
        log.info(f'Average validation accuracy: {avg_acc}')
        log.info('Evaluating test sets:')
        test_acc = {}
        for domain in opt.dev_domains:
            test_acc[domain] = evaluate(domain, test_loaders[domain], F_s,
                                        F_d[domain] if domain in F_d else None,
                                        C)
        avg_test_acc = sum([test_acc[d]
                            for d in opt.dev_domains]) / len(opt.dev_domains)
        log.info(f'Average test accuracy: {avg_test_acc}')

        if avg_acc > best_avg_acc:
            log.info(f'New best average validation accuracy: {avg_acc}')
            best_acc['valid'] = acc
            best_acc['test'] = test_acc
            best_avg_acc = avg_acc
            with open(os.path.join(opt.exp3_model_save_file, 'options.pkl'),
                      'wb') as ouf:
                pickle.dump(opt, ouf)
            torch.save(F_s.state_dict(),
                       '{}/netF_s.pth'.format(opt.exp3_model_save_file))
            for d in opt.domains:
                if d in F_d:
                    torch.save(
                        F_d[d].state_dict(),
                        '{}/net_F_d_{}.pth'.format(opt.exp3_model_save_file,
                                                   d))
            torch.save(C.state_dict(),
                       '{}/netC.pth'.format(opt.exp3_model_save_file))
            torch.save(D.state_dict(),
                       '{}/netD.pth'.format(opt.exp3_model_save_file))

    # end of training
    log.info(f'Best average validation accuracy: {best_avg_acc}')
    return best_acc
Example #10
0
    def train(self, epoch, triplet=False):

        T = 2
        self.model.train()
        print("Epochs %d" % epoch)

        tasknum = self.train_data_iterator.dataset.t
        end = self.train_data_iterator.dataset.end
        start = end - self.args.step_size

        lamb = start / end

        self.training_target = torch.tensor([]).cuda()
        self.training_output = torch.tensor([]).cuda()
        self.training_idx = torch.tensor([]).cuda()

        for data, target, traindata_idx in tqdm(self.train_data_iterator):

            target = target.type(dtype=torch.long)

            data, target = data.cuda(), target.cuda()

            old_idx = torch.where(target < start)[0]
            new_idx = torch.where(target >= start)[0]

            loss_KD = 0
            loss_CE = 0
            loss_triplet = 0
            loss_mr = 0

            if tasknum > 0 and triplet == True:
                if epoch > self.args.triplet_epoch:

                    target_a = target
                    target_b = self.class_corr_prob(target)
                    output, a_feature = self.model(data, feature_return=True)

                    p_feature = torch.zeros(
                        (data.size()[0], a_feature.size()[1])).cuda()
                    n_feature = torch.zeros(
                        (data.size()[0], a_feature.size()[1])).cuda()

                    for i in range(data.size()[0]):
                        p_feature[i] = self.class_anchor[target_a[i]]
                    p_target = torch.ones(data.size()[0], 1).cuda()
                    loss_mr = torch.mean(torch.norm(
                        (a_feature - p_feature))) * self.args.triplet_lam

                    # loss_CE = self.loss(output[:, :end], target) + loss_triplet
                    loss_CE = self.loss(output[:, :end], target)
                else:
                    output, a_feature = self.model(data, feature_return=True)

                    target_a = target
                    target_b = self.class_corr_prob(target)

                    n_feature = torch.zeros((data[new_idx, :].size()[0],
                                             a_feature.size()[1])).cuda()

                    for i in range(data[new_idx, :].size()[0]):
                        n_feature[i] = self.class_anchor[target_b[new_idx][i]]
                    n_target = torch.ones(data[new_idx, :].size()[0], 1).cuda()
                    n_target = n_target * -1

                    dist = torch.norm((a_feature[new_idx, :] - n_feature),
                                      dim=1)

                    loss_mr = self.mr_loss(dist, n_target)

                    loss_CE = self.loss(output[:, :end], target)

            else:
                output, a_feature = self.model(data, feature_return=True)
                loss_CE = self.loss(output[:, :end], target)
                loss_triplet = 0

            _, predicted = output.max(1)
            prob = F.softmax(output, dim=1).cpu()

            if tasknum > 0:
                self.model_fixed.eval()
                end_KD = start
                start_KD = end_KD - self.args.step_size

                score = self.model_fixed(data)[:, :end_KD].data

                if self.args.KD == "naive_global":
                    soft_target = F.softmax(score / T, dim=1)
                    output_log = F.log_softmax(output[:, :end_KD] / T, dim=1)
                    loss_KD = F.kl_div(output_log,
                                       soft_target,
                                       reduction='batchmean')

                elif self.args.KD == "naive_local":
                    local_KD = torch.zeros(tasknum).cuda()
                    for t in range(tasknum):
                        local_start_KD = (t) * self.args.step_size
                        local_end_KD = (t + 1) * self.args.step_size

                        soft_target = F.softmax(
                            score[:, local_start_KD:local_end_KD] / T, dim=1)
                        output_log = F.log_softmax(
                            output[:, local_start_KD:local_end_KD] / T, dim=1)
                        local_KD[t] = F.kl_div(output_log,
                                               soft_target,
                                               reduction="batchmean")

                    loss_KD = local_KD.sum()

            self.optimizer.zero_grad()

            if (self.args.KD
                    == "PCKD") or (self.args.KD
                                   == "naive_global") or (self.args.KD
                                                          == "naive_local"):

                if triplet:
                    # (lamb * loss_KD + (1 - lamb) * loss_CE + loss_triplet).backward()
                    (loss_KD + loss_CE + loss_triplet + loss_mr).backward()
                else:
                    (lamb * loss_KD + (1 - lamb) * loss_CE).backward()

            elif self.args.KD == "No":
                loss_CE.backward()

            self.optimizer.step()

            # weight cliping 0인걸 없애기
            weight = self.model.fc.weight.data
            weight[weight < 0] = 0
            self.model.fc.bias.data[:] = 0

        print(self.train_data_iterator.dataset.__len__())
def train(net, dataloader, optimizer, epoch):
    criterion = nn.CrossEntropyLoss()
    net.train()
    hard_loss_sum = 0
    soft_loss_sum = 0
    loss_sum = 0
    correct = 0
    total = 0

    print('\n=> [%s] Training Epoch #%d, lr=%.4f' %
          (model_name, epoch, cf.learning_rate(args.lr, epoch)))
    log_file.write('\n=> [%s] Training Epoch #%d, lr=%.4f\n' %
                   (model_name, epoch, cf.learning_rate(args.lr, epoch)))
    for batch_idx, (inputs, targets) in enumerate(dataloader):
        inputs = inputs.to(device)
        targets = targets.to(device)
        # obtain soft_target by forwarding data in test mode
        if epoch >= args.distill_from and args.distill > 0:
            with torch.no_grad():
                net.eval()
                soft_target = net(inputs)

        net.train()
        optimizer.zero_grad()
        outputs = net(inputs)  # Forward Propagation
        loss = criterion(outputs, targets)  # Loss
        hard_loss_sum = hard_loss_sum + loss.item() * targets.size(0)

        # compute distillation loss
        if epoch >= args.distill_from and args.distill > 0:
            heat_output = outputs / args.temp
            heat_soft_target = soft_target / args.temp

            distill_loss = F.kl_div(F.log_softmax(heat_output, 1),
                                    F.softmax(heat_soft_target),
                                    size_average=False) / targets.size(0)
            soft_loss_sum = soft_loss_sum + distill_loss.item() * targets.size(
                0)

            distill_loss = distill_loss * (args.temp * args.temp)
            loss = loss + args.distill * distill_loss

        loss.backward()  # Backward Propagation
        optimizer.step()  # Optimizer update

        loss_sum = loss_sum + loss.item() * targets.size(0)
        _, predicted = torch.max(outputs.detach(), 1)
        total += targets.size(0)
        correct += predicted.eq(targets.detach()).long().sum().item()

        if math.isnan(loss.item()):
            print('@@@@@@@nan@@@@@@@@@@@@')
            log_file.write('@@@@@@@@@@@nan @@@@@@@@@@@@@\n')
            sys.exit(0)

        sys.stdout.write('\r')
        sys.stdout.write(
            '| Epoch [%3d/%3d] Iter[%3d/%3d]\tLoss: %.4g Acc@1: %.2f%% Hard: %.4g Soft: %.4g'
            % (epoch, args.num_epochs, batch_idx + 1,
               (len(trainset) // args.bs) + 1, loss_sum / total, 100. *
               correct / total, hard_loss_sum / total, soft_loss_sum / total))
        sys.stdout.flush()
    log_file.write(
        '| Epoch [%3d/%3d] \tLoss: %.4f Acc@1: %.2f%% Hard: %.4f Soft: %.8f' %
        (epoch, args.num_epochs, loss_sum / total, 100. * correct / total,
         hard_loss_sum / total, soft_loss_sum / total))
import numpy as np
import math
from copy import deepcopy
import pdb

import torch
import torch.nn as nn
import torch.nn.functional as F
import utils
from utils import get_grad_vector, get_future_step_parameters, add_grad, get_weight_accumelated_gradient_norm, weight_gradient_norm, get_future_step_parameters_with_grads, get_nearbysamples, get_mask_unused_memories, get_weight_norm_diff
from VAE.loss import calculate_loss

#----------
# Functions
dist_kl = lambda y, t_s: F.kl_div(F.log_softmax(y, dim=-1),
                                  F.softmax(t_s, dim=-1),
                                  reduction='mean') * y.size(0)

# this returns -entropy
entropy_fn = lambda x: torch.sum(
    F.softmax(x, dim=-1) * F.log_softmax(x, dim=-1), dim=-1)

cross_entropy = lambda y, t_s: -torch.sum(
    F.log_softmax(y, dim=-1) * F.softmax(t_s, dim=-1), dim=-1).mean()
mse = torch.nn.MSELoss()


def retrieve_gen_for_cls(args, x, cls, prev_cls, prev_gen):

    grad_vector = get_grad_vector(args, cls.parameters, cls.grad_dims)
Example #13
0
def train(net, train_loader, optimizer):
    """Train for one epoch."""
    net.train()
    data_ema = 0.
    batch_ema = 0.
    loss_ema = 0.
    acc1_ema = 0.
    acc5_ema = 0.

    end = time.time()
    for i, (images, targets) in enumerate(train_loader):
        # Compute data loading time
        data_time = time.time() - end
        optimizer.zero_grad()

        if args.no_jsd:
            images = images.cuda()
            targets = targets.cuda()
            logits = net(images)
            loss = F.cross_entropy(logits, targets)
            acc1, acc5 = accuracy(logits, targets, topk=(1, 5))  # pylint: disable=unbalanced-tuple-unpacking
        else:
            images_all = torch.cat(images, 0).cuda()
            targets = targets.cuda()
            logits_all = net(images_all)
            logits_clean, logits_aug1, logits_aug2 = torch.split(
                logits_all, images[0].size(0))

            # Cross-entropy is only computed on clean images
            loss = F.cross_entropy(logits_clean, targets)

            p_clean, p_aug1, p_aug2 = F.softmax(
                logits_clean, dim=1), F.softmax(logits_aug1,
                                                dim=1), F.softmax(logits_aug2,
                                                                  dim=1)

            # Clamp mixture distribution to avoid exploding KL divergence
            p_mixture = torch.clamp((p_clean + p_aug1 + p_aug2) / 3., 1e-7,
                                    1).log()
            loss += 12 * (
                F.kl_div(p_mixture, p_clean, reduction='batchmean') +
                F.kl_div(p_mixture, p_aug1, reduction='batchmean') +
                F.kl_div(p_mixture, p_aug2, reduction='batchmean')) / 3.
            acc1, acc5 = accuracy(logits_clean, targets, topk=(1, 5))  # pylint: disable=unbalanced-tuple-unpacking

        loss.backward()
        optimizer.step()

        # Compute batch computation time and update moving averages.
        batch_time = time.time() - end
        end = time.time()

        data_ema = data_ema * 0.1 + float(data_time) * 0.9
        batch_ema = batch_ema * 0.1 + float(batch_time) * 0.9
        loss_ema = loss_ema * 0.1 + float(loss) * 0.9
        acc1_ema = acc1_ema * 0.1 + float(acc1) * 0.9
        acc5_ema = acc5_ema * 0.1 + float(acc5) * 0.9

        if i % args.print_freq == 0:
            print(
                'Batch {}/{}: Data Time {:.3f} | Batch Time {:.3f} | Train Loss {:.3f} | Train Acc1 '
                '{:.3f} | Train Acc5 {:.3f}'.format(i, len(train_loader),
                                                    data_ema, batch_ema,
                                                    loss_ema, acc1_ema,
                                                    acc5_ema))

    return loss_ema, acc1_ema, batch_ema
def _kl_div(logit1, logit2):
    return F.kl_div(F.log_softmax(logit1, dim=1),
                    F.softmax(logit2, dim=1),
                    reduction='batchmean')
Example #15
0
 def forward(self, input, target):
     logp = F.log_softmax(input, dim=1)
     target_one_hot = self._smooth_labels(input.size(1), target)
     return F.kl_div(logp, target_one_hot, reduction='sum')
Example #16
0
def _kl_div(log_probs, probs):
    # pytorch KLDLoss is averaged over all dim if size_average=True
    kld = F.kl_div(log_probs, probs, size_average=False)
    return kld / log_probs.shape[0]
Example #17
0
    def beam_sample(self, src, src_len, dict_spk2idx, tgt, beam_size=1):

        src = src.transpose(0, 1)
        #beam_size = self.config.beam_size
        batch_size = src.size(0)

        # (1) Run the encoder on the src. Done!!!!
        if self.use_cuda:
            src = src.cuda()
            src_len = src_len.cuda()

        lengths, indices = torch.sort(src_len, dim=0, descending=True)
        # _, ind = torch.sort(indices)
        # src = Variable(torch.index_select(src, dim=1, index=indices), volatile=True)
        contexts, encState = self.encoder(src, lengths.data.cpu().numpy()[0])

        #  (1b) Initialize for the decoder.
        def var(a):
            return Variable(a, volatile=True)

        def rvar(a):
            return var(a.repeat(1, beam_size, 1))

        def bottle(m):
            return m.view(batch_size * beam_size, -1)

        def unbottle(m):
            return m.view(beam_size, batch_size, -1)

        # Repeat everything beam_size times.
        contexts = rvar(contexts.data).transpose(0, 1)
        decState = (rvar(encState[0].data), rvar(encState[1].data))
        #decState.repeat_beam_size_times(beam_size)
        beam = [
            models.Beam(beam_size, dict_spk2idx, n_best=1, cuda=self.use_cuda)
            for __ in range(batch_size)
        ]

        # (2) run the decoder to generate sentences, using beam search.

        mask = None
        soft_score = None
        tmp_hiddens = []
        tmp_soft_score = []
        for i in range(self.config.max_tgt_len):

            if all((b.done() for b in beam)):
                break

            # Construct batch x beam_size nxt words.
            # Get all the pending current beam words and arrange for forward.
            inp = var(
                torch.stack([b.getCurrentState()
                             for b in beam]).t().contiguous().view(-1))
            if self.config.schmidt and i > 0:
                assert len(beam[0].sch_hiddens[-1]) == i
                tmp_hiddens = []
                for xxx in range(i):  #每一个sch之前的列表
                    one_len = []
                    for bm_idx in range(beam_size):
                        for bs_idx in range(batch_size):
                            one_len.append(
                                beam[bs_idx].sch_hiddens[-1][xxx][bm_idx, :])
                    tmp_hiddens.append(var(torch.stack(one_len)))

            # Run one step.
            output, decState, attn, hidden, emb = self.decoder.sample_one(
                inp, soft_score, decState, tmp_hiddens, contexts, mask)
            # print "sample after decState:",decState[0].data.cpu().numpy().mean()
            if self.config.schmidt:
                tmp_hiddens += [hidden]
            if self.config.ct_recu:
                contexts = (1 -
                            (attn > 0.003).float()).unsqueeze(-1) * contexts
            soft_score = F.softmax(output)
            if self.config.tmp_score:
                tmp_soft_score += [soft_score]
                if i == 1:
                    kl_loss = np.array([])
                    for kk in range(self.config.beam_size):
                        kl_loss = np.append(
                            kl_loss,
                            F.kl_div(soft_score[kk],
                                     tmp_soft_score[0][kk]).data[0])
                    kl_loss = Variable(
                        torch.from_numpy(kl_loss).float().cuda().unsqueeze(-1))
            predicted = output.max(1)[1]
            if self.config.mask:
                if mask is None:
                    mask = predicted.unsqueeze(1).long()
                else:
                    mask = torch.cat((mask, predicted.unsqueeze(1)), 1)
            # decOut: beam x rnn_size

            # (b) Compute a vector of batch*beam word scores.
            if self.config.tmp_score and i == 1:
                output = unbottle(
                    self.log_softmax(output) + self.config.tmp_score * kl_loss)
            else:
                output = unbottle(self.log_softmax(output))
            attn = unbottle(attn)
            hidden = unbottle(hidden)
            emb = unbottle(emb)
            # beam x tgt_vocab

            # (c) Advance each beam.
            # update state
            for j, b in enumerate(beam):
                b.advance(output.data[:, j], attn.data[:, j],
                          hidden.data[:, j], emb.data[:, j])
                b.beam_update(decState,
                              j)  #这个函数更新了原来的decState,只不过不是用return,是直接赋值!
                if self.config.ct_recu:
                    b.beam_update_context(
                        contexts, j)  #这个函数更新了原来的decState,只不过不是用return,是直接赋值!
            # print "beam after decState:",decState[0].data.cpu().numpy().mean()

        # (3) Package everything up.
        allHyps, allScores, allAttn, allHiddens, allEmbs = [], [], [], [], []

        ind = range(batch_size)
        for j in ind:
            b = beam[j]
            n_best = 1
            scores, ks = b.sortFinished(minimum=n_best)
            hyps, attn, hiddens, embs = [], [], [], []
            for i, (times, k) in enumerate(ks[:n_best]):
                hyp, att, hidden, emb = b.getHyp(times, k)
                if self.config.relitu:
                    relitu_line(626, 1, att[0].cpu().numpy())
                    relitu_line(626, 1, att[1].cpu().numpy())
                hyps.append(hyp)
                attn.append(att.max(1)[1])
                hiddens.append(hidden)
                embs.append(emb)
            allHyps.append(hyps[0])
            allScores.append(scores[0])
            allAttn.append(attn[0])
            allHiddens.append(hiddens[0])
            allEmbs.append(embs[0])
        print allHyps

        if not self.config.global_emb:
            outputs = Variable(torch.stack(allHiddens, 0).transpose(
                0, 1))  # to [decLen, bs, dim]
            if not self.config.hidden_mix:
                predicted_maps = self.ss_model(src, outputs[:-1, :], tgt[1:-1])
            else:
                ss_embs = Variable(torch.stack(allEmbs, 0).transpose(
                    0, 1))  # to [decLen, bs, dim]
                mix = torch.cat((outputs[:-1, :], ss_embs[1:]), dim=2)
                predicted_maps = self.ss_model(src, mix, tgt[1:-1])
            if self.config.top1:
                predicted_maps = predicted_maps[:, 0].unsqueeze(1)
        else:
            ss_embs = Variable(torch.stack(allEmbs, 0).transpose(
                0, 1))  # to [decLen, bs, dim]
            if not self.config.top1:
                predicted_maps = self.ss_model(src, ss_embs[1:, :], tgt[1:-1],
                                               dict_spk2idx)
            else:
                predicted_maps = self.ss_model(src, ss_embs[1:2], tgt[1:2])
        return allHyps, allAttn, allHiddens, predicted_maps  #.transpose(0,1)
Example #18
0
def train_sdcn(dataset):
    # KNN Graph
    adj = load_graph(args.name, args.k)
    adj = adj.cuda()

    model = SDCN(dataset=dataset,
                 n_enc_1=500,
                 n_enc_2=500,
                 n_enc_3=2000,
                 n_dec_1=2000,
                 n_dec_2=500,
                 n_dec_3=500,
                 n_input=args.n_input,
                 n_z=args.n_z,
                 n_clusters=args.n_clusters,
                 v=1.0,
                 adj=adj).to(device)
    # print(model)
    # model.pretrain(args.pretrain_path)
    if args.name == 'cite':
        optimizer = Adam(model.parameters(), lr=args.lr, weight_decay=5e-4)
    if args.name == 'dblp':
        optimizer = Adam(model.parameters(), lr=args.lr)

    # cluster parameter initiate
    data = torch.Tensor(dataset.x).to(device)
    y = dataset.y
    with torch.no_grad():
        _, _, _, _, z = model.ae(data)

    kmeans = KMeans(n_clusters=args.n_clusters, n_init=20)
    y_pred = kmeans.fit_predict(z.data.cpu().numpy())
    y_pred_last = y_pred
    model.cluster_layer.data = torch.tensor(kmeans.cluster_centers_).to(device)
    eva(y, y_pred, 'pae')
    # ## Visualization
    # with SummaryWriter(comment='model') as w:
    #     testdata = torch.rand(3327,3703).cuda()
    #     testadj = torch.rand(3327,3327).cuda()
    #     w.add_graph(model,(testdata,testadj))
    # ##
    for epoch in range(200):
        if epoch % 1 == 0:
            # update_interval
            # x_bar为decoder的结果
            # q为t分布
            # predict为softmax后预测的种类
            # z为encoder的潜在特征表示
            _, tmp_q, pred, _, adj_re = model(data, adj)
            tmp_q = tmp_q.data
            p = target_distribution(tmp_q)

            res1 = tmp_q.cpu().numpy().argmax(1)  #Q
            res2 = pred.data.cpu().numpy().argmax(1)  #Z
            res3 = p.data.cpu().numpy().argmax(1)  #P
            # eva(y, res1, str(epoch) + 'Q')
            acc, nmi, ari, f1 = eva(y, res2, str(epoch) + 'Z')
            ## Visualization
            # writer.add_scalar('checkpoints/scalar', acc, epoch)
            # writer.add_scalar('checkpoints/scalar', nmi, epoch)
            # writer.add_scalar('checkpoints/scalar', ari, epoch)
            # writer.add_scalar('checkpoints/scalar', f1, epoch)
            ##
            # eva(y, res3, str(epoch) + 'P')

        # x_bar为decoder的结果
        # q为t分布
        # predict为softmax后预测的种类
        # z为encoder的潜在特征表示

        x_bar, q, pred, _, adj_re = model(data, adj)

        kl_loss = F.kl_div(q.log(), p, reduction='batchmean')
        ce_loss = F.kl_div(pred.log(), p, reduction='batchmean')
        re_loss = F.mse_loss(x_bar, data)
        # print('kl_loss:{}       ce_loss:{}'.format(kl_loss, ce_loss))
        # print(p.shape)            # torch.Size([3327, 6])
        # print(q.log().shape)      # torch.Size([3327, 6])
        # print(pred.log().shape)   # torch.Size([3327, 6])
        loss = 0.1 * kl_loss + 0.01 * ce_loss + re_loss
        # adj_loss = norm*F.binary_cross_entropy(adj_re.view(-1).cpu(), adj_label.to_dense().view(-1).cpu(), weight = weight_tensor)
        # loss = 0.1 * kl_loss + 0.01 * ce_loss + re_loss + 0.1*adj_loss
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
Example #19
0
def kdloss(y, teacher_scores):
    p = F.log_softmax(y, dim=1)
    q = F.softmax(teacher_scores, dim=1)
    l_kl = F.kl_div(p, q, size_average=False) / y.shape[0]
    return l_kl
Example #20
0
def distillation(y, teacher_scores, labels, T, alpha):
    p = F.log_softmax(y / T, dim=1)
    q = F.softmax(teacher_scores / T, dim=1)
    l_kl = F.kl_div(p, q, reduction='sum') * (T**2) / y.shape[0]
    l_ce = F.cross_entropy(y, labels)
    return l_kl * alpha + l_ce * (1. - alpha)
Example #21
0
 def reg_crit(self, x, y):
     return F.kl_div(torch.log(x), target=y, reduction="sum")
Example #22
0
def main():
    data_times, batch_times, losses, acc = [AverageMeter() for _ in range(4)]
    if args.mix_up:
        unlabel_acc = AverageMeter()
    best_acc = 0.
    model.train()
    logger.info("Start training...")
    for step in range(args.start_step, args.total_steps):
        # Load data and distribute to devices
        data_start = time.time()
        if args.mix_up:
            label_img, label_gt, unlabel_img, unlabel_gt = next(train_loader)
        else:
            label_img, label_gt = next(train_loader)

        if args.gpu:
            label_img = label_img.cuda()
            label_gt = label_gt.cuda()
        if args.mix_up:
            if args.gpu:
                unlabel_img = unlabel_img.cuda()
                unlabel_gt = unlabel_gt.cuda()
            _label_gt = F.one_hot(label_gt, num_classes=args.num_classes).float()
        data_end = time.time()

        # Compute learning rate
        lr = compute_lr(step)
        for param_group in optimizer.param_groups:
            param_group['lr'] = lr

        if args.mix_up:
            # Adopt mix-up augmentation
            model.eval()
            with torch.no_grad():
                label_pred = model(label_img)
                unlabel_pred = F.softmax(model(unlabel_img), dim=1) 
                alpha = beta_distribution.sample((args.batch_size,))
                if args.gpu:
                    alpha = alpha.cuda()
                _alpha = alpha.view(-1, 1, 1, 1)
                interp_img = (label_img * _alpha + unlabel_img * (1. - _alpha)).detach()
                interp_pseudo_gt = (_label_gt * alpha + unlabel_pred * (1. - alpha)).detach()
            model.train()
            interp_pred = model(interp_img)
            loss = F.kl_div(F.log_softmax(interp_pred, dim=1), interp_pseudo_gt, reduction='batchmean')
        else:
            # Regular label loss
            label_pred = model(label_img)
            loss = F.cross_entropy(label_pred, label_gt, reduction='mean')

        # One SGD step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # Compute accuracy
        top1, = accuracy(label_pred, label_gt, topk=(1,))
        if args.mix_up:
            unlabel_top1, = accuracy(unlabel_pred, unlabel_gt, topk=(1,))
        # Update AverageMeter stats
        data_times.update(data_end - data_start)
        batch_times.update(time.time() - data_end)
        losses.update(loss.item(), label_img.size(0))
        acc.update(top1.item(), label_img.size(0))
        if args.mix_up:
            unlabel_acc.update(unlabel_top1.item(), unlabel_img.size(0))

        # Print and log
        if step % args.print_freq == 0:
            logger.info("Step: [{0:05d}/{1:05d}] Dtime: {dtimes.avg:.3f} Btime: {btimes.avg:.3f} "
                        "loss: {losses.val:.3f} (avg {losses.avg:.3f}) Lacc: {label.val:.3f} (avg {label.avg:.3f}) "
                        "LR: {2:.4f}".format(step, args.total_steps, optimizer.param_groups[0]['lr'],
                                             dtimes=data_times, btimes=batch_times, losses=losses, label=acc))

        # Test and save model
        if (step + 1) % args.test_freq == 0 or step == args.total_steps - 1:
            val_acc = evaluate(test_loader, model)

            # Remember best accuracy and save checkpoint
            is_best = val_acc > best_acc
            if is_best:
                best_acc = val_acc
            logger.info("Best Accuracy: %.5f" % best_acc)
            save_checkpoint({
                'step': step + 1,
                'model': model.state_dict(),
                'best_acc': best_acc,
                'optimizer' : optimizer.state_dict()
                }, is_best, path=args.save_path, filename="checkpoint.pth")

            # Write to tfboard
            writer.add_scalar('train/label-acc', top1.item(), step)
            if args.mix_up:
                writer.add_scalar('train/unlabel-acc', unlabel_top1.item(), step)
            writer.add_scalar('train/loss', loss.item(), step)
            writer.add_scalar('train/lr', optimizer.param_groups[0]['lr'], step)
            writer.add_scalar('test/accuracy', val_acc, step)

            # Reset AverageMeters
            losses.reset()
            acc.reset()
            if args.mix_up:
                unlabel_acc.reset()
    def get_q_values(self, batch):
        """update where states are whole conversations which
        each have several sentences, and actions are a sentence (series of
        words). Q values are per word. Target Q values are over the next word
        in the sentence, or, if at the end of the sentence, the first word in a
        new sentence after the user response.
        """
        actions = to_var(torch.LongTensor(batch['action']))  # [batch_size]

        # Prepare inputs to Q network
        conversations = [
            np.concatenate((conv, np.atleast_2d(batch['action'][i])))
            for i, conv in enumerate(batch['state'])
        ]
        sent_lens = [
            np.concatenate((lens, np.atleast_1d(batch['action_lens'][i])))
            for i, lens in enumerate(batch['state_lens'])
        ]
        target_conversations = [conv[1:] for conv in conversations]
        conv_lens = [len(c) - 1 for c in conversations]
        if self.config.model not in VariationalModels:
            conversations = [conv[:-1] for conv in conversations]
            sent_lens = np.concatenate([l[:-1] for l in sent_lens])
        else:
            sent_lens = np.concatenate([l for l in sent_lens])
        conv_lens = to_var(torch.LongTensor(conv_lens))

        # Run Q network. Will produce [num_sentences, max sent len, vocab size]
        all_q_values = self.run_seq2seq_model(self.q_net, conversations,
                                              sent_lens, target_conversations,
                                              conv_lens)

        # Index to get only q values for actions taken (last sentence in each
        # conversation)
        start_q = torch.cumsum(
            torch.cat((to_var(conv_lens.data.new(1).zero_()), conv_lens[:-1])),
            0)
        conv_q_values = torch.stack([
            all_q_values[s + l - 1, :, :]
            for s, l in zip(start_q.data.tolist(), conv_lens.data.tolist())
        ], 0)  # [num_sentences, max_sent_len, vocab_size]

        # Limit by actual sentence length (remove padding) and flatten into
        # long list of words
        word_q_values = torch.cat([
            conv_q_values[i, :l, :] for i, l in enumerate(batch['action_lens'])
        ], 0)  # [total words, vocab_size]
        word_actions = torch.cat(
            [actions[i, :l] for i, l in enumerate(batch['action_lens'])],
            0)  # [total words]

        # Extract q values corresponding to actions taken
        q_values = word_q_values.gather(
            1, word_actions.unsqueeze(1)).squeeze()  # [total words]
        """ Compute KL metrics """
        prior_rewards = None

        # Get probabilities from policy network
        q_dists = torch.nn.functional.softmax(word_q_values, 1)
        q_probs = q_dists.gather(1, word_actions.unsqueeze(1)).squeeze()

        with torch.no_grad():
            # Run pretrained prior network.
            # [num_sentences, max sent len, vocab size]
            all_prior_logits = self.run_seq2seq_model(self.pretrained_prior,
                                                      conversations, sent_lens,
                                                      target_conversations,
                                                      conv_lens)

            # Get relevant actions. [num_sentences, max_sent_len, vocab_size]
            conv_prior = torch.stack([
                all_prior_logits[s + l - 1, :, :]
                for s, l in zip(start_q.data.tolist(), conv_lens.data.tolist())
            ], 0)

            # Limit by actual sentence length (remove padding) and flatten.
            # [total words, vocab_size]
            word_prior_logits = torch.cat([
                conv_prior[i, :l, :]
                for i, l in enumerate(batch['action_lens'])
            ], 0)

            # Take the softmax
            prior_dists = torch.nn.functional.softmax(word_prior_logits, 1)

            kl_div = F.kl_div(q_dists.log(), prior_dists, reduce=False)

            # [total words]
            prior_probs = prior_dists.gather(
                1, word_actions.unsqueeze(1)).squeeze()
            logp_logq = prior_probs.log() - q_probs.log()

            if self.config.model_averaging:
                model_avg_sentences = batch['model_averaged_probs']

                # Convert to tensors and flatten into [num_words]
                word_model_avg = torch.cat([
                    to_var(torch.FloatTensor(m)) for m in model_avg_sentences
                ], 0)

                # Compute KL from model-averaged prior
                prior_rewards = word_model_avg.log() - q_probs.log()

                # Clip because KL should never be negative, so because we
                # are subtracting KL, rewards should never be positive
                prior_rewards = torch.clamp(prior_rewards, max=0.0)

            elif self.config.kl_control and self.config.kl_calc == 'integral':
                # Note: we reward the negative KL divergence to ensure the
                # RL model stays close to the prior
                prior_rewards = -1.0 * torch.sum(kl_div, dim=1)
            elif self.config.kl_control:
                prior_rewards = logp_logq

            if self.config.kl_control:
                prior_rewards = prior_rewards * self.config.kl_weight_c
                self.kl_reward_batch_history.append(
                    torch.sum(prior_rewards).item())

            # Track all metrics
            self.kl_div_batch_history.append(torch.mean(kl_div).item())
            self.logp_batch_history.append(
                torch.mean(prior_probs.log()).item())
            self.logp_logq_batch_history.append(torch.mean(logp_logq).item())

        return q_values, prior_rewards
Example #24
0
def tts_train_loop_af_online(paths: Paths,
                             model: Tacotron,
                             model_tf: Tacotron,
                             optimizer,
                             train_set,
                             lr,
                             train_steps,
                             attn_example,
                             hp=None):
    # setattr(model, 'mode', 'attention_forcing')
    # setattr(model, 'mode', 'teacher_forcing')
    # import pdb; pdb.set_trace()

    def smooth(d, eps=float(1e-10)):
        u = 1.0 / float(d.size()[2])
        return eps * u + (1 - eps) * d

    device = next(
        model.parameters()).device  # use same device as model parameters

    for g in optimizer.param_groups:
        g['lr'] = lr

    total_iters = len(train_set)
    epochs = train_steps // total_iters + 1

    for e in range(1, epochs + 1):

        start = time.time()
        running_loss_out, running_loss_attn = 0, 0

        # Perform 1 epoch
        for i, (x, m, ids, _) in enumerate(train_set, 1):
            # print(i)
            # import pdb; pdb.set_trace()

            x, m = x.to(device), m.to(device)
            # pdb.set_trace()

            # print(model.r, model_tf.r)
            # import pdb; pdb.set_trace()

            # Parallelize model onto GPUS using workaround due to python bug
            if device.type == 'cuda' and torch.cuda.device_count() > 1:
                with torch.no_grad():
                    _, _, attn_ref = data_parallel_workaround(model_tf, x, m)
                m1_hat, m2_hat, attention = data_parallel_workaround(
                    model, x, m, False, attn_ref)
            else:
                with torch.no_grad():
                    _, _, attn_ref = model_tf(x, m)
                # pdb.set_trace()

                # setattr(model, 'mode', 'teacher_forcing')
                # with torch.no_grad(): _, _, attn_ref = model(x, m)

                # setattr(model, 'mode', 'attention_forcing_online')
                m1_hat, m2_hat, attention = model(x,
                                                  m,
                                                  generate_gta=False,
                                                  attn_ref=attn_ref)
                # m1_hat, m2_hat, attention = model(x, m, generate_gta=False, attn_ref=None)
                # pdb.set_trace()

            # print(x.size())
            # print(m.size())
            # print(m1_hat.size(), m2_hat.size())
            # print(attention.size(), attention.size(1)*model.r)
            # print(attn_ref.size())
            # pdb.set_trace()

            m1_loss = F.l1_loss(m1_hat, m)
            m2_loss = F.l1_loss(m2_hat, m)
            attn_loss = F.kl_div(torch.log(smooth(attention)),
                                 smooth(attn_ref),
                                 reduction='none')  # 'batchmean'
            attn_loss = attn_loss.sum(2).mean()
            # attn_loss = F.l1_loss(smooth(attention), smooth(attn_ref))

            loss_out = m1_loss + m2_loss
            loss_attn = attn_loss * hp.attn_loss_coeff
            loss = loss_out + loss_attn

            # if i%100==0:
            #     save_attention(np_now(attn_ref[0][:, :160]), paths.tts_attention/f'asup_{step}_tf')
            #     save_attention(np_now(attention[0][:, :160]), paths.tts_attention/f'asup_{step}_af')

            #     model_tf.r = 2
            #     with torch.no_grad(): _, _, attn_ref = model_tf(x, m)
            #     save_attention(np_now(attn_ref[0][:, :160]), paths.tts_attention/f'asup_{step}_tf_r2')
            #     model_tf.r = model.r
            #     pdb.set_trace()

            optimizer.zero_grad()
            loss.backward()
            if hp.tts_clip_grad_norm is not None:
                grad_norm = torch.nn.utils.clip_grad_norm_(
                    model.parameters(), hp.tts_clip_grad_norm)
                if np.isnan(grad_norm):
                    print('grad_norm was NaN!')

            optimizer.step()

            running_loss_out += loss_out.item()
            avg_loss_out = running_loss_out / i
            running_loss_attn += loss_attn.item()
            avg_loss_attn = running_loss_attn / i

            speed = i / (time.time() - start)

            step = model.get_step()
            k = step // 1000

            if step % hp.tts_checkpoint_every == 0:
                ckpt_name = f'taco_step{k}K'
                save_checkpoint('tts',
                                paths,
                                model,
                                optimizer,
                                name=ckpt_name,
                                is_silent=True)

            if attn_example in ids:
                idx = ids.index(attn_example)
                save_attention(np_now(attn_ref[idx][:, :160]),
                               paths.tts_attention / f'{step}_tf')
                save_attention(np_now(attention[idx][:, :160]),
                               paths.tts_attention / f'{step}_af')
                save_spectrogram(np_now(m2_hat[idx]),
                                 paths.tts_mel_plot / f'{step}', 600)

            msg = f'| Epoch: {e}/{epochs} ({i}/{total_iters}) | Loss_out: {avg_loss_out:#.4}; Loss_attn: {avg_loss_attn:#.4} | {speed:#.2} steps/s | Step: {k}k | '
            stream(msg)

        # Must save latest optimizer state to ensure that resuming training
        # doesn't produce artifacts
        save_checkpoint('tts', paths, model, optimizer, is_silent=True)
        model.log(paths.tts_log, msg)
        print(' ')
    def _get_loss_f(self, x, y, targeted, reduction):
        #x, y original ref_data / target
        #targeted whether to use a targeted attack or not
        #reduction: reduction to use: 'sum', 'mean', 'none'
        if isinstance(self.loss, str):
            if self.loss.lower() in ['crossentropy', 'ce']:
                if not targeted:
                    l_f = lambda data, data_out: -F.cross_entropy(
                        data_out, y, reduction=reduction)
                else:
                    l_f = lambda data, data_out: F.cross_entropy(
                        data_out, y, reduction=reduction)
            elif self.loss.lower() == 'kl':
                if not targeted:
                    l_f = lambda data, data_out: -reduce(
                        F.kl_div(torch.log_softmax(data_out, dim=1),
                                 y,
                                 reduction='none').sum(dim=1), reduction)
                else:
                    l_f = lambda data, data_out: reduce(
                        F.kl_div(torch.log_softmax(data_out, dim=1),
                                 y,
                                 reduction='none').sum(dim=1), reduction)
            elif self.loss.lower() == 'logitsdiff':
                if not targeted:
                    y_oh = F.one_hot(y, self.num_classes)
                    y_oh = y_oh.float()
                    l_f = lambda data, data_out: -logits_diff_loss(
                        data_out, y_oh, reduction=reduction)
                else:
                    y_oh = F.one_hot(y, self.num_classes)
                    y_oh = y_oh.float()
                    l_f = lambda data, data_out: logits_diff_loss(
                        data_out, y_oh, reduction=reduction)
            elif self.loss.lower() == 'conf':
                if not targeted:
                    l_f = lambda data, data_out: confidence_loss(
                        data_out, y, reduction=reduction)
                else:
                    l_f = lambda data, data_out: -confidence_loss(
                        data_out, y, reduction=reduction)
            elif self.loss.lower() == 'confdiff':
                if not targeted:
                    y_oh = F.one_hot(y, self.num_classes)
                    y_oh = y_oh.float()
                    l_f = lambda data, data_out: -conf_diff_loss(
                        data_out, y_oh, reduction=reduction)
                else:
                    y_oh = F.one_hot(y, self.num_classes)
                    y_oh = y_oh.float()
                    l_f = lambda data, data_out: conf_diff_loss(
                        data_out, y_oh, reduction=reduction)
            else:
                raise ValueError(f'Loss {self.loss} not supported')
        else:
            #custom 5 argument loss
            #(x_adv, x_adv_out, x, y, reduction)
            l_f = lambda data, data_out: self.loss(
                data, data_out, x, y, reduction=reduction)

        return l_f
Example #26
0
def distillation(y, teacher_scores, labels, T, alpha):
    return F.kl_div(F.log_softmax(y/T), F.softmax(teacher_scores/T)) * (T*T * 2. * alpha) \
            + F.cross_entropy(y, labels) * (1. - alpha)
Example #27
0
    def trn_step(self, epoch, sample_l, sample_u, scaler=None):
        self.optim.zero_grad()

        images_l = sample_l['image'].to(self.cfg.device)
        images_o = sample_u["image_ori"].to(self.cfg.device)
        images_a = sample_u["image_aug"].to(self.cfg.device)
        labels = sample_l['label'].to(self.cfg.device)

        batch_s = images_l.size(0)
        images_t = torch.cat([images_l, images_o, images_a])
        if scaler is not None:
            with autocast():
                logits_t = self.model(images_t)

                logits_l = logits_t[:batch_s]
                logits_o, logits_a = logits_t[batch_s:].chunk(2)
                del logits_t

                preds_o = F.softmax(logits_o, dim=-1).detach()
                preds_a = F.log_softmax(logits_a, dim=-1)
                kl_loss = F.kl_div(preds_a, preds_o, reduction='none')
                kl_loss = torch.mean(torch.sum(kl_loss, dim=-1))

                l_loss = self.trn_crit(logits_l, labels)

                if self.cfg.ratio_mode == 'constant':
                    t_loss = l_loss + self.cfg.ratio * torch.mean(kl_loss)
                elif self.cfg.ratio_mode == "gradual":
                    t_loss = epoch / self.cfg.t_epoch * self.cfg.ratio * torch.mean(
                        kl_loss) + l_loss

            scaler.scale(t_loss).backward()
            # clipping point -> batchnorm을 대체하는 역할 AGC
            scaler.unscale_(self.optim)
            if self.cfg.clipping:
                timm.utils.adaptive_clip_grad(self.model.parameters())
            scaler.step(self.optim)
            scaler.update()
        else:
            logits_t = self.model(images_t)
            logits_l = logits_t[:batch_s]
            logits_o, logits_a = logits_t[batch_s:].chunk(2)
            del logits_t

            preds_o = F.softmax(logits_o, dim=-1).detach()
            preds_a = F.log_softmax(logits_a, dim=-1)
            kl_loss = F.kl_div(preds_a, preds_o, reduction='none')
            kl_loss = torch.mean(torch.sum(kl_loss, dim=-1))

            l_loss = self.trn_crit(logits_l, labels)

            if self.cfg.ratio_mode == 'constant':
                t_loss = l_loss + self.cfg.ratio * kl_loss
            elif self.cfg.ratio_mode == "gradual":
                t_loss = epoch / self.cfg.t_epoch * self.cfg.ratio * kl_loss + l_loss

            t_loss.backward()

            if self.cfg.clipping:
                timm.utils.adaptive_clip_grad(self.model.parameters())

            self.optim.step()

        batch_acc = self.accuracy(logits_l, labels)
        batch_f1 = self.f1_score(logits_l, labels)
        result = {
            'l_loss': l_loss,
            't_loss': t_loss,
            'kl_loss': kl_loss,
            'batch_acc': batch_acc,
            'batch_f1': batch_f1
        }
        return result
Example #28
0
    def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None, sentence_span_list=None,
                sentence_labels=None, sentence_ids=None, sentence_prob=None, max_sentences=0):
        flat_input_ids = input_ids.view(-1, input_ids.size(-1))
        flat_token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None
        flat_attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
        seq_output, _ = self.bert(flat_input_ids, flat_token_type_ids, flat_attention_mask, output_all_encoded_layers=False)
        origin_hidden_size = int(seq_output.size(2) / 2)
        seq_output = seq_output[:, :, self.view_id * origin_hidden_size: (self.view_id + 1) * origin_hidden_size]

        # mask: 1 for masked value and 0 for true value
        # doc, que, doc_mask, que_mask = layers.split_doc_que(sequence_output, token_type_ids, attention_mask)
        doc_sen, que, doc_sen_mask, que_mask, sentence_mask = \
            rep_layers.split_doc_sen_que(seq_output, flat_token_type_ids, flat_attention_mask, sentence_span_list,
                                         max_sentences=max_sentences)
        # doc_sen_mask = 1 - doc_sen_mask
        # que_mask = 1 - que_mask
        # sentence_mask = 1 - sentence_mask
        # assert doc_sen.sum() != torch.nan
        batch, max_sen, doc_len = doc_sen_mask.size()
        assert max_sen == max_sentences

        que_vec = self.que_self_attn(que, que_mask).view(batch, 1, -1)

        doc = doc_sen.reshape(batch, max_sen * doc_len, -1)
        word_sim = self.word_similarity(que_vec, doc).view(batch * max_sen, doc_len)
        doc = doc_sen.reshape(batch * max_sen, doc_len, -1)
        doc_mask = doc_sen_mask.reshape(batch * max_sen, doc_len)
        word_hidden = rep_layers.masked_softmax(word_sim, doc_mask, dim=1).unsqueeze(1).bmm(doc)

        word_hidden = word_hidden.view(batch, max_sen, -1)

        doc_vecs = self.doc_sen_self_attn(doc, doc_mask).view(batch, max_sen, -1)

        sentence_sim = self.vector_similarity(que_vec, doc_vecs)
        sentence_alpha = rep_layers.masked_softmax(sentence_sim, sentence_mask)
        sentence_hidden = sentence_alpha.bmm(word_hidden).squeeze(1)
        choice_logits = self.classifier(torch.cat([sentence_hidden, que_vec.squeeze(1)], dim=1)).reshape(-1, self.num_choices)

        if self.training:
            output_dict = {}
        else:
            output_dict = {
                'choice_logits': torch.softmax(choice_logits, dim=-1).detach().cpu().float(),
                'sentence_logits': sentence_alpha.reshape(choice_logits.size(0), self.num_choices, max_sen).detach().cpu().float()
            }

        if labels is not None:
            choice_loss = functional.cross_entropy(choice_logits, labels)
            loss = choice_loss
            if self.multi_evidence and sentence_prob is not None:
                sentence_prob = sentence_prob.reshape(batch, -1).to(sentence_sim.dtype)
                true_prob_mask = ((sentence_prob > 0).to(dtype=sentence_sim.dtype)) * sentence_mask.to(dtype=sentence_sim.dtype)
                kl_sentence_loss = functional.kl_div(sentence_alpha * true_prob_mask, sentence_prob * true_prob_mask, reduction='sum')
                output_dict['sentence_loss'] = kl_sentence_loss.item()
                loss += self.evidence_lam * kl_sentence_loss / choice_logits.size(0)
            elif not self.multi_evidence and sentence_ids is not None:
                # logger.info(f'sentence_ids total number: {(sentence_ids != -1).sum().item()}')
                # logger.info('sentence_mask.sum() = ', sentence_mask.sum())
                assert sentence_mask.sum() != 0, sentence_mask.sum()
                # assert all(x < sentence_mask.sum() for x in sentence_ids.view(batch).detach().tolist())
                assertion = (sentence_ids.view(batch) >= sentence_mask.sum(dim=-1)).sum()
                
                log_masked_sentence_prob = rep_layers.masked_log_softmax(sentence_sim.squeeze(1), sentence_mask)
                sentence_loss = functional.nll_loss(log_masked_sentence_prob, sentence_ids.view(batch), reduction='sum',
                                                    ignore_index=-1)
                # sentence_loss = functional.cross_entropy(sentence_sim.squeeze(1), sentence_ids.view(batch), reduction='sum',
                #                                          ignore_index=-1)
                # logger.info(f'sentence loss: {sentence_loss.item()}')
                loss += self.evidence_lam * sentence_loss / choice_logits.size(0)
                output_dict['sentence_loss'] = sentence_loss.item()
            output_dict['loss'] = loss
        return output_dict
Example #29
0
def train(train_loader, model, optimizer, epoch, lr_schedule, half=False):
    mean = torch.Tensor(
        np.array(configs.TRAIN.mean)[:, np.newaxis, np.newaxis])
    mean = mean.expand(3, configs.DATA.crop_size,
                       configs.DATA.crop_size).cuda()
    std = torch.Tensor(np.array(configs.TRAIN.std)[:, np.newaxis, np.newaxis])
    std = std.expand(3, configs.DATA.crop_size, configs.DATA.crop_size).cuda()

    # Initialize the meters
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()
    # switch to train mode
    model.train()
    end = time.time()
    for i, (input, target) in enumerate(train_loader):
        if configs.TRAIN.methods != 'augmix':
            input = input.cuda(non_blocking=True)
        else:
            input = torch.cat(input, 0).cuda(non_blocking=True)

        target = target.cuda(non_blocking=True)
        data_time.update(time.time() - end)

        # update learning rate
        lr = lr_schedule(epoch + (i + 1) / len(train_loader))
        for param_group in optimizer.param_groups:
            param_group['lr'] = lr

        optimizer.zero_grad()

        input.sub_(mean).div_(std)
        lam = np.random.beta(configs.TRAIN.alpha, configs.TRAIN.alpha)
        if configs.TRAIN.methods == 'manifold' or configs.TRAIN.methods == 'graphcut':
            permuted_idx1 = np.random.permutation(input.size(0) // 4)
            permuted_idx2 = permuted_idx1 + input.size(0) // 4
            permuted_idx3 = permuted_idx2 + input.size(0) // 4
            permuted_idx4 = permuted_idx3 + input.size(0) // 4
            permuted_idx = np.concatenate(
                [permuted_idx1, permuted_idx2, permuted_idx3, permuted_idx4],
                axis=0)
        else:
            permuted_idx = torch.tensor(np.random.permutation(input.size(0)))

        if configs.TRAIN.methods == 'input':
            input = lam * input + (1 - lam) * input[permuted_idx]

        elif configs.TRAIN.methods == 'cutmix':
            input, lam = mixup_box(input, lam=lam, permuted_idx=permuted_idx)

        elif configs.TRAIN.methods == 'augmix':
            logit = model(input)
            logit_clean, logit_aug1, logit_aug2 = torch.split(
                logit,
                logit.size(0) // 3)
            output = logit_clean

            p_clean = F.softmax(logit_clean, dim=1)
            p_aug1 = F.softmax(logit_aug1, dim=1)
            p_aug2 = F.softmax(logit_aug2, dim=1)

            p_mixture = torch.clamp((p_clean + p_aug1 + p_aug2) / 3., 1e-7,
                                    1).log()
            loss_JSD = 4 * (
                F.kl_div(p_mixture, p_clean, reduction='batchmean') +
                F.kl_div(p_mixture, p_aug1, reduction='batchmean') +
                F.kl_div(p_mixture, p_aug2, reduction='batchmean'))

        elif configs.TRAIN.methods == 'graphcut':
            input_var = Variable(input, requires_grad=True)

            output = model(input_var)
            loss_clean = criterion(output, target)

            if half:
                with amp.scale_loss(loss_clean, optimizer) as scaled_loss:
                    scaled_loss.backward()
            else:
                loss_clean.backward()
            unary = torch.sqrt(torch.mean(input_var.grad**2, dim=1))

            block_num = 2**(np.random.randint(1, 5))
            mask = get_mask(input,
                            unary,
                            block_num,
                            permuted_idx,
                            alpha=lam,
                            mean=mean,
                            std=std)
            output, lam = model(input,
                                graphcut=True,
                                permuted_idx=permuted_idx1,
                                block_num=block_num,
                                mask=mask,
                                unary=unary)

        if configs.TRAIN.methods == 'manifold':
            output = model(input,
                           manifold=True,
                           lam=lam,
                           permuted_idx=permuted_idx1)
        elif configs.TRAIN.methods != 'augmix' and configs.TRAIN.methods != 'graphcut':
            output = model(input)

        if configs.TRAIN.methods == 'nat':
            loss = criterion(output, target)
        elif configs.TRAIN.methods == 'augmix':
            loss = criterion(output, target) + loss_JSD
        else:
            loss = lam * criterion_batch(output, target) + (
                1 - lam) * criterion_batch(output, target[permuted_idx])
            loss = torch.mean(loss)

        # compute gradient and do SGD step
        #optimizer.zero_grad()
        if half:
            with amp.scale_loss(loss, optimizer) as scaled_loss:
                scaled_loss.backward()
        else:
            loss.backward()

        optimizer.step()

        prec1, prec5 = accuracy(output, target, topk=(1, 5))
        losses.update(loss.item(), input.size(0))
        top1.update(prec1[0], input.size(0))
        top5.update(prec5[0], input.size(0))

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % configs.TRAIN.print_freq == 0:
            print('Train Epoch: [{0}][{1}/{2}]\t'
                  'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                  'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
                  'Loss {cls_loss.val:.4f} ({cls_loss.avg:.4f})\t'
                  'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
                  'Prec@5 {top5.val:.3f} ({top5.avg:.3f})\t'
                  'LR {lr:.3f}'.format(epoch,
                                       i,
                                       len(train_loader),
                                       batch_time=batch_time,
                                       data_time=data_time,
                                       top1=top1,
                                       top5=top5,
                                       cls_loss=losses,
                                       lr=lr))
            sys.stdout.flush()
Example #30
0
        err_d = err_d_real + err_d_fake
        d_optimizer.step()

        ########################
        # (2) Update G network #
        ########################
        generator.zero_grad()

        output = discriminator(fake_images)
        err_g = criterion(output, real_label)
        D_G_z2 = output.mean().item()

        if args.classifier:
            output = F.log_softmax(classifier(fake_images), dim=1)
            uniform_dist = torch.zeros((true_batch_size, 10), device=device).fill_((1. / 10))
            err_kl = args.kl_beta * F.kl_div(output, uniform_dist, reduction='batchmean')
            err_g += err_kl

        err_g.backward()
        g_optimizer.step()

        if i % 100 == 99:
            if args.classifier:
                print(
                    '[%3d|%3d] Loss_D: %.4f Loss_G: %.4f Loss_C: %.4f D(x): %.4f D(G(z)): %.4f / %.4f'
                    % (epoch, i + 1, err_d.item(), err_g.item(), err_kl.item(), D_x, D_G_z1, D_G_z2)
                )
            else:
                print(
                    '[%3d|%3d] Loss_D: %.4f Loss_G: %.4f D(x): %.4f D(G(z)): %.4f / %.4f'
                    % (epoch, i + 1, err_d.item(), err_g.item(), D_x, D_G_z1, D_G_z2)
def distillation(y, teacher_scores, T):
    p = F.log_softmax(y / T, dim=1)
    q = F.softmax(teacher_scores / T, dim=1)
    l_kl = F.kl_div(p, q, size_average=False) * (T ** 2) / y.shape[0]
    return l_kl
Example #32
0
def cross_entropy(p,q):
    return F.kl_div(F.log_softmax(q), F.softmax(p), reduction='sum')
  def thread_target(i):
    while not exit_event.isSet():
      # If the run event is not set, the thread just waits.
      if not run_event1.wait(0.001): continue

      ######################################
      # Phase 1: Forward and Separate Loss #
      ######################################

      TVT = TVTs[i]
      model_w = model_ws[i]
      ims = ims_list[i]
      labels = labels_list[i]
      optimizer = optimizers[i]

      ims_var = Variable(TVT(torch.from_numpy(ims).float()))
      labels_t = TVT(torch.from_numpy(labels).long())
      labels_var = Variable(labels_t)

      global_feat, local_feat, logits = model_w(ims_var)
      probs = F.softmax(logits)
      log_probs = F.log_softmax(logits)

      g_loss, p_inds, n_inds, g_dist_ap, g_dist_an, g_dist_mat = global_loss(
        g_tri_loss, global_feat, labels_t,
        normalize_feature=cfg.normalize_feature)

      if cfg.l_loss_weight == 0:
        l_loss, l_dist_mat = 0, 0
      elif cfg.local_dist_own_hard_sample:
        # Let local distance find its own hard samples.
        l_loss, l_dist_ap, l_dist_an, l_dist_mat = local_loss(
          l_tri_loss, local_feat, None, None, labels_t,
          normalize_feature=cfg.normalize_feature)
      else:
        l_loss, l_dist_ap, l_dist_an = local_loss(
          l_tri_loss, local_feat, p_inds, n_inds, labels_t,
          normalize_feature=cfg.normalize_feature)
        l_dist_mat = 0

      id_loss = 0
      if cfg.id_loss_weight > 0:
        id_loss = id_criterion(logits, labels_var)

      probs_list[i] = probs
      g_dist_mat_list[i] = g_dist_mat
      l_dist_mat_list[i] = l_dist_mat

      done_list1[i] = True

      # Wait for event to be set, meanwhile checking if need to exit.
      while True:
        phase2_ready = run_event2.wait(0.001)
        if exit_event.isSet():
          return
        if phase2_ready:
          break

      #####################################
      # Phase 2: Mutual Loss and Backward #
      #####################################

      # Probability Mutual Loss (KL Loss)
      pm_loss = 0
      if (cfg.num_models > 1) and (cfg.pm_loss_weight > 0):
        for j in range(cfg.num_models):
          if j != i:
            pm_loss += F.kl_div(log_probs, TVT(probs_list[j]).detach(), False)
        pm_loss /= 1. * (cfg.num_models - 1) * len(ims)

      # Global Distance Mutual Loss (L2 Loss)
      gdm_loss = 0
      if (cfg.num_models > 1) and (cfg.gdm_loss_weight > 0):
        for j in range(cfg.num_models):
          if j != i:
            gdm_loss += torch.sum(torch.pow(
              g_dist_mat - TVT(g_dist_mat_list[j]).detach(), 2))
        gdm_loss /= 1. * (cfg.num_models - 1) * len(ims) * len(ims)

      # Local Distance Mutual Loss (L2 Loss)
      ldm_loss = 0
      if (cfg.num_models > 1) \
          and cfg.local_dist_own_hard_sample \
          and (cfg.ldm_loss_weight > 0):
        for j in range(cfg.num_models):
          if j != i:
            ldm_loss += torch.sum(torch.pow(
              l_dist_mat - TVT(l_dist_mat_list[j]).detach(), 2))
        ldm_loss /= 1. * (cfg.num_models - 1) * len(ims) * len(ims)

      loss = g_loss * cfg.g_loss_weight \
             + l_loss * cfg.l_loss_weight \
             + id_loss * cfg.id_loss_weight \
             + pm_loss * cfg.pm_loss_weight \
             + gdm_loss * cfg.gdm_loss_weight \
             + ldm_loss * cfg.ldm_loss_weight

      optimizer.zero_grad()
      loss.backward()
      optimizer.step()

      ##################################
      # Step Log For One of the Models #
      ##################################

      # These meters are outer-scope variables

      # Just record for the first model
      if i == 0:

        # precision
        g_prec = (g_dist_an > g_dist_ap).data.float().mean()
        # the proportion of triplets that satisfy margin
        g_m = (g_dist_an > g_dist_ap + cfg.global_margin).data.float().mean()
        g_d_ap = g_dist_ap.data.mean()
        g_d_an = g_dist_an.data.mean()

        g_prec_meter.update(g_prec)
        g_m_meter.update(g_m)
        g_dist_ap_meter.update(g_d_ap)
        g_dist_an_meter.update(g_d_an)
        g_loss_meter.update(to_scalar(g_loss))

        if cfg.l_loss_weight > 0:
          # precision
          l_prec = (l_dist_an > l_dist_ap).data.float().mean()
          # the proportion of triplets that satisfy margin
          l_m = (l_dist_an > l_dist_ap + cfg.local_margin).data.float().mean()
          l_d_ap = l_dist_ap.data.mean()
          l_d_an = l_dist_an.data.mean()

          l_prec_meter.update(l_prec)
          l_m_meter.update(l_m)
          l_dist_ap_meter.update(l_d_ap)
          l_dist_an_meter.update(l_d_an)
          l_loss_meter.update(to_scalar(l_loss))

        if cfg.id_loss_weight > 0:
          id_loss_meter.update(to_scalar(id_loss))

        if (cfg.num_models > 1) and (cfg.pm_loss_weight > 0):
          pm_loss_meter.update(to_scalar(pm_loss))

        if (cfg.num_models > 1) and (cfg.gdm_loss_weight > 0):
          gdm_loss_meter.update(to_scalar(gdm_loss))

        if (cfg.num_models > 1) \
            and cfg.local_dist_own_hard_sample \
            and (cfg.ldm_loss_weight > 0):
          ldm_loss_meter.update(to_scalar(ldm_loss))

        loss_meter.update(to_scalar(loss))

      ###################
      # End Up One Step #
      ###################

      run_event1.clear()
      run_event2.clear()

      done_list2[i] = True
Example #34
0
def train(trainloader, model, criterion, optimizer, epoch, use_cuda, teacher):
    # switch to train mode
    model.train()

    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()
    end = time.time()

    bar = Bar('Processing', max=len(trainloader))
    for batch_idx, (inputs, targets) in enumerate(trainloader):
        # measure data loading time
        data_time.update(time.time() - end)

        if use_cuda:
            inputs, targets = inputs.cuda(), targets.cuda(async=True)
        inputs, targets = torch.autograd.Variable(
            inputs), torch.autograd.Variable(targets)

        # compute output
        outputs = model(inputs)
        #outputs --> logits batch_size x 10
        teacher_outputs = teacher(inputs)

        if args.caruana:
            soft_loss = F.mse_loss(outputs,
                                   teacher_outputs.detach(),
                                   reduction="mean")
        else:
            soft_log_probs = F.log_softmax(outputs / args.temperature, dim=1)
            soft_targets = F.softmax(teacher_outputs / args.temperature, dim=1)

            soft_loss = F.kl_div(soft_log_probs,
                                 soft_targets.detach(),
                                 size_average=False) / outputs.shape[0]

        hard_loss = criterion(outputs, targets)

        loss = args.alpha * soft_loss * (1 - args.alpha) * hard_loss
        # measure accuracy and record loss
        prec1, prec5 = accuracy(outputs.data, targets.data, topk=(1, 5))
        losses.update(loss.data.item(), inputs.size(0))
        top1.update(prec1.item(), inputs.size(0))
        top5.update(prec5.item(), inputs.size(0))

        # compute gradient and do SGD step
        optimizer.zero_grad()
        loss.backward()
        #-----------------------------------------
        for k, m in enumerate(model.modules()):
            # print(k, m)
            if isinstance(m, nn.Conv2d):
                weight_copy = m.weight.data.abs().clone()
                mask = weight_copy.gt(0).float().cuda()
                m.weight.grad.data.mul_(mask)
        #-----------------------------------------
        optimizer.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        # plot progress
        bar.suffix = '({batch}/{size}) Data: {data:.3f}s | Batch: {bt:.3f}s | Total: {total:} | ETA: {eta:} | Loss: {loss:.4f} | top1: {top1: .4f} | top5: {top5: .4f}'.format(
            batch=batch_idx + 1,
            size=len(trainloader),
            data=data_time.avg,
            bt=batch_time.avg,
            total=bar.elapsed_td,
            eta=bar.eta_td,
            loss=losses.avg,
            top1=top1.avg,
            top5=top5.avg,
        )
        bar.next()
    bar.finish()
    return (losses.avg, top1.avg)