def step(self, batch: Any, batch_idx: BatchIndex): """ This method is called for each batch """ self.model.train(self.mode.is_train) # Get data and target labels data, target = batch[0].to(self.model.device), batch[1].to(self.model.device) if self.mode.is_train: tracker.add_global_step(data.shape[0] * data.shape[1]) # Run the model output = self.model(data) # Calculate loss loss = self.loss_func(output, target) # Calculate accuracy self.accuracy(output, target) # Log the loss tracker.add("loss.", loss) # If we are in training mode, calculate the gradients if self.mode.is_train: loss.backward() self.optimizer.step() if batch_idx.is_last: tracker.add('model', self.model) self.optimizer.zero_grad() tracker.save()
def track_disk(): res = psutil.disk_usage(lab.get_path()) tracker.add({ 'disk.free': res.free, 'disk.total': res.total, 'disk.used': res.used, })
def process(self, batch: any, state: any): device = self.discriminator.device data, target = batch data, target = data.to(device), target.to(device) with monit.section("generator"): latent = torch.normal(0, 1, (data.shape[0], 100), device=device) if MODE_STATE.is_train: self.generator_optimizer.zero_grad() logits = self.discriminator(self.generator(latent)) loss = self.generator_loss(logits) tracker.add("loss.generator.", loss) if MODE_STATE.is_train: loss.backward() self.generator_optimizer.step() with monit.section("discriminator"): latent = torch.normal(0, 1, (data.shape[0], 100), device=device) if MODE_STATE.is_train: self.discriminator_optimizer.zero_grad() logits_false = self.discriminator(self.generator(latent).detach()) logits_true = self.discriminator(data) loss = self.discriminator_loss(logits_true, logits_false) tracker.add("loss.generator.", loss) if MODE_STATE.is_train: loss.backward() self.discriminator_optimizer.step() return {}, None
def track_memory(): res = psutil.virtual_memory() tracker.add({ 'memory.total': res.total, 'memory.used': res.used, 'memory.available': res.available, })
def optimize_discriminator(self, data_x: torch.Tensor, data_y: torch.Tensor, gen_x: torch.Tensor, gen_y: torch.Tensor, true_labels: torch.Tensor, false_labels: torch.Tensor): """ ### Optimize the discriminators with gan loss. """ # GAN Loss # \begin{align} # \bigg(D_Y\Big(y ^ {(i)}\Big) - 1\bigg) ^ 2 # + D_Y\Big(G\Big(x ^ {(i)}\Big)\Big) ^ 2 + \\ # \bigg(D_X\Big(x ^ {(i)}\Big) - 1\bigg) ^ 2 # + D_X\Big(F\Big(y ^ {(i)}\Big)\Big) ^ 2 # \end{align} loss_discriminator = ( self.gan_loss(self.discriminator_x(data_x), true_labels) + self.gan_loss(self.discriminator_x(gen_x), false_labels) + self.gan_loss(self.discriminator_y(data_y), true_labels) + self.gan_loss(self.discriminator_y(gen_y), false_labels)) # Take a step in the optimizer self.discriminator_optimizer.zero_grad() loss_discriminator.backward() self.discriminator_optimizer.step() # Log losses tracker.add({'loss.discriminator': loss_discriminator})
def train(model, optimizer, train_loader, device, train_log_interval): """This is the training code""" model.train() for batch_idx, (data, target) in enumerate(train_loader): data, target = data.to(device), target.to(device) output = model(data) loss = F.cross_entropy(output, target) optimizer.zero_grad() loss.backward() if batch_idx == 0: tracker.add('model', model) optimizer.step() # **✨ Increment the global step** tracker.add_global_step() # **✨ Store stats in the tracker** tracker.save({'loss.train': loss}) # if batch_idx % train_log_interval == 0: # **✨ Save added stats** tracker.save()
def store_optimizer_indicators(optimizer: 'Optimizer', *, models: Optional[Dict[str, torch.nn.Module]] = None, optimizer_name: str = "optimizer"): if models is None: models = {} names = {} for model_name, model in models.items(): for name, p in model.named_parameters(): names[p] = f'{model_name}.{name}' unknown = 0 for group in optimizer.param_groups: for p in group['params']: if p.grad is None: continue state = optimizer.state[p] if len(state) == 0: continue name = names.get(p, None) if name is None: name = f'unknown.{unknown}' unknown += 1 for k, v in state.items(): if isinstance(v, float) or isinstance(v, int): tracker.add(f'optim.{optimizer_name}.{name}.{k}', v) if isinstance(v, torch.Tensor): store_l1_l2(f'optim.{optimizer_name}.{name}.{k}', v)
def step(self, batch: Any, batch_idx: BatchIndex): self.model.train(self.mode.is_train) data, target = batch[0].to(self.device), batch[1].to(self.device) if self.mode.is_train: tracker.add_global_step(len(data)) is_log_activations = batch_idx.is_interval( self.log_activations_batches) with monit.section("model"): with self.mode.update(is_log_activations=is_log_activations): output = self.model(data) loss = self.loss_func(output, target) tracker.add("loss.", loss) if self.mode.is_train: with monit.section('backward'): loss.backward() if batch_idx.is_interval(self.update_batches): with monit.section('optimize'): self.optimizer.step() if batch_idx.is_interval(self.log_params_updates): tracker.add('model', self.model) self.optimizer.zero_grad() if batch_idx.is_interval(self.log_save_batches): tracker.save()
def _calc_loss(self, samples: Dict[str, torch.Tensor], clip_range: float) -> torch.Tensor: """## PPO Loss""" # Sampled observations are fed into the model to get $\pi_\theta(a_t|s_t)$ and $V^{\pi_\theta}(s_t)$; pi_val, value = self.model(samples['obs'], False) pi = Categorical(logits=pi_val) # #### Policy log_pi = pi.log_prob(samples['actions']) # *this is different from rewards* $r_t$. ratio = torch.exp(log_pi - samples['log_pis']) # The ratio is clipped to be close to 1. # Using the normalized advantage # $\bar{A_t} = \frac{\hat{A_t} - \mu(\hat{A_t})}{\sigma(\hat{A_t})}$ # introduces a bias to the policy gradient estimator, # but it reduces variance a lot. clipped_ratio = ratio.clamp(min=1.0 - clip_range, max=1.0 + clip_range) # advantages are normalized policy_reward = torch.min(ratio * samples['advantages'], clipped_ratio * samples['advantages']) policy_reward = policy_reward.mean() # #### Entropy Bonus entropy_bonus = pi.entropy() entropy_bonus = entropy_bonus.mean() # add regularization to logits to revive previously-seen impossible moves max_logit = pi_val.max().item() prob_reg = -(torch.nan_to_num(pi_val - max_logit, neginf=0)** 2).mean() * self.cur_prob_reg_weight # #### Value # Clipping makes sure the value function $V_\theta$ doesn't deviate # significantly from $V_{\theta_{OLD}}$. clipped_value = samples['values'][:, :2] + ( value[:, :2] - samples['values'][:, :2]).clamp(min=-clip_range, max=clip_range) vf_loss = torch.max((value[:, :2] - samples['returns'])**2, (clipped_value - samples['returns'])**2) vf_loss[:, 1] *= self.cur_target_prob_weight vf_loss_score = 0.5 * vf_loss[:, 0].mean() vf_loss = 0.5 * vf_loss.sum(-1).mean() # we want to maximize $\mathcal{L}^{CLIP+VF+EB}(\theta)$ # so we take the negative of it as the loss loss = -(policy_reward - self.c.vf_weight * vf_loss + \ self.cur_entropy_weight * (entropy_bonus + prob_reg)) # for monitoring approx_kl_divergence = .5 * ((samples['log_pis'] - log_pi)**2).mean() clip_fraction = (abs( (ratio - 1.0)) > clip_range).to(torch.float).mean() tracker.add({ 'policy_reward': policy_reward, 'vf_loss': vf_loss**0.5, 'vf_loss_1': vf_loss_score**0.5, 'entropy_bonus': entropy_bonus, 'kl_div': approx_kl_divergence, 'clip_fraction': clip_fraction }) return loss
def run(self, i): # Get model output self.p, logits, (self.hn, self.cn) = self.model(self.x[i], self.hn, self.cn) # Flatten outputs logits = logits.view(-1, self.p.shape[-1]) yi = self.y[i].reshape(-1) # Calculate loss loss = self.loss_func(logits, yi) # Store the states self.hn = self.hn.detach() self.cn = self.cn.detach() if self.is_train: # Take a training step self.optimizer.zero_grad() loss.backward() self.optimizer.step() tracker.add("train.loss", loss.cpu().data.item()) else: tracker.add("valid.loss", loss.cpu().data.item())
def sample(self, exploration_coefficient: float): """### Sample data""" # This doesn't need gradients with torch.no_grad(): # Sample `worker_steps` for t in range(self.worker_steps): # Get Q_values for the current observation q_value = self.model(obs_to_torch(self.obs)) # Sample actions actions = self._sample_action(q_value, exploration_coefficient) # Run sampled actions on each worker for w, worker in enumerate(self.workers): worker.child.send(("step", actions[w])) # Collect information from each worker for w, worker in enumerate(self.workers): # Get results after executing the actions next_obs, reward, done, info = worker.child.recv() # Add transition to replay buffer self.replay_buffer.add(self.obs[w], actions[w], reward, next_obs, done) # update episode information. # collect episode info, which is available if an episode finished; # this includes total reward and length of the episode - # look at `Game` to see how it works. if info: tracker.add('reward', info['reward']) tracker.add('length', info['length']) # update current observation self.obs[w] = next_obs
def iterate(self): device = get_device(self.model) correct_sum = 0 total_samples = 0 for i, (data, target) in monit.enum(self.name, self.data_loader): data, target = data.to(device), target.to(device) if self.optimizer is not None: self.optimizer.zero_grad() output = self.model(data) loss = self.loss_func(output, target) correct_sum += self.accuracy_func(output, target) total_samples += len(target) tracker.add(".loss", loss) if self.optimizer is not None: loss.backward() self.optimizer.step() if self.is_increment_global_step: tracker.add_global_step(len(target)) if self.log_interval is not None and (i + 1) % self.log_interval == 0: tracker.save() tracker.add(".accuracy", correct_sum / total_samples)
def step(self, batch: Any, batch_idx: BatchIndex): self.model.train(self.mode.is_train) data, target = batch['data'].to(self.device), batch['target'].to( self.device) target = (target - self.model.y_mean) / self.model.y_std if self.mode.is_train: tracker.add_global_step(len(data)) output = self.model(data) loss = self.loss_func(output, target) tracker.add("loss.", loss) if self.mode.is_train: loss.backward() if batch_idx.is_last: tracker.add('model', self.model) self.optimizer.step() self.optimizer.zero_grad() if not self.mode.is_train: self.output_collector(output * self.model.y_std + self.model.y_mean) tracker.save()
def run(self): pytorch_utils.add_model_indicators(self.policy) for epoch, (game, arrange) in enumerate(self.games): board = Board(arrange) # TODO change this state = board.get_current_board() for iteration in count(): logger.log('epoch : {}, iteration : {}'.format(epoch, iteration), Color.cyan) action = self.get_action(state) next_state, reward, done = self.step(board, action.item()) if done: next_state = None self.memory.push(state, action, next_state, reward) state = next_state self.train() if done: tracker.add(iterations=iteration) tracker.save() break if epoch % self.target_update == 0: self.target.load_state_dict(self.policy.state_dict()) if self.is_log_parameters: pytorch_utils.store_model_indicators(self.policy)
def step(self, batch: any, batch_idx: BatchIndex): data, target = batch[0].to(self.device), batch[1].to(self.device) if self.mode.is_train: tracker.add_global_step(target.shape[0] * target.shape[1]) with self.mode.update(is_log_activations=batch_idx.is_last): state = self.state.get() output, new_state = self.model(data, state) state = self.state_updater(state, new_state) self.state.set(state) loss = self.loss_func(output, target) tracker.add("loss.", loss) self.accuracy(output, target) self.accuracy.track() if self.mode.is_train: loss.backward() torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=self.grad_norm_clip) self.optimizer.step() if batch_idx.is_last: tracker.add('model', self.model) self.optimizer.zero_grad() tracker.save()
def forward(self, evidence: torch.Tensor, target: torch.Tensor): # Number of classes n_classes = evidence.shape[-1] # Predictions that correctly match with the target (greedy sampling based on highest probability) match = evidence.argmax(dim=-1).eq(target.argmax(dim=-1)) # Track accuracy tracker.add('accuracy.', match.sum() / match.shape[0]) # $\textcolor{orange}{\alpha_k} = e_k + 1$ alpha = evidence + 1. # $S = \sum_{k=1}^K \textcolor{orange}{\alpha_k}$ strength = alpha.sum(dim=-1) # $\hat{p}_k = \frac{\textcolor{orange}{\alpha_k}}{S}$ expected_probability = alpha / strength[:, None] # Expected probability of the selected (greedy highset probability) class expected_probability, _ = expected_probability.max(dim=-1) # Uncertainty mass $u = \frac{K}{S}$ uncertainty_mass = n_classes / strength # Track $u$ for correctly predictions tracker.add('u.succ.', uncertainty_mass.masked_select(match)) # Track $u$ for incorrect predictions tracker.add('u.fail.', uncertainty_mass.masked_select(~match)) # Track $\hat{p}_k$ for correctly predictions tracker.add('prob.succ.', expected_probability.masked_select(match)) # Track $\hat{p}_k$ for incorrect predictions tracker.add('prob.fail.', expected_probability.masked_select(~match))
def process(self, batch: any, state: any): """ This method is called for each batch """ # Get data and target labels data, target = batch[0].to(self.model.device), batch[1].to( self.model.device) # Statistics for logging, and updating the global step. # Number of samples equal to the number of tokens per sequence times the batch size. stats = {'samples': data.shape[0] * data.shape[1]} # Run the model output = self.model(data) # Calculate loss loss = self.loss_func(output, target) # Calculate accuracy stats['correct'] = self.accuracy_func(output, target) # Log the loss tracker.add("loss.", loss) # If we are in training mode, calculate the gradients if MODE_STATE.is_train: loss.backward() # Returns stats, (and state if this was a recurrent net) return stats, None
def solve(self): for t in monit.loop(self.epochs): if not self.is_online_update: for I in self.info_sets.values(): I.clear() for i in range(self.n_players): self.cfr(self.create_new_history(), cast(Player, i), [1 for _ in range(self.n_players)]) if not self.is_online_update: self.update() with monit.section("Track"): for I in self.info_sets.values(): for a in I.actions(): tracker.add({ f'strategy.{I.key}.{a}': I.strategy[a], f'average_strategy.{I.key}.{a}': I.average_strategy[a], f'regret.{I.key}.{a}': I.regret[a], f'current_regret.{I.key}.{a}': I.current_regret[a] }) if t % self.track_frequency == 0: tracker.save() logger.log() if (t + 1) % self.save_frequency == 0: experiment.save_checkpoint() logger.inspect(self.info_sets)
def train(self): start = torch.zeros((self.batch_size, self.n_cards), dtype=torch.long, device=self.device) deal(start) rep = start.view(-1, 1, self.n_cards) rep = rep.repeat(1, self.samples_size, 1) cards = start.new_zeros(self.batch_size, self.samples_size, 9) cards[:, :, :self.n_cards] = rep cards = cards.view(-1, 9) deal(cards, self.n_cards) score0 = score(cards[:, :7]).view(self.batch_size, -1) score1 = score(cards[:, -7:]).view(self.batch_size, -1) labels = cards.new_zeros((self.batch_size, 3), dtype=torch.float) labels[:, 0] = (score0 > score1).to(torch.float).mean(-1) labels[:, 1] = (score0 == score1).to(torch.float).mean(-1) labels[:, 2] = (score0 < score1).to(torch.float).mean(-1) pred = torch.log_softmax(self.model(start), dim=-1) loss = self.loss_func(pred, labels) tracker.add('train.loss', loss) self.optimizer.zero_grad() loss.backward() self.optimizer.step()
def _calc_loss(self, samples: Dict[str, torch.Tensor]) -> torch.Tensor: """ ### Calculate total loss """ # $R_t$ returns sampled from $\pi_{\theta_{OLD}}$ sampled_return = samples['values'] + samples['advantages'] # $\bar{A_t} = \frac{\hat{A_t} - \mu(\hat{A_t})}{\sigma(\hat{A_t})}$, # where $\hat{A_t}$ is advantages sampled from $\pi_{\theta_{OLD}}$. # Refer to sampling function in [Main class](#main) below # for the calculation of $\hat{A}_t$. sampled_normalized_advantage = self._normalize(samples['advantages']) # Sampled observations are fed into the model to get $\pi_\theta(a_t|s_t)$ and $V^{\pi_\theta}(s_t)$; # we are treating observations as state pi, value = self.model(samples['obs']) # $-\log \pi_\theta (a_t|s_t)$, $a_t$ are actions sampled from $\pi_{\theta_{OLD}}$ log_pi = pi.log_prob(samples['actions']) # Calculate policy loss policy_loss = self.ppo_loss(log_pi, samples['log_pis'], sampled_normalized_advantage, self.clip_range()) # Calculate Entropy Bonus # # $\mathcal{L}^{EB}(\theta) = # \mathbb{E}\Bigl[ S\bigl[\pi_\theta\bigr] (s_t) \Bigr]$ entropy_bonus = pi.entropy() entropy_bonus = entropy_bonus.mean() # Calculate value function loss value_loss = self.value_loss(value, samples['values'], sampled_return, self.clip_range()) # $\mathcal{L}^{CLIP+VF+EB} (\theta) = # \mathcal{L}^{CLIP} (\theta) + # c_1 \mathcal{L}^{VF} (\theta) - c_2 \mathcal{L}^{EB}(\theta)$ loss = (policy_loss + self.value_loss_coef() * value_loss - self.entropy_bonus_coef() * entropy_bonus) # for monitoring approx_kl_divergence = .5 * ((samples['log_pis'] - log_pi)**2).mean() # Add to tracker tracker.add({ 'policy_reward': -policy_loss, 'value_loss': value_loss, 'entropy_bonus': entropy_bonus, 'kl_div': approx_kl_divergence, 'clip_fraction': self.ppo_loss.clip_fraction }) return loss
def setup_and_add(): for t in range(10): tracker.set_scalar(f"loss1.{t}", is_print=t == 0) experiment.start() for i in monit.loop(1000): for t in range(10): tracker.add({f'loss1.{t}': i}) tracker.save()
def optimize_generators(self, data_x: torch.Tensor, data_y: torch.Tensor, true_labels: torch.Tensor): """ ### Optimize the generators with identity, gan and cycle losses. """ # Change to training mode self.generator_xy.train() self.generator_yx.train() # Identity loss # $$\lVert F(G(x^{(i)})) - x^{(i)} \lVert_1\ # \lVert G(F(y^{(i)})) - y^{(i)} \rVert_1$$ loss_identity = ( self.identity_loss(self.generator_yx(data_x), data_x) + self.identity_loss(self.generator_xy(data_y), data_y)) # Generate images $G(x)$ and $F(y)$ gen_y = self.generator_xy(data_x) gen_x = self.generator_yx(data_y) # GAN loss # $$\bigg(D_Y\Big(G\Big(x^{(i)}\Big)\Big) - 1\bigg)^2 # + \bigg(D_X\Big(F\Big(y^{(i)}\Big)\Big) - 1\bigg)^2$$ loss_gan = (self.gan_loss(self.discriminator_y(gen_y), true_labels) + self.gan_loss(self.discriminator_x(gen_x), true_labels)) # Cycle loss # $$ # \lVert F(G(x^{(i)})) - x^{(i)} \lVert_1 + # \lVert G(F(y^{(i)})) - y^{(i)} \rVert_1 # $$ loss_cycle = (self.cycle_loss(self.generator_yx(gen_y), data_x) + self.cycle_loss(self.generator_xy(gen_x), data_y)) # Total loss loss_generator = (loss_gan + self.cyclic_loss_coefficient * loss_cycle + self.identity_loss_coefficient * loss_identity) # Take a step in the optimizer self.generator_optimizer.zero_grad() loss_generator.backward() self.generator_optimizer.step() # Log losses tracker.add({ 'loss.generator': loss_generator, 'loss.generator.cycle': loss_cycle, 'loss.generator.gan': loss_gan, 'loss.generator.identity': loss_identity }) # Return generated images return gen_x, gen_y
def __call__(self): if not self._change_tracked or self._last_tracked_step != tracker.get_global_step(): if self._key is None: warnings.warn('Register dynamic schedules with `experiment.configs` to update them live from the app') else: tracker.add(f'hp.{self._key}', self._value) self._change_tracked = True self._last_tracked_step = tracker.get_global_step() return self._value
def on_train_batch_end(self, batch, logs=None): if logs is None: logs = {} tracker.add_global_step() if 'size' in logs: del logs['size'] if 'batch' in logs: del logs['batch'] tracker.add(logs) if batch % self.save_batch_frequency == 0: tracker.save()
def process(self, batch: any, state: any): device = self.discriminator.device data, target = batch data, target = data.to(device), target.to(device) # Train the discriminator with monit.section("discriminator"): for _ in range(self.discriminator_k): latent = torch.randn(data.shape[0], 100, device=device) if MODE_STATE.is_train: self.discriminator_optimizer.zero_grad() logits_true = self.discriminator(data) logits_false = self.discriminator( self.generator(latent).detach()) loss_true, loss_false = self.discriminator_loss( logits_true, logits_false) loss = loss_true + loss_false # Log stuff tracker.add("loss.discriminator.true.", loss_true) tracker.add("loss.discriminator.false.", loss_false) tracker.add("loss.discriminator.", loss) # Train if MODE_STATE.is_train: loss.backward() if MODE_STATE.is_log_parameters: pytorch_utils.store_model_indicators( self.discriminator, 'discriminator') self.discriminator_optimizer.step() # Train the generator with monit.section("generator"): latent = torch.randn(data.shape[0], 100, device=device) if MODE_STATE.is_train: self.generator_optimizer.zero_grad() generated_images = self.generator(latent) logits = self.discriminator(generated_images) loss = self.generator_loss(logits) # Log stuff tracker.add('generated', generated_images[0:5]) tracker.add("loss.generator.", loss) # Train if MODE_STATE.is_train: loss.backward() if MODE_STATE.is_log_parameters: pytorch_utils.store_model_indicators( self.generator, 'generator') self.generator_optimizer.step() return {'samples': len(data)}, None
def __call__(self, info_sets: Dict[str, InfoSet]): """ Track the data from all information sets """ for I in info_sets.values(): avg_strategy = I.get_average_strategy() for a in I.actions(): tracker.add({ f'strategy.{I.key}.{a}': I.strategy[a], f'average_strategy.{I.key}.{a}': avg_strategy[a], f'regret.{I.key}.{a}': I.regret[a], })
def train(self): """ ### Train the model """ # Loop for the given number of epochs for _ in monit.loop(self.epochs): # Iterate over the minibatches for i, batch in monit.enum('Train', self.dataloader): # Move data to the device data, target = batch[0].to(self.device), batch[1].to( self.device) # Set tracker step, as the number of characters trained on tracker.add_global_step(data.shape[0] * data.shape[1]) # Set model state to training self.model.train() # Evaluate the model output = self.model(data) # Calculate loss loss = self.loss_func(output.view(-1, output.shape[-1]), target.view(-1)) # Log the loss tracker.add("loss.train", loss) # Calculate gradients loss.backward() # Clip gradients torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=self.grad_norm_clip) # Take optimizer step self.optimizer.step() # Log the model parameters and gradients if (i + 1) % 100 == 0: tracker.add('model', self.model) # Clear the gradients self.optimizer.zero_grad() # Generate a sample if (i + 1) % 100 == 0: self.model.eval() with torch.no_grad(): self.sample() # Save the tracked metrics if (i + 1) % 10 == 0: tracker.save() # Save the model experiment.save_checkpoint()
def main(): # Reset global step because we incremented in previous loop tracker.set_global_step(0) for i in range(1, 401): tracker.add_global_step() loss = train() tracker.add(loss=loss) if i % 10 == 0: tracker.save() if i % 100 == 0: logger.log() time.sleep(0.02)
def add_save(): arr = torch.zeros((1000, 1000)) experiment.start() for i in monit.loop(N): for t in range(10): arr += 1 for t in range(10): if i == 0: tracker.set_scalar(f"loss1.{t}", is_print=t == 0) for t in range(10): tracker.add({f'loss1.{t}': i}) tracker.save()
def step(self, batch: any, batch_idx: BatchIndex): """ ### Training or validation step """ # Set training/eval mode self.model.train(self.mode.is_train) # Move data to the device data, target = batch[0].to(self.device), batch[1].to(self.device) # Update global step (number of tokens processed) when in training mode if self.mode.is_train: tracker.add_global_step(data.shape[0] * data.shape[1]) # Whether to capture model outputs with self.mode.update(is_log_activations=batch_idx.is_last and self.is_log_model_activations): # Get model outputs. # It's returning a tuple for states when using RNNs. # This is not implemented yet. 😜 output, *_ = self.model(data) # Calculate and log loss loss = self.loss_func(output, target) tracker.add("loss.", loss) # Calculate and log accuracy self.accuracy(output, target) self.accuracy.track() self.other_metrics(output, target) # Train the model if self.mode.is_train: # Calculate gradients loss.backward() # Clip gradients torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=self.grad_norm_clip) # Take optimizer step self.optimizer.step() # Log the model parameters and gradients on last batch of every epoch if batch_idx.is_last and self.is_log_model_params_grads: tracker.add('model', self.model) # Clear the gradients self.optimizer.zero_grad() # Save the tracked metrics tracker.save()