Exemplo n.º 1
0
def build_model(cfg):

    cfg.build_backbone = build_backbone
    cfg.build_anchor_generator = build_anchor_generator

    model = FCOS(cfg)
    logger = logging.getLogger(__name__)
    logger.info("Model:\n{}".format(model))
    return model
Exemplo n.º 2
0
def test(epochs_tested):
    is_train=False
    transforms = transform.build_transforms(is_train=is_train)
    coco_dataset = dataset.COCODataset(is_train=is_train, transforms=transforms)
    dataloader = build_dataloader(coco_dataset, sampler=None, is_train=is_train)

    assert isinstance(epochs_tested, (list, set)), "during test, archive_name must be a list or set!"
    model = FCOS(is_train=is_train)

    for epoch in epochs_tested:
        utils.load_model(model, epoch)
        model.cuda()
        model.eval()

        final_results = []

        with torch.no_grad():
            for data in tqdm(dataloader):
                img = data["images"]
                ori_img_shape = data["ori_img_shape"]
                fin_img_shape = data["fin_img_shape"]
                index = data["indexs"]

                img = img.cuda()
                ori_img_shape = ori_img_shape.cuda()
                fin_img_shape = fin_img_shape.cuda()

                cls_pred, reg_pred, label_pred = model([img, ori_img_shape, fin_img_shape])

                cls_pred = cls_pred[0].cpu()
                reg_pred = reg_pred[0].cpu()
                label_pred = label_pred[0].cpu()
                index = index[0]

                img_info = dataloader.dataset.img_infos[index]
                imgid = img_info["id"]

                reg_pred = utils.xyxy2xywh(reg_pred)

                label_pred = label_pred.tolist()
                cls_pred = cls_pred.tolist()

                final_results.extend(
                    [
                        {
                            "image_id": imgid,
                            "category_id": dataloader.dataset.label2catid[label_pred[k]],
                            "bbox": reg_pred[k].tolist(),
                            "score": cls_pred[k],
                        }
                        for k in range(len(reg_pred))
                    ]
                )

        output_path = os.path.join(cfg.output_path, "fcos_"+str(epoch)+".json")
        utils.evaluate_coco(dataloader.dataset.coco, final_results, output_path, "bbox")
    cls_heads, reg_heads, batch_anchors = net(
        torch.autograd.Variable(torch.randn(3, 3, image_h, image_w)))
    annotations = torch.FloatTensor([[[113, 120, 183, 255, 5],
                                      [13, 45, 175, 210, 2]],
                                     [[11, 18, 223, 225, 1],
                                      [-1, -1, -1, -1, -1]],
                                     [[-1, -1, -1, -1, -1],
                                      [-1, -1, -1, -1, -1]]])
    decode = RetinaDecoder(image_w, image_h)
    batch_scores, batch_classes, batch_pred_bboxes = decode(
        cls_heads, reg_heads, batch_anchors)
    print("1111", batch_scores.shape, batch_classes.shape,
          batch_pred_bboxes.shape)

    from fcos import FCOS
    net = FCOS(resnet_type="resnet50")
    image_h, image_w = 600, 600
    cls_heads, reg_heads, center_heads, batch_positions = net(
        torch.autograd.Variable(torch.randn(3, 3, image_h, image_w)))
    annotations = torch.FloatTensor([[[113, 120, 183, 255, 5],
                                      [13, 45, 175, 210, 2]],
                                     [[11, 18, 223, 225, 1],
                                      [-1, -1, -1, -1, -1]],
                                     [[-1, -1, -1, -1, -1],
                                      [-1, -1, -1, -1, -1]]])
    decode = FCOSDecoder(image_w, image_h)
    batch_scores2, batch_classes2, batch_pred_bboxes2 = decode(
        cls_heads, reg_heads, center_heads, batch_positions)
    print("2222", batch_scores2.shape, batch_classes2.shape,
          batch_pred_bboxes2.shape)
Exemplo n.º 4
0
def train(is_dist, start_epoch, local_rank):
    transforms = transform.build_transforms()
    coco_dataset = dataset.COCODataset(is_train=True, transforms=transforms)
    if (is_dist):
        sampler = distributedGroupSampler(coco_dataset)
    else:
        sampler = groupSampler(coco_dataset)
    dataloader = build_dataloader(coco_dataset, sampler)

    batch_time_meter = utils.AverageMeter()
    cls_loss_meter = utils.AverageMeter()
    reg_loss_meter = utils.AverageMeter()
    cen_loss_meter = utils.AverageMeter()
    losses_meter = utils.AverageMeter()

    model = FCOS(is_train=True)
    if (start_epoch == 1):
        model.resNet.load_pretrained(pretrained_path[cfg.resnet_depth])
    else:
        utils.load_model(model, start_epoch - 1)
    model = model.cuda()

    if is_dist:
        model = torch.nn.parallel.DistributedDataParallel(
            model,
            device_ids=[
                local_rank,
            ],
            output_device=local_rank,
            broadcast_buffers=False)
    optimizer = solver.build_optimizer(model)
    scheduler = solver.scheduler(optimizer)

    model.train()
    logs = []

    for epoch in range(1, cfg.max_epochs + 1):
        if is_dist:
            dataloader.sampler.set_epoch(epoch - 1)
        scheduler.lr_decay(epoch)

        end_time = time.time()
        for iteration, datas in enumerate(dataloader, 1):
            scheduler.constant_warmup(epoch, iteration - 1)
            images = datas["images"]
            bboxes = datas["bboxes"]
            labels = datas["labels"]

            images = images.cuda()
            bboxes = [bbox.cuda() for bbox in bboxes]
            labels = [label.cuda() for label in labels]

            loss_dict = model([images, bboxes, labels])
            cls_loss = loss_dict["cls_loss"]
            reg_loss = loss_dict["reg_loss"]
            cen_loss = loss_dict["cen_loss"]

            losses = cls_loss + reg_loss + cen_loss
            optimizer.zero_grad()
            losses.backward()
            grad_norm = utils.clip_grads(model.parameters())
            optimizer.step()

            batch_time_meter.update(time.time() - end_time)
            end_time = time.time()

            cls_loss_meter.update(cls_loss.item())
            reg_loss_meter.update(reg_loss.item())
            cen_loss_meter.update(cen_loss.item())
            losses_meter.update(losses.item())

            if (iteration % 50 == 0):
                if (local_rank == 0):
                    res = "\t".join([
                        "Epoch: [%d/%d]" % (epoch, cfg.max_epochs),
                        "Iter: [%d/%d]" % (iteration, len(dataloader)),
                        "Time: %.3f (%.3f)" %
                        (batch_time_meter.val, batch_time_meter.avg),
                        "Cls_loss: %.4f (%.4f)" %
                        (cls_loss_meter.val, cls_loss_meter.avg),
                        "Reg_loss: %.4f (%.4f)" %
                        (reg_loss_meter.val, reg_loss_meter.avg),
                        "Cen_loss: %.4f (%.4f)" %
                        (cen_loss_meter.val, cen_loss_meter.avg),
                        "Loss: %.4f (%.4f)" %
                        (losses_meter.val, losses_meter.avg),
                        "lr: %.6f" % (optimizer.param_groups[0]["lr"]),
                        "grad_norm: %.4f" % (grad_norm.item(), )
                    ])
                    print(res)
                    logs.append(res)
                batch_time_meter.reset()
                cls_loss_meter.reset()
                reg_loss_meter.reset()
                cen_loss_meter.reset()
                losses_meter.reset()

        if (local_rank == 0):
            utils.save_model(model, epoch)
        if (is_dist):
            utils.synchronize()

    if (local_rank == 0):
        with open("logs.txt", "w") as f:
            for i in logs:
                f.write(i + "\n")