def build_experiment_name(config): template = "{tag}{model_name}_{encoder}-{latent}-{decoder}-{output}-{experiment}{data}" if config.decoder.name == "t_emb": model_tags = TEmbedding.get_model_tags(config, config.loss) elif config.decoder.name == "rnn": model_tags = InkSeq2Seq.get_model_tags(config, config.loss) else: err_unknown_type(config.decoder.name) # data = config.data.data_name data = "" lr = "" if config.experiment.learning_rate.name == "transformer": lr = "_tr" elif config.experiment.learning_rate.name == "exponential": lr = "_exp" experiment = "B{}_LR{}".format(config.data.batch_size, lr) return template.format( tag=config.experiment.tag + "_" if config.experiment.tag else "", model_name=model_tags["model_name"], encoder=model_tags["encoder"], latent=model_tags["latent"], decoder=model_tags["decoder"], output=model_tags["output"], experiment=experiment, data=data, )
def get_rnn_layer(cls, type_str, units, return_sequences, return_state, stateful, name, recurrent_dropout=0.0): """Generates an RNN layer. Args: type_str: units: return_sequences: return_state: stateful: name: recurrent_dropout: Returns: """ cell_cls = None if type_str == C.LSTM: cell_cls = tf.keras.layers.LSTM elif type_str == C.GRU: cell_cls = tf.keras.layers.GRU else: err_unknown_type(type_str) return cell_cls(units=units, return_sequences=return_sequences, return_state=return_state, stateful=stateful, recurrent_dropout=recurrent_dropout, name=name)
def get(cls, config): lr_type = config["name"] if lr_type == "exponential": return ExponentialDecay(**config) elif lr_type == "sketch_rnn": return SketchRnnDecay(**config) elif lr_type == "transformer": return TransformerDecay(**config) else: err_unknown_type(lr_type)
def get_model_tags(cls, config, config_loss): """Generates a string summarizing experiment parameters. Args: config: config_loss Returns: """ if config_loss["stroke"]["loss_type"] == C.NLL_NORMAL: output = "normal" elif config_loss["stroke"]["loss_type"] == C.NLL_BINORMAL: output = "binormal" elif config_loss["stroke"]["loss_type"] == C.NLL_GMM: output = "gmm" else: output = config_loss["stroke"]["loss_type"] latent = "L{}".format(config.embedding.latent_units) if config.embedding.use_vae: latent += "_vae" if isinstance(config_loss.embedding_kld.weight, float): latent += "_w" + str(config_loss.embedding_kld.weight) else: latent += "_aw" + str(config_loss.embedding_kld.weight["values"][1]) if config.encoder.name == "rnn": encoder = "{}_{}x{}".format(config.encoder.cell_type, config.encoder.cell_layers, config.encoder.cell_units) if config.encoder.bidirectional_encoder: encoder = "bi" + encoder if config.encoder.rec_dropout_rate > 0: encoder += "_rdrop{}".format(config.encoder.rec_dropout_rate) else: err_unknown_type(config.encoder["name"]) decoder = "" if config.decoder.repeat_vae_sample: decoder += "rep_" if config.decoder.dropout_rate > 0: decoder += "ddrop_" + str(config.decoder.dropout_rate) if config.decoder.dynamic_h0: decoder += "dh0_" model_name = "Seq2Seq" if config.decoder.autoregressive: model_name += "_ar" return dict(encoder=encoder, latent=latent, decoder=decoder, output=output, model_name=model_name)
def get_model_tags(cls, config, config_loss): """Generates a string summarizing experiment parameters. Args: config: config_loss Returns: """ if config_loss["stroke"]["loss_type"] == C.NLL_NORMAL: output = "normal" elif config_loss["stroke"]["loss_type"] == C.NLL_BINORMAL: output = "binormal" elif config_loss["stroke"]["loss_type"] == C.NLL_GMM: output = "gmm" else: output = config_loss["stroke"]["loss_type"] decoder = "{}x{}".format(config.decoder.layers, config.decoder.hidden_units[0]) latent = "L{}".format(config.embedding.latent_units) if config.embedding.use_vae: latent += "_vae" if isinstance(config_loss.embedding_kld.weight, float): latent += "_w" + str(config_loss.embedding_kld.weight) else: latent += "_aw" + str( config_loss.embedding_kld.weight["values"][1]) if config.encoder.name == "rnn": encoder = "{}_{}x{}".format(config.encoder.cell_type, config.encoder.cell_layers, config.encoder.cell_units) if config.encoder.bidirectional_encoder: encoder = "bi" + encoder elif config.encoder.name == "transformer": encoder = "TR_{}_{}x{}-head_{}-drop_{}".format( config.encoder.d_model, config.encoder.layers, config.encoder.hidden_units, config.encoder.heads, config.encoder.dropout_rate) if not config.encoder.autoregressive: encoder = "bi" + encoder else: err_unknown_type(config.encoder["name"]) return dict(encoder=encoder, latent=latent, decoder=decoder, output=output, model_name="TEMB")
def build_experiment_name(config): template = "PRED_{tag}{pred_model_name}_{predictive}-{emb_model_name}_{encoder}-{latent}-{decoder}-{output}-{loss}-{experiment}{data}" if config.decoder.name == "t_emb": emb_model_tags = TEmbedding.get_model_tags(config, config.loss.ink) elif config.decoder.name == "rnn": emb_model_tags = InkSeq2Seq.get_model_tags(config, config.loss.ink) else: err_unknown_type(config.decoder.name) if config.predictive_model.name == "rnn": pred_model_tags = RNN.get_model_tags(config.predictive_model) elif config.predictive_model.name == "transformer": pred_model_tags = TransformerAR.get_model_tags(config.predictive_model) else: err_unknown_type(config.predictive_model.name) # data = config.data.data_name data = "" lr = "" if config.experiment.learning_rate.name == "transformer": lr = "_tr" elif config.experiment.learning_rate.name == "exponential": lr = "_exp" experiment = "B{}_LR{}".format(config.data.batch_size, lr) loss = "loss_{}{}{}".format( "P" if config.loss.apply_predicted_ink else "", "E" if config.loss.apply_predicted_embedding else "", "R" if config.loss.apply_reconstructed_ink else "") return template.format( tag=config.experiment.tag + "_" if config.experiment.tag else "", pred_model_name=pred_model_tags["model_name"], predictive=pred_model_tags["model"], emb_model_name=emb_model_tags["model_name"], encoder=emb_model_tags["encoder"], latent=emb_model_tags["latent"], decoder=emb_model_tags["decoder"], output=emb_model_tags["output"], experiment=experiment, loss=loss, data=data, )
def build_embedding_model(config_, run_mode=C.RUN_STATIC): """Builds model object.""" if config_.decoder.name == "t_emb": model_ = TEmbedding(config_encoder=config_.encoder, config_embedding=config_.embedding, config_decoder=config_.decoder, config_loss=config_.loss, run_mode=run_mode) elif config_.decoder.name == "rnn": model_ = InkSeq2Seq(config_encoder=config_.encoder, config_embedding=config_.embedding, config_decoder=config_.decoder, config_loss=config_.loss, run_mode=run_mode) else: err_unknown_type(config_.decoder.name) return model_
def get_cells(cls, type_str, units, layers=1): """Creates a cell. Args: type_str: units: layers: Returns: """ cells = [] for _ in range(layers): if type_str == C.LSTM: cells.append(tf.keras.layers.LSTMCell(units)) elif type_str == C.GRU: cells.append(tf.keras.layers.GRUCell(units)) else: err_unknown_type(type_str) return cells
def load_meta_data(cls, meta_data_path): """Loads meta-data file given the path. It is assumed to be in numpy. Args: meta_data_path: Returns: Meta-data dictionary or False if it is not found. """ # if not meta_data_path or not os.path.exists(meta_data_path): _, ext = os.path.splitext(meta_data_path) if ext == ".json": meta_fp = tf.io.gfile.GFile(meta_data_path, "r") try: meta_fp.size() print("Loading statistics " + meta_data_path) json_stats = json.load(meta_fp) stats_np = dict() for key_, value_ in json_stats.items(): stats_np[key_] = np.array(value_) if isinstance(value_, list) else \ value_ return stats_np except tf.errors.NotFoundError: print("Meta-data not found.") return False elif ext == ".npy": meta_fp = tf.io.gfile.GFile(meta_data_path, "rb") try: meta_fp.size() print("Loading statistics " + meta_data_path) return np.load(meta_fp, allow_pickle=True).item() except tf.errors.NotFoundError: print("Meta-data not found.") return False else: err_unknown_type(ext)
def get(cls, type_str): """Creates the activation function, given its type. Args: type_str: Returns: """ # Check if the activation is already callable. if callable(type_str): return type_str # Check if the activation is a built-in or custom function. if type_str == C.RELU: return tf.nn.relu elif type_str == C.ELU: return tf.nn.elu elif type_str == C.TANH: return tf.nn.tanh elif type_str == C.SIGMOID: return tf.nn.sigmoid elif type_str == C.SOFTPLUS: return tf.nn.softplus elif type_str == C.SOFTMAX: return tf.nn.softmax elif type_str == C.LRELU: return lambda x: tf.nn.leaky_relu(x, alpha=1. / 3.) elif type_str == C.CLRELU: with tf.compat.v1.name_scope("ClampedLeakyRelu"): def clamped_leaky_relu(x): return tf.clip_by_value(tf.nn.leaky_relu(x, alpha=1. / 3.), -3.0, 3.0) return clamped_leaky_relu elif type_str is None: return None else: err_unknown_type(type_str)
def get_model_tags(cls, config, config_loss=None): """Generates a string summarizing experiment parameters. Args: config: config_loss: Returns: """ if config.predictive_model.get("name", "rnn") == "rnn": pred = "{}_{}x{}".format(config.encoder.cell_type, config.predictive_model.cell_layers, config.predictive_model.cell_units) elif config.predictive_model.name == "transformer": pred = "{}x{}-head_{}-drop_{}".format( config.predictive_model.layers, config.predictive_model.hidden_units, config.predictive_model.heads, config.predictive_model.dropout_rate) else: err_unknown_type(config.predictive_model.name) return dict(predictive=pred, model_name="PRED")
def get_config(FLAGS, experiment_id=None): """Defines the default configuration.""" experiment_id = FLAGS.experiment_id or experiment_id config = Configuration() config.experiment = ExperimentConfig( comment=FLAGS.comment, tag="", # Set automatically. model_dir=None, # Set automatically. eval_dir=None, # Set automatically. id=experiment_id, max_epochs=None, max_steps=200000, log_frequency=100, eval_steps=500, checkpoint_frequency=500, grad_clip_norm=FLAGS.grad_clip_norm if FLAGS.grad_clip_value <= 0 else 0, grad_clip_value=FLAGS.grad_clip_value, pretrained_emb_id=FLAGS.pretrained_emb_id) config.experiment.learning_rate = AttrDict( name=FLAGS.learning_rate_type, initial_learning_rate=FLAGS.learning_rate, ) if FLAGS.learning_rate_type == "transformer": config.experiment.learning_rate.d_model = FLAGS.transformer_dmodel config.experiment.learning_rate.warmup_steps = 4000 config.data = DataConfig( data_dir=FLAGS.data_dir, data_name=FLAGS.data_name, data_tfrecord_fname=C.DATASET_MAP[FLAGS.data_name] ["data_tfrecord_fname"], data_meta_fname=C.DATASET_MAP[FLAGS.data_name][FLAGS.metadata_type], pp_to_origin="position" in FLAGS.metadata_type, pp_relative_pos="velocity" in FLAGS.metadata_type, normalize=not FLAGS.skip_normalization, batch_size=FLAGS.batch_size, max_length_threshold=201, mask_pen=FLAGS.mask_encoder_pen, resampling_factor=FLAGS.resampling_factor, t_drop_ratio=FLAGS.t_drop_ratio, gt_targets=FLAGS.gt_targets, scale_factor=FLAGS.scale_factor, affine_prob=FLAGS.affine_prob, reverse_prob=FLAGS.reverse_prob, n_t_samples=FLAGS.n_t_samples, int_t_samples=FLAGS.int_t_samples, concat_t_inputs=FLAGS.concat_t_inputs, rdp_dataset=FLAGS.rdp_dataset, rdp_didi_pp=FLAGS.rdp_didi_pp, pos_noise_factor=FLAGS.pos_noise_factor, ) config.gdrive = AttrDict( credential=None, # Set automatically below. workbook= None, # Set your workbook ID (see https://github.com/emreaksan/glogger) sheet=FLAGS.data_name, ) # Embedding model. if FLAGS.encoder_model == "rnn": config.encoder = AttrDict( name="rnn", cell_units=FLAGS.encoder_rnn_units, cell_layers=FLAGS.encoder_rnn_layers, cell_type=FLAGS.encoder_cell_type, bidirectional_encoder=FLAGS.bidirectional_encoder, rec_dropout_rate=FLAGS.encoder_rdropout) elif FLAGS.encoder_model == "transformer": config.encoder = AttrDict( name="transformer", layers=FLAGS.transformer_layers, heads=FLAGS.transformer_heads, d_model=FLAGS.transformer_dmodel, hidden_units=FLAGS.transformer_hidden_units, dropout_rate=FLAGS.transformer_dropout, pos_encoding=config.data.max_length_threshold if FLAGS.transformer_pos_encoding else 0, scale=FLAGS.transformer_scale, autoregressive=not FLAGS.bidirectional_encoder) else: err_unknown_type(FLAGS.encoder_model) config.embedding = AttrDict( latent_units=FLAGS.latent_units, use_vae=FLAGS.use_vae, ) if FLAGS.decoder_model == "rnn": config.decoder = AttrDict( name="rnn", cell_units=FLAGS. encoder_rnn_units, # Using the same hyper-param with the encoder. cell_layers=FLAGS.encoder_rnn_layers, cell_type=FLAGS.encoder_cell_type, dropout_rate=FLAGS.decoder_dropout, dynamic_h0=FLAGS.decoder_dynamic_h0, repeat_vae_sample=FLAGS.repeat_vae_sample, autoregressive=FLAGS.decoder_autoregressive, ) target_key_pen = "pen" target_key_stroke = "stroke" elif FLAGS.decoder_model == "t_emb": config.decoder = AttrDict( name="t_emb", layers=FLAGS.decoder_layers, hidden_units=FLAGS.decoder_hidden_units, activation=FLAGS.decoder_activation, dropout_rate=FLAGS.decoder_dropout, t_frequency_channels=FLAGS.t_frequency_channels, regularizer_weight=FLAGS.reg_dec_weight) target_key_pen = C.TARGET_T_PEN target_key_stroke = C.TARGET_T_STROKE else: err_unknown_type(FLAGS.decoder_model) # Predictive model. if FLAGS.predictive_model == "rnn": config.predictive_model = AttrDict( name="rnn", output_size=config.embedding.latent_units, cell_units=FLAGS.predictive_rnn_units, cell_layers=FLAGS.predictive_rnn_layers, cell_type=FLAGS.predictive_cell_type, activation=C.RELU, use_start_pos=FLAGS.use_start_pos, use_end_pos=FLAGS.use_end_pos, stop_predictive_grad=FLAGS.stop_predictive_grad, num_predictive_inputs=FLAGS.num_pred_inputs, pred_input_type=FLAGS.pred_input_type, ) elif FLAGS.predictive_model == "transformer": config.predictive_model = AttrDict( name="transformer", output_size=config.embedding.latent_units, layers=FLAGS.p_transformer_layers, heads=FLAGS.p_transformer_heads, d_model=FLAGS.p_transformer_dmodel, latent_units=FLAGS.latent_units, hidden_units=FLAGS.p_transformer_hidden_units, dropout_rate=FLAGS.p_transformer_dropout, pos_encoding=FLAGS.p_transformer_pos_encoding, scale=FLAGS.p_transformer_scale, use_start_pos=FLAGS.use_start_pos, use_end_pos=FLAGS.use_end_pos, stop_predictive_grad=FLAGS.stop_predictive_grad, num_predictive_inputs=FLAGS.num_pred_inputs, pred_input_type=FLAGS.pred_input_type, pooling_layer=FLAGS.pooling_layer, ) else: err_unknown_type(FLAGS.predictive_model) # Sharing flags with the predictive model. if FLAGS.position_model == "rnn": config.position_model = AttrDict( name="rnn", output_size=2, cell_units=FLAGS.predictive_rnn_units, cell_layers=FLAGS.predictive_rnn_layers, cell_type=FLAGS.predictive_cell_type, activation=C.RELU, ) elif FLAGS.position_model == "transformer": config.position_model = AttrDict( name="transformer", output_size=2, layers=FLAGS.p_transformer_layers, heads=FLAGS.p_transformer_heads, d_model=FLAGS.p_transformer_dmodel, hidden_units=FLAGS.p_transformer_hidden_units, dropout_rate=FLAGS.p_transformer_dropout, pos_encoding=FLAGS.p_transformer_pos_encoding, scale=FLAGS.p_transformer_scale, ) elif FLAGS.position_model is None: config.position_model = None else: err_unknown_type(FLAGS.position_model) # Loss config.loss = AttrDict() stroke_ = LossConfig(loss_type=FLAGS.stroke_loss, num_components=20, target_key=target_key_stroke, out_key="stroke_logits", reduce_type=C.R_MEAN_STEP) pen_ = LossConfig(eval_only=FLAGS.disable_pen_loss, loss_type=C.NLL_CENT_BINARY, target_key=target_key_pen, out_key="pen_logits", reduce_type=C.R_MEAN_STEP) ink_loss = AttrDict(pen=pen_, stroke=stroke_, prefix="reconstruction") if FLAGS.use_vae: ink_loss.embedding_kld = LossConfig( loss_type=FLAGS.kld_type, # C.KLD_STANDARD or C.KLD_STANDARD_NORM target_key=None, out_key="embedding", reduce_type=C.R_MEAN_STEP) ink_loss.embedding_kld.weight = FLAGS.kld_weight if FLAGS.kld_increment > 0: ink_loss.embedding_kld.weight = dict(type="linear_decay", values=[ FLAGS.kld_start, FLAGS.kld_weight, FLAGS.kld_increment ]) if FLAGS.reg_emb_weight > 0: ink_loss.embedding_l2 = LossConfig(loss_type=C.SNORM_L2, target_key=None, out_key="embedding", reduce_type=C.R_MEAN_STEP, weight=FLAGS.reg_emb_weight) embedding_pred = LossConfig(loss_type=FLAGS.embedding_loss, num_components=FLAGS.embedding_gmm_components, target_key="target", out_key="prediction", reduce_type=C.R_MEAN_STEP) position_pred = LossConfig(loss_type=FLAGS.embedding_loss, num_components=FLAGS.embedding_gmm_components, target_key="target", out_key="prediction", reduce_type=C.R_MEAN_STEP) # embedding_pred.weight = FLAGS.kld_weight # embedding_pred.weight = dict( # type="linear_decay", # values=[FLAGS.kld_start, FLAGS.kld_weight, FLAGS.kld_increment]) config.loss = AttrDict( ink=ink_loss, predicted_embedding=AttrDict(predicted_embedding=embedding_pred), predicted_ink=copy.deepcopy(ink_loss)) if config.position_model is not None: config.loss.predicted_pos = AttrDict(predicted_pos=position_pred) # No kld loss on the predicted ink. if "embedding_kld" in config.loss.predicted_ink: del config.loss.predicted_ink["embedding_kld"] config.loss.apply_predicted_embedding = FLAGS.loss_predicted_embedding config.loss.apply_predicted_ink = FLAGS.loss_predicted_ink config.loss.apply_reconstructed_ink = FLAGS.loss_reconstructed_ink try: data_root = os.environ["COSE_DATA_DIR"] log_dir = os.environ["COSE_LOG_DIR"] eval_dir = os.environ["COSE_EVAL_DIR"] gdrive_key = os.environ["GDRIVE_API_KEY"] except KeyError: if FLAGS.data_dir is None or FLAGS.eval_dir is None or FLAGS.experiment_dir is None: raise Exception( "Either environment variables or FLAGs must be set.") data_root = FLAGS.data_dir log_dir = FLAGS.experiment_dir eval_dir = FLAGS.eval_dir gdrive_key = FLAGS.gdrive_api_key # Check if the experiment directory already exists. model_dir_query = glob.glob( os.path.join(log_dir, config.experiment.id + "*")) if model_dir_query: model_dir = model_dir_query[0] __builtin__.print = Print(os.path.join(model_dir, "log.txt")) # Overload print. # Load experiment config. config = config.from_json(os.path.join(model_dir, "config.json")) config.experiment.model_dir = model_dir config.experiment.eval_dir = os.path.join(eval_dir, os.path.basename(model_dir)) if "predictive_model" not in config: raise NotPredictiveModelError print("Loading from " + config.experiment.model_dir) else: config.experiment.tag = build_experiment_name(config) model_dir_name = config.experiment.id + "-" + config.experiment.tag config.experiment.model_dir = os.path.join(log_dir, model_dir_name) config.experiment.eval_dir = os.path.join(eval_dir, model_dir_name) os.mkdir(config.experiment.model_dir) # Create experiment directory __builtin__.print = Print( os.path.join(config.experiment.model_dir, "log.txt")) # Overload print. print("Saving to " + config.experiment.model_dir) if not isinstance(config.data.data_tfrecord_fname, list): config.data.data_tfrecord_fname = [config.data.data_tfrecord_fname] data_path = [ os.path.join(data_root, config.data.data_name, "{}", dp) for dp in config.data.data_tfrecord_fname ] config.data.train_data_path = [dp.format("training") for dp in data_path] config.data.valid_data_path = [dp.format("validation") for dp in data_path] config.data.test_data_path = [dp.format("test") for dp in data_path] config.data.meta_data_path = os.path.join(data_root, config.data.data_name, config.data.data_meta_fname) config.experiment.pretrained_emb_dir = None if config.experiment.get("pretrained_emb_id", None) is not None: config.experiment.pretrained_emb_dir = glob.glob( os.path.join(log_dir, config.experiment.pretrained_emb_id + "-*"))[0] config.gdrive.credential = gdrive_key if FLAGS.gdrive_api_key == None: config.gdrive = None config.dump(config.experiment.model_dir) return config
def call(self, inputs, training=False, **kwargs): """Call method. Args: inputs (dict): expected to contain inputs for the encoder and decoder, sequence length and number of strokes ops. training: whether in training mode or not. **kwargs: Returns: [batch_size, seq_len, feature_size] """ input_num_strokes = inputs[C.INP_NUM_STROKE] # tf.keras compile, fit, predict, etc. methods cause it to be 2-dimensional. if len(inputs[C.INP_NUM_STROKE].shape) == 2: input_num_strokes = inputs[C.INP_NUM_STROKE][:, 0] out_dict = dict() gt_reconstruction = self.embedding_model.call(inputs, training=training) out_dict["reconstructed_ink"] = gt_reconstruction embedding = gt_reconstruction["embedding"] out_dict["embedding_sample"] = out_dict["reconstructed_ink"]["embedding_sample"] # Before making a prediction for the next stroke, reshape # embeddings so that we have a sequence of stroke embeddings, # representing the diagram samples. diagram_embedding = self.batch_stroke_to_diagram(embedding, input_num_strokes) # Probabilistic inputs if the stroke model is probabilistic. # draw inputs for the ink model. embedding_sample = self.embedding_model.net_embedding.draw_sample(diagram_embedding, greedy=True) # Determine ink model inputs and targets. n_strokes = tf.shape(input=embedding_sample)[0] seq_len = tf.shape(input=embedding_sample)[1] batch_idx = tf.range(n_strokes) # (1) Predict the last step only. if self.input_type == "last_step": target_idx = input_num_strokes-1 input_range = tf.tile(tf.range(seq_len)[tf.newaxis, :], [n_strokes, 1]) mask_ = tf.not_equal(input_range, tf.tile(target_idx[:, tf.newaxis], [1, seq_len])) gather_input_idx = tf.reshape(tf.compat.v1.where(mask_), [n_strokes, seq_len - 1, 2]) pred_input_seq_len = input_num_strokes - 1 gather_target_idx = tf.stack([ batch_idx, target_idx ], axis=-1) # (2) Leave-one-out with random targets. Use all strokes except to target # as input. elif self.input_type == "leave_one_out": i = tf.random.uniform([n_strokes], minval=0, maxval=1, dtype=tf.float32) target_idx = tf.round(i*tf.cast(input_num_strokes - 1, tf.float32)) target_idx = tf.cast(target_idx, tf.int32) input_range = tf.tile(tf.range(seq_len)[tf.newaxis, :], [n_strokes, 1]) mask_ = tf.not_equal(input_range, tf.tile(target_idx[:, tf.newaxis], [1, seq_len])) gather_input_idx = tf.reshape(tf.compat.v1.where(mask_), [n_strokes, seq_len - 1, 2]) pred_input_seq_len = input_num_strokes - 1 gather_target_idx = tf.stack([ batch_idx, target_idx ], axis=-1) elif self.input_type in ["random", "ordered", "hybrid"]: min_n_stroke = tf.reduce_min(input_tensor=input_num_strokes) max_n_stroke = tf.reduce_max(input_tensor=input_num_strokes) input_range_ = tf.tile(tf.range(max_n_stroke)[tf.newaxis, :], [n_strokes, 1]) def get_random_inp_target_pairs(): """Get a randomly generated input set and a target.""" # Randomly set the number of inputs. n_inputs_ = tf.random.uniform([1], minval=2, maxval=min_n_stroke, dtype=tf.int32)[0] # Randomly pick a target. i_ = tf.random.uniform([n_strokes], minval=0, maxval=1, dtype=tf.float32) target_idx_ = tf.round(i_*tf.cast(input_num_strokes - 1, tf.float32)) target_idx_ = tf.cast(target_idx_, tf.int32) mask_ = tf.not_equal(input_range_, tf.tile(target_idx_[:, tf.newaxis], [1, max_n_stroke])) all_input_idx_ = tf.reshape(tf.compat.v1.where(mask_), [n_strokes, max_n_stroke - 1, 2]) all_input_idx_ = tf.transpose(a=all_input_idx_[:, :min_n_stroke - 1], perm=[1, 0, 2]) all_input_idx_ = tf.random.shuffle(all_input_idx_) gather_input_idx_ = tf.cast(tf.transpose(a=all_input_idx_[:n_inputs_], perm=[1,0,2]), tf.int32) return gather_input_idx_, target_idx_, n_inputs_ def get_ordered_inp_target_pairs(random_target=False): """Get a slice (i.e., window) randomly.""" # Randomly set the number of inputs. n_inputs_ = tf.random.uniform([1], minval=2, maxval=min_n_stroke, dtype=tf.int32)[0] # Select start index of the window. start_idx = tf.random.uniform([1], minval=0, maxval=min_n_stroke - n_inputs_, dtype=tf.int32)[0] if not random_target: # Target is the next, target_idx_ = tf.tile(tf.expand_dims(start_idx+n_inputs_, axis=0), [n_strokes]) else: # Randomly pick a target. i_ = tf.random.uniform([n_strokes], minval=0, maxval=1, dtype=tf.float32) target_idx_ = tf.round(i_*tf.cast(input_num_strokes - 1, tf.float32)) target_idx_ = tf.cast(target_idx_, tf.int32) mask_ = tf.not_equal(input_range_, tf.tile(target_idx_[:, tf.newaxis], [1, max_n_stroke])) all_input_idx_ = tf.reshape(tf.compat.v1.where(mask_), [n_strokes, max_n_stroke - 1, 2])[:, :min_n_stroke] gather_input_idx_ = tf.cast(all_input_idx_[:, start_idx:start_idx+n_inputs_], tf.int32) return gather_input_idx_, target_idx_, n_inputs_ all_n_inputs = [] all_gather_input_idx = [] all_target_idx = [] all_seq_lens = [] if self.input_type in ["random", "hybrid"]: for i in range(self.num_predictive_inputs): gather_input_idx, target_idx, n_inputs = get_random_inp_target_pairs() all_gather_input_idx.append(gather_input_idx) all_target_idx.append(target_idx) all_n_inputs.append(n_inputs) all_seq_lens.append(tf.ones([n_strokes], dtype=tf.int32)*n_inputs) if self.input_type in ["ordered", "hybrid"]: for i in range(self.num_predictive_inputs): gather_input_idx, target_idx, n_inputs = get_ordered_inp_target_pairs() all_gather_input_idx.append(gather_input_idx) all_target_idx.append(target_idx) all_n_inputs.append(n_inputs) all_seq_lens.append(tf.ones([n_strokes], dtype=tf.int32)*n_inputs) max_len = tf.reduce_max(input_tensor=all_n_inputs) for i in range(len(all_n_inputs)): all_gather_input_idx[i] = tf.pad(tensor=all_gather_input_idx[i], paddings=[[0, 0], [0, max_len - all_n_inputs[i]], [0, 0]]) gather_input_idx = tf.concat(all_gather_input_idx, axis=0) pred_input_seq_len = tf.concat(all_seq_lens, axis=0) gather_target_idx = tf.stack([ tf.tile(batch_idx, [len(all_n_inputs)]), tf.concat(all_target_idx, axis=0) ], axis=-1) else: err_unknown_type(self.input_type) # if "sigma" in diagram_embedding: # pred_targets = dict() # pred_targets["mu"] = tf.stop_gradient(tf.gather_nd(diagram_embedding["mu"], gather_target_idx), name="embedding_mu_target_stop") # pred_targets["sigma"] = tf.stop_gradient(tf.gather_nd(diagram_embedding["sigma"], gather_target_idx), name="embedding_sigma_target_stop") # else: pred_targets = tf.gather_nd(embedding_sample, gather_target_idx) pred_targets = tf.stop_gradient(pred_targets, name="embedding_target_stop") if self.stop_predictive_grad: # Disable gradient flow from the predictive/position models to the # embedding model. pred_input = tf.stop_gradient(tf.gather_nd(embedding_sample, gather_input_idx), name="embeddings_pred_input_stop") else: pred_input = tf.gather_nd(embedding_sample, gather_input_idx) start_pos = tf.reshape(inputs["start_coord"], [n_strokes, seq_len, 2]) start_context_pos = tf.gather_nd(start_pos, gather_input_idx) if self.end_positions: end_pos = tf.reshape(inputs["end_coord"], [n_strokes, seq_len, 2]) end_context_pos = tf.gather_nd(end_pos, gather_input_idx) context_pos = tf.concat([start_context_pos, end_context_pos], axis=-1) else: context_pos = start_context_pos target_pos = tf.expand_dims(tf.gather_nd(start_pos, gather_target_idx), axis=1) predicted_embedding = self.predictive_model(inputs=dict(input_seq=pred_input, input_cond=context_pos, target_cond=target_pos, seq_len=pred_input_seq_len), training=True) pred_emb_sample = self.predictive_model.output_layer.draw_sample(predicted_embedding, greedy=True) embedding_out_dict = dict( prediction=predicted_embedding, target=pred_targets, input_seq_len=pred_input_seq_len, embedding_sample=pred_emb_sample) out_dict["embedding"] = embedding_out_dict if self.position_model is not None: predicted_pos = self.position_model(inputs=dict(input_seq=pred_input, input_cond=context_pos, target_cond=None, seq_len=pred_input_seq_len), training=True) pred_pos_sample = self.position_model.output_layer.draw_sample(predicted_pos, greedy=True) pos_out_dict = dict( prediction=predicted_pos, target=target_pos[:, 0], input_seq_len=pred_input_seq_len, position_sample=pred_pos_sample) out_dict["position"] = pos_out_dict ### Decode the predicted embedding if not using a sequence decoder as it is # too slow. # if False: # # Select decoder t inputs corresponding to the predicted stroke. # t_batch_diagram = tf.reshape(decoder_inputs, [n_strokes, seq_len, -1]) # t_target = tf.gather_nd(t_batch_diagram, gather_target_idx) # predicted_ink = self.embedding_model.call_decode(embedding=pred_emb_sample, # decoder_inputs=t_target, # training=training) # predicted_ink["embedding"] = pred_emb_sample # # # TODO Need to pass those to select targets. # predicted_ink["shape_0"] = n_strokes # predicted_ink["shape_1"] = seq_len # predicted_ink["gather_target_idx"] = gather_target_idx # out_dict["predicted_ink"] = predicted_ink return out_dict
def predict_embedding(self, embeddings, target_idx, seq_lens, input_idx=None, input_type="leave_one_out", start_positions=None): if isinstance(embeddings, dict): embeddings = self.embedding_model.net_embedding.draw_sample(embeddings) # Determine ink model inputs and targets. n_strokes = tf.shape(input=embeddings)[0] # seq_len = tf.shape(embeddings)[1] seq_len = seq_lens[0] batch_idx = tf.range(n_strokes) if input_idx is None: if input_type == "leave_one_out": input_range = tf.tile(tf.range(seq_len)[tf.newaxis, :], [n_strokes, 1]) mask_ = tf.not_equal(input_range, tf.tile(target_idx[:, tf.newaxis], [1, seq_len])) gather_input_idx = tf.reshape(tf.compat.v1.where(mask_), [n_strokes, seq_len - 1, 2]) pred_input_seq_len = seq_lens - 1 elif input_type == "last_step" or input_type == "ordered": input_range = tf.tile(tf.range(seq_len)[tf.newaxis, :], [n_strokes, 1]) mask_ = tf.less(input_range, tf.tile(target_idx[:, tf.newaxis], [1, seq_len])) gather_input_idx = tf.reshape(tf.compat.v1.where(mask_), [n_strokes, tf.reduce_max(input_tensor=target_idx), 2]) pred_input_seq_len = target_idx else: err_unknown_type(input_type) else: gather_input_idx = tf.stack([ tf.zeros_like(input_idx), input_idx ], axis=-1) pred_input_seq_len = tf.Variable([tf.shape(input=input_idx)[1]]) gather_target_idx = tf.stack([ batch_idx, target_idx ], axis=-1) pred_targets = tf.gather_nd(embeddings, gather_target_idx) pred_input = tf.gather_nd(embeddings, gather_input_idx) start_pos = tf.reshape(start_positions, [n_strokes, seq_len, 2]) start_context_pos = tf.gather_nd(start_pos, gather_input_idx) if self.end_positions: end_pos = tf.reshape(start_positions, [n_strokes, seq_len, 2]) end_context_pos = tf.gather_nd(end_pos, gather_input_idx) context_pos = tf.concat([start_context_pos, end_context_pos], axis=-1) else: context_pos = start_context_pos target_pos = tf.expand_dims(tf.gather_nd(start_pos, gather_target_idx), axis=1) out_ = self.predictive_model(inputs=dict(input_seq=pred_input, input_cond=context_pos, target_cond=target_pos, seq_len=pred_input_seq_len), training=False) out_["gather_target_idx"] = gather_target_idx out_["gather_input_idx"] = gather_input_idx out_["embedding_sample"] = self.predictive_model.output_layer.draw_sample(out_, greedy=True) return out_, pred_targets
def __init__(self, config_encoder, config_embedding, config_decoder, config_loss, run_mode=C.RUN_ESTIMATOR, **kwargs): """Constructor. Args: config_encoder: config_embedding: config_decoder: config_loss: run_mode: eager, static or estimator. **kwargs: Raises: ValueError: if run_mode is eager and tf.executing_eagerly() is False. Exception: if # layers > 1 and dynamic_h0 is True. """ super(TEmbedding, self).__init__(config_loss=config_loss, run_mode=run_mode, **kwargs) self.pen_threshold = 0.3 self.config_encoder = config_encoder self.config_embedding = config_embedding self.config_decoder = config_decoder self.latent_prefix = "" self.regularize_decoder = self.config_decoder.get( "regularizer_weight", 0) > 0 self.kernel_regularizer = None if self.regularize_decoder: self.kernel_regularizer = tf.keras.regularizers.l2( self.config_decoder.get("regularizer_weight", 0)) self.n_latent_units = self.config_embedding["latent_units"] self.use_vae = self.config_embedding["use_vae"] self.decoder_drop_rate = self.config_decoder.get("dropout_rate", 0) self.t_frequency_channels = self.config_decoder.get( "t_frequency_channels", 0) # Encoder network self.net_encoder = None if self.config_encoder["name"] == "rnn": self.net_encoder = RNN( self.config_encoder["cell_type"], self.config_encoder["cell_units"], self.config_encoder["cell_layers"], self.config_encoder["bidirectional_encoder"], return_sequences=False, return_state=False, run_mode=run_mode, use_cudnn=self.config_encoder["use_cudnn"], name="encoder_rnn") elif self.config_encoder["name"] == "mlp": pass elif self.config_encoder["name"] == "cnn": pass elif self.config_encoder["name"] == "transformer": self.net_encoder = Transformer( num_layers=self.config_encoder["layers"], d_model=self.config_encoder["d_model"], num_heads=self.config_encoder["heads"], dff=self.config_encoder["hidden_units"], rate=self.config_encoder["dropout_rate"], scale=self.config_encoder["scale"], pos_encoding_len=self.config_encoder["pos_encoding"], autoregressive=self.config_encoder["autoregressive"], return_sequences=False, config_loss=None, run_mode=run_mode) else: err_unknown_type(self.config_encoder["name"]) # Deterministic or stochastic stroke. if self.use_vae: self.net_embedding = OutputModelNormal( out_units=self.n_latent_units, prefix=self.latent_prefix, sigma_activation=None, logvar=True) else: self.net_embedding = OutputModelDeterministic( out_units=self.n_latent_units, prefix=self.latent_prefix) # Decoder network: self.net_decoder = tf.keras.Sequential(name="decoder") layer_units = self.config_decoder["hidden_units"] if len(layer_units) == 1: layer_units = layer_units * self.config_decoder["n_layers"] decoder_activation = Activations.get(self.config_decoder["activation"]) for idx in range(self.config_decoder["layers"]): self.net_decoder.add( tf.keras.layers.Dense( layer_units[idx], activation=decoder_activation, kernel_regularizer=self.kernel_regularizer, bias_regularizer=self.kernel_regularizer)) if self.decoder_drop_rate > 0: self.net_decoder.add( tf.keras.layers.Dropout(self.decoder_drop_rate)) # Pen and stroke outputs. if config_loss["pen"]["eval_only"]: self.decoder_out_pen = None else: self.decoder_out_pen = tf.keras.layers.Dense( 1, name="out_pen", kernel_regularizer=self.kernel_regularizer, bias_regularizer=self.kernel_regularizer) # Build output model depending on the loss type. if self.config_loss["stroke"]["loss_type"] == C.NLL_NORMAL: self.decoder_out_stroke = OutputModelNormal(out_units=2, hidden_units=0, hidden_layers=0) elif self.config_loss["stroke"]["loss_type"] == C.NLL_BINORMAL: self.decoder_out_stroke = OutputModelNormal2DDense( sigma_activation=tf.keras.activations.exponential) elif self.config_loss["stroke"]["loss_type"] == C.NLL_GMM: self.decoder_out_stroke = OutputModelGMMDense( out_units=2, num_components=self.config_loss["stroke"]["num_components"], sigma_activation=tf.keras.activations.exponential) else: self.decoder_out_stroke = OutputModelDeterministic( out_units=2, hidden_units=0, hidden_layers=0, kernel_regularizer=self.kernel_regularizer, bias_regularizer=self.kernel_regularizer) # Variables for static mode. They are assigned in call method. # TODO We can get rid of them if autoregressive sampling is no # longer required in static (graph) mode. self.op_encoder_inputs = None self.op_input_seq_len = None self.op_embedding = None self.op_decoder_inputs = None self.op_embedding_sample = None
def build_predictive_model(config_, run_mode): """Builds model object.""" # Embedding model. if config_.decoder.get("name", "t_emb") == "t_emb": embedding_model = TEmbedding(config_encoder=config_.encoder, config_embedding=config_.embedding, config_decoder=config_.decoder, config_loss=config_.loss.ink, run_mode=run_mode) elif config_.decoder.get("name", "t_emb") == "rnn": embedding_model = InkSeq2Seq(config_encoder=config_.encoder, config_embedding=config_.embedding, config_decoder=config_.decoder, config_loss=config_.loss.ink, run_mode=run_mode) else: err_unknown_type(config_.decoder.name) if config_.predictive_model.get("name", "rnn") == "rnn": predictive_model = RNNConditional( output_size=config_.predictive_model.output_size, cell_units=config_.predictive_model.cell_units, cell_layers=config_.predictive_model.cell_layers, cell_type=config_.predictive_model.cell_type, return_sequences=False, return_state=False, run_mode=run_mode, config_loss=config_.loss.predicted_embedding.predicted_embedding) elif config_.predictive_model.name == "transformer": predictive_model = TransformerSeq2seqConditional( output_size=config_.predictive_model.latent_units, num_layers=config_.predictive_model.layers, d_model=config_.predictive_model.d_model, num_heads=config_.predictive_model.heads, dff=config_.predictive_model.hidden_units, rate=config_.predictive_model.dropout_rate, config_loss=config_.loss.predicted_embedding.predicted_embedding, pos_encoding_len=100 if config_.predictive_model.pos_encoding else 0, scale=config_.predictive_model.scale, run_mode=run_mode, autoregressive=False, pooling=config_.predictive_model.pooling_layer, ) else: err_unknown_type(config_.predictive_model.name) position_model = None if config_.predictive_model.get("name", "rnn") == "rnn": position_model = RNNConditional( output_size=2, cell_units=config_.position_model.cell_units, cell_layers=config_.position_model.cell_layers, cell_type=config_.position_model.cell_type, return_sequences=False, return_state=False, run_mode=run_mode, config_loss=config_.loss.predicted_pos.predicted_pos) elif config_.get("position_model", None) is not None: position_model = TransformerSeq2seqConditional( output_size=2, num_layers=config_.position_model.layers, d_model=config_.position_model.d_model, num_heads=config_.position_model.heads, dff=config_.position_model.hidden_units, rate=config_.position_model.dropout_rate, config_loss=config_.loss.predicted_pos.predicted_pos, pos_encoding_len=0, scale=config_.position_model.scale, run_mode=run_mode, autoregressive=False) model_ = PredictiveInkModel( embedding_model=embedding_model, predictive_model=predictive_model, position_model=position_model, loss_predicted_embedding=config_.loss.apply_predicted_embedding, loss_predicted_ink=config_.loss.apply_predicted_ink, loss_reconstructed_ink=config_.loss.apply_reconstructed_ink, input_type=config_.predictive_model.pred_input_type, start_positions=config_.predictive_model.use_start_pos, end_positions=config_.predictive_model.get("use_end_pos", False), stop_predictive_grad=config_.predictive_model.get( "stop_predictive_grad", False), num_predictive_inputs=config_.predictive_model.get( "num_predictive_inputs", 8), config_loss=copy.deepcopy(config_.loss), run_mode=run_mode) return model_
def build_dataset(config_, run_mode=C.RUN_STATIC, split=C.DATA_TRAIN): """Builds dataset object.""" dataset_ = None if split == C.DATA_TRAIN: dataset_ = TFRecordBatchDiagram( data_path=config_.data.train_data_path, meta_data_path=config_.data.meta_data_path, batch_size=config_.data.batch_size, pp_to_origin=config_.data.pp_to_origin, pp_relative_pos=config_.data.pp_relative_pos, normalize=config_.data.normalize, shuffle=True, run_mode=run_mode, max_length_threshold=config_.data.max_length_threshold, mask_pen=config_.data.mask_pen, resampling_factor=config_.data.get("resampling_factor", 0), t_drop_ratio=config_.data.get("t_drop_ratio", 0), scale_factor=config_.data.get("scale_factor", 0), pos_noise_factor=config_.data.get("pos_noise_factor", 0), affine_prob=config_.data.get("affine_prob", 0), reverse_prob=config_.data.get("reverse_prob", 0), gt_targets=config_.data.get("gt_targets", False), n_t_targets=config_.data.get("n_t_samples", 1), int_t_samples=config_.data.get("int_t_samples", False), concat_t_inputs=config_.data.get("concat_t_inputs", False), rdp=config_.data.get("rdp_dataset", False), rdp_didi_pp=config_.data.get("rdp_didi_pp", False), ) elif split == C.DATA_VALID: dataset_ = TFRecordBatchDiagram( data_path=config_.data.valid_data_path, meta_data_path=config_.data.meta_data_path, batch_size=config_.data.batch_size, pp_to_origin=config_.data.pp_to_origin, pp_relative_pos=config_.data.pp_relative_pos, normalize=config_.data.normalize, shuffle=False, run_mode=run_mode, max_length_threshold=config_.data.max_length_threshold, mask_pen=config_.data.mask_pen, concat_t_inputs=config_.data.get("concat_t_inputs", False), rdp=config_.data.get("rdp_dataset", False), rdp_didi_pp=config_.data.get("rdp_didi_pp", False), ) elif split == C.DATA_TEST: dataset_ = TFRecordSingleDiagram( data_path=config_.data.test_data_path, meta_data_path=config_.data.meta_data_path, batch_size=config_.data.batch_size, pp_to_origin=config_.data.pp_to_origin, pp_relative_pos=config_.data.pp_relative_pos, normalize=config_.data.normalize, shuffle=False, max_length_threshold=config_.data.max_length_threshold, run_mode=run_mode, mask_pen=config_.data.mask_pen, concat_t_inputs=config_.data.get("concat_t_inputs", False), rdp=config_.data.get("rdp_dataset", False), rdp_didi_pp=config_.data.get("rdp_didi_pp", False), ) else: err_unknown_type(split) return dataset_
def loss_fn(self, loss_config, predictions, targets, seq_len, prefix="", run_mode=C.RUN_STATIC, training=True): fixed_len_seq = loss_config.get("fixed_len_seq", 0) if prefix and not prefix.endswith("_"): prefix = prefix + "_" seq_len = seq_len loss_dict = dict() loss_metric_dict = dict() total_loss_ops = list() for loss_key, loss_term in loss_config.items(): if not isinstance(loss_term, dict): continue loss_sequence = None loss_type = loss_term["loss_type"] loss_reduce = loss_term["reduce"] loss_weight = loss_term["weight"] if not isinstance(loss_weight, float): start, end, increment = loss_term["weight"]["values"] loss_weight = end - (end - start) * increment**tf.cast(self.step, tf.float32) loss_dict["kldw"] = loss_weight if not training: loss_weight = 1.0 loss_objective = not loss_term.get("eval_only", False) loss_targets = targets.get(loss_term["target_key"], None) loss_predictions = predictions.get(loss_term["out_key"], None) seq_mask = None if seq_len is not None: if fixed_len_seq > 0: seq_mask = tf.expand_dims( tf.sequence_mask(seq_len, maxlen=fixed_len_seq, dtype=tf.float32), -1) else: seq_mask = tf.expand_dims( tf.sequence_mask(seq_len, dtype=tf.float32), -1) # Calculate loss with shape (batch_size, seq_len, 1) if loss_type == C.MSE: if isinstance(loss_targets, dict): loss_targets = loss_targets[C.MU] loss_sequence = tf.reduce_sum(input_tensor=tf.math.square(loss_targets - loss_predictions[C.MU]), axis=-1, keepdims=True) elif loss_type == C.L1: if isinstance(loss_targets, dict): loss_targets = loss_targets[C.MU] loss_sequence = tf.reduce_sum(input_tensor=tf.math.abs(loss_targets - loss_predictions[C.MU]), axis=-1, keepdims=True) elif loss_type == C.NLL_NORMAL: loss_sequence = -1*loss.logli_normal_diagonal( loss_targets, loss_predictions[C.MU], loss_predictions[C.SIGMA]) elif loss_type == C.NLL_BINORMAL: loss_sequence = -1*loss.logli_normal_bivariate( loss_targets, loss_predictions[C.MU], loss_predictions[C.SIGMA], loss_predictions[C.RHO]) elif loss_type == C.NLL_GMM: loss_sequence = -1*loss.logli_gmm_logsumexp( loss_targets, loss_predictions[C.MU], loss_predictions[C.SIGMA], loss_predictions[C.PI]) elif loss_type == C.NLL_CENT_BINARY: loss_sequence = tf.nn.sigmoid_cross_entropy_with_logits( loss_targets, loss_predictions) elif loss_type == C.KLD_STANDARD: loss_sequence = loss.kld_normal_diagonal_standard_prior( loss_predictions[C.MU], loss_predictions[C.SIGMA]) elif loss_type == C.KLD_STANDARD_NORM: loss_sequence = loss.kld_normal_diagonal_standard_prior_normalized( loss_predictions[C.MU], loss_predictions[C.SIGMA]) elif loss_type == C.KLD: loss_sequence = loss.kld_normal_diagonal( loss_predictions[C.MU], loss_predictions[C.SIGMA], loss_targets[C.MU], loss_targets[C.SIGMA], reduce_sum=False) elif loss_type == C.SNORM_L2: loss_sequence = tf.reduce_sum(input_tensor=tf.math.square(loss_predictions[C.MU]), axis=-1) else: err_unknown_type(loss_type) if len(loss_sequence.shape) == 3: # Mask the padded steps and calculate a scalar value. loss_sequence *= seq_mask loss_op = None if loss_reduce == C.R_MEAN_STEP: loss_op = reduce_mean_step(loss_sequence, seq_mask) # loss_op = tf.reduce_mean(loss_sequence) elif loss_reduce == C.R_MEAN_SEQUENCE: loss_op = reduce_mean_sequence(loss_sequence) else: err_unknown_type(loss_reduce) elif len(loss_sequence.shape) == 2: if seq_len is not None: nonzero_seq_len = tf.cast(tf.expand_dims(tf.compat.v1.where(seq_len > 0, tf.ones_like(seq_len), tf.zeros_like(seq_len)), axis=1), tf.float32) loss_op = loss_sequence * nonzero_seq_len loss_op = tf.reduce_sum(input_tensor=loss_op) / tf.reduce_sum(input_tensor=nonzero_seq_len) else: loss_op = tf.reduce_mean(input_tensor=loss_sequence) else: loss_op = tf.reduce_mean(input_tensor=loss_sequence) if loss_type in [C.KLD_STANDARD, C.KLD_STANDARD_NORM]: loss_op = tf.maximum(loss_op, 0.2) # The loss term is weighted only for the optimization. In order to enable # comparison in Tensorboard plots, we keep the loss unweighted. if len(loss_config) > 1: loss_dict[prefix + loss_key] = loss_op if loss_objective: total_loss_ops.append(loss_weight * loss_op) # tf.estimator requires the loss in tf.metrics. if run_mode == C.RUN_ESTIMATOR and len(loss_config) > 1: loss_metric_dict[prefix + loss_key] = tf.keras.metrics.Mean(loss_op) # Sum all loss terms to get the objective. loss_dict["loss"] = tf.math.add_n( total_loss_ops, name=prefix + "total_loss") loss_dict[prefix + "loss"] = loss_dict["loss"] if run_mode == C.RUN_ESTIMATOR: loss_metric_dict["loss"] = tf.keras.metrics.Mean(loss_dict["loss"]) loss_metric_dict[prefix + "loss"] = loss_metric_dict["loss"] return loss_dict, loss_metric_dict else: return loss_dict