Пример #1
0
def build_generator(image_size, name):
    return tf.keras.Sequential([
        get_backbone(name)(
            input_shape=(image_size, image_size, 3),
            include_top=False,
            weights='imagenet'
        ),
        tf.keras.layers.GlobalAveragePooling2D(),
        tf.keras.layers.Dense(4096, activation='relu')
    ])
Пример #2
0
    def __init__(self, cfg: dict):
        super().__init__()
        self.backbone = models.get_backbone(cfg)
        self.fpn = models.get_fpn(cfg)
        self.agg = models.get_agg(cfg)
        self.rpn = models.get_rpn(cfg)
        self.det_layers = torch.nn.ModuleList()

        det_layer = models.get_det_layer(cfg)
        for level_i in range(len(cfg['model.fpn.out_channels'])):
            self.det_layers.append(det_layer(level_i=level_i, cfg=cfg))

        self.check_gt_assignment = cfg.get('train.check_gt_assignment', False)
        self.bb_format = cfg.get('general.pred_bbox_format')
        self.input_format = cfg['general.input_format']

        self.hid_names = cfg['model.agg.hidden_state_names']
        assert all([(s in self.valid_hidden_states) for s in self.hid_names])
        self.hid_names = set(self.hid_names)
        self.hidden = None
Пример #3
0
def main(args):
    train_info = []
    best_epoch = np.zeros(5)
    for val_folder_index in range(5):
        best_balance_acc = 0
        whole_train_list = ['D8E6', '117E', '676F', 'E2D7', 'BE52']
        val_WSI_list = whole_train_list[val_folder_index]
        train_WSI_list = whole_train_list
        train_WSI_list.pop(val_folder_index)
        train_directory = '../data/finetune/1percent/'
        valid_directory = '../data/finetune/1percent'
        dataset = {}
        dataset_train0 = datasets.ImageFolder(
            root=train_directory + train_WSI_list[0],
            transform=get_aug(train=False,
                              train_classifier=True,
                              **args.aug_kwargs))
        dataset_train1 = datasets.ImageFolder(
            root=train_directory + train_WSI_list[1],
            transform=get_aug(train=False,
                              train_classifier=True,
                              **args.aug_kwargs))
        dataset_train2 = datasets.ImageFolder(
            root=train_directory + train_WSI_list[2],
            transform=get_aug(train=False,
                              train_classifier=True,
                              **args.aug_kwargs))
        dataset_train3 = datasets.ImageFolder(
            root=train_directory + train_WSI_list[3],
            transform=get_aug(train=False,
                              train_classifier=True,
                              **args.aug_kwargs))
        dataset['valid'] = datasets.ImageFolder(
            root=valid_directory + val_WSI_list,
            transform=get_aug(train=False,
                              train_classifier=False,
                              **args.aug_kwargs))
        dataset['train'] = data.ConcatDataset(
            [dataset_train0, dataset_train1, dataset_train2, dataset_train3])

        train_loader = torch.utils.data.DataLoader(
            dataset=dataset['train'],
            batch_size=args.eval.batch_size,
            shuffle=True,
            **args.dataloader_kwargs)
        test_loader = torch.utils.data.DataLoader(
            dataset=dataset['valid'],
            batch_size=args.eval.batch_size,
            shuffle=False,
            **args.dataloader_kwargs)

        model = get_backbone(args.model.backbone)
        classifier = nn.Linear(in_features=model.output_dim,
                               out_features=9,
                               bias=True).to(args.device)

        assert args.eval_from is not None
        save_dict = torch.load(args.eval_from, map_location='cpu')
        msg = model.load_state_dict(
            {
                k[9:]: v
                for k, v in save_dict['state_dict'].items()
                if k.startswith('backbone.')
            },
            strict=True)

        # print(msg)
        model = model.to(args.device)
        model = torch.nn.DataParallel(model)

        classifier = torch.nn.DataParallel(classifier)
        # define optimizer
        optimizer = get_optimizer(
            args.eval.optimizer.name,
            classifier,
            lr=args.eval.base_lr * args.eval.batch_size / 256,
            momentum=args.eval.optimizer.momentum,
            weight_decay=args.eval.optimizer.weight_decay)

        # define lr scheduler
        lr_scheduler = LR_Scheduler(
            optimizer,
            args.eval.warmup_epochs,
            args.eval.warmup_lr * args.eval.batch_size / 256,
            args.eval.num_epochs,
            args.eval.base_lr * args.eval.batch_size / 256,
            args.eval.final_lr * args.eval.batch_size / 256,
            len(train_loader),
        )

        loss_meter = AverageMeter(name='Loss')
        acc_meter = AverageMeter(name='Accuracy')

        # Start training
        global_progress = tqdm(range(0, args.eval.num_epochs),
                               desc=f'Evaluating')
        for epoch in global_progress:
            loss_meter.reset()
            model.eval()
            classifier.train()
            local_progress = tqdm(train_loader,
                                  desc=f'Epoch {epoch}/{args.eval.num_epochs}',
                                  disable=True)

            for idx, (images, labels) in enumerate(local_progress):
                classifier.zero_grad()
                with torch.no_grad():
                    feature = model(images.to(args.device))

                preds = classifier(feature)

                loss = F.cross_entropy(preds, labels.to(args.device))

                loss.backward()
                optimizer.step()
                loss_meter.update(loss.item())
                lr = lr_scheduler.step()
                local_progress.set_postfix({
                    'lr': lr,
                    "loss": loss_meter.val,
                    'loss_avg': loss_meter.avg
                })

            writer.add_scalar('Valid/Loss', loss_meter.avg, epoch)
            writer.add_scalar('Valid/Lr', lr, epoch)
            writer.flush()

            PATH = 'checkpoint/exp_0228_triple_1percent/' + val_WSI_list + '/' + val_WSI_list + '_tunelinear_' + str(
                epoch) + '.pth'

            torch.save(classifier, PATH)

            classifier.eval()
            correct, total = 0, 0
            acc_meter.reset()

            pred_label_for_f1 = np.array([])
            true_label_for_f1 = np.array([])
            for idx, (images, labels) in enumerate(test_loader):
                with torch.no_grad():
                    feature = model(images.to(args.device))
                    preds = classifier(feature).argmax(dim=1)
                    correct = (preds == labels.to(args.device)).sum().item()

                    preds_arr = preds.cpu().detach().numpy()
                    labels_arr = labels.cpu().detach().numpy()
                    pred_label_for_f1 = np.concatenate(
                        [pred_label_for_f1, preds_arr])
                    true_label_for_f1 = np.concatenate(
                        [true_label_for_f1, labels_arr])
                    acc_meter.update(correct / preds.shape[0])

            f1 = f1_score(true_label_for_f1,
                          pred_label_for_f1,
                          average='macro')
            balance_acc = balanced_accuracy_score(true_label_for_f1,
                                                  pred_label_for_f1)
            print('Epoch:  ', str(epoch),
                  f'Accuracy = {acc_meter.avg * 100:.2f}')
            print('F1 score =  ', f1, 'balance acc:  ', balance_acc)
            if balance_acc > best_balance_acc:
                best_epoch[val_folder_index] = epoch
                best_balance_acc = balance_acc
            train_info.append([val_WSI_list, epoch, f1, balance_acc])

    with open('checkpoint/exp_0228_triple_1percent/train_info.csv', 'w') as f:
        # using csv.writer method from CSV package
        write = csv.writer(f)
        write.writerows(train_info)
    print(best_epoch)
Пример #4
0
def main(args):

    train_set = get_dataset(
        args.dataset, 
        args.data_dir, 
        transform=get_aug(args.model, args.image_size, train=False, train_classifier=True), 
        train=True, 
        download=args.download, # default is False
        debug_subset_size=args.batch_size if args.debug else None
    )
    test_set = get_dataset(
        args.dataset, 
        args.data_dir, 
        transform=get_aug(args.model, args.image_size, train=False, train_classifier=False), 
        train=False, 
        download=args.download, # default is False
        debug_subset_size=args.batch_size if args.debug else None
    )


    train_loader = torch.utils.data.DataLoader(
        dataset=train_set,
        batch_size=args.batch_size,
        shuffle=True,
        num_workers=args.num_workers,
        pin_memory=True,
        drop_last=True
    )
    test_loader = torch.utils.data.DataLoader(
        dataset=test_set,
        batch_size=args.batch_size,
        shuffle=False,
        num_workers=args.num_workers,
        pin_memory=True,
        drop_last=True
    )
    model = get_backbone(args.backbone)
    classifier = nn.Linear(in_features=model.output_dim, out_features=10, bias=True).to(args.device)

    assert args.eval_from is not None
    save_dict = torch.load(args.eval_from, map_location='cpu')
    msg = model.load_state_dict({k[9:]:v for k, v in save_dict['state_dict'].items() if k.startswith('backbone.')}, strict=True)
    
    # print(msg)
    model = model.to(args.device)
    model = torch.nn.DataParallel(model)

    # if torch.cuda.device_count() > 1: classifier = torch.nn.SyncBatchNorm.convert_sync_batchnorm(classifier)
    classifier = torch.nn.DataParallel(classifier)
    # define optimizer
    optimizer = get_optimizer(
        args.optimizer, classifier, 
        lr=args.base_lr*args.batch_size/256, 
        momentum=args.momentum, 
        weight_decay=args.weight_decay)

    # define lr scheduler
    lr_scheduler = LR_Scheduler(
        optimizer,
        args.warmup_epochs, args.warmup_lr*args.batch_size/256, 
        args.num_epochs, args.base_lr*args.batch_size/256, args.final_lr*args.batch_size/256, 
        len(train_loader),
    )

    loss_meter = AverageMeter(name='Loss')
    acc_meter = AverageMeter(name='Accuracy')

    # Start training
    global_progress = tqdm(range(0, args.num_epochs), desc=f'Evaluating')
    for epoch in global_progress:
        loss_meter.reset()
        model.eval()
        classifier.train()
        local_progress = tqdm(train_loader, desc=f'Epoch {epoch}/{args.num_epochs}', disable=args.hide_progress)
        
        for idx, (images, labels) in enumerate(local_progress):

            classifier.zero_grad()
            with torch.no_grad():
                feature = model(images.to(args.device))

            preds = classifier(feature)

            loss = F.cross_entropy(preds, labels.to(args.device))

            loss.backward()
            optimizer.step()
            loss_meter.update(loss.item())
            lr = lr_scheduler.step()
            local_progress.set_postfix({'lr':lr, "loss":loss_meter.val, 'loss_avg':loss_meter.avg})
        

        if args.head_tail_accuracy and epoch != 0 and (epoch+1) != args.num_epochs: continue

        local_progress=tqdm(test_loader, desc=f'Test {epoch}/{args.num_epochs}', disable=args.hide_progress)
        classifier.eval()
        correct, total = 0, 0
        acc_meter.reset()
        for idx, (images, labels) in enumerate(local_progress):
            with torch.no_grad():
                feature = model(images.to(args.device))
                preds = classifier(feature).argmax(dim=1)
                correct = (preds == labels.to(args.device)).sum().item()
                acc_meter.update(correct/preds.shape[0])
                local_progress.set_postfix({'accuracy': acc_meter.avg})
        
        global_progress.set_postfix({"epoch":epoch, 'accuracy':acc_meter.avg*100})
Пример #5
0
def main(args, model=None):
    assert args.eval_from is not None or model is not None
    train_set = get_dataset(
        args.dataset,
        args.data_dir,
        transform=get_aug(args.model,
                          args.image_size,
                          train=False,
                          train_classifier=True),
        train=True,
        download=args.download,  # default is False
        debug_subset_size=args.batch_size
        if args.debug else None  # Use a subset of dataset for debugging.
    )
    test_set = get_dataset(
        args.dataset,
        args.data_dir,
        transform=get_aug(args.model,
                          args.image_size,
                          train=False,
                          train_classifier=False),
        train=False,
        download=args.download,  # default is False
        debug_subset_size=args.batch_size if args.debug else None)

    train_loader = torch.utils.data.DataLoader(dataset=train_set,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=args.num_workers,
                                               pin_memory=True,
                                               drop_last=True)
    test_loader = torch.utils.data.DataLoader(dataset=test_set,
                                              batch_size=args.batch_size,
                                              shuffle=False,
                                              num_workers=args.num_workers,
                                              pin_memory=True,
                                              drop_last=True)

    model = get_backbone(args.backbone)
    classifier = nn.Linear(in_features=model.output_dim,
                           out_features=len(train_set.classes),
                           bias=True).to(args.device)

    if args.local_rank >= 0 and not torch.distributed.is_initialized():
        torch.cuda.set_device(args.local_rank)
        torch.distributed.init_process_group(backend="nccl",
                                             init_method="env://")

    if model is None:
        model = get_backbone(args.backbone).to(args.device)
        save_dict = torch.load(args.eval_from, map_location=args.device)
        model.load_state_dict(
            {
                k[9:]: v
                for k, v in save_dict['state_dict'].items()
                if k.startswith('backbone.')
            },
            strict=True)

    output_dim = model.output_dim
    if args.local_rank >= 0:
        model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
        model = torch.nn.parallel.DistributedDataParallel(
            model,
            device_ids=[args.local_rank],
            output_device=args.local_rank,
            find_unused_parameters=True)

    classifier = nn.Linear(in_features=output_dim, out_features=10,
                           bias=True).to(args.device)
    if args.local_rank >= 0:
        classifier = torch.nn.parallel.DistributedDataParallel(
            classifier,
            device_ids=[args.local_rank],
            output_device=args.local_rank)

    # define optimizer
    optimizer = get_optimizer(args.optimizer,
                              classifier,
                              lr=args.base_lr * args.batch_size / 256,
                              momentum=args.momentum,
                              weight_decay=args.weight_decay)

    # TODO: linear lr warm up for byol simclr swav
    # args.warm_up_epochs
    # define lr scheduler
    lr_scheduler = LR_Scheduler(optimizer, args.warmup_epochs,
                                args.warmup_lr * args.batch_size / 256,
                                args.num_epochs,
                                args.base_lr * args.batch_size / 256,
                                args.final_lr * args.batch_size / 256,
                                len(train_loader))

    loss_meter = AverageMeter(name='Loss')
    acc_meter = AverageMeter(name='Accuracy')

    # Start training
    global_progress = tqdm(range(0, args.num_epochs), desc=f'Evaluating')
    for epoch in global_progress:
        loss_meter.reset()
        model.eval()
        classifier.train()
        local_progress = tqdm(train_loader,
                              desc=f'Epoch {epoch}/{args.num_epochs}',
                              disable=args.hide_progress)

        for idx, (images, labels) in enumerate(local_progress):

            classifier.zero_grad()
            with torch.no_grad():
                feature = model(images.to(args.device))

            preds = classifier(feature)

            loss = F.cross_entropy(preds, labels.to(args.device))

            loss.backward()
            optimizer.step()
            loss_meter.update(loss.item())
            lr = lr_scheduler.step()
            local_progress.set_postfix({
                'lr': lr,
                "loss": loss_meter.val,
                'loss_avg': loss_meter.avg
            })

        if args.head_tail_accuracy and epoch != 0 and (epoch +
                                                       1) != args.num_epochs:
            continue

        local_progress = tqdm(test_loader,
                              desc=f'Test {epoch}/{args.num_epochs}',
                              disable=args.hide_progress)
        classifier.eval()
        correct, total = 0, 0
        acc_meter.reset()
        for idx, (images, labels) in enumerate(local_progress):
            with torch.no_grad():
                feature = model(images.to(args.device))
                preds = classifier(feature).argmax(dim=1)
                correct = (preds == labels.to(args.device)).sum().item()
                acc_meter.update(correct / preds.shape[0])
                local_progress.set_postfix({'accuracy': acc_meter.avg})

        global_progress.set_postfix({
            "epoch": epoch,
            'accuracy': acc_meter.avg * 100
        })
Пример #6
0
def main(args):

    train_loader = torch.utils.data.DataLoader(dataset=get_dataset(
        transform=get_aug(train=False,
                          train_classifier=True,
                          **args.aug_kwargs),
        train=True,
        **args.dataset_kwargs),
                                               batch_size=args.eval.batch_size,
                                               shuffle=True,
                                               **args.dataloader_kwargs)
    test_loader = torch.utils.data.DataLoader(dataset=get_dataset(
        transform=get_aug(train=False,
                          train_classifier=False,
                          **args.aug_kwargs),
        train=False,
        **args.dataset_kwargs),
                                              batch_size=args.eval.batch_size,
                                              shuffle=False,
                                              **args.dataloader_kwargs)

    model = get_backbone(args.model.backbone)
    classifier = nn.Linear(in_features=model.output_dim,
                           out_features=10,
                           bias=True).to(args.device)

    assert args.eval_from is not None
    save_dict = torch.load(args.eval_from, map_location='cpu')
    msg = model.load_state_dict(
        {
            k[9:]: v
            for k, v in save_dict['state_dict'].items()
            if k.startswith('backbone.')
        },
        strict=True)

    # print(msg)
    model = model.to(args.device)
    model = torch.nn.DataParallel(model)

    # if torch.cuda.device_count() > 1: classifier = torch.nn.SyncBatchNorm.convert_sync_batchnorm(classifier)
    classifier = torch.nn.DataParallel(classifier)
    # define optimizer
    optimizer = get_optimizer(args.eval.optimizer.name,
                              classifier,
                              lr=args.eval.base_lr * args.eval.batch_size /
                              256,
                              momentum=args.eval.optimizer.momentum,
                              weight_decay=args.eval.optimizer.weight_decay)

    # define lr scheduler
    lr_scheduler = LR_Scheduler(
        optimizer,
        args.eval.warmup_epochs,
        args.eval.warmup_lr * args.eval.batch_size / 256,
        args.eval.num_epochs,
        args.eval.base_lr * args.eval.batch_size / 256,
        args.eval.final_lr * args.eval.batch_size / 256,
        len(train_loader),
    )

    loss_meter = AverageMeter(name='Loss')
    acc_meter = AverageMeter(name='Accuracy')

    # Start training
    global_progress = tqdm(range(0, args.eval.num_epochs), desc=f'Evaluating')
    for epoch in global_progress:
        loss_meter.reset()
        model.eval()
        classifier.train()
        local_progress = tqdm(train_loader,
                              desc=f'Epoch {epoch}/{args.eval.num_epochs}',
                              disable=True)

        for idx, (images, labels) in enumerate(local_progress):
            # this will take the images and stick them to one another using the batch dimension
            # so it expects [C x H x W] and will turn each into a [1 x C x H x W] and then for N it will
            # concatenate them into a big tensor of [N x C x H x W]
            if type(images) == list:
                print(images[1].shape, len(images))
                images = torch.cat(
                    [image.unsqueeze(dim=0) for image in images], dim=0)

            classifier.zero_grad()
            with torch.no_grad():
                feature = model(images.to(args.device))

            preds = classifier(feature)

            loss = F.cross_entropy(preds, labels.to(args.device))

            loss.backward()
            optimizer.step()
            loss_meter.update(loss.item())
            lr = lr_scheduler.step()
            local_progress.set_postfix({
                'lr': lr,
                "loss": loss_meter.val,
                'loss_avg': loss_meter.avg
            })

    classifier.eval()
    correct, total = 0, 0
    acc_meter.reset()
    for idx, (images, labels) in enumerate(test_loader):
        with torch.no_grad():
            feature = model(images.to(args.device))
            preds = classifier(feature).argmax(dim=1)
            correct = (preds == labels.to(args.device)).sum().item()
            acc_meter.update(correct / preds.shape[0])
    print(f'Accuracy = {acc_meter.avg*100:.2f}')
Пример #7
0
def main(args):

    train_set = get_dataset(
        args.dataset,
        args.data_dir,
        transform=get_aug(args.model,
                          args.image_size,
                          train=False,
                          train_classifier=True),
        train=True,
        download=args.download  # default is False
    )
    test_set = get_dataset(
        args.dataset,
        args.data_dir,
        transform=get_aug(args.model,
                          args.image_size,
                          train=False,
                          train_classifier=False),
        train=False,
        download=args.download  # default is False
    )

    if args.debug:
        args.batch_size = 20
        args.num_epochs = 2
        args.num_workers = 0
        train_set = torch.utils.data.Subset(train_set, range(
            0, args.batch_size))  # take only one batch
        test_set = torch.utils.data.Subset(test_set, range(0, args.batch_size))

    train_loader = torch.utils.data.DataLoader(dataset=train_set,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=args.num_workers,
                                               pin_memory=True,
                                               drop_last=True)
    test_loader = torch.utils.data.DataLoader(dataset=train_set,
                                              batch_size=args.batch_size,
                                              shuffle=False,
                                              num_workers=args.num_workers,
                                              pin_memory=True,
                                              drop_last=True)

    # define model
    # model = get_model(args.model, args.backbone)
    backbone = get_backbone(args.backbone, castrate=False)
    in_features = backbone.fc.in_features
    backbone.fc = nn.Identity()
    model = backbone
    assert args.eval_from is not None
    save_dict = torch.load(args.eval_from, map_location='cpu')
    msg = model.load_state_dict(
        {
            k[9:]: v
            for k, v in save_dict['state_dict'].items()
            if k.startswith('backbone.')
        },
        strict=True)
    print(msg)
    model = model.to(args.device)
    model = torch.nn.DataParallel(model)
    # if torch.cuda.device_count() > 1: model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)

    classifier = nn.Linear(in_features=in_features, out_features=10,
                           bias=True).to(args.device)
    classifier = torch.nn.DataParallel(classifier)
    # breakpoint()

    # define optimizer
    optimizer = get_optimizer(args.optimizer,
                              classifier,
                              lr=args.base_lr * args.batch_size / 256,
                              momentum=args.momentum,
                              weight_decay=args.weight_decay)

    # TODO: linear lr warm up for byol simclr swav
    # args.warm_up_epochs

    # define lr scheduler
    lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer,
                                                              args.num_epochs,
                                                              eta_min=0)

    loss_meter = AverageMeter(name='Loss')
    acc_meter = AverageMeter(name='Accuracy')
    # Start training
    for epoch in tqdm(range(0, args.num_epochs), desc=f'Evaluating'):
        loss_meter.reset()
        model.eval()
        classifier.train()
        p_bar = tqdm(train_loader, desc=f'Epoch {epoch}/{args.num_epochs}')

        for idx, (images, labels) in enumerate(p_bar):
            # breakpoint()
            classifier.zero_grad()
            with torch.no_grad():
                feature = model(images.to(args.device))
            # breakpoint()
            preds = classifier(feature)

            loss = F.cross_entropy(preds, labels.to(args.device))
            # loss = model.forward(images1.to(args.device), images2.to(args.device))
            loss.backward()
            optimizer.step()
            loss_meter.update(loss.item())
            p_bar.set_postfix({
                "loss": loss_meter.val,
                'loss_avg': loss_meter.avg
            })

        lr_scheduler.step()

        p_bar = tqdm(test_loader, desc=f'Test {epoch}/{args.num_epochs}')
        classifier.eval()
        correct, total = 0, 0
        acc_meter.reset()
        for idx, (images, labels) in enumerate(p_bar):
            with torch.no_grad():
                feature = model(images.to(args.device))
                preds = classifier(feature).argmax(dim=1)
                correct = (preds == labels.to(args.device)).sum().item()
                acc_meter.update(correct / preds.shape[0])
                p_bar.set_postfix({'accuracy': acc_meter.avg})
Пример #8
0
def main(device, args):
    train_directory = '../data/train'
    image_name_file = '../data/original.csv'
    val_directory = '../data/train'
    train_loader = torch.utils.data.DataLoader(
        dataset=get_dataset('random', train_directory, image_name_file,
            transform=get_aug(train=True, **args.aug_kwargs),
            train=True,
            **args.dataset_kwargs),
        # dataset=datasets.ImageFolder(root=train_directory, transform=get_aug(train=True, **args.aug_kwargs)),
        shuffle=True,
        batch_size=args.train.batch_size,
        **args.dataloader_kwargs
    )

    memory_loader = torch.utils.data.DataLoader(
        dataset=datasets.ImageFolder(root=val_directory, transform=get_aug(train=False, train_classifier=False, **args.aug_kwargs)),
        shuffle=False,
        batch_size=args.train.batch_size,
        **args.dataloader_kwargs
    )
    test_loader = torch.utils.data.DataLoader(
        dataset=datasets.ImageFolder(root=val_directory, transform=get_aug(train=False, train_classifier=False, **args.aug_kwargs)),
        shuffle=False,
        batch_size=args.train.batch_size,
        **args.dataloader_kwargs
    )

    # define model
    model = get_model(args.model).to(device)
    model = torch.nn.DataParallel(model)
    scaler = torch.cuda.amp.GradScaler()

    # define optimizer
    optimizer = get_optimizer(
        args.train.optimizer.name, model,
        lr=args.train.base_lr * args.train.batch_size / 256,
        momentum=args.train.optimizer.momentum,
        weight_decay=args.train.optimizer.weight_decay)

    lr_scheduler = LR_Scheduler(
        optimizer,
        args.train.warmup_epochs, args.train.warmup_lr * args.train.batch_size / 256,
        args.train.num_epochs, args.train.base_lr * args.train.batch_size / 256,
                                  args.train.final_lr * args.train.batch_size / 256,
        len(train_loader),
        constant_predictor_lr=True  # see the end of section 4.2 predictor
    )

    RESUME = False
    start_epoch = 0

    if RESUME:
        model = get_backbone(args.model.backbone)
        classifier = nn.Linear(in_features=model.output_dim, out_features=9, bias=True).to(args.device)

        assert args.eval_from is not None
        save_dict = torch.load(args.eval_from, map_location='cpu')
        msg = model.load_state_dict({k[9:]: v for k, v in save_dict['state_dict'].items() if k.startswith('backbone.')},
                                    strict=True)

        path_checkpoint = "./checkpoint/simsiam-TCGA-0218-nearby_0221134812.pth"  # 断点路径
        checkpoint = torch.load(path_checkpoint)  # 加载断点

        model.load_state_dict(checkpoint['net'])  # 加载模型可学习参数

        optimizer.load_state_dict(checkpoint['optimizer'])  # 加载优化器参数
        start_epoch = checkpoint['epoch']  # 设置开始的epoch

    logger = Logger(tensorboard=args.logger.tensorboard, matplotlib=args.logger.matplotlib, log_dir=args.log_dir)
    accuracy = 0
    # Start training
    global_progress = tqdm(range(start_epoch, args.train.stop_at_epoch), desc=f'Training')
    for epoch in global_progress:
        model.train()

        local_progress = tqdm(train_loader, desc=f'Epoch {epoch}/{args.train.num_epochs}', disable=args.hide_progress)
        for idx, (images1, images2, images3, labels) in enumerate(local_progress):
            model.zero_grad()
            with torch.cuda.amp.autocast():
                data_dict = model.forward(images1.to(device, non_blocking=True), images2.to(device, non_blocking=True),
                                          images3.to(device, non_blocking=True))
                loss = data_dict['loss'].mean()  # ddp
            # loss.backward()
            scaler.scale(loss).backward()
            # optimizer.step()
            scaler.step(optimizer)
            scaler.update()

            lr_scheduler.step()
            data_dict.update({'lr': lr_scheduler.get_lr()})

            local_progress.set_postfix(data_dict)
            logger.update_scalers(data_dict)

        if args.train.knn_monitor and epoch % args.train.knn_interval == 0:
            accuracy = knn_monitor(model.module.backbone, memory_loader, test_loader, device,
                                   k=min(args.train.knn_k, len(memory_loader.dataset)),
                                   hide_progress=args.hide_progress)

        epoch_dict = {"epoch": epoch, "accuracy": accuracy}
        global_progress.set_postfix(epoch_dict)
        logger.update_scalers(epoch_dict)

        checkpoint = {
            "net": model.state_dict(),
            'optimizer': optimizer.state_dict(),
            "epoch": epoch
        }
        if (epoch % args.train.save_interval) == 0:
            torch.save({
                'epoch': epoch + 1,
                'state_dict': model.module.state_dict()
            }, './checkpoint/exp_0223_triple_400_proj3/ckpt_best_%s.pth' % (str(epoch)))

    # Save checkpoint
    model_path = os.path.join(args.ckpt_dir,
                              f"{args.name}_{datetime.now().strftime('%m%d%H%M%S')}.pth")  # datetime.now().strftime('%Y%m%d_%H%M%S')
    torch.save({
        'epoch': epoch + 1,
        'state_dict': model.module.state_dict()
    }, model_path)
    print(f"Model saved to {model_path}")
    with open(os.path.join(args.log_dir, f"checkpoint_path.txt"), 'w+') as f:
        f.write(f'{model_path}')


    if args.eval is not False:
        args.eval_from = model_path
        linear_eval(args)
Пример #9
0
def main(args):
    # test_dictionary = '/share/contrastive_learning/data/sup_data/data_0122/val_patch'
    # test_dictionary = '/share/contrastive_learning/data/crop_after_process_doctor/crop_test_screened-20210207T180715Z-001/crop_test_screened'
    test_dictionary = '/share/contrastive_learning/data/crop_after_process_doctor/crop_train_for_exp/crop_test_screened'
    test_loader = torch.utils.data.DataLoader(
        # dataset=get_dataset(
        #     transform=get_aug(train=False, train_classifier=False, **args.aug_kwargs),
        #     train=False,
        #     **args.dataset_kwargs
        # ),
        dataset=datasets.ImageFolder(root=test_dictionary,
                                     transform=get_aug(train=False,
                                                       train_classifier=False,
                                                       **args.aug_kwargs)),
        batch_size=args.eval.batch_size,
        shuffle=False,
        **args.dataloader_kwargs)

    model = get_backbone(args.model.backbone)
    # classifier = nn.Linear(in_features=model.output_dim, out_features=16, bias=True).to(args.device)

    # assert args.eval_from is not None
    # save_dict = torch.load(args.eval_from, map_location='cpu')
    # msg = model.load_state_dict({k[9:]: v for k, v in save_dict['state_dict'].items() if k.startswith('backbone.')},
    #                             strict=True)
    # for ep in range(100):
    MODEL = '/share/contrastive_learning/SimSiam_PatrickHua/SimSiam-main-v2/SimSiam-main/checkpoint/exp_0206_eval/99_all_new/simsiam-TCGA-0126-128by128_tuneall_36.pth'

    # Load the model for testing
    # model = get_backbone(args.model.backbone)
    # model = model.cuda()
    # model = torch.nn.DataParallel(model)
    model = torch.load(MODEL)
    # model = model.load_state_dict({k[9:]: v for k, v in dict.items() if k.startswith('backbone.')},
    # strict=True)
    # model = model.load_state_dict(torch.load(MODEL))
    # model = model.load_state_dict({k: v for k, v in torch.load(MODEL).items() if k.startswith('module.')})
    model.eval()

    # print(msg)
    # if torch.cuda.device_count() > 1: classifier = torch.nn.SyncBatchNorm.convert_sync_batchnorm(classifier)
    # classifier = torch.nn.DataParallel(classifier)
    # define optimizer
    optimizer = get_optimizer(args.eval.optimizer.name,
                              model,
                              lr=args.eval.base_lr * args.eval.batch_size /
                              256,
                              momentum=args.eval.optimizer.momentum,
                              weight_decay=args.eval.optimizer.weight_decay)

    acc_meter = AverageMeter(name='Accuracy')

    # Start training
    acc_meter.reset()

    pred_label_for_f1 = np.array([])
    true_label_for_f1 = np.array([])
    for idx, (images, labels) in enumerate(test_loader):
        with torch.no_grad():
            feature = model(images.to(args.device))
            preds = feature.argmax(dim=1)
            correct = (preds == labels.to(args.device)).sum().item()

            preds_arr = preds.cpu().detach().numpy()
            labels_arr = labels.cpu().detach().numpy()
            pred_label_for_f1 = np.concatenate([pred_label_for_f1, preds_arr])
            true_label_for_f1 = np.concatenate([true_label_for_f1, labels_arr])
            acc_meter.update(correct / preds.shape[0])

    f1 = f1_score(true_label_for_f1, pred_label_for_f1, average='macro')
    # precision = precision_score(true_label_for_f1, pred_label_for_f1, average='micro')
    # recall = recall_score(true_label_for_f1, pred_label_for_f1, average='micro')

    print('Epoch : ', str(36), f'Accuracy = {acc_meter.avg * 100:.2f}',
          'F1 score =  ', f1)
    print('F1 score =  ', f1)
    cm = confusion_matrix(true_label_for_f1, pred_label_for_f1)
    np.savetxt("foo_36.csv", cm, delimiter=",")
Пример #10
0
def main(args):
    train_directory = '/share/contrastive_learning/data/sup_data/data_0124_10000/train_patch'
    train_loader = torch.utils.data.DataLoader(
        # dataset=get_dataset(
        #     transform=get_aug(train=False, train_classifier=True, **args.aug_kwargs),
        #     train=True,
        #     **args.dataset_kwargs
        # ),
        dataset=datasets.ImageFolder(root=train_directory,
                                     transform=get_aug(train=False,
                                                       train_classifier=True,
                                                       **args.aug_kwargs)),
        batch_size=args.eval.batch_size,
        shuffle=True,
        **args.dataloader_kwargs)
    test_dictionary = '/share/contrastive_learning/data/sup_data/data_0124_10000/val_patch'
    test_loader = torch.utils.data.DataLoader(
        # dataset=get_dataset(
        #     transform=get_aug(train=False, train_classifier=False, **args.aug_kwargs),
        #     train=False,
        #     **args.dataset_kwargs
        # ),
        dataset=datasets.ImageFolder(root=test_dictionary,
                                     transform=get_aug(train=False,
                                                       train_classifier=False,
                                                       **args.aug_kwargs)),
        batch_size=args.eval.batch_size,
        shuffle=False,
        **args.dataloader_kwargs)

    model = get_backbone(args.model.backbone)
    classifier = nn.Linear(in_features=model.output_dim,
                           out_features=16,
                           bias=True).to(args.device)

    assert args.eval_from is not None
    save_dict = torch.load(args.eval_from, map_location='cpu')
    msg = model.load_state_dict(
        {
            k[9:]: v
            for k, v in save_dict['state_dict'].items()
            if k.startswith('backbone.')
        },
        strict=True)

    # print(msg)
    model = model.to(args.device)
    model = torch.nn.DataParallel(model)

    # if torch.cuda.device_count() > 1: classifier = torch.nn.SyncBatchNorm.convert_sync_batchnorm(classifier)
    classifier = torch.nn.DataParallel(classifier)
    # define optimizer
    optimizer = get_optimizer(args.eval.optimizer.name,
                              classifier,
                              lr=args.eval.base_lr * args.eval.batch_size /
                              256,
                              momentum=args.eval.optimizer.momentum,
                              weight_decay=args.eval.optimizer.weight_decay)

    # define lr scheduler
    lr_scheduler = LR_Scheduler(
        optimizer,
        args.eval.warmup_epochs,
        args.eval.warmup_lr * args.eval.batch_size / 256,
        args.eval.num_epochs,
        args.eval.base_lr * args.eval.batch_size / 256,
        args.eval.final_lr * args.eval.batch_size / 256,
        len(train_loader),
    )

    loss_meter = AverageMeter(name='Loss')
    acc_meter = AverageMeter(name='Accuracy')

    # Start training
    global_progress = tqdm(range(0, args.eval.num_epochs), desc=f'Evaluating')
    for epoch in global_progress:
        loss_meter.reset()
        model.eval()
        classifier.train()
        local_progress = tqdm(train_loader,
                              desc=f'Epoch {epoch}/{args.eval.num_epochs}',
                              disable=False)

        for idx, (images, labels) in enumerate(local_progress):
            classifier.zero_grad()
            with torch.no_grad():
                feature = model(images.to(args.device))

            preds = classifier(feature)

            loss = F.cross_entropy(preds, labels.to(args.device))

            loss.backward()
            optimizer.step()
            loss_meter.update(loss.item())
            lr = lr_scheduler.step()
            local_progress.set_postfix({
                'lr': lr,
                "loss": loss_meter.val,
                'loss_avg': loss_meter.avg
            })

    classifier.eval()
    correct, total = 0, 0
    acc_meter.reset()
    for idx, (images, labels) in enumerate(test_loader):
        with torch.no_grad():
            feature = model(images.to(args.device))
            preds = classifier(feature).argmax(dim=1)
            correct = (preds == labels.to(args.device)).sum().item()
            acc_meter.update(correct / preds.shape[0])
    print(f'Accuracy = {acc_meter.avg * 100:.2f}')