示例#1
0
    def __init__(self):
        self.category = [
            'aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car',
            'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse', 'motorbike',
            'person', 'pottedplant', 'sheep', 'sofa', 'train', 'tvmonitor'
        ]
        self.img_root_dir = os.path.join(cfg.TRAIN.DATASETS_DIR,
                                         cfg.TRAIN.DATASETS[0], 'JPEGImages')
        self.img_list = os.path.join(cfg.TRAIN.DATASETS_DIR,
                                     cfg.TRAIN.DATASETS[0], 'ImageSets',
                                     'Main', cfg.TRAIN.DATASETS[1])
        self.annotations_dir = os.path.join(cfg.TRAIN.DATASETS_DIR,
                                            cfg.TRAIN.DATASETS[0],
                                            'Annotations')
        self.transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
        ])
        self.train = True
        self.input_size = cfg.TRAIN.SCALES

        self.fnames = []
        self.boxes = []
        self.labels = []
        self.num_samples = 0
        self.get_img_annotations()

        self.encoder = DataEncoder(self.input_size)
示例#2
0
def main():
    if not cfg.TRAIN.MULTI_GPU:
        torch.cuda.set_device(cfg.TRAIN.GPU_ID[0])

    i_tb = 0
    # loading data
    src_loader, tgt_loader, restore_transform = load_dataset()
    data_encoder = DataEncoder()

    
    ext_model = None
    dc_model = None
    obc_model = None
    # initialize models
    if cfg.TRAIN.COM_EXP == 5: # Full model
        ext_model, dc_model, obc_model, cur_epoch = init_model(cfg.TRAIN.NET)
    elif cfg.TRAIN.COM_EXP == 6: # FCN + SSD + OBC
        ext_model, _, obc_model, cur_epoch = init_model(cfg.TRAIN.NET)
    elif cfg.TRAIN.COM_EXP == 4: # FCN + SSD + DC
        ext_model, dc_model, _, cur_epoch = init_model(cfg.TRAIN.NET)
    elif cfg.TRAIN.COM_EXP == 3: # FCN + SSD
        ext_model, _, __, cur_epoch = init_model(cfg.TRAIN.NET)
    
    # set criterion and optimizer, training
    if ext_model is not None: 
        weight = torch.ones(cfg.DATA.NUM_CLASSES)
        weight[cfg.DATA.NUM_CLASSES - 1] = 0
        spvsd_cri = CrossEntropyLoss2d(cfg.TRAIN.LABEL_WEIGHT).cuda() # traditional Loss for the FCN-8s
        unspvsd_cri = CrossEntropyLoss2d(cfg.TRAIN.LABEL_WEIGHT).cuda() # traditional Loss for the FCN-8s
        det_cri = MultiBoxLoss()
        # the ext_opt will be set in the train_net.py, because the ssd learning rate is stepsise        

    if dc_model is not None:
        dc_cri = CrossEntropyLoss2d().cuda()    
        dc_invs_cri = CrossEntropyLoss2d().cuda()
        dc_opt = optim.Adam(dc_model.parameters(), lr=cfg.TRAIN.DC_LR, betas=(0.5, 0.999))        

    if obc_model is not None:
        obc_cri = CrossEntropyLoss().cuda()
        obc_invs_cri = CrossEntropyLoss().cuda()
        obc_opt = optim.Adam(obc_model.parameters(), lr=cfg.TRAIN.OBC_LR, betas=(0.5, 0.999))

    if cfg.TRAIN.COM_EXP == 6:
        train_adversarial(cur_epoch, i_tb, data_encoder, src_loader, tgt_loader, restore_transform, 
                        ext_model, spvsd_cri, unspvsd_cri, det_cri, 
                        obc_model=obc_model, obc_cri=obc_cri, obc_invs_cri=obc_invs_cri, obc_opt=obc_opt)
        

    if cfg.TRAIN.COM_EXP == 5:
        train_adversarial(cur_epoch, i_tb, data_encoder, src_loader, tgt_loader, restore_transform, 
                        ext_model, spvsd_cri, unspvsd_cri, det_cri, 
                        dc_model=dc_model,  dc_cri=dc_cri, dc_invs_cri=dc_invs_cri, dc_opt=dc_opt, 
                        obc_model=obc_model, obc_cri=obc_cri, obc_invs_cri=obc_invs_cri, obc_opt=obc_opt)
    if cfg.TRAIN.COM_EXP == 4:
        train_adversarial(cur_epoch, i_tb, data_encoder, src_loader, tgt_loader, restore_transform, 
                        ext_model, spvsd_cri, unspvsd_cri, det_cri, 
                        dc_model=dc_model,  dc_cri=dc_cri, dc_invs_cri=dc_invs_cri, dc_opt=dc_opt)
    if cfg.TRAIN.COM_EXP == 3:
        train_adversarial(cur_epoch, i_tb, data_encoder, src_loader, tgt_loader, restore_transform, 
                        ext_model, spvsd_cri, unspvsd_cri, det_cri)
示例#3
0
def draw_bbx(loc_preds, conf_preds, img, restore):

    data_encoder = DataEncoder()
    bbx, _, _ = \
    data_encoder.decode(loc_preds.data.squeeze(0), F.softmax(conf_preds.squeeze(0)).data)

    map_size = cfg.TRAIN.IMG_SIZE
    bbx = bbx * map_size[0]
    # pdb.set_trace()
    img = restore(img)
    # print img.size()
    imgDraw = ImageDraw.Draw(img)
    for i_bbx in range(0,bbx.shape[0]):
        imgDraw.rectangle(bbx[i_bbx,:].tolist(),outline = "red")
    del imgDraw

    return img
示例#4
0
    def __init__(self):
        self.category = ['nopon']
        self.img_root_dir = os.path.join(
            '/home/mia_dev/xeroblade2/dataset/train/voc/img/')
        self.img_list = os.path.join(
            '/home/mia_dev/xeroblade2/dataset/train/voc/train.txt')
        self.annotations_dir = os.path.join(
            '/home/mia_dev/xeroblade2/dataset/train/voc/xmlannotation/')
        self.transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
        ])
        self.train = True
        self.input_size = cfg.TRAIN.SCALES

        self.fnames = []
        self.boxes = []
        self.labels = []
        self.num_samples = 0
        self.get_img_annotations()

        self.encoder = DataEncoder(self.input_size)
示例#5
0
class VOCdataset(torch.utils.data.Dataset):
    def __init__(self):
        self.category = ['nopon']
        self.img_root_dir = os.path.join(
            '/home/mia_dev/xeroblade2/dataset/train/voc/img/')
        self.img_list = os.path.join(
            '/home/mia_dev/xeroblade2/dataset/train/voc/train.txt')
        self.annotations_dir = os.path.join(
            '/home/mia_dev/xeroblade2/dataset/train/voc/xmlannotation/')
        self.transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
        ])
        self.train = True
        self.input_size = cfg.TRAIN.SCALES

        self.fnames = []
        self.boxes = []
        self.labels = []
        self.num_samples = 0
        self.get_img_annotations()

        self.encoder = DataEncoder(self.input_size)

    def get_img_annotations(self):
        with open(self.img_list) as f:
            lines = f.readlines()
        self.num_samples = len(lines)

        for line in lines:
            splited = line.strip().split('/')[-1].split('.jpg')[0]
            self.fnames.append(splited + '.jpg')
            box = []
            label = []
            ann = os.path.join(self.annotations_dir, splited + '.xml')
            rec = parse_rec(ann)
            for r in rec:
                box.append(r['bbox'])
                label.append(self.category.index(r['name']))
            self.boxes.append(torch.Tensor(box))
            self.labels.append(torch.LongTensor(label))

    def __getitem__(self, idx):
        '''Load image.

        Args:
          idx: (int) image index.

        Returns:
          img: (tensor) image tensor.
          loc_targets: (tensor) location targets.
          cls_targets: (tensor) class label targets.
        '''
        # Load image and boxes.
        fname = self.fnames[idx]
        img = Image.open(os.path.join(self.img_root_dir, fname))
        if img.mode != 'RGB':
            img = img.convert('RGB')

        boxes = self.boxes[idx].clone()
        labels = self.labels[idx]
        size = self.input_size

        # Data augmentation.
        if self.train:
            img, boxes = random_flip(img, boxes)
            #img, boxes = random_crop(img, boxes)
            img, boxes = resize(img, boxes, size)
        else:
            img, boxes = resize(img, boxes, size)
            #img, boxes = center_crop(img, boxes, size)

        img = self.transform(img)
        return img, boxes, labels

    def collate_fn(self, batch):
        '''Pad images and encode targets.

        As for images are of different sizes, we need to pad them to the same size.

        Args:
          batch: (list) of images, cls_targets, loc_targets.

        Returns:
          padded images, stacked cls_targets, stacked loc_targets.
        '''
        imgs = [x[0] for x in batch]
        boxes = [x[1] for x in batch]
        labels = [x[2] for x in batch]

        h = self.input_size[1]
        w = self.input_size[0]
        num_imgs = len(imgs)
        inputs = torch.zeros(num_imgs, 3, h, w)

        loc_targets = []
        cls_targets = []
        for i in range(num_imgs):
            inputs[i] = imgs[i]
            loc_target, cls_target = self.encoder.encode(boxes[i],
                                                         labels[i],
                                                         input_size=(w, h))
            loc_targets.append(loc_target)
            cls_targets.append(cls_target)
        return inputs, torch.stack(loc_targets), torch.stack(cls_targets)

    def __len__(self):
        return self.num_samples
def test_model():
    """Model testing loop."""
    logger = logging.getLogger(__name__)

    model = create(cfg.MODEL.TYPE, cfg.MODEL.CONV_BODY, cfg.MODEL.NUM_CLASSES)
    checkpoint = torch.load(os.path.join('checkpoint', cfg.TEST.WEIGHTS))
    model.load_state_dict(checkpoint['net'])

    if not torch.cuda.is_available():
        logger.info('cuda not find')
        sys.exit(1)

    model = torch.nn.DataParallel(model,
                                  device_ids=range(torch.cuda.device_count()))
    model.cuda()
    model.eval()

    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
    ])

    img_dir = os.path.join(cfg.TEST.DATASETS_DIR, cfg.TEST.DATASETS[0],
                           'JPEGImages')
    img_list = os.path.join(cfg.TEST.DATASETS_DIR, cfg.TEST.DATASETS[0],
                            'ImageSets', 'Main', cfg.TEST.DATASETS[1])

    with open(img_list, 'r') as lst:
        img_list = lst.readlines()
    img_nums = len(img_list)

    test_scales = cfg.TEST.SCALES
    dic = {}
    for i in range(20):
        dic[str(i)] = []

    for im in range(img_nums):
        if im % 100 == 0:
            logger.info('{} imgs were processed, total {}'.format(
                im, img_nums))
        img = Image.open(os.path.join(img_dir, img_list[im].strip() + '.jpg'))
        img_size = img.size
        img = img.resize(test_scales)

        # For visualization
        filename = os.path.join("visual_results",
                                img_list[im].strip() + '.jpg')

        x = transform(img)
        x = x.unsqueeze(0)
        x = torch.autograd.Variable(x)
        loc_preds, cls_preds = model(x)

        loc_preds = loc_preds.data.squeeze().type(torch.FloatTensor)
        cls_preds = cls_preds.data.squeeze().type(torch.FloatTensor)

        encoder = DataEncoder(test_scales)
        boxes, labels, sco, is_found = encoder.decode(loc_preds, cls_preds,
                                                      test_scales)
        if is_found:
            img, boxes = resize(img, boxes, img_size)

            boxes = boxes.ceil()
            xmin = boxes[:, 0].clamp(min=1)
            ymin = boxes[:, 1].clamp(min=1)
            xmax = boxes[:, 2].clamp(max=img_size[0] - 1)
            ymax = boxes[:, 3].clamp(max=img_size[1] - 1)

            nums = len(boxes)
            for i in range(nums):
                dic[str(labels[i].item())].append([
                    img_list[im].strip(), sco[i].item(), xmin[i].item(),
                    ymin[i].item(), xmax[i].item(), ymax[i].item()
                ])

            draw = ImageDraw.Draw(img)
            font = ImageFont.truetype(
                b'/usr/share/fonts/truetype/ancient-scripts/Symbola_hint.ttf',
                20)

            count = 0
            for box in boxes:
                draw.rectangle(list(box), outline='red', width=5)
                draw.text((box[0], box[1] - 20),
                          str(category[labels[count]]),
                          font=font,
                          fill=(255, 0, 0, 255))
                count = count + 1
            img.save(filename)

    for key in dic.keys():
        logger.info('category id: {}, category name: {}'.format(
            key, category[int(key)]))
        file_name = cfg.TEST.OUTPUT_DIR + 'comp4_det_test_' + category[int(
            key)] + '.txt'
        with open(file_name, 'w') as comp4:
            nums = len(dic[key])
            for i in range(nums):
                img, cls_preds, xmin, ymin, xmax, ymax = dic[key][i]
                if cls_preds > 0.5:
                    cls_preds = '%.6f' % cls_preds
                    loc_preds = '%.6f %.6f %.6f %.6f' % (xmin, ymin, xmax,
                                                         ymax)
                    rlt = '{} {} {}\n'.format(img, cls_preds, loc_preds)
                    comp4.write(rlt)
示例#7
0
class VOCdataset(torch.utils.data.Dataset):
    def __init__(self):
        self.category = [
            'aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car',
            'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse', 'motorbike',
            'person', 'pottedplant', 'sheep', 'sofa', 'train', 'tvmonitor'
        ]
        self.img_root_dir = os.path.join(cfg.TRAIN.DATASETS_DIR,
                                         cfg.TRAIN.DATASETS[0], 'JPEGImages')
        self.img_list = os.path.join(cfg.TRAIN.DATASETS_DIR,
                                     cfg.TRAIN.DATASETS[0], 'ImageSets',
                                     'Main', cfg.TRAIN.DATASETS[1])
        self.annotations_dir = os.path.join(cfg.TRAIN.DATASETS_DIR,
                                            cfg.TRAIN.DATASETS[0],
                                            'Annotations')
        self.transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
        ])
        self.train = True
        self.input_size = cfg.TRAIN.SCALES

        self.fnames = []
        self.boxes = []
        self.labels = []
        self.num_samples = 0
        self.get_img_annotations()

        self.encoder = DataEncoder(self.input_size)

    def get_img_annotations(self):
        with open(self.img_list) as f:
            lines = f.readlines()
        self.num_samples = len(lines)

        for line in lines:
            splited = line.strip()
            self.fnames.append(splited + '.jpg')
            box = []
            label = []
            ann = os.path.join(self.annotations_dir, splited + '.xml')
            rec = parse_rec(ann)
            for r in rec:
                box.append(r['bbox'])
                label.append(self.category.index(r['name']))
            self.boxes.append(torch.Tensor(box))
            self.labels.append(torch.LongTensor(label))

    def __getitem__(self, idx):
        '''Load image.

        Args:
          idx: (int) image index.

        Returns:
          img: (tensor) image tensor.
          loc_targets: (tensor) location targets.
          cls_targets: (tensor) class label targets.
        '''
        # Load image and boxes.
        fname = self.fnames[idx]
        img = Image.open(os.path.join(self.img_root_dir, fname))
        if img.mode != 'RGB':
            img = img.convert('RGB')

        boxes = self.boxes[idx].clone()
        labels = self.labels[idx]
        size = self.input_size

        # Data augmentation.
        if self.train:
            img, boxes = random_flip(img, boxes)
            img, boxes = random_crop(img, boxes)
            img, boxes = resize(img, boxes, size)
        else:
            img, boxes = resize(img, boxes, size)
            img, boxes = center_crop(img, boxes, size)

        img = self.transform(img)
        return img, boxes, labels

    def collate_fn(self, batch):
        '''Pad images and encode targets.

        As for images are of different sizes, we need to pad them to the same size.

        Args:
          batch: (list) of images, cls_targets, loc_targets.

        Returns:
          padded images, stacked cls_targets, stacked loc_targets.
        '''
        imgs = [x[0] for x in batch]
        boxes = [x[1] for x in batch]
        labels = [x[2] for x in batch]

        h = self.input_size[1]
        w = self.input_size[0]
        num_imgs = len(imgs)
        inputs = torch.zeros(num_imgs, 3, h, w)

        loc_targets = []
        cls_targets = []
        for i in range(num_imgs):
            inputs[i] = imgs[i]
            loc_target, cls_target = self.encoder.encode(boxes[i],
                                                         labels[i],
                                                         input_size=(w, h))
            loc_targets.append(loc_target)
            cls_targets.append(cls_target)
        return inputs, torch.stack(loc_targets), torch.stack(cls_targets)

    def __len__(self):
        return self.num_samples
示例#8
0
def test_model():
    """Model testing loop."""
    logger = logging.getLogger(__name__)
    colors = np.random.randint(0, 255, size=(1, 3), dtype="uint8")

    model = create(cfg.MODEL.TYPE, cfg.MODEL.CONV_BODY, cfg.MODEL.NUM_CLASSES)
    checkpoint = torch.load(os.path.join('checkpoint', cfg.TEST.WEIGHTS))
    model.load_state_dict(checkpoint['net'])


    if not torch.cuda.is_available(): 
        logger.info('cuda not find')
        sys.exit(1)

    #model = torch.nn.DataParallel(model, device_ids=range(torch.cuda.device_count()))
    model.cuda()
    #model.cpu()
    model.eval()

    transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.485,0.456,0.406), (0.229,0.224,0.225))])

    img_dir = os.path.join("/home/mia_dev/xeroblade2/dataset/train/img/")
    g = os.walk(r"/home/mia_dev/xeroblade2/dataset/test/2021-03-22_21-49-34")
    img_list=[]
    for path, dir_list, file_list in g:
        for file_name in file_list:
            img_list.append(file_name)
    img_nums = len(img_list)

    test_scales = cfg.TEST.SCALES
    dic = {}
    for i in range(20) : 
        dic[str(i)] = []

    for im in range(img_nums):
        if im % 100 == 0 : logger.info('{} imgs were processed, total {}'. format(im, img_nums))
        img = Image.open(os.path.join("/home/mia_dev/xeroblade2/dataset/test/2021-03-22_21-49-34/", img_list[im].strip()))
        print(os.path.join("/home/mia_dev/xeroblade2/dataset/test/2021-03-22_21-49-34/", img_list[im].strip()))
        img_size = img.size
        img = img.resize(test_scales)

        x = transform(img)
        x=x.cuda()
        x = x.unsqueeze(0)
        x = torch.autograd.Variable(x)
        loc_preds, cls_preds = model(x)

        loc_preds = loc_preds.data.squeeze().type(torch.FloatTensor)
        cls_preds = cls_preds.data.squeeze().type(torch.FloatTensor)

        encoder = DataEncoder(test_scales)
        boxes, labels, sco, is_found = encoder.decode(loc_preds, cls_preds, test_scales)
        if is_found :
            img, boxes = resize(img, boxes, img_size)

            boxes = boxes.ceil()
            xmin = boxes[:, 0].clamp(min = 1)
            ymin = boxes[:, 1].clamp(min = 1)
            xmax = boxes[:, 2].clamp(max = img_size[0] - 1)
            ymax = boxes[:, 3].clamp(max = img_size[1] - 1)

            nums = len(boxes)
            print(nums)
            for i in range(nums) : 
                dic[str(labels[i].item())].append([img_list[im].strip(), sco[i].item(), xmin[i].item(), ymin[i].item(), xmax[i].item(), ymax[i].item()])

    temp=''
    for key in dic.keys() : 
        #logger.info('category id: {}, category name: {}'. format(key, category[int(key)]))
        #file_name = cfg.TEST.OUTPUT_DIR + 'comp4_det_test_'+category[int(key)]+'.txt'
        #with open(file_name, 'w') as comp4 :
            nums = len(dic[key])
            for i in range(nums) : 
                img, cls_preds, xmin, ymin, xmax, ymax = dic[key][i]

                if temp!=img:
                  temp=img
                  imgs = cv2.imread("/home/mia_dev/xeroblade2/dataset/test/2021-03-22_21-49-34/" + img)
                else:
                  imgs=imgs
                print(cls_preds)
                if cls_preds > 0 :
                    cls_preds = '%.6f' % cls_preds
                    loc_preds = '%.6f %.6f %.6f %.6f' % (xmin, ymin, xmax, ymax)
                    rlt = '{} {} {}\n'.format(img, cls_preds, loc_preds)
                    #comp4.write(rlt)

                    box_w = xmax - xmin
                    box_h = ymax - ymin
                    color = [int(c) for c in colors[0]]
                    print(box_w, box_h)
                    box_w=int(box_w)
                    box_h=int(box_h)
                    # print(cls_conf)
                    x1=int(xmin)
                    x2=int(xmax)
                    y1=int(ymin)
                    y2=int(ymax)

                    imgs = cv2.rectangle(imgs, (x1, y1 + box_h), (x2, y1), color, 2)
                    cv2.putText(imgs, 'nopon', (x1, y1), cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2)
                    cv2.putText(imgs, str("%.2f" % float(cls_preds)), (x2, y2 - box_h), cv2.FONT_HERSHEY_SIMPLEX, 0.5,color, 2)
                    print("/home/mia_dev/xeroblade2/dataset/result/retinanet/"+img.split('/')[-1])
                    cv2.imwrite("/home/mia_dev/xeroblade2/dataset/result/retinanet/"+img.split('/')[-1],imgs)
def test_model():
    """Model testing loop."""
    logger = logging.getLogger(__name__)

    model = create(cfg.MODEL.TYPE, cfg.MODEL.CONV_BODY, cfg.MODEL.NUM_CLASSES)
    checkpoint = torch.load(os.path.join('checkpoint', cfg.TEST.WEIGHTS))
    model.load_state_dict(checkpoint['net'])

    if not torch.cuda.is_available():
        print("DUBUGGER that cuda is not found")
        logger.info('cuda not find')
        sys.exit(1)

    print("Total cuda devices = ", torch.cuda.device_count())
    model = torch.nn.DataParallel(model,
                                  device_ids=range(torch.cuda.device_count()))
    model.cuda()
    model.eval()

    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
    ])

    print("DIRS = ", cfg.TEST.DATASETS_DIR)
    print("Another DIRS = ", cfg.TEST.DATASETS[0])
    img_dir = os.path.join(cfg.TEST.DATASETS_DIR, cfg.TEST.DATASETS[0],
                           'JPEGImages')
    img_list = os.path.join(cfg.TEST.DATASETS_DIR, cfg.TEST.DATASETS[0],
                            'ImageSets', 'Main', cfg.TEST.DATASETS[1])

    with open(img_list, 'r') as lst:
        img_list = lst.readlines()
    img_nums = len(img_list)

    test_scales = cfg.TEST.SCALES
    dic = {}
    for i in range(20):
        dic[str(i)] = []

    for im in range(img_nums):
        if im % 100 == 0:
            logger.info('{} imgs were processed, total {}'.format(
                im, img_nums))
        img = Image.open(os.path.join(img_dir, img_list[im].strip() + '.jpg'))
        img_size = img.size
        img = img.resize(test_scales)

        x = transform(img)
        x = x.unsqueeze(0)
        x = torch.autograd.Variable(x)
        loc_preds, cls_preds = model(x)

        loc_preds = loc_preds.data.squeeze().type(torch.FloatTensor)
        cls_preds = cls_preds.data.squeeze().type(torch.FloatTensor)

        encoder = DataEncoder(test_scales)
        boxes, labels, sco, is_found = encoder.decode(loc_preds, cls_preds,
                                                      test_scales)
        if is_found:
            img, boxes = resize(img, boxes, img_size)

            boxes = boxes.ceil()
            xmin = boxes[:, 0].clamp(min=1)
            ymin = boxes[:, 1].clamp(min=1)
            xmax = boxes[:, 2].clamp(max=img_size[0] - 1)
            ymax = boxes[:, 3].clamp(max=img_size[1] - 1)

            nums = len(boxes)
            for i in range(nums):
                dic[str(labels[i].item())].append([
                    img_list[im].strip(), sco[i].item(), xmin[i].item(),
                    ymin[i].item(), xmax[i].item(), ymax[i].item()
                ])

    for key in dic.keys():
        logger.info('category id: {}, category name: {}'.format(
            key, category[int(key)]))
        file_name = cfg.TEST.OUTPUT_DIR + 'comp4_det_test_' + category[int(
            key)] + '.txt'
        with open(file_name, 'w') as comp4:
            nums = len(dic[key])
            for i in range(nums):
                img, cls_preds, xmin, ymin, xmax, ymax = dic[key][i]
                if cls_preds > 0.5:
                    cls_preds = '%.6f' % cls_preds
                    loc_preds = '%.6f %.6f %.6f %.6f' % (xmin, ymin, xmax,
                                                         ymax)
                    rlt = '{} {} {}\n'.format(img, cls_preds, loc_preds)
                    comp4.write(rlt)
def test_model():
    """Model testing loop."""
    logger = logging.getLogger(__name__)
    colors = np.random.randint(0, 255, size=(1, 3), dtype="uint8")

    model = create(cfg.MODEL.TYPE, cfg.MODEL.CONV_BODY, cfg.MODEL.NUM_CLASSES)
    checkpoint = torch.load(os.path.join('checkpoint', cfg.TEST.WEIGHTS))
    model.load_state_dict(checkpoint['net'])

    if not torch.cuda.is_available():
        logger.info('cuda not find')
        sys.exit(1)

    #model = torch.nn.DataParallel(model, device_ids=range(torch.cuda.device_count()))
    model.cuda()
    #model.cpu()
    model.eval()

    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
    ])

    img_dir = os.path.join("/home/mia_dev/xeroblade2/dataset/train/img/")
    g = os.walk(r"/home/mia_dev/xeroblade2/dataset/test/2021-03-22_21-49-34")
    img_list = []
    for path, dir_list, file_list in g:
        for file_name in file_list:
            img_list.append(file_name)
    img_nums = len(img_list)

    test_scales = cfg.TEST.SCALES
    dic = {}
    for i in range(20):
        dic[str(i)] = []

    frame_array = []

    cap = cv2.VideoCapture(
        "/home/mia_dev/xeroblade2/dataset/test/2021-03-22_21-47-43.mp4")
    colors = np.random.randint(0, 255, size=(1, 3), dtype="uint8")
    a = []
    time_begin = time.time()
    NUM = cap.get(cv2.CAP_PROP_FRAME_COUNT)
    #NUM=0
    frame_array = []
    count = 0
    test_scales = (1280, 960)
    while cap.isOpened():
        #img = Image.open(os.path.join("/home/mia_dev/xeroblade2/dataset/test/2021-03-22_21-49-34/", img_list[im].strip()))
        ret, img = cap.read()
        if ret is False:
            break
        #print(os.path.join("/home/mia_dev/xeroblade2/dataset/test/2021-03-22_21-49-34/", img_list[im].strip()))
        img_size = img.shape

        #img = cv2.resize(img, (1280, 960), interpolation=cv2.INTER_CUBIC)
        img = cv2.resize(img, test_scales)

        RGBimg = changeBGR2RGB(img)
        x = transform(RGBimg)
        x = x.cuda()
        x = x.unsqueeze(0)
        x = torch.autograd.Variable(x)
        loc_preds, cls_preds = model(x)

        loc_preds = loc_preds.data.squeeze().type(torch.FloatTensor)
        cls_preds = cls_preds.data.squeeze().type(torch.FloatTensor)

        encoder = DataEncoder(test_scales)
        boxes, labels, sco, is_found = encoder.decode(loc_preds, cls_preds,
                                                      test_scales)
        if is_found:
            #img, boxes = resize(img, boxes, img_size)

            boxes = boxes.ceil()
            xmin = boxes[:, 0].clamp(min=1)
            ymin = boxes[:, 1].clamp(min=1)
            xmax = boxes[:, 2].clamp(max=img_size[0] - 1)
            ymax = boxes[:, 3].clamp(max=img_size[1] - 1)

            color = [int(c) for c in colors[0]]

            nums = len(boxes)
            print(nums)
            #for i in range(nums) :
            #dic[str(labels[i].item())].append([img_list[im].strip(), sco[i].item(), xmin[i].item(), ymin[i].item(), xmax[i].item(), ymax[i].item()])
            for i in range(nums):
                if float(sco[i]) > 0.5:
                    box_w = xmax[i] - xmin[i]
                    box_h = ymax[i] - ymin[i]
                    print(box_w, box_h)
                    box_w = int(box_w)
                    box_h = int(box_h)
                    # print(cls_conf)
                    x1 = int(xmin[i])
                    x2 = int(xmax[i])
                    y1 = int(ymin[i])
                    y2 = int(ymax[i])
                    img = cv2.rectangle(img, (x1, y1 + box_h), (x2, y1), color,
                                        2)
                    cv2.putText(img, 'nopon', (x1, y1),
                                cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2)
                    cv2.putText(img, str("%.2f" % float(sco[i])),
                                (x2, y2 - box_h), cv2.FONT_HERSHEY_SIMPLEX,
                                0.5, color, 2)
        vidframe = changeRGB2BGR(RGBimg)
        #cv2.imshow('frame', vidframe)
        frame_array.append(vidframe)
        #if cv2.waitKey(25) & 0xFF == ord('q'):
        #break

    pathOut = '/home/mia_dev/xeroblade2/RetinaNet-Pytorch-master/2021-03-22_21-47-43.mp4'
    fps = 60
    out = cv2.VideoWriter(pathOut, cv2.VideoWriter_fourcc(*'DIVX'), fps,
                          test_scales)
    print(len(frame_array))
    for i in range(len(frame_array)):
        # writing to a image array
        out.write(frame_array[i])
    out.release()
    cap.release()
    cv2.destroyAllWindows()
    """