Exemple #1
0
def dataloaders():
    # Data loading code
    # 1. all mages are already aligned to 218*178;

    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])

    train_dataset = CelebA(
        args.root,
        'train600.txt',
        transforms.Compose([
            transforms.RandomRotation(30),
            transforms.RandomHorizontalFlip(),
            #         transforms.Resize((178,178)),
            #         transforms.CenterCrop((178,178)),
            transforms.RandomResizedCrop(178, scale=(0.8, 1.0)),
            # should not cut many info for multi label classification
            transforms.ToTensor(),
            normalize,
        ]))

    val_dataset = CelebA(
        args.root,
        'val200.txt',
        transforms.Compose([
            transforms.Resize((178, 178)),
            #         transforms.CenterCrop((178,178)),
            transforms.ToTensor(),
            normalize,
        ]))

    test_dataset = CelebA(
        args.root,
        'test200.txt',
        transforms.Compose([
            transforms.Resize((178, 178)),
            #         transforms.CenterCrop((178,178)),
            transforms.ToTensor(),
            normalize,
        ]))

    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               pin_memory=True)
    val_loader = torch.utils.data.DataLoader(val_dataset,
                                             batch_size=args.batch_size,
                                             shuffle=False,
                                             pin_memory=True)
    test_loader = torch.utils.data.DataLoader(test_dataset,
                                              batch_size=args.batch_size,
                                              shuffle=False,
                                              pin_memory=True)
    return train_loader, val_loader, test_loader
Exemple #2
0
 def get_ds_train(self):
     if self.ds_train is None:
         self.celeba = CelebA(self.path_img, self.path_ann, self.path_bbox)
         self.persons = self.celeba.persons
         self.ds_train = BufferDS(self.buffer_size, self.celeba,
                                  self.batch_size)
     return self.ds_train
def dataloaders():
    # Data loading code
    # 1. all mages are already aligned to 218*178;

    normalize0 = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                      std=[0.229, 0.224, 0.225])

    train_dataset = CelebA(
        '/home/MSAI/cgong002/acv_project_celeba/',
        'train_attr_list.txt',
        transforms.Compose([
            transforms.RandomRotation(30),
            transforms.RandomHorizontalFlip(),
            #         transforms.Resize((178,178)),
            #         transforms.CenterCrop((178,178)),
            transforms.RandomResizedCrop(178, scale=(0.8, 1.0)),
            # should not cut many info for multi label classification
            transforms.ToTensor(),
            # normalize,
        ]))

    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=512,
                                               shuffle=True,
                                               num_workers=7,
                                               pin_memory=True)
    return train_loader
Exemple #4
0
    def __init__(self, split, align=False, partition='all'):

        if not partition == 'all':
            name = 'person_' + partition + '_' + split
        elif partition == 'all':
            name = 'person_' + split

        if align and (partition == 'all' or partition == 'face'):
            name += '_align'

        Imdb.__init__(self, name)

        # Load two children dataset wrappers
        self._face = CelebA(split, align=align)
        self._clothes = DeepFashion(split)
        # The class list is a combination of face and clothing attributes
        self._classes = self._face.classes + self._clothes.classes
        self._face_class_idx = range(self._face.num_classes)
        self._clothes_class_idx = range(
            self._face.num_classes,
            self._face.num_classes + self._clothes.num_classes)

        # load data path
        self._data_path = os.path.join(self.data_path, 'imdb_PersonAttributes')
        # load the image lists and attributes.
        self._load_dataset(split, align, partition)
Exemple #5
0
def main():
    global args, best_prec1
    args = parser.parse_args()

    args.distributed = args.world_size > 1

    if args.distributed:
        dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
                                world_size=args.world_size)
    # Use CUDA
    os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu_id
    use_cuda = torch.cuda.is_available()

    # Random seed
    if args.manual_seed is None:
        args.manual_seed = random.randint(1, 10000)
    random.seed(args.manual_seed)
    torch.manual_seed(args.manual_seed)
    if use_cuda:
        torch.cuda.manual_seed_all(args.manual_seed)

    # create model
    if args.pretrained:
        print("=> using pre-trained model '{}'".format(args.arch))
        model = models.__dict__[args.arch](pretrained=True)
    elif args.arch.startswith('resnext'):
        model = models.__dict__[args.arch](
                    baseWidth=args.base_width,
                    cardinality=args.cardinality,
                )
    elif args.arch.startswith('shufflenet'):
        model = models.__dict__[args.arch](
                    groups=args.groups
                )
    else:
        print("=> creating model '{}'".format(args.arch))
        model = models.__dict__[args.arch]()

    if not args.distributed:
        if args.arch.startswith('alexnet') or args.arch.startswith('vgg'):
            model.features = torch.nn.DataParallel(model.features)
            model.cuda()
        else:
            model = torch.nn.DataParallel(model).cuda()
    else:
        model.cuda()
        model = torch.nn.parallel.DistributedDataParallel(model)

    # define loss function (criterion) and optimizer
    criterion = nn.CrossEntropyLoss().cuda()

    optimizer = torch.optim.SGD(model.parameters(), args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    # optionally resume from a checkpoint
    title = 'CelebA-' + args.arch
    if not os.path.isdir(args.checkpoint):
        mkdir_p(args.checkpoint)

    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            best_prec1 = checkpoint['best_prec1']
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            print("=> loaded checkpoint '{}' (epoch {})"
                  .format(args.resume, checkpoint['epoch']))
            args.checkpoint = os.path.dirname(args.resume)
            logger = Logger(os.path.join(args.checkpoint, 'log.txt'), title=title, resume=True)
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))
    else:
        logger = Logger(os.path.join(args.checkpoint, 'log.txt'), title=title)
        logger.set_names(['Learning Rate', 'Train Loss', 'Valid Loss', 'Train Acc.', 'Valid Acc.'])


    cudnn.benchmark = True

    # Data loading code
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
    color_jitter = transforms.ColorJitter(0.8, 0.8, 0.8, 0.2)
    train_dataset = CelebA(
        args.data,
        'train_40_att_list.txt',
        transforms.Compose([
            transforms.RandomResizedCrop(size=(218, 178)),
            transforms.RandomHorizontalFlip(),
            transforms.RandomApply([color_jitter], p=0.8),
            transforms.RandomGrayscale(p=0.2),
            GaussianBlur(kernel_size=int(0.1 * 178)),
            transforms.ToTensor(),
            normalize,
        ]))

    if args.distributed:
        train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
    else:
        train_sampler = None

    train_loader = torch.utils.data.DataLoader(
        train_dataset, batch_size=args.train_batch, shuffle=(train_sampler is None),
        num_workers=args.workers, pin_memory=True, sampler=train_sampler)

    val_loader = torch.utils.data.DataLoader(
        CelebA(args.data, 'val_40_att_list.txt', transforms.Compose([
            transforms.ToTensor(),
            normalize,
        ])),
        batch_size=args.test_batch, shuffle=False,
        num_workers=args.workers, pin_memory=True)

    test_loader = torch.utils.data.DataLoader(
        CelebA(args.data, 'test_40_att_list.txt', transforms.Compose([
            transforms.ToTensor(),
            normalize,
        ])),
        batch_size=args.test_batch, shuffle=False,
        num_workers=args.workers, pin_memory=True)

    if args.evaluate:
        validate(test_loader, model, criterion)
        return

    if args.private_test:
        private_loader = torch.utils.data.DataLoader(
            CelebA(args.data, 'testset', transforms.Compose([
                transforms.ToTensor(),
                normalize,
            ]),test=True),
            batch_size=args.test_batch, shuffle=False,
            num_workers=args.workers, pin_memory=True)
        test(private_loader, model, criterion)
        return

    # visualization
    writer = SummaryWriter(os.path.join(args.checkpoint, 'logs'))

    for epoch in range(args.start_epoch, args.epochs):
        if args.distributed:
            train_sampler.set_epoch(epoch)
        lr = adjust_learning_rate(optimizer, epoch)

        print('\nEpoch: [%d | %d] LR: %f' % (epoch + 1, args.epochs, lr))

        # train for one epoch
        train_loss, train_acc = train(train_loader, model, criterion, optimizer, epoch)

        # evaluate on validation set
        val_loss, prec1 = validate(val_loader, model, criterion)

        # append logger file
        logger.append([lr, train_loss, val_loss, train_acc, prec1])

        # tensorboardX
        writer.add_scalar('learning rate', lr, epoch + 1)
        writer.add_scalars('loss', {'train loss': train_loss, 'validation loss': val_loss}, epoch + 1)
        writer.add_scalars('accuracy', {'train accuracy': train_acc, 'validation accuracy': prec1}, epoch + 1)
        #for name, param in model.named_parameters():
        #    writer.add_histogram(name, param.clone().cpu().data.numpy(), epoch + 1)


        is_best = prec1 > best_prec1
        best_prec1 = max(prec1, best_prec1)
        save_checkpoint({
            'epoch': epoch + 1,
            'arch': args.arch,
            'state_dict': model.state_dict(),
            'best_prec1': best_prec1,
            'optimizer' : optimizer.state_dict(),
        }, is_best, checkpoint=args.checkpoint)

    logger.close()
    logger.plot()
    savefig(os.path.join(args.checkpoint, 'log.eps'))
    writer.close()

    print('Best accuracy:')
    print(best_prec1)
Exemple #6
0
        for f in f_list[-7:]:
            f = os.path.join(path, f)
            transfer_paths.append(f)
        gif_images = []
        for p in transfer_paths:
            gif_images.append(imageio.imread(p))
        gif_path = os.path.join(path, "transfer.gif")
        imageio.mimsave(gif_path, gif_images, duration=0.1)

    def get_middle_vectors(self, src, dst):
        num = 5
        delta = 1 / (num + 1)
        alpha = delta
        result = []
        for i in range(num):
            vec = src * (1 - alpha) + dst * alpha
            result.append(vec)
            alpha += delta
        return result


if __name__ == '__main__':
    path_img = '../samples/celeba/Img/img_align_celeba.zip'
    path_ann = '../samples/celeba/Anno/identity_CelebA.txt'
    path_bbox = '../samples/celeba/Anno/list_bbox_celeba.txt'
    path_sample = '../samples/photos/'
    ca = CelebA(path_img, path_ann, path_bbox)
    # ca.pick_2_samples()
    cfg = MyConfig()
    cfg.from_cmd()
Exemple #7
0
                    count = count + 1
                    iterr = count * show_every
                    # Show example output for the generator
                    images_grid = show_generator_output(
                        sess, 25, inp_z, data_shape[2], data_img_mode)
                    dst = os.path.join("output", str(epoch_i),
                                       str(iterr) + ".png")
                    pyplot.imsave(dst, images_grid)

                # saving the model
                if epoch_i % 10 == 0:
                    if not os.path.exists('./model/'):
                        os.makedirs('./model')
                    saver.save(sess, './model/' + str(epoch_i))


# Get the data in a readable format
dataset = CelebA()

# Tensorflow
with tf.Graph().as_default():
    train(cfg.NB_EPOCHS, cfg.BATCH_SIZE, cfg.SIZE_G_INPUT, cfg.LEARNING_RATE,
          cfg.BETA1, dataset.get_batches, dataset.shape, dataset.image_mode)

for f in glob("output/**/*.png"):
    image = cv2.imread(f)
    # cv2.imshow('my_image', image)
    # cv2.waitKey(0)
    # cv2.destroyAllWindows()
    large = cv2.resize(image, (0, 0), fx=3, fy=3)
    cv2.imwrite(f, large)
Exemple #8
0
        split, align=True, partition='face'))

# PersonAttributes dataset (clothes partition)
for split in ['train', 'val', 'trainval', 'test']:
    name = 'person_clothes_{}'.format(split)
    __sets[name] = (
        lambda split=split: PersonAttributes(split, partition='clothes'))

# setup DeepFashion dataset
for split in ['train', 'val', 'test', 'trainval']:
    name = 'deepfashion_{}'.format(split)
    __sets[name] = (lambda split=split: DeepFashion(split))
# setup CelebA dataset
for split in ['train', 'val', 'test', 'trainval']:
    name = 'celeba_{}'.format(split)
    __sets[name] = (lambda split=split: CelebA(split))

# setup CelebA (aligned) dataset
for split in ['train', 'val', 'test', 'trainval']:
    name = 'celeba_{}_align'.format(split)
    __sets[name] = (lambda split=split: CelebA(split, align=True))

# setup CelebA+Webcam dataset
for split in ['train', 'val']:
    name = 'celeba_plus_webcam_cls_{}'.format(split)
    __sets[name] = (lambda split=split: CelebA_Plus_Webcam_Cls(split))

# setup IBMattributes dataset
for split in ['train', 'val']:
    name = 'IBMattributes_{}'.format(split)
    __sets[name] = (lambda split=split: IBMAttributes(split))
def dataloaders():
    # Data loading code
    # 1. all mages are already aligned to 218*178;

    normalize_old = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                         std=[0.229, 0.224, 0.225])
    #newly computed after data augmentation
    normalize_new = transforms.Normalize([0.4807, 0.3973, 0.3534],
                                         [0.2838, 0.2535, 0.2443])
    normalize = normalize_new

    train_dataset = CelebA(
        args.root,
        'train_attr_list.txt',
        transforms.Compose([
            transforms.RandomRotation(30),
            transforms.RandomHorizontalFlip(),
            #         transforms.Resize((178,178)),
            #         transforms.CenterCrop((178,178)),
            transforms.RandomResizedCrop(178, scale=(0.8, 1.0)),
            # should not cut many info for multi label classification
            transforms.ToTensor(),
            normalize,
        ]))

    val_dataset = CelebA(
        args.root,
        'val_attr_list.txt',
        transforms.Compose([
            transforms.Resize((178, 178)),
            #         transforms.CenterCrop((178,178)),
            transforms.ToTensor(),
            normalize,
        ]))

    test_dataset = CelebA(
        args.root,
        'test_attr_list.txt',
        transforms.Compose([
            transforms.Resize((178, 178)),
            #         transforms.CenterCrop((178,178)),
            transforms.ToTensor(),
            normalize,
        ]))

    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=7,
                                               pin_memory=True)
    val_loader = torch.utils.data.DataLoader(val_dataset,
                                             batch_size=args.batch_size,
                                             shuffle=False,
                                             num_workers=7,
                                             pin_memory=True)
    test_loader = torch.utils.data.DataLoader(test_dataset,
                                              batch_size=args.batch_size,
                                              shuffle=False,
                                              num_workers=7,
                                              pin_memory=True)
    return train_loader, val_loader, test_loader
Exemple #10
0
def main(c: dict):
    global args, best_prec1

    # if using tune, config will overwrite the args
    # TODO: write in a generic way
    # configable_list = ["lr", "arch", "lr_decay", "step", "loss"]
    # for c_name in configable_list:
    #     if c_name in c:
    #         args.

    if "lr" in c:
        args.lr = c["lr"]
    if "arch" in c:
        args.arch = c["arch"]
    if "lr_decay" in c:
        args.lr_decay = c["lr_decay"]
    if "step" in c:
        args.step = c["step"]
    if "loss" in c:
        args.loss = c["loss"]

    args.distributed = args.world_size > 1

    if args.distributed:
        dist.init_process_group(backend=args.dist_backend,
                                init_method=args.dist_url,
                                world_size=args.world_size)
    # Use CUDA
    os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu_id
    use_cuda = torch.cuda.is_available()

    # Random seed
    if args.manual_seed is None:
        args.manual_seed = random.randint(1, 10000)
    random.seed(args.manual_seed)
    torch.manual_seed(args.manual_seed)
    if use_cuda:
        torch.cuda.manual_seed_all(args.manual_seed)

    # create model
    if args.pretrained:
        print("=> using pre-trained model '{}'".format(args.arch))
        model = models.__dict__[args.arch](pretrained=True)
    # elif args.arch.startswith('resnext'):
    #     model = models.__dict__[args.arch](
    #                 baseWidth=args.base_width,
    #                 cardinality=args.cardinality,
    #             )
    elif args.arch.startswith('shufflenet'):
        model = models.__dict__[args.arch](groups=args.groups)
    else:
        print("=> creating model '{}'".format(args.arch))
        model = models.__dict__[args.arch]()

    if not args.distributed:
        if args.arch.startswith('alexnet') or args.arch.startswith('vgg'):
            model.features = torch.nn.DataParallel(model.features)
            model.to(device)  #.cuda()
        else:
            model = torch.nn.DataParallel(model).to(device)  #.cuda()
    else:
        model.to(device)  #.cuda()
        model = torch.nn.parallel.DistributedDataParallel(model)

    # define loss function (criterion) and optimizer
    if args.loss == "ce":
        criterion = nn.CrossEntropyLoss().to(device)  #.cuda()
    elif args.loss == "focalloss":
        criterion = FocalLoss(device)
    else:
        print("ERROR: ------ Unkown loss !!! ------")

    optimizer = torch.optim.SGD(model.parameters(),
                                args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    # optionally resume from a checkpoint
    title = 'CelebA-' + args.arch
    if not os.path.isdir(args.checkpoint):
        mkdir_p(args.checkpoint)

    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume, map_location=device)
            args.start_epoch = checkpoint['epoch']
            best_prec1 = checkpoint['best_prec1']
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
            args.checkpoint = os.path.dirname(args.resume)
            logger = Logger(os.path.join(args.checkpoint, 'log.txt'),
                            title=title,
                            resume=True)
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))
    else:
        logger = Logger(os.path.join(args.checkpoint, 'log.txt'), title=title)
        logger.set_names([
            'Learning Rate', 'Train Loss', 'Valid Loss', 'Train Acc.',
            'Valid Acc.'
        ])

    # Comment to avoid the cudnn cloud pickle error
    # cudnn.benchmark = True

    # Data loading code
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])

    train_dataset = CelebA(
        args.data, 'train_40_att_list.txt',
        transforms.Compose([
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize,
        ]))

    if args.distributed:
        train_sampler = torch.utils.data.distributed.DistributedSampler(
            train_dataset)
    else:
        train_sampler = None

    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=args.train_batch,
                                               shuffle=(train_sampler is None),
                                               num_workers=args.workers,
                                               pin_memory=True,
                                               sampler=train_sampler)

    val_loader = torch.utils.data.DataLoader(CelebA(
        args.data, 'val_40_att_list.txt',
        transforms.Compose([
            transforms.ToTensor(),
            normalize,
        ])),
                                             batch_size=args.test_batch,
                                             shuffle=False,
                                             num_workers=args.workers,
                                             pin_memory=True)

    test_loader = torch.utils.data.DataLoader(CelebA(
        args.data, args.e_file,
        transforms.Compose([
            transforms.ToTensor(),
            normalize,
        ])),
                                              batch_size=args.test_batch,
                                              shuffle=False,
                                              num_workers=args.workers,
                                              pin_memory=True)

    if args.evaluate:
        val_loss, prec1, top1_avg_att = validate(test_loader, model, criterion)
        print(top1_avg_att)
        return

    # visualization
    if not args.tune:
        writer = SummaryWriter(os.path.join(args.checkpoint, 'logs'))

    for epoch in range(args.start_epoch, args.epochs):
        if args.distributed:
            train_sampler.set_epoch(epoch)
        lr = adjust_learning_rate(optimizer, epoch)

        print('\nEpoch: [%d | %d] LR: %f' % (epoch + 1, args.epochs, lr))

        # train for one epoch
        train_loss, train_acc = train(train_loader, model, criterion,
                                      optimizer, epoch)

        # evaluate on validation set
        val_loss, prec1, top1_avg_att = validate(val_loader, model, criterion)

        # append logger file
        logger.append([lr, train_loss, val_loss, train_acc, prec1])

        acc_att_dict = {}
        for i, acc in enumerate(top1_avg_att):
            acc_att_dict[f"att-{i}"] = acc
        # tensorboardX
        if not args.tune:
            writer.add_scalar('learning rate', lr, epoch + 1)
            writer.add_scalars('loss', {
                'train loss': train_loss,
                'validation loss': val_loss
            }, epoch + 1)
            writer.add_scalars('accuracy', {
                'train accuracy': train_acc,
                'validation accuracy': prec1
            }, epoch + 1)
            writer.add_scalars('att-accuracy', acc_att_dict, epoch + 1)
        #for name, param in model.named_parameters():
        #    writer.add_histogram(name, param.clone().cpu().data.numpy(), epoch + 1)

        if args.tune:
            tune.report(train_loss=train_loss,
                        train_acc=train_acc,
                        val_loss=val_loss,
                        prec1=prec1,
                        att_acc=acc_att_dict,
                        lr=lr)

        is_best = prec1 > best_prec1
        best_prec1 = max(prec1, best_prec1)
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'arch': args.arch,
                'state_dict': model.state_dict(),
                'best_prec1': best_prec1,
                'optimizer': optimizer.state_dict(),
            },
            is_best,
            checkpoint=args.checkpoint)

    logger.close()
    logger.plot()
    savefig(os.path.join(args.checkpoint, 'log.eps'))
    if not args.tune:
        writer.close()

    print('Best accuracy:')
    print(best_prec1)
Exemple #11
0
def main():
    global args, best_prec1
    args = parser.parse_args()

    # Use CUDA
    # os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu_id
    use_cuda = torch.cuda.is_available()
    # Random seed
    if args.manual_seed is None:
        args.manual_seed = random.randint(1, 10000)
    random.seed(args.manual_seed)
    torch.manual_seed(args.manual_seed)
    if use_cuda:
        torch.cuda.manual_seed_all(args.manual_seed)

    # create model
    if args.resume == "" and args.pretrained:
        print("=> using pre-trained model '{}'".format(args.arch))
        model = models.__dict__[args.arch](pretrained=True)
    elif args.arch.startswith("resnext"):
        model = models.__dict__[args.arch](
            baseWidth=args.base_width, cardinality=args.cardinality,
        )
    elif args.arch.startswith("shufflenet"):
        model = models.__dict__[args.arch](groups=args.groups)
    else:
        print("=> creating model '{}'".format(args.arch))
        model = models.__dict__[args.arch](pretrained=False)

    if args.ft:
        for param in model.parameters():
            param.requires_grad = False
        classifier_numbers = model.num_attributes
        # newly constructed classifiers have requires_grad=True
        for i in range(classifier_numbers):
            setattr(
                model,
                "classifier" + str(i).zfill(2),
                nn.Sequential(fc_block(512, 256), nn.Linear(256, 1)),
            )

        args.sampler = "balance"

    args.distributed = args.world_size > 1

    if args.distributed:
        dist.init_process_group(
            backend=args.dist_backend,
            init_method=args.dist_url,
            world_size=args.world_size,
        )

    if not args.distributed:
        if args.arch.startswith("alexnet") or args.arch.startswith("vgg"):
            model.features = torch.nn.DataParallel(model.features)
            model.cuda()
        else:
            model = torch.nn.DataParallel(model).cuda()
    else:
        model.cuda()
        model = torch.nn.parallel.DistributedDataParallel(model)

    # optionally resume from a checkpoint
    title = "CelebA-" + args.arch
    if not os.path.isdir(args.checkpoint):
        mkdir_p(args.checkpoint)

    cudnn.benchmark = True

    # Data loading code
    normalize = transforms.Normalize(  # statistics from CelebA TrainSet
        mean=[0.5084, 0.4287, 0.3879], std=[0.2656, 0.2451, 0.2419]
    )
    print("=> using {} sampler to load data.".format(args.sampler))

    train_dataset = CelebA(
        args.data,
        "train_attr_list.txt",
        transforms.Compose(
            [
                transforms.RandomHorizontalFlip(),
                transforms.RandomResizedCrop(size=(256, 256), scale=(0.5, 1.0)),
                transforms.ToTensor(),
                normalize,
                transforms.RandomErasing(),
            ]
        ),
        sampler=args.sampler,
    )

    train_sample_prob = train_dataset._class_sample_prob()

    if args.distributed:
        train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
    else:
        if args.rs:
            train_sampler = WeightedRandomSampler(
                1 / train_sample_prob, len(train_dataset)
            )
        else:
            train_sampler = None

    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=args.train_batch,
        shuffle=(train_sampler is None),
        num_workers=args.workers,
        pin_memory=True,
        sampler=train_sampler,
    )

    val_loader = torch.utils.data.DataLoader(
        CelebA(
            args.data,
            "val_attr_list.txt",
            transforms.Compose(
                [transforms.Resize(size=(256, 256)), transforms.ToTensor(), normalize,]
            ),
        ),
        batch_size=args.test_batch,
        shuffle=False,
        num_workers=args.workers,
        pin_memory=True,
    )

    test_loader = torch.utils.data.DataLoader(
        CelebA(
            args.data,
            "test_attr_list.txt",
            transforms.Compose(
                [transforms.Resize(size=(256, 256)), transforms.ToTensor(), normalize,]
            ),
        ),
        batch_size=args.test_batch,
        shuffle=False,
        num_workers=args.workers,
        pin_memory=True,
    )

    lfw_test_loader = torch.utils.data.DataLoader(
        LFW(
            args.data_lfw, transforms.Compose([transforms.ToTensor(), normalize,]),
        ),  # celebA mean variance
        batch_size=args.test_batch,
        shuffle=False,
        num_workers=args.workers,
        pin_memory=True,
    )

    # define loss function (criterion) and optimizer
    if args.lw:  # loss weight
        print("=> loading CE loss_weight")
        criterion = nn.BCEWithLogitsLoss(
            reduction="mean", weight=1 / torch.sqrt(train_sample_prob)
        ).cuda()
    else:
        # criterion = nn.CrossEntropyLoss().cuda()
        criterion = nn.BCEWithLogitsLoss(reduction="mean").cuda()

    if args.focal:
        print("=> using focal loss")
        criterion = FocalLoss(criterion, balance_param=5)

    print("=> using wd {}".format(args.weight_decay))
    optimizer = torch.optim.SGD(
        model.parameters(),
        args.lr,
        momentum=args.momentum,
        weight_decay=args.weight_decay,
    )

    if args.resume:
        if os.path.isfile(args.resume):
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint["epoch"]
            best_prec1 = checkpoint["best_prec1"]
            model.load_state_dict(checkpoint["state_dict"])
            optimizer.load_state_dict(checkpoint["optimizer"])
            print(
                "=> loaded checkpoint '{}' (epoch {})".format(
                    args.resume, checkpoint["epoch"]
                )
            )
            args.checkpoint = os.path.dirname(args.resume)
            logger = Logger(
                os.path.join(args.checkpoint, "log.txt"), title=title, resume=True
            )
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))
    else:
        logger = Logger(os.path.join(args.checkpoint, "log.txt"), title=title)
        logger.set_names(
            [
                "Learning Rate",
                "Train Loss",
                "Valid Loss",
                "Train Acc.",
                "Valid Acc.",
                "LFW Loss.",
                "LFW Acc.",
            ]
        )

    if args.evaluate:  # TODO
        validate(test_loader, model, criterion)
        # stat(train_loader)
        return
    if args.validate:  # TODO
        validate(val_loader, model, criterion)
        # stat(train_loader)
        return

    if args.evaluate_lfw:
        validate(val_loader, model, criterion)
        validate(lfw_test_loader, model, criterion)
        return

    # visualization
    writer = SummaryWriter(os.path.join(args.checkpoint, "logs"))

    for epoch in range(args.start_epoch, args.epochs):
        if args.distributed:
            train_sampler.set_epoch(epoch)
        lr = adjust_learning_rate(optimizer, epoch)

        print("\nEpoch: [%d | %d] LR: %f" % (epoch + 1, args.epochs, lr))

        # train for one epoch
        train_loss, train_acc = train(train_loader, model, criterion, optimizer, epoch)

        # evaluate on validation set
        val_loss, prec1 = validate(val_loader, model, criterion)

        # evaluate on lfw
        lfw_loss, lfw_prec1 = validate(lfw_test_loader, model, criterion)

        # append logger file
        logger.append([lr, train_loss, val_loss, train_acc, prec1, lfw_loss, lfw_prec1])

        # tensorboardX
        writer.add_scalar("learning rate", lr, epoch + 1)
        writer.add_scalars(
            "loss",
            {
                "train loss": train_loss,
                "validation loss": val_loss,
                "lfw loss": lfw_loss,
            },
            epoch + 1,
        )
        writer.add_scalars(
            "accuracy",
            {
                "train accuracy": train_acc,
                "validation accuracy": prec1,
                "lfw accuracy": lfw_prec1,
            },
            epoch + 1,
        )
        # for name, param in model.named_parameters():
        #    writer.add_histogram(name, param.clone().cpu().data.numpy(), epoch + 1)

        is_best = prec1 > best_prec1
        best_prec1 = max(prec1, best_prec1)
        save_checkpoint(
            {
                "epoch": epoch + 1,
                "arch": args.arch,
                "state_dict": model.state_dict(),
                "best_prec1": best_prec1,
                "optimizer": optimizer.state_dict(),
            },
            is_best,
            checkpoint=args.checkpoint,
        )

    logger.close()
    logger.plot()
    savefig(os.path.join(args.checkpoint, "log.eps"))
    writer.close()

    print("Best accuracy:")
    print(best_prec1)