Ejemplo n.º 1
0
    def test_patience(self):
        es = early_stopping.EarlyStopping(min_delta=0, patience=0)
        patient_es = early_stopping.EarlyStopping(min_delta=0, patience=6)
        for step in range(10):
            metric = 1.
            did_improve, es = es.update(metric)
            if es.should_stop:
                break

        self.assertEqual(step, 1)

        for patient_step in range(10):
            metric = 1.
            did_improve, patient_es = patient_es.update(metric)
            if patient_es.should_stop:
                break

        self.assertEqual(patient_step, 7)
Ejemplo n.º 2
0
  def test_delta(self):
    es = early_stopping.EarlyStopping(min_delta=0, 
                                      patience=0)
    delta_es = early_stopping.EarlyStopping(min_delta=1e-3, 
                                            patience=0)
    delta_patient_es = early_stopping.EarlyStopping(min_delta=1e-3, 
                                                    patience=1)
    metric = 1.
    for step in range(100):
      metric -= 1e-4
      did_improve, es = es.update(metric)
      if es.should_stop:
        break

    self.assertEqual(step, 99)

    metric = 1.
    for step in range(100):
      metric -= 1e-4
      did_improve, delta_es = delta_es.update(metric)
      if delta_es.should_stop:
        break
      
    self.assertEqual(step, 1)

    metrics = [0.01, 0.005, 0.0033, 0.0025, 0.002, 
               0.0017, 0.0014, 0.0012, 0.0011, 0.001]
    improvement_steps = 0
    for step in range(10):
      metric = metrics[step]
      did_improve, delta_patient_es = delta_patient_es.update(metric)
      if did_improve:
        improvement_steps += 1
      if delta_patient_es.should_stop:
        break
      
    self.assertEqual(improvement_steps, 4)  # steps 0, 1, 2, 4
    self.assertEqual(step, 6)
Ejemplo n.º 3
0
    def test_update(self):
        es = early_stopping.EarlyStopping(min_delta=0, patience=0)

        for i in range(2):
            improve_steps = 0
            for step in range(10):
                metric = 1.
                did_improve, es = es.update(metric)
                if not did_improve:
                    improve_steps += 1
                if es.should_stop:
                    break

            self.assertEqual(improve_steps, 1)
            self.assertEqual(step, 1)

            es = es.reset()  # ensure object is reusable if reset.