Ejemplo n.º 1
0
    def _preprocess(self):
        encoders = (encoder.Cif(self.head_metas[0], bmin=self.b_min),
                    encoder.Caf(self.head_metas[1], bmin=self.b_min))

        if not self.augmentation:
            return transforms.Compose([
                transforms.NormalizeAnnotations(),
                transforms.RescaleAbsolute(self.square_edge),
                transforms.CenterPad(self.square_edge),
                transforms.EVAL_TRANSFORM,
                transforms.Encoders(encoders),
            ])

        if self.extended_scale:
            rescale_t = transforms.RescaleRelative(
                scale_range=(0.2 * self.rescale_images,
                             2.0 * self.rescale_images),
                power_law=True,
                stretch_range=(0.75, 1.33))
        else:
            rescale_t = transforms.RescaleRelative(
                scale_range=(0.2 * self.rescale_images,
                             1.5 * self.rescale_images),
                power_law=True,
                stretch_range=(0.75, 1.33))

        blur_t = None
        if self.blur:
            blur_t = transforms.RandomApply(transforms.Blur(), self.blur)

        orientation_t = None
        if self.orientation_invariant:
            orientation_t = transforms.RandomApply(transforms.RotateBy90(),
                                                   self.orientation_invariant)

        return transforms.Compose([
            transforms.NormalizeAnnotations(),
            transforms.AnnotationJitter(),
            transforms.RandomApply(
                transforms.HFlip(self.CAR_KEYPOINTS, self.HFLIP), 0.5),
            rescale_t,
            blur_t,
            transforms.Crop(self.square_edge, use_area_of_interest=True),
            transforms.CenterPad(self.square_edge),
            orientation_t,
            transforms.TRAIN_TRANSFORM,
            transforms.Encoders(encoders),
        ])
Ejemplo n.º 2
0
    def common_eval_preprocess(cls):
        rescale_t = None
        if cls.eval_extended_scale:
            assert cls.eval_long_edge
            rescale_t = [
                transforms.DeterministicEqualChoice([
                    transforms.RescaleAbsolute(cls.eval_long_edge),
                    transforms.RescaleAbsolute((cls.eval_long_edge - 1) // 2 + 1),
                ], salt=1)
            ]
        elif cls.eval_long_edge:
            rescale_t = transforms.RescaleAbsolute(cls.eval_long_edge)

        if cls.batch_size == 1:
            padding_t = transforms.CenterPadTight(16)
        else:
            assert cls.eval_long_edge
            padding_t = transforms.CenterPad(cls.eval_long_edge)

        orientation_t = None
        if cls.eval_orientation_invariant:
            orientation_t = transforms.DeterministicEqualChoice([
                None,
                transforms.RotateBy90(fixed_angle=90),
                transforms.RotateBy90(fixed_angle=180),
                transforms.RotateBy90(fixed_angle=270),
            ], salt=3)

        return [
            transforms.NormalizeAnnotations(),
            rescale_t,
            padding_t,
            orientation_t,
        ]
Ejemplo n.º 3
0
def preprocess_factory_from_args(args):
    collate_fn = datasets.collate_images_anns_meta
    if args.batch_size == 1 and not args.multi_scale:
        preprocess = transforms.Compose([
            transforms.NormalizeAnnotations(),
            transforms.RescaleAbsolute(args.long_edge),
            transforms.EVAL_TRANSFORM,
        ])
    else:
        preprocess = transforms.Compose([
            transforms.NormalizeAnnotations(),
            transforms.RescaleAbsolute(args.long_edge),
            transforms.CenterPad(args.long_edge),
            transforms.EVAL_TRANSFORM,
        ])

    return preprocess, collate_fn
Ejemplo n.º 4
0
def preprocess_factory(args):
    rescale_t = None
    if args.long_edge:
        rescale_t = transforms.RescaleAbsolute(args.long_edge, fast=args.fast_rescaling)

    pad_t = None
    if args.batch_size > 1:
        assert args.long_edge, '--long-edge must be provided for batch size > 1'
        pad_t = transforms.CenterPad(args.long_edge)
    else:
        pad_t = transforms.CenterPadTight(16)

    return transforms.Compose([
        transforms.NormalizeAnnotations(),
        rescale_t,
        pad_t,
        transforms.EVAL_TRANSFORM,
    ])
Ejemplo n.º 5
0
def test_pad(x=4, y=6):
    image_xy, keypoint_xy = single_pixel_transform(x, y, transforms.CenterPad(17))
    print(image_xy, keypoint_xy)
    assert image_xy == keypoint_xy
Ejemplo n.º 6
0
def main():
    args = cli()
    if args.our_new_model:
        args.checkpoint = TRAINED_MODEL_PATH
    # load model
    model_cpu, _ = nets.factory_from_args(args)
    model = model_cpu.to(args.device)
    if not args.disable_cuda and torch.cuda.device_count() > 1:
        LOG.info('Using multiple GPUs: %d', torch.cuda.device_count())
        model = torch.nn.DataParallel(model)
        model.head_names = model_cpu.head_names
        model.head_strides = model_cpu.head_strides
    processor = decoder.factory_from_args(args, model, args.device)

    # data
    preprocess = None
    if args.long_edge:
        preprocess = transforms.Compose([
            transforms.NormalizeAnnotations(),
            transforms.RescaleAbsolute(args.long_edge),
            transforms.CenterPad(args.long_edge),
            transforms.EVAL_TRANSFORM,
        ])
    data = datasets.ImageList(args.images, preprocess=preprocess)
    data_loader = torch.utils.data.DataLoader(
        data,
        batch_size=args.batch_size,
        shuffle=False,
        pin_memory=args.pin_memory,
        num_workers=args.loader_workers,
        collate_fn=datasets.collate_images_anns_meta)

    # visualizers
    keypoint_painter = show.KeypointPainter(
        show_box=args.debug,
        show_joint_scale=args.debug,
    )
    skeleton_painter = show.KeypointPainter(
        color_connections=True,
        markersize=args.line_width - 5,
        linewidth=args.line_width,
        show_box=args.debug,
        show_joint_scale=args.debug,
    )

    for batch_i, (image_tensors_batch, _,
                  meta_batch) in enumerate(data_loader):
        fields_batch = processor.fields(image_tensors_batch)
        pred_batch = processor.annotations_batch(
            fields_batch, debug_images=image_tensors_batch)

        # unbatch
        for pred, meta in zip(pred_batch, meta_batch):
            if args.output_directory is None:
                output_path = meta['file_name']
            else:
                file_name = os.path.basename(meta['file_name'])
                output_path = os.path.join(args.output_directory, file_name)
            LOG.info('batch %d: %s to %s', batch_i, meta['file_name'],
                     output_path)

            # load the original image if necessary
            cpu_image = None
            if args.debug or \
                    'keypoints' in args.output_types or \
                    'skeleton' in args.output_types:
                with open(meta['file_name'], 'rb') as f:
                    cpu_image = PIL.Image.open(f).convert('RGB')

            processor.set_cpu_image(cpu_image, None)
            if preprocess is not None:
                pred = preprocess.annotations_inverse(pred, meta)

            if 'json' in args.output_types:
                with open(output_path + '.pifpaf.json', 'w') as f:
                    json.dump([{
                        'keypoints':
                        np.around(ann.data, 1).reshape(-1).tolist(),
                        'bbox':
                        np.around(bbox_from_keypoints(ann.data), 1).tolist(),
                        'score':
                        round(ann.score(), 3),
                    } for ann in pred], f)

            if 'keypoints' in args.output_types:
                with show.image_canvas(cpu_image,
                                       output_path + '.keypoints.png',
                                       show=args.show,
                                       fig_width=args.figure_width,
                                       dpi_factor=args.dpi_factor) as ax:
                    keypoint_painter.annotations(ax, pred)

            if 'skeleton' in args.output_types:
                with show.image_canvas(cpu_image,
                                       output_path + '.skeleton.png',
                                       show=args.show,
                                       fig_width=args.figure_width,
                                       dpi_factor=args.dpi_factor) as ax:
                    skeleton_painter.annotations(ax, pred)
Ejemplo n.º 7
0
def main():
    args = cli()
    logs.configure(args)
    net_cpu, start_epoch = nets.factory_from_args(args)

    net = net_cpu.to(device=args.device)
    if not args.disable_cuda and torch.cuda.device_count() > 1:
        print('Using multiple GPUs: {}'.format(torch.cuda.device_count()))
        net = torch.nn.DataParallel(net)

    loss = losses.factory_from_args(args)
    target_transforms = encoder.factory(args, net_cpu.head_strides)

    if args.augmentation:
        preprocess_transformations = [
            transforms.NormalizeAnnotations(),
            transforms.AnnotationJitter(),
            transforms.RandomApply(transforms.HFlip(), 0.5),
            transforms.RescaleRelative(scale_range=(0.4 * args.rescale_images,
                                                    2.0 * args.rescale_images),
                                       power_law=True),
            transforms.Crop(args.square_edge),
            transforms.CenterPad(args.square_edge),
        ]
        if args.orientation_invariant:
            preprocess_transformations += [
                transforms.RotateBy90(),
            ]
        preprocess_transformations += [
            transforms.TRAIN_TRANSFORM,
        ]
    else:
        preprocess_transformations = [
            transforms.NormalizeAnnotations(),
            transforms.RescaleAbsolute(args.square_edge),
            transforms.CenterPad(args.square_edge),
            transforms.EVAL_TRANSFORM,
        ]
    preprocess = transforms.Compose(preprocess_transformations)
    train_loader, val_loader, pre_train_loader = datasets.train_factory(
        args, preprocess, target_transforms)

    optimizer = optimize.factory_optimizer(
        args,
        list(net.parameters()) + list(loss.parameters()))
    lr_scheduler = optimize.factory_lrscheduler(args, optimizer,
                                                len(train_loader))
    encoder_visualizer = None
    if args.debug_pif_indices or args.debug_paf_indices:
        encoder_visualizer = encoder.Visualizer(
            args.headnets,
            net_cpu.head_strides,
            pif_indices=args.debug_pif_indices,
            paf_indices=args.debug_paf_indices)

    if args.freeze_base:
        # freeze base net parameters
        frozen_params = set()
        for n, p in net.named_parameters():
            # Freeze only base_net parameters.
            # Parameter names in DataParallel models start with 'module.'.
            if not n.startswith('module.base_net.') and \
               not n.startswith('base_net.'):
                print('not freezing', n)
                continue
            print('freezing', n)
            if p.requires_grad is False:
                continue
            p.requires_grad = False
            frozen_params.add(p)
        print('froze {} parameters'.format(len(frozen_params)))

        # training
        foptimizer = torch.optim.SGD(
            (p for p in net.parameters() if p.requires_grad),
            lr=args.pre_lr,
            momentum=0.9,
            weight_decay=0.0,
            nesterov=True)
        ftrainer = Trainer(net,
                           loss,
                           foptimizer,
                           args.output,
                           device=args.device,
                           fix_batch_norm=True,
                           encoder_visualizer=encoder_visualizer)
        for i in range(-args.freeze_base, 0):
            ftrainer.train(pre_train_loader, i)

        # unfreeze
        for p in frozen_params:
            p.requires_grad = True

    trainer = Trainer(
        net,
        loss,
        optimizer,
        args.output,
        lr_scheduler=lr_scheduler,
        device=args.device,
        fix_batch_norm=not args.update_batchnorm_runningstatistics,
        stride_apply=args.stride_apply,
        ema_decay=args.ema,
        encoder_visualizer=encoder_visualizer,
        train_profile=args.profile,
        model_meta_data={
            'args': vars(args),
            'version': VERSION,
            'hostname': socket.gethostname(),
        },
    )
    trainer.loop(train_loader,
                 val_loader,
                 args.epochs,
                 start_epoch=start_epoch)
Ejemplo n.º 8
0
def generate(m, inputs):
    args = cli()
    model, processor = m
    image = inputs["image"]

    # data
    preprocess = None
    if args.long_edge:
        preprocess = transforms.Compose([
            transforms.Normalize(),
            transforms.RescaleAbsolute(args.long_edge),
            transforms.CenterPad(args.long_edge),
        ])
    data = datasets.PilImageList([image], preprocess=preprocess)
    data_loader = torch.utils.data.DataLoader(data,
                                              batch_size=args.batch_size,
                                              shuffle=False,
                                              pin_memory=args.pin_memory,
                                              num_workers=args.loader_workers)

    # visualizers
    keypoint_painter = show.KeypointPainter(show_box=False)
    skeleton_painter = show.KeypointPainter(show_box=False,
                                            color_connections=True,
                                            markersize=1,
                                            linewidth=6)

    image_paths, image_tensors, processed_images_cpu = next(iter(data_loader))
    images = image_tensors.permute(0, 2, 3, 1)

    processed_images = processed_images_cpu.to(args.device, non_blocking=True)
    fields_batch = processor.fields(processed_images)
    pred_batch = processor.annotations_batch(fields_batch,
                                             debug_images=processed_images_cpu)

    # unbatch
    image_path, image, processed_image_cpu, pred = image_paths[0], images[
        0], processed_images_cpu[0], pred_batch[0]

    processor.set_cpu_image(image, processed_image_cpu)
    keypoint_sets, scores = processor.keypoint_sets_from_annotations(pred)

    kp_json = json.dumps([{
        'keypoints': np.around(kps, 1).reshape(-1).tolist(),
        'bbox': bbox_from_keypoints(kps),
    } for kps in keypoint_sets])

    kwargs = {
        'figsize': (args.figure_width,
                    args.figure_width * image.shape[0] / image.shape[1]),
    }
    fig = plt.figure(**kwargs)
    ax = plt.Axes(fig, [0.0, 0.0, 1.0, 1.0])
    ax.set_axis_off()
    ax.set_xlim(0, image.shape[1])
    ax.set_ylim(image.shape[0], 0)
    fig.add_axes(ax)
    ax.imshow(image)
    skeleton_painter.keypoints(ax, keypoint_sets, scores=scores)

    fig.canvas.draw()
    w, h = fig.canvas.get_width_height()
    output_image = np.fromstring(fig.canvas.tostring_rgb(),
                                 dtype='uint8').reshape(h, w, 3)

    return {'keypoints': kp_json, 'image': output_image}