예제 #1
0
def run(options):
    if options.device is not None:
        pt_training.set_device(options.device)

    device = get_device()

    pt_training.set_random_seed(options.seed)

    folder = join(config.EXP_FOLDER, options.folder)

    if not options.do_overwrite:
        if abort_due_to_overwrite_safety(folder):
            print('Process aborted.')
            return

    check_mkdir(folder)

    if options.comment is None:
        print(
            'You forgot to add a comment to your experiment. Please add something!'
        )
        options.comment = input('Comment: ')

    save_options(join(folder, 'opt.json'), options)

    copy_source(folder, do_not_copy_folders=config.DO_NOT_COPY)

    _train_model(options, folder, device)
def save_np(data, labels, path):
    filename = f'{path}.npy'
    check_mkdir(filename)
    np.save(filename, data)
    print(f'saving to {filename} finished')

    filename = f'{path}_labels.npy'
    check_mkdir(filename)
    np.save(filename, labels)
    print(f'saving to {filename} finished')
예제 #3
0
def create_splits(num_splits=5, split_perc=.8):
    images_ = _get_images_sorted_by_class()

    for i in range(num_splits):
        train_images = []
        test_images = []

        # for each class get split_perc % for training
        for tmp in images_.values():
            n = len(tmp)
            n_train = int(split_perc * n)

            # grab images without replacement
            tmp_train = random.sample(tmp, n_train)
            tmp_test = list(set(tmp) - set(tmp_train))

            train_images += tmp_train
            test_images += tmp_test

        out_path = join(config.NAT_IM_BASE, 'splits', f'{i}.json')
        check_mkdir(out_path)
        with open(out_path, 'w') as file:
            json.dump({'train': train_images, 'test': test_images}, file)
예제 #4
0
def _train_model(options, folder, device):
    model_out = join(folder, 'model.pt')
    log_file = join(folder, 'log.json')
    check_mkdir(log_file)

    model = get_model(options.model_type, options.in_dim, options.out_dim,
                      device)

    if options.model_path is not None:
        model_sd = torch.load(options.model_path).state_dict()
        model.load_state_dict(model_sd, strict=False)

    write_model_specs(folder, model)

    optimizer = pt_training.get_optimizer(options.opt,
                                          model.parameters(),
                                          options.lr,
                                          momentum=options.mom,
                                          weight_decay=options.wd)

    if options.lr_steps > 0 or options.fixed_steps is not None:
        scheduler = pt_training.get_scheduler(options.lr_steps,
                                              options.epochs,
                                              optimizer,
                                              fixed_steps=options.fixed_steps)
    else:
        print('no scheduling used')
        scheduler = None

    data_loaders = get_data_loaders(options.dataset,
                                    options.bs,
                                    split_index=options.split_index)

    if options.early_stopping:
        assert options.sv_int == -1
        early_stopping = pt_training.EarlyStopping(options.es_metric,
                                                   get_max=True,
                                                   epoch_thres=options.epochs)
    else:
        early_stopping = None

    train_interface = get_interface(
        'classification',
        model,
        device,
        Printer(config.PRINT_FREQUENCY, log_file),
    )

    training = pt_training.Training(optimizer,
                                    data_loaders['train'],
                                    train_interface,
                                    scheduler=scheduler,
                                    printer=train_interface.printer,
                                    save_path=model_out,
                                    save_steps=options.sv_int,
                                    val_data_loader=data_loaders['val'],
                                    early_stopping=early_stopping,
                                    save_state_dict=True,
                                    test_data_loader=data_loaders['test'])

    training(options.epochs)