예제 #1
0
 def test_embeddings_trainable_params(self):
     for params, init_case in itertools.product(self.PARAMS, self.cases()):
         emb = Embeddings(**init_case)
         trainable_params = {
             n: p
             for n, p in emb.named_parameters() if p.requires_grad
         }
         # first check there's nothing unexpectedly not trainable
         for key in emb.state_dict():
             if key not in trainable_params:
                 if key.endswith("emb_luts.0.weight") and \
                         init_case["freeze_word_vecs"]:
                     # ok: word embeddings shouldn't be trainable
                     # if word vecs are freezed
                     continue
                 if key.endswith(".pe.pe"):
                     # ok: positional encodings shouldn't be trainable
                     assert init_case["position_encoding"]
                     continue
                 else:
                     self.fail("Param {:s} is unexpectedly not "
                               "trainable.".format(key))
         # then check nothing unexpectedly trainable
         if init_case["freeze_word_vecs"]:
             self.assertFalse(
                 any(
                     trainable_param.endswith("emb_luts.0.weight")
                     for trainable_param in trainable_params),
                 "Word embedding is trainable but word vecs are freezed.")
         if init_case["position_encoding"]:
             self.assertFalse(
                 any(
                     trainable_p.endswith(".pe.pe")
                     for trainable_p in trainable_params),
                 "Positional encoding is trainable.")
예제 #2
0
 def test_embeddings_trainable_params(self):
     for params, init_case in itertools.product(self.PARAMS,
                                                self.cases()):
         emb = Embeddings(**init_case)
         trainable_params = {n: p for n, p in emb.named_parameters()
                             if p.requires_grad}
         # first check there's nothing unexpectedly not trainable
         for key in emb.state_dict():
             if key not in trainable_params:
                 if key.endswith("emb_luts.0.weight") and \
                         init_case["fix_word_vecs"]:
                     # ok: word embeddings shouldn't be trainable
                     # if word vecs are fixed
                     continue
                 if key.endswith(".pe.pe"):
                     # ok: positional encodings shouldn't be trainable
                     assert init_case["position_encoding"]
                     continue
                 else:
                     self.fail("Param {:s} is unexpectedly not "
                               "trainable.".format(key))
         # then check nothing unexpectedly trainable
         if init_case["fix_word_vecs"]:
             self.assertFalse(
                 any(trainable_param.endswith("emb_luts.0.weight")
                     for trainable_param in trainable_params),
                 "Word embedding is trainable but word vecs are fixed.")
         if init_case["position_encoding"]:
             self.assertFalse(
                 any(trainable_p.endswith(".pe.pe")
                     for trainable_p in trainable_params),
                 "Positional encoding is trainable.")
예제 #3
0
 def test_embeddings_forward_shape(self):
     for params, init_case in itertools.product(self.PARAMS, self.cases()):
         emb = Embeddings(**init_case)
         dummy_in = self.dummy_inputs(params, init_case)
         res = emb(dummy_in)
         expected_shape = self.expected_shape(params, init_case)
         self.assertEqual(res.shape, expected_shape, init_case.__str__())
예제 #4
0
 def test_embeddings_trainable_params_update(self):
     for params, init_case in itertools.product(self.PARAMS, self.cases()):
         emb = Embeddings(**init_case)
         trainable_params = {n: p for n, p in emb.named_parameters()
                             if p.requires_grad}
         if len(trainable_params) > 0:
             old_weights = deepcopy(trainable_params)
             dummy_in = self.dummy_inputs(params, init_case)
             res = emb(dummy_in)
             pretend_loss = res.sum()
             pretend_loss.backward()
             dummy_optim = torch.optim.SGD(trainable_params.values(), 1)
             dummy_optim.step()
             for param_name in old_weights.keys():
                 self.assertTrue(
                     trainable_params[param_name]
                     .ne(old_weights[param_name]).any(),
                     param_name + " " + init_case.__str__())
예제 #5
0
 def _get_encoder(self, use_embedding_proj):
     emb = Embeddings(DIM, VOCAB_SIZE, 0)
     encoder = TransformerEatEncoder(1, DIM, 5, DIM, 0.1, 0.1, emb, 0,
                                     use_embedding_proj)
     return encoder