Ejemplo n.º 1
0
class FaceGenderAge(object):

    WRN_WEIGHTS_PATH = "https://github.com/Tony607/Keras_age_gender/releases/download/V1.0/weights.18-4.06.hdf5"

    def __init__(self, depth=16, width=8, face_size=64):
        self.face_size = face_size
        self.model = WideResNet(face_size, depth=depth, k=width)()
        model_dir = os.path.join(os.getcwd(), "models")
        fpath = get_file('weights.18-4.06.hdf5',
                         self.WRN_WEIGHTS_PATH,
                         cache_subdir=model_dir)
        self.model.load_weights(fpath)
        self.model._make_predict_function()

    def predict_agge(self, frame):
        inference_frame = cv2.resize(frame, (self.face_size, self.face_size))
        inference_frame = inference_frame[np.newaxis, :, :, :]
        results = self.model.predict(inference_frame)

        predicted_genders = results[0]
        ages = np.arange(0, 101).reshape(101, 1)
        predicted_ages = results[1].dot(ages).flatten()

        label = "Age: {}, Gender: {}".format(
            int(predicted_ages), "F" if predicted_genders[0][0] > 0.5 else "M")
        return label
Ejemplo n.º 2
0
 def load_network(loc):
     net_checkpoint = torch.load(loc)
     start_epoch = net_checkpoint['epoch']
     SavedConv, SavedBlock = what_conv_block(net_checkpoint['conv'],
             net_checkpoint['blocktype'], net_checkpoint['module'])
     net = WideResNet(args.wrn_depth, args.wrn_width, SavedConv, SavedBlock, num_classes=num_classes, dropRate=0).cuda()
     net.load_state_dict(net_checkpoint['net'])
     return net, start_epoch
Ejemplo n.º 3
0
 def __init__(self, depth=16, width=8, face_size=64):
     self.face_size = face_size
     self.model = WideResNet(face_size, depth=depth, k=width)()
     model_dir = os.path.join(os.getcwd(), "models")
     fpath = get_file('weights.18-4.06.hdf5',
                      self.WRN_WEIGHTS_PATH,
                      cache_subdir=model_dir)
     self.model.load_weights(fpath)
     self.model._make_predict_function()
Ejemplo n.º 4
0
    def __init__(self):
        if torch.cuda.is_available():
            self.device = torch.device('cuda')
        else:
            self.device = torch.device('cpu')
            print('WARNING: Found no valid GPU device - Running on CPU')

        self.model = WideResNet(DEPTH, cfg.NUM_TRANS)
        self.model.cuda(self.device)
        self.criterion = torch.nn.CrossEntropyLoss().cuda(self.device)
        self.optimizer = None
Ejemplo n.º 5
0
def build_network(Conv, Block, network):
    if network == 'WideResNet':
        return WideResNet(28, 10, Conv, Block, num_classes=10, dropRate=0)
    elif network == 'WRN_50_2':
        return WRN_50_2(Conv)
    elif network == 'DARTS':
        return DARTS(Conv, num_classes=10, drop_path_prob=0., auxiliary=False)
Ejemplo n.º 6
0
    def load_network(loc, masked=False):
        net_checkpoint = torch.load(loc)
        start_epoch = net_checkpoint['epoch']
        SavedConv, SavedBlock = what_conv_block(net_checkpoint['conv'],
                                                net_checkpoint['blocktype'],
                                                net_checkpoint['module'])

        net = WideResNet(args.wrn_depth,
                         args.wrn_width,
                         SavedConv,
                         SavedBlock,
                         num_classes=num_classes,
                         dropRate=0,
                         masked=masked).cuda()

        if masked:
            new_sd = net.state_dict()
            old_sd = net_checkpoint['net']
            new_names = [v for v in new_sd]

            old_names = [v for v in old_sd]
            for i, j in enumerate(new_names):
                new_sd[j] = old_sd[old_names[i]]

            net.load_state_dict(new_sd)
        else:
            net.load_state_dict(net_checkpoint['net'])
        return net, start_epoch
Ejemplo n.º 7
0
    def load_baseline_model(self):
        """
        Load a simple baseline model AND dataset
        Note that this sets the model to training mode
        """
        if self.args.dataset == DATASET_CIFAR_10:
            imsize, in_channel, num_classes = 32, 3, 10
        elif self.args.dataset == DATASET_CIFAR_100:
            imsize, in_channel, num_classes = 32, 3, 100
        elif self.args.dataset == DATASET_MNIST:
            imsize, in_channel, num_classes = 28, 1, 10
        elif self.args.dataset == DATASET_BOSTON:
            imsize, in_channel, num_classes = 13, 1, 1

        # init_l2 = -7  # TODO: Important to make sure this is small enough to be unregularized when starting?
        if self.args.model == MODEL_RESNET18:
            cnn = ResNet18(num_classes=num_classes)
        elif self.args.model == MODEL_WIDERESNET:
            cnn = WideResNet(depth=28,
                             num_classes=num_classes,
                             widen_factor=10,
                             dropRate=0.3)
        elif self.args.model[:3] == MODEL_MLP:
            cnn = Net(self.args.num_layers,
                      0.0,
                      imsize,
                      in_channel,
                      INIT_L2,
                      num_classes=num_classes,
                      do_classification=self.args.do_classification)
        elif self.args.model == MODEL_CNN_MLP:
            cnn = CNN_MLP(learning_rate=0.0001)

        checkpoint = None
        if self.args.load_baseline_checkpoint:
            checkpoint = torch.load(self.args.load_baseline_checkpoint)
            cnn.load_state_dict(checkpoint['model_state_dict'])

        model = cnn.to(self.device)
        if self.args.use_weight_decay:
            if self.args.weight_decay_all:
                num_p = sum(p.numel() for p in model.parameters())
                weights = np.ones(num_p) * INIT_L2
                model.weight_decay = Variable(torch.FloatTensor(weights).to(
                    self.device),
                                              requires_grad=True)
            else:
                weights = INIT_L2
                model.weight_decay = Variable(torch.FloatTensor([weights]).to(
                    self.device),
                                              requires_grad=True)
            model.weight_decay = model.weight_decay.to(self.device)
        model.train()
        return model, checkpoint
Ejemplo n.º 8
0
def load_baseline_model(args):
    """

    :param args:
    :return:
    """
    if args.dataset == 'cifar10':
        num_classes = 10
        train_loader, val_loader, test_loader = data_loaders.load_cifar10(args.batch_size, val_split=True,
                                                                          augmentation=args.data_augmentation)
    elif args.dataset == 'cifar100':
        num_classes = 100
        train_loader, val_loader, test_loader = data_loaders.load_cifar100(args.batch_size, val_split=True,
                                                                           augmentation=args.data_augmentation)
    elif args.dataset == 'mnist':
        args.datasize, args.valsize, args.testsize = 100, 100, 100
        num_train = args.datasize
        if args.datasize == -1:
            num_train = 50000

        from data_loaders import load_mnist
        train_loader, val_loader, test_loader = load_mnist(args.batch_size,
                                                           subset=[args.datasize, args.valsize, args.testsize],
                                                           num_train=num_train)

    if args.model == 'resnet18':
        cnn = ResNet18(num_classes=num_classes)
    elif args.model == 'wideresnet':
        cnn = WideResNet(depth=28, num_classes=num_classes, widen_factor=10, dropRate=0.3)

    checkpoint = None
    if args.load_baseline_checkpoint:
        checkpoint = torch.load(args.load_baseline_checkpoint)
        cnn.load_state_dict(checkpoint['model_state_dict'])

    model = cnn.cuda()
    model.train()
    return model, train_loader, val_loader, test_loader, checkpoint
Ejemplo n.º 9
0
 def build_network(Conv, Block):
     if args.network == 'WideResNet':
         return WideResNet(args.wrn_depth,
                           args.wrn_width,
                           Conv,
                           Block,
                           num_classes=num_classes,
                           dropRate=0,
                           s=args.AT_split)
     elif args.network == 'WRN_50_2':
         return WRN_50_2(Conv)
     elif args.network == 'MobileNetV2':
         return MobileNetV2(Conv)
     elif args.network == 'DARTS':
         return DARTS(Conv, num_classes=num_classes)
Ejemplo n.º 10
0
 def build_network(Conv, Block):
     if args.network == 'WideResNet':
         return WideResNet(args.wrn_depth,
                           args.wrn_width,
                           Conv,
                           Block,
                           num_classes=num_classes,
                           dropRate=0,
                           s=args.AT_split,
                           spectral=args.spectral
                           and not (args.rank_scale or args.target_ratio))
     elif args.network == 'WRN_50_2':
         return WRN_50_2(Conv)
     elif args.network == 'MobileNetV2':
         return MobileNetV2(Conv)
     elif args.network == 'DARTS':
         return DARTS(Conv, num_classes=num_classes)
Ejemplo n.º 11
0
def random_conv_sub(net, param_target):
    conv_options = get_ordered_conv_options()
    while ((get_no_params(net) > param_target)
           or (get_no_params(net) < (param_target - (param_target * 0.05)))):
        convs = []
        for m in net.modules():
            conv = choice(conv_options)
            convs.append(conv)

        net = WideResNet(args.wrn_depth,
                         args.wrn_width,
                         Conv,
                         MaskBlock,
                         num_classes=num_classes,
                         dropRate=0,
                         s=args.AT_split,
                         convs=convs).cuda()
    return convs
Ejemplo n.º 12
0
 def build_network(Conv, Block):
     if args.network == 'WideResNet':
         return WideResNet(args.wrn_depth,
                           args.wrn_width,
                           Conv,
                           Block,
                           num_classes=num_classes,
                           dropRate=0)
     elif args.network == 'WRN_50_2':
         return WRN_50_2(Conv)
     elif args.network == 'DARTS':
         assert not args.conv == 'Conv', 'The base network here used' \
         ' separable convolutions, so you probably did not mean to set this' \
         ' option.'
         return DARTS(Conv,
                      num_classes=num_classes,
                      drop_path_prob=0.,
                      auxiliary=False)
     elif args.network == 'MobileNetV2':
         return MobileNetV2(Conv)
Ejemplo n.º 13
0
    total_nonzero = 0.0
    # With the valid epsilon, we can set sparsities of the remaning layers.
    for name, mask in masks.items():
        n_param = np.prod(mask.shape)
        if name in dense_layers:
            density_dict[name] = 1.0
        else:
            probability_one = epsilon * raw_probabilities[name]
            density_dict[name] = probability_one
        logging.info(
            f"layer: {name}, shape: {mask.shape}, density: {density_dict[name]}"
        )
        total_nonzero += density_dict[name] * mask.numel()
    logging.info(f"Overall sparsity {total_nonzero/total_params}")
    return density_dict


if __name__ == "__main__":
    from models.wide_resnet import WideResNet

    model = WideResNet(depth=22, widen_factor=2)

    logging.basicConfig()
    logging.getLogger().setLevel(logging.INFO)

    tim_ERK(model, density=0.2)

    logging.info("========")

    googleAI_ERK(model, density=0.2)
emotion_labels = get_labels('fer2013')
gender_labels = get_labels('imdb')
race_labels = get_labels('race')
font = cv2.FONT_HERSHEY_SIMPLEX

# hyper-parameters for bounding boxes shape
frame_window = 10
gender_offsets = (30, 60)
emotion_offsets = (20, 40)

# loading models
face_detection = load_detection_model(detection_model_path)
emotion_classifier = load_model(emotion_model_path, compile=False)
# race_classifier = load_model(race_model_path, compile=False)
age_gender_classifier = WideResNet(64, depth=16, k=8)()
age_gender_classifier.load_weights(age_gender_model_path)

# getting input model shapes for inference
emotion_target_size = emotion_classifier.input_shape[1:3]
age_gender_target_size = age_gender_classifier.input_shape[1:3]
# race_target_size = race_classifier.input_shape[1:3]

# starting lists for calculating modes
gender_window = []
emotion_window = []

# starting video streaming
cv2.namedWindow('window_frame')
video_capture = cv2.VideoCapture(video_path)
test_transform = transforms.Compose([transforms.ToTensor(), normalize])

if args.dataset == 'cifar10':
    num_classes = 10
    train_loader, val_loader, test_loader = data_loaders.load_cifar10(
        args.batch_size, val_split=True, augmentation=args.data_augmentation)
elif args.dataset == 'cifar100':
    num_classes = 100
    train_loader, val_loader, test_loader = data_loaders.load_cifar100(
        args.batch_size, val_split=True, augmentation=args.data_augmentation)

if args.model == 'resnet18':
    cnn = ResNet18(num_classes=num_classes)
elif args.model == 'wideresnet':
    cnn = WideResNet(depth=28,
                     num_classes=num_classes,
                     widen_factor=10,
                     dropRate=0.3)

cnn = cnn.cuda()
criterion = nn.CrossEntropyLoss().cuda()

if args.optimizer == 'sgdm':
    cnn_optimizer = torch.optim.SGD(cnn.parameters(),
                                    lr=args.lr,
                                    momentum=0.9,
                                    nesterov=True,
                                    weight_decay=args.wdecay)
elif args.optimizer == 'sgd':
    cnn_optimizer = torch.optim.SGD(cnn.parameters(),
                                    lr=args.lr,
                                    weight_decay=args.wdecay)
Ejemplo n.º 16
0
class NeuralNet:
    def __init__(self):
        if torch.cuda.is_available():
            self.device = torch.device('cuda')
        else:
            self.device = torch.device('cpu')
            print('WARNING: Found no valid GPU device - Running on CPU')

        self.model = WideResNet(DEPTH, cfg.NUM_TRANS)
        self.model.cuda(self.device)
        self.criterion = torch.nn.CrossEntropyLoss().cuda(self.device)
        self.optimizer = None

    def test(self, test_gen):
        self.model.eval()

        batch_time = self.AverageMeter('Time', ':6.3f')
        losses = self.AverageMeter('Loss', ':.4e')
        top1 = self.AverageMeter('Acc@1', ':6.2f')
        progress = self.ProgressMeter(len(test_gen),
                                      batch_time,
                                      losses,
                                      top1,
                                      prefix='Test: ')

        # switch to evaluate mode
        self.model.eval()

        with torch.no_grad():
            end = time.time()
            for i, (input, target, _) in enumerate(test_gen):
                input = input.cuda(self.device, non_blocking=True)
                target = target.cuda(self.device, non_blocking=True)

                # Compute output
                output, _ = self.model([input, target])
                loss = self.criterion(output, target)

                # Measure accuracy and record loss
                acc1 = self._accuracy(output, target, topk=(1))
                losses.update(loss.item(), input.size(0))
                top1.update(acc1[0].cpu().detach().item(), input.size(0))

                # Measure elapsed time
                batch_time.update(time.time() - end)
                end = time.time()

                # Print to screen
                if i % 100 == 0:
                    progress.print(i)

            # TODO: this should also be done with the ProgressMeter
            print(' * Acc@1 {top1.avg:.3f}'.format(top1=top1))

        return top1.avg

    def train(self,
              train_gen,
              test_gen,
              epochs,
              lr=0.0001,
              lr_plan=None,
              momentum=0.9,
              wd=5e-4):
        self.optimizer = torch.optim.SGD(self.model.parameters(),
                                         lr=lr,
                                         momentum=momentum,
                                         weight_decay=wd)
        #self.optimizer = torch.optim.Adam(self.model.parameters())

        for epoch in range(epochs):
            self._adjust_lr_rate(self.optimizer, epoch, lr_plan)
            print("=> Training (specific label)")
            self._train_step(train_gen, epoch, self.optimizer)
            print("=> Validation (entire dataset)")
            self.test(test_gen)

    def _train_step(self, train_gen, epoch, optimizer):
        self.model.train()

        batch_time = self.AverageMeter('Time', ':6.3f')
        data_time = self.AverageMeter('Data', ':6.3f')
        losses = self.AverageMeter('Loss', ':.4e')
        top1 = self.AverageMeter('Acc@1', ':6.2f')
        progress = self.ProgressMeter(len(train_gen),
                                      batch_time,
                                      data_time,
                                      losses,
                                      top1,
                                      prefix="Epoch: [{}]".format(epoch))

        end = time.time()
        for i, (input, target, _) in enumerate(train_gen):
            # measure data loading time
            data_time.update(time.time() - end)

            input = input.cuda(self.device, non_blocking=True)
            target = target.cuda(self.device, non_blocking=True)

            # Compute output
            output, trans_out = self.model([input, target])
            loss = self.criterion(output, target)

            # measure accuracy and record loss
            acc1 = self._accuracy(output, target, topk=(1))
            losses.update(loss.item(), input.size(0))
            top1.update(acc1[0].cpu().detach().item(), input.size(0))

            # compute gradient and do SGD step
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            if i % 100 == 0:
                progress.print(i)

    def evaluate(self, eval_gen):
        # switch to evaluate mode
        self.model.eval()

        score_func_list = []
        labels_list = []
        with torch.no_grad():

            for i, (input, target, labels) in enumerate(eval_gen):
                input = input.cuda(self.device, non_blocking=True)
                #Target- the transforamation class
                target = target.cuda(self.device, non_blocking=True)
                #The true label
                labels = labels[[
                    cfg.NUM_TRANS * x
                    for x in range(len(labels) // cfg.NUM_TRANS)
                ]].cuda(self.device, non_blocking=True)

                # Compute output
                # #TODO: Rewrite this code section, can be more efficient
                output_SM = self.model([input, target])

                target_mat = torch.zeros_like(output_SM[0])
                target_mat[range(output_SM[0].shape[0]), target] = 1
                target_SM = (target_mat * output_SM[0]).sum(dim=1).view(
                    -1, cfg.NUM_TRANS).sum(dim=1)

                score_func_list.append(1 / cfg.NUM_TRANS * target_SM)
                labels_list.append(labels)

        return torch.cat(score_func_list), torch.cat(labels_list)

    def _adjust_lr_rate(self, optimizer, epoch, lr_dict):
        if lr_dict is None:
            return

        for key, value in lr_dict.items():
            if epoch == key:
                print("=> New learning rate set of {}".format(value))
                for param_group in optimizer.param_groups:
                    param_group['lr'] = value

    def summary(self, x_size, print_it=True):
        return self.model.summary(x_size, print_it=print_it)

    def print_weights(self):
        self.model.print_weights()

    @staticmethod
    def _accuracy(output, target, topk=(1)):
        """Computes the accuracy over the k top predictions for the specified values of k"""
        with torch.no_grad():
            maxk = 1
            batch_size = target.size(0)

            _, pred = output.topk(maxk, 1, True, True)
            pred = pred.t()
            correct = pred.eq(target.view(1, -1).expand_as(pred))

            res = []

            correct_k = correct[0].view(-1).float().sum(0, keepdim=True)
            res.append(correct_k.mul_(100.0 / batch_size))

            return res

    class AverageMeter(object):
        """Computes and stores the average and current value"""
        def __init__(self, name, fmt=':f'):
            self.name = name
            self.fmt = fmt
            self.reset()

        def reset(self):
            self.val = 0
            self.avg = 0
            self.sum = 0
            self.count = 0

        def update(self, val, n=1):
            self.val = val
            self.sum += val * n
            self.count += n
            self.avg = self.sum / self.count

        def __str__(self):
            fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
            return fmtstr.format(**self.__dict__)

    class ProgressMeter(object):
        def __init__(self, num_batches, *meters, prefix=""):
            self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
            self.meters = meters
            self.prefix = prefix

        def print(self, batch):
            entries = [self.prefix + self.batch_fmtstr.format(batch)]
            entries += [str(meter) for meter in self.meters]
            print('\t'.join(entries))

        def _get_batch_fmtstr(self, num_batches):
            num_digits = len(str(num_batches // 1))
            fmt = '{:' + str(num_digits) + 'd}'
            return '[' + fmt + '/' + fmt.format(num_batches) + ']'
                                               num_workers=4)
    test_set = torchvision.datasets.FashionMNIST(root="/mnt/ds/ryo",
                                                 train=False,
                                                 download=True,
                                                 transform=transform_test)
    test_loader = torch.utils.data.DataLoader(test_set,
                                              batch_size=100,
                                              shuffle=False,
                                              num_workers=4)
    num_class = 10

############## parameter ###################
if str(args.model) == "resnet18":
    model = ResNet18(num_class).to(device)
if str(args.model) == "wide-resnet28-10":
    model = WideResNet(28, 10, 0, in_channels=3, labels=num_class).to(device)

# optimizing
if str(args.optimizer) == "SGD":
    optimizer = optim.SGD(model.parameters(), lr=LEANRATE, momentum=0)
if str(args.optimizer) == "SAM":
    base_optimizer = torch.optim.SGD
    optimizer = SAM(model.parameters(),
                    base_optimizer,
                    rho=0.05,
                    lr=LEANRATE,
                    momentum=0,
                    weight_decay=0.0005)

criterion = nn.CrossEntropyLoss()
Ejemplo n.º 18
0
        net_checkpoint = torch.load(loc)
        start_epoch = net_checkpoint['epoch']
        SavedConv, SavedBlock = what_conv_block(net_checkpoint['conv'],
                net_checkpoint['blocktype'], net_checkpoint['module'])
        net = WideResNet(args.wrn_depth, args.wrn_width, SavedConv, SavedBlock, num_classes=num_classes, dropRate=0).cuda()
        net.load_state_dict(net_checkpoint['net'])
        return net, start_epoch

    if args.mode == 'teacher':

        if args.resume:
            print('Mode Teacher: Loading teacher and continuing training...')
            teach, start_epoch = load_network('checkpoints/%s.t7' % args.teacher_checkpoint)
        else:
            print('Mode Teacher: Making a teacher network from scratch and training it...')
            teach = WideResNet(args.wrn_depth, args.wrn_width, Conv, Block, num_classes=num_classes, dropRate=0).cuda()


        get_no_params(teach)
        optimizer = optim.SGD(teach.parameters(), lr=args.lr, momentum=0.9, weight_decay=args.weightDecay)
        scheduler = lr_scheduler.MultiStepLR(optimizer, milestones=epoch_step, gamma=args.lr_decay_ratio)

        # Decay the learning rate depending on the epoch
        for e in range(0,start_epoch):
            scheduler.step()

        for epoch in tqdm(range(start_epoch, args.epochs)):
            scheduler.step()
            print('Teacher Epoch %d:' % epoch)
            print('Learning rate is %s' % [v['lr'] for v in optimizer.param_groups][0])
            writer.add_scalar('learning_rate', [v['lr'] for v in optimizer.param_groups][0], epoch)
Ejemplo n.º 19
0
            net.load_state_dict(net_checkpoint['net'])
        return net, start_epoch

    if args.mode == 'teacher':

        if args.resume:
            print('Mode Teacher: Loading teacher and continuing training...')
            teach, start_epoch = load_network('checkpoints/%s.t7' %
                                              args.teacher_checkpoint)
        else:
            print(
                'Mode Teacher: Making a teacher network from scratch and training it...'
            )
            teach = WideResNet(args.wrn_depth,
                               args.wrn_width,
                               Conv,
                               Block,
                               num_classes=num_classes,
                               dropRate=0).cuda()

        get_no_params(teach)
        optimizer = optim.SGD(teach.parameters(),
                              lr=args.lr,
                              momentum=0.9,
                              weight_decay=args.weightDecay)
        scheduler = lr_scheduler.MultiStepLR(optimizer,
                                             milestones=epoch_step,
                                             gamma=args.lr_decay_ratio)

        # Decay the learning rate depending on the epoch
        for e in range(0, start_epoch):
            scheduler.step()
Ejemplo n.º 20
0
    def detection(self):
        depth = 16
        width = 8
        img_size = 64
        model = WideResNet(img_size, depth=depth, k=width)()
        model.load_weights(r'models/weights.hdf5')

        detector = dlib.get_frontal_face_detector()

        image_np = cv2.imdecode(np.fromfile(self.fname, dtype=np.uint8), -1)

        img_h = image_np.shape[0]
        img_w = image_np.shape[1]

        detected = detector(image_np, 1)

        gender_faces = []
        labels = []
        original_faces = []
        photo_position = []

        change_male_to_female_path = r'models/netG_A2B.pth'
        change_female_to_male_path = r'models/netG_B2A.pth'

        # 加载CycleGAN模型
        netG_male2female = Generator(3, 3)
        netG_female2male = Generator(3, 3)

        netG_male2female.load_state_dict(
            torch.load(change_male_to_female_path, map_location='cpu'))
        netG_female2male.load_state_dict(
            torch.load(change_female_to_male_path, map_location='cpu'))

        # 设置模型为预测模式
        netG_male2female.eval()
        netG_female2male.eval()

        transform_list = [
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ]

        transform = transforms.Compose(transform_list)
        """
        这段内容为图片数据处理frame
        """
        if len(detected) > 0:
            for i, d in enumerate(detected):
                # weight和height表示原始图片的宽度和高度,因为我们需要不停地resize处理图片
                # 最后要贴回到原始图片中去,w和h就用来做最后的resize
                x0, y0, x1, y1, w, h = d.left(), d.top(), d.right(), d.bottom(
                ), d.width(), d.height()

                x0 = max(int(x0 - 0.25 * w), 0)
                y0 = max(int(y0 - 0.45 * h), 0)
                x1 = min(int(x1 + 0.25 * w), img_w - 1)
                y1 = min(int(y1 + 0.05 * h), img_h - 1)
                w = x1 - x0
                h = y1 - y0
                if w > h:
                    x0 = x0 + w // 2 - h // 2
                    w = h
                    x1 = x0 + w
                else:
                    y0 = y0 + h // 2 - w // 2
                    h = w
                    y1 = y0 + h

                original_faces.append(
                    cv2.resize(image_np[y0:y1, x0:x1, :], (256, 256)))
                gender_faces.append(
                    cv2.resize(image_np[y0:y1, x0:x1, :],
                               (img_size, img_size)))
                photo_position.append([y0, y1, x0, x1, w, h])

            gender_faces = np.array(gender_faces)
            results = model.predict(gender_faces)
            predicted_genders = results[0]

            for i in range(len(original_faces)):
                labels.append('F' if predicted_genders[i][0] > 0.5 else 'M')

            for i, gender in enumerate(labels):

                # 这几个变量用于接下来图片缩放和替换
                y0, y1, x0, x1, w, h = photo_position[i]

                # 将数据转换成Pytorch可以处理的格式
                picture = transform(original_faces[i])
                picture = Variable(picture)
                input_picture = picture.unsqueeze(0)

                if gender == "M":
                    fake_female = 0.5 * (netG_male2female(input_picture).data +
                                         1.0)
                    out_img = fake_female.detach().squeeze(0)
                else:
                    fake_male = 0.5 * (netG_female2male(input_picture).data +
                                       1.0)
                    out_img = fake_male.detach().squeeze(0)

                # 需要将Pytorch处理之后得到的数据,转换为OpenCV可以处理的格式
                # 下面代码就是转换代码
                image_numpy = out_img.float().numpy()
                image_numpy = (np.transpose(image_numpy,
                                            (1, 2, 0)) + 1) / 2.0 * 255.0
                image_numpy = image_numpy.clip(0, 255)
                image_numpy = image_numpy.astype(np.uint8)

                # 将转换好的性别图片替换到原始图片中去
                # 使用泊松融合使生成图像和背景图像浑然一体
                # 使用方法:cv2.seamlessClone(src, dst, mask, center, flags)
                generate_face = cv2.resize(image_numpy, (w, h))

                # Create an all white mask, 感兴趣的需要替换的目标区域,精确地mask可以更好的替换,这里mask就是生成图片的大小
                mask = 255 * np.ones((w, h), image_np.dtype)
                # center是目标影像的中心在背景图像上的坐标!
                center_y = y0 + h // 2
                center_x = x0 + w // 2
                center = (center_x, center_y)
                output_face = cv2.seamlessClone(generate_face, image_np, mask,
                                                center, cv2.NORMAL_CLONE)

                self.out_img = output_face
def main(args):
    harakiri = Harakiri()
    harakiri.set_max_plateau(20)
    train_loss_meter = Meter()
    val_loss_meter = Meter()
    val_accuracy_meter = Meter()
    log = JsonLogger(args.log_path, rand_folder=True)
    log.update(args.__dict__)
    state = args.__dict__
    state['exp_dir'] = os.path.dirname(log.path)
    state['start_lr'] = state['lr']
    print(state)

    imagenet_mean = [0.485, 0.456, 0.406]
    imagenet_std = [0.229, 0.224, 0.225]

    train_dataset = ImageList(args.root_folder,
                              args.train_listfile,
                              transform=transforms.Compose([
                                  transforms.Resize(256),
                                  transforms.RandomCrop(224),
                                  transforms.RandomHorizontalFlip(),
                                  transforms.ToTensor(),
                                  transforms.Normalize(imagenet_mean,
                                                       imagenet_std)
                              ]))
    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               pin_memory=False,
                                               num_workers=args.num_workers)
    val_dataset = ImageList(args.root_folder,
                            args.val_listfile,
                            transform=transforms.Compose([
                                transforms.Resize(256),
                                transforms.CenterCrop(224),
                                transforms.ToTensor(),
                                transforms.Normalize(imagenet_mean,
                                                     imagenet_std)
                            ]))
    val_loader = torch.utils.data.DataLoader(val_dataset,
                                             batch_size=args.batch_size,
                                             shuffle=False,
                                             pin_memory=False,
                                             num_workers=args.num_workers)

    if args.attention_depth == 0:
        from models.wide_resnet import WideResNet
        model = WideResNet().finetune(args.nlabels).cuda()
    else:
        from models.wide_resnet_attention import WideResNetAttention
        model = WideResNetAttention(args.nlabels, args.attention_depth,
                                    args.attention_width, args.has_gates,
                                    args.reg_weight).finetune(args.nlabels)

    # if args.load != "":
    #     net.load_state_dict(torch.load(args.load), strict=False)
    #     net = net.cuda()

    optimizer = optim.SGD([{
        'params': model.get_base_params(),
        'lr': args.lr * 0.1
    }, {
        'params': model.get_classifier_params()
    }],
                          lr=args.lr,
                          weight_decay=1e-4,
                          momentum=0.9,
                          nesterov=True)

    if args.ngpu > 1:
        model = torch.nn.DataParallel(model, range(args.ngpu)).cuda()
    else:
        model = model.cuda()
        criterion = torch.nn.NLLLoss().cuda()

    def train():
        """

        """
        model.train()
        for data, label in train_loader:
            data, label = torch.autograd.Variable(data, requires_grad=False).cuda(async=True), \
                          torch.autograd.Variable(label, requires_grad=False).cuda()
            optimizer.zero_grad()
            if args.attention_depth > 0:
                output, loss = model(data)
                if args.reg_weight > 0:
                    loss = loss.mean()
                else:
                    loss = 0
            else:
                loss = 0
                output = model(data)
            loss += F.nll_loss(output, label)
            loss.backward()
            optimizer.step()
            train_loss_meter.update(loss.data[0], data.size(0))
        state['train_loss'] = train_loss_meter.mean()

    def val():
        """

        """
        model.eval()
        for data, label in val_loader:
            data, label = torch.autograd.Variable(data, volatile=True).cuda(async=True), \
                          torch.autograd.Variable(label, volatile=True).cuda()
            if args.attention_depth > 0:
                output, loss = model(data)
            else:
                output = model(data)
            loss = F.nll_loss(output, label)
            val_loss_meter.update(loss.data[0], data.size(0))
            preds = output.max(1)[1]
            val_accuracy_meter.update((preds == label).float().sum().data[0],
                                      data.size(0))
        state['val_loss'] = val_loss_meter.mean()
        state['val_accuracy'] = val_accuracy_meter.mean()

    best_accuracy = 0
    counter = 0
    for epoch in range(args.epochs):
        train()
        val()
        harakiri.update(epoch, state['val_accuracy'])
        if state['val_accuracy'] > best_accuracy:
            counter = 0
            best_accuracy = state['val_accuracy']
            if args.save:
                torch.save(model.state_dict(),
                           os.path.join(state["exp_dir"], "model.pytorch"))
        else:
            counter += 1
        state['epoch'] = epoch + 1
        log.update(state)
        print(state)
        if (epoch + 1) in args.schedule:
            for param_group in optimizer.param_groups:
                param_group['lr'] *= 0.1
            state['lr'] *= 0.1