示例#1
0
def PROTOTYPE_MATRIX_update(loader: DataLoader, model: nn.Module,
                            pseudo_labels: PseudoLabeling, args, epoch):
    model.eval()
    num_classes = model.num_classes
    features_dim = model.features_dim
    count = [0] * num_classes
    with torch.no_grad():
        weight1 = model.get_head_weight1()
        weight2 = model.get_head_weight2()
        matrix = 0.5 * weight1 + 0.5 * weight2  #(num_classes, features_dim)
        for i, (images, target, index) in enumerate(loader):
            images = images.to(device)
            target = target.to(device)
            y1, f1, y2, f2 = model(images)
            f = 0.5 * f1 + 0.5 * f2  #(batch_size, features_dim)
            prediction = torch.zeros_like(y1)  #(batch_size, num_classes)
            for j in range(prediction.shape[0]):
                for k in range(prediction.shape[1]):
                    a = f[j] / torch.norm(f[j], 2)
                    b = matrix[k, :] / torch.norm(matrix[k, :], 2)
                    dist = 1 - torch.sum(a * b)
                    prediction[j, k] = torch.exp(-20.0 * dist)
                prediction[j] = prediction[j] / torch.sum(prediction[j])

            if args.prob_ema == True:
                pseudo_labels.EMA_update_p(prediction, index, epoch)
            else:
                pseudo_labels.update_p(prediction, index)
示例#2
0
def EPOCH_PROTOTYPE_S_update(source_loader: DataLoader,
                             target_loader: DataLoader, model: nn.Module,
                             pseudo_labels: PseudoLabeling, args, epoch):
    model.eval()
    num_classes = model.num_classes
    features_dim = model.features_dim
    count = [0] * num_classes
    prototype_s = torch.zeros((num_classes, features_dim), device=device)

    with torch.no_grad():
        for i, (images, target, index) in enumerate(source_loader):
            images = images.to(device)
            target = target.to(device)
            y, f = model(images)

            for j in range(images.shape[0]):
                count[int(target[j].item())] += 1
                prototype_s[int(target[j].item())] += f[j]

        for i in range(prototype_s.shape[0]):
            prototype_s[i] = prototype_s[i] / count[i]

        for i, (images, target, index) in enumerate(target_loader):
            images = images.to(device)
            target = target.to(device)
            y, f = model(images)

            dist = ((f.unsqueeze(1) -
                     prototype_s.unsqueeze(0))**2).sum(2).pow(0.5)
            prediction = F.softmax(-dist, 1)

            if args.prob_ema == True:
                pseudo_labels.EMA_update_p(prediction, index, epoch)
            else:
                pseudo_labels.update_p(prediction, index)
示例#3
0
def uncertainty_update(target_loader: DataLoader, model: nn.Module,
                       pseudo_labels: PseudoLabeling, args, epoch):
    model.eval()
    model.activate_dropout()

    with torch.no_grad():
        for i, (images, target, index) in enumerate(target_loader):
            images = images.to(device)
            ys = []
            h = model.backbone_forward(images)
            for j in range(args.uncertainty_sample_num):
                y, _ = model.head_forward(h)
                y = F.softmax(y, dim=1)
                ys.append(y)
            ys = torch.cat(ys, 0)
            y = torch.mean(ys, 0)

            if args.prob_type == 'prediction_avg':
                if args.prob_ema == True:
                    pseudo_labels.EMA_update_p(y, index, epoch)
                else:
                    pseudo_labels.update_p(y, index)

            if args.uncertainty_type == 'predictive_entropy':
                uncertainty = PredictiveEntropy(ys)
            elif args.uncertainty_type == 'mutual_info':
                uncertainty = MutualInfo(ys)
            elif args.uncertainty_type == 'variation_ratio':
                uncertainty = VariationRatio(ys, device)
            else:
                raise ValueError(f'uncertainty type not found')

            pseudo_labels.update_weight(-1.0 * uncertainty, index)
示例#4
0
def ITERATION_update_pseudo_label(x_t: torch.Tensor, index_t: torch.Tensor,
                                  model: nn.Module,
                                  pseudo_labels: PseudoLabeling, args, epoch):
    model.eval()
    with torch.no_grad():
        y_t, _ = model(x_t)
        y = F.softmax(y_t, dim=1)
        if args.prob_ema == True:
            pseudo_labels.EMA_update_p(y, index_t, epoch)
        else:
            pseudo_labels.update_p(y, index_t)
示例#5
0
def EPOCH_update_pseudo_label(loader: DataLoader, model: nn.Module,
                              pseudo_labels: PseudoLabeling, args, epoch):
    model.eval()
    with torch.no_grad():
        for i, (images, target, index) in enumerate(loader):
            images = images.to(device)
            target = target.to(device)
            y, _ = model(images)
            y = F.softmax(y, dim=1)
            if args.prob_ema == True:
                pseudo_labels.EMA_update_p(y, index, epoch)
            else:
                pseudo_labels.update_p(y, index)
示例#6
0
def PROTOTYPE_MAXST_update(source_loader: DataLoader, val_loader: DataLoader,
                           model: nn.Module, pseudo_labels: PseudoLabeling,
                           args, epoch):
    model.eval()
    num_classes = model.num_classes
    features_dim = model.features_dim
    count_s = [0] * num_classes
    count_t = [0] * num_classes
    prototype_s = torch.zeros((num_classes, features_dim), device=device)
    prototype_t = torch.zeros((num_classes, features_dim), device=device)

    with torch.no_grad():
        for i, (images, target, index) in enumerate(source_loader):
            images = images.to(device)
            target = target.to(device)
            y1, f1, y2, f2 = model(images)
            f = 0.5 * f1 + 0.5 * f2  #f:(batch_size, features_dim)

            for j in range(images.shape[0]):
                count_s[int(target[j].item())] += 1
                prototype_s[int(target[j].item())] += f[j]

        for i in range(prototype_s.shape[0]):
            prototype_s[i] = prototype_s[i] / count_s[i]

        for i, (images, target, index) in enumerate(val_loader):
            images = images.to(device)
            target = target.to(device)
            y1, f1, y2, f2 = model(images)
            f = 0.5 * f1 + 0.5 * f2  # f:(batch_size, features_dim)
            y = 0.5 * y1 + 0.5 * y2  # y:(batch_size, num_classes)
            confidence, index = torch.max(y, dim=1)

            for j in range(images.shape[0]):
                count_t[int(index[j].item())] += 1
                prototype_t[int(index[j].item())] += f[j]

        for i in range(prototype_t.shape[0]):
            prototype_t[i] = prototype_t[i] / count_t[i]

        for i, (images, target, index) in enumerate(val_loader):
            images = images.to(device)
            target = target.to(device)
            y1, f1, y2, f2 = model(images)
            f = 0.5 * f1 + 0.5 * f2
            prediction_s = torch.zeros_like(y1)
            prediction_t = torch.zeros_like(y1)
            for j in range(prediction_s.shape[0]):
                for k in range(prediction_s.shape[1]):
                    prediction_s[j, k] = torch.exp(
                        -torch.norm(f[j] - prototype_s[k], 2))
                prediction_s[j] = prediction_s[j] / torch.sum(prediction_s[j])

            for j in range(prediction_t.shape[0]):
                for k in range(prediction_t.shape[1]):
                    prediction_t[j, k] = torch.exp(
                        -torch.norm(f[j] - prototype_t[k], 2))
                prediction_t[j] = prediction_t[j] / torch.sum(prediction_t[j])

            prediction = torch.zeros_like(prediction_s)
            for j in range(prediction.shape[0]):
                conf_s = torch.max(prediction_s[j]).item()
                conf_t = torch.max(prediction_t[j]).item()
                if conf_s > conf_t:
                    prediction[j] = prediction_s[j]
                else:
                    prediction[j] = prediction_t[j]

            if args.prob_ema == True:
                pseudo_labels.EMA_update_p(prediction, index, epoch)
            else:
                pseudo_labels.update_p(prediction, index)
示例#7
0
def train(train_source_iter: ForeverDataIterator,
          train_target_iter: ForeverDataIterator, model: nn.Module, jmmd_loss,
          optimizer: SGD, lr_sheduler: StepwiseLR,
          pseudo_labels: PseudoLabeling, epoch: int, args: argparse.Namespace):
    losses = AverageMeter('Loss', ':3.2f')
    cls_losses = AverageMeter('Cls Loss', ':3.2f')
    trans_losses = AverageMeter('Trans Loss', ':5.2f')
    joint_losses = AverageMeter('Joint Loss', ':3.2f')
    cls_accs = AverageMeter('Cls Acc', ':3.1f')
    tgt_accs = AverageMeter('Tgt Acc', ':3.1f')

    progress = ProgressMeter(
        args.iters_per_epoch,
        [losses, cls_losses, trans_losses, joint_losses, cls_accs, tgt_accs],
        prefix="Epoch: [{}]".format(epoch))

    model.train()
    if args.freeze_bn:
        for m in model.backbone.modules():
            if isinstance(m, nn.BatchNorm2d):
                m.eval()
    # jmmd_loss.train()

    for i in range(args.iters_per_epoch):
        lr_sheduler.step()

        x_s, labels_s, index_s = next(train_source_iter)
        x_t, labels_t, index_t = next(train_target_iter)

        x_s = x_s.to(device)
        x_t = x_t.to(device)
        labels_s = labels_s.to(device)
        labels_t = labels_t.to(device)

        # compute output
        x = torch.cat((x_s, x_t), dim=0)
        h = model.backbone_forward(x)
        y, f = model.head_forward(h)
        y_s, y_t = y.chunk(2, dim=0)
        f_s, f_t = f.chunk(2, dim=0)

        loss = 0.0

        cls_loss = F.cross_entropy(y_s, labels_s)
        loss += cls_loss

        transfer_loss = jmmd_loss((f_s, F.softmax(y_s, dim=1)),
                                  (f_t, F.softmax(y_t, dim=1)))
        loss += transfer_loss * args.lambda1

        if epoch >= args.start_epoch:
            with torch.no_grad():
                pseudo_labels_t = pseudo_labels.get_hard_pseudo_label(index_t)
                weights_t = pseudo_labels.get_weight(index_t)

            # print(F.softmax(weights_t, dim=0), F.softmax(weights_t/0.1, dim=0))

            weights_t = F.softmax(weights_t / args.temperature, dim=0)

            joint_loss = 0.0
            for j in range(args.loss_sample_num):
                y_1, _ = model.head_forward(h)
                y1_s, y1_t = y_1.chunk(2, dim=0)
                y_2, _ = model.head_forward(h)
                y2_s, y2_t = y_2.chunk(2, dim=0)

                joint_loss += CE_disagreement(y1_s, y1_t, y2_s, y2_t, labels_s,
                                              pseudo_labels_t, weights_t)

            joint_loss = joint_loss / args.loss_sample_num
            loss += joint_loss * args.lambda2

        cls_acc = accuracy(y_s, labels_s)[0]
        tgt_acc = accuracy(y_t, labels_t)[0]

        losses.update(loss.item(), x_s.size(0))
        cls_accs.update(cls_acc.item(), x_s.size(0))
        tgt_accs.update(tgt_acc.item(), x_t.size(0))
        cls_losses.update(cls_loss.item(), x_s.size(0))
        trans_losses.update(transfer_loss.item(), x_s.size(0))
        # trans_losses.update(0.0, x_s.size(0))
        try:
            joint_losses.update(joint_loss.item(), x_s.size(0))
        except:
            joint_losses.update(0.0, x_s.size(0))

        # compute gradient and do SGD step
        optimizer.zero_grad()
        loss.backward()
        if args.gradclip > 0:
            torch.nn.utils.clip_grad_norm_(model.parameters(), args.gradclip)
        optimizer.step()

        if i % args.print_freq == 0:
            progress.display(i)
示例#8
0
def main(args: argparse.Namespace):
    setup_seed(args.seed)
    cudnn.benchmark = True

    # Data loading code
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
    train_transform = transforms.Compose([
        ResizeImage(256),
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(), normalize
    ])
    val_transform = transforms.Compose([
        ResizeImage(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(), normalize
    ])

    dataset = datasets.__dict__[args.data]
    args.root = os.path.join(args.root, args.data)
    train_source_dataset = dataset(root=args.root,
                                   task=args.source,
                                   download=False,
                                   transform=train_transform)
    train_source_loader = DataLoader(train_source_dataset,
                                     batch_size=args.batch_size,
                                     shuffle=True,
                                     num_workers=args.workers,
                                     drop_last=True)
    train_target_dataset = dataset(root=args.root,
                                   task=args.target,
                                   download=False,
                                   transform=train_transform)
    train_target_loader = DataLoader(train_target_dataset,
                                     batch_size=args.batch_size,
                                     shuffle=True,
                                     num_workers=args.workers,
                                     drop_last=True)
    val_dataset = dataset(root=args.root,
                          task=args.target,
                          download=False,
                          transform=val_transform)
    val_loader = DataLoader(val_dataset,
                            batch_size=args.val_batch_size,
                            shuffle=False,
                            num_workers=args.workers)
    if args.data == 'DomainNet':
        test_dataset = dataset(root=args.root,
                               task=args.target,
                               evaluate=True,
                               download=False,
                               transform=val_transform)
        test_loader = DataLoader(test_dataset,
                                 batch_size=args.val_batch_size,
                                 shuffle=False,
                                 num_workers=args.workers)
    else:
        test_loader = val_loader

    train_source_iter = ForeverDataIterator(train_source_loader)
    train_target_iter = ForeverDataIterator(train_target_loader)

    #pseudo labels
    pseudo_labels = PseudoLabeling(len(train_target_dataset),
                                   train_target_dataset.num_classes,
                                   args.prob_ema_gamma, args.tc_ema_gamma,
                                   args.threshold, device)

    # create model
    print("=> using pre-trained model '{}'".format(args.arch))
    backbone = models.__dict__[args.arch](pretrained=True)
    num_classes = train_source_dataset.num_classes
    model = MCdropClassifier(backbone=backbone,
                             num_classes=num_classes,
                             bottleneck_dim=args.bottleneck_dim,
                             classifier_width=args.classifier_width,
                             dropout_rate=args.dropout_rate,
                             dropout_type=args.dropout_type).to(device)

    # define loss function
    jmmd_loss = JointMultipleKernelMaximumMeanDiscrepancy(
        kernels=([GaussianKernel(alpha=2**k) for k in range(-3, 2)],
                 (GaussianKernel(sigma=0.92, track_running_stats=False), )),
        linear=args.linear,
        thetas=None).to(device)

    # define optimizer
    parameters = model.get_parameters(args.freeze_backbone,
                                      args.backbone_decay)
    optimizer = SGD(parameters,
                    args.lr,
                    momentum=args.momentum,
                    weight_decay=args.wd,
                    nesterov=True)
    lr_sheduler = StepwiseLR(optimizer,
                             init_lr=args.lr,
                             gamma=args.gamma,
                             decay_rate=args.decay_rate)

    # start training
    best_acc1 = 0.
    for epoch in range(args.epochs):
        psudo_label_update_and_weight_calculate(train_source_loader,
                                                val_loader, model,
                                                pseudo_labels, args, epoch)
        pseudo_labels.copy_history()

        # train for one epoch
        train(train_source_iter, train_target_iter, model, jmmd_loss,
              optimizer, lr_sheduler, pseudo_labels, epoch, args)

        acc1 = validate(val_loader, model, args)

        # remember best acc@1 and save checkpoint
        if acc1 > best_acc1:
            best_model = copy.deepcopy(model.state_dict())
            best_acc1 = acc1
            print('find best!')
        print("current best = {:3.3f}".format(best_acc1))

    print("best_acc1 = {:3.3f}".format(best_acc1))

    # evaluate on test set
    model.load_state_dict(best_model)
    acc1 = validate(test_loader, model, args)
    print("test_acc1 = {:3.3f}".format(acc1))