class CNN(nn.Module): def __init__(self, args, reduced_size=None, info={}): super(CNN, self).__init__() # disc_type=DISC_TYPE_MATRIX self.disc_type = disc_type = args.disc_type self.layer1 = nn.Sequential( nn.Conv2d(1, 4, kernel_size=2, padding=0), nn.ReLU()) # 1,4,3,3 self.layer2 = nn.Sequential( nn.Conv2d(4, 8, kernel_size=2), nn.ReLU()) # 1,8,2,2 ## but for 5 lines, it is 1,8,3,3 if args.data_type == "sonnet_endings": self.scorer = nn.Linear(2 * 2 * 8, 1) elif args.data_type == "limerick": self.scorer = nn.Linear(3 * 3 * 8, 1) self.predictor = nn.Sigmoid() self.args = args self.use_cuda = args.use_cuda ## self.g_indexer = Indexer(args) self.g_indexer.load('tmp/tmp_' + args.g2p_model_name + '/solver_g_indexer') self.g2pmodel = Model(H=info['H'], args=args, i_size=self.g_indexer.w_cnt, o_size=self.g_indexer.w_cnt, start_idx=self.g_indexer.w2idx[utils.START]) if not args.learn_g2p_encoder_from_scratch: print("=====" * 7, "LOADING g2p ENCODER PRETRAINED") model_dir = 'tmp/tmp_' + args.g2p_model_name + '/' state_dict_best = torch.load(model_dir + 'model_best') self.g2pmodel.load_state_dict(state_dict_best) if not args.trainable_g2p: assert not args.learn_g2p_encoder_from_scratch for param in self.g2pmodel.parameters(): param.requires_grad = False def display_params(self): print("=" * 44) print("[CNN]: model parametrs") for name, param in self.named_parameters(): print("name=", name, " || grad:", param.requires_grad, "| size = ", param.size()) print("=" * 44) def _compute_word_reps(self, words_str, deb=False): if deb: print("words_str = ", words_str) use_eow_marker = self.args.use_eow_in_enc assert not use_eow_marker, "Not yet tested" word_reps = [self.g_indexer.w_to_idx(s1) for s1 in words_str] if self.args.use_eow_in_enc: x_end = self.g_indexer.w2idx[utils.END] word_reps = [x_i + [x_end] for x_i in word_reps] word_reps = [self.g2pmodel.encode(w) for w in word_reps] return word_reps def _compute_pairwise_dot(self, measure_encodings_b): ret = [] sz = len(measure_encodings_b) for measure_encodings_b_t in measure_encodings_b: for measure_encodings_b_t2 in measure_encodings_b: t1 = torch.sum(measure_encodings_b_t * measure_encodings_b_t2) t2 = torch.sqrt(torch.sum(measure_encodings_b_t * measure_encodings_b_t)) t3 = torch.sqrt(torch.sum(measure_encodings_b_t2 * measure_encodings_b_t2)) assert t2 > 0 assert t3 > 0, "t3=" + str(t3) ret.append(t1 / (t2 * t3)) ret = torch.stack(ret) ret = ret.view(sz, sz) return ret def _score_matrix(self, x, deb=False): x = x[0].unsqueeze(0).unsqueeze(0) # -> 1,1,ms,ms if deb: print("---x.shape = ", x.size()) out = self.layer1(x) if deb: print("---out = ", out.size(), out) out = self.layer2(out) if deb: print("---out = ", out.size(), out) out = out.view(out.size(0), -1) # arrange by bsz score = self.scorer(out) if deb: print("---out sum = ", torch.sum(out)) print("---score = ", score) prob = self.predictor(score) return {'prob': prob, 'out': out, 'score': score} def _compute_rhyming_matrix(self, words_str, deb=False): word_reps = self._compute_word_reps(words_str) rhyming_matrix = self._compute_pairwise_dot(word_reps) return rhyming_matrix, words_str def _compute_rnn_on_word_reps(self, word_reps): h = torch.zeros(1, self.linear_rep_H), torch.zeros(1, self.linear_rep_H) if self.use_cuda: h = h[0].cuda(), h[1].cuda() for w in word_reps: h = self.linear_rep_encoder(w, h) out, c = h return c def _run_discriminator(self, words_str, deb): rhyming_matrix, words_str = self._compute_rhyming_matrix(words_str, deb) vals = self._score_matrix([rhyming_matrix]) vals.update({'rhyming_matrix': rhyming_matrix, 'linear_rep': None, 'words_str': words_str}) return vals def update_discriminator(self, line_endings_gen, line_endings_train, deb=False, word_idx_to_str_dict=None): eps = 0.0000000001 ret = {} dump_info = {} words_str_train = [word_idx_to_str_dict[word_idx.data.cpu().item()] for word_idx in line_endings_train] words_str_gen = [word_idx_to_str_dict[word_idx.data.cpu().item()] for word_idx in line_endings_gen] disc_real = self._run_discriminator(words_str_train, deb) if deb: print("rhyming_matrix_trai = ", disc_real['rhyming_matrix'], "|| prob = ", disc_real['prob']) if self.args.disc_type == DISC_TYPE_MATRIX: dump_info['rhyming_matrix_trai'] = disc_real['rhyming_matrix'].data.cpu().numpy() dump_info['real_prob'] = disc_real['prob'].data.cpu().item() dump_info['real_words_str'] = disc_real['words_str'] disc_gen = self._run_discriminator(words_str_gen, deb) if deb: print("rhyming_matrix_gen = ", disc_gen['rhyming_matrix'], "|| prob = ", disc_gen['prob']) if self.args.disc_type == DISC_TYPE_MATRIX: dump_info['rhyming_matrix_gen'] = disc_gen['rhyming_matrix'].data.cpu().numpy() dump_info['gen_prob'] = disc_gen['prob'].data.cpu().item() dump_info['gen_words_str'] = disc_gen['words_str'] prob_real = disc_real['prob'] prob_gen = disc_gen['prob'] loss = -torch.log(prob_real + eps) - torch.log(1.0 - prob_gen + eps) reward = prob_gen if self.args.use_score_as_reward: reward = disc_gen['score'] ret.update({'loss': loss, 'reward': reward, 'dump_info': dump_info}) return ret