コード例 #1
0
ファイル: test_training.py プロジェクト: larsmennen/keras
def test_check_bad_shape():
    a = np.random.random((2, 3, 5))
    with pytest.raises(ValueError) as exc:
        _check_loss_and_target_compatibility([a], [K.categorical_crossentropy],
                                             [(2, 3, 6)])

    assert 'targets to have the same shape' in str(exc)
コード例 #2
0
ファイル: test_training.py プロジェクト: larsmennen/keras
def test_check_last_is_one():
    a = np.random.random((2, 3, 1))
    with pytest.raises(ValueError) as exc:
        _check_loss_and_target_compatibility([a], [K.categorical_crossentropy],
                                             [a.shape])

    assert 'You are passing a target array' in str(exc)
コード例 #3
0
def test_check_not_failing():
    a = np.random.random((2, 1, 3))
    _check_loss_and_target_compatibility([a],
                                         [losses.categorical_crossentropy],
                                         [a.shape])
    _check_loss_and_target_compatibility([a],
                                         [losses.categorical_crossentropy],
                                         [(2, None, 3)])
コード例 #4
0
ファイル: keras_ops.py プロジェクト: windcr/SRGAN
def _standardize_user_data(model,
                           x,
                           y,
                           sample_weight=None,
                           class_weight=None,
                           check_batch_dim=True,
                           batch_size=None):
    if not hasattr(model, 'optimizer'):
        raise Exception('You must compile a model before training/testing.'
                        ' Use `model.compile(optimizer, loss)`.')

    output_shapes = []
    for output_shape, loss_fn in zip(model.internal_output_shapes,
                                     model.loss_functions):
        if loss_fn.__name__ == 'sparse_categorical_crossentropy':
            output_shapes.append(output_shape[:-1] + (1, ))
        elif getattr(losses, loss_fn.__name__, None) is None:
            output_shapes.append(None)
        else:
            output_shapes.append(output_shape)
    x = _standardize_input_data(x,
                                model.input_names,
                                model.internal_input_shapes,
                                exception_prefix='model input')
    y = _standardize_input_data(y,
                                model.output_names,
                                output_shapes,
                                exception_prefix='model target')
    sample_weights = _standardize_sample_weights(sample_weight,
                                                 model.output_names)
    class_weights = _standardize_class_weights(class_weight,
                                               model.output_names)
    sample_weights = [
        _standardize_weights(ref, sw, cw, mode) for (ref, sw, cw, mode) in zip(
            y, sample_weights, class_weights, model.sample_weight_modes)
    ]
    '''
    We only need to comment out check_array_lengeh(x, y, weights) in the next line to
    let the model compile and train.
    '''
    # check_array_lengths(x, y, sample_weights)

    _check_loss_and_target_compatibility(y, model.loss_functions,
                                         model.internal_output_shapes)
    if model.stateful and batch_size:
        if x[0].shape[0] % batch_size != 0:
            raise Exception('In a stateful network, '
                            'you should only pass inputs with '
                            'a number of samples that can be '
                            'divided by the batch size. Found: ' +
                            str(x[0].shape[0]) + ' samples')
    return x, y, sample_weights
コード例 #5
0
ファイル: test_training.py プロジェクト: pkainz/keras
def test_check_bad_shape():
    a = np.random.random((2, 3, 5))
    with pytest.raises(ValueError) as exc:
        _check_loss_and_target_compatibility([a], [losses.categorical_crossentropy], [(2, 3, 6)])

    assert 'targets to have the same shape' in str(exc)
コード例 #6
0
ファイル: test_training.py プロジェクト: pkainz/keras
def test_check_last_is_one():
    a = np.random.random((2, 3, 1))
    with pytest.raises(ValueError) as exc:
        _check_loss_and_target_compatibility([a], [losses.categorical_crossentropy], [a.shape])

    assert 'You are passing a target array' in str(exc)
コード例 #7
0
ファイル: test_training.py プロジェクト: pkainz/keras
def test_check_not_failing():
    a = np.random.random((2, 1, 3))
    _check_loss_and_target_compatibility([a], [losses.categorical_crossentropy], [a.shape])
    _check_loss_and_target_compatibility([a], [losses.categorical_crossentropy], [(2, None, 3)])