Пример #1
0
def main():
    parser = argparse.ArgumentParser(description="Training DCGAN on CelebA dataset")
    parser.add_argument("--checkpoint_dir", type=str, default="./model/checkpoint", help="Path to write checkpoint")
    parser.add_argument("--progress_dir", type=str, default="./data/face_gan", help="Path to write training progress image")
    parser.add_argument("--dataset_dir", type=str, required=True, help="Path to dataset")
    parser.add_argument("--latent_dim", type=int, default=100, help="Latent space dimension")
    parser.add_argument("--test_size", type=int, default=4, help="Square root number of test images to control training progress")
    parser.add_argument("--batch_size", type=int, default=100, help="Number of training steps per epoch")
    parser.add_argument("--lr", type=float, default=0.0002, help="Learning rate")
    parser.add_argument("--epochs", type=int, default=20, help="Number of epochs for training")
    
    args = vars(parser.parse_args())
    
    validate_path(args["checkpoint_dir"])
    validate_path(args["progress_dir"])
    
    datagen = DataSet(args["dataset_dir"])
    dataset, total_steps = datagen.build(batch_size=args["batch_size"])
    
    DCGAN = Trainer(progress_dir=args["progress_dir"],
                    checkpoint_dir=args["checkpoint_dir"],
                    z_dim=args["latent_dim"],
                    test_size=args["test_size"],
                    batch_size=args["batch_size"],
                    learning_rate=args["lr"])
    
    DCGAN.train_loop(dataset=dataset,
                     epochs=args["epochs"],
                     total_steps=total_steps)
Пример #2
0
def save_weights(model, weights_dir, name, verbose=True):
	path = os.path.join(weights_dir, name+'.pt')
	validate_path(path, is_dir=False)

	log(f'saving {name} weights to {path}')

	try:
		torch.save(model.state_dict(), path)
	except Exception:
		raise
Пример #3
0
def main(in_raster=None, neighborhood_size=None, out_raster=None):
    hood_size = int(neighborhood_size)

    # FIXME: expose this as an option per #18
    out_workspace = os.path.dirname(out_raster)
    utils.workspace_exists(out_workspace)
    # force temporary stats to be computed in our output workspace
    arcpy.env.scratchWorkspace = out_workspace
    arcpy.env.workspace = out_workspace

    # TODO expose as config
    pyramid_orig = arcpy.env.pyramid
    arcpy.env.pyramid = "NONE"
     # TODO: currently set to automatically overwrite, expose this as option
    arcpy.env.overwriteOutput = True

    try:
        # Create Slope and Aspect rasters
        utils.msg("Calculating aspect...")
        out_aspect = Aspect(in_raster)
        utils.msg("Calculating slope...")
        out_slope = Slope(in_raster, "DEGREE")

        # Convert Slope and Aspect rasters to radians
        utils.msg("Converting slope and aspect to radians...")
        slope_rad = out_slope * (math.pi / 180)
        aspect_rad = out_aspect * (math.pi / 180)

        # Calculate x, y, and z rasters
        utils.msg("Calculating x, y, and z rasters...")
        xy_raster_calc = Sin(slope_rad)
        z_raster_calc = Cos(slope_rad)
        x_raster_calc = Con(out_aspect == -1, 0, Sin(aspect_rad)) * xy_raster_calc
        y_raster_calc = Con(out_aspect == -1, 0, Cos(aspect_rad)) * xy_raster_calc

        # Calculate sums of x, y, and z rasters for selected neighborhood size
        utils.msg("Calculating sums of x, y, and z rasters in neighborhood...")
        hood = NbrRectangle(hood_size, hood_size, "CELL")
        x_sum_calc = FocalStatistics(x_raster_calc, hood, "SUM", "NODATA")
        y_sum_calc = FocalStatistics(y_raster_calc, hood, "SUM", "NODATA")
        z_sum_calc = FocalStatistics(z_raster_calc, hood, "SUM", "NODATA")

        # Calculate the resultant vector
        utils.msg("Calculating the resultant vector...")
        result_vect = (x_sum_calc**2 + y_sum_calc**2 + z_sum_calc**2)**0.5

        arcpy.env.rasterStatistics = "STATISTICS"
        arcpy.env.pyramid = pyramid_orig
        # Calculate the Ruggedness raster
        utils.msg("Calculating the final ruggedness raster...")
        ruggedness = 1 - (result_vect / hood_size**2)

        out_raster = utils.validate_path(out_raster)
        utils.msg("Saving ruggedness raster to to {}.".format(out_raster))
        ruggedness.save(out_raster)

    except Exception as e:
        utils.msg(e, mtype='error')
Пример #4
0
def init_android_repo():
    args = parse_arguments()
    try:
        repo_path = utils.validate_path(args.path)
        repo_name = os.path.basename(os.path.normpath(repo_path)).lower()
        populate_repo(repo_path)
        create_catkin_package_files(repo_name, repo_path, args)
        create_gradle_wrapper(repo_path)
    except Exception:
        raise
Пример #5
0
def draw(data_points,
         max_val,
         min_val,
         colors,
         lines,
         title,
         x_title,
         y_title,
         x_custom="",
         width=640,
         height=480,
         x_interval=20,
         y_interval=20):
    fig, ax = plt.subplots(figsize=(width / 100, height / 100))
    ax.set_title(title)
    data_length = len(data_points[0])
    x_points = range(data_length)
    if not x_interval:
        x_interval = 1
    if not y_interval:
        y_interval = 1
    x_sticks = range(0, data_length, x_interval)
    y_sticks = range(min_val, max_val + y_interval, y_interval)
    color = colors.split("|")
    lines = lines.split("|")
    patches = []
    for d, c, l in zip(data_points, color, lines):
        ax.plot(x_points, d, color=c)
        d1_patch = mpatches.Patch(color=c, label=l)
        patches.append(d1_patch)
    plt.legend(handles=patches)
    ax.set_xlabel(x_title)
    ax.set_ylabel(y_title)
    ax.set_xticks(x_sticks)
    ax.set_yticks(y_sticks)
    if x_custom:
        ax.set_xticklabels([str((x + 2) * 4) + x_custom for x in x_sticks])
    utils.validate_path("figures")
    if title:
        filename = (title.lower().replace(" ", "_"))
    else:
        filename = "plot_figure"
    plt.savefig("figures/%s.png" % filename, format="png", bbox_inches="tight")
Пример #6
0
def setup(envs, opts):
    if opts.action == 'start' or opts.action == 'stop':
        start_stop_server(envs, opts.action)
    elif opts.action == 'setup':
        envs['ppath'] = ospath(os.path.expanduser(opts.ppath).rstrip("/").rstrip("\\"))
        if validate_path(envs) == False:
            print "Not a valid source path"
            return False
        if check_and_copy_tools(envs) == True:
            setup_index_for_project(envs)
def init_package(package_type):
    args = parse_arguments()
    try:
        repo_path = utils.validate_path(args.path)
        repo_name = os.path.basename(os.path.normpath(repo_path)).lower()
        populate_repo(repo_path, package_type)
        create_catkin_package_files(repo_name, repo_path, args)
        create_gradle_wrapper(repo_path)
    except Exception:
        raise
Пример #8
0
def plot_prediction(file,
                    title="",
                    width=640,
                    height=480,
                    save=0,
                    color="red|green"):
    with open(file) as f:
        data = f.readlines()
        data1 = []
        data2 = []
        max_val = 100
        for x in data[1:]:
            tmp = x.rstrip("\n").split(",")
            d1 = int(tmp[0])
            d2 = int(tmp[1])
            # d1 = utils.boost_pm25(d1)
            data1.append(d1)
            data2.append(d2)
            if max_val < d1:
                max_val = d1
            if max_val < d2:
                max_val = d2
    colors = color.split("|")
    max_val = int(math.ceil(max_val * 1.0 / 50) * 50)
    y_sticks = range(0, max_val + 10, 10)
    _, ax = plt.subplots(figsize=(width / 100, height / 100))
    x = range(len(data1))
    ax.set_title(title)
    ax.plot(x, data2, color=colors[1])
    ax.plot(x, data1, color=colors[0])
    ax.set_xlabel("Hourly Timestep on Test Set")
    ax.set_ylabel("PM2.5 AQI")
    ax.set_yticks(y_sticks)
    d1_patch = mpatches.Patch(color=colors[0], label="predicted")
    d2_patch = mpatches.Patch(color=colors[1], label="actual")
    plt.legend(handles=[d2_patch, d1_patch])
    if save:
        utils.validate_path("figures/")
        plt.savefig("figures/%s.png" % (title.lower().replace(" ", "_")),
                    format="png",
                    bbox_inches='tight')
    else:
        plt.show()
Пример #9
0
def main(bathy=None, out_raster=None):
    try:
        arcpy.env.rasterStatistics = "STATISTICS"
        # Calculate the slope of the bathymetric raster
        utils.msg("Calculating the slope...")
        out_slope = Slope(bathy, "DEGREE", 1)
        out_raster = utils.validate_path(out_raster)
        out_slope.save(out_raster)
    except Exception as e:
        utils.msg(e, mtype='error')
Пример #10
0
def setup(envs, opts):
    if opts.action == 'start' or opts.action == 'stop':
        start_stop_server(envs, opts.action)
    elif opts.action == 'setup':
        envs['ppath'] = ospath(
            os.path.expanduser(opts.ppath).rstrip("/").rstrip("\\"))
        if validate_path(envs) == False:
            print "Not a valid source path"
            return False
        if check_and_copy_tools(envs) == True:
            setup_index_for_project(envs)
Пример #11
0
    def __init__(self, tools, toolid, ga, path=None, archive=True):
        ArchiveBasedTool.__init__(self,tools,toolid)
        self._archive=archive
        imp=ga.split(':')
        if len(imp)<=2:  # add default type
            ga=ga+":"+zip

        self._ga=ga
        if hasattr(path, '__call__'):
            self._path=path
        else:
            self._path=validate_path(path)
Пример #12
0
def main():
    parser = argparse.ArgumentParser(
        description="Training VAE on CelebA dataset")
    parser.add_argument("--model",
                        type=str,
                        default="VAE",
                        choices=["VAE", "VAE_123", "VAE_345"],
                        help="Training model")

    args = vars(parser.parse_args())

    datagen = DataSet(cfg.dataset_dir)
    dataset, total_steps = datagen.build(batch_size=cfg.batch_size)

    encoder, decoder, vae_net = build_vae(z_dim=cfg.z_dim)

    if args["model"] == "VAE":

        validate_path(VaeConfig.progress_dir)
        validate_path(VaeConfig.checkpoint_dir)

        VAE = VaeTrainer(progress_dir=VaeConfig.progress_dir,
                         checkpoint_dir=VaeConfig.checkpoint_dir,
                         encoder=encoder,
                         decoder=decoder,
                         vae_net=vae_net,
                         reconstruction_weight=VaeConfig.reconstruction_weight,
                         z_dim=cfg.z_dim,
                         test_size=cfg.test_size,
                         batch_size=cfg.test_size,
                         learning_rate=cfg.lr)
    else:

        validate_path(DfcVaeConfig.progress_dir)
        validate_path(DfcVaeConfig.checkpoint_dir)

        VAE = DfcVaeTrainer(
            progress_dir=DfcVaeConfig.progress_dir,
            checkpoint_dir=DfcVaeConfig.checkpoint_dir,
            encoder=encoder,
            decoder=decoder,
            vae_net=vae_net,
            vgg_layers=DfcVaeConfig.vgg19_layers[args["model"]],
            perceptual_weight=DfcVaeConfig.perceptual_weight,
            z_dim=cfg.z_dim,
            test_size=cfg.test_size,
            batch_size=cfg.test_size,
            learning_rate=cfg.lr)

    VAE.train_loop(dataset=dataset, epochs=cfg.epochs, total_steps=total_steps)
Пример #13
0
def bpi(bathy=None, inner_radius=None, outer_radius=None,
         out_raster=None, bpi_type=None):

    # Calculate neighborhood
    neighborhood = NbrAnnulus(inner_radius, outer_radius, "CELL")

    #Calculate Focal Statistics
    out_focal_statistics = FocalStatistics(bathy, neighborhood, "MEAN")
    result_raster = Int(Plus(Minus(bathy, out_focal_statistics), 0.5))

    # Save output raster
    out_raster_path = utils.validate_path(out_raster)
    arcpy.CopyRaster_management(result_raster, out_raster_path)
Пример #14
0
def main(bathy=None, out_sin_raster=None, out_cos_raster=None):
    try:
        arcpy.env.rasterStatistics = "STATISTICS"
        # Calculate the aspect of the bathymetric raster. "Aspect is expressed in 
        # positive degrees from 0 to 359.9, measured clockwise from north."
        utils.msg("Calculating aspect...")
        aspect = Aspect(bathy)

        # both the sin and cos functions here expect radians, not degrees.
        # convert our Aspect raster into radians, check that the values are in range.
        aspect_rad = aspect * (math.pi / 180)

        # because this statistic is circular (0 and 359.9 are immediately adjacent),
        # we need to transform this into a form which preserves distances between items.
        # trig is the simplest mechanism.
        aspect_sin = Sin(aspect_rad)
        aspect_cos = Cos(aspect_rad)

        out_sin_raster = utils.validate_path(out_sin_raster)
        out_cos_raster = utils.validate_path(out_cos_raster)
        aspect_sin.save(out_sin_raster)
        aspect_cos.save(out_cos_raster)
    except Exception as e:
        utils.msg(e, mtype='error')
Пример #15
0
def stdbpi(bpi_raster=None, out_raster=None):

    # Convert to a path
    desc = arcpy.Describe(bpi_raster)
    bpi_raster_path = desc.catalogPath

    # Calculate mean and stdev
    bpi_mean = utils.raster_properties(bpi_raster_path, "MEAN")
    bpi_std_dev = utils.raster_properties(bpi_raster_path, "STD")

    # Create the standardized raster
    arcpy.env.rasterStatistics = "STATISTICS"
    outRaster = Int(Plus(Times(Divide(
        Minus(bpi_raster_path, bpi_mean), bpi_std_dev), 100), 0.5))
    out_raster = utils.validate_path(out_raster)
    arcpy.CopyRaster_management(outRaster, out_raster)
Пример #16
0
 def __init__(self, name, gav):
     relative = True
     self._name = name
     if is_gav(gav):
         gav = tool_gav(gav, True)
         self._gav = gav
         if len(self._gav) == 6:
             self._path = gav[5]
             self._gav = gav[:5]
         else:
             self._path = None
     else:
         self._gav = None
         self._path = gav
         relative = False
     self._path = utils.validate_path(self._path, relative)
Пример #17
0
def main(bathy=None, inner_radius=None, outer_radius=None,
    out_raster=None, bpi_type='broad', mode='toolbox'):

    arcpy.env.rasterStatistics = "STATISTICS"
    try:
        # Create the broad-scale Bathymetric Position Index (BPI) raster
        msg = "Generating the {bpi_type}-scale ".format(bpi_type=bpi_type) + \
                "Bathymetric Position Index (BPI) raster..."
        utils.msg(msg)
        utils.msg("Calculating neighborhood...")
        neighborhood = NbrAnnulus(inner_radius, outer_radius, "CELL")
        utils.msg("Calculating FocalStatistics for {}...".format(bathy))
        out_focal_statistics = FocalStatistics(bathy, neighborhood, "MEAN")
        outRaster = Int(Plus(Minus(bathy, out_focal_statistics), 0.5))

        out_raster = utils.validate_path(out_raster)
        outRaster.save(out_raster)
        utils.msg("Saved output as {}".format(out_raster))
    except Exception as e:
        utils.msg(e, mtype='error')
Пример #18
0
def main(bpi_raster=None, out_raster=None):
    try:
        # Get raster properties
        message = "Calculating properties of the Bathymetric " + \
                "Position Index (BPI) raster..."
        utils.msg(message)
        utils.msg("raster: {}; output: {}".format(bpi_raster, out_raster))
        bpi_mean = utils.raster_properties(bpi_raster, "MEAN")
        utils.msg("BPI raster mean: {}.".format(bpi_mean))
        bpi_std_dev = utils.raster_properties(bpi_raster, "STD")
        utils.msg("BPI raster standard deviation: {}.".format(bpi_std_dev))

        # Create the standardized Bathymetric Position Index (BPI) raster
        message = "Standardizing the Bathymetric Position Index (BPI) raster..."
        utils.msg(message)
        arcpy.env.rasterStatistics = "STATISTICS"
        outRaster = Int(Plus(Times(Divide(
                Minus(bpi_raster, bpi_mean), bpi_std_dev), 100), 0.5))
        out_raster = utils.validate_path(out_raster)
        outRaster.save(out_raster)

    except Exception as e:
        utils.msg(e, mtype='error')
Пример #19
0
def main(in_raster=None, out_raster=None, area_raster=None):
    out_workspace = os.path.dirname(out_raster)
    # make sure workspace exists
    utils.workspace_exists(out_workspace)

    utils.msg("Set scratch workspace to {}...".format(out_workspace))

    # force temporary stats to be computed in our output workspace
    arcpy.env.scratchWorkspace = out_workspace
    arcpy.env.workspace = out_workspace
    pyramid_orig = arcpy.env.pyramid
    arcpy.env.pyramid = "NONE"
    # TODO: currently set to automatically overwrite, expose this as option
    arcpy.env.overwriteOutput = True

    bathy = Raster(in_raster)
    desc = arcpy.Describe(bathy)
    # get the cell size of the input raster; use same calculation as was
    # performed in BTM v1: (mean_x + mean_y) / 2
    cell_size = (desc.meanCellWidth + desc.meanCellHeight) / 2.0
    corner_dist = math.sqrt(2 * cell_size ** 2)
    flat_area = cell_size ** 2
    utils.msg("Cell size: {}\nFlat area: {}".format(cell_size, flat_area))

    try:
        """
        Create a set of shifted grids, with offets n from the origin X:

                 8 | 7 | 6
                 --|---|---
                 5 | X | 4
                 --|---|---
                 3 | 2 | 1
        """
        positions = [(1, -1), (0, -1), (-1, -1),
                     (1,  0),          (-1,  0),
                     (1,  1), (0,  1), (-1,  1)]


        corners = (1, 3, 6, 8) # dist * sqrt(2), as set in corner_dist
        orthogonals = (2, 4, 5, 7) # von Neumann neighbors, straight dist
        temp_rasts = []

        shift_rasts = [None] # offset to align numbers
        for (n, pos) in enumerate(positions, start=1):
            utils.msg("Creating Shift Grid {} of 8...".format(n))
            # scale shift grid by cell size
            (x_shift, y_shift) = map(lambda(n): n * cell_size, pos)

            # set explicit path on shift rasters, otherwise suffer inexplicable 999999 errors.
            shift_out = os.path.join(out_workspace, "shift_{}.tif".format(n))
            shift_out = utils.validate_path(shift_out)
            temp_rasts.append(shift_out)
            arcpy.Shift_management(bathy, shift_out, x_shift, y_shift)
            shift_rasts.append(arcpy.sa.Raster(shift_out))

        edge_rasts = [None]
        # calculate triangle length grids

        # edges 1-8: pairs of bathy:shift[n]
        for (n, shift) in enumerate(shift_rasts[1:], start=1):
            utils.msg("Calculating Triangle Edge {} of 16...".format(n))
            # adjust for corners being sqrt(2) from center
            if n in corners:
                dist = corner_dist
            else:
                dist = cell_size
            edge_out = os.path.join(out_workspace, "edge_{}.tif".format(n))
            edge_out = utils.validate_path(edge_out)
            temp_rasts.append(edge_out)
            edge = compute_edge(bathy, shift, dist)
            edge.save(edge_out)
            edge_rasts.append(arcpy.sa.Raster(edge_out))

        # edges 9-16: pairs of adjacent shift grids [see layout above]
        # in BTM_v1, these are labeled A-H
        adjacent_shift = [(1, 2), (2, 3), (1, 4), (3, 5),
                          (6, 4), (5, 8), (6, 7), (7, 8)]
        for (n, pair) in enumerate(adjacent_shift, start=9):
            utils.msg("Calculating Triangle Edge {} of 16...".format(n))
            (i, j) = pair # the two shift rasters for this iteration
            edge_out = os.path.join(out_workspace, "edge_{}.tif".format(n))
            edge_out = utils.validate_path(edge_out)
            temp_rasts.append(edge_out)
            edge = compute_edge(shift_rasts[i], shift_rasts[j], cell_size)
            edge.save(edge_out)
            edge_rasts.append(arcpy.sa.Raster(edge_out))

        areas = [] # areas of each triangle
        for (n, pair) in enumerate(adjacent_shift, start=1):
            utils.msg("Calculating Triangle Area {} of 8...".format(n))
            (i, j) = pair # the two shift rasters; n has the third side
            area_out = os.path.join(out_workspace, "area_{}.tif".format(n))
            area_out = utils.validate_path(area_out)
            temp_rasts.append(area_out)
 
            area = triangle_area(edge_rasts[i], edge_rasts[j], edge_rasts[n+8])
            area.save(area_out)
            areas.append(arcpy.sa.Raster(area_out))

        utils.msg("Summing Triangle Area...")
        arcpy.env.pyramid = pyramid_orig
        arcpy.env.rasterStatistics = "STATISTICS"
        total_area = (areas[0] + areas[1] + areas[2] + areas[3] + \
                      areas[4] + areas[5] + areas[6] + areas[7])
        if area_raster:
            utils.msg("Saving Surface Area Raster to {}.".format(area_raster))
            total_area.save(area_raster)

        area_ratio = total_area / cell_size**2

        out_raster = utils.validate_path(out_raster)
        utils.msg("Saving Surface Area to Planar Area ratio to {}.".format(out_raster))
        area_ratio.save(out_raster)

    except Exception as e:
        utils.msg(e, mtype='error')

    try:
        # Delete all intermediate raster data sets
        utils.msg("Deleting intermediate data...")
        for path in temp_rasts:
            arcpy.Delete_management(path)

    except Exception as e:
        utils.msg(e, mtype='error')
Пример #20
0
def save_token_set(idx_to_token, token_dir, pretty=True):
    path = os.path.join(token_dir, token_set_name)
    validate_path(path, is_dir=False)
    with open(path, 'w') as fp:
        json.dump(idx_to_token, fp, indent=2 if pretty else None)
Пример #21
0
def main(classification_file, bpi_broad_std, bpi_fine_std, slope, bathy,
    out_raster=None, mode='toolbox'):

    try:
        # set up scratch workspace
        # FIXME: see issue #18
        # CON is very very picky. it generates GRID outputs by default, and the
        # resulting names must not exist. for now, push our temp results
        # to the output folder.
        out_workspace = os.path.dirname(out_raster)
        # make sure workspace exists
        utils.workspace_exists(out_workspace)
        arcpy.env.scratchWorkspace = out_workspace
        arcpy.env.workspace = out_workspace

        arcpy.env.overwriteOutput = True
        # Create the broad-scale Bathymetric Position Index (BPI) raster
        msg_text = "Generating the classified grid, based on the provided" + \
                " classes in '{classes}'.".format(classes=classification_file)
        utils.msg(msg_text)

        # Read in the BTM Document; the class handles parsing a variety of inputs.
        btm_doc = utils.BtmDocument(classification_file)
        classes = btm_doc.classification()
        utils.msg("Parsing {} document... found {} classes.".format(
            btm_doc.doctype, len(classes)))

        grids = []
        for item in classes:
            cur_class = str(item["Class"])
            cur_name = str(item["Zone"])
            utils.msg("Calculating grid for {}...".format(cur_name))
            out_con = None
            # here come the CONs:
            out_con = run_con(item["Depth_LowerBounds"], item["Depth_UpperBounds"], \
                    bathy, cur_class)
            out_con2 = run_con(item["Slope_LowerBounds"], item["Slope_UpperBounds"], \
                    slope, out_con, cur_class)
            out_con3 = run_con(item["LSB_LowerBounds"], item["LSB_UpperBounds"], \
                    bpi_fine_std, out_con2, cur_class)
            out_con4 = run_con(item["SSB_LowerBounds"], item["SSB_UpperBounds"], \
                    bpi_broad_std, out_con3, cur_class)

            if type(out_con4) == arcpy.sa.Raster:
                rast = utils.save_raster(out_con4, "con_{}.tif".format(cur_name))
                grids.append(rast)
            else:
                # fall-through: no valid values detected for this class.
                warn_msg = "WARNING, no valid locations found for class" + \
                        " {}:\n".format(cur_name)
                classifications = {
                        'depth': (item["Depth_LowerBounds"], item["Depth_UpperBounds"]),
                        'slope': (item["Slope_LowerBounds"], item["Slope_UpperBounds"]),
                        'broad': (item["SSB_LowerBounds"], item["SSB_UpperBounds"]), 
                        'fine': (item["LSB_LowerBounds"], item["LSB_UpperBounds"])}
                for (name, vrange) in classifications.items(): 
                    (vmin, vmax) = vrange
                    if vmin or vmax is not None:
                        if vmin is None:
                            vmin = ""
                        if vmax is None:
                            vmax = ""
                        warn_msg += "  {}: {{{}:{}}}\n".format(name, vmin, vmax)

                utils.msg(textwrap.dedent(warn_msg))

        if len(grids) == 0:
            raise NoValidClasses

        utils.msg("Creating Benthic Terrain Classification Dataset...")
        merge_grid = grids[0]
        for i in range(1, len(grids)):
            utils.msg("{} of {}".format(i, len(grids)-1))
            merge_grid = Con(merge_grid, grids[i], merge_grid, "VALUE = 0")

        arcpy.env.rasterStatistics = "STATISTICS"
        # validate the output raster path
        out_raster = utils.validate_path(out_raster)
        utils.msg("Saving Output to {}".format(out_raster))
        merge_grid.save(out_raster)
        utils.msg("Complete.")

    except NoValidClasses as e:
        utils.msg(e, mtype='error')
    except Exception as e:
        if type(e) is ValueError:
            raise e
        utils.msg(e, mtype='error')

    try:
        utils.msg("Deleting intermediate data...")
        # Delete all intermediate raster data sets
        for grid in grids:
            arcpy.Delete_management(grid.catalogPath)
    except Exception as e:
        # hack -- swallowing this exception, because sometimes it seems like refs are left around
        # for these files.
        utils.msg("WARNING: failed to delete all intermediate data.")
Пример #22
0
def evaluate_by_districts(url,
                          url2,
                          stride=2,
                          encoder_length=72,
                          decoder_length=24,
                          forecast_factor=0,
                          is_classify=False,
                          confusion_title="",
                          norm=True,
                          is_grid=True,
                          offset=48,
                          agg=True):
    print(encoder_length, decoder_length, offset)
    if not utils.validate_path("district_idx.pkl"):
        districts = convert_coordinate_to_idx()
    else:
        districts = utils.load_file("district_idx.pkl")
    data = utils.load_file(url)
    print(np.shape(data))
    if type(data) is list:
        data = np.asarray(data)
    if len(data.shape) == 4:
        lt = data.shape[0] * data.shape[1]
        # if not is_grid:
        # data = np.reshape(data, (lt, data.shape[-1]))
        # else:
        data = np.reshape(data, (lt, data.shape[-2], data.shape[-1]))
    else:
        lt = data.shape[0]
        data = np.reshape(data, (lt, data.shape[-2], data.shape[-1]))
    st_h = 0
    if agg:
        days = int(math.ceil(data.shape[1] / 24.0))
        if days > 2:
            st_h = (days - 1) * 24
    labels = utils.load_file(url2)
    labels = np.asarray(labels)
    if not is_classify:
        loss_mae = [0.0] * data.shape[1]
        # loss_rmse = [0.0] * decoder_length
    elif not confusion_title:
        acc = 0.
    else:
        acc = None
    cr = Crawling()
    for i, d in enumerate(data):
        if not is_grid:
            d = d[st_h:decoder_length]
        else:
            d = d[st_h:decoder_length, :]
        lb_i = i * stride + encoder_length
        lbt = labels[(lb_i + offset):(lb_i + offset + decoder_length), :,
                     forecast_factor]
        if not confusion_title:
            a = 0.
        else:
            a = None
        for t_i, (t, l_t) in enumerate(zip(d, lbt)):
            t_i += st_h
            if is_grid:
                pred_t = aggregate_predictions(districts, t)
                pred_t = np.array(pred_t)
            else:
                pred_t = t
            pred_t = pred_t.flatten()
            if not is_classify:
                if not forecast_factor:
                    # mae, mse, _ = get_evaluation(pred_t, l_t)
                    #mae = mean_absolute_error(pred_t * 300, l_t * 300)
                    mae = mean_absolute_error(
                        [cr.ConcPM25(x * 300) for x in pred_t],
                        [cr.ConcPM25(x * 300) for x in l_t])
                    loss_mae[t_i] += mae
                    # loss_rmse[t_i] += mse
                else:
                    mae = mean_absolute_error(
                        [cr.ConcPM10(x * 300) for x in pred_t],
                        [cr.ConcPM10(x * 300) for x in l_t])
                    loss_mae[t_i] += mae
            elif not confusion_title:
                a += classify_data(pred_t, l_t, forecast_factor, tp="G")
            elif a is None:
                a = classify_data(pred_t, l_t, forecast_factor, True, tp="G")
            else:
                a += classify_data(pred_t, l_t, forecast_factor, True, tp="G")
        if is_classify:
            a = a / decoder_length
            if not confusion_title:
                acc += a
            elif acc is None:
                acc = a
            else:
                acc += a
        utils.update_progress((i + 1.0) / lt)
    if not is_classify:
        # print mae loss score
        # caculate loss for each timestep
        loss_mae = np.array(loss_mae) / lt
        # loss_rmse = [sqrt(x / lt)  * 300 for x in loss_rmse]
        # calculate accumulated loss
        if agg:
            no_h = len(loss_mae)
            if no_h > 24:
                print("hourly errors", loss_mae[:24])
                days = math.ceil(no_h * 1.0 / 24)
                for x in xrange(1, int(days)):
                    ed = (x + 1) * 24
                    if ed > no_h:
                        ed = no_h
                    print("day %i" % (x + 1), np.mean(loss_mae[x * 24:ed]))
            else:
                print(loss_mae)
        else:
            print(loss_mae)
        #print_accumulate_error(loss_mae, loss_rmse, decoder_length, forecast_factor)
    elif not confusion_title:
        # print classification score
        acc = acc / lt * 100
        print("accuracy %.4f" % acc)
    else:
        name = url.split("/")[-1]
        # print confusion matrix
        utils.save_file("results/confusion/confusion_%s" % name, acc)
        draw_confusion_matrix(acc, confusion_title, norm)