예제 #1
0
def get_cas_host_type(conn):
    ''' Return a server type indicator '''
    with sw.option_context(print_messages=False):
        out = conn.about()
    ostype = out['About']['System']['OS Family']
    stype = 'mpp'
    htype = 'nohdfs'
    if out['server'].loc[0, 'nodes'] == 1:
        stype = 'smp'
    if ostype.startswith('LIN'):
        ostype = 'linux'
    elif ostype.startswith('WIN'):
        ostype = 'windows'
    elif ostype.startswith('OSX'):
        ostype = 'mac'
    else:
        raise ValueError('Unknown OS type: ' + ostype)

    # Check to see if HDFS is present
    out = conn.table.querycaslib(caslib='CASUSERHDFS')
    for key, value in list(out.items()):
        if 'CASUSERHDFS' in key and value:
            # Default HDFS caslib for user exists
            htype = ''

    if stype == 'mpp' and (len(htype) > 0):
        return ostype + '.' + stype + '.' + htype
    else:
        return ostype + '.' + stype
예제 #2
0
    def test_stdout(self):
        code = "str = ''; do i = 1 to 997; str = str || 'foo'; end; print str;"
        self.s.loadactionset('sccasl')

        if swat.TKVersion() == 'vb025':
            self.skipTest("Stdout fix does not exist in this version")

        with swat.option_context(print_messages=True):
            with captured_stdout() as out:
                self.s.runcasl(code)

        self.assertEqual(out.getvalue(), (997 * 'foo') + '\n')
예제 #3
0
    def test_sasdataframe(self):
        df = self.table.fetch(sastypes=False).Fetch

        hpinfo = df.colinfo['Horsepower']
        hpinfo.label = 'How much power?'
        hpinfo.format = 'INT'
        hpinfo.width = 11

        dmh = swat.datamsghandlers.PandasDataFrame(df)

        tbl = self.s.addtable(table='dtypes', **dmh.args.addtable).casTable

        with swat.option_context('cas.dataset.index_name', 'Column'):
            data = tbl.columninfo().ColumnInfo

        hp = data.loc['Horsepower']

        self.assertEqual(hp.Label, hpinfo.label)
        self.assertEqual(hp.Format, hpinfo.format)
        self.assertEqual(hp.FormattedLength, hpinfo.width)
예제 #4
0
def display_object_detections(conn,
                              table,
                              coord_type,
                              max_objects=10,
                              num_plot=10,
                              n_col=2,
                              fig_size=None):
    '''
    Plot images with drawn bounding boxes.

    conn : CAS
        CAS connection object
    table : string or CASTable
        Specifies the object detection castable to be plotted.
    coord_type : string
        Specifies coordinate type of input table
    max_objects : int, optional
        Specifies the maximum number of bounding boxes to be plotted on an image.
        Default: 10
    num_plot : int, optional
        Specifies the name of the castable.
    n_col : int, optional
        Specifies the number of column to plot.
        Default: 2
    fig_size : int, optional
        Specifies the size of figure.

    '''
    conn.retrieve('loadactionset', _messagelevel='error', actionset='image')

    input_tbl_opts = input_table_check(table)
    input_table = conn.CASTable(**input_tbl_opts)
    img_num = input_table.shape[0]
    num_plot = num_plot if num_plot < img_num else img_num
    input_table = input_table.sample(num_plot)
    det_label_image_table = random_name('detLabelImageTable')

    num_max_obj = input_table['_nObjects_'].max()
    max_objects = max_objects if num_max_obj > max_objects else num_max_obj

    with sw.option_context(print_messages=False):
        res = conn.image.extractdetectedobjects(casout={
            'name': det_label_image_table,
            'replace': True
        },
                                                coordtype=coord_type,
                                                maxobjects=max_objects,
                                                table=input_table)
        if res.severity > 0:
            for msg in res.messages:
                print(msg)

    outtable = conn.CASTable(det_label_image_table)
    num_detection = len(outtable)
    # print('{} out of {} images have bounding boxes to display'.format(num_detection, img_num))
    if num_detection == 0:
        print(
            'Since there is no image that contains a bounding box, cannot display any image.'
        )
        return
    num_plot = num_plot if num_plot < num_detection else num_detection
    # if random_plot:
    #     conn.shuffle(det_label_image_table, casout = {'name': det_label_image_table, 'replace': True})

    with sw.option_context(print_messages=False):
        prediction_plot = conn.image.fetchImages(
            imageTable={'name': det_label_image_table},
            to=num_plot,
            fetchImagesVars=['_image_', '_path_'])
        if res.severity > 0:
            for msg in res.messages:
                print(msg)

    if num_plot > n_col:
        n_row = num_plot // n_col + 1
    else:
        n_row = 1
        n_col = num_plot

    n_col_m = n_col
    if n_col_m < 1:
        n_col_m += 1

    n_row_m = n_row
    if n_row < 1:
        n_row_m += 1

    if fig_size is None:
        fig_size = (16, 16 // n_col_m * n_row_m)

    fig = plt.figure(figsize=fig_size)

    k = 1

    for i in range(num_plot):
        image = prediction_plot['Images']['Image'][i]
        ax = fig.add_subplot(n_row, n_col, k)
        plt.imshow(image)
        if '_path_' in prediction_plot['Images'].columns:
            plt.title(
                str(
                    os.path.basename(
                        prediction_plot['Images']['_path_'].loc[i])))
        k = k + 1
        plt.xticks([]), plt.yticks([])
    plt.show()

    with sw.option_context(print_messages=False):
        conn.table.droptable(det_label_image_table)
예제 #5
0
def create_object_detection_table(conn,
                                  data_path,
                                  coord_type,
                                  output,
                                  local_path=None,
                                  image_size=416):
    '''
    Create an object detection table

    Parameters
    ----------
    conn : session
        CAS connection object
    data_path : string
        Specifies a location where annotation files and image files are stored.
        Annotation files should be XML file based on Pascal VOC format
        Notice that the path should be accessible by CAS server.
    coord_type : string
        Specifies the type of coordinate to convert into.
        'yolo' specifies x, y, width and height, x, y is the center
        location of the object in the grid cell. x, y, are between 0
        and 1 and are relative to that grid cell. x, y = 0,0 corresponds
        to the top left pixel of the grid cell.
        'coco' specifies xmin, ymin, xmax, ymax that are borders of a
        bounding boxes.
        The values are relative to parameter image_size.
        Valid Values: yolo, coco
    output : string
        Specifies the name of the object detection table.
    local_path : string, optional
        Local_path and data_path point to the same location.
        The parameter local_path will be optional (default=None) if the
        Python client has the same OS as CAS server. Otherwise, the path that
        depends on the Python client OS needs to be specified.
        For example:
        Windows client with linux CAS server:
        data_path=/path/to/data/path
        local_path=\\path\to\data\path
        Linux clients with Windows CAS Server:
        data_path=\\path\to\data\path
        local_path=/path/to/data/path
    image_size : integer, optional
        Specifies the size of images to resize.
        Default: 416

    Returns
    -------
    A list of variables that are the labels of the object detection table

    '''
    with sw.option_context(print_messages=False):
        server_type = get_cas_host_type(conn).lower()
    local_os_type = platform.system()
    unix_type = server_type.startswith("lin") or server_type.startswith("osx")
    # check if local and server are same type of OS
    # in different os
    if (unix_type and local_os_type.startswith('Win')
        ) or not (unix_type or local_os_type.startswith('Win')):
        if local_path is None:
            raise ValueError(
                'local_path must be specified when your server is on {} OS and local '
                'python is on {} OS'.format(
                    server_type.split('.')[0].capitalize(), local_os_type))
    else:
        local_path = data_path

    conn.retrieve('loadactionset', _messagelevel='error', actionset='image')
    conn.retrieve('loadactionset',
                  _messagelevel='error',
                  actionset='deepLearn')
    conn.retrieve('loadactionset',
                  _messagelevel='error',
                  actionset='transpose')

    if coord_type.lower() not in ['yolo', 'coco']:
        raise ValueError('coord_type, {}, is not supported'.format(coord_type))

    # label variables, _ : category;
    yolo_var_name = ['_', '_x', '_y', '_width', '_height']
    coco_var_name = ['_', '_xmin', '_ymin', '_xmax', '_ymax']
    if coord_type.lower() == 'yolo':
        var_name = yolo_var_name
    elif coord_type.lower() == 'coco':
        var_name = coco_var_name

    det_img_table = random_name('DET_IMG')

    with sw.option_context(print_messages=False):
        res = conn.image.loadImages(path=data_path,
                                    recurse=False,
                                    labelLevels=-1,
                                    casout={
                                        'name': det_img_table,
                                        'replace': True
                                    })
        if res.severity > 0:
            for msg in res.messages:
                if not msg.startswith('WARNING'):
                    print(msg)

        res = conn.image.processImages(table={'name': det_img_table},
                                       imagefunctions=[{
                                           'options': {
                                               'functiontype': 'RESIZE',
                                               'height': image_size,
                                               'width': image_size
                                           }
                                       }],
                                       casout={
                                           'name': det_img_table,
                                           'replace': True
                                       })

        if res.severity > 0:
            for msg in res.messages:
                print(msg)
        else:
            print("NOTE: Images are processed.")

    with sw.option_context(print_messages=False):
        caslib = find_caslib(conn, data_path)
        if caslib is None:
            caslib = random_name('Caslib', 6)
            rt = conn.retrieve('addcaslib',
                               _messagelevel='error',
                               name=caslib,
                               path=data_path,
                               activeonadd=False,
                               subdirectories=True,
                               datasource={'srctype': 'path'})
            if rt.severity > 1:
                raise DLPyError(
                    'something went wrong while adding the caslib for the specified path.'
                )

    # find all of annotation files under the directory
    a = conn.fileinfo(caslib=caslib, allfiles=True)
    label_files = conn.fileinfo(caslib=caslib,
                                allfiles=True).FileInfo['Name'].values
    label_files = [
        x for x in label_files if x.endswith('.xml') or x.endswith('.json')
    ]
    if len(label_files) == 0:
        raise ValueError('Can not find any annotation file under data_path')

    # parse xml or json files and create txt files
    cwd = os.getcwd()
    os.chdir(local_path)
    for idx, filename in enumerate(label_files):
        if filename.endswith('.xml'):
            convert_xml_annotation(filename, coord_type, image_size)
        # elif filename.endswith('.json'):
        #     convert_json_annotation(filename)
    os.chdir(cwd)
    label_tbl_name = random_name('obj_det')
    # load all of txt files into cas server
    label_files = conn.fileinfo(caslib=caslib,
                                allfiles=True).FileInfo['Name'].values
    label_files = [x for x in label_files if x.endswith('.txt')]
    with sw.option_context(print_messages=False):
        for idx, filename in enumerate(label_files):
            tbl_name = '{}_{}'.format(label_tbl_name, idx)
            conn.retrieve('loadtable',
                          caslib=caslib,
                          path=filename,
                          casout=dict(name=tbl_name, replace=True),
                          importOptions=dict(fileType='csv',
                                             getNames=False,
                                             varChars=True,
                                             delimiter=','))
            conn.retrieve('partition',
                          table=dict(name=tbl_name,
                                     compvars=['idjoin'],
                                     comppgm='idjoin="{}";'.format(
                                         filename[:-4])),
                          casout=dict(name=tbl_name, replace=True))

    input_tbl_name = [
        '{}_{}'.format(label_tbl_name, i) for i in range(idx + 1)
    ]
    string_input_tbl_name = ' '.join(input_tbl_name)
    # concatenate all of annotation table together
    fmt_code = '''
                data {0}; 
                set {1}; 
                run;
                '''.format(output, string_input_tbl_name)
    conn.runcode(code=fmt_code, _messagelevel='error')
    cls_col_format_length = conn.columninfo(output).ColumnInfo.loc[0][3]

    conn.altertable(name=output,
                    columns=[
                        dict(name='Var1', rename=var_name[0]),
                        dict(name='Var2', rename=var_name[1]),
                        dict(name='Var3', rename=var_name[2]),
                        dict(name='Var4', rename=var_name[3]),
                        dict(name='Var5', rename=var_name[4])
                    ])
    # add sequence id that is used to build column name in transpose process
    sas_code = '''
               data {0};
                  set {0} ;
                  by idjoin;
                  seq_id+1;
                  if first.idjoin then seq_id=0;
                  output;
               run;
               '''.format(output)
    conn.runcode(code=sas_code, _messagelevel='error')
    # convert long table to wide table
    with sw.option_context(print_messages=False):
        for var in var_name:
            conn.transpose(prefix='_Object',
                           suffix=var,
                           id='seq_id',
                           transpose=[var],
                           table=dict(name=output, groupby='idjoin'),
                           casout=dict(name='output{}'.format(var), replace=1))
            conn.altertable(name='output{}'.format(var),
                            columns=[{
                                'name': '_NAME_',
                                'drop': True
                            }])
    # dljoin the five columns
    conn.deeplearn.dljoin(table='output{}'.format(var_name[0]),
                          id='idjoin',
                          annotatedtable='output{}'.format(var_name[1]),
                          casout=dict(name=output, replace=True),
                          _messagelevel='error')

    for var in var_name[2:]:
        conn.deepLearn.dljoin(table=output,
                              id='idjoin',
                              annotatedtable='output{}'.format(var),
                              casout=dict(name=output, replace=True))
    # get number of objects in each image
    code = '''
            data {0};
            set {0};
            array _all _numeric_;
            _nObjects_ = (dim(_all)-cmiss(of _all[*]))/4;
            run;
            '''.format(output)
    conn.runcode(code=code, _messagelevel='error')
    max_instance = int((max(conn.columninfo(output).ColumnInfo['ID']) - 1) / 5)
    var_order = ['idjoin', '_nObjects_']
    for i in range(max_instance):
        for var in var_name:
            var_order.append('_Object' + str(i) + var)
    # change order of columns and unify the formattedlength of class columns
    format_ = '${}.'.format(cls_col_format_length)
    # conn.altertable(name = output, columnorder = var_order, columns =[{'name': '_Object{}_'.format(i),
    #                                                                    'format': format_} for i in range(max_instance)])
    conn.altertable(name=output,
                    columns=[{
                        'name': '_Object{}_'.format(i),
                        'format': format_
                    } for i in range(max_instance)])
    # parse and create dljoin id column
    label_col_info = conn.columninfo(output).ColumnInfo
    filename_col_length = label_col_info.loc[
        label_col_info['Column'] == 'idjoin', ['FormattedLength']].values[0][0]

    image_sas_code = "length idjoin $ {0};idjoin = inputc(scan(_path_,{1},'./'),'{0}.');".format(
        filename_col_length,
        len(data_path.split('\\')) - 3)
    img_tbl = conn.CASTable(det_img_table,
                            computedvars=['idjoin'],
                            computedvarsprogram=image_sas_code,
                            vars=[{
                                'name': 'idjoin'
                            }, {
                                'name': '_image_'
                            }])
    # join the image table and label table together
    res = conn.deepLearn.dljoin(table=img_tbl,
                                annotation=output,
                                id='idjoin',
                                casout={
                                    'name': output,
                                    'replace': True,
                                    'replication': 0
                                })

    with sw.option_context(print_messages=False):
        for name in input_tbl_name:
            conn.table.droptable(name)
        for var in var_name:
            conn.table.droptable('output{}'.format(var))
        conn.table.droptable(det_img_table)

    print("NOTE: Object detection table is successfully created.")
    return var_order[2:]