Ejemplo n.º 1
0
def build_net(steps=steps_inference, mode="inference"):
    unet = unet_class.UNet(
        feature_widths_down,
        feature_widths_up,
        downsampling_factors,
        kernel_sizes_down,
        kernel_sizes_up,
        padding=padding,
        constant_upsample=constant_upsample,
        trans_equivariant=trans_equivariant,
        input_voxel_size=voxel_size,
        input_fov=voxel_size,
    )
    if voxel_size == voxel_size_input:
        net, input_shape, output_shape = make_net(unet, labels, steps, loss_name=loss_name, mode=mode)
    else:
        net, input_shape, output_shape = make_net_upsample(unet, labels, steps, upsample_factor,
                                                           final_kernel_size, final_feature_width,
                                                           loss_name=loss_name, mode=mode)
    logging.info(
        "Built {0:} with input shape {1:} and output_shape {2:}".format(
            net, input_shape, output_shape
        )
    )

    return net, input_shape, output_shape
def build_net(mode="inference"):
    if mode != "training" and not os.path.exists(
            "{0:}_io_names.json".format(network_name)):
        logging.info("Building mode training first to generate io names")
        build_net(mode="training")
        tf.reset_default_graph()
    if mode == "inference" or mode == "forward":
        padding = padding_inference
        add_context = add_context_inference
    elif mode == "training":
        padding = padding_train
        add_context = add_context_training
    else:
        raise ValueError("Unkown mode: {0:}".format(mode))
    unet = unet_class.UNet(
        feature_widths_down,
        feature_widths_up,
        downsampling_factors,
        kernel_sizes_down,
        kernel_sizes_up,
        skip_connections=skip_connections,
        padding=padding,
        constant_upsample=constant_upsample,
        trans_equivariant=trans_equivariant,
        enforce_even_context=enforce_even_context,
        input_voxel_size=voxel_size,
        input_fov=voxel_size,
    )
    net, input_shape, output_shape = make_net(network_name,
                                              unet,
                                              n_out,
                                              add_context,
                                              input_name=input_name,
                                              output_names=output_names,
                                              loss_name=loss_name,
                                              mode=mode)

    logging.info(
        "Built {0:} with input shape {1:} and output_shape {2:}".format(
            net, input_shape, output_shape))

    return net, input_shape, output_shape