def handle_init(self, init_cmd): response = Response() response.success = False self.model = GenericModel(self.model_cfg).to(self.device) self.optimizer = optim.Adam(self.model.parameters(), lr=self.optim_cfg['learning_rate']) response.msg = "Model and optimizer reinitialized" response.success = True return response
def handle_load(self, load_cmd): response = Response() checkpoint = torch.load(load_cmd.filename) self.model.load_state_dict(checkpoint['model_state_dict']) self.optimizer.load_state_dict(checkpoint['optimizer_state_dict']) response.msg = "Model and optimizer loaded" response.success = True return response
def handle_save(self, save_cmd): response = Response() checkpoint = { 'model_state_dict': self.model.state_dict(), 'optimizer_state_dict': self.optimizer.state_dict() } torch.save(checkpoint, save_cmd.filename) response.msg = f"Model and optimizer saved at {save_cmd.filename}" response.success = True return response
def handle_predict(self, predict_cmd): response = Response() response.success = False self.model.eval() for choicepoint in predict_cmd.choicepoints: with torch.no_grad(): log_prob = self.model(choicepoint.snapshot, choicepoint.choices) prediction = log_prob response.predictions.append(prediction) response.success = True return response
def handle_valid(self, valid_cmd): total_loss = 0.0 n_steps = 0 self.model.eval() for datapoint in valid_cmd.datapoints: snapshot, choices, choice_id = unpack_datapoint(datapoint) choice_id = choice_id.to(self.device) with torch.no_grad(): logits = self.model(snapshot, choices) loss = self.loss(logits, choice_id) total_loss += loss n_steps += 1 response = Response() response.loss = total_loss / n_steps if n_steps > 0 else float('nan') response.success = True return response
def handle_train(self, train_cmd): total_loss = 0.0 n_steps = 0 self.model.train() for _ in range(train_cmd.nEpochs): for datapoint in train_cmd.datapoints: snapshot, choices, choice_id = unpack_datapoint(datapoint) choice_id = choice_id.to(self.device) log_prob = self.model(snapshot, choices) loss = self.loss(log_prob, choice_id) self.model.zero_grad() loss.backward() torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.optim_cfg['grad_norm_clip']) self.optimizer.step() total_loss += loss n_steps += 1 response = Response() response.loss = total_loss / n_steps if n_steps > 0 else float('nan') response.success = True return response