def epoch_transfer_train(loader,
                         model_source,
                         model_target,
                         opt,
                         attack,
                         device,
                         use_tqdm=True,
                         **kwargs):
    model_source.eval()
    model_target.train()
    if use_tqdm:
        pbar = tqdm(total=len(loader))

    model_source.to(device)
    model_target.to(device)

    total_loss, total_err = 0., 0.

    for X, y in loader:
        X, y = X.to(device), y.to(device)
        delta = attack(model_source, X, y, **kwargs)
        yp_target = model_target(X + delta)
        loss = nn.CrossEntropyLoss()(yp_target, y)

        opt.zero_grad()
        loss.backward()
        opt.step()

        total_err += (yp_target.max(dim=1)[1] != y).sum().item()
        total_loss += loss.item() * X.shape[0]

        if use_tqdm:
            pbar.update(1)

    return total_err / len(loader.dataset), total_loss / len(loader.dataset)
def epoch_ALP(loader,
              model,
              attack,
              alp_weight=0.5,
              opt=None,
              device=None,
              use_tqdm=False,
              n_test=None,
              **kwargs):
    """Adversarial Logit Pairing epoch over the dataset"""
    total_loss, total_err = 0., 0.

    # assert(opt is not None)
    model.train()

    if use_tqdm:
        if n_test is None:
            pbar = tqdm(total=len(loader.dataset))
        else:
            pbar = tqdm(total=n_test)
    total_n = 0
    for X, y in loader:
        X, y = X.to(device), y.to(device)
        model.eval()
        with torch.no_grad():
            clean_logit = model(X)
        delta = attack(model, X, y, **kwargs)

        model.train()
        yp = model(X + delta)
        loss = nn.CrossEntropyLoss()(
            yp, y) + alp_weight * nn.MSELoss()(yp, clean_logit)

        opt.zero_grad()
        loss.backward()
        opt.step()

        total_err += (yp.max(dim=1)[1] != y).sum().item()
        total_loss += loss.item() * X.shape[0]
        if use_tqdm:
            pbar.update(X.shape[0])

        total_n += X.shape[0]

        if n_test is not None:
            if total_n >= n_test:
                break

    return total_err / total_n, total_loss / total_n
def epoch_transfer_attack(loader,
                          model_source,
                          model_target,
                          attack,
                          device,
                          success_only=False,
                          use_tqdm=True,
                          n_test=None,
                          **kwargs):
    source_err = 0.
    target_err = 0.
    target_err2 = 0.

    success_total_n = 0

    model_source.eval()
    model_target.eval()

    total_n = 0

    if use_tqdm:
        pbar = tqdm(total=n_test)

    model_source.to(device)
    model_target.to(device)
    for X, y in loader:
        X, y = X.to(device), y.to(device)
        delta = attack(model_source, X, y, **kwargs)

        if success_only:
            raise NotImplementedError
        else:
            yp_target = model_target(X + delta).detach()
            yp_source = model_source(X + delta).detach()
            # yp_origin = model_target(X).detach()
        source_err += (yp_source.max(dim=1)[1] != y).sum().item()
        target_err += (yp_target.max(dim=1)[1] != y).sum().item()
        # target_err2 += (yp_origin.max(dim=1)[1] != y).sum().item()
        # success_total_n += (yp_origin.max(dim=1)[1] == y)
        if use_tqdm:
            pbar.update(X.shape[0])

        total_n += X.shape[0]
        if n_test is not None:
            if total_n >= n_test:
                break

    return source_err / total_n, target_err / total_n, 0
def epoch_adversarial(loader,
                      model,
                      attack,
                      opt=None,
                      device=None,
                      use_tqdm=False,
                      n_test=None,
                      **kwargs):
    """Adversarial training/evaluation epoch over the dataset"""
    total_loss, total_err = 0., 0.

    if opt is None:
        model.eval()
    else:
        model.train()

    if use_tqdm:
        if n_test is None:
            pbar = tqdm(total=len(loader.dataset))
        else:
            pbar = tqdm(total=n_test)
    total_n = 0
    for X, y in loader:
        X, y = X.to(device), y.to(device)
        model.eval()
        delta = attack(model, X, y, **kwargs)
        if opt:
            model.train()
        yp = model(X + delta)
        loss = nn.CrossEntropyLoss()(yp, y)
        if opt:
            opt.zero_grad()
            loss.backward()
            opt.step()

        total_err += (yp.max(dim=1)[1] != y).sum().item()
        total_loss += loss.item() * X.shape[0]
        if use_tqdm:
            pbar.update(X.shape[0])

        total_n += X.shape[0]

        if n_test is not None:
            if total_n >= n_test:
                break

    return total_err / total_n, total_loss / total_n
def epoch_surrogate(loader,
                    basemodel,
                    ae,
                    optnet,
                    attack,
                    device,
                    use_tqdm=True,
                    n_test=None,
                    **kwargs):
    total_loss, total_err = 0., 0.

    basemodel.eval()
    ae.eval()
    optnet.eval()

    surrogate_model = nn.Sequential(ae, basemodel)
    surrogate_model.eval()

    if use_tqdm:
        if n_test is None:
            pbar = tqdm(total=len(loader.dataset))
        else:
            pbar = tqdm(total=n_test)
    total_n = 0

    for X, y in loader:
        X, y = X.to(device), y.to(device)

        delta = attack(surrogate_model, X, y, **kwargs)

        yp = optnet(X + delta)
        loss = nn.CrossEntropyLoss()(yp, y)

        total_err += (yp.max(dim=1)[1] != y).sum().item()
        total_loss += loss.item() * X.shape[0]
        if use_tqdm:
            pbar.update(X.shape[0])

        total_n += X.shape[0]

        if n_test is not None:
            if total_n >= n_test:
                break

    return total_err / total_n, total_loss / total_n