コード例 #1
0
def main(seed):
    with timer('load data'):
        df = pd.read_csv(FOLD_PATH)
        y1 = (df.EncodedPixels_1 != "-1").astype("float32").values.reshape(
            -1, 1)
        y2 = (df.EncodedPixels_2 != "-1").astype("float32").values.reshape(
            -1, 1)
        y3 = (df.EncodedPixels_3 != "-1").astype("float32").values.reshape(
            -1, 1)
        y4 = (df.EncodedPixels_4 != "-1").astype("float32").values.reshape(
            -1, 1)
        y = np.concatenate([y1, y2, y3, y4], axis=1)
        #y = (df.sum_target != 0).astype("float32").values

    with timer('preprocessing'):
        train_df, val_df = df[df.fold_id != FOLD_ID], df[df.fold_id == FOLD_ID]
        y_train, y_val = y[df.fold_id != FOLD_ID], y[df.fold_id == FOLD_ID]

        train_augmentation = Compose([
            Flip(p=0.5),
            OneOf([
                GridDistortion(p=0.5),
                OpticalDistortion(p=0.5, distort_limit=2, shift_limit=0.5)
            ],
                  p=0.5),
            OneOf([
                RandomGamma(gamma_limit=(100, 140), p=0.5),
                RandomBrightnessContrast(p=0.5),
                RandomBrightness(p=0.5),
                RandomContrast(p=0.5)
            ],
                  p=0.5),
            OneOf([
                GaussNoise(p=0.5),
                Cutout(num_holes=10, max_h_size=10, max_w_size=20, p=0.5)
            ],
                  p=0.5),
            ShiftScaleRotate(rotate_limit=20, p=0.5),
        ])
        val_augmentation = None

        train_dataset = SeverCLSDataset(train_df,
                                        IMG_DIR,
                                        IMG_SIZE,
                                        N_CLASSES,
                                        y_train,
                                        id_colname=ID_COLUMNS,
                                        transforms=train_augmentation,
                                        crop_rate=1.0)
        val_dataset = SeverCLSDataset(val_df,
                                      IMG_DIR,
                                      IMG_SIZE,
                                      N_CLASSES,
                                      y_val,
                                      id_colname=ID_COLUMNS,
                                      transforms=val_augmentation)
        #train_sampler = MaskProbSampler(train_df, demand_non_empty_proba=0.6)
        train_loader = DataLoader(train_dataset,
                                  batch_size=BATCH_SIZE,
                                  shuffle=True,
                                  num_workers=8,
                                  pin_memory=True)
        val_loader = DataLoader(val_dataset,
                                batch_size=BATCH_SIZE,
                                shuffle=False,
                                num_workers=8,
                                pin_memory=True)

        del train_df, val_df, df, train_dataset, val_dataset
        gc.collect()

    with timer('create model'):
        model = ResNet(num_classes=N_CLASSES,
                       pretrained="imagenet",
                       net_cls=models.resnet50)
        #model = convert_model(model)
        if base_model is not None:
            model.load_state_dict(torch.load(base_model))
        model.to(device)

        criterion = torch.nn.BCEWithLogitsLoss()
        optimizer = torch.optim.Adam(model.parameters(), lr=3e-4, eps=1e-4)
        if base_model is None:
            scheduler_cosine = CosineAnnealingLR(optimizer,
                                                 T_max=CLR_CYCLE,
                                                 eta_min=3e-5)
            scheduler = GradualWarmupScheduler(
                optimizer,
                multiplier=1.1,
                total_epoch=CLR_CYCLE * 2,
                after_scheduler=scheduler_cosine)
        else:
            scheduler = CosineAnnealingLR(optimizer,
                                          T_max=CLR_CYCLE,
                                          eta_min=3e-5)

        model, optimizer = amp.initialize(model,
                                          optimizer,
                                          opt_level="O1",
                                          verbosity=0)

        if EMA:
            ema_model = copy.deepcopy(model)
            if base_model_ema is not None:
                ema_model.load_state_dict(torch.load(base_model_ema))
            ema_model.to(device)
        else:
            ema_model = None
        model = torch.nn.DataParallel(model)
        ema_model = torch.nn.DataParallel(ema_model)

    with timer('train'):
        train_losses = []
        valid_losses = []

        best_model_loss = 999
        best_model_ema_loss = 999
        best_model_ep = 0
        ema_decay = 0
        checkpoint = base_ckpt + 1

        for epoch in range(1, EPOCHS + 1):
            seed = seed + epoch
            seed_torch(seed)

            if epoch >= EMA_START:
                ema_decay = 0.99

            LOGGER.info("Starting {} epoch...".format(epoch))
            tr_loss = train_one_epoch(model,
                                      train_loader,
                                      criterion,
                                      optimizer,
                                      device,
                                      ema_model=ema_model,
                                      ema_decay=ema_decay)
            train_losses.append(tr_loss)
            LOGGER.info('Mean train loss: {}'.format(round(tr_loss, 5)))

            valid_loss, y_pred, y_true = validate(model, val_loader, criterion,
                                                  device)
            valid_losses.append(valid_loss)
            LOGGER.info('Mean valid loss: {}'.format(round(valid_loss, 5)))

            if EMA and epoch >= EMA_START:
                ema_valid_loss, y_pred_ema, _ = validate(
                    ema_model, val_loader, criterion, device)
                LOGGER.info('Mean EMA valid loss: {}'.format(
                    round(ema_valid_loss, 5)))

                if ema_valid_loss < best_model_ema_loss:
                    torch.save(
                        ema_model.module.state_dict(),
                        'models/{}_fold{}_ckpt{}_ema.pth'.format(
                            EXP_ID, FOLD_ID, checkpoint))
                    best_model_ema_loss = ema_valid_loss
                    np.save("y_pred_ema_ckpt{}.npy".format(checkpoint),
                            y_pred_ema)

            scheduler.step()

            if valid_loss < best_model_loss:
                torch.save(
                    model.module.state_dict(),
                    'models/{}_fold{}_ckpt{}.pth'.format(
                        EXP_ID, FOLD_ID, checkpoint))
                np.save("y_pred_ckpt{}.npy".format(checkpoint), y_pred)
                best_model_loss = valid_loss
                best_model_ep = epoch
                #np.save("val_pred.npy", val_pred)

            if epoch % (CLR_CYCLE * 2) == CLR_CYCLE * 2 - 1:
                torch.save(
                    model.module.state_dict(),
                    'models/{}_fold{}_latest.pth'.format(EXP_ID, FOLD_ID))
                LOGGER.info('Best valid loss: {} on epoch={}'.format(
                    round(best_model_loss, 5), best_model_ep))
                if EMA:
                    torch.save(
                        ema_model.module.state_dict(),
                        'models/{}_fold{}_latest_ema.pth'.format(
                            EXP_ID, FOLD_ID))
                    LOGGER.info('Best ema valid loss: {}'.format(
                        round(best_model_ema_loss, 5)))
                checkpoint += 1
                best_model_loss = 999
                best_model_ema_loss = 999

            #del val_pred
            gc.collect()

    LOGGER.info('Best valid loss: {} on epoch={}'.format(
        round(best_model_loss, 5), best_model_ep))

    xs = list(range(1, len(train_losses) + 1))
    plt.plot(xs, train_losses, label='Train loss')
    plt.plot(xs, valid_losses, label='Val loss')
    plt.legend()
    plt.xticks(xs)
    plt.xlabel('Epochs')
    plt.savefig("loss.png")
コード例 #2
0
def main(args, ITE=0):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    reinit = True if args.prune_type == "reinit" else False
    if args.save_dir:
        utils.checkdir(
            f"{os.getcwd()}/saves/{args.arch_type}/{args.dataset}/{args.save_dir}/"
        )
        utils.checkdir(
            f"{os.getcwd()}/plots/lt/{args.arch_type}/{args.dataset}/{args.save_dir}/"
        )
        utils.checkdir(
            f"{os.getcwd()}/dumps/lt/{args.arch_type}/{args.dataset}/{args.save_dir}/"
        )
    else:
        utils.checkdir(
            f"{os.getcwd()}/dumps/lt/{args.arch_type}/{args.dataset}/")
        utils.checkdir(
            f"{os.getcwd()}/plots/lt/{args.arch_type}/{args.dataset}/")
        utils.checkdir(f"{os.getcwd()}/saves/{args.arch_type}/{args.dataset}/")

    # Data Loader
    transform = transforms.Compose(
        [transforms.ToTensor(),
         transforms.Normalize((0.1307, ), (0.3081, ))])
    if args.dataset == "mnist":
        traindataset = datasets.MNIST('../data',
                                      train=True,
                                      download=True,
                                      transform=transform)
        testdataset = datasets.MNIST('../data',
                                     train=False,
                                     transform=transform)
        from archs.mnist import AlexNet, LeNet5, fc1, vgg, resnet

    elif args.dataset == "cifar10":
        transform_train = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465),
                                 (0.2023, 0.1994, 0.2010)),
        ])

        transform_test = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465),
                                 (0.2023, 0.1994, 0.2010)),
        ])

        traindataset = datasets.CIFAR10('../data',
                                        train=True,
                                        download=True,
                                        transform=transform_train)
        testdataset = datasets.CIFAR10('../data',
                                       train=False,
                                       transform=transform_test)
        from archs.cifar10 import AlexNet, LeNet5, fc1, vgg, resnet, densenet

    elif args.dataset == "fashionmnist":
        traindataset = datasets.FashionMNIST('../data',
                                             train=True,
                                             download=True,
                                             transform=transform)
        testdataset = datasets.FashionMNIST('../data',
                                            train=False,
                                            transform=transform)
        from archs.mnist import AlexNet, LeNet5, fc1, vgg, resnet

    elif args.dataset == "cifar100":
        traindataset = datasets.CIFAR100('../data',
                                         train=True,
                                         download=True,
                                         transform=transform)
        testdataset = datasets.CIFAR100('../data',
                                        train=False,
                                        transform=transform)
        from archs.cifar100 import AlexNet, fc1, LeNet5, vgg, resnet

    # If you want to add extra datasets paste here

    else:
        print("\nWrong Dataset choice \n")
        exit()

    if args.dataset == "cifar10":
        #trainsampler = torch.utils.data.RandomSampler(traindataset, replacement=True, num_samples=45000)  # 45K train dataset
        #train_loader = torch.utils.data.DataLoader(traindataset, batch_size=args.batch_size, shuffle=False, num_workers=0, drop_last=False, sampler=trainsampler)
        train_loader = torch.utils.data.DataLoader(traindataset,
                                                   batch_size=args.batch_size,
                                                   shuffle=True,
                                                   num_workers=4)
    else:
        train_loader = torch.utils.data.DataLoader(traindataset,
                                                   batch_size=args.batch_size,
                                                   shuffle=True,
                                                   num_workers=0,
                                                   drop_last=False)
    #train_loader = cycle(train_loader)
    test_loader = torch.utils.data.DataLoader(testdataset,
                                              batch_size=args.batch_size,
                                              shuffle=False,
                                              num_workers=4)

    # Importing Network Architecture

    #Initalize hessian dataloader, default batch_num 1
    for inputs, labels in train_loader:
        hessian_dataloader = (inputs, labels)
        break

    global model
    if args.arch_type == "fc1":
        model = fc1.fc1().to(device)
    elif args.arch_type == "lenet5":
        model = LeNet5.LeNet5().to(device)
    elif args.arch_type == "alexnet":
        model = AlexNet.AlexNet().to(device)
    elif args.arch_type == "vgg16":
        model = vgg.vgg16().to(device)
    elif args.arch_type == "resnet18":
        model = resnet.resnet18().to(device)
    elif args.arch_type == "densenet121":
        model = densenet.densenet121().to(device)
    # If you want to add extra model paste here
    else:
        print("\nWrong Model choice\n")
        exit()

    model = nn.DataParallel(model)
    # Weight Initialization
    model.apply(weight_init)

    # Copying and Saving Initial State
    initial_state_dict = copy.deepcopy(model.state_dict())
    if args.save_dir:
        torch.save(
            model.state_dict(),
            f"{os.getcwd()}/saves/{args.arch_type}/{args.dataset}/{args.save_dir}/initial_state_dict_{args.prune_type}.pth"
        )
    else:
        torch.save(
            model.state_dict(),
            f"{os.getcwd()}/saves/{args.arch_type}/{args.dataset}/initial_state_dict_{args.prune_type}.pth"
        )

    # global total_params
    total_params = 0
    # Layer Looper
    for name, param in model.named_parameters():
        print(name, param.size())
        total_params += param.numel()

    # Making Initial Mask
    make_mask(model, total_params)

    # Optimizer and Loss
    optimizer = torch.optim.SGD(model.parameters(),
                                lr=args.lr,
                                momentum=0.9,
                                weight_decay=1e-4)
    # warm up schedule; scheduler_warmup is chained with schduler_steplr
    scheduler_steplr = torch.optim.lr_scheduler.MultiStepLR(optimizer,
                                                            milestones=[0, 15],
                                                            gamma=0.1,
                                                            last_epoch=-1)
    if args.warmup:
        scheduler_warmup = GradualWarmupScheduler(
            optimizer,
            multiplier=1,
            total_epoch=50,
            after_scheduler=scheduler_steplr)  # 20K=(idx)56, 35K=70
    criterion = nn.CrossEntropyLoss(
    )  # Default was F.nll_loss; why test, train different?

    # Pruning
    # NOTE First Pruning Iteration is of No Compression
    bestacc = 0.0
    best_accuracy = 0
    ITERATION = args.prune_iterations
    comp = np.zeros(ITERATION, float)
    bestacc = np.zeros(ITERATION, float)
    step = 0
    all_loss = np.zeros(args.end_iter, float)
    all_accuracy = np.zeros(args.end_iter, float)

    for _ite in range(args.start_iter, ITERATION):
        if not _ite == 0:
            prune_by_percentile(args.prune_percent,
                                resample=resample,
                                reinit=reinit,
                                total_params=total_params,
                                hessian_aware=args.hessian,
                                criterion=criterion,
                                dataloader=hessian_dataloader,
                                cuda=torch.cuda.is_available())
            if reinit:
                model.apply(weight_init)
                #if args.arch_type == "fc1":
                #    model = fc1.fc1().to(device)
                #elif args.arch_type == "lenet5":
                #    model = LeNet5.LeNet5().to(device)
                #elif args.arch_type == "alexnet":
                #    model = AlexNet.AlexNet().to(device)
                #elif args.arch_type == "vgg16":
                #    model = vgg.vgg16().to(device)
                #elif args.arch_type == "resnet18":
                #    model = resnet.resnet18().to(device)
                #elif args.arch_type == "densenet121":
                #    model = densenet.densenet121().to(device)
                #else:
                #    print("\nWrong Model choice\n")
                #    exit()
                step = 0
                for name, param in model.named_parameters():
                    if 'weight' in name:
                        param_frac = param.numel() / total_params
                        if param_frac > 0.01:
                            weight_dev = param.device
                            param.data = torch.from_numpy(
                                param.data.cpu().numpy() *
                                mask[step]).to(weight_dev)
                            step = step + 1
                step = 0
            else:
                original_initialization(mask, initial_state_dict, total_params)
            # optimizer = torch.optim.SGD([{'params': model.parameters(), 'initial_lr': 0.03}], lr=args.lr, momentum=0.9, weight_decay=1e-4)
            # scheduler_steplr = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[0, 14], gamma=0.1, last_epoch=-1)
            # scheduler_warmup = GradualWarmupScheduler(optimizer, multiplier=1, total_epoch=56, after_scheduler=scheduler_steplr)  # 20K=(idx)56, 35K=70
        print(f"\n--- Pruning Level [{ITE}:{_ite}/{ITERATION}]: ---")

        # Optimizer and Loss
        optimizer = torch.optim.SGD(model.parameters(),
                                    lr=args.lr,
                                    momentum=0.9,
                                    weight_decay=1e-4)
        # warm up schedule; scheduler_warmup is chained with schduler_steplr
        scheduler_steplr = torch.optim.lr_scheduler.MultiStepLR(
            optimizer, milestones=[0, 15], gamma=0.1, last_epoch=-1)
        if args.warmup:
            scheduler_warmup = GradualWarmupScheduler(
                optimizer,
                multiplier=1,
                total_epoch=50,
                after_scheduler=scheduler_steplr)  # 20K=(idx)56, 35K=70

        # Print the table of Nonzeros in each layer
        comp1 = utils.print_nonzeros(model)
        comp[_ite] = comp1
        pbar = tqdm(range(args.end_iter))  # process bar

        for iter_ in pbar:

            # Frequency for Testing
            if iter_ % args.valid_freq == 0:
                accuracy = test(model, test_loader, criterion)

                # Save Weights for each _ite
                if accuracy > best_accuracy:
                    best_accuracy = accuracy
                    if args.save_dir:
                        torch.save(
                            model.state_dict(),
                            f"{os.getcwd()}/saves/{args.arch_type}/{args.dataset}/{args.save_dir}/{_ite}_model_{args.prune_type}.pth"
                        )
                    else:
                        # torch.save(model,f"{os.getcwd()}/saves/{args.arch_type}/{args.dataset}/{_ite}_model_{args.prune_type}.pth")
                        torch.save(
                            model.state_dict(),
                            f"{os.getcwd()}/saves/{args.arch_type}/{args.dataset}/{_ite}_model_{args.prune_type}.pth"
                        )

            # Training
            loss = train(model, train_loader, optimizer, criterion,
                         total_params)
            all_loss[iter_] = loss
            all_accuracy[iter_] = accuracy

            # warm up
            if args.warmup:
                scheduler_warmup.step()
            _lr = optimizer.param_groups[0]['lr']

            # Save the model during training
            if args.save_freq > 0 and iter_ % args.save_freq == 0:
                if args.save_dir:
                    torch.save(
                        model.state_dict(),
                        f"{os.getcwd()}/saves/{args.arch_type}/{args.dataset}/{args.save_dir}/{_ite}_model_{args.prune_type}_epoch{iter_}.pth"
                    )
                else:
                    torch.save(
                        model.state_dict(),
                        f"{os.getcwd()}/saves/{args.arch_type}/{args.dataset}/{_ite}_model_{args.prune_type}_epoch{iter_}.pth"
                    )

            # Frequency for Printing Accuracy and Loss
            if iter_ % args.print_freq == 0:
                pbar.set_description(
                    f'Train Epoch: {iter_}/{args.end_iter} Loss: {loss:.6f} Accuracy: {accuracy:.2f}% Best Accuracy: {best_accuracy:.2f}% Learning Rate: {_lr:.6f}%'
                )

        writer.add_scalar('Accuracy/test', best_accuracy, comp1)
        bestacc[_ite] = best_accuracy

        # Plotting Loss (Training), Accuracy (Testing), Iteration Curve
        #NOTE Loss is computed for every iteration while Accuracy is computed only for every {args.valid_freq} iterations. Therefore Accuracy saved is constant during the uncomputed iterations.
        #NOTE Normalized the accuracy to [0,100] for ease of plotting.
        plt.plot(np.arange(1, (args.end_iter) + 1),
                 100 * (all_loss - np.min(all_loss)) /
                 np.ptp(all_loss).astype(float),
                 c="blue",
                 label="Loss")
        plt.plot(np.arange(1, (args.end_iter) + 1),
                 all_accuracy,
                 c="red",
                 label="Accuracy")
        plt.title(
            f"Loss Vs Accuracy Vs Iterations ({args.dataset},{args.arch_type})"
        )
        plt.xlabel("Iterations")
        plt.ylabel("Loss and Accuracy")
        plt.legend()
        plt.grid(color="gray")
        if args.save_dir:
            plt.savefig(
                f"{os.getcwd()}/plots/lt/{args.arch_type}/{args.dataset}/{args.save_dir}/{args.prune_type}_LossVsAccuracy_{comp1}.png",
                dpi=1200)
        else:
            plt.savefig(
                f"{os.getcwd()}/plots/lt/{args.arch_type}/{args.dataset}/{args.prune_type}_LossVsAccuracy_{comp1}.png",
                dpi=1200)
        plt.close()

        # Dump Plot values
        if args.save_dir:
            all_loss.dump(
                f"{os.getcwd()}/dumps/lt/{args.arch_type}/{args.dataset}/{args.save_dir}/{args.prune_type}_all_loss_{comp1}.dat"
            )
            all_accuracy.dump(
                f"{os.getcwd()}/dumps/lt/{args.arch_type}/{args.dataset}/{args.save_dir}/{args.prune_type}_all_accuracy_{comp1}.dat"
            )
        else:
            all_loss.dump(
                f"{os.getcwd()}/dumps/lt/{args.arch_type}/{args.dataset}/{args.prune_type}_all_loss_{comp1}.dat"
            )
            all_accuracy.dump(
                f"{os.getcwd()}/dumps/lt/{args.arch_type}/{args.dataset}/{args.prune_type}_all_accuracy_{comp1}.dat"
            )

        # Dumping mask
        if args.save_dir:
            with open(
                    f"{os.getcwd()}/dumps/lt/{args.arch_type}/{args.dataset}/{args.save_dir}/{args.prune_type}_mask_{comp1}.pkl",
                    'wb') as fp:
                pickle.dump(mask, fp)
        else:
            with open(
                    f"{os.getcwd()}/dumps/lt/{args.arch_type}/{args.dataset}/{args.prune_type}_mask_{comp1}.pkl",
                    'wb') as fp:
                pickle.dump(mask, fp)

        # Making variables into 0
        best_accuracy = 0
        all_loss = np.zeros(args.end_iter, float)
        all_accuracy = np.zeros(args.end_iter, float)

    # Dumping Values for Plotting
    if args.save_dir:
        comp.dump(
            f"{os.getcwd()}/dumps/lt/{args.arch_type}/{args.dataset}/{args.save_dir}/{args.prune_type}_compression.dat"
        )
        bestacc.dump(
            f"{os.getcwd()}/dumps/lt/{args.arch_type}/{args.dataset}/{args.save_dir}/{args.prune_type}_bestaccuracy.dat"
        )
    else:
        comp.dump(
            f"{os.getcwd()}/dumps/lt/{args.arch_type}/{args.dataset}/{args.prune_type}_compression.dat"
        )
        bestacc.dump(
            f"{os.getcwd()}/dumps/lt/{args.arch_type}/{args.dataset}/{args.prune_type}_bestaccuracy.dat"
        )
    # Plotting
    a = np.arange(args.prune_iterations)
    plt.plot(a, bestacc, c="blue", label="Winning tickets")
    plt.title(
        f"Test Accuracy vs Unpruned Weights Percentage ({args.dataset},{args.arch_type})"
    )
    plt.xlabel("Unpruned Weights Percentage")
    plt.ylabel("test accuracy")
    plt.xticks(a, comp, rotation="vertical")
    plt.ylim(0, 100)
    plt.legend()
    plt.grid(color="gray")
    if args.save_dir:
        plt.savefig(
            f"{os.getcwd()}/plots/lt/{args.arch_type}/{args.dataset}/{args.save_dir}/{args.prune_type}_AccuracyVsWeights.png",
            dpi=1200)
    else:
        plt.savefig(
            f"{os.getcwd()}/plots/lt/{args.arch_type}/{args.dataset}/{args.prune_type}_AccuracyVsWeights.png",
            dpi=1200)
    plt.close()
コード例 #3
0
def main(seed):
    with timer('load data'):
        df = pd.read_csv(FOLD_PATH)

    with timer('preprocessing'):
        train_df, val_df = df[df.fold_id != FOLD_ID], df[df.fold_id == FOLD_ID]

        train_augmentation = Compose([
            Flip(p=0.5),
            OneOf([
                GridDistortion(p=0.5),
                OpticalDistortion(p=0.5, distort_limit=2, shift_limit=0.5)
            ],
                  p=0.5),
            OneOf([
                RandomGamma(gamma_limit=(100, 140), p=0.5),
                RandomBrightnessContrast(p=0.5),
                RandomBrightness(p=0.5),
                RandomContrast(p=0.5)
            ],
                  p=0.5),
            OneOf([
                GaussNoise(p=0.5),
                Cutout(num_holes=10, max_h_size=10, max_w_size=20, p=0.5)
            ],
                  p=0.5),
            ShiftScaleRotate(rotate_limit=20, p=0.5),
        ])
        val_augmentation = None

        train_dataset = SeverDataset(train_df,
                                     IMG_DIR,
                                     IMG_SIZE,
                                     N_CLASSES,
                                     id_colname=ID_COLUMNS,
                                     transforms=train_augmentation,
                                     crop_rate=1.0)
        val_dataset = SeverDataset(val_df,
                                   IMG_DIR,
                                   IMG_SIZE,
                                   N_CLASSES,
                                   id_colname=ID_COLUMNS,
                                   transforms=val_augmentation)
        train_sampler = MaskProbSampler(train_df, demand_non_empty_proba=0.6)
        train_loader = DataLoader(train_dataset,
                                  batch_size=BATCH_SIZE,
                                  sampler=train_sampler,
                                  num_workers=8)
        val_loader = DataLoader(val_dataset,
                                batch_size=BATCH_SIZE,
                                shuffle=False,
                                num_workers=8)

        del train_df, val_df, df, train_dataset, val_dataset
        gc.collect()

    with timer('create model'):
        model = smp.Unet('se_resnext50_32x4d',
                         encoder_weights="imagenet",
                         classes=N_CLASSES,
                         encoder_se_module=True,
                         decoder_semodule=True,
                         h_columns=False,
                         skip=True,
                         act="swish")
        model = convert_model(model)
        if base_model is not None:
            model.load_state_dict(torch.load(base_model))
        model.to(device)

        criterion = torch.nn.BCEWithLogitsLoss()
        optimizer = torch.optim.Adam(model.parameters(), lr=3e-4)
        if base_model is None:
            scheduler_cosine = CosineAnnealingLR(optimizer,
                                                 T_max=CLR_CYCLE,
                                                 eta_min=3e-5)
            scheduler = GradualWarmupScheduler(
                optimizer,
                multiplier=1.1,
                total_epoch=CLR_CYCLE * 2,
                after_scheduler=scheduler_cosine)
        else:
            scheduler = CosineAnnealingLR(optimizer,
                                          T_max=CLR_CYCLE,
                                          eta_min=3e-5)

        model, optimizer = amp.initialize(model,
                                          optimizer,
                                          opt_level="O1",
                                          verbosity=0)
        model = torch.nn.DataParallel(model)

    with timer('train'):
        train_losses = []
        valid_losses = []

        best_model_loss = 999
        best_model_ep = 0
        checkpoint = base_ckpt + 1

        for epoch in range(1, EPOCHS + 1):
            seed = seed + epoch
            seed_torch(seed)

            LOGGER.info("Starting {} epoch...".format(epoch))
            tr_loss = train_one_epoch(model,
                                      train_loader,
                                      criterion,
                                      optimizer,
                                      device,
                                      cutmix_prob=0.0)
            train_losses.append(tr_loss)
            LOGGER.info('Mean train loss: {}'.format(round(tr_loss, 5)))

            valid_loss = validate(model, val_loader, criterion, device)
            valid_losses.append(valid_loss)
            LOGGER.info('Mean valid loss: {}'.format(round(valid_loss, 5)))

            scheduler.step()

            if valid_loss < best_model_loss:
                torch.save(
                    model.module.state_dict(),
                    'models/{}_fold{}_ckpt{}.pth'.format(
                        EXP_ID, FOLD_ID, checkpoint))
                best_model_loss = valid_loss
                best_model_ep = epoch
                #np.save("val_pred.npy", val_pred)

            if epoch % (CLR_CYCLE * 2) == CLR_CYCLE * 2 - 1:
                torch.save(
                    model.module.state_dict(),
                    'models/{}_fold{}_latest.pth'.format(EXP_ID, FOLD_ID))
                LOGGER.info('Best valid loss: {} on epoch={}'.format(
                    round(best_model_loss, 5), best_model_ep))
                checkpoint += 1
                best_model_loss = 999

            #del val_pred
            gc.collect()

    LOGGER.info('Best valid loss: {} on epoch={}'.format(
        round(best_model_loss, 5), best_model_ep))

    xs = list(range(1, len(train_losses) + 1))
    plt.plot(xs, train_losses, label='Train loss')
    plt.plot(xs, valid_losses, label='Val loss')
    plt.legend()
    plt.xticks(xs)
    plt.xlabel('Epochs')
    plt.savefig("loss.png")
コード例 #4
0
def main(seed):
    with timer('load data'):
        df = pd.read_csv(FOLD_PATH)
        df.drop("EncodedPixels_2", axis=1, inplace=True)
        df = df.rename(columns={"EncodedPixels_3": "EncodedPixels_2"})
        df = df.rename(columns={"EncodedPixels_4": "EncodedPixels_3"})
        y1 = (df.EncodedPixels_1 != "-1").astype("float32").values.reshape(
            -1, 1)
        y2 = (df.EncodedPixels_2 != "-1").astype("float32").values.reshape(
            -1, 1)
        y3 = (df.EncodedPixels_3 != "-1").astype("float32").values.reshape(
            -1, 1)
        #y4 = (df.EncodedPixels_4 != "-1").astype("float32").values.reshape(-1, 1)
        y = np.concatenate([y1, y2, y3], axis=1)

    with timer('preprocessing'):
        train_df, val_df = df[df.fold_id != FOLD_ID], df[df.fold_id == FOLD_ID]
        y_train, y_val = y[df.fold_id != FOLD_ID], y[df.fold_id == FOLD_ID]

        train_augmentation = Compose([
            Flip(p=0.5),
            OneOf([
                GridDistortion(p=0.5),
                OpticalDistortion(p=0.5, distort_limit=2, shift_limit=0.5)
            ],
                  p=0.5),
            OneOf([
                RandomGamma(gamma_limit=(100, 140), p=0.5),
                RandomBrightnessContrast(p=0.5),
            ],
                  p=0.5),
            OneOf([
                GaussNoise(p=0.5),
            ], p=0.5),
            ShiftScaleRotate(rotate_limit=20, p=0.5),
        ])
        val_augmentation = None

        train_dataset = SeverDataset(train_df,
                                     IMG_DIR,
                                     IMG_SIZE,
                                     N_CLASSES,
                                     id_colname=ID_COLUMNS,
                                     transforms=train_augmentation,
                                     crop_rate=1.0,
                                     class_y=y_train)
        val_dataset = SeverDataset(val_df,
                                   IMG_DIR,
                                   IMG_SIZE,
                                   N_CLASSES,
                                   id_colname=ID_COLUMNS,
                                   transforms=val_augmentation)
        train_sampler = MaskProbSampler(train_df, demand_non_empty_proba=0.6)
        train_loader = DataLoader(train_dataset,
                                  batch_size=BATCH_SIZE,
                                  sampler=train_sampler,
                                  num_workers=8)
        val_loader = DataLoader(val_dataset,
                                batch_size=BATCH_SIZE,
                                shuffle=False,
                                num_workers=8)

        del train_df, val_df, df, train_dataset, val_dataset
        gc.collect()

    with timer('create model'):
        model = smp.Unet('resnet34',
                         encoder_weights="imagenet",
                         classes=N_CLASSES,
                         encoder_se_module=True,
                         decoder_semodule=True,
                         h_columns=False,
                         skip=True,
                         act="swish",
                         freeze_bn=True,
                         classification=CLASSIFICATION,
                         attention_type="cbam",
                         center=True,
                         mode="train")
        model = convert_model(model)
        if base_model is not None:
            model.load_state_dict(torch.load(base_model))
        model.to(device)

        criterion = torch.nn.BCEWithLogitsLoss()
        optimizer = torch.optim.Adam([
            {
                'params': model.decoder.parameters(),
                'lr': 3e-3
            },
            {
                'params': model.encoder.parameters(),
                'lr': 3e-4
            },
        ])
        if base_model is None:
            scheduler_cosine = CosineAnnealingLR(optimizer,
                                                 T_max=CLR_CYCLE,
                                                 eta_min=3e-5)
            scheduler = GradualWarmupScheduler(
                optimizer,
                multiplier=1.1,
                total_epoch=CLR_CYCLE * 2,
                after_scheduler=scheduler_cosine)
        else:
            scheduler = CosineAnnealingLR(optimizer,
                                          T_max=CLR_CYCLE,
                                          eta_min=3e-5)

        model, optimizer = amp.initialize(model,
                                          optimizer,
                                          opt_level="O1",
                                          verbosity=0)

        if EMA:
            ema_model = copy.deepcopy(model)
            if base_model_ema is not None:
                ema_model.load_state_dict(torch.load(base_model_ema))
            ema_model.to(device)
            ema_model = torch.nn.DataParallel(ema_model)
        else:
            ema_model = None
        model = torch.nn.DataParallel(model)

    with timer('train'):
        train_losses = []
        valid_losses = []

        best_model_loss = 999
        best_model_ema_loss = 999
        best_model_ep = 0
        ema_decay = 0
        checkpoint = base_ckpt + 1

        for epoch in range(1, EPOCHS + 1):
            seed = seed + epoch
            seed_torch(seed)

            if epoch >= EMA_START:
                ema_decay = 0.99

            LOGGER.info("Starting {} epoch...".format(epoch))
            tr_loss = train_one_epoch(model,
                                      train_loader,
                                      criterion,
                                      optimizer,
                                      device,
                                      cutmix_prob=0.0,
                                      classification=CLASSIFICATION,
                                      ema_model=ema_model,
                                      ema_decay=ema_decay)
            train_losses.append(tr_loss)
            LOGGER.info('Mean train loss: {}'.format(round(tr_loss, 5)))

            valid_loss = validate(model,
                                  val_loader,
                                  criterion,
                                  device,
                                  classification=CLASSIFICATION)
            valid_losses.append(valid_loss)
            LOGGER.info('Mean valid loss: {}'.format(round(valid_loss, 5)))

            if EMA and epoch >= EMA_START:
                ema_valid_loss = validate(ema_model,
                                          val_loader,
                                          criterion,
                                          device,
                                          classification=CLASSIFICATION)
                LOGGER.info('Mean EMA valid loss: {}'.format(
                    round(ema_valid_loss, 5)))

                if ema_valid_loss < best_model_ema_loss:
                    torch.save(
                        ema_model.module.state_dict(),
                        'models/{}_fold{}_ckpt{}_ema.pth'.format(
                            EXP_ID, FOLD_ID, checkpoint))
                    best_model_ema_loss = ema_valid_loss

            scheduler.step()

            if valid_loss < best_model_loss:
                torch.save(
                    model.module.state_dict(),
                    'models/{}_fold{}_ckpt{}.pth'.format(
                        EXP_ID, FOLD_ID, checkpoint))
                best_model_loss = valid_loss
                best_model_ep = epoch
                #np.save("val_pred.npy", val_pred)

            if epoch % (CLR_CYCLE * 2) == CLR_CYCLE * 2 - 1:
                torch.save(
                    model.module.state_dict(),
                    'models/{}_fold{}_latest.pth'.format(EXP_ID, FOLD_ID))
                LOGGER.info('Best valid loss: {} on epoch={}'.format(
                    round(best_model_loss, 5), best_model_ep))
                if EMA:
                    torch.save(
                        ema_model.module.state_dict(),
                        'models/{}_fold{}_latest_ema.pth'.format(
                            EXP_ID, FOLD_ID))
                    LOGGER.info('Best ema valid loss: {}'.format(
                        round(best_model_ema_loss, 5)))
                    best_model_ema_loss = 999
                checkpoint += 1
                best_model_loss = 999

            #del val_pred
            gc.collect()

    LOGGER.info('Best valid loss: {} on epoch={}'.format(
        round(best_model_loss, 5), best_model_ep))

    xs = list(range(1, len(train_losses) + 1))
    plt.plot(xs, train_losses, label='Train loss')
    plt.plot(xs, valid_losses, label='Val loss')
    plt.legend()
    plt.xticks(xs)
    plt.xlabel('Epochs')
    plt.savefig("loss.png")
コード例 #5
0
def main(seed):
    with timer('load data'):
        df = pd.read_csv(FOLD_PATH)
        soft_df = pd.read_csv(SOFT_PATH)
        df = df.append(pd.read_csv(PSEUDO_PATH)).reset_index(drop=True)
        soft_df = soft_df.append(
            pd.read_csv(PSEUDO_PATH)).reset_index(drop=True)
        soft_df = df[[ID_COLUMNS]].merge(soft_df, how="left", on=ID_COLUMNS)
        LOGGER.info(df.head())
        LOGGER.info(soft_df.head())
        for c in [
                "EncodedPixels_1", "EncodedPixels_2", "EncodedPixels_3",
                "EncodedPixels_4"
        ]:
            df[c] = df[c].astype(str)
            soft_df[c] = soft_df[c].astype(str)
        df["fold_id"] = df["fold_id"].fillna(FOLD_ID + 1)
        y = (df.sum_target != 0).astype("float32").values
        y += (soft_df.sum_target != 0).astype("float32").values
        y = y / 2

    with timer('preprocessing'):
        train_df, val_df = df[df.fold_id != FOLD_ID], df[df.fold_id == FOLD_ID]
        train_soft_df, val_soft_df = soft_df[df.fold_id != FOLD_ID], soft_df[
            df.fold_id == FOLD_ID]
        y_train, y_val = y[df.fold_id != FOLD_ID], y[df.fold_id == FOLD_ID]

        train_augmentation = Compose([
            Flip(p=0.5),
            OneOf([
                GridDistortion(p=0.5),
                OpticalDistortion(p=0.5, distort_limit=2, shift_limit=0.5)
            ],
                  p=0.5),
            OneOf([
                RandomGamma(gamma_limit=(100, 140), p=0.5),
                RandomBrightnessContrast(p=0.5),
            ],
                  p=0.5),
            OneOf([
                GaussNoise(p=0.5),
            ], p=0.5),
            ShiftScaleRotate(rotate_limit=20, p=0.5),
        ])
        val_augmentation = None

        train_dataset = SeverDataset(train_df,
                                     IMG_DIR,
                                     IMG_SIZE,
                                     N_CLASSES,
                                     id_colname=ID_COLUMNS,
                                     transforms=train_augmentation,
                                     crop_rate=1.0,
                                     class_y=y_train,
                                     soft_df=train_soft_df)
        val_dataset = SeverDataset(val_df,
                                   IMG_DIR,
                                   IMG_SIZE,
                                   N_CLASSES,
                                   id_colname=ID_COLUMNS,
                                   transforms=val_augmentation,
                                   soft_df=val_soft_df)
        train_sampler = MaskProbSampler(train_df, demand_non_empty_proba=0.6)
        train_loader = DataLoader(train_dataset,
                                  batch_size=BATCH_SIZE,
                                  sampler=train_sampler,
                                  num_workers=8)
        val_loader = DataLoader(val_dataset,
                                batch_size=BATCH_SIZE,
                                shuffle=False,
                                num_workers=8)

        del train_df, val_df, df, train_dataset, val_dataset
        gc.collect()

    with timer('create model'):
        model = smp_old.Unet('resnet34',
                             encoder_weights="imagenet",
                             classes=N_CLASSES,
                             encoder_se_module=True,
                             decoder_semodule=True,
                             h_columns=False,
                             skip=True,
                             act="swish",
                             freeze_bn=True,
                             classification=CLASSIFICATION)
        model = convert_model(model)
        if base_model is not None:
            model.load_state_dict(torch.load(base_model))
        model.to(device)

        criterion = torch.nn.BCEWithLogitsLoss()
        optimizer = torch.optim.Adam([
            {
                'params': model.decoder.parameters(),
                'lr': 3e-3
            },
            {
                'params': model.encoder.parameters(),
                'lr': 3e-4
            },
        ])
        if base_model is None:
            scheduler_cosine = CosineAnnealingLR(optimizer,
                                                 T_max=CLR_CYCLE,
                                                 eta_min=3e-5)
            scheduler = GradualWarmupScheduler(
                optimizer,
                multiplier=1.1,
                total_epoch=CLR_CYCLE * 2,
                after_scheduler=scheduler_cosine)
        else:
            scheduler = CosineAnnealingLR(optimizer,
                                          T_max=CLR_CYCLE,
                                          eta_min=3e-5)

        model, optimizer = amp.initialize(model,
                                          optimizer,
                                          opt_level="O1",
                                          verbosity=0)
        model = torch.nn.DataParallel(model)

    with timer('train'):
        train_losses = []
        valid_losses = []

        best_model_loss = 999
        best_model_ep = 0
        checkpoint = base_ckpt + 1

        for epoch in range(1, EPOCHS + 1):
            seed = seed + epoch
            seed_torch(seed)

            LOGGER.info("Starting {} epoch...".format(epoch))
            tr_loss = train_one_epoch(model,
                                      train_loader,
                                      criterion,
                                      optimizer,
                                      device,
                                      cutmix_prob=0.0,
                                      classification=CLASSIFICATION)
            train_losses.append(tr_loss)
            LOGGER.info('Mean train loss: {}'.format(round(tr_loss, 5)))

            valid_loss, val_score = validate(model,
                                             val_loader,
                                             criterion,
                                             device,
                                             classification=CLASSIFICATION)
            valid_losses.append(valid_loss)
            LOGGER.info('Mean valid loss: {}'.format(round(valid_loss, 5)))
            LOGGER.info('Mean valid score: {}'.format(round(val_score, 5)))

            scheduler.step()

            if valid_loss < best_model_loss:
                torch.save(
                    model.module.state_dict(),
                    'models/{}_fold{}_ckpt{}.pth'.format(
                        EXP_ID, FOLD_ID, checkpoint))
                best_model_loss = valid_loss
                best_model_ep = epoch
                #np.save("val_pred.npy", val_pred)

            if epoch % (CLR_CYCLE * 2) == CLR_CYCLE * 2 - 1:
                torch.save(
                    model.module.state_dict(),
                    'models/{}_fold{}_latest.pth'.format(EXP_ID, FOLD_ID))
                LOGGER.info('Best valid loss: {} on epoch={}'.format(
                    round(best_model_loss, 5), best_model_ep))
                checkpoint += 1
                best_model_loss = 999

            #del val_pred
            gc.collect()

    LOGGER.info('Best valid loss: {} on epoch={}'.format(
        round(best_model_loss, 5), best_model_ep))

    xs = list(range(1, len(train_losses) + 1))
    plt.plot(xs, train_losses, label='Train loss')
    plt.plot(xs, valid_losses, label='Val loss')
    plt.legend()
    plt.xticks(xs)
    plt.xlabel('Epochs')
    plt.savefig("loss.png")
コード例 #6
0
                    './checkpoint/withLars-' + str(hp.batch_size) + '.pth')
            else:
                torch.save(
                    state,
                    './checkpoint/noLars-' + str(hp.batch_size) + '.pth')
            best_acc = acc

    if hp.with_lars:
        print('Resnet50, data=cifar10, With LARS')
    else:
        print('Resnet50, data=cifar10, Without LARS')
    hp.print_hyperparms()
    for epoch in range(0, hp.num_of_epoch):
        print('\nEpoch: %d' % epoch)
        if epoch <= hp.warmup_epoch:  # for readability
            warmup_scheduler.step()
        if epoch > hp.warmup_epoch:  # after warmup, start decay scheduler with warmup-ed learning rate
            poly_decay_scheduler.base_lrs = warmup_scheduler.get_lr()
        for param_group in optimizer.param_groups:
            print('lr: ' + str(param_group['lr']))
        train(epoch)
        test(epoch)

        epochs.append(epoch)
        train_accs.append(100. * train_correct / train_total)
        test_accs.append(100. * test_correct / test_total)

        plt.plot(epochs, train_accs, epochs, test_accs, 'r-')
        state = {'test_acc': test_accs}

        if not os.path.isdir('result_fig'):
コード例 #7
0
import torch
import torch.optim as optim
from torch.optim import lr_scheduler
from torchvision.models import AlexNet
import matplotlib.pyplot as plt
from archs.mnist import AlexNet, LeNet5, fc1, vgg, resnet
from scheduler import GradualWarmupScheduler

model = fc1.fc1()
optimizer = optim.SGD(params=model.parameters(), lr=0.05)
scheduler_steplr = lr_scheduler.MultiStepLR(
    optimizer, milestones=[5,
                           10], gamma=0.1)  # means multistep will start with 0
scheduler_warmup = GradualWarmupScheduler(
    optimizer, multiplier=1, total_epoch=10,
    after_scheduler=scheduler_steplr)  # 20K=(idx)56, 35K=70

plt.figure()
x = list(range(20))
y = []

for epoch in range(20):
    scheduler_warmup.step()
    lr = scheduler.get_lr()
    print(epoch, scheduler.get_lr()[0])
    y.append(scheduler.get_lr()[0])

plt.plot(x, y)
コード例 #8
0
class CTTTrainer(WandBMixin, IOMixin, BaseExperiment):
    WANDB_PROJECT = "ctt"

    def __init__(self):
        super(CTTTrainer, self).__init__()
        self.auto_setup()
        self._build()

    def _build(self):
        self._build_loaders()
        self._build_model()
        self._build_criteria_and_optim()
        self._build_scheduler()

    def _build_model(self):
        self.model: nn.Module = to_device(
            ContactTracingTransformer(**self.get("model/kwargs", {})),
            self.device)

    def _build_loaders(self):
        train_path = self.get("data/paths/train", ensure_exists=True)
        validate_path = self.get("data/paths/validate", ensure_exists=True)
        self.train_loader = get_dataloader(path=train_path,
                                           **self.get("data/loader_kwargs",
                                                      ensure_exists=True))
        self.validate_loader = get_dataloader(path=validate_path,
                                              **self.get("data/loader_kwargs",
                                                         ensure_exists=True))

    def _build_criteria_and_optim(self):
        # noinspection PyArgumentList
        self.loss = WeightedSum.from_config(
            self.get("losses", ensure_exists=True))
        self.optim = torch.optim.Adam(self.model.parameters(),
                                      **self.get("optim/kwargs"))
        self.metrics = Metrics()

    def _build_scheduler(self):
        if self.get("scheduler/use", False):
            self._base_scheduler = CosineAnnealingLR(
                self.optim,
                T_max=self.get("training/num_epochs"),
                **self.get("scheduler/kwargs", {}),
            )
        else:
            self._base_scheduler = None
        # Support for LR warmup
        if self.get("scheduler/warmup", False):
            assert self._base_scheduler is not None
            self.scheduler = GradualWarmupScheduler(
                self.optim,
                multiplier=1,
                total_epoch=5,
                after_scheduler=self._base_scheduler,
            )
        else:
            self.scheduler = self._base_scheduler

    @property
    def device(self):
        return self.get("device", "cpu")

    @register_default_dispatch
    def train(self):
        if self.get("wandb/use", True):
            self.initialize_wandb()
        for epoch in self.progress(range(
                self.get("training/num_epochs", ensure_exists=True)),
                                   tag="epochs"):
            self.log_learning_rates()
            self.train_epoch()
            validation_stats = self.validate_epoch()
            self.checkpoint()
            self.log_progress("epochs", **validation_stats)
            self.step_scheduler(epoch)
            self.next_epoch()

    def train_epoch(self):
        self.clear_moving_averages()
        self.model.train()
        for model_input in self.progress(self.train_loader, tag="train"):
            # Evaluate model
            model_input = to_device(model_input, self.device)
            model_output = Dict(self.model(model_input))
            # Compute loss
            losses = self.loss(model_input, model_output)
            loss = losses.loss
            self.optim.zero_grad()
            loss.backward()
            self.optim.step()
            # Log to wandb (if required)
            self.log_training_losses(losses)
            # Log to pbar
            self.accumulate_in_cache("moving_loss", loss.item(),
                                     momentum_accumulator(0.9))
            self.log_progress(
                "train",
                loss=self.read_from_cache("moving_loss"),
            )
            self.next_step()

    def validate_epoch(self):
        all_losses_and_metrics = defaultdict(list)
        self.metrics.reset()
        self.model.eval()
        for model_input in self.progress(self.validate_loader,
                                         tag="validation"):
            with torch.no_grad():
                model_input = to_device(model_input, self.device)
                model_output = Dict(self.model(model_input))
                losses = self.loss(model_input, model_output)
                self.metrics.update(model_input, model_output)
                all_losses_and_metrics["loss"].append(losses.loss.item())
                for key in losses.unweighted_losses:
                    all_losses_and_metrics[key].append(
                        losses.unweighted_losses[key].item())
        # Compute mean for all losses
        all_losses_and_metrics = Dict(
            {key: np.mean(val)
             for key, val in all_losses_and_metrics.items()})
        all_losses_and_metrics.update(Dict(self.metrics.evaluate()))
        self.log_validation_losses_and_metrics(all_losses_and_metrics)
        # Store the validation loss in cache. This will be used for checkpointing.
        self.write_to_cache("current_validation_loss",
                            all_losses_and_metrics.loss)
        return all_losses_and_metrics

    def log_training_losses(self, losses):
        if not self.get("wandb/use", True):
            return self
        if self.log_wandb_now:
            metrics = Dict({"training_loss": losses.loss})
            metrics.update({
                f"training_{k}": v
                for k, v in losses.unweighted_losses.items()
            })
            self.wandb_log(**metrics)
        return self

    def checkpoint(self, force=False):
        current_validation_loss = self.read_from_cache(
            "current_validation_loss", float("inf"))
        best_validation_loss = self.read_from_cache("best_validation_loss",
                                                    float("inf"))
        if current_validation_loss < best_validation_loss:
            self.write_to_cache("best_validation_loss",
                                current_validation_loss)
            ckpt_path = os.path.join(self.checkpoint_directory, "best.ckpt")
        elif self.get_arg("force_checkpoint", force):
            ckpt_path = os.path.join(self.checkpoint_directory, "best.ckpt")
        else:
            ckpt_path = None
        if ckpt_path is not None:
            info_dict = {
                "model": self.model.state_dict(),
                "optim": self.optim.state_dict(),
            }
            torch.save(info_dict, ckpt_path)
        return self

    def load(self, device=None):
        ckpt_path = os.path.join(self.checkpoint_directory, "best.ckpt")
        if not os.path.exists(ckpt_path):
            raise FileNotFoundError
        info_dict = torch.load(
            ckpt_path,
            map_location=torch.device(
                (self.device if device is None else device)),
        )
        self.model.load_state_dict(info_dict["model"])
        self.optim.load_state_dict(info_dict["optim"])
        return self

    def log_validation_losses_and_metrics(self, losses):
        if not self.get("wandb/use", True):
            return self
        metrics = {f"validation_{k}": v for k, v in losses.items()}
        self.wandb_log(**metrics)
        return self

    def clear_moving_averages(self):
        return self.clear_in_cache("moving_loss")

    def step_scheduler(self, epoch):
        if self.scheduler is not None:
            self.scheduler.step(epoch)
        return self

    def log_learning_rates(self):
        if not self.get("wandb/use", True):
            return self
        lrs = {
            f"lr_{i}": param_group["lr"]
            for i, param_group in enumerate(self.optim.param_groups)
        }
        self.wandb_log(**lrs)
        return self
コード例 #9
0
    # Use the nn package to define our model and loss function.
    model = torch.nn.Sequential(
        torch.nn.Linear(D_in, H),
        torch.nn.ReLU(),
        torch.nn.Linear(H, D_out),
    )

    #v = torch.zeros(10)
    optim = torch.optim.SGD(model.parameters(), lr=0.01)
    max_epoch = 100
    #optim = torch.optim.SGD([v], lr=0.01)
    scheduler_cosine = torch.optim.lr_scheduler.CosineAnnealingLR(
        optim, max_epoch)
    scheduler = GradualWarmupScheduler(optimizer=optim,
                                       multiplier=8,
                                       total_epoch=10,
                                       after_scheduler=scheduler_cosine)

    x = []
    y = []
    for epoch in range(1, max_epoch):
        scheduler.step(epoch)
        x.append(epoch)
        y.append(optim.param_groups[0]['lr'])
        print(optim.param_groups[0]['lr'])
        #print(epoch, optim.param_groups[0]['lr'])

    #fig = plt.figure()
    #fig.plot(x,y)
    plt.scatter(x, y, color='red')
    plt.show()
コード例 #10
0
        # sample train
        train_dataloader, _ = data_loader(root=DATASET_PATH,
                                          phase='train',
                                          batch_size=batch)
        for iter_, data in enumerate(train_dataloader, 0):
            iter1_, img0, iter2_, img1, label = data
            img0, img1, label = img0.cuda(), img1.cuda(), label.cuda()
            optimizer.zero_grad()
            model.train()

            output1, output2 = model(img0, img1)
            loss_contrastive = criterion(output1, output2, label)
            loss_contrastive.backward()
            optimizer.step()
            # cosine scheduler
            scheduler.step()
            if iter_ % print_iter == 0:
                elapsed = datetime.datetime.now() - time_
                expected = elapsed * (num_batches / print_iter)
                _epoch = epoch + ((iter_ + 1) / num_batches)
                print('[{:.3f}/{:d}] loss({}) '
                      'elapsed {} expected per epoch {}'.format(
                          _epoch, num_epochs, loss_contrastive.item(), elapsed,
                          expected))
        # StepLR 경우
        # scheduler.step()
        save_model(str(epoch + 1), model, optimizer)
        if (epoch + 1) % 1 == 0:
            # evaluation
            eval_loss = 0.0
            nb_eval_steps = 0
コード例 #11
0
def main():
    with timer('load data'):
        df = pd.read_csv(FOLD_PATH)
        y = (df["sum_target"] != 0).values.astype("float32")

    with timer('preprocessing'):
        train_df, val_df = df[df.fold_id != FOLD_ID], df[df.fold_id == FOLD_ID]
        y_train, y_val = y[df.fold_id != FOLD_ID], y[df.fold_id == FOLD_ID]

        train_augmentation = Compose([
            Flip(p=0.5),
            OneOf(
                [
                    #ElasticTransform(p=0.5, alpha=120, sigma=120 * 0.05, alpha_affine=120 * 0.03),
                    GridDistortion(p=0.5),
                    OpticalDistortion(p=0.5, distort_limit=2, shift_limit=0.5)
                ],
                p=0.5),
            #OneOf([
            #    ShiftScaleRotate(p=0.5),
            ##    RandomRotate90(p=0.5),
            #    Rotate(p=0.5)
            #], p=0.5),
            OneOf([
                Blur(blur_limit=8, p=0.5),
                MotionBlur(blur_limit=8, p=0.5),
                MedianBlur(blur_limit=8, p=0.5),
                GaussianBlur(blur_limit=8, p=0.5)
            ],
                  p=0.5),
            OneOf(
                [
                    #CLAHE(clip_limit=4, tile_grid_size=(4, 4), p=0.5),
                    RandomGamma(gamma_limit=(100, 140), p=0.5),
                    RandomBrightnessContrast(p=0.5),
                    RandomBrightness(p=0.5),
                    RandomContrast(p=0.5)
                ],
                p=0.5),
            OneOf([
                GaussNoise(p=0.5),
                Cutout(num_holes=10, max_h_size=10, max_w_size=20, p=0.5)
            ],
                  p=0.5)
        ])
        val_augmentation = None

        train_dataset = SeverDataset(train_df,
                                     IMG_DIR,
                                     IMG_SIZE,
                                     N_CLASSES,
                                     id_colname=ID_COLUMNS,
                                     transforms=train_augmentation,
                                     class_y=y_train)
        val_dataset = SeverDataset(val_df,
                                   IMG_DIR,
                                   IMG_SIZE,
                                   N_CLASSES,
                                   id_colname=ID_COLUMNS,
                                   transforms=val_augmentation,
                                   class_y=y_val)
        train_loader = DataLoader(train_dataset,
                                  batch_size=BATCH_SIZE,
                                  shuffle=True,
                                  num_workers=2)
        val_loader = DataLoader(val_dataset,
                                batch_size=BATCH_SIZE,
                                shuffle=False,
                                num_workers=2)

        del train_df, val_df, df, train_dataset, val_dataset
        gc.collect()

    with timer('create model'):
        model = smp.UnetPP('se_resnext50_32x4d',
                           encoder_weights='imagenet',
                           classes=N_CLASSES,
                           encoder_se_module=True,
                           decoder_semodule=True,
                           h_columns=False,
                           deep_supervision=True,
                           classification=CLASSIFICATION)
        #model.load_state_dict(torch.load(model_path))
        model.to(device)

        criterion = torch.nn.BCEWithLogitsLoss()
        optimizer = torch.optim.Adam(model.parameters(), lr=3e-4)
        scheduler_cosine = CosineAnnealingLR(optimizer,
                                             T_max=CLR_CYCLE,
                                             eta_min=3e-5)
        scheduler = GradualWarmupScheduler(optimizer,
                                           multiplier=1.1,
                                           total_epoch=CLR_CYCLE * 2,
                                           after_scheduler=scheduler_cosine)

        model, optimizer = amp.initialize(model,
                                          optimizer,
                                          opt_level="O1",
                                          verbosity=0)

    with timer('train'):
        train_losses = []
        valid_losses = []

        best_model_loss = 999
        best_model_ep = 0
        checkpoint = 0

        for epoch in range(1, EPOCHS + 1):
            if epoch % (CLR_CYCLE * 2) == 0:
                if epoch != 0:
                    y_val = y_val.reshape(-1, N_CLASSES, IMG_SIZE[0],
                                          IMG_SIZE[1])
                    best_pred = best_pred.reshape(-1, N_CLASSES, IMG_SIZE[0],
                                                  IMG_SIZE[1])
                    for i in range(N_CLASSES):
                        th, score, _, _ = search_threshold(
                            y_val[:, i, :, :], best_pred[:, i, :, :])
                        LOGGER.info(
                            'Best loss: {} Best Dice: {} on epoch {} th {} class {}'
                            .format(round(best_model_loss, 5), round(score, 5),
                                    best_model_ep, th, i))
                checkpoint += 1
                best_model_loss = 999

            LOGGER.info("Starting {} epoch...".format(epoch))
            tr_loss = train_one_epoch_dsv(model,
                                          train_loader,
                                          criterion,
                                          optimizer,
                                          device,
                                          classification=CLASSIFICATION)
            train_losses.append(tr_loss)
            LOGGER.info('Mean train loss: {}'.format(round(tr_loss, 5)))

            valid_loss, val_pred, y_val = validate_dsv(model, val_loader,
                                                       criterion, device)
            valid_losses.append(valid_loss)
            LOGGER.info('Mean valid loss: {}'.format(round(valid_loss, 5)))

            scheduler.step()

            if valid_loss < best_model_loss:
                torch.save(
                    model.state_dict(),
                    '{}_fold{}_ckpt{}.pth'.format(EXP_ID, FOLD_ID, checkpoint))
                best_model_loss = valid_loss
                best_model_ep = epoch
                best_pred = val_pred

            del val_pred
            gc.collect()

    with timer('eval'):
        y_val = y_val.reshape(-1, N_CLASSES, IMG_SIZE[0], IMG_SIZE[1])
        best_pred = best_pred.reshape(-1, N_CLASSES, IMG_SIZE[0], IMG_SIZE[1])
        for i in range(N_CLASSES):
            th, score, _, _ = search_threshold(y_val[:, i, :, :],
                                               best_pred[:, i, :, :])
            LOGGER.info(
                'Best loss: {} Best Dice: {} on epoch {} th {} class {}'.
                format(round(best_model_loss, 5), round(score, 5),
                       best_model_ep, th, i))

    xs = list(range(1, len(train_losses) + 1))
    plt.plot(xs, train_losses, label='Train loss')
    plt.plot(xs, valid_losses, label='Val loss')
    plt.legend()
    plt.xticks(xs)
    plt.xlabel('Epochs')
    plt.savefig("loss.png")