示例#1
0
                                                  batch_size=BATCH_SIZE,
                                                  shuffle=False)
    meta_dataloader = torch.utils.data.DataLoader(meta_dataset,
                                                  batch_size=BATCH_SIZE,
                                                  shuffle=False,
                                                  drop_last=True)
    NUM_METADATA = len(meta_dataset)
    noisy_idx, clean_labels = get_synthetic_idx(
        dataset,
        args.seed,
        args.metadata_num,
        0,
        noise_type,
        noise_ratio,
    )
    net = get_model(dataset, framework).to(device)
    if (device.type == 'cuda') and (ngpu > 1):
        net = nn.DataParallel(net, list(range(ngpu)))
    lr_scheduler = get_lr_scheduler(dataset)
    optimizer = optim.SGD(net.parameters(),
                          lr=lr_scheduler(0),
                          momentum=0.9,
                          weight_decay=1e-4)
    logsoftmax = nn.LogSoftmax(dim=1).to(device)
    softmax = nn.Softmax(dim=1).to(device)

    print(
        "Dataset: {}, Model: {}, Device: {}, Batch size: {}, #GPUS to run: {}".
        format(dataset, model_name, device, BATCH_SIZE, ngpu))
    if dataset in DATASETS_SMALL:
        print("Noise type: {}, Noise ratio: {}".format(noise_type,
示例#2
0
def metaweightnet():
    '''
    2019 - NIPS - Meta-weight-net: Learning an explicit mapping for sample weighting
    github repo: https://github.com/xjtushujun/meta-weight-net
    '''
    train_meta_loader = val_dataloader

    class VNet(nn.Module):
        def __init__(self, input, hidden, output):
            super(VNet, self).__init__()
            self.linear1 = nn.Linear(input, hidden)
            self.relu1 = nn.ReLU(inplace=True)
            self.linear2 = nn.Linear(hidden, output)

        def forward(self, x, weights=None):
            if weights == None:
                x = self.linear1(x)
                x = self.relu1(x)
                out = self.linear2(x)
                return torch.sigmoid(out)
            else:
                x = F.linear(x, weights['fc1.weight'], weights['fc1.bias'])   
                feat = F.threshold(x, 0, 0, inplace=True)
                x = F.linear(feat, weights['fc2.weight'], weights['fc2.bias'])
                return torch.sigmoid(out)
    vnet = VNet(1, 100, 1).to(device)
    optimizer_vnet = torch.optim.Adam(vnet.parameters(), 1e-3, weight_decay=1e-4)

    test_acc_best = 0
    val_acc_best = 0
    epoch_best = 0

    for epoch in range(PARAMS[dataset]['epochs']): 
        start_epoch = time.time()
        train_accuracy = AverageMeter()
        train_loss = AverageMeter()
        meta_loss = AverageMeter()

         # set learning rate
        for param_group in optimizer.param_groups:
            param_group['lr'] = lr_scheduler(epoch)

        train_meta_loader_iter = iter(train_meta_loader)
        for batch_idx, (inputs, targets) in enumerate(train_dataloader):
            start = time.time()

            net.train()
            inputs, targets = inputs.to(device), targets.to(device)
            meta_model = get_model(dataset,framework).to(device)
            meta_model.load_state_dict(net.state_dict())
            outputs = meta_model(inputs)

            cost = F.cross_entropy(outputs, targets, reduce=False)
            cost_v = torch.reshape(cost, (len(cost), 1))
            v_lambda = vnet(cost_v.data)
            l_f_meta = torch.sum(cost_v * v_lambda)/len(cost_v)
            meta_model.zero_grad()
            
            #grads = torch.autograd.grad(l_f_meta, (meta_model.parameters()), create_graph=True)
            #meta_model.update_params(lr_inner=lr_scheduler(epoch), source_params=grads)
            grads = torch.autograd.grad(l_f_meta, meta_model.parameters(), create_graph=True, retain_graph=True, only_inputs=True)
            fast_weights = OrderedDict((name, param - lr_scheduler(epoch)*grad) for ((name, param), grad) in zip(meta_model.named_parameters(), grads))

            try:
                inputs_val, targets_val = next(train_meta_loader_iter)
            except StopIteration:
                train_meta_loader_iter = iter(train_meta_loader)
                inputs_val, targets_val = next(train_meta_loader_iter)
            inputs_val, targets_val = inputs_val.to(device), targets_val.type(torch.long).to(device)
            #y_g_hat = meta_model(inputs_val)
            y_g_hat = meta_model.forward(inputs_val,fast_weights)  
            l_g_meta = F.cross_entropy(y_g_hat, targets_val)

            optimizer_vnet.zero_grad()
            l_g_meta.backward()
            optimizer_vnet.step()

            outputs = net(inputs)
            cost_w = F.cross_entropy(outputs, targets, reduce=False)
            cost_v = torch.reshape(cost_w, (len(cost_w), 1))

            with torch.no_grad():
                w_new = vnet(cost_v)

            loss = torch.sum(cost_v * w_new)/len(cost_v)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            del grads

            _, predicted = torch.max(outputs.data, 1)
            train_accuracy.update(predicted.eq(targets.data).cpu().sum().item(), targets.size(0)) 
            train_loss.update(loss.item())
            meta_loss.update(l_g_meta.item(), targets.size(0))

            if verbose == 2:
                sys.stdout.write("Progress: {:6.5f}, Accuracy: {:5.4f}, Loss: {:5.4f}, Process time:{:5.4f}   \r"
                                .format(batch_idx*BATCH_SIZE/NUM_TRAINDATA, train_accuracy.percentage, train_loss.avg, time.time()-start))
        if verbose == 2:
            sys.stdout.flush()

        # evaluate on validation and test data
        val_accuracy, val_loss = evaluate(net, val_dataloader, F.cross_entropy)
        test_accuracy, test_loss = evaluate(net, test_dataloader, F.cross_entropy)
        if val_accuracy > val_acc_best: 
            val_acc_best = val_accuracy
            test_acc_best = test_accuracy
            epoch_best = epoch

        summary_writer.add_scalar('train_loss', train_loss.avg, epoch)
        summary_writer.add_scalar('test_loss', test_loss, epoch)
        summary_writer.add_scalar('train_accuracy', train_accuracy.percentage, epoch)
        summary_writer.add_scalar('test_accuracy', test_accuracy, epoch)
        summary_writer.add_scalar('val_loss', val_loss, epoch)
        summary_writer.add_scalar('val_accuracy', val_accuracy, epoch)
        summary_writer.add_scalar('test_accuracy_best', test_acc_best, epoch)
        summary_writer.add_scalar('val_accuracy_best', val_acc_best, epoch)

        if verbose != 0:
            template = 'Epoch {}, Loss: {:7.4f}, Accuracy: {:5.3f}, Val Loss: {:7.4f}, Val Accuracy: {:5.3f}, Test Loss: {:7.4f}, Test Accuracy: {:5.3f}, lr: {:7.6f} Time: {:3.1f}({:3.2f})'
            print(template.format(epoch + 1, train_loss.avg, train_accuracy.percentage, val_loss, val_accuracy, test_loss, test_accuracy, lr_scheduler(epoch), time.time()-start_epoch, (time.time()-start_epoch)/3600))
        save_model(net, epoch)

    print('Train acc: {:5.3f}, Val acc: {:5.3f}-{:5.3f}, Test acc: {:5.3f}-{:5.3f} / Train loss: {:7.4f}, Val loss: {:7.4f}, Test loss: {:7.4f} / Best epoch: {}'.format(
        train_accuracy.percentage, val_accuracy, val_acc_best, test_accuracy, test_acc_best, train_loss.avg, val_loss, test_loss, epoch_best
    ))
    summary_writer.close()
    torch.save(net.state_dict(), os.path.join(log_dir, 'saved_model.pt'))
示例#3
0
def mlnt(criterion, consistent_criterion, start_iter=500, mid_iter = 2000, eps=0.99, args_alpha=1,num_fast=10,perturb_ratio=0.5,meta_lr=0.2):
    '''
    2019 - CVPR - Learning to Learn from Noisy Labeled Data
    github repo: https://github.com/LiJunnan1992/MLNT
    '''
    # get model 
    path = Path(log_dir)
    tch_net = get_model(dataset,framework).to(device)
    pretrain_net = get_model(dataset,framework).to(device)

    ce_base_folder = os.path.join(path.parent.parent, 'cross_entropy')
    for f in os.listdir(ce_base_folder):
        ce_path = os.path.join(ce_base_folder, f, 'saved_model.pt')
        if os.path.exists(ce_path):
            print('Loading base model from: {}'.format(ce_path))
            pretrain_net.load_state_dict(torch.load(ce_path, map_location=device))
            break

    # tensorboard
    summary_writer = SummaryWriter(log_dir)
    init = True

    test_acc_best = 0
    val_acc_best = 0
    epoch_best = 0

    for epoch in range(PARAMS[dataset]['epochs']): 
        start_epoch = time.time()
        train_accuracy = AverageMeter()
        train_loss = AverageMeter()

        net.train()
        tch_net.train()
        
        for param_group in optimizer.param_groups:
            param_group['lr'] = lr_scheduler(epoch)
        
        for batch_idx, (inputs, targets) in enumerate(train_dataloader):
            start = time.time()
            inputs, targets = inputs.to(device), targets.to(device) 
            optimizer.zero_grad()
            inputs, targets = torch.autograd.Variable(inputs), torch.autograd.Variable(targets)
            outputs = net(inputs)               # Forward Propagation        
            
            class_loss = criterion(outputs, targets)  # Loss
            class_loss.backward(retain_graph=True)  

            if batch_idx>start_iter or epoch>1:
                if batch_idx>mid_iter or epoch>1:
                    eps=0.999
                    alpha = args_alpha
                else:
                    u = (batch_idx-start_iter)/(mid_iter-start_iter)
                    alpha = args_alpha*math.exp(-5*(1-u)**2)          
            
                if init:
                    init = False
                    for param,param_tch in zip(net.parameters(),tch_net.parameters()): 
                        param_tch.data.copy_(param.data)                    
                else:
                    for param,param_tch in zip(net.parameters(),tch_net.parameters()):
                        param_tch.data.mul_(eps).add_((1-eps), param.data)   
                
                _,feats = pretrain_net(inputs,get_feat=True)
                tch_outputs = tch_net(inputs,get_feat=False)
                p_tch = F.softmax(tch_outputs,dim=1)
                p_tch.detach_()
                
                for i in range(num_fast):
                    targets_fast = targets.clone()
                    randidx = torch.randperm(targets.size(0))
                    for n in range(int(targets.size(0)*perturb_ratio)):
                        num_neighbor = 10
                        idx = randidx[n]
                        feat = feats[idx]
                        feat.view(1,feat.size(0))
                        feat.data = feat.data.expand(targets.size(0),feat.size(0))
                        dist = torch.sum((feat-feats)**2,dim=1)
                        _, neighbor = torch.topk(dist.data,num_neighbor+1,largest=False)
                        targets_fast[idx] = targets[neighbor[random.randint(1,num_neighbor)]]
                        
                    fast_loss = criterion(outputs,targets_fast)

                    grads = torch.autograd.grad(fast_loss, net.parameters(), create_graph=True, retain_graph=True, only_inputs=True)
                    for grad in grads:
                        grad.detach()
                        #grad.detach_()
                        #grad.requires_grad = False  
    
                    fast_weights = OrderedDict((name, param - meta_lr*grad) for ((name, param), grad) in zip(net.named_parameters(), grads))
                    
                    fast_out = net.forward(inputs,fast_weights)  
        
                    logp_fast = F.log_softmax(fast_out,dim=1)                
                    consistent_loss = consistent_criterion(logp_fast,p_tch)
                    consistent_loss = consistent_loss*alpha/num_fast 
                    consistent_loss.backward(retain_graph=True)
                    del grads, fast_weights
                    
            optimizer.step() # Optimizer update 

            _, predicted = torch.max(outputs.data, 1)
            train_accuracy.update(predicted.eq(targets.data).cpu().sum().item(), targets.size(0)) 
            train_loss.update(class_loss.item())

            if verbose == 2:
                sys.stdout.write("Progress: {:6.5f}, Accuracy: {:5.4f}, Loss: {:5.4f}, Process time:{:5.4f}   \r"
                                .format(batch_idx*BATCH_SIZE/NUM_TRAINDATA, train_accuracy.percentage, train_loss.avg, time.time()-start))
        if verbose == 2:
            sys.stdout.flush()
                
        # evaluate on validation and test data
        val_accuracy, val_loss = evaluate(net, val_dataloader, criterion)
        test_accuracy, test_loss = evaluate(net, test_dataloader, criterion)
        if val_accuracy > val_acc_best: 
            val_acc_best = val_accuracy
            test_acc_best = test_accuracy
            epoch_best = epoch

        summary_writer.add_scalar('train_loss', train_loss.avg, epoch)
        summary_writer.add_scalar('test_loss', test_loss, epoch)
        summary_writer.add_scalar('train_accuracy', train_accuracy.percentage, epoch)
        summary_writer.add_scalar('test_accuracy', test_accuracy, epoch)
        summary_writer.add_scalar('val_loss', val_loss, epoch)
        summary_writer.add_scalar('val_accuracy', val_accuracy, epoch)
        summary_writer.add_scalar('test_accuracy_best', test_acc_best, epoch)
        summary_writer.add_scalar('val_accuracy_best', val_acc_best, epoch)

        if verbose != 0:
            template = 'Epoch {}, Loss: {:7.4f}, Accuracy: {:5.3f}, Val Loss: {:7.4f}, Val Accuracy: {:5.3f}, Test Loss: {:7.4f}, Test Accuracy: {:5.3f}, lr: {:7.6f} Time: {:3.1f}({:3.2f})'
            print(template.format(epoch + 1, train_loss.avg, train_accuracy.percentage, val_loss, val_accuracy, test_loss, test_accuracy, lr_scheduler(epoch), time.time()-start_epoch, (time.time()-start_epoch)/3600))
        save_model(net, epoch)

    print('Train acc: {:5.3f}, Val acc: {:5.3f}-{:5.3f}, Test acc: {:5.3f}-{:5.3f} / Train loss: {:7.4f}, Val loss: {:7.4f}, Test loss: {:7.4f} / Best epoch: {}'.format(
        train_accuracy.percentage, val_accuracy, val_acc_best, test_accuracy, test_acc_best, train_loss.avg, val_loss, test_loss, epoch_best
    ))
    summary_writer.close()
    torch.save(tch_net.state_dict(), os.path.join(log_dir, 'saved_model.pt'))
示例#4
0
def coteaching(criterion):
    '''
    2018 - NIPS - Co-teaching: Robust training of deep neural networks with extremely noisy labels
    github repo: https://github.com/bhanML/Co-teaching
    '''
    # get model 
    net1 = get_model(dataset,framework).to(device)
    net2 = get_model(dataset,framework).to(device)
    optimizer1 = optim.SGD(net1.parameters(), lr=lr_scheduler(0), momentum=0.9, weight_decay=1e-4)
    optimizer2 = optim.SGD(net2.parameters(), lr=lr_scheduler(0), momentum=0.9, weight_decay=1e-4)

    train_loss1 = AverageMeter()
    train_accuracy1 = AverageMeter()
    val_loss1 = AverageMeter()
    val_accuracy1 = AverageMeter()
    test_loss1 = AverageMeter()
    test_accuracy1 = AverageMeter()

    train_loss2 = AverageMeter()
    train_accuracy2 = AverageMeter()
    val_loss2 = AverageMeter()
    val_accuracy2 = AverageMeter()
    test_loss2 = AverageMeter()
    test_accuracy2 = AverageMeter()

    # calculate forget rates for each epoch (from origianl code)
    forget_rate=0.2
    num_graduals=10
    exponent=0.2

    forget_rates = np.ones(PARAMS[dataset]['epochs'])*forget_rate
    forget_rates[:num_graduals] = np.linspace(0, forget_rate**exponent, num_graduals)

    test_acc_best = 0
    val_acc_best = 0
    epoch_best = 0

    for epoch in range(PARAMS[dataset]['epochs']):
        start_epoch = time.time()
        # Reset the metrics at the start of the next epoch
        train_loss1.reset()
        train_accuracy1.reset()
        train_loss2.reset()
        train_accuracy2.reset()
        remember_rate = 1 - forget_rates[epoch]

        for param_group in optimizer1.param_groups:
            param_group['lr'] = lr_scheduler(epoch)
        for param_group in optimizer2.param_groups:
            param_group['lr'] = lr_scheduler(epoch)

        for batch_idx, (images, labels) in enumerate(train_dataloader):
            start = time.time()
            images, labels = images.to(device), labels.to(device)
            num_remember = int(remember_rate * BATCH_SIZE)
            
            with torch.no_grad():
                # select samples based on model 1
                net1.eval()
                y_pred1 = F.softmax(net1(images))
                cross_entropy = F.cross_entropy(y_pred1, labels, reduce=False)
                batch_idx1= np.argsort(cross_entropy.cpu().numpy())[:num_remember]
                # select samples based on model 2
                net2.eval()
                y_pred2 = F.softmax(net2(images))
                cross_entropy = F.cross_entropy(y_pred2, labels, reduce=False)
                batch_idx2 = np.argsort(cross_entropy.cpu().numpy())[:num_remember]

            # train net1
            net1.train()
            optimizer1.zero_grad()
            outputs = net1(images[batch_idx2,:])
            loss1 = criterion(outputs, labels[batch_idx2])
            loss1.backward()
            optimizer1.step()
            _, predicted = torch.max(outputs.data, 1)
            train_accuracy1.update(predicted.eq(labels[batch_idx2].data).cpu().sum().item(), labels.size(0)) 
            train_loss1.update(loss1.item(), images.size(0))
            # train net2
            net2.train()
            optimizer2.zero_grad()
            outputs = net2(images[batch_idx1,:])
            loss2 = criterion(outputs, labels[batch_idx1])
            loss2.backward()
            optimizer2.step()
            _, predicted = torch.max(outputs.data, 1)
            train_accuracy2.update(predicted.eq(labels[batch_idx1].data).cpu().sum().item(), labels.size(0)) 
            train_loss2.update(loss2.item(), images.size(0))

            if verbose == 2:
                sys.stdout.write("Progress: {:6.5f}, Accuracy1: {:5.4f}, Loss1: {:5.4f}, Accuracy2: {:5.4f}, Loss2: {:5.4f}, Process time:{:5.4f}   \r"
                             .format(batch_idx*BATCH_SIZE/NUM_TRAINDATA, 
                             train_accuracy1.avg, train_loss1.avg, train_accuracy2.avg, train_loss2.avg, time.time()-start))
        if verbose == 2:
            sys.stdout.flush()

        # evaluate on validation and test data
        val_accuracy1, val_loss1 = evaluate(net1, val_dataloader, criterion)
        test_accuracy1, test_loss1 = evaluate(net1, test_dataloader, criterion)
        # evaluate on validation and test data
        val_accuracy2, val_loss2 = evaluate(net2, val_dataloader, criterion)
        test_accuracy2, test_loss2 = evaluate(net2, test_dataloader, criterion)

        if max(val_accuracy1, val_accuracy2) > val_acc_best: 
            val_acc_best = max(val_accuracy1, val_accuracy2)
            test_acc_best = max(test_accuracy1, test_accuracy2)
            epoch_best = epoch

        summary_writer.add_scalar('train_loss1', train_loss1.avg, epoch)
        summary_writer.add_scalar('train_accuracy1', train_accuracy1.percentage, epoch)
        summary_writer.add_scalar('val_loss1', val_loss1, epoch)
        summary_writer.add_scalar('val_accuracy1', val_accuracy1, epoch)
        summary_writer.add_scalar('test_loss1', test_loss1, epoch)
        summary_writer.add_scalar('test_accuracy1', test_accuracy1, epoch)
        summary_writer.add_scalar('train_loss2', train_loss2.avg, epoch)
        summary_writer.add_scalar('train_accuracy2', train_accuracy2.percentage, epoch)
        summary_writer.add_scalar('val_loss2', val_loss2, epoch)
        summary_writer.add_scalar('val_accuracy2', val_accuracy2, epoch)
        summary_writer.add_scalar('test_loss2', test_loss2, epoch)
        summary_writer.add_scalar('test_accuracy2', test_accuracy2, epoch)

        summary_writer.add_scalar('train_loss', min(train_loss1.avg, train_loss2.avg), epoch)
        summary_writer.add_scalar('train_accuracy', max(train_accuracy1.percentage, train_accuracy2.percentage), epoch)
        summary_writer.add_scalar('val_loss', min(val_loss1, val_loss2), epoch)
        summary_writer.add_scalar('val_accuracy', max(val_accuracy1, val_accuracy2), epoch)
        summary_writer.add_scalar('test_loss', min(test_loss1, test_loss2), epoch)
        summary_writer.add_scalar('test_accuracy', max(test_accuracy1, test_accuracy2), epoch)
        summary_writer.add_scalar('test_accuracy_best', test_acc_best, epoch)
        summary_writer.add_scalar('val_accuracy_best', val_acc_best, epoch)

        if verbose != 0:
            end = time.time()
            template = 'Model1 - Epoch {}, Loss: {:7.4f}, Accuracy: {:5.3f}, Val Loss: {:7.4f}, Val Accuracy: {:5.3f}, Test Loss: {:7.4f}, Test Accuracy: {:5.3f}, lr: {:7.6f} Time: {:3.1f}({:3.2f})'
            print(template.format(epoch + 1,
                                    train_loss1.avg,
                                    train_accuracy1.avg,
                                    val_loss1,
                                    val_accuracy1,
                                    test_loss1,
                                    test_accuracy1,
                                    lr_scheduler(epoch),
                                    end-start_epoch, (end-start_epoch)/3600))
            template = 'Model2 - Epoch {}, Loss: {:7.4f}, Accuracy: {:5.3f}, Val Loss: {:7.4f}, Val Accuracy: {:5.3f}, Test Loss: {:7.4f}, Test Accuracy: {:5.3f}, lr: {:7.6f} Time: {:3.1f}({:3.2f})'
            print(template.format(epoch + 1,
                                train_loss2.avg,
                                train_accuracy2.avg,
                                val_loss2,
                                val_accuracy2,
                                test_loss2,
                                test_accuracy2,
                                lr_scheduler(epoch),
                                end-start_epoch, (end-start_epoch)/3600))
        save_model(net1, epoch, 'model1')
        save_model(net2, epoch, 'model2')

    print('Train acc: {:5.3f}, Val acc: {:5.3f}-{:5.3f}, Test acc: {:5.3f}-{:5.3f} / Train loss: {:7.4f}, Val loss: {:7.4f}, Test loss: {:7.4f} / Best epoch: {}'.format(
        max(train_accuracy1.percentage, train_accuracy2.percentage), max(val_accuracy1, val_accuracy2), val_acc_best, max(test_accuracy1, test_accuracy2), test_acc_best, 
        min(train_loss1.avg, train_loss2.avg),  min(val_loss1, val_loss2), min(test_loss1, test_loss2), epoch_best
    ))
    torch.save(net1.state_dict(), os.path.join(log_dir, 'saved_model1.pt'))
    torch.save(net2.state_dict(), os.path.join(log_dir, 'saved_model2.pt'))