def make_dataloader(data_train, data_val, data_test, args, use_average_length=False, num_shards=0, num_workers=8): """Create data loaders for training/validation/test.""" data_train_lengths = get_data_lengths(data_train) data_val_lengths = get_data_lengths(data_val) data_test_lengths = get_data_lengths(data_test) train_batchify_fn = btf.Tuple(btf.Pad(), btf.Pad(), btf.Stack(dtype='float32'), btf.Stack(dtype='float32')) test_batchify_fn = btf.Tuple(btf.Pad(), btf.Pad(), btf.Stack(dtype='float32'), btf.Stack(dtype='float32'), btf.Stack()) target_val_lengths = list(map(lambda x: x[-1], data_val_lengths)) target_test_lengths = list(map(lambda x: x[-1], data_test_lengths)) if args.bucket_scheme == 'constant': bucket_scheme = nlp.data.ConstWidthBucket() elif args.bucket_scheme == 'linear': bucket_scheme = nlp.data.LinearWidthBucket() elif args.bucket_scheme == 'exp': bucket_scheme = nlp.data.ExpWidthBucket(bucket_len_step=1.2) else: raise NotImplementedError train_batch_sampler = nlp.data.FixedBucketSampler(lengths=data_train_lengths, batch_size=args.batch_size, num_buckets=args.num_buckets, ratio=args.bucket_ratio, shuffle=True, use_average_length=use_average_length, num_shards=num_shards, bucket_scheme=bucket_scheme) logging.info('Train Batch Sampler:\n%s', train_batch_sampler.stats()) train_data_loader = nlp.data.ShardedDataLoader(data_train, batch_sampler=train_batch_sampler, batchify_fn=train_batchify_fn, num_workers=num_workers) val_batch_sampler = nlp.data.FixedBucketSampler(lengths=target_val_lengths, batch_size=args.test_batch_size, num_buckets=args.num_buckets, ratio=args.bucket_ratio, shuffle=False, use_average_length=use_average_length, bucket_scheme=bucket_scheme) logging.info('Valid Batch Sampler:\n%s', val_batch_sampler.stats()) val_data_loader = gluon.data.DataLoader(data_val, batch_sampler=val_batch_sampler, batchify_fn=test_batchify_fn, num_workers=num_workers) test_batch_sampler = nlp.data.FixedBucketSampler(lengths=target_test_lengths, batch_size=args.test_batch_size, num_buckets=args.num_buckets, ratio=args.bucket_ratio, shuffle=False, use_average_length=use_average_length, bucket_scheme=bucket_scheme) logging.info('Test Batch Sampler:\n%s', test_batch_sampler.stats()) test_data_loader = gluon.data.DataLoader(data_test, batch_sampler=test_batch_sampler, batchify_fn=test_batchify_fn, num_workers=num_workers) return train_data_loader, val_data_loader, test_data_loader
def get_dataloader(data_set, args, dataset_type, use_average_length=False, num_shards=0, num_workers=8): """Create data loaders for training/validation/test.""" assert dataset_type in ['train', 'val', 'test'] if args.bucket_scheme == 'constant': bucket_scheme = nlp.data.ConstWidthBucket() elif args.bucket_scheme == 'linear': bucket_scheme = nlp.data.LinearWidthBucket() elif args.bucket_scheme == 'exp': bucket_scheme = nlp.data.ExpWidthBucket(bucket_len_step=1.2) else: raise NotImplementedError data_lengths = get_data_lengths(data_set) if dataset_type == 'train': train_batchify_fn = btf.Tuple(btf.Pad(), btf.Pad(), btf.Stack(dtype='float32'), btf.Stack(dtype='float32')) else: data_lengths = list(map(lambda x: x[-1], data_lengths)) test_batchify_fn = btf.Tuple(btf.Pad(), btf.Pad(), btf.Stack(dtype='float32'), btf.Stack(dtype='float32'), btf.Stack()) batch_sampler = nlp.data.FixedBucketSampler(lengths=data_lengths, batch_size=(args.batch_size \ if dataset_type == 'train' \ else args.test_batch_size), num_buckets=args.num_buckets, ratio=args.bucket_ratio, shuffle=(dataset_type == 'train'), use_average_length=use_average_length, num_shards=num_shards, bucket_scheme=bucket_scheme) if dataset_type == 'train': logging.info('Train Batch Sampler:\n%s', batch_sampler.stats()) data_loader = nlp.data.ShardedDataLoader(data_set, batch_sampler=batch_sampler, batchify_fn=train_batchify_fn, num_workers=num_workers) else: if dataset_type == 'val': logging.info('Valid Batch Sampler:\n%s', batch_sampler.stats()) else: logging.info('Test Batch Sampler:\n%s', batch_sampler.stats()) data_loader = gluon.data.DataLoader(data_set, batch_sampler=batch_sampler, batchify_fn=test_batchify_fn, num_workers=num_workers) return data_loader
def test_pad(): padded = batchify.Pad(val=-1)([mx.np.array([]), mx.np.arange(1) ]).asnumpy().flatten().tolist() assert padded == [-1.0, 0.0] padded = batchify.Pad(val=-1, round_to=2)([mx.np.array([]), mx.np.arange(1) ]).asnumpy().flatten().tolist() assert padded == [-1.0, -1.0, 0.0, -1.0]
def test_named_tuple(): a = ([1, 2, 3, 4], 0) b = ([5, 7], 1) c = ([1, 2, 3, 4, 5, 6, 7], 0) batchify_fn = batchify.NamedTuple([('data', batchify.Pad()), ('label', batchify.Stack())], name='SomeName') sample = batchify_fn([a, b, c]) gt_data = batchify.Pad()([a[0], b[0], c[0]]) gt_label = batchify.Stack()([a[1], b[1], c[1]]) assert_allclose(sample.data.asnumpy(), gt_data.asnumpy()) assert_allclose(sample.label.asnumpy(), gt_label.asnumpy()) assert type(sample).__name__ == 'SomeName'
def __init__(self, tokenizer, doc_stride, max_seq_length, max_query_length): """ Parameters ---------- tokenizer The tokenizer doc_stride The stride to chunk the document max_seq_length Maximum length of the merged data max_query_length Maximum query length """ self._tokenizer = tokenizer self._doc_stride = doc_stride self._max_seq_length = max_seq_length self._max_query_length = max_query_length vocab = tokenizer.vocab self.pad_id = vocab.pad_id # For roberta model, taking sepecial token <s> as [CLS] and </s> as [SEP] self.cls_id = vocab.bos_id if 'cls_token' not in vocab.special_token_keys else vocab.cls_id self.sep_id = vocab.eos_id if 'sep_token' not in vocab.special_token_keys else vocab.sep_id # TODO(sxjscience) Consider to combine the NamedTuple and batchify functionality. self.ChunkFeature = collections.namedtuple('ChunkFeature', ['qas_id', 'data', 'valid_length', 'segment_ids', 'masks', 'is_impossible', 'gt_start', 'gt_end', 'context_offset', 'chunk_start', 'chunk_length']) self.BatchifyFunction = bf.NamedTuple(self.ChunkFeature, {'qas_id': bf.List(), 'data': bf.Pad(val=self.pad_id), 'valid_length': bf.Stack(), 'segment_ids': bf.Pad(), 'masks': bf.Pad(val=1), 'is_impossible': bf.Stack(), 'gt_start': bf.Stack(), 'gt_end': bf.Stack(), 'context_offset': bf.Stack(), 'chunk_start': bf.Stack(), 'chunk_length': bf.Stack()})
def test_pad(): with pytest.warns(UserWarning): # UserWarning: Using Pad with NDArrays is discouraged for speed reasons. padded = batchify.Pad(pad_val=-1)([mx.nd.array([]), mx.nd.arange(1) ]).asnumpy().flatten().tolist() assert padded == [-1.0, 0.0] with pytest.warns(UserWarning): # UserWarning: Using Pad with NDArrays is discouraged for speed reasons. padded = batchify.Pad(pad_val=-1, round_to=2)([mx.nd.array([]), mx.nd.arange(1) ]).asnumpy().flatten().tolist() assert padded == [-1.0, -1.0, 0.0, -1.0]
def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) # Turn Dataset into Dataloader # (GPT-2 does not have a padding token; but the padding token shouldn't matter anyways) self._batchify_fn = btf.Tuple( btf.Stack(dtype='int32'), btf.Pad(pad_val=np.iinfo(np.int32).max, dtype='int32'), btf.Stack(dtype='int32'))
def __init__(self, *args, **kwargs): self._wwm = kwargs.pop('wwm') if 'wwm' in kwargs else False super().__init__(*args, **kwargs) # Turn Dataset into Dataloader self._batchify_fn = btf.Tuple( btf.Stack(dtype='int32'), btf.Pad(pad_val=np.iinfo(np.int32).max, dtype='int32'), btf.Stack(dtype='float32'))
def prepare_data_loader(args, dataset, vocab, test=False): """ Read data and build data loader. """ # Preprocess dataset = dataset.transform(lambda s1, s2, label: (vocab(s1), vocab(s2), label), lazy=False) # Batching batchify_fn = btf.Tuple(btf.Pad(), btf.Pad(), btf.Stack(dtype='int32')) data_lengths = [max(len(d[0]), len(d[1])) for d in dataset] batch_sampler = nlp.data.FixedBucketSampler(lengths=data_lengths, batch_size=args.batch_size, shuffle=(not test)) data_loader = gluon.data.DataLoader(dataset=dataset, batch_sampler=batch_sampler, batchify_fn=batchify_fn) return data_loader
def test_dict(): a = {'data': [1, 2, 3, 4], 'label': 0} b = {'data': [5, 7], 'label': 1} c = {'data': [1, 2, 3, 4, 5, 6, 7], 'label': 0} with pytest.raises(ValueError): wrong_batchify_fn = batchify.Dict( [batchify.Pad(pad_val=0), batchify.Stack()]) with pytest.raises(ValueError): wrong_batchify_fn = batchify.NamedTuple(MyNamedTuple, {'a': 1, 'b': 2}) batchify_fn = batchify.Dict({ 'data': batchify.Pad(pad_val=0), 'label': batchify.Stack() }) sample = batchify_fn([a, b, c]) gt_data = batchify.Pad(pad_val=0)([a['data'], b['data'], c['data']]) gt_label = batchify.Stack()([a['label'], b['label'], c['label']]) assert isinstance(sample, dict) assert_allclose(sample['data'].asnumpy(), gt_data.asnumpy()) assert_allclose(sample['label'].asnumpy(), gt_label.asnumpy())
def __init__(self, *args, **kwargs): self._wwm = kwargs.pop('wwm') if 'wwm' in kwargs else False super().__init__(*args, **kwargs) # Turn Dataset into Dataloader self._batchify_fn = btf.Tuple(btf.Stack(dtype='int32'), btf.Pad(pad_val=np.iinfo(np.int32).max, dtype='int32'), btf.Stack(dtype='float32'), btf.Stack(dtype='float32')) self._trainer = mx.gluon.Trainer(self._model.collect_params(), 'adam', {'learning_rate': 1e-5, 'epsilon': 1e-9}, update_on_kvstore=False) self._loss = mx.gluon.loss.L2Loss() self._loss.hybridize(static_alloc=True) self._params = [p for p in self._model.collect_params().values() if p.grad_req != 'null'] self._max_length = 384
def test_named_tuple(): a = MyNamedTuple([1, 2, 3, 4], 0) b = MyNamedTuple([5, 7], 1) c = MyNamedTuple([1, 2, 3, 4, 5, 6, 7], 0) with pytest.raises(ValueError): wrong_batchify_fn = batchify.NamedTuple( MyNamedTuple, { 'data0': batchify.Pad(pad_val=0), 'label': batchify.Stack() }) with pytest.raises(ValueError): wrong_batchify_fn = batchify.NamedTuple( MyNamedTuple, [batchify.Pad(pad_val=0), batchify.Stack(), batchify.Stack()]) with pytest.raises(ValueError): wrong_batchify_fn = batchify.NamedTuple(MyNamedTuple, (batchify.Pad(pad_val=0), )) with pytest.raises(ValueError): wrong_batchify_fn = batchify.NamedTuple(MyNamedTuple, [1, 2]) for batchify_fn in [ batchify.NamedTuple(MyNamedTuple, { 'data': batchify.Pad(pad_val=0), 'label': batchify.Stack() }), batchify.NamedTuple(MyNamedTuple, [batchify.Pad(pad_val=0), batchify.Stack()]), batchify.NamedTuple(MyNamedTuple, (batchify.Pad(pad_val=0), batchify.Stack())) ]: sample = batchify_fn([a, b, c]) gt_data = batchify.Pad(pad_val=0)([a[0], b[0], c[0]]) gt_label = batchify.Stack()([a[1], b[1], c[1]]) assert isinstance(sample, MyNamedTuple) assert_allclose(sample.data.asnumpy(), gt_data.asnumpy()) assert_allclose(sample.label.asnumpy(), gt_label.asnumpy()) with pytest.raises(ValueError): batchify_fn([1, 2, 3])
def train(): """Training function.""" trainer = gluon.Trainer(model.collect_params(), args.optimizer, {'learning_rate': args.lr}) train_batchify_fn = btf.Tuple(btf.Pad(), btf.Pad(), btf.Stack(dtype='float32'), btf.Stack(dtype='float32')) test_batchify_fn = btf.Tuple(btf.Pad(), btf.Pad(), btf.Stack(dtype='float32'), btf.Stack(dtype='float32'), btf.Stack()) if args.bucket_scheme == 'constant': bucket_scheme = ConstWidthBucket() elif args.bucket_scheme == 'linear': bucket_scheme = LinearWidthBucket() elif args.bucket_scheme == 'exp': bucket_scheme = ExpWidthBucket(bucket_len_step=1.2) else: raise NotImplementedError train_batch_sampler = FixedBucketSampler(lengths=data_train_lengths, batch_size=args.batch_size, num_buckets=args.num_buckets, ratio=args.bucket_ratio, shuffle=True, bucket_scheme=bucket_scheme) logging.info('Train Batch Sampler:\n{}'.format( train_batch_sampler.stats())) train_data_loader = DataLoader(data_train, batch_sampler=train_batch_sampler, batchify_fn=train_batchify_fn, num_workers=8) val_batch_sampler = FixedBucketSampler(lengths=data_val_lengths, batch_size=args.test_batch_size, num_buckets=args.num_buckets, ratio=args.bucket_ratio, shuffle=False) logging.info('Valid Batch Sampler:\n{}'.format(val_batch_sampler.stats())) val_data_loader = DataLoader(data_val, batch_sampler=val_batch_sampler, batchify_fn=test_batchify_fn, num_workers=8) test_batch_sampler = FixedBucketSampler(lengths=data_test_lengths, batch_size=args.test_batch_size, num_buckets=args.num_buckets, ratio=args.bucket_ratio, shuffle=False) logging.info('Test Batch Sampler:\n{}'.format(test_batch_sampler.stats())) test_data_loader = DataLoader(data_test, batch_sampler=test_batch_sampler, batchify_fn=test_batchify_fn, num_workers=8) best_valid_bleu = 0.0 for epoch_id in range(args.epochs): log_avg_loss = 0 log_avg_gnorm = 0 log_wc = 0 log_start_time = time.time() for batch_id, (src_seq, tgt_seq, src_valid_length, tgt_valid_length)\ in enumerate(train_data_loader): # logging.info(src_seq.context) Context suddenly becomes GPU. src_seq = src_seq.as_in_context(ctx) tgt_seq = tgt_seq.as_in_context(ctx) src_valid_length = src_valid_length.as_in_context(ctx) tgt_valid_length = tgt_valid_length.as_in_context(ctx) with mx.autograd.record(): out, _ = model(src_seq, tgt_seq[:, :-1], src_valid_length, tgt_valid_length - 1) loss = loss_function(out, tgt_seq[:, 1:], tgt_valid_length - 1).mean() loss = loss * (tgt_seq.shape[1] - 1) / (tgt_valid_length - 1).mean() loss.backward() grads = [p.grad(ctx) for p in model.collect_params().values()] gnorm = gluon.utils.clip_global_norm(grads, args.clip) trainer.step(1) src_wc = src_valid_length.sum().asscalar() tgt_wc = (tgt_valid_length - 1).sum().asscalar() step_loss = loss.asscalar() log_avg_loss += step_loss log_avg_gnorm += gnorm log_wc += src_wc + tgt_wc if (batch_id + 1) % args.log_interval == 0: wps = log_wc / (time.time() - log_start_time) logging.info( '[Epoch {} Batch {}/{}] loss={:.4f}, ppl={:.4f}, gnorm={:.4f}, ' 'throughput={:.2f}K wps, wc={:.2f}K'.format( epoch_id, batch_id + 1, len(train_data_loader), log_avg_loss / args.log_interval, np.exp(log_avg_loss / args.log_interval), log_avg_gnorm / args.log_interval, wps / 1000, log_wc / 1000)) log_start_time = time.time() log_avg_loss = 0 log_avg_gnorm = 0 log_wc = 0 valid_loss, valid_translation_out = evaluate(val_data_loader) valid_bleu_score, _, _, _, _ = compute_bleu([val_tgt_sentences], valid_translation_out) logging.info( '[Epoch {}] valid Loss={:.4f}, valid ppl={:.4f}, valid bleu={:.2f}' .format(epoch_id, valid_loss, np.exp(valid_loss), valid_bleu_score * 100)) test_loss, test_translation_out = evaluate(test_data_loader) test_bleu_score, _, _, _, _ = compute_bleu([test_tgt_sentences], test_translation_out) logging.info( '[Epoch {}] test Loss={:.4f}, test ppl={:.4f}, test bleu={:.2f}'. format(epoch_id, test_loss, np.exp(test_loss), test_bleu_score * 100)) write_sentences( valid_translation_out, os.path.join(args.save_dir, 'epoch{:d}_valid_out.txt').format(epoch_id)) write_sentences( test_translation_out, os.path.join(args.save_dir, 'epoch{:d}_test_out.txt').format(epoch_id)) if valid_bleu_score > best_valid_bleu: best_valid_bleu = valid_bleu_score save_path = os.path.join(args.save_dir, 'valid_best.params') logging.info('Save best parameters to {}'.format(save_path)) model.save_params(save_path) if epoch_id + 1 >= (args.epochs * 2) // 3: new_lr = trainer.learning_rate * args.lr_update_factor logging.info('Learning rate change to {}'.format(new_lr)) trainer.set_learning_rate(new_lr) model.load_params(os.path.join(args.save_dir, 'valid_best.params')) valid_loss, valid_translation_out = evaluate(val_data_loader) valid_bleu_score, _, _, _, _ = compute_bleu([val_tgt_sentences], valid_translation_out) logging.info( 'Best model valid Loss={:.4f}, valid ppl={:.4f}, valid bleu={:.2f}'. format(valid_loss, np.exp(valid_loss), valid_bleu_score * 100)) test_loss, test_translation_out = evaluate(test_data_loader) test_bleu_score, _, _, _, _ = compute_bleu([test_tgt_sentences], test_translation_out) logging.info( 'Best model test Loss={:.4f}, test ppl={:.4f}, test bleu={:.2f}'. format(test_loss, np.exp(test_loss), test_bleu_score * 100)) write_sentences(valid_translation_out, os.path.join(args.save_dir, 'best_valid_out.txt')) write_sentences(test_translation_out, os.path.join(args.save_dir, 'best_test_out.txt'))
def bin(self, corpus: Corpus, temp: float = 1.0, split_size: int = 2000, ratio: float = 0, num_workers: int = 10) -> List[float]: ctx_cpu = mx.Context('cpu') # Turn corpus into a BERT-ready Dataset dataset = self.corpus_to_dataset(corpus) # Turn Dataset into Dataloader batchify_fn = btf.Tuple( btf.Stack(dtype='int32'), btf.Pad( pad_val=self._vocab.token_to_idx[self._vocab.padding_token], dtype='int32'), btf.Stack(dtype='float32'), btf.Stack(dtype='int32'), btf.Stack(dtype='int32'), btf.Stack(dtype='float32')) # TODO: There is a 'by-design' bug in FixedBucketSampler with num_shards > 0, where it silently reuses the last utterances: # https://github.com/dmlc/gluon-nlp/blame/b1b61d3f90cf795c7b48b6d109db7b7b96fa21ff/src/gluonnlp/data/sampler.py#L398 # batch_sampler = nlp.data.sampler.FixedBucketSampler([sent_tuple[2] for sent_tuple in dataset], batch_size=split_size, ratio=ratio, num_shards=len(self._ctxs), shuffle=False) # Hence, we use num_shards = 0 and do gluon's split_data batch_sampler = nlp.data.sampler.FixedBucketSampler( [sent_tuple[2] for sent_tuple in dataset], batch_size=split_size, ratio=ratio, num_shards=0, shuffle=False) logging.info(batch_sampler.stats()) dataloader = nlp.data.ShardedDataLoader(dataset, pin_memory=True, batch_sampler=batch_sampler, batchify_fn=batchify_fn, num_workers=num_workers, thread_pool=True) ### <DIFFERENT> max_length = 256 # Compute bins # First axis is sentence length bin_counts = np.zeros((max_length, max_length)) bin_counts_per_ctx = [ mx.nd.zeros((max_length, max_length), ctx=ctx) for ctx in self._ctxs ] bin_sums = np.zeros((max_length, max_length)) bin_sums_per_ctx = [ mx.nd.zeros((max_length, max_length), ctx=ctx) for ctx in self._ctxs ] ### </DIFFERENT> # Compute sum (assumes dataset is in order) prev_sent_idx = None true_tok_lens = [] for (curr_sent_idx, _, valid_length, _, _, _) in dataset: if curr_sent_idx != prev_sent_idx: prev_sent_idx = curr_sent_idx true_tok_lens.append(valid_length - 2) sent_count = 0 batch_log_interval = 20 # For now just predicts the first non-cls token for batch_id, batch in enumerate(dataloader): batch_size = 0 # TODO: Write tests about batching over multiple GPUs and getting the same scores # TODO: SEE COMMENT ABOVE REGARDING FIXEDBUCKETSAMPLER batch = zip(*[ mx.gluon.utils.split_data(batch_compo, len(self._ctxs), batch_axis=0, even_split=False) for batch_compo in batch ]) for ctx_idx, (sent_idxs, token_ids, valid_length, masked_positions, token_masked_ids, normalization) in enumerate(batch): ctx = self._ctxs[ctx_idx] batch_size += sent_idxs.shape[0] token_ids = token_ids.as_in_context(ctx) valid_length = valid_length.as_in_context(ctx) segment_ids = mx.nd.zeros(shape=token_ids.shape, ctx=ctx) masked_positions = masked_positions.as_in_context(ctx).reshape( -1, 1) out = self._model(token_ids, segment_ids, valid_length, masked_positions) # Get the probability computed for the correct token split_size = token_ids.shape[0] # out[0] contains the representations # out[1] is what contains the distribution for the masked out = out[1].log_softmax(temperature=temp) ### <DIFFERENT> token_masked_ids = token_masked_ids.as_in_context(ctx).reshape( -1) for i in range(out.shape[0]): num_bins = int(valid_length[i].asscalar()) - 2 bin_counts_per_ctx[ctx_idx][num_bins, masked_positions[i] - 1] += 1 bin_sums_per_ctx[ctx_idx][num_bins, masked_positions[i] - 1] += out[i, 0, token_masked_ids[i]] if token_masked_ids[i].asscalar() == 1012: import pdb pdb.set_trace() ### </DIFFERENT> # Progress sent_count += batch_size if (batch_id + 1) % batch_log_interval == 0: logging.info("{} sents of {}, batch {} of {}".format( sent_count, len(dataset), batch_id + 1, len(batch_sampler))) # Accumulate the counts for ctx_idx in range(len(self._ctxs)): bin_counts += bin_counts_per_ctx[ctx_idx].asnumpy() bin_sums += bin_sums_per_ctx[ctx_idx].asnumpy() return bin_counts, bin_sums
def train(args): # Load and clean data """ Training function that orchestrates the Classification! """ train_file = args.input test_file = args.validation ngram_range = args.ngrams logging.info('Ngrams range for the training run : %s', ngram_range) logging.info('Loading Training data') train_labels, train_data = read_input_data(train_file) tokens_list = [] for x in train_data: tokens_list.extend(x.split()) cntr = Counter(tokens_list) train_vocab = gluonnlp.Vocab(cntr) logging.info('Vocabulary size: %s', len(train_vocab)) logging.info('Training data converting to sequences...') train_sequences = [train_vocab.to_indices(x.split()) for x in train_data] logging.info('Reading test dataset') test_labels, test_data = read_input_data(test_file) test_sequences = [train_vocab.to_indices(x.split()) for x in test_data] if ngram_range >= 2: logging.info('Adding %s-gram features', ngram_range) # Create set of unique n-gram from the training set. ngram_set = set() for input_list in train_sequences: for i in range(2, ngram_range + 1): set_of_ngram = create_ngram_set(input_list, ngram_value=i) ngram_set.update(set_of_ngram) start_index = len(cntr) token_indices = {v: k + start_index for k, v in enumerate(ngram_set)} train_sequences = add_ngram(train_sequences, token_indices, ngram_range) test_sequences = add_ngram(test_sequences, token_indices, ngram_range) logging.info('Added n-gram features to train and test datasets!! ') logging.info('Encoding labels') label_mapping = get_label_mapping(train_labels) y_train_final = list(map(lambda x: label_mapping[x], train_labels)) y_test_final = list(map(lambda x: label_mapping[x], test_labels)) num_classes = len(np.unique(train_labels)) logging.info('Number of labels: %s', num_classes) logging.info('Initializing network') ctx = get_context(args) logging.info('Running Training on ctx:%s', ctx) embedding_dim = args.emsize net = FastTextClassificationModel( len(train_vocab), embedding_dim, num_classes) net.hybridize() net.collect_params().initialize(mx.init.Xavier(), ctx=ctx) logging.info('Network initialized') softmax_cross_entropy = gluon.loss.SoftmaxCrossEntropyLoss() sigmoid_loss_fn = gluon.loss.SigmoidBinaryCrossEntropyLoss() loss_function = softmax_cross_entropy if num_classes == 2: logging.info( 'Changing the loss function to sigmoid since its Binary Classification' ) loss_function = sigmoid_loss_fn logging.info('Loss function for training:%s', loss_function) num_epochs = args.epochs batch_size = args.batch_size logging.info('Starting Training!') learning_rate = args.lr trainer1 = gluon.Trainer(net.embedding.collect_params(), 'adam', {'learning_rate': learning_rate}) trainer2 = gluon.Trainer(net.dense.collect_params(), 'adam', {'learning_rate': learning_rate}) train_batchify_fn = btf.Tuple(btf.Pad(), mx.nd.array) logging.info('Loading the training data to memory and creating sequences!') train_data_iter = mx.gluon.data.DataLoader( mx.gluon.data.ArrayDataset(train_sequences, mx.nd.array(y_train_final)), batch_size=batch_size, shuffle=False, batchify_fn=train_batchify_fn) logging.info('Loading the test data to memory and creating sequences') test_data_iter = mx.gluon.data.DataLoader( mx.gluon.data.ArrayDataset(test_sequences, mx.nd.array(y_test_final)), batch_size=2048, shuffle=False, batchify_fn=train_batchify_fn) num_batches = len(train_data) / batch_size display_batch_cadence = int(math.ceil(num_batches / 10)) logging.info('Training on %s samples and testing on %s samples', len(train_data), len(test_data)) logging.info('Number of batches for each epoch : %s, Display cadence: %s', num_batches, display_batch_cadence) for e in range(num_epochs): for batch, (data, label) in enumerate(train_data_iter): #num_batches += 1 data = data.as_in_context(ctx) label = label.as_in_context(ctx) with autograd.record(): output = net(data) loss = loss_function(output, label) loss.backward() trainer1.step(data.shape[0]) trainer2.step(data.shape[0]) if (batch % display_batch_cadence == 0): logging.info('Epoch : %s, Batches complete :%s', e, batch) logging.info('Epoch complete :%s, Computing Accuracy', e) test_accuracy, test_loss = evaluate_accuracy( test_data_iter, net, ctx, loss_function, num_classes) logging.info('Epochs completed : %s Test Accuracy: %s, Test Loss: %s', e, test_accuracy, test_loss) learning_rate = learning_rate * 0.5 trainer1.set_learning_rate(learning_rate) trainer2.set_learning_rate(learning_rate) save_model(net, args.output)
def score(self, corpus: Corpus, temp: float = 1.0, split_size: int = 2000, ratio: float = 0, num_workers: int = 10, per_token: bool = False) -> List[float]: ctx_cpu = mx.Context('cpu') # Turn corpus into a BERT-ready Dataset dataset = self.corpus_to_dataset(corpus) # Turn Dataset into Dataloader batchify_fn = btf.Tuple( btf.Stack(dtype='int32'), btf.Pad( pad_val=self._vocab.token_to_idx[self._vocab.padding_token], dtype='int32'), btf.Stack(dtype='float32'), btf.Stack(dtype='float32'), btf.Stack(dtype='int32'), btf.Stack(dtype='float32')) # TODO: There is a 'by-design' bug in FixedBucketSampler with num_shards > 0, where it silently reuses the last utterances: # https://github.com/dmlc/gluon-nlp/blame/b1b61d3f90cf795c7b48b6d109db7b7b96fa21ff/src/gluonnlp/data/sampler.py#L398 # batch_sampler = nlp.data.sampler.FixedBucketSampler([sent_tuple[2] for sent_tuple in dataset], batch_size=split_size, ratio=ratio, num_shards=len(self._ctxs), shuffle=False) # Hence, we use num_shards = 0 and do gluon's split_data batch_sampler = nlp.data.sampler.FixedBucketSampler( [sent_tuple[2] for sent_tuple in dataset], batch_size=split_size, ratio=ratio, num_shards=0, shuffle=False) logging.info(batch_sampler.stats()) dataloader = nlp.data.ShardedDataLoader(dataset, pin_memory=True, batch_sampler=batch_sampler, batchify_fn=batchify_fn, num_workers=num_workers, thread_pool=True) # Get lengths in tokens (assumes dataset is in order) prev_sent_idx = None true_tok_lens = [] for (curr_sent_idx, _, valid_length, _, _, _) in dataset: if curr_sent_idx != prev_sent_idx: prev_sent_idx = curr_sent_idx true_tok_lens.append(valid_length - 2) # Compute scores (total or per-position) if per_token: scores_per_token = [[None] * (true_tok_len + 2) for true_tok_len in true_tok_lens] else: scores = np.zeros((len(corpus), )) sent_count = 0 batch_log_interval = 20 batch_score_accumulation = 1 batch_sent_idxs_per_ctx = [[] for ctx in self._ctxs] batch_scores_per_ctx = [[] for ctx in self._ctxs] batch_masked_positions_per_ctx = [[] for ctx in self._ctxs] def sum_accumulated_scores(): for ctx_idx in range(len(self._ctxs)): for batch_sent_idxs, batch_scores, batch_masked_positions in zip( batch_sent_idxs_per_ctx[ctx_idx], batch_scores_per_ctx[ctx_idx], batch_masked_positions_per_ctx[ctx_idx]): if per_token: # Slow; only use when necessary for batch_sent_idx, batch_score, batch_masked_position in zip( batch_sent_idxs, batch_scores, batch_masked_positions): scores_per_token[batch_sent_idx.asscalar()][int( batch_masked_position.asscalar( ))] = batch_score.asscalar().item() else: np.add.at(scores, batch_sent_idxs.asnumpy(), batch_scores.asnumpy()) batch_sent_idxs_per_ctx[ctx_idx] = [] batch_scores_per_ctx[ctx_idx] = [] batch_masked_positions_per_ctx[ctx_idx] = [] # For now just predicts the first non-cls token for batch_id, batch in enumerate(dataloader): batch_size = 0 # TODO: Write tests about batching over multiple GPUs and getting the same scores # TODO: SEE COMMENT ABOVE REGARDING FIXEDBUCKETSAMPLER batch = zip(*[ mx.gluon.utils.split_data(batch_compo, len(self._ctxs), batch_axis=0, even_split=False) for batch_compo in batch ]) for ctx_idx, (sent_idxs, token_ids, valid_length, masked_positions, token_masked_ids, normalization) in enumerate(batch): ctx = self._ctxs[ctx_idx] batch_size += sent_idxs.shape[0] token_ids = token_ids.as_in_context(ctx) valid_length = valid_length.as_in_context(ctx) masked_positions = masked_positions.as_in_context(ctx).reshape( -1, 1) if isinstance(self._model, RoBERTaModel): out = self._model(token_ids, valid_length, masked_positions) else: segment_ids = mx.nd.zeros(shape=token_ids.shape, ctx=ctx) out = self._model(token_ids, segment_ids, valid_length, masked_positions) # Get the probability computed for the correct token split_size = token_ids.shape[0] # out[0] contains the representations # out[1] is what contains the distribution for the masked # TODO: Manual numerically-stable softmax # https://stackoverflow.com/questions/42599498/numercially-stable-softmax # Because we only need one scalar out = out[1].log_softmax(temperature=temp) # Save the scores at the masked indices batch_sent_idxs_per_ctx[ctx_idx].append(sent_idxs) out = out[list(range(split_size)), [0] * split_size, token_masked_ids.as_in_context(ctx).reshape(-1)] batch_scores_per_ctx[ctx_idx].append(out) batch_masked_positions_per_ctx[ctx_idx].append( masked_positions) # Ideally we'd accumulate the scores when possible, but something like the below won't work # > scores[sent_idxs] += out # See In[21] in https://jakevdp.github.io/PythonDataScienceHandbook/02.07-fancy-indexing.html. # Hence, aggregation is done synchronously, every so often # (though batch_score_accumulation = 1 seems best, since bucketing is effective in reducing GPU disparity) if len(batch_sent_idxs_per_ctx[0]) == batch_score_accumulation: sum_accumulated_scores() # Progress sent_count += batch_size if (batch_id + 1) % batch_log_interval == 0: logging.info("{} sents of {}, batch {} of {}".format( sent_count, len(dataset), batch_id + 1, len(batch_sampler))) # TODO: Test score accumulation # In case there are leftovers sum_accumulated_scores() if per_token: return scores_per_token, true_tok_lens else: return scores.tolist(), true_tok_lens
def get_pretrain_data_text(data, batch_size, shuffle, num_buckets, tokenizer, vocab, max_seq_length, short_seq_prob=0.05, num_parts=1, part_idx=0, num_dataset_workers=1, num_batch_workers=1, circle_length=1, repeat=1, cached_file_path=None): """Get a data iterator from raw text documents. Parameters ---------- batch_size : int The batch size per GPU. shuffle : bool Whether to shuffle the data. num_buckets : int The number of buckets for the FixedBucketSampler for training. vocab : Vocab The vocabulary. tokenizer : HuggingFaceWordPieceTokenizer or SentencepieceTokenizer The tokenizer. max_seq_length : int The hard limit of maximum sequence length of sentence pairs. short_seq_prob : float The probability of sampling sequences shorter than the max_seq_length. num_parts : int The number of partitions for the dataset. part_idx : int The index of the partition to read. num_dataset_workers : int The number of worker processes for dataset construction. num_batch_workers : int The number of worker processes for batch construction. circle_length : int, default is 1 The number of files to be read for a single worker at the same time. When circle_length is larger than 1, we merge circle_length files. repeat : int, default is 1 The number of times that files are repeated. cached_file_path: str, default is None Directory for saving preprocessed features """ num_files = len(glob(data)) logging.info('%d files are found.', num_files) assert num_files >= num_parts, \ 'The number of text files must be no less than the number of ' \ 'workers/partitions (%d). Only %d files at %s are found.' % (num_parts, num_files, data) split_sampler = SplitSampler(num_files, num_parts=num_parts, part_index=part_idx, repeat=repeat) dataset_fn = prepare_pretrain_text_dataset sampler_fn = prepare_pretrain_bucket_sampler dataset_params = { 'tokenizer': tokenizer, 'max_seq_length': max_seq_length, 'short_seq_prob': short_seq_prob, 'cached_file_path': cached_file_path } sampler_params = { 'batch_size': batch_size, 'shuffle': shuffle, 'num_buckets': num_buckets } batchify_fn = bf.Tuple( bf.Pad(val=vocab.pad_id), # input_ids bf.Pad(val=0), # segment_ids bf.Stack(), # valid_lengths ) dataloader = DatasetLoader(data, file_sampler=split_sampler, dataset_fn=dataset_fn, batch_sampler_fn=sampler_fn, dataset_params=dataset_params, batch_sampler_params=sampler_params, batchify_fn=batchify_fn, num_dataset_workers=num_dataset_workers, num_batch_workers=num_batch_workers, pin_memory=False, circle_length=circle_length) return dataloader
def get_pretrain_data_npz(data, batch_size, shuffle, num_buckets, vocab, num_parts=1, part_idx=0, num_dataset_workers=1, num_batch_workers=1, circle_length=1, repeat=1, dataset_cached=False, num_max_dataset_cached=0): """Get a data iterator from pre-processed npz files. Parameters ---------- data: str The path to the dataset directory batch_size : int The batch size per GPU. shuffle : bool Whether to shuffle the data. num_buckets : int The number of buckets for the FixedBucketSampler for training. vocab : Vocab The vocabulary. num_parts : int The number of partitions for the dataset. part_idx : int The index of the partition to read. num_dataset_workers : int The number of worker processes for dataset construction. num_batch_workers : int The number of worker processes for batch contruction. circle_length : int, default is 1 The number of files to be read for a single worker at the same time. When circle_length is larger than 1, we merge circle_length files. repeat : int, default is 1 The number of times that files are repeated. dataset_cached : bool, default is False Whether or not to cache last processed dataset. Each processed dataset can only be cached for once. When there is no new available processed dataset to be fetched, we pop a cached processed dataset. num_max_dataset_cached : int, default is 0 Maximum number of cached datasets. It is valid only if dataset_cached is True """ num_files = len(glob(data)) logging.info('%d files are found.', num_files) assert num_files >= num_parts, \ 'The number of text files must be no less than the number of ' \ 'workers/partitions (%d). Only %d files at %s are found.' % (num_parts, num_files, data) split_sampler = SplitSampler(num_files, num_parts=num_parts, part_index=part_idx, repeat=repeat) dataset_fn = prepare_pretrain_npz_dataset sampler_fn = prepare_pretrain_bucket_sampler dataset_params = {'allow_pickle': True} sampler_params = { 'batch_size': batch_size, 'shuffle': shuffle, 'num_buckets': num_buckets } batchify_fn = bf.Tuple( bf.Pad(val=vocab.pad_id), # input_ids bf.Pad(val=0), # segment_ids bf.Stack(), # valid_lengths ) dataloader = DatasetLoader(data, file_sampler=split_sampler, dataset_fn=dataset_fn, batch_sampler_fn=sampler_fn, dataset_params=dataset_params, batch_sampler_params=sampler_params, batchify_fn=batchify_fn, num_dataset_workers=num_dataset_workers, num_batch_workers=num_batch_workers, pin_memory=False, circle_length=circle_length) return dataloader
def train(args): _, num_parts, rank, local_rank, _, ctx_l = init_comm( args.comm_backend, args.gpus) if args.comm_backend == 'horovod': logging_config( args.save_dir, name=f'train_transformer_rank{rank}_local{local_rank}_{num_parts}', console=(rank == 0)) logging.info(args) else: logging_config(args.save_dir, name='train_transformer', console=True) logging.info(args) use_amp = args.fp16 if use_amp: from mxnet import amp src_tokenizer = create_tokenizer(args.src_tokenizer, args.src_subword_model_path, args.src_vocab_path) tgt_tokenizer = create_tokenizer(args.tgt_tokenizer, args.tgt_subword_model_path, args.tgt_vocab_path) base_tgt_tokenizer = MosesTokenizer(args.tgt_lang) src_vocab = src_tokenizer.vocab tgt_vocab = tgt_tokenizer.vocab train_src_data, train_tgt_data = load_dataset_with_cache( args.train_src_corpus, args.train_tgt_corpus, src_tokenizer, tgt_tokenizer, args.overwrite_cache, local_rank, max_src_length=args.max_src_length, max_tgt_length=args.max_tgt_length, pretokenized=not args.tokenize) dev_src_data, dev_tgt_data = load_dataset_with_cache( args.dev_src_corpus, args.dev_tgt_corpus, src_tokenizer, tgt_tokenizer, args.overwrite_cache, local_rank, pretokenized=not args.tokenize) tgt_detok_sentences = [] tgt_raw_sentences = [] with open(args.dev_tgt_corpus, 'r') as in_f: for line in in_f: tgt_detok_sentences.append( base_tgt_tokenizer.decode( tgt_tokenizer.decode(line.split()).split())) with open(args.dev_tgt_raw_corpus, 'r') as in_f: for line in in_f: tgt_raw_sentences.append(line.strip()) data_train = gluon.data.SimpleDataset([ (src_tokens, tgt_tokens, len(src_tokens), len(tgt_tokens), i) for i, (src_tokens, tgt_tokens) in enumerate(zip(train_src_data, train_tgt_data)) ]) val_samples = [ (src_tokens, tgt_tokens, len(src_tokens), len(tgt_tokens), i) for i, (src_tokens, tgt_tokens) in enumerate(zip(dev_src_data, dev_tgt_data)) ] if args.comm_backend == 'horovod': slice_begin = rank * (len(val_samples) // num_parts) slice_end = min((rank + 1) * (len(val_samples) // num_parts), len(val_samples)) data_val = gluon.data.SimpleDataset(val_samples[slice_begin:slice_end]) else: data_val = gluon.data.SimpleDataset(val_samples) # Construct the model + loss function if args.cfg.endswith('.yml'): cfg = TransformerModel.get_cfg().clone_merge(args.cfg) else: cfg = TransformerModel.get_cfg(args.cfg) cfg.defrost() cfg.MODEL.src_vocab_size = len(src_vocab) cfg.MODEL.tgt_vocab_size = len(tgt_vocab) cfg.MODEL.layout = 'TN' cfg.freeze() model = TransformerModel.from_cfg(cfg) model.initialize(mx.init.Xavier(magnitude=args.magnitude), ctx=ctx_l) model.hybridize() for v in model.collect_params().values(): if v.grad_req != 'null': v.grad_req = 'add' # Do not apply weight decay to all the LayerNorm and bias for _, v in model.collect_params('.*beta|.*gamma|.*bias').items(): v.wd_mult = 0.0 param_dict = deduplicate_param_dict(model.collect_params()) inference_model = TransformerInference(model=model) inference_model.hybridize() if local_rank == 0: logging.info(model) with open(os.path.join(args.save_dir, 'config.yml'), 'w') as cfg_f: cfg_f.write(cfg.dump()) label_smooth_loss = LabelSmoothCrossEntropyLoss( num_labels=len(tgt_vocab), alpha=args.label_smooth_alpha, from_logits=False) label_smooth_loss.hybridize() # Construct the beam search sampler scorer = BeamSearchScorer(alpha=args.lp_alpha, K=args.lp_k, from_logits=False) beam_search_sampler = BeamSearchSampler(beam_size=args.beam_size, decoder=inference_model, vocab_size=len(tgt_vocab), eos_id=tgt_vocab.eos_id, scorer=scorer, stochastic=False, max_length_a=args.max_length_a, max_length_b=args.max_length_b) logging.info(beam_search_sampler) if args.comm_backend == 'horovod': hvd.broadcast_parameters(param_dict, root_rank=0) # Construct the trainer if args.lr is None: base_lr = 2.0 / math.sqrt(args.num_units) / math.sqrt( args.warmup_steps) else: base_lr = args.lr lr_scheduler = InverseSquareRootScheduler( warmup_steps=args.warmup_steps, base_lr=base_lr, warmup_init_lr=args.warmup_init_lr) optimizer_params = { 'learning_rate': args.lr, 'beta1': 0.9, 'beta2': 0.997, 'epsilon': 1e-9, 'lr_scheduler': lr_scheduler, 'wd': args.wd } user_provided_ptimizer_params = json.loads(args.optimizer_params) optimizer_params.update(user_provided_ptimizer_params) if args.fp16: optimizer_params.update({'multi_precision': True}) if args.comm_backend == 'horovod': trainer = hvd.DistributedTrainer(param_dict, args.optimizer, optimizer_params) else: trainer = gluon.Trainer(param_dict, args.optimizer, optimizer_params, update_on_kvstore=False) # Load Data if args.sampler == 'BoundedBudgetSampler': train_batch_sampler = BoundedBudgetSampler( lengths=[(ele[2], ele[3]) for ele in data_train], max_num_tokens=args.max_num_tokens, max_num_sentences=args.max_num_sentences, shuffle=True, seed=args.seed) elif args.sampler == 'FixedBucketSampler': if args.comm_backend == 'horovod': raise NotImplementedError( 'FixedBucketSampler does not support horovod at present') if args.bucket_scheme == 'constant': bucket_scheme = ConstWidthBucket() elif args.bucket_scheme == 'linear': bucket_scheme = LinearWidthBucket() elif args.bucket_scheme == 'exp': bucket_scheme = ExpWidthBucket(bucket_len_step=1.2) else: raise NotImplementedError # TODO(sxjscience) Support auto-bucket-size tuning train_batch_sampler = FixedBucketSampler(lengths=[ (ele[2], ele[3]) for ele in data_train ], batch_size=args.batch_size, num_buckets=args.num_buckets, ratio=args.bucket_ratio, shuffle=True, use_average_length=True, bucket_scheme=bucket_scheme, seed=args.seed) else: raise NotImplementedError num_updates_per_epoch = int( math.ceil( len(train_batch_sampler) / (num_parts * len(ctx_l) * args.num_accumulated))) # Convert the batch sampler to multiple shards if num_parts > 1: train_batch_sampler = ShardedIterator(train_batch_sampler, num_parts=num_parts, part_index=rank, even_size=True, seed=args.seed + 1000 * rank) logging.info(train_batch_sampler) batchify_fn = bf.Tuple(bf.Pad(), bf.Pad(), bf.Stack(), bf.Stack(), bf.Stack()) train_data_loader = gluon.data.DataLoader( data_train, batch_sampler=train_batch_sampler, batchify_fn=batchify_fn, num_workers=0) val_data_loader = gluon.data.DataLoader(data_val, batch_size=args.val_batch_size, batchify_fn=batchify_fn, num_workers=0, shuffle=False) params = [p for p in param_dict.values() if p.grad_req != 'null'] model_averager = AverageSGDTracker(param_dict) log_start_time = time.time() num_params, num_fixed_params = None, None # TODO(sxjscience) Add a log metric class log_avg_loss_l = [mx.np.array(0.0, ctx=ctx) for ctx in ctx_l] # Maintain the denominator of the loss. log_avg_loss_denom_l = [mx.np.array(0.0, ctx=ctx) for ctx in ctx_l] log_wc_l = [mx.np.array(0, dtype=np.int64, ctx=ctx) for ctx in ctx_l] log_tgt_wc_l = [mx.np.array(0, dtype=np.int64, ctx=ctx) for ctx in ctx_l] log_avg_grad_norm = 0 log_iter_num = 0 if local_rank == 0: writer = SummaryWriter( logdir=os.path.join(args.save_dir, 'tensorboard')) if use_amp: amp.init_trainer(trainer) train_multi_data_loader = grouper(repeat(train_data_loader), len(ctx_l)) # when args.epochs < 0, the model will keep training if args.epochs < 0: if args.max_update > 0: total_train_iters = args.max_update if args.num_averages > 0: assert args.num_averages <= total_train_iters // args.save_iterval_update avg_start_iter = ( total_train_iters // args.save_iterval_update - args.num_averages) * args.save_iterval_update else: avg_start_iter = -1 else: total_train_iters = np.inf avg_start_iter = -1 else: total_train_iters = args.epochs * num_updates_per_epoch if args.num_averages > 0: assert args.num_averages <= args.epochs avg_start_iter = (args.epochs - args.num_average) * num_updates_per_epoch else: avg_start_iter = -1 # Here, we are manually setting up the scale to 1.0 because # in horovod, the scale can be the number of workers: # See the code here: https://github.com/horovod/horovod/blob/125115583b7029196e2ec530decd4209459d5479/horovod/mxnet/__init__.py#L141 # Since we will need to use the dynamic scaling in amp, we will manually call amp.unscale(). # A scale that is larger than 1.0 can be problematic in this case. trainer._scale = 1.0 if args.max_num_tokens > 0: const_scale = args.max_num_tokens else: const_scale = 100 train_start_time = time.time() for train_iter in range(total_train_iters): model.zero_grad() loss_denom_l = [mx.np.array(0.0, ctx=ctx) for ctx in ctx_l] for i in range(args.num_accumulated): loss_l = [] sample_data_l = next(train_multi_data_loader) for j, (sample_data, ctx) in enumerate(zip(sample_data_l, ctx_l)): src_token_ids, tgt_token_ids, src_valid_length,\ tgt_valid_length, sample_ids = sample_data src_token_ids = src_token_ids.as_in_ctx(ctx) tgt_token_ids = tgt_token_ids.as_in_ctx(ctx) src_valid_length = src_valid_length.as_in_ctx(ctx) tgt_valid_length = tgt_valid_length.as_in_ctx(ctx) src_wc, tgt_wc, bs = src_valid_length.sum(), \ tgt_valid_length.sum(), src_token_ids.shape[0] log_wc_l[j] += src_wc + tgt_wc log_tgt_wc_l[j] += tgt_wc token_count = (tgt_valid_length - 1).sum() loss_denom_l[j] += token_count / const_scale log_avg_loss_denom_l[j] += token_count / const_scale with mx.autograd.record(): if model.layout == 'NT': tgt_pred = model(src_token_ids, src_valid_length, tgt_token_ids[:, :-1], tgt_valid_length - 1) tgt_labels = tgt_token_ids[:, 1:] loss = label_smooth_loss(tgt_pred, tgt_labels) loss = mx.npx.sequence_mask( loss, sequence_length=tgt_valid_length - 1, use_sequence_length=True, axis=1) loss = loss.sum() / const_scale loss_l.append(loss) elif model.layout == 'TN': tgt_pred = model(src_token_ids.T, src_valid_length, tgt_token_ids.T[:-1, :], tgt_valid_length - 1) tgt_labels = tgt_token_ids.T[1:, :] loss = label_smooth_loss(tgt_pred, tgt_labels) loss = mx.npx.sequence_mask( loss, sequence_length=tgt_valid_length - 1, use_sequence_length=True, axis=0) loss = loss.sum() / const_scale loss_l.append(loss) log_avg_loss_l[j] += loss if use_amp: with mx.autograd.record(): with amp.scale_loss(loss_l, trainer) as amp_loss_l: for loss in amp_loss_l: loss.backward() else: with mx.autograd.record(): for loss in loss_l: loss.backward() # Print the total number of parameters if local_rank == 0 and num_params is None: num_params, num_fixed_params = count_parameters(param_dict) logging.info( 'Total Number of Parameters (not-fixed/fixed): {}/{}'.format( num_params, num_fixed_params)) # All-Reduce the gradient trainer.allreduce_grads() if args.comm_backend == 'horovod': # All-Reduce the loss denominator assert len(loss_denom_l) == 1 loss_denom = hvd.allreduce(loss_denom_l[0], average=False).asnumpy() else: loss_denom = sum([ele.asnumpy() for ele in loss_denom_l]) if use_amp: # We need to first unscale the gradient and then perform allreduce. grad_scale = trainer.amp_loss_scale * loss_denom else: grad_scale = loss_denom if args.max_grad_norm is not None: total_norm, ratio, is_finite\ = clip_grad_global_norm(params, args.max_grad_norm * grad_scale) total_norm = total_norm / grad_scale else: total_norm = grad_global_norm(params) total_norm = total_norm / grad_scale log_avg_grad_norm += total_norm log_iter_num += 1 trainer.update(loss_denom, ignore_stale_grad=True) if avg_start_iter > 0 and train_iter >= avg_start_iter: model_averager.step() if ((train_iter + 1) % args.log_interval == 0 or train_iter + 1 == total_train_iters): if args.comm_backend == 'horovod': # Use allreduce to get the total number of tokens and loss log_wc = hvd.allreduce(log_wc_l[0], average=False).asnumpy() log_tgt_wc = hvd.allreduce(log_tgt_wc_l[0], average=False).asnumpy() log_avg_loss = hvd.allreduce(log_avg_loss_l[0] / log_avg_loss_denom_l[0], average=True) log_avg_loss = log_avg_loss.asnumpy() else: log_wc = sum([ele.asnumpy() for ele in log_wc_l]) log_tgt_wc = sum([ele.asnumpy() for ele in log_tgt_wc_l]) log_avg_loss =\ sum([log_avg_loss_l[i].asnumpy() / log_avg_loss_denom_l[i].asnumpy() for i in range(len(log_avg_loss_l))]) / len(log_avg_loss_l) log_avg_grad_norm = log_avg_grad_norm / log_iter_num log_end_time = time.time() wps = log_wc / (log_end_time - log_start_time) epoch_id = train_iter // num_updates_per_epoch logging.info( '[Epoch {} Iter {}/{}, Overall {}/{}] loss={:.4f}, ppl={:.4f}, ' 'throughput={:.2f}K wps, total wc={:.2f}K, wpb={:.2f}K,' ' LR={}, gnorm={:.4f}, ETA={:.2f}h'.format( epoch_id, train_iter % num_updates_per_epoch + 1, num_updates_per_epoch, train_iter + 1, total_train_iters, log_avg_loss, np.exp(log_avg_loss), wps / 1000, log_wc / 1000, log_tgt_wc / 1000 / log_iter_num, trainer.learning_rate, log_avg_grad_norm, (log_end_time - train_start_time) / (train_iter + 1) * (total_train_iters - train_iter - 1) / 3600)) if local_rank == 0: writer.add_scalar('throughput_wps', wps, train_iter) writer.add_scalar('train_loss', log_avg_loss, train_iter) writer.add_scalar('lr', trainer.learning_rate, train_iter) writer.add_scalar('grad_norm', log_avg_grad_norm, train_iter) # Reinitialize the log variables log_start_time = time.time() log_avg_loss_l = [mx.np.array(0.0, ctx=ctx) for ctx in ctx_l] log_avg_loss_denom_l = [mx.np.array(0.0, ctx=ctx) for ctx in ctx_l] log_avg_grad_norm = 0 log_iter_num = 0 log_wc_l = [ mx.np.array(0, dtype=np.int64, ctx=ctx) for ctx in ctx_l ] log_tgt_wc_l = [ mx.np.array(0, dtype=np.int64, ctx=ctx) for ctx in ctx_l ] if (args.max_update > 0 and (train_iter + 1) % args.save_interval_update == 0) \ or ((train_iter + 1) % num_updates_per_epoch == 0) \ or train_iter + 1 == total_train_iters: epoch_id = (train_iter + 1) // num_updates_per_epoch if local_rank == 0: if args.max_update <= 0: model.save_parameters(os.path.join( args.save_dir, 'epoch{}.params'.format(epoch_id)), deduplicate=True) else: model.save_parameters(os.path.join( args.save_dir, 'iter{}.params'.format(train_iter + 1)), deduplicate=True) avg_val_loss, ntokens, pred_sentences, pred_lengths, sentence_ids\ = validation(model, val_data_loader, inference_model, beam_search_sampler, tgt_tokenizer, ctx_l) if args.comm_backend == 'horovod': flatten_pred_sentences = np.concatenate(pred_sentences, axis=0) all_val_loss = hvd.allgather( mx.np.array([avg_val_loss * ntokens], dtype=np.float32, ctx=ctx_l[0])) all_ntokens = hvd.allgather( mx.np.array([ntokens], dtype=np.int64, ctx=ctx_l[0])) flatten_pred_sentences = hvd.allgather( mx.np.array(flatten_pred_sentences, dtype=np.int32, ctx=ctx_l[0])) pred_lengths = hvd.allgather( mx.np.array(pred_lengths, dtype=np.int64, ctx=ctx_l[0])) sentence_ids = hvd.allgather( mx.np.array(sentence_ids, dtype=np.int64, ctx=ctx_l[0])) avg_val_loss = all_val_loss.asnumpy().sum( ) / all_ntokens.asnumpy().sum() flatten_pred_sentences = flatten_pred_sentences.asnumpy() pred_lengths = pred_lengths.asnumpy() sentence_ids = sentence_ids.asnumpy() pred_sentences = [None for _ in range(len(sentence_ids))] ptr = 0 assert sentence_ids.min() == 0 and sentence_ids.max( ) == len(sentence_ids) - 1 for sentence_id, length in zip(sentence_ids, pred_lengths): pred_sentences[sentence_id] = flatten_pred_sentences[ptr:( ptr + length)] ptr += length if local_rank == 0: # Perform detokenization pred_sentences_bpe_decode = [] pred_sentences_raw = [] for sentence in pred_sentences: bpe_decode_sentence = tgt_tokenizer.decode( sentence.tolist()) raw_sentence = base_tgt_tokenizer.decode( bpe_decode_sentence.split()) pred_sentences_bpe_decode.append(bpe_decode_sentence) pred_sentences_raw.append(raw_sentence) detok_sacrebleu_out = sacrebleu.corpus_bleu( sys_stream=pred_sentences_bpe_decode, ref_streams=[tgt_detok_sentences]) raw_sacrebleu_out = sacrebleu.corpus_bleu( sys_stream=pred_sentences_raw, ref_streams=[tgt_raw_sentences]) with open( os.path.join(args.save_dir, f'epoch{epoch_id}_dev_prediction.txt'), 'w') as of: for line in pred_sentences_raw: of.write(line + '\n') logging.info( '[Epoch {}][Iter {}/{}] validation loss/ppl={:.4f}/{:.4f}, ' 'SacreBlEU={}, Detok SacreBLUE={}'.format( epoch_id, train_iter, total_train_iters, avg_val_loss, np.exp(avg_val_loss), raw_sacrebleu_out.score, detok_sacrebleu_out.score)) writer.add_scalar('valid_loss', avg_val_loss, train_iter) writer.add_scalar('valid_bleu', raw_sacrebleu_out.score, train_iter) if args.num_averages > 0: model_averager.copy_back( param_dict) # TODO(sxjscience) Rewrite using update model.save_parameters(os.path.join(args.save_dir, 'average.params'), deduplicate=True)
def get_pretrain_data_text(data, batch_size, shuffle, num_buckets, vocab, tokenizer, max_seq_length, short_seq_prob, masked_lm_prob, max_predictions_per_seq, whole_word_mask, random_next_sentence, num_parts=1, part_idx=0, num_dataset_workers=1, num_batch_workers=1, circle_length=1, repeat=1, dataset_cached=False, num_max_dataset_cached=0): """Get a data iterator from raw text documents. Parameters ---------- batch_size : int The batch size per GPU. shuffle : bool Whether to shuffle the data. num_buckets : int The number of buckets for the FixedBucketSampler for training. vocab : Vocab The vocabulary. tokenizer : BaseTokenizer The tokenizer. max_seq_length : int The hard limit of maximum sequence length of sentence pairs. short_seq_prob : float The probability of sampling sequences shorter than the max_seq_length. masked_lm_prob : float The probability of replacing texts with masks/random words/original words. max_predictions_per_seq : int The hard limit of the number of predictions for masked words whole_word_mask : bool Whether to use whole word masking. num_parts : int The number of partitions for the dataset. part_idx : int The index of the partition to read. num_dataset_workers : int The number of worker processes for dataset construction. num_batch_workers : int The number of worker processes for batch construction. circle_length : int, default is 1 The number of files to be read for a single worker at the same time. When circle_length is larger than 1, we merge circle_length files. repeat : int, default is 1 The number of times that files are repeated. dataset_cached : bool, default is False Whether or not to cache last processed dataset. Each processed dataset can only be cached for once. When there is no new available processed dataset to be fetched, we pop a cached processed dataset. num_max_dataset_cached : int, default is 0 Maximum number of cached datasets. It is valid only if dataset_cached is True """ num_files = len(glob(data)) logging.info('%d files are found.', num_files) assert num_files >= num_parts, \ 'The number of text files must be no less than the number of ' \ 'workers/partitions (%d). Only %d files at %s are found.'%(num_parts, num_files, data) dataset_params = {'tokenizer': tokenizer, 'max_seq_length': max_seq_length, 'short_seq_prob': short_seq_prob, 'masked_lm_prob': masked_lm_prob, 'max_predictions_per_seq': max_predictions_per_seq, 'vocab':vocab, 'whole_word_mask': whole_word_mask, 'random_next_sentence': random_next_sentence} sampler_params = {'batch_size': batch_size, 'shuffle': shuffle, 'num_buckets': num_buckets} dataset_fn = prepare_pretrain_text_dataset sampler_fn = prepare_pretrain_bucket_sampler pad_val = vocab.pad_id batchify_fn = bf.Tuple( bf.Pad(val=pad_val, round_to=8), # input_id bf.Pad(val=pad_val), # masked_id bf.Pad(val=0), # masked_position bf.Pad(val=0), # masked_weight bf.Stack(), # next_sentence_label bf.Pad(val=0, round_to=8), # segment_id bf.Stack()) # valid_lengths split_sampler = SplitSampler(num_files, num_parts=num_parts, part_index=part_idx, repeat=repeat) dataloader = DatasetLoader(data, file_sampler=split_sampler, dataset_fn=dataset_fn, batch_sampler_fn=sampler_fn, dataset_params=dataset_params, batch_sampler_params=sampler_params, batchify_fn=batchify_fn, num_dataset_workers=num_dataset_workers, num_batch_workers=num_batch_workers, pin_memory=False, circle_length=circle_length, dataset_cached=dataset_cached, num_max_dataset_cached=num_max_dataset_cached) return dataloader
def test_pad_wrap_batchify(): def _verify_padded_arr(padded_arr, original_arr, pad_axis, pad_val, pad_length, dtype): ndim = original_arr.ndim slices_data = [slice(None) for _ in range(ndim)] slices_data[pad_axis] = slice(original_arr.shape[axis]) assert_allclose(padded_arr[tuple(slices_data)], original_arr) if original_arr.shape[pad_axis] < pad_length: slices_pad_val = [slice(None) for _ in range(ndim)] slices_pad_val[axis] = slice(original_arr.shape[pad_axis], None) pad_val_in_arr = padded_arr[tuple(slices_pad_val)] assert_allclose(pad_val_in_arr, (np.ones_like(pad_val_in_arr) * pad_val).astype(dtype)) batch_size = 6 for ndim in range(1, 3): for axis in range(-ndim, ndim): for length_min, length_max in [(3, 4), (3, 7)]: for pad_val in [-1, 0]: for dtype in [ np.uint8, np.int32, np.int64, np.float16, np.float32, np.float64 ]: # Each instance contains a single array for _dtype in [None, dtype]: shapes = [[2 for _ in range(ndim)] for _ in range(batch_size)] for i in range(len(shapes)): shapes[i][axis] = np.random.randint( length_min, length_max) random_data_npy = [ np.random.normal(0, 1, shape).astype(dtype) for shape in shapes ] batchify_fn = batchify.Pad(axis=axis, pad_val=pad_val, ret_length=True, dtype=_dtype) batch_data, valid_length = batchify_fn( random_data_npy) batch_data_use_mx, valid_length_use_mx = batchify_fn( [ mx.nd.array(ele, dtype=dtype) for ele in random_data_npy ]) assert_allclose(batch_data_use_mx.asnumpy(), batch_data.asnumpy()) assert_allclose(valid_length_use_mx.asnumpy(), valid_length.asnumpy()) assert batch_data.dtype == batch_data_use_mx.dtype == dtype assert valid_length.dtype == valid_length_use_mx.dtype == np.int32 valid_length = valid_length.asnumpy() batch_data = batch_data.asnumpy() for i in range(batch_size): assert (valid_length[i] == shapes[i][axis]) pad_length = max(shape[axis] for shape in shapes) _verify_padded_arr(batch_data[i], random_data_npy[i], axis, pad_val, pad_length, dtype) # Each instance contains 3 arrays, we pad part of them according to index TOTAL_ELE_NUM = 3 for pad_index in [[0], [1], [2], [0, 1], [1, 2], [0, 1, 2]]: shapes = [[[2 for _ in range(ndim)] for _ in range(batch_size)] for _ in range(TOTAL_ELE_NUM)] for j in pad_index: for i in range(batch_size): shapes[j][i][axis] = np.random.randint( length_min, length_max) random_data_npy = [ tuple( np.random.normal(0, 1, shapes[j] [i]).astype(dtype) for j in range(TOTAL_ELE_NUM)) for i in range(batch_size) ] batchify_fn = [] for j in range(TOTAL_ELE_NUM): if j in pad_index: batchify_fn.append( batchify.Pad(axis=axis, pad_val=pad_val, ret_length=True, dtype=_dtype)) else: batchify_fn.append( batchify.Stack(dtype=_dtype)) batchify_fn = batchify.Tuple(batchify_fn) ret_use_npy = batchify_fn(random_data_npy) ret_use_mx = batchify_fn([ tuple( mx.nd.array(ele[i], dtype=dtype) for i in range(TOTAL_ELE_NUM)) for ele in random_data_npy ]) for i in range(TOTAL_ELE_NUM): if i in pad_index: assert ret_use_npy[i][ 0].dtype == ret_use_mx[i][ 0].dtype == dtype assert ret_use_npy[i][ 1].dtype == ret_use_mx[i][ 1].dtype == np.int32 assert_allclose( ret_use_npy[i][0].asnumpy(), ret_use_mx[i][0].asnumpy()) assert_allclose( ret_use_npy[i][1].asnumpy(), ret_use_mx[i][1].asnumpy()) assert (ret_use_npy[i][1].shape == ( batch_size, )) else: assert ret_use_npy[ i].dtype == ret_use_mx[ i].dtype == dtype assert_allclose( ret_use_npy[i].asnumpy(), ret_use_mx[i].asnumpy()) for i in range(batch_size): for j in range(TOTAL_ELE_NUM): if j in pad_index: batch_data, valid_length = ret_use_npy[j][0].asnumpy(), \ ret_use_npy[j][1].asnumpy() assert (valid_length[i] == shapes[j][i][axis]) else: batch_data = ret_use_npy[ j].asnumpy() pad_length = max( ele[j].shape[axis] for ele in random_data_npy) _verify_padded_arr( batch_data[i], random_data_npy[i][j], axis, pad_val, pad_length, dtype) for _dtype in [np.float16, np.float32]: shapes = [[2 for _ in range(ndim)] for _ in range(batch_size)] for i in range(len(shapes)): shapes[i][axis] = np.random.randint( length_min, length_max) random_data_npy = [ np.random.normal(0, 1, shape).astype(dtype) for shape in shapes ] batchify_fn = batchify.Pad(axis=axis, pad_val=pad_val, ret_length=True, dtype=_dtype) batch_data, valid_length = batchify_fn( random_data_npy) batch_data_use_mx, valid_length_use_mx = batchify_fn( [ mx.nd.array(ele, dtype=dtype) for ele in random_data_npy ]) assert_allclose(valid_length_use_mx.asnumpy(), valid_length.asnumpy()) assert batch_data.dtype == batch_data_use_mx.dtype == _dtype assert valid_length.dtype == valid_length_use_mx.dtype == np.int32
def test_pad(): padded = batchify.Pad(pad_val=-1)([mx.nd.array([]), mx.nd.arange(1) ]).asnumpy().flatten().tolist() assert padded == [-1.0, 0.0]
def train(): """Training function.""" trainer = gluon.Trainer(model.collect_params(), args.optimizer, { 'learning_rate': args.lr, 'beta2': 0.98, 'epsilon': 1e-9 }) train_batchify_fn = btf.Tuple(btf.Pad(), btf.Pad(), btf.Stack(), btf.Stack()) test_batchify_fn = btf.Tuple(btf.Pad(), btf.Pad(), btf.Stack(), btf.Stack(), btf.Stack()) target_val_lengths = list(map(lambda x: x[-1], data_val_lengths)) target_test_lengths = list(map(lambda x: x[-1], data_test_lengths)) train_batch_sampler = FixedBucketSampler(lengths=data_train_lengths, batch_size=args.batch_size, num_buckets=args.num_buckets, ratio=args.bucket_ratio, shuffle=True, use_average_length=True) logging.info('Train Batch Sampler:\n{}'.format( train_batch_sampler.stats())) train_data_loader = DataLoader(data_train, batch_sampler=train_batch_sampler, batchify_fn=train_batchify_fn, num_workers=8) val_batch_sampler = FixedBucketSampler(lengths=target_val_lengths, batch_size=args.test_batch_size, num_buckets=args.num_buckets, ratio=args.bucket_ratio, shuffle=False, use_average_length=True) logging.info('Valid Batch Sampler:\n{}'.format(val_batch_sampler.stats())) val_data_loader = DataLoader(data_val, batch_sampler=val_batch_sampler, batchify_fn=test_batchify_fn, num_workers=8) test_batch_sampler = FixedBucketSampler(lengths=target_test_lengths, batch_size=args.test_batch_size, num_buckets=args.num_buckets, ratio=args.bucket_ratio, shuffle=False, use_average_length=True) logging.info('Test Batch Sampler:\n{}'.format(test_batch_sampler.stats())) test_data_loader = DataLoader(data_test, batch_sampler=test_batch_sampler, batchify_fn=test_batchify_fn, num_workers=8) best_valid_bleu = 0.0 step_num = 0 warmup_steps = args.warmup_steps grad_interval = args.num_accumulated model.collect_params().setattr('grad_req', 'add') average_start = (len(train_data_loader) // grad_interval) * (args.epochs - args.average_start) average_param_dict = None model.collect_params().zero_grad() for epoch_id in range(args.epochs): log_avg_loss = 0 log_wc = 0 loss_denom = 0 step_loss = 0 log_start_time = time.time() for batch_id, (src_seq, tgt_seq, src_valid_length, tgt_valid_length) \ in enumerate(train_data_loader): if batch_id % grad_interval == 0: step_num += 1 new_lr = args.lr / math.sqrt(args.num_units) \ * min(1. / math.sqrt(step_num), step_num * warmup_steps ** (-1.5)) trainer.set_learning_rate(new_lr) # logging.info(src_seq.context) Context suddenly becomes GPU. src_wc = src_valid_length.sum().asscalar() tgt_wc = tgt_valid_length.sum().asscalar() loss_denom += tgt_wc - tgt_valid_length.shape[0] if src_seq.shape[0] > len(ctx): src_seq_list, tgt_seq_list, src_valid_length_list, tgt_valid_length_list \ = [gluon.utils.split_and_load(seq, ctx, batch_axis=0, even_split=False) for seq in [src_seq, tgt_seq, src_valid_length, tgt_valid_length]] else: src_seq_list = [src_seq.as_in_context(ctx[0])] tgt_seq_list = [tgt_seq.as_in_context(ctx[0])] src_valid_length_list = [ src_valid_length.as_in_context(ctx[0]) ] tgt_valid_length_list = [ tgt_valid_length.as_in_context(ctx[0]) ] Ls = [] with mx.autograd.record(): for src_seq, tgt_seq, src_valid_length, tgt_valid_length \ in zip(src_seq_list, tgt_seq_list, src_valid_length_list, tgt_valid_length_list): out, _ = model(src_seq, tgt_seq[:, :-1], src_valid_length, tgt_valid_length - 1) smoothed_label = label_smoothing(tgt_seq[:, 1:]) ls = loss_function(out, smoothed_label, tgt_valid_length - 1).sum() Ls.append((ls * (tgt_seq.shape[1] - 1)) / args.batch_size) for L in Ls: L.backward() if batch_id % grad_interval == grad_interval - 1 or\ batch_id == len(train_data_loader) - 1: if average_param_dict is None: average_param_dict = { k: v.data(ctx[0]).copy() for k, v in model.collect_params().items() } trainer.step(float(loss_denom) / args.batch_size) param_dict = model.collect_params() param_dict.zero_grad() if step_num > average_start: alpha = 1. / max(1, step_num - average_start) for name, average_param in average_param_dict.items(): average_param[:] += alpha * ( param_dict[name].data(ctx[0]) - average_param) step_loss += sum([L.asscalar() for L in Ls]) if batch_id % grad_interval == grad_interval - 1 or\ batch_id == len(train_data_loader) - 1: log_avg_loss += step_loss / loss_denom * args.batch_size loss_denom = 0 step_loss = 0 log_wc += src_wc + tgt_wc if (batch_id + 1) % (args.log_interval * grad_interval) == 0: wps = log_wc / (time.time() - log_start_time) logging.info('[Epoch {} Batch {}/{}] loss={:.4f}, ppl={:.4f}, ' 'throughput={:.2f}K wps, wc={:.2f}K'.format( epoch_id, batch_id + 1, len(train_data_loader), log_avg_loss / args.log_interval, np.exp(log_avg_loss / args.log_interval), wps / 1000, log_wc / 1000)) log_start_time = time.time() log_avg_loss = 0 log_wc = 0 mx.nd.waitall() valid_loss, valid_translation_out = evaluate(val_data_loader, ctx[0]) valid_bleu_score, _, _, _, _ = compute_bleu([val_tgt_sentences], valid_translation_out, bpe=True, split_compound_word=True) logging.info( '[Epoch {}] valid Loss={:.4f}, valid ppl={:.4f}, valid bleu={:.2f}' .format(epoch_id, valid_loss, np.exp(valid_loss), valid_bleu_score * 100)) test_loss, test_translation_out = evaluate(test_data_loader, ctx[0]) test_bleu_score, _, _, _, _ = compute_bleu([test_tgt_sentences], test_translation_out, bpe=True, split_compound_word=True) logging.info( '[Epoch {}] test Loss={:.4f}, test ppl={:.4f}, test bleu={:.2f}'. format(epoch_id, test_loss, np.exp(test_loss), test_bleu_score * 100)) write_sentences( valid_translation_out, os.path.join(args.save_dir, 'epoch{:d}_valid_out.txt').format(epoch_id)) write_sentences( test_translation_out, os.path.join(args.save_dir, 'epoch{:d}_test_out.txt').format(epoch_id)) if valid_bleu_score > best_valid_bleu: best_valid_bleu = valid_bleu_score save_path = os.path.join(args.save_dir, 'valid_best.params') logging.info('Save best parameters to {}'.format(save_path)) model.save_params(save_path) save_path = os.path.join(args.save_dir, 'epoch{:d}.params'.format(epoch_id)) model.save_params(save_path) save_path = os.path.join(args.save_dir, 'average.params') mx.nd.save(save_path, average_param_dict) if args.average_checkpoint: for j in range(args.num_averages): params = mx.nd.load( os.path.join(args.save_dir, 'epoch{:d}.params'.format(args.epochs - j - 1))) alpha = 1. / (j + 1) for k, v in model._collect_params_with_prefix().items(): for c in ctx: v.data(c)[:] += alpha * (params[k].as_in_context(c) - v.data(c)) elif args.average_start > 0: for k, v in model.collect_params().items(): v.set_data(average_param_dict[k]) else: model.load_params(os.path.join(args.save_dir, 'valid_best.params'), ctx) valid_loss, valid_translation_out = evaluate(val_data_loader, ctx[0]) valid_bleu_score, _, _, _, _ = compute_bleu([val_tgt_sentences], valid_translation_out, bpe=True, split_compound_word=True) logging.info( 'Best model valid Loss={:.4f}, valid ppl={:.4f}, valid bleu={:.2f}'. format(valid_loss, np.exp(valid_loss), valid_bleu_score * 100)) test_loss, test_translation_out = evaluate(test_data_loader, ctx[0]) test_bleu_score, _, _, _, _ = compute_bleu([test_tgt_sentences], test_translation_out, bpe=True, split_compound_word=True) logging.info( 'Best model test Loss={:.4f}, test ppl={:.4f}, test bleu={:.2f}'. format(test_loss, np.exp(test_loss), test_bleu_score * 100)) write_sentences(valid_translation_out, os.path.join(args.save_dir, 'best_valid_out.txt')) write_sentences(test_translation_out, os.path.join(args.save_dir, 'best_test_out.txt'))
def train(): """Training function.""" trainer = gluon.Trainer(model.collect_params(), args.optimizer, { 'learning_rate': args.lr, 'beta2': 0.98, 'epsilon': 1e-9 }) train_batchify_fn = btf.Tuple(btf.Pad(), btf.Pad(), btf.Stack(dtype='float32'), btf.Stack(dtype='float32')) test_batchify_fn = btf.Tuple(btf.Pad(), btf.Pad(), btf.Stack(dtype='float32'), btf.Stack(dtype='float32'), btf.Stack()) target_val_lengths = list(map(lambda x: x[-1], data_val_lengths)) target_test_lengths = list(map(lambda x: x[-1], data_test_lengths)) if args.bucket_scheme == 'constant': bucket_scheme = ConstWidthBucket() elif args.bucket_scheme == 'linear': bucket_scheme = LinearWidthBucket() elif args.bucket_scheme == 'exp': bucket_scheme = ExpWidthBucket(bucket_len_step=1.2) else: raise NotImplementedError train_batch_sampler = FixedBucketSampler(lengths=data_train_lengths, batch_size=args.batch_size, num_buckets=args.num_buckets, ratio=args.bucket_ratio, shuffle=True, use_average_length=True, num_shards=len(ctx), bucket_scheme=bucket_scheme) logging.info('Train Batch Sampler:\n{}'.format( train_batch_sampler.stats())) train_data_loader = ShardedDataLoader(data_train, batch_sampler=train_batch_sampler, batchify_fn=train_batchify_fn, num_workers=8) val_batch_sampler = FixedBucketSampler(lengths=target_val_lengths, batch_size=args.test_batch_size, num_buckets=args.num_buckets, ratio=args.bucket_ratio, shuffle=False, use_average_length=True, bucket_scheme=bucket_scheme) logging.info('Valid Batch Sampler:\n{}'.format(val_batch_sampler.stats())) val_data_loader = DataLoader(data_val, batch_sampler=val_batch_sampler, batchify_fn=test_batchify_fn, num_workers=8) test_batch_sampler = FixedBucketSampler(lengths=target_test_lengths, batch_size=args.test_batch_size, num_buckets=args.num_buckets, ratio=args.bucket_ratio, shuffle=False, use_average_length=True, bucket_scheme=bucket_scheme) logging.info('Test Batch Sampler:\n{}'.format(test_batch_sampler.stats())) test_data_loader = DataLoader(data_test, batch_sampler=test_batch_sampler, batchify_fn=test_batchify_fn, num_workers=8) if args.bleu == 'tweaked': bpe = bool(args.dataset != 'IWSLT2015' and args.dataset != 'TOY') split_compound_word = bpe tokenized = True elif args.bleu == '13a' or args.bleu == 'intl': bpe = False split_compound_word = False tokenized = False else: raise NotImplementedError best_valid_bleu = 0.0 step_num = 0 warmup_steps = args.warmup_steps grad_interval = args.num_accumulated model.collect_params().setattr('grad_req', 'add') average_start = (len(train_data_loader) // grad_interval) * (args.epochs - args.average_start) average_param_dict = None model.collect_params().zero_grad() for epoch_id in range(args.epochs): log_avg_loss = 0 log_wc = 0 loss_denom = 0 step_loss = 0 log_start_time = time.time() for batch_id, seqs \ in enumerate(train_data_loader): if batch_id % grad_interval == 0: step_num += 1 new_lr = args.lr / math.sqrt(args.num_units) \ * min(1. / math.sqrt(step_num), step_num * warmup_steps ** (-1.5)) trainer.set_learning_rate(new_lr) src_wc, tgt_wc, bs = np.sum( [(shard[2].sum(), shard[3].sum(), shard[0].shape[0]) for shard in seqs], axis=0) src_wc = src_wc.asscalar() tgt_wc = tgt_wc.asscalar() loss_denom += tgt_wc - bs seqs = [[seq.as_in_context(context) for seq in shard] for context, shard in zip(ctx, seqs)] Ls = [] with mx.autograd.record(): for src_seq, tgt_seq, src_valid_length, tgt_valid_length in seqs: out, _ = model(src_seq, tgt_seq[:, :-1], src_valid_length, tgt_valid_length - 1) smoothed_label = label_smoothing(tgt_seq[:, 1:]) ls = loss_function(out, smoothed_label, tgt_valid_length - 1).sum() Ls.append((ls * (tgt_seq.shape[1] - 1)) / args.batch_size / 100.0) for L in Ls: L.backward() if batch_id % grad_interval == grad_interval - 1 or\ batch_id == len(train_data_loader) - 1: if average_param_dict is None: average_param_dict = { k: v.data(ctx[0]).copy() for k, v in model.collect_params().items() } trainer.step(float(loss_denom) / args.batch_size / 100.0) param_dict = model.collect_params() param_dict.zero_grad() if step_num > average_start: alpha = 1. / max(1, step_num - average_start) for name, average_param in average_param_dict.items(): average_param[:] += alpha * ( param_dict[name].data(ctx[0]) - average_param) step_loss += sum([L.asscalar() for L in Ls]) if batch_id % grad_interval == grad_interval - 1 or\ batch_id == len(train_data_loader) - 1: log_avg_loss += step_loss / loss_denom * args.batch_size * 100.0 loss_denom = 0 step_loss = 0 log_wc += src_wc + tgt_wc if (batch_id + 1) % (args.log_interval * grad_interval) == 0: wps = log_wc / (time.time() - log_start_time) logging.info('[Epoch {} Batch {}/{}] loss={:.4f}, ppl={:.4f}, ' 'throughput={:.2f}K wps, wc={:.2f}K'.format( epoch_id, batch_id + 1, len(train_data_loader), log_avg_loss / args.log_interval, np.exp(log_avg_loss / args.log_interval), wps / 1000, log_wc / 1000)) log_start_time = time.time() log_avg_loss = 0 log_wc = 0 mx.nd.waitall() valid_loss, valid_translation_out = evaluate(val_data_loader, ctx[0]) valid_bleu_score, _, _, _, _ = compute_bleu( [val_tgt_sentences], valid_translation_out, tokenized=tokenized, tokenizer=args.bleu, split_compound_word=split_compound_word, bpe=bpe) logging.info( '[Epoch {}] valid Loss={:.4f}, valid ppl={:.4f}, valid bleu={:.2f}' .format(epoch_id, valid_loss, np.exp(valid_loss), valid_bleu_score * 100)) test_loss, test_translation_out = evaluate(test_data_loader, ctx[0]) test_bleu_score, _, _, _, _ = compute_bleu( [test_tgt_sentences], test_translation_out, tokenized=tokenized, tokenizer=args.bleu, split_compound_word=split_compound_word, bpe=bpe) logging.info( '[Epoch {}] test Loss={:.4f}, test ppl={:.4f}, test bleu={:.2f}'. format(epoch_id, test_loss, np.exp(test_loss), test_bleu_score * 100)) write_sentences( valid_translation_out, os.path.join(args.save_dir, 'epoch{:d}_valid_out.txt').format(epoch_id)) write_sentences( test_translation_out, os.path.join(args.save_dir, 'epoch{:d}_test_out.txt').format(epoch_id)) if valid_bleu_score > best_valid_bleu: best_valid_bleu = valid_bleu_score save_path = os.path.join(args.save_dir, 'valid_best.params') logging.info('Save best parameters to {}'.format(save_path)) model.save_parameters(save_path) save_path = os.path.join(args.save_dir, 'epoch{:d}.params'.format(epoch_id)) model.save_parameters(save_path) save_path = os.path.join(args.save_dir, 'average.params') mx.nd.save(save_path, average_param_dict) if args.average_checkpoint: for j in range(args.num_averages): params = mx.nd.load( os.path.join(args.save_dir, 'epoch{:d}.params'.format(args.epochs - j - 1))) alpha = 1. / (j + 1) for k, v in model._collect_params_with_prefix().items(): for c in ctx: v.data(c)[:] += alpha * (params[k].as_in_context(c) - v.data(c)) save_path = os.path.join( args.save_dir, 'average_checkpoint_{}.params'.format(args.num_averages)) model.save_parameters(save_path) elif args.average_start > 0: for k, v in model.collect_params().items(): v.set_data(average_param_dict[k]) save_path = os.path.join(args.save_dir, 'average.params') model.save_parameters(save_path) else: model.load_parameters(os.path.join(args.save_dir, 'valid_best.params'), ctx) valid_loss, valid_translation_out = evaluate(val_data_loader, ctx[0]) valid_bleu_score, _, _, _, _ = compute_bleu( [val_tgt_sentences], valid_translation_out, tokenized=tokenized, tokenizer=args.bleu, bpe=bpe, split_compound_word=split_compound_word) logging.info( 'Best model valid Loss={:.4f}, valid ppl={:.4f}, valid bleu={:.2f}'. format(valid_loss, np.exp(valid_loss), valid_bleu_score * 100)) test_loss, test_translation_out = evaluate(test_data_loader, ctx[0]) test_bleu_score, _, _, _, _ = compute_bleu( [test_tgt_sentences], test_translation_out, tokenized=tokenized, tokenizer=args.bleu, bpe=bpe, split_compound_word=split_compound_word) logging.info( 'Best model test Loss={:.4f}, test ppl={:.4f}, test bleu={:.2f}'. format(test_loss, np.exp(test_loss), test_bleu_score * 100)) write_sentences(valid_translation_out, os.path.join(args.save_dir, 'best_valid_out.txt')) write_sentences(test_translation_out, os.path.join(args.save_dir, 'best_test_out.txt'))
def train(args): store, num_parts, rank, local_rank, is_master_node, ctx_l = init_comm( args.comm_backend, args.gpus) src_tokenizer = create_tokenizer(args.src_tokenizer, args.src_subword_model_path, args.src_vocab_path) tgt_tokenizer = create_tokenizer(args.tgt_tokenizer, args.tgt_subword_model_path, args.tgt_vocab_path) src_vocab = src_tokenizer.vocab tgt_vocab = tgt_tokenizer.vocab train_src_data, train_tgt_data = load_dataset_with_cache( args.train_src_corpus, args.train_tgt_corpus, src_tokenizer, tgt_tokenizer, args.overwrite_cache) dev_src_data, dev_tgt_data = load_dataset_with_cache( args.dev_src_corpus, args.dev_tgt_corpus, src_tokenizer, tgt_tokenizer, args.overwrite_cache) data_train = gluon.data.SimpleDataset([ (src_tokens, tgt_tokens, len(src_tokens), len(tgt_tokens), i) for i, (src_tokens, tgt_tokens) in enumerate(zip(train_src_data, train_tgt_data)) ]) data_val = gluon.data.SimpleDataset([ (src_tokens, tgt_tokens, len(src_tokens), len(tgt_tokens), i) for i, (src_tokens, tgt_tokens) in enumerate(zip(dev_src_data, dev_tgt_data)) ]) # Construct the model + loss function if args.cfg.endswith('.yml'): cfg = TransformerModel.get_cfg().clone_merge(args.cfg) else: cfg = TransformerModel.get_cfg(args.cfg) cfg.defrost() cfg.MODEL.src_vocab_size = len(src_vocab) cfg.MODEL.tgt_vocab_size = len(tgt_vocab) if args.fp16: raise NotImplementedError # cfg.MODEL.dtype = 'float16' cfg.freeze() model = TransformerModel.from_cfg(cfg) model.initialize(mx.init.Xavier(magnitude=args.magnitude), ctx=ctx_l) model.hybridize() if local_rank == 0: logging.info(model) with open(os.path.join(args.save_dir, 'config.yml'), 'w') as cfg_f: cfg_f.write(cfg.dump()) label_smooth_loss = LabelSmoothCrossEntropyLoss( num_labels=len(tgt_vocab), alpha=args.label_smooth_alpha, from_logits=False) label_smooth_loss.hybridize() rescale_loss = 100.0 if args.comm_backend == 'horovod': hvd.broadcast_parameters(model.collect_params(), root_rank=0) # Construct the trainer # TODO(sxjscience) Support AMP if args.lr is None: base_lr = 2.0 / math.sqrt(args.num_units) / math.sqrt( args.warmup_steps) else: base_lr = args.lr lr_scheduler = InverseSquareRootScheduler( warmup_steps=args.warmup_steps, base_lr=base_lr, warmup_init_lr=args.warmup_init_lr) trainer_settings = (model.collect_params(), 'adam', { 'learning_rate': args.lr, 'beta1': 0.9, 'beta2': 0.98, 'epsilon': 1e-9, 'lr_scheduler': lr_scheduler }) if args.comm_backend == 'horovod': trainer = hvd.DistributedTrainer(*trainer_settings) else: trainer = gluon.Trainer(*trainer_settings) # Load Data if args.sampler == 'BoundedBudgetSampler': train_batch_sampler = BoundedBudgetSampler( lengths=[(ele[2], ele[3]) for ele in data_train], max_num_tokens=args.max_num_tokens, max_num_sentences=args.max_num_sentences, seed=args.seed, num_parts=num_parts, part_index=rank) elif args.sampler == 'FixedBucketSampler': if args.comm_backend == 'horovod': raise NotImplementedError( 'FixedBucketSampler does not support horovod at present') if args.bucket_scheme == 'constant': bucket_scheme = ConstWidthBucket() elif args.bucket_scheme == 'linear': bucket_scheme = LinearWidthBucket() elif args.bucket_scheme == 'exp': bucket_scheme = ExpWidthBucket(bucket_len_step=1.2) else: raise NotImplementedError # TODO(sxjscience) Support auto-bucket-size tuning train_batch_sampler = FixedBucketSampler(lengths=[ (ele[2], ele[3]) for ele in data_train ], batch_size=args.batch_size, num_buckets=args.num_buckets, ratio=args.bucket_ratio, shuffle=True, use_average_length=True, bucket_scheme=bucket_scheme, seed=args.seed) else: raise NotImplementedError if local_rank == 0: logging.info(train_batch_sampler) batchify_fn = bf.Tuple(bf.Pad(), bf.Pad(), bf.Stack(), bf.Stack(), bf.Stack()) train_data_loader = gluon.data.DataLoader( data_train, batch_sampler=train_batch_sampler, batchify_fn=batchify_fn, num_workers=0) val_data_loader = gluon.data.DataLoader(data_val, batch_size=args.val_batch_size, batchify_fn=batchify_fn, num_workers=0, shuffle=False) for v in model.collect_params().values(): if v.grad_req != 'null': v.grad_req = 'add' model.zero_grad() model_averager = AverageSGDTracker(model.collect_params()) log_start_time = time.time() num_params, num_fixed_params = None, None # TODO(sxjscience) Add a log metric class accum_count = 0 loss_denom = 0 n_train_iters = 0 log_wc = 0 log_avg_loss = 0.0 log_loss_denom = 0 epoch_id = 0 while (args.epochs < 0 or epoch_id < args.epochs ): # when args.epochs < 0, the model will keep training n_epoch_train_iters = 0 processed_batch_num = 0 train_multi_data_loader = grouper(train_data_loader, len(ctx_l)) is_last_batch = False sample_data_l = next(train_multi_data_loader) while not is_last_batch: processed_batch_num += len(sample_data_l) loss_l = [] for sample_data, ctx in zip(sample_data_l, ctx_l): if sample_data is None: continue src_token_ids, tgt_token_ids, src_valid_length, tgt_valid_length, sample_ids = sample_data src_wc, tgt_wc, bs = src_valid_length.sum( ), tgt_valid_length.sum(), src_token_ids.shape[0] loss_denom += tgt_wc - bs log_loss_denom += tgt_wc - bs log_wc += src_wc + tgt_wc src_token_ids = src_token_ids.as_in_ctx(ctx) tgt_token_ids = tgt_token_ids.as_in_ctx(ctx) src_valid_length = src_valid_length.as_in_ctx(ctx) tgt_valid_length = tgt_valid_length.as_in_ctx(ctx) with mx.autograd.record(): tgt_pred = model(src_token_ids, src_valid_length, tgt_token_ids[:, :-1], tgt_valid_length - 1) tgt_labels = tgt_token_ids[:, 1:] loss = label_smooth_loss(tgt_pred, tgt_labels) loss = mx.npx.sequence_mask( loss, sequence_length=tgt_valid_length - 1, use_sequence_length=True, axis=1) loss_l.append(loss.sum() / rescale_loss) for l in loss_l: l.backward() accum_count += 1 try: sample_data_l = next(train_multi_data_loader) except StopIteration: is_last_batch = True if local_rank == 0 and num_params is None: num_params, num_fixed_params = count_parameters( model.collect_params()) logging.info( 'Total Number of Parameters (not-fixed/fixed): {}/{}'. format(num_params, num_fixed_params)) sum_loss = sum([l.as_in_ctx(mx.cpu()) for l in loss_l]) * rescale_loss log_avg_loss += sum_loss mx.npx.waitall() if accum_count == args.num_accumulated or is_last_batch: # Update the parameters n_train_iters += 1 n_epoch_train_iters += 1 trainer.step(loss_denom.asnumpy() / rescale_loss) accum_count = 0 loss_denom = 0 model.zero_grad() if (args.epochs > 0 and epoch_id >= args.epochs - args.num_averages) or \ (args.max_update > 0 and n_train_iters >= args.max_update - args.num_averages * args.save_interval_update): model_averager.step() if local_rank == 0 and \ (n_epoch_train_iters % args.log_interval == 0 or is_last_batch): log_end_time = time.time() log_wc = log_wc.asnumpy() wps = log_wc / (log_end_time - log_start_time) log_avg_loss = (log_avg_loss / log_loss_denom).asnumpy() logging.info( '[Epoch {} Batch {}/{}] loss={:.4f}, ppl={:.4f}, ' 'throughput={:.2f}K wps, wc={:.2f}K, LR={}'.format( epoch_id, processed_batch_num * num_parts, len(train_data_loader), log_avg_loss, np.exp(log_avg_loss), wps / 1000, log_wc / 1000, trainer.learning_rate)) log_start_time = time.time() log_avg_loss = 0 log_loss_denom = 0 log_wc = 0 if local_rank == 0 and \ (args.max_update > 0 and n_train_iters % args.save_interval_update == 0): model.save_parameters(os.path.join( args.save_dir, 'update{:d}.params'.format( n_train_iters // args.save_interval_update)), deduplicate=True) if args.max_update > 0 and n_train_iters >= args.max_update: break if local_rank == 0 and args.epochs > 0: model.save_parameters(os.path.join( args.save_dir, 'epoch{:d}.params'.format(epoch_id)), deduplicate=True) avg_valid_loss = validation(model, val_data_loader, ctx_l) logging.info('[Epoch {}] validation loss/ppl={:.4f}/{:.4f}'.format( epoch_id, avg_valid_loss, np.exp(avg_valid_loss))) if args.max_update > 0 and n_train_iters >= args.max_update: break epoch_id += 1 if args.num_averages > 0: model_averager.copy_back( model.collect_params()) # TODO(sxjscience) Rewrite using update model.save_parameters(os.path.join(args.save_dir, 'average.params'), deduplicate=True)