def handle_valid(self, valid_cmd): total_loss = 0.0 n_steps = 0 self.model.eval() for datapoint in valid_cmd.datapoints: snapshot, choices, choice_id = unpack_datapoint(datapoint) choice_id = choice_id.to(self.device) with torch.no_grad(): logits = self.model(snapshot, choices) loss = self.loss(logits, choice_id) total_loss += loss n_steps += 1 response = Response() response.loss = total_loss / n_steps if n_steps > 0 else float('nan') response.success = True return response
def handle_train(self, train_cmd): total_loss = 0.0 n_steps = 0 self.model.train() for _ in range(train_cmd.nEpochs): for datapoint in train_cmd.datapoints: snapshot, choices, choice_id = unpack_datapoint(datapoint) choice_id = choice_id.to(self.device) log_prob = self.model(snapshot, choices) loss = self.loss(log_prob, choice_id) self.model.zero_grad() loss.backward() torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.optim_cfg['grad_norm_clip']) self.optimizer.step() total_loss += loss n_steps += 1 response = Response() response.loss = total_loss / n_steps if n_steps > 0 else float('nan') response.success = True return response