示例#1
0
    def common_eval_preprocess(cls):
        rescale_t = None
        if cls.eval_extended_scale:
            assert cls.eval_long_edge
            rescale_t = [
                transforms.DeterministicEqualChoice([
                    transforms.RescaleAbsolute(cls.eval_long_edge),
                    transforms.RescaleAbsolute((cls.eval_long_edge - 1) // 2 + 1),
                ], salt=1)
            ]
        elif cls.eval_long_edge:
            rescale_t = transforms.RescaleAbsolute(cls.eval_long_edge)

        if cls.batch_size == 1:
            padding_t = transforms.CenterPadTight(16)
        else:
            assert cls.eval_long_edge
            padding_t = transforms.CenterPad(cls.eval_long_edge)

        orientation_t = None
        if cls.eval_orientation_invariant:
            orientation_t = transforms.DeterministicEqualChoice([
                None,
                transforms.RotateBy90(fixed_angle=90),
                transforms.RotateBy90(fixed_angle=180),
                transforms.RotateBy90(fixed_angle=270),
            ], salt=3)

        return [
            transforms.NormalizeAnnotations(),
            rescale_t,
            padding_t,
            orientation_t,
        ]
示例#2
0
def test_rotateby90(x, y=6):
    transform = transforms.Compose([
        transforms.SquarePad(),
        transforms.RotateBy90(),
    ])
    image_xy, keypoint_xy = single_pixel_transform(x, y, transform)
    print(image_xy, keypoint_xy)
    assert image_xy == pytest.approx(keypoint_xy)
示例#3
0
    def _preprocess(self):
        encoders = (encoder.Cif(self.head_metas[0], bmin=self.b_min),
                    encoder.Caf(self.head_metas[1], bmin=self.b_min))

        if not self.augmentation:
            return transforms.Compose([
                transforms.NormalizeAnnotations(),
                transforms.RescaleAbsolute(self.square_edge),
                transforms.CenterPad(self.square_edge),
                transforms.EVAL_TRANSFORM,
                transforms.Encoders(encoders),
            ])

        if self.extended_scale:
            rescale_t = transforms.RescaleRelative(
                scale_range=(0.2 * self.rescale_images,
                             2.0 * self.rescale_images),
                power_law=True,
                stretch_range=(0.75, 1.33))
        else:
            rescale_t = transforms.RescaleRelative(
                scale_range=(0.2 * self.rescale_images,
                             1.5 * self.rescale_images),
                power_law=True,
                stretch_range=(0.75, 1.33))

        blur_t = None
        if self.blur:
            blur_t = transforms.RandomApply(transforms.Blur(), self.blur)

        orientation_t = None
        if self.orientation_invariant:
            orientation_t = transforms.RandomApply(transforms.RotateBy90(),
                                                   self.orientation_invariant)

        return transforms.Compose([
            transforms.NormalizeAnnotations(),
            transforms.AnnotationJitter(),
            transforms.RandomApply(
                transforms.HFlip(self.CAR_KEYPOINTS, self.HFLIP), 0.5),
            rescale_t,
            blur_t,
            transforms.Crop(self.square_edge, use_area_of_interest=True),
            transforms.CenterPad(self.square_edge),
            orientation_t,
            transforms.TRAIN_TRANSFORM,
            transforms.Encoders(encoders),
        ])
示例#4
0
def main():
    args = cli()
    logs.configure(args)
    net_cpu, start_epoch = nets.factory_from_args(args)

    net = net_cpu.to(device=args.device)
    if not args.disable_cuda and torch.cuda.device_count() > 1:
        print('Using multiple GPUs: {}'.format(torch.cuda.device_count()))
        net = torch.nn.DataParallel(net)

    loss = losses.factory_from_args(args)
    target_transforms = encoder.factory(args, net_cpu.head_strides)

    if args.augmentation:
        preprocess_transformations = [
            transforms.NormalizeAnnotations(),
            transforms.AnnotationJitter(),
            transforms.RandomApply(transforms.HFlip(), 0.5),
            transforms.RescaleRelative(scale_range=(0.4 * args.rescale_images,
                                                    2.0 * args.rescale_images),
                                       power_law=True),
            transforms.Crop(args.square_edge),
            transforms.CenterPad(args.square_edge),
        ]
        if args.orientation_invariant:
            preprocess_transformations += [
                transforms.RotateBy90(),
            ]
        preprocess_transformations += [
            transforms.TRAIN_TRANSFORM,
        ]
    else:
        preprocess_transformations = [
            transforms.NormalizeAnnotations(),
            transforms.RescaleAbsolute(args.square_edge),
            transforms.CenterPad(args.square_edge),
            transforms.EVAL_TRANSFORM,
        ]
    preprocess = transforms.Compose(preprocess_transformations)
    train_loader, val_loader, pre_train_loader = datasets.train_factory(
        args, preprocess, target_transforms)

    optimizer = optimize.factory_optimizer(
        args,
        list(net.parameters()) + list(loss.parameters()))
    lr_scheduler = optimize.factory_lrscheduler(args, optimizer,
                                                len(train_loader))
    encoder_visualizer = None
    if args.debug_pif_indices or args.debug_paf_indices:
        encoder_visualizer = encoder.Visualizer(
            args.headnets,
            net_cpu.head_strides,
            pif_indices=args.debug_pif_indices,
            paf_indices=args.debug_paf_indices)

    if args.freeze_base:
        # freeze base net parameters
        frozen_params = set()
        for n, p in net.named_parameters():
            # Freeze only base_net parameters.
            # Parameter names in DataParallel models start with 'module.'.
            if not n.startswith('module.base_net.') and \
               not n.startswith('base_net.'):
                print('not freezing', n)
                continue
            print('freezing', n)
            if p.requires_grad is False:
                continue
            p.requires_grad = False
            frozen_params.add(p)
        print('froze {} parameters'.format(len(frozen_params)))

        # training
        foptimizer = torch.optim.SGD(
            (p for p in net.parameters() if p.requires_grad),
            lr=args.pre_lr,
            momentum=0.9,
            weight_decay=0.0,
            nesterov=True)
        ftrainer = Trainer(net,
                           loss,
                           foptimizer,
                           args.output,
                           device=args.device,
                           fix_batch_norm=True,
                           encoder_visualizer=encoder_visualizer)
        for i in range(-args.freeze_base, 0):
            ftrainer.train(pre_train_loader, i)

        # unfreeze
        for p in frozen_params:
            p.requires_grad = True

    trainer = Trainer(
        net,
        loss,
        optimizer,
        args.output,
        lr_scheduler=lr_scheduler,
        device=args.device,
        fix_batch_norm=not args.update_batchnorm_runningstatistics,
        stride_apply=args.stride_apply,
        ema_decay=args.ema,
        encoder_visualizer=encoder_visualizer,
        train_profile=args.profile,
        model_meta_data={
            'args': vars(args),
            'version': VERSION,
            'hostname': socket.gethostname(),
        },
    )
    trainer.loop(train_loader,
                 val_loader,
                 args.epochs,
                 start_epoch=start_epoch)