Exemplo n.º 1
0
def train(
        trainloader,
        testloader,
        device,
        seed,
        debias_=True,
        specific=None,
        ratio=0.5,  # bias ratio in dataset
        n_epochs=5,
        model_lr=1e-3,
        n2v_lr=1e-3,
        combined_n2v_lr=1e-3,  # metalearning rate for n2v
        alpha=100,  # for debias,
        beta=0.1,  # for adversarial loss
        out_file=None,
        base_folder="",
        results_folder="",
        experiment="sgd",
        momentum=0,
        module="layer4",
        finetuned=False,
        adversarial=False,
        nonlinear=False,
        subset=False,
        subset_ratio=0.1,
        save_every=False,
        model_momentum=0,
        n2v_momentum=0,
        experimental=False,
        multiple=False,
        debias_multiple=False,
        reset=False,
        reset_counter=1,
        n2v_start=False,
        experiment2=None,
        adaptive_alpha=False,
        n2v_adam=False,
        single=False,
        imagenet=False,
        train_batch_size=64,
        constant_resize=False,
        adaptive_resize=False,
        no_class=False,
        gamma=0,
        partial_projection=False,
        norm='l2',
        constant_alpha=False,
        jump_alpha=False,
        linear_alpha=False,
        mean_debias=False,
        no_limit=False,
        dataset='bam',
        parallel=False,
        gpu_ids=[],
        switch_modes=True):
    print("mu", momentum, "debias", debias_, "alpha", alpha, " | ratio:",
          ratio)

    def get_vg(W):
        if single:
            return W[-2, :]
        else:
            return W[-2, :] - W[-1, :]

    if dataset == 'bam' or dataset == 'coco':
        model_init_path, n2v_init_path = utils.get_paths(
            base_folder,
            seed,
            specific,
            model_end="resnet_init" + '.pt',
            n2v_end="resnet_n2v_init" + '.pt',
            n2v_module=module,
            experiment=experiment,
            with_n2v=False)
    else:
        model_init_path = os.path.join(base_folder, str(seed), experiment,
                                       'resnet_init.pt')
        n2v_init_path = os.path.join(base_folder, str(seed), experiment,
                                     module, 'resnet_n2v_init.pt')
    if finetuned:
        if dataset == 'bam' or dataset == 'coco':
            model_init_path = utils.get_model_path(
                base_folder,
                seed,
                specific,
                "resnet_" + str(ratio) + ".pt",
                experiment='post_train'
                if not n2v_start else experiment.split('_finetuned')[0])
        else:
            model_init_path = os.path.join(
                base_folder, str(seed), 'post_train' if not n2v_start else
                experiment.split('_finetuned')[0], 'resnet.pt')
        assert (debias_ and not adversarial) or (
            adversarial and not debias_) or (not adversarial and not debias_)
        if debias_ and n2v_start:
            ext = "_n2v_" if not nonlinear else "_mlp_"
            if dataset == 'bam' or dataset == 'coco':
                n2v_init_path = utils.get_net2vec_path(
                    base_folder,
                    seed,
                    specific,
                    module,
                    "resnet" + str(ext) + str(ratio) + ".pt",
                    experiment=experiment.split('_finetuned')[0])
            else:
                n2v_init_path = os.path.join(base_folder, str(seed),
                                             experiment.split('_finetuned')[0],
                                             module,
                                             'resnet' + ext[:-1] + '.pt')
        # if we're also doing adversarial, make sure to load the matching n2v as init...
        if adversarial:
            ext = "_n2v_" if not nonlinear else "_mlp_"
            if dataset == 'bam' or dataset == 'coco':
                n2v_init_path = utils.get_net2vec_path(base_folder,
                                                       seed,
                                                       specific,
                                                       module,
                                                       "resnet" + str(ext) +
                                                       str(ratio) + ".pt",
                                                       experiment='post_train')
            else:
                n2v_init_path = os.path.join(base_folder, str(seed),
                                             'post_train', module,
                                             'resnet' + ext[:-1] + '.pt')
    num_classes = 10
    num_attributes = 12
    if nonlinear:
        num_attributes = 2
    if multiple:
        num_attributes = 10 + 9 + 2 * 10
    if dataset == 'coco':
        num_classes = 79
        num_attributes = 81
    model, net, net_forward, activation_probe = models.load_models(
        device,
        lambda x, y, z: models.resnet_(pretrained=True,
                                       custom_path=x,
                                       device=y,
                                       initialize=z,
                                       num_classes=num_classes,
                                       size=50 if (dataset == 'bam' or dataset
                                                   == 'coco') else 34),
        model_path=model_init_path,
        net2vec_pretrained=True,
        net2vec_path=n2v_init_path,
        module=module,
        num_attributes=num_attributes,
        # we want to make sure to save the inits if not finetuned...
        model_init=True if not finetuned else False,
        n2v_init=True if not (finetuned and
                              (adversarial or
                               (debias_ and n2v_start))) else False,
        loader=trainloader,
        nonlinear=nonlinear,
        # parameters if we want to initially project probes to have a certain amount of bias
        partial_projection=partial_projection,
        t=gamma)
    print(model_init_path, n2v_init_path)
    model_n2v_combined = models.ProbedModel(model,
                                            net,
                                            module,
                                            switch_modes=switch_modes)
    if n2v_adam:
        combined_optim = torch.optim.Adam(
            [{
                'params': model_n2v_combined.model.parameters()
            }, {
                'params': model_n2v_combined.net.parameters()
            }],
            lr=n2v_lr)
        # TODO: allow for momentum training as well
        n2v_optim = torch.optim.Adam(net.parameters(), lr=n2v_lr)
    else:
        combined_optim = torch.optim.SGD(
            [{
                'params': model_n2v_combined.model.parameters()
            }, {
                'params': model_n2v_combined.net.parameters(),
                'lr': combined_n2v_lr,
                'momentum': n2v_momentum
            }],
            lr=model_lr,
            momentum=model_momentum)

        # TODO: allow for momentum training as well
        n2v_optim = torch.optim.SGD(net.parameters(),
                                    lr=n2v_lr,
                                    momentum=n2v_momentum)
    model_optim = torch.optim.SGD(model.parameters(),
                                  lr=model_lr,
                                  momentum=model_momentum)

    d_losses = []
    adv_losses = []
    n2v_train_losses = []
    n2v_accs = []
    n2v_val_losses = []
    class_train_losses = []
    class_accs = []
    class_val_losses = []
    alpha_log = []
    magnitudes = []
    magnitudes2 = []
    unreduced = []
    bias_grads = []
    loss_shapes = []
    loss_shapes2 = []

    results = {
        "debias_losses": d_losses,
        "n2v_train_losses": n2v_train_losses,
        "n2v_val_losses": n2v_val_losses,
        "n2v_accs": n2v_accs,
        "class_train_losses": class_train_losses,
        "class_val_losses": class_val_losses,
        "class_accs": class_accs,
        "adv_losses": adv_losses,
        "alphas": alpha_log,
        "magnitudes": magnitudes,
        "magnitudes2": magnitudes2,
        "unreduced": unreduced,
        "bias_grads": bias_grads,
        "loss_shapes": loss_shapes,
        "loss_shapes2": loss_shapes2
    }
    if debias_:
        results_end = str(ratio) + "_debias.pck"
    elif adversarial:
        results_end = str(ratio) + "_adv.pck"
        if nonlinear:
            results_end = str(ratio) + "_mlp_adv.pck"
    else:
        results_end = str(ratio) + "_base.pck"

    if dataset == 'bam' or dataset == 'coco':
        results_path = utils.get_net2vec_path(
            results_folder, seed, specific, module, results_end,
            experiment if experiment2 is None else experiment2)
    else:
        results_path = os.path.join(
            results_folder, str(seed),
            experiment if experiment2 is None else experiment2, module,
            results_end)
    if debias_:
        model_end = "resnet_debias_" + str(ratio) + '.pt'
        n2v_end = "resnet_n2v_debias_" + str(ratio) + '.pt'
    elif adversarial:
        if not nonlinear:
            model_end = "resnet_adv_" + str(ratio) + '.pt'
        else:
            model_end = "resnet_adv_nonlinear_" + str(ratio) + '.pt'
        if not nonlinear:
            n2v_end = "resnet_n2v_adv_" + str(ratio) + '.pt'
        else:
            n2v_end = "resnet_mlp_adv_" + str(ratio) + '.pt'
    else:
        model_end = "resnet_base_" + str(ratio) + '.pt'
        n2v_end = "resnet_n2v_base_" + str(ratio) + '.pt'

    if dataset != 'bam' and dataset != 'coco':
        model_end = model_end.replace('_' + str(ratio), '')
        n2v_end = n2v_end.replace('_' + str(ratio), '')

    if dataset == 'bam' or dataset == 'coco':
        model_path, n2v_path = utils.get_paths(
            base_folder,
            seed,
            specific,
            model_end=model_end,
            n2v_end=n2v_end,
            n2v_module=module,
            experiment=experiment if experiment2 is None else experiment2,
            with_n2v=True,
        )
    else:
        model_path = os.path.join(
            base_folder, str(seed),
            experiment if experiment2 is None else experiment2, module,
            model_end)
        n2v_path = os.path.join(
            base_folder, str(seed),
            experiment if experiment2 is None else experiment2, module,
            n2v_end)
    if hasattr(trainloader.dataset, 'idx_to_class'):
        for key in trainloader.dataset.idx_to_class:
            if specific is not None and trainloader.dataset.idx_to_class[
                    key] in specific:
                specific_idx = int(key)
            else:
                specific_idx = 0
    train_labels = None if not nonlinear else [-2, -1]
    d_last = 0
    resize = constant_resize or adaptive_resize
    if imagenet:
        imagenet_trainloaders, _ = dataload.get_imagenet_tz(
            './datasets/imagenet',
            workers=8,
            train_batch_size=train_batch_size // 8,
            resize=resize,
            constant=constant_resize)
        imagenet_trainloader = dataload.process_imagenet_loaders(
            imagenet_trainloaders)

    params = list(model_n2v_combined.parameters())[:-2]
    init_alpha = alpha
    last_e = 0

    # setup training criteria
    if dataset == 'coco':
        object_weights = torch.FloatTensor(
            trainloader.dataset.getObjectWeights())
        gender_weights = torch.FloatTensor(
            trainloader.dataset.getGenderWeights())
        all_weights = torch.cat([object_weights, gender_weights])
        probe_criterion = nn.BCEWithLogitsLoss(weight=all_weights.to(device),
                                               reduction='elementwise_mean')
        downstream_criterion = nn.BCEWithLogitsLoss(
            weight=object_weights.to(device), reduction='elementwise_mean')
    else:
        probe_criterion = None
        downstream_criterion = nn.CrossEntropyLoss()

    for e in range(n_epochs):
        # save results every epoch...
        with open(results_path, 'wb') as f:
            print("saving results", e)
            print(results_path)
            pickle.dump(results, f)

        model.eval()

        with torch.no_grad():
            n2v_acc, n2v_val_loss = utils.net2vec_accuracy(
                testloader, net_forward, device, train_labels)
            n2v_accs.append(n2v_acc)
            n2v_val_losses.append(n2v_val_loss)

            if dataset != 'coco':
                class_acc, class_val_loss = utils.classification_accuracy(
                    testloader, model, device)
                class_accs.append(class_acc)
                class_val_losses.append(class_val_loss)
            else:
                f1, mAP = utils.detection_results(testloader, model, device)
                print("Epoch", e, "| f1:", f1, "| mAP:", mAP)
                class_accs.append([f1, mAP])

        d_initial = 0
        if not adversarial:
            curr_W = net.weight.data.clone()
            if not multiple:
                vg = get_vg(curr_W).reshape(-1, 1)
                d_initial = debias.debias_loss(curr_W[:-2], vg, t=0).item()
                print("Epoch", e, "bias", str(d_initial), " | debias: ",
                      debias_)
            else:
                ds = np.zeros(10)
                for i in range(10):
                    if i == 0:
                        vg = (curr_W[10, :] - curr_W[11, :]).reshape(-1, 1)
                    else:
                        vg = (curr_W[20 + i, :] - curr_W[29 + i, :]).reshape(
                            -1, 1)
                    ds[i] = debias.debias_loss(curr_W[:10], vg, t=0).item()
                print("Epoch", e, "bias", ds, " | debias: ", debias_)
                print("Accuracies:", n2v_acc)
                d_initial = ds[0]
        else:
            print("Epoch", e, "Adversarial", n2v_accs[-1])
        if adaptive_alpha and (e == 0 or ((d_last / d_initial) >=
                                          (5 / 2**(e - 1)) or
                                          (0.8 < (d_last / d_initial) < 1.2))):
            #alpha = alpha
            old_alpha = alpha
            # we don't want to increase too much if it's already decreasing
            if (e == 0 or (d_last / d_initial) >= (5 / 2**(e - 1))):
                alpha = min(
                    alpha * 2, (15 / (2**e)) / (d_initial + 1e-10)
                )  # numerical stability just in case d_initial gets really low
                #if e > 0 and old_alpha >= alpha:
                #    alpha = old_alpha # don't update if we're decreasing...
                print("Option 1")
            if e > 0 and alpha < old_alpha:
                # we want to increase if plateaud
                alpha = max(
                    old_alpha * 1.5, alpha
                )  # numerical stability just in case d_initial gets really low
                print("Option 2")
            # don't want to go over 1000...
            if alpha > 1000:
                alpha = 1000
            d_last = d_initial
        elif not adaptive_alpha and not constant_alpha:
            if dataset == 'coco' and jump_alpha:
                if e < 2:
                    alpha = 5e3
                elif e >= 2 and e < 4:
                    alpha = 1e4
                else:
                    alpha = init_alpha
            elif jump_alpha and (e - last_e) > 2:
                if not mean_debias:
                    if alpha < 100:
                        alpha = min(alpha * 2, 100)
                        last_e = e
                    else:
                        # two jumps
                        # if (e-last_e) >= ((n_epochs - last_e) // 2):
                        #     alpha = 1000
                        # else:
                        alpha = 1000
                else:
                    if alpha < 1000:
                        alpha = min(alpha * 2, 1000)
                        last_e = e
                    else:
                        alpha = 10000
            elif linear_alpha and (e - last_e) > 2:
                if alpha < 100:
                    alpha = min(alpha * 2, 100)
                    last_e = e
                else:
                    alpha += (1000 - 100) / (n_epochs - last_e)
            elif not jump_alpha and not linear_alpha:
                if (e + 1) % 3 == 0:
                    # apply alpha schedule?
                    # alpha = min(alpha * 1.2, max(init_alpha,1000))
                    alpha = alpha * 1.5
        alpha_log.append(alpha)
        print("Current Alpha:,", alpha)
        if save_every and e % 10 == 0 and e > 0 and seed == 0 and debias_:
            torch.save(net.state_dict(),
                       n2v_path.split('.pt')[0] + '_' + str(e) + '.pt')
            torch.save(model.state_dict(),
                       model_path.split('.pt')[0] + '_' + str(e) + '.pt')
        if reset and (e + 1) % reset_counter == 0 and e > 0:
            print("resetting")
            net, net_forward, activation_probe = net2vec.create_net2vec(
                model,
                module,
                num_attributes,
                device,
                pretrained=False,
                initialize=True,
                nonlinear=nonlinear)
            n2v_optim = torch.optim.SGD(net.parameters(),
                                        lr=n2v_lr,
                                        momentum=n2v_momentum)

        model.train()
        ct = 0
        for X, y, genders in trainloader:
            ids = None
            ##### Part 1: Update the Embeddings #####
            model_optim.zero_grad()
            n2v_optim.zero_grad()
            labels = utils.merge_labels(y, genders, device)
            logits = net_forward(X.to(device), switch_modes=switch_modes)
            # Now actually update net2vec embeddings, making sure to use the same batch
            if train_labels is not None:
                if logits.shape[1] == labels.shape[1]:
                    logits = logits[:, train_labels]
                labels = labels[:, train_labels]
            shapes = []
            shapes2 = []
            if dataset == 'coco':
                prelim_loss = probe_criterion(logits, labels)
            else:
                prelim_loss, ids = utils.balanced_loss(logits,
                                                       labels,
                                                       device,
                                                       0.5,
                                                       ids=ids,
                                                       multiple=multiple,
                                                       specific=specific_idx,
                                                       shapes=shapes)
            #print("prelim_loss:", prelim_loss.item())
            prelim_loss.backward()
            # we don't want to update these parameters, just in case
            model_optim.zero_grad()
            n2v_train_losses.append(prelim_loss.item())
            n2v_optim.step()
            try:
                magnitudes.append(
                    torch.norm(net.weight.data, dim=1).data.cpu().numpy())
            except:
                pass

            ##### Part 2: Update Conv parameters for classification #####
            model_optim.zero_grad()
            n2v_optim.zero_grad()
            class_logits = model(X.to(device))
            class_loss = downstream_criterion(class_logits, y.to(device))
            class_train_losses.append(class_loss.item())

            if debias_:
                W_curr = net.weight.data
                vg = get_vg(W_curr).reshape(-1, 1)
                unreduced.append(
                    debias.debias_loss(W_curr[:-2], vg, t=0,
                                       unreduced=True).data.cpu().numpy())

            loss = class_loss
            #### Part 2a: Debias Loss
            if debias_:
                model_optim.zero_grad()
                n2v_optim.zero_grad()

                labels = utils.merge_labels(y, genders, device)
                o = net.weight.clone()
                combined_optim.zero_grad()
                with higher.innerloop_ctx(model_n2v_combined,
                                          combined_optim) as (fn2v,
                                                              diffopt_n2v):
                    models.update_probe(fn2v)
                    logits = fn2v(X.to(device))
                    if dataset == 'coco':
                        prelim_loss = probe_criterion(logits, labels)
                    else:
                        prelim_loss, ids = utils.balanced_loss(
                            logits,
                            labels,
                            device,
                            0.5,
                            ids=ids,
                            multiple=False,
                            specific=specific_idx,
                            shapes=shapes2)
                    diffopt_n2v.step(prelim_loss)
                    weights = list(fn2v.parameters())[-2]
                    vg = get_vg(weights).reshape(-1, 1)
                    d_loss = debias.debias_loss(weights[:-2],
                                                vg,
                                                t=gamma,
                                                norm=norm,
                                                mean=mean_debias)
                    # only want to save the actual bias...
                    d_losses.append(d_loss.item())
                    grad_of_grads = torch.autograd.grad(
                        alpha * d_loss,
                        list(fn2v.parameters(time=0))[:-2],
                        allow_unused=True)

                    del prelim_loss
                    del logits
                    del vg
                    del fn2v
                    del diffopt_n2v
            #### Part 2b: Adversarial Loss
            if adversarial:
                logits = net_forward(
                    None, forward=True)[:, -2:]  # just use activation probe
                labels = genders.type(torch.FloatTensor).reshape(
                    genders.shape[0], -1).to(device)
                adv_loss, _ = utils.balanced_loss(logits,
                                                  labels,
                                                  device,
                                                  0.5,
                                                  ids=ids,
                                                  stable=True)
                adv_losses.append(adv_loss.item())
                # getting too strong, let it retrain...
                if adv_loss < 2:
                    adv_loss = -beta * adv_loss
                    loss += adv_loss
            loss.backward()
            if debias_:
                # custom backward to include the bias regularization....
                max_norm_grad = -1
                param_idx = -1
                for ii in range(len(grad_of_grads)):
                    if (grad_of_grads[ii] is not None
                            and params[ii].grad is not None
                            and torch.isnan(grad_of_grads[ii]).long().sum() <
                            grad_of_grads[ii].reshape(-1).shape[0]):
                        # just in case some or nan for some reason?
                        not_nan = ~torch.isnan(grad_of_grads[ii])
                        params[ii].grad[not_nan] += grad_of_grads[ii][not_nan]
                        if grad_of_grads[ii][not_nan].norm().item(
                        ) > max_norm_grad:
                            max_norm_grad = grad_of_grads[ii][not_nan].norm(
                            ).item()
                            param_idx = ii
                bias_grads.append((param_idx, max_norm_grad))
                # undo the last step and apply a smaller alpha to prevent stability issues
                if not no_limit and ((not mean_debias and max_norm_grad > 100)
                                     or (mean_debias and max_norm_grad > 100)):
                    for ii in range(len(grad_of_grads)):
                        if (grad_of_grads[ii] is not None
                                and params[ii].grad is not None and
                                torch.isnan(grad_of_grads[ii]).long().sum() <
                                grad_of_grads[ii].reshape(-1).shape[0]):
                            # just in case some or nan for some reason?
                            not_nan = ~torch.isnan(grad_of_grads[ii])
                            params[ii].grad[not_nan] -= grad_of_grads[ii][
                                not_nan]
                            # scale accordingly
                            # params[ii].grad[not_nan] += grad_of_grads[ii][not_nan] / max_norm_grad

            loss_shapes.append(shapes)
            loss_shapes2.append(shapes2)
            model_optim.step()
            #magnitudes2.append(
            #    torch.norm(net.weight.data, dim=1).data.cpu().numpy()
            #)
            ct += 1

    # save results every epoch...
    with open(results_path, 'wb') as f:
        print("saving results", e)
        print(results_path)
        pickle.dump(results, f)
    torch.save(net.state_dict(), n2v_path)
    torch.save(model.state_dict(), model_path)
Exemplo n.º 2
0
def train_leakage(trainloader,
                  testloader,
                  device,
                  seed,
                  specific=None,
                  p=0.5,
                  n_epochs=5,
                  module='layer4',
                  lr=0.5,
                  base_folder="",
                  out_file=None,
                  experiment1="",
                  experiment2="",
                  model_extra="",
                  n2v_extra="",
                  with_n2v=False,
                  nonlinear=False,
                  model_custom_end='',
                  n2v_custom_end='',
                  multiple=False,
                  dataset='bam',
                  parallel=False,
                  gpu_ids=[]):
    if out_file is not None:
        f = open(out_file, 'a')
    else:
        f = None
    print("Training Model Leakage | p =", p, file=f)
    if not nonlinear:
        n2v_extra = "n2v" + str(n2v_extra)
    else:
        n2v_extra = "mlp" + str(n2v_extra)
    if len(model_custom_end) > 0:
        model_custom_end = "_" + model_custom_end
    if len(n2v_custom_end) > 0:
        n2v_custom_end = "_" + n2v_custom_end
    if hasattr(trainloader.dataset, 'idx_to_class'):
        for key in trainloader.dataset.idx_to_class:
            if specific is not None and trainloader.dataset.idx_to_class[
                    key] in specific:
                specific_idx = int(key)
            else:
                specific_idx = 0
    else:
        specific_idx = 0
    if dataset == 'bam' or dataset == 'coco':
        model_path = utils.get_model_path(base_folder,
                                          seed,
                                          specific,
                                          "resnet" + str(model_extra) + "_" +
                                          str(p) + model_custom_end + ".pt",
                                          experiment=experiment1,
                                          with_n2v=with_n2v,
                                          n2v_module=module)
        n2v_path = utils.get_net2vec_path(
            base_folder,
            seed,
            specific,
            module,
            "leakage/resnet_" + str(n2v_extra) + "_" + str(p) +
            n2v_custom_end + ".pt",
            experiment=experiment2,
        )
    else:
        if with_n2v:
            model_path = os.path.join(
                base_folder, str(seed), experiment1, module,
                "resnet" + str(model_extra) + model_custom_end + ".pt")
        else:
            model_path = os.path.join(
                base_folder, str(seed), experiment1,
                "resnet" + str(model_extra) + model_custom_end + ".pt")
        n2v_path = os.path.join(
            base_folder, str(seed), experiment2, module,
            'leakage/resnet_' + str(n2v_extra) + n2v_custom_end + ".pt")

    if dataset == 'bam':
        if specific is not None and not isinstance(specific, str):
            folder_name = '.'.join(sorted(specific))
        else:
            folder_name = specific
        leakage_folder = os.path.join(str(base_folder), str(seed), folder_name,
                                      str(experiment2), str(module), 'leakage')
    else:
        leakage_folder = os.path.join(str(base_folder), str(seed),
                                      str(experiment2), str(module), 'leakage')
    if not os.path.isdir(leakage_folder):
        os.mkdir(leakage_folder)
    num_classes = 10
    if dataset == 'coco':
        num_classes = 79
    num_attributes = 2
    model, net, net_forward, activation_probe = models.load_models(
        device,
        None if (dataset == 'coco') and ('adv' in model_extra) else lambda x,
        y, z: models.resnet_(pretrained=True,
                             custom_path=x,
                             device=y,
                             num_classes=num_classes,
                             initialize=z,
                             size=50 if
                             (dataset == 'bam' or dataset == 'coco') else 34),
        model_path=model_path,
        net2vec_pretrained=False,
        module='fc',  # leakage will come from the output logits...
        num_attributes=num_attributes,
        model_init=False,  # don't need to initialize a new one
        n2v_init=True,
        loader=trainloader,
        nonlinear=nonlinear,
        parallel=parallel,
        gpu_ids=gpu_ids)

    def criterion(logits, genders):
        return F.cross_entropy(logits,
                               genders[:, 1].long(),
                               reduction='elementwise_mean')

    net2vec.train_net2vec(model,
                          net,
                          net_forward,
                          n_epochs,
                          trainloader,
                          testloader,
                          device,
                          lr=lr,
                          save_path=n2v_path,
                          f=f,
                          train_labels=[-2, -1],
                          balanced=False,
                          criterion=criterion,
                          specific=specific_idx,
                          adam=True,
                          save_best=True,
                          leakage=True)
    if f is not None:
        f.close()
Exemplo n.º 3
0
def train_net2vec(trainloader,
                  testloader,
                  device,
                  seed,
                  specific=None,
                  p=0.5,
                  n_epochs=5,
                  module='layer4',
                  lr=0.5,
                  base_folder="",
                  out_file=None,
                  experiment1="",
                  experiment2="",
                  model_extra="",
                  n2v_extra="",
                  with_n2v=False,
                  nonlinear=False,
                  model_custom_end='',
                  n2v_custom_end='',
                  multiple=False,
                  dataset='bam',
                  parallel=False,
                  gpu_ids=[]):
    if out_file is not None:
        f = open(out_file, 'a')
    else:
        f = None
    print("Training N2V | p =", p, file=f)
    if not nonlinear:
        n2v_extra = "n2v" + str(n2v_extra)
    else:
        n2v_extra = "mlp" + str(n2v_extra)
    if len(model_custom_end) > 0:
        model_custom_end = "_" + model_custom_end
    if len(n2v_custom_end) > 0:
        n2v_custom_end = "_" + n2v_custom_end
    if hasattr(trainloader.dataset, 'idx_to_class'):
        for key in trainloader.dataset.idx_to_class:
            if specific is not None and trainloader.dataset.idx_to_class[
                    key] in specific:
                specific_idx = int(key)
    else:
        specific_idx = 0
    if dataset == 'bam' or dataset == 'coco':
        model_path = utils.get_model_path(base_folder,
                                          seed,
                                          specific,
                                          "resnet" + str(model_extra) + "_" +
                                          str(p) + model_custom_end + ".pt",
                                          experiment=experiment1,
                                          with_n2v=with_n2v,
                                          n2v_module=module)
        n2v_path = utils.get_net2vec_path(
            base_folder,
            seed,
            specific,
            module,
            "resnet_" + str(n2v_extra) + "_" + str(p) + n2v_custom_end + ".pt",
            experiment=experiment2,
        )
    else:
        if with_n2v:
            model_path = os.path.join(
                base_folder, str(seed), experiment1, module,
                "resnet" + str(model_extra) + model_custom_end + ".pt")
        else:
            model_path = os.path.join(
                base_folder, str(seed), experiment1,
                "resnet" + str(model_extra) + model_custom_end + ".pt")
        n2v_path = os.path.join(
            base_folder, str(seed), experiment2, module,
            'resnet_' + str(n2v_extra) + n2v_custom_end + ".pt")
    print(model_path, n2v_path)
    num_attributes = 12
    if nonlinear:
        num_attributes = 2
    if multiple:
        num_attributes = 10 + 9 + 2 * 10
    num_classes = 10
    if dataset == 'coco':
        num_classes = 79
        num_attributes = 81
    model, net, net_forward, activation_probe = models.load_models(
        device,
        # load in None means we load in the pretrained weights from Tianlu
        None if (dataset == 'coco') and ('adv' in model_extra) else lambda x,
        y, z: models.resnet_(pretrained=True,
                             custom_path=x,
                             device=y,
                             num_classes=num_classes,
                             initialize=z,
                             size=50 if
                             (dataset == 'bam' or dataset == 'coco') else 34),
        model_path=model_path,
        net2vec_pretrained=False,
        module=module,
        num_attributes=num_attributes,
        model_init=False,  # don't need to initialize a new one
        n2v_init=True,
        loader=trainloader,
        nonlinear=nonlinear,
        parallel=parallel,
        gpu_ids=gpu_ids)
    if dataset == 'coco':
        object_weights = torch.FloatTensor(
            trainloader.dataset.getObjectWeights())
        gender_weights = torch.FloatTensor(
            trainloader.dataset.getGenderWeights())
        all_weights = torch.cat([object_weights, gender_weights])
        criterion = nn.BCEWithLogitsLoss(weight=all_weights.to(device),
                                         reduction='elementwise_mean')
        #criterion = nn.BCEWithLogitsLoss()
    else:
        criterion = None
    net2vec.train_net2vec(
        model,
        net,
        net_forward,
        n_epochs,
        trainloader,
        testloader,
        device,
        lr=lr,
        save_path=n2v_path,
        f=f,
        train_labels=[-2, -1] if nonlinear else None,
        multiple=multiple,
        balanced=False if dataset == 'coco' else True,
        criterion=criterion,
        adam=False,  #True if dataset == 'coco' else False,
        leakage=False)
    if f is not None:
        f.close()
Exemplo n.º 4
0
def train_main(trainloader,
               testloader,
               device,
               seed,
               specific=None,
               p=0.5,
               n_epochs=5,
               lr=0.1,
               experiment="",
               out_file=None,
               base_folder="",
               dataset="bam",
               parallel=False,
               gpu_ids=[],
               linear_only=False):
    if out_file is not None:
        f = open(out_file, 'a')
    else:
        f = None
    print("Downstream Training | Ratio: " + str(p) + " | lr = " + str(lr),
          file=f)
    num_classes = 10
    if dataset == 'coco':
        num_classes = 79
    model = models.resnet_(pretrained=True,
                           custom_path=os.path.join(base_folder, str(seed),
                                                    "resnet_init.pt"),
                           device=device,
                           num_classes=num_classes,
                           initialize=True,
                           size=50 if
                           (dataset == 'bam' or dataset == 'coco') else 34,
                           linear_only=linear_only)
    if parallel:
        model = nn.DataParallel(model, device_ids=gpu_ids)
    optim = torch.optim.SGD(model.parameters(), lr=lr, momentum=0.9)

    def scaler(epoch):
        return 0.75**(epoch // 10)

    scheduler = torch.optim.lr_scheduler.LambdaLR(
        optim,
        lr_lambda=scaler if dataset == 'coco' else (lambda epoch: 0.95**epoch))
    start = time.time()
    criterion = nn.CrossEntropyLoss()
    best_f1 = 0
    if dataset == 'coco':
        object_weights = torch.FloatTensor(
            trainloader.dataset.getObjectWeights())
        criterion = nn.BCEWithLogitsLoss(weight=object_weights.to(device),
                                         reduction='elementwise_mean')
    for e in range(n_epochs):
        if dataset != 'coco':
            with torch.no_grad():
                acc = utils.classification_accuracy(testloader, model, device)
            print("Epoch:", e, "| acc:", acc, file=f)
        else:
            with torch.no_grad():
                f1, mAP = utils.detection_results(testloader, model, device)
            print("Epoch:",
                  e,
                  "| f1:",
                  f1,
                  '| mAP:',
                  mAP,
                  '| lr:',
                  scheduler.get_lr(),
                  file=f)
            if f1 > best_f1:
                save_file = utils.get_model_path(base_folder,
                                                 seed,
                                                 specific,
                                                 'resnet_best' + str(p) +
                                                 '_{}_{}.pt'.format(f1, mAP),
                                                 experiment=experiment)
                best_f1 = f1
                torch.save(model.state_dict(), save_file)
        model.train()
        for X, y, color in trainloader:
            optim.zero_grad()
            loss = criterion(model(X.to(device)), y.to(device))
            loss.backward()
            optim.step()
        scheduler.step()
    end = time.time()
    print(start - end)
    if dataset == 'bam' or dataset == 'coco':
        if dataset == 'coco':
            with torch.no_grad():
                f1, mAP = utils.detection_results(testloader, model, device)
        # print("final", utils.classification_accuracy(testloader, model, device), file=f)
        save_file = utils.get_model_path(base_folder,
                                         seed,
                                         specific,
                                         'resnet_' + str(p) +
                                         '_{}_{}.pt'.format(f1, mAP),
                                         experiment=experiment)
    else:
        save_file = os.path.join(base_folder, str(seed), experiment,
                                 'resnet.pt')
    torch.save(model.state_dict(), save_file)
    if f is not None:
        f.close()
Exemplo n.º 5
0
def load_models(device,
                base_folder='./models/BAM/',
                specific="bowling_alley", 
                seed=0, 
                module="layer3",
                experiment="sgd_finetuned",
                ratio="0.5",
                adv=False,
                baseline=False,
                epoch=None,
                post=False,
                multiple=True,
                leakage=False,
                tcav=False,
                force=False,
                dataset='bam',
                args=None,
                ignore_net=False):
    '''
    if dataset == 'coco' and adv:
        class DummyArgs:
            num_object = 79
            finetune=False
            layer='generated_image'
            autoencoder_finetune=True
            finetune=True
        model = balanced_models.ObjectMultiLabelAdv(DummyArgs(), 79, 300, True, 1)
        ok    = torch.load('model_best.pth.tar', encoding='bytes')
        state_dict = {key.decode("utf-8"):ok[b'state_dict'][key] for key in ok[b'state_dict']}
        model.load_state_dict(state_dict)
        model.to(device)
        model.eval()
    '''
    if leakage:
        assert post
    if epoch is not None:
        epoch = "_" + str(epoch)
    else:
        epoch = ""
    if len(args.custom_end) > 0:
        args.custom_end = "_" + str(args.custom_end)
    if baseline:
        model_end = "resnet_base_"+str(ratio)+epoch+'.pt'
        if not post:
            n2v_end   = "resnet_n2v_base_"+str(ratio)+epoch+'.pt'
        else:
            n2v_end   = "resnet_n2v_base_after_"+str(ratio)+epoch+'.pt'
    else:
        if not adv:
            model_end = "resnet_debias_"+str(ratio)+epoch+'.pt'
            if not post:
                n2v_end   = "resnet_n2v_debias_"+str(ratio)+epoch+'.pt'
            else:
                n2v_end   = "resnet_n2v_debias_after_"+str(ratio)+epoch+str(args.custom_end)+'.pt'
        else:
            model_end = "resnet_adv_"+str(ratio)+'.pt'
            if not post:
                n2v_end   = "resnet_n2v_adv_"+str(ratio)+'.pt'
            else:
                n2v_end   = "resnet_n2v_adv_after_"+str(ratio)+epoch+'.pt'
    if dataset != 'bam' and dataset != 'coco':
        model_end = model_end.replace('_'+str(ratio), '')
        n2v_end   = n2v_end.replace('_'+str(ratio), '')
    if dataset == 'bam' or dataset == 'coco':
        model_path, n2v_path = utils.get_paths(
                base_folder,
                seed,
                specific,
                model_end=model_end,
                n2v_end='leakage/' + n2v_end.replace('n2v','mlp') if leakage else n2v_end,
                n2v_module=module,
                experiment=experiment,
                with_n2v=True,
        )
    else:
        model_path = os.path.join(base_folder, str(seed), experiment, module, model_end)
        n2v_path = os.path.join(base_folder, str(seed), experiment, module, 'leakage/' + n2v_end.replace('n2v','mlp') if leakage else n2v_end)
    if dataset == 'bam':
        trainloader, _ = dataload.get_data_loader_SceneBAM(seed=seed,ratio=float(ratio), specific=specific)
        _, testloader = dataload.get_data_loader_SceneBAM(seed=seed,ratio=float(0.5), specific=specific)
    elif dataset == 'coco':
        tmp_args = copy.deepcopy(args)
        tmp_args.ratio = ratio
        if int(ratio) > 0:
            tmp_args.balanced = True
        if leakage:
            tmp_args.gender_balanced = True
        trainloader, testloader = coco_dataload.get_data_loader_coco(
            tmp_args
        )
    else:
        trainloader,testloader = dataload.get_data_loader_idenProf('idenprof',train_shuffle=True,
                                                                   train_batch_size=64,
                                                                   test_batch_size=64,
                                                                   exclusive=True)
    if not (dataset == 'coco' and adv):
        assert os.path.exists(model_path), model_path
    if post:
        # since we have to run a separate script, might not have finished...
        if not leakage:
            model_extra = '_adv' if adv else ('_base' if baseline else '_debias')
            n2v_extra   = model_extra + '_after'
            if tcav:
                pass
            elif force:
                post_train.train_net2vec(trainloader, 
                                        testloader, 
                                        device, 
                                        seed,
                                        specific=specific,
                                        p=ratio,
                                        n_epochs=20,
                                        module=module,
                                        lr=.01,
                                        out_file=None,
                                        base_folder=base_folder,
                                        experiment1=experiment,
                                        experiment2=experiment,
                                        model_extra=model_extra,
                                        n2v_extra=n2v_extra,
                                        with_n2v=True,
                                        nonlinear=False, # might want to change this later
                                        model_custom_end=epoch.replace('_',''),
                                        n2v_custom_end=epoch.replace('_',''),
                                        multiple=multiple,
                                        dataset=dataset
                )
            else:
                raise Exception('Run trial again')
        elif leakage:
            model_extra = '_adv' if adv else ('_base' if baseline else '_debias')
            n2v_extra   = model_extra + '_after'
            if force:
                post_train.train_leakage(trainloader, 
                                        testloader, 
                                        device, 
                                        seed,
                                        specific=specific,
                                        p=ratio,
                                        n_epochs=20,
                                        module=module,
                                        lr=5e-5, # leakage model uses adam
                                        out_file=None,
                                        base_folder=base_folder,
                                        experiment1=experiment,
                                        experiment2=experiment,
                                        model_extra=model_extra,
                                        n2v_extra=n2v_extra,
                                        with_n2v=True,
                                        nonlinear=True, # MLP leakage model
                                        model_custom_end='',
                                        n2v_custom_end='',
                                        dataset=dataset
                )
            else:
                raise Exception('Run trial again')
    else:
        # should've been saved during training if not ported from tianlu
        if not (dataset == 'coco' and adv):
            assert os.path.exists(n2v_path)
    num_attributes = 10 + 9 + 20 if multiple else 12
    num_classes=10
    if dataset == 'coco':
        num_attributes = 81
        num_classes = 79
    model, net, net_forward, activation_probe = models.load_models(
        device,
        None if (dataset == 'coco' and adv) else
        lambda x,y,z: models.resnet_(
            pretrained=True, 
            custom_path=x, 
            device=y,
            initialize=z, 
            num_classes=num_classes,
            size=50 if (dataset == 'bam') or (dataset == 'coco') else 34
        ),
        model_path=model_path,
        net2vec_pretrained=True,
        net2vec_path=n2v_path,
        module='fc' if leakage else module,
        num_attributes=2 if leakage else num_attributes,
        model_init = False,
        n2v_init = False,
        nonlinear = leakage,
        ignore_net = ignore_net
    )
    print(n2v_path)
    return model, net, net_forward, activation_probe
Exemplo n.º 6
0
def load_models(device,
                base_folder='./models/BAM/',
                specific="bowling_alley", 
                seed=0, 
                module="layer3",
                experiment="sgd_finetuned",
                ratio="0.5",
                adv=False,
                baseline=False,
                epoch=None,
                post=False,
                multiple=True,
                leakage=False,
                tcav=False,
                force=False,
                dataset='bam'):
    if leakage:
        assert post
    if epoch is not None:
        epoch = "_" + str(epoch)
    else:
        epoch = ""
    if baseline:
        model_end = "resnet_base_"+str(ratio)+epoch+'.pt'
        if not post:
            n2v_end   = "resnet_n2v_base_"+str(ratio)+epoch+'.pt'
        else:
            n2v_end   = "resnet_n2v_base_after_"+str(ratio)+epoch+'.pt'
    else:
        if not adv:
            model_end = "resnet_debias_"+str(ratio)+epoch+'.pt'
            if not post:
                n2v_end   = "resnet_n2v_debias_"+str(ratio)+epoch+'.pt'
            else:
                n2v_end   = "resnet_n2v_debias_after_"+str(ratio)+epoch+'.pt'
        else:
            model_end = "resnet_adv_"+str(ratio)+'.pt'
            if not post:
                n2v_end   = "resnet_n2v_adv_"+str(ratio)+'.pt'
            else:
                n2v_end   = "resnet_n2v_adv_after_"+str(ratio)+epoch+'.pt'
    if dataset != 'bam':
        model_end = model_end.replace('_'+str(ratio), '')
        n2v_end   = n2v_end.replace('_'+str(ratio), '')
    if dataset == 'bam':
        model_path, n2v_path = utils.get_paths(
                base_folder,
                seed,
                specific,
                model_end=model_end,
                n2v_end='leakage/' + n2v_end.replace('n2v','mlp') if leakage else n2v_end,
                n2v_module=module,
                experiment=experiment,
                with_n2v=True,
        )
    else:
        model_path = os.path.join(base_folder, str(seed), experiment, module, model_end)
        n2v_path = os.path.join(base_folder, str(seed), experiment, module, 'leakage/' + n2v_end.replace('n2v','mlp') if leakage else n2v_end)
    if dataset == 'bam':
        trainloader, _ = dataload.get_data_loader_SceneBAM(seed=seed,ratio=float(ratio), specific=specific)
        _, testloader = dataload.get_data_loader_SceneBAM(seed=seed,ratio=float(0.5), specific=specific)
    else:
        trainloader,testloader = dataload.get_data_loader_idenProf('idenprof',train_shuffle=True,
                                                                   train_batch_size=64,
                                                                   test_batch_size=64,
                                                                   exclusive=True)
    assert os.path.exists(model_path), model_path
    if post:
        # since we have to run a separate script, might not have finished...
        if not leakage:
            model_extra = '_adv' if adv else ('_base' if baseline else '_debias')
            n2v_extra   = model_extra + '_after'
            if tcav:
                pass
            elif force:
                post_train.train_net2vec(trainloader, 
                                        testloader, 
                                        device, 
                                        seed,
                                        specific=specific,
                                        p=ratio,
                                        n_epochs=20,
                                        module=module,
                                        lr=.01,
                                        out_file=None,
                                        base_folder=base_folder,
                                        experiment1=experiment,
                                        experiment2=experiment,
                                        model_extra=model_extra,
                                        n2v_extra=n2v_extra,
                                        with_n2v=True,
                                        nonlinear=False, # might want to change this later
                                        model_custom_end=epoch.replace('_',''),
                                        n2v_custom_end=epoch.replace('_',''),
                                        multiple=multiple,
                                        dataset=dataset
                )
            else:
                raise Exception('Run trial again')
        elif leakage:
            model_extra = '_adv' if adv else ('_base' if baseline else '_debias')
            n2v_extra   = model_extra + '_after'
            if force:
                post_train.train_leakage(trainloader, 
                                        testloader, 
                                        device, 
                                        seed,
                                        specific=specific,
                                        p=ratio,
                                        n_epochs=20,
                                        module=module,
                                        lr=5e-5, # leakage model uses adam
                                        out_file=None,
                                        base_folder=base_folder,
                                        experiment1=experiment,
                                        experiment2=experiment,
                                        model_extra=model_extra,
                                        n2v_extra=n2v_extra,
                                        with_n2v=True,
                                        nonlinear=True, # MLP leakage model
                                        model_custom_end='',
                                        n2v_custom_end='',
                                        dataset=dataset
                )
            else:
                raise Exception('Run trial again')
    else:
        # should've been saved during training
        assert os.path.exists(n2v_path)
    num_attributes = 10 + 9 + 20 if multiple else 12
    model, net, net_forward, activation_probe = models.load_models(
        device,
        lambda x,y,z: models.resnet_(pretrained=True, custom_path=x, device=y,initialize=z, size=50 if dataset == 'bam' else 34),
        model_path=model_path,
        net2vec_pretrained=True,
        net2vec_path=n2v_path,
        module='fc' if leakage else module,
        num_attributes=2 if leakage else num_attributes,
        model_init = False,
        n2v_init = False,
        nonlinear = leakage
    )
    return model, net, net_forward, activation_probe