Пример #1
0
class NetworkVisualisation(object):
    def __init__(self,
                 units,
                 data_points,
                 min_range,
                 max_range,
                 quality,
                 dataset,
                 saves_path=None,
                 seed=1):
        np.random.seed(seed)

        self.precision = quality
        self.min_range = min_range
        self.max_range = max_range
        self.dataset_type = dataset
        if dataset is Dataset.CIRCLE:
            self.dataset = get_circle_dataset(points=data_points,
                                              min_range=min_range,
                                              max_range=max_range,
                                              radius=0.8)
        elif dataset is Dataset.SPIRAL:
            self.dataset = get_spiral_dataset(data_points, classes=2)
        else:
            raise Exception("Invalid dataset type")
        self.data_points = self.dataset[:, :-1]
        self.data_labels = self.dataset[:, -1]
        self.data_space, self.dim_data = setup_data_space(
            min_range, max_range, quality)

        # Network Creation
        if saves_path and check_saved_network(units, dataset, saves_path):
            self.network = load_network(units, dataset, saves_path)
        else:
            if dataset is Dataset.CIRCLE:
                self.network = train_network_sigmoid(self.dataset,
                                                     units=units,
                                                     learning_rate=5e-3,
                                                     window_size=1000)
            elif dataset is Dataset.SPIRAL:
                self.network = train_network_softmax(self.dataset,
                                                     units=units,
                                                     learning_rate=1,
                                                     window_size=1000)
            else:
                raise Exception("Invalid dataset type")
            save_network(self.network, dataset, saves_path)
        self.default_network = dict(
            zip(self.network.keys(),
                [layer.copy() for layer in self.network.values()]))

        # GUI Visualisation
        self.perceptron1 = 0
        self.is_relu = True
        self.perceptron2 = 0
        self.connection = 0
        self.is_pre_add = False
        self.is_sig = True
        self.all_p1_enabled = set(range(self.network["W1"].shape[1]))
        self.ignore_update = False

        fig = plt.figure(figsize=(13, 6.5))
        self.plot_network(fig)
        self.plot_controls(fig)

    def plot_network(self, fig):
        _, out2, out3, out4 = forward(self.data_space,
                                      self.network,
                                      self.dataset_type,
                                      precision=self.precision)
        outer_points = self.data_points[self.data_labels == 1]
        inner_points = self.data_points[self.data_labels == 0]

        self.layer1_plot = Plot(fig, (4, 4), (0, 0), (1, 3), out2[:, :, 0],
                                self.min_range, self.max_range)
        self.layer1_plot.ax.scatter(outer_points[:, 0],
                                    outer_points[:, 1],
                                    s=3,
                                    c="g",
                                    alpha=0.5)
        self.layer1_plot.ax.scatter(inner_points[:, 0],
                                    inner_points[:, 1],
                                    s=3,
                                    c="r",
                                    alpha=0.5)

        self.layer1_3d_plot = Plot3D(fig, (4, 4), (1, 0), (1, 3),
                                     self.precision, out2)

        self.layer2_plot = Plot(fig, (4, 4), (2, 0), (1, 3), out4[:, :, 0],
                                self.min_range, self.max_range)
        self.layer2_plot.ax.scatter(outer_points[:, 0],
                                    outer_points[:, 1],
                                    s=3,
                                    c="g",
                                    alpha=0.5)
        self.layer2_plot.ax.scatter(inner_points[:, 0],
                                    inner_points[:, 1],
                                    s=3,
                                    c="r",
                                    alpha=0.5)

        self.layer2_3d_plot = Plot3D(fig, (4, 4), (3, 0), (1, 3),
                                     self.precision, out4)

    def plot_controls(self, fig):
        step_size = 0.01
        padding = 5

        # Plot 1 controls
        w1x_min = self.network["W1"][0].min()
        w1x_max = self.network["W1"][0].max()
        w1x_diff = (w1x_max - w1x_min) / 2 + padding

        w1y_min = self.network["W1"][1].min()
        w1y_max = self.network["W1"][1].max()
        w1y_diff = (w1y_max - w1y_min) / 2 + padding

        w1b_min = self.network["b1"].min()
        w1b_max = self.network["b1"].max()
        w1b_diff = (w1b_max - w1b_min) / 2 + padding

        p1x_ax = plot_to_grid(fig, (2, 16), (0, 12), (1, 1))
        self.p1x_slid = Slider(p1x_ax,
                               'P1 x',
                               valmin=w1x_min - w1x_diff,
                               valmax=w1x_max + w1x_diff,
                               valinit=self.network["W1"][0, 0],
                               valstep=step_size)
        self.p1x_slid.on_changed(self.p1x_changed)

        p1y_ax = plot_to_grid(fig, (2, 16), (0, 13), (1, 1))
        self.p1y_slid = Slider(p1y_ax,
                               'P1 y',
                               valmin=w1y_min - w1y_diff,
                               valmax=w1y_max + w1y_diff,
                               valinit=self.network["W1"][1, 0],
                               valstep=step_size)
        self.p1y_slid.on_changed(self.p1y_changed)

        p1b_ax = plot_to_grid(fig, (24, 16), (0, 14), (7, 1))
        self.p1b_slid = Slider(p1b_ax,
                               'P1 b',
                               valmin=w1b_min - w1b_diff,
                               valmax=w1b_max + w1b_diff,
                               valinit=self.network["b1"][0, 0],
                               valstep=step_size)
        self.p1b_slid.on_changed(self.p1b_changed)

        p1_ax = plot_to_grid(fig, (24, 16), (0, 15), (7, 1))
        self.p1_slid = Slider(p1_ax,
                              'P1',
                              valmin=0,
                              valmax=self.network["W1"].shape[1] - 1,
                              valinit=self.perceptron1,
                              valstep=1)
        self.p1_slid.on_changed(self.p1_changed)

        p1_opt_ax = plot_to_grid(fig, (24, 16), (8, 14), (3, 2))
        self.p1_opt_buttons = CheckButtons(p1_opt_ax, ["ReLU?", "Enabled?"],
                                           [self.is_relu, True])
        self.p1_opt_buttons.on_clicked(self.p1_options_update)

        # Plot 2 Controls
        w2_min = self.network["W2"].min()
        w2_max = self.network["W2"].max()
        w2_diff = (w2_max - w2_min) / 2 + padding

        w2b_abs = np.abs(self.network["b2"][0, 0]) + padding
        w2b_min = self.network["b2"][0, 0] - w2b_abs
        w2b_max = self.network["b2"][0, 0] + w2b_abs

        p2_weight_val_ax = plot_to_grid(fig, (2, 16), (1, 12), (1, 1))
        self.p2_dim_val_slid = Slider(p2_weight_val_ax,
                                      'p2 w',
                                      valmin=w2_min - w2_diff,
                                      valmax=w2_max + w2_diff,
                                      valinit=self.network["W2"][0, 0],
                                      valstep=step_size)
        self.p2_dim_val_slid.on_changed(self.p2_weight_changed)

        p2_connection_dim_ax = plot_to_grid(fig, (2, 16), (1, 13), (1, 1))
        self.p2_connection_dim_slid = Slider(
            p2_connection_dim_ax,
            'p2 c',
            valmin=0,
            valmax=self.network["W2"].shape[0] - 1,
            valinit=0,
            valstep=1)
        self.p2_connection_dim_slid.on_changed(self.p2_connection_dim_changed)

        p2b_ax = plot_to_grid(fig, (24, 16), (13, 14), (7, 1))
        self.p2b_slid = Slider(p2b_ax,
                               'p2 b',
                               valmin=w2b_min,
                               valmax=w2b_max,
                               valinit=self.network["b2"][0, 0],
                               valstep=step_size)
        self.p2b_slid.on_changed(self.p2b_changed)

        p2_opt_ax = plot_to_grid(fig, (24, 16), (21, 14), (4, 2))
        self.p2_opt_buttons = CheckButtons(p2_opt_ax,
                                           ["Pre-add?", "Transform?"],
                                           [self.is_pre_add, self.is_sig])
        self.p2_opt_buttons.on_clicked(self.p2_options_update)

    def p1_changed(self, val):
        self.perceptron1 = int(val)
        self.ignore_update = True
        self.update_widgets()
        self.ignore_update = False

        self.update_just_plot1()

    def p1x_changed(self, val):
        self.network["W1"][0, self.perceptron1] = val
        self.update_visuals()

    def p1y_changed(self, val):
        self.network["W1"][1, self.perceptron1] = val
        self.update_visuals()

    def p1b_changed(self, val):
        self.network["b1"][0, self.perceptron1] = val
        self.update_visuals()

    def p1_options_update(self, label):
        if label == "ReLU?":
            self.is_relu = not self.is_relu
            self.update_just_plot1()
        elif label == "Enabled?":
            is_enabled = self.p1_opt_buttons.get_status()[1]
            if is_enabled and self.perceptron1 not in self.all_p1_enabled:
                self.all_p1_enabled.add(self.perceptron1)
            elif not is_enabled and self.perceptron1 in self.all_p1_enabled:
                layer1_out = sorted(list(self.all_p1_enabled)).index(
                    self.perceptron1)
                self.layer1_3d_plot.remove_plot(layer1_out)
                self.all_p1_enabled.remove(self.perceptron1)

            self.update_visuals()

    def p2_weight_changed(self, val):
        self.network["W2"][self.connection, 0] = val

        self.update_just_plot2()

    def p2_connection_dim_changed(self, val):
        self.connection = int(val)
        self.ignore_update = True
        self.p2_dim_val_slid.set_val(self.network["W2"][self.connection, 0])
        self.p2_dim_val_slid.vline.set_xdata(
            self.default_network["W2"][self.connection, 0])
        self.ignore_update = False

    def p2b_changed(self, val):
        self.network["b2"][0, 0] = val
        self.update_just_plot2()

    def p2_options_update(self, label):
        if label == "Transform?":
            self.is_sig = not self.is_sig
        elif label == "Pre-add?":
            self.is_pre_add = not self.is_pre_add

        self.update_just_plot2()

    def show(self):
        plt.show()

    def update_plot1(self, out1, out2):
        if self.perceptron1 in self.all_p1_enabled:
            self.layer1_plot.set_visible(True)
            layer1_out = sorted(list(self.all_p1_enabled)).index(
                self.perceptron1)
            if not self.is_relu:
                layer1_data = out1[:, :, layer1_out]
            else:
                layer1_data = out2[:, :, layer1_out]
            self.layer1_plot.update(layer1_data)
        else:
            self.layer1_plot.set_visible(False)

    def update_3d_plot1(self, out1, out2):
        if self.perceptron1 in self.all_p1_enabled:
            if not self.is_relu:
                self.layer1_3d_plot.update_all(out1)
            else:
                self.layer1_3d_plot.update_all(out2)

    def update_plot2(self, out2, out3, out4):
        if self.is_pre_add:
            layer2_data = scale_out2(out2, self.network["W2"],
                                     self.all_p1_enabled, self.perceptron2,
                                     self.precision)
            layer2_data = np.sum(layer2_data, axis=2)
        elif not self.is_sig:
            layer2_data = out3[:, :, 0]
        else:
            layer2_data = out4[:, :, 0]
        self.layer2_plot.update(layer2_data)

    def update_3d_plot2(self, out2, out3, out4):
        if self.is_pre_add:
            layer2_data = scale_out2(out2, self.network["W2"],
                                     self.all_p1_enabled, self.perceptron2,
                                     self.precision)
        elif not self.is_sig:
            layer2_data = out3
        else:
            layer2_data = out4
        self.layer2_3d_plot.update_all(layer2_data)

    def update_visuals(self):
        if not self.ignore_update:
            out1, out2, out3, out4 = forward(self.data_space, self.network,
                                             self.dataset_type,
                                             self.all_p1_enabled,
                                             self.precision)
            self.update_plot1_visuals(out1, out2)
            self.update_plot2_visuals(out2, out3, out4)
            plt.draw()

    def update_just_plot1(self):
        if not self.ignore_update:
            out1, out2, out3, out4 = forward(self.data_space, self.network,
                                             self.dataset_type,
                                             self.all_p1_enabled,
                                             self.precision)
            self.update_plot1_visuals(out1, out2)
            plt.draw()

    def update_plot1_visuals(self, out1, out2):
        self.update_plot1(out1, out2)
        self.update_3d_plot1(out1, out2)

    def update_just_plot2(self):
        if not self.ignore_update:
            out1, out2, out3, out4 = forward(self.data_space, self.network,
                                             self.dataset_type,
                                             self.all_p1_enabled,
                                             self.precision)
            self.update_plot2_visuals(out2, out3, out4)
            plt.draw()

    def update_plot2_visuals(self, out2, out3, out4):
        self.update_plot2(out2, out3, out4)
        self.update_3d_plot2(out2, out3, out4)

    def update_widgets(self):
        self.p1b_slid.set_val(self.network["b1"][0, self.perceptron1])
        self.p1x_slid.set_val(self.network["W1"][0, self.perceptron1])
        self.p1y_slid.set_val(self.network["W1"][1, self.perceptron1])

        self.p1b_slid.vline.set_xdata(
            self.default_network["b1"][0, self.perceptron1])
        self.p1x_slid.vline.set_xdata(
            self.default_network["W1"][0, self.perceptron1])
        self.p1y_slid.vline.set_xdata(
            self.default_network["W1"][1, self.perceptron1])

        if (self.perceptron1 in self.all_p1_enabled and not self.p1_opt_buttons.get_status()[1]) or \
           (self.perceptron1 not in self.all_p1_enabled and self.p1_opt_buttons.get_status()[1]):
            self.p1_opt_buttons.set_active(1)