示例#1
0
文件: handler.py 项目: dselsam/oracle
 def handle_init(self, init_cmd):
     response = Response()
     response.success = False
     self.model = GenericModel(self.model_cfg).to(self.device)
     self.optimizer = optim.Adam(self.model.parameters(), lr=self.optim_cfg['learning_rate'])
     response.msg = "Model and optimizer reinitialized"
     response.success = True
     return response
示例#2
0
文件: handler.py 项目: dselsam/oracle
 def handle_load(self, load_cmd):
     response = Response()
     checkpoint = torch.load(load_cmd.filename)
     self.model.load_state_dict(checkpoint['model_state_dict'])
     self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
     response.msg = "Model and optimizer loaded"
     response.success = True
     return response
示例#3
0
文件: handler.py 项目: dselsam/oracle
 def handle_save(self, save_cmd):
     response = Response()
     checkpoint = {
         'model_state_dict': self.model.state_dict(),
         'optimizer_state_dict': self.optimizer.state_dict()
     }
     torch.save(checkpoint, save_cmd.filename)
     response.msg = f"Model and optimizer saved at {save_cmd.filename}"
     response.success = True
     return response
示例#4
0
文件: handler.py 项目: dselsam/oracle
 def handle_predict(self, predict_cmd):
     response = Response()
     response.success = False
     self.model.eval()
     for choicepoint in predict_cmd.choicepoints:
         with torch.no_grad():
             log_prob = self.model(choicepoint.snapshot, choicepoint.choices)
             prediction = log_prob
             response.predictions.append(prediction)
     response.success = True
     return response
示例#5
0
文件: handler.py 项目: dselsam/oracle
 def handle_valid(self, valid_cmd):
     total_loss = 0.0
     n_steps = 0
     self.model.eval()
     for datapoint in valid_cmd.datapoints:
         snapshot, choices, choice_id = unpack_datapoint(datapoint)
         choice_id = choice_id.to(self.device)
         with torch.no_grad():
             logits = self.model(snapshot, choices)
             loss = self.loss(logits, choice_id)
         total_loss += loss
         n_steps += 1
     response = Response()
     response.loss = total_loss / n_steps if n_steps > 0 else float('nan')
     response.success = True
     return response
示例#6
0
文件: handler.py 项目: dselsam/oracle
    def handle_train(self, train_cmd):
        total_loss = 0.0
        n_steps = 0
        self.model.train()
        for _ in range(train_cmd.nEpochs):
            for datapoint in train_cmd.datapoints:
                snapshot, choices, choice_id = unpack_datapoint(datapoint)
                choice_id = choice_id.to(self.device)
                log_prob = self.model(snapshot, choices)
                loss = self.loss(log_prob, choice_id)

                self.model.zero_grad()
                loss.backward()
                torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.optim_cfg['grad_norm_clip'])
                self.optimizer.step()
                total_loss += loss
                n_steps += 1
        response = Response()
        response.loss = total_loss / n_steps if n_steps > 0 else float('nan')
        response.success = True
        return response