def train(model, train_dataset, dev_dataset, dev_examples, dev_features, tokenizer, args): ctx = D.parallel.prepare_context() model = D.parallel.DataParallel(model, ctx) max_steps = len(train_features) * args.epoch // args.bsz opt = AdamW(learning_rate=args.lr, parameter_list=model.parameters(), weight_decay=args.wd) g_clip = F.clip.GradientClipByGlobalNorm(1.0) #experimental train_dataset = train_dataset \ .repeat() \ .shard(D.parallel.Env().nranks, D.parallel.Env().dev_id) \ .shuffle(1000) \ .padded_batch(args.bsz) log.debug('init training with args: %s' % repr(args)) for step, (_, token_ids, token_type_ids, start_pos, end_pos) in enumerate(train_dataset.start(place)): loss, _, __ = model(token_ids, token_type_ids, start_pos=start_pos, end_pos=end_pos) scaled_loss = model.scale_loss(loss) scaled_loss.backward() model.apply_collective_grads() opt.minimize(scaled_loss, grad_clip=g_clip) model.clear_gradients() if D.parallel.Env().dev_id == 0 and step % 10 == 0: log.debug('[step %d] train loss %.5f lr %.3e' % (step, loss.numpy(), opt.current_step_lr())) if D.parallel.Env().dev_id == 0 and step % 100 == 0: f1, em = evaluate(model, dev_dataset, dev_examples, dev_features, tokenizer, args) log.debug('[step %d] eval result: f1 %.5f em %.5f' % (step, f1, em)) if step > max_steps: break
args.lr, int(args.warmup_proportion * args.max_steps), args.max_steps), parameter_list=model.parameters(), weight_decay=args.wd, grad_clip=g_clip) for epoch in range(args.epoch): for step, d in enumerate( tqdm(train_ds.start(place), desc='training')): ids, sids, label = d loss, _ = model(ids, sids, labels=label) loss.backward() if step % 10 == 0: log.debug('train loss %.5f lr %.3e' % (loss.numpy(), opt.current_step_lr())) opt.minimize(loss) model.clear_gradients() if step % 100 == 0: acc = [] with FD.base._switch_tracer_mode_guard_( is_train=False): model.eval() for step, d in enumerate( tqdm(dev_ds.start(place), desc='evaluating %d' % epoch)): ids, sids, label = d loss, logits = model(ids, sids, labels=label) #print('\n'.join(map(str, logits.numpy().tolist()))) a = L.argmax(logits, -1) == label acc.append(a.numpy()) model.train()
def seq2seq(model, tokenizer, args): log.info('Training starts with args: %r' % args) attn_id = tokenizer.vocab[args.attn_token] def gen_mask(batch_ids, mask_type='bidi', query_len=None, pad_value=0): if query_len is None: query_len = batch_ids.shape[1] if mask_type != 'empty': mask = (batch_ids != pad_value).astype(np.float32) mask = np.tile(np.expand_dims(mask, 1), [1, query_len, 1]) if mask_type == 'causal': assert query_len == batch_ids.shape[1] mask = np.tril(mask) elif mask_type == 'causal_without_diag': assert query_len == batch_ids.shape[1] mask = np.tril(mask, -1) elif mask_type == 'diag': assert query_len == batch_ids.shape[1] mask = np.stack([np.diag(np.diag(m)) for m in mask], 0) else: mask_type == 'empty' mask = np.zeros_like(batch_ids).astype(np.float32) mask = np.tile(np.expand_dims(mask, 1), [1, query_len, 1]) return mask def make_some_noice(ids): if args.use_random_noice: noice_ids = np.random.randint(1, len(tokenizer.vocab), size=ids.shape) else: noice_ids = np.ones_like(ids) * tokenizer.vocab['[NOISE]'] pos, = np.where(np.ones_like(ids)) np.random.shuffle(pos) pos = pos[:int(args.noise_prob * len(pos))] ids[pos, ] = noice_ids[pos, ] return ids def map_fn(example_id, src_ids, tgt_ids): src_ids = src_ids[:args.max_encode_len] tgt_ids = tgt_ids[:args.max_decode_len] src_ids, src_sids = tokenizer.build_for_ernie(src_ids) src_pids = np.arange(len(src_ids)) tgt_ids, tgt_sids = tokenizer.build_for_ernie(tgt_ids) tgt_pids = np.arange(len(tgt_ids)) + len(src_ids) # continues position tgt_sids = np.ones_like(tgt_sids) * args.tgt_type_id attn_ids = np.ones_like(tgt_ids) * attn_id if args.noise_prob > 0.: tgt_labels = deepcopy(tgt_ids) tgt_ids = make_some_noice(tgt_ids) #corrupted else: tgt_labels = tgt_ids return (example_id, src_ids, src_pids, src_sids, tgt_ids, tgt_pids, tgt_sids, attn_ids, tgt_labels) def after_padding(example_id, src_ids, src_pids, src_sids, tgt_ids, tgt_pids, tgt_sids, attn_ids, tgt_labels): ''' attention mask: *** src, tgt, attn src 00, 01, 11 tgt 10, 11, 12 attn 20, 21, 22 *** s1, s2 | t1 t2 t3| attn1 attn2 attn3 s1 1, 1 | 0, 0, 0,| 0, 0, 0, s2 1, 1 | 0, 0, 0,| 0, 0, 0, - t1 1, 1, | 1, 0, 0,| 0, 0, 0, t2 1, 1, | 1, 1, 0,| 0, 0, 0, t3 1, 1, | 1, 1, 1,| 0, 0, 0, - attn1 1, 1, | 0, 0, 0,| 1, 0, 0, attn2 1, 1, | 1, 0, 0,| 0, 1, 0, attn3 1, 1, | 1, 1, 0,| 0, 0, 1, for details, see Fig3. https://arxiv.org/abs/2001.11314 ''' src_len = src_ids.shape[1] tgt_len = tgt_ids.shape[1] mask_00 = gen_mask(src_ids, 'bidi', query_len=src_len) mask_01 = gen_mask(tgt_ids, 'empty', query_len=src_len) mask_02 = gen_mask(attn_ids, 'empty', query_len=src_len) mask_10 = gen_mask(src_ids, 'bidi', query_len=tgt_len) mask_11 = gen_mask(tgt_ids, 'causal', query_len=tgt_len) mask_12 = gen_mask(attn_ids, 'empty', query_len=tgt_len) mask_20 = gen_mask(src_ids, 'bidi', query_len=tgt_len) mask_21 = gen_mask(tgt_ids, 'causal_without_diag', query_len=tgt_len) mask_22 = gen_mask(attn_ids, 'diag', query_len=tgt_len) ''' mask = np.concatenate([ np.concatenate([mask_00, mask_01, mask_02], 2), np.concatenate([mask_10, mask_11, mask_12], 2), np.concatenate([mask_20, mask_21, mask_22], 2), ], 1) ids = np.concatenate([src_ids, tgt_ids, attn_ids], 1) pids = np.concatenate([src_pids, tgt_pids, tgt_pids], 1) sids = np.concatenate([src_sids, tgt_sids, tgt_sids], 1) ''' mask_src_2_src = mask_00 mask_tgt_2_srctgt = np.concatenate([mask_10, mask_11], 2) mask_attn_2_srctgtattn = np.concatenate([mask_20, mask_21, mask_22], 2) tgt_labels = tgt_labels[np.where(tgt_labels != 0)] return (example_id, src_ids, src_sids, src_pids, tgt_ids, tgt_sids, tgt_pids, attn_ids, mask_src_2_src, mask_tgt_2_srctgt, mask_attn_2_srctgtattn, tgt_labels) bytes_vocab = {k.encode('utf8'): v for k, v in tokenizer.vocab.items()} feature_column = propeller.data.FeatureColumns([ propeller.data.LabelColumn('id'), propeller.data.TextColumn('src', unk_id=tokenizer.unk_id, vocab_dict=bytes_vocab), propeller.data.TextColumn('tgt', unk_id=tokenizer.unk_id, vocab_dict=bytes_vocab), ]) train_ds = feature_column.build_dataset('train', data_dir=os.path.join(args.data_dir, 'train'), shuffle=False, repeat=True, use_gz=False) \ .map(map_fn) dev_ds = feature_column.build_dataset('dev', data_dir=os.path.join(args.data_dir, 'dev'), shuffle=False, repeat=False, use_gz=False) \ .map(map_fn) \ .padded_batch(args.eval_bsz) \ .map(after_padding) log.debug('shard %d of %d' % (D.parallel.Env().dev_id, D.parallel.Env().nranks)) train_ds = train_ds.shard( D.parallel.Env().nranks, D.parallel.Env().dev_id).shuffle(10000).padded_batch( args.bsz).map(after_padding) dev_ds = dev_ds.shard(D.parallel.Env().nranks, D.parallel.Env().dev_id) shapes = [[None, None]] * 7 + [[None, None, None]] * 3 + [[None]] types = ['int64'] * 11 train_ds.data_shapes = shapes train_ds.data_types = types dev_ds.data_shapes = shapes dev_ds.data_types = types vocab_size, _ = model.word_emb.weight.shape ctx = D.parallel.prepare_context() model = D.parallel.DataParallel(model, ctx) g_clip = F.clip.GradientClipByGlobalNorm(1.0) opt = AdamW(learning_rate=LinearDecay( args.lr, int(args.warmup_proportion * args.max_steps), args.max_steps), parameter_list=model.parameters(), weight_decay=args.wd, grad_clip=g_clip) attn_id = tokenizer.vocab[args.attn_token] for step, data in enumerate(train_ds.start(place)): (example_id, src_ids, src_sids, src_pids, tgt_ids, tgt_sids, tgt_pids, attn_ids, mask_src_2_src, mask_tgt_2_srctgt, mask_attn_2_srctgtattn, tgt_labels) = data _, __, info = model(src_ids, sent_ids=src_sids, pos_ids=src_pids, attn_bias=mask_src_2_src, encode_only=True) cached_k, cached_v = info['caches'] _, __, info = model(tgt_ids, sent_ids=tgt_sids, pos_ids=tgt_pids, attn_bias=mask_tgt_2_srctgt, past_cache=(cached_k, cached_v), encode_only=True) cached_k2, cached_v2 = info['caches'] past_cache_k = [ L.concat([k, k2], 1) for k, k2 in zip(cached_k, cached_k2) ] past_cache_v = [ L.concat([v, v2], 1) for v, v2 in zip(cached_v, cached_v2) ] if args.label_smooth > 0.: tgt_labels = L.label_smooth(F.one_hot(tgt_labels, vocab_size), epsilon=args.label_smooth) loss, _, __ = model(attn_ids, sent_ids=tgt_sids, pos_ids=tgt_pids, attn_bias=mask_attn_2_srctgtattn, past_cache=(past_cache_k, past_cache_v), tgt_labels=tgt_labels, tgt_pos=L.where(attn_ids == attn_id)) scaled_loss = model.scale_loss(loss) scaled_loss.backward() model.apply_collective_grads() opt.minimize(scaled_loss) model.clear_gradients() if step % 10 == 0: loss = loss.numpy() ppl = np.exp(loss) log.debug('[step %d]train loss %.5f, ppl %.5f, lr %.3e' % (step, loss, ppl, opt.current_step_lr())) if args.save_dir is not None and step % 1000 == 0 and D.parallel.Env( ).dev_id == 0: F.save_dygraph(model.state_dict(), args.save_dir) if args.predict_output_dir is not None and step > args.skip_eval_steps and step % args.eval_steps == 0: assert os.path.exists( args.predict_output_dir ), 'predict_output_dir not found: %s' % args.predict_output_dir log.debug('doing predict on gpu %d...' % D.parallel.Env().dev_id) evaluate(model, dev_ds, step, args) if step > args.max_steps: break evaluate(model, dev_ds, step, args) if args.save_dir is not None: F.save_dygraph(model.state_dict(), args.save_dir)
opt = AdamW(learning_rate=LinearDecay( args.lr, int(args.warmup_proportion * args.max_steps), args.max_steps), parameter_list=model.parameters(), weight_decay=args.wd) g_clip = F.dygraph_grad_clip.GradClipByGlobalNorm(1.0) #experimental for epoch in range(args.epoch): for step, d in enumerate( tqdm(train_ds.start(place), desc='training')): ids, sids, label = d loss, _ = model(ids, sids, labels=label) loss.backward() if step % 10 == 0: log.debug('train loss %.5f lr %.3e' % (loss.numpy(), opt.current_step_lr())) opt.minimize(loss, grad_clip=g_clip) model.clear_gradients() if step % 100 == 0: acc = [] with FD.base._switch_tracer_mode_guard_(is_train=False): model.eval() for step, d in enumerate( tqdm(dev_ds.start(), desc='evaluating %d' % epoch)): ids, sids, label = d loss, logits = model(ids, sids, labels=label) #print('\n'.join(map(str, logits.numpy().tolist()))) a = L.argmax(logits, -1) == label acc.append(a.numpy()) model.train() log.debug('acc %.5f' % np.concatenate(acc).mean())
def finetune( self, train_path, dev_path=None, save_dir="ernie_gen_result", init_ckpt_path=None, use_gpu=True, max_steps=500, batch_size=8, max_encode_len=50, max_decode_len=50, learning_rate=5e-5, warmup_proportion=0.1, weight_decay=0.1, noise_prob=0, label_smooth=0, beam_width=5, length_penalty=1.0, log_interval=100, save_interval=200, ): """ finetune with the specified dataset. Args: train_path(str): the train dataset path. dev_path(str): the dev dataset path. save_dir(str): the model params and dev dataset predict result save path. init_ckpt_path(str): incremental training load path. use_gpu(bool): use gpu or not. max_steps(int): max training steps. batch_size(int): the batch size. max_encode_len(int): the max encode length. max_decode_len(int): the max decode length. learning_rate(float): the learning rate. warmup_proportion(float): the warmup proportion. weight_decay(float): the weight decay magnitude. noise_prob(float): the nosie probability. see the ernie gen paper for details. label_smooth(float): the label smooth magnitude. beam_width(int): the beam size during evaluating the dev dataset. length_penalty(float): the length penalty during evaluating the dev dataset. log_interval(int): the log interval. save_interval(int): the save interval. dev set will be evaluated after saving. Return: result(dict): A Dictionary of shape:: { last_save_path(str): last model save path. last_ppl(float): last model ppl. } """ self.max_encode_len = max_encode_len self.max_decode_len = max_decode_len self.noise_prob = noise_prob place = F.CUDAPlace(0) if use_gpu else F.CPUPlace() with F.dygraph.guard(place): if init_ckpt_path is not None: logger.info('loading checkpoint from %s' % init_ckpt_path) sd, _ = D.load_dygraph(init_ckpt_path) self.model.set_dict(sd) feature_column = propeller.data.FeatureColumns([ propeller.data.LabelColumn('id'), propeller.data.TextColumn( 'src', unk_id=self.tokenizer.unk_id, vocab_dict=self.tokenizer.vocab, tokenizer=self.tokenizer.tokenize), propeller.data.TextColumn( 'tgt', unk_id=self.tokenizer.unk_id, vocab_dict=self.tokenizer.vocab, tokenizer=self.tokenizer.tokenize), ]) train_ds = feature_column.build_dataset('train', data_file=train_path, shuffle=False, repeat=True, use_gz=False)\ .map(self._map_fn).shuffle(10000).padded_batch(batch_size).map(self._after_padding) train_ds.data_shapes = [[None, None]] * 7 + [[None, None, None] ] * 3 + [[None]] train_ds.data_types = ['int64'] * 11 if dev_path: dev_ds = feature_column.build_dataset('dev', data_file=dev_path, shuffle=False, repeat=False, use_gz=False) \ .map(self._map_fn) \ .padded_batch(1) \ .map(self._after_padding) dev_ds.data_shapes = [[None, None]] * 7 + [[None, None, None] ] * 3 + [[None]] dev_ds.data_types = ['int64'] * 11 vocab_size, _ = self.model.word_emb.weight.shape g_clip = F.clip.GradientClipByGlobalNorm(1.0) opt = AdamW( learning_rate=LinearDecay(learning_rate, int(warmup_proportion * max_steps), max_steps), parameter_list=self.model.parameters(), weight_decay=weight_decay, grad_clip=g_clip) loss = None save_path = None ppl = None if save_dir and not os.path.exists(save_dir): os.makedirs(save_dir) for step, data in enumerate(train_ds.start(place)): (example_id, src_ids, src_sids, src_pids, tgt_ids, tgt_sids, tgt_pids, attn_ids, mask_src_2_src, mask_tgt_2_srctgt, mask_attn_2_srctgtattn, tgt_labels) = data _, __, info = self.model( src_ids, sent_ids=src_sids, pos_ids=src_pids, attn_bias=mask_src_2_src, encode_only=True) cached_k, cached_v = info['caches'] _, __, info = self.model( tgt_ids, sent_ids=tgt_sids, pos_ids=tgt_pids, attn_bias=mask_tgt_2_srctgt, past_cache=(cached_k, cached_v), encode_only=True) cached_k2, cached_v2 = info['caches'] past_cache_k = [ L.concat([k, k2], 1) for k, k2 in zip(cached_k, cached_k2) ] past_cache_v = [ L.concat([v, v2], 1) for v, v2 in zip(cached_v, cached_v2) ] if label_smooth > 0.: tgt_labels = L.label_smooth( F.one_hot(tgt_labels, vocab_size), epsilon=label_smooth) loss, _, __ = self.model( attn_ids, sent_ids=tgt_sids, pos_ids=tgt_pids, attn_bias=mask_attn_2_srctgtattn, past_cache=(past_cache_k, past_cache_v), tgt_labels=tgt_labels, tgt_pos=L.where(attn_ids == self.tokenizer.vocab['[MASK]'])) loss.backward() opt.minimize(loss) self.model.clear_gradients() if step % log_interval == 0: loss_np = loss.numpy() ppl = np.exp(loss_np) logger.info( '[step %d / %d]train loss %.5f, ppl %.5f, elr %.3e' % (step, max_steps, loss_np, ppl, opt.current_step_lr())) if save_dir and step % save_interval == 0 and step > 0: loss_np = loss.numpy() ppl = np.exp(loss_np) save_name = "step_%s_ppl_%.5f" % (step, ppl) save_path = os.path.join(save_dir, save_name) logger.info("save the model in %s" % save_path) F.save_dygraph(self.model.state_dict(), save_path) if dev_path: logger.info('evaluating...') res = self._evaluate(dev_ds, place, beam_width, length_penalty) output_path = os.path.join( save_dir, "step_%s_ppl_%.5f.txt" % (step, ppl)) logger.info( 'save the predict result in %s' % output_path) with open(output_path, 'w') as fout: fout.write(('\n'.join(res))) if step > max_steps: break if loss: loss_np = loss.numpy() ppl = np.exp(loss_np) logger.info('[final step %d]train loss %.5f, ppl %.5f, elr %.3e' % (step, loss_np, ppl, opt.current_step_lr())) if save_dir: save_name = "step_%s_ppl_%.5f" % (step, ppl) save_path = os.path.join(save_dir, save_name) logger.info("save the model in %s" % save_path) F.save_dygraph(self.model.state_dict(), save_path) if dev_path: logger.info('evaluating...') res = self._evaluate(dev_ds, place, beam_width, length_penalty) output_path = os.path.join( save_dir, "step_%s_ppl_%.5f.txt" % (step, ppl)) logger.info( 'save the predict result in %s' % output_path) with open(output_path, 'w') as fout: fout.write(('\n'.join(res))) result = { "last_save_path": "%s.pdparams" % save_path, "last_ppl": ppl[0], } return result
place = F.CUDAPlace(D.parallel.Env().dev_id) with D.guard(place): model = ErnieModelForPretraining.from_pretrained(args.from_pretrained) opt = AdamW(learning_rate=LinearDecay(args.lr, args.warmup_steps, args.max_steps), parameter_list=model.parameters(), weight_decay=0.01) ctx = D.parallel.prepare_context() model = D.parallel.DataParallel(model, ctx) for step, samples in enumerate(tqdm(train_ds.start(place))): (src_ids, sent_ids, mlm_label, mask_pos, nsp_label) = samples loss, mlmloss, nsploss = model(src_ids, sent_ids, labels=mlm_label, mlm_pos=mask_pos, nsp_labels=nsp_label) scaled_loss = model.scale_loss(loss) scaled_loss.backward() model.apply_collective_grads() opt.minimize(scaled_loss) model.clear_gradients() if step % 10 == 0: log.debug('train loss %.5f scaled loss %.5f' % (loss.numpy(), scaled_loss.numpy())) if step % 10000 == 0 and D.parallel.Env( ).dev_id == 0 and args.save_dir is not None: F.save_dygraph(model.state_dict(), args.save_dir)
args.from_pretrained, num_labels=3, name='') model = FD.parallel.DataParallel(model, ctx) opt = AdamW(learning_rate=LinearDecay( args.lr, int(args.warmup_proportion * args.max_steps), args.max_steps), parameter_list=model.parameters(), weight_decay=args.wd) g_clip = F.dygraph_grad_clip.GradClipByGlobalNorm(1.0) #experimental for step, d in enumerate(tqdm(train_ds.start(place), desc='training')): ids, sids, label = d loss, _ = model(ids, sids, labels=label) scaled_loss = model.scale_loss(loss) scaled_loss.backward() model.apply_collective_grads() opt.minimize(scaled_loss, grad_clip=g_clip) model.clear_gradients() if step % 10 == 0: log.debug('train loss %.5f, lr %.e3' % (loss.numpy(), opt.current_step_lr())) if step % 100 == 0 and FD.parallel.Env().dev_id == 0: acc = [] with FD.base._switch_tracer_mode_guard_(is_train=False): model.eval() for step, d in enumerate( tqdm(dev_ds.start(place), desc='evaluating')): ids, sids, label = d loss, logits = model(ids, sids, labels=label) #print('\n'.join(map(str, logits.numpy().tolist()))) a = L.argmax(logits, -1) == label acc.append(a.numpy())