コード例 #1
0
    def train(self):
        console_header = 'Epoch\tTrain_Loss\tTrain_Accuracy\tTest_Accuracy\tEpoch_Runtime\tLearning_Rate'
        print_to_console(console_header)
        print_to_logfile(self._logfile, console_header, init=True)

        for t in range(self._start_epoch, self._epochs):
            epoch_start = time.time()
            self._scheduler.step(epoch=t)
            # reset average meters
            self._train_loss.reset()
            self._train_accuracy.reset()

            self._net.train(True)
            self.single_epoch_training(t)
            test_accuracy = evaluate(self._test_loader, self._net)

            lr = get_lr_from_optimizer(self._optimizer)

            if test_accuracy > self._best_accuracy:
                self._best_accuracy = test_accuracy
                self._best_epoch = t + 1
                torch.save(self._net.state_dict(), 'model/step{}_best_epoch.pth'.format(self._step))
                # print('*', end='')
            epoch_end = time.time()
            single_epoch_runtime = epoch_end - epoch_start
            # Logging
            console_content = '{:05d}\t{:10.4f}\t{:14.4f}\t{:13.4f}\t{:13.2f}\t{:13.1e}'.format(
                t + 1, self._train_loss.avg, self._train_accuracy.avg, test_accuracy, single_epoch_runtime, lr)
            print_to_console(console_content)
            print_to_logfile(self._logfile, console_content, init=False)

            # save checkpoint
            save_checkpoint({
                'epoch': t + 1,
                'state_dict': self._net.state_dict(),
                'best_epoch': self._best_epoch,
                'best_accuracy': self._best_accuracy,
                'optimizer': self._optimizer.state_dict(),
                'step': self._step,
                'scheduler': self._scheduler.state_dict(),
                'memory_pool': self.memory_pool,
            })

        console_content = 'Best at epoch {}, test accuracy is {}'.format(self._best_epoch, self._best_accuracy)
        print_to_console(console_content)

        # rename log file, stats files and model
        os.rename(self._logfile, self._logfile.replace('.txt', '-{}_{}_{}_{:.4f}.txt'.format(
            self._config['net'], self._config['batch_size'], self._config['lr'], self._best_accuracy)))
コード例 #2
0
    def single_epoch_training(self, epoch, log_iter=True, log_freq=100):
        if epoch >= self.T_k:
            stats_log_path = 'stats/drop_n_reuse_stats_epoch{:03d}.csv'.format(epoch+1)
            stats_log_header = 'clean_sample_num,reusable_sample_num,irrelevant_sample_num'
            print_to_logfile(stats_log_path, stats_log_header, init=True, end='\n')
        for it, (x, y, indices) in enumerate(self._train_loader):
            s = time.time()

            x = x.cuda()
            y = y.cuda()
            self._optimizer.zero_grad()
            logits = self._net(x)
            losses, ce_loss = std_loss(logits, y, indices, self.T_k, epoch, self.memory_pool,
                                       eps=self._config['eps'])
            loss = losses.mean()

            self.memory_pool.update(indices=indices, losses=ce_loss.detach().data.cpu(),
                                    scores=F.softmax(logits, dim=1).detach().data.cpu(),
                                    labels=y.detach().data.cpu())

            train_accuracy = accuracy(logits, y, topk=(1,))

            self._train_loss.update(loss.item(), x.size(0))
            self._train_accuracy.update(train_accuracy[0], x.size(0))

            loss.backward()
            self._optimizer.step()

            e = time.time()
            self._epoch_train_time.update(e-s, 1)
            if (log_iter and (it+1) % log_freq == 0) or (it+1 == len(self._train_loader)):
                console_content = 'Epoch:[{0:03d}/{1:03d}]  Iter:[{2:04d}/{3:04d}]  ' \
                                  'Train Accuracy :[{4:6.2f}]  Loss:[{5:4.4f}]  ' \
                                  'Iter Runtime:[{6:6.2f}]'.format(epoch + 1, self._epochs, it + 1,
                                                                   len(self._train_loader),
                                                                   self._train_accuracy.avg,
                                                                   self._train_loss.avg, self._epoch_train_time.avg)
                print_to_console(console_content)
コード例 #3
0
ファイル: cot_train.py プロジェクト: yyy11178/CRSSC
    def single_epoch_training(self, epoch, log_iter=True, log_freq=200):
        if epoch >= self.T_k:
            stats_log_path1 = 'stats/net1_drop_n_reuse_stats_epoch{:03d}.csv'.format(
                epoch + 1)
            stats_log_path2 = 'stats/net2_drop_n_reuse_stats_epoch{:03d}.csv'.format(
                epoch + 1)
            stats_log_header = 'clean_sample_num,reusable_sample_num,irrelevant_sample_num'
            print_to_logfile(stats_log_path1,
                             stats_log_header,
                             init=True,
                             end='\n')
            print_to_logfile(stats_log_path2,
                             stats_log_header,
                             init=True,
                             end='\n')

        for it, (x, y, indices) in enumerate(self._train_loader):
            s = time.time()

            x = x.cuda()
            y = y.cuda()
            self._optimizer1.zero_grad()
            self._optimizer2.zero_grad()
            logits1 = self._net1(x)
            logits2 = self._net2(x)
            losses1, ce_loss1, losses2, ce_loss2 = \
                cot_std_loss(logits1, logits2, y, indices, self.T_k, epoch,
                             self.memory_pool1, self.memory_pool1, eps=self._config['eps'])
            loss1 = losses1.mean()
            loss2 = losses2.mean()

            self.memory_pool1.update(indices=indices,
                                     losses=ce_loss1.detach().data.cpu(),
                                     scores=F.softmax(
                                         logits1, dim=1).detach().data.cpu(),
                                     labels=y.detach().data.cpu())
            self.memory_pool1.update(indices=indices,
                                     losses=ce_loss2.detach().data.cpu(),
                                     scores=F.softmax(
                                         logits2, dim=1).detach().data.cpu(),
                                     labels=y.detach().data.cpu())

            train_accuracy1 = accuracy(logits1, y, topk=(1, ))
            train_accuracy2 = accuracy(logits2, y, topk=(1, ))

            self._train_loss1.update(loss1.item(), losses1.size(0))
            self._train_loss2.update(loss2.item(), losses1.size(0))
            self._train_accuracy1.update(train_accuracy1[0], x.size(0))
            self._train_accuracy2.update(train_accuracy2[0], x.size(0))

            loss1.backward()
            loss2.backward()
            self._optimizer1.step()
            self._optimizer2.step()

            e = time.time()
            self._epoch_train_time.update(e - s, 1)
            if (log_iter and (it + 1) % log_freq == 0) or (it + 1 == len(
                    self._train_loader)):
                console_content = 'Epoch:[{:03d}/{:03d}]  Iter:[{:04d}/{:04d}]  ' \
                                  'Train Accuracy1 :[{:6.2f}]  Train Accuracy2 :[{:6.2f}]  ' \
                                  'Loss1:[{:4.4f}]  Loss2:[{:4.4f}]  ' \
                                  'Iter Runtime:[{:6.2f}]'.format(epoch + 1, self._epochs, it + 1,
                                                                  len(self._train_loader),
                                                                  self._train_accuracy1.avg, self._train_accuracy2.avg,
                                                                  self._train_loss1.avg, self._train_loss2.avg,
                                                                  self._epoch_train_time.avg)
                print_to_console(console_content)
コード例 #4
0
ファイル: cot_train.py プロジェクト: yyy11178/CRSSC
    def train(self):
        console_header = 'Epoch\tTrain_Loss1\tTrain_Loss2\tTrain_Accuracy1\tTrain_Accuracy2\t' \
                         'Test_Accuracy1\tTest_Accuracy2\tEpoch_Runtime\tLearning_Rate1\tLearning_Rate2'
        print_to_console(console_header)
        print_to_logfile(self._logfile, console_header, init=True)

        for t in range(self._start_epoch, self._epochs):
            epoch_start = time.time()
            self._scheduler1.step(epoch=t)
            self._scheduler2.step(epoch=t)
            # reset average meters
            self._train_loss1.reset()
            self._train_loss2.reset()
            self._train_accuracy1.reset()
            self._train_accuracy2.reset()

            self._net1.train(True)
            self._net2.train(True)
            self.single_epoch_training(t)
            test_accuracy1 = evaluate(self._test_loader, self._net1)
            test_accuracy2 = evaluate(self._test_loader, self._net2)

            lr1 = get_lr_from_optimizer(self._optimizer1)
            lr2 = get_lr_from_optimizer(self._optimizer2)

            if test_accuracy1 > self._best_accuracy1:
                self._best_accuracy1 = test_accuracy1
                self._best_epoch1 = t + 1
                torch.save(
                    self._net1.state_dict(),
                    'model/net1_step{}_best_epoch.pth'.format(self._step))
            if test_accuracy2 > self._best_accuracy2:
                self._best_accuracy2 = test_accuracy2
                self._best_epoch2 = t + 1
                torch.save(
                    self._net2.state_dict(),
                    'model/net2_step{}_best_epoch.pth'.format(self._step))

            epoch_end = time.time()
            single_epoch_runtime = epoch_end - epoch_start
            # Logging
            console_content = '{:05d}\t{:10.4f}\t{:10.4f}\t{:14.4f}\t{:14.4f}\t' \
                              '{:13.4f}\t{:13.4f}\t{:13.2f}\t' \
                              '{:13.1e}\t{:13.1e}'.format(t + 1, self._train_loss1.avg, self._train_loss2.avg,
                                                          self._train_accuracy1.avg, self._train_accuracy2.avg,
                                                          test_accuracy1, test_accuracy2,
                                                          single_epoch_runtime, lr1, lr2)
            print_to_console(console_content)
            print_to_logfile(self._logfile, console_content, init=False)

            # save checkpoint
            save_checkpoint({
                'epoch': t + 1,
                'state_dict1': self._net1.state_dict(),
                'state_dict2': self._net2.state_dict(),
                'best_epoch1': self._best_epoch1,
                'best_epoch2': self._best_epoch2,
                'best_accuracy1': self._best_accuracy1,
                'best_accuracy2': self._best_accuracy2,
                'optimizer1': self._optimizer1.state_dict(),
                'optimizer2': self._optimizer2.state_dict(),
                'step': self._step,
                'scheduler1': self._scheduler1.state_dict(),
                'scheduler2': self._scheduler2.state_dict(),
                'memory_pool1': self.memory_pool1,
                'memory_pool2': self.memory_pool2,
            })

        console_content = 'Net1: Best at epoch {}, test accuracy is {}'.format(
            self._best_epoch1, self._best_accuracy1)
        print_to_console(console_content)
        console_content = 'Net2: Best at epoch {}, test accuracy is {}'.format(
            self._best_epoch2, self._best_accuracy2)
        print_to_console(console_content)

        # rename log file
        os.rename(
            self._logfile,
            self._logfile.replace(
                '.txt', '-{}_{}_{}_{:.4f}_{:.4f}.txt'.format(
                    self._config['net'], self._config['batch_size'],
                    self._config['lr'], self._best_accuracy1,
                    self._best_accuracy2)))
コード例 #5
0
ファイル: loss.py プロジェクト: yyy11178/CRSSC
def cot_std_loss(logits1,
                 logits2,
                 labels,
                 indices,
                 T_k,
                 epoch,
                 memory_pool1,
                 memory_pool2,
                 eps=0.1):
    ce_losses1 = label_smoothing_cross_entropy(logits1,
                                               labels,
                                               epsilon=eps,
                                               reduction='none')  # (N,)
    ce_losses2 = label_smoothing_cross_entropy(logits2,
                                               labels,
                                               epsilon=eps,
                                               reduction='none')  # (N,)

    # in the first T_k epochs, train with the entire training set
    if epoch < T_k:
        # print('using naive CE', end=' <--- ')
        return ce_losses1, ce_losses1, ce_losses2, ce_losses2

    # after T_k epochs, start dividing training set into clean / uncertain / irrelevant
    ind_loss_sorted1 = torch.argsort(ce_losses1.data)
    ind_loss_sorted2 = torch.argsort(ce_losses2.data)
    num_remember1 = torch.nonzero(ce_losses1 < ce_losses1.mean()).shape[0]
    num_remember2 = torch.nonzero(ce_losses2 < ce_losses2.mean()).shape[0]

    # print(' ---> {:2d}, {:2d}'.format(num_remember1, num_remember1), end=', ')
    stats_log_path1 = 'stats/net1_drop_n_reuse_stats_epoch{:03d}.csv'.format(
        epoch + 1)
    stats_log_path2 = 'stats/net2_drop_n_reuse_stats_epoch{:03d}.csv'.format(
        epoch + 1)
    print_to_logfile(stats_log_path1,
                     '{:03d}'.format(num_remember1),
                     init=False,
                     end=',')
    print_to_logfile(stats_log_path2,
                     '{:03d}'.format(num_remember2),
                     init=False,
                     end=',')

    ind_clean1 = ind_loss_sorted1[:num_remember1]
    ind_clean2 = ind_loss_sorted2[:num_remember2]
    ind_forget1 = ind_loss_sorted1[num_remember1:]
    ind_forget2 = ind_loss_sorted2[num_remember2:]
    logits_clean1 = logits1[ind_clean2]
    logits_clean2 = logits2[ind_clean1]
    labels_clean1 = labels[ind_clean2]
    labels_clean2 = labels[ind_clean1]

    logits_final1 = logits_clean1
    logits_final2 = logits_clean2
    labels_final1 = labels_clean1
    labels_final2 = labels_clean2

    if ind_forget1.shape[0] > 1:
        # for samples with high loss
        #   high loss, high std --> mislabeling
        #   high loss, low std  --> irrelevant category
        # indices_forget1 = indices[ind_forget1]
        logits_forget1 = logits1[ind_forget1]
        pred_distribution1 = F.softmax(logits_forget1, dim=1)
        batch_std1 = pred_distribution1.std(dim=1)

        flag1 = F.softmax(logits_clean1, dim=1).std(dim=1).mean().item()
        # print('{:.5f}'.format(flag), end='*****')

        batch_std_sorted1, ind_std_sorted1 = torch.sort(batch_std1.data,
                                                        descending=True)
        ind_split1 = split_set(batch_std_sorted1, flag1)
        if ind_split1 is None:
            ind_split1 = -1
            # print('{} == {}'.format(batch_std_sorted, ind_split), end=' ---> ')

        # uncertain could be either mislabeled or hard example
        ind_uncertain1 = ind_std_sorted1[:(ind_split1 + 1)]

        # print('{:2d}/{:2d}'.format(ind_split1 + 1, logits1.shape[0] - num_remember1), end=' <--- ')
        print_to_logfile(
            stats_log_path1, '{:03d},{:03d}'.format(
                ind_split1 + 1,
                logits1.shape[0] - num_remember1 - ind_split1 - 1))

        ind_mislabeled1 = ind_forget1[ind_uncertain1]
        logits_mislabeled2 = logits2[ind_mislabeled1]
        indices_mislabeled2 = indices[ind_mislabeled1]
        labels_mislabeled2 = memory_pool2.most_prob_labels[
            indices_mislabeled2].to(logits_mislabeled2.device)

        logits_final2 = torch.cat((logits_final2, logits_mislabeled2), dim=0)
        labels_final2 = torch.cat((labels_final2, labels_mislabeled2), dim=0)
    if ind_forget2.shape[0] > 1:
        # for samples with high loss
        #   high loss, high std --> mislabeling
        #   high loss, low std  --> irrelevant category
        # indices_forget2 = indices[ind_forget2]
        logits_forget2 = logits2[ind_forget2]
        pred_distribution2 = F.softmax(logits_forget2, dim=1)
        batch_std2 = pred_distribution2.std(dim=1)

        flag2 = F.softmax(logits_clean2, dim=1).std(dim=1).mean().item()
        # print('{:.5f}'.format(flag), end='*****')

        batch_std_sorted2, ind_std_sorted2 = torch.sort(batch_std2.data,
                                                        descending=True)
        ind_split2 = split_set(batch_std_sorted2, flag2)
        if ind_split2 is None:
            ind_split2 = -1
            # print('{} == {}'.format(batch_std_sorted, ind_split), end=' ---> ')

        # uncertain could be either mislabeled or hard example
        ind_uncertain2 = ind_std_sorted2[:(ind_split2 + 1)]

        # print('{:2d}/{:2d}'.format(ind_split2 + 1, logits2.shape[0] - num_remember2), end=' <--- ')
        print_to_logfile(
            stats_log_path2, '{:03d},{:03d}'.format(
                ind_split2 + 1,
                logits2.shape[0] - num_remember2 - ind_split2 - 1))

        ind_mislabeled2 = ind_forget2[ind_uncertain2]
        logits_mislabeled1 = logits1[ind_mislabeled2]
        indices_mislabeled1 = indices[ind_mislabeled2]
        labels_mislabeled1 = memory_pool1.most_prob_labels[
            indices_mislabeled1].to(logits_mislabeled1.device)

        logits_final1 = torch.cat((logits_final1, logits_mislabeled1), dim=0)
        labels_final1 = torch.cat((labels_final1, labels_mislabeled1), dim=0)
    else:
        # print('{:2d}/{:2d}, {:2d}/{:2d}'.format(0, logits1.shape[0] - num_remember1,
        #                                         0, logits2.shape[0] - num_remember2), end=' <--- ')
        print_to_logfile(
            stats_log_path1,
            '{:03d},{:03d}'.format(0, logits1.shape[0] - num_remember1))
        print_to_logfile(
            stats_log_path2,
            '{:03d},{:03d}'.format(0, logits2.shape[0] - num_remember2))

    losses1 = label_smoothing_cross_entropy(logits_final1,
                                            labels_final1,
                                            epsilon=eps,
                                            reduction='none')
    losses2 = label_smoothing_cross_entropy(logits_final2,
                                            labels_final2,
                                            epsilon=eps,
                                            reduction='none')
    return losses1, ce_losses1, losses2, ce_losses2
コード例 #6
0
ファイル: loss.py プロジェクト: yyy11178/CRSSC
def std_loss(logits, labels, indices, T_k, epoch, memory_pool, eps=0.1):
    ce_losses = label_smoothing_cross_entropy(logits,
                                              labels,
                                              epsilon=eps,
                                              reduction='none')

    # in the first T_k epochs, train with the entire training set
    if epoch < T_k:
        # print('using naive CE', end=' <--- ')
        return ce_losses, ce_losses

    # after T_k epochs, start dividing training set into clean / uncertain / irrelevant
    ind_loss_sorted = torch.argsort(ce_losses.data)
    num_remember = torch.nonzero(ce_losses < ce_losses.mean()).shape[0]

    # print(' ---> {:2d}'.format(num_remember), end=', ')
    stats_log_path = 'stats/drop_n_reuse_stats_epoch{:03d}.csv'.format(epoch +
                                                                       1)
    print_to_logfile(stats_log_path,
                     '{:03d}'.format(num_remember),
                     init=False,
                     end=',')

    ind_clean = ind_loss_sorted[:num_remember]
    ind_forget = ind_loss_sorted[num_remember:]
    logits_clean = logits[ind_clean]
    labels_clean = labels[ind_clean]

    if ind_forget.shape[0] > 1:
        # for samples with high loss
        #   high loss, high std --> mislabeling
        #   high loss, low std  --> irrelevant category
        indices_forget = indices[ind_forget]
        logits_forget = logits[ind_forget]
        pred_distribution = F.softmax(logits_forget, dim=1)
        batch_std = pred_distribution.std(dim=1)

        flag = F.softmax(logits_clean, dim=1).std(dim=1).mean().item()
        # print('{:.5f}'.format(flag), end='*****')

        batch_std_sorted, ind_std_sorted = torch.sort(batch_std.data,
                                                      descending=True)
        ind_split = split_set(batch_std_sorted, flag)
        if ind_split is None:
            ind_split = -1
        # print('{} == {}'.format(batch_std_sorted, ind_split), end=' ---> ')

        # uncertain could be either mislabeled or hard example
        ind_uncertain = ind_std_sorted[:(ind_split + 1)]

        # print('{:2d}/{:2d}'.format(ind_split+1, logits.shape[0] - num_remember), end=' <--- ')
        print_to_logfile(stats_log_path,
                         '{:03d},{:03d}'.format(
                             ind_split + 1,
                             logits.shape[0] - num_remember - ind_split - 1),
                         init=False,
                         end='\n')

        logits_mislabeled = logits_forget[ind_uncertain]
        indices_mislabeled = indices_forget[ind_uncertain]
        labels_mislabeled = memory_pool.most_prob_labels[
            indices_mislabeled].to(logits_mislabeled.device)

        logits_final = torch.cat((logits_clean, logits_mislabeled), dim=0)
        labels_final = torch.cat((labels_clean, labels_mislabeled), dim=0)
    else:
        # print('{:2d}/{:2d}'.format(0, logits.shape[0] - num_remember), end=' <--- ')
        print_to_logfile(stats_log_path,
                         '{:03d},{:03d}'.format(0, logits.shape[0] -
                                                num_remember),
                         init=False,
                         end='\n')
        logits_final = logits_clean
        labels_final = labels_clean
    std_losses = label_smoothing_cross_entropy(logits_final,
                                               labels_final,
                                               epsilon=eps,
                                               reduction='none')
    return std_losses, ce_losses