Exemplo n.º 1
0
class OptimizerOptions(OptionsBase):
    type: str = argfield(default="adam", choices=["adam", "adamw"])
    adam_options: AdamOptions = argfield(default_factory=AdamOptions)
    adamw_options: AdamWOptions = argfield(default_factory=AdamWOptions)
    look_ahead_k: int = 0
    look_ahead_alpha: int = 0.5

    def get(self, trainable_params):
        if self.type == "adam":
            ret = self.adam_options.get_optimizer(trainable_params)
        elif self.type == "adamw":
            ret = self.adamw_options.get_optimizer(trainable_params)
        else:
            raise Exception(f"Optimizer {self.type} is not support yet")

        if self.look_ahead_k > 0:
            from coli.torch_extra.lookahead import Lookahead
            ret = Lookahead(ret, k=self.look_ahead_k, alpha=self.look_ahead_alpha)
        return ret

    @property
    def learning_rate(self):
        if self.type == "adam":
            return self.adam_options.lr
        else:
            return self.adamw_options.lr
Exemplo n.º 2
0
 class Options(PyTorchParserBase.Options):
     embed_file: Optional[str] = argfield(None,
                                          predict_time=True,
                                          predict_default=None)
     gpu: bool = argfield(False, predict_time=True, predict_default=False)
     hparams: Any = argfield(
         default_factory=lambda: SimpleParser.HParams(), )
Exemplo n.º 3
0
 class Options(OptionsBase):
     bert_model: ExistFile = argfield(predict_time=True)
     student_model: Optional[str] = argfield(default=None,
                                             predict_time=True)
     lower: bool = False
     project_to: Optional[int] = None
     feature_dropout: float = 0.0
     finetune_last_n: int = 0
     pooling_method: Optional[str] = "last"
Exemplo n.º 4
0
 class HParams(HParamsBase):
     optimizer: OptimizerOptions = argfield(
         default_factory=OptimizerOptions)
     learning: AdvancedLearningOptions = field(
         default_factory=AdvancedLearningOptions)
     pretrained_contextual: ExternalContextualEmbedding.Options = field(
         default_factory=ExternalContextualEmbedding.Options)
Exemplo n.º 5
0
 class Options(OptionsBase):
     path: str = argfield(predict_time=True)
     requires_grad: bool = False
     do_layer_norm: bool = False
     dropout: float = 0.5
     feature_dropout: float = 0.0
     keep_sentence_boundaries: bool = False
     project_to: Optional[int] = None
Exemplo n.º 6
0
 class Options(BranchSelect.Options):
     type: "contextual unit" = argfield("lstm", choices=contextual_units)
     lstm_options: LSTMLayer.Options = field(default_factory=LSTMLayer.Options)
     tflstm_options: LSTMLayer.Options = field(default_factory=LSTMLayer.Options)
     gru_options: GRULayer.Options = field(default_factory=LSTMLayer.Options)
     transformer_options: TransformerEncoder.Options = field(default_factory=TransformerEncoder.Options)
     allen_lstm_options: AllenNLPLSTMLayer.Options = field(default_factory=AllenNLPLSTMLayer.Options)
     conv_options: ConvEncoder.Options = field(default_factory=ConvEncoder.Options)
Exemplo n.º 7
0
class HParamsBase(OptionsBase):
    train_iters: "Count of training step" = 50000
    train_batch_size: "Batch size when training (words)" = 5000
    test_batch_size: "Batch size when inference (words)" = argfield(
        default=5000, type=int, predict_time=True)
    max_sentence_batch_size: "Max sentence count in a step" = argfield(
        default=16384, type=int, predict_time=True)

    print_every: "Print result every n step" = 5
    evaluate_every: "Validate result every n step" = 500

    num_buckets: "bucket count" = 100
    num_valid_bkts: "validation bucket count" = argfield(default=40,
                                                         type=int,
                                                         predict_time=True)

    seed: "random seed" = 42
    bucket_type: "bucket_type" = argfield(default="length_group",
                                          choices=bucket_types,
                                          predict_time=True)
Exemplo n.º 8
0
 class Options(OptionsBase):
     dim_word: "word embedding dim" = 100
     dim_postag: "postag embedding dim. 0 for not using postag" = 100
     dim_char_input: "character embedding input dim" = 100
     dim_char: "character embedding dim. 0 for not using character" = 100
     word_dropout: "word embedding dropout" = 0.4
     postag_dropout: "postag embedding dropout" = 0.2
     character_embedding: CharacterEmbedding.Options = field(
         default_factory=CharacterEmbedding.Options)
     input_layer_norm: "Use layer norm on input embeddings" = True
     mode: str = argfield("concat", choices=["add", "concat"])
     replace_unk_with_chars: bool = False
Exemplo n.º 9
0
 class Options(OptionsBase):
     num_layers: int = 8
     num_heads: int = 2
     d_kv: int = 32
     d_ff: int = 1024
     d_positional: Optional[int] = argfield(
         None, help="Use partitioned transformer if it is not None")
     relu_dropout: float = 0.1
     residual_dropout: float = 0.1
     attention_dropout: float = 0.1
     timing_dropout: float = 0.0
     timing_method: str = "embedding"
     timing_layer_norm: bool = False
     max_sent_len: int = 512
     leaky_relu_slope: float = 0.0
Exemplo n.º 10
0
class TaggerHParams(SimpleParser.HParams):
    train_iters: "Count of training step" = 10000
    evaluate_every: int = 100

    dims_hidden: "dims of hidden layers" = argfield(
        default_factory=lambda: [100], nargs="+")
    mlp_dropout: float = 0.2

    # NOTE: if both chars and crf, only 1.6x slower on GPU
    use_crf: "if crf, training is 1.7x slower on CPU" = True
    word_threshold: int = 30

    sentence_embedding: SentenceEmbeddings.Options = field(
        default_factory=SentenceEmbeddings.Options)
    contextual: ContextualUnits.Options = field(
        default_factory=ContextualUnits.Options)

    @classmethod
    def get_default(cls):
        default_tagger_hparams = cls()
        default_tagger_hparams.contextual.lstm_options.num_layers = 1
        default_tagger_hparams.contextual.lstm_options.recurrent_keep_prob = 1.0
        default_tagger_hparams.sentence_embedding.dim_postag = 0
        return default_tagger_hparams
Exemplo n.º 11
0
    class Options(OptionsBase):
        title: str = argfield("default", help="Name of this task")
        train: str = argfield(metavar="FILE", help="Path of training set")
        dev: List[str] = argfield(metavar="FILE",
                                  nargs="+",
                                  help="Path of development set")
        max_save: int = argfield(100,
                                 help="keep only best n model when training")
        epochs: int = argfield(30, help="Training epochs")
        debug_cache: bool = argfield(False,
                                     help="Use cache file for quick debugging")

        # both train and predict
        output: str = argfield(predict_time=True,
                               predict_default=REQUIRED,
                               help="Output path")
        test: Optional[str] = argfield(default=None,
                                       metavar="FILE",
                                       predict_time=True,
                                       predict_default=REQUIRED,
                                       help="Path of test set")
        model: str = argfield(default="model.",
                              help="Load/Save model file",
                              metavar="FILE",
                              predict_time=True,
                              predict_default=REQUIRED)
        dynet_seed: int = argfield(42, predict_time=True)
        dynet_autobatch: int = argfield(0, predict_time=True)
        dynet_mem: int = argfield(0, predict_time=True)
        dynet_gpus: int = argfield(0, predict_time=True)
        dynet_l2: float = argfield(0.0, predict_time=True)
        weight_decay: float = argfield(0.0, predict_time=True)
        output_scores: bool = argfield(False, predict_time=True)
        data_format: str = argfield("default",
                                    predict_time=True,
                                    help="format of input data")
        # ???
        # group.add_argument("--data-format", dest="data_format",
        #                    choices=cls.get_data_formats(),
        #                    default=cls.default_data_format_name)
        bilm_cache: str = argfield(None,
                                   metavar="FILE",
                                   predict_time=True,
                                   help="path of elmo cache file")
        bilm_use_cache_only: bool = argfield(
            False,
            predict_time=True,
            help="use elmo in cache file only, do not generate new elmo")
        bilm_path: Optional[str] = argfield(None,
                                            metavar="FILE",
                                            predict_time=True,
                                            help="path of elmo model")
        bilm_stateless: bool = argfield(False,
                                        predict_time=True,
                                        help="only use stateless elmo")
        bilm_gpu: str = argfield("",
                                 predict_time=True,
                                 help="run elmo on these gpu")
        use_exception_handler: bool = argfield(
            False,
            predict_default=False,
            predict_time=True,
            help="useful tools for quick debugging when encountering an error")

        # predict only
        eval: bool = argfield(predict_default=False,
                              train_time=False,
                              predict_time=True)
        input_format: str = argfield(
            choices=[
                "standard", "tokenlist", "space", "english", "english-line"
            ],
            help=
            'Input format. (default)"standard": use the same format of treebank;\n'
            'tokenlist: like [[(sent_1_word1, sent_1_pos1), ...], [...]];\n'
            'space: sentence is separated by newlines, and words are separated by space;'
            'no POSTag info will be used. \n'
            'english: raw english sentence that will be processed by NLTK tokenizer, '
            'no POSTag info will be used.',
            predict_time=True,
            train_time=False,
            predict_default="standard")
Exemplo n.º 12
0
 class Options(BranchSelect.Options):
     type: "Embedding Type" = argfield("none", choices=list(external_contextual_embeddings.keys()))
     elmo_options: ELMoPlugin.Options = argfield(default_factory=ELMoPlugin.Options)
     bert_options: BERTPlugin.Options = argfield(default_factory=BERTPlugin.Options)
     xlnet_options: XLNetPlugin.Options = argfield(default_factory=XLNetPlugin.Options)
Exemplo n.º 13
0
 class Options(BranchSelect.Options):
     type: "Character Embedding Type" = argfield("rnn", choices=char_embeddings)
     rnn_options: CharLSTMLayer.Options = field(default_factory=CharLSTMLayer.Options)
     cnn_options: CharCNNLayer.Options = field(default_factory=CharCNNLayer.Options)
     max_char: int = 20