Beispiel #1
0
 def cbow_batch(data):
     """Create a batch for CBOW training objective."""
     centers, word_context, word_context_mask = data
     assert len(centers.shape) == 2
     negatives_shape = (len(centers), args.negative)
     negatives, negatives_mask = remove_accidental_hits(
         negatives_sampler(negatives_shape), centers)
     center_negatives = mx.nd.concat(centers, negatives, dim=1)
     center_negatives_mask = mx.nd.concat(mx.nd.ones_like(centers),
                                          negatives_mask,
                                          dim=1)
     labels = mx.nd.concat(mx.nd.ones_like(centers),
                           mx.nd.zeros_like(negatives),
                           dim=1)
     if not args.ngram_buckets:
         return (word_context.as_in_context(context[0]),
                 word_context_mask.as_in_context(context[0]),
                 center_negatives.as_in_context(context[0]),
                 center_negatives_mask.as_in_context(context[0]),
                 labels.as_in_context(context[0]))
     else:
         unique, inverse_unique_indices = np.unique(word_context.asnumpy(),
                                                    return_inverse=True)
         inverse_unique_indices = mx.nd.array(inverse_unique_indices,
                                              ctx=context[0])
         subwords, subwords_mask = subword_lookup.get(unique.astype(int))
         return (word_context.as_in_context(context[0]),
                 word_context_mask.as_in_context(context[0]),
                 center_negatives.as_in_context(context[0]),
                 center_negatives_mask.as_in_context(context[0]),
                 labels.as_in_context(context[0]),
                 mx.nd.array(subwords, ctx=context[0]),
                 mx.nd.array(subwords_mask,
                             ctx=context[0]), inverse_unique_indices)
def train(args):
    """Training helper."""
    if args.ngram_buckets:  # Fasttext model
        coded_dataset, negatives_sampler, vocab, subword_function, \
            idx_to_subwordidxs = get_train_data(args)
        embedding = nlp.model.train.FasttextEmbeddingModel(
            token_to_idx=vocab.token_to_idx,
            subword_function=subword_function,
            embedding_size=args.emsize,
            weight_initializer=mx.init.Uniform(scale=1 / args.emsize),
            sparse_grad=not args.no_sparse_grad,
        )
    else:
        coded_dataset, negatives_sampler, vocab = get_train_data(args)
        embedding = nlp.model.train.SimpleEmbeddingModel(
            token_to_idx=vocab.token_to_idx,
            embedding_size=args.emsize,
            weight_initializer=mx.init.Uniform(scale=1 / args.emsize),
            sparse_grad=not args.no_sparse_grad,
        )
    embedding_out = nlp.model.train.SimpleEmbeddingModel(
        token_to_idx=vocab.token_to_idx,
        embedding_size=args.emsize,
        weight_initializer=mx.init.Zero(),
        sparse_grad=not args.no_sparse_grad,
    )
    loss_function = mx.gluon.loss.SigmoidBinaryCrossEntropyLoss()

    context = get_context(args)
    embedding.initialize(ctx=context)
    embedding_out.initialize(ctx=context)
    if not args.no_hybridize:
        embedding.hybridize(static_alloc=not args.no_static_alloc)
        embedding_out.hybridize(static_alloc=not args.no_static_alloc)

    optimizer_kwargs = dict(learning_rate=args.lr)
    params = list(embedding.embedding.collect_params().values()) + \
        list(embedding_out.collect_params().values())
    trainer = mx.gluon.Trainer(params, args.optimizer, optimizer_kwargs)

    if args.ngram_buckets:  # Fasttext model
        optimizer_subwords_kwargs = dict(learning_rate=args.lr_subwords)
        params_subwords = list(
            embedding.subword_embedding.collect_params().values())
        trainer_subwords = mx.gluon.Trainer(params_subwords,
                                            args.optimizer_subwords,
                                            optimizer_subwords_kwargs)

    num_update = 0
    for epoch in range(args.epochs):
        random.shuffle(coded_dataset)
        context_sampler = nlp.data.ContextSampler(coded=coded_dataset,
                                                  batch_size=args.batch_size,
                                                  window=args.window)
        num_batches = len(context_sampler)

        # Logging variables
        log_wc = 0
        log_start_time = time.time()
        log_avg_loss = 0

        for i, batch in enumerate(context_sampler):
            progress = (epoch * num_batches + i) / (args.epochs * num_batches)
            (center, word_context, word_context_mask) = batch
            negatives_shape = (word_context.shape[0],
                               word_context.shape[1] * args.negative)
            negatives, negatives_mask = remove_accidental_hits(
                negatives_sampler(negatives_shape), word_context,
                word_context_mask)

            if args.ngram_buckets:  # Fasttext model
                if args.model.lower() == 'skipgram':
                    unique, inverse_unique_indices = np.unique(
                        center.asnumpy(), return_inverse=True)
                    unique = mx.nd.array(unique)
                    inverse_unique_indices = mx.nd.array(
                        inverse_unique_indices, ctx=context[0])
                    subwords, subwords_mask = \
                        indices_to_subwordindices_mask(unique, idx_to_subwordidxs)
                elif args.model.lower() == 'cbow':
                    unique, inverse_unique_indices = np.unique(
                        word_context.asnumpy(), return_inverse=True)
                    unique = mx.nd.array(unique)
                    inverse_unique_indices = mx.nd.array(
                        inverse_unique_indices, ctx=context[0])
                    subwords, subwords_mask = \
                        indices_to_subwordindices_mask(unique, idx_to_subwordidxs)
                else:
                    logging.error('Unsupported model %s.', args.model)
                    sys.exit(1)

            num_update += len(center)

            # To GPU
            center = center.as_in_context(context[0])
            if args.ngram_buckets:  # Fasttext model
                subwords = subwords.as_in_context(context[0])
                subwords_mask = subwords_mask.astype(np.float32).as_in_context(
                    context[0])
            word_context = word_context.as_in_context(context[0])
            word_context_mask = word_context_mask.as_in_context(context[0])
            negatives = negatives.as_in_context(context[0])
            negatives_mask = negatives_mask.as_in_context(context[0])

            with mx.autograd.record():
                # Combine subword level embeddings with word embeddings
                if args.model.lower() == 'skipgram':
                    if args.ngram_buckets:
                        emb_in = embedding(center, subwords,
                                           subwordsmask=subwords_mask,
                                           words_to_unique_subwords_indices=
                                           inverse_unique_indices)
                    else:
                        emb_in = embedding(center)

                    with mx.autograd.pause():
                        word_context_negatives = mx.nd.concat(
                            word_context, negatives, dim=1)
                        word_context_negatives_mask = mx.nd.concat(
                            word_context_mask, negatives_mask, dim=1)

                    emb_out = embedding_out(word_context_negatives,
                                            word_context_negatives_mask)

                    # Compute loss
                    pred = mx.nd.batch_dot(emb_in, emb_out.swapaxes(1, 2))
                    pred = pred.squeeze() * word_context_negatives_mask
                    label = mx.nd.concat(word_context_mask,
                                         mx.nd.zeros_like(negatives), dim=1)

                elif args.model.lower() == 'cbow':
                    word_context = word_context.reshape((-3, 1))
                    word_context_mask = word_context_mask.reshape((-3, 1))
                    if args.ngram_buckets:
                        emb_in = embedding(word_context, subwords,
                                           word_context_mask, subwords_mask,
                                           inverse_unique_indices)
                    else:
                        emb_in = embedding(word_context, word_context_mask)

                    with mx.autograd.pause():
                        center = center.tile(args.window * 2).reshape((-1, 1))
                        negatives = negatives.reshape((-1, args.negative))

                        center_negatives = mx.nd.concat(
                            center, negatives, dim=1)
                        center_negatives_mask = mx.nd.concat(
                            mx.nd.ones_like(center), negatives_mask, dim=1)

                    emb_out = embedding_out(center_negatives,
                                            center_negatives_mask)

                    # Compute loss
                    pred = mx.nd.batch_dot(emb_in, emb_out.swapaxes(1, 2))
                    pred = pred.squeeze() * word_context_mask
                    label = mx.nd.concat(
                        mx.nd.ones_like(word_context),
                        mx.nd.zeros_like(negatives), dim=1)

                loss = loss_function(pred, label)

            loss.backward()

            if args.optimizer.lower() != 'adagrad':
                trainer.set_learning_rate(
                    max(0.0001, args.lr * (1 - progress)))

            if (args.optimizer_subwords.lower() != 'adagrad'
                    and args.ngram_buckets):
                trainer_subwords.set_learning_rate(
                    max(0.0001, args.lr_subwords * (1 - progress)))

            trainer.step(batch_size=1)
            if args.ngram_buckets:
                trainer_subwords.step(batch_size=1)

            # Logging
            log_wc += loss.shape[0]
            log_avg_loss += loss.mean()
            if (i + 1) % args.log_interval == 0:
                wps = log_wc / (time.time() - log_start_time)
                # Forces waiting for computation by computing loss value
                log_avg_loss = log_avg_loss.asscalar() / args.log_interval
                logging.info('[Epoch {} Batch {}/{}] loss={:.4f}, '
                             'throughput={:.2f}K wps, wc={:.2f}K'.format(
                                 epoch, i + 1, num_batches, log_avg_loss,
                                 wps / 1000, log_wc / 1000))
                log_start_time = time.time()
                log_avg_loss = 0
                log_wc = 0

            if args.eval_interval and (i + 1) % args.eval_interval == 0:
                with print_time('mx.nd.waitall()'):
                    mx.nd.waitall()
                with print_time('evaluate'):
                    evaluate(args, embedding, vocab, num_update)

    # Evaluate
    with print_time('mx.nd.waitall()'):
        mx.nd.waitall()
    with print_time('evaluate'):
        evaluate(args, embedding, vocab, num_update,
                 eval_analogy=not args.no_eval_analogy)

    # Save params
    with print_time('save parameters'):
        save_params(args, embedding, embedding_out)