def extract_video_features(self, data_loader, print_freq=1, metric=None):

        batch_time = AverageMeter()
        data_time = AverageMeter()
        f1 = torch.zeros(len(data_loader), 256)
        f2 = torch.zeros(len(data_loader), 256)
        f3 = torch.zeros(len(data_loader), 256)
        Pids, Camids = [], []
        use_gpu = True
        for batch_idx, (imgs, pids, camids) in enumerate(data_loader):

            if use_gpu: imgs = Variable(imgs.cuda(), volatile=True)
            b, s, c, h, w = imgs.size()

            imgs = imgs.view(b * s, c, h, w)

            feat1, feat2, feat3 = self.cnnmodel(imgs)

            pool = 'avg'
            feat1 = feat1.view(b, s, -1)
            feat2 = feat2.view(b, s, -1)
            feat3 = feat3.view(b, s, -1)

            if pool == 'avg':
                feat1 = torch.mean(feat1, 1)
                feat2 = torch.mean(feat2, 1)
                feat3 = torch.mean(feat3, 1)

            else:
                feat1, _ = torch.max(feat1, 1)
                feat2, _ = torch.max(feat2, 1)
                feat3, _ = torch.max(feat3, 1)

            feat1 = feat1.data.cpu()
            feat2 = feat2.data.cpu()
            feat3 = feat3.data.cpu()
            f1[batch_idx, :] = feat1
            f2[batch_idx, :] = feat2
            f3[batch_idx, :] = feat3
            Pids.extend(pids)
            Camids.extend(camids)
            end = time.time()
            batch_time.update(time.time() - end)

            if (batch_idx + 1) % print_freq == 0:
                print('Extract Features: [{}/{}]\t'
                      'Time {:.3f} ({:.3f})\t'
                      'Data {:.3f} ({:.3f})\t'
                      'num_frame {}\t'.format(batch_idx + 1, len(data_loader),
                                              batch_time.val, batch_time.avg,
                                              data_time.val, data_time.avg, s))

        Pids = np.asarray(Pids)
        Camids = np.asarray(Camids)

        print(
            "Extracted features for data set, obtained {}-by-{} matrix".format(
                f1.size(0), f1.size(1)))

        return f1, f2, f3, Pids, Camids
Ejemplo n.º 2
0
def extract_n_save(model,
                   data_loader,
                   args,
                   root,
                   num_cams,
                   is_detection=True,
                   use_fname=True,
                   gt_type='reid'):
    model.eval()
    print_freq = 1000
    batch_time = AverageMeter()
    data_time = AverageMeter()

    if_created = [0 for _ in range(num_cams)]
    lines = [[] for _ in range(num_cams)]

    end = time.time()
    for i, (imgs, fnames, pids, cams) in enumerate(data_loader):
        cams += 1
        outputs = extract_cnn_feature(model, imgs)
        for fname, output, pid, cam in zip(fnames, outputs, pids, cams):
            if is_detection:
                pattern = re.compile(r'c(\d+)_f(\d+)')
                cam, frame = map(int, pattern.search(fname).groups())
                # f_names[cam - 1].append(fname)
                # features[cam - 1].append(output.numpy())
                line = np.concatenate(
                    [np.array([cam, 0, frame]),
                     output.numpy()])
            else:
                if use_fname:
                    pattern = re.compile(r'(\d+)_c(\d+)_f(\d+)')
                    pid, cam, frame = map(int, pattern.search(fname).groups())
                else:
                    cam, pid = cam.numpy(), pid.numpy()
                    frame = -1 * np.ones_like(pid)
                # line = output.numpy()
                line = np.concatenate(
                    [np.array([cam, pid, frame]),
                     output.numpy()])
            lines[cam - 1].append(line)
        batch_time.update(time.time() - end)
        end = time.time()

        if (i + 1) % print_freq == 0:
            print('Extract Features: [{}/{}]\t'
                  'Time {:.3f} ({:.3f})\t'
                  'Data {:.3f} ({:.3f})\t'.format(i + 1, len(data_loader),
                                                  batch_time.val,
                                                  batch_time.avg,
                                                  data_time.val,
                                                  data_time.avg))

            if_created = save_file(lines, args, root, if_created)

            lines = [[] for _ in range(num_cams)]

    save_file(lines, args, root, if_created)
    return
Ejemplo n.º 3
0
    def extract_features(model, data_loader, eval_only, print_freq=100):
        model.eval()
        batch_time = AverageMeter()
        data_time = AverageMeter()

        features = []
        labels = []
        cameras = []

        end = time.time()
        for i, (imgs, fnames, pids, cids) in enumerate(data_loader):
            data_time.update(time.time() - end)

            outputs = extract_cnn_feature(model, imgs, eval_only)
            for fname, output, pid, cid in zip(fnames, outputs, pids, cids):
                features.append(output)
                labels.append(int(pid.numpy()))
                cameras.append(int(cid.numpy()))

            batch_time.update(time.time() - end)
            end = time.time()

            if (i + 1) % print_freq == 0:
                print('Extract Features: [{}/{}]\t'
                      'Time {:.3f} ({:.3f})\t'
                      'Data {:.3f} ({:.3f})\t'
                      .format(i + 1, len(data_loader),
                              batch_time.val, batch_time.avg,
                              data_time.val, data_time.avg))

        output_features = torch.stack(features, 0)

        return output_features, labels, cameras
Ejemplo n.º 4
0
def extract_features(model, data_loader, print_freq=50, metric=None):
    model.eval()
    batch_time = AverageMeter()
    data_time = AverageMeter()

    features = OrderedDict()
    labels = OrderedDict()

    end = time.time()
    for i, (imgs, fnames, pids, _) in enumerate(data_loader):
        data_time.update(time.time() - end)
        outputs = extract_cnn_feature(model, imgs)
        for fname, output, pid in zip(fnames, outputs, pids):
            features[fname] = output
            labels[fname] = pid

        batch_time.update(time.time() - end)
        end = time.time()

        if (i + 1) % print_freq == 0:
            print('Extract Features: [{}/{}]\t'
                  'Time {:.3f} ({:.3f})\t'
                  'Data {:.3f} ({:.3f})\t'.format(i + 1, len(data_loader),
                                                  batch_time.val,
                                                  batch_time.avg,
                                                  data_time.val,
                                                  data_time.avg))

    return features, labels
Ejemplo n.º 5
0
    def extractfeature(self, data_loader):

        ## print
        print_freq = 10
        batch_time = AverageMeter()
        data_time = AverageMeter()
        end = time.time()

        queryfeat1 = 0
        queryfeat2 = 0
        queryfeat3 = 0
        preimgs = 0

        for i, (imgs, fnames, pids, _) in enumerate(data_loader):
            data_time.update(time.time() - end)
            imgs = Variable(imgs, volatile=True)

            if i == 0:
                query_feat1, query_feat2, query_feat3 = self.cnnmodel(imgs)
                queryfeat1 = query_feat1
                queryfeat2 = query_feat2
                queryfeat3 = query_feat3
                preimgs = imgs

            elif imgs.size(0) < data_loader.batch_size:

                flaw_batchsize = imgs.size(0)
                cat_batchsize = data_loader.batch_size - flaw_batchsize
                imgs = torch.cat((imgs, preimgs[0:cat_batchsize]), 0)
                query_feat1, query_feat2, query_feat3 = self.cnnmodel(imgs)

                query_feat1 = query_feat1[0:flaw_batchsize]
                query_feat2 = query_feat2[0:flaw_batchsize]
                query_feat3 = query_feat3[0:flaw_batchsize]
                queryfeat1 = torch.cat((queryfeat1, query_feat1), 0)
                queryfeat2 = torch.cat((queryfeat2, query_feat2), 0)
                queryfeat3 = torch.cat((queryfeat3, query_feat3), 0)
            else:
                query_feat1, query_feat2, query_feat3 = self.cnnmodel(imgs)
                queryfeat1 = torch.cat((queryfeat1, query_feat1), 0)
                queryfeat2 = torch.cat((queryfeat2, query_feat2), 0)
                queryfeat3 = torch.cat((queryfeat3, query_feat3), 0)

            batch_time.update(time.time() - end)
            end = time.time()

            if (i + 1) % print_freq == 0:
                print('Extract Features: [{}/{}]\t'
                      'Time {:.3f} ({:.3f})\t'
                      'Data {:.3f} ({:.3f})\t'.format(i + 1, len(data_loader),
                                                      batch_time.val,
                                                      batch_time.avg,
                                                      data_time.val,
                                                      data_time.avg))

        return queryfeat1, queryfeat2, queryfeat3
Ejemplo n.º 6
0
    def evaluate(self, queryloader, galleryloader, query, gallery):

        query_features = self.extractfeature(queryloader)
        batch_time = AverageMeter()
        data_time = AverageMeter()
        end = time.time()
        print_freq = 50
        distmat = 0

        self.cnnmodel.eval()
        self.classifier.eval()

        for i, (imgs, _, pids, _) in enumerate(galleryloader):
            data_time.update(time.time() - end)
            imgs = Variable(imgs, volatile=True)

            if i == 0:
                gallery_feat = self.cnnmodel(imgs)
                preimgs = imgs
            elif imgs.size(0) < galleryloader.batch_size:
                flaw_batchsize = imgs.size(0)
                cat_batchsize = galleryloader.batch_size - flaw_batchsize
                imgs = torch.cat((imgs, preimgs[0:cat_batchsize]), 0)
                gallery_feat = self.cnnmodel(imgs)
                gallery_feat = gallery_feat[0:flaw_batchsize]
            else:
                gallery_feat = self.cnnmodel(imgs)

            batch_cls_encode = self.classifier(query_features, gallery_feat)
            batch_cls_size = batch_cls_encode.size()
            batch_cls_encode = batch_cls_encode.view(-1, 2)
            batch_cls_encode = F.softmax(batch_cls_encode)
            batch_cls_encode = batch_cls_encode.view(batch_cls_size[0],
                                                     batch_cls_size[1], 2)
            batch_encode = batch_cls_encode[:, :, 0]

            if i == 0:
                distmat = batch_encode.data
            else:
                distmat = torch.cat((distmat, batch_encode.data), 1)

            batch_time.update(time.time() - end)
            end = time.time()

            if (i + 1) % print_freq == 0:
                print('Extract Features: [{}/{}]\t'
                      'Time {:.3f} ({:.3f})\t'
                      'Data {:.3f} ({:.3f})\t'.format(
                          i + 1, len(galleryloader), batch_time.val,
                          batch_time.avg, data_time.val, data_time.avg))

        return evaluate_all(distmat, query=query, gallery=gallery)
Ejemplo n.º 7
0
    def sim_computation(self, galleryloader, query_features):

        batch_time = AverageMeter()
        data_time = AverageMeter()
        end = time.time()
        print_freq = 50
        simmat = 0

        for i, (imgs, _, pids, _) in enumerate(galleryloader):
            data_time.update(time.time() - end)
            imgs = Variable(imgs, volatile=True)

            if i == 0:
                gallery_feat = self.cnnmodel(imgs)
                preimgs = imgs

            elif imgs.size(0) < galleryloader.batch_size:

                flaw_batchsize = imgs.size(0)
                cat_batchsize = galleryloader.batch_size - flaw_batchsize
                imgs = torch.cat((imgs, preimgs[0:cat_batchsize]), 0)
                gallery_feat = self.cnnmodel(imgs)
                gallery_feat = gallery_feat[0:flaw_batchsize]

            else:
                gallery_feat = self.cnnmodel(imgs)

            batch_cls_encode = self.classifier(query_features, gallery_feat)
            batch_cls_size = batch_cls_encode.size()
            batch_cls_encode = batch_cls_encode.view(-1, 2)
            batch_cls_encode = F.softmax(batch_cls_encode)
            batch_cls_encode = batch_cls_encode.view(batch_cls_size[0],
                                                     batch_cls_size[1], 2)
            batch_similarity = batch_cls_encode[:, :, 1]

            if i == 0:
                simmat = batch_similarity
            else:
                simmat = torch.cat((simmat, batch_similarity), 1)

            batch_time.update(time.time() - end)
            end = time.time()

            if (i + 1) % print_freq == 0:
                print('Extract Features: [{}/{}]\t'
                      'Time {:.3f} ({:.3f})\t'
                      'Data {:.3f} ({:.3f})\t'.format(
                          i + 1, len(galleryloader), batch_time.val,
                          batch_time.avg, data_time.val, data_time.avg))

        return simmat
Ejemplo n.º 8
0
def extract_features(model,
                     data_loader,
                     print_freq=1,
                     save_name='feature.mat'):

    batch_time = AverageMeter()
    data_time = AverageMeter()

    ids = []
    cams = []
    features = []
    query_files = []
    end = time.time()
    for i, (imgs, fnames) in enumerate(data_loader):
        data_time.update(time.time() - end)

        outputs = extract_cnn_feature(model, imgs)
        #for test time augmentation
        #bs, ncrops, c, h, w = imgs.size()
        #outputs = extract_cnn_feature(model, imgs.view(-1,c,h,w))
        #outputs = outputs.view(bs,ncrops,-1).mean(1)
        for fname, output in zip(fnames, outputs):
            if fname[0] == '-':
                ids.append(-1)
                cams.append(int(fname[4]))
            else:
                ids.append(int(fname[:4]))
                cams.append(int(fname[6]))
            features.append(output.numpy())
            query_files.append(fname)
            batch_time.update(time.time() - end)
            end = time.time()

        if (i + 1) % print_freq == 0:
            print('Extract Features: [{}/{}]\t'
                  'Time {:.3f} ({:.3f})\t'
                  'Data {:.3f} ({:.3f})\t'.format(i + 1, len(data_loader),
                                                  batch_time.val,
                                                  batch_time.avg,
                                                  data_time.val,
                                                  data_time.avg))

    return features, ids, cams, query_files
Ejemplo n.º 9
0
def extract_features(model, data_loader, print_freq=1):
    model.eval()
    batch_time = AverageMeter()
    data_time = AverageMeter()

    features = OrderedDict()
    labels = OrderedDict()
    print('extract feature')
    end = time.time()
    for i, data in enumerate(data_loader):
        imgs, npys, fnames, pids = data.get('img'), data.get('npy'), data.get(
            'fname'), data.get('pid')
        data_time.update(time.time() - end)
        outputs = extract_cnn_feature(model, [imgs, npys])
        for fname, output, pid in zip(fnames, outputs, pids):
            features[fname] = output
            labels[fname] = pid

        batch_time.update(time.time() - end)
        end = time.time()

        if (i + 1) % print_freq == 0:
            print(
                'Extract Features: [{}/{}]\t'
                'Time {:.3f} ({:.3f})\t'
                'Data {:.3f} ({:.3f})\t'.format(i + 1, len(data_loader),
                                                batch_time.val, batch_time.avg,
                                                data_time.val, data_time.avg),
                imgs.shape)

    print(
        'Extract Features: [{}/{}]\t'
        'Time {:.3f} ({:.3f})\t'
        'Data {:.3f} ({:.3f})\t'.format(i + 1, len(data_loader),
                                        batch_time.val, batch_time.avg,
                                        data_time.val, data_time.avg),
        imgs.shape)
    print(
        f'{len(features)} features, each of len {features.values().__iter__().__next__().shape[0]}'
    )
    return features, labels
Ejemplo n.º 10
0
def extract_embeddings(
    model,
    data_loader,
    print_freq=10,
):
    model.eval()
    batch_time = AverageMeter()
    data_time = AverageMeter()
    embeddings = []
    print('extract embedding')
    end = time.time()
    for i, inputs in enumerate(data_loader):
        data_time.update(time.time() - end)
        outputs = extract_cnn_embeddings(model, inputs)
        # print(outputs.shape)
        embeddings.append(outputs)
        batch_time.update(time.time() - end)
        end = time.time()

        if (i + 1) % print_freq == 0:
            print('Extract Embedding: [{}/{}]\t'
                  'Time {:.3f} ({:.3f})\t'
                  'Data {:.3f} ({:.3f})\t'.format(i + 1, len(data_loader),
                                                  batch_time.val,
                                                  batch_time.avg,
                                                  data_time.val,
                                                  data_time.avg))

    print('Extract embedding: [{}/{}]\t'
          'Time {:.3f} ({:.3f})\t'
          'Data {:.3f} ({:.3f})\t'.format(i + 1, len(data_loader),
                                          batch_time.val, batch_time.avg,
                                          data_time.val, data_time.avg))
    res = torch.cat(embeddings)
    print(res.shape)
    return res
Ejemplo n.º 11
0
def inference(model, query_loader, gallery_loader, use_gpu):
    batch_time = AverageMeter()

    model.eval()

    with torch.no_grad():
        qf = []
        for batch_idx, (imgs, _) in enumerate(query_loader):
            if use_gpu:
                imgs = imgs.cuda()
            end = time.time()
            features = extract_cnn_feature(model, imgs)
            batch_time.update(time.time() - end)

            features = features.data.cpu()
            qf.extend(list(features))

        gf, g_paths = [], []
        for batch_idx, (imgs, path) in enumerate(gallery_loader):
            if use_gpu:
                imgs = imgs.cuda()

            end = time.time()
            features = extract_cnn_feature(model, imgs)

            batch_time.update(time.time() - end)

            features = features.data.cpu()
            gf.extend(list(features))
            g_paths.extend(list(path))

    print('=> BatchTime(s): {:.3f}'.format(batch_time.avg))

    x = torch.cat([qf[i].unsqueeze(0) for i in range(len(qf))], 0)
    y = torch.cat([gf[i].unsqueeze(0) for i in range(len(gf))], 0)
    m, n = x.size(0), y.size(0)
    x = x.view(m, -1)
    y = y.view(n, -1)
    dist = torch.pow(x, 2).sum(1).unsqueeze(1).expand(m, n) + \
           torch.pow(y, 2).sum(1).unsqueeze(1).expand(n, m).t()
    dist.addmm_(1, -2, x, y.t())

    return dist
Ejemplo n.º 12
0
def trainMeta(meta_train_loader, meta_test_loader, net, noise, epoch, optimizer,
              centroids, metaCentroids, normalize):
    global args
    noise.requires_grad = True
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()

    mean = torch.Tensor(normalize.mean).view(1, 3, 1, 1).cuda()
    std = torch.Tensor(normalize.std).view(1, 3, 1, 1).cuda()

    net.eval()

    end = time.time()
    optimizer.zero_grad()
    optimizer.rescale()
    for i, ((input, _, pid, _), (metaTest, _, _, _)) in enumerate(zip(meta_train_loader, meta_test_loader)):
        # measure data loading time.
        data_time.update(time.time() - end)
        model.zero_grad()
        input = input.cuda()
        metaTest = metaTest.cuda()

        # one step update
        with torch.no_grad():
            normInput = (input - mean) / std
            feature, realPred = net(normInput)
            scores = centroids.mm(F.normalize(feature.t(), p=2, dim=0))
            # scores = centroids.mm(feature.t())
            realLab = scores.max(0, keepdim=True)[1]
            _, ranks = torch.sort(scores, dim=0, descending=True)
            pos_i = ranks[0, :]
            neg_i = ranks[-1, :]
        neg_feature = centroids[neg_i, :]  # centroids--512*2048
        pos_feature = centroids[pos_i, :]

        current_noise = noise
        current_noise = F.interpolate(
            current_noise.unsqueeze(0),
            mode=MODE, size=tuple(input.shape[-2:]), align_corners=True,
        ).squeeze()
        perturted_input = torch.clamp(input + current_noise, 0, 1)
        perturted_input_norm = (perturted_input - mean) / std
        perturbed_feature = net(perturted_input_norm)[0]

        optimizer.zero_grad()

        pair_loss = 10 * F.triplet_margin_loss(perturbed_feature, neg_feature, pos_feature, 0.5)

        # clsScore = centroids.mm(perturbed_feature.t()).t()
        # oneHotReal = torch.zeros(clsScore.shape).cuda()
        # oneHotReal.scatter_(1, predLab.view(-1, 1), float(1))
        # oneHotReal = F.normalize(1 - oneHotReal, p=1, dim=1)
        # label_loss = -(F.log_softmax(clsScore, 1) * oneHotReal).sum(1).mean()

        fakePred = centroids.mm(perturbed_feature.t()).t()

        oneHotReal = torch.zeros(scores.t().shape).cuda()
        oneHotReal.scatter_(1, realLab.view(-1, 1), float(1))
        label_loss = F.relu(
            (fakePred * oneHotReal).sum(1).mean()
            - (fakePred * (1 - oneHotReal)).max(1)[0].mean()
        )

        pair_loss = pair_loss.view(1)

        loss = pair_loss + label_loss

        # maml one step
        grad = torch.autograd.grad(loss, noise, create_graph=True)[0]
        noiseOneStep = keepGradUpdate(noise, optimizer, grad, MAX_EPS)

        # maml test
        newNoise = F.interpolate(
            noiseOneStep.unsqueeze(0), mode=MODE,
            size=tuple(metaTest.shape[-2:]), align_corners=True,
        ).squeeze()

        with torch.no_grad():
            normMte = (metaTest - mean) / std
            mteFeat = net(normMte)[0]
            scores = metaCentroids.mm(F.normalize(mteFeat.t(), p=2, dim=0))
            # scores = metaCentroids.mm(mteFeat.detach().t())
            metaLab = scores.max(0, keepdim=True)[1]
            _, ranks = torch.sort(scores, dim=0, descending=True)
            pos_i = ranks[0, :]
            neg_i = ranks[-1, :]
        neg_mte_feat = metaCentroids[neg_i, :]  # centroids--512*2048
        pos_mte_feat = metaCentroids[pos_i, :]

        perMteInput = torch.clamp(metaTest + newNoise, 0, 1)
        normPerMteInput = (perMteInput - mean) / std
        normMteFeat = net(normPerMteInput)[0]

        lossMeta = 10 * F.triplet_margin_loss(
            normMteFeat, neg_mte_feat, pos_mte_feat, 0.5
        )

        fakePredMeta = metaCentroids.mm(normMteFeat.t()).t()

        oneHotRealMeta = torch.zeros(scores.t().shape).cuda()
        oneHotRealMeta.scatter_(1, metaLab.view(-1, 1), float(1))
        labelLossMeta = F.relu(
            (fakePredMeta * oneHotRealMeta).sum(1).mean()
            - (fakePredMeta * (1 - oneHotRealMeta)).max(1)[0].mean()
        )

        finalLoss = lossMeta + labelLossMeta + pair_loss + label_loss

        finalLoss.backward()

        losses.update(pair_loss.item())
        optimizer.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % args.print_freq == 0:
            print(
                ">> Train: [{0}][{1}/{2}]\t"
                "Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t"
                "Data {data_time.val:.3f} ({data_time.avg:.3f})\t"
                "Loss {loss.val:.4f} ({loss.avg:.4f})\t"
                "LossMeta {lossMeta:.4f}\t"
                "Noise l2: {noise:.4f}".format(
                    epoch + 1,
                    i, len(meta_train_loader),
                    batch_time=batch_time,
                    data_time=data_time,
                    loss=losses, lossMeta=lossMeta.item(),
                    noise=noise.norm(),
                )
            )

    noise.requires_grad = False
    print(f"Train {epoch}: Loss: {losses.avg}")
    return losses.avg, noise
Ejemplo n.º 13
0
def run():
    np.random.seed(opt.seed)
    torch.manual_seed(opt.seed)
    cudnn.benchmark = True
    data_dir = opt.data_dir

    # Redirect print to both console and log file
    #if not opt.evaluate:
    #    sys.stdout = Logger(osp.join(opt.logs_dir, 'log_l2_per.txt'))
    # Create data loaders
    def readlist(path):
        lines = []
        with open(path, 'r') as f:
            data = f.readlines()

        #pdb.set_trace()
        for line in data:
            name, pid, cam = line.split()
            lines.append((name, int(pid), int(cam)))
        return lines

    # Load data list for wuzhen
    if osp.exists(osp.join(data_dir, 'train.txt')):
        train_list = readlist(osp.join(data_dir, 'train.txt'))
    else:
        print("The training list doesn't exist")

    if osp.exists(osp.join(data_dir, 'val.txt')):
        val_list = readlist(osp.join(data_dir, 'val.txt'))
    else:
        print("The validation list doesn't exist")

    if osp.exists(osp.join(data_dir, 'query.txt')):
        query_list = readlist(osp.join(data_dir, 'query.txt'))
    else:
        print("The query.txt doesn't exist")

    if osp.exists(osp.join(data_dir, 'gallery.txt')):
        gallery_list = readlist(osp.join(data_dir, 'gallery.txt'))
    else:
        print("The gallery.txt doesn't exist")

    if opt.height is None or opt.width is None:
        opt.height, opt.width = (144, 56) if opt.arch == 'inception' else \
                                  (256, 128)

    train_loader,val_loader, test_loader = \
        get_data(opt.split, data_dir, opt.height,
                 opt.width, opt.batchSize, opt.workers,
                 opt.combine_trainval, train_list, val_list, query_list, gallery_list)
    # Create model
    # ori 14514; clear 12654,  16645
    densenet = densenet121(num_classes=20330, num_features=256)
    start_epoch = best_top1 = 0
    if opt.resume:
        #checkpoint = load_checkpoint(opt.resume)
        #densenet.load_state_dict(checkpoint['state_dict'])
        densenet.load_state_dict(torch.load(opt.resume))
        start_epoch = opt.resume_epoch
        print("=> Finetune Start epoch {} ".format(start_epoch))
    if opt.pretrained_model:
        print('Start load params...')
        load_params(densenet, opt.pretrained_model)
    # Load from checkpoint
    #densenet = nn.DataParallel(densenet).cuda()
    metric = DistanceMetric(algorithm=opt.dist_metric)
    print('densenet')
    show_info(densenet, with_arch=True, with_grad=False)
    netG = netg()
    print('netG')
    show_info(netG, with_arch=True, with_grad=False)
    netG.apply(weights_init)
    if opt.netG != '':
        netG.load_state_dict(torch.load(opt.netG))
        #load_params(netG,opt.netG)
    if opt.cuda:
        netG = netG.cuda()
        densenet = densenet.cuda()
    perceptionloss = perception_loss(cuda=opt.cuda)
    l2loss = l2_loss(cuda=opt.cuda)
    #    discriloss=discri_loss(cuda = opt.cuda,batchsize = opt.batchSize,height = \
    #                           opt.height,width = opt.width,lr = opt.lr,step_size = \
    #                           opt.step_size,decay_step = opt.decay_step )
    # Evaluator
    evaluator = Evaluator(densenet)
    #    if opt.evaluate:
    metric.train(densenet, train_loader)
    print("Validation:")
    evaluator.evaluate(val_loader, val_list, val_list, metric)
    print("Test:")
    evaluator.evaluate(test_loader, query_list, gallery_list, metric)
    #    return
    # Criterion
    #    criterion = nn.CrossEntropyLoss(ignore_index=-100).cuda()
    criterion = nn.CrossEntropyLoss().cuda()
    # Optimizer
    param_groups = []
    mult_lr(densenet, param_groups)
    optimizer = optim.SGD(param_groups,
                          lr=opt.lr,
                          momentum=opt.momentum,
                          weight_decay=opt.weight_decay)
    #    optimizer = optim.Adam(param_groups, lr=opt.lr, betas=(opt.beta1, 0.9))

    optimizerG = optim.Adam(netG.parameters(),
                            lr=opt.lr,
                            betas=(opt.beta1, 0.9))

    # Start training
    for epoch in range(start_epoch, opt.epochs):
        adjust_lr(optimizer, epoch)
        adjust_lr(optimizerG, epoch)
        #discriloss.adjust_lr(epoch)
        losses = AverageMeter()
        precisions = AverageMeter()
        densenet.train()
        for i, data in enumerate(train_loader):
            ############################
            # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
            ###########################
            # train with real
            real_cpu, _, pids, _ = data
            if opt.cuda:
                real_cpu = real_cpu.cuda()
                targets = Variable(pids.cuda())
                input.resize_as_(real_cpu).copy_(real_cpu)
            inputv = Variable(input)
            outputs, output_dense, _ = densenet(inputv)
            fake = netG(output_dense)
            fake = fake * 3
            #discriloss(fake = fake, inputv = inputv, i = i)
            ############################
            # (2) Update G network: maximize log(D(G(z)))
            ###########################
            if i % opt.CRITIC_ITERS == 0:
                netG.zero_grad()
                optimizer.zero_grad()
                #loss_discri = discriloss.gloss(fake = fake)
                loss_l2 = l2loss(fake=fake, inputv=inputv)
                loss_perception = perceptionloss(fake=fake, inputv=inputv)
                loss_classify = criterion(outputs, targets)
                prec, = accuracy(outputs.data, targets.data)
                prec = prec[0]
                losses.update(loss_classify.data[0], targets.size(0))
                precisions.update(prec, targets.size(0))
                loss = loss_classify + 0 * loss_l2 + 0 * loss_perception
                #                loss = loss_discri
                loss.backward()
                optimizerG.step()
                optimizer.step()
            #print(precisions.val)
            #print(precisions.avg)
#           print('[%d/%d][%d/%d] '%(epoch, opt.epochs, i, len(train_loader)))


#            print('[%d/%d][%d/%d] Loss_discri: %.4f '%(epoch, opt.epochs, i, \
#                  len(train_loader),loss_discri.data[0]))
            print('[%d/%d][%d/%d] Loss_l2: %.4f Loss_perception: %.4f '%(epoch, opt.epochs, i, \
                  len(train_loader),loss_l2.data[0],loss_perception.data[0]))
            print('Loss {}({})\t'
                  'Prec {}({})\t'.format(losses.val, losses.avg,
                                         precisions.val, precisions.avg))
            if i % 100 == 0:
                vutils.save_image(real_cpu,
                                  '%s/real_samples.png' % opt.outf,
                                  normalize=True)
                outputs, output_dense, _ = densenet(x=inputv)
                fake = netG(output_dense)
                fake = fake * 3
                vutils.save_image(fake.data,
                                  '%s/fake_samples_epoch_%03d.png' %
                                  (opt.outf, epoch),
                                  normalize=True)
        show_info(densenet, with_arch=False, with_grad=True)
        show_info(netG, with_arch=False, with_grad=True)
        if epoch % 5 == 0:
            torch.save(densenet.state_dict(),
                       '%s/densenet_epoch_%d.pth' % (opt.outf, epoch))
            torch.save(netG.state_dict(),
                       '%s/netG_epoch_%d.pth' % (opt.outf, epoch))
        if epoch < opt.start_save:
            continue
        top1 = evaluator.evaluate(val_loader, val_list, val_list)

        is_best = top1 > best_top1
        best_top1 = max(top1, best_top1)
        save_checkpoint(
            {
                'state_dict': densenet.state_dict(),
                'epoch': epoch + 1,
                'best_top1': best_top1,
            },
            is_best,
            fpath=osp.join(opt.logs_dir, 'checkpoint.pth.tar'))

        print('\n * Finished epoch {:3d}  top1: {:5.1%}  best: {:5.1%}{}\n'.
              format(epoch, top1, best_top1, ' *' if is_best else ''))
        if (epoch + 1) % 5 == 0:
            print('Test model: \n')
            evaluator.evaluate(test_loader, query_list, gallery_list)
            model_name = 'epoch_' + str(epoch) + '.pth.tar'
            torch.save({'state_dict': densenet.state_dict()},
                       osp.join(opt.logs_dir, model_name))
    # Final test
    print('Test with best model:')
    checkpoint = load_checkpoint(osp.join(opt.logs_dir, 'model_best.pth.tar'))
    densenet.load_state_dict(checkpoint['state_dict'])
    print('best epoch: ', checkpoint['epoch'])
    metric.train(densenet, train_loader)
    evaluator.evaluate(test_loader, query_list, gallery_list, metric)
Ejemplo n.º 14
0
def trainMeta(meta_train_loader, meta_test_loader, net, epoch, normalize,
              perturbation):
    global args
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()

    mean = torch.Tensor(normalize.mean).view(1, 3, 1, 1).cuda()
    std = torch.Tensor(normalize.std).view(1, 3, 1, 1).cuda()

    net.eval()
    end = time.time()
    perturbation.zero_grad()
    optimizer.zero_grad()

    for i, (input, _, pids, _) in enumerate(meta_train_loader):
        metaTest, _, mtepids, _ = meta_test_loader.next()

        data_time.update(time.time() - end)
        model.zero_grad()
        input = input.cuda()
        metaTest = metaTest.cuda()

        # one step update
        with torch.no_grad():
            norm_output = (input - mean) / std
            feature = net(norm_output)[0]

        current_noise = perturbation
        perturted_input = current_noise(input)
        perturted_input_clamp = torch.clamp(perturted_input, 0, 1)
        perturted_input_norm = (perturted_input_clamp - mean) / std
        perturbed_feature = net(perturted_input_norm)[0]

        optimizer.zero_grad()

        loss = TripletLoss()(feature, pids.cuda(), perturbed_feature)

        # maml one step
        noise = perturbation.parameters()

        grad = torch.autograd.grad(loss, noise, create_graph=True)

        noiseOneStep = keepGradUpdate(perturbation, optimizer, grad)

        perturbation_new = noiseOneStep

        #maml test
        with torch.no_grad():
            normMte = (metaTest - mean) / std
            mteFeat = net(normMte)[0]

        perMteInput = perturbation_new(metaTest)
        perMteInput = torch.clamp(perMteInput, 0, 1)
        normPerMteInput = (perMteInput - mean) / std
        normMteFeat = net(normPerMteInput)[0]

        mteloss = TripletLoss()(mteFeat, mtepids.cuda(), normMteFeat)

        finalLoss = loss + mteloss
        finalLoss.backward()

        losses.update(loss.item())
        optimizer.step()
        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % args.print_freq == 0:
            print(">> Train: [{0}][{1}/{2}]\t"
                  "Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t"
                  "Data {data_time.val:.3f} ({data_time.avg:.3f})\t"
                  "Loss {loss.val:.4f} ({loss.avg:.4f})".format(
                      epoch + 1,
                      i,
                      len(meta_train),
                      batch_time=batch_time,
                      data_time=data_time,
                      loss=losses))
    print(f"Train {epoch}: Loss: {losses.avg}")
    perturbation.state_dict().requires_grad = False
    return losses.avg, perturbation
Ejemplo n.º 15
0
    def single_train(self, model, criterion, optimizer, trial):
        model.train()
        criterion.train()

        batch_time = AverageMeter()
        data_time = AverageMeter()
        losses = AverageMeter()
        precisions = AverageMeter()

        iters = 0
        start_time = time.time()
        end = time.time()

        for ep in range(self.num_epochs):
            for i, inputs in enumerate(self.data_loader):
                data_time.update(time.time() - end)

                iters += 1
                inputs, targets = self._parse_data(inputs)
                loss, acc = self._forward(model, criterion, inputs, targets)

                losses.update(loss.item(), targets.size(0))
                precisions.update(acc, targets.size(0))

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                batch_time.update(time.time() - end)
                end = time.time()

                print('Trial {}: epoch [{}][{}/{}]. '
                      'Time: {:.3f} ({:.3f}). '
                      'Data: {:.3f} ({:.3f}). '
                      'Metric: {:.4f} ({:.4f}). '
                      'Loss: {:.3f} ({:.3f}). '
                      'Prec: {:.2%} ({:.2%}).'.format(
                          trial + 1, ep, i + 1,
                          min(self.max_steps, len(self.data_loader)),
                          batch_time.val, batch_time.avg, data_time.val,
                          data_time.avg, precisions.val / losses.val,
                          precisions.avg / losses.avg, losses.val, losses.avg,
                          precisions.val, precisions.avg),
                      end='\r',
                      file=sys.stdout.console)

                if iters == self.max_steps - 1:
                    break

            if iters == self.max_steps - 1:
                break

        loss = losses.avg
        acc = precisions.avg

        print(
            '* Trial %d. Metric: %.4f. Loss: %.3f. Acc: %.2f%%. Training time: %.0f seconds.                                                     \n'
            %
            (trial + 1, acc / loss, loss, acc * 100, time.time() - start_time))
        return loss, acc
Ejemplo n.º 16
0
def train(args, model, train_loader, start_epoch):
    """Train classifier for source domain."""
    ####################
    # 1. setup network #
    ####################

    base_param_ids = set(map(id, model.module.base.parameters()))
    new_params = [p for p in model.parameters() if id(p) not in base_param_ids]

    param_groups = [{
        'params': model.module.base.parameters(),
        'lr_mult': 0.1
    }, {
        'params': new_params,
        'lr_mult': 1.0
    }]

    optimizer = optim.Adam(param_groups, lr=args.lr)

    # Criterion
    criterion = CapsuleLoss()
    criterion2 = nn.CrossEntropyLoss().cuda()
    criterion3 = nn.CrossEntropyLoss().cuda()

    # Schedule learning rate
    def adjust_lr(epoch):
        lr = args.lr * (0.1**(epoch // args.step_size))

        for g in optimizer.param_groups:
            g['lr'] = lr * g.get('lr_mult', 1)

    ####################
    # 2. train network #
    ####################

    for epoch in range(start_epoch, args.epochs):
        adjust_lr(epoch)

        print_freq = 1
        batch_time = AverageMeter()
        data_time = AverageMeter()
        losses = AverageMeter()
        precisions_id = AverageMeter()
        precisions_id2 = AverageMeter()
        precisions_id3 = AverageMeter()

        end = time.time()

        for i, inputs in enumerate(train_loader):
            data_time.update(time.time() - end)

            imgs, pids, imgname = inputs

            inputs = Variable(imgs.cuda())
            labels = torch.eye(args.true_class).index_select(dim=0, index=pids)
            labels = Variable(labels.cuda())
            targets = Variable(pids.cuda())
            results, y, y2 = model(inputs)

            loss1 = criterion(imgs, labels, results)
            loss2 = criterion2(y, targets)
            loss3 = criterion3(y2, targets)

            prec, = accuracy_capsule(results.data, targets.data,
                                     args.true_class)
            prec = prec[0]

            prec2, = accuracy(y.data, targets.data)
            prec2 = prec2[0]

            prec3, = accuracy(y2.data, targets.data)
            prec3 = prec3[0]

            loss = loss1 + 0.5 * loss2 + 0.5 * loss3

            # update the re-id model
            losses.update(loss.data.item(), targets.size(0))

            precisions_id.update(prec, targets.size(0))
            precisions_id2.update(prec2, targets.size(0))
            precisions_id3.update(prec3, targets.size(0))

            optimizer.zero_grad()

            loss.backward()
            optimizer.step()

            batch_time.update(time.time() - end)
            end = time.time()

            if (i + 1) % print_freq == 0:
                print('Epoch: [{}][{}/{}]\t'
                      'Time {:.3f} ({:.3f})\t'
                      'Data {:.3f} ({:.3f})\t'
                      'Loss {:.3f} ({:.3f})\t'
                      'Prec_capslue {:.2%} ({:.2%})\t'
                      'Prec_ID2 {:.2%} ({:.2%})\t'
                      'Prec_ID3 {:.2%} ({:.2%})\t'.format(
                          epoch, i + 1, len(train_loader), batch_time.val,
                          batch_time.avg, data_time.val, data_time.avg,
                          losses.val, losses.avg, precisions_id.val,
                          precisions_id.avg, precisions_id2.val,
                          precisions_id2.avg, precisions_id3.val,
                          precisions_id3.avg))

        # save model
        if (epoch + 1) % 5 == 0:
            save_checkpoint(
                {
                    'state_dict': model.state_dict(),
                    'epoch': epoch + 1,
                },
                fpath=osp.join(args.logs_dir,
                               'checkpoint' + str(epoch + 1) + '.pth.tar'))

        print('\n * Finished epoch {:3d} \n'.format(epoch))

    return model
Ejemplo n.º 17
0
    def compute_distmat(self, queryloader, galleryloader):
        self.cnnmodel.eval()
        self.classifier.eval()

        queryfeat1, queryfeat2, queryfeat3 = self.extractfeature(queryloader)
        batch_time = AverageMeter()
        data_time = AverageMeter()
        end = time.time()
        print_freq = 50
        distmat = 0

        for i, (imgs, _, pids, _) in enumerate(galleryloader):
            data_time.update(time.time() - end)
            imgs = Variable(imgs, volatile=True)

            if i == 0:
                gallery_feat1, gallery_feat2, gallery_feat3 = self.cnnmodel(
                    imgs)
                preimgs = imgs
            elif imgs.size(0) < galleryloader.batch_size:
                flaw_batchsize = imgs.size(0)
                cat_batchsize = galleryloader.batch_size - flaw_batchsize
                imgs = torch.cat((imgs, preimgs[0:cat_batchsize]), 0)
                gallery_feat1, gallery_feat2, gallery_feat3 = self.cnnmodel(
                    imgs)
                gallery_feat1 = gallery_feat1[0:flaw_batchsize]
                gallery_feat2 = gallery_feat2[0:flaw_batchsize]
                gallery_feat3 = gallery_feat3[0:flaw_batchsize]
            else:
                gallery_feat1, gallery_feat2, gallery_feat3 = self.cnnmodel(
                    imgs)

            batch_cls_encode1, batch_cls_encode2, batch_cls_encode3 = self.classifier(
                queryfeat1, gallery_feat1, queryfeat2, gallery_feat2,
                queryfeat3, gallery_feat3)

            batch_cls_size1 = batch_cls_encode1.size()
            batch_cls_encode1 = batch_cls_encode1.view(-1, 2)
            batch_cls_encode1 = F.softmax(batch_cls_encode1, 1)
            batch_cls_encode1 = batch_cls_encode1.view(batch_cls_size1[0],
                                                       batch_cls_size1[1], 2)
            batch_cls_encode1 = batch_cls_encode1[:, :, 0]

            batch_cls_size2 = batch_cls_encode2.size()
            batch_cls_encode2 = batch_cls_encode2.view(-1, 2)
            batch_cls_encode2 = F.softmax(batch_cls_encode2, 1)
            batch_cls_encode2 = batch_cls_encode2.view(batch_cls_size2[0],
                                                       batch_cls_size2[1], 2)
            batch_cls_encode2 = batch_cls_encode2[:, :, 0]

            batch_cls_size3 = batch_cls_encode3.size()
            batch_cls_encode3 = batch_cls_encode3.view(-1, 2)
            batch_cls_encode3 = F.softmax(batch_cls_encode3, 1)
            batch_cls_encode3 = batch_cls_encode3.view(batch_cls_size3[0],
                                                       batch_cls_size3[1], 2)
            batch_cls_encode3 = batch_cls_encode3[:, :, 0]

            batch_cls_encode = batch_cls_encode1 * self.alphas[
                0] + batch_cls_encode2 * self.alphas[
                    1] + batch_cls_encode3 * self.alphas[2]
            if i == 0:
                distmat = batch_cls_encode.data
            else:
                distmat = torch.cat((distmat, batch_cls_encode.data), 1)

            batch_time.update(time.time() - end)
            end = time.time()

            if (i + 1) % print_freq == 0:
                print('Extract Features: [{}/{}]\t'
                      'Time {:.3f} ({:.3f})\t'
                      'Data {:.3f} ({:.3f})\t'.format(
                          i + 1, len(galleryloader), batch_time.val,
                          batch_time.avg, data_time.val, data_time.avg))
        return distmat
Ejemplo n.º 18
0
if len(gpus) < 2:
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    model = model.to(device)

# If multi gpus
if len(gpus) > 1:
    model = torch.nn.DataParallel(model, range(len(args.gpus))).cuda()

if args.pretrained_weights_dir:
    model = torch.load(args.pretrained_weights_dir)

else:
    model = torch.load(os.path.join(exp_dir, 'model.pth'))

model.eval()
batch_time = AverageMeter()
data_time = AverageMeter()

features = OrderedDict()
labels = OrderedDict()

end = time.time()
print('Extracting features... This may take a while...')
with torch.no_grad():
    for i, (imgs, fnames, pids, _) in enumerate(test_loader):
        data_time.update(time.time() - end)

        imgs_flip = torch.flip(imgs, [3])
        final_feat_list, _, _, _, _, = model(Variable(imgs).cuda())
        final_feat_list_flip, _, _, _, _ = model(Variable(imgs_flip).cuda())
def main(args):

    args = parser.parse_args()

    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    args.cuda = not args.no_cuda and torch.cuda.is_available()
    if args.cuda:
        torch.cuda.manual_seed(args.seed)
        cudnn.benchmark = True

    # Redirect print to both console and log file
    if not args.evaluate:
        sys.stdout = Logger(osp.join(args.logs_dir, 'Part_log.txt'))

    # Create data loaders
    assert args.num_instances > 1, "num_instances should be greater than 1"
    assert args.batch_size % args.num_instances == 0, \
        'num_instances should divide batch_size'
    if args.height is None or args.width is None:
        args.height, args.width = (144, 56) if args.arch == 'inception' else \
                                  (256, 128)
    dataset, num_classes, train_loader, val_loader, test_loader = \
        get_data(args.dataset, args.split, args.data_dir, args.height,
                 args.width, args.batch_size, args.num_instances, args.workers,
                 args.combine_trainval)

    # Create model
    # Hacking here to let the classifier be the last feature embedding layer
    # Net structure: avgpool -> FC(1024) -> FC(args.features)

    model = models.create(args.arch,
                          num_features=512,
                          pretrained=True,
                          dropout=args.dropout,
                          num_classes=args.features,
                          embedding=False)

    # Load from checkpoint
    start_epoch = best_top1 = 0
    if args.resume:
        checkpoint = load_checkpoint(args.resume)
        model.load_state_dict(checkpoint['state_dict'])
        #start_epoch = checkpoint['epoch']
        start_epoch = 0
        best_top1 = checkpoint['best_top1']
        print("=> Start epoch {}  best top1 {:.1%}".format(
            start_epoch, best_top1))
    model = nn.DataParallel(model)
    #model = nn.DataParallel(model).cpu()
    if args.cuda:
        model.cuda()
    # Distance metric
    metric = DistanceMetric(algorithm=args.dist_metric)

    # Evaluator
    evaluator = Evaluator(model)
    if args.evaluate:
        metric.train(model, train_loader)
        print("Validation:")
        evaluator.evaluate(val_loader, dataset.val, dataset.val, metric)
        print("Test:")
        evaluator.evaluate(test_loader, dataset.query, dataset.gallery, metric)
        return

    # Criterion
    # criterion = TripletLoss(margin=args.margin).cpu()
    criterion = TripletLoss(margin=args.margin)
    if args.cuda:
        criterion.cuda()
    #
    # Optimizer

    optimizer = torch.optim.Adam(model.parameters(),
                                 lr=args.lr,
                                 weight_decay=args.weight_decay)
    '''
    optimizer = torch.optim.Adam([{'params': model.module.w1.parameters(), 'lr': 1e-6, 'weight_decay': 5e-4},
                                  {'params': model.module.w2.parameters(), 'lr': 1e-6, 'weight_decay': 5e-4},
                                  {'params': model.module.w3.parameters(), 'lr': 1e-6, 'weight_decay': 5e-4},
                                  {'params': model.module.w4.parameters(), 'lr': 1e-6, 'weight_decay': 5e-4},
                                  {'params': model.module.w5.parameters(), 'lr': 1e-6, 'weight_decay': 5e-4}], lr=args.lr,
                                 weight_decay=args.weight_decay)'''

    # Trainer
    trainer = Trainer(model, criterion)

    # Schedule learning rate
    def adjust_lr(epoch):
        lr = args.lr if epoch <= 100 else \
            args.lr * (0.001 ** ((epoch - 100) / 50.0))
        for g in optimizer.param_groups:
            g['lr'] = lr * g.get('lr_mult', 1)

    # Start training
    accs_market = AverageMeter()
    accs_cuhk03 = AverageMeter()
    for epoch in range(start_epoch, args.epochs):
        adjust_lr(epoch)
        trainer.train(epoch, train_loader, optimizer)
        if epoch < args.start_save:
            continue
        top1, cuhk03_top1, market_top1 = evaluator.evaluate(
            val_loader, dataset.val, dataset.val)
        accs_market.update(market_top1, args.batch_size * 40)
        accs_cuhk03.update(cuhk03_top1, args.batch_size * 40)

        plotter.plot('acc', 'test-multishot', epoch, market_top1)
        plotter.plot('acc', 'test-singleshot', epoch, cuhk03_top1)

        is_best = top1 > best_top1
        best_top1 = max(top1, best_top1)
        save_checkpoint(
            {
                'state_dict': model.module.state_dict(),
                'epoch': epoch + 1,
                'best_top1': best_top1,
            },
            is_best,
            fpath=osp.join(args.logs_dir, 'checkpoint.pth.tar'))

        print('\n * Finished epoch {:3d}  top1: {:5.1%}  best: {:5.1%}{}\n'.
              format(epoch, top1, best_top1, ' *' if is_best else ''))

    # Final test
    print('Test with best model:')
    checkpoint = load_checkpoint(osp.join(args.logs_dir, 'model_best.pth.tar'))
    model.module.load_state_dict(checkpoint['state_dict'])
    metric.train(model, train_loader)
    evaluator.evaluate(test_loader, dataset.query, dataset.gallery, metric)
Ejemplo n.º 20
0
    def train(self, epoch, data_loader, optimizer):
        self.model.train()

        batch_time = AverageMeter()
        data_time = AverageMeter()
        losses = AverageMeter()
        precisions = AverageMeter()
        precisions1 = AverageMeter()
        precisions2 = AverageMeter()

        end = time.time()
        for i, inputs in enumerate(data_loader):
            data_time.update(time.time() - end)
            inputs, targets = self._parse_data(inputs)
            loss, prec_oim, loss_score, prec_finalscore = self._forward(
                inputs, targets, i)

            losses.update(loss.data.item(), targets.size(0))
            precisions.update(prec_oim, targets.size(0))
            precisions1.update(loss_score.data.item(), targets.size(0))
            precisions2.update(prec_finalscore, targets.size(0))

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            batch_time.update(time.time() - end)
            end = time.time()
            print_freq = 50
            if (i + 1) % print_freq == 0:
                print('Epoch: [{}][{}/{}]\t'
                      'Loss {:.3f} ({:.3f})\t'
                      'prec_oim {:.2%} ({:.2%})\t'
                      'prec_score {:.2%} ({:.2%})\t'
                      'prec_finalscore(total) {:.2%} ({:.2%})\t'.format(
                          epoch, i + 1, len(data_loader), losses.val,
                          losses.avg, precisions.val, precisions.avg,
                          precisions1.val, precisions1.avg, precisions2.val,
                          precisions2.avg))
Ejemplo n.º 21
0
# If multi gpus
if len(gpus) > 1:
    model = torch.nn.DataParallel(model, range(len(args.gpus))).cuda()

# Training
for epoch in range(1, args.epochs+1):
    
    adjust_lr_staircase(
        optimizer.param_groups,
        [args.base_lr, args.lr],
        epoch,
        decay_schedule,
        args.staircase_decay_multiply_factor)
    
    model.train()
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    precisions = AverageMeter()
    
    end = time.time()
    for i, inputs in enumerate(train_loader):
        data_time.update(time.time() - end)

        (imgs, _, labels, _) = inputs
        inputs = Variable(imgs).float().cuda()
        labels = Variable(labels).cuda()


        optimizer.zero_grad()
        final_feat_list, logits_local_rest_list, logits_local_list, logits_rest_list, logits_global_list = model(inputs)
Ejemplo n.º 22
0
def train(train_loader, net, noise, epoch, optimizer, centroids, normalize):
    global args
    noise.requires_grad = True
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()

    mean = torch.Tensor(normalize.mean).view(1, 3, 1, 1).cuda()
    std = torch.Tensor(normalize.std).view(1, 3, 1, 1).cuda()

    net.eval()

    end = time.time()
    optimizer.zero_grad()
    optimizer.rescale()
    for i, (input, _, _, _) in enumerate(train_loader):
        # measure data loading time.
        data_time.update(time.time() - end)
        model.zero_grad()
        input = input.cuda()
        with torch.no_grad():
            norm_output = (input - mean) / std
            feature = net(norm_output)[0]

            scores = centroids.mm(F.normalize(feature.t(), p=2, dim=0))
            realLab = scores.max(0, keepdim=True)[1]
            _, ranks = torch.sort(scores, dim=0, descending=True)

            pos_i = ranks[0, :]
            neg_i = ranks[-1, :]
        neg_feature = centroids[neg_i, :].view(-1, 2048)  # centroids--512*2048
        pos_feature = centroids[pos_i, :].view(-1, 2048)

        current_noise = noise
        current_noise = F.interpolate(
            current_noise.unsqueeze(0),
            mode=MODE,
            size=tuple(input.shape[-2:]),
            align_corners=True,
        ).squeeze()
        perturted_input = torch.clamp(input + current_noise, 0, 1)
        perturted_input_norm = (perturted_input - mean) / std
        perturbed_feature = net(perturted_input_norm)[0]

        optimizer.zero_grad()

        pair_loss = 10 * F.triplet_margin_loss(perturbed_feature, neg_feature,
                                               pos_feature, 0.5)

        fakePred = centroids.mm(perturbed_feature.t()).t()

        oneHotReal = torch.zeros(scores.t().shape).cuda()
        oneHotReal.scatter_(1, realLab.view(-1, 1), float(1))
        label_loss = F.relu((fakePred * oneHotReal).sum(1).mean() -
                            (fakePred * (1 - oneHotReal)).max(1)[0].mean())

        loss = pair_loss + label_loss

        loss.backward()

        losses.update(loss.item())
        optimizer.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % args.print_freq == 0:
            print(">> Train: [{0}][{1}/{2}]\t"
                  "Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t"
                  "Data {data_time.val:.3f} ({data_time.avg:.3f})\t"
                  "PairLoss {loss:.4f}\t"
                  "LabelLoss {lossLab:.4f}\t"
                  "Noise l2: {noise:.4f}".format(
                      epoch + 1,
                      i,
                      len(train_loader),
                      batch_time=batch_time,
                      data_time=data_time,
                      loss=pair_loss.item(),
                      lossLab=label_loss.item(),
                      noise=noise.norm(),
                  ))

    noise.requires_grad = False
    print(f"Train {epoch}: Loss: {losses.avg}")
    return losses.avg, noise
Ejemplo n.º 23
0
    def train(self, epoch,mt_train_loader, mt_test_loader, optimizer,noise_model,args):
        self.model.train()
        batch_time = AverageMeter()
        data_time = AverageMeter()
        losses = AverageMeter()
        precisions = AverageMeter()
        end = time.time()
        for i ,(inputs) in enumerate(mt_train_loader):
            meta_input = mt_test_loader.next()
            img, _, pid, _ = inputs
            metaTest, _, meta_pid, _ = meta_input

            adv_inputs_total, adv_labels, adv_labels_total, coupled_inputs = [], [], [], []
            adv_inputs_total_meta, adv_labels_meta, adv_labels_total_meta, coupled_inputs_meta = [], [], [], []

            ###generate perturbed images during meta train##
            adv_data = create_attack_exp(inputs, noise_model)
            adv_inputs, adv_labels, adv_idxs, og_adv_inputs = adv_data
            adv_inputs_total.append(adv_inputs)
            adv_labels_total.append(adv_labels)
            coupled_inputs.append(og_adv_inputs)

            inputs = torch.cat([img.cuda()] + [_.data for _ in adv_inputs_total], dim=0)
            labels = torch.cat([pid.cuda()] + [_.data for _ in adv_labels_total], dim=0)

            inputs, pid = Variable(inputs), Variable(labels)
            finall_input = inputs.cuda()
            targets = pid.cuda()

            ###generate perturbed images during meta test##
            adv_data = create_attack_exp(meta_input, noise_model)
            adv_inputs, adv_labels, adv_idxs, og_adv_inputs = adv_data
            adv_inputs_total_meta.append(adv_inputs)
            adv_labels_total_meta.append(adv_labels)
            coupled_inputs_meta.append(og_adv_inputs)

            meta_input = torch.cat([metaTest.cuda()] + [_.data for _ in adv_inputs_total_meta], dim=0)
            meta_pid = torch.cat([meta_pid.cuda()] + [_.data for _ in adv_labels_total_meta], dim=0)

            meta_input = meta_input.cuda()
            meta_pid = meta_pid.cuda()

            data_time.update(time.time() - end)

            ###meta train####
            cur_model=self.model
            output = cur_model(finall_input)
            loss, prec1 = self._memory(output, targets, epoch)
            self.model.zero_grad()

            grads = torch.autograd.grad(loss, (self.model.module.params()), create_graph=True)
            lr = optimizer.param_groups[0]["lr"]
            lr_base = optimizer.param_groups[1]["lr"]

            ###meta test###
            newMeta = models.create('resMeta', num_classes=class_meta)
            newMeta.copyModel(self.model.module)
            newMeta.update_params(lr_inner=lr, lr_base=lr_base, source_params=grads, solver='adam')
            del grads
            newMeta = nn.DataParallel(newMeta).to(self.device)

            meta_out = newMeta(meta_input)
            metaloss, prec2 = self._memory(meta_out, meta_pid, epoch)

            ######
            loss_finall = metaloss + loss

            optimizer.zero_grad()
            loss_finall.backward()
            optimizer.step()

            losses.update(loss.item(), targets.size(0))
            precisions.update(prec1, targets.size(0))

            batch_time.update(time.time() - end)
            end = time.time()

            if (i + 1) % self.print_freq == 0:
                print('Epoch: [{}][{}/{}]\t'
                      'Time {:.3f} ({:.3f})\t'
                      'Data {:.3f} ({:.3f})\t'
                      'Loss {:.3f} ({:.3f})\t'
                      'Prec {:.2%} ({:.2%})\t'
                      .format(epoch, i + 1, len(mt_train_loader),
                              batch_time.val, batch_time.avg,
                              data_time.val, data_time.avg,
                              losses.val, losses.avg,
                              precisions.val, precisions.avg))
Ejemplo n.º 24
0
def extract_features(model,
                     data_loader,
                     is_flip=False,
                     print_freq=1,
                     metric=None):
    model.eval()
    batch_time = AverageMeter()
    data_time = AverageMeter()

    features = OrderedDict()

    end = time.time()
    if is_flip:
        print('flip')
        for i, (imgs, flip_imgs, fnames) in enumerate(data_loader):
            data_time.update(time.time() - end)

            outputs = extract_cnn_feature(model, imgs)
            flip_outputs = extract_cnn_feature(model, flip_imgs)
            final_outputs = (outputs + flip_outputs) / 2
            for fname, output in zip(fnames, final_outputs):
                features[fname] = output.numpy()

            batch_time.update(time.time() - end)
            end = time.time()

            if (i + 1) % print_freq == 0:
                print('Extract Features: [{}/{}]\t'
                      'Time {:.3f} ({:.3f})\t'
                      'Data {:.3f} ({:.3f})\t'.format(i + 1, len(data_loader),
                                                      batch_time.val,
                                                      batch_time.avg,
                                                      data_time.val,
                                                      data_time.avg))
    else:
        print('no flip')
        for i, (imgs, fnames) in enumerate(data_loader):
            data_time.update(time.time() - end)

            outputs = extract_cnn_feature(model, imgs)
            for fname, output in zip(fnames, outputs):
                features[fname] = output.numpy()

            batch_time.update(time.time() - end)
            end = time.time()

            if (i + 1) % print_freq == 0:
                print('Extract Features: [{}/{}]\t'
                      'Time {:.3f} ({:.3f})\t'
                      'Data {:.3f} ({:.3f})\t'.format(i + 1, len(data_loader),
                                                      batch_time.val,
                                                      batch_time.avg,
                                                      data_time.val,
                                                      data_time.avg))
    return features
Ejemplo n.º 25
0
    def compute_distmat(self, queryloader, galleryloader):
        self.cnnmodel.eval()
        self.classifier.eval()

        queryfeat1, queryfeat2, queryfeat3, q_ids, q_cams = self.extract_video_features(
            queryloader)
        galleryfeat1, galleryfeat2, galleryfeat3, g_ids, g_cams = self.extract_video_features(
            galleryloader)

        batch_time = AverageMeter()
        data_time = AverageMeter()
        end = time.time()
        print_freq = 50
        distmat = 0
        step_size = 32

        N = math.ceil(galleryfeat1.size(0) / (step_size + 0.001))
        for i in range(N):
            #print('@@@@@@ {} {}'.format(i, N))
            idx_bg = i * step_size
            idx_end = (
                i + 1) * step_size if (i + 1) * step_size < galleryfeat1.size(
                    0) else galleryfeat1.size(0)
            #print('@@@@@{} {}'.format(idx_bg,idx_end))
            gallery_feat1 = galleryfeat1[idx_bg:idx_end]
            gallery_feat2 = galleryfeat2[idx_bg:idx_end]
            gallery_feat3 = galleryfeat3[idx_bg:idx_end]

            queryfeat1 = queryfeat1.cuda()
            gallery_feat1 = gallery_feat1.cuda()
            queryfeat2 = queryfeat2.cuda()
            gallery_feat2 = gallery_feat2.cuda()
            queryfeat3 = queryfeat3.cuda()
            gallery_feat3 = gallery_feat3.cuda()
            batch_cls_encode1, batch_cls_encode2, batch_cls_encode3 = self.classifier(
                queryfeat1, gallery_feat1, queryfeat2, gallery_feat2,
                queryfeat3, gallery_feat3)
            batch_cls_size1 = batch_cls_encode1.size()
            batch_cls_encode1 = batch_cls_encode1.view(-1, 2)
            batch_cls_encode1 = F.softmax(batch_cls_encode1, 1)
            batch_cls_encode1 = batch_cls_encode1.view(batch_cls_size1[0],
                                                       batch_cls_size1[1], 2)
            batch_cls_encode1 = batch_cls_encode1[:, :, 0]

            batch_cls_size2 = batch_cls_encode2.size()
            batch_cls_encode2 = batch_cls_encode2.view(-1, 2)
            batch_cls_encode2 = F.softmax(batch_cls_encode2, 1)
            batch_cls_encode2 = batch_cls_encode2.view(batch_cls_size2[0],
                                                       batch_cls_size2[1], 2)
            batch_cls_encode2 = batch_cls_encode2[:, :, 0]

            batch_cls_size3 = batch_cls_encode3.size()
            batch_cls_encode3 = batch_cls_encode3.view(-1, 2)
            batch_cls_encode3 = F.softmax(batch_cls_encode3, 1)
            batch_cls_encode3 = batch_cls_encode3.view(batch_cls_size3[0],
                                                       batch_cls_size3[1], 2)
            batch_cls_encode3 = batch_cls_encode3[:, :, 0]

            batch_cls_encode = batch_cls_encode1 * self.alphas[
                0] + batch_cls_encode2 * self.alphas[
                    1] + batch_cls_encode3 * self.alphas[2]
            if i == 0:
                distmat = batch_cls_encode.data
            else:
                distmat = torch.cat((distmat, batch_cls_encode.data), 1)

            batch_time.update(time.time() - end)
            end = time.time()

            if (i + 1) % print_freq == 0:
                print('Extract Features: [{}/{}]\t'
                      'Time {:.3f} ({:.3f})\t'
                      'Data {:.3f} ({:.3f})\t'.format(
                          i + 1, len(galleryloader), batch_time.val,
                          batch_time.avg, data_time.val, data_time.avg))

        return distmat, q_ids, q_cams, g_ids, g_cams
Ejemplo n.º 26
0
    def train(self, epoch, data_loader, optimizer):
        self.model.train()

        batch_time = AverageMeter()
        data_time = AverageMeter()
        losses = AverageMeter()
        precisions = AverageMeter()

        end = time.time()
        for i, inputs in enumerate(data_loader):
            data_time.update(time.time() - end)

            inputs, targets = self._parse_data(inputs)
            loss, prec1 = self._forward(inputs, targets, epoch)
            losses.update(loss.data[0], targets.size(0))
            precisions.update(prec1, targets.size(0))

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            batch_time.update(time.time() - end)
            end = time.time()

            if (i + 1) % self.print_freq == 0:
                print('Epoch: [{}][{}/{}]\t'
                      'Time {:.3f} ({:.3f})\t'
                      'Data {:.3f} ({:.3f})\t'
                      'Loss {:.3f} ({:.3f})\t'
                      'Prec {:.2%} ({:.2%})\t'.format(
                          epoch, i + 1, len(data_loader), batch_time.val,
                          batch_time.avg, data_time.val, data_time.avg,
                          losses.val, losses.avg, precisions.val,
                          precisions.avg))