コード例 #1
0
ファイル: data.py プロジェクト: kylehkhsu/neural-tangents
    def test_omniglot():
        viz = Visdom(port=8000, env='main')
        splits = load_omniglot()

        n_way, n_support, n_query = 3, 5, 7
        # task = omniglot_task(splits['train'], n_way=3, n_support=5, n_query=7)

        batch_size = 2
        for i, batch in enumerate(
                taskbatch(omniglot_task,
                          batch_size=batch_size,
                          n_task=batch_size,
                          split_dict=splits['val'],
                          n_way=n_way,
                          n_support=n_support,
                          n_query=n_query)):

            for i_task in range(batch_size):
                x_train = batch['x_train'][i_task]
                x_test = batch['x_test'][i_task]
                y_train = batch['y_train'][i_task]
                y_test = batch['y_test'][i_task]

                viz.images(tensor=np.transpose(x_train, (0, 3, 1, 2)),
                           nrow=n_support)
                viz.text(f'y_train: {y_train}')
                viz.images(tensor=np.transpose(x_test, (0, 3, 1, 2)),
                           nrow=n_query)
                viz.text(f'y_test: {y_test}')

        viz.save(viz.get_env_list())
コード例 #2
0
ファイル: data.py プロジェクト: kylehkhsu/neural-tangents
    def test_circle():
        viz = Visdom(port=8000, env='circle')

        for i in range(10):
            task = circle_task(n_way=3, n_support=5, n_query=7)

            viz.scatter(X=task['x_train'],
                        Y=np.argmax(task['y_train'], axis=1) + 1,
                        opts=dict(title=f'task {i}: train'))
            viz.scatter(X=task['x_test'],
                        Y=np.argmax(task['y_test'], axis=1) + 1,
                        opts=dict(title=f'task {i}: test'))
        viz.save(viz.get_env_list())
コード例 #3
0
    for i in range(1, args.n_inner_step + 1):
        p = inner_get_params(state)
        g = grad_loss(p, x1, y1)
        state = inner_opt_update(i, g, state)
        p = inner_get_params(state)
        predictions = f(p, xrange_inputs)
        plotter.line(win_name='inference',
                     Y=predictions,
                     X=xrange_inputs,
                     name=f'{i}-step predictions',
                     update='append')

        p_lin = inner_get_params(state_lin)
        g_lin = grad_loss_lin(p_lin, x1, y1)
        state_lin = inner_opt_update(i, g_lin, state_lin)
        p_lin = inner_get_params(state_lin)
        predictions_lin = f_lin(p_lin, xrange_inputs)
        plotter.line(win_name='inference_lin',
                     Y=predictions_lin,
                     X=xrange_inputs,
                     name=f'{i}-step predictions',
                     update='append')

# serialize
np_dir = os.path.join(args.log_dir, 'np')
os.makedirs(np_dir, exist_ok=True)
onp.save(file=os.path.join(np_dir, f'log'), arr=log)

# serialize visdom envs
viz.save(viz.get_env_list())
コード例 #4
0
class VisdomPlotter:
    """
    A Visdom based plotter, to plot aggregated metrics.

    How to use:
    ------------
    (1) Start the server with:
            python -m visdom.server
    (2) Then, in your browser, you can go to:
            http://localhost:8097
    """
    def __init__(self, experiment_env, server='http://localhost', port=8097):
        self.server = server
        self.port = port
        self.viz = Visdom(
            server=server,
            port=port)  # Connect to Visdom server on server / port
        if not self.start_visdom_server():
            raise ValueError('Failed to launch Visdom server at %r:%r' %
                             (server, port))

        if experiment_env in self.viz.get_env_list():
            self.viz.delete_env(
                experiment_env)  # Clear previous runs with same id
        self.experiment_env = experiment_env
        self.plots = {}

    def start_visdom_server(self):
        is_visdom_server_connected = self.viz.check_connection(
            timeout_seconds=1)  # Ping if it's already on..
        if not is_visdom_server_connected:
            interpreter_path = sys.executable
            os.system(interpreter_path + ' -m visdom.server &')
            is_visdom_server_connected = self.viz.check_connection(
                timeout_seconds=35)
        return is_visdom_server_connected

    def plot_single_metric(self, metric, line_id, title, epoch, value):

        if metric not in self.plots:
            self.plots[metric] = self.viz.line(X=np.array([epoch, epoch]),
                                               Y=np.array([value, value]),
                                               env=self.experiment_env,
                                               opts=dict(legend=[line_id],
                                                         title=title,
                                                         xlabel='Epochs',
                                                         ylabel=metric))
        else:
            self.viz.line(X=np.array([epoch]),
                          Y=np.array([value]),
                          env=self.experiment_env,
                          win=self.plots[metric],
                          name=line_id,
                          update='append')

    def plot_confusion_matrix(self, metric, matrix, label_classes):
        if metric not in self.plots:
            self.plots[metric] = self.viz.heatmap(
                X=matrix,
                env=self.experiment_env,
                opts=dict(columnnames=label_classes, rownames=label_classes))
        else:
            self.viz.heatmap(X=matrix,
                             env=self.experiment_env,
                             win=self.plots[metric],
                             opts=dict(columnnames=label_classes,
                                       rownames=label_classes))

    def plot_images(self, images_bchw):
        self.viz.images(images_bchw)

    def plot_aggregated_metrics(self, metrics, epoch):

        for metric in metrics.metrics:
            title = metrics.metric_to_title[metric]
            value = metrics[epoch][metric]

            if metric == 'confusion_matrix':
                label_classes = metrics.label_classes
                self.plot_confusion_matrix(metric, value, label_classes)
            else:
                if hasattr(value, 'shape') and value.size > 1:
                    for idx, dim_val in enumerate(value):
                        line_id = metrics.label_classes[idx]
                        self.plot_single_metric(metric, line_id, title, epoch,
                                                dim_val)
                else:
                    line_id = metrics.data_type
                    self.plot_single_metric(metric, line_id, title, epoch,
                                            value)
コード例 #5
0
                    assert len(value['content']['data']) == 1, "content has multiple items, how to handle?"
                    x = value['content']['data'][0]['x']
                    y = value['content']['data'][0]['y']
                    dfs[value['title']] = pd.DataFrame({'x': x, 'y': y})

    return dfs


def dataframes2csv(dfs):
    for k, v in dfs.items():
        v.to_csv(k + ".csv", index=False)


if __name__ == "__main__":
    # build the visdom object
    viz = Visdom(server=args.visdom_url, port=args.visdom_port, use_incoming_socket=False)

    # handle the --ls case
    if args.ls is not None:
        env_list = [e for e in viz.get_env_list() if args.ls in e]
        for env in env_list:
            print(env)

        exit(0)

    # grab the dataframes
    dfs = vis2dataframe(viz, args.env_base_name, args.feature_name)

    # write the dataframes to csv
    dataframes2csv(dfs)