def main():
    saver = create_saver(cfg.local_rank, save_dir=cfg.ckpt_dir)
    logger = create_logger(cfg.local_rank, save_dir=cfg.log_dir)
    summary_writer = create_summary(cfg.local_rank, log_dir=cfg.log_dir)
    print = logger.info
    print(cfg)

    torch.manual_seed(319)
    torch.backends.cudnn.benchmark = True  # disable this if OOM at beginning of training

    num_gpus = torch.cuda.device_count()
    if cfg.dist:
        cfg.device = torch.device('cuda:%d' % cfg.local_rank)
        torch.cuda.set_device(cfg.local_rank)
        dist.init_process_group(backend='nccl',
                                init_method='env://',
                                world_size=num_gpus,
                                rank=cfg.local_rank)
    else:
        cfg.device = torch.device('cuda:%d' % cfg.device_id)

    print('Setting up data...')
    dictionary = np.load(cfg.dictionary_file)
    Dataset = COCOSEGMSHIFT if cfg.dataset == 'coco' else PascalVOC
    train_dataset = Dataset(cfg.data_dir,
                            cfg.dictionary_file,
                            'train',
                            split_ratio=cfg.split_ratio,
                            img_size=cfg.img_size)
    train_sampler = torch.utils.data.distributed.DistributedSampler(
        train_dataset, num_replicas=num_gpus, rank=cfg.local_rank)
    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=cfg.batch_size // num_gpus if cfg.dist else cfg.batch_size,
        shuffle=not cfg.dist,
        num_workers=cfg.num_workers,
        pin_memory=False,
        drop_last=True,
        sampler=train_sampler if cfg.dist else None)

    Dataset_eval = COCO_eval_segm_shift if cfg.dataset == 'coco' else PascalVOC_eval
    val_dataset = Dataset_eval(cfg.data_dir,
                               cfg.dictionary_file,
                               'val',
                               test_scales=[1.],
                               test_flip=False)
    val_loader = torch.utils.data.DataLoader(val_dataset,
                                             batch_size=1,
                                             shuffle=False,
                                             num_workers=1,
                                             pin_memory=False,
                                             collate_fn=val_dataset.collate_fn)

    print('Creating model...')
    if 'hourglass' in cfg.arch:
        model = get_hourglass[cfg.arch]
    elif 'resdcn' in cfg.arch:
        model = get_pose_net(num_layers=int(cfg.arch.split('_')[-1]),
                             num_classes=train_dataset.num_classes)
    else:
        raise NotImplementedError

    if cfg.dist:
        # model = nn.SyncBatchNorm.convert_sync_batchnorm(model)
        model = model.to(cfg.device)
        model = nn.parallel.DistributedDataParallel(
            model, device_ids=[
                cfg.local_rank,
            ], output_device=cfg.local_rank)
    else:
        model = nn.DataParallel(model, device_ids=[
            cfg.local_rank,
        ]).to(cfg.device)

    if cfg.pretrain_checkpoint is not None and os.path.isfile(
            cfg.pretrain_checkpoint):
        print('Load pretrain model from ' + cfg.pretrain_checkpoint)
        model = load_model(model, cfg.pretrain_checkpoint, cfg.device_id)
        torch.cuda.empty_cache()

    optimizer = torch.optim.Adam(model.parameters(), cfg.lr)
    lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer,
                                                        cfg.lr_step,
                                                        gamma=0.1)

    def train(epoch):
        print('\n Epoch: %d' % epoch)
        model.train()
        tic = time.perf_counter()
        for batch_idx, batch in enumerate(train_loader):
            for k in batch:
                if k != 'meta':
                    batch[k] = batch[k].to(device=cfg.device,
                                           non_blocking=True)

            outputs = model(batch['image'])
            hmap, regs, w_h_, codes_ = zip(*outputs)

            regs = [
                _tranpose_and_gather_feature(r, batch['inds']) for r in regs
            ]
            w_h_ = [
                _tranpose_and_gather_feature(r, batch['inds']) for r in w_h_
            ]
            codes_ = [
                _tranpose_and_gather_feature(r, batch['inds']) for r in codes_
            ]

            hmap_loss = _neg_loss(hmap, batch['hmap'])
            reg_loss = _reg_loss(regs, batch['regs'], batch['ind_masks'])
            w_h_loss = _reg_loss(w_h_, batch['w_h_'], batch['ind_masks'])
            codes_loss = norm_reg_loss(codes_, batch['codes'],
                                       batch['ind_masks'])
            # codes_loss = mse_reg_loss(codes_, batch['codes'], batch['ind_masks'])
            loss = hmap_loss + 1 * reg_loss + 0.1 * w_h_loss + cfg.code_loss_weight * codes_loss

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            if batch_idx % cfg.log_interval == 0:
                duration = time.perf_counter() - tic
                tic = time.perf_counter()
                print(
                    '[%d/%d-%d/%d] ' %
                    (epoch, cfg.num_epochs, batch_idx, len(train_loader)) +
                    ' hmap_loss= %.3f reg_loss= %.3f w_h_loss= %.3f  code_loss= %.3f'
                    % (hmap_loss.item(), reg_loss.item(), w_h_loss.item(),
                       codes_loss.item()) + ' (%d samples/sec)' %
                    (cfg.batch_size * cfg.log_interval / duration))

                step = len(train_loader) * epoch + batch_idx
                summary_writer.add_scalar('hmap_loss', hmap_loss.item(), step)
                summary_writer.add_scalar('reg_loss', reg_loss.item(), step)
                summary_writer.add_scalar('w_h_loss', w_h_loss.item(), step)
                summary_writer.add_scalar('code_loss', codes_loss.item(), step)
        return

    def val_map(epoch):
        print('\n Val@Epoch: %d' % epoch)
        model.eval()
        torch.cuda.empty_cache()
        max_per_image = 100

        results = {}
        input_scales = {}
        speed_list = []
        with torch.no_grad():
            for inputs in val_loader:
                img_id, inputs = inputs[0]
                start_image_time = time.time()
                segmentations = []
                for scale in inputs:
                    inputs[scale]['image'] = inputs[scale]['image'].to(
                        cfg.device)
                    if scale == 1. and img_id not in input_scales.keys(
                    ):  # keep track of the input image Sizes
                        _, _, input_h, input_w = inputs[scale]['image'].shape
                        input_scales[img_id] = {'h': input_h, 'w': input_w}

                    output = model(inputs[scale]['image'])[-1]

                    segms = ctsegm_shift_decode(
                        *output,
                        torch.from_numpy(dictionary.astype(np.float32)).to(
                            cfg.device),
                        K=cfg.test_topk)
                    segms = segms.detach().cpu().numpy().reshape(
                        1, -1, segms.shape[2])[0]

                    top_preds = {}
                    for j in range(cfg.n_vertices):
                        segms[:, 2 * j:2 * j + 2] = transform_preds(
                            segms[:, 2 * j:2 * j + 2], inputs[scale]['center'],
                            inputs[scale]['scale'],
                            (inputs[scale]['fmap_w'], inputs[scale]['fmap_h']))
                    segms[:, cfg.n_vertices * 2:cfg.n_vertices * 2 +
                          2] = transform_preds(
                              segms[:,
                                    cfg.n_vertices * 2:cfg.n_vertices * 2 + 2],
                              inputs[scale]['center'], inputs[scale]['scale'],
                              (inputs[scale]['fmap_w'],
                               inputs[scale]['fmap_h']))
                    segms[:, cfg.n_vertices * 2 + 2:cfg.n_vertices * 2 +
                          4] = transform_preds(
                              segms[:, cfg.n_vertices * 2 +
                                    2:cfg.n_vertices * 2 + 4],
                              inputs[scale]['center'], inputs[scale]['scale'],
                              (inputs[scale]['fmap_w'],
                               inputs[scale]['fmap_h']))

                    clses = segms[:, -1]
                    for j in range(val_dataset.num_classes):
                        inds = (clses == j)
                        top_preds[j + 1] = segms[inds, :cfg.n_vertices * 2 +
                                                 5].astype(np.float32)
                        top_preds[j + 1][:, :cfg.n_vertices * 2 + 4] /= scale

                    segmentations.append(top_preds)

                end_image_time = time.time()
                segms_and_scores = {
                    j: np.concatenate([d[j] for d in segmentations], axis=0)
                    for j in range(1, val_dataset.num_classes + 1)
                }
                scores = np.hstack([
                    segms_and_scores[j][:, cfg.n_vertices * 2 + 4]
                    for j in range(1, val_dataset.num_classes + 1)
                ])
                if len(scores) > max_per_image:
                    kth = len(scores) - max_per_image
                    thresh = np.partition(scores, kth)[kth]
                    for j in range(1, val_dataset.num_classes + 1):
                        keep_inds = (
                            segms_and_scores[j][:, cfg.n_vertices * 2 + 4] >=
                            thresh)
                        segms_and_scores[j] = segms_and_scores[j][keep_inds]

                results[img_id] = segms_and_scores
                speed_list.append(end_image_time - start_image_time)

        eval_results = val_dataset.run_eval(results,
                                            input_scales,
                                            save_dir=cfg.ckpt_dir)
        print(eval_results)
        summary_writer.add_scalar('val_mAP/mAP', eval_results[0], epoch)
        print('Average speed on val set:{:.2f}'.format(1. /
                                                       np.mean(speed_list)))

    print('Starting training...')
    for epoch in range(1, cfg.num_epochs + 1):
        start = time.time()
        train_sampler.set_epoch(epoch)
        train(epoch)
        if (cfg.val_interval > 0
                and epoch % cfg.val_interval == 0) or epoch == 3:
            val_map(epoch)
            print(saver.save(model.module.state_dict(), 'checkpoint'))
        lr_scheduler.step(epoch)  # move to here after pytorch1.1.0

        epoch_time = (time.time() - start) / 3600. / 24.
        print('ETA:{:.2f} Days'.format((cfg.num_epochs - epoch) * epoch_time))

    summary_writer.close()
Beispiel #2
0
def main():
    saver = create_saver(cfg.local_rank, save_dir=cfg.ckpt_dir)
    logger = create_logger(cfg.local_rank, save_dir=cfg.log_dir)
    summary_writer = create_summary(cfg.local_rank, log_dir=cfg.log_dir)
    print = logger.info
    print(cfg)

    torch.manual_seed(317)
    # disable this if OOM at beginning of training
    torch.backends.cudnn.benchmark = True

    num_gpus = torch.cuda.device_count()
    if cfg.dist:
        cfg.device = torch.device('cuda:%d' % cfg.local_rank)
        torch.cuda.set_device(cfg.local_rank)
        dist.init_process_group(backend='nccl',
                                init_method='env://',
                                world_size=num_gpus,
                                rank=cfg.local_rank)
    else:
        cfg.device = torch.device('cuda')

    print('Setting up data...')
    Dataset = COCO if cfg.dataset == 'coco' else PascalVOC

    train_dataset = Dataset(cfg.data_dir,
                            'train',
                            split_ratio=cfg.split_ratio,
                            img_size=cfg.img_size)

    train_sampler = torch.utils.data.distributed.DistributedSampler(
        train_dataset, num_replicas=num_gpus, rank=cfg.local_rank)

    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=cfg.batch_size // num_gpus if cfg.dist else cfg.batch_size,
        shuffle=not cfg.dist,
        num_workers=cfg.num_workers,
        pin_memory=True,
        drop_last=True,
        sampler=train_sampler if cfg.dist else None)

    Dataset_eval = COCO_eval if cfg.dataset == 'coco' else PascalVOC_eval

    val_dataset = Dataset_eval(cfg.data_dir,
                               'test',
                               test_scales=[1.],
                               test_flip=False)

    val_loader = torch.utils.data.DataLoader(val_dataset,
                                             batch_size=1,
                                             shuffle=False,
                                             num_workers=1,
                                             pin_memory=True,
                                             collate_fn=val_dataset.collate_fn)

    print('Creating model...')
    if 'hourglass' in cfg.arch:
        model = get_hourglass[cfg.arch]
    elif 'resdcn' in cfg.arch:
        model = get_pose_net(num_layers=int(cfg.arch.split('_')[-1]),
                             num_classes=train_dataset.num_classes)
    else:
        raise NotImplementedError

    if cfg.dist:
        # model = nn.SyncBatchNorm.convert_sync_batchnorm(model)
        model = model.to(cfg.device)
        model = nn.parallel.DistributedDataParallel(
            model, device_ids=[
                cfg.local_rank,
            ], output_device=cfg.local_rank)
    else:
        model = nn.DataParallel(model).to(cfg.device)

    if os.path.isfile(cfg.pretrain_dir):
        model = load_model(model, cfg.pretrain_dir)

    optimizer = torch.optim.Adam(model.parameters(), cfg.lr)
    lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer,
                                                        cfg.lr_step,
                                                        gamma=0.1)

    def train(epoch):
        print('\n Epoch: %d' % epoch)
        model.train()
        tic = time.perf_counter()
        for batch_idx, batch in enumerate(train_loader):
            for k in batch:
                if k != 'meta':
                    batch[k] = batch[k].to(device=cfg.device,
                                           non_blocking=True)

            outputs = model(batch['image'])

            # 得到heat map, reg, wh 三个变量
            hmap, regs, w_h_ = zip(*outputs)

            regs = [
                _tranpose_and_gather_feature(r, batch['inds']) for r in regs
            ]
            w_h_ = [
                _tranpose_and_gather_feature(r, batch['inds']) for r in w_h_
            ]

            # 分别计算loss
            hmap_loss = _neg_loss(hmap, batch['hmap'])
            reg_loss = _reg_loss(regs, batch['regs'], batch['ind_masks'])
            w_h_loss = _reg_loss(w_h_, batch['w_h_'], batch['ind_masks'])

            # 进行loss加权,得到最终loss
            loss = hmap_loss + 1 * reg_loss + 0.1 * w_h_loss

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            if batch_idx % cfg.log_interval == 0:
                duration = time.perf_counter() - tic
                tic = time.perf_counter()
                print('[%d/%d-%d/%d] ' %
                      (epoch, cfg.num_epochs, batch_idx, len(train_loader)) +
                      ' hmap_loss= %.5f reg_loss= %.5f w_h_loss= %.5f' %
                      (hmap_loss.item(), reg_loss.item(), w_h_loss.item()) +
                      ' (%d samples/sec)' %
                      (cfg.batch_size * cfg.log_interval / duration))

                step = len(train_loader) * epoch + batch_idx
                summary_writer.add_scalar('hmap_loss', hmap_loss.item(), step)
                summary_writer.add_scalar('reg_loss', reg_loss.item(), step)
                summary_writer.add_scalar('w_h_loss', w_h_loss.item(), step)
        return

    def val_map(epoch):
        print('\n Val@Epoch: %d' % epoch)
        model.eval()
        torch.cuda.empty_cache()
        max_per_image = 100

        results = {}
        with torch.no_grad():
            for inputs in val_loader:
                img_id, inputs = inputs[0]

                detections = []
                for scale in inputs:
                    inputs[scale]['image'] = inputs[scale]['image'].to(
                        cfg.device)
                    output = model(inputs[scale]['image'])[-1]

                    dets = ctdet_decode(*output, K=cfg.test_topk)
                    dets = dets.detach().cpu().numpy().reshape(
                        1, -1, dets.shape[2])[0]

                    top_preds = {}
                    dets[:, :2] = transform_preds(
                        dets[:, 0:2], inputs[scale]['center'],
                        inputs[scale]['scale'],
                        (inputs[scale]['fmap_w'], inputs[scale]['fmap_h']))
                    dets[:, 2:4] = transform_preds(
                        dets[:, 2:4], inputs[scale]['center'],
                        inputs[scale]['scale'],
                        (inputs[scale]['fmap_w'], inputs[scale]['fmap_h']))
                    clses = dets[:, -1]
                    for j in range(val_dataset.num_classes):
                        inds = (clses == j)
                        top_preds[j + 1] = dets[inds, :5].astype(np.float32)
                        top_preds[j + 1][:, :4] /= scale

                    detections.append(top_preds)

                bbox_and_scores = {
                    j: np.concatenate([d[j] for d in detections], axis=0)
                    for j in range(1, val_dataset.num_classes + 1)
                }
                scores = np.hstack([
                    bbox_and_scores[j][:, 4]
                    for j in range(1, val_dataset.num_classes + 1)
                ])
                if len(scores) > max_per_image:
                    kth = len(scores) - max_per_image
                    thresh = np.partition(scores, kth)[kth]
                    for j in range(1, val_dataset.num_classes + 1):
                        keep_inds = (bbox_and_scores[j][:, 4] >= thresh)
                        bbox_and_scores[j] = bbox_and_scores[j][keep_inds]

                results[img_id] = bbox_and_scores

        eval_results = val_dataset.run_eval(results, save_dir=cfg.ckpt_dir)
        print(eval_results)
        summary_writer.add_scalar('val_mAP/mAP', eval_results[0], epoch)

    print('Starting training...')
    for epoch in range(1, cfg.num_epochs + 1):
        train_sampler.set_epoch(epoch)
        train(epoch)
        if cfg.val_interval > 0 and epoch % cfg.val_interval == 0:
            val_map(epoch)
        print(saver.save(model.module.state_dict(), 'checkpoint'))
        lr_scheduler.step(epoch)  # move to here after pytorch1.1.0

    summary_writer.close()
def main():
    logger = create_logger(save_dir=cfg.log_dir)
    print = logger.info
    print(cfg)

    cfg.device = torch.device('cuda')
    torch.backends.cudnn.benchmark = False

    max_per_image = 100

    Dataset_eval = COCO_eval if cfg.dataset == 'coco' else PascalVOC_eval
    dataset = Dataset_eval(cfg.data_dir,
                           split='val',
                           img_size=cfg.img_size,
                           test_scales=cfg.test_scales,
                           test_flip=cfg.test_flip)
    data_loader = torch.utils.data.DataLoader(dataset,
                                              batch_size=1,
                                              shuffle=False,
                                              num_workers=1,
                                              pin_memory=False,
                                              collate_fn=dataset.collate_fn)

    print('Creating model...')
    if 'hourglass' in cfg.arch:
        model = get_hourglass[cfg.arch]
    elif 'resdcn' in cfg.arch:
        model = get_pose_net(num_layers=int(cfg.arch.split('_')[-1]),
                             num_classes=dataset.num_classes)
    else:
        raise NotImplementedError

    model = load_model(model, cfg.pretrain_dir)
    model = model.to(cfg.device)
    model.eval()

    results = {}
    with torch.no_grad():
        for inputs in data_loader:
            img_id, inputs = inputs[0]

            detections = []
            for scale in inputs:
                inputs[scale]['image'] = inputs[scale]['image'].to(cfg.device)

                output = model(inputs[scale]['image'])[-1]
                dets = ctdet_decode(*output, K=cfg.test_topk)
                dets = dets.detach().cpu().numpy().reshape(
                    1, -1, dets.shape[2])[0]

                top_preds = {}
                dets[:, :2] = transform_preds(
                    dets[:,
                         0:2], inputs[scale]['center'], inputs[scale]['scale'],
                    (inputs[scale]['fmap_w'], inputs[scale]['fmap_h']))
                dets[:, 2:4] = transform_preds(
                    dets[:,
                         2:4], inputs[scale]['center'], inputs[scale]['scale'],
                    (inputs[scale]['fmap_w'], inputs[scale]['fmap_h']))
                cls = dets[:, -1]
                for j in range(dataset.num_classes):
                    inds = (cls == j)
                    top_preds[j + 1] = dets[inds, :5].astype(np.float32)
                    top_preds[j + 1][:, :4] /= scale

                detections.append(top_preds)

            bbox_and_scores = {}
            for j in range(1, dataset.num_classes + 1):
                bbox_and_scores[j] = np.concatenate([d[j] for d in detections],
                                                    axis=0)
                if len(dataset.test_scales) > 1:
                    soft_nms(bbox_and_scores[j], Nt=0.5, method=2)
            scores = np.hstack([
                bbox_and_scores[j][:, 4]
                for j in range(1, dataset.num_classes + 1)
            ])

            if len(scores) > max_per_image:
                kth = len(scores) - max_per_image
                thresh = np.partition(scores, kth)[kth]
                for j in range(1, dataset.num_classes + 1):
                    keep_inds = (bbox_and_scores[j][:, 4] >= thresh)
                    bbox_and_scores[j] = bbox_and_scores[j][keep_inds]

            results[img_id] = bbox_and_scores

    eval_results = dataset.run_eval(results, cfg.ckpt_dir)
    print(eval_results)
Beispiel #4
0
def main():
    cfg.device = torch.device('cuda')
    torch.backends.cudnn.benchmark = False

    max_per_image = 100

    image = cv2.imread(cfg.img_dir)
    # orig_image = image
    height, width = image.shape[0:2]
    padding = 127 if 'hourglass' in cfg.arch else 31
    imgs = {}
    for scale in cfg.test_scales:
        new_height = int(height * scale)
        new_width = int(width * scale)

        if cfg.img_size > 0:
            img_height, img_width = cfg.img_size, cfg.img_size
            center = np.array([new_width / 2., new_height / 2.],
                              dtype=np.float32)
            scaled_size = max(height, width) * 1.0
            scaled_size = np.array([scaled_size, scaled_size],
                                   dtype=np.float32)
        else:
            img_height = (new_height | padding) + 1
            img_width = (new_width | padding) + 1
            center = np.array([new_width // 2, new_height // 2],
                              dtype=np.float32)
            scaled_size = np.array([img_width, img_height], dtype=np.float32)

        img = cv2.resize(image, (new_width, new_height))
        trans_img = get_affine_transform(center, scaled_size, 0,
                                         [img_width, img_height])
        img = cv2.warpAffine(img, trans_img, (img_width, img_height))

        img = img.astype(np.float32) / 255.
        img -= np.array(COCO_MEAN if cfg.dataset == 'coco' else VOC_MEAN,
                        dtype=np.float32)[None, None, :]
        img /= np.array(COCO_STD if cfg.dataset == 'coco' else VOC_STD,
                        dtype=np.float32)[None, None, :]
        img = img.transpose(2, 0,
                            1)[None, :, :, :]  # from [H, W, C] to [1, C, H, W]

        if cfg.test_flip:
            img = np.concatenate((img, img[:, :, :, ::-1].copy()), axis=0)

        imgs[scale] = {
            'image': torch.from_numpy(img).float(),
            'center': np.array(center),
            'scale': np.array(scaled_size),
            'fmap_h': np.array(img_height // 4),
            'fmap_w': np.array(img_width // 4)
        }

    print('Creating model...')
    if 'hourglass' in cfg.arch:
        model = get_hourglass[cfg.arch]
    elif 'resdcn' in cfg.arch:
        model = get_pose_net(num_layers=int(cfg.arch.split('_')[-1]),
                             num_classes=80 if cfg.dataset == 'coco' else 20)
    else:
        raise NotImplementedError

    model = load_model(model, cfg.ckpt_dir)
    model = model.to(cfg.device)
    model.eval()

    with torch.no_grad():
        detections = []
        for scale in imgs:
            imgs[scale]['image'] = imgs[scale]['image'].to(cfg.device)

            output = model(imgs[scale]['image'])[-1]
            dets = ctdet_decode(*output, K=cfg.test_topk)
            dets = dets.detach().cpu().numpy().reshape(1, -1, dets.shape[2])[0]

            top_preds = {}
            dets[:, :2] = transform_preds(
                dets[:, 0:2], imgs[scale]['center'], imgs[scale]['scale'],
                (imgs[scale]['fmap_w'], imgs[scale]['fmap_h']))
            dets[:, 2:4] = transform_preds(
                dets[:, 2:4], imgs[scale]['center'], imgs[scale]['scale'],
                (imgs[scale]['fmap_w'], imgs[scale]['fmap_h']))
            cls = dets[:, -1]
            for j in range(80):
                inds = (cls == j)
                top_preds[j + 1] = dets[inds, :5].astype(np.float32)
                top_preds[j + 1][:, :4] /= scale

            detections.append(top_preds)

        bbox_and_scores = {}
        for j in range(1, 81 if cfg.dataset == 'coco' else 21):
            bbox_and_scores[j] = np.concatenate([d[j] for d in detections],
                                                axis=0)
            if len(cfg.test_scales) > 1:
                soft_nms(bbox_and_scores[j], Nt=0.5, method=2)
        scores = np.hstack([
            bbox_and_scores[j][:, 4]
            for j in range(1, 81 if cfg.dataset == 'coco' else 21)
        ])

        if len(scores) > max_per_image:
            kth = len(scores) - max_per_image
            thresh = np.partition(scores, kth)[kth]
            for j in range(1, 81 if cfg.dataset == 'coco' else 21):
                keep_inds = (bbox_and_scores[j][:, 4] >= thresh)
                bbox_and_scores[j] = bbox_and_scores[j][keep_inds]

        # plt.imshow(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
        # plt.show()
        fig = plt.figure(0)
        colors = COCO_COLORS if cfg.dataset == 'coco' else VOC_COLORS
        names = COCO_NAMES if cfg.dataset == 'coco' else VOC_NAMES
        plt.imshow(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
        for lab in bbox_and_scores:
            for boxes in bbox_and_scores[lab]:
                x1, y1, x2, y2, score = boxes
                if score > 0.3:
                    plt.gca().add_patch(
                        Rectangle((x1, y1),
                                  x2 - x1,
                                  y2 - y1,
                                  linewidth=2,
                                  edgecolor=colors[lab],
                                  facecolor='none'))
                    plt.text(x1 + 3,
                             y1 + 3,
                             names[lab] + '%.2f' % score,
                             bbox=dict(facecolor=colors[lab], alpha=0.5),
                             fontsize=7,
                             color='k')

        fig.patch.set_visible(False)
        plt.axis('off')
        plt.savefig('data/demo_results.png', dpi=300, transparent=True)
        plt.show()
Beispiel #5
0
def main():
  saver = create_saver(cfg.local_rank, save_dir=cfg.ckpt_dir)
  logger = create_logger(cfg.local_rank, save_dir=cfg.log_dir)
  summary_writer = create_summary(cfg.local_rank, log_dir=cfg.log_dir)
  print = logger.info
  print(cfg)

  torch.manual_seed(317)
  torch.backends.cudnn.benchmark = True  # disable this if OOM at beginning of training

  num_gpus = torch.cuda.device_count()
  if cfg.dist:
    cfg.device = torch.device('cuda:%d' % cfg.local_rank)
    torch.cuda.set_device(cfg.local_rank)
    dist.init_process_group(backend='nccl', init_method='env://',
                            world_size=num_gpus, rank=cfg.local_rank)
  else:
    cfg.device = torch.device('cuda')

  print('Setting up data...')
  Dataset = COCO if cfg.dataset == 'coco' else PascalVOC
  train_dataset = Dataset(cfg.data_dir, 'train', split_ratio=cfg.split_ratio, img_size=cfg.img_size)
  train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset,
                                                                  num_replicas=num_gpus,
                                                                  rank=cfg.local_rank)
  train_loader = torch.utils.data.DataLoader(train_dataset,
                                             batch_size=cfg.batch_size // num_gpus
                                             if cfg.dist else cfg.batch_size,
                                             shuffle=not cfg.dist,
                                             num_workers=cfg.num_workers,
                                             pin_memory=True,
                                             drop_last=True,
                                             sampler=train_sampler if cfg.dist else None)

  Dataset_eval = COCO_eval if cfg.dataset == 'coco' else PascalVOC_eval
  val_dataset = Dataset_eval(cfg.data_dir, 'val', test_scales=[1.], test_flip=False)
  val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=1,
                                           shuffle=False, num_workers=1, pin_memory=True,
                                           collate_fn=val_dataset.collate_fn)

  print('Creating model...')
  if 'hourglass' in cfg.arch:
    model = get_hourglass[cfg.arch]
  elif 'resdcn' in cfg.arch:
    model = get_pose_net(num_layers=int(cfg.arch.split('_')[-1]), num_classes=train_dataset.num_classes)
  else:
    raise NotImplementedError

  if cfg.dist:
    # model = nn.SyncBatchNorm.convert_sync_batchnorm(model)
    model = model.to(cfg.device)
    model = nn.parallel.DistributedDataParallel(model,
                                                device_ids=[cfg.local_rank, ],
                                                output_device=cfg.local_rank)
  else:
    model = nn.DataParallel(model).to(cfg.device)

  if os.path.isfile(cfg.pretrain_dir):
    model = load_model(model, cfg.pretrain_dir)

  optimizer = torch.optim.Adam(model.parameters(), cfg.lr)
  lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, cfg.lr_step, gamma=0.1)

  def train(epoch):
    print('\n Epoch: %d' % epoch)
    model.train()
    tic = time.perf_counter()
    for batch_idx, batch in enumerate(train_loader):
      for k in batch:
        if k != 'meta':
          batch[k] = batch[k].to(device=cfg.device, non_blocking=True)