コード例 #1
0
def test_required():
    cfg = Configuration({'present': 1})
    with pytest.raises(ConfigurationError) as excinfo:
        cfg.configure(dict, missing=cfg.REQUIRED)
    logger.debug(excinfo.exconly())
    assert "'missing'" in str(excinfo.value)
    assert "'present'" not in str(excinfo.value)
コード例 #2
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('db_path', metavar='DB-FILE')
    parser.add_argument('meta_path', metavar='METADATA-FILE')
    parser.add_argument('--config', metavar='YAML-FILE', default=None)
    parser.add_argument('--max-segments-per-style', type=int, default=None)
    parser.add_argument('--style-key', type=str, default='style')
    args = parser.parse_args()

    if args.config:
        with open(args.config, 'rb') as f:
            _cfg = Configuration.from_yaml(f)
    else:
        _cfg = Configuration({})

    random.seed(42)

    with gzip.open(args.meta_path, 'rt') as f:
        metadata = json.load(f)
    styles = sorted(set(v[args.style_key] for v in metadata.values()))
    keys_by_style = {
        s: [
            k for k in sorted(metadata.keys())
            if metadata[k][args.style_key] == s
        ]
        for s in styles
    }

    def get_sequences(style):
        keys = list(keys_by_style[style])
        random.shuffle(keys)
        keys = keys[:args.max_segments_per_style]
        _LOGGER.info(f'Processing style {style} ({len(keys)} segments)...')
        with lmdb.open(args.db_path, subdir=False, readonly=True,
                       lock=False) as db:
            with db.begin(buffers=True) as txn:
                for key in keys:
                    val = txn.get(key.encode())
                    seq = NoteSequence.FromString(val)
                    yield note_sequence_utils.normalize_tempo(seq)

    results = collections.defaultdict(dict)
    for style in styles:
        sequences = list(get_sequences(style))

        total_notes = sum(len(seq.notes) for seq in sequences)
        if total_notes < 2:
            _LOGGER.info(f'Skipping style {style} with {total_notes} note(s).')
            continue

        stats = _cfg.configure(extract_all_stats, data=sequences)
        for stat_name, stat in stats.items():
            results[stat_name][style] = stat

    json.dump(dict(results),
              sys.stdout,
              default=lambda a: a.tolist(),
              separators=(',', ':'))
    sys.stdout.write('\n')
コード例 #3
0
def extract_note_stats(data, *, _cfg):
    features = {
        key: _cfg['features'][key].configure(feat_type, **kwargs)
        for key, (feat_type, kwargs) in NOTE_FEATURE_DEFS.items()
    }
    feature_values = note_features.extract_features(data, features)

    @configurable
    def make_hist(name, normed=True, *, _cfg):
        feature_names = [f['name'] for f in _cfg.get('features')]
        with np.errstate(divide='ignore', invalid='ignore'):
            hist, _ = np.histogramdd(
                sample=[feature_values[name] for name in feature_names],
                bins=[
                    _cfg['features'][i]['bins'].configure(
                        features[name].get_bins)
                    for i, name in enumerate(feature_names)
                ],
                normed=normed)
        np.nan_to_num(hist, copy=False)

        return name, hist

    # Create a dictionary mapping stat names to their values
    stats_cfg = _cfg['stats'] if 'stats' in _cfg else Configuration(
        NOTE_STAT_DEFS)
    return dict(stats_cfg.configure_list(make_hist))
コード例 #4
0
def main(args):
    config_file = args.config or os.path.join(args.logdir, 'model.yaml')
    with open(config_file, 'rb') as f:
        config = Configuration.from_yaml(f)
    logger.debug(config)

    model, trainer, encoding = config.configure(
        _init,
        logdir=args.logdir,
        train_mode=(args.action == 'train'),
        sampling_seed=getattr(args, 'seed', None))

    if args.action == 'train':
        trainer.train()
    elif args.action == 'run':
        trainer.load_variables(checkpoint_file=args.checkpoint)
        data = pickle.load(args.input_file)
        dataset = make_simple_dataset(_make_data_generator(encoding, data),
                                      output_types=(tf.int32, tf.int32,
                                                    tf.int32),
                                      output_shapes=([None], [None], [None]),
                                      batch_size=args.batch_size)

        output_ids = model.run(trainer.session, dataset, args.sample,
                               args.softmax_temperature)
        output = [encoding.decode(seq) for seq in output_ids]
        pickle.dump(output, args.output_file)
コード例 #5
0
def test_none_value():
    @configurable
    def f(*, _cfg):
        return _cfg.get('a')

    result = Configuration(None).configure(f, a=1)
    assert result is None
コード例 #6
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--logdir',
                        type=str,
                        required=True,
                        help='model directory')
    parser.set_defaults(train_mode=False, sampling_seed=None)
    subparsers = parser.add_subparsers(title='action')

    subparser = subparsers.add_parser('train')
    subparser.set_defaults(func=Experiment.train, train_mode=True)

    subparser = subparsers.add_parser('run-midi')
    subparser.set_defaults(func=Experiment.run_midi)
    subparser.add_argument('source_file', metavar='INPUTFILE')
    subparser.add_argument('style_file', metavar='STYLEFILE')
    subparser.add_argument('output_file', metavar='OUTPUTFILE')
    subparser.add_argument('--checkpoint', default=None, type=str)
    subparser.add_argument('--batch-size', default=None, type=int)
    subparser.add_argument('--sample', action='store_true')
    subparser.add_argument('--softmax-temperature', default=1., type=float)
    subparser.add_argument('--seed', type=int, dest='sampling_seed')
    subparser.add_argument(
        '--filters',
        choices=['training', 'program'],
        default='program',
        help='how to filter the input; training: use the same filters as '
        'during training; program: filter by MIDI program')
    subparser.add_argument('-b', '--bars-per-segment', default=8, type=int)

    subparser = subparsers.add_parser('run-test')
    subparser.set_defaults(func=Experiment.run_test)
    subparser.add_argument('source_db', metavar='INPUTDB')
    subparser.add_argument('style_db', metavar='STYLEDB')
    subparser.add_argument('key_pairs', metavar='KEYPAIRS')
    subparser.add_argument('output_db', metavar='OUTPUTDB')
    subparser.add_argument('--checkpoint', default=None, type=str)
    subparser.add_argument('--batch-size', default=None, type=int)
    subparser.add_argument('--sample', action='store_true')
    subparser.add_argument('--softmax-temperature', default=1., type=float)
    subparser.add_argument('--seed', type=int, dest='sampling_seed')
    subparser.add_argument(
        '--filters',
        choices=['training', 'program'],
        default='program',
        help='how to filter the input; training: use the same filters as '
        'during training; program: filter by MIDI program')

    args = parser.parse_args()

    config_file = os.path.join(args.logdir, 'model.yaml')
    with open(config_file, 'rb') as f:
        config = Configuration.from_yaml(f)
    _LOGGER.debug(config)

    experiment = config.configure(Experiment,
                                  logdir=args.logdir,
                                  train_mode=args.train_mode,
                                  sampling_seed=args.sampling_seed)
    args.func(experiment, args)
コード例 #7
0
def test_configurable_with_kwargs():
    @configurable
    def f(a, *, _cfg, **kwargs):
        del _cfg
        return a, kwargs

    result = Configuration({'a': 1, 'c': 3}).configure(f, b=2)
    expected_result = (1, {'b': 2, 'c': 3})
    assert result == expected_result
コード例 #8
0
def test_maybe_configure():
    @configurable
    def f(*, _cfg):
        return (_cfg['missing'].maybe_configure(dict),
                _cfg['present'].maybe_configure(dict))

    result = Configuration({'present': {'a': 1}}).configure(f)
    expected_result = (None, {'a': 1})
    assert result == expected_result
コード例 #9
0
def test_get_unused_keys():
    @configurable
    def f(a, b, *, _cfg):
        _cfg['g'].configure()
        _cfg['g_list'].configure_list(g)

    @configurable
    def g(a):
        pass

    cfg = Configuration({
        'a': 1,
        'b': 2,
        'unused': {
            'a': 1,
            'b': 2
        },
        'g': {
            'class': g,
            'a': 1,
            'unused': 2
        },
        'g_list': [{
            'a': 1,
            'unused': 2
        }],
        'x': {
            'a': 1,
            'unused': {
                'b': 2
            }
        },
    })
    cfg.configure(f)
    cfg['x']['a'].get()
    cfg['unused']  # we are not accessing the value

    with pytest.warns(ConfigurationWarning):
        unused_keys = cfg.get_unused_keys(warn=True)
    assert set(unused_keys) == {
        'unused', 'g.unused', 'g_list[0].unused', 'x.unused'
    }
コード例 #10
0
def test_configurable_with_params():
    @configurable(params=['a', 'c'])
    def f(a, b, c, d=4, *, _cfg):
        del _cfg
        return a, b, c, d

    result = Configuration({
        'a': 1,
        'b': None,
        'c': 3,
        'd': 4
    }).configure(f, b=2)
    expected_result = (1, 2, 3, 4)
    assert result == expected_result
コード例 #11
0
def init_models():
    for model_name, model_cfg in app.config['MODELS'].items():
        logdir = os.path.join(app.config['MODEL_ROOT'],
                              model_cfg.get('logdir', model_name))
        with open(os.path.join(logdir, 'model.yaml'), 'rb') as f:
            config = Configuration.from_yaml(f)

        model_graphs[model_name] = tf.Graph()
        with model_graphs[model_name].as_default():
            models[model_name] = config.configure(
                roll2seq_style_transfer.Experiment,
                logdir=logdir,
                train_mode=False)
            models[model_name].trainer.load_variables(
                **model_cfg.get('load_variables', {}))
コード例 #12
0
def test_configure_function():
    @configurable
    def f(a, b, c, d, e=5, *, _cfg):
        return _cfg['obj'].configure(dict, a=a, b=b, c=c, d=d, e=e)

    result = Configuration({
        'a': 10,
        'b': 2,
        'c': 3,
        'obj': {
            'a': 1,
            'f': 6
        }
    }).configure(f, a=0, c=300, d=4)
    expected_result = {'a': 1, 'b': 2, 'c': 3, 'd': 4, 'e': 5, 'f': 6}
    assert result == expected_result
コード例 #13
0
def test_bind():
    @configurable
    class A:
        def __init__(self, a, b, c, d, e=5):
            assert 'obj' in self._cfg
            self.obj = self._cfg['obj'].bind(dict, a=a, c=c, d=d, e=e)(b=b)

    result = Configuration({
        'a': 10,
        'b': 2,
        'c': 3,
        'obj': {
            'a': 1,
            'f': 6
        }
    }).bind(A, a=0, c=300)(d=4).obj
    expected_result = {'a': 1, 'b': 2, 'c': 3, 'd': 4, 'e': 5, 'f': 6}
    assert result == expected_result
コード例 #14
0
def main(args):
    config_file = args.config or os.path.join(args.logdir, 'model.yaml')
    with open(config_file, 'rb') as f:
        config = Configuration.from_yaml(f)
    logger.debug(config)

    model, trainer, encoding = config.configure(
        _init,
        logdir=args.logdir,
        train_mode=(args.action == 'train'),
        sampling_seed=getattr(args, 'seed', None))

    if args.action == 'train':
        trainer.train()
    elif args.action == 'sample':
        trainer.load_variables(checkpoint_file=args.checkpoint)
        output_ids = model.sample(session=trainer.session,
                                  batch_size=args.batch_size,
                                  softmax_temperature=args.softmax_temperature)
        output = [encoding.decode(seq) for seq in output_ids]
        pickle.dump(output, args.output_file)
コード例 #15
0
def test_configure_list():
    @configurable
    def f(*, _cfg):
        return _cfg['items'].configure_list(f)

    result = Configuration({
        'items': [{
            'class': dict,
            'x': 1
        }, {
            'items': [{
                'class': dict,
                'y': 2
            }, {
                'class': dict,
                'z': 3
            }]
        }]
    }).configure(f)
    expected_result = [{'x': 1}, [{'y': 2}, {'z': 3}]]
    assert result == expected_result
コード例 #16
0
def test_get_unused_keys_with_all_children_used():
    cfg = Configuration({'a': {'x': 1, 'y': 2}})
    cfg['a']['x'].get()
    cfg['a']['y'].get()

    assert cfg.get_unused_keys() == []
コード例 #17
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--model-dir', type=str, required=True)
    parser.add_argument('--name', type=str)
    parser.add_argument('--load-params', type=str)
    parser.add_argument('--load-optim', type=str)
    args = parser.parse_args()

    cfg_path = os.path.join(args.model_dir, 'config.yaml')
    cfg = Configuration.from_yaml_file(cfg_path)

    global neptune
    if neptune:
        try:
            neptune.init()
            neptune.create_experiment(args.name or args.model_dir,
                                      upload_source_files=[],
                                      params=dict(
                                          flatdict.FlatterDict(cfg.get(),
                                                               delimiter='.')))
        except neptune.exceptions.NeptuneException:
            neptune = None
            traceback.print_exc()

    seed = cfg.get('seed', 0)
    np.random.seed(seed)
    torch.random.manual_seed(seed)

    representation, start_id, end_id = cfg.configure(make_representation)
    print('Vocab size:', len(representation.vocab))

    def encode(music: muspy.Music):
        encoded = representation.encode(music)
        encoded = np.concatenate([[start_id], encoded])
        return encoded

    data_train = muspy.MusicDataset(cfg.get('train_data_path'))
    data_train = cfg['data_augmentation'].configure(AugmentedDataset,
                                                    dataset=data_train,
                                                    seed=seed)
    data_train_pt = data_train.to_pytorch_dataset(factory=encode)

    model = cfg['model'].configure(
        MusicPerformer,
        n_token=len(representation.vocab),
    ).to(DEVICE)

    train_loader = cfg['data_loader'].configure(DataLoader,
                                                dataset=data_train_pt,
                                                collate_fn=functools.partial(
                                                    collate_padded,
                                                    pad_value=end_id,
                                                    max_len=model.max_len),
                                                batch_size=1,
                                                shuffle=True,
                                                num_workers=24)

    val_loaders = {}
    if cfg['val_data_paths']:
        val_loaders = {
            name: cfg['val_data_loader'].configure(
                DataLoader,
                dataset=muspy.MusicDataset(path).to_pytorch_dataset(
                    factory=encode),
                collate_fn=functools.partial(collate_padded,
                                             pad_value=end_id,
                                             max_len=model.max_len),
                batch_size=1,
                shuffle=False,
                num_workers=24)
            for name, path in cfg.get('val_data_paths').items()
        }

    cfg['training'].configure(train,
                              model=model,
                              ckpt_dir=args.model_dir,
                              pretrained_param_path=args.load_params,
                              optimizer_path=args.load_optim,
                              train_dloader=train_loader,
                              val_dloaders=val_loaders,
                              pad_index=end_id)

    cfg.get_unused_keys(warn=True)