def add_arguments_conformer_encoder_common(group): """Define common arguments for conformer encoder.""" group = add_arguments_rnn_encoder_common(group) group = add_arguments_conformer_common(group) group.add_argument( "--transformer-adim", default=320, type=int, help="Number of attention transformation dimensions", ) group.add_argument( "--transformer-aheads", default=4, type=int, help="Number of heads for multi head attention", ) group.add_argument( "--transformer-attn-dropout-rate", default=None, type=float, help="dropout in transformer attention. use --dropout-rate if None is set", ) group.add_argument( "--transformer-input-layer", type=str, default="conv2d", choices=["conv2d", "linear", "embed"], help="transformer input layer type", ) group.add_argument( "--transformer-encoder-selfattn-layer-type", type=str, default="selfattn", choices=[ "selfattn", "rel_selfattn", "lightconv", "lightconv2d", "dynamicconv", "dynamicconv2d", "light-dynamicconv2d", ], help="transformer encoder self-attention layer type", ) group.add_argument( "--transformer-lr", default=10.0, type=float, help="Initial value of learning rate", ) group.add_argument( "--transformer-warmup-steps", default=25000, type=int, help="optimizer warmup steps", ) return group
def add_maskctc_arguments(parser): """Add arguments for maskctc model.""" group = parser.add_argument_group("maskctc specific setting") group.add_argument( "--maskctc-use-conformer-encoder", default=False, type=strtobool, ) group = add_arguments_conformer_common(group) return parser
def add_conformer_arguments(parser): """Add arguments for conformer model.""" group = parser.add_argument_group("conformer model specific setting") group = add_arguments_conformer_common(group) group.add_argument( "--transformer-encoder-selfattn-layer-type", type=str, default="rel_selfattn", choices=[ "selfattn", "rel_selfattn", ], help="transformer encoder self-attention layer type", ) return parser
def add_conformer_arguments(parser): """Add arguments for conformer model.""" group = parser.add_argument_group("conformer model specific setting") group = add_arguments_conformer_common(group) return parser