def common_eval_preprocess(cls): rescale_t = None if cls.eval_extended_scale: assert cls.eval_long_edge rescale_t = [ transforms.DeterministicEqualChoice([ transforms.RescaleAbsolute(cls.eval_long_edge), transforms.RescaleAbsolute((cls.eval_long_edge - 1) // 2 + 1), ], salt=1) ] elif cls.eval_long_edge: rescale_t = transforms.RescaleAbsolute(cls.eval_long_edge) if cls.batch_size == 1: padding_t = transforms.CenterPadTight(16) else: assert cls.eval_long_edge padding_t = transforms.CenterPad(cls.eval_long_edge) orientation_t = None if cls.eval_orientation_invariant: orientation_t = transforms.DeterministicEqualChoice([ None, transforms.RotateBy90(fixed_angle=90), transforms.RotateBy90(fixed_angle=180), transforms.RotateBy90(fixed_angle=270), ], salt=3) return [ transforms.NormalizeAnnotations(), rescale_t, padding_t, orientation_t, ]
def test_rotateby90(x, y=6): transform = transforms.Compose([ transforms.SquarePad(), transforms.RotateBy90(), ]) image_xy, keypoint_xy = single_pixel_transform(x, y, transform) print(image_xy, keypoint_xy) assert image_xy == pytest.approx(keypoint_xy)
def _preprocess(self): encoders = (encoder.Cif(self.head_metas[0], bmin=self.b_min), encoder.Caf(self.head_metas[1], bmin=self.b_min)) if not self.augmentation: return transforms.Compose([ transforms.NormalizeAnnotations(), transforms.RescaleAbsolute(self.square_edge), transforms.CenterPad(self.square_edge), transforms.EVAL_TRANSFORM, transforms.Encoders(encoders), ]) if self.extended_scale: rescale_t = transforms.RescaleRelative( scale_range=(0.2 * self.rescale_images, 2.0 * self.rescale_images), power_law=True, stretch_range=(0.75, 1.33)) else: rescale_t = transforms.RescaleRelative( scale_range=(0.2 * self.rescale_images, 1.5 * self.rescale_images), power_law=True, stretch_range=(0.75, 1.33)) blur_t = None if self.blur: blur_t = transforms.RandomApply(transforms.Blur(), self.blur) orientation_t = None if self.orientation_invariant: orientation_t = transforms.RandomApply(transforms.RotateBy90(), self.orientation_invariant) return transforms.Compose([ transforms.NormalizeAnnotations(), transforms.AnnotationJitter(), transforms.RandomApply( transforms.HFlip(self.CAR_KEYPOINTS, self.HFLIP), 0.5), rescale_t, blur_t, transforms.Crop(self.square_edge, use_area_of_interest=True), transforms.CenterPad(self.square_edge), orientation_t, transforms.TRAIN_TRANSFORM, transforms.Encoders(encoders), ])
def main(): args = cli() logs.configure(args) net_cpu, start_epoch = nets.factory_from_args(args) net = net_cpu.to(device=args.device) if not args.disable_cuda and torch.cuda.device_count() > 1: print('Using multiple GPUs: {}'.format(torch.cuda.device_count())) net = torch.nn.DataParallel(net) loss = losses.factory_from_args(args) target_transforms = encoder.factory(args, net_cpu.head_strides) if args.augmentation: preprocess_transformations = [ transforms.NormalizeAnnotations(), transforms.AnnotationJitter(), transforms.RandomApply(transforms.HFlip(), 0.5), transforms.RescaleRelative(scale_range=(0.4 * args.rescale_images, 2.0 * args.rescale_images), power_law=True), transforms.Crop(args.square_edge), transforms.CenterPad(args.square_edge), ] if args.orientation_invariant: preprocess_transformations += [ transforms.RotateBy90(), ] preprocess_transformations += [ transforms.TRAIN_TRANSFORM, ] else: preprocess_transformations = [ transforms.NormalizeAnnotations(), transforms.RescaleAbsolute(args.square_edge), transforms.CenterPad(args.square_edge), transforms.EVAL_TRANSFORM, ] preprocess = transforms.Compose(preprocess_transformations) train_loader, val_loader, pre_train_loader = datasets.train_factory( args, preprocess, target_transforms) optimizer = optimize.factory_optimizer( args, list(net.parameters()) + list(loss.parameters())) lr_scheduler = optimize.factory_lrscheduler(args, optimizer, len(train_loader)) encoder_visualizer = None if args.debug_pif_indices or args.debug_paf_indices: encoder_visualizer = encoder.Visualizer( args.headnets, net_cpu.head_strides, pif_indices=args.debug_pif_indices, paf_indices=args.debug_paf_indices) if args.freeze_base: # freeze base net parameters frozen_params = set() for n, p in net.named_parameters(): # Freeze only base_net parameters. # Parameter names in DataParallel models start with 'module.'. if not n.startswith('module.base_net.') and \ not n.startswith('base_net.'): print('not freezing', n) continue print('freezing', n) if p.requires_grad is False: continue p.requires_grad = False frozen_params.add(p) print('froze {} parameters'.format(len(frozen_params))) # training foptimizer = torch.optim.SGD( (p for p in net.parameters() if p.requires_grad), lr=args.pre_lr, momentum=0.9, weight_decay=0.0, nesterov=True) ftrainer = Trainer(net, loss, foptimizer, args.output, device=args.device, fix_batch_norm=True, encoder_visualizer=encoder_visualizer) for i in range(-args.freeze_base, 0): ftrainer.train(pre_train_loader, i) # unfreeze for p in frozen_params: p.requires_grad = True trainer = Trainer( net, loss, optimizer, args.output, lr_scheduler=lr_scheduler, device=args.device, fix_batch_norm=not args.update_batchnorm_runningstatistics, stride_apply=args.stride_apply, ema_decay=args.ema, encoder_visualizer=encoder_visualizer, train_profile=args.profile, model_meta_data={ 'args': vars(args), 'version': VERSION, 'hostname': socket.gethostname(), }, ) trainer.loop(train_loader, val_loader, args.epochs, start_epoch=start_epoch)