Exemple #1
0
    def __init__(self, pretrain=False):
        super(ResDown, self).__init__()
        self.features = resnet50(layer3=True, layer4=False)
        if pretrain:
            load_pretrain(self.features, 'resnet.model')

        self.downsample = ResDownS(1024, 256)
Exemple #2
0
    def __init__(self, pretrain=False):
        super(ResDown, self).__init__()
        self.features = resnet50(layer3=True, layer4=False)
        if pretrain:
            load_pretrain(self.features, 'resnet.model')

        self.downsample = ResDownS(1024, 256)

        self.layers = [self.downsample, self.features.layer2, self.features.layer3]
        self.train_nums = [1, 3]
        self.change_point = [0, 0.5]

        self.unfix(0.0)
Exemple #3
0
    def __init__(self, config=default_config, context=True):
        super(SharpMask, self).__init__()
        self.context = context  # with context
        self.km, self.ks = config.km, config.ks
        self.skpos = [6, 5, 4, 2]

        deepmask = DeepMask(config)
        #deeomask_resume = os.path.join('exps', 'deepmask', 'train', 'model_best.pth.tar')###########
        deeomask_resume = os.path.join('exps', 'deepmask', 'train', 'checkpoint.pth.tar')
        assert os.path.exists(deeomask_resume), "Please train DeepMask first"
        deepmask = load_pretrain(deepmask, deeomask_resume)
        self.trunk = deepmask.trunk
        self.crop_trick = deepmask.crop_trick
        self.scoreBranch = deepmask.scoreBranch
        #self.maskBranchDM = deepmask.maskBranch##########
        self.maskBranch = deepmask.maskBranch
        self.fSz = deepmask.fSz

        self.refs = self.createTopDownRefinement()  # create refinement modules

        nph = sum(p.numel()
                  for h in self.neths for p in h.parameters()) / 1e+06
        npv = sum(p.numel()
                  for h in self.netvs for p in h.parameters()) / 1e+06
        print('| number of paramaters net h: {:.3f} M'.format(nph))
        print('| number of paramaters net v: {:.3f} M'.format(npv))
        print('| number of paramaters total: {:.3f} M'.format(nph + npv))
Exemple #4
0
def main():
    global args, logger, v_id
    args = parser.parse_args()
    cfg = load_config(args)

    init_log('global', logging.INFO)
    if args.log != "":
        add_file_handler('global', args.log, logging.INFO)

    logger = logging.getLogger('global')
    logger.info(args)

    # setup model
    if args.arch == 'Custom':
        from custom import Custom
        model = Custom(anchors=cfg['anchors'])
    else:
        parser.error('invalid architecture: {}'.format(args.arch))

    if args.resume:
        assert isfile(args.resume), '{} is not a valid file'.format(args.resume)
        model = load_pretrain(model, args.resume)
    model.eval()
    device = torch.device('cuda' if (torch.cuda.is_available() and not args.cpu) else 'cpu')
    model = model.to(device)
    # setup dataset
    dataset = load_dataset(args.dataset)

    # VOS or VOT?
    if args.dataset in ['DAVIS','DAVIS2016', 'DAVIS2017', 'ytb_vos'] and args.mask:
        vos_enable = True  # enable Mask output
    else:
        vos_enable = False

    total_lost = 0  # VOT
    iou_lists = []  # VOS
    speed_list = []

    for v_id, video in enumerate(dataset.keys(), start=1):
        if args.video != '' and video != args.video:
            continue

        if vos_enable:
            iou_list, speed = track_vos(model, dataset[video], cfg['hp'] if 'hp' in cfg.keys() else None,
                                 args.mask, args.refine, args.dataset in ['DAVIS2017', 'ytb_vos'], device=device)
            iou_lists.append(iou_list)
        else:
            lost, speed = track_vot(model, dataset[video], cfg['hp'] if 'hp' in cfg.keys() else None,
                             args.mask, args.refine, device=device)
            total_lost += lost
        speed_list.append(speed)

    # report final result
    if vos_enable:
        for thr, iou in zip(thrs, np.mean(np.concatenate(iou_lists), axis=0)):
            logger.info('Segmentation Threshold {:.2f} mIoU: {:.3f}'.format(thr, iou))
    else:
        logger.info('Total Lost: {:d}'.format(total_lost))

    logger.info('Mean Speed: {:.2f} FPS'.format(np.mean(speed_list)))
Exemple #5
0
def main():
    global args
    args = parser.parse_args()
    # Setup device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # Setup Model
    from collections import namedtuple
    Config = namedtuple('Config', ['iSz', 'oSz', 'gSz', 'batch'])
    config = Config(iSz=160, oSz=56, gSz=112, batch=1)  # default for training

    model = (models.__dict__[args.arch](config))
    model = load_pretrain(model, args.resume)
    model = model.eval().to(device)

    scales = [2**i for i in range_end(args.si, args.sf, args.ss)]
    meanstd = {'mean': [0.485, 0.456, 0.406], 'std': [0.229, 0.224, 0.225]}
    infer = Infer(nps=args.nps,
                  scales=scales,
                  meanstd=meanstd,
                  model=model,
                  device=device)

    print('| start')
    tic = time.time()
    im = np.array(Image.open(args.img).convert('RGB'), dtype=np.float32)
    h, w = im.shape[:2]
    img = np.expand_dims(np.transpose(im, (2, 0, 1)),
                         axis=0).astype(np.float32)
    img = torch.from_numpy(img / 255.).to(device)
    infer.forward(img)
    masks, scores = infer.getTopProps(.2, h, w)
    toc = time.time() - tic
    print('| done in %05.3f s' % toc)

    for i in range(masks.shape[2]):
        res = im[:, :, ::-1].copy().astype(np.uint8)
        res[:, :,
            2] = masks[:, :, i] * 255 + (1 - masks[:, :, i]) * res[:, :, 2]

        mask = masks[:, :, i].astype(np.uint8)
        _, contours, _ = cv2.findContours(mask.copy(), cv2.RETR_EXTERNAL,
                                          cv2.CHAIN_APPROX_NONE)
        cnt_area = [cv2.contourArea(cnt) for cnt in contours]
        cnt_max_id = np.argmax(cnt_area)
        contour = contours[cnt_max_id]
        polygons = contour.reshape(-1, 2)

        predict_box = cv2.boundingRect(polygons)
        predict_rbox = cv2.minAreaRect(polygons)
        rbox = cv2.boxPoints(predict_rbox)
        print('Segment Proposal Score: {:.3f}'.format(scores[i]))

        res = cv2.rectangle(
            res, (predict_box[0], predict_box[1]),
            (predict_box[0] + predict_box[2], predict_box[1] + predict_box[3]),
            (0, 255, 0), 3)
        res = cv2.polylines(res, [np.int0(rbox)], True, (0, 255, 255), 3)
        cv2.imshow('Proposal', res)
        cv2.waitKey(0)
Exemple #6
0
    def __init__(self, config_path, model_path):
        args = TrackArgs()
        args.config = config_path
        args.resume = model_path

        cfg = load_config(args)
        if args.arch == 'Custom':
            from custom import Custom
            self.model = Custom(anchors=cfg['anchors'])
        else:
            parser.error('invalid architecture: {}'.format(args.arch))

        if args.resume:
            assert isfile(args.resume), '{} is not a valid file'.format(args.resume)
            self.model = load_pretrain(self.model, args.resume)
        self.model.eval()
        self.device = torch.device('cuda' if (torch.cuda.is_available() and not args.cpu) else 'cpu')
        self.model = self.model.to(self.device)

        ################# Dangerous
        self.p = TrackerConfig()
        self.p.update(cfg['hp'] if 'hp' in cfg.keys() else None, self.model.anchors)
        self.p.renew()

        self.p.scales = self.model.anchors['scales']
        self.p.ratios = self.model.anchors['ratios']
        self.p.anchor_num = self.model.anchor_num
        self.p.anchor = generate_anchor(self.model.anchors, self.p.score_size)

        if self.p.windowing == 'cosine':
            self.window = np.outer(np.hanning(self.p.score_size), np.hanning(self.p.score_size))
        elif self.p.windowing == 'uniform':
            self.window = np.ones((self.p.score_size, self.p.score_size))
        self.window = np.tile(self.window.flatten(), self.p.anchor_num)
def main():
    global args, best_acc, tb_writer, logger
    args = parser.parse_args()  # args通过解析获得的

    init_log('global', logging.INFO)

    if args.log != "":
        add_file_handler('global', args.log, logging.INFO)

    logger = logging.getLogger('global')  # 实例化一个记录器
    logger.info("\n" + collect_env_info())
    logger.info(args)

    cfg = load_config(args)
    logger.info("config \n{}".format(json.dumps(
        cfg, indent=4)))  # 转变成json格式的文件,缩进4格

    if args.log_dir:
        tb_writer = SummaryWriter(args.log_dir)
    else:
        tb_writer = Dummy()

    # build dataset
    train_loader, val_loader = build_data_loader(cfg)

    if args.arch == 'Custom':
        from custom import Custom
        model = Custom(pretrain=True, anchors=cfg['anchors'])
    else:
        exit()
    logger.info(model)

    if args.pretrained:
        model = load_pretrain(model, args.pretrained)

    model = model.cuda()  # 模型转移到GPU上
    dist_model = torch.nn.DataParallel(
        model, list(range(torch.cuda.device_count()))).cuda()  # 多GPU训练

    if args.resume and args.start_epoch != 0:  # 这是在干啥?蒙蔽了!!!!!
        model.features.unfix((args.start_epoch - 1) / args.epochs)

    optimizer, lr_scheduler = build_opt_lr(model, cfg, args,
                                           args.start_epoch)  # 如何构建优化器和学习策略???
    # optionally resume from a checkpoint
    if args.resume:
        assert os.path.isfile(args.resume), '{} is not a valid file'.format(
            args.resume)
        model, optimizer, args.start_epoch, best_acc, arch = restore_from(
            model, optimizer, args.resume)
        dist_model = torch.nn.DataParallel(
            model, list(range(torch.cuda.device_count()))).cuda()

    logger.info(lr_scheduler)

    logger.info('model prepare done')

    train(train_loader, dist_model, optimizer, lr_scheduler, args.start_epoch,
          cfg)
Exemple #8
0
    def __init__(self, pretrain=False):
        super(ResDown, self).__init__()
        # 利用resnet50进行特征提取
        self.features = resnet50(layer3=True, layer4=False)
        # 若存在预训练网络则将其直接赋值给feature
        if pretrain:
            load_pretrain(self.features, 'resnet.model')
        # adjust
        self.downsample = ResDownS(1024, 256)
        # 网络层
        self.layers = [
            self.downsample, self.features.layer2, self.features.layer3
        ]
        self.train_nums = [1, 3]
        self.change_point = [0, 0.5]

        self.unfix(0.0)
Exemple #9
0
    def __init__(self, pretrain=False):
        super(ResDown, self).__init__()
        self.features = resnet50(
            layer3=True, layer4=False)  # 只取resnet的前3层,这个过程是在搭建框架,参数的具体值并没有给出
        if pretrain:
            load_pretrain(self.features,
                          'resnet.model')  # 这里载入预训练的参数; 返回来的model传给了谁???

        self.downsample = ResDownS(1024, 256)  # 进入了调整层adjust,由1024转变成256

        self.layers = [
            self.downsample, self.features.layer2, self.features.layer3
        ]  # 这步想干啥???
        self.train_nums = [1, 3]
        self.change_point = [
            0, 0.5
        ]  # 前一半的epoch只训练self.downsample;后一半的epoch训练self.layers里面的全部层

        self.unfix(0.0)
Exemple #10
0
def main():
    global args, best_acc, tb_writer, logger
    args = parser.parse_args()

    init_log('global', logging.INFO) # 返回一个logger对象,logging_INFO是日志的等级

    if args.log != "":
        add_file_handler('global', args.log, logging.INFO)

    logger = logging.getLogger('global')  # 获取上面初始化的logger对象
    logger.info("\n" + collect_env_info())
    logger.info(args)

    cfg = load_config(args)  # 返回修改后的配置文件对象
    
    logger.info("config \n{}".format(json.dumps(cfg, indent=4)))  #json.loads()是将str转化成dict格式,json.dumps()是将dict转化成str格式。

    if args.log_dir:
        tb_writer = SummaryWriter(args.log_dir)
    else:
        tb_writer = Dummy()

    # build dataset
    train_loader, val_loader = build_data_loader(cfg)  

    if args.arch == 'Custom':
        from custom import Custom
        model = Custom(pretrain=True, anchors=cfg['anchors'])
    else:
        exit()
    logger.info(model)

    if args.pretrained:
        model = load_pretrain(model, args.pretrained)

    model = model.cuda()
    dist_model = torch.nn.DataParallel(model, list(range(torch.cuda.device_count()))).cuda()

    if args.resume and args.start_epoch != 0:
        model.features.unfix((args.start_epoch - 1) / args.epochs)

    optimizer, lr_scheduler = build_opt_lr(model, cfg, args, args.start_epoch)
    # optionally resume from a checkpoint
    if args.resume:
        assert os.path.isfile(args.resume), '{} is not a valid file'.format(args.resume)
        model, optimizer, args.start_epoch, best_acc, arch = restore_from(model, optimizer, args.resume)
        dist_model = torch.nn.DataParallel(model, list(range(torch.cuda.device_count()))).cuda()

    logger.info(lr_scheduler)

    logger.info('model prepare done')

    train(train_loader, dist_model, optimizer, lr_scheduler, args.start_epoch, cfg)
Exemple #11
0
def main():
    global args, best_acc, tb_writer, logger
    args = parser.parse_args()
    args = args_process(args)

    init_log('global', logging.INFO)

    if args.log != "":
        add_file_handler('global', args.log, logging.INFO)

    logger = logging.getLogger('global')
    logger.info("\n" + collect_env_info())
    logger.info(args)

    cfg = load_config(args)
    logger.info("config \n{}".format(json.dumps(cfg, indent=4)))

    # build dataset
    train_loader, val_loader = build_data_loader(cfg)

    args.img_size = int(cfg['train_datasets']['search_size'])
    args.nms_threshold = float(cfg['train_datasets']['RPN_NMS'])
    if args.arch == 'Custom':
        from custom import Custom
        model = Custom(pretrain=True,
                       opts=args,
                       anchors=train_loader.dataset.anchors)
    else:
        exit()
    logger.info(model)

    if args.pretrained:
        model = load_pretrain(model, args.pretrained)
    else:
        raise Exception("Pretrained weights must be loaded!")

    model = model.cuda()
    dist_model = torch.nn.DataParallel(model,
                                       list(range(
                                           torch.cuda.device_count()))).cuda()

    logger.info('model prepare done')

    logger = logging.getLogger('global')
    val_avg = AverageMeter()

    validation(val_loader, dist_model, cfg, val_avg)
Exemple #12
0
def set_model():
    # Setup device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    # Setup Model
    from collections import namedtuple

    if args.arch == 'DeepMask':
        Config = namedtuple('Config', ['iSz', 'oSz', 'gSz', 'batch'])
        config = Config(iSz=160, oSz=56, gSz=112,
                        batch=1)  # default for training (Deepmask)
    elif args.arch == 'SharpMask':
        Config = namedtuple('Config',
                            ['iSz', 'oSz', 'gSz', 'batch', 'km', 'ks'])
        config = Config(iSz=160, oSz=56, gSz=160, batch=1, km=32, ks=32)

    model = (models.__dict__[args.arch](config))
    model = load_pretrain(model, args.resume)
    model = model.eval().to(device)  #使模型為TEST階段

    return model, device
Exemple #13
0
    def __init__(self, args):
        super(PatchTrainer, self).__init__()

        # Setup device
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        torch.backends.cudnn.benchmark = True

        # Setup tracker cfg
        cfg = load_config(args)
        p = TrackerConfig()
        p.renew()
        self.p = p

        # Setup tracker
        siammask = Tracker(p=p, anchors=cfg['anchors'])
        if args.resume:
            assert isfile(args.resume), 'Please download {} first.'.format(args.resume)
            siammask = load_pretrain(siammask, args.resume)
        siammask.eval().to(self.device)
        self.model = siammask
Exemple #14
0
def main():
    global args, best_acc, tb_writer, logger
    args = parser.parse_args()
    args = args_process(args)

    init_log('global', logging.INFO)

    if args.log != "":
        add_file_handler('global', args.log, logging.INFO)

    logger = logging.getLogger('global')
    logger.info("\n" + collect_env_info())
    logger.info(args)

    cfg = load_config(args)
    logger.info("config \n{}".format(json.dumps(cfg, indent=4)))

    if args.log_dir:
        tb_writer = SummaryWriter(args.log_dir)
    else:
        tb_writer = Dummy()

    # build dataset
    train_loader, val_loader = build_data_loader(cfg)

    args.img_size = int(cfg['train_datasets']['search_size'])
    args.nms_threshold = float(cfg['train_datasets']['RPN_NMS'])
    if args.arch == 'Custom':
        from custom import Custom
        model = Custom(pretrain=True,
                       opts=args,
                       anchors=train_loader.dataset.anchors)
    else:
        exit()
    logger.info(model)

    if args.pretrained:
        model = load_pretrain(model, args.pretrained)

    model = model.cuda()
    dist_model = torch.nn.DataParallel(model,
                                       list(range(
                                           torch.cuda.device_count()))).cuda()

    if args.resume and args.start_epoch != 0:
        model.features.unfix((args.start_epoch - 1) / args.epochs)

    optimizer, lr_scheduler = build_opt_lr(model, cfg, args, args.start_epoch)
    # optionally resume from a checkpoint
    if args.resume:
        assert os.path.isfile(args.resume), '{} is not a valid file'.format(
            args.resume)
        model, optimizer, args.start_epoch, best_acc, arch = restore_from(
            model, optimizer, args.resume)
        dist_model = torch.nn.DataParallel(
            model, list(range(torch.cuda.device_count()))).cuda()

    logger.info(lr_scheduler)

    logger.info('model prepare done')

    train(train_loader, dist_model, optimizer, lr_scheduler, args.start_epoch,
          cfg)
Exemple #15
0
def main():
    global args
    args = parser.parse_args()
    # Setup device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # Setup Model
    from collections import namedtuple
    Config = namedtuple('Config', ['iSz', 'oSz', 'gSz', 'batch'])
    config = Config(iSz=160, oSz=56, gSz=112, batch=1)  # default for training

    model = (models.__dict__[args.arch](config))
    model = load_pretrain(model, args.resume)
    model = model.eval().to(device)

    scales = [2**i for i in range_end(args.smin, args.smax, args.sstep)]
    meanstd = {'mean': [0.485, 0.456, 0.406], 'std': [0.229, 0.224, 0.225]}
    infer = Infer(nps=args.nps,
                  scales=scales,
                  meanstd=meanstd,
                  model=model,
                  device=device,
                  timer=False)

    annFile = '{}/annotations/instances_{}2017.json'.format(
        args.datadir, args.split)
    coco = COCO(annFile)
    imgIds = coco.getImgIds()
    imgIds = sorted(imgIds)  # [:500] for fast test
    segm_props = []
    print('| start eval')
    tic = time.time()
    for k, imgId in enumerate(imgIds):
        ann = coco.loadImgs(imgId)[0]
        fileName = ann['file_name']
        pathImg = '{}/{}2017/{}'.format(args.datadir, args.split, fileName)
        im = np.array(Image.open(pathImg).convert('RGB'), dtype=np.float32)
        h, w = im.shape[:2]
        img = np.expand_dims(np.transpose(im, (2, 0, 1)),
                             axis=0).astype(np.float32)
        img = torch.from_numpy(img / 255.).to(device)

        infer.forward(img)
        masks, scores = infer.getTopProps(args.thr, h, w)

        enc = encode(np.asfortranarray(masks))
        for i in range(args.nps):
            enc[i]['counts'] = enc[i]['counts'].decode('utf-8')
            elem = {
                'segmentation': enc[i],
                'image_id': imgId,
                'category_id': 1,
                'score': scores[i]
            }
            segm_props.append(elem)

        # for i in range(args.nps):
        #     res = im[:, :, ::-1].copy().astype(np.uint8)
        #     res[:, :, 2] = masks[:, :, i] * 255 + (1 - masks[:, :, i]) * res[:, :, 2]
        #
        #     mask = masks[i].astype(np.uint8)
        #     _, contour, _ = cv2.findContours(mask.copy(), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)
        #     polygons = [c.reshape(-1, 2) for c in contour]
        #
        #     predict_box = cv2.boundingRect(polygons[0])
        #     predict_rbox = cv2.minAreaRect(polygons[0])
        #     box = cv2.boxPoints(predict_rbox)
        #     print('Segment Proposal Score: {:.3f}'.format(scores[i]))
        #
        #     res = cv2.rectangle(res, (predict_box[0], predict_box[1]),
        #                         (predict_box[0] + predict_box[2], predict_box[1] + predict_box[3]), (0, 255, 0), 3)
        #     res = cv2.polylines(res, [np.int0(box)], True, (0, 255, 255), 3)
        #     cv2.imshow('Proposal', res)
        #     cv2.waitKey(0)
        # plt.imshow(im); plt.axis('off')
        # coco.showAnns([props[-1]])
        # plt.show()
        if (k + 1) % 10 == 0:
            toc = time.time() - tic
            print('| process %05d in %010.3f s' % (k + 1, toc))

    toc = time.time() - tic
    print('| finish in %010.3f s' % toc)

    pathsv = 'sharpmask/eval_coco' if args.arch == 'SharpMask' else 'deepmask/eval_coco'
    args.rundir = join(args.rundir, pathsv)
    try:
        if not isdir(args.rundir):
            makedirs(args.rundir)
    except OSError as err:
        print(err)

    result_path = join(args.rundir, 'segm_proposals.json')
    with open(result_path, 'w') as outfile:
        json.dump(segm_props, outfile)

    cocoDt = coco.loadRes(result_path)

    print('\n\nBox Proposals Evalution\n\n')
    annType = ['bbox']  # segm  bbox
    cocoEval = COCOeval(coco, cocoDt)

    max_dets = [10, 100, 1000]
    useSegm = False
    useCats = False

    cocoEval.params.imgIds = imgIds
    cocoEval.params.maxDets = max_dets
    cocoEval.params.useSegm = useSegm
    cocoEval.params.useCats = useCats
    cocoEval.evaluate()
    cocoEval.accumulate()
    cocoEval.summarize()

    print('\n\nSegmentation Proposals Evalution\n\n')
    annType = ['segm']  # segm  bbox
    cocoEval = COCOeval(coco, cocoDt)

    max_dets = [10, 100, 1000]
    useSegm = True
    useCats = False

    cocoEval.params.imgIds = imgIds
    cocoEval.params.maxDets = max_dets
    cocoEval.params.useSegm = useSegm
    cocoEval.params.useCats = useCats
    cocoEval.evaluate()
    cocoEval.accumulate()
    cocoEval.summarize()
Exemple #16
0
        plt.axis('off')
        plt.subplots_adjust(.0, .0, 1, 1)
        plt.draw()
        plt.pause(0.5)

    free_image(im)
    free_detections(dets, num)
    return res


if __name__ == "__main__":
    Config = namedtuple('Config', ['iSz', 'oSz', 'gSz'])
    default_config = Config(iSz=160, oSz=56, gSz=160)
    model = DeepMask(default_config)
    model = load_pretrain(model, './pretrained/deepmask/DeepMask.pth.tar')
    model = model.eval().to('cuda')

    # net = load_net(b"./darknet/cfg/yolov3-tiny.cfg", b"yolov3-tiny.weights", 0)
    net = load_net(b"./darknet/cfg/yolov3.cfg", b"yolov3.weights", 0)
    meta = load_meta(b"./darknet/cfg/coco.data")
    image_files = glob.glob('./data/coco/val2017/*.jpg')
    tic = time.time()
    if VISUALIZATION:
        fig, ax = plt.subplots(1)

    for i, image_file in enumerate(image_files):
        r = detect(net, meta, image_file.encode(), mask_model=model)
        print(r)
        # plt.savefig('%05d.jpg' % i)
    toc = time.time() - tic
Exemple #17
0
def main():
    global args, logger, v_id
    args = parser.parse_args()
    cfg = load_config(args)

    init_log('global', logging.INFO)
    if args.log != "":
        add_file_handler('global', args.log, logging.INFO)

    logger = logging.getLogger('global')
    logger.info(args)

    # setup model
    if args.arch == 'Custom':
        from custom import Custom
        model = Custom(anchors=cfg['anchors'])
    else:
        parser.error('invalid architecture: {}'.format(args.arch))

    if args.resume:
        assert isfile(args.resume), '{} is not a valid file'.format(
            args.resume)
        model = load_pretrain(model, args.resume)
    model.eval()
    device = torch.device('cuda' if (
        torch.cuda.is_available() and not args.cpu) else 'cpu')
    model = model.to(device)
    # setup dataset
    dataset = load_dataset(args.dataset, args.dir_type)

    # VOS or VOT?
    if args.dataset in ['DAVIS2016', 'DAVIS2017', 'ytb_vos'] and args.mask:
        vos_enable = True  # enable Mask output
    else:
        vos_enable = False

    # total_lost = 0  # VOT
    # iou_lists = []  # VOS
    # speed_list = []

    for v_id, video in enumerate(dataset.keys(), start=1):
        if args.video != '' and video != args.video:
            continue

        if vos_enable:
            iou_list, speed = track_vos(
                model,
                dataset[video],
                cfg['hp'] if 'hp' in cfg.keys() else None,
                args.mask,
                args.refine,
                args.dataset in ['DAVIS2017', 'ytb_vos'],
                device=device)
            # iou_lists.append(iou_list)
        else:
            lost, speed = track_vot(model,
                                    dataset[video],
                                    cfg['hp'] if 'hp' in cfg.keys() else None,
                                    args.mask,
                                    args.refine,
                                    device=device)
            total_lost += lost
            plt.axis('off')
            plt.subplots_adjust(.0, .0, 1, 1)
            plt.draw()
            plt.pause(0.01)

        # if f % 3 == 1:
        #     plt.savefig('%05d.jpg' % f)

    return toc / cv2.getTickFrequency()


if __name__ == "__main__":
    Config = namedtuple('Config', ['iSz', 'oSz', 'gSz'])
    default_config = Config(iSz=160, oSz=56, gSz=160)
    mask_net = DeepMask(default_config)
    mask_net = load_pretrain(mask_net,
                             './pretrained/deepmask/DeepMask.pth.tar')
    mask_net = mask_net.eval().cuda()

    # load net
    tracker_net = SiamRPNvot()
    tracker_net.load_state_dict(torch.load('SiamRPNVOT.model'))
    tracker_net.eval().cuda()

    image_files = sorted(glob.glob('./tracker/bag/*.jpg'))
    tic = time.time()
    if VISUALIZATION:
        try:
            fig
        except NameError:
            fig, ax = plt.subplots(1)
Exemple #19
0
                    help='ground truth txt file')
parser.add_argument('--cpu', action='store_true', help='cpu mode')
args = parser.parse_args()

if __name__ == '__main__':
    # Setup device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    torch.backends.cudnn.benchmark = True

    # Setup Model
    cfg = load_config(args)
    siammask = Custom_(anchors=cfg['anchors'])
    if args.resume:
        assert isfile(args.resume), 'Please download {} first.'.format(
            args.resume)
        siammask = load_pretrain(siammask, args.resume)

    siammask.eval().to(device)

    # Parse Image file
    img_files = sorted(glob.glob(join(args.base_path, '*.jp*')))
    ims = [cv2.imread(imf) for imf in img_files]

    # Select ROI
    cv2.namedWindow("SiamMask", cv2.WND_PROP_FULLSCREEN)
    # cv2.setWindowProperty("SiamMask", cv2.WND_PROP_FULLSCREEN, cv2.WINDOW_FULLSCREEN)
    x, y, w, h = 305, 112, 163, 253
    if args.gt_file:
        with open(args.base_path + '/../' + args.gt_file, "r") as f:
            gts = f.readlines()
            split_flag = ',' if ',' in gts[0] else '\t'
Exemple #20
0
def main():

    # args.base_path = base_path
    args.resume = "../SiamMask/experiments/siammask_sharp/SiamMask_DAVIS.pth"
    args.config = "../SiamMask/experiments/siammask_sharp/config_davis.json"
    print(join(args.base_path, 'groundtruth_rect.txt'))

    # Setup device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    torch.backends.cudnn.benchmark = True

    # Setup Model
    cfg = load_config(args)
    p = TrackerConfig()
    p.renew()
    siammask = Tracker(p=p, anchors=cfg['anchors'])
    if args.resume:
        assert isfile(args.resume), 'Please download {} first.'.format(args.resume)
        siammask = load_pretrain(siammask, args.resume)
    siammask.eval().to(device)

    # Parse Image file
    img_files = sorted(glob.glob(join(join(args.base_path, 'imgs'), '*.jp*')))
    ims = [cv2.imread(imf) for imf in img_files]

    # Select ROI
    cv2.namedWindow("SiamMask", cv2.WND_PROP_FULLSCREEN)
    try:
        init_rect = cv2.selectROI('SiamMask', ims[0], False, False)
        gts = None
        x, y, w, h = init_rect
    except:
        exit()

    file1 = open(join(args.base_path, 'groundtruth_rect.txt'), 'w') 
    file1.write('{0:d},{1:d},{2:d},{3:d}\n'.format(x, y, w, h))

    toc = 0
    for f, im in enumerate(ims):
        tic = cv2.getTickCount()
        if f == 0:  # init
            target_pos = np.array([x + w / 2, y + h / 2])
            target_sz = np.array([w, h])
            state = tracker_init(im, target_pos, target_sz, siammask, device=device)  # init tracker
            state['gts'] = gts
            state['device'] = device
        elif f > 0:  # tracking
            state = tracker_track(state, im, siammask, device=device)  # track
            target_pos, target_sz =state['target_pos'], state['target_sz']
            x, y = (target_pos - target_sz/2).astype(int)
            x2, y2 = (target_pos + target_sz/2).astype(int)
            cv2.rectangle(im, (x, y), (x2, y2), (0, 255, 0), 4)
            cv2.imshow('SiamMask', im)
            key = cv2.waitKey(1)
            if key == ord('q'):
                break
            file1.write('{0:d},{1:d},{2:d},{3:d}\n'.format(x, y, x2-x, y2-y))
        toc += cv2.getTickCount() - tic
    file1.close() 

    toc /= cv2.getTickFrequency()
    fps = f / toc
    print('SiamMask Time: {:02.1f}s Speed: {:3.1f}fps (with visulization!)'.format(toc, fps))
Exemple #21
0
def main():
    # 获取命令行参数信息
    global args, logger, v_id
    args = parser.parse_args()
    # 获取配置文件中配置信息:主要包括网络结构,超参数等
    cfg = load_config(args)
    # 初始化logxi信息,并将日志信息输入到磁盘文件中
    init_log('global', logging.INFO)
    if args.log != "":
        add_file_handler('global', args.log, logging.INFO)
    # 将相关的配置信息输入到日志文件中
    logger = logging.getLogger('global')
    logger.info(args)

    # setup model
    # 加载网络模型架构
    if args.arch == 'Custom':
        from custom import Custom
        model = Custom(anchors=cfg['anchors'])
    else:
        parser.error('invalid architecture: {}'.format(args.arch))
    # 加载网络模型参数
    if args.resume:
        assert isfile(args.resume), '{} is not a valid file'.format(
            args.resume)
        model = load_pretrain(model, args.resume)
    # 使用评估模式,将drop等激活
    model.eval()
    # 硬件信息
    device = torch.device('cuda' if (
        torch.cuda.is_available() and not args.cpu) else 'cpu')
    model = model.to(device)
    # 加载数据集 setup dataset
    dataset = load_dataset(args.dataset)

    # 这三种数据支持掩膜 VOS or VOT?
    if args.dataset in ['DAVIS2016', 'DAVIS2017', 'ytb_vos'] and args.mask:
        vos_enable = True  # enable Mask output
    else:
        vos_enable = False

    total_lost = 0  # VOT
    iou_lists = []  # VOS
    speed_list = []
    # 对数据进行处理
    for v_id, video in enumerate(dataset.keys(), start=1):
        if args.video != '' and video != args.video:
            continue
        # true 调用track_vos
        if vos_enable:
            # 如测试数据是['DAVIS2017', 'ytb_vos']时,会开启多目标跟踪
            iou_list, speed = track_vos(
                model,
                dataset[video],
                cfg['hp'] if 'hp' in cfg.keys() else None,
                args.mask,
                args.refine,
                args.dataset in ['DAVIS2017', 'ytb_vos'],
                device=device)
            iou_lists.append(iou_list)
        # False 调用track_vot
        else:
            lost, speed = track_vot(model,
                                    dataset[video],
                                    cfg['hp'] if 'hp' in cfg.keys() else None,
                                    args.mask,
                                    args.refine,
                                    device=device)
            total_lost += lost
        speed_list.append(speed)

    # report final result
    if vos_enable:
        for thr, iou in zip(thrs, np.mean(np.concatenate(iou_lists), axis=0)):
            logger.info('Segmentation Threshold {:.2f} mIoU: {:.3f}'.format(
                thr, iou))
    else:
        logger.info('Total Lost: {:d}'.format(total_lost))

    logger.info('Mean Speed: {:.2f} FPS'.format(np.mean(speed_list)))
def main():
    init_log('global', logging.INFO)
    logger = logging.getLogger('global')
    global args, best_recall
    args = parser.parse_args()
    cfg = load_config(args.config)

    if args.dist:
        logger.info('dist:{}'.format(args.dist))
        dist_init(args.port, backend=args.backend)

    # build dataset
    train_loader, val_loader = build_data_loader(args.dataset, cfg)
    # if args.arch == 'resnext_101_64x4d_deform_maskrcnn':
    #     model = resnext_101_64x4d_deform_maskrcnn(cfg = cfg['shared'])
    # elif args.arch == 'FishMask':
    #     model = FishMask(cfg = cfg['shared'])
    # else:
    #     if args.arch.find('fpn'):
    #         arch = args.arch.replace('fpn', '')
    #         model = resnet_fpn.__dict__[arch](pretrained=False, cfg = cfg['shared'])
    #     else:
    model = resnet.__dict__[args.arch](pretrained=False, cfg=cfg['shared'])
    logger.info('build model done')
    logger.info(model)

    if args.pretrained:
        model = load_pretrain(model, args.pretrained)

    trainable_params = [p for p in model.parameters() if p.requires_grad]
    optimizer = torch.optim.SGD(trainable_params,
                                args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    # optionally resume from a checkpoint
    if args.resume:
        assert os.path.isfile(args.resume), '{} is not a valid file'.format(
            args.resume)
        model, optimizer, args.start_epoch, best_recall, arch = restore_from(
            model, optimizer, args.resume)

    model = model.cuda()
    if args.dist:
        broadcast_params(model)

    logger.info('build dataloader done')
    if args.evaluate:
        rc = validate(val_loader, model, cfg)
        logger.info('recall=%f' % rc)
        return

    # warmup to enlarge lr
    if args.start_epoch == 0 and args.warmup_epochs > 0:
        world_size = 1
        try:
            world_size = dist.get_world_size()
        except Exception as e:
            print(e)
        rate = world_size * args.batch_size
        warmup_iter = args.warmup_epochs * len(train_loader)
        assert (warmup_iter > 1)
        gamma = rate**(1.0 / (warmup_iter - 1))
        lr_scheduler = IterExponentialLR(optimizer, gamma)
        for epoch in range(args.warmup_epochs):
            logger.info('warmup epoch %d' % (epoch))
            train(train_loader,
                  model,
                  lr_scheduler,
                  epoch + 1,
                  cfg,
                  warmup=True)
        # overwrite initial_lr with magnified lr through warmup
        for group in optimizer.param_groups:
            group['initial_lr'] = group['lr']
        logger.info('warmup for %d epochs done, start large batch training' %
                    args.warmup_epochs)

    lr_scheduler = MultiStepLR(optimizer,
                               milestones=args.step_epochs,
                               gamma=0.1,
                               last_epoch=args.start_epoch - 1)
    for epoch in range(args.start_epoch, args.epochs):
        logger.info('step_epochs:{}'.format(args.step_epochs))
        lr_scheduler.step()
        lr = lr_scheduler.get_lr()[0]
        # train for one epoch
        train(train_loader, model, lr_scheduler, epoch + 1, cfg)

        if (epoch + 1) % 5 == 0 or epoch + 1 == args.epochs:
            # evaluate on validation set
            recall = validate(val_loader, model, cfg)
            # remember best prec@1 and save checkpoint
            is_best = recall > best_recall
            best_recall = max(recall, best_recall)
            logger.info('recall %f(%f)' % (recall, best_recall))

        if (not args.dist) or (dist.get_rank() == 0):
            if not os.path.exists(args.save_dir):
                os.makedirs(args.save_dir)
            save_path = os.path.join(args.save_dir,
                                     'checkpoint_e%d.pth' % (epoch + 1))
            torch.save(
                {
                    'epoch': epoch + 1,
                    'arch': args.arch,
                    'state_dict': model.cpu().state_dict(),
                    'best_recall': best_recall,
                    'optimizer': optimizer.state_dict(),
                }, save_path)
def main():
    global args, best_acc, tb_writer, logger
    args = parser.parse_args()

    init_log('global', logging.INFO)

    if args.log != "":
        add_file_handler('global', args.log, logging.INFO)

    
    print("Init logger")

    logger = logging.getLogger('global')

    print(44)
    #logger.info("\n" + collect_env_info())
    print(99)
    logger.info(args)

    cfg = load_config(args)
    logger.info("config \n{}".format(json.dumps(cfg, indent=4)))

    print(2)

    if args.log_dir:
        tb_writer = SummaryWriter(args.log_dir)
    else:
        tb_writer = Dummy()

    # build dataset
    train_loader, val_loader = build_data_loader(cfg)

    print(3)

    path = "/usr4/alg504/cliao25/siammask/experiments/siammask_base/snapshot/checkpoint_e{}.pth"

    for epoch in range(1,21):

        if args.arch == 'Custom':
            from custom import Custom
            model = Custom(pretrain=True, anchors=cfg['anchors'])
        else:
            exit()

        print(4)

        if args.pretrained:
            model = load_pretrain(model, args.pretrained)

        model = model.cuda()


        #model.features.unfix((epoch - 1) / 20)
        optimizer, lr_scheduler = build_opt_lr(model, cfg, args, epoch)
        filepath = path.format(epoch)
        assert os.path.isfile(filepath)

        model, _, _, _, _ = restore_from(model, optimizer, filepath)
        #model = load_pretrain(model, filepath)
        model = torch.nn.DataParallel(model, list(range(torch.cuda.device_count()))).cuda()

        model.train()
        device = torch.device('cuda')
        model = model.to(device)

        valid(val_loader, model, cfg)

    print("Done")
Exemple #24
0
def main():
    global args, best_acc, tb_writer, logger
    args = parser.parse_args()

    init_log('global', logging.INFO)

    if args.log != "":
        add_file_handler('global', args.log, logging.INFO)

    logger = logging.getLogger('global')
    logger.info(args)

    cfg = load_config(args)

    logger.info("config \n{}".format(json.dumps(cfg, indent=4)))
    
    logger.info("\n" + collect_env_info())

    if args.log_dir:
        tb_writer = SummaryWriter(args.log_dir)
    else:
        tb_writer = Dummy()

    # build dataset
    train_loader, val_loader = build_data_loader(cfg)

    if args.arch == 'Custom':
        from custom import Custom
        model = Custom(pretrain=True, anchors=cfg['anchors'])
    else:
        model = models.__dict__[args.arch](anchors=cfg['anchors'])

    logger.info(model)

    if args.pretrained:
        model = load_pretrain(model, args.pretrained)

    model = model.cuda()
    dist_model = torch.nn.DataParallel(model, list(range(torch.cuda.device_count()))).cuda()

    if args.resume and args.start_epoch != 0:
        model.features.unfix((args.start_epoch - 1) / args.epochs)

    optimizer, lr_scheduler = build_opt_lr(model, cfg, args, args.start_epoch)
    logger.info(lr_scheduler)
    # optionally resume from a checkpoint
    if args.resume:
        assert os.path.isfile(args.resume), '{} is not a valid file'.format(args.resume)
        model, optimizer, args.start_epoch, best_acc, arch = restore_from(model, optimizer, args.resume)
        dist_model = torch.nn.DataParallel(model, list(range(torch.cuda.device_count()))).cuda()
        epoch = args.start_epoch
        if dist_model.module.features.unfix(epoch/args.epochs):
            logger.info('unfix part model.')
            optimizer, lr_scheduler = build_opt_lr(dist_model.module, cfg, args, epoch)
        lr_scheduler.step(epoch)
        cur_lr = lr_scheduler.get_cur_lr()
        logger.info('epoch:{} resume lr {}'.format(epoch, cur_lr))

    logger.info('model prepare done')

    train(train_loader, dist_model, optimizer, lr_scheduler, args.start_epoch, cfg)
Exemple #25
0
def main():
    global args, best_acc, tb_writer, logger
    args = parser.parse_args()
    args = args_process(args)

    init_log('global', logging.INFO)

    if args.log != "":
        add_file_handler('global', args.log, logging.INFO)

    logger = logging.getLogger('global')
    logger.info("\n" + collect_env_info())
    logger.info(args)

    cfg = load_config(args)
    logger.info("config \n{}".format(json.dumps(cfg, indent=4)))

    if args.log_dir:
        tb_writer = SummaryWriter(args.log_dir)
    else:
        tb_writer = Dummy()

    # build dataset
    train_loader, val_loader = build_data_loader(cfg)

    args.img_size = int(cfg['train_datasets']['search_size'])
    args.nms_threshold = float(cfg['train_datasets']['RPN_NMS'])
    if args.arch == 'Custom':
        from custom import Custom
        model = Custom(pretrain=True,
                       opts=args,
                       anchors=train_loader.dataset.anchors)
    else:
        exit()
    logger.info(model)

    if args.pretrained:
        model = load_pretrain(model, args.pretrained)

    model = model.cuda()
    dist_model = torch.nn.DataParallel(model,
                                       list(range(
                                           torch.cuda.device_count()))).cuda()

    if args.resume and args.start_epoch != 0:
        model.features.unfix((args.start_epoch - 1) / args.epochs)

    optimizer, lr_scheduler = build_opt_lr(model, cfg, args, args.start_epoch)
    # optionally resume from a checkpoint
    if args.resume:
        assert os.path.isfile(args.resume), '{} is not a valid file'.format(
            args.resume)
        model, optimizer, args.start_epoch, best_acc, arch = restore_from(
            model, optimizer, args.resume)
        dist_model = torch.nn.DataParallel(
            model, list(range(torch.cuda.device_count()))).cuda()

    logger.info(lr_scheduler)

    logger.info('model prepare done')
    global cur_lr

    if not os.path.exists(args.save_dir):  # makedir/save model
        os.makedirs(args.save_dir)
    num_per_epoch = len(train_loader.dataset) // args.batch
    num_per_epoch_val = len(val_loader.dataset) // args.batch

    for epoch in range(args.start_epoch, args.epochs):
        lr_scheduler.step(epoch)
        cur_lr = lr_scheduler.get_cur_lr()
        logger = logging.getLogger('global')
        train_avg = AverageMeter()
        val_avg = AverageMeter()

        if dist_model.module.features.unfix(epoch / args.epochs):
            logger.info('unfix part model.')
            optimizer, lr_scheduler = build_opt_lr(dist_model.module, cfg,
                                                   args, epoch)

        train(train_loader, dist_model, optimizer, lr_scheduler, epoch, cfg,
              train_avg, num_per_epoch)

        if dist_model.module.features.unfix(epoch / args.epochs):
            logger.info('unfix part model.')
            optimizer, lr_scheduler = build_opt_lr(dist_model.module, cfg,
                                                   args, epoch)

        if (epoch + 1) % args.save_freq == 0:
            save_checkpoint(
                {
                    'epoch': epoch,
                    'arch': args.arch,
                    'state_dict': dist_model.module.state_dict(),
                    'best_acc': best_acc,
                    'optimizer': optimizer.state_dict(),
                    'anchor_cfg': cfg['anchors']
                }, False,
                os.path.join(args.save_dir, 'checkpoint_e%d.pth' % (epoch)),
                os.path.join(args.save_dir, 'best.pth'))

            validation(val_loader, dist_model, epoch, cfg, val_avg,
                       num_per_epoch_val)
Exemple #26
0
def process_vedio(vedio_path, initRect):
    """
    视频处理
    :param vedio_path:视频路径
    :param initRect: 跟踪目标的初始位置
    :return:
    """

    # 1. 设置设备信息 Setup device
    # 有GPU时选择GPU,否则使用CPU
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    # 默认优化运行效率
    torch.backends.cudnn.benchmark = True

    # 2. 模型设置 Setup Model
    # 2.1 将命令行参数解析出来
    cfg = load_config(args)

    # 2.2 custom是构建的网络,否则引用model中的网络结构
    from custom import Custom
    siammask = Custom(anchors=cfg['anchors'])
    # 2.3 判断是否存在模型的权重文件
    if args.resume:
        assert isfile(args.resume), 'Please download {} first.'.format(
            args.resume)
        siammask = load_pretrain(siammask, args.resume)
    # 在运行推断前,需要调用 model.eval() 函数,以将 dropout 层 和 batch normalization 层设置为评估模式(非训练模式).
    # to(device)将张量复制到GPU上,之后的计算将在GPU上运行
    siammask.eval().to(device)

    # 首帧跟踪目标的位置
    x, y, w, h = initRect
    print(x)
    VeryBig = 999999999  # 用于将视频框调整到最大
    Cap = cv2.VideoCapture(vedio_path)  # 设置读取摄像头
    ret, frame = Cap.read()  # 读取帧
    ims = [frame]  # 把frame放入列表格式的frame, 因为原文是将每帧图片放入列表

    im = frame
    f = 0
    target_pos = np.array([x + w / 2, y + h / 2])
    target_sz = np.array([w, h])
    state = siamese_init(im, target_pos, target_sz, siammask,
                         cfg['hp'])  # init tracker"
    middlepath = "../data/middle.mp4"
    outpath = "../data/output.mp4"
    vediowriter = cv2.VideoWriter(middlepath,
                                  cv2.VideoWriter_fourcc('M', 'P', '4', 'V'),
                                  10, (320, 240))
    while (True):
        tic = cv2.getTickCount()
        ret, im = Cap.read()  # 逐个提取frame
        if (ret == False):
            break
        state = siamese_track(state, im, mask_enable=True,
                              refine_enable=True)  # track
        location = state['ploygon'].flatten()
        mask = state['mask'] > state['p'].seg_thr
        im[:, :, 2] = (mask > 0) * 255 + (mask == 0) * im[:, :, 2]
        cv2.polylines(im, [np.int0(location).reshape((-1, 1, 2))], True,
                      (0, 255, 0), 3)
        vediowriter.write(im)
        cv2.imshow('SiamMask', im)
        key = cv2.waitKey(1)
        if key > 0:
            break

        f = f + 1
    vediowriter.release()

    return
Exemple #27
0
def main():
    init_log('global', logging.INFO)
    if args.log != "":
        add_file_handler('global', args.log, logging.INFO)

    params = {'penalty_k': args.penalty_k,
              'window_influence': args.window_influence,
              'lr': args.lr,
              'instance_size': args.search_region}

    num_search = len(params['penalty_k']) * len(params['window_influence']) * \
        len(params['lr']) * len(params['instance_size'])

    print(params)
    print(num_search)

    cfg = load_config(args)
    if args.arch == 'Custom':
        from custom import Custom
        model = Custom(anchors=cfg['anchors'])
    else:
        model = models.__dict__[args.arch](anchors=cfg['anchors'])

    if args.resume:
        assert isfile(args.resume), '{} is not a valid file'.format(args.resume)
        model = load_pretrain(model, args.resume)
    model.eval()
    model = model.to(device)

    default_hp = cfg.get('hp', {})

    p = dict()

    p['network'] = model
    p['network_name'] = args.arch+'_'+args.resume.split('/')[-1].split('.')[0]
    p['dataset'] = args.dataset

    global ims, gt, image_files

    dataset_info = load_dataset(args.dataset)
    videos = list(dataset_info.keys())
    np.random.shuffle(videos)

    for video in videos:
        print(video)
        if isfile('finish.flag'):
            return

        p['video'] = video
        ims = None
        image_files = dataset_info[video]['image_files']
        gt = dataset_info[video]['gt']

        np.random.shuffle(params['penalty_k'])
        np.random.shuffle(params['window_influence'])
        np.random.shuffle(params['lr'])
        for penalty_k in params['penalty_k']:
            for window_influence in params['window_influence']:
                for lr in params['lr']:
                    for instance_size in params['instance_size']:
                        p['hp'] = default_hp.copy()
                        p['hp'].update({'penalty_k':penalty_k,
                                'window_influence':window_influence,
                                'lr':lr,
                                'instance_size': instance_size,
                                })
                        tune(p)
def main():
    """
    基础网络的训练
    :return:
    """
    global args, best_acc, tb_writer, logger
    args = parser.parse_args()
    # 初始化日志信息
    init_log('global', logging.INFO)

    if args.log != "":
        add_file_handler('global', args.log, logging.INFO)
    # 获取log信息
    logger = logging.getLogger('global')
    logger.info("\n" + collect_env_info())
    logger.info(args)
    # 获取配置信息
    cfg = load_config(args)
    logger.info("config \n{}".format(json.dumps(cfg, indent=4)))

    if args.log_dir:
        tb_writer = SummaryWriter(args.log_dir)
    else:
        tb_writer = Dummy()

    # 构建数据集
    train_loader, val_loader = build_data_loader(cfg)
    # 加载训练网络
    if args.arch == 'Custom':
        from custom import Custom
        model = Custom(pretrain=True, anchors=cfg['anchors'])
    else:
        exit()
    logger.info(model)
    # 加载预训练网络
    if args.pretrained:
        model = load_pretrain(model, args.pretrained)

    # GPU版本
    # model = model.cuda()
    # dist_model = torch.nn.DataParallel(model, list(range(torch.cuda.device_count()))).cuda()
    # 网络模型
    dist_model = torch.nn.DataParallel(model)
    # 模型参数的更新比例
    if args.resume and args.start_epoch != 0:
        model.features.unfix((args.start_epoch - 1) / args.epochs)
    # 获取优化器和学习率的更新策略
    optimizer, lr_scheduler = build_opt_lr(model, cfg, args, args.start_epoch)
    # optionally resume from a checkpoint 加载模型
    if args.resume:
        assert os.path.isfile(args.resume), '{} is not a valid file'.format(
            args.resume)
        model, optimizer, args.start_epoch, best_acc, arch = restore_from(
            model, optimizer, args.resume)
        # GPU
        # dist_model = torch.nn.DataParallel(model, list(range(torch.cuda.device_count()))).cuda()
        dist_model = torch.nn.DataParallel(model)

    logger.info(lr_scheduler)

    logger.info('model prepare done')
    # 模型训练
    train(train_loader, dist_model, optimizer, lr_scheduler, args.start_epoch,
          cfg)
Exemple #29
0
def main():
    global args, best_acc, tb_writer, logger
    args = parser.parse_args()

    init_log('global', logging.INFO)

    if args.log != "":
        add_file_handler('global', args.log, logging.INFO)

    logger = logging.getLogger('global')
    logger.info("\n" + collect_env_info())
    logger.info(args)

    cfg = load_config(args)
    logger.info("config \n{}".format(json.dumps(cfg, indent=4)))

    if args.log_dir:
        tb_writer = SummaryWriter(args.log_dir)
    else:
        tb_writer = Dummy()

    # build dataset
    train_loader, val_loader = build_data_loader(cfg)

    if args.arch == 'Custom':
        model = Custom(anchors=cfg['anchors'])
    elif args.arch == 'Custom_Sky':
        model = Custom_Sky(anchors=cfg['anchors'])
    else:
        exit()
    logger.info(model)

    if args.pretrained:
        model = load_pretrain(model, args.pretrained)

    # print(summary(model=model, input_size=(3, 511, 511), batch_size=1))
    model = model.cuda()
    dist_model = torch.nn.DataParallel(model,
                                       list(range(
                                           torch.cuda.device_count()))).cuda()

    if args.resume and args.start_epoch != 0:
        model.features.unfix((args.start_epoch - 1) / args.epochs)

    optimizer, lr_scheduler = build_opt_lr(model, cfg, args, args.start_epoch)
    # optionally resume from a checkpoint
    if args.resume:
        print(args.resume)
        assert os.path.isfile(args.resume), '{} is not a valid file'.format(
            args.resume)
        model, optimizer, args.start_epoch, best_acc, arch = restore_from(
            model, optimizer, args.resume)
        dist_model = torch.nn.DataParallel(
            model, list(range(torch.cuda.device_count()))).cuda()

    logger.info(lr_scheduler)

    logger.info('model prepare done')

    train(train_loader, dist_model, optimizer, lr_scheduler, args.start_epoch,
          cfg)
def main():
    init_log('global', logging.INFO)
    logger = logging.getLogger('global')
    global args, best_recall, best_map
    args = parser.parse_args()
    cfg = load_config(args.config)

    if args.dist:
        logger.info('dist:{}'.format(args.dist))
        dist_init(args.port, backend=args.backend)

    # build dataset
    train_loader, val_loader, target_loader = build_data_loader(
        args.dataset, cfg)

    if args.arch == 'vgg16_FasterRCNN':
        model = vgg16_FasterRCNN(pretrained=False, cfg=cfg['shared'])
    elif args.arch == 'vgg16bn_FasterRCNN':
        model = vgg16bn_FasterRCNN(pretrained=False, cfg=cfg['shared'])
    else:
        logger.info("The arch is not in model zoo")
        exit()
    logger.info('build model done')
    logger.info(model)

    if args.pretrained:
        model = load_pretrain(model, args.pretrained)

    trainable_params = [p for p in model.parameters() if p.requires_grad]
    optimizer = torch.optim.Adam(trainable_params,
                                 args.lr,
                                 betas=(0.9, 0.999),
                                 weight_decay=0.0001)

    #=============build gan part===============
    dis_model, dec_model, dis_model_patch = builder_gan(args)
    trainable_params_dec = [
        p for p in dec_model.parameters() if p.requires_grad
    ]
    trainable_params_dis = [
        p for p in dis_model.parameters() if p.requires_grad
    ]
    trainable_params_dis_patch = [
        p for p in dis_model_patch.parameters() if p.requires_grad
    ]
    dis_optimizer = torch.optim.Adam(trainable_params_dis,
                                     args.lr,
                                     betas=(0.9, 0.999),
                                     weight_decay=0.0001)
    dis_patch_optimizer = torch.optim.Adam(trainable_params_dis_patch,
                                           args.lr,
                                           betas=(0.9, 0.999),
                                           weight_decay=0.0001)
    dec_optimizer = torch.optim.Adam(trainable_params_dec,
                                     args.lr,
                                     betas=(0.9, 0.999),
                                     weight_decay=0.0001)

    # optionally resume from a checkpoint
    if args.resume:
        assert os.path.isfile(args.resume), '{} is not a valid file'.format(
            args.resume)
        model, optimizer, args.start_epoch, best_recall, arch = restore_from(
            model, optimizer, args.resume)

    model = model.cuda()
    dis_model = dis_model.cuda()
    dec_model = dec_model.cuda()
    dis_model_patch = dis_model_patch.cuda()
    if args.dist:
        # broadcast_params([model, dis_model, dec_model])
        broadcast_params(model)
        broadcast_params(dis_model)
        broadcast_params(dec_model)
        broadcast_params(dis_model_patch)

    logger.info('build dataloader done')
    if args.evaluate:
        if args.dist:
            rc = validate(val_loader, model, cfg)
            logger.info('recall=%f' % rc)
            return
        else:
            rc = validate_single(val_loader, model, cfg)
            logger.info('recall=%f' % rc)
            return

    # warmup to enlarge lr
    if args.start_epoch == 0 and args.warmup_epochs > 0:
        world_size = 1
        try:
            world_size = dist.get_world_size()
        except Exception as e:
            print(e)
        rate = world_size * args.batch_size
        warmup_iter = args.warmup_epochs * len(train_loader)
        assert (warmup_iter > 1)
        gamma = rate**(1.0 / (warmup_iter - 1))
        lr_scheduler = IterExponentialLR(optimizer, gamma)
        lr_scheduler_dis = IterExponentialLR(dis_optimizer, gamma)
        lr_scheduler_dis_patch = IterExponentialLR(dis_patch_optimizer, gamma)

        lr_scheduler_dec = IterExponentialLR(dec_optimizer, gamma)
        for epoch in range(args.warmup_epochs):
            logger.info('warmup epoch %d' % (epoch))
            train(train_loader,
                  target_loader,
                  val_loader,
                  model,
                  dec_model,
                  dis_model,
                  dis_model_patch,
                  lr_scheduler,
                  lr_scheduler_dec,
                  lr_scheduler_dis,
                  lr_scheduler_dis_patch,
                  epoch + 1,
                  cfg,
                  warmup=True)
        # overwrite initial_lr with magnified lr through warmup
        for group in optimizer.param_groups + dis_optimizer.param_groups + dec_optimizer.param_groups + dis_patch_optimizer.param_groups:
            group['initial_lr'] = group['lr']
        logger.info('warmup for %d epochs done, start large batch training' %
                    args.warmup_epochs)

    lr_scheduler = MultiStepLR(optimizer,
                               milestones=args.step_epochs,
                               gamma=0.1,
                               last_epoch=args.start_epoch - 1)
    lr_scheduler_dis = MultiStepLR(dis_optimizer,
                                   milestones=args.step_epochs,
                                   gamma=0.1,
                                   last_epoch=args.start_epoch - 1)
    lr_scheduler_dec = MultiStepLR(dec_optimizer,
                                   milestones=args.step_epochs,
                                   gamma=0.1,
                                   last_epoch=args.start_epoch - 1)
    lr_scheduler_dis_patch = MultiStepLR(dis_patch_optimizer,
                                         milestones=args.step_epochs,
                                         gamma=0.1,
                                         last_epoch=args.start_epoch - 1)

    for epoch in range(args.start_epoch, args.epochs):
        logger.info('step_epochs:{}'.format(args.step_epochs))
        lr_scheduler.step()
        lr_scheduler_dis.step()
        lr_scheduler_dec.step()
        lr_scheduler_dis_patch.step()
        lr = lr_scheduler.get_lr()[0]
        # train for one epoch

        train(train_loader, target_loader, val_loader, model, dec_model,
              dis_model, dis_model_patch, lr_scheduler, lr_scheduler_dec,
              lr_scheduler_dis, lr_scheduler_dis_patch, epoch + 1, cfg)

        if (epoch + 1) % args.eval_interval == 0 or epoch + 1 == args.epochs:
            # evaluate on validation set
            recall = validate(val_loader, model, cfg)
            # remember best prec@1 and save checkpoint
            is_best = recall > best_recall
            best_recall = max(recall, best_recall)
            logger.info('recall %f(%f)' % (recall, best_recall))

        if (not args.dist) or (dist.get_rank() == 0):
            if not os.path.exists(args.save_dir):
                os.makedirs(args.save_dir)
            save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'arch': args.arch,
                    'state_dict': model.cpu().state_dict(),
                    'best_recall': best_recall,
                    'optimizer': optimizer.state_dict(),
                }, False,
                os.path.join(args.save_dir,
                             'checkpoint_e%d.pth' % (epoch + 1)))