def cbow_batch(data): """Create a batch for CBOW training objective.""" centers, word_context, word_context_mask = data assert len(centers.shape) == 2 negatives_shape = (len(centers), args.negative) negatives, negatives_mask = remove_accidental_hits( negatives_sampler(negatives_shape), centers) center_negatives = mx.nd.concat(centers, negatives, dim=1) center_negatives_mask = mx.nd.concat(mx.nd.ones_like(centers), negatives_mask, dim=1) labels = mx.nd.concat(mx.nd.ones_like(centers), mx.nd.zeros_like(negatives), dim=1) if not args.ngram_buckets: return (word_context.as_in_context(context[0]), word_context_mask.as_in_context(context[0]), center_negatives.as_in_context(context[0]), center_negatives_mask.as_in_context(context[0]), labels.as_in_context(context[0])) else: unique, inverse_unique_indices = np.unique(word_context.asnumpy(), return_inverse=True) inverse_unique_indices = mx.nd.array(inverse_unique_indices, ctx=context[0]) subwords, subwords_mask = subword_lookup.get(unique.astype(int)) return (word_context.as_in_context(context[0]), word_context_mask.as_in_context(context[0]), center_negatives.as_in_context(context[0]), center_negatives_mask.as_in_context(context[0]), labels.as_in_context(context[0]), mx.nd.array(subwords, ctx=context[0]), mx.nd.array(subwords_mask, ctx=context[0]), inverse_unique_indices)
def train(args): """Training helper.""" if args.ngram_buckets: # Fasttext model coded_dataset, negatives_sampler, vocab, subword_function, \ idx_to_subwordidxs = get_train_data(args) embedding = nlp.model.train.FasttextEmbeddingModel( token_to_idx=vocab.token_to_idx, subword_function=subword_function, embedding_size=args.emsize, weight_initializer=mx.init.Uniform(scale=1 / args.emsize), sparse_grad=not args.no_sparse_grad, ) else: coded_dataset, negatives_sampler, vocab = get_train_data(args) embedding = nlp.model.train.SimpleEmbeddingModel( token_to_idx=vocab.token_to_idx, embedding_size=args.emsize, weight_initializer=mx.init.Uniform(scale=1 / args.emsize), sparse_grad=not args.no_sparse_grad, ) embedding_out = nlp.model.train.SimpleEmbeddingModel( token_to_idx=vocab.token_to_idx, embedding_size=args.emsize, weight_initializer=mx.init.Zero(), sparse_grad=not args.no_sparse_grad, ) loss_function = mx.gluon.loss.SigmoidBinaryCrossEntropyLoss() context = get_context(args) embedding.initialize(ctx=context) embedding_out.initialize(ctx=context) if not args.no_hybridize: embedding.hybridize(static_alloc=not args.no_static_alloc) embedding_out.hybridize(static_alloc=not args.no_static_alloc) optimizer_kwargs = dict(learning_rate=args.lr) params = list(embedding.embedding.collect_params().values()) + \ list(embedding_out.collect_params().values()) trainer = mx.gluon.Trainer(params, args.optimizer, optimizer_kwargs) if args.ngram_buckets: # Fasttext model optimizer_subwords_kwargs = dict(learning_rate=args.lr_subwords) params_subwords = list( embedding.subword_embedding.collect_params().values()) trainer_subwords = mx.gluon.Trainer(params_subwords, args.optimizer_subwords, optimizer_subwords_kwargs) num_update = 0 for epoch in range(args.epochs): random.shuffle(coded_dataset) context_sampler = nlp.data.ContextSampler(coded=coded_dataset, batch_size=args.batch_size, window=args.window) num_batches = len(context_sampler) # Logging variables log_wc = 0 log_start_time = time.time() log_avg_loss = 0 for i, batch in enumerate(context_sampler): progress = (epoch * num_batches + i) / (args.epochs * num_batches) (center, word_context, word_context_mask) = batch negatives_shape = (word_context.shape[0], word_context.shape[1] * args.negative) negatives, negatives_mask = remove_accidental_hits( negatives_sampler(negatives_shape), word_context, word_context_mask) if args.ngram_buckets: # Fasttext model if args.model.lower() == 'skipgram': unique, inverse_unique_indices = np.unique( center.asnumpy(), return_inverse=True) unique = mx.nd.array(unique) inverse_unique_indices = mx.nd.array( inverse_unique_indices, ctx=context[0]) subwords, subwords_mask = \ indices_to_subwordindices_mask(unique, idx_to_subwordidxs) elif args.model.lower() == 'cbow': unique, inverse_unique_indices = np.unique( word_context.asnumpy(), return_inverse=True) unique = mx.nd.array(unique) inverse_unique_indices = mx.nd.array( inverse_unique_indices, ctx=context[0]) subwords, subwords_mask = \ indices_to_subwordindices_mask(unique, idx_to_subwordidxs) else: logging.error('Unsupported model %s.', args.model) sys.exit(1) num_update += len(center) # To GPU center = center.as_in_context(context[0]) if args.ngram_buckets: # Fasttext model subwords = subwords.as_in_context(context[0]) subwords_mask = subwords_mask.astype(np.float32).as_in_context( context[0]) word_context = word_context.as_in_context(context[0]) word_context_mask = word_context_mask.as_in_context(context[0]) negatives = negatives.as_in_context(context[0]) negatives_mask = negatives_mask.as_in_context(context[0]) with mx.autograd.record(): # Combine subword level embeddings with word embeddings if args.model.lower() == 'skipgram': if args.ngram_buckets: emb_in = embedding(center, subwords, subwordsmask=subwords_mask, words_to_unique_subwords_indices= inverse_unique_indices) else: emb_in = embedding(center) with mx.autograd.pause(): word_context_negatives = mx.nd.concat( word_context, negatives, dim=1) word_context_negatives_mask = mx.nd.concat( word_context_mask, negatives_mask, dim=1) emb_out = embedding_out(word_context_negatives, word_context_negatives_mask) # Compute loss pred = mx.nd.batch_dot(emb_in, emb_out.swapaxes(1, 2)) pred = pred.squeeze() * word_context_negatives_mask label = mx.nd.concat(word_context_mask, mx.nd.zeros_like(negatives), dim=1) elif args.model.lower() == 'cbow': word_context = word_context.reshape((-3, 1)) word_context_mask = word_context_mask.reshape((-3, 1)) if args.ngram_buckets: emb_in = embedding(word_context, subwords, word_context_mask, subwords_mask, inverse_unique_indices) else: emb_in = embedding(word_context, word_context_mask) with mx.autograd.pause(): center = center.tile(args.window * 2).reshape((-1, 1)) negatives = negatives.reshape((-1, args.negative)) center_negatives = mx.nd.concat( center, negatives, dim=1) center_negatives_mask = mx.nd.concat( mx.nd.ones_like(center), negatives_mask, dim=1) emb_out = embedding_out(center_negatives, center_negatives_mask) # Compute loss pred = mx.nd.batch_dot(emb_in, emb_out.swapaxes(1, 2)) pred = pred.squeeze() * word_context_mask label = mx.nd.concat( mx.nd.ones_like(word_context), mx.nd.zeros_like(negatives), dim=1) loss = loss_function(pred, label) loss.backward() if args.optimizer.lower() != 'adagrad': trainer.set_learning_rate( max(0.0001, args.lr * (1 - progress))) if (args.optimizer_subwords.lower() != 'adagrad' and args.ngram_buckets): trainer_subwords.set_learning_rate( max(0.0001, args.lr_subwords * (1 - progress))) trainer.step(batch_size=1) if args.ngram_buckets: trainer_subwords.step(batch_size=1) # Logging log_wc += loss.shape[0] log_avg_loss += loss.mean() if (i + 1) % args.log_interval == 0: wps = log_wc / (time.time() - log_start_time) # Forces waiting for computation by computing loss value log_avg_loss = log_avg_loss.asscalar() / args.log_interval logging.info('[Epoch {} Batch {}/{}] loss={:.4f}, ' 'throughput={:.2f}K wps, wc={:.2f}K'.format( epoch, i + 1, num_batches, log_avg_loss, wps / 1000, log_wc / 1000)) log_start_time = time.time() log_avg_loss = 0 log_wc = 0 if args.eval_interval and (i + 1) % args.eval_interval == 0: with print_time('mx.nd.waitall()'): mx.nd.waitall() with print_time('evaluate'): evaluate(args, embedding, vocab, num_update) # Evaluate with print_time('mx.nd.waitall()'): mx.nd.waitall() with print_time('evaluate'): evaluate(args, embedding, vocab, num_update, eval_analogy=not args.no_eval_analogy) # Save params with print_time('save parameters'): save_params(args, embedding, embedding_out)