def validate_classification( all_classes, all_data, num_val_tasks, p_rand=None, uncertainty=False, csv_flag=False ): if csv_flag: filename = 'MAML_{0:s}_{1:d}way_{2:d}shot_{3:s}_{4:d}.csv'.format( datasource, num_classes_per_task, num_training_samples_per_class, 'uncertainty' if uncertainty else 'accuracy', resume_epoch ) outfile = open(file=os.path.join('csv', filename), mode='w') wr = csv.writer(outfile, quoting=csv.QUOTE_NONE) else: accuracies = [] total_val_samples_per_task = num_val_samples_per_class * num_classes_per_task all_class_names = [class_name for class_name in sorted(all_classes.keys())] all_task_names = itertools.combinations(all_class_names, r=num_classes_per_task) task_count = 0 for class_labels in all_task_names: if p_rand is not None: skip_task = np.random.binomial(n=1, p=p_rand) # sample from an uniform Bernoulli distribution if skip_task == 1: continue x_t, y_t, x_v, y_v = get_train_val_task_data( all_classes=all_classes, all_data=all_data, class_labels=class_labels, num_samples_per_class=num_samples_per_class, num_training_samples_per_class=num_training_samples_per_class, device=device ) w_task = adapt_to_task(x=x_t, y=y_t, w0=theta) y_pred, prob = predict_label_score(x=x_v, w=w_task) correct = [1 if y_pred[i] == y_v[i] else 0 for i in range(total_val_samples_per_task)] accuracy = np.mean(a=correct, axis=0) if csv_flag: if not uncertainty: outline = [class_label for class_label in class_labels] outline.append(accuracy) wr.writerow(outline) else: for correct_, prob_ in zip(correct, prob): outline = [correct_, prob_] wr.writerow(outline) else: accuracies.append(accuracy) task_count = task_count + 1 if not train_flag: sys.stdout.write('\033[F') print(task_count) if task_count >= num_val_tasks: break if csv_flag: outfile.close() return None else: return accuracies
def meta_train(): if datasource == 'sine_line': data_generator = DataGenerator(num_samples=num_samples_per_class) # create dummy sampler all_class_train = [0] * 10 else: all_class_train, all_data_train = load_dataset( dataset_name=datasource, subset=train_set ) all_class_val, all_data_val = load_dataset( dataset_name=datasource, subset=val_set ) all_class_train.update(all_class_val) all_data_train.update(all_data_val) # initialize data loader train_loader = initialize_dataloader( all_classes=[class_label for class_label in all_class_train], num_classes_per_task=num_classes_per_task ) for epoch in range(resume_epoch, resume_epoch + num_epochs): # variables used to store information of each epoch for monitoring purpose meta_loss_saved = [] # meta loss to save val_accuracies = [] train_accuracies = [] meta_loss = 0 # accumulate the loss of many ensambling networks to descent gradient for meta update num_meta_updates_count = 0 meta_loss_avg_print = 0 # compute loss average to print meta_loss_avg_save = [] # meta loss to save task_count = 0 # a counter to decide when a minibatch of task is completed to perform meta update while (task_count < num_tasks_per_epoch): for class_labels in train_loader: if datasource == 'sine_line': x_t, y_t, x_v, y_v = get_task_sine_line_data( data_generator=data_generator, p_sine=p_sine, num_training_samples=num_training_samples_per_class, noise_flag=True ) x_t = torch.tensor(x_t, dtype=torch.float, device=device) y_t = torch.tensor(y_t, dtype=torch.float, device=device) x_v = torch.tensor(x_v, dtype=torch.float, device=device) y_v = torch.tensor(y_v, dtype=torch.float, device=device) else: x_t, y_t, x_v, y_v = get_train_val_task_data( all_classes=all_class_train, all_data=all_data_train, class_labels=class_labels, num_samples_per_class=num_samples_per_class, num_training_samples_per_class=num_training_samples_per_class, device=device ) w_task = adapt_to_task(x=x_t, y=y_t, w0=theta) y_pred = net.forward(x=x_v, w=w_task) loss_NLL = loss_fn(input=y_pred, target=y_v) if torch.isnan(loss_NLL).item(): sys.exit('NaN error') # accumulate meta loss meta_loss = meta_loss + loss_NLL task_count = task_count + 1 if task_count % num_tasks_per_minibatch == 0: meta_loss = meta_loss/num_tasks_per_minibatch # accumulate into different variables for printing purpose meta_loss_avg_print += meta_loss.item() op_theta.zero_grad() meta_loss.backward() torch.nn.utils.clip_grad_norm_(parameters=theta.values(), max_norm=10) op_theta.step() # Printing losses num_meta_updates_count += 1 if (num_meta_updates_count % num_meta_updates_print == 0): meta_loss_avg_save.append(meta_loss_avg_print/num_meta_updates_count) print('{0:d}, {1:2.4f}'.format( task_count, meta_loss_avg_save[-1] )) num_meta_updates_count = 0 meta_loss_avg_print = 0 if (task_count % num_tasks_save_loss == 0): meta_loss_saved.append(np.mean(meta_loss_avg_save)) meta_loss_avg_save = [] print('Saving loss...') if datasource != 'sine_line': val_accs = validate_classification( all_classes=all_class_val, all_data=all_data_val, num_val_tasks=100, p_rand=0.5, uncertainty=False, csv_flag=False ) val_acc = np.mean(val_accs) val_ci95 = 1.96*np.std(val_accs)/np.sqrt(num_val_tasks) print('Validation accuracy = {0:2.4f} +/- {1:2.4f}'.format(val_acc, val_ci95)) val_accuracies.append(val_acc) train_accs = validate_classification( all_classes=all_class_train, all_data=all_data_train, num_val_tasks=100, p_rand=0.5, uncertainty=False, csv_flag=False ) train_acc = np.mean(train_accs) train_ci95 = 1.96*np.std(train_accs)/np.sqrt(num_val_tasks) print('Train accuracy = {0:2.4f} +/- {1:2.4f}\n'.format(train_acc, train_ci95)) train_accuracies.append(train_acc) # reset meta loss meta_loss = 0 if (task_count >= num_tasks_per_epoch): break if ((epoch + 1)% num_epochs_save == 0): checkpoint = { 'theta': theta, 'meta_loss': meta_loss_saved, 'val_accuracy': val_accuracies, 'train_accuracy': train_accuracies, 'op_theta': op_theta.state_dict() } print('SAVING WEIGHTS...') checkpoint_filename = 'Epoch_{0:d}.pt'.format(epoch + 1) print(checkpoint_filename) torch.save(checkpoint, os.path.join(dst_folder, checkpoint_filename)) scheduler.step() print()