def check_amp_convert_bucketing_module(): model = train_model(context=mx.current_context()) result_model = amp.convert_bucketing_module(model) val_sent = [] batch_size = 128 invalid_label = -1 num_sentence = 1000 buckets = [5, 10, 20, 30, 40] len_vocab = 50 for _ in range(num_sentence): len_sentence = randint(6, max(buckets) - 1) # leave out the two last buckets empty val_sentence = [] for _ in range(len_sentence): val_sentence.append(randint(1, len_vocab)) val_sent.append(val_sentence) data_val = mx.rnn.BucketSentenceIter(val_sent, batch_size, buckets=buckets, invalid_label=invalid_label) result_model.bind(data_val.provide_data, data_val.provide_label, for_training=False) result_model.score(data_val, mx.gluon.metric.Perplexity(invalid_label), batch_end_callback=mx.callback.Speedometer( batch_size, 1)) # AMP conversion with cast_optional_params set to true # Flaky test when cast_optional_params set to True : https://github.com/apache/incubator-mxnet/issues/16030 '''
def test_bucketing_save_load(tmpdir): previous_update_on_kvstore = os.getenv('MXNET_UPDATE_ON_KVSTORE', "1") os.putenv('MXNET_UPDATE_ON_KVSTORE', '1') def dict_equ(a, b): assert set(a) == set(b) for k in a: assert (a[k].asnumpy() == b[k].asnumpy()).all() len_vocab = 50 num_embed = 25 num_epochs = 5 batch_size = 128 num_layers = 2 num_hidden = 25 buckets = [5, 10, 20, 30, 40] invalid_label = -1 num_sentence=1000 stack = mx.rnn.SequentialRNNCell() for i in range(num_layers): stack.add(mx.rnn.LSTMCell(num_hidden=num_hidden, prefix='lstm_l%d_' % i)) def sym_gen(seq_len): data = mx.sym.Variable('data') label = mx.sym.Variable('softmax_label') embed = mx.sym.Embedding(data=data, input_dim=len_vocab, output_dim=num_embed, name='embed') stack.reset() outputs, states = stack.unroll(seq_len, inputs=embed, merge_outputs=True) pred = mx.sym.Reshape(outputs, shape=(-1, num_hidden)) pred = mx.sym.FullyConnected(data=pred, num_hidden=len_vocab, name='pred') label = mx.sym.Reshape(label, shape=(-1,)) loss = mx.sym.SoftmaxOutput(data=pred, label=label, name='softmax') return loss, ('data',), ('softmax_label',) path = str(tmpdir.join('test')) model = train_model(context=mx.current_context()) model.save_checkpoint(path, 0) data_train, data_val = prepare_bucketing_data(buckets, len_vocab, batch_size, invalid_label, num_sentence) mod2 = mx.mod.BucketingModule.load(path, 0, sym_gen=sym_gen, default_bucket_key=data_train.default_bucket_key) mod2.bind(data_shapes=data_train.provide_data, label_shapes=data_train.provide_label) for bucket_key in model._buckets.keys(): dict_equ(model._buckets[model._default_bucket_key].get_params()[0], mod2._buckets[mod2._default_bucket_key].get_params()[0]) mod2.fit( train_data=data_train, eval_data=data_val, eval_metric=mx.gluon.metric.Perplexity(invalid_label), # Use Perplexity for multiclass classification. kvstore='device', optimizer='sgd', optimizer_params={'learning_rate': 0.01, 'momentum': 0, 'wd': 0.00001}, initializer=mx.init.Xavier(factor_type="in", magnitude=2.34), num_epoch=num_epochs, batch_end_callback=mx.callback.Speedometer(batch_size, 50)) os.putenv('MXNET_UPDATE_ON_KVSTORE', previous_update_on_kvstore)