示例#1
0
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465),
                             (0.2023, 0.1994, 0.2010)),
    ])

    transform_val = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465),
                             (0.2023, 0.1994, 0.2010)),
    ])

    train_dataset = IMBALANCECIFAR10(root='./data',
                                     imb_type=imb_type,
                                     imb_factor=imb_factor,
                                     rand_number=rand_number,
                                     train=True,
                                     download=False,
                                     transform=transform_train)
    val_dataset = datasets.CIFAR10(root='./data',
                                   train=False,
                                   download=False,
                                   transform=transform_val)

    # data loader
    num_workers = 4
    batch_size = 128
    train_sampler = None
    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=batch_size,
                                               shuffle=(train_sampler is None),
示例#2
0
def main_worker(gpu, ngpus_per_node, args):
    global best_acc1
    args.gpu = gpu

    if args.gpu is not None:
        print("Use GPU: {} for training".format(args.gpu))

    # create model
    print("=> creating model '{}'".format(args.arch))
    num_classes = 100 if args.dataset == 'cifar100' else 10
    use_norm = True if args.loss_type == 'LDAM' else False

    # DataParallel will divide and allocate batch_size to all available GPUs
    model = models.__dict__[args.arch](
        num_classes=num_classes,
        use_norm=use_norm,
        head_tail_ratio=args.head_tail_ratio,
        transfer_strength=args.transfer_strength,
        phase_train=True,
        epoch_thresh=args.epoch_thresh)

    model = torch.nn.DataParallel(model).cuda()

    optimizer = torch.optim.SGD(model.parameters(),
                                args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    cudnn.benchmark = True

    # Data loading code

    transform_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465),
                             (0.2023, 0.1994, 0.2010)),
    ])

    transform_val = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465),
                             (0.2023, 0.1994, 0.2010)),
    ])

    if args.dataset == 'cifar10':
        train_dataset = IMBALANCECIFAR10(root='./data',
                                         imb_type=args.imb_type,
                                         imb_factor=args.imb_factor,
                                         rand_number=args.rand_number,
                                         train=True,
                                         download=True,
                                         transform=transform_train)
        val_dataset = datasets.CIFAR10(root='./data',
                                       train=False,
                                       download=True,
                                       transform=transform_val)
    elif args.dataset == 'cifar100':
        train_dataset = IMBALANCECIFAR100(root='./data',
                                          imb_type=args.imb_type,
                                          imb_factor=args.imb_factor,
                                          rand_number=args.rand_number,
                                          train=True,
                                          download=True,
                                          transform=transform_train)
        val_dataset = datasets.CIFAR100(root='./data',
                                        train=False,
                                        download=True,
                                        transform=transform_val)
    else:
        warnings.warn('Dataset is not listed')
        return
    cls_num_list = train_dataset.get_cls_num_list()
    print('cls num list:')
    print(cls_num_list)
    args.cls_num_list = cls_num_list

    train_sampler = None

    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=(train_sampler is None),
                                               num_workers=args.workers,
                                               pin_memory=True,
                                               sampler=train_sampler,
                                               drop_last=True)

    val_loader = torch.utils.data.DataLoader(val_dataset,
                                             batch_size=100,
                                             shuffle=False,
                                             num_workers=args.workers,
                                             pin_memory=True)

    # init log for training
    log_training = open(
        os.path.join(args.root_log, args.store_name, 'log_train.csv'), 'w')
    log_testing = open(
        os.path.join(args.root_log, args.store_name, 'log_test.csv'), 'w')
    with open(os.path.join(args.root_log, args.store_name, 'args.txt'),
              'w') as f:
        f.write(str(args))
    tf_writer = SummaryWriter(
        log_dir=os.path.join(args.root_log, args.store_name))
    for epoch in range(args.start_epoch, args.epochs):
        adjust_learning_rate(optimizer, epoch, args)

        if args.train_rule == 'None':
            train_sampler = None
            per_cls_weights = None
        elif args.train_rule == 'Resample':
            train_sampler = ImbalancedDatasetSampler(train_dataset)
            per_cls_weights = None
        elif args.train_rule == 'Reweight':
            train_sampler = None
            beta = 0.9999
            effective_num = 1.0 - np.power(beta, cls_num_list)
            per_cls_weights = (1.0 - beta) / np.array(effective_num)
            per_cls_weights = per_cls_weights / np.sum(per_cls_weights) * len(
                cls_num_list)
            per_cls_weights = torch.FloatTensor(per_cls_weights).cuda(args.gpu)
        elif args.train_rule == 'DRW':
            train_sampler = None
            idx = epoch // 160
            betas = [0, 0.9999]
            effective_num = 1.0 - np.power(betas[idx], cls_num_list)
            per_cls_weights = (1.0 - betas[idx]) / np.array(effective_num)
            per_cls_weights = per_cls_weights / np.sum(per_cls_weights) * len(
                cls_num_list)
            per_cls_weights = torch.FloatTensor(per_cls_weights).cuda(args.gpu)
        else:
            warnings.warn('Sample rule is not listed')

        if args.loss_type == 'CE':
            criterion = nn.CrossEntropyLoss(weight=per_cls_weights).cuda(
                args.gpu)
        elif args.loss_type == 'LDAM':
            criterion = LDAMLoss(cls_num_list=cls_num_list,
                                 max_m=0.5,
                                 s=30,
                                 weight=per_cls_weights).cuda(args.gpu)
        elif args.loss_type == 'Focal':
            criterion = FocalLoss(weight=per_cls_weights,
                                  gamma=1).cuda(args.gpu)
        else:
            warnings.warn('Loss type is not listed')
            return

        # train for one epoch
        train(train_loader, model, criterion, optimizer, epoch, args,
              log_training, tf_writer)

        # evaluate on validation set
        acc1 = validate(val_loader, model, criterion, epoch, args, log_testing,
                        tf_writer)

        # remember best acc@1 and save checkpoint
        is_best = acc1 > best_acc1
        best_acc1 = max(acc1, best_acc1)

        tf_writer.add_scalar('acc/test_top1_best', best_acc1, epoch)
        output_best = 'Best Prec@1: %.3f\n' % (best_acc1)
        print(output_best)
        log_testing.write(output_best + '\n')
        log_testing.flush()

        # remove RSG module, since RSG is not used during testing.
        new_state_dict = OrderedDict()
        for k in model.state_dict().keys():
            name = k[7:]  # remove `module.`
            if 'RSG' in k:
                continue
            new_state_dict[name] = model.state_dict()[k]

        save_checkpoint(
            args, {
                'epoch': epoch + 1,
                'arch': args.arch,
                'state_dict': new_state_dict,
                'best_acc1': best_acc1,
                'optimizer': optimizer.state_dict(),
            }, is_best)
示例#3
0
def main_worker(gpu, ngpus_per_node, args):
    global best_acc1
    args.gpu = gpu

    if args.gpu is not None:
        print("Use GPU: {} for training".format(args.gpu))

    # create model
    print("=> creating model '{}'".format(args.arch))
    num_classes = 100 if args.dataset == 'cifar100' else 10
    use_norm = True if args.loss_type == 'LDAM' else False
    if args.dataset == 'cifar10':
        args.train_rule = 'DRW'
    else:
        args.train_rule = 'Reweight'
    teacher_model = teacher_models.__dict__[args.arch](num_classes=num_classes,
                                                       use_norm=use_norm)
    student_model = student_models.__dict__[args.arch](num_classes=num_classes,
                                                       use_norm=use_norm)
    teacher_model = load_network(teacher_model, args)

    args.num_classes = num_classes

    if args.gpu is not None:
        torch.cuda.set_device(args.gpu)
        teacher_model = teacher_model.to(args.gpu)
        student_model = student_model.to(args.gpu)
    else:
        # DataParallel will divide and allocate batch_size to all available GPUs
        teacher_model = torch.nn.DataParallel(teacher_model).cuda()
        student_model = torch.nn.DataParallel(student_model).cuda()

    optimizer = torch.optim.SGD(student_model.parameters(),
                                args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume, map_location='cuda:0')
            args.start_epoch = checkpoint['epoch']
            best_acc1 = checkpoint['best_acc1']
            if args.gpu is not None:
                # best_acc1 may be from a checkpoint from a different GPU
                best_acc1 = best_acc1.to(args.gpu)
            student_model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    cudnn.benchmark = True

    # Data loading code

    transform_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465),
                             (0.2023, 0.1994, 0.2010)),
    ])

    transform_val = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465),
                             (0.2023, 0.1994, 0.2010)),
    ])

    if args.dataset == 'cifar10':
        train_dataset = IMBALANCECIFAR10(root='./data',
                                         imb_type=args.imb_type,
                                         imb_factor=args.imb_factor,
                                         rand_number=args.rand_number,
                                         train=True,
                                         download=True,
                                         transform=transform_train)
        val_dataset = datasets.CIFAR10(root='./data',
                                       train=False,
                                       download=True,
                                       transform=transform_val)
    elif args.dataset == 'cifar100':
        train_dataset = IMBALANCECIFAR100(root='./data',
                                          imb_type=args.imb_type,
                                          imb_factor=args.imb_factor,
                                          rand_number=args.rand_number,
                                          train=True,
                                          download=True,
                                          transform=transform_train)
        val_dataset = datasets.CIFAR100(root='./data',
                                        train=False,
                                        download=True,
                                        transform=transform_val)
    else:
        warnings.warn('Dataset is not listed')
        return
    cls_num_list = train_dataset.get_cls_num_list()
    print('cls num list:')
    print(cls_num_list)
    args.cls_num_list = cls_num_list

    train_sampler = None

    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=(train_sampler is None),
                                               num_workers=args.workers,
                                               pin_memory=True,
                                               sampler=train_sampler)

    val_loader = torch.utils.data.DataLoader(val_dataset,
                                             batch_size=100,
                                             shuffle=False,
                                             num_workers=args.workers,
                                             pin_memory=True)

    # init log for training
    log_training = open(
        os.path.join(args.root_log, args.store_name, 'log_train.csv'), 'w')
    log_testing = open(
        os.path.join(args.root_log, args.store_name, 'log_test.csv'), 'w')
    with open(os.path.join(args.root_log, args.store_name, 'args.txt'),
              'w') as f:
        f.write(str(args))
    tf_writer = SummaryWriter(
        log_dir=os.path.join(args.root_log, args.store_name))
    for epoch in range(args.start_epoch, args.epochs):
        adjust_learning_rate(optimizer, epoch, args)

        if args.train_rule == 'None':
            train_sampler = None
            per_cls_weights = None
        elif args.train_rule == 'Resample':
            train_sampler = ImbalancedDatasetSampler(train_dataset)
            per_cls_weights = None
        elif args.train_rule == 'Reweight':
            train_sampler = None
            beta = 0.9999
            effective_num = 1.0 - np.power(beta, cls_num_list)
            per_cls_weights = (1.0 - beta) / np.array(effective_num)
            per_cls_weights = per_cls_weights / np.sum(per_cls_weights) * len(
                cls_num_list)
            per_cls_weights = torch.FloatTensor(per_cls_weights).cuda(args.gpu)
        elif args.train_rule == 'DRW':
            train_sampler = None
            idx = epoch // 160
            betas = [0, 0.9999]
            effective_num = 1.0 - np.power(betas[idx], cls_num_list)
            per_cls_weights = (1.0 - betas[idx]) / np.array(effective_num)
            per_cls_weights = per_cls_weights / np.sum(per_cls_weights) * len(
                cls_num_list)
            per_cls_weights = torch.FloatTensor(per_cls_weights).cuda(args.gpu)
        else:
            warnings.warn('Sample rule is not listed')

        if args.loss_type == 'CE':
            criterion = nn.CrossEntropyLoss(weight=per_cls_weights).cuda(
                args.gpu)
        elif args.loss_type == 'KD':
            criterion = KDLoss(cls_num_list=cls_num_list,
                               T=args.T,
                               weight=per_cls_weights).cuda(args.gpu)
        elif args.loss_type == 'BKD':
            criterion = BKDLoss(cls_num_list=cls_num_list,
                                T=args.T,
                                weight=per_cls_weights).cuda(args.gpu)
        elif args.loss_type == 'LDAM':
            criterion = LDAMLoss(cls_num_list=cls_num_list,
                                 max_m=0.5,
                                 s=30,
                                 weight=per_cls_weights).cuda(args.gpu)
        elif args.loss_type == 'Focal':
            criterion = FocalLoss(weight=per_cls_weights,
                                  gamma=1).cuda(args.gpu)
        else:
            warnings.warn('Loss type is not listed')
            return

        # train for one epoch
        train(train_loader, teacher_model, student_model, criterion, optimizer,
              epoch, args, log_training, tf_writer)

        # evaluate on validation set
        acc1 = validate(val_loader, teacher_model, student_model, criterion,
                        epoch, args, log_testing, tf_writer)

        # remember best acc@1 and save checkpoint
        is_best = acc1 > best_acc1
        best_acc1 = max(acc1, best_acc1)

        tf_writer.add_scalar('acc/test_top1_best', best_acc1, epoch)
        output_best = 'Best Prec@1: %.3f\n' % (best_acc1)
        print(output_best)
        log_testing.write(output_best + '\n')
        log_testing.flush()

        save_checkpoint(
            args, {
                'epoch': epoch + 1,
                'arch': args.arch,
                'state_dict': student_model.state_dict(),
                'best_acc1': best_acc1,
                'optimizer': optimizer.state_dict(),
            }, is_best)
示例#4
0
def main_worker(gpu, ngpus_per_node, args):
    global best_acc1
    args.gpu = gpu

    if args.gpu is not None:
        print("Use GPU: {} for training".format(args.gpu))

    # create model
    print("=> creating model '{}'".format(args.arch))
    num_classes = 100 if args.dataset == 'cifar100' else 10
    use_norm = True if args.loss_type == 'LDAM' else False
    model = models.__dict__[args.arch](num_classes=num_classes,
                                       use_norm=use_norm)

    if args.gpu is not None:
        torch.cuda.set_device(args.gpu)
        model = model.cuda(
            args.gpu)  #cuda()用于将变量传输到GPU上,model.cuda将model复制到cuda上
    else:
        # DataParallel will divide and allocate batch_size to all available GPUs
        model = torch.nn.DataParallel(
            model).cuda()  #批处理,多GPU默认用dataparallel使用在多块gpu上

    optimizer = torch.optim.SGD(
        model.parameters(),
        args.lr,
        momentum=args.momentum,
        weight_decay=args.weight_decay
    )  # optimizer 使用 SGD + momentum  动量,默认设置为0.9 权值衰减,默认为1e-4

    # optionally resume from a checkpoint 恢复模型
    if args.resume:
        if os.path.isfile(args.resume):  #os.path.isfile判断返回是否为文件
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume,
                                    map_location='cuda:0')  #load一个save得对象
            args.start_epoch = checkpoint['epoch']  #defaut=200
            best_acc1 = checkpoint['best_acc1']  #best_acc1 = 0
            if args.gpu is not None:
                # best_acc1 may be from a checkpoint from a different GPU
                best_acc1 = best_acc1.to(args.gpu)
            model.load_state_dict(
                checkpoint['state_dict'])  #load_state_dict仅仅加载模型参数
            optimizer.load_state_dict(checkpoint['optimizer'])  # 优化模型,恢复模型
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    cudnn.benchmark = True  #设置这个 flag 可以让内置的 cuDNN 的 auto-tuner 自动寻找最适合当前配置的高效算法,来达到优化运行效率的问题
    # Data loading code
    #数据预处理及数据加载
    transform_train = transforms.Compose([  #将几个transforms 组合在一起
        transforms.RandomCrop(32, padding=4),  #剪裁到32×32
        transforms.RandomHorizontalFlip(),  #随机水平翻转,概率为0.5。
        transforms.ToTensor(),  #将numpy表示的图像转化为torch的Tensor表示
        # 把一个取值范围是[0,255]的PIL.Image或者shape为(H,W,C)的numpy.ndarray,转换成形状为[C,H,W],取值范围是[0,1.0]的FloadTensor
        transforms.Normalize(
            (0.4914, 0.4822, 0.4465),
            (0.2023, 0.1994, 0.2010)),  #标准化normalize: - mean / std
    ])

    transform_val = transforms.Compose([
        transforms.ToTensor(),  #将numpy表示的图像转化为torch的Tensor表示
        transforms.Normalize((0.4914, 0.4822, 0.4465),
                             (0.2023, 0.1994, 0.2010)),
    ])

    if args.dataset == 'cifar10':
        train_dataset = IMBALANCECIFAR10(root='./data',
                                         imb_type=args.imb_type,
                                         imb_factor=args.imb_factor,
                                         rand_number=args.rand_number,
                                         train=True,
                                         download=True,
                                         transform=transform_train)
        val_dataset = datasets.CIFAR10(root='./data',
                                       train=False,
                                       download=True,
                                       transform=transform_val)
    elif args.dataset == 'cifar100':
        train_dataset = IMBALANCECIFAR100(root='./data',
                                          imb_type=args.imb_type,
                                          imb_factor=args.imb_factor,
                                          rand_number=args.rand_number,
                                          train=True,
                                          download=True,
                                          transform=transform_train)
        val_dataset = datasets.CIFAR100(root='./data',
                                        train=False,
                                        download=True,
                                        transform=transform_val)
    else:
        warnings.warn('Dataset is not listed')
        return
    cls_num_list = train_dataset.get_cls_num_list()
    print('cls num list:')
    print(cls_num_list)
    args.cls_num_list = cls_num_list

    train_sampler = None
    # dataloader
    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=(train_sampler is None),
                                               num_workers=args.workers,
                                               pin_memory=True,
                                               sampler=train_sampler)

    val_loader = torch.utils.data.DataLoader(val_dataset,
                                             batch_size=100,
                                             shuffle=False,
                                             num_workers=args.workers,
                                             pin_memory=True)

    # init log for training
    log_training = open(
        os.path.join(args.root_log, args.store_name, 'log_train.csv'),
        'w')  #os.path.join把目录和文件名合成一个路径
    log_testing = open(
        os.path.join(args.root_log, args.store_name, 'log_test.csv'), 'w')
    with open(os.path.join(args.root_log, args.store_name, 'args.txt'),
              'w') as f:
        f.write(str(args))
    tf_writer = SummaryWriter(
        log_dir=os.path.join(args.root_log, args.store_name))
    #训练模型
    for epoch in range(args.start_epoch, args.epochs):
        adjust_learning_rate(optimizer, epoch,
                             args)  #adjust_learning_rate自定义函数

        if args.train_rule == 'None':
            train_sampler = None
            per_cls_weights = None
        elif args.train_rule == 'Resample':
            train_sampler = ImbalancedDatasetSampler(train_dataset)
            per_cls_weights = None
        elif args.train_rule == 'Reweight':
            train_sampler = None
            beta = 0.9999
            effective_num = 1.0 - np.power(beta, cls_num_list)
            per_cls_weights = (1.0 - beta) / np.array(effective_num)
            per_cls_weights = per_cls_weights / np.sum(per_cls_weights) * len(
                cls_num_list)
            per_cls_weights = torch.FloatTensor(per_cls_weights).cuda(args.gpu)
        elif args.train_rule == 'DRW':
            train_sampler = None
            idx = epoch // 160
            betas = [0, 0.9999]
            effective_num = 1.0 - np.power(betas[idx], cls_num_list)
            per_cls_weights = (1.0 - betas[idx]) / np.array(effective_num)
            per_cls_weights = per_cls_weights / np.sum(per_cls_weights) * len(
                cls_num_list)
            per_cls_weights = torch.FloatTensor(per_cls_weights).cuda(args.gpu)
        else:
            warnings.warn('Sample rule is not listed')

        if args.loss_type == 'CE':
            criterion = nn.CrossEntropyLoss(weight=per_cls_weights).cuda(
                args.gpu)
        elif args.loss_type == 'LDAM':
            criterion = LDAMLoss(cls_num_list=cls_num_list,
                                 max_m=0.5,
                                 s=30,
                                 weight=per_cls_weights).cuda(args.gpu)
        elif args.loss_type == 'Focal':
            criterion = FocalLoss(weight=per_cls_weights,
                                  gamma=1).cuda(args.gpu)
        else:
            warnings.warn('Loss type is not listed')
            return

        # train for one epoch
        train(train_loader, model, criterion, optimizer, epoch, args,
              log_training, tf_writer)

        # evaluate on validation set
        acc1 = validate(val_loader, model, criterion, epoch, args, log_testing,
                        tf_writer)

        # remember best acc@1 and save checkpoint
        is_best = acc1 > best_acc1
        best_acc1 = max(acc1, best_acc1)

        tf_writer.add_scalar(
            'acc/test_top1_best', best_acc1, epoch
        )  #writer.add_scalar将我们所需要的数据保存在文件里面供可视化使用。 这里是Scalar类型,所以使用writer.add_scalar()
        output_best = 'Best Prec@1: %.3f\n' % (best_acc1)
        print(output_best)
        log_testing.write(output_best + '\n')
        log_testing.flush()

        save_checkpoint(
            args, {
                'epoch': epoch + 1,
                'arch': args.arch,
                'state_dict': model.state_dict(),
                'best_acc1': best_acc1,
                'optimizer': optimizer.state_dict(),
            }, is_best)  #保存checkpiont
示例#5
0
文件: main.py 项目: dvirsamuel/DRO-LT
def main_worker(gpu, ngpus_per_node, args):
    global best_acc1
    args.gpu = gpu

    if args.gpu is not None:
        print("Use GPU: {} for training".format(args.gpu))

    # create model
    print("=> creating model '{}'".format(args.arch))
    num_classes = 100 if args.dataset == 'cifar100' else 10
    use_norm = True if args.loss_type == 'LDAM' else False
    model = models.__dict__[args.arch](num_classes=num_classes,
                                       use_norm=use_norm)

    # create two optimizers - one for feature extractor and one for classifier
    feat_params = []
    feat_params_names = []
    cls_params = []
    cls_params_names = []
    learnable_epsilons = torch.nn.Parameter(torch.ones(num_classes))
    for name, params in model.named_parameters():
        if params.requires_grad:
            if "linear" in name:
                cls_params_names += [name]
                cls_params += [params]
            else:
                feat_params_names += [name]
                feat_params += [params]
    print("Create Feat Optimizer")
    print(f"\tRequires Grad:{feat_params_names}")
    feat_optim = torch.optim.SGD(feat_params + [learnable_epsilons],
                                 args.feat_lr,
                                 momentum=args.momentum,
                                 weight_decay=args.weight_decay)
    print("Create Feat Optimizer")
    print(f"\tRequires Grad:{cls_params_names}")
    cls_optim = torch.optim.SGD(cls_params,
                                args.cls_lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    # optionally resume from a checkpoint
    if args.resume or args.evaluation:
        curr_store_name = args.store_name
        if not args.evaluation and args.pretrained:
            curr_store_name = os.path.join(curr_store_name, os.path.pardir)
        filename = '%s/%s/ckpt.best.pth.tar' % (args.root_model,
                                                curr_store_name)
        if os.path.isfile(filename):
            print("=> loading checkpoint '{}'".format(filename))
            checkpoint = torch.load(filename, map_location=f"cuda:{args.gpu}")
            model.load_state_dict(checkpoint['state_dict'])
            print("=> loaded checkpoint '{}' (epoch {})".format(
                filename, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(filename))

    if args.gpu is not None:
        torch.cuda.set_device(args.gpu)
        model = model.cuda(args.gpu)
    else:
        # DataParallel will divide and allocate batch_size to all available GPUs
        model = torch.nn.DataParallel(model).cuda()

    cudnn.benchmark = True
    # Data loading code=
    transform_train = transforms.Compose([
        transforms.RandomCrop(
            32, padding=4
        ),  # fill parameter needs torchvision installed from source
        transforms.RandomHorizontalFlip(),
        CIFAR10Policy(),
        transforms.ToTensor(),
        Cutout(
            n_holes=1, length=16
        ),  # (https://github.com/uoguelph-mlrg/Cutout/blob/master/util/cutout.py)
        transforms.Normalize((0.4914, 0.4822, 0.4465),
                             (0.2023, 0.1994, 0.2010))
    ])

    transform_val = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465),
                             (0.2023, 0.1994, 0.2010)),
    ])

    if args.dataset == 'cifar10':
        original_train_dataset = IMBALANCECIFAR10(root='./data',
                                                  imb_type=args.imb_type,
                                                  imb_factor=args.imb_factor,
                                                  rand_number=args.rand_number,
                                                  train=True,
                                                  download=True,
                                                  transform=transform_val)
        augmented_train_dataset = IMBALANCECIFAR10(
            root='./data',
            imb_type=args.imb_type,
            imb_factor=args.imb_factor,
            rand_number=args.rand_number,
            train=True,
            download=True,
            transform=transform_train
            if not args.evaluation else transform_val)
        val_dataset = datasets.CIFAR10(root='./data',
                                       train=False,
                                       download=True,
                                       transform=transform_val)
    elif args.dataset == 'cifar100':
        original_train_dataset = IMBALANCECIFAR100(
            root='./data',
            imb_type=args.imb_type,
            imb_factor=args.imb_factor,
            rand_number=args.rand_number,
            train=True,
            download=True,
            transform=transform_val)
        augmented_train_dataset = IMBALANCECIFAR100(
            root='./data',
            imb_type=args.imb_type,
            imb_factor=args.imb_factor,
            rand_number=args.rand_number,
            train=True,
            download=True,
            transform=transform_train
            if not args.evaluation else transform_val)
        val_dataset = datasets.CIFAR100(root='./data',
                                        train=False,
                                        download=True,
                                        transform=transform_val)
    else:
        warnings.warn('Dataset is not listed')
        return

    cls_num_list = augmented_train_dataset.get_cls_num_list()
    args.cls_num_list = cls_num_list

    train_labels = np.array(augmented_train_dataset.get_targets()).astype(int)
    # calculate balanced weights
    balanced_weights = torch.tensor(class_weight.compute_class_weight(
        'balanced', np.unique(train_labels), train_labels),
                                    dtype=torch.float).cuda(args.gpu)
    lt_weights = torch.tensor(cls_num_list).float() / max(cls_num_list)

    def create_sampler(args_str):
        if args_str is not None and "resample" in args_str:
            sampler_type, n_resample = args_str.split(",")
            return ClassAwareSampler(train_labels,
                                     num_samples_cls=int(n_resample))
        return None

    original_train_loader = torch.utils.data.DataLoader(
        original_train_dataset,
        batch_size=args.batch_size,
        shuffle=False,
        num_workers=args.workers,
        pin_memory=True)
    val_loader = torch.utils.data.DataLoader(val_dataset,
                                             batch_size=100,
                                             shuffle=False,
                                             num_workers=args.workers,
                                             pin_memory=True)

    # feature extractor dataloader
    feat_sampler = create_sampler(args.feat_sampler)
    feat_train_loader = torch.utils.data.DataLoader(
        augmented_train_dataset,
        batch_size=args.batch_size,
        shuffle=(feat_sampler is None),
        num_workers=args.workers,
        pin_memory=True,
        sampler=feat_sampler)

    if args.evaluation:
        # evaluate on validation set
        # calculate centroids on the train
        _, train_features, train_targets, _ = validate(original_train_loader,
                                                       model,
                                                       0,
                                                       args,
                                                       train_labels,
                                                       flag="train",
                                                       save_out=True)
        # validate
        validate(val_loader,
                 model,
                 0,
                 args,
                 train_labels,
                 flag="val",
                 save_out=True,
                 base_features=train_features,
                 base_targets=train_targets)
        quit()

    # create losses
    def create_loss_list(args_str):
        loss_ls = []
        loss_str_ls = args_str.split(",")
        for loss_str in loss_str_ls:
            c_weights = None
            prefix = ""
            if "_bal" in loss_str:
                c_weights = balanced_weights
                prefix = "Balanced "
                loss_str = loss_str.split("_bal")[0]
            if "_lt" in loss_str:
                c_weights = lt_weights
                prefix = "Longtailed "
                loss_str = loss_str.split("_")[0]
            if loss_str == "ce":
                print(f"{prefix}CE", end=",")
                loss_ls += [
                    nn.CrossEntropyLoss(weight=c_weights).cuda(args.gpu)
                ]
            elif loss_str == "robust_loss":
                print(f"{prefix}Robust Loss", end=",")
                loss_ls += [
                    DROLoss(temperature=args.temperature,
                            base_temperature=args.temperature,
                            class_weights=c_weights,
                            epsilons=learnable_epsilons)
                ]
        print()
        return loss_ls

    feat_losses = create_loss_list(args.feat_loss)
    cls_losses = create_loss_list(args.cls_loss)

    # init log for training
    if not args.evaluation:
        log_training = open(
            os.path.join(args.root_log, args.store_name, 'log_train.csv'), 'w')
        log_testing = open(
            os.path.join(args.root_log, args.store_name, 'log_test.csv'), 'w')
        with open(os.path.join(args.root_log, args.store_name, 'args.txt'),
                  'w') as f:
            f.write(str(args))
        tf_writer = None

    best_acc1 = 0
    best_acc_contrastive = 0
    for epoch in range(args.start_epoch, args.epochs):
        print("=============== Extract Train Centroids ===============")
        _, train_features, train_targets, _ = validate(feat_train_loader,
                                                       model,
                                                       epoch,
                                                       args,
                                                       train_labels,
                                                       log_training,
                                                       tf_writer,
                                                       flag="train",
                                                       verbose=True)

        if epoch < args.epochs - args.balanced_clf_nepochs:
            print("=============== Train Feature Extractor ===============")
            freeze_layers(model, fe_bool=True, cls_bool=False)
            train(feat_train_loader, model, feat_losses, epoch, feat_optim,
                  args, train_features, train_targets)

        else:
            if epoch == args.epochs - args.balanced_clf_nepochs:
                print(
                    "================ Loading Best Feature Extractor ================="
                )
                # load best model
                curr_store_name = args.store_name
                filename = '%s/%s/ckpt.best.pth.tar' % (args.root_model,
                                                        curr_store_name)
                checkpoint = torch.load(
                    filename, map_location=f"cuda:{args.gpu}")['state_dict']
                model.load_state_dict(checkpoint)

            print("=============== Train Classifier ===============")
            freeze_layers(model, fe_bool=False, cls_bool=True)
            train(feat_train_loader, model, cls_losses, epoch, cls_optim, args)

        print("=============== Extract Train Centroids ===============")
        _, train_features, train_targets, _ = validate(original_train_loader,
                                                       model,
                                                       epoch,
                                                       args,
                                                       train_labels,
                                                       log_training,
                                                       tf_writer,
                                                       flag="train",
                                                       verbose=False)

        print("=============== Validate ===============")
        acc1, _, _, contrastive_acc = validate(val_loader,
                                               model,
                                               epoch,
                                               args,
                                               train_labels,
                                               log_testing,
                                               tf_writer,
                                               flag="val",
                                               base_features=train_features,
                                               base_targets=train_targets)
        if epoch < args.epochs - args.balanced_clf_nepochs:
            is_best = contrastive_acc > best_acc_contrastive
            best_acc_contrastive = max(contrastive_acc, best_acc_contrastive)
        else:
            is_best = acc1 > best_acc1
            best_acc1 = max(acc1, best_acc1)

        print(
            f"Best Contrastive Acc: {best_acc_contrastive}, Best Cls Acc: {best_acc1}"
        )
        log_testing.write(
            f"Best Contrastive Acc: {best_acc_contrastive}, Best Cls Acc: {best_acc1}"
        )
        log_testing.flush()
        save_checkpoint(
            args, {
                'epoch': epoch + 1,
                'arch': args.arch,
                'state_dict': model.state_dict(),
                'best_acc1': best_acc1
            }, is_best)
def main_worker(gpu, ngpus_per_node, args):
    global best_acc1
    args.gpu = gpu

    if args.gpu is not None:
        print("Use GPU: {} for training".format(args.gpu))

    # create model
    print("=> creating model '{}'".format(args.arch))
    num_classes = 100 if args.dataset == 'cifar100' else 10
    use_norm = True if args.loss_type == 'LDAM' else False
    model = models.__dict__[args.arch](num_classes=num_classes,
                                       use_norm=use_norm)
    non_gpu_model = models.__dict__[args.arch](num_classes=num_classes,
                                               use_norm=use_norm)

    if args.gpu is not None:
        torch.cuda.set_device(args.gpu)
        model = model.cuda(args.gpu)
    else:
        # DataParallel will divide and allocate batch_size to all available GPUs
        model = torch.nn.DataParallel(model).cuda()

    optimizer = torch.optim.SGD(model.parameters(),
                                args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume, map_location='cuda:0')
            args.start_epoch = checkpoint['epoch']
            best_acc1 = checkpoint['best_acc1']
            if args.gpu is not None:
                # best_acc1 may be from a checkpoint from a different GPU
                best_acc1 = best_acc1.to(args.gpu)
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    cudnn.benchmark = True

    # Data loading code

    transform_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465),
                             (0.2023, 0.1994, 0.2010)),
    ])

    transform_val = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465),
                             (0.2023, 0.1994, 0.2010)),
    ])

    if args.dataset == 'cifar10':
        train_dataset = IMBALANCECIFAR10(root='./data',
                                         imb_type=args.imb_type,
                                         imb_factor=args.imb_factor,
                                         rand_number=args.rand_number,
                                         train=True,
                                         download=True,
                                         transform=transform_train)
        val_dataset = datasets.CIFAR10(root='./data',
                                       train=False,
                                       download=True,
                                       transform=transform_val)
    elif args.dataset == 'cifar100':
        train_dataset = IMBALANCECIFAR100(root='./data',
                                          imb_type=args.imb_type,
                                          imb_factor=args.imb_factor,
                                          rand_number=args.rand_number,
                                          train=True,
                                          download=True,
                                          transform=transform_train)
        val_dataset = datasets.CIFAR100(root='./data',
                                        train=False,
                                        download=True,
                                        transform=transform_val)
    else:
        warnings.warn('Dataset is not listed')
        return
    cls_num_list = train_dataset.get_cls_num_list()
    print('cls num list:')
    print(cls_num_list)
    args.cls_num_list = cls_num_list

    train_sampler = None

    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=(train_sampler is None),
                                               num_workers=args.workers,
                                               pin_memory=True,
                                               sampler=train_sampler)

    val_loader = torch.utils.data.DataLoader(val_dataset,
                                             batch_size=100,
                                             shuffle=False,
                                             num_workers=args.workers,
                                             pin_memory=True)

    # init log for training
    log_training = open(
        os.path.join(args.root_log, args.store_name, 'log_train.csv'), 'w')
    log_testing = open(
        os.path.join(args.root_log, args.store_name, 'log_test.csv'), 'w')
    with open(os.path.join(args.root_log, args.store_name, 'args.txt'),
              'w') as f:
        f.write(str(args))
    tf_writer = SummaryWriter(
        log_dir=os.path.join(args.root_log, args.store_name))
    for epoch in range(args.start_epoch, args.epochs):
        adjust_learning_rate(optimizer, epoch, args)

        if args.train_rule == 'None':
            train_sampler = None
            per_cls_weights = None
        elif args.train_rule == 'Resample':
            train_sampler = ImbalancedDatasetSampler(train_dataset)
            per_cls_weights = None
        elif args.train_rule == 'Reweight':
            train_sampler = None
            beta = 0.9999
            effective_num = 1.0 - np.power(beta, cls_num_list)
            per_cls_weights = (1.0 - beta) / np.array(effective_num)
            per_cls_weights = per_cls_weights / np.sum(per_cls_weights) * len(
                cls_num_list)
            per_cls_weights = torch.FloatTensor(per_cls_weights).cuda(args.gpu)
        elif args.train_rule == 'DRW':
            train_sampler = None
            idx = min(epoch // 160, 1)
            betas = [0, 0.9999]
            effective_num = 1.0 - np.power(betas[idx], cls_num_list)
            per_cls_weights = (1.0 - betas[idx]) / np.array(effective_num)
            per_cls_weights = per_cls_weights / np.sum(per_cls_weights) * len(
                cls_num_list)
            per_cls_weights = torch.FloatTensor(per_cls_weights).cuda(args.gpu)
        elif args.train_rule == 'SWITCHING_DRW_AUTO_CLUSTER':
            train_sampler = None
            cutoff_epoch = args.cutoff_epoch
            idx = min(epoch // cutoff_epoch, 1)
            betas = [0, 0.9999]
            if epoch >= cutoff_epoch + 10 and (epoch - cutoff_epoch) % 20 == 0:
                max_real_ratio_number_of_labels = 20
                # todo: transform data batch by batch, then concatentate...
                temp_batch_size = int(train_dataset.data.shape[0])
                temp_train_loader = torch.utils.data.DataLoader(
                    train_dataset,
                    batch_size=int(temp_batch_size),
                    shuffle=(train_sampler is None),
                    num_workers=args.workers,
                    pin_memory=True,
                    sampler=train_sampler)

                transformed_data = None
                transformed_labels = None
                for i, (xs, labels) in enumerate(train_loader):
                    transformed_batch_data = model.forward_partial(
                        xs.cuda(), num_layers=6)
                    transformed_batch_data = transformed_batch_data.cpu(
                    ).detach()
                    transformed_batch_data = transformed_batch_data.numpy()
                    transformed_batch_data = np.reshape(
                        transformed_batch_data,
                        (transformed_batch_data.shape[0], -1))
                    labels = np.array(labels)[:, np.newaxis]
                    if transformed_data is None:
                        transformed_data = transformed_batch_data
                        transformed_labels = labels
                    else:
                        transformed_data = np.vstack(
                            (transformed_data, transformed_batch_data))
                        # print(labels.shape)
                        # print(transformed_labels.shape)
                        transformed_labels = np.vstack(
                            (transformed_labels, labels))

                xmean_model = xmeans(data=transformed_data,
                                     kmax=num_classes *
                                     max_real_ratio_number_of_labels)
                xmean_model.process()
                # Extract clustering results: clusters and their centers
                clusters = xmean_model.get_clusters()
                centers = xmean_model.get_centers()
                new_labels = []
                xs = transformed_data
                centers = np.array(centers)
                print("number of clusters: ", len(centers))
                squared_norm_dist = np.sum((xs - centers[:, np.newaxis])**2,
                                           axis=2)
                data_centers = np.argmin(squared_norm_dist, axis=0)
                data_centers = np.expand_dims(data_centers, axis=1)

                new_labels = []
                for i in range(len(transformed_labels)):
                    new_labels.append(data_centers[i][0] +
                                      transformed_labels[i][0] * len(centers))

                new_label_counts = {}
                for label in new_labels:
                    if label in new_label_counts.keys():
                        new_label_counts[label] += 1
                    else:
                        new_label_counts[label] = 1

                # print(new_label_counts)

                per_cls_weights = []
                for i in range(len(cls_num_list)):
                    temp = []
                    for j in range(len(centers)):
                        new_label = j + i * len(centers)
                        if new_label in new_label_counts:
                            temp.append(new_label_counts[new_label])
                    effective_num_temp = 1.0 - np.power(betas[idx], temp)
                    per_cls_weights_temp = (
                        1.0 - betas[idx]) / np.array(effective_num_temp)
                    per_cls_weights.append(
                        np.average(per_cls_weights_temp, weights=temp))
                per_cls_weights = per_cls_weights / np.sum(
                    per_cls_weights) * len(cls_num_list)
                per_cls_weights = torch.FloatTensor(per_cls_weights).cuda(
                    args.gpu)
            elif epoch < cutoff_epoch or (epoch - cutoff_epoch) % 20 == 10:
                effective_num = 1.0 - np.power(betas[idx], cls_num_list)
                per_cls_weights = (1.0 - betas[idx]) / np.array(effective_num)
                per_cls_weights = per_cls_weights / np.sum(
                    per_cls_weights) * len(cls_num_list)
                per_cls_weights = torch.FloatTensor(per_cls_weights).cuda(
                    args.gpu)

        elif args.train_rule == 'DRW_AUTO_CLUSTER':
            train_sampler = None
            cutoff_epoch = args.cutoff_epoch
            idx = min(epoch // cutoff_epoch, 1)
            betas = [0, 0.9999]
            if epoch >= cutoff_epoch and (epoch - cutoff_epoch) % 10 == 0:
                max_real_ratio_number_of_labels = 20
                # todo: transform data batch by batch, then concatentate...
                temp_batch_size = int(train_dataset.data.shape[0])
                temp_train_loader = torch.utils.data.DataLoader(
                    train_dataset,
                    batch_size=int(temp_batch_size),
                    shuffle=(train_sampler is None),
                    num_workers=args.workers,
                    pin_memory=True,
                    sampler=train_sampler)

                transformed_data = None
                transformed_labels = None
                for i, (xs, labels) in enumerate(train_loader):
                    transformed_batch_data = model.forward_partial(
                        xs.cuda(), num_layers=6)
                    transformed_batch_data = transformed_batch_data.cpu(
                    ).detach()
                    transformed_batch_data = transformed_batch_data.numpy()
                    transformed_batch_data = np.reshape(
                        transformed_batch_data,
                        (transformed_batch_data.shape[0], -1))
                    labels = np.array(labels)[:, np.newaxis]
                    if transformed_data is None:
                        transformed_data = transformed_batch_data
                        transformed_labels = labels
                    else:
                        transformed_data = np.vstack(
                            (transformed_data, transformed_batch_data))
                        # print(labels.shape)
                        # print(transformed_labels.shape)
                        transformed_labels = np.vstack(
                            (transformed_labels, labels))

                initial_centers = [
                    np.zeros((transformed_data.shape[1], ))
                    for i in range(num_classes)
                ]
                center_counts = [0 for i in range(num_classes)]
                for i in range(transformed_data.shape[0]):
                    temp_idx = transformed_labels[i][0]
                    initial_centers[temp_idx] = initial_centers[
                        temp_idx] + transformed_data[i, :]
                    center_counts[temp_idx] = center_counts[temp_idx] + 1

                for i in range(num_classes):
                    initial_centers[i] = initial_centers[i] / center_counts[i]

                xmean_model = xmeans(data=transformed_data, initial_centers=initial_centers, \
                                     kmax=num_classes * max_real_ratio_number_of_labels)
                xmean_model.process()
                # Extract clustering results: clusters and their centers
                clusters = xmean_model.get_clusters()
                centers = xmean_model.get_centers()
                new_labels = []
                xs = transformed_data
                centers = np.array(centers)
                print("number of clusters: ", len(centers))
                squared_norm_dist = np.sum((xs - centers[:, np.newaxis])**2,
                                           axis=2)
                data_centers = np.argmin(squared_norm_dist, axis=0)
                data_centers = np.expand_dims(data_centers, axis=1)

                new_labels = []
                for i in range(len(transformed_labels)):
                    new_labels.append(data_centers[i][0] +
                                      transformed_labels[i][0] * len(centers))

                new_label_counts = {}
                for label in new_labels:
                    if label in new_label_counts.keys():
                        new_label_counts[label] += 1
                    else:
                        new_label_counts[label] = 1

                # print(new_label_counts)

                per_cls_weights = []
                for i in range(len(cls_num_list)):
                    temp = []
                    for j in range(len(centers)):
                        new_label = j + i * len(centers)
                        if new_label in new_label_counts:
                            temp.append(new_label_counts[new_label])
                    effective_num_temp = 1.0 - np.power(betas[idx], temp)
                    per_cls_weights_temp = (
                        1.0 - betas[idx]) / np.array(effective_num_temp)
                    per_cls_weights.append(
                        np.average(per_cls_weights_temp, weights=temp))
                per_cls_weights = per_cls_weights / np.sum(
                    per_cls_weights) * len(cls_num_list)
                per_cls_weights = torch.FloatTensor(per_cls_weights).cuda(
                    args.gpu)
            elif epoch < cutoff_epoch:
                effective_num = 1.0 - np.power(betas[idx], cls_num_list)
                per_cls_weights = (1.0 - betas[idx]) / np.array(effective_num)
                per_cls_weights = per_cls_weights / np.sum(
                    per_cls_weights) * len(cls_num_list)
                per_cls_weights = torch.FloatTensor(per_cls_weights).cuda(
                    args.gpu)

        else:
            warnings.warn('Sample rule is not listed')

        if args.loss_type == 'CE':
            criterion = nn.CrossEntropyLoss(weight=per_cls_weights).cuda(
                args.gpu)
        elif args.loss_type == 'LDAM':
            criterion = LDAMLoss(cls_num_list=cls_num_list,
                                 max_m=0.5,
                                 s=30,
                                 weight=per_cls_weights).cuda(args.gpu)
            #temp = [cls_num_list[i] * per_cls_weights[i].item() for i in range(len(cls_num_list))]
            #criterion = LDAMLoss(cls_num_list=temp, max_m=0.5, s=30, weight=per_cls_weights).cuda(args.gpu)
        elif args.loss_type == 'Focal':
            criterion = FocalLoss(weight=per_cls_weights,
                                  gamma=1).cuda(args.gpu)
        else:
            warnings.warn('Loss type is not listed')
            return

        # train for one epoch
        train(train_loader, model, criterion, optimizer, epoch, args,
              log_training, tf_writer)

        # evaluate on validation set
        acc1 = validate(val_loader, model, criterion, epoch, args, log_testing,
                        tf_writer)

        # remember best acc@1 and save checkpoint
        is_best = acc1 > best_acc1
        best_acc1 = max(acc1, best_acc1)

        tf_writer.add_scalar('acc/test_top1_best', best_acc1, epoch)
        output_best = 'Best Prec@1: %.3f\n' % (best_acc1)
        print(output_best)
        log_testing.write(output_best + '\n')
        log_testing.flush()

        save_checkpoint(
            args, {
                'epoch': epoch + 1,
                'arch': args.arch,
                'state_dict': model.state_dict(),
                'best_acc1': best_acc1,
                'optimizer': optimizer.state_dict(),
            }, is_best)