コード例 #1
0
def test_model_construction(config_simple):
    """
    Tests whether a model and its loss can be constructed.
    """
    model, criterion = factories.construct(config_simple['model']['name'])
    net = model(config_simple['model'])
    loss = criterion(config_simple['model'])

    net.eval()
    net.train()
コード例 #2
0
def test_model_full(config_full):
    """
    Tests whether a model can be trained.
    Including parsers and trainval in the execution.

    Parameters
    ----------
    config: dict
        Generated by a fixture above, dummy config to allow networks to run.
        It is mostly empty, we rely on networks default config.
    """
    config = config_full
    model, criterion = factories.construct(config['model']['name'])
    net = model(config['model'])
    loss = criterion(config['model'])

    if not hasattr(net, "INPUT_SCHEMA"):
        pytest.skip('No test defined for network of %s' %
                    config['model']['name'])

    if not hasattr(loss, "INPUT_SCHEMA"):
        pytest.skip('No test defined for criterion of %s' %
                    config['model']['name'])

    # Setup configuration to have all necessary I/O keys
    config['iotool']['dataset']['schema'] = {}
    config['model']['network_input'] = []
    config['model']['loss_input'] = []
    for i, x in enumerate(net.INPUT_SCHEMA + loss.INPUT_SCHEMA):
        parser_name = x[0]
        parser_return_types = x[1]
        config['iotool']['dataset']['schema'][i] = [x[0]]
        for t in parser_return_types:
            config['iotool']['dataset']['schema'][i].extend(
                branch[parser_name][t])
        if i < len(net.INPUT_SCHEMA):
            config['model']['network_input'].append(i)
        else:
            config['model']['loss_input'].append(i)

    process_config(config)
    # Try opening LArCV data file
    try:
        handlers = prepare(config)
    except FileNotFoundError:
        pytest.skip('File not found to test the loader.')

    train_loop(config, handlers)
コード例 #3
0
def test_model_forward(config_simple, N, num_voxels_low, num_voxels_high):
    """
    Test whether a model can be trained.
    Using only numpy input arrays, should also test with parsers running.

    Parameters
    ----------
    config: dict
        Generated by a fixture above, dummy config to allow networks to run.
        It is mostly empty, we rely on networks default config.
    N: int
        Spatial size
    num_voxels_low: int, optional
        Lower boundary for generating (random) number of voxels.
    num_voxels_high: int, optional
        Upper boundary for generating (random) number of voxels.
    """
    config = config_simple
    model, criterion = factories.construct(config['model']['name'])
    net = model(config['model'])
    loss = criterion(config['model'])

    if not hasattr(net, "INPUT_SCHEMA"):
        pytest.skip('No test defined for network of %s' % config['model']['name'])

    net_input, voxels = generate_data(N, net.INPUT_SCHEMA,
                                      num_voxels_low=num_voxels_low,
                                      num_voxels_high=num_voxels_high)
    output = net.forward(net_input)

    if not hasattr(loss, "INPUT_SCHEMA"):
        pytest.skip('No test defined for criterion of %s' % config['model']['name'])


    loss_input = generate_data(N, loss.INPUT_SCHEMA,
                                 num_voxels_low=num_voxels_low,
                                 num_voxels_high=num_voxels_high,
                                 voxels=voxels,
                                 loss=True)[0]
    res = loss.forward(output, *loss_input)

    res['loss'].backward()
コード例 #4
0
def config(request):
    """
    Fixture to generate a basic configuration dictionary given a model name.
    """
    model_name = request.param
    model, criterion = factories.construct(model_name)
    if 'chain' in model_name:
        model_config = {'name': model_name, 'modules': {}}
        for module in model.MODULES:
            model_config['modules'][module] = {}
    else:
        model_config = {'name': model_name, 'modules': {model_name: {}}}
    model_config['network_input'] = ['input_data', 'segment_label']
    model_config['loss_input'] = ['segment_label']
    iotool_config = {
        'batch_size': 1,
        'minibatch_size': 1,
    }
    config = {'iotool': iotool_config, 'training': {}, 'model': model_config}
    return config