def run_train_epoch(i_epoch): log_interval = 500 post_model.train() for batch_idx, batch_data in enumerate(train_loader): correct_count = 0 sample_count = 0 # Monte-Carlo iterations: n_MC = prm.n_MC task_empirical_loss = 0 task_complexity = 0 for i_MC in range(n_MC): # get batch: inputs, targets = data_gen.get_batch_vars(batch_data, prm) # Calculate empirical loss: outputs = post_model(inputs) curr_empirical_loss = loss_criterion(outputs, targets) curr_empirical_loss, curr_complexity = get_bayes_task_objective( prm, prior_model, post_model, n_train_samples, curr_empirical_loss, noised_prior=False) task_empirical_loss += (1 / n_MC) * curr_empirical_loss task_complexity += (1 / n_MC) * curr_complexity correct_count += count_correct(outputs, targets) sample_count += inputs.size(0) # Total objective: total_objective = task_empirical_loss + task_complexity # Take gradient step with the posterior: grad_step(total_objective, optimizer, lr_schedule, prm.lr, i_epoch) # Print status: if batch_idx % log_interval == 0: batch_acc = correct_count / sample_count print( cmn.status_string(i_epoch, prm.n_meta_test_epochs, batch_idx, n_batches, batch_acc, total_objective.item()) + ' Empiric Loss: {:.4}\t Intra-Comp. {:.4}'.format( task_empirical_loss.item(), task_complexity.item())) data_objective.append(total_objective.item()) data_accuracy.append(batch_acc) data_emp_loss.append(task_empirical_loss.item()) data_task_comp.append(task_complexity.item()) return total_objective.item()
def run_train_epoch(i_epoch): # # Adjust randomness (eps_std) # if hasattr(prm, 'use_randomness_schedeule') and prm.use_randomness_schedeule: # if i_epoch > prm.randomness_full_epoch: # eps_std = 1.0 # elif i_epoch > prm.randomness_init_epoch: # eps_std = (i_epoch - prm.randomness_init_epoch) / (prm.randomness_full_epoch - prm.randomness_init_epoch) # else: # eps_std = 0.0 # turn off randomness # post_model.set_eps_std(eps_std) # post_model.set_eps_std(0.00) # debug complexity_term = 0 post_model.train() for batch_idx, batch_data in enumerate(train_loader): # Monte-Carlo iterations: empirical_loss = 0 n_MC = prm.n_MC for i_MC in range(n_MC): # get batch: inputs, targets = data_gen.get_batch_vars(batch_data, prm) # calculate objective: outputs = post_model(inputs) empirical_loss_c = loss_criterion(outputs, targets) empirical_loss += (1 / n_MC) * empirical_loss_c # complexity/prior term: if prior_model: empirical_loss, complexity_term = get_bayes_task_objective( prm, prior_model, post_model, n_train_samples, empirical_loss) else: complexity_term = 0.0 # Total objective: objective = empirical_loss + complexity_term # Take gradient step: grad_step(objective, optimizer, lr_schedule, prm.lr, i_epoch) # Print status: log_interval = 500 if batch_idx % log_interval == 0: batch_acc = correct_rate(outputs, targets) print( cmn.status_string(i_epoch, prm.num_epochs, batch_idx, n_batches, batch_acc, objective.data[0]) + ' Loss: {:.4}\t Comp.: {:.4}'.format( get_value(empirical_loss), get_value(complexity_term)))
def get_objective(prior_model, prm, mb_data_loaders, mb_iterators, mb_posteriors_models, loss_criterion, n_train_tasks): ''' Calculate objective based on tasks in meta-batch ''' # note: it is OK if some tasks appear several times in the meta-batch n_tasks_in_mb = len(mb_data_loaders) sum_empirical_loss = 0 sum_intra_task_comp = 0 correct_count = 0 sample_count = 0 #set_trace() # KLD between hyper-posterior and hyper-prior: hyper_kl = (1 / (2 * prm.kappa_prior**2)) * net_norm( prior_model, p=2) #net_norm is L2-regularization # Hyper-prior term: meta_complex_term = get_meta_complexity_term(hyper_kl, prm, n_train_tasks) sum_w_kld = 0.0 sum_b_kld = 0.0 # ----------- loop over tasks in meta-batch -----------------------------------# for i_task in range(n_tasks_in_mb): n_samples = mb_data_loaders[i_task]['n_train_samples'] # get sample-batch data from current task to calculate the empirical loss estimate: batch_data = data_gen.get_next_batch_cyclic( mb_iterators[i_task], mb_data_loaders[i_task]['train']) # The posterior model corresponding to the task in the batch: post_model = mb_posteriors_models[i_task] post_model.train() # Monte-Carlo iterations: n_MC = prm.n_MC task_empirical_loss = 0 task_complexity = 0 # ----------- Monte-Carlo loop -----------------------------------# for i_MC in range(n_MC): # get batch variables: inputs, targets = data_gen.get_batch_vars(batch_data, prm) # Debug # print(targets[0].data[0]) # print first image label # import matplotlib.pyplot as plt # plt.imshow(inputs[0].cpu().data[0].numpy()) # show first image # plt.show() # Empirical Loss on current task: outputs = post_model(inputs) curr_empirical_loss = loss_criterion(outputs, targets) correct_count += count_correct(outputs, targets) sample_count += inputs.size(0) # Intra-task complexity of current task: curr_empirical_loss, curr_complexity, task_info = get_bayes_task_objective( prm, prior_model, post_model, n_samples, curr_empirical_loss, hyper_kl, n_train_tasks=n_train_tasks) sum_w_kld += task_info["w_kld"] sum_b_kld += task_info["b_kld"] task_empirical_loss += (1 / n_MC) * curr_empirical_loss task_complexity += (1 / n_MC) * curr_complexity # end Monte-Carlo loop sum_empirical_loss += task_empirical_loss sum_intra_task_comp += task_complexity # end loop over tasks in meta-batch avg_empirical_loss = (1 / n_tasks_in_mb) * sum_empirical_loss avg_intra_task_comp = (1 / n_tasks_in_mb) * sum_intra_task_comp avg_w_kld += (1 / n_tasks_in_mb) * sum_w_kld avg_b_kld += (1 / n_tasks_in_mb) * sum_b_kld # Approximated total objective: total_objective = avg_empirical_loss + prm.task_complex_w * avg_intra_task_comp + prm.meta_complex_w * meta_complex_term info = { 'sample_count': get_value(sample_count), 'correct_count': get_value(correct_count), 'avg_empirical_loss': get_value(avg_empirical_loss), 'avg_intra_task_comp': get_value(avg_intra_task_comp), 'meta_comp': get_value(meta_complex_term), 'w_kld': avg_w_kld, 'b_kld': avg_b_kld } return total_objective, info
def run_train_epoch(i_epoch): log_interval = 500 post_model.train() train_info = {} train_info["task_comp"] = 0.0 train_info["total_loss"] = 0.0 cnt = 0 for batch_idx, batch_data in enumerate(train_loader): cnt += 1 correct_count = 0 sample_count = 0 # Monte-Carlo iterations: n_MC = prm.n_MC task_empirical_loss = 0 task_complexity = 0 for i_MC in range(n_MC): # get batch: inputs, targets = data_gen.get_batch_vars(batch_data, prm) # Calculate empirical loss: outputs = post_model(inputs) curr_empirical_loss = loss_criterion(outputs, targets) #hyper_kl = 0 when testing curr_empirical_loss, curr_complexity, task_info = get_bayes_task_objective( prm, prior_model, post_model, n_train_samples, curr_empirical_loss, noised_prior=False) task_empirical_loss += (1 / n_MC) * curr_empirical_loss task_complexity += (1 / n_MC) * curr_complexity correct_count += count_correct(outputs, targets) sample_count += inputs.size(0) # Total objective: total_objective = task_empirical_loss + prm.task_complex_w * task_complexity train_info["task_comp"] += task_complexity.data[0] train_info["total_loss"] += total_objective.data[0] # Take gradient step with the posterior: grad_step(total_objective, optimizer, lr_schedule, prm.lr, i_epoch) # Print status: if batch_idx % log_interval == 0: batch_acc = correct_count / sample_count write_to_log( cmn.status_string(i_epoch, prm.n_meta_test_epochs, batch_idx, n_batches, batch_acc, total_objective.data[0]) + ' Empiric Loss: {:.4}\t Intra-Comp. {:.4}, w_kld {:.4}, b_kld {:.4}' .format(task_empirical_loss.data[0], task_complexity.data[0], task_info["w_kld"], task_info["b_kld"]), prm) train_info["task_comp"] /= cnt train_info["total_loss"] /= cnt return train_info