Exemplo n.º 1
0
def predict(model, test_dataset, test_dataset_vis, output_path):
    mkdir_p(output_path)
    for i in tqdm.tqdm(range(len(test_dataset))):
        image = test_dataset[i]
        image_vis, extra = test_dataset_vis[i]

        # 重复图片直接用之前计算好的即可
        image_path = Path(extra["image_path"])
        if str(image_path) in IMG2MASK:
            extra["mask_path"] = str(IMG2MASK[str(image_path)])
            continue
        mask_path = output_path / f"{image_path.name.split('.')[0]}.png"

        x_tensor = torch.from_numpy(image).to("cuda").unsqueeze(0)
        with torch.no_grad():
            pr_mask = model.predict(x_tensor)
        pr_map = pr_mask.squeeze().cpu().numpy().round()
        pr_map = np.argmax(pr_map,
                           axis=0)[:image_vis.shape[0], :image_vis.shape[1]]
        cv2.imwrite(str(mask_path), pr_map.astype(np.uint8))
        extra["mask_path"] = str(mask_path)

        IMG2MASK[str(image_path)] = str(mask_path)
Exemplo n.º 2
0
def train(config_files, cmd_config):
    """
    Training models.
    """
    cfg = make_config()
    cfg = merge_configs(cfg, config_files, cmd_config)

    mkdir_p(cfg.output_dir)
    logzero.logfile(f"{cfg.output_dir}/train.log")
    logzero.loglevel(getattr(logging, cfg.logging.level.upper()))
    logger.info(cfg)
    logger.info(f"worker ip is {get_host_ip()}")

    writter = SummaryWriter(
        comment=f"{cfg.data.name}_{cfg.model.name}__{cfg.data.batch_size}")

    logger.info(f"Loading {cfg.data.name} dataset")

    train_dataset, valid_dataset, meta_dataset = make_basic_dataset(
        cfg.data.pkl_path,
        cfg.data.train_size,
        cfg.data.valid_size,
        cfg.data.pad,
        test_ext=cfg.data.test_ext,
        re_prob=cfg.data.re_prob,
        with_mask=cfg.data.with_mask,
    )
    num_class = meta_dataset.num_train_ids
    sampler = getattr(samplers, cfg.data.sampler)(train_dataset.meta_dataset,
                                                  cfg.data.batch_size,
                                                  cfg.data.num_instances)
    train_loader = DataLoader(train_dataset,
                              sampler=sampler,
                              batch_size=cfg.data.batch_size,
                              num_workers=cfg.data.train_num_workers,
                              pin_memory=True)
    valid_loader = DataLoader(valid_dataset,
                              batch_size=cfg.data.batch_size,
                              num_workers=cfg.data.test_num_workers,
                              pin_memory=True,
                              shuffle=False)
    logger.info(f"Successfully load {cfg.data.name}!")

    logger.info(f"Building {cfg.model.name} model, "
                f"num class is {num_class}")
    model = build_model(cfg, num_class).to(cfg.device)

    logger.info(f"Building {cfg.optim.name} optimizer...")

    optimizer = make_optimizer(cfg.optim.name, model, cfg.optim.base_lr,
                               cfg.optim.weight_decay,
                               cfg.optim.bias_lr_factor, cfg.optim.momentum)

    logger.info(f"Building losses {cfg.loss.losses}")

    triplet_loss = None
    id_loss = None
    center_loss = None
    optimizer_center = None
    tuplet_loss = None
    if 'local-triplet' in cfg.loss.losses:
        pt_loss = ParsingTripletLoss(margin=0.3)
    if 'triplet' in cfg.loss.losses:
        triplet_loss = vr_loss.TripletLoss(margin=cfg.loss.triplet_margin)
    if 'id' in cfg.loss.losses:
        id_loss = vr_loss.CrossEntropyLabelSmooth(num_class,
                                                  cfg.loss.id_epsilon)
        # id_loss = vr_losses.CrossEntropyLabelSmooth(num_class, cfg.loss.id_epsilon, keep_dim=False)
    if 'center' in cfg.loss.losses:
        center_loss = vr_loss.CenterLoss(
            num_class, feat_dim=model.in_planes).to(cfg.device)
        optimizer_center = torch.optim.SGD(center_loss.parameters(),
                                           cfg.loss.center_lr)
    if 'tuplet' in cfg.loss.losses:
        tuplet_loss = vr_loss.TupletLoss(
            cfg.data.num_instances,
            cfg.data.batch_size // cfg.data.num_instances, cfg.loss.tuplet_s,
            cfg.loss.tuplet_beta)

    start_epoch = 1
    if cfg.model.pretrain_choice == "self":
        logger.info(f"Loading checkpoint from {cfg.output_dir}")
        if "center_loss" in cfg.loss.losses:
            start_epoch = load_checkpoint(cfg.output_dir,
                                          cfg.device,
                                          model=model,
                                          optimizer=optimizer,
                                          optimizer_center=optimizer_center,
                                          center_loss=center_loss)
        else:
            start_epoch = load_checkpoint(cfg.output_dir,
                                          cfg.device,
                                          model=model,
                                          optimizer=optimizer)
        logger.info(
            f"Loaded checkpoint successfully! Start epoch is {start_epoch}")

    if torch.cuda.device_count() > 1:
        model = torch.nn.DataParallel(model)

    scheduler = make_warmup_scheduler(optimizer,
                                      cfg.scheduler.milestones,
                                      cfg.scheduler.gamma,
                                      cfg.scheduler.warmup_factor,
                                      cfg.scheduler.warmup_iters,
                                      cfg.scheduler.warmup_method,
                                      last_epoch=start_epoch - 1)

    logger.info("Start training!")
    for epoch in range(start_epoch, cfg.train.epochs + 1):
        t_begin = time.time()
        scheduler.step()
        running_loss = 0
        running_acc = 0
        gpu_time = 0
        data_time = 0

        t0 = time.time()
        for iter, batch in enumerate(train_loader):
            t1 = time.time()
            data_time += t1 - t0
            global_steps = (epoch - 1) * len(train_loader) + iter
            model.train()
            optimizer.zero_grad()

            if 'center' in cfg.loss.losses:
                optimizer_center.zero_grad()

            for name, item in batch.items():
                if isinstance(item, torch.Tensor):
                    batch[name] = item.to(cfg.device)

            output = model(**batch)
            global_feat = output["global_feat"]
            global_score = output["cls_score"]
            local_feat = output["local_feat"]
            vis_score = output["vis_score"]

            # losses
            loss = 0
            if "id" in cfg.loss.losses:
                g_xent_loss = id_loss(global_score, batch["id"]).mean()
                loss += g_xent_loss
                logger.debug(f'ID Loss: {g_xent_loss.item()}')
                writter.add_scalar("global_loss/id_loss", g_xent_loss.item(),
                                   global_steps)

            if "triplet" in cfg.loss.losses:
                t_loss, _, _ = triplet_loss(global_feat,
                                            batch["id"],
                                            normalize_feature=False)
                logger.debug(f'Triplet Loss: {t_loss.item()}')
                loss += t_loss
                writter.add_scalar("global_loss/triplet_loss", t_loss.item(),
                                   global_steps)

            if "center" in cfg.loss.losses:
                g_center_loss = center_loss(global_feat, batch["id"])
                logger.debug(g_center_loss.item())
                loss += cfg.loss.center_weight * g_center_loss
                writter.add_scalar("global_loss/center_loss",
                                   g_center_loss.item(), global_steps)

            if "tuplet" in cfg.loss.losses:
                g_tuplet_loss = tuplet_loss(global_feat)
                loss += g_tuplet_loss
                writter.add_scalar("global_loss/tuplet_loss",
                                   g_tuplet_loss.item(), global_steps)

            if "local-triplet" in cfg.loss.losses:
                l_triplet_loss, _, _ = pt_loss(local_feat, vis_score,
                                               batch["id"], True)
                writter.add_scalar("local_loss/triplet_loss",
                                   l_triplet_loss.item(), global_steps)
                loss += l_triplet_loss

            loss.backward()
            optimizer.step()

            # centerloss单独优化
            if 'center' in cfg.loss.losses:
                for param in center_loss.parameters():
                    param.grad.data *= (1. / cfg.loss.center_weight)
                optimizer_center.step()

            acc = (global_score.max(1)[1] == batch["id"]).float().mean()

            # running mean
            if iter == 0:
                running_acc = acc.item()
                running_loss = loss.item()
            else:
                running_acc = 0.98 * running_acc + 0.02 * acc.item()
                running_loss = 0.98 * running_loss + 0.02 * loss.item()

            if iter % cfg.logging.period == 0:
                logger.info(
                    f"Epoch[{epoch:3d}] Iteration[{iter:4d}/{len(train_loader):4d}] "
                    f"Loss: {running_loss:.3f}, Acc: {running_acc:.3f}, Base Lr: {scheduler.get_lr()[0]:.2e}"
                )
                if cfg.debug:
                    break
            t0 = time.time()
            gpu_time += t0 - t1
            logger.debug(f"GPU Time: {gpu_time}, Data Time: {data_time}")

        t_end = time.time()

        logger.info(
            f"Epoch {epoch} done. Time per epoch: {t_end - t_begin:.1f}[s] "
            f"Speed:{(t_end - t_begin) / len(train_loader.dataset):.1f}[samples/s] "
        )
        logger.info('-' * 10)

        # 测试模型, veriwild在训练时测试会导致显存溢出,训练后单独测试。 vehicleid使用不同的测试策略,也训练后单独测试
        if (epoch == 1
                or epoch % cfg.test.period == 0) and cfg.data.name.lower(
                ) != 'veriwild' and cfg.data.name.lower() != 'vehicleid':
            query_length = meta_dataset.num_query_imgs
            if query_length != 0:  # Private没有测试集
                eval_(model,
                      device=cfg.device,
                      valid_loader=valid_loader,
                      query_length=query_length,
                      feat_norm=cfg.test.feat_norm,
                      remove_junk=cfg.test.remove_junk,
                      lambda_=cfg.test.lambda_,
                      output_dir=cfg.output_dir)

        # save checkpoint
        if epoch % cfg.model.ckpt_period == 0 or epoch == 1:
            logger.info(f"Saving models in epoch {epoch}")
            if 'center' in cfg.loss.losses:
                save_checkpoint(epoch,
                                cfg.output_dir,
                                model=model,
                                optimizer=optimizer,
                                center_loss=center_loss,
                                optimizer_center=optimizer_center)
            else:
                save_checkpoint(epoch,
                                cfg.output_dir,
                                model=model,
                                optimizer=optimizer)
Exemplo n.º 3
0

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--model-path", default="best_model_trainval.pth")
    parser.add_argument("--reid-pkl-path", type=str, required=True)
    parser.add_argument("--output-path", type=str, required=True)
    args = parser.parse_args()
    model = torch.load(args.model_path)
    model = model.cuda()
    model.eval()

    with open(args.reid_pkl_path, "rb") as f:
        metas = pickle.load(f)
    output_path = Path(args.output_path).absolute()

    for phase in metas.keys():
        sub_path = output_path / phase
        mkdir_p(str(sub_path))
        dataset = VehicleReIDParsingDataset(
            metas[phase],
            augmentation=get_validation_augmentation(),
            preprocessing=get_preprocessing(preprocessing_fn))
        dataset_vis = VehicleReIDParsingDataset(metas[phase], with_extra=True)
        print('Predict mask to {}'.format(sub_path))
        predict(model, dataset, dataset_vis, sub_path)

    # Write mask path to pkl
    with open(args.reid_pkl_path, "wb") as f:
        pickle.dump(metas, f)
Exemplo n.º 4
0
    shape = item['resources'][0]['size']
    shape = [shape['height'], shape["width"]]
    polys_list = item['results']['polys']
    polys = [poly['poly'] for poly in polys_list]
    classes = [int(poly['attr']['side']) for poly in polys_list]
    return nori_id, shape, polys, classes


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument("--json-path", default="poly.json")
    parser.add_argument("--output-path", default="veri776_parsing3165")
    args = parser.parse_args()

    with open(args.json_path, "r") as f:
        polygons = json.load(f)
    output_path = args.output_path
    mkdir_p(output_path)

    for i, item in tqdm(enumerate(polygons)):
        image_name = item["image_name"]
        shape = item["shape"]
        polys = item["polys"]
        classes = item["classes"]  
        mask = poly2mask(polys, classes, shape)
        print(image_name)
        image_name = image_name.split('/')[1].split('.')[0]
        cv2.imwrite('{}/{}.png'.format(output_path, image_name), mask)