Esempio n. 1
0
def matching_net_predictions(attention: torch.Tensor, n: int, k: int, q: int) -> torch.Tensor:
    """Calculates Matching Network predictions based on equation (1) of the paper.

    The predictions are the weighted sum of the labels of the support set where the
    weights are the "attentions" (i.e. softmax over query-support distances) pointing
    from the query set samples to the support set samples.

    # Arguments
        attention: torch.Tensor containing softmax over query-support distances.
            Should be of shape (q * k, k * n)
        n: Number of support set samples per class, n-shot
        k: Number of classes in the episode, k-way
        q: Number of query samples per-class

    # Returns
        y_pred: Predicted class probabilities
    """
    if attention.shape != (q * k, k * n):
        raise(ValueError(f'Expecting attention Tensor to have shape (q * k, k * n) = ({q * k, k * n})'))

    # Create one hot label vector for the support set
    y_onehot = torch.zeros(k * n, k)

    # Unsqueeze to force y to be of shape (K*n, 1) as this
    # is needed for .scatter()
    y = create_nshot_task_label(k, n).unsqueeze(-1)
    y_onehot = y_onehot.scatter(1, y, 1)

    y_pred = torch.mm(attention, y_onehot.cuda().double())

    return y_pred
Esempio n. 2
0
 def prepare_batch_(batch):
     x, y = batch
     x = x.double().cuda()
     # Create dummy 0-(num_classes - 1) label
     y = create_nshot_task_label(k, n).cuda()
     # for e in x:
     #     plt.imshow(e.cpu().squeeze().numpy())
     #     plt.show()
     return x, y
Esempio n. 3
0
 def prepare_meta_batch_(batch):
     x, y = batch
     # Reshape to `meta_batch_size` number of tasks. Each task contains
     # n*k support samples to train the fast model on and q*k query samples to
     # evaluate the fast model on and generate meta-gradients
     x = x.reshape(meta_batch_size, n*k + q*k, num_input_channels, x.shape[-2], x.shape[-1])
     # Move to device
     x = x.double().to(device)
     # Create label
     y = create_nshot_task_label(k, q).cuda().repeat(meta_batch_size)
     return x, y
Esempio n. 4
0
    def prepare_nshot_task_(
        batch: Tuple[torch.Tensor, torch.Tensor]
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Create 0-k label and move to GPU.

        TODO: Move to arbitrary device
        """
        x, y = batch
        # BEFROE x = x.double().cuda()
        x = x.to(device)  # ADPATED
        # Create dummy 0-(num_classes - 1) label
        y = create_nshot_task_label(k, q).to(device)
        return x, y
Esempio n. 5
0
    def _get_maml_graph(
        self, order: int, inner_train_steps: int
    ) -> Tuple[List[torch.autograd.Function], List[Tuple[
            torch.autograd.Function, torch.autograd.Function]]]:
        """Gets the autograd graph for a single iteration of MAML.

        # Arguments:
            order: Whether to use 1st order MAML (update meta-learner weights with gradients of the updated weights on the
            query set) or 2nd order MAML (use 2nd order updates by differentiating through the gradients of the updated
            weights on the query with respect to the original weights).
            inner_train_steps: Number of gradient steps to fit the fast weights during each inner update

        # Returns
            nodes: List of torch.autograd.Functions that are the nodes of the autograd graph
            edges: List of (Function, Function) tuples that are the edges between the nodes of the autograd graph
        """
        x, _ = self.dummy_tasks.__iter__().__next__()
        x = x.double().reshape(self.meta_batch_size,
                               self.n * self.k + self.q * self.k, x.shape[-1])
        y = create_nshot_task_label(self.k,
                                    self.q).repeat(self.meta_batch_size)

        loss, y_pred = meta_gradient_step(self.model,
                                          self.opt,
                                          torch.nn.CrossEntropyLoss(),
                                          x,
                                          y,
                                          self.n,
                                          self.k,
                                          self.q,
                                          order=order,
                                          inner_train_steps=inner_train_steps,
                                          inner_lr=0.1,
                                          train=True,
                                          device='cpu')

        nodes, edges = autograd_graph(loss)

        return nodes, edges
Esempio n. 6
0
def meta_gradient_step(model: Module, optimiser: Optimizer, loss_fn: Callable,
                       x: torch.Tensor, y: torch.Tensor, n_shot: int,
                       k_way: int, q_queries: int, order: int,
                       inner_train_steps: int, inner_lr: float, train: bool,
                       device: Union[str, torch.device]):
    """
    Perform a gradient step on a meta-learner.

    # Arguments
        model: Base model of the meta-learner being trained
        optimiser: Optimiser to calculate gradient step from loss
        loss_fn: Loss function to calculate between predictions and outputs
        x: Input samples for all few shot tasks
        y: Input labels of all few shot tasks
        n_shot: Number of examples per class in the support set of each task
        k_way: Number of classes in the few shot classification task of each task
        q_queries: Number of examples per class in the query set of each task. The query set is used to calculate
            meta-gradients after applying the update to
        order: Whether to use 1st order MAML (update meta-learner weights with gradients of the updated weights on the
            query set) or 2nd order MAML (use 2nd order updates by differentiating through the gradients of the updated
            weights on the query with respect to the original weights).
        inner_train_steps: Number of gradient steps to fit the fast weights during each inner update
        inner_lr: Learning rate used to update the fast weights on the inner update
        train: Whether to update the meta-learner weights at the end of the episode.
        device: Device on which to run computation
    """
    data_shape = x.shape[2:]
    create_graph = (True if order == 2 else False) and train

    task_gradients = []
    task_losses = []
    task_predictions = []
    for meta_batch_x, meta_batch_y in zip(x, y):
        # By construction x is a 5D tensor of shape: (meta_batch_size, n*k + q*k, channels, width, height)
        # Hence when we iterate over the first  dimension we are iterating through the meta batches
        x_task_train = meta_batch_x[:n_shot]
        x_task_val = meta_batch_x[n_shot:]

        y_task_train = meta_batch_y[:n_shot]
        y_task_val = meta_batch_y[n_shot:]

        # Create a fast model using the current meta model weights
        fast_weights = OrderedDict(model.named_parameters())

        # Train the model for `inner_train_steps` iterations
        for inner_batch in range(inner_train_steps):
            # Perform update of model weights
            #y = create_nshot_task_label(k_way, n_shot).to(device)
            logits = model.functional_forward(x_task_train, fast_weights)

            assert logits.shape == y_task_train.shape

            #print('train')
            #print(logits)
            #print(y_task_train)

            loss = loss_fn(logits, y_task_train)
            gradients = torch.autograd.grad(loss,
                                            fast_weights.values(),
                                            create_graph=create_graph)

            # Update weights manually
            fast_weights = OrderedDict(
                (name, param - inner_lr * grad)
                for ((name, param),
                     grad) in zip(fast_weights.items(), gradients))

        # Do a pass of the model on the validation data from the current task
        #y = create_nshot_task_label(k_way, q_queries).to(device)
        logits = model.functional_forward(x_task_val, fast_weights)
        loss = loss_fn(logits, y_task_val)
        loss.backward(retain_graph=True)

        #print('val')
        #print(logits)
        #print(y_task_val)

        #function_to_map = lambda x: round_updown(x)  # Where `f` instantiates myCustomOp.
        y_pred = logits  #tf.map_fn(function_to_map, logits)
        #print(logits)
        y_pred = (np.array([
            round_updown(xi)
            for xi in y_pred.cpu().detach().numpy().reshape(logits.shape[0])
        ],
                           dtype='long'))
        # Get post-update accuracies
        #y_pred = logits.softmax(dim=0)
        task_predictions.append(torch.from_numpy(y_pred).to(device))

        # Accumulate losses and gradients
        task_losses.append(loss)
        gradients = torch.autograd.grad(loss,
                                        fast_weights.values(),
                                        create_graph=create_graph)
        named_grads = {
            name: g
            for ((name, _), g) in zip(fast_weights.items(), gradients)
        }
        task_gradients.append(named_grads)

    if order == 1:
        pass

        if train:
            sum_task_gradients = {
                k:
                torch.stack([grad[k] for grad in task_gradients]).mean(dim=0)
                for k in task_gradients[0].keys()
            }
            hooks = []
            for name, param in model.named_parameters():
                hooks.append(
                    param.register_hook(replace_grad(sum_task_gradients,
                                                     name)))

            model.train()
            optimiser.zero_grad()
            # Dummy pass in order to create `loss` variable
            # Replace dummy gradients with mean task gradients using hooks
            logits = model(
                torch.zeros((k_way, ) + data_shape).to(device,
                                                       dtype=torch.double))
            loss = loss_fn(logits,
                           create_nshot_task_label(k_way, 1).to(device))
            loss.backward()
            optimiser.step()

            for h in hooks:
                h.remove()

        return torch.stack(task_losses).mean(), torch.cat(task_predictions)

    elif order == 2:
        model.train()
        optimiser.zero_grad()
        meta_batch_loss = torch.stack(task_losses).mean()

        if train:
            meta_batch_loss.backward()
            optimiser.step()

        return meta_batch_loss, torch.cat(task_predictions)
    else:
        raise ValueError('Order must be either 1 or 2.')
Esempio n. 7
0
def meta_gradient_step(
    model: Module,
    optimiser: Optimizer,
    loss_fn: Callable,
    x: torch.Tensor,
    y: torch.Tensor,
    n_shot: int,
    k_way: int,
    q_queries: int,
    order: int,
    inner_train_steps: int,
    inner_lr: float,
    train: bool,
    device: Union[str, torch.device],
    stnmodel=None,
    stnoptim=None,
    args=None,
):
    """
    Perform a gradient step on a meta-learner.

    # Arguments
        model: Base model of the meta-learner being trained
        optimiser: Optimiser to calculate gradient step from loss
        loss_fn: Loss function to calculate between predictions and outputs
        x: Input samples for all few shot tasks
        y: Input labels of all few shot tasks
        n_shot: Number of examples per class in the support set of each task
        k_way: Number of classes in the few shot classification task of each task
        q_queries: Number of examples per class in the query set of each task. The query set is used to calculate
            meta-gradients after applying the update to
        order: Whether to use 1st order MAML (update meta-learner weights with gradients of the updated weights on the
            query set) or 2nd order MAML (use 2nd order updates by differentiating through the gradients of the updated
            weights on the query with respect to the original weights).
        inner_train_steps: Number of gradient steps to fit the fast weights during each inner update
        inner_lr: Learning rate used to update the fast weights on the inner update
        train: Whether to update the meta-learner weights at the end of the episode.
        device: Device on which to run computation
    """
    if stnmodel:
        if train:
            stnmodel.train()
            stnoptim.zero_grad()
        else:
            stnmodel.eval()
    theta = []
    info = None

    # Check for meta parameters
    data_shape = x.shape[2:]
    create_graph = (True if order == 2 else False) and train

    task_gradients = []
    task_losses = []
    task_predictions = []
    for meta_batch in x:
        # By construction x is a 5D tensor of shape: (meta_batch_size, n*k + q*k, channels, width, height)
        # Hence when we iterate over the first  dimension we are iterating through the meta batches
        x_task_train = meta_batch[:n_shot * k_way]
        x_task_val = meta_batch[n_shot * k_way:]

        # Modify some examples
        if stnmodel and train:
            x_task_train, theta_train, info_train = stnmodel(x_task_train, 0)
            x_task_val, theta_val, info_val = stnmodel(x_task_val,
                                                       args.targetonly)
            theta.append(theta_train)
            theta.append(theta_val)

        # Create a fast model using the current meta model weights
        fast_weights = OrderedDict(model.named_parameters())

        # Train the model for `inner_train_steps` iterations
        for inner_batch in range(inner_train_steps):
            # Perform update of model weights
            y = create_nshot_task_label(k_way, n_shot).to(device)
            logits = model.functional_forward(x_task_train, fast_weights)
            loss = loss_fn(logits, y)
            gradients = torch.autograd.grad(loss,
                                            fast_weights.values(),
                                            create_graph=create_graph)

            # Update weights manually
            fast_weights = OrderedDict(
                (name, param - inner_lr * grad)
                for ((name, param),
                     grad) in zip(fast_weights.items(), gradients))

        # Do a pass of the model on the validation data from the current task
        y = create_nshot_task_label(k_way, q_queries).to(device)
        logits = model.functional_forward(x_task_val, fast_weights)
        loss = loss_fn(logits, y)
        loss.backward(retain_graph=True)

        # Get post-update accuracies
        y_pred = logits.softmax(dim=1)
        task_predictions.append(y_pred)

        # Accumulate losses and gradients
        task_losses.append(loss)
        gradients = torch.autograd.grad(loss,
                                        fast_weights.values(),
                                        create_graph=create_graph)
        named_grads = {
            name: g
            for ((name, _), g) in zip(fast_weights.items(), gradients)
        }
        task_gradients.append(named_grads)

    # Append all thetas
    if stnmodel:
        theta = torch.cat(theta, 0)

    if order == 1:
        if train:
            sum_task_gradients = {
                k:
                torch.stack([grad[k] for grad in task_gradients]).mean(dim=0)
                for k in task_gradients[0].keys()
            }
            hooks = []
            for name, param in model.named_parameters():
                hooks.append(
                    param.register_hook(replace_grad(sum_task_gradients,
                                                     name)))

            model.train()
            # Dummy pass in order to create `loss` variable
            # Replace dummy gradients with mean task gradients using hooks
            logits = model(
                torch.zeros((k_way, ) + data_shape).to(device,
                                                       dtype=torch.double))
            loss = loss_fn(logits,
                           create_nshot_task_label(k_way, 1).to(device))

            # Update STN here if present
            if train and stnmodel:
                stnoptim.zero_grad()
                stnloss = -loss + args.stn_reg_coeff * stnidentityloss(theta)
                stnloss.backward(retain_graph=True)
                stnoptim.step()

            # Update parameters
            optimiser.zero_grad()
            loss.backward()
            optimiser.step()

            for h in hooks:
                h.remove()

        return torch.stack(task_losses).mean(), torch.cat(
            task_predictions), x.detach()

    elif order == 2:
        model.train()
        optimiser.zero_grad()
        meta_batch_loss = torch.stack(task_losses).mean()

        if train:
            meta_batch_loss.backward()
            optimiser.step()

        return meta_batch_loss, torch.cat(task_predictions)
    else:
        raise ValueError('Order must be either 1 or 2.')
Esempio n. 8
0
    def _worker(i, models, x, loss_fn, device):

        for model in models:
            o = torch.optim.SGD(model.parameters(), lr=1)
            o.zero_grad()

        losses_mgpu = []
        preds_mgpu = []
        named_grads = {}
        task_predictions_gathered = []
        try:
            with torch.cuda.device(device):
                for meta_batch in x:
                    task_predictions = []
                    x_task_train = meta_batch[:n_shot * k_way]
                    x_task_val = meta_batch[n_shot * k_way:]

                    preds = []
                    fast_weights_dict = {}
                    for idx, model in enumerate(models):
                        fast_weights = OrderedDict(model.named_parameters())

                        for inner_batch in range(inner_train_steps):
                            y = create_nshot_task_label(k_way,
                                                        n_shot).to(device)
                            logits = model.functional_forward(
                                x_task_train, fast_weights)
                            loss = F.cross_entropy(logits, y)
                            gradients = torch.autograd.grad(
                                loss,
                                fast_weights.values(),
                                create_graph=create_graph)

                            fast_weights = OrderedDict(
                                (name, param - inner_lr * grad)
                                for ((name, param), grad
                                     ) in zip(fast_weights.items(), gradients))

                        y = create_nshot_task_label(k_way,
                                                    q_queries).to(device)
                        model_logits = model.functional_forward(
                            x_task_val, fast_weights)
                        preds.append(model_logits)
                        fast_weights_dict[idx] = fast_weights

                    task_predictions.append(preds)

                    with lock:
                        predictions[i] = task_predictions

                    barrier.wait()

                    with lock:
                        task_predictions = gather_predictions(predictions, i)

                    # TODO: does not work
                    y_pred = pred_fn(task_predictions[0])
                    task_predictions_gathered.append(task_predictions[0])
                    loss = loss_fn(y_pred, y)
                    loss.backward(retain_graph=True)

                    preds_mgpu.append(y_pred)
                    losses_mgpu.append(loss)

                    for idx, model in enumerate(models):
                        gradients = torch.autograd.grad(
                            loss,
                            fast_weights_dict[idx].values(),
                            create_graph=create_graph,
                            retain_graph=True)
                        for p, grad in zip(model.parameters(), gradients):
                            p.grad += grad

                total_models = len(task_predictions[0])

                models_task_losses = []  # [n_models, n_tasks]
                models_task_preds = []  # [n_models, n_tasks, n_classes]
                if i == 0:
                    with torch.no_grad():
                        for model_idx in range(total_models):
                            task_loss = []
                            task_pred = []
                            for task in task_predictions_gathered:
                                loss = loss_fn(
                                    F.log_softmax(task[model_idx], dim=-1),
                                    y).item()
                                task_loss.append(loss)
                                task_pred.append(task[model_idx])
                            models_task_losses.append(task_loss)
                            models_task_preds.append(task_pred)

                if order == 2:
                    raise ValueError('Order must be 1')
                elif order == 1:
                    meta_batch_loss = torch.stack(losses_mgpu).mean()
                    with lock:
                        meta_batch_losses[i] = meta_batch_loss
                        task_predictions_mgpu[i] = torch.cat(preds_mgpu)
                        models_losses[i] = models_task_losses
                        models_predictions[i] = models_task_preds

                else:
                    raise ValueError('Order must be either 1 or 2.')

        except Exception:
            with lock:
                meta_batch_losses[i] = ExceptionWrapper(
                    where="in replica {} on device {}".format(i, device))
                task_predictions_mgpu[i] = ExceptionWrapper(
                    where="in replica {} on device {}".format(i, device))
Esempio n. 9
0
    def _worker(i, models, x, device):

        losses_mgpu = []
        preds_mgpu = []
        task_predictions = []

        try:
            with torch.cuda.device(device):
                for meta_batch in x:

                    x_task_train = meta_batch[:n_shot * k_way]
                    x_task_val = meta_batch[n_shot * k_way:]

                    preds = []
                    for model in models:
                        fast_weights = OrderedDict(model.named_parameters())

                        for inner_batch in range(inner_train_steps):
                            y = create_nshot_task_label(k_way,
                                                        n_shot).to(device)
                            logits = model.functional_forward(
                                x_task_train, fast_weights)
                            loss = F.cross_entropy(logits, y)
                            gradients = torch.autograd.grad(
                                loss,
                                fast_weights.values(),
                                create_graph=create_graph)

                            fast_weights = OrderedDict(
                                (name, param - inner_lr * grad)
                                for ((name, param), grad
                                     ) in zip(fast_weights.items(), gradients))

                        y = create_nshot_task_label(k_way,
                                                    q_queries).to(device)
                        model_logits = model.functional_forward(
                            x_task_val, fast_weights)
                        preds.append(model_logits)

                    task_predictions.append(preds)

                with lock:
                    predictions[i] = task_predictions

                barrier.wait()

                with lock:
                    task_predictions = gather_predictions(predictions, i)

                n_models = len(models)
                for task_pred in task_predictions:
                    loss = loss_fn(torch.stack(task_pred,
                                               dim=0).permute(0, 2, 1),
                                   y.unsqueeze(0).repeat(n_models, 1),
                                   reduction='mean').mean()
                    preds_mgpu.append(pred_fn(task_pred, mode=pred_mode))
                    losses_mgpu.append(loss)

                models_task_losses = []  # [n_models, n_tasks]
                models_task_preds = []  # [n_models, n_tasks, n_classes]
                if i == 0:
                    with torch.no_grad():
                        for model_idx in range(n_models):
                            task_loss = []
                            task_pred = []
                            for task in task_predictions:
                                loss = loss_fn(
                                    F.log_softmax(task[model_idx], dim=-1),
                                    y).item()
                                task_loss.append(loss)
                                task_pred.append(task[model_idx])
                            models_task_losses.append(task_loss)
                            models_task_preds.append(task_pred)

                if True:
                    for model in models:
                        model.train()

                    meta_batch_loss = torch.stack(losses_mgpu).mean()
                    if train:
                        meta_batch_loss.backward()

                    with lock:
                        meta_batch_losses[i] = meta_batch_loss
                        task_predictions_mgpu[i] = torch.cat(preds_mgpu)
                        models_losses[i] = models_task_losses
                        models_predictions[i] = models_task_preds
                elif order == 1:
                    pass
                else:
                    raise ValueError('Order must be either 1 or 2.')

        except Exception:
            with lock:
                meta_batch_losses[i] = ExceptionWrapper(
                    where="in replica {} on device {}".format(i, device))
                task_predictions_mgpu[i] = ExceptionWrapper(
                    where="in replica {} on device {}".format(i, device))
Esempio n. 10
0
  def test_label(self):
    n = 1
    k = 5
    q = 1

    y = create_nshot_task_label(k, q)
 def prepare_batch_(batch):
     x, y = batch
     x = x.double().cuda()
     # Create dummy 0-(num_classes - 1) label
     y = create_nshot_task_label(k, n).cuda()
     return x, y