def run_epoch_fast(loader,
                   model,
                   criterion,
                   optimizer,
                   device_type='cuda',
                   epoch=0,
                   n_epochs=0,
                   train=True,
                   log_every_step=True):
    time_meter = Meter(name='Time', cum=True)
    loss_meter = Meter(name='Loss', cum=False)
    error_meter = Meter(name='Error', cum=False)

    if train:
        model.train()
        print('Training')
    else:
        model.eval()
        print('Evaluating')

    end = time.time()
    for i, (input, target) in enumerate(loader):
        if train:
            model.zero_grad()
            optimizer.zero_grad()

            # Forward pass
            input = move_to_device(input, device_type, False)
            target = move_to_device(target, device_type, False)
            output = model(input)
            loss = criterion(output, target)

            # Backward pass
            if loss.item() > 0:
                loss.backward()
                optimizer.step()
            optimizer.n_iters = optimizer.n_iters + 1 if hasattr(
                optimizer, 'n_iters') else 1

        else:
            with torch.no_grad():
                # Forward pass
                input = move_to_device(input, device_type, False)
                target = move_to_device(target, device_type, False)
                output = model(input)
                loss = criterion(output, target)

        # Accounting
        _, predictions = torch.topk(output, 1)
        error = 1 - torch.eq(torch.squeeze(predictions), target).float().mean()
        batch_time = time.time() - end
        end = time.time()

        # Log errors
        time_meter.update(batch_time)
        loss_meter.update(loss)
        error_meter.update(error)
        if log_every_step:
            for param_group in optimizer.param_groups:
                lr_value = param_group['lr']
            print('  '.join([
                '%s: (Epoch %d of %d) [%04d/%04d]' %
                ('Train' if train else 'Eval', epoch, n_epochs, i + 1,
                 len(loader)),
                str(time_meter),
                str(loss_meter),
                str(error_meter),
                '%.4f' % lr_value
            ]))

    if not log_every_step:
        print('  '.join([
            #'%s: (Epoch %d of %d) [%04d/%04d]' % ('Train' if train else 'Eval',
            #epoch, n_epochs, i + 1, len(loader)),
            '%s: (Epoch %d of %d)' %
            ('Train' if train else 'Eval', epoch, n_epochs),
            str(time_meter),
            str(loss_meter),
            str(error_meter),
        ]))

    return time_meter.value(), loss_meter.value(), error_meter.value()
def run_epoch_perturb(loader,
                      model,
                      perturbation,
                      onehot_criterion,
                      soft_criterion,
                      optimizer,
                      device_type='cuda',
                      epoch=0,
                      n_epochs=0,
                      train=True,
                      class_num=100,
                      perturb_degree=5,
                      sample_num=20,
                      alpha=0.5,
                      pt_input_ratio=0.25,
                      log_every_step=True):
    time_meter = Meter(name='Time', cum=True)
    loss_meter = Meter(name='Loss', cum=False)
    error_meter = Meter(name='Error', cum=False)

    if train:
        model.train()
        print('Training')

    else:
        model.eval()
        print('Evaluating')

    pt_random = torch.randn([sample_num, 3, 32, 32])
    end = time.time()
    for i, (input, target) in enumerate(loader):
        if train:
            model.zero_grad()
            optimizer.zero_grad()

            # Forward pass
            input_size = input.size()
            pt_batch = int(input_size[0] * pt_input_ratio)
            torch.nn.init.normal_(pt_random)
            pt_random.requires_grad = False
            pt_flat = torch.flatten(pt_random, start_dim=1)
            p_n = torch.norm(
                pt_flat, p=2, dim=1,
                keepdim=True).unsqueeze(1).unsqueeze(1).expand_as(pt_random)
            pt = pt_random.div_(p_n)
            pt_ = pt.unsqueeze(0).expand(pt_batch, sample_num, input_size[1],
                                         input_size[2], input_size[3])
            pt_input = input[:pt_batch].unsqueeze(1).expand_as(pt_) + pt_
            pt_input = torch.reshape(pt_input,
                                     (pt_batch * sample_num, input_size[1],
                                      input_size[2], input_size[3]))
            input = move_to_device(input, device_type, False)
            target = move_to_device(target, device_type, False)
            pt_input = move_to_device(pt_input, device_type, False)
            pt_target = target[:pt_batch].unsqueeze(1)
            p_logits = model.forward(pt_input)
            p_outputs = torch.argmax(p_logits, dim=1)
            p_outputs = torch.reshape(p_outputs, (pt_batch, sample_num))
            pt_output_sum = torch.sum(torch.eq(
                p_outputs, pt_target.expand_as(p_outputs)).float(),
                                      dim=1,
                                      keepdim=True)
            '''
            for j in range(sample_num):
                pt = torch.randn_like(input[0])
                pt_flat = torch.flatten(pt)
                p_n = torch.norm(pt_flat, p=2)
                pt = pt.div(p_n)
                pt = pt.unsqueeze(0).expand_as(input) * perturb_degree
                pt = move_to_device(pt, device_type)
                pt_input = input + pt
                #pt_input = torch.clamp(pt_input, 0, 1) # input already normalized, TODO unnormalize first
                p_logits = model.forward(pt_input) 
                p_outputs = torch.argmax(p_logits, dim=1, keepdim=True)
                pt_output_sum = pt_output_sum + torch.eq(p_outputs, target_).float()
            '''
            pt_output_mean = torch.div(pt_output_sum, sample_num)
            pt_target = smooth_label(pt_target, pt_output_mean, class_num)

            output = model(input)
            onehot_loss = onehot_criterion(output, target)
            pt_target = pt_target.detach()
            perturb_loss = soft_criterion(output[:pt_batch], pt_target)
            loss = alpha * onehot_loss + (
                1 - alpha) * pt_input_ratio * perturb_loss

            # Backward pass
            loss.backward()
            optimizer.step()
            optimizer.n_iters = optimizer.n_iters + 1 if hasattr(
                optimizer, 'n_iters') else 1

        else:
            with torch.no_grad():
                # Forward pass
                input = move_to_device(input, device_type, False)
                target = move_to_device(target, device_type, False)
                output = model(input)
                loss = onehot_criterion(output, target)

        # Accounting
        _, predictions = torch.topk(output, 1)
        error = 1 - torch.eq(torch.squeeze(predictions), target).float().mean()
        batch_time = time.time() - end
        end = time.time()

        # Log errors
        time_meter.update(batch_time)
        loss_meter.update(loss)
        error_meter.update(error)
        if log_every_step:
            for param_group in optimizer.param_groups:
                lr_value = param_group['lr']
            print('  '.join([
                '%s: (Epoch %d of %d) [%04d/%04d]' %
                ('Train' if train else 'Eval', epoch, n_epochs, i + 1,
                 len(loader)),
                str(time_meter),
                str(loss_meter),
                str(error_meter),
                '%.4f' % lr_value
            ]))

    print('pt_output_mean')
    print(pt_output_mean)
    print('onehot_loss:' + str(onehot_loss))
    print('perturb_loss:' + str(perturb_loss))
    if not log_every_step:
        print('  '.join([
            #'%s: (Epoch %d of %d) [%04d/%04d]' % ('Train' if train else 'Eval',
            #epoch, n_epochs, i + 1, len(loader)),
            '%s: (Epoch %d of %d)' %
            ('Train' if train else 'Eval', epoch, n_epochs),
            str(time_meter),
            str(loss_meter),
            str(error_meter),
        ]))

    return time_meter.value(), loss_meter.value(), error_meter.value()