def main(device, args):

    loss1_func = nn.CrossEntropyLoss()
    loss2_func = softmax_kl_loss

    dataset_kwargs = {
        'dataset': args.dataset,
        'data_dir': args.data_dir,
        'download': args.download,
        'debug_subset_size': args.batch_size if args.debug else None
    }
    dataloader_kwargs = {
        'batch_size': args.batch_size,
        'drop_last': True,
        'pin_memory': True,
        'num_workers': args.num_workers,
    }
    dataloader_unlabeled_kwargs = {
        'batch_size': args.batch_size * 5,
        'drop_last': True,
        'pin_memory': True,
        'num_workers': args.num_workers,
    }
    dataset_train = get_dataset(transform=get_aug_fedmatch(args.dataset, True),
                                train=True,
                                **dataset_kwargs)

    if args.iid == 'iid':
        dict_users_labeled, dict_users_unlabeled = iid(dataset_train,
                                                       args.num_users,
                                                       args.label_rate)
    else:
        dict_users_labeled, dict_users_unlabeled = noniid(
            dataset_train, args.num_users, args.label_rate)

    train_loader_unlabeled = {}

    # define model
    model_glob = get_model('fedfixmatch', args.backbone).to(device)
    if torch.cuda.device_count() > 1:
        model_glob = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model_glob)

    for iter in range(args.num_epochs):

        model_glob.train()
        optimizer = torch.optim.SGD(model_glob.parameters(),
                                    lr=0.01,
                                    momentum=0.5)
        class_criterion = nn.CrossEntropyLoss(size_average=False,
                                              ignore_index=-1)

        train_loader_labeled = torch.utils.data.DataLoader(
            dataset=DatasetSplit(dataset_train, dict_users_labeled),
            shuffle=True,
            **dataloader_kwargs)

        for batch_idx, ((img, img_ema),
                        label) in enumerate(train_loader_labeled):
            input_var = torch.autograd.Variable(img.cuda())
            ema_input_var = torch.autograd.Variable(img_ema.cuda())
            target_var = torch.autograd.Variable(label.cuda())
            minibatch_size = len(target_var)
            labeled_minibatch_size = target_var.data.ne(-1).sum()
            ema_model_out = model_glob(ema_input_var)
            model_out = model_glob(input_var)
            if isinstance(model_out, Variable):
                logit1 = model_out
                ema_logit = ema_model_out
            else:
                assert len(model_out) == 2
                assert len(ema_model_out) == 2
                logit1, logit2 = model_out
                ema_logit, _ = ema_model_out

            ema_logit = Variable(ema_logit.detach().data, requires_grad=False)
            class_logit, cons_logit = logit1, logit1
            class_loss = class_criterion(class_logit,
                                         target_var) / minibatch_size
            ema_class_loss = class_criterion(ema_logit,
                                             target_var) / minibatch_size
            pseudo_label1 = torch.softmax(model_out.detach_(), dim=-1)
            max_probs, targets_u = torch.max(pseudo_label1, dim=-1)
            mask = max_probs.ge(args.threshold_pl).float()
            Lu = (F.cross_entropy(ema_logit, targets_u, reduction='none') *
                  mask).mean()
            loss = class_loss + Lu
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()


#             batch_loss.append(loss.item())

        del train_loader_labeled
        gc.collect()
        torch.cuda.empty_cache()

        if iter % 5 == 0:
            test_loader = torch.utils.data.DataLoader(dataset=get_dataset(
                transform=get_aug(args.dataset, False, train_classifier=False),
                train=False,
                **dataset_kwargs),
                                                      shuffle=False,
                                                      **dataloader_kwargs)
            model_glob.eval()
            accuracy, loss_train_test_labeled = test_img(
                model_glob, test_loader, args)
            del test_loader
            gc.collect()
            torch.cuda.empty_cache()

        w_locals, loss_locals, loss0_locals, loss2_locals = [], [], [], []

        m = max(int(args.frac * args.num_users), 1)
        idxs_users = np.random.choice(range(args.num_users), m, replace=False)

        for idx in idxs_users:

            loss_local = []
            loss0_local = []
            loss2_local = []

            model_local = copy.deepcopy(model_glob).to(args.device)

            train_loader_unlabeled = torch.utils.data.DataLoader(
                dataset=DatasetSplit(dataset_train, dict_users_unlabeled[idx]),
                shuffle=True,
                **dataloader_unlabeled_kwargs)

            model_local.train()

            for i, ((img, img_ema),
                    label) in enumerate(train_loader_unlabeled):

                input_var = torch.autograd.Variable(img.cuda())
                ema_input_var = torch.autograd.Variable(img_ema.cuda())
                target_var = torch.autograd.Variable(label.cuda())
                minibatch_size = len(target_var)
                labeled_minibatch_size = target_var.data.ne(-1).sum()
                ema_model_out = model_local(ema_input_var)
                model_out = model_local(input_var)
                if isinstance(model_out, Variable):
                    logit1 = model_out
                    ema_logit = ema_model_out
                else:
                    assert len(model_out) == 2
                    assert len(ema_model_out) == 2
                    logit1, logit2 = model_out
                    ema_logit, _ = ema_model_out

                ema_logit = Variable(ema_logit.detach().data,
                                     requires_grad=True)
                class_logit, cons_logit = logit1, logit1
                #                 class_loss = class_criterion(class_logit, target_var) / minibatch_size
                #                 ema_class_loss = class_criterion(ema_logit, target_var) / minibatch_size
                pseudo_label1 = torch.softmax(model_out.detach_(), dim=-1)
                max_probs, targets_u = torch.max(pseudo_label1, dim=-1)
                mask = max_probs.ge(args.threshold_pl).float()
                Lu = (F.cross_entropy(ema_logit, targets_u, reduction='none') *
                      mask).mean()
                loss = Lu
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
    #             batch_loss.append(loss.item())

            w_locals.append(copy.deepcopy(model_local.state_dict()))
            #             loss_locals.append(sum(loss_local) / len(loss_local) )

            del model_local
            gc.collect()
            del train_loader_unlabeled
            gc.collect()
            torch.cuda.empty_cache()

        w_glob = FedAvg(w_locals)
        model_glob.load_state_dict(w_glob)

        #         loss_avg = sum(loss_locals) / len(loss_locals)

        if iter % 5 == 0:
            print('Round {:3d}, Acc {:.3f}'.format(iter, accuracy))
def main(device, args):

    loss1_func = nn.CrossEntropyLoss()
    loss2_func = softmax_kl_loss

    dataset_kwargs = {
        'dataset': args.dataset,
        'data_dir': args.data_dir,
        'download': args.download,
        'debug_subset_size': args.batch_size if args.debug else None
    }
    dataloader_kwargs = {
        'batch_size': args.batch_size,
        'drop_last': True,
        'pin_memory': True,
        'num_workers': args.num_workers,
    }
    dataloader_unlabeled_kwargs = {
        'batch_size': args.batch_size * 5,
        'drop_last': True,
        'pin_memory': True,
        'num_workers': args.num_workers,
    }
    dataset_train = get_dataset(transform=get_aug_fedmatch(args.dataset, True),
                                train=True,
                                **dataset_kwargs)

    if args.iid == 'iid':
        dict_users_labeled, dict_users_unlabeled = iid(dataset_train,
                                                       args.num_users,
                                                       args.label_rate)
    else:
        dict_users_labeled, dict_users_unlabeled = noniid(
            dataset_train, args.num_users, args.label_rate)
    train_loader_unlabeled = {}

    # define model
    model_glob = get_model('fedfixmatch', args.backbone).to(device)
    if torch.cuda.device_count() > 1:
        model_glob = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model_glob)

    model_local_idx = set()

    user_epoch = {}
    lr_scheduler = {}
    accuracy = []
    class_criterion = nn.CrossEntropyLoss(size_average=False, ignore_index=-1)
    if args.dataset == 'cifar' and args.iid != 'noniid_tradition':
        consistency_criterion = softmax_kl_loss
    else:
        consistency_criterion = softmax_mse_loss

    for iter in range(args.num_epochs):

        model_glob.train()
        optimizer = torch.optim.SGD(model_glob.parameters(),
                                    lr=0.01,
                                    momentum=0.5)

        train_loader_labeled = torch.utils.data.DataLoader(
            dataset=DatasetSplit(dataset_train, dict_users_labeled),
            shuffle=True,
            **dataloader_kwargs)

        for batch_idx, ((img, img_ema),
                        label) in enumerate(train_loader_labeled):

            img, img_ema, label = img.to(args.device), img_ema.to(
                args.device), label.to(args.device)
            input_var = torch.autograd.Variable(img)
            ema_input_var = torch.autograd.Variable(img_ema, volatile=True)
            target_var = torch.autograd.Variable(label)
            minibatch_size = len(target_var)
            labeled_minibatch_size = target_var.data.ne(-1).sum()
            ema_model_out = model_glob(ema_input_var)
            model_out = model_glob(input_var)
            if isinstance(model_out, Variable):
                logit1 = model_out
                ema_logit = ema_model_out
            else:
                assert len(model_out) == 2
                assert len(ema_model_out) == 2
                logit1, logit2 = model_out
                ema_logit, _ = ema_model_out
            ema_logit = Variable(ema_logit.detach().data, requires_grad=False)
            class_logit, cons_logit = logit1, logit1
            classification_weight = 1
            class_loss = classification_weight * class_criterion(
                class_logit, target_var) / minibatch_size
            ema_class_loss = class_criterion(ema_logit,
                                             target_var) / minibatch_size
            consistency_weight = get_current_consistency_weight(iter)
            consistency_loss = consistency_weight * consistency_criterion(
                cons_logit, ema_logit) / minibatch_size
            loss = class_loss + consistency_loss
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        del train_loader_labeled
        gc.collect()
        torch.cuda.empty_cache()

        if iter % 1 == 0:
            test_loader = torch.utils.data.DataLoader(dataset=get_dataset(
                transform=get_aug(args.dataset, False, train_classifier=False),
                train=False,
                **dataset_kwargs),
                                                      shuffle=False,
                                                      **dataloader_kwargs)
            model_glob.eval()
            acc, loss_train_test_labeled = test_img(model_glob, test_loader,
                                                    args)
            accuracy.append(str(acc))
            del test_loader
            gc.collect()
            torch.cuda.empty_cache()

        w_locals, loss_locals, loss0_locals, loss2_locals = [], [], [], []

        m = max(int(args.frac * args.num_users), 1)
        idxs_users = np.random.choice(range(args.num_users), m, replace=False)

        for idx in idxs_users:
            if idx in user_epoch.keys():
                user_epoch[idx] += 1
            else:
                user_epoch[idx] = 1

            loss_local = []
            loss0_local = []
            loss2_local = []

            model_local = copy.deepcopy(model_glob).to(args.device)

            train_loader_unlabeled = torch.utils.data.DataLoader(
                dataset=DatasetSplit(dataset_train, dict_users_unlabeled[idx]),
                shuffle=True,
                **dataloader_unlabeled_kwargs)

            optimizer = torch.optim.SGD(model_local.parameters(),
                                        lr=args.lr,
                                        momentum=args.momentum,
                                        weight_decay=args.weight_decay,
                                        nesterov=False)

            model_local.train()

            for i, ((images1, images2),
                    labels) in enumerate(train_loader_unlabeled):

                img, img_ema, label = img.to(args.device), img_ema.to(
                    args.device), label.to(args.device)
                adjust_learning_rate(optimizer, user_epoch[idx], batch_idx,
                                     len(train_loader_unlabeled), args)
                input_var = torch.autograd.Variable(img)
                ema_input_var = torch.autograd.Variable(img_ema, volatile=True)
                target_var = torch.autograd.Variable(label)
                minibatch_size = len(target_var)
                labeled_minibatch_size = target_var.data.ne(-1).sum()
                ema_model_out = model_local(ema_input_var)
                model_out = model_local(input_var)
                if isinstance(model_out, Variable):
                    logit1 = model_out
                    ema_logit = ema_model_out
                else:
                    assert len(model_out) == 2
                    assert len(ema_model_out) == 2
                    logit1, logit2 = model_out
                    ema_logit, _ = ema_model_out
                ema_logit = Variable(ema_logit.detach().data,
                                     requires_grad=False)
                class_logit, cons_logit = logit1, logit1

                consistency_weight = get_current_consistency_weight(
                    user_epoch[idx])
                consistency_loss = consistency_weight * consistency_criterion(
                    cons_logit, ema_logit) / minibatch_size

                Lprox = 1 / 2 * dist(model_local.state_dict(),
                                     model_glob.state_dict())

                loss = consistency_loss + Lprox
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

            w_locals.append(copy.deepcopy(model_local.state_dict()))

            del model_local
            gc.collect()
            del train_loader_unlabeled
            gc.collect()
            torch.cuda.empty_cache()

        w_glob = FedAvg(w_locals)
        model_glob.load_state_dict(w_glob)

        #         loss_avg = sum(loss_locals) / len(loss_locals)

        if iter % 1 == 0:
            print('Round {:3d}, Acc {:.2f}%'.format(iter, acc))