Esempio n. 1
0
def epoch_evaluate(args, model, loader, puncts):
    """Evaluate in one epoch"""
    model.eval()

    total_loss, metric = 0, Metric()

    for words, feats, arcs, rels in loader():
        # ignore the first token of each sentence
        tmp_words = layers.pad(words[:, 1:],
                               paddings=[0, 0, 1, 0],
                               pad_value=args.pad_index)
        mask = tmp_words != args.pad_index

        s_arc, s_rel = model(words, feats)
        loss = loss_function(s_arc, s_rel, arcs, rels, mask)
        arc_preds, rel_preds = decode(args, s_arc, s_rel, mask)
        # ignore all punctuation if not specified
        if not args.punct:
            punct_mask = layers.reduce_all(
                layers.expand(layers.unsqueeze(words, -1),
                              (1, 1, puncts.shape[0])) != layers.expand(
                                  layers.reshape(puncts, (1, 1, -1)),
                                  (*words.shape, 1)),
                dim=-1)
            mask = layers.logical_and(mask, punct_mask)

        metric(arc_preds, rel_preds, arcs, rels, mask)
        total_loss += loss.numpy().item()

    total_loss /= len(loader)

    return total_loss, metric
Esempio n. 2
0
def epoch_evaluate(args, model, loader, punctuation):
    """Evaluate in one epoch"""
    model.eval()
    total_loss, metric = 0, Metric()
    pad_index = args.pad_index
    bos_index = args.bos_index
    eos_index = args.eos_index

    for batch_index, inputs in enumerate(loader(), start=1):
        if args.encoding_model.startswith("ernie"):
            words, connections, deprel = inputs
            connection_prob, deprel_prob, words = model(words)
        else:
            words, feats, connections, deprel = inputs
            connection_prob, deprel_prob, words = model(words, feats)
        mask = layers.logical_and(
            layers.logical_and(words != pad_index, words != bos_index),
            words != eos_index,
        )
        loss = loss_function(connection_prob, deprel_prob, connections, deprel,
                             mask)
        connection_predict, deprel_predict = decode(args, connection_prob,
                                                    deprel_prob, mask)
        # ignore all punctuation if not specified
        if not args.punct:
            punct_mask = layers.reduce_all(
                layers.expand(layers.unsqueeze(words, -1),
                              (1, 1, punctuation.shape[0])) !=
                layers.expand(layers.reshape(punctuation,
                                             (1, 1, -1)), words.shape + [1]),
                dim=-1)

            mask = layers.logical_and(mask, punct_mask)

        metric(connection_predict, deprel_predict, connections, deprel, mask)
        total_loss += loss.numpy().item()

    total_loss /= len(loader)

    return total_loss, metric
Esempio n. 3
0
def train(env):
    """Train"""
    args = env.args

    logging.info("loading data.")
    train = Corpus.load(args.train_data_path, env.fields)
    dev = Corpus.load(args.valid_data_path, env.fields)
    test = Corpus.load(args.test_data_path, env.fields)
    logging.info("init dataset.")
    train = TextDataset(train, env.fields, args.buckets)
    dev = TextDataset(dev, env.fields, args.buckets)
    test = TextDataset(test, env.fields, args.buckets)
    logging.info("set the data loaders.")
    train.loader = batchify(train, args.batch_size, args.use_data_parallel,
                            True)
    dev.loader = batchify(dev, args.batch_size)
    test.loader = batchify(test, args.batch_size)

    logging.info(f"{'train:':6} {len(train):5} sentences, "
                 f"{len(train.loader):3} batches, "
                 f"{len(train.buckets)} buckets")
    logging.info(f"{'dev:':6} {len(dev):5} sentences, "
                 f"{len(dev.loader):3} batches, "
                 f"{len(train.buckets)} buckets")
    logging.info(f"{'test:':6} {len(test):5} sentences, "
                 f"{len(test.loader):3} batches, "
                 f"{len(train.buckets)} buckets")

    logging.info("Create the model")
    model = Model(args, env.WORD.embed)

    # init parallel strategy
    if args.use_data_parallel:
        strategy = dygraph.parallel.prepare_context()
        model = dygraph.parallel.DataParallel(model, strategy)

    if args.use_cuda:
        grad_clip = fluid.clip.GradientClipByNorm(clip_norm=args.clip)
    else:
        grad_clip = fluid.clip.GradientClipByGlobalNorm(clip_norm=args.clip)
    decay = dygraph.ExponentialDecay(learning_rate=args.lr,
                                     decay_steps=args.decay_steps,
                                     decay_rate=args.decay)
    optimizer = fluid.optimizer.AdamOptimizer(
        learning_rate=decay,
        beta1=args.mu,
        beta2=args.nu,
        epsilon=args.epsilon,
        parameter_list=model.parameters(),
        grad_clip=grad_clip)

    total_time = datetime.timedelta()
    best_e, best_metric = 1, Metric()

    puncts = dygraph.to_variable(env.puncts, zero_copy=False)
    logging.info("start training.")
    for epoch in range(1, args.epochs + 1):
        start = datetime.datetime.now()
        # train one epoch and update the parameter
        logging.info(f"Epoch {epoch} / {args.epochs}:")
        epoch_train(args, model, optimizer, train.loader, epoch)
        if args.local_rank == 0:
            loss, dev_metric = epoch_evaluate(args, model, dev.loader, puncts)
            logging.info(f"{'dev:':6} Loss: {loss:.4f} {dev_metric}")
            loss, test_metric = epoch_evaluate(args, model, test.loader,
                                               puncts)
            logging.info(f"{'test:':6} Loss: {loss:.4f} {test_metric}")

            t = datetime.datetime.now() - start
            # save the model if it is the best so far
            if dev_metric > best_metric and epoch > args.patience // 10:
                best_e, best_metric = epoch, dev_metric
                save(args.model_path, args, model, optimizer)
                logging.info(f"{t}s elapsed (saved)\n")
            else:
                logging.info(f"{t}s elapsed\n")
            total_time += t
            if epoch - best_e >= args.patience:
                break
    if args.local_rank == 0:
        model = load(args.model_path, model)
        loss, metric = epoch_evaluate(args, model, test.loader, puncts)
        logging.info(
            f"max score of dev is {best_metric.score:.2%} at epoch {best_e}")
        logging.info(
            f"the score of test at epoch {best_e} is {metric.score:.2%}")
        logging.info(f"average time of each epoch is {total_time / epoch}s")
        logging.info(f"{total_time}s elapsed")
Esempio n. 4
0
def train(env):
    """Train"""
    args = env.args

    logging.info("loading data.")
    train = Corpus.load(args.train_data_path, env.fields)
    dev = Corpus.load(args.valid_data_path, env.fields)
    test = Corpus.load(args.test_data_path, env.fields)
    logging.info("init dataset.")
    train = TextDataset(train, env.fields, args.buckets)
    dev = TextDataset(dev, env.fields, args.buckets)
    test = TextDataset(test, env.fields, args.buckets)
    logging.info("set the data loaders.")
    train.loader = batchify(train, args.batch_size, args.use_data_parallel, True)
    dev.loader = batchify(dev, args.batch_size)
    test.loader = batchify(test, args.batch_size)

    logging.info("{:6} {:5} sentences, ".format('train:', len(train)) + "{:3} batches, ".format(len(train.loader)) +
                 "{} buckets".format(len(train.buckets)))
    logging.info("{:6} {:5} sentences, ".format('dev:', len(dev)) + "{:3} batches, ".format(len(dev.loader)) +
                 "{} buckets".format(len(dev.buckets)))
    logging.info("{:6} {:5} sentences, ".format('test:', len(test)) + "{:3} batches, ".format(len(test.loader)) +
                 "{} buckets".format(len(test.buckets)))

    logging.info("Create the model")
    model = Model(args)

    # init parallel strategy
    if args.use_data_parallel:
        dist.init_parallel_env()
        model = paddle.DataParallel(model)

    if args.encoding_model.startswith(
            "ernie") and args.encoding_model != "ernie-lstm" or args.encoding_model == 'transformer':
        args['lr'] = args.ernie_lr
    else:
        args['lr'] = args.lstm_lr

    if args.encoding_model.startswith("ernie") and args.encoding_model != "ernie-lstm":
        max_steps = 100 * len(train.loader)
        decay = LinearDecay(args.lr, int(args.warmup_proportion * max_steps), max_steps)
        clip = args.ernie_clip
    else:
        decay = dygraph.ExponentialDecay(learning_rate=args.lr, decay_steps=args.decay_steps, decay_rate=args.decay)
        clip = args.clip

    if args.use_cuda:
        grad_clip = fluid.clip.GradientClipByNorm(clip_norm=clip)
    else:
        grad_clip = fluid.clip.GradientClipByGlobalNorm(clip_norm=clip)

    if args.encoding_model.startswith("ernie") and args.encoding_model != "ernie-lstm":
        optimizer = AdamW(
            learning_rate=decay,
            parameter_list=model.parameters(),
            weight_decay=args.weight_decay,
            grad_clip=grad_clip,
        )
    else:
        optimizer = fluid.optimizer.AdamOptimizer(
            learning_rate=decay,
            beta1=args.mu,
            beta2=args.nu,
            epsilon=args.epsilon,
            parameter_list=model.parameters(),
            grad_clip=grad_clip,
        )

    total_time = datetime.timedelta()
    best_e, best_metric = 1, Metric()

    puncts = dygraph.to_variable(env.puncts, zero_copy=False)
    logging.info("start training.")

    for epoch in range(1, args.epochs + 1):
        start = datetime.datetime.now()
        # train one epoch and update the parameter
        logging.info("Epoch {} / {}:".format(epoch, args.epochs))
        epoch_train(args, model, optimizer, train.loader, epoch)
        if args.local_rank == 0:
            loss, dev_metric = epoch_evaluate(args, model, dev.loader, puncts)
            logging.info("{:6} Loss: {:.4f} {}".format('dev:', loss, dev_metric))
            loss, test_metric = epoch_evaluate(args, model, test.loader, puncts)
            logging.info("{:6} Loss: {:.4f} {}".format('test:', loss, test_metric))

            t = datetime.datetime.now() - start
            # save the model if it is the best so far
            if dev_metric > best_metric and epoch > args.patience // 10:
                best_e, best_metric = epoch, dev_metric
                save(args.model_path, args, model, optimizer)
                logging.info("{}s elapsed (saved)\n".format(t))
            else:
                logging.info("{}s elapsed\n".format(t))
            total_time += t
            if epoch - best_e >= args.patience:
                break
    if args.local_rank == 0:
        model = load(args.model_path, model)
        loss, metric = epoch_evaluate(args, model, test.loader, puncts)
        logging.info("max score of dev is {:.2%} at epoch {}".format(best_metric.score, best_e))
        logging.info("the score of test at epoch {} is {:.2%}".format(best_e, metric.score))
        logging.info("average time of each epoch is {}s".format(total_time / epoch))
        logging.info("{}s elapsed".format(total_time))