コード例 #1
0
ファイル: baby_model_xlnet.py プロジェクト: li3cmz/SRGT
    def _get_loss_train_op(self):
        # Aggregates loss
        loss_rephraser =  (tf.get_collection('loss_rephraser_list')[0] + tf.get_collection('loss_rephraser_list')[1] + tf.get_collection('loss_rephraser_list')[2])/3.

        self.loss = loss_rephraser
        
        # Creates optimizers
        self.vars = collect_trainable_variables([self.xlnet_encoder, self.word_embedder, 
                self.downmlp, self.rephrase_encoder, self.rephrase_decoder])
                
        # Train Op
        self.train_op_pre = get_train_op(
                self.loss, self.vars, learning_rate=self.lr, hparams=self._hparams.opt)
        
        # Interface tensors
        self.losses = {
            "loss": self.loss,
            "loss_rephraser": loss_rephraser,
        }
        self.metrics = {
        }
        self.train_ops = {
            "train_op_pre": self.train_op_pre,
        }
        self.samples = {
            "transferred_yy1_gt": self.text_ids_yy1,
            "transferred_yy1_pred": tf.get_collection('yy_pred_list')[0],
            "transferred_yy2_gt": self.text_ids_yy2,
            "transferred_yy2_pred": tf.get_collection('yy_pred_list')[1],
            "transferred_yy3_gt": self.text_ids_yy3,
            "transferred_yy3_pred": tf.get_collection('yy_pred_list')[2],
            "origin_y1":self.text_ids_y1,
            "origin_y2":self.text_ids_y2,
            "origin_y3":self.text_ids_y3,
            "x1x2":self.x1x2,
            "x1xx2":self.x1xx2
        }

        tf.summary.scalar("loss", self.loss)
        tf.summary.scalar("loss_rephraser", loss_rephraser)
        self.merged = tf.summary.merge_all()
        self.fetches_train_pre = {
            "loss": self.train_ops["train_op_pre"],
            "loss_rephraser": self.losses["loss_rephraser"],
            "merged": self.merged,


        }
        fetches_eval = {"batch_size": get_batch_size(self.x1x2yx1xx2_ids),
        "merged": self.merged,
        }
        fetches_eval.update(self.losses)
        fetches_eval.update(self.metrics)
        fetches_eval.update(self.samples)
        self.fetches_eval = fetches_eval
コード例 #2
0
    def _build_model(self, inputs, vocab, gamma, lambda_g):
        """Builds the model.
        """
        embedder = WordEmbedder(vocab_size=vocab.size,
                                hparams=self._hparams.embedder)
        encoder = UnidirectionalRNNEncoder(hparams=self._hparams.encoder)

        # text_ids for encoder, with BOS token removed
        enc_text_ids = inputs['text_ids'][:, 1:]
        enc_outputs, final_state = encoder(embedder(enc_text_ids),
                                           sequence_length=inputs['length'] -
                                           1)
        z = final_state[:, self._hparams.dim_c:]

        # Encodes label
        label_connector = MLPTransformConnector(self._hparams.dim_c)

        # Gets the sentence representation: h = (c, z)
        labels = tf.cast(tf.reshape(inputs['labels'], [-1, 1]), tf.float32)
        c = label_connector(labels)
        c_ = label_connector(1 - labels)
        h = tf.concat([c, z], 1)
        h_ = tf.concat([c_, z], 1)

        # Teacher-force decoding and the auto-encoding loss for G
        decoder = AttentionRNNDecoder(
            memory=enc_outputs,
            memory_sequence_length=inputs['length'] - 1,
            cell_input_fn=lambda inputs, attention: inputs,
            vocab_size=vocab.size,
            hparams=self._hparams.decoder)

        connector = MLPTransformConnector(decoder.state_size)

        g_outputs, _, _ = decoder(initial_state=connector(h),
                                  inputs=inputs['text_ids'],
                                  embedding=embedder,
                                  sequence_length=inputs['length'] - 1)

        loss_g_ae = tx.losses.sequence_sparse_softmax_cross_entropy(
            labels=inputs['text_ids'][:, 1:],
            logits=g_outputs.logits,
            sequence_length=inputs['length'] - 1,
            average_across_timesteps=True,
            sum_over_timesteps=False)

        # Gumbel-softmax decoding, used in training
        start_tokens = tf.ones_like(inputs['labels']) * vocab.bos_token_id
        end_token = vocab.eos_token_id
        gumbel_helper = GumbelSoftmaxEmbeddingHelper(embedder.embedding,
                                                     start_tokens, end_token,
                                                     gamma)

        soft_outputs_, _, soft_length_, = decoder(helper=gumbel_helper,
                                                  initial_state=connector(h_))

        # Greedy decoding, used in eval
        outputs_, _, length_ = decoder(decoding_strategy='infer_greedy',
                                       initial_state=connector(h_),
                                       embedding=embedder,
                                       start_tokens=start_tokens,
                                       end_token=end_token)

        # Creates classifier
        classifier = Conv1DClassifier(hparams=self._hparams.classifier)
        clas_embedder = WordEmbedder(vocab_size=vocab.size,
                                     hparams=self._hparams.embedder)

        # Classification loss for the classifier
        clas_logits, clas_preds = classifier(
            inputs=clas_embedder(ids=inputs['text_ids'][:, 1:]),
            sequence_length=inputs['length'] - 1)
        loss_d_clas = tf.nn.sigmoid_cross_entropy_with_logits(
            labels=tf.cast(inputs['labels'], tf.float32), logits=clas_logits)
        loss_d_clas = tf.reduce_mean(loss_d_clas)
        accu_d = tx.evals.accuracy(labels=inputs['labels'], preds=clas_preds)

        # Classification loss for the generator, based on soft samples
        soft_logits, soft_preds = classifier(
            inputs=clas_embedder(soft_ids=soft_outputs_.sample_id),
            sequence_length=soft_length_)
        loss_g_clas = tf.nn.sigmoid_cross_entropy_with_logits(
            labels=tf.cast(1 - inputs['labels'], tf.float32),
            logits=soft_logits)
        loss_g_clas = tf.reduce_mean(loss_g_clas)

        # Accuracy on soft samples, for training progress monitoring
        accu_g = tx.evals.accuracy(labels=1 - inputs['labels'],
                                   preds=soft_preds)

        # Accuracy on greedy-decoded samples, for training progress monitoring
        _, gdy_preds = classifier(inputs=clas_embedder(ids=outputs_.sample_id),
                                  sequence_length=length_)
        accu_g_gdy = tx.evals.accuracy(labels=1 - inputs['labels'],
                                       preds=gdy_preds)

        # Aggregates losses
        loss_g = loss_g_ae + lambda_g * loss_g_clas
        loss_d = loss_d_clas

        # Creates optimizers
        g_vars = collect_trainable_variables(
            [embedder, encoder, label_connector, connector, decoder])
        d_vars = collect_trainable_variables([clas_embedder, classifier])

        train_op_g = get_train_op(loss_g, g_vars, hparams=self._hparams.opt)
        train_op_g_ae = get_train_op(loss_g_ae,
                                     g_vars,
                                     hparams=self._hparams.opt)
        train_op_d = get_train_op(loss_d, d_vars, hparams=self._hparams.opt)

        # Interface tensors
        self.losses = {
            "loss_g": loss_g,
            "loss_g_ae": loss_g_ae,
            "loss_g_clas": loss_g_clas,
            "loss_d": loss_d_clas
        }
        self.metrics = {
            "accu_d": accu_d,
            "accu_g": accu_g,
            "accu_g_gdy": accu_g_gdy,
        }
        self.train_ops = {
            "train_op_g": train_op_g,
            "train_op_g_ae": train_op_g_ae,
            "train_op_d": train_op_d
        }
        self.samples = {
            "original": inputs['text_ids'][:, 1:],
            "transferred": outputs_.sample_id
        }

        self.fetches_train_g = {
            "loss_g": self.train_ops["train_op_g"],
            "loss_g_ae": self.losses["loss_g_ae"],
            "loss_g_clas": self.losses["loss_g_clas"],
            "accu_g": self.metrics["accu_g"],
            "accu_g_gdy": self.metrics["accu_g_gdy"],
        }
        self.fetches_train_d = {
            "loss_d": self.train_ops["train_op_d"],
            "accu_d": self.metrics["accu_d"]
        }
        fetches_eval = {"batch_size": get_batch_size(inputs['text_ids'])}
        fetches_eval.update(self.losses)
        fetches_eval.update(self.metrics)
        fetches_eval.update(self.samples)
        self.fetches_eval = fetches_eval
コード例 #3
0
    def _build_model(self, inputs, vocab, gamma, lambda_g, lambda_z, lambda_z1,
                     lambda_z2, lambda_ae):

        embedder = WordEmbedder(vocab_size=vocab.size,
                                hparams=self._hparams.embedder)

        encoder = UnidirectionalRNNEncoder(hparams=self._hparams.encoder)

        enc_text_ids = inputs['text_ids'][:, 1:]
        enc_outputs, final_state = encoder(embedder(enc_text_ids),
                                           sequence_length=inputs['length'] -
                                           1)

        z = final_state[:, self._hparams.dim_c:]

        # -------------------- CLASSIFIER ---------------------

        n_classes = self._hparams.num_classes
        z_classifier_l1 = MLPTransformConnector(
            256, hparams=self._hparams.z_classifier_l1)
        z_classifier_l2 = MLPTransformConnector(
            64, hparams=self._hparams.z_classifier_l2)
        z_classifier_out = MLPTransformConnector(
            n_classes if n_classes > 2 else 1)

        z_logits = z_classifier_l1(z)
        z_logits = z_classifier_l2(z_logits)
        z_logits = z_classifier_out(z_logits)
        z_pred = tf.greater(z_logits, 0)
        z_logits = tf.reshape(z_logits, [-1])

        z_pred = tf.to_int64(tf.reshape(z_pred, [-1]))

        loss_z_clas = tf.nn.sigmoid_cross_entropy_with_logits(
            labels=tf.to_float(inputs['labels']), logits=z_logits)
        loss_z_clas = tf.reduce_mean(loss_z_clas)

        accu_z_clas = tx.evals.accuracy(labels=inputs['labels'], preds=z_pred)

        # -------------------________________---------------------

        label_connector = MLPTransformConnector(self._hparams.dim_c)

        labels = tf.to_float(tf.reshape(inputs['labels'], [-1, 1]))

        c = label_connector(labels)
        c_ = label_connector(1 - labels)

        h = tf.concat([c, z], 1)
        h_ = tf.concat([c_, z], 1)

        # Teacher-force decoding and the auto-encoding loss for G

        decoder = AttentionRNNDecoder(
            memory=enc_outputs,
            memory_sequence_length=inputs['length'] - 1,
            cell_input_fn=lambda inputs, attention: inputs,
            vocab_size=vocab.size,
            hparams=self._hparams.decoder)

        connector = MLPTransformConnector(decoder.state_size)

        g_outputs, _, _ = decoder(initial_state=connector(h),
                                  inputs=inputs['text_ids'],
                                  embedding=embedder,
                                  sequence_length=inputs['length'] - 1)

        loss_g_ae = tx.losses.sequence_sparse_softmax_cross_entropy(
            labels=inputs['text_ids'][:, 1:],
            logits=g_outputs.logits,
            sequence_length=inputs['length'] - 1,
            average_across_timesteps=True,
            sum_over_timesteps=False)

        # Gumbel-softmax decoding, used in training

        start_tokens = tf.ones_like(inputs['labels']) * vocab.bos_token_id

        end_token = vocab.eos_token_id

        gumbel_helper = GumbelSoftmaxEmbeddingHelper(embedder.embedding,
                                                     start_tokens, end_token,
                                                     gamma)

        soft_outputs_, _, soft_length_, = decoder(helper=gumbel_helper,
                                                  initial_state=connector(h_))

        soft_outputs, _, soft_length, = decoder(helper=gumbel_helper,
                                                initial_state=connector(h))

        # ---------------------------- SHIFTED LOSS -------------------------------------
        _, encoder_final_state_ = encoder(
            embedder(soft_ids=soft_outputs_.sample_id),
            sequence_length=inputs['length'] - 1)
        _, encoder_final_state = encoder(
            embedder(soft_ids=soft_outputs.sample_id),
            sequence_length=inputs['length'] - 1)
        new_z_ = encoder_final_state_[:, self._hparams.dim_c:]
        new_z = encoder_final_state[:, self._hparams.dim_c:]

        cos_distance_z_ = tf.abs(
            tf.losses.cosine_distance(tf.nn.l2_normalize(z, axis=1),
                                      tf.nn.l2_normalize(new_z_, axis=1),
                                      axis=1))
        cos_distance_z = tf.abs(
            tf.losses.cosine_distance(tf.nn.l2_normalize(z, axis=1),
                                      tf.nn.l2_normalize(new_z, axis=1),
                                      axis=1))
        # ----------------------------______________-------------------------------------

        # Greedy decoding, used in eval

        outputs_, _, length_ = decoder(decoding_strategy='infer_greedy',
                                       initial_state=connector(h_),
                                       embedding=embedder,
                                       start_tokens=start_tokens,
                                       end_token=end_token)

        # Creates classifier

        classifier = Conv1DClassifier(
            hparams=self._hparams.classifier
        )  #Conv1DClassifier(hparams=self._hparams.classifier)
        discriminator = Conv1DClassifier(hparams=self._hparams.discriminator)

        clas_embedder = WordEmbedder(vocab_size=vocab.size,
                                     hparams=self._hparams.embedder)

        d_embedder = WordEmbedder(vocab_size=vocab.size,
                                  hparams=self._hparams.embedder)

        # Classification loss for the classifier
        true_samples = d_embedder(ids=inputs['text_ids'][:, 1:])
        clas_logits, clas_preds = discriminator(
            inputs=true_samples, sequence_length=inputs['length'] - 1)
        # print(clas_logits.shape)
        clas_logits, d_logits = tf.split(clas_logits, 2, 1)
        clas_logits = tf.squeeze(clas_logits)
        d_logits = tf.squeeze(d_logits)
        if self._hparams.WGAN:
            loss_d_clas = tf.nn.sigmoid_cross_entropy_with_logits(
                labels=tf.to_float(inputs['labels']), logits=clas_logits)
            accu_d_r = tx.evals.accuracy(labels=inputs['labels'],
                                         preds=(clas_logits >= 0.5))

            loss_d_dis = -tf.reduce_mean(d_logits)
            # Classification loss for the generator, based on soft samples
            fake_samples = d_embedder(soft_ids=soft_outputs_.sample_id)
            soft_logits, soft_preds = discriminator(
                inputs=fake_samples, sequence_length=soft_length_)
            clas_logits, d_logits = tf.split(soft_logits, 2, 1)
            clas_logits = tf.squeeze(clas_logits)
            d_logits = tf.squeeze(d_logits)
            loss_d_dis = loss_d_dis + tf.reduce_mean(
                d_logits)  # tf.reduce_mean(loss_g_clas)
            loss_d_clas = loss_d_clas + lambda_g * tf.nn.sigmoid_cross_entropy_with_logits(
                labels=tf.to_float(1 - inputs['labels']), logits=clas_logits)
            accu_d_f = tx.evals.accuracy(labels=1 - inputs['labels'],
                                         preds=(clas_logits >= 0.5))

            loss_d_clas = tf.reduce_mean(loss_d_clas)
        else:
            loss_d_clas = tf.nn.sigmoid_cross_entropy_with_logits(
                labels=tf.to_float(inputs['labels']), logits=clas_logits)

            loss_d_clas = tf.reduce_mean(loss_d_clas)
            loss_d_dis = tf.constant(0., tf.float32)

        clas_samples = clas_embedder(ids=inputs['text_ids'][:, 1:])
        clas_logits, clas_preds = classifier(inputs=clas_samples,
                                             sequence_length=inputs['length'] -
                                             1)
        loss_clas = tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.to_float(
            inputs['labels']),
                                                            logits=clas_logits)

        loss_clas = tf.reduce_mean(loss_clas)
        accu_clas = tx.evals.accuracy(labels=inputs['labels'],
                                      preds=clas_preds)

        # Classification loss for the generator, based on soft samples
        fake_samples = d_embedder(soft_ids=soft_outputs_.sample_id)
        soft_logits, soft_preds = discriminator(inputs=fake_samples,
                                                sequence_length=soft_length_)

        soft_logits, d_logits = tf.split(soft_logits, 2, 1)
        soft_logits = tf.squeeze(soft_logits)
        d_logits = tf.squeeze(d_logits)
        if self._hparams.WGAN:
            loss_g_clas = tf.nn.sigmoid_cross_entropy_with_logits(
                labels=tf.to_float(1 - inputs['labels']), logits=soft_logits)
            loss_g_clas = tf.reduce_mean(
                loss_g_clas)  # tf.reduce_mean(loss_g_clas)
            normalized_d_logits = d_logits - tf.math.reduce_max(
                d_logits, axis=0, keepdims=True)
            W = tf.cast(
                get_batch_size(inputs['text_ids']),
                tf.float32) * tf.exp(normalized_d_logits) / tf.reduce_sum(
                    tf.exp(normalized_d_logits), 0)
            W = tf.stop_gradient(W)
            loss_g_dis = -tf.reduce_mean(W * d_logits)
        else:
            loss_g_clas = tf.nn.sigmoid_cross_entropy_with_logits(
                labels=tf.to_float(1 - inputs['labels']), logits=soft_logits)

            loss_g_clas = tf.reduce_mean(loss_g_clas)
            loss_g_dis = tf.constant(0., tf.float32)

        if self._hparams.WGAN:
            # WGAN-GP loss
            alpha = tf.random_uniform(
                shape=[get_batch_size(inputs['text_ids']), 1, 1],
                minval=0.,
                maxval=1.)
            true_samples = tf.cond(
                tf.less(tf.shape(fake_samples)[1],
                        tf.shape(true_samples)[1]),
                true_fn=lambda: true_samples,
                false_fn=lambda: dynamic_padding(true_samples, fake_samples))
            fake_samples = tf.cond(
                tf.less(tf.shape(fake_samples)[1],
                        tf.shape(true_samples)[1]),
                true_fn=lambda: dynamic_padding(fake_samples, true_samples),
                false_fn=lambda: fake_samples)
            differences = fake_samples - true_samples  # fake_samples[:,:16] - true_samples[:,:16]
            interpolates = true_samples + (alpha * differences)
            # D(interpolates, is_reuse=True)
            soft_logits, _ = discriminator(inputs=interpolates,
                                           sequence_length=inputs['length'] -
                                           1)
            _, d_logits = tf.split(soft_logits, 2, 1)
            d_logits = tf.squeeze(d_logits)
            gradients = tf.gradients(d_logits, [interpolates])[0]
            slopes = tf.sqrt(
                tf.reduce_sum(tf.square(gradients), reduction_indices=[1, 2]))
            gradient_penalty = self._hparams.LAMBDA * tf.reduce_mean(
                (slopes - 1.)**2)
        else:
            gradient_penalty = tf.constant(0., tf.float32)

        cfake_samples = clas_embedder(soft_ids=soft_outputs_.sample_id)
        # Accuracy on soft samples, for training progress monitoring
        soft_logits, soft_preds = classifier(inputs=cfake_samples,
                                             sequence_length=soft_length_)
        accu_g = tx.evals.accuracy(labels=1 - inputs['labels'],
                                   preds=soft_preds)

        # Accuracy on greedy-decoded samples, for training progress monitoring

        _, gdy_preds = classifier(inputs=clas_embedder(ids=outputs_.sample_id),
                                  sequence_length=length_)

        accu_g_gdy = tx.evals.accuracy(labels=1 - inputs['labels'],
                                       preds=gdy_preds)

        # Aggregates losses

        loss_g = lambda_ae * loss_g_ae + \
                 lambda_g * (self._hparams.ACGAN_SCALE_G*loss_g_clas + loss_g_dis) + \
                 lambda_z1 * cos_distance_z + cos_distance_z_ * lambda_z2 \
                 - lambda_z * loss_z_clas
        loss_d = self._hparams.ACGAN_SCALE_D * loss_d_clas + loss_d_dis + gradient_penalty
        print("\n==========ACSCALE  D:{}, G:{}=========\n".format(
            self._hparams.ACGAN_SCALE_D, self._hparams.ACGAN_SCALE_G))
        loss_z = loss_z_clas

        # Creates optimizers

        g_vars = collect_trainable_variables(
            [embedder, encoder, label_connector, connector, decoder])
        d_vars = collect_trainable_variables([d_embedder, discriminator])
        clas_vars = collect_trainable_variables([clas_embedder, classifier])
        z_vars = collect_trainable_variables(
            [z_classifier_l1, z_classifier_l2, z_classifier_out])

        train_op_g = get_train_op(loss_g, g_vars, hparams=self._hparams.opt_g)
        train_op_g_ae = get_train_op(loss_g_ae,
                                     g_vars,
                                     hparams=self._hparams.opt)
        if self._hparams.WGAN:
            train_op_d = get_train_op(loss_d,
                                      d_vars,
                                      hparams=self._hparams.opt_d)
        else:
            train_op_d = get_train_op(loss_d,
                                      d_vars,
                                      hparams=self._hparams.opt)
        train_op_c = get_train_op(loss_clas,
                                  clas_vars,
                                  hparams=self._hparams.opt)
        train_op_z = get_train_op(loss_z, z_vars, hparams=self._hparams.opt)

        # Interface tensors
        self.losses = {
            "loss_g": loss_g,
            "loss_g_ae": loss_g_ae,
            "loss_g_clas": loss_g_clas,
            "loss_g_dis": loss_g_dis,
            "loss_d": loss_d,
            "loss_d_clas": loss_d_clas,
            "loss_d_dis": loss_d_dis,
            "loss_clas": loss_clas,
            "loss_gp": gradient_penalty,
            "loss_z_clas": loss_z_clas,
            "loss_cos_": cos_distance_z_,
            "loss_cos": cos_distance_z
        }
        self.metrics = {
            "accu_d_r": accu_d_r,
            "accu_d_f": accu_d_f,
            "accu_clas": accu_clas,
            "accu_g": accu_g,
            "accu_g_gdy": accu_g_gdy,
            "accu_z_clas": accu_z_clas
        }
        self.train_ops = {
            "train_op_g": train_op_g,
            "train_op_g_ae": train_op_g_ae,
            "train_op_d": train_op_d,
            "train_op_z": train_op_z,
            "train_op_c": train_op_c
        }
        self.samples = {
            "original": inputs['text_ids'][:, 1:],
            "transferred": outputs_.sample_id,
            "z_vector": z,
            "labels_source": inputs['labels'],
            "labels_target": 1 - inputs['labels'],
            "labels_predicted": gdy_preds
        }

        self.fetches_train_g = {
            "loss_g": self.train_ops["train_op_g"],
            "loss_g_ae": self.losses["loss_g_ae"],
            "loss_g_clas": self.losses["loss_g_clas"],
            "loss_g_dis": self.losses["loss_g_dis"],
            "loss_shifted_ae1": self.losses["loss_cos"],
            "loss_shifted_ae2": self.losses["loss_cos_"],
            "accu_g": self.metrics["accu_g"],
            "accu_g_gdy": self.metrics["accu_g_gdy"],
            "accu_z_clas": self.metrics["accu_z_clas"]
        }

        self.fetches_train_z = {
            "loss_z": self.train_ops["train_op_z"],
            "accu_z": self.metrics["accu_z_clas"]
        }

        self.fetches_train_d = {
            "loss_d": self.train_ops["train_op_d"],
            "loss_d_clas": self.losses["loss_d_clas"],
            "loss_d_dis": self.losses["loss_d_dis"],
            "loss_c": self.train_ops["train_op_c"],
            "loss_gp": self.losses["loss_gp"],
            "accu_d_r": self.metrics["accu_d_r"],
            "accu_d_f": self.metrics["accu_d_f"],
            "accu_clas": self.metrics["accu_clas"]
        }
        fetches_eval = {"batch_size": get_batch_size(inputs['text_ids'])}
        fetches_eval.update(self.losses)
        fetches_eval.update(self.metrics)
        fetches_eval.update(self.samples)
        self.fetches_eval = fetches_eval
コード例 #4
0
    def _build_model(self, inputs, vocab,lambda_ae, gamma, lambda_D):
        """Builds the model.
        """
        embedder = WordEmbedder(
            vocab_size=vocab.size,
            hparams=self._hparams.embedder)
        # rnn_cell = tf.nn.rnn_cell.BasicRNNCell(hidden_size)

# 'outputs' is a tensor of shape [batch_size, max_time, cell_state_size]

# defining initial state
        encoder = UnidirectionalRNNEncoder(hparams=self._hparams.encoder)
        # initial_state=rnn_cell.zero_state(64, dtype=tf.float32)

        # text_ids for encoder, with BOS removed
        enc_text_ids = inputs['text_ids'][:, 1:]
        enc_outputs, final_state = encoder(embedder(enc_text_ids),
                                           sequence_length=inputs['length'] - 1)
        z = final_state[:, self._hparams.dim_c:]

        # Encodes label
        label_connector = MLPTransformConnector(self._hparams.dim_c)

        # Gets the sentence representation: h = (c, z)
        labels = tf.to_float(tf.reshape(inputs['labels'], [-1, 1]))
        c = label_connector(labels)
        c_ = label_connector(1 - labels)
        h = tf.concat([c, z], 1)
        h_ = tf.concat([c_, z], 1)

        # Teacher-force decoding and the auto-encoding loss for G
        decoder = AttentionRNNDecoder(
            memory=enc_outputs,
            memory_sequence_length=inputs['length'] - 1,
            cell_input_fn=lambda inputs, attention: inputs,
            vocab_size=vocab.size,
            hparams=self._hparams.decoder)

        connector = MLPTransformConnector(decoder.state_size)

        g_outputs, _, _ = decoder(
            initial_state=connector(h), inputs=inputs['text_ids'],
            embedding=embedder, sequence_length=inputs['length'] - 1)


        # sequence_sparse_softmax_cross_entropy <---> tf.nn.softmax_cross_entropy_with_logits_v2
        #    1. calculate y_hat_softmax: softmax to logits(y_hat) 
        #    2. compute cross entropy---> y*tf.log(y_hat_softmax)
        #    3. Sum over different class for an instance

        loss_g_ae = tx.losses.sequence_sparse_softmax_cross_entropy(
            labels=inputs['text_ids'][:, 1:],
            logits=g_outputs.logits,
            sequence_length=inputs['length'] - 1,
            average_across_timesteps=True,
            sum_over_timesteps=False)


        # Gumbel-softmax decoding, used in training
        start_tokens = tf.ones_like(inputs['labels']) * vocab.bos_token_id
        end_token = vocab.eos_token_id
        gumbel_helper = GumbelSoftmaxEmbeddingHelper(
            embedder.embedding, start_tokens, end_token, gamma)
        soft_outputs_, _, soft_length_, = decoder(
            helper=gumbel_helper, initial_state=connector(h_))

        # Greedy decoding, used in evaluation
        outputs_, _, length_ = decoder(
            decoding_strategy='infer_greedy', initial_state=connector(h_),
            embedding=embedder, start_tokens=start_tokens, end_token=end_token)

        # Creates discriminator
        classifier = UnidirectionalRNNClassifier(hparams=self._hparams.classifier)
        clas_embedder = WordEmbedder(vocab_size=vocab.size,
                                     hparams=self._hparams.embedder)

        # Classification loss for the classifier
        clas_logits, clas_preds = classifier(
            inputs=clas_embedder(ids=inputs['text_ids'][:, 1:]),
            sequence_length=inputs['length'] - 1)
        loss_d_clas = tf.nn.sigmoid_cross_entropy_with_logits(
            labels=tf.to_float(inputs['labels']), logits=clas_logits)


        print("========================================================")
        print("Classifier made, number of trainable parameters:")
        print(classifier.trainable_variables)
        print("========================================================")
        prob=tf.nn.sigmoid(clas_logits)
        
        loss_d_clas = tf.reduce_mean(loss_d_clas)
        accu_d = tx.evals.accuracy(labels=inputs['labels'], preds=clas_preds)

        # Classification loss for the generator, based on soft samples
        soft_logits, soft_preds = classifier(
            inputs=clas_embedder(soft_ids=soft_outputs_.sample_id),
            sequence_length=soft_length_)
        loss_g_clas = tf.nn.sigmoid_cross_entropy_with_logits(
            labels=tf.to_float(1 - inputs['labels']), logits=soft_logits)
        loss_g_clas = tf.reduce_mean(loss_g_clas)

        # Accuracy on soft samples, for training progress monitoring
        accu_g = tx.evals.accuracy(labels=1 - inputs['labels'], preds=soft_preds)

        # Accuracy on greedy-decoded samples, for training progress monitoring
        self.my_samples_id=outputs_.sample_id

        
        _, gdy_preds = classifier(
            inputs=clas_embedder(ids=outputs_.sample_id),
            sequence_length=length_)
        accu_g_gdy = tx.evals.accuracy(
            labels=1 - inputs['labels'], preds=gdy_preds)

        # Aggregates losses
        loss_g = (lambda_ae *loss_g_ae) + (lambda_D * loss_g_clas)
        loss_d = loss_d_clas

        # Summaries for losses
        loss_g_ae_summary = tf.summary.scalar(name='loss_g_ae_summary', tensor=loss_g_ae)
        loss_g_clas_summary = tf.summary.scalar(name='loss_g_clas_summary', tensor=loss_g_clas)

        # Creates optimizers
        g_vars = collect_trainable_variables(
            [embedder, encoder, label_connector, connector, decoder])
        d_vars = collect_trainable_variables([clas_embedder, classifier])

        train_op_g = get_train_op(
            loss_g, g_vars, hparams=self._hparams.opt)
        train_op_g_ae = get_train_op(
            loss_g_ae, g_vars, hparams=self._hparams.opt)
        train_op_d = get_train_op(
            loss_d, d_vars, hparams=self._hparams.opt)

        # Interface tensors
        self.losses = {
            "loss_g": loss_g,
            "loss_g_ae": loss_g_ae,
            "loss_g_clas": loss_g_clas,
            "loss_d": loss_d_clas
        }

        self.metrics = {
            "accu_d": accu_d,
            "accu_g": accu_g,
            "accu_g_gdy": accu_g_gdy
        }

        self.train_ops = {
            "train_op_g": train_op_g,
            "train_op_g_ae": train_op_g_ae,
            "train_op_d": train_op_d
        }

        self.samples = {
            "original": inputs['text_ids'],
            "original_labels": inputs['labels'],
            "transferred": outputs_.sample_id,
            "soft_transferred": soft_outputs_.sample_id
        }

        self.summaries = {
            "loss_g_ae_summary": loss_g_ae_summary,
            "loss_g_clas_summary": loss_g_clas_summary
        }

        self.fetches_train_g = {
            "loss_g": self.train_ops["train_op_g"],
            "loss_g_ae": self.losses["loss_g_ae"],
            "loss_g_clas": self.losses["loss_g_clas"],
            "accu_g": self.metrics["accu_g"],
            "accu_g_gdy": self.metrics["accu_g_gdy"],
            "loss_g_ae_summary": self.summaries["loss_g_ae_summary"],
            "loss_g_clas_summary": self.summaries["loss_g_clas_summary"]
        }

        self.fetches_train_d = {
            "loss_d": self.train_ops["train_op_d"],
            "accu_d": self.metrics["accu_d"],
            "y_prob":prob,
            "y_pred": clas_preds,
            "y_true": inputs['labels'],
            "sentences": inputs['text_ids']
        }

        self.fetches_dev_test_d = {
            "y_prob":prob,
            "y_pred": clas_preds,
            "y_true": inputs['labels'],
            "sentences": inputs['text_ids'],
            
            "batch_size": get_batch_size(inputs['text_ids']),            
            "loss_d":self.losses['loss_d'],
            "accu_d": self.metrics["accu_d"],
        }

        fetches_eval = {"batch_size": get_batch_size(inputs['text_ids'])}
        fetches_eval.update(self.losses)
        fetches_eval.update(self.metrics)
        fetches_eval.update(self.samples)
        self.fetches_eval = fetches_eval