def epoch_evaluate(args, model, loader, puncts): """Evaluate in one epoch""" model.eval() total_loss, metric = 0, Metric() for words, feats, arcs, rels in loader(): # ignore the first token of each sentence tmp_words = layers.pad(words[:, 1:], paddings=[0, 0, 1, 0], pad_value=args.pad_index) mask = tmp_words != args.pad_index s_arc, s_rel = model(words, feats) loss = loss_function(s_arc, s_rel, arcs, rels, mask) arc_preds, rel_preds = decode(args, s_arc, s_rel, mask) # ignore all punctuation if not specified if not args.punct: punct_mask = layers.reduce_all( layers.expand(layers.unsqueeze(words, -1), (1, 1, puncts.shape[0])) != layers.expand( layers.reshape(puncts, (1, 1, -1)), (*words.shape, 1)), dim=-1) mask = layers.logical_and(mask, punct_mask) metric(arc_preds, rel_preds, arcs, rels, mask) total_loss += loss.numpy().item() total_loss /= len(loader) return total_loss, metric
def epoch_evaluate(args, model, loader, punctuation): """Evaluate in one epoch""" model.eval() total_loss, metric = 0, Metric() pad_index = args.pad_index bos_index = args.bos_index eos_index = args.eos_index for batch_index, inputs in enumerate(loader(), start=1): if args.encoding_model.startswith("ernie"): words, connections, deprel = inputs connection_prob, deprel_prob, words = model(words) else: words, feats, connections, deprel = inputs connection_prob, deprel_prob, words = model(words, feats) mask = layers.logical_and( layers.logical_and(words != pad_index, words != bos_index), words != eos_index, ) loss = loss_function(connection_prob, deprel_prob, connections, deprel, mask) connection_predict, deprel_predict = decode(args, connection_prob, deprel_prob, mask) # ignore all punctuation if not specified if not args.punct: punct_mask = layers.reduce_all( layers.expand(layers.unsqueeze(words, -1), (1, 1, punctuation.shape[0])) != layers.expand(layers.reshape(punctuation, (1, 1, -1)), words.shape + [1]), dim=-1) mask = layers.logical_and(mask, punct_mask) metric(connection_predict, deprel_predict, connections, deprel, mask) total_loss += loss.numpy().item() total_loss /= len(loader) return total_loss, metric
def train(env): """Train""" args = env.args logging.info("loading data.") train = Corpus.load(args.train_data_path, env.fields) dev = Corpus.load(args.valid_data_path, env.fields) test = Corpus.load(args.test_data_path, env.fields) logging.info("init dataset.") train = TextDataset(train, env.fields, args.buckets) dev = TextDataset(dev, env.fields, args.buckets) test = TextDataset(test, env.fields, args.buckets) logging.info("set the data loaders.") train.loader = batchify(train, args.batch_size, args.use_data_parallel, True) dev.loader = batchify(dev, args.batch_size) test.loader = batchify(test, args.batch_size) logging.info(f"{'train:':6} {len(train):5} sentences, " f"{len(train.loader):3} batches, " f"{len(train.buckets)} buckets") logging.info(f"{'dev:':6} {len(dev):5} sentences, " f"{len(dev.loader):3} batches, " f"{len(train.buckets)} buckets") logging.info(f"{'test:':6} {len(test):5} sentences, " f"{len(test.loader):3} batches, " f"{len(train.buckets)} buckets") logging.info("Create the model") model = Model(args, env.WORD.embed) # init parallel strategy if args.use_data_parallel: strategy = dygraph.parallel.prepare_context() model = dygraph.parallel.DataParallel(model, strategy) if args.use_cuda: grad_clip = fluid.clip.GradientClipByNorm(clip_norm=args.clip) else: grad_clip = fluid.clip.GradientClipByGlobalNorm(clip_norm=args.clip) decay = dygraph.ExponentialDecay(learning_rate=args.lr, decay_steps=args.decay_steps, decay_rate=args.decay) optimizer = fluid.optimizer.AdamOptimizer( learning_rate=decay, beta1=args.mu, beta2=args.nu, epsilon=args.epsilon, parameter_list=model.parameters(), grad_clip=grad_clip) total_time = datetime.timedelta() best_e, best_metric = 1, Metric() puncts = dygraph.to_variable(env.puncts, zero_copy=False) logging.info("start training.") for epoch in range(1, args.epochs + 1): start = datetime.datetime.now() # train one epoch and update the parameter logging.info(f"Epoch {epoch} / {args.epochs}:") epoch_train(args, model, optimizer, train.loader, epoch) if args.local_rank == 0: loss, dev_metric = epoch_evaluate(args, model, dev.loader, puncts) logging.info(f"{'dev:':6} Loss: {loss:.4f} {dev_metric}") loss, test_metric = epoch_evaluate(args, model, test.loader, puncts) logging.info(f"{'test:':6} Loss: {loss:.4f} {test_metric}") t = datetime.datetime.now() - start # save the model if it is the best so far if dev_metric > best_metric and epoch > args.patience // 10: best_e, best_metric = epoch, dev_metric save(args.model_path, args, model, optimizer) logging.info(f"{t}s elapsed (saved)\n") else: logging.info(f"{t}s elapsed\n") total_time += t if epoch - best_e >= args.patience: break if args.local_rank == 0: model = load(args.model_path, model) loss, metric = epoch_evaluate(args, model, test.loader, puncts) logging.info( f"max score of dev is {best_metric.score:.2%} at epoch {best_e}") logging.info( f"the score of test at epoch {best_e} is {metric.score:.2%}") logging.info(f"average time of each epoch is {total_time / epoch}s") logging.info(f"{total_time}s elapsed")
def train(env): """Train""" args = env.args logging.info("loading data.") train = Corpus.load(args.train_data_path, env.fields) dev = Corpus.load(args.valid_data_path, env.fields) test = Corpus.load(args.test_data_path, env.fields) logging.info("init dataset.") train = TextDataset(train, env.fields, args.buckets) dev = TextDataset(dev, env.fields, args.buckets) test = TextDataset(test, env.fields, args.buckets) logging.info("set the data loaders.") train.loader = batchify(train, args.batch_size, args.use_data_parallel, True) dev.loader = batchify(dev, args.batch_size) test.loader = batchify(test, args.batch_size) logging.info("{:6} {:5} sentences, ".format('train:', len(train)) + "{:3} batches, ".format(len(train.loader)) + "{} buckets".format(len(train.buckets))) logging.info("{:6} {:5} sentences, ".format('dev:', len(dev)) + "{:3} batches, ".format(len(dev.loader)) + "{} buckets".format(len(dev.buckets))) logging.info("{:6} {:5} sentences, ".format('test:', len(test)) + "{:3} batches, ".format(len(test.loader)) + "{} buckets".format(len(test.buckets))) logging.info("Create the model") model = Model(args) # init parallel strategy if args.use_data_parallel: dist.init_parallel_env() model = paddle.DataParallel(model) if args.encoding_model.startswith( "ernie") and args.encoding_model != "ernie-lstm" or args.encoding_model == 'transformer': args['lr'] = args.ernie_lr else: args['lr'] = args.lstm_lr if args.encoding_model.startswith("ernie") and args.encoding_model != "ernie-lstm": max_steps = 100 * len(train.loader) decay = LinearDecay(args.lr, int(args.warmup_proportion * max_steps), max_steps) clip = args.ernie_clip else: decay = dygraph.ExponentialDecay(learning_rate=args.lr, decay_steps=args.decay_steps, decay_rate=args.decay) clip = args.clip if args.use_cuda: grad_clip = fluid.clip.GradientClipByNorm(clip_norm=clip) else: grad_clip = fluid.clip.GradientClipByGlobalNorm(clip_norm=clip) if args.encoding_model.startswith("ernie") and args.encoding_model != "ernie-lstm": optimizer = AdamW( learning_rate=decay, parameter_list=model.parameters(), weight_decay=args.weight_decay, grad_clip=grad_clip, ) else: optimizer = fluid.optimizer.AdamOptimizer( learning_rate=decay, beta1=args.mu, beta2=args.nu, epsilon=args.epsilon, parameter_list=model.parameters(), grad_clip=grad_clip, ) total_time = datetime.timedelta() best_e, best_metric = 1, Metric() puncts = dygraph.to_variable(env.puncts, zero_copy=False) logging.info("start training.") for epoch in range(1, args.epochs + 1): start = datetime.datetime.now() # train one epoch and update the parameter logging.info("Epoch {} / {}:".format(epoch, args.epochs)) epoch_train(args, model, optimizer, train.loader, epoch) if args.local_rank == 0: loss, dev_metric = epoch_evaluate(args, model, dev.loader, puncts) logging.info("{:6} Loss: {:.4f} {}".format('dev:', loss, dev_metric)) loss, test_metric = epoch_evaluate(args, model, test.loader, puncts) logging.info("{:6} Loss: {:.4f} {}".format('test:', loss, test_metric)) t = datetime.datetime.now() - start # save the model if it is the best so far if dev_metric > best_metric and epoch > args.patience // 10: best_e, best_metric = epoch, dev_metric save(args.model_path, args, model, optimizer) logging.info("{}s elapsed (saved)\n".format(t)) else: logging.info("{}s elapsed\n".format(t)) total_time += t if epoch - best_e >= args.patience: break if args.local_rank == 0: model = load(args.model_path, model) loss, metric = epoch_evaluate(args, model, test.loader, puncts) logging.info("max score of dev is {:.2%} at epoch {}".format(best_metric.score, best_e)) logging.info("the score of test at epoch {} is {:.2%}".format(best_e, metric.score)) logging.info("average time of each epoch is {}s".format(total_time / epoch)) logging.info("{}s elapsed".format(total_time))