コード例 #1
0
def main(args):
    train_info = []
    best_epoch = np.zeros(5)
    for val_folder_index in range(5):
        best_balance_acc = 0
        whole_train_list = ['D8E6', '117E', '676F', 'E2D7', 'BE52']
        val_WSI_list = whole_train_list[val_folder_index]
        train_WSI_list = whole_train_list
        train_WSI_list.pop(val_folder_index)
        train_directory = '../data/finetune/1percent/'
        valid_directory = '../data/finetune/1percent'
        dataset = {}
        dataset_train0 = datasets.ImageFolder(
            root=train_directory + train_WSI_list[0],
            transform=get_aug(train=False,
                              train_classifier=True,
                              **args.aug_kwargs))
        dataset_train1 = datasets.ImageFolder(
            root=train_directory + train_WSI_list[1],
            transform=get_aug(train=False,
                              train_classifier=True,
                              **args.aug_kwargs))
        dataset_train2 = datasets.ImageFolder(
            root=train_directory + train_WSI_list[2],
            transform=get_aug(train=False,
                              train_classifier=True,
                              **args.aug_kwargs))
        dataset_train3 = datasets.ImageFolder(
            root=train_directory + train_WSI_list[3],
            transform=get_aug(train=False,
                              train_classifier=True,
                              **args.aug_kwargs))
        dataset['valid'] = datasets.ImageFolder(
            root=valid_directory + val_WSI_list,
            transform=get_aug(train=False,
                              train_classifier=False,
                              **args.aug_kwargs))
        dataset['train'] = data.ConcatDataset(
            [dataset_train0, dataset_train1, dataset_train2, dataset_train3])

        train_loader = torch.utils.data.DataLoader(
            dataset=dataset['train'],
            batch_size=args.eval.batch_size,
            shuffle=True,
            **args.dataloader_kwargs)
        test_loader = torch.utils.data.DataLoader(
            dataset=dataset['valid'],
            batch_size=args.eval.batch_size,
            shuffle=False,
            **args.dataloader_kwargs)

        model = get_backbone(args.model.backbone)
        classifier = nn.Linear(in_features=model.output_dim,
                               out_features=9,
                               bias=True).to(args.device)

        assert args.eval_from is not None
        save_dict = torch.load(args.eval_from, map_location='cpu')
        msg = model.load_state_dict(
            {
                k[9:]: v
                for k, v in save_dict['state_dict'].items()
                if k.startswith('backbone.')
            },
            strict=True)

        # print(msg)
        model = model.to(args.device)
        model = torch.nn.DataParallel(model)

        classifier = torch.nn.DataParallel(classifier)
        # define optimizer
        optimizer = get_optimizer(
            args.eval.optimizer.name,
            classifier,
            lr=args.eval.base_lr * args.eval.batch_size / 256,
            momentum=args.eval.optimizer.momentum,
            weight_decay=args.eval.optimizer.weight_decay)

        # define lr scheduler
        lr_scheduler = LR_Scheduler(
            optimizer,
            args.eval.warmup_epochs,
            args.eval.warmup_lr * args.eval.batch_size / 256,
            args.eval.num_epochs,
            args.eval.base_lr * args.eval.batch_size / 256,
            args.eval.final_lr * args.eval.batch_size / 256,
            len(train_loader),
        )

        loss_meter = AverageMeter(name='Loss')
        acc_meter = AverageMeter(name='Accuracy')

        # Start training
        global_progress = tqdm(range(0, args.eval.num_epochs),
                               desc=f'Evaluating')
        for epoch in global_progress:
            loss_meter.reset()
            model.eval()
            classifier.train()
            local_progress = tqdm(train_loader,
                                  desc=f'Epoch {epoch}/{args.eval.num_epochs}',
                                  disable=True)

            for idx, (images, labels) in enumerate(local_progress):
                classifier.zero_grad()
                with torch.no_grad():
                    feature = model(images.to(args.device))

                preds = classifier(feature)

                loss = F.cross_entropy(preds, labels.to(args.device))

                loss.backward()
                optimizer.step()
                loss_meter.update(loss.item())
                lr = lr_scheduler.step()
                local_progress.set_postfix({
                    'lr': lr,
                    "loss": loss_meter.val,
                    'loss_avg': loss_meter.avg
                })

            writer.add_scalar('Valid/Loss', loss_meter.avg, epoch)
            writer.add_scalar('Valid/Lr', lr, epoch)
            writer.flush()

            PATH = 'checkpoint/exp_0228_triple_1percent/' + val_WSI_list + '/' + val_WSI_list + '_tunelinear_' + str(
                epoch) + '.pth'

            torch.save(classifier, PATH)

            classifier.eval()
            correct, total = 0, 0
            acc_meter.reset()

            pred_label_for_f1 = np.array([])
            true_label_for_f1 = np.array([])
            for idx, (images, labels) in enumerate(test_loader):
                with torch.no_grad():
                    feature = model(images.to(args.device))
                    preds = classifier(feature).argmax(dim=1)
                    correct = (preds == labels.to(args.device)).sum().item()

                    preds_arr = preds.cpu().detach().numpy()
                    labels_arr = labels.cpu().detach().numpy()
                    pred_label_for_f1 = np.concatenate(
                        [pred_label_for_f1, preds_arr])
                    true_label_for_f1 = np.concatenate(
                        [true_label_for_f1, labels_arr])
                    acc_meter.update(correct / preds.shape[0])

            f1 = f1_score(true_label_for_f1,
                          pred_label_for_f1,
                          average='macro')
            balance_acc = balanced_accuracy_score(true_label_for_f1,
                                                  pred_label_for_f1)
            print('Epoch:  ', str(epoch),
                  f'Accuracy = {acc_meter.avg * 100:.2f}')
            print('F1 score =  ', f1, 'balance acc:  ', balance_acc)
            if balance_acc > best_balance_acc:
                best_epoch[val_folder_index] = epoch
                best_balance_acc = balance_acc
            train_info.append([val_WSI_list, epoch, f1, balance_acc])

    with open('checkpoint/exp_0228_triple_1percent/train_info.csv', 'w') as f:
        # using csv.writer method from CSV package
        write = csv.writer(f)
        write.writerows(train_info)
    print(best_epoch)
コード例 #2
0
def main(args, model=None):
    assert args.eval_from is not None or model is not None
    train_set = get_dataset(
        args.dataset,
        args.data_dir,
        transform=get_aug(args.model,
                          args.image_size,
                          train=False,
                          train_classifier=True),
        train=True,
        download=args.download,  # default is False
        debug_subset_size=args.batch_size
        if args.debug else None  # Use a subset of dataset for debugging.
    )
    test_set = get_dataset(
        args.dataset,
        args.data_dir,
        transform=get_aug(args.model,
                          args.image_size,
                          train=False,
                          train_classifier=False),
        train=False,
        download=args.download,  # default is False
        debug_subset_size=args.batch_size if args.debug else None)

    train_loader = torch.utils.data.DataLoader(dataset=train_set,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=args.num_workers,
                                               pin_memory=True,
                                               drop_last=True)
    test_loader = torch.utils.data.DataLoader(dataset=test_set,
                                              batch_size=args.batch_size,
                                              shuffle=False,
                                              num_workers=args.num_workers,
                                              pin_memory=True,
                                              drop_last=True)

    model = get_backbone(args.backbone)
    classifier = nn.Linear(in_features=model.output_dim,
                           out_features=len(train_set.classes),
                           bias=True).to(args.device)

    if args.local_rank >= 0 and not torch.distributed.is_initialized():
        torch.cuda.set_device(args.local_rank)
        torch.distributed.init_process_group(backend="nccl",
                                             init_method="env://")

    if model is None:
        model = get_backbone(args.backbone).to(args.device)
        save_dict = torch.load(args.eval_from, map_location=args.device)
        model.load_state_dict(
            {
                k[9:]: v
                for k, v in save_dict['state_dict'].items()
                if k.startswith('backbone.')
            },
            strict=True)

    output_dim = model.output_dim
    if args.local_rank >= 0:
        model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
        model = torch.nn.parallel.DistributedDataParallel(
            model,
            device_ids=[args.local_rank],
            output_device=args.local_rank,
            find_unused_parameters=True)

    classifier = nn.Linear(in_features=output_dim, out_features=10,
                           bias=True).to(args.device)
    if args.local_rank >= 0:
        classifier = torch.nn.parallel.DistributedDataParallel(
            classifier,
            device_ids=[args.local_rank],
            output_device=args.local_rank)

    # define optimizer
    optimizer = get_optimizer(args.optimizer,
                              classifier,
                              lr=args.base_lr * args.batch_size / 256,
                              momentum=args.momentum,
                              weight_decay=args.weight_decay)

    # TODO: linear lr warm up for byol simclr swav
    # args.warm_up_epochs
    # define lr scheduler
    lr_scheduler = LR_Scheduler(optimizer, args.warmup_epochs,
                                args.warmup_lr * args.batch_size / 256,
                                args.num_epochs,
                                args.base_lr * args.batch_size / 256,
                                args.final_lr * args.batch_size / 256,
                                len(train_loader))

    loss_meter = AverageMeter(name='Loss')
    acc_meter = AverageMeter(name='Accuracy')

    # Start training
    global_progress = tqdm(range(0, args.num_epochs), desc=f'Evaluating')
    for epoch in global_progress:
        loss_meter.reset()
        model.eval()
        classifier.train()
        local_progress = tqdm(train_loader,
                              desc=f'Epoch {epoch}/{args.num_epochs}',
                              disable=args.hide_progress)

        for idx, (images, labels) in enumerate(local_progress):

            classifier.zero_grad()
            with torch.no_grad():
                feature = model(images.to(args.device))

            preds = classifier(feature)

            loss = F.cross_entropy(preds, labels.to(args.device))

            loss.backward()
            optimizer.step()
            loss_meter.update(loss.item())
            lr = lr_scheduler.step()
            local_progress.set_postfix({
                'lr': lr,
                "loss": loss_meter.val,
                'loss_avg': loss_meter.avg
            })

        if args.head_tail_accuracy and epoch != 0 and (epoch +
                                                       1) != args.num_epochs:
            continue

        local_progress = tqdm(test_loader,
                              desc=f'Test {epoch}/{args.num_epochs}',
                              disable=args.hide_progress)
        classifier.eval()
        correct, total = 0, 0
        acc_meter.reset()
        for idx, (images, labels) in enumerate(local_progress):
            with torch.no_grad():
                feature = model(images.to(args.device))
                preds = classifier(feature).argmax(dim=1)
                correct = (preds == labels.to(args.device)).sum().item()
                acc_meter.update(correct / preds.shape[0])
                local_progress.set_postfix({'accuracy': acc_meter.avg})

        global_progress.set_postfix({
            "epoch": epoch,
            'accuracy': acc_meter.avg * 100
        })
コード例 #3
0
def main(args):

    train_set = get_dataset(
        args.dataset, 
        args.data_dir, 
        transform=get_aug(args.model, args.image_size, train=False, train_classifier=True), 
        train=True, 
        download=args.download, # default is False
        debug_subset_size=args.batch_size if args.debug else None
    )
    test_set = get_dataset(
        args.dataset, 
        args.data_dir, 
        transform=get_aug(args.model, args.image_size, train=False, train_classifier=False), 
        train=False, 
        download=args.download, # default is False
        debug_subset_size=args.batch_size if args.debug else None
    )


    train_loader = torch.utils.data.DataLoader(
        dataset=train_set,
        batch_size=args.batch_size,
        shuffle=True,
        num_workers=args.num_workers,
        pin_memory=True,
        drop_last=True
    )
    test_loader = torch.utils.data.DataLoader(
        dataset=test_set,
        batch_size=args.batch_size,
        shuffle=False,
        num_workers=args.num_workers,
        pin_memory=True,
        drop_last=True
    )
    model = get_backbone(args.backbone)
    classifier = nn.Linear(in_features=model.output_dim, out_features=10, bias=True).to(args.device)

    assert args.eval_from is not None
    save_dict = torch.load(args.eval_from, map_location='cpu')
    msg = model.load_state_dict({k[9:]:v for k, v in save_dict['state_dict'].items() if k.startswith('backbone.')}, strict=True)
    
    # print(msg)
    model = model.to(args.device)
    model = torch.nn.DataParallel(model)

    # if torch.cuda.device_count() > 1: classifier = torch.nn.SyncBatchNorm.convert_sync_batchnorm(classifier)
    classifier = torch.nn.DataParallel(classifier)
    # define optimizer
    optimizer = get_optimizer(
        args.optimizer, classifier, 
        lr=args.base_lr*args.batch_size/256, 
        momentum=args.momentum, 
        weight_decay=args.weight_decay)

    # define lr scheduler
    lr_scheduler = LR_Scheduler(
        optimizer,
        args.warmup_epochs, args.warmup_lr*args.batch_size/256, 
        args.num_epochs, args.base_lr*args.batch_size/256, args.final_lr*args.batch_size/256, 
        len(train_loader),
    )

    loss_meter = AverageMeter(name='Loss')
    acc_meter = AverageMeter(name='Accuracy')

    # Start training
    global_progress = tqdm(range(0, args.num_epochs), desc=f'Evaluating')
    for epoch in global_progress:
        loss_meter.reset()
        model.eval()
        classifier.train()
        local_progress = tqdm(train_loader, desc=f'Epoch {epoch}/{args.num_epochs}', disable=args.hide_progress)
        
        for idx, (images, labels) in enumerate(local_progress):

            classifier.zero_grad()
            with torch.no_grad():
                feature = model(images.to(args.device))

            preds = classifier(feature)

            loss = F.cross_entropy(preds, labels.to(args.device))

            loss.backward()
            optimizer.step()
            loss_meter.update(loss.item())
            lr = lr_scheduler.step()
            local_progress.set_postfix({'lr':lr, "loss":loss_meter.val, 'loss_avg':loss_meter.avg})
        

        if args.head_tail_accuracy and epoch != 0 and (epoch+1) != args.num_epochs: continue

        local_progress=tqdm(test_loader, desc=f'Test {epoch}/{args.num_epochs}', disable=args.hide_progress)
        classifier.eval()
        correct, total = 0, 0
        acc_meter.reset()
        for idx, (images, labels) in enumerate(local_progress):
            with torch.no_grad():
                feature = model(images.to(args.device))
                preds = classifier(feature).argmax(dim=1)
                correct = (preds == labels.to(args.device)).sum().item()
                acc_meter.update(correct/preds.shape[0])
                local_progress.set_postfix({'accuracy': acc_meter.avg})
        
        global_progress.set_postfix({"epoch":epoch, 'accuracy':acc_meter.avg*100})
コード例 #4
0
def main(args):

    train_set = get_dataset(
        args.dataset,
        args.data_dir,
        transform=get_aug(args.model,
                          args.image_size,
                          train=False,
                          train_classifier=True),
        train=True,
        download=args.download  # default is False
    )
    test_set = get_dataset(
        args.dataset,
        args.data_dir,
        transform=get_aug(args.model,
                          args.image_size,
                          train=False,
                          train_classifier=False),
        train=False,
        download=args.download  # default is False
    )

    if args.debug:
        args.batch_size = 20
        args.num_epochs = 2
        args.num_workers = 0
        train_set = torch.utils.data.Subset(train_set, range(
            0, args.batch_size))  # take only one batch
        test_set = torch.utils.data.Subset(test_set, range(0, args.batch_size))

    train_loader = torch.utils.data.DataLoader(dataset=train_set,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=args.num_workers,
                                               pin_memory=True,
                                               drop_last=True)
    test_loader = torch.utils.data.DataLoader(dataset=train_set,
                                              batch_size=args.batch_size,
                                              shuffle=False,
                                              num_workers=args.num_workers,
                                              pin_memory=True,
                                              drop_last=True)

    # define model
    # model = get_model(args.model, args.backbone)
    backbone = get_backbone(args.backbone, castrate=False)
    in_features = backbone.fc.in_features
    backbone.fc = nn.Identity()
    model = backbone
    assert args.eval_from is not None
    save_dict = torch.load(args.eval_from, map_location='cpu')
    msg = model.load_state_dict(
        {
            k[9:]: v
            for k, v in save_dict['state_dict'].items()
            if k.startswith('backbone.')
        },
        strict=True)
    print(msg)
    model = model.to(args.device)
    model = torch.nn.DataParallel(model)
    # if torch.cuda.device_count() > 1: model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)

    classifier = nn.Linear(in_features=in_features, out_features=10,
                           bias=True).to(args.device)
    classifier = torch.nn.DataParallel(classifier)
    # breakpoint()

    # define optimizer
    optimizer = get_optimizer(args.optimizer,
                              classifier,
                              lr=args.base_lr * args.batch_size / 256,
                              momentum=args.momentum,
                              weight_decay=args.weight_decay)

    # TODO: linear lr warm up for byol simclr swav
    # args.warm_up_epochs

    # define lr scheduler
    lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer,
                                                              args.num_epochs,
                                                              eta_min=0)

    loss_meter = AverageMeter(name='Loss')
    acc_meter = AverageMeter(name='Accuracy')
    # Start training
    for epoch in tqdm(range(0, args.num_epochs), desc=f'Evaluating'):
        loss_meter.reset()
        model.eval()
        classifier.train()
        p_bar = tqdm(train_loader, desc=f'Epoch {epoch}/{args.num_epochs}')

        for idx, (images, labels) in enumerate(p_bar):
            # breakpoint()
            classifier.zero_grad()
            with torch.no_grad():
                feature = model(images.to(args.device))
            # breakpoint()
            preds = classifier(feature)

            loss = F.cross_entropy(preds, labels.to(args.device))
            # loss = model.forward(images1.to(args.device), images2.to(args.device))
            loss.backward()
            optimizer.step()
            loss_meter.update(loss.item())
            p_bar.set_postfix({
                "loss": loss_meter.val,
                'loss_avg': loss_meter.avg
            })

        lr_scheduler.step()

        p_bar = tqdm(test_loader, desc=f'Test {epoch}/{args.num_epochs}')
        classifier.eval()
        correct, total = 0, 0
        acc_meter.reset()
        for idx, (images, labels) in enumerate(p_bar):
            with torch.no_grad():
                feature = model(images.to(args.device))
                preds = classifier(feature).argmax(dim=1)
                correct = (preds == labels.to(args.device)).sum().item()
                acc_meter.update(correct / preds.shape[0])
                p_bar.set_postfix({'accuracy': acc_meter.avg})
コード例 #5
0
def main(args):

    train_loader = torch.utils.data.DataLoader(dataset=get_dataset(
        transform=get_aug(train=False,
                          train_classifier=True,
                          **args.aug_kwargs),
        train=True,
        **args.dataset_kwargs),
                                               batch_size=args.eval.batch_size,
                                               shuffle=True,
                                               **args.dataloader_kwargs)
    test_loader = torch.utils.data.DataLoader(dataset=get_dataset(
        transform=get_aug(train=False,
                          train_classifier=False,
                          **args.aug_kwargs),
        train=False,
        **args.dataset_kwargs),
                                              batch_size=args.eval.batch_size,
                                              shuffle=False,
                                              **args.dataloader_kwargs)

    model = get_backbone(args.model.backbone)
    classifier = nn.Linear(in_features=model.output_dim,
                           out_features=10,
                           bias=True).to(args.device)

    assert args.eval_from is not None
    save_dict = torch.load(args.eval_from, map_location='cpu')
    msg = model.load_state_dict(
        {
            k[9:]: v
            for k, v in save_dict['state_dict'].items()
            if k.startswith('backbone.')
        },
        strict=True)

    # print(msg)
    model = model.to(args.device)
    model = torch.nn.DataParallel(model)

    # if torch.cuda.device_count() > 1: classifier = torch.nn.SyncBatchNorm.convert_sync_batchnorm(classifier)
    classifier = torch.nn.DataParallel(classifier)
    # define optimizer
    optimizer = get_optimizer(args.eval.optimizer.name,
                              classifier,
                              lr=args.eval.base_lr * args.eval.batch_size /
                              256,
                              momentum=args.eval.optimizer.momentum,
                              weight_decay=args.eval.optimizer.weight_decay)

    # define lr scheduler
    lr_scheduler = LR_Scheduler(
        optimizer,
        args.eval.warmup_epochs,
        args.eval.warmup_lr * args.eval.batch_size / 256,
        args.eval.num_epochs,
        args.eval.base_lr * args.eval.batch_size / 256,
        args.eval.final_lr * args.eval.batch_size / 256,
        len(train_loader),
    )

    loss_meter = AverageMeter(name='Loss')
    acc_meter = AverageMeter(name='Accuracy')

    # Start training
    global_progress = tqdm(range(0, args.eval.num_epochs), desc=f'Evaluating')
    for epoch in global_progress:
        loss_meter.reset()
        model.eval()
        classifier.train()
        local_progress = tqdm(train_loader,
                              desc=f'Epoch {epoch}/{args.eval.num_epochs}',
                              disable=True)

        for idx, (images, labels) in enumerate(local_progress):
            # this will take the images and stick them to one another using the batch dimension
            # so it expects [C x H x W] and will turn each into a [1 x C x H x W] and then for N it will
            # concatenate them into a big tensor of [N x C x H x W]
            if type(images) == list:
                print(images[1].shape, len(images))
                images = torch.cat(
                    [image.unsqueeze(dim=0) for image in images], dim=0)

            classifier.zero_grad()
            with torch.no_grad():
                feature = model(images.to(args.device))

            preds = classifier(feature)

            loss = F.cross_entropy(preds, labels.to(args.device))

            loss.backward()
            optimizer.step()
            loss_meter.update(loss.item())
            lr = lr_scheduler.step()
            local_progress.set_postfix({
                'lr': lr,
                "loss": loss_meter.val,
                'loss_avg': loss_meter.avg
            })

    classifier.eval()
    correct, total = 0, 0
    acc_meter.reset()
    for idx, (images, labels) in enumerate(test_loader):
        with torch.no_grad():
            feature = model(images.to(args.device))
            preds = classifier(feature).argmax(dim=1)
            correct = (preds == labels.to(args.device)).sum().item()
            acc_meter.update(correct / preds.shape[0])
    print(f'Accuracy = {acc_meter.avg*100:.2f}')
コード例 #6
0
def main(device, args):
    dataset_kwargs = {
        'dataset': args.dataset,
        'data_dir': args.data_dir,
        'download': args.download,
        'debug_subset_size': args.batch_size if args.debug else None
    }
    dataloader_kwargs = {
        'batch_size': args.batch_size,
        'drop_last': True,
        'pin_memory': True,
        'num_workers': args.num_workers,
    }

    train_loader = torch.utils.data.DataLoader(dataset=get_dataset(
        transform=get_aug(args.model, args.image_size, True),
        train=True,
        **dataset_kwargs),
                                               shuffle=True,
                                               **dataloader_kwargs)
    memory_loader = torch.utils.data.DataLoader(dataset=get_dataset(
        transform=get_aug(args.model,
                          args.image_size,
                          False,
                          train_classifier=False),
        train=True,
        **dataset_kwargs),
                                                shuffle=False,
                                                **dataloader_kwargs)
    test_loader = torch.utils.data.DataLoader(dataset=get_dataset(
        transform=get_aug(args.model,
                          args.image_size,
                          False,
                          train_classifier=False),
        train=False,
        **dataset_kwargs),
                                              shuffle=False,
                                              **dataloader_kwargs)

    # define model
    model = get_model(args.model, args.backbone).to(device)
    if args.model == 'simsiam' and args.proj_layers is not None:
        model.projector.set_layers(args.proj_layers)
    model = torch.nn.DataParallel(model)
    if torch.cuda.device_count() > 1:
        model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)

    # define optimizer
    optimizer = get_optimizer(args.optimizer,
                              model,
                              lr=args.base_lr * args.batch_size / 256,
                              momentum=args.momentum,
                              weight_decay=args.weight_decay)

    lr_scheduler = LR_Scheduler(
        optimizer,
        args.warmup_epochs,
        args.warmup_lr * args.batch_size / 256,
        args.num_epochs,
        args.base_lr * args.batch_size / 256,
        args.final_lr * args.batch_size / 256,
        len(train_loader),
        constant_predictor_lr=True  # see the end of section 4.2 predictor
    )

    loss_meter = AverageMeter(name='Loss')
    plot_logger = PlotLogger(params=['lr', 'loss', 'accuracy'])
    # Start training
    global_progress = tqdm(range(0, args.stop_at_epoch), desc=f'Training')
    for epoch in global_progress:
        loss_meter.reset()
        model.train()

        # plot_logger.update({'epoch':epoch, 'accuracy':accuracy})
        local_progress = tqdm(train_loader,
                              desc=f'Epoch {epoch}/{args.num_epochs}',
                              disable=args.hide_progress)
        for idx, ((images1, images2), labels) in enumerate(local_progress):

            model.zero_grad()
            loss = model.forward(images1.to(device, non_blocking=True),
                                 images2.to(device, non_blocking=True))
            loss.backward()
            optimizer.step()
            loss_meter.update(loss.item())
            lr = lr_scheduler.step()

            data_dict = {'lr': lr, "loss": loss_meter.val}
            local_progress.set_postfix(data_dict)
            plot_logger.update(data_dict)
        accuracy = knn_monitor(model.module.backbone,
                               memory_loader,
                               test_loader,
                               device,
                               k=200,
                               hide_progress=args.hide_progress)
        global_progress.set_postfix({
            "epoch": epoch,
            "loss_avg": loss_meter.avg,
            "accuracy": accuracy
        })
        plot_logger.update({'accuracy': accuracy})
        plot_logger.save(os.path.join(args.output_dir, 'logger.svg'))

        # Save checkpoint

    model_path = os.path.join(
        args.output_dir, f'{args.model}-{args.dataset}-epoch{epoch+1}.pth')
    torch.save(
        {
            'epoch': epoch + 1,
            'state_dict': model.module.state_dict(),
            # 'optimizer':optimizer.state_dict(), # will double the checkpoint file size
            'lr_scheduler': lr_scheduler,
            'args': args,
            'loss_meter': loss_meter,
            'plot_logger': plot_logger
        },
        model_path)
    print(f"Model saved to {model_path}")

    if args.eval_after_train is not None:
        args.eval_from = model_path
        arg_list = [
            x.strip().lstrip('--').split()
            for x in args.eval_after_train.split('\n')
        ]
        args.__dict__.update({x[0]: eval(x[1]) for x in arg_list})
        if args.debug:
            args.batch_size = 2
            args.num_epochs = 3

        linear_eval(args)
コード例 #7
0
def main(args):
    # test_dictionary = '/share/contrastive_learning/data/sup_data/data_0122/val_patch'
    # test_dictionary = '/share/contrastive_learning/data/crop_after_process_doctor/crop_test_screened-20210207T180715Z-001/crop_test_screened'
    test_dictionary = '/share/contrastive_learning/data/crop_after_process_doctor/crop_train_for_exp/crop_test_screened'
    test_loader = torch.utils.data.DataLoader(
        # dataset=get_dataset(
        #     transform=get_aug(train=False, train_classifier=False, **args.aug_kwargs),
        #     train=False,
        #     **args.dataset_kwargs
        # ),
        dataset=datasets.ImageFolder(root=test_dictionary,
                                     transform=get_aug(train=False,
                                                       train_classifier=False,
                                                       **args.aug_kwargs)),
        batch_size=args.eval.batch_size,
        shuffle=False,
        **args.dataloader_kwargs)

    model = get_backbone(args.model.backbone)
    # classifier = nn.Linear(in_features=model.output_dim, out_features=16, bias=True).to(args.device)

    # assert args.eval_from is not None
    # save_dict = torch.load(args.eval_from, map_location='cpu')
    # msg = model.load_state_dict({k[9:]: v for k, v in save_dict['state_dict'].items() if k.startswith('backbone.')},
    #                             strict=True)
    # for ep in range(100):
    MODEL = '/share/contrastive_learning/SimSiam_PatrickHua/SimSiam-main-v2/SimSiam-main/checkpoint/exp_0206_eval/99_all_new/simsiam-TCGA-0126-128by128_tuneall_36.pth'

    # Load the model for testing
    # model = get_backbone(args.model.backbone)
    # model = model.cuda()
    # model = torch.nn.DataParallel(model)
    model = torch.load(MODEL)
    # model = model.load_state_dict({k[9:]: v for k, v in dict.items() if k.startswith('backbone.')},
    # strict=True)
    # model = model.load_state_dict(torch.load(MODEL))
    # model = model.load_state_dict({k: v for k, v in torch.load(MODEL).items() if k.startswith('module.')})
    model.eval()

    # print(msg)
    # if torch.cuda.device_count() > 1: classifier = torch.nn.SyncBatchNorm.convert_sync_batchnorm(classifier)
    # classifier = torch.nn.DataParallel(classifier)
    # define optimizer
    optimizer = get_optimizer(args.eval.optimizer.name,
                              model,
                              lr=args.eval.base_lr * args.eval.batch_size /
                              256,
                              momentum=args.eval.optimizer.momentum,
                              weight_decay=args.eval.optimizer.weight_decay)

    acc_meter = AverageMeter(name='Accuracy')

    # Start training
    acc_meter.reset()

    pred_label_for_f1 = np.array([])
    true_label_for_f1 = np.array([])
    for idx, (images, labels) in enumerate(test_loader):
        with torch.no_grad():
            feature = model(images.to(args.device))
            preds = feature.argmax(dim=1)
            correct = (preds == labels.to(args.device)).sum().item()

            preds_arr = preds.cpu().detach().numpy()
            labels_arr = labels.cpu().detach().numpy()
            pred_label_for_f1 = np.concatenate([pred_label_for_f1, preds_arr])
            true_label_for_f1 = np.concatenate([true_label_for_f1, labels_arr])
            acc_meter.update(correct / preds.shape[0])

    f1 = f1_score(true_label_for_f1, pred_label_for_f1, average='macro')
    # precision = precision_score(true_label_for_f1, pred_label_for_f1, average='micro')
    # recall = recall_score(true_label_for_f1, pred_label_for_f1, average='micro')

    print('Epoch : ', str(36), f'Accuracy = {acc_meter.avg * 100:.2f}',
          'F1 score =  ', f1)
    print('F1 score =  ', f1)
    cm = confusion_matrix(true_label_for_f1, pred_label_for_f1)
    np.savetxt("foo_36.csv", cm, delimiter=",")
コード例 #8
0
def main(args):
    train_directory = '/share/contrastive_learning/data/sup_data/data_0124_10000/train_patch'
    train_loader = torch.utils.data.DataLoader(
        # dataset=get_dataset(
        #     transform=get_aug(train=False, train_classifier=True, **args.aug_kwargs),
        #     train=True,
        #     **args.dataset_kwargs
        # ),
        dataset=datasets.ImageFolder(root=train_directory,
                                     transform=get_aug(train=False,
                                                       train_classifier=True,
                                                       **args.aug_kwargs)),
        batch_size=args.eval.batch_size,
        shuffle=True,
        **args.dataloader_kwargs)
    test_dictionary = '/share/contrastive_learning/data/sup_data/data_0124_10000/val_patch'
    test_loader = torch.utils.data.DataLoader(
        # dataset=get_dataset(
        #     transform=get_aug(train=False, train_classifier=False, **args.aug_kwargs),
        #     train=False,
        #     **args.dataset_kwargs
        # ),
        dataset=datasets.ImageFolder(root=test_dictionary,
                                     transform=get_aug(train=False,
                                                       train_classifier=False,
                                                       **args.aug_kwargs)),
        batch_size=args.eval.batch_size,
        shuffle=False,
        **args.dataloader_kwargs)

    model = get_backbone(args.model.backbone)
    classifier = nn.Linear(in_features=model.output_dim,
                           out_features=16,
                           bias=True).to(args.device)

    assert args.eval_from is not None
    save_dict = torch.load(args.eval_from, map_location='cpu')
    msg = model.load_state_dict(
        {
            k[9:]: v
            for k, v in save_dict['state_dict'].items()
            if k.startswith('backbone.')
        },
        strict=True)

    # print(msg)
    model = model.to(args.device)
    model = torch.nn.DataParallel(model)

    # if torch.cuda.device_count() > 1: classifier = torch.nn.SyncBatchNorm.convert_sync_batchnorm(classifier)
    classifier = torch.nn.DataParallel(classifier)
    # define optimizer
    optimizer = get_optimizer(args.eval.optimizer.name,
                              classifier,
                              lr=args.eval.base_lr * args.eval.batch_size /
                              256,
                              momentum=args.eval.optimizer.momentum,
                              weight_decay=args.eval.optimizer.weight_decay)

    # define lr scheduler
    lr_scheduler = LR_Scheduler(
        optimizer,
        args.eval.warmup_epochs,
        args.eval.warmup_lr * args.eval.batch_size / 256,
        args.eval.num_epochs,
        args.eval.base_lr * args.eval.batch_size / 256,
        args.eval.final_lr * args.eval.batch_size / 256,
        len(train_loader),
    )

    loss_meter = AverageMeter(name='Loss')
    acc_meter = AverageMeter(name='Accuracy')

    # Start training
    global_progress = tqdm(range(0, args.eval.num_epochs), desc=f'Evaluating')
    for epoch in global_progress:
        loss_meter.reset()
        model.eval()
        classifier.train()
        local_progress = tqdm(train_loader,
                              desc=f'Epoch {epoch}/{args.eval.num_epochs}',
                              disable=False)

        for idx, (images, labels) in enumerate(local_progress):
            classifier.zero_grad()
            with torch.no_grad():
                feature = model(images.to(args.device))

            preds = classifier(feature)

            loss = F.cross_entropy(preds, labels.to(args.device))

            loss.backward()
            optimizer.step()
            loss_meter.update(loss.item())
            lr = lr_scheduler.step()
            local_progress.set_postfix({
                'lr': lr,
                "loss": loss_meter.val,
                'loss_avg': loss_meter.avg
            })

    classifier.eval()
    correct, total = 0, 0
    acc_meter.reset()
    for idx, (images, labels) in enumerate(test_loader):
        with torch.no_grad():
            feature = model(images.to(args.device))
            preds = classifier(feature).argmax(dim=1)
            correct = (preds == labels.to(args.device)).sum().item()
            acc_meter.update(correct / preds.shape[0])
    print(f'Accuracy = {acc_meter.avg * 100:.2f}')
コード例 #9
0
def main():

    parser = ArgumentParser()
    parser.add_argument("--file_num", type=int, default=10,
                        help="Number of pregenerate file")
    parser.add_argument("--reduce_memory", action="store_true",
                        help="Store training data as on-disc memmaps to massively reduce memory usage")
    parser.add_argument("--epochs", type=int, default=2,
                        help="Number of epochs to train for")
    parser.add_argument('--num_eval_steps', default=200)
    parser.add_argument('--num_save_steps', default=5000)
    parser.add_argument("--local_rank", type=int, default=-1,
                        help="local_rank for distributed training on gpus")
    parser.add_argument("--no_cuda", action='store_true',
                        help="Whether not to use CUDA when available")
    parser.add_argument('--gradient_accumulation_steps', type=int, default=1,
                        help="Number of updates steps to accumulate before performing a backward/update pass.")
    parser.add_argument("--train_batch_size", default=24, type=int,
                        help="Total batch size for training.")
    parser.add_argument('--fp16', action='store_true',
                        help="Whether to use 16-bit float precision instead of 32-bit")
    parser.add_argument('--loss_scale', type=float, default=0,
                        help="Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n"
                             "0 (default value): dynamic loss scaling.\n"
                             "Positive power of 2: static loss scaling value.\n")
    parser.add_argument("--warmup_proportion",default=0.1,type=float,
                        help="Linear warmup over warmup_steps.")
    parser.add_argument("--adam_epsilon", default=1e-8, type=float,
                        help="Epsilon for Adam optimizer.")
    parser.add_argument("--learning_rate", default=1e-4, type=float,
                        help="The initial learning rate for Adam.")
    parser.add_argument('--seed', type=int, default=42,
                        help="random seed for initialization")
    parser.add_argument('--fp16_opt_level', type=str, default='O1',
                        help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
                             "See details at https://nvidia.github.io/apex/amp.html")
    args = parser.parse_args()

    pregenerated_data = config['data_dir'] / "corpus/train"
    assert pregenerated_data.is_dir(), \
        "--pregenerated_data should point to the folder of files made by prepare_lm_data_mask.py!"

    samples_per_epoch = 0
    for i in range(args.file_num):
        data_file = pregenerated_data / f"file_{i}.json"
        metrics_file = pregenerated_data / f"file_{i}_metrics.json"
        if data_file.is_file() and metrics_file.is_file():
            metrics = json.loads(metrics_file.read_text())
            samples_per_epoch += metrics['num_training_examples']
        else:
            if i == 0:
                exit("No training data was found!")
            print(f"Warning! There are fewer epochs of pregenerated data ({i}) than training epochs ({args.epochs}).")
            print("This script will loop over the available data, but training diversity may be negatively impacted.")
            break
    logger.info(f"samples_per_epoch: {samples_per_epoch}")

    if args.local_rank == -1 or args.no_cuda:
        device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
        n_gpu = torch.cuda.device_count()
    else:
        torch.cuda.set_device(args.local_rank)
        device = torch.device("cuda", args.local_rank)
        n_gpu = 1
        # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
        torch.distributed.init_process_group(backend='nccl')
    logger.info( f"device: {device} n_gpu: {n_gpu}, distributed training: {bool(args.local_rank != -1)}, 16-bits training: {args.fp16}")

    if args.gradient_accumulation_steps < 1:
        raise ValueError(f"Invalid gradient_accumulation_steps parameter: {args.gradient_accumulation_steps}, should be >= 1")
    args.train_batch_size = args.train_batch_size // args.gradient_accumulation_steps

    seed_everything(args.seed)
    tokenizer = CustomTokenizer(vocab_file=config['bert_vocab_path'])
    total_train_examples = samples_per_epoch * args.epochs

    num_train_optimization_steps = int(
        total_train_examples / args.train_batch_size / args.gradient_accumulation_steps)
    if args.local_rank != -1:
        num_train_optimization_steps = num_train_optimization_steps // torch.distributed.get_world_size()
    args.warmup_steps = int(num_train_optimization_steps * args.warmup_proportion)

    # Prepare model
    with open(str(config['bert_config_file']), "r", encoding='utf-8') as reader:
        json_config = json.loads(reader.read())
    print(json_config)
    bert_config = BertConfig.from_json_file(str(config['bert_config_file']))
    model = BertForMaskedLM(config=bert_config)
    # model = BertForMaskedLM.from_pretrained(config['checkpoint_dir'] / 'checkpoint-580000')
    if args.fp16:
        model.half()
    model.to(device)
    if args.local_rank != -1:
        try:
            from apex.parallel import DistributedDataParallel as DDP
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training.")
        model = DDP(model)
    elif n_gpu > 1:
        model = torch.nn.DataParallel(model)

    # Prepare optimizer
    param_optimizer = list(model.named_parameters())
    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [
        {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],'weight_decay': 0.01},
        {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
    ]
    optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
    scheduler = WarmupLinearSchedule(optimizer, warmup_steps=args.warmup_steps, t_total=num_train_optimization_steps)
    if args.fp16:
        try:
            from apex import amp
        except ImportError:
            raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
        model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16_opt_level)
    # if args.fp16:
    #     try:
    #         from apex.optimizers import FP16_Optimizer
    #         from apex.optimizers import FusedAdam
    #     except ImportError:
    #         raise ImportError(
    #             "Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training.")
    #
    #     optimizer = FusedAdam(optimizer_grouped_parameters,
    #                           lr=args.learning_rate,
    #                           bias_correction=False,
    #                           max_grad_norm=1.0)
    #     if args.loss_scale == 0:
    #         optimizer = FP16_Optimizer(optimizer, dynamic_loss_scale=True)
    #     else:
    #         optimizer = FP16_Optimizer(optimizer, static_loss_scale=args.loss_scale)
    # else:
    #     optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
    # scheduler = WarmupLinearSchedule(optimizer, warmup_steps=args.warmup_steps, t_total=num_train_optimization_steps)

    global_step = 0
    metric = LMAccuracy()
    tr_acc = AverageMeter()
    tr_loss = AverageMeter()

    train_logs = {}
    logger.info("***** Running training *****")
    logger.info(f"  Num examples = {total_train_examples}")
    logger.info(f"  Batch size = {args.train_batch_size}" )
    logger.info(f"  Num steps = {num_train_optimization_steps}" )
    logger.info(f"  warmup_steps = {args.warmup_steps}")

    model.train()
    for epoch in range(args.epochs):
        for idx in range(args.file_num):
            epoch_dataset = PregeneratedDataset(file_id=idx, training_path=pregenerated_data, tokenizer=tokenizer,
                                                reduce_memory=args.reduce_memory)
            if args.local_rank == -1:
                train_sampler = RandomSampler(epoch_dataset)
            else:
                train_sampler = DistributedSampler(epoch_dataset)
            train_dataloader = DataLoader(epoch_dataset, sampler=train_sampler, batch_size=args.train_batch_size)

            nb_tr_examples, nb_tr_steps = 0, 0
            for step, batch in enumerate(train_dataloader):
                batch = tuple(t.to(device) for t in batch)
                input_ids, input_mask, segment_ids, lm_label_ids = batch
                outputs = model(input_ids, segment_ids, input_mask, lm_label_ids)
                pred_output = outputs[1]
                loss = outputs[0]
                metric(logits=pred_output.view(-1, bert_config.vocab_size), target=lm_label_ids.view(-1))
                if n_gpu > 1:
                    loss = loss.mean()  # mean() to average on multi-gpu.
                if args.gradient_accumulation_steps > 1:
                    loss = loss / args.gradient_accumulation_steps
                if args.fp16:
                    with amp.scale_loss(loss, optimizer) as scaled_loss:
                        scaled_loss.backward()
                else:
                    loss.backward()

                nb_tr_steps += 1
                tr_acc.update(metric.value(), n=input_ids.size(0))
                tr_loss.update(loss.item(), n=1)

                if (step + 1) % args.gradient_accumulation_steps == 0:
                    # if args.fp16:
                    #     # modify learning rate with special warm up BERT uses
                    #     # if args.fp16 is False, BertAdam is used that handles this automatically
                    #     lr_this_step = args.learning_rate * warmup_linear.get_lr(global_step, args.warmup_proportion)
                    #     for param_group in optimizer.param_groups:
                    #         param_group['lr'] = lr_this_step
                    scheduler.step()  # Update learning rate schedule
                    optimizer.step()
                    optimizer.zero_grad()
                    global_step += 1

                if global_step % args.num_eval_steps == 0:
                    train_logs['loss'] = tr_loss.avg
                    train_logs['acc'] = tr_acc.avg
                    show_info = f'\n[Training]:[{epoch}/{args.epochs}]{global_step}/{num_train_optimization_steps} ' + "-".join(
                        [f' {key}: {value:.4f} ' for key, value in train_logs.items()])
                    logger.info(show_info)
                    tr_acc.reset()
                    tr_loss.reset()

                if global_step % args.num_save_steps == 0:
                    if args.local_rank in [-1, 0] and args.num_save_steps > 0:
                        # Save model checkpoint
                        output_dir = config['checkpoint_dir'] / f'checkpoint-{global_step}'
                        if not output_dir.exists():
                            output_dir.mkdir()
                        # save model
                        model_to_save = model.module if hasattr(model,'module') else model  # Take care of distributed/parallel training
                        model_to_save.save_pretrained(str(output_dir))
                        torch.save(args, str(output_dir / 'training_args.bin'))
                        logger.info("Saving model checkpoint to %s", output_dir)
                        torch.save(args, str(output_dir / 'training_args.bin'))

                        #save config
                        output_config_file = output_dir / CONFIG_NAME
                        with open(str(output_config_file), 'w') as f:
                            f.write(model_to_save.config.to_json_string())
                        #save vocab
                        tokenizer.save_vocabulary(output_dir)
コード例 #10
0
def main():
    parser = ArgumentParser()
    parser.add_argument('--data_name', default='albert', type=str)
    parser.add_argument("--file_num",
                        type=int,
                        default=2,
                        help="Number of pregenerate file")
    parser.add_argument(
        "--reduce_memory",
        action="store_true",
        help=
        "Store training data as on-disc memmaps to massively reduce memory usage"
    )
    parser.add_argument("--epochs",
                        type=int,
                        default=4,
                        help="Number of epochs to train for")
    parser.add_argument('--num_eval_steps', default=20)
    parser.add_argument('--num_save_steps', default=200)
    parser.add_argument('--share_parameter',
                        default=False,
                        action='store_true')
    parser.add_argument("--local_rank",
                        type=int,
                        default=-1,
                        help="local_rank for distributed training on gpus")
    parser.add_argument("--no_cuda",
                        action='store_true',
                        help="Whether not to use CUDA when available")
    parser.add_argument(
        '--gradient_accumulation_steps',
        type=int,
        default=1,
        help=
        "Number of updates steps to accumulate before performing a backward/update pass."
    )
    parser.add_argument("--train_batch_size",
                        default=8,
                        type=int,
                        help="Total batch size for training.")
    parser.add_argument(
        '--loss_scale',
        type=float,
        default=0,
        help=
        "Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n"
        "0 (default value): dynamic loss scaling.\n"
        "Positive power of 2: static loss scaling value.\n")
    parser.add_argument("--warmup_proportion",
                        default=0.1,
                        type=float,
                        help="Linear warmup over warmup_steps.")
    parser.add_argument("--adam_epsilon",
                        default=1e-8,
                        type=float,
                        help="Epsilon for Adam optimizer.")
    parser.add_argument('--max_grad_norm', default=1.0, type=float)
    parser.add_argument("--learning_rate",
                        default=2e-4,
                        type=float,
                        help="The initial learning rate for Adam.")
    parser.add_argument('--seed',
                        type=int,
                        default=42,
                        help="random seed for initialization")
    parser.add_argument(
        '--fp16_opt_level',
        type=str,
        default='O2',
        help=
        "For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
        "See details at https://nvidia.github.io/apex/amp.html")
    parser.add_argument(
        '--fp16',
        action='store_true',
        help="Whether to use 16-bit float precision instead of 32-bit")
    args = parser.parse_args()

    pregenerated_data = config['data_dir'] / "corpus/train"
    assert pregenerated_data.is_dir(), \
        "--pregenerated_data should point to the folder of files made by prepare_lm_data_mask.py!"

    samples_per_epoch = 0
    for i in range(args.file_num):
        data_file = pregenerated_data / f"{args.data_name}_file_{i}.json"
        metrics_file = pregenerated_data / f"{args.data_name}_file_{i}_metrics.json"
        if data_file.is_file() and metrics_file.is_file():
            metrics = json.loads(metrics_file.read_text())
            samples_per_epoch += metrics['num_training_examples']
        else:
            if i == 0:
                exit("No training data was found!")
            print(
                f"Warning! There are fewer epochs of pregenerated data ({i}) than training epochs ({args.epochs})."
            )
            print(
                "This script will loop over the available data, but training diversity may be negatively impacted."
            )
            break
    logger.info(f"samples_per_epoch: {samples_per_epoch}")
    if args.local_rank == -1 or args.no_cuda:
        device = torch.device(f"cuda" if torch.cuda.is_available()
                              and not args.no_cuda else "cpu")
        args.n_gpu = torch.cuda.device_count()
    else:
        torch.cuda.set_device(args.local_rank)
        device = torch.device("cuda", args.local_rank)
        args.n_gpu = 1
        # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
        torch.distributed.init_process_group(backend='nccl')
    logger.info(
        f"device: {device} , distributed training: {bool(args.local_rank != -1)}, 16-bits training: {args.fp16}, share_parameter: {args.share_parameter}"
    )

    if args.gradient_accumulation_steps < 1:
        raise ValueError(
            f"Invalid gradient_accumulation_steps parameter: {args.gradient_accumulation_steps}, should be >= 1"
        )
    args.train_batch_size = args.train_batch_size // args.gradient_accumulation_steps
    seed_everything(args.seed)
    tokenizer = BertTokenizer(vocab_file=config['checkpoint_dir'] /
                              'vocab.txt')
    total_train_examples = samples_per_epoch * args.epochs

    num_train_optimization_steps = int(total_train_examples /
                                       args.train_batch_size /
                                       args.gradient_accumulation_steps)
    if args.local_rank != -1:
        num_train_optimization_steps = num_train_optimization_steps // torch.distributed.get_world_size(
        )
    args.warmup_steps = int(num_train_optimization_steps *
                            args.warmup_proportion)

    bert_config = BertConfig.from_json_file(
        str(config['checkpoint_dir'] / 'config.json'))
    if args.share_parameter:
        bert_config.share_parameter_across_layers = True
    else:
        bert_config.share_parameter_across_layers = False
    model = BertForPreTraining(config=bert_config)
    # model = BertForMaskedLM.from_pretrained(config['checkpoint_dir'] / 'checkpoint-580000')
    model.to(device)
    # Prepare optimizer
    param_optimizer = list(model.named_parameters())
    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [{
        'params':
        [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
        'weight_decay':
        0.01
    }, {
        'params':
        [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
        'weight_decay':
        0.0
    }]
    optimizer = AdamW(optimizer_grouped_parameters,
                      lr=args.learning_rate,
                      eps=args.adam_epsilon)
    lr_scheduler = WarmupLinearSchedule(optimizer,
                                        warmup_steps=args.warmup_steps,
                                        t_total=num_train_optimization_steps)
    if args.fp16:
        try:
            from apex import amp
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use fp16 training."
            )
        model, optimizer = amp.initialize(model,
                                          optimizer,
                                          opt_level=args.fp16_opt_level)

    if args.n_gpu > 1:
        model = torch.nn.DataParallel(model)

    if args.local_rank != -1:
        model = torch.nn.parallel.DistributedDataParallel(
            model, device_ids=[args.local_rank], output_device=args.local_rank)
    global_step = 0
    mask_metric = LMAccuracy()
    sop_metric = LMAccuracy()
    tr_mask_acc = AverageMeter()
    tr_sop_acc = AverageMeter()
    tr_loss = AverageMeter()
    tr_mask_loss = AverageMeter()
    tr_sop_loss = AverageMeter()
    loss_fct = CrossEntropyLoss(ignore_index=-1)

    train_logs = {}
    logger.info("***** Running training *****")
    logger.info(f"  Num examples = {total_train_examples}")
    logger.info(f"  Batch size = {args.train_batch_size}")
    logger.info(f"  Num steps = {num_train_optimization_steps}")
    logger.info(f"  warmup_steps = {args.warmup_steps}")
    start_time = time.time()
    seed_everything(args.seed)  # Added here for reproducibility
    for epoch in range(args.epochs):
        for idx in range(args.file_num):
            epoch_dataset = PregeneratedDataset(
                file_id=idx,
                training_path=pregenerated_data,
                tokenizer=tokenizer,
                reduce_memory=args.reduce_memory,
                data_name=args.data_name)
            if args.local_rank == -1:
                train_sampler = RandomSampler(epoch_dataset)
            else:
                train_sampler = DistributedSampler(epoch_dataset)
            train_dataloader = DataLoader(epoch_dataset,
                                          sampler=train_sampler,
                                          batch_size=args.train_batch_size)
            model.train()
            nb_tr_examples, nb_tr_steps = 0, 0
            for step, batch in enumerate(train_dataloader):
                batch = tuple(t.to(device) for t in batch)
                input_ids, input_mask, segment_ids, lm_label_ids, is_next = batch
                outputs = model(input_ids=input_ids,
                                token_type_ids=segment_ids,
                                attention_mask=input_mask)
                prediction_scores = outputs[0]
                seq_relationship_score = outputs[1]

                masked_lm_loss = loss_fct(
                    prediction_scores.view(-1, bert_config.vocab_size),
                    lm_label_ids.view(-1))
                next_sentence_loss = loss_fct(
                    seq_relationship_score.view(-1, 2), is_next.view(-1))
                loss = masked_lm_loss + next_sentence_loss

                mask_metric(logits=prediction_scores.view(
                    -1, bert_config.vocab_size),
                            target=lm_label_ids.view(-1))
                sop_metric(logits=seq_relationship_score.view(-1, 2),
                           target=is_next.view(-1))

                if args.n_gpu > 1:
                    loss = loss.mean()  # mean() to average on multi-gpu.
                if args.gradient_accumulation_steps > 1:
                    loss = loss / args.gradient_accumulation_steps
                if args.fp16:
                    with amp.scale_loss(loss, optimizer) as scaled_loss:
                        scaled_loss.backward()
                else:
                    loss.backward()

                nb_tr_steps += 1
                tr_mask_acc.update(mask_metric.value(), n=input_ids.size(0))
                tr_sop_acc.update(sop_metric.value(), n=input_ids.size(0))
                tr_loss.update(loss.item(), n=1)
                tr_mask_loss.update(masked_lm_loss.item(), n=1)
                tr_sop_loss.update(next_sentence_loss.item(), n=1)

                if (step + 1) % args.gradient_accumulation_steps == 0:
                    if args.fp16:
                        torch.nn.utils.clip_grad_norm_(
                            amp.master_params(optimizer), args.max_grad_norm)
                    else:
                        torch.nn.utils.clip_grad_norm_(model.parameters(),
                                                       args.max_grad_norm)
                    lr_scheduler.step()
                    optimizer.step()
                    optimizer.zero_grad()
                    global_step += 1

                if global_step % args.num_eval_steps == 0:
                    now = time.time()
                    eta = now - start_time
                    if eta > 3600:
                        eta_format = ('%d:%02d:%02d' %
                                      (eta // 3600,
                                       (eta % 3600) // 60, eta % 60))
                    elif eta > 60:
                        eta_format = '%d:%02d' % (eta // 60, eta % 60)
                    else:
                        eta_format = '%ds' % eta
                    train_logs['loss'] = tr_loss.avg
                    train_logs['mask_acc'] = tr_mask_acc.avg
                    train_logs['sop_acc'] = tr_sop_acc.avg
                    train_logs['mask_loss'] = tr_mask_loss.avg
                    train_logs['sop_loss'] = tr_sop_loss.avg
                    show_info = f'[Training]:[{epoch}/{args.epochs}]{global_step}/{num_train_optimization_steps} ' \
                                    f'- ETA: {eta_format}' + "-".join(
                        [f' {key}: {value:.4f} ' for key, value in train_logs.items()])
                    logger.info(show_info)
                    tr_mask_acc.reset()
                    tr_sop_acc.reset()
                    tr_loss.reset()
                    tr_mask_loss.reset()
                    tr_sop_loss.reset()
                    start_time = now

                if global_step % args.num_save_steps == 0:
                    if args.local_rank in [-1, 0] and args.num_save_steps > 0:
                        # Save model checkpoint
                        output_dir = config[
                            'checkpoint_dir'] / f'lm-checkpoint-{global_step}'
                        if not output_dir.exists():
                            output_dir.mkdir()
                        # save model
                        model_to_save = model.module if hasattr(
                            model, 'module'
                        ) else model  # Take care of distributed/parallel training
                        model_to_save.save_pretrained(str(output_dir))
                        torch.save(args, str(output_dir / 'training_args.bin'))
                        logger.info("Saving model checkpoint to %s",
                                    output_dir)

                        # save config
                        output_config_file = output_dir / CONFIG_NAME
                        with open(str(output_config_file), 'w') as f:
                            f.write(model_to_save.config.to_json_string())

                        # save vocab
                        tokenizer.save_vocabulary(output_dir)
コード例 #11
0
def main(args):

    train_set = get_dataset(
        args.dataset,
        args.data_dir,
        transform=get_aug(args.model, args.image_size, True),
        train=True,
        download=args.download,  # default is False
        debug_subset_size=args.batch_size
        if args.debug else None  # run one batch if debug
    )

    train_loader = torch.utils.data.DataLoader(dataset=train_set,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=args.num_workers,
                                               pin_memory=True,
                                               drop_last=True)

    # define model
    model = get_model(args.model, args.backbone).to(args.device)
    backbone = model.backbone
    if args.model == 'simsiam' and args.proj_layers is not None:
        model.projector.set_layers(args.proj_layers)

    if args.local_rank >= 0:
        torch.cuda.set_device(args.local_rank)
        torch.distributed.init_process_group(backend="nccl",
                                             init_method="env://")
        model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
        model = torch.nn.parallel.DistributedDataParallel(
            model,
            device_ids=[args.local_rank],
            output_device=args.local_rank,
            find_unused_parameters=True)

    # define optimizer
    optimizer = get_optimizer(args.optimizer,
                              model,
                              lr=args.base_lr * args.batch_size / 256,
                              momentum=args.momentum,
                              weight_decay=args.weight_decay)

    lr_scheduler = LR_Scheduler(optimizer, args.warmup_epochs,
                                args.warmup_lr * args.batch_size / 256,
                                args.num_epochs,
                                args.base_lr * args.batch_size / 256,
                                args.final_lr * args.batch_size / 256,
                                len(train_loader))

    loss_meter = AverageMeter(name='Loss')
    plot_logger = PlotLogger(params=['epoch', 'lr', 'loss'])
    os.makedirs(args.output_dir, exist_ok=True)
    # Start training
    global_progress = tqdm(range(0, args.stop_at_epoch), desc=f'Training')
    for epoch in global_progress:
        loss_meter.reset()
        model.train()

        local_progress = tqdm(train_loader,
                              desc=f'Epoch {epoch}/{args.num_epochs}',
                              disable=args.hide_progress)
        for idx, ((images1, images2), labels) in enumerate(local_progress):

            model.zero_grad()
            loss = model.forward(images1.to(args.device),
                                 images2.to(args.device))
            loss.backward()
            optimizer.step()
            loss_meter.update(loss.item())
            lr = lr_scheduler.step()
            local_progress.set_postfix({'lr': lr, "loss": loss_meter.val})
            plot_logger.update({
                'epoch': epoch,
                'lr': lr,
                'loss': loss_meter.val
            })
        global_progress.set_postfix({
            "epoch": epoch,
            "loss_avg": loss_meter.avg
        })
        plot_logger.save(os.path.join(args.output_dir, 'logger.svg'))

    # Save checkpoint
    if args.local_rank <= 0:
        model_path = os.path.join(
            args.output_dir,
            f'{args.model}-{args.dataset}-epoch{args.stop_at_epoch}.pth')
        torch.save(
            {
                'epoch': args.stop_at_epoch,
                'state_dict': model.state_dict(),
                # 'optimizer':optimizer.state_dict(), # will double the checkpoint file size
                'lr_scheduler': lr_scheduler,
                'args': args,
                'loss_meter': loss_meter,
                'plot_logger': plot_logger
            },
            model_path)
        print(f"Model saved to {model_path}")

    if args.eval_after_train is not None:
        arg_list = [
            x.strip().lstrip('--').split()
            for x in args.eval_after_train.split('\n')
        ]
        args.__dict__.update({x[0]: eval(x[1]) for x in arg_list})
        args.distributed_initialized = True
        if args.debug:
            args.batch_size = 2
            args.num_epochs = 3

        linear_eval(args, backbone)
コード例 #12
0
ファイル: main.py プロジェクト: yyht/SimSiam
def main(args):

    train_set = get_dataset(
        args.dataset, 
        args.data_dir, 
        transform=get_aug(args.model, args.image_size, True), 
        train=True, 
        download=args.download # default is False
    )
    
    if args.debug:
        args.batch_size = 2 
        args.num_epochs = 1 # train only one epoch
        args.num_workers = 0
        train_set = torch.utils.data.Subset(train_set, range(0, args.batch_size)) # take only one batch

    train_loader = torch.utils.data.DataLoader(
        dataset=train_set,
        batch_size=args.batch_size,
        shuffle=True,
        num_workers=args.num_workers,
        pin_memory=True,
        drop_last=True
    )

    # define model
    model = get_model(args.model, args.backbone).to(args.device)
    model = torch.nn.DataParallel(model)
    if torch.cuda.device_count() > 1: model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
    
    # define optimizer
    optimizer = get_optimizer(
        args.optimizer, model, 
        lr=args.base_lr*args.batch_size/256, 
        momentum=args.momentum, 
        weight_decay=args.weight_decay)

    # TODO: linear lr warm up for byol simclr swav
    # args.warm_up_epochs

    # define lr scheduler
    lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer, args.num_epochs, eta_min=0)

    loss_meter = AverageMeter(name='Loss')

    # Start training
    for epoch in tqdm(range(0, args.num_epochs), desc=f'Training'):
        loss_meter.reset()
        model.train()
        p_bar=tqdm(train_loader, desc=f'Epoch {epoch}/{args.num_epochs}')
        for idx, ((images1, images2), labels) in enumerate(p_bar):
            # breakpoint()
            model.zero_grad()
            loss = model.forward(images1.to(args.device), images2.to(args.device))
            loss.backward()
            optimizer.step()
            loss_meter.update(loss.item())
            p_bar.set_postfix({"loss":loss_meter.val, 'loss_avg':loss_meter.avg})

        lr_scheduler.step()


        # Save checkpoint
        os.makedirs(args.output_dir, exist_ok=True)
        model_path = os.path.join(args.output_dir, f'{args.model}-{args.dataset}-epoch{epoch+1}.pth')
        torch.save({
            'epoch': epoch+1,
            'state_dict':model.module.state_dict(),
            # 'optimizer':optimizer.state_dict(), # will double the checkpoint file size
            'lr_scheduler':lr_scheduler.state_dict(),
            'args':args,
            'loss_meter':loss_meter
        }, model_path)
    print(f"Model saved to {model_path}")