def decoder_wrapper( model_file_path: str = "http://dmserv4.cs.illinois.edu/pner0.th", configs: dict = {}): """ Wrapper for different decode functions. Parameters ---------- model_file_path: ``str``, optional, (default = "http://dmserv4.cs.illinois.edu/pner0.th"). Path to loaded checkpoint. configs: ``dict``, optional, (default = "{}"). Additional configs. """ pw = wrapper(configs.get("log_path", None)) logger.info( "Loading model from {} (might download from source if not cached).". format(model_file_path)) model_file = wrapper.restore_checkpoint(model_file_path) model_type = model_file['config'].get("model_type", 'char-lstm-crf') logger.info('Preparing the pre-trained {} model.'.format(model_type)) model_type_dict = { \ "char-lstm-crf": decoder_wc, "char-lstm-two-level": decoder_tl} return model_type_dict[model_type](model_file, pw, configs)
def decoder_wrapper(model_file_path: str, configs: dict = {}): """ Wrapper for different decode functions. Parameters ---------- model_file_path: ``str``, required. Path to loaded checkpoint. configs: ``dict``, optional, (default = "{}"). Additional configs. """ pw = wrapper(configs.get("log_path", None)) pw.set_level(configs.get("log_level", 'info')) pw.info( "Loading model from {} (might download from source if not cached).". format(model_file_path)) model_file = wrapper.restore_checkpoint(model_file_path) model_type = configs.get("model_type", 'char-lstm-crf') pw.info('Preparing the pre-trained {} model.'.format(model_type)) model_type_dict = {"char-lstm-crf": decoder_wc} return model_type_dict[model_type](model_file, pw, configs)
else: soft_max = AdaptiveSoftmax(rnn_layer.output_dim, cut_off) lm_model = LM(rnn_layer, soft_max, len(w_map), args.word_dim, args.droprate, label_dim = args.label_dim, add_relu=args.add_relu) lm_model.rand_ini() pw.info('Building optimizer.') optim_map = {'Adam' : optim.Adam, 'Adagrad': optim.Adagrad, 'Adadelta': optim.Adadelta} if args.lr > 0: optimizer=optim_map[args.update](lm_model.parameters(), lr=args.lr) else: optimizer=optim_map[args.update](lm_model.parameters()) if args.restore_checkpoint: if os.path.isfile(args.restore_checkpoint): pw.info("loading checkpoint: '{}'".format(args.restore_checkpoint)) model_file = wrapper.restore_checkpoint(args.restore_checkpoint)['model'] lm_model.load_state_dict(model_file, False) else: pw.info("no checkpoint found at: '{}'".format(args.restore_checkpoint)) lm_model.to(device) pw.info('Saving configues.') pw.save_configue(args) pw.info('Setting up training environ.') best_train_ppl = float('inf') cur_lr = args.lr batch_index = 0 epoch_loss = 0 patience = 0
pw.info('Building language models and seuqence labeling models.') rnn_map = {'Basic': BasicRNN, 'DenseNet': DenseRNN, 'LDNet': functools.partial(LDRNN, layer_drop = 0)} flm_rnn_layer = rnn_map[args.lm_rnn_layer](args.lm_layer_num, args.lm_rnn_unit, args.lm_word_dim, args.lm_hid_dim, args.lm_droprate) blm_rnn_layer = rnn_map[args.lm_rnn_layer](args.lm_layer_num, args.lm_rnn_unit, args.lm_word_dim, args.lm_hid_dim, args.lm_droprate) flm_model = LM(flm_rnn_layer, None, len(flm_map), args.lm_word_dim, args.lm_droprate, label_dim = args.lm_label_dim) blm_model = LM(blm_rnn_layer, None, len(blm_map), args.lm_word_dim, args.lm_droprate, label_dim = args.lm_label_dim) flm_model_seq = SparseSeqLM(flm_model, False, args.lm_droprate, False) blm_model_seq = SparseSeqLM(blm_model, True, args.lm_droprate, False) SL_map = {'vanilla':Vanilla_SeqLabel, 'lm-aug': SeqLabel} seq_model = SL_map[args.seq_model](flm_model_seq, blm_model_seq, len(c_map), args.seq_c_dim, args.seq_c_hid, args.seq_c_layer, len(gw_map), args.seq_w_dim, args.seq_w_hid, args.seq_w_layer, len(y_map), args.seq_droprate, unit=args.seq_rnn_unit) pw.info('Loading pre-trained models from {}.'.format(args.load_seq)) seq_file = wrapper.restore_checkpoint(args.load_seq)['model'] seq_model.load_state_dict(seq_file) seq_model.to(device) crit = CRFLoss(y_map) decoder = CRFDecode(y_map) evaluator = eval_wc(decoder, 'f1') pw.info('Constructing dataset.') train_dataset, test_dataset, dev_dataset = [SeqDataset(tup_data, flm_map['\n'], blm_map['\n'], gw_map['<\n>'], c_map[' '], c_map['\n'], y_map['<s>'], y_map['<eof>'], len(y_map), args.batch_size) for tup_data in [train_data, test_data, dev_data]] pw.info('Constructing optimizer.') param_dict = filter(lambda t: t.requires_grad, seq_model.parameters()) optim_map = {'Adam' : optim.Adam, 'Adagrad': optim.Adagrad, 'Adadelta': optim.Adadelta, 'SGD': functools.partial(optim.SGD, momentum=0.9)} if args.lr > 0:
torch.cuda.set_device(gpu_index) logger.info('Loading data') dataset = pickle.load(open(args.corpus, 'rb')) name_list = ['flm_map', 'blm_map', 'gw_map', 'c_map', 'y_map', 'emb_array', 'train_data', 'test_data', 'dev_data'] flm_map, blm_map, gw_map, c_map, y_map, emb_array, train_data, test_data, dev_data = [dataset[tup] for tup in name_list ] logger.info('Loading language model') rnn_map = {'Basic': BasicRNN} flm_rnn_layer = rnn_map[args.lm_rnn_layer](args.lm_layer_num, args.lm_rnn_unit, args.lm_word_dim, args.lm_hid_dim, args.lm_droprate) blm_rnn_layer = rnn_map[args.lm_rnn_layer](args.lm_layer_num, args.lm_rnn_unit, args.lm_word_dim, args.lm_hid_dim, args.lm_droprate) flm_model = LM(flm_rnn_layer, None, len(flm_map), args.lm_word_dim, args.lm_droprate, label_dim = args.lm_label_dim) blm_model = LM(blm_rnn_layer, None, len(blm_map), args.lm_word_dim, args.lm_droprate, label_dim = args.lm_label_dim) flm_file = wrapper.restore_checkpoint(args.forward_lm)['model'] flm_model.load_state_dict(flm_file, False) blm_file = wrapper.restore_checkpoint(args.backward_lm)['model'] blm_model.load_state_dict(blm_file, False) flm_model_seq = ElmoLM(flm_model, False, args.lm_droprate, True) blm_model_seq = ElmoLM(blm_model, True, args.lm_droprate, True) logger.info('Building model') SL_map = {'vanilla':Vanilla_SeqLabel, 'lm-aug': SeqLabel} seq_model = SL_map[args.seq_model](flm_model_seq, blm_model_seq, len(c_map), args.seq_c_dim, args.seq_c_hid, args.seq_c_layer, len(gw_map), args.seq_w_dim, args.seq_w_hid, args.seq_w_layer, len(y_map), args.seq_droprate, unit=args.seq_rnn_unit) seq_model.rand_init() seq_model.load_pretrained_word_embedding(torch.FloatTensor(emb_array)) seq_model.to(device) crit = CRFLoss(y_map) decoder = CRFDecode(y_map)