def load(self, path): """Load from disk Parameters ---------- path : str path to the directory which typically contains a config.pkl file and a model.bin file Returns ------- DepParser parser itself """ config = _Config.load(os.path.join(path, 'config.pkl')) config.save_dir = path # redirect root path to what user specified self._vocab = vocab = ParserVocabulary.load(config.save_vocab_path) with mx.Context(mxnet_prefer_gpu()): self._parser = BiaffineParser( vocab, config.word_dims, config.tag_dims, config.dropout_emb, config.lstm_layers, config.lstm_hiddens, config.dropout_lstm_input, config.dropout_lstm_hidden, config.mlp_arc_size, config.mlp_rel_size, config.dropout_mlp, config.debug) self._parser.load(config.save_model_path) return self
def load(self, path): """Load from disk Parameters ---------- path : str path to the directory which typically contains a config.pkl file and a model.bin file Returns ------- DepParser parser itself """ config = _Config.load(os.path.join(path, 'config.pkl')) config.save_dir = path # redirect root path to what user specified self._vocab = vocab = ParserVocabulary.load(config.save_vocab_path) with mx.Context(mxnet_prefer_gpu()): self._parser = BiaffineParser(vocab, config.word_dims, config.tag_dims, config.dropout_emb, config.lstm_layers, config.lstm_hiddens, config.dropout_lstm_input, config.dropout_lstm_hidden, config.mlp_arc_size, config.mlp_rel_size, config.dropout_mlp, config.debug) self._parser.load(config.save_model_path) return self
def train(self, train_file, dev_file, test_file, save_dir, pretrained_embeddings=None, min_occur_count=2, lstm_layers=3, word_dims=100, tag_dims=100, dropout_emb=0.33, lstm_hiddens=400, dropout_lstm_input=0.33, dropout_lstm_hidden=0.33, mlp_arc_size=500, mlp_rel_size=100, dropout_mlp=0.33, learning_rate=2e-3, decay=.75, decay_steps=5000, beta_1=.9, beta_2=.9, epsilon=1e-12, num_buckets_train=40, num_buckets_valid=10, num_buckets_test=10, train_iters=50000, train_batch_size=5000, test_batch_size=5000, validate_every=100, save_after=5000, debug=False): """Train a deep biaffine dependency parser. Parameters ---------- train_file : str path to training set dev_file : str path to dev set test_file : str path to test set save_dir : str a directory for saving model and related meta-data pretrained_embeddings : tuple (embedding_name, source), used for gluonnlp.embedding.create(embedding_name, source) min_occur_count : int threshold of rare words, which will be replaced with UNKs, lstm_layers : int layers of lstm word_dims : int dimension of word embedding tag_dims : int dimension of tag embedding dropout_emb : float word dropout lstm_hiddens : int size of lstm hidden states dropout_lstm_input : int dropout on x in variational RNN dropout_lstm_hidden : int dropout on h in variational RNN mlp_arc_size : int output size of MLP for arc feature extraction mlp_rel_size : int output size of MLP for rel feature extraction dropout_mlp : float dropout on the output of LSTM learning_rate : float learning rate decay : float see ExponentialScheduler decay_steps : int see ExponentialScheduler beta_1 : float see ExponentialScheduler beta_2 : float see ExponentialScheduler epsilon : float see ExponentialScheduler num_buckets_train : int number of buckets for training data set num_buckets_valid : int number of buckets for dev data set num_buckets_test : int number of buckets for testing data set train_iters : int training iterations train_batch_size : int training batch size test_batch_size : int test batch size validate_every : int validate on dev set every such number of batches save_after : int skip saving model in early epochs debug : bool debug mode Returns ------- DepParser parser itself """ logger = init_logger(save_dir) config = _Config(train_file, dev_file, test_file, save_dir, pretrained_embeddings, min_occur_count, lstm_layers, word_dims, tag_dims, dropout_emb, lstm_hiddens, dropout_lstm_input, dropout_lstm_hidden, mlp_arc_size, mlp_rel_size, dropout_mlp, learning_rate, decay, decay_steps, beta_1, beta_2, epsilon, num_buckets_train, num_buckets_valid, num_buckets_test, train_iters, train_batch_size, debug) config.save() self._vocab = vocab = ParserVocabulary(train_file, pretrained_embeddings, min_occur_count) vocab.save(config.save_vocab_path) vocab.log_info(logger) with mx.Context(mxnet_prefer_gpu()): self._parser = parser = BiaffineParser( vocab, word_dims, tag_dims, dropout_emb, lstm_layers, lstm_hiddens, dropout_lstm_input, dropout_lstm_hidden, mlp_arc_size, mlp_rel_size, dropout_mlp, debug) parser.initialize() scheduler = ExponentialScheduler(learning_rate, decay, decay_steps) optimizer = mx.optimizer.Adam(learning_rate, beta_1, beta_2, epsilon, lr_scheduler=scheduler) trainer = gluon.Trainer(parser.collect_params(), optimizer=optimizer) data_loader = DataLoader(train_file, num_buckets_train, vocab) global_step = 0 best_UAS = 0. batch_id = 0 epoch = 1 total_epoch = math.ceil(train_iters / validate_every) logger.info('Epoch %d out of %d', epoch, total_epoch) bar = Progbar(target=min(validate_every, data_loader.samples)) while global_step < train_iters: for words, tags, arcs, rels in data_loader.get_batches( batch_size=train_batch_size, shuffle=True): with autograd.record(): arc_accuracy, _, _, loss = parser.forward( words, tags, arcs, rels) loss_value = loss.asscalar() loss.backward() trainer.step(train_batch_size) batch_id += 1 try: bar.update(batch_id, exact=[('UAS', arc_accuracy, 2), ('loss', loss_value)]) except OverflowError: pass # sometimes loss can be 0 or infinity, crashes the bar global_step += 1 if global_step % validate_every == 0: bar = Progbar(target=min(validate_every, train_iters - global_step)) batch_id = 0 UAS, LAS, speed = evaluate_official_script( parser, vocab, num_buckets_valid, test_batch_size, dev_file, os.path.join(save_dir, 'valid_tmp')) logger.info('Dev: UAS %.2f%% LAS %.2f%% %d sents/s', UAS, LAS, speed) epoch += 1 if global_step < train_iters: logger.info('Epoch %d out of %d', epoch, total_epoch) if global_step > save_after and UAS > best_UAS: logger.info('- new best score!') best_UAS = UAS parser.save(config.save_model_path) # When validate_every is too big if not os.path.isfile(config.save_model_path) or best_UAS != UAS: parser.save(config.save_model_path) return self