コード例 #1
0
def _nnet2file(layers, set_layer_num = -1, filename='nnet.out', activation='sigmoid', start_layer = 0, withfinal=True, input_factor = 0.0, factor=[0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0]):
    logger = logging.getLogger(__name__)
    logger.info("Saving network "+filename)

    n_layers = len(layers)
    nnet_dict = {}
    if set_layer_num == -1:
        set_layer_num = n_layers - 1

    for i in range(start_layer, set_layer_num):
        logger.info("Saving hidden layer "+str(i))
        dict_a = str(i) + ' ' + activation + ' W'
        if i == 0:
            nnet_dict[dict_a] = array_2_string((1.0 - input_factor) * layers[i].params[0].get_value())
        else:
            nnet_dict[dict_a] = array_2_string((1.0 - factor[i-1]) * layers[i].params[0].get_value())
        dict_a = str(i) + ' ' + activation + ' b'
        nnet_dict[dict_a] = array_2_string(layers[i].params[1].get_value())

        # gradients
        dict_a = str(i) + ' ' + activation + ' dW'
        nnet_dict[dict_a] = array_2_string(layers[i].delta_params[0].get_value())
        dict_a = str(i) + ' ' + activation + ' db'
        nnet_dict[dict_a] = array_2_string(layers[i].delta_params[1].get_value())
    
        if layers[i].kahan:
            logger.info("Loading hidden kahan")
            dict_a = str(i) + ' ' + activation + ' W_carry'
            nnet_dict[dict_a] = array_2_string(layers[i].params_carry[0].get_value())
            dict_a = str(i) + ' ' + activation + ' b_carry'
            nnet_dict[dict_a] = array_2_string(layers[i].params_carry[1].get_value())
            #dict_a = str(i) + ' ' + activation + ' dW_carry'
            #nnet_dict[dict_a] = array_2_string(layers[i].delta_params_carry[0].get_value())
            #dict_a = str(i) + ' ' + activation + ' db_carry'
            #nnet_dict[dict_a] = array_2_string(layers[i].delta_params_carry[1].get_value())

    if withfinal: 
        logger.info("Saving final layer ")
        
        dict_a = 'logreg W'
        nnet_dict[dict_a] = array_2_string((1.0 - factor[-1]) * layers[-1].params[0].get_value())
        dict_a = 'logreg b'
        nnet_dict[dict_a] = array_2_string(layers[-1].params[1].get_value())

        #gradients
        dict_a = 'logreg dW'
        nnet_dict[dict_a] = array_2_string(layers[-1].delta_params[0].get_value())
        dict_a = 'logreg db'
        nnet_dict[dict_a] = array_2_string(layers[-1].delta_params[1].get_value())

        if layers[-1].kahan:
            logger.info("Loading softmax kahan")
            dict_a = 'logreg W_carry'
            nnet_dict[dict_a] = array_2_string(layers[-1].params_carry[0].get_value())
            dict_a = 'logreg b_carry'
            nnet_dict[dict_a] = array_2_string(layers[-1].params_carry[1].get_value())
            #dict_a = 'logreg dW_carry'
            #nnet_dict[dict_a] = array_2_string(layers[-1].delta_params_carry[0].get_value())
            #dict_a = 'logreg db_carry'
            #nnet_dict[dict_a] = array_2_string(layers[-1].delta_params_carry[1].get_value())

    utils.pickle_save(nnet_dict, filename)   
コード例 #2
0
def main(args, model=None) -> SummarizationModule:
    print(args)
    Path(args.output_dir).mkdir(exist_ok=True)

    if model is None:
        if "summarization" in args.task:
            ### Define BART model
            # Config from "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/bart-large-cnn/config.json
            # Vocab modified to 50265 to be consistent with facebook/bart-large default
            config = BartConfig(**json.load(open(args.config_path, "r")))
            config.fp16 = args.fp16

            if args.distill:  # if distilling, start from finetuned checkpoint
                if Path(args.data_dir).name == "cnn_dm":
                    checkpoint = 'facebook/bart-large-cnn'
                else:
                    checkpoint = 'facebook/bart-large-xsum'
            else:
                checkpoint = 'facebook/bart-large'  #Start from pretrained checkpoint otherwise

            if args.resume_from_checkpoint:
                print(
                    "Resuming from checkpoint, make sure checkpoint is finetuned for best results"
                )
                if ".ckpt" in args.resume_from_checkpoint:
                    checkpoint = args.resume_from_checkpoint
                    if args.distill:  # set resume from checkpoint to None (state dict is different)
                        args.resume_from_checkpoint = None
                else:
                    checkpoints = list(
                        sorted(
                            glob.glob(os.path.join(args.output_dir, "*.ckpt"),
                                      recursive=True)))
                    if len(checkpoints) > 0:  #No checkpoints available
                        checkpoint = checkpoints[-1]
                        args.resume_from_checkpoint = checkpoint
                    else:
                        args.resume_from_checkpoint = None
                        print("No valid checkpoint to resume from. Using ",
                              checkpoint)

            print("Loading BART model checkpoint using ", checkpoint)
            model = BartForConditionalGeneration.from_pretrained(checkpoint,
                                                                 config=config)

            if args.distill == "sft":
                model = distill_sft(model)

            tokenizer = BartTokenizer.from_pretrained(
                'facebook/bart-large'
            )  # Downloads vocab and merges file automatically
            model: SummarizationModule = SummarizationModule(
                args, model=model, config=config, tokenizer=tokenizer)
        else:
            raise ValueError("Translation not supported at this time")
            model: SummarizationModule = TranslationModule(args)
    dataset = Path(args.data_dir).name
    if (args.logger_name == "default" or args.fast_dev_run
            or str(args.output_dir).startswith("/tmp")
            or str(args.output_dir).startswith("/var")):
        logger = True  # don't pollute wandb logs unnecessarily
    elif args.logger_name == "wandb":
        from pytorch_lightning.loggers import WandbLogger

        project = os.environ.get("WANDB_PROJECT", dataset)
        logger = WandbLogger(name=model.output_dir.name, project=project)

    elif args.logger_name == "wandb_shared":
        from pytorch_lightning.loggers import WandbLogger

        logger = WandbLogger(name=model.output_dir.name,
                             project=f"hf_{dataset}")

    # if args.early_stopping_patience >= 0:
    #     extra_callbacks = [get_early_stopping_callback(f"val_{model.val_metric}", args.early_stopping_patience)]
    # else:
    #     extra_callbacks = []
    extra_callbacks = [
        CheckpointEveryNSteps(args.output_dir, args.max_steps - 1)
    ]

    lower_is_better = args.val_metric == "loss"

    trainer: pl.Trainer = generic_train(
        model,
        args,
        logging_callback=Seq2SeqLoggingCallback(),
        checkpoint_callback=get_checkpoint_callback(args.output_dir,
                                                    model.val_metric,
                                                    args.save_top_k,
                                                    lower_is_better),
        extra_callbacks=extra_callbacks,
        logger=logger,
    )

    pickle_save(model.hparams, model.output_dir / "hparams.pkl")
    if args.do_predict and not args.do_train:
        # Testing from a checkpoint
        trainer.test(model)
    elif args.do_predict and args.do_train:
        # test() without a model tests using the best checkpoint automatically
        model.hparams.test_checkpoint = ""
        checkpoints = list(
            sorted(
                glob.glob(os.path.join(args.output_dir, "*.ckpt"),
                          recursive=True)))
        if checkpoints:
            model.hparams.test_checkpoint = checkpoints[-1]
            trainer.resume_from_checkpoint = checkpoints[-1]
        trainer.logger.log_hyperparams(model.hparams)
        trainer.test()
    return model
コード例 #3
0
    def __init__(self, hparams, **kwargs):
        if hparams.sortish_sampler and hparams.gpus > 1:
            hparams.replace_sampler_ddp = False
        elif hparams.max_tokens_per_batch is not None:
            if hparams.gpus > 1:
                raise NotImplementedError(
                    "Dynamic Batch size does not work for multi-gpu training")
            if hparams.sortish_sampler:
                raise ValueError(
                    "--sortish_sampler and --max_tokens_per_batch may not be used simultaneously"
                )

        super().__init__(hparams, num_labels=None, mode=self.mode, **kwargs)
        use_task_specific_params(self.model, "summarization")
        save_git_info(self.hparams.output_dir)
        self.metrics_save_path = Path(self.output_dir) / "metrics.json"
        self.hparams_save_path = Path(self.output_dir) / "hparams.pkl"
        pickle_save(self.hparams, self.hparams_save_path)
        self.step_count = 0
        self.metrics = defaultdict(list)

        self.dataset_kwargs: dict = dict(
            data_dir=self.hparams.data_dir,
            max_source_length=self.hparams.max_source_length,
            prefix=self.model.config.prefix or "",
        )
        n_observations_per_split = {
            "train": self.hparams.n_train,
            "val": self.hparams.n_val,
            "test": self.hparams.n_test,
        }
        self.n_obs = {
            k: v if v >= 0 else None
            for k, v in n_observations_per_split.items()
        }

        self.target_lens = {
            "train": self.hparams.max_target_length,
            "val": self.hparams.val_max_target_length,
            "test": self.hparams.test_max_target_length,
        }
        assert self.target_lens["train"] <= self.target_lens[
            "val"], f"target_lens: {self.target_lens}"
        assert self.target_lens["train"] <= self.target_lens[
            "test"], f"target_lens: {self.target_lens}"
        if self.hparams.freeze_embeds:
            self.freeze_embeds()
        if self.hparams.freeze_encoder:
            freeze_params(self.model.get_encoder())
            assert_all_frozen(self.model.get_encoder())

        self.hparams.git_sha = get_git_info()["repo_sha"]
        self.num_workers = hparams.num_workers
        self.sync_dist = True if hparams.gpus > 1 else False
        self.decoder_start_token_id = None  # default to config
        if self.model.config.decoder_start_token_id is None and isinstance(
                self.tokenizer, MBartTokenizer):
            self.decoder_start_token_id = self.tokenizer.lang_code_to_id[
                hparams.tgt_lang]
            self.model.config.decoder_start_token_id = self.decoder_start_token_id
        self.dataset_class = (LegacySeq2SeqDataset)
        self.eval_beams = self.model.config.num_beams if self.hparams.eval_beams is None else self.hparams.eval_beams
        assert self.eval_beams >= 0, f"got self.eval_beams={self.eval_beams}. Need an integer >= 0"
        if self.hparams.eval_max_gen_length is not None:
            self.eval_max_length = self.hparams.eval_max_gen_length
        else:
            self.eval_max_length = self.model.config.max_length
        self.val_metric = self.default_val_metric if self.hparams.val_metric is None else self.hparams.val_metric
コード例 #4
0
ファイル: model_io.py プロジェクト: GrassSunFlower/mxnet
def _nnet2file(layers, set_layer_num = -1, filename='nnet.out', activation='sigmoid', start_layer = 0, withfinal=True, input_factor = 0.0, factor=[0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0]):
    logger = logging.getLogger(__name__)
    logger.info("Saving network "+filename)

    n_layers = len(layers)
    nnet_dict = {}
    if set_layer_num == -1:
        set_layer_num = n_layers - 1

    for i in range(start_layer, set_layer_num):
        logger.info("Saving hidden layer "+str(i))
        dict_a = str(i) + ' ' + activation + ' W'
        if i == 0:
            nnet_dict[dict_a] = array_2_string((1.0 - input_factor) * layers[i].params[0].get_value())
        else:
            nnet_dict[dict_a] = array_2_string((1.0 - factor[i-1]) * layers[i].params[0].get_value())
        dict_a = str(i) + ' ' + activation + ' b'
        nnet_dict[dict_a] = array_2_string(layers[i].params[1].get_value())

        # gradients
        dict_a = str(i) + ' ' + activation + ' dW'
        nnet_dict[dict_a] = array_2_string(layers[i].delta_params[0].get_value())
        dict_a = str(i) + ' ' + activation + ' db'
        nnet_dict[dict_a] = array_2_string(layers[i].delta_params[1].get_value())

        if layers[i].kahan:
            logger.info("Loading hidden kahan")
            dict_a = str(i) + ' ' + activation + ' W_carry'
            nnet_dict[dict_a] = array_2_string(layers[i].params_carry[0].get_value())
            dict_a = str(i) + ' ' + activation + ' b_carry'
            nnet_dict[dict_a] = array_2_string(layers[i].params_carry[1].get_value())
            #dict_a = str(i) + ' ' + activation + ' dW_carry'
            #nnet_dict[dict_a] = array_2_string(layers[i].delta_params_carry[0].get_value())
            #dict_a = str(i) + ' ' + activation + ' db_carry'
            #nnet_dict[dict_a] = array_2_string(layers[i].delta_params_carry[1].get_value())

    if withfinal:
        logger.info("Saving final layer ")

        dict_a = 'logreg W'
        nnet_dict[dict_a] = array_2_string((1.0 - factor[-1]) * layers[-1].params[0].get_value())
        dict_a = 'logreg b'
        nnet_dict[dict_a] = array_2_string(layers[-1].params[1].get_value())

        #gradients
        dict_a = 'logreg dW'
        nnet_dict[dict_a] = array_2_string(layers[-1].delta_params[0].get_value())
        dict_a = 'logreg db'
        nnet_dict[dict_a] = array_2_string(layers[-1].delta_params[1].get_value())

        if layers[-1].kahan:
            logger.info("Loading softmax kahan")
            dict_a = 'logreg W_carry'
            nnet_dict[dict_a] = array_2_string(layers[-1].params_carry[0].get_value())
            dict_a = 'logreg b_carry'
            nnet_dict[dict_a] = array_2_string(layers[-1].params_carry[1].get_value())
            #dict_a = 'logreg dW_carry'
            #nnet_dict[dict_a] = array_2_string(layers[-1].delta_params_carry[0].get_value())
            #dict_a = 'logreg db_carry'
            #nnet_dict[dict_a] = array_2_string(layers[-1].delta_params_carry[1].get_value())

    utils.pickle_save(nnet_dict, filename)