def soft_pattern_arg_parser(): """ CLI args related to SoftPatternsClassifier """ p = ArgumentParser(add_help=False, parents=[lstm_arg_parser(), mlp_arg_parser()]) p.add_argument("-u", "--use_rnn", help="Use an RNN underneath soft-patterns", action="store_true") p.add_argument("-p", "--patterns", help="Pattern lengths and numbers: an underscore separated list of length-number pairs", default="5-50_4-50_3-50_2-50") p.add_argument("--maxplus", help="Use max-plus semiring instead of plus-times", default=False, action='store_true') p.add_argument("--maxtimes", help="Use max-times semiring instead of plus-times", default=False, action='store_true') p.add_argument("--bias_scale_param", help="Scale bias term by this parameter", default=0.1, type=float) p.add_argument("--eps_scale", help="Scale epsilon by this parameter", default=None, type=float) p.add_argument("--self_loop_scale", help="Scale self_loop by this parameter", default=None, type=float) p.add_argument("--no_eps", help="Don't use epsilon transitions", action='store_true') p.add_argument("--no_sl", help="Don't use self loops", action='store_true') p.add_argument("--shared_sl", help="Share main path and self loop parameters, where self loops are discounted by a self_loop_parameter. "+ str(SHARED_SL_PARAM_PER_STATE_PER_PATTERN)+ ": one parameter per state per pattern, "+str(SHARED_SL_SINGLE_PARAM)+ ": a global parameter.", type=int, default=0) return p
if model_save_dir is not None: if not os.path.exists(model_save_dir): os.makedirs(model_save_dir) print("Training with", model_file_prefix) train(train_data, dev_data, model, num_classes, model_save_dir, args.num_iterations, model_file_prefix, args.learning_rate, args.batch_size, args.scheduler, gpu=args.gpu, clip=args.clip, debug=args.debug, dropout=dropout, word_dropout=args.word_dropout, patience=args.patience) if __name__ == '__main__': parser = \ argparse.ArgumentParser(description=__doc__, formatter_class=argparse.ArgumentDefaultsHelpFormatter, parents=[lstm_arg_parser(), mlp_arg_parser(), training_arg_parser(), general_arg_parser()]) main(parser.parse_args())