class Detector:
    def __init__(self, detection_size=(160, 160)):
        self.model = Slim()
        self.model.load_state_dict(
            torch.load(open("pretrained_weights/slim_160_latest.pth", "rb"),
                       map_location="cpu"))
        self.model.eval()
        self.model.cuda()
        self.tracker = Tracker()
        self.detection_size = detection_size

    def crop_image(self, orig, bbox):
        bbox = bbox.copy()
        image = orig.copy()
        bbox_width = bbox[2] - bbox[0]
        bbox_height = bbox[3] - bbox[1]
        face_width = (1 + 2 * 0.25) * bbox_width
        face_height = (1 + 2 * 0.25) * bbox_height
        center = [(bbox[0] + bbox[2]) // 2, (bbox[1] + bbox[3]) // 2]
        bbox[0] = max(0, center[0] - face_width // 2)
        bbox[1] = max(0, center[1] - face_height // 2)
        bbox[2] = min(image.shape[1], center[0] + face_width // 2)
        bbox[3] = min(image.shape[0], center[1] + face_height // 2)
        bbox = bbox.astype(np.int)
        crop_image = image[bbox[1]:bbox[3], bbox[0]:bbox[2], :]
        h, w, _ = crop_image.shape
        crop_image = cv2.resize(crop_image, self.detection_size)
        return crop_image, ([h, w, bbox[1], bbox[0]])

    def detect(self, img, bbox):
        crop_image, detail = self.crop_image(img, bbox)
        crop_image = (crop_image - 127.0) / 127.0
        crop_image = np.array([np.transpose(crop_image, (2, 0, 1))])
        crop_image = torch.tensor(crop_image).float().cuda()
        with torch.no_grad():
            start = time.time()
            raw = self.model(crop_image)[0].cpu().numpy()
            end = time.time()
            print("PyTorch Inference Time: {:.6f}".format(end - start))
            landmark = raw[0:136].reshape((-1, 2))
        landmark[:, 0] = landmark[:, 0] * detail[1] + detail[3]
        landmark[:, 1] = landmark[:, 1] * detail[0] + detail[2]
        landmark = self.tracker.track(img, landmark)
        _, PRY_3d = get_head_pose(landmark, img)
        return landmark, PRY_3d[:, 0]
        "Eval Avg Loss  -- Total: {:.4f} Landmark: {:.4f} Poss: {:.4f} LEye: {:.4f} REye: {:.4f} Mouth: {:.4f}".format(
            avg_total_loss, avg_landmark_loss, avg_loss_pose, avg_leye_loss, avg_reye_loss, avg_mouth_loss))
    torch.save(model.state_dict(), open("weights/slim128_epoch_{}_{:.4f}.pth".format(epoch, avg_landmark_loss), "wb"))


if __name__ == '__main__':
    checkpoint = None
    torch.backends.cudnn.benchmark = True
    train_dataset = Landmark("train.json", input_size, True)
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
    val_dataset = Landmark("val.json", input_size, False)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=4)

    model = Slim()
    model.train()
    model.cuda()
    if checkpoint is not None:
        model.load_state_dict(torch.load(checkpoint))
        start_epoch = int(checkpoint.split("_")[-2]) + 1
    else:
        start_epoch = 0

    wing_loss_fn = WingLoss()
    mse_loss_fn = torch.nn.MSELoss()
    bce_loss_fn = torch.nn.BCEWithLogitsLoss()

    optim = torch.optim.Adam(model.parameters(), lr=lr_value_every_epoch[0], weight_decay=5e-4)
    for epoch in range(start_epoch, 150):
        for param_group in optim.param_groups:
            param_group['lr'] = decay(epoch)
        train(epoch)