Exemple #1
0
def main(args):
    print(args)
    assert not args.dpsgd
    torch.backends.cudnn.benchmark = True

    train_data, train_labels = get_data(args)
    model = model_dict[args.experiment](vocab_size=args.max_features).cuda()
    optimizer = optim.SGD(model.parameters(), lr=args.learning_rate)
    loss_function = nn.CrossEntropyLoss() if args.experiment != 'logreg' else nn.BCELoss()

    timings = []
    for epoch in range(1, args.epochs + 1):
        start = time.perf_counter()
        dataloader = data.dataloader(train_data, train_labels, args.batch_size)
        for batch_idx, (x, y) in enumerate(dataloader):
            x, y = x.cuda(non_blocking=True), y.cuda(non_blocking=True)
            model.zero_grad()
            outputs = model(x)
            loss = loss_function(outputs, y)
            loss.backward()
            optimizer.step()
        duration = time.perf_counter() - start
        print("Time Taken for Epoch: ", duration)
        timings.append(duration)

    if not args.no_save:
        utils.save_runtimes(__file__.split('.')[0], args, timings)
    else:
        print('Not saving!')
    print('Done!')
def main(args):
    print(args)
    assert args.dpsgd
    torch.backends.cudnn.benchmark = True

    mdict = model_dict.copy()
    mdict['lstm'] = LSTMNet

    train_data, train_labels = get_data(args)
    model = mdict[args.experiment](vocab_size=args.max_features,
                                   batch_size=args.batch_size).cuda()
    optimizer = optim.SGD(model.parameters(),
                          lr=args.learning_rate,
                          momentum=0)
    loss_function = nn.CrossEntropyLoss(
    ) if args.experiment != 'logreg' else nn.BCELoss()

    privacy_engine = PrivacyEngine(
        model,
        batch_size=args.batch_size,
        sample_size=len(train_data),
        alphas=[1 + x / 10.0 for x in range(1, 100)] + list(range(12, 64)),
        noise_multiplier=args.sigma,
        max_grad_norm=args.max_per_sample_grad_norm,
    )
    privacy_engine.attach(optimizer)

    timings = []
    for epoch in range(1, args.epochs + 1):
        start = time.perf_counter()
        dataloader = data.dataloader(train_data, train_labels, args.batch_size)
        for batch_idx, (x, y) in enumerate(dataloader):
            x, y = x.cuda(non_blocking=True), y.cuda(non_blocking=True)
            model.zero_grad()
            outputs = model(x)
            loss = loss_function(outputs, y)
            loss.backward()
            optimizer.step()
        torch.cuda.synchronize()
        duration = time.perf_counter() - start
        print("Time Taken for Epoch: ", duration)
        timings.append(duration)

        if args.dpsgd:
            epsilon, best_alpha = optimizer.privacy_engine.get_privacy_spent(
                args.delta)
            print(
                f"Train Epoch: {epoch} \t"
                # f"Loss: {np.mean(losses):.6f} "
                f"(ε = {epsilon:.2f}, δ = {args.delta}) for α = {best_alpha}")
        else:
            print(f"Train Epoch: {epoch} \t Loss: {np.mean(losses):.6f}")

    if not args.no_save:
        utils.save_runtimes(__file__.split('.')[0], args, timings)
    else:
        print('Not saving!')
    print('Done!')
Exemple #3
0
def main(args):
    print(args)
    assert args.dpsgd
    torch.backends.cudnn.benchmark = True

    train_data, train_labels = get_data(args)
    num_complete_batches, leftover = divmod(len(train_data), args.batch_size)
    num_batches = num_complete_batches + bool(leftover)

    model = model_dict[args.experiment](vocab_size=args.max_features).cuda()
    loss_function = nn.CrossEntropyLoss(
    ) if args.experiment != 'logreg' else nn.BCELoss()

    opt = optim.DPSGD(params=model.parameters(),
                      l2_norm_clip=args.l2_norm_clip,
                      noise_multiplier=args.noise_multiplier,
                      minibatch_size=args.batch_size,
                      microbatch_size=1,
                      lr=args.learning_rate)

    timings = []
    for epoch in range(1, args.epochs + 1):
        start = time.perf_counter()
        dataloader = data.dataloader(train_data, train_labels, args.batch_size)
        for batch_idx, (x_mb, y_mb) in enumerate(dataloader):
            x_mb, y_mb = x_mb.cuda(non_blocking=True), y_mb.cuda(
                non_blocking=True)
            for x, y in zip(x_mb, y_mb):
                opt.zero_microbatch_grad()
                out = model(x[None])
                curr_loss = loss_function(out, y[None])
                curr_loss.backward()
                opt.microbatch_step()
            opt.step()
        duration = time.perf_counter() - start
        print("Time Taken for Epoch: ", duration)
        timings.append(duration)

    if not args.no_save:
        utils.save_runtimes(__file__.split('.')[0], args, timings)
    else:
        print('Not saving!')
    print('Done!')
def main(args):
    print(args)
    tf.disable_eager_execution()
    if args.memory_limit:
        physical_devices = tf.config.list_physical_devices('GPU')
        tf.config.experimental.set_memory_growth(physical_devices[0], True)
        tf.config.experimental.set_virtual_device_configuration(
            physical_devices[0], [
                tf.config.experimental.VirtualDeviceConfiguration(
                    memory_limit=args.memory_limit)
            ])

    assert args.microbatches is None
    args.microbatches = args.batch_size

    data_fn = data.data_fn_dict[args.experiment][int(args.dummy_data)]
    kwargs = {
        'max_features': args.max_features,
        'max_len': args.max_len,
        'format': 'NHWC',
    }
    if args.dummy_data:
        kwargs['num_examples'] = args.batch_size * 2
    (train_data, train_labels), _ = data_fn(**kwargs)
    num_train_eg = train_data.shape[0]

    loss_fn = tf.nn.sparse_softmax_cross_entropy_with_logits
    if args.experiment == 'logreg':
        loss_fn = lambda labels, logits: tf.nn.sigmoid_cross_entropy_with_logits(
            labels=labels, logits=tf.squeeze(logits))
        train_labels = train_labels.astype('float32')

    model = partial(model_dict[args.experiment],
                    features=train_data,
                    max_features=args.max_features,
                    args=args)

    if args.use_xla:
        # Not sure which one of these two works, so I'll just use both
        assert os.environ['TF_XLA_FLAGS'] == '--tf_xla_auto_jit=2'
        session_config = tf.ConfigProto()
        session_config.graph_options.optimizer_options.global_jit_level = tf.OptimizerOptions.ON_2
        run_config = tf.estimator.RunConfig(session_config=session_config)
        print('Using XLA!')
    else:
        run_config = None
        print('NOT using XLA!')

    model_obj = tf.estimator.Estimator(model_fn=partial(
        nn_model_fn, model, loss_fn, args),
                                       config=run_config)
    train_input_fn = tf.estimator.inputs.numpy_input_fn(
        x={'x': train_data},
        y=train_labels,
        batch_size=args.batch_size,
        num_epochs=args.epochs,
        shuffle=True)

    steps_per_epoch = num_train_eg // args.batch_size
    timings = []
    for epoch in range(1, args.epochs + 1):
        start = time.perf_counter()
        model_obj.train(input_fn=train_input_fn, steps=steps_per_epoch)
        duration = time.perf_counter() - start
        print("Time Taken: ", duration)
        timings.append(duration)

        if args.dpsgd:
            # eps = compute_epsilon(epoch, num_train_eg, args)
            # print('For delta=1e-5, the current epsilon is: %.2f' % eps)
            print('Trained with DPSGD optimizer')
        else:
            print('Trained with vanilla non-private SGD optimizer')

    if not args.no_save:
        utils.save_runtimes(__file__.split('.')[0], args, timings)
    else:
        print('Not saving!')
    print('Done!')
def main(args):
    print(args)
    if args.microbatches:
        raise NotImplementedError(
            'Microbatches < batch size not currently supported')
    if args.experiment == 'lstm' and args.no_jit:
        raise ValueError('LSTM with no JIT will fail.')

    data_fn = data.data_fn_dict[args.experiment][int(args.dummy_data)]
    kwargs = {
        'max_features': args.max_features,
        'max_len': args.max_len,
        'format': 'NHWC',
    }
    if args.dummy_data:
        kwargs['num_examples'] = args.batch_size * 2
    (train_data, train_labels), _ = data_fn(**kwargs)
    # train_labels, test_labels = to_categorical(train_labels), to_categorical(
    #     test_labels)

    num_train = train_data.shape[0]
    num_complete_batches, leftover = divmod(num_train, args.batch_size)
    num_batches = num_complete_batches + bool(leftover)
    key = random.PRNGKey(args.seed)

    model = hk.transform(
        partial(model_dict[args.experiment],
                args=args,
                vocab_size=args.max_features,
                seq_len=args.max_len))
    rng = jax.random.PRNGKey(42)
    init_params = model.init(key, train_data[:args.batch_size])
    opt_init, opt_update, get_params = optimizers.sgd(args.learning_rate)
    loss = logistic_loss if args.experiment == 'logreg' else multiclass_loss

    if args.dpsgd:
        train_data, train_labels = train_data[:, None], train_labels[:, None]

    def update(_, i, opt_state, batch):
        params = get_params(opt_state)
        return opt_update(i,
                          grad(partial(loss, model))(params, batch), opt_state)

    grad_fn = private_grad_no_vmap if args.no_vmap else private_grad

    def private_update(rng, i, opt_state, batch):
        params = get_params(opt_state)
        rng = random.fold_in(rng, i)  # get new key for new random numbers
        return opt_update(
            i,
            grad_fn(model, loss, params, batch, rng, args.l2_norm_clip,
                    args.noise_multiplier, args.batch_size), opt_state)

    opt_state = opt_init(init_params)
    itercount = itertools.count()
    train_fn = private_update if args.dpsgd else update

    if args.no_vmap:
        print('No vmap for dpsgd!')

    if not args.no_jit:
        train_fn = jit(train_fn)
    else:
        print('No jit!')

    dummy = jnp.array(1.)

    timings = []
    for epoch in range(1, args.epochs + 1):
        start = time.perf_counter()
        for i, batch in enumerate(
                data.dataloader(train_data, train_labels, args.batch_size)):
            opt_state = train_fn(
                key,
                next(itercount),
                opt_state,
                batch,
            )
        (dummy * dummy).block_until_ready()  # synchronize
        duration = time.perf_counter() - start
        print("Time Taken: ", duration)
        timings.append(duration)

        if args.dpsgd:
            print('Trained with DP SGD optimizer')
        else:
            print('Trained with vanilla non-private SGD optimizer')

    if not args.no_save:
        append_to_name = ''
        if args.no_jit: append_to_name += '_nojit'
        if args.no_vmap: append_to_name += '_novmap'
        utils.save_runtimes(
            __file__.split('.')[0], args, timings, append_to_name)
    else:
        print('Not saving!')
    print('Done!')
def main(args):
    print(args)
    if args.memory_limit:
        physical_devices = tf.config.list_physical_devices('GPU')
        tf.config.experimental.set_memory_growth(physical_devices[0], True)
        tf.config.experimental.set_virtual_device_configuration(
            physical_devices[0],
            [tf.config.experimental.VirtualDeviceConfiguration(memory_limit=args.memory_limit)])

    assert args.microbatches is None
    args.microbatches = args.batch_size

    data_fn = data.data_fn_dict[args.experiment][int(args.dummy_data)]
    kwargs = {
        'max_features': args.max_features,
        'max_len': args.max_len,
        'format': 'NHWC',
    }
    if args.dummy_data:
        kwargs['num_examples'] = args.batch_size * 2
    (train_data, train_labels), _ = data_fn(**kwargs)
    train_data, train_labels = tf.constant(train_data), tf.constant(train_labels)
    num_train_eg = train_data[0].shape[0]

    loss_fn = tf.nn.sparse_softmax_cross_entropy_with_logits
    if args.experiment == 'logreg':
        loss_fn = lambda labels, logits: tf.nn.sigmoid_cross_entropy_with_logits(
            labels=labels, logits=tf.squeeze(logits))
        train_labels = tf.cast(train_labels, tf.float32)

    model_bs = 1 if args.dpsgd else args.batch_size
    model = model_dict[args.experiment](
        train_data,
        max_features=args.max_features,
        # batch_size=model_bs,
        args=args)
    optimizer = tf.keras.optimizers.SGD(learning_rate=args.learning_rate)
    train_fn = private_train_step if args.dpsgd else train_step
    train_fn = partial(train_fn, model, optimizer, loss_fn, args)

    if args.no_vmap:
        print('No vmap for dpsgd!')

    if args.no_jit:
        print('No jit!')
    else:
        train_fn = tf.function(experimental_compile=args.use_xla)(train_fn)

    with tf.device('GPU'):
        dummy = tf.convert_to_tensor(1.)

    timings = []
    for epoch in range(1, args.epochs + 1):
        start = time.perf_counter()
        for i, batch in enumerate(data.dataloader(train_data, train_labels, args.batch_size)):
            train_fn(batch)
        _ = dummy.numpy()
        duration = time.perf_counter() - start
        print("Time Taken: ", duration)
        timings.append(duration)

        if args.dpsgd:
            # eps = compute_eps_poisson(epoch, args.noise_multiplier, num_train_eg, args.batch_size,
            #                           1e-5)
            # mu = compute_mu_poisson(epoch, args.noise_multiplier, num_train_eg, args.batch_size)
            # print('For delta=1e-5, the current epsilon is: %.2f' % eps)
            # print('For delta=1e-5, the current mu is: %.2f' % mu)
            print('Trained with DPSGD optimizer')
        else:
            print('Trained with vanilla non-private SGD optimizer')

    if not args.no_save:
        append_to_name = ''
        if args.no_jit: append_to_name += '_nojit'
        if args.no_vmap: append_to_name += '_novmap'
        utils.save_runtimes(__file__.split('.')[0], args, timings, append_to_name)
    else:
        print('Not saving!')
    print('Done!')