예제 #1
0
def show(net, input_shape=(1, 3, 32, 32), logdir='tensorboard', port=8888):
    '''print the network architecture on web-browser, using tensorboardX and tensorboard.
    tensoboard must be install to use this tool.
    this tool will create a noise data according to given input_shape,
    and feed it directly to net, in order to probe its structure.
    network strctures descriptions will be written to logdir.
    a tensorboard daemon will be launched to read the logdir and start a web server
    on given port.

    Notice: 

        input shape must be NCHW, following pytorch style.

        This program overwrites the system argument lists (sys.argv)
    '''
    try:
        from tensorboard.main import run_main
    except Exception as e:
        raise Exception(
            str(e) +
            '\nError importing tensorboard. Maybe your tensorboard is not installed correctly.\
            usually, tensorboard should come with tensorflow. stand-alone tensorboard packages are not stable enough.'
        )
    import tensorboardX as tb
    import sys, shutil, re
    net, input_shape, logdir, port = sort_args(
        [net, input_shape, logdir, port],
        [torch.nn.Module, (tuple, list), str, int])
    shutil.rmtree(logdir, ignore_errors=True)
    imgs = torch.rand(*input_shape)
    w = tb.SummaryWriter(logdir)
    try:
        w.add_graph(net, imgs)
    except Exception as e:
        raise Exception(
            str(e) + '\nYour network has problems dealing with input data.\
         It is usually due to wrong input shape or problematic network implementation.\
         Please check your network code for more information.')
    finally:
        w.close()
    args = [
        'tensorboard', '--logdir', logdir, '--port',
        str(port), '--host', '127.0.0.1'
    ]
    sys.argv = args
    sys.argv[0] = re.sub(r'(-script\.pyw?|\.exe)?$', '', sys.argv[0])
    print('You may have to delete', logdir, 'folder manully.')
    omini_open('http://127.0.0.1:%d' % port)
    try:
        run_main()
    except Exception as e:
        print(e)
예제 #2
0
def main():
  sys.stderr.write("PID: %d\n" % (os.getpid(),))

  ready_lock = threading.Semaphore(0)

  threading.Thread(
    target=track_memory_forever,
    args=(ready_lock,),
    name="MemoryTracker",
    daemon=True,
  ).start()
  ready_lock.acquire()

  from tensorboard import main as tensorboard_main

  tensorboard_main.run_main()
예제 #3
0
# -*- coding: utf-8 -*-
# -*- coding: utf-8 -*-
import re
import sys

# del the dir path
sys.path = sys.path[1:]
from tensorboard.main import run_main

if __name__ == '__main__':
    sys.argv[0] = re.sub(r'(-script\.pyw|\.exe)?$', '', sys.argv[0])
    sys.exit(run_main())



예제 #4
0
 def run(self):
     from tensorboard.main import run_main
     import sys
     sys.argv = ['', '--logdir=' + self.logdir]
     sys.exit(run_main())
예제 #5
0
def cli():
    if 'win32' in sys.platform:
        patch_mkl()
    main.run_main()
예제 #6
0
def spawn_daemon(env):
    from tensorboard.main import run_main

    logdir = os.path.join(env.root, LOGS_DIR)
    sys.argv = [sys.argv[0]] + ['--logdir', logdir] + sys.argv[3:]
    run_main()
예제 #7
0
 def __init__(self, log_folder: str):
     self._writer = tf.summary.FileWriter(log_folder)
     tf.flags.FLAGS.logdir = log_folder
     tb.run_main()