예제 #1
0
 def joint_metric(masked_lm_example_loss, masked_lm_log_probs,
                  masked_lm_ids, masked_lm_weights,
                  next_sentence_example_loss,
                  next_sentence_log_probs, next_sentence_labels,
                  per_example_loss, logits, input_ori_ids,
                  input_ids, input_mask):
     generator_metric = generator_metric_fn_eval(
         masked_lm_example_loss, masked_lm_log_probs,
         masked_lm_ids, masked_lm_weights,
         next_sentence_example_loss, next_sentence_log_probs,
         next_sentence_labels)
     discriminator_metric = discriminator_metric_eval(
         per_example_loss, logits, input_ori_ids, input_ids,
         input_mask)
     generator_metric.update(discriminator_metric)
     return generator_metric
    def model_fn(features, labels, mode, params):

        train_op_type = kargs.get('train_op_type', 'joint')
        gen_disc_type = kargs.get('gen_disc_type', 'all_disc')
        print(train_op_type, "===train op type===", gen_disc_type,
              "===generator loss type===")
        if kargs.get('optimization_type', 'grl') == 'grl':
            if_flip_grad = True
            train_op_type = 'joint'
        elif kargs.get('optimization_type', 'grl') == 'minmax':
            if_flip_grad = False
        generator_fn = generator(
            model_config_dict['generator'],
            num_labels_dict['generator'],
            init_checkpoint_dict['generator'],
            model_reuse=None,
            load_pretrained=load_pretrained_dict['generator'],
            model_io_config=model_io_config,
            opt_config=opt_config,
            exclude_scope=exclude_scope_dict.get('generator', ""),
            not_storage_params=not_storage_params_dict.get('generator', []),
            target=target_dict['generator'],
            if_flip_grad=if_flip_grad,
            mask_method="all_mask",
            **kargs)

        tf.logging.info("****** train_op_type:%s *******", train_op_type)
        tf.logging.info("****** optimization_type:%s *******",
                        kargs.get('optimization_type', 'grl'))
        generator_dict = generator_fn(features, labels, mode, params)

        discriminator_fn = discriminator(
            model_config_dict['discriminator'],
            num_labels_dict['discriminator'],
            init_checkpoint_dict['discriminator'],
            model_reuse=None,
            load_pretrained=load_pretrained_dict['discriminator'],
            model_io_config=model_io_config,
            opt_config=opt_config,
            exclude_scope=exclude_scope_dict.get('discriminator', ""),
            not_storage_params=not_storage_params_dict.get(
                'discriminator', []),
            target=target_dict['discriminator'],
            **kargs)

        tf.logging.info("****** true sampled_ids of discriminator *******")
        true_distriminator_features = {}
        true_distriminator_features['input_ids'] = generator_dict[
            'sampled_input_ids']
        true_distriminator_features['input_mask'] = generator_dict[
            'sampled_input_mask']
        true_distriminator_features['segment_ids'] = generator_dict[
            'sampled_segment_ids']
        true_distriminator_features['input_ori_ids'] = generator_dict[
            'sampled_input_ids']
        true_distriminator_features['next_sentence_labels'] = features[
            'next_sentence_labels']
        true_distriminator_features['ori_input_ids'] = generator_dict[
            'sampled_input_ids']

        true_distriminator_dict = discriminator_fn(true_distriminator_features,
                                                   labels, mode, params)

        fake_discriminator_features = {}
        if kargs.get('minmax_mode', 'corrupted') == 'corrupted':
            tf.logging.info("****** gumbel 3-D sampled_ids *******")
        elif kargs.get('minmax_mode', 'corrupted') == 'masked':
            fake_discriminator_features['ori_sampled_ids'] = generator_dict[
                'output_ids']
            discriminator_features['sampled_binary_mask'] = generator_dict[
                'sampled_binary_mask']
            tf.logging.info("****** conditioanl sampled_ids *******")
        fake_discriminator_features['input_ids'] = generator_dict[
            'sampled_ids']
        fake_discriminator_features['input_mask'] = generator_dict[
            'sampled_input_mask']
        fake_discriminator_features['segment_ids'] = generator_dict[
            'sampled_segment_ids']
        fake_discriminator_features['input_ori_ids'] = generator_dict[
            'sampled_input_ids']
        fake_discriminator_features['next_sentence_labels'] = features[
            'next_sentence_labels']
        fake_discriminator_features['ori_input_ids'] = generator_dict[
            'sampled_ids']

        fake_discriminator_dict = discriminator_fn(fake_discriminator_features,
                                                   labels, mode, params)

        use_tpu = 1 if kargs.get('use_tpu', False) else 0

        output_dict = global_gan_loss(model_config_dict['discriminator'],
                                      true_distriminator_dict,
                                      true_distriminator_features,
                                      fake_discriminator_dict,
                                      fake_discriminator_features,
                                      use_tpu=use_tpu)

        discriminator_dict = {}
        discriminator_dict['loss'] = output_dict[
            'loss'] + 0.0 * fake_discriminator_dict['loss']
        discriminator_dict['tvars'] = fake_discriminator_dict['tvars']
        discriminator_dict['per_example_loss'] = fake_discriminator_dict[
            'per_example_loss']
        discriminator_dict['logits'] = fake_discriminator_dict['logits']

        model_io_fn = model_io.ModelIO(model_io_config)

        seq_global_vars = model_io_fn.get_params("cls/seq_global",
                                                 not_storage_params=[])
        discriminator_dict['tvars'].extend(seq_global_vars)

        tvars = []

        loss = discriminator_dict['loss']

        tvars.extend(discriminator_dict['tvars'])

        if kargs.get('joint_train', '1') == '1':
            tf.logging.info(
                "****** joint generator and discriminator training *******")
            tvars.extend(generator_dict['tvars'])
            loss += generator_dict['loss']
        tvars = list(set(tvars))

        # logging_hook = tf.train.LoggingTensorHook({
        # 				"generator_loss" : tf.get_collection('generator_loss'),
        # 				"discriminator_loss":tf.get_collection('discriminator_loss')},
        # 				every_n_iter=1000)

        var_checkpoint_dict_list = []
        for key in init_checkpoint_dict:
            if load_pretrained_dict[key] == "yes":
                if key == 'generator':
                    tmp = {
                        "tvars":
                        generator_dict['tvars'],
                        "init_checkpoint":
                        init_checkpoint_dict['generator'],
                        "exclude_scope":
                        exclude_scope_dict[key],
                        "restore_var_name":
                        model_config_dict['generator'].get(
                            'restore_var_name', [])
                    }
                    if kargs.get("sharing_mode", "none") != "none":
                        tmp['exclude_scope'] = ''
                    var_checkpoint_dict_list.append(tmp)
                elif key == 'discriminator':
                    tmp = {
                        "tvars":
                        discriminator_dict['tvars'],
                        "init_checkpoint":
                        init_checkpoint_dict['discriminator'],
                        "exclude_scope":
                        exclude_scope_dict[key],
                        "restore_var_name":
                        model_config_dict['discriminator'].get(
                            'restore_var_name', [])
                    }
                    var_checkpoint_dict_list.append(tmp)

        use_tpu = 1 if kargs.get('use_tpu', False) else 0

        if len(var_checkpoint_dict_list) >= 1:
            scaffold_fn = model_io_fn.load_multi_pretrained(
                var_checkpoint_dict_list, use_tpu=use_tpu)
        else:
            scaffold_fn = None

        if mode == tf.estimator.ModeKeys.TRAIN:

            if not kargs.get('use_tpu', False):
                metric_dict = discriminator_metric_train(
                    discriminator_dict['per_example_loss'],
                    discriminator_dict['logits'],
                    generator_dict['sampled_input_ids'],
                    generator_dict['sampled_ids'],
                    generator_dict['sampled_input_mask'])
                global_metric_dict = discriminator_metric_global_train(
                    output_dict)

                for key in global_metric_dict:
                    tf.summary.scalar(key, global_metric_dict[key])

                for key in metric_dict:
                    tf.summary.scalar(key, metric_dict[key])
                tf.summary.scalar("generator_loss", generator_dict['loss'])
                tf.summary.scalar("discriminator_loss",
                                  discriminator_dict['loss'])

            if kargs.get('use_tpu', False):
                optimizer_fn = optimizer.Optimizer(opt_config)
                use_tpu = 1
            else:
                optimizer_fn = distributed_optimizer.Optimizer(opt_config)
                use_tpu = 0

            model_io_fn.print_params(tvars, string=", trainable params")

            train_op = get_train_op(generator_dict,
                                    discriminator_dict,
                                    optimizer_fn,
                                    opt_config,
                                    model_config_dict['generator'],
                                    model_config_dict['discriminator'],
                                    use_tpu=use_tpu,
                                    train_op_type=train_op_type,
                                    gen_disc_type=gen_disc_type)

            # update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
            # with tf.control_dependencies(update_ops):
            # 	train_op = optimizer_fn.get_train_op(loss, list(set(tvars)),
            # 					opt_config.init_lr,
            # 					opt_config.num_train_steps,
            # 					use_tpu=use_tpu)

            if kargs.get('use_tpu', False):
                estimator_spec = tf.contrib.tpu.TPUEstimatorSpec(
                    mode=mode,
                    loss=loss,
                    train_op=train_op,
                    scaffold_fn=scaffold_fn
                    # training_hooks=[logging_hook]
                )
            else:
                estimator_spec = tf.estimator.EstimatorSpec(mode=mode,
                                                            loss=loss,
                                                            train_op=train_op)

            return estimator_spec

        elif mode == tf.estimator.ModeKeys.EVAL:

            if kargs.get('joint_train', '0') == '1':

                def joint_metric(masked_lm_example_loss, masked_lm_log_probs,
                                 masked_lm_ids, masked_lm_weights,
                                 next_sentence_example_loss,
                                 next_sentence_log_probs, next_sentence_labels,
                                 per_example_loss, logits, input_ori_ids,
                                 input_ids, input_mask):
                    generator_metric = generator_metric_fn_eval(
                        masked_lm_example_loss, masked_lm_log_probs,
                        masked_lm_ids, masked_lm_weights,
                        next_sentence_example_loss, next_sentence_log_probs,
                        next_sentence_labels)
                    discriminator_metric = discriminator_metric_eval(
                        per_example_loss, logits, input_ori_ids, input_ids,
                        input_mask)
                    generator_metric.update(discriminator_metric)
                    return generator_metric

                tpu_eval_metrics = (joint_metric, [
                    generator_dict['masked_lm_example_loss'],
                    generator_dict['masked_lm_log_probs'],
                    generator_dict['masked_lm_ids'],
                    generator_dict['masked_lm_weights'],
                    generator_dict.get('next_sentence_example_loss', None),
                    generator_dict.get('next_sentence_log_probs', None),
                    generator_dict.get('next_sentence_labels', None),
                    discriminator_dict['per_example_loss'],
                    discriminator_dict['logits'],
                    generator_dict['sampled_input_ids'],
                    generator_dict['sampled_ids'],
                    generator_dict['sampled_input_mask']
                ])
                gpu_eval_metrics = joint_metric(
                    generator_dict['masked_lm_example_loss'],
                    generator_dict['masked_lm_log_probs'],
                    generator_dict['masked_lm_ids'],
                    generator_dict['masked_lm_weights'],
                    generator_dict.get('next_sentence_example_loss', None),
                    generator_dict.get('next_sentence_log_probs', None),
                    generator_dict.get('next_sentence_labels', None),
                    discriminator_dict['per_example_loss'],
                    discriminator_dict['logits'],
                    generator_dict['sampled_input_ids'],
                    generator_dict['sampled_ids'],
                    generator_dict['sampled_input_mask'])
            else:
                gpu_eval_metrics = discriminator_metric_eval(
                    discriminator_dict['per_example_loss'],
                    discriminator_dict['logits'],
                    generator_dict['sampled_input_ids'],
                    generator_dict['sampled_ids'],
                    generator_dict['sampled_input_mask'])
                tpu_eval_metrics = (discriminator_metric_eval, [
                    discriminator_dict['per_example_loss'],
                    discriminator_dict['logits'],
                    generator_dict['sampled_input_ids'],
                    generator_dict['sampled_ids'],
                    generator_dict['sampled_input_mask']
                ])

            if kargs.get('use_tpu', False):
                estimator_spec = tf.contrib.tpu.TPUEstimatorSpec(
                    mode=mode,
                    loss=loss,
                    eval_metrics=tpu_eval_metrics,
                    scaffold_fn=scaffold_fn)
            else:
                estimator_spec = tf.estimator.EstimatorSpec(
                    mode=mode, loss=loss, eval_metric_ops=gpu_eval_metrics)

            return estimator_spec
        else:
            raise NotImplementedError()
예제 #3
0
    def model_fn(features, labels, mode, params):

        train_op_type = kargs.get('train_op_type', 'joint')
        gen_disc_type = kargs.get('gen_disc_type', 'all_disc')
        mask_method = kargs.get('mask_method', 'only_mask')
        use_tpu = 1 if kargs.get('use_tpu', False) else 0
        print(train_op_type, "===train op type===", gen_disc_type,
              "===generator loss type===")
        if mask_method == 'only_mask':
            tf.logging.info(
                "****** generator token generation mask type:%s with only masked token *******",
                mask_method)
        elif mask_method == 'all_mask':
            tf.logging.info(
                "****** generator token generation mask type:%s with all token *******",
                mask_method)
        else:
            mask_method = 'only_mask'
            tf.logging.info(
                "****** generator token generation mask type:%s with only masked token *******",
                mask_method)

        if kargs.get('optimization_type', 'grl') == 'grl':
            if_flip_grad = True
            train_op_type = 'joint'
        elif kargs.get('optimization_type', 'grl') == 'minmax':
            if_flip_grad = False
        else:
            if_flip_grad = True
            train_op_type = 'joint'
        generator_fn = generator(
            model_config_dict['generator'],
            num_labels_dict['generator'],
            init_checkpoint_dict['generator'],
            model_reuse=None,
            load_pretrained=load_pretrained_dict['generator'],
            model_io_config=model_io_config,
            opt_config=opt_config,
            exclude_scope=exclude_scope_dict.get('generator', ""),
            not_storage_params=not_storage_params_dict.get('generator', []),
            target=target_dict['generator'],
            if_flip_grad=if_flip_grad,
            # mask_method='only_mask',
            **kargs)

        tf.logging.info("****** train_op_type:%s *******", train_op_type)
        tf.logging.info("****** optimization_type:%s *******",
                        kargs.get('optimization_type', 'grl'))
        generator_dict = generator_fn(features, labels, mode, params)

        discriminator_fn = discriminator_generator(
            model_config_dict['discriminator'],
            num_labels_dict['discriminator'],
            init_checkpoint_dict['discriminator'],
            model_reuse=None,
            load_pretrained=load_pretrained_dict['discriminator'],
            model_io_config=model_io_config,
            opt_config=opt_config,
            exclude_scope=exclude_scope_dict.get('discriminator', ""),
            not_storage_params=not_storage_params_dict.get(
                'discriminator', []),
            target=target_dict['discriminator'],
            loss='cross_entropy',
            **kargs)

        discriminator_features = {}
        # minmax_mode in ['masked', 'corrupted']
        minmax_mode = kargs.get('minmax_mode', 'corrupted')
        tf.logging.info("****** minmax mode for discriminator: %s *******",
                        minmax_mode)
        if minmax_mode == 'corrupted':
            tf.logging.info("****** gumbel 3-D sampled_ids *******")
        elif minmax_mode == 'masked':
            discriminator_features['ori_sampled_ids'] = generator_dict[
                'output_ids']
            discriminator_features['sampled_binary_mask'] = generator_dict[
                'sampled_binary_mask']
            tf.logging.info("****** conditional sampled_ids *******")
        discriminator_features['input_ids'] = generator_dict['sampled_ids']
        discriminator_features['input_mask'] = generator_dict[
            'sampled_input_mask']
        discriminator_features['segment_ids'] = generator_dict[
            'sampled_segment_ids']
        discriminator_features['input_ori_ids'] = generator_dict[
            'sampled_input_ids']
        discriminator_features['next_sentence_labels'] = features[
            'next_sentence_labels']
        discriminator_features['ori_input_ids'] = generator_dict['sampled_ids']

        discriminator_dict = discriminator_fn(discriminator_features, labels,
                                              mode, params)

        [disc_loss, disc_logits, disc_per_example_loss
         ] = optimal_discriminator(model_config_dict['discriminator'],
                                   generator_dict,
                                   features,
                                   discriminator_dict,
                                   discriminator_features,
                                   use_tpu=use_tpu)

        [
            equal_per_example_loss, equal_loss_all, equal_loss_self,
            not_equal_per_example_loss, not_equal_loss_all, not_equal_loss_self
        ] = modified_loss(disc_per_example_loss,
                          disc_logits,
                          discriminator_features['input_ori_ids'],
                          discriminator_features['ori_input_ids'],
                          discriminator_features['input_mask'],
                          sampled_binary_mask=discriminator_features.get(
                              'sampled_binary_mask', None),
                          **kargs)
        output_dict = {}
        output_dict['logits'] = disc_logits
        output_dict['per_example_loss'] = disc_per_example_loss
        output_dict['loss'] = disc_loss + 0.0 * discriminator_dict["loss"]
        output_dict["equal_per_example_loss"] = equal_per_example_loss,
        output_dict["equal_loss_all"] = equal_loss_all,
        output_dict["equal_loss_self"] = equal_loss_self,
        output_dict["not_equal_per_example_loss"] = not_equal_per_example_loss,
        output_dict["not_equal_loss_all"] = not_equal_loss_all,
        output_dict["not_equal_loss_self"] = not_equal_loss_self
        output_dict['tvars'] = discriminator_dict['tvars']

        model_io_fn = model_io.ModelIO(model_io_config)

        tvars = []

        loss = kargs.get('dis_loss', 1.0) * output_dict['loss']

        tvars.extend(discriminator_dict['tvars'])

        if kargs.get('joint_train', '1') == '1':
            tf.logging.info(
                "****** joint generator and discriminator training *******")
            tvars.extend(generator_dict['tvars'])
            loss += generator_dict['loss']
        tvars = list(set(tvars))

        var_checkpoint_dict_list = []
        for key in init_checkpoint_dict:
            if load_pretrained_dict[key] == "yes":
                if key == 'generator':
                    tmp = {
                        "tvars":
                        generator_dict['tvars'],
                        "init_checkpoint":
                        init_checkpoint_dict['generator'],
                        "exclude_scope":
                        exclude_scope_dict[key],
                        "restore_var_name":
                        model_config_dict['generator'].get(
                            'restore_var_name', [])
                    }
                    if kargs.get("sharing_mode", "none") != "none":
                        tmp['exclude_scope'] = ''
                    var_checkpoint_dict_list.append(tmp)
                elif key == 'discriminator':
                    tmp = {
                        "tvars":
                        discriminator_dict['tvars'],
                        "init_checkpoint":
                        init_checkpoint_dict['discriminator'],
                        "exclude_scope":
                        exclude_scope_dict[key],
                        "restore_var_name":
                        model_config_dict['discriminator'].get(
                            'restore_var_name', [])
                    }
                    var_checkpoint_dict_list.append(tmp)

        use_tpu = 1 if kargs.get('use_tpu', False) else 0

        if len(var_checkpoint_dict_list) >= 1:
            scaffold_fn = model_io_fn.load_multi_pretrained(
                var_checkpoint_dict_list, use_tpu=use_tpu)
        else:
            scaffold_fn = None

        if mode == tf.estimator.ModeKeys.TRAIN:

            if not kargs.get('use_tpu', False):
                metric_dict = discriminator_metric_train(
                    output_dict['per_example_loss'], output_dict['logits'],
                    generator_dict['sampled_input_ids'],
                    generator_dict['sampled_ids'],
                    generator_dict['sampled_input_mask'])

                for key in metric_dict:
                    tf.summary.scalar(key, metric_dict[key])
                tf.summary.scalar("generator_loss", generator_dict['loss'])
                tf.summary.scalar("discriminator_loss",
                                  discriminator_dict['loss'])

            if kargs.get('use_tpu', False):
                optimizer_fn = optimizer.Optimizer(opt_config)
                use_tpu = 1
            else:
                optimizer_fn = distributed_optimizer.Optimizer(opt_config)
                use_tpu = 0

            model_io_fn.print_params(tvars, string=", trainable params")

            train_op = get_train_op(generator_dict,
                                    output_dict,
                                    optimizer_fn,
                                    opt_config,
                                    model_config_dict['generator'],
                                    model_config_dict['discriminator'],
                                    use_tpu=use_tpu,
                                    train_op_type=train_op_type,
                                    gen_disc_type=gen_disc_type)

            # update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
            # with tf.control_dependencies(update_ops):
            # 	train_op = optimizer_fn.get_train_op(loss, list(set(tvars)),
            # 					opt_config.init_lr,
            # 					opt_config.num_train_steps,
            # 					use_tpu=use_tpu)

            if kargs.get('use_tpu', False):
                estimator_spec = tf.contrib.tpu.TPUEstimatorSpec(
                    mode=mode,
                    loss=loss,
                    train_op=train_op,
                    scaffold_fn=scaffold_fn
                    # training_hooks=[logging_hook]
                )
            else:
                estimator_spec = tf.estimator.EstimatorSpec(mode=mode,
                                                            loss=loss,
                                                            train_op=train_op)

            return estimator_spec

        elif mode == tf.estimator.ModeKeys.EVAL:

            if kargs.get('joint_train', '0') == '1':

                def joint_metric(masked_lm_example_loss, masked_lm_log_probs,
                                 masked_lm_ids, masked_lm_weights,
                                 next_sentence_example_loss,
                                 next_sentence_log_probs, next_sentence_labels,
                                 per_example_loss, logits, input_ori_ids,
                                 input_ids, input_mask):
                    generator_metric = generator_metric_fn_eval(
                        masked_lm_example_loss, masked_lm_log_probs,
                        masked_lm_ids, masked_lm_weights,
                        next_sentence_example_loss, next_sentence_log_probs,
                        next_sentence_labels)
                    discriminator_metric = discriminator_metric_eval(
                        per_example_loss, logits, input_ori_ids, input_ids,
                        input_mask)
                    generator_metric.update(discriminator_metric)
                    return generator_metric

                tpu_eval_metrics = (joint_metric, [
                    generator_dict['masked_lm_example_loss'],
                    generator_dict['masked_lm_log_probs'],
                    generator_dict['masked_lm_ids'],
                    generator_dict['masked_lm_weights'],
                    generator_dict.get('next_sentence_example_loss', None),
                    generator_dict.get('next_sentence_log_probs', None),
                    generator_dict.get('next_sentence_labels', None),
                    discriminator_dict['per_example_loss'],
                    discriminator_dict['logits'],
                    generator_dict['sampled_input_ids'],
                    generator_dict['sampled_ids'],
                    generator_dict['sampled_input_mask']
                ])
                gpu_eval_metrics = joint_metric(
                    generator_dict['masked_lm_example_loss'],
                    generator_dict['masked_lm_log_probs'],
                    generator_dict['masked_lm_ids'],
                    generator_dict['masked_lm_weights'],
                    generator_dict.get('next_sentence_example_loss', None),
                    generator_dict.get('next_sentence_log_probs', None),
                    generator_dict.get('next_sentence_labels', None),
                    discriminator_dict['per_example_loss'],
                    discriminator_dict['logits'],
                    generator_dict['sampled_input_ids'],
                    generator_dict['sampled_ids'],
                    generator_dict['sampled_input_mask'])
            else:
                gpu_eval_metrics = discriminator_metric_eval(
                    discriminator_dict['per_example_loss'],
                    discriminator_dict['logits'],
                    generator_dict['sampled_input_ids'],
                    generator_dict['sampled_ids'],
                    generator_dict['sampled_input_mask'])
                tpu_eval_metrics = (discriminator_metric_eval, [
                    discriminator_dict['per_example_loss'],
                    discriminator_dict['logits'],
                    generator_dict['sampled_input_ids'],
                    generator_dict['sampled_ids'],
                    generator_dict['sampled_input_mask']
                ])

            if kargs.get('use_tpu', False):
                estimator_spec = tf.contrib.tpu.TPUEstimatorSpec(
                    mode=mode,
                    loss=loss,
                    eval_metrics=tpu_eval_metrics,
                    scaffold_fn=scaffold_fn)
            else:
                estimator_spec = tf.estimator.EstimatorSpec(
                    mode=mode, loss=loss, eval_metric_ops=gpu_eval_metrics)

            return estimator_spec
        else:
            raise NotImplementedError()
예제 #4
0
    def model_fn(features, labels, mode, params):

        generator_fn = generator(
            model_config_dict['generator'],
            num_labels_dict['generator'],
            init_checkpoint_dict['generator'],
            model_reuse=None,
            load_pretrained=load_pretrained_dict['generator'],
            model_io_config=model_io_config,
            opt_config=opt_config,
            exclude_scope=exclude_scope_dict.get('generator', ""),
            not_storage_params=not_storage_params_dict.get('generator', []),
            target=target_dict['generator'],
            **kargs)
        generator_dict = generator_fn(features, labels, mode, params)

        discriminator_fn = discriminator(
            model_config_dict['discriminator'],
            num_labels_dict['discriminator'],
            init_checkpoint_dict['discriminator'],
            model_reuse=None,
            load_pretrained=load_pretrained_dict['discriminator'],
            model_io_config=model_io_config,
            opt_config=opt_config,
            exclude_scope=exclude_scope_dict.get('discriminator', ""),
            not_storage_params=not_storage_params_dict.get(
                'discriminator', []),
            target=target_dict['discriminator'],
            **kargs)

        discriminator_features = {}
        discriminator_features['input_ids'] = generator_dict['sampled_ids']
        discriminator_features['input_mask'] = generator_dict[
            'sampled_input_mask']
        discriminator_features['segment_ids'] = generator_dict[
            'sampled_segment_ids']
        discriminator_features['input_ori_ids'] = generator_dict[
            'sampled_input_ids']
        discriminator_features['next_sentence_labels'] = features[
            'next_sentence_labels']
        discriminator_dict = discriminator_fn(discriminator_features, labels,
                                              mode, params)

        model_io_fn = model_io.ModelIO(model_io_config)

        tvars = []
        loss = discriminator_dict['loss']
        print(loss)
        tvars.extend(discriminator_dict['tvars'])
        if kargs.get('joint_train', '0') == '1':
            tvars.extend(generator_fn['tvars'])
            loss += generator_dict['loss']

        var_checkpoint_dict_list = []
        for key in init_checkpoint_dict:
            if load_pretrained_dict[key] == "yes":
                if key == 'generator':
                    tmp = {
                        "tvars": generator_dict['tvars'],
                        "init_checkpoint": init_checkpoint_dict['generator'],
                        "exclude_scope": exclude_scope_dict[key]
                    }
                    var_checkpoint_dict_list.append(tmp)
                elif key == 'discriminator':
                    tmp = {
                        "tvars": discriminator_dict['tvars'],
                        "init_checkpoint":
                        init_checkpoint_dict['discriminator'],
                        "exclude_scope": exclude_scope_dict[key]
                    }
                    var_checkpoint_dict_list.append(tmp)

        use_tpu = 1 if kargs.get('use_tpu', False) else 0

        if len(var_checkpoint_dict_list) >= 1:
            scaffold_fn = model_io_fn.load_multi_pretrained(
                var_checkpoint_dict_list, use_tpu=use_tpu)
        else:
            scaffold_fn = None

        if mode == tf.estimator.ModeKeys.TRAIN:

            metric_dict = discriminator_metric_train(
                discriminator_dict['per_example_loss'],
                discriminator_dict['logits'],
                generator_dict['sampled_input_ids'],
                generator_dict['sampled_ids'],
                generator_dict['sampled_input_mask'])

            for key in metric_dict:
                tf.summary.scalar(key, metric_dict[key])

            if kargs.get('use_tpu', False):
                optimizer_fn = optimizer.Optimizer(opt_config)
                use_tpu = 1
            else:
                optimizer_fn = distributed_optimizer.Optimizer(opt_config)
                use_tpu = 0

            model_io_fn.print_params(tvars, string=", trainable params")

            update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
            with tf.control_dependencies(update_ops):
                train_op = optimizer_fn.get_train_op(
                    loss,
                    list(set(tvars)),
                    opt_config.init_lr,
                    opt_config.num_train_steps,
                    use_tpu=use_tpu)

            if kargs.get('use_tpu', False):
                estimator_spec = tf.contrib.tpu.TPUEstimatorSpec(
                    mode=mode,
                    loss=loss,
                    train_op=train_op,
                    scaffold_fn=scaffold_fn)
            else:
                estimator_spec = tf.estimator.EstimatorSpec(mode=mode,
                                                            loss=loss,
                                                            train_op=train_op)

            return estimator_spec

        elif mode == tf.estimator.ModeKeys.EVAL:

            if kargs.get('joint_train', '0') == '1':
                generator_metric = generator_metric_fn_eval(
                    generator_dict['masked_lm_example_loss'],
                    generator_dict['masked_lm_log_probs'],
                    generator_dict['masked_lm_ids'],
                    generator_dict['masked_lm_weights'],
                    generator_dict.get('next_sentence_example_loss', None),
                    generator_dict.get('next_sentence_log_probs', None),
                    generator_dict.get('next_sentence_labels', None))
            else:
                generator_metric = {}

            discriminator_metric = discriminator_metric_eval(
                discriminator_dict['per_example_loss'],
                discriminator_dict['logits'],
                generator_dict['sampled_input_ids'],
                generator_dict['sampled_ids'],
                generator_dict['sampled_input_mask'])

            metric_dict = discriminator_metric
            if len(generator_metric):
                metric_dict.update(discriminator_metric)

            if kargs.get('use_tpu', False):
                estimator_spec = tf.contrib.tpu.TPUEstimatorSpec(
                    mode=mode,
                    loss=loss,
                    eval_metrics=metric_dict,
                    scaffold_fn=scaffold_fn)
            else:
                estimator_spec = tf.estimator.EstimatorSpec(
                    mode=mode, loss=loss, eval_metric_ops=metric_dict)

            return estimator_spec
        else:
            raise NotImplementedError()