def __init__(self, hparams, teacher_path=''): super().__init__() # addition: convert dict to namespace when necessary # hack: if isinstance(hparams, dict): import argparse args = argparse.Namespace() for k, v in hparams.items(): setattr(args, k, v) hparams = args self.hparams = hparams self.to_heatmap = ToHeatmap(hparams.heatmap_radius) if teacher_path: # modifiction: add str self.teacher = MapModel.load_from_checkpoint(str(teacher_path)) self.teacher.freeze() self.net = SegmentationModel(10, 4, hack=hparams.hack, temperature=hparams.temperature) self.converter = Converter() self.controller = RawController(4)
def __init__(self, hparams): super().__init__() self.hparams = hparams self.net = SegmentationModel(4, 4) self.teacher = MapModel.load_from_checkpoint(pathlib.Path('/home/bradyzhou/code/carla_random/') / hparams.teacher_path) # self.teacher.eval() self.converter = Converter()
def __init__(self, hparams): super().__init__() # addition: convert dict to namespace when necessary # hack: if isinstance(hparams, dict): import argparse args = argparse.Namespace() for k,v in hparams.items(): setattr(args, k, v) hparams = args self.hparams = hparams self.to_heatmap = ToHeatmap(hparams.heatmap_radius) self.net = SegmentationModel(10, 4, hack=hparams.hack, temperature=hparams.temperature) self.controller = RawController(4)
def get_model(ie, args): if args.architecture_type == 'segmentation': return SegmentationModel(ie, args.model), SegmentationVisualizer( args.colors) if args.architecture_type == 'salient_object_detection': return SalientObjectDetectionModel( ie, args.model), SaliencyMapVisualizer()
class ImageModel(pl.LightningModule): def __init__(self, hparams, teacher_path=''): super().__init__() # addition: convert dict to namespace when necessary # hack: if isinstance(hparams, dict): import argparse args = argparse.Namespace() for k, v in hparams.items(): setattr(args, k, v) hparams = args self.hparams = hparams self.to_heatmap = ToHeatmap(hparams.heatmap_radius) if teacher_path: # modifiction: add str self.teacher = MapModel.load_from_checkpoint(str(teacher_path)) self.teacher.freeze() self.net = SegmentationModel(10, 4, hack=hparams.hack, temperature=hparams.temperature) self.converter = Converter() self.controller = RawController(4) def forward(self, img, target): target_cam = self.converter.map_to_cam(target) target_heatmap_cam = self.to_heatmap(target, img)[:, None] out = self.net(torch.cat((img, target_heatmap_cam), 1)) # print(out, target) return out, (target_cam, target_heatmap_cam) @torch.no_grad() def _get_labels(self, topdown, target): out, (target_heatmap, ) = self.teacher.forward(topdown, target, debug=True) control = self.teacher.controller(out) return out, control, (target_heatmap, ) def training_step(self, batch, batch_nb): img, topdown, points, target, actions, meta = batch # Ground truth command. lbl_map, ctrl_map, (target_heatmap, ) = self._get_labels( topdown, target) lbl_cam = self.converter.map_to_cam((lbl_map + 1) / 2 * 256) lbl_cam[..., 0] = (lbl_cam[..., 0] / 256) * 2 - 1 lbl_cam[..., 1] = (lbl_cam[..., 1] / 144) * 2 - 1 out, (target_cam, target_heatmap_cam) = self.forward(img, target) alpha = torch.rand(out.shape[0], out.shape[1], 1).type_as(out) between = alpha * out + (1 - alpha) * lbl_cam out_ctrl = self.controller(between) point_loss = torch.nn.functional.l1_loss(out, lbl_cam, reduction='none').mean((1, 2)) ctrl_loss_raw = torch.nn.functional.l1_loss(out_ctrl, ctrl_map, reduction='none') ctrl_loss = ctrl_loss_raw.mean(1) steer_loss = ctrl_loss_raw[:, 0] speed_loss = ctrl_loss_raw[:, 1] loss_gt = (point_loss + self.hparams.command_coefficient * ctrl_loss) loss_gt_mean = loss_gt.mean() # Random command. indices = np.random.choice(RANDOM_POINTS.shape[0], topdown.shape[0]) target_aug = torch.from_numpy(RANDOM_POINTS[indices]).type_as(img) lbl_map_aug, ctrl_map_aug, (target_heatmap_aug, ) = self._get_labels( topdown, target_aug) lbl_cam_aug = self.converter.map_to_cam((lbl_map_aug + 1) / 2 * 256) lbl_cam_aug[..., 0] = (lbl_cam_aug[..., 0] / 256) * 2 - 1 lbl_cam_aug[..., 1] = (lbl_cam_aug[..., 1] / 144) * 2 - 1 out_aug, (target_cam_aug, target_heatmap_cam_aug) = self.forward(img, target_aug) alpha = torch.rand(out.shape[0], out.shape[1], 1).type_as(out) between_aug = alpha * out_aug + (1 - alpha) * lbl_cam_aug out_ctrl_aug = self.controller(between_aug) point_loss_aug = torch.nn.functional.l1_loss(out_aug, lbl_cam_aug, reduction='none').mean( (1, 2)) ctrl_loss_aug_raw = torch.nn.functional.l1_loss(out_ctrl_aug, ctrl_map_aug, reduction='none') ctrl_loss_aug = ctrl_loss_aug_raw.mean(1) steer_loss_aug = ctrl_loss_aug_raw[:, 0] speed_loss_aug = ctrl_loss_aug_raw[:, 1] loss_aug = (point_loss_aug + self.hparams.command_coefficient * ctrl_loss_aug) loss_aug_mean = loss_aug.mean() loss = loss_gt_mean + loss_aug_mean metrics = { 'train_loss': loss.item(), 'train_point': point_loss.mean().item(), 'train_ctrl': ctrl_loss.mean().item(), 'train_steer': steer_loss.mean().item(), 'train_speed': speed_loss.mean().item(), 'train_point_aug': point_loss_aug.mean().item(), 'train_ctrl_aug': ctrl_loss_aug.mean().item(), 'train_steer_aug': steer_loss_aug.mean().item(), 'train_speed_aug': speed_loss_aug.mean().item(), } if batch_nb % 250 == 0: metrics['train_image'] = viz(batch, out, out_ctrl, target_cam, lbl_cam, lbl_map, ctrl_map, point_loss, ctrl_loss) metrics['train_image_aug'] = viz(batch, out_aug, out_ctrl_aug, target_cam_aug, lbl_cam_aug, lbl_map_aug, ctrl_map_aug, point_loss_aug, ctrl_loss_aug) self.logger.log_metrics(metrics, self.global_step) return {'loss': loss} def validation_step(self, batch, batch_nb): img, topdown, points, target, actions, meta = batch # Ground truth command. lbl_map, ctrl_map, (target_heatmap, ) = self._get_labels( topdown, target) lbl_cam = self.converter.map_to_cam((lbl_map + 1) / 2 * 256) lbl_cam[..., 0] = (lbl_cam[..., 0] / 256) * 2 - 1 lbl_cam[..., 1] = (lbl_cam[..., 1] / 144) * 2 - 1 out, (target_cam, target_heatmap_cam) = self.forward(img, target) out_ctrl = self.controller(out) out_ctrl_gt = self.controller(lbl_cam) point_loss = torch.nn.functional.l1_loss(out, lbl_cam, reduction='none').mean((1, 2)) ctrl_loss_raw = torch.nn.functional.l1_loss(out_ctrl, ctrl_map, reduction='none') ctrl_loss = ctrl_loss_raw.mean(1) steer_loss = ctrl_loss_raw[:, 0] speed_loss = ctrl_loss_raw[:, 1] ctrl_loss_gt_raw = torch.nn.functional.l1_loss(out_ctrl_gt, ctrl_map, reduction='none') ctrl_loss_gt = ctrl_loss_gt_raw.mean(1) steer_loss_gt = ctrl_loss_gt_raw[:, 0] speed_loss_gt = ctrl_loss_gt_raw[:, 1] loss_gt = (point_loss + self.hparams.command_coefficient * ctrl_loss) loss_gt_mean = loss_gt.mean() # Random command. indices = np.random.choice(RANDOM_POINTS.shape[0], topdown.shape[0]) target_aug = torch.from_numpy(RANDOM_POINTS[indices]).type_as(img) lbl_map_aug, ctrl_map_aug, (target_heatmap_aug, ) = self._get_labels( topdown, target_aug) lbl_cam_aug = self.converter.map_to_cam((lbl_map_aug + 1) / 2 * 256) lbl_cam_aug[..., 0] = (lbl_cam_aug[..., 0] / 256) * 2 - 1 lbl_cam_aug[..., 1] = (lbl_cam_aug[..., 1] / 144) * 2 - 1 out_aug, (target_cam_aug, target_heatmap_cam_aug) = self.forward(img, target_aug) out_ctrl_aug = self.controller(out_aug) out_ctrl_gt_aug = self.controller(lbl_cam_aug) point_loss_aug = torch.nn.functional.l1_loss(out_aug, lbl_cam_aug, reduction='none').mean( (1, 2)) ctrl_loss_aug_raw = torch.nn.functional.l1_loss(out_ctrl_aug, ctrl_map_aug, reduction='none') ctrl_loss_aug = ctrl_loss_aug_raw.mean(1) steer_loss_aug = ctrl_loss_aug_raw[:, 0] speed_loss_aug = ctrl_loss_aug_raw[:, 1] ctrl_loss_gt_aug_raw = torch.nn.functional.l1_loss(out_ctrl_gt_aug, ctrl_map_aug, reduction='none') ctrl_loss_gt_aug = ctrl_loss_gt_aug_raw.mean(1) steer_loss_gt_aug = ctrl_loss_gt_aug_raw[:, 0] speed_loss_gt_aug = ctrl_loss_gt_aug_raw[:, 1] loss_gt_aug = (point_loss_aug + self.hparams.command_coefficient * ctrl_loss_aug) loss_gt_aug_mean = loss_gt_aug.mean() if batch_nb == 0: self.logger.log_metrics( { 'val_image': viz(batch, out, out_ctrl, target_cam, lbl_cam, lbl_map, ctrl_map, point_loss, ctrl_loss), 'val_image_aug': viz(batch, out_aug, out_ctrl_aug, target_cam_aug, lbl_cam_aug, lbl_map_aug, ctrl_map_aug, point_loss_aug, ctrl_loss_aug) }, self.global_step) return { 'val_loss': (loss_gt_mean + loss_gt_aug_mean).item(), 'val_point': point_loss.mean().item(), 'val_ctrl': ctrl_loss.mean().item(), 'val_steer': steer_loss.mean().item(), 'val_speed': speed_loss.mean().item(), 'val_ctrl_gt': ctrl_loss_gt.mean().item(), 'val_steer_gt': steer_loss_gt.mean().item(), 'val_speed_gt': speed_loss_gt.mean().item(), 'val_point_aug': point_loss_aug.mean().item(), 'val_ctrl_aug': ctrl_loss_aug.mean().item(), 'val_steer_aug': steer_loss_aug.mean().item(), 'val_speed_aug': speed_loss_aug.mean().item(), 'val_ctrl_gt_aug': ctrl_loss_gt_aug.mean().item(), 'val_steer_gt_aug': steer_loss_gt_aug.mean().item(), 'val_speed_gt_aug': speed_loss_gt_aug.mean().item(), } def validation_epoch_end(self, outputs): results = dict() for output in outputs: for key in output: if key not in results: results[key] = list() results[key].append(output[key]) summary = {key: np.mean(val) for key, val in results.items()} self.logger.log_metrics(summary, self.global_step) return summary def configure_optimizers(self): optim = torch.optim.Adam(list(self.net.parameters()) + list(self.controller.parameters()), lr=self.hparams.lr, weight_decay=self.hparams.weight_decay) scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optim, mode='min', factor=0.5, patience=5, min_lr=1e-6, verbose=True) return [optim], [scheduler] def train_dataloader(self): return get_dataset(self.hparams.dataset_dir, True, self.hparams.batch_size, sample_by=self.hparams.sample_by) def val_dataloader(self): return get_dataset(self.hparams.dataset_dir, False, self.hparams.batch_size, sample_by=self.hparams.sample_by) def state_dict(self): return { k: v for k, v in super().state_dict().items() if 'teacher' not in k } def load_state_dict(self, state_dict): errors = super().load_state_dict(state_dict, strict=False) print(errors)
class MapModel(pl.LightningModule): def __init__(self, hparams): super().__init__() # addition: convert dict to namespace when necessary # hack: if isinstance(hparams, dict): import argparse args = argparse.Namespace() for k,v in hparams.items(): setattr(args, k, v) hparams = args self.hparams = hparams self.to_heatmap = ToHeatmap(hparams.heatmap_radius) self.net = SegmentationModel(10, 4, hack=hparams.hack, temperature=hparams.temperature) self.controller = RawController(4) def forward(self, topdown, target, debug=False): target_heatmap = self.to_heatmap(target, topdown)[:, None] out = self.net(torch.cat((topdown, target_heatmap), 1)) if not debug: return out return out, (target_heatmap,) def training_step(self, batch, batch_nb): img, topdown, points, target, actions, meta = batch out, (target_heatmap,) = self.forward(topdown, target, debug=True) alpha = torch.rand(out.shape).type_as(out) between = alpha * out + (1-alpha) * points out_cmd = self.controller(between) loss_point = torch.nn.functional.l1_loss(out, points, reduction='none').mean((1, 2)) loss_cmd_raw = torch.nn.functional.l1_loss(out_cmd, actions, reduction='none') loss_cmd = loss_cmd_raw.mean(1) loss = (loss_point + self.hparams.command_coefficient * loss_cmd).mean() metrics = { 'point_loss': loss_point.mean().item(), 'cmd_loss': loss_cmd.mean().item(), 'loss_steer': loss_cmd_raw[:, 0].mean().item(), 'loss_speed': loss_cmd_raw[:, 1].mean().item() } if batch_nb % 250 == 0: metrics['train_image'] = visualize(batch, out, between, out_cmd, loss_point, loss_cmd, target_heatmap) self.logger.log_metrics(metrics, self.global_step) print(type(loss)) return {'loss': loss} def validation_step(self, batch, batch_nb): img, topdown, points, target, actions, meta = batch out, (target_heatmap,) = self.forward(topdown, target, debug=True) alpha = 0.0 between = alpha * out + (1-alpha) * points out_cmd = self.controller(between) out_cmd_pred = self.controller(out) loss_point = torch.nn.functional.l1_loss(out, points, reduction='none').mean((1, 2)) loss_cmd_raw = torch.nn.functional.l1_loss(out_cmd, actions, reduction='none') loss_cmd_pred_raw = torch.nn.functional.l1_loss(out_cmd_pred, actions, reduction='none') loss_cmd = loss_cmd_raw.mean(1) loss = (loss_point + self.hparams.command_coefficient * loss_cmd).mean() if batch_nb == 0: self.logger.log_metrics({ 'val_image': visualize(batch, out, between, out_cmd, loss_point, loss_cmd, target_heatmap) }, self.global_step) return { 'val_loss': loss.item(), 'val_point_loss': loss_point.mean().item(), 'val_cmd_loss': loss_cmd_raw.mean(1).mean().item(), 'val_steer_loss': loss_cmd_raw[:, 0].mean().item(), 'val_speed_loss': loss_cmd_raw[:, 1].mean().item(), 'val_cmd_pred_loss': loss_cmd_pred_raw.mean(1).mean().item(), 'val_steer_pred_loss': loss_cmd_pred_raw[:, 0].mean().item(), 'val_speed_pred_loss': loss_cmd_pred_raw[:, 1].mean().item(), } def validation_epoch_end(self, batch_metrics): results = dict() for metrics in batch_metrics: for key in metrics: if key not in results: results[key] = list() results[key].append(metrics[key]) summary = {key: np.mean(val) for key, val in results.items()} self.logger.log_metrics(summary, self.global_step) return summary def configure_optimizers(self): optim = torch.optim.Adam( list(self.net.parameters()) + list(self.controller.parameters()), lr=self.hparams.lr, weight_decay=self.hparams.weight_decay) scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( optim, mode='min', factor=0.5, patience=5, min_lr=1e-6, verbose=True) return [optim], [scheduler] def train_dataloader(self): return get_dataset(self.hparams.dataset_dir, True, self.hparams.batch_size, sample_by=self.hparams.sample_by) def val_dataloader(self): return get_dataset(self.hparams.dataset_dir, False, self.hparams.batch_size, sample_by=self.hparams.sample_by)
def main(): metrics = PerformanceMetrics() args = build_argparser().parse_args() log.info('Initializing Inference Engine...') ie = IECore() plugin_config = get_plugin_configs(args.device, args.num_streams, args.num_threads) log.info('Loading network...') model = SegmentationModel(ie, args.model) pipeline = AsyncPipeline(ie, model, plugin_config, device=args.device, max_num_requests=args.num_infer_requests) cap = open_images_capture(args.input, args.loop) next_frame_id = 0 next_frame_id_to_show = 0 log.info('Starting inference...') print( "To close the application, press 'CTRL+C' here or switch to the output window and press ESC key" ) visualizer = Visualizer(args.colors) presenter = None video_writer = cv2.VideoWriter() while True: if pipeline.is_ready(): # Get new image/frame start_time = perf_counter() frame = cap.read() if frame is None: if next_frame_id == 0: raise ValueError("Can't read an image from the input") break if next_frame_id == 0: presenter = monitors.Presenter( args.utilization_monitors, 55, (round(frame.shape[1] / 4), round(frame.shape[0] / 8))) if args.output and not video_writer.open( args.output, cv2.VideoWriter_fourcc(*'MJPG'), cap.fps(), (frame.shape[1], frame.shape[0])): raise RuntimeError("Can't open video writer") # Submit for inference pipeline.submit_data(frame, next_frame_id, { 'frame': frame, 'start_time': start_time }) next_frame_id += 1 else: # Wait for empty request pipeline.await_any() if pipeline.callback_exceptions: raise pipeline.callback_exceptions[0] # Process all completed requests results = pipeline.get_result(next_frame_id_to_show) if results: objects, frame_meta = results frame = frame_meta['frame'] start_time = frame_meta['start_time'] frame = visualizer.overlay_masks(frame, objects) presenter.drawGraphs(frame) metrics.update(start_time, frame) if video_writer.isOpened() and ( args.output_limit <= 0 or next_frame_id_to_show <= args.output_limit - 1): video_writer.write(frame) if not args.no_show: cv2.imshow('Segmentation Results', frame) key = cv2.waitKey(1) if key == 27 or key == 'q' or key == 'Q': break presenter.handleKey(key) next_frame_id_to_show += 1 pipeline.await_all() # Process completed requests while pipeline.has_completed_request(): results = pipeline.get_result(next_frame_id_to_show) if results: objects, frame_meta = results frame = frame_meta['frame'] start_time = frame_meta['start_time'] frame = visualizer.overlay_masks(frame, objects) presenter.drawGraphs(frame) metrics.update(start_time, frame) if video_writer.isOpened() and ( args.output_limit <= 0 or next_frame_id_to_show <= args.output_limit - 1): video_writer.write(frame) if not args.no_show: cv2.imshow('Segmentation Results', frame) key = cv2.waitKey(1) next_frame_id_to_show += 1 else: break metrics.print_total() print(presenter.reportMeans())
class ImageModel(pl.LightningModule): def __init__(self, hparams): super().__init__() self.hparams = hparams self.net = SegmentationModel(4, 4) self.teacher = MapModel.load_from_checkpoint(pathlib.Path('/home/bradyzhou/code/carla_random/') / hparams.teacher_path) # self.teacher.eval() self.converter = Converter() def forward(self, x, *args, **kwargs): return self.net(x, *args, **kwargs) @torch.no_grad() def _get_labels(self, batch): img, topdown, points, heatmap, heatmap_img, meta = batch out = self.teacher.forward(torch.cat([topdown, heatmap], 1)) return out def training_step(self, batch, batch_nb): img, topdown, points, heatmap, heatmap_img, meta = batch labels_map = self._get_labels(batch) labels_cam = self.converter.map_to_cam((labels_map + 1) / 2 * 256) labels_cam[..., 0] = (labels_cam[..., 0] / 256) * 2 - 1 labels_cam[..., 1] = (labels_cam[..., 1] / 144) * 2 - 1 out = self.forward(torch.cat([img, heatmap_img], 1)) loss = torch.nn.functional.l1_loss(out, labels_cam, reduction='none').mean((1, 2)) loss_mean = loss.mean() metrics = {'train_loss': loss_mean.item()} if batch_nb % 250 == 0: metrics['train_image'] = visualize(batch, out, labels_cam, labels_map, loss) self.logger.log_metrics(metrics, self.global_step) return {'loss': loss_mean} def validation_step(self, batch, batch_nb): img, topdown, points, heatmap, heatmap_img, meta = batch labels_map = self._get_labels(batch) labels_cam = self.converter.map_to_cam((labels_map + 1) / 2 * 256) labels_cam[..., 0] = (labels_cam[..., 0] / 256) * 2 - 1 labels_cam[..., 1] = (labels_cam[..., 1] / 144) * 2 - 1 out = self.forward(torch.cat([img, heatmap_img], 1)) loss = torch.nn.functional.l1_loss(out, points, reduction='none').mean((1, 2)) loss_mean = loss.mean() if batch_nb == 0: self.logger.log_metrics({ 'val_image': visualize(batch, out, labels_cam, labels_map, loss) }, self.global_step) return {'val_loss': loss_mean.item()} def validation_epoch_end(self, outputs): results = {'val_loss': list()} for output in outputs: for key in results: results[key].append(output[key]) summary = {key: np.mean(val) for key, val in results.items()} self.logger.log_metrics(summary, self.global_step) return summary def configure_optimizers(self): return torch.optim.Adam( self.net.parameters(), lr=self.hparams.lr, weight_decay=self.hparams.weight_decay) def train_dataloader(self): return get_dataset(self.hparams.dataset_dir, True, self.hparams.batch_size) def val_dataloader(self): return get_dataset(self.hparams.dataset_dir, False, self.hparams.batch_size)