예제 #1
0
def configure_split_dict(split,
                         data,
                         split_name,
                         verbose,
                         grouper,
                         batch_size,
                         config,
                         get_train=False,
                         get_eval=False):
    split_dict = defaultdict(dict)

    # Loaders and dict
    if data is not None:
        split_dict['dataset'] = data
        if get_train:
            split_dict['train_loader'] = get_train_loader(
                loader=config.train_loader,
                dataset=split_dict['dataset'],
                batch_size=batch_size,
                uniform_over_groups=config.uniform_over_groups,
                grouper=grouper,
                distinct_groups=config.distinct_groups,
                n_groups_per_batch=config.n_groups_per_batch,
                **config.loader_kwargs)
        if get_eval:
            split_dict['eval_loader'] = get_eval_loader(
                loader=config.eval_loader,
                dataset=split_dict['dataset'],
                grouper=grouper,
                batch_size=batch_size,
                **config.loader_kwargs)

    # Set fields
    split_dict['split'] = split
    split_dict['name'] = split_name
    split_dict['verbose'] = verbose

    # Loggers
    split_dict['eval_logger'] = BatchLogger(os.path.join(
        config.log_dir, f'{split}_eval.csv'),
                                            mode=config.mode,
                                            use_wandb=config.use_wandb)
    split_dict['algo_logger'] = BatchLogger(os.path.join(
        config.log_dir, f'{split}_algo.csv'),
                                            mode=config.mode,
                                            use_wandb=config.use_wandb)

    return split_dict
예제 #2
0
def main():
    ''' set default hyperparams in default_hyperparams.py '''
    parser = argparse.ArgumentParser()

    # Required arguments
    parser.add_argument('-d',
                        '--dataset',
                        choices=wilds.supported_datasets,
                        required=True)
    parser.add_argument('--algorithm',
                        required=True,
                        choices=supported.algorithms)
    parser.add_argument(
        '--root_dir',
        required=True,
        help=
        'The directory where [dataset]/data can be found (or should be downloaded to, if it does not exist).'
    )

    # Dataset
    parser.add_argument(
        '--split_scheme',
        help=
        'Identifies how the train/val/test split is constructed. Choices are dataset-specific.'
    )
    parser.add_argument('--dataset_kwargs',
                        nargs='*',
                        action=ParseKwargs,
                        default={})
    parser.add_argument(
        '--download',
        default=False,
        type=parse_bool,
        const=True,
        nargs='?',
        help=
        'If true, tries to downloads the dataset if it does not exist in root_dir.'
    )
    parser.add_argument(
        '--frac',
        type=float,
        default=1.0,
        help=
        'Convenience parameter that scales all dataset splits down to the specified fraction, for development purposes. Note that this also scales the test set down, so the reported numbers are not comparable with the full test set.'
    )
    parser.add_argument('--version', default=None, type=str)

    # Loaders
    parser.add_argument('--loader_kwargs',
                        nargs='*',
                        action=ParseKwargs,
                        default={})
    parser.add_argument('--train_loader', choices=['standard', 'group'])
    parser.add_argument('--uniform_over_groups',
                        type=parse_bool,
                        const=True,
                        nargs='?')
    parser.add_argument('--distinct_groups',
                        type=parse_bool,
                        const=True,
                        nargs='?')
    parser.add_argument('--n_groups_per_batch', type=int)
    parser.add_argument('--batch_size', type=int)
    parser.add_argument('--eval_loader',
                        choices=['standard'],
                        default='standard')

    # Model
    parser.add_argument('--model', choices=supported.models)
    parser.add_argument(
        '--model_kwargs',
        nargs='*',
        action=ParseKwargs,
        default={},
        help=
        'keyword arguments for model initialization passed as key1=value1 key2=value2'
    )

    # Transforms
    parser.add_argument('--train_transform', choices=supported.transforms)
    parser.add_argument('--eval_transform', choices=supported.transforms)
    parser.add_argument(
        '--target_resolution',
        nargs='+',
        type=int,
        help=
        'The input resolution that images will be resized to before being passed into the model. For example, use --target_resolution 224 224 for a standard ResNet.'
    )
    parser.add_argument('--resize_scale', type=float)
    parser.add_argument('--max_token_length', type=int)

    # Objective
    parser.add_argument('--loss_function', choices=supported.losses)

    # Algorithm
    parser.add_argument('--groupby_fields', nargs='+')
    parser.add_argument('--group_dro_step_size', type=float)
    parser.add_argument('--coral_penalty_weight', type=float)
    parser.add_argument('--dann_lambda', type=float)
    parser.add_argument('--dann_domain_layers', type=int,
                        default=1)  # hidden layers
    parser.add_argument('--dann_label_layers', type=int,
                        default=1)  # hidden layers
    parser.add_argument('--domain_loss_function', choices=supported.losses)
    parser.add_argument('--irm_lambda', type=float)
    parser.add_argument('--irm_penalty_anneal_iters', type=int)
    parser.add_argument('--algo_log_metric')

    # Model selection
    parser.add_argument('--val_metric')
    parser.add_argument('--val_metric_decreasing',
                        type=parse_bool,
                        const=True,
                        nargs='?')

    # Optimization
    parser.add_argument('--n_epochs', type=int)
    parser.add_argument('--optimizer', choices=supported.optimizers)
    parser.add_argument('--lr', type=float)
    parser.add_argument('--weight_decay', type=float)
    parser.add_argument('--max_grad_norm', type=float)
    parser.add_argument('--optimizer_kwargs',
                        nargs='*',
                        action=ParseKwargs,
                        default={})

    # Scheduler
    parser.add_argument('--scheduler', choices=supported.schedulers)
    parser.add_argument('--scheduler_kwargs',
                        nargs='*',
                        action=ParseKwargs,
                        default={})
    parser.add_argument('--scheduler_metric_split',
                        choices=['train', 'val'],
                        default='val')
    parser.add_argument('--scheduler_metric_name')

    # Evaluation
    parser.add_argument('--process_outputs_function',
                        choices=supported.process_outputs_functions)
    parser.add_argument('--evaluate_all_splits',
                        type=parse_bool,
                        const=True,
                        nargs='?',
                        default=True)
    parser.add_argument('--eval_splits', nargs='+', default=[])
    parser.add_argument('--eval_only',
                        type=parse_bool,
                        const=True,
                        nargs='?',
                        default=False)
    parser.add_argument(
        '--eval_epoch',
        default=None,
        type=int,
        help=
        'If eval_only is set, then eval_epoch allows you to specify evaluating at a particular epoch. By default, it evaluates the best epoch by validation performance.'
    )

    # Misc
    parser.add_argument('--device', type=int, default=0)
    parser.add_argument('--seed', type=int, default=0)
    parser.add_argument('--log_dir', default='./logs')
    parser.add_argument('--log_every', default=50, type=int)
    parser.add_argument('--save_step', type=int)
    parser.add_argument('--save_best',
                        type=parse_bool,
                        const=True,
                        nargs='?',
                        default=True)
    parser.add_argument('--save_last',
                        type=parse_bool,
                        const=True,
                        nargs='?',
                        default=True)
    parser.add_argument('--save_pred',
                        type=parse_bool,
                        const=True,
                        nargs='?',
                        default=True)
    parser.add_argument('--no_group_logging',
                        type=parse_bool,
                        const=True,
                        nargs='?')
    parser.add_argument('--use_wandb',
                        type=parse_bool,
                        const=True,
                        nargs='?',
                        default=False)
    parser.add_argument('--progress_bar',
                        type=parse_bool,
                        const=True,
                        nargs='?',
                        default=False)
    parser.add_argument('--resume',
                        type=parse_bool,
                        const=True,
                        nargs='?',
                        default=False)

    config = parser.parse_args()
    config = populate_defaults(config)

    # set device
    config.device = torch.device("cuda:" + str(
        config.device)) if torch.cuda.is_available() else torch.device("cpu")

    ## Initialize logs
    if os.path.exists(config.log_dir) and config.resume:
        resume = True
        mode = 'a'
    elif os.path.exists(config.log_dir) and config.eval_only:
        resume = False
        mode = 'a'
    else:
        resume = False
        mode = 'w'

    if not os.path.exists(config.log_dir):
        os.makedirs(config.log_dir)
    logger = Logger(os.path.join(config.log_dir, 'log.txt'), mode)

    # Record config
    log_config(config, logger)

    # Set random seed
    set_seed(config.seed)

    # Data
    full_dataset = wilds.get_dataset(dataset=config.dataset,
                                     version=config.version,
                                     root_dir=config.root_dir,
                                     download=config.download,
                                     split_scheme=config.split_scheme,
                                     **config.dataset_kwargs)

    # To implement data augmentation (i.e., have different transforms
    # at training time vs. test time), modify these two lines:
    train_transform = initialize_transform(
        transform_name=config.train_transform,
        config=config,
        dataset=full_dataset)
    eval_transform = initialize_transform(transform_name=config.eval_transform,
                                          config=config,
                                          dataset=full_dataset)

    train_grouper = CombinatorialGrouper(dataset=full_dataset,
                                         groupby_fields=config.groupby_fields)

    datasets = defaultdict(dict)
    for split in full_dataset.split_dict.keys():
        if split == 'train':
            transform = train_transform
            verbose = True
        elif split == 'val':
            transform = eval_transform
            verbose = True
        else:
            transform = eval_transform
            verbose = False
        # Get subset
        datasets[split]['dataset'] = full_dataset.get_subset(
            split, frac=config.frac, transform=transform)

        if split == 'train':
            datasets[split]['loader'] = get_train_loader(
                loader=config.train_loader,
                dataset=datasets[split]['dataset'],
                batch_size=config.batch_size,
                uniform_over_groups=config.uniform_over_groups,
                grouper=train_grouper,
                distinct_groups=config.distinct_groups,
                n_groups_per_batch=config.n_groups_per_batch,
                **config.loader_kwargs)
        else:
            datasets[split]['loader'] = get_eval_loader(
                loader=config.eval_loader,
                dataset=datasets[split]['dataset'],
                grouper=train_grouper,
                batch_size=config.batch_size,
                **config.loader_kwargs)

        # Set fields
        datasets[split]['split'] = split
        datasets[split]['name'] = full_dataset.split_names[split]
        datasets[split]['verbose'] = verbose

        # Loggers
        datasets[split]['eval_logger'] = BatchLogger(
            os.path.join(config.log_dir, f'{split}_eval.csv'),
            mode=mode,
            use_wandb=(config.use_wandb and verbose))
        datasets[split]['algo_logger'] = BatchLogger(
            os.path.join(config.log_dir, f'{split}_algo.csv'),
            mode=mode,
            use_wandb=(config.use_wandb and verbose))

        if config.use_wandb:
            initialize_wandb(config)

    # Logging dataset info
    # Show class breakdown if feasible
    if config.no_group_logging and full_dataset.is_classification and full_dataset.y_size == 1 and full_dataset.n_classes <= 10:
        log_grouper = CombinatorialGrouper(dataset=full_dataset,
                                           groupby_fields=['y'])
    elif config.no_group_logging:
        log_grouper = None
    else:
        log_grouper = train_grouper
    log_group_data(datasets, log_grouper, logger)

    ## Initialize algorithm
    algorithm = initialize_algorithm(config=config,
                                     datasets=datasets,
                                     train_grouper=train_grouper)

    model_prefix = get_model_prefix(datasets['train'], config)
    if not config.eval_only:
        ## Load saved results if resuming
        resume_success = False
        if resume:
            save_path = model_prefix + 'epoch:last_model.pth'
            if not os.path.exists(save_path):
                epochs = [
                    int(file.split('epoch:')[1].split('_')[0])
                    for file in os.listdir(config.log_dir)
                    if file.endswith('.pth')
                ]
                if len(epochs) > 0:
                    latest_epoch = max(epochs)
                    save_path = model_prefix + f'epoch:{latest_epoch}_model.pth'
            try:
                prev_epoch, best_val_metric = load(algorithm, save_path)
                epoch_offset = prev_epoch + 1
                logger.write(
                    f'Resuming from epoch {epoch_offset} with best val metric {best_val_metric}'
                )
                resume_success = True
            except FileNotFoundError:
                pass

        if resume_success == False:
            epoch_offset = 0
            best_val_metric = None

        train(algorithm=algorithm,
              datasets=datasets,
              general_logger=logger,
              config=config,
              epoch_offset=epoch_offset,
              best_val_metric=best_val_metric)
    else:
        if config.eval_epoch is None:
            eval_model_path = model_prefix + 'epoch:best_model.pth'
        else:
            eval_model_path = model_prefix + f'epoch:{config.eval_epoch}_model.pth'
        best_epoch, best_val_metric = load(algorithm, eval_model_path)
        if config.eval_epoch is None:
            epoch = best_epoch
        else:
            epoch = config.eval_epoch
        evaluate(algorithm=algorithm,
                 datasets=datasets,
                 epoch=epoch,
                 general_logger=logger,
                 config=config)

    logger.close()
    for split in datasets:
        datasets[split]['eval_logger'].close()
        datasets[split]['algo_logger'].close()
예제 #3
0
def main():
    ''' set default hyperparams in default_hyperparams.py '''
    parser = argparse.ArgumentParser()

    # Required arguments
    parser.add_argument('-d',
                        '--dataset',
                        choices=supported.datasets,
                        required=True)
    parser.add_argument('--algorithm',
                        required=True,
                        choices=supported.algorithms)
    parser.add_argument(
        '--root_dir',
        required=True,
        help=
        'The directory where [dataset]/data can be found (or should be downloaded to, if it does not exist).'
    )
    parser.add_argument('--analyze_sample', default=1)

    # Dataset
    parser.add_argument(
        '--split_scheme',
        help=
        'Identifies how the train/val/test split is constructed. Choices are dataset-specific.'
    )
    parser.add_argument('--dataset_kwargs',
                        nargs='*',
                        action=ParseKwargs,
                        default={})
    parser.add_argument(
        '--download',
        default=False,
        type=parse_bool,
        const=True,
        nargs='?',
        help=
        'If true, tries to downloads the dataset if it does not exist in root_dir.'
    )
    parser.add_argument(
        '--frac',
        type=float,
        default=1.0,
        help=
        'Convenience parameter that scales all dataset splits down to the specified fraction, for development purposes.'
    )

    # Loaders
    parser.add_argument('--loader_kwargs',
                        nargs='*',
                        action=ParseKwargs,
                        default={})
    parser.add_argument('--train_loader', choices=['standard', 'group'])
    parser.add_argument('--uniform_over_groups',
                        type=parse_bool,
                        const=True,
                        nargs='?')
    parser.add_argument('--distinct_groups',
                        type=parse_bool,
                        const=True,
                        nargs='?')
    parser.add_argument('--n_groups_per_batch', type=int)
    parser.add_argument('--batch_size', type=int)
    parser.add_argument('--eval_loader',
                        choices=['standard'],
                        default='standard')

    # Model
    parser.add_argument('--model', choices=supported.models)
    parser.add_argument(
        '--model_kwargs',
        nargs='*',
        action=ParseKwargs,
        default={},
        help=
        'keyword arguments for model initialization passed as key1=value1 key2=value2'
    )

    # Transforms
    parser.add_argument('--train_transform', choices=supported.transforms)
    parser.add_argument('--eval_transform', choices=supported.transforms)
    parser.add_argument(
        '--target_resolution',
        nargs='+',
        type=int,
        help=
        'target resolution. for example --target_resolution 224 224 for standard resnet.'
    )
    parser.add_argument('--resize_scale', type=float)
    parser.add_argument('--max_token_length', type=int)

    # Objective
    parser.add_argument('--loss_function', choices=supported.losses)

    # Algorithm
    parser.add_argument('--groupby_fields', nargs='+')
    parser.add_argument('--group_dro_step_size', type=float)
    parser.add_argument('--coral_penalty_weight', type=float)
    parser.add_argument('--irm_lambda', type=float)
    parser.add_argument('--irm_penalty_anneal_iters', type=int)
    parser.add_argument('--algo_log_metric')
    parser.add_argument('--hsic_beta', type=float)
    parser.add_argument('--grad_penalty_lamb', type=float)
    parser.add_argument(
        '--params_regex',
        type=str,
        help='Regular expression specifying which gradients to penalize.')
    parser.add_argument('--label_cond',
                        type=parse_bool,
                        const=True,
                        nargs='?',
                        default=False)
    parser.add_argument('--dann_lamb', type=float)
    parser.add_argument('--dann_dc_name', type=str)

    # Model selection
    parser.add_argument('--val_metric')
    parser.add_argument('--val_metric_decreasing',
                        type=parse_bool,
                        const=True,
                        nargs='?')

    # Optimization
    parser.add_argument('--n_epochs', type=int)
    parser.add_argument('--optimizer', choices=supported.optimizers)
    parser.add_argument('--lr', type=float)
    parser.add_argument('--weight_decay', type=float)
    parser.add_argument('--max_grad_norm', type=float)
    parser.add_argument('--optimizer_kwargs',
                        nargs='*',
                        action=ParseKwargs,
                        default={})

    # Scheduler
    parser.add_argument('--scheduler', choices=supported.schedulers)
    parser.add_argument('--scheduler_kwargs',
                        nargs='*',
                        action=ParseKwargs,
                        default={})
    parser.add_argument('--scheduler_metric_split',
                        choices=['train', 'val'],
                        default='val')
    parser.add_argument('--scheduler_metric_name')

    # Evaluation
    parser.add_argument('--evaluate_all_splits',
                        type=parse_bool,
                        const=True,
                        nargs='?',
                        default=True)
    parser.add_argument('--eval_splits', nargs='+', default=[])
    parser.add_argument('--eval_only',
                        type=parse_bool,
                        const=True,
                        nargs='?',
                        default=False)
    parser.add_argument('--eval_epoch', default=None, type=int)
    parser.add_argument('--save_z',
                        type=parse_bool,
                        const=True,
                        nargs='?',
                        default=False)

    # Misc
    parser.add_argument('--device', type=int, default=0)
    parser.add_argument('--seed', type=int, default=0)
    parser.add_argument('--log_dir', default='./logs')
    parser.add_argument('--log_every', default=50, type=int)
    parser.add_argument('--save_step', type=int)
    parser.add_argument('--save_best',
                        type=parse_bool,
                        const=True,
                        nargs='?',
                        default=True)
    parser.add_argument('--save_last',
                        type=parse_bool,
                        const=True,
                        nargs='?',
                        default=True)
    parser.add_argument('--no_group_logging',
                        type=parse_bool,
                        const=True,
                        nargs='?')
    parser.add_argument('--use_wandb',
                        type=parse_bool,
                        const=True,
                        nargs='?',
                        default=False)
    parser.add_argument('--progress_bar',
                        type=parse_bool,
                        const=True,
                        nargs='?',
                        default=False)
    parser.add_argument('--resume',
                        type=parse_bool,
                        const=True,
                        nargs='?',
                        default=False)

    config = parser.parse_args()
    config = populate_defaults(config)

    # set device
    config.device = torch.device("cuda:" + str(
        config.device)) if torch.cuda.is_available() else torch.device("cpu")

    ## Initialize logs
    if os.path.exists(config.log_dir) and config.resume:
        resume = True
        mode = 'a'
    elif os.path.exists(config.log_dir) and config.eval_only:
        resume = False
        mode = 'a'
    else:
        resume = False
        mode = 'w'

    if not os.path.exists(config.log_dir):
        os.makedirs(config.log_dir)
    logger = Logger(os.path.join(config.log_dir, 'log.txt'), mode)

    # Record config
    log_config(config, logger)

    # Set random seed
    set_seed(config.seed)

    # Data
    full_dataset = supported.datasets[config.dataset](
        root_dir=config.root_dir,
        download=config.download,
        split_scheme=config.split_scheme,
        **config.dataset_kwargs)

    # To implement data augmentation (i.e., have different transforms
    # at training time vs. test time), modify these two lines:
    train_transform = initialize_transform(
        transform_name=config.train_transform,
        config=config,
        dataset=full_dataset)
    eval_transform = initialize_transform(transform_name=config.eval_transform,
                                          config=config,
                                          dataset=full_dataset)

    train_grouper = CombinatorialGrouper(dataset=full_dataset,
                                         groupby_fields=config.groupby_fields)

    datasets = defaultdict(dict)
    for split in full_dataset.split_dict.keys():
        if split == 'train':
            transform = train_transform
            verbose = True
        elif split == 'val':
            transform = eval_transform
            verbose = True
        else:
            transform = eval_transform
            verbose = False
        # Get subset
        datasets[split]['dataset'] = full_dataset.get_subset(
            split, frac=config.frac, transform=transform)

        if split == 'train':
            datasets[split]['loader'] = get_train_loader(
                loader=config.train_loader,
                dataset=datasets[split]['dataset'],
                batch_size=config.batch_size,
                uniform_over_groups=config.uniform_over_groups,
                grouper=train_grouper,
                distinct_groups=config.distinct_groups,
                n_groups_per_batch=config.n_groups_per_batch,
                **config.loader_kwargs)
        else:
            datasets[split]['loader'] = get_eval_loader(
                loader=config.eval_loader,
                dataset=datasets[split]['dataset'],
                grouper=train_grouper,
                batch_size=config.batch_size,
                **config.loader_kwargs)

        # Set fields
        datasets[split]['split'] = split
        datasets[split]['name'] = full_dataset.split_names[split]
        datasets[split]['verbose'] = verbose

        # Loggers
        datasets[split]['eval_logger'] = BatchLogger(
            os.path.join(config.log_dir, f'{split}_eval.csv'),
            mode=mode,
            use_wandb=(config.use_wandb and verbose))
        datasets[split]['algo_logger'] = BatchLogger(
            os.path.join(config.log_dir, f'{split}_algo.csv'),
            mode=mode,
            use_wandb=(config.use_wandb and verbose))

        if config.use_wandb:
            initialize_wandb(config)

    # Logging dataset info
    if config.no_group_logging and full_dataset.is_classification and full_dataset.y_size == 1:
        log_grouper = CombinatorialGrouper(dataset=full_dataset,
                                           groupby_fields=['y'])
    elif config.no_group_logging:
        log_grouper = None
    else:
        log_grouper = train_grouper
    log_group_data(datasets, log_grouper, logger)

    ## Initialize algorithm
    algorithm = initialize_algorithm(config=config,
                                     datasets=datasets,
                                     train_grouper=train_grouper)

    if config.eval_epoch is None:
        eval_model_path = os.path.join(config.log_dir, 'best_model.pth')
    else:
        eval_model_path = os.path.join(config.log_dir,
                                       f'{config.eval_epoch}_model.pth')
    best_epoch, best_val_metric = load(algorithm, eval_model_path)
    if config.eval_epoch is None:
        epoch = best_epoch
    else:
        epoch = config.eval_epoch

    results, z_splits, y_splits, c_splits = evaluate(algorithm=algorithm,
                                                     datasets=datasets,
                                                     epoch=epoch,
                                                     general_logger=logger,
                                                     config=config)

    include_test = config.evaluate_all_splits or 'test' in config.eval_splits

    logistics = all_logistics(z_splits,
                              c_splits,
                              y_splits,
                              epoch=epoch,
                              sample=int(config.analyze_sample),
                              include_test=include_test)

    logistics['G0'] = results['id_val']['acc_avg']
    logistics['G1'] = logistics['val_on_val']
    logistics['G2'] = logistics['trainval_on_val']
    logistics['G3'] = results['val']['acc_avg']

    logistics['I0'] = logistics['c_train']
    logistics['I1'] = logistics['c_val']
    per_class = torch.tensor(list(logistics['c_perclass'].values()))
    logistics['I2'] = torch.mean(per_class).item()

    if include_test:
        logistics['G1_test'] = logistics['test_on_test']
        logistics['G2_test'] = logistics['traintest_on_test']
        logistics['G3_test'] = results['test']['acc_avg']

        logistics['I1_test'] = logistics['c_test']
        per_class = torch.tensor(list(logistics['c_perclass_test'].values()))
        logistics['I2_test'] = torch.mean(per_class).item()

    with (open(os.path.join(config.log_dir, f'tests_epoch_{epoch}.pkl'),
               "wb")) as f:
        pickle.dump(logistics, f)

    logger.close()
    for split in datasets:
        datasets[split]['eval_logger'].close()
        datasets[split]['algo_logger'].close()