示例#1
0
def retrieve_init_savers(hparams):
    """Retrieve a dictionary of all the initial savers for the models.
    Args:
      hparams:  MaskGAN hyperparameters.
    """
    ## Dictionary of init savers.
    init_savers = {}

    ## Load Generator weights from MaskGAN checkpoint.
    if FLAGS.maskgan_ckpt:
        gen_vars = [
            v for v in tf.trainable_variables() if v.op.name.startswith('gen')
        ]
        init_saver = tf.train.Saver(var_list=gen_vars)
        init_savers['init_saver'] = init_saver

        ## Load the Discriminator weights from the MaskGAN checkpoint if
        # the weights are compatible.
        if FLAGS.discriminator_model == 'seq2seq_vd':
            dis_variable_maps = variable_mapping.dis_seq2seq_vd(hparams)
            dis_init_saver = tf.train.Saver(var_list=dis_variable_maps)
            init_savers['dis_init_saver'] = dis_init_saver

    ## Load weights from language model checkpoint.
    if FLAGS.language_model_ckpt_dir:
        if FLAGS.maskgan_ckpt is None:
            ## Generator Variables/Savers.
            if FLAGS.generator_model == 'rnn_nas':
                gen_variable_maps = variable_mapping.rnn_nas(hparams,
                                                             model='gen')
                gen_init_saver = tf.train.Saver(var_list=gen_variable_maps)
                init_savers['gen_init_saver'] = gen_init_saver

            elif FLAGS.generator_model == 'seq2seq_nas':
                # Encoder.
                gen_encoder_variable_maps = variable_mapping.gen_encoder_seq2seq_nas(
                    hparams)
                gen_encoder_init_saver = tf.train.Saver(
                    var_list=gen_encoder_variable_maps)
                # Decoder.
                gen_decoder_variable_maps = variable_mapping.gen_decoder_seq2seq_nas(
                    hparams)
                gen_decoder_init_saver = tf.train.Saver(
                    var_list=gen_decoder_variable_maps)
                init_savers['gen_encoder_init_saver'] = gen_encoder_init_saver
                init_savers['gen_decoder_init_saver'] = gen_decoder_init_saver

            # seq2seq_vd derived from the same code base as seq2seq_zaremba.
            elif (FLAGS.generator_model == 'seq2seq_zaremba'
                  or FLAGS.generator_model == 'seq2seq_vd'):
                # Encoder.
                gen_encoder_variable_maps = variable_mapping.gen_encoder_seq2seq(
                    hparams)
                gen_encoder_init_saver = tf.train.Saver(
                    var_list=gen_encoder_variable_maps)
                # Decoder.
                gen_decoder_variable_maps = variable_mapping.gen_decoder_seq2seq(
                    hparams)
                gen_decoder_init_saver = tf.train.Saver(
                    var_list=gen_decoder_variable_maps)
                init_savers['gen_encoder_init_saver'] = gen_encoder_init_saver
                init_savers['gen_decoder_init_saver'] = gen_decoder_init_saver

            else:
                raise NotImplementedError

        ## Discriminator Variables/Savers.
        if FLAGS.discriminator_model == 'rnn_nas':
            dis_variable_maps = variable_mapping.rnn_nas(hparams, model='dis')
            dis_init_saver = tf.train.Saver(var_list=dis_variable_maps)
            init_savers['dis_init_saver'] = dis_init_saver

        # rnn_vd derived from the same code base as rnn_zaremba.
        elif (FLAGS.discriminator_model == 'rnn_zaremba'
              or FLAGS.discriminator_model == 'rnn_vd'):
            dis_variable_maps = variable_mapping.rnn_zaremba(hparams,
                                                             model='dis')
            dis_init_saver = tf.train.Saver(var_list=dis_variable_maps)
            init_savers['dis_init_saver'] = dis_init_saver

        elif (FLAGS.discriminator_model == 'bidirectional_zaremba'
              or FLAGS.discriminator_model == 'bidirectional_vd'):
            dis_fwd_variable_maps = variable_mapping.dis_fwd_bidirectional(
                hparams)
            dis_bwd_variable_maps = variable_mapping.dis_bwd_bidirectional(
                hparams)
            # Savers for the forward/backward Discriminator components.
            dis_fwd_init_saver = tf.train.Saver(var_list=dis_fwd_variable_maps)
            dis_bwd_init_saver = tf.train.Saver(var_list=dis_bwd_variable_maps)
            init_savers['dis_fwd_init_saver'] = dis_fwd_init_saver
            init_savers['dis_bwd_init_saver'] = dis_bwd_init_saver

        elif FLAGS.discriminator_model == 'cnn':
            dis_variable_maps = variable_mapping.cnn()
            dis_init_saver = tf.train.Saver(var_list=dis_variable_maps)
            init_savers['dis_init_saver'] = dis_init_saver

        elif FLAGS.discriminator_model == 'seq2seq_vd':
            # Encoder.
            dis_encoder_variable_maps = variable_mapping.dis_encoder_seq2seq(
                hparams)
            dis_encoder_init_saver = tf.train.Saver(
                var_list=dis_encoder_variable_maps)
            # Decoder.
            dis_decoder_variable_maps = variable_mapping.dis_decoder_seq2seq(
                hparams)
            dis_decoder_init_saver = tf.train.Saver(
                var_list=dis_decoder_variable_maps)
            init_savers['dis_encoder_init_saver'] = dis_encoder_init_saver
            init_savers['dis_decoder_init_saver'] = dis_decoder_init_saver

    return init_savers
示例#2
0
def retrieve_init_savers(hparams):
  """Retrieve a dictionary of all the initial savers for the models.

  Args:
    hparams:  MaskGAN hyperparameters.
  """
  ## Dictionary of init savers.
  init_savers = {}

  ## Load Generator weights from MaskGAN checkpoint.
  if FLAGS.maskgan_ckpt:
    gen_vars = [
        v for v in tf.trainable_variables() if v.op.name.startswith('gen')
    ]
    init_saver = tf.train.Saver(var_list=gen_vars)
    init_savers['init_saver'] = init_saver

    ## Load the Discriminator weights from the MaskGAN checkpoint if
    # the weights are compatible.
    if FLAGS.discriminator_model == 'seq2seq_vd':
      dis_variable_maps = variable_mapping.dis_seq2seq_vd(hparams)
      dis_init_saver = tf.train.Saver(var_list=dis_variable_maps)
      init_savers['dis_init_saver'] = dis_init_saver

  ## Load weights from language model checkpoint.
  if FLAGS.language_model_ckpt_dir:
    if FLAGS.maskgan_ckpt is None:
      ## Generator Variables/Savers.
      if FLAGS.generator_model == 'rnn_nas':
        gen_variable_maps = variable_mapping.rnn_nas(hparams, model='gen')
        gen_init_saver = tf.train.Saver(var_list=gen_variable_maps)
        init_savers['gen_init_saver'] = gen_init_saver

      elif FLAGS.generator_model == 'seq2seq_nas':
        # Encoder.
        gen_encoder_variable_maps = variable_mapping.gen_encoder_seq2seq_nas(
            hparams)
        gen_encoder_init_saver = tf.train.Saver(
            var_list=gen_encoder_variable_maps)
        # Decoder.
        gen_decoder_variable_maps = variable_mapping.gen_decoder_seq2seq_nas(
            hparams)
        gen_decoder_init_saver = tf.train.Saver(
            var_list=gen_decoder_variable_maps)
        init_savers['gen_encoder_init_saver'] = gen_encoder_init_saver
        init_savers['gen_decoder_init_saver'] = gen_decoder_init_saver

      # seq2seq_vd derived from the same code base as seq2seq_zaremba.
      elif (FLAGS.generator_model == 'seq2seq_zaremba' or
            FLAGS.generator_model == 'seq2seq_vd'):
        # Encoder.
        gen_encoder_variable_maps = variable_mapping.gen_encoder_seq2seq(
            hparams)
        gen_encoder_init_saver = tf.train.Saver(
            var_list=gen_encoder_variable_maps)
        # Decoder.
        gen_decoder_variable_maps = variable_mapping.gen_decoder_seq2seq(
            hparams)
        gen_decoder_init_saver = tf.train.Saver(
            var_list=gen_decoder_variable_maps)
        init_savers['gen_encoder_init_saver'] = gen_encoder_init_saver
        init_savers['gen_decoder_init_saver'] = gen_decoder_init_saver

      else:
        raise NotImplementedError

    ## Discriminator Variables/Savers.
    if FLAGS.discriminator_model == 'rnn_nas':
      dis_variable_maps = variable_mapping.rnn_nas(hparams, model='dis')
      dis_init_saver = tf.train.Saver(var_list=dis_variable_maps)
      init_savers['dis_init_saver'] = dis_init_saver

    # rnn_vd derived from the same code base as rnn_zaremba.
    elif (FLAGS.discriminator_model == 'rnn_zaremba' or
          FLAGS.discriminator_model == 'rnn_vd'):
      dis_variable_maps = variable_mapping.rnn_zaremba(hparams, model='dis')
      dis_init_saver = tf.train.Saver(var_list=dis_variable_maps)
      init_savers['dis_init_saver'] = dis_init_saver

    elif (FLAGS.discriminator_model == 'bidirectional_zaremba' or
          FLAGS.discriminator_model == 'bidirectional_vd'):
      dis_fwd_variable_maps = variable_mapping.dis_fwd_bidirectional(hparams)
      dis_bwd_variable_maps = variable_mapping.dis_bwd_bidirectional(hparams)
      # Savers for the forward/backward Discriminator components.
      dis_fwd_init_saver = tf.train.Saver(var_list=dis_fwd_variable_maps)
      dis_bwd_init_saver = tf.train.Saver(var_list=dis_bwd_variable_maps)
      init_savers['dis_fwd_init_saver'] = dis_fwd_init_saver
      init_savers['dis_bwd_init_saver'] = dis_bwd_init_saver

    elif FLAGS.discriminator_model == 'cnn':
      dis_variable_maps = variable_mapping.cnn()
      dis_init_saver = tf.train.Saver(var_list=dis_variable_maps)
      init_savers['dis_init_saver'] = dis_init_saver

    elif FLAGS.discriminator_model == 'seq2seq_vd':
      # Encoder.
      dis_encoder_variable_maps = variable_mapping.dis_encoder_seq2seq(hparams)
      dis_encoder_init_saver = tf.train.Saver(
          var_list=dis_encoder_variable_maps)
      # Decoder.
      dis_decoder_variable_maps = variable_mapping.dis_decoder_seq2seq(hparams)
      dis_decoder_init_saver = tf.train.Saver(
          var_list=dis_decoder_variable_maps)
      init_savers['dis_encoder_init_saver'] = dis_encoder_init_saver
      init_savers['dis_decoder_init_saver'] = dis_decoder_init_saver

  return init_savers