예제 #1
0
def infer():
    test_dataloader = create_dataloader(root='.',
                                        batch_size=32,
                                        is_train=False)

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    net = FAN()
    net.eval()
    net.to(device)
    net.load_state_dict(torch.load("ckpt_epoch_28"))

    running_nme = 0.0

    for sample in test_dataloader:
        images, targets, t_hm, boxes = sample['image'], sample['landmarks'], \
            sample['heatmaps'], sample['bbox']

        images = images.to(device)
        targets = targets.to(device)
        boxes = boxes.to(device)

        # DO NOT ACCUMULATE GRADIENTS DURING FORWARD PASS
        with torch.no_grad():
            preds = net(images)

        # break  # if we just want an image to illustrate

        # Compute batch NME
        running_nme += NME(preds[-1], targets, boxes)

    # Test data NME
    print("Evaluation NME:")
    print(running_nme / len(test_dataloader))
test_loader = DataLoader(test_data, batch_size=1, shuffle=True, num_workers=0)

# val_loader = DataLoader(val_data, batch_size=8, num_workers=0)

# #----------------------------------------------------------------------------------------------

# #------------------------- Creating Model and loading pretraind weights -----------------------

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
network_FAN = FAN(4)
network_FAN = network_FAN.to(device)

network_depth = ResNetDepth()

fan_weights = torch.load('checkpoint_60_dlib.pth.tar', map_location='cuda')
network_FAN.load_state_dict(fan_weights['model_state_dict'])
depth_weights = torch.load(
    '/home/kak/Documents/DFKI/TestingFAN/checkpoint_resnet_usingbb_3d.pth.tar',
    map_location='cuda')
#depth_weights = torch.load('/data/skak/project/FaceAlignmentNet/PretrainedModels/depth-2a464da4ea.pth.tar', map_location='cuda')

depth_dict = {
    k.replace('module.', ''): v
    for k, v in depth_weights['model_state_dict'].items()
}
network_depth.load_state_dict(depth_dict)

# for params in network_depth.parameters():
#     params.requires_grad = True

network_FAN.cuda()
예제 #3
0
class FaceAlignment:
    def __init__(self,
                 landmarks_type,
                 network_size=NetworkSize.LARGE,
                 device='cuda',
                 flip_input=False,
                 verbose=False,
                 model_dir=None):
        self.device = device
        self.flip_input = flip_input
        self.landmarks_type = landmarks_type
        self.verbose = verbose

        network_size = int(network_size)

        if 'cuda' in device:
            torch.backends.cudnn.benchmark = True

        # Initialise the face alignemnt networks
        self.face_alignment_net = FAN(network_size)
        if landmarks_type == LandmarksType.point_2D:
            model_name = '2DFAN-' + str(network_size)
        else:
            model_name = '3DFAN-' + str(network_size)

        if not os.path.exists(model_dir):
            raise Exception('Landmarks model directory not found')
        filename = models_urls[model_name]
        model_file = os.path.join(model_dir, filename)
        if not os.path.isfile(model_file):
            raise Exception(
                'Landmarks model file not found: {}'.format(model_file))
        if device == 'cpu':
            fan_weights = torch.load(model_file, map_location='cpu')
        else:
            fan_weights = torch.load(model_file)

        self.face_alignment_net.load_state_dict(fan_weights)

        self.face_alignment_net.to(device)
        self.face_alignment_net.eval()

        # Initialiase the depth prediciton network
        if landmarks_type == LandmarksType.point_3D:
            self.depth_prediciton_net = ResNetDepth()

            filename = models_urls['depth']
            model_file = os.path.join(model_dir, filename)
            if not os.path.exists(model_file):
                raise Exception('Landmarks depth model file not found')
            depth_weights = torch.load(model_file)

            depth_dict = {
                k.replace('module.', ''): v
                for k, v in depth_weights['state_dict'].items()
            }
            self.depth_prediciton_net.load_state_dict(depth_dict)

            self.depth_prediciton_net.to(device)
            self.depth_prediciton_net.eval()

    def get_landmarks(self, image_or_path, detected_face=None):
        """Predict the landmarks for each face present in the image.

        This function predicts a set of 68 2D or 3D images, one for each
         image present.
        If detect_faces is None the method will also run a face detector.

         Arguments:
            image_or_path {string or numpy.array or torch.tensor} --
            The input image or path to it.

        Keyword Arguments:
            detected_faces {numpy.array} -- bounding box for founded face
            in the image (default: None)
        """
        if isinstance(image_or_path, str):
            image = io.imread(image_or_path)
        else:
            image = image_or_path

        if image.ndim == 2:
            image = color.gray2rgb(image)
        elif image.ndim == 4:
            image = image[..., :3]

        if detected_face is None:
            raise Exception('No faces were received.')

        torch.set_grad_enabled(False)
        x1, y1, x2, y2 = detected_face

        center = torch.FloatTensor(
            [x2 - (x2 - x1) / 2.0, y2 - (y2 - y1) / 2.0])
        center[1] = center[1] - (y2 - y1) * 0.12
        scale = (x2 - x1 + y2 - y1) / 195

        inp = crop(image, center, scale)
        inp = torch.from_numpy(inp.transpose((2, 0, 1))).float()

        inp = inp.to(self.device)
        inp.div_(255.0).unsqueeze_(0)

        out = self.face_alignment_net(inp)[-1].detach()
        if self.flip_input:
            out += flip(self.face_alignment_net(flip(inp))[-1].detach(),
                        is_label=True)
        out = out.cpu()

        pts, pts_img, score = get_preds_fromhm(out, center, scale)

        score = score.view(68)
        pts = pts.view(68, 2) * 4
        pts_img = pts_img.view(68, 2)

        if self.landmarks_type == LandmarksType.point_3D:
            heatmaps = np.zeros((68, 256, 256), dtype=np.float32)
            for j in range(68):
                if pts[j, 0] > 0:
                    heatmaps[j] = draw_gaussian(heatmaps[j], pts[j], 2)
            heatmaps = torch.from_numpy(heatmaps).unsqueeze_(0)

            heatmaps = heatmaps.to(self.device)
            depth_pred = self.depth_prediciton_net(
                torch.cat((inp, heatmaps), 1)).data.cpu().view(68, 1)
            pts_img = torch.cat(
                (pts_img, depth_pred * (1.0 / (256.0 / (200.0 * scale)))), 1)

        landmarks = pts_img.numpy()
        scores = score.numpy()

        return landmarks, scores
예제 #4
0
def train(print_every=10):

    max_epoch = 30
    lr = {50: 1e-4, \
        70: 5e-5, \
 90: 1e-5, \
 100: 5e-6}
    batch_size = 32

    # Dataloaders for train and test set
    train_dataloader = create_dataloader(root='.',
                                         batch_size=batch_size,
                                         is_train=True)
    test_dataloader = create_dataloader(root='.',
                                        batch_size=32,
                                        is_train=False)

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    num_batches = len(train_dataloader)

    # Network configuration
    net = FAN()
    net = net.to(device)
    net = net.train()
    net.load_state_dict(torch.load("ckpt_epoch_28"))

    # Loss function and Optimizer
    criterion = nn.MSELoss()
    optimizer = optim.RMSprop(net.parameters(), lr=2.5e-4)
    best_nme = 99.0

    for epoch in range(max_epoch):
        print("=============Epoch %i================" % (epoch))

        # Adjust Learning Rate based on epoch number
        if epoch in list(lr.keys()):
            adjust_lr(optimizer, lr[epoch])

        # Train
        for i_batch, sample_batch in enumerate(train_dataloader, 0):
            net.zero_grad()  # Zero gradients after each batch
            loss = 0

            inputs, targets = sample_batch['image'], sample_batch['heatmaps']
            inputs = inputs.to(device)
            targets = targets.to(device)
            outputs = net(inputs)

            for o in outputs:
                loss += criterion(o, targets)
            loss.backward()
            optimizer.step()

            # Print training loss every N batches
            if (i_batch % print_every == 0):
                print("Batch %i/%i;	Loss: %.4f" %
                      (i_batch, num_batches, loss.item()))

        ## Evaluation at the end of epoch
        net = net.eval()
        current_nme = 0

        with torch.no_grad():
            for i_test, test_batch in enumerate(test_dataloader, 0):
                inputs, landmarks, boxes = test_batch['image'], \
                     test_batch['landmarks'], test_batch['bbox']
                inputs = inputs.to(device)
                boxes = boxes.to(device)
                landmarks = landmarks.to(device)
                outputs = net(inputs)
                nme = NME(outputs[-1], landmarks, boxes)
                current_nme += nme

        current_nme /= len(test_dataloader)
        print("Test NME: %.8f" % (current_nme))

        # Save model if it is the best thus far
        if current_nme < best_nme:
            best_nme = current_nme
            torch.save(net.state_dict(), "ckpt_epoch_" + str(epoch))
        net = net.train()
예제 #5
0
class FaceAlignment:
    def __init__(self, device='cuda', flip_input=False, face_detector='sfd', verbose=False):
        self.device = device
        self.flip_input = flip_input
        self.verbose = verbose

        network_size = 4

        if 'cuda' in device:
            torch.backends.cudnn.benchmark = True

        '''
        # Get the face detector
        face_detector_module = __import__('detection.' + face_detector,
                                          globals(), locals(), [face_detector], 0)
        self.face_detector = face_detector_module.FaceDetector(device=device, verbose=verbose)
        '''
        self.face_detector = SFDDetector(device=device, verbose=verbose)
        # Initialise the face alignemnt networks
        self.face_alignment_net = FAN(network_size)

        fan_weights = load('3DFAN.pth.tar', map_location=lambda storage, loc: storage)
        # Load all tensors onto GPU 1
        self.face_alignment_net.load_state_dict(fan_weights)

        self.face_alignment_net.to(device)
        self.face_alignment_net.eval()

        # Initialiase the depth prediciton network
        self.depth_prediciton_net = ResNetDepth()

        depth_weights = load('2D-to-3D.pth.tar', map_location=lambda storage, loc: storage)
        # Load all tensors onto GPU 1
        depth_dict = {
            k.replace('module.', ''): v for k,
                                            v in depth_weights['state_dict'].items()}
        self.depth_prediciton_net.load_state_dict(depth_dict)

        self.depth_prediciton_net.to(device)
        self.depth_prediciton_net.eval()

    def get_landmarks(self, image_or_path):
        tensor_or_path = torch.tensor(image_or_path)
        detected_faces = self.face_detector.detect_from_image(tensor_or_path)
        return self.get_landmarks_from_image(image_or_path, detected_faces)

    @torch.no_grad()
    def get_landmarks_from_image(self, image_or_path, detected_faces):
        """

        This function predicts a set of 68 3D images, one for each image present.
         Arguments:
            image_or_path {string or numpy.array or torch.tensor} -- The input image or path to it.
            detected_faces {list of numpy.array} -- list of bounding boxes, one for each face found
            in the image
        """
        if isinstance(image_or_path, str):
            try:
                image = io.imread(image_or_path)
            except IOError:
                print("error opening file :: ", image_or_path)
                return None
        elif isinstance(image_or_path, torch.Tensor):
            image = image_or_path.detach().cpu().numpy()
        else:
            image = image_or_path

        if image.ndim == 2:
            image = color.gray2rgb(image)
        elif image.ndim == 4:
            image = image[..., :3]

        if len(detected_faces) == 0:
            print("Warning: No faces were detected.")
            return None

        landmarks = []
        for i, d in enumerate(detected_faces):
            center = torch.FloatTensor(
                [d[2] - (d[2] - d[0]) / 2.0, d[3] - (d[3] - d[1]) / 2.0])
            center[1] = center[1] - (d[3] - d[1]) * 0.12
            scale = (d[2] - d[0] + d[3] - d[1]) / self.face_detector.reference_scale

            inp = crop(image, center, scale)
            inp = torch.from_numpy(inp.transpose(
                (2, 0, 1))).float()

            inp = inp.to(self.device)
            inp.div_(255.0).unsqueeze_(0)

            out = self.face_alignment_net(inp)[-1].detach()
            if self.flip_input:
                out += flip(self.face_alignment_net(flip(inp))
                            [-1].detach(), is_label=True)
            out = out.cpu()

            pts, pts_img = get_preds_fromhm(out, center, scale)
            pts, pts_img = pts.view(68, 2) * 4, pts_img.view(68, 2)

            heatmaps = np.zeros((68, 256, 256), dtype=np.float32)
            for i in range(68):
                if pts[i, 0] > 0:
                    heatmaps[i] = draw_gaussian(
                        heatmaps[i], pts[i], 2)
            heatmaps = torch.from_numpy(
                heatmaps).unsqueeze_(0)

            heatmaps = heatmaps.to(self.device)
            depth_pred = self.depth_prediciton_net(
                torch.cat((inp, heatmaps), 1)).data.cpu().view(68, 1)
            pts_img = torch.cat(
                (pts_img, depth_pred * (1.0 / (256.0 / (200.0 * scale)))), 1)

            landmarks.append(pts_img.numpy())
        return landmarks

    @staticmethod
    def remove_models(self):
        base_path = os.path.join(appdata_dir('face_alignment'), "data")
        for data_model in os.listdir(base_path):
            file_path = os.path.join(base_path, data_model)
            try:
                if os.path.isfile(file_path):
                    print('Removing ' + data_model + ' ...')
                    os.unlink(file_path)
            except Exception as e:
                print(e)
예제 #6
0
# test_data = Menpo(root_path='/home/kak/Documents/DFKI/MenpoTracking/Menpo_Challenge/test/')

# val_data, train_data = random_split(train_data,[2954,6000])

# train_loader = DataLoader(train_data, batch_size=1, shuffle=True, num_workers=0)

test_loader = DataLoader(test_data, batch_size=1, num_workers=0)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model1 = FAN(4)
# fan_weights1 = torch.load('checkpoint_60_dlib.pth.tar', map_location='cuda')
# model1.load_state_dict(fan_weights1['model_state_dict'])
fan_weights1 = torch.load(
    '/home/kak/Documents/DFKI/Resnet/PretrainedModels/3DFAN4-7835d9f11d.pth.tar',
    map_location='cuda')
model1.load_state_dict(fan_weights1)

model2 = ResNetDepth()
# depth_weights = torch.load('/home/kak/Documents/DFKI/TestingFAN/checkpoint_resnet_usingbb_3d.pth.tar', map_location='cuda')
depth_weights = torch.load(
    '/home/kak/Documents/DFKI/FaceAlignmentNet/PretrainedModels/depth-2a464da4ea.pth.tar',
    map_location='cuda')

depth_dict = {
    k.replace('module.', ''): v
    for k, v in depth_weights['state_dict'].items()
}
model2.load_state_dict(depth_dict)

# criterion = nn.MSELoss(reduction='mean')
# optimizer = optim.Adam(network_FAN.parameters(), lr=1e-6)