def test_mx_pt_eq_ssru(hidden, inference_only, seq_len, batch): pytest.importorskip("mxnet") from mxnet import np import sockeye.layers b_mx = sockeye.layers.SSRU(hidden, inference_only) b_mx.initialize() b_pt = sockeye.layers_pt.PyTorchSSRU(hidden, inference_only) b_pt.weights_from_mxnet_block(b_mx) inputs_mx = np.random.uniform(0, 1, (seq_len, batch, hidden)) previous_states_mx = np.zeros((1, batch, hidden)) inputs_pt = pt.as_tensor(inputs_mx.asnumpy()) previous_states_pt = pt.as_tensor(previous_states_mx.asnumpy()) r1_mx, r2_mx = b_mx(inputs_mx, previous_states_mx) r1_pt, r2_pt = b_pt(inputs_pt, previous_states_pt) r1_mx = r1_mx.asnumpy() r2_mx = r2_mx.asnumpy() r1_pt = r1_pt.detach().numpy() r2_pt = r2_pt.detach().numpy() assert np.allclose(r1_mx, r1_pt) assert np.allclose(r2_mx, r2_pt)
def test_mx_pt_eq_multi_head_self_attention(seq_len, batch_size, hidden, heads): pytest.importorskip("mxnet") from mxnet import np import sockeye.layers inputs_mx = np.random.uniform(0, 1, (seq_len, batch_size, hidden)) inputs_pt = pt.as_tensor(inputs_mx.asnumpy()) b_mx = sockeye.layers.MultiHeadSelfAttention(hidden, heads, hidden, dropout=0.0) b_mx.initialize() b_pt = sockeye.layers_pt.PyTorchMultiHeadSelfAttention(hidden, heads, hidden, dropout=0.0) b_pt.eval() b_pt.weights_from_mxnet_block(b_mx) r_mx, states_mx = b_mx(inputs_mx, None, None, None) r_pt, states_pt = b_pt(inputs_pt, previous_states=None, inputs_length_mask=None, bias=None) r_mx = r_mx.asnumpy() states_mx = states_mx.asnumpy() r_pt = r_pt.detach().numpy() states_pt = states_pt.detach().numpy() assert np.allclose(r_mx, r_pt, atol=1e-06) assert np.allclose(states_mx, states_pt, atol=1e-06)
def test_candidate_scorer(): pytest.importorskip("mxnet") from mxnet import np import sockeye.beam_search scorer = sockeye.beam_search.CandidateScorer(length_penalty_alpha=1.0, length_penalty_beta=0.0, brevity_penalty_weight=0.1) scorer.initialize() scorer.hybridize(static_alloc=True) # np.array input raw_scores = np.random.uniform(0, 1, (5, )) lengths = np.array([1, 2, 3, 4, 5]) reference_lengths = np.array([2, 3, 4, 5, 6]) scores = scorer(raw_scores, lengths, reference_lengths) unnormalized_scores = scorer.unnormalize(scores, lengths, reference_lengths) assert np.allclose(unnormalized_scores, raw_scores) # int/float input raw_scores = 5.6 lengths = 3 reference_lengths = 4 scores = scorer(raw_scores, lengths, reference_lengths) unnormalized_scores = scorer.unnormalize(scores, lengths, reference_lengths) assert np.allclose(unnormalized_scores, raw_scores)
def test_mx_pt_eq_positional_embeddings(data_len, num_embed, scale_up_input, scale_down_positions, steps): pytest.importorskip("mxnet") from mxnet import np import sockeye.layers max_seq_len = 10 data_mx = np.random.uniform(0, 1, (2, data_len, num_embed)) data_pt = pt.as_tensor(data_mx.asnumpy()) if steps is None: steps_mx, steps_pt = None, None else: steps_mx = np.array(steps).reshape((-1, 1)) steps_pt = pt.as_tensor(steps).unsqueeze(1) b_mx = sockeye.layers.PositionalEmbeddings( weight_type='fixed', num_embed=num_embed, max_seq_len=max_seq_len, scale_up_input=scale_up_input, scale_down_positions=scale_down_positions, weight_init=None) b_mx.initialize() r_mx = b_mx(data_mx, steps_mx).asnumpy() b_pt = sockeye.layers_pt.PyTorchPositionalEmbeddings( weight_type='fixed', num_embed=num_embed, max_seq_len=max_seq_len, scale_up_input=scale_up_input, scale_down_positions=scale_down_positions) b_pt.weights_from_mxnet_block(b_mx) r_pt = b_pt(data_pt, steps_pt).detach().numpy() np.allclose(r_mx, r_pt)
def _data_batches_equal(db1, db2) -> bool: from mxnet import np equal = True equal = equal and np.allclose(db1.source, db2.source) equal = equal and np.allclose(db1.source_length, db2.source_length) equal = equal and np.allclose(db1.target, db2.target) equal = equal and np.allclose(db1.target_length, db2.target_length) equal = equal and db1.labels.keys() == db2.labels.keys() equal = equal and db1.samples == db2.samples equal = equal and db1.tokens == db2.tokens return equal
def test_length_penalty_default(): pytest.importorskip("mxnet") from mxnet import np import sockeye.beam_search lengths = np.array([[1], [2], [3]]) length_penalty = sockeye.beam_search.LengthPenalty(1.0, 0.0) expected_lp = np.array([[1.0], [2.], [3.]]) assert np.allclose(length_penalty(lengths), expected_lp) length_penalty.hybridize() assert np.allclose(length_penalty(lengths), expected_lp)
def test_t5_model(cfg_key, activation, ctx): with ctx: cfg = T5Model.get_cfg(cfg_key) cfg.defrost() cfg.MODEL.vocab_size = 256 cfg.MODEL.d_model = 128 cfg.MODEL.d_ff = 512 cfg.MODEL.num_layers = 2 cfg.MODEL.num_heads = 4 cfg.MODEL.activation = activation cfg.MODEL.layout = 'NT' cfg.freeze() cfg_tn = cfg.clone() cfg_tn.defrost() cfg_tn.MODEL.layout = 'TN' cfg_tn.freeze() # test TN and NT consistency t5_model = T5Model.from_cfg(cfg) t5_model.initialize() t5_model.hybridize() t5_model_tn = T5Model.from_cfg(cfg_tn) t5_model_tn.share_parameters(t5_model.collect_params()) t5_model_tn.hybridize() batch_size = 8 src_length = 32 tgt_length = 18 src_data = np.random.randint(0, 255, (batch_size, src_length)) src_valid_length = np.random.randint(src_length // 2, src_length, (batch_size, )) tgt_data = np.random.randint(0, 255, (batch_size, tgt_length)) tgt_valid_length = np.random.randint(tgt_length // 4, tgt_length, (batch_size, )) out = t5_model(src_data, src_valid_length, tgt_data, tgt_valid_length) out_tn = t5_model_tn(src_data.T, src_valid_length, tgt_data.T, tgt_valid_length) assert np.allclose(np.swapaxes(out, 0, 1), out_tn, 1E-5, 1E-5) # test consistency with various target valid length for shift in range(1, np.min(tgt_valid_length).item()): for partial_out in [ t5_model(src_data, src_valid_length, tgt_data[:, :-shift], tgt_valid_length - shift), t5_model(src_data, src_valid_length, tgt_data, tgt_valid_length - shift) ]: for i in range(batch_size): vl = tgt_valid_length[i].item() - shift assert np.allclose(partial_out[i, :vl], out[i, :vl], 1E-5, 1E-5)
def test_brevity_penalty(): pytest.importorskip("mxnet") from mxnet import np import sockeye.beam_search hyp_lengths = np.array([[1], [2], [3]]) ref_lengths = np.array([[7], [2], [91]]) brevity_penalty = sockeye.beam_search.BrevityPenalty(3.5) expected_bp = np.array([[3.5 * (1 - 7 / 1)], [0.0], [3.5 * (1 - 91 / 3)]]) assert np.allclose(brevity_penalty(hyp_lengths, ref_lengths), expected_bp) brevity_penalty.hybridize() assert np.allclose(brevity_penalty(hyp_lengths, ref_lengths), expected_bp)
def test_length_penalty(): pytest.importorskip("mxnet") from mxnet import np import sockeye.beam_search lengths = np.array([[1], [2], [3]]) length_penalty = sockeye.beam_search.LengthPenalty(.2, 5.0) expected_lp = np.array([[6**0.2 / 6**0.2], [7**0.2 / 6**0.2], [8**0.2 / 6**0.2]]) assert np.allclose(length_penalty(lengths), expected_lp) length_penalty.hybridize() assert np.allclose(length_penalty(lengths), expected_lp)
def test_mx_pt_eq_length_ratio(): pytest.importorskip("mxnet") from mxnet import np import sockeye.layers hidden_size = 32 seq_len = 10 batch_size = 8 num_layers = 1 # more layers seems to be numerically unstable source_encoded_mx = np.random.uniform(0, 1, (batch_size, seq_len, hidden_size)) source_encoded_pt = pt.as_tensor(source_encoded_mx.asnumpy()) source_lengths_mx = np.random.randint(1, seq_len, (batch_size, ), dtype='int32') source_lengths_pt = pt.as_tensor(source_lengths_mx.asnumpy()) b_mx = sockeye.layers.LengthRatio(hidden_size=hidden_size, num_layers=num_layers) b_mx.initialize() r_mx = b_mx(source_encoded_mx, source_lengths_mx).asnumpy() b_pt = sockeye.layers_pt.PyTorchLengthRatio(hidden_size=hidden_size, num_layers=num_layers) b_pt.weights_from_mxnet_block(b_mx) r_pt = b_pt(source_encoded_pt, source_lengths_pt).detach().numpy() assert np.allclose(r_mx, r_pt)
def test_mx_pt_eq_multi_head_attention(qlen, kvlen, batch_size, hidden, heads): pytest.importorskip("mxnet") from mxnet import np import sockeye.layers queries_mx = np.random.uniform(0, 1, (qlen, batch_size, hidden)) queries_pt = pt.as_tensor(queries_mx.asnumpy()) memory_mx = np.random.uniform(0, 1, (kvlen, batch_size, hidden)) memory_pt = pt.as_tensor(memory_mx.asnumpy()) b_mx = sockeye.layers.MultiHeadAttention(hidden, heads, hidden, dropout=0.0) b_mx.initialize() r_mx = b_mx(queries_mx, memory_mx, None, None, None) b_pt = sockeye.layers_pt.PyTorchMultiHeadAttention(hidden, heads, hidden, dropout=0.0, depth_key_value=hidden) b_pt.weights_from_mxnet_block(b_mx) r_pt = b_pt(queries_pt, memory_pt, mask=None, projected_memory_kv=None) print(b_pt.ff_kv.weight[0]) print(b_mx.ff_kv.weight.data()[0]) r_mx = r_mx.asnumpy() r_pt = r_pt.detach().numpy() assert np.allclose(r_mx, r_pt, atol=1e-06)
def test_mx_pt_eq_embedding(vocab_size, num_embed, factor_configs, sparse): pytest.importorskip("mxnet") import sockeye.encoder from mxnet import np config = sockeye.encoder.EmbeddingConfig(vocab_size=vocab_size, num_embed=num_embed, dropout=0, factor_configs=factor_configs, allow_sparse_grad=sparse) block_mx = sockeye.encoder.Embedding(config, None, C.DTYPE_FP32) block_mx.initialize() block_pt = sockeye.encoder_pt.PyTorchEmbedding(config, None) block_pt.weights_from_mxnet_block(block_mx) batch, seq_len, num_factors = 4, 10, len( factor_configs) + 1 if factor_configs is not None else 1 # data_mx does not take into account different vocab sizes for factors data_mx = np.random.randint(0, vocab_size, (batch, seq_len, num_factors)) data_pt = pt.as_tensor(data_mx.asnumpy()) vl_mx = np.ones((1, )) # not used vl_pt = pt.as_tensor(vl_mx.asnumpy()) r_mx, _ = block_mx(data_mx, vl_mx) r_pt = block_pt(data_pt) r_mx = r_mx.asnumpy() r_pt = r_pt.detach().numpy() assert np.allclose(r_mx, r_pt)
def test_conversion(args, hf_model, gluon_model): logging.info('testing conversion...') # create dummy input batch_size = 6 src_length = 128 tgt_length = 8 vocab_size = hf_model.shared.weight.shape[0] src_data = np.random.randint(1, vocab_size, (batch_size, src_length)) src_valid_length = np.random.randint(src_length // 2, src_length, (batch_size, )) tgt_data = np.random.randint(1, vocab_size, (batch_size, tgt_length)) tgt_valid_length = np.random.randint(tgt_length // 2, tgt_length, (batch_size, )) enc_attn_mask = npx.arange_like(src_data, axis=-1) < src_valid_length.reshape(-1, 1) dec_attn_mask = npx.arange_like(tgt_data, axis=-1) < tgt_valid_length.reshape(-1, 1) # test T5Model forward pass hf_model.eval() # disable dropout hf_out = hf_model( input_ids=torch.from_numpy(src_data.asnumpy()), attention_mask=torch.from_numpy(enc_attn_mask.asnumpy()), decoder_input_ids=torch.from_numpy(tgt_data.asnumpy()), decoder_attention_mask=torch.from_numpy( dec_attn_mask.asnumpy()))['last_hidden_state'].detach().numpy() gl_out = gluon_model(src_data, src_valid_length, tgt_data, tgt_valid_length) for i in range(batch_size): assert np.allclose(hf_out[i, :tgt_valid_length[i].item(), :], gl_out[i, :tgt_valid_length[i].item(), :], 1E-3, 1E-3) logging.info('pass')
def test_brevity_penalty_default(): pytest.importorskip("mxnet") from mxnet import np import sockeye.beam_search hyp_lengths = np.array([[1], [2], [3]]) ref_lengths = np.array([[2], [3], [2]]) brevity_penalty = sockeye.beam_search.BrevityPenalty(0.0) expected_bp = np.array([[0.0], [0.0], [0.0]]) expected_bp_np = np.array([0.0, 0.0, 0.0]) assert np.allclose(brevity_penalty(hyp_lengths, ref_lengths), expected_bp) assert np.allclose(brevity_penalty(hyp_lengths, ref_lengths), expected_bp_np) brevity_penalty.hybridize() assert np.allclose(brevity_penalty(hyp_lengths, ref_lengths), expected_bp)
def test_mx_pt_eq_transformer_encoder(): pytest.importorskip("mxnet") import sockeye.transformer import sockeye.encoder import mxnet as mx from mxnet import np pt.manual_seed(13) mx.random.seed(13) config = sockeye.transformer.TransformerConfig( model_size=128, attention_heads=8, feed_forward_num_hidden=256, act_type='relu', num_layers=12, dropout_attention=0, dropout_act=0, dropout_prepost=0, positional_embedding_type=C.LEARNED_POSITIONAL_EMBEDDING, preprocess_sequence='n', postprocess_sequence='r', max_seq_len_source=50, max_seq_len_target=60, use_lhuc=False) encoder_mx = sockeye.encoder.get_transformer_encoder(config, dtype=C.DTYPE_FP32) encoder_mx.initialize() encoder_pt = sockeye.encoder_pt.pytorch_get_transformer_encoder(config) encoder_pt.weights_from_mxnet_block(encoder_mx) batch = 12 seq_len = 45 data_mx = np.random.uniform(0, 1, (batch, seq_len, config.model_size)) data_pt = pt.as_tensor(data_mx.asnumpy()) lengths_mx = np.random.randint(1, seq_len, (batch, )) lengths_pt = pt.as_tensor(lengths_mx.asnumpy()) r1_mx, r2_mx = encoder_mx(data_mx, lengths_mx) r1_pt, r2_pt = encoder_pt(data_pt, lengths_pt) r1_mx, r2_mx = r1_mx.asnumpy(), r2_mx.asnumpy() r1_pt, r2_pt = r1_pt.detach().numpy(), r2_pt.detach().numpy() print("Max deviation:", onp.abs(r1_mx - r1_pt).max()) assert np.allclose(r1_mx, r1_pt, atol=1e-04) assert np.allclose(r2_mx, r2_pt, atol=1e-04)
def test_mx_pt_eq_cross_entropy_loss(logits, labels, weight, alpha): mxnet = pytest.importorskip("mxnet") from mxnet import np import sockeye.loss logits_mx = np.array(logits) logits_mx.attach_grad() labels_mx = np.array(labels) logits_pt = pt.tensor(logits, requires_grad=True, dtype=pt.float32) labels_pt = pt.tensor(labels) num_labels = logits_mx.shape[-1] loss_mx = sockeye.loss.CrossEntropyLossWithoutSoftmaxOutput( ignore_label=C.PAD_ID, label_smoothing=alpha, num_labels=num_labels, weight=weight) loss_mx.initialize() loss_pt = sockeye.loss_pt.PyTorchCrossEntropyLoss(ignore_label=C.PAD_ID, label_smoothing=alpha, weight=weight) with mxnet.autograd.record(): loss_value_mx, loss_samples_mx = loss_mx( { C.LOGITS_NAME: logits_mx, 'other_stuff': None }, { C.TARGET_LABEL_NAME: labels_mx, 'other_stuff': None }) loss_value_mx.backward() loss_value_pt, loss_samples_pt = loss_pt( { C.LOGITS_NAME: logits_pt, 'other_stuff': None }, { C.TARGET_LABEL_NAME: labels_pt, 'other_stuff': None }) loss_value_pt.backward() assert np.allclose(loss_value_mx.asnumpy(), loss_value_pt.detach().numpy()) assert loss_samples_mx.item() == loss_samples_pt.detach().numpy() assert np.allclose(logits_mx.grad.asnumpy(), logits_pt.grad.numpy())
def test_mx_pt_eq_interleaved_matmul_encdec_qk(qlen, kvlen, batch_size): pytest.importorskip("mxnet") from mxnet import np, npx import sockeye.layers hidden = 32 q_mx = np.random.uniform(0, 1, (qlen, batch_size, hidden)) kv_mx = np.random.uniform(0, 1, (kvlen, batch_size, hidden * 2)) heads = 4 q_pt = pt.as_tensor(q_mx.asnumpy()) kv_pt = pt.as_tensor(kv_mx.asnumpy()) assert np.allclose(q_pt.numpy(), q_mx.asnumpy()) assert np.allclose(kv_pt.numpy(), kv_mx.asnumpy()) r0 = npx.interleaved_matmul_encdec_qk(q_mx, kv_mx, heads=heads).asnumpy() r1 = sockeye.layers_pt.pytorch_interleaved_matmul_encdec_qk( q_pt, kv_pt, heads=heads).detach().numpy() assert np.allclose(r0, r1)
def test_topk_func(batch_size, beam_size, target_vocab_size): pytest.importorskip("mxnet") from mxnet import np import sockeye.beam_search # Random model scores. Shape: (batch_size * beam_size, target_vocab_size) scores = np.random.uniform(0, 1, (batch_size * beam_size, target_vocab_size)) # offset for batch sizes > 1 offset = np.repeat( np.arange(0, batch_size * beam_size, beam_size, dtype='int32'), beam_size) np_hyp, np_word, np_values = numpy_topk(scores.asnumpy(), k=beam_size, offset=offset) topk = sockeye.beam_search.TopK(k=beam_size) topk.initialize() mx_hyp, mx_word, mx_values = topk(scores, offset) assert np.allclose(mx_hyp, np_hyp) assert np.allclose(mx_word, np_word) assert np.allclose(mx_values, np_values) topk.hybridize() mx_hyp, mx_word, mx_values = topk(scores, offset) assert np.allclose(mx_hyp, np_hyp) assert np.allclose(mx_word, np_word) assert np.allclose(mx_values, np_values)
def test_mx_pt_eq_output_layer(): pytest.importorskip("mxnet") from mxnet import np import sockeye.layers num_hidden = 32 vocab_size = 64 data_mx = np.random.uniform(0, 1, (2, 10, num_hidden)) data_pt = pt.as_tensor(data_mx.asnumpy()) vocab_slice_ids_mx = np.array([4, 7, 23]) vocab_slice_ids_pt = pt.tensor([4, 7, 23]) b_mx = sockeye.layers.OutputLayer(num_hidden, vocab_size) b_mx.initialize() b_pt = sockeye.layers_pt.PyTorchOutputLayer(num_hidden, vocab_size) b_pt.weights_from_mxnet_block(b_mx) assert b_pt.weight.size() == (vocab_size, num_hidden) out_mx = b_mx(data_mx, None) assert out_mx.shape == (2, 10, vocab_size) out_pt = b_pt(data_pt, None) assert out_pt.shape == (2, 10, vocab_size) assert np.allclose(out_mx.asnumpy(), out_pt.detach().numpy(), atol=1e-06) reduced_out_mx = out_mx.take(vocab_slice_ids_mx, axis=-1).asnumpy() reduced_out_pt = pt.index_select(out_pt, 2, vocab_slice_ids_pt).detach().numpy() assert np.allclose(reduced_out_mx, reduced_out_pt, atol=1e-06) out_restricted_mx = b_mx(data_mx, vocab_slice_ids_mx).asnumpy() out_restricted_pt = b_pt(data_pt, vocab_slice_ids_pt).detach().numpy() assert out_restricted_mx.shape == (2, 10, len(vocab_slice_ids_mx)) assert out_restricted_pt.shape == (2, 10, len(vocab_slice_ids_pt)) assert onp.allclose(out_restricted_mx, out_restricted_pt, atol=1e-06)
def test_mx_pt_eq_interleaved_matmul_encdec_valatt(qlen, kvlen, batch_size): pytest.importorskip("mxnet") from mxnet import np, npx import sockeye.layers hidden = 32 kv_mx = np.random.uniform(0, 1, (kvlen, batch_size, hidden * 2)) heads = 4 kv_pt = pt.as_tensor(kv_mx.asnumpy()) att = np.random.uniform(0, 1, (batch_size * heads, qlen, kvlen)) attpt = pt.as_tensor(att.asnumpy()) r0 = npx.interleaved_matmul_encdec_valatt(kv_mx, att, heads=heads).asnumpy() r1 = sockeye.layers_pt.pytorch_interleaved_matmul_encdec_valatt( kv_pt, attpt, heads=heads).numpy() assert np.allclose(r0, r1)
def test_create_target_and_shifted_label_sequences(): pytest.importorskip('mxnet') from sockeye import data_io from mxnet import np target_and_label = np.array([[C.BOS_ID, 4, 17, 35, 12, C.EOS_ID, C.PAD_ID, C.PAD_ID], [C.BOS_ID, 15, 23, 23, 77, 55, 22, C.EOS_ID], [C.BOS_ID, 4, C.EOS_ID, C.PAD_ID, C.PAD_ID, C.PAD_ID, C.PAD_ID, C.PAD_ID]]) target_and_label = np.expand_dims(target_and_label, axis=2) expected_lengths = np.array([5, 7, 2]) target, label = data_io.create_target_and_shifted_label_sequences(target_and_label) assert target.shape[0] == label.shape[0] == target_and_label.shape[0] assert target.shape[1] == label.shape[1] == target_and_label.shape[1] - 1 lengths = (target != C.PAD_ID).sum(axis=1).squeeze() assert np.allclose(lengths, expected_lengths)
def test_mx_pt_eq_weight_normalization(): pytest.importorskip("mxnet") from mxnet import np import sockeye.layers num_hidden = 3 weight_mx = np.random.uniform(0, 1, size=(num_hidden, 4)) weight_pt = pt.as_tensor(weight_mx.asnumpy()) b_mx = sockeye.layers.WeightNormalization(num_hidden=num_hidden) b_mx.initialize() b_pt = sockeye.layers_pt.PyTorchWeightNormalization(num_hidden=num_hidden) result_mx = b_mx(weight_mx).asnumpy() result_pt = b_pt(weight_pt).detach().numpy() assert np.allclose(result_mx, result_pt)
def test_mx_pt_eq_dot_attention_cell(qlen, kvlen, batch_size, hidden, heads): pytest.importorskip("mxnet") from mxnet import np import sockeye.layers import sockeye.transformer import sockeye.transformer_pt q_mx = np.random.uniform(0, 1, (qlen, batch_size, hidden)) kv_mx = np.random.uniform(0, 1, (kvlen, batch_size, hidden * 2)) q_pt = pt.as_tensor(q_mx.asnumpy()) kv_pt = pt.as_tensor(kv_mx.asnumpy()) if qlen == kvlen: # self-attention case bias_mx = sockeye.transformer.AutoRegressiveBias() bias_mx.initialize() autoregr_mx = bias_mx(q_mx.transpose(1, 0, 2)) mx_args = (None, autoregr_mx) # no source mask, autoregr mask mask_pt = sockeye.transformer_pt.AutoRegressiveMask() att_mask_pt = mask_pt(q_pt.permute(1, 0, 2)) else: # cross-attention lengths_mx = np.random.randint( 1, kvlen, (batch_size, ), ) valid_lengths_mx = sockeye.layers.prepare_source_valid_lengths( lengths_mx, q_mx.transpose(1, 0, 2), heads) mx_args = (valid_lengths_mx, None) # source mask, no autoregr mask lengths_pt = pt.tensor(lengths_mx.asnumpy()) att_mask_pt = sockeye.layers_pt.prepare_source_length_mask( lengths_pt, heads, kvlen) b_mx = sockeye.layers.DotAttentionCell(dropout=0.0) b_mx.initialize() b_pt = sockeye.layers_pt.PyTorchDotAttentionCell(dropout=0.0, heads=heads) r_mx = b_mx(q_mx, kv_mx, heads, *mx_args).asnumpy() r_pt = b_pt(q_pt, kv_pt, mask=att_mask_pt).detach().numpy() assert np.allclose(r_mx, r_pt, atol=1e-06)
def test_mx_pt_eq_multi_head_attention_base(qlen, kvlen, batch_size, hidden, heads): pytest.importorskip("mxnet") from mxnet import np import sockeye.layers q_mx = np.random.uniform(0, 1, (qlen, batch_size, hidden)) kv_mx = np.random.uniform(0, 1, (kvlen, batch_size, hidden * 2)) q_pt = pt.as_tensor(q_mx.asnumpy()) kv_pt = pt.as_tensor(kv_mx.asnumpy()) b_mx = sockeye.layers.MultiHeadAttentionBase(hidden, heads, hidden) b_mx.initialize() b_pt = sockeye.layers_pt.PyTorchMultiHeadAttentionBase( hidden, heads, hidden) # use mxnet parameter initializations for pytorch block b_pt.ff_out.weight.data[:] = pt.as_tensor( b_mx.ff_out.weight.data().asnumpy()) r_mx = b_mx._attend(q_mx, kv_mx, None, None).asnumpy() r_pt = b_pt._attend(q_pt, kv_pt, mask=None).detach().numpy() assert np.allclose(r_mx, r_pt, atol=1e-06)
def test_mx_pt_eq_transformer_decoder(inference_only): pytest.importorskip("mxnet") import sockeye.transformer import sockeye.decoder import mxnet as mx from mxnet import np pt.manual_seed(13) mx.random.seed(13) config_mx = sockeye.transformer.TransformerConfig(model_size=128, attention_heads=8, feed_forward_num_hidden=256, act_type='relu', num_layers=12, dropout_attention=0, dropout_act=0, dropout_prepost=0, positional_embedding_type=C.FIXED_POSITIONAL_EMBEDDING, preprocess_sequence='n', postprocess_sequence='r', max_seq_len_source=50, max_seq_len_target=60, depth_key_value=128, use_lhuc=False) config_pt = sockeye.transformer_pt.TransformerConfig(model_size=128, attention_heads=8, feed_forward_num_hidden=256, act_type='relu', num_layers=12, dropout_attention=0, dropout_act=0, dropout_prepost=0, positional_embedding_type=C.FIXED_POSITIONAL_EMBEDDING, preprocess_sequence='n', postprocess_sequence='r', max_seq_len_source=50, max_seq_len_target=60, depth_key_value=128, use_lhuc=False) batch = 12 encoder_seq_len = 45 decoder_seq_len = 39 if not inference_only else 1 encoder_outputs_mx = np.random.uniform(0, 1, (batch, encoder_seq_len, config_mx.model_size)) encoder_outputs_pt = pt.tensor(encoder_outputs_mx.asnumpy()) encoder_valid_length_mx = np.random.randint(1, encoder_seq_len, (batch,)) encoder_valid_length_pt = pt.tensor(encoder_valid_length_mx.asnumpy()) inputs_mx = np.random.uniform(0, 1, (batch, decoder_seq_len, config_mx.model_size)) inputs_pt = pt.tensor(inputs_mx.asnumpy()) # mx decoder_mx = sockeye.decoder.get_decoder(config_mx, inference_only=inference_only, dtype=C.DTYPE_FP32) decoder_mx.initialize() init_states_mx = decoder_mx.init_state_from_encoder(encoder_outputs_mx, encoder_valid_length_mx) output_mx, new_states_mx = decoder_mx(inputs_mx, init_states_mx) if inference_only: # do a second decoder step output_mx, new_states_mx = decoder_mx(output_mx, new_states_mx) # pt decoder_pt = sockeye.decoder_pt.pytorch_get_decoder(config_pt, inference_only=inference_only) decoder_pt.weights_from_mxnet_block(decoder_mx) decoder_pt.eval() init_states_pt = decoder_pt.init_state_from_encoder(encoder_outputs_pt, encoder_valid_length_pt) output_pt, new_states_pt = decoder_pt(inputs_pt, init_states_pt) if inference_only: # do a second decoder step output_pt, new_states_pt = decoder_pt(output_pt, new_states_pt) assert decoder_mx.state_structure() == decoder_pt.state_structure() assert decoder_mx.get_num_hidden() == decoder_pt.get_num_hidden() assert len(init_states_mx) == len(init_states_pt) for s_mx, s_pt, structure in zip(init_states_mx, init_states_pt, decoder_mx.state_structure()): if structure != C.MASK_STATE: # MASK state is new in Pytorch and not equivalent assert np.allclose(s_mx.asnumpy(), s_pt.detach().numpy(), atol=1e-05) output_mx = output_mx.asnumpy() output_pt = output_pt.detach().numpy() print("Max deviation:", onp.abs(output_mx - output_pt).max()) assert np.allclose(output_mx, output_pt, atol=1e-05) assert len(new_states_mx) == len(new_states_pt) for i, (s_mx, s_pt, structure) in enumerate(zip(new_states_mx, new_states_pt, decoder_mx.state_structure())): if structure != C.MASK_STATE: # MASK state is new in Pytorch and not equivalent assert np.allclose(s_mx.asnumpy(), s_pt.detach().numpy(), atol=1e-05)
def test_mx_pt_eq_sockeye_model(): pytest.importorskip('mxnet') from mxnet import np import sockeye.transformer import sockeye.encoder import sockeye.model # model setup source_vocab_size = target_vocab_size = 32000 num_embed_source = num_embed_target = model_size = 512 max_seq_len_source = max_seq_len_target = 100 num_source_factors = 1 num_target_factors = 1 num_layers = 4 weight_tying = False batch_size = 4 topk_size = 200 config_encoder = sockeye.transformer.TransformerConfig( model_size=model_size, attention_heads=8, feed_forward_num_hidden=256, act_type='relu', num_layers=num_layers, dropout_attention=0, dropout_act=0, dropout_prepost=0, positional_embedding_type=C.FIXED_POSITIONAL_EMBEDDING, preprocess_sequence='n', postprocess_sequence='r', max_seq_len_source=max_seq_len_source, max_seq_len_target=max_seq_len_target, use_lhuc=False) config_encoder_pt = sockeye.transformer_pt.TransformerConfig( model_size=model_size, attention_heads=8, feed_forward_num_hidden=256, act_type='relu', num_layers=num_layers, dropout_attention=0, dropout_act=0, dropout_prepost=0, positional_embedding_type=C.FIXED_POSITIONAL_EMBEDDING, preprocess_sequence='n', postprocess_sequence='r', max_seq_len_source=max_seq_len_source, max_seq_len_target=max_seq_len_target, use_lhuc=False) config_decoder = sockeye.transformer.TransformerConfig( model_size=model_size, attention_heads=8, feed_forward_num_hidden=256, act_type='relu', num_layers=num_layers, dropout_attention=0, dropout_act=0, dropout_prepost=0, positional_embedding_type=C.FIXED_POSITIONAL_EMBEDDING, preprocess_sequence='n', postprocess_sequence='r', max_seq_len_source=max_seq_len_source, max_seq_len_target=max_seq_len_target, depth_key_value=model_size, use_lhuc=False) config_decoder_pt = sockeye.transformer_pt.TransformerConfig( model_size=model_size, attention_heads=8, feed_forward_num_hidden=256, act_type='relu', num_layers=num_layers, dropout_attention=0, dropout_act=0, dropout_prepost=0, positional_embedding_type=C.FIXED_POSITIONAL_EMBEDDING, preprocess_sequence='n', postprocess_sequence='r', max_seq_len_source=max_seq_len_source, max_seq_len_target=max_seq_len_target, depth_key_value=model_size, use_lhuc=False) config_embed_source = sockeye.encoder.EmbeddingConfig( vocab_size=source_vocab_size, num_embed=num_embed_source, dropout=0, factor_configs=None, allow_sparse_grad=False) config_embed_target = sockeye.encoder.EmbeddingConfig( vocab_size=target_vocab_size, num_embed=num_embed_target, dropout=0, factor_configs=None, allow_sparse_grad=False) data_statistics = sockeye.data_io_pt.DataStatistics( num_sents=0, num_discarded=0, num_tokens_source=0, num_tokens_target=0, num_unks_source=0, num_unks_target=0, max_observed_len_source=100, max_observed_len_target=100, size_vocab_source=source_vocab_size, size_vocab_target=target_vocab_size, length_ratio_mean=1.0, length_ratio_std=0.001, buckets=[], num_sents_per_bucket=[], average_len_target_per_bucket=[], length_ratio_stats_per_bucket=None) data_config = sockeye.data_io_pt.DataConfig( data_statistics=data_statistics, max_seq_len_source=max_seq_len_source, max_seq_len_target=max_seq_len_target, num_source_factors=num_source_factors, num_target_factors=num_target_factors) config_length_task = None model_config = sockeye.model.ModelConfig( config_data=data_config, vocab_source_size=source_vocab_size, vocab_target_size=target_vocab_size, config_embed_source=config_embed_source, config_embed_target=config_embed_target, config_encoder=config_encoder, config_decoder=config_decoder, config_length_task=config_length_task, weight_tying_type=C.WEIGHT_TYING_NONE, lhuc=False, dtype=C.DTYPE_FP32) model_config_pt = sockeye.model.ModelConfig( config_data=data_config, vocab_source_size=source_vocab_size, vocab_target_size=target_vocab_size, config_embed_source=config_embed_source, config_embed_target=config_embed_target, config_encoder=config_encoder_pt, config_decoder=config_decoder_pt, config_length_task=config_length_task, weight_tying_type=C.WEIGHT_TYING_NONE, lhuc=False, dtype=C.DTYPE_FP32) # inputs source_inputs_mx = np.random.randint( 0, max_seq_len_source, (batch_size, max_seq_len_source, num_source_factors)) source_input_lengths_mx = np.random.randint(0, max_seq_len_source, (batch_size, )) target_inputs_mx = np.random.randint( 0, max_seq_len_target, (batch_size, max_seq_len_target, num_source_factors)) target_input_lengths_mx = np.random.randint(0, max_seq_len_target, (batch_size, )) source_inputs_pt = pt.tensor(source_inputs_mx.asnumpy()) source_input_lengths_pt = pt.tensor(source_input_lengths_mx.asnumpy()) target_inputs_pt = pt.tensor(target_inputs_mx.asnumpy()) target_input_lengths_pt = pt.tensor(target_input_lengths_mx.asnumpy()) step_inputs_mx = np.random.randint(0, target_vocab_size, (batch_size, num_target_factors)) vocab_slice_ids_mx = np.random.randint(0, target_vocab_size, (topk_size, )) step_inputs_pt = pt.tensor(step_inputs_mx.asnumpy()) vocab_slice_ids_pt = pt.tensor(vocab_slice_ids_mx.asnumpy()) b_mx = sockeye.model.SockeyeModel(model_config, inference_only=False, mc_dropout=False, forward_pass_cache_size=0) b_mx.initialize() b_pt = sockeye.model_pt.PyTorchSockeyeModel(model_config_pt, inference_only=False, mc_dropout=False, forward_pass_cache_size=0) assert b_mx.state_structure() == b_pt.state_structure() # test forward() # first run mx block to complete deferred initialization forward_dict_mx = b_mx(source_inputs_mx, source_input_lengths_mx, target_inputs_mx, target_input_lengths_mx) # get weights from mx into pt b_pt.weights_from_mxnet_block(b_mx) forward_dict_pt = b_pt(source_inputs_pt, source_input_lengths_pt, target_inputs_pt, target_input_lengths_pt) assert forward_dict_mx.keys() == forward_dict_pt.keys() logits_mx = forward_dict_mx[C.LOGITS_NAME].asnumpy() logits_pt = forward_dict_pt[C.LOGITS_NAME].detach().numpy() assert np.allclose(logits_mx, logits_pt, atol=1e-05) # test encode() source_encoded_mx, source_encoded_length_mx = b_mx.encode( source_inputs_mx, source_input_lengths_mx) source_encoded_pt, source_encoded_length_pt = b_pt.encode( source_inputs_pt, source_input_lengths_pt) assert np.allclose(source_encoded_mx.asnumpy(), source_encoded_pt.detach().numpy(), atol=1e-05) assert np.allclose(source_encoded_length_mx.asnumpy(), source_encoded_length_pt.detach().numpy(), atol=1e-05) # test encode_and_initialize() init_states_mx, pred_out_length_mx = b_mx.encode_and_initialize( source_inputs_mx, source_input_lengths_mx, constant_length_ratio=0.0) init_states_pt, pred_out_length_pt = b_pt.encode_and_initialize( source_inputs_pt, source_input_lengths_pt, constant_length_ratio=0.0) if config_length_task is None: assert np.allclose(pred_out_length_mx.asnumpy(), np.zeros_like(source_input_lengths_mx).asnumpy()) assert np.allclose( pred_out_length_pt.detach().numpy(), pt.zeros_like(source_input_lengths_pt).detach().numpy()) else: assert pred_out_length_mx.asnumpy() == pred_out_length_pt.detach( ).numpy() assert len(init_states_mx) == len(init_states_pt) state_structure = b_pt.decoder.state_structure() for s_mx, s_pt, structure in zip(init_states_mx, init_states_pt, state_structure): if structure != C.MASK_STATE: # MASK state is new in Pytorch and not equivalent assert np.allclose(s_mx.asnumpy(), s_pt.detach().numpy(), atol=1e-05) # test decode_step() b_pt.eval() states_mx = init_states_mx states_pt = init_states_pt step_output_mx, states_mx, factor_outputs_mx = b_mx.decode_step( step_inputs_mx, states_mx, vocab_slice_ids=vocab_slice_ids_mx) step_output_pt, states_pt, factor_outputs_pt = b_pt.decode_step( step_inputs_pt, states_pt, vocab_slice_ids=vocab_slice_ids_pt) assert np.allclose(step_output_mx.asnumpy(), step_output_pt.detach().numpy(), atol=1e-05) assert step_output_mx.asnumpy().shape == step_output_pt.detach().numpy( ).shape == (batch_size, topk_size) assert len(factor_outputs_mx) == len(factor_outputs_pt) # TODO assert factor outputs equality assert len(states_mx) == len(states_pt) for s_mx, s_pt, structure in zip(states_mx, states_pt, state_structure): if structure != C.MASK_STATE: # MASK state is new in Pytorch and not equivalent assert np.allclose(s_mx.asnumpy(), s_pt.detach().numpy(), atol=1e-05) from pprint import pprint pprint(b_mx.collect_params()) for param_tensor in b_pt.state_dict(): print(param_tensor, "\t", b_pt.state_dict()[param_tensor].size()) # save & load parameters with TemporaryDirectory() as work_dir: fname = os.path.join(work_dir, 'params.pt') b_pt.save_parameters(fname) b_pt.load_parameters(fname) forward_dict_pt = b_pt(source_inputs_pt, source_input_lengths_pt, target_inputs_pt, target_input_lengths_pt) assert forward_dict_mx.keys() == forward_dict_pt.keys() logits_mx = forward_dict_mx[C.LOGITS_NAME].asnumpy() logits_pt = forward_dict_pt[C.LOGITS_NAME].detach().numpy() assert np.allclose(logits_mx, logits_pt, atol=1e-05)