def evaluate_batch(self, batch: TorchData, model: nn.Module) -> Dict[str, Any]: images, targets = batch output = model(list(images), copy.deepcopy(list(targets))) sum_iou = 0 num_boxes = 0 # Instantiate the Tensorboard writer and set the log_dir to /tmp/tensorboard where Determined looks for events writer = SummaryWriter(log_dir="/tmp/tensorboard") # Our eval metric is the average best IoU (across all predicted # pedestrian bounding boxes) per target pedestrian. Given predicted # and target bounding boxes, IoU is the area of the intersection over # the area of the union. for idx, target in enumerate(targets): # Filter out overlapping bounding box predictions based on # non-maximum suppression (NMS) predicted_boxes = output[idx]["boxes"] prediction_scores = output[idx]["scores"] keep_indices = torchvision.ops.nms(predicted_boxes, prediction_scores, 0.1) predicted_boxes = torch.index_select(predicted_boxes, 0, keep_indices) prediction_scores = torch.index_select(prediction_scores, 0, keep_indices) # Tally IoU with respect to the ground truth target boxes target_boxes = target["boxes"] boxes_iou = torchvision.ops.box_iou(target_boxes, predicted_boxes) sum_iou += sum(max(iou_result) for iou_result in boxes_iou) num_boxes += len(target_boxes) # boxes are ordered by confidence, so get the top 5 bounding boxes and write out to Tensorboard # new_predicted_boxes = output[idx]["boxes"][:5] threshold = 0.7 cutoff = 0 for i, score in enumerate(prediction_scores): if score < threshold: break cutoff = i new_predicted_boxes = output[idx]["boxes"][:cutoff] writer.add_image_with_boxes("step_" + str(self.current_step), images[idx], predicted_boxes) writer.close() return {"val_avg_iou": sum_iou / num_boxes}
class PytorchTBWriter(object): def __init__(self, *inputs): args = inputs[0] log_dir = None if len(inputs) == 2: self.log_dir = inputs[1] if self.log_dir is None: directory = os.path.join(args.log_dir, args.dataset, args.checkname) runs = sorted(glob.glob(os.path.join(directory, 'experiment_*'))) exist_run_ids = sorted([int(r.split('_')[-1]) for r in runs]) run_id = exist_run_ids[-1] + 1 if runs else 0 self.log_dir = os.path.join(directory, 'experiment_{}'.format(str(run_id))) if not os.path.exists(self.log_dir): os.makedirs(self.log_dir) self.writer = SummaryWriter(self.log_dir) def add_images_with_bboxes(self, tag, images, bbox, step, labels=None, dataformats='CHW'): '''bbox: N x 4 (xmin, ymin, xmax, ymax), absolute values''' self.writer.add_image_with_boxes(tag, images, bbox, step, dataformats=dataformats) self.writer.flush() def list_of_scalars_summary(self, tag_value_pairs, step): for tag, value in tag_value_pairs: self.writer.add_scalar(tag, value, step) self.writer.flush() def add_image(self, tag, image, step): self.writer.add_image(tag, image, step) self.writer.flush() def add_images(self, tag, images, step): self.writer.add_images(tag, images, step) self.writer.flush()
class SummaryWriter: def __init__(self, logdir, flush_secs=120): self.writer = TensorboardSummaryWriter( log_dir=logdir, purge_step=None, max_queue=10, flush_secs=flush_secs, filename_suffix='') self.global_step = None self.active = True # ------------------------------------------------------------------------ # register add_* and set_* functions in summary module on instantiation # ------------------------------------------------------------------------ this_module = sys.modules[__name__] list_of_names = dir(SummaryWriter) for name in list_of_names: # add functions (without the 'add' prefix) if name.startswith('add_'): setattr(this_module, name[4:], getattr(self, name)) # set functions if name.startswith('set_'): setattr(this_module, name, getattr(self, name)) def set_global_step(self, value): self.global_step = value def set_active(self, value): self.active = value def add_audio(self, tag, snd_tensor, global_step=None, sample_rate=44100, walltime=None): if self.active: global_step = self.global_step if global_step is None else global_step self.writer.add_audio( tag, snd_tensor, global_step=global_step, sample_rate=sample_rate, walltime=walltime) def add_custom_scalars(self, layout): if self.active: self.writer.add_custom_scalars(layout) def add_custom_scalars_marginchart(self, tags, category='default', title='untitled'): if self.active: self.writer.add_custom_scalars_marginchart(tags, category=category, title=title) def add_custom_scalars_multilinechart(self, tags, category='default', title='untitled'): if self.active: self.writer.add_custom_scalars_multilinechart(tags, category=category, title=title) def add_embedding(self, mat, metadata=None, label_img=None, global_step=None, tag='default', metadata_header=None): if self.active: global_step = self.global_step if global_step is None else global_step self.writer.add_embedding( mat, metadata=metadata, label_img=label_img, global_step=global_step, tag=tag, metadata_header=metadata_header) def add_figure(self, tag, figure, global_step=None, close=True, walltime=None): if self.active: global_step = self.global_step if global_step is None else global_step self.writer.add_figure( tag, figure, global_step=global_step, close=close, walltime=walltime) def add_graph(self, model, input_to_model=None, verbose=False): if self.active: self.writer.add_graph(model, input_to_model=input_to_model, verbose=verbose) def add_histogram(self, tag, values, global_step=None, bins='tensorflow', walltime=None, max_bins=None): if self.active: global_step = self.global_step if global_step is None else global_step self.writer.add_histogram( tag, values, global_step=global_step, bins=bins, walltime=walltime, max_bins=max_bins) def add_histogram_raw(self, tag, min, max, num, sum, sum_squares, bucket_limits, bucket_counts, global_step=None, walltime=None): if self.active: global_step = self.global_step if global_step is None else global_step self.writer.add_histogram_raw( tag, min=min, max=max, num=num, sum=sum, sum_squares=sum_squares, bucket_limits=bucket_limits, bucket_counts=bucket_counts, global_step=global_step, walltime=walltime) def add_image(self, tag, img_tensor, global_step=None, walltime=None, dataformats='CHW'): if self.active: global_step = self.global_step if global_step is None else global_step self.writer.add_image( tag, img_tensor, global_step=global_step, walltime=walltime, dataformats=dataformats) def add_image_with_boxes(self, tag, img_tensor, box_tensor, global_step=None, walltime=None, rescale=1, dataformats='CHW'): if self.active: global_step = self.global_step if global_step is None else global_step self.writer.add_image_with_boxes( tag, img_tensor, box_tensor, global_step=global_step, walltime=walltime, rescale=rescale, dataformats=dataformats) def add_images(self, tag, img_tensor, global_step=None, walltime=None, dataformats='NCHW'): if self.active: global_step = self.global_step if global_step is None else global_step self.writer.add_images( tag, img_tensor, global_step=global_step, walltime=walltime, dataformats=dataformats) def add_mesh(self, tag, vertices, colors=None, faces=None, config_dict=None, global_step=None, walltime=None): if self.active: global_step = self.global_step if global_step is None else global_step self.writer.add_mesh( tag, vertices, colors=colors, faces=faces, config_dict=config_dict, global_step=global_step, walltime=walltime) def add_onnx_graph(self, graph): if self.active: self.writer.add_onnx_graph(graph) def add_pr_curve(self, tag, labels, predictions, global_step=None, num_thresholds=127, weights=None, walltime=None): if self.active: global_step = self.global_step if global_step is None else global_step self.writer.add_pr_curve( tag, labels, predictions, global_step=global_step, num_thresholds=num_thresholds, weights=weights, walltime=walltime) def add_pr_curve_raw(self, tag, true_positive_counts, false_positive_counts, true_negative_counts, false_negative_counts, precision, recall, global_step=None, num_thresholds=127, weights=None, walltime=None): if self.active: global_step = self.global_step if global_step is None else global_step self.writer.add_pr_curve_raw( tag, true_positive_counts, false_positive_counts, true_negative_counts, false_negative_counts, precision, recall, global_step=global_step, num_thresholds=num_thresholds, weights=weights, walltime=walltime) def add_scalar(self, tag, scalar_value, global_step=None, walltime=None): if self.active: global_step = self.global_step if global_step is None else global_step self.writer.add_scalar( tag, scalar_value, global_step=global_step, walltime=walltime) def add_scalars(self, main_tag, tag_scalar_dict, global_step=None, walltime=None): if self.active: global_step = self.global_step if global_step is None else global_step self.writer.add_scalars( main_tag, tag_scalar_dict, global_step=global_step, walltime=walltime) def add_text(self, tag, text_string, global_step=None, walltime=None): if self.active: global_step = self.global_step if global_step is None else global_step self.writer.add_text( tag, text_string, global_step=global_step, walltime=walltime) def add_video(self, tag, vid_tensor, global_step=None, fps=4, walltime=None): if self.active: global_step = self.global_step if global_step is None else global_step self.writer.add_video( tag, vid_tensor, global_step=global_step, fps=fps, walltime=walltime) def close(self): self.writer.close() def __enter__(self): return self.writer.__enter__() def __exit__(self, exc_type, exc_val, exc_tb): return self.writer.__exit__(exc_type, exc_val, exc_tb)
def train(): # train on the GPU or on the CPU, if a GPU is not available device = torch.device( 'cuda') if torch.cuda.is_available() else torch.device('cpu') time_stamp = "{0:%Y-%m-%dT%H-%M-%S}".format(datetime.now()) stamp = "relation" + '_b' + str(batch_size) + '_e' + str( num_epochs) + '_lr' + str(learning_rate) + '_tr' + str(train_set_ratio) writer = SummaryWriter(comment=stamp) # use our dataset and defined transformations dataset = Labelme_Dataset(data_path, cls2id=cls2id, transforms=get_transform(train=True)) dataset_test = Labelme_Dataset(data_path, cls2id=cls2id, transforms=get_transform(train=False)) # split the dataset in train and test set num_train_set = round(train_set_ratio * len(dataset)) indices = torch.randperm(len(dataset)).tolist() dataset = torch.utils.data.Subset(dataset, indices[:num_train_set]) dataset_test = torch.utils.data.Subset(dataset_test, indices[num_train_set:]) # define training and validation data loaders data_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers, collate_fn=utils.collate_fn) data_loader_test = torch.utils.data.DataLoader(dataset_test, batch_size=batch_size, shuffle=False, num_workers=num_workers, collate_fn=utils.collate_fn) # get the model using our helper function model = get_model(num_classes) # move model to the right device model.to(device) # construct an optimizer params = [p for p in model.parameters() if p.requires_grad] optimizer = torch.optim.SGD(params, lr=learning_rate, momentum=momentum, weight_decay=weight_decay) # and a learning rate scheduler lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', patience=10) max_AP = 0 max_acc = 0 cm_when_max_acc = 0 ps_when_max_acc = 0 rc_when_max_acc = 0 for epoch in range(num_epochs): # train for one epoch, printing every 10 iterations logger = train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq=10) writer.add_scalar('epoch-loss', logger.loss.total, epoch) # evaluate on the test dataset evaler, pick = evaluate(model, data_loader_test, device=device) # 在TensorBoard中显示参数和测试结果 writer.add_scalar('mAP', evaler.coco_eval['bbox'].stats[0], epoch) writer.add_scalar('AP50', evaler.coco_eval['bbox'].stats[1], epoch) writer.add_scalar('AP75', evaler.coco_eval['bbox'].stats[2], epoch) bboxes = pick[1][0][ 'boxes'] # 如果输出是X,Y,W,H 就要多一行代码: bboxes[:, 2:] += bboxes[:, :2] preds = ["background"] + list(cls2id.keys()) labels = pick[1][0]['labels'] scores = pick[1][0]['scores'] bboxes_str = [] for i in range(len(labels)): bboxes_str.append("{} {:.2%}".format(preds[int(labels[i])], float(scores[i]))) # tensorboard中输出mask,若需要可取消注释 # masks = pick[1][0]['masks'] # final_mask = torch.zeros(masks.shape[-2:], dtype=torch.bool) # for i in masks > 0.5: # final_mask = torch.bitwise_or(i[0], final_mask) # writer.add_image('test_mask', final_mask, epoch, dataformats='HW') writer.add_image('real_img', pick[0][0], epoch) writer.add_image_with_boxes('test_img', pick[0][0], bboxes, epoch, labels=bboxes_str) # 在测试集上的AP50以及准确率作为评测指标来选取最佳模型,保存最后一个最佳模型(AP50或准确率) if max_AP < evaler.coco_eval['bbox'].stats[1]: max_AP = evaler.coco_eval['bbox'].stats[1] torch.save( model, os.path.join(model_save_path, stamp + '_' + time_stamp + '.pth')) acc, cm, ps, rc = custom_eval_sanan_when_train(model, data_loader_test, device=device) writer.add_scalar('acc', acc, epoch) if max_acc < acc: max_acc = acc cm_when_max_acc = cm ps_when_max_acc = ps rc_when_max_acc = rc torch.save( model, os.path.join(model_save_path, stamp + '_' + time_stamp + '.pth')) # update the learning rate writer.add_scalar('lr', logger.lr.value, epoch) lr_scheduler.step(evaler.coco_eval['bbox'].stats[1] ) # AP50连续patience个epoch不超过当前最大值,就降低学习率 print("That's it!") print('acc:{:.2%}'.format(max_acc)) print(cm_when_max_acc) print('precision:', ['{:.2%}'.format(x) for x in ps_when_max_acc]) print('recall:', ['{:.2%}'.format(x) for x in rc_when_max_acc])
def train_model(args): writer = SummaryWriter() transforms = DetectionTransform(output_size=args.resize, greyscale=True, normalize=True) dataset = FBSDetectionDataset(database_path=args.db, data_path=args.images, greyscale=True, transforms=transforms, categories_filter={'person': True}, area_filter=[100**2, 500**2]) dataset.print_categories() ''' Split dataset into train and validation ''' train_len = int(0.75 * len(dataset)) dataset_lens = [train_len, len(dataset) - train_len] print("Splitting dataset into pieces: ", dataset_lens) datasets = torch.utils.data.random_split(dataset, dataset_lens) print(datasets) ''' Setup the data loader objects (collation, batching) ''' loader = torch.utils.data.DataLoader(collate_fn=collate_detection_samples, dataset=datasets[0], batch_size=args.batch_size, pin_memory=True, num_workers=args.num_data_workers) validation_loader = torch.utils.data.DataLoader( dataset=datasets[1], batch_size=args.batch_size, pin_memory=True, collate_fn=collate_detection_samples, num_workers=args.num_data_workers) ''' Select device (cpu/gpu) ''' device = torch.device(args.device) ''' Create the model and transfer weights to device ''' model = ObjectDetection(input_image_shape=args.resize, pos_threshold=args.pos_anchor_iou, neg_threshold=args.neg_anchor_iou, num_classes=len(dataset.categories), predict_conf_threshold=0.5).to(device) ''' Select optimizer ''' optim = torch.optim.SGD(params=model.parameters(), lr=args.lr, momentum=0.5) lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer=optim, step_size=2, gamma=0.1, last_epoch=-1) ''' Outer training loop ''' for epoch in range(1, args.epochs + 1): ''' Inner training loop ''' print("\n BEGINNING TRAINING STEP EPOCH {}".format(epoch)) cummulative_loss = 0.0 start_time = time.time() batch: ObjectDetectionBatch for idx, batch in enumerate(loader): ''' Reset gradient ''' optim.zero_grad() ''' Push the data to the gpu (if necessary) ''' batch.to(device) batch.debug = True if idx % args.log_interval == 0 else False ''' Run the model ''' losses, model_data = model(batch) cummulative_loss += losses["class_loss"].item() ''' Calc gradient and step optimizer. ''' losses['class_loss'].backward() optim.step() ''' Log Metrics and Visualizations ''' if (idx + 1) % args.log_interval == 0: step = (epoch - 1) * len(loader) + idx + 1 print( "Ep {} Training Step {} Batch {}/{} Loss : {:.3f}".format( epoch, step, idx, len(loader), cummulative_loss)) ''' Save visualizations and metrics with tensorboard Note: For research, to reproduce graphs you will want some way to save the collected metrics (e.g. the loss values) to an array for recreating figures for a paper. To do so, metrics are often wrapped in a "metering" class that takes care of logging to tensorboard, resetting cumulative metrics, saving arrays, etc. ''' ''' training_image - the raw training images with box labels training_image_predicted_anchors - predictions for the same image, using basic thresholding (0.7 confidence on the logit) training_image_predicted_post_nms - predictions for the same image, filtered at 0.7 confidence followed by Non-Max-Suppression training_image_positive_anchors - shows anchors which received a positive label in the labeling step in the model ''' sample_image = normalize_tensor(batch.images[0]) writer.add_image_with_boxes("training_image", sample_image, box_tensor=batch.boxes[0], global_step=step) writer.add_image_with_boxes( "training_image_predicted_anchors", sample_image, model_data["pos_predicted_anchors"][0], global_step=step) keep_ind = nms(model_data["pos_predicted_anchors"][0], model_data["pos_predicted_confidence"][0], iou_threshold=args.nms_iou) writer.add_image_with_boxes( "training_image_predicted_post_nms", sample_image, model_data["pos_predicted_anchors"][0][keep_ind], global_step=step) writer.add_image_with_boxes( "training_image_positive_anchors", sample_image, box_tensor=model_data["pos_labeled_anchors"][0], global_step=step) ''' Scalars - batch_time, training loss ''' writer.add_scalar( "batch_time", ((time.time() - start_time) / float(args.log_interval)) * 1000.0, global_step=step) writer.add_scalar("training_loss", losses['class_loss'].item(), global_step=step) writer.add_scalar( "avg_pos_labeled_anchor_conf", torch.tensor([ c.mean() for c in model_data["pos_labeled_confidence"] ]).mean().item(), global_step=step) start_time = time.time() writer.close() ''' Reset metric meters as necessary ''' if idx % args.metric_interval == 0: cummulative_loss = 0.0 ''' Inner validation loop ''' print("\nBEGINNING VALIDATION STEP {}\n".format(epoch)) with torch.no_grad(): batch: ObjectDetectionBatch for idx, batch in enumerate(validation_loader): ''' Push the data to the gpu (if necessary) ''' batch.to(device) batch.debug = True if idx % args.log_interval == 0 else False ''' Run the model ''' losses, model_data = model(batch) if idx % args.log_interval == 0: step = (epoch - 1) * len(validation_loader) + idx + 1 print("Ep {} Validation Step {} Batch {}/{} Loss : {:.3f}". format(epoch, step, idx, len(validation_loader), losses["class_loss"].item())) ''' Log Images ''' sample_image = normalize_tensor(batch.images[0]) writer.add_image_with_boxes("validation_images", sample_image, box_tensor=batch.boxes[0], global_step=step) writer.add_image_with_boxes( "validation_img_predicted_anchors", sample_image, model_data["pos_predicted_anchors"][0], global_step=step) keep_ind = nms(model_data["pos_predicted_anchors"][0], model_data["pos_predicted_confidence"][0], iou_threshold=0.5) print("Indicies after NMS: ", keep_ind, model_data["pos_predicted_confidence"][0].shape, model_data["pos_predicted_anchors"][0].shape) writer.add_image_with_boxes( "validation_img_predicted_post_nms", sample_image, model_data["pos_predicted_anchors"][0][keep_ind], global_step=step) ''' Log Scalars ''' writer.add_scalar("validation_loss", losses['class_loss'].item(), global_step=step) writer.close() lr_scheduler.step() print("Stepped learning rate. Rate is now: ", lr_scheduler.get_lr())
class Trainer: ''' a class for train SPAIR. with support of drawing boundary, logging and scoring. ''' def __init__(self, Implement=SPAIR, **config): save_cfg = self.checkpoint(config.get("model_path", None)) if save_cfg: self.loadConfig(config, save_cfg["config"]) self.start_epoch = save_cfg["epoch"] self.spair = Implement(**self.config.get('model', {})) self.spair.load_state_dict(save_cfg["model"]) else: self.loadConfig(config) self.start_epoch = 0 self.spair = Implement(**self.config.get('model', {})) self.kl_bulder = KL_Builder(**self.config['KL_Builder']) self.op = eval(self.optimizer)(self.spair.parameters(), **self.config.get(self.optimizer, {})) self.spair.to(self.device) @staticmethod def checkpoint(path) -> dict: if not path: return if os.path.exists(path): return torch.load(path) else: folder = os.path.dirname(path) if not os.path.exists(folder): os.makedirs( folder) # make sure `save` will not raise exception def loadConfig(self, config, default: dict = None): if default: self.config = default.copy() else: self.config = {} self.config.update(config) self.summary = SummaryWriter(self.config["logdir"], self.config["logdir"].split('/')[-1]) self.device = torch.device(self.config.get('device', 'cuda:0')) self.optimizer = self.config.get('optimizer', 'Adam') self._sup = self.config.get('KL_lambda', 1) def save(self): d = { "config": self.config, "epoch": self.cur_epoch, "model": self.spair.state_dict() } folder = os.path.dirname(self.config["model_path"]) if not os.path.exists(folder): os.makedirs(folder) # make sure `save` will not raise exception torch.save(d, self.config["model_path"]) def loss(self, rec, target, param_dict: dict, global_step) -> torch.Tensor: recon_loss = binary_cross_entropy(rec, target, reduction='sum') norm_loss = self.kl_bulder.norm_KL(param_dict) kin_loss = self.kl_bulder.bin_KL(param_dict['pres'], global_step) assert not torch.isnan(recon_loss) self.summary.add_scalar('loss/reconstruct', recon_loss, global_step) self.summary.add_scalar('loss/normal', norm_loss, global_step) self.summary.add_scalar('loss/bernoulli', kin_loss, global_step) return recon_loss + self._sup * (norm_loss + kin_loss) def __histogram(self, tag: str, value: torch.Tensor, global_step=None, dim=None): with torch.no_grad(): if dim: for i, v in enumerate(torch.split(value, 1, dim=dim)): self.summary.add_histogram( "sample/%s/%s_%d" % (tag, tag, i), v, global_step) else: self.summary.add_histogram("sample/%s" % tag, value, global_step) def rectangle(self, rec, norm_box, pres, global_step): ''' rec: [N, C, H_img, W_img] norm_box: [N, H*W, 4], y_center, x_center, height, width pres: [N, H*W, 1] ''' with torch.no_grad(): norm_box[:, :, :2] *= self.spair.encoder.image_size norm_box[:, :, 2:] *= self.spair.encoder.image_size / 2 norm_box = torch.round_( norm_box) # y_center, x_center, height / 2, width / 2 norm_box = torch.stack( ( norm_box[:, :, 1] - norm_box[:, :, 3], # xmin norm_box[:, :, 0] - norm_box[:, :, 2], # ymin norm_box[:, :, 1] + norm_box[:, :, 3], # xmax norm_box[:, :, 0] + norm_box[:, :, 2], # ymax ), dim=-1) pres = torch.round_(pres).bool().squeeze_(-1) for i, (img, box, zpres) in enumerate(zip(rec, norm_box, pres)): box = torch.stack([b for b, z in zip(box, zpres) if z]) self.summary.add_image_with_boxes('detect/rec_%d' % i, img, box, global_step) def reconstruct(self, X, bg=None): ''' X: [N, H, W, C] return: [N, H, W, C] ''' self.spair.eval() if X.dim() == 3: X = X.unsqueeze_(0) prev_device = X.device with torch.no_grad(): return self.spair( X.to(self.device).permute(0, 3, 1, 2), bg.permute(0, 3, 1, 2)).to(prev_device) def train(self, X: torch.Tensor, bg=None): ''' X shape: [N, H, W, C]; ''' # torch.autograd.set_detect_anomaly(True) if bg is None: bg = torch.zeros_like(X) data = TensorDataset( X.permute(0, 3, 1, 2).to(self.device), # [N, 3, H, W] bg.permute(0, 3, 1, 2).to(self.device) # [N, 3, H, W] ) loader = DataLoader(data, **self.config.get("loader", {})) max_batch = len(loader) # ============================== # ========= train here ========= def one_epoch(): citer = max_batch * self.cur_epoch for batch, (R, B) in enumerate(loader): self.op.zero_grad() where, what, depth, pres, pd = self.spair.encoder(R) self.__histogram('where', pd['where'].mean, citer + batch, -1) self.__histogram('what', what, citer + batch) self.__histogram('depth', depth, citer + batch) self.__histogram('pres', pres, citer + batch) rec = self.spair.decoder(where, what, depth, pres, B) loss = self.loss(rec, R, pd, citer + batch) loss.backward() self.op.step() bar.next() # ========= train end ========== # ============================== for self.cur_epoch in range(self.start_epoch, self.config["max_epoch"]): bar = Bar('epoch%3d' % (self.cur_epoch + 1), max=max_batch) try: self.spair.train() one_epoch() yield self.cur_epoch # see what my caller wanna do after one epoch bar.finish() except KeyboardInterrupt: bar.finish() if 'Y' == input('save model? Y/n ').upper(): self.save() print("Saved. Start from epoch%d next time." % (self.cur_epoch + 1)) return self.cur_epoch = self.config["max_epoch"] self.save() def __enter__(self): return self def __exit__(self, exc_type, exc_val, exc_tb): self.summary.close()
class ObjectDetectionTrainer: """Object detection training manager. """ def __init__(self, max_epoch: int, dataset_train_path: str, dataset_validation_path: str, output_grid_size: int, anchor_boxes: List[BoundingBox], target_device: str, label_map: Dict[str, int], target_image_size: Tuple[int, int] = (320, 320), n_classes: int = 1): """Create a new ObjectDetectionTrainer Args: max_epoch (int): Maximum number of epoch to run the training. dataset_path (str): Path to the folder containing the json files. output_grid_size (int): Number of grids for the prediction. The model will output (output_grid_size, output_grid_size) cells. anchor_boxes (List[BoundingBox]): List of bounding box to be used as anchor. target_device (str): The target device of that Pytorch will use. Example ("cuda:0", "cpu:0") target_image_size (Tuple[int,int], optional): The expected size of the network input (width, height). Defaults to (320,320). """ self.n_classes = n_classes self.max_epoch = max_epoch self.target_image_size = target_image_size self.dataset_train = ObjectDetectionDataset( dataset_train_path, target_image_size, transform=image_bb_transforms, label_map=label_map) self.dataset_validation = ObjectDetectionDataset( dataset_validation_path, target_image_size, transform=image_bb_transforms, label_map=label_map) self.boxes_per_cell = len(anchor_boxes) self.anchor_boxes = anchor_boxes self.output_grid_size = output_grid_size self.model = ObjectDetectionModel(n_classes, self.output_grid_size, self.boxes_per_cell) self.target_device = target_device self.model.to(self.target_device) self.optimizer = Adam(self.model.parameters(), lr=1e-3) self.label_map = label_map # Timestamp to assign unique id to logs now = datetime.now() date_time = now.strftime("%m_%d_%Y_%H_%M_%S") self.run_id = date_time self.writer = SummaryWriter("logs/%s" % (self.run_id)) self.scheduler = ReduceLROnPlateau(self.optimizer, 'min', verbose=True, patience=10) def calculate_iou(self, b1: BoundingBox, b2: BoundingBox) -> float: """Calculate intersection over union Args: b1 (BoundingBox): First bounding box b2 (BoundingBox): Second bounding box Returns: float: iou score between the boxes """ x_a = max(b1.min_x, b2.min_x) y_a = max(b1.min_y, b2.min_y) x_b = min(b1.max_x, b2.max_x) y_b = min(b1.max_y, b2.max_y) intersection = max(0, x_b - x_a + 1) * max(0, y_b - y_a + 1) area_b1 = (b1.max_x - b1.min_x + 1) * (b1.max_y - b1.min_y + 1) area_b2 = (b2.max_x - b2.min_x + 1) * (b2.max_y - b2.min_y + 1) iou = intersection / float(area_b1 + area_b2 - intersection) return iou # assume output is (batch_size*n_cells, cell_tensor_size) def parse_the_outputs(self, output: Type[torch.FloatTensor], boxes: List[List[BoundingBox]], classes: List[List[int]]): """Takes the output of the neural network and the target bounding boxes Args: output (torch.FloatTensor): The output from the neural network, reshaped as (batch_size*n_cells, cell_tensor_size) boxes (List[List[BoundingBox]]): List of list of target bounding boxes for each image clasess (List[int]): List of list of target classes for each image Returns: [type]: [description] """ # Each cell has [ self.boxes_per_cell * 5 + n_classes ] entries, extract each relevant part bbox_slices = output[:, :self.boxes_per_cell * 5].reshape(-1, 5) class_slices = output[:, self.boxes_per_cell * 5:] positive_bbox_outputs = [] positive_bbox_targets = [] negative_bbox_outputs = [] positive_classprob_indices = [] positive_classprob_target = [] positive_classprob_output = [] anchor_ids = [] positive_bbox_indices_list = [] # total number of cells in the batch output n_cells = output.shape[0] # gather positive cell indices and corresponding anchor boxes for image_id, bb_list in enumerate(boxes): # Process all bbs in the image for bb_id, bb in enumerate(bb_list): # Center in output grid coordinate center_x_grid = self.output_grid_size * (bb.min_x + bb.max_x) / 2 center_y_grid = self.output_grid_size * (bb.min_y + bb.max_y) / 2 # Get cell index in the batch cell_index = math.floor( center_y_grid) * self.output_grid_size + math.floor( center_x_grid) cell_index += image_id * self.model.n_cells # Calc bounding box offset within the grid offset_x = center_x_grid - math.floor(center_x_grid) offset_y = center_y_grid - math.floor(center_y_grid) # Calc bb size width = (bb.max_x - bb.min_x) height = (bb.max_y - bb.min_y) matching_anchors = [] max_anchor_id = -1 anchor_offset_x = (bb.min_x + bb.max_x) / 2 anchor_offset_y = (bb.min_y + bb.max_y) / 2 # Find matching matching boxes for anchor_id, anchor in enumerate(self.anchor_boxes): adjusted_anchor = BoundingBox( min_x=anchor.min_x + anchor_offset_x, min_y=anchor.min_y + anchor_offset_y, max_x=anchor.max_x + anchor_offset_x, max_y=anchor.max_y + anchor_offset_y) iou = self.calculate_iou(bb, adjusted_anchor) if iou > 0.5: matching_anchors.append((iou, anchor_id)) # Find the maximum matching anchor if matching_anchors: max_anchor_id = max(matching_anchors, key=lambda t: t[0])[1] anchor = self.anchor_boxes[max_anchor_id] anchor_w = anchor.max_x - anchor.min_x anchor_h = anchor.max_y - anchor.min_y # calc bbox index bbox_index = cell_index * self.model.boxes_per_cell + max_anchor_id # Calc target tensor target_tensor = torch.FloatTensor( [offset_x, offset_y, width, height, 1.0]) positive_bbox_targets.append(target_tensor) positive_bbox_indices_list.append(bbox_index) # class prob target #class_target = torch.zeros(self.n_classes) #class_target[classes[image_id][bb_id]] = 1.0 positive_classprob_output.append(class_slices[cell_index]) positive_classprob_target.append(classes[image_id][bb_id]) anchor_ids.append(max_anchor_id) # Make a set for quick membership lookup positive_bbox_indices_set = set(positive_bbox_indices_list) # Other than positive indices, combine as negative indices negative_bbox_indices_list = torch.LongTensor([ s for s in range(bbox_slices.shape[0]) if s not in positive_bbox_indices_set ]) # Get positive bbox and negative bbox positive_bbox_outputs = bbox_slices[positive_bbox_indices_list, :] negative_bbox_outputs = bbox_slices[negative_bbox_indices_list, :] positive_classprob_output = torch.stack(positive_classprob_output) positive_classprob_target = torch.LongTensor( positive_classprob_target).to(self.target_device) # Stack together bbox target tensors bbox_targets = torch.stack(positive_bbox_targets, 0).to(self.target_device) # Keep track of anchor sizes. This will be used as scaling factor for bbox size anchor_sizes = [[ self.anchor_boxes[s].max_x - self.anchor_boxes[s].min_x, self.anchor_boxes[s].max_y - self.anchor_boxes[s].min_y ] for s in anchor_ids] anchor_sizes = np.array(anchor_sizes) return positive_bbox_outputs, negative_bbox_outputs, bbox_targets, positive_classprob_output, positive_classprob_target, anchor_sizes def preview_generate(self, input_img: Type[torch.FloatTensor], output: Type[torch.FloatTensor], target_box: List[BoundingBox]): """Detection results preview generator Args: input_img (Type[torch.FloatTensor]): Input image (C,H,W) output (Type[torch.FloatTensor]) : Network output (-1,CELL_SIZE) target_box (List[BoundingBox]): List of the boxes in image[0] """ np_image = np_image_from_tensor(input_img) img_height = np_image.shape[0] img_width = np_image.shape[1] # Reshape cells to 2D bb_outputs = output.view(self.model.output_grid_size, self.model.output_grid_size, -1)[:, :, :self.boxes_per_cell * 5] bb_outputs = bb_outputs.reshape(self.model.output_grid_size, self.model.output_grid_size, self.model.boxes_per_cell, -1) # Get objectness outputs objectness_output = torch.sigmoid(bb_outputs[:, :, :, -1]) # Only select if objectness > 0.5, will return list of index tuples candidate_cell_index = torch.nonzero( objectness_output > 0.1).detach().cpu().numpy().tolist() # Bounding box outputs of candidate boxes candidates_bb = [] # Cell sizes in normalized coordinate cell_size_x = 1.0 / self.model.output_grid_size cell_size_y = 1.0 / self.model.output_grid_size # Bounding boxes showing candidate cells cell_bboxes = [] for idx in candidate_cell_index: # Get output of current candidate cell bb = bb_outputs[idx[0], idx[1], idx[2]][:4] # idx[2] corresponds to the anchor idx anchor_w = self.anchor_boxes[idx[2]].max_x - self.anchor_boxes[ idx[2]].min_x anchor_h = self.anchor_boxes[idx[2]].max_y - self.anchor_boxes[ idx[2]].min_y # Compute offset and sizes bb[:2] = torch.sigmoid(bb[:2]) bb[2:4] = torch.exp(bb[2:4]) bb[2] *= anchor_w bb[3] *= anchor_h # Compute the origin of the cell coordinate in pixel coordinate base_x = cell_size_x * (idx[1]) * img_width base_y = cell_size_y * (idx[0]) * img_height preview_bb = bb.detach() preview_bb[0] = base_x + preview_bb[ 0] * cell_size_x * img_width - preview_bb[2] * img_width / 2.0 preview_bb[1] = base_y + preview_bb[ 1] * cell_size_y * img_height - preview_bb[3] * img_height / 2.0 preview_bb[2] = preview_bb[0] + preview_bb[2] * img_width preview_bb[3] = preview_bb[1] + preview_bb[3] * img_height preview_bb[[0, 2]] = torch.clamp_min(preview_bb[[0, 2]], 0.0) preview_bb[[0, 2]] = torch.clamp_max(preview_bb[[0, 2]], img_width - 1) preview_bb[[1, 3]] = torch.clamp_min(preview_bb[[1, 3]], 0.0) preview_bb[[1, 3]] = torch.clamp_max(preview_bb[[1, 3]], img_height - 1) candidates_bb.append(preview_bb) cell_bbox = np.array([ base_x, base_y, base_x + cell_size_x * img_width, base_y + cell_size_y * img_height ]) cell_bboxes.append(cell_bbox) target_boxes = [] for target in target_box: bb_ref = list(target) bb_ref = np.array(bb_ref) # Convert to pixel coordinate bb_ref[[0, 2]] *= img_width bb_ref[[1, 3]] *= img_height target_boxes.append(bb_ref) if len(cell_bboxes) > 0: cell_bboxes = np.stack(cell_bboxes) else: cell_bboxes = torch.FloatTensor([]) #self.writer.add_image_with_boxes("Preview/cell", np.array(image), cell_bboxes, step, dataformats='HWC') #else: # self.writer.add_image("Preview/cell", np.array(image), step, dataformats='HWC') if len(target_boxes) > 0: target_boxes = np.stack(target_boxes) else: target_boxes = torch.FloatTensor([]) #self.writer.add_image_with_boxes("Preview/target", np.array(image), target_boxes, step, dataformats='HWC') if len(candidates_bb) > 0: candidates_bb = torch.stack(candidates_bb, dim=0) else: candidates_bb = torch.FloatTensor([]) #self.writer.add_image_with_boxes("Preview/detection", np.array(image), candidates_bb, step, dataformats='HWC') #else: # self.writer.add_image("Preview/detection", np.array(image), step, dataformats='HWC') preview_info = { "image": np_image, "target_bb": target_boxes, "candidate_bb": candidates_bb, "cell_bb": cell_bboxes } return preview_info def run_batch(self, batch_data: Type[torch.Tensor], epoch: int): """Train batch Args: batch_data (Type[torch.Tensor]): Batch of data epoch (int): epoch number Returns: Dict: Dict of losses """ images = batch_data["images"] boxes = batch_data["bboxes"] classes = batch_data["labels"] images = images.float().to(self.target_device) self.optimizer.zero_grad() # run network output = self.model(images) # Reshape to cell outputs cell_outputs = output.view(-1, self.model.cell_tensor_size) # Parse the cell outputs positive_bbox, negative_bbox, positive_bbox_targets, positive_classprob_output, positive_classprob_target, anchor_sizes = self.parse_the_outputs( cell_outputs, boxes, classes) # Cell structure, ( box1, box2, box3, ..., objectness, class_1, class_2, ... ) positive_objectness_output = positive_bbox[:, -1].flatten() negative_objectness_output = negative_bbox[:, -1].flatten() objectness_output = torch.cat( [positive_objectness_output, negative_objectness_output], dim=0) # generate target objectness tensor objectness_target_positive = torch.ones( positive_objectness_output.shape[0]).to(self.target_device) objectness_target_negative = torch.zeros( negative_objectness_output.shape[0]).to(self.target_device) objectness_positive_loss = nn.functional.mse_loss( input=torch.sigmoid(positive_objectness_output), target=objectness_target_positive, reduction="sum") objectness_negative_loss = nn.functional.mse_loss( input=torch.sigmoid(negative_objectness_output), target=objectness_target_negative, reduction="sum") # Apply weight to the losses to handle class imbalance objectness_loss = objectness_positive_loss + 0.5 * objectness_negative_loss # Update offset to sigmoid positive_bbox[:, :2] = torch.sigmoid(positive_bbox[:, :2]) # width = exp(w) * anchor_w # height = exp(h) * anchor_h positive_bbox[:, 2:4] = torch.exp(positive_bbox[:, 2:4]) positive_bbox[:, 2:4] *= torch.from_numpy(anchor_sizes).to( self.target_device) # Calc bb loss bb_loss_offset = nn.functional.mse_loss(positive_bbox[:, :2], positive_bbox_targets[:, :2], reduction='sum') bb_loss_dims = nn.functional.mse_loss( torch.sqrt(positive_bbox[:, 2:4]), torch.sqrt(positive_bbox_targets[:, 2:4]), reduction='sum') bb_loss = bb_loss_offset + bb_loss_dims # Calc classes prob loss classes_loss = nn.functional.nll_loss( input=nn.functional.log_softmax(positive_classprob_output), target=positive_classprob_target, reduction='sum') # Make bb_loss stronger to counter gradient from negative objectness total_loss = 1 * objectness_loss + 5 * bb_loss + classes_loss total_loss.backward() self.optimizer.step() preview_info = self.preview_generate(images[0], output[0], boxes[0]) loss_info = { "objectness": objectness_loss.detach().item(), "bb": bb_loss.detach().item(), "class": classes_loss.detach().item(), "total": total_loss.detach().item() } return loss_info, preview_info def start_training(self): # Create dataloader to batch the dataset dataloader_train = DataLoader(self.dataset_train, 32, True, collate_fn=object_dataset_collate_fn, num_workers=6, worker_init_fn=worker_init_fn) sampler = torch.utils.data.RandomSampler(self.dataset_validation, replacement=True, num_samples=32) dataloader_validation = DataLoader( self.dataset_validation, 1, False, collate_fn=object_dataset_collate_fn, num_workers=1, worker_init_fn=worker_init_fn, sampler=sampler) for ep in range(self.max_epoch): print("epoch : {}".format(ep)) objectness_loss = [] bb_loss = [] total_loss = [] classes_loss = [] # Train batches for idx, data in enumerate(dataloader_train): losses, preview_images = self.run_batch(data, idx) objectness_loss.append(losses["objectness"]) bb_loss.append(losses["bb"]) total_loss.append(losses["total"]) classes_loss.append(losses["class"]) mean_objectness_loss = np.mean(objectness_loss) mean_bb_loss = np.mean(bb_loss) mean_total_loss = np.mean(total_loss) mean_classes_loss = np.mean(classes_loss) self.writer.add_scalar("training/mean_objectness_loss", mean_objectness_loss, ep) self.writer.add_scalar("training/mean_bb_loss", mean_bb_loss, ep) self.writer.add_scalar("training/mean_classes_loss", mean_classes_loss, ep) self.writer.add_scalar("training/mean_total_loss", mean_total_loss, ep) self.writer.add_image_with_boxes("training/target", preview_images["image"], preview_images["target_bb"], ep, dataformats='HWC') self.writer.add_image_with_boxes("training/cell", preview_images["image"], preview_images["cell_bb"], ep, dataformats='HWC') self.writer.add_image_with_boxes("training/detection", preview_images["image"], preview_images["candidate_bb"], ep, dataformats='HWC') if ep % 20 == 0: val_objectness_losses = [] val_bb_losses = [] val_class_losses = [] val_total_losses = [] # Eval mode for validation self.model.eval() # Validation for idx, data in enumerate(dataloader_validation): losses, preview_images = self.run_batch(data, idx) val_objectness_losses.append(losses["objectness"]) val_bb_losses.append(losses["bb"]) val_total_losses.append(losses["total"]) val_class_losses.append(losses["class"]) # Return to train mode self.model.train() # Update learning rate scheduler step self.scheduler.step(mean_total_loss) val_mean_objectness_loss = np.mean(val_objectness_losses) val_mean_bb_loss = np.mean(val_objectness_losses) val_mean_total_loss = np.mean(val_total_losses) val_mean_classes_loss = np.mean(val_class_losses) self.writer.add_scalar("validation/mean_objectness_loss", val_mean_objectness_loss, ep) self.writer.add_scalar("validation/mean_bb_loss", val_mean_bb_loss, ep) self.writer.add_scalar("validation/mean_classes_loss", val_mean_classes_loss, ep) self.writer.add_scalar("validation/mean_total_loss", val_mean_total_loss, ep) self.writer.add_image_with_boxes("validation/target", preview_images["image"], preview_images["target_bb"], ep, dataformats='HWC') self.writer.add_image_with_boxes("validation/cell", preview_images["image"], preview_images["cell_bb"], ep, dataformats='HWC') self.writer.add_image_with_boxes( "validation/detection", preview_images["image"], preview_images["candidate_bb"], ep, dataformats='HWC') checkpoint_dir = "checkpoints/{}".format(self.run_id) if (not os.path.exists(checkpoint_dir)): os.makedirs(checkpoint_dir) torch.save( self.model.state_dict(), "{}/{}_{}.pth".format(checkpoint_dir, ep, mean_total_loss))