Exemplo n.º 1
0
 def __init__(self, dataset_config: Config.Dataset, buffer_size: int = 64):
     super(SyntheticDataset).__init__()
     self.device = torch.device(
         "cuda" if torch.cuda.is_available() else "cpu")
     self.label = dataset_config.classes[-1]
     self.model = ModelFactory.create(
         dataset_config.gan_model,
         n_classes=len(dataset_config.classes)).to(self.device)
     self.buffer_size = buffer_size
     self.pointer = 0
     self.buffer = self.generate_buffer()
Exemplo n.º 2
0
 def __init__(self, env: MultiEnv, model_factory: ModelFactory, curiosity_factory: CuriosityFactory,
              normalize_state: bool, normalize_reward: bool, reporter: Reporter = NoReporter()) -> None:
     self.env = env
     self.reporter = reporter
     self.state_converter = Converter.for_space(self.env.observation_space)
     self.action_converter = Converter.for_space(self.env.action_space)
     self.model = model_factory.create(self.state_converter, self.action_converter)
     self.curiosity = curiosity_factory.create(self.state_converter, self.action_converter)
     self.reward_normalizer = StandardNormalizer() if normalize_reward else NoNormalizer()
     self.state_normalizer = self.state_converter.state_normalizer() if normalize_state else NoNormalizer()
     self.normalize_state = normalize_state
     self.device: torch.device = None
     self.dtype: torch.dtype = None
     self.numpy_dtype: object = None
Exemplo n.º 3
0
def test_with_cover_stego_biased_proportions(prob_cover=0.99, prob_stego=0.01):
    print("[Testing with prob_stego: %.3f" % (prob_stego))
    # Create dataset and dataloader
    dset_test = TimitTestSet(dpath_cover,
                             dpath_stego,
                             seed=SEED,
                             prob_cover=prob_cover,
                             prob_stego=prob_stego,
                             dtype=np.float32)
    n_data = len(dset_test)
    n_train = math.floor(0.8 * n_data)
    ix_end_valid = n_train
    indices = np.arange(n_data)
    sampler_test = SubsetRandomSampler(indices[ix_end_valid:])

    # Create dataloader_train
    dataloader_test = DataLoader(dset_test,
                                 batch_size=BATCH_SIZE,
                                 sampler=sampler_test,
                                 num_workers=N_WORKERS,
                                 pin_memory=True)

    # Create model
    model = ModelFactory.create(config)
    model.to(device, dtype=dtype)
    model = nn.DataParallel(model)

    global fpath_load_ckpt
    if dpath_load_ckpt:
        if not fpath_load_ckpt:
            fpath_load_ckpt = get_ckpt(dpath_load_ckpt, policy=LOAD_POLICY)

        load_model(fpath_load_ckpt, model)
        print("[%s]" % LOAD_POLICY.upper(), fpath_load_ckpt,
              "has been loaded...")
    elif fpath_load_ckpt:
        load_model(fpath_load_ckpt, model)

    loss_ce = nn.CrossEntropyLoss()

    def classification_loss(logits, target_labels):
        return loss_ce(logits, target_labels)

    def classify(model, batch):
        audios, labels = batch
        audios = audios.to(device, non_blocking=True)
        labels = labels.to(device, non_blocking=True)

        logits = model(audios)
        loss = classification_loss(logits, labels)
        acc = accuracy_top1(logits, labels)
        ps = labels.to(torch.float32).mean().item()

        return (logits, loss.item(), acc, ps)

    def compute_rates(labels_pred, labels_true):
        # Calculate confusion matrix
        cm = confusion_matrix(labels_pred, labels_true, labels=(1, 0))

        tp = cm[0, 0]  # True stegos  (stegos)
        tn = cm[1, 1]  # True covers  (covers)
        fp = cm[0, 1]  # False stegos (covers)
        fn = cm[1, 0]  # False covers (stegos)

        p = tp + fn
        n = tn + fp
        tpr = tp / p  # Sensitivity
        fpr = fp / n  # False alarm (1 - specificity)
        tnr = tn / n  # Specificity
        fnr = fn / p  # Miss rate

        return tpr, fpr, tnr, fnr

    # Lists for statistics
    list_stats = []
    list_acc = []
    list_loss = []
    list_prob = []

    # Lists for true labels, scores, predictions
    list_scores = []
    list_labels = []

    num_audios = 0
    model.eval()

    for epoch in tqdm.tqdm(range(N_REPEATS)):
        # Testing model
        sum_acc = 0
        sum_loss = 0
        sum_prob_stego = 0
        list_single_test_preds = []
        list_single_test_labels = []

        for step, batch in enumerate(dataloader_test):
            num_audios += 2 * len(batch)
            with torch.no_grad():
                # ps denotes prob. of fetching stegos
                logits, loss, acc, ps = classify(model, batch)
                sum_acc += acc
                sum_loss += loss
                sum_prob_stego += ps

                # Compute score for roc_curve
                sm = torch.softmax(logits, dim=0)
                list_scores.append(sm[:, 1].cpu().numpy())

                _, labels = batch
                list_single_test_labels.append(labels.cpu().numpy())

                preds = logits.topk(1).indices.view(-1)  # Predictions
                list_single_test_preds.append(preds.cpu().numpy())

        # end of for

        avg_acc = sum_acc / len(dataloader_test)
        avg_loss = sum_loss / len(dataloader_test)
        avg_prob_stego = sum_prob_stego / len(dataloader_test)

        # Compute the rates
        labels_pred = np.concatenate(list_single_test_preds)
        labels_true = np.concatenate(list_single_test_labels)
        tpr, fpr, tnr, fnr = compute_rates(labels_pred, labels_true)

        fstr = "- Acc:%.4f, Loss:%.6f, Ps:%.4f, " \
               "FA(fpr):%.4f, MD(fnr):%.4f, PE:%.4f"
        print()
        print(fstr % (avg_acc, avg_loss, avg_prob_stego, fpr, fnr, 0.5 *
                      (fpr + fnr)))
        # end of for
        list_acc.append(avg_acc)
        list_loss.append(avg_loss)
        list_prob.append(avg_prob_stego)
        list_labels.append(labels_true)
        list_stats.append({
            "test_avg_acc": avg_acc,
            "test_avg_loss": avg_loss,
            "test_avg_prob_stego": avg_prob_stego,
            "test_avg_prob_cover": 1 - avg_prob_stego,
            "test_tpr": tpr,
            "test_fpr": fpr,
            "test_tnr": tnr,
            "test_fnr": fnr,
        })
    # end of for

    # Compute ROC
    labels_true = np.concatenate(list_labels)
    y_score = np.concatenate(list_scores)
    roc_fpr, roc_tpr, roc_thr = roc_curve(labels_true, y_score)
    roc_auc = roc_auc_score(labels_true, y_score)

    print()
    print("- Avg. acc:", "%.4f ± %.4f" % (np.mean(list_acc), np.std(list_acc)))
    print("- Avg. loss:",
          "%.6f ± %.4f" % (np.mean(list_loss), np.std(list_loss)))
    print("- Avg. prob:",
          "%.4f ± %.4f" % (np.mean(list_prob), np.std(list_prob)))
    print("- Total num. tested audios:", num_audios)
    print()

    df_stats = pd.DataFrame(list_stats)

    dict_stats = {
        "model": MODEL,
        "steganography": STEGANOGRAPHY.lower(),
        "ps": ps,
        "stats": df_stats,
        "roc": {
            "roc_auc": roc_auc,
            "roc_tpr": roc_tpr,
            "roc_fpr": roc_fpr,
            "roc_thr": roc_thr
        }
    }

    return dict_stats
Exemplo n.º 4
0
                                  pin_memory=True)

    dataloader_valid = DataLoader(dset_valid,
                                  batch_size=BATCH_SIZE,
                                  sampler=sampler_valid,
                                  num_workers=N_WORKERS,
                                  pin_memory=True)

    dataloader_test = DataLoader(dset_test,
                                 batch_size=BATCH_SIZE,
                                 sampler=sampler_test,
                                 num_workers=N_WORKERS,
                                 pin_memory=True)

    # Create model    
    model = ModelFactory.create(config)    
    model.to(device, dtype=dtype)
    lr = float(LEARNING_RATE)
    optimizer = model.get_optimizer(model, lr)
    lr_scheduler = model.get_lr_scheduler(optimizer)       

    if DPATH_LOAD_CKPT:
        if not fpath_load_ckpt:
            fpath_load_ckpt = get_ckpt(DPATH_LOAD_CKPT, LOAD_POLICY) #get_best_ckpt_with_criterion(dpath_load_ckpt, LOAD_POLICY)
        load_model(fpath_load_ckpt, model)
        print("[%s]"%(LOAD_POLICY.upper()), fpath_load_ckpt, "has been loaded...")
    # end of if
    model = nn.DataParallel(model)

    loss_ce = nn.CrossEntropyLoss()   
    def classification_loss(logits, target_labels):
Exemplo n.º 5
0
def main(
        config_path: str = 'configs/classification.yaml',
        dataset_name: str = 'svhn',
        imbalance_ratio: int = 1,
        oversampling: str = 'none',  # none, oversampling, gan
        ada: bool = False,  # only for gan training
        seed: int = 1,  # No seed if 0
        wandb_logs: bool = False,
        test: bool = False,
        load_model: bool = False):
    # Ensure output directory exists
    if not os.path.exists(OUTPUT_PATH):
        os.mkdir(OUTPUT_PATH)

    # Set a seed
    if seed:
        torch.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
        np.random.seed(seed)
        random.seed(seed)

    # Load configuration
    logger.info(f'Loading config at "{config_path}"...')
    config = load_config(config_path, dataset_name, imbalance_ratio,
                         oversampling, ada, load_model)

    if config.trainer.task == 'generation' and test:
        raise ValueError('Cannot test the generation models')

    # Init logging with WandB
    mode = 'offline' if wandb_logs else 'disabled'
    wandb.init(mode=mode,
               dir=OUTPUT_PATH,
               entity=WANDB_TEAM,
               project=PROJECT_NAME,
               group=config.trainer.task,
               config=dataclasses.asdict(config))

    # Load model
    logger.info('Loading model...')
    model = ModelFactory.create(model_config=config.model,
                                n_classes=len(config.dataset.classes))

    # Load dataset
    logger.info('Loading dataset...')
    train_dataset, valid_dataset, test_dataset = DatasetFactory.create(
        dataset_config=config.dataset)

    # Instatiate trainer
    logger.info('Loading trainer...')
    trainer = TrainerFactory.create(trainer_config=config.trainer,
                                    train_dataset=train_dataset,
                                    valid_dataset=valid_dataset,
                                    model=model,
                                    classes=config.dataset.classes)

    if test:
        logger.info('Testing...')
        trainer.test(test_dataset)
    else:
        logger.info('Training...')
        trainer.train()

    # Cleanup
    wandb.finish()

    logger.info('done :)')