Exemple #1
0
 def __init__(self, _model, _attack, _args):
     super(AT, self).__init__(_model, _attack, _args)
     self.inner_loss_fn = get_loss_fn(_args.inner_loss)
     self.outer_loss_fn = get_loss_fn(_args.outer_loss)
     self.init_mode = "pgd"
     if _args.defense == "trades":
         self.init_mode = "trades"
Exemple #2
0
def validate(cfg, model_path):
    assert model_path is not None, 'Not assert model path'
    use_cuda = False
    if cfg.get("cuda", None) is not None:
        if cfg.get("cuda", None) != "all":
            os.environ["CUDA_VISIBLE_DEVICES"] = cfg.get("cuda", None)
        use_cuda = torch.cuda.is_available()

    # Setup Dataloader
    train_loader, val_loader = get_loader(cfg)

    loss_fn = get_loss_fn(cfg)

    # Load Model
    model = get_model(cfg)
    if use_cuda:
        model.cuda()
        loss_fn.cuda()
        checkpoint = torch.load(model_path)
        if torch.cuda.device_count() > 1:  # multi gpus
            model = torch.nn.DataParallel(
                model, device_ids=list(range(torch.cuda.device_count())))
            state = checkpoint["state_dict"]
        else:  # 1 gpu
            state = convert_state_dict(checkpoint["state_dict"])
    else:  # cpu
        checkpoint = torch.load(model_path, map_location='cpu')
        state = convert_state_dict(checkpoint["state_dict"])
    model.load_state_dict(state)

    validate_epoch(val_loader, model, loss_fn, use_cuda)
Exemple #3
0
 def __init__(self, _model, _attack, _args):
     super(MART, self).__init__(_model, _attack, _args)
     self.inner_loss_fn = get_loss_fn("CE")
     self.outer_loss_fn = get_loss_fn("mart_outer")
Exemple #4
0
def train(cfg, writer, logger):
    # This statement must be declared before using pytorch
    use_cuda = False
    if cfg.get("cuda", None) is not None:
        if cfg.get("cuda", None) != "all":
            os.environ["CUDA_VISIBLE_DEVICES"] = cfg.get("cuda", None)
        use_cuda = torch.cuda.is_available()

    # Setup random seed
    seed = cfg["training"].get("seed", random.randint(1, 10000))
    torch.manual_seed(seed)
    if use_cuda:
        torch.cuda.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)

    # Setup Dataloader
    train_loader, val_loader = get_loader(cfg)

    # Setup Model
    model = get_model(cfg)
    # writer.add_graph(model, torch.rand([1, 3, 224, 224]))
    if use_cuda and torch.cuda.device_count() > 1:
        model = torch.nn.DataParallel(model,
                                      device_ids=list(
                                          range(torch.cuda.device_count())))

    # Setup optimizer, lr_scheduler and loss function
    optimizer = get_optimizer(model.parameters(), cfg)
    scheduler = get_scheduler(optimizer, cfg)
    loss_fn = get_loss_fn(cfg)

    # Setup Metrics
    epochs = cfg["training"]["epochs"]
    recorder = RecorderMeter(epochs)
    start_epoch = 0

    # save model parameters every <n> epochs
    save_interval = cfg["training"]["save_interval"]

    if use_cuda:
        model.cuda()
        loss_fn.cuda()

    # Resume Trained Model
    resume_path = os.path.join(writer.file_writer.get_logdir(),
                               cfg["training"]["resume"])
    best_path = os.path.join(writer.file_writer.get_logdir(),
                             cfg["training"]["best_model"])

    if cfg["training"]["resume"] is not None:
        if os.path.isfile(resume_path):
            logger.info(
                "Loading model and optimizer from checkpoint '{}'".format(
                    resume_path))
            checkpoint = torch.load(resume_path)
            state = checkpoint["state_dict"]
            if torch.cuda.device_count() <= 1:
                state = convert_state_dict(state)
            model.load_state_dict(state)
            optimizer.load_state_dict(checkpoint["optimizer"])
            scheduler.load_state_dict(checkpoint["scheduler"])
            start_epoch = checkpoint["epoch"]
            recorder = checkpoint['recorder']
            logger.info("Loaded checkpoint '{}' (epoch {})".format(
                resume_path, checkpoint["epoch"]))
        else:
            logger.info("No checkpoint found at '{}'".format(resume_path))

    epoch_time = AverageMeter()
    for epoch in range(start_epoch, epochs):
        start_time = time.time()
        need_hour, need_mins, need_secs = convert_secs2time(epoch_time.avg *
                                                            (epochs - epoch))
        need_time = '[Need: {:02d}:{:02d}:{:02d}]'.format(
            need_hour, need_mins, need_secs)
        logger.info(
            '\n==>>{:s} [Epoch={:03d}/{:03d}] {:s} [learning_rate={:8.6f}]'.
            format(time_string(), epoch, epochs, need_time, optimizer.
                   param_groups[0]['lr']) +  # scheduler.get_last_lr() >=1.4
            ' [Best : Accuracy={:.2f}]'.format(recorder.max_accuracy(False)))
        train_acc, train_los = train_epoch(train_loader, model, loss_fn,
                                           optimizer, use_cuda, logger)
        val_acc, val_los = validate_epoch(val_loader, model, loss_fn, use_cuda,
                                          logger)
        scheduler.step()

        is_best = recorder.update(epoch, train_los, train_acc, val_los,
                                  val_acc)
        if is_best or epoch % save_interval == 0 or epoch == epochs - 1:  # save model (resume model and best model)
            save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'recorder': recorder,
                    'state_dict': model.state_dict(),
                    'optimizer': optimizer.state_dict(),
                    'scheduler': scheduler.state_dict(),
                }, is_best, best_path, resume_path)

            for name, param in model.named_parameters():  # save histogram
                writer.add_histogram(name,
                                     param.clone().cpu().data.numpy(), epoch)

        writer.add_scalar('Train/loss', train_los, epoch)  # save curves
        writer.add_scalar('Train/acc', train_acc, epoch)
        writer.add_scalar('Val/loss', val_los, epoch)
        writer.add_scalar('Val/acc', val_acc, epoch)

        epoch_time.update(time.time() - start_time)

    writer.close()
Exemple #5
0
def attack_pgd(
    model,
    X,
    y,
    epsilon,
    alpha,
    attack_iters,
    restarts,
    norm,
    early_stop=False,
    loss_fn=None,
    init_mode="pgd",
    args=None,
):
    max_loss = torch.zeros_like(y)
    max_delta = torch.zeros_like(X)
    init_mode = init_mode.lower()
    if loss_fn is not None:
        with torch.no_grad():
            is_model_training = model.training
            model.eval()
            nat_output = model(X).detach()
            if is_model_training:
                model.train()
    else:
        nat_output = None
        loss_fn = get_loss_fn("CE")

    for _ in range(restarts):
        delta = torch.zeros_like(X)
        if init_mode == "pgd":
            if norm == "l_inf":
                delta.uniform_(-epsilon, epsilon)
            elif norm == "l_2":
                delta.normal_()
                d_flat = delta.view(delta.size(0), -1)
                n = d_flat.norm(p=2, dim=1).view(delta.size(0), 1, 1, 1)
                r = torch.zeros_like(n).uniform_(0, 1)
                delta *= r / (n + 1e-10) * epsilon
        elif init_mode == "trades":
            delta = 0.001 * torch.randn_like(X).detach()
        else:
            raise ValueError
        delta = clamp(delta, -X, 1 - X)
        delta.requires_grad = True

        for _ in range(attack_iters):
            output = model(X + delta)
            if early_stop:
                index = torch.where(output.max(1)[1] == y)[0]
            else:
                index = slice(None, None, None)
            if not isinstance(index, slice) and len(index) == 0:
                break
            loss = loss_fn(output, y, nat_output)
            if not args.no_amp:
                # args.scaler.scale(loss).backward()
                with args.amp.scale_loss(loss, args.opt) as scaled_loss:
                    scaled_loss.backward()
            else:
                loss.backward()
            grad = delta.grad.detach()
            d = delta[index, :, :, :]
            g = grad[index, :, :, :]
            x = X[index, :, :, :]
            if norm == "l_inf":
                d = torch.clamp(d + alpha * torch.sign(g),
                                min=-epsilon,
                                max=epsilon)
            elif norm == "l_2":
                g_norm = torch.norm(g.view(g.shape[0], -1),
                                    dim=1).view(-1, 1, 1, 1)
                scaled_g = g / (g_norm + 1e-10)
                d = ((d + scaled_g * alpha).view(d.size(0), -1).renorm(
                    p=2, dim=0, maxnorm=epsilon).view_as(d))
            d = clamp(d, -x, 1 - x)
            delta.data[index, :, :, :] = d
            delta.grad.zero_()
        all_loss = loss_fn(model(X + delta), y, nat_output, reduction="none")
        max_delta[all_loss >= max_loss] = delta.detach()[all_loss >= max_loss]
        max_loss = torch.max(max_loss, all_loss)
    return max_delta
Exemple #6
0
def main(_):
    mnist = input_data.read_data_sets(FLAGS.datadir, one_hot=False)
    ims = np.reshape(mnist.train.images, [-1, 28, 28, 1]).astype(np.float32)
    labels = np.reshape(mnist.train.labels, [-1]).astype(np.int64)
    ims, labels = shuffle(ims, labels)
    # TODO. this makes it harder to compare. unless we do multiple runs

    test_ims = np.reshape(mnist.test.images,
                          [-1, 28, 28, 1]).astype(np.float32)
    test_labels = np.reshape(mnist.test.labels, [-1]).astype(np.int64)

    x = tf.placeholder(shape=[FLAGS.batchsize, 28, 28, 1], dtype=tf.float32)
    tf.add_to_collection('inputs', x)
    T = tf.placeholder(shape=[FLAGS.batchsize], dtype=tf.int64)
    tf.add_to_collection('targets', T)

    # set up
    global_step = tf.Variable(0, name='global_step', trainable=False)
    global_step = global_step.assign_add(1)
    main_opt = tf.train.AdamOptimizer(FLAGS.lr)
    e2e_opt = tf.train.AdamOptimizer(FLAGS.valid_lr)
    classifier_opt = tf.train.AdamOptimizer(FLAGS.valid_lr)

    # build the model
    with tf.variable_scope('representation') as scope:
        hidden = encoder(x)  # TODO hidden should be [batch, N] embeddings
        unsupervised_loss = tf.add_n(
            [get_loss_fn(name, hidden) for name in FLAGS.loss_fn.split('-')])

    with tf.variable_scope('classifier') as scope:
        logits = classifier(tf.reduce_mean(hidden, [1, 2]))
        discrim_loss = tf.reduce_mean(
            tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits,
                                                           labels=T))

    with tf.variable_scope('optimisers') as scope:
        pretrain_step = main_opt.minimize(unsupervised_loss,
                                          var_list=tf.get_collection(
                                              tf.GraphKeys.TRAINABLE_VARIABLES,
                                              scope='representation'))
        train_step = classifier_opt.minimize(
            discrim_loss,
            var_list=tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
                                       scope='classifier'))
        e2e_step = e2e_opt.minimize(discrim_loss)

    with tf.name_scope('metrics'):
        preds = tf.argmax(tf.nn.softmax(logits), axis=-1)
        acc = tf.contrib.metrics.streaming_accuracy(
            preds,
            T,
            metrics_collections='METRICS',
            updates_collections='METRIC_UPDATES')

    # summaries
    pretrain_summary = tf.summary.scalar('unsupervised', unsupervised_loss)
    discrim_summary = tf.summary.scalar('supervised', discrim_loss)
    loss_summaries = tf.summary.merge([pretrain_summary, discrim_summary])

    ################################################################################
    """
    Given that we are doing unsupervised pretraining for a discrimination task,
    it makes sense to validate our model on discrimination.
    """
    def freeze_pretrain(sess, writer, step):
        """
        Question: how useful is the pretrained representation for
        discrimination?
        Measure: validation accuracy.
        """
        # TODO. instead pick a subset of classes. or binary 1 class vs rest
        # try 10 labels?!

        ### Vanilla
        train_labels = labels
        valid_labels = test_labels
        train_ims = ims
        valid_ims = test_ims

        ### Train new classifier
        # TODO. what if we also want to validate on other tasks? such as;
        # ability to reconstruct data, MI with data, ???,
        variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
                                      scope='classifier')
        sess.run(tf.variables_initializer(variables))

        # a different subset every time we validate?!?
        # idx = np.random.randint(0, len(train_labels), FLAGS.N_labels)
        idx = range(FLAGS.N_labels)

        for e in range(FLAGS.valid_epochs):
            for i, batch_ims, batch_labels in batch(train_ims[idx],
                                                    train_labels[idx],
                                                    FLAGS.batchsize):
                L, _ = sess.run([discrim_loss, train_step], {
                    x: batch_ims,
                    T: batch_labels
                })
            print('\rvalid: train step: {} loss: {:.5f}'.format(e, L), end='')
            add_summary(writer, e + FLAGS.valid_epochs * step // 100,
                        'valid-train/freeze', L)
        validate(sess,
                 writer,
                 step,
                 x,
                 T,
                 valid_ims,
                 valid_labels,
                 FLAGS.batchsize,
                 name='freeze')

    def pretrained_endtoend(sess, writer, saver, step):
        """
        Question(s):
            - how close is the pretrained representation to the
            final learned representation (after training on labels).
            - how good the the pretrained init?
        Measure: iterations to ?, max accuracy?, l2 sitance between init and
        final weights?
        """
        ### Vanilla
        train_labels = labels
        valid_labels = test_labels
        train_ims = ims
        valid_ims = test_ims

        # save the variables before we fine tune them on labels
        saver.save(sess, FLAGS.logdir + '/', step)

        idx = range(FLAGS.N_labels)

        # TODO. need to makes use i am not overfitting here!
        # could use early stopping? but need separate valid-valid data
        for e in range(FLAGS.valid_epochs):
            for i, batch_ims, batch_labels in batch(train_ims[idx],
                                                    train_labels[idx],
                                                    FLAGS.batchsize):
                L, _ = sess.run([discrim_loss, e2e_step], {
                    x: batch_ims,
                    T: batch_labels
                })
            print('\rvalid: train step: {} loss: {:.5f}'.format(e, L), end='')
            add_summary(writer, e + FLAGS.valid_epochs * step // 100,
                        'valid-train/e2e', L)
        validate(sess,
                 writer,
                 step,
                 x,
                 T,
                 valid_ims,
                 valid_labels,
                 FLAGS.batchsize,
                 name='e2e')
        # restore original variables to continue pretrianing
        ckpt = tf.train.latest_checkpoint(FLAGS.logdir)
        saver.restore(sess, ckpt)

    def semisupervised(sess, writer, step, batch_ims, batch_labels):
        """
        Question:
            - how can extra unlabelled data be used to help a small set of
            labels generalise?
            - how does adding labels into unsupervised training effect the
            representation learnt?
        """
        L, _ = sess.run([discrim_loss, e2e_step], {
            x: batch_ims,
            T: batch_labels
        })

        validate(sess, writer, step, x, T, valid_ims, valid_labels,
                 FLAGS.batchsize)

    def embed(sess, step):
        """
        Let's have a look at the hidden representations learned by our different
        methods. We need to run our model on a subset of data, collect the
        hidden representations and then save into a fake checkpoint.
        """
        # get embeddings from tensorflow
        X = []
        H = []
        L = []
        for i, batch_ims, batch_labels in batch(ims, labels, FLAGS.batchsize):
            if i >= 10000 // FLAGS.batchsize: break
            # print('\r embed step {}'.format(i), end='', flush=True)
            X.append(batch_ims)
            H.append(sess.run(hidden, feed_dict={x: batch_ims}))
            L.append(batch_labels.reshape(FLAGS.batchsize))

        save_embeddings(os.path.join(FLAGS.logdir, 'embedding' + str(step)),
                        np.vstack(H),
                        np.vstack(L).reshape(10000),
                        images=np.vstack(X))

    with tf.Session() as sess:
        run_options, run_metadata = profile()
        writer = tf.summary.FileWriter(FLAGS.logdir, sess.graph)
        saver = tf.train.Saver(
            tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
                              scope='representation'))
        sess.run(tf.global_variables_initializer())
        sess.run(tf.local_variables_initializer())

        for e in range(FLAGS.epochs):
            for _, batch_ims, batch_labels in batch(ims, labels,
                                                    FLAGS.batchsize):
                step, L, _ = sess.run(
                    [global_step, unsupervised_loss, pretrain_step], {
                        x: batch_ims,
                        T: batch_labels
                    },
                    options=run_options,
                    run_metadata=run_metadata)
                print('\rtrain step: {} loss: {:.5f}'.format(step, L), end='')

                # semi-supervised learning
                # TODO. want a better way to sample labels
                idx = np.random.randint(0, FLAGS.N_labels, FLAGS.batchsize)
                _ = sess.run(e2e_step, {x: ims[idx], T: labels[idx]})
                # TODO. how does running the update together effect things?
                # TODO. what about elastic weight consolidation? treating
                # semi supervised learning as a type of transfer!?

                if step % 20 == 0:
                    summ = sess.run(loss_summaries, {
                        x: batch_ims,
                        T: batch_labels
                    })
                    writer.add_summary(summ, step)

                if step % 100 == 0:
                    # pretrained_endtoend(sess, writer, saver, step)
                    # freeze_pretrain(sess, writer, step)
                    validate(sess,
                             writer,
                             step,
                             x,
                             T,
                             test_ims,
                             test_labels,
                             FLAGS.batchsize,
                             name='super')

                if step == 30:
                    trace(run_metadata, FLAGS.logdir)

                if step % 500 == 0:
                    var = tf.get_collection('random_vars')
                    sess.run(tf.variables_initializer(var))
def plugin_estimator_training_loop(generator, discriminator, dataloader,
                                   learning_rate, latent_dim, loss_function,
                                   optim, device, total_iters,
                                   checkpoint_intervals, batchsize, algorithm,
                                   n_discard, save_dir, save_suffix):
    """
    Function to train a generator and discriminator for a GAN using
    a plugin mean estimation algorithm for total_iters with learning_rate
    """
    optimizer_cons = utils.get_optimizer_cons(optim, learning_rate)

    disc_optimizer = optimizer_cons(discriminator.parameters())
    gen_optimizer = optimizer_cons(generator.parameters())

    discriminator_loss, generator_loss = losses.get_loss_fn(loss_function)
    flag = False
    iteration = 0
    while not flag:
        for real_image_batch, _ in dataloader:
            # Update iteration counter
            iteration += 1

            # Discriminator: standard training
            fake_image_batch = generator(
                torch.randn(real_image_batch.shape[0],
                            latent_dim,
                            device=device))
            real_pred = discriminator(real_image_batch.to(device))
            fake_pred = discriminator(fake_image_batch.detach())
            real_loss, fake_loss = discriminator_loss(real_pred, fake_pred)
            disc_loss = torch.mean(real_loss + fake_loss)
            disc_optimizer.zero_grad()
            disc_loss.backward()
            disc_optimizer.step()

            if algorithm.__name__ == 'mean':
                fake_images = generator(
                    torch.randn(batchsize, latent_dim, device=device))
                fake_preds = discriminator(fake_images).squeeze()
                gen_loss = generator_loss(fake_preds).mean()
                gen_optimizer.zero_grad()
                gen_loss.backward()
            else:
                # Generator: proper mean estimation
                # First sample gradients
                sgradients = utils.gradient_sampler(discriminator,
                                                    generator,
                                                    generator_loss,
                                                    noise_batchsize=batchsize,
                                                    latent_dim=latent_dim,
                                                    device=device)
                # Then get the estimate with the mean estimation algorithm
                stoc_grad = algorithm(sgradients,
                                      n_discard=n_discard(iteration))
                # Perform the update of .grad attributes
                with torch.no_grad():
                    utils.update_grad_attributes(
                        generator.parameters(),
                        torch.as_tensor(stoc_grad, device=device))
            # Perform the update
            gen_optimizer.step()

            if iteration in checkpoint_intervals:
                print(f"Completed {iteration}")
                torch.save(
                    generator.state_dict(),
                    f"{save_dir}/generator_{algorithm.__name__}_{iteration}_{save_suffix}.pt"
                )
                torch.save(
                    discriminator.state_dict(),
                    f"{save_dir}/discriminator_{algorithm.__name__}_{iteration}_{save_suffix}.pt"
                )

            if iteration == total_iters:
                flag = True
                break

    return generator, discriminator
def streaming_approx_training_loop(generator, discriminator, dataloader,
                                   learning_rate, latent_dim, loss_function,
                                   optim, device, total_iters,
                                   checkpoint_intervals, alpha, batchsize,
                                   n_discard, save_dir, save_suffix):
    """
    Function to train a generator and discriminator for a GAN using
    the streaming rank-1 approximation with algorithm for total_iters
    with optimizer optim
    """
    optimizer_cons = utils.get_optimizer_cons(optim, learning_rate)

    disc_optimizer = optimizer_cons(discriminator.parameters())
    gen_optimizer = optimizer_cons(generator.parameters())

    discriminator_loss, generator_loss = losses.get_loss_fn(loss_function)
    flag = False
    iteration = 0
    top_eigvec, top_eigval, running_mean = None, None, None

    while not flag:
        for real_image_batch, _ in dataloader:
            # Update iteration counter
            iteration += 1

            # Discriminator: standard training
            fake_image_batch = generator(
                torch.randn(real_image_batch.shape[0],
                            latent_dim,
                            device=device))
            real_pred = discriminator(real_image_batch.to(device))
            fake_pred = discriminator(fake_image_batch.detach())
            real_loss, fake_loss = discriminator_loss(real_pred, fake_pred)
            disc_loss = torch.mean(real_loss + fake_loss)
            disc_optimizer.zero_grad()
            disc_loss.backward()
            disc_optimizer.step()

            # Generator: proper mean estimation
            # First sample gradients
            sgradients = utils.gradient_sampler(discriminator,
                                                generator,
                                                generator_loss,
                                                noise_batchsize=batchsize,
                                                latent_dim=latent_dim,
                                                device=device)
            # Then get the estimate with the previously computed direction
            stoc_grad, top_eigvec, top_eigval, running_mean = streaming_update_algorithm(
                sgradients,
                n_discard=n_discard(iteration),
                top_v=top_eigvec,
                top_lambda=top_eigval,
                old_mean=running_mean,
                alpha=alpha)
            # Perform the update of .grad attributes
            with torch.no_grad():
                utils.update_grad_attributes(
                    generator.parameters(),
                    torch.as_tensor(stoc_grad, device=device))
            # Perform the update
            gen_optimizer.step()

            if iteration in checkpoint_intervals:
                print(f"Completed {iteration}")
                torch.save(
                    generator.state_dict(),
                    f"{save_dir}/generator_streaming_approx_{iteration}_{save_suffix}.pt"
                )
                torch.save(
                    discriminator.state_dict(),
                    f"{save_dir}/discriminator_streaming_approx_{iteration}_{save_suffix}.pt"
                )

            if iteration == total_iters:
                flag = True
                break

    return generator, discriminator