def validate_regression(uncertainty_flag, num_val_tasks=1): assert datasource == 'sine_line' if uncertainty_flag: from scipy.special import erf cal_avg = 0 else: from matplotlib import pyplot as plt data_generator = DataGenerator(num_samples=num_training_samples_per_class) std = data_generator.noise_std x0 = torch.linspace(start=-5, end=5, steps=100, device=device).view(-1, 1) for _ in range(num_val_tasks): # throw a coin to see 0 - 'sine' or 1 - 'line' binary_flag = np.random.binomial(n=1, p=p_sine) if (binary_flag == 0): # generate sinusoidal data x_t, y_t, amp, phase = data_generator.generate_sinusoidal_data(noise_flag=True) y0 = amp * np.sin(x0 + phase) else: # generate line data x_t, y_t, slope, intercept = data_generator.generate_line_data(noise_flag=True) y0 = slope * x0 + intercept x_t = torch.tensor(x_t, dtype=torch.float, device=device) y_t = torch.tensor(y_t, dtype=torch.float, device=device) y0 = y0.numpy().reshape(shape=(1, -1)) w_task = adapt_to_task(x=x_t, y=y_t, w0=theta) y_pred = predict_label_score(x=x0, w=w_task) y_pred = torch.squeeze(y_pred, dim=-1).detach().cpu().numpy() # convert to numpy array if uncertainty_flag: cal_temp = (1 + erf((y_pred - y0) / (np.sqrt(2) * std))) / 2 cal_temp_avg = np.mean(a=cal_temp, axis=1) cal_avg = cal_avg + cal_temp_avg else: plt.figure(figsize=(4, 4)) plt.subplot(111) plt.scatter(x_t.numpy(), y_t.numpy(), marker='^', label='Training data') plt.plot(x0.numpy(), y_pred, linewidth=1, linestyle='-', label='Prediction') plt.plot(x0, y0, linewidth=1, linestyle='--', label='Ground-truth') plt.xlabel('x') plt.ylabel('y') plt.legend() plt.tight_layout() plt.show() if uncertainty_flag: print('Average calibration \'score\' = {0}'.format(cal_avg / num_val_tasks))
def meta_validation(self, datasubset, num_val_tasks, return_uncertainty=False): x0 = torch.linspace(start=-5, end=5, steps=100, device=self.device).view(-1, 1) # vector quantiles = np.arange(start=0., stop=1.1, step=0.1) cal_data = [] data_generator = DataGenerator( num_samples=self.num_training_samples_per_class, device=self.device) for _ in range(num_val_tasks): # generate sinusoidal data x_t, y_t, amp, phase, slope = data_generator.generate_sinusoidal_data( noise_flag=True) y0 = amp * torch.sin(slope * x0 + phase) y0 = y0.view(1, -1).cpu().numpy() # row vector y_preds = torch.stack( self.get_task_prediction(x_t=x_t, y_t=y_t, x_v=x0)) # K x len(x0) y_preds_np = torch.squeeze(y_preds, dim=-1).detach().cpu().numpy() y_preds_quantile = np.quantile(a=y_preds_np, q=quantiles, axis=0, keepdims=False) # ground truth cdf std = data_generator.noise_std cal_temp = (1 + erf( (y_preds_quantile - y0) / (np.sqrt(2) * std))) / 2 cal_temp_avg = np.mean(a=cal_temp, axis=1) # average for a task cal_data.append(cal_temp_avg) return cal_data
def meta_validation(datasubset, num_val_tasks, return_uncertainty=False): if datasource == 'sine_line': x0 = torch.linspace(start=-5, end=5, steps=100, device=device).view(-1, 1) # vector if num_val_tasks == 0: from matplotlib import pyplot as plt import matplotlib matplotlib.rcParams['xtick.labelsize'] = 16 matplotlib.rcParams['ytick.labelsize'] = 16 matplotlib.rcParams['axes.labelsize'] = 18 num_stds = 2 data_generator = DataGenerator( num_samples=num_training_samples_per_class, device=device) if datasubset == 'sine': x_t, y_t, amp, phase = data_generator.generate_sinusoidal_data( noise_flag=True) y0 = amp * torch.sin(x0 + phase) else: x_t, y_t, slope, intercept = data_generator.generate_line_data( noise_flag=True) y0 = slope * x0 + intercept y_preds = get_task_prediction(x_t=x_t, y_t=y_t, x_v=x0) '''LOAD MAML DATA''' maml_folder = '{0:s}/MAML_mixed_sine_line'.format(dst_folder_root) maml_filename = 'MAML_mixed_{0:d}shot_{1:s}.pt'.format( num_training_samples_per_class, '{0:d}') i = 1 maml_checkpoint_filename = os.path.join(maml_folder, maml_filename.format(i)) while (os.path.exists(maml_checkpoint_filename)): i = i + 1 maml_checkpoint_filename = os.path.join( maml_folder, maml_filename.format(i)) print(maml_checkpoint_filename) maml_checkpoint = torch.load( os.path.join(maml_folder, maml_filename.format(i - 1)), map_location=lambda storage, loc: storage.cuda(gpu_id)) theta_maml = maml_checkpoint['theta'] y_pred_maml = get_task_prediction_maml(x_t=x_t, y_t=y_t, x_v=x0, meta_params=theta_maml) '''PLOT''' _, ax = plt.subplots(figsize=(5, 5)) y_top = torch.squeeze( torch.mean(y_preds, dim=0) + num_stds * torch.std(y_preds, dim=0)) y_bottom = torch.squeeze( torch.mean(y_preds, dim=0) - num_stds * torch.std(y_preds, dim=0)) ax.fill_between(x=torch.squeeze(x0).cpu().numpy(), y1=y_bottom.cpu().detach().numpy(), y2=y_top.cpu().detach().numpy(), alpha=0.25, color='C3', zorder=0, label='VAMPIRE') ax.plot(x0.cpu().numpy(), y0.cpu().numpy(), color='C7', linestyle='-', linewidth=3, zorder=1, label='Ground truth') ax.plot(x0.cpu().numpy(), y_pred_maml.cpu().detach().numpy(), color='C2', linestyle='--', linewidth=3, zorder=2, label='MAML') ax.scatter(x=x_t.cpu().numpy(), y=y_t.cpu().numpy(), color='C0', marker='^', s=300, zorder=3, label='Data') plt.xticks([-5, -2.5, 0, 2.5, 5]) plt.savefig(fname='img/mixed_sine_temp.svg', format='svg') return 0 else: from scipy.special import erf quantiles = np.arange(start=0., stop=1.1, step=0.1) cal_data = [] data_generator = DataGenerator( num_samples=num_training_samples_per_class, device=device) for _ in range(num_val_tasks): binary_flag = np.random.binomial(n=1, p=p_sine) if (binary_flag == 0): # generate sinusoidal data x_t, y_t, amp, phase = data_generator.generate_sinusoidal_data( noise_flag=True) y0 = amp * torch.sin(x0 + phase) else: # generate line data x_t, y_t, slope, intercept = data_generator.generate_line_data( noise_flag=True) y0 = slope * x0 + intercept y0 = y0.view(1, -1).cpu().numpy() # row vector y_preds = torch.stack( get_task_prediction(x_t=x_t, y_t=y_t, x_v=x0)) # K x len(x0) y_preds_np = torch.squeeze(y_preds, dim=-1).detach().cpu().numpy() y_preds_quantile = np.quantile(a=y_preds_np, q=quantiles, axis=0, keepdims=False) # ground truth cdf std = data_generator.noise_std cal_temp = (1 + erf( (y_preds_quantile - y0) / (np.sqrt(2) * std))) / 2 cal_temp_avg = np.mean(a=cal_temp, axis=1) # average for a task cal_data.append(cal_temp_avg) return cal_data else: accuracies = [] corrects = [] probability_pred = [] total_validation_samples = ( num_total_samples_per_class - num_training_samples_per_class) * num_classes_per_task if datasubset == 'train': all_class_data = all_class_train embedding_data = embedding_train elif datasubset == 'val': all_class_data = all_class_val embedding_data = embedding_val elif datasubset == 'test': all_class_data = all_class_test embedding_data = embedding_test else: sys.exit('Unknown datasubset for validation') all_class_names = list(all_class_data.keys()) all_task_names = itertools.combinations(all_class_names, r=num_classes_per_task) if train_flag: all_task_names = list(all_task_names) random.shuffle(all_task_names) task_count = 0 for class_labels in all_task_names: x_t, y_t, x_v, y_v = get_task_image_data( all_class_data, embedding_data, class_labels, num_total_samples_per_class, num_training_samples_per_class, device) y_pred_v = get_task_prediction(x_t, y_t, x_v, y_v=None) y_pred_v = torch.stack(y_pred_v) y_pred_v = sm_loss(y_pred_v) y_pred = torch.mean(input=y_pred_v, dim=0, keepdim=False) prob_pred, labels_pred = torch.max(input=y_pred, dim=1) correct = (labels_pred == y_v) corrects.extend(correct.detach().cpu().numpy()) accuracy = torch.sum(correct, dim=0).item() / total_validation_samples accuracies.append(accuracy) probability_pred.extend(prob_pred.detach().cpu().numpy()) task_count += 1 if not train_flag: print(task_count) if (task_count >= num_val_tasks): break if not return_uncertainty: return accuracies, all_task_names else: return corrects, probability_pred
def meta_train(train_subset=train_set): #region PREPARING DATALOADER if datasource == 'sine_line': data_generator = DataGenerator(num_samples=num_total_samples_per_class, device=device) # create dummy sampler all_class = [0] * 100 sampler = torch.utils.data.sampler.RandomSampler(data_source=all_class) train_loader = torch.utils.data.DataLoader( dataset=all_class, batch_size=num_classes_per_task, sampler=sampler, drop_last=True) else: all_class = all_class_train embedding = embedding_train sampler = torch.utils.data.sampler.RandomSampler(data_source=list( all_class.keys()), replacement=False) train_loader = torch.utils.data.DataLoader( dataset=list(all_class.keys()), batch_size=num_classes_per_task, sampler=sampler, drop_last=True) #endregion print('Start to train...') for epoch in range(resume_epoch, resume_epoch + num_epochs): # variables used to store information of each epoch for monitoring purpose meta_loss_saved = [] # meta loss to save val_accuracies = [] train_accuracies = [] meta_loss = 0 # accumulate the loss of many ensambling networks to descent gradient for meta update num_meta_updates_count = 0 meta_loss_avg_print = 0 # compute loss average to print meta_loss_avg_save = [] # meta loss to save task_count = 0 # a counter to decide when a minibatch of task is completed to perform meta update while (task_count < num_tasks_per_epoch): for class_labels in train_loader: if datasource == 'sine_line': x_t, y_t, x_v, y_v = get_task_sine_line_data( data_generator=data_generator, p_sine=p_sine, num_training_samples=num_training_samples_per_class, noise_flag=True) else: x_t, y_t, x_v, y_v = get_task_image_data( all_class, embedding, class_labels, num_total_samples_per_class, num_training_samples_per_class, device) loss_NLL = get_task_prediction(x_t, y_t, x_v, y_v) if torch.isnan(loss_NLL).item(): sys.exit('NaN error') # accumulate meta loss meta_loss = meta_loss + loss_NLL task_count = task_count + 1 if task_count % num_tasks_per_minibatch == 0: meta_loss = meta_loss / num_tasks_per_minibatch # accumulate into different variables for printing purpose meta_loss_avg_print += meta_loss.item() op_theta.zero_grad() meta_loss.backward() op_theta.step() # Printing losses num_meta_updates_count += 1 if (num_meta_updates_count % num_meta_updates_print == 0): meta_loss_avg_save.append(meta_loss_avg_print / num_meta_updates_count) print('{0:d}, {1:2.4f}'.format(task_count, meta_loss_avg_save[-1])) num_meta_updates_count = 0 meta_loss_avg_print = 0 if (task_count % num_tasks_save_loss == 0): meta_loss_saved.append(np.mean(meta_loss_avg_save)) meta_loss_avg_save = [] # print('Saving loss...') # val_accs, _ = meta_validation( # datasubset=val_set, # num_val_tasks=num_val_tasks, # return_uncertainty=False) # val_acc = np.mean(val_accs) # val_ci95 = 1.96*np.std(val_accs)/np.sqrt(num_val_tasks) # print('Validation accuracy = {0:2.4f} +/- {1:2.4f}'.format(val_acc, val_ci95)) # val_accuracies.append(val_acc) # train_accs, _ = meta_validation( # datasubset=train_set, # num_val_tasks=num_val_tasks, # return_uncertainty=False) # train_acc = np.mean(train_accs) # train_ci95 = 1.96*np.std(train_accs)/np.sqrt(num_val_tasks) # print('Train accuracy = {0:2.4f} +/- {1:2.4f}\n'.format(train_acc, train_ci95)) # train_accuracies.append(train_acc) # reset meta loss meta_loss = 0 if (task_count >= num_tasks_per_epoch): break if ((epoch + 1) % num_epochs_save == 0): checkpoint = { 'theta': theta, 'meta_loss': meta_loss_saved, 'val_accuracy': val_accuracies, 'train_accuracy': train_accuracies, 'op_theta': op_theta.state_dict() } print('SAVING WEIGHTS...') checkpoint_filename = ('{0:s}_{1:d}way_{2:d}shot_{3:d}.pt')\ .format(datasource, num_classes_per_task, num_training_samples_per_class, epoch + 1) print(checkpoint_filename) torch.save(checkpoint, os.path.join(dst_folder, checkpoint_filename)) print()
def meta_train(): if datasource == 'sine_line': data_generator = DataGenerator( num_samples=num_total_samples_per_class, device=device ) # create dummy sampler all_class = [0]*100 sampler = torch.utils.data.sampler.RandomSampler(data_source=all_class) train_loader = torch.utils.data.DataLoader( dataset=all_class, batch_size=num_classes_per_task, sampler=sampler, drop_last=True ) else: all_class = all_class_train # all_class.update(all_class_val) embedding = embedding_train # embedding.update(embedding_val) sampler = torch.utils.data.sampler.RandomSampler(data_source=list(all_class.keys()), replacement=False) train_loader = torch.utils.data.DataLoader( dataset=list(all_class.keys()), batch_size=num_classes_per_task, sampler=sampler, drop_last=True ) for epoch in range(resume_epoch, resume_epoch + num_epochs): # variables used to store information of each epoch for monitoring purpose meta_loss_saved = [] # meta loss to save kl_loss_saved = [] val_accuracies = [] train_accuracies = [] task_count = 0 # a counter to decide when a minibatch of task is completed to perform meta update meta_loss = 0 # accumulate the loss of many ensambling networks to descent gradient for meta update num_meta_updates_count = 0 meta_loss_avg_print = 0 # compute loss average to print kl_loss = 0 kl_loss_avg_print = 0 meta_loss_avg_save = [] # meta loss to save kl_loss_avg_save = [] task_count = 0 while (task_count < num_tasks_per_epoch): for class_labels in train_loader: if datasource == 'sine_line': x_t, y_t, x_v, y_v = get_task_sine_line_data( data_generator=data_generator, p_sine=p_sine, num_training_samples=num_training_samples_per_class, noise_flag=True ) else: x_t, y_t, x_v, y_v = get_task_image_data( all_class, embedding, class_labels, num_total_samples_per_class, num_training_samples_per_class, device ) loss_NLL, KL_loss = get_task_prediction( x_t=x_t, y_t=y_t, x_v=x_v, y_v=y_v, p_dropout=p_base_dropout ) meta_loss = meta_loss + loss_NLL kl_loss = kl_loss + KL_loss task_count = task_count + 1 if torch.isnan(meta_loss).item(): sys.exit('nan') if (task_count % num_tasks_per_minibatch == 0): # average over the number of tasks per minibatch meta_loss = meta_loss/num_tasks_per_minibatch kl_loss = kl_loss/num_tasks_per_minibatch # accumulate for printing purpose meta_loss_avg_print += meta_loss.item() kl_loss_avg_print += kl_loss.item() op_theta.zero_grad() meta_loss.backward(retain_graph=True) # torch.nn.utils.clip_grad_norm_(parameters=theta.values(), max_norm=1) op_theta.step() # Printing losses num_meta_updates_count += 1 if (num_meta_updates_count % num_meta_updates_print == 0): meta_loss_avg_save.append(meta_loss_avg_print/num_meta_updates_count) kl_loss_avg_save.append(kl_loss_avg_print/num_meta_updates_count) print('{0:d}, {1:2.4f}, {2:1.4f}'.format( task_count, meta_loss_avg_save[-1], kl_loss_avg_save[-1] )) num_meta_updates_count = 0 meta_loss_avg_print = 0 kl_loss_avg_print = 0 if (task_count % num_tasks_save_loss == 0): meta_loss_saved.append(np.mean(meta_loss_avg_save)) kl_loss_saved.append(np.mean(kl_loss_avg_save)) meta_loss_avg_save = [] kl_loss_avg_save = [] # if datasource != 'sine_line': # val_accs = meta_validation( # datasubset=val_set, # num_val_tasks=num_val_tasks) # val_acc = np.mean(val_accs) # val_ci95 = 1.96*np.std(val_accs)/np.sqrt(num_val_tasks) # print('Validation accuracy = {0:2.4f} +/- {1:2.4f}'.format(val_acc, val_ci95)) # val_accuracies.append(val_acc) # train_accs = meta_validation( # datasubset=train_set, # num_val_tasks=num_val_tasks) # train_acc = np.mean(train_accs) # train_ci95 = 1.96*np.std(train_accs)/np.sqrt(num_val_tasks) # print('Train accuracy = {0:2.4f} +/- {1:2.4f}\n'.format(train_acc, train_ci95)) # train_accuracies.append(train_acc) # reset meta loss for the next minibatch of tasks meta_loss = 0 kl_loss = 0 if (task_count >= num_tasks_per_epoch): break if ((epoch + 1)% num_epochs_save == 0): checkpoint = { 'theta': theta, 'meta_loss': meta_loss_saved, 'kl_loss': kl_loss_saved, 'val_accuracy': val_accuracies, 'train_accuracy': train_accuracies, 'op_theta': op_theta.state_dict(), } print('SAVING WEIGHTS...') checkpoint_filename = ('{0:s}_{1:d}way_{2:d}shot_{3:d}.pt')\ .format(datasource, num_classes_per_task, num_training_samples_per_class, epoch + 1) print(checkpoint_filename) torch.save(checkpoint, os.path.join(dst_folder, checkpoint_filename)) # scheduler.step() print()
def meta_train(): if datasource == 'sine_line': data_generator = DataGenerator(num_samples=num_total_samples_per_class, device=device) # create dummy sampler all_class = [0] * 100 sampler = torch.utils.data.sampler.RandomSampler(data_source=all_class) train_loader = torch.utils.data.DataLoader( dataset=all_class, batch_size=num_classes_per_task, sampler=sampler, drop_last=True) else: all_class = all_class_train # all_class.update(all_class_val) embedding = embedding_train # embedding.update(embedding_val) sampler = torch.utils.data.sampler.RandomSampler(data_source=list( all_class.keys()), replacement=False) train_loader = torch.utils.data.DataLoader( dataset=list(all_class.keys()), batch_size=num_classes_per_task, sampler=sampler, drop_last=True) for epoch in range(resume_epoch, resume_epoch + num_epochs): # variables used to store information of each epoch for monitoring purpose loss_NLL_saved = [] kl_loss_saved = [] d_loss_saved = [] val_accuracies = [] train_accuracies = [] task_count = 0 # a counter to decide when a minibatch of task is completed to perform meta update meta_loss = 0 # accumulate the loss of many ensambling networks to descent gradient for meta update num_meta_updates_count = 0 loss_NLL_v = 0 loss_NLL_avg_print = 0 kl_loss = 0 kl_loss_avg_print = 0 d_loss = 0 d_loss_avg_print = 0 loss_NLL_avg_save = [] kl_loss_avg_save = [] d_loss_avg_save = [] task_count = 0 while (task_count < num_tasks_per_epoch): for class_labels in train_loader: if datasource == 'sine_line': x_t, y_t, x_v, y_v = get_task_sine_line_data( data_generator=data_generator, p_sine=p_sine, num_training_samples=num_training_samples_per_class, noise_flag=True) else: x_t, y_t, x_v, y_v = get_task_image_data( all_class, embedding, class_labels, num_total_samples_per_class, num_training_samples_per_class, device) loss_NLL, KL_loss, discriminator_loss = get_task_prediction( x_t=x_t, y_t=y_t, x_v=x_v, y_v=y_v, p_dropout=p_dropout_base, p_dropout_g=p_dropout_generator, p_dropout_d=p_dropout_discriminator, p_dropout_e=p_dropout_encoder) if torch.isnan(loss_NLL).item(): sys.exit('nan') loss_NLL_v += loss_NLL.item() if (loss_NLL.item() > 1): loss_NLL.data = torch.tensor([1.], device=device) # if (discriminator_loss.item() > d_loss_const): # discriminator_loss.data = torch.tensor([d_loss_const], device=device) kl_loss = kl_loss + KL_loss if (KL_loss.item() < 0): KL_loss.data = torch.tensor([0.], device=device) Ri = torch.sqrt((KL_loss + ri_const) / (2 * (total_validation_samples - 1))) meta_loss = meta_loss + loss_NLL + Ri d_loss = d_loss + discriminator_loss task_count = task_count + 1 if (task_count % num_tasks_per_minibatch == 0): # average over the number of tasks per minibatch meta_loss = meta_loss / num_tasks_per_minibatch loss_NLL_v /= num_tasks_per_minibatch kl_loss = kl_loss / num_tasks_per_minibatch d_loss = d_loss / num_tasks_per_minibatch # accumulate for printing purpose loss_NLL_avg_print += loss_NLL_v kl_loss_avg_print += kl_loss.item() d_loss_avg_print += d_loss.item() # adding R0 R0 = 0 for key in theta.keys(): R0 += theta[key].norm(2) R0 = torch.sqrt((L2_regularization * R0 + r0_const) / (2 * (num_tasks_per_minibatch - 1))) meta_loss += R0 # optimize theta op_theta.zero_grad() meta_loss.backward(retain_graph=True) torch.nn.utils.clip_grad_norm_(parameters=theta.values(), max_norm=clip_grad_value) torch.nn.utils.clip_grad_norm_( parameters=w_encoder.values(), max_norm=clip_grad_value) op_theta.step() # optimize the discriminator op_discriminator.zero_grad() d_loss.backward() torch.nn.utils.clip_grad_norm_( parameters=w_discriminator.values(), max_norm=clip_grad_value) op_discriminator.step() # Printing losses num_meta_updates_count += 1 if (num_meta_updates_count % num_meta_updates_print == 0): loss_NLL_avg_save.append(loss_NLL_avg_print / num_meta_updates_count) kl_loss_avg_save.append(kl_loss_avg_print / num_meta_updates_count) d_loss_avg_save.append(d_loss_avg_print / num_meta_updates_count) print('{0:d}, {1:2.4f}, {2:1.4f}, {3:2.4e}'.format( task_count, loss_NLL_avg_save[-1], kl_loss_avg_save[-1], d_loss_avg_save[-1])) num_meta_updates_count = 0 loss_NLL_avg_print = 0 kl_loss_avg_print = 0 d_loss_avg_print = 0 if (task_count % num_tasks_save_loss == 0): loss_NLL_saved.append(np.mean(loss_NLL_avg_save)) kl_loss_saved.append(np.mean(kl_loss_avg_save)) d_loss_saved.append(np.mean(d_loss_avg_save)) loss_NLL_avg_save = [] kl_loss_avg_save = [] d_loss_avg_save = [] # if datasource != 'sine_line': # val_accs = meta_validation( # datasubset=val_set, # num_val_tasks=num_val_tasks) # val_acc = np.mean(val_accs) # val_ci95 = 1.96*np.std(val_accs)/np.sqrt(num_val_tasks) # print('Validation accuracy = {0:2.4f} +/- {1:2.4f}'.format(val_acc, val_ci95)) # val_accuracies.append(val_acc) # train_accs = meta_validation( # datasubset=train_set, # num_val_tasks=num_val_tasks) # train_acc = np.mean(train_accs) # train_ci95 = 1.96*np.std(train_accs)/np.sqrt(num_val_tasks) # print('Train accuracy = {0:2.4f} +/- {1:2.4f}\n'.format(train_acc, train_ci95)) # train_accuracies.append(train_acc) # reset meta loss for the next minibatch of tasks meta_loss = 0 kl_loss = 0 d_loss = 0 loss_NLL_v = 0 if (task_count >= num_tasks_per_epoch): break if ((epoch + 1) % num_epochs_save == 0): checkpoint = { 'w_discriminator': w_discriminator, 'theta': theta, 'w_encoder': w_encoder, 'w_encoder_2': w_encoder_2, 'meta_loss': loss_NLL_saved, 'kl_loss': kl_loss_saved, 'd_loss': d_loss_saved, 'val_accuracy': val_accuracies, 'train_accuracy': train_accuracies, 'op_theta': op_theta.state_dict(), 'op_discriminator': op_discriminator.state_dict() } print('SAVING WEIGHTS...') checkpoint_filename = ('{0:s}_{1:d}way_{2:d}shot_{3:d}.pt')\ .format(datasource, num_classes_per_task, num_training_samples_per_class, epoch + 1) print(checkpoint_filename) torch.save(checkpoint, os.path.join(dst_folder, checkpoint_filename)) # scheduler.step() print()
def meta_train(): if datasource == 'sine_line': data_generator = DataGenerator(num_samples=num_samples_per_class) # create dummy sampler all_class_train = [0] * 10 else: all_class_train, all_data_train = load_dataset( dataset_name=datasource, subset=train_set ) all_class_val, all_data_val = load_dataset( dataset_name=datasource, subset=val_set ) all_class_train.update(all_class_val) all_data_train.update(all_data_val) # initialize data loader train_loader = initialize_dataloader( all_classes=[class_label for class_label in all_class_train], num_classes_per_task=num_classes_per_task ) for epoch in range(resume_epoch, resume_epoch + num_epochs): # variables used to store information of each epoch for monitoring purpose meta_loss_saved = [] # meta loss to save val_accuracies = [] train_accuracies = [] meta_loss = 0 # accumulate the loss of many ensambling networks to descent gradient for meta update num_meta_updates_count = 0 meta_loss_avg_print = 0 # compute loss average to print meta_loss_avg_save = [] # meta loss to save task_count = 0 # a counter to decide when a minibatch of task is completed to perform meta update while (task_count < num_tasks_per_epoch): for class_labels in train_loader: if datasource == 'sine_line': x_t, y_t, x_v, y_v = get_task_sine_line_data( data_generator=data_generator, p_sine=p_sine, num_training_samples=num_training_samples_per_class, noise_flag=True ) x_t = torch.tensor(x_t, dtype=torch.float, device=device) y_t = torch.tensor(y_t, dtype=torch.float, device=device) x_v = torch.tensor(x_v, dtype=torch.float, device=device) y_v = torch.tensor(y_v, dtype=torch.float, device=device) else: x_t, y_t, x_v, y_v = get_train_val_task_data( all_classes=all_class_train, all_data=all_data_train, class_labels=class_labels, num_samples_per_class=num_samples_per_class, num_training_samples_per_class=num_training_samples_per_class, device=device ) w_task = adapt_to_task(x=x_t, y=y_t, w0=theta) y_pred = net.forward(x=x_v, w=w_task) loss_NLL = loss_fn(input=y_pred, target=y_v) if torch.isnan(loss_NLL).item(): sys.exit('NaN error') # accumulate meta loss meta_loss = meta_loss + loss_NLL task_count = task_count + 1 if task_count % num_tasks_per_minibatch == 0: meta_loss = meta_loss/num_tasks_per_minibatch # accumulate into different variables for printing purpose meta_loss_avg_print += meta_loss.item() op_theta.zero_grad() meta_loss.backward() torch.nn.utils.clip_grad_norm_(parameters=theta.values(), max_norm=10) op_theta.step() # Printing losses num_meta_updates_count += 1 if (num_meta_updates_count % num_meta_updates_print == 0): meta_loss_avg_save.append(meta_loss_avg_print/num_meta_updates_count) print('{0:d}, {1:2.4f}'.format( task_count, meta_loss_avg_save[-1] )) num_meta_updates_count = 0 meta_loss_avg_print = 0 if (task_count % num_tasks_save_loss == 0): meta_loss_saved.append(np.mean(meta_loss_avg_save)) meta_loss_avg_save = [] print('Saving loss...') if datasource != 'sine_line': val_accs = validate_classification( all_classes=all_class_val, all_data=all_data_val, num_val_tasks=100, p_rand=0.5, uncertainty=False, csv_flag=False ) val_acc = np.mean(val_accs) val_ci95 = 1.96*np.std(val_accs)/np.sqrt(num_val_tasks) print('Validation accuracy = {0:2.4f} +/- {1:2.4f}'.format(val_acc, val_ci95)) val_accuracies.append(val_acc) train_accs = validate_classification( all_classes=all_class_train, all_data=all_data_train, num_val_tasks=100, p_rand=0.5, uncertainty=False, csv_flag=False ) train_acc = np.mean(train_accs) train_ci95 = 1.96*np.std(train_accs)/np.sqrt(num_val_tasks) print('Train accuracy = {0:2.4f} +/- {1:2.4f}\n'.format(train_acc, train_ci95)) train_accuracies.append(train_acc) # reset meta loss meta_loss = 0 if (task_count >= num_tasks_per_epoch): break if ((epoch + 1)% num_epochs_save == 0): checkpoint = { 'theta': theta, 'meta_loss': meta_loss_saved, 'val_accuracy': val_accuracies, 'train_accuracy': train_accuracies, 'op_theta': op_theta.state_dict() } print('SAVING WEIGHTS...') checkpoint_filename = 'Epoch_{0:d}.pt'.format(epoch + 1) print(checkpoint_filename) torch.save(checkpoint, os.path.join(dst_folder, checkpoint_filename)) scheduler.step() print()
def validate_regression(uncertainty_flag, num_val_tasks=1): assert datasource == 'sine_line' if uncertainty_flag: from scipy.special import erf quantiles = np.arange(start=0., stop=1.1, step=0.1) filename = 'VAMPIRE_calibration_{0:s}_{1:d}shot_{2:d}.csv'.format( datasource, num_training_samples_per_class, resume_epoch) outfile = open(file=os.path.join('csv', filename), mode='w') wr = csv.writer(outfile, quoting=csv.QUOTE_NONE) else: # visualization from matplotlib import pyplot as plt num_stds_plot = 2 data_generator = DataGenerator(num_samples=num_training_samples_per_class) std = data_generator.noise_std x0 = torch.linspace(start=-5, end=5, steps=100, device=device).view(-1, 1) for _ in range(num_val_tasks): # throw a coin to see 0 - 'sine' or 1 - 'line' binary_flag = np.random.binomial(n=1, p=p_sine) if (binary_flag == 0): # generate sinusoidal data x_t, y_t, amp, phase = data_generator.generate_sinusoidal_data( noise_flag=True) y0 = amp * np.sin(x0 + phase) else: # generate line data x_t, y_t, slope, intercept = data_generator.generate_line_data( noise_flag=True) y0 = slope * x0 + intercept x_t = torch.tensor(x_t, dtype=torch.float, device=device) y_t = torch.tensor(y_t, dtype=torch.float, device=device) y0 = y0.numpy().reshape(shape=(1, -1)) q = adapt_to_task(x=x_t, y=y_t, theta0=theta) y_pred = predict(x=x0, q=q, num_models=Lv) y_pred = torch.squeeze(y_pred, dim=-1).detach().cpu().numpy( ) # convert to numpy array Lv x len(x0) if uncertainty_flag: # each column in y_pred represents a distribution for that x0-value at that column # hence, we calculate the quantile along axis 0 y_preds_quantile = np.quantile(a=y_pred, q=quantiles, axis=0, keepdims=False) # ground truth cdf cal_temp = (1 + erf( (y_preds_quantile - y0) / (np.sqrt(2) * std))) / 2 cal_temp_avg = np.mean(a=cal_temp, axis=1) # average for a task wr.writerow(cal_temp_avg) else: y_mean = np.mean(a=y_pred, axis=0) y_std = np.std(a=y_pred, axis=0) y_top = y_mean + num_stds_plot * y_std y_bottom = y_mean - num_stds_plot * y_std plt.figure(figsize=(4, 4)) plt.scatter(x_t.numpy(), y_t.numpy(), marker='^', label='Training data') plt.fill_between(x=torch.squeeze(x0).cpu().numpy(), y1=y_bottom.cpu().detach().numpy(), y2=y_top.cpu().detach().numpy(), alpha=0.25, zorder=0, label='Prediction') plt.plot(x0, y0, linewidth=1, linestyle='--', label='Ground-truth') plt.xlabel('x') plt.ylabel('y') plt.legend() plt.tight_layout() plt.show() if uncertainty_flag: outfile.close() print('Reliability data is stored at {0:s}'.format( os.path.join('csv', filename)))
def meta_validation(datasubset, num_val_tasks, return_uncertainty=False): if datasource == 'sine_line': from scipy.special import erf x0 = torch.linspace(start=-5, end=5, steps=100, device=device).view(-1, 1) cal_avg = 0 data_generator = DataGenerator(num_samples=num_training_samples_per_class, device=device) for _ in range(num_val_tasks): binary_flag = np.random.binomial(n=1, p=p_sine) if (binary_flag == 0): # generate sinusoidal data x_t, y_t, amp, phase = data_generator.generate_sinusoidal_data(noise_flag=True) y0 = amp*torch.sin(x0 + phase) else: # generate line data x_t, y_t, slope, intercept = data_generator.generate_line_data(noise_flag=True) y0 = slope*x0 + intercept y0 = y0.view(1, -1).cpu().numpy() y_preds = get_task_prediction(x_t=x_t, y_t=y_t, x_v=x0) y_preds_np = torch.squeeze(y_preds, dim=-1).detach().cpu().numpy() # ground truth cdf std = data_generator.noise_std cal_temp = (1 + erf((y_preds_np - y0)/(np.sqrt(2)*std)))/2 cal_temp_avg = np.mean(a=cal_temp, axis=1) cal_avg = cal_avg + cal_temp_avg cal_avg = cal_avg / num_val_tasks return cal_avg else: accuracies = [] corrects = [] probability_pred = [] total_validation_samples = (num_total_samples_per_class - num_training_samples_per_class)*num_classes_per_task if datasubset == 'train': all_class_data = all_class_train embedding_data = embedding_train elif datasubset == 'val': all_class_data = all_class_val embedding_data = embedding_val elif datasubset == 'test': all_class_data = all_class_test embedding_data = embedding_test else: sys.exit('Unknown datasubset for validation') all_class_names = list(all_class_data.keys()) all_task_names = list(itertools.combinations(all_class_names, r=num_classes_per_task)) if train_flag: random.shuffle(all_task_names) task_count = 0 for class_labels in all_task_names: x_t, y_t, x_v, y_v = get_task_image_data( all_class_data, embedding_data, class_labels, num_total_samples_per_class, num_training_samples_per_class, device) y_pred_v = get_task_prediction(x_t, y_t, x_v, y_v=None) y_pred = sm_loss(y_pred_v) prob_pred, labels_pred = torch.max(input=y_pred, dim=1) correct = (labels_pred == y_v) corrects.extend(correct.detach().cpu().numpy()) accuracy = torch.sum(correct, dim=0).item()/total_validation_samples accuracies.append(accuracy) probability_pred.extend(prob_pred.detach().cpu().numpy()) task_count += 1 if not train_flag: print(task_count) if task_count >= num_val_tasks: break if not return_uncertainty: return accuracies, all_task_names else: return corrects, probability_pred
def meta_train(self, train_subset='train'): data_generator = DataGenerator( num_samples=self.num_total_samples_per_class, device=self.device) #create dummy sampler all_class = [0] * 100 sampler = torch.utils.data.sampler.RandomSampler(data_source=all_class) train_loader = torch.utils.data.DataLoader( dataset=all_class, batch_size=self.num_classes_per_task, sampler=sampler, drop_last=True) print('Start to train...') for epoch in range(0, self.num_epochs): #variables for monitoring meta_loss_saved = [] val_accuracies = [] train_accuracies = [] meta_loss = 0 #accumulate the loss of many ensembling networks num_meta_updates_count = 0 meta_loss_avg_print = 0 #meta_mse_avg_print = 0 meta_loss_avg_save = [] #meta_mse_avg_save = [] task_count = 0 while (task_count < self.num_tasks_per_epoch): for class_labels in train_loader: #print class labels probably x_t, y_t, x_v, y_v = get_task_sine_data( data_generator=data_generator, p_sine=self.p_sine, num_training_samples=self. num_training_samples_per_class, noise_flag=True) chaser, leader, y_pred = self.get_task_prediction( x_t, y_t, x_v, y_v) loss_NLL = self.get_meta_loss(chaser, leader) if torch.isnan(loss_NLL).item(): sys.exit('NaN error') meta_loss = meta_loss + loss_NLL #meta_mse = self.loss(y_pred, y_v) task_count = task_count + 1 if task_count % self.num_tasks_per_minibatch == 0: meta_loss = meta_loss / self.num_tasks_per_minibatch #meta_mse = meta_mse/self.num_tasks_per_minibatch # accumulate into different variables for printing purpose meta_loss_avg_print += meta_loss.item() #meta_mse_avg_print += meta_mse.item() self.op_theta.zero_grad() meta_loss.backward() self.op_theta.step() # Printing losses num_meta_updates_count += 1 if (num_meta_updates_count % self.num_meta_updates_print == 0): meta_loss_avg_save.append(meta_loss_avg_print / num_meta_updates_count) #meta_mse_avg_save.append(meta_mse_avg_print/num_meta_updates_count) print('{0:d}, {1:2.4f}, {1:2.4f}'.format( task_count, meta_loss_avg_save[-1] #meta_mse_avg_save[-1] )) num_meta_updates_count = 0 meta_loss_avg_print = 0 #meta_mse_avg_print = 0 if (task_count % self.num_tasks_save_loss == 0): meta_loss_saved.append(np.mean(meta_loss_avg_save)) meta_loss_avg_save = [] #meta_mse_avg_save = [] # print('Saving loss...') # val_accs, _ = meta_validation( # datasubset=val_set, # num_val_tasks=num_val_tasks, # return_uncertainty=False) # val_acc = np.mean(val_accs) # val_ci95 = 1.96*np.std(val_accs)/np.sqrt(num_val_tasks) # print('Validation accuracy = {0:2.4f} +/- {1:2.4f}'.format(val_acc, val_ci95)) # val_accuracies.append(val_acc) # train_accs, _ = meta_validation( # datasubset=train_set, # num_val_tasks=num_val_tasks, # return_uncertainty=False) # train_acc = np.mean(train_accs) # train_ci95 = 1.96*np.std(train_accs)/np.sqrt(num_val_tasks) # print('Train accuracy = {0:2.4f} +/- {1:2.4f}\n'.format(train_acc, train_ci95)) # train_accuracies.append(train_acc) # reset meta loss meta_loss = 0 if (task_count >= self.num_tasks_per_epoch): break if ((epoch + 1) % self.num_epochs_save == 0): checkpoint = { 'theta': self.theta, 'meta_loss': meta_loss_saved, 'val_accuracy': val_accuracies, 'train_accuracy': train_accuracies, 'op_theta': self.op_theta.state_dict() } print('SAVING WEIGHTS...') checkpoint_filename = ('{0:s}_{1:d}way_{2:d}shot_{3:d}.pt')\ .format('sine_line', self.num_classes_per_task, self.num_training_samples_per_class, epoch + 1) print(checkpoint_filename) torch.save(checkpoint, os.path.join(self.dst_folder, checkpoint_filename)) print(checkpoint['meta_loss']) print()