Beispiel #1
0
def main(args):

    train_set = get_dataset(
        args.dataset, 
        args.data_dir, 
        transform=get_aug(args.model, args.image_size, train=False, train_classifier=True), 
        train=True, 
        download=args.download, # default is False
        debug_subset_size=args.batch_size if args.debug else None
    )
    test_set = get_dataset(
        args.dataset, 
        args.data_dir, 
        transform=get_aug(args.model, args.image_size, train=False, train_classifier=False), 
        train=False, 
        download=args.download, # default is False
        debug_subset_size=args.batch_size if args.debug else None
    )


    train_loader = torch.utils.data.DataLoader(
        dataset=train_set,
        batch_size=args.batch_size,
        shuffle=True,
        num_workers=args.num_workers,
        pin_memory=True,
        drop_last=True
    )
    test_loader = torch.utils.data.DataLoader(
        dataset=test_set,
        batch_size=args.batch_size,
        shuffle=False,
        num_workers=args.num_workers,
        pin_memory=True,
        drop_last=True
    )
    model = get_backbone(args.backbone)
    classifier = nn.Linear(in_features=model.output_dim, out_features=10, bias=True).to(args.device)

    assert args.eval_from is not None
    save_dict = torch.load(args.eval_from, map_location='cpu')
    msg = model.load_state_dict({k[9:]:v for k, v in save_dict['state_dict'].items() if k.startswith('backbone.')}, strict=True)
    
    # print(msg)
    model = model.to(args.device)
    model = torch.nn.DataParallel(model)

    # if torch.cuda.device_count() > 1: classifier = torch.nn.SyncBatchNorm.convert_sync_batchnorm(classifier)
    classifier = torch.nn.DataParallel(classifier)
    # define optimizer
    optimizer = get_optimizer(
        args.optimizer, classifier, 
        lr=args.base_lr*args.batch_size/256, 
        momentum=args.momentum, 
        weight_decay=args.weight_decay)

    # define lr scheduler
    lr_scheduler = LR_Scheduler(
        optimizer,
        args.warmup_epochs, args.warmup_lr*args.batch_size/256, 
        args.num_epochs, args.base_lr*args.batch_size/256, args.final_lr*args.batch_size/256, 
        len(train_loader),
    )

    loss_meter = AverageMeter(name='Loss')
    acc_meter = AverageMeter(name='Accuracy')

    # Start training
    global_progress = tqdm(range(0, args.num_epochs), desc=f'Evaluating')
    for epoch in global_progress:
        loss_meter.reset()
        model.eval()
        classifier.train()
        local_progress = tqdm(train_loader, desc=f'Epoch {epoch}/{args.num_epochs}', disable=args.hide_progress)
        
        for idx, (images, labels) in enumerate(local_progress):

            classifier.zero_grad()
            with torch.no_grad():
                feature = model(images.to(args.device))

            preds = classifier(feature)

            loss = F.cross_entropy(preds, labels.to(args.device))

            loss.backward()
            optimizer.step()
            loss_meter.update(loss.item())
            lr = lr_scheduler.step()
            local_progress.set_postfix({'lr':lr, "loss":loss_meter.val, 'loss_avg':loss_meter.avg})
        

        if args.head_tail_accuracy and epoch != 0 and (epoch+1) != args.num_epochs: continue

        local_progress=tqdm(test_loader, desc=f'Test {epoch}/{args.num_epochs}', disable=args.hide_progress)
        classifier.eval()
        correct, total = 0, 0
        acc_meter.reset()
        for idx, (images, labels) in enumerate(local_progress):
            with torch.no_grad():
                feature = model(images.to(args.device))
                preds = classifier(feature).argmax(dim=1)
                correct = (preds == labels.to(args.device)).sum().item()
                acc_meter.update(correct/preds.shape[0])
                local_progress.set_postfix({'accuracy': acc_meter.avg})
        
        global_progress.set_postfix({"epoch":epoch, 'accuracy':acc_meter.avg*100})
Beispiel #2
0
def main(args):
    train_info = []
    best_epoch = np.zeros(5)
    for val_folder_index in range(5):
        best_balance_acc = 0
        whole_train_list = ['D8E6', '117E', '676F', 'E2D7', 'BE52']
        val_WSI_list = whole_train_list[val_folder_index]
        train_WSI_list = whole_train_list
        train_WSI_list.pop(val_folder_index)
        train_directory = '../data/finetune/1percent/'
        valid_directory = '../data/finetune/1percent'
        dataset = {}
        dataset_train0 = datasets.ImageFolder(
            root=train_directory + train_WSI_list[0],
            transform=get_aug(train=False,
                              train_classifier=True,
                              **args.aug_kwargs))
        dataset_train1 = datasets.ImageFolder(
            root=train_directory + train_WSI_list[1],
            transform=get_aug(train=False,
                              train_classifier=True,
                              **args.aug_kwargs))
        dataset_train2 = datasets.ImageFolder(
            root=train_directory + train_WSI_list[2],
            transform=get_aug(train=False,
                              train_classifier=True,
                              **args.aug_kwargs))
        dataset_train3 = datasets.ImageFolder(
            root=train_directory + train_WSI_list[3],
            transform=get_aug(train=False,
                              train_classifier=True,
                              **args.aug_kwargs))
        dataset['valid'] = datasets.ImageFolder(
            root=valid_directory + val_WSI_list,
            transform=get_aug(train=False,
                              train_classifier=False,
                              **args.aug_kwargs))
        dataset['train'] = data.ConcatDataset(
            [dataset_train0, dataset_train1, dataset_train2, dataset_train3])

        train_loader = torch.utils.data.DataLoader(
            dataset=dataset['train'],
            batch_size=args.eval.batch_size,
            shuffle=True,
            **args.dataloader_kwargs)
        test_loader = torch.utils.data.DataLoader(
            dataset=dataset['valid'],
            batch_size=args.eval.batch_size,
            shuffle=False,
            **args.dataloader_kwargs)

        model = get_backbone(args.model.backbone)
        classifier = nn.Linear(in_features=model.output_dim,
                               out_features=9,
                               bias=True).to(args.device)

        assert args.eval_from is not None
        save_dict = torch.load(args.eval_from, map_location='cpu')
        msg = model.load_state_dict(
            {
                k[9:]: v
                for k, v in save_dict['state_dict'].items()
                if k.startswith('backbone.')
            },
            strict=True)

        # print(msg)
        model = model.to(args.device)
        model = torch.nn.DataParallel(model)

        classifier = torch.nn.DataParallel(classifier)
        # define optimizer
        optimizer = get_optimizer(
            args.eval.optimizer.name,
            classifier,
            lr=args.eval.base_lr * args.eval.batch_size / 256,
            momentum=args.eval.optimizer.momentum,
            weight_decay=args.eval.optimizer.weight_decay)

        # define lr scheduler
        lr_scheduler = LR_Scheduler(
            optimizer,
            args.eval.warmup_epochs,
            args.eval.warmup_lr * args.eval.batch_size / 256,
            args.eval.num_epochs,
            args.eval.base_lr * args.eval.batch_size / 256,
            args.eval.final_lr * args.eval.batch_size / 256,
            len(train_loader),
        )

        loss_meter = AverageMeter(name='Loss')
        acc_meter = AverageMeter(name='Accuracy')

        # Start training
        global_progress = tqdm(range(0, args.eval.num_epochs),
                               desc=f'Evaluating')
        for epoch in global_progress:
            loss_meter.reset()
            model.eval()
            classifier.train()
            local_progress = tqdm(train_loader,
                                  desc=f'Epoch {epoch}/{args.eval.num_epochs}',
                                  disable=True)

            for idx, (images, labels) in enumerate(local_progress):
                classifier.zero_grad()
                with torch.no_grad():
                    feature = model(images.to(args.device))

                preds = classifier(feature)

                loss = F.cross_entropy(preds, labels.to(args.device))

                loss.backward()
                optimizer.step()
                loss_meter.update(loss.item())
                lr = lr_scheduler.step()
                local_progress.set_postfix({
                    'lr': lr,
                    "loss": loss_meter.val,
                    'loss_avg': loss_meter.avg
                })

            writer.add_scalar('Valid/Loss', loss_meter.avg, epoch)
            writer.add_scalar('Valid/Lr', lr, epoch)
            writer.flush()

            PATH = 'checkpoint/exp_0228_triple_1percent/' + val_WSI_list + '/' + val_WSI_list + '_tunelinear_' + str(
                epoch) + '.pth'

            torch.save(classifier, PATH)

            classifier.eval()
            correct, total = 0, 0
            acc_meter.reset()

            pred_label_for_f1 = np.array([])
            true_label_for_f1 = np.array([])
            for idx, (images, labels) in enumerate(test_loader):
                with torch.no_grad():
                    feature = model(images.to(args.device))
                    preds = classifier(feature).argmax(dim=1)
                    correct = (preds == labels.to(args.device)).sum().item()

                    preds_arr = preds.cpu().detach().numpy()
                    labels_arr = labels.cpu().detach().numpy()
                    pred_label_for_f1 = np.concatenate(
                        [pred_label_for_f1, preds_arr])
                    true_label_for_f1 = np.concatenate(
                        [true_label_for_f1, labels_arr])
                    acc_meter.update(correct / preds.shape[0])

            f1 = f1_score(true_label_for_f1,
                          pred_label_for_f1,
                          average='macro')
            balance_acc = balanced_accuracy_score(true_label_for_f1,
                                                  pred_label_for_f1)
            print('Epoch:  ', str(epoch),
                  f'Accuracy = {acc_meter.avg * 100:.2f}')
            print('F1 score =  ', f1, 'balance acc:  ', balance_acc)
            if balance_acc > best_balance_acc:
                best_epoch[val_folder_index] = epoch
                best_balance_acc = balance_acc
            train_info.append([val_WSI_list, epoch, f1, balance_acc])

    with open('checkpoint/exp_0228_triple_1percent/train_info.csv', 'w') as f:
        # using csv.writer method from CSV package
        write = csv.writer(f)
        write.writerows(train_info)
    print(best_epoch)
Beispiel #3
0
def main(args, model=None):
    assert args.eval_from is not None or model is not None
    train_set = get_dataset(
        args.dataset,
        args.data_dir,
        transform=get_aug(args.model,
                          args.image_size,
                          train=False,
                          train_classifier=True),
        train=True,
        download=args.download,  # default is False
        debug_subset_size=args.batch_size
        if args.debug else None  # Use a subset of dataset for debugging.
    )
    test_set = get_dataset(
        args.dataset,
        args.data_dir,
        transform=get_aug(args.model,
                          args.image_size,
                          train=False,
                          train_classifier=False),
        train=False,
        download=args.download,  # default is False
        debug_subset_size=args.batch_size if args.debug else None)

    train_loader = torch.utils.data.DataLoader(dataset=train_set,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=args.num_workers,
                                               pin_memory=True,
                                               drop_last=True)
    test_loader = torch.utils.data.DataLoader(dataset=test_set,
                                              batch_size=args.batch_size,
                                              shuffle=False,
                                              num_workers=args.num_workers,
                                              pin_memory=True,
                                              drop_last=True)

    model = get_backbone(args.backbone)
    classifier = nn.Linear(in_features=model.output_dim,
                           out_features=len(train_set.classes),
                           bias=True).to(args.device)

    if args.local_rank >= 0 and not torch.distributed.is_initialized():
        torch.cuda.set_device(args.local_rank)
        torch.distributed.init_process_group(backend="nccl",
                                             init_method="env://")

    if model is None:
        model = get_backbone(args.backbone).to(args.device)
        save_dict = torch.load(args.eval_from, map_location=args.device)
        model.load_state_dict(
            {
                k[9:]: v
                for k, v in save_dict['state_dict'].items()
                if k.startswith('backbone.')
            },
            strict=True)

    output_dim = model.output_dim
    if args.local_rank >= 0:
        model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
        model = torch.nn.parallel.DistributedDataParallel(
            model,
            device_ids=[args.local_rank],
            output_device=args.local_rank,
            find_unused_parameters=True)

    classifier = nn.Linear(in_features=output_dim, out_features=10,
                           bias=True).to(args.device)
    if args.local_rank >= 0:
        classifier = torch.nn.parallel.DistributedDataParallel(
            classifier,
            device_ids=[args.local_rank],
            output_device=args.local_rank)

    # define optimizer
    optimizer = get_optimizer(args.optimizer,
                              classifier,
                              lr=args.base_lr * args.batch_size / 256,
                              momentum=args.momentum,
                              weight_decay=args.weight_decay)

    # TODO: linear lr warm up for byol simclr swav
    # args.warm_up_epochs
    # define lr scheduler
    lr_scheduler = LR_Scheduler(optimizer, args.warmup_epochs,
                                args.warmup_lr * args.batch_size / 256,
                                args.num_epochs,
                                args.base_lr * args.batch_size / 256,
                                args.final_lr * args.batch_size / 256,
                                len(train_loader))

    loss_meter = AverageMeter(name='Loss')
    acc_meter = AverageMeter(name='Accuracy')

    # Start training
    global_progress = tqdm(range(0, args.num_epochs), desc=f'Evaluating')
    for epoch in global_progress:
        loss_meter.reset()
        model.eval()
        classifier.train()
        local_progress = tqdm(train_loader,
                              desc=f'Epoch {epoch}/{args.num_epochs}',
                              disable=args.hide_progress)

        for idx, (images, labels) in enumerate(local_progress):

            classifier.zero_grad()
            with torch.no_grad():
                feature = model(images.to(args.device))

            preds = classifier(feature)

            loss = F.cross_entropy(preds, labels.to(args.device))

            loss.backward()
            optimizer.step()
            loss_meter.update(loss.item())
            lr = lr_scheduler.step()
            local_progress.set_postfix({
                'lr': lr,
                "loss": loss_meter.val,
                'loss_avg': loss_meter.avg
            })

        if args.head_tail_accuracy and epoch != 0 and (epoch +
                                                       1) != args.num_epochs:
            continue

        local_progress = tqdm(test_loader,
                              desc=f'Test {epoch}/{args.num_epochs}',
                              disable=args.hide_progress)
        classifier.eval()
        correct, total = 0, 0
        acc_meter.reset()
        for idx, (images, labels) in enumerate(local_progress):
            with torch.no_grad():
                feature = model(images.to(args.device))
                preds = classifier(feature).argmax(dim=1)
                correct = (preds == labels.to(args.device)).sum().item()
                acc_meter.update(correct / preds.shape[0])
                local_progress.set_postfix({'accuracy': acc_meter.avg})

        global_progress.set_postfix({
            "epoch": epoch,
            'accuracy': acc_meter.avg * 100
        })
Beispiel #4
0
def main(device, args):
    train_loader = torch.utils.data.DataLoader(
        dataset=get_dataset(transform=get_aug(train=True, **args.aug_kwargs),
                            train=True,
                            **args.dataset_kwargs),
        shuffle=True,
        batch_size=args.train.batch_size,
        **args.dataloader_kwargs)
    memory_loader = torch.utils.data.DataLoader(
        dataset=get_dataset(transform=get_aug(train=False,
                                              train_classifier=False,
                                              **args.aug_kwargs),
                            train=True,
                            **args.dataset_kwargs),
        shuffle=False,
        batch_size=args.train.batch_size,
        **args.dataloader_kwargs)
    test_loader = torch.utils.data.DataLoader(dataset=get_dataset(
        transform=get_aug(train=False,
                          train_classifier=False,
                          **args.aug_kwargs),
        train=False,
        **args.dataset_kwargs),
                                              shuffle=False,
                                              batch_size=args.train.batch_size,
                                              **args.dataloader_kwargs)

    # define model
    model = get_model(args.model).to(device)
    model = torch.nn.DataParallel(model)

    # define optimizer
    optimizer = get_optimizer(args.train.optimizer.name,
                              model,
                              lr=args.train.base_lr * args.train.batch_size /
                              256,
                              momentum=args.train.optimizer.momentum,
                              weight_decay=args.train.optimizer.weight_decay)

    lr_scheduler = LR_Scheduler(
        optimizer,
        args.train.warmup_epochs,
        args.train.warmup_lr * args.train.batch_size / 256,
        args.train.num_epochs,
        args.train.base_lr * args.train.batch_size / 256,
        args.train.final_lr * args.train.batch_size / 256,
        len(train_loader),
        constant_predictor_lr=True  # see the end of section 4.2 predictor
    )

    logger = Logger(tensorboard=args.logger.tensorboard,
                    matplotlib=args.logger.matplotlib,
                    log_dir=args.log_dir)
    accuracy = 0
    # Start training

    print("Trying to train model {}".format(model))
    print("Will run up to {} epochs of training".format(
        args.train.stop_at_epoch))

    global_progress = tqdm(range(0, args.train.stop_at_epoch),
                           desc=f'Training')
    for epoch in global_progress:
        model.train()

        local_progress = tqdm(train_loader,
                              desc=f'Epoch {epoch}/{args.train.num_epochs}',
                              disable=args.hide_progress)
        for idx, _data in enumerate(local_progress):
            # TODO looks like we might be missing the label?
            ((images1, images2), labels) = _data

            model.zero_grad()
            data_dict = model.forward(images1.to(device, non_blocking=True),
                                      images2.to(device, non_blocking=True))
            loss = data_dict['loss'].mean()  # ddp
            loss.backward()
            optimizer.step()
            lr_scheduler.step()
            data_dict.update({'lr': lr_scheduler.get_lr()})

            local_progress.set_postfix(data_dict)
            logger.update_scalers(data_dict)

        # ignore KNN monitor since it's coded to work ONLY on cuda enabled devices unfortunately
        # check the mnist yaml to see
        if args.train.knn_monitor and epoch % args.train.knn_interval == 0:
            accuracy = knn_monitor(model.module.backbone,
                                   memory_loader,
                                   test_loader,
                                   device,
                                   k=min(args.train.knn_k,
                                         len(memory_loader.dataset)),
                                   hide_progress=args.hide_progress)

        epoch_dict = {"epoch": epoch, "accuracy": accuracy}
        global_progress.set_postfix(epoch_dict)
        logger.update_scalers(epoch_dict)

    # Save checkpoint
    model_path = os.path.join(
        args.ckpt_dir,
        f"{args.name}_{datetime.now().strftime('%m%d%H%M%S')}.pth"
    )  # datetime.now().strftime('%Y%m%d_%H%M%S')
    torch.save({
        'epoch': epoch + 1,
        'state_dict': model.module.state_dict()
    }, model_path)
    print(f"Model saved to {model_path}")
    with open(os.path.join(args.log_dir, f"checkpoint_path.txt"), 'w+') as f:
        f.write(f'{model_path}')

    if args.eval is not False:
        args.eval_from = model_path
        linear_eval(args)
def main(device, args):
    train_directory = '../data/train'
    image_name_file = '../data/original.csv'
    val_directory = '../data/train'
    train_loader = torch.utils.data.DataLoader(
        dataset=get_dataset('random', train_directory, image_name_file,
            transform=get_aug(train=True, **args.aug_kwargs),
            train=True,
            **args.dataset_kwargs),
        # dataset=datasets.ImageFolder(root=train_directory, transform=get_aug(train=True, **args.aug_kwargs)),
        shuffle=True,
        batch_size=args.train.batch_size,
        **args.dataloader_kwargs
    )

    memory_loader = torch.utils.data.DataLoader(
        dataset=datasets.ImageFolder(root=val_directory, transform=get_aug(train=False, train_classifier=False, **args.aug_kwargs)),
        shuffle=False,
        batch_size=args.train.batch_size,
        **args.dataloader_kwargs
    )
    test_loader = torch.utils.data.DataLoader(
        dataset=datasets.ImageFolder(root=val_directory, transform=get_aug(train=False, train_classifier=False, **args.aug_kwargs)),
        shuffle=False,
        batch_size=args.train.batch_size,
        **args.dataloader_kwargs
    )

    # define model
    model = get_model(args.model).to(device)
    model = torch.nn.DataParallel(model)
    scaler = torch.cuda.amp.GradScaler()

    # define optimizer
    optimizer = get_optimizer(
        args.train.optimizer.name, model,
        lr=args.train.base_lr * args.train.batch_size / 256,
        momentum=args.train.optimizer.momentum,
        weight_decay=args.train.optimizer.weight_decay)

    lr_scheduler = LR_Scheduler(
        optimizer,
        args.train.warmup_epochs, args.train.warmup_lr * args.train.batch_size / 256,
        args.train.num_epochs, args.train.base_lr * args.train.batch_size / 256,
                                  args.train.final_lr * args.train.batch_size / 256,
        len(train_loader),
        constant_predictor_lr=True  # see the end of section 4.2 predictor
    )

    RESUME = False
    start_epoch = 0

    if RESUME:
        model = get_backbone(args.model.backbone)
        classifier = nn.Linear(in_features=model.output_dim, out_features=9, bias=True).to(args.device)

        assert args.eval_from is not None
        save_dict = torch.load(args.eval_from, map_location='cpu')
        msg = model.load_state_dict({k[9:]: v for k, v in save_dict['state_dict'].items() if k.startswith('backbone.')},
                                    strict=True)

        path_checkpoint = "./checkpoint/simsiam-TCGA-0218-nearby_0221134812.pth"  # 断点路径
        checkpoint = torch.load(path_checkpoint)  # 加载断点

        model.load_state_dict(checkpoint['net'])  # 加载模型可学习参数

        optimizer.load_state_dict(checkpoint['optimizer'])  # 加载优化器参数
        start_epoch = checkpoint['epoch']  # 设置开始的epoch

    logger = Logger(tensorboard=args.logger.tensorboard, matplotlib=args.logger.matplotlib, log_dir=args.log_dir)
    accuracy = 0
    # Start training
    global_progress = tqdm(range(start_epoch, args.train.stop_at_epoch), desc=f'Training')
    for epoch in global_progress:
        model.train()

        local_progress = tqdm(train_loader, desc=f'Epoch {epoch}/{args.train.num_epochs}', disable=args.hide_progress)
        for idx, (images1, images2, images3, labels) in enumerate(local_progress):
            model.zero_grad()
            with torch.cuda.amp.autocast():
                data_dict = model.forward(images1.to(device, non_blocking=True), images2.to(device, non_blocking=True),
                                          images3.to(device, non_blocking=True))
                loss = data_dict['loss'].mean()  # ddp
            # loss.backward()
            scaler.scale(loss).backward()
            # optimizer.step()
            scaler.step(optimizer)
            scaler.update()

            lr_scheduler.step()
            data_dict.update({'lr': lr_scheduler.get_lr()})

            local_progress.set_postfix(data_dict)
            logger.update_scalers(data_dict)

        if args.train.knn_monitor and epoch % args.train.knn_interval == 0:
            accuracy = knn_monitor(model.module.backbone, memory_loader, test_loader, device,
                                   k=min(args.train.knn_k, len(memory_loader.dataset)),
                                   hide_progress=args.hide_progress)

        epoch_dict = {"epoch": epoch, "accuracy": accuracy}
        global_progress.set_postfix(epoch_dict)
        logger.update_scalers(epoch_dict)

        checkpoint = {
            "net": model.state_dict(),
            'optimizer': optimizer.state_dict(),
            "epoch": epoch
        }
        if (epoch % args.train.save_interval) == 0:
            torch.save({
                'epoch': epoch + 1,
                'state_dict': model.module.state_dict()
            }, './checkpoint/exp_0223_triple_400_proj3/ckpt_best_%s.pth' % (str(epoch)))

    # Save checkpoint
    model_path = os.path.join(args.ckpt_dir,
                              f"{args.name}_{datetime.now().strftime('%m%d%H%M%S')}.pth")  # datetime.now().strftime('%Y%m%d_%H%M%S')
    torch.save({
        'epoch': epoch + 1,
        'state_dict': model.module.state_dict()
    }, model_path)
    print(f"Model saved to {model_path}")
    with open(os.path.join(args.log_dir, f"checkpoint_path.txt"), 'w+') as f:
        f.write(f'{model_path}')


    if args.eval is not False:
        args.eval_from = model_path
        linear_eval(args)
Beispiel #6
0
def main(args):

    train_loader = torch.utils.data.DataLoader(dataset=get_dataset(
        transform=get_aug(train=False,
                          train_classifier=True,
                          **args.aug_kwargs),
        train=True,
        **args.dataset_kwargs),
                                               batch_size=args.eval.batch_size,
                                               shuffle=True,
                                               **args.dataloader_kwargs)
    test_loader = torch.utils.data.DataLoader(dataset=get_dataset(
        transform=get_aug(train=False,
                          train_classifier=False,
                          **args.aug_kwargs),
        train=False,
        **args.dataset_kwargs),
                                              batch_size=args.eval.batch_size,
                                              shuffle=False,
                                              **args.dataloader_kwargs)

    model = get_backbone(args.model.backbone)
    classifier = nn.Linear(in_features=model.output_dim,
                           out_features=10,
                           bias=True).to(args.device)

    assert args.eval_from is not None
    save_dict = torch.load(args.eval_from, map_location='cpu')
    msg = model.load_state_dict(
        {
            k[9:]: v
            for k, v in save_dict['state_dict'].items()
            if k.startswith('backbone.')
        },
        strict=True)

    # print(msg)
    model = model.to(args.device)
    model = torch.nn.DataParallel(model)

    # if torch.cuda.device_count() > 1: classifier = torch.nn.SyncBatchNorm.convert_sync_batchnorm(classifier)
    classifier = torch.nn.DataParallel(classifier)
    # define optimizer
    optimizer = get_optimizer(args.eval.optimizer.name,
                              classifier,
                              lr=args.eval.base_lr * args.eval.batch_size /
                              256,
                              momentum=args.eval.optimizer.momentum,
                              weight_decay=args.eval.optimizer.weight_decay)

    # define lr scheduler
    lr_scheduler = LR_Scheduler(
        optimizer,
        args.eval.warmup_epochs,
        args.eval.warmup_lr * args.eval.batch_size / 256,
        args.eval.num_epochs,
        args.eval.base_lr * args.eval.batch_size / 256,
        args.eval.final_lr * args.eval.batch_size / 256,
        len(train_loader),
    )

    loss_meter = AverageMeter(name='Loss')
    acc_meter = AverageMeter(name='Accuracy')

    # Start training
    global_progress = tqdm(range(0, args.eval.num_epochs), desc=f'Evaluating')
    for epoch in global_progress:
        loss_meter.reset()
        model.eval()
        classifier.train()
        local_progress = tqdm(train_loader,
                              desc=f'Epoch {epoch}/{args.eval.num_epochs}',
                              disable=True)

        for idx, (images, labels) in enumerate(local_progress):
            # this will take the images and stick them to one another using the batch dimension
            # so it expects [C x H x W] and will turn each into a [1 x C x H x W] and then for N it will
            # concatenate them into a big tensor of [N x C x H x W]
            if type(images) == list:
                print(images[1].shape, len(images))
                images = torch.cat(
                    [image.unsqueeze(dim=0) for image in images], dim=0)

            classifier.zero_grad()
            with torch.no_grad():
                feature = model(images.to(args.device))

            preds = classifier(feature)

            loss = F.cross_entropy(preds, labels.to(args.device))

            loss.backward()
            optimizer.step()
            loss_meter.update(loss.item())
            lr = lr_scheduler.step()
            local_progress.set_postfix({
                'lr': lr,
                "loss": loss_meter.val,
                'loss_avg': loss_meter.avg
            })

    classifier.eval()
    correct, total = 0, 0
    acc_meter.reset()
    for idx, (images, labels) in enumerate(test_loader):
        with torch.no_grad():
            feature = model(images.to(args.device))
            preds = classifier(feature).argmax(dim=1)
            correct = (preds == labels.to(args.device)).sum().item()
            acc_meter.update(correct / preds.shape[0])
    print(f'Accuracy = {acc_meter.avg*100:.2f}')
Beispiel #7
0
def main(device, args):
    dataset_kwargs = {
        'dataset': args.dataset,
        'data_dir': args.data_dir,
        'download': args.download,
        'debug_subset_size': args.batch_size if args.debug else None
    }
    dataloader_kwargs = {
        'batch_size': args.batch_size,
        'drop_last': True,
        'pin_memory': True,
        'num_workers': args.num_workers,
    }

    train_loader = torch.utils.data.DataLoader(dataset=get_dataset(
        transform=get_aug(args.model, args.image_size, True),
        train=True,
        **dataset_kwargs),
                                               shuffle=True,
                                               **dataloader_kwargs)
    memory_loader = torch.utils.data.DataLoader(dataset=get_dataset(
        transform=get_aug(args.model,
                          args.image_size,
                          False,
                          train_classifier=False),
        train=True,
        **dataset_kwargs),
                                                shuffle=False,
                                                **dataloader_kwargs)
    test_loader = torch.utils.data.DataLoader(dataset=get_dataset(
        transform=get_aug(args.model,
                          args.image_size,
                          False,
                          train_classifier=False),
        train=False,
        **dataset_kwargs),
                                              shuffle=False,
                                              **dataloader_kwargs)

    # define model
    model = get_model(args.model, args.backbone).to(device)
    if args.model == 'simsiam' and args.proj_layers is not None:
        model.projector.set_layers(args.proj_layers)
    model = torch.nn.DataParallel(model)
    if torch.cuda.device_count() > 1:
        model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)

    # define optimizer
    optimizer = get_optimizer(args.optimizer,
                              model,
                              lr=args.base_lr * args.batch_size / 256,
                              momentum=args.momentum,
                              weight_decay=args.weight_decay)

    lr_scheduler = LR_Scheduler(
        optimizer,
        args.warmup_epochs,
        args.warmup_lr * args.batch_size / 256,
        args.num_epochs,
        args.base_lr * args.batch_size / 256,
        args.final_lr * args.batch_size / 256,
        len(train_loader),
        constant_predictor_lr=True  # see the end of section 4.2 predictor
    )

    loss_meter = AverageMeter(name='Loss')
    plot_logger = PlotLogger(params=['lr', 'loss', 'accuracy'])
    # Start training
    global_progress = tqdm(range(0, args.stop_at_epoch), desc=f'Training')
    for epoch in global_progress:
        loss_meter.reset()
        model.train()

        # plot_logger.update({'epoch':epoch, 'accuracy':accuracy})
        local_progress = tqdm(train_loader,
                              desc=f'Epoch {epoch}/{args.num_epochs}',
                              disable=args.hide_progress)
        for idx, ((images1, images2), labels) in enumerate(local_progress):

            model.zero_grad()
            loss = model.forward(images1.to(device, non_blocking=True),
                                 images2.to(device, non_blocking=True))
            loss.backward()
            optimizer.step()
            loss_meter.update(loss.item())
            lr = lr_scheduler.step()

            data_dict = {'lr': lr, "loss": loss_meter.val}
            local_progress.set_postfix(data_dict)
            plot_logger.update(data_dict)
        accuracy = knn_monitor(model.module.backbone,
                               memory_loader,
                               test_loader,
                               device,
                               k=200,
                               hide_progress=args.hide_progress)
        global_progress.set_postfix({
            "epoch": epoch,
            "loss_avg": loss_meter.avg,
            "accuracy": accuracy
        })
        plot_logger.update({'accuracy': accuracy})
        plot_logger.save(os.path.join(args.output_dir, 'logger.svg'))

        # Save checkpoint

    model_path = os.path.join(
        args.output_dir, f'{args.model}-{args.dataset}-epoch{epoch+1}.pth')
    torch.save(
        {
            'epoch': epoch + 1,
            'state_dict': model.module.state_dict(),
            # 'optimizer':optimizer.state_dict(), # will double the checkpoint file size
            'lr_scheduler': lr_scheduler,
            'args': args,
            'loss_meter': loss_meter,
            'plot_logger': plot_logger
        },
        model_path)
    print(f"Model saved to {model_path}")

    if args.eval_after_train is not None:
        args.eval_from = model_path
        arg_list = [
            x.strip().lstrip('--').split()
            for x in args.eval_after_train.split('\n')
        ]
        args.__dict__.update({x[0]: eval(x[1]) for x in arg_list})
        if args.debug:
            args.batch_size = 2
            args.num_epochs = 3

        linear_eval(args)
Beispiel #8
0
def main(args):
    train_directory = '/share/contrastive_learning/data/sup_data/data_0124_10000/train_patch'
    train_loader = torch.utils.data.DataLoader(
        # dataset=get_dataset(
        #     transform=get_aug(train=False, train_classifier=True, **args.aug_kwargs),
        #     train=True,
        #     **args.dataset_kwargs
        # ),
        dataset=datasets.ImageFolder(root=train_directory,
                                     transform=get_aug(train=False,
                                                       train_classifier=True,
                                                       **args.aug_kwargs)),
        batch_size=args.eval.batch_size,
        shuffle=True,
        **args.dataloader_kwargs)
    test_dictionary = '/share/contrastive_learning/data/sup_data/data_0124_10000/val_patch'
    test_loader = torch.utils.data.DataLoader(
        # dataset=get_dataset(
        #     transform=get_aug(train=False, train_classifier=False, **args.aug_kwargs),
        #     train=False,
        #     **args.dataset_kwargs
        # ),
        dataset=datasets.ImageFolder(root=test_dictionary,
                                     transform=get_aug(train=False,
                                                       train_classifier=False,
                                                       **args.aug_kwargs)),
        batch_size=args.eval.batch_size,
        shuffle=False,
        **args.dataloader_kwargs)

    model = get_backbone(args.model.backbone)
    classifier = nn.Linear(in_features=model.output_dim,
                           out_features=16,
                           bias=True).to(args.device)

    assert args.eval_from is not None
    save_dict = torch.load(args.eval_from, map_location='cpu')
    msg = model.load_state_dict(
        {
            k[9:]: v
            for k, v in save_dict['state_dict'].items()
            if k.startswith('backbone.')
        },
        strict=True)

    # print(msg)
    model = model.to(args.device)
    model = torch.nn.DataParallel(model)

    # if torch.cuda.device_count() > 1: classifier = torch.nn.SyncBatchNorm.convert_sync_batchnorm(classifier)
    classifier = torch.nn.DataParallel(classifier)
    # define optimizer
    optimizer = get_optimizer(args.eval.optimizer.name,
                              classifier,
                              lr=args.eval.base_lr * args.eval.batch_size /
                              256,
                              momentum=args.eval.optimizer.momentum,
                              weight_decay=args.eval.optimizer.weight_decay)

    # define lr scheduler
    lr_scheduler = LR_Scheduler(
        optimizer,
        args.eval.warmup_epochs,
        args.eval.warmup_lr * args.eval.batch_size / 256,
        args.eval.num_epochs,
        args.eval.base_lr * args.eval.batch_size / 256,
        args.eval.final_lr * args.eval.batch_size / 256,
        len(train_loader),
    )

    loss_meter = AverageMeter(name='Loss')
    acc_meter = AverageMeter(name='Accuracy')

    # Start training
    global_progress = tqdm(range(0, args.eval.num_epochs), desc=f'Evaluating')
    for epoch in global_progress:
        loss_meter.reset()
        model.eval()
        classifier.train()
        local_progress = tqdm(train_loader,
                              desc=f'Epoch {epoch}/{args.eval.num_epochs}',
                              disable=False)

        for idx, (images, labels) in enumerate(local_progress):
            classifier.zero_grad()
            with torch.no_grad():
                feature = model(images.to(args.device))

            preds = classifier(feature)

            loss = F.cross_entropy(preds, labels.to(args.device))

            loss.backward()
            optimizer.step()
            loss_meter.update(loss.item())
            lr = lr_scheduler.step()
            local_progress.set_postfix({
                'lr': lr,
                "loss": loss_meter.val,
                'loss_avg': loss_meter.avg
            })

    classifier.eval()
    correct, total = 0, 0
    acc_meter.reset()
    for idx, (images, labels) in enumerate(test_loader):
        with torch.no_grad():
            feature = model(images.to(args.device))
            preds = classifier(feature).argmax(dim=1)
            correct = (preds == labels.to(args.device)).sum().item()
            acc_meter.update(correct / preds.shape[0])
    print(f'Accuracy = {acc_meter.avg * 100:.2f}')
Beispiel #9
0
def main(args):

    train_set = get_dataset(
        args.dataset,
        args.data_dir,
        transform=get_aug(args.model, args.image_size, True),
        train=True,
        download=args.download,  # default is False
        debug_subset_size=args.batch_size
        if args.debug else None  # run one batch if debug
    )

    train_loader = torch.utils.data.DataLoader(dataset=train_set,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=args.num_workers,
                                               pin_memory=True,
                                               drop_last=True)

    # define model
    model = get_model(args.model, args.backbone).to(args.device)
    backbone = model.backbone
    if args.model == 'simsiam' and args.proj_layers is not None:
        model.projector.set_layers(args.proj_layers)

    if args.local_rank >= 0:
        torch.cuda.set_device(args.local_rank)
        torch.distributed.init_process_group(backend="nccl",
                                             init_method="env://")
        model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
        model = torch.nn.parallel.DistributedDataParallel(
            model,
            device_ids=[args.local_rank],
            output_device=args.local_rank,
            find_unused_parameters=True)

    # define optimizer
    optimizer = get_optimizer(args.optimizer,
                              model,
                              lr=args.base_lr * args.batch_size / 256,
                              momentum=args.momentum,
                              weight_decay=args.weight_decay)

    lr_scheduler = LR_Scheduler(optimizer, args.warmup_epochs,
                                args.warmup_lr * args.batch_size / 256,
                                args.num_epochs,
                                args.base_lr * args.batch_size / 256,
                                args.final_lr * args.batch_size / 256,
                                len(train_loader))

    loss_meter = AverageMeter(name='Loss')
    plot_logger = PlotLogger(params=['epoch', 'lr', 'loss'])
    os.makedirs(args.output_dir, exist_ok=True)
    # Start training
    global_progress = tqdm(range(0, args.stop_at_epoch), desc=f'Training')
    for epoch in global_progress:
        loss_meter.reset()
        model.train()

        local_progress = tqdm(train_loader,
                              desc=f'Epoch {epoch}/{args.num_epochs}',
                              disable=args.hide_progress)
        for idx, ((images1, images2), labels) in enumerate(local_progress):

            model.zero_grad()
            loss = model.forward(images1.to(args.device),
                                 images2.to(args.device))
            loss.backward()
            optimizer.step()
            loss_meter.update(loss.item())
            lr = lr_scheduler.step()
            local_progress.set_postfix({'lr': lr, "loss": loss_meter.val})
            plot_logger.update({
                'epoch': epoch,
                'lr': lr,
                'loss': loss_meter.val
            })
        global_progress.set_postfix({
            "epoch": epoch,
            "loss_avg": loss_meter.avg
        })
        plot_logger.save(os.path.join(args.output_dir, 'logger.svg'))

    # Save checkpoint
    if args.local_rank <= 0:
        model_path = os.path.join(
            args.output_dir,
            f'{args.model}-{args.dataset}-epoch{args.stop_at_epoch}.pth')
        torch.save(
            {
                'epoch': args.stop_at_epoch,
                'state_dict': model.state_dict(),
                # 'optimizer':optimizer.state_dict(), # will double the checkpoint file size
                'lr_scheduler': lr_scheduler,
                'args': args,
                'loss_meter': loss_meter,
                'plot_logger': plot_logger
            },
            model_path)
        print(f"Model saved to {model_path}")

    if args.eval_after_train is not None:
        arg_list = [
            x.strip().lstrip('--').split()
            for x in args.eval_after_train.split('\n')
        ]
        args.__dict__.update({x[0]: eval(x[1]) for x in arg_list})
        args.distributed_initialized = True
        if args.debug:
            args.batch_size = 2
            args.num_epochs = 3

        linear_eval(args, backbone)
Beispiel #10
0
def main(gpu, args):
    rank = args.nr * args.gpus + gpu
    dist.init_process_group("nccl", rank=rank, world_size=args.world_size)

    torch.manual_seed(0)
    torch.cuda.set_device(gpu)

    train_dataset = get_dataset(transform=get_aug(train=True,
                                                  **args.aug_kwargs),
                                train=True,
                                **args.dataset_kwargs)

    train_sampler = torch.utils.data.distributed.DistributedSampler(
        train_dataset, num_replicas=args.world_size, rank=rank)

    train_loader = torch.utils.data.DataLoader(
        dataset=train_dataset,
        shuffle=False,
        batch_size=(args.train.batch_size // args.gpus),
        sampler=train_sampler,
        **args.dataloader_kwargs)

    memory_dataset = get_dataset(transform=get_aug(train=False,
                                                   train_classifier=False,
                                                   **args.aug_kwargs),
                                 train=True,
                                 **args.dataset_kwargs)

    memory_loader = torch.utils.data.DataLoader(
        dataset=memory_dataset,
        shuffle=False,
        batch_size=(args.train.batch_size // args.gpus),
        **args.dataloader_kwargs)

    test_datset = get_dataset(transform=get_aug(train=False,
                                                train_classifier=False,
                                                **args.aug_kwargs),
                              train=False,
                              **args.dataset_kwargs)

    test_loader = torch.utils.data.DataLoader(
        dataset=test_datset,
        shuffle=False,
        batch_size=(args.train.batch_size // args.gpus),
        **args.dataloader_kwargs)
    print("Batch size:", (args.train.batch_size // args.gpus))
    # define model
    model = get_model(args.model).cuda(gpu)
    model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
    model = DDP(model, device_ids=[gpu], find_unused_parameters=True)

    # define optimizer
    optimizer = get_optimizer(args.train.optimizer.name,
                              model,
                              lr=args.train.base_lr * args.train.batch_size /
                              256,
                              momentum=args.train.optimizer.momentum,
                              weight_decay=args.train.optimizer.weight_decay)

    lr_scheduler = LR_Scheduler(
        optimizer,
        args.train.warmup_epochs,
        args.train.warmup_lr * args.train.batch_size / 256,
        args.train.num_epochs,
        args.train.base_lr * args.train.batch_size / 256,
        args.train.final_lr * args.train.batch_size / 256,
        len(train_loader),
        constant_predictor_lr=True  # see the end of section 4.2 predictor
    )
    if gpu == 0:
        logger = Logger(tensorboard=args.logger.tensorboard,
                        matplotlib=args.logger.matplotlib,
                        log_dir=args.log_dir)
    accuracy = 0
    # Start training
    global_progress = tqdm(range(0, args.train.stop_at_epoch),
                           desc=f'Training')
    for epoch in global_progress:
        model.train()

        local_progress = tqdm(train_loader,
                              desc=f'Epoch {epoch}/{args.train.num_epochs}',
                              disable=args.hide_progress)
        for idx, ((images1, images2), labels) in enumerate(local_progress):

            model.zero_grad()
            data_dict = model.forward(images1.cuda(non_blocking=True),
                                      images2.cuda(non_blocking=True))
            loss = data_dict['loss']  # ddp
            loss.backward()
            optimizer.step()
            lr_scheduler.step()
            data_dict.update({'lr': lr_scheduler.get_lr()})
            local_progress.set_postfix(data_dict)
            if gpu == 0:
                logger.update_scalers(data_dict)

        if args.train.knn_monitor and epoch % args.train.knn_interval == 0 and gpu == 0:
            accuracy = knn_monitor(model.module.backbone,
                                   memory_loader,
                                   test_loader,
                                   gpu,
                                   k=min(args.train.knn_k,
                                         len(memory_loader.dataset)),
                                   hide_progress=args.hide_progress)

        epoch_dict = {"epoch": epoch, "accuracy": accuracy}
        global_progress.set_postfix(epoch_dict)

        if gpu == 0:
            logger.update_scalers(epoch_dict)

        # Save checkpoint
        if gpu == 0 and epoch % args.train.knn_interval == 0:
            model_path = os.path.join(
                args.ckpt_dir, f"{args.name}_{epoch+1}.pth"
            )  # datetime.now().strftime('%Y%m%d_%H%M%S')
            torch.save(
                {
                    'epoch': epoch + 1,
                    'state_dict': model.module.state_dict()
                }, model_path)
            print(f"Model saved to {model_path}")
            with open(os.path.join(args.log_dir, f"checkpoint_path.txt"),
                      'w+') as f:
                f.write(f'{model_path}')

    # if args.eval is not False and gpu == 0:
    #     args.eval_from = model_path
    #     linear_eval(args)

    dist.destroy_process_group()
Beispiel #11
0
def main(device, args):

    loss1_func = nn.CrossEntropyLoss()
    loss2_func = softmax_kl_loss

    dataset_kwargs = {
        'dataset': args.dataset,
        'data_dir': args.data_dir,
        'download': args.download,
        'debug_subset_size': args.batch_size if args.debug else None
    }
    dataloader_kwargs = {
        'batch_size': args.batch_size,
        'drop_last': True,
        'pin_memory': True,
        'num_workers': args.num_workers,
    }
    dataloader_unlabeled_kwargs = {
        'batch_size': args.batch_size * 5,
        'drop_last': True,
        'pin_memory': True,
        'num_workers': args.num_workers,
    }
    dataset_train = get_dataset(transform=get_aug(args.dataset, True),
                                train=True,
                                **dataset_kwargs)

    if args.iid == 'iid':
        dict_users_labeled, dict_users_unlabeled_server, dict_users_unlabeled = iid(
            dataset_train, args.num_users, args.label_rate)
    else:
        dict_users_labeled, dict_users_unlabeled_server, dict_users_unlabeled = noniid(
            dataset_train, args.num_users, args.label_rate)
    train_loader_unlabeled = {}

    # define model
    model_glob = get_model('global', args.backbone).to(device)
    if torch.cuda.device_count() > 1:
        model_glob = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model_glob)

    model_local_idx = set()
    model_local_dict = {}
    accuracy = []
    lr_scheduler = {}

    for iter in range(args.num_epochs):

        model_glob.train()
        optimizer = torch.optim.SGD(model_glob.parameters(),
                                    lr=0.01,
                                    momentum=0.5)

        train_loader_labeled = torch.utils.data.DataLoader(
            dataset=DatasetSplit(dataset_train, dict_users_labeled),
            shuffle=True,
            **dataloader_kwargs)
        train_loader_unlabeled = torch.utils.data.DataLoader(
            dataset=DatasetSplit(dataset_train, dict_users_unlabeled_server),
            shuffle=True,
            **dataloader_unlabeled_kwargs)
        train_loader = zip(train_loader_labeled, train_loader_unlabeled)

        for batch_idx, (data_x, data_u) in enumerate(train_loader):
            (images1_l, images2_l), labels = data_x
            (images1_u, images2_u), _ = data_u

            model_glob.zero_grad()
            labels = labels.cuda()

            batch_size = images1_l.shape[0]
            images1 = torch.cat((images1_l, images1_u)).to(args.device)
            images2 = torch.cat((images2_l, images2_u)).to(args.device)

            z1_t, z2_t, z1_s, z2_s = model_glob.forward(
                images1.to(device, non_blocking=True),
                images2.to(device, non_blocking=True))
            model_glob.update_moving_average(iter * 1000 + batch_idx, 20000)

            loss_class = 1 / 2 * loss1_func(z1_t[:batch_size],
                                            labels) + 1 / 2 * loss1_func(
                                                z2_t[:batch_size], labels)

            loss_consist = 1 / 2 * loss2_func(z1_t, z2_s) / len(
                labels) + 1 / 2 * loss2_func(z2_t, z1_s) / len(labels)
            consistency_weight = get_current_consistency_weight(batch_idx)
            loss = loss_class  # + consistency_weight * loss_consist

            loss.backward()
            optimizer.step()

        del train_loader_labeled
        gc.collect()
        torch.cuda.empty_cache()

        if iter % 1 == 0:
            test_loader = torch.utils.data.DataLoader(dataset=get_dataset(
                transform=get_aug(args.dataset, False, train_classifier=False),
                train=False,
                **dataset_kwargs),
                                                      shuffle=False,
                                                      **dataloader_kwargs)
            model_glob.eval()
            acc, loss_train_test_labeled = test_img(model_glob, test_loader,
                                                    args)
            accuracy.append(str(acc))
            del test_loader
            gc.collect()
            torch.cuda.empty_cache()

        w_locals, loss_locals, loss0_locals, loss2_locals = [], [], [], []

        m = max(int(args.frac * args.num_users), 1)
        idxs_users = np.random.choice(range(args.num_users), m, replace=False)

        for idx in idxs_users:

            loss_local = []
            loss0_local = []
            loss2_local = []

            if idx in model_local_idx:
                model_local = get_model('local', args.backbone).to(device)
                model_local.projector.load_state_dict(model_local_dict[idx][0])
                model_local.target_encoder.load_state_dict(
                    model_local_dict[idx][1])
                #                 model_local.projector.load_state_dict(torch.load('/model/'+'model1' + str(args.dataset) + str(idx)+ '.pkl'))
                #                 model_local.target_encoder.load_state_dict(torch.load('/model/'+'model1' + str(args.dataset) + 'tar'+ str(idx)+ '.pkl'))

                model_local.backbone.load_state_dict(
                    model_glob.backbone.state_dict())
            else:
                model_local = get_model('local', args.backbone).to(device)
                model_local.backbone.load_state_dict(
                    model_glob.backbone.state_dict())
                model_local.target_encoder.load_state_dict(
                    model_local.online_encoder.state_dict())
                model_local_idx = model_local_idx | set([idx])

            train_loader_unlabeled = torch.utils.data.DataLoader(
                dataset=DatasetSplit(dataset_train, dict_users_unlabeled[idx]),
                shuffle=True,
                **dataloader_unlabeled_kwargs)

            # define optimizer
            optimizer = get_optimizer(args.optimizer,
                                      model_local,
                                      lr=args.base_lr * args.batch_size / 256,
                                      momentum=args.momentum,
                                      weight_decay=args.weight_decay)

            lr_scheduler = LR_Scheduler(
                optimizer,
                args.warmup_epochs,
                args.warmup_lr * args.batch_size / 256,
                args.num_epochs,
                args.base_lr * args.batch_size / 256,
                args.final_lr * args.batch_size / 256,
                len(train_loader_unlabeled),
                constant_predictor_lr=
                True  # see the end of section 4.2 predictor
            )

            model_local.train()

            for j in range(args.local_ep):

                for i, ((images1, images2),
                        labels) in enumerate(train_loader_unlabeled):

                    model_local.zero_grad()

                    batch_size = images1.shape[0]

                    loss = model_local.forward(
                        images1.to(device, non_blocking=True),
                        images2.to(device, non_blocking=True))

                    loss.backward()
                    optimizer.step()

                    loss_local.append(int(loss))

                    lr = lr_scheduler.step()

                    model_local.update_moving_average()

                w_locals.append(
                    copy.deepcopy(model_local.backbone.state_dict()))
                loss_locals.append(sum(loss_local) / len(loss_local))
                model_local_dict[idx] = [
                    model_local.projector.state_dict(),
                    model_local.target_encoder.state_dict()
                ]
    #             torch.save(model_local.projector.state_dict(), '/model/'+'model1' + str(args.dataset) + str(idx)+ '.pkl')
    #             torch.save(model_local.target_encoder.state_dict(), '/model/'+'model1' + str(args.dataset)+ 'tar' + str(idx)+ '.pkl')

            del model_local
            gc.collect()
            del train_loader_unlabeled
            gc.collect()
            torch.cuda.empty_cache()

        w_glob = FedAvg(w_locals)
        model_glob.backbone.load_state_dict(w_glob)

        loss_avg = sum(loss_locals) / len(loss_locals)

        if iter % 1 == 0:
            print('Round {:3d}, Acc {:.2f}%'.format(iter, acc))