示例#1
0
文件: regimens.py 项目: msperber/misc
    def __init__(
        self,
        model: models.ConditionedModel = Ref("model"),
        src_file: Union[None, str, Sequence[str]] = None,
        trg_file: Optional[str] = None,
        dev_every: numbers.Integral = 0,
        dev_zero: bool = False,
        batcher: batchers.Batcher = bare(batchers.SrcBatcher, batch_size=32),
        loss_calculator: loss_calculators.LossCalculator = bare(
            loss_calculators.MLELoss),
        trainer: optimizers.XnmtOptimizer = bare(optimizers.SimpleSGDTrainer,
                                                 e0=0.1),
        run_for_epochs: Optional[numbers.Integral] = None,
        lr_decay: numbers.Real = 1.0,
        lr_decay_times: numbers.Integral = 3,
        patience: numbers.Integral = 1,
        initial_patience: Optional[numbers.Integral] = None,
        dev_tasks: Sequence[eval_tasks.EvalTask] = None,
        dev_combinator: Optional[str] = None,
        restart_trainer: bool = False,
        reload_command: Optional[str] = None,
        name: str = "{EXP}",
        sample_train_sents: Optional[numbers.Integral] = None,
        max_num_train_sents: Optional[numbers.Integral] = None,
        max_src_len: Optional[numbers.Integral] = None,
        max_trg_len: Optional[numbers.Integral] = None,
        loss_comb_method: str = Ref("exp_global.loss_comb_method",
                                    default="sum"),
        update_every: numbers.Integral = 1,
        commandline_args: dict = Ref("exp_global.commandline_args", default={})
    ) -> None:

        super().__init__(model=model,
                         src_file=src_file,
                         trg_file=trg_file,
                         dev_every=dev_every,
                         batcher=batcher,
                         loss_calculator=loss_calculator,
                         run_for_epochs=run_for_epochs,
                         lr_decay=lr_decay,
                         lr_decay_times=lr_decay_times,
                         patience=patience,
                         initial_patience=initial_patience,
                         dev_tasks=dev_tasks,
                         dev_combinator=dev_combinator,
                         restart_trainer=restart_trainer,
                         reload_command=reload_command,
                         name=name,
                         sample_train_sents=sample_train_sents,
                         max_num_train_sents=max_num_train_sents,
                         max_src_len=max_src_len,
                         max_trg_len=max_trg_len)
        self.dev_zero = dev_zero
        self.trainer = trainer or optimizers.SimpleSGDTrainer(e0=0.1)
        self.dynet_profiling = commandline_args.get(
            "dynet_profiling", 0) if commandline_args else 0
        self.train_loss_tracker = loss_trackers.TrainLossTracker(self)
        self.loss_comb_method = loss_comb_method
        self.update_every = update_every
        self.num_updates_skipped = 0
示例#2
0
    def __init__(self,
                 model: models.ConditionedModel = Ref("model"),
                 src_file: Union[None, str, Sequence[str]] = None,
                 trg_file: Optional[str] = None,
                 dev_zero: bool = False,
                 batcher: batchers.Batcher = bare(batchers.SrcBatcher,
                                                  batch_size=32),
                 loss_calculator: loss_calculators.LossCalculator = bare(
                     loss_calculators.MLELoss),
                 trainer: optimizers.XnmtOptimizer = bare(
                     optimizers.SimpleSGDTrainer, e0=0.1),
                 run_for_epochs: Optional[numbers.Integral] = None,
                 lr_decay: numbers.Real = 1.0,
                 lr_decay_times: numbers.Integral = 3,
                 patience: numbers.Integral = 1,
                 initial_patience: Optional[numbers.Integral] = None,
                 dev_tasks: Sequence[eval_tasks.EvalTask] = None,
                 dev_combinator: Optional[str] = None,
                 restart_trainer: bool = False,
                 reload_command: Optional[str] = None,
                 name: str = "{EXP}",
                 sample_train_sents: Optional[numbers.Integral] = None,
                 max_num_train_sents: Optional[numbers.Integral] = None,
                 max_src_len: Optional[numbers.Integral] = None,
                 max_trg_len: Optional[numbers.Integral] = None,
                 loss_comb_method: str = Ref("exp_global.loss_comb_method",
                                             default="sum"),
                 update_every: numbers.Integral = 1) -> None:

        super().__init__(model=model,
                         src_file=src_file,
                         trg_file=trg_file,
                         batcher=batcher,
                         loss_calculator=loss_calculator,
                         run_for_epochs=run_for_epochs,
                         lr_decay=lr_decay,
                         lr_decay_times=lr_decay_times,
                         patience=patience,
                         initial_patience=initial_patience,
                         dev_tasks=dev_tasks,
                         dev_combinator=dev_combinator,
                         restart_trainer=restart_trainer,
                         reload_command=reload_command,
                         name=name,
                         sample_train_sents=sample_train_sents,
                         max_num_train_sents=max_num_train_sents,
                         max_src_len=max_src_len,
                         max_trg_len=max_trg_len)
        if batcher.batch_size != 1:
            raise ValueError(
                "AutobatchTrainingRegimen forces the batcher to have batch_size 1. Use update_every to set the actual batch size in this regimen."
            )
        self.dev_zero = dev_zero
        self.trainer = trainer or optimizers.SimpleSGDTrainer(e0=0.1)
        self.loss_comb_method = loss_comb_method
        self.update_every = update_every
        self.num_updates_skipped = 0