예제 #1
0
 def test_cyclical_rate_constant(self):
     rate_begin, end_of_period = sgmcmc.cyclical_rate(1, 1)
     self.assertAlmostEqual(rate_begin,
                            1.0,
                            delta=1.0e-6,
                            msg='Rate for period_length=1 is not 1.0')
     self.assertEqual(bool(end_of_period),
                      True,
                      msg='Period length of one but not an end-step.')
  def on_train_batch_end(self, batch, logs=None):
    self.batch_count += 1
    if (self.epoch_count+1) < self.sampling_start_epoch:
      return
    elif (self.epoch_count+1) == self.sampling_start_epoch and batch == 0:
      print('# Starting sampling phase')

    timestep_factor, is_cycle_end = sgmcmc.cyclical_rate(
        self.cycle_period_length, self.batch_count,
        schedule=self.schedule, min_value=self.min_value)
    if isinstance(self.model.optimizer, sgmcmc.SGMCMCOptimizer):
      tf.keras.backend.set_value(self.model.optimizer.timestep_factor,
                                 timestep_factor)

    if is_cycle_end:
      print('# Taking ensemble member sample')
      self.ensemble.append_maybe(self.model.get_weights)
예제 #3
0
    def test_cyclical_rate(self, period_length, schedule, min_value):
        rate_begin, end_of_period = sgmcmc.cyclical_rate(period_length,
                                                         1,
                                                         schedule=schedule,
                                                         min_value=min_value)
        self.assertAlmostEqual(
            rate_begin,
            1.0,
            delta=1.0e-6,
            msg='Cyclical learning rate at beginning of period '
            'not equals to one.')
        self.assertEqual(bool(end_of_period),
                         False,
                         msg='Beginning marked as end.')

        for pi in range(2, period_length):
            rate_prev, end_of_period = sgmcmc.cyclical_rate(
                period_length, pi - 1, schedule=schedule, min_value=min_value)
            rate_cur, _ = sgmcmc.cyclical_rate(period_length,
                                               pi,
                                               schedule=schedule,
                                               min_value=min_value)
            self.assertLess(
                rate_cur,
                rate_prev + 1.0e-6,
                msg='Cyclical rate increasing from %.5f to %.5f in '
                'period %d to %d' % (rate_prev, rate_cur, pi - 1, pi))
            self.assertFalse(
                bool(end_of_period),
                msg='End of period in the middle of period at index %d '
                'with rate %.5f' % (pi - 1, rate_prev))
            self.assertGreaterEqual(rate_cur,
                                    min_value,
                                    msg='Minimum value of %.5f not obeyed by '
                                    'rate %.5f' % (min_value, rate_cur))

        _, end_of_period = sgmcmc.cyclical_rate(period_length,
                                                period_length,
                                                schedule=schedule,
                                                min_value=min_value)
        self.assertTrue(bool(end_of_period), msg='End of period not detected')

        with self.assertRaises(ValueError):
            sgmcmc.cyclical_rate(period_length, 0)