コード例 #1
0
    def test_check_target_array_2d_multiclass_good(self):
        """Ensures correct output from check_target_array.

        In this case the input array is 2-D and multiclass, as expected.
        """

        dl_utils.check_target_array(target_array=WIND_CLASS_MATRIX,
                                    num_dimensions=2,
                                    num_classes=WIND_CLASS_MATRIX.shape[1])
コード例 #2
0
    def test_check_target_array_2d_binary_good(self):
        """Ensures correct output from check_target_array.

        In this case the input array is 2-D and contains 2 classes, as expected.
        """

        dl_utils.check_target_array(target_array=TORNADO_CLASS_MATRIX,
                                    num_dimensions=2,
                                    num_classes=2)
コード例 #3
0
    def test_check_target_array_2d_bad_dim(self):
        """Ensures correct output from check_target_array.

        In this case, the input array is 2-D but a 1-D array is expected.
        """

        with self.assertRaises(TypeError):
            dl_utils.check_target_array(target_array=TORNADO_CLASS_MATRIX,
                                        num_dimensions=1,
                                        num_classes=2)
コード例 #4
0
    def test_check_target_array_2d_multiclass_bad_class_num(self):
        """Ensures correct output from check_target_array.

        In this case, the input array contains 6 classes but 2 classes are
        expected.
        """

        with self.assertRaises(TypeError):
            dl_utils.check_target_array(target_array=WIND_CLASS_MATRIX,
                                        num_dimensions=2,
                                        num_classes=2)
コード例 #5
0
    def test_check_target_array_1d_binary_bad_class_num(self):
        """Ensures correct output from check_target_array.

        In this case, 6 classes are expected and the input array contains only 2
        classes.  However, there is no way to ascertain that the 2-class array
        is wrong (maybe higher classes just did not occur in the sample).
        """

        dl_utils.check_target_array(target_array=TORNADO_CLASSES_1D,
                                    num_dimensions=1,
                                    num_classes=WIND_CLASS_MATRIX.shape[1])
コード例 #6
0
ファイル: cnn.py プロジェクト: theweathermanda/GewitterGefahr
def write_features(netcdf_file_name,
                   feature_matrix,
                   target_values,
                   num_classes,
                   append_to_file=False):
    """Writes features (activations of intermediate layer) to NetCDF file.

    :param netcdf_file_name: Path to output file.
    :param feature_matrix: numpy array of features.  Must have >= 2 dimensions,
        where the first dimension (length E) represents examples and the last
        dimension represents channels (transformed input variables).
    :param target_values: length-E numpy array of target values.  Must all be
        integers in 0...(K - 1), where K = number of classes.
    :param num_classes: Number of classes.
    :param append_to_file: Boolean flag.  If True, will append to existing file.
        If False, will create new file.
    """

    error_checking.assert_is_boolean(append_to_file)
    error_checking.assert_is_numpy_array(feature_matrix)
    num_storm_objects = feature_matrix.shape[0]

    dl_utils.check_target_array(target_array=target_values,
                                num_dimensions=1,
                                num_classes=num_classes)
    error_checking.assert_is_numpy_array(target_values,
                                         exact_dimensions=numpy.array(
                                             [num_storm_objects]))

    if append_to_file:
        error_checking.assert_is_string(netcdf_file_name)
        netcdf_dataset = netCDF4.Dataset(netcdf_file_name,
                                         'a',
                                         format='NETCDF3_64BIT_OFFSET')

        prev_num_storm_objects = len(
            numpy.array(netcdf_dataset.variables[TARGET_VALUES_KEY][:]))
        netcdf_dataset.variables[FEATURE_MATRIX_KEY][prev_num_storm_objects:(
            prev_num_storm_objects + num_storm_objects), ...] = feature_matrix
        netcdf_dataset.variables[TARGET_VALUES_KEY][prev_num_storm_objects:(
            prev_num_storm_objects + num_storm_objects)] = target_values

    else:
        file_system_utils.mkdir_recursive_if_necessary(
            file_name=netcdf_file_name)
        netcdf_dataset = netCDF4.Dataset(netcdf_file_name,
                                         'w',
                                         format='NETCDF3_64BIT_OFFSET')

        netcdf_dataset.setncattr(NUM_CLASSES_KEY, num_classes)
        netcdf_dataset.createDimension(STORM_OBJECT_DIMENSION_KEY, None)
        netcdf_dataset.createDimension(FEATURE_DIMENSION_KEY,
                                       feature_matrix.shape[1])

        num_spatial_dimensions = len(feature_matrix.shape) - 2
        tuple_of_dimension_keys = (STORM_OBJECT_DIMENSION_KEY, )

        for i in range(num_spatial_dimensions):
            netcdf_dataset.createDimension(SPATIAL_DIMENSION_KEYS[i],
                                           feature_matrix.shape[i + 1])
            tuple_of_dimension_keys += (SPATIAL_DIMENSION_KEYS[i], )

        tuple_of_dimension_keys += (FEATURE_DIMENSION_KEY, )
        netcdf_dataset.createVariable(FEATURE_MATRIX_KEY,
                                      datatype=numpy.float32,
                                      dimensions=tuple_of_dimension_keys)
        netcdf_dataset.variables[FEATURE_MATRIX_KEY][:] = feature_matrix

        netcdf_dataset.createVariable(TARGET_VALUES_KEY,
                                      datatype=numpy.int32,
                                      dimensions=STORM_OBJECT_DIMENSION_KEY)
        netcdf_dataset.variables[TARGET_VALUES_KEY][:] = target_values

    netcdf_dataset.close()