def save_bert(model, optimizer, args, save_path, save_mode="all", verbose=True): assert save_mode in [ "all", "tunable", "model_all", "model_tunable", ] save_dict = dict() # Save args save_dict["args"] = vars(args) # Save model model_to_save = model.module if hasattr(model, 'module') else model # Only save the model itself if save_mode in ["all", "model_all"]: model_state_dict = model_to_save.state_dict() elif save_mode in ["tunable", "model_tunable"]: model_state_dict = get_tunable_state_dict(model_to_save) else: raise KeyError(save_mode) if verbose: print("Saving {} model elems:".format(len(model_state_dict))) save_dict["model"] = utils.to_cpu(model_state_dict) # Save optimizer if save_mode in ["all", "tunable"]: optimizer_state_dict = utils.to_cpu(optimizer.state_dict()) if optimizer is not None else None if verbose: print("Saving {} optimizer elems:".format(len(optimizer_state_dict))) torch.save(save_dict, save_path)
def save_bert(glue_lm_model, optimizer, args, save_path, save_mode="all", verbose=True): glue_model = glue_lm_model.glue_model lm_model = glue_lm_model.lm_model assert save_mode in [ "all", "tunable", "model_all", "model_tunable", ] save_dict = dict() # Save args save_dict["args"] = vars(args) # Save model glue_model_to_save = glue_model.module \ if hasattr(glue_model, 'module') \ else glue_model # Only save the model itself if save_mode in ["all", "model_all"]: glue_model_state_dict = glue_model_to_save.state_dict() lm_state_dict = get_lm_cls_state_dict(lm_model) elif save_mode in ["tunable", "model_tunable"]: glue_model_state_dict = get_tunable_state_dict(glue_model_to_save) lm_state_dict = get_tunable_state_dict(get_lm_cls_state_dict(lm_model)) else: raise KeyError(save_mode) if verbose: print("Saving {} glue model elems:".format(len(glue_model_state_dict))) print("Saving {} lm model elems:".format(len(lm_state_dict))) save_dict["model"] = utils.to_cpu(glue_model_state_dict) save_dict["lm_model"] = utils.to_cpu(lm_state_dict) # Save optimizer if save_mode in ["all", "tunable"]: optimizer_state_dict = utils.to_cpu( optimizer.state_dict()) if optimizer is not None else None if verbose: print("Saving {} optimizer elems:".format( len(optimizer_state_dict))) torch.save(save_dict, save_path)
def save_best_model(self, save_mode="model_all", verbose=True): classification_model = self.classification_lm_model.classification_model lm_model = self.classification_lm_model.lm_model assert save_mode in [ "all", "tunable", "model_all", "model_tunable", ] save_dict = dict() # Save model classification_model_to_save = classification_model.module \ if hasattr(classification_model, 'module') \ else classification_model # Only save the model itself if save_mode in ["all", "model_all"]: classification_model_state_dict = classification_model_to_save.state_dict( ) lm_state_dict = lm_model.state_dict() elif save_mode in ["tunable", "model_tunable"]: raise NotImplementedError else: raise KeyError(save_mode) if verbose: print("Saving {} classification model elems:".format( len(classification_model_state_dict))) print("Saving {} lm model elems:".format(len(lm_state_dict))) save_dict["model"] = utils.to_cpu(classification_model_state_dict) save_dict["lm_model"] = utils.to_cpu(lm_state_dict) # Save optimizer if save_mode in ["all", "tunable"]: optimizer_state_dict = utils.to_cpu(self.optimizer.state_dict( )) if self.optimizer is not None else None if verbose: print("Saving {} optimizer elems:".format( len(optimizer_state_dict))) torch.save(save_dict, os.path.join(self.output_path, "all_state.p"))