def fit(self, support_set: torch.Tensor): self.n_classes = support_set.size(0) def transform(self, x: torch.Tensor): if len(x.size()) == 3: x = torch.unsqueeze(x, 0) self.n_query = x.size(0) prob = F.softmax(torch.rand(self.n_query, self.n_classes), dim=1) if prob.size(0) == 1: prob = torch.squeeze(prob, 0) return prob if __name__ == '__main__': best_model = RandomClassifier() session_info = { "task": "few-shot learning", "model": "RANDOM", } session = Session() session.build(name="RANDOM", comment=r"RANDOM classifier", **session_info) torch.save( best_model, os.path.join(session.data['output_dir'], "trained_model_state_dict.tar")) session.save_info()
def train_protonetmhlnbs(base_subdataset: LabeledSubdataset, val_subdataset: LabeledSubdataset, n_shot: int, n_way: int, n_iterations: int, batch_size: int, eval_period: int, val_batch_size: int, image_size: int, balanced_batches: bool, train_n_way=15, backbone_name='resnet12-np-o', device=torch.device("cuda:0" if torch.cuda.is_available() else "cpu"), **kwargs): session_info = { "task": "few-shot learning", "model": "ProtoNetMHLNBS", "feature_extractor": backbone_name, "n_iterations": n_iterations, "eval_period": eval_period, # "dataset": dataset_name, # "optimizer": optimizer_name, "batch_size": batch_size, "val_batch_size": val_batch_size, "n_shot": n_shot, "n_way": n_way, "train_n_way": train_n_way, "optimizer": 'adam', "image_size": image_size, "balanced_batches": balanced_batches, } session_info.update(kwargs) backbone = FEATURE_EXTRACTORS[backbone_name]() model = ProtoNet_MHLNBS(backbone=backbone).to(device) optimizer = OPTIMIZERS['adam'](model=model) base_sampler = FSLEpisodeSampler(subdataset=base_subdataset, n_way=train_n_way, n_shot=n_shot, batch_size=batch_size, balanced=balanced_batches) val_sampler = FSLEpisodeSampler(subdataset=val_subdataset, n_way=n_way, n_shot=n_shot, batch_size=val_batch_size, balanced=balanced_batches) loss_plotter = PlotterWindow(interval=1000) accuracy_plotter = PlotterWindow(interval=1000) loss_plotter.new_line('Loss') loss_plotter.new_line('Loss Instance') accuracy_plotter.new_line('Train Accuracy') accuracy_plotter.new_line('Validation Accuracy') losses = [] losses_i = [] acc_train = [] acc_val = [] val_iters = [] best_accuracy = 0 best_iteration = -1 print("Training started for parameters:") print(session_info) print() start_time = time.time() for iteration in range(n_iterations): model.train() support_set, batch = base_sampler.sample() # print(support_set.size()) query_set, query_labels = batch # print(query_set.size()) # print(global_classes_mapping) query_set = query_set.to(device) query_labels = query_labels.to(device) optimizer.zero_grad() output, loss, loss_i = model.forward_with_loss(support_set, query_set, query_labels) loss.backward() optimizer.step() labels_pred = output.argmax(dim=1) labels = query_labels cur_accuracy = accuracy(labels=labels, labels_pred=labels_pred) loss_plotter.add_point('Loss', iteration, loss.item()) loss_plotter.add_point('Loss Instance', iteration, loss_i.item()) accuracy_plotter.add_point('Train Accuracy', iteration, cur_accuracy) losses.append(loss.item()) losses_i.append(loss_i.item()) acc_train.append(cur_accuracy) if iteration % eval_period == 0 or iteration == n_iterations - 1: val_start_time = time.time() val_accuracy = evaluate_solution_episodes(model, val_sampler) accuracy_plotter.add_point('Validation Accuracy', iteration, val_accuracy) acc_val.append(val_accuracy) val_iters.append(iteration + 1) if val_accuracy > best_accuracy: best_accuracy = val_accuracy best_iteration = iteration print("Best evaluation result yet!") cur_time = time.time() val_time = cur_time - val_start_time time_used = cur_time - start_time time_per_iteration = time_used / (iteration + 1) print() print("[%d/%d] = %.2f%%\t\tLoss: %.4f" % ( iteration + 1, n_iterations, (iteration + 1) / n_iterations * 100, loss.item())) print("Current validation time: %s" % pretty_time(val_time)) print('Average iteration time: %s\tEstimated execution time: %s' % ( pretty_time(time_per_iteration), pretty_time(time_per_iteration * (n_iterations - iteration - 1)), )) print() cur_time = time.time() training_time = cur_time - start_time print("Training finished. Total execution time: %s" % pretty_time(training_time)) print("Best accuracy is: %.3f" % best_accuracy) print("Best iteration is: [%d/%d]" % (best_iteration + 1, n_iterations)) print() session_info['accuracy'] = best_accuracy session_info['best_iteration'] = best_iteration session_info['execution_time'] = training_time session = Session() session.build(name="ProtoNetMHLNBS", comment=r"ProtoNet with Mahalanobis distance Few-Shot Learning", **session_info) torch.save(model, os.path.join(session.data['output_dir'], "trained_model_state_dict.tar")) iters = list(range(1, n_iterations + 1)) plt.figure(figsize=(20, 20)) plt.plot(iters, losses, label="Loss") plt.plot(iters, losses_i, label="Loss Instance") plt.legend() plt.savefig(os.path.join(session.data['output_dir'], "loss_plot.png")) plt.figure(figsize=(20, 20)) plt.plot(iters, acc_train, label="Train Accuracy") plt.plot(val_iters, acc_val, label="Test Accuracy") plt.legend() plt.savefig(os.path.join(session.data['output_dir'], "acc_plot.png")) session.save_info()
def train_mctdfmn( base_subdataset: LabeledSubdataset, val_subdataset: LabeledSubdataset, n_shot: int, n_way: int, n_iterations: int, batch_size: int, eval_period: int, val_batch_size: int, dataset_classes: int, image_size: int, balanced_batches: bool, pretrained_model: MCTDFMN = None, train_n_way=15, backbone_name='resnet12-np', lr=0.1, train_ts_steps=1, test_ts_steps=10, all_global_prototypes=True, no_scaling=False, pca=False, extend_input=False, device=torch.device("cuda:0" if torch.cuda.is_available() else "cpu"), **kwargs): session_info = { "task": "few-shot learning", "model": "MCT_DFMN", "feature_extractor": backbone_name, "n_iterations": n_iterations, "eval_period": eval_period, # "dataset": dataset_name, # "optimizer": optimizer_name, "batch_size": batch_size, "val_batch_size": val_batch_size, "n_shot": n_shot, "n_way": n_way, "train_n_way": train_n_way, "train_ts_steps": train_ts_steps, "test_ts_steps": test_ts_steps, "optimizer": 'sgd', "all_global_prototypes": all_global_prototypes, "image_size": image_size, "balanced_batches": balanced_batches, "pretrained_model": pretrained_model is not None, "no_scaling": no_scaling, "pca": pca, "extend_input": extend_input, } session_info.update(kwargs) backbone = FEATURE_EXTRACTORS[backbone_name]() if pretrained_model is None: model = MCTDFMN(backbone=backbone, test_transduction_steps=test_ts_steps, train_transduction_steps=train_ts_steps, train_classes=dataset_classes, all_global_prototypes=all_global_prototypes, scaling=not no_scaling, pca=pca, extend_input=extend_input).to(device) else: model = copy.deepcopy(pretrained_model) optimizer = torch.optim.SGD(params=model.parameters(), lr=lr, nesterov=True, weight_decay=0.0005, momentum=0.9) scheduler = LambdaLR(optimizer, lr_lambda=lr_schedule) base_sampler = FSLEpisodeSamplerGlobalLabels(subdataset=base_subdataset, n_way=train_n_way, n_shot=n_shot, batch_size=batch_size, balanced=balanced_batches) val_sampler = FSLEpisodeSampler(subdataset=val_subdataset, n_way=n_way, n_shot=n_shot, batch_size=val_batch_size, balanced=balanced_batches) loss_plotter = PlotterWindow(interval=1000) accuracy_plotter = PlotterWindow(interval=1000) loss_plotter.new_line('Loss') loss_plotter.new_line('Dense Loss') loss_plotter.new_line('Instance Loss') accuracy_plotter.new_line('Train Accuracy') accuracy_plotter.new_line('Validation Accuracy') losses = [] losses_d = [] losses_i = [] acc_train = [] acc_val = [] val_iters = [] best_model = copy.deepcopy(model) best_accuracy = 0 best_iteration = -1 print("Training started for parameters:") print(session_info) print() start_time = time.time() for iteration in range(n_iterations): model.train() support_set, batch, global_classes_mapping = base_sampler.sample() # print(support_set.size()) query_set, query_labels = batch # print(query_set.size()) # print(global_classes_mapping) query_set = query_set.to(device) query_labels = query_labels.to(device) optimizer.zero_grad() output, loss, loss_i, loss_d = model.forward_with_loss( support_set, query_set, query_labels, global_classes_mapping) # output = model.forward(support_set, query_set) # loss = loss_fn(output, query_labels) loss.backward() optimizer.step() scheduler.step() labels_pred = output.argmax(dim=1) labels = query_labels cur_accuracy = accuracy(labels=labels, labels_pred=labels_pred) loss_plotter.add_point('Loss', iteration, loss.item()) loss_plotter.add_point('Dense Loss', iteration, loss_d.item()) loss_plotter.add_point('Instance Loss', iteration, 0.2 * loss_i.item()) accuracy_plotter.add_point('Train Accuracy', iteration, cur_accuracy) losses.append(loss.item()) losses_i.append(loss_i.item()) losses_d.append(loss_d.item()) acc_train.append(cur_accuracy) if iteration % eval_period == 0 or iteration == n_iterations - 1: val_start_time = time.time() val_accuracy = evaluate_solution_episodes(model, val_sampler) accuracy_plotter.add_point('Validation Accuracy', iteration, val_accuracy) acc_val.append(val_accuracy) val_iters.append(iteration + 1) if val_accuracy > best_accuracy: best_accuracy = val_accuracy best_iteration = iteration best_model = copy.deepcopy(model) print("Best evaluation result yet!") cur_time = time.time() val_time = cur_time - val_start_time time_used = cur_time - start_time time_per_iteration = time_used / (iteration + 1) print() print("[%d/%d] = %.2f%%\t\tLoss: %.4f" % (iteration + 1, n_iterations, (iteration + 1) / n_iterations * 100, loss.item())) print("Current validation time: %s" % pretty_time(val_time)) print('Average iteration time: %s\tEstimated execution time: %s' % ( pretty_time(time_per_iteration), pretty_time(time_per_iteration * (n_iterations - iteration - 1)), )) print() cur_time = time.time() training_time = cur_time - start_time print("Training finished. Total execution time: %s" % pretty_time(training_time)) print("Best accuracy is: %.3f" % best_accuracy) print("Best iteration is: [%d/%d]" % (best_iteration + 1, n_iterations)) print() session_info['accuracy'] = best_accuracy session_info['best_iteration'] = best_iteration session_info['execution_time'] = training_time session = Session() session.build( name="FSL_MCTDFMN", comment= r"Few-Shot Learning solution based on https://arxiv.org/abs/2002.12017", **session_info) # session.data.update(session_info) # save_record(name="Few-Shot Learning Training: MCT + DFMN", **session_info) torch.save( best_model, os.path.join(session.data['output_dir'], "trained_model_state_dict.tar")) iters = list(range(1, n_iterations + 1)) plt.figure(figsize=(20, 20)) plt.plot(iters, losses, label="Loss") plt.plot(iters, losses_d, label="Dense Loss") plt.plot(iters, losses_i, label="Instance Loss") plt.legend() plt.savefig(os.path.join(session.data['output_dir'], "loss_plot.png")) plt.figure(figsize=(20, 20)) plt.plot(iters, acc_train, label="Train Accuracy") plt.plot(val_iters, acc_val, label="Test Accuracy") plt.legend() plt.savefig(os.path.join(session.data['output_dir'], "acc_plot.png")) session.save_info() return best_model