Example #1
0
def build_experiment_name(config):
    template = "{tag}{model_name}_{encoder}-{latent}-{decoder}-{output}-{experiment}{data}"

    if config.decoder.name == "t_emb":
        model_tags = TEmbedding.get_model_tags(config, config.loss)
    elif config.decoder.name == "rnn":
        model_tags = InkSeq2Seq.get_model_tags(config, config.loss)
    else:
        err_unknown_type(config.decoder.name)

    # data = config.data.data_name
    data = ""
    lr = ""
    if config.experiment.learning_rate.name == "transformer":
        lr = "_tr"
    elif config.experiment.learning_rate.name == "exponential":
        lr = "_exp"
    experiment = "B{}_LR{}".format(config.data.batch_size, lr)

    return template.format(
        tag=config.experiment.tag + "_" if config.experiment.tag else "",
        model_name=model_tags["model_name"],
        encoder=model_tags["encoder"],
        latent=model_tags["latent"],
        decoder=model_tags["decoder"],
        output=model_tags["output"],
        experiment=experiment,
        data=data,
    )
Example #2
0
    def get_rnn_layer(cls,
                      type_str,
                      units,
                      return_sequences,
                      return_state,
                      stateful,
                      name,
                      recurrent_dropout=0.0):
        """Generates an RNN layer.

    Args:
      type_str:
      units:
      return_sequences:
      return_state:
      stateful:
      name:
      recurrent_dropout:

    Returns:
    """
        cell_cls = None
        if type_str == C.LSTM:
            cell_cls = tf.keras.layers.LSTM
        elif type_str == C.GRU:
            cell_cls = tf.keras.layers.GRU
        else:
            err_unknown_type(type_str)

        return cell_cls(units=units,
                        return_sequences=return_sequences,
                        return_state=return_state,
                        stateful=stateful,
                        recurrent_dropout=recurrent_dropout,
                        name=name)
Example #3
0
 def get(cls, config):
     lr_type = config["name"]
     if lr_type == "exponential":
         return ExponentialDecay(**config)
     elif lr_type == "sketch_rnn":
         return SketchRnnDecay(**config)
     elif lr_type == "transformer":
         return TransformerDecay(**config)
     else:
         err_unknown_type(lr_type)
Example #4
0
  def get_model_tags(cls, config, config_loss):
    """Generates a string summarizing experiment parameters.

    Args:
      config:
      config_loss

    Returns:
    """
    if config_loss["stroke"]["loss_type"] == C.NLL_NORMAL:
      output = "normal"
    elif config_loss["stroke"]["loss_type"] == C.NLL_BINORMAL:
      output = "binormal"
    elif config_loss["stroke"]["loss_type"] == C.NLL_GMM:
      output = "gmm"
    else:
      output = config_loss["stroke"]["loss_type"]
  
    latent = "L{}".format(config.embedding.latent_units)
    if config.embedding.use_vae:
      latent += "_vae"
      if isinstance(config_loss.embedding_kld.weight, float):
        latent += "_w" + str(config_loss.embedding_kld.weight)
      else:
        latent += "_aw" + str(config_loss.embedding_kld.weight["values"][1])
  
    if config.encoder.name == "rnn":
      encoder = "{}_{}x{}".format(config.encoder.cell_type,
                                  config.encoder.cell_layers,
                                  config.encoder.cell_units)
      if config.encoder.bidirectional_encoder:
        encoder = "bi" + encoder

      if config.encoder.rec_dropout_rate > 0:
        encoder += "_rdrop{}".format(config.encoder.rec_dropout_rate)
    else:
      err_unknown_type(config.encoder["name"])

    decoder = ""
    if config.decoder.repeat_vae_sample:
      decoder += "rep_"
    if config.decoder.dropout_rate > 0:
      decoder += "ddrop_" + str(config.decoder.dropout_rate)
    if config.decoder.dynamic_h0:
      decoder += "dh0_"

    model_name = "Seq2Seq"
    if config.decoder.autoregressive:
      model_name += "_ar"
  
    return dict(encoder=encoder, latent=latent, decoder=decoder, output=output,
                model_name=model_name)
Example #5
0
    def get_model_tags(cls, config, config_loss):
        """Generates a string summarizing experiment parameters.

    Args:
      config:
      config_loss

    Returns:
    """
        if config_loss["stroke"]["loss_type"] == C.NLL_NORMAL:
            output = "normal"
        elif config_loss["stroke"]["loss_type"] == C.NLL_BINORMAL:
            output = "binormal"
        elif config_loss["stroke"]["loss_type"] == C.NLL_GMM:
            output = "gmm"
        else:
            output = config_loss["stroke"]["loss_type"]

        decoder = "{}x{}".format(config.decoder.layers,
                                 config.decoder.hidden_units[0])

        latent = "L{}".format(config.embedding.latent_units)
        if config.embedding.use_vae:
            latent += "_vae"
            if isinstance(config_loss.embedding_kld.weight, float):
                latent += "_w" + str(config_loss.embedding_kld.weight)
            else:
                latent += "_aw" + str(
                    config_loss.embedding_kld.weight["values"][1])

        if config.encoder.name == "rnn":
            encoder = "{}_{}x{}".format(config.encoder.cell_type,
                                        config.encoder.cell_layers,
                                        config.encoder.cell_units)
            if config.encoder.bidirectional_encoder:
                encoder = "bi" + encoder
        elif config.encoder.name == "transformer":
            encoder = "TR_{}_{}x{}-head_{}-drop_{}".format(
                config.encoder.d_model, config.encoder.layers,
                config.encoder.hidden_units, config.encoder.heads,
                config.encoder.dropout_rate)
            if not config.encoder.autoregressive:
                encoder = "bi" + encoder
        else:
            err_unknown_type(config.encoder["name"])

        return dict(encoder=encoder,
                    latent=latent,
                    decoder=decoder,
                    output=output,
                    model_name="TEMB")
Example #6
0
def build_experiment_name(config):
    template = "PRED_{tag}{pred_model_name}_{predictive}-{emb_model_name}_{encoder}-{latent}-{decoder}-{output}-{loss}-{experiment}{data}"

    if config.decoder.name == "t_emb":
        emb_model_tags = TEmbedding.get_model_tags(config, config.loss.ink)
    elif config.decoder.name == "rnn":
        emb_model_tags = InkSeq2Seq.get_model_tags(config, config.loss.ink)
    else:
        err_unknown_type(config.decoder.name)

    if config.predictive_model.name == "rnn":
        pred_model_tags = RNN.get_model_tags(config.predictive_model)
    elif config.predictive_model.name == "transformer":
        pred_model_tags = TransformerAR.get_model_tags(config.predictive_model)
    else:
        err_unknown_type(config.predictive_model.name)

    # data = config.data.data_name
    data = ""
    lr = ""
    if config.experiment.learning_rate.name == "transformer":
        lr = "_tr"
    elif config.experiment.learning_rate.name == "exponential":
        lr = "_exp"
    experiment = "B{}_LR{}".format(config.data.batch_size, lr)

    loss = "loss_{}{}{}".format(
        "P" if config.loss.apply_predicted_ink else "",
        "E" if config.loss.apply_predicted_embedding else "",
        "R" if config.loss.apply_reconstructed_ink else "")

    return template.format(
        tag=config.experiment.tag + "_" if config.experiment.tag else "",
        pred_model_name=pred_model_tags["model_name"],
        predictive=pred_model_tags["model"],
        emb_model_name=emb_model_tags["model_name"],
        encoder=emb_model_tags["encoder"],
        latent=emb_model_tags["latent"],
        decoder=emb_model_tags["decoder"],
        output=emb_model_tags["output"],
        experiment=experiment,
        loss=loss,
        data=data,
    )
Example #7
0
def build_embedding_model(config_, run_mode=C.RUN_STATIC):
    """Builds model object."""

    if config_.decoder.name == "t_emb":
        model_ = TEmbedding(config_encoder=config_.encoder,
                            config_embedding=config_.embedding,
                            config_decoder=config_.decoder,
                            config_loss=config_.loss,
                            run_mode=run_mode)
    elif config_.decoder.name == "rnn":
        model_ = InkSeq2Seq(config_encoder=config_.encoder,
                            config_embedding=config_.embedding,
                            config_decoder=config_.decoder,
                            config_loss=config_.loss,
                            run_mode=run_mode)
    else:
        err_unknown_type(config_.decoder.name)

    return model_
Example #8
0
    def get_cells(cls, type_str, units, layers=1):
        """Creates a cell.

    Args:
      type_str:
      units:
      layers:

    Returns:
    """
        cells = []

        for _ in range(layers):
            if type_str == C.LSTM:
                cells.append(tf.keras.layers.LSTMCell(units))
            elif type_str == C.GRU:
                cells.append(tf.keras.layers.GRUCell(units))
            else:
                err_unknown_type(type_str)
        return cells
Example #9
0
    def load_meta_data(cls, meta_data_path):
        """Loads meta-data file given the path.

    It is assumed to be in numpy.

    Args:
        meta_data_path:

    Returns:
        Meta-data dictionary or False if it is not found.
    """
        # if not meta_data_path or not os.path.exists(meta_data_path):

        _, ext = os.path.splitext(meta_data_path)
        if ext == ".json":
            meta_fp = tf.io.gfile.GFile(meta_data_path, "r")
            try:
                meta_fp.size()
                print("Loading statistics " + meta_data_path)
                json_stats = json.load(meta_fp)
                stats_np = dict()
                for key_, value_ in json_stats.items():
                    stats_np[key_] = np.array(value_) if isinstance(value_, list) else \
                      value_
                return stats_np
            except tf.errors.NotFoundError:
                print("Meta-data not found.")
                return False

        elif ext == ".npy":
            meta_fp = tf.io.gfile.GFile(meta_data_path, "rb")
            try:
                meta_fp.size()
                print("Loading statistics " + meta_data_path)
                return np.load(meta_fp, allow_pickle=True).item()
            except tf.errors.NotFoundError:
                print("Meta-data not found.")
                return False
        else:
            err_unknown_type(ext)
Example #10
0
    def get(cls, type_str):
        """Creates the activation function, given its type.

    Args:
      type_str:

    Returns:
    """
        # Check if the activation is already callable.
        if callable(type_str):
            return type_str

        # Check if the activation is a built-in or custom function.
        if type_str == C.RELU:
            return tf.nn.relu
        elif type_str == C.ELU:
            return tf.nn.elu
        elif type_str == C.TANH:
            return tf.nn.tanh
        elif type_str == C.SIGMOID:
            return tf.nn.sigmoid
        elif type_str == C.SOFTPLUS:
            return tf.nn.softplus
        elif type_str == C.SOFTMAX:
            return tf.nn.softmax
        elif type_str == C.LRELU:
            return lambda x: tf.nn.leaky_relu(x, alpha=1. / 3.)
        elif type_str == C.CLRELU:
            with tf.compat.v1.name_scope("ClampedLeakyRelu"):

                def clamped_leaky_relu(x):
                    return tf.clip_by_value(tf.nn.leaky_relu(x, alpha=1. / 3.),
                                            -3.0, 3.0)

                return clamped_leaky_relu
        elif type_str is None:
            return None
        else:
            err_unknown_type(type_str)
Example #11
0
  def get_model_tags(cls, config, config_loss=None):
    """Generates a string summarizing experiment parameters.

    Args:
      config:
      config_loss:

    Returns:
    """
  
    if config.predictive_model.get("name", "rnn") == "rnn":
      pred = "{}_{}x{}".format(config.encoder.cell_type,
                               config.predictive_model.cell_layers,
                               config.predictive_model.cell_units)
    elif config.predictive_model.name == "transformer":
      pred = "{}x{}-head_{}-drop_{}".format(
          config.predictive_model.layers,
          config.predictive_model.hidden_units,
          config.predictive_model.heads,
          config.predictive_model.dropout_rate)
    else:
      err_unknown_type(config.predictive_model.name)
  
    return dict(predictive=pred, model_name="PRED")
Example #12
0
def get_config(FLAGS, experiment_id=None):
    """Defines the default configuration."""
    experiment_id = FLAGS.experiment_id or experiment_id

    config = Configuration()
    config.experiment = ExperimentConfig(
        comment=FLAGS.comment,
        tag="",  # Set automatically.
        model_dir=None,  # Set automatically.
        eval_dir=None,  # Set automatically.
        id=experiment_id,
        max_epochs=None,
        max_steps=200000,
        log_frequency=100,
        eval_steps=500,
        checkpoint_frequency=500,
        grad_clip_norm=FLAGS.grad_clip_norm
        if FLAGS.grad_clip_value <= 0 else 0,
        grad_clip_value=FLAGS.grad_clip_value,
        pretrained_emb_id=FLAGS.pretrained_emb_id)
    config.experiment.learning_rate = AttrDict(
        name=FLAGS.learning_rate_type,
        initial_learning_rate=FLAGS.learning_rate,
    )
    if FLAGS.learning_rate_type == "transformer":
        config.experiment.learning_rate.d_model = FLAGS.transformer_dmodel
        config.experiment.learning_rate.warmup_steps = 4000

    config.data = DataConfig(
        data_dir=FLAGS.data_dir,
        data_name=FLAGS.data_name,
        data_tfrecord_fname=C.DATASET_MAP[FLAGS.data_name]
        ["data_tfrecord_fname"],
        data_meta_fname=C.DATASET_MAP[FLAGS.data_name][FLAGS.metadata_type],
        pp_to_origin="position" in FLAGS.metadata_type,
        pp_relative_pos="velocity" in FLAGS.metadata_type,
        normalize=not FLAGS.skip_normalization,
        batch_size=FLAGS.batch_size,
        max_length_threshold=201,
        mask_pen=FLAGS.mask_encoder_pen,
        resampling_factor=FLAGS.resampling_factor,
        t_drop_ratio=FLAGS.t_drop_ratio,
        gt_targets=FLAGS.gt_targets,
        scale_factor=FLAGS.scale_factor,
        affine_prob=FLAGS.affine_prob,
        reverse_prob=FLAGS.reverse_prob,
        n_t_samples=FLAGS.n_t_samples,
        int_t_samples=FLAGS.int_t_samples,
        concat_t_inputs=FLAGS.concat_t_inputs,
        rdp_dataset=FLAGS.rdp_dataset,
        rdp_didi_pp=FLAGS.rdp_didi_pp,
        pos_noise_factor=FLAGS.pos_noise_factor,
    )
    config.gdrive = AttrDict(
        credential=None,  # Set automatically below.
        workbook=
        None,  # Set your workbook ID (see https://github.com/emreaksan/glogger)
        sheet=FLAGS.data_name,
    )

    # Embedding model.
    if FLAGS.encoder_model == "rnn":
        config.encoder = AttrDict(
            name="rnn",
            cell_units=FLAGS.encoder_rnn_units,
            cell_layers=FLAGS.encoder_rnn_layers,
            cell_type=FLAGS.encoder_cell_type,
            bidirectional_encoder=FLAGS.bidirectional_encoder,
            rec_dropout_rate=FLAGS.encoder_rdropout)
    elif FLAGS.encoder_model == "transformer":
        config.encoder = AttrDict(
            name="transformer",
            layers=FLAGS.transformer_layers,
            heads=FLAGS.transformer_heads,
            d_model=FLAGS.transformer_dmodel,
            hidden_units=FLAGS.transformer_hidden_units,
            dropout_rate=FLAGS.transformer_dropout,
            pos_encoding=config.data.max_length_threshold
            if FLAGS.transformer_pos_encoding else 0,
            scale=FLAGS.transformer_scale,
            autoregressive=not FLAGS.bidirectional_encoder)
    else:
        err_unknown_type(FLAGS.encoder_model)

    config.embedding = AttrDict(
        latent_units=FLAGS.latent_units,
        use_vae=FLAGS.use_vae,
    )

    if FLAGS.decoder_model == "rnn":
        config.decoder = AttrDict(
            name="rnn",
            cell_units=FLAGS.
            encoder_rnn_units,  # Using the same hyper-param with the encoder.
            cell_layers=FLAGS.encoder_rnn_layers,
            cell_type=FLAGS.encoder_cell_type,
            dropout_rate=FLAGS.decoder_dropout,
            dynamic_h0=FLAGS.decoder_dynamic_h0,
            repeat_vae_sample=FLAGS.repeat_vae_sample,
            autoregressive=FLAGS.decoder_autoregressive,
        )
        target_key_pen = "pen"
        target_key_stroke = "stroke"
    elif FLAGS.decoder_model == "t_emb":
        config.decoder = AttrDict(
            name="t_emb",
            layers=FLAGS.decoder_layers,
            hidden_units=FLAGS.decoder_hidden_units,
            activation=FLAGS.decoder_activation,
            dropout_rate=FLAGS.decoder_dropout,
            t_frequency_channels=FLAGS.t_frequency_channels,
            regularizer_weight=FLAGS.reg_dec_weight)
        target_key_pen = C.TARGET_T_PEN
        target_key_stroke = C.TARGET_T_STROKE
    else:
        err_unknown_type(FLAGS.decoder_model)

    # Predictive model.
    if FLAGS.predictive_model == "rnn":
        config.predictive_model = AttrDict(
            name="rnn",
            output_size=config.embedding.latent_units,
            cell_units=FLAGS.predictive_rnn_units,
            cell_layers=FLAGS.predictive_rnn_layers,
            cell_type=FLAGS.predictive_cell_type,
            activation=C.RELU,
            use_start_pos=FLAGS.use_start_pos,
            use_end_pos=FLAGS.use_end_pos,
            stop_predictive_grad=FLAGS.stop_predictive_grad,
            num_predictive_inputs=FLAGS.num_pred_inputs,
            pred_input_type=FLAGS.pred_input_type,
        )
    elif FLAGS.predictive_model == "transformer":
        config.predictive_model = AttrDict(
            name="transformer",
            output_size=config.embedding.latent_units,
            layers=FLAGS.p_transformer_layers,
            heads=FLAGS.p_transformer_heads,
            d_model=FLAGS.p_transformer_dmodel,
            latent_units=FLAGS.latent_units,
            hidden_units=FLAGS.p_transformer_hidden_units,
            dropout_rate=FLAGS.p_transformer_dropout,
            pos_encoding=FLAGS.p_transformer_pos_encoding,
            scale=FLAGS.p_transformer_scale,
            use_start_pos=FLAGS.use_start_pos,
            use_end_pos=FLAGS.use_end_pos,
            stop_predictive_grad=FLAGS.stop_predictive_grad,
            num_predictive_inputs=FLAGS.num_pred_inputs,
            pred_input_type=FLAGS.pred_input_type,
            pooling_layer=FLAGS.pooling_layer,
        )
    else:
        err_unknown_type(FLAGS.predictive_model)

    # Sharing flags with the predictive model.
    if FLAGS.position_model == "rnn":
        config.position_model = AttrDict(
            name="rnn",
            output_size=2,
            cell_units=FLAGS.predictive_rnn_units,
            cell_layers=FLAGS.predictive_rnn_layers,
            cell_type=FLAGS.predictive_cell_type,
            activation=C.RELU,
        )
    elif FLAGS.position_model == "transformer":
        config.position_model = AttrDict(
            name="transformer",
            output_size=2,
            layers=FLAGS.p_transformer_layers,
            heads=FLAGS.p_transformer_heads,
            d_model=FLAGS.p_transformer_dmodel,
            hidden_units=FLAGS.p_transformer_hidden_units,
            dropout_rate=FLAGS.p_transformer_dropout,
            pos_encoding=FLAGS.p_transformer_pos_encoding,
            scale=FLAGS.p_transformer_scale,
        )
    elif FLAGS.position_model is None:
        config.position_model = None
    else:
        err_unknown_type(FLAGS.position_model)

    # Loss
    config.loss = AttrDict()

    stroke_ = LossConfig(loss_type=FLAGS.stroke_loss,
                         num_components=20,
                         target_key=target_key_stroke,
                         out_key="stroke_logits",
                         reduce_type=C.R_MEAN_STEP)
    pen_ = LossConfig(eval_only=FLAGS.disable_pen_loss,
                      loss_type=C.NLL_CENT_BINARY,
                      target_key=target_key_pen,
                      out_key="pen_logits",
                      reduce_type=C.R_MEAN_STEP)

    ink_loss = AttrDict(pen=pen_, stroke=stroke_, prefix="reconstruction")

    if FLAGS.use_vae:
        ink_loss.embedding_kld = LossConfig(
            loss_type=FLAGS.kld_type,  # C.KLD_STANDARD or C.KLD_STANDARD_NORM
            target_key=None,
            out_key="embedding",
            reduce_type=C.R_MEAN_STEP)

        ink_loss.embedding_kld.weight = FLAGS.kld_weight
        if FLAGS.kld_increment > 0:
            ink_loss.embedding_kld.weight = dict(type="linear_decay",
                                                 values=[
                                                     FLAGS.kld_start,
                                                     FLAGS.kld_weight,
                                                     FLAGS.kld_increment
                                                 ])

    if FLAGS.reg_emb_weight > 0:
        ink_loss.embedding_l2 = LossConfig(loss_type=C.SNORM_L2,
                                           target_key=None,
                                           out_key="embedding",
                                           reduce_type=C.R_MEAN_STEP,
                                           weight=FLAGS.reg_emb_weight)

    embedding_pred = LossConfig(loss_type=FLAGS.embedding_loss,
                                num_components=FLAGS.embedding_gmm_components,
                                target_key="target",
                                out_key="prediction",
                                reduce_type=C.R_MEAN_STEP)

    position_pred = LossConfig(loss_type=FLAGS.embedding_loss,
                               num_components=FLAGS.embedding_gmm_components,
                               target_key="target",
                               out_key="prediction",
                               reduce_type=C.R_MEAN_STEP)

    # embedding_pred.weight = FLAGS.kld_weight
    # embedding_pred.weight = dict(
    #     type="linear_decay",
    #     values=[FLAGS.kld_start, FLAGS.kld_weight, FLAGS.kld_increment])

    config.loss = AttrDict(
        ink=ink_loss,
        predicted_embedding=AttrDict(predicted_embedding=embedding_pred),
        predicted_ink=copy.deepcopy(ink_loss))

    if config.position_model is not None:
        config.loss.predicted_pos = AttrDict(predicted_pos=position_pred)

    # No kld loss on the predicted ink.
    if "embedding_kld" in config.loss.predicted_ink:
        del config.loss.predicted_ink["embedding_kld"]

    config.loss.apply_predicted_embedding = FLAGS.loss_predicted_embedding
    config.loss.apply_predicted_ink = FLAGS.loss_predicted_ink
    config.loss.apply_reconstructed_ink = FLAGS.loss_reconstructed_ink

    try:
        data_root = os.environ["COSE_DATA_DIR"]
        log_dir = os.environ["COSE_LOG_DIR"]
        eval_dir = os.environ["COSE_EVAL_DIR"]
        gdrive_key = os.environ["GDRIVE_API_KEY"]
    except KeyError:
        if FLAGS.data_dir is None or FLAGS.eval_dir is None or FLAGS.experiment_dir is None:
            raise Exception(
                "Either environment variables or FLAGs must be set.")
        data_root = FLAGS.data_dir
        log_dir = FLAGS.experiment_dir
        eval_dir = FLAGS.eval_dir
        gdrive_key = FLAGS.gdrive_api_key

    # Check if the experiment directory already exists.
    model_dir_query = glob.glob(
        os.path.join(log_dir, config.experiment.id + "*"))
    if model_dir_query:
        model_dir = model_dir_query[0]
        __builtin__.print = Print(os.path.join(model_dir,
                                               "log.txt"))  # Overload print.
        # Load experiment config.
        config = config.from_json(os.path.join(model_dir, "config.json"))
        config.experiment.model_dir = model_dir
        config.experiment.eval_dir = os.path.join(eval_dir,
                                                  os.path.basename(model_dir))
        if "predictive_model" not in config:
            raise NotPredictiveModelError
        print("Loading from " + config.experiment.model_dir)
    else:
        config.experiment.tag = build_experiment_name(config)
        model_dir_name = config.experiment.id + "-" + config.experiment.tag
        config.experiment.model_dir = os.path.join(log_dir, model_dir_name)
        config.experiment.eval_dir = os.path.join(eval_dir, model_dir_name)
        os.mkdir(config.experiment.model_dir)  # Create experiment directory
        __builtin__.print = Print(
            os.path.join(config.experiment.model_dir,
                         "log.txt"))  # Overload print.
        print("Saving to " + config.experiment.model_dir)

    if not isinstance(config.data.data_tfrecord_fname, list):
        config.data.data_tfrecord_fname = [config.data.data_tfrecord_fname]
    data_path = [
        os.path.join(data_root, config.data.data_name, "{}", dp)
        for dp in config.data.data_tfrecord_fname
    ]
    config.data.train_data_path = [dp.format("training") for dp in data_path]
    config.data.valid_data_path = [dp.format("validation") for dp in data_path]
    config.data.test_data_path = [dp.format("test") for dp in data_path]
    config.data.meta_data_path = os.path.join(data_root, config.data.data_name,
                                              config.data.data_meta_fname)

    config.experiment.pretrained_emb_dir = None
    if config.experiment.get("pretrained_emb_id", None) is not None:
        config.experiment.pretrained_emb_dir = glob.glob(
            os.path.join(log_dir,
                         config.experiment.pretrained_emb_id + "-*"))[0]

    config.gdrive.credential = gdrive_key
    if FLAGS.gdrive_api_key == None:
        config.gdrive = None

    config.dump(config.experiment.model_dir)
    return config
Example #13
0
  def call(self, inputs, training=False, **kwargs):
    """Call method.

    Args:
      inputs (dict): expected to contain inputs for the encoder and decoder,
        sequence length and number of strokes ops.
      training: whether in training mode or not.
      **kwargs:

    Returns:
      [batch_size, seq_len, feature_size]
    """
    input_num_strokes = inputs[C.INP_NUM_STROKE]
    # tf.keras compile, fit, predict, etc. methods cause it to be 2-dimensional.
    if len(inputs[C.INP_NUM_STROKE].shape) == 2:
      input_num_strokes = inputs[C.INP_NUM_STROKE][:, 0]

    out_dict = dict()
    gt_reconstruction = self.embedding_model.call(inputs, training=training)
    out_dict["reconstructed_ink"] = gt_reconstruction
    embedding = gt_reconstruction["embedding"]
    out_dict["embedding_sample"] = out_dict["reconstructed_ink"]["embedding_sample"]
    
    # Before making a prediction for the next stroke, reshape
    # embeddings so that we have a sequence of stroke embeddings,
    # representing the diagram samples.
    diagram_embedding = self.batch_stroke_to_diagram(embedding,
                                                     input_num_strokes)
    
    # Probabilistic inputs if the stroke model is probabilistic.
    # draw inputs for the ink model.
    embedding_sample = self.embedding_model.net_embedding.draw_sample(diagram_embedding, greedy=True)
    
    # Determine ink model inputs and targets.
    n_strokes = tf.shape(input=embedding_sample)[0]
    seq_len = tf.shape(input=embedding_sample)[1]
    batch_idx = tf.range(n_strokes)
    # (1) Predict the last step only.
    if self.input_type == "last_step":
      target_idx = input_num_strokes-1

      input_range = tf.tile(tf.range(seq_len)[tf.newaxis, :], [n_strokes, 1])
      mask_ = tf.not_equal(input_range, tf.tile(target_idx[:, tf.newaxis], [1, seq_len]))
      gather_input_idx = tf.reshape(tf.compat.v1.where(mask_), [n_strokes, seq_len - 1, 2])
      pred_input_seq_len = input_num_strokes - 1

      gather_target_idx = tf.stack([
          batch_idx,
          target_idx
          ], axis=-1)
    
    # (2) Leave-one-out with random targets. Use all strokes except to target
    # as input.
    elif self.input_type == "leave_one_out":
      i = tf.random.uniform([n_strokes], minval=0, maxval=1, dtype=tf.float32)
      target_idx = tf.round(i*tf.cast(input_num_strokes - 1, tf.float32))
      target_idx = tf.cast(target_idx, tf.int32)

      input_range = tf.tile(tf.range(seq_len)[tf.newaxis, :], [n_strokes, 1])
      mask_ = tf.not_equal(input_range, tf.tile(target_idx[:, tf.newaxis], [1, seq_len]))
      gather_input_idx = tf.reshape(tf.compat.v1.where(mask_), [n_strokes, seq_len - 1, 2])
      pred_input_seq_len = input_num_strokes - 1

      gather_target_idx = tf.stack([
          batch_idx,
          target_idx
          ], axis=-1)
      
    elif self.input_type in ["random", "ordered", "hybrid"]:
      min_n_stroke = tf.reduce_min(input_tensor=input_num_strokes)
      max_n_stroke = tf.reduce_max(input_tensor=input_num_strokes)
      input_range_ = tf.tile(tf.range(max_n_stroke)[tf.newaxis, :], [n_strokes, 1])
      
      def get_random_inp_target_pairs():
        """Get a randomly generated input set and a target."""
        # Randomly set the number of inputs.
        n_inputs_ = tf.random.uniform([1], minval=2, maxval=min_n_stroke, dtype=tf.int32)[0]
        # Randomly pick a target.
        i_ = tf.random.uniform([n_strokes], minval=0, maxval=1, dtype=tf.float32)
        target_idx_ = tf.round(i_*tf.cast(input_num_strokes - 1, tf.float32))
        target_idx_ = tf.cast(target_idx_, tf.int32)

        mask_ = tf.not_equal(input_range_, tf.tile(target_idx_[:, tf.newaxis], [1, max_n_stroke]))
        all_input_idx_ = tf.reshape(tf.compat.v1.where(mask_), [n_strokes, max_n_stroke - 1, 2])
        all_input_idx_ = tf.transpose(a=all_input_idx_[:, :min_n_stroke - 1], perm=[1, 0, 2])
        all_input_idx_ = tf.random.shuffle(all_input_idx_)
        gather_input_idx_ = tf.cast(tf.transpose(a=all_input_idx_[:n_inputs_], perm=[1,0,2]), tf.int32)
        return gather_input_idx_, target_idx_, n_inputs_
      
      def get_ordered_inp_target_pairs(random_target=False):
        """Get a slice (i.e., window) randomly."""
        # Randomly set the number of inputs.
        n_inputs_ = tf.random.uniform([1], minval=2, maxval=min_n_stroke, dtype=tf.int32)[0]
        # Select start index of the window.
        start_idx = tf.random.uniform([1], minval=0, maxval=min_n_stroke - n_inputs_, dtype=tf.int32)[0]
        
        if not random_target:
          # Target is the next,
          target_idx_ = tf.tile(tf.expand_dims(start_idx+n_inputs_, axis=0), [n_strokes])
        else:
          # Randomly pick a target.
          i_ = tf.random.uniform([n_strokes], minval=0, maxval=1, dtype=tf.float32)
          target_idx_ = tf.round(i_*tf.cast(input_num_strokes - 1, tf.float32))
          target_idx_ = tf.cast(target_idx_, tf.int32)

        mask_ = tf.not_equal(input_range_, tf.tile(target_idx_[:, tf.newaxis], [1, max_n_stroke]))
        all_input_idx_ = tf.reshape(tf.compat.v1.where(mask_), [n_strokes, max_n_stroke - 1, 2])[:, :min_n_stroke]
        gather_input_idx_ = tf.cast(all_input_idx_[:, start_idx:start_idx+n_inputs_], tf.int32)
        return gather_input_idx_, target_idx_, n_inputs_
      
      all_n_inputs = []
      all_gather_input_idx = []
      all_target_idx = []
      all_seq_lens = []
      
      if self.input_type in ["random", "hybrid"]:
        for i in range(self.num_predictive_inputs):
          gather_input_idx, target_idx, n_inputs = get_random_inp_target_pairs()
          all_gather_input_idx.append(gather_input_idx)
          all_target_idx.append(target_idx)
          all_n_inputs.append(n_inputs)
          all_seq_lens.append(tf.ones([n_strokes], dtype=tf.int32)*n_inputs)

      if self.input_type in ["ordered", "hybrid"]:
        for i in range(self.num_predictive_inputs):
          gather_input_idx, target_idx, n_inputs = get_ordered_inp_target_pairs()
          all_gather_input_idx.append(gather_input_idx)
          all_target_idx.append(target_idx)
          all_n_inputs.append(n_inputs)
          all_seq_lens.append(tf.ones([n_strokes], dtype=tf.int32)*n_inputs)
        
      max_len = tf.reduce_max(input_tensor=all_n_inputs)
      for i in range(len(all_n_inputs)):
        all_gather_input_idx[i] = tf.pad(tensor=all_gather_input_idx[i], paddings=[[0, 0], [0, max_len - all_n_inputs[i]], [0, 0]])
        
      gather_input_idx = tf.concat(all_gather_input_idx, axis=0)
      pred_input_seq_len = tf.concat(all_seq_lens, axis=0)

      gather_target_idx = tf.stack([
          tf.tile(batch_idx, [len(all_n_inputs)]),
          tf.concat(all_target_idx, axis=0)
          ], axis=-1)

    else:
      err_unknown_type(self.input_type)
    
    # if "sigma" in diagram_embedding:
    #   pred_targets = dict()
    #   pred_targets["mu"] = tf.stop_gradient(tf.gather_nd(diagram_embedding["mu"], gather_target_idx), name="embedding_mu_target_stop")
    #   pred_targets["sigma"] = tf.stop_gradient(tf.gather_nd(diagram_embedding["sigma"], gather_target_idx), name="embedding_sigma_target_stop")
    # else:
    pred_targets = tf.gather_nd(embedding_sample, gather_target_idx)
    pred_targets = tf.stop_gradient(pred_targets, name="embedding_target_stop")
    
    if self.stop_predictive_grad:
      # Disable gradient flow from the predictive/position models to the
      # embedding model.
      pred_input = tf.stop_gradient(tf.gather_nd(embedding_sample, gather_input_idx), name="embeddings_pred_input_stop")
    else:
      pred_input = tf.gather_nd(embedding_sample, gather_input_idx)

    start_pos = tf.reshape(inputs["start_coord"], [n_strokes, seq_len, 2])
    start_context_pos = tf.gather_nd(start_pos, gather_input_idx)
    if self.end_positions:
      end_pos = tf.reshape(inputs["end_coord"], [n_strokes, seq_len, 2])
      end_context_pos = tf.gather_nd(end_pos, gather_input_idx)
      context_pos = tf.concat([start_context_pos, end_context_pos], axis=-1)
    else:
      context_pos = start_context_pos
      
    target_pos = tf.expand_dims(tf.gather_nd(start_pos, gather_target_idx), axis=1)
    predicted_embedding = self.predictive_model(inputs=dict(input_seq=pred_input,
                                                            input_cond=context_pos,
                                                            target_cond=target_pos,
                                                            seq_len=pred_input_seq_len),
                                                training=True)

    pred_emb_sample = self.predictive_model.output_layer.draw_sample(predicted_embedding, greedy=True)
    embedding_out_dict = dict(
        prediction=predicted_embedding,
        target=pred_targets,
        input_seq_len=pred_input_seq_len,
        embedding_sample=pred_emb_sample)
    out_dict["embedding"] = embedding_out_dict

    if self.position_model is not None:
      predicted_pos = self.position_model(inputs=dict(input_seq=pred_input,
                                                      input_cond=context_pos,
                                                      target_cond=None,
                                                      seq_len=pred_input_seq_len),
                                          training=True)
      pred_pos_sample = self.position_model.output_layer.draw_sample(predicted_pos, greedy=True)
      pos_out_dict = dict(
          prediction=predicted_pos,
          target=target_pos[:, 0],
          input_seq_len=pred_input_seq_len,
          position_sample=pred_pos_sample)
      out_dict["position"] = pos_out_dict
    
    ### Decode the predicted embedding if not using a sequence decoder as it is
    # too slow.
    # if False:
    #   # Select decoder t inputs corresponding to the predicted stroke.
    #   t_batch_diagram = tf.reshape(decoder_inputs, [n_strokes, seq_len, -1])
    #   t_target = tf.gather_nd(t_batch_diagram, gather_target_idx)
    #   predicted_ink = self.embedding_model.call_decode(embedding=pred_emb_sample,
    #                                                    decoder_inputs=t_target,
    #                                                    training=training)
    #   predicted_ink["embedding"] = pred_emb_sample
    #
    #   # TODO Need to pass those to select targets.
    #   predicted_ink["shape_0"] = n_strokes
    #   predicted_ink["shape_1"] = seq_len
    #   predicted_ink["gather_target_idx"] = gather_target_idx
    #   out_dict["predicted_ink"] = predicted_ink
      
    return out_dict
Example #14
0
  def predict_embedding(self, embeddings, target_idx, seq_lens, input_idx=None, input_type="leave_one_out", start_positions=None):
    if isinstance(embeddings, dict):
      embeddings = self.embedding_model.net_embedding.draw_sample(embeddings)
      
    # Determine ink model inputs and targets.
    n_strokes = tf.shape(input=embeddings)[0]
    # seq_len = tf.shape(embeddings)[1]
    seq_len = seq_lens[0]
    batch_idx = tf.range(n_strokes)
    
    if input_idx is None:
      if input_type == "leave_one_out":
        input_range = tf.tile(tf.range(seq_len)[tf.newaxis, :], [n_strokes, 1])
        mask_ = tf.not_equal(input_range,
                             tf.tile(target_idx[:, tf.newaxis], [1, seq_len]))
        gather_input_idx = tf.reshape(tf.compat.v1.where(mask_), [n_strokes, seq_len - 1, 2])
        pred_input_seq_len = seq_lens - 1
      elif input_type == "last_step" or input_type == "ordered":
        input_range = tf.tile(tf.range(seq_len)[tf.newaxis, :], [n_strokes, 1])
        mask_ = tf.less(input_range,
                        tf.tile(target_idx[:, tf.newaxis], [1, seq_len]))
        gather_input_idx = tf.reshape(tf.compat.v1.where(mask_),
                                      [n_strokes, tf.reduce_max(input_tensor=target_idx), 2])
        pred_input_seq_len = target_idx
      else:
        err_unknown_type(input_type)
    else:
      gather_input_idx = tf.stack([
          tf.zeros_like(input_idx),
          input_idx
          ], axis=-1)
      pred_input_seq_len = tf.Variable([tf.shape(input=input_idx)[1]])

    gather_target_idx = tf.stack([
        batch_idx,
        target_idx
        ], axis=-1)
  
    pred_targets = tf.gather_nd(embeddings, gather_target_idx)
    pred_input = tf.gather_nd(embeddings, gather_input_idx)

    start_pos = tf.reshape(start_positions, [n_strokes, seq_len, 2])
    start_context_pos = tf.gather_nd(start_pos, gather_input_idx)
    if self.end_positions:
      end_pos = tf.reshape(start_positions, [n_strokes, seq_len, 2])
      end_context_pos = tf.gather_nd(end_pos, gather_input_idx)
      context_pos = tf.concat([start_context_pos, end_context_pos], axis=-1)
    else:
      context_pos = start_context_pos
    
    target_pos = tf.expand_dims(tf.gather_nd(start_pos, gather_target_idx), axis=1)
    out_ = self.predictive_model(inputs=dict(input_seq=pred_input,
                                             input_cond=context_pos,
                                             target_cond=target_pos,
                                             seq_len=pred_input_seq_len),
                                 training=False)
    
    out_["gather_target_idx"] = gather_target_idx
    out_["gather_input_idx"] = gather_input_idx
    out_["embedding_sample"] = self.predictive_model.output_layer.draw_sample(out_, greedy=True)
    return out_, pred_targets
Example #15
0
    def __init__(self,
                 config_encoder,
                 config_embedding,
                 config_decoder,
                 config_loss,
                 run_mode=C.RUN_ESTIMATOR,
                 **kwargs):
        """Constructor.

    Args:
      config_encoder:
      config_embedding:
      config_decoder:
      config_loss:
      run_mode: eager, static or estimator.
      **kwargs:

    Raises:
      ValueError: if run_mode is eager and tf.executing_eagerly() is False.
      Exception: if # layers > 1 and dynamic_h0 is True.
    """
        super(TEmbedding, self).__init__(config_loss=config_loss,
                                         run_mode=run_mode,
                                         **kwargs)

        self.pen_threshold = 0.3
        self.config_encoder = config_encoder
        self.config_embedding = config_embedding
        self.config_decoder = config_decoder
        self.latent_prefix = ""

        self.regularize_decoder = self.config_decoder.get(
            "regularizer_weight", 0) > 0
        self.kernel_regularizer = None
        if self.regularize_decoder:
            self.kernel_regularizer = tf.keras.regularizers.l2(
                self.config_decoder.get("regularizer_weight", 0))

        self.n_latent_units = self.config_embedding["latent_units"]
        self.use_vae = self.config_embedding["use_vae"]
        self.decoder_drop_rate = self.config_decoder.get("dropout_rate", 0)
        self.t_frequency_channels = self.config_decoder.get(
            "t_frequency_channels", 0)

        # Encoder network
        self.net_encoder = None
        if self.config_encoder["name"] == "rnn":
            self.net_encoder = RNN(
                self.config_encoder["cell_type"],
                self.config_encoder["cell_units"],
                self.config_encoder["cell_layers"],
                self.config_encoder["bidirectional_encoder"],
                return_sequences=False,
                return_state=False,
                run_mode=run_mode,
                use_cudnn=self.config_encoder["use_cudnn"],
                name="encoder_rnn")
        elif self.config_encoder["name"] == "mlp":
            pass
        elif self.config_encoder["name"] == "cnn":
            pass
        elif self.config_encoder["name"] == "transformer":
            self.net_encoder = Transformer(
                num_layers=self.config_encoder["layers"],
                d_model=self.config_encoder["d_model"],
                num_heads=self.config_encoder["heads"],
                dff=self.config_encoder["hidden_units"],
                rate=self.config_encoder["dropout_rate"],
                scale=self.config_encoder["scale"],
                pos_encoding_len=self.config_encoder["pos_encoding"],
                autoregressive=self.config_encoder["autoregressive"],
                return_sequences=False,
                config_loss=None,
                run_mode=run_mode)
        else:
            err_unknown_type(self.config_encoder["name"])

        # Deterministic or stochastic stroke.
        if self.use_vae:
            self.net_embedding = OutputModelNormal(
                out_units=self.n_latent_units,
                prefix=self.latent_prefix,
                sigma_activation=None,
                logvar=True)
        else:
            self.net_embedding = OutputModelDeterministic(
                out_units=self.n_latent_units, prefix=self.latent_prefix)

        # Decoder network:
        self.net_decoder = tf.keras.Sequential(name="decoder")

        layer_units = self.config_decoder["hidden_units"]
        if len(layer_units) == 1:
            layer_units = layer_units * self.config_decoder["n_layers"]

        decoder_activation = Activations.get(self.config_decoder["activation"])

        for idx in range(self.config_decoder["layers"]):
            self.net_decoder.add(
                tf.keras.layers.Dense(
                    layer_units[idx],
                    activation=decoder_activation,
                    kernel_regularizer=self.kernel_regularizer,
                    bias_regularizer=self.kernel_regularizer))
            if self.decoder_drop_rate > 0:
                self.net_decoder.add(
                    tf.keras.layers.Dropout(self.decoder_drop_rate))

        # Pen and stroke outputs.
        if config_loss["pen"]["eval_only"]:
            self.decoder_out_pen = None
        else:
            self.decoder_out_pen = tf.keras.layers.Dense(
                1,
                name="out_pen",
                kernel_regularizer=self.kernel_regularizer,
                bias_regularizer=self.kernel_regularizer)

        # Build output model depending on the loss type.
        if self.config_loss["stroke"]["loss_type"] == C.NLL_NORMAL:
            self.decoder_out_stroke = OutputModelNormal(out_units=2,
                                                        hidden_units=0,
                                                        hidden_layers=0)
        elif self.config_loss["stroke"]["loss_type"] == C.NLL_BINORMAL:
            self.decoder_out_stroke = OutputModelNormal2DDense(
                sigma_activation=tf.keras.activations.exponential)
        elif self.config_loss["stroke"]["loss_type"] == C.NLL_GMM:
            self.decoder_out_stroke = OutputModelGMMDense(
                out_units=2,
                num_components=self.config_loss["stroke"]["num_components"],
                sigma_activation=tf.keras.activations.exponential)
        else:
            self.decoder_out_stroke = OutputModelDeterministic(
                out_units=2,
                hidden_units=0,
                hidden_layers=0,
                kernel_regularizer=self.kernel_regularizer,
                bias_regularizer=self.kernel_regularizer)

        # Variables for static mode. They are assigned in call method.
        # TODO We can get rid of them if autoregressive sampling is no
        #  longer required in static (graph) mode.
        self.op_encoder_inputs = None
        self.op_input_seq_len = None
        self.op_embedding = None
        self.op_decoder_inputs = None
        self.op_embedding_sample = None
Example #16
0
def build_predictive_model(config_, run_mode):
    """Builds model object."""

    # Embedding model.
    if config_.decoder.get("name", "t_emb") == "t_emb":
        embedding_model = TEmbedding(config_encoder=config_.encoder,
                                     config_embedding=config_.embedding,
                                     config_decoder=config_.decoder,
                                     config_loss=config_.loss.ink,
                                     run_mode=run_mode)
    elif config_.decoder.get("name", "t_emb") == "rnn":
        embedding_model = InkSeq2Seq(config_encoder=config_.encoder,
                                     config_embedding=config_.embedding,
                                     config_decoder=config_.decoder,
                                     config_loss=config_.loss.ink,
                                     run_mode=run_mode)
    else:
        err_unknown_type(config_.decoder.name)

    if config_.predictive_model.get("name", "rnn") == "rnn":

        predictive_model = RNNConditional(
            output_size=config_.predictive_model.output_size,
            cell_units=config_.predictive_model.cell_units,
            cell_layers=config_.predictive_model.cell_layers,
            cell_type=config_.predictive_model.cell_type,
            return_sequences=False,
            return_state=False,
            run_mode=run_mode,
            config_loss=config_.loss.predicted_embedding.predicted_embedding)

    elif config_.predictive_model.name == "transformer":
        predictive_model = TransformerSeq2seqConditional(
            output_size=config_.predictive_model.latent_units,
            num_layers=config_.predictive_model.layers,
            d_model=config_.predictive_model.d_model,
            num_heads=config_.predictive_model.heads,
            dff=config_.predictive_model.hidden_units,
            rate=config_.predictive_model.dropout_rate,
            config_loss=config_.loss.predicted_embedding.predicted_embedding,
            pos_encoding_len=100
            if config_.predictive_model.pos_encoding else 0,
            scale=config_.predictive_model.scale,
            run_mode=run_mode,
            autoregressive=False,
            pooling=config_.predictive_model.pooling_layer,
        )
    else:
        err_unknown_type(config_.predictive_model.name)

    position_model = None
    if config_.predictive_model.get("name", "rnn") == "rnn":
        position_model = RNNConditional(
            output_size=2,
            cell_units=config_.position_model.cell_units,
            cell_layers=config_.position_model.cell_layers,
            cell_type=config_.position_model.cell_type,
            return_sequences=False,
            return_state=False,
            run_mode=run_mode,
            config_loss=config_.loss.predicted_pos.predicted_pos)

    elif config_.get("position_model", None) is not None:
        position_model = TransformerSeq2seqConditional(
            output_size=2,
            num_layers=config_.position_model.layers,
            d_model=config_.position_model.d_model,
            num_heads=config_.position_model.heads,
            dff=config_.position_model.hidden_units,
            rate=config_.position_model.dropout_rate,
            config_loss=config_.loss.predicted_pos.predicted_pos,
            pos_encoding_len=0,
            scale=config_.position_model.scale,
            run_mode=run_mode,
            autoregressive=False)

    model_ = PredictiveInkModel(
        embedding_model=embedding_model,
        predictive_model=predictive_model,
        position_model=position_model,
        loss_predicted_embedding=config_.loss.apply_predicted_embedding,
        loss_predicted_ink=config_.loss.apply_predicted_ink,
        loss_reconstructed_ink=config_.loss.apply_reconstructed_ink,
        input_type=config_.predictive_model.pred_input_type,
        start_positions=config_.predictive_model.use_start_pos,
        end_positions=config_.predictive_model.get("use_end_pos", False),
        stop_predictive_grad=config_.predictive_model.get(
            "stop_predictive_grad", False),
        num_predictive_inputs=config_.predictive_model.get(
            "num_predictive_inputs", 8),
        config_loss=copy.deepcopy(config_.loss),
        run_mode=run_mode)
    return model_
Example #17
0
def build_dataset(config_, run_mode=C.RUN_STATIC, split=C.DATA_TRAIN):
    """Builds dataset object."""

    dataset_ = None
    if split == C.DATA_TRAIN:
        dataset_ = TFRecordBatchDiagram(
            data_path=config_.data.train_data_path,
            meta_data_path=config_.data.meta_data_path,
            batch_size=config_.data.batch_size,
            pp_to_origin=config_.data.pp_to_origin,
            pp_relative_pos=config_.data.pp_relative_pos,
            normalize=config_.data.normalize,
            shuffle=True,
            run_mode=run_mode,
            max_length_threshold=config_.data.max_length_threshold,
            mask_pen=config_.data.mask_pen,
            resampling_factor=config_.data.get("resampling_factor", 0),
            t_drop_ratio=config_.data.get("t_drop_ratio", 0),
            scale_factor=config_.data.get("scale_factor", 0),
            pos_noise_factor=config_.data.get("pos_noise_factor", 0),
            affine_prob=config_.data.get("affine_prob", 0),
            reverse_prob=config_.data.get("reverse_prob", 0),
            gt_targets=config_.data.get("gt_targets", False),
            n_t_targets=config_.data.get("n_t_samples", 1),
            int_t_samples=config_.data.get("int_t_samples", False),
            concat_t_inputs=config_.data.get("concat_t_inputs", False),
            rdp=config_.data.get("rdp_dataset", False),
            rdp_didi_pp=config_.data.get("rdp_didi_pp", False),
        )
    elif split == C.DATA_VALID:
        dataset_ = TFRecordBatchDiagram(
            data_path=config_.data.valid_data_path,
            meta_data_path=config_.data.meta_data_path,
            batch_size=config_.data.batch_size,
            pp_to_origin=config_.data.pp_to_origin,
            pp_relative_pos=config_.data.pp_relative_pos,
            normalize=config_.data.normalize,
            shuffle=False,
            run_mode=run_mode,
            max_length_threshold=config_.data.max_length_threshold,
            mask_pen=config_.data.mask_pen,
            concat_t_inputs=config_.data.get("concat_t_inputs", False),
            rdp=config_.data.get("rdp_dataset", False),
            rdp_didi_pp=config_.data.get("rdp_didi_pp", False),
        )
    elif split == C.DATA_TEST:
        dataset_ = TFRecordSingleDiagram(
            data_path=config_.data.test_data_path,
            meta_data_path=config_.data.meta_data_path,
            batch_size=config_.data.batch_size,
            pp_to_origin=config_.data.pp_to_origin,
            pp_relative_pos=config_.data.pp_relative_pos,
            normalize=config_.data.normalize,
            shuffle=False,
            max_length_threshold=config_.data.max_length_threshold,
            run_mode=run_mode,
            mask_pen=config_.data.mask_pen,
            concat_t_inputs=config_.data.get("concat_t_inputs", False),
            rdp=config_.data.get("rdp_dataset", False),
            rdp_didi_pp=config_.data.get("rdp_didi_pp", False),
        )
    else:
        err_unknown_type(split)

    return dataset_
Example #18
0
  def loss_fn(self,
              loss_config,
              predictions,
              targets,
              seq_len,
              prefix="",
              run_mode=C.RUN_STATIC,
              training=True):

    fixed_len_seq = loss_config.get("fixed_len_seq", 0)

    if prefix and not prefix.endswith("_"):
      prefix = prefix + "_"

    seq_len = seq_len
    loss_dict = dict()
    loss_metric_dict = dict()
    total_loss_ops = list()

    for loss_key, loss_term in loss_config.items():
      if not isinstance(loss_term, dict):
        continue
      loss_sequence = None
      loss_type = loss_term["loss_type"]
      loss_reduce = loss_term["reduce"]
      loss_weight = loss_term["weight"]

      if not isinstance(loss_weight, float):
        start, end, increment = loss_term["weight"]["values"]
        loss_weight = end - (end - start) * increment**tf.cast(self.step, tf.float32)
        loss_dict["kldw"] = loss_weight

      if not training:
        loss_weight = 1.0

      loss_objective = not loss_term.get("eval_only", False)
      loss_targets = targets.get(loss_term["target_key"], None)
      loss_predictions = predictions.get(loss_term["out_key"], None)

      seq_mask = None
      if seq_len is not None:
        if fixed_len_seq > 0:
          seq_mask = tf.expand_dims(
              tf.sequence_mask(seq_len, maxlen=fixed_len_seq, dtype=tf.float32),
              -1)
        else:
          seq_mask = tf.expand_dims(
              tf.sequence_mask(seq_len, dtype=tf.float32), -1)
      # Calculate loss with shape (batch_size, seq_len, 1)
      if loss_type == C.MSE:
        if isinstance(loss_targets, dict):
          loss_targets = loss_targets[C.MU]
        loss_sequence = tf.reduce_sum(input_tensor=tf.math.square(loss_targets - loss_predictions[C.MU]), axis=-1, keepdims=True)
      elif loss_type == C.L1:
        if isinstance(loss_targets, dict):
          loss_targets = loss_targets[C.MU]
        loss_sequence = tf.reduce_sum(input_tensor=tf.math.abs(loss_targets - loss_predictions[C.MU]), axis=-1, keepdims=True)
      elif loss_type == C.NLL_NORMAL:
        loss_sequence = -1*loss.logli_normal_diagonal(
            loss_targets, loss_predictions[C.MU], loss_predictions[C.SIGMA])
      elif loss_type == C.NLL_BINORMAL:
        loss_sequence = -1*loss.logli_normal_bivariate(
            loss_targets, loss_predictions[C.MU], loss_predictions[C.SIGMA],
            loss_predictions[C.RHO])
      elif loss_type == C.NLL_GMM:
        loss_sequence = -1*loss.logli_gmm_logsumexp(
            loss_targets, loss_predictions[C.MU], loss_predictions[C.SIGMA],
            loss_predictions[C.PI])
      elif loss_type == C.NLL_CENT_BINARY:
        loss_sequence = tf.nn.sigmoid_cross_entropy_with_logits(
            loss_targets,
            loss_predictions)
      elif loss_type == C.KLD_STANDARD:
        loss_sequence = loss.kld_normal_diagonal_standard_prior(
            loss_predictions[C.MU], loss_predictions[C.SIGMA])
      elif loss_type == C.KLD_STANDARD_NORM:
        loss_sequence = loss.kld_normal_diagonal_standard_prior_normalized(
          loss_predictions[C.MU], loss_predictions[C.SIGMA])
      elif loss_type == C.KLD:
        loss_sequence = loss.kld_normal_diagonal(
            loss_predictions[C.MU],
            loss_predictions[C.SIGMA],
            loss_targets[C.MU],
            loss_targets[C.SIGMA],
            reduce_sum=False)
      elif loss_type == C.SNORM_L2:
        loss_sequence = tf.reduce_sum(input_tensor=tf.math.square(loss_predictions[C.MU]), axis=-1)
      else:
        err_unknown_type(loss_type)

      if len(loss_sequence.shape) == 3:
        # Mask the padded steps and calculate a scalar value.
        loss_sequence *= seq_mask
        loss_op = None
        if loss_reduce == C.R_MEAN_STEP:
          loss_op = reduce_mean_step(loss_sequence, seq_mask)
          # loss_op = tf.reduce_mean(loss_sequence)
        elif loss_reduce == C.R_MEAN_SEQUENCE:
          loss_op = reduce_mean_sequence(loss_sequence)
        else:
          err_unknown_type(loss_reduce)
      
      elif len(loss_sequence.shape) == 2:
          if seq_len is not None:
            nonzero_seq_len = tf.cast(tf.expand_dims(tf.compat.v1.where(seq_len > 0, tf.ones_like(seq_len), tf.zeros_like(seq_len)), axis=1), tf.float32)
            loss_op = loss_sequence * nonzero_seq_len
            loss_op = tf.reduce_sum(input_tensor=loss_op) / tf.reduce_sum(input_tensor=nonzero_seq_len)
          else:
            loss_op = tf.reduce_mean(input_tensor=loss_sequence)
      else:
        loss_op = tf.reduce_mean(input_tensor=loss_sequence)

      if loss_type in [C.KLD_STANDARD, C.KLD_STANDARD_NORM]:
        loss_op = tf.maximum(loss_op, 0.2)

      # The loss term is weighted only for the optimization. In order to enable
      # comparison in Tensorboard plots, we keep the loss unweighted.
      if len(loss_config) > 1:
        loss_dict[prefix + loss_key] = loss_op
      if loss_objective:
        total_loss_ops.append(loss_weight * loss_op)

      # tf.estimator requires the loss in tf.metrics.
      if run_mode == C.RUN_ESTIMATOR and len(loss_config) > 1:
        loss_metric_dict[prefix + loss_key] = tf.keras.metrics.Mean(loss_op)

    # Sum all loss terms to get the objective.
    loss_dict["loss"] = tf.math.add_n(
        total_loss_ops, name=prefix + "total_loss")
    loss_dict[prefix + "loss"] = loss_dict["loss"]

    if run_mode == C.RUN_ESTIMATOR:
      loss_metric_dict["loss"] = tf.keras.metrics.Mean(loss_dict["loss"])
      loss_metric_dict[prefix + "loss"] = loss_metric_dict["loss"]
      return loss_dict, loss_metric_dict
    else:
      return loss_dict