def _nnet2file(layers, set_layer_num = -1, filename='nnet.out', activation='sigmoid', start_layer = 0, withfinal=True, input_factor = 0.0, factor=[0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0]): logger = logging.getLogger(__name__) logger.info("Saving network "+filename) n_layers = len(layers) nnet_dict = {} if set_layer_num == -1: set_layer_num = n_layers - 1 for i in range(start_layer, set_layer_num): logger.info("Saving hidden layer "+str(i)) dict_a = str(i) + ' ' + activation + ' W' if i == 0: nnet_dict[dict_a] = array_2_string((1.0 - input_factor) * layers[i].params[0].get_value()) else: nnet_dict[dict_a] = array_2_string((1.0 - factor[i-1]) * layers[i].params[0].get_value()) dict_a = str(i) + ' ' + activation + ' b' nnet_dict[dict_a] = array_2_string(layers[i].params[1].get_value()) # gradients dict_a = str(i) + ' ' + activation + ' dW' nnet_dict[dict_a] = array_2_string(layers[i].delta_params[0].get_value()) dict_a = str(i) + ' ' + activation + ' db' nnet_dict[dict_a] = array_2_string(layers[i].delta_params[1].get_value()) if layers[i].kahan: logger.info("Loading hidden kahan") dict_a = str(i) + ' ' + activation + ' W_carry' nnet_dict[dict_a] = array_2_string(layers[i].params_carry[0].get_value()) dict_a = str(i) + ' ' + activation + ' b_carry' nnet_dict[dict_a] = array_2_string(layers[i].params_carry[1].get_value()) #dict_a = str(i) + ' ' + activation + ' dW_carry' #nnet_dict[dict_a] = array_2_string(layers[i].delta_params_carry[0].get_value()) #dict_a = str(i) + ' ' + activation + ' db_carry' #nnet_dict[dict_a] = array_2_string(layers[i].delta_params_carry[1].get_value()) if withfinal: logger.info("Saving final layer ") dict_a = 'logreg W' nnet_dict[dict_a] = array_2_string((1.0 - factor[-1]) * layers[-1].params[0].get_value()) dict_a = 'logreg b' nnet_dict[dict_a] = array_2_string(layers[-1].params[1].get_value()) #gradients dict_a = 'logreg dW' nnet_dict[dict_a] = array_2_string(layers[-1].delta_params[0].get_value()) dict_a = 'logreg db' nnet_dict[dict_a] = array_2_string(layers[-1].delta_params[1].get_value()) if layers[-1].kahan: logger.info("Loading softmax kahan") dict_a = 'logreg W_carry' nnet_dict[dict_a] = array_2_string(layers[-1].params_carry[0].get_value()) dict_a = 'logreg b_carry' nnet_dict[dict_a] = array_2_string(layers[-1].params_carry[1].get_value()) #dict_a = 'logreg dW_carry' #nnet_dict[dict_a] = array_2_string(layers[-1].delta_params_carry[0].get_value()) #dict_a = 'logreg db_carry' #nnet_dict[dict_a] = array_2_string(layers[-1].delta_params_carry[1].get_value()) utils.pickle_save(nnet_dict, filename)
def main(args, model=None) -> SummarizationModule: print(args) Path(args.output_dir).mkdir(exist_ok=True) if model is None: if "summarization" in args.task: ### Define BART model # Config from "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/bart-large-cnn/config.json # Vocab modified to 50265 to be consistent with facebook/bart-large default config = BartConfig(**json.load(open(args.config_path, "r"))) config.fp16 = args.fp16 if args.distill: # if distilling, start from finetuned checkpoint if Path(args.data_dir).name == "cnn_dm": checkpoint = 'facebook/bart-large-cnn' else: checkpoint = 'facebook/bart-large-xsum' else: checkpoint = 'facebook/bart-large' #Start from pretrained checkpoint otherwise if args.resume_from_checkpoint: print( "Resuming from checkpoint, make sure checkpoint is finetuned for best results" ) if ".ckpt" in args.resume_from_checkpoint: checkpoint = args.resume_from_checkpoint if args.distill: # set resume from checkpoint to None (state dict is different) args.resume_from_checkpoint = None else: checkpoints = list( sorted( glob.glob(os.path.join(args.output_dir, "*.ckpt"), recursive=True))) if len(checkpoints) > 0: #No checkpoints available checkpoint = checkpoints[-1] args.resume_from_checkpoint = checkpoint else: args.resume_from_checkpoint = None print("No valid checkpoint to resume from. Using ", checkpoint) print("Loading BART model checkpoint using ", checkpoint) model = BartForConditionalGeneration.from_pretrained(checkpoint, config=config) if args.distill == "sft": model = distill_sft(model) tokenizer = BartTokenizer.from_pretrained( 'facebook/bart-large' ) # Downloads vocab and merges file automatically model: SummarizationModule = SummarizationModule( args, model=model, config=config, tokenizer=tokenizer) else: raise ValueError("Translation not supported at this time") model: SummarizationModule = TranslationModule(args) dataset = Path(args.data_dir).name if (args.logger_name == "default" or args.fast_dev_run or str(args.output_dir).startswith("/tmp") or str(args.output_dir).startswith("/var")): logger = True # don't pollute wandb logs unnecessarily elif args.logger_name == "wandb": from pytorch_lightning.loggers import WandbLogger project = os.environ.get("WANDB_PROJECT", dataset) logger = WandbLogger(name=model.output_dir.name, project=project) elif args.logger_name == "wandb_shared": from pytorch_lightning.loggers import WandbLogger logger = WandbLogger(name=model.output_dir.name, project=f"hf_{dataset}") # if args.early_stopping_patience >= 0: # extra_callbacks = [get_early_stopping_callback(f"val_{model.val_metric}", args.early_stopping_patience)] # else: # extra_callbacks = [] extra_callbacks = [ CheckpointEveryNSteps(args.output_dir, args.max_steps - 1) ] lower_is_better = args.val_metric == "loss" trainer: pl.Trainer = generic_train( model, args, logging_callback=Seq2SeqLoggingCallback(), checkpoint_callback=get_checkpoint_callback(args.output_dir, model.val_metric, args.save_top_k, lower_is_better), extra_callbacks=extra_callbacks, logger=logger, ) pickle_save(model.hparams, model.output_dir / "hparams.pkl") if args.do_predict and not args.do_train: # Testing from a checkpoint trainer.test(model) elif args.do_predict and args.do_train: # test() without a model tests using the best checkpoint automatically model.hparams.test_checkpoint = "" checkpoints = list( sorted( glob.glob(os.path.join(args.output_dir, "*.ckpt"), recursive=True))) if checkpoints: model.hparams.test_checkpoint = checkpoints[-1] trainer.resume_from_checkpoint = checkpoints[-1] trainer.logger.log_hyperparams(model.hparams) trainer.test() return model
def __init__(self, hparams, **kwargs): if hparams.sortish_sampler and hparams.gpus > 1: hparams.replace_sampler_ddp = False elif hparams.max_tokens_per_batch is not None: if hparams.gpus > 1: raise NotImplementedError( "Dynamic Batch size does not work for multi-gpu training") if hparams.sortish_sampler: raise ValueError( "--sortish_sampler and --max_tokens_per_batch may not be used simultaneously" ) super().__init__(hparams, num_labels=None, mode=self.mode, **kwargs) use_task_specific_params(self.model, "summarization") save_git_info(self.hparams.output_dir) self.metrics_save_path = Path(self.output_dir) / "metrics.json" self.hparams_save_path = Path(self.output_dir) / "hparams.pkl" pickle_save(self.hparams, self.hparams_save_path) self.step_count = 0 self.metrics = defaultdict(list) self.dataset_kwargs: dict = dict( data_dir=self.hparams.data_dir, max_source_length=self.hparams.max_source_length, prefix=self.model.config.prefix or "", ) n_observations_per_split = { "train": self.hparams.n_train, "val": self.hparams.n_val, "test": self.hparams.n_test, } self.n_obs = { k: v if v >= 0 else None for k, v in n_observations_per_split.items() } self.target_lens = { "train": self.hparams.max_target_length, "val": self.hparams.val_max_target_length, "test": self.hparams.test_max_target_length, } assert self.target_lens["train"] <= self.target_lens[ "val"], f"target_lens: {self.target_lens}" assert self.target_lens["train"] <= self.target_lens[ "test"], f"target_lens: {self.target_lens}" if self.hparams.freeze_embeds: self.freeze_embeds() if self.hparams.freeze_encoder: freeze_params(self.model.get_encoder()) assert_all_frozen(self.model.get_encoder()) self.hparams.git_sha = get_git_info()["repo_sha"] self.num_workers = hparams.num_workers self.sync_dist = True if hparams.gpus > 1 else False self.decoder_start_token_id = None # default to config if self.model.config.decoder_start_token_id is None and isinstance( self.tokenizer, MBartTokenizer): self.decoder_start_token_id = self.tokenizer.lang_code_to_id[ hparams.tgt_lang] self.model.config.decoder_start_token_id = self.decoder_start_token_id self.dataset_class = (LegacySeq2SeqDataset) self.eval_beams = self.model.config.num_beams if self.hparams.eval_beams is None else self.hparams.eval_beams assert self.eval_beams >= 0, f"got self.eval_beams={self.eval_beams}. Need an integer >= 0" if self.hparams.eval_max_gen_length is not None: self.eval_max_length = self.hparams.eval_max_gen_length else: self.eval_max_length = self.model.config.max_length self.val_metric = self.default_val_metric if self.hparams.val_metric is None else self.hparams.val_metric