Exemple #1
0
def build_model(cfg, args):

    # -----------------------------------------------------------------------------
    # Model
    # -----------------------------------------------------------------------------
    model = build_ssd_model(cfg)
    return model
Exemple #2
0
def run_demo(cfg, weights_file, iou_threshold, score_threshold, images_dir, output_dir, dataset_type):
    if dataset_type == "voc":
        class_names = VOCDataset.class_names
    elif dataset_type == 'coco':
        class_names = COCODataset.class_names
    else:
        raise NotImplementedError('Not implemented now.')

    device = torch.device(cfg.MODEL.DEVICE)
    model = build_ssd_model(cfg, is_test=True)
    model.load(weights_file)
    print('Loaded weights from {}.'.format(weights_file))
    model = model.to(device)
    predictor = Predictor(cfg=cfg,
                          model=model,
                          iou_threshold=iou_threshold,
                          score_threshold=score_threshold,
                          device=device)
    cpu_device = torch.device("cpu")

    image_paths = glob.glob(os.path.join(images_dir, '*.jpg'))

    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    for image_path in tqdm(image_paths):
        image = Image.open(image_path).convert("RGB")
        image = np.array(image)
        output = predictor.predict(image)
        boxes, labels, scores = [o.to(cpu_device).numpy() for o in output]
        drawn_image = draw_bounding_boxes(image, boxes, labels, scores, class_names).astype(np.uint8)
        image_name = os.path.basename(image_path)
        Image.fromarray(drawn_image).save(os.path.join(output_dir, image_name))
def train(cfg, args):
    #logging.basicConfig(filename='./output/LOG/'+__name__+'.log',format='[%(asctime)s-%(filename)s-%(levelname)s:%(message)s]', level = logging.INFO,filemode='a',datefmt='%Y-%m-%d%I:%M:%S %p')
    logging.basicConfig(
        filename='./output/LOG/' + __name__ + '.log',
        format='[%(asctime)s-%(filename)s-%(levelname)s:%(message)s]',
        level=logging.INFO)
    logger = logging.getLogger('SSD.trainer')
    # -----------------------------------------------------------------------------
    # Model
    # -----------------------------------------------------------------------------
    ssd_model = build_ssd_model(cfg)
    ssd_model.init_from_base_net(args.vgg)
    ssd_model = torch.nn.DataParallel(ssd_model,
                                      device_ids=range(
                                          torch.cuda.device_count()))
    device = torch.device(cfg.MODEL.DEVICE)
    print(ssd_model)
    logger.info(ssd_model)
    model = torchvision.models.AlexNet(num_classes=10)
    logger.info(model)
    writer = tensorboardX.SummaryWriter(log_dir="./output/model_graph/",
                                        comment="myresnet")
    #dummy_input = torch.autograd.Variable(torch.rand(1, 3, 227, 227))
    dummy_input = torch.autograd.Variable(torch.rand(1, 3, 300, 300))
    #writer.add_graph(model=ssd_model, input_to_model=(dummy_input, ))
    model_onnx_path = 'torch_model.onnx'
    output = torch_onnx.export(ssd_model,
                               dummy_input,
                               model_onnx_path,
                               verbose=False)

    #ssd_model.to(device)
    print('----------------')
Exemple #4
0
def evaluation(cfg, weights_file, output_dir, distributed):
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
    device = torch.device(cfg.MODEL.DEVICE)
    model = build_ssd_model(cfg)
    model.load(weights_file)
    logger = logging.getLogger("SSD.inference")
    logger.info('Loaded weights from {}.'.format(weights_file))
    model.to(device)
    do_evaluation(cfg, model, output_dir, distributed)
Exemple #5
0
def convert2scriptmodule(cfg, args):

    ssd_model = build_ssd_model(cfg)

    print(ssd_model)

    input = torch.Tensor(1, 3, cfg.INPUT.IMAGE_SIZE, cfg.INPUT.IMAGE_SIZE)
    model_path = args.model_path # 'ssd300_vgg_final.pth'
    ssd_model.load_state_dict(torch.load(model_path))
    save(ssd_model, input, args.model_out) # 'script.pt'
Exemple #6
0
def build_model(cfg, args):

    cfg.merge_from_file("configs/ssd512_voc0712.yaml")
    #cfg.merge_from_list(args.opts)
    cfg.freeze()
    # -----------------------------------------------------------------------------
    # Model
    # -----------------------------------------------------------------------------
    model = build_ssd_model(cfg)
    return model
Exemple #7
0
def setup_self_ade(cfg, args):
    logger = logging.getLogger("self_ade.setup")
    logger.info("Starting self_ade setup")

    # build model from config
    model = build_ssd_model(cfg)
    device = torch.device(cfg.MODEL.DEVICE)
    model.to(device)

    train_transform = TrainAugmentation(cfg.INPUT.IMAGE_SIZE,
                                        cfg.INPUT.PIXEL_MEAN,
                                        cfg.INPUT.PIXEL_STD)

    target_transform = MatchPrior(
        PriorBox(cfg)(), cfg.MODEL.CENTER_VARIANCE, cfg.MODEL.SIZE_VARIANCE,
        cfg.MODEL.THRESHOLD)

    test_dataset = build_dataset(dataset_list=cfg.DATASETS.TEST,
                                 is_test=True)[0]
    self_ade_dataset = build_dataset(dataset_list=cfg.DATASETS.TEST,
                                     transform=train_transform,
                                     target_transform=target_transform)
    ss_dataset = SelfSupervisedDataset(self_ade_dataset, cfg)

    test_sampler = SequentialSampler(test_dataset)
    os_sampler = OneSampleBatchSampler(test_sampler, cfg.SOLVER.BATCH_SIZE,
                                       args.self_ade_iterations)

    self_ade_dataloader = DataLoader(ss_dataset,
                                     batch_sampler=os_sampler,
                                     num_workers=args.num_workers)

    effective_lr = args.learning_rate * args.self_ade_weight

    optimizer = torch.optim.SGD(model.parameters(),
                                lr=effective_lr,
                                momentum=cfg.SOLVER.MOMENTUM,
                                weight_decay=cfg.SOLVER.WEIGHT_DECAY)

    # Initialize mixed-precision training
    use_mixed_precision = cfg.USE_AMP
    amp_opt_level = 'O1' if use_mixed_precision else 'O0'
    model, optimizer = amp.initialize(model,
                                      optimizer,
                                      opt_level=amp_opt_level)

    execute_self_ade(cfg, args, test_dataset, self_ade_dataloader, model,
                     optimizer)
Exemple #8
0
def run_demo(cfg, weights_file, iou_threshold, score_threshold, video_dir, output_dir, dataset_type):
    if dataset_type == "voc":
        class_names = VOCDataset.class_names
    elif dataset_type == 'coco':
        class_names = COCODataset.class_names
    elif dataset_type == 'cla':
        class_names = CLADataset.class_names
    else:
        raise NotImplementedError('Not implemented now.')

    device = torch.device(cfg.MODEL.DEVICE)
    model = build_ssd_model(cfg)
    model.load(weights_file)
    print('Loaded weights from {}.'.format(weights_file))
    model = model.to(device)
    predictor = Predictor(cfg=cfg,
                          model=model,
                          iou_threshold=iou_threshold,
                          score_threshold=score_threshold,
                          device=device)
    cpu_device = torch.device("cpu")
    stream = cv2.VideoCapture(video_dir)
#     image_paths = glob.glob(os.path.join(video_dir, '*.jpg'))
    # 获得输出视频大小,与原视频大小相同
    shape=(int(stream.get(cv2.CAP_PROP_FRAME_WIDTH)),int(stream.get(cv2.CAP_PROP_FRAME_HEIGHT)))
    # 获取输出视频的帧率
    _fps = stream.get(cv2.CAP_PROP_FPS)
    # 指定视频编码
    fourcc = cv2.VideoWriter_fourcc(*'XVID')
    # 如果输出目录不存在 创建输出目录
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
    # 创建视频输出对象
    output_name = os.path.basename(video_dir)
    output_name = output_name.split('.')[0] + ".avi"
    writer = cv2.VideoWriter(os.path.join(output_dir, output_name), fourcc, _fps, shape)
    
    while True:
        ret, image = stream.read()
        if ret is False :
            break
        output = predictor.predict(image)
        boxes, labels, scores = [o.to(cpu_device).numpy() for o in output]
        drawn_image = draw_bounding_boxes(image, boxes, labels, scores, class_names).astype(np.uint8)
        writer.write(drawn_image)
    writer.release()
    stream.release()
Exemple #9
0
def load_model(config_file='SSD/configs/ssd300_voc0712.yaml',
               iou_threshold=0.5,
               score_threshold=0.5):
    cfg = get_configuration(config_file)
    class_names = VOCDataset.class_names
    global device
    device = torch.device('cpu')
    model = build_ssd_model(cfg)
    logger.info('Loading model from S3')
    obj = s3.get_object(Bucket=MODEL_BUCKET, Key=MODEL_KEY)
    bytestream = io.BytesIO(obj['Body'].read())
    model.load(bytestream)
    logger.info(f'Loaded weights from {MODEL_KEY}')
    model = model.to(device)
    return Predictor(cfg=cfg,
                     model=model,
                     iou_threshold=iou_threshold,
                     score_threshold=score_threshold,
                     device=device)
def evaluation(cfg, args, weights_file, output_dir, distributed):
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
    device = torch.device(cfg.MODEL.DEVICE)
    model = build_ssd_model(cfg)
    model.load(open(weights_file, 'rb'))
    logger = logging.getLogger("SSD.inference")
    logger.info('Loaded weights from {}.'.format(weights_file))
    model.to(device)

    if args.eval_mode == "test":
        do_evaluation(cfg, model, output_dir, distributed)
    else:
        dataset_metrics = do_evaluation(cfg, model, cfg.OUTPUT_DIR, distributed, datasets_dict=_create_val_datasets(args, cfg, logger))
        count = len(dataset_metrics)
        map_sum = 0
        for k,v in dataset_metrics.items():
            #logger.info("mAP on {}: {:.3f}".format(k, v.info["mAP"]))
            map_sum += v.info["mAP"]

        avg_map = map_sum/count
        print("'Model': '{}', 'Avg_mAP': {}".format(weights_file, avg_map))
Exemple #11
0
def run_demo(cfg, checkpoint_file, iou_threshold, score_threshold, images_dir,
             output_dir):
    device = torch.device(cfg.MODEL.DEVICE)
    model = build_ssd_model(cfg)
    checkpoint = torch.load(checkpoint_file)
    model.load_state_dict(checkpoint['state_dict'])
    print('Loaded weights from {}.'.format(checkpoint_file))
    model = model.to(device)
    model.eval()
    predictor = Predictor(cfg=cfg,
                          model=model,
                          iou_threshold=iou_threshold,
                          score_threshold=score_threshold,
                          device=device)
    cpu_device = torch.device("cpu")

    image_paths = glob.glob(os.path.join(images_dir, '*.jpg'))

    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
    add_count = 0
    for image_path in tqdm(image_paths):
        image = Image.open(image_path).convert("RGB")
        image = np.array(image)
        # image_mirror = image[:, ::-1]
        output = predictor.predict(image)
        boxes, scores, seg_map = [o.to(cpu_device).numpy() for o in output]
        seg_map = cv2.resize(seg_map, (512, 512)) * 255
        seg_map = seg_map.astype(np.uint8)
        # seg_map = cv2.applyColorMap(seg_map, cv2.COLORMAP_JET)
        seg_map = cv2.resize(seg_map, (1280, 720),
                             interpolation=cv2.INTER_CUBIC)
        drawn_image = draw_bounding_boxes(image, boxes).astype(np.uint8)
        image_name = os.path.basename(image_path)
        txt_path = os.path.join(output_dir, 'txtes')
        if not os.path.exists(txt_path):
            os.makedirs(txt_path)
        txt_path = os.path.join(txt_path,
                                'res_' + image_name.replace('jpg', 'txt'))
        #multi-output merge
        merge_output = False
        if merge_output:
            ret, binary = cv2.threshold(seg_map, 75, 255, cv2.THRESH_BINARY)
            # cv2.imshow('binary:',binary)
            # cv2.waitKey()

            contours, hierarchy = cv2.findContours(binary, cv2.RETR_TREE,
                                                   cv2.CHAIN_APPROX_SIMPLE)

            w, h = np.shape(binary)
            for contour in contours:
                # 获取最小包围矩形
                rect = cv2.minAreaRect(contour)

                # 中心坐标
                x, y = rect[0]
                # cv2.circle(img, (int(x), int(y)), 3, (0, 255, 0), 5)

                # 长宽,总有 width>=height
                width, height = rect[1]
                if width < 10 or height < 10:
                    continue

                # 角度:[-90,0)
                angle = rect[2]
                box = cv2.boxPoints(rect)
                box = np.int0(box)
                box[:, 0] = np.clip(box[:, 0], 0, h)
                box[:, 1] = np.clip(box[:, 1], 0, w)

                poly1 = Polygon(box).convex_hull
                intersect = False
                for item in boxes:
                    print('item:', item)
                    poly2 = Polygon(item.reshape(4, 2)).convex_hull
                    if poly1.intersects(poly2):  # 如果两四边形相交
                        intersect = True
                        break
                if not intersect:
                    print('boxes.shape:', np.shape(boxes))
                    box = box.reshape((1, 8))
                    print('box.shape:', np.shape(box))
                    num, _ = np.shape(boxes)
                    if num == 0:
                        print('num == 0')
                        boxes = box
                    else:
                        boxes = np.concatenate((boxes, box))
                    print('boxes.shape:', np.shape(boxes))
                    print('add one box')
                    add_count += 1
                    # cv2.line(image, (box[0][0], box[0][1]), (box[0][2], box[0][3]), (0, 0, 255), thickness=4)
                    # cv2.line(image,(box[0][2], box[0][3]), (box[0][4], box[0][5]), (0, 0, 255), thickness=4)
                    # cv2.line(image,(box[0][4], box[0][5]), (box[0][6], box[0][7]), (0, 0, 255), thickness=4)
                    # cv2.line(image, (box[0][6], box[0][7]), (box[0][0], box[0][1]), (0, 0, 255), thickness=4)
                    # cv2.imshow('img',image)
                    # cv2.waitKey()

        # print('txt_path:',txt_path)
        with open(txt_path, 'w+') as f:
            for box in boxes:
                box_temp = np.reshape(box, (4, 2))
                box = order_points_quadrangle(box_temp)
                box = np.reshape(box, (1, 8)).squeeze(0)
                is_valid = validate_clockwise_points(box)
                if not is_valid:
                    continue
                # print('box:',box)
                line = ''
                for item in box:
                    if item < 0:
                        item = 0
                    line += str(int(item)) + ','
                line = line[:-1] + '\n'
                f.write(line)
        path = os.path.join(output_dir, image_name)
        print('path:', path)
        Image.fromarray(drawn_image).save(path)
        path = os.path.join(
            output_dir,
            image_name.split('.')[0] + '_segmap.' + image_name.split('.')[1])
        # print(path)
        # 存储score_map
        cv2.imwrite(path, seg_map)
    print('add count:', add_count)
Exemple #12
0
def train(cfg, args):
    logger = logging.getLogger('SSD.trainer')
    # -----------------------------------------------------------------------------
    # Model
    # -----------------------------------------------------------------------------
    model = build_ssd_model(cfg)
    device = torch.device(cfg.MODEL.DEVICE)
    model.to(device)
    if args.resume:
        logger.info("Resume from the model {}".format(args.resume))
        model.load(args.resume)
    else:
        logger.info("Init from base net {}".format(args.vgg))
        model.init_from_base_net(args.vgg)
    if args.distributed:
        model = torch.nn.parallel.DistributedDataParallel(
            model, device_ids=[args.local_rank], output_device=args.local_rank)
    # -----------------------------------------------------------------------------
    # Optimizer
    # -----------------------------------------------------------------------------
    lr = cfg.SOLVER.LR * args.num_gpus  # scale by num gpus
    optimizer = torch.optim.SGD(model.parameters(),
                                lr=lr,
                                momentum=cfg.SOLVER.MOMENTUM,
                                weight_decay=cfg.SOLVER.WEIGHT_DECAY)
    # -----------------------------------------------------------------------------
    # Criterion
    # -----------------------------------------------------------------------------
    criterion = MultiBoxLoss(iou_threshold=cfg.MODEL.THRESHOLD,
                             neg_pos_ratio=cfg.MODEL.NEG_POS_RATIO)

    # -----------------------------------------------------------------------------
    # Scheduler
    # -----------------------------------------------------------------------------
    milestones = [step // args.num_gpus for step in cfg.SOLVER.LR_STEPS]
    scheduler = WarmupMultiStepLR(optimizer=optimizer,
                                  milestones=milestones,
                                  gamma=cfg.SOLVER.GAMMA,
                                  warmup_factor=cfg.SOLVER.WARMUP_FACTOR,
                                  warmup_iters=cfg.SOLVER.WARMUP_ITERS)

    # -----------------------------------------------------------------------------
    # Dataset
    # -----------------------------------------------------------------------------
    train_transform = TrainAugmentation(cfg.INPUT.IMAGE_SIZE,
                                        cfg.INPUT.PIXEL_MEAN)
    target_transform = MatchPrior(
        PriorBox(cfg)(), cfg.MODEL.CENTER_VARIANCE, cfg.MODEL.SIZE_VARIANCE,
        cfg.MODEL.THRESHOLD)
    train_dataset = build_dataset(dataset_list=cfg.DATASETS.TRAIN,
                                  transform=train_transform,
                                  target_transform=target_transform)
    logger.info("Train dataset size: {}".format(len(train_dataset)))
    if args.distributed:
        sampler = torch.utils.data.DistributedSampler(train_dataset)
    else:
        sampler = torch.utils.data.RandomSampler(train_dataset)
    batch_sampler = torch.utils.data.sampler.BatchSampler(
        sampler=sampler, batch_size=cfg.SOLVER.BATCH_SIZE, drop_last=False)
    batch_sampler = samplers.IterationBasedBatchSampler(
        batch_sampler, num_iterations=cfg.SOLVER.MAX_ITER // args.num_gpus)
    train_loader = DataLoader(train_dataset,
                              num_workers=4,
                              batch_sampler=batch_sampler)

    return do_train(cfg, model, train_loader, optimizer, scheduler, criterion,
                    device, args)
Exemple #13
0
    net = add_flops_counting_methods(net)
    net = net.cuda().eval()
    net.start_flops_count()

    _ = net(input)

    return net.compute_average_flops_cost()/1e9/2


# example
if __name__ == '__main__':

    from ssd.modeling.vgg_ssd import build_ssd_model
    from ssd.config import cfg
    '''
    '''

    cfg.merge_from_file("configs/ssd512_voc0712.yaml")

    cfg.freeze()
    model = build_ssd_model(cfg)
    input_size = (1024, 1024)
    #ssd_net = model.eval()
    ssd_net = model.cuda()


    total_flops = get_flops(ssd_net, input_size)

    # For default vgg16 model, this shoud output 31.386288 G FLOPS
    print("The Model's Total FLOPS is : {:.6f} G FLOPS".format(total_flops))
def train(cfg, args):
    logger = logging.getLogger('SSD.trainer')
    # -----------------------------------------------------------------------------
    # Model
    # -----------------------------------------------------------------------------
    model = build_ssd_model(cfg)
    device = torch.device(cfg.MODEL.DEVICE)
    model.to(device)
    if args.resume:
        logger.info("Resume from the model {}".format(args.resume))
        checkpoint = torch.load(args.resume)
        model.load_state_dict(checkpoint['state_dict'])
        iteration = checkpoint['iteration']
        print('iteration:', iteration)
    elif args.vgg:
        iteration = 0
        logger.info("Init from backbone net {}".format(args.vgg))
        model.init_from_base_net(args.vgg)
    else:
        iteration = 0
        logger.info("all init from kaiming init")
    # -----------------------------------------------------------------------------
    # Optimizer
    # -----------------------------------------------------------------------------
    lr = cfg.SOLVER.LR * args.num_gpus  # scale by num gpus
    #optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=cfg.SOLVER.MOMENTUM, weight_decay=cfg.SOLVER.WEIGHT_DECAY)
    print('cfg.SOLVER.WEIGHT_DECAY:', cfg.SOLVER.WEIGHT_DECAY)
    optimizer = torch.optim.Adam(model.parameters(),
                                 lr=lr,
                                 weight_decay=cfg.SOLVER.WEIGHT_DECAY)

    # -----------------------------------------------------------------------------
    # Scheduler
    # -----------------------------------------------------------------------------
    milestones = [step // args.num_gpus for step in cfg.SOLVER.LR_STEPS]
    scheduler = WarmupMultiStepLR(optimizer=optimizer,
                                  milestones=milestones,
                                  gamma=cfg.SOLVER.GAMMA,
                                  warmup_factor=cfg.SOLVER.WARMUP_FACTOR,
                                  warmup_iters=cfg.SOLVER.WARMUP_ITERS)

    # ------------------------1-----------------------------------------------------
    # Dataset
    # -----------------------------------------------------------------------------
    #对原始图像进行数据增强
    train_transform = TrainAugmentation(cfg.INPUT.IMAGE_SIZE,
                                        cfg.INPUT.PIXEL_MEAN)
    target_transform = MatchPrior(
        PriorBox(cfg)(), cfg.MODEL.CENTER_VARIANCE, cfg.MODEL.SIZE_VARIANCE,
        cfg.MODEL.IOU_THRESHOLD, cfg.MODEL.PRIORS.DISTANCE_THRESHOLD)
    train_dataset = build_dataset(dataset_list=cfg.DATASETS.TRAIN,
                                  transform=train_transform,
                                  target_transform=target_transform,
                                  args=args)
    logger.info("Train dataset size: {}".format(len(train_dataset)))
    sampler = torch.utils.data.RandomSampler(train_dataset)
    # sampler = torch.utils.data.SequentialSampler(train_dataset)
    batch_sampler = torch.utils.data.sampler.BatchSampler(
        sampler=sampler, batch_size=cfg.SOLVER.BATCH_SIZE, drop_last=False)
    batch_sampler = samplers.IterationBasedBatchSampler(
        batch_sampler, num_iterations=cfg.SOLVER.MAX_ITER // args.num_gpus)
    train_loader = DataLoader(train_dataset,
                              num_workers=4,
                              batch_sampler=batch_sampler,
                              pin_memory=True)

    return do_train(cfg, model, train_loader, optimizer, scheduler, device,
                    args, iteration)
Exemple #15
0
def train(cfg, args):
    logger = logging.getLogger('SSD.trainer')
    # -----------------------------------------------------------------------------
    # Model
    # -----------------------------------------------------------------------------
    model = build_ssd_model(cfg)
    device = torch.device(cfg.MODEL.DEVICE)
    model.to(device)
    # -----------------------------------------------------------------------------
    # Optimizer
    # -----------------------------------------------------------------------------
    lr = cfg.SOLVER.LR * args.num_gpus  # scale by num gpus
    optimizer = torch.optim.SGD(model.parameters(),
                                lr=lr,
                                momentum=cfg.SOLVER.MOMENTUM,
                                weight_decay=cfg.SOLVER.WEIGHT_DECAY)

    # -----------------------------------------------------------------------------
    # Scheduler
    # -----------------------------------------------------------------------------
    milestones = [step // args.num_gpus for step in cfg.SOLVER.LR_STEPS]
    scheduler = WarmupMultiStepLR(optimizer=optimizer,
                                  milestones=milestones,
                                  gamma=cfg.SOLVER.GAMMA,
                                  warmup_factor=cfg.SOLVER.WARMUP_FACTOR,
                                  warmup_iters=cfg.SOLVER.WARMUP_ITERS)

    # -----------------------------------------------------------------------------
    # Load weights or restore checkpoint
    # -----------------------------------------------------------------------------
    if args.resume:
        logger.info("Resume from the model {}".format(args.resume))
        restore_training_checkpoint(logger,
                                    model,
                                    args.resume,
                                    optimizer=optimizer,
                                    scheduler=scheduler)
    else:
        logger.info("Init from base net {}".format(args.vgg))
        model.init_from_base_net(args.vgg)

    # Initialize mixed-precision training
    use_mixed_precision = cfg.USE_AMP
    amp_opt_level = 'O1' if use_mixed_precision else 'O0'
    model, optimizer = amp.initialize(model,
                                      optimizer,
                                      opt_level=amp_opt_level)

    if args.distributed:
        model = torch.nn.parallel.DistributedDataParallel(
            model,
            device_ids=[args.local_rank],
            output_device=args.local_rank,
            find_unused_parameters=True)

    # -----------------------------------------------------------------------------
    # Dataset
    # -----------------------------------------------------------------------------
    train_transform = TrainAugmentation(cfg.INPUT.IMAGE_SIZE,
                                        cfg.INPUT.PIXEL_MEAN,
                                        cfg.INPUT.PIXEL_STD)
    target_transform = MatchPrior(
        PriorBox(cfg)(), cfg.MODEL.CENTER_VARIANCE, cfg.MODEL.SIZE_VARIANCE,
        cfg.MODEL.THRESHOLD)

    if cfg.DATASETS.DG:
        if args.eval_mode == "val":
            dslist, val_set_dict = _create_dg_datasets(args, cfg, logger,
                                                       target_transform,
                                                       train_transform)
        else:
            dslist = _create_dg_datasets(args, cfg, logger, target_transform,
                                         train_transform)

        logger.info("Sizes of sources datasets:")
        for k, v in dslist.items():
            logger.info("{} size: {}".format(k, len(v)))

        dataloaders = []
        for name, train_dataset in dslist.items():
            sampler = torch.utils.data.RandomSampler(train_dataset)
            batch_sampler = torch.utils.data.sampler.BatchSampler(
                sampler=sampler,
                batch_size=cfg.SOLVER.BATCH_SIZE,
                drop_last=True)

            batch_sampler = samplers.IterationBasedBatchSampler(
                batch_sampler, num_iterations=cfg.SOLVER.MAX_ITER)

            if cfg.MODEL.SELF_SUPERVISED:
                ss_dataset = SelfSupervisedDataset(train_dataset, cfg)
                train_loader = DataLoader(ss_dataset,
                                          num_workers=args.num_workers,
                                          batch_sampler=batch_sampler,
                                          pin_memory=True)
            else:
                train_loader = DataLoader(train_dataset,
                                          num_workers=args.num_workers,
                                          batch_sampler=batch_sampler,
                                          pin_memory=True)
            dataloaders.append(train_loader)

        if args.eval_mode == "val":
            if args.return_best:
                return do_train(cfg, model, dataloaders, optimizer, scheduler,
                                device, args, val_set_dict)
            else:
                return do_train(cfg, model, dataloaders, optimizer, scheduler,
                                device, args)
        else:
            return do_train(cfg, model, dataloaders, optimizer, scheduler,
                            device, args)

    # No DG:
    if args.eval_mode == "val":
        train_dataset, val_dataset = build_dataset(
            dataset_list=cfg.DATASETS.TRAIN,
            transform=train_transform,
            target_transform=target_transform,
            split=True)
    else:
        train_dataset = build_dataset(dataset_list=cfg.DATASETS.TRAIN,
                                      transform=train_transform,
                                      target_transform=target_transform)
    logger.info("Train dataset size: {}".format(len(train_dataset)))
    if args.distributed:
        sampler = torch.utils.data.DistributedSampler(train_dataset)
    else:
        sampler = torch.utils.data.RandomSampler(train_dataset)
    batch_sampler = torch.utils.data.sampler.BatchSampler(
        sampler=sampler, batch_size=cfg.SOLVER.BATCH_SIZE, drop_last=False)
    batch_sampler = samplers.IterationBasedBatchSampler(
        batch_sampler, num_iterations=cfg.SOLVER.MAX_ITER // args.num_gpus)

    if cfg.MODEL.SELF_SUPERVISED:
        ss_dataset = SelfSupervisedDataset(train_dataset, cfg)
        train_loader = DataLoader(ss_dataset,
                                  num_workers=args.num_workers,
                                  batch_sampler=batch_sampler,
                                  pin_memory=True)
    else:
        train_loader = DataLoader(train_dataset,
                                  num_workers=args.num_workers,
                                  batch_sampler=batch_sampler,
                                  pin_memory=True)

    if args.eval_mode == "val":
        return do_train(cfg, model, train_loader, optimizer, scheduler, device,
                        args, {"validation_split": val_dataset})
    else:
        return do_train(cfg, model, train_loader, optimizer, scheduler, device,
                        args)
def run_demo(cfg, checkpoint_file, iou_threshold, score_threshold, images_dir,
             output_dir):
    basename = os.path.basename(checkpoint_file).split('.')[0]
    epoch = basename[21:]
    epoch = int(epoch)
    if epoch < min_epoch:
        return
    device = torch.device(cfg.MODEL.DEVICE)
    model = build_ssd_model(cfg)
    checkpoint = torch.load(checkpoint_file)
    model.load_state_dict(checkpoint['state_dict'])
    print('Loaded weights from {}.'.format(checkpoint_file))
    model = model.to(device)
    model.eval()
    predictor = Predictor(cfg=cfg,
                          model=model,
                          iou_threshold=iou_threshold,
                          score_threshold=score_threshold,
                          device=device)
    cpu_device = torch.device("cpu")

    image_paths = glob.glob(os.path.join(images_dir, '*.jpg'))

    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    for image_path in tqdm(image_paths):
        image = Image.open(image_path).convert("RGB")
        image = np.array(image)
        # image = image[:, ::-1]
        #image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
        output = predictor.predict(image)
        boxes, scores, score_map = [o.to(cpu_device).numpy() for o in output]
        # score_map = cv2.resize(score_map, (512, 512)) * 255
        # score_map = score_map.astype(np.uint8)
        # score_map = cv2.applyColorMap(score_map, cv2.COLORMAP_JET)
        # score_map = cv2.resize(score_map,(1280,720),interpolation=cv2.INTER_CUBIC)
        drawn_image = draw_bounding_boxes(image, boxes).astype(np.uint8)
        image_name = os.path.basename(image_path)
        txt_path = os.path.join(output_dir, 'txtes')
        if not os.path.exists(txt_path):
            os.makedirs(txt_path)
        txt_path = os.path.join(txt_path,
                                'res_' + image_name.replace('jpg', 'txt'))
        # print('txt_path:',txt_path)
        with open(txt_path, 'w+') as f:
            for box in boxes:
                box_temp = np.reshape(box, (4, 2))
                box = order_points_quadrangle(box_temp)

                box = np.reshape(box, (1, 8)).squeeze(0)
                is_valid = validate_clockwise_points(box)
                if not is_valid:
                    continue
                # print('box:',box)
                line = ''
                for item in box:
                    if item < 0:
                        item = 0
                    line += str(int(item)) + ','
                line = line[:-1] + '\n'
                f.write(line)

        # path = os.path.join(output_dir, image_name)
        # Image.fromarray(drawn_image).save(path)
        # path = os.path.join(output_dir, image_name.split('.')[0]+'_score_map.'+image_name.split('.')[1])
        # print(path)
        # cv2.imwrite(path,score_map)
    submit_path = MyZip(
        '/home/binchengxiong/ssd_fcn_multitask_text_detection_pytorch1.0/demo/result223'
        + str(min_epoch) + '/txtes', epoch)
    hmean_this_epoch = compute_hmean(submit_path, epoch)