コード例 #1
0
def main():
    # Reproducibility
    np.random.seed(cfg['seed'])
    torch.manual_seed(cfg['seed'])

    # Model & Optimizer
    model = getModel(cfg['model'])
    start_epoch = 1
    max_epoch = 1

    # Load model and optimizer
    if cfg['load_ckpt'] != '':
        checkpoint = torch.load(cfg['load_ckpt'], map_location="cpu")
        model.load_state_dict(checkpoint['model_state'])
        print("load model on '{}' is complete.".format(cfg['load_ckpt']))
    cudnn.benchmark = True

    # Data Loader
    in_valid_loader = getDataLoader(ds_cfg=cfg['in_dataset'],
                                    dl_cfg=cfg['dataloader'],
                                    split="valid")

    if 'targets' in cfg['in_dataset'].keys():
        exp_dir = os.path.join(cfg['exp_root'], cfg['exp_dir'], "logits",
                               cfg['in_dataset']['dataset'],
                               cfg['in_dataset']['targets'][0])
    else:
        exp_dir = os.path.join(cfg['exp_root'], cfg['exp_dir'], "logits",
                               cfg['in_dataset']['dataset'])

    # Result directory and make tensorboard event file
    if os.path.exists(exp_dir) is False:
        os.makedirs(exp_dir)

    # Outlier detector
    detector_func = detectors.getDetector(cfg['detector'])
    global_cfg['detector'] = cfg['detector']

    # Outlier detector
    print("=======================IMPORTANT CONFIG=======================")
    print(" Model    : {}\n \
Detector     : {}\n".format(cfg['model']['network_kind'],
                            cfg['detector']['detector']))
    print(
        "========Start logits extraction for GODIN. Result will be saved in {}"
        .format(exp_dir))

    valid_summary = valid_epoch_wo_outlier(model, in_valid_loader,
                                           detector_func)
    summary_log = "Acc [{}]\n".format(valid_summary['classifier_acc'])
    print(summary_log)

    torch.save(valid_summary['f_confidences'],
               os.path.join(exp_dir, 'f_confidences.pt'))
    torch.save(valid_summary['h_confidences'],
               os.path.join(exp_dir, 'h_confidences.pt'))
    torch.save(valid_summary['targets'], os.path.join(exp_dir, 'targets.pt'))
    torch.save(valid_summary['logits'], os.path.join(exp_dir, 'logits.pt'))
コード例 #2
0
def main():
    global global_cfg
    # Reproducibility
    np.random.seed(cfg['seed'])
    torch.manual_seed(cfg['seed'])
    
    # Model & Optimizer
    model = getModel(cfg['model'])
    model.rot_head = nn.Linear(model.nChannels, 4)
    model.rot_head.cuda()
    optimizer = optim.getOptimizer(model, cfg['optim'])
    start_epoch = 1
    
    # Load model and optimizer
    if cfg['load_ckpt'] != '':
        checkpoint = torch.load(cfg['load_ckpt'], map_location="cpu")
        model.load_state_dict(checkpoint['model_state'])
        print("load model on '{}' is complete.".format(cfg['load_ckpt']))
        if not cfg['finetuning']:
            optimizer.load_state_dict(checkpoint['optimizer_state'])
        if 'epoch' in checkpoint.keys() and not cfg['finetuning']:
            start_epoch = checkpoint['epoch']
            print("Restore epoch {}".format(start_epoch))
        else:
            start_epoch = 1
    cudnn.benchmark = True
    
    # Data Loader
    in_train_loader = getDataLoader(ds_cfg=cfg['in_dataset'],
                                    dl_cfg=cfg['dataloader'],
                                    split="train")
    in_valid_loader = getDataLoader(ds_cfg=cfg['in_dataset'],
                                    dl_cfg=cfg['dataloader'],
                                    split="valid")
    attack_in = None
    if 'PGD' in cfg.keys() and cfg['PGD'] is not None:
        attack_in = RotPGDAttack(model=model, eps=cfg['PGD']['epsilon'],
                                  nb_iter=cfg['PGD']['iters'],
                              eps_iter=cfg['PGD']['iter_size'], rand_init=True,
                                  loss_func='CE')
    
    if cfg['out_dataset'] is not None:
        out_train_loader = getDataLoader(ds_cfg=cfg['out_dataset'],
                                         dl_cfg=cfg['dataloader'],
                                         split="train")
        out_valid_loader = getDataLoader(ds_cfg=cfg['out_dataset'],
                                         dl_cfg=cfg['dataloader'],
                                         split="valid")
    else:
        out_train_loader = None
        out_valid_loader = None
    
    # Result directory and make tensorboard event file
    exp_dir = os.path.join(cfg['exp_root'], cfg['exp_dir'])
    if os.path.exists(exp_dir) is False:
        os.makedirs(exp_dir)
    shutil.copy('./config.py', os.path.join(exp_dir, "config.py"))
    writer_train = SummaryWriter(logdir=os.path.join(exp_dir, 'log', 'train'))
    writer_valid = SummaryWriter(logdir=os.path.join(exp_dir, 'log', 'valid'))
    
    # Stats Meters
    #train_meter = TrainMeter()
    #valid_meter = ValidMeter()
    
    # Loss function
    global_cfg['loss'] = cfg['loss']
    
    # Outlier detector
    detector_func = detectors.getDetector(cfg['detector'])
    global_cfg['detector'] = cfg['detector']
    
    print("=======================IMPORTANT CONFIG=======================")
    print(" Model    : {}\n \
Loss     : {}\n \
Detector : {}\n \
Optimizer: {}\n".format(cfg['model']['network_kind'], cfg['loss']['loss'], cfg['detector']['detector'], cfg['optim']['optimizer']))
    print("============Start training. Result will be saved in {}".format(exp_dir))
    
    for cur_epoch in range(start_epoch, cfg['max_epoch'] + 1):
        if out_train_loader is not None:
            train_summary = train_epoch_w_outlier(model, optimizer, in_train_loader, out_train_loader, loss_func, detector_func, cur_epoch, cfg['optim'], writer_train)
        else:
            train_summary = train_epoch_wo_outlier(model, optimizer, in_train_loader, attack_in, cur_epoch, cfg['optim'], writer_train)
        summary_write(summary=train_summary, writer=writer_train)
        print("Training result=========Epoch [{}]/[{}]=========\nlr: {} | loss: {} | acc: {}".format(cur_epoch, cfg['max_epoch'], train_summary['lr'], train_summary['avg_loss'], train_summary['classifier_acc']))
        
        
        if cur_epoch % cfg['valid_epoch'] == 0:
            if out_valid_loader is not None:
                valid_summary = valid_epoch_w_outlier(model, in_valid_loader, out_valid_loader, loss_func, detector_func, cur_epoch)
            else:
                valid_summary = valid_epoch_wo_outlier(model, in_valid_loader, cur_epoch)
            summary_write(summary=valid_summary, writer=writer_valid)
            print("Validate result=========Epoch [{}]/[{}]=========\nloss: {} | acc: {}".format(cur_epoch, cfg['max_epoch'], valid_summary['avg_loss'], valid_summary['classifier_acc']))
        
        if cur_epoch % cfg['ckpt_epoch'] == 0:
            ckpt_dir = os.path.join(cfg['exp_root'], cfg['exp_dir'], "ckpt")
            if os.path.exists(ckpt_dir) is False:
                os.makedirs(ckpt_dir)
            model_state = model.module.state_dict() if cfg['ngpu'] > 1 else model.state_dict()
            checkpoint = {
                "epoch": cur_epoch,
                "model_state": model_state,
                "optimizer_state": optimizer.state_dict(),
            }
            ckpt_name = "checkpoint_epoch_{}".format(cur_epoch)
            ckpt_path = os.path.join(ckpt_dir, ckpt_name + ".pyth")
            torch.save(checkpoint, ckpt_path)
コード例 #3
0
ファイル: valid_ovnni.py プロジェクト: emckwon/OOD-saige
def main():
    # Reproducibility
    np.random.seed(cfg['seed'])
    torch.manual_seed(cfg['seed'])

    # Model & Optimizer
    model = getModel(cfg['model'])
    start_epoch = 1
    max_epoch = 1

    assert len(cfg['load_ckpt']) == len(cfg['in_dataset']['targets']) + 1
    # Load model and optimizer
    for idx, ckpt in enumerate(cfg['load_ckpt']):
        checkpoint = torch.load(ckpt, map_location="cpu")
        if idx == 0:
            model.ava_network.load_state_dict(checkpoint['model_state'])
        else:
            model.ova_networks[idx - 1].load_state_dict(
                checkpoint['model_state'])
        print("load model on '{}' is complete.".format(ckpt))
    cudnn.benchmark = True

    # Data Loader
    in_valid_loader = getDataLoader(ds_cfg=cfg['in_dataset'],
                                    dl_cfg=cfg['dataloader'],
                                    split="valid")

    if cfg['out_dataset'] is not None:
        out_valid_loader = getDataLoader(ds_cfg=cfg['out_dataset'],
                                         dl_cfg=cfg['dataloader'],
                                         split="valid")
        exp_dir = os.path.join(cfg['exp_root'], cfg['exp_dir'], "valid",
                               cfg['out_dataset']['dataset'])
    else:
        out_train_loader = None
        out_valid_loader = None
        exp_dir = os.path.join(cfg['exp_root'], cfg['exp_dir'], "valid",
                               "classifier")

    # Result directory and make tensorboard event file
    if os.path.exists(exp_dir) is False:
        os.makedirs(exp_dir)
    shutil.copy('./config.py', os.path.join(exp_dir, "val_config.py"))

    # Loss function
    loss_func = losses.getLoss(cfg['loss'])
    global_cfg['loss'] = cfg['loss']

    # Outlier detector
    detector_func = detectors.getDetector(cfg['detector'])
    global_cfg['detector'] = cfg['detector']
    print("=======================IMPORTANT CONFIG=======================")
    print(" Model    : {}\n \
Loss     : {}\n \
Detector : {}\n".format(cfg['model']['network_kind'], cfg['loss']['loss'],
                        cfg['detector']['detector']))
    print(
        "========Start validation. Result will be saved in {}".format(exp_dir))

    logfile = open(os.path.join(exp_dir, "validation_log.txt"), "w")
    logfile2 = open(os.path.join(exp_dir, "wrong_predict_log.txt"), "w")
    for cur_epoch in range(start_epoch, max_epoch + 1):
        if out_valid_loader is not None:
            valid_summary = valid_epoch_w_outlier(model, in_valid_loader,
                                                  out_valid_loader, loss_func,
                                                  detector_func, cur_epoch,
                                                  logfile2)
            summary_log = "=============Epoch [{}]/[{}]=============\nloss: {} | acc: {} | acc_w_ood: {}\nAUROC: {} | AUPR: {} | FPR95: {}\nInlier Conf. {} | Outlier Conf. {}\n".format(
                cur_epoch, max_epoch, valid_summary['avg_loss'],
                valid_summary['classifier_acc'], valid_summary['acc'],
                valid_summary['AUROC'], valid_summary['AUPR'],
                valid_summary['FPR95'], valid_summary['inlier_confidence'],
                valid_summary['outlier_confidence'])

            ind_max, ind_min = np.max(valid_summary['inliers']), np.min(
                valid_summary['inliers'])
            ood_max, ood_min = np.max(valid_summary['outliers']), np.min(
                valid_summary['outliers'])

            ranges = (ind_min if ind_min < ood_min else ood_min,
                      ind_max if ind_max > ood_max else ood_max)

            fig = plt.figure()
            sns.distplot(valid_summary['inliers'].ravel(),
                         hist_kws={'range': ranges},
                         kde=False,
                         bins=50,
                         norm_hist=True,
                         label='In-distribution')
            sns.distplot(valid_summary['outliers'],
                         hist_kws={'range': ranges},
                         kde=False,
                         bins=50,
                         norm_hist=True,
                         label='Out-of-distribution')
            plt.xlabel('Confidence')
            plt.ylabel('Density')
            fig.legend()
            fig.savefig(os.path.join(exp_dir, "confidences.png"))

        else:
            valid_summary = valid_epoch_wo_outlier(model, in_valid_loader,
                                                   loss_func, cur_epoch,
                                                   logfile2)
            summary_log = "=============Epoch [{}]/[{}]=============\nloss: {} | acc: {}\n".format(
                cur_epoch, max_epoch, valid_summary['avg_loss'],
                valid_summary['classifier_acc'])

        print(summary_log)
        logfile.write(summary_log)

    logfile.close()
    logfile2.close()