Beispiel #1
0
    model_path = checkpoint_path + args.resume
    if os.path.isfile(model_path):
        print('==> loading checkpoint {}'.format(args.resume))
        checkpoint = torch.load(model_path)
        start_epoch = checkpoint['epoch']
        net.load_state_dict(checkpoint['net'])
        print('==> loaded checkpoint {} (epoch {})'.format(
            args.resume, checkpoint['epoch']))
    else:
        print('==> no checkpoint found at {}'.format(args.resume))

# define loss function
if args.label_smooth == 'on':
    criterion_id = nn.CrossEntropyLoss()
else:
    criterion_id = CrossEntropyLabelSmooth(n_class)

if args.method == 'agw':
    criterion_tri = TripletLoss_WRT()
else:
    loader_batch = args.batch_size * args.num_pos
    #criterion_tri= OriTripletLoss(batch_size=loader_batch, margin=args.margin)
    criterion_tri = CenterTripletLoss(batch_size=loader_batch,
                                      margin=args.margin)

criterion_id.to(device)
criterion_tri.to(device)

if args.optim == 'sgd':
    if args.pcb == 'on':
        ignored_params = list(map(id, net.local_conv_list.parameters())) \
def train(parameters: Dict[str, float]) -> nn.Module:
    global args 
    print("====", args.focus,  "=====")
    torch.manual_seed(args.seed)
    # args.gpu_devices = "0,1"
    os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu_devices
    use_gpu = torch.cuda.is_available()
    cudnn.benchmark = True
    torch.cuda.manual_seed_all(args.seed)
    
    dataset = data_manager.init_dataset(name=args.dataset, sampling= args.sampling)
    transform_test = transforms.Compose([
    transforms.Resize((args.height, args.width), interpolation=3),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])


    pin_memory = True if use_gpu else False
    transform_train = transforms.Compose([
                transforms.Resize((args.height, args.width), interpolation=3),
                transforms.RandomHorizontalFlip(p=0.5),
                transforms.Pad(10),
                Random2DTranslation(args.height, args.width),
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
            ])

    batch_size = int(round(parameters.get("batch_size", 32) )) 
    base_learning_rate = 0.00035
    # weight_decay = 0.0005
    alpha = parameters.get("alpha", 1.2)
    sigma = parameters.get("sigma", 0.8)
    l = parameters.get("l", 0.5)
    beta_ratio = parameters.get("beta_ratio", 0.5)
    gamma = parameters.get("gamma", 0.1)
    margin = parameters.get("margin", 0.3)
    weight_decay = parameters.get("weight_decay", 0.0005)
    lamb = 0.3 
    
    num_instances = 4
    pin_memory = True
    trainloader = DataLoader(
    VideoDataset(dataset.train, seq_len=args.seq_len, sample='random',transform=transform_train),
    sampler=RandomIdentitySampler(dataset.train, num_instances=args.num_instances),
    batch_size=batch_size, num_workers=args.workers,
    pin_memory=pin_memory, drop_last=True,
    )

    if args.dataset == 'mars_subset' :
        validation_loader = DataLoader(
            VideoDataset(dataset.val, seq_len=8, sample='random', transform=transform_test),
            batch_size=args.test_batch, shuffle=False, num_workers=args.workers,
            pin_memory=pin_memory, drop_last=False,
        )
    else:
        queryloader = DataLoader(
            VideoDataset(dataset.val_query, seq_len=args.seq_len, sample='dense_subset', transform=transform_test),
            batch_size=args.test_batch, shuffle=False, num_workers=args.workers,
            pin_memory=pin_memory, drop_last=False,
        )
        galleryloader = DataLoader(
            VideoDataset(dataset.val_gallery, seq_len=args.seq_len, sample='dense_subset', transform=transform_test),
            batch_size=args.test_batch, shuffle=False, num_workers=args.workers,
            pin_memory=pin_memory, drop_last=False,
        )

    criterion_htri = TripletLoss(margin, 'cosine')
    criterion_xent = CrossEntropyLabelSmooth(dataset.num_train_pids)
    criterion_center_loss = CenterLoss(use_gpu=1)
    criterion_osm_caa = OSM_CAA_Loss(alpha=alpha , l=l , osm_sigma=sigma )
    args.arch = "ResNet50ta_bt"
    model = models.init_model(name=args.arch, num_classes=dataset.num_train_pids, loss={'xent', 'htri'})
    if use_gpu:
        model = nn.DataParallel(model).cuda()
    params = []
    for key, value in model.named_parameters():
        if not value.requires_grad:
            continue
        lr = base_learning_rate
        weight_decay = weight_decay
        params += [{"params": [value], "lr": lr, "weight_decay": weight_decay}]

    optimizer = torch.optim.Adam(params)
    scheduler = WarmupMultiStepLR(optimizer, milestones=[40, 70], gamma=gamma, warmup_factor=0.01, warmup_iters=10)
    optimizer_center = torch.optim.SGD(criterion_center_loss.parameters(), lr=0.5)
    start_epoch = args.start_epoch
    best_rank1 = -np.inf
    num_epochs = 121
    
    if 'mars' not in args.dataset :
        num_epochs = 121
    # test_rerank(model, queryloader, galleryloader, args.pool, use_gpu, lamb=lamb , parameters=parameters)
    for epoch in range (num_epochs):
        vals = train_model(model, criterion_xent, criterion_htri, optimizer, trainloader, use_gpu , optimizer_center , criterion_center_loss, criterion_osm_caa, beta_ratio)
        if math.isnan(vals[0]):
            return 0
        scheduler.step()
        if epoch % 40 ==0 :
            print("TripletLoss {:.6f} OSM Loss {:.6f} Cross_entropy {:.6f} Total Loss {:.6f}  ".format(vals[1] , vals[3] , vals[1] , vals[0]))            
    
    if args.dataset == 'mars_subset' :
        result1 = test_validation(model, validation_loader, args.pool, use_gpu,  parameters=parameters)
        del validation_loader
    else:
        result1= test_rerank(model, queryloader, galleryloader, args.pool, use_gpu, lamb=lamb , parameters=parameters)    
        del queryloader
        del galleryloader
    del trainloader 
    del model
    del criterion_htri
    del criterion_xent
    del criterion_center_loss
    del criterion_osm_caa
    del optimizer
    del optimizer_center
    del scheduler
    return result1
Beispiel #3
0
def main():
    global args, best_prec1
    args = parser.parse_args()

    if args.seed is not None:
        random.seed(args.seed)
        torch.manual_seed(args.seed)
        cudnn.deterministic = True
        warnings.warn('You have chosen to seed training. '
                      'This will turn on the CUDNN deterministic setting, '
                      'which can slow down your training considerably! '
                      'You may see unexpected behavior when restarting '
                      'from checkpoints.')

    if args.gpu is not None:
        warnings.warn('You have chosen a specific GPU. This will completely '
                      'disable data parallelism.')

    args.distributed = args.world_size > 1

    if args.distributed:
        dist.init_process_group(backend=args.dist_backend,
                                init_method=args.dist_url,
                                world_size=args.world_size)

    # create model
    if args.pretrained:
        print("=> using pre-trained model '{}'".format(args.arch))
        model = models.__dict__[args.arch]()
    else:
        print("=> creating model '{}'".format(args.arch))

    model = models.__dict__[args.arch]()

    if args.gpu is not None:
        model = model.cuda(args.gpu)
    elif args.distributed:
        model.cuda()
        model = torch.nn.parallel.DistributedDataParallel(model)

    else:
        if args.arch.startswith('alexnet') or args.arch.startswith('vgg'):
            model.features = torch.nn.DataParallel(model.features)
            model.cuda()
        else:
            model = torch.nn.DataParallel(model).cuda()

    print(model)

    # get the number of models parameters
    print('Number of models parameters: {}'.format(
        sum([p.data.nelement() for p in model.parameters()])))

    # define loss function (criterion) and optimizer
    criterion = CrossEntropyLabelSmooth(num_classes=1000, epsilon=0.1)

    optimizer = torch.optim.SGD(model.parameters(),
                                args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            best_prec1 = checkpoint['best_prec1']
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
            del checkpoint
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    cudnn.benchmark = True

    # Data loading code
    traindir = os.path.join(args.data, 'train')
    valdir = os.path.join(args.data, 'val')
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])

    train_dataset = datasets.ImageFolder(
        traindir,
        transforms.Compose([
            transforms.RandomResizedCrop(224),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize,
        ]))

    if args.distributed:
        train_sampler = torch.utils.data.distributed.DistributedSampler(
            train_dataset)
    else:
        train_sampler = None

    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=(train_sampler is None),
                                               num_workers=args.workers,
                                               pin_memory=True,
                                               sampler=train_sampler)

    val_loader = torch.utils.data.DataLoader(datasets.ImageFolder(
        valdir,
        transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            normalize,
        ])),
                                             batch_size=args.batch_size,
                                             shuffle=False,
                                             num_workers=args.workers,
                                             pin_memory=True)

    if args.evaluate:
        m = time.time()
        _, _ = validate(val_loader, model, criterion)
        n = time.time()
        print((n - m) / 3600)
        return

    directory = "runs/%s/" % (args.arch + '_' + args.action)
    if not os.path.exists(directory):
        os.makedirs(directory)

    Loss_plot = {}
    train_prec1_plot = {}
    train_prec5_plot = {}
    val_prec1_plot = {}
    val_prec5_plot = {}

    for epoch in range(args.start_epoch, args.epochs):
        start_time = time.time()
        if args.distributed:
            train_sampler.set_epoch(epoch)
        adjust_learning_rate(optimizer, epoch)

        # train for one epoch
        # train(train_loader, model, criterion, optimizer, epoch)
        loss_temp, train_prec1_temp, train_prec5_temp = train(
            train_loader, model, criterion, optimizer, epoch)
        Loss_plot[epoch] = loss_temp
        train_prec1_plot[epoch] = train_prec1_temp
        train_prec5_plot[epoch] = train_prec5_temp

        # evaluate on validation set
        # prec1 = validate(val_loader, model, criterion)
        prec1, prec5 = validate(val_loader, model, criterion)
        val_prec1_plot[epoch] = prec1
        val_prec5_plot[epoch] = prec5

        # remember best prec@1 and save checkpoint
        is_best = prec1 > best_prec1
        best_prec1 = max(prec1, best_prec1)
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'arch': args.arch,
                'state_dict': model.state_dict(),
                'best_prec1': best_prec1,
                'optimizer': optimizer.state_dict(),
            }, is_best)

        if is_best:
            best_epoch = epoch + 1
            best_prec5 = prec5
        print(
            ' * BestPrec so far@1 {top1:.3f} @5 {top5:.3f} in epoch {best_epoch}'
            .format(top1=best_prec1, top5=best_prec5, best_epoch=best_epoch))

        data_save(directory + 'Loss_plot.txt', Loss_plot)
        data_save(directory + 'train_prec1.txt', train_prec1_plot)
        data_save(directory + 'train_prec5.txt', train_prec5_plot)
        data_save(directory + 'val_prec1.txt', val_prec1_plot)
        data_save(directory + 'val_prec5.txt', val_prec5_plot)

        end_time = time.time()
        time_value = (end_time - start_time) / 3600
        print("-" * 80)
        print(time_value)
        print("-" * 80)
Beispiel #4
0
def train(model, model_ema, memorybank, labeled_eval_loader_train,
          unlabeled_eval_loader_test, unlabeled_eval_loader_train, args):
    labeled_train_loader = CIFAR10Loader_iter(root=args.dataset_root,
                                              batch_size=args.batch_size // 2,
                                              split='train',
                                              aug='twice',
                                              shuffle=True,
                                              target_list=range(
                                                  args.num_labeled_classes))
    optimizer = SGD(model.parameters(),
                    lr=args.lr,
                    momentum=args.momentum,
                    weight_decay=args.weight_decay)
    exp_lr_scheduler = lr_scheduler.StepLR(optimizer,
                                           step_size=args.step_size,
                                           gamma=args.gamma)
    criterion1 = nn.CrossEntropyLoss()
    criterion2 = BCE()
    criterion3 = CrossEntropyLabelSmooth(
        num_classes=args.num_unlabeled_classes)

    for epoch in range(args.epochs):
        loss_record = AverageMeter()
        model.train()
        model_ema.train()
        exp_lr_scheduler.step()
        w = args.rampup_coefficient * ramps.sigmoid_rampup(
            epoch, args.rampup_length)

        iters = 400

        if epoch % 5 == 0:
            args.head = 'head2'
            feats, feats_mb, _ = test(model_ema, unlabeled_eval_loader_train,
                                      args)
            feats = F.normalize(torch.cat(feats, dim=0), dim=1)
            feats_mb = F.normalize(torch.cat(feats_mb, dim=0), dim=1)
            cluster = faiss.Kmeans(512, 5, niter=300, verbose=True, gpu=True)
            moving_avg_features = feats.numpy()
            cluster.train(moving_avg_features)
            _, labels_ = cluster.index.search(moving_avg_features, 1)
            labels = labels_ + 5
            target_label = labels.reshape(-1).tolist()

            # centers=faiss.vector_to_array(cluster.centroids).reshape(5, 512)
            centers = cluster.centroids

            # Memory bank by zkc
            # if epoch == 0: memorybank.features = torch.cat((F.normalize(torch.tensor(centers).cuda(), dim=1), feats), dim=0).cuda()
            # memorybank.labels = torch.cat((torch.arange(args.num_unlabeled_classes), torch.Tensor(target_label).long())).cuda()
            if epoch == 0: memorybank.features = feats_mb.cuda()
            memorybank.labels = torch.Tensor(
                labels_.reshape(-1).tolist()).long().cuda()

            model.memory.prototypes[args.num_labeled_classes:] = F.normalize(
                torch.tensor(centers).cuda(), dim=1)
            model_ema.memory.prototypes[args.
                                        num_labeled_classes:] = F.normalize(
                                            torch.tensor(centers).cuda(),
                                            dim=1)

            feats, _, labels = test(model_ema, labeled_eval_loader_train, args)
            feats = F.normalize(torch.cat(feats, dim=0), dim=1)
            centers = torch.zeros(args.num_labeled_classes, 512)
            for i in range(args.num_labeled_classes):
                idx = torch.where(torch.tensor(labels) == i)[0]
                centers[i] = torch.mean(feats[idx], 0)
            model.memory.prototypes[:args.num_labeled_classes] = torch.tensor(
                centers).cuda()
            model_ema.memory.prototypes[:args.
                                        num_labeled_classes] = torch.tensor(
                                            centers).cuda()

            unlabeled_train_loader = CIFAR10Loader_iter(
                root=args.dataset_root,
                batch_size=args.batch_size // 2,
                split='train',
                aug='twice',
                shuffle=True,
                target_list=range(args.num_labeled_classes, num_classes),
                new_labels=target_label)
            # model.head2.weight.data.copy_(
            #     torch.from_numpy(F.normalize(target_centers, axis=1)).float().cuda())

        # labeled_train_loader.new_epoch()
        # unlabeled_train_loader.new_epoch()
        # for batch_idx,_ in enumerate(range(iters)):
        #     ((x_l, x_bar_l), label_l, idx) = labeled_train_loader.next()
        #     ((x_u, x_bar_u), label_u, idx) = unlabeled_train_loader.next()
        for batch_idx, (((x_l, x_bar_l), label_l, idx_l),
                        ((x_u, x_bar_u), label_u, idx_u)) in enumerate(
                            zip(labeled_train_loader, unlabeled_train_loader)):

            x = torch.cat([x_l, x_u], dim=0)
            x_bar = torch.cat([x_bar_l, x_bar_u], dim=0)
            label = torch.cat([label_l, label_u], dim=0)
            idx = torch.cat([idx_l, idx_u], dim=0)
            x, x_bar, label = x.to(device), x_bar.to(device), label.to(device)

            output1, output2, feat, feat_mb = model(x)
            output1_bar, output2_bar, _, _ = model(x_bar)

            with torch.no_grad():
                output1_ema, output2_ema, feat_ema, feat_mb_ema = model_ema(x)
                output1_bar_ema, output2_bar_ema, _, _ = model_ema(x_bar)
            prob1, prob1_bar, prob2, prob2_bar = F.softmax(
                output1, dim=1), F.softmax(output1_bar, dim=1), F.softmax(
                    output2, dim=1), F.softmax(output2_bar, dim=1)
            prob1_ema, prob1_bar_ema, prob2_ema, prob2_bar_ema = F.softmax(
                output1_ema,
                dim=1), F.softmax(output1_bar_ema, dim=1), F.softmax(
                    output2_ema, dim=1), F.softmax(output2_bar_ema, dim=1)

            mask_lb = label < args.num_labeled_classes

            loss_ce_label = criterion1(output1[mask_lb], label[mask_lb])
            loss_ce_unlabel = criterion1(output2[~mask_lb],
                                         label[~mask_lb])  # torch.tensor(0)#

            loss_in_unlabel = torch.tensor(
                0
            )  #memorybank(feat_mb[~mask_lb], feat_mb_ema[~mask_lb], label[~mask_lb], idx[~mask_lb])

            loss_ce = loss_ce_label + loss_ce_unlabel

            loss_bce = rank_bce(criterion2, feat, mask_lb, prob2,
                                prob2_bar)  # torch.tensor(0)#

            consistency_loss = F.mse_loss(prob1, prob1_bar) + F.mse_loss(
                prob2, prob2_bar)
            consistency_loss_ema = F.mse_loss(
                prob1, prob1_bar_ema) + F.mse_loss(prob2, prob2_bar_ema)

            loss = loss_ce + loss_bce + w * consistency_loss + w * consistency_loss_ema + loss_in_unlabel

            loss_record.update(loss.item(), x.size(0))
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            _update_ema_variables(model, model_ema, 0.99,
                                  epoch * iters + batch_idx)

            if batch_idx % 200 == 0:
                print(
                    'Train Epoch: {}, iter {}/{} unl-CE Loss: {:.4f}, unl-instance Loss: {:.4f}, l-CE Loss: {:.4f}, BCE Loss: {:.4f}, CL Loss: {:.4f}, Avg Loss: {:.4f}'
                    .format(epoch, batch_idx, 400, loss_ce_unlabel.item(),
                            loss_in_unlabel.item(), loss_ce_label.item(),
                            loss_bce.item(), consistency_loss.item(),
                            loss_record.avg))
        print('Train Epoch: {} Avg Loss: {:.4f}'.format(
            epoch, loss_record.avg))
        # print('test on labeled classes')
        # args.head = 'head1'
        # test(model, labeled_eval_loader_test, args)

        # print('test on unlabeled classes')
        args.head = 'head2'
        # test(model, unlabeled_eval_loader_train, args)
        test(model_ema, unlabeled_eval_loader_test, args)
def main():
    global args, best_prec1, best_prec5, best_epoch

    if args.seed is not None:
        random.seed(args.seed)
        torch.manual_seed(args.seed)
        cudnn.deterministic = True
        warnings.warn('You have chosen to seed training. '
                      'This will turn on the CUDNN deterministic setting, '
                      'which can slow down your training considerably! '
                      'You may see unexpected behavior when restarting '
                      'from checkpoints.')

    if args.gpu is not None:
        warnings.warn('You have chosen a specific GPU. This will completely '
                      'disable data parallelism.')

    args.distributed = args.world_size > 1

    if args.distributed:
        dist.init_process_group(backend=args.dist_backend,
                                init_method=args.dist_url,
                                world_size=args.world_size)

    # create model
    if args.pretrained:
        log_string("=> using pre-trained model")
        model = res2net50()
    else:
        log_string("=> creating model")
        model = res2net50()

    if args.gpu is not None:
        model = model.cuda(args.gpu)
    elif args.distributed:
        model.cuda()
        model = torch.nn.parallel.DistributedDataParallel(model)
    else:
        if args.arch.startswith('alexnet') or args.arch.startswith('vgg'):
            model.features = torch.nn.DataParallel(model.features)
            model.cuda()
        else:
            #####
            gpus = "1,2,3,4"
            os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"  # see issue #152
            os.environ["CUDA_VISIBLE_DEVICES"] = gpus
            log_string("IN DataParallel mode.")
            #####
            model = torch.nn.DataParallel(model).cuda()

    # define loss function (criterion) and optimizer
    criterion = CrossEntropyLabelSmooth(num_classes=1000, epsilon=0.1)

    optimizer = torch.optim.SGD(model.parameters(),
                                args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            log_string("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            best_prec1 = checkpoint['best_prec1']
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            log_string("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
        else:
            log_string("=> no checkpoint found at '{}'".format(args.resume))

    cudnn.benchmark = True

    # Data loading code
    traindir = os.path.join(args.data, 'train')
    valdir = os.path.join(args.data, 'val')
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])

    train_dataset = datasets.ImageFolder(
        traindir,
        transforms.Compose([
            transforms.RandomResizedCrop(224),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize,
        ]))

    if args.distributed:
        train_sampler = torch.utils.data.distributed.DistributedSampler(
            train_dataset)
    else:
        train_sampler = None

    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=(train_sampler is None),
                                               num_workers=args.workers,
                                               pin_memory=True,
                                               sampler=train_sampler)

    val_loader = torch.utils.data.DataLoader(datasets.ImageFolder(
        valdir,
        transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            normalize,
        ])),
                                             batch_size=args.batch_size,
                                             shuffle=False,
                                             num_workers=args.workers,
                                             pin_memory=True)

    if args.evaluate:
        validate(val_loader, model, criterion)
        return

    for epoch in range(args.start_epoch, args.epochs):
        if args.distributed:
            train_sampler.set_epoch(epoch)
        adjust_learning_rate(optimizer, epoch)

        # train for one epoch
        train(train_loader, model, criterion, optimizer, epoch)

        # evaluate on validation set
        prec1, prec5 = validate(val_loader, model, criterion, epoch)

        # remember best prec@1 and save checkpoint
        is_best = prec1 > best_prec1
        best_prec1 = max(prec1, best_prec1)
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'arch': args.arch,
                'state_dict': model.state_dict(),
                'best_prec1': best_prec1,
                'optimizer': optimizer.state_dict(),
            }, is_best)
        if is_best:
            best_epoch = epoch + 1
            best_prec5 = prec5
        log_string(
            ' * BestPrec so far@1 {top1:.3f} @5 {top5:.3f} in epoch {best_epoch}'
            .format(top1=best_prec1, top5=best_prec5, best_epoch=best_epoch))
Beispiel #6
0
def train_target(args):
    dset_loaders = data_load(args)
    ## set base network
    if args.net[0:3] == 'res':
        netF = network.ResBase(res_name=args.net).cuda()
    elif args.net[0:3] == 'vgg':
        netF = network.VGGBase(vgg_name=args.net).cuda()

    netB = network.feat_bootleneck(type=args.classifier,
                                   feature_dim=netF.in_features,
                                   bottleneck_dim=args.bottleneck).cuda()
    netC = network.feat_classifier(type=args.layer,
                                   class_num=args.class_num,
                                   bottleneck_dim=args.bottleneck).cuda()

    if not args.ssl == 0:
        netR = network.feat_classifier(type='linear',
                                       class_num=4,
                                       bottleneck_dim=2 *
                                       args.bottleneck).cuda()
        netR_dict, acc_rot = train_target_rot(args)
        netR.load_state_dict(netR_dict)

    modelpath = args.output_dir_src + '/source_F.pt'
    netF.load_state_dict(torch.load(modelpath))
    modelpath = args.output_dir_src + '/source_B.pt'
    netB.load_state_dict(torch.load(modelpath))
    modelpath = args.output_dir_src + '/source_C.pt'
    netC.load_state_dict(torch.load(modelpath))
    netC.eval()
    for k, v in netC.named_parameters():
        v.requires_grad = False

    param_group = []
    for k, v in netF.named_parameters():
        if args.lr_decay1 > 0:
            param_group += [{'params': v, 'lr': args.lr * args.lr_decay1}]
        else:
            v.requires_grad = False
    for k, v in netB.named_parameters():
        if args.lr_decay2 > 0:
            param_group += [{'params': v, 'lr': args.lr * args.lr_decay2}]
        else:
            v.requires_grad = False
    if not args.ssl == 0:
        for k, v in netR.named_parameters():
            param_group += [{'params': v, 'lr': args.lr * args.lr_decay2}]
        netR.train()

    optimizer = optim.SGD(param_group)
    optimizer = op_copy(optimizer)

    max_iter = args.max_epoch * len(dset_loaders["target_u"])
    interval_iter = max_iter // args.interval
    iter_num = 0

    while iter_num < max_iter:
        optimizer.zero_grad()
        try:
            inputs_test, _, tar_idx = iter_test.next()
        except:
            iter_test = iter(dset_loaders["target_u"])
            inputs_test, _, tar_idx = iter_test.next()

        try:
            inputs_target_l, labels_target_l, tar_l_idx = iter_target_l.next()
        except:
            iter_target_l = iter(dset_loaders["target_l"])
            inputs_target_l, labels_target_l, tar_l_idx = iter_target_l.next()

        if inputs_test.size(0) == 1 or inputs_target_l.size(0) == 1:
            continue

        if iter_num % interval_iter == 0 and args.cls_par > 0:
            netF.eval()
            netB.eval()
            mem_label = obtain_label(dset_loaders['test'], netF, netB, netC,
                                     args, dset_loaders['target_l'])
            mem_label = torch.from_numpy(mem_label).cuda()
            netF.train()
            netB.train()

        inputs_test = inputs_test.cuda()

        iter_num += 1
        lr_scheduler(optimizer, iter_num=iter_num, max_iter=max_iter)
        if args.cls_par > 0:
            pred = mem_label[tar_idx]

        features_test = netB(netF(inputs_test))
        outputs_test = netC(features_test)

        if args.cls_par > 0:
            classifier_loss = nn.CrossEntropyLoss()(outputs_test, pred)
            classifier_loss *= args.cls_par
        else:
            classifier_loss = torch.tensor(0.0).cuda()

        inputs_target_l = inputs_target_l.cuda()
        features_target_l = netB(netF(inputs_target_l))
        outputs_target_l = netC(features_target_l)
        classifier_loss += CrossEntropyLabelSmooth(num_classes=args.class_num,
                                                   epsilon=0.1)(
                                                       outputs_target_l,
                                                       labels_target_l.cuda())

        if args.ent:
            softmax_out = nn.Softmax(dim=1)(outputs_test)
            entropy_loss = torch.mean(loss.Entropy(softmax_out))
            if args.gent:
                msoftmax = softmax_out.mean(dim=0)
                gentropy_loss = torch.sum(-msoftmax *
                                          torch.log(msoftmax + args.epsilon))
                entropy_loss -= gentropy_loss
            im_loss = entropy_loss * args.ent_par
            classifier_loss += im_loss

        classifier_loss.backward()

        if not args.ssl == 0:
            r_labels_target = np.random.randint(0, 4, len(inputs_test))
            r_inputs_target = rotation.rotate_batch_with_labels(
                inputs_test, r_labels_target)
            r_labels_target = torch.from_numpy(r_labels_target).cuda()
            r_inputs_target = r_inputs_target.cuda()

            f_outputs = netB(netF(inputs_test))
            f_outputs = f_outputs.detach()
            f_r_outputs = netB(netF(r_inputs_target))
            r_outputs_target = netR(torch.cat((f_outputs, f_r_outputs), 1))

            rotation_loss = args.ssl * nn.CrossEntropyLoss()(r_outputs_target,
                                                             r_labels_target)
            rotation_loss.backward()

        optimizer.step()

        if iter_num % interval_iter == 0 or iter_num == max_iter:
            netF.eval()
            netB.eval()
            acc_s_te, _ = cal_acc(dset_loaders['test'], netF, netB, netC,
                                  False)
            log_str = 'Task: {}, Iter:{}/{}; Accuracy = {:.2f}%'.format(
                args.name, iter_num, max_iter, acc_s_te)

            args.out_file.write(log_str + '\n')
            args.out_file.flush()
            print(log_str + '\n')
            netF.train()
            netB.train()

    if args.issave:
        torch.save(
            netF.state_dict(),
            osp.join(args.output_dir, "target_F_" + args.savename + ".pt"))
        torch.save(
            netB.state_dict(),
            osp.join(args.output_dir, "target_B_" + args.savename + ".pt"))
        torch.save(
            netC.state_dict(),
            osp.join(args.output_dir, "target_C_" + args.savename + ".pt"))

    return netF, netB, netC
Beispiel #7
0
            loss = nn.CrossEntropyLoss()(out, y)
            _, pred = torch.max(out.data, 1)
            total += y.size(0)
            correct += (pred == y).squeeze().sum().cpu().numpy()
            valid_loss += loss.item()
    valid_acc = correct / total
    valid_loss /= step
    logx.msg("valid_acc:" + str(valid_acc))
    return valid_acc, valid_loss


if __name__ == '__main__':

    my_model = models.resnet50(pretrained=False)
    my_model = my_model.to(device)
    criterion = CrossEntropyLabelSmooth(num_classes=2, epsilon=0.1)
    optimizer = optim.Adam(my_model.parameters())
    normMean = [0.5964188, 0.4566936, 0.3908954]
    normStd = [0.2590655, 0.2314241, 0.2269535]
    # train_transformer = transforms.Compose([
    #     transforms.Resize(225),
    #     transforms.CenterCrop(200),
    #     transforms.RandomHorizontalFlip(p=0.5),
    #     transforms.RandomRotation(degrees=(5, 10)),
    #     transforms.ToTensor(),
    #     transforms.Normalize(normMean, normStd)
    # ])
    train_transformer = transforms.Compose(
        [transforms.ToTensor(),
         transforms.Normalize(normMean, normStd)])
    valid_transformer = transforms.Compose(
Beispiel #8
0
def train_target(args):
    dset_loaders = data_load(args)
    ## set base network
    if args.norm_layer == 'batchnorm':
        norm_layer = nn.BatchNorm2d
    elif args.norm_layer == 'groupnorm':

        def gn_helper(planes):
            return nn.GroupNorm(8, planes)

        norm_layer = gn_helper
    if args.net[0:3] == 'res':
        if '26' in args.net:
            netF = network.ResCifarBase(26, norm_layer=norm_layer)
            args.bottleneck = netF.in_features // 2
        else:
            netF = network.ResBase(res_name=args.net, args=args)
    elif args.net[0:3] == 'vgg':
        netF = network.VGGBase(vgg_name=args.net)

    # print(args.ssl_before_btn)
    if args.ssl_before_btn:
        netH = network.ssl_head(ssl_task=args.ssl_task,
                                feature_dim=netF.in_features,
                                embedding_dim=args.embedding_dim)
    else:
        netH = network.ssl_head(ssl_task=args.ssl_task,
                                feature_dim=args.bottleneck,
                                embedding_dim=args.embedding_dim)
    if args.bottleneck != 0:
        netB = network.feat_bootleneck(type=args.classifier,
                                       feature_dim=netF.in_features,
                                       bottleneck_dim=args.bottleneck,
                                       norm_btn=args.norm_btn)
        if args.reset_running_stats and args.classifier == 'bn':
            netB.norm.running_mean.fill_(0.)
            netB.norm.running_var.fill_(1.)

        if args.reset_bn_params and args.classifier == 'bn':
            netB.norm.weight.data.fill_(1.)
            netB.norm.bias.data.fill_(0.)
        netC = network.feat_classifier(type=args.layer,
                                       class_num=args.class_num,
                                       bottleneck_dim=args.bottleneck,
                                       bias=args.classifier_bias,
                                       temp=args.angular_temp,
                                       args=args)
    else:
        netB = nn.Identity()
        netC = network.feat_classifier(type=args.layer,
                                       class_num=args.class_num,
                                       bottleneck_dim=netF.in_features,
                                       bias=args.classifier_bias,
                                       temp=args.angular_temp,
                                       args=args)

    modelpath = args.output_dir_src + '/source_F.pt'
    netF.load_state_dict(torch.load(modelpath), strict=False)
    modelpath = args.output_dir_src + '/source_H.pt'
    netH.load_state_dict(torch.load(modelpath), strict=False)
    try:
        modelpath = args.output_dir_src + '/source_B.pt'
        netB.load_state_dict(torch.load(modelpath), strict=False)
    except:
        print('Skipped loading btn for version compatibility')
    modelpath = args.output_dir_src + '/source_C.pt'
    netC.load_state_dict(torch.load(modelpath), strict=False)
    netC.eval()
    for k, v in netC.named_parameters():
        v.requires_grad = False

    if args.dataparallel:
        netF = nn.DataParallel(netF).cuda()
        netH = nn.DataParallel(netH).cuda()
        netB = nn.DataParallel(netB).cuda()
        netC = nn.DataParallel(netC).cuda()
    else:
        netF.cuda()
        netH.cuda()
        netB.cuda()
        netC.cuda()

    param_group = []
    for k, v in netF.named_parameters():
        if args.lr_decay1 > 0:
            param_group += [{'params': v, 'lr': args.lr * args.lr_decay1}]
        else:
            v.requires_grad = False
    for k, v in netB.named_parameters():
        if args.lr_decay2 > 0:
            param_group += [{'params': v, 'lr': args.lr * args.lr_decay2}]
        else:
            v.requires_grad = False
    for k, v in netH.named_parameters():
        if args.lr_decay2 > 0:
            param_group += [{'params': v, 'lr': args.lr * args.lr_decay2}]
        else:
            v.requires_grad = False

    if args.ssl_task in ['simclr', 'crs']:
        ssl_loss_fn = NTXentLoss(args.batch_size, args.temperature,
                                 True).cuda()
    elif args.ssl_task in ['supcon', 'crsc']:
        ssl_loss_fn = SupConLoss(temperature=args.temperature,
                                 base_temperature=args.temperature).cuda()
    elif args.ssl_task == 'ls_supcon':
        ssl_loss_fn = LabelSmoothedSCLLoss(args.batch_size, args.temperature,
                                           args.class_num, args.ssl_smooth)

    if args.cr_weight > 0:
        if args.cr_metric == 'cos':
            dist = nn.CosineSimilarity(dim=1).cuda()
        elif args.cr_metric == 'l1':
            dist = nn.PairwiseDistance(p=1)
        elif args.cr_metric == 'l2':
            dist = nn.PairwiseDistance(p=2)
        elif args.cr_metric == 'bce':
            dist = nn.BCEWithLogitsLoss(reduction='sum').cuda()
        elif args.cr_metric == 'kl':
            dist = nn.KLDivLoss(reduction='sum').cuda()

    use_second_pass = (args.ssl_task in ['simclr', 'supcon', 'ls_supcon'
                                         ]) and (args.ssl_weight > 0)
    use_third_pass = (args.cr_weight >
                      0) or (args.ssl_task in ['crsc', 'crs']
                             and args.ssl_weight > 0) or (args.cls3)

    optimizer = optim.SGD(param_group)
    optimizer = op_copy(optimizer)

    max_iter = args.max_epoch * len(dset_loaders["target"])
    interval_iter = max_iter // args.interval
    iter_num = 0

    centroid = None

    while iter_num < max_iter:
        try:
            inputs_test, _, tar_idx = iter_test.next()
        except:
            iter_test = iter(dset_loaders["target"])
            inputs_test, _, tar_idx = iter_test.next()

        try:
            if inputs_test.size(0) == 1:
                continue
        except:
            if inputs_test[0].size(0) == 1:
                continue

        if iter_num % interval_iter == 0 and (
                args.cls_par > 0
                or args.ssl_task in ['supcon', 'ls_supcon', 'crsc']):
            netF.eval()
            netH.eval()
            netB.eval()
            if centroid is None or args.recompute_centroid:
                mem_label, mem_conf, centroid, labelset = obtain_label(
                    dset_loaders['pl'], netF, netH, netB, netC, args)
                mem_label = torch.from_numpy(mem_label).cuda()
            else:
                pass

            netF.train()
            netH.train()
            netB.train()

        inputs_test1 = None
        inputs_test2 = None
        inputs_test3 = None

        pred = mem_label[tar_idx]

        if type(inputs_test) is list:
            inputs_test1 = inputs_test[0].cuda()
            inputs_test2 = inputs_test[1].cuda()
            if len(inputs_test) == 3:
                inputs_test3 = inputs_test[2].cuda()
        else:
            inputs_test1 = inputs_test.cuda()

        if args.layer in ['add_margin', 'arc_margin', 'sphere'
                          ] and args.use_margin_forward:
            labels_forward = pred
        else:
            labels_forward = None

        if inputs_test is not None:
            f1 = netF(inputs_test1)
            b1 = netB(f1)
            outputs_test = netC(b1, labels_forward)
        if use_second_pass:
            f2 = netF(inputs_test2)
            b2 = netB(f2)
        if use_third_pass:
            if args.sg3:
                with torch.no_grad():
                    f3 = netF(inputs_test3)
                    b3 = netB(f3)
                    c3 = netC(b3, labels_forward)
                    conf = torch.max(F.softmax(c3, dim=1), dim=1)[0]
            else:
                f3 = netF(inputs_test3)
                b3 = netB(f3)
                c3 = netC(b3, labels_forward)
                conf = torch.max(F.softmax(c3, dim=1), dim=1)[0]

        iter_num += 1
        lr_scheduler(args,
                     optimizer,
                     iter_num=iter_num,
                     max_iter=max_iter,
                     gamma=args.gamma,
                     power=args.power)

        pred = compute_pl(args, b3, centroid, labelset)

        if args.cr_weight > 0:
            if args.cr_site == 'feat':
                f_hard = f1
                f_weak = f3
            elif args.cr_site == 'btn':
                f_hard = b1
                f_weak = b3
            elif args.cr_site == 'cls':
                f_hard = outputs_test
                f_weak = c3
                if args.cr_metric != 'cos':
                    f_hard = F.softmax(f_hard, dim=-1)
                    f_weak = F.softmax(f_weak, dim=-1)
            else:
                raise NotImplementedError

        if args.cls_par > 0:
            # with torch.no_grad():
            #    conf, _ = torch.max(F.softmax(outputs_test, dim=-1), dim=-1)
            #    conf = conf.cpu().numpy()
            conf_cls = mem_conf[tar_idx]

            #pred = mem_label[tar_idx]
            if args.cls_smooth > 0:
                classifier_loss = CrossEntropyLabelSmooth(
                    num_classes=args.class_num, epsilon=args.cls_smooth)(
                        outputs_test[conf_cls >= args.conf_threshold],
                        pred[conf_cls >= args.conf_threshold])
            else:
                classifier_loss = nn.CrossEntropyLoss()(
                    outputs_test[conf_cls >= args.conf_threshold],
                    pred[conf_cls >= args.conf_threshold])
            if args.cls3:
                if args.cls_smooth > 0:
                    classifier_loss = CrossEntropyLabelSmooth(
                        num_classes=args.class_num, epsilon=args.cls_smooth)(
                            c3[conf_cls >= args.conf_threshold],
                            pred[conf_cls >= args.conf_threshold])
                else:
                    classifier_loss = nn.CrossEntropyLoss()(
                        c3[conf_cls >= args.conf_threshold],
                        pred[conf_cls >= args.conf_threshold])
            classifier_loss *= args.cls_par
            if iter_num < interval_iter and args.dset == "visda-c":
                classifier_loss *= 0
        else:
            classifier_loss = torch.tensor(0.0).cuda()

        if args.ent:
            softmax_out = nn.Softmax(dim=1)(outputs_test)
            entropy_loss = torch.mean(loss.Entropy(softmax_out))
            if args.gent:
                msoftmax = softmax_out.mean(dim=0)
                gentropy_loss = torch.sum(-msoftmax *
                                          torch.log(msoftmax + args.epsilon))
                entropy_loss -= gentropy_loss
            im_loss = entropy_loss * args.ent_par
            classifier_loss += im_loss

        if args.ssl_weight > 0:
            if args.ssl_before_btn:
                z1 = netH(f1, args.norm_feat)
                if use_second_pass:
                    z2 = netH(f2, args.norm_feat)
                if use_third_pass:
                    z3 = netH(f3, args.norm_feat)
            else:
                z1 = netH(b1, args.norm_feat)
                if use_second_pass:
                    z2 = netH(b2, args.norm_feat)
                if use_third_pass:
                    z3 = netH(b3, args.norm_feat)

            if args.ssl_task == 'simclr':
                ssl_loss = ssl_loss_fn(z1, z2)
            elif args.ssl_task == 'supcon':
                z = torch.cat([z1.unsqueeze(1), z2.unsqueeze(1)], dim=1)
                pl = mem_label[tar_idx]
                ssl_loss = ssl_loss_fn(z, pl)
            elif args.ssl_task == 'ls_supcon':
                pl = mem_label[tar_idx]
                ssl_loss = ssl_loss_fn(z1, z2, pl).squeeze()
            elif args.ssl_task == 'crsc':
                z = torch.cat([z1.unsqueeze(1), z3.unsqueeze(1)], dim=1)
                pl = mem_label[tar_idx]
                ssl_loss = ssl_loss_fn(z, pl)
            elif args.ssl_task == 'crs':
                ssl_loss = ssl_loss_fn(z1, z3)
            classifier_loss += args.ssl_weight * ssl_loss

        if args.cr_weight > 0:
            try:
                cr_loss = dist(f_hard[conf >= args.cr_threshold],
                               f_weak[conf >= args.cr_threshold]).mean()

                if args.cr_metric == 'cos':
                    cr_loss *= -1
            except:
                print('Error computing CR loss')
                cr_loss = torch.tensor(0.0).cuda()
            classifier_loss += args.cr_weight * cr_loss

        optimizer.zero_grad()
        classifier_loss.backward()
        optimizer.step()

        centroid = update_centroid(args, b3, centroid, c3, labelset)

        if iter_num % interval_iter == 0 or iter_num == max_iter:
            netF.eval()
            netH.eval()
            netB.eval()
            if args.dset in ['visda-c', 'CIFAR-10-C', 'CIFAR-100-C']:
                acc_s_te, acc_list = cal_acc(dset_loaders['test'], netF, netH,
                                             netB, netC, True)
                log_str = 'Task: {}, Iter:{}/{}; Accuracy = {:.2f}%'.format(
                    args.name, iter_num, max_iter, acc_s_te) + '\n' + acc_list
            else:
                acc_s_te, _ = cal_acc(dset_loaders['test'], netF, netH, netB,
                                      netC, False)
                log_str = 'Task: {}, Iter:{}/{}; Accuracy = {:.2f}%'.format(
                    args.name, iter_num, max_iter, acc_s_te)

            args.out_file.write(log_str + '\n')
            args.out_file.flush()
            print(log_str + '\n')
            netF.train()
            netH.train()
            netB.train()

    if args.issave:
        if args.dataparallel:
            torch.save(
                netF.module.state_dict(),
                osp.join(args.output_dir, "target_F_" + args.savename + ".pt"))
            torch.save(
                netH.module.state_dict(),
                osp.join(args.output_dir, "target_H_" + args.savename + ".pt"))
            torch.save(
                netB.module.state_dict(),
                osp.join(args.output_dir, "target_B_" + args.savename + ".pt"))
            torch.save(
                netC.module.state_dict(),
                osp.join(args.output_dir, "target_C_" + args.savename + ".pt"))
        else:
            torch.save(
                netF.state_dict(),
                osp.join(args.output_dir, "target_F_" + args.savename + ".pt"))
            torch.save(
                netH.state_dict(),
                osp.join(args.output_dir, "target_H_" + args.savename + ".pt"))
            torch.save(
                netB.state_dict(),
                osp.join(args.output_dir, "target_B_" + args.savename + ".pt"))
            torch.save(
                netC.state_dict(),
                osp.join(args.output_dir, "target_C_" + args.savename + ".pt"))

    return netF, netH, netB, netC
Beispiel #9
0
    alpha = 1.5747007053351507
    l = 0.5241677630566622
    margin = 0.040520629258433416
    beta_ratio = 0.7103921571238655
    gamma = 0.368667605025003
    weight_decay = 0.014055481861393148

args.arch = "ResNet50ta_bt"
model = models.init_model(name=args.arch,
                          num_classes=dataset.num_train_pids,
                          loss={'xent', 'htri'})
print("Model size: {:.5f}M".format(
    sum(p.numel() for p in model.parameters()) / 1000000.0))

criterion_htri = TripletLoss(margin, 'cosine')
criterion_xent = CrossEntropyLabelSmooth(dataset.num_train_pids)
criterion_center_loss = CenterLoss(use_gpu=use_gpu)

if args.use_OSMCAA:
    print("USING OSM LOSS")
    print("config, alpha = %f  sigma = %f  l=%f" % (alpha, sigma, l))
    criterion_osm_caa = OSM_CAA_Loss(alpha=alpha, l=l, osm_sigma=sigma)
else:
    criterion_osm_caa = None

if args.cl_centers:
    print("USING CL CENTERS")
    print("config, alpha = %f  sigma = %f  l=%f" % (alpha, sigma, l))
    criterion_osm_caa = OSM_CAA_Loss(alpha=alpha, l=l, osm_sigma=sigma)

base_learning_rate = 0.00035
Beispiel #10
0
def train_source(args):
    dset_loaders = data_load(args)
    ## set base network
    if args.norm_layer == 'batchnorm':
        norm_layer = nn.BatchNorm2d
    elif args.norm_layer == 'groupnorm':

        def gn_helper(planes):
            return nn.GroupNorm(8, planes)

        norm_layer = gn_helper
    if args.net[0:3] == 'res':
        if '26' in args.net:
            netF = network.ResCifarBase(26, norm_layer=norm_layer)
            args.bottleneck = netF.in_features // 2
        else:
            netF = network.ResBase(res_name=args.net, args=args)
    elif args.net[0:3] == 'vgg':
        netF = network.VGGBase(vgg_name=args.net)

    if args.ssl_before_btn:
        netH = network.ssl_head(ssl_task=args.ssl_task,
                                feature_dim=netF.in_features,
                                embedding_dim=args.embedding_dim)
    else:
        netH = network.ssl_head(ssl_task=args.ssl_task,
                                feature_dim=args.bottleneck,
                                embedding_dim=args.embedding_dim)
    if args.bottleneck != 0:
        netB = network.feat_bootleneck(type=args.classifier,
                                       feature_dim=netF.in_features,
                                       bottleneck_dim=args.bottleneck,
                                       norm_btn=args.norm_btn)
        netC = network.feat_classifier(type=args.layer,
                                       class_num=args.class_num,
                                       bottleneck_dim=args.bottleneck,
                                       bias=args.classifier_bias,
                                       temp=args.angular_temp,
                                       args=args)
    else:
        netB = nn.Identity()
        netC = network.feat_classifier(type=args.layer,
                                       class_num=args.class_num,
                                       bottleneck_dim=netF.in_features,
                                       bias=args.classifier_bias,
                                       temp=args.angular_temp,
                                       args=args)

    if args.dataparallel:
        netF = nn.DataParallel(netF).cuda()
        netH = nn.DataParallel(netH).cuda()
        netB = nn.DataParallel(netB).cuda()
        netC = nn.DataParallel(netC).cuda()
    else:
        netF.cuda()
        netH.cuda()
        netB.cuda()
        netC.cuda()

    param_group = []
    learning_rate = args.lr
    for k, v in netF.named_parameters():
        if args.separate_wd and ('bias' in k or 'norm' in k):
            param_group += [{
                'params': v,
                'lr': learning_rate * 0.1,
                'weight_decay': 0
            }]
        else:
            param_group += [{
                'params': v,
                'lr': learning_rate * 0.1,
                'weight_decay': args.weight_decay
            }]
    for k, v in netH.named_parameters():
        if args.separate_wd and ('bias' in k or 'norm' in k):
            param_group += [{
                'params': v,
                'lr': learning_rate,
                'weight_decay': 0
            }]
        else:
            param_group += [{
                'params': v,
                'lr': learning_rate,
                'weight_decay': args.weight_decay
            }]
    for k, v in netB.named_parameters():
        if args.separate_wd and ('bias' in k or 'norm' in k):
            param_group += [{
                'params': v,
                'lr': learning_rate,
                'weight_decay': 0
            }]
        else:
            param_group += [{
                'params': v,
                'lr': learning_rate,
                'weight_decay': args.weight_decay
            }]
    for k, v in netC.named_parameters():
        if args.separate_wd and ('bias' in k or 'norm' in k):
            param_group += [{
                'params': v,
                'lr': learning_rate,
                'weight_decay': 0
            }]
        else:
            param_group += [{
                'params': v,
                'lr': learning_rate,
                'weight_decay': args.weight_decay
            }]

    optimizer = optim.SGD(param_group)
    optimizer = op_copy(optimizer)

    acc_init = 0
    if args.class_stratified:
        max_iter = args.max_epoch * len(
            dset_loaders["source_tr"].batch_sampler)
    else:
        max_iter = args.max_epoch * len(dset_loaders["source_tr"])
    interval_iter = max_iter // 10
    iter_num = 0
    epoch = 0

    netF.train()
    netH.train()
    netB.train()
    netC.train()

    if args.use_focal_loss:
        cls_loss_fn = FocalLoss(alpha=args.focal_alpha,
                                gamma=args.focal_gamma,
                                reduction='mean')
    else:
        if args.ce_weighting:
            w = torch.Tensor(args.ce_weight).cuda()
            w.requires_grad = False
            if args.smooth == 0:
                cls_loss_fn = nn.CrossEntropyLoss(weight=w).cuda()
            else:
                cls_loss_fn = CrossEntropyLabelSmooth(
                    num_classes=args.class_num, epsilon=args.smooth,
                    weight=w).cuda()
        else:
            if args.smooth == 0:
                cls_loss_fn = nn.CrossEntropyLoss().cuda()
            else:
                cls_loss_fn = CrossEntropyLabelSmooth(
                    num_classes=args.class_num, epsilon=args.smooth).cuda()

    if args.ssl_task in ['simclr', 'crs']:
        if args.use_new_ntxent:
            ssl_loss_fn = SupConLoss(temperature=args.temperature,
                                     base_temperature=args.temperature).cuda()
        else:
            ssl_loss_fn = NTXentLoss(args.batch_size, args.temperature,
                                     True).cuda()
    elif args.ssl_task in ['supcon', 'crsc']:
        ssl_loss_fn = SupConLoss(temperature=args.temperature,
                                 base_temperature=args.temperature).cuda()
    elif args.ssl_task == 'ls_supcon':
        ssl_loss_fn = LabelSmoothedSCLLoss(args.batch_size, args.temperature,
                                           args.class_num, args.ssl_smooth)

    if args.cr_weight > 0:
        if args.cr_metric == 'cos':
            dist = nn.CosineSimilarity(dim=1).cuda()
        elif args.cr_metric == 'l1':
            dist = nn.PairwiseDistance(p=1).cuda()
        elif args.cr_metric == 'l2':
            dist = nn.PairwiseDistance(p=2).cuda()
        elif args.cr_metric == 'bce':
            dist = nn.BCEWithLogitsLoss(reduction='sum').cuda()
        elif args.cr_metric == 'kl':
            dist = nn.KLDivLoss(reduction='sum').cuda()
        elif args.cr_metric == 'js':
            dist = JSDivLoss(reduction='sum').cuda()

    use_second_pass = (args.ssl_task in ['simclr', 'supcon', 'ls_supcon'
                                         ]) and (args.ssl_weight > 0)
    use_third_pass = (args.cr_weight >
                      0) or (args.ssl_task in ['crsc', 'crs']
                             and args.ssl_weight > 0) or (args.cls3)

    while iter_num < max_iter:
        try:
            inputs_source, labels_source = iter_source.next()
        except:
            iter_source = iter(dset_loaders["source_tr"])
            if args.class_stratified:
                dset_loaders["source_tr"].batch_sampler.set_epoch(epoch)
            epoch += 1
            inputs_source, labels_source = iter_source.next()

        try:
            if inputs_source.size(0) == 1:
                continue
        except:
            if inputs_source[0].size(0) == 1:
                continue

        iter_num += 1
        lr_scheduler(args, optimizer, iter_num=iter_num, max_iter=max_iter)

        inputs_source1 = None
        inputs_source2 = None
        inputs_source3 = None
        labels_source = labels_source.cuda()

        if args.layer in ['add_margin', 'arc_margin', 'shpere']:
            labels_forward = labels_source
        else:
            labels_forward = None

        if type(inputs_source) is list:
            inputs_source1 = inputs_source[0].cuda()
            inputs_source2 = inputs_source[1].cuda()
            if len(inputs_source) == 3:
                inputs_source3 = inputs_source[2].cuda()
        else:
            inputs_source1 = inputs_source.cuda()

        if inputs_source1 is not None:
            f1 = netF(inputs_source1)
            b1 = netB(f1)
            outputs_source = netC(b1, labels_forward)
        if use_second_pass:
            f2 = netF(inputs_source2)
            b2 = netB(f2)
        if use_third_pass:
            if args.sg3:
                with torch.no_grad():
                    f3 = netF(inputs_source3)
                    b3 = netB(f3)
                    c3 = netC(b3, labels_forward)
                    conf = torch.max(F.softmax(c3, dim=1), dim=1)[0]
            else:
                f3 = netF(inputs_source3)
                b3 = netB(f3)
                c3 = netC(b3, labels_forward)
                conf = torch.max(F.softmax(c3, dim=1), dim=1)[0]

        if args.cr_weight > 0:
            if args.cr_site == 'feat':
                f_hard = f1
                f_weak = f3
            elif args.cr_site == 'btn':
                f_hard = b1
                f_weak = b3
            elif args.cr_site == 'cls':
                f_hard = outputs_source
                f_weak = c3
                if args.cr_metric in ['kl', 'js']:
                    f_hard = F.softmax(f_hard, dim=-1)
                if args.cr_metric in ['bce', 'kl', 'js']:
                    f_weak = F.softmax(f_weak, dim=-1)
            else:
                raise NotImplementedError

        classifier_loss = cls_loss_fn(outputs_source, labels_source)
        if args.cls3:
            classifier_loss += cls_loss_fn(c3, labels_source)

        if args.ssl_weight > 0:
            if args.ssl_before_btn:
                z1 = netH(f1, args.norm_feat)
                if use_second_pass:
                    z2 = netH(f2, args.norm_feat)
                if use_third_pass:
                    z3 = netH(f3, args.norm_feat)
            else:
                z1 = netH(b1, args.norm_feat)
                if use_second_pass:
                    z2 = netH(b2, args.norm_feat)
                if use_third_pass:
                    z3 = netH(b3, args.norm_feat)

            if args.ssl_task in 'simclr':
                if args.use_new_ntxent:
                    z = torch.cat([z1.unsqueeze(1), z2.unsqueeze(1)], dim=1)
                    ssl_loss = ssl_loss_fn(z)
                else:
                    ssl_loss = ssl_loss_fn(z1, z2)
            elif args.ssl_task == 'supcon':
                z = torch.cat([z1.unsqueeze(1), z2.unsqueeze(1)], dim=1)
                ssl_loss = ssl_loss_fn(z, labels=labels_source)
            elif args.ssl_task == 'ls_supcon':
                ssl_loss = ssl_loss_fn(z1, z2, labels_source)
            elif args.ssl_task == 'crsc':
                z = torch.cat([z1.unsqueeze(1), z3.unsqueeze(1)], dim=1)
                ssl_loss = ssl_loss_fn(z, labels_source)
            elif args.ssl_task == 'crs':
                if args.use_new_ntxent:
                    z = torch.cat([z1.unsqueeze(1), z3.unsqueeze(1)], dim=1)
                    ssl_loss = ssl_loss_fn(z)
                else:
                    ssl_loss = ssl_loss_fn(z1, z3)
        else:
            ssl_loss = torch.tensor(0.0).cuda()

        if args.cr_weight > 0:
            try:
                cr_loss = dist(f_hard[conf <= args.cr_threshold],
                               f_weak[conf <= args.cr_threshold]).mean()

                if args.cr_metric == 'cos':
                    cr_loss *= -1
            except:
                print('Error computing CR loss')
                cr_loss = torch.tensor(0.0).cuda()
        else:
            cr_loss = torch.tensor(0.0).cuda()

        if args.ent_weight > 0:
            softmax_out = nn.Softmax(dim=1)(outputs_source)
            entropy_loss = torch.mean(Entropy(softmax_out))
            classifier_loss += args.ent_weight * entropy_loss

        if args.gent_weight > 0:
            softmax_out = nn.Softmax(dim=1)(outputs_source)
            msoftmax = softmax_out.mean(dim=0)
            gentropy_loss = torch.sum(-msoftmax *
                                      torch.log(msoftmax + args.epsilon))
            classifier_loss -= args.gent_weight * gentropy_loss

        loss = classifier_loss + args.ssl_weight * ssl_loss + args.cr_weight * cr_loss

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if iter_num % interval_iter == 0 or iter_num == max_iter:
            netF.eval()
            netH.eval()
            netB.eval()
            netC.eval()
            if args.dset == 'visda-c':
                acc_s_te, acc_list = cal_acc(dset_loaders['source_te'], netF,
                                             netH, netB, netC, args, True)
                log_str = 'Task: {}, Iter:{}/{}; Accuracy = {:.2f}%'.format(
                    args.name_src, iter_num, max_iter,
                    acc_s_te) + '\n' + acc_list
            else:
                acc_s_te, _ = cal_acc(dset_loaders['source_te'], netF, netH,
                                      netB, netC, args, False)
                log_str = 'Task: {}, Iter:{}/{}; Accuracy = {:.2f}%'.format(
                    args.name_src, iter_num, max_iter, acc_s_te)
            args.out_file.write(log_str + '\n')
            args.out_file.flush()
            print(log_str + '\n')

            if acc_s_te >= acc_init:
                acc_init = acc_s_te

                if args.dataparallel:
                    best_netF = netF.module.state_dict()
                    best_netH = netH.module.state_dict()
                    best_netB = netB.module.state_dict()
                    best_netC = netC.module.state_dict()
                else:
                    best_netF = netF.state_dict()
                    best_netH = netH.state_dict()
                    best_netB = netB.state_dict()
                    best_netC = netC.state_dict()

            netF.train()
            netH.train()
            netB.train()
            netC.train()

    torch.save(best_netF, osp.join(args.output_dir_src, "source_F.pt"))
    torch.save(best_netH, osp.join(args.output_dir_src, "source_H.pt"))
    torch.save(best_netB, osp.join(args.output_dir_src, "source_B.pt"))
    torch.save(best_netC, osp.join(args.output_dir_src, "source_C.pt"))

    return netF, netH, netB, netC
Beispiel #11
0
def train_source_simp(args):
    dset_loaders = data_load(args)
    if args.net_src[0:3] == 'res':
        netF = network.ResBase(res_name=args.net_src).cuda()
    netC = network.feat_classifier_simpl(class_num=args.class_num,
                                         feat_dim=netF.in_features).cuda()

    param_group = []
    learning_rate = args.lr_src
    for k, v in netF.named_parameters():
        param_group += [{'params': v, 'lr': learning_rate * 0.1}]
    for k, v in netC.named_parameters():
        param_group += [{'params': v, 'lr': learning_rate}]
    optimizer = optim.SGD(param_group)
    optimizer = op_copy(optimizer)

    acc_init = 0
    max_iter = args.max_epoch * len(dset_loaders["source_tr"])
    interval_iter = max_iter // 10
    iter_num = 0

    netF.train()
    netC.train()

    while iter_num < max_iter:
        try:
            inputs_source, labels_source = iter_source.next()
        except:
            iter_source = iter(dset_loaders["source_tr"])
            inputs_source, labels_source = iter_source.next()

        if inputs_source.size(0) == 1:
            continue

        iter_num += 1
        lr_scheduler(optimizer, iter_num=iter_num, max_iter=max_iter)

        inputs_source, labels_source = inputs_source.cuda(
        ), labels_source.cuda()
        outputs_source = netC(netF(inputs_source))
        classifier_loss = CrossEntropyLabelSmooth(num_classes=args.class_num,
                                                  epsilon=0.1)(outputs_source,
                                                               labels_source)

        optimizer.zero_grad()
        classifier_loss.backward()
        optimizer.step()

        if iter_num % interval_iter == 0 or iter_num == max_iter:
            netF.eval()
            netC.eval()
            acc_s_te, _ = cal_acc(dset_loaders['source_te'], netF, None, netC,
                                  False)
            log_str = 'Task: {}, Iter:{}/{}; Accuracy = {:.2f}%'.format(
                args.name_src, iter_num, max_iter, acc_s_te)
            args.out_file.write(log_str + '\n')
            args.out_file.flush()
            print(log_str + '\n')

            if acc_s_te >= acc_init:
                acc_init = acc_s_te
                best_netF = netF.state_dict()
                best_netC = netC.state_dict()

            netF.train()
            netC.train()

    torch.save(best_netF, osp.join(args.output_dir_src, "source_F.pt"))
    torch.save(best_netC, osp.join(args.output_dir_src, "source_C.pt"))

    return netF, netC
Beispiel #12
0
def train_source(args):
    dset_loaders = data_load(args)
    ## set base network
    if args.net[0:3] == 'res' or args.net[0:3] == 'vgg':
        if args.net[0:3] == 'res':
            netF = network.ResBase(res_name=args.net).cuda()
        else:
            netF = network.VGGBase(vgg_name=args.net).cuda()
        netB = network.feat_bootleneck(type=args.classifier, feature_dim=netF.in_features,
                                       bottleneck_dim=args.bottleneck).cuda()   # classifier: bn
        netC = network.feat_classifier(type=args.layer, class_num=args.class_num,
                                       bottleneck_dim=args.bottleneck).cuda()   # layer: wn

        if args.resume:
            args.modelpath = args.output_dir_src + '/source_F.pt'
            netF.load_state_dict(torch.load(args.modelpath))
            args.modelpath = args.output_dir_src + '/source_B.pt'
            netB.load_state_dict(torch.load(args.modelpath))
            args.modelpath = args.output_dir_src + '/source_C.pt'
            netC.load_state_dict(torch.load(args.modelpath))

        param_group = []
        learning_rate = args.lr
        for k, v in netF.named_parameters():
            param_group += [{'params': v, 'lr': learning_rate * 0.1}]
        for k, v in netB.named_parameters():
            param_group += [{'params': v, 'lr': learning_rate}]
        for k, v in netC.named_parameters():
            param_group += [{'params': v, 'lr': learning_rate}]
        optimizer = optim.SGD(param_group)
        optimizer = op_copy(optimizer)

    acc_init = 0.
    max_iter = args.max_epoch * len(dset_loaders["source_tr"])
    print_loss_interval = 25
    interval_iter = 100
    iter_num = 0

    if args.net[0:3] == 'res' or args.net[0:3] == 'vgg':
        netF.train()
        netB.train()
        netC.train()

    while iter_num < max_iter:
        try:
            inputs_source, labels_source = iter_source.next()
        except:
            iter_source = iter(dset_loaders["source_tr"])
            inputs_source, labels_source = iter_source.next()

        if inputs_source.size(0) == 1:
            continue

        iter_num += 1
        lr_scheduler(optimizer, iter_num=iter_num, max_iter=max_iter)

        inputs_source, labels_source = inputs_source.cuda(), labels_source.cuda()
        outputs_source = netC(netB(netF(inputs_source)))
        classifier_loss = CrossEntropyLabelSmooth(num_classes=args.class_num, epsilon=args.smooth)(outputs_source,
                                                                                                   labels_source)
        optimizer.zero_grad()
        classifier_loss.backward()
        optimizer.step()

        if iter_num % print_loss_interval == 0:
            print("Iter:{:>4d}/{} | Classification loss on Source: {:.2f}".format(iter_num, max_iter,
                                                                              classifier_loss.item()))
        if iter_num % interval_iter == 0 or iter_num == max_iter:
            netF.eval()
            netB.eval()
            netC.eval()
            if args.dset == 'VISDA-RSUT' or args.dset == 'VISDA-RSUT-50' or args.dset == 'VISDA-RSUT-10':
                # The small classes in VisDA-C (RSUT) still have relatively many samples.
                # Safe to use per-class average accuracy.
                acc_s_te, acc_list, acc_cls_avg_te= cal_acc(dset_loaders['source_te'], netF, netB, netC,
                                                            per_class_flag=True, visda_flag=True)
                log_str = 'Task: {}, Iter:{}/{}; Accuracy = {:.2f}%, Cls Avg Acc = {:.2f}'.format(args.name_src, iter_num, max_iter,
                                                                acc_s_te, acc_cls_avg_te) + '\n' + acc_list
                cur_acc = acc_cls_avg_te

            else:
                if args.trte == 'stratified':
                    # Stratified cross validation ensures the existence of every class in the validation set.
                    # Safe to use per-class average accuracy.
                    acc_s_te, acc_cls_avg_te, _ = cal_acc(dset_loaders['source_te'], netF, netB, netC,
                                                          per_class_flag=True, visda_flag=False)
                    log_str = 'Task: {}, Iter:{}/{}; Accuracy = {:.2f}%, Cls Avg Acc = {:.2f}'.format(args.name_src,
                                                                iter_num, max_iter, acc_s_te, acc_cls_avg_te)
                    cur_acc = acc_cls_avg_te
                else:
                    # Conventional cross validation may lead to the absence of certain classes in validation set,
                    # esp. when the dataset includes some very small classes, e.g., Office-Home (RSUT), DomainNet.
                    # Use overall accuracy to avoid 'nan' issue.
                    acc_s_te, _ = cal_acc(dset_loaders['source_te'], netF, netB, netC,
                                          per_class_flag=False, visda_flag=False)
                    log_str = 'Task: {}, Iter:{}/{}; Accuracy = {:.2f}%'.format(args.name_src, iter_num, max_iter, acc_s_te)
                    cur_acc = acc_s_te

            args.out_file.write(log_str + '\n')
            args.out_file.flush()
            print(log_str + '\n')

            if cur_acc >= acc_init and iter_num >= 3 * len(dset_loaders["source_tr"]):
                # first 3 epochs: not stable yet
                acc_init = cur_acc
                best_netF = netF.state_dict()
                best_netB = netB.state_dict()
                best_netC = netC.state_dict()

            netF.train()
            netB.train()
            netC.train()

    torch.save(best_netF, osp.join(args.output_dir_src, "source_F.pt"))
    torch.save(best_netB, osp.join(args.output_dir_src, "source_B.pt"))
    torch.save(best_netC, osp.join(args.output_dir_src, "source_C.pt"))

    return netF, netB, netC
Beispiel #13
0
def train_source(args):
    dset_loaders = data_load(args)
    ## set base network
    if args.net[0:3] == 'res':
        netF = network.ResBase(res_name=args.net).cuda()
    elif args.net[0:3] == 'vgg':
        netF = network.VGGBase(vgg_name=args.net).cuda()

    netB = network.feat_bootleneck(type=args.classifier,
                                   feature_dim=netF.in_features,
                                   bottleneck_dim=args.bottleneck).cuda()
    netC = network.feat_classifier(type=args.layer,
                                   class_num=args.class_num,
                                   bottleneck_dim=args.bottleneck).cuda()

    param_group = []
    learning_rate = args.lr
    for k, v in netF.named_parameters():
        param_group += [{'params': v, 'lr': learning_rate * 0.1}]
    for k, v in netB.named_parameters():
        param_group += [{'params': v, 'lr': learning_rate}]
    for k, v in netC.named_parameters():
        param_group += [{'params': v, 'lr': learning_rate}]
    optimizer = optim.SGD(param_group)
    optimizer = op_copy(optimizer)

    acc_init = 0
    max_iter = args.max_epoch * len(dset_loaders["source_tr"])
    interval_iter = max_iter // 10
    iter_num = 0

    netF.train()
    netB.train()
    netC.train()

    # wandb watching
    wandb.watch(netF)
    wandb.watch(netB)
    wandb.watch(netC)

    while iter_num < max_iter:
        try:
            inputs_source, labels_source = iter_source.next()
        except:
            iter_source = iter(dset_loaders["source_tr"])
            inputs_source, labels_source = iter_source.next()

        if inputs_source.size(0) == 1:
            continue

        iter_num += 1
        lr_scheduler(optimizer, iter_num=iter_num, max_iter=max_iter)

        inputs_source, labels_source = inputs_source.cuda(
        ), labels_source.cuda()
        outputs_source = netC(netB(netF(inputs_source)))
        classifier_loss = CrossEntropyLabelSmooth(
            num_classes=args.class_num, epsilon=args.smooth)(outputs_source,
                                                             labels_source)

        optimizer.zero_grad()
        classifier_loss.backward()
        optimizer.step()

        if iter_num % interval_iter == 0 or iter_num == max_iter:
            netF.eval()
            netB.eval()
            netC.eval()
            if args.dset == 'VISDA18' or args.dset == 'VISDA-C':
                acc_s_te, acc_list = cal_acc(dset_loaders['source_te'], netF,
                                             netB, netC, True)
                log_str = 'Task: {}, Iter:{}/{}; Accuracy = {:.2f}%'.format(
                    args.name_src, iter_num, max_iter,
                    acc_s_te) + '\n' + acc_list
                wandb.log({"accuracy": acc_s_te})
            else:
                acc_s_te, _ = cal_acc(dset_loaders['source_te'], netF, netB,
                                      netC, False)
                log_str = 'Task: {}, Iter:{}/{}; Accuracy = {:.2f}%'.format(
                    args.name_src, iter_num, max_iter, acc_s_te)
            args.out_file.write(log_str + '\n')
            args.out_file.flush()
            print(log_str + '\n')

            if acc_s_te >= acc_init:
                acc_init = acc_s_te
                best_netF = netF.state_dict()
                best_netB = netB.state_dict()
                best_netC = netC.state_dict()

            netF.train()
            netB.train()
            netC.train()

    torch.save(best_netF, osp.join(args.output_dir_src, "source_F.pt"))
    torch.save(best_netB, osp.join(args.output_dir_src, "source_B.pt"))
    torch.save(best_netC, osp.join(args.output_dir_src, "source_C.pt"))

    return netF, netB, netC
Beispiel #14
0
train = process_dir(txt_path='/data/zhoumi/datasets/train_data/train.txt')

if use_triplet == True:
    train_data = DataLoader(train_datasets, sampler=RandomIdentitySampler_new(train, NUM_CLASSES ,4),
                        batch_size=BATCH_SIZE, pin_memory=True, num_workers=8, drop_last=True)
else:
    train_data = DataLoader(train_datasets, batch_size=BATCH_SIZE, pin_memory=True, shuffle=True)

test_data = DataLoader(MyDataset(txt_path='/data/zhoumi/datasets/train_data/val.txt', transform=test_transform),
                       batch_size=TEST_BATCH_SIZE, pin_memory=True)

optimizer = optim.Adam(params=model.parameters(), lr=1e-4, weight_decay=5e-4)
# optimizer = optim.SGD(model.parameters(), lr=1e-4, momentum=0.9, weight_decay=5e-4)

#define loss function
xent_criterion = CrossEntropyLabelSmooth(NUM_CLASSES)
center_criterion = CenterLoss(NUM_CLASSES, feat_dim=1792)
triplet_criterion = TripletLoss(margin=0.3)

best_model = model
best_acc = 0
print(len(test_data) * TEST_BATCH_SIZE, len(train_data))

model = model.cuda()

for epoch in range(MAX_EPOC):
    lr = adjust_lr(epoch)
    for p in optimizer.param_groups:
        p['lr'] = lr

    for i, inputs in enumerate(train_data):