コード例 #1
0
ファイル: experiment.py プロジェクト: makslevental/ferit_nets
def exp3():
    tuf_table_file_name = 'all_maxs.csv'
    all_alarms = tuf_table_csv_to_df(
        os.path.join(PROJECT_ROOT, "csvs", tuf_table_file_name))
    australia_alarms = tuf_table_csv_to_df(
        os.path.join(PROJECT_ROOT, "csvs", "australia.csv"))

    m1, s1 = torch.serialization.load(os.path.join(PROJECT_ROOT, 'means.pt')), \
             torch.serialization.load(os.path.join(PROJECT_ROOT, 'stds.pt'))

    australia_ad = AlarmDataset(australia_alarms,
                                DATA_ROOT,
                                transform=transforms.Compose(
                                    [Normalize(m1, s1)]))
    australia_adl = DataLoader(australia_ad,
                               BATCH_SIZE,
                               shuffle=False,
                               num_workers=4)

    f = open(os.path.join(LOGS_PATH, "loss.csv"), "w+")

    nets = defaultlist(lambda: torch.nn.DataParallel(GPR_15_300()))
    for overtrain in range(100):
        for i, (strat_splits, _rgn_train, rgn_holdout) in enumerate(
                region_and_stratified(all_alarms,
                                      n_splits_stratified=N_STRAT_SPLITS)):

            net = nets[i]
            # main training loop
            for j, (alarm_strat_train,
                    alarm_strat_holdout) in enumerate(strat_splits):
                # train
                strat_train_ad = AlarmDataset(alarm_strat_train,
                                              DATA_ROOT,
                                              transform=transforms.Compose(
                                                  [Normalize(m1, s1)]))
                strat_train_adl = DataLoader(strat_train_ad,
                                             BATCH_SIZE,
                                             SHUFFLE_DL,
                                             num_workers=4)

                optim = OPTIMIZER(net)
                sched = SCHEDULER(optim, strat_train_adl)
                for k, _ in enumerate(
                        train(net,
                              strat_train_adl,
                              criterion=CRITERION,
                              optimizer=optim,
                              scheduler=sched,
                              epochs=EPOCHS)):
                    _roc, auc, _all_labels, _confs, loss = test(
                        net, australia_adl, CRITERION)

                    f.write(f"{i}, {j}, {k}, {auc}, {loss}\n")
                    f.flush()
                    print(f"done with {i} {j} {k}  auc: {auc} loss: {loss}")
コード例 #2
0
ファイル: experiment.py プロジェクト: makslevental/ferit_nets
def exp9():
    tuf_table_file_name = 'all_maxs.csv'
    exp_dir = "/home/maksim/dev_projects/ferit_nets/experiments/2019-03-23_train_all_50_epochs_aucloss"
    all_alarms = tuf_table_csv_to_df(
        os.path.join(PROJECT_ROOT, "csvs", tuf_table_file_name))
    m1, s1 = torch.serialization.load(os.path.join(PROJECT_ROOT, 'means.pt')), \
             torch.serialization.load(os.path.join(PROJECT_ROOT, 'stds.pt'))

    for i, n in enumerate(nets):
        print("loading net ", i)
        n.cuda()
    criterion = AucLoss()

    for i, (_strat_splits, _rgn_train, rgn_holdout) in enumerate(
            region_and_stratified(all_alarms,
                                  n_splits_stratified=N_STRAT_SPLITS)):
        region_holdout_ad = AlarmDataset(rgn_holdout,
                                         DATA_ROOT,
                                         transform=transforms.Compose(
                                             [Normalize(m1, s1)]))
        region_holdout_adl = DataLoader(
            region_holdout_ad,
            BATCH_SIZE,
            shuffle=False,
            num_workers=multiprocessing.cpu_count())
        print(f"testing region {i}")
        try:
            ensemble_test = test_ensemble(nets[i * 10:(i + 1) * 10],
                                          region_holdout_adl, criterion,
                                          lambda cs: gmean(cs, axis=0))
            pickle.dump(
                ensemble_test,
                open(os.path.join(exp_dir, f"ensemble_{i}_test.pkl"), "wb"))
        except Exception as e:
            print(e)
コード例 #3
0
ファイル: experiment.py プロジェクト: makslevental/ferit_nets
def exp11():
    tuf_table_file_name = 'all_maxs.csv'
    all_alarms = tuf_table_csv_to_df(
        os.path.join(PROJECT_ROOT, "csvs", tuf_table_file_name))
    df = pd.DataFrame()
    for i, (_strat_splits, _rgn_train, rgn_holdout) in enumerate(
            region_and_stratified(all_alarms,
                                  n_splits_stratified=N_STRAT_SPLITS)):
        print(i)
        df = df.append(rgn_holdout)

    df.to_csv(os.path.join(PROJECT_ROOT, "csvs", "regions.csv"))
コード例 #4
0
ファイル: cross_val.py プロジェクト: makslevental/ferit_nets
 def test_create_cross_val_split(self):
     cv = LeaveOneGroupOut()
     alarms = cross_val.tuf_table_csv_to_df(self.csv_fp)
     true_alarms, false_alarms = cross_val.split_t_f_alarms(alarms)
     grouped_t_alarms = cross_val.group_alarms_by(
         true_alarms, group_assign_attrs=['target', 'depth', 'corners'])
     crossval_splits = cross_val.create_cross_val_splits(
         3, cv, grouped_t_alarms.df, grouped_t_alarms.idxs_groupids)
     for nonholdout, holdout in crossval_splits:
         self.assertFalse(set(nonholdout.index) & set(holdout.index))
         self.assertEqual(
             set(nonholdout.index) | set(holdout.index),
             set(range(len(grouped_t_alarms.df))))
コード例 #5
0
ファイル: cross_val.py プロジェクト: makslevental/ferit_nets
    def test_group_false_alarms(self):
        alarms = cross_val.tuf_table_csv_to_df(self.csv_fp)
        _, false_alarms = cross_val.split_t_f_alarms(alarms)
        grouped_alarms, idxs_gids, df = cross_val.group_false_alarms(
            false_alarms)

        self.assertEqual(len(false_alarms), len(df))
        self.assertEqual(len(false_alarms),
                         sum(map(len, map(attrg('idxs'), grouped_alarms))))

        for f_alarm_grp in grouped_alarms:
            for ix in f_alarm_grp.idxs:
                for iy in f_alarm_grp.idxs:
                    self.assertLessEqual(
                        np.linalg.norm(
                            np.asarray(df.loc[ix]['utm']) -
                            np.asarray(df.loc[iy]['utm'])), 10 * np.sqrt(2))
コード例 #6
0
ファイル: cross_val.py プロジェクト: makslevental/ferit_nets
    def test_group_alarms_by(self):
        alarms = cross_val.tuf_table_csv_to_df(self.csv_fp)
        with self.assertRaisesRegex(AssertionError, 'groupby'):
            cross_val.group_alarms_by(alarms, ['missing_columns'])
        with self.assertRaisesRegex(AssertionError, 'group assign'):
            cross_val.group_alarms_by(alarms, ['srid', 'target'],
                                      ['missing_group_attr'])

        true_alarms, false_alarms = cross_val.split_t_f_alarms(alarms)
        with self.assertRaisesRegex(AssertionError, 'null values'):
            cross_val.group_alarms_by(false_alarms)

        grouped_alarms, idxs_gids, df = cross_val.group_alarms_by(true_alarms)
        self.assertEqual(len(true_alarms), len(df))
        self.assertEqual(len(true_alarms),
                         sum(map(len, map(attrg('idxs'), grouped_alarms))))

        for groupby_attrs in [['target'], ['target', 'depth', 'corners']]:
            grouped_alarms, idxs_gids, df = cross_val.group_alarms_by(
                true_alarms,
                groupby_attrs=groupby_attrs,
                group_assign_attrs=groupby_attrs)

            self.assertEqual(len(true_alarms), len(df))
            self.assertEqual(len(true_alarms),
                             sum(map(len, map(attrg('idxs'), grouped_alarms))))

            dfidxs_gids = dict(idxs_gids)
            gids_dfidxs = dict(grouped_alarms)
            for grouped_alarm in grouped_alarms:
                self.assertEqual(
                    set(
                        map(tuple,
                            df[groupby_attrs].loc[grouped_alarm.idxs].values)),
                    {grouped_alarm.group_id})
                for idx in grouped_alarm.idxs:
                    self.assertEqual(dfidxs_gids[idx], grouped_alarm.group_id)

            for groupd_alarm_index in idxs_gids:
                self.assertEqual(
                    tuple(
                        df.loc[groupd_alarm_index.idx][groupby_attrs].values),
                    groupd_alarm_index.group_id)
                self.assertTrue(groupd_alarm_index.idx in gids_dfidxs[
                    groupd_alarm_index.group_id])
コード例 #7
0
ファイル: experiment.py プロジェクト: makslevental/ferit_nets
def exp8(exp_dir, criterion):
    m1, s1 = torch.serialization.load(os.path.join(PROJECT_ROOT, 'means.pt')), \
             torch.serialization.load(os.path.join(PROJECT_ROOT, 'stds.pt'))

    australia_alarms = tuf_table_csv_to_df(
        os.path.join(PROJECT_ROOT, "csvs", "all_australia.csv"))
    australia_ad = AlarmDataset(australia_alarms,
                                DATA_ROOT,
                                transform=transforms.Compose(
                                    [Normalize(m1, s1)]))
    australia_adl = DataLoader(australia_ad,
                               BATCH_SIZE,
                               shuffle=False,
                               num_workers=multiprocessing.cpu_count())
    nets = load_nets_dir(os.path.join(exp_dir, "nets"), "GPR_15_300")
    for i, n in enumerate(nets):
        print("loading net ", i)
        n.cuda()
    ensemble_test = test_ensemble(nets, australia_adl, criterion,
                                  lambda cs: gmean(cs, axis=0))
    pickle.dump(ensemble_test,
                open(os.path.join(exp_dir, f"ensemble_aus_test.pkl"), "wb"))
コード例 #8
0
ファイル: cross_val.py プロジェクト: makslevental/ferit_nets
    def test_parse_csv(self):
        alarms = cross_val.tuf_table_csv_to_df(self.csv_fp)
        self.assertGreater(len(alarms), 0)
        self.assertTrue((alarms.columns == [
            'HIT', 'event_id', 'sample', 'pre_conf', 'conf', 'lane', 'site',
            'srid', 'target', 'depth', 'corners', 'utm'
        ]).all())

        self.assertFalse(alarms[[
            'HIT', 'event_id', 'sample', 'pre_conf', 'conf', 'lane', 'site',
            'srid', 'utm'
        ]].isnull().any().any())

        self.assertTrue(
            all(alarms['corners'].map(
                lambda c: isinstance(c, tuple) or c is None)))
        self.assertTrue(
            (alarms['corners'].isnull() == alarms['HIT'].map(lambda h: h == 0)
             ).all())
        self.assertTrue(
            (alarms['depth'].isnull() == alarms['HIT'].map(lambda h: h == 0)
             ).all())
コード例 #9
0
ファイル: train.py プロジェクト: makslevental/ferit_nets
            outputs = net(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            if np.isnan(loss.item()):
                raise Exception('gradients blew up')
            writer.add_scalar('Train/Loss', loss, i)
            writer.add_scalar('Train/LR', optimizer.param_groups[0]['lr'], i)
        scheduler.step()
        yield net, loss.item()


if __name__ == '__main__':
    tuf_table_file_name = 'small_maxs_table.csv'
    all_alarms = tuf_table_csv_to_df(os.path.join(PROJECT_ROOT, tuf_table_file_name))
    ad = AlarmDataset(all_alarms, DATA_ROOT)
    adl = DataLoader(ad, BATCH_SIZE, SHUFFLE_DL)
    net = torch.nn.DataParallel(GPR_15_300())

    optim = OPTIMIZER(net)
    sched = SCHEDULER(optim, adl)
    net = train(
        net,
        adl,
        criterion=CRITERION,
        optimizer=optim,
        scheduler=sched,
        epochs=EPOCHS
    )
    # print(net.state_dict())
コード例 #10
0
ファイル: experiment.py プロジェクト: makslevental/ferit_nets
def exp7():
    tuf_table_file_name = 'all_maxs.csv'
    all_alarms = tuf_table_csv_to_df(
        os.path.join(PROJECT_ROOT, "csvs", tuf_table_file_name))
    m1, s1 = torch.serialization.load(os.path.join(PROJECT_ROOT, 'means.pt')), \
             torch.serialization.load(os.path.join(PROJECT_ROOT, 'stds.pt'))
    nets = defaultlist(
        lambda: defaultlist(lambda: torch.nn.DataParallel(GPR_15_300())))
    log_path = os.path.join(LOGS_PATH, "loss.csv")
    log_csv = open(log_path, "w")
    log_csv.write("region, strat, epoch, strat_auc, strat_loss\n")
    epochs = 50
    criterion = CRITERION
    strat_aucs = [0, 0, 0]
    for i, (strat_splits, _rgn_train, rgn_holdout) in enumerate(
            region_and_stratified(all_alarms,
                                  n_splits_stratified=N_STRAT_SPLITS)):
        print(rgn_holdout[['site', 'lane']].drop_duplicates())
        # main training loop
        for j, (alarm_strat_train,
                alarm_strat_holdout) in enumerate(strat_splits):
            net = nets[i][j]
            strat_train_ad = AlarmDataset(alarm_strat_train,
                                          DATA_ROOT,
                                          transform=transforms.Compose(
                                              [Normalize(m1, s1)]))
            strat_train_adl = DataLoader(strat_train_ad,
                                         BATCH_SIZE,
                                         shuffle=True,
                                         num_workers=16)
            strat_holdout_ad = AlarmDataset(alarm_strat_holdout,
                                            DATA_ROOT,
                                            transform=transforms.Compose(
                                                [Normalize(m1, s1)]))
            strat_holdout_adl = DataLoader(strat_holdout_ad,
                                           BATCH_SIZE,
                                           shuffle=True,
                                           num_workers=4)
            optim = OPTIMIZER(net)
            sched = SCHEDULER(optim, strat_train_adl)
            for k, (_, loss) in enumerate(
                    train(net,
                          strat_train_adl,
                          criterion=criterion,
                          optimizer=optim,
                          scheduler=sched,
                          epochs=epochs)):
                if k % 5 == 0:
                    _roc, strat_auc, _all_labels, strat_confs, strat_loss = test(
                        net, strat_holdout_adl, criterion)
                    log_csv.write(
                        f"{i}, {j}, {k}, {strat_auc}, {strat_loss}\n")
                    log_csv.flush()
                    print(f"{i}, {j}, {k}, {strat_auc}, {strat_loss}\n")
                    strat_aucs[k % 3] = strat_auc
                    if np.abs(np.mean(strat_aucs) - 0.5) < 0.0000001 or np.abs(
                            np.mean(strat_aucs) - 0.5) < 0.00000001:
                        print(f"!!!! diverged")
                        nets[i][j] = torch.nn.DataParallel(GPR_15_300())

    for i, sub_nets in enumerate(nets):
        for j, net in enumerate(sub_nets):
            torch.save(net.state_dict(),
                       os.path.join(NETS_PATH, f"net_{i}_{j}.net"))
コード例 #11
0
ファイル: experiment.py プロジェクト: makslevental/ferit_nets
def exp6():
    tuf_table_file_name = 'big_maxs_table.csv'
    all_alarms = tuf_table_csv_to_df(
        os.path.join(PROJECT_ROOT, "csvs", tuf_table_file_name))
    m1, s1 = torch.serialization.load(os.path.join(PROJECT_ROOT, 'means.pt')), \
             torch.serialization.load(os.path.join(PROJECT_ROOT, 'stds.pt'))

    log_path = os.path.join(LOGS_PATH, "loss.csv")
    log_csv = open(log_path, "w")
    log_csv.write("overtrain_epoch, net, train_epoch, strat_auc, strat_loss\n")

    for i, (strat_splits, _rgn_train, rgn_holdout) in enumerate(
            region_and_stratified(all_alarms,
                                  n_splits_stratified=N_STRAT_SPLITS)):
        break

    net = torch.nn.DataParallel(GPR_15_300())
    strat_aucs = [0, 0, 0]
    strat_train_ad = None
    strat_train_adl = None
    strat_holdout_ad = None
    strat_holdout_adl = None
    for i in range(100):
        # main training loop
        for j, (alarm_strat_train,
                alarm_strat_holdout) in enumerate(strat_splits):
            # train
            strat_train_ad = strat_train_ad if strat_train_ad is not None else AlarmDataset(
                alarm_strat_train,
                DATA_ROOT,
                transform=transforms.Compose([Normalize(m1, s1)]))
            strat_train_adl = strat_train_adl if strat_train_adl is not None else DataLoader(
                strat_train_ad, BATCH_SIZE, shuffle=True, num_workers=4)

            strat_holdout_ad = strat_holdout_ad if strat_holdout_ad is not None else AlarmDataset(
                alarm_strat_holdout,
                DATA_ROOT,
                transform=transforms.Compose([Normalize(m1, s1)]))
            strat_holdout_adl = strat_holdout_adl if strat_holdout_adl is not None else DataLoader(
                strat_holdout_ad, BATCH_SIZE, shuffle=True, num_workers=4)

            optim = OPTIMIZER(net)
            sched = SCHEDULER(optim, strat_train_adl)
            for k, _ in enumerate(
                    train(
                        net,
                        strat_train_adl,
                        # criterion=AucLoss(),
                        criterion=CRITERION,
                        optimizer=optim,
                        scheduler=sched,
                        epochs=EPOCHS)):
                _roc, strat_auc, _all_labels, strat_confs, strat_loss = test(
                    net, strat_holdout_adl, CRITERION)

                log_csv.write(f"{i}, {j}, {k}, {strat_auc}, {strat_loss}\n")
                log_csv.flush()
                print(f"{i}, {j}, {k}, {strat_auc}, {strat_loss}\n")
                strat_aucs[k % 3] = strat_auc
                if np.abs(np.mean(strat_aucs) - 0.5) < 0.0000001 or np.abs(
                        np.mean(strat_aucs) - 0.5) < 0.00000001:
                    print(f"!!!! diverged")
                    net = torch.nn.DataParallel(GPR_15_300())
                    break

            break

    log_csv.close()
    log_df = pd.read_csv(log_path, sep=",\s+")
    plot_auc_loss(log_df, "cross entropy loss")
コード例 #12
0
ファイル: experiment.py プロジェクト: makslevental/ferit_nets
def experiment1():
    tuf_table_file_name = 'all_maxs.csv'
    all_alarms = tuf_table_csv_to_df(
        os.path.join(PROJECT_ROOT, tuf_table_file_name))

    print(PROJECT_NAME)

    m1, s1 = torch.serialization.load(os.path.join(PROJECT_ROOT, 'means.pt')), \
             torch.serialization.load(os.path.join(PROJECT_ROOT, 'stds.pt'))

    for i, (strat_splits, _rgn_train, rgn_holdout) in enumerate(
            region_and_stratified(all_alarms,
                                  n_splits_stratified=N_STRAT_SPLITS)):
        nets = []

        # main training loop
        for j, (alarm_strat_train,
                alarm_strat_holdout) in enumerate(strat_splits):
            # train
            strat_train_ad = AlarmDataset(alarm_strat_train,
                                          DATA_ROOT,
                                          transform=transforms.Compose(
                                              [Normalize(m1, s1)]))
            strat_train_adl = DataLoader(strat_train_ad,
                                         BATCH_SIZE,
                                         SHUFFLE_DL,
                                         num_workers=4)
            net = torch.nn.DataParallel(GPR_15_300())
            optim = OPTIMIZER(net)
            sched = SCHEDULER(optim, strat_train_adl)
            candidate_net_aucs = []
            for candidate_net in train(net,
                                       strat_train_adl,
                                       criterion=CRITERION,
                                       optimizer=optim,
                                       scheduler=sched,
                                       epochs=EPOCHS):
                # test for early stopping using flattening of auc curve
                holdout_ad = AlarmDataset(alarm_strat_holdout,
                                          DATA_ROOT,
                                          transform=transforms.Compose(
                                              [Normalize(m1, s1)]))
                holdout_adl = DataLoader(holdout_ad,
                                         BATCH_SIZE,
                                         SHUFFLE_DL,
                                         num_workers=4)
                _roc, auc, _all_labels, _confs = test(net, holdout_adl)
                candidate_net_aucs.append(auc)
                aucp, aucpp = deriv(np.arange(len(candidate_net_aucs)),
                                    candidate_net_aucs)
                if len(candidate_net_aucs) >= 3 and all(
                        aucp[-2:] >= 0) and all(aucpp[-2:0] <= 0):
                    break

            nets.append((auc, candidate_net))
            print(f"done with {i} {j} train")

        # drop worst (by auc) 2 nets
        nets = list(map(itemgetter(1), sorted(nets, key=itemgetter(0))[2:]))
        # majority = len(nets) // 2 + 1
        rgn_holdout_ad = AlarmDataset(rgn_holdout,
                                      DATA_ROOT,
                                      transform=transforms.Compose(
                                          [Normalize(m1, s1)]))

        ####
        # testing on holdoout
        ####

        fig_test = None
        # DO NOT SHUFFLE region holdout in order to fuse
        rgn_holdout_adl = DataLoader(rgn_holdout_ad,
                                     BATCH_SIZE,
                                     shuffle=False,
                                     num_workers=4)
        for j, net in enumerate(nets):
            torch.save(net.state_dict(),
                       os.path.join(NETS_PATH, f"net_test_{i}_{j}.net"))
            roc, auc, labels, confs = test(net, rgn_holdout_adl)
            fig_test = plot_roc(roc,
                                f"test {i}",
                                f"{auc}",
                                show=False,
                                fig=fig_test)

            all_aucs = all_aucs + [auc] if j > 0 else [auc]
            all_confs = np.vstack((all_confs, confs)) if j > 0 else confs
            all_labels = np.vstack((all_labels, labels)) if j > 0 else labels
        fig_test.savefig(os.path.join(FIGS_PATH, f'test_{i}.png'))

        # fusion
        fused_confs = gmean(all_confs, axis=0)
        fused_roc = roc_curve(all_labels[0, :], fused_confs)
        if len(set(all_labels[0, :])) > 1:
            fused_auc = roc_auc_score(all_labels[0, :], fused_confs)
        else:
            fused_auc = 'NaN'
        fig_fused = plot_roc(fused_roc, f"fused {i} roc", f"auc {fused_auc}")
        fig_fused.savefig(os.path.join(FIGS_PATH, f'fused_{i}.png'))

        # histograms
        hist_range = (np.min(all_confs), np.max(all_confs))
        hist_fig = plt.figure()
        ax = hist_fig.add_axes([0.1, 0.1, 0.85, 0.8])
        ax.set_title(f"hist {i}")
        for j, confs in enumerate(all_confs):
            ax.hist(confs,
                    100,
                    range=hist_range,
                    density=True,
                    label=all_aucs[j],
                    alpha=0.7)
        hist_fig.legend(loc='upper right')
        hist_fig.savefig(os.path.join(FIGS_PATH, f'hist_{i}.png'))

        plt.close(fig_test)
        plt.close(fig_fused)
        plt.close(hist_fig)
コード例 #13
0
ファイル: experiment.py プロジェクト: makslevental/ferit_nets
def exp4():
    tuf_table_file_name = 'all_maxs.csv'
    all_alarms = tuf_table_csv_to_df(
        os.path.join(PROJECT_ROOT, "csvs", tuf_table_file_name))
    australia_alarms = tuf_table_csv_to_df(
        os.path.join(PROJECT_ROOT, "csvs", "australia.csv"))

    m1, s1 = torch.serialization.load(os.path.join(PROJECT_ROOT, 'means.pt')), \
             torch.serialization.load(os.path.join(PROJECT_ROOT, 'stds.pt'))

    australia_ad = AlarmDataset(australia_alarms,
                                DATA_ROOT,
                                transform=transforms.Compose(
                                    [Normalize(m1, s1)]))
    australia_adl = DataLoader(australia_ad,
                               BATCH_SIZE,
                               shuffle=False,
                               num_workers=4)

    f = open(os.path.join(LOGS_PATH, "loss.csv"), "w+")

    for i, (strat_splits, _rgn_train, rgn_holdout) in enumerate(
            region_and_stratified(all_alarms,
                                  n_splits_stratified=N_STRAT_SPLITS)):
        rgn_holdout_ad = AlarmDataset(rgn_holdout,
                                      DATA_ROOT,
                                      transform=transforms.Compose(
                                          [Normalize(m1, s1)]))
        rgn_holdout_adl = DataLoader(rgn_holdout_ad,
                                     BATCH_SIZE,
                                     shuffle=False,
                                     num_workers=4)
        break

    aus_aucs, rgn_aucs = [0, 0, 0], [0, 0, 0]
    net = torch.nn.DataParallel(GPR_15_300())
    for overtrain in range(100):
        # main training loop
        for j, (alarm_strat_train,
                alarm_strat_holdout) in enumerate(strat_splits):
            # train
            strat_train_ad = AlarmDataset(alarm_strat_train,
                                          DATA_ROOT,
                                          transform=transforms.Compose(
                                              [Normalize(m1, s1)]))
            strat_train_adl = DataLoader(strat_train_ad,
                                         BATCH_SIZE,
                                         shuffle=True,
                                         num_workers=4)
            strat_holdout_ad = AlarmDataset(alarm_strat_holdout,
                                            DATA_ROOT,
                                            transform=transforms.Compose(
                                                [Normalize(m1, s1)]))
            strat_holdout_adl = DataLoader(strat_holdout_ad,
                                           BATCH_SIZE,
                                           shuffle=True,
                                           num_workers=4)

            optim = OPTIMIZER(net)
            sched = SCHEDULER(optim, strat_train_adl)
            for k, _ in enumerate(
                    train(net,
                          strat_train_adl,
                          criterion=CRITERION,
                          optimizer=optim,
                          scheduler=sched,
                          epochs=EPOCHS)):
                _roc, aus_auc, _all_labels, aus_confs, aus_loss = test(
                    net, australia_adl, CRITERION)
                _roc, rgn_auc, _all_labels, rgn_confs, rgn_loss = test(
                    net, rgn_holdout_adl, CRITERION)
                _roc, strat_auc, _all_labels, strat_confs, strat_loss = test(
                    net, strat_holdout_adl, CRITERION)

                f.write(
                    f"{overtrain}, {i}, {j}, {k}, {aus_auc}, {aus_loss}, {rgn_auc}, {rgn_loss}, {strat_auc}, {strat_loss}\n"
                )
                f.flush()
                print(
                    f"{overtrain}, {i}, {j}, {k}, {aus_auc}, {aus_loss}, {rgn_auc}, {rgn_loss}, {strat_auc}, {strat_loss}\n"
                )
                aus_aucs[k % 3] = aus_auc
                rgn_aucs[k % 3] = rgn_auc
                if np.abs(np.mean(aus_aucs) -
                          0.5) < 0.0000001 or np.abs(np.mean(rgn_aucs) -
                                                     0.5) < 0.00000001:
                    print(
                        f"!!!!diverged {np.mean(aus_confs)} {np.mean(rgn_confs)}"
                    )
                    net = torch.nn.DataParallel(GPR_15_300())
                    break
コード例 #14
0
ファイル: experiment.py プロジェクト: makslevental/ferit_nets
def exp2():
    tuf_table_file_name = 'all_maxs.csv'
    all_alarms = tuf_table_csv_to_df(
        os.path.join(PROJECT_ROOT, tuf_table_file_name))

    m1, s1 = torch.serialization.load(os.path.join(PROJECT_ROOT, 'means.pt')), \
             torch.serialization.load(os.path.join(PROJECT_ROOT, 'stds.pt'))
    nets = load_nets_dir(NETS_PATH, "GPR_15_300")

    for i, (strat_splits, _rgn_train, rgn_holdout) in enumerate(
            region_and_stratified(all_alarms,
                                  n_splits_stratified=N_STRAT_SPLITS)):
        rgn_holdout_ad = AlarmDataset(rgn_holdout,
                                      DATA_ROOT,
                                      transform=transforms.Compose(
                                          [Normalize(m1, s1)]))
        fig_test = None
        # DO NOT SHUFFLE region holdout in order to fuse
        rgn_holdout_adl = DataLoader(rgn_holdout_ad,
                                     BATCH_SIZE,
                                     shuffle=False,
                                     num_workers=4)
        for j, net in enumerate(nets):
            # torch.save(net.state_dict(), os.path.join(NETS_PATH, f"net_test_{i}_{j}.net"))
            roc, auc, labels, confs = test(net, rgn_holdout_adl)
            fig_test = plot_roc(roc,
                                f"test {i}",
                                f"{auc}",
                                show=False,
                                fig=fig_test)

            all_aucs = all_aucs + [auc] if j > 0 else [auc]
            all_confs = np.vstack((all_confs, confs)) if j > 0 else confs
            all_labels = np.vstack((all_labels, labels)) if j > 0 else labels
            print(f"done testing {i} {j}")
        fig_test.savefig(os.path.join(FIGS_PATH, f'test_{i}.png'))

        # fusion
        fused_confs = gmean(all_confs, axis=0)
        fused_roc = roc_curve(all_labels[0, :], fused_confs)
        if len(set(all_labels[0, :])) > 1:
            fused_auc = roc_auc_score(all_labels[0, :], fused_confs)
        else:
            fused_auc = 'NaN'
        fig_fused = plot_roc(fused_roc, f"fused {i} roc", f"auc {fused_auc}")
        fig_fused.savefig(os.path.join(FIGS_PATH, f'fused_{i}.png'))

        # histograms
        hist_range = (np.min(all_confs), np.max(all_confs))
        hist_fig = plt.figure()
        ax = hist_fig.add_axes([0.1, 0.1, 0.85, 0.8])
        ax.set_title(f"hist {i}")
        for j, confs in enumerate(all_confs):
            ax.hist(confs,
                    100,
                    range=hist_range,
                    density=True,
                    label=all_aucs[j],
                    alpha=0.7)
        hist_fig.legend(loc='upper right')
        hist_fig.savefig(os.path.join(FIGS_PATH, f'hist_{i}.png'))

        plt.close(fig_test)
        plt.close(fig_fused)
        plt.close(hist_fig)