コード例 #1
0
 def test_basic(self):
     """
     Basic test to check that the calculation is sensible and conforms to the formula.
     """
     test_tensor = torch.Tensor([[0.5, 0.5], [0.0, 1.0]])
     output = target_distribution(test_tensor)
     self.assertAlmostEqual(tuple(output[0]), (0.75, 0.25))
     self.assertAlmostEqual(tuple(output[1]), (0.0, 1.0))
コード例 #2
0
def activeUnsL(model, node,  label, features, adj_lists, num_features, num_hidden, num_cls, filetime,labels, xi=1e-6, eps=2.5, num_iters=10):
    #obtain the adj matrix and find the best perturbation direction then add perturbation to the attention matrix
    
    encSpc = model(node, actE = True)
    
    dec_ae = DEC_AE(50, 100, num_hidden)
    dec = DEC(num_cls, 50, dec_ae)
    kmeans = KMeans(n_clusters=dec.cluster_number, n_init=20)
    features = []
    # form initial cluster centres
    dec.pretrain(encSpc.data)
    features = dec.ae.encoder(encSpc).detach()
    predicted = kmeans.fit_predict(features)
    predicted_previous = torch.tensor(np.copy(predicted), dtype=torch.long)
    _, accuracy, _, _ = cluster_accuracy( label, predicted)
    print("ACCU", accuracy)
    cluster_centers = torch.tensor(kmeans.cluster_centers_, dtype=torch.float)
    #print(features)
    dec.assignment.cluster_centers = torch.nn.Parameter(cluster_centers)
    loss_function = nn.KLDivLoss(size_average=False)
    delta_label = None
    optimizer = torch.optim.SGD(dec.parameters(), lr = 0.01, momentum=0.9)
    for epoch in range(250):
        dec.train()
        output = dec(encSpc)
        target = target_distribution(output).detach()
        loss = loss_function(output.log(), target) / output.shape[0]
        
        optimizer.zero_grad()
        loss.backward(retain_graph=True)
        optimizer.step(closure=None)
        features = dec.ae.encoder(encSpc).detach()
        #predicted =  dec(test)
        predicted = output.argmax(dim = 1)
        delta_label = float((predicted != predicted_previous ).float().sum().item()) / predicted_previous.shape[0]

        predicted_previous = predicted
        _, accuracy, _, _ = cluster_accuracy(np.array(predicted), np.array(label))
        
        if epoch % 50 == 49: 
            count_matrix = np.zeros((num_cls, num_cls), dtype=np.int64)
            for i in range(len(predicted)):
                count_matrix[np.array(predicted)[i], np.array(label)[i]] += 1
            for i in range(num_cls):
                print(count_matrix[i])
            summary(node, labels, np.array(predicted), num_cls, filetime, outlog = False, output=None)
        
            print(loss)
            print("ACCU", accuracy)
コード例 #3
0
def Train(epoch, model, data, adj, label, lr, pre_model_save_path,
          final_model_save_path, n_clusters, original_acc, gamma_value,
          lambda_value, device):
    optimizer = Adam(model.parameters(), lr=lr)
    model.load_state_dict(torch.load(pre_model_save_path, map_location='cpu'))
    with torch.no_grad():
        x_hat, z_hat, adj_hat, z_ae, z_igae, _, _, _, z_tilde = model(
            data, adj)
    kmeans = KMeans(n_clusters=n_clusters, n_init=20)
    cluster_id = kmeans.fit_predict(z_tilde.data.cpu().numpy())
    model.cluster_layer.data = torch.tensor(kmeans.cluster_centers_).to(device)
    eva(label, cluster_id, 'Initialization')

    for epoch in range(epoch):
        # if opt.args.name in use_adjust_lr:
        #     adjust_learning_rate(optimizer, epoch)
        x_hat, z_hat, adj_hat, z_ae, z_igae, q, q1, q2, z_tilde = model(
            data, adj)

        tmp_q = q.data
        p = target_distribution(tmp_q)

        loss_ae = F.mse_loss(x_hat, data)
        loss_w = F.mse_loss(z_hat, torch.spmm(adj, data))
        loss_a = F.mse_loss(adj_hat, adj.to_dense())
        loss_igae = loss_w + gamma_value * loss_a
        loss_kl = F.kl_div((q.log() + q1.log() + q2.log()) / 3,
                           p,
                           reduction='batchmean')
        loss = loss_ae + loss_igae + lambda_value * loss_kl
        print('{} loss: {}'.format(epoch, loss))

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        kmeans = KMeans(n_clusters=n_clusters,
                        n_init=20).fit(z_tilde.data.cpu().numpy())

        acc, nmi, ari, f1 = eva(label, kmeans.labels_, epoch)
        acc_reuslt.append(acc)
        nmi_result.append(nmi)
        ari_result.append(ari)
        f1_result.append(f1)

        if acc > original_acc:
            original_acc = acc
            torch.save(model.state_dict(), final_model_save_path)
コード例 #4
0
ファイル: main.py プロジェクト: topteulen/DeepFairClustering
def main():
    set_seed(args.seed)
    torch.cuda.set_device(args.gpu)

    encoder = Encoder().cuda()
    encoder_group_0 = Encoder().cuda()
    encoder_group_1 = Encoder().cuda()

    dfc = DFC(cluster_number=args.k, hidden_dimension=64).cuda()
    dfc_group_0 = DFC(cluster_number=args.k, hidden_dimension=64).cuda()
    dfc_group_1 = DFC(cluster_number=args.k, hidden_dimension=64).cuda()

    critic = AdversarialNetwork(in_feature=args.k,
                                hidden_size=32,
                                max_iter=args.iters,
                                lr_mult=args.adv_mult).cuda()

    # encoder pre-trained with self-reconstruction
    encoder.load_state_dict(torch.load("./save/encoder_pretrain.pth"))

    # encoder and clustering model trained by DEC
    encoder_group_0.load_state_dict(torch.load("./save/encoder_mnist.pth"))
    encoder_group_1.load_state_dict(torch.load("./save/encoder_usps.pth"))
    dfc_group_0.load_state_dict(torch.load("./save/dec_mnist.pth"))
    dfc_group_1.load_state_dict(torch.load("./save/dec_usps.pth"))

    # load clustering centroids given by k-means
    centers = np.loadtxt("./save/centers.txt")
    cluster_centers = torch.tensor(centers,
                                   dtype=torch.float,
                                   requires_grad=True).cuda()
    with torch.no_grad():
        print("loading clustering centers...")
        dfc.state_dict()['assignment.cluster_centers'].copy_(cluster_centers)

    optimizer = torch.optim.Adam(dfc.get_parameters() +
                                 encoder.get_parameters() +
                                 critic.get_parameters(),
                                 lr=args.lr,
                                 weight_decay=5e-4)
    criterion_c = nn.KLDivLoss(reduction="sum")
    criterion_p = nn.MSELoss(reduction="sum")
    C_LOSS = AverageMeter()
    F_LOSS = AverageMeter()
    P_LOSS = AverageMeter()

    encoder_group_0.eval(), encoder_group_1.eval()
    dfc_group_0.eval(), dfc_group_1.eval()

    data_loader = mnist_usps(args)
    len_image_0 = len(data_loader[0])
    len_image_1 = len(data_loader[1])

    for step in range(args.iters):
        encoder.train()
        dfc.train()
        if step % len_image_0 == 0:
            iter_image_0 = iter(data_loader[0])
        if step % len_image_1 == 0:
            iter_image_1 = iter(data_loader[1])

        image_0, _ = iter_image_0.__next__()
        image_1, _ = iter_image_1.__next__()

        image_0, image_1 = image_0.cuda(), image_1.cuda()
        image = torch.cat((image_0, image_1), dim=0)

        predict_0, predict_1 = dfc_group_0(
            encoder_group_0(image_0)[0]), dfc_group_1(
                encoder_group_1(image_1)[0])

        z, _, _ = encoder(image)
        output = dfc(z)

        output_0, output_1 = output[0:args.bs, :], output[args.bs:args.bs *
                                                          2, :]
        target_0, target_1 = target_distribution(
            output_0).detach(), target_distribution(output_1).detach()

        clustering_loss = 0.5 * criterion_c(output_0.log(
        ), target_0) + 0.5 * criterion_c(output_1.log(), target_1)
        fair_loss = adv_loss(output, critic)
        partition_loss = 0.5 * criterion_p(aff(output_0), aff(predict_0).detach()) \
                         + 0.5 * criterion_p(aff(output_1), aff(predict_1).detach())
        total_loss = clustering_loss + args.coeff_fair * fair_loss + args.coeff_par * partition_loss

        optimizer = inv_lr_scheduler(optimizer, args.lr, step, args.iters)
        optimizer.zero_grad()
        total_loss.backward()
        optimizer.step()

        C_LOSS.update(clustering_loss)
        F_LOSS.update(fair_loss)
        P_LOSS.update(partition_loss)

        if step % args.test_interval == args.test_interval - 1 or step == 0:
            predicted, labels = predict(data_loader, encoder, dfc)
            predicted, labels = predicted.cpu().numpy(), labels.numpy()
            _, accuracy = cluster_accuracy(predicted, labels, 10)
            nmi = normalized_mutual_info_score(labels,
                                               predicted,
                                               average_method="arithmetic")
            bal, en_0, en_1 = balance(predicted, 60000)

            print("Step:[{:03d}/{:03d}]  "
                  "Acc:{:2.3f};"
                  "NMI:{:1.3f};"
                  "Bal:{:1.3f};"
                  "En:{:1.3f}/{:1.3f};"
                  "C.loss:{C_Loss.avg:3.2f};"
                  "F.loss:{F_Loss.avg:3.2f};"
                  "P.loss:{P_Loss.avg:3.2f};".format(step + 1,
                                                     args.iters,
                                                     accuracy,
                                                     nmi,
                                                     bal,
                                                     en_0,
                                                     en_1,
                                                     C_Loss=C_LOSS,
                                                     F_Loss=F_LOSS,
                                                     P_Loss=P_LOSS))

    return
コード例 #5
0
            # logging file
            if not os.path.exists(args.save_dir):
                os.makedirs(args.save_dir)
            logfile = open(args.save_dir + '/dec_log_{}.csv'.format(i), 'w')
            logwriter = csv.DictWriter(
                logfile, fieldnames=['iter', 'acc', 'nmi', 'ari', 'L'])
            logwriter.writeheader()

            loss = 0
            idx = 0
            t0 = time()
            for ite in range(int(args.maxiter)):
                if ite % args.update_interval == 0:
                    q = model.predict_generator(AE_generator, verbose=1)
                    p = target_distribution(
                        q)  # update the auxiliary target distribution p
                    print(p.shape)
                    # evaluate the clustering performance
                    y_pred = q.argmax(1)
                    if y_true is not None:
                        acc = np.round(cluster_acc(y_true, y_pred), 5)
                        nmi = np.round(
                            metrics.normalized_mutual_info_score(
                                y_true, y_pred), 5)
                        ari = np.round(
                            metrics.adjusted_rand_score(y_true, y_pred), 5)
                        loss = np.round(loss, 5)
                        logwriter.writerow(
                            dict(iter=ite, acc=acc, nmi=nmi, ari=ari, L=loss))
                        print(
                            'Iter-%d: ACC= %.4f, NMI= %.4f, ARI= %.4f;  L= %.5f'
コード例 #6
0
    def clustering(self,
                   x,
                   y=None,
                   update_interval=100,
                   maxiter=200,
                   save_dir='./results/dec'):

        print('Updating auxilliary distribution after %d iterations' %
              update_interval)
        save_interval = 10  # 10 epochs
        print('Saving models after %d iterations' % save_interval)

        # initialize cluster centers using k-means
        print('Initializing cluster centers with k-means.')
        k_means = KMeans(n_clusters=self.n_clusters,
                         n_init=20,
                         random_state=42)
        y_pred = k_means.fit_predict(self.encoder.predict(x))
        self.model.get_layer(name='clustering').set_weights(
            [k_means.cluster_centers_])

        # logging file
        if not os.path.exists(save_dir):
            os.makedirs(save_dir)

        log_file = open(save_dir + '/dec_log.csv', 'w')
        log_writer = csv.DictWriter(
            log_file, fieldnames=['iter', 'acc', 'nmi', 'ari', 'L'])
        log_writer.writeheader()

        loss = 0
        for ite in range(int(maxiter)):
            if ite % update_interval == 0:
                q = self.model.predict(x, verbose=0)
                p = target_distribution(
                    q)  # update the auxiliary target distribution p

                # evaluate the clustering performance
                y_pred = q.argmax(1)

                acc, nmi, ari = get_acc_nmi_ari(y, y_pred)
                loss = np.round(loss, 5)
                log_dict = dict(iter=ite, acc=acc, nmi=nmi, ari=ari, L=loss)
                log_writer.writerow(log_dict)
                print('Iter', ite, ': Acc', acc, ', nmi', nmi, ', ari', ari,
                      '; loss=', loss)

            # training on whole data
            loss = self.model.train_on_batch(x=x, y=p)

            # save intermediate model
            if ite % save_interval == 0:
                # save DEC model checkpoints
                print('saving model to:',
                      save_dir + '/DEC_model_' + str(ite) + '.h5')
                self.model.save(save_dir + '/DEC_model_' + str(ite) + '.h5')

            ite += 1

        # save the trained model
        log_file.close()
        print('saving model to:', save_dir + '/DEC_model_final.h5')
        self.model.save(save_dir + '/DEC_model_final.h5')

        return y_pred
コード例 #7
0
def train(args, dataloader_list, encoder, device='cpu', centers=None, save_name='DEC'):
    """
        Trains DFC and optionally the critic,
        automatically saves when finished training

    Args:
        args: Namespace object which contains config set from argument parser
              {
                lr,
                seed,
                iters,
                log_dir,
                test_interval,
                adv_multiplier,
                dfc_hidden_dim
              }
        dataloader_list (list): this list may consist of only 1 dataloader or multiple
        encoder: Encoder to use
        encoder_group_0: Optional pre-trained golden standard model
        encoder_group_1: Optional pre-trained golden standard model
        dfc_group_0: Optional cluster centers file obtained with encoder_group_0
        dfc_group_1: Optional cluster centers file obtained with encoder_group_1
        device: Device configuration
        centers: Initial centers clusters if available
        get_loss_trade_off: Proportional importance of individual loss functions
        save_name: Prefix for save files

    Returns:
        DFC: A trained DFC model

    """
    # """
    # Function for training and testing a VAE model.
    # Inputs:
    #     args -
    # """

    set_seed(args.seed)

    if args.half_tensor:
        torch.set_default_tensor_type('torch.HalfTensor')

    dec = DFC(cluster_number=args.cluster_number, hidden_dimension=args.dfc_hidden_dim).to(device)
    wandb.watch(dec)

    if not (centers is None):
        cluster_centers = centers.clone().detach().requires_grad_(True).to(device)
        with torch.no_grad():
            print("loading clustering centers...")
            dec.state_dict()['assignment.cluster_centers'].copy_(cluster_centers)
    # depending on the encoder we get the params diff so we have to use this if
    encoder_param = encoder.get_parameters() if args.encoder_type == 'vae' else [
        {"params": get_update_param(encoder), "lr_mult": 1}]
    optimizer = torch.optim.Adam(dec.get_parameters() + encoder_param, lr=args.dec_lr)

    # criterion_c = nn.KLDivLoss(reduction="sum")
    # following dec code more closely
    criterion_c = nn.KLDivLoss(size_average=False)

    C_LOSS = AverageMeter()

    print("Start training")
    assert 0 < len(dataloader_list) < 3

    concat_dataset = torch.utils.data.ConcatDataset([x.dataset for x in dataloader_list])
    training_dataloader = torch.utils.data.DataLoader(
        dataset=concat_dataset,
        batch_size=args.dec_batch_size,
        shuffle=True,
        num_workers=4
    )

    for step in range(args.dec_iters):
        encoder.train()
        dec.train()

        if step % len(training_dataloader) == 0:
            iterator = iter(training_dataloader)

        image, _ = iterator.__next__()
        image = image.to(device)
        if args.encoder_type == 'vae':
            z, _, _ = encoder(image)
        elif args.encoder_type == 'resnet50':
            z = encoder(image)
        else:
            raise Exception('Wrong encoder type, how did you get this far in running the code?')
        output = dec(z)

        target = target_distribution(output).detach()

        clustering_loss = criterion_c(output.log(), target) / output.shape[0]

        optimizer.zero_grad()
        clustering_loss.backward()
        optimizer.step()

        C_LOSS.update(clustering_loss)

        wandb.log({f"{save_name} Train C Loss Avg": C_LOSS.avg, f"{save_name} step": step})
        wandb.log({f"{save_name} Train C Loss Cur": C_LOSS.val, f"{save_name} step": step})

        if step % args.test_interval == args.test_interval - 1 or step == 0:
            predicted, labels = predict(dataloader_list, encoder, dec, device=device, encoder_type=args.encoder_type)
            predicted, labels = predicted.cpu().numpy(), labels.numpy()
            _, accuracy = cluster_accuracy(predicted, labels, args.cluster_number)
            nmi = normalized_mutual_info_score(labels, predicted, average_method="arithmetic")
            bal, en_0, en_1 = balance(predicted, len(dataloader_list[0]), k=args.cluster_number)

            wandb.log(
                {f"{save_name} Train Accuracy": accuracy, f"{save_name} Train NMI": nmi, f"{save_name} Train Bal": bal,
                 f"{save_name} Train Entropy 0": en_0,
                 f"{save_name} Train Entropy 1": en_1, f"{save_name} step": step})

            print("Step:[{:03d}/{:03d}]  "
                  "Acc:{:2.3f};"
                  "NMI:{:1.3f};"
                  "Bal:{:1.3f};"
                  "En:{:1.3f}/{:1.3f};"
                  "Clustering.loss:{C_Loss.avg:3.2f};".format(step + 1, args.dec_iters, accuracy, nmi, bal, en_0,
                                                              en_1, C_Loss=C_LOSS))

            # log tsne visualisation
            if args.encoder_type == "vae":
                tsne_img = tsne_visualization(dataloader_list, encoder, args.cluster_number,
                                              encoder_type=args.encoder_type,
                                              device=device)
                if not (tsne_img is None):
                    wandb.log({f"{save_name} TSNE": plt, f"{save_name} step": step})

    torch.save(dec.state_dict(), f'{args.log_dir}DFC_{save_name}.pth')

    return dec
コード例 #8
0
def train(args,
          dataloader_list,
          encoder,
          encoder_group_0=None,
          encoder_group_1=None,
          dfc_group_0=None,
          dfc_group_1=None,
          device='cpu',
          centers=None,
          get_loss_trade_off=lambda step: (10, 10, 10),
          save_name='DFC'):
    """Trains DFC and optionally the critic,

    automatically saves when finished training

    Args:
        args: Namespace object which contains config set from argument parser
              {
                lr,
                seed,
                iters,
                log_dir,
                test_interval,
                adv_multiplier,
                dfc_hidden_dim
              }
        dataloader_list (list): this list may consist of only 1 dataloader or multiple
        encoder: Encoder to use
        encoder_group_0: Optional pre-trained golden standard model
        encoder_group_1: Optional pre-trained golden standard model
        dfc_group_0: Optional cluster centers file obtained with encoder_group_0
        dfc_group_1: Optional cluster centers file obtained with encoder_group_1
        device: Device configuration
        centers: Initial centers clusters if available
        get_loss_trade_off: Proportional importance of individual loss functions
        save_name: Prefix for save files
    Returns:
        DFC: A trained DFC model
    """

    set_seed(args.seed)
    if args.half_tensor:
        torch.set_default_tensor_type('torch.HalfTensor')

    dfc = DFC(cluster_number=args.cluster_number,
              hidden_dimension=args.dfc_hidden_dim).to(device)
    wandb.watch(dfc)

    critic = AdversarialNetwork(in_feature=args.cluster_number,
                                hidden_size=32,
                                max_iter=args.iters,
                                lr_mult=args.adv_multiplier).to(device)
    wandb.watch(critic)

    if not (centers is None):
        cluster_centers = centers.clone().detach().requires_grad_(True).to(
            device)
        with torch.no_grad():
            print("loading clustering centers...")
            dfc.state_dict()['assignment.cluster_centers'].copy_(
                cluster_centers)

    encoder_param = encoder.get_parameters(
    ) if args.encoder_type == 'vae' else [{
        "params": encoder.parameters(),
        "lr_mult": 1
    }]
    optimizer = torch.optim.Adam(dfc.get_parameters() + encoder_param +
                                 critic.get_parameters(),
                                 lr=args.dec_lr,
                                 weight_decay=5e-4)

    criterion_c = nn.KLDivLoss(reduction="sum")
    criterion_p = nn.MSELoss(reduction="sum")
    C_LOSS = AverageMeter()
    F_LOSS = AverageMeter()
    P_LOSS = AverageMeter()

    partition_loss_enabled = True
    if not encoder_group_0 or not encoder_group_1 or not dfc_group_0 or not dfc_group_1:
        print(
            "Missing Golden Standard models, switching to DEC mode instead of DFC."
        )
        partition_loss_enabled = False

    if partition_loss_enabled:
        encoder_group_0.eval(), encoder_group_1.eval()
        dfc_group_0.eval(), dfc_group_1.eval()

    print("Start training")
    assert 0 < len(dataloader_list) < 3
    len_image_0 = len(dataloader_list[0])
    len_image_1 = len(
        dataloader_list[1]) if len(dataloader_list) == 2 else None
    for step in range(args.iters):
        encoder.train()
        dfc.train()

        if step % len_image_0 == 0:
            iter_image_0 = iter(dataloader_list[0])
        if len_image_1 and step % len_image_1 == 0:
            iter_image_1 = iter(dataloader_list[1])

        image_0, _ = iter_image_0.__next__()
        image_0 = image_0.to(device)
        if not (len_image_1 is None):
            image_1, _ = iter_image_1.__next__()
            image_1 = image_1.to(device)
            image = torch.cat((image_0, image_1), dim=0)
        else:
            image_1 = None
            image = torch.cat((image_0, ), dim=0)

        if args.encoder_type == 'vae':
            z, _, _ = encoder(image)
        elif args.encoder_type == 'resnet50':
            z = encoder(image)

        else:
            raise Exception(
                'Wrong encoder type, how did you get this far in running the code?'
            )
        output = dfc(z)
        features_enc_0 = encoder_group_0(
            image_0)[0] if args.encoder_type == 'vae' else encoder_group_0(
                image_0)
        predict_0 = dfc_group_0(features_enc_0)
        features_enc_1 = encoder_group_1(
            image_1)[0] if args.encoder_type == 'vae' else encoder_group_1(
                image_1)
        predict_1 = dfc_group_1(features_enc_1) if not (
            image_1 is None) else None

        output_0, output_1 = output[0:args.bs, :], output[
            args.bs:args.bs * 2, :] if not (predict_1 is None) else None
        target_0, target_1 = target_distribution(
            output_0).detach(), target_distribution(output_1).detach() if not (
                predict_1 is None) else None

        # Equaition (5) in the paper
        # output_0 and output_1 are probability distribution P of samples being assinged to a class in k
        # target_0 and target_1 are auxiliary distribuion Q calculated based on P. Eqation (4) in the paper
        if not (output_1 is None):
            clustering_loss = 0.5 * criterion_c(output_0.log(
            ), target_0) + 0.5 * criterion_c(output_1.log(), target_1)
        else:
            clustering_loss = criterion_c(output_0.log(), target_0)

        # Equation (2) in the paper
        # output = D(A(F(X)))
        # critic is the distribuition of categorical sensitive subgroup variable G (?)
        if len(dataloader_list) > 1:
            fair_loss, critic_acc = adv_loss(output, critic, device=device)
        else:
            fair_loss, critic_acc = 0, 0

        if partition_loss_enabled:
            # Equation (3) in the paper
            # output_0 and output_1 are the output of the pretrained encoder
            # predict_0 and predict_1 are the soft cluster assignments of the DFC.
            # loss is high if the outputs and predictions (and this the cluster structures) differ.
            if not (predict_1 is None):
                partition_loss = 0.5 * criterion_p(aff(output_0), aff(predict_0).detach()) \
                                 + 0.5 * criterion_p(aff(output_1), aff(predict_1).detach())
            else:
                partition_loss = criterion_p(aff(output_0),
                                             aff(predict_0).detach())
        else:
            partition_loss = 0

        loss_trade_off = get_loss_trade_off(step)
        if args.encoder_type == 'resnet50' and args.dataset == 'office_31':  # alpha_s
            loss_trade_off = list(loss_trade_off)
            loss_trade_off[1] = ((512 / 128)**2) * (31 / 10)

        total_loss = loss_trade_off[0] * fair_loss + loss_trade_off[
            1] * partition_loss + loss_trade_off[2] * clustering_loss

        optimizer = inv_lr_scheduler(optimizer, args.lr, step, args.iters)
        optimizer.zero_grad()
        total_loss.backward()
        optimizer.step()

        C_LOSS.update(clustering_loss)
        F_LOSS.update(fair_loss)
        P_LOSS.update(partition_loss)

        wandb.log({
            f"{save_name} Train C Loss Avg": C_LOSS.avg,
            f"{save_name} Train F Loss Avg": F_LOSS.avg,
            f"{save_name} Train P Loss Avg": P_LOSS.avg,
            f"{save_name} step": step,
            f"{save_name} Critic ACC": critic_acc
        })
        wandb.log({
            f"{save_name} Train C Loss Cur": C_LOSS.val,
            f"{save_name} Train F Loss Cur": F_LOSS.val,
            f"{save_name} Train P Loss Cur": P_LOSS.val,
            f"{save_name} step": step
        })

        if step % args.test_interval == args.test_interval - 1 or step == 0:

            predicted, labels = predict(dataloader_list,
                                        encoder,
                                        dfc,
                                        device=device,
                                        encoder_type=args.encoder_type)
            predicted, labels = predicted.cpu().numpy(), labels.numpy()
            _, accuracy = cluster_accuracy(predicted, labels,
                                           args.cluster_number)
            nmi = normalized_mutual_info_score(labels,
                                               predicted,
                                               average_method="arithmetic")
            bal, en_0, en_1 = balance(predicted,
                                      len_image_0,
                                      k=args.cluster_number)

            wandb.log({
                f"{save_name} Train Accuracy": accuracy,
                f"{save_name} Train NMI": nmi,
                f"{save_name} Train Bal": bal,
                f"{save_name} Train Entropy 0": en_0,
                f"{save_name} Train Entropy 1": en_1,
                f"{save_name} step": step
            })

            print("Step:[{:03d}/{:03d}]  "
                  "Acc:{:2.3f};"
                  "NMI:{:1.3f};"
                  "Bal:{:1.3f};"
                  "En:{:1.3f}/{:1.3f};"
                  "Clustering.loss:{C_Loss.avg:3.2f};"
                  "Fairness.loss:{F_Loss.avg:3.2f};"
                  "Partition.loss:{P_Loss.avg:3.2f};".format(step + 1,
                                                             args.iters,
                                                             accuracy,
                                                             nmi,
                                                             bal,
                                                             en_0,
                                                             en_1,
                                                             C_Loss=C_LOSS,
                                                             F_Loss=F_LOSS,
                                                             P_Loss=P_LOSS))

            # log tsne visualisation
            if args.encoder_type == "vae":
                tsne_img = tsne_visualization(dataloader_list,
                                              encoder,
                                              args.cluster_number,
                                              encoder_type=args.encoder_type,
                                              device=device)

                if not (tsne_img is None):
                    wandb.log({
                        f"{save_name} TSNE": plt,
                        f"{save_name} step": step
                    })

    torch.save(dfc.state_dict(), f'{args.log_dir}DFC_{save_name}.pth')

    if len(dataloader_list) > 1:
        torch.save(critic.state_dict(),
                   f'{args.log_dir}CRITIC_{save_name}.pth')

    return dfc