def main(args): cfg = setup(args) model = build_detection_model(cfg) device = torch.device(cfg.MODEL.DEVICE) model.to(device) if args.eval_only: checkpointer = DetectronCheckpointer(cfg, model, save_dir=cfg.OUTPUT_DIR) ckpt = cfg.MODEL.WEIGHT if args.ckpt is None else args.ckpt _ = checkpointer.load(ckpt, use_latest=args.ckpt is None) return run_test(cfg, model) distributed = comm.get_world_size() > 1 if distributed: model = torch.nn.parallel.DistributedDataParallel( model, device_ids=[comm.get_local_rank()], broadcast_buffers=False, find_unused_parameters=True, ) train(cfg, model, device, distributed)
def main(args): cfg = setup(args) checkpoints_path = cfg.OUTPUT_DIR val_mAP = [] iteration_list = [] for model_name in os.listdir(checkpoints_path): if "pth" not in model_name or "final" in model_name: continue iteration = int(model_name.split(".")[0].split('_')[1]) iteration_list.append(iteration) iteration_list = sorted(iteration_list) for iteration in iteration_list: model = build_detection_model(cfg) device = torch.device(cfg.MODEL.DEVICE) model.to(device) checkpointer = DetectronCheckpointer(cfg, model, save_dir=cfg.OUTPUT_DIR) model_name = "model_{:07d}.pth".format(iteration) ckpt = os.path.join(checkpoints_path, model_name) _ = checkpointer.load(ckpt, use_latest=False) run_test(cfg, model) gt_label_path = "datasets/kitti/training/label_2/" pred_label_path = os.path.join(cfg.OUTPUT_DIR, "inference", "kitti_train", "data") result_dict = evaluate_kitti_mAP(gt_label_path, pred_label_path, ["Car", "Pedestrian", "Cyclist"]) if result_dict is not None: mAP_3d_moderate = result_dict["mAP3d"][1] val_mAP.append(mAP_3d_moderate) with open(os.path.join(cfg.OUTPUT_DIR, "val_mAP.json"),'w') as file_object: json.dump(val_mAP, file_object) with open(os.path.join(cfg.OUTPUT_DIR, 'epoch_result_{:07d}_{}.txt'.format(iteration, round(mAP_3d_moderate, 2))), "w") as f: f.write(result_dict["result"]) print(result_dict["result"])
def do_train( cfg, distributed, model, data_loader, optimizer, scheduler, checkpointer, device, checkpoint_period, arguments, ): logger = logging.getLogger("smoke.trainer") logger.info("Start training") meters = MetricLogger(delimiter=" ") max_iter = cfg.SOLVER.MAX_ITERATION start_iter = arguments["iteration"] model.train() start_training_time = time.time() end = time.time() checkpoint_period = cfg.SOLVER.CHECKPOINT_PERIOD evaluate_period = cfg.SOLVER.EVALUATE_PERIOD is_val = cfg.SOLVER.IS_VAL for data, iteration in zip(data_loader, range(start_iter, max_iter)): data_time = time.time() - end iteration += 1 arguments["iteration"] = iteration images = data["images"].to(device) targets = [target.to(device) for target in data["targets"]] loss_dict = model(images, targets) losses = sum(loss for loss in loss_dict.values()) # reduce losses over all GPUs for logging purposes loss_dict_reduced = reduce_loss_dict(loss_dict) losses_reduced = sum(loss for loss in loss_dict_reduced.values()) meters.update(loss=losses_reduced, **loss_dict_reduced) optimizer.zero_grad() losses.backward() optimizer.step() scheduler.step() batch_time = time.time() - end end = time.time() meters.update(time=batch_time, data=data_time) eta_seconds = meters.time.global_avg * (max_iter - iteration) eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) if iteration % 20 == 0 or iteration == max_iter: logger.info( meters.delimiter.join([ "eta: {eta}", "iter: {iter}", "{meters}", "lr: {lr:.8f}", "max men: {memory:.0f}", ]).format(eta=eta_string, iter=iteration, meters=str(meters), lr=optimizer.param_groups[0]["lr"], memory=torch.cuda.max_memory_allocated() / 1024.0 / 1024.0)) # fixme: do we need checkpoint_period here if iteration % checkpoint_period == 0 or iteration in cfg.SOLVER.STEPS: checkpointer.save("model_{:07d}".format(iteration), **arguments) if iteration % evaluate_period == 0 and is_val: run_test(cfg, model) model.train() if iteration == max_iter: checkpointer.save("model_final", **arguments) # todo: add evaluations here # if iteration % evaluate_period == 0: # test_net.main() total_training_time = time.time() - start_training_time total_time_str = str(datetime.timedelta(seconds=total_training_time)) logger.info("Total training time: {} ({:.4f} s / it)".format( total_time_str, total_training_time / (max_iter)))