Пример #1
0
    def __init__(self, cfg):
        self.cfg = cfg
        assert self.cfg.MODEL.TYPE
        assert self.cfg.TRAIN.DATASETS
        assert self.cfg.OUTPUT.DIR
        Logger.init(cfg.OUTPUT.DIR, comm.get_local_rank())
        Logger.log(cfg)

        self.train_loader, self.val_loader = build_train_dataloader(cfg)
        Logger.log('训练集数据量:{}   验证集数据量:{}'.format(
            len(self.train_loader.dataset),
            (len(self.val_loader.dataset) if self.val_loader else 0)))
        self.model = build_models(cfg)
        self.load_weights()
        self.model = self.model.cuda()
        Logger.log(self.model)

        self.optimer = build_optimer(cfg, self.model)
        self.lr_scheduler = build_lr_scheduler(cfg, self.optimer,
                                               len(self.train_loader.dataset))
        self.criterion = build_criterion(cfg)

        if cfg.TRAIN.APEX:
            self.model, self.optimer = apex.amp.initialize(
                self.model, self.optimer)
        if comm.get_world_size() > 1:
            if cfg.TRAIN.APEX:
                self.model = apex.parallel.convert_syncbn_model(self.model)
                self.model = apex.parallel.DistributedDataParallel(self.model)
            else:
                self.model = DistributedDataParallel(
                    self.model,
                    device_ids=[comm.get_local_rank()],
                    broadcast_buffers=False)
def tmp(cfg):

    weight_path = cfg.MODEL.WEIGHTS
    if weight_path.endswith('.pth'):
        model = build_models(cfg)
        model.load_state_dict(torch.load(cfg.MODEL.WEIGHTS,
                                         map_location='cpu'),
                              strict=True)
    else:
        model = torch.jit.load(cfg.MODEL.WEIGHTS, map_location='cpu')
    model.eval().cuda()

    floder = cfg.TEST.METRIC.TEMPLATE_DATAPATH
    if not cfg.TEST.METRIC.TEMPLATE_DATAPATH or not os.path.exists(floder):
        print("请指定正确的数据集路径:TEMPLATE_DATAPATH")
        return None

    dataloader = create_infer_dataloader(floder, cfg)

    labels = None
    features = None
    features_template = {}
    for image, target in tqdm(dataloader):
        with torch.no_grad():
            feature = model(image.cuda())

        if labels is None:
            labels = target
            features = feature
        else:
            labels = torch.cat([labels, target], dim=0)
            features = torch.cat([features, feature], dim=0)

    labels_set = set(labels.numpy())
    for label in labels_set:
        pos = labels == label
        target_embeddings = features[pos]
        features_template[
            dataloader.dataset.classes[label]] = target_embeddings

    for k in features_template.keys():
        feature = features_template[k].mean(dim=0)

        denom = feature.norm(2, 0, True).clamp_min(1e-12)
        feature = feature / denom

        features_template[k] = feature

    torch.save(
        features_template,
        os.path.join(
            cfg.OUTPUT.DIR,
            os.path.basename(cfg.MODEL.WEIGHTS).split('.')[0] +
            '_features_template.pth'))
def cluster_tmp(cfg):
    weight_path = cfg.MODEL.WEIGHTS
    if weight_path.endswith('.pth'):
        model = build_models(cfg)
        model.load_state_dict(torch.load(cfg.MODEL.WEIGHTS,
                                         map_location='cpu'),
                              strict=True)
    else:
        model = torch.jit.load(cfg.MODEL.WEIGHTS, map_location='cpu')
    model.eval().cuda()

    floder = cfg.TEST.METRIC.TEMPLATE_DATAPATH
    if not cfg.TEST.METRIC.TEMPLATE_DATAPATH or not os.path.exists(floder):
        print("请指定正确的数据集路径:TEMPLATE_DATAPATH")
        return None

    transform = TRANSFORMS_REGISTRY.get(cfg.INFERENCE.TRANSFORM)(
        cfg.TRAIN.INPUT_WIDTH, cfg.TRAIN.INPUT_HEIGHT)
    dataset = ImageFolder(floder, transform)

    cluster_map = {}
    for image, target in tqdm(dataset):
        target = dataset.classes[target]
        with torch.no_grad():
            features = model(image.unsqueeze(dim=0).cuda())

        if target in cluster_map.keys():
            cluster_map[target] = torch.cat((cluster_map[target], features),
                                            dim=0)
        else:
            cluster_map[target] = features

    features_template = {}
    features_template['labels'] = []
    features_template['features'] = torch.zeros(
        (1, cfg.MODEL.METRIC.DIM)).cuda()
    for key, value in cluster_map.items():
        #对单一类别聚类,返回聚类中心数量和聚类中心特征
        cluster_num, features = cluster(cluster_map[key])

        for i in range(cluster_num):
            features_template['labels'].append(key)
        features_template['features'] = torch.cat(
            (features_template['features'], features), dim=0)

    features_template['features'] = features_template['features'][1:]
    print(features_template['features'].shape)
    torch.save(
        features_template,
        os.path.join(
            cfg.OUTPUT.DIR,
            os.path.basename(cfg.MODEL.WEIGHTS).split('.')[0] +
            '_features_template.pth'))
    def __init__(self, cfg):
        self.cfg = cfg
        weight_path = cfg.MODEL.WEIGHTS
        if weight_path.endswith('.pth'):
            self.model = build_models(cfg)
            self.model.load_state_dict(torch.load(cfg.MODEL.WEIGHTS,
                                                  map_location='cpu'),
                                       strict=True)
        else:
            self.model = torch.jit.load(cfg.MODEL.WEIGHTS, map_location='cpu')
        self.model.eval().cuda()

        if self.cfg.TEST.OUTPATH:
            if not os.path.exists(self.cfg.TEST.OUTPATH):
                os.mkdir(self.cfg.TEST.OUTPATH)
def instance_tmp(cfg):

    weight_path = cfg.MODEL.WEIGHTS
    if weight_path.endswith('.pth'):
        model = build_models(cfg)
        model.load_state_dict(torch.load(cfg.MODEL.WEIGHTS,
                                         map_location='cpu'),
                              strict=True)
    else:
        model = torch.jit.load(cfg.MODEL.WEIGHTS, map_location='cpu')
    model.eval().cuda()

    floder = cfg.TEST.METRIC.TEMPLATE_DATAPATH
    if not cfg.TEST.METRIC.TEMPLATE_DATAPATH or not os.path.exists(floder):
        print("请指定正确的数据集路径:TEMPLATE_DATAPATH")
        return None

    transform = TRANSFORMS_REGISTRY.get(cfg.INFERENCE.TRANSFORM)(
        cfg.TRAIN.INPUT_WIDTH, cfg.TRAIN.INPUT_HEIGHT)
    dataset = ImageFolder(floder, transform)

    features_template = {}
    features_template['labels'] = []
    features_template['features'] = torch.zeros(
        (1, cfg.MODEL.METRIC.DIM)).cuda()
    for image, target in tqdm(dataset):
        target = dataset.classes[target]
        with torch.no_grad():
            features = model(image.unsqueeze(dim=0).cuda())

        simi = torch.matmul(features, torch.t(features_template['features']))
        sorted_simi, sorted_idx = torch.sort(simi, 1, True)

        if (sorted_simi[0, 0].cpu().item() < 0.99):
            features_template['labels'].append(target)
            features_template['features'] = torch.cat(
                (features_template['features'], features), dim=0)

    features_template['features'] = features_template['features'][1:]
    print(features_template['features'].shape)
    torch.save(
        features_template,
        os.path.join(
            cfg.OUTPUT.DIR,
            os.path.basename(cfg.MODEL.WEIGHTS).split('.')[0] +
            '_features_template.pth'))