Example #1
0
def main_worker(gpu, ngpus_per_node, args):
    global best_acc1
    args.gpu = gpu

    if args.gpu is not None:
        print("Use GPU: {} for training".format(args.gpu))

    # create model
    print("=> creating model '{}'".format(args.arch))
    num_classes = 100 if args.dataset == 'cifar100' else 10
    use_norm = True if args.loss_type == 'LDAM' else False
    model = models.__dict__[args.arch](num_classes=num_classes,
                                       use_norm=use_norm)

    # create two optimizers - one for feature extractor and one for classifier
    feat_params = []
    feat_params_names = []
    cls_params = []
    cls_params_names = []
    learnable_epsilons = torch.nn.Parameter(torch.ones(num_classes))
    for name, params in model.named_parameters():
        if params.requires_grad:
            if "linear" in name:
                cls_params_names += [name]
                cls_params += [params]
            else:
                feat_params_names += [name]
                feat_params += [params]
    print("Create Feat Optimizer")
    print(f"\tRequires Grad:{feat_params_names}")
    feat_optim = torch.optim.SGD(feat_params + [learnable_epsilons],
                                 args.feat_lr,
                                 momentum=args.momentum,
                                 weight_decay=args.weight_decay)
    print("Create Feat Optimizer")
    print(f"\tRequires Grad:{cls_params_names}")
    cls_optim = torch.optim.SGD(cls_params,
                                args.cls_lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    # optionally resume from a checkpoint
    if args.resume or args.evaluation:
        curr_store_name = args.store_name
        if not args.evaluation and args.pretrained:
            curr_store_name = os.path.join(curr_store_name, os.path.pardir)
        filename = '%s/%s/ckpt.best.pth.tar' % (args.root_model,
                                                curr_store_name)
        if os.path.isfile(filename):
            print("=> loading checkpoint '{}'".format(filename))
            checkpoint = torch.load(filename, map_location=f"cuda:{args.gpu}")
            model.load_state_dict(checkpoint['state_dict'])
            print("=> loaded checkpoint '{}' (epoch {})".format(
                filename, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(filename))

    if args.gpu is not None:
        torch.cuda.set_device(args.gpu)
        model = model.cuda(args.gpu)
    else:
        # DataParallel will divide and allocate batch_size to all available GPUs
        model = torch.nn.DataParallel(model).cuda()

    cudnn.benchmark = True
    # Data loading code=
    transform_train = transforms.Compose([
        transforms.RandomCrop(
            32, padding=4
        ),  # fill parameter needs torchvision installed from source
        transforms.RandomHorizontalFlip(),
        CIFAR10Policy(),
        transforms.ToTensor(),
        Cutout(
            n_holes=1, length=16
        ),  # (https://github.com/uoguelph-mlrg/Cutout/blob/master/util/cutout.py)
        transforms.Normalize((0.4914, 0.4822, 0.4465),
                             (0.2023, 0.1994, 0.2010))
    ])

    transform_val = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465),
                             (0.2023, 0.1994, 0.2010)),
    ])

    if args.dataset == 'cifar10':
        original_train_dataset = IMBALANCECIFAR10(root='./data',
                                                  imb_type=args.imb_type,
                                                  imb_factor=args.imb_factor,
                                                  rand_number=args.rand_number,
                                                  train=True,
                                                  download=True,
                                                  transform=transform_val)
        augmented_train_dataset = IMBALANCECIFAR10(
            root='./data',
            imb_type=args.imb_type,
            imb_factor=args.imb_factor,
            rand_number=args.rand_number,
            train=True,
            download=True,
            transform=transform_train
            if not args.evaluation else transform_val)
        val_dataset = datasets.CIFAR10(root='./data',
                                       train=False,
                                       download=True,
                                       transform=transform_val)
    elif args.dataset == 'cifar100':
        original_train_dataset = IMBALANCECIFAR100(
            root='./data',
            imb_type=args.imb_type,
            imb_factor=args.imb_factor,
            rand_number=args.rand_number,
            train=True,
            download=True,
            transform=transform_val)
        augmented_train_dataset = IMBALANCECIFAR100(
            root='./data',
            imb_type=args.imb_type,
            imb_factor=args.imb_factor,
            rand_number=args.rand_number,
            train=True,
            download=True,
            transform=transform_train
            if not args.evaluation else transform_val)
        val_dataset = datasets.CIFAR100(root='./data',
                                        train=False,
                                        download=True,
                                        transform=transform_val)
    else:
        warnings.warn('Dataset is not listed')
        return

    cls_num_list = augmented_train_dataset.get_cls_num_list()
    args.cls_num_list = cls_num_list

    train_labels = np.array(augmented_train_dataset.get_targets()).astype(int)
    # calculate balanced weights
    balanced_weights = torch.tensor(class_weight.compute_class_weight(
        'balanced', np.unique(train_labels), train_labels),
                                    dtype=torch.float).cuda(args.gpu)
    lt_weights = torch.tensor(cls_num_list).float() / max(cls_num_list)

    def create_sampler(args_str):
        if args_str is not None and "resample" in args_str:
            sampler_type, n_resample = args_str.split(",")
            return ClassAwareSampler(train_labels,
                                     num_samples_cls=int(n_resample))
        return None

    original_train_loader = torch.utils.data.DataLoader(
        original_train_dataset,
        batch_size=args.batch_size,
        shuffle=False,
        num_workers=args.workers,
        pin_memory=True)
    val_loader = torch.utils.data.DataLoader(val_dataset,
                                             batch_size=100,
                                             shuffle=False,
                                             num_workers=args.workers,
                                             pin_memory=True)

    # feature extractor dataloader
    feat_sampler = create_sampler(args.feat_sampler)
    feat_train_loader = torch.utils.data.DataLoader(
        augmented_train_dataset,
        batch_size=args.batch_size,
        shuffle=(feat_sampler is None),
        num_workers=args.workers,
        pin_memory=True,
        sampler=feat_sampler)

    if args.evaluation:
        # evaluate on validation set
        # calculate centroids on the train
        _, train_features, train_targets, _ = validate(original_train_loader,
                                                       model,
                                                       0,
                                                       args,
                                                       train_labels,
                                                       flag="train",
                                                       save_out=True)
        # validate
        validate(val_loader,
                 model,
                 0,
                 args,
                 train_labels,
                 flag="val",
                 save_out=True,
                 base_features=train_features,
                 base_targets=train_targets)
        quit()

    # create losses
    def create_loss_list(args_str):
        loss_ls = []
        loss_str_ls = args_str.split(",")
        for loss_str in loss_str_ls:
            c_weights = None
            prefix = ""
            if "_bal" in loss_str:
                c_weights = balanced_weights
                prefix = "Balanced "
                loss_str = loss_str.split("_bal")[0]
            if "_lt" in loss_str:
                c_weights = lt_weights
                prefix = "Longtailed "
                loss_str = loss_str.split("_")[0]
            if loss_str == "ce":
                print(f"{prefix}CE", end=",")
                loss_ls += [
                    nn.CrossEntropyLoss(weight=c_weights).cuda(args.gpu)
                ]
            elif loss_str == "robust_loss":
                print(f"{prefix}Robust Loss", end=",")
                loss_ls += [
                    DROLoss(temperature=args.temperature,
                            base_temperature=args.temperature,
                            class_weights=c_weights,
                            epsilons=learnable_epsilons)
                ]
        print()
        return loss_ls

    feat_losses = create_loss_list(args.feat_loss)
    cls_losses = create_loss_list(args.cls_loss)

    # init log for training
    if not args.evaluation:
        log_training = open(
            os.path.join(args.root_log, args.store_name, 'log_train.csv'), 'w')
        log_testing = open(
            os.path.join(args.root_log, args.store_name, 'log_test.csv'), 'w')
        with open(os.path.join(args.root_log, args.store_name, 'args.txt'),
                  'w') as f:
            f.write(str(args))
        tf_writer = None

    best_acc1 = 0
    best_acc_contrastive = 0
    for epoch in range(args.start_epoch, args.epochs):
        print("=============== Extract Train Centroids ===============")
        _, train_features, train_targets, _ = validate(feat_train_loader,
                                                       model,
                                                       epoch,
                                                       args,
                                                       train_labels,
                                                       log_training,
                                                       tf_writer,
                                                       flag="train",
                                                       verbose=True)

        if epoch < args.epochs - args.balanced_clf_nepochs:
            print("=============== Train Feature Extractor ===============")
            freeze_layers(model, fe_bool=True, cls_bool=False)
            train(feat_train_loader, model, feat_losses, epoch, feat_optim,
                  args, train_features, train_targets)

        else:
            if epoch == args.epochs - args.balanced_clf_nepochs:
                print(
                    "================ Loading Best Feature Extractor ================="
                )
                # load best model
                curr_store_name = args.store_name
                filename = '%s/%s/ckpt.best.pth.tar' % (args.root_model,
                                                        curr_store_name)
                checkpoint = torch.load(
                    filename, map_location=f"cuda:{args.gpu}")['state_dict']
                model.load_state_dict(checkpoint)

            print("=============== Train Classifier ===============")
            freeze_layers(model, fe_bool=False, cls_bool=True)
            train(feat_train_loader, model, cls_losses, epoch, cls_optim, args)

        print("=============== Extract Train Centroids ===============")
        _, train_features, train_targets, _ = validate(original_train_loader,
                                                       model,
                                                       epoch,
                                                       args,
                                                       train_labels,
                                                       log_training,
                                                       tf_writer,
                                                       flag="train",
                                                       verbose=False)

        print("=============== Validate ===============")
        acc1, _, _, contrastive_acc = validate(val_loader,
                                               model,
                                               epoch,
                                               args,
                                               train_labels,
                                               log_testing,
                                               tf_writer,
                                               flag="val",
                                               base_features=train_features,
                                               base_targets=train_targets)
        if epoch < args.epochs - args.balanced_clf_nepochs:
            is_best = contrastive_acc > best_acc_contrastive
            best_acc_contrastive = max(contrastive_acc, best_acc_contrastive)
        else:
            is_best = acc1 > best_acc1
            best_acc1 = max(acc1, best_acc1)

        print(
            f"Best Contrastive Acc: {best_acc_contrastive}, Best Cls Acc: {best_acc1}"
        )
        log_testing.write(
            f"Best Contrastive Acc: {best_acc_contrastive}, Best Cls Acc: {best_acc1}"
        )
        log_testing.flush()
        save_checkpoint(
            args, {
                'epoch': epoch + 1,
                'arch': args.arch,
                'state_dict': model.state_dict(),
                'best_acc1': best_acc1
            }, is_best)
Example #2
0
criterion = nn.CrossEntropyLoss()

optimizer = optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=200)     # , eta_min=1e-8

print('==> Preparing data..')

transform_train = transforms.Compose(
    [
        transforms.Resize((new_image_size, new_image_size)),
        transforms.RandomCrop(new_image_size, padding=4),  # resolution
        transforms.RandomHorizontalFlip(),
        CIFAR10Policy(),
        transforms.ToTensor(),
        Cutout(n_holes=1, length=16),  # (https://github.com/uoguelph-mlrg/Cutout/blob/master/util/cutout.py)
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])

transform_test = transforms.Compose([
    transforms.Resize((new_image_size, new_image_size)),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

trainset = torchvision.datasets.CIFAR10(
    root='./data', train=True, download=True, transform=transform_train)

testset = torchvision.datasets.CIFAR10(
    root='./data', train=False, download=True, transform=transform_test)

def main():
    print(args)

    if not osp.exists(args.dir):
        os.makedirs(args.dir)

    if args.use_gpu:
        torch.cuda.set_device(args.gpu)
        cudnn.enabled = True
        cudnn.benchmark = True

    if args.manualSeed is None:
        args.manualSeed = random.randint(1, 10000)
    np.random.seed(args.manualSeed)

    labeled_size = args.label_num + args.val_num

    num_classes = 10
    data_dir = '../cifar10_data/'

    normalize = transforms.Normalize(mean=[0.4914, 0.4822, 0.4465],
                                     std=[0.2470, 0.2435, 0.2616])

    # transform is implemented inside zca dataloader
    dataloader = cifar.CIFAR10
    if args.auto:
        transform_train = transforms.Compose([
            transforms.RandomCrop(
                32, padding=4, fill=128
            ),  # fill parameter needs torchvision installed from source
            transforms.RandomHorizontalFlip(),
            CIFAR10Policy(),
            transforms.ToTensor(),
            Cutout(
                n_holes=1, length=16
            ),  # (https://github.com/uoguelph-mlrg/Cutout/blob/master/util/cutout.py)
            normalize
        ])
    else:
        transform_train = transforms.Compose([
            transforms.RandomCrop(
                32, padding=4, fill=128
            ),  # fill parameter needs torchvision installed from source
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize
        ])

    transform_test = transforms.Compose([transforms.ToTensor(), normalize])

    base_dataset = datasets.CIFAR10(data_dir, train=True, download=True)
    train_labeled_idxs, train_unlabeled_idxs, val_idxs = train_val_split(
        base_dataset.targets, int(args.label_num / 10))

    labelset = CIFAR10_labeled(data_dir,
                               train_labeled_idxs,
                               train=True,
                               transform=transform_train)
    labelset2 = CIFAR10_labeled(data_dir,
                                train_labeled_idxs,
                                train=True,
                                transform=transform_test)
    unlabelset = CIFAR10_labeled(data_dir,
                                 train_unlabeled_idxs,
                                 train=True,
                                 transform=transform_train)
    unlabelset2 = CIFAR10_labeled(data_dir,
                                  train_unlabeled_idxs,
                                  train=True,
                                  transform=transform_test)
    validset = CIFAR10_labeled(data_dir,
                               val_idxs,
                               train=True,
                               transform=transform_test)
    testset = CIFAR10_labeled(data_dir, train=False, transform=transform_test)

    label_y = np.array(labelset.targets).astype(np.int32)
    unlabel_y = np.array(unlabelset.targets).astype(np.int32)
    unlabel_num = unlabel_y.shape[0]

    label_loader = torch.utils.data.DataLoader(labelset,
                                               batch_size=args.batch_size,
                                               num_workers=args.num_workers,
                                               pin_memory=True,
                                               drop_last=True)

    label_loader2 = torch.utils.data.DataLoader(
        labelset2,
        batch_size=args.eval_batch_size,
        num_workers=args.num_workers,
        pin_memory=True)

    unlabel_loader = torch.utils.data.DataLoader(
        unlabelset,
        batch_size=args.eval_batch_size,
        num_workers=args.num_workers,
        pin_memory=True)

    unlabel_loader2 = torch.utils.data.DataLoader(
        unlabelset2,
        batch_size=args.eval_batch_size,
        num_workers=args.num_workers,
        pin_memory=True)

    validloader = torch.utils.data.DataLoader(validset,
                                              batch_size=args.eval_batch_size,
                                              num_workers=args.num_workers,
                                              pin_memory=True)

    testloader = torch.utils.data.DataLoader(testset,
                                             batch_size=args.eval_batch_size,
                                             num_workers=args.num_workers,
                                             pin_memory=True)

    #initialize models
    model1 = create_model(args.num_classes, args.model)
    model2 = create_model(args.num_classes, args.model)
    ema_model = create_model(args.num_classes, args.model)

    if args.use_gpu:
        model1 = model1.cuda()
        model2 = model2.cuda()
        ema_model = ema_model.cuda()

    for param in ema_model.parameters():
        param.detach_()

    df = pd.DataFrame()
    stats_path = osp.join(args.dir, 'stats.txt')
    '''if prop > args.scale:
        prop = args.scale'''

    optimizer1 = AdamW(model1.parameters(), lr=args.lr)

    if args.init1 and osp.exists(args.init1):
        model1.load_state_dict(
            torch.load(args.init1, map_location='cuda:{}'.format(args.gpu)))

    ema_optimizer = WeightEMA(model1, ema_model, alpha=args.ema_decay)

    if args.init and osp.exists(args.init):
        model1.load_state_dict(
            torch.load(args.init, map_location='cuda:{}'.format(args.gpu)))

    _, best_acc = evaluate(validloader, ema_model, prefix='val')

    best_ema_path = osp.join(args.dir, 'best_ema.pth')
    best_model1_path = osp.join(args.dir, 'best_model1.pth')
    best_model2_path = osp.join(args.dir, 'best_model2.pth')
    init_path = osp.join(args.dir, 'init_ema.pth')
    init_path1 = osp.join(args.dir, 'init1.pth')
    init_path2 = osp.join(args.dir, 'init2.pth')
    torch.save(ema_model.state_dict(), init_path)
    torch.save(model1.state_dict(), init_path1)
    torch.save(model2.state_dict(), init_path2)
    torch.save(ema_model.state_dict(), best_ema_path)
    torch.save(model1.state_dict(), best_model1_path)
    skip_model2 = False
    end_iter = False

    confident_indices = np.array([], dtype=np.int64)
    all_indices = np.arange(unlabel_num).astype(np.int64)
    #no_help_indices = np.array([]).astype(np.int64)
    pseudo_labels = np.zeros(all_indices.shape, dtype=np.int32)

    steps_per_epoch = len(iter(label_loader))
    max_epoch = args.steps // steps_per_epoch

    logger = logging.getLogger('init')
    file_handler = logging.FileHandler(osp.join(args.dir, 'init.txt'))
    logger.addHandler(file_handler)
    logger.setLevel(logging.INFO)

    for epoch in range(max_epoch * 4 // 5):
        if args.mix:
            train_init_mix(label_loader,
                           model1,
                           optimizer1,
                           ema_optimizer,
                           steps_per_epoch,
                           epoch,
                           logger=logger)
        else:
            train_init(label_loader,
                       model1,
                       optimizer1,
                       ema_optimizer,
                       steps_per_epoch,
                       epoch,
                       logger=logger)

        if epoch % 10 == 0:
            val_loss, val_acc = evaluate(validloader, ema_model, logger,
                                         'valid')
            if val_acc >= best_acc:
                best_acc = val_acc
                evaluate(testloader, ema_model, logger, 'test')
                torch.save(ema_model.state_dict(), init_path)
                torch.save(model1.state_dict(), init_path1)

    adjust_learning_rate_adam(optimizer1, args.lr * 0.2)

    for epoch in range(max_epoch // 5):
        if args.mix:
            train_init_mix(label_loader,
                           model1,
                           optimizer1,
                           ema_optimizer,
                           steps_per_epoch,
                           epoch,
                           logger=logger)
        else:
            train_init(label_loader,
                       model1,
                       optimizer1,
                       ema_optimizer,
                       steps_per_epoch,
                       epoch,
                       logger=logger)

        if epoch % 10 == 0:
            val_loss, val_acc = evaluate(validloader, ema_model, logger,
                                         'valid')
            if val_acc >= best_acc:
                best_acc = val_acc
                evaluate(testloader, ema_model, logger, 'test')
                torch.save(ema_model.state_dict(), init_path)
                torch.save(model1.state_dict(), init_path1)

    ema_model.load_state_dict(torch.load(init_path))
    model1.load_state_dict(torch.load(init_path1))

    logger.info('init train finished')
    evaluate(validloader, ema_model, logger, 'valid')
    evaluate(testloader, ema_model, logger, 'test')

    for i_round in range(args.round):
        mask = np.zeros(all_indices.shape, dtype=bool)
        mask[confident_indices] = True
        other_indices = all_indices[~mask]

        optimizer2 = AdamW(model2.parameters(), lr=args.lr)

        logger = logging.getLogger('model2_round_{}'.format(i_round))
        file_handler = logging.FileHandler(
            osp.join(args.dir, 'model2_round_{}.txt'.format(i_round)))
        logger.addHandler(file_handler)
        logger.setLevel(logging.INFO)

        if args.auto:
            probs = predict_probs(ema_model, unlabel_loader2)
        else:
            probs = np.zeros((unlabel_num, args.num_classes))
            for i in range(args.K):
                probs += predict_probs(ema_model, unlabel_loader)
            probs /= args.K

        pseudo_labels[other_indices] = probs.argmax(axis=1).astype(
            np.int32)[other_indices]
        #pseudo_labels = probs.argmax(axis=1).astype(np.int32)

        df2 = create_basic_stats_dataframe()
        df2['iter'] = i_round
        df2['train_acc'] = accuracy_score(unlabel_y, pseudo_labels)
        df = df.append(df2, ignore_index=True)
        df.to_csv(stats_path, index=False)

        #phase2: train model2
        unlabelset.targets = pseudo_labels.copy()
        trainset = ConcatDataset([labelset, unlabelset])

        trainloader = torch.utils.data.DataLoader(trainset,
                                                  batch_size=args.batch_size2,
                                                  num_workers=args.num_workers,
                                                  pin_memory=True,
                                                  shuffle=True)

        model2.load_state_dict(torch.load(init_path2))
        best_val_epoch = 0
        best_model2_acc = 0

        steps_per_epoch = len(iter(trainloader))
        max_epoch2 = args.steps2 // steps_per_epoch

        for epoch in range(max_epoch2):
            train_model2(trainloader, model2, optimizer2, epoch, logger=logger)

            val_loss, val_acc = evaluate(validloader, model2, logger, 'val')

            if val_acc >= best_model2_acc:
                best_model2_acc = val_acc
                best_val_epoch = epoch
                torch.save(model2.state_dict(), best_model2_path)
                evaluate(testloader, model2, logger, 'test')

            if (epoch - best_val_epoch) * steps_per_epoch > args.stop_steps2:
                break

        df.loc[df['iter'] == i_round, 'valid_acc'] = best_model2_acc
        df.loc[df['iter'] == i_round, 'valid_epoch'] = best_val_epoch
        df.to_csv(stats_path, index=False)

        model2.load_state_dict(torch.load(best_model2_path))
        logger.info('model2 train finished')

        evaluate(trainloader, model2, logger, 'train')

        evaluate(validloader, model2, logger, 'val')
        evaluate(label_loader2, model2, logger, 'reward')
        evaluate(testloader, model2, logger, 'test')
        #phase3: get confidence of unlabeled data by labeled data, split confident and unconfident data
        '''if args.auto:
            probs  = predict_probs(model2,unlabel_loader2)
        else:
            probs = np.zeros((unlabel_num,args.num_classes))
            for i in range(args.K):
                probs += predict_probs(model2, unlabel_loader)
            probs /= args.K'''

        probs = predict_probs(model2, unlabel_loader2)
        new_pseudo_labels = probs.argmax(axis=1)

        confidences = probs[all_indices, pseudo_labels]

        if args.schedule == 'exp':
            confident_num = int((len(confident_indices) + args.label_num) *
                                (1 + args.scale)) - args.label_num
        elif args.schedule == 'linear':
            confident_num = len(confident_indices) + int(
                unlabel_num * args.scale)

        old_confident_indices = confident_indices.copy()
        confident_indices = np.array([], dtype=np.int64)

        for j in range(args.num_classes):
            j_cands = (pseudo_labels == j)
            k_size = int(min(confident_num // args.num_classes, j_cands.sum()))
            logger.info('class: {}, confident size: {}'.format(j, k_size))
            if k_size > 0:
                j_idx_top = all_indices[j_cands][
                    confidences[j_cands].argsort()[-k_size:]]
                confident_indices = np.concatenate(
                    (confident_indices, all_indices[j_idx_top]))
        '''new_confident_indices = np.intersect1d(new_confident_indices, np.setdiff1d(new_confident_indices, no_help_indices))
        new_confident_indices = new_confident_indices[(-confidences[new_confident_indices]).argsort()]
        confident_indices = np.concatenate((old_confident_indices, new_confident_indices))'''

        acc = accuracy_score(unlabel_y[confident_indices],
                             pseudo_labels[confident_indices])
        logger.info('confident data num:{}, prop: {:4f}, acc: {:4f}'.format(
            len(confident_indices),
            len(confident_indices) / len(unlabel_y), acc))
        '''if len(old_confident_indices) > 0:
            acc = accuracy_score(unlabel_y[old_confident_indices],pseudo_labels[old_confident_indices])        
            logger.info('old confident data prop: {:4f}, acc: {:4f}'.format(len(old_confident_indices)/len(unlabel_y), acc))

        acc = accuracy_score(unlabel_y[new_confident_indices],pseudo_labels[new_confident_indices])
        logger.info('new confident data prop: {:4f}, acc: {:4f}'.format(len(new_confident_indices)/len(unlabel_y), acc))'''

        #unlabelset.train_labels_ul = pseudo_labels.copy()
        confident_dataset = torch.utils.data.Subset(unlabelset,
                                                    confident_indices)

        #phase4: refine model1 by confident data and reward data
        #train_dataset = torch.utils.data.ConcatDataset([confident_dataset,labelset])

        logger = logging.getLogger('model1_round_{}'.format(i_round))
        file_handler = logging.FileHandler(
            osp.join(args.dir, 'model1_round_{}.txt'.format(i_round)))
        logger.addHandler(file_handler)
        logger.setLevel(logging.INFO)

        best_val_epoch = 0
        evaluate(validloader, ema_model, logger, 'valid')
        evaluate(testloader, ema_model, logger, 'test')

        optimizer1 = AdamW(model1.parameters(), lr=args.lr)

        confident_dataset = torch.utils.data.Subset(unlabelset,
                                                    confident_indices)
        trainloader = torch.utils.data.DataLoader(confident_dataset,
                                                  batch_size=args.batch_size,
                                                  num_workers=args.num_workers,
                                                  shuffle=True,
                                                  drop_last=True)

        #steps_per_epoch = len(iter(trainloader))
        steps_per_epoch = 200
        max_epoch1 = args.steps1 // steps_per_epoch

        for epoch in range(max_epoch1):
            '''current_num = int(cal_consistency_weight( (epoch + 1) * steps_per_epoch, init_ep=0, end_ep=args.stop_steps1//2, init_w=start_num, end_w=end_num))            
            current_confident_indices = confident_indices[:current_num]
            logger.info('current num: {}'.format(current_num))'''
            if args.mix:
                train_model1_mix(label_loader,
                                 trainloader,
                                 model1,
                                 optimizer1,
                                 ema_model,
                                 ema_optimizer,
                                 steps_per_epoch,
                                 epoch,
                                 logger=logger)
            else:
                train_model1(label_loader,
                             trainloader,
                             model1,
                             optimizer1,
                             ema_model,
                             ema_optimizer,
                             steps_per_epoch,
                             epoch,
                             logger=logger)

            val_loss, val_acc = evaluate(validloader, ema_model, logger,
                                         'valid')
            if val_acc >= best_acc:
                best_acc = val_acc
                best_val_epoch = epoch
                evaluate(testloader, ema_model, logger, 'test')
                torch.save(model1.state_dict(), best_model1_path)
                torch.save(ema_model.state_dict(), best_ema_path)

            if (epoch - best_val_epoch) * steps_per_epoch > args.stop_steps1:
                break

        ema_model.load_state_dict(torch.load(best_ema_path))
        model1.load_state_dict(torch.load(best_model1_path))

        logger.info('model1 train finished')
        evaluate(validloader, ema_model, logger, 'valid')
        evaluate(testloader, ema_model, logger, 'test')
        '''no_help_indices = np.concatenate((no_help_indices,confident_indices[current_num:]))
        confident_indices = confident_indices[:current_num]'''

        if len(confident_indices) >= len(all_indices):
            break