Esempio n. 1
0
def train(config):
    seed = config["seed"]
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)

    trainTransform, testTransform = data.get_transforms(config)
    trainLoader = data.get_dataloader(config,
                                      train=True,
                                      transform=trainTransform)
    testLoader = data.get_dataloader(config,
                                     train=False,
                                     transform=testTransform)

    densenet = model.get_model(config).to(config["device"])
    opt, scheduler = optimizer.get_scheduled_optimizer(config, densenet)
    criterion = CrossEntropyLoss()
    recorder = Recorder(config)

    max_epoch = config["max_epoch"]
    for epoch in range(max_epoch):
        print("epoch:{:0>3}".format(epoch))
        train_out = _train(densenet, criterion, trainLoader, config["device"],
                           opt)
        recorder(epoch, train_out=train_out)
        test_out = _test(densenet,
                         criterion,
                         testLoader,
                         config["device"],
                         need_output_y=False)
        recorder(epoch, test_out=test_out)
        scheduler.step()
        torch.save(densenet, os.path.join(config["save_dir"], "latest.pth"))
    print("train finished")
    return None
Esempio n. 2
0
def setup_data_loaders(args):
    train_transforms, val_transforms = get_transforms(
        crop_size=args.crop_size,
        shorter_side=args.shorter_side,
        low_scale=args.low_scale,
        high_scale=args.high_scale,
        img_mean=args.img_mean,
        img_std=args.img_std,
        img_scale=args.img_scale,
        ignore_label=args.ignore_label,
        num_stages=args.num_stages,
        augmentations_type=args.augmentations_type,
        dataset_type=args.dataset_type,
    )
    train_sets, val_set = get_datasets(
        train_dir=args.train_dir,
        val_dir=args.val_dir,
        train_list_path=args.train_list_path,
        val_list_path=args.val_list_path,
        train_transforms=train_transforms,
        val_transforms=val_transforms,
        masks_names=("segm", ),
        dataset_type=args.dataset_type,
        stage_names=args.stage_names,
        train_download=args.train_download,
        val_download=args.val_download,
    )
    train_loaders, val_loader = dt.data.get_loaders(
        train_batch_size=args.train_batch_size,
        val_batch_size=args.val_batch_size,
        train_set=train_sets,
        val_set=val_set,
        num_stages=args.num_stages,
    )
    return train_loaders, val_loader
def test_transforms(
        augmentations_type,
        crop_size,
        dataset_type,
        num_stages,
        shorter_side,
        low_scale,
        high_scale,
        img_mean=(0.5, 0.5, 0.5),
        img_std=(0.5, 0.5, 0.5),
        img_scale=1.0 / 255,
        ignore_label=255,
):
    train_transforms, val_transforms = get_transforms(
        crop_size=broadcast(crop_size, num_stages),
        shorter_side=broadcast(shorter_side, num_stages),
        low_scale=broadcast(low_scale, num_stages),
        high_scale=broadcast(high_scale, num_stages),
        img_mean=(0.5, 0.5, 0.5),
        img_std=(0.5, 0.5, 0.5),
        img_scale=1.0 / 255,
        ignore_label=255,
        num_stages=num_stages,
        augmentations_type=augmentations_type,
        dataset_type=dataset_type,
    )
    assert len(train_transforms) == num_stages
    for is_val, transform in zip([False] * num_stages + [True],
                                 train_transforms + [val_transforms]):
        image, mask = get_dummy_image_and_mask()
        sample = pack_sample(image=image, mask=mask, dataset_type=dataset_type)
        output = transform(*sample)
        image_output, mask_output = unpack_sample(sample=output,
                                                  dataset_type=dataset_type)
        # Test shape
        if not is_val:
            assert (image_output.shape[-2:] == mask_output.shape[-2:] ==
                    (crop_size, crop_size))
        # Test that the outputs are torch tensors
        assert isinstance(image_output, torch.Tensor)
        assert isinstance(mask_output, torch.Tensor)
        # Test that there are no new segmentation classes, except for probably ignore_label
        uq_classes_before = np.unique(mask)
        uq_classes_after = np.unique(mask_output.numpy())
        assert (len(
            np.setdiff1d(uq_classes_after,
                         uq_classes_before.tolist() + [ignore_label])) == 0)
        if is_val:
            # Test that for validation transformation the output shape has not changed
            assert (image_output.shape[-2:] == image.shape[:2] ==
                    mask_output.shape[-2:] == mask.shape[:2])
            # Test that there were no changes to the classes at all
            assert all(uq_classes_before == uq_classes_after)
Esempio n. 4
0
def test(config, testLoader=None):
    # check if there's weight in work/* (* means the name of setting) before run test.
    if testLoader is None:
        trainTransform, testTransform = data.get_transforms(config)
        testLoader = data.get_dataloader(config,
                                         train=False,
                                         transform=testTransform)

    densenet = model.get_model(config)
    densenet.load_state_dict(os.path.join(config["save_dir"], "latest.pth"))
    criterion = CrossEntropyLoss()
    test_out = _test(densenet, criterion, testLoader, need_output_y=True)
    print("test finished")
    return test_out
def main(
    fold: None = None,
    show_info: bool = True,
    plot_data: bool = False,
    display_metrics: bool = False,
) -> None:
    # instantiate the elm_model and load the checkpoint
    elm_model = cnn_feature_model.CNNModel()
    # elm_model = model.StackedELMModel()
    model_name = type(elm_model).__name__
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model_ckpt_path = os.path.join(
        config.model_dir,
        f"{model_name}_fold{fold}_best_roc_{config.data_mode}.pth",
    )
    print(f"Using elm_model checkpoint: {model_ckpt_path}")
    elm_model.load_state_dict(
        torch.load(
            model_ckpt_path,
            map_location=device,
        )["model"])
    elm_model = elm_model.to(device)

    # get the test data and dataloader
    f_name = f"test_data_{config.data_mode}.pkl"
    print(f"Using test data file: {f_name}")
    test_transforms = data.get_transforms()
    test_data, test_dataset = get_test_dataset(file_name=f_name,
                                               transforms=test_transforms)
    test_loader = torch.utils.data.DataLoader(
        test_dataset,
        batch_size=config.batch_size,
        shuffle=False,
        drop_last=True,
    )

    if show_info:
        show_details(test_data)

    targets, predictions = model_predict(elm_model, device, test_loader)

    if plot_data:
        plot(test_data, elm_model, device)

    if display_metrics:
        show_metrics(model_name, targets, predictions)
Esempio n. 6
0
    train_batch_size = 256
    train_dataloader_num_workers = 2
    test_batch_size = 256
    test_dataloader_num_workers = 2
    train_epochs = 10
    device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
    learning_rate = 1e-3
    phase = 'train'
    ## End of Global args ##
    make_folders()

    print("Device: {}".format(torch.cuda.get_device_name()))

    train_dataset = CIFAR10(root='datasets',
                            train=True,
                            transform=get_transforms(resize_shape=64),
                            download=True)
    train_dataloader = get_dataloader(dataset=train_dataset,
                                      bs=train_batch_size,
                                      num_workers=train_dataloader_num_workers)

    test_dataset = CIFAR10(root='datasets',
                           train=False,
                           transform=get_transforms(resize_shape=64),
                           download=True)
    test_dataloader = get_dataloader(dataset=test_dataset,
                                     bs=test_batch_size,
                                     num_workers=test_dataloader_num_workers)
    num_classes = len(train_dataset.classes)
    optimizer = optim.Adam
    loss_fn = CrossEntropyLoss()
Esempio n. 7
0
def main():
    global args
    args = parser.parse_args()

    os.makedirs(args.save_path, exist_ok=True)
    input_dims = (args.image_height, args.image_width)
    log_path = os.path.join(args.save_path, 'log.txt')
    with open(log_path, 'w+') as f:
        f.write(
            '\n'.join(['%s: %s' % (k, v)
                       for k, v in args.__dict__.items()]) + '\n')

    # create Light CNN for face recognition
    if args.model == 'LightCNN-9':
        model = LightCNN_9Layers(num_classes=args.num_classes,
                                 input_dims=input_dims)
    elif args.model == 'LightCNN-29':
        model = LightCNN_29Layers(num_classes=args.num_classes,
                                  input_dims=input_dims)
    elif args.model == 'LightCNN-29v2':
        model = LightCNN_29Layers_v2(num_classes=args.num_classes,
                                     input_dims=input_dims)
    else:
        print('Error model type\n')

    if args.cuda:
        model = torch.nn.DataParallel(model).cuda()

    print(model)

    # large lr for last fc parameters
    params = []
    for name, value in model.named_parameters():
        if 'bias' in name:
            if 'fc2' in name:
                params += [{
                    'params': value,
                    'lr': 20 * args.lr,
                    'weight_decay': 0
                }]
            else:
                params += [{
                    'params': value,
                    'lr': 2 * args.lr,
                    'weight_decay': 0
                }]
        else:
            if 'fc2' in name:
                params += [{'params': value, 'lr': 10 * args.lr}]
            else:
                params += [{'params': value, 'lr': 1 * args.lr}]

    optimizer = torch.optim.SGD(params,
                                args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            model.load_state_dict(checkpoint['state_dict'])
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    cudnn.benchmark = True

    #load image
    train_loader = torch.utils.data.DataLoader(get_dataset(
        args.dataset,
        args.root_path,
        args.train_list,
        transform=get_transforms(dataset=args.dataset, phase='train')),
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=args.workers,
                                               pin_memory=True)

    val_loader = torch.utils.data.DataLoader(get_dataset(
        args.dataset,
        args.root_path,
        args.val_list,
        transform=get_transforms(dataset=args.dataset, phase='train')),
                                             batch_size=args.batch_size,
                                             shuffle=False,
                                             num_workers=args.workers,
                                             pin_memory=True)

    # define loss function and optimizer
    criterion = nn.CrossEntropyLoss()

    if args.cuda:
        criterion.cuda()

    # validate(val_loader, model, criterion)
    with trange(args.start_epoch, args.epochs) as epochs:
        for epoch in epochs:
            epochs.set_description('Epoch %d' % epoch)

            adjust_learning_rate(optimizer, epoch)

            # train for one epoch
            train(train_loader, model, criterion, optimizer, epoch, log_path)
            if epoch % args.val_freq == 0:
                # evaluate on validation set
                prec1 = validate(val_loader, model, criterion, log_path)

            save_name = args.save_path + 'lightCNN_' + str(
                epoch + 1) + '_checkpoint.pth.tar'
            save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'arch': args.arch,
                    'state_dict': model.state_dict(),
                    'prec1': prec1,
                }, save_name)
Esempio n. 8
0
from torchvision.utils import save_image
import torchvision.utils as vutils

from ignite.contrib.handlers import ProgressBar
from ignite.engine import Engine, Events
from ignite.handlers import ModelCheckpoint, Timer
from ignite.metrics import RunningAverage

import data
import model
from utils import *
from args import args

CUTOFF_SECONDS = 45000
seed_everything(seed=args.seed)
transform1, transform2 = data.get_transforms(args.image_size)
train_data = data.DogDataset(img_dir=args.root_images,
                             args=args,
                             transform1=transform1,
                             transform2=transform2)
decoded_dog_labels = {
    i: breed
    for i, breed in enumerate(sorted(set(train_data.labels)))
}
encoded_dog_labels = {
    breed: i
    for i, breed in enumerate(sorted(set(train_data.labels)))
}
train_data.labels = [encoded_dog_labels[l] for l in train_data.labels
                     ]  # encode dog labels in the data generator
dataloader = torch.utils.data.DataLoader(train_data,
Esempio n. 9
0
def train_loop(
    data_obj: data.Data,
    test_datafile_name: str,
    model_class=cnn_feature_model.CNNModel,
    kfold: bool = False,
    fold: Union[int, None] = None,
    desc: bool = True,
):
    # TODO: Implement K-fold cross-validation
    if kfold and (fold is None):
        raise Exception(
            f"K-fold cross validation is passed but fold index in range [0, {config.folds}) is not specified."
        )
    if (not kfold) and (fold is not None):
        LOGGER.info(
            f"K-fold is set to {kfold} but fold index is passed!"
            " Proceeding without using K-fold."
        )
        fold = None

    # test data file path
    test_data_file = os.path.join(config.data_dir, test_datafile_name)

    LOGGER.info("-" * 60)
    if config.balance_classes:
        LOGGER.info("Training with balanced classes.")
    else:
        LOGGER.info("Training using unbalanced (original) classes.")

    LOGGER.info(f"Test data will be saved to: {test_data_file}")
    LOGGER.info("-" * 30)
    LOGGER.info(f"       Training fold: {fold}       ")
    LOGGER.info("-" * 30)

    # turn off model details for subsequent folds/epochs
    if fold is not None:
        if fold >= 1:
            desc = False

    # create train, valid and test data
    train_data, valid_data, test_data = data_obj.get_data(
        shuffle_sample_indices=True, fold=fold
    )

    # dump test data into to a file
    with open(test_data_file, "wb") as f:
        pickle.dump(
            {
                "signals": test_data[0],
                "labels": test_data[1],
                "sample_indices": test_data[2],
                "window_start": test_data[3],
            },
            f,
        )

    # create image transforms
    if type(model_class).__name__ in ["FeatureModel", "CNNModel"]:
        transforms = None
    else:
        transforms = data.get_transforms()

    # create datasets
    train_dataset = data.ELMDataset(
        *train_data,
        config.signal_window_size,
        config.label_look_ahead,
        stack_elm_events=config.stack_elm_events,
        transform=transforms,
    )

    valid_dataset = data.ELMDataset(
        *valid_data,
        config.signal_window_size,
        config.label_look_ahead,
        stack_elm_events=config.stack_elm_events,
        transform=transforms,
    )

    # training and validation dataloaders
    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=config.batch_size,
        shuffle=True,
        num_workers=config.num_workers,
        pin_memory=True,
        drop_last=True,
    )

    valid_loader = torch.utils.data.DataLoader(
        valid_dataset,
        batch_size=config.batch_size,
        shuffle=False,
        num_workers=config.num_workers,
        pin_memory=True,
        drop_last=True,
    )

    # model
    model = model_class()
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model.to(device)
    model_name = type(model).__name__
    LOGGER.info("-" * 50)
    LOGGER.info(f"       Training with model: {model_name}       ")
    LOGGER.info("-" * 50)

    # display model details
    if desc:
        if config.stack_elm_events and model_name == "StackedELMModel":
            input_size = (config.batch_size, 1, config.size, config.size)
        else:
            input_size = (
                config.batch_size,
                1,
                config.signal_window_size,
                8,
                8,
            )
        x = torch.rand(*input_size)
        x = x.to(device)
        cnn_feature_model.model_details(model, x, input_size)

    # optimizer
    optimizer = torch.optim.AdamW(
        model.parameters(),
        lr=config.learning_rate,
        weight_decay=config.weight_decay,
        amsgrad=False,
    )

    # get the lr scheduler
    scheduler = get_lr_scheduler(
        optimizer, scheduler_name=config.scheduler, dataloader=train_loader
    )

    # loss function
    criterion = nn.BCEWithLogitsLoss(reduction="none")

    # define variables for ROC and loss
    best_score = 0
    best_loss = np.inf

    # instantiate training object
    engine = run.Run(
        model,
        device=device,
        criterion=criterion,
        optimizer=optimizer,
        use_focal_loss=True,
    )

    # iterate through all the epochs
    for epoch in range(config.epochs):
        start_time = time.time()

        if config.scheduler in [
            "CosineAnnealingLR",
            "CyclicLR",
            "CyclicLR2",
            "OneCycleLR",
        ]:
            # train
            avg_loss = engine.train(
                train_loader, epoch, scheduler=scheduler, print_every=5000
            )

            # evaluate
            avg_val_loss, preds, valid_labels = engine.evaluate(
                valid_loader, print_every=2000
            )
            scheduler = get_lr_scheduler(
                optimizer,
                scheduler_name=config.scheduler,
                dataloader=train_loader,
            )
        else:
            # train
            avg_loss = engine.train(train_loader, epoch, print_every=5000)

            # evaluate
            avg_val_loss, preds, valid_labels = engine.evaluate(
                valid_loader, print_every=2000
            )

            # step the scheduler
            if config.scheduler == "ReduceLROnPlateau":
                scheduler.step(avg_val_loss)
            else:
                scheduler.step()

        # scoring
        roc_score = roc_auc_score(valid_labels, preds)
        elapsed = time.time() - start_time

        LOGGER.info(
            f"Epoch: {epoch + 1}, \tavg train loss: {avg_loss:.4f}, \tavg validation loss: {avg_val_loss:.4f}"
        )
        LOGGER.info(
            f"Epoch: {epoch +1}, \tROC-AUC score: {roc_score:.4f}, \ttime elapsed: {elapsed}"
        )

        # save the model if best ROC is found
        model_save_path = os.path.join(
            config.model_dir,
            f"{model_name}_fold{fold}_best_roc_{config.data_mode}.pth",
        )
        if roc_score > best_score:
            best_score = roc_score
            LOGGER.info(
                f"Epoch: {epoch+1}, \tSave Best Score: {best_score:.4f} Model"
            )
            torch.save(
                {"model": model.state_dict(), "preds": preds},
                model_save_path,
            )

        if avg_val_loss < best_loss:
            best_loss = avg_val_loss
            LOGGER.info(
                f"Epoch: {epoch+1}, \tSave Best Loss: {best_loss:.4f} Model"
            )
        LOGGER.info(f"Model saved to: {model_save_path}")