def train_epoch_side_effect(net,
                             train_loader,
                             clean_val_loader,
                             triggered_val_loader,
                             epoch,
                             progress_bar_disable=True,
                             use_amp=False):
     # these variables are not consequential for the early-stopping code, so we just set them to
     # constants
     train_acc_noop = 1.0
     train_loss_noop = 1.0
     ts = EpochTrainStatistics(train_acc_noop, train_loss_noop)
     val_acc_noop = 1.0
     if epoch < 2:
         val_loss = 10.0 - epoch  # we keep the loss decreasing until the first 4 epochs
         # This prevents the early-stopping code from being activated,
         # since the loss is decreasing every epoch
         vs = EpochValidationStatistics(val_acc_noop, val_loss,
                                        None, None)
         return ts, vs
     else:
         val_loss = float(
             epoch)  # we fix the loss from here on within eps,
         # we expect it to quit in 5 epochs
         vs = EpochValidationStatistics(val_acc_noop, val_loss,
                                        None, None)
         return ts, vs
Esempio n. 2
0
 def train_epoch_side_effect(net,
                             train_loader,
                             val_loader,
                             epoch,
                             progress_bar_disable=True):
     # these variables are not consequential for the early-stopping code, so we just set them to
     # constants
     train_acc_noop = 1.0
     train_loss_noop = 1.0
     ts = EpochTrainStatistics(train_acc_noop, train_loss_noop)
     val_acc_noop = 1.0
     if epoch < 2:
         val_loss = 10.0 - epoch  # we keep the loss decreasing until the first 4 epochs
         # This prevents the early-stopping code from being activated,
         # since the loss is decreasing every epoch
         vs = EpochValidationStatistics(val_acc_noop, val_loss)
         return ts, vs
     else:
         val_loss = 9.0 - eps  # decrease the loss, but only by eps, so we quit
         vs = EpochValidationStatistics(val_acc_noop, val_loss)
         return ts, vs