예제 #1
0
def eval(config_files, cmd_config):
    cfg = make_config()
    cfg = merge_configs(cfg, config_files, cmd_config)

    os.makedirs(cfg.output_dir, exist_ok=True)

    model = build_model(cfg, 1).to(cfg.device)
    # start_epoch = load_checkpoint(cfg.output_dir, device=cfg.device, epoch=cfg.test.epoch, exclude="classifier", model=model)
    state_dict = torch.load(cfg.test.model_path, map_location=cfg.device)

    # Remove the classifier
    remove_keys = []
    # import ipdb; ipdb.set_trace()
    for key, value in state_dict.items():
        if 'classifier' in key:
            remove_keys.append(key)
    for key in remove_keys:
        del state_dict[key]

    model.load_state_dict(state_dict, strict=False)

    if torch.cuda.device_count() > 1:
        model = torch.nn.DataParallel(model)

    logger.info(f"Load model {cfg.test.model_path}")
    train_dataset, valid_dataset, meta_dataset = make_basic_dataset(
        cfg.data.pkl_path,
        cfg.data.train_size,
        cfg.data.valid_size,
        cfg.data.pad,
        test_ext=cfg.data.test_ext,
        re_prob=cfg.data.re_prob,
        with_mask=cfg.data.with_mask,
    )
    valid_loader = DataLoader(valid_dataset,
                              batch_size=cfg.data.batch_size,
                              num_workers=cfg.data.test_num_workers,
                              pin_memory=True,
                              shuffle=False)

    query_length = meta_dataset.num_query_imgs

    if cfg.data.name.lower() == "vehicleid":
        eval_vehicle_id_(model, valid_loader, query_length, cfg)
    else:
        eval_(model,
              cfg.test.device,
              valid_loader,
              query_length,
              feat_norm=cfg.test.feat_norm,
              remove_junk=cfg.test.remove_junk,
              max_rank=cfg.test.max_rank,
              output_dir=cfg.output_dir,
              lambda_=cfg.test.lambda_,
              rerank=cfg.test.rerank,
              split=cfg.test.split,
              output_html_path=cfg.test.output_html_path)
예제 #2
0
def train(config_files, cmd_config):
    """
    Training models.
    """
    cfg = make_config()
    cfg = merge_configs(cfg, config_files, cmd_config)

    mkdir_p(cfg.output_dir)
    logzero.logfile(f"{cfg.output_dir}/train.log")
    logzero.loglevel(getattr(logging, cfg.logging.level.upper()))
    logger.info(cfg)
    logger.info(f"worker ip is {get_host_ip()}")

    writter = SummaryWriter(
        comment=f"{cfg.data.name}_{cfg.model.name}__{cfg.data.batch_size}")

    logger.info(f"Loading {cfg.data.name} dataset")

    train_dataset, valid_dataset, meta_dataset = make_basic_dataset(
        cfg.data.pkl_path,
        cfg.data.train_size,
        cfg.data.valid_size,
        cfg.data.pad,
        test_ext=cfg.data.test_ext,
        re_prob=cfg.data.re_prob,
        with_mask=cfg.data.with_mask,
    )
    num_class = meta_dataset.num_train_ids
    sampler = getattr(samplers, cfg.data.sampler)(train_dataset.meta_dataset,
                                                  cfg.data.batch_size,
                                                  cfg.data.num_instances)
    train_loader = DataLoader(train_dataset,
                              sampler=sampler,
                              batch_size=cfg.data.batch_size,
                              num_workers=cfg.data.train_num_workers,
                              pin_memory=True)
    valid_loader = DataLoader(valid_dataset,
                              batch_size=cfg.data.batch_size,
                              num_workers=cfg.data.test_num_workers,
                              pin_memory=True,
                              shuffle=False)
    logger.info(f"Successfully load {cfg.data.name}!")

    logger.info(f"Building {cfg.model.name} model, "
                f"num class is {num_class}")
    model = build_model(cfg, num_class).to(cfg.device)

    logger.info(f"Building {cfg.optim.name} optimizer...")

    optimizer = make_optimizer(cfg.optim.name, model, cfg.optim.base_lr,
                               cfg.optim.weight_decay,
                               cfg.optim.bias_lr_factor, cfg.optim.momentum)

    logger.info(f"Building losses {cfg.loss.losses}")

    triplet_loss = None
    id_loss = None
    center_loss = None
    optimizer_center = None
    tuplet_loss = None
    if 'local-triplet' in cfg.loss.losses:
        pt_loss = ParsingTripletLoss(margin=0.3)
    if 'triplet' in cfg.loss.losses:
        triplet_loss = vr_loss.TripletLoss(margin=cfg.loss.triplet_margin)
    if 'id' in cfg.loss.losses:
        id_loss = vr_loss.CrossEntropyLabelSmooth(num_class,
                                                  cfg.loss.id_epsilon)
        # id_loss = vr_losses.CrossEntropyLabelSmooth(num_class, cfg.loss.id_epsilon, keep_dim=False)
    if 'center' in cfg.loss.losses:
        center_loss = vr_loss.CenterLoss(
            num_class, feat_dim=model.in_planes).to(cfg.device)
        optimizer_center = torch.optim.SGD(center_loss.parameters(),
                                           cfg.loss.center_lr)
    if 'tuplet' in cfg.loss.losses:
        tuplet_loss = vr_loss.TupletLoss(
            cfg.data.num_instances,
            cfg.data.batch_size // cfg.data.num_instances, cfg.loss.tuplet_s,
            cfg.loss.tuplet_beta)

    start_epoch = 1
    if cfg.model.pretrain_choice == "self":
        logger.info(f"Loading checkpoint from {cfg.output_dir}")
        if "center_loss" in cfg.loss.losses:
            start_epoch = load_checkpoint(cfg.output_dir,
                                          cfg.device,
                                          model=model,
                                          optimizer=optimizer,
                                          optimizer_center=optimizer_center,
                                          center_loss=center_loss)
        else:
            start_epoch = load_checkpoint(cfg.output_dir,
                                          cfg.device,
                                          model=model,
                                          optimizer=optimizer)
        logger.info(
            f"Loaded checkpoint successfully! Start epoch is {start_epoch}")

    if torch.cuda.device_count() > 1:
        model = torch.nn.DataParallel(model)

    scheduler = make_warmup_scheduler(optimizer,
                                      cfg.scheduler.milestones,
                                      cfg.scheduler.gamma,
                                      cfg.scheduler.warmup_factor,
                                      cfg.scheduler.warmup_iters,
                                      cfg.scheduler.warmup_method,
                                      last_epoch=start_epoch - 1)

    logger.info("Start training!")
    for epoch in range(start_epoch, cfg.train.epochs + 1):
        t_begin = time.time()
        scheduler.step()
        running_loss = 0
        running_acc = 0
        gpu_time = 0
        data_time = 0

        t0 = time.time()
        for iter, batch in enumerate(train_loader):
            t1 = time.time()
            data_time += t1 - t0
            global_steps = (epoch - 1) * len(train_loader) + iter
            model.train()
            optimizer.zero_grad()

            if 'center' in cfg.loss.losses:
                optimizer_center.zero_grad()

            for name, item in batch.items():
                if isinstance(item, torch.Tensor):
                    batch[name] = item.to(cfg.device)

            output = model(**batch)
            global_feat = output["global_feat"]
            global_score = output["cls_score"]
            local_feat = output["local_feat"]
            vis_score = output["vis_score"]

            # losses
            loss = 0
            if "id" in cfg.loss.losses:
                g_xent_loss = id_loss(global_score, batch["id"]).mean()
                loss += g_xent_loss
                logger.debug(f'ID Loss: {g_xent_loss.item()}')
                writter.add_scalar("global_loss/id_loss", g_xent_loss.item(),
                                   global_steps)

            if "triplet" in cfg.loss.losses:
                t_loss, _, _ = triplet_loss(global_feat,
                                            batch["id"],
                                            normalize_feature=False)
                logger.debug(f'Triplet Loss: {t_loss.item()}')
                loss += t_loss
                writter.add_scalar("global_loss/triplet_loss", t_loss.item(),
                                   global_steps)

            if "center" in cfg.loss.losses:
                g_center_loss = center_loss(global_feat, batch["id"])
                logger.debug(g_center_loss.item())
                loss += cfg.loss.center_weight * g_center_loss
                writter.add_scalar("global_loss/center_loss",
                                   g_center_loss.item(), global_steps)

            if "tuplet" in cfg.loss.losses:
                g_tuplet_loss = tuplet_loss(global_feat)
                loss += g_tuplet_loss
                writter.add_scalar("global_loss/tuplet_loss",
                                   g_tuplet_loss.item(), global_steps)

            if "local-triplet" in cfg.loss.losses:
                l_triplet_loss, _, _ = pt_loss(local_feat, vis_score,
                                               batch["id"], True)
                writter.add_scalar("local_loss/triplet_loss",
                                   l_triplet_loss.item(), global_steps)
                loss += l_triplet_loss

            loss.backward()
            optimizer.step()

            # centerloss单独优化
            if 'center' in cfg.loss.losses:
                for param in center_loss.parameters():
                    param.grad.data *= (1. / cfg.loss.center_weight)
                optimizer_center.step()

            acc = (global_score.max(1)[1] == batch["id"]).float().mean()

            # running mean
            if iter == 0:
                running_acc = acc.item()
                running_loss = loss.item()
            else:
                running_acc = 0.98 * running_acc + 0.02 * acc.item()
                running_loss = 0.98 * running_loss + 0.02 * loss.item()

            if iter % cfg.logging.period == 0:
                logger.info(
                    f"Epoch[{epoch:3d}] Iteration[{iter:4d}/{len(train_loader):4d}] "
                    f"Loss: {running_loss:.3f}, Acc: {running_acc:.3f}, Base Lr: {scheduler.get_lr()[0]:.2e}"
                )
                if cfg.debug:
                    break
            t0 = time.time()
            gpu_time += t0 - t1
            logger.debug(f"GPU Time: {gpu_time}, Data Time: {data_time}")

        t_end = time.time()

        logger.info(
            f"Epoch {epoch} done. Time per epoch: {t_end - t_begin:.1f}[s] "
            f"Speed:{(t_end - t_begin) / len(train_loader.dataset):.1f}[samples/s] "
        )
        logger.info('-' * 10)

        # 测试模型, veriwild在训练时测试会导致显存溢出,训练后单独测试。 vehicleid使用不同的测试策略,也训练后单独测试
        if (epoch == 1
                or epoch % cfg.test.period == 0) and cfg.data.name.lower(
                ) != 'veriwild' and cfg.data.name.lower() != 'vehicleid':
            query_length = meta_dataset.num_query_imgs
            if query_length != 0:  # Private没有测试集
                eval_(model,
                      device=cfg.device,
                      valid_loader=valid_loader,
                      query_length=query_length,
                      feat_norm=cfg.test.feat_norm,
                      remove_junk=cfg.test.remove_junk,
                      lambda_=cfg.test.lambda_,
                      output_dir=cfg.output_dir)

        # save checkpoint
        if epoch % cfg.model.ckpt_period == 0 or epoch == 1:
            logger.info(f"Saving models in epoch {epoch}")
            if 'center' in cfg.loss.losses:
                save_checkpoint(epoch,
                                cfg.output_dir,
                                model=model,
                                optimizer=optimizer,
                                center_loss=center_loss,
                                optimizer_center=optimizer_center)
            else:
                save_checkpoint(epoch,
                                cfg.output_dir,
                                model=model,
                                optimizer=optimizer)