def test_len_reoder(): vals_to_reorder = (torch.Tensor([[1,1,2,2], [3,4,4,3], [5,6,7,8]]), torch.randn(3, 4, 3)) sorted_lens, sorting_inds, vals_after_reorder = reorder_based_off_len( input_lens=torch.LongTensor([2, 4, 1]), vals_to_reorder=vals_to_reorder ) assert torch_epsilon_eq(sorted_lens, torch.LongTensor([4, 2, 1])) assert torch_epsilon_eq(sorting_inds, torch.LongTensor([1, 0, 2])) assert torch_epsilon_eq(vals_after_reorder[0], torch.Tensor([[3,4,4,3], [1,1,2,2], [5,6,7,8]])) put_back_together = undo_len_ordering(sorting_inds, vals_after_reorder) assert torch_epsilon_eq(put_back_together[0], vals_to_reorder[0]) assert torch_epsilon_eq(put_back_together[1], vals_to_reorder[1])
def test_latent_store(): instance = TorchLatentStore(type_ind_to_latents=[ LatentDataWrapper.construct( latents=torch.Tensor([[1, 2], [3, 1], [-3, 1], [2, -1]]), example_indxs=torch.LongTensor([0, 1, 2, 3]), y_inds=torch.LongTensor([2, 2, 2, 2]), impl_choices=torch.LongTensor([3, 1, 9, 1])) ], example_to_depths_to_ind_in_types=None) data, similarities = instance.get_n_nearest_latents( torch.Tensor([1, 1]), 0) print(similarities) assert torch.all(data.example_indxs == torch.LongTensor([1, 0, 3, 2])) assert torch.all(data.impl_choices == torch.LongTensor([1, 3, 1, 9])) torch_epsilon_eq(similarities, torch.Tensor([4, 3, 1, -2]))
def test_get_kernel_around(): a, b, c = [3., 5], [1., 6], [1., 8] assert torch_epsilon_eq( get_kernel_around(torch.tensor([[[1., 2], a, b, c, [3, 5]]]), index=2, tokens_before_channels=True), torch.tensor([[a, b, c]]))
def test_get_kernel_around_pad(): a, b, c = [3., 5], [1., 6], [1., 8] assert torch_epsilon_eq( get_kernel_around(torch.tensor([[c, a, b, c, a]]), index=0, tokens_before_channels=True), torch.tensor([[[0, 0], c, a]]))
def test_word_lens_tokens_with_pad2(): v = get_word_lens_of_moded_tokens( [[mtok("my"), mtok("nme"), mtok("'s", False)], [mtok("my"), mtok("nme"), mtok("'s", False), mtok(".")], [mtok("my"), mtok("nme")]]) assert torch_epsilon_eq(v, [[1, 2, 2, 0], [1, 2, 2, 1], [1, 1, 0, 0]])
def test_get_input_lens_mask_expanded(): assert torch_epsilon_eq( get_input_lengths_mask_expanded(torch.tensor([1, 3, 2]), 2), torch.tensor([ [[1., 1], [0, 0], [0, 0]], [[1, 1], [1, 1], [1, 1]], [[1, 1], [1, 1], [0, 0]], ]))
def test_get_input_lens_mask(): assert torch_epsilon_eq( get_input_lengths_mask(torch.tensor([1, 3, 2])), torch.tensor([ [True, False, False], [True, True, True], [True, True, False], ]))
def test_stop_word_mask5(): tokenizer = ModifiedWordPieceTokenizer( ["foo", "bar", "baz", "the"]) mask = get_non_stop_word_mask(*tokenizer.tokenize("foo THE bar"), stop_words={"the"}) assert torch_epsilon_eq( mask, [1, 1, 0, 1, 1] )
def test_pack_picks2(): val = pack_picks([ torch.tensor([1, 2, 3]), torch.tensor([4, 5]), torch.tensor([6, 7, 8]) ], [torch.tensor([0, 2]), torch.tensor([]), torch.tensor([1, 2])]) assert torch_epsilon_eq(val, torch.tensor([1, 3, 7, 8]))
def test_stop_word_mask6(): tokenizer = ModifiedWordPieceTokenizer( ["foo", "bar", "baz", "the"]) mask = get_non_stop_word_mask(*tokenizer.tokenize("foo the bar"), stop_words={"the"}, pad_to_len=8) assert torch_epsilon_eq( mask, [1, 1, 0, 1, 1, 0, 0, 0] )
def test_training(builder_store): trainer = builder_store.get_default_trainer() target = torch.Tensor([-0.4, 1, 2]) type_ind = 1 # FT is type_ind 1 example_id = 1 dfs_depth = 0 assert not torch_epsilon_eq(builder_store.get_latent_for_example( type_ind, example_id, dfs_depth), target, epsilon=1e-3) for i in range(100): print( builder_store.get_latent_for_example(type_ind, example_id, dfs_depth)) trainer.update_value(1, 1, 0, target) assert torch_epsilon_eq(builder_store.get_latent_for_example( type_ind, example_id, dfs_depth), target, epsilon=1e-2)
def test_pack_picks3(): val = pack_picks([ torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]]), torch.tensor([[3, 2, 1], [4, 3, 2]]), torch.tensor([[1, 8, 3], [4, 9, 6], [7, 0, 9]]) ], [torch.tensor([0, 2]), torch.tensor([]), torch.tensor([1, 2])]) assert torch_epsilon_eq( val, torch.tensor([[1, 2, 3], [7, 8, 9], [4, 9, 6], [7, 0, 9]]))
def test_word_lens_tokens(): v = get_word_lens_of_moded_tokens([[ mtok("my"), mtok("nme"), mtok("'s", False), mtok("s"), mtok("o", False), mtok("p", False), mtok(".") ]]) assert torch_epsilon_eq(v, [[1, 2, 2, 3, 3, 3, 1]])
def test_avg_pool4(use_cuda): if use_cuda and not torch.cuda.is_available(): pytest.skip("CUDA not available. Skipping") dvc = torch.device("cuda" if use_cuda else "cpu") assert torch_epsilon_eq( avg_pool( torch.tensor([[[0., 3., 3.], [6., 3., 1.], [0., 0., 0.]], [[8, 3., 4.], [4., 3., 6.], [6., 0., 11.]], [[8, 3., 4.], [4., 3., 6.], [0., 0., 0.]]], device=dvc), torch.tensor([1, 3, 2], device=dvc)), torch.tensor([[0., 3., 3.], [6., 2., 7.], [6., 3., 5.]], device=dvc))
def test_get_valid_for_copy_mask(): assert torch_epsilon_eq( get_valid_for_copy_mask([ [ _ms(SOS), _ms("a"), _ms(parse_constants.SPACE), _ms("b"), _ms(EOS) ], [_ms(SOS), _ms("a"), _ms(PAD), _ms(PAD), _ms(EOS)], ]), torch.tensor([[0, 1, 0, 1, 0], [0, 1, 0, 0, 0]]))
def test_torchvectdb(): builder = TorchVectorDatabaseBuilder( key_dimensionality=2, value_fields=[VDBIntField('foo'), VDBStringField('bar')] ) builder.add_data(np.array([1., 3]), (2, 'hi')) builder.add_data(np.array([1., -3]), (5, 'yo')) builder.add_data(np.array([-1., 1]), (7, 'jo')) db = builder.produce_result() values, sims = db.get_n_nearest(torch.tensor([4., -1]), max_results=2) assert values == [(5, 'yo'), (2, 'hi')] assert torch_epsilon_eq(sims, [7., 1]) nearest, sim = db.get_nearest(torch.tensor([4., -1])) assert nearest == (5, 'yo') assert sim == torch.tensor(7.)
def test_1dsame_check_weight(): mod = Conv1dSame(1, 1, 3, bias=False) mod.set_weight(nn.Parameter(torch.Tensor([[[1., 1., 1.]]]))) v = mod(torch.tensor([[[1., 2., 3., 4.]]])) assert torch_epsilon_eq(v, torch.tensor([[[3., 6, 9, 7]]]))
def test_latent_store_builder(builder_store): assert len(builder_store.type_ind_to_latents) == 2 assert torch_epsilon_eq(builder_store.type_ind_to_latents[1].example_indxs, [0, 1, 2]) assert list(builder_store.example_to_depths_to_ind_to_types[1]) == [1]
def test_avg_pool(): assert torch_epsilon_eq( avg_pool(torch.tensor([[[0., 3., 3.], [6., 3., 1.]]])), torch.tensor([[3., 3., 2.]]))
def test_sparce_groupby_sum2(): reduced, group_keys = sparse_groupby_sum(torch.Tensor([1, 2, 1, 5]), torch.Tensor([0, 1, 8, 0]), sort_out_keys=True) assert torch_epsilon_eq(reduced, [6, 2, 1]) assert torch_epsilon_eq(group_keys, torch.Tensor([0, 1, 8]))
def test_manual_bincount_weighted(): assert torch_epsilon_eq( manual_bincount(torch.Tensor([0, 1, 1, 0, 2]), torch.Tensor([1, 2, 1, 1, -2])), [2, 3, -2])
def test_word_lens_tokens2(): v = get_word_lens_of_moded_tokens( [[mtok("my"), mtok("nme"), mtok("'s", False)]]) assert torch_epsilon_eq(v, [[1, 2, 2]])
def test_avg_pool3(): assert torch_epsilon_eq( avg_pool( torch.tensor([[[0., 3., 3.], [6., 3., 1.]], [[8, 3., 4.], [4., 3., 6.]]]), torch.tensor([1, 2])), torch.tensor([[0., 3., 3.], [6., 3., 5.]]))
def test_manual_bincount(): assert torch_epsilon_eq(manual_bincount(torch.Tensor([0, 1, 1, 0, 2])), [2, 2, 1])
def test_get_kernel_around0(): assert torch_epsilon_eq( get_kernel_around(torch.tensor([[[1., 5, 7, 4, 6], [3., 6, 7, 2, 6]]]), index=2, tokens_before_channels=False), torch.tensor([[[5., 7, 4], [6, 7, 2]]]))