示例#1
0
                            weight_decay=args.wd,
                            apply_decay_param_fun=lambda n:
                            param_name_to_exclue_from_weight_decay.match(n),
                            grad_clip=g_clip)
else:
    lr_scheduler = None
    opt = P.optimizer.Adam(args.lr,
                           parameters=model.parameters(),
                           weight_decay=args.wd,
                           apply_decay_param_fun=lambda n:
                           param_name_to_exclue_from_weight_decay.match(n),
                           grad_clip=g_clip)

scaler = P.amp.GradScaler(enable=args.use_amp)
step, inter_step = 0, 0
with LogWriter(logdir=str(create_if_not_exists(args.save_dir /
                                               'vdl'))) as log_writer:
    with P.amp.auto_cast(enable=args.use_amp):
        for epoch in range(args.epoch):
            for ids, sids, label in P.io.DataLoader(train_ds,
                                                    places=P.CUDAPlace(0),
                                                    batch_size=None):
                inter_step += 1
                loss, _ = model(ids, sids, labels=label)
                loss /= acc_step
                loss = scaler.scale(loss)
                loss.backward()
                if inter_step % acc_step != 0:
                    continue
                step += 1
                scaler.minimize(opt, loss)
                model.clear_gradients()
示例#2
0
def train(model, train_dataset, dev_dataset, dev_examples, dev_features,
          tokenizer, args):
    model = P.DataParallel(model)

    max_steps = len(train_features) * args.epoch // args.bsz

    g_clip = P.nn.ClipGradByGlobalNorm(1.0)  #experimental
    lr_scheduler = P.optimizer.lr.LambdaDecay(
        args.lr,
        get_warmup_and_linear_decay(max_steps,
                                    int(args.warmup_proportion * max_steps)))

    opt = P.optimizer.AdamW(lr_scheduler,
                            parameters=model.parameters(),
                            weight_decay=args.wd,
                            grad_clip=g_clip)

    train_dataset = train_dataset \
            .cache_shuffle_shard(env.nranks, env.dev_id, drop_last=True) \
            .padded_batch(args.bsz)

    log.debug('init training with args: %s' % repr(args))
    scaler = P.amp.GradScaler(enable=args.use_amp)
    create_if_not_exists(args.save_dir)

    with P.amp.auto_cast(enable=args.use_amp):
        for step, (_, token_ids, token_type_ids, start_pos,
                   end_pos) in enumerate(
                       P.io.DataLoader(train_dataset,
                                       places=P.CUDAPlace(env.dev_id),
                                       batch_size=None)):
            loss, _, __ = model(token_ids,
                                token_type_ids,
                                start_pos=start_pos,
                                end_pos=end_pos)
            loss = scaler.scale(loss)
            loss.backward()
            scaler.minimize(opt, loss)
            model.clear_gradients()
            lr_scheduler.step()

            if env.dev_id == 0 and step % 10 == 0:
                _lr = lr_scheduler.get_lr()
                if args.use_amp:
                    _l = (loss / scaler._scale).numpy()
                    msg = '[rank-%d][step-%d] train loss %.5f lr %.3e scaling %.3e' % (
                        env.dev_id, step, _l, _lr, scaler._scale.numpy())
                else:
                    _l = loss.numpy()
                    msg = '[rank-%d][step-%d] train loss %.5f lr %.3e' % (
                        env.dev_id, step, _l, _lr)
                log.debug(msg)

            if env.dev_id == 0 and step % 100 == 0:
                f1, em = evaluate(model, dev_dataset, dev_examples,
                                  dev_features, tokenizer, args)
                log.debug('[step %d] eval result: f1 %.5f em %.5f' %
                          (step, f1, em))
            if env.dev_id == 0 and args.save_dir is not None:
                P.save(model.state_dict(), args.save_dir / 'ckpt.bin')
            if step > max_steps:
                break
param_name_to_exclue_from_weight_decay = re.compile(
    r'.*layer_norm_scale|.*layer_norm_bias|.*b_0')

lr_scheduler = P.optimizer.lr.LambdaDecay(
    args.lr,
    get_warmup_and_linear_decay(args.max_steps,
                                int(args.warmup_proportion * args.max_steps)))
opt = P.optimizer.AdamW(
    learning_rate=lr_scheduler,
    parameters=model.parameters(),
    apply_decay_param_fun=lambda n: not param_name_to_exclue_from_weight_decay.match(n),
    weight_decay=args.wd,
    grad_clip=g_clip)
scaler = P.amp.GradScaler(enable=args.use_amp)
step = 0
create_if_not_exists(args.save_dir)

#with LogWriter(logdir=str(create_if_not_exists(args.save_dir / 'vdl-%d' % env.dev_id))) as log_writer:
with P.amp.auto_cast(enable=args.use_amp):
    for ids, sids, label in P.io.DataLoader(
            train_ds, places=P.CUDAPlace(env.dev_id), batch_size=None):
        step += 1
        loss, _ = model(ids, sids, labels=label)
        loss = scaler.scale(loss)
        loss.backward()
        scaler.minimize(opt, loss)
        model.clear_gradients()
        lr_scheduler.step()

        # do logging
        if step % 10 == 0:
示例#4
0
def seq2seq(model, tokenizer, args):
    log.info('Training starts with args: %r' % args)
    attn_id = tokenizer.vocab[args.attn_token]

    def gen_mask(batch_ids, mask_type='bidi', query_len=None, pad_value=0):
        if query_len is None:
            query_len = batch_ids.shape[1]
        if mask_type != 'empty':
            mask = (batch_ids != pad_value).astype(np.float32)
            mask = np.tile(np.expand_dims(mask, 1), [1, query_len, 1])
            if mask_type == 'causal':
                assert query_len == batch_ids.shape[1]
                mask = np.tril(mask)
            elif mask_type == 'causal_without_diag':
                assert query_len == batch_ids.shape[1]
                mask = np.tril(mask, -1)
            elif mask_type == 'diag':
                assert query_len == batch_ids.shape[1]
                mask = np.stack([np.diag(np.diag(m)) for m in mask], 0)
        else:
            mask_type == 'empty'
            mask = np.zeros_like(batch_ids).astype(np.float32)
            mask = np.tile(np.expand_dims(mask, 1), [1, query_len, 1])
        return mask

    def make_some_noice(ids):
        if args.use_random_noice:
            noice_ids = np.random.randint(
                1, len(tokenizer.vocab), size=ids.shape)
        else:
            noice_ids = np.ones_like(ids) * tokenizer.vocab['[NOISE]']
        pos, = np.where(np.ones_like(ids))
        np.random.shuffle(pos)
        pos = pos[:int(args.noise_prob * len(pos))]
        ids[pos, ] = noice_ids[pos, ]
        return ids

    def map_fn(example_id, src_ids, tgt_ids):
        src_ids = src_ids[:args.max_encode_len]
        tgt_ids = tgt_ids[:args.max_decode_len]
        src_ids, src_sids = tokenizer.build_for_ernie(src_ids)
        src_pids = np.arange(len(src_ids))

        tgt_ids, tgt_sids = tokenizer.build_for_ernie(tgt_ids)
        tgt_pids = np.arange(len(tgt_ids)) + len(src_ids)  # continues position
        tgt_sids = np.ones_like(tgt_sids) * args.tgt_type_id

        attn_ids = np.ones_like(tgt_ids) * attn_id
        if args.noise_prob > 0.:
            tgt_labels = deepcopy(tgt_ids)
            tgt_ids = make_some_noice(tgt_ids)  #corrupted
        else:
            tgt_labels = tgt_ids

        return (example_id, src_ids, src_pids, src_sids, tgt_ids, tgt_pids,
                tgt_sids, attn_ids, tgt_labels)

    def after_padding(example_id, src_ids, src_pids, src_sids, tgt_ids,
                      tgt_pids, tgt_sids, attn_ids, tgt_labels):
        '''
        attention mask:
        ***  src,  tgt, attn
        src  00,   01,   11
        tgt  10,   11,   12
        attn 20,   21,   22

        ***   s1, s2 | t1 t2 t3| attn1 attn2 attn3
        s1    1,  1  | 0, 0, 0,| 0,    0,    0,
        s2    1,  1  | 0, 0, 0,| 0,    0,    0,
        -
        t1    1,  1, | 1, 0, 0,| 0,    0,    0,
        t2    1,  1, | 1, 1, 0,| 0,    0,    0,
        t3    1,  1, | 1, 1, 1,| 0,    0,    0,
        -
        attn1 1,  1, | 0, 0, 0,| 1,    0,    0,
        attn2 1,  1, | 1, 0, 0,| 0,    1,    0,
        attn3 1,  1, | 1, 1, 0,| 0,    0,    1,

        for details, see Fig3. https://arxiv.org/abs/2001.11314
        '''

        src_len = src_ids.shape[1]
        tgt_len = tgt_ids.shape[1]
        mask_00 = gen_mask(src_ids, 'bidi', query_len=src_len)
        mask_01 = gen_mask(tgt_ids, 'empty', query_len=src_len)
        mask_02 = gen_mask(attn_ids, 'empty', query_len=src_len)

        mask_10 = gen_mask(src_ids, 'bidi', query_len=tgt_len)
        mask_11 = gen_mask(tgt_ids, 'causal', query_len=tgt_len)
        mask_12 = gen_mask(attn_ids, 'empty', query_len=tgt_len)

        mask_20 = gen_mask(src_ids, 'bidi', query_len=tgt_len)
        mask_21 = gen_mask(tgt_ids, 'causal_without_diag', query_len=tgt_len)
        mask_22 = gen_mask(attn_ids, 'diag', query_len=tgt_len)
        '''
        mask = np.concatenate([
            np.concatenate([mask_00, mask_01, mask_02], 2),
            np.concatenate([mask_10, mask_11, mask_12], 2),
            np.concatenate([mask_20, mask_21, mask_22], 2),
        ], 1)

        ids = np.concatenate([src_ids, tgt_ids, attn_ids], 1)
        pids = np.concatenate([src_pids, tgt_pids, tgt_pids], 1)
        sids = np.concatenate([src_sids, tgt_sids, tgt_sids], 1)

        '''

        mask_src_2_src = mask_00
        mask_tgt_2_srctgt = np.concatenate([mask_10, mask_11], 2)
        mask_attn_2_srctgtattn = np.concatenate([mask_20, mask_21, mask_22], 2)

        tgt_labels = tgt_labels[np.where(tgt_labels != 0)]
        return (example_id, src_ids, src_sids, src_pids, tgt_ids, tgt_sids,
                tgt_pids, attn_ids, mask_src_2_src, mask_tgt_2_srctgt,
                mask_attn_2_srctgtattn, tgt_labels)

    bytes_vocab = {k.encode('utf8'): v for k, v in tokenizer.vocab.items()}
    feature_column = propeller.data.FeatureColumns([
        propeller.data.LabelColumn('id'),
        propeller.data.TextColumn(
            'src', unk_id=tokenizer.unk_id, vocab_dict=bytes_vocab),
        propeller.data.TextColumn(
            'tgt', unk_id=tokenizer.unk_id, vocab_dict=bytes_vocab),
    ])

    train_ds = feature_column.build_dataset('train', data_dir=os.path.join(args.data_dir, 'train'), shuffle=True, repeat=True, use_gz=False) \
                                   .map(map_fn) \
                                   .padded_batch(args.bsz) \
                                   .map(after_padding)


    dev_ds = feature_column.build_dataset('dev', data_dir=os.path.join(args.data_dir, 'dev'), shuffle=False, repeat=False, use_gz=False) \
                                   .map(map_fn) \
                                   .padded_batch(args.eval_bsz) \
                                   .map(after_padding) \
                                   .shard(env.nranks, env.dev_id)

    vocab_size, _ = model.word_emb.weight.shape
    model = P.DataParallel(model)
    g_clip = P.nn.ClipGradByGlobalNorm(1.0)
    param_name_to_exclue_from_weight_decay = re.compile(
        r'.*layer_norm_scale|.*layer_norm_bias|.*b_0')
    lr_scheduler = P.optimizer.lr.LambdaDecay(
        args.lr,
        get_warmup_and_linear_decay(
            args.max_steps, int(args.warmup_proportion * args.max_steps)))

    opt = P.optimizer.AdamW(
        learning_rate=lr_scheduler,
        parameters=model.parameters(),
        weight_decay=args.wd,
        apply_decay_param_fun=lambda n: param_name_to_exclue_from_weight_decay.match(n),
        grad_clip=g_clip)

    scaler = P.amp.GradScaler(enable=args.use_amp)
    attn_id = tokenizer.vocab[args.attn_token]
    create_if_not_exists(args.save_dir)
    if args.predict_output_dir:
        create_if_not_exists(args.predict_output_dir)

    with P.amp.auto_cast(enable=args.use_amp):
        for step, data in enumerate(
                P.io.DataLoader(
                    train_ds, places=P.CUDAPlace(env.dev_id),
                    batch_size=None)):
            (example_id, src_ids, src_sids, src_pids, tgt_ids, tgt_sids,
             tgt_pids, attn_ids, mask_src_2_src, mask_tgt_2_srctgt,
             mask_attn_2_srctgtattn, tgt_labels) = data

            _, __, info = model(
                src_ids,
                sent_ids=src_sids,
                pos_ids=src_pids,
                attn_bias=mask_src_2_src,
                encode_only=True)
            cached_k, cached_v = info['caches']
            _, __, info = model(
                tgt_ids,
                sent_ids=tgt_sids,
                pos_ids=tgt_pids,
                attn_bias=mask_tgt_2_srctgt,
                past_cache=(cached_k, cached_v),
                encode_only=True)
            cached_k2, cached_v2 = info['caches']
            past_cache_k = [
                P.concat([k, k2], 1) for k, k2 in zip(cached_k, cached_k2)
            ]
            past_cache_v = [
                P.concat([v, v2], 1) for v, v2 in zip(cached_v, cached_v2)
            ]
            tgt_labels = F.one_hot(tgt_labels, vocab_size)
            if args.label_smooth > 0.:
                tgt_labels = F.label_smooth(
                    tgt_labels, epsilon=args.label_smooth)
            loss, _, __ = model(
                attn_ids,
                sent_ids=tgt_sids,
                pos_ids=tgt_pids,
                attn_bias=mask_attn_2_srctgtattn,
                past_cache=(past_cache_k, past_cache_v),
                tgt_labels=tgt_labels,
                tgt_pos=P.nonzero(attn_ids == attn_id))

            loss = scaler.scale(loss)
            loss.backward()
            scaler.minimize(opt, loss)
            model.clear_gradients()
            lr_scheduler.step()

            if step % 10 == 0:
                _lr = lr_scheduler.get_lr()
                if args.use_amp:
                    _l = (loss / scaler._scale).numpy()
                    msg = '[rank-%d][step-%d] train loss %.5f lr %.3e scaling %.3e' % (
                        env.dev_id, step, _l, _lr, scaler._scale.numpy())
                else:
                    _l = loss.numpy()
                    msg = '[rank-%d][step-%d] train loss %.5f lr %.3e' % (
                        env.dev_id, step, _l, _lr)
                log.debug(msg)

            if args.save_dir is not None and step % 1000 == 0 and env.dev_id == 0:
                P.save(model.state_dict(), args.save_dir / 'ckpt.bin')

            if args.predict_output_dir is not None and step > args.skip_eval_steps and step % args.eval_steps == 0:
                assert  args.predict_output_dir.exists(), \
                 'predict_output_dir not found: %s' % args.predict_output_dir
                log.debug('doing predict on gpu %d...' % env.dev_id)
                evaluate(model, dev_ds, step, args)
            if step > args.max_steps:
                break
        evaluate(model, dev_ds, step, args)

    if args.save_dir is not None:
        P.save(model.state_dict(), args.save_dir / 'ckpt.bin')