Example #1
0
    def lr_step(self):
        """
		Learning rate scheduler
		"""
        if self.args.optimizer == 'SGD':
            self.tgt_opt = utils.inv_lr_scheduler(self.param_lr_c,
                                                  self.tgt_opt,
                                                  self.current_step,
                                                  init_lr=self.args.lr)
def train_source(config, base_network, classifier_gnn, dset_loaders):
    # define loss functions
    criterion_gedge = nn.BCELoss(reduction='mean')
    ce_criterion = nn.CrossEntropyLoss()

    # configure optimizer
    optimizer_config = config['optimizer']
    parameter_list = base_network.get_parameters() +\
                     [{'params': classifier_gnn.parameters(), 'lr_mult': 10, 'decay_mult': 2}]
    optimizer = optimizer_config['type'](parameter_list, **(optimizer_config['optim_params']))

    # configure learning rates
    param_lr = []
    for param_group in optimizer.param_groups:
        param_lr.append(param_group['lr'])
    schedule_param = optimizer_config['lr_param']

    # start train loop
    base_network.train()
    classifier_gnn.train()
    len_train_source = len(dset_loaders["source"])
    for i in range(config['source_iters']):
        optimizer = utils.inv_lr_scheduler(optimizer, i, **schedule_param)
        optimizer.zero_grad()

        # get input data
        if i % len_train_source == 0:
            iter_source = iter(dset_loaders["source"])
        batch_source = iter_source.next()
        inputs_source, labels_source = batch_source['img'].to(DEVICE), batch_source['target'].to(DEVICE)

        # make forward pass for encoder and mlp head
        features_source, logits_mlp = base_network(inputs_source)
        mlp_loss = ce_criterion(logits_mlp, labels_source)

        # make forward pass for gnn head
        logits_gnn, edge_sim = classifier_gnn(features_source)
        gnn_loss = ce_criterion(logits_gnn, labels_source)
        # compute edge loss
        edge_gt, edge_mask = classifier_gnn.label2edge(labels_source.unsqueeze(dim=0))
        edge_loss = criterion_gedge(edge_sim.masked_select(edge_mask), edge_gt.masked_select(edge_mask))

        # total loss and backpropagation
        loss = mlp_loss + config['lambda_node'] * gnn_loss + config['lambda_edge'] * edge_loss
        loss.backward()
        optimizer.step()

        # printout train loss
        if i % 20 == 0 or i == config['source_iters'] - 1:
            log_str = 'Iters:(%4d/%d)\tMLP loss:%.4f\tGNN loss:%.4f\tEdge loss:%.4f' % (i,
                  config['source_iters'], mlp_loss.item(), gnn_loss.item(), edge_loss.item())
            utils.write_logs(config, log_str)
        # evaluate network every test_interval
        if i % config['test_interval'] == config['test_interval'] - 1:
            evaluate(i, config, base_network, classifier_gnn, dset_loaders['target_test'])

    return base_network, classifier_gnn
Example #3
0
def main():
    set_seed(args.seed)
    torch.cuda.set_device(args.gpu)

    encoder = Encoder().cuda()
    encoder_group_0 = Encoder().cuda()
    encoder_group_1 = Encoder().cuda()

    dfc = DFC(cluster_number=args.k, hidden_dimension=64).cuda()
    dfc_group_0 = DFC(cluster_number=args.k, hidden_dimension=64).cuda()
    dfc_group_1 = DFC(cluster_number=args.k, hidden_dimension=64).cuda()

    critic = AdversarialNetwork(in_feature=args.k,
                                hidden_size=32,
                                max_iter=args.iters,
                                lr_mult=args.adv_mult).cuda()

    # encoder pre-trained with self-reconstruction
    encoder.load_state_dict(torch.load("./save/encoder_pretrain.pth"))

    # encoder and clustering model trained by DEC
    encoder_group_0.load_state_dict(torch.load("./save/encoder_mnist.pth"))
    encoder_group_1.load_state_dict(torch.load("./save/encoder_usps.pth"))
    dfc_group_0.load_state_dict(torch.load("./save/dec_mnist.pth"))
    dfc_group_1.load_state_dict(torch.load("./save/dec_usps.pth"))

    # load clustering centroids given by k-means
    centers = np.loadtxt("./save/centers.txt")
    cluster_centers = torch.tensor(centers,
                                   dtype=torch.float,
                                   requires_grad=True).cuda()
    with torch.no_grad():
        print("loading clustering centers...")
        dfc.state_dict()['assignment.cluster_centers'].copy_(cluster_centers)

    optimizer = torch.optim.Adam(dfc.get_parameters() +
                                 encoder.get_parameters() +
                                 critic.get_parameters(),
                                 lr=args.lr,
                                 weight_decay=5e-4)
    criterion_c = nn.KLDivLoss(reduction="sum")
    criterion_p = nn.MSELoss(reduction="sum")
    C_LOSS = AverageMeter()
    F_LOSS = AverageMeter()
    P_LOSS = AverageMeter()

    encoder_group_0.eval(), encoder_group_1.eval()
    dfc_group_0.eval(), dfc_group_1.eval()

    data_loader = mnist_usps(args)
    len_image_0 = len(data_loader[0])
    len_image_1 = len(data_loader[1])

    for step in range(args.iters):
        encoder.train()
        dfc.train()
        if step % len_image_0 == 0:
            iter_image_0 = iter(data_loader[0])
        if step % len_image_1 == 0:
            iter_image_1 = iter(data_loader[1])

        image_0, _ = iter_image_0.__next__()
        image_1, _ = iter_image_1.__next__()

        image_0, image_1 = image_0.cuda(), image_1.cuda()
        image = torch.cat((image_0, image_1), dim=0)

        predict_0, predict_1 = dfc_group_0(
            encoder_group_0(image_0)[0]), dfc_group_1(
                encoder_group_1(image_1)[0])

        z, _, _ = encoder(image)
        output = dfc(z)

        output_0, output_1 = output[0:args.bs, :], output[args.bs:args.bs *
                                                          2, :]
        target_0, target_1 = target_distribution(
            output_0).detach(), target_distribution(output_1).detach()

        clustering_loss = 0.5 * criterion_c(output_0.log(
        ), target_0) + 0.5 * criterion_c(output_1.log(), target_1)
        fair_loss = adv_loss(output, critic)
        partition_loss = 0.5 * criterion_p(aff(output_0), aff(predict_0).detach()) \
                         + 0.5 * criterion_p(aff(output_1), aff(predict_1).detach())
        total_loss = clustering_loss + args.coeff_fair * fair_loss + args.coeff_par * partition_loss

        optimizer = inv_lr_scheduler(optimizer, args.lr, step, args.iters)
        optimizer.zero_grad()
        total_loss.backward()
        optimizer.step()

        C_LOSS.update(clustering_loss)
        F_LOSS.update(fair_loss)
        P_LOSS.update(partition_loss)

        if step % args.test_interval == args.test_interval - 1 or step == 0:
            predicted, labels = predict(data_loader, encoder, dfc)
            predicted, labels = predicted.cpu().numpy(), labels.numpy()
            _, accuracy = cluster_accuracy(predicted, labels, 10)
            nmi = normalized_mutual_info_score(labels,
                                               predicted,
                                               average_method="arithmetic")
            bal, en_0, en_1 = balance(predicted, 60000)

            print("Step:[{:03d}/{:03d}]  "
                  "Acc:{:2.3f};"
                  "NMI:{:1.3f};"
                  "Bal:{:1.3f};"
                  "En:{:1.3f}/{:1.3f};"
                  "C.loss:{C_Loss.avg:3.2f};"
                  "F.loss:{F_Loss.avg:3.2f};"
                  "P.loss:{P_Loss.avg:3.2f};".format(step + 1,
                                                     args.iters,
                                                     accuracy,
                                                     nmi,
                                                     bal,
                                                     en_0,
                                                     en_1,
                                                     C_Loss=C_LOSS,
                                                     F_Loss=F_LOSS,
                                                     P_Loss=P_LOSS))

    return
Example #4
0
def adapt_target(config, base_network, classifier_gnn, dset_loaders,
                 max_inherit_domain):
    # define loss functions
    criterion_gedge = nn.BCELoss(reduction='mean')
    ce_criterion = nn.CrossEntropyLoss()
    # add random layer and adversarial network
    class_num = config['encoder']['params']['class_num']
    random_layer = networks.RandomLayer([base_network.output_num(), class_num],
                                        config['random_dim'], DEVICE)

    adv_net = networks.AdversarialNetwork(config['random_dim'],
                                          config['random_dim'],
                                          config['ndomains'])

    random_layer.to(DEVICE)
    adv_net = adv_net.to(DEVICE)

    # configure optimizer
    optimizer_config = config['optimizer']
    parameter_list = base_network.get_parameters() + adv_net.get_parameters() \
                     + [{'params': classifier_gnn.parameters(), 'lr_mult': 10, 'decay_mult': 2}]
    optimizer = optimizer_config['type'](parameter_list,
                                         **(optimizer_config['optim_params']))
    # configure learning rates
    param_lr = []
    for param_group in optimizer.param_groups:
        param_lr.append(param_group['lr'])
    schedule_param = optimizer_config['lr_param']

    # start train loop
    len_train_source = len(dset_loaders['source'])
    len_train_target = len(dset_loaders['target_train'][max_inherit_domain])
    # set nets in train mode
    base_network.train()
    classifier_gnn.train()
    adv_net.train()
    random_layer.train()
    for i in range(config['adapt_iters']):
        optimizer = utils.inv_lr_scheduler(optimizer, i, **schedule_param)
        optimizer.zero_grad()
        # get input data
        if i % len_train_source == 0:
            iter_source = iter(dset_loaders['source'])
        if i % len_train_target == 0:
            iter_target = iter(
                dset_loaders['target_train'][max_inherit_domain])
        batch_source = iter_source.next()
        batch_target = iter_target.next()
        inputs_source, inputs_target = batch_source['img'].to(
            DEVICE), batch_target['img'].to(DEVICE)
        labels_source = batch_source['target'].to(DEVICE)
        domain_source, domain_target = batch_source['domain'].to(
            DEVICE), batch_target['domain'].to(DEVICE)
        domain_input = torch.cat([domain_source, domain_target], dim=0)

        # make forward pass for encoder and mlp head
        features_source, logits_mlp_source = base_network(inputs_source)
        features_target, logits_mlp_target = base_network(inputs_target)
        features = torch.cat((features_source, features_target), dim=0)
        logits_mlp = torch.cat((logits_mlp_source, logits_mlp_target), dim=0)
        softmax_mlp = nn.Softmax(dim=1)(logits_mlp)
        mlp_loss = ce_criterion(logits_mlp_source, labels_source)

        # *** GNN at work ***
        # make forward pass for gnn head
        logits_gnn, edge_sim = classifier_gnn(features)
        gnn_loss = ce_criterion(logits_gnn[:labels_source.size(0)],
                                labels_source)
        # compute pseudo-labels for affinity matrix by mlp classifier
        out_target_class = torch.softmax(logits_mlp_target, dim=1)
        target_score, target_pseudo_labels = out_target_class.max(1,
                                                                  keepdim=True)
        idx_pseudo = target_score > config['threshold']
        target_pseudo_labels[~idx_pseudo] = classifier_gnn.mask_val
        # combine source labels and target pseudo labels for edge_net
        node_labels = torch.cat(
            (labels_source, target_pseudo_labels.squeeze(dim=1)),
            dim=0).unsqueeze(dim=0)
        # compute source-target mask and ground truth for edge_net
        edge_gt, edge_mask = classifier_gnn.label2edge(node_labels)
        # compute edge loss
        edge_loss = criterion_gedge(edge_sim.masked_select(edge_mask),
                                    edge_gt.masked_select(edge_mask))

        # *** Adversarial net at work ***
        if config['method'] == 'CDAN+E':
            entropy = transfer_loss.Entropy(softmax_mlp)
            trans_loss = transfer_loss.CDAN(config['ndomains'],
                                            [features, softmax_mlp], adv_net,
                                            entropy, networks.calc_coeff(i),
                                            random_layer, domain_input)
        elif config['method'] == 'CDAN':
            trans_loss = transfer_loss.CDAN(config['ndomains'],
                                            [features, softmax_mlp], adv_net,
                                            None, None, random_layer,
                                            domain_input)
        else:
            raise ValueError('Method cannot be recognized.')

        # total loss and backpropagation
        loss = config['lambda_adv'] * trans_loss + mlp_loss +\
               config['lambda_node'] * gnn_loss + config['lambda_edge'] * edge_loss
        loss.backward()
        optimizer.step()
        # printout train loss
        if i % 20 == 0 or i == config['adapt_iters'] - 1:
            log_str = 'Iters:(%4d/%d)\tMLP loss: %.4f\t GNN Loss: %.4f\t Edge Loss: %.4f\t Transfer loss:%.4f' % (
                i, config["adapt_iters"], mlp_loss.item(),
                config['lambda_node'] * gnn_loss.item(), config['lambda_edge']
                * edge_loss.item(), config['lambda_adv'] * trans_loss.item())
            utils.write_logs(config, log_str)
        # evaluate network every test_interval
        if i % config['test_interval'] == config['test_interval'] - 1:
            evaluate(i, config, base_network, classifier_gnn,
                     dset_loaders['target_test'])

    return base_network, classifier_gnn
Example #5
0
def train(args,
          dataloader_list,
          encoder,
          encoder_group_0=None,
          encoder_group_1=None,
          dfc_group_0=None,
          dfc_group_1=None,
          device='cpu',
          centers=None,
          get_loss_trade_off=lambda step: (10, 10, 10),
          save_name='DFC'):
    """Trains DFC and optionally the critic,

    automatically saves when finished training

    Args:
        args: Namespace object which contains config set from argument parser
              {
                lr,
                seed,
                iters,
                log_dir,
                test_interval,
                adv_multiplier,
                dfc_hidden_dim
              }
        dataloader_list (list): this list may consist of only 1 dataloader or multiple
        encoder: Encoder to use
        encoder_group_0: Optional pre-trained golden standard model
        encoder_group_1: Optional pre-trained golden standard model
        dfc_group_0: Optional cluster centers file obtained with encoder_group_0
        dfc_group_1: Optional cluster centers file obtained with encoder_group_1
        device: Device configuration
        centers: Initial centers clusters if available
        get_loss_trade_off: Proportional importance of individual loss functions
        save_name: Prefix for save files
    Returns:
        DFC: A trained DFC model
    """

    set_seed(args.seed)
    if args.half_tensor:
        torch.set_default_tensor_type('torch.HalfTensor')

    dfc = DFC(cluster_number=args.cluster_number,
              hidden_dimension=args.dfc_hidden_dim).to(device)
    wandb.watch(dfc)

    critic = AdversarialNetwork(in_feature=args.cluster_number,
                                hidden_size=32,
                                max_iter=args.iters,
                                lr_mult=args.adv_multiplier).to(device)
    wandb.watch(critic)

    if not (centers is None):
        cluster_centers = centers.clone().detach().requires_grad_(True).to(
            device)
        with torch.no_grad():
            print("loading clustering centers...")
            dfc.state_dict()['assignment.cluster_centers'].copy_(
                cluster_centers)

    encoder_param = encoder.get_parameters(
    ) if args.encoder_type == 'vae' else [{
        "params": encoder.parameters(),
        "lr_mult": 1
    }]
    optimizer = torch.optim.Adam(dfc.get_parameters() + encoder_param +
                                 critic.get_parameters(),
                                 lr=args.dec_lr,
                                 weight_decay=5e-4)

    criterion_c = nn.KLDivLoss(reduction="sum")
    criterion_p = nn.MSELoss(reduction="sum")
    C_LOSS = AverageMeter()
    F_LOSS = AverageMeter()
    P_LOSS = AverageMeter()

    partition_loss_enabled = True
    if not encoder_group_0 or not encoder_group_1 or not dfc_group_0 or not dfc_group_1:
        print(
            "Missing Golden Standard models, switching to DEC mode instead of DFC."
        )
        partition_loss_enabled = False

    if partition_loss_enabled:
        encoder_group_0.eval(), encoder_group_1.eval()
        dfc_group_0.eval(), dfc_group_1.eval()

    print("Start training")
    assert 0 < len(dataloader_list) < 3
    len_image_0 = len(dataloader_list[0])
    len_image_1 = len(
        dataloader_list[1]) if len(dataloader_list) == 2 else None
    for step in range(args.iters):
        encoder.train()
        dfc.train()

        if step % len_image_0 == 0:
            iter_image_0 = iter(dataloader_list[0])
        if len_image_1 and step % len_image_1 == 0:
            iter_image_1 = iter(dataloader_list[1])

        image_0, _ = iter_image_0.__next__()
        image_0 = image_0.to(device)
        if not (len_image_1 is None):
            image_1, _ = iter_image_1.__next__()
            image_1 = image_1.to(device)
            image = torch.cat((image_0, image_1), dim=0)
        else:
            image_1 = None
            image = torch.cat((image_0, ), dim=0)

        if args.encoder_type == 'vae':
            z, _, _ = encoder(image)
        elif args.encoder_type == 'resnet50':
            z = encoder(image)

        else:
            raise Exception(
                'Wrong encoder type, how did you get this far in running the code?'
            )
        output = dfc(z)
        features_enc_0 = encoder_group_0(
            image_0)[0] if args.encoder_type == 'vae' else encoder_group_0(
                image_0)
        predict_0 = dfc_group_0(features_enc_0)
        features_enc_1 = encoder_group_1(
            image_1)[0] if args.encoder_type == 'vae' else encoder_group_1(
                image_1)
        predict_1 = dfc_group_1(features_enc_1) if not (
            image_1 is None) else None

        output_0, output_1 = output[0:args.bs, :], output[
            args.bs:args.bs * 2, :] if not (predict_1 is None) else None
        target_0, target_1 = target_distribution(
            output_0).detach(), target_distribution(output_1).detach() if not (
                predict_1 is None) else None

        # Equaition (5) in the paper
        # output_0 and output_1 are probability distribution P of samples being assinged to a class in k
        # target_0 and target_1 are auxiliary distribuion Q calculated based on P. Eqation (4) in the paper
        if not (output_1 is None):
            clustering_loss = 0.5 * criterion_c(output_0.log(
            ), target_0) + 0.5 * criterion_c(output_1.log(), target_1)
        else:
            clustering_loss = criterion_c(output_0.log(), target_0)

        # Equation (2) in the paper
        # output = D(A(F(X)))
        # critic is the distribuition of categorical sensitive subgroup variable G (?)
        if len(dataloader_list) > 1:
            fair_loss, critic_acc = adv_loss(output, critic, device=device)
        else:
            fair_loss, critic_acc = 0, 0

        if partition_loss_enabled:
            # Equation (3) in the paper
            # output_0 and output_1 are the output of the pretrained encoder
            # predict_0 and predict_1 are the soft cluster assignments of the DFC.
            # loss is high if the outputs and predictions (and this the cluster structures) differ.
            if not (predict_1 is None):
                partition_loss = 0.5 * criterion_p(aff(output_0), aff(predict_0).detach()) \
                                 + 0.5 * criterion_p(aff(output_1), aff(predict_1).detach())
            else:
                partition_loss = criterion_p(aff(output_0),
                                             aff(predict_0).detach())
        else:
            partition_loss = 0

        loss_trade_off = get_loss_trade_off(step)
        if args.encoder_type == 'resnet50' and args.dataset == 'office_31':  # alpha_s
            loss_trade_off = list(loss_trade_off)
            loss_trade_off[1] = ((512 / 128)**2) * (31 / 10)

        total_loss = loss_trade_off[0] * fair_loss + loss_trade_off[
            1] * partition_loss + loss_trade_off[2] * clustering_loss

        optimizer = inv_lr_scheduler(optimizer, args.lr, step, args.iters)
        optimizer.zero_grad()
        total_loss.backward()
        optimizer.step()

        C_LOSS.update(clustering_loss)
        F_LOSS.update(fair_loss)
        P_LOSS.update(partition_loss)

        wandb.log({
            f"{save_name} Train C Loss Avg": C_LOSS.avg,
            f"{save_name} Train F Loss Avg": F_LOSS.avg,
            f"{save_name} Train P Loss Avg": P_LOSS.avg,
            f"{save_name} step": step,
            f"{save_name} Critic ACC": critic_acc
        })
        wandb.log({
            f"{save_name} Train C Loss Cur": C_LOSS.val,
            f"{save_name} Train F Loss Cur": F_LOSS.val,
            f"{save_name} Train P Loss Cur": P_LOSS.val,
            f"{save_name} step": step
        })

        if step % args.test_interval == args.test_interval - 1 or step == 0:

            predicted, labels = predict(dataloader_list,
                                        encoder,
                                        dfc,
                                        device=device,
                                        encoder_type=args.encoder_type)
            predicted, labels = predicted.cpu().numpy(), labels.numpy()
            _, accuracy = cluster_accuracy(predicted, labels,
                                           args.cluster_number)
            nmi = normalized_mutual_info_score(labels,
                                               predicted,
                                               average_method="arithmetic")
            bal, en_0, en_1 = balance(predicted,
                                      len_image_0,
                                      k=args.cluster_number)

            wandb.log({
                f"{save_name} Train Accuracy": accuracy,
                f"{save_name} Train NMI": nmi,
                f"{save_name} Train Bal": bal,
                f"{save_name} Train Entropy 0": en_0,
                f"{save_name} Train Entropy 1": en_1,
                f"{save_name} step": step
            })

            print("Step:[{:03d}/{:03d}]  "
                  "Acc:{:2.3f};"
                  "NMI:{:1.3f};"
                  "Bal:{:1.3f};"
                  "En:{:1.3f}/{:1.3f};"
                  "Clustering.loss:{C_Loss.avg:3.2f};"
                  "Fairness.loss:{F_Loss.avg:3.2f};"
                  "Partition.loss:{P_Loss.avg:3.2f};".format(step + 1,
                                                             args.iters,
                                                             accuracy,
                                                             nmi,
                                                             bal,
                                                             en_0,
                                                             en_1,
                                                             C_Loss=C_LOSS,
                                                             F_Loss=F_LOSS,
                                                             P_Loss=P_LOSS))

            # log tsne visualisation
            if args.encoder_type == "vae":
                tsne_img = tsne_visualization(dataloader_list,
                                              encoder,
                                              args.cluster_number,
                                              encoder_type=args.encoder_type,
                                              device=device)

                if not (tsne_img is None):
                    wandb.log({
                        f"{save_name} TSNE": plt,
                        f"{save_name} step": step
                    })

    torch.save(dfc.state_dict(), f'{args.log_dir}DFC_{save_name}.pth')

    if len(dataloader_list) > 1:
        torch.save(critic.state_dict(),
                   f'{args.log_dir}CRITIC_{save_name}.pth')

    return dfc