def meta_train(train_subset=train_set): #region PREPARING DATALOADER if datasource == 'sine_line': data_generator = DataGenerator(num_samples=num_total_samples_per_class, device=device) # create dummy sampler all_class = [0] * 100 sampler = torch.utils.data.sampler.RandomSampler(data_source=all_class) train_loader = torch.utils.data.DataLoader( dataset=all_class, batch_size=num_classes_per_task, sampler=sampler, drop_last=True) else: all_class = all_class_train embedding = embedding_train sampler = torch.utils.data.sampler.RandomSampler(data_source=list( all_class.keys()), replacement=False) train_loader = torch.utils.data.DataLoader( dataset=list(all_class.keys()), batch_size=num_classes_per_task, sampler=sampler, drop_last=True) #endregion print('Start to train...') for epoch in range(resume_epoch, resume_epoch + num_epochs): # variables used to store information of each epoch for monitoring purpose meta_loss_saved = [] # meta loss to save val_accuracies = [] train_accuracies = [] meta_loss = 0 # accumulate the loss of many ensambling networks to descent gradient for meta update num_meta_updates_count = 0 meta_loss_avg_print = 0 # compute loss average to print meta_loss_avg_save = [] # meta loss to save task_count = 0 # a counter to decide when a minibatch of task is completed to perform meta update while (task_count < num_tasks_per_epoch): for class_labels in train_loader: if datasource == 'sine_line': x_t, y_t, x_v, y_v = get_task_sine_line_data( data_generator=data_generator, p_sine=p_sine, num_training_samples=num_training_samples_per_class, noise_flag=True) else: x_t, y_t, x_v, y_v = get_task_image_data( all_class, embedding, class_labels, num_total_samples_per_class, num_training_samples_per_class, device) loss_NLL = get_task_prediction(x_t, y_t, x_v, y_v) if torch.isnan(loss_NLL).item(): sys.exit('NaN error') # accumulate meta loss meta_loss = meta_loss + loss_NLL task_count = task_count + 1 if task_count % num_tasks_per_minibatch == 0: meta_loss = meta_loss / num_tasks_per_minibatch # accumulate into different variables for printing purpose meta_loss_avg_print += meta_loss.item() op_theta.zero_grad() meta_loss.backward() op_theta.step() # Printing losses num_meta_updates_count += 1 if (num_meta_updates_count % num_meta_updates_print == 0): meta_loss_avg_save.append(meta_loss_avg_print / num_meta_updates_count) print('{0:d}, {1:2.4f}'.format(task_count, meta_loss_avg_save[-1])) num_meta_updates_count = 0 meta_loss_avg_print = 0 if (task_count % num_tasks_save_loss == 0): meta_loss_saved.append(np.mean(meta_loss_avg_save)) meta_loss_avg_save = [] # print('Saving loss...') # val_accs, _ = meta_validation( # datasubset=val_set, # num_val_tasks=num_val_tasks, # return_uncertainty=False) # val_acc = np.mean(val_accs) # val_ci95 = 1.96*np.std(val_accs)/np.sqrt(num_val_tasks) # print('Validation accuracy = {0:2.4f} +/- {1:2.4f}'.format(val_acc, val_ci95)) # val_accuracies.append(val_acc) # train_accs, _ = meta_validation( # datasubset=train_set, # num_val_tasks=num_val_tasks, # return_uncertainty=False) # train_acc = np.mean(train_accs) # train_ci95 = 1.96*np.std(train_accs)/np.sqrt(num_val_tasks) # print('Train accuracy = {0:2.4f} +/- {1:2.4f}\n'.format(train_acc, train_ci95)) # train_accuracies.append(train_acc) # reset meta loss meta_loss = 0 if (task_count >= num_tasks_per_epoch): break if ((epoch + 1) % num_epochs_save == 0): checkpoint = { 'theta': theta, 'meta_loss': meta_loss_saved, 'val_accuracy': val_accuracies, 'train_accuracy': train_accuracies, 'op_theta': op_theta.state_dict() } print('SAVING WEIGHTS...') checkpoint_filename = ('{0:s}_{1:d}way_{2:d}shot_{3:d}.pt')\ .format(datasource, num_classes_per_task, num_training_samples_per_class, epoch + 1) print(checkpoint_filename) torch.save(checkpoint, os.path.join(dst_folder, checkpoint_filename)) print()
def meta_train(): if datasource == 'sine_line': data_generator = DataGenerator( num_samples=num_total_samples_per_class, device=device ) # create dummy sampler all_class = [0]*100 sampler = torch.utils.data.sampler.RandomSampler(data_source=all_class) train_loader = torch.utils.data.DataLoader( dataset=all_class, batch_size=num_classes_per_task, sampler=sampler, drop_last=True ) else: all_class = all_class_train # all_class.update(all_class_val) embedding = embedding_train # embedding.update(embedding_val) sampler = torch.utils.data.sampler.RandomSampler(data_source=list(all_class.keys()), replacement=False) train_loader = torch.utils.data.DataLoader( dataset=list(all_class.keys()), batch_size=num_classes_per_task, sampler=sampler, drop_last=True ) for epoch in range(resume_epoch, resume_epoch + num_epochs): # variables used to store information of each epoch for monitoring purpose meta_loss_saved = [] # meta loss to save kl_loss_saved = [] val_accuracies = [] train_accuracies = [] task_count = 0 # a counter to decide when a minibatch of task is completed to perform meta update meta_loss = 0 # accumulate the loss of many ensambling networks to descent gradient for meta update num_meta_updates_count = 0 meta_loss_avg_print = 0 # compute loss average to print kl_loss = 0 kl_loss_avg_print = 0 meta_loss_avg_save = [] # meta loss to save kl_loss_avg_save = [] task_count = 0 while (task_count < num_tasks_per_epoch): for class_labels in train_loader: if datasource == 'sine_line': x_t, y_t, x_v, y_v = get_task_sine_line_data( data_generator=data_generator, p_sine=p_sine, num_training_samples=num_training_samples_per_class, noise_flag=True ) else: x_t, y_t, x_v, y_v = get_task_image_data( all_class, embedding, class_labels, num_total_samples_per_class, num_training_samples_per_class, device ) loss_NLL, KL_loss = get_task_prediction( x_t=x_t, y_t=y_t, x_v=x_v, y_v=y_v, p_dropout=p_base_dropout ) meta_loss = meta_loss + loss_NLL kl_loss = kl_loss + KL_loss task_count = task_count + 1 if torch.isnan(meta_loss).item(): sys.exit('nan') if (task_count % num_tasks_per_minibatch == 0): # average over the number of tasks per minibatch meta_loss = meta_loss/num_tasks_per_minibatch kl_loss = kl_loss/num_tasks_per_minibatch # accumulate for printing purpose meta_loss_avg_print += meta_loss.item() kl_loss_avg_print += kl_loss.item() op_theta.zero_grad() meta_loss.backward(retain_graph=True) # torch.nn.utils.clip_grad_norm_(parameters=theta.values(), max_norm=1) op_theta.step() # Printing losses num_meta_updates_count += 1 if (num_meta_updates_count % num_meta_updates_print == 0): meta_loss_avg_save.append(meta_loss_avg_print/num_meta_updates_count) kl_loss_avg_save.append(kl_loss_avg_print/num_meta_updates_count) print('{0:d}, {1:2.4f}, {2:1.4f}'.format( task_count, meta_loss_avg_save[-1], kl_loss_avg_save[-1] )) num_meta_updates_count = 0 meta_loss_avg_print = 0 kl_loss_avg_print = 0 if (task_count % num_tasks_save_loss == 0): meta_loss_saved.append(np.mean(meta_loss_avg_save)) kl_loss_saved.append(np.mean(kl_loss_avg_save)) meta_loss_avg_save = [] kl_loss_avg_save = [] # if datasource != 'sine_line': # val_accs = meta_validation( # datasubset=val_set, # num_val_tasks=num_val_tasks) # val_acc = np.mean(val_accs) # val_ci95 = 1.96*np.std(val_accs)/np.sqrt(num_val_tasks) # print('Validation accuracy = {0:2.4f} +/- {1:2.4f}'.format(val_acc, val_ci95)) # val_accuracies.append(val_acc) # train_accs = meta_validation( # datasubset=train_set, # num_val_tasks=num_val_tasks) # train_acc = np.mean(train_accs) # train_ci95 = 1.96*np.std(train_accs)/np.sqrt(num_val_tasks) # print('Train accuracy = {0:2.4f} +/- {1:2.4f}\n'.format(train_acc, train_ci95)) # train_accuracies.append(train_acc) # reset meta loss for the next minibatch of tasks meta_loss = 0 kl_loss = 0 if (task_count >= num_tasks_per_epoch): break if ((epoch + 1)% num_epochs_save == 0): checkpoint = { 'theta': theta, 'meta_loss': meta_loss_saved, 'kl_loss': kl_loss_saved, 'val_accuracy': val_accuracies, 'train_accuracy': train_accuracies, 'op_theta': op_theta.state_dict(), } print('SAVING WEIGHTS...') checkpoint_filename = ('{0:s}_{1:d}way_{2:d}shot_{3:d}.pt')\ .format(datasource, num_classes_per_task, num_training_samples_per_class, epoch + 1) print(checkpoint_filename) torch.save(checkpoint, os.path.join(dst_folder, checkpoint_filename)) # scheduler.step() print()
def meta_train(params, amine=None): # Start by unpacking the variables that we need datasource = params['datasource'] num_total_samples_per_class = params['num_total_samples_per_class'] device = params['device'] num_classes_per_task = params['num_classes_per_task'] num_training_samples_per_class = params['num_training_samples_per_class'] num_tasks_save_loss = params['num_tasks_save_loss'] # Epoch variables num_epochs = params['num_epochs'] resume_epoch = params['resume_epoch'] num_tasks_per_epoch = params['num_tasks_per_epoch'] # Note we have lowercase theta here vs with PLATIPUS theta = params['theta'] op_theta = params['op_theta'] # How often should we do a printout? num_meta_updates_print = 1 # How often should we save? num_epochs_save = 1000 if datasource == 'sine_line': data_generator = DataGenerator(num_samples=num_total_samples_per_class, device=device) for epoch in range(resume_epoch, resume_epoch + num_epochs): print(f"Starting epoch {epoch}") if datasource == 'drp_chem': training_batches = params['training_batches'] if params['cross_validate']: b_num = np.random.choice(len(training_batches[amine])) batch = training_batches[amine][b_num] else: b_num = np.random.choice(len(training_batches)) batch = training_batches[b_num] x_train, y_train, x_val, y_val = torch.from_numpy(batch[0]).float().to(params['device']), torch.from_numpy(batch[1]).long().to(params['device']), \ torch.from_numpy(batch[2]).float().to(params['device']), torch.from_numpy(batch[3]).long().to(params['device']) # variables used to store information of each epoch for monitoring purpose meta_loss_saved = [] # meta loss to save val_accuracies = [] train_accuracies = [] task_count = 0 # a counter to decide when a minibatch of task is completed to perform meta update meta_loss = 0 # accumulate the loss of many ensambling networks to descent gradient for meta update num_meta_updates_count = 0 meta_loss_avg_print = 0 # compute loss average to print meta_loss_avg_save = [] # meta loss to save while (task_count < num_tasks_per_epoch): if datasource == 'sine_line': p_sine = params['p_sine'] x_t, y_t, x_v, y_v = get_task_sine_line_data( data_generator=data_generator, p_sine=p_sine, num_training_samples=num_training_samples_per_class, noise_flag=True) elif datasource == 'drp_chem': x_t, y_t, x_v, y_v = x_train[task_count], y_train[ task_count], x_val[task_count], y_val[task_count] else: sys.exit('Unknown dataset') loss_NLL = get_task_prediction(x_t, y_t, x_v, params, y_v) if torch.isnan(loss_NLL).item(): sys.exit('NaN error') # accumulate meta loss meta_loss = meta_loss + loss_NLL task_count = task_count + 1 if task_count % num_tasks_per_epoch == 0: meta_loss = meta_loss / num_tasks_per_epoch # accumulate into different variables for printing purpose meta_loss_avg_print += meta_loss.item() op_theta.zero_grad() meta_loss.backward() # Clip gradients to prevent exploding gradient problem torch.nn.utils.clip_grad_norm_(parameters=theta.values(), max_norm=3) op_theta.step() # Printing losses num_meta_updates_count += 1 if (num_meta_updates_count % num_meta_updates_print == 0): meta_loss_avg_save.append(meta_loss_avg_print / num_meta_updates_count) print('{0:d}, {1:2.4f}'.format(task_count, meta_loss_avg_save[-1])) num_meta_updates_count = 0 meta_loss_avg_print = 0 if (task_count % num_tasks_save_loss == 0): meta_loss_saved.append(np.mean(meta_loss_avg_save)) meta_loss_avg_save = [] # print('Saving loss...') # if datasource != 'sine_line': # val_accs, _ = meta_validation( # datasubset=val_set, # num_val_tasks=num_val_tasks, # return_uncertainty=False) # val_acc = np.mean(val_accs) # val_ci95 = 1.96*np.std(val_accs)/np.sqrt(num_val_tasks) # print('Validation accuracy = {0:2.4f} +/- {1:2.4f}'.format(val_acc, val_ci95)) # val_accuracies.append(val_acc) # train_accs, _ = meta_validation( # datasubset=train_set, # num_val_tasks=num_val_tasks, # return_uncertainty=False) # train_acc = np.mean(train_accs) # train_ci95 = 1.96*np.std(train_accs)/np.sqrt(num_val_tasks) # print('Train accuracy = {0:2.4f} +/- {1:2.4f}\n'.format(train_acc, train_ci95)) # train_accuracies.append(train_acc) # Reset meta loss meta_loss = 0 if (task_count >= num_tasks_per_epoch): break if ((epoch + 1) % num_epochs_save == 0): checkpoint = { 'theta': theta, 'meta_loss': meta_loss_saved, 'val_accuracy': val_accuracies, 'train_accuracy': train_accuracies, 'op_theta': op_theta.state_dict() } print('SAVING WEIGHTS...') checkpoint_filename = ('{0:s}_{1:d}way_{2:d}shot_{3:d}.pt')\ .format(datasource, num_classes_per_task, num_training_samples_per_class, epoch + 1) print(checkpoint_filename) dst_folder = params['dst_folder'] torch.save(checkpoint, os.path.join(dst_folder, checkpoint_filename)) print()
def meta_train(): if datasource == 'sine_line': data_generator = DataGenerator(num_samples=num_total_samples_per_class, device=device) # create dummy sampler all_class = [0] * 100 sampler = torch.utils.data.sampler.RandomSampler(data_source=all_class) train_loader = torch.utils.data.DataLoader( dataset=all_class, batch_size=num_classes_per_task, sampler=sampler, drop_last=True) else: all_class = all_class_train # all_class.update(all_class_val) embedding = embedding_train # embedding.update(embedding_val) sampler = torch.utils.data.sampler.RandomSampler(data_source=list( all_class.keys()), replacement=False) train_loader = torch.utils.data.DataLoader( dataset=list(all_class.keys()), batch_size=num_classes_per_task, sampler=sampler, drop_last=True) for epoch in range(resume_epoch, resume_epoch + num_epochs): # variables used to store information of each epoch for monitoring purpose loss_NLL_saved = [] kl_loss_saved = [] d_loss_saved = [] val_accuracies = [] train_accuracies = [] task_count = 0 # a counter to decide when a minibatch of task is completed to perform meta update meta_loss = 0 # accumulate the loss of many ensambling networks to descent gradient for meta update num_meta_updates_count = 0 loss_NLL_v = 0 loss_NLL_avg_print = 0 kl_loss = 0 kl_loss_avg_print = 0 d_loss = 0 d_loss_avg_print = 0 loss_NLL_avg_save = [] kl_loss_avg_save = [] d_loss_avg_save = [] task_count = 0 while (task_count < num_tasks_per_epoch): for class_labels in train_loader: if datasource == 'sine_line': x_t, y_t, x_v, y_v = get_task_sine_line_data( data_generator=data_generator, p_sine=p_sine, num_training_samples=num_training_samples_per_class, noise_flag=True) else: x_t, y_t, x_v, y_v = get_task_image_data( all_class, embedding, class_labels, num_total_samples_per_class, num_training_samples_per_class, device) loss_NLL, KL_loss, discriminator_loss = get_task_prediction( x_t=x_t, y_t=y_t, x_v=x_v, y_v=y_v, p_dropout=p_dropout_base, p_dropout_g=p_dropout_generator, p_dropout_d=p_dropout_discriminator, p_dropout_e=p_dropout_encoder) if torch.isnan(loss_NLL).item(): sys.exit('nan') loss_NLL_v += loss_NLL.item() if (loss_NLL.item() > 1): loss_NLL.data = torch.tensor([1.], device=device) # if (discriminator_loss.item() > d_loss_const): # discriminator_loss.data = torch.tensor([d_loss_const], device=device) kl_loss = kl_loss + KL_loss if (KL_loss.item() < 0): KL_loss.data = torch.tensor([0.], device=device) Ri = torch.sqrt((KL_loss + ri_const) / (2 * (total_validation_samples - 1))) meta_loss = meta_loss + loss_NLL + Ri d_loss = d_loss + discriminator_loss task_count = task_count + 1 if (task_count % num_tasks_per_minibatch == 0): # average over the number of tasks per minibatch meta_loss = meta_loss / num_tasks_per_minibatch loss_NLL_v /= num_tasks_per_minibatch kl_loss = kl_loss / num_tasks_per_minibatch d_loss = d_loss / num_tasks_per_minibatch # accumulate for printing purpose loss_NLL_avg_print += loss_NLL_v kl_loss_avg_print += kl_loss.item() d_loss_avg_print += d_loss.item() # adding R0 R0 = 0 for key in theta.keys(): R0 += theta[key].norm(2) R0 = torch.sqrt((L2_regularization * R0 + r0_const) / (2 * (num_tasks_per_minibatch - 1))) meta_loss += R0 # optimize theta op_theta.zero_grad() meta_loss.backward(retain_graph=True) torch.nn.utils.clip_grad_norm_(parameters=theta.values(), max_norm=clip_grad_value) torch.nn.utils.clip_grad_norm_( parameters=w_encoder.values(), max_norm=clip_grad_value) op_theta.step() # optimize the discriminator op_discriminator.zero_grad() d_loss.backward() torch.nn.utils.clip_grad_norm_( parameters=w_discriminator.values(), max_norm=clip_grad_value) op_discriminator.step() # Printing losses num_meta_updates_count += 1 if (num_meta_updates_count % num_meta_updates_print == 0): loss_NLL_avg_save.append(loss_NLL_avg_print / num_meta_updates_count) kl_loss_avg_save.append(kl_loss_avg_print / num_meta_updates_count) d_loss_avg_save.append(d_loss_avg_print / num_meta_updates_count) print('{0:d}, {1:2.4f}, {2:1.4f}, {3:2.4e}'.format( task_count, loss_NLL_avg_save[-1], kl_loss_avg_save[-1], d_loss_avg_save[-1])) num_meta_updates_count = 0 loss_NLL_avg_print = 0 kl_loss_avg_print = 0 d_loss_avg_print = 0 if (task_count % num_tasks_save_loss == 0): loss_NLL_saved.append(np.mean(loss_NLL_avg_save)) kl_loss_saved.append(np.mean(kl_loss_avg_save)) d_loss_saved.append(np.mean(d_loss_avg_save)) loss_NLL_avg_save = [] kl_loss_avg_save = [] d_loss_avg_save = [] # if datasource != 'sine_line': # val_accs = meta_validation( # datasubset=val_set, # num_val_tasks=num_val_tasks) # val_acc = np.mean(val_accs) # val_ci95 = 1.96*np.std(val_accs)/np.sqrt(num_val_tasks) # print('Validation accuracy = {0:2.4f} +/- {1:2.4f}'.format(val_acc, val_ci95)) # val_accuracies.append(val_acc) # train_accs = meta_validation( # datasubset=train_set, # num_val_tasks=num_val_tasks) # train_acc = np.mean(train_accs) # train_ci95 = 1.96*np.std(train_accs)/np.sqrt(num_val_tasks) # print('Train accuracy = {0:2.4f} +/- {1:2.4f}\n'.format(train_acc, train_ci95)) # train_accuracies.append(train_acc) # reset meta loss for the next minibatch of tasks meta_loss = 0 kl_loss = 0 d_loss = 0 loss_NLL_v = 0 if (task_count >= num_tasks_per_epoch): break if ((epoch + 1) % num_epochs_save == 0): checkpoint = { 'w_discriminator': w_discriminator, 'theta': theta, 'w_encoder': w_encoder, 'w_encoder_2': w_encoder_2, 'meta_loss': loss_NLL_saved, 'kl_loss': kl_loss_saved, 'd_loss': d_loss_saved, 'val_accuracy': val_accuracies, 'train_accuracy': train_accuracies, 'op_theta': op_theta.state_dict(), 'op_discriminator': op_discriminator.state_dict() } print('SAVING WEIGHTS...') checkpoint_filename = ('{0:s}_{1:d}way_{2:d}shot_{3:d}.pt')\ .format(datasource, num_classes_per_task, num_training_samples_per_class, epoch + 1) print(checkpoint_filename) torch.save(checkpoint, os.path.join(dst_folder, checkpoint_filename)) # scheduler.step() print()
def meta_train(): if datasource == 'sine_line': data_generator = DataGenerator(num_samples=num_samples_per_class) # create dummy sampler all_class_train = [0] * 10 else: all_class_train, all_data_train = load_dataset( dataset_name=datasource, subset=train_set ) all_class_val, all_data_val = load_dataset( dataset_name=datasource, subset=val_set ) all_class_train.update(all_class_val) all_data_train.update(all_data_val) # initialize data loader train_loader = initialize_dataloader( all_classes=[class_label for class_label in all_class_train], num_classes_per_task=num_classes_per_task ) for epoch in range(resume_epoch, resume_epoch + num_epochs): # variables used to store information of each epoch for monitoring purpose meta_loss_saved = [] # meta loss to save val_accuracies = [] train_accuracies = [] meta_loss = 0 # accumulate the loss of many ensambling networks to descent gradient for meta update num_meta_updates_count = 0 meta_loss_avg_print = 0 # compute loss average to print meta_loss_avg_save = [] # meta loss to save task_count = 0 # a counter to decide when a minibatch of task is completed to perform meta update while (task_count < num_tasks_per_epoch): for class_labels in train_loader: if datasource == 'sine_line': x_t, y_t, x_v, y_v = get_task_sine_line_data( data_generator=data_generator, p_sine=p_sine, num_training_samples=num_training_samples_per_class, noise_flag=True ) x_t = torch.tensor(x_t, dtype=torch.float, device=device) y_t = torch.tensor(y_t, dtype=torch.float, device=device) x_v = torch.tensor(x_v, dtype=torch.float, device=device) y_v = torch.tensor(y_v, dtype=torch.float, device=device) else: x_t, y_t, x_v, y_v = get_train_val_task_data( all_classes=all_class_train, all_data=all_data_train, class_labels=class_labels, num_samples_per_class=num_samples_per_class, num_training_samples_per_class=num_training_samples_per_class, device=device ) w_task = adapt_to_task(x=x_t, y=y_t, w0=theta) y_pred = net.forward(x=x_v, w=w_task) loss_NLL = loss_fn(input=y_pred, target=y_v) if torch.isnan(loss_NLL).item(): sys.exit('NaN error') # accumulate meta loss meta_loss = meta_loss + loss_NLL task_count = task_count + 1 if task_count % num_tasks_per_minibatch == 0: meta_loss = meta_loss/num_tasks_per_minibatch # accumulate into different variables for printing purpose meta_loss_avg_print += meta_loss.item() op_theta.zero_grad() meta_loss.backward() torch.nn.utils.clip_grad_norm_(parameters=theta.values(), max_norm=10) op_theta.step() # Printing losses num_meta_updates_count += 1 if (num_meta_updates_count % num_meta_updates_print == 0): meta_loss_avg_save.append(meta_loss_avg_print/num_meta_updates_count) print('{0:d}, {1:2.4f}'.format( task_count, meta_loss_avg_save[-1] )) num_meta_updates_count = 0 meta_loss_avg_print = 0 if (task_count % num_tasks_save_loss == 0): meta_loss_saved.append(np.mean(meta_loss_avg_save)) meta_loss_avg_save = [] print('Saving loss...') if datasource != 'sine_line': val_accs = validate_classification( all_classes=all_class_val, all_data=all_data_val, num_val_tasks=100, p_rand=0.5, uncertainty=False, csv_flag=False ) val_acc = np.mean(val_accs) val_ci95 = 1.96*np.std(val_accs)/np.sqrt(num_val_tasks) print('Validation accuracy = {0:2.4f} +/- {1:2.4f}'.format(val_acc, val_ci95)) val_accuracies.append(val_acc) train_accs = validate_classification( all_classes=all_class_train, all_data=all_data_train, num_val_tasks=100, p_rand=0.5, uncertainty=False, csv_flag=False ) train_acc = np.mean(train_accs) train_ci95 = 1.96*np.std(train_accs)/np.sqrt(num_val_tasks) print('Train accuracy = {0:2.4f} +/- {1:2.4f}\n'.format(train_acc, train_ci95)) train_accuracies.append(train_acc) # reset meta loss meta_loss = 0 if (task_count >= num_tasks_per_epoch): break if ((epoch + 1)% num_epochs_save == 0): checkpoint = { 'theta': theta, 'meta_loss': meta_loss_saved, 'val_accuracy': val_accuracies, 'train_accuracy': train_accuracies, 'op_theta': op_theta.state_dict() } print('SAVING WEIGHTS...') checkpoint_filename = 'Epoch_{0:d}.pt'.format(epoch + 1) print(checkpoint_filename) torch.save(checkpoint, os.path.join(dst_folder, checkpoint_filename)) scheduler.step() print()