示例#1
0
    def build_loss_matrix(self, embs: Tensor, ys: Tensor):
        lpembdist = distances.LpDistance(normalize_embeddings=False,
                                         p=2,
                                         power=1)
        emb_distance_matrix = lpembdist(embs)

        lpydist = distances.LpDistance(normalize_embeddings=False,
                                       p=1,
                                       power=1)
        y_distance_matrix = lpydist(ys)

        positive_embs = emb_distance_matrix.where(
            y_distance_matrix <= self.threshold,
            torch.tensor(0.).to(embs))
        negative_embs = emb_distance_matrix.where(
            y_distance_matrix > self.threshold,
            torch.tensor(0.).to(embs))

        loss_loop = 0 * torch.tensor([0.], requires_grad=True).to(embs)
        n_positive_triplets = 0
        for i in range(embs.size(0)):
            pos_i = positive_embs[i][positive_embs[i] > 0]
            neg_i = negative_embs[i][negative_embs[i] > 0]
            pairs = torch.cartesian_prod(pos_i, -neg_i)
            if self.soft:
                triplet_losses_for_anchor_i = torch.nn.functional.softplus(
                    pairs.sum(dim=-1))
                if self.eta is not None:
                    # get the corresponding delta ys
                    pos_y_i = y_distance_matrix[i][positive_embs[i] > 0]
                    neg_y_i = y_distance_matrix[i][negative_embs[i] > 0]
                    pairs_y = torch.cartesian_prod(pos_y_i, neg_y_i)
                    assert pairs.shape == pairs_y.shape, (pairs_y.shape,
                                                          pairs.shape)
                    triplet_losses_for_anchor_i = triplet_losses_for_anchor_i * \
                                                  self.smooth_indicator(self.threshold - pairs_y[:, 0]) \
                                                      .div(self.smooth_indicator(self.threshold)) \
                                                  * self.smooth_indicator(pairs_y[:, 1] - self.threshold) \
                                                      .div(self.smooth_indicator(1 - self.threshold))
            else:
                triplet_losses_for_anchor_i = torch.relu(self.margin +
                                                         pairs.sum(dim=-1))
            n_positive_triplets += (triplet_losses_for_anchor_i > 0).sum()
            loss_loop += triplet_losses_for_anchor_i.sum()
        loss_loop = loss_loop.div(max(1, n_positive_triplets))

        return loss_loop
示例#2
0
    def build_loss_matrix(self, embs: Tensor, ys: Tensor):
        eps = 1e-4 / embs.size(0)

        lpembdist = distances.LpDistance(normalize_embeddings=False,
                                         p=2,
                                         power=2)
        emb_distance_matrix = torch.sqrt(lpembdist(embs) + eps)  # L2dist

        lpydist = distances.LpDistance(normalize_embeddings=False,
                                       p=1,
                                       power=1)
        y_distance_matrix = lpydist(ys)

        eps = 1e-6

        loss_loop = 0 * torch.tensor([0.], requires_grad=True).to(embs)
        n_positive_triplets = 0
        m = embs.size()[0] - 1  # #paired

        for ind_a in range(embs.size(0)):
            # auxiliary variables
            idxs = torch.arange(0, m).to(device=embs.device)
            idxs[ind_a:] += 1

            log_dist = torch.log(emb_distance_matrix[ind_a][idxs] + eps)
            log_y_dist = torch.log(y_distance_matrix[ind_a][idxs] + eps)

            diff_log_dist = log_dist.repeat(m, 1).t() - log_dist.repeat(m, 1)
            diff_log_y_dist = log_y_dist.repeat(m, 1).t() - log_y_dist.repeat(
                m, 1)
            assert diff_log_y_dist.shape == diff_log_dist.shape == (m, m), (
                diff_log_y_dist.shape, diff_log_dist.shape, m)
            valid_aij = diff_log_y_dist < 0  # keep triplet having D(y_a, y_i) < D(y_q, y_j)

            log_ratio_loss = (diff_log_dist -
                              diff_log_y_dist).pow(2)[valid_aij].sum()

            loss_loop += log_ratio_loss
            n_positive_triplets += valid_aij.sum()

        loss_loop = loss_loop.div(max(1, n_positive_triplets))

        return loss_loop
示例#3
0
    def build_loss_matrix(self, embs: Tensor, ys: Tensor):
        lpembdist = distances.LpDistance(normalize_embeddings=False,
                                         p=2,
                                         power=1)
        emb_distance_matrix = lpembdist(embs)

        lpydist = distances.LpDistance(normalize_embeddings=False,
                                       p=1,
                                       power=1)
        y_distance_matrix = lpydist(ys)

        loss = torch.zeros_like(emb_distance_matrix).to(embs)

        threshold_matrix = self.threshold * torch.ones(loss.shape).to(embs)

        high_dy_filter = y_distance_matrix > self.threshold
        aux_max_dz_thr = torch.maximum(emb_distance_matrix, threshold_matrix)
        aux_min_dz_thr = torch.minimum(emb_distance_matrix, threshold_matrix)

        if self.hard:
            # dy - dz
            loss[high_dy_filter] = y_distance_matrix[
                high_dy_filter] - emb_distance_matrix[high_dy_filter]
            # dz
            loss[~high_dy_filter] = emb_distance_matrix[~high_dy_filter]
        else:
            # (2 - min(threshold, dz) / threshold) * (dy - max(dz, threshold))
            loss[high_dy_filter] = (2 - aux_min_dz_thr[high_dy_filter]).div(
                self.threshold) * (y_distance_matrix[high_dy_filter] -
                                   aux_max_dz_thr[high_dy_filter])

            #  max(threshold, dz) / threshold * (min(dz, threshold) - dy)
            loss[~high_dy_filter] = aux_max_dz_thr[~high_dy_filter].div(
                self.threshold) * (aux_min_dz_thr[~high_dy_filter] -
                                   y_distance_matrix[~high_dy_filter])

        loss = torch.relu(loss)
        return loss
示例#4
0
def make_local_stationarity_plots(
    centers: Tensor,
    radiuses: list,
    n_samples: int,
    model: Union[EquationGrammarModelTorch, ShapesVAE, TopologyVAE],
    score_function: Callable,
    target: Union[Tensor, np.ndarray],
    save_dir: str,
    dist: Optional[str] = "l2",
):
    # sample all points in once and then dispatch them
    z_dim = centers.shape[-1]
    n_centers = centers.shape[0]
    n_radiuses = len(radiuses)
    n_total_samples = int(n_centers * n_radiuses * n_samples)
    if dist == "sup":
        all_samples = sample_on_hypercube(center=torch.zeros_like(centers[0]),
                                          side=1,
                                          n_samples=n_total_samples)
    elif dist == "l1":
        all_samples = sample_on_hyperdiamond(center=torch.zeros_like(
            centers[0]),
                                             side=1,
                                             n_samples=n_total_samples)
    else:
        all_samples = sample_on_hypersphere(center=torch.zeros_like(
            centers[0]),
                                            radius=1,
                                            n_samples=n_total_samples)
    all_samples_resize = all_samples.view(n_centers, n_radiuses, n_samples, -1)
    all_samples_radius = torch.tensor(radiuses).view(
        1, -1, 1, 1).to(centers) * all_samples_resize
    all_samples_center = all_samples_radius + centers.view(n_centers, 1, 1, -1)
    all_samples = all_samples_center.view(-1, z_dim)

    model.eval()
    model.to(centers.device) if not isinstance(
        model, EquationGrammarModelTorch) else model.vae.to(centers.device)
    res = np.zeros((n_centers, n_radiuses))
    with torch.no_grad():
        if isinstance(model, EquationGrammarModelTorch):
            dec_cntrs = model.decode_from_latent_space(zs=centers,
                                                       n_decode_attempts=100)
            dec_cntrs_scores = score_function(dec_cntrs)
            bs = n_radiuses * n_samples
            dataset = TensorDataset(all_samples)
            dl = DataLoader(dataset, batch_size=bs)
            i = 0
            for (batch, ) in tqdm(dl):
                dec_batch = model.decode_from_latent_space(batch)
                dec_batch_sores = score_function(dec_batch)
                dec_batch_scores_gap = np.abs(dec_cntrs_scores[i] -
                                              dec_batch_sores)
                dec_batch_scores_gap_mean_rad = dec_batch_scores_gap.reshape(
                    n_radiuses, n_samples).mean(-1)
                res[i] = dec_batch_scores_gap_mean_rad
                i += 1
        else:
            dec_cntrs = model.decode_deterministic(centers)
            dec_cntrs_scores = score_function(dec_cntrs, target)
            bs = n_radiuses * n_samples
            dataset = TensorDataset(all_samples)
            dl = DataLoader(dataset, batch_size=bs)
            i = 0
            for (batch, ) in tqdm(dl):
                dec_batch = model.decode_deterministic(batch)
                dec_batch_sores = score_function(dec_batch, target)
                dec_batch_scores_gap = np.abs(dec_cntrs_scores[i] -
                                              dec_batch_sores)
                dec_batch_scores_gap_mean_rad = dec_batch_scores_gap.reshape(
                    n_radiuses, n_samples).mean(-1)
                res[i] = dec_batch_scores_gap_mean_rad
                i += 1

    # loop version is slow but doesn't run out of memory
    # res = np.zeros((n_centers, n_radiuses))
    # for i, c in enumerate(centers):
    #     for j, r in enumerate(radiuses):
    #         res[i][j] = average_distance_to_center_target_value(
    #             center=c,
    #             radius=r,
    #             model=model,
    #             score_function=score_function,
    #             target=target,
    #             samples=all_samples[i][j]
    #         )

    plt.imshow(res)
    plt.title(
        f"Avg. gap to target score per $z$ & per {dist.capitalize()}-dist radius"
    )
    plt.xlabel(f"radius of {dist.capitalize()}-ball around center z")
    plt.ylabel("centers ($z \in$ train set)")
    tickidx = np.arange(0, len(radiuses), len(radiuses) // 10)
    plt.xticks(tickidx, np.round(radiuses[tickidx]))
    plt.savefig(os.path.join(save_dir, f"local_stationarity_{dist}.pdf"))
    plt.close()

    plt.plot(res.mean(0))
    plt.fill_between(np.arange(n_radiuses),
                     res.mean(0) + res.std(0),
                     res.mean(0) - res.std(0),
                     alpha=0.2)
    plt.title(
        f"Avg. gap to target score (avg. on {res.shape[0]} latent points) per {dist}-dist radius"
    )
    plt.xlabel(f"radius of {dist.capitalize()}-ball around center z")
    tickidx = np.arange(0, len(radiuses), len(radiuses) // 10)
    plt.xticks(tickidx, np.round(radiuses[tickidx]))
    plt.savefig(
        os.path.join(save_dir, f"local_stationarity_average_{dist}.pdf"))
    plt.close()

    # get all pairwise distances in latent space and in score space then sort and plot one versus the other
    lpydist = distances.LpDistance(normalize_embeddings=False, p=1, power=1)
    y_distance_matrix = lpydist(torch.from_numpy(dec_cntrs_scores).to(centers))
    y_dists_tril_idx = torch.tril_indices(y_distance_matrix.shape[0],
                                          y_distance_matrix.shape[1],
                                          offset=-1)
    y_dists_tril = y_distance_matrix[y_dists_tril_idx[0, :],
                                     y_dists_tril_idx[1, :]]
    y_dists_sorted, sorted_idx = torch.sort(y_dists_tril)

    for dist in ['sup', 'l2', 'l1', 'cos']:
        if dist == "l2":
            lpembdist = distances.LpDistance(normalize_embeddings=False,
                                             p=2,
                                             power=1)
        elif dist == "l1":
            lpembdist = distances.LpDistance(normalize_embeddings=False,
                                             p=1,
                                             power=1)
        elif dist == "cos":
            lpembdist = distances.DotProductSimilarity()
        else:
            lpembdist = distances.LpDistance(normalize_embeddings=False,
                                             p=np.inf,
                                             power=1)

        emb_distance_matrix = lpembdist(centers)
        emb_dists_tril_idx = torch.tril_indices(emb_distance_matrix.shape[0],
                                                emb_distance_matrix.shape[1],
                                                offset=-1)
        emb_dists_tril = emb_distance_matrix[emb_dists_tril_idx[0, :],
                                             emb_dists_tril_idx[1, :]]
        emb_dists_sorted = emb_dists_tril[sorted_idx]

        dy, dz = y_dists_sorted.cpu().numpy(), emb_dists_sorted.cpu().numpy()
        plt.scatter(dy, dz, marker="+", alpha=0.25)
        # plt.fill_between(dz.mean(), )
        plt.title(f"")
        plt.xlabel(f"absolute difference in score")
        plt.ylabel(f"{dist}-distance in latent space")
        plt.savefig(os.path.join(save_dir, f"y-dist_vs_z_{dist}-dist.pdf"))
        plt.close()

        y_dists_sorted_cat = y_dists_sorted.view(-1,
                                                 len(y_dists_sorted) //
                                                 100).cpu().numpy()
        emb_dists_sorted_cat = emb_dists_sorted.view(
            -1,
            len(emb_dists_sorted) // 100).cpu().numpy()
        plt.plot(y_dists_sorted_cat.mean(-1), emb_dists_sorted_cat.mean(-1))
        plt.fill_between(
            y_dists_sorted_cat.mean(-1),
            emb_dists_sorted_cat.mean(-1) + emb_dists_sorted_cat.std(-1),
            emb_dists_sorted_cat.mean(-1) - emb_dists_sorted_cat.std(-1),
            alpha=0.2)
        plt.title(
            f"Avg. {dist}-dist in latent space vs. avg. absolute score gap")
        plt.xlabel(f"avg. absolute difference in score")
        plt.ylabel(f"avg. {dist}-distance in latent space")
        plt.savefig(os.path.join(save_dir, f"y-dist_vs_z_{dist}-dist_cat.pdf"))
        plt.close()

    print("Local stationarity plots done.")
示例#5
0
})
model = model_config.build()
"""
OPTIMIZER
"""
optimizer_config = Config({"type": Adam, "params": {"lr": 0.001}})
# optimizer = optimizer_config.build()
"""
LOSS
"""
loss_config = Config({
    "type": losses.ContrastiveLoss,
    "params": {
        "pos_margin": 0,
        "neg_margin": 1,
        "distance": distances.LpDistance(),
    },
})
loss = loss_config.build()
"""
MASTER CONFIG
"""

config = {
    "dataset": dataset_config,
    "model": model_config,
    "optimizer": optimizer_config,
    "loss": loss_config,
    "data_loader": data_loader_config,
}
"""
    def __init__(self,
                 train_dl,
                 val_dl,
                 unseen_dl,
                 model,
                 optimizer,
                 scheduler,
                 criterion,
                 mining_function,
                 loss,
                 savePath='./models/',
                 device='cuda',
                 BATCH_SIZE=64):
        self.device = device
        self.train_dl = train_dl
        self.val_dl = val_dl
        self.unseen_dl = unseen_dl
        self.BATCH_SIZE = BATCH_SIZE
        self.model = model.to(self.device)
        self.optimizer = optimizer
        self.scheduler = scheduler
        self.criterion = criterion
        self.mining_function = mining_function
        self.loss = loss
        self.distance = distances.LpDistance(normalize_embeddings=True,
                                             p=2,
                                             power=1)
        self.reducer = reducers.ThresholdReducer(low=0)
        self.regularizer = regularizers.LpRegularizer(p=2)
        if self.mining_function == 'triplet':
            self.mining_func = miners.TripletMarginMiner(
                margin=0.01,
                distance=self.distance,
                type_of_triplets="semihard")
        elif self.mining_function == 'pair':
            self.mining_func = miners.PairMarginMiner(pos_margin=0,
                                                      neg_margin=0.2)

        if self.loss == 'triplet':
            self.loss_function = losses.TripletMarginLoss(
                margin=0.01, distance=self.distance, reducer=self.reducer)
        elif self.loss == 'contrastive':
            self.loss_function = losses.ContrastiveLoss(pos_margin=0,
                                                        neg_margin=1.5)
        elif self.loss == 'panc':
            self.loss_function = losses.ProxyAnchorLoss(
                9,
                128,
                margin=0.01,
                alpha=5,
                reducer=self.reducer,
                weight_regularizer=self.regularizer)
        elif self.loss == 'pnca':
            self.loss_function = losses.ProxyNCALoss(
                9,
                128,
                softmax_scale=1,
                reducer=self.reducer,
                weight_regularizer=self.regularizer)
        elif self.loss == 'normsoftmax':
            self.loss_function = losses.NormalizedSoftmaxLoss(
                9,
                128,
                temperature=0.05,
                reducer=self.reducer,
                weight_regularizer=self.regularizer)

        if self.loss in ['normsoftmax', 'panc', 'pnca']:
            self.loss_optimizer = optim.SGD(self.loss_function.parameters(),
                                            lr=0.0001,
                                            momentum=0.9)
            self.loss_scheduler = lr_scheduler.ReduceLROnPlateau(
                self.loss_optimizer,
                'min',
                patience=3,
                threshold=0.0001,
                factor=0.1,
                verbose=True)

        self.savePath = savePath + 'efigi{}_{}_128'.format(
            self.mining_function, self.loss)
# dataset2 = datasets.MNIST('.', train=False, transform=transform)
#
# train_loader = torch.utils.data.DataLoader(dataset1, batch_size=256, shuffle=True)
# test_loader = torch.utils.data.DataLoader(dataset2, batch_size=256)

output_size = 4
input_size = 768
hidden_size = 200
training_epochs = 30

model = LSTM_model(input_size, output_size, hidden_size).to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-3)
num_epochs = 40

### pytorch-metric-learning stuff ###
distance = distances.LpDistance()
reducer = reducers.MeanReducer()
loss_func = losses.ProxyNCALoss(output_size, hidden_size * 2, softmax_scale=1)
mining_func = miners.TripletMarginMiner(margin=0.2,
                                        distance=distance,
                                        type_of_triplets="semihard")
accuracy_calculator = AccuracyCalculator(
    include=("mean_average_precision_at_r", ), k=10)
### pytorch-metric-learning stuff ###

for epoch in range(1, num_epochs + 1):
    train(model, loss_func, mining_func, device, train_loader, optimizer,
          epoch)
    # test(dataset2, model, accuracy_calculator)

torch.save(model.state_dict(),