Exemple #1
0
class CNN(nn.Module):
    def __init__(self, args, reduced_size=None, info={}):
        super(CNN, self).__init__()
        # disc_type=DISC_TYPE_MATRIX
        self.disc_type = disc_type = args.disc_type
        self.layer1 = nn.Sequential(
            nn.Conv2d(1, 4, kernel_size=2, padding=0),
            nn.ReLU())
        # 1,4,3,3
        self.layer2 = nn.Sequential(
            nn.Conv2d(4, 8, kernel_size=2),
            nn.ReLU())
        # 1,8,2,2
        ## but for 5 lines, it is 1,8,3,3
        if args.data_type == "sonnet_endings":
            self.scorer = nn.Linear(2 * 2 * 8, 1)
        elif args.data_type == "limerick":
            self.scorer = nn.Linear(3 * 3 * 8, 1)
        self.predictor = nn.Sigmoid()
        self.args = args
        self.use_cuda = args.use_cuda

        ##
        self.g_indexer = Indexer(args)
        self.g_indexer.load('tmp/tmp_' + args.g2p_model_name + '/solver_g_indexer')
        self.g2pmodel = Model(H=info['H'], args=args, i_size=self.g_indexer.w_cnt, o_size=self.g_indexer.w_cnt,
                              start_idx=self.g_indexer.w2idx[utils.START])
        if not args.learn_g2p_encoder_from_scratch:
            print("=====" * 7, "LOADING g2p ENCODER PRETRAINED")
            model_dir = 'tmp/tmp_' + args.g2p_model_name + '/'
            state_dict_best = torch.load(model_dir + 'model_best')
            self.g2pmodel.load_state_dict(state_dict_best)
        if not args.trainable_g2p:
            assert not args.learn_g2p_encoder_from_scratch
            for param in self.g2pmodel.parameters():
                param.requires_grad = False

    def display_params(self):
        print("=" * 44)
        print("[CNN]: model parametrs")
        for name, param in self.named_parameters():
            print("name=", name, " || grad:", param.requires_grad, "| size = ", param.size())
        print("=" * 44)

    def _compute_word_reps(self, words_str, deb=False):
        if deb:
            print("words_str = ", words_str)
        use_eow_marker = self.args.use_eow_in_enc
        assert not use_eow_marker, "Not yet tested"
        word_reps = [self.g_indexer.w_to_idx(s1) for s1 in words_str]
        if self.args.use_eow_in_enc:
            x_end = self.g_indexer.w2idx[utils.END]
            word_reps = [x_i + [x_end] for x_i in word_reps]
        word_reps = [self.g2pmodel.encode(w) for w in word_reps]
        return word_reps

    def _compute_pairwise_dot(self, measure_encodings_b):
        ret = []
        sz = len(measure_encodings_b)
        for measure_encodings_b_t in measure_encodings_b:
            for measure_encodings_b_t2 in measure_encodings_b:
                t1 = torch.sum(measure_encodings_b_t * measure_encodings_b_t2)
                t2 = torch.sqrt(torch.sum(measure_encodings_b_t * measure_encodings_b_t))
                t3 = torch.sqrt(torch.sum(measure_encodings_b_t2 * measure_encodings_b_t2))
                assert t2 > 0
                assert t3 > 0, "t3=" + str(t3)
                ret.append(t1 / (t2 * t3))
        ret = torch.stack(ret)
        ret = ret.view(sz, sz)
        return ret

    def _score_matrix(self, x, deb=False):
        x = x[0].unsqueeze(0).unsqueeze(0)  # -> 1,1,ms,ms
        if deb:
            print("---x.shape = ", x.size())
        out = self.layer1(x)
        if deb:
            print("---out = ", out.size(), out)
        out = self.layer2(out)
        if deb:
            print("---out = ", out.size(), out)
        out = out.view(out.size(0), -1)  # arrange by bsz
        score = self.scorer(out)
        if deb:
            print("---out sum = ", torch.sum(out))
            print("---score = ", score)
        prob = self.predictor(score)
        return {'prob': prob, 'out': out, 'score': score}

    def _compute_rhyming_matrix(self, words_str, deb=False):
        word_reps = self._compute_word_reps(words_str)
        rhyming_matrix = self._compute_pairwise_dot(word_reps)
        return rhyming_matrix, words_str

    def _compute_rnn_on_word_reps(self, word_reps):
        h = torch.zeros(1, self.linear_rep_H), torch.zeros(1, self.linear_rep_H)
        if self.use_cuda:
            h = h[0].cuda(), h[1].cuda()
        for w in word_reps:
            h = self.linear_rep_encoder(w, h)
        out, c = h
        return c

    def _run_discriminator(self, words_str, deb):
        rhyming_matrix, words_str = self._compute_rhyming_matrix(words_str, deb)
        vals = self._score_matrix([rhyming_matrix])
        vals.update({'rhyming_matrix': rhyming_matrix, 'linear_rep': None, 'words_str': words_str})
        return vals

    def update_discriminator(self, line_endings_gen, line_endings_train, deb=False, word_idx_to_str_dict=None):
        eps = 0.0000000001
        ret = {}
        dump_info = {}
        words_str_train = [word_idx_to_str_dict[word_idx.data.cpu().item()] for word_idx in line_endings_train]
        words_str_gen = [word_idx_to_str_dict[word_idx.data.cpu().item()] for word_idx in line_endings_gen]
        disc_real = self._run_discriminator(words_str_train, deb)
        if deb:
            print("rhyming_matrix_trai = ", disc_real['rhyming_matrix'], "|| prob = ", disc_real['prob'])
            if self.args.disc_type == DISC_TYPE_MATRIX:
                dump_info['rhyming_matrix_trai'] = disc_real['rhyming_matrix'].data.cpu().numpy()
            dump_info['real_prob'] = disc_real['prob'].data.cpu().item()
            dump_info['real_words_str'] = disc_real['words_str']
        disc_gen = self._run_discriminator(words_str_gen, deb)
        if deb:
            print("rhyming_matrix_gen = ", disc_gen['rhyming_matrix'], "|| prob = ", disc_gen['prob'])
            if self.args.disc_type == DISC_TYPE_MATRIX:
                dump_info['rhyming_matrix_gen'] = disc_gen['rhyming_matrix'].data.cpu().numpy()
            dump_info['gen_prob'] = disc_gen['prob'].data.cpu().item()
            dump_info['gen_words_str'] = disc_gen['words_str']
        prob_real = disc_real['prob']
        prob_gen = disc_gen['prob']
        loss = -torch.log(prob_real + eps) - torch.log(1.0 - prob_gen + eps)
        reward = prob_gen
        if self.args.use_score_as_reward:
            reward = disc_gen['score']
        ret.update({'loss': loss, 'reward': reward, 'dump_info': dump_info})
        return ret