def test_embeddings_trainable_params(self): for params, init_case in itertools.product(self.PARAMS, self.cases()): emb = Embeddings(**init_case) trainable_params = { n: p for n, p in emb.named_parameters() if p.requires_grad } # first check there's nothing unexpectedly not trainable for key in emb.state_dict(): if key not in trainable_params: if key.endswith("emb_luts.0.weight") and \ init_case["freeze_word_vecs"]: # ok: word embeddings shouldn't be trainable # if word vecs are freezed continue if key.endswith(".pe.pe"): # ok: positional encodings shouldn't be trainable assert init_case["position_encoding"] continue else: self.fail("Param {:s} is unexpectedly not " "trainable.".format(key)) # then check nothing unexpectedly trainable if init_case["freeze_word_vecs"]: self.assertFalse( any( trainable_param.endswith("emb_luts.0.weight") for trainable_param in trainable_params), "Word embedding is trainable but word vecs are freezed.") if init_case["position_encoding"]: self.assertFalse( any( trainable_p.endswith(".pe.pe") for trainable_p in trainable_params), "Positional encoding is trainable.")
def test_embeddings_trainable_params(self): for params, init_case in itertools.product(self.PARAMS, self.cases()): emb = Embeddings(**init_case) trainable_params = {n: p for n, p in emb.named_parameters() if p.requires_grad} # first check there's nothing unexpectedly not trainable for key in emb.state_dict(): if key not in trainable_params: if key.endswith("emb_luts.0.weight") and \ init_case["fix_word_vecs"]: # ok: word embeddings shouldn't be trainable # if word vecs are fixed continue if key.endswith(".pe.pe"): # ok: positional encodings shouldn't be trainable assert init_case["position_encoding"] continue else: self.fail("Param {:s} is unexpectedly not " "trainable.".format(key)) # then check nothing unexpectedly trainable if init_case["fix_word_vecs"]: self.assertFalse( any(trainable_param.endswith("emb_luts.0.weight") for trainable_param in trainable_params), "Word embedding is trainable but word vecs are fixed.") if init_case["position_encoding"]: self.assertFalse( any(trainable_p.endswith(".pe.pe") for trainable_p in trainable_params), "Positional encoding is trainable.")
def test_embeddings_forward_shape(self): for params, init_case in itertools.product(self.PARAMS, self.cases()): emb = Embeddings(**init_case) dummy_in = self.dummy_inputs(params, init_case) res = emb(dummy_in) expected_shape = self.expected_shape(params, init_case) self.assertEqual(res.shape, expected_shape, init_case.__str__())
def test_embeddings_trainable_params_update(self): for params, init_case in itertools.product(self.PARAMS, self.cases()): emb = Embeddings(**init_case) trainable_params = {n: p for n, p in emb.named_parameters() if p.requires_grad} if len(trainable_params) > 0: old_weights = deepcopy(trainable_params) dummy_in = self.dummy_inputs(params, init_case) res = emb(dummy_in) pretend_loss = res.sum() pretend_loss.backward() dummy_optim = torch.optim.SGD(trainable_params.values(), 1) dummy_optim.step() for param_name in old_weights.keys(): self.assertTrue( trainable_params[param_name] .ne(old_weights[param_name]).any(), param_name + " " + init_case.__str__())
def _get_encoder(self, use_embedding_proj): emb = Embeddings(DIM, VOCAB_SIZE, 0) encoder = TransformerEatEncoder(1, DIM, 5, DIM, 0.1, 0.1, emb, 0, use_embedding_proj) return encoder