Esempio n. 1
0
def main(args):
    utils.init_distributed_mode(args)
    print(args)

    device = torch.device(args.device)

    # Data loading code
    print("Loading data")

    dataset, num_classes = get_dataset(args.dataset, "train", get_transform(train=True), args.data_path)
    dataset_test, _ = get_dataset(args.dataset, "val", get_transform(train=False), args.data_path)

    print("Creating data loaders")
    if args.distributed:
        train_sampler = torch.utils.data.distributed.DistributedSampler(dataset)
        test_sampler = torch.utils.data.distributed.DistributedSampler(dataset_test)
    else:
        train_sampler = torch.utils.data.RandomSampler(dataset)
        test_sampler = torch.utils.data.SequentialSampler(dataset_test)

    if args.aspect_ratio_group_factor >= 0:
        group_ids = create_aspect_ratio_groups(dataset, k=args.aspect_ratio_group_factor)
        train_batch_sampler = GroupedBatchSampler(train_sampler, group_ids, args.batch_size)
    else:
        train_batch_sampler = torch.utils.data.BatchSampler(
            train_sampler, args.batch_size, drop_last=True)

    data_loader = torch.utils.data.DataLoader(
        dataset, batch_sampler=train_batch_sampler, num_workers=args.workers,
        collate_fn=utils.collate_fn)

    data_loader_test = torch.utils.data.DataLoader(
        dataset_test, batch_size=1,
        sampler=test_sampler, num_workers=args.workers,
        collate_fn=utils.collate_fn)

    print("Creating model")
    model = torchvision.models.detection.__dict__[args.model](num_classes=num_classes,
                                                              pretrained=args.pretrained)
    model.to(device)

    model_without_ddp = model
    if args.distributed:
        model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
        model_without_ddp = model.module

    params = [p for p in model.parameters() if p.requires_grad]
    optimizer = torch.optim.SGD(
        params, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)

    # lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=args.lr_step_size, gamma=args.lr_gamma)
    lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=args.lr_steps, gamma=args.lr_gamma)

    if args.resume:
        checkpoint = torch.load(args.resume, map_location='cpu')
        model_without_ddp.load_state_dict(checkpoint['model'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
        args.start_epoch = checkpoint['epoch'] + 1

    if args.test_only:
        evaluate(model, data_loader_test, device=device)
        return

    print("Start training")
    start_time = time.time()
    for epoch in range(args.start_epoch, args.epochs):
        if args.distributed:
            train_sampler.set_epoch(epoch)
        train_one_epoch(model, optimizer, data_loader, device, epoch, args.print_freq)
        lr_scheduler.step()
        if args.output_dir:
            utils.save_on_master({
                'model': model_without_ddp.state_dict(),
                'optimizer': optimizer.state_dict(),
                'lr_scheduler': lr_scheduler.state_dict(),
                'args': args,
                'epoch': epoch},
                os.path.join(args.output_dir, 'model_{}.pth'.format(epoch)))

        # evaluate after every epoch
        evaluate(model, data_loader_test, device=device)

    total_time = time.time() - start_time
    total_time_str = str(datetime.timedelta(seconds=int(total_time)))
    print('Training time {}'.format(total_time_str))
Esempio n. 2
0
def main(args):
    torch.cuda.set_device(0)
    random.seed(0)
    torch.manual_seed(0)
    torch.cuda.manual_seed(0)
    torch.cuda.manual_seed_all(0)
    print(args)

    device = torch.device(args.device)

    # Data loading code
    print("Loading data")
    if 'voc2007' in args.dataset:
        dataset, num_classes = get_dataset(args.dataset, "trainval", get_transform(train=True), args.data_path)
        dataset_test, _ = get_dataset(args.dataset, "test", get_transform(train=False), args.data_path)
    else:
        dataset, num_classes = get_dataset(args.dataset, "train", get_transform(train=True), args.data_path)
        dataset_test, _ = get_dataset(args.dataset, "val", get_transform(train=False), args.data_path)
    if 'voc' in args.dataset:
        init_num = 500
        budget_num = 500
        if 'retina' in args.model:
            init_num = 1000
            budget_num = 500
    else:
        init_num = 5000
        budget_num = 1000
    print("Creating data loaders")
    num_images = len(dataset)
    indices = list(range(num_images))
    random.shuffle(indices)
    labeled_set = indices[:init_num]
    unlabeled_set = list(set(indices) - set(labeled_set))
    train_sampler = SubsetRandomSampler(labeled_set)
    test_sampler = torch.utils.data.SequentialSampler(dataset_test)
    data_loader_test = DataLoader(dataset_test, batch_size=1, sampler=test_sampler, num_workers=args.workers,
                                  collate_fn=utils.collate_fn)

    # SSM parameters
    gamma = 0.15
    clslambda = np.array([-np.log(0.9)] * (num_classes - 1))
    # Start active learning cycles training
    for cycle in range(args.cycles):
        if args.aspect_ratio_group_factor >= 0:
            group_ids = create_aspect_ratio_groups(dataset, k=args.aspect_ratio_group_factor)
            train_batch_sampler = GroupedBatchSampler(train_sampler, group_ids, args.batch_size)
        else:
            train_batch_sampler = torch.utils.data.BatchSampler(train_sampler, args.batch_size, drop_last=True)

        data_loader = torch.utils.data.DataLoader(dataset, batch_sampler=train_batch_sampler, num_workers=args.workers,
                                                  collate_fn=utils.collate_fn)

        print("Creating model")
        if 'voc' in args.dataset:
            if 'faster' in args.model:
                task_model = fasterrcnn_resnet50_fpn_ssm(num_classes=num_classes, min_size=600, max_size=1000)
            elif 'retina' in args.model:
                task_model = retinanet_resnet50_fpn_ssm(num_classes=num_classes, min_size=600, max_size=1000)
        else:
            if 'faster' in args.model:
                task_model = fasterrcnn_resnet50_fpn_ssm(num_classes=num_classes, min_size=800, max_size=1333)
            elif 'retina' in args.model:
                task_model = retinanet_resnet50_fpn_ssm(num_classes=num_classes, min_size=800, max_size=1333)
        task_model.to(device)

        if not args.init and cycle == 0 and args.skip:
            if 'faster' in args.model:
                checkpoint = torch.load(os.path.join(args.first_checkpoint_path,
                                                     '{}_frcnn_1st.pth'.format(args.dataset)), map_location='cpu')
            elif 'retina' in args.model:
                checkpoint = torch.load(os.path.join(args.first_checkpoint_path,
                                                     '{}_retinanet_1st.pth'.format(args.dataset)), map_location='cpu')
            task_model.load_state_dict(checkpoint['model'])
            if args.test_only:
                if 'coco' in args.dataset:
                    coco_evaluate(task_model, data_loader_test)
                elif 'voc' in args.dataset:
                    # task_model.ssm_mode(False)
                    voc_evaluate(task_model, data_loader_test, args.dataset, path=args.results_path)
                return
            # if 'coco' in args.dataset:
            #     coco_evaluate(task_model, data_loader_test)
            # elif 'voc' in args.dataset:
            #     voc_evaluate(task_model, data_loader_test, args.dataset)
            print("Getting stability")
            random.shuffle(unlabeled_set)
            if 'coco' in args.dataset:
                subset = unlabeled_set[:10000]
            else:
                subset = unlabeled_set
            unlabeled_loader = DataLoader(dataset, batch_size=1, sampler=SubsetSequentialSampler(subset),
                                          num_workers=args.workers,
                                          # more convenient if we maintain the order of subset
                                          pin_memory=True, collate_fn=utils.collate_fn)
            print("Getting detections from unlabeled set")
            allScore, allBox, allY, al_idx = get_uncertainty(task_model, unlabeled_loader)
            al_idx = [subset[i] for i in al_idx]
            # al_idx = subset[:budget_num]
            cls_sum = 0
            cls_loss_sum = np.zeros((num_classes - 1,))
            print(
                "First stage results: unlabeled set: {}, tobe labeled set: {}".format(len(subset), len(al_idx)))
            if len(al_idx) >= budget_num:
                al_idx = al_idx[:budget_num]
                labeled_set += al_idx
                subset = list(set(subset) - set(al_idx))
                print(len(set(labeled_set)))
                print(
                    "First stage results: unlabeled set: {}, tobe labeled set: {}".format(len(subset), len(al_idx)))
                # Create a new dataloader for the updated labeled dataset
                train_sampler = SubsetRandomSampler(labeled_set)
                clslambda = 0.9 * clslambda - 0.1 * np.log(softmax(cls_loss_sum / (cls_sum + 1e-30)))
                gamma = min(gamma + 0.05, 1)
                unlabeled_set = list(set(unlabeled_set) - set(al_idx))
                continue
            subset = list(set(subset) - set(al_idx))
            print("Image cross validation")
            for i in range(len(subset)):
                if len(al_idx) >= budget_num:
                    break
                cls_sum += len(allBox[i])
                for j, box in enumerate(allBox[i]):
                    if len(al_idx) >= budget_num:
                        break
                    score = allScore[i][j]
                    label = torch.tensor(allY[i][j]).cuda()
                    loss = -((1 + label.cpu().numpy()) / 2 * np.log(score.cpu().numpy()) + (
                            1 - label.cpu().numpy()) / 2 * np.log(1 - score.cpu().numpy() + 1e-30))
                    cls_loss_sum += loss
                    v, v_val = judge_uv(loss, gamma, clslambda)
                    if v:
                        # print(label)
                        if torch.sum(label == 1) == 1 and torch.where(label == 1)[0] != 0:
                            # add Imgae Cross Validation
                            pre_cls = torch.where(label == 1)[0]
                            pre_box = box
                            curr_ind = [subset[i]]
                            curr_sampler = SubsetSequentialSampler(curr_ind)
                            curr_loader = DataLoader(dataset, batch_size=1, sampler=curr_sampler,
                                                     num_workers=args.workers, pin_memory=True,
                                                     collate_fn=utils.collate_fn)
                            labeled_sampler = SubsetRandomSampler(labeled_set)
                            labeled_loader = DataLoader(dataset, batch_size=1, sampler=labeled_sampler,
                                                        num_workers=args.workers, pin_memory=True,
                                                        collate_fn=utils.collate_fn)
                            cross_validate, _ = image_cross_validation(
                                task_model, curr_loader, labeled_loader, pre_box, pre_cls)
                            if not cross_validate:
                                al_idx.append(subset[i])
                                break
                        else:
                            continue
                    else:
                        al_idx.append(subset[i])
                        break
            # Update the labeled dataset and the unlabeled dataset, respectively
            print(
                "Second stage results: unlabeled set: {}, tobe labeled set: {}".format(len(subset),
                                                                                       len(set(al_idx))))
            subset = list(set(subset) - set(al_idx))
            if len(al_idx) > budget_num:
                al_idx = al_idx[:budget_num]
            if len(al_idx) < budget_num:
                al_idx += list(set(subset) - set(al_idx))[:budget_num - len(al_idx)]
            labeled_set += al_idx
            unlabeled_set = list(set(unlabeled_set) - set(al_idx))
            print(
                "Second stage results: unlabeled set: {}, tobe labeled set: {}".format(len(subset), len(set(al_idx))))
            # Create a new dataloader for the updated labeled dataset
            train_sampler = SubsetRandomSampler(labeled_set)
            clslambda = 0.9 * clslambda - 0.1 * np.log(softmax(cls_loss_sum / (cls_sum + 1e-30)))
            gamma = min(gamma + 0.05, 1)
            continue

        params = [p for p in task_model.parameters() if p.requires_grad]
        task_optimizer = torch.optim.SGD(params, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)
        task_lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(task_optimizer, milestones=args.lr_steps,
                                                                 gamma=args.lr_gamma)

        print("Start training")
        start_time = time.time()
        for epoch in range(args.start_epoch, args.total_epochs):
            train_one_epoch(task_model, task_optimizer, data_loader, device, cycle, epoch, args.print_freq)
            task_lr_scheduler.step()
            # evaluate after pre-set epoch
            if (epoch + 1) == args.total_epochs:
                if 'coco' in args.dataset:
                    coco_evaluate(task_model, data_loader_test)
                elif 'voc' in args.dataset:
                    task_model.ssm_mode(False)
                    voc_evaluate(task_model, data_loader_test, args.dataset, path=args.results_path)
        # if not args.skip and cycle == 0:
        #     utils.save_on_master({
        #         'model': task_model.state_dict(), 'args': args},
        #         os.path.join(args.first_checkpoint_path, '{}_frcnn_1st.pth'.format(args.dataset)))
        random.shuffle(unlabeled_set)
        if 'coco' in args.dataset:
            subset = unlabeled_set[:10000]
        else:
            subset = unlabeled_set
        unlabeled_loader = DataLoader(dataset, batch_size=1, sampler=SubsetSequentialSampler(subset),
                                      num_workers=args.workers,
                                      # more convenient if we maintain the order of subset
                                      pin_memory=True, collate_fn=utils.collate_fn)
        print("Getting detections from unlabeled set")
        allScore, allBox, allY, al_idx = get_uncertainty(task_model, unlabeled_loader)
        al_idx = [subset[i] for i in al_idx]
        cls_sum = 0
        cls_loss_sum = np.zeros((num_classes - 1,))
        print(
            "First stage results: unlabeled set: {}, tobe labeled set: {}".format(len(subset), len(set(al_idx))))
        if len(al_idx) >= budget_num:
            al_idx = al_idx[:budget_num]
            labeled_set += al_idx
            print(len(set(labeled_set)))
            subset = list(set(subset) - set(al_idx))
            print(
                "First stage results: unlabeled set: {}, tobe labeled set: {}".format(len(subset), len(set(al_idx))))
            # Create a new dataloader for the updated labeled dataset
            train_sampler = SubsetRandomSampler(labeled_set)
            clslambda = 0.9 * clslambda - 0.1 * np.log(softmax(cls_loss_sum / (cls_sum + 1e-30)))
            gamma = min(gamma + 0.05, 1)
            unlabeled_set = list(set(unlabeled_set) - set(al_idx))
            continue
        subset = list(set(subset) - set(al_idx))
        print("Image cross validation")
        for i in range(len(subset)):
            if len(al_idx) >= budget_num:
                break
            cls_sum += len(allBox[i])
            for j, box in enumerate(allBox[i]):
                if len(al_idx) >= budget_num:
                    break
                score = allScore[i][j]
                label = torch.tensor(allY[i][j]).cuda()
                loss = -((1 + label.cpu().numpy()) / 2 * np.log(score.cpu().numpy()) + (
                        1 - label.cpu().numpy()) / 2 * np.log(1 - score.cpu().numpy() + 1e-30))
                cls_loss_sum += loss
                v, v_val = judge_uv(loss, gamma, clslambda)
                if v:
                    if torch.sum(label == 1) == 1 and torch.where(label == 1)[0] != 0:
                        # add Imgae Cross Validation
                        pre_cls = torch.where(label == 1)[0]
                        pre_box = box
                        curr_ind = [subset[i]]
                        curr_sampler = SubsetSequentialSampler(curr_ind)
                        curr_loader = DataLoader(dataset, batch_size=1, sampler=curr_sampler,
                                                 num_workers=args.workers, pin_memory=True, collate_fn=utils.collate_fn)
                        labeled_sampler = SubsetRandomSampler(labeled_set)
                        labeled_loader = DataLoader(dataset, batch_size=1, sampler=labeled_sampler,
                                                    num_workers=args.workers, pin_memory=True,
                                                    collate_fn=utils.collate_fn)
                        cross_validate, _ = image_cross_validation(
                            task_model, curr_loader, labeled_loader, pre_box, pre_cls)
                        if not cross_validate:
                            al_idx.append(subset[i])
                            break
                else:
                    al_idx.append(subset[i])
                    break
        # Update the labeled dataset and the unlabeled dataset, respectively
        print("Second stage results: unlabeled set: {}, tobe labeled set: {}".format(len(subset), len(set(al_idx))))
        subset = list(set(subset) - set(al_idx))
        if len(al_idx) > budget_num:
            al_idx = al_idx[:budget_num]
        if len(al_idx) < budget_num:
            al_idx += list(set(subset) - set(al_idx))[:budget_num - len(al_idx)]
        labeled_set += al_idx
        unlabeled_set = list(set(unlabeled_set) - set(al_idx))
        print("Second stage results: unlabeled set: {}, tobe labeled set: {}".format(len(subset), len(set(al_idx))))
        # Create a new dataloader for the updated labeled dataset
        train_sampler = SubsetRandomSampler(labeled_set)
        clslambda = 0.9 * clslambda - 0.1 * np.log(softmax(cls_loss_sum / (cls_sum + 1e-30)))
        gamma = min(gamma + 0.05, 1)

        total_time = time.time() - start_time
        total_time_str = str(datetime.timedelta(seconds=int(total_time)))
        print('Training time {}'.format(total_time_str))
Esempio n. 3
0
def main(args):
    torch.cuda.set_device(0)
    random.seed(0)
    torch.manual_seed(0)
    torch.cuda.manual_seed(0)
    torch.cuda.manual_seed_all(0)
    print(args)

    device = torch.device(args.device)

    # Data loading code
    print("Loading data")

    if 'voc2007' in args.dataset:
        dataset, num_classes = get_dataset(args.dataset, "trainval",
                                           get_transform(train=True),
                                           args.data_path)
        dataset_test, _ = get_dataset(args.dataset, "test",
                                      get_transform(train=False),
                                      args.data_path)
    else:
        dataset, num_classes = get_dataset(args.dataset, "train",
                                           get_transform(train=True),
                                           args.data_path)
        dataset_test, _ = get_dataset(args.dataset, "val",
                                      get_transform(train=False),
                                      args.data_path)
    if 'voc' in args.dataset:
        init_num = 500
        budget_num = 500
        if 'retina' in args.model:
            init_num = 1000
            budget_num = 500
    else:
        init_num = 5000
        budget_num = 1000
    print("Creating data loaders")
    num_images = len(dataset)
    indices = list(range(num_images))
    random.shuffle(indices)
    labeled_set = indices[:init_num]
    unlabeled_set = list(set(indices) - set(labeled_set))
    train_sampler = SubsetRandomSampler(labeled_set)
    test_sampler = torch.utils.data.SequentialSampler(dataset_test)
    data_loader_test = DataLoader(dataset_test,
                                  batch_size=1,
                                  sampler=test_sampler,
                                  num_workers=args.workers,
                                  collate_fn=utils.collate_fn)
    for cycle in range(args.cycles):
        if args.aspect_ratio_group_factor >= 0:
            group_ids = create_aspect_ratio_groups(
                dataset, k=args.aspect_ratio_group_factor)
            train_batch_sampler = GroupedBatchSampler(train_sampler, group_ids,
                                                      args.batch_size)
        else:
            train_batch_sampler = torch.utils.data.BatchSampler(
                train_sampler, args.batch_size, drop_last=True)

        data_loader = torch.utils.data.DataLoader(
            dataset,
            batch_sampler=train_batch_sampler,
            num_workers=args.workers,
            collate_fn=utils.collate_fn)

        print("Creating model")
        if 'voc' in args.dataset:
            if 'faster' in args.model:
                task_model = fasterrcnn_resnet50_fpn_feature(
                    num_classes=num_classes, min_size=600, max_size=1000)
            elif 'retina' in args.model:
                task_model = retinanet_resnet50_fpn(num_classes=num_classes,
                                                    min_size=600,
                                                    max_size=1000)
        else:
            if 'faster' in args.model:
                task_model = fasterrcnn_resnet50_fpn_feature(
                    num_classes=num_classes, min_size=800, max_size=1333)
            elif 'retina' in args.model:
                task_model = retinanet_resnet50_fpn(num_classes=num_classes,
                                                    min_size=800,
                                                    max_size=1333)
        task_model.to(device)

        params = [p for p in task_model.parameters() if p.requires_grad]
        task_optimizer = torch.optim.SGD(params,
                                         lr=args.lr,
                                         momentum=args.momentum,
                                         weight_decay=args.weight_decay)
        task_lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
            task_optimizer, milestones=args.lr_steps, gamma=args.lr_gamma)
        ll_model = lossnet.LossNet()
        ll_model.to(device)
        params_ll = [p for p in ll_model.parameters() if p.requires_grad]
        ll_optimizer = torch.optim.SGD(params_ll,
                                       lr=args.lr,
                                       momentum=args.momentum,
                                       weight_decay=args.weight_decay)
        ll_lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
            ll_optimizer, milestones=args.lr_steps, gamma=args.lr_gamma)
        # Start active learning cycles training
        if args.test_only:
            if 'coco' in args.dataset:
                coco_evaluate(task_model, data_loader_test, feature=True)
            elif 'voc' in args.dataset:
                voc_evaluate(task_model, data_loader_test, args.dataset, True)
            return
        print("Start training")
        start_time = time.time()
        for epoch in range(args.start_epoch, args.total_epochs):
            train_one_epoch(task_model, task_optimizer, ll_model, ll_optimizer,
                            data_loader, device, cycle, epoch, args.print_freq)
            task_lr_scheduler.step()
            ll_lr_scheduler.step()
            # evaluate after pre-set epoch
            if (epoch + 1) == args.total_epochs:
                if 'coco' in args.dataset:
                    coco_evaluate(task_model, data_loader_test, feature=True)
                elif 'voc' in args.dataset:
                    voc_evaluate(task_model,
                                 data_loader_test,
                                 args.dataset,
                                 True,
                                 path=args.results_path)
        random.shuffle(unlabeled_set)
        if 'coco' in args.dataset:
            subset = unlabeled_set[:10000]
        else:
            subset = unlabeled_set
        unlabeled_loader = DataLoader(
            dataset,
            batch_size=args.batch_size,
            sampler=SubsetSequentialSampler(subset),
            num_workers=args.workers,
            # more convenient if we maintain the order of subset
            pin_memory=True,
            collate_fn=utils.collate_fn)
        uncertainty = get_uncertainty(task_model, ll_model, unlabeled_loader)
        labeled_loader = DataLoader(
            dataset,
            batch_size=args.batch_size,
            sampler=SubsetSequentialSampler(labeled_set),
            num_workers=args.workers,
            # more convenient if we maintain the order of subset
            pin_memory=True,
            collate_fn=utils.collate_fn)
        u = get_uncertainty(task_model, ll_model, labeled_loader)
        # with open("vis/ll_labeled_metric_{}_{}_{}.pkl".format(args.model, args.dataset, cycle),
        #           "wb") as fp:  # Pickling
        #     pickle.dump(u, fp)
        arg = np.argsort(uncertainty)
        # with open("vis/ll_unlabeled_metric_{}_{}_{}.pkl".format(args.model, args.dataset, cycle),
        #           "wb") as fp:  # Pickling
        #     pickle.dump(torch.tensor(uncertainty)[arg][-1 * budget_num:].numpy(), fp)
        # Update the labeled dataset and the unlabeled dataset, respectively
        labeled_set += list(
            torch.tensor(subset)[arg][-1 * budget_num:].numpy())
        labeled_set = list(set(labeled_set))
        # with open("vis/ll_{}_{}_{}.txt".format(args.model, args.dataset, cycle), "wb") as fp:  # Pickling
        #     pickle.dump(labeled_set, fp)
        unlabeled_set = list(set(indices) - set(labeled_set))

        # Create a new dataloader for the updated labeled dataset
        train_sampler = SubsetRandomSampler(labeled_set)

        total_time = time.time() - start_time
        total_time_str = str(datetime.timedelta(seconds=int(total_time)))
        print('Training time {}'.format(total_time_str))
Esempio n. 4
0
def main(args):
    torch.cuda.set_device(0)
    random.seed(0)
    torch.manual_seed(0)
    torch.cuda.manual_seed(0)
    torch.cuda.manual_seed_all(0)
    utils.init_distributed_mode(args)
    print(args)

    device = torch.device(args.device)

    # Data loading code
    print("Loading data")

    if 'voc2007' in args.dataset:
        dataset, num_classes = get_dataset(args.dataset, "trainval",
                                           get_transform(train=True),
                                           args.data_path)
        dataset_test, _ = get_dataset(args.dataset, "test",
                                      get_transform(train=False),
                                      args.data_path)
    else:
        dataset, num_classes = get_dataset(args.dataset, "train",
                                           get_transform(train=True),
                                           args.data_path)
        dataset_test, _ = get_dataset(args.dataset, "val",
                                      get_transform(train=False),
                                      args.data_path)
    print("Creating data loaders")
    num_images = len(dataset)
    if 'voc' in args.dataset:
        init_num = 1000
        budget_num = 1000
        if 'retina' in args.model:
            init_num = 1000
            budget_num = 500
    else:
        init_num = 5000
        budget_num = 1000
    indices = list(range(num_images))
    random.shuffle(indices)
    labeled_set = indices[:init_num]
    unlabeled_set = indices[init_num:]
    train_sampler = SubsetRandomSampler(labeled_set)
    test_sampler = torch.utils.data.SequentialSampler(dataset_test)
    data_loader_test = DataLoader(dataset_test,
                                  batch_size=1,
                                  sampler=test_sampler,
                                  num_workers=args.workers,
                                  collate_fn=utils.collate_fn)
    for cycle in range(args.cycles):
        if args.aspect_ratio_group_factor >= 0:
            group_ids = create_aspect_ratio_groups(
                dataset, k=args.aspect_ratio_group_factor)
            train_batch_sampler = GroupedBatchSampler(train_sampler, group_ids,
                                                      args.batch_size)
        else:
            train_batch_sampler = torch.utils.data.BatchSampler(
                train_sampler, args.batch_size, drop_last=True)

        data_loader = torch.utils.data.DataLoader(
            dataset,
            batch_sampler=train_batch_sampler,
            num_workers=args.workers,
            collate_fn=utils.collate_fn)

        print("Creating model")
        if 'voc' in args.dataset:
            if 'faster' in args.model:
                task_model = fasterrcnn_resnet50_fpn(num_classes=num_classes,
                                                     min_size=600,
                                                     max_size=1000)
            elif 'retina' in args.model:
                task_model = retinanet_resnet50_fpn(num_classes=num_classes,
                                                    min_size=600,
                                                    max_size=1000)
        else:
            if 'faster' in args.model:
                task_model = fasterrcnn_resnet50_fpn(num_classes=num_classes,
                                                     min_size=800,
                                                     max_size=1333)
            elif 'retina' in args.model:
                task_model = retinanet_resnet50_fpn(num_classes=num_classes,
                                                    min_size=600,
                                                    max_size=1000)
        task_model.to(device)
        if not args.init and cycle == 0 and args.skip:
            if 'faster' in args.model:
                checkpoint = torch.load(os.path.join(
                    args.first_checkpoint_path,
                    '{}_frcnn_1st.pth'.format(args.dataset)),
                                        map_location='cpu')
            elif 'retina' in args.model:
                checkpoint = torch.load(os.path.join(
                    args.first_checkpoint_path,
                    '{}_retinanet_1st.pth'.format(args.dataset)),
                                        map_location='cpu')
            task_model.load_state_dict(checkpoint['model'])
            # if 'coco' in args.dataset:
            #     coco_evaluate(task_model, data_loader_test)
            # elif 'voc' in args.dataset:
            #     voc_evaluate(task_model, data_loader_test, args.dataset)
            print("Getting stability")
            random.shuffle(unlabeled_set)
            if 'coco' in args.dataset:
                subset = unlabeled_set[:5000]
            else:
                subset = unlabeled_set
            # Update the labeled dataset and the unlabeled dataset, respectively
            labeled_set += subset[:budget_num]
            labeled_set = list(set(labeled_set))
            # with open("vis/cycle_{}.txt".format(cycle), "rb") as fp:  # Unpickling
            #     labeled_set = pickle.load(fp)
            unlabeled_set = list(set(indices) - set(labeled_set))

            # Create a new dataloader for the updated labeled dataset
            train_sampler = SubsetRandomSampler(labeled_set)
            continue
        params = [p for p in task_model.parameters() if p.requires_grad]
        task_optimizer = torch.optim.SGD(params,
                                         lr=args.lr,
                                         momentum=args.momentum,
                                         weight_decay=args.weight_decay)
        task_lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
            task_optimizer, milestones=args.lr_steps, gamma=args.lr_gamma)
        # Start active learning cycles training
        if args.test_only:
            if 'coco' in args.dataset:
                coco_evaluate(task_model, data_loader_test)
            elif 'voc' in args.dataset:
                voc_evaluate(task_model, data_loader_test, args.dataset)
            return
        print("Start training")
        start_time = time.time()
        for epoch in range(args.start_epoch, args.total_epochs):
            train_one_epoch(task_model, task_optimizer, data_loader, device,
                            cycle, epoch, args.print_freq)
            task_lr_scheduler.step()
            # evaluate after pre-set epoch
            if (epoch + 1) == args.total_epochs:
                if 'coco' in args.dataset:
                    coco_evaluate(task_model, data_loader_test)
                elif 'voc' in args.dataset:
                    voc_evaluate(task_model,
                                 data_loader_test,
                                 args.dataset,
                                 path=args.results_path)
        if not args.skip and cycle == 0:
            if 'faster' in args.model:
                utils.save_on_master(
                    {
                        'model': task_model.state_dict(),
                        'args': args
                    },
                    os.path.join(args.first_checkpoint_path,
                                 '{}_frcnn_1st.pth'.format(args.dataset)))
            elif 'retina' in args.model:
                utils.save_on_master(
                    {
                        'model': task_model.state_dict(),
                        'args': args
                    },
                    os.path.join(args.first_checkpoint_path,
                                 '{}_retinanet_1st.pth'.format(args.dataset)))
        random.shuffle(unlabeled_set)
        # Update the labeled dataset and the unlabeled dataset, respectively
        labeled_set += unlabeled_set[:budget_num]
        labeled_set = list(set(labeled_set))
        unlabeled_set = unlabeled_set[budget_num:]
        # Create a new dataloader for the updated labeled dataset
        train_sampler = SubsetRandomSampler(labeled_set)
        total_time = time.time() - start_time
        total_time_str = str(datetime.timedelta(seconds=int(total_time)))
        print('Training time {}'.format(total_time_str))
Esempio n. 5
0
def main(args):
    torch.cuda.set_device(0)
    random.seed(0)
    torch.manual_seed(0)
    torch.cuda.manual_seed(0)
    torch.cuda.manual_seed_all(0)
    print(args)

    device = torch.device(args.device)

    # Data loading code
    print("Loading data")
    if 'voc2007' in args.dataset:
        dataset, num_classes = get_dataset(args.dataset, "trainval",
                                           get_transform(train=True),
                                           args.data_path)
        dataset_test, _ = get_dataset(args.dataset, "test",
                                      get_transform(train=False),
                                      args.data_path)
    else:
        dataset, num_classes = get_dataset(args.dataset, "train",
                                           get_transform(train=True),
                                           args.data_path)
        dataset_test, _ = get_dataset(args.dataset, "val",
                                      get_transform(train=False),
                                      args.data_path)
    if 'voc' in args.dataset:
        init_num = 500
        budget_num = 500
        if 'retina' in args.model:
            init_num = 1000
            budget_num = 500
    else:
        init_num = 5000
        budget_num = 1000
    print("Creating data loaders")
    num_images = len(dataset)
    indices = list(range(num_images))
    random.shuffle(indices)
    labeled_set = indices[:init_num]
    unlabeled_set = list(set(indices) - set(labeled_set))
    train_sampler = SubsetRandomSampler(labeled_set)
    unlabeled_sampler = SubsetRandomSampler(unlabeled_set)
    test_sampler = torch.utils.data.SequentialSampler(dataset_test)
    data_loader_test = DataLoader(dataset_test,
                                  batch_size=1,
                                  sampler=test_sampler,
                                  num_workers=args.workers,
                                  collate_fn=utils.collate_fn)
    for cycle in range(args.cycles):
        if args.aspect_ratio_group_factor >= 0:
            group_ids = create_aspect_ratio_groups(
                dataset, k=args.aspect_ratio_group_factor)
            train_batch_sampler = GroupedBatchSampler(train_sampler, group_ids,
                                                      args.batch_size)
            unlabeled_batch_sampler = GroupedBatchSampler(
                unlabeled_sampler, group_ids, args.batch_size)
        else:
            train_batch_sampler = torch.utils.data.BatchSampler(
                train_sampler, args.batch_size, drop_last=True)
            unlabeled_batch_sampler = torch.utils.data.BatchSampler(
                unlabeled_sampler, args.batch_size, drop_last=True)
        data_loader = torch.utils.data.DataLoader(
            dataset,
            batch_sampler=train_batch_sampler,
            num_workers=args.workers,
            collate_fn=utils.collate_fn)
        unlabeled_dataloader = torch.utils.data.DataLoader(
            dataset,
            batch_sampler=unlabeled_batch_sampler,
            num_workers=args.workers,
            collate_fn=utils.collate_fn)
        print("Creating model")
        if 'voc' in args.dataset:
            if 'faster' in args.model:
                task_model = fasterrcnn_resnet50_fpn(num_classes=num_classes,
                                                     min_size=600,
                                                     max_size=1000)
            elif 'retina' in args.model:
                task_model = retinanet_resnet50_fpn(num_classes=num_classes,
                                                    min_size=600,
                                                    max_size=1000)
        else:
            if 'faster' in args.model:
                task_model = fasterrcnn_resnet50_fpn(num_classes=num_classes,
                                                     min_size=800,
                                                     max_size=1333)
            elif 'retina' in args.model:
                task_model = retinanet_resnet50_fpn(num_classes=num_classes,
                                                    min_size=800,
                                                    max_size=1333)
        task_model.to(device)

        params = [p for p in task_model.parameters() if p.requires_grad]
        task_optimizer = torch.optim.SGD(params,
                                         lr=args.lr,
                                         momentum=args.momentum,
                                         weight_decay=args.weight_decay)
        task_lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
            task_optimizer, milestones=args.lr_steps, gamma=args.lr_gamma)
        vae = VAE()
        params = [p for p in vae.parameters() if p.requires_grad]
        vae_optimizer = torch.optim.SGD(params,
                                        lr=args.lr / 10,
                                        momentum=args.momentum,
                                        weight_decay=args.weight_decay)
        vae_lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
            vae_optimizer, milestones=args.lr_steps, gamma=args.lr_gamma)
        torch.nn.utils.clip_grad_value_(vae.parameters(), 1e5)

        vae.to(device)
        discriminator = Discriminator()
        params = [p for p in discriminator.parameters() if p.requires_grad]
        discriminator_optimizer = torch.optim.SGD(
            params,
            lr=args.lr,
            momentum=args.momentum,
            weight_decay=args.weight_decay)
        discriminator_lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
            discriminator_optimizer,
            milestones=args.lr_steps,
            gamma=args.lr_gamma)
        discriminator.to(device)
        # Start active learning cycles training
        if args.test_only:
            if 'coco' in args.dataset:
                coco_evaluate(task_model, data_loader_test)
            elif 'voc' in args.dataset:
                voc_evaluate(task_model,
                             data_loader_test,
                             args.dataset,
                             False,
                             path=args.results_path)
            return
        print("Start training")
        start_time = time.time()
        for epoch in range(args.start_epoch, args.total_epochs):
            train_one_epoch(task_model, task_optimizer, vae, vae_optimizer,
                            discriminator, discriminator_optimizer,
                            data_loader, unlabeled_dataloader, device, cycle,
                            epoch, args.print_freq)
            task_lr_scheduler.step()
            vae_lr_scheduler.step()
            discriminator_lr_scheduler.step()
            # evaluate after pre-set epoch
            if (epoch + 1) == args.total_epochs:
                if 'coco' in args.dataset:
                    coco_evaluate(task_model, data_loader_test)
                elif 'voc' in args.dataset:
                    voc_evaluate(task_model,
                                 data_loader_test,
                                 args.dataset,
                                 False,
                                 path=args.results_path)
        # Update the labeled dataset and the unlabeled dataset, respectively
        random.shuffle(unlabeled_set)
        if 'coco' in args.dataset:
            subset = unlabeled_set[:10000]
        else:
            subset = unlabeled_set
        unlabeled_loader = DataLoader(dataset,
                                      batch_size=1,
                                      sampler=SubsetSequentialSampler(subset),
                                      num_workers=args.workers,
                                      pin_memory=True,
                                      collate_fn=utils.collate_fn)
        tobe_labeled_inds = sample_for_labeling(vae, discriminator,
                                                unlabeled_loader, budget_num)
        tobe_labeled_set = [subset[i] for i in tobe_labeled_inds]
        labeled_set += tobe_labeled_set
        unlabeled_set = list(set(unlabeled_set) - set(tobe_labeled_set))
        # Create a new dataloader for the updated labeled dataset
        train_sampler = SubsetRandomSampler(labeled_set)
        unlabeled_sampler = SubsetRandomSampler(unlabeled_set)
        total_time = time.time() - start_time
        total_time_str = str(datetime.timedelta(seconds=int(total_time)))
        print('Training time {}'.format(total_time_str))
Esempio n. 6
0
def main(args):
    utils.init_distributed_mode(args)
    print(args)
    device = torch.device(args.device)
    print("Loading data")
    # pdb.set_trace()
    transform = build_transforms(cfg, is_train=True)

    train_data = VisualGenomeDataset(args.data_dir,
                                     task='detection',
                                     split='train',
                                     transforms=transform)
    test_data = VisualGenomeDataset(args.data_dir,
                                    task='detection',
                                    split='test',
                                    transforms=transform)
    print("Creating data loaders")
    if args.distributed:
        train_sampler = torch.utils.data.distributed.DistributedSampler(
            train_data)
        test_sampler = torch.utils.data.distributed.DistributedSampler(
            test_data)
    else:
        train_sampler = torch.utils.data.RandomSampler(train_data)
        test_sampler = torch.utils.data.SequentialSampler(test_data)

    if args.aspect_ratio_group_factor >= 0:
        group_ids = create_aspect_ratio_groups(
            train_data, k=args.aspect_ratio_group_factor)
        train_batch_sampler = GroupedBatchSampler(train_sampler, group_ids,
                                                  args.batch_size)
    else:
        train_batch_sampler = torch.utils.data.BatchSampler(train_sampler,
                                                            args.batch_size,
                                                            drop_last=True)

    train_data_loader = torch.utils.data.DataLoader(
        train_data,
        batch_sampler=train_batch_sampler,
        num_workers=args.workers,
        collate_fn=utils.collate_fn)
    test_data_loader = torch.utils.data.DataLoader(test_data,
                                                   batch_size=1,
                                                   sampler=test_sampler,
                                                   num_workers=args.workers,
                                                   collate_fn=utils.collate_fn)

    print("Creating model")
    model = fasterrcnn_resnet50_fpn(pretrained=True)
    in_features = model.roi_heads.box_predictor.cls_score.in_features
    model.roi_heads.box_predictor = FastRCNNPredictor(in_features,
                                                      cfg.NUM_CALSSES)
    model.to(device)

    model_without_ddp = model
    if args.distributed:
        model = torch.nn.parallel.DistributedDataParallel(
            model, device_ids=[args.gpu])
        model_without_ddp = model.module

    params = [p for p in model.parameters() if p.requires_grad]
    optimizer = torch.optim.Adam(params,
                                 lr=args.lr,
                                 weight_decay=args.weight_decay)
    lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
        optimizer, milestones=args.lr_steps, gamma=args.lr_gamma)
    # lr_scheduler = torch.optim.lr_scheduler.StepLR(
    #     optimizer, step_size=8, gamma=0.5)
    last_epoch = 0
    if args.resume:
        print("from checkpoint*************")
        checkpoint = torch.load(args.resume, map_location='cpu')
        model_without_ddp.load_state_dict(checkpoint['model'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
        last_epoch = lr_scheduler.last_epoch
    if args.test_only:
        evaluate(model, test_data_loader, device=device)
        return

    print("Start training")
    start_time = time.time()
    for epoch in range(last_epoch, args.epochs):
        if args.distributed:
            train_sampler.set_epoch(epoch)
        train_one_epoch(model, optimizer, train_data_loader, device, epoch,
                        args.print_freq)
        # lr_scheduler.step()
        # if args.output_dir:
        #     utils.save_on_master({
        #         'model': model_without_ddp.state_dict(),
        #         'optimizer': optimizer.state_dict(),
        #         'lr_scheduler': lr_scheduler.state_dict(),
        #         'args': args},
        #         os.path.join(args.output_dir, 'model_{}.pth'.format(epoch)))
        # evaluate(model, test_data_loader, device=device)

    total_time = time.time() - start_time
    total_time_str = str(datetime.timedelta(seconds=int(total_time)))
    print('Training time {}'.format(total_time_str))
Esempio n. 7
0
def main(args):
    torch.cuda.set_device(0)
    random.seed(0)
    torch.manual_seed(0)
    torch.cuda.manual_seed(0)
    torch.cuda.manual_seed_all(0)
    print(args)

    device = torch.device(args.device)

    # Data loading code
    print("Loading data")

    if 'voc2007' in args.dataset:
        dataset, num_classes = get_dataset(args.dataset, "trainval", get_transform(train=True), args.data_path)
        dataset_aug, _ = get_dataset(args.dataset, "trainval", None, args.data_path)
        dataset_test, _ = get_dataset(args.dataset, "test", get_transform(train=False), args.data_path)
    else:
        dataset, num_classes = get_dataset(args.dataset, "train", get_transform(train=True), args.data_path)
        dataset_aug, _ = get_dataset(args.dataset, "train", None, args.data_path)
        dataset_test, _ = get_dataset(args.dataset, "val", get_transform(train=False), args.data_path)

    print("Creating data loaders")
    num_images = len(dataset)
    if 'voc' in args.dataset:
        init_num = 500
        budget_num = 500
        if 'retina' in args.model:
            init_num = 1000
            budget_num = 500
    else:
        init_num = 5000
        budget_num = 1000
    indices = list(range(num_images))
    random.shuffle(indices)
    labeled_set = indices[:init_num]
    unlabeled_set = list(set(indices) - set(labeled_set))
    train_sampler = SubsetRandomSampler(labeled_set)
    data_loader_test = DataLoader(dataset_test, batch_size=1, sampler=SequentialSampler(dataset_test),
                                  num_workers=args.workers, collate_fn=utils.collate_fn)
    augs = []
    if 'F' in args.augs:
        augs.append('flip')
    if 'C' in args.augs:
        augs.append('cut_out')
    if 'D' in args.augs:
        augs.append('smaller_resize')
    if 'R' in args.augs:
        augs.append('rotation')
    if 'G' in args.augs:
        augs.append('ga')
    if 'S' in args.augs:
        augs.append('sp')
    for cycle in range(args.cycles):
        if args.aspect_ratio_group_factor >= 0:
            group_ids = create_aspect_ratio_groups(dataset, k=args.aspect_ratio_group_factor)
            train_batch_sampler = GroupedBatchSampler(train_sampler, group_ids, args.batch_size)
        else:
            train_batch_sampler = torch.utils.data.BatchSampler(train_sampler, args.batch_size, drop_last=True)

        data_loader = torch.utils.data.DataLoader(dataset, batch_sampler=train_batch_sampler, num_workers=args.workers,
                                                  collate_fn=utils.collate_fn)

        print("Creating model")
        if 'voc' in args.dataset:
            if 'faster' in args.model:
                task_model = fasterrcnn_resnet50_fpn_feature(num_classes=num_classes, min_size=600, max_size=1000)
            elif 'retina' in args.model:
                task_model = retinanet_resnet50_fpn_cal(num_classes=num_classes, min_size=600, max_size=1000)
        else:
            if 'faster' in args.model:
                task_model = fasterrcnn_resnet50_fpn_feature(num_classes=num_classes, min_size=800, max_size=1333)
            elif 'retina' in args.model:
                task_model = retinanet_resnet50_fpn_cal(num_classes=num_classes, min_size=800, max_size=1333)
        task_model.to(device)
        if cycle == 0 and args.skip:
            if 'faster' in args.model:
                checkpoint = torch.load(os.path.join(args.first_checkpoint_path,
                                                     '{}_frcnn_1st.pth'.format(args.dataset)), map_location='cpu')
            elif 'retina' in args.model:
                checkpoint = torch.load(os.path.join(args.first_checkpoint_path,
                                                     '{}_retinanet_1st.pth'.format(args.dataset)), map_location='cpu')
            task_model.load_state_dict(checkpoint['model'])
            if args.test_only:
                if 'coco' in args.dataset:
                    coco_evaluate(task_model, data_loader_test)
                elif 'voc' in args.dataset:
                    voc_evaluate(task_model, data_loader_test, args.dataset, False, path=args.results_path)
                return
            print("Getting stability")
            random.shuffle(unlabeled_set)
            if 'coco' in args.dataset:
                subset = unlabeled_set[:10000]
            else:
                subset = unlabeled_set
            if not args.no_mutual:
                unlabeled_loader = DataLoader(dataset_aug, batch_size=1, sampler=SubsetSequentialSampler(subset),
                                              num_workers=args.workers, pin_memory=True, collate_fn=utils.collate_fn)
                uncertainty, _cls_corrs = get_uncertainty(task_model, unlabeled_loader, augs, num_classes)
                arg = np.argsort(np.array(uncertainty))
                cls_corrs_set = arg[:int(args.mr * budget_num)]
                cls_corrs = [_cls_corrs[i] for i in cls_corrs_set]
                labeled_loader = DataLoader(dataset_aug, batch_size=1, sampler=SubsetSequentialSampler(labeled_set),
                                            num_workers=args.workers, pin_memory=True, collate_fn=utils.collate_fn)
                tobe_labeled_set = cls_kldiv(labeled_loader, cls_corrs, budget_num, cycle)
                # Update the labeled dataset and the unlabeled dataset, respectively
                tobe_labeled_set = list(torch.tensor(subset)[arg][tobe_labeled_set].numpy())
                labeled_set += tobe_labeled_set
                unlabeled_set = list(set(indices) - set(labeled_set))
            else:
                unlabeled_loader = DataLoader(dataset_aug, batch_size=1, sampler=SubsetSequentialSampler(subset),
                                              num_workers=args.workers, pin_memory=True, collate_fn=utils.collate_fn)
                uncertainty, _ = get_uncertainty(task_model, unlabeled_loader, augs, num_classes)
                arg = np.argsort(np.array(uncertainty))
                # Update the labeled dataset and the unlabeled dataset, respectively
                labeled_set += list(torch.tensor(subset)[arg][:budget_num].numpy())
                labeled_set = list(set(labeled_set))
                unlabeled_set = list(set(indices) - set(labeled_set))

            # Create a new dataloader for the updated labeled dataset
            train_sampler = SubsetRandomSampler(labeled_set)
            continue
        params = [p for p in task_model.parameters() if p.requires_grad]
        task_optimizer = torch.optim.SGD(params, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)
        task_lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(task_optimizer, milestones=args.lr_steps,
                                                                 gamma=args.lr_gamma)
        # Start active learning cycles training
        if args.test_only:
            if 'coco' in args.dataset:
                coco_evaluate(task_model, data_loader_test)
            elif 'voc' in args.dataset:
                voc_evaluate(task_model, data_loader_test, args.dataset, False, path=args.results_path)
            return
        print("Start training")
        start_time = time.time()
        for epoch in range(args.start_epoch, args.total_epochs):
            train_one_epoch(task_model, task_optimizer, data_loader, device, cycle, epoch, args.print_freq)
            task_lr_scheduler.step()
            # evaluate after pre-set epoch
            if (epoch + 1) == args.total_epochs:
                if 'coco' in args.dataset:
                    coco_evaluate(task_model, data_loader_test)
                elif 'voc' in args.dataset:
                    voc_evaluate(task_model, data_loader_test, args.dataset, False, path=args.results_path)
        if not args.skip and cycle == 0:
            if 'faster' in args.model:
                utils.save_on_master({
                    'model': task_model.state_dict(), 'args': args},
                    os.path.join(args.first_checkpoint_path, '{}_frcnn_1st.pth'.format(args.dataset)))
            elif 'retina' in args.model:
                utils.save_on_master({
                    'model': task_model.state_dict(), 'args': args},
                    os.path.join(args.first_checkpoint_path, '{}_retinanet_1st.pth'.format(args.dataset)))
        random.shuffle(unlabeled_set)
        if 'coco' in args.dataset:
            subset = unlabeled_set[:10000]
        else:
            subset = unlabeled_set
        print("Getting stability")
        if not args.no_mutual:
            unlabeled_loader = DataLoader(dataset_aug, batch_size=1, sampler=SubsetSequentialSampler(subset),
                                          num_workers=args.workers, pin_memory=True, collate_fn=utils.collate_fn)
            uncertainty, _cls_corrs = get_uncertainty(task_model, unlabeled_loader, augs, num_classes)
            # labeled_loader = DataLoader(dataset_aug, batch_size=1, sampler=SubsetSequentialSampler(labeled_set),
            #                             num_workers=args.workers, pin_memory=True, collate_fn=utils.collate_fn)
            arg = np.argsort(np.array(uncertainty))
            cls_corrs_set = arg[:int(args.mr * budget_num)]
            cls_corrs = [_cls_corrs[i] for i in cls_corrs_set]
            labeled_loader = DataLoader(dataset_aug, batch_size=1, sampler=SubsetSequentialSampler(labeled_set),
                                        num_workers=args.workers, pin_memory=True, collate_fn=utils.collate_fn)
            tobe_labeled_set = cls_kldiv(labeled_loader, cls_corrs, budget_num, cycle)
            # Update the labeled dataset and the unlabeled dataset, respectively
            tobe_labeled_set = list(torch.tensor(subset)[arg][tobe_labeled_set].numpy())
            labeled_set += tobe_labeled_set
            unlabeled_set = list(set(indices) - set(labeled_set))
        else:
            unlabeled_loader = DataLoader(dataset_aug, batch_size=1, sampler=SubsetSequentialSampler(subset),
                                          num_workers=args.workers, pin_memory=True, collate_fn=utils.collate_fn)
            uncertainty, _ = get_uncertainty(task_model, unlabeled_loader, augs, num_classes)
            arg = np.argsort(np.array(uncertainty))
            # Update the labeled dataset and the unlabeled dataset, respectively
            labeled_set += list(torch.tensor(subset)[arg][:budget_num].numpy())
            labeled_set = list(set(labeled_set))
            unlabeled_set = list(set(indices) - set(labeled_set))
        # Create a new dataloader for the updated labeled dataset
        train_sampler = SubsetRandomSampler(labeled_set)

        total_time = time.time() - start_time
        total_time_str = str(datetime.timedelta(seconds=int(total_time)))
        print('Training time {}'.format(total_time_str))