Exemplo n.º 1
0
def val(model, val_loader, criterion, epoch, writer, use_CUDA = True):
    model.eval()
    accuracy_logger = ScalarLogger(prefix = 'accuracy')
    losses_logger = ScalarLogger(prefix = 'loss')
    with torch.no_grad():
        for (input, label, _) in val_loader:
            input = to_var(input, requires_grad = False)
            label = to_var(label, requires_grad = False).long()

            output = model(input)
            loss = criterion(output, label)
            prediction = torch.softmax(output, 1)
            top1 = accuracy(prediction, label)
            accuracy_logger.update(top1)
            losses_logger.update(loss)

    accuracy_logger.write(writer, 'val', epoch)
    losses_logger.write(writer, 'val', epoch)
    accuracy_ = accuracy_logger.avg()
    losses = losses_logger.avg()
    print("Validation Epoch: {}, Accuracy: {}, Losses: {}".format(epoch, accuracy_, losses))
    return accuracy_, losses
Exemplo n.º 2
0
def train(model, input_channel, optimizer, criterion, train_loader, val_loader, epoch, writer, args, use_CUDA = True, clamp = False, num_classes = 10):
    model.train()
    accs = []
    losses_w1 = []
    losses_w2 = []
    iter_val_loader = iter(val_loader)
    meta_criterion = nn.CrossEntropyLoss(reduce = False)
    index = 0
    noisy_labels = []
    true_labels = []

    w = defaultdict()
    w_logger = defaultdict()
    losses_logger = defaultdict()
    accuracy_logger = ScalarLogger(prefix = 'accuracy')
         
    for (input, label, real) in train_loader:
        noisy_labels.append(label)
        true_labels.append(real)
        input = to_var(input, requires_grad = False)
        label = to_var(label, requires_grad = False).long()
        index += 1
        output = model(input)
        loss = meta_criterion(output, label).sum() / input.shape[0]
        prediction = torch.softmax(output, 1)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        top1 = accuracy(prediction, label)
        accuracy_logger.update(top1)
        
    noisy_labels = torch.cat(noisy_labels)
    true_labels = torch.cat(true_labels)
    mask = (noisy_labels != true_labels).cpu().numpy()

    accuracy_logger.write(writer, 'train', epoch)
    
    print("Training Epoch: {}, Accuracy: {}".format(epoch, accuracy_logger.avg()))
    return accuracy_logger.avg()
Exemplo n.º 3
0
def train(model,
          vnet,
          input_channel,
          optimizers,
          optimizer_vnet,
          components,
          criterion,
          train_loader,
          val_loader,
          epoch,
          writer,
          args,
          use_CUDA=True,
          clamp=False,
          num_classes=10):
    model.train()
    accs = []
    losses_w1 = []
    losses_w2 = []
    iter_val_loader = iter(val_loader)
    meta_criterion = nn.CrossEntropyLoss(reduce=False)
    index = 0
    noisy_labels = []
    true_labels = []

    w = defaultdict()
    w_logger = defaultdict()
    losses_logger = defaultdict()
    accuracy_logger = ScalarLogger(prefix='accuracy')
    for c in components:
        w[c] = None
        w_logger[c] = WLogger()
        losses_logger[c] = ScalarLogger(prefix='loss')

    for (input, label, real) in train_loader:
        noisy_labels.append(label)
        true_labels.append(real)

        meta_model = get_model(args,
                               num_classes=num_classes,
                               input_channel=input_channel)
        meta_model.load_state_dict(model.state_dict())
        if use_CUDA:
            meta_model = meta_model.cuda()

        val_input, val_label, iter_val_loader = get_val_samples(
            iter_val_loader, val_loader)
        input = to_var(input, requires_grad=False)
        label = to_var(label, requires_grad=False).long()
        val_input = to_var(val_input, requires_grad=False)
        val_label = to_var(val_label, requires_grad=False).long()

        meta_output = meta_model(input)
        cost = meta_criterion(meta_output, label)
        #eps = to_var(torch.zeros(cost.size()))
        cost_v = torch.reshape(cost, (len(cost), 1))
        eps = vnet(cost_v.data)  # shape: (N, 2)

        meta_loss_backbone = (cost * eps[:, 0]).sum()
        meta_loss_fc = (cost * eps[:, 1]).sum()
        meta_model.zero_grad()

        grads_backbone = torch.autograd.grad(
            meta_loss_backbone, (meta_model.backbone.parameters()),
            create_graph=True,
            retain_graph=True)
        grads_fc = torch.autograd.grad(meta_loss_fc,
                                       (meta_model.fc.parameters()),
                                       create_graph=True)

        # Backbone Grads
        meta_model.backbone.update_params(0.001, source_params=grads_backbone)
        meta_val_feature = torch.flatten(meta_model.backbone(val_input), 1)
        meta_val_output = meta_model.fc(meta_val_feature)
        meta_val_loss = meta_criterion(meta_val_output, val_label).sum()
        ''' TODO: temorarily remove 
        if args.with_kl and args.reg_start <= epoch:
            train_feature = torch.flatten(meta_model.backbone(input), 1)
            meta_val_loss -= sample_wise_kl(train_feature, meta_val_feature)
                    
        grad_eps = torch.autograd.grad(meta_val_loss, eps, only_inputs = True, retain_graph = True)[0]
        if clamp:
            w['backbone'] = torch.clamp(-grad_eps, min = 0)
        else:
            w['backbone'] = -grad_eps
        norm = torch.sum(abs(w['backbone']))
        w['backbone'] = w['backbone'] / norm
        '''
        optimizer_vnet.zero_grad()
        meta_val_loss.backward(retain_graph=True)
        optimizer_vnet.step()

        # FC backward
        meta_model.load_state_dict(model.state_dict())
        meta_model.fc.update_params(0.001, source_params=grads_fc)
        meta_val_output = meta_model(val_input)
        meta_val_loss = meta_criterion(meta_val_output, val_label).sum()
        '''
        grad_eps = torch.autograd.grad(meta_val_loss, eps, only_inputs = True, retain_graph = True)[0]
        
        if clamp:
            w['fc'] = torch.clamp(-grad_eps, min = 0)
        else:
            w['fc'] = -grad_eps
        norm = torch.sum(abs(w['fc']))
        w['fc'] = w['fc'] / norm
        '''
        optimizer_vnet.zero_grad()
        meta_val_loss.backward(retain_graph=True)
        optimizer_vnet.step()

        index += 1
        output = model(input)
        losses = defaultdict()
        loss = meta_criterion(output, label)
        loss_v = torch.reshape(loss, (len(loss), 1))
        with torch.no_grad():
            w_ = vnet(loss_v)
            if clamp:
                w_ = torch.clamp(w_, min=0)
            for i in range(w_.shape[1]):
                w_[:, i] = torch.sum(torch.abs(w_[:, i]))
            w['backbone'] = w_[:, 0]
            w['fc'] = w_[:, 1]

        prediction = torch.softmax(output, 1)
        for c in components:
            w_logger[c].update(w[c])
            losses[c] = (loss * w[c]).sum()
            optimizers[c].zero_grad()
            losses[c].backward(retain_graph=True)
            optimizers[c].step()
            losses_logger[c].update(losses[c])

        top1 = accuracy(prediction, label)
        accuracy_logger.update(top1)

    noisy_labels = torch.cat(noisy_labels)
    true_labels = torch.cat(true_labels)
    mask = (noisy_labels != true_labels).cpu().numpy()
    for c in components:
        w_logger[c].write(writer, c, epoch)
        w_logger[c].mask_write(writer, c, epoch, mask)
        losses_logger[c].write(writer, c, epoch)

    accuracy_logger.write(writer, 'train', epoch)

    print("Training Epoch: {}, Accuracy: {}".format(epoch,
                                                    accuracy_logger.avg()))
    return accuracy_logger.avg()
Exemplo n.º 4
0
def train(model,
          input_channel,
          optimizers,
          criterion,
          components,
          train_loader,
          val_loader,
          epoch,
          writer,
          args,
          use_CUDA=True,
          clamp=False,
          num_classes=10):
    model.train()
    accs = []
    losses_w1 = []
    losses_w2 = []
    iter_val_loader = iter(val_loader)
    meta_criterion = nn.CrossEntropyLoss(reduce=False)
    index = 0
    noisy_labels = []
    true_labels = []

    w = defaultdict()
    w_logger = defaultdict()
    losses_logger = defaultdict()
    accuracy_logger = ScalarLogger(prefix='accuracy')
    for c in components:
        w[c] = None
        w_logger[c] = WLogger()
        losses_logger[c] = ScalarLogger(prefix='loss')

    for (input, label, real) in train_loader:
        noisy_labels.append(label)
        true_labels.append(real)

        meta_model = get_model(args,
                               num_classes=num_classes,
                               input_channel=input_channel)
        meta_model.load_state_dict(model.state_dict())
        if use_CUDA:
            meta_model = meta_model.cuda()

        val_input, val_label, iter_val_loader = get_val_samples(
            iter_val_loader, val_loader)
        input = to_var(input, requires_grad=False)
        label = to_var(label, requires_grad=False).long()
        val_input = to_var(val_input, requires_grad=False)
        val_label = to_var(val_label, requires_grad=False).long()

        meta_output = meta_model(input)
        cost = meta_criterion(meta_output, label)
        eps = to_var(torch.zeros(cost.size()))
        meta_loss = (cost * eps).sum()
        meta_model.zero_grad()

        if 'all' in components:
            grads = torch.autograd.grad(meta_loss, (meta_model.parameters()),
                                        create_graph=True)
            meta_model.update_params(0.001, source_params=grads)

            meta_val_output = meta_model(val_input)
            meta_val_loss = meta_criterion(meta_val_output, val_label).sum()
            grad_eps = torch.autograd.grad(meta_val_loss,
                                           eps,
                                           only_inputs=True)[0]
            if clamp:
                w['all'] = torch.clamp(-grad_eps, min=0)
            else:
                w['all'] = -grad_eps

            norm = torch.sum(abs(w['all']))
            assert (clamp and len(components)
                    == 1) or (len(components) > 1), "Error combination"
            w['all'] = w['all'] / norm
            if ('fc' in components):
                w['fc'] = copy.deepcopy(w['all'])
                w['fc'] = torch.clamp(w['fc'], max=0)
                w['all'] = torch.clamp(w['all'], min=0)
            elif ('backbone' in components):
                w['backbone'] = copy.deepcopy(w['all'])
                w['backbone'] = torch.clamp(w['backbone'], max=0)
                w['all'] = torch.clamp(w['all'], min=0)

        else:
            assert ('backbone' in components) and ('fc' in components)

            grads_backbone = torch.autograd.grad(
                meta_loss, (meta_model.backbone.parameters()),
                create_graph=True,
                retain_graph=True)
            grads_fc = torch.autograd.grad(meta_loss,
                                           (meta_model.fc.parameters()),
                                           create_graph=True)

            # Backbone Grads
            meta_model.backbone.update_params(0.001,
                                              source_params=grads_backbone)
            meta_val_feature = torch.flatten(meta_model.backbone(val_input), 1)
            meta_val_output = meta_model.fc(val_input)
            meta_val_loss = meta_criterion(meta_val_output, val_label).sum()

            if args.with_kl and args.reg_start <= epoch:
                train_feature = torch.flatten(meta_model.backbone(input), 1)
                meta_val_loss -= sample_wise_kl(train_feature,
                                                meta_val_feature)

            grad_eps = torch.autograd.grad(meta_val_loss,
                                           eps,
                                           only_inputs=True,
                                           retain_graph=True)[0]
            if clamp:
                w['backbone'] = torch.clamp(-grad_eps, min=0)
            else:
                w['backbone'] = -grad_eps
            norm = torch.sum(abs(w['backbone']))
            w['backbone'] = w['backbone'] / norm

            # FC backward
            meta_model.load_state_dict(model.state_dict())
            meta_model.fc.update_params(0.001, source_params=grads_fc)
            meta_val_output = meta_model(val_input)
            meta_val_loss = meta_criterion(meta_val_output, val_label).sum()
            grad_eps = torch.autograd.grad(meta_val_loss,
                                           eps,
                                           only_inputs=True,
                                           retain_graph=True)[0]

            if clamp:
                w['fc'] = torch.clamp(-grad_eps, min=0)
            else:
                w['fc'] = -grad_eps
            norm = torch.sum(abs(w['fc']))
            w['fc'] = w['fc'] / norm

        index += 1
        output = model(input)
        loss = defaultdict()
        prediction = torch.softmax(output, 1)
        for c in components:
            w_logger[c].update(w[c])
            loss[c] = (meta_criterion(output, label) * w[c]).sum()
            optimizers[c].zero_grad()
            loss[c].backward(retain_graph=True)
            optimizers[c].step()
            losses_logger[c].update(loss[c])

        top1 = accuracy(prediction, label)
        accuracy_logger.update(top1)

    noisy_labels = torch.cat(noisy_labels)
    true_labels = torch.cat(true_labels)
    mask = (noisy_labels != true_labels).cpu().numpy()
    for c in components:
        w_logger[c].write(writer, c, epoch)
        w_logger[c].mask_write(writer, c, epoch, mask)
        losses_logger[c].write(writer, c, epoch)

    accuracy_logger.write(writer, 'train', epoch)

    print("Training Epoch: {}, Accuracy: {}".format(epoch,
                                                    accuracy_logger.avg()))
    return accuracy_logger.avg()