예제 #1
0
def val(model, val_loader, criterion, epoch, writer, use_CUDA = True):
    model.eval()
    accuracy_logger = ScalarLogger(prefix = 'accuracy')
    losses_logger = ScalarLogger(prefix = 'loss')
    with torch.no_grad():
        for (input, label, _) in val_loader:
            input = to_var(input, requires_grad = False)
            label = to_var(label, requires_grad = False).long()

            output = model(input)
            loss = criterion(output, label)
            prediction = torch.softmax(output, 1)
            top1 = accuracy(prediction, label)
            accuracy_logger.update(top1)
            losses_logger.update(loss)

    accuracy_logger.write(writer, 'val', epoch)
    losses_logger.write(writer, 'val', epoch)
    accuracy_ = accuracy_logger.avg()
    losses = losses_logger.avg()
    print("Validation Epoch: {}, Accuracy: {}, Losses: {}".format(epoch, accuracy_, losses))
    return accuracy_, losses
예제 #2
0
def train(model, input_channel, optimizer, criterion, train_loader, val_loader, epoch, writer, args, use_CUDA = True, clamp = False, num_classes = 10):
    model.train()
    accs = []
    losses_w1 = []
    losses_w2 = []
    iter_val_loader = iter(val_loader)
    meta_criterion = nn.CrossEntropyLoss(reduce = False)
    index = 0
    noisy_labels = []
    true_labels = []

    w = defaultdict()
    w_logger = defaultdict()
    losses_logger = defaultdict()
    accuracy_logger = ScalarLogger(prefix = 'accuracy')
         
    for (input, label, real) in train_loader:
        noisy_labels.append(label)
        true_labels.append(real)
        input = to_var(input, requires_grad = False)
        label = to_var(label, requires_grad = False).long()
        index += 1
        output = model(input)
        loss = meta_criterion(output, label).sum() / input.shape[0]
        prediction = torch.softmax(output, 1)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        top1 = accuracy(prediction, label)
        accuracy_logger.update(top1)
        
    noisy_labels = torch.cat(noisy_labels)
    true_labels = torch.cat(true_labels)
    mask = (noisy_labels != true_labels).cpu().numpy()

    accuracy_logger.write(writer, 'train', epoch)
    
    print("Training Epoch: {}, Accuracy: {}".format(epoch, accuracy_logger.avg()))
    return accuracy_logger.avg()
예제 #3
0
def train(model,
          input_channel,
          optimizers,
          criterion,
          components,
          train_loader,
          val_loader,
          epoch,
          writer,
          args,
          use_CUDA=True,
          clamp=False,
          num_classes=10):
    model.train()
    accs = []
    losses_w1 = []
    losses_w2 = []
    iter_val_loader = iter(val_loader)
    meta_criterion = nn.CrossEntropyLoss(reduce=False)
    index = 0

    w = defaultdict()
    w_logger = defaultdict()
    losses_logger = defaultdict()
    accuracy_logger = ScalarLogger(prefix='accuracy')
    for c in components:
        w[c] = None
        w_logger[c] = WLogger()
        losses_logger[c] = ScalarLogger(prefix='loss')

    w_all = []

    store_input = None
    store_label = None
    store_real = None
    for (input, label, real) in train_loader:
        if store_input is None:
            store_input = input
            store_label = label

        meta_model = get_model(args,
                               num_classes=num_classes,
                               input_channel=input_channel)
        meta_model.load_state_dict(model.state_dict())
        if use_CUDA:
            meta_model = meta_model.cuda()

        val_input, val_label, iter_val_loader = get_val_samples(
            iter_val_loader, val_loader)
        store_input = to_var(store_input, requires_grad=False)
        store_label = to_var(store_label, requires_grad=False).long()
        val_input = to_var(val_input, requires_grad=False)
        val_label = to_var(val_label, requires_grad=False).long()

        meta_output = meta_model(store_input)
        cost = meta_criterion(meta_output, store_label)
        eps = to_var(torch.zeros(cost.size()))
        meta_loss = (cost * eps).sum()
        meta_model.zero_grad()

        if 'all' in components:
            grads = torch.autograd.grad(meta_loss, (meta_model.parameters()),
                                        create_graph=True)
            meta_model.update_params(0.001, source_params=grads)

            meta_val_output = meta_model(val_input)
            meta_val_loss = meta_criterion(meta_val_output, val_label).sum()
            grad_eps = torch.autograd.grad(meta_val_loss,
                                           eps,
                                           only_inputs=True)[0]
            if clamp:
                w['all'] = torch.clamp(-grad_eps, min=0)
            else:
                w['all'] = -grad_eps

            norm = torch.sum(abs(w['all']))
            assert (clamp and len(components)
                    == 1) or (len(components) > 1), "Error combination"
            w['all'] = w['all'] / norm
            if ('fc' in components):
                w['fc'] = copy.deepcopy(w['all'])
                w['fc'] = torch.clamp(w['fc'], max=0)
                w['all'] = torch.clamp(w['all'], min=0)
            elif ('backbone' in components):
                w['backbone'] = copy.deepcopy(w['all'])
                w['backbone'] = torch.clamp(w['backbone'], max=0)
                w['all'] = torch.clamp(w['all'], min=0)

        else:
            assert ('backbone' in components) and ('fc' in components)

            grads_backbone = torch.autograd.grad(
                meta_loss, (meta_model.backbone.parameters()),
                create_graph=True,
                retain_graph=True)
            grads_fc = torch.autograd.grad(meta_loss,
                                           (meta_model.fc.parameters()),
                                           create_graph=True)

            # Backbone Grads
            meta_model.backbone.update_params(0.001,
                                              source_params=grads_backbone)
            meta_val_output = meta_model(val_input)
            meta_val_loss = meta_criterion(meta_val_output, val_label).sum()
            grad_eps = torch.autograd.grad(meta_val_loss,
                                           eps,
                                           only_inputs=True,
                                           retain_graph=True)[0]
            if clamp:
                w['backbone'] = torch.clamp(-grad_eps, min=0)
            else:
                w['backbone'] = -grad_eps
            norm = torch.sum(abs(w['backbone']))
            w['backbone'] = w['backbone'] / norm

            # FC backward
            meta_model.load_state_dict(model.state_dict())
            meta_model.fc.update_params(0.001, source_params=grads_fc)
            meta_val_output = meta_model(val_input)
            meta_val_loss = meta_criterion(meta_val_output, val_label).sum()
            grad_eps = torch.autograd.grad(meta_val_loss,
                                           eps,
                                           only_inputs=True,
                                           retain_graph=True)[0]

            if clamp:
                w['fc'] = torch.clamp(-grad_eps, min=0)
            else:
                w['fc'] = -grad_eps
            norm = torch.sum(abs(w['fc']))
            w['fc'] = w['fc'] / norm

        w_all.append(w['all'].detach().cpu().numpy().reshape(128, 1))

    w_all = np.concatenate(w_all, axis=1)

    assert w_all.shape[0] == 128
    pickle.dump(w_all, open('w.npy', 'wb'))
    print(np.std(w_all, axis=1))
    print(np.mean(w_all, axis=1))
    print(np.std(w_all, axis=1) / np.mean(w_all, axis=1))
예제 #4
0
def train_with_meta(clean_loader,
                    noisy_loader,
                    test_loader,
                    model,
                    ema_model,
                    criterion,
                    consistency_criterion,
                    optimizer,
                    scheduler,
                    epoch,
                    warmup=False,
                    self_paced_pick=0):

    global step, switched, single_epoch_steps
    pacc = AverageMeter()
    nacc = AverageMeter()
    pnacc = AverageMeter()
    cov0 = AverageMeter()
    cov1 = AverageMeter()
    wilcoxon = AverageMeter()
    model.train()
    ema_model.train()
    consistency_weight = get_current_consistency_weight(epoch - 30)
    resultt = np.zeros(61000)
    if not warmup: scheduler.step()

    print("Learning rate is {}".format(optimizer.param_groups[0]['lr']))
    if (clean_loader):
        for i, (X, _, Y, T, ids, _) in enumerate(clean_loader):
            # measure data loading time
            xeps = torch.ones((X.shape[0], 2)).cuda() * 1e-10
            xeps[:, 1] = 0
            xeps = xeps.cuda()
            if args.gpu == None:
                X = X.cuda()
                Y = Y.cuda().float()
                T = T.cuda().long()
            else:
                X = X.cuda()
                Y = Y.cuda().float()
                T = T.cuda().long()

            if args.dataset == 'mnist':
                X = X.view(X.shape[0], -1)

            # compute output
            output = model(X)
            with torch.no_grad():
                ema_output = ema_model(X)

            consistency_loss = consistency_weight * \
            consistency_criterion(output, ema_output) / X.shape[0]
            #if epoch >= args.self_paced_start: criterion.update_p(0.5)
            _, loss = criterion(output, Y, eps=1)  # 计算loss,使用PU标签

            #print(output)
            # measure accuracy and record loss

            if check_mean_teacher(epoch):
                predictions = torch.sign(ema_output).long()  # 使用teacher的结果作为预测
            else:
                predictions = torch.sign(output).long()  # 否则使用自己的结果

            smx = torch.sigmoid(output)  # 计算sigmoid概率
            #print(smx)
            smx = torch.cat([1 - smx, smx], dim=1)  # 组合成预测变量
            smxY = ((Y + 1) // 2).long()  # 分类结果,0-1分类

            if args.soft_label:
                xent = -torch.sum(smx * torch.log(smx + xeps), dim=1)
                aux = xent.mean()
            else:
                aux = F.cross_entropy(smx + xeps, smxY)  # 计算Xent loss
            loss = aux
            if check_mean_teacher(epoch):
                loss += consistency_loss
            if args.soft_label:
                detach = xent.detach().cpu().numpy()
            else:
                detach = 0

            optimizer.zero_grad()
            #if not np.any(np.isnan(detach)):
            if True:
                loss.backward()
                optimizer.step()

            if np.any(np.isnan(detach)):
                for i in model.parameters():
                    print('clean_data')
                    print(i)
                    print('clean_grad')
                    print(i.grad)

            if check_mean_teacher(epoch) and (
                (i + 1) % int(single_epoch_steps / 2 - 1)) == 0:
                update_ema_variables(model, ema_model, args.ema_decay,
                                     step)  # 更新ema参数
                step += 1

            pacc_, nacc_, pnacc_, psize = accuracy(predictions,
                                                   T)  # 使用T来计算预测准确率
            if np.any(torch.isnan(output).cpu().numpy()):
                print(output)
                print("clean interrupt")
                raise NotImplementedError
            pacc.update(pacc_, psize)
            nacc.update(nacc_, X.size(0) - psize)
            pnacc.update(pnacc_, X.size(0))

        print('Epoch Clean : [{0}]\t'
              'PACC {pacc.val:.3f} ({pacc.avg:.3f})\t'
              'NACC {nacc.val:.3f} ({nacc.avg:.3f})\t'
              'PNACC {pnacc.val:.3f} ({pnacc.avg:.3f})\t'.format(epoch,
                                                                 pacc=pacc,
                                                                 nacc=nacc,
                                                                 pnacc=pnacc))

    if epoch <= args.noisy_stop:
        meta_step = 0

        for i, (X, Y, _, T, ids, p) in enumerate(noisy_loader):
            xeps = torch.ones((X.shape[0], 2)).cuda() * 1e-10
            xeps[:, 1] = 0
            meta_step += 1
            if args.dataset == 'cifar':
                meta_net = create_cifar_model()
            else:
                meta_net = create_model()
            meta_net.load_state_dict(model.state_dict())
            if torch.cuda.is_available():
                meta_net.cuda()

            if args.gpu == None:
                X = X.cuda()
                Y = Y.cuda().float()
                T = T.cuda().long()
                p = p.cuda().float()
            else:
                X = X.cuda()
                Y = Y.cuda().float()
                T = T.cuda().long()
                p = p.cuda().float()
            if args.dataset == 'mnist':
                X = X.view(X.shape[0], -1)

            y_f_hat = meta_net(X)
            prob = torch.sigmoid(y_f_hat)
            prob = torch.cat([1 - prob, prob], dim=1)

            cost1 = torch.sum(prob * torch.log(prob + xeps), dim=1)
            eps = to_var(torch.zeros(cost1.shape[0], 2))
            cost2 = criterion(y_f_hat, Y, eps=eps[:, 0])
            l_f_meta = (cost1 * eps[:, 1]).mean() + cost2[1]
            meta_net.zero_grad()

            grads = torch.autograd.grad(l_f_meta,
                                        meta_net.parameters(),
                                        create_graph=True)
            meta_net.update_params(0.001, source_params=grads)
            val_data, val_Y, _, val_labels, val_ids, val_p = next(
                iter(test_loader))
            veps = torch.ones((val_data.shape[0], 2)).cuda() * 1e-10
            veps[:, 1] = 0
            val_data = to_var(val_data, requires_grad=False)
            if args.dataset == 'mnist':
                val_data = val_data.view(-1, 784)

            val_labels = to_var(val_labels, requires_grad=False).float()
            y_g_hat = meta_net(val_data)

            val_prob = torch.sigmoid(y_g_hat)
            val_prob = torch.cat([1 - val_prob, val_prob], dim=1)

            val_xent = -torch.sum(val_prob * torch.log(val_prob + veps), dim=1)
            val_xent[torch.isnan(val_xent)] = 0
            l_g_meta = val_xent.mean()

            grad_eps = torch.autograd.grad(l_g_meta, eps, only_inputs=True)[0]
            w = torch.clamp(-grad_eps, min=1e-6)
            del meta_net
            for j in range(w.shape[0]):
                if Y[j] == -1:
                    if torch.sum(w[:, 1]) >= 4:
                        w[j, 0] = 1
                        w[j, 1] = 0
                    else:
                        w[j, :] = w[j, :] / torch.sum(w[j, :])

                else:
                    w[j, 0] = 1
                    w[j, 1] = 0
            #w[:, 0] = 1
            #w[:, 1] = 0
            w = w.cuda().detach()
            output = model(X)
            with torch.no_grad():
                ema_output = ema_model(X)

            consistency_loss = consistency_weight * \
            consistency_criterion(output, ema_output) / X.shape[0]

            #if epoch >= args.self_paced_start: criterion.update_p(0.5)
            _, loss = criterion(output, Y, eps=w[:, 0])

            prob = torch.sigmoid(output)
            prob = torch.cat([1 - prob, prob], dim=1)
            xent = -torch.sum(prob * torch.log(prob + xeps), dim=1)
            detach = xent.detach().cpu().numpy()
            #xent[torch.isnan(xent)] = 0
            if check_mean_teacher(epoch) and not warmup:
                total_loss = loss + consistency_loss + (xent * w[:, 1]).mean()
                predictions = torch.sign(ema_output).long()
            else:
                predictions = torch.sign(output).long()
                total_loss = (xent * w[:, 1]).mean() + loss

            cov_0 = ((prob[:, 0] * w[:, 0]).mean() -
                     prob[:, 0].float().mean() * w[:, 0].mean()) / torch.sqrt(
                         torch.var(prob[:, 0].float()) *
                         torch.var(w[:, 0].float() + 1e-6))
            cov_1 = ((prob[:, 0] * w[:, 1]).mean() -
                     prob[:, 0].float().mean() * w[:, 1].mean()) / torch.sqrt(
                         torch.var(prob[:, 0].float()) *
                         torch.var(w[:, 1].float() + 1e-6))
            #if epoch >= args.self_paced_start

            if check_noisy(epoch):

                optimizer.zero_grad()
                total_loss.backward()
                optimizer.step()

                if np.any(np.isnan(detach)):
                    for i in model.parameters():
                        print('noisy_data')
                        print(i)
                        print('noisy_grad')
                        print(i.grad)
                    print(w)
                    print(prob)
                    print(xent)

            if check_mean_teacher(epoch) and (
                (i + 1) % int(single_epoch_steps / 2 - 1)) == 0 and not warmup:
                update_ema_variables(model, ema_model, args.ema_decay, step)
                step += 1

            pacc_, nacc_, pnacc_, psize = accuracy(predictions, T)
            if np.any(torch.isnan(output).cpu().numpy()):
                print('noisy_interrupt')
                print(output)
                raise NotImplementedError
            last_loss = loss
            last_total_loss = total_loss
            last_predictions = predictions
            last_output = output
            last_prob = prob
            last_xent = xent
            pacc.update(pacc_, psize)
            nacc.update(nacc_, X.size(0) - psize)
            pnacc.update(pnacc_, X.size(0))
            cov0.update(cov_0)
            cov1.update(cov_1)
        print('Noisy Epoch: [{0}]\t'
              'PACC {pacc.val:.3f} ({pacc.avg:.3f})\t'
              'NACC {nacc.val:.3f} ({nacc.avg:.3f})\t'
              'PNACC {pnacc.val:.3f} ({pnacc.avg:.3f})\t'
              'COV0 {cov0.val:.3f} ({cov0.avg:.3f})\t'
              'COV1 {cov1.val:.3f} ({cov1.avg:.3f})\t'.format(epoch,
                                                              pacc=pacc,
                                                              nacc=nacc,
                                                              pnacc=pnacc,
                                                              cov0=cov0,
                                                              cov1=cov1))
        #raise NotImplementedError

    return pacc.avg, nacc.avg, pnacc.avg
예제 #5
0
def train(model,
          vnet,
          input_channel,
          optimizers,
          optimizer_vnet,
          components,
          criterion,
          train_loader,
          val_loader,
          epoch,
          writer,
          args,
          use_CUDA=True,
          clamp=False,
          num_classes=10):
    model.train()
    accs = []
    losses_w1 = []
    losses_w2 = []
    iter_val_loader = iter(val_loader)
    meta_criterion = nn.CrossEntropyLoss(reduce=False)
    index = 0
    noisy_labels = []
    true_labels = []

    w = defaultdict()
    w_logger = defaultdict()
    losses_logger = defaultdict()
    accuracy_logger = ScalarLogger(prefix='accuracy')
    for c in components:
        w[c] = None
        w_logger[c] = WLogger()
        losses_logger[c] = ScalarLogger(prefix='loss')

    for (input, label, real) in train_loader:
        noisy_labels.append(label)
        true_labels.append(real)

        meta_model = get_model(args,
                               num_classes=num_classes,
                               input_channel=input_channel)
        meta_model.load_state_dict(model.state_dict())
        if use_CUDA:
            meta_model = meta_model.cuda()

        val_input, val_label, iter_val_loader = get_val_samples(
            iter_val_loader, val_loader)
        input = to_var(input, requires_grad=False)
        label = to_var(label, requires_grad=False).long()
        val_input = to_var(val_input, requires_grad=False)
        val_label = to_var(val_label, requires_grad=False).long()

        meta_output = meta_model(input)
        cost = meta_criterion(meta_output, label)
        #eps = to_var(torch.zeros(cost.size()))
        cost_v = torch.reshape(cost, (len(cost), 1))
        eps = vnet(cost_v.data)  # shape: (N, 2)

        meta_loss_backbone = (cost * eps[:, 0]).sum()
        meta_loss_fc = (cost * eps[:, 1]).sum()
        meta_model.zero_grad()

        grads_backbone = torch.autograd.grad(
            meta_loss_backbone, (meta_model.backbone.parameters()),
            create_graph=True,
            retain_graph=True)
        grads_fc = torch.autograd.grad(meta_loss_fc,
                                       (meta_model.fc.parameters()),
                                       create_graph=True)

        # Backbone Grads
        meta_model.backbone.update_params(0.001, source_params=grads_backbone)
        meta_val_feature = torch.flatten(meta_model.backbone(val_input), 1)
        meta_val_output = meta_model.fc(meta_val_feature)
        meta_val_loss = meta_criterion(meta_val_output, val_label).sum()
        ''' TODO: temorarily remove 
        if args.with_kl and args.reg_start <= epoch:
            train_feature = torch.flatten(meta_model.backbone(input), 1)
            meta_val_loss -= sample_wise_kl(train_feature, meta_val_feature)
                    
        grad_eps = torch.autograd.grad(meta_val_loss, eps, only_inputs = True, retain_graph = True)[0]
        if clamp:
            w['backbone'] = torch.clamp(-grad_eps, min = 0)
        else:
            w['backbone'] = -grad_eps
        norm = torch.sum(abs(w['backbone']))
        w['backbone'] = w['backbone'] / norm
        '''
        optimizer_vnet.zero_grad()
        meta_val_loss.backward(retain_graph=True)
        optimizer_vnet.step()

        # FC backward
        meta_model.load_state_dict(model.state_dict())
        meta_model.fc.update_params(0.001, source_params=grads_fc)
        meta_val_output = meta_model(val_input)
        meta_val_loss = meta_criterion(meta_val_output, val_label).sum()
        '''
        grad_eps = torch.autograd.grad(meta_val_loss, eps, only_inputs = True, retain_graph = True)[0]
        
        if clamp:
            w['fc'] = torch.clamp(-grad_eps, min = 0)
        else:
            w['fc'] = -grad_eps
        norm = torch.sum(abs(w['fc']))
        w['fc'] = w['fc'] / norm
        '''
        optimizer_vnet.zero_grad()
        meta_val_loss.backward(retain_graph=True)
        optimizer_vnet.step()

        index += 1
        output = model(input)
        losses = defaultdict()
        loss = meta_criterion(output, label)
        loss_v = torch.reshape(loss, (len(loss), 1))
        with torch.no_grad():
            w_ = vnet(loss_v)
            if clamp:
                w_ = torch.clamp(w_, min=0)
            for i in range(w_.shape[1]):
                w_[:, i] = torch.sum(torch.abs(w_[:, i]))
            w['backbone'] = w_[:, 0]
            w['fc'] = w_[:, 1]

        prediction = torch.softmax(output, 1)
        for c in components:
            w_logger[c].update(w[c])
            losses[c] = (loss * w[c]).sum()
            optimizers[c].zero_grad()
            losses[c].backward(retain_graph=True)
            optimizers[c].step()
            losses_logger[c].update(losses[c])

        top1 = accuracy(prediction, label)
        accuracy_logger.update(top1)

    noisy_labels = torch.cat(noisy_labels)
    true_labels = torch.cat(true_labels)
    mask = (noisy_labels != true_labels).cpu().numpy()
    for c in components:
        w_logger[c].write(writer, c, epoch)
        w_logger[c].mask_write(writer, c, epoch, mask)
        losses_logger[c].write(writer, c, epoch)

    accuracy_logger.write(writer, 'train', epoch)

    print("Training Epoch: {}, Accuracy: {}".format(epoch,
                                                    accuracy_logger.avg()))
    return accuracy_logger.avg()
예제 #6
0
def train_with_meta(clean1_loader, noisy1_loader, clean2_loader, noisy2_loader, test_loader, model1, model2, ema_model1, ema_model2, criterion, consistency_criterion, optimizer1, scheduler1, optimizer2, scheduler2, epoch, warmup = False, self_paced_pick = 0):

    global step, switched
    batch_time = AverageMeter()
    data_time = AverageMeter()
    #losses = AverageMeter()
    pacc1 = AverageMeter()
    nacc1 = AverageMeter()
    pnacc1 = AverageMeter()
    pacc2 = AverageMeter()
    nacc2 = AverageMeter()
    pnacc2 = AverageMeter()
    pacc3 = AverageMeter()
    nacc3 = AverageMeter()
    pnacc3 = AverageMeter()
    pacc4 = AverageMeter()
    nacc4 = AverageMeter()
    pnacc4 = AverageMeter()
    model1.train()
    model2.train()
    ema_model1.train()
    ema_model2.train()
    end = time.time()
    consistency_weight = get_current_consistency_weight(epoch - 30)
    if not warmup: 
        scheduler1.step()
        scheduler2.step()
    resultt = np.zeros(61000)

    if clean1_loader: 
        for i, (X, left, right, _, Y, T, ids) in enumerate(clean1_loader):
            # measure data loading time
            
            data_time.update(time.time() - end)
            X = X.cuda(args.gpu)
            left = left.cuda(args.gpu)
            right = right.cuda(args.gpu)
            Y = Y.cuda(args.gpu).float()
            T = T.cuda(args.gpu).long()
            # compute output
            output1 = model1(X, left, right)
            output2 = model2(X, left, right)
            with torch.no_grad():
                ema_output1 = ema_model1(X, left, right)

            consistency_loss = consistency_weight * \
            consistency_criterion(output1, ema_output1) / X.shape[0]

            predictiont1   = torch.sign(ema_output1).long()
            predictions1 = torch.sign(output1).long() # 否则使用自己的结果

            smx1 = torch.sigmoid(output1) # 计算sigmoid概率
            smx1 = torch.cat([1 - smx1, smx1], dim=1) # 组合成预测变量
            smxY = ((Y + 1) // 2).long() # 分类结果,0-1分类
            smx2 = torch.sigmoid(output2) # 计算sigmoid概率
            smx2 = torch.cat([1 - smx2, smx2], dim=1) # 组合成预测变量

            smx1[smx1 > 1-1e-7] = 1-1e-7
            smx1[smx1 < 1e-7] = 1e-7
            smx2[smx2 > 1-1e-7] = 1-1e-7
            smx2[smx2 < 1e-7] = 1e-7

            if args.soft_label:
                aux1 = - torch.sum(smx1 * torch.log(smx1)) / smx1.shape[0]
                aux2 = - torch.sum(smx2 * torch.log(smx2)) / smx2.shape[0]
            else:
                smxY = smxY.float()
                smxY = smxY.view(-1, 1)
                smxY = torch.cat([1 - smxY, smxY], dim = 1)
                aux1 = - torch.sum(smxY * torch.log(smx1)) / smxY.shape[0] # 计算Xent loss
                aux2 = - torch.sum(smxY * torch.log(smx2)) / smxY.shape[0] # 计算Xent loss

            loss = aux1

            if check_mean_teacher(epoch):
                loss += consistency_loss

            optimizer1.zero_grad()
            loss.backward()
            optimizer1.step()
            
            pacc_1, nacc_1, pnacc_1, psize = accuracy(predictions1, T) # 使用T来计算预测准确率
            pacc_3, nacc_3, pnacc_3, psize = accuracy(predictiont1, T) 
            pacc1.update(pacc_1, psize)
            nacc1.update(nacc_1, X.size(0) - psize)
            pnacc1.update(pnacc_1, X.size(0))
            pacc3.update(pacc_3, psize)
            nacc3.update(nacc_3, X.size(0) - psize)
            pnacc3.update(pnacc_3, X.size(0))
        
    if clean2_loader:
        for i, (X, left, right,  _, Y, T, ids) in enumerate(clean2_loader):
            # measure data loading time
            
            data_time.update(time.time() - end)
            X = X.cuda(args.gpu)
            left = left.cuda(args.gpu)
            right = right.cuda(args.gpu)
            Y = Y.cuda(args.gpu).float()
            T = T.cuda(args.gpu).long()
            # compute output
            output1 = model1(X, left, right)
            output2 = model2(X, left, right)
            with torch.no_grad():
                ema_output2 = ema_model2(X, left, right)

            consistency_loss = consistency_weight * \
            consistency_criterion(output2, ema_output2) / X.shape[0]

            predictiont2 = torch.sign(ema_output2).long()
            predictions2 = torch.sign(output2).long()

            smx1 = torch.sigmoid(output1) # 计算sigmoid概率
            smx1 = torch.cat([1 - smx1, smx1], dim=1) # 组合成预测变量

            smxY = ((Y + 1) // 2).long() # 分类结果,0-1分类

            smx2 = torch.sigmoid(output2) # 计算sigmoid概率
            smx2 = torch.cat([1 - smx2, smx2], dim=1) # 组合成预测变量

            smx1[smx1 > 1-1e-7] = 1-1e-7
            smx1[smx1 < 1e-7] = 1e-7
            smx2[smx2 > 1-1e-7] = 1-1e-7
            smx2[smx2 < 1e-7] = 1e-7

            if args.soft_label:
                aux1 = - torch.sum(smx1 * torch.log(smx1)) / smx1.shape[0]
                aux2 = - torch.sum(smx2 * torch.log(smx2)) / smx2.shape[0]
            else:
                smxY = smxY.float()
                smxY = smxY.view(-1, 1)
                smxY = torch.cat([1 - smxY, smxY], dim = 1)
                aux1 = - torch.sum(smxY * torch.log(smx1)) / smxY.shape[0] # 计算Xent loss
                aux2 = - torch.sum(smxY * torch.log(smx2)) / smxY.shape[0] # 计算Xent loss

            loss = aux2

            if check_mean_teacher(epoch):
                loss += consistency_loss
                
            optimizer2.zero_grad()
            loss.backward()
            optimizer2.step()

            pacc_2, nacc_2, pnacc_2, psize = accuracy(predictions2, T)
            pacc_4, nacc_4, pnacc_4, psize = accuracy(predictiont2, T) 
            pacc2.update(pacc_2, psize)
            nacc2.update(nacc_2, X.size(0) - psize)
            pnacc2.update(pnacc_2, X.size(0))
            pacc4.update(pacc_4, psize)
            nacc4.update(nacc_4, X.size(0) - psize)
            pnacc4.update(pnacc_4, X.size(0))

        if check_mean_teacher(epoch):
            update_ema_variables(model1, ema_model1, args.ema_decay, step) # 更新ema参数
            update_ema_variables(model2, ema_model2, args.ema_decay, step)
            step += 1

        print('Epoch Clean : [{0}]\t'
                'PACC1 {pacc1.val:.3f} ({pacc1.avg:.3f})\t'
                'NACC1 {nacc1.val:.3f} ({nacc1.avg:.3f})\t'
                'PNACC1 {pnacc1.val:.3f} ({pnacc1.avg:.3f})\t'
                'PACC2 {pacc2.val:.3f} ({pacc2.avg:.3f})\t'
                'NACC2 {nacc2.val:.3f} ({nacc2.avg:.3f})\t'
                'PNACC2 {pnacc2.val:.3f} ({pnacc2.avg:.3f})\t'
                'PACC3 {pacc3.val:.3f} ({pacc3.avg:.3f})\t'
                'NACC3 {nacc3.val:.3f} ({nacc3.avg:.3f})\t'
                'PNACC3 {pnacc3.val:.3f} ({pnacc3.avg:.3f})\t'
                'PACC4 {pacc4.val:.3f} ({pacc4.avg:.3f})\t'
                'NACC4 {nacc4.val:.3f} ({nacc4.avg:.3f})\t'
                'PNACC4 {pnacc4.val:.3f} ({pnacc4.avg:.3f})\t'.format(
                epoch, pacc1=pacc1, nacc1=nacc1, pnacc1=pnacc1, 
                pacc2=pacc2, nacc2=nacc2, pnacc2=pnacc2, pacc3=pacc3, nacc3=nacc3, pnacc3=pnacc3, 
                pacc4=pacc4, nacc4=nacc4, pnacc4=pnacc4))
    
    for i, (X, left, right, Y, _, T, ids) in enumerate(noisy1_loader):
        X = X.cuda(args.gpu)
        left = left.cuda(args.gpu)
        right = right.cuda(args.gpu)
        Y = Y.cuda(args.gpu).float()
        T = T.cuda(args.gpu).long()
        meta_net = create_lenet_model()
        meta_net.load_state_dict(model1.state_dict())
        if torch.cuda.is_available():
            meta_net.cuda()

        y_f_hat = meta_net(X, left, right)
        prob = torch.sigmoid(y_f_hat)
        prob = torch.cat([1-prob, prob], dim=1)

        cost1 = torch.sum(prob * torch.log(prob + 1e-10), dim = 1)
        eps = to_var(torch.zeros(cost1.shape[0], 2))
        cost2 = criterion(y_f_hat, Y, eps = eps[:, 0])
        l_f_meta = (cost1 * eps[:, 1]).mean() + cost2[1]
        meta_net.zero_grad()
        
        grads = torch.autograd.grad(l_f_meta, meta_net.parameters(), create_graph = True, allow_unused=True)
        meta_net.update_params(0.001, source_params = grads)
        val_data, val_left, val_right, val_Y, _, val_labels, val_ids = next(iter(test_loader))

        val_data = to_var(val_data, requires_grad = False)
        val_left = to_var(val_left, requires_grad = False)
        val_right = to_var(val_right, requires_grad = False)
        val_labels = to_var(val_labels, requires_grad=False).float()
        y_g_hat = meta_net(val_data, val_left, val_right)
        
        val_prob = torch.sigmoid(y_g_hat)
        val_prob = torch.cat([1 - val_prob, val_prob], dim=1)

        l_g_meta = -torch.mean(torch.sum(val_prob * torch.log(val_prob + 1e-10), dim = 1)) * 2
        
        grad_eps = torch.autograd.grad(l_g_meta, eps, only_inputs=True)[0] 
        #print(grad_eps) 
        w = torch.clamp(-grad_eps, min = 1e-10)
        acount = 0
        bcount = 0
        ccount = 0
        dcount = 0
        for j in range(w.shape[0]):
            if Y[j] == -1:
                if torch.sum(w[:, 1]) >= 4:
                    w[j, 0] = 1
                    w[j, 1] = 0
                else:
                    w[j, :] = w[j, :] / torch.sum(w[j, :])
                
            else:
                w[j, 0] = 1
                w[j, 1] = 0
            #w[:, 0] = 1
            #w[:, 1] = 0
            w = w.cuda().detach()
        # compute output
        # compute output
        output1 = model1(X, left, right)
        output2 = model2(X, left, right)

        with torch.no_grad():
            ema_output1 = ema_model1(X, left, right)
        #if epoch >= args.self_paced_start: criterion.update_p(0.5)
        _, loss = criterion(output1, Y, eps = w[:, 0])
        consistency_loss = consistency_weight * \
        consistency_criterion(output1, ema_output1) / X.shape[0]
        #print(loss1)

        predictions1 = torch.sign(output1).long()
        predictiont1 = torch.sign(ema_output1).long()

        smx1 = torch.sigmoid(output1) # 计算sigmoid概率
        smx1 = torch.cat([1 - smx1, smx1], dim=1) # 组合成预测变量

        smxY = ((Y + 1) // 2).long() # 分类结果,0-1分类

        smx2 = torch.sigmoid(output2) # 计算sigmoid概率
        smx2 = torch.cat([1 - smx2, smx2], dim=1) # 组合成预测变量

        xent = -torch.sum(smx1 * torch.log(smx1 + 1e-10), dim = 1)
         
        if args.type_noisy == 'mu' and check_mean_teacher(epoch):
            aux = F.mse_loss(smx1[:, 0], smx2[:, 0].detach())
            if aux < loss * args.alpha:
                loss += aux

        if check_mean_teacher(epoch):
            loss += consistency_loss
            
        loss += (xent * w[:, 1]).mean()
        optimizer1.zero_grad()
        loss.backward()
        optimizer1.step()
        pacc_3, nacc_3, pnacc_3, psize = accuracy(predictiont1, T) 
        pacc_1, nacc_1, pnacc_1, psize = accuracy(predictions1, T) # 使用T来计算预测准确率
        pacc1.update(pacc_1, psize)
        nacc1.update(nacc_1, X.size(0) - psize)
        pnacc1.update(pnacc_1, X.size(0))
        pacc3.update(pacc_3, psize)
        nacc3.update(nacc_3, X.size(0) - psize)
        pnacc3.update(pnacc_3, X.size(0))

    for i, (X, left, right, Y, _, T, ids) in enumerate(noisy2_loader):

        X = X.cuda(args.gpu)
        left = left.cuda(args.gpu)
        right = right.cuda(args.gpu)
        Y = Y.cuda(args.gpu).float()
        T = T.cuda(args.gpu).long()
        meta_net = create_lenet_model()
        meta_net.load_state_dict(model2.state_dict())
        if torch.cuda.is_available():
            meta_net.cuda()

        y_f_hat = meta_net(X, left, right)
        prob = torch.sigmoid(y_f_hat)
        prob = torch.cat([1-prob, prob], dim=1)

        cost1 = torch.sum(prob * torch.log(prob + 1e-10), dim = 1)
        eps = to_var(torch.zeros(cost1.shape[0], 2))
        cost2 = criterion(y_f_hat, Y, eps = eps[:, 0])
        l_f_meta = (cost1 * eps[:, 1]).mean() + cost2[1]
        meta_net.zero_grad()
        
        grads = torch.autograd.grad(l_f_meta, meta_net.parameters(), create_graph = True)
        meta_net.update_params(0.001, source_params = grads)
        val_data, val_left, val_right, val_Y, _, val_labels, val_ids = next(iter(test_loader))

        val_data = to_var(val_data, requires_grad = False)
        val_left = to_var(val_left, requires_grad = False)
        val_right = to_var(val_right, requires_grad = False)
        val_labels = to_var(val_labels, requires_grad=False).float()
        
        y_g_hat = meta_net(val_data, val_left, val_right)
        
        val_prob = torch.sigmoid(y_g_hat)
        val_prob = torch.cat([1 - val_prob, val_prob], dim=1)

        l_g_meta = -torch.mean(torch.sum(val_prob * torch.log(val_prob + 1e-10), dim = 1)) * 2
        
        grad_eps = torch.autograd.grad(l_g_meta, eps, only_inputs=True)[0] 
        #print(grad_eps) 
        w = torch.clamp(-grad_eps, min = 1e-10)
        acount = 0
        bcount = 0
        ccount = 0
        dcount = 0
        for j in range(w.shape[0]):
            if Y[j] == -1:
                if torch.sum(w[:, 1]) >= 4:
                    w[j, 0] = 1
                    w[j, 1] = 0
                else:
                    w[j, :] = w[j, :] / torch.sum(w[j, :])
                
            else:
                w[j, 0] = 1
                w[j, 1] = 0
            #w[:, 0] = 1
            #w[:, 1] = 0
            w = w.cuda().detach()
        # compute output
        output1 = model1(X, left, right)
        output2 = model2(X, left, right)
        with torch.no_grad():
            ema_output2 = ema_model2(X, left, right)

        _, loss = criterion(output2, Y, eps = w[:, 0])
        consistency_loss = consistency_weight * \
        consistency_criterion(output2, ema_output2) / X.shape[0]
        #print(loss2)
        predictions2 = torch.sign(output2).long()
        predictiont2 = torch.sign(ema_output2).long()

        smx1 = torch.sigmoid(output1) # 计算sigmoid概率
        smx1 = torch.cat([1 - smx1, smx1], dim=1) # 组合成预测变量

        smxY = ((Y + 1) // 2).long() # 分类结果,0-1分类

        smx2 = torch.sigmoid(output2) # 计算sigmoid概率
        smx2 = torch.cat([1 - smx2, smx2], dim=1) # 组合成预测变量
        xent = -torch.sum(smx2 * torch.log(smx2 + 1e-10), dim = 1)

        if args.type_noisy == 'mu' and check_mean_teacher(epoch):
            aux = F.mse_loss(smx2[:, 0], smx1[:, 0].detach())
            if aux < loss * args.alpha:
                loss += aux

        if check_mean_teacher(epoch):
            loss += consistency_loss
        loss += (xent * w[:, 1]).mean()
        optimizer2.zero_grad()
        loss.backward()
        optimizer2.step()

        pacc_2, nacc_2, pnacc_2, psize = accuracy(predictions2, T)
        pacc_4, nacc_4, pnacc_4, psize = accuracy(predictiont2, T) 
        pacc2.update(pacc_2, psize)
        nacc2.update(nacc_2, X.size(0) - psize)
        pnacc2.update(pnacc_2, X.size(0))
        pacc4.update(pacc_4, psize)
        nacc4.update(nacc_4, X.size(0) - psize)
        pnacc4.update(pnacc_4, X.size(0))

    if check_mean_teacher(epoch):
        update_ema_variables(model1, ema_model1, args.ema_decay, step) # 更新ema参数
        update_ema_variables(model2, ema_model2, args.ema_decay, step)
        step += 1

    print('Epoch Noisy : [{0}]\t'
            'PACC1 {pacc1.val:.3f} ({pacc1.avg:.3f})\t'
            'NACC1 {nacc1.val:.3f} ({nacc1.avg:.3f})\t'
            'PNACC1 {pnacc1.val:.3f} ({pnacc1.avg:.3f})\t'
            'PACC2 {pacc2.val:.3f} ({pacc2.avg:.3f})\t'
            'NACC2 {nacc2.val:.3f} ({nacc2.avg:.3f})\t'
            'PNACC2 {pnacc2.val:.3f} ({pnacc2.avg:.3f})\t'
            'PACC3 {pacc3.val:.3f} ({pacc3.avg:.3f})\t'
            'NACC3 {nacc3.val:.3f} ({nacc3.avg:.3f})\t'
            'PNACC3 {pnacc3.val:.3f} ({pnacc3.avg:.3f})\t'
            'PACC4 {pacc4.val:.3f} ({pacc4.avg:.3f})\t'
            'NACC4 {nacc4.val:.3f} ({nacc4.avg:.3f})\t'
            'PNACC4 {pnacc4.val:.3f} ({pnacc4.avg:.3f})\t'.format(
            epoch, pacc1=pacc1, nacc1=nacc1, pnacc1=pnacc1, 
            pacc2=pacc2, nacc2=nacc2, pnacc2=pnacc2, pacc3=pacc3, nacc3=nacc3, pnacc3=pnacc3, 
            pacc4=pacc4, nacc4=nacc4, pnacc4=pnacc4))

    
    return pacc1.avg, nacc1.avg, pnacc1.avg
예제 #7
0
파일: main.py 프로젝트: xxchenxx/NoisyLabel
def train(model,
          input_channel,
          optimizers,
          criterion,
          components,
          train_loader,
          val_loader,
          epoch,
          writer,
          args,
          use_CUDA=True,
          clamp=False,
          num_classes=10):
    model.train()
    accs = []
    losses_w1 = []
    losses_w2 = []
    iter_val_loader = iter(val_loader)
    meta_criterion = nn.CrossEntropyLoss(reduce=False)
    index = 0
    noisy_labels = []
    true_labels = []

    w = defaultdict()
    w_logger = defaultdict()
    losses_logger = defaultdict()
    accuracy_logger = ScalarLogger(prefix='accuracy')
    for c in components:
        w[c] = None
        w_logger[c] = WLogger()
        losses_logger[c] = ScalarLogger(prefix='loss')

    for (input, label, real) in train_loader:
        noisy_labels.append(label)
        true_labels.append(real)

        meta_model = get_model(args,
                               num_classes=num_classes,
                               input_channel=input_channel)
        meta_model.load_state_dict(model.state_dict())
        if use_CUDA:
            meta_model = meta_model.cuda()

        val_input, val_label, iter_val_loader = get_val_samples(
            iter_val_loader, val_loader)
        input = to_var(input, requires_grad=False)
        label = to_var(label, requires_grad=False).long()
        val_input = to_var(val_input, requires_grad=False)
        val_label = to_var(val_label, requires_grad=False).long()

        meta_output = meta_model(input)
        cost = meta_criterion(meta_output, label)
        eps = to_var(torch.zeros(cost.size()))
        meta_loss = (cost * eps).sum()
        meta_model.zero_grad()

        if 'all' in components:
            grads = torch.autograd.grad(meta_loss, (meta_model.parameters()),
                                        create_graph=True)
            meta_model.update_params(0.001, source_params=grads)

            meta_val_output = meta_model(val_input)
            meta_val_loss = meta_criterion(meta_val_output, val_label).sum()
            grad_eps = torch.autograd.grad(meta_val_loss,
                                           eps,
                                           only_inputs=True)[0]
            if clamp:
                w['all'] = torch.clamp(-grad_eps, min=0)
            else:
                w['all'] = -grad_eps

            norm = torch.sum(abs(w['all']))
            assert (clamp and len(components)
                    == 1) or (len(components) > 1), "Error combination"
            w['all'] = w['all'] / norm
            if ('fc' in components):
                w['fc'] = copy.deepcopy(w['all'])
                w['fc'] = torch.clamp(w['fc'], max=0)
                w['all'] = torch.clamp(w['all'], min=0)
            elif ('backbone' in components):
                w['backbone'] = copy.deepcopy(w['all'])
                w['backbone'] = torch.clamp(w['backbone'], max=0)
                w['all'] = torch.clamp(w['all'], min=0)

        else:
            assert ('backbone' in components) and ('fc' in components)

            grads_backbone = torch.autograd.grad(
                meta_loss, (meta_model.backbone.parameters()),
                create_graph=True,
                retain_graph=True)
            grads_fc = torch.autograd.grad(meta_loss,
                                           (meta_model.fc.parameters()),
                                           create_graph=True)

            # Backbone Grads
            meta_model.backbone.update_params(0.001,
                                              source_params=grads_backbone)
            meta_val_feature = torch.flatten(meta_model.backbone(val_input), 1)
            meta_val_output = meta_model.fc(val_input)
            meta_val_loss = meta_criterion(meta_val_output, val_label).sum()

            if args.with_kl and args.reg_start <= epoch:
                train_feature = torch.flatten(meta_model.backbone(input), 1)
                meta_val_loss -= sample_wise_kl(train_feature,
                                                meta_val_feature)

            grad_eps = torch.autograd.grad(meta_val_loss,
                                           eps,
                                           only_inputs=True,
                                           retain_graph=True)[0]
            if clamp:
                w['backbone'] = torch.clamp(-grad_eps, min=0)
            else:
                w['backbone'] = -grad_eps
            norm = torch.sum(abs(w['backbone']))
            w['backbone'] = w['backbone'] / norm

            # FC backward
            meta_model.load_state_dict(model.state_dict())
            meta_model.fc.update_params(0.001, source_params=grads_fc)
            meta_val_output = meta_model(val_input)
            meta_val_loss = meta_criterion(meta_val_output, val_label).sum()
            grad_eps = torch.autograd.grad(meta_val_loss,
                                           eps,
                                           only_inputs=True,
                                           retain_graph=True)[0]

            if clamp:
                w['fc'] = torch.clamp(-grad_eps, min=0)
            else:
                w['fc'] = -grad_eps
            norm = torch.sum(abs(w['fc']))
            w['fc'] = w['fc'] / norm

        index += 1
        output = model(input)
        loss = defaultdict()
        prediction = torch.softmax(output, 1)
        for c in components:
            w_logger[c].update(w[c])
            loss[c] = (meta_criterion(output, label) * w[c]).sum()
            optimizers[c].zero_grad()
            loss[c].backward(retain_graph=True)
            optimizers[c].step()
            losses_logger[c].update(loss[c])

        top1 = accuracy(prediction, label)
        accuracy_logger.update(top1)

    noisy_labels = torch.cat(noisy_labels)
    true_labels = torch.cat(true_labels)
    mask = (noisy_labels != true_labels).cpu().numpy()
    for c in components:
        w_logger[c].write(writer, c, epoch)
        w_logger[c].mask_write(writer, c, epoch, mask)
        losses_logger[c].write(writer, c, epoch)

    accuracy_logger.write(writer, 'train', epoch)

    print("Training Epoch: {}, Accuracy: {}".format(epoch,
                                                    accuracy_logger.avg()))
    return accuracy_logger.avg()