Esempio n. 1
0
def validate(params,
             save=False,
             adversarial=False,
             adversarial_attack=None,
             whitebox=True,
             adversary_model=None,
             adversary_criterion=None):
    '''
    Performs validation of params['model'] using 
    params['val_dataloader']. 
    
    Keyword arguments:
    > params (dict) -- current state parameters
    > save (bool) -- whether to save val accuracies.
        Should only be `True` when called from train loop!
    > adversarial (bool) -- whether to test adversarially.
    > adversarial_attack (string) -- name of adversarial
        attack.
    > whitebox (bool) -- whether to use a whitebox attack.
    > adversary_model (torch.nn.Module) -- pre-trained
        model to generate black-box attacks with.
    > adversary_criterion (torch.nn.[loss]) -- loss func
        for the black-box "imitator" model.
        
    Returns: N/A
    '''
    if params['model'] is None:
        print(
            'No model loaded! Type -n to create a new model, or -l to load an existing one from file.\n'
        )
        return

    # Sets up training statistics to be logged to console (and possibly file -- TODO -- ?) output.
    extension = 'adversarial' if adversarial else 'non-adversarial'
    print(color.PURPLE + '\n--- BEGIN (' + extension +
          ') VALIDATION PASS ---' + color.END)
    batch_time = train_utils.AverageMeter('Time', ':5.3f')
    losses = train_utils.AverageMeter('Loss', ':.4e')
    if params['is_generator']:
        progress = train_utils.ProgressMeter(len(params['val_dataloader']),
                                             batch_time,
                                             losses,
                                             prefix='Test: ')
    else:
        top1 = train_utils.AverageMeter('Acc@1', ':5.2f')
        progress = train_utils.ProgressMeter(len(params['val_dataloader']),
                                             batch_time,
                                             losses,
                                             top1,
                                             prefix='Test: ')

    # Switch model to evaluate mode; push to GPU
    params['model'].eval()
    setup_cuda(params)

    end = time.time()
    for i, (data, target) in enumerate(params['val_dataloader']):

        # Pushes data to GPU
        data = data.to(params['device'])
        target = target.to(params['device'])

        # Generate adversarial attack (default whitebox mode)
        if adversarial:

            if whitebox:
                data = adversary.attack_batch(data,
                                              target,
                                              params['model'],
                                              params['criterion'],
                                              attack_name=adversarial_attack,
                                              device=params['device'],
                                              epsilon=0.3,
                                              alpha=0.05)
            else:
                data = adversary.attack_batch(data,
                                              target,
                                              adversary_model,
                                              adversary_criterion,
                                              attack_name=adversarial_attack,
                                              device=params['device'],
                                              epsilon=0.3,
                                              alpha=0.05)

        with torch.no_grad():

            # compute output
            output = params['model'](data)
            loss = params['criterion'](output, target)
            losses.update(loss.item(), data.size(0))

            # measure accuracy and record loss
            if not params['is_generator']:
                acc1 = accuracy(output, target)[0]
                top1.update(acc1[0], data.size(0))

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            if i % params['print_frequency'] == 0:
                # Storing validation losses/accuracies
                if save:
                    params['val_losses'].append(losses.get_avg())
                    if not params['is_generator']:
                        params['val_accuracies'].append(top1.get_avg())
                progress.print(i)

    # Print final accuracy/loss
    progress.print(len(params['val_dataloader']))
    if not params['is_generator']:
        print(color.GREEN + ' * Acc@1 {top1.avg:.3f}'.format(top1=top1) +
              color.END)

    # Update train/val accuracy/loss plots
    if save:
        viz_utils.plot_accuracies(params)
        viz_utils.plot_losses(params)

    print(color.PURPLE + '--- END VALIDATION PASS ---\n' + color.END)

    if not params['is_generator']:
        return acc1
Esempio n. 2
0
def run_epoch(split,
              model,
              opt,
              train_data,
              val_data=None,
              batch_size=100,
              upto=None,
              epoch_num=None,
              epochs=1,
              verbose=False,
              log_every=10,
              return_losses=False,
              table_bits=None,
              warmups=1000,
              loader=None,
              constant_lr=None,
              use_meters=True,
              summary_writer=None,
              lr_scheduler=None,
              custom_lr_lambda=None,
              label_smoothing=0.0):
    torch.set_grad_enabled(split == 'train')
    model.train() if split == 'train' else model.eval()
    dataset = train_data if split == 'train' else val_data
    losses = []

    if loader is None:
        loader = data.DataLoader(dataset,
                                 batch_size=batch_size,
                                 shuffle=(split == 'train'))

    # How many orderings to run for the same batch?
    nsamples = 1
    if hasattr(model, 'orderings'):
        nsamples = len(model.orderings)
        if verbose:
            print('setting nsamples to', nsamples)

    dur_meter = train_utils.AverageMeter('dur',
                                         lambda v: '{:.0f}s'.format(v),
                                         display_average=False)
    lr_meter = train_utils.AverageMeter('lr', ':.5f', display_average=False)
    tups_meter = train_utils.AverageMeter('tups',
                                          utils.HumanFormat,
                                          display_average=False)
    loss_meter = train_utils.AverageMeter('loss (bits/tup)', ':.2f')
    train_throughput = train_utils.AverageMeter('tups/s',
                                                utils.HumanFormat,
                                                display_average=False)
    batch_time = train_utils.AverageMeter('sgd_ms', ':3.1f')
    data_time = train_utils.AverageMeter('data_ms', ':3.1f')
    progress = train_utils.ProgressMeter(upto, [
        batch_time,
        data_time,
        dur_meter,
        lr_meter,
        tups_meter,
        train_throughput,
        loss_meter,
    ])

    begin_time = t1 = time.time()

    for step, xb in enumerate(loader):
        data_time.update((time.time() - t1) * 1e3)

        if split == 'train':
            if isinstance(dataset, data.IterableDataset):
                # Can't call len(loader).
                global_steps = upto * epoch_num + step + 1
            else:
                global_steps = len(loader) * epoch_num + step + 1

            if constant_lr:
                lr = constant_lr
                for param_group in opt.param_groups:
                    param_group['lr'] = lr
            elif custom_lr_lambda:
                lr_scheduler = None
                lr = custom_lr_lambda(global_steps)
                for param_group in opt.param_groups:
                    param_group['lr'] = lr
            elif lr_scheduler is None:
                t = warmups
                if warmups < 1:  # A ratio.
                    t = int(warmups * upto * epochs)

                d_model = model.embed_size
                lr = (d_model**-0.5) * min(
                    (global_steps**-.5), global_steps * (t**-1.5))
                for param_group in opt.param_groups:
                    param_group['lr'] = lr
            else:
                # We'll call lr_scheduler.step() below.
                lr = opt.param_groups[0]['lr']

        if upto and step >= upto:
            break

        if isinstance(xb, list):
            # This happens if using data.TensorDataset.
            assert len(xb) == 1, xb
            xb = xb[0]

        xb = xb.float().to(train_utils.get_device(), non_blocking=True)

        # Forward pass, potentially through several orderings.
        xbhat = None
        model_logits = []
        num_orders_to_forward = 1
        if split == 'test' and nsamples > 1:
            # At test, we want to test the 'true' nll under all orderings.
            num_orders_to_forward = nsamples

        for i in range(num_orders_to_forward):
            if hasattr(model, 'update_masks'):
                # We want to update_masks even for first ever batch.
                model.update_masks()

            model_out = model(xb)
            model_logits.append(model_out)
            if xbhat is None:
                xbhat = torch.zeros_like(model_out)
            xbhat += model_out

        if num_orders_to_forward == 1:
            loss = model.nll(xbhat, xb, label_smoothing=label_smoothing).mean()
        else:
            # Average across orderings & then across minibatch.
            #
            #   p(x) = 1/N sum_i p_i(x)
            #   log(p(x)) = log(1/N) + log(sum_i p_i(x))
            #             = log(1/N) + logsumexp ( log p_i(x) )
            #             = log(1/N) + logsumexp ( - nll_i (x) )
            #
            # Used only at test time.
            logps = []  # [batch size, num orders]
            assert len(model_logits) == num_orders_to_forward, len(
                model_logits)
            for logits in model_logits:
                # Note the minus.
                logps.append(
                    -model.nll(logits, xb, label_smoothing=label_smoothing))
            logps = torch.stack(logps, dim=1)
            logps = logps.logsumexp(dim=1) + torch.log(
                torch.tensor(1.0 / nsamples, device=logps.device))
            loss = (-logps).mean()

        losses.append(loss.detach().item())

        if split == 'train':
            opt.zero_grad()
            loss.backward()
            l2_grad_norm = TotalGradNorm(model.parameters())

            opt.step()
            if lr_scheduler is not None:
                lr_scheduler.step()

            loss_bits = loss.item() / np.log(2)

            # Number of tuples processed in this epoch so far.
            ntuples = (step + 1) * batch_size
            if use_meters:
                dur = time.time() - begin_time
                lr_meter.update(lr)
                tups_meter.update(ntuples)
                loss_meter.update(loss_bits)
                dur_meter.update(dur)
                train_throughput.update(ntuples / dur)

            if summary_writer is not None:
                wandb.log({
                    'train/lr': lr,
                    'train/tups': ntuples,
                    'train/tups_per_sec': ntuples / dur,
                    'train/nll': loss_bits,
                    'train/global_step': global_steps,
                    'train/l2_grad_norm': l2_grad_norm,
                })
                summary_writer.add_scalar('train/lr',
                                          lr,
                                          global_step=global_steps)
                summary_writer.add_scalar('train/tups',
                                          ntuples,
                                          global_step=global_steps)
                summary_writer.add_scalar('train/tups_per_sec',
                                          ntuples / dur,
                                          global_step=global_steps)
                summary_writer.add_scalar('train/nll',
                                          loss_bits,
                                          global_step=global_steps)

            if step % log_every == 0:
                if table_bits:
                    print(
                        'Epoch {} Iter {}, {} entropy gap {:.4f} bits (loss {:.3f}, data {:.3f}) {:.5f} lr, {} tuples seen ({} tup/s)'
                        .format(
                            epoch_num, step, split,
                            loss.item() / np.log(2) - table_bits,
                            loss.item() / np.log(2), table_bits, lr,
                            utils.HumanFormat(ntuples),
                            utils.HumanFormat(ntuples /
                                              (time.time() - begin_time))))
                elif not use_meters:
                    print(
                        'Epoch {} Iter {}, {} loss {:.3f} bits/tuple, {:.5f} lr'
                        .format(epoch_num, step, split,
                                loss.item() / np.log(2), lr))

        if verbose:
            print('%s epoch average loss: %f' % (split, np.mean(losses)))

        batch_time.update((time.time() - t1) * 1e3)
        t1 = time.time()
        if split == 'train' and step % log_every == 0 and use_meters:
            progress.display(step)

    if return_losses:
        return losses
    return np.mean(losses)
Esempio n. 3
0
def train_one_epoch(epoch, params, classifier_state=None):
    '''
    Trains model given in params['model'] for a single epoch.
    
    Keyword arguments:
    > epoch (int) -- current training epoch
    > params (dict) -- current state parameters
    
    Returns: N/A
    '''
    # Saves statistics about epoch (TODO -- pipe to file?)
    batch_time = train_utils.AverageMeter('Time', ':.3f')
    data_time = train_utils.AverageMeter('Data', ':.3f')
    losses = train_utils.AverageMeter('Loss', ':.3e')
    if params['is_generator']:
        progress = train_utils.ProgressMeter(
            len(params['train_dataloader']),
            batch_time,
            losses,
            prefix='Epoch: [{}]'.format(epoch))
    else:
        top1 = train_utils.AverageMeter('Acc@1', ':4.2f')
        if params['adversarial_train']:
            top1_adv = train_utils.AverageMeter('Adv@1', ':4.2f')
            progress = train_utils.ProgressMeter(
                len(params['train_dataloader']),
                batch_time,
                losses,
                top1,
                top1_adv,
                prefix='Epoch: [{}]'.format(epoch))
        else:
            progress = train_utils.ProgressMeter(
                len(params['train_dataloader']),
                batch_time,
                losses,
                top1,
                prefix='Epoch: [{}]'.format(epoch))

    # Switch to train mode. Important for dropout and batchnorm.
    params['model'].train()

    end = time.time()
    for i, (data, target) in enumerate(params['train_dataloader']):
        # Measure data loading time
        data_time.update(time.time() - end)

        # Sends input/label tensors to GPU
        data = data.to(params['device'])
        target = target.to(params['device'])

        # Generate and separately perform forward pass on adversarial examples
        if params['adversarial_train']:
            if not params['is_generator']:
                params['model'].eval()
                perturbed_data = adversary.attack_batch(
                    data,
                    target,
                    params['model'],
                    params['criterion'],
                    attack_name='FGSM',
                    device=params['device'])
                perturbed_target = target.clone()
                params['model'].train()
                perturbed_output = params['model'](perturbed_data)
            else:
                # Setup
                perturbed_loss = 0
                classifier_state['model'].eval()

                for epsilon in constants.ADV_VAE_EPSILONS:
                    for attack_name in constants.ADV_VAE_ATTACKS:
                        # Get perturbed batch
                        perturbed_data = adversary.attack_batch(
                            data,
                            target,
                            classifier_state['model'],
                            classifier_state['criterion'],
                            attack_name=attack_name,
                            device=classifier_state['device'],
                            epsilon=epsilon)
                        clean_data = data.clone()
                        perturbed_data, recon, mu, logvar = params['model'](
                            perturbed_data)
                        perturbed_output = clean_data, recon, mu, logvar
                        perturbed_loss = perturbed_loss + params['criterion'](
                            perturbed_output, target)

                # CW batch
                # if i % constants.CW_SPLITS == epoch % constants.CW_SPLITS:
                perturbed_data = adversary.attack_batch(
                    data,
                    target,
                    classifier_state['model'],
                    classifier_state['criterion'],
                    attack_name='CW',
                    device=classifier_state['device'])
                clean_data = data.clone()
                perturbed_data, recon, mu, logvar = params['model'](
                    perturbed_data)
                perturbed_output = clean_data, recon, mu, logvar
                perturbed_loss = perturbed_loss + (
                    params['criterion'](perturbed_output, target) *
                    len(constants.ADV_VAE_EPSILONS))

                # Assume that we are not using CW, for if we do, we most certainly will not do four of them.
                perturbed_loss = perturbed_loss / (
                    len(constants.ADV_VAE_EPSILONS) *
                    (len(constants.ADV_VAE_ATTACKS) + 1))

        # Compute output
        output = params['model'](data)

        # Adversarial train uses slightly different criterion
        if params['adversarial_train']:
            if params['is_generator']:
                loss = params['alpha'] * params['criterion'](output, target) + \
                    (1 - params['alpha']) * perturbed_loss
            else:
                loss = params['alpha'] * params['criterion'](output, target) + \
                    (1 - params['alpha']) * params['criterion'](perturbed_output, perturbed_target)
        else:
            loss = params['criterion'](output, target)

        # Measure accuracy and record loss
        losses.update(loss.item(), data.size(0))
        if not params['is_generator']:
            acc1 = accuracy(output, target)[0]
            top1.update(acc1[0], data.size(0))

        if params['adversarial_train'] and not params['is_generator']:
            adv_acc1 = accuracy(perturbed_output, perturbed_target)[0]
            top1_adv.update(adv_acc1[0], perturbed_data.size(0))

        # Compute gradient and do SGD step
        params['optimizer'].zero_grad()
        loss.backward()
        params['optimizer'].step()

        # Measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        # Prints/stores training acc/losses
        if i % params['print_frequency'] == 0:
            params['train_losses'].append(losses.get_avg())
            if not params['is_generator']:
                params['train_accuracies'].append(top1.get_avg())
            progress.print(i)

    # Prints final training accuracy per epoch
    progress.print(len(params['train_dataloader']))

    # Update train/val accuracy/loss plots
    viz_utils.plot_accuracies(params)
    viz_utils.plot_losses(params)

    if not params['is_generator']:
        return acc1