Esempio n. 1
0
    def __init__(self, params, controller, x_train, x_valid, name='child'):
        print('-' * 80)
        print('Building LM')

        self.params = _set_default_params(params)
        self.controller = controller
        self.sample_arc = tf.unstack(controller.sample_arc)
        self.name = name

        # train data
        (self.x_train, self.y_train, self.num_train_batches,
         self.reset_start_idx, self.should_reset,
         self.base_bptt) = data_utils.input_producer(x_train,
                                                     params.batch_size,
                                                     params.bptt_steps,
                                                     random_len=True)
        params.add_hparam('num_train_steps',
                          self.num_train_batches * params.num_train_epochs)

        # valid data
        (self.x_valid, self.y_valid,
         self.num_valid_batches) = data_utils.input_producer(
             x_valid, params.batch_size, params.bptt_steps)

        self._build_params()
        self._build_train()
        self._build_valid()
Esempio n. 2
0
  def __init__(self, params, x_train, x_valid, x_test, name='language_model'):
    print('-' * 80)
    print('Building LM')

    self.params = _set_default_params(params)
    self.name = name

    # train data
    (self.x_train, self.y_train,
     self.num_train_batches, self.reset_start_idx,
     self.should_reset, self.base_bptt) = data_utils.input_producer(
         x_train, params.batch_size, params.bptt_steps, random_len=True)
    params.add_hparam(
        'num_train_steps', self.num_train_batches * params.num_train_epochs)

    # valid data
    (self.x_valid, self.y_valid,
     self.num_valid_batches) = data_utils.input_producer(
         x_valid, params.batch_size, params.bptt_steps)

    # test data
    (self.x_test, self.y_test,
     self.num_test_batches) = data_utils.input_producer(x_test, 1, 1)

    params.add_hparam('start_decay_step',
                      params.start_decay_epoch * self.num_train_batches)
    params.add_hparam('decay_every_step',
                      params.decay_every_epoch * self.num_train_batches)

    self._build_params()
    self._build_train()
    self._build_valid()
    self._build_test()