예제 #1
0
def validationphase(f_net, g_net, batch_size, path='/home/***/data1t_ssd/librispeech_3/val'):
    f_net.eval()
    g_net.eval()
    dataLoader = getDataloader(path, length = 2, batch_size = batch_size, dataset_length = 1000)
    all_cnt = 0
    hit_cnt = 0
    combinations2 = torch.tensor(list(itertools.combinations(list(range(5)),2)))
    combinations3 = torch.tensor([[0, 2], [0, 3], [0, 4], [1, 3], [1, 4], [2, 4], [4, 3], [4, 4], [5, 4], [7, 4]])
    for i, data in enumerate(dataLoader):
        test_imgs = data
        test_imgs = test_imgs.float().to(device)
        real_batch_size = test_imgs.size(0)
        test_imgs = test_imgs.contiguous().view(real_batch_size * 30, 199, 32)

        embeddings = f_net(test_imgs).contiguous().view(real_batch_size, 30, -1)
        for embedding in embeddings:
            centroids = embedding[-5:]
            comb2_a = centroids[combinations2.transpose(-2, -1)[0]]
            comb2_b = centroids[combinations2.transpose(-2, -1)[1]]
            merged2 = g_net(comb2_a, comb2_b)

            comb3_a = merged2[combinations3.transpose(-2, -1)[0]]
            comb3_b = centroids[combinations3.transpose(-2, -1)[1]]
            merged3 = g_net(comb3_a, comb3_b)

            truth = torch.cat([centroids, merged2, merged3])
            preds = embedding[:-5]
            # dists = pairwiseDists(preds, truth)
            dists = torch.matmul(preds, truth.transpose(0, 1))
            _, res = torch.topk(dists, 1)#, largest=False)
            label = torch.tensor(list(range(25)), device=device)
            all_cnt += 25
            hit_cnt += torch.sum(label == res.squeeze()).item()
    return hit_cnt/all_cnt
예제 #2
0
def validationphase(f_net,
                    g_net,
                    path='/home/***/data1t_ssd/omniglot/validation'):
    f_net.eval()
    g_net.eval()
    dataLoader = getDataloader(path)  #('data/validation', num_samples=200)
    all_cnt = 0
    hit_cnt = 0
    for i, data in enumerate(dataLoader):
        test_imgs, ref_imgs = data
        combinations2 = torch.tensor(
            list(itertools.combinations(list(range(5)), 2)))
        combinations3 = torch.tensor([[0, 2], [0, 3], [0, 4], [1, 3], [1, 4],
                                      [2, 4], [4, 3], [4, 4], [5, 4], [7, 4]])
        # comb_imgs = torch.ones((len(combinations), 64, 64))*255
        # for comb_cnt, combination in enumerate(combinations):
        #     for idx in combination:
        #         comb_imgs[comb_cnt] = torch.min(comb_imgs[comb_cnt], test_imgs[0][idx].float())

        # ref_comb_imgs = torch.ones((len(combinations), 64, 64))*255
        # for comb_cnt, combination in enumerate(combinations):
        #     for idx in combination:
        #         ref_comb_imgs[comb_cnt] = torch.min(ref_comb_imgs[comb_cnt], ref_imgs[0][idx].float())
        # ref_comb_imgs = ref_comb_imgs.float().to(device)/256

        test_imgs, ref_imgs = test_imgs[0].float() / 256, ref_imgs[0].float(
        ) / 256
        test_imgs, ref_imgs = test_imgs.to(device), ref_imgs.to(device)
        # imgs = torch.cat([test_imgs, comb_imgs, ref_imgs, ref_comb_imgs], 0)
        imgs = torch.cat([test_imgs, ref_imgs], 0)
        embeddings = f_net(imgs)

        centroids = embeddings[-5:]
        comb2_a = centroids[combinations2.transpose(-2, -1)[0]]
        comb2_b = centroids[combinations2.transpose(-2, -1)[1]]
        merged2 = g_net(comb2_a, comb2_b)

        comb3_a = merged2[combinations3.transpose(-2, -1)[0]]
        comb3_b = centroids[combinations3.transpose(-2, -1)[1]]
        merged3 = g_net(comb3_a, comb3_b)

        truth = torch.cat([centroids, merged2, merged3])
        # truth = embeddings[-15:]
        preds = embeddings[:-5]
        cosinesim = torch.matmul(preds, truth.transpose(0, 1))
        _, res = torch.topk(cosinesim, 1)
        label = torch.tensor(list(range(25)), device=device)
        all_cnt += 25
        hit_cnt += torch.sum(label == res.squeeze()).item()
    return hit_cnt / all_cnt
예제 #3
0
def trainphase(f_net, g_net, optimizer, batch_size, path='/home/***/data1t_ssd/librispeech_3/train'):
    f_net.train()
    g_net.train()
    dataloader = getDataloader(path, length = 2, batch_size = batch_size, dataset_length = 100000)
    t = tqdm(iter(dataloader), leave=False, total=len(dataloader))
    criterion = nn.MarginRankingLoss(0.1)
    combinations2 = torch.tensor(list(itertools.combinations(list(range(5)),2)))
    combinations3 = torch.tensor([[0, 2], [0, 3], [0, 4], [1, 3], [1, 4], [2, 4], [4, 3], [4, 4], [5, 4], [7, 4]])
    for i, data in enumerate(t):
        optimizer.zero_grad()
        test_imgs = data
        test_imgs = test_imgs.float().to(device)
        real_batch_size = test_imgs.size(0)

        test_imgs = test_imgs.contiguous().view(real_batch_size * 30, 199, 32)

        embeddings = f_net(test_imgs).contiguous().view(real_batch_size, 30, -1)

        loss = 0

        for embedding in embeddings:
            centroids = embedding[-5:]
            comb2_a = centroids[combinations2.transpose(-2, -1)[0]]
            comb2_b = centroids[combinations2.transpose(-2, -1)[1]]
            merged2 = g_net(comb2_a, comb2_b)

            comb3_a = merged2[combinations3.transpose(-2, -1)[0]]
            comb3_b = centroids[combinations3.transpose(-2, -1)[1]]
            merged3 = g_net(comb3_a, comb3_b)

            truth = torch.cat([centroids, merged2, merged3])
            preds = embedding[:-5]
            dists = pairwiseDists(preds, truth)
            losses = torch.zeros(25, device=device)
            for i in range(25):
                if i < 5:
                    weight = 1
                elif i < 15:
                    weight = 0.5
                else:
                    weight = 0.5
                dist = dists[i]
                losses[i] = weight * criterion(dist[[e for e in range(25) if e != i]], dist[[i]*24], torch.ones(24, device=device))
            loss += torch.mean(losses)

        loss.backward()
        optimizer.step()
    return loss.item()
예제 #4
0
def testphase(f_net, g_net, batch_size, path='/home/***/data1t_ssd/librispeech_3/test'):
    f_net.eval()
    g_net.eval()
    dataloader = getDataloader(path, length = 2, batch_size = batch_size, dataset_length = 1000)
    all_cnt0 = 0
    top3_hit_cnt0 = 0
    top1_hit_cnt0 = 0
    all_cnt1 = 0
    top3_hit_cnt1 = 0
    top1_hit_cnt1 = 0
    all_cnt2 = 0
    top3_hit_cnt2 = 0
    top1_hit_cnt2 = 0
    # setsize = torch.tensor([1 for i in range(5)] + [2 for i in range(5, 15)] + [3 for i in range(15, 25)], device=device)
    combinations2 = torch.tensor(list(itertools.combinations(list(range(5)),2)))
    combinations3 = torch.tensor([[0, 2], [0, 3], [0, 4], [1, 3], [1, 4], [2, 4], [4, 3], [4, 4], [5, 4], [7, 4]])
    t = tqdm(iter(dataloader), leave=False, total=len(dataloader))
    for i, data in enumerate(t):
        test_imgs = data
        test_imgs = test_imgs.float().to(device)
        real_batch_size = test_imgs.size(0)

        test_imgs = test_imgs.contiguous().view(30, 199, 32)

        embeddings = f_net(test_imgs).contiguous().view(30, -1)

        centroids = embeddings[-5:]
        comb2_a = centroids[combinations2.transpose(-2, -1)[0]]
        comb2_b = centroids[combinations2.transpose(-2, -1)[1]]
        merged2 = g_net(comb2_a, comb2_b)

        comb3_a = merged2[combinations3.transpose(-2, -1)[0]]
        comb3_b = centroids[combinations3.transpose(-2, -1)[1]]
        merged3 = g_net(comb3_a, comb3_b)

        truth = torch.cat([centroids, merged2, merged3])
        preds = embeddings[:-5]
        dist = torch.matmul(preds, truth.transpose(0, 1))

        _, res = torch.topk(dist, 1)
        # res = setsize[res].squeeze()
        _, res_top3 = torch.topk(dist, 3)
        # res_top3 = setsize[res_top3]
        # res_top3 = res_top3.data.cpu().numpy()
        for i in range(5):
            if i in res_top3[i]:
                top3_hit_cnt0 += 1
        for i in range(5, 15):
            if i in res_top3[i]:
                top3_hit_cnt1 += 1
        for i in range(15, 25):
            if i in res_top3[i]:
                top3_hit_cnt2 += 1

        label = torch.tensor(list(range(25)), device=device)
        # label = setsize
        all_cnt0 += 5
        top1_hit_cnt0 += torch.sum(label[:5] == res.squeeze()[:5]).item()
        all_cnt1 += 10
        top1_hit_cnt1 += torch.sum(label[5:15] == res.squeeze()[5:15]).item()
        all_cnt2 += 10
        top1_hit_cnt2 += torch.sum(label[15:] == res.squeeze()[15:]).item()
    return (top1_hit_cnt0/all_cnt0, 
        top3_hit_cnt0/all_cnt0, 
        top1_hit_cnt1/all_cnt1, 
        top3_hit_cnt1/all_cnt1, 
        top1_hit_cnt2/all_cnt2, 
        top3_hit_cnt2/all_cnt2,
        (top1_hit_cnt0+top1_hit_cnt1+top1_hit_cnt2)/(all_cnt0+all_cnt1+all_cnt2), 
        (top3_hit_cnt0+top3_hit_cnt1+top3_hit_cnt2)/(all_cnt0+all_cnt1+all_cnt2))
예제 #5
0
def trainphase(f_net,
               g_net,
               optimizer,
               path='/home/***/data1t_ssd/omniglot/train'):
    f_net.train()
    g_net.train()
    dataloader = getDataloader(path)  #('data/training_aug', num_samples=10000)
    combinations2 = torch.tensor(
        list(itertools.combinations(list(range(5)), 2)))
    combinations3 = torch.tensor([[0, 2], [0, 3], [0, 4], [1, 3], [1, 4],
                                  [2, 4], [4, 3], [4, 4], [5, 4], [7, 4]])
    t = tqdm(iter(dataloader), leave=False, total=len(dataloader))
    criterion = nn.MarginRankingLoss(0.1)
    for i, data in enumerate(t):
        test_imgs, ref_imgs = data
        # comb_imgs = np.ones((len(combinations), 64, 64))*255
        # for comb_cnt, combination in enumerate(combinations):
        #     for idx in combination:
        #         comb_imgs[comb_cnt] = np.minimum(comb_imgs[comb_cnt], test_imgs.numpy()[0][idx])
        # comb_imgs = torch.from_numpy(comb_imgs)

        # ref_comb_imgs = np.ones((len(combinations), 64, 64))*255
        # for comb_cnt, combination in enumerate(combinations):
        #     for idx in combination:
        #         ref_comb_imgs[comb_cnt] = np.minimum(ref_comb_imgs[comb_cnt], ref_imgs.numpy()[0][idx])
        # ref_comb_imgs = torch.from_numpy(ref_comb_imgs)
        # ref_comb_imgs = ref_comb_imgs.float().to(device)/256

        test_imgs, ref_imgs = test_imgs[0].float() / 256, ref_imgs[0].float(
        ) / 256
        test_imgs, ref_imgs = test_imgs.to(device), ref_imgs.to(device)
        # imgs = torch.cat([test_imgs, comb_imgs, ref_imgs, ref_comb_imgs], 0)
        imgs = torch.cat([test_imgs, ref_imgs], 0)

        embeddings = f_net(imgs)
        optimizer.zero_grad()

        centroids = embeddings[-5:]
        comb2_a = centroids[combinations2.transpose(-2, -1)[0]]
        comb2_b = centroids[combinations2.transpose(-2, -1)[1]]
        merged2 = g_net(comb2_a, comb2_b)

        comb3_a = merged2[combinations3.transpose(-2, -1)[0]]
        comb3_b = centroids[combinations3.transpose(-2, -1)[1]]
        merged3 = g_net(comb3_a, comb3_b)

        truth = torch.cat([centroids, merged2, merged3])
        # truth = embeddings[-15:]
        preds = embeddings[:-5]
        dists = pairwiseDists(preds, truth)
        # import pdb;pdb.set_trace()
        losses = torch.zeros(25, device=device)
        for i in range(25):
            if i < 5:
                weight = 1
            elif i < 15:
                weight = 0.5
            else:
                weight = 0.5

            dist = dists[i]
            losses[i] = weight * criterion(
                dist[[e for e in range(25) if e != i]], dist[[i] * 24],
                torch.ones(24, device=device))
        loss = torch.mean(losses)
        # cosinesim = torch.matmul(preds, truth.transpose(0, 1))
        # loss = criterion(cosinesim, torch.tensor(list(range(15)), device=device))
        loss.backward()
        optimizer.step()
    return loss.item()
예제 #6
0
def testphase(f_net, g_net, path='data/test_aug_3'):
    f_net.eval()
    g_net.eval()
    dataloader = getDataloader(path)
    all_cnt0 = 0
    top3_hit_cnt0 = 0
    top1_hit_cnt0 = 0
    all_cnt1 = 0
    top3_hit_cnt1 = 0
    top1_hit_cnt1 = 0
    all_cnt2 = 0
    top3_hit_cnt2 = 0
    top1_hit_cnt2 = 0
    # setsize = torch.tensor([1 for i in range(5)] + [2 for i in range(5, 15)] + [3 for i in range(15, 25)], device=device)
    combinations2 = torch.tensor(
        list(itertools.combinations(list(range(5)), 2)))
    combinations3 = torch.tensor([[0, 2], [0, 3], [0, 4], [1, 3], [1, 4],
                                  [2, 4], [4, 3], [4, 4], [5, 4], [7, 4]])
    for i, data in enumerate(dataloader):
        test_imgs, ref_imgs = data

        test_imgs, ref_imgs = test_imgs[0].float() / 256, ref_imgs[0].float(
        ) / 256
        test_imgs, ref_imgs = test_imgs.to(device), ref_imgs.to(device)
        imgs = torch.cat([test_imgs, ref_imgs], 0)
        embeddings = f_net(imgs)

        centroids = embeddings[-5:]
        comb2_a = centroids[combinations2.transpose(-2, -1)[0]]
        comb2_b = centroids[combinations2.transpose(-2, -1)[1]]
        merged2 = g_net(comb2_a, comb2_b)

        comb3_a = merged2[combinations3.transpose(-2, -1)[0]]
        comb3_b = centroids[combinations3.transpose(-2, -1)[1]]
        merged3 = g_net(comb3_a, comb3_b)

        truth = torch.cat([centroids, merged2, merged3])
        preds = embeddings[:-5]
        dist = torch.matmul(preds, truth.transpose(0, 1))

        _, res = torch.topk(dist, 1)
        # res = setsize[res].squeeze()
        _, res_top3 = torch.topk(dist, 3)
        # res_top3 = setsize[res_top3]
        # res_top3 = res_top3.data.cpu().numpy()
        for i in range(5):
            if i in res_top3[i]:
                top3_hit_cnt0 += 1
        for i in range(5, 15):
            if i in res_top3[i]:
                top3_hit_cnt1 += 1
        for i in range(15, 25):
            if i in res_top3[i]:
                top3_hit_cnt2 += 1

        label = torch.tensor(list(range(25)), device=device)
        # label = setsize
        all_cnt0 += 5
        top1_hit_cnt0 += torch.sum(label[:5] == res.squeeze()[:5]).item()
        all_cnt1 += 10
        top1_hit_cnt1 += torch.sum(label[5:15] == res.squeeze()[5:15]).item()
        all_cnt2 += 10
        top1_hit_cnt2 += torch.sum(label[15:] == res.squeeze()[15:]).item()
    return (top1_hit_cnt0 / all_cnt0, top3_hit_cnt0 / all_cnt0,
            top1_hit_cnt1 / all_cnt1, top3_hit_cnt1 / all_cnt1,
            top1_hit_cnt2 / all_cnt2, top3_hit_cnt2 / all_cnt2,
            (top1_hit_cnt0 + top1_hit_cnt1 + top1_hit_cnt2) /
            (all_cnt0 + all_cnt1 + all_cnt2),
            (top3_hit_cnt0 + top3_hit_cnt1 + top3_hit_cnt2) /
            (all_cnt0 + all_cnt1 + all_cnt2))