示例#1
0
文件: handler.py 项目: dselsam/oracle
 def handle_valid(self, valid_cmd):
     total_loss = 0.0
     n_steps = 0
     self.model.eval()
     for datapoint in valid_cmd.datapoints:
         snapshot, choices, choice_id = unpack_datapoint(datapoint)
         choice_id = choice_id.to(self.device)
         with torch.no_grad():
             logits = self.model(snapshot, choices)
             loss = self.loss(logits, choice_id)
         total_loss += loss
         n_steps += 1
     response = Response()
     response.loss = total_loss / n_steps if n_steps > 0 else float('nan')
     response.success = True
     return response
示例#2
0
文件: handler.py 项目: dselsam/oracle
    def handle_train(self, train_cmd):
        total_loss = 0.0
        n_steps = 0
        self.model.train()
        for _ in range(train_cmd.nEpochs):
            for datapoint in train_cmd.datapoints:
                snapshot, choices, choice_id = unpack_datapoint(datapoint)
                choice_id = choice_id.to(self.device)
                log_prob = self.model(snapshot, choices)
                loss = self.loss(log_prob, choice_id)

                self.model.zero_grad()
                loss.backward()
                torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.optim_cfg['grad_norm_clip'])
                self.optimizer.step()
                total_loss += loss
                n_steps += 1
        response = Response()
        response.loss = total_loss / n_steps if n_steps > 0 else float('nan')
        response.success = True
        return response