def create_model(weights):
    model = RetinaFace(cfg=cfg_re50, phase='test')

    checkpoint = load(weights, map_location='cpu')
    ckpt = {k.replace('module.', ''): v for k, v in checkpoint.items()}

    model.load_state_dict(ckpt)
    return model
Exemple #2
0
save_folder = args.save_folder

if args.resume_net is not None:
    print('Loading resume network...')
    state_dict = torch.load(args.resume_net)
    # create new OrderedDict that does not contain `module.`
    from collections import OrderedDict
    new_state_dict = OrderedDict()
    for k, v in state_dict.items():
        head = k[:7]
        if head == 'module.':
            name = k[7:]  # remove `module.`
        else:
            name = k
        new_state_dict[name] = v
    net.load_state_dict(new_state_dict)

if num_gpu > 1 and gpu_train:
    net = torch.nn.DataParallel(net).cuda()
else:
    net = net.cuda()

cudnn.benchmark = True

optimizer = optim.SGD(net.parameters(),
                      lr=initial_lr,
                      momentum=momentum,
                      weight_decay=weight_decay)
criterion = MultiBoxLoss(num_classes, 0.35, True, 0, True, 7, 0.35, False)

priorbox = PriorBox(cfg, image_size=(img_dim, img_dim))
def run(args):
    # net and load
    cfg = cfg_mnet
    net = RetinaFace(cfg=cfg, phase='test')
    new_state_dict = load_normal(args.trained_model)
    net.load_state_dict(new_state_dict)
    print('Finished loading model!')
    print(net)

    torch.set_grad_enabled(False)

    device = torch.device("cpu" if args.cpu else "cuda")
    net = net.to(device)
    input = torch.randn(1, 3, 270, 480).cuda()
    flops, params = profile(net, inputs=(input, ))
    print('flops:', flops, 'params:', params)

    # testing dataset
    with open(args.test_list_dir, 'r') as fr:
        test_dataset = fr.read().split()
    test_dataset.sort()

    _t = {'forward_pass': Timer(), 'misc': Timer()}

    # testing begin
    if not os.path.isdir(args.save_folder):
        os.makedirs(args.save_folder)
    f_ = open(os.path.join(args.save_folder, 'vis_bbox.txt'), 'w')
    net.eval()
    for i, image_path in enumerate(test_dataset):
        #img_name = os.path.split(image_path)[-1]
        img_name = image_path[image_path.find('datasets') + 9:]
        img_raw = cv2.imread(image_path, cv2.IMREAD_COLOR)
        # img_raw = cv2.resize(img_raw, None, fx=1./3, fy=1.0/3, interpolation=cv2.INTER_AREA)
        img = np.float32(img_raw)

        # testing scale
        target_size = 1600
        max_size = 2150
        im_shape = img.shape
        im_size_min = np.min(im_shape[0:2])
        im_size_max = np.max(im_shape[0:2])
        resize = float(target_size) / float(im_size_min)
        # prevent bigger axis from being more than max_size:
        if np.round(resize * im_size_max) > max_size:
            resize = float(max_size) / float(im_size_max)
        if args.origin_size:
            resize = 1

        if resize != 1:
            img = cv2.resize(img,
                             None,
                             None,
                             fx=resize,
                             fy=resize,
                             interpolation=cv2.INTER_LINEAR)
        im_height, im_width, _ = img.shape
        scale = torch.Tensor(
            [img.shape[1], img.shape[0], img.shape[1], img.shape[0]])
        img -= (104, 117, 123)
        img = img.transpose(2, 0, 1)
        img = torch.from_numpy(img).unsqueeze(0)
        img = img.to(device)
        scale = scale.to(device)

        _t['forward_pass'].tic()
        loc, conf, landms = net(img)  # forward pass
        _t['forward_pass'].toc()
        _t['misc'].tic()
        priorbox = PriorBox(cfg, image_size=(im_height, im_width))
        priors = priorbox.forward()
        priors = priors.to(device)
        prior_data = priors.data
        boxes = decode(loc.data.squeeze(0), prior_data, cfg['variance'])
        boxes = boxes * scale / resize
        boxes = boxes.cpu().numpy()
        scores = conf.squeeze(0).data.cpu().numpy()[:, 1]
        landms = decode_landm(landms.data.squeeze(0), prior_data,
                              cfg['variance'])
        scale1 = torch.Tensor([
            img.shape[3], img.shape[2], img.shape[3], img.shape[2],
            img.shape[3], img.shape[2], img.shape[3], img.shape[2],
            img.shape[3], img.shape[2]
        ])
        scale1 = scale1.to(device)
        landms = landms * scale1 / resize
        landms = landms.cpu().numpy()

        # ignore low scores
        inds = np.where(scores > args.confidence_threshold)[0]
        boxes = boxes[inds]
        landms = landms[inds]
        scores = scores[inds]

        # keep top-K before NMS
        order = scores.argsort()[::-1]
        order = scores.argsort()[::-1][:args.top_k]
        boxes = boxes[order]
        landms = landms[order]
        scores = scores[order]

        # do NMS
        dets = np.hstack((boxes, scores[:, np.newaxis])).astype(np.float32,
                                                                copy=False)
        keep = py_cpu_nms(dets, args.nms_threshold)
        # keep = nms(dets, args.nms_threshold,force_cpu=args.cpu)
        dets = dets[keep, :]
        landms = landms[keep]

        # keep top-K faster NMS
        dets = dets[:args.keep_top_k, :]
        landms = landms[:args.keep_top_k, :]

        dets = np.concatenate((dets, landms), axis=1)
        _t['misc'].toc()

        # --------------------------------------------------------------------
        save_name = os.path.join(args.save_folder, 'txt',
                                 img_name)[:-4] + '.txt'
        dirname = os.path.dirname(save_name)
        if not os.path.isdir(dirname):
            os.makedirs(dirname)
        with open(save_name, "w") as fd:
            bboxs = dets
            file_name = os.path.basename(save_name)[:-4] + "\n"
            bboxs_num = str(len(bboxs)) + "\n"
            fd.write(file_name)
            fd.write(bboxs_num)
            for box in bboxs:
                x = int(box[0])
                y = int(box[1])
                w = int(box[2]) - int(box[0])
                h = int(box[3]) - int(box[1])
                confidence = str(box[4])
                line = str(x) + " " + str(y) + " " + str(w) + " " + str(
                    h) + " " + confidence + " \n"
                fd.write(line)

        print('im_detect: {:d}/{:d}'
              ' forward_pass_time: {:.4f}s'
              ' misc: {:.4f}s'
              ' img_shape:{:}'.format(i + 1, len(test_dataset),
                                      _t['forward_pass'].average_time,
                                      _t['misc'].average_time, img.shape))

        # save bbox-image
        line_write = save_image(dets,
                                args.vis_thres,
                                img_raw,
                                args.save_folder,
                                img_name,
                                save_all=args.save_image_all)
        f_.write(line_write)
        f_.flush()
    f_.close()
if args.resume_net is not None:
    print('Loading resume network...')
    state_dict = torch.load(args.resume_net)
    # create new OrderedDict that does not contain `module.`
    from collections import OrderedDict
    new_state_dict = OrderedDict()
    for k, v in state_dict.items():
        head = k[:7]
        if head == 'module.':
            name = k[7:]  # remove `module.`
        else:
            name = k
        new_state_dict[name] = v
    # print('new_state_dict: {}'.format(new_state_dict))
    net.load_state_dict(new_state_dict, strict=False)

if num_gpu > 1 and gpu_train:
    net = torch.nn.DataParallel(net, device_ids=[0, 1]).cuda()
else:
    net = net.cuda()

cudnn.benchmark = True

optimizer = optim.SGD(net.parameters(),
                      lr=initial_lr,
                      momentum=momentum,
                      weight_decay=weight_decay)
criterion = MultiBoxLoss(num_classes, 0.35, True, 0, True, 7, 0.35, False)

priorbox = PriorBox(cfg, image_size=(img_dim, img_dim))
Exemple #5
0
def main():
    args = get_args()
    if not os.path.exists(args.save_folder):
        os.mkdir(args.save_folder)
    cfg = None
    if args.network == "mobile0.25":
        cfg = cfg_mnet
    elif args.network == "resnet50":
        cfg = cfg_re50

    rgb_mean = (104, 117, 123)  # bgr order
    num_classes = 2
    img_dim = cfg["image_size"]
    num_gpu = cfg["ngpu"]
    batch_size = cfg["batch_size"]
    max_epoch = cfg["epoch"]
    gpu_train = cfg["gpu_train"]

    num_workers = args.num_workers
    momentum = args.momentum
    weight_decay = args.weight_decay
    initial_lr = args.lr
    gamma = args.gamma
    training_dataset = args.training_dataset
    save_folder = args.save_folder

    net = RetinaFace(cfg=cfg)
    print("Printing net...")
    print(net)

    if args.resume_net is not None:
        print("Loading resume network...")
        state_dict = torch.load(args.resume_net)
        # create new OrderedDict that does not contain `module.`
        from collections import OrderedDict

        new_state_dict = OrderedDict()
        for k, v in state_dict.items():
            head = k[:7]
            if head == "module.":
                name = k[7:]  # remove `module.`
            else:
                name = k
            new_state_dict[name] = v
        net.load_state_dict(new_state_dict)

    if num_gpu > 1 and gpu_train:
        net = torch.nn.DataParallel(net).cuda()
    else:
        net = net.cuda()

    cudnn.benchmark = True

    optimizer = optim.SGD(net.parameters(),
                          lr=initial_lr,
                          momentum=momentum,
                          weight_decay=weight_decay)
    criterion = MultiBoxLoss(num_classes, 0.35, True, 0, True, 7, 0.35, False)

    priorbox = PriorBox(cfg, image_size=(img_dim, img_dim))
    with torch.no_grad():
        priors = priorbox.forward()
        priors = priors.cuda()

    net.train()
    epoch = 0 + args.resume_epoch
    print("Loading Dataset...")

    dataset = WiderFaceDetection(training_dataset, preproc(img_dim, rgb_mean))

    epoch_size = math.ceil(len(dataset) / batch_size)
    max_iter = max_epoch * epoch_size

    stepvalues = (cfg["decay1"] * epoch_size, cfg["decay2"] * epoch_size)
    step_index = 0

    if args.resume_epoch > 0:
        start_iter = args.resume_epoch * epoch_size
    else:
        start_iter = 0

    for iteration in range(start_iter, max_iter):
        if iteration % epoch_size == 0:
            # create batch iterator
            batch_iterator = iter(
                data.DataLoader(dataset,
                                batch_size,
                                shuffle=True,
                                num_workers=num_workers,
                                collate_fn=detection_collate))
            if (epoch % 10 == 0 and epoch > 0) or (epoch % 5 == 0
                                                   and epoch > cfg["decay1"]):
                torch.save(
                    net.state_dict(), save_folder + cfg["name"] + "_epoch_" +
                    str(epoch) + ".pth")
            epoch += 1

        load_t0 = time.time()
        if iteration in stepvalues:
            step_index += 1
        lr = adjust_learning_rate(initial_lr, optimizer, gamma, epoch,
                                  step_index, iteration, epoch_size)

        # load train data
        images, targets = next(batch_iterator)
        images = images.cuda()
        targets = [anno.cuda() for anno in targets]

        # forward
        out = net(images)

        # backprop
        optimizer.zero_grad()
        loss_l, loss_c, loss_landm = criterion(out, priors, targets)
        loss = cfg["loc_weight"] * loss_l + loss_c + loss_landm
        loss.backward()
        optimizer.step()
        load_t1 = time.time()
        batch_time = load_t1 - load_t0
        eta = int(batch_time * (max_iter - iteration))
        print(
            "Epoch:{}/{} || Epochiter: {}/{} || Iter: {}/{} || Loc: {:.4f} Cla: {:.4f} Landm: {:.4f} "
            "|| LR: {:.8f} || Batchtime: {:.4f} s || ETA: {}".format(
                epoch,
                max_epoch,
                (iteration % epoch_size) + 1,
                epoch_size,
                iteration + 1,
                max_iter,
                loss_l.item(),
                loss_c.item(),
                loss_landm.item(),
                lr,
                batch_time,
                str(datetime.timedelta(seconds=eta)),
            ))

    torch.save(net.state_dict(), save_folder + cfg["name"] + "_Final.pth")
Exemple #6
0
class FaceDetector:
    def __init__(self, device="cuda", confidence_threshold=0.8):
        self.device = device
        self.confidence_threshold = confidence_threshold

        self.cfg = cfg = cfg_re50
        self.variance = cfg["variance"]
        cfg["pretrain"] = False

        self.net = RetinaFace(cfg=cfg, phase="test").to(device).eval()
        self.decode_param_cache = {}

    def load_checkpoint(self, path):
        self.net.load_state_dict(torch.load(path))

    def decode_params(self, height, width):
        cache_key = (height, width)

        try:
            return self.decode_param_cache[cache_key]
        except KeyError:
            priorbox = PriorBox(self.cfg, image_size=(height, width))
            priors = priorbox.forward()

            prior_data = priors.data
            scale = torch.Tensor([width, height] * 2)
            scale1 = torch.Tensor([width, height] * 5)

            result = (prior_data, scale, scale1)
            self.decode_param_cache[cache_key] = result

            return result

    def detect(self, img):
        device = self.device

        prior_data, scale, scale1 = self.decode_params(*img.shape[:2])

        # REF: test_fddb.py
        img = np.float32(img)
        img -= (104, 117, 123)
        img = img.transpose(2, 0, 1)
        img = torch.from_numpy(img).unsqueeze(0)
        img = img.to(device, dtype=torch.float32)

        loc, conf, landms = self.net(img)

        loc = loc.cpu()
        conf = conf.cpu()
        landms = landms.cpu()

        # Decode results
        boxes = decode(loc.squeeze(0), prior_data, self.variance)
        boxes = boxes * scale
        scores = conf.squeeze(0)[:, 1]

        landms = decode_landm(landms.squeeze(0), prior_data, self.variance)
        landms = landms * scale1

        inds = scores > self.confidence_threshold
        boxes = boxes[inds]
        landms = landms[inds]

        return boxes, landms
Exemple #7
0
def main():
    global args
    global minmum_loss
    args.gpu = 0
    args.world_size = 1

    if args.distributed:
        args.gpu = args.local_rank % torch.cuda.device_count()
        torch.cuda.set_device(args.gpu)
        torch.distributed.init_process_group(backend='nccl',
                                             init_method='env://')
        args.world_size = torch.distributed.get_world_size()

    args.total_batch_size = args.world_size * args.batch_size

    # build dsfd network
    print("Building net...")
    model = RetinaFace(cfg=cfg)
    print("Printing net...")

    # for multi gpu
    if args.distributed:
        model = torch.nn.parallel.DistributedDataParallel(model)

    model = model.cuda()
    # optimizer and loss function
    optimizer = optim.SGD(model.parameters(),
                          lr=args.lr,
                          momentum=args.momentum,
                          weight_decay=args.weight_decay)
    criterion = MultiBoxLoss(cfg['num_classes'], 0.35, True, 0, True, 7, 0.35,
                             False)

    ## dataset
    print("loading dataset")
    train_dataset = WiderFaceDetection(
        args.training_dataset, preproc(cfg['image_size'], cfg['rgb_mean']))

    train_loader = data.DataLoader(train_dataset,
                                   args.batch_size,
                                   num_workers=args.num_workers,
                                   shuffle=True,
                                   collate_fn=detection_collate,
                                   pin_memory=True)

    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(
                args.resume,
                map_location=lambda storage, loc: storage.cuda(args.gpu))
            args.start_epoch = checkpoint['epoch']
            minmum_loss = checkpoint['minmum_loss']
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    print('Using the specified args:')
    print(args)
    # load PriorBox
    print("Load priorbox")
    with torch.no_grad():
        priorbox = PriorBox(cfg=cfg,
                            image_size=(cfg['image_size'], cfg['image_size']))
        priors = priorbox.forward()
        priors = priors.cuda()

    print("start traing")
    for epoch in range(args.start_epoch, args.epochs):
        # train for one epoch
        train_loss = train(train_loader, model, priors, criterion, optimizer,
                           epoch)
        if args.local_rank == 0:
            is_best = train_loss < minmum_loss
            minmum_loss = min(train_loss, minmum_loss)
            save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'state_dict': model.state_dict(),
                    'best_prec1': minmum_loss,
                    'optimizer': optimizer.state_dict(),
                }, is_best, epoch)
Exemple #8
0
        print('load landmarks model : {}'.format(ops.landmarks_model))

    #--------------------------------------------------------------------------- 构建人脸检测模型
    cfg = None
    if ops.detect_network == "mobile0.25":
        cfg = cfg_mnet
    elif ops.detect_network == "resnet50":
        cfg = cfg_re50
    # net and model
    detect_model = RetinaFace(cfg=cfg, phase='test')

    detect_model = detect_model.to(device)

    if os.access(ops.detect_model, os.F_OK):  # checkpoint
        chkpt = torch.load(ops.detect_model, map_location=device)
        detect_model.load_state_dict(chkpt)
        print('load detect model : {}'.format(ops.detect_model))

    detect_model.eval()
    if use_cuda:
        cudnn.benchmark = True

    print('loading model done ~')
    #-------------------------------------------------------------------------- run vedio
    video_capture = cv2.VideoCapture(ops.test_path)
    with torch.no_grad():
        idx = 0
        while True:
            ret, img_o = video_capture.read()

            if ret:
def main(args):
    # dataset
    rgb_mean = (104, 117, 123)  # bgr order
    dataset = MyDataset(args.txt_path, args.txt_path2,
                        preproc(args.img_size, rgb_mean))
    dataloader = DataLoader(dataset,
                            args.bs,
                            shuffle=True,
                            num_workers=args.num_workers,
                            collate_fn=detection_collate,
                            pin_memory=True)

    # net and load
    net = RetinaFace(cfg=cfg_mnet)
    if args.resume_net is not None:
        print('Loading resume network...')
        state_dict = load_normal(args.resume_net)
        net.load_state_dict(state_dict)
        print('Loading success!')
    net = net.cuda()
    if torch.cuda.device_count() >= 1 and args.multi_gpu:
        net = torch.nn.DataParallel(net)

    # optimizer and loss
    optimizer = optim.SGD(net.parameters(),
                          lr=args.lr,
                          momentum=0.9,
                          weight_decay=5e-4)
    scheduler = WarmupCosineSchedule(optimizer,
                                     args.warm_epoch, args.max_epoch,
                                     len(dataloader), args.cycles)

    num_classes = 2
    criterion = MultiBoxLoss(num_classes, 0.35, True, 0, True, 7, 0.35, False)

    # priorbox
    priorbox = PriorBox(cfg_mnet, image_size=(args.img_size, args.img_size))
    with torch.no_grad():
        priors = priorbox.forward()
        priors = priors.cuda()

    # save folder
    time_str = datetime.datetime.strftime(datetime.datetime.now(),
                                          '%y-%m-%d-%H-%M-%S')
    args.save_folder = os.path.join(args.save_folder, time_str)
    if not os.path.exists(args.save_folder):
        os.makedirs(args.save_folder)
    logger = logger_init(args.save_folder)
    logger.info(args)

    # train
    for i_epoch in range(args.max_epoch):
        net.train()

        for i_iter, data in enumerate(dataloader):
            load_t0 = time.time()
            images, targets = data[:2]
            images = images.cuda()
            targets = [anno.cuda() for anno in targets]

            # forward
            out = net(images)

            # backward
            optimizer.zero_grad()
            loss_l, loss_c, loss_landm = criterion(out, priors, targets)
            loss = cfg_mnet['loc_weight'] * loss_l + loss_c + loss_landm
            loss.backward()
            optimizer.step()
            scheduler.step()

            # print info
            load_t1 = time.time()
            batch_time = load_t1 - load_t0
            eta = int(batch_time * (len(dataloader) *
                                    (args.max_epoch - i_epoch) - i_iter))
            logger.info('Epoch:{}/{} || Iter: {}/{} || '
                        'Loc: {:.4f} Cla: {:.4f} Landm: {:.4f} || '
                        'LR: {:.8f} || '
                        'Batchtime: {:.4f} s || '
                        'ETA: {}'.format(
                            i_epoch + 1, args.max_epoch, i_iter + 1,
                            len(dataloader), loss_l.item(), loss_c.item(),
                            loss_landm.item(),
                            optimizer.state_dict()['param_groups'][0]['lr'],
                            batch_time, str(datetime.timedelta(seconds=eta))))

        if (i_epoch + 1) % args.save_fre == 0:
            save_name = 'mobile0.25_' + str(i_epoch + 1) + '.pth'
            torch.save(net.state_dict(),
                       os.path.join(args.save_folder, save_name))
    def Train(self):
        self.setup()
        cfg = self.system_dict["local"]["cfg"]
        print(cfg)

        rgb_mean = (104, 117, 123)  # bgr order
        num_classes = 2
        img_dim = cfg['image_size']
        num_gpu = cfg['ngpu']
        batch_size = cfg['batch_size']
        max_epoch = cfg['epoch']
        gpu_train = cfg['gpu_train']

        num_workers = self.system_dict["params"]["num_workers"]
        momentum = self.system_dict["params"]["momentum"]
        weight_decay = self.system_dict["params"]["weight_decay"]
        initial_lr = self.system_dict["params"]["lr"]
        gamma = self.system_dict["params"]["gamma"]
        save_folder = self.system_dict["params"]["save_folder"]

        print("Loading Network...")
        net = RetinaFace(cfg=cfg)

        if self.system_dict["params"]["resume_net"] is not None:
            print('Loading resume network...')
            state_dict = torch.load(self.system_dict["params"]["resume_net"])
            # create new OrderedDict that does not contain `module.`
            from collections import OrderedDict
            new_state_dict = OrderedDict()
            for k, v in state_dict.items():
                head = k[:7]
                if head == 'module.':
                    name = k[7:]  # remove `module.`
                else:
                    name = k
                new_state_dict[name] = v
            net.load_state_dict(new_state_dict)

        if num_gpu > 1 and gpu_train:
            net = torch.nn.DataParallel(net).cuda()
        else:
            net = net.cuda()
        cudnn.benchmark = True
        print("Done...")

        optimizer = optim.SGD(net.parameters(),
                              lr=initial_lr,
                              momentum=momentum,
                              weight_decay=weight_decay)
        criterion = MultiBoxLoss(num_classes, 0.35, True, 0, True, 7, 0.35,
                                 False)

        priorbox = PriorBox(cfg, image_size=(img_dim, img_dim))
        with torch.no_grad():
            priors = priorbox.forward()
            priors = priors.cuda()

        net.train()
        epoch = 0 + self.system_dict["params"]["resume_epoch"]
        dataset = self.system_dict["local"]["dataset"]

        epoch_size = math.ceil(len(dataset) / batch_size)
        max_iter = max_epoch * epoch_size

        stepvalues = (cfg['decay1'] * epoch_size, cfg['decay2'] * epoch_size)
        step_index = 0

        if self.system_dict["params"]["resume_epoch"] > 0:
            start_iter = self.system_dict["params"]["resume_epoch"] * epoch_size
        else:
            start_iter = 0

        for iteration in range(start_iter, max_iter):
            if iteration % epoch_size == 0:
                # create batch iterator
                batch_iterator = iter(
                    data.DataLoader(dataset,
                                    batch_size,
                                    shuffle=True,
                                    num_workers=num_workers,
                                    collate_fn=detection_collate))
                torch.save(
                    net.state_dict(),
                    save_folder + "/" + cfg['name'] + '_intermediate.pth')
                epoch += 1

            load_t0 = time.time()
            if iteration in stepvalues:
                step_index += 1
            lr = self.adjust_learning_rate(optimizer, gamma, epoch, step_index,
                                           iteration, epoch_size, initial_lr)

            # load train data
            images, targets = next(batch_iterator)
            images = images.cuda()
            targets = [anno.cuda() for anno in targets]

            # forward
            out = net(images)

            # backprop
            optimizer.zero_grad()
            loss_l, loss_c, loss_landm = criterion(out, priors, targets)
            loss = cfg['loc_weight'] * loss_l + loss_c + loss_landm
            loss.backward()
            optimizer.step()
            load_t1 = time.time()
            batch_time = load_t1 - load_t0
            eta = int(batch_time * (max_iter - iteration))
            if (iteration % 50 == 0):
                print(
                    'Epoch:{}/{} || Epochiter: {}/{} || Iter: {}/{} || Loc: {:.4f} Cla: {:.4f} Landm: {:.4f} || LR: {:.8f} || Batchtime: {:.4f} s || ETA: {}'
                    .format(epoch, max_epoch, (iteration % epoch_size) + 1,
                            epoch_size, iteration + 1, max_iter, loss_l.item(),
                            loss_c.item(), loss_landm.item(), lr, batch_time,
                            str(datetime.timedelta(seconds=eta))))

        torch.save(net.state_dict(),
                   save_folder + "/" + cfg['name'] + '_Final.pth')

ALLOWED_EXTENSIONS = {'png', 'jpg', 'jpeg'}
app = Flask(__name__)
# cfg = cfg_re50
cfg = cfg_mnet
# trained_model = './weights/Resnet50_epoch_95.pth'
retina_trained_model = './weights/mobilenet0.25_epoch_245.pth'
mask_trained_model = './weights/net_21.pth'
pfld_trained_model = './weights/checkpoint_epoch_500.pth.tar'
# net and model
retina_net = RetinaFace(cfg=cfg, phase='test')
mask_net = resnet50(num_classes=2)
pfld_backbone = PFLDInference()
# load pre-trained model
retina_net.load_state_dict(torch.load(retina_trained_model))
mask_net.load_state_dict(torch.load(mask_trained_model))
pfld_backbone.load_state_dict(torch.load(pfld_trained_model)['plfd_backbone'])

retina_net = retina_net.cuda(0)
mask_net = mask_net.cuda(0)
pfld_net = pfld_backbone.cuda(0)
retina_net.eval()
mask_net.eval()
pfld_net.eval()

resize = 1
top_k = 5000
keep_top_k = 750
nms_threshold = 0.5
def train(opt, train_dict, device, tb_writer=None):
    log_dir = Path(tb_writer.log_dir) if tb_writer else Path(
        train_dict['logdir']) / 'logs'
    wdir = str(log_dir / 'weights') + os.sep
    os.makedirs(wdir, exist_ok=True)
    last = wdir + 'last.pt'
    best = wdir + 'best.pt'
    results_file = 'results.txt'
    with open(log_dir / 'hyp.yaml', 'w') as f:
        yaml.dump(train_dict, f, sort_keys=False)
    with open(log_dir / 'opt.yaml', 'w') as f:
        yaml.dump(vars(opt), f, sort_keys=False)

    # Configure
    cuda = device.type != 'cpu'
    rank = opt.global_rank
    init_seeds(2 + rank)
    train_path = train_dict['train']
    test_path = train_dict['val']
    train_dict['weights'] = last if not train_dict['pretrain'] or (
        train_dict['pretrain'] and
        not os.path.exists(train_dict['weights'])) else train_dict['weights']
    model = RetinaFace(train_dict, phase='Train')
    pretrained = False
    if os.path.exists(train_dict['weights']):
        pretrained = True
        logger('Loading resume network from ====>{}'.format(
            train_dict['weights']))
        state_dict = torch.load(train_dict['weights'], map_location=device)
        # create new OrderedDict that does not contain `module.`
        from collections import OrderedDict
        new_state_dict = OrderedDict()
        for k, v in state_dict['model'].items():
            head = k[:7]
            if head == 'module.':
                name = k[7:]  # remove `module.`
            else:
                name = k
            new_state_dict[name] = v
        model.load_state_dict(new_state_dict)

    pg0, pg1, pg2 = [], [], []  # optimizer parameter groups
    for k, v in model.named_parameters():
        v.requires_grad = True
        if '.bias' in k:
            pg2.append(v)  # biases
        elif '.weight' in k and '.bn' not in k:
            pg1.append(v)  # apply weight decay
        else:
            pg0.append(v)  # all else

    if train_dict['adam']:
        optimizer = optim.Adam(pg0,
                               lr=train_dict['lr0'],
                               betas=(train_dict['momentum'],
                                      0.999))  # adjust beta1 to momentum
    else:
        optimizer = optim.SGD(pg0,
                              lr=train_dict['lr0'],
                              momentum=train_dict['momentum'],
                              nesterov=True)
    optimizer.add_param_group({
        'params': pg1,
        'weight_decay': train_dict['weight_decay']
    })  # add pg1 with weight_decay
    optimizer.add_param_group({'params': pg2})  # add pg2 (biases)
    logger.info('Optimizer groups: %g .bias, %g conv.weight, %g other' %
                (len(pg2), len(pg1), len(pg0)))
    del pg0, pg1, pg2

    epochs = train_dict['epoch']
    lf = lambda x: ((
        (1 + math.cos(x * math.pi / epochs)) / 2)**1.0) * 0.8 + 0.2  # cosine
    scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lf)
    plot_lr_scheduler(optimizer, scheduler, epochs)

    # Resume
    start_epoch, best_fitness = 0, 0.0
    if pretrained:
        # Optimizer
        if state_dict['optimizer'] is not None:
            optimizer.load_state_dict(ckpt['optimizer'])
            best_fitness = state_dict['best_fitness']

        # Results
        if state_dict.get('training_results') is not None:
            with open(results_file, 'w') as file:
                file.write(state_dict['training_results'])  # write results.txt

        # Epochs
        start_epoch = state_dict['epoch'] + 1
        if epochs < start_epoch:
            logger.info(
                '%s has been trained for %g epochs. Fine-tuning for %g additional epochs.'
                % (weights, state_dict['epoch'], epochs))
            epochs += state_dict['epoch']  # finetune additional epochs

        del ckpt, state_dict

    if train_dict['sync_bn'] and cuda and rank != -1:
        model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model).to(device)
        logger.info('Using SyncBatchNorm()')

    # Exponential moving average
    ema = ModelEMA(model) if rank in [-1, 0] else None

    # ddp
    if cuda and rank != -1:
        model = DDP(model,
                    device_ids=[opt.local_rank],
                    output_device=(opt.local_rank))

    # Trainloader
    batch_size = train_dict['batch_size']
    image_size = train_dict['image_size']
    # dataloader, dataset = create_dataloader(train_path,image_size, batch_size, opt, hyp=train_dict, augment=True,
    #                                         rect=opt.rect, rank=rank,
    #                                         world_size=opt.world_size, workers=train_dict['workers'])
    rgb_mean = (104, 117, 123)  # bgr order
    dataset = WiderFaceDetection(train_path, preproc(image_size, rgb_mean))
    sampler = torch.utils.data.distributed.DistributedSampler(dataset)
    dataloader = torch.utils.data.DataLoader(dataset,
                                             batch_size=batch_size,
                                             num_workers=8,
                                             sampler=sampler,
                                             pin_memory=True,
                                             collate_fn=detection_collate)

    criterion = MultiBoxLoss(num_classes, 0.35, True, 0, True, 7, 0.35, False)
    priorbox = PriorBox(train_dict, image_size=(image_size, image_size))
    with torch.no_grad():
        priors = priorbox.forward()
        priors = priors.cuda()
    for epoch in range(start_epoch, epochs):
        if rank != -1:
            dataloader.sampler.set_epoch(epoch)
        pbar = enumerate(dataloader)
        if rank in [-1, 0]:
            pbar = tqdm(pbar)  # progress bar
        optimizer.zero_grad()
        for i, (
                images, targets
        ) in pbar:  # batch -------------------------------------------------------------
            with amp.autocast(enabled=cuda):
                images = images.cuda()
                targets = [anno.cuda() for anno in targets]
                out = model(images)
                optimizer.zero_grad()
                loss_l, loss_c, loss_landm = criterion(
                    out, priors, targets) * opt.world_size
                loss = cfg['loc_weight'] * loss_l + loss_c + loss_landm
                loss.backward()
                optimizer.step()
                load_t1 = time.time()
                batch_time = load_t1 - load_t0
                eta = int(batch_time * (max_iter - iteration))
                if rank in [-1, 0]:
                    print(
                        'Epoch:{}/{} || Epochiter: {}/{} || Iter: {}/{} || Loc: {:.4f} Cla: {:.4f} Landm: {:.4f} || LR: {:.8f} || Batchtime: {:.4f} s || ETA: {}'
                        .format(epoch, max_epoch, (iteration % epoch_size) + 1,
                                epoch_size, iteration + 1, max_iter,
                                loss_l.item(), loss_c.item(),
                                loss_landm.item(), lr, batch_time,
                                str(datetime.timedelta(seconds=eta))))
                    torch.save(net.state_dict(),
                               wdir + os.sep + '{}_Final.pth'.format(i))
def train():

    net = RetinaFace(cfg=cfg)
    logger.info("Printing net...")
    logger.info(net)

    if args.resume_net is not None:
        logger.info('Loading resume network...')
        state_dict = torch.load(args.resume_net)
        # create new OrderedDict that does not contain `module.`
        from collections import OrderedDict
        new_state_dict = OrderedDict()
        for k, v in state_dict.items():
            head = k[:7]
            if head == 'module.':
                name = k[7:] # remove `module.`
            else:
                name = k
            new_state_dict[name] = v
        net.load_state_dict(new_state_dict)

    if num_gpu > 1 and gpu_train:
        net = torch.nn.DataParallel(net).cuda()
    else:
        net = net.cuda()

    cudnn.benchmark = True

    priorbox = PriorBox(cfg, image_size=(img_dim, img_dim))
    with torch.no_grad():
        priors = priorbox.forward()
        priors = priors.cuda()

    net.train()
    epoch = 0 + args.resume_epoch
    logger.info('Loading Dataset...')

    trainset = WiderFaceDetection(training_dataset, preproc=train_preproc(img_dim, rgb_mean), mode='train')
    validset = WiderFaceDetection(training_dataset, preproc=valid_preproc(img_dim, rgb_mean), mode='valid')
    # trainset = WiderFaceDetection(training_dataset, transformers=train_transformers(img_dim), mode='train')
    # validset = WiderFaceDetection(training_dataset, transformers=valid_transformers(img_dim), mode='valid')
    trainloader = data.DataLoader(trainset, batch_size, shuffle=True, num_workers=num_workers, collate_fn=detection_collate)
    validloader = data.DataLoader(validset, batch_size, shuffle=True, num_workers=num_workers, collate_fn=detection_collate)
    logger.info(f'Totally {len(trainset)} training samples and {len(validset)} validating samples.')

    epoch_size = math.ceil(len(trainset) / batch_size)
    max_iter = max_epoch * epoch_size
    logger.info(f'max_epoch: {max_epoch:d} epoch_size: {epoch_size:d}, max_iter: {max_iter:d}')

    # optimizer = optim.SGD(net.parameters(), lr=initial_lr, momentum=momentum, weight_decay=weight_decay)
    optimizer = optim.Adam(net.parameters(), lr=initial_lr, weight_decay=weight_decay)
    scheduler = _utils.get_linear_schedule_with_warmup(optimizer, int(0.1 * max_iter), max_iter)
    criterion = MultiBoxLoss(num_classes, 0.35, True, 0, True, 7, 0.35, False)

    stepvalues = (cfg['decay1'] * epoch_size, cfg['decay2'] * epoch_size)
    step_index = 0

    if args.resume_epoch > 0:
        start_iter = args.resume_epoch * epoch_size
    else:
        start_iter = 0

    best_loss_val = float('inf')
    for iteration in range(start_iter, max_iter):
        if iteration % epoch_size == 0:
            # create batch iterator
            # batch_iterator = iter(tqdm(trainloader, total=len(trainloader)))
            batch_iterator = iter(trainloader)
            # if (epoch % 10 == 0 and epoch > 0) or (epoch % 5 == 0 and epoch > cfg['decay1']):
            #     torch.save(net.state_dict(), save_folder + cfg['name']+ '_epoch_' + str(epoch) + '.pth')
            epoch += 1
            torch.cuda.empty_cache()

        if (valid_steps > 0) and (iteration > 0) and (iteration % valid_steps == 0):
            net.eval()
            # validation
            loss_l_val = 0.
            loss_c_val = 0.
            loss_landm_val = 0.
            loss_val = 0.
            # for val_no, (images, targets) in tqdm(enumerate(validloader), total=len(validloader)):
            for val_no, (images, targets) in enumerate(validloader):
                # load data
                images = images.cuda()
                targets = [anno.cuda() for anno in targets]
                # forward
                with torch.no_grad():
                    out = net(images)
                    loss_l, loss_c, loss_landm = criterion(out, priors, targets)
                    loss = cfg['loc_weight'] * loss_l + loss_c + loss_landm
                loss_l_val += loss_l.item()
                loss_c_val += loss_c.item()
                loss_landm_val += loss_landm.item()
                loss_val += loss.item()
            loss_l_val /= len(validloader)
            loss_c_val /= len(validloader)
            loss_landm_val /= len(validloader)
            loss_val /= len(validloader)
            logger.info('[Validating] Epoch:{}/{} || Epochiter: {}/{} || Iter: {}/{} || Total: {:.4f} Loc: {:.4f} Cla: {:.4f} Landm: {:.4f}'
                .format(epoch, max_epoch, (iteration % epoch_size) + 1,
                epoch_size, iteration + 1, max_iter, 
                loss_val, loss_l_val, loss_c_val, loss_landm_val))
            if loss_val < best_loss_val:
                best_loss_val = loss_val
                pth = os.path.join(save_folder, cfg['name'] + '_iter_' + str(iteration) + f'_{loss_val:.4f}_' + '.pth')
                torch.save(net.state_dict(), pth)
                logger.info(f'Best validating loss: {best_loss_val:.4f}, model saved as {pth:s})')
            net.train()

        load_t0 = time.time()
        # if iteration in stepvalues:
        #     step_index += 1
        # lr = adjust_learning_rate(optimizer, gamma, epoch, step_index, iteration, epoch_size)

        # load train data
        images, targets = next(batch_iterator)
        images = images.cuda()
        targets = [anno.cuda() for anno in targets]

        # forward
        out = net(images)

        # backprop
        optimizer.zero_grad()
        loss_l, loss_c, loss_landm = criterion(out, priors, targets)
        loss = cfg['loc_weight'] * loss_l + loss_c + loss_landm
        loss.backward()
        optimizer.step()
        scheduler.step()
        load_t1 = time.time()
        batch_time = load_t1 - load_t0
        eta = int(batch_time * (max_iter - iteration))
        if iteration % verbose_steps == 0:
            logger.info('[Training] Epoch:{}/{} || Epochiter: {}/{} || Iter: {}/{} || Total: {:.4f} Loc: {:.4f} Cla: {:.4f} Landm: {:.4f} || LR: {:.8f} || Batchtime: {:.4f} s || ETA: {}'
                .format(epoch, max_epoch, (iteration % epoch_size) + 1,
                epoch_size, iteration + 1, max_iter, 
                loss.item(), loss_l.item(), loss_c.item(), loss_landm.item(), 
                scheduler.get_last_lr()[-1], batch_time, str(datetime.timedelta(seconds=eta))))
print(net)

if args.resume_net is not None:
    print('Loading resume network...')
    state_dict = torch.load(args.resume_net)
    # create new OrderedDict that does not contain `module.`
    from collections import OrderedDict
    new_state_dict = OrderedDict()
    for k, v in state_dict.items():
        head = k[:7]
        if head == 'module.':
            name = k[7:]  # remove `module.`
        else:
            name = k
        new_state_dict[name] = v
    net.load_state_dict(new_state_dict)

if args.resume_teacher_net is not None:
    print('Loading resume teacher net...')
    state_dict = torch.load(args.resume_teacher_net)
    # create new OrderedDict that does not contain `module.`
    from collections import OrderedDict
    new_state_dict = OrderedDict()
    for k, v in state_dict.items():
        head = k[:7]
        if head == 'module.':
            name = k[7:]  # remove `module.`
        else:
            name = k
        new_state_dict[name] = v
    teacher_net.load_state_dict(new_state_dict)