Пример #1
0
    def save(self, filename=None):
        """Save file

        Parameters
        ----------
        filename : str, optional
            File path
            Default value filename given to class constructor

        Raises
        ------
        IOError:
            File has unknown file format

        Returns
        -------
        self

        """

        if filename:
            self.filename = filename
            self.detect_file_format()
            self.validate_format()

        if self.filename is None or self.filename == '':
            message = '{name}: Filename is empty [{filename}]'.format(
                name=self.__class__.__name__, filename=self.filename)

            self.logger.exception(message)
            raise IOError(message)

        try:
            from dcase_util.files import Serializer
            if self.format == FileFormat.TXT:
                with open(self.filename, "w") as text_file:
                    for line in self:
                        text_file.write(str(line) + '\n')

            elif self.format == FileFormat.CPICKLE:
                Serializer.save_cpickle(filename=self.filename, data=self)

            else:
                message = '{name}: Unknown format [{format}]'.format(
                    name=self.__class__.__name__, format=self.filename)
                self.logger.exception(message)
                raise IOError(message)

        except KeyboardInterrupt:
            os.remove(
                self.filename
            )  # Delete the file, since most likely it was not saved fully
            raise

        # Check if after save function is defined, call if found
        if hasattr(self, '_after_save'):
            self._after_save()

        return self
Пример #2
0
    def load(self, filename=None, headers=None):
        """Load file

        Parameters
        ----------
        filename : str, optional
            File path
            Default value filename given to class constructor

        headers : list of str, optional
            List of column names

        Raises
        ------
        IOError:
            File does not exists or has unknown file format

        Returns
        -------
        self

        """

        if filename:
            self.filename = filename
            self.detect_file_format()
            self.validate_format()

        if self.exists():
            from dcase_util.files import Serializer

            if self.format == FileFormat.TXT:
                with open(self.filename, 'r') as f:
                    lines = f.readlines()
                    # Remove line breaks
                    for i in range(0, len(lines)):
                        lines[i] = lines[i].replace('\r\n',
                                                    '').replace('\n', '')
                    list.__init__(self, lines)

            elif self.format == FileFormat.CPICKLE:
                list.__init__(self,
                              Serializer.load_cpickle(filename=self.filename))

            else:
                message = '{name}: Unknown format [{format}]'.format(
                    name=self.__class__.__name__, format=self.filename)
                self.logger.exception(message)
                raise IOError(message)
        else:
            message = '{name}: File does not exists [{file}]'.format(
                name=self.__class__.__name__, file=self.filename)
            self.logger.exception(message)
            raise IOError(message)

        # Check if after load function is defined, call if found
        if hasattr(self, '_after_load'):
            self._after_load()

        return self
Пример #3
0
    def load(self, filename=None):
        """Load file

        Parameters
        ----------
        filename : str, optional
            File path
            Default value filename given to class constructor

        Raises
        ------
        ImportError:
            Error if file format specific module cannot be imported

        IOError:
            File does not exists or has unknown file format

        Returns
        -------
        self

        """

        if filename:
            self.filename = filename
            self.detect_file_format()
            self.validate_format()

        if self.exists():
            from dcase_util.files import Serializer
            if self.format == FileFormat.CPICKLE:
                self.__dict__.update(
                    Serializer.load_cpickle(filename=self.filename))

            else:
                message = '{name}: Unknown format [{format}]'.format(
                    name=self.__class__.__name__, format=self.filename)
                self.logger.exception(message)
                raise IOError(message)

        else:
            message = '{name}: File does not exists [{file}]'.format(
                name=self.__class__.__name__, file=self.filename)
            self.logger.exception(message)
            raise IOError(message)

        # Check if after load function is defined, call if found
        if hasattr(self, '_after_load'):
            self._after_load()

        return self
Пример #4
0
    def save(self, filename=None):
        """Save file

        Parameters
        ----------
        filename : str, optional
            File path
            Default value filename given to class constructor

        Raises
        ------
        ImportError:
            Error if file format specific module cannot be imported

        IOError:
            File has unknown file format

        Returns
        -------
        self

        """

        if filename:
            self.filename = filename
            self.detect_file_format()
            self.validate_format()

        if self.filename is None or self.filename == '':
            message = '{name}: Filename is empty [{filename}]'.format(
                name=self.__class__.__name__, filename=self.filename)

            self.logger.exception(message)
            raise IOError(message)

        try:
            from dcase_util.files import Serializer
            data = dict(self)
            if hasattr(self, '__getstate__'):
                data.update(dict(self.__getstate__()))

            # Check if before save function is defined, call if found
            if hasattr(self, '_before_save'):
                data = self._before_save(data)

            if self.format == FileFormat.YAML:
                Serializer.save_yaml(filename=self.filename,
                                     data=self.get_dump_content(data=data))

            elif self.format == FileFormat.CPICKLE:
                Serializer.save_cpickle(filename=self.filename, data=data)

            elif self.format == FileFormat.MARSHAL:
                Serializer.save_marshal(filename=self.filename, data=data)

            elif self.format == FileFormat.MSGPACK:
                Serializer.save_msgpack(filename=self.filename, data=data)

            elif self.format == FileFormat.JSON:
                Serializer.save_json(filename=self.filename, data=data)

            elif self.format == FileFormat.TXT:
                with open(self.filename, "w") as text_file:
                    for line_id in self:
                        text_file.write(self[line_id])

            else:
                message = '{name}: Unknown format [{format}]'.format(
                    name=self.__class__.__name__, format=self.filename)
                self.logger.exception(message)
                raise IOError(message)

        except KeyboardInterrupt:
            os.remove(
                self.filename
            )  # Delete the file, since most likely it was not saved fully
            raise

        # Check if after save function is defined, call if found
        if hasattr(self, '_after_save'):
            self._after_save()

        return self
Пример #5
0
    def load(self, filename=None):
        """Load file

        Parameters
        ----------
        filename : str, optional
            File path
            Default value filename given to class constructor

        Raises
        ------
        ImportError:
            Error if file format specific module cannot be imported

        IOError:
            File does not exists or has unknown file format

        Returns
        -------
        self

        """

        if filename:
            self.filename = filename
            self.detect_file_format()
            self.validate_format()

        if self.exists():
            # File exits
            from dcase_util.files import Serializer
            dict.clear(self)

            if self.format == FileFormat.YAML or self.format == FileFormat.META:
                data = Serializer.load_yaml(filename=self.filename)
                dict.update(self, data)

            elif self.format == FileFormat.CPICKLE:
                dict.update(self,
                            Serializer.load_cpickle(filename=self.filename))

            elif self.format == FileFormat.MARSHAL:
                dict.update(self,
                            Serializer.load_marshal(filename=self.filename))

            elif self.format == FileFormat.MSGPACK:
                dict.update(self,
                            Serializer.load_msgpack(filename=self.filename))

            elif self.format == FileFormat.JSON:
                dict.update(self, Serializer.load_json(filename=self.filename))

            elif self.format == FileFormat.TXT:
                with open(self.filename, 'r') as f:
                    lines = f.readlines()
                    dict.update(self, dict(zip(range(0, len(lines)), lines)))

            elif self.format == FileFormat.CSV:
                data = {}
                delimiter = self.delimiter()
                with open(self.filename, 'r') as f:
                    csv_reader = csv.reader(f, delimiter=delimiter)
                    for row in csv_reader:
                        if len(row) == 2:
                            data[row[0]] = row[1]

                dict.update(self, data)

            else:
                message = '{name}: Unknown format [{format}]'.format(
                    name=self.__class__.__name__, format=self.filename)
                self.logger.exception(message)
                raise IOError(message)

        else:
            message = '{name}: File does not exists [{file}]'.format(
                name=self.__class__.__name__, file=self.filename)
            self.logger.exception(message)
            raise IOError(message)

        # Check if after load function is defined, call if found
        if hasattr(self, '_after_load'):
            self._after_load()

        return self
Пример #6
0
    def save(self,
             filename=None,
             fields=None,
             csv_header=True,
             file_format=None,
             delimiter=','):
        """Save file

        Parameters
        ----------
        filename : str, optional
            File path
            Default value filename given to class constructor

        fields : list of str
            Fields in correct order, if none given all field in alphabetical order will be outputted

        csv_header : bool
            In case of CSV formatted file, first line will contain field names. Names are taken from fields parameter.
            Default value True

        file_format : FileFormat, optional
            Forced file format, use this when there is a miss-match between file extension and file format.

        delimiter : str
            Delimiter to be used when saving data

        Raises
        ------
        IOError:
            File has unknown file format

        Returns
        -------
        self

        """

        if filename:
            self.filename = filename
            if not file_format:
                self.detect_file_format()
                self.validate_format()

        if file_format and FileFormat.validate_label(label=file_format):
            self.format = file_format

        if self.filename is None or self.filename == '':
            message = '{name}: Filename is empty [{filename}]'.format(
                name=self.__class__.__name__, filename=self.filename)

            self.logger.exception(message)
            raise IOError(message)

        try:
            from dcase_util.files import Serializer

            if self.format == FileFormat.YAML:
                data = copy.deepcopy(list(self))
                for item_id, item in enumerate(data):
                    data[item_id] = self.get_dump_content(data=item)

                    Serializer.save_yaml(filename=self.filename, data=data)

            elif self.format == FileFormat.CSV:
                if fields is None:
                    fields = set()
                    for item in self:
                        fields.update(list(item.keys()))
                    fields = sorted(list(fields))

                with open(self.filename, 'w') as csv_file:
                    csv_writer = csv.writer(csv_file, delimiter=delimiter)
                    if csv_header:
                        csv_writer.writerow(fields)

                    for item in self:
                        item_values = []
                        for field in fields:
                            item_values.append(item[field])
                        csv_writer.writerow(item_values)

            elif self.format == FileFormat.CPICKLE:
                Serializer.save_cpickle(filename=self.filename, data=self)

            else:
                message = '{name}: Unknown format [{format}]'.format(
                    name=self.__class__.__name__, format=self.filename)
                self.logger.exception(message)
                raise IOError(message)

        except KeyboardInterrupt:
            os.remove(
                self.filename
            )  # Delete the file, since most likely it was not saved fully
            raise

        # Check if after save function is defined, call if found
        if hasattr(self, '_after_save'):
            self._after_save()

        return self
Пример #7
0
    def save(self,
             filename=None,
             fields=None,
             csv_header=True,
             file_format=None,
             delimiter='\t',
             **kwargs):
        """Save content to csv file

        Parameters
        ----------
        filename : str
            Filename. If none given, one given for class constructor is used.
            Default value None

        fields : list of str
            Fields in correct order, if none given all field in alphabetical order will be outputted.
            Used only for CSV formatted files.
            Default value None

        csv_header : bool
            In case of CSV formatted file, first line will contain field names. Names are taken from fields parameter.
            Default value True

        file_format : FileFormat, optional
            Forced file format, use this when there is a miss-match between file extension and file format.
            Default value None

        delimiter : str
            Delimiter to be used when saving data.
            Default value '\t'

        Returns
        -------
        self

        """

        if filename:
            self.filename = filename
            if not file_format:
                self.detect_file_format()
                self.validate_format()

        if file_format and FileFormat.validate_label(label=file_format):
            self.format = file_format

        if self.format in [FileFormat.TXT]:
            # Make sure writing is using correct line endings to avoid extra empty lines
            if sys.version_info[0] == 2:
                f = open(self.filename, 'wbt')

            elif sys.version_info[0] >= 3:
                f = open(self.filename, 'wt', newline='')

            try:
                writer = csv.writer(f, delimiter=delimiter)
                for item in self:
                    writer.writerow(item.get_list())

            finally:
                f.close()

        elif self.format == FileFormat.CSV:
            if fields is None:
                fields = set()
                for item in self:
                    fields.update(list(item.keys()))

                fields = sorted(list(fields))

            # Make sure writing is using correct line endings to avoid extra empty lines
            if sys.version_info[0] == 2:
                csv_file = open(self.filename, 'wb')

            elif sys.version_info[0] >= 3:
                csv_file = open(self.filename, 'w', newline='')

            try:
                csv_writer = csv.writer(csv_file, delimiter=delimiter)
                if csv_header:
                    csv_writer.writerow(fields)

                for item in self:
                    item_values = []
                    for field in fields:
                        value = item[field]
                        if isinstance(value, list):
                            value = ";".join(value) + ";"

                        item_values.append(value)

                    csv_writer.writerow(item_values)

            finally:
                csv_file.close()

        elif self.format == FileFormat.CPICKLE:
            from dcase_util.files import Serializer
            Serializer.save_cpickle(filename=self.filename, data=self)

        else:
            message = '{name}: Unknown format [{format}]'.format(
                name=self.__class__.__name__, format=self.filename)
            self.logger.exception(message)
            raise IOError(message)

        return self
Пример #8
0
    def load(self,
             filename=None,
             fields=None,
             csv_header=True,
             file_format=None,
             delimiter=None,
             convert_numeric_fields=True):
        """Load file

        Parameters
        ----------
        filename : str, optional
            File path
            Default value filename given to class constructor

        fields : list of str, optional
            List of column names

        csv_header : bool, optional
            Read field names from first line (header). Used only for CSV formatted files.
            Default value True

        file_format : FileFormat, optional
            Forced file format, use this when there is a miss-match between file extension and file format.

        delimiter : str, optional
            Forced data delimiter for csv format. If None given, automatic delimiter sniffer used. Use this when sniffer does not work.

        convert_numeric_fields : bool, optional
            Convert int and float fields to correct type.
            Default value True

        Raises
        ------
        IOError:
            File does not exists or has unknown file format

        ValueError:
            No fields or csv_header set for CSV formatted file.

        Returns
        -------
        self

        """

        if filename:
            self.filename = filename
            if not file_format:
                self.detect_file_format()
                self.validate_format()

        if file_format and FileFormat.validate_label(label=file_format):
            self.format = file_format

        if self.exists():
            from dcase_util.files import Serializer

            if self.format == FileFormat.CSV:
                if fields is None and csv_header is None:
                    message = '{name}: Parameters fields or csv_header has to be set for CSV files.'.format(
                        name=self.__class__.__name__)
                    self.logger.exception(message)
                    raise ValueError(message)

                data = []

                if not delimiter:
                    delimiter = self.delimiter()

                with open(self.filename, 'r') as f:
                    csv_reader = csv.reader(f, delimiter=delimiter)
                    if csv_header:
                        csv_fields = next(csv_reader)
                        if fields is None:
                            fields = csv_fields

                    for row in csv_reader:
                        if convert_numeric_fields:
                            for cell_id, cell_data in enumerate(row):
                                if is_int(cell_data):
                                    row[cell_id] = int(cell_data)

                                elif is_float(cell_data):
                                    row[cell_id] = float(cell_data)

                        data.append(dict(zip(fields, row)))

                list.__init__(self, data)

            elif self.format == FileFormat.YAML:
                data = Serializer.load_yaml(filename=self.filename)
                if isinstance(data, list):
                    list.__init__(self, data)
                else:
                    message = '{name}: YAML data is not in list format.'.format(
                        name=self.__class__.__name__)
                    self.logger.exception(message)
                    raise ImportError(message)

            elif self.format == FileFormat.CPICKLE:
                list.__init__(self,
                              Serializer.load_cpickle(filename=self.filename))

            else:
                message = '{name}: Unknown format [{format}]'.format(
                    name=self.__class__.__name__, format=self.filename)
                self.logger.exception(message)
                raise IOError(message)

        else:
            message = '{name}: File does not exists [{file}]'.format(
                name=self.__class__.__name__, file=self.filename)
            self.logger.exception(message)
            raise IOError(message)

        # Check if after load function is defined, call if found
        if hasattr(self, '_after_load'):
            self._after_load()

        return self
Пример #9
0
    def load(self,
             filename=None,
             fields=None,
             csv_header=True,
             file_format=None,
             delimiter=None,
             decimal='point'):
        """Load probability list from file

        Preferred delimiter is tab, however, other delimiters are supported automatically
        (they are sniffed automatically).

        Supported input formats:
            - [file(string)][label(string)][probability(float)]

        Parameters
        ----------
        filename : str
            Path to the probability list in text format (csv). If none given, one given for class constructor is used.
            Default value None

        fields : list of str, optional
            List of column names. Used only for CSV formatted files.
            Default value None

        csv_header : bool, optional
            Read field names from first line (header). Used only for CSV formatted files.
            Default value True

        file_format : FileFormat, optional
            Forced file format, use this when there is a miss-match between file extension and file format.
            Default value None

        delimiter : str, optional
            Forced data delimiter for csv format. If None given, automatic delimiter sniffer used.
            Use this when sniffer does not work.
            Default value None

        decimal : str
            Decimal 'point' or 'comma'
            Default value 'point'


        Returns
        -------
        data : list of probability item dicts
            List containing probability item dicts

        """
        def validate(row_format, valid_formats):
            for valid_format in valid_formats:
                if row_format == valid_format:
                    return True

            return False

        if filename:
            self.filename = filename
            if not file_format:
                self.detect_file_format()
                self.validate_format()

        if file_format and FileFormat.validate_label(label=file_format):
            self.format = file_format

        if self.exists():
            if self.format in [FileFormat.TXT]:
                if decimal == 'comma':
                    delimiter = self.delimiter(exclude_delimiters=[','])

                else:
                    delimiter = self.delimiter()

                data = []
                field_validator = FieldValidator()
                f = io.open(self.filename, 'rt')
                try:
                    for row in csv.reader(f, delimiter=delimiter):
                        if row:
                            row_format = []
                            for item in row:
                                row_format.append(
                                    field_validator.process(item))

                            for item_id, item in enumerate(row):

                                if row_format[
                                        item_id] == FieldValidator.NUMBER:
                                    # Translate decimal comma into decimal point
                                    row[item_id] = float(row[item_id].replace(
                                        ',', '.'))

                                elif row_format[item_id] in [
                                        FieldValidator.AUDIOFILE,
                                        FieldValidator.DATAFILE,
                                        FieldValidator.STRING,
                                        FieldValidator.ALPHA1,
                                        FieldValidator.ALPHA2,
                                        FieldValidator.LIST
                                ]:

                                    row[item_id] = row[item_id].strip()

                            if validate(row_format=row_format,
                                        valid_formats=[
                                            [
                                                FieldValidator.AUDIOFILE,
                                                FieldValidator.STRING,
                                                FieldValidator.NUMBER
                                            ],
                                            [
                                                FieldValidator.AUDIOFILE,
                                                FieldValidator.ALPHA1,
                                                FieldValidator.NUMBER
                                            ],
                                            [
                                                FieldValidator.AUDIOFILE,
                                                FieldValidator.ALPHA2,
                                                FieldValidator.NUMBER
                                            ],
                                            [
                                                FieldValidator.DATAFILE,
                                                FieldValidator.STRING,
                                                FieldValidator.NUMBER
                                            ],
                                            [
                                                FieldValidator.DATAFILE,
                                                FieldValidator.ALPHA1,
                                                FieldValidator.NUMBER
                                            ],
                                            [
                                                FieldValidator.DATAFILE,
                                                FieldValidator.ALPHA2,
                                                FieldValidator.NUMBER
                                            ]
                                        ]):
                                # Format: [file label probability]
                                data.append(
                                    self.item_class({
                                        'filename': row[0],
                                        'label': row[1],
                                        'probability': row[2],
                                    }))

                            elif validate(row_format=row_format,
                                          valid_formats=[
                                              [
                                                  FieldValidator.AUDIOFILE,
                                                  FieldValidator.STRING,
                                                  FieldValidator.NUMBER,
                                                  FieldValidator.NUMBER
                                              ],
                                              [
                                                  FieldValidator.AUDIOFILE,
                                                  FieldValidator.ALPHA1,
                                                  FieldValidator.NUMBER,
                                                  FieldValidator.NUMBER
                                              ],
                                              [
                                                  FieldValidator.AUDIOFILE,
                                                  FieldValidator.ALPHA2,
                                                  FieldValidator.NUMBER,
                                                  FieldValidator.NUMBER
                                              ],
                                              [
                                                  FieldValidator.DATAFILE,
                                                  FieldValidator.STRING,
                                                  FieldValidator.NUMBER,
                                                  FieldValidator.NUMBER
                                              ],
                                              [
                                                  FieldValidator.DATAFILE,
                                                  FieldValidator.ALPHA1,
                                                  FieldValidator.NUMBER,
                                                  FieldValidator.NUMBER
                                              ],
                                              [
                                                  FieldValidator.DATAFILE,
                                                  FieldValidator.ALPHA2,
                                                  FieldValidator.NUMBER,
                                                  FieldValidator.NUMBER
                                              ]
                                          ]):
                                # Format: [file label probability index]
                                data.append(
                                    self.item_class({
                                        'filename': row[0],
                                        'label': row[1],
                                        'probability': row[2],
                                        'index': row[3]
                                    }))

                            else:
                                message = '{name}: Unknown row format [{row}] [{row_format}]'.format(
                                    name=self.__class__.__name__,
                                    row=row,
                                    row_format=row_format)
                                self.logger.exception(message)
                                raise IOError(message)

                finally:
                    f.close()

                self.update(data=data)

            elif self.format == FileFormat.CSV:
                if fields is None and csv_header is None:
                    message = '{name}: Parameters fields or csv_header has to be set for CSV files.'.format(
                        name=self.__class__.__name__)
                    self.logger.exception(message)
                    raise ValueError(message)

                if not delimiter:
                    if decimal == 'comma':
                        delimiter = self.delimiter(exclude_delimiters=[','])

                    else:
                        delimiter = self.delimiter()

                data = []
                with open(self.filename, 'r') as f:
                    csv_reader = csv.reader(f, delimiter=delimiter)
                    if csv_header:
                        csv_fields = next(csv_reader)
                        if fields is None:
                            fields = csv_fields

                    for row in csv_reader:
                        if row:
                            for cell_id, cell_data in enumerate(row):
                                if decimal == 'comma':
                                    # Translate decimal comma into decimal point
                                    cell_data = float(
                                        cell_data.replace(',', '.'))

                                if is_int(cell_data):
                                    row[cell_id] = int(cell_data)

                                elif is_float(cell_data):
                                    row[cell_id] = float(cell_data)

                            data.append(dict(zip(fields, row)))

                self.update(data=data)

            elif self.format == FileFormat.CPICKLE:
                from dcase_util.files import Serializer
                self.update(data=Serializer.load_cpickle(
                    filename=self.filename))

        else:
            message = '{name}: File not found [{file}]'.format(
                name=self.__class__.__name__, file=self.filename)
            self.logger.exception(message)
            raise IOError(message)

        return self
Пример #10
0
    def save(self, data, filename=None):
        """Save file

        Parameters
        ----------
        data
            Data to be saved

        filename : str, optional
            File path
            Default value filename given to class constructor

        Raises
        ------
        ImportError:
            Error if file format specific module cannot be imported

        IOError:
            File has unknown file format

        Returns
        -------
        self

        """

        if filename:
            self.filename = filename
            self.detect_file_format()
            self.validate_format()

        if self.filename is None or self.filename == '':
            message = '{name}: Filename is empty [{filename}]'.format(
                name=self.__class__.__name__, filename=self.filename)

            self.logger.exception(message)
            raise IOError(message)

        try:
            from dcase_util.files import Serializer

            if self.format == FileFormat.YAML:
                Serializer.save_yaml(filename=self.filename, data=data)

            elif self.format == FileFormat.CPICKLE:
                Serializer.save_cpickle(filename=self.filename, data=data)

            elif self.format == FileFormat.MARSHAL:
                Serializer.save_marshal(filename=self.filename, data=data)

            elif self.format == FileFormat.MSGPACK:
                Serializer.save_msgpack(filename=self.filename, data=data)

            elif self.format == FileFormat.JSON:
                Serializer.save_json(filename=self.filename, data=data)

            elif self.format == FileFormat.TXT:
                with open(self.filename, "w") as text_file:
                    for line_id in data:
                        text_file.write(data[line_id])

            else:
                message = '{name}: Unknown format [{format}]'.format(
                    name=self.__class__.__name__, format=self.filename)

                self.logger.exception(message)
                raise IOError(message)

        except KeyboardInterrupt:
            os.remove(
                self.filename
            )  # Delete the file, since most likely it was not saved fully
            raise

        return self
Пример #11
0
def test_Serializer():
    data = {
        'field1': 10,
        'field2': 100,
        'field3': 1000,
    }

    s = Serializer()
    tmp = tempfile.NamedTemporaryFile('r+',
                                      suffix='.yaml',
                                      dir='/tmp',
                                      delete=False)
    try:
        s.save_yaml(filename=tmp.name, data=data)
        nose.tools.eq_(data, s.load_yaml(filename=tmp.name))
    finally:
        os.unlink(tmp.name)

    tmp = tempfile.NamedTemporaryFile('r+',
                                      suffix='.cpickle',
                                      dir='/tmp',
                                      delete=False)
    try:
        s.save_cpickle(filename=tmp.name, data=data)
        nose.tools.eq_(data, s.load_cpickle(filename=tmp.name))
    finally:
        os.unlink(tmp.name)

    tmp = tempfile.NamedTemporaryFile('r+',
                                      suffix='.json',
                                      dir='/tmp',
                                      delete=False)
    try:
        s.save_json(filename=tmp.name, data=data)
        nose.tools.eq_(data, s.load_json(filename=tmp.name))
    finally:
        os.unlink(tmp.name)

    tmp = tempfile.NamedTemporaryFile('r+',
                                      suffix='.msgpack',
                                      dir='/tmp',
                                      delete=False)
    try:
        s.save_msgpack(filename=tmp.name, data=data)
        nose.tools.eq_(data, s.load_msgpack(filename=tmp.name))
    finally:
        os.unlink(tmp.name)

    tmp = tempfile.NamedTemporaryFile('r+',
                                      suffix='.marshal',
                                      dir='/tmp',
                                      delete=False)
    try:
        s.save_marshal(filename=tmp.name, data=data)
        nose.tools.eq_(data, s.load_marshal(filename=tmp.name))
    finally:
        os.unlink(tmp.name)
Пример #12
0
    def save(self, filename=None):
        """Save file

        Parameters
        ----------
        filename : str, optional
            File path
            Default value filename given to class constructor

        Raises
        ------
        ImportError:
            Error if file format specific module cannot be imported
        IOError:
            File has unknown file format

        Returns
        -------
        self

        """

        if filename:
            self.filename = filename
            self.detect_file_format()
            self.validate_format()

        if self.filename is None or self.filename == '':
            message = '{name}: Filename is empty [{filename}]'.format(
                name=self.__class__.__name__, filename=self.filename)

            self.logger.exception(message)
            raise IOError(message)

        try:
            from dcase_util.files import Serializer

            if self.format == FileFormat.CSV or self.format == FileFormat.TXT:
                delimiter = ','
                with open(self.filename, 'w') as csv_file:
                    csv_writer = csv.writer(csv_file, delimiter=delimiter)
                    for key, value in iteritems(self):
                        if key not in ['filename']:
                            csv_writer.writerow((key, value))

            elif self.format == FileFormat.CPICKLE:
                Serializer.save_cpickle(filename=self.filename,
                                        data=dict(self))

            else:
                message = '{name}: Unknown format [{format}]'.format(
                    name=self.__class__.__name__, format=self.filename)
                self.logger.exception(message)
                raise IOError(message)

        except KeyboardInterrupt:
            os.remove(
                self.filename
            )  # Delete the file, since most likely it was not saved fully
            raise

        # Check if after save function is defined, call if found
        if hasattr(self, '_after_save'):
            self._after_save()

        return self
Пример #13
0
    def load(self, filename=None):
        """Load file

        Parameters
        ----------
        filename : str, optional
            File path
            Default value filename given to class constructor

        Raises
        ------
        ImportError:
            Error if file format specific module cannot be imported
        IOError:
            File does not exists or has unknown file format

        Returns
        -------
        self

        """

        if filename:
            self.filename = filename
            self.detect_file_format()
            self.validate_format()

        dict.clear(self)
        if self.exists():
            from dcase_util.files import Serializer

            if self.format == FileFormat.TXT or self.format == FileFormat.CSV:
                map_data = {}
                with open(self.filename, 'rtU') as f:
                    for row in csv.reader(f, delimiter=self.delimiter()):
                        if len(row) == 2:
                            map_data[row[0]] = row[1]

                dict.update(self, map_data)

            elif self.format == FileFormat.CPICKLE:
                dict.update(self,
                            Serializer.load_cpickle(filename=self.filename))

            else:
                message = '{name}: Unknown format [{format}]'.format(
                    name=self.__class__.__name__, format=self.filename)
                self.logger.exception(message)
                raise IOError(message)

        else:
            message = '{name}: File does not exists [{file}]'.format(
                name=self.__class__.__name__, file=self.filename)
            self.logger.exception(message)
            raise IOError(message)

        # Check if after load function is defined, call if found
        if hasattr(self, '_after_load'):
            self._after_load()

        return self