def test_nchw_backwardweights_filter3x5(request, tensile_client_dir, tmp_path):
    z = {}  # problemType definition
    conv = Convolution(z,
                       'ConvolutionBackwardWeights',
                       config={
                           'TensorAFormat': 'NCHW',
                           'Filter': '3x5',
                       })
    log.debug(conv.printUsage(z))
    assert (z['NumIndicesC'] == 4)
    assert (z['IndexAssignmentsA'] == [5, 0, 1, 2, 4])
    assert (z['IndexAssignmentsB'] == [5, 3, 4])
    #assert(z['SetConstStrideA']==[[3,1]])
    assert (z['SetConstStrideB'] == [])
    YamlBuilder.run_tensile_client(request, conv, z, tensile_client_dir,
                                   tmp_path)
def test_nchw_backwardweights_defaults(request, tensile_client_dir, tmp_path):
    z = {}  # problemType definition
    conv = Convolution(z,
                       'ConvolutionBackwardWeights',
                       config={
                           'TensorAFormat': 'NCHW',
                           'Spatial': '14x14',
                       })
    log.debug(conv.printUsage(z))
    assert (z['NumIndicesC'] == 2)
    assert (z['IndexAssignmentsA'] == [3, 0, 2])
    assert (z['IndexAssignmentsB'] == [3, 1, 2])
    assert (z['SetConstStrideA'] == [[3, 1]])
    assert (z['SetConstStrideB'] == [])
    YamlBuilder.run_tensile_client(request, conv, z, tensile_client_dir,
                                   tmp_path)
Ejemplo n.º 3
0
def test_yaml(request, tensile_client_dir, tmp_path):
    z = {}  # problemType definition
    conv = Convolution(z,
                       'ConvolutionForward',
                       config={
                           'TensorAFormat': 'NCHW',
                       })
    log.debug(conv.printUsage(z))
    assert (z['NumIndicesC'] == 3)
    assert (z['IndexAssignmentsA'] == [0, 3, 2])
    assert (z['IndexAssignmentsB'] == [3, 1, 2])
    assert (z['SetConstStrideA'] == [[0, 1]])
    assert (z['SetConstStrideB'] == [[2, 0]])
    assert (z['UseInitialStrides'] == False)

    YamlBuilder.run_tensile_client(request, conv, z, tensile_client_dir,
                                   tmp_path)