def train_epoch_side_effect(net, train_loader, clean_val_loader, triggered_val_loader, epoch, progress_bar_disable=True, use_amp=False): # these variables are not consequential for the early-stopping code, so we just set them to # constants train_acc_noop = 1.0 train_loss_noop = 1.0 ts = EpochTrainStatistics(train_acc_noop, train_loss_noop) val_acc_noop = 1.0 if epoch < 2: val_loss = 10.0 - epoch # we keep the loss decreasing until the first 4 epochs # This prevents the early-stopping code from being activated, # since the loss is decreasing every epoch vs = EpochValidationStatistics(val_acc_noop, val_loss, None, None) return ts, vs else: val_loss = float( epoch) # we fix the loss from here on within eps, # we expect it to quit in 5 epochs vs = EpochValidationStatistics(val_acc_noop, val_loss, None, None) return ts, vs
def train_epoch_side_effect(net, train_loader, val_loader, epoch, progress_bar_disable=True): # these variables are not consequential for the early-stopping code, so we just set them to # constants train_acc_noop = 1.0 train_loss_noop = 1.0 ts = EpochTrainStatistics(train_acc_noop, train_loss_noop) val_acc_noop = 1.0 if epoch < 2: val_loss = 10.0 - epoch # we keep the loss decreasing until the first 4 epochs # This prevents the early-stopping code from being activated, # since the loss is decreasing every epoch vs = EpochValidationStatistics(val_acc_noop, val_loss) return ts, vs else: val_loss = 9.0 - eps # decrease the loss, but only by eps, so we quit vs = EpochValidationStatistics(val_acc_noop, val_loss) return ts, vs