Пример #1
0
def get_arguments():
    parser = argparse.ArgumentParser()
    parser.add_argument('--batch_size', type=int, default=4, help='batch size for training')
    parser.add_argument('--log_interval', type=int, default=1000, help='steps to print metrics and loss')
    parser.add_argument('--dataset_name', type=str, default="COVID_CT", help='dataset name COVIDx or COVID_CT')
    parser.add_argument('--nEpochs', type=int, default=250, help='total number of epochs')
    parser.add_argument('--device', type=int, default=0, help='gpu device')
    parser.add_argument('--seed', type=int, default=123, help='select seed number for reproducibility')
    parser.add_argument('--classes', type=int, default=3, help='dataset classes')
    parser.add_argument('--lr', default=2e-5, type=float,
                        help='learning rate (default: 1e-3)')
    parser.add_argument('--weight_decay', default=1e-7, type=float,
                        help='weight decay (default: 1e-6)')
    parser.add_argument('--cuda', action='store_true', default=False, help='use gpu for speed-up')
    parser.add_argument('--tensorboard', action='store_true', default=True,
                        help='use tensorboard for loggging and visualization')
    parser.add_argument('--resume', default='', type=str, metavar='PATH',
                        help='path to latest checkpoint (default: none)')
    parser.add_argument('--model', type=str, default='COVIDNet_large',
                        choices=('COVIDNet_small', 'resnet18', 'mobilenet_v2', 'densenet169', 'COVIDNet_large'))
    parser.add_argument('--opt', type=str, default='adam',
                        choices=('sgd', 'adam', 'rmsprop'))
    parser.add_argument('--root_path', type=str, default='./data',
                        help='path to dataset ')
    parser.add_argument('--save', type=str, default='/saved/COVIDNet' + util.datestr(),
                        help='path to checkpoint save directory ')
    args = parser.parse_args()
    return args
Пример #2
0
def main():
    args = get_arguments()
    SEED = args.seed
    torch.manual_seed(SEED)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    np.random.seed(SEED)
    if (args.cuda):
        torch.cuda.manual_seed(SEED)
    model, optimizer, training_generator, val_generator, test_generator = initialize(
        args)

    print(model)

    best_pred_loss = 1000.0
    scheduler = ReduceLROnPlateau(optimizer,
                                  factor=0.5,
                                  patience=2,
                                  min_lr=1e-5,
                                  verbose=True)
    print('Checkpoint folder ', args.save)
    if args.tensorboard:
        writer = SummaryWriter('./runs/' + util.datestr())
    else:
        writer = None
    for epoch in range(1, args.nEpochs + 1):
        train(args, model, training_generator, optimizer, epoch, writer)
        val_metrics, confusion_matrix = validation(args, model, val_generator,
                                                   epoch, writer)

        best_pred_loss = util.save_model(model, optimizer, args, val_metrics,
                                         epoch, best_pred_loss,
                                         confusion_matrix)

        scheduler.step(val_metrics.avg_loss())
Пример #3
0
def get_arguments():

    parser = argparse.ArgumentParser()

    parser.add_argument('--batch_size',
                        type=int,
                        default=12,
                        help='batch size for training')
    parser.add_argument('--log_interval',
                        type=int,
                        default=200,
                        help='steps to print metrics and loss')
    parser.add_argument('--cuda', type=int, default=0, help='use gpu support')
    parser.add_argument('--device', type=int, default=0, help='gpu device')
    parser.add_argument('--dataset_name',
                        type=str,
                        default='COVIDx',
                        help='dataset name COVIDx')
    parser.add_argument('--seed',
                        type=int,
                        default=123,
                        help='select seed number for reproducibility')
    parser.add_argument('--classes',
                        type=int,
                        default=3,
                        help='dataset classes')
    parser.add_argument('--resume',
                        default='',
                        type=str,
                        metavar='PATH',
                        help='path to latest checkpoint (default: none)')
    parser.add_argument('--model',
                        type=str,
                        default='COVIDNet_small',
                        choices=('COVIDNet_small', 'resnet18',
                                 'COVIDNet_large'))
    parser.add_argument('--root_path',
                        type=str,
                        default='./data',
                        help='path to dataset ')
    parser.add_argument('--save',
                        type=str,
                        default='./saved/COVIDNet' + util.datestr(),
                        help='path to checkpoint save directory ')
    parser.add_argument('--epochs',
                        type=int,
                        default=1,
                        help="number of training epochs")
    parser.add_argument('--trials',
                        type=int,
                        default=10,
                        help="number of HPO trials")
    parser.add_argument('--worker_id', type=int, default=0, help="worker id")
    parser.add_argument('--ex_rate',
                        type=int,
                        default=100,
                        help="info exchange rate in HPO")
    args = parser.parse_args()

    return args
Пример #4
0
def get_arguments():
    parser = argparse.ArgumentParser()
    parser.add_argument('--new_training', action='store_true', default=False,
                        help='load saved_model as initial model')
    parser.add_argument('--batch_size', type=int, default=32)
    parser.add_argument('--log_interval', type=int, default=1000)
    parser.add_argument('--dataset_name', type=str, default="COVIDx")
    parser.add_argument('--nEpochs', type=int, default=20)
    parser.add_argument('--device', type=str, default='0')
    parser.add_argument('--seed', type=int, default=123)
    parser.add_argument('--classes', type=int, default=3)
    parser.add_argument('--inChannels', type=int, default=1)
    parser.add_argument('--lr', default=2e-5, type=float,
                        help='learning rate (default: 1e-3)')
    parser.add_argument('--weight_decay', default=1e-7, type=float,
                        help='weight decay (default: 1e-6)')
    parser.add_argument('--cuda', action='store_true', default=True)
    parser.add_argument('--resume', default='', type=str, metavar='PATH',
                        help='path to latest checkpoint (default: none)')
    parser.add_argument('--model', type=str, default='CovidNet_Grad_CAM',
                        choices=('COVIDNet','CovidNet_ResNet50', 'CovidNet_DenseNet', 'CovidNet_Grad_CAM','CovidNet_DE'))
    parser.add_argument('--opt', type=str, default='adam',
                        choices=('sgd', 'adam', 'rmsprop'))
    parser.add_argument('--dataset', type=str, default='/content/covid-chestxray-dataset/data',
                        help='path to dataset ')
    parser.add_argument('--saved_model', type=str, default='/home/julian/Documents/PythonExperiments/COVIDNet/ModelSavedCoviNet/COVIDNet20200406_0412/COVIDNet_best_checkpoint.pth.tar',
                        help='path to save_model ')
    parser.add_argument('--save', type=str, default='/home/julian/Documents/PythonExperiments/COVIDNet/ModelSavedCoviNet/COVIDNet' + util.datestr(),
                        help='path to checkpoint ')
    args = parser.parse_args()
    return args
Пример #5
0
def main():
    args = get_arguments()
    myargs = []  # getopts(sys.argv)
    now = datetime.datetime.now()
    cwd = os.getcwd()
    if len(myargs) > 0:
        if 'c' in myargs:
            config_file = myargs['c']
    else:
        config_file = 'config/trainer_config.yml'

    config = OmegaConf.load(os.path.join(cwd, config_file))['trainer']
    config.cwd = str(cwd)
    reproducibility(config)
    dt_string = now.strftime("%d_%m_%Y_%H.%M.%S")
    cpkt_fol_name = os.path.join(
        config.cwd,
        f'checkpoints/model_{config.model.name}/dataset_{config.dataset.name}/date_{dt_string}'
    )

    log = Logger(path=cpkt_fol_name, name='LOG').get_logger()

    best_pred_loss = 1000.0
    log.info(f"Checkpoint folder {cpkt_fol_name}")
    log.info(f"date and time = {dt_string}")

    log.info(f'pyTorch VERSION:{torch.__version__}', )
    log.info(f'CUDA VERSION')

    log.info(f'CUDNN VERSION:{torch.backends.cudnn.version()}')
    log.info(f'Number CUDA Devices: {torch.cuda.device_count()}')

    if args.tensorboard:

        writer_path = os.path.join(cpkt_fol_name + 'runs/')

        writer = SummaryWriter(writer_path + util.datestr())
    else:
        writer = None

    use_cuda = torch.cuda.is_available()

    device = torch.device("cuda:0" if use_cuda else "cpu")
    log.info(f'device: {device}')

    training_generator, val_generator, test_generator, class_dict = select_dataset(
        config)
    n_classes = len(class_dict)
    model = select_model(config, n_classes)

    log.info(f"{model}")

    if (config.load):

        pth_file, _ = load_checkpoint(config.pretrained_cpkt,
                                      model,
                                      strict=True,
                                      load_seperate_layers=False)

    else:
        pth_file = None
    if (config.cuda and use_cuda):
        if torch.cuda.device_count() > 1:
            log.info(f"Let's use {torch.cuda.device_count()} GPUs!")

            model = torch.nn.DataParallel(model)
    model.to(device)

    optimizer, scheduler = select_optimizer(model, config['model'], None)
    log.info(f'{model}')
    log.info(f"Checkpoint Folder {cpkt_fol_name} ")
    shutil.copy(os.path.join(config.cwd, config_file), cpkt_fol_name)

    trainer = Trainer(config,
                      model=model,
                      optimizer=optimizer,
                      data_loader=training_generator,
                      writer=writer,
                      logger=log,
                      valid_data_loader=val_generator,
                      test_data_loader=test_generator,
                      class_dict=class_dict,
                      lr_scheduler=scheduler,
                      checkpoint_dir=cpkt_fol_name)
    trainer.train()