def meta_validation(datasubset, num_val_tasks, return_uncertainty=False): if datasource == 'sine_line': x0 = torch.linspace(start=-5, end=5, steps=100, device=device).view(-1, 1) # vector if num_val_tasks == 0: from matplotlib import pyplot as plt import matplotlib matplotlib.rcParams['xtick.labelsize'] = 16 matplotlib.rcParams['ytick.labelsize'] = 16 matplotlib.rcParams['axes.labelsize'] = 18 num_stds = 2 data_generator = DataGenerator( num_samples=num_training_samples_per_class, device=device) if datasubset == 'sine': x_t, y_t, amp, phase = data_generator.generate_sinusoidal_data( noise_flag=True) y0 = amp * torch.sin(x0 + phase) else: x_t, y_t, slope, intercept = data_generator.generate_line_data( noise_flag=True) y0 = slope * x0 + intercept y_preds = get_task_prediction(x_t=x_t, y_t=y_t, x_v=x0) '''LOAD MAML DATA''' maml_folder = '{0:s}/MAML_mixed_sine_line'.format(dst_folder_root) maml_filename = 'MAML_mixed_{0:d}shot_{1:s}.pt'.format( num_training_samples_per_class, '{0:d}') i = 1 maml_checkpoint_filename = os.path.join(maml_folder, maml_filename.format(i)) while (os.path.exists(maml_checkpoint_filename)): i = i + 1 maml_checkpoint_filename = os.path.join( maml_folder, maml_filename.format(i)) print(maml_checkpoint_filename) maml_checkpoint = torch.load( os.path.join(maml_folder, maml_filename.format(i - 1)), map_location=lambda storage, loc: storage.cuda(gpu_id)) theta_maml = maml_checkpoint['theta'] y_pred_maml = get_task_prediction_maml(x_t=x_t, y_t=y_t, x_v=x0, meta_params=theta_maml) '''PLOT''' _, ax = plt.subplots(figsize=(5, 5)) y_top = torch.squeeze( torch.mean(y_preds, dim=0) + num_stds * torch.std(y_preds, dim=0)) y_bottom = torch.squeeze( torch.mean(y_preds, dim=0) - num_stds * torch.std(y_preds, dim=0)) ax.fill_between(x=torch.squeeze(x0).cpu().numpy(), y1=y_bottom.cpu().detach().numpy(), y2=y_top.cpu().detach().numpy(), alpha=0.25, color='C3', zorder=0, label='VAMPIRE') ax.plot(x0.cpu().numpy(), y0.cpu().numpy(), color='C7', linestyle='-', linewidth=3, zorder=1, label='Ground truth') ax.plot(x0.cpu().numpy(), y_pred_maml.cpu().detach().numpy(), color='C2', linestyle='--', linewidth=3, zorder=2, label='MAML') ax.scatter(x=x_t.cpu().numpy(), y=y_t.cpu().numpy(), color='C0', marker='^', s=300, zorder=3, label='Data') plt.xticks([-5, -2.5, 0, 2.5, 5]) plt.savefig(fname='img/mixed_sine_temp.svg', format='svg') return 0 else: from scipy.special import erf quantiles = np.arange(start=0., stop=1.1, step=0.1) cal_data = [] data_generator = DataGenerator( num_samples=num_training_samples_per_class, device=device) for _ in range(num_val_tasks): binary_flag = np.random.binomial(n=1, p=p_sine) if (binary_flag == 0): # generate sinusoidal data x_t, y_t, amp, phase = data_generator.generate_sinusoidal_data( noise_flag=True) y0 = amp * torch.sin(x0 + phase) else: # generate line data x_t, y_t, slope, intercept = data_generator.generate_line_data( noise_flag=True) y0 = slope * x0 + intercept y0 = y0.view(1, -1).cpu().numpy() # row vector y_preds = torch.stack( get_task_prediction(x_t=x_t, y_t=y_t, x_v=x0)) # K x len(x0) y_preds_np = torch.squeeze(y_preds, dim=-1).detach().cpu().numpy() y_preds_quantile = np.quantile(a=y_preds_np, q=quantiles, axis=0, keepdims=False) # ground truth cdf std = data_generator.noise_std cal_temp = (1 + erf( (y_preds_quantile - y0) / (np.sqrt(2) * std))) / 2 cal_temp_avg = np.mean(a=cal_temp, axis=1) # average for a task cal_data.append(cal_temp_avg) return cal_data else: accuracies = [] corrects = [] probability_pred = [] total_validation_samples = ( num_total_samples_per_class - num_training_samples_per_class) * num_classes_per_task if datasubset == 'train': all_class_data = all_class_train embedding_data = embedding_train elif datasubset == 'val': all_class_data = all_class_val embedding_data = embedding_val elif datasubset == 'test': all_class_data = all_class_test embedding_data = embedding_test else: sys.exit('Unknown datasubset for validation') all_class_names = list(all_class_data.keys()) all_task_names = itertools.combinations(all_class_names, r=num_classes_per_task) if train_flag: all_task_names = list(all_task_names) random.shuffle(all_task_names) task_count = 0 for class_labels in all_task_names: x_t, y_t, x_v, y_v = get_task_image_data( all_class_data, embedding_data, class_labels, num_total_samples_per_class, num_training_samples_per_class, device) y_pred_v = get_task_prediction(x_t, y_t, x_v, y_v=None) y_pred_v = torch.stack(y_pred_v) y_pred_v = sm_loss(y_pred_v) y_pred = torch.mean(input=y_pred_v, dim=0, keepdim=False) prob_pred, labels_pred = torch.max(input=y_pred, dim=1) correct = (labels_pred == y_v) corrects.extend(correct.detach().cpu().numpy()) accuracy = torch.sum(correct, dim=0).item() / total_validation_samples accuracies.append(accuracy) probability_pred.extend(prob_pred.detach().cpu().numpy()) task_count += 1 if not train_flag: print(task_count) if (task_count >= num_val_tasks): break if not return_uncertainty: return accuracies, all_task_names else: return corrects, probability_pred
def meta_train(train_subset=train_set): #region PREPARING DATALOADER if datasource == 'sine_line': data_generator = DataGenerator(num_samples=num_total_samples_per_class, device=device) # create dummy sampler all_class = [0] * 100 sampler = torch.utils.data.sampler.RandomSampler(data_source=all_class) train_loader = torch.utils.data.DataLoader( dataset=all_class, batch_size=num_classes_per_task, sampler=sampler, drop_last=True) else: all_class = all_class_train embedding = embedding_train sampler = torch.utils.data.sampler.RandomSampler(data_source=list( all_class.keys()), replacement=False) train_loader = torch.utils.data.DataLoader( dataset=list(all_class.keys()), batch_size=num_classes_per_task, sampler=sampler, drop_last=True) #endregion print('Start to train...') for epoch in range(resume_epoch, resume_epoch + num_epochs): # variables used to store information of each epoch for monitoring purpose meta_loss_saved = [] # meta loss to save val_accuracies = [] train_accuracies = [] meta_loss = 0 # accumulate the loss of many ensambling networks to descent gradient for meta update num_meta_updates_count = 0 meta_loss_avg_print = 0 # compute loss average to print meta_loss_avg_save = [] # meta loss to save task_count = 0 # a counter to decide when a minibatch of task is completed to perform meta update while (task_count < num_tasks_per_epoch): for class_labels in train_loader: if datasource == 'sine_line': x_t, y_t, x_v, y_v = get_task_sine_line_data( data_generator=data_generator, p_sine=p_sine, num_training_samples=num_training_samples_per_class, noise_flag=True) else: x_t, y_t, x_v, y_v = get_task_image_data( all_class, embedding, class_labels, num_total_samples_per_class, num_training_samples_per_class, device) loss_NLL = get_task_prediction(x_t, y_t, x_v, y_v) if torch.isnan(loss_NLL).item(): sys.exit('NaN error') # accumulate meta loss meta_loss = meta_loss + loss_NLL task_count = task_count + 1 if task_count % num_tasks_per_minibatch == 0: meta_loss = meta_loss / num_tasks_per_minibatch # accumulate into different variables for printing purpose meta_loss_avg_print += meta_loss.item() op_theta.zero_grad() meta_loss.backward() op_theta.step() # Printing losses num_meta_updates_count += 1 if (num_meta_updates_count % num_meta_updates_print == 0): meta_loss_avg_save.append(meta_loss_avg_print / num_meta_updates_count) print('{0:d}, {1:2.4f}'.format(task_count, meta_loss_avg_save[-1])) num_meta_updates_count = 0 meta_loss_avg_print = 0 if (task_count % num_tasks_save_loss == 0): meta_loss_saved.append(np.mean(meta_loss_avg_save)) meta_loss_avg_save = [] # print('Saving loss...') # val_accs, _ = meta_validation( # datasubset=val_set, # num_val_tasks=num_val_tasks, # return_uncertainty=False) # val_acc = np.mean(val_accs) # val_ci95 = 1.96*np.std(val_accs)/np.sqrt(num_val_tasks) # print('Validation accuracy = {0:2.4f} +/- {1:2.4f}'.format(val_acc, val_ci95)) # val_accuracies.append(val_acc) # train_accs, _ = meta_validation( # datasubset=train_set, # num_val_tasks=num_val_tasks, # return_uncertainty=False) # train_acc = np.mean(train_accs) # train_ci95 = 1.96*np.std(train_accs)/np.sqrt(num_val_tasks) # print('Train accuracy = {0:2.4f} +/- {1:2.4f}\n'.format(train_acc, train_ci95)) # train_accuracies.append(train_acc) # reset meta loss meta_loss = 0 if (task_count >= num_tasks_per_epoch): break if ((epoch + 1) % num_epochs_save == 0): checkpoint = { 'theta': theta, 'meta_loss': meta_loss_saved, 'val_accuracy': val_accuracies, 'train_accuracy': train_accuracies, 'op_theta': op_theta.state_dict() } print('SAVING WEIGHTS...') checkpoint_filename = ('{0:s}_{1:d}way_{2:d}shot_{3:d}.pt')\ .format(datasource, num_classes_per_task, num_training_samples_per_class, epoch + 1) print(checkpoint_filename) torch.save(checkpoint, os.path.join(dst_folder, checkpoint_filename)) print()
def meta_train(): if datasource == 'sine_line': data_generator = DataGenerator(num_samples=num_total_samples_per_class, device=device) # create dummy sampler all_class = [0] * 100 sampler = torch.utils.data.sampler.RandomSampler(data_source=all_class) train_loader = torch.utils.data.DataLoader( dataset=all_class, batch_size=num_classes_per_task, sampler=sampler, drop_last=True) else: all_class = all_class_train # all_class.update(all_class_val) embedding = embedding_train # embedding.update(embedding_val) sampler = torch.utils.data.sampler.RandomSampler(data_source=list( all_class.keys()), replacement=False) train_loader = torch.utils.data.DataLoader( dataset=list(all_class.keys()), batch_size=num_classes_per_task, sampler=sampler, drop_last=True) for epoch in range(resume_epoch, resume_epoch + num_epochs): # variables used to store information of each epoch for monitoring purpose loss_NLL_saved = [] kl_loss_saved = [] d_loss_saved = [] val_accuracies = [] train_accuracies = [] task_count = 0 # a counter to decide when a minibatch of task is completed to perform meta update meta_loss = 0 # accumulate the loss of many ensambling networks to descent gradient for meta update num_meta_updates_count = 0 loss_NLL_v = 0 loss_NLL_avg_print = 0 kl_loss = 0 kl_loss_avg_print = 0 d_loss = 0 d_loss_avg_print = 0 loss_NLL_avg_save = [] kl_loss_avg_save = [] d_loss_avg_save = [] task_count = 0 while (task_count < num_tasks_per_epoch): for class_labels in train_loader: if datasource == 'sine_line': x_t, y_t, x_v, y_v = get_task_sine_line_data( data_generator=data_generator, p_sine=p_sine, num_training_samples=num_training_samples_per_class, noise_flag=True) else: x_t, y_t, x_v, y_v = get_task_image_data( all_class, embedding, class_labels, num_total_samples_per_class, num_training_samples_per_class, device) loss_NLL, KL_loss, discriminator_loss = get_task_prediction( x_t=x_t, y_t=y_t, x_v=x_v, y_v=y_v, p_dropout=p_dropout_base, p_dropout_g=p_dropout_generator, p_dropout_d=p_dropout_discriminator, p_dropout_e=p_dropout_encoder) if torch.isnan(loss_NLL).item(): sys.exit('nan') loss_NLL_v += loss_NLL.item() if (loss_NLL.item() > 1): loss_NLL.data = torch.tensor([1.], device=device) # if (discriminator_loss.item() > d_loss_const): # discriminator_loss.data = torch.tensor([d_loss_const], device=device) kl_loss = kl_loss + KL_loss if (KL_loss.item() < 0): KL_loss.data = torch.tensor([0.], device=device) Ri = torch.sqrt((KL_loss + ri_const) / (2 * (total_validation_samples - 1))) meta_loss = meta_loss + loss_NLL + Ri d_loss = d_loss + discriminator_loss task_count = task_count + 1 if (task_count % num_tasks_per_minibatch == 0): # average over the number of tasks per minibatch meta_loss = meta_loss / num_tasks_per_minibatch loss_NLL_v /= num_tasks_per_minibatch kl_loss = kl_loss / num_tasks_per_minibatch d_loss = d_loss / num_tasks_per_minibatch # accumulate for printing purpose loss_NLL_avg_print += loss_NLL_v kl_loss_avg_print += kl_loss.item() d_loss_avg_print += d_loss.item() # adding R0 R0 = 0 for key in theta.keys(): R0 += theta[key].norm(2) R0 = torch.sqrt((L2_regularization * R0 + r0_const) / (2 * (num_tasks_per_minibatch - 1))) meta_loss += R0 # optimize theta op_theta.zero_grad() meta_loss.backward(retain_graph=True) torch.nn.utils.clip_grad_norm_(parameters=theta.values(), max_norm=clip_grad_value) torch.nn.utils.clip_grad_norm_( parameters=w_encoder.values(), max_norm=clip_grad_value) op_theta.step() # optimize the discriminator op_discriminator.zero_grad() d_loss.backward() torch.nn.utils.clip_grad_norm_( parameters=w_discriminator.values(), max_norm=clip_grad_value) op_discriminator.step() # Printing losses num_meta_updates_count += 1 if (num_meta_updates_count % num_meta_updates_print == 0): loss_NLL_avg_save.append(loss_NLL_avg_print / num_meta_updates_count) kl_loss_avg_save.append(kl_loss_avg_print / num_meta_updates_count) d_loss_avg_save.append(d_loss_avg_print / num_meta_updates_count) print('{0:d}, {1:2.4f}, {2:1.4f}, {3:2.4e}'.format( task_count, loss_NLL_avg_save[-1], kl_loss_avg_save[-1], d_loss_avg_save[-1])) num_meta_updates_count = 0 loss_NLL_avg_print = 0 kl_loss_avg_print = 0 d_loss_avg_print = 0 if (task_count % num_tasks_save_loss == 0): loss_NLL_saved.append(np.mean(loss_NLL_avg_save)) kl_loss_saved.append(np.mean(kl_loss_avg_save)) d_loss_saved.append(np.mean(d_loss_avg_save)) loss_NLL_avg_save = [] kl_loss_avg_save = [] d_loss_avg_save = [] # if datasource != 'sine_line': # val_accs = meta_validation( # datasubset=val_set, # num_val_tasks=num_val_tasks) # val_acc = np.mean(val_accs) # val_ci95 = 1.96*np.std(val_accs)/np.sqrt(num_val_tasks) # print('Validation accuracy = {0:2.4f} +/- {1:2.4f}'.format(val_acc, val_ci95)) # val_accuracies.append(val_acc) # train_accs = meta_validation( # datasubset=train_set, # num_val_tasks=num_val_tasks) # train_acc = np.mean(train_accs) # train_ci95 = 1.96*np.std(train_accs)/np.sqrt(num_val_tasks) # print('Train accuracy = {0:2.4f} +/- {1:2.4f}\n'.format(train_acc, train_ci95)) # train_accuracies.append(train_acc) # reset meta loss for the next minibatch of tasks meta_loss = 0 kl_loss = 0 d_loss = 0 loss_NLL_v = 0 if (task_count >= num_tasks_per_epoch): break if ((epoch + 1) % num_epochs_save == 0): checkpoint = { 'w_discriminator': w_discriminator, 'theta': theta, 'w_encoder': w_encoder, 'w_encoder_2': w_encoder_2, 'meta_loss': loss_NLL_saved, 'kl_loss': kl_loss_saved, 'd_loss': d_loss_saved, 'val_accuracy': val_accuracies, 'train_accuracy': train_accuracies, 'op_theta': op_theta.state_dict(), 'op_discriminator': op_discriminator.state_dict() } print('SAVING WEIGHTS...') checkpoint_filename = ('{0:s}_{1:d}way_{2:d}shot_{3:d}.pt')\ .format(datasource, num_classes_per_task, num_training_samples_per_class, epoch + 1) print(checkpoint_filename) torch.save(checkpoint, os.path.join(dst_folder, checkpoint_filename)) # scheduler.step() print()
def meta_train(): if datasource == 'sine_line': data_generator = DataGenerator( num_samples=num_total_samples_per_class, device=device ) # create dummy sampler all_class = [0]*100 sampler = torch.utils.data.sampler.RandomSampler(data_source=all_class) train_loader = torch.utils.data.DataLoader( dataset=all_class, batch_size=num_classes_per_task, sampler=sampler, drop_last=True ) else: all_class = all_class_train # all_class.update(all_class_val) embedding = embedding_train # embedding.update(embedding_val) sampler = torch.utils.data.sampler.RandomSampler(data_source=list(all_class.keys()), replacement=False) train_loader = torch.utils.data.DataLoader( dataset=list(all_class.keys()), batch_size=num_classes_per_task, sampler=sampler, drop_last=True ) for epoch in range(resume_epoch, resume_epoch + num_epochs): # variables used to store information of each epoch for monitoring purpose meta_loss_saved = [] # meta loss to save kl_loss_saved = [] val_accuracies = [] train_accuracies = [] task_count = 0 # a counter to decide when a minibatch of task is completed to perform meta update meta_loss = 0 # accumulate the loss of many ensambling networks to descent gradient for meta update num_meta_updates_count = 0 meta_loss_avg_print = 0 # compute loss average to print kl_loss = 0 kl_loss_avg_print = 0 meta_loss_avg_save = [] # meta loss to save kl_loss_avg_save = [] task_count = 0 while (task_count < num_tasks_per_epoch): for class_labels in train_loader: if datasource == 'sine_line': x_t, y_t, x_v, y_v = get_task_sine_line_data( data_generator=data_generator, p_sine=p_sine, num_training_samples=num_training_samples_per_class, noise_flag=True ) else: x_t, y_t, x_v, y_v = get_task_image_data( all_class, embedding, class_labels, num_total_samples_per_class, num_training_samples_per_class, device ) loss_NLL, KL_loss = get_task_prediction( x_t=x_t, y_t=y_t, x_v=x_v, y_v=y_v, p_dropout=p_base_dropout ) meta_loss = meta_loss + loss_NLL kl_loss = kl_loss + KL_loss task_count = task_count + 1 if torch.isnan(meta_loss).item(): sys.exit('nan') if (task_count % num_tasks_per_minibatch == 0): # average over the number of tasks per minibatch meta_loss = meta_loss/num_tasks_per_minibatch kl_loss = kl_loss/num_tasks_per_minibatch # accumulate for printing purpose meta_loss_avg_print += meta_loss.item() kl_loss_avg_print += kl_loss.item() op_theta.zero_grad() meta_loss.backward(retain_graph=True) # torch.nn.utils.clip_grad_norm_(parameters=theta.values(), max_norm=1) op_theta.step() # Printing losses num_meta_updates_count += 1 if (num_meta_updates_count % num_meta_updates_print == 0): meta_loss_avg_save.append(meta_loss_avg_print/num_meta_updates_count) kl_loss_avg_save.append(kl_loss_avg_print/num_meta_updates_count) print('{0:d}, {1:2.4f}, {2:1.4f}'.format( task_count, meta_loss_avg_save[-1], kl_loss_avg_save[-1] )) num_meta_updates_count = 0 meta_loss_avg_print = 0 kl_loss_avg_print = 0 if (task_count % num_tasks_save_loss == 0): meta_loss_saved.append(np.mean(meta_loss_avg_save)) kl_loss_saved.append(np.mean(kl_loss_avg_save)) meta_loss_avg_save = [] kl_loss_avg_save = [] # if datasource != 'sine_line': # val_accs = meta_validation( # datasubset=val_set, # num_val_tasks=num_val_tasks) # val_acc = np.mean(val_accs) # val_ci95 = 1.96*np.std(val_accs)/np.sqrt(num_val_tasks) # print('Validation accuracy = {0:2.4f} +/- {1:2.4f}'.format(val_acc, val_ci95)) # val_accuracies.append(val_acc) # train_accs = meta_validation( # datasubset=train_set, # num_val_tasks=num_val_tasks) # train_acc = np.mean(train_accs) # train_ci95 = 1.96*np.std(train_accs)/np.sqrt(num_val_tasks) # print('Train accuracy = {0:2.4f} +/- {1:2.4f}\n'.format(train_acc, train_ci95)) # train_accuracies.append(train_acc) # reset meta loss for the next minibatch of tasks meta_loss = 0 kl_loss = 0 if (task_count >= num_tasks_per_epoch): break if ((epoch + 1)% num_epochs_save == 0): checkpoint = { 'theta': theta, 'meta_loss': meta_loss_saved, 'kl_loss': kl_loss_saved, 'val_accuracy': val_accuracies, 'train_accuracy': train_accuracies, 'op_theta': op_theta.state_dict(), } print('SAVING WEIGHTS...') checkpoint_filename = ('{0:s}_{1:d}way_{2:d}shot_{3:d}.pt')\ .format(datasource, num_classes_per_task, num_training_samples_per_class, epoch + 1) print(checkpoint_filename) torch.save(checkpoint, os.path.join(dst_folder, checkpoint_filename)) # scheduler.step() print()
def meta_validation(datasubset, num_val_tasks, return_uncertainty=False): if datasource == 'sine_line': from scipy.special import erf x0 = torch.linspace(start=-5, end=5, steps=100, device=device).view(-1, 1) cal_avg = 0 data_generator = DataGenerator(num_samples=num_training_samples_per_class, device=device) for _ in range(num_val_tasks): binary_flag = np.random.binomial(n=1, p=p_sine) if (binary_flag == 0): # generate sinusoidal data x_t, y_t, amp, phase = data_generator.generate_sinusoidal_data(noise_flag=True) y0 = amp*torch.sin(x0 + phase) else: # generate line data x_t, y_t, slope, intercept = data_generator.generate_line_data(noise_flag=True) y0 = slope*x0 + intercept y0 = y0.view(1, -1).cpu().numpy() y_preds = get_task_prediction(x_t=x_t, y_t=y_t, x_v=x0) y_preds_np = torch.squeeze(y_preds, dim=-1).detach().cpu().numpy() # ground truth cdf std = data_generator.noise_std cal_temp = (1 + erf((y_preds_np - y0)/(np.sqrt(2)*std)))/2 cal_temp_avg = np.mean(a=cal_temp, axis=1) cal_avg = cal_avg + cal_temp_avg cal_avg = cal_avg / num_val_tasks return cal_avg else: accuracies = [] corrects = [] probability_pred = [] total_validation_samples = (num_total_samples_per_class - num_training_samples_per_class)*num_classes_per_task if datasubset == 'train': all_class_data = all_class_train embedding_data = embedding_train elif datasubset == 'val': all_class_data = all_class_val embedding_data = embedding_val elif datasubset == 'test': all_class_data = all_class_test embedding_data = embedding_test else: sys.exit('Unknown datasubset for validation') all_class_names = list(all_class_data.keys()) all_task_names = list(itertools.combinations(all_class_names, r=num_classes_per_task)) if train_flag: random.shuffle(all_task_names) task_count = 0 for class_labels in all_task_names: x_t, y_t, x_v, y_v = get_task_image_data( all_class_data, embedding_data, class_labels, num_total_samples_per_class, num_training_samples_per_class, device) y_pred_v = get_task_prediction(x_t, y_t, x_v, y_v=None) y_pred = sm_loss(y_pred_v) prob_pred, labels_pred = torch.max(input=y_pred, dim=1) correct = (labels_pred == y_v) corrects.extend(correct.detach().cpu().numpy()) accuracy = torch.sum(correct, dim=0).item()/total_validation_samples accuracies.append(accuracy) probability_pred.extend(prob_pred.detach().cpu().numpy()) task_count += 1 if not train_flag: print(task_count) if task_count >= num_val_tasks: break if not return_uncertainty: return accuracies, all_task_names else: return corrects, probability_pred