Exemplo n.º 1
0
    def __init__(self, training_set, training_set_header, objective_field,
                 multi_label=False, labels=None, label_separator=None,
                 training_separator=None, multi_label_fields=None,
                 label_aggregates=None, objective=True):
        """Builds a generator from a csv file

           `training_set`: path to the training data file
           `training_set_header`: boolean, True means that headers are first
                                 row in the file
           `objective_field`: objective field column or field name
           `labels`: Fields object with the expected fields structure.
        """
        self.training_set = training_set

        if training_set.__class__.__name__ == "StringIO":
            self.encode = None
            self.training_set = UTF8Recoder(training_set, SYSTEM_ENCODING)
        else:
            self.encode = None if PYTHON3 else FILE_ENCODING
        self.training_set_header = training_set_header
        self.training_reader = None
        self.multi_label = multi_label
        self.objective = objective
        if label_aggregates is None:
            label_aggregates = []
        self.label_aggregates = label_aggregates
        self.training_separator = (decode2(training_separator,
                                           encoding="string_escape")
                                   if training_separator is not None
                                   else get_csv_delimiter())
        if len(self.training_separator) > 1:
            sys.exit("Only one character can be used as test data separator.")
        # opening csv reader
        self.reset()
        self.label_separator = (decode2(label_separator,
                                        encoding="string_escape")
                                if label_separator is not None
                                else get_csv_delimiter())

        first_row = self.get_next(reset=not training_set_header)
        self.row_length = len(first_row)

        if training_set_header:
            self.headers = first_row
        else:
            self.headers = [("field_%s" % index) for index in
                            range(0, self.row_length)]

        self.multi_label_fields = sorted(self._get_columns(multi_label_fields))
        if objective:
            self.objective_column = self._get_columns([objective_field])[0]
            if not self.objective_column in self.multi_label_fields:
                self.multi_label_fields.append(self.objective_column)
            self.labels = labels
        self.fields_labels = self._get_labels()
        if objective:
            if labels is None:
                self.labels = self.fields_labels[self.objective_column]
            self.objective_name = self.headers[self.objective_column]
Exemplo n.º 2
0
    def _get_field_labels(self, row, labels, field_column, separator):
        """Returns the list of labels in a multi-label field

        """
        field_value = row[field_column]
        if self.multi_label:
            new_labels = field_value.split(separator)
            new_labels = [decode2(label).strip() for label in new_labels]
            # TODO: clean user given missing tokens
            for label_index in range(0, len(new_labels)):
                if new_labels[label_index] == "":
                    del new_labels[label_index]
            if new_labels != []:
                if self.objective and field_column == self.objective_column and self.labels is not None:
                    # If user gave the subset of labels, use only those
                    new_labels = [label for label in self.labels if label in new_labels]
                labels[field_column].extend(new_labels)
        else:
            labels[field_column].append(field_value)
        labels[field_column] = sorted(list(set(labels[field_column])))
        return labels
Exemplo n.º 3
0
    def _get_field_labels(self, row, labels, field_column, separator):
        """Returns the list of labels in a multi-label field

        """
        field_value = row[field_column]
        if self.multi_label:
            new_labels = field_value.split(separator)
            new_labels = [decode2(label).strip() for label in new_labels]
            # TODO: clean user given missing tokens
            for label_index in range(0, len(new_labels)):
                if new_labels[label_index] == '':
                    del new_labels[label_index]
            if new_labels != []:
                if (self.objective and field_column == self.objective_column
                        and self.labels is not None):
                    # If user gave the subset of labels, use only those
                    new_labels = [
                        label for label in self.labels if label in new_labels
                    ]
                labels[field_column].extend(new_labels)
        else:
            labels[field_column].append(field_value)
        labels[field_column] = sorted(list(set(labels[field_column])))
        return labels
Exemplo n.º 4
0
class TstReader(object):
    """Retrieves csv info and builds a input data dict

    """
    def __init__(self,
                 test_set,
                 test_set_header,
                 fields,
                 objective_field,
                 test_separator=None):
        """Builds a generator from a csv file and the fields' model structure

           `test_set`: path to the test data file
           `test_set_header`: boolean, True means that headers are first row
                              in the file
           `fields`: Fields object with the expected fields structure.
           `objective_field`: field_id of the objective field
        """
        self.test_set = test_set
        if test_set.__class__.__name__ == "StringIO":
            self.encode = None
            self.test_set = UTF8Recoder(test_set, SYSTEM_ENCODING)
        else:
            self.encode = None if PYTHON3 else FILE_ENCODING
        self.test_set_header = test_set_header
        self.fields = fields
        if (objective_field is not None
                and not objective_field in fields.fields):
            try:
                objective_field = fields.field_id(objective_field)
            except ValueError, exc:
                sys.exit(exc)
        self.objective_field = objective_field
        if test_separator and not PYTHON3:
            test_separator = decode2(test_separator, encoding="string_escape")
        self.test_separator = (test_separator if test_separator is not None
                               else get_csv_delimiter())
        if len(self.test_separator) > 1:
            sys.exit("Only one character can be used as test data separator.")
        try:
            self.test_reader = UnicodeReader(
                self.test_set,
                delimiter=self.test_separator,
                lineterminator="\n").open_reader()
        except IOError:
            sys.exit("Error: cannot read test %s" % test_set)

        self.headers = None
        self.raw_headers = None
        self.exclude = []
        if test_set_header:
            self.headers = self.test_reader.next()
            # validate headers against model fields excluding objective_field,
            # that may be present or not
            if objective_field is not None:
                objective_field = fields.field_column_number(objective_field)
            try:
                fields_names = [
                    fields.fields[fields.field_id(i)]['name']
                    for i in sorted(fields.fields_by_column_number.keys())
                    if objective_field is None or i != objective_field
                ]
            except ValueError, exc:
                sys.exit(exc)
            self.raw_headers = self.headers[:]

            self.exclude = [
                i for i in range(len(self.headers))
                if not self.headers[i] in fields_names
            ]

            self.exclude.reverse()
            if self.exclude:
                if len(self.headers) > len(self.exclude):
                    for index in self.exclude:
                        del self.headers[index]
                else:
                    raise Exception(
                        (u"No test field matches the model fields."
                         u"\nThe expected fields are:\n\n%s\n\n"
                         u"while "
                         u"the headers found in the test file are:"
                         u"\n\n%s\n\n"
                         u"Use --no-test-header flag if first li"
                         u"ne should not be interpreted as"
                         u" headers." % (",".join(fields_names), ",".join(
                             self.headers))).encode("utf-8"))
Exemplo n.º 5
0
def set_source_args(args,
                    name=None,
                    multi_label_data=None,
                    data_set_header=None,
                    fields=None):
    """Returns a source arguments dict

    """

    if name is None:
        name = args.name
    source_args = set_basic_args(args, name)
    if args.project_id is not None:
        source_args.update({"project": args.project_id})
    # if header is set, use it
    if data_set_header is not None:
        source_args.update({"source_parser": {"header": data_set_header}})
    # If user has given an OS locale, try to add the locale used in bigml.com
    if args.user_locale is not None:
        source_locale = bigml_locale(args.user_locale)
        if source_locale is None:
            log_message("WARNING: %s locale equivalence not found."
                        " Using %s instead.\n" %
                        (args.user_locale, LOCALE_DEFAULT),
                        log_file=None,
                        console=True)
            source_locale = LOCALE_DEFAULT
        source_args.update({'source_parser': {}})
        source_args["source_parser"].update({'locale': source_locale})
    # If user has set a training separator, use it.
    if args.training_separator is not None:
        training_separator = decode2(args.training_separator,
                                     encoding="string_escape")
        source_args["source_parser"].update({'separator': training_separator})
    # If uploading a multi-label file, add the user_metadata info needed to
    # manage the multi-label fields
    if (hasattr(args, 'multi_label') and args.multi_label
            and multi_label_data is not None):
        source_args.update(
            {"user_metadata": {
                "multi_label_data": multi_label_data
            }})

    # to update fields attributes or types you must have a previous fields
    # structure (at update time)
    if fields:
        if args.field_attributes_:
            update_attributes(source_args, {"fields": args.field_attributes_},
                              by_column=True,
                              fields=fields)
        if args.types_:
            update_attributes(source_args, {"fields": args.types_},
                              by_column=True,
                              fields=fields)
        if args.import_fields:
            fields_struct = fields.new_fields_structure(args.import_fields)
            check_fields_struct(fields_struct, "source")
            update_attributes(source_args, fields_struct)
        if 'source' in args.json_args:
            update_json_args(source_args, args.json_args.get('source'), fields)
    return source_args