def test(model, path):

    batch_size = 64
    dataloader = torch.utils.data.DataLoader(ListDataset(path),
                                             batch_size=batch_size,
                                             shuffle=True)
    dataset_sizes = len(ListDataset(path))

    model.eval()
    loss_cls = nn.CrossEntropyLoss()
    loss_offset = nn.MSELoss()
    loss_landmark = nn.MSELoss()

    running_correct = 0
    running_gt = 0
    running_loss, running_loss_cls, running_loss_offset, running_loss_landmark = 0.0, 0.0, 0.0, 0.0

    for i_batch, sample_batched in enumerate(dataloader):

        printProgressBar(i_batch + 1,
                         dataset_sizes // batch_size + 1,
                         prefix='Progress:',
                         suffix='Complete',
                         length=50)

        input_images, gt_label, gt_offset, landmark_offset = sample_batched[
            'input_img'], sample_batched['label'], sample_batched[
                'bbox_target'], sample_batched['landmark']
        input_images = input_images.to(device)
        gt_label = gt_label.to(device)
        gt_offset = gt_offset.type(torch.FloatTensor).to(device)
        landmark_offset = landmark_offset.type(torch.FloatTensor).to(device)

        with torch.set_grad_enabled(False):
            pred_landmark, pred_offsets, pred_label = model(input_images)

            mask_cls = torch.ge(gt_label, 0)
            valid_gt_label = gt_label[mask_cls]
            valid_pred_label = pred_label[mask_cls]

            unmask = torch.eq(gt_label, 0)
            mask_offset = torch.eq(unmask, 0)
            valid_gt_offset = gt_offset[mask_offset]
            valid_pred_offset = pred_offsets[mask_offset]

            mask_lm = torch.eq(gt_label, -2)
            valid_landmark_offset = landmark_offset[mask_lm]
            valid_pred_landmark = pred_landmark[mask_lm]

            loss = torch.tensor(0.0).to(device)
            num_gt = len(valid_gt_label)

            if len(valid_gt_label) != 0:
                loss += 0.02 * loss_cls(valid_pred_label, valid_gt_label)
                cls_loss = loss_cls(valid_pred_label, valid_gt_label).item()
                pred = torch.max(valid_pred_label, 1)[1]
                eval_correct = (pred == valid_gt_label).sum().item()

            if len(valid_gt_offset) != 0:
                loss += 0.6 * loss_offset(valid_pred_offset, valid_gt_offset)
                offset_loss = loss_offset(valid_pred_offset,
                                          valid_gt_offset).item()

            if len(valid_landmark_offset) != 0:
                loss += 3 * loss_landmark(valid_pred_landmark,
                                          valid_landmark_offset)
                landmark_loss = loss_landmark(valid_pred_landmark,
                                              valid_landmark_offset)

            # statistics
            running_loss += loss.item() * batch_size
            running_loss_cls += cls_loss * batch_size
            running_loss_offset += offset_loss * batch_size
            running_loss_landmark += landmark_loss * batch_size
            running_correct += eval_correct
            running_gt += num_gt

    epoch_loss = running_loss / dataset_sizes
    epoch_loss_cls = running_loss_cls / dataset_sizes
    epoch_loss_offset = running_loss_offset / dataset_sizes
    epoch_loss_landmark = running_loss_landmark / dataset_sizes
    epoch_accuracy = running_correct / (running_gt + 1e-16)

    return epoch_accuracy, epoch_loss, epoch_loss_cls, epoch_loss_offset, epoch_loss_landmark
def prune_model(model, prunner, path):

    batch_size = 64
    dataloader = torch.utils.data.DataLoader(ListDataset(path),
                                             batch_size=batch_size,
                                             shuffle=True)
    dataset_sizes = len(ListDataset(path))

    model.train()
    loss_cls = nn.CrossEntropyLoss()
    loss_offset = nn.MSELoss()
    loss_landmark = nn.MSELoss()

    prunner.reset()

    for i_batch, sample_batched in enumerate(dataloader):

        printProgressBar(i_batch + 1,
                         dataset_sizes // batch_size + 1,
                         prefix='Progress:',
                         suffix='Complete',
                         length=50)

        input_images, gt_label, gt_offset, landmark_offset = sample_batched[
            'input_img'], sample_batched['label'], sample_batched[
                'bbox_target'], sample_batched['landmark']
        input_images = input_images.to(device)
        gt_label = gt_label.to(device)
        gt_offset = gt_offset.type(torch.FloatTensor).to(device)
        landmark_offset = landmark_offset.type(torch.FloatTensor).to(device)

        # zero the parameter gradients
        model.zero_grad()

        with torch.set_grad_enabled(True):
            pred_landmark, pred_offsets, pred_label = prunner.forward(
                input_images)

            # calculate the cls loss
            # get the mask element which >= 0, only 0 and 1 can effect the detection loss
            mask_cls = torch.ge(gt_label, 0)
            valid_gt_label = gt_label[mask_cls]
            valid_pred_label = pred_label[mask_cls]

            # calculate the box loss
            # get the mask element which != 0
            unmask = torch.eq(gt_label, 0)
            mask_offset = torch.eq(unmask, 0)
            valid_gt_offset = gt_offset[mask_offset]
            valid_pred_offset = pred_offsets[mask_offset]

            # calculate the landmark loss
            # get the mask element which = -2
            mask_lm = torch.eq(gt_label, -2)
            valid_landmark_offset = landmark_offset[mask_lm]
            valid_pred_landmark = pred_landmark[mask_lm]

            loss = torch.tensor(0.0).to(device)

            if len(valid_gt_label) != 0:
                loss += 0.02 * loss_cls(valid_pred_label, valid_gt_label)

            if len(valid_gt_offset) != 0:
                loss += 0.6 * loss_offset(valid_pred_offset, valid_gt_offset)

            if len(valid_landmark_offset) != 0:
                loss += 3 * loss_landmark(valid_pred_landmark,
                                          valid_landmark_offset)

            loss.backward()

    prunner.normalize_ranks_per_layer()
    filters_to_prune = prunner.get_prunning_plan(args.filter_size)

    return filters_to_prune
def train(model, path, epoch=10):

    batch_size = 32
    dataloader = torch.utils.data.DataLoader(ListDataset(path),
                                             batch_size=batch_size,
                                             shuffle=True)
    dataset_sizes = len(ListDataset(path))

    model.train()
    loss_cls = nn.CrossEntropyLoss()
    loss_offset = nn.MSELoss()
    loss_landmark = nn.MSELoss()

    optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

    num_epochs = epoch
    for epoch in range(num_epochs):
        print('Epoch {}/{}'.format(epoch, num_epochs - 1))

        running_loss, running_loss_cls, running_loss_offset, running_loss_landmark = 0.0, 0.0, 0.0, 0.0
        running_correct = 0.0
        running_gt = 0.0

        for i_batch, sample_batched in enumerate(dataloader):

            printProgressBar(i_batch + 1,
                             dataset_sizes // batch_size + 1,
                             prefix='Progress:',
                             suffix='Complete',
                             length=50)

            input_images, gt_label, gt_offset, landmark_offset = sample_batched[
                'input_img'], sample_batched['label'], sample_batched[
                    'bbox_target'], sample_batched['landmark']
            input_images = input_images.to(device)
            gt_label = gt_label.to(device)
            gt_offset = gt_offset.type(torch.FloatTensor).to(device)
            landmark_offset = landmark_offset.type(
                torch.FloatTensor).to(device)

            # zero the parameter gradients
            optimizer.zero_grad()

            with torch.set_grad_enabled(True):
                pred_landmark, pred_offsets, pred_label = model(input_images)

                # calculate the cls loss
                # get the mask element which >= 0, only 0 and 1 can effect the detection loss
                mask_cls = torch.ge(gt_label, 0)
                valid_gt_label = gt_label[mask_cls]
                valid_pred_label = pred_label[mask_cls]

                # calculate the box loss
                # get the mask element which != 0
                unmask = torch.eq(gt_label, 0)
                mask_offset = torch.eq(unmask, 0)
                valid_gt_offset = gt_offset[mask_offset]
                valid_pred_offset = pred_offsets[mask_offset]

                # calculate the landmark loss
                # get the mask element which = -2
                mask_lm = torch.eq(gt_label, -2)
                valid_landmark_offset = landmark_offset[mask_lm]
                valid_pred_landmark = pred_landmark[mask_lm]

                loss = torch.tensor(0.0).to(device)
                cls_loss, offset_loss, landmark_loss = 0.0, 0.0, 0.0
                eval_correct = 0.0
                num_gt = len(valid_gt_label)

                if len(valid_gt_label) != 0:
                    loss += 0.02 * loss_cls(valid_pred_label, valid_gt_label)
                    cls_loss = loss_cls(valid_pred_label,
                                        valid_gt_label).item()
                    pred = torch.max(valid_pred_label, 1)[1]
                    eval_correct = (pred == valid_gt_label).sum().item()

                if len(valid_gt_offset) != 0:
                    loss += 0.6 * loss_offset(valid_pred_offset,
                                              valid_gt_offset)
                    offset_loss = loss_offset(valid_pred_offset,
                                              valid_gt_offset).item()

                if len(valid_landmark_offset) != 0:
                    loss += 3 * loss_landmark(valid_pred_landmark,
                                              valid_landmark_offset)
                    landmark_loss = loss_landmark(valid_pred_landmark,
                                                  valid_landmark_offset)

                loss.backward()
                optimizer.step()

                # statistics
                running_loss += loss.item() * batch_size
                running_loss_cls += cls_loss * batch_size
                running_loss_offset += offset_loss * batch_size
                running_loss_landmark += landmark_loss * batch_size
                running_correct += eval_correct
                running_gt += num_gt

        epoch_loss = running_loss / dataset_sizes
        epoch_loss_cls = running_loss_cls / dataset_sizes
        epoch_loss_offset = running_loss_offset / dataset_sizes
        epoch_loss_landmark = running_loss_landmark / dataset_sizes
        epoch_accuracy = running_correct / (running_gt + 1e-16)

        print(
            'accuracy: {:.4f} loss: {:.4f} cls Loss: {:.4f} offset Loss: {:.4f} landmark Loss: {:.4f}'
            .format(epoch_accuracy, epoch_loss, epoch_loss_cls,
                    epoch_loss_offset, epoch_loss_landmark))