예제 #1
0
    def set_forward(self, support_images, support_labels, query_images):
        """
        Overwrites method set_forward in AbstractMetaLearner.
        """
        support_images = set_device(support_images)
        query_images = set_device(query_images)
        support_labels = set_device(support_labels)

        # Save parameters
        feature_parameters = copy.deepcopy(self.feature).cpu().state_dict()

        # Init linear model
        self.linear_model = set_device(nn.Linear(CLASSES["train"], N_WAY_EVAL))
        self.support_based_initializer(support_images, support_labels)

        # Compute the linear model
        self.fine_tune(support_images, support_labels, query_images)

        # Compute score of query
        scores = self.linear_model(self.feature(query_images))

        # Refresh parameters
        self.feature.load_state_dict(feature_parameters)

        return scores
예제 #2
0
 def forward(self, x):
     running_mean = set_device(torch.zeros(x.data.size()[1]))
     running_var = set_device(torch.ones(x.data.size()[1]))
     if self.weight.fast is not None and self.bias.fast is not None:
         out = F.batch_norm(
             x,
             running_mean,
             running_var,
             self.weight.fast,
             self.bias.fast,
             training=True,
             momentum=1,
         )
         # batch_norm momentum hack: follow hack of Kate Rakelly in pytorch-maml/src/layers.py
     else:
         out = F.batch_norm(
             x,
             running_mean,
             running_var,
             self.weight,
             self.bias,
             training=True,
             momentum=1,
         )
     return out
    def propagate(self, laplacian, support_labels):
        """
        Compute label propagation.
        See eq (4) of LEARNING TO PROPAGATE LABELS: TRANSDUCTIVE PROPAGATION NETWORK FOR FEW-SHOT LEARNING
        Args:
            laplacian (torch.Tensor): shape (n_support + n_query, n_support + n_query)
            support_labels (torch.Tensor): artificial support set labels in range (0, n_way)
        Returns:
            torch.Tensor: shape(n_support + n_query, n_support + n_query), similarity matrix between samples.
        """

        # compute labels as one_hot
        n_way = len(torch.unique(support_labels))
        n_support_query = laplacian.size(0)
        n_support = support_labels.size(0)
        n_query = n_support_query - n_support

        ## compute support labels as one hot
        one_hot_labels = set_device(torch.zeros(n_support, n_way))

        one_hot_labels[torch.arange(n_support), support_labels] = 1.0

        ## sample to predict has 0 everywhere
        one_hot_labels = torch.cat(
            [one_hot_labels, set_device(torch.zeros(n_query, n_way))]
        )

        # compute label propagation
        propagation = (
            set_device(torch.eye(laplacian.size(0))) - self.alpha * laplacian + self.eps
        ).inverse()

        scores = torch.matmul(propagation, one_hot_labels)

        return scores
예제 #4
0
    def forward(self, x, y):
        # The Sinkhorn algorithm takes as input three variables :
        C = self._cost_matrix(x, y)  # Wasserstein cost function
        cost_normalization = C.max()
        C = (C / cost_normalization
             )  # Needs to normalize the matrix to be consistent with reg

        x_points = x.shape[-2]
        y_points = y.shape[-2]
        if x.dim() == 2:
            batch_size = 1
        else:
            batch_size = x.shape[0]

        # both marginals are fixed with equal weights
        mu = set_device(
            torch.empty(batch_size,
                        x_points,
                        dtype=torch.float,
                        requires_grad=False).fill_(1.0 / x_points).squeeze())
        nu = set_device(
            torch.empty(batch_size,
                        y_points,
                        dtype=torch.float,
                        requires_grad=False).fill_(1.0 / y_points).squeeze())

        u = torch.zeros_like(mu)
        v = torch.zeros_like(nu)
        # To check if algorithm terminates because of threshold
        # or max iterations reached
        actual_nits = 0

        # Sinkhorn iterations
        for i in range(self.max_iter):
            u1 = u  # useful to check the update
            u = (self.eps * (torch.log(mu + 1e-8) -
                             torch.logsumexp(self.M(C, u, v), dim=-1)) + u)
            v = (self.eps *
                 (torch.log(nu + 1e-8) -
                  torch.logsumexp(self.M(C, u, v).transpose(-2, -1), dim=-1)) +
                 v)
            err = (u - u1).abs().sum(-1).mean()

            actual_nits += 1
            if err.item() < self.thresh:
                break

        U, V = u, v
        # Transport plan pi = diag(a)*K*diag(b)
        pi = torch.exp(self.M(C, U, V))
        # Sinkhorn distance
        cost = torch.sum(pi * C, dim=(-2, -1))

        if self.reduction == "mean":
            cost = cost.mean()
        elif self.reduction == "sum":
            cost = cost.sum()

        return cost, pi, C
예제 #5
0
def get_model(n_classes: int) -> nn.Module:
    logger.info(f"Initializing {model_config.BACKBONE.__name__}...")
    model = set_device(model_config.BACKBONE())
    model.trunk.add_module(
        "fc", set_device(nn.Linear(model.final_feat_dim, n_classes)))
    model.loss_fn = nn.CrossEntropyLoss()
    model.optimizer = erm_training_config.OPTIMIZER(model.parameters())
    return model
예제 #6
0
 def forward(self, x):
     # Placeholders for F.batch_norm, not used since momentum=1.
     running_mean = set_device(torch.zeros(x.data.size()[1]))
     running_var = set_device(torch.ones(x.data.size()[1]))
     out = F.batch_norm(
         x,
         running_mean,
         running_var,
         self.weight,
         self.bias,
         training=True,
         momentum=1,
     )
     return out
예제 #7
0
    def __init__(
        self,
        model_func,
        transportation=None,
        training_stats=None,
        lr=5.0 * 1e-5,
        epochs=25,
    ):
        super(TransFineTune, self).__init__(model_func,
                                            transportation=transportation,
                                            training_stats=training_stats)

        # Hyper-parameters used in the paper.
        self.lr = lr
        self.epochs = epochs

        # Use the output of fc, not the backbone output.
        self.feature.trunk.add_module(
            "fc",
            set_device(nn.Linear(self.feature.final_feat_dim,
                                 CLASSES["train"])))

        # Add a non-linearity to the output
        self.feature.trunk.add_module("relu", nn.ReLU())

        self.linear_model = None
예제 #8
0
def validation(model: nn.Module, data_loader: DataLoader,
               n_batches: int) -> float:
    val_acc_list = []
    model.eval()
    with tqdm(
            zip(range(n_batches), data_loader),
            total=n_batches,
            desc="Validation:",
    ) as tqdm_val:
        for _, (images, labels, _) in tqdm_val:
            val_acc_list.append(
                float((model(set_device(images)).data.topk(
                    1, 1, True, True)[1][:, 0] == set_device(labels)).sum()) /
                len(labels))
            tqdm_val.set_postfix(accuracy=np.asarray(val_acc_list).mean())

    return np.asarray(val_acc_list).mean()
예제 #9
0
    def set_forward(self, support_images, support_labels, query_images):
        """
        Overwrites method set_forward in AbstractMetaLearner.
        """

        support_query_size = len(support_images)
        n_chunks = support_query_size // 32 + 1

        support_chunk = []
        query_chunk = []

        for support, query in zip(support_images.chunk(n_chunks),
                                  query_images.chunk(n_chunks)):

            support_features, query_features = (
                features.detach().cpu() for features in self.extract_features(
                    set_device(support), set_device(query)))

            support_chunk.append(support_features.detach().cpu())
            query_chunk.append(query_features.detach().cpu())

        z_support = torch.cat(support_chunk, dim=0)

        del support_chunk

        z_query = torch.cat(query_chunk, dim=0)

        del query_chunk

        # If a transportation method in the feature space has been defined, use it
        if self.transportation_module:
            z_support, z_query = (z.cpu() for z in self.transportation_module(
                set_device(z_support), set_device(z_query)))

        z_support = z_support.numpy()
        z_query = z_query.numpy()
        support_labels = support_labels.cpu().numpy()

        linear_classifier = RidgeClassifier(alpha=0.1)
        linear_classifier.fit(z_support, support_labels)

        scores = torch.tensor(linear_classifier.decision_function(z_query))

        scores = set_device(scores)
        return scores
예제 #10
0
def training_epoch(model: nn.Module, data_loader: DataLoader, epoch: int,
                   n_batches: int) -> (nn.Module, float):
    loss_list = []
    model.train()

    with tqdm(
            zip(range(n_batches), data_loader),
            total=n_batches,
            desc=f"Epoch {epoch}",
    ) as tqdm_train:
        for batch_id, (images, labels, _) in tqdm_train:
            model, loss_value = fit(model, set_device(images),
                                    set_device(labels))

            loss_list.append(loss_value)

            tqdm_train.set_postfix(loss=np.asarray(loss_list).mean())

    return model, np.asarray(loss_list).mean()
    def extract_features(self, support_images, query_images):
        """
        Computes the features vectors of the support and query sets
        Args:
            support_images (torch.Tensor): shape (n_support_images, **image_dim) input data
            query_images (torch.Tensor): shape (n_query_images, **image_dim) input data

        Returns:
            Tuple(torch.Tensor, torch.Tensor): features vectors of the support and query sets, respectively of shapes
            (n_support_images, features_dim) and (n_query_images, features_dim)
        """
        # Set to CUDA if available
        support_images = set_device(support_images)
        query_images = set_device(query_images)

        z_support = self.feature.forward(support_images)
        z_query = self.feature.forward(query_images)

        return z_support, z_query
    def eval_loop(self, test_loader):
        """

        Args:
            test_loader (DataLoader): loader of a given number of episodes

        Returns:
            tuple(float, float, pd.DataFrame): resp. average loss and classification accuracy,
                and advanced evaluation statistics
        """
        acc_all = []
        loss_all = []
        evaluation_stats = []

        n_tasks = len(test_loader)
        for episode_index, (
                support_images,
                support_labels,
                query_images,
                query_labels,
                class_ids,
                source_domain,
                target_domain,
        ) in enumerate(test_loader):

            query_labels = set_device(query_labels)
            scores = self.set_forward(support_images, support_labels,
                                      query_images).detach()

            loss_value = self.loss_fn(scores, query_labels).detach().item()

            evaluation_stats.append(
                self.get_task_perf(
                    episode_index,
                    scores.cpu(),
                    query_labels.cpu().detach(),
                    class_ids,
                    source_domain,
                    target_domain,
                ))

            loss_all.append(loss_value)

            acc_all.append(self.evaluate(scores, query_labels) * 100)

        evaluations_stats_df = pd.concat(evaluation_stats, ignore_index=True)
        acc_all = np.asarray(acc_all)
        acc_mean = np.mean(acc_all)
        acc_std = np.std(acc_all)
        logger.info("%d Test Accuracy = %4.2f%% +- %4.2f%%" %
                    (n_tasks, acc_mean, confidence_interval(acc_std, n_tasks)))

        return np.asarray(loss_all).mean(), acc_mean, evaluations_stats_df
예제 #13
0
def load_model_non_episodic(model: nn.Module, state_dict: OrderedDict,
                            use_fc: bool) -> nn.Module:
    if use_fc:
        model.feature.trunk.fc = set_device(
            nn.Linear(
                model.feature.final_feat_dim,
                dataset_config.CLASSES["train"] +
                dataset_config.CLASSES["val"],
            ))

    model.feature.load_state_dict(state_dict if use_fc else OrderedDict([(
        k, v) for k, v in state_dict.items() if ".fc." not in k]))
    return model
예제 #14
0
def one_hot(labels):
    """

    Args:
        labels (torch.Tensor): 1-dimensional tensor of integers

    Returns:
        torch.Tensor: 2-dimensional tensor of shape[len(labels), max(labels)] corresponding to the one-hot
            form of the input tensor
    """
    num_class = torch.max(labels) + 1
    return set_device(
        torch.zeros((len(labels), num_class)).scatter_(1, labels.unsqueeze(1),
                                                       1))
예제 #15
0
def load_model(state_path: Path, episodic: bool, use_fc: bool,
               force_ot: bool) -> nn.Module:
    model = set_device(model_config.MODEL(model_config.BACKBONE))

    if force_ot:
        model.transportation_module = model_config.TRANSPORTATION_MODULE
        logger.info("Forced the Optimal Transport module into the model.")

    state_dict = torch.load(state_path)
    model = (load_model_episodic(model, state_dict) if episodic else
             load_model_non_episodic(model, state_dict, use_fc))

    logger.info(f"Loaded model from {state_path}")

    return model
    def train_loop(self, epoch, train_loader, optimizer):
        """
        Executes one training epoch
        Args:
            epoch (int): current epoch
            train_loader (DataLoader): loader of a given number of episodes
            optimizer (torch.optim.Optimizer): model optimizer

        Returns:
            tuple(float, float): resp. average loss and classification accuracy
        """
        print_freq = 100

        loss_list = []
        acc_list = []
        for episode_index, (
                support_images,
                support_labels,
                query_images,
                query_labels,
                _,
                _,
                _,
        ) in enumerate(train_loader):
            query_labels = set_device(query_labels)
            scores, loss_value = self.fit_on_task(support_images,
                                                  support_labels, query_images,
                                                  query_labels, optimizer)

            loss_list.append(loss_value)

            acc_list.append(self.evaluate(scores, query_labels) * 100)

            if episode_index % print_freq == print_freq - 1:
                logger.info(
                    "Epoch {epoch} | Batch {episode_index}/{n_batches} | Loss {loss}"
                    .format(
                        epoch=epoch,
                        episode_index=episode_index + 1,
                        n_batches=len(train_loader),
                        loss=np.asarray(loss_list).mean(),
                    ))

        return np.asarray(loss_list).mean(), np.asarray(acc_list).mean()
예제 #17
0
    def __init__(self, model_func, transportation=None, training_stats=None):
        super(MatchingNet, self).__init__(
            model_func, transportation=transportation, training_stats=training_stats
        )

        self.loss_fn = nn.NLLLoss()

        self.FCE = FullyContextualEmbedding(self.feature.final_feat_dim)
        self.support_features_encoder = set_device(
            nn.LSTM(
                self.feature.final_feat_dim,
                self.feature.final_feat_dim,
                1,
                batch_first=True,
                bidirectional=True,
            )
        )

        self.relu = nn.ReLU()
        self.softmax = nn.Softmax()
예제 #18
0
def train_model():
    logger.info("Initializing data loaders...")
    train_loader, _ = get_episodic_loader(
        "train",
        n_way=training_config.N_WAY,
        n_source=training_config.N_SOURCE,
        n_target=training_config.N_TARGET,
        n_episodes=training_config.N_EPISODES,
    )
    val_loader, _ = get_episodic_loader(
        "val",
        n_way=training_config.N_WAY,
        n_source=training_config.N_SOURCE,
        n_target=training_config.N_TARGET,
        n_episodes=training_config.N_VAL_TASKS,
    )
    if training_config.TEST_SET_VALIDATION_FREQUENCY:
        test_loader, _ = get_episodic_loader(
            "test",
            n_way=training_config.N_WAY,
            n_source=training_config.N_SOURCE,
            n_target=training_config.N_TARGET,
            n_episodes=training_config.N_VAL_TASKS,
        )

    logger.info("Initializing model...")

    model = set_device(model_config.MODEL(model_config.BACKBONE))
    optimizer = training_config.OPTIMIZER(model.parameters())

    max_acc = -1.0
    best_model_epoch = -1
    best_model_state = None

    writer = SummaryWriter(log_dir=experiment_config.SAVE_DIR)

    logger.info("Model and data are ready. Starting training...")
    for epoch in range(training_config.N_EPOCHS):
        # Set model to training mode
        model.train()
        # Execute a training loop of the model
        train_loss, train_acc = model.train_loop(epoch, train_loader,
                                                 optimizer)
        writer.add_scalar("Train/loss", train_loss, epoch)
        writer.add_scalar("Train/acc", train_acc, epoch)
        # Set model to evaluation mode
        model.eval()
        # Evaluate on validation set
        val_loss, val_acc, _ = model.eval_loop(val_loader)
        writer.add_scalar("Val/loss", val_loss, epoch)
        writer.add_scalar("Val/acc", val_acc, epoch)

        # We make sure the best model is saved on disk, in case the training breaks
        if val_acc > max_acc:
            max_acc = val_acc
            best_model_epoch = epoch
            best_model_state = model.state_dict()
            torch.save(best_model_state,
                       experiment_config.SAVE_DIR / "best_model.tar")

        if training_config.TEST_SET_VALIDATION_FREQUENCY:
            if (epoch % training_config.TEST_SET_VALIDATION_FREQUENCY ==
                    training_config.TEST_SET_VALIDATION_FREQUENCY - 1):
                logger.info("Validating on test set...")
                _, test_acc, _ = model.eval_loop(test_loader)
                writer.add_scalar("Test/acc", test_acc, epoch)

    logger.info(f"Training over after {training_config.N_EPOCHS} epochs")
    logger.info("Retrieving model with best validation accuracy...")
    model.load_state_dict(best_model_state)
    logger.info(f"Retrieved model from epoch {best_model_epoch}")

    writer.close()

    return model
예제 #19
0
 def __init__(self, num_features):
     super(FullyContextualEmbedding, self).__init__()
     self.lstm_cell = set_device(nn.LSTMCell(num_features * 2,
                                             num_features))
     self.softmax = nn.Softmax()
     self.c_0 = set_device(Variable(torch.zeros(1, num_features)))