Ejemplo n.º 1
0
def _data_transforms(args):

    if 'cifar' in args.dataset:
        norm_mean = [0.49139968, 0.48215827, 0.44653124]
        norm_std = [0.24703233, 0.24348505, 0.26158768]
    else:
        raise KeyError

    train_transform = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        # transforms.Resize(224, interpolation=3),  # BICUBIC interpolation
        transforms.RandomHorizontalFlip(),
    ])

    if args.autoaugment:
        train_transform.transforms.append(CIFAR10Policy())

    train_transform.transforms.append(transforms.ToTensor())

    if args.cutout:
        train_transform.transforms.append(Cutout(args.cutout_length))

    train_transform.transforms.append(transforms.Normalize(
        norm_mean, norm_std))

    valid_transform = transforms.Compose([
        transforms.Resize(224, interpolation=3),  # BICUBIC interpolation
        transforms.ToTensor(),
        transforms.Normalize(norm_mean, norm_std),
    ])
    return train_transform, valid_transform
Ejemplo n.º 2
0
def get_transform(args):
    train_transform = []
    test_transform = []
    train_transform += [
        transforms.RandomCrop(size=args.size, padding=args.padding)
    ]
    if args.dataset != 'svhn':
        train_transform += [transforms.RandomHorizontalFlip()]
    
    if args.autoaugment:
        if args.dataset == 'c10' or args.dataset=='c100':
            train_transform.append(CIFAR10Policy())
        elif args.dataset == 'svhn':
            train_transform.append(SVHNPolicy())
        else:
            print(f"No AutoAugment for {args.dataset}")   

    train_transform += [
        transforms.ToTensor(),
        transforms.Normalize(mean=args.mean, std=args.std)
    ]
    if args.rcpaste:
        train_transform += [RandomCropPaste(size=args.size)]
    
    test_transform += [
        transforms.ToTensor(),
        transforms.Normalize(mean=args.mean, std=args.std)
    ]

    train_transform = transforms.Compose(train_transform)
    test_transform = transforms.Compose(test_transform)

    return train_transform, test_transform
Ejemplo n.º 3
0
def _data_transforms_cifar10(args):
  CIFAR_MEAN = [0.49139968, 0.48215827, 0.44653124]
  CIFAR_STD = [0.24703233, 0.24348505, 0.26158768]

  if args.auto_aug:
     train_transform = transforms.Compose([
      transforms.RandomCrop(32, padding=4),
      transforms.RandomHorizontalFlip(),
      CIFAR10Policy(),
      transforms.ToTensor(),
      transforms.Normalize(CIFAR_MEAN, CIFAR_STD),
    ])
  else:
    train_transform = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(CIFAR_MEAN, CIFAR_STD),
  ])
  if args.cutout:
    train_transform.transforms.append(Cutout(args.cutout_length))

  valid_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(CIFAR_MEAN, CIFAR_STD),
    ])
  return train_transform, valid_transform
Ejemplo n.º 4
0
def augment(np_ary):
    transformed_list = []
    img_list = make_img_from_tensor(np_ary)
    policy = CIFAR10Policy()
    policy.draw()
    for img in img_list:
        transformed_list.append(np.array(policy(img)))
    return transformed_list
Ejemplo n.º 5
0
 def CIFAR10_policy(self):
     """generates CIFAR10 Policy"""
     imag_tag = self.im_file
     policy = CIFAR10Policy()
     image = Image.open(imag_tag)
     tag = re.findall('([-\w]+\.(?:jpg|gif|png|JPG|JPEG|jpeg))', imag_tag)
     for t in range(self.count):
         polic = policy(image)
         polic.save(self.data_save_dir + str(t) + tag[0], format="JPEG")
Ejemplo n.º 6
0
def cifar_autoaugment(input_size, scale_size=None, normalize=_IMAGENET_STATS):
    padding = int((scale_size - input_size) / 2)
    return transforms.Compose([
        transforms.RandomCrop(input_size, padding=padding),
        transforms.RandomHorizontalFlip(),
        CIFAR10Policy(fillcolor=(128, 128, 128)),
        transforms.ToTensor(),
        transforms.Normalize(**normalize),
    ])
Ejemplo n.º 7
0
def cifar_autoaugment(input_size, scale_size=None, padding=None, normalize=_IMAGENET_STATS):
    scale_size = scale_size or input_size
    T = transforms.Compose([
        transforms.RandomCrop(scale_size, padding=padding),
        transforms.RandomHorizontalFlip(),
        CIFAR10Policy(fillcolor=(128, 128, 128)),
        transforms.ToTensor(),
        transforms.Normalize(**normalize),
    ])
    if input_size != scale_size:
        T.transforms.insert(1, transforms.Resize(input_size))
    return T
Ejemplo n.º 8
0
    def __init__(self, args):
#        self.train_size = args.train_size
        self.batch_size = args.batch_size
        self.threads = args.threads

#        mean, std = self._get_statistics()
        cifar10_mean = (0.4914, 0.4822, 0.4465)
        cifar10_std = (0.2471, 0.2435, 0.2616)
#        print("4. Cifar.py ", probabilities)

        train_transform = transforms.Compose([
            torchvision.transforms.RandomCrop(size=(32, 32), padding=4),
            torchvision.transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
#            transforms.Normalize(mean, std),
            transforms.Normalize(cifar10_mean, cifar10_std),
            Cutout()
        ])
        if args.add_augment > 0:
            train_transform.transforms.insert(0,CIFAR10Policy())
            train_transform.transforms.insert(0,torchvision.transforms.RandomRotation(30.0))

        test_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(cifar10_mean, cifar10_std),
#            transforms.Normalize(mean, std)
        ])

        self.train_set = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=train_transform)
        filename = ('data/config/cifar10.%d@%d%s.npy' % (args.seed, args.train_size, args.data_bal) )
        print("Loading data configuration file ", filename)
        train_samples = np.load(filename)
#        print("train_samples ", train_samples)
        self.train_set = Subset(self.train_set, indices=train_samples)
        self.test_set = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=test_transform)
        
        sampler = RandomSampler(self.train_set, replacement=False) #, num_samples=self.train_size)
        batch_sampler = BatchSampler(sampler, self.batch_size, drop_last=True)

        self.train = torch.utils.data.DataLoader(self.train_set, batch_sampler=batch_sampler, num_workers=self.threads)
#        self.train = torch.utils.data.DataLoader(self.train_set, batch_size=self.batch_size, shuffle=True, num_workers=self.threads)
        self.test = torch.utils.data.DataLoader(self.test_set, batch_size=self.batch_size, shuffle=False, num_workers=self.threads)

        self.classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
Ejemplo n.º 9
0
def data_transforms(dataset, cutout_length, autoaugment=False):
    dataset = dataset.lower()
    if dataset == 'cifar10':
        MEAN = [0.49139968, 0.48215827, 0.44653124]
        STD = [0.24703233, 0.24348505, 0.26158768]
        transf = [
            transforms.RandomCrop(32, padding=4, padding_mode='reflect'),
            transforms.RandomHorizontalFlip()
        ]
        if autoaugment:
            transf.append(CIFAR10Policy())
    elif dataset == 'mnist':
        MEAN = [0.13066051707548254]
        STD = [0.30810780244715075]
        transf = [
            transforms.RandomAffine(degrees=15,
                                    translate=(0.1, 0.1),
                                    scale=(0.9, 1.1),
                                    shear=0.1)
        ]
    elif dataset == 'fashionmnist':
        MEAN = [0.28604063146254594]
        STD = [0.35302426207299326]
        transf = [
            transforms.RandomAffine(degrees=15,
                                    translate=(0.1, 0.1),
                                    scale=(0.9, 1.1),
                                    shear=0.1),
            transforms.RandomVerticalFlip()
        ]
    else:
        raise ValueError('not expected dataset = {}'.format(dataset))

    normalize = [transforms.ToTensor(), transforms.Normalize(MEAN, STD)]

    train_transform = transforms.Compose(transf + normalize)
    valid_transform = transforms.Compose(normalize)

    if cutout_length > 0:
        train_transform.transforms.append(Cutout(cutout_length))

    return train_transform, valid_transform
Ejemplo n.º 10
0
def dataload(r, supnum=4000):
    root = 'data/'
    if not os.path.exists(root):
        os.mkdir(root)
    t = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.49139968, 0.48215841, 0.44653091],
                             std=[0.24703223, 0.24348513, 0.26158784])
    ])
    aug_t = transforms.Compose([
        transforms.RandomHorizontalFlip(),
        CIFAR10Policy(),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.49139968, 0.48215841, 0.44653091],
                             std=[0.24703223, 0.24348513, 0.26158784])
    ])
    base_set = torchvision.datasets.CIFAR10(root=root)
    sup_idx, unsup_idx = sup_unsup_proc(base_set.targets, int(supnum / 10), r)
    sup_data = SupImgDataSet(root=root, index=sup_idx, transform=t)
    val_data = torchvision.datasets.CIFAR10(root=root,
                                            train=False,
                                            transform=t)
    unsup_data = UnSupImgDataSet(root=root,
                                 index=unsup_idx,
                                 transform=aug_t,
                                 target_transform=t)
    sup_dataloader = DataLoader(dataset=sup_data,
                                batch_size=64,
                                num_workers=2,
                                shuffle=True)
    val_dataloader = DataLoader(dataset=val_data,
                                batch_size=16,
                                num_workers=2,
                                shuffle=True)
    unsup_dataloader = DataLoader(dataset=unsup_data,
                                  batch_size=128,
                                  num_workers=2,
                                  shuffle=True)
    return sup_dataloader, val_dataloader, unsup_dataloader
Ejemplo n.º 11
0
    def __getitem__(self, i):

        image = cv2.imread(self.images_dir + self.images[i], cv2.IMREAD_COLOR)
        image = cv2.resize(image,
                           self.base_size,
                           interpolation=cv2.INTER_LINEAR)
        if self.applyAutoAug:
            image = Image.fromarray(image.astype('uint8'))
            #policy = ImageNetPolicy()
            policy = CIFAR10Policy()
            #policy = SVHNPolicy()
            image = policy(image)
            image = np.array(image)
            #img = np.asarray(image)
        if self.applyCutout:
            image = Cutout(image)  ##这么写是错的
        #size = image.shape
        ##normalize and change it to a tensor
        #image = self.input_transform(image)
        #image = image.transpose((2, 0, 1))

        label = cv2.imread(self.labels_dir + self.labels[i],
                           cv2.IMREAD_GRAYSCALE)
        label = cv2.resize(label,
                           self.base_size,
                           interpolation=cv2.INTER_NEAREST)

        #some operations needed here
        ## depends on the range of label values
        #label = self.convert_label(label)

        image, label = self.gen_sample(image, label, self.multi_scale,
                                       self.flip)
        label = np.expand_dims(label, axis=0)

        return image.copy(), label.copy()
Ejemplo n.º 12
0
# print(model)

criterion = nn.CrossEntropyLoss()

optimizer = optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=200)     # , eta_min=1e-8

print('==> Preparing data..')

transform_train = transforms.Compose(
    [
        transforms.Resize((new_image_size, new_image_size)),
        transforms.RandomCrop(new_image_size, padding=4),  # resolution
        transforms.RandomHorizontalFlip(),
        CIFAR10Policy(),
        transforms.ToTensor(),
        Cutout(n_holes=1, length=16),  # (https://github.com/uoguelph-mlrg/Cutout/blob/master/util/cutout.py)
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])

transform_test = transforms.Compose([
    transforms.Resize((new_image_size, new_image_size)),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

trainset = torchvision.datasets.CIFAR10(
    root='./data', train=True, download=True, transform=transform_train)

testset = torchvision.datasets.CIFAR10(
    root='./data', train=False, download=True, transform=transform_test)
Ejemplo n.º 13
0
 def auto_aug(tensor):
     policy = CIFAR10Policy()
     return tf.py_func(policy, [tensor], tf.uint8)
Ejemplo n.º 14
0
def main():
    global args, config, last_epoch, best_prec

    # read config from yaml file
    with open(args.work_path + '/config.yaml') as f:
        config = yaml.load(f)
    # convert to dict
    config = EasyDict(config)
    logger.info(config)

    # define netowrk
    net = get_model(config)
    logger.info(net)
    logger.info(" == total parameters: " + str(count_parameters(net)))

    # CPU or GPU
    device = 'cuda' if config.use_gpu else 'cpu'
    # data parallel for multiple-GPU
    if device == 'cuda':
        net = torch.nn.DataParallel(net)
        cudnn.benchmark = True

    net.to(device)

    # define loss and optimizer
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(
        net.parameters(),
        config.lr_scheduler.base_lr,
        momentum=config.optimize.momentum,
        weight_decay=config.optimize.weight_decay,
        nesterov=config.optimize.nesterov)

    # resume from a checkpoint
    last_epoch = -1
    best_prec = 0
    if args.work_path:
        ckpt_file_name = args.work_path + '/' + config.ckpt_name + '.pth.tar'
        if args.resume:
            best_prec, last_epoch = load_checkpoint(
                ckpt_file_name, net, optimizer=optimizer)

    # load training data, do data augmentation and get data loader
    if config.auto_augment:
        transform_train = transforms.Compose(
                        [transforms.RandomCrop(32, padding=4, fill=128), # fill parameter needs torchvision installed from source
                         transforms.RandomHorizontalFlip(), CIFAR10Policy(), 
			             transforms.ToTensor(), 
                         Cutout(n_holes=1, length=16), # (https://github.com/uoguelph-mlrg/Cutout/blob/master/util/cutout.py)
                         transforms.Normalize( (0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])
    else:
        transform_train = transforms.Compose(
            data_augmentation(config))

    transform_test = transforms.Compose(
        data_augmentation(config, is_train=False))

    train_loader, test_loader = get_data_loader(
        transform_train, transform_test, config)

    logger.info("            =======  Training  =======\n")
    for epoch in range(last_epoch + 1, config.epochs):
        lr = adjust_learning_rate(optimizer, epoch, config)
        train(train_loader, net, criterion, optimizer, epoch, device)
        if epoch == 0 or (
                epoch + 1) % config.eval_freq == 0 or epoch == config.epochs - 1:
            test(test_loader, net, criterion, optimizer, epoch, device)

    logger.info(
        "======== Training Finished.   best_test_acc: {:.3f}% ========".format(best_prec))
Ejemplo n.º 15
0
    ])
elif args.dataset == 'svhn':
    # the WRN paper does no augmentation on SVHN
    # obviously flipping is a bad idea, and it makes some sense not to
    # crop because there are a lot of distractor digits in the edges of the
    # image
    transform_train = transforms.ToTensor()

if args.autoaugment or args.cutout:
    assert (args.dataset == 'cifar10')
    transform_list = [
        transforms.RandomCrop(32, padding=4, fill=128),
        # fill parameter needs torchvision installed from source
        transforms.RandomHorizontalFlip()]
    if args.autoaugment:
        transform_list.append(CIFAR10Policy())
    transform_list.append(transforms.ToTensor())
    if args.cutout:
        transform_list.append(Cutout(n_holes=1, length=16))

    transform_train = transforms.Compose(transform_list)
    logger.info('Applying aggressive training augmentation: %s'
                % transform_train)

transform_test = transforms.Compose([
    transforms.ToTensor()])
# ------------------------------------------------------------------------------

# ----------------- DATASET WITH AUX PSEUDO-LABELED DATA -----------------------
trainset = SemiSupervisedDataset(base_dataset=args.dataset,
                                 add_svhn_extra=args.svhn_extra,
def main():
    print(args)

    if not osp.exists(args.dir):
        os.makedirs(args.dir)

    if args.use_gpu:
        torch.cuda.set_device(args.gpu)
        cudnn.enabled = True
        cudnn.benchmark = True

    if args.manualSeed is None:
        args.manualSeed = random.randint(1, 10000)
    np.random.seed(args.manualSeed)

    labeled_size = args.label_num + args.val_num

    num_classes = 10
    data_dir = '../cifar10_data/'

    normalize = transforms.Normalize(mean=[0.4914, 0.4822, 0.4465],
                                     std=[0.2470, 0.2435, 0.2616])

    # transform is implemented inside zca dataloader
    dataloader = cifar.CIFAR10
    if args.auto:
        transform_train = transforms.Compose([
            transforms.RandomCrop(
                32, padding=4, fill=128
            ),  # fill parameter needs torchvision installed from source
            transforms.RandomHorizontalFlip(),
            CIFAR10Policy(),
            transforms.ToTensor(),
            Cutout(
                n_holes=1, length=16
            ),  # (https://github.com/uoguelph-mlrg/Cutout/blob/master/util/cutout.py)
            normalize
        ])
    else:
        transform_train = transforms.Compose([
            transforms.RandomCrop(
                32, padding=4, fill=128
            ),  # fill parameter needs torchvision installed from source
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize
        ])

    transform_test = transforms.Compose([transforms.ToTensor(), normalize])

    base_dataset = datasets.CIFAR10(data_dir, train=True, download=True)
    train_labeled_idxs, train_unlabeled_idxs, val_idxs = train_val_split(
        base_dataset.targets, int(args.label_num / 10))

    labelset = CIFAR10_labeled(data_dir,
                               train_labeled_idxs,
                               train=True,
                               transform=transform_train)
    labelset2 = CIFAR10_labeled(data_dir,
                                train_labeled_idxs,
                                train=True,
                                transform=transform_test)
    unlabelset = CIFAR10_labeled(data_dir,
                                 train_unlabeled_idxs,
                                 train=True,
                                 transform=transform_train)
    unlabelset2 = CIFAR10_labeled(data_dir,
                                  train_unlabeled_idxs,
                                  train=True,
                                  transform=transform_test)
    validset = CIFAR10_labeled(data_dir,
                               val_idxs,
                               train=True,
                               transform=transform_test)
    testset = CIFAR10_labeled(data_dir, train=False, transform=transform_test)

    label_y = np.array(labelset.targets).astype(np.int32)
    unlabel_y = np.array(unlabelset.targets).astype(np.int32)
    unlabel_num = unlabel_y.shape[0]

    label_loader = torch.utils.data.DataLoader(labelset,
                                               batch_size=args.batch_size,
                                               num_workers=args.num_workers,
                                               pin_memory=True,
                                               drop_last=True)

    label_loader2 = torch.utils.data.DataLoader(
        labelset2,
        batch_size=args.eval_batch_size,
        num_workers=args.num_workers,
        pin_memory=True)

    unlabel_loader = torch.utils.data.DataLoader(
        unlabelset,
        batch_size=args.eval_batch_size,
        num_workers=args.num_workers,
        pin_memory=True)

    unlabel_loader2 = torch.utils.data.DataLoader(
        unlabelset2,
        batch_size=args.eval_batch_size,
        num_workers=args.num_workers,
        pin_memory=True)

    validloader = torch.utils.data.DataLoader(validset,
                                              batch_size=args.eval_batch_size,
                                              num_workers=args.num_workers,
                                              pin_memory=True)

    testloader = torch.utils.data.DataLoader(testset,
                                             batch_size=args.eval_batch_size,
                                             num_workers=args.num_workers,
                                             pin_memory=True)

    #initialize models
    model1 = create_model(args.num_classes, args.model)
    model2 = create_model(args.num_classes, args.model)
    ema_model = create_model(args.num_classes, args.model)

    if args.use_gpu:
        model1 = model1.cuda()
        model2 = model2.cuda()
        ema_model = ema_model.cuda()

    for param in ema_model.parameters():
        param.detach_()

    df = pd.DataFrame()
    stats_path = osp.join(args.dir, 'stats.txt')
    '''if prop > args.scale:
        prop = args.scale'''

    optimizer1 = AdamW(model1.parameters(), lr=args.lr)

    if args.init1 and osp.exists(args.init1):
        model1.load_state_dict(
            torch.load(args.init1, map_location='cuda:{}'.format(args.gpu)))

    ema_optimizer = WeightEMA(model1, ema_model, alpha=args.ema_decay)

    if args.init and osp.exists(args.init):
        model1.load_state_dict(
            torch.load(args.init, map_location='cuda:{}'.format(args.gpu)))

    _, best_acc = evaluate(validloader, ema_model, prefix='val')

    best_ema_path = osp.join(args.dir, 'best_ema.pth')
    best_model1_path = osp.join(args.dir, 'best_model1.pth')
    best_model2_path = osp.join(args.dir, 'best_model2.pth')
    init_path = osp.join(args.dir, 'init_ema.pth')
    init_path1 = osp.join(args.dir, 'init1.pth')
    init_path2 = osp.join(args.dir, 'init2.pth')
    torch.save(ema_model.state_dict(), init_path)
    torch.save(model1.state_dict(), init_path1)
    torch.save(model2.state_dict(), init_path2)
    torch.save(ema_model.state_dict(), best_ema_path)
    torch.save(model1.state_dict(), best_model1_path)
    skip_model2 = False
    end_iter = False

    confident_indices = np.array([], dtype=np.int64)
    all_indices = np.arange(unlabel_num).astype(np.int64)
    #no_help_indices = np.array([]).astype(np.int64)
    pseudo_labels = np.zeros(all_indices.shape, dtype=np.int32)

    steps_per_epoch = len(iter(label_loader))
    max_epoch = args.steps // steps_per_epoch

    logger = logging.getLogger('init')
    file_handler = logging.FileHandler(osp.join(args.dir, 'init.txt'))
    logger.addHandler(file_handler)
    logger.setLevel(logging.INFO)

    for epoch in range(max_epoch * 4 // 5):
        if args.mix:
            train_init_mix(label_loader,
                           model1,
                           optimizer1,
                           ema_optimizer,
                           steps_per_epoch,
                           epoch,
                           logger=logger)
        else:
            train_init(label_loader,
                       model1,
                       optimizer1,
                       ema_optimizer,
                       steps_per_epoch,
                       epoch,
                       logger=logger)

        if epoch % 10 == 0:
            val_loss, val_acc = evaluate(validloader, ema_model, logger,
                                         'valid')
            if val_acc >= best_acc:
                best_acc = val_acc
                evaluate(testloader, ema_model, logger, 'test')
                torch.save(ema_model.state_dict(), init_path)
                torch.save(model1.state_dict(), init_path1)

    adjust_learning_rate_adam(optimizer1, args.lr * 0.2)

    for epoch in range(max_epoch // 5):
        if args.mix:
            train_init_mix(label_loader,
                           model1,
                           optimizer1,
                           ema_optimizer,
                           steps_per_epoch,
                           epoch,
                           logger=logger)
        else:
            train_init(label_loader,
                       model1,
                       optimizer1,
                       ema_optimizer,
                       steps_per_epoch,
                       epoch,
                       logger=logger)

        if epoch % 10 == 0:
            val_loss, val_acc = evaluate(validloader, ema_model, logger,
                                         'valid')
            if val_acc >= best_acc:
                best_acc = val_acc
                evaluate(testloader, ema_model, logger, 'test')
                torch.save(ema_model.state_dict(), init_path)
                torch.save(model1.state_dict(), init_path1)

    ema_model.load_state_dict(torch.load(init_path))
    model1.load_state_dict(torch.load(init_path1))

    logger.info('init train finished')
    evaluate(validloader, ema_model, logger, 'valid')
    evaluate(testloader, ema_model, logger, 'test')

    for i_round in range(args.round):
        mask = np.zeros(all_indices.shape, dtype=bool)
        mask[confident_indices] = True
        other_indices = all_indices[~mask]

        optimizer2 = AdamW(model2.parameters(), lr=args.lr)

        logger = logging.getLogger('model2_round_{}'.format(i_round))
        file_handler = logging.FileHandler(
            osp.join(args.dir, 'model2_round_{}.txt'.format(i_round)))
        logger.addHandler(file_handler)
        logger.setLevel(logging.INFO)

        if args.auto:
            probs = predict_probs(ema_model, unlabel_loader2)
        else:
            probs = np.zeros((unlabel_num, args.num_classes))
            for i in range(args.K):
                probs += predict_probs(ema_model, unlabel_loader)
            probs /= args.K

        pseudo_labels[other_indices] = probs.argmax(axis=1).astype(
            np.int32)[other_indices]
        #pseudo_labels = probs.argmax(axis=1).astype(np.int32)

        df2 = create_basic_stats_dataframe()
        df2['iter'] = i_round
        df2['train_acc'] = accuracy_score(unlabel_y, pseudo_labels)
        df = df.append(df2, ignore_index=True)
        df.to_csv(stats_path, index=False)

        #phase2: train model2
        unlabelset.targets = pseudo_labels.copy()
        trainset = ConcatDataset([labelset, unlabelset])

        trainloader = torch.utils.data.DataLoader(trainset,
                                                  batch_size=args.batch_size2,
                                                  num_workers=args.num_workers,
                                                  pin_memory=True,
                                                  shuffle=True)

        model2.load_state_dict(torch.load(init_path2))
        best_val_epoch = 0
        best_model2_acc = 0

        steps_per_epoch = len(iter(trainloader))
        max_epoch2 = args.steps2 // steps_per_epoch

        for epoch in range(max_epoch2):
            train_model2(trainloader, model2, optimizer2, epoch, logger=logger)

            val_loss, val_acc = evaluate(validloader, model2, logger, 'val')

            if val_acc >= best_model2_acc:
                best_model2_acc = val_acc
                best_val_epoch = epoch
                torch.save(model2.state_dict(), best_model2_path)
                evaluate(testloader, model2, logger, 'test')

            if (epoch - best_val_epoch) * steps_per_epoch > args.stop_steps2:
                break

        df.loc[df['iter'] == i_round, 'valid_acc'] = best_model2_acc
        df.loc[df['iter'] == i_round, 'valid_epoch'] = best_val_epoch
        df.to_csv(stats_path, index=False)

        model2.load_state_dict(torch.load(best_model2_path))
        logger.info('model2 train finished')

        evaluate(trainloader, model2, logger, 'train')

        evaluate(validloader, model2, logger, 'val')
        evaluate(label_loader2, model2, logger, 'reward')
        evaluate(testloader, model2, logger, 'test')
        #phase3: get confidence of unlabeled data by labeled data, split confident and unconfident data
        '''if args.auto:
            probs  = predict_probs(model2,unlabel_loader2)
        else:
            probs = np.zeros((unlabel_num,args.num_classes))
            for i in range(args.K):
                probs += predict_probs(model2, unlabel_loader)
            probs /= args.K'''

        probs = predict_probs(model2, unlabel_loader2)
        new_pseudo_labels = probs.argmax(axis=1)

        confidences = probs[all_indices, pseudo_labels]

        if args.schedule == 'exp':
            confident_num = int((len(confident_indices) + args.label_num) *
                                (1 + args.scale)) - args.label_num
        elif args.schedule == 'linear':
            confident_num = len(confident_indices) + int(
                unlabel_num * args.scale)

        old_confident_indices = confident_indices.copy()
        confident_indices = np.array([], dtype=np.int64)

        for j in range(args.num_classes):
            j_cands = (pseudo_labels == j)
            k_size = int(min(confident_num // args.num_classes, j_cands.sum()))
            logger.info('class: {}, confident size: {}'.format(j, k_size))
            if k_size > 0:
                j_idx_top = all_indices[j_cands][
                    confidences[j_cands].argsort()[-k_size:]]
                confident_indices = np.concatenate(
                    (confident_indices, all_indices[j_idx_top]))
        '''new_confident_indices = np.intersect1d(new_confident_indices, np.setdiff1d(new_confident_indices, no_help_indices))
        new_confident_indices = new_confident_indices[(-confidences[new_confident_indices]).argsort()]
        confident_indices = np.concatenate((old_confident_indices, new_confident_indices))'''

        acc = accuracy_score(unlabel_y[confident_indices],
                             pseudo_labels[confident_indices])
        logger.info('confident data num:{}, prop: {:4f}, acc: {:4f}'.format(
            len(confident_indices),
            len(confident_indices) / len(unlabel_y), acc))
        '''if len(old_confident_indices) > 0:
            acc = accuracy_score(unlabel_y[old_confident_indices],pseudo_labels[old_confident_indices])        
            logger.info('old confident data prop: {:4f}, acc: {:4f}'.format(len(old_confident_indices)/len(unlabel_y), acc))

        acc = accuracy_score(unlabel_y[new_confident_indices],pseudo_labels[new_confident_indices])
        logger.info('new confident data prop: {:4f}, acc: {:4f}'.format(len(new_confident_indices)/len(unlabel_y), acc))'''

        #unlabelset.train_labels_ul = pseudo_labels.copy()
        confident_dataset = torch.utils.data.Subset(unlabelset,
                                                    confident_indices)

        #phase4: refine model1 by confident data and reward data
        #train_dataset = torch.utils.data.ConcatDataset([confident_dataset,labelset])

        logger = logging.getLogger('model1_round_{}'.format(i_round))
        file_handler = logging.FileHandler(
            osp.join(args.dir, 'model1_round_{}.txt'.format(i_round)))
        logger.addHandler(file_handler)
        logger.setLevel(logging.INFO)

        best_val_epoch = 0
        evaluate(validloader, ema_model, logger, 'valid')
        evaluate(testloader, ema_model, logger, 'test')

        optimizer1 = AdamW(model1.parameters(), lr=args.lr)

        confident_dataset = torch.utils.data.Subset(unlabelset,
                                                    confident_indices)
        trainloader = torch.utils.data.DataLoader(confident_dataset,
                                                  batch_size=args.batch_size,
                                                  num_workers=args.num_workers,
                                                  shuffle=True,
                                                  drop_last=True)

        #steps_per_epoch = len(iter(trainloader))
        steps_per_epoch = 200
        max_epoch1 = args.steps1 // steps_per_epoch

        for epoch in range(max_epoch1):
            '''current_num = int(cal_consistency_weight( (epoch + 1) * steps_per_epoch, init_ep=0, end_ep=args.stop_steps1//2, init_w=start_num, end_w=end_num))            
            current_confident_indices = confident_indices[:current_num]
            logger.info('current num: {}'.format(current_num))'''
            if args.mix:
                train_model1_mix(label_loader,
                                 trainloader,
                                 model1,
                                 optimizer1,
                                 ema_model,
                                 ema_optimizer,
                                 steps_per_epoch,
                                 epoch,
                                 logger=logger)
            else:
                train_model1(label_loader,
                             trainloader,
                             model1,
                             optimizer1,
                             ema_model,
                             ema_optimizer,
                             steps_per_epoch,
                             epoch,
                             logger=logger)

            val_loss, val_acc = evaluate(validloader, ema_model, logger,
                                         'valid')
            if val_acc >= best_acc:
                best_acc = val_acc
                best_val_epoch = epoch
                evaluate(testloader, ema_model, logger, 'test')
                torch.save(model1.state_dict(), best_model1_path)
                torch.save(ema_model.state_dict(), best_ema_path)

            if (epoch - best_val_epoch) * steps_per_epoch > args.stop_steps1:
                break

        ema_model.load_state_dict(torch.load(best_ema_path))
        model1.load_state_dict(torch.load(best_model1_path))

        logger.info('model1 train finished')
        evaluate(validloader, ema_model, logger, 'valid')
        evaluate(testloader, ema_model, logger, 'test')
        '''no_help_indices = np.concatenate((no_help_indices,confident_indices[current_num:]))
        confident_indices = confident_indices[:current_num]'''

        if len(confident_indices) >= len(all_indices):
            break
Ejemplo n.º 17
0
    log_softmax_outputs = F.log_softmax(outputs/3.0, dim=1)
    softmax_targets = F.softmax(targets/3.0, dim=1)
    return -(log_softmax_outputs * softmax_targets).sum(dim=1).mean()

def clip_grads(params):
    params = list(
        filter(lambda p: p.requires_grad and p.grad is not None, params))
    if len(params) > 0:
        return torch.nn.utils.clip_grad_norm_(params, max_norm=args.clip_grad, norm_type=2)


BATCH_SIZE = 128
LR = 0.1

transform_train = transforms.Compose([transforms.RandomCrop(32, padding=4, fill=128),
                         transforms.RandomHorizontalFlip(), CIFAR10Policy(),
                         transforms.ToTensor(), Cutout(n_holes=1, length=16),
                         transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

trainset, testset = None, None
if args.class_num == 100:
    print("dataset: CIFAR100")
    trainset = torchvision.datasets.CIFAR100(
        root='/home/lthpc/datasets/data',
        train=True,
        download=False,
Ejemplo n.º 18
0
def main_worker(gpu, ngpus_per_node, args):
    global best_acc1
    args.gpu = gpu

    if args.gpu is not None:
        print("Use GPU: {} for training".format(args.gpu))

    # create model
    print("=> creating model '{}'".format(args.arch))
    num_classes = 100 if args.dataset == 'cifar100' else 10
    use_norm = True if args.loss_type == 'LDAM' else False
    model = models.__dict__[args.arch](num_classes=num_classes,
                                       use_norm=use_norm)

    # create two optimizers - one for feature extractor and one for classifier
    feat_params = []
    feat_params_names = []
    cls_params = []
    cls_params_names = []
    learnable_epsilons = torch.nn.Parameter(torch.ones(num_classes))
    for name, params in model.named_parameters():
        if params.requires_grad:
            if "linear" in name:
                cls_params_names += [name]
                cls_params += [params]
            else:
                feat_params_names += [name]
                feat_params += [params]
    print("Create Feat Optimizer")
    print(f"\tRequires Grad:{feat_params_names}")
    feat_optim = torch.optim.SGD(feat_params + [learnable_epsilons],
                                 args.feat_lr,
                                 momentum=args.momentum,
                                 weight_decay=args.weight_decay)
    print("Create Feat Optimizer")
    print(f"\tRequires Grad:{cls_params_names}")
    cls_optim = torch.optim.SGD(cls_params,
                                args.cls_lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    # optionally resume from a checkpoint
    if args.resume or args.evaluation:
        curr_store_name = args.store_name
        if not args.evaluation and args.pretrained:
            curr_store_name = os.path.join(curr_store_name, os.path.pardir)
        filename = '%s/%s/ckpt.best.pth.tar' % (args.root_model,
                                                curr_store_name)
        if os.path.isfile(filename):
            print("=> loading checkpoint '{}'".format(filename))
            checkpoint = torch.load(filename, map_location=f"cuda:{args.gpu}")
            model.load_state_dict(checkpoint['state_dict'])
            print("=> loaded checkpoint '{}' (epoch {})".format(
                filename, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(filename))

    if args.gpu is not None:
        torch.cuda.set_device(args.gpu)
        model = model.cuda(args.gpu)
    else:
        # DataParallel will divide and allocate batch_size to all available GPUs
        model = torch.nn.DataParallel(model).cuda()

    cudnn.benchmark = True
    # Data loading code=
    transform_train = transforms.Compose([
        transforms.RandomCrop(
            32, padding=4
        ),  # fill parameter needs torchvision installed from source
        transforms.RandomHorizontalFlip(),
        CIFAR10Policy(),
        transforms.ToTensor(),
        Cutout(
            n_holes=1, length=16
        ),  # (https://github.com/uoguelph-mlrg/Cutout/blob/master/util/cutout.py)
        transforms.Normalize((0.4914, 0.4822, 0.4465),
                             (0.2023, 0.1994, 0.2010))
    ])

    transform_val = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465),
                             (0.2023, 0.1994, 0.2010)),
    ])

    if args.dataset == 'cifar10':
        original_train_dataset = IMBALANCECIFAR10(root='./data',
                                                  imb_type=args.imb_type,
                                                  imb_factor=args.imb_factor,
                                                  rand_number=args.rand_number,
                                                  train=True,
                                                  download=True,
                                                  transform=transform_val)
        augmented_train_dataset = IMBALANCECIFAR10(
            root='./data',
            imb_type=args.imb_type,
            imb_factor=args.imb_factor,
            rand_number=args.rand_number,
            train=True,
            download=True,
            transform=transform_train
            if not args.evaluation else transform_val)
        val_dataset = datasets.CIFAR10(root='./data',
                                       train=False,
                                       download=True,
                                       transform=transform_val)
    elif args.dataset == 'cifar100':
        original_train_dataset = IMBALANCECIFAR100(
            root='./data',
            imb_type=args.imb_type,
            imb_factor=args.imb_factor,
            rand_number=args.rand_number,
            train=True,
            download=True,
            transform=transform_val)
        augmented_train_dataset = IMBALANCECIFAR100(
            root='./data',
            imb_type=args.imb_type,
            imb_factor=args.imb_factor,
            rand_number=args.rand_number,
            train=True,
            download=True,
            transform=transform_train
            if not args.evaluation else transform_val)
        val_dataset = datasets.CIFAR100(root='./data',
                                        train=False,
                                        download=True,
                                        transform=transform_val)
    else:
        warnings.warn('Dataset is not listed')
        return

    cls_num_list = augmented_train_dataset.get_cls_num_list()
    args.cls_num_list = cls_num_list

    train_labels = np.array(augmented_train_dataset.get_targets()).astype(int)
    # calculate balanced weights
    balanced_weights = torch.tensor(class_weight.compute_class_weight(
        'balanced', np.unique(train_labels), train_labels),
                                    dtype=torch.float).cuda(args.gpu)
    lt_weights = torch.tensor(cls_num_list).float() / max(cls_num_list)

    def create_sampler(args_str):
        if args_str is not None and "resample" in args_str:
            sampler_type, n_resample = args_str.split(",")
            return ClassAwareSampler(train_labels,
                                     num_samples_cls=int(n_resample))
        return None

    original_train_loader = torch.utils.data.DataLoader(
        original_train_dataset,
        batch_size=args.batch_size,
        shuffle=False,
        num_workers=args.workers,
        pin_memory=True)
    val_loader = torch.utils.data.DataLoader(val_dataset,
                                             batch_size=100,
                                             shuffle=False,
                                             num_workers=args.workers,
                                             pin_memory=True)

    # feature extractor dataloader
    feat_sampler = create_sampler(args.feat_sampler)
    feat_train_loader = torch.utils.data.DataLoader(
        augmented_train_dataset,
        batch_size=args.batch_size,
        shuffle=(feat_sampler is None),
        num_workers=args.workers,
        pin_memory=True,
        sampler=feat_sampler)

    if args.evaluation:
        # evaluate on validation set
        # calculate centroids on the train
        _, train_features, train_targets, _ = validate(original_train_loader,
                                                       model,
                                                       0,
                                                       args,
                                                       train_labels,
                                                       flag="train",
                                                       save_out=True)
        # validate
        validate(val_loader,
                 model,
                 0,
                 args,
                 train_labels,
                 flag="val",
                 save_out=True,
                 base_features=train_features,
                 base_targets=train_targets)
        quit()

    # create losses
    def create_loss_list(args_str):
        loss_ls = []
        loss_str_ls = args_str.split(",")
        for loss_str in loss_str_ls:
            c_weights = None
            prefix = ""
            if "_bal" in loss_str:
                c_weights = balanced_weights
                prefix = "Balanced "
                loss_str = loss_str.split("_bal")[0]
            if "_lt" in loss_str:
                c_weights = lt_weights
                prefix = "Longtailed "
                loss_str = loss_str.split("_")[0]
            if loss_str == "ce":
                print(f"{prefix}CE", end=",")
                loss_ls += [
                    nn.CrossEntropyLoss(weight=c_weights).cuda(args.gpu)
                ]
            elif loss_str == "robust_loss":
                print(f"{prefix}Robust Loss", end=",")
                loss_ls += [
                    DROLoss(temperature=args.temperature,
                            base_temperature=args.temperature,
                            class_weights=c_weights,
                            epsilons=learnable_epsilons)
                ]
        print()
        return loss_ls

    feat_losses = create_loss_list(args.feat_loss)
    cls_losses = create_loss_list(args.cls_loss)

    # init log for training
    if not args.evaluation:
        log_training = open(
            os.path.join(args.root_log, args.store_name, 'log_train.csv'), 'w')
        log_testing = open(
            os.path.join(args.root_log, args.store_name, 'log_test.csv'), 'w')
        with open(os.path.join(args.root_log, args.store_name, 'args.txt'),
                  'w') as f:
            f.write(str(args))
        tf_writer = None

    best_acc1 = 0
    best_acc_contrastive = 0
    for epoch in range(args.start_epoch, args.epochs):
        print("=============== Extract Train Centroids ===============")
        _, train_features, train_targets, _ = validate(feat_train_loader,
                                                       model,
                                                       epoch,
                                                       args,
                                                       train_labels,
                                                       log_training,
                                                       tf_writer,
                                                       flag="train",
                                                       verbose=True)

        if epoch < args.epochs - args.balanced_clf_nepochs:
            print("=============== Train Feature Extractor ===============")
            freeze_layers(model, fe_bool=True, cls_bool=False)
            train(feat_train_loader, model, feat_losses, epoch, feat_optim,
                  args, train_features, train_targets)

        else:
            if epoch == args.epochs - args.balanced_clf_nepochs:
                print(
                    "================ Loading Best Feature Extractor ================="
                )
                # load best model
                curr_store_name = args.store_name
                filename = '%s/%s/ckpt.best.pth.tar' % (args.root_model,
                                                        curr_store_name)
                checkpoint = torch.load(
                    filename, map_location=f"cuda:{args.gpu}")['state_dict']
                model.load_state_dict(checkpoint)

            print("=============== Train Classifier ===============")
            freeze_layers(model, fe_bool=False, cls_bool=True)
            train(feat_train_loader, model, cls_losses, epoch, cls_optim, args)

        print("=============== Extract Train Centroids ===============")
        _, train_features, train_targets, _ = validate(original_train_loader,
                                                       model,
                                                       epoch,
                                                       args,
                                                       train_labels,
                                                       log_training,
                                                       tf_writer,
                                                       flag="train",
                                                       verbose=False)

        print("=============== Validate ===============")
        acc1, _, _, contrastive_acc = validate(val_loader,
                                               model,
                                               epoch,
                                               args,
                                               train_labels,
                                               log_testing,
                                               tf_writer,
                                               flag="val",
                                               base_features=train_features,
                                               base_targets=train_targets)
        if epoch < args.epochs - args.balanced_clf_nepochs:
            is_best = contrastive_acc > best_acc_contrastive
            best_acc_contrastive = max(contrastive_acc, best_acc_contrastive)
        else:
            is_best = acc1 > best_acc1
            best_acc1 = max(acc1, best_acc1)

        print(
            f"Best Contrastive Acc: {best_acc_contrastive}, Best Cls Acc: {best_acc1}"
        )
        log_testing.write(
            f"Best Contrastive Acc: {best_acc_contrastive}, Best Cls Acc: {best_acc1}"
        )
        log_testing.flush()
        save_checkpoint(
            args, {
                'epoch': epoch + 1,
                'arch': args.arch,
                'state_dict': model.state_dict(),
                'best_acc1': best_acc1
            }, is_best)
Ejemplo n.º 19
0
def main():

    global best_prec1
    best_prec1 = 0

    global val_acc
    val_acc = []

    global class_num

    class_num = args.dataset == 'cifar10' and 10 or 100

    normalize = transforms.Normalize(
        mean=[x / 255.0 for x in [125.3, 123.0, 113.9]],
        std=[x / 255.0 for x in [63.0, 62.1, 66.7]])

    if args.augment:
        if args.autoaugment:
            print('Autoaugment')
            transform_train = transforms.Compose([
                transforms.ToTensor(),
                transforms.Lambda(lambda x: F.pad(x.unsqueeze(0), (4, 4, 4, 4),
                                                  mode='reflect').squeeze()),
                transforms.ToPILImage(),
                transforms.RandomCrop(32),
                transforms.RandomHorizontalFlip(),
                CIFAR10Policy(),
                transforms.ToTensor(),
                Cutout(n_holes=args.n_holes, length=args.length),
                normalize,
            ])

        elif args.cutout:
            print('Cutout')
            transform_train = transforms.Compose([
                transforms.ToTensor(),
                transforms.Lambda(lambda x: F.pad(x.unsqueeze(0), (4, 4, 4, 4),
                                                  mode='reflect').squeeze()),
                transforms.ToPILImage(),
                transforms.RandomCrop(32),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                Cutout(n_holes=args.n_holes, length=args.length),
                normalize,
            ])

        else:
            print('Standrad Augmentation!')
            transform_train = transforms.Compose([
                transforms.ToTensor(),
                transforms.Lambda(lambda x: F.pad(x.unsqueeze(0), (4, 4, 4, 4),
                                                  mode='reflect').squeeze()),
                transforms.ToPILImage(),
                transforms.RandomCrop(32),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                normalize,
            ])
    else:
        transform_train = transforms.Compose([
            transforms.ToTensor(),
            normalize,
        ])

    transform_test = transforms.Compose([transforms.ToTensor(), normalize])

    kwargs = {'num_workers': 1, 'pin_memory': True}
    assert (args.dataset == 'cifar10' or args.dataset == 'cifar100')
    train_loader = torch.utils.data.DataLoader(
        datasets.__dict__[args.dataset.upper()]('../data',
                                                train=True,
                                                download=True,
                                                transform=transform_train),
        batch_size=training_configurations[args.model]['batch_size'],
        shuffle=True,
        **kwargs)
    val_loader = torch.utils.data.DataLoader(
        datasets.__dict__[args.dataset.upper()]('../data',
                                                train=False,
                                                transform=transform_test),
        batch_size=training_configurations[args.model]['batch_size'],
        shuffle=True,
        **kwargs)

    # create model
    if args.model == 'resnet':
        model = eval('networks.resnet.resnet' + str(args.layers) +
                     '_cifar')(dropout_rate=args.droprate)
    elif args.model == 'se_resnet':
        model = eval('networks.se_resnet.resnet' + str(args.layers) +
                     '_cifar')(dropout_rate=args.droprate)
    elif args.model == 'wideresnet':
        model = networks.wideresnet.WideResNet(args.layers,
                                               args.dataset == 'cifar10' and 10
                                               or 100,
                                               args.widen_factor,
                                               dropRate=args.droprate)
    elif args.model == 'se_wideresnet':
        model = networks.se_wideresnet.WideResNet(
            args.layers,
            args.dataset == 'cifar10' and 10 or 100,
            args.widen_factor,
            dropRate=args.droprate)

    elif args.model == 'densenet_bc':
        model = networks.densenet_bc.DenseNet(
            growth_rate=args.growth_rate,
            block_config=(int((args.layers - 4) / 6), ) * 3,
            compression=args.compression_rate,
            num_init_features=24,
            bn_size=args.bn_size,
            drop_rate=args.droprate,
            small_inputs=True,
            efficient=False)
    elif args.model == 'shake_pyramidnet':
        model = networks.shake_pyramidnet.PyramidNet(dataset=args.dataset,
                                                     depth=args.layers,
                                                     alpha=args.alpha,
                                                     num_classes=class_num,
                                                     bottleneck=True)

    elif args.model == 'resnext':
        if args.cardinality == 8:
            model = networks.resnext.resnext29_8_64(class_num)
        if args.cardinality == 16:
            model = networks.resnext.resnext29_16_64(class_num)

    elif args.model == 'shake_shake':
        if args.widen_factor == 112:
            model = networks.shake_shake.shake_resnet26_2x112d(class_num)
        if args.widen_factor == 32:
            model = networks.shake_shake.shake_resnet26_2x32d(class_num)
        if args.widen_factor == 96:
            model = networks.shake_shake.shake_resnet26_2x32d(class_num)

    elif args.model == 'shake_shake_x':

        model = networks.shake_shake.shake_resnext29_2x4x64d(class_num)

    if not os.path.isdir(check_point):
        mkdir_p(check_point)

    fc = Full_layer(int(model.feature_num), class_num)

    print('Number of final features: {}'.format(int(model.feature_num)))

    print('Number of model parameters: {}'.format(
        sum([p.data.nelement() for p in model.parameters()]) +
        sum([p.data.nelement() for p in fc.parameters()])))

    cudnn.benchmark = True

    # define loss function (criterion) and optimizer
    isda_criterion = ISDALoss(int(model.feature_num), class_num).cuda()
    ce_criterion = nn.CrossEntropyLoss().cuda()

    optimizer = torch.optim.SGD(
        [{
            'params': model.parameters()
        }, {
            'params': fc.parameters()
        }],
        lr=training_configurations[args.model]['initial_learning_rate'],
        momentum=training_configurations[args.model]['momentum'],
        nesterov=training_configurations[args.model]['nesterov'],
        weight_decay=training_configurations[args.model]['weight_decay'])

    model = torch.nn.DataParallel(model).cuda()
    fc = nn.DataParallel(fc).cuda()

    if args.resume:
        # Load checkpoint.
        print('==> Resuming from checkpoint..')
        assert os.path.isfile(
            args.resume), 'Error: no checkpoint directory found!'
        args.checkpoint = os.path.dirname(args.resume)
        checkpoint = torch.load(args.resume)
        start_epoch = checkpoint['epoch']
        model.load_state_dict(checkpoint['state_dict'])
        fc.load_state_dict(checkpoint['fc'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        isda_criterion = checkpoint['isda_criterion']
        val_acc = checkpoint['val_acc']
        best_prec1 = checkpoint['best_acc']
        np.savetxt(accuracy_file, np.array(val_acc))
    else:
        start_epoch = 0

    for epoch in range(start_epoch,
                       training_configurations[args.model]['epochs']):

        adjust_learning_rate(optimizer, epoch + 1)

        # train for one epoch
        train(train_loader, model, fc, isda_criterion, optimizer, epoch)

        # evaluate on validation set
        prec1 = validate(val_loader, model, fc, ce_criterion, epoch)

        # remember best prec@1 and save checkpoint
        is_best = prec1 > best_prec1
        best_prec1 = max(prec1, best_prec1)

        save_checkpoint(
            {
                'epoch': epoch + 1,
                'state_dict': model.state_dict(),
                'fc': fc.state_dict(),
                'best_acc': best_prec1,
                'optimizer': optimizer.state_dict(),
                'isda_criterion': isda_criterion,
                'val_acc': val_acc,
            },
            is_best,
            checkpoint=check_point)
        print('Best accuracy: ', best_prec1)
        np.savetxt(accuracy_file, np.array(val_acc))

    print('Best accuracy: ', best_prec1)
    print('Average accuracy', sum(val_acc[len(val_acc) - 10:]) / 10)
    # val_acc.append(sum(val_acc[len(val_acc) - 10:]) / 10)
    # np.savetxt(val_acc, np.array(val_acc))
    np.savetxt(accuracy_file, np.array(val_acc))
Ejemplo n.º 20
0
def get_iters(dataset='CIFAR10',
              root_path='.',
              data_transforms=None,
              n_labeled=4000,
              valid_size=1000,
              l_batch_size=32,
              ul_batch_size=128,
              test_batch_size=256,
              workers=8,
              pseudo_label=None):

    train_path = '{}/data/{}/train/'.format(root_path, dataset)
    test_path = '{}/data/{}/test/'.format(root_path, dataset)

    if dataset == 'CIFAR10':
        train_dataset = datasets.CIFAR10(train_path,
                                         download=True,
                                         train=True,
                                         transform=None)
        test_dataset = datasets.CIFAR10(test_path,
                                        download=True,
                                        train=False,
                                        transform=None)
    elif dataset == 'CIFAR100':
        train_dataset = datasets.CIFAR100(train_path,
                                          download=True,
                                          train=True,
                                          transform=None)
        test_dataset = datasets.CIFAR100(test_path,
                                         download=True,
                                         train=False,
                                         transform=None)
    else:
        raise ValueError

    if data_transforms is None:
        data_transforms = {
            'train':
            transforms.Compose([
                transforms.ToPILImage(),
                transforms.RandomHorizontalFlip(),
                CIFAR10Policy(),
                transforms.ToTensor(),
                transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
            ]),
            'eval':
            transforms.Compose([
                transforms.ToPILImage(),
                transforms.ToTensor(),
                transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
            ]),
        }

    x_train, y_train = train_dataset.train_data, np.array(
        train_dataset.train_labels)
    x_test, y_test = test_dataset.test_data, np.array(test_dataset.test_labels)

    randperm = np.random.permutation(len(x_train))
    labeled_idx = randperm[:n_labeled]
    validation_idx = randperm[n_labeled:n_labeled + valid_size]
    unlabeled_idx = randperm[n_labeled + valid_size:]

    x_labeled = x_train[labeled_idx]
    x_validation = x_train[validation_idx]
    x_unlabeled = x_train[unlabeled_idx]

    y_labeled = y_train[labeled_idx]
    y_validation = y_train[validation_idx]
    if pseudo_label is None:
        y_unlabeled = y_train[unlabeled_idx]
    else:
        assert isinstance(pseudo_label, np.ndarray)
        y_unlabeled = pseudo_label

    data_iterators = {
        'labeled':
        iter(
            DataLoader(
                SimpleDataset(x_labeled, y_labeled, data_transforms['train']),
                batch_size=l_batch_size,
                num_workers=workers,
                sampler=InfiniteSampler(len(x_labeled)),
            )),
        'unlabeled':
        iter(
            DataLoader(
                MultiDataset(x_unlabeled, y_unlabeled, data_transforms['eval'],
                             data_transforms['train']),
                batch_size=ul_batch_size,
                num_workers=workers,
                sampler=InfiniteSampler(len(x_unlabeled)),
            )),
        'val':
        iter(
            DataLoader(SimpleDataset(x_validation, y_validation,
                                     data_transforms['eval']),
                       batch_size=len(x_validation),
                       num_workers=workers,
                       shuffle=False)),
        'test':
        iter(
            DataLoader(SimpleDataset(x_test, y_test, data_transforms['eval']),
                       batch_size=test_batch_size,
                       num_workers=workers,
                       shuffle=False))
    }

    return data_iterators