示例#1
0
 def test_save_load(self):
     try:
         os.environ[
             'TF_CPP_MIN_LOG_LEVEL'] = '3'  # shut tensorflow up during testing
         # Create a model with a downstream task
         tw = TextWiser(Embedding.USE(), dtype=torch.float32).fit(docs)
         model = nn.Sequential(tw, nn.Linear(512, 1)).to(device)
         # Get results of the model
         expected = model(docs)
         # Save the model to a temporary file
         with NamedTemporaryFile() as file:
             state_dict = model.state_dict()
             self.assertNotIn('0._imp.0.use', state_dict)
             torch.save(state_dict, file)  # Use string name of the file
             # Get rid of the original model
             del tw
             del model
             # Create the same model
             tw = TextWiser(Embedding.USE(), dtype=torch.float32)
             tw.fit()
             model = nn.Sequential(tw, nn.Linear(512, 1)).to(device)
             # Load the model from file
             file.seek(0)
             model.load_state_dict(torch.load(file, map_location=device))
             # Do predictions with the loaded model
             predicted = model(docs)
             self.assertTrue(torch.allclose(predicted, expected, atol=1e-6))
     except ModuleNotFoundError:
         print('No Tensorflow found. Skipping the test. ...',
               end=" ",
               flush=True)
示例#2
0
 def test_pretrained_error(self):
     # Not a pretrained model
     try:
         os.environ[
             'TF_CPP_MIN_LOG_LEVEL'] = '3'  # shut tensorflow up during testing
         with self.assertRaises(ValueError):
             TextWiser(Embedding.USE(pretrained=None), dtype=torch.float32)
     except ModuleNotFoundError:
         print('No Tensorflow found. Skipping the test. ...',
               end=" ",
               flush=True)
示例#3
0
 def test_use_versions(self):
     """Tests if the previous versions of USE are useable"""
     try:
         os.environ[
             'TF_CPP_MIN_LOG_LEVEL'] = '3'  # shut tensorflow up during testing
         TextWiser(Embedding.USE(
             pretrained=
             'https://tfhub.dev/google/universal-sentence-encoder-large/5'),
                   dtype=torch.float32).fit_transform(docs)
         TextWiser(Embedding.USE(
             pretrained=
             'https://tfhub.dev/google/universal-sentence-encoder-large/4'),
                   dtype=torch.float32).fit_transform(docs)
         TextWiser(Embedding.USE(
             pretrained=
             'https://tfhub.dev/google/universal-sentence-encoder-large/3'),
                   dtype=torch.float32).fit_transform(docs)
     except ModuleNotFoundError:
         print('No Tensorflow found. Skipping the test. ...',
               end=" ",
               flush=True)
示例#4
0
 def test_fit_transform(self):
     try:
         os.environ[
             'TF_CPP_MIN_LOG_LEVEL'] = '3'  # shut tensorflow up during testing
         tw = TextWiser(Embedding.USE(), dtype=torch.float32)
         expected = torch.from_numpy(
             np.genfromtxt(self._get_test_path('data',
                                               'use_embeddings.csv'),
                           dtype=np.float32))
         self._test_fit_transform(tw, expected)
         self._test_fit_before_transform(tw, expected)
     except ModuleNotFoundError:
         print('No Tensorflow found. Skipping the test. ...',
               end=" ",
               flush=True)