def rollout(self, net: Model, rollout: int, value_targets: torch.Tensor): """Saves statistics after a rollout has been performed for understanding the loss development :param torch.nn.Model net: The current net, used for saving values and policies of first 12 states :param rollout int: The rollout number. Used to determine whether it is evaluation time => check targets :param torch.Tensor value_targets: Used for visualizing value change """ # First time if self.params is None: self.params = net.get_params() # Keeping track of the entropy off on the 12-dimensional log-probability policy-output entropies = [entropy(policy, axis=1) for policy in self.rollout_policy] #Currently: Mean over all games in entire rollout. Maybe we want it more fine grained later. self.policy_entropies.append(np.mean( [np.nanmean(entropy) for entropy in entropies] )) self.rollout_policy = list() #reset for next rollout if rollout in self.evaluations: net.eval() # Calculating value targets targets = value_targets.cpu().numpy().reshape((-1, self.depth)) self.avg_value_targets.append(targets.mean(axis=0)) # Calculating model change model_change = torch.sqrt((net.get_params()-self.params)**2).mean().cpu() model_total_change = torch.sqrt((net.get_params()-self.orig_params)**2).mean().cpu() self.params = net.get_params() self.param_changes.append(float(model_change)) self.param_total_changes.append(model_total_change) #In the beginning: Calculate value given to first 12 substates if rollout <= self.extra_evals: self.first_state_values.append( net(self.first_states, policy=False, value=True).detach().cpu().numpy() ) net.train()
def train(self, net: Model) -> (Model, Model): """ Training loop: generates data, optimizes parameters, evaluates (sometimes) and repeats. Trains `net` for `self.rollouts` rollouts each consisting of `self.rollout_games` games and scrambled `self.rollout_depth`. The network is evaluated for each rollout number in `self.evaluations` according to `self.evaluator`. Stores multiple performance and training results. :param torch.nn.Model net: The network to be trained. Must accept input consistent with cube.get_oh_size() :return: The network after all evaluations and the network with the best evaluation score (win fraction) :rtype: (torch.nn.Model, torch.nn.Model) """ self.tt.reset() self.tt.tick() self.states_per_rollout = self.rollout_depth * self.rollout_games self.log(f"Beginning training. Optimization is performed in batches of {self.batch_size}") self.log("\n".join([ f"Rollouts: {self.rollouts}", f"Each consisting of {self.rollout_games} games with a depth of {self.rollout_depth}", f"Evaluations: {len(self.evaluation_rollouts)}", ])) best_solve = 0 best_net = net.clone() self.agent.net = net if self.with_analysis: self.analysis.orig_params = net.get_params() generator_net = net.clone() alpha = 1 if self.alpha_update == 1 else 0 optimizer = self.optim(net.parameters(), lr=self.lr) lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1, self.gamma) self.policy_losses = np.zeros(self.rollouts) self.value_losses = np.zeros(self.rollouts) self.train_losses = np.empty(self.rollouts) self.sol_percents = list() for rollout in range(self.rollouts): reset_cuda() generator_net = self._update_gen_net(generator_net, net) if self.tau != 1 else net self.tt.profile("ADI training data") training_data, policy_targets, value_targets, loss_weights = self.ADI_traindata(generator_net, alpha) self.tt.profile("To cuda") training_data = training_data.to(gpu) policy_targets = policy_targets.to(gpu) value_targets = value_targets.to(gpu) loss_weights = loss_weights.to(gpu) self.tt.end_profile("To cuda") self.tt.end_profile("ADI training data") reset_cuda() self.tt.profile("Training loop") net.train() batches = self._get_batches(self.states_per_rollout, self.batch_size) for i, batch in enumerate(batches): optimizer.zero_grad() policy_pred, value_pred = net(training_data[batch], policy=True, value=True) # Use loss on both policy and value policy_loss = self.policy_criterion(policy_pred, policy_targets[batch]) * loss_weights[batch] value_loss = self.value_criterion(value_pred.squeeze(), value_targets[batch]) * loss_weights[batch] loss = torch.mean(policy_loss + value_loss) loss.backward() optimizer.step() self.policy_losses[rollout] += policy_loss.detach().cpu().numpy().mean() / len(batches) self.value_losses[rollout] += value_loss.detach().cpu().numpy().mean() / len(batches) if self.with_analysis: #Save policy output to compute entropy with torch.no_grad(): self.analysis.rollout_policy.append( torch.nn.functional.softmax(policy_pred.detach(), dim=0).cpu().numpy() ) self.train_losses[rollout] = (self.policy_losses[rollout] + self.value_losses[rollout]) self.tt.end_profile("Training loop") # Updates learning rate and alpha if rollout and self.update_interval and rollout % self.update_interval == 0: if self.gamma != 1: lr_scheduler.step() lr = optimizer.param_groups[0]["lr"] self.log(f"Updated learning rate from {lr/self.gamma:.2e} to {lr:.2e}") if (alpha + self.alpha_update <= 1 or np.isclose(alpha + self.alpha_update, 1)) and self.alpha_update: alpha += self.alpha_update self.log(f"Updated alpha from {alpha-self.alpha_update:.2f} to {alpha:.2f}") elif alpha < 1 and alpha + self.alpha_update > 1 and self.alpha_update: self.log(f"Updated alpha from {alpha:.2f} to 1") alpha = 1 if self.log.is_verbose() or rollout in (np.linspace(0, 1, 20)*self.rollouts).astype(int): self.log(f"Rollout {rollout} completed with mean loss {self.train_losses[rollout]}") if self.with_analysis: self.tt.profile("Analysis of rollout") self.analysis.rollout(net, rollout, value_targets) self.tt.end_profile("Analysis of rollout") if rollout in self.evaluation_rollouts: net.eval() self.agent.net = net self.tt.profile(f"Evaluating using agent {self.agent}") with unverbose: eval_results, _, _ = self.evaluator.eval(self.agent) eval_reward = (eval_results != -1).mean() self.sol_percents.append(eval_reward) self.tt.end_profile(f"Evaluating using agent {self.agent}") if eval_reward > best_solve: best_solve = eval_reward best_net = net.clone() self.log(f"Updated best net with solve rate {eval_reward*100:.2f} % at depth {self.evaluator.scrambling_depths}") self.log.section("Finished training") if len(self.evaluation_rollouts): self.log(f"Best net solves {best_solve*100:.2f} % of games at depth {self.evaluator.scrambling_depths}") self.log.verbose("Training time distribution") self.log.verbose(self.tt) total_time = self.tt.tock() eval_time = self.tt.profiles[f'Evaluating using agent {self.agent}'].sum() if len(self.evaluation_rollouts) else 0 train_time = self.tt.profiles["Training loop"].sum() adi_time = self.tt.profiles["ADI training data"].sum() nstates = self.rollouts * self.rollout_games * self.rollout_depth * cube.action_dim states_per_sec = int(nstates / (adi_time+train_time)) self.log("\n".join([ f"Total running time: {self.tt.stringify_time(total_time, TimeUnit.second)}", f"- Training data for ADI: {self.tt.stringify_time(adi_time, TimeUnit.second)} or {adi_time/total_time*100:.2f} %", f"- Training time: {self.tt.stringify_time(train_time, TimeUnit.second)} or {train_time/total_time*100:.2f} %", f"- Evaluation time: {self.tt.stringify_time(eval_time, TimeUnit.second)} or {eval_time/total_time*100:.2f} %", f"States witnessed incl. substates: {TickTock.thousand_seps(nstates)}", f"- Per training second: {TickTock.thousand_seps(states_per_sec)}", ])) return net, best_net