def __init__(self, config, model_obj, task): super(SentimentTrainer, self).__init__(config, model_obj) self.task = SST("data/sst", pretrained_path=self.config.pretrained_embedding_path, embedding_size=self.config.embedding_dim) self.vocab = PretrainedVocab(self.config.data_path, self.config.pretrained_embedding_path, self.config.embedding_dim) self.pretrained_word_embeddings, self.word2id = self.vocab.get_word_embeddings( ) self.config.input_dim = len(self.word2id)
def __init__(self, config, model_class): self.config = config self.sst = SST("data/sst", pretrained_path=self.config.pretrained_embedding_path, embedding_size=self.config.embedding_dim) self.vocab = PretrainedVocab(self.config.data_path, self.config.pretrained_embedding_path, self.config.embedding_dim) self.pretrained_word_embeddings, self.word2id = self.vocab.get_word_embeddings() self.config.input_dim = len(self.word2id) self.sentimen_tree_lstm = model_class(self.config)
def __init__(self, config, student_model, teacher_model): super(SSTDistiller, self).__init__(config, student_model, teacher_model) self.sst = SST("data/sst") self.config.vocab_size = len(self.sst.vocab) self.student.hparams.vocab_size = self.config.vocab_size self.teacher.hparams.vocab_size = self.config.vocab_size self.vocab = PretrainedVocab(self.config.data_path, self.config.pretrained_embedding_path, self.config.embedding_dim) self.pretrained_word_embeddings, self.word2id = self.vocab.get_word_embeddings()
def get_data_itaratoes(self): dataset = tf.data.TFRecordDataset(SST.get_tfrecord_path("data/sst", mode="train", add_subtrees=True)) dataset = dataset.map(SST.parse_full_sst_tree_examples) dataset = dataset.padded_batch(self.config.batch_size, padded_shapes=SST.get_padded_shapes(), drop_remainder=True) dataset = dataset.shuffle(buffer_size=1000) dataset = dataset.repeat() iterator = dataset.make_initializable_iterator() dev_dataset = tf.data.TFRecordDataset(SST.get_tfrecord_path("data/sst", mode="dev", add_subtrees=True)) dev_dataset = dev_dataset.map(SST.parse_full_sst_tree_examples) dev_dataset = dev_dataset.shuffle(buffer_size=1000) dev_dataset = dev_dataset.repeat() dev_dataset = dev_dataset.padded_batch(1000, padded_shapes=SST.get_padded_shapes(), drop_remainder=True) dev_iterator = dev_dataset.make_initializable_iterator() test_dataset = tf.data.TFRecordDataset(SST.get_tfrecord_path("data/sst", mode="test", add_subtrees=True)) test_dataset = test_dataset.map(SST.parse_full_sst_tree_examples) test_dataset = test_dataset.shuffle(buffer_size=1000) test_dataset = test_dataset.repeat() test_dataset = test_dataset.padded_batch(1000, padded_shapes=SST.get_padded_shapes(), drop_remainder=True) test_iterator = test_dataset.make_initializable_iterator() return iterator, dev_iterator, test_iterator
def __init__(self, hparams, model_class): self.config = hparams self.sst = SST("data/sst") self.vocab = PretrainedVocab(self.config.data_path, self.config.pretrained_embedding_path, self.config.embedding_dim) self.pretrained_word_embeddings, self.word2id = self.vocab.get_word_embeddings() self.config.input_dim = len(self.word2id) self.config.vocab_size = len(self.word2id) if hparams.bidirectional: lstm = BiLSTM else: lstm = LSTM self.sentimen_lstm = model_class(self.config, model=lstm)
ArithmaticSimpleSameLength21Depth2NormalBiLing( os.path.join( hparams.data_dir, 'arithmatic_simple_samelength21_depth2_normal_biling')), 'arithmatic_simple_samelength201_depth2_normal': ArithmaticSimpleSameLength201Depth2Normal( os.path.join(hparams.data_dir, 'arithmatic_simple_samelength201_depth2_normal')), 'arithmatic_simple_missinglength21_depth2_normal_biling': ArithmaticSimpleMissingLength21Depth2NormalBiLing( os.path.join( hparams.data_dir, 'arithmatic_simple_missinglength21_depth2_normal_biling')), 'sst': SST(data_path=os.path.join(hparams.data_dir, "sst/"), add_subtrees=False, pretrained=True), 'ptb_lm': PTB(os.path.join(hparams.data_dir, 'ptb')), 'wsj_parse': ParseWSJ(os.path.join(hparams.data_dir, 'wsj')), 'imdb': IMDB(data_path=os.path.join(hparams.data_dir, "imdb"), pretrained=True), 'char_trec': CharTrec6(os.path.join(hparams.data_dir, "char_trec6"), build_vocab=False), 'mnist': Mnist1D(os.path.join(hparams.data_dir, 'mnist1d')), }
class SentimentTrainer(Trainer): def __init__(self, config, model_obj, task): super(SentimentTrainer, self).__init__(config, model_obj) self.task = SST("data/sst", pretrained_path=self.config.pretrained_embedding_path, embedding_size=self.config.embedding_dim) self.vocab = PretrainedVocab(self.config.data_path, self.config.pretrained_embedding_path, self.config.embedding_dim) self.pretrained_word_embeddings, self.word2id = self.vocab.get_word_embeddings( ) self.config.input_dim = len(self.word2id) def get_train_data_itaratoes(self): dataset = tf.data.TFRecordDataset( self.task.get_tfrecord_path(mode="train")) dataset = dataset.map(self.task.parse_seq2seq_examples) dataset = dataset.padded_batch( self.config.batch_size, padded_shapes=self.task.get_padded_shapes()) dataset = dataset.shuffle(buffer_size=1000) dataset = dataset.repeat() train_iterator = dataset.make_initializable_iterator() dataset = tf.data.TFRecordDataset( self.task.get_tfrecord_path(mode="dev")) dataset = dataset.map(self.task.parse_seq2seq_examples) dataset = dataset.padded_batch( self.config.batch_size, padded_shapes=self.task.get_padded_shapes()) dataset = dataset.shuffle(buffer_size=1000) dataset = dataset.repeat() dev_iterator = dataset.make_initializable_iterator() dataset = tf.data.TFRecordDataset( self.task.get_tfrecord_path(mode="test")) dataset = dataset.map(self.task.parse_seq2seq_examples) dataset = dataset.padded_batch( self.config.batch_size, padded_shapes=self.task.get_padded_shapes()) dataset = dataset.shuffle(buffer_size=1000) dataset = dataset.repeat() test_iterator = dataset.make_initializable_iterator() return train_iterator, dev_iterator, test_iterator def compute_loss(self, logits, targets): xentropy, weights = padded_cross_entropy_loss( logits, targets, self.config.label_smoothing, self.config.vocab_size) loss = tf.reduce_sum(xentropy) / tf.reduce_sum(weights) return loss def add_metric_summaries(self, logits, labels, family): eval_metrics = get_eval_metrics(logits, labels, self.model.hparams) for metric in eval_metrics: tf.logging.info(metric) tf.logging.info(eval_metrics[metric]) tf.summary.scalar(metric, tf.reduce_mean(eval_metrics[metric]), family=family) def get_metric_summaries_as_dic(self, logits, labels): metric_summaries = {} eval_metrics = get_eval_metrics(logits, labels, self.model.hparams) for metric in eval_metrics: metric_summaries[metric] = tf.reduce_mean(eval_metrics[metric]) return metric_summaries def build_train_graph(self): train_iterator, dev_iterator, test_iterator = self.get_train_data_itaratoes( ) train_examples = train_iterator.get_next() dev_examples = dev_iterator.get_next() test_examples = test_iterator.get_next() self.model.create_vars(reuse=False) train_output_dic = self.model.apply(train_examples, is_train=True) dev_output_dic = self.model.apply(dev_examples, is_train=False) test_output_dic = self.model.apply(test_examples, is_train=False) train_loss = self.compute_loss(train_output_dic['logits'], train_output_dic['targets']) dev_loss = self.compute_loss(dev_output_dic['logits'], dev_output_dic['targets']) test_loss = self.compute_loss(test_output_dic['logits'], test_output_dic['targets']) train_output_dic['loss'] = train_loss tf.summary.scalar("loss", train_loss, family="train") tf.summary.scalar("loss", dev_loss, family="dev") tf.summary.scalar("loss", test_loss, family="test") self.add_metric_summaries(train_output_dic['logits'], train_output_dic['targets'], "train") self.add_metric_summaries(dev_output_dic['logits'], dev_output_dic['targets'], "dev") self.add_metric_summaries(test_output_dic['logits'], test_output_dic['targets'], "test") update_op, learning_rate = self.get_train_op( train_loss, train_output_dic["trainable_vars"], start_learning_rate=0.0005, base_learning_rate=self.model.hparams.learning_rate, warmup_steps=self.model.hparams.learning_rate_warmup_steps, clip_gradient_norm=self.model.hparams.clip_grad_norm) tf.summary.scalar("learning_rate", learning_rate, family="train") scaffold = tf.train.Scaffold(local_init_op=tf.group( tf.local_variables_initializer(), train_iterator.initializer, dev_iterator.initializer, test_iterator.initializer), init_feed_dict={}) return update_op, scaffold, train_output_dic, dev_output_dic, test_output_dic
'identity_binary': AlgorithmicIdentityBinary40('data/alg'), 'addition': AlgorithmicAdditionDecimal40('data/alg'), 'multiplication': AlgorithmicMultiplicationDecimal40('data/alg'), 'sort': AlgorithmicSortProblem('data/alg'), 'reverse': AlgorithmicReverseProblem('data/alg'), 'arithmatic': Arithmatic('data/arithmatic'), 'sst': SST(data_path="data/sst/", add_subtrees=True, pretrained=True, pretrained_path="data/sst/filtered_glove.txt", embedding_size=300) } hparams.vocab_size = tasks[hparams.task_name].vocab_length hparams.output_dim = len(tasks[hparams.task_name].target_vocab) transformer_params = transformer_medium_hparams( vocab_size=hparams.vocab_size, output_dim=hparams.output_dim, input_dim=hparams.input_dim, encoder_attention_dir=hparams.encoder_attention_dir) lstm_params = lstm_small_hparams(vocab_size=hparams.vocab_size, output_dim=hparams.output_dim, input_dim=hparams.input_dim)
"bilstm": BidiLSTMSeq2Seq, "transformer": Transformer, "utransformer": UniversalTransformer, "enc_transformer": EncodingTransformer, "enc_utransformer": EncodingUniversalTransformer } tasks = { 'identity': AlgorithmicIdentityDecimal40('data/alg'), 'identity_binary': AlgorithmicIdentityBinary40('data/alg'), 'addition': AlgorithmicAdditionDecimal40('data/alg'), 'multiplication': AlgorithmicMultiplicationDecimal40('data/alg'), 'sort': AlgorithmicSortProblem('data/alg'), 'reverse': AlgorithmicReverseProblem('data/alg'), 'arithmatic': Arithmatic('data/arithmatic'), 'sst': SST(data_path="data/sst/", add_subtrees=False, pretrained=True), 'ptb_lm': PTB('data/ptb'), 'wsj_parse': ParseWSJ('data/wsj'), 'imdb': IMDB(data_path="data/imdb", pretrained=True) } hparams.vocab_size = tasks[hparams.task_name].vocab_length hparams.output_dim = len(tasks[hparams.task_name].target_vocab) transformer_params = TransformerHparam( input_dim=hparams.input_dim, hidden_dim=hparams.hidden_dim, output_dim=hparams.output_dim, encoder_depth=hparams.encoder_depth, decoder_depth=hparams.decoder_depth, number_of_heads=2,