def run_train_epoch(i_epoch): # # Adjust randomness (eps_std) # if hasattr(prm, 'use_randomness_schedeule') and prm.use_randomness_schedeule: # if i_epoch > prm.randomness_full_epoch: # eps_std = 1.0 # elif i_epoch > prm.randomness_init_epoch: # eps_std = (i_epoch - prm.randomness_init_epoch) / (prm.randomness_full_epoch - prm.randomness_init_epoch) # else: # eps_std = 0.0 # turn off randomness # post_model.set_eps_std(eps_std) # post_model.set_eps_std(0.00) # debug complexity_term = 0 post_model.train() for batch_idx, batch_data in enumerate(train_loader): # Monte-Carlo iterations: empirical_loss = 0 n_MC = prm.n_MC for i_MC in range(n_MC): # get batch: inputs, targets = data_gen.get_batch_vars(batch_data, prm) # calculate objective: outputs = post_model(inputs) empirical_loss_c = loss_criterion(outputs, targets) empirical_loss += (1 / n_MC) * empirical_loss_c # complexity/prior term: if prior_model: empirical_loss, complexity_term = get_bayes_task_objective( prm, prior_model, post_model, n_train_samples, empirical_loss) else: complexity_term = 0.0 # Total objective: objective = empirical_loss + complexity_term # Take gradient step: grad_step(objective, optimizer, lr_schedule, prm.lr, i_epoch) # Print status: log_interval = 500 if batch_idx % log_interval == 0: batch_acc = correct_rate(outputs, targets) print( cmn.status_string(i_epoch, prm.num_epochs, batch_idx, n_batches, batch_acc, objective.data[0]) + ' Loss: {:.4}\t Comp.: {:.4}'.format( get_value(empirical_loss), get_value(complexity_term)))
def run_train_epoch(i_epoch, log_mat): post_model.train() for batch_idx, batch_data in enumerate(train_loader): # get batch data: inputs, targets = data_gen.get_batch_vars(batch_data, prm) batch_size = inputs.shape[0] # Monte-Carlo iterations: avg_empiric_loss = torch.zeros(1, device=prm.device) n_MC = prm.n_MC for i_MC in range(n_MC): # calculate objective: outputs = post_model(inputs) avg_empiric_loss_curr = (1 / batch_size) * loss_criterion( outputs, targets) avg_empiric_loss += (1 / n_MC) * avg_empiric_loss_curr # complexity/prior term: if prior_model: complexity_term = get_task_complexity(prm, prior_model, post_model, n_train_samples, avg_empiric_loss) else: complexity_term = torch.zeros(1, device=prm.device) # Total objective: objective = avg_empiric_loss + complexity_term # Take gradient step: grad_step(objective, optimizer, lr_schedule, prm.lr, i_epoch) # Print status: log_interval = 1000 if batch_idx % log_interval == 0: batch_acc = correct_rate(outputs, targets) print( cmn.status_string(i_epoch, prm.num_epochs, batch_idx, n_batches, batch_acc, objective.item()) + ' Loss: {:.4}\t Comp.: {:.4}'.format( avg_empiric_loss.item(), complexity_term.item())) # End batch loop # save results for epochs-figure: if figure_flag and (i_epoch % prm.log_figure['interval_epochs'] == 0): save_result_for_figure(post_model, prior_model, data_loader, prm, log_mat, i_epoch)
def run_train_epoch(i_epoch): log_interval = 500 model.train() for batch_idx, batch_data in enumerate(train_loader): # get batch: inputs, targets = data_gen.get_batch_vars(batch_data, prm) # Calculate loss: outputs = model(inputs) loss = loss_criterion(outputs, targets) # Take gradient step: grad_step(loss, optimizer, lr_schedule, prm.lr, i_epoch) # Print status: if batch_idx % log_interval == 0: batch_acc = correct_rate(outputs, targets) print( cmn.status_string(i_epoch, prm.num_epochs, batch_idx, n_batches, batch_acc, get_value(loss)))