class GoturnTrain(LightningModule): """Docstring for GoturnTrain. """ def __init__(self, hparams, dbg=False): ''' Pytorch lightning module for training goturn tracker. @hparams: all the argparse arguments for training @dbg: boolean for switching on visualizer ''' logger.info('=' * 15) logger.info('GOTURN TRACKER') logger.info('=' * 15) super(GoturnTrain, self).__init__() self.__set_seed(hparams.seed) self.hparams = hparams logger.info('Setting up the network...') # network with pretrained model self._model = GoturnNetwork(self.hparams.pretrained_model) self._dbg = dbg if dbg: self._viz = Visualizer(port=8097) def __freeze(self): """Freeze the model features layer """ features_layer = self._model._net for param in features_layer.parameters(): param.requires_grad = False def _set_conv_layer(self, conv_layers, param_dict): for layer in conv_layers.modules(): if type(layer) == torch.nn.modules.conv.Conv2d: param_dict.append({ 'params': layer.weight, 'lr': 0, 'weight_decay': self.hparams.wd }) param_dict.append({ 'params': layer.bias, 'lr': 0, 'weight_decay': 0 }) return param_dict def __set_lr(self): '''set learning rate for classifier layer''' param_dict = [] if 1: conv_layers = self._model._net_1 param_dict = self._set_conv_layer(conv_layers, param_dict) conv_layers = self._model._net_2 param_dict = self._set_conv_layer(conv_layers, param_dict) regression_layer = self._model._classifier for layer in regression_layer.modules(): if type(layer) == torch.nn.modules.linear.Linear: param_dict.append({ 'params': layer.weight, 'lr': 10 * self.hparams.lr, 'weight_decay': self.hparams.wd }) param_dict.append({ 'params': layer.bias, 'lr': 20 * self.hparams.lr, 'weight_decay': 0 }) return param_dict def find_lr(self): """finding suitable learning rate """ model = self._model params = self.__set_lr() criterion = torch.nn.L1Loss(size_average=False) optimizer = CaffeSGD(params, lr=1e-8, momentum=self.hparams.momentum, weight_decay=self.hparams.wd) lr_finder = LRFinder(model, optimizer, criterion, device="cuda") trainloader = self.train_dataloader() lr_finder.range_test(trainloader, start_lr=1e-7, end_lr=1, num_iter=500) lr_finder.plot() def __set_seed(self, SEED): ''' set all the seeds for reproducibility ''' logger.info('Settings seed = {}'.format(SEED)) torch.manual_seed(SEED) np.random.seed(SEED) random.seed(SEED) cudnn.deterministic = True @staticmethod def add_model_specific_args(parent_parser): ''' These are specific parameters for the sample generator ''' ap = argparse.ArgumentParser(parents=[parent_parser]) ap.add_argument('--min_scale', type=float, default=-0.4, help='min scale') ap.add_argument('--max_scale', type=float, default=0.4, help='max scale') ap.add_argument('--lamda_shift', type=float, default=5) ap.add_argument('--lamda_scale', type=int, default=15) return ap def configure_optimizers(self): """Configure optimizers""" logger.info( 'Configuring optimizer: SGD with lr = {}, momentum = {}'.format( self.hparams.lr, self.hparams.momentum)) params = self.__set_lr() optimizer = CaffeSGD(params, lr=self.hparams.lr, momentum=self.hparams.momentum, weight_decay=self.hparams.wd) scheduler = torch.optim.lr_scheduler.StepLR( optimizer, step_size=self.hparams.lr_step, gamma=self.hparams.gamma) return [optimizer], [scheduler] @pl.data_loader def train_dataloader(self): """train dataloader""" logger.info('===' * 20) logger.info('Loading dataset for training, please wait...') logger.info('===' * 20) imagenet_path = self.hparams.imagenet_path alov_path = self.hparams.alov_path mean_file = None manager = Manager() objGoturn = GoturnDataloader(imagenet_path, alov_path, mean_file=mean_file, images_p=manager.list(), targets_p=manager.list(), bboxes_p=manager.list(), val_ratio=0.005, isTrain=True, dbg=False) train_loader = DataLoader(objGoturn, batch_size=self.hparams.batch_size, shuffle=True, num_workers=6, collate_fn=objGoturn.collate) return train_loader @pl.data_loader def val_dataloader(self): """validation dataloader""" logger.info('===' * 20) logger.info('Loading dataset for Validation, please wait...') logger.info('===' * 20) imagenet_path = self.hparams.imagenet_path alov_path = self.hparams.alov_path mean_file = None manager = Manager() objGoturn = GoturnDataloader(imagenet_path, alov_path, mean_file=mean_file, images_p=manager.list(), targets_p=manager.list(), bboxes_p=manager.list(), val_ratio=0.005, isTrain=False, dbg=False) val_loader = DataLoader(objGoturn, batch_size=self.hparams.batch_size, shuffle=True, num_workers=6, collate_fn=objGoturn.collate) return val_loader def forward(self, prev, curr): """forward function """ pred_bb = self._model(prev.float(), curr.float()) return pred_bb def vis_images(self, prev, curr, gt_bb, pred_bb, prefix='train'): def unnormalize(image, mean): image = np.transpose(image, (1, 2, 0)) + mean image = image.astype(np.float32) image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) return image for i in range(0, prev.shape[0]): # _mean = np.load(self.hparams.mean_file) _mean = np.array([104, 117, 123]) prev_img = prev[i].cpu().detach().numpy() curr_img = curr[i].cpu().detach().numpy() prev_img = unnormalize(prev_img, _mean) curr_img = unnormalize(curr_img, _mean) gt_bb_i = BoundingBox(*gt_bb[i].cpu().detach().numpy().tolist()) gt_bb_i.unscale(curr_img) curr_img = draw.bbox(curr_img, gt_bb_i, color=(255, 0, 255)) pred_bb_i = BoundingBox( *pred_bb[i].cpu().detach().numpy().tolist()) pred_bb_i.unscale(curr_img) curr_img = draw.bbox(curr_img, pred_bb_i) out = np.concatenate( (prev_img[np.newaxis, ...], curr_img[np.newaxis, ...]), axis=0) out = np.transpose(out, [0, 3, 1, 2]) self._viz.plot_images_np(out, title='sample_{}'.format(i), env='goturn_{}'.format(prefix)) def training_step(self, batch, batch_idx): """Training step @batch: current batch data @batch_idx: current batch index """ curr, prev, gt_bb = batch pred_bb = self.forward(prev, curr) loss = torch.nn.L1Loss(size_average=False)(pred_bb.float(), gt_bb.float()) if self.trainer.use_dp: loss = loss.unsqueeze(0) if self._dbg: if batch_idx % 1000 == 0: d = {'loss': loss.item()} iters = (self.trainer.num_training_batches - 1) * self.current_epoch + batch_idx self._viz.plot_curves(d, iters, title='Train', ylabel='train_loss') if batch_idx % 1000 == 0: self.vis_images(prev, curr, gt_bb, pred_bb) tqdm_dict = {'batch_loss': loss} output = OrderedDict({ 'loss': loss, 'progress_bar': tqdm_dict, 'log': tqdm_dict }) return output def validation_step(self, batch, batch_idx): """validation step @batch: current batch data @batch_idx: current batch index """ curr, prev, gt_bb = batch pred_bb = self.forward(prev, curr) loss = torch.nn.L1Loss(size_average=False)(pred_bb, gt_bb.float()) if self.trainer.use_dp: loss = loss.unsqueeze(0) if self._dbg: if batch_idx % 100 == 0: d = {'loss': loss.item()} iters = (self.trainer.num_val_batches - 1) * self.current_epoch + batch_idx self._viz.plot_curves(d, iters, title='Validation', ylabel='val_loss') if batch_idx % 1000 == 0: self.vis_images(prev, curr, gt_bb, pred_bb, prefix='val') tqdm_dict = {'val_loss': loss} output = OrderedDict({ 'val_loss': loss, 'progress_bar': tqdm_dict, 'log': tqdm_dict }) return output def validation_end(self, outputs): avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean() return {'val_loss': avg_loss}
class GoturnTracker: """Docstring for . """ def __init__(self, args, dbg=False): """load model """ loader = loadfromfolder(args.input) self._vid_frames = loader.get_video_frames() model_dir = Path(args.model_dir) # Checkpoint path ckpt_dir = model_dir.joinpath('checkpoints') ckpt_path = next(ckpt_dir.glob('*.ckpt')) model = GoturnTrain.load_from_checkpoint(ckpt_path) model.eval() model.freeze() self._model = model if dbg: self._viz = Visualizer() self._dbg = dbg def vis_images(self, prev, curr, gt_bb, pred_bb, prefix='train'): def unnormalize(image, mean, std): image = np.transpose(image, (1, 2, 0)) * std + mean image = image.astype(np.float32) return image for i in range(0, prev.shape[0]): _mean = np.array([104, 117, 123]) _std = np.ones_like(_mean) prev_img = prev[i].cpu().detach().numpy() curr_img = curr[i].cpu().detach().numpy() prev_img = unnormalize(prev_img, _mean, _std) curr_img = unnormalize(curr_img, _mean, _std) gt_bb_i = BoundingBox(*gt_bb[i].cpu().detach().numpy().tolist()) gt_bb_i.unscale(curr_img) curr_img = draw.bbox(curr_img, gt_bb_i, color=(255, 255, 255)) pred_bb_i = BoundingBox( *pred_bb[i].cpu().detach().numpy().tolist()) pred_bb_i.unscale(curr_img) curr_img = draw.bbox(curr_img, pred_bb_i) out = np.concatenate( (prev_img[np.newaxis, ...], curr_img[np.newaxis, ...]), axis=0) out = np.transpose(out, [0, 3, 1, 2]) self._viz.plot_images_np(out, title='sample_{}'.format(i), env='goturn_{}'.format(prefix)) def _track(self, curr_frame, prev_frame, rect): """track current frame @curr_frame: current frame @prev_frame: prev frame @rect: bounding box of previous frame """ prev_bbox = rect target_pad, _, _, _ = cropPadImage(prev_bbox, prev_frame) cur_search_region, search_location, edge_spacing_x, edge_spacing_y = cropPadImage( prev_bbox, curr_frame) if self._dbg: self._viz.plot_image_opencv(target_pad, 'target') self._viz.plot_image_opencv(cur_search_region, 'current') target_pad_in = self.preprocess(target_pad, mean=None).unsqueeze(0) cur_search_region_in = self.preprocess(cur_search_region, mean=None).unsqueeze(0) pred_bb = self._model.forward(target_pad_in, cur_search_region_in) if self._dbg: prev_bbox.scale(prev_frame) x1, y1, x2, y2 = prev_bbox.x1, prev_bbox.y1, prev_bbox.x2, prev_bbox.y2 prev_bbox = torch.tensor([x1, y1, x2, y2]).unsqueeze(0) target_dbg = target_pad_in.clone() cur_search_region_dbg = cur_search_region_in.clone() self.vis_images(target_dbg, cur_search_region_dbg, prev_bbox, pred_bb) pred_bb = BoundingBox(*pred_bb[0].cpu().detach().numpy().tolist()) pred_bb.unscale(cur_search_region) pred_bb.uncenter(curr_frame, search_location, edge_spacing_x, edge_spacing_y) x1, y1, x2, y2 = int(pred_bb.x1), int(pred_bb.y1), int( pred_bb.x2), int(pred_bb.y2) pred_bb = BoundingBox(x1, y1, x2, y2) return pred_bb def preprocess(self, im, mean=None): """preprocess image before forward pass, this is the same preprocessing used during training, please refer to collate function in train.py for reference @image: input image """ # preprocessing for all pretrained pytorch models if mean: im = resize(im, (227, 227)) - mean else: mean = np.array([104, 117, 123]) im = resize(im, (227, 227)) - mean im = image_io.image_to_tensor(im) return im def track(self): """Track""" vid_frames = self._vid_frames[0] num_frames = len(vid_frames) f_path = vid_frames[0] frame_0 = image_io.load(f_path) prev = np.asarray(frame_0) global image image = prev while True: # prev_out = cv2.cvtColor(prev, cv2.COLOR_RGB2BGR) prev_out = np.copy(prev) cv2.imshow('image', prev_out) key = cv2.waitKey(1) & 0xFF if key == ord('s'): (x1, y1), (x2, y2) = refPt[0], refPt[1] bbox_0 = BoundingBox(x1, y1, x2, y2) break elif key == ord('r'): (x1, y1), (x2, y2) = refPt[0], refPt[1] bbox_0 = BoundingBox(x1, y1, x2, y2) break for i in range(1, num_frames): f_path = vid_frames[i] frame_1 = image_io.load(f_path) curr = np.asarray(frame_1) bbox_0 = self._track(curr, prev, bbox_0) bbox = bbox_0 prev = curr if cv2.waitKey(1) & 0xFF == ord('p'): while True: image = curr cv2.imshow("image", curr) key = cv2.waitKey(0) & 0xFF if key == ord("s"): (x1, y1), (x2, y2) = refPt[0], refPt[1] bbox_0 = BoundingBox(x1, y1, x2, y2) break curr_dbg = np.copy(curr) curr_dbg = cv2.rectangle(curr_dbg, (int(bbox.x1), int(bbox.y1)), (int(bbox.x2), int(bbox.y2)), (255, 255, 0), 2) # curr_dbg = cv2.cvtColor(curr_dbg, cv2.COLOR_RGB2BGR) cv2.imshow('image', curr_dbg) # cv2.imwrite('./output/{:04d}.png'.format(i), curr_dbg) cv2.waitKey(20)
class sample_generator: """Generate samples from single frame""" def __init__(self, lamda_shift, lamda_scale, min_scale, max_scale, dbg=False, env='sample_generator'): """set parameters """ self._lamda_shift = lamda_shift self._lamda_scale = lamda_scale self._min_scale = min_scale self._max_scale = max_scale self._kSamplesPerImage = 10 # number of synthetic samples per image self._viz = None if dbg: self._env = env self._viz = Visualizer(env=self._env) self._dbg = dbg def make_true_sample(self): """Generate true target:search_region pair""" curr_prior_tight = self.bbox_prev_gt_ target_pad = self.target_pad_ # To find out the region in which we need to search in the # current frame, we use the previous frame bbox to get the # region in which we can make the search output = cropPadImage(curr_prior_tight, self.img_curr_, self._dbg, self._viz) curr_search_region, curr_search_location, edge_spacing_x, edge_spacing_y = output bbox_curr_gt = self.bbox_curr_gt_ bbox_curr_gt_recentered = BoundingBox(0, 0, 0, 0) bbox_curr_gt_recentered = bbox_curr_gt.recenter( curr_search_location, edge_spacing_x, edge_spacing_y, bbox_curr_gt_recentered) if self._dbg: env = self._env + '_make_true_sample' search_dbg = draw.bbox(self.img_curr_, curr_search_location) search_dbg = draw.bbox(search_dbg, bbox_curr_gt, color=(255, 255, 0)) self._viz.plot_image_opencv(search_dbg, 'search_region', env=env) recentered_img = draw.bbox(curr_search_region, bbox_curr_gt_recentered, color=(255, 255, 0)) self._viz.plot_image_opencv(recentered_img, 'cropped_search_region', env=env) del recentered_img del search_dbg bbox_curr_gt_recentered.scale(curr_search_region) return curr_search_region, target_pad, bbox_curr_gt_recentered def make_training_samples(self, num_samples, images, targets, bbox_gt_scales): """ @num_samples: number of samples @images: set of num_samples appended to images list @target: set of num_samples targets appended to targets list @bbox_gt_scales: bounding box to be regressed (scaled version) """ for i in range(num_samples): image_rand_focus, target_pad, bbox_gt_scaled = self.make_training_sample_BBShift( ) images.append(image_rand_focus) targets.append(target_pad) bbox_gt_scales.append(bbox_gt_scaled) if self._dbg: self.visualize(image_rand_focus, target_pad, bbox_gt_scaled, i) return images, targets, bbox_gt_scales def visualize(self, image, target, bbox, idx): """ sample generator prepares image and the respective targets (with bounding box). This function helps you to visualize it. The visualization is based on the Visdom server, please initialize the visdom server by running the command $ python -m visdom.server open http://localhost:8097 in your browser to visualize the images """ if image_io._is_pil_image(image): image = np.asarray(image) if image_io._is_pil_image(target): target = np.asarray(target) target = cv2.resize(target, (227, 227)) target = cv2.cvtColor(target, cv2.COLOR_BGR2RGB) image = cv2.resize(image, (227, 227)) image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) bbox.unscale(image) bbox.x1, bbox.x2, bbox.y1, bbox.y2 = int(bbox.x1), int(bbox.x2), int( bbox.y1), int(bbox.y2) image_bb = draw.bbox(image, bbox) out = np.concatenate( (target[np.newaxis, ...], image_bb[np.newaxis, ...]), axis=0) out = np.transpose(out, [0, 3, 1, 2]) self._viz.plot_images_np(out, title='sample_{}'.format(idx), env=self._env + '_train') def get_default_bb_params(self): """default bb parameters""" default_params = bbParams(self._lamda_shift, self._lamda_scale, self._min_scale, self._max_scale) return default_params def make_training_sample_BBShift_(self, bbParams, dbg=False): """generate training samples based on bbparams""" bbox_curr_gt = self.bbox_curr_gt_ bbox_curr_shift = BoundingBox(0, 0, 0, 0) bbox_curr_shift = bbox_curr_gt.shift( self.img_curr_, bbParams.lamda_scale, bbParams.lamda_shift, bbParams.min_scale, bbParams.max_scale, True, bbox_curr_shift) rand_search_region, rand_search_location, edge_spacing_x, edge_spacing_y = cropPadImage( bbox_curr_shift, self.img_curr_, dbg=self._dbg, viz=self._viz) bbox_curr_gt = self.bbox_curr_gt_ bbox_gt_recentered = BoundingBox(0, 0, 0, 0) bbox_gt_recentered = bbox_curr_gt.recenter(rand_search_location, edge_spacing_x, edge_spacing_y, bbox_gt_recentered) if dbg: env = self._env + '_make_training_sample_bbshift' viz = self._viz curr_img_bbox = draw.bbox(self.img_curr_, bbox_curr_gt) recentered_img = draw.bbox(rand_search_region, bbox_gt_recentered) viz.plot_image_opencv(curr_img_bbox, 'curr shifted bbox', env=env) viz.plot_image_opencv(recentered_img, 'recentered shifted bbox', env=env) bbox_gt_recentered.scale(rand_search_region) bbox_gt_scaled = bbox_gt_recentered return rand_search_region, self.target_pad_, bbox_gt_scaled def make_training_sample_BBShift(self): """ bb_params consists of shift, scale, min-max scale for shifting the current bounding box """ default_bb_params = self.get_default_bb_params() image_rand_focus, target_pad, bbox_gt_scaled = self.make_training_sample_BBShift_( default_bb_params, self._dbg) return image_rand_focus, target_pad, bbox_gt_scaled def reset(self, bbox_curr, bbox_prev, img_curr, img_prev): """This prepares the target image with enough context (search region) @bbox_curr: current frame bounding box @bbox_prev: previous frame bounding box @img_curr: current frame @img_prev: previous frame """ target_pad, pad_image_location, _, _ = cropPadImage(bbox_prev, img_prev, dbg=self._dbg, viz=self._viz) self.img_curr_ = img_curr self.bbox_curr_gt_ = bbox_curr self.bbox_prev_gt_ = bbox_prev self.target_pad_ = target_pad # crop kContextFactor * bbox_curr copied if self._dbg: env = self._env + '_targetpad' search_dbg = draw.bbox(img_prev, bbox_prev, color=(0, 0, 255)) search_dbg = draw.bbox(search_dbg, pad_image_location) self._viz.plot_image_opencv(search_dbg, 'target_region', env=env) self._viz.plot_image_opencv(target_pad, 'cropped_target_region', env=env) del search_dbg