示例#1
0
def visdom_server():
    # Start Visdom server once and stop it with visdom_server_stop
    global vd_hostname, vd_port, vd_server_process

    if vd_server_process is None:

        import subprocess
        import time

        from visdom import Visdom
        from visdom.server import download_scripts

        download_scripts()

        vd_hostname = "localhost"
        vd_port = random.randint(8089, 8887)

        try:
            vis = Visdom(server=vd_hostname, port=vd_port, raise_exceptions=True)
        except ConnectionError:
            pass

        vd_server_process = subprocess.Popen(
            ["python", "-m", "visdom.server", "--hostname", vd_hostname, "-port", str(vd_port)]
        )
        time.sleep(5)

        vis = Visdom(server=vd_hostname, port=vd_port)
        assert vis.check_connection()
        vis.close()

    yield (vd_hostname, vd_port)
示例#2
0
class BaseVis(object):
    def __init__(self,
                 viz_opts,
                 update_mode='append',
                 env=None,
                 win=None,
                 resume=False,
                 port=8097):
        self.viz_opts = viz_opts
        self.update_mode = update_mode
        self.win = win
        if env is None:
            env = 'main'
        self.viz = Visdom(env=env, port=port)
        # if resume first plot should not update with replace
        self.removed = not resume

    def win_exists(self):
        return self.viz.win_exists(self.win)

    def close(self):
        if self.win is not None:
            self.viz.close(win=self.win)
            self.win = None

    def register_event_handler(self, handler):
        self.viz.register_event_handler(handler, self.win)
示例#3
0
class ZcsVisdom:
    def __init__(self, server='10.160.82.54', port=8899):
        self.vis = Visdom(server=server, port=port)
        assert self.vis.check_connection()

        self.wins = {}

    def plot(self, data, win):
        '''
        array: np.array
        win: str
        '''
        x = np.arange(data.shape[0])
        self.vis.line(data, x, win=win)#, update='append')

    def append(self, data, win, opts):
        '''
        data是个list,长度和绘制的折线数一致
        '''
        assert isinstance(data, list)
        if win not in self.wins.keys():
            self.vis.close(win)
            self.wins[win] = 0
        y = [data]
        x = [self.wins[win]]
        self.vis.line(Y=y, X=x, win=win, update='append', opts=opts)
        self.wins[win] += 1
示例#4
0
class VisdomLinePlotter(object):
    """Plots to Visdom"""
    def __init__(self, env_name='main'):
        self.viz = Visdom()
        self.env = env_name
        self.plots = {}

    def clear(self):
        self.viz.close(None, env=self.env)

    def plot(self, var_name, split_name, title_name, x, y):
        if var_name not in self.plots:
            self.plots[var_name] = self.viz.line(X=np.array([x, x]),
                                                 Y=np.array([y, y]),
                                                 env=self.env,
                                                 opts=dict(legend=[split_name],
                                                           title=title_name,
                                                           xlabel='Epochs',
                                                           ylabel=var_name))
        else:
            self.viz.line(X=np.array([x]),
                          Y=np.array([y]),
                          env=self.env,
                          win=self.plots[var_name],
                          name=split_name,
                          update='append')
示例#5
0
class VisdomLinePlotter(object):
    """Plots to Visdom"""

    def __init__(self, env_name='main', port=8097):
        self.viz = Visdom(port=port)
        self.env = env_name
        self.plots = {}
        self.scores_window = None
        self.image_window = None

    def plot(self, var_name, split_name, x, y, x_label='Epochs'):
        if var_name not in self.plots:
            self.plots[var_name] = self.viz.line(X=np.array([x,x]), Y=np.array([y,y]), env=self.env, opts=dict(
                legend=[split_name],
                title=var_name,
                xlabel=x_label,
                ylabel=var_name
            ))
        else:
            self.viz.line(X=np.array([x]), Y=np.array([y]), env=self.env, win=self.plots[var_name], name=split_name, update = 'append')

    def close_window(self, var_name):
        self.viz.close(self.plots[var_name])
        del self.plots[var_name]
        
    def images(self, images):
        if self.image_window != None:
            self.viz.close(self.image_window)
            
        self.image_window = self.viz.images(images, nrow=3, env=self.env,
                                            opts=dict(nrow=2, title='Images Batch'))
class VisdomLinePlotter(object):
    """Plots to Visdom"""
    def __init__(self, env_name='main'):
        self.viz = Visdom(port=8090, env=env_name)
        self.env = env_name
        self.plots = {}
        self.scores_window = None
        self.image_window = None

    def plot(self, var_name, split_name, x, y, x_label='Epochs'):
        if var_name not in self.plots:
            self.plots[var_name] = self.viz.line(X=np.array([x, x]),
                                                 Y=np.array([y, y]),
                                                 env=self.env,
                                                 opts=dict(legend=[split_name],
                                                           title=var_name,
                                                           xlabel=x_label,
                                                           ylabel=var_name))
        else:
            self.viz.line(X=np.array([x]),
                          Y=np.array([y]),
                          env=self.env,
                          win=self.plots[var_name],
                          name=split_name,
                          update='append')

    def close_window(self, var_name):
        self.viz.close(self.plots[var_name])
        del self.plots[var_name]

    def plot_voxels(self, name, voxels, title, savePLY=False):
        v, f = cubes2mesh(name, voxels, save=savePLY)
        self.viz.mesh(X=v, Y=f, opts=dict(opacity=0.5, title=title))
示例#7
0
class VisdomLogger:
    def __init__(self, visdom_env='main', log_every=10, prefix=''):
        self.vis = None
        self.log_every = log_every
        self.prefix = prefix
        if visdom_env is not None:
            self.vis = Visdom(env=visdom_env)
            self.vis.close()

    def on_batch_end(self, state):
        iters = state['iters']
        if self.log_every != -1 and iters % self.log_every == 0:
            self.log(iters, state['metrics'])

    def on_epoch_end(self, state):
        self.log(state['iters'], state['metrics'])

    def log(self, iters, xs, store_history=[]):
        if self.vis is None:
            return

        for name, x in xs.items():
            name = self.prefix + name
            if isinstance(x, (float, int)):
                self.vis.line(X=[iters],
                              Y=[x],
                              update='append',
                              win=name,
                              opts=dict(title=name),
                              name=name)
            elif isinstance(x, str):
                self.vis.text(x, win=name, opts=dict(title=name))
            elif isinstance(x, torch.Tensor):
                if x.numel() == 1:
                    self.vis.line(X=[iters],
                                  Y=[x.item()],
                                  update='append',
                                  win=name,
                                  opts=dict(title=name),
                                  name=name)
                elif x.dim() == 2:
                    self.vis.heatmap(x, win=name, opts=dict(title=name))
                elif x.dim() == 3:
                    self.vis.image(x,
                                   win=name,
                                   opts=dict(title=name,
                                             store_history=name
                                             in store_history))
                elif x.dim() == 4:
                    self.vis.images(x,
                                    win=name,
                                    opts=dict(title=name,
                                              store_history=name
                                              in store_history))
                else:
                    assert False, "incorrect tensor dim"
            else:
                assert False, "incorrect type " + x.__class__.__name__
示例#8
0
def visdom_plot(total_num_steps, cum_reward):
    from visdom import Visdom

    global vis
    global win1
    global win2
    global avg_reward

    if vis is None:
        vis = Visdom()
        assert vis.check_connection()

        # Close all existing plots
        vis.close()

    # Running average for curve smoothing
    # avg_reward = avg_reward * 0.9 + 0.1 * cum_reward
    avg_reward = cum_reward / total_num_steps

    X.append(total_num_steps)
    Y1.append(cum_reward)
    Y2.append(avg_reward)

    # The plot with the handle 'win' is updated each time this is called
    win1 = vis.line(
        X=np.array(X),
        Y=np.array(Y1),
        opts=dict(
            #title = 'All Environments',
            xlabel='Total number of steps',
            ylabel='Cumulative reward',
            ytickmin=0,
            #ytickmax=1,
            #ytickstep=0.1,
            #legend=legend,
            #showlegend=True,
            width=900,
            height=500),
        win=win1)

    # The plot with the handle 'win' is updated each time this is called
    win2 = vis.line(
        X=np.array(X),
        Y=np.array(Y2),
        opts=dict(
            # title = 'All Environments',
            xlabel='Total number of episodes',
            ylabel='Average Reward',
            ytickmin=0,
            # ytickmax=1,
            # ytickstep=0.1,
            # legend=legend,
            # showlegend=True,
            width=900,
            height=500),
        win=win2)
示例#9
0
 def testCallbackUpdateGraph(self):
     with subprocess.Popen(['python', '-m', 'visdom.server', '-port', str(self.port)]) as proc:
         # wait for visdom server startup (any better way?)
         viz = Visdom(server=self.host, port=self.port)
         for attempt in range(5):
             time.sleep(1.0)  # seconds
             if viz.check_connection():
                 break
         assert viz.check_connection()
         viz.close()
         self.model.update(self.corpus)
         proc.kill()
示例#10
0
class VisdomLogger():
    """
    Logger that uses visdom to create learning curves
    Parameters
    ----------
    - env: str, name of the visdom environment
    - log_checkpoints: bool, whether to use checkpoints or epoch averages
        for training loss
    - legend: tuple, names of the different losses that will be plotted.
    """
    def __init__(self,
                 server='http://localhost',
                 port=8097):
        if Visdom is None:
            warnings.warn("Couldn't import visdom: `pip install visdom`")
        else:
            self.viz = Visdom(server=server, port=port)
            # self.viz.delete_env()

    def deleteWindow(self, win):
        self.viz.close(win=win)
        
    def appendLine(self, name, win, X, Y, xlabel='empty', ylabel='empty'):
        if xlabel == 'empty' or ylabel == 'empty':
            self.viz.line(X=X, Y=Y, win=win, name=name, update='append', opts=dict(title="Loss"))
        else:
            self.viz.line(X=X, Y=Y, win=win, name=name, update='append', opts=dict(title="Loss", xlabel=xlabel, ylabel=ylabel, showlegend=True))

    def plotLine(self, name, win, X, Y):
        self.viz.line(X=X, Y=Y, win=win, name=name)

    def plotImage(self, image, win, title="Image", caption="Just a Image"):
        self.viz.image(image,
                     win=win,
                     opts=dict(title=title, caption=caption))

    def plotImages(self, images, win, nrow, caption="Validation Output"):
        self.viz.images(images,
                        win=win,
                        nrow=nrow,
                        opts=dict(caption=caption))

    def plot3dScatter(self, point, win):
        print("Point is", point)
        self.viz.scatter(X = point,
                        win=win,
                        opts=dict(update='update'))
示例#11
0
def create_session(**kwargs):
    """
    Creates a visdom session

    Parameters
    ----------
    kwargs : ...

    Returns
    -------
    object
        a visdom session
    """

    session = Visdom(**kwargs)
    session.close(None)
    return session
示例#12
0
class Visualizer:
    def __init__(self, env="main"):
        self._viz = Visdom(env=env, use_incoming_socket=False)
        self._viz.close(env=env)

    def plot_line(self, values, steps, name, legend=None):
        if legend is None:
            opts = dict(title=name)
        else:
            opts = dict(title=name, legend=legend)

        self._viz.line(X=numpy.column_stack(steps),
                       Y=numpy.column_stack(values),
                       win=name,
                       update='append',
                       opts=opts)

    def plot_text(self, text, title, pre=True):
        _width = max([len(x) for x in text.split("\n")]) * 10
        _heigth = len(text.split("\n")) * 20
        _heigth = max(_heigth, 120)
        if pre:
            text = "<pre>{}</pre>".format(text)

        self._viz.text(text,
                       win=title,
                       opts=dict(title=title,
                                 width=min(_width, 400),
                                 height=min(_heigth, 400)))

    def plot_scatter(self, data, labels, title):
        X = numpy.concatenate(data, axis=0)
        Y = numpy.concatenate(
            [numpy.full(len(d), i) for i, d in enumerate(data, 1)], axis=0)
        self._viz.scatter(win=title,
                          X=X,
                          Y=Y,
                          opts=dict(legend=labels,
                                    title=title,
                                    markersize=5,
                                    webgl=True,
                                    width=600,
                                    height=600,
                                    markeropacity=0.5))
    def testCallbackUpdateGraph(self):

        # Popen have no context-manager in 2.7, for this reason - try/finally.
        try:
            # spawn visdom.server
            proc = subprocess.Popen(['python', '-m', 'visdom.server', '-port', str(self.port)])

            # wait for visdom server startup (any better way?)
            time.sleep(3)

            viz = Visdom(server=self.host, port=self.port)
            assert viz.check_connection()

            # clear screen
            viz.close()

            self.model.update(self.corpus)
        finally:
            proc.kill()
示例#14
0
class VisdomSummary(object):
    def __init__(self, port=None, env=None):
        self.vis = Visdom(port=port, env=env)
        self.opts = dict()

    def scalar(self, win, name, x, y, remove=False):
        if not hasattr(self, '__scalar'):
            self.__scalar = Scalar()

        opts = dict(title=win)
        self.__scalar.update(self.vis,
                             win,
                             name,
                             x,
                             y,
                             opts=opts,
                             remove=remove)

    def bar(self, win, x, rownames=None):
        if rownames is None:
            rownames = ['{}'.format(i) for i in range(x.size(0))]
        opts = dict(title=win, rownames=rownames)
        self.vis.bar(X=x, win=win, opts=opts)

    def image2d(self, win, name, img, caption=None, nrow=3):
        if not hasattr(self, '__image2d'):
            self.__image2d = Image2D()

        self.__image2d.update(self.vis, win, name, img, caption, nrow)

    def image3d(self, win, name, img):
        raise NotImplementedError

    def text(self, win, text):
        self.vis.text(text, win=win)

    def close(self, win=None):
        self.vis.close(win=win)

    def save(self):
        self.vis.save([self.vis.env])
    def testCallbackUpdateGraph(self):

        # Popen have no context-manager in 2.7, for this reason - try/finally.
        try:
            # spawn visdom.server
            proc = subprocess.Popen(
                ['python', '-m', 'visdom.server', '-port',
                 str(self.port)])

            # wait for visdom server startup (any better way?)
            time.sleep(3)

            viz = Visdom(server=self.host, port=self.port)
            assert viz.check_connection()

            # clear screen
            viz.close()

            self.model.update(self.corpus)
        finally:
            proc.kill()
示例#16
0
def visdom_plot(total_num_steps, mean_reward):
    # Lazily import visdom so that people don't need to install visdom
    # if they're not actually using it
    from visdom import Visdom

    global vis
    global win
    global avg_reward

    if vis is None:
        vis = Visdom()
        assert vis.check_connection()

        # Close all existing plots
        vis.close()

    # Running average for curve smoothing
    avg_reward = avg_reward * 0.9 + 0.1 * mean_reward

    X.append(total_num_steps)
    Y.append(avg_reward)

    # The plot with the handle 'win' is updated each time this is called
    win = vis.line(
        X=np.array(X),
        Y=np.array(Y),
        opts=dict(
            #title = 'All Environments',
            xlabel='Total time steps',
            ylabel='Reward per episode',
            ytickmin=0,
            #ytickmax=1,
            #ytickstep=0.1,
            #legend=legend,
            #showlegend=True,
            width=900,
            height=500),
        win=win)
示例#17
0
    def testCallbackUpdateGraph(self):

        # create a new process for visdom.server
        proc = subprocess.Popen(['python', '-m', 'visdom.server'])

        # wait for visdom server startup
        time.sleep(3)

        # create visdom object
        print(FLAGS.port)
        print(FLAGS.server)
        viz = Visdom(port=FLAGS.port, server=FLAGS.server)

        # check connection
        assert viz.check_connection()

        # clear screen
        viz.close()

        # test callback's update graph function
        try:
            self.model.update(self.corpus)
            # raise AttributeError("test")
        except Exception as e:
            print(e)
            # kill visdom.server
            proc.kill()
            print('server killed')
            self.assertTrue(0)

        # kill visdom.server
        try:
            proc.wait(timeout=3)
        except subprocess.TimeoutExpired:
            proc.kill()
            print('server killed')
示例#18
0
    k = [0, 7, 2, 3, 6, 7, 1, 1, 5, 5, 7, 6]
    Y = np.c_[i, j, k]
    viz.mesh(X=X, Y=Y, opts=dict(opacity=0.5))

    # SVG plotting
    svgstr = """
    <svg height="300" width="300">
      <ellipse cx="80" cy="80" rx="50" ry="30"
       style="fill:red;stroke:purple;stroke-width:2" />
      Sorry, your browser does not support inline SVG.
    </svg>
    """
    viz.svg(svgstr=svgstr, opts=dict(title='Example of SVG Rendering'))

    # close text window:
    viz.close(win=textwindow)

    # assert that the closed window doesn't exist
    assert not viz.win_exists(textwindow), 'Closed window still exists'

    # Arbitrary visdom content
    trace = dict(x=[1, 2, 3],
                 y=[4, 5, 6],
                 mode="markers+lines",
                 type='custom',
                 marker={
                     'color': 'red',
                     'symbol': 104,
                     'size': "10"
                 },
                 text=["one", "two", "three"],
示例#19
0
            for target, actual in zip(target_output, output):
                target_str = embedded_to_string(target)
                actual_str = embedded_to_string(actual)
                print(target_str + ": " + actual_str)
                if target_str == actual_str:
                    num_correct += 1
            print("Accuracy: %f" % (num_correct / output.shape[0]))
            last_100_losses = []

        if summarize and rnn.debug:
            loss = np.mean(last_save_losses)
            last_save_losses = []

            print(random_length)
            if (epoch / summarize_freq) % 3 == 0:
                viz.close(None)
            viz.heatmap(
                mhx['memory'][0],
                opts=dict(xtickstep=10,
                          ytickstep=2,
                          title='Memory, t: ' + str(epoch) + ', num: ' +
                          q_and_a_to_string(input_data[0], output[0]),
                          ylabel='layer * time',
                          xlabel='mem_slot * mem_size'))

            all_read_weights = torch.stack(
                [x['read_weights'][0] for x in all_mems])
            viz.heatmap(
                all_read_weights.squeeze(),
                opts=dict(xtickstep=10,
                          ytickstep=2,
示例#20
0
### TORCH ###

import torch

use_cuda = torch.cuda.is_available()
device = torch.device("cuda:0" if use_cuda else "cpu")
Tensor = torch.cuda.FloatTensor if use_cuda else torch.FloatTensor

print("Using GPU:" + str(use_cuda))

### VISDOM ###

from visdom import Visdom

viz = Visdom(port=8097)
viz.close()

optsLossTrain = {
    'title': 'Loss Train',
    'xlabel': 'Epoch',
    'width': 800,
    'height': 400,
}

winLossTrain = viz.line(Y=Tensor([0]), X=Tensor([0]), opts=optsLossTrain)

optsLossTest = {
    'title': 'Loss Test',
    'xlabel': 'Epoch',
    'width': 800,
    'height': 400,
示例#21
0
文件: demo.py 项目: Garvit244/visdom
z = [0, 0, 0, 0, 1, 1, 1, 1]
X = np.c_[x, y, z]
i = [7, 0, 0, 0, 4, 4, 6, 6, 4, 0, 3, 2]
j = [3, 4, 1, 2, 5, 6, 5, 2, 0, 1, 6, 3]
k = [0, 7, 2, 3, 6, 7, 1, 1, 5, 5, 7, 6]
Y = np.c_[i, j, k]
viz.mesh(X=X, Y=Y, opts=dict(opacity=0.5))

# SVG plotting
svgstr = """
<svg height="300" width="300">
  <ellipse cx="80" cy="80" rx="50" ry="30"
   style="fill:red;stroke:purple;stroke-width:2" />
  Sorry, your browser does not support inline SVG.
</svg>
"""
viz.svg(
    svgstr=svgstr,
    opts=dict(title='Example of SVG Rendering')
)

# close text window:
viz.close(win=textwindow)

# PyTorch tensor
try:
    import torch
    viz.line(Y=torch.Tensor([[0., 0.], [1., 1.]]))
except ImportError:
    print('Skipped PyTorch example')
示例#22
0
class VisdomReporter(Reporter):
    def __init__(self, port=6006, save_dir=None):
        from visdom import Visdom

        super(VisdomReporter, self).__init__(save_dir)
        self._viz = Visdom(port=port, env=self._now)
        self._lines = defaultdict()
        assert self._viz.check_connection(), f"""
        Please launch visdom.server before calling VisdomReporter.
        $python -m visdom.server -port {port}
        """

    def add_scalar(self, x, name: str, idx: int, **kwargs):
        self.add_scalars({name: x}, name=name, idx=idx, **kwargs)

    def add_scalars(self, x: dict, name, idx: int, **kwargs):
        x = {k: self._to_numpy(v) for k, v in x.items()}
        num_lines = len(x)
        is_new = self._lines.get(name) is None
        self._lines[name] = 1
        for k, v in x.items():
            self._register_data(v, k, idx)
        opts = dict(title=name,
                    legend=list(x.keys()))
        opts.update(**kwargs)
        X = np.column_stack((self._to_numpy(idx) for _ in range(num_lines)))
        Y = np.column_stack(x.values())
        self._viz.line(X=X, Y=Y, update=None if is_new else "append", win=name, opts=opts)

    def add_parameters(self, x, name: str, idx: int, **kwargs):
        # todo
        raise NotImplementedError

    def add_text(self, x, name: str, idx: int):
        self._register_data(x, name, idx)
        self._viz.text(x)

    def add_image(self, x, name: str, idx: int):
        x, dim = self._tensor_type_check(x)
        assert dim == 3
        self._viz.image(self._normalize(x), opts=dict(title="name", caption=str(idx)))

    def add_images(self, x, name: str, idx: int):
        x, dim = self._tensor_type_check(x)
        assert dim == 4
        self._viz.images(self._normalize(x), opts=dict(title="name", caption=str(idx)))

    def _to_numpy(self, x):
        if isinstance(x, numbers.Number):
            x = np.array([x])
        elif "Tensor" in str(type(x)):
            x = x.numpy()
        return

    def __exit__(self, exc_type, exc_val, exc_tb):
        super(VisdomReporter, self).__exit__(exc_type, exc_val, exc_tb)
        self._viz.close()

    @staticmethod
    def _normalize(x):
        # normalize tensor values in (0, 1)
        _min, _max = x.min(), x.max()
        return (x - _min) / (_max - _min)
示例#23
0
class Visualizer:
    def __init__(self,
                 env="main",
                 server="http://localhost",
                 port=8097,
                 base_url="/",
                 http_proxy_host=None,
                 http_proxy_port=None):
        self._viz = Visdom(env=env,
                           server=server,
                           port=port,
                           http_proxy_host=http_proxy_host,
                           http_proxy_port=http_proxy_port,
                           use_incoming_socket=False)
        self._viz.close(env=env)

    def plot_line(self, values, steps, name, legend=None):
        if legend is None:
            opts = dict(title=name)
        else:
            opts = dict(title=name, legend=legend)

        self._viz.line(X=numpy.column_stack(steps),
                       Y=numpy.column_stack(values),
                       win=name,
                       update='append',
                       opts=opts)

    def plot_text(self, text, title, pre=True):
        _width = max([len(x) for x in text.split("\n")]) * 10
        _heigth = len(text.split("\n")) * 20
        _heigth = max(_heigth, 120)
        if pre:
            text = "<pre>{}</pre>".format(text)

        self._viz.text(text,
                       win=title,
                       opts=dict(title=title,
                                 width=min(_width, 400),
                                 height=min(_heigth, 400)))

    def plot_bar(self, data, labels, title):
        self._viz.bar(win=title,
                      X=data,
                      opts=dict(legend=labels, stacked=False, title=title))

    def plot_scatter(self, data, labels, title):
        X = numpy.concatenate(data, axis=0)
        Y = numpy.concatenate(
            [numpy.full(len(d), i) for i, d in enumerate(data, 1)], axis=0)
        self._viz.scatter(win=title,
                          X=X,
                          Y=Y,
                          opts=dict(legend=labels,
                                    title=title,
                                    markersize=5,
                                    webgl=True,
                                    width=400,
                                    height=400,
                                    markeropacity=0.5))

    def plot_heatmap(self, data, labels, title):
        self._viz.heatmap(
            win=title,
            X=data,
            opts=dict(
                title=title,
                columnnames=labels[1],
                rownames=labels[0],
                width=700,
                height=700,
                layoutopts={
                    'plotly': {
                        'xaxis': {
                            'side': 'top',
                            'tickangle': -60,
                            # 'autorange': "reversed"
                        },
                        'yaxis': {
                            'autorange': "reversed"
                        },
                    }
                }))
示例#24
0
class Custom_Visdom:
    def __init__(self, model_name, transfer_learning):

        self.vis = Visdom()
        if transfer_learning:
            self.vis.env = f'{model_name}_tl'
        else:
            self.vis.env = model_name
        self.vis.close()
        self.mod = sys.modules[__name__]
        self.pred_error = {}

    def loss_plot(self, checkpoint):

        opts = dict(title='Loss curve',
                    legend=['train loss', 'valid_loss'],
                    showlegend=True)

        length = 50

        try:
            if len(checkpoint.epoch_list) < length:
                self._update_loss_plot(checkpoint.epoch_list,
                                       checkpoint.train_loss_list_per_epoch,
                                       checkpoint.valid_loss_list, opts,
                                       self.loss_plt)
            else:
                self._update_loss_plot(
                    checkpoint.epoch_list[-length:],
                    checkpoint.train_loss_list_per_epoch[-length:],
                    checkpoint.valid_loss_list[-length:], opts, self.loss_plt)
        except:
            self.loss_plt = self.vis.line(
                X=np.array([checkpoint.epoch_list, checkpoint.epoch_list]).T,
                Y=np.array([
                    checkpoint.train_loss_list_per_epoch,
                    checkpoint.valid_loss_list
                ]).T,
                opts=opts)

    def predict_plot(self, y_df, target='pre'):

        if target not in ['pre', 'trans']:
            raise ValueError('Wrong target')

        opts = dict(legend=y_df.columns.tolist(), showlegend=True)

        for y_target in y_df.columns:
            if y_target != 'pred':
                mse = mse_AIFrenz(y_df[y_target], y_df['pred'])
                self.pred_error[y_target] = mse

        if target == 'pre':
            opts['title'] = 'pretrain set prediction'
            try:
                self._update_predict_curve(y_df, opts, self.pretrain_pred)
            except:
                self.pretrain_pred = self.vis.line(X=np.tile(
                    y_df.index, (y_df.shape[1], 1)).T,
                                                   Y=y_df.values,
                                                   opts=opts)

        else:
            opts['title'] = 'Y18 prediction'
            try:
                self._update_predict_curve(y_df, opts, self.trans_pred)
            except:
                self.trans_pred = self.vis.line(X=np.tile(
                    y_df.index, (y_df.shape[1], 1)).T,
                                                Y=y_df.values,
                                                opts=opts)

    def print_error(self):

        text = '<h4> Error for each target </h4><br>'

        for keys, values in self.pred_error.items():
            text += f'{keys}: {values:7.3f} <br>'

        try:
            self.vis.text(text, win=self.error, append=False)
        except:
            self.error = self.vis.text(text)

    def print_params(self, params):

        text = '<h4> Hyperparameter list </h4><br>'

        for keys, values in params.items():
            text += f'{keys}: {values} <br>'

        self.vis.text(text)

    def print_training(self, EPOCH, epoch, training_time, avg_train_loss,
                       valid_loss, patience, counter):

        iter_time = time.time() - training_time

        text = '<h4> Training status </h4><br>'\
            f'\r Epoch: {epoch:3d}/{str(EPOCH):3s}<br>'\
            f'train time: {int(iter_time//60):2d}m {iter_time%60:5.2f}s<br>'\
            f'avg train loss: {avg_train_loss:7.3f}<br>'\
            f'valid loss: {valid_loss:7.3f}<br>'\
            f'\r EarlyStopping: {">"*counter + "-"*(patience-counter)} |<br>'\
            # f'{"-----"*17}<br>'

        try:
            self.vis.text(text, win=self.training, append=False)
        except:
            self.training = self.vis.text(text)

    def _update_predict_curve(self, y_df, opts, win):

        self.vis.line(X=np.tile(y_df.index, (y_df.shape[1], 1)).T,
                      Y=y_df.values,
                      opts=opts,
                      win=win,
                      update='replace')

    def _update_loss_plot(self, epoch_list, train_loss_list_per_epoch,
                          valid_loss_list, opts, win):

        self.vis.line(X=np.array([epoch_list, epoch_list]).T,
                      Y=np.array([train_loss_list_per_epoch,
                                  valid_loss_list]).T,
                      opts=opts,
                      win=win,
                      update='replace')
class Visualizer(object):
    def __init__(self, config: Config):
        # logging_level = logging._checkLevel("INFO")
        # logging.getLogger().setLevel(logging_level)
        # VisdomServer.start_server(port=VisdomServer.DEFAULT_PORT, env_path=config.vis_env_path)
        self.reinit(config)

    def reinit(self, config):
        self.config = config
        try:
            self.visdom = Visdom(env=config.visdom_env)
            self.connected = self.visdom.check_connection()
            if not self.connected:
                print(
                    "Visdom server hasn't started, please run command 'python -m visdom.server' in terminal."
                )
                # try:
                #     print("Visdom server hasn't started, do you want to start it? ")
                #     if 'y' in input("y/n: ").lower():
                #         os.popen('python -m visdom.server')
                # except Exception as e:
                #     warn(e)
        except ConnectionError as e:
            warn("Can't open Visdom because " + e.strerror)
        with open(self.config.log_file, 'a') as f:
            info = "[{time}]Initialize Visdom\n".format(
                time=timestr('%m-%d %H:%M:%S'))
            info += str(self.config)
            f.write(info + '\n')

    def save(self, save_path: str = None) -> str:
        retstr = self.visdom.save([
            self.config.visdom_env
        ])  # return current environments name in format of json
        try:
            ret = json.loads(retstr)[0]
            if ret == self.config.visdom_env:
                if isinstance(save_path, str):
                    from shutil import copy
                    copy(self.config.vis_env_path, save_path)
                    print('Visdom Environment has saved into ' + save_path)
                else:
                    print('Visdom Environment has saved into ' +
                          self.config.vis_env_path)
                with open(self.config.vis_env_path, 'r') as fp:
                    env_str = json.load(fp)
                    return env_str
        except Exception as e:
            warn(e)
        return None

    def clear(self):
        self.visdom.close()

    @staticmethod
    def _to_numpy(value):
        if isinstance(value, t.Tensor):
            value = value.cpu().detach().numpy()
        elif isinstance(value, np.ndarray):
            pass
        else:
            value = np.array(value)
        if value.ndim == 0:
            value = value[np.newaxis]
        return value

    def plot(self, y, x, line_name, win, legend=None):
        # type:(float,float,str,str,list)->bool
        """Plot a (sequence) of y point(s) (each) with one x value(s), loop this method to draw whole plot"""
        update = None if not self.visdom.win_exists(win) else 'append'
        opts = dict(title=win)
        if legend is not None:
            opts["legend"] = legend
        y = Visualizer._to_numpy(y)
        x = Visualizer._to_numpy(x)
        return win == self.visdom.line(y,
                                       x,
                                       win=win,
                                       env=self.config.visdom_env,
                                       update=update,
                                       name=line_name,
                                       opts=opts)

    def bar(self, y, win, rowindices=None):
        opts = dict(title=win)
        y = Visualizer._to_numpy(y)
        if isinstance(rowindices, list) and len(rowindices) == len(y):
            opts["rownames"] = rowindices
        return win == self.visdom.bar(y,
                                      win=win,
                                      env=self.config.visdom_env,
                                      opts=opts)

    def log(self, msg, name, append=True, log_file=None):
        # type:(Visualizer,str,str,bool,str)->bool
        if log_file is None:
            log_file = self.config.log_file
        info = "[{time}]{msg}".format(time=timestr('%m-%d %H:%M:%S'), msg=msg)
        append = append and self.visdom.win_exists(name)
        ret = self.visdom.text(info,
                               win=name,
                               env=self.config.visdom_env,
                               opts=dict(title=name),
                               append=append)
        mode = 'a+' if append else 'w+'
        with open(log_file, mode) as f:
            f.write(info + '\n')
        return ret == name

    def log_process(self, num, total, msg, name, append=True):
        # type:(Visualizer,int,int,str,str,bool)->bool
        info = "[{time}]{msg}".format(time=timestr('%m-%d %H:%M:%S'), msg=msg)
        append = append and self.visdom.win_exists(name)
        ret = self.visdom.text(info,
                               win=(name),
                               env=self.config.visdom_env,
                               opts=dict(title=name),
                               append=append)
        with open(self.config.log_file, 'a') as f:
            f.write(info + '\n')
        self.process_bar(num, total, msg)
        return ret == name

    def process_bar(self, num, total, msg='', length=50):
        rate = num / total
        rate_num = int(rate * 100)
        clth = int(rate * length)
        if len(msg) > 0:
            msg += ':'
        # msg = msg.replace('\n', '').replace('\r', '')
        if rate_num == 100:
            r = '\r%s[%s%d%%]\n' % (
                msg,
                '*' * length,
                rate_num,
            )
        else:
            r = '\r%s[%s%s%d%%]' % (
                msg,
                '*' * clth,
                '-' * (length - clth),
                rate_num,
            )
        sys.stdout.write(r)
        sys.stdout.flush()
        return r.replace('\r', ':')
示例#26
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--dataPath', type=str, default='data')
    parser.add_argument('--batchSz', type=int, default=64)
    parser.add_argument('--nEpochs', type=int, default=300)
    parser.add_argument('--num_workers',
                        default=8,
                        type=int,
                        help='Number of workers used in dataloading')
    parser.add_argument('--no_cuda', type=str2bool, default=False)
    parser.add_argument('--procName',
                        type=str,
                        default='train',
                        help='process name')
    parser.add_argument('--save',
                        type=str,
                        default='work/densenet.base',
                        help='Location to save checkpoint models')
    parser.add_argument('--resume',
                        '-r',
                        type=str,
                        default=None,
                        help='resume from checkpoint')
    parser.add_argument('--seed', type=int, default=1)
    parser.add_argument(
        '--use_visdom',
        action='store_false',
        help='Whether or not to use visdom.the default is true')
    parser.add_argument('--viz_ip',
                        type=str,
                        default='http://localhost',
                        help='server ip for visdom')
    parser.add_argument('--viz_port',
                        type=int,
                        default=8098,
                        help='server port for visdom')
    parser.add_argument('--opt',
                        type=str,
                        default='sgd',
                        choices=('sgd', 'adam', 'rmsprop'))
    args = parser.parse_args()

    args.cuda = not args.no_cuda and torch.cuda.is_available()

    classes = args.classes

    torch.manual_seed(args.seed)
    if args.cuda and torch.cuda.is_available():
        torch.cuda.manual_seed(args.seed)
    setproctitle.setproctitle(args.procName)

    os.makedirs(args.save, exist_ok=True)

    if args.use_visdom:
        try:
            viz = Visdom(server=args.viz_ip, port=args.viz_port)
            assert viz.check_connection()
            viz.close()
            vis_title = args.procName
            epoch_plot = create_vis_plot(viz, vis_title)

        except BaseException as err:
            raise BaseException('Visdom connect error...')

    normMean = [0.49139968, 0.48215827, 0.44653124]
    normStd = [0.24703233, 0.24348505, 0.26158768]
    normTransform = transforms.Normalize(normMean, normStd)

    trainTransform = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(), normTransform
    ])
    testTransform = transforms.Compose([transforms.ToTensor(), normTransform])
    kwargs = {'pin_memory': True} if args.cuda else {}
    # trainLoader = DataLoader(
    #     dset.CIFAR10(root='cifar', train=True, download=True,
    #                  transform=trainTransform),
    #     batch_size=args.batchSz, shuffle=True, **kwargs)
    # testLoader = DataLoader(
    #     dset.CIFAR10(root='cifar', train=False, download=True,
    #                  transform=testTransform),
    #     batch_size=args.batchSz, shuffle=False, **kwargs)
    trainLoader = DataLoader(myImageFolder(root=ospj(args.dataPath, 'train'),
                                           label=ospj(args.dataPath,
                                                      '../train.list'),
                                           transform=trainTransform),
                             batch_size=args.batchSz,
                             shuffle=True,
                             num_workers=args.num_workers,
                             **kwargs)
    testLoader = DataLoader(myImageFolder(root=ospj(args.dataPath, 'test'),
                                          label=ospj(args.dataPath,
                                                     '../test.list'),
                                          transform=testTransform),
                            batch_size=args.batchSz,
                            shuffle=False,
                            num_workers=args.num_workers,
                            **kwargs)

    start_epoch = 0

    if args.resume:
        print('===>Resuming from checkpoint:{} ..'.format(args.resume))
        assert os.path.isfile(
            args.resume), 'Error:no checkpoint:%s found' % args.resume
        checkpoint = torch.load(args.resume)
        net = checkpoint['net']
        start_epoch = checkpoint['epoch']
        print('===>Resume from checkpoint:{},start epoch:{} ..'.format(
            args.resume, start_epoch))
    else:
        net = densenet.DenseNet(growthRate=12,
                                depth=100,
                                reduction=0.5,
                                bottleneck=True,
                                nClasses=10)
    # use load_state_dict to reload parameters of model
    # if args.resume:
    #     print('===>Resuming from checkpoint:{} ..'.format(args.resume))
    #     net = densenet.DenseNet(growthRate=12,depth=100,reduction=0.5,bottleneck=True,nClasses=10)
    #     checkpoint = torch.load(args.resume, map_location=lambda storage, loc: storage)
    #     start_epoch = checkpoint['epoch']
    #     net.load_state_dict(checkpoint['state_dict'])
    # else:
    #     net = densenet.DenseNet(growthRate=12, depth=100, reduction=0.5,
    #                         bottleneck=True, nClasses=10)

    print('  + Number of params: {}'.format(
        sum([p.data.nelement() for p in net.parameters()])))

    if args.cuda and torch.cuda.is_available():
        net = net.cuda()
        # pdb.set_trace()
        if len(os.getenv('CUDA_VISIBLE_DEVICES')) > 1:
            net = torch.nn.DataParallel(net)
        cudnn.benchmark = True

    if args.opt == 'sgd':
        optimizer = optim.SGD(net.parameters(),
                              lr=1e-1,
                              momentum=0.9,
                              weight_decay=1e-4)
    elif args.opt == 'adam':
        optimizer = optim.Adam(net.parameters(), weight_decay=1e-4)
    elif args.opt == 'rmsprop':
        optimizer = optim.RMSprop(net.parameters(), weight_decay=1e-4)

    # trainF = open(os.path.join(args.save, 'train.csv'), 'w')
    # testF = open(os.path.join(args.save, 'test.csv'), 'w')

    for epoch in range(start_epoch, args.nEpochs + 1):
        adjust_opt(args.opt, optimizer, epoch)
        train(args, epoch, net, trainLoader, optimizer, epoch_plot, viz)
        test(args, epoch, net, testLoader, optimizer, epoch_plot, viz)
        state = {
            'net':
            net.module if len(os.getenv("CUDA_VISIBLE_DEVICES")) > 1 else net,
            # 'state_dict': net.module.state_dict() if len(os.getenv("CUDA_VISIBLE_DEVICES")) > 1 else net.state_dict(),
            'epoch': epoch
        }
        torch.save(state, os.path.join(args.save, 'latest.pth'))
示例#27
0
            self.m.loss = totalvalue(rewards) * processlogprob(probs)
            self.m.optimize()
            if (i % plotinterval == 0):
                viz.line(
                    X=list(range(len(actions))),
                    Y=actions,
                    opts=dict(legend=["a_mu", "a_sigma", "w_mu", "w_sigma"]))

                viz.line(X=list(range(len(states))),
                         Y=states,
                         opts=dict(legend=["v", "x", "y", "h", "d"]))


def testcar():
    car = Car()
    car.x = -5
    car.y = -5
    car.v = 1
    car.w = 0.5
    xs = []
    ys = []
    for t in range(100):
        car.step(t * 0.01)
        xs.append(car.x)
        ys.append(car.y)
    viz.quiver(X=xs, Y=ys)


viz.close(win=None)
sim = Simulation()
sim.run(10001, 1000)
示例#28
0
class NetFramework():
    def __init__(self, defaults_path):

        parser = argparse.ArgumentParser(
            description='Net framework arguments description')
        parser.add_argument('--experiment',
                            nargs='?',
                            type=str,
                            default='experiment',
                            help='Experiment name')
        parser.add_argument('--model',
                            nargs='?',
                            type=str,
                            default='',
                            help='Architecture to use')
        parser.add_argument('--modelparam',
                            type=str,
                            default='{}',
                            help='Experiment model parameters')
        parser.add_argument('--dataset',
                            nargs='?',
                            type=str,
                            default='',
                            help='Dataset key specified in dataconfig_*.json')
        parser.add_argument('--datasetparam',
                            type=str,
                            default='{}',
                            help='Experiment dataset parameters')
        parser.add_argument('--imsize',
                            nargs='?',
                            type=int,
                            default=200,
                            help='Image resize parameter')

        parser.add_argument('--visdom',
                            action='store_true',
                            help='If included shows visdom visulaization')
        parser.add_argument(
            '--show_rate',
            nargs='?',
            type=int,
            default=4,
            help='Visdom show after num of iterations (used with --visdom)')
        parser.add_argument('--print_rate',
                            nargs='?',
                            type=int,
                            default=4,
                            help='Print after num of iterations')
        parser.add_argument(
            '--save_rate',
            nargs='?',
            type=int,
            default=10,
            help=
            'Save after num of iterations (if --save_rate=0 then no save is done during training)'
        )

        parser.add_argument('--use_cuda',
                            nargs='?',
                            type=int,
                            default=0,
                            help='GPU device (if --use_cuda=-1 then CPU used)')
        parser.add_argument(
            '--parallel',
            action='store_true',
            help='Use multiples GPU (used only if --use_cuda>-1)')
        parser.add_argument('--epochs',
                            nargs='?',
                            type=int,
                            default=1000,
                            help='Number of epochs')
        parser.add_argument('--batch_size',
                            nargs='?',
                            type=int,
                            default=1,
                            help='Minibatch size')
        parser.add_argument('--batch_acc',
                            nargs='?',
                            type=int,
                            default=1,
                            help='Minibatch accumulation')
        parser.add_argument('--train_worker',
                            nargs='?',
                            type=int,
                            default=1,
                            help='Number of training workers')
        parser.add_argument('--test_worker',
                            nargs='?',
                            type=int,
                            default=1,
                            help='Number of testing workers')

        parser.add_argument('--optimizer',
                            nargs='?',
                            type=str,
                            default='RMSprop',
                            help='Optimizer to use')
        parser.add_argument('--optimizerparam',
                            type=str,
                            default='{}',
                            help='Experiment optimizer parameters')
        parser.add_argument('--lrschedule',
                            nargs='?',
                            type=str,
                            default='none',
                            help='LR Schedule to use')
        parser.add_argument('--loss',
                            nargs='?',
                            type=str,
                            default='',
                            help='Loss function to use')
        parser.add_argument('--lossparam',
                            type=str,
                            default='{}',
                            help='Loss function parameters')
        parser.add_argument('--resume',
                            action='store_true',
                            help='Resume training')

        parser.add_argument('--seed',
                            nargs='?',
                            type=int,
                            default=123,
                            help='Random seed (for reproducibility)')

        args = parser.parse_args()

        if args.seed != -1:
            torch.manual_seed(args.seed)
            np.random.seed(args.seed)
            random.seed(args.seed)

        # create outputs folders
        root = '../out'
        experimentpath = (os.path.join(root, args.experiment))
        folders = {
            'root_path': root,
            'experiment_path': experimentpath,
            'model_path': os.path.join(experimentpath, 'model'),
            'images_path': os.path.join(experimentpath, 'images')
        }

        for i in range(2):
            for folder, path in folders.items():
                if not os.path.isdir(path):
                    try:
                        os.mkdir(path)
                    except:
                        pass

        json.dump(vars(args),
                  open(os.path.join(experimentpath, 'args.json'), 'w'))
        args.folders = folders

        args.lossparam = json.loads(args.lossparam.replace("'", "\""),
                                    cls=Decoder)
        args.datasetparam = json.loads(args.datasetparam.replace("'", "\""),
                                       cls=Decoder)
        args.modelparam = json.loads(args.modelparam.replace("'", "\""),
                                     cls=Decoder)
        args.optimizerparam = json.loads(args.optimizerparam.replace(
            "'", "\""),
                                         cls=Decoder)

        # Parse use cuda
        self.device, self.use_parallel = parse_cuda(args)
        torch.cuda.set_device(args.use_cuda)

        # Visdom visualization
        self.visdom = args.visdom
        if self.visdom == True:
            self.vis = Visdom(use_incoming_socket=False)
            self.vis.close(env=args.experiment)
            self.visplotter = gph.VisdomLinePlotter(self.vis,
                                                    env_name=args.experiment)
            self.visheatmap = gph.HeatMapVisdom(self.vis,
                                                env_name=args.experiment)
            self.visimshow = gph.ImageVisdom(self.vis,
                                             env_name=args.experiment)
            self.vistext = gph.TextVisdom(self.vis, env_name=args.experiment)

        # Showing results rate
        self.print_rate = args.print_rate
        self.show_rate = args.show_rate
        self.save_rate = args.save_rate

        self.init_epoch = 0
        self.current_epoch = 0
        self.epochs = args.epochs
        self.folders = args.folders
        self.bestmetric = 0
        self.batch_size = args.batch_size
        self.batch_acc = args.batch_acc

        # Load datasets
        print('Loading dataset: ', args.dataset)
        self.traindataset, self.train_loader, self.dmodule = loaddataset(
            datasetname=args.dataset,
            experimentparam=args.datasetparam,
            batch_size=args.batch_size,
            worker=args.train_worker,
            config_file=os.path.join(defaults_path, 'dataconfig_train.json'))

        self.testdataset, self.test_loader, _ = loaddataset(
            datasetname=args.dataset,
            experimentparam=args.datasetparam,
            batch_size=args.batch_size,
            worker=args.test_worker,
            config_file=os.path.join(defaults_path, 'dataconfig_test.json'))

        self.warp_var_mod = import_module(self.dmodule + '.dataset')

        # Setup model
        print('Loading model: ', args.model)
        self.net, self.arch, self.mmodule = loadmodel(
            modelname=args.model,
            experimentparams=args.modelparam,
            config_file=os.path.join(defaults_path, 'modelconfig.json'))

        self.net.to(self.device)
        if self.use_parallel:
            self.net = torch.nn.DataParallel(self.net,
                                             device_ids=range(
                                                 torch.cuda.device_count()))
            cudnn.benchmark = True

        # Setup Optimizer
        print('Selecting optimizer: ', args.optimizer)
        self.optimizer = selectoptimizer(args.optimizer, self.net,
                                         args.optimizerparam)

        # Setup Learning Rate Scheduling
        print('LR Schedule: ', args.lrschedule)
        self.scheduler = selectschedule(args.lrschedule, self.optimizer)

        # Setup Loss criterion
        print('Selecting loss function: ', args.loss)
        self.criterion, self.losseval = selectloss(lossname=args.loss,
                                                   parameter=args.lossparam,
                                                   config_file=os.path.join(
                                                       defaults_path,
                                                       'loss_definition.json'))
        self.criterion.to(self.device)
        self.trlossavg = AverageMeter()
        self.vdlossavg = AverageMeter()

        # Others evaluation metrics
        print('Selecting metrics functions:')
        metrics_dict = get_metric_path(
            os.path.join(defaults_path, 'metrics.json'))
        self.metrics = dict()
        self.metrics_eval = dict()
        self.trmetrics_avg = dict()
        self.vdmetrics_avg = dict()

        for key, value in metrics_dict.items():
            self.metrics[key], self.metrics_eval[key] = selectloss(
                lossname=value['metric'],
                parameter=value['param'],
                config_file=os.path.join(defaults_path,
                                         'loss_definition.json'))
            self.metrics[key].to(self.device)
            self.trmetrics_avg[key] = AverageMeter()
            self.vdmetrics_avg[key] = AverageMeter()

        if args.resume:
            self.resume()

        signal.signal(signal.SIGTERM, self.savemodel)
        self.args = args

    def do_train(self):
        for current_epoch in range(self.init_epoch, self.epochs):
            print('epoch ', current_epoch)
            self.current_epoch = current_epoch

            # Forward over validation set
            avgloss, avgmetric = self.validation(current_epoch)
            self.scheduler.step(avgloss, current_epoch)

            # If obtained validation accuracy improvement save network in model/bestmodel.t7
            if self.bestmetric < avgmetric:
                print('Validation metric improvement ({:.3f}) in epoch {} \n'.
                      format(avgmetric, current_epoch))
                self.bestmetric = avgmetric
                self.savemodel(
                    os.path.join(self.folders['model_path'], 'bestmodel.t7'))

            save_ = True if self.save_rate != 0 and (
                current_epoch % self.save_rate) == 0 else False
            # Save netowrk after self.save_rate epochs
            if save_:
                print('Saving checkpoint epoch {}\n'.format(current_epoch))
                self.savemodel(
                    os.path.join(self.folders['model_path'],
                                 'epoch{}model.t7'.format(current_epoch)))

            # Forward and backward over training set
            self.train(current_epoch)
            self.valid_visualization(current_epoch, 3)

        # Save last model netowrk
        self.savemodel(os.path.join(self.folders['model_path'],
                                    'lastmodel.t7'))

    ## Train function
    def train(self, current_epoch):
        ttime = time.time()
        data_time = AverageMeter()
        batch_time = AverageMeter()

        self.trlossavg.new_local()
        for key, value in self.trmetrics_avg.items():
            self.trmetrics_avg[key].new_local()

        self.net.train()

        end = time.time()
        total_train = len(self.train_loader)
        for i, sample in enumerate(self.train_loader):
            data_time.update(time.time() - end)

            iteration = float(i) / total_train + current_epoch
            sample = self.warp_var_mod.warp_Variable(sample, self.device)
            images = sample['image']

            outputs = self.net(images)
            kwarg = eval(self.losseval)
            loss = self.criterion(**kwarg)
            loss.backward()
            if (i + 1) % self.batch_acc == 0 or (i + 1) == total_train:
                self.optimizer.step()

            self.trlossavg.update(loss.item(), images.size(0))
            for key, value in self.metrics_eval.items():
                kwarg = eval(self.metrics_eval[key])
                metric = self.metrics[key](**kwarg)
                self.trmetrics_avg[key].update(metric.item(), images.size(0))

            if (i + 1) % self.batch_acc == 0 or (i + 1) == total_train:
                self.optimizer.zero_grad()

            batch_time.update(time.time() - end)
            end = time.time()

            if (i % self.print_rate) == 0:
                strinfo = '| Train: [{0}][{1}/{2}]\t'
                strinfo += 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                strinfo += 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'

                print(strinfo.format(current_epoch,
                                     i + 1,
                                     total_train,
                                     batch_time=batch_time,
                                     data_time=data_time),
                      end='')

                for key, value in self.trmetrics_avg.items():
                    print('{} {:.3f} ({:.3f})\t'.format(
                        key, value.val, value.avg),
                          end='')

                print('loss {:.3f} ({:.3f})'.format(self.trlossavg.val,
                                                    self.trlossavg.avg))

            if self.visdom == True and (((i + 1) % self.show_rate) == 0 or
                                        ((i + 1) % total_train) == 0):
                info = {'loss': self.trlossavg}

                for key, value in self.trmetrics_avg.items():
                    info[key] = value

                for tag, value in info.items():
                    self.visplotter.show(tag, 'train', iteration, value.avg)
                    self.visplotter.show(tag, 'train_mean', iteration,
                                         value.total_avg)

        print('|Total time: {:.3f}'.format(time.time() - ttime))

    def validation(self, current_epoch):
        ttime = time.time()
        data_time = AverageMeter()
        batch_time = AverageMeter()

        self.vdlossavg.new_local()
        for key, value in self.vdmetrics_avg.items():
            self.vdmetrics_avg[key].new_local()

        end = time.time()
        total_valid = len(self.test_loader)
        with torch.no_grad():
            for i, sample in enumerate(self.test_loader):
                data_time.update(time.time() - end)

                iteration = float(i) / total_valid + current_epoch - 1
                sample = self.warp_var_mod.warp_Variable(sample, self.device)
                images = sample['image']

                outputs = self.net(images)
                kwarg = eval(self.losseval)
                loss = self.criterion(**kwarg)

                self.vdlossavg.update(loss.item(), images.size(0))
                for key, value in self.metrics_eval.items():
                    kwarg = eval(self.metrics_eval[key])
                    metric = self.metrics[key](**kwarg)
                    self.vdmetrics_avg[key].update(metric.item(),
                                                   images.size(0))

                batch_time.update(time.time() - end)
                end = time.time()

                if i % self.print_rate == 0:
                    strinfo = '| Valid: [{0}][{1}/{2}]\t'
                    strinfo += 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                    strinfo += 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'

                    print(strinfo.format(
                        current_epoch,
                        i + 1,
                        total_valid,
                        batch_time=batch_time,
                        data_time=data_time,
                    ),
                          end='')

                    for key, value in self.vdmetrics_avg.items():
                        print('{} {:.3f} ({:.3f})\t'.format(
                            key, value.val, value.avg),
                              end='')

                    print('loss {:.3f} ({:.3f})'.format(
                        self.vdlossavg.val, self.vdlossavg.avg))

                if self.visdom == True and current_epoch != self.init_epoch and (
                    ((i + 1) % self.show_rate) == 0 or
                    ((i + 1) % total_valid) == 0):
                    info = {'loss': self.vdlossavg}

                    for key, value in self.vdmetrics_avg.items():
                        info[key] = value

                    for tag, value in info.items():
                        self.visplotter.show(tag, 'valid', iteration,
                                             value.avg)
                        self.visplotter.show(tag, 'valid_mean', iteration,
                                             value.total_avg)

        if list(self.vdmetrics_avg.keys()):
            watch_metric = self.vdmetrics_avg[list(
                self.vdmetrics_avg.keys())[0]]
        else:
            watch_metric = self.vdlossavg
        print('|Total time: {:.3f}'.format(time.time() - ttime))

        return self.vdlossavg.avg, watch_metric.avg

    def valid_visualization(self, current_epoch, index=0, save=False):
        with torch.no_grad():
            sample = self.testdataset[index]
            sample['image'].unsqueeze_(0)

            sample = self.warp_var_mod.warp_Variable(sample, self.device)
            images = sample['image']
            img = images[0].cpu().numpy()
            if self.visdom == True:
                self.visimshow.show('Image', img)

        return 1

    def savemodel(self, modelpath='', killsignal=None):
        if modelpath == '' or killsignal is not None:
            print('Saving checkpoint epoch {}\n'.format(self.current_epoch))
            modelpath = os.path.join(
                self.folders['model_path'],
                'epoch{}model.t7'.format(self.current_epoch))
        to_save = self.net.module if self.use_parallel else self.net
        state = {
            'epoch': self.current_epoch,
            'arch': self.arch,
            'net': to_save.state_dict(),
            'optimizer': self.optimizer.state_dict(),
            'bestmetric': self.bestmetric
        }
        torch.save(state, modelpath)

        metrics_dict = {
            'train_loss': self.trlossavg,
            'valid_loss': self.vdlossavg
        }
        for key, value in self.trmetrics_avg.items():
            metrics_dict['train_' + key] = value
        for key, value in self.vdmetrics_avg.items():
            metrics_dict['valid_' + key] = value

        for tag, value in metrics_dict.items():
            np.savetxt(self.folders['experiment_path'] + '/' + tag + '.txt',
                       np.array(value.array),
                       delimiter=',',
                       fmt='%3.6f')

        if killsignal is not None:
            exit(-1)

    def loadmodel(self, modelpath):
        if os.path.isfile(modelpath):
            checkpoint = torch.load(modelpath, map_location='cpu')
            to_load = self.net.module if self.use_parallel else self.net
            to_load.load_state_dict(checkpoint['net'])
            self.optimizer.load_state_dict(checkpoint['optimizer'])
            self.current_epoch = checkpoint['epoch']
            self.arch = checkpoint['arch']
            self.bestmetric = checkpoint['bestmetric']

            files = [
                f for f in sorted(os.listdir(self.folders['experiment_path']))
                if (f.find('train_') != -1 and f.find('.txt') != -1)
            ]
            for f in files:
                narray = np.loadtxt(os.path.join(
                    self.folders['experiment_path'], f),
                                    delimiter=',')
                metric = f[6:f.find('.txt')]
                if metric == 'loss':
                    self.trlossavg.load(narray, 1)
                if metric in self.trmetrics_avg:
                    self.trmetrics_avg[metric].load(narray.tolist(), 1)

            files = [
                f for f in sorted(os.listdir(self.folders['experiment_path']))
                if (f.find('valid_') != -1 and f.find('.txt') != -1)
            ]
            for f in files:
                narray = np.loadtxt(os.path.join(
                    self.folders['experiment_path'], f),
                                    delimiter=',')
                metric = f[6:f.find('.txt')]
                if metric == 'loss':
                    self.vdlossavg.load(narray, 1)
                if metric in self.vdmetrics_avg:
                    self.vdmetrics_avg[metric].load(narray.tolist(), 1)

        else:
            raise Exception('Model not found')

    def resume(self):
        if os.path.isdir(self.folders['model_path']):
            files = [
                f for f in sorted(os.listdir(self.folders['model_path']))
                if (f.find('epoch') != -1 and f.find('model.t7') != -1)
            ]
            if files:
                self.init_epoch = max(
                    [int(f[5:f.find('model.t7')]) for f in files]) + 1
                self.loadmodel(
                    os.path.join(
                        self.folders['model_path'],
                        'epoch' + str(self.init_epoch - 1) + 'model.t7'))
                print('Resuming on epoch' + str(self.init_epoch - 1))
示例#29
0
class VisdomLogger(tu.AutoStateDict):
    """
    Log metrics to Visdom. It logs scalars and scalar tensors as plots, 3D and
    4D tensors as images, and strings as HTML.

    Args:
        visdom_env (str): name of the target visdom env
        log_every (int): batch logging freq. -1 logs on epoch ends only.
        prefix (str): prefix for all metrics name
    """
    def __init__(self, visdom_env='main', log_every=10, prefix=''):
        super(VisdomLogger, self).__init__(except_names=['vis'])
        self.vis = None
        self.log_every = log_every
        self.prefix = prefix
        if visdom_env is not None:
            self.vis = Visdom(env=visdom_env)
            self.vis.close()

    def on_batch_start(self, state):
        iters = state['iters']
        state['visdom_will_log'] = (self.log_every != -1
                                    and iters % self.log_every == 0)

    @torch.no_grad()
    def on_batch_end(self, state):
        iters = state['iters']
        if self.log_every != -1 and iters % self.log_every == 0:
            self.log(iters, state['metrics'])

    def on_epoch_end(self, state):
        self.log(state['iters'], state['metrics'])

    def log(self, iters, xs, store_history=[]):
        if self.vis is None:
            return

        for name, x in xs.items():
            name = self.prefix + name
            if isinstance(x, (float, int)):
                self.vis.line(X=[iters],
                              Y=[x],
                              update='append',
                              win=name,
                              opts=dict(title=name),
                              name=name)
            elif isinstance(x, str):
                self.vis.text(x, win=name, opts=dict(title=name))
            elif isinstance(x, torch.Tensor):
                if x.numel() == 1:
                    self.vis.line(X=[iters],
                                  Y=[x.item()],
                                  update='append',
                                  win=name,
                                  opts=dict(title=name),
                                  name=name)
                elif x.dim() == 2:
                    self.vis.heatmap(x, win=name, opts=dict(title=name))
                elif x.dim() == 3:
                    self.vis.image(x,
                                   win=name,
                                   opts=dict(title=name,
                                             store_history=name
                                             in store_history))
                elif x.dim() == 4:
                    self.vis.images(x,
                                    win=name,
                                    opts=dict(title=name,
                                              store_history=name
                                              in store_history))
                else:
                    assert False, "incorrect tensor dim"
            else:
                assert False, "incorrect type " + x.__class__.__name__
示例#30
0
        help="num seconds between vis plotting, default just one plot")
    parser.add_argument("--save-loc",
                        type=str,
                        default='None',
                        help="directory to save plots")
    return parser.parse_args()


if __name__ == "__main__":
    args = parse_args()
    viz = Visdom()
    if not viz.check_connection():
        subprocess.Popen(
            ["python", "-m", "visdom.server", "-p",
             str(args.port)])

    if args.vis_interval is not None:
        vizes = []
        while True:
            for v in vizes:
                viz.close(v)
            try:
                vizes = visdom_plot_all(viz,
                                        args.log_dir,
                                        save_loc=args.save_loc)
            except IOError:
                pass
            time.sleep(args.vis_interval)
    else:
        visdom_plot_all(viz, args.log_dir, save_loc=args.save_loc)
示例#31
0
def main(args):
    # ***
    #
    # ***
    os.makedirs(args.save_folder, exist_ok=True)
    setproctitle.setproctitle(args.proc_name)
    gpus = os.getenv('CUDA_VISIBLE_DEVICES')

    if args.use_visdom:
        try:
            viz = Visdom(server=args.viz_ip, port=args.viz_port)
            assert viz.check_connection()
            viz.close()
            train_iter_plot = create_vis_plot(viz,
                                              args.proc_name + ' training')
            test_iter_plot = create_vis_plot(viz, args.proc_name + ' testing')
        except BaseException as err:
            raise BaseException(
                'fail to connect visdom server:{}, port:{} ...'.format(
                    args.viz_ip, args.viz_port))

    args.cuda = not args.no_cuda and torch.cuda.is_available()

    kwargs = {'pin_memory': True} if args.cuda else {}

    normMean = [0.49139968, 0.48215827, 0.44653124]
    normStd = [0.24703233, 0.24348505, 0.26158768]
    normTransform = transforms.Normalize(normMean, normStd)

    trainTransform = transforms.Compose([
        # transforms.RandomCrop(32, padding=4),
        transforms.RandomCrop(224, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        normTransform
    ])
    testTransform = transforms.Compose([
        # transforms.RandomCrop(32,padding=4),
        transforms.RandomCrop(224, padding=4),
        transforms.ToTensor(),
        normTransform
    ])
    trainLoader = DataLoader(myImageFolder(label_file=args.train_list,
                                           transform=trainTransform),
                             batch_size=args.batch_size,
                             shuffle=True,
                             num_workers=args.num_workers,
                             **kwargs)
    testLoader = DataLoader(myImageFolder(label_file=args.test_list,
                                          transform=testTransform),
                            batch_size=args.batch_size,
                            shuffle=False,
                            num_workers=args.num_workers,
                            **kwargs)

    assert not (
        args.ft_net and args.resume
    ), 'args.resume and args.ft_net can not be used at the same time!'
    assert (args.ft_net
            or args.resume), 'args.ft_net and args.resume must have one!'

    start_iter = 0
    if args.ft_net:
        net = getNetwork(args.ft_net)
        print('===>Initate network from pretrained model:{}'.format(
            args.ft_net))
    if args.resume:
        print('===>Resuming from checkpoint:{} ..'.format(args.resume))
        assert os.path.isfile(
            args.resume), 'Error:no checkpoint:%s found' % args.resume
        checkpoint = torch.load(args.resume)
        net = checkpoint['net']
        start_iter = checkpoint['iteration']
        acc = checkpoint['acc']
        print('===>Resumed from checkpoint:{},start iteration:{} ..'.format(
            args.resume, start_iter))

    print('  + Number of params: {}'.format(
        sum([p.data.nelement() for p in net.parameters()])))

    if args.opt == 'sgd':
        optimizer = optim.SGD(net.fc.parameters(),
                              lr=1e-1,
                              momentum=0.9,
                              weight_decay=1e-4)
    elif args.opt == 'adam':
        optimizer = optim.Adam(net.fc.parameters(), weight_decay=1e-4)
    elif args.opt == 'rmsprop':
        optimizer = optim.RMSprop(net.fc.parameters(), weight_decay=1e-4)

    if args.cuda:
        net = net.cuda()
        if len(gpus) > 1:
            net = torch.nn.DataParallel(net)
        cudnn.benchmark = True

    train_batch_iterator = iter(trainLoader)
    for iteration in range(start_iter, args.max_iter):
        adjust_opt(args.opt, optimizer, iteration)

        train_batch = next(train_batch_iterator)
        train(args, iteration, net, train_batch, optimizer, viz,
              train_iter_plot)

        if iteration % args.test_interval == 0:
            test_acc = test(args, iteration, net, testLoader, viz,
                            test_iter_plot)
        if iteration % args.snapshot == 0:
            state = {
                'net': net.module if len(gpus) > 1 else net,
                'iteration': iteration,
                'acc': test_acc
            }
            snapshot = ospj(args.save_folder,
                            args.proc_name + '_{}.pth'.format(iteration))
            torch.save(state, snapshot)
            print('Saving state, iter: {}, to {}'.format(iteration, snapshot))
示例#32
0
class Experiment(object):
    """
    Experiment class
    """

    def __init__(self, name, hparams, desc=None):
        """

        Args:
            name (string): the name of the experiment
            hparams (object): the hypermarameters used for this experiment
        """
        self.name = name
        self.desc = desc
        self.hparams = hparams
        self.metrics = defaultdict(Metric)

        self.timestamp_start = datetime.now()
        self.timestamp_update = datetime.now()
        self.last_update = time.time()

        self.viz = Visdom()

        self.viz.close()
        self.vis_params()

        if desc is not None:
            self.vis_desc()

    def update_plots(self):
        for exp_name, metric in self.metrics.items():
            metric.update_plot()
        # self.save_experiment()

    def vis_params(self):
        lines = []
        for param, value in self.hparams.items():
            lines.append("{}: {}".format(param, value))
        self.viz.text("<pre>{}</pre>".format("\n".join(lines)),
                      opts=dict(
                          width=max([len(x) for x in lines]) * 10,
                          height=len(lines) * 20,
                      ))

    def vis_desc(self):
        self.viz.text("<pre>{}</pre>".format(self.desc),
                      opts=dict(
                          width=max([len(x) for x in
                                     self.desc.split("\n")]) * 8.5,
                          height=len(self.desc.split("\n")) * 20,
                      ))

    def add_metric(self, metric):
        """
        Add a metric to the experiment
        Args:
            metric (Metric): a metric object

        Returns:

        """
        metric.vic_context = self.viz
        self.metrics[metric.name] = metric

    def get_score(self, metric, tag):
        return self.metrics[metric]._values[tag][-1]

    def save_experiment(self):
        """
        Implement a saving mechanism (in text, csv or a database)
        Returns:

        """
        self.timestamp_update = datetime.now()
        self.db.update({
            'name': self.name,
            'desc': self.desc,
            'hparams': self.hparams,
            'metrics': self.metrics,
            'timestamp_start': self.timestamp_start,
            'timestamp_update': self.timestamp_update,
            'last_update': self.last_update
        }, doc_ids=[self.db_record])

    def visualize_experiment(self):
        raise NotImplementedError