def main():
    parser = argparse.ArgumentParser(description='Semantic Segmentation')
    parser.add_argument('--train_cfg',
                        type=str,
                        default='./configs/train_config.yaml',
                        help='train config path')
    args = parser.parse_args()
    config_folder = Path(args.train_cfg.strip("/"))
    config = load_yaml(config_folder)
    init_seed(config['SEED'])

    df, train_ids, valid_ids = split_dataset(config['DATA_TRAIN'])
    train_dataset = getattribute(config=config,
                                 name_package='TRAIN_DATASET',
                                 df=df,
                                 img_ids=train_ids)
    valid_dataset = getattribute(config=config,
                                 name_package='VALID_DATASET',
                                 df=df,
                                 img_ids=valid_ids)
    train_dataloader = getattribute(config=config,
                                    name_package='TRAIN_DATALOADER',
                                    dataset=train_dataset)
    valid_dataloader = getattribute(config=config,
                                    name_package='VALID_DATALOADER',
                                    dataset=valid_dataset)
    model = getattribute(config=config, name_package='MODEL')
    criterion = getattribute(config=config, name_package='CRITERION')
    optimizer = getattribute(config=config,
                             name_package='OPTIMIZER',
                             params=model.parameters())
    scheduler = getattribute(config=config,
                             name_package='SCHEDULER',
                             optimizer=optimizer)
    device = config['DEVICE']
    metric_ftns = [accuracy_dice_score]
    num_epoch = config['NUM_EPOCH']
    gradient_clipping = config['GRADIENT_CLIPPING']
    gradient_accumulation_steps = config['GRADIENT_ACCUMULATION_STEPS']
    early_stopping = config['EARLY_STOPPING']
    validation_frequency = config['VALIDATION_FREQUENCY']
    saved_period = config['SAVED_PERIOD']
    checkpoint_dir = Path(config['CHECKPOINT_DIR'], type(model).__name__)
    checkpoint_dir.mkdir(exist_ok=True, parents=True)
    resume_path = config['RESUME_PATH']
    learning = Learning(model=model,
                        optimizer=optimizer,
                        criterion=criterion,
                        device=device,
                        metric_ftns=metric_ftns,
                        num_epoch=num_epoch,
                        scheduler=scheduler,
                        grad_clipping=gradient_clipping,
                        grad_accumulation_steps=gradient_accumulation_steps,
                        early_stopping=early_stopping,
                        validation_frequency=validation_frequency,
                        save_period=saved_period,
                        checkpoint_dir=checkpoint_dir,
                        resume_path=resume_path)
    learning.train(tqdm(train_dataloader), tqdm(valid_dataloader))
def main():
    parser = argparse.ArgumentParser(description='Semantic Segmentation')
    parser.add_argument('--train_cfg',
                        type=str,
                        default='./configs/train.yaml',
                        help='train config path')
    args = parser.parse_args()
    config_folder = Path(args.train_cfg.strip("/"))
    config = load_yaml(config_folder)
    init_seed(config['SEED'])

    image_datasets = {
        x: vinDataset(root_dir=config['ROOT_DIR'],
                      file_name=config['FILE_NAME'],
                      num_triplet=config['NUM_TRIPLET'],
                      phase=x)
        for x in ['train', 'valid']
    }
    dataloaders = {
        x: torch.utils.data.DataLoader(image_datasets[x],
                                       batch_size=config['BATCH_SIZE'],
                                       shuffle=True,
                                       num_workers=4,
                                       pin_memory=True)
        for x in ['train', 'valid']
    }

    model = getattribute(config=config, name_package='MODEL')
    criterion = getattribute(config=config, name_package='CRITERION')
    metric_ftns = [accuracy_score]
    optimizer = getattribute(config=config,
                             name_package='OPTIMIZER',
                             params=model.parameters())
    scheduler = getattribute(config=config,
                             name_package='SCHEDULER',
                             optimizer=optimizer)
    device = config['DEVICE']
    num_epoch = config['NUM_EPOCH']
    gradient_clipping = config['GRADIENT_CLIPPING']
    gradient_accumulation_steps = config['GRADIENT_ACCUMULATION_STEPS']
    early_stopping = config['EARLY_STOPPING']
    validation_frequency = config['VALIDATION_FREQUENCY']
    saved_period = config['SAVED_PERIOD']
    checkpoint_dir = Path(config['CHECKPOINT_DIR'], type(model).__name__)
    checkpoint_dir.mkdir(exist_ok=True, parents=True)
    resume_path = config['RESUME_PATH']
    learning = Learning(model=model,
                        criterion=criterion,
                        metric_ftns=metric_ftns,
                        optimizer=optimizer,
                        device=device,
                        num_epoch=num_epoch,
                        scheduler=scheduler,
                        grad_clipping=gradient_clipping,
                        grad_accumulation_steps=gradient_accumulation_steps,
                        early_stopping=early_stopping,
                        validation_frequency=validation_frequency,
                        save_period=saved_period,
                        checkpoint_dir=checkpoint_dir,
                        resume_path=resume_path)

    learning.train(tqdm(dataloaders['train']), tqdm(dataloaders['valid']))
def main():
    parser = argparse.ArgumentParser(description='Pytorch parser')
    parser.add_argument('--train_cfg',
                        type=str,
                        default='./configs/efficientdet-d0.yaml',
                        help='train config path')
    parser.add_argument('-d',
                        '--device',
                        default=None,
                        type=str,
                        help='indices of GPUs to enable (default: all)')
    parser.add_argument('-r',
                        '--resume',
                        default=None,
                        type=str,
                        help='path to latest checkpoint (default: None)')

    CustomArgs = collections.namedtuple('CustomArgs', 'flags type target')
    options = [
        CustomArgs(['-lr', '--learning_rate'],
                   type=float,
                   target='OPTIMIZER,ARGS,lr'),
        CustomArgs(
            ['-bs', '--batch_size'],
            type=int,
            target=
            'TRAIN_DATALOADER,ARGS,batch_size;VALID_DATALOADER,ARGS,batch_size'
        )
    ]
    config = config_parser(parser, options)
    init_seed(config['SEED'])
    train_dataset = VOCDetection(root=VOC_ROOT,
                                 transform=SSDAugmentation(
                                     voc['min_dim'], MEANS))

    train_dataloader = getattribute(config=config,
                                    name_package='TRAIN_DATALOADER',
                                    dataset=train_dataset,
                                    collate_fn=detection_collate)
    # valid_dataloader = getattribute(config = config, name_package = 'VALID_DATALOADER', dataset = valid_dataset)
    model = getattribute(config=config, name_package='MODEL')
    criterion = getattribute(config=config, name_package='CRITERION')
    optimizer = getattribute(config=config,
                             name_package='OPTIMIZER',
                             params=model.parameters())
    scheduler = getattribute(config=config,
                             name_package='SCHEDULER',
                             optimizer=optimizer)
    device = config['DEVICE']
    metric_ftns = []
    num_epoch = config['NUM_EPOCH']
    gradient_clipping = config['GRADIENT_CLIPPING']
    gradient_accumulation_steps = config['GRADIENT_ACCUMULATION_STEPS']
    early_stopping = config['EARLY_STOPPING']
    validation_frequency = config['VALIDATION_FREQUENCY']
    tensorboard = config['TENSORBOARD']
    checkpoint_dir = Path(config['CHECKPOINT_DIR'], type(model).__name__)
    checkpoint_dir.mkdir(exist_ok=True, parents=True)
    resume_path = config['RESUME_PATH']
    learning = Learning(model=model,
                        criterion=criterion,
                        optimizer=optimizer,
                        scheduler=scheduler,
                        metric_ftns=metric_ftns,
                        device=device,
                        num_epoch=num_epoch,
                        grad_clipping=gradient_clipping,
                        grad_accumulation_steps=gradient_accumulation_steps,
                        early_stopping=early_stopping,
                        validation_frequency=validation_frequency,
                        tensorboard=tensorboard,
                        checkpoint_dir=checkpoint_dir,
                        resume_path=resume_path)

    learning.train(tqdm(train_dataloader))