Esempio n. 1
0
 def _build_scheduler(self):
     if self.get("scheduler/use", False):
         self._base_scheduler = CosineAnnealingLR(
             self.optim,
             T_max=self.get("training/num_epochs"),
             **self.get("scheduler/kwargs", {}),
         )
     else:
         self._base_scheduler = None
     # Support for LR warmup
     if self.get("scheduler/warmup", False):
         assert self._base_scheduler is not None
         self.scheduler = GradualWarmupScheduler(
             self.optim,
             multiplier=1,
             total_epoch=5,
             after_scheduler=self._base_scheduler,
         )
     else:
         self.scheduler = self._base_scheduler
Esempio n. 2
0
def main(args, ITE=0):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    reinit = True if args.prune_type == "reinit" else False
    if args.save_dir:
        utils.checkdir(
            f"{os.getcwd()}/saves/{args.arch_type}/{args.dataset}/{args.save_dir}/"
        )
        utils.checkdir(
            f"{os.getcwd()}/plots/lt/{args.arch_type}/{args.dataset}/{args.save_dir}/"
        )
        utils.checkdir(
            f"{os.getcwd()}/dumps/lt/{args.arch_type}/{args.dataset}/{args.save_dir}/"
        )
    else:
        utils.checkdir(
            f"{os.getcwd()}/dumps/lt/{args.arch_type}/{args.dataset}/")
        utils.checkdir(
            f"{os.getcwd()}/plots/lt/{args.arch_type}/{args.dataset}/")
        utils.checkdir(f"{os.getcwd()}/saves/{args.arch_type}/{args.dataset}/")

    # Data Loader
    transform = transforms.Compose(
        [transforms.ToTensor(),
         transforms.Normalize((0.1307, ), (0.3081, ))])
    if args.dataset == "mnist":
        traindataset = datasets.MNIST('../data',
                                      train=True,
                                      download=True,
                                      transform=transform)
        testdataset = datasets.MNIST('../data',
                                     train=False,
                                     transform=transform)
        from archs.mnist import AlexNet, LeNet5, fc1, vgg, resnet

    elif args.dataset == "cifar10":
        transform_train = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465),
                                 (0.2023, 0.1994, 0.2010)),
        ])

        transform_test = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465),
                                 (0.2023, 0.1994, 0.2010)),
        ])

        traindataset = datasets.CIFAR10('../data',
                                        train=True,
                                        download=True,
                                        transform=transform_train)
        testdataset = datasets.CIFAR10('../data',
                                       train=False,
                                       transform=transform_test)
        from archs.cifar10 import AlexNet, LeNet5, fc1, vgg, resnet, densenet

    elif args.dataset == "fashionmnist":
        traindataset = datasets.FashionMNIST('../data',
                                             train=True,
                                             download=True,
                                             transform=transform)
        testdataset = datasets.FashionMNIST('../data',
                                            train=False,
                                            transform=transform)
        from archs.mnist import AlexNet, LeNet5, fc1, vgg, resnet

    elif args.dataset == "cifar100":
        traindataset = datasets.CIFAR100('../data',
                                         train=True,
                                         download=True,
                                         transform=transform)
        testdataset = datasets.CIFAR100('../data',
                                        train=False,
                                        transform=transform)
        from archs.cifar100 import AlexNet, fc1, LeNet5, vgg, resnet

    # If you want to add extra datasets paste here

    else:
        print("\nWrong Dataset choice \n")
        exit()

    if args.dataset == "cifar10":
        #trainsampler = torch.utils.data.RandomSampler(traindataset, replacement=True, num_samples=45000)  # 45K train dataset
        #train_loader = torch.utils.data.DataLoader(traindataset, batch_size=args.batch_size, shuffle=False, num_workers=0, drop_last=False, sampler=trainsampler)
        train_loader = torch.utils.data.DataLoader(traindataset,
                                                   batch_size=args.batch_size,
                                                   shuffle=True,
                                                   num_workers=4)
    else:
        train_loader = torch.utils.data.DataLoader(traindataset,
                                                   batch_size=args.batch_size,
                                                   shuffle=True,
                                                   num_workers=0,
                                                   drop_last=False)
    #train_loader = cycle(train_loader)
    test_loader = torch.utils.data.DataLoader(testdataset,
                                              batch_size=args.batch_size,
                                              shuffle=False,
                                              num_workers=4)

    # Importing Network Architecture

    #Initalize hessian dataloader, default batch_num 1
    for inputs, labels in train_loader:
        hessian_dataloader = (inputs, labels)
        break

    global model
    if args.arch_type == "fc1":
        model = fc1.fc1().to(device)
    elif args.arch_type == "lenet5":
        model = LeNet5.LeNet5().to(device)
    elif args.arch_type == "alexnet":
        model = AlexNet.AlexNet().to(device)
    elif args.arch_type == "vgg16":
        model = vgg.vgg16().to(device)
    elif args.arch_type == "resnet18":
        model = resnet.resnet18().to(device)
    elif args.arch_type == "densenet121":
        model = densenet.densenet121().to(device)
    # If you want to add extra model paste here
    else:
        print("\nWrong Model choice\n")
        exit()

    model = nn.DataParallel(model)
    # Weight Initialization
    model.apply(weight_init)

    # Copying and Saving Initial State
    initial_state_dict = copy.deepcopy(model.state_dict())
    if args.save_dir:
        torch.save(
            model.state_dict(),
            f"{os.getcwd()}/saves/{args.arch_type}/{args.dataset}/{args.save_dir}/initial_state_dict_{args.prune_type}.pth"
        )
    else:
        torch.save(
            model.state_dict(),
            f"{os.getcwd()}/saves/{args.arch_type}/{args.dataset}/initial_state_dict_{args.prune_type}.pth"
        )

    # global total_params
    total_params = 0
    # Layer Looper
    for name, param in model.named_parameters():
        print(name, param.size())
        total_params += param.numel()

    # Making Initial Mask
    make_mask(model, total_params)

    # Optimizer and Loss
    optimizer = torch.optim.SGD(model.parameters(),
                                lr=args.lr,
                                momentum=0.9,
                                weight_decay=1e-4)
    # warm up schedule; scheduler_warmup is chained with schduler_steplr
    scheduler_steplr = torch.optim.lr_scheduler.MultiStepLR(optimizer,
                                                            milestones=[0, 15],
                                                            gamma=0.1,
                                                            last_epoch=-1)
    if args.warmup:
        scheduler_warmup = GradualWarmupScheduler(
            optimizer,
            multiplier=1,
            total_epoch=50,
            after_scheduler=scheduler_steplr)  # 20K=(idx)56, 35K=70
    criterion = nn.CrossEntropyLoss(
    )  # Default was F.nll_loss; why test, train different?

    # Pruning
    # NOTE First Pruning Iteration is of No Compression
    bestacc = 0.0
    best_accuracy = 0
    ITERATION = args.prune_iterations
    comp = np.zeros(ITERATION, float)
    bestacc = np.zeros(ITERATION, float)
    step = 0
    all_loss = np.zeros(args.end_iter, float)
    all_accuracy = np.zeros(args.end_iter, float)

    for _ite in range(args.start_iter, ITERATION):
        if not _ite == 0:
            prune_by_percentile(args.prune_percent,
                                resample=resample,
                                reinit=reinit,
                                total_params=total_params,
                                hessian_aware=args.hessian,
                                criterion=criterion,
                                dataloader=hessian_dataloader,
                                cuda=torch.cuda.is_available())
            if reinit:
                model.apply(weight_init)
                #if args.arch_type == "fc1":
                #    model = fc1.fc1().to(device)
                #elif args.arch_type == "lenet5":
                #    model = LeNet5.LeNet5().to(device)
                #elif args.arch_type == "alexnet":
                #    model = AlexNet.AlexNet().to(device)
                #elif args.arch_type == "vgg16":
                #    model = vgg.vgg16().to(device)
                #elif args.arch_type == "resnet18":
                #    model = resnet.resnet18().to(device)
                #elif args.arch_type == "densenet121":
                #    model = densenet.densenet121().to(device)
                #else:
                #    print("\nWrong Model choice\n")
                #    exit()
                step = 0
                for name, param in model.named_parameters():
                    if 'weight' in name:
                        param_frac = param.numel() / total_params
                        if param_frac > 0.01:
                            weight_dev = param.device
                            param.data = torch.from_numpy(
                                param.data.cpu().numpy() *
                                mask[step]).to(weight_dev)
                            step = step + 1
                step = 0
            else:
                original_initialization(mask, initial_state_dict, total_params)
            # optimizer = torch.optim.SGD([{'params': model.parameters(), 'initial_lr': 0.03}], lr=args.lr, momentum=0.9, weight_decay=1e-4)
            # scheduler_steplr = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[0, 14], gamma=0.1, last_epoch=-1)
            # scheduler_warmup = GradualWarmupScheduler(optimizer, multiplier=1, total_epoch=56, after_scheduler=scheduler_steplr)  # 20K=(idx)56, 35K=70
        print(f"\n--- Pruning Level [{ITE}:{_ite}/{ITERATION}]: ---")

        # Optimizer and Loss
        optimizer = torch.optim.SGD(model.parameters(),
                                    lr=args.lr,
                                    momentum=0.9,
                                    weight_decay=1e-4)
        # warm up schedule; scheduler_warmup is chained with schduler_steplr
        scheduler_steplr = torch.optim.lr_scheduler.MultiStepLR(
            optimizer, milestones=[0, 15], gamma=0.1, last_epoch=-1)
        if args.warmup:
            scheduler_warmup = GradualWarmupScheduler(
                optimizer,
                multiplier=1,
                total_epoch=50,
                after_scheduler=scheduler_steplr)  # 20K=(idx)56, 35K=70

        # Print the table of Nonzeros in each layer
        comp1 = utils.print_nonzeros(model)
        comp[_ite] = comp1
        pbar = tqdm(range(args.end_iter))  # process bar

        for iter_ in pbar:

            # Frequency for Testing
            if iter_ % args.valid_freq == 0:
                accuracy = test(model, test_loader, criterion)

                # Save Weights for each _ite
                if accuracy > best_accuracy:
                    best_accuracy = accuracy
                    if args.save_dir:
                        torch.save(
                            model.state_dict(),
                            f"{os.getcwd()}/saves/{args.arch_type}/{args.dataset}/{args.save_dir}/{_ite}_model_{args.prune_type}.pth"
                        )
                    else:
                        # torch.save(model,f"{os.getcwd()}/saves/{args.arch_type}/{args.dataset}/{_ite}_model_{args.prune_type}.pth")
                        torch.save(
                            model.state_dict(),
                            f"{os.getcwd()}/saves/{args.arch_type}/{args.dataset}/{_ite}_model_{args.prune_type}.pth"
                        )

            # Training
            loss = train(model, train_loader, optimizer, criterion,
                         total_params)
            all_loss[iter_] = loss
            all_accuracy[iter_] = accuracy

            # warm up
            if args.warmup:
                scheduler_warmup.step()
            _lr = optimizer.param_groups[0]['lr']

            # Save the model during training
            if args.save_freq > 0 and iter_ % args.save_freq == 0:
                if args.save_dir:
                    torch.save(
                        model.state_dict(),
                        f"{os.getcwd()}/saves/{args.arch_type}/{args.dataset}/{args.save_dir}/{_ite}_model_{args.prune_type}_epoch{iter_}.pth"
                    )
                else:
                    torch.save(
                        model.state_dict(),
                        f"{os.getcwd()}/saves/{args.arch_type}/{args.dataset}/{_ite}_model_{args.prune_type}_epoch{iter_}.pth"
                    )

            # Frequency for Printing Accuracy and Loss
            if iter_ % args.print_freq == 0:
                pbar.set_description(
                    f'Train Epoch: {iter_}/{args.end_iter} Loss: {loss:.6f} Accuracy: {accuracy:.2f}% Best Accuracy: {best_accuracy:.2f}% Learning Rate: {_lr:.6f}%'
                )

        writer.add_scalar('Accuracy/test', best_accuracy, comp1)
        bestacc[_ite] = best_accuracy

        # Plotting Loss (Training), Accuracy (Testing), Iteration Curve
        #NOTE Loss is computed for every iteration while Accuracy is computed only for every {args.valid_freq} iterations. Therefore Accuracy saved is constant during the uncomputed iterations.
        #NOTE Normalized the accuracy to [0,100] for ease of plotting.
        plt.plot(np.arange(1, (args.end_iter) + 1),
                 100 * (all_loss - np.min(all_loss)) /
                 np.ptp(all_loss).astype(float),
                 c="blue",
                 label="Loss")
        plt.plot(np.arange(1, (args.end_iter) + 1),
                 all_accuracy,
                 c="red",
                 label="Accuracy")
        plt.title(
            f"Loss Vs Accuracy Vs Iterations ({args.dataset},{args.arch_type})"
        )
        plt.xlabel("Iterations")
        plt.ylabel("Loss and Accuracy")
        plt.legend()
        plt.grid(color="gray")
        if args.save_dir:
            plt.savefig(
                f"{os.getcwd()}/plots/lt/{args.arch_type}/{args.dataset}/{args.save_dir}/{args.prune_type}_LossVsAccuracy_{comp1}.png",
                dpi=1200)
        else:
            plt.savefig(
                f"{os.getcwd()}/plots/lt/{args.arch_type}/{args.dataset}/{args.prune_type}_LossVsAccuracy_{comp1}.png",
                dpi=1200)
        plt.close()

        # Dump Plot values
        if args.save_dir:
            all_loss.dump(
                f"{os.getcwd()}/dumps/lt/{args.arch_type}/{args.dataset}/{args.save_dir}/{args.prune_type}_all_loss_{comp1}.dat"
            )
            all_accuracy.dump(
                f"{os.getcwd()}/dumps/lt/{args.arch_type}/{args.dataset}/{args.save_dir}/{args.prune_type}_all_accuracy_{comp1}.dat"
            )
        else:
            all_loss.dump(
                f"{os.getcwd()}/dumps/lt/{args.arch_type}/{args.dataset}/{args.prune_type}_all_loss_{comp1}.dat"
            )
            all_accuracy.dump(
                f"{os.getcwd()}/dumps/lt/{args.arch_type}/{args.dataset}/{args.prune_type}_all_accuracy_{comp1}.dat"
            )

        # Dumping mask
        if args.save_dir:
            with open(
                    f"{os.getcwd()}/dumps/lt/{args.arch_type}/{args.dataset}/{args.save_dir}/{args.prune_type}_mask_{comp1}.pkl",
                    'wb') as fp:
                pickle.dump(mask, fp)
        else:
            with open(
                    f"{os.getcwd()}/dumps/lt/{args.arch_type}/{args.dataset}/{args.prune_type}_mask_{comp1}.pkl",
                    'wb') as fp:
                pickle.dump(mask, fp)

        # Making variables into 0
        best_accuracy = 0
        all_loss = np.zeros(args.end_iter, float)
        all_accuracy = np.zeros(args.end_iter, float)

    # Dumping Values for Plotting
    if args.save_dir:
        comp.dump(
            f"{os.getcwd()}/dumps/lt/{args.arch_type}/{args.dataset}/{args.save_dir}/{args.prune_type}_compression.dat"
        )
        bestacc.dump(
            f"{os.getcwd()}/dumps/lt/{args.arch_type}/{args.dataset}/{args.save_dir}/{args.prune_type}_bestaccuracy.dat"
        )
    else:
        comp.dump(
            f"{os.getcwd()}/dumps/lt/{args.arch_type}/{args.dataset}/{args.prune_type}_compression.dat"
        )
        bestacc.dump(
            f"{os.getcwd()}/dumps/lt/{args.arch_type}/{args.dataset}/{args.prune_type}_bestaccuracy.dat"
        )
    # Plotting
    a = np.arange(args.prune_iterations)
    plt.plot(a, bestacc, c="blue", label="Winning tickets")
    plt.title(
        f"Test Accuracy vs Unpruned Weights Percentage ({args.dataset},{args.arch_type})"
    )
    plt.xlabel("Unpruned Weights Percentage")
    plt.ylabel("test accuracy")
    plt.xticks(a, comp, rotation="vertical")
    plt.ylim(0, 100)
    plt.legend()
    plt.grid(color="gray")
    if args.save_dir:
        plt.savefig(
            f"{os.getcwd()}/plots/lt/{args.arch_type}/{args.dataset}/{args.save_dir}/{args.prune_type}_AccuracyVsWeights.png",
            dpi=1200)
    else:
        plt.savefig(
            f"{os.getcwd()}/plots/lt/{args.arch_type}/{args.dataset}/{args.prune_type}_AccuracyVsWeights.png",
            dpi=1200)
    plt.close()
Esempio n. 3
0
def main(seed):
    with timer('load data'):
        df = pd.read_csv(FOLD_PATH)
        y1 = (df.EncodedPixels_1 != "-1").astype("float32").values.reshape(
            -1, 1)
        y2 = (df.EncodedPixels_2 != "-1").astype("float32").values.reshape(
            -1, 1)
        y3 = (df.EncodedPixels_3 != "-1").astype("float32").values.reshape(
            -1, 1)
        y4 = (df.EncodedPixels_4 != "-1").astype("float32").values.reshape(
            -1, 1)
        y = np.concatenate([y1, y2, y3, y4], axis=1)
        #y = (df.sum_target != 0).astype("float32").values

    with timer('preprocessing'):
        train_df, val_df = df[df.fold_id != FOLD_ID], df[df.fold_id == FOLD_ID]
        y_train, y_val = y[df.fold_id != FOLD_ID], y[df.fold_id == FOLD_ID]

        train_augmentation = Compose([
            Flip(p=0.5),
            OneOf([
                GridDistortion(p=0.5),
                OpticalDistortion(p=0.5, distort_limit=2, shift_limit=0.5)
            ],
                  p=0.5),
            OneOf([
                RandomGamma(gamma_limit=(100, 140), p=0.5),
                RandomBrightnessContrast(p=0.5),
                RandomBrightness(p=0.5),
                RandomContrast(p=0.5)
            ],
                  p=0.5),
            OneOf([
                GaussNoise(p=0.5),
                Cutout(num_holes=10, max_h_size=10, max_w_size=20, p=0.5)
            ],
                  p=0.5),
            ShiftScaleRotate(rotate_limit=20, p=0.5),
        ])
        val_augmentation = None

        train_dataset = SeverCLSDataset(train_df,
                                        IMG_DIR,
                                        IMG_SIZE,
                                        N_CLASSES,
                                        y_train,
                                        id_colname=ID_COLUMNS,
                                        transforms=train_augmentation,
                                        crop_rate=1.0)
        val_dataset = SeverCLSDataset(val_df,
                                      IMG_DIR,
                                      IMG_SIZE,
                                      N_CLASSES,
                                      y_val,
                                      id_colname=ID_COLUMNS,
                                      transforms=val_augmentation)
        #train_sampler = MaskProbSampler(train_df, demand_non_empty_proba=0.6)
        train_loader = DataLoader(train_dataset,
                                  batch_size=BATCH_SIZE,
                                  shuffle=True,
                                  num_workers=8,
                                  pin_memory=True)
        val_loader = DataLoader(val_dataset,
                                batch_size=BATCH_SIZE,
                                shuffle=False,
                                num_workers=8,
                                pin_memory=True)

        del train_df, val_df, df, train_dataset, val_dataset
        gc.collect()

    with timer('create model'):
        model = ResNet(num_classes=N_CLASSES,
                       pretrained="imagenet",
                       net_cls=models.resnet50)
        #model = convert_model(model)
        if base_model is not None:
            model.load_state_dict(torch.load(base_model))
        model.to(device)

        criterion = torch.nn.BCEWithLogitsLoss()
        optimizer = torch.optim.Adam(model.parameters(), lr=3e-4, eps=1e-4)
        if base_model is None:
            scheduler_cosine = CosineAnnealingLR(optimizer,
                                                 T_max=CLR_CYCLE,
                                                 eta_min=3e-5)
            scheduler = GradualWarmupScheduler(
                optimizer,
                multiplier=1.1,
                total_epoch=CLR_CYCLE * 2,
                after_scheduler=scheduler_cosine)
        else:
            scheduler = CosineAnnealingLR(optimizer,
                                          T_max=CLR_CYCLE,
                                          eta_min=3e-5)

        model, optimizer = amp.initialize(model,
                                          optimizer,
                                          opt_level="O1",
                                          verbosity=0)

        if EMA:
            ema_model = copy.deepcopy(model)
            if base_model_ema is not None:
                ema_model.load_state_dict(torch.load(base_model_ema))
            ema_model.to(device)
        else:
            ema_model = None
        model = torch.nn.DataParallel(model)
        ema_model = torch.nn.DataParallel(ema_model)

    with timer('train'):
        train_losses = []
        valid_losses = []

        best_model_loss = 999
        best_model_ema_loss = 999
        best_model_ep = 0
        ema_decay = 0
        checkpoint = base_ckpt + 1

        for epoch in range(1, EPOCHS + 1):
            seed = seed + epoch
            seed_torch(seed)

            if epoch >= EMA_START:
                ema_decay = 0.99

            LOGGER.info("Starting {} epoch...".format(epoch))
            tr_loss = train_one_epoch(model,
                                      train_loader,
                                      criterion,
                                      optimizer,
                                      device,
                                      ema_model=ema_model,
                                      ema_decay=ema_decay)
            train_losses.append(tr_loss)
            LOGGER.info('Mean train loss: {}'.format(round(tr_loss, 5)))

            valid_loss, y_pred, y_true = validate(model, val_loader, criterion,
                                                  device)
            valid_losses.append(valid_loss)
            LOGGER.info('Mean valid loss: {}'.format(round(valid_loss, 5)))

            if EMA and epoch >= EMA_START:
                ema_valid_loss, y_pred_ema, _ = validate(
                    ema_model, val_loader, criterion, device)
                LOGGER.info('Mean EMA valid loss: {}'.format(
                    round(ema_valid_loss, 5)))

                if ema_valid_loss < best_model_ema_loss:
                    torch.save(
                        ema_model.module.state_dict(),
                        'models/{}_fold{}_ckpt{}_ema.pth'.format(
                            EXP_ID, FOLD_ID, checkpoint))
                    best_model_ema_loss = ema_valid_loss
                    np.save("y_pred_ema_ckpt{}.npy".format(checkpoint),
                            y_pred_ema)

            scheduler.step()

            if valid_loss < best_model_loss:
                torch.save(
                    model.module.state_dict(),
                    'models/{}_fold{}_ckpt{}.pth'.format(
                        EXP_ID, FOLD_ID, checkpoint))
                np.save("y_pred_ckpt{}.npy".format(checkpoint), y_pred)
                best_model_loss = valid_loss
                best_model_ep = epoch
                #np.save("val_pred.npy", val_pred)

            if epoch % (CLR_CYCLE * 2) == CLR_CYCLE * 2 - 1:
                torch.save(
                    model.module.state_dict(),
                    'models/{}_fold{}_latest.pth'.format(EXP_ID, FOLD_ID))
                LOGGER.info('Best valid loss: {} on epoch={}'.format(
                    round(best_model_loss, 5), best_model_ep))
                if EMA:
                    torch.save(
                        ema_model.module.state_dict(),
                        'models/{}_fold{}_latest_ema.pth'.format(
                            EXP_ID, FOLD_ID))
                    LOGGER.info('Best ema valid loss: {}'.format(
                        round(best_model_ema_loss, 5)))
                checkpoint += 1
                best_model_loss = 999
                best_model_ema_loss = 999

            #del val_pred
            gc.collect()

    LOGGER.info('Best valid loss: {} on epoch={}'.format(
        round(best_model_loss, 5), best_model_ep))

    xs = list(range(1, len(train_losses) + 1))
    plt.plot(xs, train_losses, label='Train loss')
    plt.plot(xs, valid_losses, label='Val loss')
    plt.legend()
    plt.xticks(xs)
    plt.xlabel('Epochs')
    plt.savefig("loss.png")
def main(seed):
    with timer('load data'):
        df = pd.read_csv(FOLD_PATH)
        df.drop("EncodedPixels_2", axis=1, inplace=True)
        df = df.rename(columns={"EncodedPixels_3": "EncodedPixels_2"})
        df = df.rename(columns={"EncodedPixels_4": "EncodedPixels_3"})
        y1 = (df.EncodedPixels_1 != "-1").astype("float32").values.reshape(
            -1, 1)
        y2 = (df.EncodedPixels_2 != "-1").astype("float32").values.reshape(
            -1, 1)
        y3 = (df.EncodedPixels_3 != "-1").astype("float32").values.reshape(
            -1, 1)
        #y4 = (df.EncodedPixels_4 != "-1").astype("float32").values.reshape(-1, 1)
        y = np.concatenate([y1, y2, y3], axis=1)

    with timer('preprocessing'):
        train_df, val_df = df[df.fold_id != FOLD_ID], df[df.fold_id == FOLD_ID]
        y_train, y_val = y[df.fold_id != FOLD_ID], y[df.fold_id == FOLD_ID]

        train_augmentation = Compose([
            Flip(p=0.5),
            OneOf([
                GridDistortion(p=0.5),
                OpticalDistortion(p=0.5, distort_limit=2, shift_limit=0.5)
            ],
                  p=0.5),
            OneOf([
                RandomGamma(gamma_limit=(100, 140), p=0.5),
                RandomBrightnessContrast(p=0.5),
            ],
                  p=0.5),
            OneOf([
                GaussNoise(p=0.5),
            ], p=0.5),
            ShiftScaleRotate(rotate_limit=20, p=0.5),
        ])
        val_augmentation = None

        train_dataset = SeverDataset(train_df,
                                     IMG_DIR,
                                     IMG_SIZE,
                                     N_CLASSES,
                                     id_colname=ID_COLUMNS,
                                     transforms=train_augmentation,
                                     crop_rate=1.0,
                                     class_y=y_train)
        val_dataset = SeverDataset(val_df,
                                   IMG_DIR,
                                   IMG_SIZE,
                                   N_CLASSES,
                                   id_colname=ID_COLUMNS,
                                   transforms=val_augmentation)
        train_sampler = MaskProbSampler(train_df, demand_non_empty_proba=0.6)
        train_loader = DataLoader(train_dataset,
                                  batch_size=BATCH_SIZE,
                                  sampler=train_sampler,
                                  num_workers=8)
        val_loader = DataLoader(val_dataset,
                                batch_size=BATCH_SIZE,
                                shuffle=False,
                                num_workers=8)

        del train_df, val_df, df, train_dataset, val_dataset
        gc.collect()

    with timer('create model'):
        model = smp.Unet('resnet34',
                         encoder_weights="imagenet",
                         classes=N_CLASSES,
                         encoder_se_module=True,
                         decoder_semodule=True,
                         h_columns=False,
                         skip=True,
                         act="swish",
                         freeze_bn=True,
                         classification=CLASSIFICATION,
                         attention_type="cbam",
                         center=True,
                         mode="train")
        model = convert_model(model)
        if base_model is not None:
            model.load_state_dict(torch.load(base_model))
        model.to(device)

        criterion = torch.nn.BCEWithLogitsLoss()
        optimizer = torch.optim.Adam([
            {
                'params': model.decoder.parameters(),
                'lr': 3e-3
            },
            {
                'params': model.encoder.parameters(),
                'lr': 3e-4
            },
        ])
        if base_model is None:
            scheduler_cosine = CosineAnnealingLR(optimizer,
                                                 T_max=CLR_CYCLE,
                                                 eta_min=3e-5)
            scheduler = GradualWarmupScheduler(
                optimizer,
                multiplier=1.1,
                total_epoch=CLR_CYCLE * 2,
                after_scheduler=scheduler_cosine)
        else:
            scheduler = CosineAnnealingLR(optimizer,
                                          T_max=CLR_CYCLE,
                                          eta_min=3e-5)

        model, optimizer = amp.initialize(model,
                                          optimizer,
                                          opt_level="O1",
                                          verbosity=0)

        if EMA:
            ema_model = copy.deepcopy(model)
            if base_model_ema is not None:
                ema_model.load_state_dict(torch.load(base_model_ema))
            ema_model.to(device)
            ema_model = torch.nn.DataParallel(ema_model)
        else:
            ema_model = None
        model = torch.nn.DataParallel(model)

    with timer('train'):
        train_losses = []
        valid_losses = []

        best_model_loss = 999
        best_model_ema_loss = 999
        best_model_ep = 0
        ema_decay = 0
        checkpoint = base_ckpt + 1

        for epoch in range(1, EPOCHS + 1):
            seed = seed + epoch
            seed_torch(seed)

            if epoch >= EMA_START:
                ema_decay = 0.99

            LOGGER.info("Starting {} epoch...".format(epoch))
            tr_loss = train_one_epoch(model,
                                      train_loader,
                                      criterion,
                                      optimizer,
                                      device,
                                      cutmix_prob=0.0,
                                      classification=CLASSIFICATION,
                                      ema_model=ema_model,
                                      ema_decay=ema_decay)
            train_losses.append(tr_loss)
            LOGGER.info('Mean train loss: {}'.format(round(tr_loss, 5)))

            valid_loss = validate(model,
                                  val_loader,
                                  criterion,
                                  device,
                                  classification=CLASSIFICATION)
            valid_losses.append(valid_loss)
            LOGGER.info('Mean valid loss: {}'.format(round(valid_loss, 5)))

            if EMA and epoch >= EMA_START:
                ema_valid_loss = validate(ema_model,
                                          val_loader,
                                          criterion,
                                          device,
                                          classification=CLASSIFICATION)
                LOGGER.info('Mean EMA valid loss: {}'.format(
                    round(ema_valid_loss, 5)))

                if ema_valid_loss < best_model_ema_loss:
                    torch.save(
                        ema_model.module.state_dict(),
                        'models/{}_fold{}_ckpt{}_ema.pth'.format(
                            EXP_ID, FOLD_ID, checkpoint))
                    best_model_ema_loss = ema_valid_loss

            scheduler.step()

            if valid_loss < best_model_loss:
                torch.save(
                    model.module.state_dict(),
                    'models/{}_fold{}_ckpt{}.pth'.format(
                        EXP_ID, FOLD_ID, checkpoint))
                best_model_loss = valid_loss
                best_model_ep = epoch
                #np.save("val_pred.npy", val_pred)

            if epoch % (CLR_CYCLE * 2) == CLR_CYCLE * 2 - 1:
                torch.save(
                    model.module.state_dict(),
                    'models/{}_fold{}_latest.pth'.format(EXP_ID, FOLD_ID))
                LOGGER.info('Best valid loss: {} on epoch={}'.format(
                    round(best_model_loss, 5), best_model_ep))
                if EMA:
                    torch.save(
                        ema_model.module.state_dict(),
                        'models/{}_fold{}_latest_ema.pth'.format(
                            EXP_ID, FOLD_ID))
                    LOGGER.info('Best ema valid loss: {}'.format(
                        round(best_model_ema_loss, 5)))
                    best_model_ema_loss = 999
                checkpoint += 1
                best_model_loss = 999

            #del val_pred
            gc.collect()

    LOGGER.info('Best valid loss: {} on epoch={}'.format(
        round(best_model_loss, 5), best_model_ep))

    xs = list(range(1, len(train_losses) + 1))
    plt.plot(xs, train_losses, label='Train loss')
    plt.plot(xs, valid_losses, label='Val loss')
    plt.legend()
    plt.xticks(xs)
    plt.xlabel('Epochs')
    plt.savefig("loss.png")
Esempio n. 5
0
    net = torch.nn.DataParallel(net)
    cudnn.benchmark = True

if args.resume:
    # Load checkpoint.
    print('==> Resuming from checkpoint..')
    assert os.path.isdir('checkpoint'), 'Error: no checkpoint directory found!'
    checkpoint = torch.load('./checkpoint/ckpt.pth')
    net.load_state_dict(checkpoint['net'])
    best_acc = checkpoint['acc']
    start_epoch = checkpoint['epoch']

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.1, momentum=0.9, weight_decay=0.0001)
cosine_scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, 250, eta_min=0, last_epoch=-1)
scheduler = GradualWarmupScheduler(optimizer, multiplier=10, total_epoch=2, after_scheduler=cosine_scheduler)# optimizer = optim.Adam(net.parameters())


# Training
def train(epoch):
    print('\nEpoch: %d' % epoch)
    net.train()
    train_loss = 0
    correct = 0
    total = 0
    for batch_idx, (inputs, targets) in enumerate(trainloader):
        inputs, targets = inputs.to(device), targets.to(device)
        optimizer.zero_grad()
        outputs = net(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
def main(seed):
    with timer('load data'):
        df = pd.read_csv(FOLD_PATH)

    with timer('preprocessing'):
        train_df, val_df = df[df.fold_id != FOLD_ID], df[df.fold_id == FOLD_ID]

        train_augmentation = Compose([
            Flip(p=0.5),
            OneOf([
                GridDistortion(p=0.5),
                OpticalDistortion(p=0.5, distort_limit=2, shift_limit=0.5)
            ],
                  p=0.5),
            OneOf([
                RandomGamma(gamma_limit=(100, 140), p=0.5),
                RandomBrightnessContrast(p=0.5),
                RandomBrightness(p=0.5),
                RandomContrast(p=0.5)
            ],
                  p=0.5),
            OneOf([
                GaussNoise(p=0.5),
                Cutout(num_holes=10, max_h_size=10, max_w_size=20, p=0.5)
            ],
                  p=0.5),
            ShiftScaleRotate(rotate_limit=20, p=0.5),
        ])
        val_augmentation = None

        train_dataset = SeverDataset(train_df,
                                     IMG_DIR,
                                     IMG_SIZE,
                                     N_CLASSES,
                                     id_colname=ID_COLUMNS,
                                     transforms=train_augmentation,
                                     crop_rate=1.0)
        val_dataset = SeverDataset(val_df,
                                   IMG_DIR,
                                   IMG_SIZE,
                                   N_CLASSES,
                                   id_colname=ID_COLUMNS,
                                   transforms=val_augmentation)
        train_sampler = MaskProbSampler(train_df, demand_non_empty_proba=0.6)
        train_loader = DataLoader(train_dataset,
                                  batch_size=BATCH_SIZE,
                                  sampler=train_sampler,
                                  num_workers=8)
        val_loader = DataLoader(val_dataset,
                                batch_size=BATCH_SIZE,
                                shuffle=False,
                                num_workers=8)

        del train_df, val_df, df, train_dataset, val_dataset
        gc.collect()

    with timer('create model'):
        model = smp.Unet('se_resnext50_32x4d',
                         encoder_weights="imagenet",
                         classes=N_CLASSES,
                         encoder_se_module=True,
                         decoder_semodule=True,
                         h_columns=False,
                         skip=True,
                         act="swish")
        model = convert_model(model)
        if base_model is not None:
            model.load_state_dict(torch.load(base_model))
        model.to(device)

        criterion = torch.nn.BCEWithLogitsLoss()
        optimizer = torch.optim.Adam(model.parameters(), lr=3e-4)
        if base_model is None:
            scheduler_cosine = CosineAnnealingLR(optimizer,
                                                 T_max=CLR_CYCLE,
                                                 eta_min=3e-5)
            scheduler = GradualWarmupScheduler(
                optimizer,
                multiplier=1.1,
                total_epoch=CLR_CYCLE * 2,
                after_scheduler=scheduler_cosine)
        else:
            scheduler = CosineAnnealingLR(optimizer,
                                          T_max=CLR_CYCLE,
                                          eta_min=3e-5)

        model, optimizer = amp.initialize(model,
                                          optimizer,
                                          opt_level="O1",
                                          verbosity=0)
        model = torch.nn.DataParallel(model)

    with timer('train'):
        train_losses = []
        valid_losses = []

        best_model_loss = 999
        best_model_ep = 0
        checkpoint = base_ckpt + 1

        for epoch in range(1, EPOCHS + 1):
            seed = seed + epoch
            seed_torch(seed)

            LOGGER.info("Starting {} epoch...".format(epoch))
            tr_loss = train_one_epoch(model,
                                      train_loader,
                                      criterion,
                                      optimizer,
                                      device,
                                      cutmix_prob=0.0)
            train_losses.append(tr_loss)
            LOGGER.info('Mean train loss: {}'.format(round(tr_loss, 5)))

            valid_loss = validate(model, val_loader, criterion, device)
            valid_losses.append(valid_loss)
            LOGGER.info('Mean valid loss: {}'.format(round(valid_loss, 5)))

            scheduler.step()

            if valid_loss < best_model_loss:
                torch.save(
                    model.module.state_dict(),
                    'models/{}_fold{}_ckpt{}.pth'.format(
                        EXP_ID, FOLD_ID, checkpoint))
                best_model_loss = valid_loss
                best_model_ep = epoch
                #np.save("val_pred.npy", val_pred)

            if epoch % (CLR_CYCLE * 2) == CLR_CYCLE * 2 - 1:
                torch.save(
                    model.module.state_dict(),
                    'models/{}_fold{}_latest.pth'.format(EXP_ID, FOLD_ID))
                LOGGER.info('Best valid loss: {} on epoch={}'.format(
                    round(best_model_loss, 5), best_model_ep))
                checkpoint += 1
                best_model_loss = 999

            #del val_pred
            gc.collect()

    LOGGER.info('Best valid loss: {} on epoch={}'.format(
        round(best_model_loss, 5), best_model_ep))

    xs = list(range(1, len(train_losses) + 1))
    plt.plot(xs, train_losses, label='Train loss')
    plt.plot(xs, valid_losses, label='Val loss')
    plt.legend()
    plt.xticks(xs)
    plt.xlabel('Epochs')
    plt.savefig("loss.png")
Esempio n. 7
0
    if hp.with_lars:
        #     optimizer = SGD_with_lars(net.parameters(), lr=hp.lr, momentum=hp.momentum, weight_decay=hp.weight_decay, trust_coef=hp.trust_coef)
        optimizer = SGD_with_lars_ver2(net.parameters(),
                                       lr=hp.lr,
                                       momentum=hp.momentum,
                                       weight_decay=hp.weight_decay,
                                       trust_coef=hp.trust_coef)
    else:
        #     optimizer = SGD_without_lars(net.parameters(), lr=hp.lr, momentum=hp.momentum, weight_decay=hp.weight_decay)
        optimizer = optim.SGD(net.parameters(),
                              lr=hp.lr,
                              momentum=hp.momentum,
                              weight_decay=hp.weight_decay)

    warmup_scheduler = GradualWarmupScheduler(optimizer=optimizer,
                                              multiplier=hp.warmup_multiplier,
                                              total_epoch=hp.warmup_epoch)
    poly_decay_scheduler = PolynomialLRDecay(
        optimizer=optimizer,
        max_decay_steps=hp.max_decay_epoch * len(trainloader),
        end_learning_rate=hp.end_learning_rate,
        power=2.0)  # poly(2)

    # Training
    def train(epoch):
        global train_total
        global train_correct
        global time_to_train
        net.train()
        train_loss = 0
        correct = 0
Esempio n. 8
0
def main(seed):
    with timer('load data'):
        df = pd.read_csv(FOLD_PATH)
        soft_df = pd.read_csv(SOFT_PATH)
        df = df.append(pd.read_csv(PSEUDO_PATH)).reset_index(drop=True)
        soft_df = soft_df.append(
            pd.read_csv(PSEUDO_PATH)).reset_index(drop=True)
        soft_df = df[[ID_COLUMNS]].merge(soft_df, how="left", on=ID_COLUMNS)
        LOGGER.info(df.head())
        LOGGER.info(soft_df.head())
        for c in [
                "EncodedPixels_1", "EncodedPixels_2", "EncodedPixels_3",
                "EncodedPixels_4"
        ]:
            df[c] = df[c].astype(str)
            soft_df[c] = soft_df[c].astype(str)
        df["fold_id"] = df["fold_id"].fillna(FOLD_ID + 1)
        y = (df.sum_target != 0).astype("float32").values
        y += (soft_df.sum_target != 0).astype("float32").values
        y = y / 2

    with timer('preprocessing'):
        train_df, val_df = df[df.fold_id != FOLD_ID], df[df.fold_id == FOLD_ID]
        train_soft_df, val_soft_df = soft_df[df.fold_id != FOLD_ID], soft_df[
            df.fold_id == FOLD_ID]
        y_train, y_val = y[df.fold_id != FOLD_ID], y[df.fold_id == FOLD_ID]

        train_augmentation = Compose([
            Flip(p=0.5),
            OneOf([
                GridDistortion(p=0.5),
                OpticalDistortion(p=0.5, distort_limit=2, shift_limit=0.5)
            ],
                  p=0.5),
            OneOf([
                RandomGamma(gamma_limit=(100, 140), p=0.5),
                RandomBrightnessContrast(p=0.5),
            ],
                  p=0.5),
            OneOf([
                GaussNoise(p=0.5),
            ], p=0.5),
            ShiftScaleRotate(rotate_limit=20, p=0.5),
        ])
        val_augmentation = None

        train_dataset = SeverDataset(train_df,
                                     IMG_DIR,
                                     IMG_SIZE,
                                     N_CLASSES,
                                     id_colname=ID_COLUMNS,
                                     transforms=train_augmentation,
                                     crop_rate=1.0,
                                     class_y=y_train,
                                     soft_df=train_soft_df)
        val_dataset = SeverDataset(val_df,
                                   IMG_DIR,
                                   IMG_SIZE,
                                   N_CLASSES,
                                   id_colname=ID_COLUMNS,
                                   transforms=val_augmentation,
                                   soft_df=val_soft_df)
        train_sampler = MaskProbSampler(train_df, demand_non_empty_proba=0.6)
        train_loader = DataLoader(train_dataset,
                                  batch_size=BATCH_SIZE,
                                  sampler=train_sampler,
                                  num_workers=8)
        val_loader = DataLoader(val_dataset,
                                batch_size=BATCH_SIZE,
                                shuffle=False,
                                num_workers=8)

        del train_df, val_df, df, train_dataset, val_dataset
        gc.collect()

    with timer('create model'):
        model = smp_old.Unet('resnet34',
                             encoder_weights="imagenet",
                             classes=N_CLASSES,
                             encoder_se_module=True,
                             decoder_semodule=True,
                             h_columns=False,
                             skip=True,
                             act="swish",
                             freeze_bn=True,
                             classification=CLASSIFICATION)
        model = convert_model(model)
        if base_model is not None:
            model.load_state_dict(torch.load(base_model))
        model.to(device)

        criterion = torch.nn.BCEWithLogitsLoss()
        optimizer = torch.optim.Adam([
            {
                'params': model.decoder.parameters(),
                'lr': 3e-3
            },
            {
                'params': model.encoder.parameters(),
                'lr': 3e-4
            },
        ])
        if base_model is None:
            scheduler_cosine = CosineAnnealingLR(optimizer,
                                                 T_max=CLR_CYCLE,
                                                 eta_min=3e-5)
            scheduler = GradualWarmupScheduler(
                optimizer,
                multiplier=1.1,
                total_epoch=CLR_CYCLE * 2,
                after_scheduler=scheduler_cosine)
        else:
            scheduler = CosineAnnealingLR(optimizer,
                                          T_max=CLR_CYCLE,
                                          eta_min=3e-5)

        model, optimizer = amp.initialize(model,
                                          optimizer,
                                          opt_level="O1",
                                          verbosity=0)
        model = torch.nn.DataParallel(model)

    with timer('train'):
        train_losses = []
        valid_losses = []

        best_model_loss = 999
        best_model_ep = 0
        checkpoint = base_ckpt + 1

        for epoch in range(1, EPOCHS + 1):
            seed = seed + epoch
            seed_torch(seed)

            LOGGER.info("Starting {} epoch...".format(epoch))
            tr_loss = train_one_epoch(model,
                                      train_loader,
                                      criterion,
                                      optimizer,
                                      device,
                                      cutmix_prob=0.0,
                                      classification=CLASSIFICATION)
            train_losses.append(tr_loss)
            LOGGER.info('Mean train loss: {}'.format(round(tr_loss, 5)))

            valid_loss, val_score = validate(model,
                                             val_loader,
                                             criterion,
                                             device,
                                             classification=CLASSIFICATION)
            valid_losses.append(valid_loss)
            LOGGER.info('Mean valid loss: {}'.format(round(valid_loss, 5)))
            LOGGER.info('Mean valid score: {}'.format(round(val_score, 5)))

            scheduler.step()

            if valid_loss < best_model_loss:
                torch.save(
                    model.module.state_dict(),
                    'models/{}_fold{}_ckpt{}.pth'.format(
                        EXP_ID, FOLD_ID, checkpoint))
                best_model_loss = valid_loss
                best_model_ep = epoch
                #np.save("val_pred.npy", val_pred)

            if epoch % (CLR_CYCLE * 2) == CLR_CYCLE * 2 - 1:
                torch.save(
                    model.module.state_dict(),
                    'models/{}_fold{}_latest.pth'.format(EXP_ID, FOLD_ID))
                LOGGER.info('Best valid loss: {} on epoch={}'.format(
                    round(best_model_loss, 5), best_model_ep))
                checkpoint += 1
                best_model_loss = 999

            #del val_pred
            gc.collect()

    LOGGER.info('Best valid loss: {} on epoch={}'.format(
        round(best_model_loss, 5), best_model_ep))

    xs = list(range(1, len(train_losses) + 1))
    plt.plot(xs, train_losses, label='Train loss')
    plt.plot(xs, valid_losses, label='Val loss')
    plt.legend()
    plt.xticks(xs)
    plt.xlabel('Epochs')
    plt.savefig("loss.png")
Esempio n. 9
0
import torch
import torch.optim as optim
from torch.optim import lr_scheduler
from torchvision.models import AlexNet
import matplotlib.pyplot as plt
from archs.mnist import AlexNet, LeNet5, fc1, vgg, resnet
from scheduler import GradualWarmupScheduler

model = fc1.fc1()
optimizer = optim.SGD(params=model.parameters(), lr=0.05)
scheduler_steplr = lr_scheduler.MultiStepLR(
    optimizer, milestones=[5,
                           10], gamma=0.1)  # means multistep will start with 0
scheduler_warmup = GradualWarmupScheduler(
    optimizer, multiplier=1, total_epoch=10,
    after_scheduler=scheduler_steplr)  # 20K=(idx)56, 35K=70

plt.figure()
x = list(range(20))
y = []

for epoch in range(20):
    scheduler_warmup.step()
    lr = scheduler.get_lr()
    print(epoch, scheduler.get_lr()[0])
    y.append(scheduler.get_lr()[0])

plt.plot(x, y)
Esempio n. 10
0
    # Use the nn package to define our model and loss function.
    model = torch.nn.Sequential(
        torch.nn.Linear(D_in, H),
        torch.nn.ReLU(),
        torch.nn.Linear(H, D_out),
    )

    #v = torch.zeros(10)
    optim = torch.optim.SGD(model.parameters(), lr=0.01)
    max_epoch = 100
    #optim = torch.optim.SGD([v], lr=0.01)
    scheduler_cosine = torch.optim.lr_scheduler.CosineAnnealingLR(
        optim, max_epoch)
    scheduler = GradualWarmupScheduler(optimizer=optim,
                                       multiplier=8,
                                       total_epoch=10,
                                       after_scheduler=scheduler_cosine)

    x = []
    y = []
    for epoch in range(1, max_epoch):
        scheduler.step(epoch)
        x.append(epoch)
        y.append(optim.param_groups[0]['lr'])
        print(optim.param_groups[0]['lr'])
        #print(epoch, optim.param_groups[0]['lr'])

    #fig = plt.figure()
    #fig.plot(x,y)
    plt.scatter(x, y, color='red')
    plt.show()
Esempio n. 11
0
class CTTTrainer(WandBMixin, IOMixin, BaseExperiment):
    WANDB_PROJECT = "ctt"

    def __init__(self):
        super(CTTTrainer, self).__init__()
        self.auto_setup()
        self._build()

    def _build(self):
        self._build_loaders()
        self._build_model()
        self._build_criteria_and_optim()
        self._build_scheduler()

    def _build_model(self):
        self.model: nn.Module = to_device(
            ContactTracingTransformer(**self.get("model/kwargs", {})),
            self.device)

    def _build_loaders(self):
        train_path = self.get("data/paths/train", ensure_exists=True)
        validate_path = self.get("data/paths/validate", ensure_exists=True)
        self.train_loader = get_dataloader(path=train_path,
                                           **self.get("data/loader_kwargs",
                                                      ensure_exists=True))
        self.validate_loader = get_dataloader(path=validate_path,
                                              **self.get("data/loader_kwargs",
                                                         ensure_exists=True))

    def _build_criteria_and_optim(self):
        # noinspection PyArgumentList
        self.loss = WeightedSum.from_config(
            self.get("losses", ensure_exists=True))
        self.optim = torch.optim.Adam(self.model.parameters(),
                                      **self.get("optim/kwargs"))
        self.metrics = Metrics()

    def _build_scheduler(self):
        if self.get("scheduler/use", False):
            self._base_scheduler = CosineAnnealingLR(
                self.optim,
                T_max=self.get("training/num_epochs"),
                **self.get("scheduler/kwargs", {}),
            )
        else:
            self._base_scheduler = None
        # Support for LR warmup
        if self.get("scheduler/warmup", False):
            assert self._base_scheduler is not None
            self.scheduler = GradualWarmupScheduler(
                self.optim,
                multiplier=1,
                total_epoch=5,
                after_scheduler=self._base_scheduler,
            )
        else:
            self.scheduler = self._base_scheduler

    @property
    def device(self):
        return self.get("device", "cpu")

    @register_default_dispatch
    def train(self):
        if self.get("wandb/use", True):
            self.initialize_wandb()
        for epoch in self.progress(range(
                self.get("training/num_epochs", ensure_exists=True)),
                                   tag="epochs"):
            self.log_learning_rates()
            self.train_epoch()
            validation_stats = self.validate_epoch()
            self.checkpoint()
            self.log_progress("epochs", **validation_stats)
            self.step_scheduler(epoch)
            self.next_epoch()

    def train_epoch(self):
        self.clear_moving_averages()
        self.model.train()
        for model_input in self.progress(self.train_loader, tag="train"):
            # Evaluate model
            model_input = to_device(model_input, self.device)
            model_output = Dict(self.model(model_input))
            # Compute loss
            losses = self.loss(model_input, model_output)
            loss = losses.loss
            self.optim.zero_grad()
            loss.backward()
            self.optim.step()
            # Log to wandb (if required)
            self.log_training_losses(losses)
            # Log to pbar
            self.accumulate_in_cache("moving_loss", loss.item(),
                                     momentum_accumulator(0.9))
            self.log_progress(
                "train",
                loss=self.read_from_cache("moving_loss"),
            )
            self.next_step()

    def validate_epoch(self):
        all_losses_and_metrics = defaultdict(list)
        self.metrics.reset()
        self.model.eval()
        for model_input in self.progress(self.validate_loader,
                                         tag="validation"):
            with torch.no_grad():
                model_input = to_device(model_input, self.device)
                model_output = Dict(self.model(model_input))
                losses = self.loss(model_input, model_output)
                self.metrics.update(model_input, model_output)
                all_losses_and_metrics["loss"].append(losses.loss.item())
                for key in losses.unweighted_losses:
                    all_losses_and_metrics[key].append(
                        losses.unweighted_losses[key].item())
        # Compute mean for all losses
        all_losses_and_metrics = Dict(
            {key: np.mean(val)
             for key, val in all_losses_and_metrics.items()})
        all_losses_and_metrics.update(Dict(self.metrics.evaluate()))
        self.log_validation_losses_and_metrics(all_losses_and_metrics)
        # Store the validation loss in cache. This will be used for checkpointing.
        self.write_to_cache("current_validation_loss",
                            all_losses_and_metrics.loss)
        return all_losses_and_metrics

    def log_training_losses(self, losses):
        if not self.get("wandb/use", True):
            return self
        if self.log_wandb_now:
            metrics = Dict({"training_loss": losses.loss})
            metrics.update({
                f"training_{k}": v
                for k, v in losses.unweighted_losses.items()
            })
            self.wandb_log(**metrics)
        return self

    def checkpoint(self, force=False):
        current_validation_loss = self.read_from_cache(
            "current_validation_loss", float("inf"))
        best_validation_loss = self.read_from_cache("best_validation_loss",
                                                    float("inf"))
        if current_validation_loss < best_validation_loss:
            self.write_to_cache("best_validation_loss",
                                current_validation_loss)
            ckpt_path = os.path.join(self.checkpoint_directory, "best.ckpt")
        elif self.get_arg("force_checkpoint", force):
            ckpt_path = os.path.join(self.checkpoint_directory, "best.ckpt")
        else:
            ckpt_path = None
        if ckpt_path is not None:
            info_dict = {
                "model": self.model.state_dict(),
                "optim": self.optim.state_dict(),
            }
            torch.save(info_dict, ckpt_path)
        return self

    def load(self, device=None):
        ckpt_path = os.path.join(self.checkpoint_directory, "best.ckpt")
        if not os.path.exists(ckpt_path):
            raise FileNotFoundError
        info_dict = torch.load(
            ckpt_path,
            map_location=torch.device(
                (self.device if device is None else device)),
        )
        self.model.load_state_dict(info_dict["model"])
        self.optim.load_state_dict(info_dict["optim"])
        return self

    def log_validation_losses_and_metrics(self, losses):
        if not self.get("wandb/use", True):
            return self
        metrics = {f"validation_{k}": v for k, v in losses.items()}
        self.wandb_log(**metrics)
        return self

    def clear_moving_averages(self):
        return self.clear_in_cache("moving_loss")

    def step_scheduler(self, epoch):
        if self.scheduler is not None:
            self.scheduler.step(epoch)
        return self

    def log_learning_rates(self):
        if not self.get("wandb/use", True):
            return self
        lrs = {
            f"lr_{i}": param_group["lr"]
            for i, param_group in enumerate(self.optim.param_groups)
        }
        self.wandb_log(**lrs)
        return self
Esempio n. 12
0
    num_batches = len(tmp_train_dataloader)

    # set optimizer
    optimizer = optim.Adam(model.parameters(), lr=base_lr, eps=1e-8)

    # set scheduler
    # StepLR 경우
    # scheduler = StepLR(optimizer, step_size=5, gamma=0.5)
    # warmup + ReduceLROnPlateau
    t_total = len(tmp_train_dataloader) * num_epochs
    warmup_step = int(0.01 * t_total)
    # decay lr, related to a validation
    scheduler_cosine = CosineAnnealingLR(optimizer, t_total)
    scheduler = GradualWarmupScheduler(optimizer,
                                       1,
                                       warmup_step,
                                       after_scheduler=scheduler_cosine)

    criterion = ContrastiveLoss()

    #prediction_file = 'prediction.txt'
    counter = []
    loss_history = []
    iteration_number = 0
    best_f1 = 0.0
    best_loss = 1e10

    # check parameter of model
    print("------------------------------------------------------------")
    total_params = sum(p.numel() for p in model.parameters())
    print("num of parameter : ", total_params)
def main():
    with timer('load data'):
        df = pd.read_csv(FOLD_PATH)
        y = (df["sum_target"] != 0).values.astype("float32")

    with timer('preprocessing'):
        train_df, val_df = df[df.fold_id != FOLD_ID], df[df.fold_id == FOLD_ID]
        y_train, y_val = y[df.fold_id != FOLD_ID], y[df.fold_id == FOLD_ID]

        train_augmentation = Compose([
            Flip(p=0.5),
            OneOf(
                [
                    #ElasticTransform(p=0.5, alpha=120, sigma=120 * 0.05, alpha_affine=120 * 0.03),
                    GridDistortion(p=0.5),
                    OpticalDistortion(p=0.5, distort_limit=2, shift_limit=0.5)
                ],
                p=0.5),
            #OneOf([
            #    ShiftScaleRotate(p=0.5),
            ##    RandomRotate90(p=0.5),
            #    Rotate(p=0.5)
            #], p=0.5),
            OneOf([
                Blur(blur_limit=8, p=0.5),
                MotionBlur(blur_limit=8, p=0.5),
                MedianBlur(blur_limit=8, p=0.5),
                GaussianBlur(blur_limit=8, p=0.5)
            ],
                  p=0.5),
            OneOf(
                [
                    #CLAHE(clip_limit=4, tile_grid_size=(4, 4), p=0.5),
                    RandomGamma(gamma_limit=(100, 140), p=0.5),
                    RandomBrightnessContrast(p=0.5),
                    RandomBrightness(p=0.5),
                    RandomContrast(p=0.5)
                ],
                p=0.5),
            OneOf([
                GaussNoise(p=0.5),
                Cutout(num_holes=10, max_h_size=10, max_w_size=20, p=0.5)
            ],
                  p=0.5)
        ])
        val_augmentation = None

        train_dataset = SeverDataset(train_df,
                                     IMG_DIR,
                                     IMG_SIZE,
                                     N_CLASSES,
                                     id_colname=ID_COLUMNS,
                                     transforms=train_augmentation,
                                     class_y=y_train)
        val_dataset = SeverDataset(val_df,
                                   IMG_DIR,
                                   IMG_SIZE,
                                   N_CLASSES,
                                   id_colname=ID_COLUMNS,
                                   transforms=val_augmentation,
                                   class_y=y_val)
        train_loader = DataLoader(train_dataset,
                                  batch_size=BATCH_SIZE,
                                  shuffle=True,
                                  num_workers=2)
        val_loader = DataLoader(val_dataset,
                                batch_size=BATCH_SIZE,
                                shuffle=False,
                                num_workers=2)

        del train_df, val_df, df, train_dataset, val_dataset
        gc.collect()

    with timer('create model'):
        model = smp.UnetPP('se_resnext50_32x4d',
                           encoder_weights='imagenet',
                           classes=N_CLASSES,
                           encoder_se_module=True,
                           decoder_semodule=True,
                           h_columns=False,
                           deep_supervision=True,
                           classification=CLASSIFICATION)
        #model.load_state_dict(torch.load(model_path))
        model.to(device)

        criterion = torch.nn.BCEWithLogitsLoss()
        optimizer = torch.optim.Adam(model.parameters(), lr=3e-4)
        scheduler_cosine = CosineAnnealingLR(optimizer,
                                             T_max=CLR_CYCLE,
                                             eta_min=3e-5)
        scheduler = GradualWarmupScheduler(optimizer,
                                           multiplier=1.1,
                                           total_epoch=CLR_CYCLE * 2,
                                           after_scheduler=scheduler_cosine)

        model, optimizer = amp.initialize(model,
                                          optimizer,
                                          opt_level="O1",
                                          verbosity=0)

    with timer('train'):
        train_losses = []
        valid_losses = []

        best_model_loss = 999
        best_model_ep = 0
        checkpoint = 0

        for epoch in range(1, EPOCHS + 1):
            if epoch % (CLR_CYCLE * 2) == 0:
                if epoch != 0:
                    y_val = y_val.reshape(-1, N_CLASSES, IMG_SIZE[0],
                                          IMG_SIZE[1])
                    best_pred = best_pred.reshape(-1, N_CLASSES, IMG_SIZE[0],
                                                  IMG_SIZE[1])
                    for i in range(N_CLASSES):
                        th, score, _, _ = search_threshold(
                            y_val[:, i, :, :], best_pred[:, i, :, :])
                        LOGGER.info(
                            'Best loss: {} Best Dice: {} on epoch {} th {} class {}'
                            .format(round(best_model_loss, 5), round(score, 5),
                                    best_model_ep, th, i))
                checkpoint += 1
                best_model_loss = 999

            LOGGER.info("Starting {} epoch...".format(epoch))
            tr_loss = train_one_epoch_dsv(model,
                                          train_loader,
                                          criterion,
                                          optimizer,
                                          device,
                                          classification=CLASSIFICATION)
            train_losses.append(tr_loss)
            LOGGER.info('Mean train loss: {}'.format(round(tr_loss, 5)))

            valid_loss, val_pred, y_val = validate_dsv(model, val_loader,
                                                       criterion, device)
            valid_losses.append(valid_loss)
            LOGGER.info('Mean valid loss: {}'.format(round(valid_loss, 5)))

            scheduler.step()

            if valid_loss < best_model_loss:
                torch.save(
                    model.state_dict(),
                    '{}_fold{}_ckpt{}.pth'.format(EXP_ID, FOLD_ID, checkpoint))
                best_model_loss = valid_loss
                best_model_ep = epoch
                best_pred = val_pred

            del val_pred
            gc.collect()

    with timer('eval'):
        y_val = y_val.reshape(-1, N_CLASSES, IMG_SIZE[0], IMG_SIZE[1])
        best_pred = best_pred.reshape(-1, N_CLASSES, IMG_SIZE[0], IMG_SIZE[1])
        for i in range(N_CLASSES):
            th, score, _, _ = search_threshold(y_val[:, i, :, :],
                                               best_pred[:, i, :, :])
            LOGGER.info(
                'Best loss: {} Best Dice: {} on epoch {} th {} class {}'.
                format(round(best_model_loss, 5), round(score, 5),
                       best_model_ep, th, i))

    xs = list(range(1, len(train_losses) + 1))
    plt.plot(xs, train_losses, label='Train loss')
    plt.plot(xs, valid_losses, label='Val loss')
    plt.legend()
    plt.xticks(xs)
    plt.xlabel('Epochs')
    plt.savefig("loss.png")
Esempio n. 14
0
def main(config):
    """ 1. data process """
    if fst.flag_orig_npy == True:
        print('preparation of the numpy')
        if os.path.exists(st.orig_npy_dir) == False:
            os.makedirs(st.orig_npy_dir)
        """ processing """
        if st.list_data_type[st.data_type_num] == 'Density':
            cDL.Prepare_data_GM_AGE_MMSE()
        elif st.list_data_type[st.data_type_num] == 'ADNI_JSY':
            jDL.Prepare_data_1()
        elif st.list_data_type[st.data_type_num] == 'ADNI_Jacob_256':
            jcDL.Prepare_data_GM_WM_CSF()
        elif 'ADNI_Jacob' in st.list_data_type[st.data_type_num]:
            jcDL.Prepare_data_GM()
        elif 'ADNI_AAL_256' in st.list_data_type[st.data_type_num]:
            aDL.Prepare_data_GM()

    if fst.flag_orig_npy_other_dataset == True:
        cDL.Prepare_data_GM_age_others(dataset='ABIDE')
        cDL.Prepare_data_GM_age_others(dataset='ICBM')
        cDL.Prepare_data_GM_age_others(dataset='Cam')
        cDL.Prepare_data_GM_age_others(dataset='IXI')
        cDL.Prepare_data_GM_age_others(dataset='PPMI')
    """ 2. fold index processing """
    if fst.flag_fold_index == True:
        print('preparation of the fold index')
        if os.path.exists(st.fold_index_dir) == False:
            os.makedirs(st.fold_index_dir)
        """ save the fold index """
        ut.preparation_fold_index(config)
    """ fold selection """
    start_fold = st.start_fold
    end_fold = st.end_fold
    """ workbook """
    list_dir_result = []
    list_wb = []
    list_ws = []
    for i in range(len(st.list_standard_eval_dir)):
        list_dir_result.append(st.dir_to_save_1 + st.list_standard_eval_dir[i])
        ut.make_dir(dir=list_dir_result[i], flag_rm=False)
        out = ut.excel_setting(start_fold=start_fold,
                               end_fold=end_fold,
                               result_dir=list_dir_result[i],
                               f_name='results')
        list_wb.append(out[0])
        list_ws.append(out[1])
    """ fold """
    list_eval_metric = st.list_eval_metric
    metric_avg = [[[] for j in range(len(st.list_eval_metric))]
                  for i in range(len(st.list_standard_eval_dir))]
    for fold in range(start_fold, end_fold + 1):
        print("FOLD : {}".format(fold))

        ## TODO : Directory preparation
        print('-' * 10 + 'Directory preparation' + '-' * 10)
        list_dir_save_model = []
        list_dir_save_model_2 = []
        list_dir_confusion = []
        list_dir_age_pred = []
        list_dir_heatmap = []
        for i in range(len(st.list_standard_eval_dir)):
            """ dir to save model """
            list_dir_save_model.append(st.dir_to_save_1 +
                                       st.list_standard_eval_dir[i] +
                                       '/weights/fold_{}'.format(fold))
            ut.make_dir(dir=list_dir_save_model[i], flag_rm=False)

            list_dir_save_model_2.append(st.dir_to_save_1 +
                                         st.list_standard_eval_dir[i] +
                                         '/weights_2/fold_{}'.format(fold))
            ut.make_dir(dir=list_dir_save_model_2[i], flag_rm=False)
            """ dir to save confusion matrix  """
            list_dir_confusion.append(st.dir_to_save_1 +
                                      st.list_standard_eval_dir[i] +
                                      '/confusion')
            ut.make_dir(dir=list_dir_confusion[i], flag_rm=False)
            """ dir to save age pred """
            list_dir_age_pred.append(st.dir_to_save_1 +
                                     st.list_standard_eval_dir[i] +
                                     '/age_pred')
            ut.make_dir(dir=list_dir_age_pred[i], flag_rm=False)

            list_dir_heatmap.append(st.dir_to_save_1 +
                                    st.list_standard_eval_dir[i] + '/heatmap')
            ut.make_dir(dir=list_dir_heatmap[i], flag_rm=False)
        """ dir to save pyplot """
        dir_pyplot = st.dir_to_save_1 + '/pyplot/fold_{}'.format(fold)
        ut.make_dir(dir=dir_pyplot, flag_rm=False)
        """ dir to save MMSE dist """
        dir_MMSE_dist = st.dir_to_save_1 + '/MMSE_dist'
        ut.make_dir(dir=dir_MMSE_dist, flag_rm=False)

        ##TODO : model construction
        print('-' * 10 + 'Model construction' + '-' * 10)
        model_1 = construct_model.construct_model(config, flag_model_num=0)
        model_1 = nn.DataParallel(model_1)
        if fst.flag_classification_fine_tune == True:
            dir_to_load = st.dir_preTrain_1
            dir_load_model = dir_to_load + '/weights/fold_{}'.format(fold)
            model_dir = ut.model_dir_to_load(fold, dir_load_model)

            pretrained_dict = torch.load(model_dir)
            model_dict = model_1.state_dict()
            for k, v in pretrained_dict.items():
                if k in model_dict:
                    print(k)
            pretrained_dict = {
                k: v
                for k, v in pretrained_dict.items() if k in model_dict
            }
            model_dict.update(pretrained_dict)
            model_1.load_state_dict(model_dict)

        elif fst.flag_classification_using_pretrained == True:
            model_2 = construct_model.construct_model(config, flag_model_num=1)
            model_2 = nn.DataParallel(model_2)
            dir_to_load = st.dir_preTrain_1
            dir_load_model = dir_to_load + '/weights/fold_{}'.format(fold)
            model_dir = ut.model_dir_to_load(fold, dir_load_model)
            pretrained_dict = torch.load(model_dir)
            model_dict = model_2.state_dict()
            for k, v in pretrained_dict.items():
                if k in model_dict:
                    print(k)
            pretrained_dict = {
                k: v
                for k, v in pretrained_dict.items() if k in model_dict
            }
            model_dict.update(pretrained_dict)
            model_2.load_state_dict(model_dict)
            model_2.eval()
        """ optimizer """
        # optimizer = torch.optim.SGD(model_1.parameters(), lr=config.lr, momentum=0.9, weight_decay=st.weight_decay)

        optimizer = torch.optim.Adam(model_1.parameters(),
                                     lr=config.lr,
                                     betas=(0.9, 0.999),
                                     eps=1e-8,
                                     weight_decay=st.weight_decay)
        # optimizer = AdamP(model_1.parameters(), lr=config.lr, betas=(0.9, 0.999), weight_decay=st.weight_decay)
        # optimizer = RAdam(model_1.parameters(), lr=config.lr, betas=(0.9, 0.999), weight_decay=st.weight_decay)

        # scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=st.step_size, gamma=st.LR_decay_rate, last_epoch=-1)

        # params_dict = []
        # params_dict.append({'params': model.parameters(), 'lr': config.lr})
        # optimizer = ut.BayesianSGD(params=params_dict)
        # scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=st.step_size, gamma=st.LR_decay_rate, last_epoch=-1)

        # scheduler_expo = torch.optim.lr_scheduler.StepLR(optimizer, step_size=st.step_size, gamma=st.LR_decay_rate, last_epoch=-1)
        # scheduler = GradualWarmupScheduler(optimizer, multiplier=1, total_epoch=5 , after_scheduler=scheduler_expo)

        scheduler_cosine = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer, st.epoch)
        scheduler = GradualWarmupScheduler(optimizer,
                                           multiplier=1,
                                           total_epoch=5,
                                           after_scheduler=scheduler_cosine)

        # scheduler_cosine_restart = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer=optimizer, T_0=50)
        # scheduler = GradualWarmupScheduler(optimizer, multiplier=1, total_epoch=5, after_scheduler=scheduler_cosine_restart)

        # scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer=optimizer, T_0=50)

        # scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer=optimizer, mode='min', factor=0.2, patience=10)
        """ data loader """
        print('-' * 10 + 'data loader' + '-' * 10)
        train_loader = DL.convert_Dloader_3(fold,
                                            list_class=st.list_class_for_train,
                                            flag_tr_val_te='train',
                                            batch_size=config.batch_size,
                                            num_workers=0,
                                            shuffle=True,
                                            drop_last=True)
        val_loader = DL.convert_Dloader_3(fold,
                                          list_class=st.list_class_for_test,
                                          flag_tr_val_te='val',
                                          batch_size=config.batch_size,
                                          num_workers=0,
                                          shuffle=False,
                                          drop_last=False)
        test_loader = DL.convert_Dloader_3(fold,
                                           list_class=st.list_class_for_test,
                                           flag_tr_val_te='test',
                                           batch_size=config.batch_size,
                                           num_workers=0,
                                           shuffle=False,
                                           drop_last=False)

        dict_data_loader = {
            'train': train_loader,
            'val': val_loader,
            'test': test_loader
        }
        """ normal classification tasks """
        list_test_result = []
        print('-' * 10 + 'start training' + '-' * 10)
        if fst.flag_classification == True or fst.flag_classification_fine_tune == True:
            train.train(config,
                        fold,
                        model_1,
                        dict_data_loader,
                        optimizer,
                        scheduler,
                        list_dir_save_model,
                        dir_pyplot,
                        Validation=True,
                        Test_flag=True)
            for i_tmp in range(len(st.list_standard_eval_dir)):
                dict_test_output = test.test(config, fold, model_1,
                                             dict_data_loader['test'],
                                             list_dir_save_model[i_tmp],
                                             list_dir_confusion[i_tmp])
                list_test_result.append(dict_test_output)
                # if len(st.list_selected_for_train) == 2 and fold == 1 and st.list_standard_eval_dir[i_tmp] == '/val_auc':
                #     generate_heatmap.get_multi_heatmap_2class(config, fold, model, list_dir_save_model[i_tmp], list_dir_heatmap[i_tmp])

        elif fst.flag_classification_using_pretrained == True:
            """ using pretrained patch level model """
            train_using_pretrained.train(config,
                                         fold,
                                         model_1,
                                         model_2,
                                         dict_data_loader,
                                         optimizer,
                                         scheduler,
                                         list_dir_save_model,
                                         dir_pyplot,
                                         Validation=True,
                                         Test_flag=True)
            for i_tmp in range(len(st.list_standard_eval_dir)):
                dict_test_output = test_using_pretrained.test(
                    config, fold, model_1, model_2, dict_data_loader['test'],
                    list_dir_save_model[i_tmp], list_dir_confusion[i_tmp])
                list_test_result.append(dict_test_output)

        elif fst.flag_multi_task == True:
            train_multi_task.train(config,
                                   fold,
                                   model_1,
                                   dict_data_loader,
                                   optimizer,
                                   scheduler,
                                   list_dir_save_model,
                                   dir_pyplot,
                                   Validation=True,
                                   Test_flag=True)
            for i_tmp in range(len(st.list_standard_eval_dir)):
                dict_test_output = test_multi_task.test(
                    config, fold, model_1, dict_data_loader['test'],
                    list_dir_save_model[i_tmp], list_dir_confusion[i_tmp])
                list_test_result.append(dict_test_output)
        """ fill out the results on the excel sheet """
        for i_standard in range(len(st.list_standard_eval_dir)):
            for i in range(len(list_eval_metric)):
                if list_eval_metric[i] in list_test_result[i_standard]:
                    list_ws[i_standard].cell(
                        row=2 + i + st.push_start_row,
                        column=fold + 1,
                        value="%.4f" %
                        (list_test_result[i_standard][list_eval_metric[i]]))
                    metric_avg[i_standard][i].append(
                        list_test_result[i_standard][list_eval_metric[i]])

            for i in range(len(list_eval_metric)):
                if metric_avg[i_standard][i]:
                    avg = round(np.mean(metric_avg[i_standard][i]), 4)
                    std = round(np.std(metric_avg[i_standard][i]), 4)
                    tmp = "%.4f \u00B1 %.4f" % (avg, std)
                    list_ws[i_standard].cell(row=2 + st.push_start_row + i,
                                             column=end_fold + 2,
                                             value=tmp)

            list_wb[i_standard].save(list_dir_result[i_standard] +
                                     "/results.xlsx")

    for i_standard in range(len(st.list_standard_eval_dir)):
        n_row = list_ws[i_standard].max_row
        n_col = list_ws[i_standard].max_column
        for i_row in range(1, n_row + 1):
            for i_col in range(1, n_col + 1):
                ca1 = list_ws[i_standard].cell(row=i_row, column=i_col)
                ca1.alignment = Alignment(horizontal='center',
                                          vertical='center')
        list_wb[i_standard].save(list_dir_result[i_standard] + "/results.xlsx")
        list_wb[i_standard].close()

    print("finished")
Esempio n. 15
0
def main(seed):
    with timer('load data'):
        df = pd.read_csv(FOLD_PATH)
        y1 = (df.EncodedPixels_1 != "-1").astype("float32").values.reshape(
            -1, 1)
        y2 = (df.EncodedPixels_2 != "-1").astype("float32").values.reshape(
            -1, 1)
        y3 = (df.EncodedPixels_3 != "-1").astype("float32").values.reshape(
            -1, 1)
        y4 = (df.EncodedPixels_4 != "-1").astype("float32").values.reshape(
            -1, 1)
        y = np.concatenate([y1, y2, y3, y4], axis=1)

    with timer('preprocessing'):
        train_df, val_df = df[df.fold_id != FOLD_ID], df[df.fold_id == FOLD_ID]
        y_train, y_val = y[df.fold_id != FOLD_ID], y[df.fold_id == FOLD_ID]

        train_augmentation = Compose([
            Flip(p=0.5),
            OneOf([
                GridDistortion(p=0.5),
                OpticalDistortion(p=0.5, distort_limit=2, shift_limit=0.5)
            ],
                  p=0.5),
            OneOf([
                RandomGamma(gamma_limit=(100, 140), p=0.5),
                RandomBrightnessContrast(p=0.5),
                RandomBrightness(p=0.5),
                RandomContrast(p=0.5)
            ],
                  p=0.5),
            OneOf([
                GaussNoise(p=0.5),
                Cutout(num_holes=10, max_h_size=10, max_w_size=20, p=0.5)
            ],
                  p=0.5),
            ShiftScaleRotate(rotate_limit=20, p=0.5),
        ])
        val_augmentation = None

        train_dataset = SeverDataset(train_df,
                                     IMG_DIR,
                                     IMG_SIZE,
                                     N_CLASSES,
                                     id_colname=ID_COLUMNS,
                                     transforms=train_augmentation,
                                     crop_rate=1.0,
                                     class_y=y_train)
        val_dataset = SeverDataset(val_df,
                                   IMG_DIR,
                                   IMG_SIZE,
                                   N_CLASSES,
                                   id_colname=ID_COLUMNS,
                                   transforms=val_augmentation)
        train_sampler = MaskProbSampler(train_df, demand_non_empty_proba=0.6)
        train_loader = DataLoader(train_dataset,
                                  batch_size=BATCH_SIZE,
                                  sampler=train_sampler,
                                  num_workers=8)
        val_loader = DataLoader(val_dataset,
                                batch_size=BATCH_SIZE,
                                shuffle=False,
                                num_workers=8)

        del train_df, val_df, df, train_dataset, val_dataset
        gc.collect()

    with timer('create model'):
        model = smp.Unet('se_resnext50_32x4d',
                         encoder_weights="imagenet",
                         classes=N_CLASSES,
                         encoder_se_module=True,
                         decoder_semodule=True,
                         h_columns=False,
                         skip=True,
                         act="swish",
                         freeze_bn=True,
                         classification=CLASSIFICATION,
                         attention_type="cbam",
                         center=True)
        model = convert_model(model)
        if base_model is not None:
            model.load_state_dict(torch.load(base_model))
        model.to(device)

        criterion = torch.nn.BCEWithLogitsLoss()
        optimizer = torch.optim.Adam([
            {
                'params': model.decoder.parameters(),
                'lr': 3e-3
            },
            {
                'params': model.encoder.parameters(),
                'lr': 3e-4
            },
        ])
        if base_model is None:
            scheduler_cosine = CosineAnnealingLR(optimizer,
                                                 T_max=CLR_CYCLE,
                                                 eta_min=3e-5)
            scheduler = GradualWarmupScheduler(
                optimizer,
                multiplier=1.1,
                total_epoch=CLR_CYCLE * 2,
                after_scheduler=scheduler_cosine)
        else:
            scheduler = CosineAnnealingLR(optimizer,
                                          T_max=CLR_CYCLE,
                                          eta_min=3e-5)

        model, optimizer = amp.initialize(model,
                                          optimizer,
                                          opt_level="O1",
                                          verbosity=0)

        if EMA:
            ema_model = copy.deepcopy(model)
            if base_model_ema is not None:
                ema_model.load_state_dict(torch.load(base_model_ema))
            ema_model.to(device)
        else:
            ema_model = None
        model = torch.nn.DataParallel(model)
        ema_model = torch.nn.DataParallel(ema_model)

    with timer('train'):
        valid_loss = validate(model,
                              val_loader,
                              criterion,
                              device,
                              classification=CLASSIFICATION)
Esempio n. 16
0
    def train(self):

        # pass
        df = pd.read_csv(os.path.join(DATA_PATH, DataID, 'train.csv'))
        image_path_list = df['image_path'].values
        label_list = df['label'].values

        # 划分训练集和校验集
        all_size = len(image_path_list)
        train_size = int(all_size * 0.9)
        train_image_path_list = image_path_list[:train_size]
        train_label_list = label_list[:train_size]
        val_image_path_list = image_path_list[train_size:]
        val_label_list = label_list[train_size:]
        print(
            'train_size: %d, val_size: %d' % (len(train_image_path_list),
                                              len(val_image_path_list)))
        train_transform, val_trainsform = self.deal_with_data()
        train_data = ImageData(train_image_path_list, train_label_list,
                               train_transform)
        val_data = ImageData(val_image_path_list, val_label_list,
                             val_trainsform)
        train_loader = DataLoader(train_data, batch_size=args.BATCH,
                                  num_workers=0, shuffle=True)
        val_loader = DataLoader(val_data, batch_size=args.BATCH,
                                num_workers=0, shuffle=False)
        model = EfficientNet.from_pretrained('efficientnet-b1')
        
        model.fc = nn.Linear(1280, 2)
        if use_gpu:
            model.to(DEVICE)
        criteration = nn.CrossEntropyLoss()
        criteration.cuda()
        optimizer = torch.optim.SGD(model.parameters(), lr=args.LR,
                                    momentum=0.9, weight_decay=5e-4)

        if args.SCHE == "cos":
            scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer,
                                                             T_max=5,
                                                             eta_min=4e-08)
        elif args.SCHE == "red":
             scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
                optimizer, mode="min", factor=0.1,
               patience=3, verbose=False, threshold=0.0001
            )
        else:
            sys.exit(-1)
        max_correct = 0
        
        #scheduler_steplr = StepLR(optimizer, step_size=10, gamma=0.1)
        scheduler_warmup = GradualWarmupScheduler(optimizer, multiplier=1, total_epoch=5, after_scheduler=scheduler)
        
        for epoch in range(args.EPOCHS):
           
            #scheduler_warmup.step(epoch)
            model.train()
            correct = 0
            # Train losses
            train_losses = []
            for img, label in train_loader:
                img, label = img.to(DEVICE), label.to(DEVICE)
                optimizer.zero_grad()
                output = model(img)
                #loss = criteration(output, label)
                loss = self.label_smoothing(output, label,epsilon=0.1)
                loss.backward()
                optimizer.step()
                # Train Metric
                train_pred = output.detach().cpu().max(1, keepdim=True)[1]
                correct += train_pred.eq(label.detach().cpu().
                                         view_as(train_pred)).sum().item()
                train_losses.append(loss.item())
                del train_pred
                # print("Epoch {}, Loss {:.4f}".format(epoch, loss.item()))
            del img, label

            #  Train loss curve
            train_avg_loss = np.mean(train_losses)
            
            acc = 100 * correct / len(train_image_path_list)
            
            scheduler_warmup.step_ReduceLROnPlateau(train_avg_loss)


            if epoch % 1 == 0 or epoch == args.EPOCHS - 1:
                correct = 0
                with torch.no_grad():
                    model.eval()
                    # Val losses
                    val_losses = []
                    for val_img, val_label in val_loader:
                        val_img = val_img.to(DEVICE),
                        val_label = val_label.to(DEVICE)
                        val_output = model(val_img[0])
                        loss = criteration(val_output, val_label)
                        val_pred = val_output.detach().cpu().\
                            max(1, keepdim=True)[1]
                        correct += val_pred.eq(val_label.detach().cpu().
                                               view_as(val_pred)).\
                            sum().item()
                        val_losses.append(loss.item())
                        del val_img, val_label, val_output, val_pred

                #  Val loss curve
                val_avg_loss = np.mean(val_losses)
                
                val_acc = 100 * correct / len(val_image_path_list)
                

                if (correct > max_correct):
                    max_correct = correct
                    torch.save(model, MODEL_PATH + '/' + "best.pth")
                print("Epoch {},  Accuracy {:.0f}%".format(
                    epoch, 100 * correct / len(val_image_path_list)))
                
                
               # LR curve
                
            train_log(train_loss=train_avg_loss, train_acc=acc, val_loss=val_avg_loss,val_acc=val_acc)