Exemplo n.º 1
0
def main_worker(gpu, ngpus_per_node, args):
    global best_acc1
    args.gpu = gpu

    if args.gpu is not None:
        print("Use GPU: {} for training".format(args.gpu))

    # create model
    print("=> creating model '{}'".format(args.arch))
    num_classes = 100 if args.dataset == 'cifar100' else 10
    use_norm = True if args.loss_type == 'LDAM' else False

    # DataParallel will divide and allocate batch_size to all available GPUs
    model = models.__dict__[args.arch](
        num_classes=num_classes,
        use_norm=use_norm,
        head_tail_ratio=args.head_tail_ratio,
        transfer_strength=args.transfer_strength,
        phase_train=True,
        epoch_thresh=args.epoch_thresh)

    model = torch.nn.DataParallel(model).cuda()

    optimizer = torch.optim.SGD(model.parameters(),
                                args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    cudnn.benchmark = True

    # Data loading code

    transform_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465),
                             (0.2023, 0.1994, 0.2010)),
    ])

    transform_val = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465),
                             (0.2023, 0.1994, 0.2010)),
    ])

    if args.dataset == 'cifar10':
        train_dataset = IMBALANCECIFAR10(root='./data',
                                         imb_type=args.imb_type,
                                         imb_factor=args.imb_factor,
                                         rand_number=args.rand_number,
                                         train=True,
                                         download=True,
                                         transform=transform_train)
        val_dataset = datasets.CIFAR10(root='./data',
                                       train=False,
                                       download=True,
                                       transform=transform_val)
    elif args.dataset == 'cifar100':
        train_dataset = IMBALANCECIFAR100(root='./data',
                                          imb_type=args.imb_type,
                                          imb_factor=args.imb_factor,
                                          rand_number=args.rand_number,
                                          train=True,
                                          download=True,
                                          transform=transform_train)
        val_dataset = datasets.CIFAR100(root='./data',
                                        train=False,
                                        download=True,
                                        transform=transform_val)
    else:
        warnings.warn('Dataset is not listed')
        return
    cls_num_list = train_dataset.get_cls_num_list()
    print('cls num list:')
    print(cls_num_list)
    args.cls_num_list = cls_num_list

    train_sampler = None

    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=(train_sampler is None),
                                               num_workers=args.workers,
                                               pin_memory=True,
                                               sampler=train_sampler,
                                               drop_last=True)

    val_loader = torch.utils.data.DataLoader(val_dataset,
                                             batch_size=100,
                                             shuffle=False,
                                             num_workers=args.workers,
                                             pin_memory=True)

    # init log for training
    log_training = open(
        os.path.join(args.root_log, args.store_name, 'log_train.csv'), 'w')
    log_testing = open(
        os.path.join(args.root_log, args.store_name, 'log_test.csv'), 'w')
    with open(os.path.join(args.root_log, args.store_name, 'args.txt'),
              'w') as f:
        f.write(str(args))
    tf_writer = SummaryWriter(
        log_dir=os.path.join(args.root_log, args.store_name))
    for epoch in range(args.start_epoch, args.epochs):
        adjust_learning_rate(optimizer, epoch, args)

        if args.train_rule == 'None':
            train_sampler = None
            per_cls_weights = None
        elif args.train_rule == 'Resample':
            train_sampler = ImbalancedDatasetSampler(train_dataset)
            per_cls_weights = None
        elif args.train_rule == 'Reweight':
            train_sampler = None
            beta = 0.9999
            effective_num = 1.0 - np.power(beta, cls_num_list)
            per_cls_weights = (1.0 - beta) / np.array(effective_num)
            per_cls_weights = per_cls_weights / np.sum(per_cls_weights) * len(
                cls_num_list)
            per_cls_weights = torch.FloatTensor(per_cls_weights).cuda(args.gpu)
        elif args.train_rule == 'DRW':
            train_sampler = None
            idx = epoch // 160
            betas = [0, 0.9999]
            effective_num = 1.0 - np.power(betas[idx], cls_num_list)
            per_cls_weights = (1.0 - betas[idx]) / np.array(effective_num)
            per_cls_weights = per_cls_weights / np.sum(per_cls_weights) * len(
                cls_num_list)
            per_cls_weights = torch.FloatTensor(per_cls_weights).cuda(args.gpu)
        else:
            warnings.warn('Sample rule is not listed')

        if args.loss_type == 'CE':
            criterion = nn.CrossEntropyLoss(weight=per_cls_weights).cuda(
                args.gpu)
        elif args.loss_type == 'LDAM':
            criterion = LDAMLoss(cls_num_list=cls_num_list,
                                 max_m=0.5,
                                 s=30,
                                 weight=per_cls_weights).cuda(args.gpu)
        elif args.loss_type == 'Focal':
            criterion = FocalLoss(weight=per_cls_weights,
                                  gamma=1).cuda(args.gpu)
        else:
            warnings.warn('Loss type is not listed')
            return

        # train for one epoch
        train(train_loader, model, criterion, optimizer, epoch, args,
              log_training, tf_writer)

        # evaluate on validation set
        acc1 = validate(val_loader, model, criterion, epoch, args, log_testing,
                        tf_writer)

        # remember best acc@1 and save checkpoint
        is_best = acc1 > best_acc1
        best_acc1 = max(acc1, best_acc1)

        tf_writer.add_scalar('acc/test_top1_best', best_acc1, epoch)
        output_best = 'Best Prec@1: %.3f\n' % (best_acc1)
        print(output_best)
        log_testing.write(output_best + '\n')
        log_testing.flush()

        # remove RSG module, since RSG is not used during testing.
        new_state_dict = OrderedDict()
        for k in model.state_dict().keys():
            name = k[7:]  # remove `module.`
            if 'RSG' in k:
                continue
            new_state_dict[name] = model.state_dict()[k]

        save_checkpoint(
            args, {
                'epoch': epoch + 1,
                'arch': args.arch,
                'state_dict': new_state_dict,
                'best_acc1': best_acc1,
                'optimizer': optimizer.state_dict(),
            }, is_best)
Exemplo n.º 2
0
def main_worker(gpu, ngpus_per_node, args):
    global best_acc1
    args.gpu = gpu

    if args.gpu is not None:
        print("Use GPU: {} for training".format(args.gpu))

    # create model
    print("=> creating model '{}'".format(args.arch))
    num_classes = 100 if args.dataset == 'cifar100' else 10
    use_norm = True if args.loss_type in ['LDAM'] else False
    model = models.__dict__[args.arch](num_classes=num_classes,
                                       use_norm=use_norm)

    if args.gpu is not None:
        torch.cuda.set_device(args.gpu)
        model = model.cuda(args.gpu)
    else:
        # DataParallel will divide and allocate batch_size to all available GPUs
        model = torch.nn.DataParallel(model).cuda()

    optimizer = torch.optim.SGD(model.parameters(),
                                args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume, map_location='cuda:0')
            args.start_epoch = checkpoint['epoch']
            best_acc1 = checkpoint['best_acc1']
            if args.gpu is not None:
                # best_acc1 may be from a checkpoint from a different GPU
                best_acc1 = best_acc1.to(args.gpu)
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    cudnn.benchmark = True

    # Data loading code

    transform_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465),
                             (0.2023, 0.1994, 0.2010)),
    ])

    transform_val = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465),
                             (0.2023, 0.1994, 0.2010)),
    ])

    if args.dataset == 'cifar10':
        train_dataset = IMBALANCECIFAR10(root='./data/CIFAR10',
                                         imb_type=args.imb_type,
                                         imb_factor=args.imb_factor,
                                         rand_number=args.rand_number,
                                         train=True,
                                         download=True,
                                         transform=transform_train)
        val_dataset = datasets.CIFAR10(root='./data/CIFAR10',
                                       train=False,
                                       download=True,
                                       transform=transform_val)
    elif args.dataset == 'cifar100':
        train_dataset = IMBALANCECIFAR100(root='./data/CIFAR100',
                                          imb_type=args.imb_type,
                                          imb_factor=args.imb_factor,
                                          rand_number=args.rand_number,
                                          train=True,
                                          download=True,
                                          transform=transform_train)
        val_dataset = datasets.CIFAR100(root='./data/CIFAR100',
                                        train=False,
                                        download=True,
                                        transform=transform_val)
    else:
        warnings.warn('Dataset is not listed')
        return
    cls_num_list = train_dataset.get_cls_num_list()
    print('cls num list:')
    print(cls_num_list)
    args.cls_num_list = cls_num_list

    train_sampler = None

    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=(train_sampler is None),
                                               num_workers=args.workers,
                                               pin_memory=True,
                                               sampler=train_sampler)

    val_loader = torch.utils.data.DataLoader(val_dataset,
                                             batch_size=100,
                                             shuffle=False,
                                             num_workers=args.workers,
                                             pin_memory=True)

    # init log for training
    log_training = open(
        os.path.join(args.root_log, args.store_name, 'log_train.csv'), 'w')
    log_testing = open(
        os.path.join(args.root_log, args.store_name, 'log_test.csv'), 'w')
    with open(os.path.join(args.root_log, args.store_name, 'args.txt'),
              'w') as f:
        f.write(str(args))
    tf_writer = SummaryWriter(
        log_dir=os.path.join(args.root_log, args.store_name))

    # TAG Init train rule
    if args.train_rule == 'None':
        train_sampler = None
        per_cls_weights = None
    elif args.train_rule == 'EffectiveNumber':
        train_sampler = None
        beta = 0.9999
        effective_num = 1.0 - np.power(beta, cls_num_list)
        per_cls_weights = (1.0 - beta) / np.array(effective_num)
        per_cls_weights = per_cls_weights / np.sum(per_cls_weights) * len(
            cls_num_list)
        per_cls_weights = torch.FloatTensor(per_cls_weights).cuda(args.gpu)
    elif args.train_rule == 'ClassBlance':
        train_sampler = None
        per_cls_weights = 1.0 / np.array(cls_num_list)
        per_cls_weights = per_cls_weights / np.mean(per_cls_weights)
        per_cls_weights = torch.FloatTensor(per_cls_weights).cuda(args.gpu)
    elif args.train_rule == 'ClassBlanceV2':
        train_sampler = None
        per_cls_weights = 1.0 / np.power(np.array(cls_num_list), 0.25)
        per_cls_weights = per_cls_weights / np.mean(per_cls_weights)
        per_cls_weights = torch.FloatTensor(per_cls_weights).cuda(args.gpu)
    else:
        warnings.warn('Sample rule is not listed')

    # TAG Init loss
    if args.loss_type == 'CE':
        # criterion = nn.CrossEntropyLoss(weight=per_cls_weights).cuda(args.gpu)
        criterion = CEloss(weight=per_cls_weights,
                           num_classes=num_classes).cuda(args.gpu)
    elif args.loss_type == 'LDAM':
        criterion = LDAMLoss(cls_num_list=cls_num_list,
                             max_m=0.5,
                             s=30,
                             weight=per_cls_weights).cuda(args.gpu)
    elif args.loss_type == 'Focal':
        criterion = FocalLoss(weight=per_cls_weights, gamma=1).cuda(args.gpu)
    elif args.loss_type == 'Seesaw':
        criterion = SeesawLoss(num_classes=num_classes).cuda(args.gpu)
    elif args.loss_type == 'GradSeesawLoss':
        criterion = GradSeesawLoss(num_classes=num_classes).cuda(args.gpu)
    elif args.loss_type == 'SoftSeesaw':
        criterion = SoftSeesawLoss(num_classes=num_classes,
                                   beta=args.beta).cuda(args.gpu)
    elif args.loss_type == 'SoftGradeSeesawLoss':
        criterion = SoftGradeSeesawLoss(num_classes=num_classes).cuda(args.gpu)
    elif args.loss_type == 'Seesaw_prior':
        criterion = SeesawLoss_prior(cls_num_list=cls_num_list).cuda(args.gpu)
    elif args.loss_type == 'GradSeesawLoss_prior':
        criterion = GradSeesawLoss_prior(cls_num_list=cls_num_list).cuda(
            args.gpu)
    elif args.loss_type == 'GHMc':
        criterion = GHMcLoss(bins=10, momentum=0.75,
                             use_sigmoid=True).cuda(args.gpu)
    elif args.loss_type == 'SoftmaxGHMc':
        criterion = SoftmaxGHMc(bins=30, momentum=0.75).cuda(args.gpu)
    elif args.loss_type == 'SoftmaxGHMcV2':
        criterion = SoftmaxGHMcV2(bins=30, momentum=0.75).cuda(args.gpu)
    elif args.loss_type == 'SoftmaxGHMcV3':
        criterion = SoftmaxGHMcV3(bins=30, momentum=0.75).cuda(args.gpu)
    elif args.loss_type == 'SeesawGHMc':
        criterion = SeesawGHMc(bins=30, momentum=0.75).cuda(args.gpu)
    elif args.loss_type == 'EQLv2':
        criterion = EQLv2(num_classes=num_classes).cuda(args.gpu)
    elif args.loss_type == 'EQL':
        criterion = EQLloss(cls_num_list=cls_num_list).cuda(args.gpu)
    elif args.loss_type == 'GHMSeesaw':
        criterion = GHMSeesaw(num_classes=num_classes).cuda(args.gpu)
    elif args.loss_type == 'GHMSeesawV2':
        criterion = GHMSeesawV2(num_classes=num_classes,
                                beta=args.beta).cuda(args.gpu)
    else:
        warnings.warn('Loss type is not listed')
        return

    valid_criterion = nn.CrossEntropyLoss().cuda(args.gpu)
    for epoch in range(args.start_epoch, args.epochs):
        adjust_learning_rate(optimizer, epoch, args)

        # train for one epoch
        train(train_loader, model, criterion, optimizer, epoch, args,
              log_training, tf_writer)
        # print(criterion.cls_num_list.transpose(1,0))

        # evaluate on validation set
        acc1 = validate(val_loader, model, valid_criterion, epoch, args,
                        log_testing, tf_writer)

        # remember best acc@1 and save checkpoint
        is_best = acc1 > best_acc1
        best_acc1 = max(acc1, best_acc1)

        tf_writer.add_scalar('acc/test_top1_best', best_acc1, epoch)
        output_best = 'Best Prec@1: %.3f\n' % (best_acc1)
        print(output_best)
        log_testing.write(output_best + '\n')
        log_testing.flush()

        save_checkpoint(
            args, {
                'epoch': epoch + 1,
                'arch': args.arch,
                'state_dict': model.state_dict(),
                'best_acc1': best_acc1,
                'optimizer': optimizer.state_dict(),
            }, is_best)
Exemplo n.º 3
0
def main_worker(gpu, args):
    global best_acc1
    args.gpu = gpu

    if args.gpu is not None:
        print(f"Use GPU: {args.gpu} for training")

    print(f"===> Creating model '{args.arch}'")
    if args.dataset == 'cifar100':
        num_classes = 100
    elif args.dataset in {'cifar10', 'svhn'}:
        num_classes = 10
    else:
        raise NotImplementedError
    use_norm = True if args.loss_type == 'LDAM' else False
    model = models.__dict__[args.arch](num_classes=num_classes,
                                       use_norm=use_norm)

    if args.gpu is not None:
        torch.cuda.set_device(args.gpu)
        model = model.cuda()
    else:
        model = torch.nn.DataParallel(model).cuda()

    optimizer = torch.optim.SGD(model.parameters(),
                                args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    mean = [0.4914, 0.4822, 0.4465
            ] if args.dataset.startswith('cifar') else [.5, .5, .5]
    std = [0.2023, 0.1994, 0.2010
           ] if args.dataset.startswith('cifar') else [.5, .5, .5]
    transform_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(mean, std),
    ])
    transform_val = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean, std),
    ])

    if args.dataset == 'cifar10':
        train_dataset = ImbalanceCIFAR10(root=args.data_path,
                                         imb_type=args.imb_type,
                                         imb_factor=args.imb_factor,
                                         rand_number=args.rand_number,
                                         train=True,
                                         download=True,
                                         transform=transform_train)
        val_dataset = datasets.CIFAR10(root=args.data_path,
                                       train=False,
                                       download=True,
                                       transform=transform_val)
        train_sampler = None
        if args.train_rule == 'Resample':
            train_sampler = ImbalancedDatasetSampler(train_dataset)
        train_loader = torch.utils.data.DataLoader(
            train_dataset,
            batch_size=args.batch_size,
            shuffle=(train_sampler is None),
            num_workers=args.workers,
            pin_memory=True,
            sampler=train_sampler)
        val_loader = torch.utils.data.DataLoader(val_dataset,
                                                 batch_size=100,
                                                 shuffle=False,
                                                 num_workers=args.workers,
                                                 pin_memory=True)
    elif args.dataset == 'cifar100':
        train_dataset = ImbalanceCIFAR100(root=args.data_path,
                                          imb_type=args.imb_type,
                                          imb_factor=args.imb_factor,
                                          rand_number=args.rand_number,
                                          train=True,
                                          download=True,
                                          transform=transform_train)
        val_dataset = datasets.CIFAR100(root=args.data_path,
                                        train=False,
                                        download=True,
                                        transform=transform_val)
        train_sampler = None
        if args.train_rule == 'Resample':
            train_sampler = ImbalancedDatasetSampler(train_dataset)
        train_loader = torch.utils.data.DataLoader(
            train_dataset,
            batch_size=args.batch_size,
            shuffle=(train_sampler is None),
            num_workers=args.workers,
            pin_memory=True,
            sampler=train_sampler)
        val_loader = torch.utils.data.DataLoader(val_dataset,
                                                 batch_size=100,
                                                 shuffle=False,
                                                 num_workers=args.workers,
                                                 pin_memory=True)
    elif args.dataset == 'svhn':
        train_dataset = ImbalanceSVHN(root=args.data_path,
                                      imb_type=args.imb_type,
                                      imb_factor=args.imb_factor,
                                      rand_number=args.rand_number,
                                      split='train',
                                      download=True,
                                      transform=transform_train)
        val_dataset = datasets.SVHN(root=args.data_path,
                                    split='test',
                                    download=True,
                                    transform=transform_val)
        train_sampler = None
        if args.train_rule == 'Resample':
            train_sampler = ImbalancedDatasetSampler(train_dataset)
        train_loader = torch.utils.data.DataLoader(
            train_dataset,
            batch_size=args.batch_size,
            shuffle=(train_sampler is None),
            num_workers=args.workers,
            pin_memory=True,
            sampler=train_sampler)
        val_loader = torch.utils.data.DataLoader(val_dataset,
                                                 batch_size=100,
                                                 shuffle=False,
                                                 num_workers=args.workers,
                                                 pin_memory=True)
    else:
        raise NotImplementedError(f"Dataset {args.dataset} is not supported!")

    # evaluate only
    if args.evaluate:
        assert args.resume, 'Specify a trained model using [args.resume]'
        checkpoint = torch.load(
            args.resume, map_location=torch.device(f'cuda:{str(args.gpu)}'))
        model.load_state_dict(checkpoint['state_dict'])
        print(f"===> Checkpoint '{args.resume}' loaded, testing...")
        validate(val_loader, model, nn.CrossEntropyLoss(), 0, args)
        return

    if args.resume:
        if os.path.isfile(args.resume):
            print(f"===> Loading checkpoint '{args.resume}'")
            checkpoint = torch.load(
                args.resume,
                map_location=torch.device(f'cuda:{str(args.gpu)}'))
            args.start_epoch = checkpoint['epoch']
            best_acc1 = checkpoint['best_acc1']
            if args.gpu is not None:
                best_acc1 = best_acc1.to(args.gpu)
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            print(
                f"===> Loaded checkpoint '{args.resume}' (epoch {checkpoint['epoch']})"
            )
        else:
            raise ValueError(f"No checkpoint found at '{args.resume}'")

    # load self-supervised pre-trained model
    if args.pretrained_model:
        checkpoint = torch.load(
            args.pretrained_model,
            map_location=torch.device(f'cuda:{str(args.gpu)}'))
        if 'moco_ckpt' not in args.pretrained_model:
            from collections import OrderedDict
            new_state_dict = OrderedDict()
            for k, v in checkpoint['state_dict'].items():
                if 'linear' not in k and 'fc' not in k:
                    new_state_dict[k] = v
            model.load_state_dict(new_state_dict, strict=False)
            print(
                f'===> Pretrained weights found in total: [{len(list(new_state_dict.keys()))}]'
            )
        else:
            # rename moco pre-trained keys
            state_dict = checkpoint['state_dict']
            for k in list(state_dict.keys()):
                # retain only encoder_q up to before the embedding layer
                if k.startswith('module.encoder_q'
                                ) and not k.startswith('module.encoder_q.fc'):
                    # remove prefix
                    state_dict[k[len("module.encoder_q."):]] = state_dict[k]
                # delete renamed or unused k
                del state_dict[k]
            msg = model.load_state_dict(state_dict, strict=False)
            if use_norm:
                assert set(msg.missing_keys) == {"fc.weight"}
            else:
                assert set(msg.missing_keys) == {"fc.weight", "fc.bias"}
        print(f'===> Pre-trained model loaded: {args.pretrained_model}')

    cudnn.benchmark = True

    if args.dataset.startswith(('cifar', 'svhn')):
        cls_num_list = train_dataset.get_cls_num_list()
        print('cls num list:')
        print(cls_num_list)
        args.cls_num_list = cls_num_list

    # init log for training
    log_training = open(
        os.path.join(args.root_log, args.store_name, 'log_train.csv'), 'w')
    log_testing = open(
        os.path.join(args.root_log, args.store_name, 'log_test.csv'), 'w')
    with open(os.path.join(args.root_log, args.store_name, 'args.txt'),
              'w') as f:
        f.write(str(args))
    tf_writer = SummaryWriter(
        log_dir=os.path.join(args.root_log, args.store_name))

    for epoch in range(args.start_epoch, args.epochs):
        adjust_learning_rate(optimizer, epoch, args)

        if args.train_rule == 'Reweight':
            beta = 0.9999
            effective_num = 1.0 - np.power(beta, cls_num_list)
            per_cls_weights = (1.0 - beta) / np.array(effective_num)
            per_cls_weights = per_cls_weights / np.sum(per_cls_weights) * len(
                cls_num_list)
            per_cls_weights = torch.FloatTensor(per_cls_weights).cuda(args.gpu)
        elif args.train_rule == 'DRW':
            idx = epoch // 160
            betas = [0, 0.9999]
            effective_num = 1.0 - np.power(betas[idx], cls_num_list)
            per_cls_weights = (1.0 - betas[idx]) / np.array(effective_num)
            per_cls_weights = per_cls_weights / np.sum(per_cls_weights) * len(
                cls_num_list)
            per_cls_weights = torch.FloatTensor(per_cls_weights).cuda(args.gpu)
        else:
            per_cls_weights = None

        if args.loss_type == 'CE':
            criterion = nn.CrossEntropyLoss(weight=per_cls_weights).cuda(
                args.gpu)
        elif args.loss_type == 'LDAM':
            criterion = LDAMLoss(cls_num_list=cls_num_list,
                                 max_m=0.5,
                                 s=30,
                                 weight=per_cls_weights).cuda(args.gpu)
        elif args.loss_type == 'Focal':
            criterion = FocalLoss(weight=per_cls_weights,
                                  gamma=1).cuda(args.gpu)
        else:
            warnings.warn('Loss type is not listed')
            return

        train(train_loader, model, criterion, optimizer, epoch, args,
              log_training, tf_writer)
        acc1 = validate(val_loader, model, criterion, epoch, args, log_testing,
                        tf_writer)

        is_best = acc1 > best_acc1
        best_acc1 = max(acc1, best_acc1)

        tf_writer.add_scalar('acc/test_top1_best', best_acc1, epoch)
        output_best = 'Best Prec@1: %.3f\n' % best_acc1
        print(output_best)
        log_testing.write(output_best + '\n')
        log_testing.flush()

        save_checkpoint(
            args, {
                'epoch': epoch + 1,
                'arch': args.arch,
                'state_dict': model.state_dict(),
                'best_acc1': best_acc1,
                'optimizer': optimizer.state_dict(),
            }, is_best)
Exemplo n.º 4
0
def main_worker(gpu, ngpus_per_node, args):
    global best_acc1
    args.gpu = gpu

    if args.gpu is not None:
        print("Use GPU: {} for training".format(args.gpu))

    # create model
    print("=> creating model '{}'".format(args.arch))
    num_classes = 100 if args.dataset == 'cifar100' else 10
    use_norm = True if args.loss_type == 'LDAM' else False
    model = models.__dict__[args.arch](num_classes=num_classes,
                                       use_norm=use_norm)

    if args.gpu is not None:
        torch.cuda.set_device(args.gpu)
        model = model.cuda(
            args.gpu)  #cuda()用于将变量传输到GPU上,model.cuda将model复制到cuda上
    else:
        # DataParallel will divide and allocate batch_size to all available GPUs
        model = torch.nn.DataParallel(
            model).cuda()  #批处理,多GPU默认用dataparallel使用在多块gpu上

    optimizer = torch.optim.SGD(
        model.parameters(),
        args.lr,
        momentum=args.momentum,
        weight_decay=args.weight_decay
    )  # optimizer 使用 SGD + momentum  动量,默认设置为0.9 权值衰减,默认为1e-4

    # optionally resume from a checkpoint 恢复模型
    if args.resume:
        if os.path.isfile(args.resume):  #os.path.isfile判断返回是否为文件
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume,
                                    map_location='cuda:0')  #load一个save得对象
            args.start_epoch = checkpoint['epoch']  #defaut=200
            best_acc1 = checkpoint['best_acc1']  #best_acc1 = 0
            if args.gpu is not None:
                # best_acc1 may be from a checkpoint from a different GPU
                best_acc1 = best_acc1.to(args.gpu)
            model.load_state_dict(
                checkpoint['state_dict'])  #load_state_dict仅仅加载模型参数
            optimizer.load_state_dict(checkpoint['optimizer'])  # 优化模型,恢复模型
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    cudnn.benchmark = True  #设置这个 flag 可以让内置的 cuDNN 的 auto-tuner 自动寻找最适合当前配置的高效算法,来达到优化运行效率的问题
    # Data loading code
    #数据预处理及数据加载
    transform_train = transforms.Compose([  #将几个transforms 组合在一起
        transforms.RandomCrop(32, padding=4),  #剪裁到32×32
        transforms.RandomHorizontalFlip(),  #随机水平翻转,概率为0.5。
        transforms.ToTensor(),  #将numpy表示的图像转化为torch的Tensor表示
        # 把一个取值范围是[0,255]的PIL.Image或者shape为(H,W,C)的numpy.ndarray,转换成形状为[C,H,W],取值范围是[0,1.0]的FloadTensor
        transforms.Normalize(
            (0.4914, 0.4822, 0.4465),
            (0.2023, 0.1994, 0.2010)),  #标准化normalize: - mean / std
    ])

    transform_val = transforms.Compose([
        transforms.ToTensor(),  #将numpy表示的图像转化为torch的Tensor表示
        transforms.Normalize((0.4914, 0.4822, 0.4465),
                             (0.2023, 0.1994, 0.2010)),
    ])

    if args.dataset == 'cifar10':
        train_dataset = IMBALANCECIFAR10(root='./data',
                                         imb_type=args.imb_type,
                                         imb_factor=args.imb_factor,
                                         rand_number=args.rand_number,
                                         train=True,
                                         download=True,
                                         transform=transform_train)
        val_dataset = datasets.CIFAR10(root='./data',
                                       train=False,
                                       download=True,
                                       transform=transform_val)
    elif args.dataset == 'cifar100':
        train_dataset = IMBALANCECIFAR100(root='./data',
                                          imb_type=args.imb_type,
                                          imb_factor=args.imb_factor,
                                          rand_number=args.rand_number,
                                          train=True,
                                          download=True,
                                          transform=transform_train)
        val_dataset = datasets.CIFAR100(root='./data',
                                        train=False,
                                        download=True,
                                        transform=transform_val)
    else:
        warnings.warn('Dataset is not listed')
        return
    cls_num_list = train_dataset.get_cls_num_list()
    print('cls num list:')
    print(cls_num_list)
    args.cls_num_list = cls_num_list

    train_sampler = None
    # dataloader
    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=(train_sampler is None),
                                               num_workers=args.workers,
                                               pin_memory=True,
                                               sampler=train_sampler)

    val_loader = torch.utils.data.DataLoader(val_dataset,
                                             batch_size=100,
                                             shuffle=False,
                                             num_workers=args.workers,
                                             pin_memory=True)

    # init log for training
    log_training = open(
        os.path.join(args.root_log, args.store_name, 'log_train.csv'),
        'w')  #os.path.join把目录和文件名合成一个路径
    log_testing = open(
        os.path.join(args.root_log, args.store_name, 'log_test.csv'), 'w')
    with open(os.path.join(args.root_log, args.store_name, 'args.txt'),
              'w') as f:
        f.write(str(args))
    tf_writer = SummaryWriter(
        log_dir=os.path.join(args.root_log, args.store_name))
    #训练模型
    for epoch in range(args.start_epoch, args.epochs):
        adjust_learning_rate(optimizer, epoch,
                             args)  #adjust_learning_rate自定义函数

        if args.train_rule == 'None':
            train_sampler = None
            per_cls_weights = None
        elif args.train_rule == 'Resample':
            train_sampler = ImbalancedDatasetSampler(train_dataset)
            per_cls_weights = None
        elif args.train_rule == 'Reweight':
            train_sampler = None
            beta = 0.9999
            effective_num = 1.0 - np.power(beta, cls_num_list)
            per_cls_weights = (1.0 - beta) / np.array(effective_num)
            per_cls_weights = per_cls_weights / np.sum(per_cls_weights) * len(
                cls_num_list)
            per_cls_weights = torch.FloatTensor(per_cls_weights).cuda(args.gpu)
        elif args.train_rule == 'DRW':
            train_sampler = None
            idx = epoch // 160
            betas = [0, 0.9999]
            effective_num = 1.0 - np.power(betas[idx], cls_num_list)
            per_cls_weights = (1.0 - betas[idx]) / np.array(effective_num)
            per_cls_weights = per_cls_weights / np.sum(per_cls_weights) * len(
                cls_num_list)
            per_cls_weights = torch.FloatTensor(per_cls_weights).cuda(args.gpu)
        else:
            warnings.warn('Sample rule is not listed')

        if args.loss_type == 'CE':
            criterion = nn.CrossEntropyLoss(weight=per_cls_weights).cuda(
                args.gpu)
        elif args.loss_type == 'LDAM':
            criterion = LDAMLoss(cls_num_list=cls_num_list,
                                 max_m=0.5,
                                 s=30,
                                 weight=per_cls_weights).cuda(args.gpu)
        elif args.loss_type == 'Focal':
            criterion = FocalLoss(weight=per_cls_weights,
                                  gamma=1).cuda(args.gpu)
        else:
            warnings.warn('Loss type is not listed')
            return

        # train for one epoch
        train(train_loader, model, criterion, optimizer, epoch, args,
              log_training, tf_writer)

        # evaluate on validation set
        acc1 = validate(val_loader, model, criterion, epoch, args, log_testing,
                        tf_writer)

        # remember best acc@1 and save checkpoint
        is_best = acc1 > best_acc1
        best_acc1 = max(acc1, best_acc1)

        tf_writer.add_scalar(
            'acc/test_top1_best', best_acc1, epoch
        )  #writer.add_scalar将我们所需要的数据保存在文件里面供可视化使用。 这里是Scalar类型,所以使用writer.add_scalar()
        output_best = 'Best Prec@1: %.3f\n' % (best_acc1)
        print(output_best)
        log_testing.write(output_best + '\n')
        log_testing.flush()

        save_checkpoint(
            args, {
                'epoch': epoch + 1,
                'arch': args.arch,
                'state_dict': model.state_dict(),
                'best_acc1': best_acc1,
                'optimizer': optimizer.state_dict(),
            }, is_best)  #保存checkpiont
Exemplo n.º 5
0
def main_worker(gpu, ngpus_per_node, args):
    global best_acc1

    # create model
    print("=> creating model '{}'".format(args.arch))
    num_classes = 365

    # Data loading code

    transform_train = transforms.Compose([
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
    ])

    transform_val = transforms.Compose([
        transforms.Scale(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
    ])

    train_dataset = Places(args.image_dir, transform_train, 'train')
    val_dataset = Places(args.image_dir, transform_val, 'val')

    cls_num_list = train_dataset.get_cls_num_list()
    print('cls num list:')
    print(cls_num_list)

    args.cls_num_list = cls_num_list.copy()

    head_lists = []
    Inf = 0
    for i in range(int(num_classes * args.head_tail_ratio)):
        head_lists.append(cls_num_list.index(max(cls_num_list)))
        cls_num_list[cls_num_list.index(max(cls_num_list))] = Inf

    model = models.__dict__[args.arch](num_classes=num_classes,
                                       head_lists=head_lists,
                                       phase_train=True,
                                       epoch_thresh=args.epoch_thresh)

    model = torch.nn.DataParallel(model).cuda()

    optimizer = torch.optim.SGD(model.parameters(),
                                args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    cudnn.benchmark = True

    train_sampler = None

    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=(train_sampler is None),
                                               num_workers=args.workers,
                                               pin_memory=True,
                                               sampler=train_sampler,
                                               drop_last=True)

    val_loader = torch.utils.data.DataLoader(val_dataset,
                                             batch_size=100,
                                             shuffle=False,
                                             num_workers=args.workers,
                                             pin_memory=True)

    # init log for training
    log_training = open(
        os.path.join(args.root_log, args.store_name, 'log_train.csv'), 'w')
    log_testing = open(
        os.path.join(args.root_log, args.store_name, 'log_test.csv'), 'w')
    with open(os.path.join(args.root_log, args.store_name, 'args.txt'),
              'w') as f:
        f.write(str(args))
    tf_writer = SummaryWriter(
        log_dir=os.path.join(args.root_log, args.store_name))
    for epoch in range(args.start_epoch, args.epochs):
        adjust_learning_rate(optimizer, epoch, args)

        if epoch == 10:
            train_sampler = ImbalancedDatasetSampler(
                train_dataset, label_count=args.cls_num_list)
            train_loader = torch.utils.data.DataLoader(
                train_dataset,
                batch_size=args.batch_size,
                shuffle=(train_sampler is None),
                num_workers=args.workers,
                pin_memory=True,
                sampler=train_sampler,
                drop_last=True)

        effective_num = 1.0 - np.power(0, args.cls_num_list)
        per_cls_weights = (1.0) / np.array(effective_num)
        per_cls_weights = per_cls_weights / np.sum(per_cls_weights) * len(
            args.cls_num_list)
        per_cls_weights = torch.FloatTensor(per_cls_weights).cuda()

        criterion = LDAMLoss(cls_num_list=args.cls_num_list,
                             max_m=0.2,
                             s=50,
                             weight=per_cls_weights).cuda()

        # train for one epoch
        train(train_loader, model, criterion, optimizer, epoch, args,
              log_training, tf_writer)

        # evaluate on validation set
        acc1 = validate(val_loader, model, criterion, epoch, args, log_testing,
                        tf_writer)

        # remember best acc@1 and save checkpoint
        is_best = acc1 > best_acc1
        best_acc1 = max(acc1, best_acc1)

        tf_writer.add_scalar('acc/test_top1_best', best_acc1, epoch)
        output_best = 'Best Prec@1: %.3f\n' % (best_acc1)
        print(output_best)
        log_testing.write(output_best + '\n')
        log_testing.flush()

        # remove RSG module, since RSG is not used during testing.
        new_state_dict = OrderedDict()
        for k in model.state_dict().keys():
            name = k[7:]  # remove `module.`
            if 'RSG' in k:
                continue
            new_state_dict[name] = model.state_dict()[k]

        save_checkpoint(
            args, {
                'epoch': epoch + 1,
                'arch': args.arch,
                'state_dict': new_state_dict,
                'best_acc1': best_acc1,
                'optimizer': optimizer.state_dict(),
            }, is_best)
def main_worker(gpu, ngpus_per_node, args):
    global best_acc1
    args.gpu = gpu

    if args.gpu is not None:
        print("Use GPU: {} for training".format(args.gpu))

    # create model
    print("=> creating model '{}'".format(args.arch))
    num_classes = 100 if args.dataset == 'cifar100' else 10
    use_norm = True if args.loss_type == 'LDAM' else False
    model = models.__dict__[args.arch](num_classes=num_classes,
                                       use_norm=use_norm)
    non_gpu_model = models.__dict__[args.arch](num_classes=num_classes,
                                               use_norm=use_norm)

    if args.gpu is not None:
        torch.cuda.set_device(args.gpu)
        model = model.cuda(args.gpu)
    else:
        # DataParallel will divide and allocate batch_size to all available GPUs
        model = torch.nn.DataParallel(model).cuda()

    optimizer = torch.optim.SGD(model.parameters(),
                                args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume, map_location='cuda:0')
            args.start_epoch = checkpoint['epoch']
            best_acc1 = checkpoint['best_acc1']
            if args.gpu is not None:
                # best_acc1 may be from a checkpoint from a different GPU
                best_acc1 = best_acc1.to(args.gpu)
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    cudnn.benchmark = True

    # Data loading code

    transform_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465),
                             (0.2023, 0.1994, 0.2010)),
    ])

    transform_val = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465),
                             (0.2023, 0.1994, 0.2010)),
    ])

    if args.dataset == 'cifar10':
        train_dataset = IMBALANCECIFAR10(root='./data',
                                         imb_type=args.imb_type,
                                         imb_factor=args.imb_factor,
                                         rand_number=args.rand_number,
                                         train=True,
                                         download=True,
                                         transform=transform_train)
        val_dataset = datasets.CIFAR10(root='./data',
                                       train=False,
                                       download=True,
                                       transform=transform_val)
    elif args.dataset == 'cifar100':
        train_dataset = IMBALANCECIFAR100(root='./data',
                                          imb_type=args.imb_type,
                                          imb_factor=args.imb_factor,
                                          rand_number=args.rand_number,
                                          train=True,
                                          download=True,
                                          transform=transform_train)
        val_dataset = datasets.CIFAR100(root='./data',
                                        train=False,
                                        download=True,
                                        transform=transform_val)
    else:
        warnings.warn('Dataset is not listed')
        return
    cls_num_list = train_dataset.get_cls_num_list()
    print('cls num list:')
    print(cls_num_list)
    args.cls_num_list = cls_num_list

    train_sampler = None

    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=(train_sampler is None),
                                               num_workers=args.workers,
                                               pin_memory=True,
                                               sampler=train_sampler)

    val_loader = torch.utils.data.DataLoader(val_dataset,
                                             batch_size=100,
                                             shuffle=False,
                                             num_workers=args.workers,
                                             pin_memory=True)

    # init log for training
    log_training = open(
        os.path.join(args.root_log, args.store_name, 'log_train.csv'), 'w')
    log_testing = open(
        os.path.join(args.root_log, args.store_name, 'log_test.csv'), 'w')
    with open(os.path.join(args.root_log, args.store_name, 'args.txt'),
              'w') as f:
        f.write(str(args))
    tf_writer = SummaryWriter(
        log_dir=os.path.join(args.root_log, args.store_name))
    for epoch in range(args.start_epoch, args.epochs):
        adjust_learning_rate(optimizer, epoch, args)

        if args.train_rule == 'None':
            train_sampler = None
            per_cls_weights = None
        elif args.train_rule == 'Resample':
            train_sampler = ImbalancedDatasetSampler(train_dataset)
            per_cls_weights = None
        elif args.train_rule == 'Reweight':
            train_sampler = None
            beta = 0.9999
            effective_num = 1.0 - np.power(beta, cls_num_list)
            per_cls_weights = (1.0 - beta) / np.array(effective_num)
            per_cls_weights = per_cls_weights / np.sum(per_cls_weights) * len(
                cls_num_list)
            per_cls_weights = torch.FloatTensor(per_cls_weights).cuda(args.gpu)
        elif args.train_rule == 'DRW':
            train_sampler = None
            idx = min(epoch // 160, 1)
            betas = [0, 0.9999]
            effective_num = 1.0 - np.power(betas[idx], cls_num_list)
            per_cls_weights = (1.0 - betas[idx]) / np.array(effective_num)
            per_cls_weights = per_cls_weights / np.sum(per_cls_weights) * len(
                cls_num_list)
            per_cls_weights = torch.FloatTensor(per_cls_weights).cuda(args.gpu)
        elif args.train_rule == 'SWITCHING_DRW_AUTO_CLUSTER':
            train_sampler = None
            cutoff_epoch = args.cutoff_epoch
            idx = min(epoch // cutoff_epoch, 1)
            betas = [0, 0.9999]
            if epoch >= cutoff_epoch + 10 and (epoch - cutoff_epoch) % 20 == 0:
                max_real_ratio_number_of_labels = 20
                # todo: transform data batch by batch, then concatentate...
                temp_batch_size = int(train_dataset.data.shape[0])
                temp_train_loader = torch.utils.data.DataLoader(
                    train_dataset,
                    batch_size=int(temp_batch_size),
                    shuffle=(train_sampler is None),
                    num_workers=args.workers,
                    pin_memory=True,
                    sampler=train_sampler)

                transformed_data = None
                transformed_labels = None
                for i, (xs, labels) in enumerate(train_loader):
                    transformed_batch_data = model.forward_partial(
                        xs.cuda(), num_layers=6)
                    transformed_batch_data = transformed_batch_data.cpu(
                    ).detach()
                    transformed_batch_data = transformed_batch_data.numpy()
                    transformed_batch_data = np.reshape(
                        transformed_batch_data,
                        (transformed_batch_data.shape[0], -1))
                    labels = np.array(labels)[:, np.newaxis]
                    if transformed_data is None:
                        transformed_data = transformed_batch_data
                        transformed_labels = labels
                    else:
                        transformed_data = np.vstack(
                            (transformed_data, transformed_batch_data))
                        # print(labels.shape)
                        # print(transformed_labels.shape)
                        transformed_labels = np.vstack(
                            (transformed_labels, labels))

                xmean_model = xmeans(data=transformed_data,
                                     kmax=num_classes *
                                     max_real_ratio_number_of_labels)
                xmean_model.process()
                # Extract clustering results: clusters and their centers
                clusters = xmean_model.get_clusters()
                centers = xmean_model.get_centers()
                new_labels = []
                xs = transformed_data
                centers = np.array(centers)
                print("number of clusters: ", len(centers))
                squared_norm_dist = np.sum((xs - centers[:, np.newaxis])**2,
                                           axis=2)
                data_centers = np.argmin(squared_norm_dist, axis=0)
                data_centers = np.expand_dims(data_centers, axis=1)

                new_labels = []
                for i in range(len(transformed_labels)):
                    new_labels.append(data_centers[i][0] +
                                      transformed_labels[i][0] * len(centers))

                new_label_counts = {}
                for label in new_labels:
                    if label in new_label_counts.keys():
                        new_label_counts[label] += 1
                    else:
                        new_label_counts[label] = 1

                # print(new_label_counts)

                per_cls_weights = []
                for i in range(len(cls_num_list)):
                    temp = []
                    for j in range(len(centers)):
                        new_label = j + i * len(centers)
                        if new_label in new_label_counts:
                            temp.append(new_label_counts[new_label])
                    effective_num_temp = 1.0 - np.power(betas[idx], temp)
                    per_cls_weights_temp = (
                        1.0 - betas[idx]) / np.array(effective_num_temp)
                    per_cls_weights.append(
                        np.average(per_cls_weights_temp, weights=temp))
                per_cls_weights = per_cls_weights / np.sum(
                    per_cls_weights) * len(cls_num_list)
                per_cls_weights = torch.FloatTensor(per_cls_weights).cuda(
                    args.gpu)
            elif epoch < cutoff_epoch or (epoch - cutoff_epoch) % 20 == 10:
                effective_num = 1.0 - np.power(betas[idx], cls_num_list)
                per_cls_weights = (1.0 - betas[idx]) / np.array(effective_num)
                per_cls_weights = per_cls_weights / np.sum(
                    per_cls_weights) * len(cls_num_list)
                per_cls_weights = torch.FloatTensor(per_cls_weights).cuda(
                    args.gpu)

        elif args.train_rule == 'DRW_AUTO_CLUSTER':
            train_sampler = None
            cutoff_epoch = args.cutoff_epoch
            idx = min(epoch // cutoff_epoch, 1)
            betas = [0, 0.9999]
            if epoch >= cutoff_epoch and (epoch - cutoff_epoch) % 10 == 0:
                max_real_ratio_number_of_labels = 20
                # todo: transform data batch by batch, then concatentate...
                temp_batch_size = int(train_dataset.data.shape[0])
                temp_train_loader = torch.utils.data.DataLoader(
                    train_dataset,
                    batch_size=int(temp_batch_size),
                    shuffle=(train_sampler is None),
                    num_workers=args.workers,
                    pin_memory=True,
                    sampler=train_sampler)

                transformed_data = None
                transformed_labels = None
                for i, (xs, labels) in enumerate(train_loader):
                    transformed_batch_data = model.forward_partial(
                        xs.cuda(), num_layers=6)
                    transformed_batch_data = transformed_batch_data.cpu(
                    ).detach()
                    transformed_batch_data = transformed_batch_data.numpy()
                    transformed_batch_data = np.reshape(
                        transformed_batch_data,
                        (transformed_batch_data.shape[0], -1))
                    labels = np.array(labels)[:, np.newaxis]
                    if transformed_data is None:
                        transformed_data = transformed_batch_data
                        transformed_labels = labels
                    else:
                        transformed_data = np.vstack(
                            (transformed_data, transformed_batch_data))
                        # print(labels.shape)
                        # print(transformed_labels.shape)
                        transformed_labels = np.vstack(
                            (transformed_labels, labels))

                initial_centers = [
                    np.zeros((transformed_data.shape[1], ))
                    for i in range(num_classes)
                ]
                center_counts = [0 for i in range(num_classes)]
                for i in range(transformed_data.shape[0]):
                    temp_idx = transformed_labels[i][0]
                    initial_centers[temp_idx] = initial_centers[
                        temp_idx] + transformed_data[i, :]
                    center_counts[temp_idx] = center_counts[temp_idx] + 1

                for i in range(num_classes):
                    initial_centers[i] = initial_centers[i] / center_counts[i]

                xmean_model = xmeans(data=transformed_data, initial_centers=initial_centers, \
                                     kmax=num_classes * max_real_ratio_number_of_labels)
                xmean_model.process()
                # Extract clustering results: clusters and their centers
                clusters = xmean_model.get_clusters()
                centers = xmean_model.get_centers()
                new_labels = []
                xs = transformed_data
                centers = np.array(centers)
                print("number of clusters: ", len(centers))
                squared_norm_dist = np.sum((xs - centers[:, np.newaxis])**2,
                                           axis=2)
                data_centers = np.argmin(squared_norm_dist, axis=0)
                data_centers = np.expand_dims(data_centers, axis=1)

                new_labels = []
                for i in range(len(transformed_labels)):
                    new_labels.append(data_centers[i][0] +
                                      transformed_labels[i][0] * len(centers))

                new_label_counts = {}
                for label in new_labels:
                    if label in new_label_counts.keys():
                        new_label_counts[label] += 1
                    else:
                        new_label_counts[label] = 1

                # print(new_label_counts)

                per_cls_weights = []
                for i in range(len(cls_num_list)):
                    temp = []
                    for j in range(len(centers)):
                        new_label = j + i * len(centers)
                        if new_label in new_label_counts:
                            temp.append(new_label_counts[new_label])
                    effective_num_temp = 1.0 - np.power(betas[idx], temp)
                    per_cls_weights_temp = (
                        1.0 - betas[idx]) / np.array(effective_num_temp)
                    per_cls_weights.append(
                        np.average(per_cls_weights_temp, weights=temp))
                per_cls_weights = per_cls_weights / np.sum(
                    per_cls_weights) * len(cls_num_list)
                per_cls_weights = torch.FloatTensor(per_cls_weights).cuda(
                    args.gpu)
            elif epoch < cutoff_epoch:
                effective_num = 1.0 - np.power(betas[idx], cls_num_list)
                per_cls_weights = (1.0 - betas[idx]) / np.array(effective_num)
                per_cls_weights = per_cls_weights / np.sum(
                    per_cls_weights) * len(cls_num_list)
                per_cls_weights = torch.FloatTensor(per_cls_weights).cuda(
                    args.gpu)

        else:
            warnings.warn('Sample rule is not listed')

        if args.loss_type == 'CE':
            criterion = nn.CrossEntropyLoss(weight=per_cls_weights).cuda(
                args.gpu)
        elif args.loss_type == 'LDAM':
            criterion = LDAMLoss(cls_num_list=cls_num_list,
                                 max_m=0.5,
                                 s=30,
                                 weight=per_cls_weights).cuda(args.gpu)
            #temp = [cls_num_list[i] * per_cls_weights[i].item() for i in range(len(cls_num_list))]
            #criterion = LDAMLoss(cls_num_list=temp, max_m=0.5, s=30, weight=per_cls_weights).cuda(args.gpu)
        elif args.loss_type == 'Focal':
            criterion = FocalLoss(weight=per_cls_weights,
                                  gamma=1).cuda(args.gpu)
        else:
            warnings.warn('Loss type is not listed')
            return

        # train for one epoch
        train(train_loader, model, criterion, optimizer, epoch, args,
              log_training, tf_writer)

        # evaluate on validation set
        acc1 = validate(val_loader, model, criterion, epoch, args, log_testing,
                        tf_writer)

        # remember best acc@1 and save checkpoint
        is_best = acc1 > best_acc1
        best_acc1 = max(acc1, best_acc1)

        tf_writer.add_scalar('acc/test_top1_best', best_acc1, epoch)
        output_best = 'Best Prec@1: %.3f\n' % (best_acc1)
        print(output_best)
        log_testing.write(output_best + '\n')
        log_testing.flush()

        save_checkpoint(
            args, {
                'epoch': epoch + 1,
                'arch': args.arch,
                'state_dict': model.state_dict(),
                'best_acc1': best_acc1,
                'optimizer': optimizer.state_dict(),
            }, is_best)