コード例 #1
0
ファイル: decoder.py プロジェクト: maoxl9622/LightNER
def decoder_wrapper(
        model_file_path: str = "http://dmserv4.cs.illinois.edu/pner0.th",
        configs: dict = {}):
    """
    Wrapper for different decode functions.

    Parameters
    ----------
    model_file_path: ``str``, optional, (default = "http://dmserv4.cs.illinois.edu/pner0.th").
        Path to loaded checkpoint.
    configs: ``dict``, optional, (default = "{}").
        Additional configs.
    """
    pw = wrapper(configs.get("log_path", None))

    logger.info(
        "Loading model from {} (might download from source if not cached).".
        format(model_file_path))
    model_file = wrapper.restore_checkpoint(model_file_path)

    model_type = model_file['config'].get("model_type", 'char-lstm-crf')
    logger.info('Preparing the pre-trained {} model.'.format(model_type))
    model_type_dict = { \
            "char-lstm-crf": decoder_wc,
            "char-lstm-two-level": decoder_tl}
    return model_type_dict[model_type](model_file, pw, configs)
コード例 #2
0
ファイル: decoder.py プロジェクト: fendaq/LightNER
def decoder_wrapper(model_file_path: str, configs: dict = {}):
    """
    Wrapper for different decode functions.

    Parameters
    ----------
    model_file_path: ``str``, required.
        Path to loaded checkpoint.
    configs: ``dict``, optional, (default = "{}").
        Additional configs.
    """

    pw = wrapper(configs.get("log_path", None))
    pw.set_level(configs.get("log_level", 'info'))

    pw.info(
        "Loading model from {} (might download from source if not cached).".
        format(model_file_path))
    model_file = wrapper.restore_checkpoint(model_file_path)

    model_type = configs.get("model_type", 'char-lstm-crf')
    pw.info('Preparing the pre-trained {} model.'.format(model_type))
    model_type_dict = {"char-lstm-crf": decoder_wc}
    return model_type_dict[model_type](model_file, pw, configs)
コード例 #3
0
ファイル: train_lm.py プロジェクト: zpppy/LD-Net
    else:
        soft_max = AdaptiveSoftmax(rnn_layer.output_dim, cut_off)
    lm_model = LM(rnn_layer, soft_max, len(w_map), args.word_dim, args.droprate, label_dim = args.label_dim, add_relu=args.add_relu)
    lm_model.rand_ini()

    pw.info('Building optimizer.')
    optim_map = {'Adam' : optim.Adam, 'Adagrad': optim.Adagrad, 'Adadelta': optim.Adadelta}
    if args.lr > 0:
        optimizer=optim_map[args.update](lm_model.parameters(), lr=args.lr)
    else:
        optimizer=optim_map[args.update](lm_model.parameters())

    if args.restore_checkpoint:
        if os.path.isfile(args.restore_checkpoint):
            pw.info("loading checkpoint: '{}'".format(args.restore_checkpoint))
            model_file = wrapper.restore_checkpoint(args.restore_checkpoint)['model']
            lm_model.load_state_dict(model_file, False)
        else:
            pw.info("no checkpoint found at: '{}'".format(args.restore_checkpoint))
    lm_model.to(device)

    pw.info('Saving configues.')
    pw.save_configue(args)

    pw.info('Setting up training environ.')
    best_train_ppl = float('inf')
    cur_lr = args.lr
    batch_index = 0
    epoch_loss = 0
    patience = 0
コード例 #4
0
    pw.info('Building language models and seuqence labeling models.')

    rnn_map = {'Basic': BasicRNN, 'DenseNet': DenseRNN, 'LDNet': functools.partial(LDRNN, layer_drop = 0)}
    flm_rnn_layer = rnn_map[args.lm_rnn_layer](args.lm_layer_num, args.lm_rnn_unit, args.lm_word_dim, args.lm_hid_dim, args.lm_droprate)
    blm_rnn_layer = rnn_map[args.lm_rnn_layer](args.lm_layer_num, args.lm_rnn_unit, args.lm_word_dim, args.lm_hid_dim, args.lm_droprate)
    flm_model = LM(flm_rnn_layer, None, len(flm_map), args.lm_word_dim, args.lm_droprate, label_dim = args.lm_label_dim)
    blm_model = LM(blm_rnn_layer, None, len(blm_map), args.lm_word_dim, args.lm_droprate, label_dim = args.lm_label_dim)
    flm_model_seq = SparseSeqLM(flm_model, False, args.lm_droprate, False)
    blm_model_seq = SparseSeqLM(blm_model, True, args.lm_droprate, False)
    SL_map = {'vanilla':Vanilla_SeqLabel, 'lm-aug': SeqLabel}
    seq_model = SL_map[args.seq_model](flm_model_seq, blm_model_seq, len(c_map), args.seq_c_dim, args.seq_c_hid, args.seq_c_layer, len(gw_map), args.seq_w_dim, args.seq_w_hid, args.seq_w_layer, len(y_map), args.seq_droprate, unit=args.seq_rnn_unit)

    pw.info('Loading pre-trained models from {}.'.format(args.load_seq))

    seq_file = wrapper.restore_checkpoint(args.load_seq)['model']
    seq_model.load_state_dict(seq_file)
    seq_model.to(device)
    crit = CRFLoss(y_map)
    decoder = CRFDecode(y_map)
    evaluator = eval_wc(decoder, 'f1')

    pw.info('Constructing dataset.')

    train_dataset, test_dataset, dev_dataset = [SeqDataset(tup_data, flm_map['\n'], blm_map['\n'], gw_map['<\n>'], c_map[' '], c_map['\n'], y_map['<s>'], y_map['<eof>'], len(y_map), args.batch_size) for tup_data in [train_data, test_data, dev_data]]

    pw.info('Constructing optimizer.')

    param_dict = filter(lambda t: t.requires_grad, seq_model.parameters())
    optim_map = {'Adam' : optim.Adam, 'Adagrad': optim.Adagrad, 'Adadelta': optim.Adadelta, 'SGD': functools.partial(optim.SGD, momentum=0.9)}
    if args.lr > 0:
コード例 #5
0
ファイル: train_seq_elmo.py プロジェクト: xtang27/LD-Net
        torch.cuda.set_device(gpu_index)

    logger.info('Loading data')

    dataset = pickle.load(open(args.corpus, 'rb'))
    name_list = ['flm_map', 'blm_map', 'gw_map', 'c_map', 'y_map', 'emb_array', 'train_data', 'test_data', 'dev_data']
    flm_map, blm_map, gw_map, c_map, y_map, emb_array, train_data, test_data, dev_data = [dataset[tup] for tup in name_list ]

    logger.info('Loading language model')

    rnn_map = {'Basic': BasicRNN}
    flm_rnn_layer = rnn_map[args.lm_rnn_layer](args.lm_layer_num, args.lm_rnn_unit, args.lm_word_dim, args.lm_hid_dim, args.lm_droprate)
    blm_rnn_layer = rnn_map[args.lm_rnn_layer](args.lm_layer_num, args.lm_rnn_unit, args.lm_word_dim, args.lm_hid_dim, args.lm_droprate)
    flm_model = LM(flm_rnn_layer, None, len(flm_map), args.lm_word_dim, args.lm_droprate, label_dim = args.lm_label_dim)
    blm_model = LM(blm_rnn_layer, None, len(blm_map), args.lm_word_dim, args.lm_droprate, label_dim = args.lm_label_dim)
    flm_file = wrapper.restore_checkpoint(args.forward_lm)['model']
    flm_model.load_state_dict(flm_file, False)
    blm_file = wrapper.restore_checkpoint(args.backward_lm)['model']
    blm_model.load_state_dict(blm_file, False)
    flm_model_seq = ElmoLM(flm_model, False, args.lm_droprate, True)
    blm_model_seq = ElmoLM(blm_model, True, args.lm_droprate, True)

    logger.info('Building model')

    SL_map = {'vanilla':Vanilla_SeqLabel, 'lm-aug': SeqLabel}
    seq_model = SL_map[args.seq_model](flm_model_seq, blm_model_seq, len(c_map), args.seq_c_dim, args.seq_c_hid, args.seq_c_layer, len(gw_map), args.seq_w_dim, args.seq_w_hid, args.seq_w_layer, len(y_map), args.seq_droprate, unit=args.seq_rnn_unit)
    seq_model.rand_init()
    seq_model.load_pretrained_word_embedding(torch.FloatTensor(emb_array))
    seq_model.to(device)
    crit = CRFLoss(y_map)
    decoder = CRFDecode(y_map)