def _get_loss_train_op(self): # Aggregates loss loss_rephraser = (tf.get_collection('loss_rephraser_list')[0] + tf.get_collection('loss_rephraser_list')[1] + tf.get_collection('loss_rephraser_list')[2]) / 3. self.loss = loss_rephraser # Creates optimizers self.vars = collect_trainable_variables([ self.bert_encoder, self.word_embedder, self.downmlp, self.rephrase_encoder, self.rephrase_decoder ]) # Train Op self.train_op_pre = get_train_op(self.loss, self.vars, learning_rate=self.lr, hparams=self._hparams.opt) # Interface tensors self.losses = { "loss": self.loss, "loss_rephraser": loss_rephraser, } self.metrics = {} self.train_ops = { "train_op_pre": self.train_op_pre, } self.samples = { "transferred_yy1_gt": self.text_ids_yy1, "transferred_yy1_pred": tf.get_collection('yy_pred_list')[0], "transferred_yy2_gt": self.text_ids_yy2, "transferred_yy2_pred": tf.get_collection('yy_pred_list')[1], "transferred_yy3_gt": self.text_ids_yy3, "transferred_yy3_pred": tf.get_collection('yy_pred_list')[2], "origin_y1": self.text_ids_y1, "origin_y2": self.text_ids_y2, "origin_y3": self.text_ids_y3, "x1x2": self.x1x2, "x1xx2": self.x1xx2 } tf.summary.scalar("loss", self.loss) tf.summary.scalar("loss_rephraser", loss_rephraser) self.merged = tf.summary.merge_all() self.fetches_train_pre = { "loss": self.train_ops["train_op_pre"], "loss_rephraser": self.losses["loss_rephraser"], "merged": self.merged, } fetches_eval = { "batch_size": get_batch_size(self.x1x2yx1xx2_ids), "merged": self.merged, } fetches_eval.update(self.losses) fetches_eval.update(self.metrics) fetches_eval.update(self.samples) self.fetches_eval = fetches_eval
def _get_loss_train_op(self): # Aggregates losses self.loss_g = self.loss_g_ae + self.lambda_g_hidden * self.loss_g_clas_hidden + self.lambda_g_sentence * self.loss_g_clas_sentence self.loss_d = self.loss_d_clas_hidden + self.loss_d_clas_sentence # Creates optimizers self.g_vars = collect_trainable_variables([ self.embedder, self.self_graph_encoder, self.label_connector, self.rephrase_encoder, self.rephrase_decoder ]) self.d_vars = collect_trainable_variables([ self.clas_embedder, self.classifier_hidden, self.classifier_sentence ]) self.train_op_g = get_train_op(self.loss_g, self.g_vars, hparams=self._hparams.opt) self.train_op_g_ae = get_train_op(self.loss_g_ae, self.g_vars, hparams=self._hparams.opt) self.train_op_d = get_train_op(self.loss_d, self.d_vars, hparams=self._hparams.opt) # Interface tensors self.losses = { "loss_g": self.loss_g, "loss_d": self.loss_d, "loss_g_ae": self.loss_g_ae, "loss_g_clas_hidden": self.loss_g_clas_hidden, "loss_g_clas_sentence": self.loss_g_clas_sentence, "loss_d_clas_hidden": self.loss_d_clas_hidden, "loss_d_clas_sentence": self.loss_d_clas_sentence, } self.metrics = { "accu_d_hidden": self.accu_d_hidden, "accu_d_sentence": self.accu_d_sentence, "accu_g_hidden": self.accu_g_hidden, "accu_g_sentence": self.accu_g_sentence, "accu_g_gdy_sentence": self.accu_g_gdy_sentence, } self.train_ops = { "train_op_g": self.train_op_g, "train_op_g_ae": self.train_op_g_ae, "train_op_d": self.train_op_d } self.samples = { "original": self.text_ids[:, 1:], "transferred": self.rephrase_outputs_.sample_id } self.fetches_train_d = { "loss_d": self.train_ops["train_op_d"], "loss_d_clas_hidden": self.losses["loss_d_clas_hidden"], "loss_d_clas_sentence": self.losses["loss_d_clas_sentence"], "accu_d_hidden": self.metrics["accu_d_hidden"], "accu_d_sentence": self.metrics["accu_d_sentence"], } tf.summary.scalar("loss_d", self.loss_d) tf.summary.scalar("loss_d_clas_hidden", self.loss_d_clas_hidden) tf.summary.scalar("loss_d_clas_sentence", self.loss_d_clas_sentence) tf.summary.scalar("accu_d_hidden", self.accu_d_hidden) tf.summary.scalar("accu_d_sentence", self.accu_d_sentence) tf.summary.scalar("loss_g", self.loss_g) tf.summary.scalar("loss_g_ae", self.loss_g_ae) tf.summary.scalar("loss_g_clas_hidden", self.loss_g_clas_hidden) tf.summary.scalar("loss_g_clas_sentence", self.loss_g_clas_sentence) tf.summary.scalar("accu_g_hidden", self.accu_g_hidden) tf.summary.scalar("accu_g_sentence", self.accu_g_sentence) tf.summary.scalar("accu_g_gdy_sentence", self.accu_g_gdy_sentence) self.merged = tf.summary.merge_all() self.fetches_train_g = { "loss_g": self.train_ops["train_op_g"], "loss_g_ae": self.losses["loss_g_ae"], "loss_g_clas_hidden": self.losses["loss_g_clas_hidden"], "loss_g_clas_sentence": self.losses["loss_g_clas_sentence"], "accu_g_hidden": self.metrics["accu_g_hidden"], "accu_g_sentence": self.metrics["accu_g_sentence"], "accu_g_gdy_sentence": self.metrics["accu_g_gdy_sentence"], "merged": self.merged, } fetches_eval = { "batch_size": get_batch_size(self.text_ids), "merged": self.merged, } fetches_eval.update(self.losses) fetches_eval.update(self.metrics) fetches_eval.update(self.samples) self.fetches_eval = fetches_eval
def build_model(data_batch, data, step): batch_size, num_steps = [ tf.shape(data_batch["x_value_text_ids"])[d] for d in range(2) ] vocab = data.vocab("y_aux") id2str = "<{}>".format bos_str, eos_str = map(id2str, (vocab.bos_token_id, vocab.eos_token_id)) def single_bleu(ref, hypo): ref = [id2str(u if u != vocab.unk_token_id else -1) for u in ref] hypo = [id2str(u) for u in hypo] ref = tx.utils.strip_special_tokens(" ".join(ref), strip_bos=bos_str, strip_eos=eos_str) hypo = tx.utils.strip_special_tokens(" ".join(hypo), strip_eos=eos_str) return 0.01 * tx.evals.sentence_bleu(references=[ref], hypothesis=hypo) def batch_bleu(refs, hypos): return np.array( [single_bleu(ref, hypo) for ref, hypo in zip(refs, hypos)], dtype=np.float32, ) def lambda_anneal(step_stage): print("==========step_stage is {}".format(step_stage)) if step_stage <= 1: rec_weight = 1 elif step_stage > 1 and step_stage < 2: rec_weight = Config.rec_w - step_stage * 0.1 return np.array(rec_weight, dtype=tf.float32) # losses losses = {} # embedders embedders = { name: tx.modules.WordEmbedder(vocab_size=data.vocab(name).size, hparams=hparams) for name, hparams in Config.config_model.embedders.items() } # encoders y_encoder = tx.modules.BidirectionalRNNEncoder( hparams=Config.config_model.y_encoder) x_encoder = tx.modules.BidirectionalRNNEncoder( hparams=Config.config_model.x_encoder) def concat_encoder_outputs(outputs): return tf.concat(outputs, -1) def encode(ref_flag): y_str = y_strs[ref_flag] sent_ids = data_batch["{}_text_ids".format(y_str)] sent_embeds = embedders["y_aux"](sent_ids) sent_sequence_length = data_batch["{}_length".format(y_str)] sent_enc_outputs, _ = y_encoder(sent_embeds, sequence_length=sent_sequence_length) sent_enc_outputs = concat_encoder_outputs(sent_enc_outputs) x_str = x_strs[ref_flag] sd_ids = { field: data_batch["{}_{}_text_ids".format(x_str, field)][:, 1:-1] for field in x_fields } sd_embeds = tf.concat( [ embedders["x_{}".format(field)](sd_ids[field]) for field in x_fields ], axis=-1, ) sd_sequence_length = ( data_batch["{}_{}_length".format(x_str, x_fields[0])] - 2) sd_enc_outputs, _ = x_encoder(sd_embeds, sequence_length=sd_sequence_length) sd_enc_outputs = concat_encoder_outputs(sd_enc_outputs) return ( sent_ids, sent_embeds, sent_enc_outputs, sent_sequence_length, sd_ids, sd_embeds, sd_enc_outputs, sd_sequence_length, ) encode_results = [encode(ref_str) for ref_str in range(2)] ( sent_ids, sent_embeds, sent_enc_outputs, sent_sequence_length, sd_ids, sd_embeds, sd_enc_outputs, sd_sequence_length, ) = zip(*encode_results) # get rnn cell rnn_cell = tx.core.layers.get_rnn_cell(Config.config_model.rnn_cell) def get_decoder(cell, y__ref_flag, x_ref_flag, tgt_ref_flag, beam_width=None): output_layer_params = ({ "output_layer": tf.identity } if Config.copy_flag else { "vocab_size": vocab.size }) if Config.attn_flag: # attention if Config.attn_x and Config.attn_y_: memory = tf.concat( [ sent_enc_outputs[y__ref_flag], sd_enc_outputs[x_ref_flag] ], axis=1, ) memory_sequence_length = None elif Config.attn_y_: memory = sent_enc_outputs[y__ref_flag] memory_sequence_length = sent_sequence_length[y__ref_flag] elif Config.attn_x: memory = sd_enc_outputs[x_ref_flag] memory_sequence_length = sd_sequence_length[x_ref_flag] else: raise Exception( "Must specify either y__ref_flag or x_ref_flag.") attention_decoder = tx.modules.AttentionRNNDecoder( cell=cell, memory=memory, memory_sequence_length=memory_sequence_length, hparams=Config.config_model.attention_decoder, **output_layer_params) if not Config.copy_flag: return attention_decoder cell = (attention_decoder.cell if beam_width is None else attention_decoder._get_beam_search_cell(beam_width)) if Config.copy_flag: # copynet kwargs = { "y__ids": sent_ids[y__ref_flag][:, 1:], "y__states": sent_enc_outputs[y__ref_flag][:, 1:], "y__lengths": sent_sequence_length[y__ref_flag] - 1, "x_ids": sd_ids[x_ref_flag]["value"], "x_states": sd_enc_outputs[x_ref_flag], "x_lengths": sd_sequence_length[x_ref_flag], } if tgt_ref_flag is not None: kwargs.update({ "input_ids": data_batch["{}_text_ids".format( y_strs[tgt_ref_flag])][:, :-1] }) memory_prefixes = [] if Config.copy_y_: memory_prefixes.append("y_") if Config.copy_x: memory_prefixes.append("x") if beam_width is not None: kwargs = { name: tile_batch(value, beam_width) for name, value in kwargs.items() } def get_get_copy_scores(memory_ids_states_lengths, output_size): memory_copy_states = [ tf.layers.dense( memory_states, units=output_size, activation=None, use_bias=False, ) for _, memory_states, _ in memory_ids_states_lengths ] def get_copy_scores(query, coverities=None): ret = [] if Config.copy_y_: memory = memory_copy_states[len(ret)] if coverities is not None: memory = memory + tf.layers.dense( coverities[len(ret)], units=output_size, activation=None, use_bias=False, ) memory = tf.nn.tanh(memory) ret_y_ = tf.einsum("bim,bm->bi", memory, query) ret.append(ret_y_) if Config.copy_x: memory = memory_copy_states[len(ret)] if coverities is not None: memory = memory + tf.layers.dense( coverities[len(ret)], units=output_size, activation=None, use_bias=False, ) memory = tf.nn.tanh(memory) ret_x = tf.einsum("bim,bm->bi", memory, query) ret.append(ret_x) return ret return get_copy_scores covrity_dim = (Config.config_model.coverage_state_dim if Config.coverage else None) coverity_rnn_cell_hparams = (Config.config_model.coverage_rnn_cell if Config.coverage else None) cell = CopyNetWrapper( cell=cell, vocab_size=vocab.size, memory_ids_states_lengths=[ tuple(kwargs["{}_{}".format(prefix, s)] for s in ("ids", "states", "lengths")) for prefix in memory_prefixes ], input_ids=kwargs["input_ids"] if tgt_ref_flag is not None else None, get_get_copy_scores=get_get_copy_scores, coverity_dim=covrity_dim, coverity_rnn_cell_hparams=coverity_rnn_cell_hparams, disabled_vocab_size=Config.disabled_vocab_size, eps=Config.eps, ) decoder = tx.modules.BasicRNNDecoder( cell=cell, hparams=Config.config_model.decoder, **output_layer_params) return decoder def get_decoder_and_outputs(cell, y__ref_flag, x_ref_flag, tgt_ref_flag, params, beam_width=None): decoder = get_decoder(cell, y__ref_flag, x_ref_flag, tgt_ref_flag, beam_width=beam_width) if beam_width is None: ret = decoder(**params) else: ret = tx.modules.beam_search_decode(decoder_or_cell=decoder, beam_width=beam_width, **params) return (decoder, ) + ret get_decoder_and_outputs = tf.make_template("get_decoder_and_outputs", get_decoder_and_outputs) def teacher_forcing(cell, y__ref_flag, x_ref_flag, loss_name): tgt_ref_flag = x_ref_flag tgt_str = y_strs[tgt_ref_flag] sequence_length = data_batch["{}_length".format(tgt_str)] - 1 decoder, tf_outputs, final_state, _ = get_decoder_and_outputs( cell, y__ref_flag, x_ref_flag, tgt_ref_flag, { "decoding_strategy": "train_greedy", "inputs": sent_embeds[tgt_ref_flag], "sequence_length": sequence_length, }, ) tgt_sent_ids = data_batch["{}_text_ids".format(tgt_str)][:, 1:] loss = tx.losses.sequence_sparse_softmax_cross_entropy( labels=tgt_sent_ids, logits=tf_outputs.logits, sequence_length=sequence_length, average_across_batch=False, ) if (Config.add_bleu_weight and y__ref_flag is not None and tgt_ref_flag is not None and y__ref_flag != tgt_ref_flag): w = tf.py_func( batch_bleu, [sent_ids[y__ref_flag], tgt_sent_ids], tf.float32, stateful=False, name="W_BLEU", ) w.set_shape(loss.get_shape()) loss = w * loss loss = tf.reduce_mean(loss, 0) if Config.copy_flag and Config.exact_cover_w != 0: sum_copy_probs = list( map(lambda t: tf.cast(t, tf.float32), final_state.sum_copy_probs)) memory_lengths = [ lengths for _, _, lengths in decoder.cell.memory_ids_states_lengths ] exact_coverage_losses = [ tf.reduce_mean( tf.reduce_sum( tx.utils.mask_sequences(tf.square(sum_copy_prob - 1.0), memory_length), 1, )) for sum_copy_prob, memory_length in zip( sum_copy_probs, memory_lengths) ] print_xe_loss_op = tf.print(loss_name, "xe loss:", loss) with tf.control_dependencies([print_xe_loss_op]): for i, exact_coverage_loss in enumerate(exact_coverage_losses): print_op = tf.print( loss_name, "exact coverage loss {:d}:".format(i), exact_coverage_loss, ) with tf.control_dependencies([print_op]): loss += Config.exact_cover_w * exact_coverage_loss losses[loss_name] = loss return decoder, tf_outputs, loss def beam_searching(cell, y__ref_flag, x_ref_flag, beam_width): start_tokens = (tf.ones_like(data_batch["y_aux_length"]) * vocab.bos_token_id) end_token = vocab.eos_token_id decoder, bs_outputs, _, _ = get_decoder_and_outputs( cell, y__ref_flag, x_ref_flag, None, { "embedding": embedders["y_aux"], "start_tokens": start_tokens, "end_token": end_token, "max_decoding_length": Config.config_train.infer_max_decoding_length, }, beam_width=Config.config_train.infer_beam_width, ) return decoder, bs_outputs decoder, tf_outputs, loss = teacher_forcing(rnn_cell, 1, 0, "MLE") rec_decoder, _, rec_loss = teacher_forcing(rnn_cell, 1, 1, "REC") rec_weight = Config.rec_w step_stage = tf.cast(step, tf.float32) / tf.constant(800.0) rec_weight = tf.case( [ ( tf.less_equal(step_stage, tf.constant(1.0)), lambda: tf.constant(1.0), ), (tf.greater(step_stage, tf.constant(2.0)), lambda: Config.rec_w), ], default=lambda: tf.constant(1.0) - (step_stage - 1) * (1 - Config.rec_w), ) joint_loss = (1 - rec_weight) * loss + rec_weight * rec_loss losses["joint"] = joint_loss tiled_decoder, bs_outputs = beam_searching( rnn_cell, 1, 0, Config.config_train.infer_beam_width) train_ops = { name: get_train_op(losses[name], hparams=Config.config_train.train[name]) for name in Config.config_train.train } return train_ops, bs_outputs
def _build_model(self, inputs, vocab, gamma, lambda_g): """Builds the model. """ embedder = WordEmbedder( vocab_size=vocab.size, hparams=self._hparams.embedder) encoder = UnidirectionalRNNEncoder(hparams=self._hparams.encoder) # text_ids for encoder, with BOS token removed enc_text_ids = inputs['text_ids'][:, 1:] enc_outputs, final_state = encoder(embedder(enc_text_ids), sequence_length=inputs['length']-1) z = final_state[:, self._hparams.dim_c:] # Encodes label label_connector = MLPTransformConnector(self._hparams.dim_c) # Gets the sentence representation: h = (c, z) labels0 = tf.to_float(tf.reshape(inputs['labels0'], [-1, 1])) labels1 = tf.to_float(tf.reshape(inputs['labels1'], [-1, 1])) labels2 = tf.to_float(tf.reshape(inputs['labels2'], [-1, 1])) labels3 = tf.to_float(tf.reshape(inputs['labels3'], [-1, 1])) labels = tf.concat([labels0, labels1, labels2, labels3], axis = 1) print('labels', labels) sys.stdout.flush() c = label_connector(labels) c_ = label_connector(1 - labels) h = tf.concat([c, z], 1) h_ = tf.concat([c_, z], 1) # Teacher-force decoding and the auto-encoding loss for G decoder = AttentionRNNDecoder( memory=enc_outputs, memory_sequence_length=inputs['length']-1, cell_input_fn=lambda inputs, attention: inputs, vocab_size=vocab.size, hparams=self._hparams.decoder) connector = MLPTransformConnector(decoder.state_size) g_outputs, _, _ = decoder( initial_state=connector(h), inputs=inputs['text_ids'], embedding=embedder, sequence_length=inputs['length']-1) print('labels shape', inputs['text_ids'][:, 1:], 'logits shape', g_outputs.logits) print(inputs['length'] - 1) loss_g_ae = tx.losses.sequence_sparse_softmax_cross_entropy( labels=inputs['text_ids'][:, 1:], logits=g_outputs.logits, sequence_length=inputs['length']-1, average_across_timesteps=True, sum_over_timesteps=False) # Gumbel-softmax decoding, used in training start_tokens = tf.ones_like(inputs['labels0']) * vocab.bos_token_id end_token = vocab.eos_token_id gumbel_helper = GumbelSoftmaxEmbeddingHelper( embedder.embedding, start_tokens, end_token, gamma) soft_outputs_, _, soft_length_, = decoder( helper=gumbel_helper, initial_state=connector(h_)) print(g_outputs, soft_outputs_) # Greedy decoding, used in eval outputs_, _, length_ = decoder( decoding_strategy='infer_greedy', initial_state=connector(h_), embedding=embedder, start_tokens=start_tokens, end_token=end_token) # Creates classifier classifier0 = Conv1DClassifier(hparams=self._hparams.classifier) classifier1 = Conv1DClassifier(hparams=self._hparams.classifier) classifier2 = Conv1DClassifier(hparams=self._hparams.classifier) classifier3 = Conv1DClassifier(hparams=self._hparams.classifier) clas_embedder = WordEmbedder(vocab_size=vocab.size, hparams=self._hparams.embedder) clas_logits, clas_preds = self._high_level_classifier([classifier0, classifier1, classifier2, classifier3], clas_embedder, inputs, vocab, gamma, lambda_g, inputs['text_ids'][:, 1:], None, inputs['length']-1) loss_d_clas = tf.nn.sigmoid_cross_entropy_with_logits( labels=tf.to_float(labels), logits=clas_logits) loss_d_clas = tf.reduce_mean(loss_d_clas) accu_d = tx.evals.accuracy(labels, preds=clas_preds) # Classification loss for the generator, based on soft samples # soft_logits, soft_preds = classifier( # inputs=clas_embedder(soft_ids=soft_outputs_.sample_id), # sequence_length=soft_length_) soft_logits, soft_preds = self._high_level_classifier([classifier0, classifier1, classifier2, classifier3], clas_embedder, inputs, vocab, gamma, lambda_g, None, soft_outputs_.sample_id, soft_length_) print(soft_logits.shape, soft_preds.shape) loss_g_clas = tf.nn.sigmoid_cross_entropy_with_logits( labels=tf.to_float(1-labels), logits=soft_logits) loss_g_clas = tf.reduce_mean(loss_g_clas) # Accuracy on soft samples, for training progress monitoring accu_g = tx.evals.accuracy(labels=1-labels, preds=soft_preds) # Accuracy on greedy-decoded samples, for training progress monitoring # _, gdy_preds = classifier( # inputs=clas_embedder(ids=outputs_.sample_id), # sequence_length=length_) _, gdy_preds = self._high_level_classifier([classifier0, classifier1, classifier2, classifier3], clas_embedder, inputs, vocab, gamma, lambda_g, outputs_.sample_id, None, length_) print(gdy_preds.shape) accu_g_gdy = tx.evals.accuracy( labels=1-labels, preds=gdy_preds) # Aggregates losses loss_g = loss_g_ae + lambda_g * loss_g_clas loss_d = loss_d_clas # Creates optimizers g_vars = collect_trainable_variables( [embedder, encoder, label_connector, connector, decoder]) d_vars = collect_trainable_variables([clas_embedder, classifier0, classifier1, classifier2, classifier3]) train_op_g = get_train_op( loss_g, g_vars, hparams=self._hparams.opt) train_op_g_ae = get_train_op( loss_g_ae, g_vars, hparams=self._hparams.opt) train_op_d = get_train_op( loss_d, d_vars, hparams=self._hparams.opt) # Interface tensors self.predictions = { "predictions": clas_preds, "ground_truth": labels } self.losses = { "loss_g": loss_g, "loss_g_ae": loss_g_ae, "loss_g_clas": loss_g_clas, "loss_d": loss_d_clas } self.metrics = { "accu_d": accu_d, "accu_g": accu_g, "accu_g_gdy": accu_g_gdy, } self.train_ops = { "train_op_g": train_op_g, "train_op_g_ae": train_op_g_ae, "train_op_d": train_op_d } self.samples = { "original": inputs['text_ids'][:, 1:], "transferred": outputs_.sample_id } self.fetches_train_g = { "loss_g": self.train_ops["train_op_g"], "loss_g_ae": self.losses["loss_g_ae"], "loss_g_clas": self.losses["loss_g_clas"], "accu_g": self.metrics["accu_g"], "accu_g_gdy": self.metrics["accu_g_gdy"], } self.fetches_train_d = { "loss_d": self.train_ops["train_op_d"], "accu_d": self.metrics["accu_d"] } fetches_eval = {"batch_size": get_batch_size(inputs['text_ids'])} fetches_eval.update(self.losses) fetches_eval.update(self.metrics) fetches_eval.update(self.samples) fetches_eval.update(self.predictions) self.fetches_eval = fetches_eval
def _build_model(self, inputs, vocab, finputs, minputs, gamma): """Builds the model. """ self.inputs = inputs self.finputs = finputs self.minputs = minputs self.vocab = vocab self.embedder = WordEmbedder(vocab_size=self.vocab.size, hparams=self._hparams.embedder) # maybe later have to try BidirectionalLSTMEncoder self.encoder = UnidirectionalRNNEncoder( hparams=self._hparams.encoder) #GRU cell # text_ids for encoder, with BOS(begin of sentence) token removed self.enc_text_ids = self.inputs['text_ids'][:, 1:] self.enc_outputs, self.final_state = self.encoder( self.embedder(self.enc_text_ids), sequence_length=self.inputs['length'] - 1) h = self.final_state # Teacher-force decoding and the auto-encoding loss for G self.decoder = AttentionRNNDecoder( memory=self.enc_outputs, memory_sequence_length=self.inputs['length'] - 1, cell_input_fn=lambda inputs, attention: inputs, #default: lambda inputs, attention: tf.concat([inputs, attention], -1), which cancats regular RNN cell inputs with attentions. vocab_size=self.vocab.size, hparams=self._hparams.decoder) self.connector = MLPTransformConnector(self.decoder.state_size) self.g_outputs, _, _ = self.decoder( initial_state=self.connector(h), inputs=self.inputs['text_ids'], embedding=self.embedder, sequence_length=self.inputs['length'] - 1) self.loss_g_ae = tx.losses.sequence_sparse_softmax_cross_entropy( labels=self.inputs['text_ids'][:, 1:], logits=self.g_outputs.logits, sequence_length=self.inputs['length'] - 1, average_across_timesteps=True, sum_over_timesteps=False) # Greedy decoding, used in eval (and RL training) start_tokens = tf.ones_like( self.inputs['labels']) * self.vocab.bos_token_id end_token = self.vocab.eos_token_id self.outputs, _, length = self.decoder( #也许可以尝试之后把这个换成 "infer_sample"看效果 decoding_strategy='infer_greedy', initial_state=self.connector(h), embedding=self.embedder, start_tokens=start_tokens, end_token=end_token) # Creates optimizers self.g_vars = collect_trainable_variables( [self.embedder, self.encoder, self.connector, self.decoder]) self.train_op_g_ae = get_train_op(self.loss_g_ae, self.g_vars, hparams=self._hparams.opt) # Interface tensors self.samples = { "batch_size": get_batch_size(self.inputs['text_ids']), "original": self.inputs['text_ids'][:, 1:], "transferred": self.outputs.sample_id #outputs 是infer_greedy的结果 } ############################ female sentiment regression model #现在只用了convnet不知道效果,之后可以试试RNN decoding看regression的准确度,或者把两个结合一下(concat成一个向量) self.fconvnet = Conv1DNetwork( hparams=self._hparams.convnet ) #[batch_size, time_steps, embedding_dim] (default input) #convnet = Conv1DNetwork() self.freg_embedder = WordEmbedder( vocab_size=self.vocab.size, hparams=self._hparams.embedder ) #(64, 26, 100) (output shape of clas_embedder(ids=inputs['text_ids'][:, 1:])) self.fconv_output = self.fconvnet(inputs=self.freg_embedder( ids=self.finputs['text_ids'][:, 1:])) #(64, 128) 等一会做一下finputs!!! p = {"type": "Dense", "kwargs": {'units': 1}} self.fdense_layer = tx.core.layers.get_layer(hparams=p) self.freg_output = self.fdense_layer(inputs=self.fconv_output) ''' #考虑 self.fenc_text_ids = self.finputs['text_ids'][:, 1:] self.fencoder = UnidirectionalRNNEncoder(hparams=self._hparams.encoder) #GRU cell self.fenc_outputs, self.ffinal_state = self.fencoder(self.freg_embedder(self.fenc_text_ids),sequence_length=self.finputs['length']-1) self.freg_output = self.fdense_layer(inputs = tf.concat([self.fconv_output, self.ffinal_state], -1)) ''' self.fprediction = tf.reshape(self.freg_output, [-1]) self.fground_truth = tf.to_float(self.finputs['labels']) self.floss_reg_single = tf.pow( self.fprediction - self.fground_truth, 2) #这样得到的是单个的loss,可以之后在RL里面对一整个batch进行update self.floss_reg_batch = tf.reduce_mean( self.floss_reg_single) #对一个batch求和平均的loss #self.freg_vars = collect_trainable_variables([self.freg_embedder, self.fconvnet, self.fencoder, self.fdense_layer]) self.freg_vars = collect_trainable_variables( [self.freg_embedder, self.fconvnet, self.fdense_layer]) self.ftrain_op_d = get_train_op(self.floss_reg_batch, self.freg_vars, hparams=self._hparams.opt) self.freg_sample = { "fprediction": self.fprediction, "fground_truth": self.fground_truth, "fsent": self.finputs['text_ids'][:, 1:] } ############################ male sentiment regression model self.mconvnet = Conv1DNetwork( hparams=self._hparams.convnet ) #[batch_size, time_steps, embedding_dim] (default input) #convnet = Conv1DNetwork() self.mreg_embedder = WordEmbedder( vocab_size=self.vocab.size, hparams=self._hparams.embedder ) #(64, 26, 100) (output shape of clas_embedder(ids=inputs['text_ids'][:, 1:])) self.mconv_output = self.mconvnet(inputs=self.mreg_embedder( ids=self.minputs['text_ids'][:, 1:])) #(64, 128) p = {"type": "Dense", "kwargs": {'units': 1}} self.mdense_layer = tx.core.layers.get_layer(hparams=p) self.mreg_output = self.mdense_layer(inputs=self.mconv_output) ''' #考虑 self.menc_text_ids = self.minputs['text_ids'][:, 1:] self.mencoder = UnidirectionalRNNEncoder(hparams=self._hparams.encoder) #GRU cell self.menc_outputs, self.mfinal_state = self.mencoder(self.mreg_embedder(self.menc_text_ids),sequence_length=self.minputs['length']-1) self.mreg_output = self.mdense_layer(inputs = tf.concat([self.mconv_output, self.mfinal_state], -1)) ''' self.mprediction = tf.reshape(self.mreg_output, [-1]) self.mground_truth = tf.to_float(self.minputs['labels']) self.mloss_reg_single = tf.pow( self.mprediction - self.mground_truth, 2) #这样得到的是单个的loss,可以之后在RL里面对一整个batch进行update self.mloss_reg_batch = tf.reduce_mean( self.mloss_reg_single) #对一个batch求和平均的loss #self.mreg_vars = collect_trainable_variables([self.mreg_embedder, self.mconvnet, self.mencoder, self.mdense_layer]) self.mreg_vars = collect_trainable_variables( [self.mreg_embedder, self.mconvnet, self.mdense_layer]) self.mtrain_op_d = get_train_op(self.mloss_reg_batch, self.mreg_vars, hparams=self._hparams.opt) self.mreg_sample = { "mprediction": self.mprediction, "mground_truth": self.mground_truth, "msent": self.minputs['text_ids'][:, 1:] } ###### get self.pre_dif when doing RL training (for transferred sents) ### pass to female regression model self.RL_fconv_output = self.fconvnet(inputs=self.freg_embedder( ids=self.outputs.sample_id)) #(64, 128) 等一会做一下finputs!!! self.RL_freg_output = self.fdense_layer(inputs=self.RL_fconv_output) self.RL_fprediction = tf.reshape(self.RL_freg_output, [-1]) ### pass to male regression model self.RL_mconv_output = self.mconvnet(inputs=self.mreg_embedder( ids=self.outputs.sample_id)) #(64, 128) 等一会做一下finputs!!! self.RL_mreg_output = self.mdense_layer(inputs=self.RL_mconv_output) self.RL_mprediction = tf.reshape(self.RL_mreg_output, [-1]) self.pre_dif = tf.abs(self.RL_fprediction - self.RL_mprediction) ###### get self.Ypre_dif for original sents ### pass to female regression model self.YRL_fconv_output = self.fconvnet(inputs=self.freg_embedder( ids=self.inputs['text_ids'][:, 1:])) #(64, 128) 等一会做一下finputs!!! self.YRL_freg_output = self.fdense_layer(inputs=self.YRL_fconv_output) self.YRL_fprediction = tf.reshape(self.YRL_freg_output, [-1]) ### pass to male regression model self.YRL_mconv_output = self.mconvnet(inputs=self.mreg_embedder( ids=self.inputs['text_ids'][:, 1:])) #(64, 128) 等一会做一下finputs!!! self.YRL_mreg_output = self.mdense_layer(inputs=self.YRL_mconv_output) self.YRL_mprediction = tf.reshape(self.YRL_mreg_output, [-1]) self.Ypre_dif = tf.abs(self.YRL_fprediction - self.YRL_mprediction) ######################## RL training ''' def fil(elem): return tf.where(elem > 1.3, tf.minimum(elem,3), 0) def fil_pushsmall(elem): return tf.add(tf.where(elem <0.5, 1, 0),tf.where(elem>1.5,-0.5*elem,0)) ''' ''' #缩小prediction差异 def fil1(elem): return tf.where(elem<0.5,1.0,0.0) def fil2(elem): return tf.where(elem>1.5,-0.5*elem,0.0) ''' #扩大prediction差异 def fil1(elem): return tf.where(elem < 0.5, -0.01, 0.0) def fil2(elem): return tf.where(elem > 1.3, elem, 0.0) # 维数是(batch_size,time_step),对应的是一个batch中每一个sample的每一个timestep的loss self.beginning_loss_g_RL2 = tf.nn.sparse_softmax_cross_entropy_with_logits( _sentinel=None, labels=self.outputs.sample_id, logits=self.outputs.logits, name=None) self.middle_loss_g_RL2 = tf.reduce_sum( self.beginning_loss_g_RL2, axis=1 ) #(batch_size,),这样得到的loss是每一个句子的loss(对time_steps求和,对batch不求和) #trivial "RL" training with all weight set to 1 #final_loss_g_RL2 = tf.reduce_sum(self.middle_loss_g_RL2) #RL training self.filtered = tf.add(tf.map_fn(fil1, self.pre_dif), tf.map_fn(fil2, self.pre_dif)) self.updated_loss_per_sent = tf.multiply( self.filtered, self.middle_loss_g_RL2) #haven't set threshold for weight update self.updated_loss_per_batch = tf.reduce_sum( self.updated_loss_per_sent) #############!!有一个问题: # 我想update每一个句子的loss,但是train_updated那里会报错,所以好像只能updateloss的求和,这样是相当于update每一个句子的loss吗? self.vars_updated = collect_trainable_variables( [self.connector, self.decoder]) self.train_updated = get_train_op(self.updated_loss_per_batch, self.vars_updated, hparams=self._hparams.opt) self.train_updated_interface = { "pre_dif": self.pre_dif, "updated_loss_per_sent": self.updated_loss_per_sent, "updated_loss_per_batch": self.updated_loss_per_batch, } ### Train AE and RL together self.loss_AERL = gamma * self.updated_loss_per_batch + self.loss_g_ae self.vars_AERL = collect_trainable_variables( [self.connector, self.decoder]) self.train_AERL = get_train_op(self.loss_AERL, self.vars_AERL, hparams=self._hparams.opt)
def _get_loss_train_op(self): # Aggregates losses self.loss_g = self.loss_g_ae + self.lambda_t_graph * self.loss_g_clas_graph + self.lambda_t_sentence * self.loss_g_clas_sentence # possible ablation: SGT-I, CGT-I, SGT-CGT-I, c-clas-g-only, c-clas-s-only if self.ablation == 'c-clas-g-only': self.loss_d = self.loss_d_clas_graph elif self.ablation == 'c-clas-s-only': self.loss_d = self.loss_d_clas_sentence else: self.loss_d = self.loss_d_clas_graph + self.loss_d_clas_sentence # Creates optimizers self.g_vars = collect_trainable_variables([ self.embedder, self.self_graph_encoder, self.label_connector, self.cross_graph_encoder, self.rephrase_encoder, self.rephrase_decoder ]) self.d_vars = collect_trainable_variables([ self.clas_embedder, self.classifier_graph, self.classifier_sentence ]) self.train_op_g = get_train_op(self.loss_g, self.g_vars, hparams=self._hparams.opt) self.train_op_g_ae = get_train_op(self.loss_g_ae, self.g_vars, hparams=self._hparams.opt) self.train_op_d = get_train_op(self.loss_d, self.d_vars, hparams=self._hparams.opt) # Interface tensors self.losses = { "loss_g": self.loss_g, "loss_d": self.loss_d, "loss_g_ae": self.loss_g_ae, "loss_g_clas_graph": self.loss_g_clas_graph, "loss_g_clas_sentence": self.loss_g_clas_sentence, "loss_d_clas_graph": self.loss_d_clas_graph, "loss_d_clas_sentence": self.loss_d_clas_sentence, } self.metrics = { "accu_d_graph": self.accu_d_graph, "accu_d_sentence": self.accu_d_sentence, "accu_g_graph": self.accu_g_graph, "accu_g_sentence": self.accu_g_sentence, "accu_g_gdy_sentence": self.accu_g_gdy_sentence } self.train_ops = { "train_op_g": self.train_op_g, "train_op_g_ae": self.train_op_g_ae, "train_op_d": self.train_op_d } self.samples = { "original": self.text_ids[:, 1:], "transferred": self.rephrase_outputs_.sample_id } self.fetches_train_g = { "loss_g": self.train_ops["train_op_g"], "loss_g_ae": self.losses["loss_g_ae"], "loss_g_clas_graph": self.losses["loss_g_clas_graph"], "loss_g_clas_sentence": self.losses["loss_g_clas_sentence"], "accu_g_graph": self.metrics["accu_g_graph"], "accu_g_sentence": self.metrics["accu_g_sentence"], "accu_g_gdy_sentence": self.metrics["accu_g_gdy_sentence"] # 'adjs': self.adjs, # 'identities': self.identities } self.fetches_train_d = { "loss_d": self.train_ops["train_op_d"], "loss_d_clas_graph": self.losses["loss_d_clas_graph"], "loss_d_clas_sentence": self.losses["loss_d_clas_sentence"], "accu_d_graph": self.metrics["accu_d_graph"], "accu_d_sentence": self.metrics["accu_d_sentence"] # 'adjs': self.adjs, # 'identities': self.identities } fetches_eval = {"batch_size": get_batch_size(self.text_ids)} fetches_eval.update(self.losses) fetches_eval.update(self.metrics) fetches_eval.update(self.samples) self.fetches_eval = fetches_eval
def _get_loss_train_op(self): # Aggregates loss loss_rephraser = (tf.get_collection('loss_rephraser_list')[0] + tf.get_collection('loss_rephraser_list')[1] + tf.get_collection('loss_rephraser_list')[2])/3. w_recon = 1.0 w_fine = 0.5 w_xx2 = 0.0 self.loss = loss_rephraser + w_fine*self.loss_fine + w_recon*self.loss_mask_recon #+ w_xx2*self.loss_xx2###check逐量级修改 # Creates optimizers self.vars = collect_trainable_variables([self.transformer_encoder, self.word_embedder, self.self_graph_encoder, self.downmlp, self.PRelu, self.rephrase_encoder, self.rephrase_decoder]) # Train Op self.train_op_pre = get_train_op( self.loss, self.vars, hparams=self._hparams.opt)#learning_rate=self.lr # Interface tensors self.losses = { "loss": self.loss, "loss_rephraser": loss_rephraser, "loss_fine":self.loss_fine, "loss_mask_recon":self.loss_mask_recon, "loss_xx2": self.loss_xx2, } self.metrics = { } self.train_ops = { "train_op_pre": self.train_op_pre, } self.samples = { "transferred_yy1_gt": self.text_ids_yy1, "transferred_yy1_pred": tf.get_collection('yy_pred_list')[0], "transferred_yy2_gt": self.text_ids_yy2, "transferred_yy2_pred": tf.get_collection('yy_pred_list')[1], "transferred_yy3_gt": self.text_ids_yy3, "transferred_yy3_pred": tf.get_collection('yy_pred_list')[2], "origin_y1":self.text_ids_y1, "origin_y2":self.text_ids_y2, "origin_y3":self.text_ids_y3, "x1x2":self.x1x2, "x1xx2":self.x1xx2 } tf.summary.scalar("loss", self.loss) tf.summary.scalar("loss_rephraser", loss_rephraser) tf.summary.scalar("loss_fine", self.loss_fine) tf.summary.scalar("loss_mask_recon", self.loss_mask_recon) tf.summary.scalar("loss_xx2", self.loss_xx2) self.merged = tf.summary.merge_all() self.fetches_train_pre = { "loss": self.train_ops["train_op_pre"], "loss_rephraser": self.losses["loss_rephraser"], "loss_fine": self.losses["loss_fine"], "loss_mask_recon": self.losses["loss_mask_recon"], "loss_xx2": self.losses["loss_xx2"], "merged": self.merged, } fetches_eval = {"batch_size": get_batch_size(self.x1x2yx1xx2_ids), "merged": self.merged, } fetches_eval.update(self.losses) fetches_eval.update(self.metrics) fetches_eval.update(self.samples) self.fetches_eval = fetches_eval
def build_model(data_batch, data, step): batch_size, num_steps = [ tf.shape(data_batch["x_value_text_ids"])[d] for d in range(2) ] vocab = data.vocab('y_aux') id2str = '<{}>'.format bos_str, eos_str = map(id2str, (vocab.bos_token_id, vocab.eos_token_id)) def single_bleu(ref, hypo): ref = [id2str(u if u != vocab.unk_token_id else -1) for u in ref] hypo = [id2str(u) for u in hypo] ref = tx.utils.strip_special_tokens(' '.join(ref), strip_bos=bos_str, strip_eos=eos_str) hypo = tx.utils.strip_special_tokens(' '.join(hypo), strip_eos=eos_str) return 0.01 * tx.evals.sentence_bleu(references=[ref], hypothesis=hypo) # losses losses = {} # embedders embedders = { name: tx.modules.WordEmbedder(vocab_size=data.vocab(name).size, hparams=hparams) for name, hparams in config_model.embedders.items() } # encoders y_encoder = tx.modules.TransformerEncoder(hparams=config_model.y_encoder) x_encoder = tx.modules.TransformerEncoder(hparams=config_model.x_encoder) def concat_encoder_outputs(outputs): return tf.concat(outputs, -1) def encode(ref_flag): y_str = y_strs[ref_flag] y_ids = data_batch['{}_text_ids'.format(y_str)] y_embeds = embedders['y_aux'](y_ids) y_sequence_length = data_batch['{}_length'.format(y_str)] y_enc_outputs = y_encoder(y_embeds, sequence_length=y_sequence_length) y_enc_outputs = concat_encoder_outputs(y_enc_outputs) x_str = x_strs[ref_flag] x_ids = { field: data_batch['{}_{}_text_ids'.format(x_str, field)][:, 1:-1] for field in x_fields } x_embeds = tf.concat([ embedders['x_{}'.format(field)](x_ids[field]) for field in x_fields ], axis=-1) x_sequence_length = data_batch['{}_{}_length'.format( x_str, x_fields[0])] - 2 x_enc_outputs = x_encoder(x_embeds, sequence_length=x_sequence_length) x_enc_outputs = concat_encoder_outputs(x_enc_outputs) return y_ids, y_embeds, y_enc_outputs, y_sequence_length, \ x_ids, x_embeds, x_enc_outputs, x_sequence_length encode_results = [encode(ref_flag) for ref_flag in range(2)] y_ids, y_embeds, y_enc_outputs, y_sequence_length, \ x_ids, x_embeds, x_enc_outputs, x_sequence_length = \ zip(*encode_results) # get rnn cell # rnn_cell = tx.core.layers.get_rnn_cell(config_model.rnn_cell) def get_decoder(y__ref_flag, x_ref_flag, tgt_ref_flag, beam_width=None): output_layer_params = \ {'output_layer': tf.identity} if copy_flag else {'vocab_size': vocab.size} if attn_flag: # attention memory = tf.concat( [y_enc_outputs[y__ref_flag], x_enc_outputs[x_ref_flag]], axis=1) memory_sequence_length = None copy_memory_sequence_length = None tgt_embedding = tf.concat([ tf.zeros(shape=[1, embedders['y_aux'].dim]), embedders['y_aux'].embedding[1:, :] ], axis=0) decoder = tx.modules.TransformerCopyDecoder( embedding=tgt_embedding, hparams=config_model.decoder) return decoder def get_decoder_and_outputs(y__ref_flag, x_ref_flag, tgt_ref_flag, params, beam_width=None): decoder = get_decoder(y__ref_flag, x_ref_flag, tgt_ref_flag, beam_width=beam_width) if beam_width is None: ret = decoder(**params) else: ret = decoder(beam_width=beam_width, **params) return decoder, ret get_decoder_and_outputs = tf.make_template('get_decoder_and_outputs', get_decoder_and_outputs) gamma = tf.Variable(1, dtype=tf.float32, trainable=True) gamma = tf.exp(tf.log(gamma)) def teacher_forcing(y__ref_flag, x_ref_flag, loss_name): tgt_flag = x_ref_flag tgt_str = y_strs[tgt_flag] memory_sequence_length = tf.add(y_sequence_length[y__ref_flag] - 1, x_sequence_length[x_ref_flag]) sequence_length = data_batch['{}_length'.format(tgt_str)] - 1 memory = tf.concat( [y_enc_outputs[y__ref_flag], x_enc_outputs[x_ref_flag]], axis=1) # [64 61 384] decoder, rets = get_decoder_and_outputs( y__ref_flag, x_ref_flag, tgt_flag, { 'memory': memory, #print_mem, 'memory_sequence_length': memory_sequence_length, 'copy_memory': x_enc_outputs[x_ref_flag], 'copy_memory_sequence_length': x_sequence_length[x_ref_flag], 'source_ids': x_ids[x_ref_flag]['value'], #print_ids, # source_ids 'gamma': gamma, 'decoding_strategy': 'train_greedy', 'inputs': y_embeds[tgt_flag] [:, :-1, :], #[:, 1:, :], #target yence embeds (ignore <BOS>) 'alpha': config_model.alpha, 'sequence_length': sequence_length, 'mode': tf.estimator.ModeKeys.TRAIN }) tgt_y_ids = data_batch['{}_text_ids'.format( tgt_str)][:, 1:] # ground_truth ids (ignore <BOS>) tf_outputs = rets[0] gens = rets[2] loss = tx.losses.sequence_sparse_softmax_cross_entropy( labels=tgt_y_ids, logits=tf_outputs.logits, sequence_length=data_batch['{}_length'.format(tgt_str)] - 1) # average_across_timesteps=True, # sum_over_timesteps=False) # loss = tf.reduce_mean(loss, 0) if copy_flag and FLAGS.exact_cover_w != 0: # sum_copy_probs = list(map(lambda t: tf.cast(t, tf.float32), final_state.sum_copy_probs)) copy_probs = (1 - gens) * rets[1] sum_copy_probs = tf.reduce_sum(copy_probs, 1) # sum_copy_probs = tf.split(sum_copy_probs, tf.shape(sum_copy_probs)[0], axis=0)#list(map(lambda prob: tf.cast(prob, tf.float32), tuple(tf.reduce_sum(copy_probs, 1)))) #[batch_size, len_key] memory_lengths = x_sequence_length[ x_ref_flag] #[len for len in sd_sequence_length[x_ref_flag]] exact_coverage_loss = \ tf.reduce_mean(tf.reduce_sum( tx.utils.mask_sequences( tf.square(sum_copy_probs - 1.), memory_lengths), 1)) print_xe_loss_op = tf.print(loss_name, 'xe loss:', loss) with tf.control_dependencies([print_xe_loss_op]): print_op = tf.print(loss_name, 'exact coverage loss :', exact_coverage_loss) with tf.control_dependencies([print_op]): loss += FLAGS.exact_cover_w * exact_coverage_loss losses[loss_name] = loss return decoder, rets, loss, tgt_y_ids def beam_searching(y__ref_flag, x_ref_flag, beam_width): start_tokens = tf.ones_like(data_batch['y_aux_length']) * \ vocab.bos_token_id end_token = vocab.eos_token_id memory_sequence_length = tf.add(y_sequence_length[y__ref_flag] - 1, x_sequence_length[x_ref_flag]) sequence_length = data_batch['{}_length'.format( y_strs[y__ref_flag])] - 1 memory = tf.concat( [y_enc_outputs[y__ref_flag], x_enc_outputs[x_ref_flag]], axis=1) source_ids = tf.concat( [y_ids[y__ref_flag], x_ids[x_ref_flag]['value']], axis=1) #decoder, (bs_outputs, seq_len) decoder, bs_outputs = get_decoder_and_outputs( y__ref_flag, x_ref_flag, None, { 'memory': memory, #print_mem, 'memory_sequence_length': memory_sequence_length, 'copy_memory': x_enc_outputs[x_ref_flag], 'copy_memory_sequence_length': x_sequence_length[x_ref_flag], 'gamma': gamma, 'source_ids': x_ids[x_ref_flag] ['value'], # source_ids,#x_ids[x_ref_flag]['entry'], #[ batch_size, source_length] # 'decoding_strategy': 'infer_sample', only for random sampling 'alpha': config_model.alpha, 'start_tokens': start_tokens, 'end_token': end_token, 'max_decoding_length': config_train.infer_max_decoding_length }, beam_width=beam_width) return decoder, bs_outputs, sequence_length, start_tokens decoder, rets, loss, tgt_y_ids = teacher_forcing(1, 0, 'MLE') rec_decoder, _, rec_loss, _ = teacher_forcing(1, 1, 'REC') rec_weight = FLAGS.rec_w + FLAGS.rec_w_rate * tf.cast(step, tf.float32) step_stage = tf.cast(step, tf.float32) / tf.constant(600.0) rec_weight = tf.case([(tf.less_equal(step_stage, tf.constant(1.0)), lambda: tf.constant(1.0)), \ (tf.greater(step_stage, tf.constant(2.0)), lambda: FLAGS.rec_w)], \ default=lambda: tf.constant(1.0) - (step_stage - 1) * (1 - FLAGS.rec_w)) joint_loss = (1 - rec_weight) * loss + rec_weight * rec_loss losses['joint'] = joint_loss tiled_decoder, bs_outputs, sequence_length, start_tokens = beam_searching( 1, 0, config_train.infer_beam_width) train_ops = { name: get_train_op(losses[name], hparams=config_train.train[name]) for name in config_train.train } return train_ops, bs_outputs, rets, sequence_length, tgt_y_ids, start_tokens, gamma
def _get_loss_train_op(self): # Aggregates losses self.loss_g = self.loss_g_ae + self.lambda_g_graph * self.loss_g_clas_graph + self.lambda_g_sentence * self.loss_g_clas_sentence + self.loss_trans_adj + self.loss_ori_adj self.loss_d = self.loss_d_clas_graph + self.loss_d_clas_sentence + self.loss_ori_adj # Creates optimizers self.g_vars = collect_trainable_variables([ self.embedder, self.self_graph_encoder, self.label_connector, self.cross_graph_encoder, self.rephrase_encoder, self.rephrase_decoder ]) self.d_vars = collect_trainable_variables([ self.clas_embedder, self.classifier_graph, self.classifier_sentence, self.adj_embedder, self.adj_encoder, self.conv1d_1, self.conv1d_2, self.bn1, self.conv1d_3, self.bn2, self.conv1d_4, self.bn3, self.conv1d_5 ]) self.train_op_g = get_train_op(self.loss_g, self.g_vars, hparams=self._hparams.opt) self.train_op_g_ae = get_train_op(self.loss_g_ae, self.g_vars, hparams=self._hparams.opt) self.train_op_d = get_train_op(self.loss_d, self.d_vars, hparams=self._hparams.opt) # Interface tensors self.losses = { "loss_g": self.loss_g, "loss_d": self.loss_d, "loss_g_ae": self.loss_g_ae, "loss_g_clas_graph": self.loss_g_clas_graph, "loss_g_clas_sentence": self.loss_g_clas_sentence, "loss_d_clas_graph": self.loss_d_clas_graph, "loss_d_clas_sentence": self.loss_d_clas_sentence, "loss_ori_adj": self.loss_ori_adj, "loss_trans_adj": self.loss_trans_adj, } self.metrics = { "accu_d_graph": self.accu_d_graph, "accu_d_sentence": self.accu_d_sentence, "accu_g_graph": self.accu_g_graph, "accu_g_sentence": self.accu_g_sentence, "accu_g_gdy_sentence": self.accu_g_gdy_sentence, "accu_ori_adj": self.accu_ori_adj, "accu_trans_adj": self.accu_trans_adj } self.train_ops = { "train_op_g": self.train_op_g, "train_op_g_ae": self.train_op_g_ae, "train_op_d": self.train_op_d } self.samples = { "original": self.text_ids[:, 1:], "transferred": self.rephrase_outputs_.sample_id } self.fetches_train_g = { "loss_g": self.train_ops["train_op_g"], "loss_g_ae": self.losses["loss_g_ae"], "loss_g_clas_graph": self.losses["loss_g_clas_graph"], "loss_g_clas_sentence": self.losses["loss_g_clas_sentence"], "loss_ori_adj": self.losses["loss_ori_adj"], "loss_trans_adj": self.losses["loss_trans_adj"], "accu_g_graph": self.metrics["accu_g_graph"], "accu_g_sentence": self.metrics["accu_g_sentence"], "accu_ori_adj": self.metrics["accu_ori_adj"], "accu_trans_adj": self.metrics["accu_trans_adj"], "adjs_truth": self.adjs[:, 1:, 1:], "adjs_preds": self.pred_trans_adjs_binary, } self.fetches_train_d = { "loss_d": self.train_ops["train_op_d"], "loss_d_clas_graph": self.losses["loss_d_clas_graph"], "loss_d_clas_sentence": self.losses["loss_d_clas_sentence"], "loss_ori_adj": self.losses["loss_ori_adj"], "accu_d_graph": self.metrics["accu_d_graph"], "accu_d_sentence": self.metrics["accu_d_sentence"], "accu_ori_adj": self.metrics["accu_ori_adj"], "adjs_truth": self.adjs[:, 1:, 1:], "adjs_preds": self.pred_ori_adjs_binary, } fetches_eval = {"batch_size": get_batch_size(self.text_ids)} fetches_eval.update(self.losses) fetches_eval.update(self.metrics) fetches_eval.update(self.samples) self.fetches_eval = fetches_eval
def build_model(data_batch, data, step): batch_size, num_steps = [ tf.shape(data_batch["x_value_text_ids"])[d] for d in range(2)] vocab = data.vocab('y_aux') id2str = '<{}>'.format bos_str, eos_str = map(id2str, (vocab.bos_token_id, vocab.eos_token_id)) def single_bleu(ref, hypo): ref = [id2str(u if u != vocab.unk_token_id else -1) for u in ref] hypo = [id2str(u) for u in hypo] ref = tx.utils.strip_special_tokens( ' '.join(ref), strip_bos=bos_str, strip_eos=eos_str) hypo = tx.utils.strip_special_tokens( ' '.join(hypo), strip_eos=eos_str) return 0.01 * tx.evals.sentence_bleu(references=[ref], hypothesis=hypo) def batch_bleu(refs, hypos): return np.array( [single_bleu(ref, hypo) for ref, hypo in zip(refs, hypos)], dtype=np.float32) def lambda_anneal(step_stage): print('==========step_stage is {}'.format(step_stage)) if step_stage <= 1: rec_weight = 1 elif step_stage > 1 and step_stage < 2: rec_weight = FLAGS.rec_w - step_stage * 0.1 return np.array(rec_weight, dtype = tf.float32) # losses losses = {} # embedders embedders = { name: tx.modules.WordEmbedder( vocab_size=data.vocab(name).size, hparams=hparams) for name, hparams in config_model.embedders.items()} # encoders y_encoder = tx.modules.BidirectionalRNNEncoder( hparams=config_model.y_encoder) x_encoder = tx.modules.BidirectionalRNNEncoder( hparams=config_model.x_encoder) def concat_encoder_outputs(outputs): return tf.concat(outputs, -1) def encode(ref_flag): y_str = y_strs[ref_flag] sent_ids = data_batch['{}_text_ids'.format(y_str)] sent_embeds = embedders['y_aux'](sent_ids) sent_sequence_length = data_batch['{}_length'.format(y_str)] sent_enc_outputs, _ = y_encoder( sent_embeds, sequence_length=sent_sequence_length) sent_enc_outputs = concat_encoder_outputs(sent_enc_outputs) x_str = x_strs[ref_flag] sd_ids = { field: data_batch['{}_{}_text_ids'.format(x_str, field)][:, 1:-1] for field in x_fields} sd_embeds = tf.concat( [embedders['x_{}'.format(field)](sd_ids[field]) for field in x_fields], axis=-1) sd_sequence_length = data_batch[ '{}_{}_length'.format(x_str, x_fields[0])] - 2 sd_enc_outputs, _ = x_encoder( sd_embeds, sequence_length=sd_sequence_length) sd_enc_outputs = concat_encoder_outputs(sd_enc_outputs) return sent_ids, sent_embeds, sent_enc_outputs, sent_sequence_length, \ sd_ids, sd_embeds, sd_enc_outputs, sd_sequence_length encode_results = [encode(ref_str) for ref_str in range(2)] sent_ids, sent_embeds, sent_enc_outputs, sent_sequence_length, \ sd_ids, sd_embeds, sd_enc_outputs, sd_sequence_length = \ zip(*encode_results) # get rnn cell rnn_cell = tx.core.layers.get_rnn_cell(config_model.rnn_cell) def get_decoder(cell, y__ref_flag, x_ref_flag, tgt_ref_flag, beam_width=None): output_layer_params = \ {'output_layer': tf.identity} if copy_flag else \ {'vocab_size': vocab.size} if attn_flag: # attention if FLAGS.attn_x and FLAGS.attn_y_: memory = tf.concat( [sent_enc_outputs[y__ref_flag], sd_enc_outputs[x_ref_flag]], axis=1) memory_sequence_length = None elif FLAGS.attn_y_: memory = sent_enc_outputs[y__ref_flag] memory_sequence_length = sent_sequence_length[y__ref_flag] elif FLAGS.attn_x: memory = sd_enc_outputs[x_ref_flag] memory_sequence_length = sd_sequence_length[x_ref_flag] else: raise Exception( "Must specify either y__ref_flag or x_ref_flag.") attention_decoder = tx.modules.AttentionRNNDecoder( cell=cell, memory=memory, memory_sequence_length=memory_sequence_length, hparams=config_model.attention_decoder, **output_layer_params) if not copy_flag: return attention_decoder cell = attention_decoder.cell if beam_width is None else \ attention_decoder._get_beam_search_cell(beam_width) if copy_flag: # copynet kwargs = { 'y__ids': sent_ids[y__ref_flag][:, 1:], 'y__states': sent_enc_outputs[y__ref_flag][:, 1:], 'y__lengths': sent_sequence_length[y__ref_flag] - 1, 'x_ids': sd_ids[x_ref_flag]['value'], 'x_states': sd_enc_outputs[x_ref_flag], 'x_lengths': sd_sequence_length[x_ref_flag], } if tgt_ref_flag is not None: kwargs.update({ 'input_ids': data_batch[ '{}_text_ids'.format(y_strs[tgt_ref_flag])][:, :-1]}) memory_prefixes = [] if FLAGS.copy_y_: memory_prefixes.append('y_') if FLAGS.copy_x: memory_prefixes.append('x') if beam_width is not None: kwargs = { name: tile_batch(value, beam_width) for name, value in kwargs.items()} def get_get_copy_scores(memory_ids_states_lengths, output_size): memory_copy_states = [ tf.layers.dense( memory_states, units=output_size, activation=None, use_bias=False) for _, memory_states, _ in memory_ids_states_lengths] def get_copy_scores(query, coverities=None): ret = [] if FLAGS.copy_y_: memory = memory_copy_states[len(ret)] if coverities is not None: memory = memory + tf.layers.dense( coverities[len(ret)], units=output_size, activation=None, use_bias=False) memory = tf.nn.tanh(memory) ret_y_ = tf.einsum("bim,bm->bi", memory, query) ret.append(ret_y_) if FLAGS.copy_x: memory = memory_copy_states[len(ret)] if coverities is not None: memory = memory + tf.layers.dense( coverities[len(ret)], units=output_size, activation=None, use_bias=False) memory = tf.nn.tanh(memory) ret_x = tf.einsum("bim,bm->bi", memory, query) ret.append(ret_x) if FLAGS.sd_path: ret_sd_path = FLAGS.sd_path_multiplicator * \ tf.einsum("bi,bij->bj", ret_x, match_align) \ + FLAGS.sd_path_addend ret.append(ret_sd_path) return ret return get_copy_scores cell = CopyNetWrapper( cell=cell, vocab_size=vocab.size, memory_ids_states_lengths=[ tuple(kwargs['{}_{}'.format(prefix, s)] for s in ('ids', 'states', 'lengths')) for prefix in memory_prefixes], input_ids= \ kwargs['input_ids'] if tgt_ref_flag is not None else None, get_get_copy_scores=get_get_copy_scores, coverity_dim=config_model.coverage_state_dim if FLAGS.coverage else None, coverity_rnn_cell_hparams=config_model.coverage_rnn_cell if FLAGS.coverage else None, disabled_vocab_size=FLAGS.disabled_vocab_size, eps=FLAGS.eps) decoder = tx.modules.BasicRNNDecoder( cell=cell, hparams=config_model.decoder, **output_layer_params) return decoder def get_decoder_and_outputs( cell, y__ref_flag, x_ref_flag, tgt_ref_flag, params, beam_width=None): decoder = get_decoder( cell, y__ref_flag, x_ref_flag, tgt_ref_flag, beam_width=beam_width) if beam_width is None: ret = decoder(**params) else: ret = tx.modules.beam_search_decode( decoder_or_cell=decoder, beam_width=beam_width, **params) return (decoder,) + ret get_decoder_and_outputs = tf.make_template( 'get_decoder_and_outputs', get_decoder_and_outputs) def teacher_forcing(cell, y__ref_flag, x_ref_flag, loss_name): tgt_ref_flag = x_ref_flag tgt_str = y_strs[tgt_ref_flag] sequence_length = data_batch['{}_length'.format(tgt_str)] - 1 decoder, tf_outputs, final_state, _ = get_decoder_and_outputs( cell, y__ref_flag, x_ref_flag, tgt_ref_flag, {'decoding_strategy': 'train_greedy', 'inputs': sent_embeds[tgt_ref_flag], 'sequence_length': sequence_length}) tgt_sent_ids = data_batch['{}_text_ids'.format(tgt_str)][:, 1:] loss = tx.losses.sequence_sparse_softmax_cross_entropy( labels=tgt_sent_ids, logits=tf_outputs.logits, sequence_length=sequence_length, average_across_batch=False) if FLAGS.add_bleu_weight and y__ref_flag is not None \ and tgt_ref_flag is not None and y__ref_flag != tgt_ref_flag: w = tf.py_func( batch_bleu, [sent_ids[y__ref_flag], tgt_sent_ids], tf.float32, stateful=False, name='W_BLEU') w.set_shape(loss.get_shape()) loss = w * loss loss = tf.reduce_mean(loss, 0) if copy_flag and FLAGS.exact_cover_w != 0: sum_copy_probs = list(map(lambda t: tf.cast(t, tf.float32), final_state.sum_copy_probs)) memory_lengths = [lengths for _, _, lengths in decoder.cell.memory_ids_states_lengths] exact_coverage_losses = [ tf.reduce_mean(tf.reduce_sum( tx.utils.mask_sequences( tf.square(sum_copy_prob - 1.), memory_length), 1)) for sum_copy_prob, memory_length in zip(sum_copy_probs, memory_lengths)] print_xe_loss_op = tf.print(loss_name, 'xe loss:', loss) with tf.control_dependencies([print_xe_loss_op]): for i, exact_coverage_loss in enumerate(exact_coverage_losses): print_op = tf.print(loss_name, 'exact coverage loss {:d}:'.format(i), exact_coverage_loss) with tf.control_dependencies([print_op]): # exact_cover_w = FLAGS.exact_cover_w + FLAGS.exact_cover_w * tf.cast(step, tf.float32) loss += FLAGS.exact_cover_w * exact_coverage_loss losses[loss_name] = loss return decoder, tf_outputs, loss def beam_searching(cell, y__ref_flag, x_ref_flag, beam_width): start_tokens = tf.ones_like(data_batch['y_aux_length']) * \ vocab.bos_token_id end_token = vocab.eos_token_id decoder, bs_outputs, _, _ = get_decoder_and_outputs( cell, y__ref_flag, x_ref_flag, None, {'embedding': embedders['y_aux'], 'start_tokens': start_tokens, 'end_token': end_token, 'max_decoding_length': config_train.infer_max_decoding_length}, beam_width=config_train.infer_beam_width) return decoder, bs_outputs def build_align(): ref_str = ref_strs[1] sent_str = 'y{}'.format(ref_str) sent_texts = data_batch['{}_text'.format(sent_str)][:, 1:-1] sent_ids = data_batch['{}_text_ids'.format(sent_str)][:, 1:-1] #TODO: Here we simply use the embedder previously constructed, #therefore it's shared. We have to construct a new one here if we'd #like to get align on the fly. sent_embeds = embedders['y_aux'](sent_ids) sent_sequence_length = data_batch['{}_length'.format(sent_str)] - 2 sent_enc_outputs, _ = y_encoder( sent_embeds, sequence_length=sent_sequence_length) sent_enc_outputs = concat_encoder_outputs(sent_enc_outputs) sd_field = x_fields[0] sd_str = 'x{}_{}'.format(ref_str, sd_field) sd_texts = data_batch['{}_text'.format(sd_str)][:, :-1] sd_ids = data_batch['{}_text_ids'.format(sd_str)] tgt_sd_ids = sd_ids[:, 1:] sd_ids = sd_ids[:, :-1] sd_sequence_length = data_batch['{}_length'.format(sd_str)] - 1 sd_embedder = embedders['x_'+sd_field] rnn_cell = tx.core.layers.get_rnn_cell(config_model.align_rnn_cell) attention_decoder = tx.modules.AttentionRNNDecoder( cell=rnn_cell, memory=sent_enc_outputs, memory_sequence_length=sent_sequence_length, vocab_size=vocab.size, hparams=config_model.align_attention_decoder) tf_outputs, _, tf_sequence_length = attention_decoder( decoding_strategy='train_greedy', inputs=sd_ids, embedding=sd_embedder, sequence_length=sd_sequence_length) loss = tx.losses.sequence_sparse_softmax_cross_entropy( labels=tgt_sd_ids, logits=tf_outputs.logits, sequence_length=sd_sequence_length) start_tokens = tf.ones_like(sd_sequence_length) * vocab.bos_token_id end_token = vocab.eos_token_id bs_outputs, _, _ = tx.modules.beam_search_decode( decoder_or_cell=attention_decoder, embedding=sd_embedder, start_tokens=start_tokens, end_token=end_token, max_decoding_length=config_train.infer_max_decoding_length, beam_width=config_train.infer_beam_width) return (sent_texts, sent_sequence_length), (sd_texts, sd_sequence_length), \ loss, tf_outputs, bs_outputs decoder, tf_outputs, loss = teacher_forcing(rnn_cell, 1, 0, 'MLE') rec_decoder, _, rec_loss = teacher_forcing(rnn_cell, 1, 1, 'REC') rec_weight = FLAGS.rec_w # rec_weight = tf.py_func( # lambda_anneal, [step_stage], # tf.float32, stateful=False, name='lambda_w') # rec_weight = tf.cond(step_stage < 1 ,) #rec_weight = rec_weight[0] #tf.Print('===========rec_w is {}'.format(rec_weight[0])) step_stage = tf.cast(step, tf.float32) / tf.constant(600.0) rec_weight = tf.case([(tf.less_equal(step_stage, tf.constant(1.0)), lambda:tf.constant(1.0)),\ (tf.greater(step_stage, tf.constant(2.0)), lambda:FLAGS.rec_w)],\ default=lambda:tf.constant(1.0) - (step_stage - 1) * (1 - FLAGS.rec_w)) joint_loss = (1 - rec_weight) * loss + rec_weight * rec_loss losses['joint'] = joint_loss tiled_decoder, bs_outputs = beam_searching( rnn_cell, 1, 0, config_train.infer_beam_width) align_sents, align_sds, align_loss, align_tf_outputs, align_bs_outputs = \ build_align() losses['align'] = align_loss train_ops = { name: get_train_op(losses[name], hparams=config_train.train[name]) for name in config_train.train} return train_ops, bs_outputs, \ align_sents, align_sds, align_tf_outputs, align_bs_outputs
def _build_model(self, inputs, vocab, gamma, lambda_g, lambda_z, lambda_z1, lambda_z2, lambda_ae): embedder = WordEmbedder(vocab_size=vocab.size, hparams=self._hparams.embedder) encoder = UnidirectionalRNNEncoder(hparams=self._hparams.encoder) enc_text_ids = inputs['text_ids'][:, 1:] enc_outputs, final_state = encoder(embedder(enc_text_ids), sequence_length=inputs['length'] - 1) z = final_state[:, self._hparams.dim_c:] # -------------------- CLASSIFIER --------------------- n_classes = self._hparams.num_classes z_classifier_l1 = MLPTransformConnector( 256, hparams=self._hparams.z_classifier_l1) z_classifier_l2 = MLPTransformConnector( 64, hparams=self._hparams.z_classifier_l2) z_classifier_out = MLPTransformConnector( n_classes if n_classes > 2 else 1) z_logits = z_classifier_l1(z) z_logits = z_classifier_l2(z_logits) z_logits = z_classifier_out(z_logits) z_pred = tf.greater(z_logits, 0) z_logits = tf.reshape(z_logits, [-1]) z_pred = tf.to_int64(tf.reshape(z_pred, [-1])) loss_z_clas = tf.nn.sigmoid_cross_entropy_with_logits( labels=tf.to_float(inputs['labels']), logits=z_logits) loss_z_clas = tf.reduce_mean(loss_z_clas) accu_z_clas = tx.evals.accuracy(labels=inputs['labels'], preds=z_pred) # -------------------________________--------------------- label_connector = MLPTransformConnector(self._hparams.dim_c) labels = tf.to_float(tf.reshape(inputs['labels'], [-1, 1])) c = label_connector(labels) c_ = label_connector(1 - labels) h = tf.concat([c, z], 1) h_ = tf.concat([c_, z], 1) # Teacher-force decoding and the auto-encoding loss for G decoder = AttentionRNNDecoder( memory=enc_outputs, memory_sequence_length=inputs['length'] - 1, cell_input_fn=lambda inputs, attention: inputs, vocab_size=vocab.size, hparams=self._hparams.decoder) connector = MLPTransformConnector(decoder.state_size) g_outputs, _, _ = decoder(initial_state=connector(h), inputs=inputs['text_ids'], embedding=embedder, sequence_length=inputs['length'] - 1) loss_g_ae = tx.losses.sequence_sparse_softmax_cross_entropy( labels=inputs['text_ids'][:, 1:], logits=g_outputs.logits, sequence_length=inputs['length'] - 1, average_across_timesteps=True, sum_over_timesteps=False) # Gumbel-softmax decoding, used in training start_tokens = tf.ones_like(inputs['labels']) * vocab.bos_token_id end_token = vocab.eos_token_id gumbel_helper = GumbelSoftmaxEmbeddingHelper(embedder.embedding, start_tokens, end_token, gamma) soft_outputs_, _, soft_length_, = decoder(helper=gumbel_helper, initial_state=connector(h_)) soft_outputs, _, soft_length, = decoder(helper=gumbel_helper, initial_state=connector(h)) # ---------------------------- SHIFTED LOSS ------------------------------------- _, encoder_final_state_ = encoder( embedder(soft_ids=soft_outputs_.sample_id), sequence_length=inputs['length'] - 1) _, encoder_final_state = encoder( embedder(soft_ids=soft_outputs.sample_id), sequence_length=inputs['length'] - 1) new_z_ = encoder_final_state_[:, self._hparams.dim_c:] new_z = encoder_final_state[:, self._hparams.dim_c:] cos_distance_z_ = tf.abs( tf.losses.cosine_distance(tf.nn.l2_normalize(z, axis=1), tf.nn.l2_normalize(new_z_, axis=1), axis=1)) cos_distance_z = tf.abs( tf.losses.cosine_distance(tf.nn.l2_normalize(z, axis=1), tf.nn.l2_normalize(new_z, axis=1), axis=1)) # ----------------------------______________------------------------------------- # Greedy decoding, used in eval outputs_, _, length_ = decoder(decoding_strategy='infer_greedy', initial_state=connector(h_), embedding=embedder, start_tokens=start_tokens, end_token=end_token) # Creates classifier classifier = Conv1DClassifier(hparams=self._hparams.classifier) clas_embedder = WordEmbedder(vocab_size=vocab.size, hparams=self._hparams.embedder) # Classification loss for the classifier clas_logits, clas_preds = classifier( inputs=clas_embedder(ids=inputs['text_ids'][:, 1:]), sequence_length=inputs['length'] - 1) loss_d_clas = tf.nn.sigmoid_cross_entropy_with_logits( labels=tf.to_float(inputs['labels']), logits=clas_logits) loss_d_clas = tf.reduce_mean(loss_d_clas) accu_d = tx.evals.accuracy(labels=inputs['labels'], preds=clas_preds) # Classification loss for the generator, based on soft samples soft_logits, soft_preds = classifier( inputs=clas_embedder(soft_ids=soft_outputs_.sample_id), sequence_length=soft_length_) loss_g_clas = tf.nn.sigmoid_cross_entropy_with_logits( labels=tf.to_float(1 - inputs['labels']), logits=soft_logits) loss_g_clas = tf.reduce_mean(loss_g_clas) # Accuracy on soft samples, for training progress monitoring accu_g = tx.evals.accuracy(labels=1 - inputs['labels'], preds=soft_preds) # Accuracy on greedy-decoded samples, for training progress monitoring _, gdy_preds = classifier(inputs=clas_embedder(ids=outputs_.sample_id), sequence_length=length_) accu_g_gdy = tx.evals.accuracy(labels=1 - inputs['labels'], preds=gdy_preds) # Aggregates losses loss_g = lambda_ae * loss_g_ae + \ lambda_g * loss_g_clas + \ lambda_z1 * cos_distance_z + cos_distance_z_ * lambda_z2 \ - lambda_z * loss_z_clas loss_d = loss_d_clas loss_z = loss_z_clas # Creates optimizers g_vars = collect_trainable_variables( [embedder, encoder, label_connector, connector, decoder]) d_vars = collect_trainable_variables([clas_embedder, classifier]) z_vars = collect_trainable_variables( [z_classifier_l1, z_classifier_l2, z_classifier_out]) train_op_g = get_train_op(loss_g, g_vars, hparams=self._hparams.opt) train_op_g_ae = get_train_op(loss_g_ae, g_vars, hparams=self._hparams.opt) train_op_d = get_train_op(loss_d, d_vars, hparams=self._hparams.opt) train_op_z = get_train_op(loss_z, z_vars, hparams=self._hparams.opt) # Interface tensors self.losses = { "loss_g": loss_g, "loss_g_ae": loss_g_ae, "loss_g_clas": loss_g_clas, "loss_d": loss_d_clas, "loss_z_clas": loss_z_clas, "loss_cos_": cos_distance_z_, "loss_cos": cos_distance_z } self.metrics = { "accu_d": accu_d, "accu_g": accu_g, "accu_g_gdy": accu_g_gdy, "accu_z_clas": accu_z_clas } self.train_ops = { "train_op_g": train_op_g, "train_op_g_ae": train_op_g_ae, "train_op_d": train_op_d, "train_op_z": train_op_z } self.samples = { "original": inputs['text_ids'][:, 1:], "transferred": outputs_.sample_id, "z_vector": z, "labels_source": inputs['labels'], "labels_target": 1 - inputs['labels'], "labels_predicted": gdy_preds } self.fetches_train_g = { "loss_g": self.train_ops["train_op_g"], "loss_g_ae": self.losses["loss_g_ae"], "loss_g_clas": self.losses["loss_g_clas"], "loss_shifted_ae1": self.losses["loss_cos"], "loss_shifted_ae2": self.losses["loss_cos_"], "accu_g": self.metrics["accu_g"], "accu_g_gdy": self.metrics["accu_g_gdy"], "accu_z_clas": self.metrics["accu_z_clas"] } self.fetches_train_z = { "loss_z": self.train_ops["train_op_z"], "accu_z": self.metrics["accu_z_clas"] } self.fetches_train_d = { "loss_d": self.train_ops["train_op_d"], "accu_d": self.metrics["accu_d"] } fetches_eval = {"batch_size": get_batch_size(inputs['text_ids'])} fetches_eval.update(self.losses) fetches_eval.update(self.metrics) fetches_eval.update(self.samples) self.fetches_eval = fetches_eval
def build_model(data_batch, data, step): batch_size, num_steps = [ tf.shape(data_batch["x_value_text_ids"])[d] for d in range(2) ] vocab = data.vocab('y_aux') id2str = '<{}>'.format bos_str, eos_str = map(id2str, (vocab.bos_token_id, vocab.eos_token_id)) def single_bleu(ref, hypo): ref = [id2str(u if u != vocab.unk_token_id else -1) for u in ref] hypo = [id2str(u) for u in hypo] ref = tx.utils.strip_special_tokens(' '.join(ref), strip_bos=bos_str, strip_eos=eos_str) hypo = tx.utils.strip_special_tokens(' '.join(hypo), strip_eos=eos_str) return 0.01 * tx.evals.sentence_bleu(references=[ref], hypothesis=hypo) def batch_bleu(refs, hypos): return np.array( [single_bleu(ref, hypo) for ref, hypo in zip(refs, hypos)], dtype=np.float32) # losses losses = {} # embedders embedders = { name: tx.modules.WordEmbedder(vocab_size=data.vocab(name).size, hparams=hparams) for name, hparams in config_model.embedders.items() } # encoders y_encoder = tx.modules.BidirectionalRNNEncoder( hparams=config_model.y_encoder) x_encoder = tx.modules.BidirectionalRNNEncoder( hparams=config_model.x_encoder) def concat_encoder_outputs(outputs): return tf.concat(outputs, -1) def encode_y(ids, length): embeds = embedders['y_aux'](ids) enc_outputs, _ = y_encoder(embeds, sequence_length=length) enc_outputs = concat_encoder_outputs(enc_outputs) return Encoded(ids, embeds, enc_outputs, length) def encode_x(ids, length): embeds = tf.concat( [embedders['x_' + field](ids[field]) for field in x_fields], axis=-1) enc_outputs, _ = x_encoder(embeds, sequence_length=length) enc_outputs = concat_encoder_outputs(enc_outputs) return Encoded(ids, embeds, enc_outputs, length) y_encoded = [ encode_y(data_batch['{}_text_ids'.format(ref_str)], data_batch['{}_length'.format(ref_str)]) for ref_str in y_strs ] x_encoded = [ encode_x( { field: data_batch['x{}_{}_text_ids'.format(ref_str, field)][:, 1:-1] for field in x_fields }, data_batch['x{}_{}_length'.format(ref_str, x_fields[0])] - 2) for ref_str in ref_strs ] # get rnn cell rnn_cell = tx.core.layers.get_rnn_cell(config_model.rnn_cell) def get_decoder(cell, y_, x, tgt_ref_flag, beam_width=None): output_layer_params = \ {'output_layer': tf.identity} if copy_flag else \ {'vocab_size': vocab.size} if attn_flag: # attention if FLAGS.attn_x and FLAGS.attn_y_: memory = tf.concat([y_.enc_outputs, x.enc_outputs], axis=1) memory_sequence_length = None elif FLAGS.attn_y_: memory = y_.enc_outputs memory_sequence_length = y_.length elif FLAGS.attn_x: memory = x.enc_outputs memory_sequence_length = x.length attention_decoder = tx.modules.AttentionRNNDecoder( cell=cell, memory=memory, memory_sequence_length=memory_sequence_length, hparams=config_model.attention_decoder, **output_layer_params) if not copy_flag: return attention_decoder cell = attention_decoder.cell if beam_width is None else \ attention_decoder._get_beam_search_cell(beam_width) if copy_flag: # copynet kwargs = { 'y__ids': y_.ids[:, 1:], 'y__states': y_.enc_outputs[:, 1:], 'y__lengths': y_.length - 1, 'x_ids': x.ids['value'], 'x_states': x.enc_outputs, 'x_lengths': x.length, } if tgt_ref_flag is not None: kwargs.update({ 'input_ids': data_batch['{}_text_ids'.format( y_strs[tgt_ref_flag])][:, :-1] }) memory_prefixes = [] if FLAGS.copy_y_: memory_prefixes.append('y_') if FLAGS.copy_x: memory_prefixes.append('x') if beam_width is not None: kwargs = { name: tile_batch(value, beam_width) for name, value in kwargs.items() } def get_get_copy_scores(memory_ids_states_lengths, output_size): memory_copy_states = [ tf.layers.dense(memory_states, units=output_size, activation=None, use_bias=False) for _, memory_states, _ in memory_ids_states_lengths ] def get_copy_scores(query, coverities=None): ret = [] if FLAGS.copy_y_: memory = memory_copy_states[len(ret)] if coverities is not None: memory = memory + tf.layers.dense( coverities[len(ret)], units=output_size, activation=None, use_bias=False) memory = tf.nn.tanh(memory) ret_y_ = tf.einsum("bim,bm->bi", memory, query) ret.append(ret_y_) if FLAGS.copy_x: memory = memory_copy_states[len(ret)] if coverities is not None: memory + memory + tf.layers.dense( coverities[len(ret)], units=output_size, activation=None, use_bias=False) memory = tf.nn.tanh(memory) ret_x = tf.einsum("bim,bm->bi", memory, query) ret.append(ret_x) return ret return get_copy_scores cell = CopyNetWrapper( cell=cell, vocab_size=vocab.size, memory_ids_states_lengths=[ tuple(kwargs['{}_{}'.format(prefix, s)] for s in ('ids', 'states', 'lengths')) for prefix in memory_prefixes], input_ids=\ kwargs['input_ids'] if tgt_ref_flag is not None else None, get_get_copy_scores=get_get_copy_scores, coverity_dim=config_model.coverity_dim if FLAGS.coverage else None, coverity_rnn_cell_hparams=config_model.coverity_rnn_cell if FLAGS.coverage else None, disabled_vocab_size=FLAGS.disabled_vocab_size, eps=FLAGS.eps) decoder = tx.modules.BasicRNNDecoder(cell=cell, hparams=config_model.decoder, **output_layer_params) return decoder def get_decoder_and_outputs(cell, y_, x, tgt_ref_flag, params, beam_width=None): decoder = get_decoder(cell, y_, x, tgt_ref_flag, beam_width=beam_width) if beam_width is None: ret = decoder(**params) else: ret = tx.modules.beam_search_decode(decoder_or_cell=decoder, beam_width=beam_width, **params) return (decoder, ) + ret get_decoder_and_outputs = tf.make_template('get_decoder_and_outputs', get_decoder_and_outputs) def teacher_forcing(cell, y_, x_ref_flag, loss_name, add_bleu_weight=False): x = x_encoded[x_ref_flag] tgt_ref_flag = x_ref_flag tgt_str = y_strs[tgt_ref_flag] sequence_length = data_batch['{}_length'.format(tgt_str)] - 1 decoder, tf_outputs, final_state, _ = get_decoder_and_outputs( cell, y_, x, tgt_ref_flag, { 'decoding_strategy': 'train_greedy', 'inputs': y_encoded[tgt_ref_flag].embeds, 'sequence_length': sequence_length }) tgt_y_ids = data_batch['{}_text_ids'.format(tgt_str)][:, 1:] loss = tx.losses.sequence_sparse_softmax_cross_entropy( labels=tgt_y_ids, logits=tf_outputs.logits, sequence_length=sequence_length, average_across_batch=False) if add_bleu_weight: w = tf.py_func(batch_bleu, [y_.ids, tgt_y_ids], tf.float32, stateful=False, name='W_BLEU') w.set_shape(loss.get_shape()) loss = w * loss loss = tf.reduce_mean(loss, 0) if copy_flag and FLAGS.exact_cover_w != 0: sum_copy_probs = list( map(lambda t: tf.cast(t, tf.float32), final_state.sum_copy_probs)) memory_lengths = [ lengths for _, _, lengths in decoder.cell.memory_ids_states_lengths ] exact_coverage_losses = [ tf.reduce_mean( tf.reduce_sum( tx.utils.mask_sequences(tf.square(sum_copy_prob - 1.), memory_length), 1)) for sum_copy_prob, memory_length in zip( sum_copy_probs, memory_lengths) ] for i, exact_coverage_loss in enumerate(exact_coverage_losses): loss += FLAGS.exact_cover_w * exact_coverage_loss losses[loss_name] = loss return decoder, tf_outputs, loss def beam_searching(cell, y_, x, beam_width): start_tokens = tf.ones_like(data_batch['y_aux_length']) * \ vocab.bos_token_id end_token = vocab.eos_token_id decoder, bs_outputs, _, bs_length = get_decoder_and_outputs( cell, y_, x, None, { 'embedding': embedders['y_aux'], 'start_tokens': start_tokens, 'end_token': end_token, 'max_decoding_length': config_train.infer_max_decoding_length }, beam_width=beam_width) return decoder, bs_outputs, bs_length def greedy_decoding(cell, y_, x): decoder, bs_outputs, bs_length = beam_searching(cell, y_, x, 1) return bs_outputs.predicted_ids[:, :, 0], bs_length[:, 0] discriminator = tx.modules.UnidirectionalRNNClassifier( hparams=config_model.discriminator) def disc(y, x): return discriminator(tf.concat([x.enc_outputs, y], axis=1)) joint_loss = 0. disc_loss = 0. for flag in range(2): xe_decoder, xe_outputs, xe_loss = teacher_forcing( rnn_cell, y_encoded[1 - flag], flag, 'XE') rec_decoder, rec_outputs, rec_loss = teacher_forcing( rnn_cell, y_encoded[1 - flag], 1 - flag, 'REC') # print('[info]rec_outputs is:{}'.format(rec_outputs.cell_output)) greedy_ids, greedy_length = greedy_decoding(rnn_cell, y_encoded[1 - flag], x_encoded[flag]) greedy_ids = tf.concat([ data_batch['y_aux_text_ids'][:, :1], tf.cast(greedy_ids, tf.int64) ], axis=1) greedy_length = 1 + greedy_length greedy_encoded = encode_y(greedy_ids, greedy_length) bt_decoder, _, bt_loss = teacher_forcing(rnn_cell, greedy_encoded, 1 - flag, 'BT') adv_loss = 2. * disc(rec_outputs.cell_output, x_encoded[1-flag])[0][:, 1] \ + disc(xe_outputs.cell_output, x_encoded[flag])[0][:, 0] \ + disc(rec_outputs.cell_output, x_encoded[flag])[0][:, 0] adv_loss = tf.reduce_mean(adv_loss, 0) disc_loss = disc_loss + adv_loss joint_loss = joint_loss + (rec_loss + FLAGS.bt_w * bt_loss + FLAGS.adv_w * adv_loss) disc_loss = -disc_loss losses['joint'] = joint_loss losses['disc'] = disc_loss tiled_decoder, bs_outputs, _ = beam_searching( rnn_cell, y_encoded[1], x_encoded[0], config_train.infer_beam_width) disc_variables = discriminator.trainable_variables other_variables = list( filter(lambda var: var not in disc_variables, tf.trainable_variables())) variables_of_train_op = { 'joint': other_variables, 'disc': disc_variables, } train_ops = { name: get_train_op(losses[name], variables=variables_of_train_op.get(name, None), hparams=config_train.train[name]) for name in config_train.train } return train_ops, bs_outputs