コード例 #1
0
 def __init__(self, config, model_obj, task):
     super(SentimentTrainer, self).__init__(config, model_obj)
     self.task = SST("data/sst",
                     pretrained_path=self.config.pretrained_embedding_path,
                     embedding_size=self.config.embedding_dim)
     self.vocab = PretrainedVocab(self.config.data_path,
                                  self.config.pretrained_embedding_path,
                                  self.config.embedding_dim)
     self.pretrained_word_embeddings, self.word2id = self.vocab.get_word_embeddings(
     )
     self.config.input_dim = len(self.word2id)
コード例 #2
0
  def __init__(self, config, model_class):
    self.config = config
    self.sst = SST("data/sst", pretrained_path=self.config.pretrained_embedding_path, embedding_size=self.config.embedding_dim)
    self.vocab = PretrainedVocab(self.config.data_path, self.config.pretrained_embedding_path,
                                 self.config.embedding_dim)
    self.pretrained_word_embeddings, self.word2id = self.vocab.get_word_embeddings()
    self.config.input_dim = len(self.word2id)

    self.sentimen_tree_lstm = model_class(self.config)
コード例 #3
0
  def __init__(self, config, student_model, teacher_model):
    super(SSTDistiller, self).__init__(config, student_model, teacher_model)

    self.sst = SST("data/sst")
    self.config.vocab_size = len(self.sst.vocab)
    self.student.hparams.vocab_size = self.config.vocab_size
    self.teacher.hparams.vocab_size = self.config.vocab_size

    self.vocab = PretrainedVocab(self.config.data_path, self.config.pretrained_embedding_path,
                                 self.config.embedding_dim)
    self.pretrained_word_embeddings, self.word2id = self.vocab.get_word_embeddings()
コード例 #4
0
  def get_data_itaratoes(self):
    dataset = tf.data.TFRecordDataset(SST.get_tfrecord_path("data/sst", mode="train", add_subtrees=True))
    dataset = dataset.map(SST.parse_full_sst_tree_examples)
    dataset = dataset.padded_batch(self.config.batch_size, padded_shapes=SST.get_padded_shapes(), drop_remainder=True)
    dataset = dataset.shuffle(buffer_size=1000)
    dataset = dataset.repeat()
    iterator = dataset.make_initializable_iterator()

    dev_dataset = tf.data.TFRecordDataset(SST.get_tfrecord_path("data/sst", mode="dev", add_subtrees=True))
    dev_dataset = dev_dataset.map(SST.parse_full_sst_tree_examples)
    dev_dataset = dev_dataset.shuffle(buffer_size=1000)
    dev_dataset = dev_dataset.repeat()
    dev_dataset = dev_dataset.padded_batch(1000, padded_shapes=SST.get_padded_shapes(),
                                           drop_remainder=True)
    dev_iterator = dev_dataset.make_initializable_iterator()

    test_dataset = tf.data.TFRecordDataset(SST.get_tfrecord_path("data/sst", mode="test", add_subtrees=True))
    test_dataset = test_dataset.map(SST.parse_full_sst_tree_examples)
    test_dataset = test_dataset.shuffle(buffer_size=1000)
    test_dataset = test_dataset.repeat()
    test_dataset = test_dataset.padded_batch(1000, padded_shapes=SST.get_padded_shapes(),
                                           drop_remainder=True)
    test_iterator = test_dataset.make_initializable_iterator()


    return iterator, dev_iterator, test_iterator
コード例 #5
0
  def __init__(self, hparams, model_class):
    self.config = hparams
    self.sst = SST("data/sst")

    self.vocab = PretrainedVocab(self.config.data_path, self.config.pretrained_embedding_path,
                                 self.config.embedding_dim)
    self.pretrained_word_embeddings, self.word2id = self.vocab.get_word_embeddings()
    self.config.input_dim = len(self.word2id)
    self.config.vocab_size = len(self.word2id)

    if hparams.bidirectional:
      lstm = BiLSTM
    else:
      lstm = LSTM
    self.sentimen_lstm = model_class(self.config, model=lstm)
コード例 #6
0
        ArithmaticSimpleSameLength21Depth2NormalBiLing(
            os.path.join(
                hparams.data_dir,
                'arithmatic_simple_samelength21_depth2_normal_biling')),
        'arithmatic_simple_samelength201_depth2_normal':
        ArithmaticSimpleSameLength201Depth2Normal(
            os.path.join(hparams.data_dir,
                         'arithmatic_simple_samelength201_depth2_normal')),
        'arithmatic_simple_missinglength21_depth2_normal_biling':
        ArithmaticSimpleMissingLength21Depth2NormalBiLing(
            os.path.join(
                hparams.data_dir,
                'arithmatic_simple_missinglength21_depth2_normal_biling')),
        'sst':
        SST(data_path=os.path.join(hparams.data_dir, "sst/"),
            add_subtrees=False,
            pretrained=True),
        'ptb_lm':
        PTB(os.path.join(hparams.data_dir, 'ptb')),
        'wsj_parse':
        ParseWSJ(os.path.join(hparams.data_dir, 'wsj')),
        'imdb':
        IMDB(data_path=os.path.join(hparams.data_dir, "imdb"),
             pretrained=True),
        'char_trec':
        CharTrec6(os.path.join(hparams.data_dir, "char_trec6"),
                  build_vocab=False),
        'mnist':
        Mnist1D(os.path.join(hparams.data_dir, 'mnist1d')),
    }
コード例 #7
0
class SentimentTrainer(Trainer):
    def __init__(self, config, model_obj, task):
        super(SentimentTrainer, self).__init__(config, model_obj)
        self.task = SST("data/sst",
                        pretrained_path=self.config.pretrained_embedding_path,
                        embedding_size=self.config.embedding_dim)
        self.vocab = PretrainedVocab(self.config.data_path,
                                     self.config.pretrained_embedding_path,
                                     self.config.embedding_dim)
        self.pretrained_word_embeddings, self.word2id = self.vocab.get_word_embeddings(
        )
        self.config.input_dim = len(self.word2id)

    def get_train_data_itaratoes(self):
        dataset = tf.data.TFRecordDataset(
            self.task.get_tfrecord_path(mode="train"))
        dataset = dataset.map(self.task.parse_seq2seq_examples)
        dataset = dataset.padded_batch(
            self.config.batch_size,
            padded_shapes=self.task.get_padded_shapes())
        dataset = dataset.shuffle(buffer_size=1000)
        dataset = dataset.repeat()
        train_iterator = dataset.make_initializable_iterator()

        dataset = tf.data.TFRecordDataset(
            self.task.get_tfrecord_path(mode="dev"))
        dataset = dataset.map(self.task.parse_seq2seq_examples)
        dataset = dataset.padded_batch(
            self.config.batch_size,
            padded_shapes=self.task.get_padded_shapes())
        dataset = dataset.shuffle(buffer_size=1000)
        dataset = dataset.repeat()
        dev_iterator = dataset.make_initializable_iterator()

        dataset = tf.data.TFRecordDataset(
            self.task.get_tfrecord_path(mode="test"))
        dataset = dataset.map(self.task.parse_seq2seq_examples)
        dataset = dataset.padded_batch(
            self.config.batch_size,
            padded_shapes=self.task.get_padded_shapes())
        dataset = dataset.shuffle(buffer_size=1000)
        dataset = dataset.repeat()
        test_iterator = dataset.make_initializable_iterator()

        return train_iterator, dev_iterator, test_iterator

    def compute_loss(self, logits, targets):
        xentropy, weights = padded_cross_entropy_loss(
            logits, targets, self.config.label_smoothing,
            self.config.vocab_size)
        loss = tf.reduce_sum(xentropy) / tf.reduce_sum(weights)

        return loss

    def add_metric_summaries(self, logits, labels, family):
        eval_metrics = get_eval_metrics(logits, labels, self.model.hparams)
        for metric in eval_metrics:
            tf.logging.info(metric)
            tf.logging.info(eval_metrics[metric])
            tf.summary.scalar(metric,
                              tf.reduce_mean(eval_metrics[metric]),
                              family=family)

    def get_metric_summaries_as_dic(self, logits, labels):
        metric_summaries = {}
        eval_metrics = get_eval_metrics(logits, labels, self.model.hparams)
        for metric in eval_metrics:
            metric_summaries[metric] = tf.reduce_mean(eval_metrics[metric])

        return metric_summaries

    def build_train_graph(self):
        train_iterator, dev_iterator, test_iterator = self.get_train_data_itaratoes(
        )

        train_examples = train_iterator.get_next()
        dev_examples = dev_iterator.get_next()
        test_examples = test_iterator.get_next()

        self.model.create_vars(reuse=False)

        train_output_dic = self.model.apply(train_examples, is_train=True)
        dev_output_dic = self.model.apply(dev_examples, is_train=False)
        test_output_dic = self.model.apply(test_examples, is_train=False)

        train_loss = self.compute_loss(train_output_dic['logits'],
                                       train_output_dic['targets'])
        dev_loss = self.compute_loss(dev_output_dic['logits'],
                                     dev_output_dic['targets'])
        test_loss = self.compute_loss(test_output_dic['logits'],
                                      test_output_dic['targets'])

        train_output_dic['loss'] = train_loss
        tf.summary.scalar("loss", train_loss, family="train")
        tf.summary.scalar("loss", dev_loss, family="dev")
        tf.summary.scalar("loss", test_loss, family="test")

        self.add_metric_summaries(train_output_dic['logits'],
                                  train_output_dic['targets'], "train")
        self.add_metric_summaries(dev_output_dic['logits'],
                                  dev_output_dic['targets'], "dev")
        self.add_metric_summaries(test_output_dic['logits'],
                                  test_output_dic['targets'], "test")

        update_op, learning_rate = self.get_train_op(
            train_loss,
            train_output_dic["trainable_vars"],
            start_learning_rate=0.0005,
            base_learning_rate=self.model.hparams.learning_rate,
            warmup_steps=self.model.hparams.learning_rate_warmup_steps,
            clip_gradient_norm=self.model.hparams.clip_grad_norm)
        tf.summary.scalar("learning_rate", learning_rate, family="train")

        scaffold = tf.train.Scaffold(local_init_op=tf.group(
            tf.local_variables_initializer(), train_iterator.initializer,
            dev_iterator.initializer, test_iterator.initializer),
                                     init_feed_dict={})

        return update_op, scaffold, train_output_dic, dev_output_dic, test_output_dic
        'identity_binary':
        AlgorithmicIdentityBinary40('data/alg'),
        'addition':
        AlgorithmicAdditionDecimal40('data/alg'),
        'multiplication':
        AlgorithmicMultiplicationDecimal40('data/alg'),
        'sort':
        AlgorithmicSortProblem('data/alg'),
        'reverse':
        AlgorithmicReverseProblem('data/alg'),
        'arithmatic':
        Arithmatic('data/arithmatic'),
        'sst':
        SST(data_path="data/sst/",
            add_subtrees=True,
            pretrained=True,
            pretrained_path="data/sst/filtered_glove.txt",
            embedding_size=300)
    }

    hparams.vocab_size = tasks[hparams.task_name].vocab_length
    hparams.output_dim = len(tasks[hparams.task_name].target_vocab)
    transformer_params = transformer_medium_hparams(
        vocab_size=hparams.vocab_size,
        output_dim=hparams.output_dim,
        input_dim=hparams.input_dim,
        encoder_attention_dir=hparams.encoder_attention_dir)

    lstm_params = lstm_small_hparams(vocab_size=hparams.vocab_size,
                                     output_dim=hparams.output_dim,
                                     input_dim=hparams.input_dim)
コード例 #9
0
        "bilstm": BidiLSTMSeq2Seq,
        "transformer": Transformer,
        "utransformer": UniversalTransformer,
        "enc_transformer": EncodingTransformer,
        "enc_utransformer": EncodingUniversalTransformer
    }

    tasks = {
        'identity': AlgorithmicIdentityDecimal40('data/alg'),
        'identity_binary': AlgorithmicIdentityBinary40('data/alg'),
        'addition': AlgorithmicAdditionDecimal40('data/alg'),
        'multiplication': AlgorithmicMultiplicationDecimal40('data/alg'),
        'sort': AlgorithmicSortProblem('data/alg'),
        'reverse': AlgorithmicReverseProblem('data/alg'),
        'arithmatic': Arithmatic('data/arithmatic'),
        'sst': SST(data_path="data/sst/", add_subtrees=False, pretrained=True),
        'ptb_lm': PTB('data/ptb'),
        'wsj_parse': ParseWSJ('data/wsj'),
        'imdb': IMDB(data_path="data/imdb", pretrained=True)
    }

    hparams.vocab_size = tasks[hparams.task_name].vocab_length
    hparams.output_dim = len(tasks[hparams.task_name].target_vocab)

    transformer_params = TransformerHparam(
        input_dim=hparams.input_dim,
        hidden_dim=hparams.hidden_dim,
        output_dim=hparams.output_dim,
        encoder_depth=hparams.encoder_depth,
        decoder_depth=hparams.decoder_depth,
        number_of_heads=2,