def train(cfg, local_rank, distributed): num_classes = COCODataset(cfg.data.train[0], cfg.data.train[1]).num_classes model = EfficientDet(num_classes=num_classes, model_name=cfg.model.name) inp_size = model.config['inp_size'] device = torch.device(cfg.device) model.to(device) optimizer = build_optimizer(model, **optimizer_kwargs(cfg)) lr_scheduler = build_lr_scheduler(optimizer, **lr_scheduler_kwargs(cfg)) use_mixed_precision = cfg.dtype == "float16" amp_opt_level = 'O1' if use_mixed_precision else 'O0' model, optimizer = amp.initialize(model, optimizer, opt_level=amp_opt_level) if distributed: model = torch.nn.parallel.DistributedDataParallel( model, device_ids=[local_rank], output_device=local_rank, # this should be removed if we update BatchNorm stats broadcast_buffers=False, find_unused_parameters=True) arguments = {} arguments["iteration"] = 0 output_dir = cfg.output_dir save_to_disk = comm.get_rank() == 0 checkpointer = Checkpointer(model, optimizer, lr_scheduler, output_dir, save_to_disk) extra_checkpoint_data = checkpointer.load(cfg.model.resume) arguments.update(extra_checkpoint_data) train_dataloader = build_dataloader(cfg, inp_size, is_train=True, distributed=distributed, start_iter=arguments["iteration"]) test_period = cfg.test.test_period if test_period > 0: val_dataloader = build_dataloader(cfg, inp_size, is_train=False, distributed=distributed) else: val_dataloader = None checkpoint_period = cfg.solver.checkpoint_period log_period = cfg.solver.log_period do_train(cfg, model, train_dataloader, val_dataloader, optimizer, lr_scheduler, checkpointer, device, checkpoint_period, test_period, log_period, arguments) return model
class Detect(object): """ dir_name: Folder or image_file """ def __init__(self, weights, num_class=21): super(Detect, self).__init__() self.weights = weights self.device = torch.device( "cuda:0" if torch.cuda.is_available() else 'cpu') self.transform = transforms.Compose([Normalizer(), Resizer()]) self.model = EfficientDet(num_classes=num_class, is_training=False) self.model = self.model.to(self.device) if (self.weights is not None): print('Load pretrained Model') state_dict = torch.load(weights) self.model.load_state_dict(state_dict) self.model.eval() def process(self, file_name): img = cv2.imread(file_name) cv2.imwrite('kaka.png', img) img = self.transform(img) img = img.to(self.device) img = img.unsqueeze(0).permute(0, 3, 1, 2) scores, classification, transformed_anchors = self.model(img) print('scores: ', scores) scores = scores.detach().cpu().numpy() idxs = np.where(scores > 0.1) return idxs
def main(): args = parse_args() cfg = get_default_cfg() if args.config_file: cfg.merge_from_file(args.config_file) cfg.merge_from_list(args.opts) cfg.freeze() dataset = COCODataset(cfg.data.test[0], cfg.data.test[1]) num_classes = dataset.num_classes label_map = dataset.labels model = EfficientDet(num_classes=num_classes, model_name=cfg.model.name) device = torch.device(cfg.device) model.to(device) model.eval() inp_size = model.config['inp_size'] transforms = build_transforms(False, inp_size=inp_size) output_dir = cfg.output_dir checkpointer = Checkpointer(model, None, None, output_dir, True) checkpointer.load(args.ckpt) images = [] if args.img: if osp.isdir(args.img): for filename in os.listdir(args.img): if is_valid_file(filename): images.append(osp.join(args.img, filename)) else: images = [args.img] for img_path in images: img = cv2.imread(img_path) img = inference(model, img, label_map, score_thr=args.score_thr, transforms=transforms) save_path = osp.join(args.save, osp.basename(img_path)) cv2.imwrite(save_path, img) if args.vid: vCap = cv2.VideoCapture(args.v) fps = int(vCap.get(cv2.CAP_PROP_FPS)) height = int(vCap.get(cv2.CAP_PROP_FRAME_HEIGHT)) width = int(vCap.get(cv2.CAP_PROP_FRAME_WIDTH)) size = (width, height) fourcc = cv2.VideoWriter_fourcc(*'mp4v') save_path = osp.join(args.save, osp.basename(args.v)) vWrt = cv2.VideoWriter(save_path, fourcc, fps, size) while True: flag, frame = vCap.read() if not flag: break frame = inference(model, frame, label_map, score_thr=args.score_thr, transforms=transforms) vWrt.write(frame) vCap.release() vWrt.release()