def meta_step(prm, model, mb_data_loaders, mb_iterators, loss_criterion): total_objective = 0 correct_count = 0 sample_count = 0 n_tasks_in_mb = len(mb_data_loaders) # ----------- loop over tasks in meta-batch -----------------------------------# for i_task in range(n_tasks_in_mb): fast_weights = OrderedDict( (name, param) for (name, param) in model.named_parameters()) # ----------- gradient steps loop -----------------------------------# for i_step in range(prm.n_meta_train_grad_steps): # get batch variables: batch_data = data_gen.get_next_batch_cyclic( mb_iterators[i_task], mb_data_loaders[i_task]['train']) inputs, targets = data_gen.get_batch_vars(batch_data, prm) # Debug # print(targets[0].data[0]) # print first image label # import matplotlib.pyplot as plt # plt.imshow(inputs[0].cpu().data[0].numpy()) # show first image # plt.show() if i_step == 0: outputs = model(inputs) else: outputs = model(inputs, fast_weights) # Empirical Loss on current task: task_loss = loss_criterion(outputs, targets) grads = torch.autograd.grad(task_loss, fast_weights.values(), create_graph=True) fast_weights = OrderedDict( (name, param - prm.alpha * grad) for ((name, param), grad) in zip(fast_weights.items(), grads)) # end grad steps loop # Sample new (validation) data batch for this task: if hasattr(prm, 'MAML_Use_Test_Data') and prm.MAML_Use_Test_Data: batch_data = data_gen.get_next_batch_cyclic( mb_iterators[i_task], mb_data_loaders[i_task]['test']) else: batch_data = data_gen.get_next_batch_cyclic( mb_iterators[i_task], mb_data_loaders[i_task]['train']) inputs, targets = data_gen.get_batch_vars(batch_data, prm) outputs = model(inputs, fast_weights) total_objective += loss_criterion(outputs, targets) correct_count += count_correct(outputs, targets) sample_count += inputs.size(0) # end loop over tasks in meta-batch info = {'sample_count': sample_count, 'correct_count': correct_count} return total_objective, info
def run_test_max_posterior(model, test_loader, loss_criterion, prm): n_test_samples = len(test_loader.dataset) model.eval() test_loss = 0 n_correct = 0 for batch_data in test_loader: inputs, targets = data_gen.get_batch_vars(batch_data, prm, is_test=True) old_eps_std = model.set_eps_std(0.0) # test with max-posterior outputs = model(inputs) model.set_eps_std(old_eps_std) # return model to normal behaviour test_loss += loss_criterion(outputs, targets) # sum the mean loss in batch n_correct += count_correct(outputs, targets) test_loss /= n_test_samples test_acc = n_correct / n_test_samples info = { 'test_acc': test_acc, 'n_correct': n_correct, 'test_type': 'max_posterior', 'n_test_samples': n_test_samples, 'test_loss': get_value(test_loss) } return info
def run_eval_expected(model, loader, prm): ''' Estimates the expectation of the loss by monte-carlo averaging''' n_samples = len(loader.dataset) loss_criterion = get_loss_func(prm) model.eval() avg_loss = 0.0 n_correct = 0 n_MC = prm.n_MC_eval # number of monte-carlo runs for expected loss estimation for batch_data in loader: inputs, targets = data_gen.get_batch_vars(batch_data, prm) batch_size = inputs.shape[0] # monte-carlo runs for i_MC in range(n_MC): outputs = model(inputs) avg_loss += loss_criterion( outputs, targets).item() # sum the loss contributed from batch n_correct += count_correct(outputs, targets) avg_loss /= (n_MC * n_samples) acc = n_correct / (n_MC * n_samples) info = { 'acc': acc, 'n_correct': n_correct, 'n_samples': n_samples, 'avg_loss': avg_loss } return info
def run_eval_max_posterior(model, loader, prm): ''' Estimates the the loss by using the mean network parameters''' # 使用平均网络参数 n_samples = len(loader.dataset) loss_criterion = get_loss_func(prm) model.eval() avg_loss = 0 n_correct = 0 for batch_data in loader: # 提取batch data inputs, targets = data_gen.get_batch_vars(batch_data, prm) batch_size = inputs.shape[0] old_eps_std = model.set_eps_std(0.0) # test with max-posterior outputs = model(inputs) model.set_eps_std(old_eps_std) # return model to normal behaviour avg_loss += loss_criterion( outputs, targets).item() # sum the loss contributed from batch n_correct += count_correct(outputs, targets) avg_loss /= n_samples acc = n_correct / n_samples info = { 'acc': acc, 'n_correct': n_correct, 'n_samples': n_samples, 'avg_loss': avg_loss } return info
def run_eval_majority_vote(model, loader, prm, n_votes=5): ''' Estimates the the loss of the the majority votes over several draws form network's distribution''' loss_criterion = get_loss_func(prm) n_samples = len(loader.dataset) n_test_batches = len(loader) model.eval() avg_loss = 0 n_correct = 0 for batch_data in loader: inputs, targets = data_gen.get_batch_vars(batch_data, prm) batch_size = inputs.shape[0] # min(prm.test_batch_size, n_samples) info = data_gen.get_info(prm) n_labels = info['n_classes'] votes = torch.zeros((batch_size, n_labels), device=prm.device) loss_from_batch = 0.0 for i_vote in range(n_votes): outputs = model(inputs) loss_from_batch += loss_criterion(outputs, targets).item() pred = outputs.data.max(1, keepdim=True)[1] # get the index of the max output for i_sample in range(batch_size): pred_val = pred[i_sample].cpu().numpy()[0] votes[i_sample, pred_val] += 1 avg_loss += loss_from_batch / n_votes # sum the loss contributed from batch majority_pred = votes.max(1, keepdim=True)[1] # find argmax class for each sample n_correct += majority_pred.eq(targets.data.view_as(majority_pred)).cpu().sum() avg_loss /= n_samples acc = n_correct / n_samples info = {'acc': acc, 'n_correct': n_correct, 'n_samples': n_samples, 'avg_loss': avg_loss} return info
def run_eval_avg_vote(model, loader, prm, n_votes=5): ''' Estimates the the loss by of the average vote over several draws form network's distribution''' loss_criterion = get_loss_func(prm) n_samples = len(loader.dataset) n_test_batches = len(loader) model.eval() avg_loss = 0 n_correct = 0 for batch_data in loader: inputs, targets = data_gen.get_batch_vars(batch_data, prm) batch_size = min(prm.test_batch_size, n_samples) info = data_gen.get_info(prm) n_labels = info['n_classes'] votes = torch.zeros((batch_size, n_labels), device=prm.device) loss_from_batch = 0.0 for i_vote in range(n_votes): outputs = model(inputs) loss_from_batch += loss_criterion(outputs, targets).item() votes += outputs.data majority_pred = votes.max(1, keepdim=True)[1] n_correct += majority_pred.eq(targets.data.view_as(majority_pred)).cpu().sum() avg_loss += loss_from_batch / n_votes # sum the loss contributed from batch avg_loss /= n_samples acc = n_correct / n_samples info = {'acc': acc, 'n_correct': n_correct, 'n_samples': n_samples, 'avg_loss': avg_loss} return info
def run_train_epoch(i_epoch): log_interval = 500 post_model.train() for batch_idx, batch_data in enumerate(train_loader): correct_count = 0 sample_count = 0 # Monte-Carlo iterations: n_MC = prm.n_MC task_empirical_loss = 0 task_complexity = 0 for i_MC in range(n_MC): # get batch: inputs, targets = data_gen.get_batch_vars(batch_data, prm) # Calculate empirical loss: outputs = post_model(inputs) curr_empirical_loss = loss_criterion(outputs, targets) curr_empirical_loss, curr_complexity = get_bayes_task_objective( prm, prior_model, post_model, n_train_samples, curr_empirical_loss, noised_prior=False) task_empirical_loss += (1 / n_MC) * curr_empirical_loss task_complexity += (1 / n_MC) * curr_complexity correct_count += count_correct(outputs, targets) sample_count += inputs.size(0) # Total objective: total_objective = task_empirical_loss + task_complexity # Take gradient step with the posterior: grad_step(total_objective, optimizer, lr_schedule, prm.lr, i_epoch) # Print status: if batch_idx % log_interval == 0: batch_acc = correct_count / sample_count print( cmn.status_string(i_epoch, prm.n_meta_test_epochs, batch_idx, n_batches, batch_acc, total_objective.item()) + ' Empiric Loss: {:.4}\t Intra-Comp. {:.4}'.format( task_empirical_loss.item(), task_complexity.item())) data_objective.append(total_objective.item()) data_accuracy.append(batch_acc) data_emp_loss.append(task_empirical_loss.item()) data_task_comp.append(task_complexity.item()) return total_objective.item()
def run_train_epoch(i_epoch): # # Adjust randomness (eps_std) # if hasattr(prm, 'use_randomness_schedeule') and prm.use_randomness_schedeule: # if i_epoch > prm.randomness_full_epoch: # eps_std = 1.0 # elif i_epoch > prm.randomness_init_epoch: # eps_std = (i_epoch - prm.randomness_init_epoch) / (prm.randomness_full_epoch - prm.randomness_init_epoch) # else: # eps_std = 0.0 # turn off randomness # post_model.set_eps_std(eps_std) # post_model.set_eps_std(0.00) # debug complexity_term = 0 post_model.train() for batch_idx, batch_data in enumerate(train_loader): # Monte-Carlo iterations: empirical_loss = 0 n_MC = prm.n_MC for i_MC in range(n_MC): # get batch: inputs, targets = data_gen.get_batch_vars(batch_data, prm) # calculate objective: outputs = post_model(inputs) empirical_loss_c = loss_criterion(outputs, targets) empirical_loss += (1 / n_MC) * empirical_loss_c # complexity/prior term: if prior_model: empirical_loss, complexity_term = get_bayes_task_objective( prm, prior_model, post_model, n_train_samples, empirical_loss) else: complexity_term = 0.0 # Total objective: objective = empirical_loss + complexity_term # Take gradient step: grad_step(objective, optimizer, lr_schedule, prm.lr, i_epoch) # Print status: log_interval = 500 if batch_idx % log_interval == 0: batch_acc = correct_rate(outputs, targets) print( cmn.status_string(i_epoch, prm.num_epochs, batch_idx, n_batches, batch_acc, objective.data[0]) + ' Loss: {:.4}\t Comp.: {:.4}'.format( get_value(empirical_loss), get_value(complexity_term)))
def run_train_epoch(i_epoch, log_mat): post_model.train() for batch_idx, batch_data in enumerate(train_loader): # get batch data: inputs, targets = data_gen.get_batch_vars(batch_data, prm) batch_size = inputs.shape[0] # Monte-Carlo iterations: avg_empiric_loss = torch.zeros(1, device=prm.device) n_MC = prm.n_MC for i_MC in range(n_MC): # calculate objective: outputs = post_model(inputs) avg_empiric_loss_curr = (1 / batch_size) * loss_criterion( outputs, targets) avg_empiric_loss += (1 / n_MC) * avg_empiric_loss_curr # complexity/prior term: if prior_model: complexity_term = get_task_complexity(prm, prior_model, post_model, n_train_samples, avg_empiric_loss) else: complexity_term = torch.zeros(1, device=prm.device) # Total objective: objective = avg_empiric_loss + complexity_term # Take gradient step: grad_step(objective, optimizer, lr_schedule, prm.lr, i_epoch) # Print status: log_interval = 1000 if batch_idx % log_interval == 0: batch_acc = correct_rate(outputs, targets) print( cmn.status_string(i_epoch, prm.num_epochs, batch_idx, n_batches, batch_acc, objective.item()) + ' Loss: {:.4}\t Comp.: {:.4}'.format( avg_empiric_loss.item(), complexity_term.item())) # End batch loop # save results for epochs-figure: if figure_flag and (i_epoch % prm.log_figure['interval_epochs'] == 0): save_result_for_figure(post_model, prior_model, data_loader, prm, log_mat, i_epoch)
def get_risk(prior_model, prm, mb_data_loaders, mb_iterators, loss_criterion, n_train_tasks): ''' Calculate objective based on tasks in meta-batch ''' # note: it is OK if some tasks appear several times in the meta-batch n_tasks_in_mb = len(mb_data_loaders) sum_empirical_loss = 0 sum_intra_task_comp = 0 correct_count = 0 sample_count = 0 # ----------- loop over tasks in meta-batch -----------------------------------# for i_task in range(n_tasks_in_mb): n_samples = mb_data_loaders[i_task]['n_train_samples'] # get sample-batch data from current task to calculate the empirical loss estimate: batch_data = data_gen.get_next_batch_cyclic( mb_iterators[i_task], mb_data_loaders[i_task]['train']) # The posterior model corresponding to the task in the batch: # post_model = mb_posteriors_models[i_task] prior_model.train() # Monte-Carlo iterations: n_MC = prm.n_MC task_empirical_loss = 0 # task_complexity = 0 # ----------- Monte-Carlo loop -----------------------------------# for i_MC in range(n_MC): # get batch variables: inputs, targets = data_gen.get_batch_vars(batch_data, prm) # Empirical Loss on current task: outputs = prior_model(inputs) curr_empirical_loss = loss_criterion(outputs, targets) correct_count += count_correct(outputs, targets) sample_count += inputs.size(0) task_empirical_loss += (1 / n_MC) * curr_empirical_loss # end Monte-Carlo loop sum_empirical_loss += task_empirical_loss # end loop over tasks in meta-batch avg_empirical_loss = (1 / n_tasks_in_mb) * sum_empirical_loss return avg_empirical_loss
def run_test(model, test_loader, loss_criterion, prm): model.eval() test_loss = 0 n_correct = 0 for batch_data in test_loader: inputs, targets = data_gen.get_batch_vars(batch_data, prm) batch_size = inputs.shape[0] outputs = model(inputs) test_loss += (1 / batch_size) * loss_criterion( outputs, targets).item() # sum the mean loss in batch n_correct += count_correct(outputs, targets) n_test_samples = len(test_loader.dataset) n_test_batches = len(test_loader) test_loss = test_loss / n_test_batches test_acc = n_correct / n_test_samples print('\n Standard learning: test loss: {:.4}, test err: {:.3} ( {}/{})\n'. format(test_loss, 1 - test_acc, n_correct, n_test_samples)) return test_acc
def run_test(model, test_loader): model.eval() test_loss = 0 n_correct = 0 for batch_data in test_loader: inputs, targets = data_gen.get_batch_vars(batch_data, prm, is_test=True) outputs = model(inputs) test_loss += loss_criterion(outputs, targets) # sum the mean loss in batch n_correct += count_correct(outputs, targets) n_test_samples = len(test_loader.dataset) n_test_batches = len(test_loader) test_loss = test_loss.data[0] / n_test_batches test_acc = n_correct / n_test_samples print('\nTest set: Average loss: {:.4}, Accuracy: {:.3} ( {}/{})\n'. format(test_loss, test_acc, n_correct, n_test_samples)) return test_acc
def run_test_majority_vote(model, test_loader, loss_criterion, prm, n_votes=9): # n_test_samples = len(test_loader.dataset) n_test_batches = len(test_loader) model.eval() test_loss = 0 n_correct = 0 for batch_data in test_loader: inputs, targets = data_gen.get_batch_vars(batch_data, prm, is_test=True) batch_size = inputs.shape[ 0] # min(prm.test_batch_size, n_test_samples) info = data_gen.get_info(prm) n_labels = info['n_classes'] votes = cmn.zeros_gpu((batch_size, n_labels)) for i_vote in range(n_votes): outputs = model(inputs) test_loss += loss_criterion(outputs, targets) pred = outputs.data.max( 1, keepdim=True)[1] # get the index of the max output for i_sample in range(batch_size): pred_val = pred[i_sample].cpu().numpy()[0] votes[i_sample, pred_val] += 1 majority_pred = votes.max( 1, keepdim=True)[1] # find argmax class for each sample n_correct += majority_pred.eq( targets.data.view_as(majority_pred)).cpu().sum() test_loss /= n_test_samples test_acc = n_correct / n_test_samples info = { 'test_acc': test_acc, 'n_correct': n_correct, 'test_type': 'majority_vote', 'n_test_samples': n_test_samples, 'test_loss': get_value(test_loss) } return info
def run_train_epoch(i_epoch): log_interval = 500 model.train() for batch_idx, batch_data in enumerate(train_loader): # get batch: inputs, targets = data_gen.get_batch_vars(batch_data, prm) # Calculate loss: outputs = model(inputs) loss = loss_criterion(outputs, targets) # Take gradient step: grad_step(loss, optimizer, lr_schedule, prm.lr, i_epoch) # Print status: if batch_idx % log_interval == 0: batch_acc = correct_rate(outputs, targets) print( cmn.status_string(i_epoch, prm.num_epochs, batch_idx, n_batches, batch_acc, get_value(loss)))
def run_meta_test_learning(task_model, train_loader): task_model.train() train_loader_iter = iter(train_loader) # Gradient steps (training) loop for i_grad_step in range(prm.n_meta_test_grad_steps): # get batch: batch_data = data_gen.get_next_batch_cyclic( train_loader_iter, train_loader) inputs, targets = data_gen.get_batch_vars(batch_data, prm) # Calculate empirical loss: outputs = task_model(inputs) task_objective = loss_criterion(outputs, targets) # Take gradient step with the task weights: grad_step(task_objective, task_optimizer) # end gradient step loop return task_model
def run_train_epoch(i_epoch): prior_model.train() for batch_idx, batch_data in enumerate(data_prior_loader): correct_count = 0 sample_count = 0 # Monte-Carlo iterations: n_MC = prm.n_MC task_empirical_loss = 0 for i_MC in range(n_MC): # get batch: inputs, targets = data_gen.get_batch_vars(batch_data, prm) # Calculate empirical loss: outputs = prior_model(inputs) curr_empirical_loss = loss_criterion(outputs, targets) task_empirical_loss += (1 / n_MC) * curr_empirical_loss correct_count += count_correct(outputs, targets) sample_count += inputs.size(0) # Total objective: total_objective = task_empirical_loss # Take gradient step with the posterior: grad_step(total_objective, all_optimizer, lr_schedule, prm.lr, i_epoch) log_interval = 20 if batch_idx % log_interval == 0: batch_acc = correct_count / sample_count print( 'number meta batch:{} \t avg_empiric_loss:{:.3f} \t batch accuracy:{:.3f}' .format(batch_idx, total_objective, batch_acc))
def run_test_avg_vote(model, test_loader, loss_criterion, prm, n_votes=5): n_test_samples = len(test_loader.dataset) n_test_batches = len(test_loader) model.eval() test_loss = 0 n_correct = 0 for batch_data in test_loader: inputs, targets = data_gen.get_batch_vars(batch_data, prm, is_test=True) batch_size = min(prm.test_batch_size, n_test_samples) info = data_gen.get_info(prm) n_labels = info['n_classes'] votes = cmn.zeros_gpu((batch_size, n_labels)) for i_vote in range(n_votes): outputs = model(inputs) test_loss += loss_criterion(outputs, targets) votes += outputs.data majority_pred = votes.max(1, keepdim=True)[1] n_correct += majority_pred.eq( targets.data.view_as(majority_pred)).cpu().sum() test_loss /= n_test_samples test_acc = n_correct / n_test_samples info = { 'test_acc': test_acc, 'n_correct': n_correct, 'test_type': 'AvgVote', 'n_test_samples': n_test_samples, 'test_loss': get_value(test_loss) } return info
def run_train_epoch(i_epoch): log_interval = 500 post_model.train() train_info = {} train_info["task_comp"] = 0.0 train_info["total_loss"] = 0.0 cnt = 0 for batch_idx, batch_data in enumerate(train_loader): cnt += 1 correct_count = 0 sample_count = 0 # Monte-Carlo iterations: n_MC = prm.n_MC task_empirical_loss = 0 task_complexity = 0 for i_MC in range(n_MC): # get batch: inputs, targets = data_gen.get_batch_vars(batch_data, prm) # Calculate empirical loss: outputs = post_model(inputs) curr_empirical_loss = loss_criterion(outputs, targets) #hyper_kl = 0 when testing curr_empirical_loss, curr_complexity, task_info = get_bayes_task_objective( prm, prior_model, post_model, n_train_samples, curr_empirical_loss, noised_prior=False) task_empirical_loss += (1 / n_MC) * curr_empirical_loss task_complexity += (1 / n_MC) * curr_complexity correct_count += count_correct(outputs, targets) sample_count += inputs.size(0) # Total objective: total_objective = task_empirical_loss + prm.task_complex_w * task_complexity train_info["task_comp"] += task_complexity.data[0] train_info["total_loss"] += total_objective.data[0] # Take gradient step with the posterior: grad_step(total_objective, optimizer, lr_schedule, prm.lr, i_epoch) # Print status: if batch_idx % log_interval == 0: batch_acc = correct_count / sample_count write_to_log( cmn.status_string(i_epoch, prm.n_meta_test_epochs, batch_idx, n_batches, batch_acc, total_objective.data[0]) + ' Empiric Loss: {:.4}\t Intra-Comp. {:.4}, w_kld {:.4}, b_kld {:.4}' .format(task_empirical_loss.data[0], task_complexity.data[0], task_info["w_kld"], task_info["b_kld"]), prm) train_info["task_comp"] /= cnt train_info["total_loss"] /= cnt return train_info
def get_objective(prior_model, prm, mb_data_loaders, mb_iterators, mb_posteriors_models, loss_criterion, n_train_tasks): ''' Calculate objective based on tasks in meta-batch ''' # note: it is OK if some tasks appear several times in the meta-batch n_tasks_in_mb = len(mb_data_loaders) sum_empirical_loss = 0 sum_intra_task_comp = 0 correct_count = 0 sample_count = 0 #set_trace() # KLD between hyper-posterior and hyper-prior: hyper_kl = (1 / (2 * prm.kappa_prior**2)) * net_norm( prior_model, p=2) #net_norm is L2-regularization # Hyper-prior term: meta_complex_term = get_meta_complexity_term(hyper_kl, prm, n_train_tasks) sum_w_kld = 0.0 sum_b_kld = 0.0 # ----------- loop over tasks in meta-batch -----------------------------------# for i_task in range(n_tasks_in_mb): n_samples = mb_data_loaders[i_task]['n_train_samples'] # get sample-batch data from current task to calculate the empirical loss estimate: batch_data = data_gen.get_next_batch_cyclic( mb_iterators[i_task], mb_data_loaders[i_task]['train']) # The posterior model corresponding to the task in the batch: post_model = mb_posteriors_models[i_task] post_model.train() # Monte-Carlo iterations: n_MC = prm.n_MC task_empirical_loss = 0 task_complexity = 0 # ----------- Monte-Carlo loop -----------------------------------# for i_MC in range(n_MC): # get batch variables: inputs, targets = data_gen.get_batch_vars(batch_data, prm) # Debug # print(targets[0].data[0]) # print first image label # import matplotlib.pyplot as plt # plt.imshow(inputs[0].cpu().data[0].numpy()) # show first image # plt.show() # Empirical Loss on current task: outputs = post_model(inputs) curr_empirical_loss = loss_criterion(outputs, targets) correct_count += count_correct(outputs, targets) sample_count += inputs.size(0) # Intra-task complexity of current task: curr_empirical_loss, curr_complexity, task_info = get_bayes_task_objective( prm, prior_model, post_model, n_samples, curr_empirical_loss, hyper_kl, n_train_tasks=n_train_tasks) sum_w_kld += task_info["w_kld"] sum_b_kld += task_info["b_kld"] task_empirical_loss += (1 / n_MC) * curr_empirical_loss task_complexity += (1 / n_MC) * curr_complexity # end Monte-Carlo loop sum_empirical_loss += task_empirical_loss sum_intra_task_comp += task_complexity # end loop over tasks in meta-batch avg_empirical_loss = (1 / n_tasks_in_mb) * sum_empirical_loss avg_intra_task_comp = (1 / n_tasks_in_mb) * sum_intra_task_comp avg_w_kld += (1 / n_tasks_in_mb) * sum_w_kld avg_b_kld += (1 / n_tasks_in_mb) * sum_b_kld # Approximated total objective: total_objective = avg_empirical_loss + prm.task_complex_w * avg_intra_task_comp + prm.meta_complex_w * meta_complex_term info = { 'sample_count': get_value(sample_count), 'correct_count': get_value(correct_count), 'avg_empirical_loss': get_value(avg_empirical_loss), 'avg_intra_task_comp': get_value(avg_intra_task_comp), 'meta_comp': get_value(meta_complex_term), 'w_kld': avg_w_kld, 'b_kld': avg_b_kld } return total_objective, info
def get_objective(prior_model, prm, mb_data_loaders, mb_iterators, mb_posteriors_models, loss_criterion, n_train_tasks): ''' Calculate objective based on tasks in meta-batch ''' # note: it is OK if some tasks appear several times in the meta-batch n_tasks_in_mb = len(mb_data_loaders) correct_count = 0 sample_count = 0 # Hyper-prior term: hyper_dvrg = get_hyper_divergnce(prm, prior_model) meta_complex_term = get_meta_complexity_term(hyper_dvrg, prm, n_train_tasks) avg_empiric_loss_per_task = torch.zeros(n_tasks_in_mb, device=prm.device) complexity_per_task = torch.zeros(n_tasks_in_mb, device=prm.device) n_samples_per_task = torch.zeros( n_tasks_in_mb, device=prm.device ) # how many sampels there are total in each task (not just in a batch) # ----------- loop over tasks in meta-batch -----------------------------------# for i_task in range(n_tasks_in_mb): n_samples = mb_data_loaders[i_task]['n_train_samples'] n_samples_per_task[i_task] = n_samples # get sample-batch data from current task to calculate the empirical loss estimate: batch_data = data_gen.get_next_batch_cyclic( mb_iterators[i_task], mb_data_loaders[i_task]['train']) # get batch variables: inputs, targets = data_gen.get_batch_vars(batch_data, prm) batch_size = inputs.shape[0] # The posterior model corresponding to the task in the batch: post_model = mb_posteriors_models[i_task] post_model.train() # Monte-Carlo iterations: n_MC = prm.n_MC avg_empiric_loss = 0.0 complexity = 0.0 # Monte-Carlo loop for i_MC in range(n_MC): # Debug # print(targets[0].data[0]) # print first image label # import matplotlib.pyplot as plt # plt.imshow(inputs[0].cpu().data[0].numpy()) # show first image # plt.show() # Empirical Loss on current task: outputs = post_model(inputs) avg_empiric_loss_curr = (1 / batch_size) * loss_criterion( outputs, targets) correct_count += count_correct(outputs, targets) # for print sample_count += inputs.size(0) # Intra-task complexity of current task: # curr_complexity = get_task_complexity(prm, prior_model, post_model, # n_samples, avg_empiric_loss_curr, hyper_dvrg, n_train_tasks=n_train_tasks, noised_prior=True) avg_empiric_loss += (1 / n_MC) * avg_empiric_loss_curr # complexity += (1 / n_MC) * curr_complexity # end Monte-Carlo loop complexity = get_task_complexity(prm, prior_model, post_model, n_samples, avg_empiric_loss, hyper_dvrg, n_train_tasks=n_train_tasks, noised_prior=True) avg_empiric_loss_per_task[i_task] = avg_empiric_loss complexity_per_task[i_task] = complexity # end loop over tasks in meta-batch # Approximated total objective: if prm.complexity_type == 'Variational_Bayes': # note that avg_empiric_loss_per_task is estimated by an average over batch samples, # but its weight in the objective should be considered by how many samples there are total in the task total_objective =\ (avg_empiric_loss_per_task * n_samples_per_task + complexity_per_task).mean() * n_train_tasks + meta_complex_term # total_objective = ( avg_empiric_loss_per_task * n_samples_per_task + complexity_per_task).mean() + meta_complex_term else: total_objective =\ avg_empiric_loss_per_task.mean() + complexity_per_task.mean() + meta_complex_term info = { 'sample_count': sample_count, 'correct_count': correct_count, 'avg_empirical_loss': avg_empiric_loss_per_task.mean().item(), 'avg_intra_task_comp': complexity_per_task.mean().item(), 'meta_comp': meta_complex_term.item() } return total_objective, info
def run_train_epoch(i_epoch): log_interval = 500 post_model.train() for batch_idx, batch_data in enumerate(train_loader): # get batch data: inputs, targets = data_gen.get_batch_vars(batch_data, prm) batch_size = inputs.shape[0] correct_count = 0 sample_count = 0 # Monte-Carlo iterations: n_MC = prm.n_MC avg_empiric_loss = 0 complexity_term = 0 for i_MC in range(n_MC): # Calculate empirical loss: outputs = post_model(inputs) avg_empiric_loss_curr = (1 / batch_size) * loss_criterion( outputs, targets) # complexity_curr = get_task_complexity(prm, prior_model, post_model, # n_train_samples, avg_empiric_loss_curr) avg_empiric_loss += (1 / n_MC) * avg_empiric_loss_curr # complexity_term += (1 / n_MC) * complexity_curr correct_count += count_correct(outputs, targets) sample_count += inputs.size(0) # end monte-carlo loop complexity_term = get_task_complexity(prm, prior_model, post_model, n_train_samples, avg_empiric_loss) # Approximated total objective (for current batch): if prm.complexity_type == 'Variational_Bayes': # note that avg_empiric_loss_per_task is estimated by an average over batch samples, # but its weight in the objective should be considered by how many samples there are total in the task total_objective = avg_empiric_loss * ( n_train_samples) + complexity_term else: total_objective = avg_empiric_loss + complexity_term # Take gradient step with the posterior: grad_step(total_objective, optimizer, lr_schedule, prm.lr, i_epoch) # Print status: if batch_idx % log_interval == 0: batch_acc = correct_count / sample_count print( cmn.status_string(i_epoch, prm.n_meta_test_epochs, batch_idx, n_batches, batch_acc, total_objective.item()) + ' Empiric Loss: {:.4}\t Intra-Comp. {:.4}'.format( avg_empiric_loss.item(), complexity_term.item()))