Ejemplo n.º 1
0
    def _get_loss_train_op(self):
        # Aggregates loss
        loss_rephraser = (tf.get_collection('loss_rephraser_list')[0] +
                          tf.get_collection('loss_rephraser_list')[1] +
                          tf.get_collection('loss_rephraser_list')[2]) / 3.

        self.loss = loss_rephraser

        # Creates optimizers
        self.vars = collect_trainable_variables([
            self.bert_encoder, self.word_embedder, self.downmlp,
            self.rephrase_encoder, self.rephrase_decoder
        ])

        # Train Op
        self.train_op_pre = get_train_op(self.loss,
                                         self.vars,
                                         learning_rate=self.lr,
                                         hparams=self._hparams.opt)

        # Interface tensors
        self.losses = {
            "loss": self.loss,
            "loss_rephraser": loss_rephraser,
        }
        self.metrics = {}
        self.train_ops = {
            "train_op_pre": self.train_op_pre,
        }
        self.samples = {
            "transferred_yy1_gt": self.text_ids_yy1,
            "transferred_yy1_pred": tf.get_collection('yy_pred_list')[0],
            "transferred_yy2_gt": self.text_ids_yy2,
            "transferred_yy2_pred": tf.get_collection('yy_pred_list')[1],
            "transferred_yy3_gt": self.text_ids_yy3,
            "transferred_yy3_pred": tf.get_collection('yy_pred_list')[2],
            "origin_y1": self.text_ids_y1,
            "origin_y2": self.text_ids_y2,
            "origin_y3": self.text_ids_y3,
            "x1x2": self.x1x2,
            "x1xx2": self.x1xx2
        }

        tf.summary.scalar("loss", self.loss)
        tf.summary.scalar("loss_rephraser", loss_rephraser)
        self.merged = tf.summary.merge_all()
        self.fetches_train_pre = {
            "loss": self.train_ops["train_op_pre"],
            "loss_rephraser": self.losses["loss_rephraser"],
            "merged": self.merged,
        }
        fetches_eval = {
            "batch_size": get_batch_size(self.x1x2yx1xx2_ids),
            "merged": self.merged,
        }
        fetches_eval.update(self.losses)
        fetches_eval.update(self.metrics)
        fetches_eval.update(self.samples)
        self.fetches_eval = fetches_eval
Ejemplo n.º 2
0
    def _get_loss_train_op(self):
        # Aggregates losses
        self.loss_g = self.loss_g_ae + self.lambda_g_hidden * self.loss_g_clas_hidden + self.lambda_g_sentence * self.loss_g_clas_sentence
        self.loss_d = self.loss_d_clas_hidden + self.loss_d_clas_sentence

        # Creates optimizers
        self.g_vars = collect_trainable_variables([
            self.embedder, self.self_graph_encoder, self.label_connector,
            self.rephrase_encoder, self.rephrase_decoder
        ])
        self.d_vars = collect_trainable_variables([
            self.clas_embedder, self.classifier_hidden,
            self.classifier_sentence
        ])

        self.train_op_g = get_train_op(self.loss_g,
                                       self.g_vars,
                                       hparams=self._hparams.opt)
        self.train_op_g_ae = get_train_op(self.loss_g_ae,
                                          self.g_vars,
                                          hparams=self._hparams.opt)
        self.train_op_d = get_train_op(self.loss_d,
                                       self.d_vars,
                                       hparams=self._hparams.opt)

        # Interface tensors
        self.losses = {
            "loss_g": self.loss_g,
            "loss_d": self.loss_d,
            "loss_g_ae": self.loss_g_ae,
            "loss_g_clas_hidden": self.loss_g_clas_hidden,
            "loss_g_clas_sentence": self.loss_g_clas_sentence,
            "loss_d_clas_hidden": self.loss_d_clas_hidden,
            "loss_d_clas_sentence": self.loss_d_clas_sentence,
        }
        self.metrics = {
            "accu_d_hidden": self.accu_d_hidden,
            "accu_d_sentence": self.accu_d_sentence,
            "accu_g_hidden": self.accu_g_hidden,
            "accu_g_sentence": self.accu_g_sentence,
            "accu_g_gdy_sentence": self.accu_g_gdy_sentence,
        }
        self.train_ops = {
            "train_op_g": self.train_op_g,
            "train_op_g_ae": self.train_op_g_ae,
            "train_op_d": self.train_op_d
        }
        self.samples = {
            "original": self.text_ids[:, 1:],
            "transferred": self.rephrase_outputs_.sample_id
        }

        self.fetches_train_d = {
            "loss_d": self.train_ops["train_op_d"],
            "loss_d_clas_hidden": self.losses["loss_d_clas_hidden"],
            "loss_d_clas_sentence": self.losses["loss_d_clas_sentence"],
            "accu_d_hidden": self.metrics["accu_d_hidden"],
            "accu_d_sentence": self.metrics["accu_d_sentence"],
        }

        tf.summary.scalar("loss_d", self.loss_d)
        tf.summary.scalar("loss_d_clas_hidden", self.loss_d_clas_hidden)
        tf.summary.scalar("loss_d_clas_sentence", self.loss_d_clas_sentence)
        tf.summary.scalar("accu_d_hidden", self.accu_d_hidden)
        tf.summary.scalar("accu_d_sentence", self.accu_d_sentence)
        tf.summary.scalar("loss_g", self.loss_g)
        tf.summary.scalar("loss_g_ae", self.loss_g_ae)
        tf.summary.scalar("loss_g_clas_hidden", self.loss_g_clas_hidden)
        tf.summary.scalar("loss_g_clas_sentence", self.loss_g_clas_sentence)
        tf.summary.scalar("accu_g_hidden", self.accu_g_hidden)
        tf.summary.scalar("accu_g_sentence", self.accu_g_sentence)
        tf.summary.scalar("accu_g_gdy_sentence", self.accu_g_gdy_sentence)
        self.merged = tf.summary.merge_all()
        self.fetches_train_g = {
            "loss_g": self.train_ops["train_op_g"],
            "loss_g_ae": self.losses["loss_g_ae"],
            "loss_g_clas_hidden": self.losses["loss_g_clas_hidden"],
            "loss_g_clas_sentence": self.losses["loss_g_clas_sentence"],
            "accu_g_hidden": self.metrics["accu_g_hidden"],
            "accu_g_sentence": self.metrics["accu_g_sentence"],
            "accu_g_gdy_sentence": self.metrics["accu_g_gdy_sentence"],
            "merged": self.merged,
        }
        fetches_eval = {
            "batch_size": get_batch_size(self.text_ids),
            "merged": self.merged,
        }
        fetches_eval.update(self.losses)
        fetches_eval.update(self.metrics)
        fetches_eval.update(self.samples)
        self.fetches_eval = fetches_eval
Ejemplo n.º 3
0
def build_model(data_batch, data, step):
    batch_size, num_steps = [
        tf.shape(data_batch["x_value_text_ids"])[d] for d in range(2)
    ]
    vocab = data.vocab("y_aux")

    id2str = "<{}>".format
    bos_str, eos_str = map(id2str, (vocab.bos_token_id, vocab.eos_token_id))

    def single_bleu(ref, hypo):
        ref = [id2str(u if u != vocab.unk_token_id else -1) for u in ref]
        hypo = [id2str(u) for u in hypo]

        ref = tx.utils.strip_special_tokens(" ".join(ref),
                                            strip_bos=bos_str,
                                            strip_eos=eos_str)
        hypo = tx.utils.strip_special_tokens(" ".join(hypo), strip_eos=eos_str)

        return 0.01 * tx.evals.sentence_bleu(references=[ref], hypothesis=hypo)

    def batch_bleu(refs, hypos):
        return np.array(
            [single_bleu(ref, hypo) for ref, hypo in zip(refs, hypos)],
            dtype=np.float32,
        )

    def lambda_anneal(step_stage):

        print("==========step_stage is {}".format(step_stage))
        if step_stage <= 1:
            rec_weight = 1
        elif step_stage > 1 and step_stage < 2:
            rec_weight = Config.rec_w - step_stage * 0.1
        return np.array(rec_weight, dtype=tf.float32)

    # losses
    losses = {}

    # embedders
    embedders = {
        name: tx.modules.WordEmbedder(vocab_size=data.vocab(name).size,
                                      hparams=hparams)
        for name, hparams in Config.config_model.embedders.items()
    }

    # encoders
    y_encoder = tx.modules.BidirectionalRNNEncoder(
        hparams=Config.config_model.y_encoder)
    x_encoder = tx.modules.BidirectionalRNNEncoder(
        hparams=Config.config_model.x_encoder)

    def concat_encoder_outputs(outputs):
        return tf.concat(outputs, -1)

    def encode(ref_flag):
        y_str = y_strs[ref_flag]
        sent_ids = data_batch["{}_text_ids".format(y_str)]
        sent_embeds = embedders["y_aux"](sent_ids)
        sent_sequence_length = data_batch["{}_length".format(y_str)]
        sent_enc_outputs, _ = y_encoder(sent_embeds,
                                        sequence_length=sent_sequence_length)
        sent_enc_outputs = concat_encoder_outputs(sent_enc_outputs)

        x_str = x_strs[ref_flag]
        sd_ids = {
            field: data_batch["{}_{}_text_ids".format(x_str, field)][:, 1:-1]
            for field in x_fields
        }
        sd_embeds = tf.concat(
            [
                embedders["x_{}".format(field)](sd_ids[field])
                for field in x_fields
            ],
            axis=-1,
        )
        sd_sequence_length = (
            data_batch["{}_{}_length".format(x_str, x_fields[0])] - 2)
        sd_enc_outputs, _ = x_encoder(sd_embeds,
                                      sequence_length=sd_sequence_length)
        sd_enc_outputs = concat_encoder_outputs(sd_enc_outputs)

        return (
            sent_ids,
            sent_embeds,
            sent_enc_outputs,
            sent_sequence_length,
            sd_ids,
            sd_embeds,
            sd_enc_outputs,
            sd_sequence_length,
        )

    encode_results = [encode(ref_str) for ref_str in range(2)]
    (
        sent_ids,
        sent_embeds,
        sent_enc_outputs,
        sent_sequence_length,
        sd_ids,
        sd_embeds,
        sd_enc_outputs,
        sd_sequence_length,
    ) = zip(*encode_results)

    # get rnn cell
    rnn_cell = tx.core.layers.get_rnn_cell(Config.config_model.rnn_cell)

    def get_decoder(cell,
                    y__ref_flag,
                    x_ref_flag,
                    tgt_ref_flag,
                    beam_width=None):
        output_layer_params = ({
            "output_layer": tf.identity
        } if Config.copy_flag else {
            "vocab_size": vocab.size
        })

        if Config.attn_flag:  # attention
            if Config.attn_x and Config.attn_y_:
                memory = tf.concat(
                    [
                        sent_enc_outputs[y__ref_flag],
                        sd_enc_outputs[x_ref_flag]
                    ],
                    axis=1,
                )
                memory_sequence_length = None
            elif Config.attn_y_:
                memory = sent_enc_outputs[y__ref_flag]
                memory_sequence_length = sent_sequence_length[y__ref_flag]
            elif Config.attn_x:
                memory = sd_enc_outputs[x_ref_flag]
                memory_sequence_length = sd_sequence_length[x_ref_flag]
            else:
                raise Exception(
                    "Must specify either y__ref_flag or x_ref_flag.")
            attention_decoder = tx.modules.AttentionRNNDecoder(
                cell=cell,
                memory=memory,
                memory_sequence_length=memory_sequence_length,
                hparams=Config.config_model.attention_decoder,
                **output_layer_params)
            if not Config.copy_flag:
                return attention_decoder
            cell = (attention_decoder.cell if beam_width is None else
                    attention_decoder._get_beam_search_cell(beam_width))

        if Config.copy_flag:  # copynet
            kwargs = {
                "y__ids": sent_ids[y__ref_flag][:, 1:],
                "y__states": sent_enc_outputs[y__ref_flag][:, 1:],
                "y__lengths": sent_sequence_length[y__ref_flag] - 1,
                "x_ids": sd_ids[x_ref_flag]["value"],
                "x_states": sd_enc_outputs[x_ref_flag],
                "x_lengths": sd_sequence_length[x_ref_flag],
            }

            if tgt_ref_flag is not None:
                kwargs.update({
                    "input_ids":
                    data_batch["{}_text_ids".format(
                        y_strs[tgt_ref_flag])][:, :-1]
                })

            memory_prefixes = []

            if Config.copy_y_:
                memory_prefixes.append("y_")

            if Config.copy_x:
                memory_prefixes.append("x")

            if beam_width is not None:
                kwargs = {
                    name: tile_batch(value, beam_width)
                    for name, value in kwargs.items()
                }

            def get_get_copy_scores(memory_ids_states_lengths, output_size):
                memory_copy_states = [
                    tf.layers.dense(
                        memory_states,
                        units=output_size,
                        activation=None,
                        use_bias=False,
                    ) for _, memory_states, _ in memory_ids_states_lengths
                ]

                def get_copy_scores(query, coverities=None):
                    ret = []

                    if Config.copy_y_:
                        memory = memory_copy_states[len(ret)]
                        if coverities is not None:
                            memory = memory + tf.layers.dense(
                                coverities[len(ret)],
                                units=output_size,
                                activation=None,
                                use_bias=False,
                            )
                        memory = tf.nn.tanh(memory)
                        ret_y_ = tf.einsum("bim,bm->bi", memory, query)
                        ret.append(ret_y_)

                    if Config.copy_x:
                        memory = memory_copy_states[len(ret)]
                        if coverities is not None:
                            memory = memory + tf.layers.dense(
                                coverities[len(ret)],
                                units=output_size,
                                activation=None,
                                use_bias=False,
                            )
                        memory = tf.nn.tanh(memory)
                        ret_x = tf.einsum("bim,bm->bi", memory, query)
                        ret.append(ret_x)

                    return ret

                return get_copy_scores

            covrity_dim = (Config.config_model.coverage_state_dim
                           if Config.coverage else None)
            coverity_rnn_cell_hparams = (Config.config_model.coverage_rnn_cell
                                         if Config.coverage else None)
            cell = CopyNetWrapper(
                cell=cell,
                vocab_size=vocab.size,
                memory_ids_states_lengths=[
                    tuple(kwargs["{}_{}".format(prefix, s)]
                          for s in ("ids", "states", "lengths"))
                    for prefix in memory_prefixes
                ],
                input_ids=kwargs["input_ids"]
                if tgt_ref_flag is not None else None,
                get_get_copy_scores=get_get_copy_scores,
                coverity_dim=covrity_dim,
                coverity_rnn_cell_hparams=coverity_rnn_cell_hparams,
                disabled_vocab_size=Config.disabled_vocab_size,
                eps=Config.eps,
            )

        decoder = tx.modules.BasicRNNDecoder(
            cell=cell,
            hparams=Config.config_model.decoder,
            **output_layer_params)
        return decoder

    def get_decoder_and_outputs(cell,
                                y__ref_flag,
                                x_ref_flag,
                                tgt_ref_flag,
                                params,
                                beam_width=None):
        decoder = get_decoder(cell,
                              y__ref_flag,
                              x_ref_flag,
                              tgt_ref_flag,
                              beam_width=beam_width)
        if beam_width is None:
            ret = decoder(**params)
        else:
            ret = tx.modules.beam_search_decode(decoder_or_cell=decoder,
                                                beam_width=beam_width,
                                                **params)
        return (decoder, ) + ret

    get_decoder_and_outputs = tf.make_template("get_decoder_and_outputs",
                                               get_decoder_and_outputs)

    def teacher_forcing(cell, y__ref_flag, x_ref_flag, loss_name):
        tgt_ref_flag = x_ref_flag
        tgt_str = y_strs[tgt_ref_flag]
        sequence_length = data_batch["{}_length".format(tgt_str)] - 1
        decoder, tf_outputs, final_state, _ = get_decoder_and_outputs(
            cell,
            y__ref_flag,
            x_ref_flag,
            tgt_ref_flag,
            {
                "decoding_strategy": "train_greedy",
                "inputs": sent_embeds[tgt_ref_flag],
                "sequence_length": sequence_length,
            },
        )

        tgt_sent_ids = data_batch["{}_text_ids".format(tgt_str)][:, 1:]
        loss = tx.losses.sequence_sparse_softmax_cross_entropy(
            labels=tgt_sent_ids,
            logits=tf_outputs.logits,
            sequence_length=sequence_length,
            average_across_batch=False,
        )
        if (Config.add_bleu_weight and y__ref_flag is not None
                and tgt_ref_flag is not None and y__ref_flag != tgt_ref_flag):
            w = tf.py_func(
                batch_bleu,
                [sent_ids[y__ref_flag], tgt_sent_ids],
                tf.float32,
                stateful=False,
                name="W_BLEU",
            )
            w.set_shape(loss.get_shape())
            loss = w * loss
        loss = tf.reduce_mean(loss, 0)

        if Config.copy_flag and Config.exact_cover_w != 0:
            sum_copy_probs = list(
                map(lambda t: tf.cast(t, tf.float32),
                    final_state.sum_copy_probs))
            memory_lengths = [
                lengths
                for _, _, lengths in decoder.cell.memory_ids_states_lengths
            ]
            exact_coverage_losses = [
                tf.reduce_mean(
                    tf.reduce_sum(
                        tx.utils.mask_sequences(tf.square(sum_copy_prob - 1.0),
                                                memory_length),
                        1,
                    )) for sum_copy_prob, memory_length in zip(
                        sum_copy_probs, memory_lengths)
            ]
            print_xe_loss_op = tf.print(loss_name, "xe loss:", loss)
            with tf.control_dependencies([print_xe_loss_op]):
                for i, exact_coverage_loss in enumerate(exact_coverage_losses):
                    print_op = tf.print(
                        loss_name,
                        "exact coverage loss {:d}:".format(i),
                        exact_coverage_loss,
                    )
                    with tf.control_dependencies([print_op]):
                        loss += Config.exact_cover_w * exact_coverage_loss

        losses[loss_name] = loss

        return decoder, tf_outputs, loss

    def beam_searching(cell, y__ref_flag, x_ref_flag, beam_width):
        start_tokens = (tf.ones_like(data_batch["y_aux_length"]) *
                        vocab.bos_token_id)
        end_token = vocab.eos_token_id

        decoder, bs_outputs, _, _ = get_decoder_and_outputs(
            cell,
            y__ref_flag,
            x_ref_flag,
            None,
            {
                "embedding":
                embedders["y_aux"],
                "start_tokens":
                start_tokens,
                "end_token":
                end_token,
                "max_decoding_length":
                Config.config_train.infer_max_decoding_length,
            },
            beam_width=Config.config_train.infer_beam_width,
        )

        return decoder, bs_outputs

    decoder, tf_outputs, loss = teacher_forcing(rnn_cell, 1, 0, "MLE")
    rec_decoder, _, rec_loss = teacher_forcing(rnn_cell, 1, 1, "REC")
    rec_weight = Config.rec_w

    step_stage = tf.cast(step, tf.float32) / tf.constant(800.0)
    rec_weight = tf.case(
        [
            (
                tf.less_equal(step_stage, tf.constant(1.0)),
                lambda: tf.constant(1.0),
            ),
            (tf.greater(step_stage, tf.constant(2.0)), lambda: Config.rec_w),
        ],
        default=lambda: tf.constant(1.0) - (step_stage - 1) *
        (1 - Config.rec_w),
    )
    joint_loss = (1 - rec_weight) * loss + rec_weight * rec_loss
    losses["joint"] = joint_loss

    tiled_decoder, bs_outputs = beam_searching(
        rnn_cell, 1, 0, Config.config_train.infer_beam_width)

    train_ops = {
        name: get_train_op(losses[name],
                           hparams=Config.config_train.train[name])
        for name in Config.config_train.train
    }

    return train_ops, bs_outputs
    def _build_model(self, inputs, vocab, gamma, lambda_g):
        """Builds the model.
        """
        embedder = WordEmbedder(
            vocab_size=vocab.size,
            hparams=self._hparams.embedder)
        encoder = UnidirectionalRNNEncoder(hparams=self._hparams.encoder)

        # text_ids for encoder, with BOS token removed
        enc_text_ids = inputs['text_ids'][:, 1:]
        enc_outputs, final_state = encoder(embedder(enc_text_ids),
                                           sequence_length=inputs['length']-1)
        z = final_state[:, self._hparams.dim_c:]

        # Encodes label
        label_connector = MLPTransformConnector(self._hparams.dim_c)

        # Gets the sentence representation: h = (c, z)
        labels0 = tf.to_float(tf.reshape(inputs['labels0'], [-1, 1]))
        labels1 = tf.to_float(tf.reshape(inputs['labels1'], [-1, 1]))
        labels2 = tf.to_float(tf.reshape(inputs['labels2'], [-1, 1]))
        labels3 = tf.to_float(tf.reshape(inputs['labels3'], [-1, 1]))
        labels = tf.concat([labels0, labels1, labels2, labels3], axis = 1)
        print('labels', labels)
        sys.stdout.flush()
        c = label_connector(labels)
        c_ = label_connector(1 - labels)
        h = tf.concat([c, z], 1)
        h_ = tf.concat([c_, z], 1)

        # Teacher-force decoding and the auto-encoding loss for G
        decoder = AttentionRNNDecoder(
            memory=enc_outputs,
            memory_sequence_length=inputs['length']-1,
            cell_input_fn=lambda inputs, attention: inputs,
            vocab_size=vocab.size,
            hparams=self._hparams.decoder)

        connector = MLPTransformConnector(decoder.state_size)

        g_outputs, _, _ = decoder(
            initial_state=connector(h), inputs=inputs['text_ids'],
            embedding=embedder, sequence_length=inputs['length']-1)

        print('labels shape', inputs['text_ids'][:, 1:], 'logits shape', g_outputs.logits)
        print(inputs['length'] - 1)
        loss_g_ae = tx.losses.sequence_sparse_softmax_cross_entropy(
            labels=inputs['text_ids'][:, 1:],
            logits=g_outputs.logits,
            sequence_length=inputs['length']-1,
            average_across_timesteps=True,
            sum_over_timesteps=False)

        # Gumbel-softmax decoding, used in training
        start_tokens = tf.ones_like(inputs['labels0']) * vocab.bos_token_id
        end_token = vocab.eos_token_id
        gumbel_helper = GumbelSoftmaxEmbeddingHelper(
            embedder.embedding, start_tokens, end_token, gamma)

        soft_outputs_, _, soft_length_, = decoder(
            helper=gumbel_helper, initial_state=connector(h_))

        print(g_outputs, soft_outputs_)

        # Greedy decoding, used in eval
        outputs_, _, length_ = decoder(
            decoding_strategy='infer_greedy', initial_state=connector(h_),
            embedding=embedder, start_tokens=start_tokens, end_token=end_token)
        # Creates classifier
        classifier0 = Conv1DClassifier(hparams=self._hparams.classifier)
        classifier1 = Conv1DClassifier(hparams=self._hparams.classifier)
        classifier2 = Conv1DClassifier(hparams=self._hparams.classifier)
        classifier3 = Conv1DClassifier(hparams=self._hparams.classifier)
        clas_embedder = WordEmbedder(vocab_size=vocab.size,
                                     hparams=self._hparams.embedder)

        clas_logits, clas_preds = self._high_level_classifier([classifier0, classifier1, classifier2, classifier3],
            clas_embedder, inputs, vocab, gamma, lambda_g, inputs['text_ids'][:, 1:], None, inputs['length']-1)
        loss_d_clas = tf.nn.sigmoid_cross_entropy_with_logits(
            labels=tf.to_float(labels), logits=clas_logits)
        loss_d_clas = tf.reduce_mean(loss_d_clas)
        accu_d = tx.evals.accuracy(labels, preds=clas_preds)

        # Classification loss for the generator, based on soft samples
        # soft_logits, soft_preds = classifier(
        #     inputs=clas_embedder(soft_ids=soft_outputs_.sample_id),
        #     sequence_length=soft_length_)
        soft_logits, soft_preds = self._high_level_classifier([classifier0, classifier1, classifier2, classifier3],
            clas_embedder, inputs, vocab, gamma, lambda_g, None, soft_outputs_.sample_id, soft_length_)
        print(soft_logits.shape, soft_preds.shape)
        loss_g_clas = tf.nn.sigmoid_cross_entropy_with_logits(
            labels=tf.to_float(1-labels), logits=soft_logits)
        loss_g_clas = tf.reduce_mean(loss_g_clas)

        # Accuracy on soft samples, for training progress monitoring
        accu_g = tx.evals.accuracy(labels=1-labels, preds=soft_preds)

        # Accuracy on greedy-decoded samples, for training progress monitoring
        # _, gdy_preds = classifier(
        #     inputs=clas_embedder(ids=outputs_.sample_id),
        #     sequence_length=length_)
        _, gdy_preds = self._high_level_classifier([classifier0, classifier1, classifier2, classifier3],
            clas_embedder, inputs, vocab, gamma, lambda_g, outputs_.sample_id, None, length_)
        print(gdy_preds.shape)
        accu_g_gdy = tx.evals.accuracy(
            labels=1-labels, preds=gdy_preds)

        # Aggregates losses
        loss_g = loss_g_ae + lambda_g * loss_g_clas
        loss_d = loss_d_clas

        # Creates optimizers
        g_vars = collect_trainable_variables(
            [embedder, encoder, label_connector, connector, decoder])
        d_vars = collect_trainable_variables([clas_embedder, classifier0, classifier1, classifier2, classifier3])

        train_op_g = get_train_op(
            loss_g, g_vars, hparams=self._hparams.opt)
        train_op_g_ae = get_train_op(
            loss_g_ae, g_vars, hparams=self._hparams.opt)
        train_op_d = get_train_op(
            loss_d, d_vars, hparams=self._hparams.opt)

        # Interface tensors
        self.predictions = {
            "predictions": clas_preds,
            "ground_truth": labels
        }
        self.losses = {
            "loss_g": loss_g,
            "loss_g_ae": loss_g_ae,
            "loss_g_clas": loss_g_clas,
            "loss_d": loss_d_clas
        }
        self.metrics = {
            "accu_d": accu_d,
            "accu_g": accu_g,
            "accu_g_gdy": accu_g_gdy,
        }
        self.train_ops = {
            "train_op_g": train_op_g,
            "train_op_g_ae": train_op_g_ae,
            "train_op_d": train_op_d
        }
        self.samples = {
            "original": inputs['text_ids'][:, 1:],
            "transferred": outputs_.sample_id
        }

        self.fetches_train_g = {
            "loss_g": self.train_ops["train_op_g"],
            "loss_g_ae": self.losses["loss_g_ae"],
            "loss_g_clas": self.losses["loss_g_clas"],
            "accu_g": self.metrics["accu_g"],
            "accu_g_gdy": self.metrics["accu_g_gdy"],
        }
        self.fetches_train_d = {
            "loss_d": self.train_ops["train_op_d"],
            "accu_d": self.metrics["accu_d"]
        }
        fetches_eval = {"batch_size": get_batch_size(inputs['text_ids'])}
        fetches_eval.update(self.losses)
        fetches_eval.update(self.metrics)
        fetches_eval.update(self.samples)
        fetches_eval.update(self.predictions)
        self.fetches_eval = fetches_eval
Ejemplo n.º 5
0
    def _build_model(self, inputs, vocab, finputs, minputs, gamma):
        """Builds the model.
        """
        self.inputs = inputs
        self.finputs = finputs
        self.minputs = minputs
        self.vocab = vocab

        self.embedder = WordEmbedder(vocab_size=self.vocab.size,
                                     hparams=self._hparams.embedder)
        # maybe later have to try BidirectionalLSTMEncoder
        self.encoder = UnidirectionalRNNEncoder(
            hparams=self._hparams.encoder)  #GRU cell

        # text_ids for encoder, with BOS(begin of sentence) token removed
        self.enc_text_ids = self.inputs['text_ids'][:, 1:]
        self.enc_outputs, self.final_state = self.encoder(
            self.embedder(self.enc_text_ids),
            sequence_length=self.inputs['length'] - 1)

        h = self.final_state

        # Teacher-force decoding and the auto-encoding loss for G
        self.decoder = AttentionRNNDecoder(
            memory=self.enc_outputs,
            memory_sequence_length=self.inputs['length'] - 1,
            cell_input_fn=lambda inputs, attention: inputs,
            #default: lambda inputs, attention: tf.concat([inputs, attention], -1), which cancats regular RNN cell inputs with attentions.
            vocab_size=self.vocab.size,
            hparams=self._hparams.decoder)

        self.connector = MLPTransformConnector(self.decoder.state_size)

        self.g_outputs, _, _ = self.decoder(
            initial_state=self.connector(h),
            inputs=self.inputs['text_ids'],
            embedding=self.embedder,
            sequence_length=self.inputs['length'] - 1)

        self.loss_g_ae = tx.losses.sequence_sparse_softmax_cross_entropy(
            labels=self.inputs['text_ids'][:, 1:],
            logits=self.g_outputs.logits,
            sequence_length=self.inputs['length'] - 1,
            average_across_timesteps=True,
            sum_over_timesteps=False)

        # Greedy decoding, used in eval (and RL training)
        start_tokens = tf.ones_like(
            self.inputs['labels']) * self.vocab.bos_token_id
        end_token = self.vocab.eos_token_id
        self.outputs, _, length = self.decoder(
            #也许可以尝试之后把这个换成 "infer_sample"看效果
            decoding_strategy='infer_greedy',
            initial_state=self.connector(h),
            embedding=self.embedder,
            start_tokens=start_tokens,
            end_token=end_token)

        # Creates optimizers
        self.g_vars = collect_trainable_variables(
            [self.embedder, self.encoder, self.connector, self.decoder])
        self.train_op_g_ae = get_train_op(self.loss_g_ae,
                                          self.g_vars,
                                          hparams=self._hparams.opt)

        # Interface tensors
        self.samples = {
            "batch_size": get_batch_size(self.inputs['text_ids']),
            "original": self.inputs['text_ids'][:, 1:],
            "transferred": self.outputs.sample_id  #outputs 是infer_greedy的结果
        }

        ############################ female sentiment regression model
        #现在只用了convnet不知道效果,之后可以试试RNN decoding看regression的准确度,或者把两个结合一下(concat成一个向量)
        self.fconvnet = Conv1DNetwork(
            hparams=self._hparams.convnet
        )  #[batch_size, time_steps, embedding_dim] (default input)
        #convnet = Conv1DNetwork()
        self.freg_embedder = WordEmbedder(
            vocab_size=self.vocab.size, hparams=self._hparams.embedder
        )  #(64, 26, 100) (output shape of clas_embedder(ids=inputs['text_ids'][:, 1:]))
        self.fconv_output = self.fconvnet(inputs=self.freg_embedder(
            ids=self.finputs['text_ids'][:, 1:]))  #(64, 128)  等一会做一下finputs!!!
        p = {"type": "Dense", "kwargs": {'units': 1}}
        self.fdense_layer = tx.core.layers.get_layer(hparams=p)
        self.freg_output = self.fdense_layer(inputs=self.fconv_output)
        '''
        #考虑
        self.fenc_text_ids = self.finputs['text_ids'][:, 1:]
        self.fencoder = UnidirectionalRNNEncoder(hparams=self._hparams.encoder) #GRU cell
        self.fenc_outputs, self.ffinal_state = self.fencoder(self.freg_embedder(self.fenc_text_ids),sequence_length=self.finputs['length']-1)
        self.freg_output = self.fdense_layer(inputs = tf.concat([self.fconv_output, self.ffinal_state], -1))
        '''

        self.fprediction = tf.reshape(self.freg_output, [-1])
        self.fground_truth = tf.to_float(self.finputs['labels'])

        self.floss_reg_single = tf.pow(
            self.fprediction - self.fground_truth,
            2)  #这样得到的是单个的loss,可以之后在RL里面对一整个batch进行update
        self.floss_reg_batch = tf.reduce_mean(
            self.floss_reg_single)  #对一个batch求和平均的loss

        #self.freg_vars = collect_trainable_variables([self.freg_embedder, self.fconvnet, self.fencoder, self.fdense_layer])
        self.freg_vars = collect_trainable_variables(
            [self.freg_embedder, self.fconvnet, self.fdense_layer])
        self.ftrain_op_d = get_train_op(self.floss_reg_batch,
                                        self.freg_vars,
                                        hparams=self._hparams.opt)

        self.freg_sample = {
            "fprediction": self.fprediction,
            "fground_truth": self.fground_truth,
            "fsent": self.finputs['text_ids'][:, 1:]
        }

        ############################ male sentiment regression model
        self.mconvnet = Conv1DNetwork(
            hparams=self._hparams.convnet
        )  #[batch_size, time_steps, embedding_dim] (default input)
        #convnet = Conv1DNetwork()
        self.mreg_embedder = WordEmbedder(
            vocab_size=self.vocab.size, hparams=self._hparams.embedder
        )  #(64, 26, 100) (output shape of clas_embedder(ids=inputs['text_ids'][:, 1:]))
        self.mconv_output = self.mconvnet(inputs=self.mreg_embedder(
            ids=self.minputs['text_ids'][:, 1:]))  #(64, 128)
        p = {"type": "Dense", "kwargs": {'units': 1}}
        self.mdense_layer = tx.core.layers.get_layer(hparams=p)
        self.mreg_output = self.mdense_layer(inputs=self.mconv_output)
        '''
        #考虑
        self.menc_text_ids = self.minputs['text_ids'][:, 1:]
        self.mencoder = UnidirectionalRNNEncoder(hparams=self._hparams.encoder) #GRU cell
        self.menc_outputs, self.mfinal_state = self.mencoder(self.mreg_embedder(self.menc_text_ids),sequence_length=self.minputs['length']-1)
        self.mreg_output = self.mdense_layer(inputs = tf.concat([self.mconv_output, self.mfinal_state], -1))
        '''

        self.mprediction = tf.reshape(self.mreg_output, [-1])
        self.mground_truth = tf.to_float(self.minputs['labels'])

        self.mloss_reg_single = tf.pow(
            self.mprediction - self.mground_truth,
            2)  #这样得到的是单个的loss,可以之后在RL里面对一整个batch进行update
        self.mloss_reg_batch = tf.reduce_mean(
            self.mloss_reg_single)  #对一个batch求和平均的loss

        #self.mreg_vars = collect_trainable_variables([self.mreg_embedder, self.mconvnet, self.mencoder, self.mdense_layer])
        self.mreg_vars = collect_trainable_variables(
            [self.mreg_embedder, self.mconvnet, self.mdense_layer])
        self.mtrain_op_d = get_train_op(self.mloss_reg_batch,
                                        self.mreg_vars,
                                        hparams=self._hparams.opt)

        self.mreg_sample = {
            "mprediction": self.mprediction,
            "mground_truth": self.mground_truth,
            "msent": self.minputs['text_ids'][:, 1:]
        }

        ###### get self.pre_dif when doing RL training (for transferred sents)
        ### pass to female regression model
        self.RL_fconv_output = self.fconvnet(inputs=self.freg_embedder(
            ids=self.outputs.sample_id))  #(64, 128)  等一会做一下finputs!!!
        self.RL_freg_output = self.fdense_layer(inputs=self.RL_fconv_output)
        self.RL_fprediction = tf.reshape(self.RL_freg_output, [-1])
        ### pass to male regression model
        self.RL_mconv_output = self.mconvnet(inputs=self.mreg_embedder(
            ids=self.outputs.sample_id))  #(64, 128)  等一会做一下finputs!!!
        self.RL_mreg_output = self.mdense_layer(inputs=self.RL_mconv_output)
        self.RL_mprediction = tf.reshape(self.RL_mreg_output, [-1])

        self.pre_dif = tf.abs(self.RL_fprediction - self.RL_mprediction)

        ###### get self.Ypre_dif for original sents
        ### pass to female regression model
        self.YRL_fconv_output = self.fconvnet(inputs=self.freg_embedder(
            ids=self.inputs['text_ids'][:, 1:]))  #(64, 128)  等一会做一下finputs!!!
        self.YRL_freg_output = self.fdense_layer(inputs=self.YRL_fconv_output)
        self.YRL_fprediction = tf.reshape(self.YRL_freg_output, [-1])
        ### pass to male regression model
        self.YRL_mconv_output = self.mconvnet(inputs=self.mreg_embedder(
            ids=self.inputs['text_ids'][:, 1:]))  #(64, 128)  等一会做一下finputs!!!
        self.YRL_mreg_output = self.mdense_layer(inputs=self.YRL_mconv_output)
        self.YRL_mprediction = tf.reshape(self.YRL_mreg_output, [-1])

        self.Ypre_dif = tf.abs(self.YRL_fprediction - self.YRL_mprediction)

        ######################## RL training
        '''
        def fil(elem):
            return tf.where(elem > 1.3, tf.minimum(elem,3), 0)
        def fil_pushsmall(elem):
            return tf.add(tf.where(elem <0.5, 1, 0),tf.where(elem>1.5,-0.5*elem,0))
        '''
        '''
        #缩小prediction差异
        def fil1(elem):
            return tf.where(elem<0.5,1.0,0.0)
        def fil2(elem):
            return tf.where(elem>1.5,-0.5*elem,0.0)
        '''

        #扩大prediction差异
        def fil1(elem):
            return tf.where(elem < 0.5, -0.01, 0.0)

        def fil2(elem):
            return tf.where(elem > 1.3, elem, 0.0)

        # 维数是(batch_size,time_step),对应的是一个batch中每一个sample的每一个timestep的loss
        self.beginning_loss_g_RL2 = tf.nn.sparse_softmax_cross_entropy_with_logits(
            _sentinel=None,
            labels=self.outputs.sample_id,
            logits=self.outputs.logits,
            name=None)
        self.middle_loss_g_RL2 = tf.reduce_sum(
            self.beginning_loss_g_RL2, axis=1
        )  #(batch_size,),这样得到的loss是每一个句子的loss(对time_steps求和,对batch不求和)

        #trivial "RL" training with all weight set to 1
        #final_loss_g_RL2 = tf.reduce_sum(self.middle_loss_g_RL2)

        #RL training
        self.filtered = tf.add(tf.map_fn(fil1, self.pre_dif),
                               tf.map_fn(fil2, self.pre_dif))
        self.updated_loss_per_sent = tf.multiply(
            self.filtered,
            self.middle_loss_g_RL2)  #haven't set threshold for weight update
        self.updated_loss_per_batch = tf.reduce_sum(
            self.updated_loss_per_sent)  #############!!有一个问题:
        # 我想update每一个句子的loss,但是train_updated那里会报错,所以好像只能updateloss的求和,这样是相当于update每一个句子的loss吗?

        self.vars_updated = collect_trainable_variables(
            [self.connector, self.decoder])
        self.train_updated = get_train_op(self.updated_loss_per_batch,
                                          self.vars_updated,
                                          hparams=self._hparams.opt)
        self.train_updated_interface = {
            "pre_dif": self.pre_dif,
            "updated_loss_per_sent": self.updated_loss_per_sent,
            "updated_loss_per_batch": self.updated_loss_per_batch,
        }

        ### Train AE and RL together
        self.loss_AERL = gamma * self.updated_loss_per_batch + self.loss_g_ae
        self.vars_AERL = collect_trainable_variables(
            [self.connector, self.decoder])
        self.train_AERL = get_train_op(self.loss_AERL,
                                       self.vars_AERL,
                                       hparams=self._hparams.opt)
Ejemplo n.º 6
0
    def _get_loss_train_op(self):
        # Aggregates losses
        self.loss_g = self.loss_g_ae + self.lambda_t_graph * self.loss_g_clas_graph + self.lambda_t_sentence * self.loss_g_clas_sentence
        # possible ablation: SGT-I, CGT-I, SGT-CGT-I, c-clas-g-only, c-clas-s-only
        if self.ablation == 'c-clas-g-only':
            self.loss_d = self.loss_d_clas_graph
        elif self.ablation == 'c-clas-s-only':
            self.loss_d = self.loss_d_clas_sentence
        else:
            self.loss_d = self.loss_d_clas_graph + self.loss_d_clas_sentence

        # Creates optimizers
        self.g_vars = collect_trainable_variables([
            self.embedder, self.self_graph_encoder, self.label_connector,
            self.cross_graph_encoder, self.rephrase_encoder,
            self.rephrase_decoder
        ])
        self.d_vars = collect_trainable_variables([
            self.clas_embedder, self.classifier_graph, self.classifier_sentence
        ])

        self.train_op_g = get_train_op(self.loss_g,
                                       self.g_vars,
                                       hparams=self._hparams.opt)
        self.train_op_g_ae = get_train_op(self.loss_g_ae,
                                          self.g_vars,
                                          hparams=self._hparams.opt)
        self.train_op_d = get_train_op(self.loss_d,
                                       self.d_vars,
                                       hparams=self._hparams.opt)

        # Interface tensors
        self.losses = {
            "loss_g": self.loss_g,
            "loss_d": self.loss_d,
            "loss_g_ae": self.loss_g_ae,
            "loss_g_clas_graph": self.loss_g_clas_graph,
            "loss_g_clas_sentence": self.loss_g_clas_sentence,
            "loss_d_clas_graph": self.loss_d_clas_graph,
            "loss_d_clas_sentence": self.loss_d_clas_sentence,
        }
        self.metrics = {
            "accu_d_graph": self.accu_d_graph,
            "accu_d_sentence": self.accu_d_sentence,
            "accu_g_graph": self.accu_g_graph,
            "accu_g_sentence": self.accu_g_sentence,
            "accu_g_gdy_sentence": self.accu_g_gdy_sentence
        }
        self.train_ops = {
            "train_op_g": self.train_op_g,
            "train_op_g_ae": self.train_op_g_ae,
            "train_op_d": self.train_op_d
        }
        self.samples = {
            "original": self.text_ids[:, 1:],
            "transferred": self.rephrase_outputs_.sample_id
        }

        self.fetches_train_g = {
            "loss_g": self.train_ops["train_op_g"],
            "loss_g_ae": self.losses["loss_g_ae"],
            "loss_g_clas_graph": self.losses["loss_g_clas_graph"],
            "loss_g_clas_sentence": self.losses["loss_g_clas_sentence"],
            "accu_g_graph": self.metrics["accu_g_graph"],
            "accu_g_sentence": self.metrics["accu_g_sentence"],
            "accu_g_gdy_sentence": self.metrics["accu_g_gdy_sentence"]
            # 'adjs': self.adjs,
            # 'identities': self.identities
        }
        self.fetches_train_d = {
            "loss_d": self.train_ops["train_op_d"],
            "loss_d_clas_graph": self.losses["loss_d_clas_graph"],
            "loss_d_clas_sentence": self.losses["loss_d_clas_sentence"],
            "accu_d_graph": self.metrics["accu_d_graph"],
            "accu_d_sentence": self.metrics["accu_d_sentence"]
            # 'adjs': self.adjs,
            # 'identities': self.identities
        }
        fetches_eval = {"batch_size": get_batch_size(self.text_ids)}
        fetches_eval.update(self.losses)
        fetches_eval.update(self.metrics)
        fetches_eval.update(self.samples)
        self.fetches_eval = fetches_eval
Ejemplo n.º 7
0
    def _get_loss_train_op(self):
        # Aggregates loss
        loss_rephraser =  (tf.get_collection('loss_rephraser_list')[0] + tf.get_collection('loss_rephraser_list')[1] + tf.get_collection('loss_rephraser_list')[2])/3.
        w_recon = 1.0
        w_fine = 0.5
        w_xx2 = 0.0
        self.loss = loss_rephraser + w_fine*self.loss_fine + w_recon*self.loss_mask_recon #+ w_xx2*self.loss_xx2###check逐量级修改
        
        # Creates optimizers
        self.vars = collect_trainable_variables([self.transformer_encoder, self.word_embedder, self.self_graph_encoder,
                self.downmlp, self.PRelu, self.rephrase_encoder, self.rephrase_decoder])
                
        # Train Op
        self.train_op_pre = get_train_op(
                self.loss, self.vars, hparams=self._hparams.opt)#learning_rate=self.lr
        
        # Interface tensors
        self.losses = {
            "loss": self.loss,
            "loss_rephraser": loss_rephraser,
            "loss_fine":self.loss_fine,
            "loss_mask_recon":self.loss_mask_recon,
            "loss_xx2": self.loss_xx2,
        }
        self.metrics = {
        }
        self.train_ops = {
            "train_op_pre": self.train_op_pre,
        }
        self.samples = {
            "transferred_yy1_gt": self.text_ids_yy1,
            "transferred_yy1_pred": tf.get_collection('yy_pred_list')[0],
            "transferred_yy2_gt": self.text_ids_yy2,
            "transferred_yy2_pred": tf.get_collection('yy_pred_list')[1],
            "transferred_yy3_gt": self.text_ids_yy3,
            "transferred_yy3_pred": tf.get_collection('yy_pred_list')[2],
            "origin_y1":self.text_ids_y1,
            "origin_y2":self.text_ids_y2,
            "origin_y3":self.text_ids_y3,
            "x1x2":self.x1x2,
            "x1xx2":self.x1xx2
        }

        tf.summary.scalar("loss", self.loss)
        tf.summary.scalar("loss_rephraser", loss_rephraser)
        tf.summary.scalar("loss_fine", self.loss_fine)
        tf.summary.scalar("loss_mask_recon", self.loss_mask_recon)
        tf.summary.scalar("loss_xx2", self.loss_xx2)
        self.merged = tf.summary.merge_all()
        self.fetches_train_pre = {
            "loss": self.train_ops["train_op_pre"],
            "loss_rephraser": self.losses["loss_rephraser"],
            "loss_fine": self.losses["loss_fine"],
            "loss_mask_recon": self.losses["loss_mask_recon"],
            "loss_xx2": self.losses["loss_xx2"],
            "merged": self.merged,
        }
        fetches_eval = {"batch_size": get_batch_size(self.x1x2yx1xx2_ids),
        "merged": self.merged,
        }
        fetches_eval.update(self.losses)
        fetches_eval.update(self.metrics)
        fetches_eval.update(self.samples)
        self.fetches_eval = fetches_eval
Ejemplo n.º 8
0
def build_model(data_batch, data, step):
    batch_size, num_steps = [
        tf.shape(data_batch["x_value_text_ids"])[d] for d in range(2)
    ]
    vocab = data.vocab('y_aux')

    id2str = '<{}>'.format
    bos_str, eos_str = map(id2str, (vocab.bos_token_id, vocab.eos_token_id))

    def single_bleu(ref, hypo):
        ref = [id2str(u if u != vocab.unk_token_id else -1) for u in ref]
        hypo = [id2str(u) for u in hypo]

        ref = tx.utils.strip_special_tokens(' '.join(ref),
                                            strip_bos=bos_str,
                                            strip_eos=eos_str)
        hypo = tx.utils.strip_special_tokens(' '.join(hypo), strip_eos=eos_str)

        return 0.01 * tx.evals.sentence_bleu(references=[ref], hypothesis=hypo)

    # losses
    losses = {}

    # embedders
    embedders = {
        name: tx.modules.WordEmbedder(vocab_size=data.vocab(name).size,
                                      hparams=hparams)
        for name, hparams in config_model.embedders.items()
    }

    # encoders
    y_encoder = tx.modules.TransformerEncoder(hparams=config_model.y_encoder)
    x_encoder = tx.modules.TransformerEncoder(hparams=config_model.x_encoder)

    def concat_encoder_outputs(outputs):
        return tf.concat(outputs, -1)

    def encode(ref_flag):
        y_str = y_strs[ref_flag]
        y_ids = data_batch['{}_text_ids'.format(y_str)]
        y_embeds = embedders['y_aux'](y_ids)
        y_sequence_length = data_batch['{}_length'.format(y_str)]
        y_enc_outputs = y_encoder(y_embeds, sequence_length=y_sequence_length)
        y_enc_outputs = concat_encoder_outputs(y_enc_outputs)

        x_str = x_strs[ref_flag]
        x_ids = {
            field: data_batch['{}_{}_text_ids'.format(x_str, field)][:, 1:-1]
            for field in x_fields
        }
        x_embeds = tf.concat([
            embedders['x_{}'.format(field)](x_ids[field]) for field in x_fields
        ],
                             axis=-1)

        x_sequence_length = data_batch['{}_{}_length'.format(
            x_str, x_fields[0])] - 2
        x_enc_outputs = x_encoder(x_embeds, sequence_length=x_sequence_length)
        x_enc_outputs = concat_encoder_outputs(x_enc_outputs)

        return y_ids, y_embeds, y_enc_outputs, y_sequence_length, \
            x_ids, x_embeds, x_enc_outputs, x_sequence_length

    encode_results = [encode(ref_flag) for ref_flag in range(2)]
    y_ids, y_embeds, y_enc_outputs, y_sequence_length, \
            x_ids, x_embeds, x_enc_outputs, x_sequence_length = \
        zip(*encode_results)

    # get rnn cell
    # rnn_cell = tx.core.layers.get_rnn_cell(config_model.rnn_cell)

    def get_decoder(y__ref_flag, x_ref_flag, tgt_ref_flag, beam_width=None):
        output_layer_params = \
             {'output_layer': tf.identity} if copy_flag else {'vocab_size': vocab.size}

        if attn_flag:  # attention
            memory = tf.concat(
                [y_enc_outputs[y__ref_flag], x_enc_outputs[x_ref_flag]],
                axis=1)
            memory_sequence_length = None
            copy_memory_sequence_length = None

            tgt_embedding = tf.concat([
                tf.zeros(shape=[1, embedders['y_aux'].dim]),
                embedders['y_aux'].embedding[1:, :]
            ],
                                      axis=0)
            decoder = tx.modules.TransformerCopyDecoder(
                embedding=tgt_embedding, hparams=config_model.decoder)

        return decoder

    def get_decoder_and_outputs(y__ref_flag,
                                x_ref_flag,
                                tgt_ref_flag,
                                params,
                                beam_width=None):
        decoder = get_decoder(y__ref_flag,
                              x_ref_flag,
                              tgt_ref_flag,
                              beam_width=beam_width)
        if beam_width is None:
            ret = decoder(**params)
        else:
            ret = decoder(beam_width=beam_width, **params)
        return decoder, ret

    get_decoder_and_outputs = tf.make_template('get_decoder_and_outputs',
                                               get_decoder_and_outputs)

    gamma = tf.Variable(1, dtype=tf.float32, trainable=True)
    gamma = tf.exp(tf.log(gamma))

    def teacher_forcing(y__ref_flag, x_ref_flag, loss_name):
        tgt_flag = x_ref_flag
        tgt_str = y_strs[tgt_flag]
        memory_sequence_length = tf.add(y_sequence_length[y__ref_flag] - 1,
                                        x_sequence_length[x_ref_flag])
        sequence_length = data_batch['{}_length'.format(tgt_str)] - 1

        memory = tf.concat(
            [y_enc_outputs[y__ref_flag], x_enc_outputs[x_ref_flag]],
            axis=1)  # [64 61 384]

        decoder, rets = get_decoder_and_outputs(
            y__ref_flag,
            x_ref_flag,
            tgt_flag,
            {
                'memory': memory,  #print_mem,
                'memory_sequence_length': memory_sequence_length,
                'copy_memory': x_enc_outputs[x_ref_flag],
                'copy_memory_sequence_length': x_sequence_length[x_ref_flag],
                'source_ids':
                x_ids[x_ref_flag]['value'],  #print_ids,         # source_ids
                'gamma': gamma,
                'decoding_strategy': 'train_greedy',
                'inputs': y_embeds[tgt_flag]
                [:, :-1, :],  #[:, 1:, :], #target yence embeds (ignore <BOS>)
                'alpha': config_model.alpha,
                'sequence_length': sequence_length,
                'mode': tf.estimator.ModeKeys.TRAIN
            })

        tgt_y_ids = data_batch['{}_text_ids'.format(
            tgt_str)][:, 1:]  # ground_truth ids (ignore <BOS>)
        tf_outputs = rets[0]
        gens = rets[2]
        loss = tx.losses.sequence_sparse_softmax_cross_entropy(
            labels=tgt_y_ids,
            logits=tf_outputs.logits,
            sequence_length=data_batch['{}_length'.format(tgt_str)] - 1)
        # average_across_timesteps=True,
        # sum_over_timesteps=False)
        # loss = tf.reduce_mean(loss, 0)

        if copy_flag and FLAGS.exact_cover_w != 0:
            # sum_copy_probs = list(map(lambda t: tf.cast(t, tf.float32), final_state.sum_copy_probs))
            copy_probs = (1 - gens) * rets[1]
            sum_copy_probs = tf.reduce_sum(copy_probs, 1)
            # sum_copy_probs = tf.split(sum_copy_probs, tf.shape(sum_copy_probs)[0], axis=0)#list(map(lambda  prob: tf.cast(prob, tf.float32), tuple(tf.reduce_sum(copy_probs, 1))))  #[batch_size, len_key]
            memory_lengths = x_sequence_length[
                x_ref_flag]  #[len for len in sd_sequence_length[x_ref_flag]]
            exact_coverage_loss = \
                tf.reduce_mean(tf.reduce_sum(
                    tx.utils.mask_sequences(
                        tf.square(sum_copy_probs - 1.), memory_lengths),
                    1))
            print_xe_loss_op = tf.print(loss_name, 'xe loss:', loss)
            with tf.control_dependencies([print_xe_loss_op]):
                print_op = tf.print(loss_name, 'exact coverage loss :',
                                    exact_coverage_loss)
                with tf.control_dependencies([print_op]):
                    loss += FLAGS.exact_cover_w * exact_coverage_loss
        losses[loss_name] = loss

        return decoder, rets, loss, tgt_y_ids

    def beam_searching(y__ref_flag, x_ref_flag, beam_width):
        start_tokens = tf.ones_like(data_batch['y_aux_length']) * \
            vocab.bos_token_id
        end_token = vocab.eos_token_id
        memory_sequence_length = tf.add(y_sequence_length[y__ref_flag] - 1,
                                        x_sequence_length[x_ref_flag])
        sequence_length = data_batch['{}_length'.format(
            y_strs[y__ref_flag])] - 1

        memory = tf.concat(
            [y_enc_outputs[y__ref_flag], x_enc_outputs[x_ref_flag]], axis=1)
        source_ids = tf.concat(
            [y_ids[y__ref_flag], x_ids[x_ref_flag]['value']], axis=1)

        #decoder, (bs_outputs, seq_len)
        decoder, bs_outputs = get_decoder_and_outputs(
            y__ref_flag,
            x_ref_flag,
            None,
            {
                'memory': memory,  #print_mem,
                'memory_sequence_length': memory_sequence_length,
                'copy_memory': x_enc_outputs[x_ref_flag],
                'copy_memory_sequence_length': x_sequence_length[x_ref_flag],
                'gamma': gamma,
                'source_ids': x_ids[x_ref_flag]
                ['value'],  # source_ids,#x_ids[x_ref_flag]['entry'],        #[ batch_size, source_length]
                # 'decoding_strategy': 'infer_sample',  only for random sampling
                'alpha': config_model.alpha,
                'start_tokens': start_tokens,
                'end_token': end_token,
                'max_decoding_length': config_train.infer_max_decoding_length
            },
            beam_width=beam_width)

        return decoder, bs_outputs, sequence_length, start_tokens

    decoder, rets, loss, tgt_y_ids = teacher_forcing(1, 0, 'MLE')
    rec_decoder, _, rec_loss, _ = teacher_forcing(1, 1, 'REC')
    rec_weight = FLAGS.rec_w + FLAGS.rec_w_rate * tf.cast(step, tf.float32)
    step_stage = tf.cast(step, tf.float32) / tf.constant(600.0)
    rec_weight = tf.case([(tf.less_equal(step_stage, tf.constant(1.0)), lambda: tf.constant(1.0)), \
                          (tf.greater(step_stage, tf.constant(2.0)), lambda: FLAGS.rec_w)], \
                         default=lambda: tf.constant(1.0) - (step_stage - 1) * (1 - FLAGS.rec_w))
    joint_loss = (1 - rec_weight) * loss + rec_weight * rec_loss
    losses['joint'] = joint_loss

    tiled_decoder, bs_outputs, sequence_length, start_tokens = beam_searching(
        1, 0, config_train.infer_beam_width)

    train_ops = {
        name: get_train_op(losses[name], hparams=config_train.train[name])
        for name in config_train.train
    }

    return train_ops, bs_outputs, rets, sequence_length, tgt_y_ids, start_tokens, gamma
Ejemplo n.º 9
0
    def _get_loss_train_op(self):
        # Aggregates losses
        self.loss_g = self.loss_g_ae + self.lambda_g_graph * self.loss_g_clas_graph + self.lambda_g_sentence * self.loss_g_clas_sentence + self.loss_trans_adj + self.loss_ori_adj
        self.loss_d = self.loss_d_clas_graph + self.loss_d_clas_sentence + self.loss_ori_adj

        # Creates optimizers
        self.g_vars = collect_trainable_variables([
            self.embedder, self.self_graph_encoder, self.label_connector,
            self.cross_graph_encoder, self.rephrase_encoder,
            self.rephrase_decoder
        ])
        self.d_vars = collect_trainable_variables([
            self.clas_embedder, self.classifier_graph,
            self.classifier_sentence, self.adj_embedder, self.adj_encoder,
            self.conv1d_1, self.conv1d_2, self.bn1, self.conv1d_3, self.bn2,
            self.conv1d_4, self.bn3, self.conv1d_5
        ])

        self.train_op_g = get_train_op(self.loss_g,
                                       self.g_vars,
                                       hparams=self._hparams.opt)
        self.train_op_g_ae = get_train_op(self.loss_g_ae,
                                          self.g_vars,
                                          hparams=self._hparams.opt)
        self.train_op_d = get_train_op(self.loss_d,
                                       self.d_vars,
                                       hparams=self._hparams.opt)

        # Interface tensors
        self.losses = {
            "loss_g": self.loss_g,
            "loss_d": self.loss_d,
            "loss_g_ae": self.loss_g_ae,
            "loss_g_clas_graph": self.loss_g_clas_graph,
            "loss_g_clas_sentence": self.loss_g_clas_sentence,
            "loss_d_clas_graph": self.loss_d_clas_graph,
            "loss_d_clas_sentence": self.loss_d_clas_sentence,
            "loss_ori_adj": self.loss_ori_adj,
            "loss_trans_adj": self.loss_trans_adj,
        }
        self.metrics = {
            "accu_d_graph": self.accu_d_graph,
            "accu_d_sentence": self.accu_d_sentence,
            "accu_g_graph": self.accu_g_graph,
            "accu_g_sentence": self.accu_g_sentence,
            "accu_g_gdy_sentence": self.accu_g_gdy_sentence,
            "accu_ori_adj": self.accu_ori_adj,
            "accu_trans_adj": self.accu_trans_adj
        }
        self.train_ops = {
            "train_op_g": self.train_op_g,
            "train_op_g_ae": self.train_op_g_ae,
            "train_op_d": self.train_op_d
        }
        self.samples = {
            "original": self.text_ids[:, 1:],
            "transferred": self.rephrase_outputs_.sample_id
        }

        self.fetches_train_g = {
            "loss_g": self.train_ops["train_op_g"],
            "loss_g_ae": self.losses["loss_g_ae"],
            "loss_g_clas_graph": self.losses["loss_g_clas_graph"],
            "loss_g_clas_sentence": self.losses["loss_g_clas_sentence"],
            "loss_ori_adj": self.losses["loss_ori_adj"],
            "loss_trans_adj": self.losses["loss_trans_adj"],
            "accu_g_graph": self.metrics["accu_g_graph"],
            "accu_g_sentence": self.metrics["accu_g_sentence"],
            "accu_ori_adj": self.metrics["accu_ori_adj"],
            "accu_trans_adj": self.metrics["accu_trans_adj"],
            "adjs_truth": self.adjs[:, 1:, 1:],
            "adjs_preds": self.pred_trans_adjs_binary,
        }
        self.fetches_train_d = {
            "loss_d": self.train_ops["train_op_d"],
            "loss_d_clas_graph": self.losses["loss_d_clas_graph"],
            "loss_d_clas_sentence": self.losses["loss_d_clas_sentence"],
            "loss_ori_adj": self.losses["loss_ori_adj"],
            "accu_d_graph": self.metrics["accu_d_graph"],
            "accu_d_sentence": self.metrics["accu_d_sentence"],
            "accu_ori_adj": self.metrics["accu_ori_adj"],
            "adjs_truth": self.adjs[:, 1:, 1:],
            "adjs_preds": self.pred_ori_adjs_binary,
        }
        fetches_eval = {"batch_size": get_batch_size(self.text_ids)}
        fetches_eval.update(self.losses)
        fetches_eval.update(self.metrics)
        fetches_eval.update(self.samples)
        self.fetches_eval = fetches_eval
Ejemplo n.º 10
0
def build_model(data_batch, data, step):
    batch_size, num_steps = [
        tf.shape(data_batch["x_value_text_ids"])[d] for d in range(2)]
    vocab = data.vocab('y_aux')

    id2str = '<{}>'.format
    bos_str, eos_str = map(id2str, (vocab.bos_token_id, vocab.eos_token_id))

    def single_bleu(ref, hypo):
        ref = [id2str(u if u != vocab.unk_token_id else -1) for u in ref]
        hypo = [id2str(u) for u in hypo]

        ref = tx.utils.strip_special_tokens(
            ' '.join(ref), strip_bos=bos_str, strip_eos=eos_str)
        hypo = tx.utils.strip_special_tokens(
            ' '.join(hypo), strip_eos=eos_str)

        return 0.01 * tx.evals.sentence_bleu(references=[ref], hypothesis=hypo)

    def batch_bleu(refs, hypos):
        return np.array(
            [single_bleu(ref, hypo) for ref, hypo in zip(refs, hypos)],
            dtype=np.float32)

    def lambda_anneal(step_stage):

        print('==========step_stage is {}'.format(step_stage))
        if step_stage <= 1:
            rec_weight = 1
        elif step_stage > 1 and step_stage < 2:
            rec_weight = FLAGS.rec_w - step_stage * 0.1
        return np.array(rec_weight, dtype = tf.float32)

    # losses
    losses = {}

    # embedders
    embedders = {
        name: tx.modules.WordEmbedder(
            vocab_size=data.vocab(name).size, hparams=hparams)
        for name, hparams in config_model.embedders.items()}

    # encoders
    y_encoder = tx.modules.BidirectionalRNNEncoder(
        hparams=config_model.y_encoder)
    x_encoder = tx.modules.BidirectionalRNNEncoder(
        hparams=config_model.x_encoder)


    def concat_encoder_outputs(outputs):
        return tf.concat(outputs, -1)


    def encode(ref_flag):
        y_str = y_strs[ref_flag]
        sent_ids = data_batch['{}_text_ids'.format(y_str)]
        sent_embeds = embedders['y_aux'](sent_ids)
        sent_sequence_length = data_batch['{}_length'.format(y_str)]
        sent_enc_outputs, _ = y_encoder(
            sent_embeds, sequence_length=sent_sequence_length)
        sent_enc_outputs = concat_encoder_outputs(sent_enc_outputs)

        x_str = x_strs[ref_flag]
        sd_ids = {
            field: data_batch['{}_{}_text_ids'.format(x_str, field)][:, 1:-1]
            for field in x_fields}
        sd_embeds = tf.concat(
            [embedders['x_{}'.format(field)](sd_ids[field]) for field in x_fields],
            axis=-1)
        sd_sequence_length = data_batch[
                                 '{}_{}_length'.format(x_str, x_fields[0])] - 2
        sd_enc_outputs, _ = x_encoder(
            sd_embeds, sequence_length=sd_sequence_length)
        sd_enc_outputs = concat_encoder_outputs(sd_enc_outputs)

        return sent_ids, sent_embeds, sent_enc_outputs, sent_sequence_length, \
               sd_ids, sd_embeds, sd_enc_outputs, sd_sequence_length


    encode_results = [encode(ref_str) for ref_str in range(2)]
    sent_ids, sent_embeds, sent_enc_outputs, sent_sequence_length, \
    sd_ids, sd_embeds, sd_enc_outputs, sd_sequence_length = \
        zip(*encode_results)

    # get rnn cell
    rnn_cell = tx.core.layers.get_rnn_cell(config_model.rnn_cell)


    def get_decoder(cell, y__ref_flag, x_ref_flag, tgt_ref_flag,
                    beam_width=None):
        output_layer_params = \
            {'output_layer': tf.identity} if copy_flag else \
                {'vocab_size': vocab.size}

        if attn_flag: # attention
            if FLAGS.attn_x and FLAGS.attn_y_:
                memory = tf.concat(
                    [sent_enc_outputs[y__ref_flag],
                     sd_enc_outputs[x_ref_flag]],
                    axis=1)
                memory_sequence_length = None
            elif FLAGS.attn_y_:
                memory = sent_enc_outputs[y__ref_flag]
                memory_sequence_length = sent_sequence_length[y__ref_flag]
            elif FLAGS.attn_x:
                memory = sd_enc_outputs[x_ref_flag]
                memory_sequence_length = sd_sequence_length[x_ref_flag]
            else:
                raise Exception(
                    "Must specify either y__ref_flag or x_ref_flag.")
            attention_decoder = tx.modules.AttentionRNNDecoder(
                cell=cell,
                memory=memory,
                memory_sequence_length=memory_sequence_length,
                hparams=config_model.attention_decoder,
                **output_layer_params)
            if not copy_flag:
                return attention_decoder
            cell = attention_decoder.cell if beam_width is None else \
                attention_decoder._get_beam_search_cell(beam_width)

        if copy_flag: # copynet
            kwargs = {
                'y__ids': sent_ids[y__ref_flag][:, 1:],
                'y__states': sent_enc_outputs[y__ref_flag][:, 1:],
                'y__lengths': sent_sequence_length[y__ref_flag] - 1,
                'x_ids': sd_ids[x_ref_flag]['value'],
                'x_states': sd_enc_outputs[x_ref_flag],
                'x_lengths': sd_sequence_length[x_ref_flag],
            }

            if tgt_ref_flag is not None:
                kwargs.update({
                    'input_ids': data_batch[
                                     '{}_text_ids'.format(y_strs[tgt_ref_flag])][:, :-1]})

            memory_prefixes = []

            if FLAGS.copy_y_:
                memory_prefixes.append('y_')

            if FLAGS.copy_x:
                memory_prefixes.append('x')

            if beam_width is not None:
                kwargs = {
                    name: tile_batch(value, beam_width)
                    for name, value in kwargs.items()}

            def get_get_copy_scores(memory_ids_states_lengths, output_size):
                memory_copy_states = [
                    tf.layers.dense(
                        memory_states,
                        units=output_size,
                        activation=None,
                        use_bias=False)
                    for _, memory_states, _ in memory_ids_states_lengths]

                def get_copy_scores(query, coverities=None):
                    ret = []

                    if FLAGS.copy_y_:
                        memory = memory_copy_states[len(ret)]
                        if coverities is not None:
                            memory = memory + tf.layers.dense(
                                coverities[len(ret)],
                                units=output_size,
                                activation=None,
                                use_bias=False)
                        memory = tf.nn.tanh(memory)
                        ret_y_ = tf.einsum("bim,bm->bi", memory, query)
                        ret.append(ret_y_)

                    if FLAGS.copy_x:
                        memory = memory_copy_states[len(ret)]
                        if coverities is not None:
                            memory = memory + tf.layers.dense(
                                coverities[len(ret)],
                                units=output_size,
                                activation=None,
                                use_bias=False)
                        memory = tf.nn.tanh(memory)
                        ret_x = tf.einsum("bim,bm->bi", memory, query)
                        ret.append(ret_x)

                    if FLAGS.sd_path:
                        ret_sd_path = FLAGS.sd_path_multiplicator * \
                                      tf.einsum("bi,bij->bj", ret_x, match_align) \
                                      + FLAGS.sd_path_addend
                        ret.append(ret_sd_path)

                    return ret

                return get_copy_scores

            cell = CopyNetWrapper(
                cell=cell, vocab_size=vocab.size,
                memory_ids_states_lengths=[
                    tuple(kwargs['{}_{}'.format(prefix, s)]
                          for s in ('ids', 'states', 'lengths'))
                    for prefix in memory_prefixes],
                input_ids= \
                    kwargs['input_ids'] if tgt_ref_flag is not None else None,
                get_get_copy_scores=get_get_copy_scores,
                coverity_dim=config_model.coverage_state_dim if FLAGS.coverage else None,
                coverity_rnn_cell_hparams=config_model.coverage_rnn_cell if FLAGS.coverage else None,
                disabled_vocab_size=FLAGS.disabled_vocab_size,
                eps=FLAGS.eps)

        decoder = tx.modules.BasicRNNDecoder(
            cell=cell, hparams=config_model.decoder,
            **output_layer_params)
        return decoder

    def get_decoder_and_outputs(
            cell, y__ref_flag, x_ref_flag, tgt_ref_flag, params,
            beam_width=None):
        decoder = get_decoder(
            cell, y__ref_flag, x_ref_flag, tgt_ref_flag,
            beam_width=beam_width)
        if beam_width is None:
            ret = decoder(**params)
        else:
            ret = tx.modules.beam_search_decode(
                decoder_or_cell=decoder,
                beam_width=beam_width,
                **params)
        return (decoder,) + ret

    get_decoder_and_outputs = tf.make_template(
        'get_decoder_and_outputs', get_decoder_and_outputs)

    def teacher_forcing(cell, y__ref_flag, x_ref_flag, loss_name):
        tgt_ref_flag = x_ref_flag
        tgt_str = y_strs[tgt_ref_flag]
        sequence_length = data_batch['{}_length'.format(tgt_str)] - 1
        decoder, tf_outputs, final_state, _ = get_decoder_and_outputs(
            cell, y__ref_flag, x_ref_flag, tgt_ref_flag,
            {'decoding_strategy': 'train_greedy',
             'inputs': sent_embeds[tgt_ref_flag],
             'sequence_length': sequence_length})

        tgt_sent_ids = data_batch['{}_text_ids'.format(tgt_str)][:, 1:]
        loss = tx.losses.sequence_sparse_softmax_cross_entropy(
            labels=tgt_sent_ids,
            logits=tf_outputs.logits,
            sequence_length=sequence_length,
            average_across_batch=False)
        if FLAGS.add_bleu_weight and y__ref_flag is not None \
                and tgt_ref_flag is not None and y__ref_flag != tgt_ref_flag:
            w = tf.py_func(
                batch_bleu, [sent_ids[y__ref_flag], tgt_sent_ids],
                tf.float32, stateful=False, name='W_BLEU')
            w.set_shape(loss.get_shape())
            loss = w * loss
        loss = tf.reduce_mean(loss, 0)

        if copy_flag and FLAGS.exact_cover_w != 0:
            sum_copy_probs = list(map(lambda t: tf.cast(t, tf.float32), final_state.sum_copy_probs))
            memory_lengths = [lengths for _, _, lengths in decoder.cell.memory_ids_states_lengths]
            exact_coverage_losses = [
                tf.reduce_mean(tf.reduce_sum(
                    tx.utils.mask_sequences(
                        tf.square(sum_copy_prob - 1.), memory_length),
                    1))
                for sum_copy_prob, memory_length in zip(sum_copy_probs, memory_lengths)]
            print_xe_loss_op = tf.print(loss_name, 'xe loss:', loss)
            with tf.control_dependencies([print_xe_loss_op]):
                for i, exact_coverage_loss in enumerate(exact_coverage_losses):
                    print_op = tf.print(loss_name, 'exact coverage loss {:d}:'.format(i), exact_coverage_loss)
                    with tf.control_dependencies([print_op]):
                        # exact_cover_w = FLAGS.exact_cover_w + FLAGS.exact_cover_w * tf.cast(step, tf.float32)
                        loss += FLAGS.exact_cover_w * exact_coverage_loss

        losses[loss_name] = loss

        return decoder, tf_outputs, loss


    def beam_searching(cell, y__ref_flag, x_ref_flag, beam_width):
        start_tokens = tf.ones_like(data_batch['y_aux_length']) * \
                       vocab.bos_token_id
        end_token = vocab.eos_token_id

        decoder, bs_outputs, _, _ = get_decoder_and_outputs(
            cell, y__ref_flag, x_ref_flag, None,
            {'embedding': embedders['y_aux'],
             'start_tokens': start_tokens,
             'end_token': end_token,
             'max_decoding_length': config_train.infer_max_decoding_length},
            beam_width=config_train.infer_beam_width)

        return decoder, bs_outputs


    def build_align():
        ref_str = ref_strs[1]
        sent_str = 'y{}'.format(ref_str)
        sent_texts = data_batch['{}_text'.format(sent_str)][:, 1:-1]
        sent_ids = data_batch['{}_text_ids'.format(sent_str)][:, 1:-1]
        #TODO: Here we simply use the embedder previously constructed,
        #therefore it's shared. We have to construct a new one here if we'd
        #like to get align on the fly.
        sent_embeds = embedders['y_aux'](sent_ids)
        sent_sequence_length = data_batch['{}_length'.format(sent_str)] - 2
        sent_enc_outputs, _ = y_encoder(
            sent_embeds, sequence_length=sent_sequence_length)
        sent_enc_outputs = concat_encoder_outputs(sent_enc_outputs)

        sd_field = x_fields[0]
        sd_str = 'x{}_{}'.format(ref_str, sd_field)
        sd_texts = data_batch['{}_text'.format(sd_str)][:, :-1]
        sd_ids = data_batch['{}_text_ids'.format(sd_str)]
        tgt_sd_ids = sd_ids[:, 1:]
        sd_ids = sd_ids[:, :-1]
        sd_sequence_length = data_batch['{}_length'.format(sd_str)] - 1
        sd_embedder = embedders['x_'+sd_field]

        rnn_cell = tx.core.layers.get_rnn_cell(config_model.align_rnn_cell)
        attention_decoder = tx.modules.AttentionRNNDecoder(
            cell=rnn_cell,
            memory=sent_enc_outputs,
            memory_sequence_length=sent_sequence_length,
            vocab_size=vocab.size,
            hparams=config_model.align_attention_decoder)

        tf_outputs, _, tf_sequence_length = attention_decoder(
            decoding_strategy='train_greedy',
            inputs=sd_ids,
            embedding=sd_embedder,
            sequence_length=sd_sequence_length)

        loss = tx.losses.sequence_sparse_softmax_cross_entropy(
            labels=tgt_sd_ids,
            logits=tf_outputs.logits,
            sequence_length=sd_sequence_length)

        start_tokens = tf.ones_like(sd_sequence_length) * vocab.bos_token_id
        end_token = vocab.eos_token_id
        bs_outputs, _, _ = tx.modules.beam_search_decode(
            decoder_or_cell=attention_decoder,
            embedding=sd_embedder,
            start_tokens=start_tokens,
            end_token=end_token,
            max_decoding_length=config_train.infer_max_decoding_length,
            beam_width=config_train.infer_beam_width)

        return (sent_texts, sent_sequence_length), (sd_texts, sd_sequence_length), \
               loss, tf_outputs, bs_outputs


    decoder, tf_outputs, loss = teacher_forcing(rnn_cell, 1, 0, 'MLE')
    rec_decoder, _, rec_loss = teacher_forcing(rnn_cell, 1, 1, 'REC')
    rec_weight = FLAGS.rec_w
    # rec_weight = tf.py_func(
    #     lambda_anneal, [step_stage],
    #     tf.float32, stateful=False, name='lambda_w')
    # rec_weight = tf.cond(step_stage < 1 ,)
    #rec_weight = rec_weight[0]
    #tf.Print('===========rec_w is {}'.format(rec_weight[0]))

    step_stage = tf.cast(step, tf.float32) / tf.constant(600.0)
    rec_weight = tf.case([(tf.less_equal(step_stage, tf.constant(1.0)), lambda:tf.constant(1.0)),\
                         (tf.greater(step_stage, tf.constant(2.0)), lambda:FLAGS.rec_w)],\
                         default=lambda:tf.constant(1.0) - (step_stage - 1) * (1 - FLAGS.rec_w))
    joint_loss = (1 - rec_weight) * loss + rec_weight * rec_loss
    losses['joint'] = joint_loss

    tiled_decoder, bs_outputs = beam_searching(
        rnn_cell, 1, 0, config_train.infer_beam_width)

    align_sents, align_sds, align_loss, align_tf_outputs, align_bs_outputs = \
        build_align()
    losses['align'] = align_loss

    train_ops = {
        name: get_train_op(losses[name], hparams=config_train.train[name])
        for name in config_train.train}

    return train_ops, bs_outputs, \
           align_sents, align_sds, align_tf_outputs, align_bs_outputs
Ejemplo n.º 11
0
    def _build_model(self, inputs, vocab, gamma, lambda_g, lambda_z, lambda_z1,
                     lambda_z2, lambda_ae):

        embedder = WordEmbedder(vocab_size=vocab.size,
                                hparams=self._hparams.embedder)

        encoder = UnidirectionalRNNEncoder(hparams=self._hparams.encoder)

        enc_text_ids = inputs['text_ids'][:, 1:]
        enc_outputs, final_state = encoder(embedder(enc_text_ids),
                                           sequence_length=inputs['length'] -
                                           1)

        z = final_state[:, self._hparams.dim_c:]

        # -------------------- CLASSIFIER ---------------------

        n_classes = self._hparams.num_classes
        z_classifier_l1 = MLPTransformConnector(
            256, hparams=self._hparams.z_classifier_l1)
        z_classifier_l2 = MLPTransformConnector(
            64, hparams=self._hparams.z_classifier_l2)
        z_classifier_out = MLPTransformConnector(
            n_classes if n_classes > 2 else 1)

        z_logits = z_classifier_l1(z)
        z_logits = z_classifier_l2(z_logits)
        z_logits = z_classifier_out(z_logits)
        z_pred = tf.greater(z_logits, 0)
        z_logits = tf.reshape(z_logits, [-1])

        z_pred = tf.to_int64(tf.reshape(z_pred, [-1]))

        loss_z_clas = tf.nn.sigmoid_cross_entropy_with_logits(
            labels=tf.to_float(inputs['labels']), logits=z_logits)
        loss_z_clas = tf.reduce_mean(loss_z_clas)

        accu_z_clas = tx.evals.accuracy(labels=inputs['labels'], preds=z_pred)

        # -------------------________________---------------------

        label_connector = MLPTransformConnector(self._hparams.dim_c)

        labels = tf.to_float(tf.reshape(inputs['labels'], [-1, 1]))

        c = label_connector(labels)
        c_ = label_connector(1 - labels)

        h = tf.concat([c, z], 1)
        h_ = tf.concat([c_, z], 1)

        # Teacher-force decoding and the auto-encoding loss for G

        decoder = AttentionRNNDecoder(
            memory=enc_outputs,
            memory_sequence_length=inputs['length'] - 1,
            cell_input_fn=lambda inputs, attention: inputs,
            vocab_size=vocab.size,
            hparams=self._hparams.decoder)

        connector = MLPTransformConnector(decoder.state_size)

        g_outputs, _, _ = decoder(initial_state=connector(h),
                                  inputs=inputs['text_ids'],
                                  embedding=embedder,
                                  sequence_length=inputs['length'] - 1)

        loss_g_ae = tx.losses.sequence_sparse_softmax_cross_entropy(
            labels=inputs['text_ids'][:, 1:],
            logits=g_outputs.logits,
            sequence_length=inputs['length'] - 1,
            average_across_timesteps=True,
            sum_over_timesteps=False)

        # Gumbel-softmax decoding, used in training

        start_tokens = tf.ones_like(inputs['labels']) * vocab.bos_token_id

        end_token = vocab.eos_token_id

        gumbel_helper = GumbelSoftmaxEmbeddingHelper(embedder.embedding,
                                                     start_tokens, end_token,
                                                     gamma)

        soft_outputs_, _, soft_length_, = decoder(helper=gumbel_helper,
                                                  initial_state=connector(h_))

        soft_outputs, _, soft_length, = decoder(helper=gumbel_helper,
                                                initial_state=connector(h))

        # ---------------------------- SHIFTED LOSS -------------------------------------
        _, encoder_final_state_ = encoder(
            embedder(soft_ids=soft_outputs_.sample_id),
            sequence_length=inputs['length'] - 1)
        _, encoder_final_state = encoder(
            embedder(soft_ids=soft_outputs.sample_id),
            sequence_length=inputs['length'] - 1)
        new_z_ = encoder_final_state_[:, self._hparams.dim_c:]
        new_z = encoder_final_state[:, self._hparams.dim_c:]

        cos_distance_z_ = tf.abs(
            tf.losses.cosine_distance(tf.nn.l2_normalize(z, axis=1),
                                      tf.nn.l2_normalize(new_z_, axis=1),
                                      axis=1))
        cos_distance_z = tf.abs(
            tf.losses.cosine_distance(tf.nn.l2_normalize(z, axis=1),
                                      tf.nn.l2_normalize(new_z, axis=1),
                                      axis=1))
        # ----------------------------______________-------------------------------------

        # Greedy decoding, used in eval

        outputs_, _, length_ = decoder(decoding_strategy='infer_greedy',
                                       initial_state=connector(h_),
                                       embedding=embedder,
                                       start_tokens=start_tokens,
                                       end_token=end_token)

        # Creates classifier

        classifier = Conv1DClassifier(hparams=self._hparams.classifier)

        clas_embedder = WordEmbedder(vocab_size=vocab.size,
                                     hparams=self._hparams.embedder)

        # Classification loss for the classifier

        clas_logits, clas_preds = classifier(
            inputs=clas_embedder(ids=inputs['text_ids'][:, 1:]),
            sequence_length=inputs['length'] - 1)

        loss_d_clas = tf.nn.sigmoid_cross_entropy_with_logits(
            labels=tf.to_float(inputs['labels']), logits=clas_logits)

        loss_d_clas = tf.reduce_mean(loss_d_clas)

        accu_d = tx.evals.accuracy(labels=inputs['labels'], preds=clas_preds)

        # Classification loss for the generator, based on soft samples

        soft_logits, soft_preds = classifier(
            inputs=clas_embedder(soft_ids=soft_outputs_.sample_id),
            sequence_length=soft_length_)

        loss_g_clas = tf.nn.sigmoid_cross_entropy_with_logits(
            labels=tf.to_float(1 - inputs['labels']), logits=soft_logits)

        loss_g_clas = tf.reduce_mean(loss_g_clas)

        # Accuracy on soft samples, for training progress monitoring

        accu_g = tx.evals.accuracy(labels=1 - inputs['labels'],
                                   preds=soft_preds)

        # Accuracy on greedy-decoded samples, for training progress monitoring

        _, gdy_preds = classifier(inputs=clas_embedder(ids=outputs_.sample_id),
                                  sequence_length=length_)

        accu_g_gdy = tx.evals.accuracy(labels=1 - inputs['labels'],
                                       preds=gdy_preds)

        # Aggregates losses

        loss_g = lambda_ae * loss_g_ae + \
                 lambda_g * loss_g_clas + \
                 lambda_z1 * cos_distance_z + cos_distance_z_ * lambda_z2 \
                 - lambda_z * loss_z_clas
        loss_d = loss_d_clas
        loss_z = loss_z_clas

        # Creates optimizers

        g_vars = collect_trainable_variables(
            [embedder, encoder, label_connector, connector, decoder])
        d_vars = collect_trainable_variables([clas_embedder, classifier])
        z_vars = collect_trainable_variables(
            [z_classifier_l1, z_classifier_l2, z_classifier_out])

        train_op_g = get_train_op(loss_g, g_vars, hparams=self._hparams.opt)
        train_op_g_ae = get_train_op(loss_g_ae,
                                     g_vars,
                                     hparams=self._hparams.opt)
        train_op_d = get_train_op(loss_d, d_vars, hparams=self._hparams.opt)
        train_op_z = get_train_op(loss_z, z_vars, hparams=self._hparams.opt)

        # Interface tensors
        self.losses = {
            "loss_g": loss_g,
            "loss_g_ae": loss_g_ae,
            "loss_g_clas": loss_g_clas,
            "loss_d": loss_d_clas,
            "loss_z_clas": loss_z_clas,
            "loss_cos_": cos_distance_z_,
            "loss_cos": cos_distance_z
        }
        self.metrics = {
            "accu_d": accu_d,
            "accu_g": accu_g,
            "accu_g_gdy": accu_g_gdy,
            "accu_z_clas": accu_z_clas
        }
        self.train_ops = {
            "train_op_g": train_op_g,
            "train_op_g_ae": train_op_g_ae,
            "train_op_d": train_op_d,
            "train_op_z": train_op_z
        }
        self.samples = {
            "original": inputs['text_ids'][:, 1:],
            "transferred": outputs_.sample_id,
            "z_vector": z,
            "labels_source": inputs['labels'],
            "labels_target": 1 - inputs['labels'],
            "labels_predicted": gdy_preds
        }

        self.fetches_train_g = {
            "loss_g": self.train_ops["train_op_g"],
            "loss_g_ae": self.losses["loss_g_ae"],
            "loss_g_clas": self.losses["loss_g_clas"],
            "loss_shifted_ae1": self.losses["loss_cos"],
            "loss_shifted_ae2": self.losses["loss_cos_"],
            "accu_g": self.metrics["accu_g"],
            "accu_g_gdy": self.metrics["accu_g_gdy"],
            "accu_z_clas": self.metrics["accu_z_clas"]
        }

        self.fetches_train_z = {
            "loss_z": self.train_ops["train_op_z"],
            "accu_z": self.metrics["accu_z_clas"]
        }

        self.fetches_train_d = {
            "loss_d": self.train_ops["train_op_d"],
            "accu_d": self.metrics["accu_d"]
        }
        fetches_eval = {"batch_size": get_batch_size(inputs['text_ids'])}
        fetches_eval.update(self.losses)
        fetches_eval.update(self.metrics)
        fetches_eval.update(self.samples)
        self.fetches_eval = fetches_eval
Ejemplo n.º 12
0
def build_model(data_batch, data, step):
    batch_size, num_steps = [
        tf.shape(data_batch["x_value_text_ids"])[d] for d in range(2)
    ]
    vocab = data.vocab('y_aux')

    id2str = '<{}>'.format
    bos_str, eos_str = map(id2str, (vocab.bos_token_id, vocab.eos_token_id))

    def single_bleu(ref, hypo):
        ref = [id2str(u if u != vocab.unk_token_id else -1) for u in ref]
        hypo = [id2str(u) for u in hypo]

        ref = tx.utils.strip_special_tokens(' '.join(ref),
                                            strip_bos=bos_str,
                                            strip_eos=eos_str)
        hypo = tx.utils.strip_special_tokens(' '.join(hypo), strip_eos=eos_str)

        return 0.01 * tx.evals.sentence_bleu(references=[ref], hypothesis=hypo)

    def batch_bleu(refs, hypos):
        return np.array(
            [single_bleu(ref, hypo) for ref, hypo in zip(refs, hypos)],
            dtype=np.float32)

    # losses
    losses = {}

    # embedders
    embedders = {
        name: tx.modules.WordEmbedder(vocab_size=data.vocab(name).size,
                                      hparams=hparams)
        for name, hparams in config_model.embedders.items()
    }

    # encoders
    y_encoder = tx.modules.BidirectionalRNNEncoder(
        hparams=config_model.y_encoder)
    x_encoder = tx.modules.BidirectionalRNNEncoder(
        hparams=config_model.x_encoder)

    def concat_encoder_outputs(outputs):
        return tf.concat(outputs, -1)

    def encode_y(ids, length):
        embeds = embedders['y_aux'](ids)
        enc_outputs, _ = y_encoder(embeds, sequence_length=length)
        enc_outputs = concat_encoder_outputs(enc_outputs)
        return Encoded(ids, embeds, enc_outputs, length)

    def encode_x(ids, length):
        embeds = tf.concat(
            [embedders['x_' + field](ids[field]) for field in x_fields],
            axis=-1)
        enc_outputs, _ = x_encoder(embeds, sequence_length=length)
        enc_outputs = concat_encoder_outputs(enc_outputs)
        return Encoded(ids, embeds, enc_outputs, length)

    y_encoded = [
        encode_y(data_batch['{}_text_ids'.format(ref_str)],
                 data_batch['{}_length'.format(ref_str)]) for ref_str in y_strs
    ]
    x_encoded = [
        encode_x(
            {
                field: data_batch['x{}_{}_text_ids'.format(ref_str,
                                                           field)][:, 1:-1]
                for field in x_fields
            }, data_batch['x{}_{}_length'.format(ref_str, x_fields[0])] - 2)
        for ref_str in ref_strs
    ]

    # get rnn cell
    rnn_cell = tx.core.layers.get_rnn_cell(config_model.rnn_cell)

    def get_decoder(cell, y_, x, tgt_ref_flag, beam_width=None):
        output_layer_params = \
            {'output_layer': tf.identity} if copy_flag else \
            {'vocab_size': vocab.size}

        if attn_flag:  # attention
            if FLAGS.attn_x and FLAGS.attn_y_:
                memory = tf.concat([y_.enc_outputs, x.enc_outputs], axis=1)
                memory_sequence_length = None
            elif FLAGS.attn_y_:
                memory = y_.enc_outputs
                memory_sequence_length = y_.length
            elif FLAGS.attn_x:
                memory = x.enc_outputs
                memory_sequence_length = x.length
            attention_decoder = tx.modules.AttentionRNNDecoder(
                cell=cell,
                memory=memory,
                memory_sequence_length=memory_sequence_length,
                hparams=config_model.attention_decoder,
                **output_layer_params)
            if not copy_flag:
                return attention_decoder
            cell = attention_decoder.cell if beam_width is None else \
                   attention_decoder._get_beam_search_cell(beam_width)

        if copy_flag:  # copynet
            kwargs = {
                'y__ids': y_.ids[:, 1:],
                'y__states': y_.enc_outputs[:, 1:],
                'y__lengths': y_.length - 1,
                'x_ids': x.ids['value'],
                'x_states': x.enc_outputs,
                'x_lengths': x.length,
            }

            if tgt_ref_flag is not None:
                kwargs.update({
                    'input_ids':
                    data_batch['{}_text_ids'.format(
                        y_strs[tgt_ref_flag])][:, :-1]
                })

            memory_prefixes = []

            if FLAGS.copy_y_:
                memory_prefixes.append('y_')

            if FLAGS.copy_x:
                memory_prefixes.append('x')

            if beam_width is not None:
                kwargs = {
                    name: tile_batch(value, beam_width)
                    for name, value in kwargs.items()
                }

            def get_get_copy_scores(memory_ids_states_lengths, output_size):
                memory_copy_states = [
                    tf.layers.dense(memory_states,
                                    units=output_size,
                                    activation=None,
                                    use_bias=False)
                    for _, memory_states, _ in memory_ids_states_lengths
                ]

                def get_copy_scores(query, coverities=None):
                    ret = []

                    if FLAGS.copy_y_:
                        memory = memory_copy_states[len(ret)]
                        if coverities is not None:
                            memory = memory + tf.layers.dense(
                                coverities[len(ret)],
                                units=output_size,
                                activation=None,
                                use_bias=False)
                        memory = tf.nn.tanh(memory)
                        ret_y_ = tf.einsum("bim,bm->bi", memory, query)
                        ret.append(ret_y_)

                    if FLAGS.copy_x:
                        memory = memory_copy_states[len(ret)]
                        if coverities is not None:
                            memory + memory + tf.layers.dense(
                                coverities[len(ret)],
                                units=output_size,
                                activation=None,
                                use_bias=False)
                        memory = tf.nn.tanh(memory)
                        ret_x = tf.einsum("bim,bm->bi", memory, query)
                        ret.append(ret_x)

                    return ret

                return get_copy_scores

            cell = CopyNetWrapper(
                cell=cell, vocab_size=vocab.size,
                memory_ids_states_lengths=[
                    tuple(kwargs['{}_{}'.format(prefix, s)]
                          for s in ('ids', 'states', 'lengths'))
                    for prefix in memory_prefixes],
                input_ids=\
                    kwargs['input_ids'] if tgt_ref_flag is not None else None,
                get_get_copy_scores=get_get_copy_scores,
                coverity_dim=config_model.coverity_dim if FLAGS.coverage else None,
                coverity_rnn_cell_hparams=config_model.coverity_rnn_cell if FLAGS.coverage else None,
                disabled_vocab_size=FLAGS.disabled_vocab_size,
                eps=FLAGS.eps)

        decoder = tx.modules.BasicRNNDecoder(cell=cell,
                                             hparams=config_model.decoder,
                                             **output_layer_params)
        return decoder

    def get_decoder_and_outputs(cell,
                                y_,
                                x,
                                tgt_ref_flag,
                                params,
                                beam_width=None):
        decoder = get_decoder(cell, y_, x, tgt_ref_flag, beam_width=beam_width)
        if beam_width is None:
            ret = decoder(**params)
        else:
            ret = tx.modules.beam_search_decode(decoder_or_cell=decoder,
                                                beam_width=beam_width,
                                                **params)
        return (decoder, ) + ret

    get_decoder_and_outputs = tf.make_template('get_decoder_and_outputs',
                                               get_decoder_and_outputs)

    def teacher_forcing(cell,
                        y_,
                        x_ref_flag,
                        loss_name,
                        add_bleu_weight=False):
        x = x_encoded[x_ref_flag]
        tgt_ref_flag = x_ref_flag
        tgt_str = y_strs[tgt_ref_flag]
        sequence_length = data_batch['{}_length'.format(tgt_str)] - 1
        decoder, tf_outputs, final_state, _ = get_decoder_and_outputs(
            cell, y_, x, tgt_ref_flag, {
                'decoding_strategy': 'train_greedy',
                'inputs': y_encoded[tgt_ref_flag].embeds,
                'sequence_length': sequence_length
            })

        tgt_y_ids = data_batch['{}_text_ids'.format(tgt_str)][:, 1:]
        loss = tx.losses.sequence_sparse_softmax_cross_entropy(
            labels=tgt_y_ids,
            logits=tf_outputs.logits,
            sequence_length=sequence_length,
            average_across_batch=False)
        if add_bleu_weight:
            w = tf.py_func(batch_bleu, [y_.ids, tgt_y_ids],
                           tf.float32,
                           stateful=False,
                           name='W_BLEU')
            w.set_shape(loss.get_shape())
            loss = w * loss
        loss = tf.reduce_mean(loss, 0)

        if copy_flag and FLAGS.exact_cover_w != 0:
            sum_copy_probs = list(
                map(lambda t: tf.cast(t, tf.float32),
                    final_state.sum_copy_probs))
            memory_lengths = [
                lengths
                for _, _, lengths in decoder.cell.memory_ids_states_lengths
            ]
            exact_coverage_losses = [
                tf.reduce_mean(
                    tf.reduce_sum(
                        tx.utils.mask_sequences(tf.square(sum_copy_prob - 1.),
                                                memory_length), 1))
                for sum_copy_prob, memory_length in zip(
                    sum_copy_probs, memory_lengths)
            ]
            for i, exact_coverage_loss in enumerate(exact_coverage_losses):
                loss += FLAGS.exact_cover_w * exact_coverage_loss

        losses[loss_name] = loss

        return decoder, tf_outputs, loss

    def beam_searching(cell, y_, x, beam_width):
        start_tokens = tf.ones_like(data_batch['y_aux_length']) * \
            vocab.bos_token_id
        end_token = vocab.eos_token_id

        decoder, bs_outputs, _, bs_length = get_decoder_and_outputs(
            cell,
            y_,
            x,
            None, {
                'embedding': embedders['y_aux'],
                'start_tokens': start_tokens,
                'end_token': end_token,
                'max_decoding_length': config_train.infer_max_decoding_length
            },
            beam_width=beam_width)

        return decoder, bs_outputs, bs_length

    def greedy_decoding(cell, y_, x):
        decoder, bs_outputs, bs_length = beam_searching(cell, y_, x, 1)
        return bs_outputs.predicted_ids[:, :, 0], bs_length[:, 0]

    discriminator = tx.modules.UnidirectionalRNNClassifier(
        hparams=config_model.discriminator)

    def disc(y, x):
        return discriminator(tf.concat([x.enc_outputs, y], axis=1))

    joint_loss = 0.
    disc_loss = 0.
    for flag in range(2):
        xe_decoder, xe_outputs, xe_loss = teacher_forcing(
            rnn_cell, y_encoded[1 - flag], flag, 'XE')
        rec_decoder, rec_outputs, rec_loss = teacher_forcing(
            rnn_cell, y_encoded[1 - flag], 1 - flag, 'REC')

        # print('[info]rec_outputs is:{}'.format(rec_outputs.cell_output))

        greedy_ids, greedy_length = greedy_decoding(rnn_cell,
                                                    y_encoded[1 - flag],
                                                    x_encoded[flag])
        greedy_ids = tf.concat([
            data_batch['y_aux_text_ids'][:, :1],
            tf.cast(greedy_ids, tf.int64)
        ],
                               axis=1)
        greedy_length = 1 + greedy_length
        greedy_encoded = encode_y(greedy_ids, greedy_length)

        bt_decoder, _, bt_loss = teacher_forcing(rnn_cell, greedy_encoded,
                                                 1 - flag, 'BT')


        adv_loss = 2. * disc(rec_outputs.cell_output, x_encoded[1-flag])[0][:, 1] \
                 +      disc(xe_outputs.cell_output, x_encoded[flag])[0][:, 0] \
                 +      disc(rec_outputs.cell_output, x_encoded[flag])[0][:, 0]
        adv_loss = tf.reduce_mean(adv_loss, 0)

        disc_loss = disc_loss + adv_loss
        joint_loss = joint_loss + (rec_loss + FLAGS.bt_w * bt_loss +
                                   FLAGS.adv_w * adv_loss)

    disc_loss = -disc_loss

    losses['joint'] = joint_loss
    losses['disc'] = disc_loss

    tiled_decoder, bs_outputs, _ = beam_searching(
        rnn_cell, y_encoded[1], x_encoded[0], config_train.infer_beam_width)

    disc_variables = discriminator.trainable_variables
    other_variables = list(
        filter(lambda var: var not in disc_variables,
               tf.trainable_variables()))
    variables_of_train_op = {
        'joint': other_variables,
        'disc': disc_variables,
    }

    train_ops = {
        name: get_train_op(losses[name],
                           variables=variables_of_train_op.get(name, None),
                           hparams=config_train.train[name])
        for name in config_train.train
    }

    return train_ops, bs_outputs