def longest_path(self):
     if self._model is None:
         raise RuntimeError(
             "Can't compute longest path of model without creating a model."
         )
     nx_graph = nx.drawing.nx_pydot.from_pydot(model_to_dot(self._model))
     return nx.algorithms.dag.dag_longest_path(nx_graph)
Esempio n. 2
0
    def test_model(self):
        """Test that it can be called with a model."""
        model = Sequential()
        model.add(Conv2D(10, (5, 5), input_shape=(28, 28, 1), activation="relu"))
        model.add(MaxPooling2D(pool_size=(2, 2)))
        model.add(Flatten())
        model.add(Dense(8, activation="relu"))

        with patch("streamlit.graphviz_chart") as graphviz_chart:
            st.write(model)

            dot = vis_utils.model_to_dot(model)
            graphviz_chart.assert_called_once_with(dot.to_string())
Esempio n. 3
0
def plot_model_recursive(model, title=None, exclude_models_by_name=None):
    # model.summary()

    if type(exclude_models_by_name) is list:
        if model.name in exclude_models_by_name:
            return
    else:
        exclude_models_by_name = []
    exclude_models_by_name.append(model.name)

    if title is None:
        title = 'Model %s' % model.name

    # render pydot by calling dot, no file saved to disk
    png_str = model_to_dot(model, show_shapes=True).create_png(prog='dot')

    # treat the dot output string as an image file
    sio = BytesIO()
    sio.write(png_str)
    sio.seek(0)
    img = mpimg.imread(sio)

    # set actual size of image on plot
    dpi = 80
    height, width, depth = img.shape
    figsize = width / float(dpi), height / float(dpi)
    fig = plt.figure(figsize=figsize)

    # plot the image
    plt.imshow(img)
    plt.axis('off')
    plt.title(title)
    plt.show(block=False)

    for layer in model.layers:
        if isinstance(layer, Model):
            plot_model_recursive(layer,
                                 exclude_models_by_name=exclude_models_by_name)
Esempio n. 4
0
def write(*args, **kwargs):
    """Write arguments to the app.

    This is the swiss-army knife of Streamlit commands. It does different
    things depending on what you throw at it.

    Unlike other Streamlit commands, write() has some unique properties:

        1. You can pass in multiple arguments, all of which will be written.
        2. Its behavior depends on the input types as follows.
        3. It returns None, so it's "slot" in the App cannot be reused.

    Parameters
    ----------
    *args : any
        One or many objects to print to the App.

        Arguments are handled as follows:

            - write(string)     : Prints the formatted Markdown string.
            - write(data_frame) : Displays the DataFrame as a table.
            - write(error)      : Prints an exception specially.
            - write(func)       : Displays information about a function.
            - write(module)     : Displays information about the module.
            - write(dict)       : Displays dict in an interactive widget.
            - write(obj)        : The default is to print str(obj).
            - write(mpl_fig)    : Displays a Matplotlib figure.
            - write(altair)     : Displays an Altair chart.
            - write(keras)      : Displays a Keras model.
            - write(graphviz)   : Displays a Graphviz graph.
            - write(plotly_fig) : Displays a Plotly figure.
            - write(bokeh_fig)  : Displays a Bokeh figure.

    unsafe_allow_html : bool
        This is a keyword-only argument that defaults to False.

        By default, any HTML tags found in strings will be escaped and
        therefore treated as pure text. This behavior may be turned off by
        setting this argument to True.

        That said, *we strongly advise* against it*. It is hard to write secure
        HTML, so by using this argument you may be compromising your users'
        security. For more information, see:

        https://github.com/streamlit/streamlit/issues/152

        *Also note that `unsafe_allow_html` is a temporary measure and may be
        removed from Streamlit at any time.*

        If you decide to turn on HTML anyway, we ask you to please tell us your
        exact use case here:

        https://discuss.streamlit.io/t/96

        This will help us come up with safe APIs that allow you to do what you
        want.

    Example
    -------

    Its simplest use case is to draw Markdown-formatted text, whenever the
    input is a string:

    >>> write('Hello, *World!*')

    .. output::
       https://share.streamlit.io/0.25.0-2JkNY/index.html?id=DUJaq97ZQGiVAFi6YvnihF
       height: 50px

    As mentioned earlier, `st.write()` also accepts other data formats, such as
    numbers, data frames, styled data frames, and assorted objects:

    >>> st.write(1234)
    >>> st.write(pd.DataFrame({
    ...     'first column': [1, 2, 3, 4],
    ...     'second column': [10, 20, 30, 40],
    ... }))

    .. output::
       https://share.streamlit.io/0.25.0-2JkNY/index.html?id=FCp9AMJHwHRsWSiqMgUZGD
       height: 250px

    Finally, you can pass in multiple arguments to do things like:

    >>> st.write('1 + 1 = ', 2)
    >>> st.write('Below is a DataFrame:', data_frame, 'Above is a dataframe.')

    .. output::
       https://share.streamlit.io/0.25.0-2JkNY/index.html?id=DHkcU72sxYcGarkFbf4kK1
       height: 300px

    Oh, one more thing: `st.write` accepts chart objects too! For example:

    >>> import pandas as pd
    >>> import numpy as np
    >>> import altair as alt
    >>>
    >>> df = pd.DataFrame(
    ...     np.random.randn(200, 3),
    ...     columns=['a', 'b', 'c'])
    ...
    >>> c = alt.Chart(df).mark_circle().encode(
    ...     x='a', y='b', size='c', color='c')
    >>>
    >>> st.write(c)

    .. output::
       https://share.streamlit.io/0.25.0-2JkNY/index.html?id=8jmmXR8iKoZGV4kXaKGYV5
       height: 200px

    """
    # Python2 doesn't support this syntax
    #   def write(*args, unsafe_allow_html=False)
    # so we do this instead:
    unsafe_allow_html = kwargs.get('unsafe_allow_html', False)

    try:
        string_buffer = []

        def flush_buffer():
            if string_buffer:
                markdown(" ".join(string_buffer),
                         unsafe_allow_html=unsafe_allow_html)  # noqa: F821
                string_buffer[:] = []

        for arg in args:
            # Order matters!
            if isinstance(arg, string_types):  # noqa: F821
                string_buffer.append(arg)
            elif type(arg).__name__ in _DATAFRAME_LIKE_TYPES:
                flush_buffer()
                if len(_np.shape(arg)) > 2:
                    text(arg)
                else:
                    dataframe(arg)  # noqa: F821
            elif isinstance(arg, Exception):
                flush_buffer()
                exception(arg)  # noqa: F821
            elif isinstance(arg, _HELP_TYPES):
                flush_buffer()
                help(arg)
            elif _util.is_altair_chart(arg):
                flush_buffer()
                altair_chart(arg)
            elif _util.is_type(arg, "matplotlib.figure.Figure"):
                flush_buffer()
                pyplot(arg)
            elif _util.is_plotly_chart(arg):
                flush_buffer()
                plotly_chart(arg)
            elif _util.is_type(arg, "bokeh.plotting.figure.Figure"):
                flush_buffer()
                bokeh_chart(arg)
            elif _util.is_graphviz_chart(arg):
                flush_buffer()
                graphviz_chart(arg)
            elif _util.is_keras_model(arg):
                from tensorflow.python.keras.utils import vis_utils

                flush_buffer()
                dot = vis_utils.model_to_dot(arg)
                graphviz_chart(dot.to_string())
            elif (type(arg) in dict_types) or (isinstance(arg,
                                                          list)):  # noqa: F821
                flush_buffer()
                json(arg)
            elif _util.is_namedtuple(arg):
                flush_buffer()
                json(_json.dumps(arg._asdict()))
            else:
                string_buffer.append("`%s`" % str(arg).replace("`", "\\`"))

        flush_buffer()

    except Exception:
        _, exc, exc_tb = _sys.exc_info()
        exception(exc, exc_tb)  # noqa: F821
Esempio n. 5
0
def write(*args):
    """Write arguments to the report.

    This is the swiss-army knife of Streamlit commands. It does different
    things depending on what you throw at it.

    Unlike other Streamlit commands, write() has some unique properties:

        1. You can pass in multiple arguments, all of which will be written.
        2. Its behavior depends on the input types as follows.
        3. It returns None, so it's "slot" in the report cannot be reused.

    Parameters
    ----------
    *args : any
        One or many objects to print to the Report.

    Arguments are handled as follows:

        - write(string)     : Prints the formatted Markdown string.
        - write(data_frame) : Displays the DataFrame as a table.
        - write(error)      : Prints an exception specially.
        - write(func)       : Displays information about a function.
        - write(module)     : Displays information about the module.
        - write(dict)       : Displays dict in an interactive widget.
        - write(obj)        : The default is to print str(obj).
        - write(mpl_fig)    : Displays a Matplotlib figure.
        - write(altair)     : Displays an Altair chart.
        - write(keras)      : Displays a Keras model.
        - write(graphviz)   : Displays a Graphviz graph.
        - write(plotly_fig) : Displays a Plotly figure.
        - write(bokeh_fig)  : Displays a Bokeh figure.

    Example
    -------

    Its simplest use case is to draw Markdown-formatted text, whenever the
    input is a string:

    >>> write('Hello, *World!*')

    .. output::
       https://share.streamlit.io/0.25.0-2JkNY/index.html?id=DUJaq97ZQGiVAFi6YvnihF
       height: 50px

    As mentioned earlier, `st.write()` also accepts other data formats, such as
    numbers, data frames, styled data frames, and assorted objects:

    >>> st.write(1234)
    >>> st.write(pd.DataFrame({
    ...     'first column': [1, 2, 3, 4],
    ...     'second column': [10, 20, 30, 40],
    ... }))

    .. output::
       https://share.streamlit.io/0.25.0-2JkNY/index.html?id=FCp9AMJHwHRsWSiqMgUZGD
       height: 250px

    Finally, you can pass in multiple arguments to do things like:

    >>> st.write('1 + 1 = ', 2)
    >>> st.write('Below is a DataFrame:', data_frame, 'Above is a dataframe.')

    .. output::
       https://share.streamlit.io/0.25.0-2JkNY/index.html?id=DHkcU72sxYcGarkFbf4kK1
       height: 300px

    Oh, one more thing: `st.write` accepts chart objects too! For example:

    >>> import pandas as pd
    >>> import numpy as np
    >>> import altair as alt
    >>>
    >>> df = pd.DataFrame(
    ...     np.random.randn(200, 3),
    ...     columns=['a', 'b', 'c'])
    ...
    >>> c = alt.Chart(df).mark_circle().encode(
    ...     x='a', y='b', size='c', color='c')
    >>>
    >>> st.write(c)

    .. output::
       https://share.streamlit.io/0.25.0-2JkNY/index.html?id=8jmmXR8iKoZGV4kXaKGYV5
       height: 200px

    """
    try:
        string_buffer = []

        def flush_buffer():
            if string_buffer:
                markdown(' '.join(string_buffer))  # noqa: F821
                string_buffer[:] = []

        for arg in args:
            # Order matters!
            if isinstance(arg, string_types):  # noqa: F821
                string_buffer.append(arg)
            elif type(arg).__name__ in _DATAFRAME_LIKE_TYPES:
                flush_buffer()
                dataframe(arg)  # noqa: F821
            elif isinstance(arg, Exception):
                flush_buffer()
                exception(arg)  # noqa: F821
            elif isinstance(arg, _HELP_TYPES):
                flush_buffer()
                help(arg)
            elif _util.is_altair_chart(arg):
                flush_buffer()
                altair_chart(arg)
            elif _util.is_type(arg, 'matplotlib.figure.Figure'):
                flush_buffer()
                pyplot(arg)
            elif _util.is_plotly_chart(arg):
                flush_buffer()
                plotly_chart(arg)
            elif _util.is_type(arg, 'bokeh.plotting.figure.Figure'):
                flush_buffer()
                bokeh_chart(arg)
            elif _util.is_graphviz_chart(arg):
                flush_buffer()
                graphviz_chart(arg)
            elif util.is_keras_model(arg):
                from tensorflow.python.keras.utils import vis_utils
                flush_buffer()
                dot = vis_utils.model_to_dot(arg)
                graphviz_chart(dot.to_string())
            elif type(arg) in dict_types:  # noqa: F821
                flush_buffer()
                json(arg)
            else:
                string_buffer.append('`%s`' % str(arg).replace('`', '\\`'))

        flush_buffer()

    except Exception:
        _, exc, exc_tb = _sys.exc_info()
        exception(exc, exc_tb)  # noqa: F821
Esempio n. 6
0
    def write(self, *args, **kwargs):
        """Write arguments to the app.

        This is the Swiss Army knife of Streamlit commands: it does different
        things depending on what you throw at it. Unlike other Streamlit commands,
        write() has some unique properties:

        1. You can pass in multiple arguments, all of which will be written.
        2. Its behavior depends on the input types as follows.
        3. It returns None, so it's "slot" in the App cannot be reused.

        Parameters
        ----------
        *args : any
            One or many objects to print to the App.

            Arguments are handled as follows:

            - write(string)     : Prints the formatted Markdown string, with
                support for LaTeX expression and emoji shortcodes.
                See docs for st.markdown for more.
            - write(data_frame) : Displays the DataFrame as a table.
            - write(error)      : Prints an exception specially.
            - write(func)       : Displays information about a function.
            - write(module)     : Displays information about the module.
            - write(dict)       : Displays dict in an interactive widget.
            - write(obj)        : The default is to print str(obj).
            - write(mpl_fig)    : Displays a Matplotlib figure.
            - write(altair)     : Displays an Altair chart.
            - write(keras)      : Displays a Keras model.
            - write(graphviz)   : Displays a Graphviz graph.
            - write(plotly_fig) : Displays a Plotly figure.
            - write(bokeh_fig)  : Displays a Bokeh figure.
            - write(sympy_expr) : Prints SymPy expression using LaTeX.

        unsafe_allow_html : bool
            This is a keyword-only argument that defaults to False.

            By default, any HTML tags found in strings will be escaped and
            therefore treated as pure text. This behavior may be turned off by
            setting this argument to True.

            That said, *we strongly advise* against it*. It is hard to write secure
            HTML, so by using this argument you may be compromising your users'
            security. For more information, see:

            https://github.com/streamlit/streamlit/issues/152

            **Also note that `unsafe_allow_html` is a temporary measure and may be
            removed from Streamlit at any time.**

            If you decide to turn on HTML anyway, we ask you to please tell us your
            exact use case here:
            https://discuss.streamlit.io/t/96 .

            This will help us come up with safe APIs that allow you to do what you
            want.

        Example
        -------

        Its basic use case is to draw Markdown-formatted text, whenever the
        input is a string:

        >>> write('Hello, *World!* :sunglasses:')

        ..  output::
            https://static.streamlit.io/0.50.2-ZWk9/index.html?id=Pn5sjhgNs4a8ZbiUoSTRxE
            height: 50px

        As mentioned earlier, `st.write()` also accepts other data formats, such as
        numbers, data frames, styled data frames, and assorted objects:

        >>> st.write(1234)
        >>> st.write(pd.DataFrame({
        ...     'first column': [1, 2, 3, 4],
        ...     'second column': [10, 20, 30, 40],
        ... }))

        ..  output::
            https://static.streamlit.io/0.25.0-2JkNY/index.html?id=FCp9AMJHwHRsWSiqMgUZGD
            height: 250px

        Finally, you can pass in multiple arguments to do things like:

        >>> st.write('1 + 1 = ', 2)
        >>> st.write('Below is a DataFrame:', data_frame, 'Above is a dataframe.')

        ..  output::
            https://static.streamlit.io/0.25.0-2JkNY/index.html?id=DHkcU72sxYcGarkFbf4kK1
            height: 300px

        Oh, one more thing: `st.write` accepts chart objects too! For example:

        >>> import pandas as pd
        >>> import numpy as np
        >>> import altair as alt
        >>>
        >>> df = pd.DataFrame(
        ...     np.random.randn(200, 3),
        ...     columns=['a', 'b', 'c'])
        ...
        >>> c = alt.Chart(df).mark_circle().encode(
        ...     x='a', y='b', size='c', color='c', tooltip=['a', 'b', 'c'])
        >>>
        >>> st.write(c)

        ..  output::
            https://static.streamlit.io/0.25.0-2JkNY/index.html?id=8jmmXR8iKoZGV4kXaKGYV5
            height: 200px

        """
        string_buffer = []  # type: List[str]
        unsafe_allow_html = kwargs.get("unsafe_allow_html", False)

        # This bans some valid cases like: e = st.empty(); e.write("a", "b").
        # BUT: 1) such cases are rare, 2) this rule is easy to understand,
        # and 3) this rule should be removed once we have st.container()
        if not self.dg._is_top_level and len(args) > 1:
            raise StreamlitAPIException(
                "Cannot replace a single element with multiple elements.\n\n"
                "The `write()` method only supports multiple elements when "
                "inserting elements rather than replacing. That is, only "
                "when called as `st.write()` or `st.sidebar.write()`.")

        def flush_buffer():
            if string_buffer:
                self.dg.markdown(
                    " ".join(string_buffer),
                    unsafe_allow_html=unsafe_allow_html,
                )
                string_buffer[:] = []

        for arg in args:
            # Order matters!
            if isinstance(arg, str):
                string_buffer.append(arg)
            elif type_util.is_dataframe_like(arg):
                flush_buffer()
                if len(np.shape(arg)) > 2:
                    self.dg.text(arg)
                else:
                    self.dg.dataframe(arg)
            elif isinstance(arg, Exception):
                flush_buffer()
                self.dg.exception(arg)
            elif isinstance(arg, HELP_TYPES):
                flush_buffer()
                self.dg.help(arg)
            elif type_util.is_altair_chart(arg):
                flush_buffer()
                self.dg.altair_chart(arg)
            elif type_util.is_type(arg, "matplotlib.figure.Figure"):
                flush_buffer()
                self.dg.pyplot(arg)
            elif type_util.is_plotly_chart(arg):
                flush_buffer()
                self.dg.plotly_chart(arg)
            elif type_util.is_type(arg, "bokeh.plotting.figure.Figure"):
                flush_buffer()
                self.dg.bokeh_chart(arg)
            elif type_util.is_graphviz_chart(arg):
                flush_buffer()
                self.dg.graphviz_chart(arg)
            elif type_util.is_sympy_expession(arg):
                flush_buffer()
                self.dg.latex(arg)
            elif type_util.is_keras_model(arg):
                from tensorflow.python.keras.utils import vis_utils

                flush_buffer()
                dot = vis_utils.model_to_dot(arg)
                self.dg.graphviz_chart(dot.to_string())
            elif isinstance(arg, (dict, list)):
                flush_buffer()
                self.dg.json(arg)
            elif type_util.is_namedtuple(arg):
                flush_buffer()
                self.dg.json(json.dumps(arg._asdict()))
            elif type_util.is_pydeck(arg):
                flush_buffer()
                self.dg.pydeck_chart(arg)
            else:
                string_buffer.append("`%s`" % str(arg).replace("`", "\\`"))

        flush_buffer()
                       kernel_regularizer=tf.keras.regularizers.l2(1e-4),
                       model_name="fashion_mnist_test")

# Get keras fit data
OPTIMIZER = tf.keras.optimizers.SGD(lr=MAX_LR, momentum=0.9, nesterov=True)

LOSS = tf.keras.losses.categorical_crossentropy
METRICS = [
    tf.keras.metrics.categorical_accuracy,
]

CALLBACKS = []

# get visualisation
SVG(
    model_to_dot(keras_model, show_shapes=True).create(prog='dot',
                                                       format='svg'))

# compile
keras_model.compile(optimizer=OPTIMIZER, loss=LOSS, metrics=METRICS)

# fit
history = keras_model.fit(train_dataset_fn(),
                          epochs=NUM_EPOCHS,
                          verbose=2,
                          callbacks=CALLBACKS,
                          validation_data=test_dataset_fn(),
                          initial_epoch=0,
                          steps_per_epoch=NUM_TRAIN_STEPS,
                          validation_steps=NUM_TEST_STEPS)

# Plot training & validation accuracy values
Esempio n. 8
0
    # preprocessing sentences into sentence vectors
    sentence = Input(shape=(T, embedding_size), name='Sentences') # batch, 50, 300
    sentence_vec = Bidirectional(CuDNNGRU(units=n_a, return_sequences=False), name='Sentence_Vectors')(sentence) # batch, 300
    # dot
    #product = Dot(axes=-1, normalize=False, name='Matrix')([word_vec, sentence_vec])
    product = tf.matmul(word_vec, sentence_vec, transpose_b = True, name = 'Matrix')
    key_matrix = K.transpose(product)
    model = Model(inputs= sentence, outputs=key_matrix)
    return model

# create a model
Factorize = model(embedding_size, n_a) 

# get a summary of the model
Factorize.summary()
SVG(model_to_dot(Factorize,  show_shapes=True, show_layer_names=True, rankdir='HB').create(prog='dot', format='svg'))
plot_model(Factorize, to_file='model.png', show_shapes=True, show_layer_names=True)



#%% Run 
Factorize.compile(loss='mean_squared_error', optimizer = Adam(lr=0.001), metrics=['mean_squared_error', 'mean_absolute_error'])
hist = Factorize.fit(X_train_padded, key_matrix_train, batch_size=256, epochs=500, verbose=1, validation_data=(X_test_padded, key_matrix_test))



#%% Loss plot
%matplotlib inline
import matplotlib.pyplot as plt

# accuracy
Esempio n. 9
0
                                                              np.newaxis],
                                                    to_categorical(y_dataset),
                                                    test_size=0.2,
                                                    random_state=42)

hid_dim = 10

# SimpleRNNにDenseを接続し、分類
model = Sequential()

model.add(SimpleRNN(hid_dim, input_shape=x_train.shape[1:])
          )  # input_shape=(系列長T, x_tの次元), output_shape=(units(=hid_dim),)
model.add(Dense(y_train.shape[1], activation='softmax'))

model.compile(loss='categorical_crossentropy',
              optimizer='adam',
              metrics=['accuracy'])

model.fit(x_train,
          y_train,
          epochs=50,
          batch_size=100,
          verbose=2,
          validation_split=0.2)

score = model.evaluate(x_test, y_test, verbose=0)
print('test_loss:', score[0])
print('test_acc:', score[1])

SVG(model_to_dot(model).create(prog='dot', format='svg'))
Esempio n. 10
0
#Model

input = Input(shape=(224, 224, 3))

x = conv_batchnorm_relu(input, filters=64, kernel_size=7, strides=2)
x = MaxPool2D(pool_size=3, strides=2)(x)
x = resnet_block(x, filters=64, reps=3, strides=1)
x = resnet_block(x, filters=128, reps=4, strides=2)
x = resnet_block(x, filters=256, reps=6, strides=2)
x = resnet_block(x, filters=512, reps=3, strides=2)
x = GlobalAvgPool2D()(x)

output = Dense(1000, activation='softmax')(x)

model = Model(inputs=input, outputs=output)
model.summary()

from tensorflow.python.keras.utils.vis_utils import model_to_dot
from IPython.display import SVG
import pydot
import graphviz

SVG(
    model_to_dot(model,
                 show_shapes=True,
                 show_layer_names=True,
                 rankdir='TB',
                 expand_nested=False,
                 dpi=60,
                 subgraph=False).create(prog='dot', format='svg'))
Esempio n. 11
0
# loss は損失関数
#  連続値の時は平均二乗誤差が使われることが多い loss='mean_squared_error'
#  離散値の時は交差エントロピーが使われることが多い
#   2クラス交差エントロピー loss='binary_crossentropy'
#   多クラス交差エントロピー loss='categorical_crossentropy'
model.compile(loss='categorical_crossentropy',
              optimizer='sgd',
              metrics=['accuracy'])

# 最後に fit でデータを入れる
model.fit(x_train,
          y_train,
          batch_size=1000,
          epochs=10,
          verbose=1,
          validation_data=(x_test, y_test))

# モデルを評価する
score = model.evaluate(x_test, y_test, verbose=0)
print('Test loss:', score[0])
print('Test accuracy:', score[1])

# 新しいデータを予測する場合
# classes = model.predict(x_test, batch_size=128)
# print(classes)

# モデルの可視化
from IPython.display import SVG
from tensorflow.python.keras.utils.vis_utils import model_to_dot
print(model_to_dot(model, show_shapes=True))