Esempio n. 1
0
	def NONBUSICAL(self):
		
		session = self.sess

		date = datetime(2010, 12, 14)
		new_busiday = calendar(CALID = 'DECEMBER', DATE=date)
		session.add(new_busiday)
		session.flush()

		new_sch = schedule(SCHEDID = '1 NONBUSIDAY DEC', INTERVAL = 1, METHOD = 0, AFTER=0, DELTA=3, WAIT=0, CALID='DECEMBER', CALTYPE=1)
		session.add(new_sch)
		session.flush()
		
		new_action = action(ACTIONID = 'TESTACTIONDEC', USERID = self.USERID)
		session.add(new_action)
		session.flush()
		
		date = datetime(2010, 12, 13, 1, 1, 1)
		new_job = job(ACTIONID = 'TESTACTIONDEC', SCHEDID = '1 NONBUSIDAY DEC', SCHDATE=date, STATUS = '2')
		session.add(new_job)
		session.flush()

		new_job.resch(None)
	
		self.assert_(new_job.SCHDATE.strftime("%Y") == '2010', "NONBUSICAL test Invalid Year")
		self.assert_(new_job.SCHDATE.strftime("%d") == '14', "NONBUSICAL test Invalid Day")
		self.assert_(new_job.SCHDATE.strftime("%m") == '12', "NONBUSICAL test Invalid Month")
		self.assert_(new_job.SCHDATE.strftime("%H") == '01', "NONBUSICAL test Invalid Hour")
		self.assert_(new_job.SCHDATE.strftime("%M") == '01', "NONBUSICAL test Invalid Minute")
		self.assert_(new_job.SCHDATE.strftime("%S") == '01', "NONBUSICAL test Invalid Second")
Esempio n. 2
0
	def testQuickCreate(self):
		quick_create_event(1, event_server = server(SERVERID = "TEST"), event_schedule=schedule(SCHEDID = "SCHED1"), event_job=job(JOBID = 15), event_action=action(ACTIONID="ACT1"), event_user=user(USERID="Dan"))
		self.assert_(getsession().query(event).filter(event.EVENTNUM == 1).count())
		e = getsession().query(event).filter(event.EVENTNUM == 1).one()
		self.assert_(e.SERVERID == "TEST")
		self.assert_(e.SCHEDID == "SCHED1")
		self.assert_(e.JOBID == 15)
		self.assert_(e.ACTIONID == "ACT1")
		self.assert_(e.USERID == "Dan")
Esempio n. 3
0
	def TwoHourlySchedPast(self):
		new_sch = schedule(SCHEDID = '2 HOUR TEST', INTERVAL = 2, METHOD = 0, AFTER = 0, DELTA = 2)		
		self.sess.add(new_sch)
		self.sess.flush()

		new_action = action(ACTIONID = 'UNITTEST0013', USERID = self.USERID)

		self.sess.add(new_action)
		self.sess.flush()

		new_date = datetime(1969, 6, 20, 8, 00, 00)

		resch_date = datetime.now()

		resch_date = resch_date + timedelta(hours=2)

		new_job = job(ACTIONID = 'UNITTEST0013', SCHEDID = '2 HOUR TEST', STATUS = 0, SCHDATE = new_date)	

		self.sess.add(new_job)
		self.sess.flush()	

		self.assert_(new_job.resch(None) == True, "Error resch job with time and DAILY schedule.")	
	
		self.assert_(new_job.SCHDATE.strftime("%m") == resch_date.strftime("%m"), "INVALID MONTH FOR HourlySCHEDPAST")
		self.assert_(new_job.SCHDATE.strftime("%d") == resch_date.strftime("%d"), "INVALID DAY FOR  HourlySCHEDPAST")	
		self.assert_(new_job.SCHDATE.strftime("%Y") == resch_date.strftime("%Y"), "INVALID YEAR FOR  HourlSCHEDPAST")
		
		hour_check = resch_date.strftime("%H")
		
		if hour_check == "01":
			hour_check = "00"
		elif hour_check == "03":
			hour_check = "02"
		elif hour_check == "05":
			hour_check = "04"
		elif hour_check == "07":
			hour_check = "06"
		elif hour_check == "09":
			hour_check = "08"
		elif hour_check == "11":
			hour_check = "10"
		elif hour_check == "13":
			hour_check = "12"
		elif hour_check == "15":
			hour_check = "14"
		elif hour_check == "17":
			hour_check = "16"
		elif hour_check == "19":
			hour_check = "18"
		elif hour_check == "21":
			hour_check = "20"
		elif hour_check == "23":
			hour_check = "22"

		self.assert_(new_job.SCHDATE.strftime("%H") == hour_check, "INVALID H TIME FOR  HourlySCHEDPAST")
		self.assert_(new_job.SCHDATE.strftime("%M") == '00', "INVALID M TIME FOR DAILYSCHEDPAST")
		self.assert_(new_job.SCHDATE.strftime("%S") == '00', "INVALID S TIME FOR DAILYSCHEDPAST")
Esempio n. 4
0
	def testDescription(self):
		u = user(USERID = "DAN", EMAIL="*****@*****.**")
		edef = eventdef()
		desc = edef.event_text("%USER% has email %EMAIL%", event_user = u)
		self.assert_(u.USERID in desc)
		self.assert_(u.EMAIL in desc)

		s = server(SERVERID = "testserver", IP = "192.168.1.1", PORT = 2200, OS = 1, TYPE = 1)
		desc = edef.event_text("%SERVER% has ip %IP% port %PORT% os %SERVEROS% ", event_server = s, event_user = u)
		self.assert_(s.SERVERID in desc)
		self.assert_(s.IP in desc)
		self.assert_(str(s.PORT) in desc)
		self.assert_(s.os_to_string() in desc)

		q = job(JOBID = 14, SCHDATE = datetime.now(), BATCH = "batch1", IGNOREDEP = False, REMAININGRUNS = 15, MAXRUNS = 10)
		desc = edef.event_text("%JOB% %SCHDATE% %BATCH% %IGNOREDEPS% %REMAININGRUNS% %MAXRUNS%", event_job = q)
		self.assert_(str(q.JOBID) in desc)
		self.assert_(str(q.SCHDATE) in desc)
		self.assert_(q.BATCH in desc)
		self.assert_(str(q.IGNOREDEP) in desc)
		self.assert_(str(q.REMAININGRUNS) in desc)
		self.assert_(str(q.MAXRUNS) in desc)

		a = action(ACTIONID = "Action12", CMDLINE = "dir", STDIN = True, STARTINDIR = "/etc", STDOUT = False, OS = 1, USERID = "JOHN")
		desc = edef.event_text(" %ACTION% %CMDLINE% %STDIN% %STARTINDIR% %CAPTURESTDOUT% %ACTIONOS% %ACTIONUSER%", event_action = a)
		self.assert_(a.CMDLINE in desc)
		self.assert_(a.ACTIONID in desc)
		self.assert_(str(a.STDIN) in desc)
		self.assert_(a.STARTINDIR in desc)
		self.assert_(str(a.STDOUT) in desc)
		self.assert_(a.os_to_string() in desc)
		self.assert_(a.USERID in desc)

		s = schedule(SCHEDID = "sched1", INTERVAL = 1, METHOD = 2, AFTER = 0, DELTA = 3, CALID = "cal3", CALTYPE = 2)
		desc = edef.event_text("%SCHEDID% %INTERVAL% %METHOD% %RESCHEDULEON% %DELTA% %CALENDAR% %CALTYPE% %ASCHEDULEMAXRUNS% %ABSRESCHEDULE%", event_schedule = s)
		self.assert_(s.SCHEDID in desc)
		self.assert_(str(s.INTERVAL) in desc)
		self.assert_(s.method_str() in desc)
		self.assert_(s.resch_str() in desc)
		self.assert_(s.type_str() in desc)
		self.assert_(s.CALID in desc)
		self.assert_(s.caltype_str() in desc)
		self.assert_(str(s.MAXRUNS) in desc)
		self.assert_(str(s.WHEN) in desc)

		desc = edef.event_text("%CONTROLLER% for %MESSAGE%", event_message="A generic message")
		self.assert_(myname in desc)
		self.assert_("A generic message" in desc)
Esempio n. 5
0
	def DOWNODAYS(self):
		sess = getsession()
	
		nowdate = datetime.now()

		new_action = action(ACTIONID = "DOWNOACTION", CMDLINE = "echo no", USERID = self.USERID)
		sess.add(new_action)
		sess.flush()

		new_sch = schedule(SCHEDID = "NODOW", METHOD = 2, MON = False, TUE = False, WED = False, THU = False, FRI = False, SAT = False, SUN = False)
		sess.add(new_sch)
		sess.flush()

		new_job = job(ACTIONID = "DOWNOACTION", SCHEDID = "NODOW", STATUS = 0, SCHDATE = nowdate)
		sess.add(new_job)
		sess.flush()

		new_job.resch()

		self.assert_(new_job.SCHDATE == nowdate, "No DOW incorrect.")
Esempio n. 6
0
	def test_remainingruns(self):
		sess = getsession()

		new_action = action(ACTIONID = "DEETZ", CMDLINE = "echo no", USERID = self.USERID)
		sess.add(new_action)
		sess.flush()

		new_sch = schedule(SCHEDID = 'HOUR WITH REMAINING', INTERVAL = 1, DELTA = 2, REMAININGRUNS = 5)
		sess.add(new_sch)
		sess.flush()

		new_job = job(ACTIONID = "DEETZ", SCHEDID = "HOUR WITH REMAINING", STATUS = 0, SCHDATE = datetime.now(), REMAININGRUNS = 5)
		sess.add(new_job)
		sess.flush()

		d = datetime.now()

		new_job.resch()
		new_job.resch()
		new_job.resch()
		new_job.resch()
		new_job.resch()
		self.assert_(new_job.REMAININGRUNS == 0, "Remaing times not 0")
Esempio n. 7
0
from schedules import schedule

schedule()
 @property def schedule(self): return schedules.schedule(datetime.timedelta(**{self.period: self.every}))
 @classmethod def from_schedule(cls, dbsession, schedule, period='seconds'):        every = max(schedule.run_every.total_seconds(), 0) try:            query = dbsession.query(IntervalSchedule)            query = query.filter_by(every=every, period=period)            existing = query.one() return existing except NoResultFound: return cls(every=every, period=period) except MultipleResultsFound:            query = dbsession.query(IntervalSchedule)            query = query.filter_by(every=every, period=period)            query.delete()            dbsession.commit() return cls(every=every, period=period)
Esempio n. 9
0
def train_good(model, train_loader_in, train_loader_out, val_loader_in,
               val_loader_out_list, device, expfolder, args):
    train_out_name = train_loader_out.dataset.__repr__().split()[1]
    train_in_name = train_loader_in.dataset.__repr__().split()[1]
    print(model.layers)
    starttime = datetime.datetime.utcnow()

    schedule = schedules.schedule(args)

    print(f'schedule: {schedule}')
    model_folder = names.model_folder_name(expfolder, starttime, args,
                                           schedule)
    for subfolder in ['state_dicts', 'sample_images', 'logs', 'batch_images']:
        os.makedirs(f'{model_folder}/{subfolder}/', exist_ok=True)
    tb_subfolder = f'tb_logs/{args.tb_folder}/{model_folder}'
    os.makedirs(tb_subfolder, exist_ok=True)
    writer = SummaryWriter(tb_subfolder)
    print(f'model folder: {model_folder}')
    print(f'tb_subfolder: {tb_subfolder}')

    trainstart_message = f'Training {model.__name__} for {schedule["epochs"]} epochs of {2*min(len(train_loader_in.dataset), len(train_loader_out.dataset))} samples.'
    print(trainstart_message)

    if schedule['optimizer'] == 'SGDM':
        optimizer = optim.SGD(model.parameters(),
                              lr=schedule['start_lr'],
                              weight_decay=0.05,
                              momentum=0.9,
                              nesterov=True)
    elif schedule['optimizer'] == 'ADAM':
        optimizer = optim.Adam(model.parameters(),
                               lr=schedule['start_lr'],
                               weight_decay=0.005)
    else:
        raise ValueError(
            f'Optimizer {schedule["optimizer"]} not supported. Must be SGDM or ADAM.'
        )
    print(f'Optimizer settings: {optimizer}')
    lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
        optimizer,
        schedule['lr_decay_epochs'],
        gamma=schedule['lr_decay_factor'],
        last_epoch=-1)

    num_classes = model.num_classes

    for epoch in range(schedule['epochs']):
        #initialize epoch summary
        in_samples_this_epoch = 0
        ce_losses_in_epoch = []
        log_confs_in_epoch = []
        corrects_in_this_epoch = 0

        out_samples_this_epoch = 0
        above_quantile_losses_out_epoch, below_quantile_losses_out_epoch, full_good_losses_out_epoch, zero_good_losses_out_epoch, good_losses_out_epoch = [], [], [], [], []
        oe_losses_out_epoch, ceda_losses_out_epoch = [], []
        losses_out_epoch, kappa_losses_out_epoch = [], []
        log_confs_out_epoch, ub_log_confs_out_epoch = [], []
        acet_losses_out_epoch = []

        losses_epoch = []

        #hyperparameters for this epoch
        if schedule['kappa_epoch_ramp'] == 0:
            kappa_epoch = schedule['kappa'] * (epoch >=
                                               schedule['out_start_epoch'])
        else:
            kappa_epoch = schedule['kappa'] * min(
                max(epoch - schedule['out_start_epoch'], 0) /
                schedule['kappa_epoch_ramp'], 1)
        if schedule['eps_epoch_ramp'] == 0:
            eps_epoch = schedule['eps'] * (epoch >=
                                           schedule['eps_start_epoch'])
        else:
            eps_epoch = schedule['eps'] * min(
                (max(epoch - schedule['eps_start_epoch'], 0) /
                 schedule['eps_epoch_ramp']), 1)

        #if acet is turned on, it will be used in addition to the args.method
        if args.acet:
            if args.acet == 'ce':
                acet_lossfn = lossfunctions.CrossEntropyLossDistr
            elif args.acet == 'lc':
                acet_lossfn = lossfunctions.LogConf
            pgd = attacks.LinfPGDAttack(epsilon=eps_epoch,
                                        n=schedule['acet_n'],
                                        loss_fn=acet_lossfn,
                                        random_start=False,
                                        device=device)
        do_acet_epoch = args.acet and kappa_epoch > 0

        model.train()
        for batch_number, data in enumerate(
                zip(train_loader_in, train_loader_out), 0):
            img_batch_parts, lbl_batch_parts = [
                d[0].to(device) for d in data
            ], [d[1].to(device) for d in data]
            img_batch_in = img_batch_parts[0].to(device)
            img_batch_out = img_batch_parts[1].to(device)
            lbl_batch_in = lbl_batch_parts[0].to(device)
            lbl_batch_in_1hot = F.one_hot(lbl_batch_in, num_classes).float()
            lbl_batch_out = 1 / num_classes * torch.ones(
                lbl_batch_parts[1].size() + (num_classes, ),
                dtype=lbl_batch_parts[1].dtype).to(
                    device)  #set uniform label as it represents optimum

            batch_size_in = len(img_batch_in)
            batch_size_out = len(img_batch_out)

            in_samples_this_epoch += batch_size_in
            out_samples_this_epoch += batch_size_out

            #save example batch
            if epoch == 0 and batch_number == 0:
                vutils.save_image(img_batch_in,
                                  model_folder + '/batch_images/in_batch0.png')
                vutils.save_image(
                    img_batch_out,
                    model_folder + '/batch_images/out_batch0.png')

            optimizer.zero_grad()  #resets the calculated gradients

            logit_batch_in = model(img_batch_in)

            ce_loss_in = ce_loss(logit_batch_in, lbl_batch_in)
            ce_losses_in_epoch.append(
                ce_loss_in.detach().cpu().numpy())  #tracking
            p_in = logit_batch_in.softmax(dim=-1)  #tracking
            _, predicted_class_in = logit_batch_in.max(dim=-1)  #tracking
            corrects_in_this_epoch += predicted_class_in.eq(
                lbl_batch_in).sum().item()  #tracking

            do_acet_epoch = args.acet and kappa_epoch > 0
            if do_acet_epoch:
                if eps_epoch > 0:
                    adv_batch_out, _ = pgd.perturbt(img_batch_out,
                                                    lbl_batch_out, model)
                    model.train(
                    )  #to make sure it isn't set to eval after the attack
                else:
                    adv_batch_out = img_batch_out
                logit_adv_batch_out = model(adv_batch_out)
                acet_losses_indiv = acet_lossfn(logit_adv_batch_out,
                                                lbl_batch_out)
                log_conf_adv_out = logit_adv_batch_out.log_softmax(dim=-1).max(
                    dim=-1)[0]
                acet_loss_out = acet_losses_indiv.sum()
                acet_losses_out_epoch.append(
                    acet_loss_out.detach().cpu().numpy())  #tracking

            #calculate losses on the OOD inputs
            if args.method in {'OE', 'CEDA'}:
                logit_batch_out = model(img_batch_out)
                log_conf_out_batch = logit_batch_out.log_softmax(dim=-1).max(
                    dim=-1)[0]

                ceda_loss_out = log_conf_out_batch.sum()
                log_pred_out = logit_batch_out.log_softmax(dim=-1)
                oe_loss_out = -(log_pred_out / num_classes).sum()
                ceda_losses_out_epoch.append(
                    ceda_loss_out.detach().cpu().numpy())  #tracking
                oe_losses_out_epoch.append(
                    oe_loss_out.detach().cpu().numpy())  #tracking
                log_confs_out_epoch.append(
                    log_conf_out_batch.detach().cpu().numpy())

            if args.method == 'GOOD':
                l_logits_batch_out, u_logits_batch_out, ud_logit_out_batch = model.ibp_elision_forward(
                    img_batch_out - eps_epoch, img_batch_out + eps_epoch,
                    num_classes)
                ub_log_conf_out_batch = ud_logit_out_batch.max(dim=-1)[0].max(
                    dim=-1)[0]
                ub_conf_out_batch = ub_log_conf_out_batch.exp() / num_classes
                logit_batch_out = model(img_batch_out)
                logit_diff = logit_batch_out.max(
                    dim=-1
                )[0] - logit_batch_out.min(
                    dim=-1
                )[0]  #equals ud_logit_out_batch.max(dim=-1)[0].max(dim=-1)[0] for eps=0, but only needs 1 pass.
                l = math.floor(batch_size_out * args.good_quantile)
                h = batch_size_out - l
                above_quantile_indices = ub_log_conf_out_batch.topk(
                    h, largest=True)[
                        1]  #above or exactly at quantile, i.e. 'not below'.
                below_quantile_indices = ub_log_conf_out_batch.topk(
                    l, largest=False)[1]

                above_quantile_loss_out = (
                    (logit_diff[above_quantile_indices])**2 / 2).log1p().sum()
                below_quantile_loss_out = (
                    (ub_log_conf_out_batch[below_quantile_indices])**2 /
                    2).log1p().sum()
                good_loss_out = above_quantile_loss_out + below_quantile_loss_out

                #for tracking only
                zero_good_loss_out = (logit_diff**2 / 2).log1p().sum()
                full_good_loss_out = (ub_log_conf_out_batch**2 /
                                      2).log1p().sum()
                log_conf_out_batch = logit_batch_out.log_softmax(dim=-1).max(
                    dim=-1)[0]
                ceda_loss_out = log_conf_out_batch.sum()
                log_pred_out = logit_batch_out.log_softmax(dim=-1)
                oe_loss_out = -(log_pred_out / num_classes).sum()

                above_quantile_losses_out_epoch.append(
                    above_quantile_loss_out.detach().cpu().numpy())
                below_quantile_losses_out_epoch.append(
                    below_quantile_loss_out.detach().cpu().numpy())
                good_losses_out_epoch.append(
                    good_loss_out.detach().cpu().numpy())

                zero_good_losses_out_epoch.append(
                    zero_good_loss_out.detach().cpu().numpy())
                full_good_losses_out_epoch.append(
                    full_good_loss_out.detach().cpu().numpy())
                ceda_losses_out_epoch.append(
                    ceda_loss_out.detach().cpu().numpy())
                oe_losses_out_epoch.append(oe_loss_out.detach().cpu().numpy())
                log_confs_out_epoch.append(
                    log_conf_out_batch.detach().cpu().numpy())
                ub_log_confs_out_epoch.append(
                    ub_log_conf_out_batch.detach().cpu().numpy())

                #save example out batch splits
                if epoch % 10 == 0 and batch_number == 0:
                    if len(above_quantile_indices) > 0:
                        vutils.save_image(
                            img_batch_out[above_quantile_indices],
                            model_folder +
                            f'/batch_images/{epoch:3d}batch0_above_quantile.png'
                        )
                    if len(below_quantile_indices) > 0:
                        vutils.save_image(
                            img_batch_out[below_quantile_indices],
                            model_folder +
                            f'/batch_images/{epoch:3d}batch0_below_quantile.png'
                        )

            if args.method == 'plain' or epoch < schedule['out_start_epoch']:
                loss_batch = ce_loss_in.clone(
                )  #clone so adding acet to it cannot change ce_loss_in
                loss_name = 'in_ce'
                losses_out_epoch.append(0)
                kappa_losses_out_epoch.append(0)
            elif args.method == 'OE':
                loss_batch = ce_loss_in + kappa_epoch * oe_loss_out
                loss_name = f'in_ce+{kappa_epoch}*oe_loss_out'
                losses_out_epoch.append(oe_loss_out.detach().cpu().numpy())
                kappa_losses_out_epoch.append(
                    kappa_epoch * oe_loss_out.detach().cpu().numpy())
            elif args.method == 'CEDA':
                loss_batch = ce_loss_in + kappa_epoch * ceda_loss_out
                loss_name = f'in_ce+{kappa_epoch}*ceda_loss_out'
                losses_out_epoch.append(ceda_loss_out.detach().cpu().numpy())
                kappa_losses_out_epoch.append(
                    kappa_epoch * ceda_loss_out.detach().cpu().numpy())
            elif args.method == 'GOOD':
                loss_batch = ce_loss_in + kappa_epoch * good_loss_out
                loss_name = f'in_ce + {kappa_epoch}*(above_quantile_loss_out + eps{eps_epoch}below_quantile_loss_out)'
                losses_out_epoch.append(good_loss_out.detach().cpu().numpy())
                kappa_losses_out_epoch.append(
                    kappa_epoch * good_loss_out.detach().cpu().numpy())

            #acet is added on top
            if do_acet_epoch:
                loss_batch += kappa_epoch * acet_loss_out
                loss_name += f'+{kappa_epoch}*out_{eps_epoch}_adv_conf'

            losses_epoch.append(loss_batch.detach().cpu().numpy())  #tracking

            loss_batch.backward(
            )  # backpropagation of the loss. between here and optimizer.step() there should be no computations; only for saving the gradients it makes sense to have code between the two commands.
            optimizer.step()  # updates the parameters of the model

        ce_loss_in_epoch = np.sum(ce_losses_in_epoch) / in_samples_this_epoch
        accuracy_epoch = corrects_in_this_epoch / in_samples_this_epoch
        log_conf_in_epoch = np.sum(log_confs_in_epoch) / in_samples_this_epoch
        loss_epoch = np.sum(
            losses_epoch) / in_samples_this_epoch  #per in sample!

        loss_out_epoch = np.sum(losses_out_epoch) / out_samples_this_epoch
        kappa_loss_out_epoch = np.sum(
            kappa_losses_out_epoch) / out_samples_this_epoch

        if args.acet and kappa_epoch > 0:
            acet_loss_out_epoch = np.sum(
                acet_losses_out_epoch) / out_samples_this_epoch

        if args.method in {'OE', 'CEDA'}:
            oe_loss_out_epoch = np.sum(
                oe_losses_out_epoch) / out_samples_this_epoch
            ceda_loss_out_epoch = np.sum(
                ceda_losses_out_epoch) / out_samples_this_epoch
            log_conf_out_epoch = np.sum(
                np.concatenate(log_confs_out_epoch)) / out_samples_this_epoch
            median_log_conf_out_epoch = np.median(
                np.concatenate(log_confs_out_epoch)) / out_samples_this_epoch

        if args.method == 'GOOD':
            above_quantile_loss_out_epoch = np.sum(
                above_quantile_losses_out_epoch) / out_samples_this_epoch
            below_quantile_loss_out_epoch = np.sum(
                below_quantile_losses_out_epoch) / out_samples_this_epoch
            full_good_loss_out_epoch = np.sum(
                full_good_losses_out_epoch) / out_samples_this_epoch
            zero_good_loss_out_epoch = np.sum(
                zero_good_losses_out_epoch) / out_samples_this_epoch
            oe_loss_out_epoch = np.sum(
                oe_losses_out_epoch) / out_samples_this_epoch
            ceda_loss_out_epoch = np.sum(
                ceda_losses_out_epoch) / out_samples_this_epoch
            log_conf_out_epoch = np.sum(
                np.concatenate(log_confs_out_epoch)) / out_samples_this_epoch
            ub_log_conf_out_epoch = np.sum(
                np.concatenate(
                    ub_log_confs_out_epoch)) / out_samples_this_epoch
            median_log_conf_out_epoch = np.median(
                np.concatenate(log_confs_out_epoch)) / out_samples_this_epoch
            median_ub_log_conf_out_epoch = np.median(
                np.concatenate(
                    ub_log_confs_out_epoch)) / out_samples_this_epoch
        s_0 = f'Epoch {epoch} (lr={get_lr(optimizer)}) complete with mean training loss {loss_epoch} (ce_loss_in: {ce_loss_in_epoch}, loss_out:{loss_out_epoch}, used loss:{loss_name}).'
        s_1 = 'Time since start of training: {0}.\n'.format(
            datetime.datetime.utcnow() - starttime)
        print(s_0)
        print(s_1)

        writer.add_scalar('TrainIn/loss_total_per_in', loss_epoch, epoch)
        writer.add_scalar('TrainIn/ce_loss_in', ce_loss_in_epoch, epoch)
        writer.add_scalar('TrainIn/accuracy', accuracy_epoch, epoch)

        writer.add_scalar('TrainOut/loss_out', loss_out_epoch, epoch)
        writer.add_scalar('TrainOut/kappa_loss_out', kappa_loss_out_epoch,
                          epoch)
        if args.acet and kappa_epoch > 0:
            writer.add_scalar('TrainOut/acet_loss_out', acet_loss_out_epoch,
                              epoch)
        if args.method in {'OE', 'CEDA'}:
            writer.add_scalar('TrainOut/oe_loss_out', oe_loss_out_epoch, epoch)
            writer.add_scalar('TrainOut/ceda_loss_out', ceda_loss_out_epoch,
                              epoch)
            writer.add_scalar('TrainOut/log_conf_out', log_conf_out_epoch,
                              epoch)
            writer.add_scalar('TrainOut/median_log_conf_out',
                              median_log_conf_out_epoch, epoch)
            writer.add_histogram('Train_log_conf_out',
                                 np.concatenate(log_confs_out_epoch), epoch)
        if args.method == 'GOOD':
            writer.add_scalar('TrainOut/above_quantile_loss_out',
                              above_quantile_loss_out_epoch, epoch)
            writer.add_scalar('TrainOut/below_quantile_loss_out',
                              below_quantile_loss_out_epoch, epoch)
            writer.add_scalar('TrainOut/full_good_loss_out',
                              full_good_loss_out_epoch, epoch)
            writer.add_scalar('TrainOut/zero_good_loss_out',
                              zero_good_loss_out_epoch, epoch)
            writer.add_scalar('TrainOut/oe_loss_out', oe_loss_out_epoch, epoch)
            writer.add_scalar('TrainOut/ceda_loss_out', ceda_loss_out_epoch,
                              epoch)
            writer.add_scalar('TrainOut/log_conf_out', log_conf_out_epoch,
                              epoch)
            writer.add_scalar('TrainOut/ub_log_conf_out',
                              ub_log_conf_out_epoch, epoch)
            writer.add_scalar('TrainOut/median_log_conf_out',
                              median_log_conf_out_epoch, epoch)
            writer.add_scalar('TrainOut/median_ub_log_conf_out',
                              median_ub_log_conf_out_epoch, epoch)
            writer.add_histogram('Train_log_conf_out',
                                 np.concatenate(log_confs_out_epoch), epoch)
            writer.add_histogram('Train_ub_log_conf_out',
                                 np.concatenate(ub_log_confs_out_epoch), epoch)
        writer.add_scalar('TrainHyPa/eps', eps_epoch, epoch)
        writer.add_scalar('TrainHyPa/kappa', kappa_epoch, epoch)
        writer.add_scalar('TrainHyPa/learning_rate', get_lr(optimizer), epoch)

        do_valuation = True  #the whole evaluation only takes a few seconds.
        if do_valuation:
            val_th = 0.3  #for evaluating how many samples get conf values > 30%.
            if train_in_name == 'MNIST':
                val_eps = 0.3  #smaller values can be useful if no guarantees are given for this
            if train_in_name == 'SVHN' or train_in_name == 'CIFAR10':
                val_eps = 0.01

            eval_result_dict = evaluation.evaluate_ibp_lc(model,
                                                          val_loader_in,
                                                          val_loader_out_list,
                                                          eps=val_eps,
                                                          conf_th=val_th,
                                                          device=device,
                                                          n_pgd=0,
                                                          n_samples=1000)

            in_accuracy, pred_in_confidences, pred_in_mean_confidence, pred_in_above_th, number_of_in_datapoints = eval_result_dict[
                val_loader_out_list[0].dataset.__repr__().split()[1]][:5]
            writer.add_scalar('Val/in_accuracy', in_accuracy, epoch)
            writer.add_scalar('Val/mean_confidence', pred_in_mean_confidence,
                              epoch)
            writer.add_scalar('Val/confidences_above_{0:.2f}'.format(val_th),
                              pred_in_above_th / number_of_in_datapoints,
                              epoch)
            writer.add_scalar('Val/eps', val_eps, epoch)

            writer.add_histogram('Val/pred_in_confidences',
                                 pred_in_confidences, epoch)

            for val_loader_out in val_loader_out_list:
                out_name = val_loader_out.dataset.__repr__().split()[1]
                in_accuracy, pred_in_confidences, pred_in_mean_confidence, pred_in_above_th, number_of_in_datapoints, pred_out_confidences, pred_out_mean_confidence, pred_out_above_th, number_of_out_datapoints, ub_el_out_confidences, ub_elision_mean_out_confidence, ub_elision_median_out_confidence, ub_elision_out_below_th, auroc_from_predictions, auroc_out_guaranteed_softmax_elision, auroc_from_predictions_conservative, auroc_out_guaranteed_softmax_elision_conservative, pred_adv_out_confidences, adversarial_pred_out_mean_confidence, adversarial_pred_out_median_confidence, adversarial_pred_out_above_th = eval_result_dict[
                    out_name]

                writer.add_scalar('Val{0}/mean_confidence'.format(out_name),
                                  pred_out_mean_confidence, epoch)
                writer.add_scalar('Val{0}/mean_ub_confidence'.format(out_name),
                                  ub_elision_mean_out_confidence, epoch)
                writer.add_scalar(
                    'Val{0}/median_ub_confidence'.format(out_name),
                    ub_elision_median_out_confidence, epoch)
                writer.add_scalar(
                    'Val{0}/confidences_above_{1:.2f}'.format(
                        out_name, val_th),
                    pred_out_above_th / number_of_out_datapoints, epoch)
                writer.add_scalar(
                    'Val{0}/ub_confidences_below_{1:.2f}'.format(
                        out_name, val_th),
                    ub_elision_out_below_th / number_of_out_datapoints, epoch)
                writer.add_scalar('Val{0}/AUC'.format(out_name),
                                  auroc_from_predictions, epoch)
                writer.add_scalar('Val{0}/GAUC'.format(out_name),
                                  auroc_out_guaranteed_softmax_elision, epoch)
                writer.add_scalar('Val{0}/cAUC'.format(out_name),
                                  auroc_from_predictions_conservative, epoch)
                writer.add_scalar(
                    'Val{0}/cGAUC'.format(out_name),
                    auroc_out_guaranteed_softmax_elision_conservative, epoch)

                writer.add_histogram('Val{0}confidences'.format(out_name),
                                     pred_out_confidences, epoch)
                writer.add_histogram('Val{0}/ub_confidences'.format(out_name),
                                     ub_el_out_confidences, epoch)

        lr_scheduler.step()
        if epoch % 50 == 0 or epoch == 103 or epoch == 105:
            save_filename = model_folder + '/state_dicts/{0:03d}.pt'.format(
                epoch)
            torch.save(model.state_dict(), save_filename)
        torch.cuda.empty_cache()
        del data, img_batch_parts, lbl_batch_parts
        if 'reopen_data_file' in dir(
                train_loader_out
        ):  #loading 80M Tiny Images and thus training becomes much slower from epoch 2 if we do not do this.
            train_loader_out.reopen_data_file()
    stoptime = datetime.datetime.utcnow()
    dt = stoptime - starttime
    save_filename = model_folder + '/state_dicts/' + str(epoch) + 'fin.pt'
    torch.save(model.state_dict(), save_filename)
    print('Training finished after {0} seconds'.format(dt))
    writer.close()
    return model_folder