Example #1
0
    def check_amp_convert_bucketing_module():
        model = train_model(context=mx.current_context())
        result_model = amp.convert_bucketing_module(model)
        val_sent = []
        batch_size = 128
        invalid_label = -1
        num_sentence = 1000
        buckets = [5, 10, 20, 30, 40]
        len_vocab = 50

        for _ in range(num_sentence):
            len_sentence = randint(6,
                                   max(buckets) -
                                   1)  # leave out the two last buckets empty
            val_sentence = []
            for _ in range(len_sentence):
                val_sentence.append(randint(1, len_vocab))
            val_sent.append(val_sentence)

        data_val = mx.rnn.BucketSentenceIter(val_sent,
                                             batch_size,
                                             buckets=buckets,
                                             invalid_label=invalid_label)
        result_model.bind(data_val.provide_data,
                          data_val.provide_label,
                          for_training=False)
        result_model.score(data_val,
                           mx.gluon.metric.Perplexity(invalid_label),
                           batch_end_callback=mx.callback.Speedometer(
                               batch_size, 1))

        # AMP conversion with cast_optional_params set to true
        # Flaky test when cast_optional_params set to True : https://github.com/apache/incubator-mxnet/issues/16030
        '''
def test_bucketing_save_load(tmpdir):
    previous_update_on_kvstore = os.getenv('MXNET_UPDATE_ON_KVSTORE', "1")
    os.putenv('MXNET_UPDATE_ON_KVSTORE', '1')
    def dict_equ(a, b):
        assert set(a) == set(b)
        for k in a:
            assert (a[k].asnumpy() == b[k].asnumpy()).all()


    len_vocab = 50
    num_embed = 25
    num_epochs = 5
    batch_size = 128
    num_layers = 2
    num_hidden = 25
    buckets = [5, 10, 20, 30, 40]
    invalid_label = -1
    num_sentence=1000

    stack = mx.rnn.SequentialRNNCell()
    for i in range(num_layers):
        stack.add(mx.rnn.LSTMCell(num_hidden=num_hidden, prefix='lstm_l%d_' % i))

    def sym_gen(seq_len):
        data = mx.sym.Variable('data')
        label = mx.sym.Variable('softmax_label')
        embed = mx.sym.Embedding(data=data, input_dim=len_vocab,
                                 output_dim=num_embed, name='embed')
        stack.reset()
        outputs, states = stack.unroll(seq_len, inputs=embed, merge_outputs=True)

        pred = mx.sym.Reshape(outputs, shape=(-1, num_hidden))
        pred = mx.sym.FullyConnected(data=pred, num_hidden=len_vocab, name='pred')

        label = mx.sym.Reshape(label, shape=(-1,))
        loss = mx.sym.SoftmaxOutput(data=pred, label=label, name='softmax')

        return loss, ('data',), ('softmax_label',)

    path = str(tmpdir.join('test'))

    model = train_model(context=mx.current_context())
    model.save_checkpoint(path, 0)
    data_train, data_val = prepare_bucketing_data(buckets, len_vocab, batch_size, invalid_label, num_sentence)
    mod2 = mx.mod.BucketingModule.load(path, 0, sym_gen=sym_gen,
                                       default_bucket_key=data_train.default_bucket_key)

    mod2.bind(data_shapes=data_train.provide_data,
              label_shapes=data_train.provide_label)

    for bucket_key in model._buckets.keys():
        dict_equ(model._buckets[model._default_bucket_key].get_params()[0],
                 mod2._buckets[mod2._default_bucket_key].get_params()[0])
    mod2.fit(
        train_data=data_train,
        eval_data=data_val,
        eval_metric=mx.gluon.metric.Perplexity(invalid_label), # Use Perplexity for multiclass classification.
        kvstore='device',
        optimizer='sgd',
        optimizer_params={'learning_rate': 0.01,
                          'momentum': 0,
                          'wd': 0.00001},
        initializer=mx.init.Xavier(factor_type="in", magnitude=2.34),
        num_epoch=num_epochs,
        batch_end_callback=mx.callback.Speedometer(batch_size, 50))
    os.putenv('MXNET_UPDATE_ON_KVSTORE', previous_update_on_kvstore)