def test_c_size(self):
     dim_dict = {'dist_type': 'c',
                 'size': 42,
                 'proc_grid_size': 2,
                 'proc_grid_rank': 0,
                 'start': 0}
     dist = Distribution.from_global_dim_data(self.context, (dim_dict,))
     ddpr = dist.get_dim_data_per_rank()
     shapes = metadata_utils.shapes_from_dim_data_per_rank(ddpr)
     self.assertEqual(shapes, [(21,), (21,)])
示例#2
0
def create_distribution_plot_and_documentation(context, params):
    """Create an array distribution plot and the related .rst documentation."""
    def shape_text(shape):
        """ Get a text string describing the array shape. """
        # Always want to display at least N X M.
        if len(shape) == 1:
            shape = (1, shape[0])
        shape_labels = ['%d' % (s) for s in shape]
        shape_text = ' X '.join(shape_labels)
        return shape_text

    title = params['title']
    labels = params['labels']
    shape = params['shape']
    grid_shape = params.get('grid_shape', None)
    text = params.get('text', None)
    dist = params.get('dist', None)
    dimdata = params.get('dimdata', None)
    filename = params.get('filename', None)
    skip = params.get('skip', False)

    if skip:
        return

    # Create array, either from dist or dimdata.
    if dist is not None:
        distribution = Distribution(context,
                                    shape,
                                    dist=dist,
                                    grid_shape=grid_shape)
    elif dimdata is not None:
        distribution = Distribution.from_global_dim_data(context, dimdata)
    else:
        raise ValueError('Must provide either dist or dimdata.')
    array = context.empty(distribution)

    # Fill the array. This is slow but not a real problem here.
    value = 0.0
    if len(shape) == 1:
        for i in range(shape[0]):
            array[i] = value
            value += 1.0
    elif len(shape) == 2:
        for row in range(shape[0]):
            for col in range(shape[1]):
                array[row, col] = value
                value += 1.0
    elif len(shape) == 3:
        for i in range(shape[0]):
            for j in range(shape[1]):
                for k in range(shape[2]):
                    array[i, j, k] = value
                    value += 1.0
    else:
        # TODO: Even better would be to generalize this to any dimensions.
        raise ValueError('Array must be 1, 2, or 3 dimensional.')

    # Get all process grid coordinates.
    # This is duplicating work in print_array_documentation(),
    # but it is needed for the local array plots.
    def _get_process_coords(local_arr):
        return local_arr.cart_coords

    process_coords = context.apply(_get_process_coords, (array.key, ),
                                   targets=array.targets)

    # Plot title and axis labels.
    plot_title = title + ' ' + shape_text(shape) + '\n'
    if len(shape) == 1:
        # add more space for cramped plot.
        plot_title += '\n'
    xlabel = 'Axis 1, %s' % (labels[1])
    ylabel = 'Axis 0, %s' % (labels[0])

    # Documentation title and text description.
    doc_title = title
    dist_text = ' X '.join(["'%s'" % (label) for label in labels])
    # Choose 'a' vs 'an' appropriately.
    if title[0] in 'aeiouAEIOU':
        article = 'an'
    else:
        article = 'a'
    doc_text = 'A (%s) array, with %s %s (%s) distribution over a (%s) process grid.' % (
        shape_text(shape), article, title, dist_text,
        shape_text(array.grid_shape))
    if text is not None:
        doc_text = doc_text + "\n\n" + text

    # Filenames for array plots.
    global_plot_filename = filename
    local_plot_filename = None
    if global_plot_filename is not None:
        root, ext = os.path.splitext(global_plot_filename)
        local_plot_filename = root + '_local' + ext

    # Create plot.
    if len(shape) in [1, 2]:
        plotting.plot_array_distribution(
            array,
            process_coords,
            title=plot_title,
            xlabel=xlabel,
            ylabel=ylabel,
            legend=True,
            global_plot_filename=global_plot_filename,
            local_plot_filename=local_plot_filename)
    else:
        # Not plottable, avoid writing links to missing plots.
        global_plot_filename = None
        local_plot_filename = None

    # Print documentation.
    print_array_documentation(context,
                              array,
                              title=doc_title,
                              text=doc_text,
                              global_plot_filename=global_plot_filename,
                              local_plot_filename=local_plot_filename)
def create_distribution_plot_and_documentation(context, params):
    """Create an array distribution plot and the related .rst documentation."""

    def shape_text(shape):
        """ Get a text string describing the array shape. """
        # Always want to display at least N X M.
        if len(shape) == 1:
            shape = (1, shape[0])
        shape_labels = ['%d' % (s) for s in shape]
        shape_text = ' X '.join(shape_labels)
        return shape_text

    title = params['title']
    labels = params['labels']
    shape = params['shape']
    grid_shape = params.get('grid_shape', None)
    text = params.get('text', None)
    dist = params.get('dist', None)
    dimdata = params.get('dimdata', None)
    filename = params.get('filename', None)
    skip = params.get('skip', False)

    if skip:
        return

    # Create array, either from dist or dimdata.
    if dist is not None:
        distribution = Distribution(context, shape, dist=dist,
                                    grid_shape=grid_shape)
    elif dimdata is not None:
        distribution = Distribution.from_global_dim_data(context, dimdata)
    else:
        raise ValueError('Must provide either dist or dimdata.')
    array = context.empty(distribution)

    # Fill the array. This is slow but not a real problem here.
    value = 0.0
    if len(shape) == 1:
        for i in range(shape[0]):
            array[i] = value
            value += 1.0
    elif len(shape) == 2:
        for row in range(shape[0]):
            for col in range(shape[1]):
                array[row, col] = value
                value += 1.0
    elif len(shape) == 3:
        for i in range(shape[0]):
            for j in range(shape[1]):
                for k in range(shape[2]):
                    array[i, j, k] = value
                    value += 1.0
    else:
        # TODO: Even better would be to generalize this to any dimensions.
        raise ValueError('Array must be 1, 2, or 3 dimensional.')

    # Get all process grid coordinates.
    # This is duplicating work in print_array_documentation(),
    # but it is needed for the local array plots.
    def _get_process_coords(local_arr):
        return local_arr.cart_coords
    process_coords = context.apply(_get_process_coords,
                                   (array.key,),
                                   targets=array.targets)

    # Plot title and axis labels.
    plot_title = title + ' ' + shape_text(shape) + '\n'
    if len(shape) == 1:
        # add more space for cramped plot.
        plot_title += '\n'
    xlabel = 'Axis 1, %s' % (labels[1])
    ylabel = 'Axis 0, %s' % (labels[0])

    # Documentation title and text description.
    doc_title = title
    dist_text = ' X '.join(["'%s'" % (label) for label in labels])
    # Choose 'a' vs 'an' appropriately.
    if title[0] in 'aeiouAEIOU':
        article = 'an'
    else:
        article = 'a'
    doc_text = 'A (%s) array, with %s %s (%s) distribution over a (%s) process grid.' % (
        shape_text(shape), article, title, dist_text, shape_text(array.grid_shape))
    if text is not None:
        doc_text = doc_text + "\n\n" + text

    # Filenames for array plots.
    global_plot_filename = filename
    local_plot_filename = None
    if global_plot_filename is not None:
        root, ext = os.path.splitext(global_plot_filename)
        local_plot_filename = root + '_local' + ext

    # Create plot.
    if len(shape) in [1, 2]:
        plotting.plot_array_distribution(
            array,
            process_coords,
            title=plot_title,
            xlabel=xlabel,
            ylabel=ylabel,
            legend=True,
            global_plot_filename=global_plot_filename,
            local_plot_filename=local_plot_filename)
    else:
        # Not plottable, avoid writing links to missing plots.
        global_plot_filename = None
        local_plot_filename = None

    # Print documentation.
    print_array_documentation(
        context,
        array,
        title=doc_title,
        text=doc_text,
        global_plot_filename=global_plot_filename,
        local_plot_filename=local_plot_filename)