def set_forward(self, support_images, support_labels, query_images): """ Overwrites method set_forward in AbstractMetaLearner. """ support_images = set_device(support_images) query_images = set_device(query_images) support_labels = set_device(support_labels) # Save parameters feature_parameters = copy.deepcopy(self.feature).cpu().state_dict() # Init linear model self.linear_model = set_device(nn.Linear(CLASSES["train"], N_WAY_EVAL)) self.support_based_initializer(support_images, support_labels) # Compute the linear model self.fine_tune(support_images, support_labels, query_images) # Compute score of query scores = self.linear_model(self.feature(query_images)) # Refresh parameters self.feature.load_state_dict(feature_parameters) return scores
def forward(self, x): running_mean = set_device(torch.zeros(x.data.size()[1])) running_var = set_device(torch.ones(x.data.size()[1])) if self.weight.fast is not None and self.bias.fast is not None: out = F.batch_norm( x, running_mean, running_var, self.weight.fast, self.bias.fast, training=True, momentum=1, ) # batch_norm momentum hack: follow hack of Kate Rakelly in pytorch-maml/src/layers.py else: out = F.batch_norm( x, running_mean, running_var, self.weight, self.bias, training=True, momentum=1, ) return out
def propagate(self, laplacian, support_labels): """ Compute label propagation. See eq (4) of LEARNING TO PROPAGATE LABELS: TRANSDUCTIVE PROPAGATION NETWORK FOR FEW-SHOT LEARNING Args: laplacian (torch.Tensor): shape (n_support + n_query, n_support + n_query) support_labels (torch.Tensor): artificial support set labels in range (0, n_way) Returns: torch.Tensor: shape(n_support + n_query, n_support + n_query), similarity matrix between samples. """ # compute labels as one_hot n_way = len(torch.unique(support_labels)) n_support_query = laplacian.size(0) n_support = support_labels.size(0) n_query = n_support_query - n_support ## compute support labels as one hot one_hot_labels = set_device(torch.zeros(n_support, n_way)) one_hot_labels[torch.arange(n_support), support_labels] = 1.0 ## sample to predict has 0 everywhere one_hot_labels = torch.cat( [one_hot_labels, set_device(torch.zeros(n_query, n_way))] ) # compute label propagation propagation = ( set_device(torch.eye(laplacian.size(0))) - self.alpha * laplacian + self.eps ).inverse() scores = torch.matmul(propagation, one_hot_labels) return scores
def forward(self, x, y): # The Sinkhorn algorithm takes as input three variables : C = self._cost_matrix(x, y) # Wasserstein cost function cost_normalization = C.max() C = (C / cost_normalization ) # Needs to normalize the matrix to be consistent with reg x_points = x.shape[-2] y_points = y.shape[-2] if x.dim() == 2: batch_size = 1 else: batch_size = x.shape[0] # both marginals are fixed with equal weights mu = set_device( torch.empty(batch_size, x_points, dtype=torch.float, requires_grad=False).fill_(1.0 / x_points).squeeze()) nu = set_device( torch.empty(batch_size, y_points, dtype=torch.float, requires_grad=False).fill_(1.0 / y_points).squeeze()) u = torch.zeros_like(mu) v = torch.zeros_like(nu) # To check if algorithm terminates because of threshold # or max iterations reached actual_nits = 0 # Sinkhorn iterations for i in range(self.max_iter): u1 = u # useful to check the update u = (self.eps * (torch.log(mu + 1e-8) - torch.logsumexp(self.M(C, u, v), dim=-1)) + u) v = (self.eps * (torch.log(nu + 1e-8) - torch.logsumexp(self.M(C, u, v).transpose(-2, -1), dim=-1)) + v) err = (u - u1).abs().sum(-1).mean() actual_nits += 1 if err.item() < self.thresh: break U, V = u, v # Transport plan pi = diag(a)*K*diag(b) pi = torch.exp(self.M(C, U, V)) # Sinkhorn distance cost = torch.sum(pi * C, dim=(-2, -1)) if self.reduction == "mean": cost = cost.mean() elif self.reduction == "sum": cost = cost.sum() return cost, pi, C
def get_model(n_classes: int) -> nn.Module: logger.info(f"Initializing {model_config.BACKBONE.__name__}...") model = set_device(model_config.BACKBONE()) model.trunk.add_module( "fc", set_device(nn.Linear(model.final_feat_dim, n_classes))) model.loss_fn = nn.CrossEntropyLoss() model.optimizer = erm_training_config.OPTIMIZER(model.parameters()) return model
def forward(self, x): # Placeholders for F.batch_norm, not used since momentum=1. running_mean = set_device(torch.zeros(x.data.size()[1])) running_var = set_device(torch.ones(x.data.size()[1])) out = F.batch_norm( x, running_mean, running_var, self.weight, self.bias, training=True, momentum=1, ) return out
def __init__( self, model_func, transportation=None, training_stats=None, lr=5.0 * 1e-5, epochs=25, ): super(TransFineTune, self).__init__(model_func, transportation=transportation, training_stats=training_stats) # Hyper-parameters used in the paper. self.lr = lr self.epochs = epochs # Use the output of fc, not the backbone output. self.feature.trunk.add_module( "fc", set_device(nn.Linear(self.feature.final_feat_dim, CLASSES["train"]))) # Add a non-linearity to the output self.feature.trunk.add_module("relu", nn.ReLU()) self.linear_model = None
def validation(model: nn.Module, data_loader: DataLoader, n_batches: int) -> float: val_acc_list = [] model.eval() with tqdm( zip(range(n_batches), data_loader), total=n_batches, desc="Validation:", ) as tqdm_val: for _, (images, labels, _) in tqdm_val: val_acc_list.append( float((model(set_device(images)).data.topk( 1, 1, True, True)[1][:, 0] == set_device(labels)).sum()) / len(labels)) tqdm_val.set_postfix(accuracy=np.asarray(val_acc_list).mean()) return np.asarray(val_acc_list).mean()
def set_forward(self, support_images, support_labels, query_images): """ Overwrites method set_forward in AbstractMetaLearner. """ support_query_size = len(support_images) n_chunks = support_query_size // 32 + 1 support_chunk = [] query_chunk = [] for support, query in zip(support_images.chunk(n_chunks), query_images.chunk(n_chunks)): support_features, query_features = ( features.detach().cpu() for features in self.extract_features( set_device(support), set_device(query))) support_chunk.append(support_features.detach().cpu()) query_chunk.append(query_features.detach().cpu()) z_support = torch.cat(support_chunk, dim=0) del support_chunk z_query = torch.cat(query_chunk, dim=0) del query_chunk # If a transportation method in the feature space has been defined, use it if self.transportation_module: z_support, z_query = (z.cpu() for z in self.transportation_module( set_device(z_support), set_device(z_query))) z_support = z_support.numpy() z_query = z_query.numpy() support_labels = support_labels.cpu().numpy() linear_classifier = RidgeClassifier(alpha=0.1) linear_classifier.fit(z_support, support_labels) scores = torch.tensor(linear_classifier.decision_function(z_query)) scores = set_device(scores) return scores
def training_epoch(model: nn.Module, data_loader: DataLoader, epoch: int, n_batches: int) -> (nn.Module, float): loss_list = [] model.train() with tqdm( zip(range(n_batches), data_loader), total=n_batches, desc=f"Epoch {epoch}", ) as tqdm_train: for batch_id, (images, labels, _) in tqdm_train: model, loss_value = fit(model, set_device(images), set_device(labels)) loss_list.append(loss_value) tqdm_train.set_postfix(loss=np.asarray(loss_list).mean()) return model, np.asarray(loss_list).mean()
def extract_features(self, support_images, query_images): """ Computes the features vectors of the support and query sets Args: support_images (torch.Tensor): shape (n_support_images, **image_dim) input data query_images (torch.Tensor): shape (n_query_images, **image_dim) input data Returns: Tuple(torch.Tensor, torch.Tensor): features vectors of the support and query sets, respectively of shapes (n_support_images, features_dim) and (n_query_images, features_dim) """ # Set to CUDA if available support_images = set_device(support_images) query_images = set_device(query_images) z_support = self.feature.forward(support_images) z_query = self.feature.forward(query_images) return z_support, z_query
def eval_loop(self, test_loader): """ Args: test_loader (DataLoader): loader of a given number of episodes Returns: tuple(float, float, pd.DataFrame): resp. average loss and classification accuracy, and advanced evaluation statistics """ acc_all = [] loss_all = [] evaluation_stats = [] n_tasks = len(test_loader) for episode_index, ( support_images, support_labels, query_images, query_labels, class_ids, source_domain, target_domain, ) in enumerate(test_loader): query_labels = set_device(query_labels) scores = self.set_forward(support_images, support_labels, query_images).detach() loss_value = self.loss_fn(scores, query_labels).detach().item() evaluation_stats.append( self.get_task_perf( episode_index, scores.cpu(), query_labels.cpu().detach(), class_ids, source_domain, target_domain, )) loss_all.append(loss_value) acc_all.append(self.evaluate(scores, query_labels) * 100) evaluations_stats_df = pd.concat(evaluation_stats, ignore_index=True) acc_all = np.asarray(acc_all) acc_mean = np.mean(acc_all) acc_std = np.std(acc_all) logger.info("%d Test Accuracy = %4.2f%% +- %4.2f%%" % (n_tasks, acc_mean, confidence_interval(acc_std, n_tasks))) return np.asarray(loss_all).mean(), acc_mean, evaluations_stats_df
def load_model_non_episodic(model: nn.Module, state_dict: OrderedDict, use_fc: bool) -> nn.Module: if use_fc: model.feature.trunk.fc = set_device( nn.Linear( model.feature.final_feat_dim, dataset_config.CLASSES["train"] + dataset_config.CLASSES["val"], )) model.feature.load_state_dict(state_dict if use_fc else OrderedDict([( k, v) for k, v in state_dict.items() if ".fc." not in k])) return model
def one_hot(labels): """ Args: labels (torch.Tensor): 1-dimensional tensor of integers Returns: torch.Tensor: 2-dimensional tensor of shape[len(labels), max(labels)] corresponding to the one-hot form of the input tensor """ num_class = torch.max(labels) + 1 return set_device( torch.zeros((len(labels), num_class)).scatter_(1, labels.unsqueeze(1), 1))
def load_model(state_path: Path, episodic: bool, use_fc: bool, force_ot: bool) -> nn.Module: model = set_device(model_config.MODEL(model_config.BACKBONE)) if force_ot: model.transportation_module = model_config.TRANSPORTATION_MODULE logger.info("Forced the Optimal Transport module into the model.") state_dict = torch.load(state_path) model = (load_model_episodic(model, state_dict) if episodic else load_model_non_episodic(model, state_dict, use_fc)) logger.info(f"Loaded model from {state_path}") return model
def train_loop(self, epoch, train_loader, optimizer): """ Executes one training epoch Args: epoch (int): current epoch train_loader (DataLoader): loader of a given number of episodes optimizer (torch.optim.Optimizer): model optimizer Returns: tuple(float, float): resp. average loss and classification accuracy """ print_freq = 100 loss_list = [] acc_list = [] for episode_index, ( support_images, support_labels, query_images, query_labels, _, _, _, ) in enumerate(train_loader): query_labels = set_device(query_labels) scores, loss_value = self.fit_on_task(support_images, support_labels, query_images, query_labels, optimizer) loss_list.append(loss_value) acc_list.append(self.evaluate(scores, query_labels) * 100) if episode_index % print_freq == print_freq - 1: logger.info( "Epoch {epoch} | Batch {episode_index}/{n_batches} | Loss {loss}" .format( epoch=epoch, episode_index=episode_index + 1, n_batches=len(train_loader), loss=np.asarray(loss_list).mean(), )) return np.asarray(loss_list).mean(), np.asarray(acc_list).mean()
def __init__(self, model_func, transportation=None, training_stats=None): super(MatchingNet, self).__init__( model_func, transportation=transportation, training_stats=training_stats ) self.loss_fn = nn.NLLLoss() self.FCE = FullyContextualEmbedding(self.feature.final_feat_dim) self.support_features_encoder = set_device( nn.LSTM( self.feature.final_feat_dim, self.feature.final_feat_dim, 1, batch_first=True, bidirectional=True, ) ) self.relu = nn.ReLU() self.softmax = nn.Softmax()
def train_model(): logger.info("Initializing data loaders...") train_loader, _ = get_episodic_loader( "train", n_way=training_config.N_WAY, n_source=training_config.N_SOURCE, n_target=training_config.N_TARGET, n_episodes=training_config.N_EPISODES, ) val_loader, _ = get_episodic_loader( "val", n_way=training_config.N_WAY, n_source=training_config.N_SOURCE, n_target=training_config.N_TARGET, n_episodes=training_config.N_VAL_TASKS, ) if training_config.TEST_SET_VALIDATION_FREQUENCY: test_loader, _ = get_episodic_loader( "test", n_way=training_config.N_WAY, n_source=training_config.N_SOURCE, n_target=training_config.N_TARGET, n_episodes=training_config.N_VAL_TASKS, ) logger.info("Initializing model...") model = set_device(model_config.MODEL(model_config.BACKBONE)) optimizer = training_config.OPTIMIZER(model.parameters()) max_acc = -1.0 best_model_epoch = -1 best_model_state = None writer = SummaryWriter(log_dir=experiment_config.SAVE_DIR) logger.info("Model and data are ready. Starting training...") for epoch in range(training_config.N_EPOCHS): # Set model to training mode model.train() # Execute a training loop of the model train_loss, train_acc = model.train_loop(epoch, train_loader, optimizer) writer.add_scalar("Train/loss", train_loss, epoch) writer.add_scalar("Train/acc", train_acc, epoch) # Set model to evaluation mode model.eval() # Evaluate on validation set val_loss, val_acc, _ = model.eval_loop(val_loader) writer.add_scalar("Val/loss", val_loss, epoch) writer.add_scalar("Val/acc", val_acc, epoch) # We make sure the best model is saved on disk, in case the training breaks if val_acc > max_acc: max_acc = val_acc best_model_epoch = epoch best_model_state = model.state_dict() torch.save(best_model_state, experiment_config.SAVE_DIR / "best_model.tar") if training_config.TEST_SET_VALIDATION_FREQUENCY: if (epoch % training_config.TEST_SET_VALIDATION_FREQUENCY == training_config.TEST_SET_VALIDATION_FREQUENCY - 1): logger.info("Validating on test set...") _, test_acc, _ = model.eval_loop(test_loader) writer.add_scalar("Test/acc", test_acc, epoch) logger.info(f"Training over after {training_config.N_EPOCHS} epochs") logger.info("Retrieving model with best validation accuracy...") model.load_state_dict(best_model_state) logger.info(f"Retrieved model from epoch {best_model_epoch}") writer.close() return model
def __init__(self, num_features): super(FullyContextualEmbedding, self).__init__() self.lstm_cell = set_device(nn.LSTMCell(num_features * 2, num_features)) self.softmax = nn.Softmax() self.c_0 = set_device(Variable(torch.zeros(1, num_features)))