Exemple #1
0
 def MakeProgressiveSamplers(self,
                             model,
                             train_data,
                             do_fanout_scaling=False):
     estimators = []
     dropout = self.dropout or self.per_row_dropout
     for n in self.eval_psamples:
         if self.factorize:
             estimators.append(
                 estimators_lib.FactorizedProgressiveSampling(
                     model,
                     train_data,
                     n,
                     self.join_spec,
                     device=train_utils.get_device(),
                     shortcircuit=dropout,
                     do_fanout_scaling=do_fanout_scaling))
         else:
             estimators.append(
                 estimators_lib.ProgressiveSampling(
                     model,
                     train_data,
                     n,
                     self.join_spec,
                     device=train_utils.get_device(),
                     shortcircuit=dropout,
                     do_fanout_scaling=do_fanout_scaling))
     return estimators
Exemple #2
0
def main():
    args = get_parser().parse_args()
    args = hparams.merge_args_hparams(args)
    train_utils.save_json(os.path.join(args.exp_directory, "config.json"), args.__dict__)
    train_utils.log_training_information(args)

    # Build data loaders, model, and optimizer.
    device = train_utils.get_device()
    train_loader, val_loader = get_data_loaders(args)
    model, criterion = get_model_and_loss(args, device)
    optimizer = get_optimizer(args, model)

    # Build trainer and validator.
    trainer = train_utils.create_supervised_trainer(model,
                                                    optimizer,
                                                    criterion,
                                                    metrics=train_utils.get_metrics(args, criterion),
                                                    device=device,
                                                    grad_clip=args.grad_clip,
                                                    grad_norm=args.grad_norm)
    validator = train_utils.create_supervised_evaluator(model,
                                                        metrics=train_utils.get_metrics(args, criterion),
                                                        device=device)

    # Add handlers.
    action_names = train_utils.load_action_names(args)
    trainer.add_event_handler(Events.EPOCH_COMPLETED, train_utils.val_epoch_completed_logger, validator,
                              val_loader, action_names)
    trainer.add_event_handler(Events.EPOCH_STARTED, train_utils.train_epoch_started_logger, args.max_epochs)
    timer = train_utils.attach_timer(trainer)
    trainer.add_event_handler(Events.EPOCH_COMPLETED, train_utils.train_epoch_completed_logger, timer)
    trainer.add_event_handler(Events.ITERATION_COMPLETED,
                              train_utils.batch_logger,
                              n_batches=len(train_loader),
                              split="Train")
    train_utils.maybe_attach_lr_scheduler_handler(trainer, args, optimizer, len(train_loader))
    logger = train_utils.attach_tensorboard_handler(args, trainer, validator, optimizer, model, train_loader)
    validator.add_event_handler(Events.EPOCH_COMPLETED, train_utils.get_model_checkpoint_handler(args),
                                {"model": model})
    validator.add_event_handler(Events.EPOCH_COMPLETED,
                                train_utils.get_interval_model_checkpoint_handler(args), {"model": model})
    validator.add_event_handler(
        Events.EPOCH_COMPLETED,
        train_utils.get_early_stopping_handler(args, trainer, args.early_stopping_patience))
    validator.add_event_handler(Events.ITERATION_COMPLETED,
                                train_utils.batch_logger,
                                n_batches=len(val_loader),
                                split="Val")

    # Run training.
    trainer.run(train_loader, max_epochs=args.max_epochs)
    logger.close()
Exemple #3
0
def main(dataset,
         augment=False,
         use_scattering=False,
         size=None,
         batch_size=2048,
         mini_batch_size=256,
         sample_batches=False,
         lr=1,
         optim="SGD",
         momentum=0.9,
         nesterov=False,
         noise_multiplier=1,
         max_grad_norm=0.1,
         epochs=100,
         input_norm=None,
         num_groups=None,
         bn_noise_multiplier=None,
         max_epsilon=None,
         logdir=None,
         early_stop=True,
         seed=0):
    torch.manual_seed(seed)
    logger = Logger(logdir)
    device = get_device()

    train_data, test_data = get_data(dataset, augment=augment)

    if use_scattering:
        scattering, K, _ = get_scatter_transform(dataset)
        scattering.to(device)
    else:
        scattering = None
        K = 3 if len(train_data.data.shape) == 4 else 1

    bs = batch_size
    assert bs % mini_batch_size == 0
    n_acc_steps = bs // mini_batch_size

    # Batch accumulation and data augmentation with Poisson sampling isn't implemented
    if sample_batches:
        assert n_acc_steps == 1
        assert not augment

    train_loader = torch.utils.data.DataLoader(train_data,
                                               batch_size=mini_batch_size,
                                               shuffle=True,
                                               num_workers=1,
                                               pin_memory=True)

    test_loader = torch.utils.data.DataLoader(test_data,
                                              batch_size=mini_batch_size,
                                              shuffle=False,
                                              num_workers=1,
                                              pin_memory=True)

    rdp_norm = 0
    if input_norm == "BN":
        # compute noisy data statistics or load from disk if pre-computed
        save_dir = f"bn_stats/{dataset}"
        os.makedirs(save_dir, exist_ok=True)
        bn_stats, rdp_norm = scatter_normalization(
            train_loader,
            scattering,
            K,
            device,
            len(train_data),
            len(train_data),
            noise_multiplier=bn_noise_multiplier,
            orders=ORDERS,
            save_dir=save_dir)
        model = CNNS[dataset](K, input_norm="BN", bn_stats=bn_stats, size=size)
    else:
        model = CNNS[dataset](K,
                              input_norm=input_norm,
                              num_groups=num_groups,
                              size=size)

    model.to(device)

    if use_scattering and augment:
        model = nn.Sequential(scattering, model)
        train_loader = torch.utils.data.DataLoader(train_data,
                                                   batch_size=mini_batch_size,
                                                   shuffle=True,
                                                   num_workers=1,
                                                   pin_memory=True,
                                                   drop_last=True)
    else:
        # pre-compute the scattering transform if necessery
        train_loader = get_scattered_loader(train_loader,
                                            scattering,
                                            device,
                                            drop_last=True,
                                            sample_batches=sample_batches)
        test_loader = get_scattered_loader(test_loader, scattering, device)

    print(f"model has {get_num_params(model)} parameters")

    if optim == "SGD":
        optimizer = torch.optim.SGD(model.parameters(),
                                    lr=lr,
                                    momentum=momentum,
                                    nesterov=nesterov)
    else:
        optimizer = torch.optim.Adam(model.parameters(), lr=lr)

    privacy_engine = PrivacyEngine(
        model,
        batch_size=bs,
        sample_size=len(train_data),
        alphas=ORDERS,
        noise_multiplier=noise_multiplier,
        max_grad_norm=max_grad_norm,
    )
    privacy_engine.attach(optimizer)

    best_acc = 0
    flat_count = 0

    results = dict(train_zeon=[],
                   train_xent=[],
                   test_zeon=[],
                   test_xent=[],
                   epoch=[])
    for epoch in range(0, epochs):
        print(f"\nEpoch: {epoch}")

        train_loss, train_acc = train(model,
                                      train_loader,
                                      optimizer,
                                      n_acc_steps=n_acc_steps)
        test_loss, test_acc = test(model, test_loader)

        results['train_zeon'].append(train_acc)
        results['train_xent'].append(train_loss)
        results['test_zeon'].append(test_acc)
        results['test_xent'].append(test_loss)
        results['epoch'].append(epoch)

        if noise_multiplier > 0:
            rdp_sgd = get_renyi_divergence(
                privacy_engine.sample_rate,
                privacy_engine.noise_multiplier) * privacy_engine.steps
            epsilon, _ = get_privacy_spent(rdp_norm + rdp_sgd)
            epsilon2, _ = get_privacy_spent(rdp_sgd)
            print(f"ε = {epsilon:.3f} (sgd only: ε = {epsilon2:.3f})")

            if max_epsilon is not None and epsilon >= max_epsilon:
                return
        else:
            epsilon = None

        logger.log_epoch(epoch, train_loss, train_acc, test_loss, test_acc,
                         epsilon)
        logger.log_scalar("epsilon/train", epsilon, epoch)

        # stop if we're not making progress
        if test_acc > best_acc:
            best_acc = test_acc
            flat_count = 0
        else:
            flat_count += 1
            if flat_count >= 20 and early_stop:
                print("plateau...")
                break

    # Write to file.
    record = {
        **results,
        **{
            'best_acc': best_acc,
            'seed': seed,
            'dataset': dataset
        }
    }
    record_path = os.path.join('.', 'record', f'{dataset}-{seed}.json')
    os.makedirs(os.path.dirname(record_path), exist_ok=True)
    with open(record_path, 'w') as f:
        json.dump(record, f, indent=4)
    import logging
    logging.warning(f'Wrote to file: {record_path}')
Exemple #4
0
def run_epoch(split,
              model,
              opt,
              train_data,
              val_data=None,
              batch_size=100,
              upto=None,
              epoch_num=None,
              epochs=1,
              verbose=False,
              log_every=10,
              return_losses=False,
              table_bits=None,
              warmups=1000,
              loader=None,
              constant_lr=None,
              use_meters=True,
              summary_writer=None,
              lr_scheduler=None,
              custom_lr_lambda=None,
              label_smoothing=0.0):
    torch.set_grad_enabled(split == 'train')
    model.train() if split == 'train' else model.eval()
    dataset = train_data if split == 'train' else val_data
    losses = []

    if loader is None:
        loader = data.DataLoader(dataset,
                                 batch_size=batch_size,
                                 shuffle=(split == 'train'))

    # How many orderings to run for the same batch?
    nsamples = 1
    if hasattr(model, 'orderings'):
        nsamples = len(model.orderings)
        if verbose:
            print('setting nsamples to', nsamples)

    dur_meter = train_utils.AverageMeter('dur',
                                         lambda v: '{:.0f}s'.format(v),
                                         display_average=False)
    lr_meter = train_utils.AverageMeter('lr', ':.5f', display_average=False)
    tups_meter = train_utils.AverageMeter('tups',
                                          utils.HumanFormat,
                                          display_average=False)
    loss_meter = train_utils.AverageMeter('loss (bits/tup)', ':.2f')
    train_throughput = train_utils.AverageMeter('tups/s',
                                                utils.HumanFormat,
                                                display_average=False)
    batch_time = train_utils.AverageMeter('sgd_ms', ':3.1f')
    data_time = train_utils.AverageMeter('data_ms', ':3.1f')
    progress = train_utils.ProgressMeter(upto, [
        batch_time,
        data_time,
        dur_meter,
        lr_meter,
        tups_meter,
        train_throughput,
        loss_meter,
    ])

    begin_time = t1 = time.time()

    for step, xb in enumerate(loader):
        data_time.update((time.time() - t1) * 1e3)

        if split == 'train':
            if isinstance(dataset, data.IterableDataset):
                # Can't call len(loader).
                global_steps = upto * epoch_num + step + 1
            else:
                global_steps = len(loader) * epoch_num + step + 1

            if constant_lr:
                lr = constant_lr
                for param_group in opt.param_groups:
                    param_group['lr'] = lr
            elif custom_lr_lambda:
                lr_scheduler = None
                lr = custom_lr_lambda(global_steps)
                for param_group in opt.param_groups:
                    param_group['lr'] = lr
            elif lr_scheduler is None:
                t = warmups
                if warmups < 1:  # A ratio.
                    t = int(warmups * upto * epochs)

                d_model = model.embed_size
                lr = (d_model**-0.5) * min(
                    (global_steps**-.5), global_steps * (t**-1.5))
                for param_group in opt.param_groups:
                    param_group['lr'] = lr
            else:
                # We'll call lr_scheduler.step() below.
                lr = opt.param_groups[0]['lr']

        if upto and step >= upto:
            break

        if isinstance(xb, list):
            # This happens if using data.TensorDataset.
            assert len(xb) == 1, xb
            xb = xb[0]

        xb = xb.float().to(train_utils.get_device(), non_blocking=True)

        # Forward pass, potentially through several orderings.
        xbhat = None
        model_logits = []
        num_orders_to_forward = 1
        if split == 'test' and nsamples > 1:
            # At test, we want to test the 'true' nll under all orderings.
            num_orders_to_forward = nsamples

        for i in range(num_orders_to_forward):
            if hasattr(model, 'update_masks'):
                # We want to update_masks even for first ever batch.
                model.update_masks()

            model_out = model(xb)
            model_logits.append(model_out)
            if xbhat is None:
                xbhat = torch.zeros_like(model_out)
            xbhat += model_out

        if num_orders_to_forward == 1:
            loss = model.nll(xbhat, xb, label_smoothing=label_smoothing).mean()
        else:
            # Average across orderings & then across minibatch.
            #
            #   p(x) = 1/N sum_i p_i(x)
            #   log(p(x)) = log(1/N) + log(sum_i p_i(x))
            #             = log(1/N) + logsumexp ( log p_i(x) )
            #             = log(1/N) + logsumexp ( - nll_i (x) )
            #
            # Used only at test time.
            logps = []  # [batch size, num orders]
            assert len(model_logits) == num_orders_to_forward, len(
                model_logits)
            for logits in model_logits:
                # Note the minus.
                logps.append(
                    -model.nll(logits, xb, label_smoothing=label_smoothing))
            logps = torch.stack(logps, dim=1)
            logps = logps.logsumexp(dim=1) + torch.log(
                torch.tensor(1.0 / nsamples, device=logps.device))
            loss = (-logps).mean()

        losses.append(loss.detach().item())

        if split == 'train':
            opt.zero_grad()
            loss.backward()
            l2_grad_norm = TotalGradNorm(model.parameters())

            opt.step()
            if lr_scheduler is not None:
                lr_scheduler.step()

            loss_bits = loss.item() / np.log(2)

            # Number of tuples processed in this epoch so far.
            ntuples = (step + 1) * batch_size
            if use_meters:
                dur = time.time() - begin_time
                lr_meter.update(lr)
                tups_meter.update(ntuples)
                loss_meter.update(loss_bits)
                dur_meter.update(dur)
                train_throughput.update(ntuples / dur)

            if summary_writer is not None:
                wandb.log({
                    'train/lr': lr,
                    'train/tups': ntuples,
                    'train/tups_per_sec': ntuples / dur,
                    'train/nll': loss_bits,
                    'train/global_step': global_steps,
                    'train/l2_grad_norm': l2_grad_norm,
                })
                summary_writer.add_scalar('train/lr',
                                          lr,
                                          global_step=global_steps)
                summary_writer.add_scalar('train/tups',
                                          ntuples,
                                          global_step=global_steps)
                summary_writer.add_scalar('train/tups_per_sec',
                                          ntuples / dur,
                                          global_step=global_steps)
                summary_writer.add_scalar('train/nll',
                                          loss_bits,
                                          global_step=global_steps)

            if step % log_every == 0:
                if table_bits:
                    print(
                        'Epoch {} Iter {}, {} entropy gap {:.4f} bits (loss {:.3f}, data {:.3f}) {:.5f} lr, {} tuples seen ({} tup/s)'
                        .format(
                            epoch_num, step, split,
                            loss.item() / np.log(2) - table_bits,
                            loss.item() / np.log(2), table_bits, lr,
                            utils.HumanFormat(ntuples),
                            utils.HumanFormat(ntuples /
                                              (time.time() - begin_time))))
                elif not use_meters:
                    print(
                        'Epoch {} Iter {}, {} loss {:.3f} bits/tuple, {:.5f} lr'
                        .format(epoch_num, step, split,
                                loss.item() / np.log(2), lr))

        if verbose:
            print('%s epoch average loss: %f' % (split, np.mean(losses)))

        batch_time.update((time.time() - t1) * 1e3)
        t1 = time.time()
        if split == 'train' and step % log_every == 0 and use_meters:
            progress.display(step)

    if return_losses:
        return losses
    return np.mean(losses)
Exemple #5
0
    def MakeModel(self, table, train_data, table_primary_index=None):
        cols_to_train = table.columns
        if self.factorize:
            cols_to_train = train_data.columns

        fixed_ordering = self.MakeOrdering(table)

        table_num_columns = table_column_types = table_indexes = None
        if isinstance(train_data,
                      (common.SamplerBasedIterDataset,
                       common.FactorizedSampleFromJoinIterDataset)):
            table_num_columns = train_data.table_num_columns
            table_column_types = train_data.combined_columns_types
            table_indexes = train_data.table_indexes
            print('table_num_columns', table_num_columns)
            print('table_column_types', table_column_types)
            print('table_indexes', table_indexes)
            print('table_primary_index', table_primary_index)

        if self.use_transformer:
            args = {
                'num_blocks': 4,
                'd_ff': 128,
                'd_model': 32,
                'num_heads': 4,
                'd_ff': 64,
                'd_model': 16,
                'num_heads': 2,
                'nin': len(cols_to_train),
                'input_bins': [c.distribution_size for c in cols_to_train],
                'use_positional_embs': False,
                'activation': 'gelu',
                'fixed_ordering': self.fixed_ordering,
                'dropout': self.dropout,
                'per_row_dropout': self.per_row_dropout,
                'seed': None,
                'join_args': {
                    'num_joined_tables': len(self.join_tables),
                    'table_dropout': self.table_dropout,
                    'table_num_columns': table_num_columns,
                    'table_column_types': table_column_types,
                    'table_indexes': table_indexes,
                    'table_primary_index': table_primary_index,
                }
            }
            args.update(self.transformer_args)
            model = transformer.Transformer(**args).to(
                train_utils.get_device())
        else:
            model = MakeMade(
                table=table,
                scale=self.fc_hiddens,
                layers=self.layers,
                cols_to_train=cols_to_train,
                seed=self.seed,
                factor_table=train_data if self.factorize else None,
                fixed_ordering=fixed_ordering,
                special_orders=self.special_orders,
                order_content_only=self.order_content_only,
                order_indicators_at_front=self.order_indicators_at_front,
                inv_order=True,
                residual=self.residual,
                direct_io=self.direct_io,
                input_encoding=self.input_encoding,
                output_encoding=self.output_encoding,
                embed_size=self.embed_size,
                dropout=self.dropout,
                per_row_dropout=self.per_row_dropout,
                grouped_dropout=self.grouped_dropout
                if self.factorize else False,
                fixed_dropout_ratio=self.fixed_dropout_ratio,
                input_no_emb_if_leq=self.input_no_emb_if_leq,
                embs_tied=self.embs_tied,
                resmade_drop_prob=self.resmade_drop_prob,
                # DMoL:
                num_dmol=self.num_dmol,
                scale_input=self.scale_input if self.num_dmol else False,
                dmol_cols=self.dmol_cols if self.num_dmol else [],
                # Join specific:
                num_joined_tables=len(self.join_tables),
                table_dropout=self.table_dropout,
                table_num_columns=table_num_columns,
                table_column_types=table_column_types,
                table_indexes=table_indexes,
                table_primary_index=table_primary_index,
            )
        return model
Exemple #6
0
def MakeMade(
        table,
        scale,
        layers,
        cols_to_train,
        seed,
        factor_table=None,
        fixed_ordering=None,
        special_orders=0,
        order_content_only=True,
        order_indicators_at_front=True,
        inv_order=True,
        residual=True,
        direct_io=True,
        input_encoding='embed',
        output_encoding='embed',
        embed_size=32,
        dropout=True,
        grouped_dropout=False,
        per_row_dropout=False,
        fixed_dropout_ratio=False,
        input_no_emb_if_leq=False,
        embs_tied=True,
        resmade_drop_prob=0.,
        # Join specific:
        num_joined_tables=None,
        table_dropout=None,
        table_num_columns=None,
        table_column_types=None,
        table_indexes=None,
        table_primary_index=None,
        # DMoL
        num_dmol=0,
        scale_input=False,
        dmol_cols=[]):
    dmol_col_indexes = []
    if dmol_cols:
        for i in range(len(cols_to_train)):
            if cols_to_train[i].name in dmol_cols:
                dmol_col_indexes.append(i)

    model = made.MADE(
        nin=len(cols_to_train),
        hidden_sizes=[scale] *
        layers if layers > 0 else [512, 256, 512, 128, 1024],
        nout=sum([c.DistributionSize() for c in cols_to_train]),
        num_masks=max(1, special_orders),
        natural_ordering=True,
        input_bins=[c.DistributionSize() for c in cols_to_train],
        do_direct_io_connections=direct_io,
        input_encoding=input_encoding,
        output_encoding=output_encoding,
        embed_size=embed_size,
        input_no_emb_if_leq=input_no_emb_if_leq,
        embs_tied=embs_tied,
        residual_connections=residual,
        factor_table=factor_table,
        seed=seed,
        fixed_ordering=fixed_ordering,
        resmade_drop_prob=resmade_drop_prob,

        # Wildcard skipping:
        dropout_p=dropout,
        fixed_dropout_p=fixed_dropout_ratio,
        grouped_dropout=grouped_dropout,
        learnable_unk=True,
        per_row_dropout=per_row_dropout,

        # DMoL
        num_dmol=num_dmol,
        scale_input=scale_input,
        dmol_col_indexes=dmol_col_indexes,

        # Join support.
        num_joined_tables=num_joined_tables,
        table_dropout=table_dropout,
        table_num_columns=table_num_columns,
        table_column_types=table_column_types,
        table_indexes=table_indexes,
        table_primary_index=table_primary_index,
    ).to(train_utils.get_device())

    if special_orders > 0:
        orders = []

        if order_content_only:
            print('Leaving out virtual columns from orderings')
            cols = [c for c in cols_to_train if not c.name.startswith('__')]
            inds_cols = [
                c for c in cols_to_train if c.name.startswith('__in_')
            ]
            num_indicators = len(inds_cols)
            num_content, num_virtual = len(
                cols), len(cols_to_train) - len(cols)

            # Data: { content }, { indicators }, { fanouts }.
            for i in range(special_orders):
                rng = np.random.RandomState(i + 1)
                content = rng.permutation(np.arange(num_content))
                inds = rng.permutation(
                    np.arange(num_content, num_content + num_indicators))
                fanouts = rng.permutation(
                    np.arange(num_content + num_indicators,
                              len(cols_to_train)))

                if order_indicators_at_front:
                    # Model: { indicators }, { content }, { fanouts },
                    # permute each bracket independently.
                    order = np.concatenate(
                        (inds, content, fanouts)).reshape(-1, )
                else:
                    # Model: { content }, { indicators }, { fanouts }.
                    # permute each bracket independently.
                    order = np.concatenate(
                        (content, inds, fanouts)).reshape(-1, )
                assert len(np.unique(order)) == len(cols_to_train), order
                orders.append(order)
        else:
            # Permute content & virtual columns together.
            for i in range(special_orders):
                orders.append(
                    np.random.RandomState(i + 1).permutation(
                        np.arange(len(cols_to_train))))

        if factor_table:
            # Correct for subvar ordering.
            for i in range(special_orders):
                # This could have [..., 6, ..., 4, ..., 5, ...].
                # So we map them back into:
                # This could have [..., 4, 5, 6, ...].
                # Subvars have to be in order and also consecutive
                order = orders[i]
                for orig_col, sub_cols in factor_table.fact_col_mapping.items(
                ):
                    first_subvar_index = cols_to_train.index(sub_cols[0])
                    print('Before', order)
                    for j in range(1, len(sub_cols)):
                        subvar_index = cols_to_train.index(sub_cols[j])
                        order = np.delete(order,
                                          np.argwhere(order == subvar_index))
                        order = np.insert(
                            order,
                            np.argwhere(order == first_subvar_index)[0][0] + j,
                            subvar_index)
                    orders[i] = order
                    print('After', order)

        print('Special orders', np.array(orders))

        if inv_order:
            for i, order in enumerate(orders):
                orders[i] = np.asarray(utils.InvertOrder(order))
            print('Inverted special orders:', orders)

        model.orderings = orders

    return model
def main(tiny_images=None,
         model="cnn",
         augment=False,
         use_scattering=False,
         batch_size=2048,
         mini_batch_size=256,
         lr=1,
         lr_start=None,
         optim="SGD",
         momentum=0.9,
         noise_multiplier=1,
         max_grad_norm=0.1,
         epochs=100,
         bn_noise_multiplier=None,
         max_epsilon=None,
         data_size=550000,
         delta=1e-6,
         logdir=None):
    logger = Logger(logdir)

    device = get_device()

    bs = batch_size
    assert bs % mini_batch_size == 0
    n_acc_steps = bs // mini_batch_size

    train_data, test_data = get_data("cifar10", augment=augment)
    train_loader = torch.utils.data.DataLoader(train_data,
                                               batch_size=100,
                                               shuffle=False,
                                               num_workers=4,
                                               pin_memory=True)

    test_loader = torch.utils.data.DataLoader(test_data,
                                              batch_size=100,
                                              shuffle=False,
                                              num_workers=4,
                                              pin_memory=True)

    if isinstance(tiny_images, torch.utils.data.Dataset):
        train_data_aug = tiny_images
    else:
        print("loading tiny images...")
        train_data_aug, _ = get_data("cifar10_500K",
                                     augment=augment,
                                     aux_data_filename=tiny_images)

    scattering, K, (h, w) = None, None, (None, None)
    pre_scattered = False
    if use_scattering:
        scattering, K, (h, w) = get_scatter_transform("cifar10_500K")
        scattering.to(device)

    # if the whole data fits in memory, pre-compute the scattering
    if use_scattering and data_size <= 50000:
        loader = torch.utils.data.DataLoader(train_data_aug,
                                             batch_size=100,
                                             shuffle=False,
                                             num_workers=4)
        train_data_aug = get_scattered_dataset(loader, scattering, device,
                                               data_size)
        pre_scattered = True

    assert data_size <= len(train_data_aug)
    num_sup = min(data_size, 50000)
    num_batches = int(np.ceil(50000 / mini_batch_size))  # cifar-10 equivalent

    train_batch_sampler = SemiSupervisedSampler(data_size, num_batches,
                                                mini_batch_size)
    train_loader_aug = torch.utils.data.DataLoader(
        train_data_aug,
        batch_sampler=train_batch_sampler,
        num_workers=0 if pre_scattered else 4,
        pin_memory=not pre_scattered)

    rdp_norm = 0
    if model == "cnn":
        if use_scattering:
            save_dir = f"bn_stats/cifar10_500K"
            os.makedirs(save_dir, exist_ok=True)
            bn_stats, rdp_norm = scatter_normalization(
                train_loader,
                scattering,
                K,
                device,
                data_size,
                num_sup,
                noise_multiplier=bn_noise_multiplier,
                orders=ORDERS,
                save_dir=save_dir)
            model = CNNS["cifar10"](K, input_norm="BN", bn_stats=bn_stats)
            model = model.to(device)

            if not pre_scattered:
                model = nn.Sequential(scattering, model)
        else:
            model = CNNS["cifar10"](in_channels=3, internal_norm=False)

    elif model == "linear":
        save_dir = f"bn_stats/cifar10_500K"
        os.makedirs(save_dir, exist_ok=True)
        bn_stats, rdp_norm = scatter_normalization(
            train_loader,
            scattering,
            K,
            device,
            data_size,
            num_sup,
            noise_multiplier=bn_noise_multiplier,
            orders=ORDERS,
            save_dir=save_dir)
        model = ScatterLinear(K, (h, w), input_norm="BN", bn_stats=bn_stats)
        model = model.to(device)

        if not pre_scattered:
            model = nn.Sequential(scattering, model)
    else:
        raise ValueError(f"Unknown model {model}")
    model.to(device)

    if pre_scattered:
        test_loader = get_scattered_loader(test_loader, scattering, device)

    print(f"model has {get_num_params(model)} parameters")

    if optim == "SGD":
        optimizer = torch.optim.SGD(model.parameters(),
                                    lr=lr,
                                    momentum=momentum)
    else:
        optimizer = torch.optim.Adam(model.parameters(), lr=lr)

    privacy_engine = PrivacyEngine(
        model,
        bs,
        data_size,
        alphas=ORDERS,
        noise_multiplier=noise_multiplier,
        max_grad_norm=max_grad_norm,
    )
    privacy_engine.attach(optimizer)

    best_acc = 0
    flat_count = 0

    for epoch in range(0, epochs):

        print(f"\nEpoch: {epoch} ({privacy_engine.steps} steps)")
        train_loss, train_acc = train(model,
                                      train_loader_aug,
                                      optimizer,
                                      n_acc_steps=n_acc_steps)
        test_loss, test_acc = test(model, test_loader)

        if noise_multiplier > 0:
            print(f"sample_rate={privacy_engine.sample_rate}, "
                  f"mul={privacy_engine.noise_multiplier}, "
                  f"steps={privacy_engine.steps}")
            rdp_sgd = get_renyi_divergence(
                privacy_engine.sample_rate,
                privacy_engine.noise_multiplier) * privacy_engine.steps
            epsilon, _ = get_privacy_spent(rdp_norm + rdp_sgd,
                                           target_delta=delta)
            epsilon2, _ = get_privacy_spent(rdp_sgd, target_delta=delta)
            print(f"ε = {epsilon:.3f} (sgd only: ε = {epsilon2:.3f})")

            if max_epsilon is not None and epsilon >= max_epsilon:
                return
        else:
            epsilon = None

        logger.log_epoch(epoch, train_loss, train_acc, test_loss, test_acc,
                         epsilon)
        logger.log_scalar("epsilon/train", epsilon, epoch)
        logger.log_scalar("cifar10k_loss/train", train_loss, epoch)
        logger.log_scalar("cifar10k_acc/train", train_acc, epoch)

        if test_acc > best_acc:
            best_acc = test_acc
            flat_count = 0
        else:
            flat_count += 1
            if flat_count >= 20:
                print("plateau...")
                return
Exemple #8
0
def main(feature_path=None,
         batch_size=2048,
         mini_batch_size=256,
         lr=1,
         optim="SGD",
         momentum=0.9,
         nesterov=False,
         noise_multiplier=1,
         max_grad_norm=0.1,
         max_epsilon=None,
         epochs=100,
         logdir=None):

    logger = Logger(logdir)

    device = get_device()

    # get pre-computed features
    x_train = np.load(f"{feature_path}_train.npy")
    x_test = np.load(f"{feature_path}_test.npy")

    train_data, test_data = get_data("cifar10", augment=False)
    y_train = np.asarray(train_data.targets)
    y_test = np.asarray(test_data.targets)

    trainset = torch.utils.data.TensorDataset(torch.from_numpy(x_train),
                                              torch.from_numpy(y_train))
    testset = torch.utils.data.TensorDataset(torch.from_numpy(x_test),
                                             torch.from_numpy(y_test))

    bs = batch_size
    assert bs % mini_batch_size == 0
    n_acc_steps = bs // mini_batch_size
    train_loader = torch.utils.data.DataLoader(trainset,
                                               batch_size=mini_batch_size,
                                               shuffle=True,
                                               num_workers=1,
                                               pin_memory=True,
                                               drop_last=True)
    test_loader = torch.utils.data.DataLoader(testset,
                                              batch_size=mini_batch_size,
                                              shuffle=False,
                                              num_workers=1,
                                              pin_memory=True)

    n_features = x_train.shape[-1]
    try:
        mean = np.load(f"{feature_path}_mean.npy")
        var = np.load(f"{feature_path}_var.npy")
    except FileNotFoundError:
        mean = np.zeros(n_features, dtype=np.float32)
        var = np.ones(n_features, dtype=np.float32)

    bn_stats = (torch.from_numpy(mean).to(device),
                torch.from_numpy(var).to(device))

    model = nn.Sequential(StandardizeLayer(bn_stats),
                          nn.Linear(n_features, 10)).to(device)

    if optim == "SGD":
        optimizer = torch.optim.SGD(model.parameters(),
                                    lr=lr,
                                    momentum=momentum,
                                    nesterov=nesterov)
    else:
        optimizer = torch.optim.Adam(model.parameters(), lr=lr)

    privacy_engine = PrivacyEngine(
        model,
        sample_rate=bs / len(train_data),
        alphas=ORDERS,
        noise_multiplier=noise_multiplier,
        max_grad_norm=max_grad_norm,
    )
    privacy_engine.attach(optimizer)

    for epoch in range(0, epochs):
        print(f"\nEpoch: {epoch}")

        train_loss, train_acc = train(model,
                                      train_loader,
                                      optimizer,
                                      n_acc_steps=n_acc_steps)
        test_loss, test_acc = test(model, test_loader)

        if noise_multiplier > 0:
            rdp_sgd = get_renyi_divergence(
                privacy_engine.sample_rate,
                privacy_engine.noise_multiplier) * privacy_engine.steps
            epsilon, _ = get_privacy_spent(rdp_sgd)
            print(f"ε = {epsilon:.3f}")

            if max_epsilon is not None and epsilon >= max_epsilon:
                return
        else:
            epsilon = None

        logger.log_epoch(epoch, train_loss, train_acc, test_loss, test_acc,
                         epsilon)
Exemple #9
0
def main(dataset,
         augment=False,
         batch_size=2048,
         mini_batch_size=256,
         sample_batches=False,
         lr=1,
         optim="SGD",
         momentum=0.9,
         nesterov=False,
         noise_multiplier=1,
         max_grad_norm=0.1,
         epochs=100,
         input_norm=None,
         num_groups=None,
         bn_noise_multiplier=None,
         max_epsilon=None,
         logdir=None):
    logger = Logger(logdir)
    device = get_device()

    train_data, test_data = get_data(dataset, augment=augment)
    scattering, K, (h, w) = get_scatter_transform(dataset)
    scattering.to(device)

    bs = batch_size
    assert bs % mini_batch_size == 0
    n_acc_steps = bs // mini_batch_size

    # Batch accumulation and data augmentation with Poisson sampling isn't implemented
    if sample_batches:
        assert n_acc_steps == 1
        assert not augment

    train_loader = torch.utils.data.DataLoader(train_data,
                                               batch_size=mini_batch_size,
                                               shuffle=True,
                                               num_workers=1,
                                               pin_memory=True)

    test_loader = torch.utils.data.DataLoader(test_data,
                                              batch_size=mini_batch_size,
                                              shuffle=False,
                                              num_workers=1,
                                              pin_memory=True)

    rdp_norm = 0
    if input_norm == "BN":
        # compute noisy data statistics or load from disk if pre-computed
        save_dir = f"bn_stats/{dataset}"
        os.makedirs(save_dir, exist_ok=True)
        bn_stats, rdp_norm = scatter_normalization(
            train_loader,
            scattering,
            K,
            device,
            len(train_data),
            len(train_data),
            noise_multiplier=bn_noise_multiplier,
            orders=ORDERS,
            save_dir=save_dir)
        model = ScatterLinear(K, (h, w), input_norm="BN", bn_stats=bn_stats)
    else:
        model = ScatterLinear(K, (h, w),
                              input_norm=input_norm,
                              num_groups=num_groups)
    model.to(device)

    trainable_params = sum(p.numel() for p in model.parameters()
                           if p.requires_grad)
    print(f'model: {model}\n')
    print(f'has {trainable_params / 1e6:.4f} million trainable parameters')

    if augment:
        model = nn.Sequential(scattering, model)
        train_loader = torch.utils.data.DataLoader(train_data,
                                                   batch_size=mini_batch_size,
                                                   shuffle=True,
                                                   num_workers=1,
                                                   pin_memory=True,
                                                   drop_last=True)
        preprocessor = None
    else:
        preprocessor = lambda x, y: (scattering(x), y)

    # baseline Logistic Regression without privacy
    if optim == "LR":
        assert not augment
        X_train = []
        y_train = []
        X_test = []
        y_test = []
        for data, target in train_loader:
            with torch.no_grad():
                data = data.to(device)
                X_train.append(data.cpu().numpy().reshape(len(data), -1))
                y_train.extend(target.cpu().numpy())

        for data, target in test_loader:
            with torch.no_grad():
                data = data.to(device)
                X_test.append(data.cpu().numpy().reshape(len(data), -1))
                y_test.extend(target.cpu().numpy())

        import numpy as np
        X_train = np.concatenate(X_train, axis=0)
        X_test = np.concatenate(X_test, axis=0)
        y_train = np.asarray(y_train)
        y_test = np.asarray(y_test)

        print(X_train.shape, y_train.shape, X_test.shape, y_test.shape)

        for idx, C in enumerate([0.01, 0.1, 1.0, 10, 100]):
            clf = LogisticRegression(C=C, fit_intercept=True)
            clf.fit(X_train, y_train)

            train_acc = 100 * clf.score(X_train, y_train)
            test_acc = 100 * clf.score(X_test, y_test)
            print(f"C={C}, "
                  f"Acc train = {train_acc: .2f}, "
                  f"Acc test = {test_acc: .2f}")

            logger.log_epoch(idx, 0, train_acc, 0, test_acc, None)
        return

    print(f"model has {get_num_params(model)} parameters")

    if optim == "SGD":
        optimizer = torch.optim.SGD(model.parameters(),
                                    lr=lr,
                                    momentum=momentum,
                                    nesterov=nesterov)
    else:
        optimizer = torch.optim.Adam(model.parameters(), lr=lr)

    privacy_engine = PrivacyEngine(
        model,
        batch_size=bs,
        sample_size=len(train_data),
        alphas=ORDERS,
        noise_multiplier=noise_multiplier,
        max_grad_norm=max_grad_norm,
    )
    privacy_engine.attach(optimizer)

    for epoch in range(0, epochs):
        print(f"\nEpoch: {epoch}")
        train_loss, train_acc = train(model,
                                      train_loader,
                                      optimizer,
                                      n_acc_steps=n_acc_steps,
                                      preprocessor=preprocessor)
        test_loss, test_acc = test(model,
                                   test_loader,
                                   preprocessor=preprocessor)

        if noise_multiplier > 0:
            rdp_sgd = get_renyi_divergence(
                privacy_engine.sample_rate,
                privacy_engine.noise_multiplier) * privacy_engine.steps
            epsilon, _ = get_privacy_spent(rdp_norm + rdp_sgd)
            epsilon2, _ = get_privacy_spent(rdp_sgd)
            print(f"ε = {epsilon:.3f} (sgd only: ε = {epsilon2:.3f})")

            if max_epsilon is not None and epsilon >= max_epsilon:
                return
        else:
            epsilon = None

        logger.log_epoch(epoch, train_loss, train_acc, test_loss, test_acc,
                         epsilon)
        logger.log_scalar("epsilon/train", epsilon, epoch)