Example #1
0
def compute_features(model, use_flip, batch_size, workers, data_path):
    ccrop = transforms.Compose([
        ToArray(),
        Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], data_format='CHW'),
    ])
    ref_dataset = EvalDataset(
        os.path.dirname(data_path),
        os.path.basename(data_path),
        ccrop)
    eval_loader = paddle.io.DataLoader(
        ref_dataset,
        batch_size=batch_size, shuffle=False, drop_last=False,
        num_workers=workers)
    batch_time = AverageMeter('Time', ':6.3f')
    progress = ProgressMeter(
        len(eval_loader),
        [batch_time])
    outputs, targets = [], []
    end = time.time()
    for i, (images, target) in enumerate(eval_loader):
        targets.extend(target)
        # compute output
        output = model(images, im_k=None, use_flip=use_flip, is_train=False)
        outputs.append(output)
        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()
        if i % 10 == 0:
            progress.display(i)
    embeddings = paddle.concat(outputs)
    return embeddings, targets
Example #2
0
def predict_on_model(model, batch_data, device,out_path):
    epoch_problem_acc = AverageMeter()
    epoch_binary_acc=AverageMeter()
    batch_cnt = len(batch_data)
    start_time = time.time()
    for bnum, batch in enumerate(batch_data):
        # batch data
        contents, question_ans, sample_labels, sample_ids, sample_categorys, sample_logics = batch
        contents = contents.to(device)
        question_ans = question_ans.to(device)
        sample_labels = sample_labels.to(device)
        sample_logics = sample_logics.to(device)
        # contents:batch_size*10*200,  question_ans:batch_size*100  ,sample_labels=batchsize
        # forward
        pred_labels = model.forward(contents, question_ans, sample_logics)  # pred_labels size=(batch,2)

        binary_acc = compute_binary_accuracy(pred_labels, sample_labels)
        problem_acc = compute_problems_accuracy(pred_labels, sample_labels, sample_ids)

        epoch_problem_acc.update(problem_acc.item(), int(len(sample_ids) / 5))
        epoch_binary_acc.update(binary_acc.item(),len(sample_ids))

        logger.info('batch=%d/%d, binary_acc=%.4f  problem_acc=%.4f' % (bnum, batch_cnt,binary_acc, problem_acc))
        # save result to csv file
        save_test_result_to_csv(sample_ids,pred_labels,sample_labels,sample_categorys,sample_logics,out_path)

    test_time = time.time() - start_time
    logger.info('===== test completed, avg_problem_acc=%.4f, eval_time=%.1f====' % (epoch_problem_acc.avg, test_time))

    return 0
Example #3
0
def cal_acc(dataloader, model, num_classes, device):
    accuracy = AverageMeter()
    model.eval()

    cls_count = np.zeros(num_classes, dtype=np.float32)
    cls_correct = np.zeros(num_classes, dtype=np.float32)

    with torch.no_grad():
        for i, (images, labels) in enumerate(dataloader):
            batch_size = images.size(0)
            images = images.to(device)
            labels = labels.to(device)
            outputs = model(images)

            for gt_label in labels:
                cls_count[int(gt_label.item())] += 1

            _, preds = torch.max(outputs, 1)
            for corr_pred in labels[preds == labels.data]:
                cls_correct[int(corr_pred.item())] += 1

            acc = torch.mean((preds == labels.data).float())
            cls_acc = cls_correct / (cls_count + 1e-8)

            accuracy.update(acc.item(), batch_size)

            print(time.strftime('%m/%d %H:%M:%S', time.localtime()), end='\t')
            print('Test: [{0}/{1}] '
                  'Acc: {acc.val:.3f}({acc.avg:.3f})'.format(i + 1,
                                                             len(dataloader),
                                                             acc=accuracy),
                  flush=True)

    return accuracy.avg, cls_acc
Example #4
0
    def train(epoch):
        print('Using KL Lambda: {}'.format(kl_lambda))
        vae.train()
        loss_meter = AverageMeter()

        for batch_idx, (data, _) in enumerate(train_loader):
            data = Variable(data)
            if args.cuda:
                data = data.cuda()
            optimizer.zero_grad()
            recon_batch, mu, logvar = vae(data)
            # watch out for logvar -- could explode if learning rate is too high.
            loss = loss_function(mu,
                                 logvar,
                                 recon_image=recon_batch,
                                 image=data,
                                 kl_lambda=kl_lambda,
                                 lambda_xy=1.)
            loss_meter.update(loss.data[0], len(data))
            loss.backward()
            optimizer.step()
            if batch_idx % args.log_interval == 0:
                print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                    epoch, batch_idx * len(data), len(train_loader.dataset),
                    100. * batch_idx / len(train_loader), loss_meter.avg))

        print('====> Epoch: {} Average loss: {:.4f}'.format(
            epoch, loss_meter.avg))
Example #5
0
def estimate_epoch_time(models, train_loader, num_batches=10000):

    optimizers = [
        torch.optim.SGD(model.parameters(), lr=0.1) for model in models
    ]
    batch_time = AverageMeter()

    end = time.time()
    for batch_idx, (input, target) in enumerate(train_loader):

        # create vaiables
        if torch.cuda.is_available():
            input = input.cuda()
            target = target.cuda()

        # fake forward and backward pass
        for optimizer, model in zip(optimizers, models):
            if torch.cuda.is_available():
                model = model.cuda()

            output = model(input)
            loss = torch.nn.functional.cross_entropy(output, target)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        batch_time.update(time.time() - end)
        end = time.time()

        if num_batches <= batch_idx:
            break

    torch.cuda.empty_cache()

    return batch_time.avg * len(train_loader)
    def train(epoch):
        vae.train()
        loss_meter = AverageMeter()

        for batch_idx, (_, data) in enumerate(train_loader):
            data = Variable(data)
            if args.cuda:
                data = data.cuda()
            optimizer.zero_grad()
            recon_batch, mu, logvar = vae(data)
            loss = loss_function(mu,
                                 logvar,
                                 recon_text=recon_batch,
                                 text=data,
                                 kl_lambda=kl_lambda,
                                 lambda_yx=1.)
            loss.backward()
            loss_meter.update(loss.data[0], len(data))
            optimizer.step()
            if batch_idx % args.log_interval == 0:
                print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                    epoch, batch_idx * len(data), len(train_loader.dataset),
                    100. * batch_idx / len(train_loader), loss_meter.avg))

        print('====> Epoch: {} Average loss: {:.4f}'.format(
            epoch, loss_meter.avg))
    def train(epoch):
        model.train()
        loss_meter = AverageMeter()

        for batch_idx, (data, _) in enumerate(train_loader):
            data = Variable(data)
            target = Variable((data.data * (args.out_dims - 1)).long())

            if args.cuda:
                data = data.cuda()
                target = target.cuda()

            optimizer.zero_grad()
            output = model(data)
            loss = cross_entropy_by_dim(output, target)
            loss_meter.update(loss.data[0], len(data))

            loss.backward()
            # clip gradients to prevent exploding gradients
            torch.nn.utils.clip_grad_norm(model.parameters(), 1.)
            optimizer.step()

            if batch_idx % args.log_interval == 0:
                print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                    epoch, batch_idx * len(data), len(train_loader.dataset),
                    100. * batch_idx / len(train_loader), loss_meter.avg))

        print('====> Epoch: {}\tLoss: {:.4f}'.format(epoch, loss_meter.avg))
Example #8
0
def eval_test_set(dataset, model, device):

    run_val_acc = AverageMeter('val_acc')
    pbar = tqdm(dataset, total=len(dataset))
    with torch.no_grad():
        for images, labels in pbar:
            images = images.to(device)
            labels = labels.to(device)
            predictions = model(images)
            acc_1 = accuracy(predictions, labels)
            run_val_acc.update(acc_1[0].item(), images.size(0))
            pbar.set_description(
                'Acc@1 {top1.val:.3f} ({top1.avg:.3f})'.format(
                    top1=run_val_acc))
    print(run_val_acc.avg)
    def test():
        model.eval()
        loss_meter = AverageMeter()

        for batch_idx, (data, _) in enumerate(test_loader):
            data = Variable(data)
            target = Variable((data.data * (args.out_dims - 1)).long())

            if args.cuda:
                data = data.cuda()
                target = target.cuda()

            output = model(data)
            loss = cross_entropy_by_dim(output, target)
            loss_meter.update(loss.data[0], len(data))

        print('====> Test Epoch\tLoss: {:.4f}'.format(loss_meter.avg))
        return loss_meter.avg
Example #10
0
    def train(epoch):
        random.seed(42)
        np.random.seed(42)  # important to have the same seed
                            # in order to make the same choices for weak supervision
                            # otherwise, we end up showing different examples over epochs
        vae.train()

        joint_loss_meter = AverageMeter()
        image_loss_meter = AverageMeter()
        text_loss_meter = AverageMeter()

        for batch_idx, (image, text) in enumerate(train_loader):
            if cuda:
                image, text = image.cuda(), text.cuda()
            image, text = Variable(image), Variable(text)
            optimizer.zero_grad()
            
            # depending on this flip, we either show it a full paired example or 
            # we show it single modalities (in which we cannot compute the full loss)
            flip = np.random.random()
            if flip < weak_perc:  # here we show a paired example
                recon_image_1, recon_text_1, mu_1, logvar_1 = vae(image, text)
                loss_1 = loss_function(mu_1, logvar_1, recon_image=recon_image_1, image=image, 
                                       recon_text=recon_text_1, text=text, kl_lambda=kl_lambda,
                                       lambda_xy=1., lambda_yx=1.)
                recon_image_2, recon_text_2, mu_2, logvar_2 = vae(image=image)
                loss_2 = loss_function(mu_2, logvar_2, recon_image=recon_image_2, image=image, 
                                       recon_text=recon_text_2, text=text, kl_lambda=kl_lambda,
                                       lambda_xy=1., lambda_yx=1.)
                recon_image_3, recon_text_3, mu_3, logvar_3 = vae(text=text)
                loss_3 = loss_function(mu_3, logvar_3, recon_image=recon_image_3, image=image, 
                                       recon_text=recon_text_3, text=text, kl_lambda=kl_lambda,
                                       lambda_xy=0., lambda_yx=1.)

                loss = loss_1 + loss_2 + loss_3
                joint_loss_meter.update(loss_1.data[0], len(image))
            
            else:  # here we show individual modalities
                recon_image_2, _, mu_2, logvar_2 = vae(image=image)
                loss_2 = loss_function(mu_2, logvar_2, recon_image=recon_image_2, image=image, 
                                       kl_lambda=kl_lambda, lambda_xy=1.)
                _, recon_text_3, mu_3, logvar_3 = vae(text=text)
                loss_3 = loss_function(mu_3, logvar_3, recon_text=recon_text_3, text=text, 
                                       kl_lambda=kl_lambda, lambda_yx=1.)
                loss = loss_2 + loss_3

            image_loss_meter.update(loss_2.data[0], len(image))
            text_loss_meter.update(loss_3.data[0], len(text))

            loss.backward()
            optimizer.step()

            if batch_idx % log_interval == 0:
                print('[Weak {:.0f}%] Train Epoch: {} [{}/{} ({:.0f}%)]\tJoint Loss: {:.6f}\tImage Loss: {:.6f}\tText Loss: {:.6f}'.format(
                    100. * weak_perc, epoch, batch_idx * len(image), len(train_loader.dataset),
                    100. * batch_idx / len(train_loader), joint_loss_meter.avg,
                    image_loss_meter.avg, text_loss_meter.avg))

        print('====> [Weak {:.0f}%] Epoch: {} Joint loss: {:.4f}\tImage loss: {:.4f}\tText loss: {:.4f}'.format(
            100. * weak_perc, epoch, joint_loss_meter.avg, image_loss_meter.avg, text_loss_meter.avg))
Example #11
0
    def train(epoch):
        vae.train()
        loss_meter = AverageMeter()

        for batch_idx, (data, _) in enumerate(train_loader):
            if args.cuda:
                data = data.cuda()
            data = Variable(data)
            optimizer.zero_grad()

            recon_data, z = vae(data)
            loss = loss_function(recon_data, data, z)
            loss_meter.update(loss.data[0], len(data))

            loss.backward()
            optimizer.step()

            if batch_idx % args.log_interval == 0:
                print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                    epoch, batch_idx * len(data), len(train_loader.dataset),
                    100. * batch_idx / len(train_loader), loss_meter.avg))

        print('====> Epoch: {}\tLoss: {:.4f}'.format(epoch, loss_meter.avg))
Example #12
0
def val_fm(model, model_dir, criterion,dataset, args):
    losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()
    # switch to evaluate mode
    trained_model = torch.load(model_dir)
    model.load_state_dict(trained_model)
    model = torch.nn.DataParallel(model).cuda()
    model.eval()
    criterion.eval()

    def part(x):
        return itertools.islice(x, int(len(x) * args.val_size))
    data_loader = part(dataset)
    for i, (input, target, meta, vid, Auxili_info, raw_test) in enumerate(data_loader):
        gc.collect()
        # meta['epoch'] = epoch
        target = target.long().cuda(async=True)
        input_var = torch.autograd.Variable(input.cuda(), volatile=True)
        target_var = torch.autograd.Variable(target.float().cuda(async=True))
        Auxili_info = torch.autograd.Variable(Auxili_info)
        output = model(input_var, Auxili_info)
        loss = criterion(output, target_var)
        output = torch.nn.Sigmoid()(output)
        prec1, prec5 = accuracy(output.data, target, topk=(1, 5))
        losses.update(loss.data[0], input.size(0))
        top1.update(prec1[0], input.size(0))
        top5.update(prec5[0], input.size(0))

        if i % int(0.1 * args.val_size * len(dataset)) == 0:
            print('Test: [{0}/{1} ({2})]\t'
                  'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                  'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
                  'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
                i, int(len(dataset) * args.val_size), len(dataset),
                loss=losses,top1=top1, top5=top5))

    print(' * Prec@1 {top1.avg:.3f} Prec@5 {top5.avg:.3f}'
          .format(top1=top1, top5=top5))

    return top1.avg, top5.avg, losses.avg
Example #13
0
            exc_key = epoch > threshold
            NetFunction.Open_block(model, keys=7, Exc=exc_key)
            NetFunction.Lock_BN_Dur_train(model, keys=7, Exc=exc_key)
            param = filter(lambda p: p.requires_grad, model.parameters())
            optimizer = torch.optim.SGD(param, opt.lr, momentum=opt.momentum, weight_decay=opt.weight_decay)
            optimizer.zero_grad()
            Two_stage_learning_rate(opt.lr, opt.lr_decay_step, optimizer, epoch, opt.lr_decay_ratio, threshold)
            print('This batch lr:\t', optimizer.param_groups[0]['lr'])

        def part(x): return itertools.islice(x, int(len(x)*opt.train_size))
        data_loader = part(train_loader)
        end = time.time()
        for i, (input, target, meta, vid, Auxili_info, raw_test) in enumerate(data_loader):
            # Image._show(Image.fromarray(raw_test[0, :, :, :].numpy().astype(np.uint8)))
            gc.collect()
            data_time.update(time.time() - end)
            meta['epoch'] = epoch
            input_var = torch.autograd.Variable(input.cuda())
            target_var = torch.autograd.Variable(target.float().cuda(async=True))
            Auxili_info = torch.autograd.Variable(Auxili_info)
            output = model(input_var, Auxili_info)
            loss = criterion(output, target_var)
            output = torch.nn.Sigmoid()(output)
            prec1, prec5 = accuracy(output.data, target, topk=(1, 5))
            losses.update(loss.data[0], input.size(0))
            top1.update(prec1[0], input.size(0))
            top5.update(prec5[0], input.size(0))

            loss.backward()
            optimizer.step()
            optimizer.zero_grad()
Example #14
0
    def train(epoch):
        random.seed(42)
        np.random.seed(42)  # important to have the same seed
        # in order to make the same choices for weak supervision
        # otherwise, we end up showing different examples over epochs
        vae.train()

        joint_loss_meter = AverageMeter()
        image_loss_meter = AverageMeter()
        text_loss_meter = AverageMeter()

        for batch_idx, (image, text) in enumerate(train_loader):
            if cuda:
                image, text = image.cuda(), text.cuda()
            image, text = Variable(image), Variable(text)
            optimizer.zero_grad()

            recon_image_1, recon_text_1, mu_1, logvar_1 = vae(image, text)
            loss = loss_function(mu_1,
                                 logvar_1,
                                 recon_image=recon_image_1,
                                 image=image,
                                 recon_text=recon_text_1,
                                 text=text,
                                 kl_lambda=kl_lambda,
                                 lambda_xy=1.,
                                 lambda_yx=1.)
            joint_loss_meter.update(loss.data[0], len(image))

            # depending on this flip, we decide whether or not to show a modality
            # versus another one.
            flip = np.random.random()

            if flip < weak_perc_m1:
                recon_image_2, recon_text_2, mu_2, logvar_2 = vae(image=image)
                loss_2 = loss_function(mu_2,
                                       logvar_2,
                                       recon_image=recon_image_2,
                                       image=image,
                                       recon_text=recon_text_2,
                                       text=text,
                                       kl_lambda=kl_lambda,
                                       lambda_xy=1.,
                                       lambda_yx=1.)
                image_loss_meter.update(loss_2.data[0], len(image))
                loss += loss_2

            flip = np.random.random()
            if flip < weak_perc_m2:
                recon_image_3, recon_text_3, mu_3, logvar_3 = vae(text=text)
                loss_3 = loss_function(mu_3,
                                       logvar_3,
                                       recon_image=recon_image_3,
                                       image=image,
                                       recon_text=recon_text_3,
                                       text=text,
                                       kl_lambda=kl_lambda,
                                       lambda_xy=0.,
                                       lambda_yx=1.)
                text_loss_meter.update(loss_3.data[0], len(text))
                loss += loss_3

            loss.backward()
            optimizer.step()

            if batch_idx % log_interval == 0:
                print(
                    '[Weak (Image) {:.0f}% | Weak (Text) {:.0f}%] Train Epoch: {} [{}/{} ({:.0f}%)]\tJoint Loss: {:.6f}\tImage Loss: {:.6f}\tText Loss: {:.6f}'
                    .format(100. * weak_perc_m1, 100. * weak_perc_m2, epoch,
                            batch_idx * len(image), len(train_loader.dataset),
                            100. * batch_idx / len(train_loader),
                            joint_loss_meter.avg, image_loss_meter.avg,
                            text_loss_meter.avg))

        print(
            '====> [Weak (Image) {:.0f}% | Weak (Text) {:.0f}%] Epoch: {} Joint loss: {:.4f}\tImage loss: {:.4f}\tText loss: {:.4f}'
            .format(100. * weak_perc_m1, 100. * weak_perc_m2, epoch,
                    joint_loss_meter.avg, image_loss_meter.avg,
                    text_loss_meter.avg))
Example #15
0
normal_mae = AverageMeter('Normal_MAE')
normal_mape = AverageMeter('Normal_MAPE')
over_mae = AverageMeter('Over_MAE')
over_mape = AverageMeter('Over_MAPE')
obese_mae = AverageMeter('Obese_MAE')
obese_mape = AverageMeter('Obese_MAPE')

with torch.no_grad():
    for img, (sex, targ) in test_loader:
        out = model(img)
        out = out.detach().cpu().numpy()
        target = targ.detach().cpu().numpy()
        mae = mean_absolute_error(target, out)
        mape = mean_absolute_percentage_error(target, out)
        if target <= 18.5:
            under_mae.update(mae)
            under_mape.update(mape)
        elif target > 18.5 and target <= 25:
            normal_mae.update(mae)
            normal_mape.update(mape)
        elif target > 25 and target <= 30:
            over_mae.update(mae)
            over_mape.update(mape)
        elif target > 30:
            obese_mae.update(mae)
            obese_mape.update(mape)

print("UnderMAE:", under_mae.avg, '\tUnderMAPE:', under_mape.avg)
print("NormalMAE:", normal_mae.avg, '\tNormalMAPE:', normal_mape.avg)
print("OverMAE:", over_mae.avg, '\tOverMAPE:', over_mape.avg)
print("ObeseMAE:", obese_mae.avg, '\tObeseMAPE:', obese_mape.avg)
Example #16
0
File: test.py Project: tjusxh/DDSL
def evaluate(args, model, loader, criterion, criterion_smooth, epoch, device):
    model.eval()

    rasterlosses = [AverageMeter() for _ in range(len(args.res))]
    smoothloss = AverageMeter()
    losses_sum = AverageMeter()
    mious = AverageMeter()
    batch_time = AverageMeter()
    data_time = AverageMeter()

    count = 0

    with torch.no_grad():
        end = time.time()
        for batch_idx, (input, target, label) in enumerate(tqdm(loader)):
            # measure data loading time
            data_time.update(time.time() - end)

            # compute output
            num = input.size(0)
            input, label = input.to(device), label.to(device)
            target = [t.to(device) for t in target]
            output = model(input)  # output shape [N, 128, 2]

            evals = [
                crit(output[:, ::l], t)
                for crit, l, t in zip(criterion, args.levels, target)
            ]
            loss_vecs = [e[0] for e in evals]
            output_rasters = [e[1] for e in evals]
            rasterlosses_ = [(lv * args.weights[label]).sum()
                             for lv in loss_vecs]

            rasterloss = torch.sum(torch.stack(rasterlosses_, dim=0), dim=0)
            output_raster = output_rasters[0]

            # compute smoothness loss
            smoothloss_ = criterion_smooth(output)
            smoothloss_ = (smoothloss_).mean()

            loss = rasterloss + args.smooth_loss * smoothloss_

            # measure miou and record loss
            miou, miou_count = masked_miou(output_raster, target[0], label,
                                           args.nclass)

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            # update statistics
            for i in range(len(args.res)):
                rasterlosses[i].update(rasterlosses_[i].item())
            smoothloss.update(smoothloss_)
            mious.update(miou, miou_count)
            losses_sum.update(rasterloss.item(), num)

            # output visualization
            if count < args.nsamples:
                # create image grid
                fname = os.path.join(args.output_dir,
                                     "sample_{}.png".format(count))
                compimg = compose_masked_img(output_raster, target[0], input,
                                             args.img_mean, args.img_std)
                imgrid = vutils.save_image(compimg,
                                           fname,
                                           normalize=False,
                                           scale_each=False)

            count += 1

        log_text = ('Test Epoch: [{0}]\t'
                    'CompTime {batch_time.sum:.3f} ({batch_time.avg:.3f})\t'
                    'DataTime {data_time.sum:.3f} ({data_time.avg:.3f})\t'
                    'Loss {loss.avg:.4f}\t'
                    'mIoU {miou:.3f}\t').format(epoch,
                                                batch_time=batch_time,
                                                data_time=data_time,
                                                loss=losses_sum,
                                                miou=mious.avgcavg)
        print(log_text)
        # tabulate mean iou
        print(
            tabulate(dict(zip(args.label_names, [[iou] for iou in mious.avg])),
                     headers="keys"))
def attack(model, model_name, loader, start_eps, end_eps, max_eps, norm,
           logger, verbose, method, **kwargs):
    torch.manual_seed(6247423)
    num_class = 10
    losses = AverageMeter()
    l1_losses = AverageMeter()
    errors = AverageMeter()
    robust_errors = AverageMeter()
    regular_ce_losses = AverageMeter()
    robust_ce_losses = AverageMeter()
    relu_activities = AverageMeter()
    bound_bias = AverageMeter()
    bound_diff = AverageMeter()
    unstable_neurons = AverageMeter()
    dead_neurons = AverageMeter()
    alive_neurons = AverageMeter()
    batch_time = AverageMeter()
    # initial
    model.eval()
    duplicate_rgb = True
    # pregenerate the array for specifications, will be used for scatter
    sa = np.zeros((num_class, num_class - 1), dtype=np.int32)
    for i in range(sa.shape[0]):
        for j in range(sa.shape[1]):
            if j < i:
                sa[i][j] = j
            else:
                sa[i][j] = j + 1
    sa = torch.LongTensor(sa)
    total = len(loader.dataset)
    batch_size = loader.batch_size
    print(batch_size)
    std = torch.tensor(loader.std).unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
    total_steps = 300

    batch_eps = np.linspace(start_eps, end_eps, (total // batch_size) + 1)
    if end_eps < 1e-6:
        logger.log('eps {} close to 0, using natural training'.format(end_eps))
        method = "natural"

    exp_name = 'outputs/[{}:{}]'.format(get_exp_name(), model_name)
    # real_i = 0
    for i, (init_data, init_labels) in enumerate(loader):
        # labels = torch.zeros_like(init_labels)
        init_data = init_data.cuda()
        tv_eps, tv_lam, reg_lam = get_args(duplicate_rgb=duplicate_rgb)
        attacker = Shadow(init_data, init_labels, tv_lam, reg_lam, tv_eps)
        success = np.zeros(len(init_data))
        # saved_advs = torch.zeros_like(init_data).cuda()
        for t_i in range(9):

            attacker.iterate_labels_not_equal_to(init_labels)
            attacker.renew_t()
            labels = attacker.labels

            for rep in range(total_steps):
                ct = attacker.get_ct()
                data = init_data + ct
                data.data = get_normal(get_unit01(data))

                # ========================== The rest of code is taken from CROWN-IBP REPO
                start = time.time()
                eps = batch_eps[i]
                c = torch.eye(num_class).type_as(data)[labels].unsqueeze(
                    1) - torch.eye(num_class).type_as(data).unsqueeze(0)
                # remove specifications to self
                eye = (~(labels.data.unsqueeze(1)
                         == torch.arange(num_class).type_as(
                             labels.data).unsqueeze(0)))
                c = (c[eye].view(data.size(0), num_class - 1, num_class))
                # scatter matrix to avoid compute margin to self
                sa_labels = sa[labels]
                # storing computed lower bounds after scatter
                lb_s = torch.zeros(data.size(0), num_class)

                # FIXME: Assume data is from range 0 - 1
                if kwargs["bounded_input"]:
                    assert loader.std == [1, 1, 1] or loader.std == [1]
                    # bounded input only makes sense for Linf perturbation
                    assert norm == np.inf
                    data_ub = (data + eps).clamp(max=1.0)
                    data_lb = (data - eps).clamp(min=0.0)
                else:
                    if norm == np.inf:
                        data_ub = data.cpu() + (eps / std)
                        data_lb = data.cpu() - (eps / std)
                    else:
                        data_ub = data_lb = data

                if list(model.parameters())[0].is_cuda:
                    data = data.cuda()
                    data_ub = data_ub.cuda()
                    data_lb = data_lb.cuda()
                    labels = labels.cuda()
                    c = c.cuda()
                    sa_labels = sa_labels.cuda()
                    lb_s = lb_s.cuda()
                # convert epsilon to a tensor
                eps_tensor = data.new(1)
                eps_tensor[0] = eps

                # omit the regular cross entropy, since we use robust error
                output = model(data)
                regular_ce = torch.nn.CrossEntropyLoss()(output, labels)
                regular_ce_losses.update(regular_ce.cpu().detach().numpy(),
                                         data.size(0))
                errors.update(
                    torch.sum(torch.argmax(output, dim=1) != labels).cpu().
                    detach().numpy() / data.size(0), data.size(0))
                # get range statistic

                if verbose or method != "natural":
                    if kwargs["bound_type"] == "convex-adv":
                        # Wong and Kolter's bound, or equivalently Fast-Lin
                        if kwargs["convex-proj"] is not None:
                            proj = kwargs["convex-proj"]
                            if norm == np.inf:
                                norm_type = "l1_median"
                            elif norm == 2:
                                norm_type = "l2_normal"
                            else:
                                raise (ValueError(
                                    "Unsupported norm {} for convex-adv".
                                    format(norm)))
                        else:
                            proj = None
                            if norm == np.inf:
                                norm_type = "l1"
                            elif norm == 2:
                                norm_type = "l2"
                            else:
                                raise (ValueError(
                                    "Unsupported norm {} for convex-adv".
                                    format(norm)))
                        if loader.std == [1] or loader.std == [1, 1, 1]:
                            convex_eps = eps
                        else:
                            convex_eps = eps / np.mean(loader.std)
                            # for CIFAR we are roughly / 0.2
                            # FIXME this is due to a bug in convex_adversarial, we cannot use per-channel eps
                        if norm == np.inf:
                            # bounded input is only for Linf
                            if kwargs["bounded_input"]:
                                # FIXME the bounded projection in convex_adversarial has a bug, data range must be positive
                                data_l = 0.0
                                data_u = 1.0
                            else:
                                data_l = -np.inf
                                data_u = np.inf
                        else:
                            data_l = data_u = None
                        f = DualNetwork(model,
                                        data,
                                        convex_eps,
                                        proj=proj,
                                        norm_type=norm_type,
                                        bounded_input=kwargs["bounded_input"],
                                        data_l=data_l,
                                        data_u=data_u)
                        lb = f(c)
                    elif kwargs["bound_type"] == "interval":
                        ub, lb, relu_activity, unstable, dead, alive = model.interval_range(
                            norm=norm, x_U=data_ub, x_L=data_lb, eps=eps, C=c)
                    elif kwargs["bound_type"] == "crown-interval":
                        ub, ilb, relu_activity, unstable, dead, alive = model.interval_range(
                            norm=norm, x_U=data_ub, x_L=data_lb, eps=eps, C=c)
                        crown_final_factor = kwargs['final-beta']
                        factor = (max_eps - eps *
                                  (1.0 - crown_final_factor)) / max_eps
                        if factor < 1e-5:
                            lb = ilb
                        else:
                            if kwargs["runnerup_only"]:
                                masked_output = output.detach().scatter(
                                    1, labels.unsqueeze(-1), -100)
                                runner_up = masked_output.max(1)[1]
                                runnerup_c = torch.eye(num_class).type_as(
                                    data)[labels]
                                runnerup_c.scatter_(1, runner_up.unsqueeze(-1),
                                                    -1)
                                runnerup_c = runnerup_c.unsqueeze(1).detach()
                                clb, bias = model.backward_range(norm=norm,
                                                                 x_U=data_ub,
                                                                 x_L=data_lb,
                                                                 eps=eps,
                                                                 C=c)
                                clb = clb.expand(clb.size(0), num_class - 1)
                            else:
                                clb, bias = model.backward_range(norm=norm,
                                                                 x_U=data_ub,
                                                                 x_L=data_lb,
                                                                 eps=eps,
                                                                 C=c)
                                bound_bias.update(bias.sum() / data.size(0))
                            diff = (clb - ilb).sum().item()
                            bound_diff.update(diff / data.size(0),
                                              data.size(0))
                            lb = clb * factor + ilb * (1 - factor)
                    else:
                        raise RuntimeError("Unknown bound_type " +
                                           kwargs["bound_type"])

                    lb = lb_s.scatter(1, sa_labels, lb)
                    robust_ce = torch.nn.CrossEntropyLoss()(-lb, labels)
                    if kwargs["bound_type"] != "convex-adv":
                        relu_activities.update(
                            relu_activity.detach().cpu().item() / data.size(0),
                            data.size(0))
                        unstable_neurons.update(unstable / data.size(0),
                                                data.size(0))
                        dead_neurons.update(dead / data.size(0), data.size(0))
                        alive_neurons.update(alive / data.size(0),
                                             data.size(0))

                if method == "robust":
                    loss = robust_ce
                elif method == "robust_activity":
                    loss = robust_ce + kwargs["activity_reg"] * relu_activity
                elif method == "natural":
                    loss = regular_ce
                elif method == "robust_natural":
                    natural_final_factor = kwargs["final-kappa"]
                    kappa = (max_eps - eps *
                             (1.0 - natural_final_factor)) / max_eps
                    loss = (1 - kappa) * robust_ce + kappa * regular_ce
                else:
                    raise ValueError("Unknown method " + method)

                if "l1_reg" in kwargs:
                    reg = kwargs["l1_reg"]
                    l1_loss = 0.0
                    for name, param in model.named_parameters():
                        if 'bias' not in name:
                            l1_loss = l1_loss + (reg *
                                                 torch.sum(torch.abs(param)))
                    loss = loss + l1_loss
                    l1_losses.update(l1_loss.cpu().detach().numpy(),
                                     data.size(0))

                # =========================================== The rest is from breaking paper not from CROWN-IBP Repo
                c_loss = -loss
                attacker.back_prop(c_loss, rep)

                batch_time.update(time.time() - start)
                losses.update(loss.cpu().detach().numpy(), data.size(0))

                if (verbose or method != "natural") and rep == total_steps - 1:
                    robust_ce_losses.update(robust_ce.cpu().detach().numpy(),
                                            data.size(0))
                    certified = (lb < 0).any(dim=1).cpu().numpy()
                    success = success + np.ones(len(success)) - certified
                    # saved_advs[certified == False] = data[certified == False].data
            torch.cuda.empty_cache()
            to_print = '{}\t{}\t{}'.format((success > 0).sum(), t_i,
                                           attacker.log)
            print(to_print, flush=True)
            attacker.labels = attacker.labels + 1
        # save_images(get_unit01(torch.cat((saved_advs, init_data), dim=-1)), success.astype(np.bool), real_i, exp_name)
        # real_i += len(saved_advs)
        robust_errors.update((success > 0).sum() / len(success), len(success))
        print('====', robust_errors.avg, '===', flush=True)
    for i, l in enumerate(model):
        if isinstance(l, BoundLinear) or isinstance(l, BoundConv2d):
            norm = l.weight.data.detach().view(l.weight.size(0),
                                               -1).abs().sum(1).max().cpu()
            logger.log('layer {} norm {}'.format(i, norm))
    if method == "natural":
        return errors.avg, errors.avg
    else:
        return robust_errors.avg, errors.avg