Пример #1
0
def train(model, train_dataset, dev_dataset, dev_examples, dev_features, tokenizer, args):
    ctx = D.parallel.prepare_context()
    model = D.parallel.DataParallel(model, ctx)

    max_steps = len(train_features) * args.epoch // args.bsz
    opt = AdamW(learning_rate=args.lr, parameter_list=model.parameters(), weight_decay=args.wd)
    g_clip = F.clip.GradientClipByGlobalNorm(1.0) #experimental

    train_dataset = train_dataset \
            .repeat() \
            .shard(D.parallel.Env().nranks, D.parallel.Env().dev_id) \
            .shuffle(1000) \
            .padded_batch(args.bsz) 

    log.debug('init training with args: %s' % repr(args))
    for step, (_, token_ids, token_type_ids, start_pos, end_pos) in enumerate(train_dataset.start(place)):
        loss, _, __ = model(token_ids, token_type_ids, start_pos=start_pos, end_pos=end_pos)
        scaled_loss = model.scale_loss(loss)
        scaled_loss.backward()
        model.apply_collective_grads()
        opt.minimize(scaled_loss, grad_clip=g_clip)
        model.clear_gradients()
        if D.parallel.Env().dev_id == 0 and step % 10 == 0:
            log.debug('[step %d] train loss %.5f lr %.3e' % (step, loss.numpy(), opt.current_step_lr()))
        if D.parallel.Env().dev_id == 0 and step % 100 == 0:
            f1, em = evaluate(model, dev_dataset, dev_examples, dev_features, tokenizer, args)
            log.debug('[step %d] eval result: f1 %.5f em %.5f' % (step, f1, em))
        if step > max_steps:
            break
Пример #2
0
    def __init__(self,
                 args,
                 ernie,
                 tokenizer,
                 device,
                 hidden=256,
                 layer_n=1,
                 lr=2e-5,
                 gama=0.8,
                 betas=(0.9, 0.999),
                 weight_decay=0.01,
                 warmup_steps=10000,
                 g_clip=0.001):

        self.device = device
        self.tokenizer = tokenizer
        self.model = SoftMaskedErnie(ernie, self.tokenizer, hidden, layer_n,
                                     self.device).to(self.device)

        opt = AdamW(learning_rate=LinearDecay(
            args.lr, int(args.warmup_proportion * args.max_steps),
            args.max_steps),
                    parameter_list=self.model.parameters(),
                    weight_decay=args.wd,
                    grad_clip=g_clip)

        self.optim_schedule = ScheduledOptim(opt,
                                             hidden,
                                             n_warmup_steps=warmup_steps)
        self.criterion_c = fluid.dygraph.NLLLoss()
        self.criterion_d = fluid.dygraph.BCELoss()

        self.gama = gama
        self.log_freq = 10
            dev_ds = feature_column.build_dataset('dev', data_dir=os.path.join(args.data_dir, 'dev'), shuffle=False, repeat=False, use_gz=False) \
                                           .map(map_fn) \
                                           .padded_batch(args.bsz)

            shapes = ([-1, args.max_seqlen], [-1, args.max_seqlen], [-1])
            types = ('int64', 'int64', 'int64')

            train_ds.data_shapes = shapes
            train_ds.data_types = types
            dev_ds.data_shapes = shapes
            dev_ds.data_types = types

            g_clip = F.clip.GradientClipByGlobalNorm(1.0)  #experimental
            opt = AdamW(learning_rate=LinearDecay(
                args.lr, int(args.warmup_proportion * args.max_steps),
                args.max_steps),
                        parameter_list=model.parameters(),
                        weight_decay=args.wd,
                        grad_clip=g_clip)

            for epoch in range(args.epoch):
                for step, d in enumerate(
                        tqdm(train_ds.start(place), desc='training')):
                    ids, sids, label = d
                    loss, _ = model(ids, sids, labels=label)
                    loss.backward()
                    if step % 10 == 0:
                        log.debug('train loss %.5f lr %.3e' %
                                  (loss.numpy(), opt.current_step_lr()))
                    opt.minimize(loss)
                    model.clear_gradients()
                    if step % 100 == 0:
Пример #4
0
def seq2seq(model, tokenizer, args):
    log.info('Training starts with args: %r' % args)
    attn_id = tokenizer.vocab[args.attn_token]

    def gen_mask(batch_ids, mask_type='bidi', query_len=None, pad_value=0):
        if query_len is None:
            query_len = batch_ids.shape[1]
        if mask_type != 'empty':
            mask = (batch_ids != pad_value).astype(np.float32)
            mask = np.tile(np.expand_dims(mask, 1), [1, query_len, 1])
            if mask_type == 'causal':
                assert query_len == batch_ids.shape[1]
                mask = np.tril(mask)
            elif mask_type == 'causal_without_diag':
                assert query_len == batch_ids.shape[1]
                mask = np.tril(mask, -1)
            elif mask_type == 'diag':
                assert query_len == batch_ids.shape[1]
                mask = np.stack([np.diag(np.diag(m)) for m in mask], 0)
        else:
            mask_type == 'empty'
            mask = np.zeros_like(batch_ids).astype(np.float32)
            mask = np.tile(np.expand_dims(mask, 1), [1, query_len, 1])
        return mask

    def make_some_noice(ids):
        if args.use_random_noice:
            noice_ids = np.random.randint(1,
                                          len(tokenizer.vocab),
                                          size=ids.shape)
        else:
            noice_ids = np.ones_like(ids) * tokenizer.vocab['[NOISE]']
        pos, = np.where(np.ones_like(ids))
        np.random.shuffle(pos)
        pos = pos[:int(args.noise_prob * len(pos))]
        ids[pos, ] = noice_ids[pos, ]
        return ids

    def map_fn(example_id, src_ids, tgt_ids):
        src_ids = src_ids[:args.max_encode_len]
        tgt_ids = tgt_ids[:args.max_decode_len]
        src_ids, src_sids = tokenizer.build_for_ernie(src_ids)
        src_pids = np.arange(len(src_ids))

        tgt_ids, tgt_sids = tokenizer.build_for_ernie(tgt_ids)
        tgt_pids = np.arange(len(tgt_ids)) + len(src_ids)  # continues position
        tgt_sids = np.ones_like(tgt_sids) * args.tgt_type_id

        attn_ids = np.ones_like(tgt_ids) * attn_id
        if args.noise_prob > 0.:
            tgt_labels = deepcopy(tgt_ids)
            tgt_ids = make_some_noice(tgt_ids)  #corrupted
        else:
            tgt_labels = tgt_ids

        return (example_id, src_ids, src_pids, src_sids, tgt_ids, tgt_pids,
                tgt_sids, attn_ids, tgt_labels)

    def after_padding(example_id, src_ids, src_pids, src_sids, tgt_ids,
                      tgt_pids, tgt_sids, attn_ids, tgt_labels):
        '''
        attention mask:
        ***  src,  tgt, attn
        src  00,   01,   11
        tgt  10,   11,   12
        attn 20,   21,   22

        ***   s1, s2 | t1 t2 t3| attn1 attn2 attn3
        s1    1,  1  | 0, 0, 0,| 0,    0,    0,
        s2    1,  1  | 0, 0, 0,| 0,    0,    0,
        -
        t1    1,  1, | 1, 0, 0,| 0,    0,    0,
        t2    1,  1, | 1, 1, 0,| 0,    0,    0,
        t3    1,  1, | 1, 1, 1,| 0,    0,    0,
        -
        attn1 1,  1, | 0, 0, 0,| 1,    0,    0,
        attn2 1,  1, | 1, 0, 0,| 0,    1,    0,
        attn3 1,  1, | 1, 1, 0,| 0,    0,    1,

        for details, see Fig3. https://arxiv.org/abs/2001.11314
        '''

        src_len = src_ids.shape[1]
        tgt_len = tgt_ids.shape[1]
        mask_00 = gen_mask(src_ids, 'bidi', query_len=src_len)
        mask_01 = gen_mask(tgt_ids, 'empty', query_len=src_len)
        mask_02 = gen_mask(attn_ids, 'empty', query_len=src_len)

        mask_10 = gen_mask(src_ids, 'bidi', query_len=tgt_len)
        mask_11 = gen_mask(tgt_ids, 'causal', query_len=tgt_len)
        mask_12 = gen_mask(attn_ids, 'empty', query_len=tgt_len)

        mask_20 = gen_mask(src_ids, 'bidi', query_len=tgt_len)
        mask_21 = gen_mask(tgt_ids, 'causal_without_diag', query_len=tgt_len)
        mask_22 = gen_mask(attn_ids, 'diag', query_len=tgt_len)
        '''
        mask = np.concatenate([
            np.concatenate([mask_00, mask_01, mask_02], 2),
            np.concatenate([mask_10, mask_11, mask_12], 2),
            np.concatenate([mask_20, mask_21, mask_22], 2),
        ], 1)

        ids = np.concatenate([src_ids, tgt_ids, attn_ids], 1)
        pids = np.concatenate([src_pids, tgt_pids, tgt_pids], 1)
        sids = np.concatenate([src_sids, tgt_sids, tgt_sids], 1)

        '''

        mask_src_2_src = mask_00
        mask_tgt_2_srctgt = np.concatenate([mask_10, mask_11], 2)
        mask_attn_2_srctgtattn = np.concatenate([mask_20, mask_21, mask_22], 2)

        tgt_labels = tgt_labels[np.where(tgt_labels != 0)]
        return (example_id, src_ids, src_sids, src_pids, tgt_ids, tgt_sids,
                tgt_pids, attn_ids, mask_src_2_src, mask_tgt_2_srctgt,
                mask_attn_2_srctgtattn, tgt_labels)

    bytes_vocab = {k.encode('utf8'): v for k, v in tokenizer.vocab.items()}
    feature_column = propeller.data.FeatureColumns([
        propeller.data.LabelColumn('id'),
        propeller.data.TextColumn('src',
                                  unk_id=tokenizer.unk_id,
                                  vocab_dict=bytes_vocab),
        propeller.data.TextColumn('tgt',
                                  unk_id=tokenizer.unk_id,
                                  vocab_dict=bytes_vocab),
    ])

    train_ds = feature_column.build_dataset('train', data_dir=os.path.join(args.data_dir, 'train'), shuffle=False, repeat=True, use_gz=False) \
                                   .map(map_fn)

    dev_ds = feature_column.build_dataset('dev', data_dir=os.path.join(args.data_dir, 'dev'), shuffle=False, repeat=False, use_gz=False) \
                                   .map(map_fn) \
                                   .padded_batch(args.eval_bsz) \
                                   .map(after_padding)

    log.debug('shard %d of %d' %
              (D.parallel.Env().dev_id, D.parallel.Env().nranks))
    train_ds = train_ds.shard(
        D.parallel.Env().nranks,
        D.parallel.Env().dev_id).shuffle(10000).padded_batch(
            args.bsz).map(after_padding)
    dev_ds = dev_ds.shard(D.parallel.Env().nranks, D.parallel.Env().dev_id)

    shapes = [[None, None]] * 7 + [[None, None, None]] * 3 + [[None]]
    types = ['int64'] * 11

    train_ds.data_shapes = shapes
    train_ds.data_types = types
    dev_ds.data_shapes = shapes
    dev_ds.data_types = types

    vocab_size, _ = model.word_emb.weight.shape
    ctx = D.parallel.prepare_context()
    model = D.parallel.DataParallel(model, ctx)
    g_clip = F.clip.GradientClipByGlobalNorm(1.0)
    opt = AdamW(learning_rate=LinearDecay(
        args.lr, int(args.warmup_proportion * args.max_steps), args.max_steps),
                parameter_list=model.parameters(),
                weight_decay=args.wd,
                grad_clip=g_clip)
    attn_id = tokenizer.vocab[args.attn_token]
    for step, data in enumerate(train_ds.start(place)):
        (example_id, src_ids, src_sids, src_pids, tgt_ids, tgt_sids, tgt_pids,
         attn_ids, mask_src_2_src, mask_tgt_2_srctgt, mask_attn_2_srctgtattn,
         tgt_labels) = data

        _, __, info = model(src_ids,
                            sent_ids=src_sids,
                            pos_ids=src_pids,
                            attn_bias=mask_src_2_src,
                            encode_only=True)
        cached_k, cached_v = info['caches']
        _, __, info = model(tgt_ids,
                            sent_ids=tgt_sids,
                            pos_ids=tgt_pids,
                            attn_bias=mask_tgt_2_srctgt,
                            past_cache=(cached_k, cached_v),
                            encode_only=True)
        cached_k2, cached_v2 = info['caches']
        past_cache_k = [
            L.concat([k, k2], 1) for k, k2 in zip(cached_k, cached_k2)
        ]
        past_cache_v = [
            L.concat([v, v2], 1) for v, v2 in zip(cached_v, cached_v2)
        ]
        if args.label_smooth > 0.:
            tgt_labels = L.label_smooth(F.one_hot(tgt_labels, vocab_size),
                                        epsilon=args.label_smooth)
        loss, _, __ = model(attn_ids,
                            sent_ids=tgt_sids,
                            pos_ids=tgt_pids,
                            attn_bias=mask_attn_2_srctgtattn,
                            past_cache=(past_cache_k, past_cache_v),
                            tgt_labels=tgt_labels,
                            tgt_pos=L.where(attn_ids == attn_id))

        scaled_loss = model.scale_loss(loss)
        scaled_loss.backward()
        model.apply_collective_grads()
        opt.minimize(scaled_loss)
        model.clear_gradients()
        if step % 10 == 0:
            loss = loss.numpy()
            ppl = np.exp(loss)
            log.debug('[step %d]train loss %.5f, ppl %.5f, lr %.3e' %
                      (step, loss, ppl, opt.current_step_lr()))
        if args.save_dir is not None and step % 1000 == 0 and D.parallel.Env(
        ).dev_id == 0:
            F.save_dygraph(model.state_dict(), args.save_dir)
        if args.predict_output_dir is not None and step > args.skip_eval_steps and step % args.eval_steps == 0:
            assert os.path.exists(
                args.predict_output_dir
            ), 'predict_output_dir not found: %s' % args.predict_output_dir
            log.debug('doing predict on gpu %d...' % D.parallel.Env().dev_id)
            evaluate(model, dev_ds, step, args)
        if step > args.max_steps:
            break
    evaluate(model, dev_ds, step, args)

    if args.save_dir is not None:
        F.save_dygraph(model.state_dict(), args.save_dir)
Пример #5
0
        for step, (ids_student, ids, _, labels) in enumerate(dataset.start()):
            _, logits = model(ids)
            pred = L.argmax(logits, -1)
            all_pred.extend(pred.numpy())
            all_label.extend(labels.numpy())
        f1 = f1_score(all_label, all_pred, average='macro')
        model.train()
        return f1


teacher_model = ErnieModelForSequenceClassification.from_pretrained(
    'ernie-1.0', num_labels=2)
teacher_model.train()
if not os.path.exists('./teacher_model.pdparams'):
    opt = AdamW(learning_rate=LinearDecay(LR, 9600 * EPOCH * 0.1 / BATCH,
                                          9600 * EPOCH / BATCH),
                parameter_list=teacher_model.parameters(),
                weight_decay=0.01)
    g_clip = F.dygraph_grad_clip.GradClipByGlobalNorm(1.0)
    for epoch in range(EPOCH):
        for step, (ids_student, ids, sids,
                   labels) in enumerate(train_ds.start(place)):
            loss, logits = teacher_model(ids, labels=labels)
            loss.backward()
            if step % 10 == 0:
                print('[step %03d] teacher train loss %.5f lr %.3e' %
                      (step, loss.numpy(), opt.current_step_lr()))
            opt.minimize(loss, grad_clip=g_clip)
            teacher_model.clear_gradients()
            if step % 100 == 0:
                f1 = evaluate_teacher(teacher_model, dev_ds)
                print('teacher f1: %.5f' % f1)
Пример #6
0
    def finetune(
            self,
            train_path,
            dev_path=None,
            save_dir="ernie_gen_result",
            init_ckpt_path=None,
            use_gpu=True,
            max_steps=500,
            batch_size=8,
            max_encode_len=50,
            max_decode_len=50,
            learning_rate=5e-5,
            warmup_proportion=0.1,
            weight_decay=0.1,
            noise_prob=0,
            label_smooth=0,
            beam_width=5,
            length_penalty=1.0,
            log_interval=100,
            save_interval=200,
    ):
        """
        finetune with the specified dataset.

        Args:
            train_path(str): the train dataset path.
            dev_path(str): the dev dataset path.
            save_dir(str): the model params and dev dataset predict result save path.
            init_ckpt_path(str): incremental training load path.
            use_gpu(bool): use gpu or not.
            max_steps(int): max training steps.
            batch_size(int): the batch size.
            max_encode_len(int): the max encode length.
            max_decode_len(int): the max decode length.
            learning_rate(float): the learning rate.
            warmup_proportion(float): the warmup proportion.
            weight_decay(float): the weight decay magnitude.
            noise_prob(float): the nosie probability. see the ernie gen paper for details.
            label_smooth(float): the label smooth magnitude.
            beam_width(int): the beam size during evaluating the dev dataset.
            length_penalty(float): the length penalty during evaluating the dev dataset.
            log_interval(int): the log interval.
            save_interval(int): the save interval. dev set will be evaluated after saving.

        Return:
            result(dict): A Dictionary of shape::
                {
                    last_save_path(str): last model save path.
                    last_ppl(float): last model ppl.
                }
        """
        self.max_encode_len = max_encode_len
        self.max_decode_len = max_decode_len
        self.noise_prob = noise_prob

        place = F.CUDAPlace(0) if use_gpu else F.CPUPlace()

        with F.dygraph.guard(place):
            if init_ckpt_path is not None:
                logger.info('loading checkpoint from %s' % init_ckpt_path)
                sd, _ = D.load_dygraph(init_ckpt_path)
                self.model.set_dict(sd)

            feature_column = propeller.data.FeatureColumns([
                propeller.data.LabelColumn('id'),
                propeller.data.TextColumn(
                    'src',
                    unk_id=self.tokenizer.unk_id,
                    vocab_dict=self.tokenizer.vocab,
                    tokenizer=self.tokenizer.tokenize),
                propeller.data.TextColumn(
                    'tgt',
                    unk_id=self.tokenizer.unk_id,
                    vocab_dict=self.tokenizer.vocab,
                    tokenizer=self.tokenizer.tokenize),
            ])

            train_ds = feature_column.build_dataset('train', data_file=train_path, shuffle=False,
                                                    repeat=True, use_gz=False)\
                .map(self._map_fn).shuffle(10000).padded_batch(batch_size).map(self._after_padding)
            train_ds.data_shapes = [[None, None]] * 7 + [[None, None, None]
                                                         ] * 3 + [[None]]
            train_ds.data_types = ['int64'] * 11

            if dev_path:
                dev_ds = feature_column.build_dataset('dev', data_file=dev_path, shuffle=False,
                                                    repeat=False, use_gz=False) \
                    .map(self._map_fn) \
                    .padded_batch(1) \
                    .map(self._after_padding)
                dev_ds.data_shapes = [[None, None]] * 7 + [[None, None, None]
                                                           ] * 3 + [[None]]
                dev_ds.data_types = ['int64'] * 11

            vocab_size, _ = self.model.word_emb.weight.shape
            g_clip = F.clip.GradientClipByGlobalNorm(1.0)
            opt = AdamW(
                learning_rate=LinearDecay(learning_rate,
                                          int(warmup_proportion * max_steps),
                                          max_steps),
                parameter_list=self.model.parameters(),
                weight_decay=weight_decay,
                grad_clip=g_clip)

            loss = None

            save_path = None
            ppl = None

            if save_dir and not os.path.exists(save_dir):
                os.makedirs(save_dir)
            for step, data in enumerate(train_ds.start(place)):
                (example_id, src_ids, src_sids, src_pids, tgt_ids, tgt_sids,
                 tgt_pids, attn_ids, mask_src_2_src, mask_tgt_2_srctgt,
                 mask_attn_2_srctgtattn, tgt_labels) = data

                _, __, info = self.model(
                    src_ids,
                    sent_ids=src_sids,
                    pos_ids=src_pids,
                    attn_bias=mask_src_2_src,
                    encode_only=True)
                cached_k, cached_v = info['caches']
                _, __, info = self.model(
                    tgt_ids,
                    sent_ids=tgt_sids,
                    pos_ids=tgt_pids,
                    attn_bias=mask_tgt_2_srctgt,
                    past_cache=(cached_k, cached_v),
                    encode_only=True)
                cached_k2, cached_v2 = info['caches']
                past_cache_k = [
                    L.concat([k, k2], 1) for k, k2 in zip(cached_k, cached_k2)
                ]
                past_cache_v = [
                    L.concat([v, v2], 1) for v, v2 in zip(cached_v, cached_v2)
                ]
                if label_smooth > 0.:
                    tgt_labels = L.label_smooth(
                        F.one_hot(tgt_labels, vocab_size), epsilon=label_smooth)
                loss, _, __ = self.model(
                    attn_ids,
                    sent_ids=tgt_sids,
                    pos_ids=tgt_pids,
                    attn_bias=mask_attn_2_srctgtattn,
                    past_cache=(past_cache_k, past_cache_v),
                    tgt_labels=tgt_labels,
                    tgt_pos=L.where(attn_ids == self.tokenizer.vocab['[MASK]']))

                loss.backward()
                opt.minimize(loss)
                self.model.clear_gradients()

                if step % log_interval == 0:
                    loss_np = loss.numpy()
                    ppl = np.exp(loss_np)
                    logger.info(
                        '[step %d / %d]train loss %.5f, ppl %.5f, elr %.3e' %
                        (step, max_steps, loss_np, ppl, opt.current_step_lr()))
                if save_dir and step % save_interval == 0 and step > 0:
                    loss_np = loss.numpy()
                    ppl = np.exp(loss_np)
                    save_name = "step_%s_ppl_%.5f" % (step, ppl)
                    save_path = os.path.join(save_dir, save_name)
                    logger.info("save the model in %s" % save_path)
                    F.save_dygraph(self.model.state_dict(), save_path)

                    if dev_path:
                        logger.info('evaluating...')
                        res = self._evaluate(dev_ds, place, beam_width,
                                             length_penalty)
                        output_path = os.path.join(
                            save_dir, "step_%s_ppl_%.5f.txt" % (step, ppl))
                        logger.info(
                            'save the predict result in %s' % output_path)
                        with open(output_path, 'w') as fout:
                            fout.write(('\n'.join(res)))

                if step > max_steps:
                    break

            if loss:
                loss_np = loss.numpy()
                ppl = np.exp(loss_np)
                logger.info('[final step %d]train loss %.5f, ppl %.5f, elr %.3e'
                            % (step, loss_np, ppl, opt.current_step_lr()))
                if save_dir:
                    save_name = "step_%s_ppl_%.5f" % (step, ppl)
                    save_path = os.path.join(save_dir, save_name)
                    logger.info("save the model in %s" % save_path)
                    F.save_dygraph(self.model.state_dict(), save_path)

                    if dev_path:
                        logger.info('evaluating...')
                        res = self._evaluate(dev_ds, place, beam_width,
                                             length_penalty)
                        output_path = os.path.join(
                            save_dir, "step_%s_ppl_%.5f.txt" % (step, ppl))
                        logger.info(
                            'save the predict result in %s' % output_path)
                        with open(output_path, 'w') as fout:
                            fout.write(('\n'.join(res)))

            result = {
                "last_save_path": "%s.pdparams" % save_path,
                "last_ppl": ppl[0],
            }

            return result
Пример #7
0
    seq_shape = [-1, args.max_seqlen]
    ints_shape = [
        -1,
    ]
    shapes = (seq_shape, seq_shape, ints_shape, [-1, 2], ints_shape)
    types = ('int64', 'int64', 'int64', 'int64', 'int64')

    train_ds.data_shapes = shapes
    train_ds.data_types = types

    place = F.CUDAPlace(D.parallel.Env().dev_id)
    with D.guard(place):
        model = ErnieModelForPretraining.from_pretrained(args.from_pretrained)
        opt = AdamW(learning_rate=LinearDecay(args.lr, args.warmup_steps,
                                              args.max_steps),
                    parameter_list=model.parameters(),
                    weight_decay=0.01)

        ctx = D.parallel.prepare_context()
        model = D.parallel.DataParallel(model, ctx)

        for step, samples in enumerate(tqdm(train_ds.start(place))):
            (src_ids, sent_ids, mlm_label, mask_pos, nsp_label) = samples
            loss, mlmloss, nsploss = model(src_ids,
                                           sent_ids,
                                           labels=mlm_label,
                                           mlm_pos=mask_pos,
                                           nsp_labels=nsp_label)
            scaled_loss = model.scale_loss(loss)
            scaled_loss.backward()
            model.apply_collective_grads()
Пример #8
0

    shapes = ([-1, args.max_seqlen], [-1, args.max_seqlen], [-1])
    types = ('int64', 'int64', 'int64')

    train_ds.data_shapes = shapes
    train_ds.data_types = types
    dev_ds.data_shapes = shapes
    dev_ds.data_types = types

    place = F.CUDAPlace(0)
    with FD.guard(place):
        model = ErnieModelForSequenceClassification.from_pretrained(args.from_pretrained, num_labels=3, name='')

        if args.use_lr_decay:
            opt = AdamW(learning_rate=LinearDecay(args.lr, int(args.warmup_proportion * args.max_steps), args.max_steps), parameter_list=model.parameters(), weight_decay=args.wd)
        else:
            opt = AdamW(args.lr, parameter_list=model.parameters(), weight_decay=args.wd)

        g_clip = F.dygraph_grad_clip.GradClipByGlobalNorm(1.0) #experimental
        for epoch in range(args.epoch):
            for step, d in enumerate(tqdm(train_ds.start(place), desc='training')):
                ids, sids, label = d
                loss, _ = model(ids, sids, labels=label)
                loss.backward()
                if step % 10 == 0:
                    log.debug('train loss %.5f lr %.3e' % (loss.numpy(), opt.current_step_lr()))
                opt.minimize(loss, grad_clip=g_clip)
                model.clear_gradients()
                if step % 100 == 0:
                    acc = []