Пример #1
0
 def __init__(self, graph,
              save_steps=None,
              save_secs=None,
              output_dir="", suffix=""):
   """Initializes a hook that takes periodic profiling snapshots.
   `options.run_metadata` argument of `tf.Session.Run` is used to collect
   metadata about execution. This hook sets the metadata and dumps it in Chrome
   Trace format.
   Args:
     save_steps: `int`, save profile traces every N steps. Exactly one of
         `save_secs` and `save_steps` should be set.
     save_secs: `int` or `float`, save profile traces every N seconds.
     output_dir: `string`, the directory to save the profile traces to.
         Defaults to the current directory.
   """
   self._output_file = os.path.join(output_dir, "profile-{}-{}.txt")
   self._suffix = suffix
   self._file_writer = SummaryWriterCache.get(output_dir)
   self._timer = tf.train.SecondOrStepTimer(
       every_secs=save_secs, every_steps=save_steps)
   self._profiler = model_analyzer.Profiler(graph=graph)
   profile_op_builder = option_builder.ProfileOptionBuilder( )
   ## sort by time taken
   #profile_op_builder.select(['micros', 'occurrence'])
   #profile_op_builder.order_by('micros')
   profile_op_builder.select(['bytes'])
   profile_op_builder.order_by('bytes')
   profile_op_builder.with_max_depth(10) # can be any large number
   self._profile_op_builder = profile_op_builder
Пример #2
0
 def profiler(self):
     """Returns the current profiler object."""
     with self._lock:
         if not self._profiler:
             self._profiler = model_analyzer.Profiler(
                 ops.get_default_graph())
         return self._profiler
Пример #3
0
 def profiler(self):
     """Returns the current profiler object."""
     if not self._enabled:
         return None
     if not self._profiler:
         self._profiler = model_analyzer.Profiler(ops.get_default_graph())
     return self._profiler
Пример #4
0
 def __enter__(self):
     if self.profile:
         self.profiler = model_analyzer.Profiler(graph=self.sess.graph)
         if tf.__version__ < "1.15.0":
             self.run_options = tf.RunOptions(
                 trace_level=tf.RunOptions.FULL_TRACE)
             self.run_metadata = tf.RunMetadata()
         else:
             self.run_options = tf.compat.v1.RunOptions(
                 trace_level=tf.compat.v1.RunOptions.FULL_TRACE)
             self.run_metadata = tf.compat.v1.RunMetadata()
     return self
Пример #5
0
def read_data(cycle_length, num_splits, data_dir_index, batch_size,
              num_threads):

    run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE)
    run_metadata = tf.RunMetadata()
    loop = 2
    # num_splits = 1
    # cycle_length = 1
    # data_dir = './benchmarks-cnn_tf_v1.8_compatible/scripts/tf_cnn_benchmarks/test_data/fake_tf_record_data/'
    data_dir = './TFRecords{}/'.format(data_dir_index)
    # batch_size = 10000
    batch_size_per_split = batch_size / num_splits
    # num_threads = 2

    sess = tf.Session()

    test_profiler = model_analyzer.Profiler(graph=sess.graph)
    # time.sleep(10)
    with tf.name_scope('batch_processing'):
        # Build final results per split.
        # images = [[] for _ in range(loop)]
        # labels = [[] for _ in range(loop)]
        ds_iterator = parallel_read_data(
            data_dir=data_dir,
            batch_size=batch_size,
            batch_size_per_split=batch_size_per_split,
            num_splits=num_splits,
            cycle_length=cycle_length,
            num_threads=num_threads)

        sess.run(ds_iterator.initializer,
                 options=run_options,
                 run_metadata=run_metadata)
        test_profiler.add_step(step=0, run_meta=run_metadata)
        with open(
                './parallel_results/{}-{}-{}-{}.txt'.format(
                    cycle_length, num_splits, batch_size, data_dir_index),
                'w') as f_out:
            for d in xrange(loop):
                for k in xrange(num_splits):
                    images = ds_iterator.get_next()
                    start = time.time()
                    temp = sess.run(images,
                                    options=run_options,
                                    run_metadata=run_metadata)
                    end = time.time()
                    # test_profiler.add_step(step=k+1, run_meta=run_metadata)
                    print len(temp)
                    print 'time is: {}'.format(end - start)
                    print 'End {}-{}'.format(d, k)
                    f_out.write('{}-{} {}'.format(d, k, end - start) + '\n')
 def __enter__(self):
     if self.profiler == "nvprof":
         import ctypes
         self._cudart = ctypes.CDLL('libcudart.so')
         self._cudart.cudaProfilerStart()
     elif self.profiler == "pyprof":
         import cProfile
         self._profiler_handle = cProfile.Profile()
         self._profiler_handle.enable()
     elif self.profiler == "native":
         self._profiler_handle = model_analyzer.Profiler(
             graph=self._sess.graph)
         self.run_options = tf.compat.v1.RunOptions(
             trace_level=tf.compat.v1.RunOptions.FULL_TRACE)
         self.run_metadata = tf.compat.v1.RunMetadata()
     return self
Пример #7
0
 def __enter__(self):
     if self.profiler == "pyprof":
         import cProfile
         self.profiler_handle = cProfile.Profile()
         self.profiler_handle.enable()
     elif self.profiler != "none":
         self.profiler_handle = model_analyzer.Profiler(
             graph=self.sess.graph)
         if tf.__version__ < "1.15.0":
             self.run_options = tf.RunOptions(
                 trace_level=tf.RunOptions.FULL_TRACE)
             self.run_metadata = tf.RunMetadata()
         else:
             self.run_options = tf.compat.v1.RunOptions(
                 trace_level=tf.compat.v1.RunOptions.FULL_TRACE)
             self.run_metadata = tf.compat.v1.RunMetadata()
     return self
Пример #8
0
  def testMultipleProfilePerStep(self):
    ops.reset_default_graph()
    opts = (builder(builder.trainable_variables_parameter())
            .with_empty_output()
            .with_accounted_types(['.*'])
            .select(['micros', 'bytes', 'peak_bytes',
                     'residual_bytes', 'output_bytes']).build())

    r = lib.BuildSmallModel()
    sess = session.Session()
    profiler = model_analyzer.Profiler(sess.graph)

    init_var_run_meta = config_pb2.RunMetadata()
    sess.run(variables.global_variables_initializer(),
             options=config_pb2.RunOptions(
                 trace_level=config_pb2.RunOptions.FULL_TRACE),
             run_metadata=init_var_run_meta)

    train_run_meta = config_pb2.RunMetadata()
    sess.run(r,
             options=config_pb2.RunOptions(
                 trace_level=config_pb2.RunOptions.FULL_TRACE),
             run_metadata=train_run_meta)

    profiler.add_step(0, train_run_meta)
    ret1 = profiler.profile_name_scope(opts)
    n1 = lib.SearchTFProfNode(
        ret1, 'DW/Initializer/random_normal/RandomStandardNormal')
    # Without the var initialization run_meta, it doesn't have the
    # information of var_initialization.
    self.assertEqual(n1.exec_micros, 0)
    self.assertEqual(n1.requested_bytes, 0)
    self.assertEqual(n1.peak_bytes, 0)
    self.assertEqual(n1.residual_bytes, 0)

    profiler.add_step(0, init_var_run_meta)
    ret2 = profiler.profile_name_scope(opts)
    n2 = lib.SearchTFProfNode(
        ret2, 'DW/Initializer/random_normal/RandomStandardNormal')
    # After adding the var initialization run_meta.
    self.assertGreater(n2.exec_micros, 0)
    self.assertGreater(n2.requested_bytes, 0)
    self.assertGreater(n2.peak_bytes, 0)
    self.assertGreater(n2.residual_bytes, 0)
Пример #9
0
 def build_networks(self):
     if self.disp_console: print("Building YOLO_tiny graph...")
     self.x = tf.placeholder('float32', [None, 448, 448, 3])
     self.conv_1 = self.conv_layer(1, self.x, 16, 3, 1)
     self.pool_2 = self.pooling_layer(2, self.conv_1, 2, 2)
     self.conv_3 = self.conv_layer(3, self.pool_2, 32, 3, 1)
     self.pool_4 = self.pooling_layer(4, self.conv_3, 2, 2)
     self.conv_5 = self.conv_layer(5, self.pool_4, 64, 3, 1)
     self.pool_6 = self.pooling_layer(6, self.conv_5, 2, 2)
     self.conv_7 = self.conv_layer(7, self.pool_6, 128, 3, 1)
     self.pool_8 = self.pooling_layer(8, self.conv_7, 2, 2)
     self.conv_9 = self.conv_layer(9, self.pool_8, 256, 3, 1)
     self.pool_10 = self.pooling_layer(10, self.conv_9, 2, 2)
     self.conv_11 = self.conv_layer(11, self.pool_10, 512, 3, 1)
     self.pool_12 = self.pooling_layer(12, self.conv_11, 2, 2)
     self.conv_13 = self.conv_layer(13, self.pool_12, 1024, 3, 1)
     self.conv_14 = self.conv_layer(14, self.conv_13, 1024, 3, 1)
     self.conv_15 = self.conv_layer(15, self.conv_14, 1024, 3, 1)
     self.fc_16 = self.fc_layer(16,
                                self.conv_15,
                                256,
                                flat=True,
                                linear=False)
     self.fc_17 = self.fc_layer(17,
                                self.fc_16,
                                4096,
                                flat=False,
                                linear=False)
     #skip dropout_18
     self.fc_19 = self.fc_layer(19,
                                self.fc_17,
                                1470,
                                flat=False,
                                linear=True)
     self.sess = tf.Session()
     self.sess.run(tf.initialize_all_variables())
     self.saver = tf.train.Saver()
     self.writer = tf.summary.FileWriter('./cachelogs',
                                         tf.get_default_graph())
     self.profiler = model_analyzer.Profiler(graph=self.sess.graph)
     self.run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE)
     self.run_metadata = tf.RunMetadata()
     #self.saver.restore(self.sess,self.weights_file)
     if self.disp_console: print("Loading complete!" + '\n')
Пример #10
0
    def testEager(self):
        ops.reset_default_graph()
        with context.eager_mode():
            outfile = os.path.join(test.get_temp_dir(), 'dump')
            opts = builder(
                builder.time_and_memory()).with_file_output(outfile).build()
            context.enable_run_metadata()
            lib.BuildSmallModel()

            profiler = model_analyzer.Profiler()
            profiler.add_step(0, context.export_run_metadata())
            context.disable_run_metadata()
            profiler.profile_operations(opts)
            with gfile.Open(outfile, 'r') as f:
                out_str = f.read()
                self.assertTrue('Conv2D' in out_str)
                self.assertTrue('VarHandleOp' in out_str)

            with gfile.Open('/tmp/eager_profile', 'wb') as f:
                profile_pb = tfprof_log_pb2.ProfileProto()
                profile_pb.ParseFromString(profiler.serialize_to_string())
                profile_pb_str = '%s' % profile_pb
                self.assertTrue('Conv2D' in profile_pb_str)
                self.assertTrue('VarHandleOp' in profile_pb_str)
Пример #11
0
def main(argv=None):  # pylint: disable=unused-argument
    assert args.ckpt > 0 or args.batch_eval
    assert args.detect or args.segment, "Either detect or segment should be True"
    if args.trunk == 'resnet50':
        net = ResNet
        depth = 50
    if args.trunk == 'resnet101':
        net = ResNet
        depth = 101
    if args.trunk == 'vgg16':
        net = VGG
        depth = 16

    net = net(config=net_config, depth=depth, training=False)

    if args.dataset == 'voc07' or args.dataset == 'voc07+12':
        loader = VOCLoader('07', 'test')
    if args.dataset == 'voc12':
        loader = VOCLoader('12', 'val', segmentation=args.segment)
    if args.dataset == 'coco':
        loader = COCOLoader(args.split)

    with tf.Session(config=tf.ConfigProto(allow_soft_placement=True,
                                          log_device_placement=False,
                                          gpu_options=tf.GPUOptions(allow_growth=True,
                                                                    per_process_gpu_memory_fraction=0.2))) as sess:
        if args.use_profile:
            profiler = model_analyzer.Profiler(graph=sess.graph)
            detector = Detector(sess, net, loader, net_config,
                                no_gt=args.no_seg_gt, profiler=profiler)
        else:
            detector = Detector(sess, net, loader, net_config,
                                no_gt=args.no_seg_gt)

        if args.dataset == 'coco':
            tester = COCOEval(detector, loader)
        else:
            tester = Evaluation(detector, loader, iou_thresh=args.voc_iou_thresh)
        if not args.batch_eval:
            detector.restore_from_ckpt(args.ckpt)
            tester.evaluate_network(args.ckpt)
        else:
            log.info('Evaluating %s' % args.run_name)
            ckpts_folder = CKPT_ROOT + args.run_name + '/'
            out_file = ckpts_folder + evaluation_logfile

            max_checked = get_last_eval(out_file)
            log.debug("Maximum checked ckpt is %i" % max_checked)
            with open(out_file, 'a') as f:
                start = max(args.min_ckpt, max_checked+1)
                ckpt_files = glob(ckpts_folder + '*.data*')
                folder_has_nums = np.array(list((map(filename2num, ckpt_files))), dtype='int')
                nums_available = sorted(folder_has_nums[folder_has_nums >= start])
                nums_to_eval = [nums_available[-1]]
                for n in reversed(nums_available):
                    if nums_to_eval[-1] - n >= args.step:
                        nums_to_eval.append(n)
                nums_to_eval.reverse()

                for ckpt in nums_to_eval:
                    log.info("Evaluation of ckpt %i" % ckpt)
                    tester.reset()
                    detector.restore_from_ckpt(ckpt)
                    res = tester.evaluate_network(ckpt)
                    f.write(res)
                    f.flush()

        if args.use_profile:
            profile_scope_builder = option_builder.ProfileOptionBuilder(
                # option_builder.ProfileOptionBuilder.trainable_variables_parameter()
            )
            profile_scope_builder.with_max_depth(4)
            profile_scope_builder.with_min_memory(int(2e6))
            profile_scope_builder.with_step(2)
            profile_scope_builder.select(['bytes'])
            # profile_scope_builder.with_node_names(show_name_regexes=['.*resnet.*', '.*ssd.*'])
            # profile_scope_builder.with_node_names(hide_name_regexes=['.*resnet.*', '.*ssd.*'])
            # profile_scope_builder.order_by('output_bytes')
            detector.profiler.profile_name_scope(profile_scope_builder.build())
Пример #12
0
    x = tf.placeholder(tf.float32, shape=[1, image_size, image_size, 3])
    with slim.arg_scope(inception.inception_v3_arg_scope()):
        logits, end_points = inception.inception_v3(x,
                                                    num_classes=1001,
                                                    is_training=False)
    probabilities = tf.nn.softmax(logits)

    init_fn = slim.assign_from_checkpoint_fn(
        checkpoints, slim.get_model_variables('InceptionV3'))

    results = {}
    with tf.Session() as sess:
        init_fn(sess)

        #profiler
        inception_profiler = model_analyzer.Profiler(graph=sess.graph)
        options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE)
        run_metadata = tf.RunMetadata()

        profile_scope_opt_builder = option_builder.ProfileOptionBuilder(
            option_builder.ProfileOptionBuilder.float_operation())
        inception_profiler.profile_name_scope(
            profile_scope_opt_builder.build())

        #https://upload.wikimedia.org/wikipedia/commons/d/d9/First_Student_IC_school_bus_202076.jpg
        for f in (glob.glob("First_Student_IC_school_bus_202076.jpg")):
            img = image2placeholder(f, image_size)
            probabilities = sess.run(
                probabilities,
                feed_dict={x: img},
                options=tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE),
Пример #13
0
def train():
    """Train CIFAR-10 for a number of steps."""
    with tf.Graph().as_default():
        global_step = tf.train.get_or_create_global_step()

        # Get images and labels for CIFAR-10.
        # Force input pipeline to CPU:0 to avoid operations sometimes ending up on
        # GPU and resulting in a slow down.
        with tf.device('/cpu:0'):
            images, labels = cifar10.distorted_inputs()

        # Build a Graph that computes the logits predictions from the
        # inference model.
        logits = cifar10.inference(images)

        # Calculate loss.
        loss = cifar10.loss(logits, labels)

        # Build a Graph that trains the model with one batch of examples and
        # updates the model parameters.
        train_op = cifar10.train(loss, global_step)

        # KJ: add variable
        i = 0

        class _LoggerHook(tf.train.SessionRunHook):
            """Logs loss and runtime."""
            def begin(self):
                self._step = -1
                self._start_time = time.time()

            def before_run(self, run_context):
                self._step += 1
                return tf.train.SessionRunArgs(loss)  # Asks for loss value.

            def after_run(self, run_context, run_values):
                if self._step % FLAGS.log_frequency == 0:
                    current_time = time.time()
                    duration = current_time - self._start_time
                    self._start_time = current_time

                    loss_value = run_values.results
                    examples_per_sec = FLAGS.log_frequency * FLAGS.batch_size / duration
                    sec_per_batch = float(duration / FLAGS.log_frequency)

                    format_str = (
                        '%s: step %d, loss = %.2f (%.1f examples/sec; %.3f '
                        'sec/batch)')
                    print(format_str % (datetime.now(), self._step, loss_value,
                                        examples_per_sec, sec_per_batch))

        with tf.train.MonitoredTrainingSession(
                checkpoint_dir=FLAGS.train_dir,
                save_checkpoint_secs=10,  # Save checkpoint by interval
                save_summaries_steps=10,  # Save summary by interval
                hooks=[
                    tf.train.StopAtStepHook(last_step=FLAGS.max_steps),
                    tf.train.NanTensorHook(loss),
                    _LoggerHook()
                ],
                config=tf.ConfigProto(
                    log_device_placement=FLAGS.log_device_placement
                    # , intra_op_parallelism_threads=1
                    # , inter_op_parallelism_threads=1
                    ,
                    allow_soft_placement=True
                    # , device_count = {'GPU': 0}
                )) as mon_sess:

            # Create tfProfiler instance
            cifar_profiler = model_analyzer.Profiler(graph=mon_sess.graph)
            run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE
                                        )  # Set level to Full Trace
            run_metadata = tf.RunMetadata()

            while not mon_sess.should_stop():
                if i % FLAGS.log_frequency == 0:
                    mon_sess.run(train_op,
                                 options=run_options,
                                 run_metadata=run_metadata)
                    cifar_profiler.add_step(step=i, run_meta=run_metadata)

                else:
                    mon_sess.run(train_op)
                i += 1
            """
      Profiler Section
        1. Profile each graph node's execution time and consumed memory
        2. Profile each layer's parameters, modle size and parameters distribution
        3. Profile top K most time-consuming operations
        4. Profile top K most memory-consuming operations
        5. Profile python code performance line by line
        6. Give optimization Advice
      """

            # 1. Profile each graph node's execution time and consumed memory
            profile_graph_opts_builder = option_builder.ProfileOptionBuilder(
                option_builder.ProfileOptionBuilder.time_and_memory())
            profile_graph_opts_builder.with_timeline_output(
                timeline_file=os.path.join(
                    os.path.split(os.path.split(os.path.abspath(__file__))[0])
                    [0], 'logs/cifar10_profiler/cifar10_profiler.json'))
            profile_graph_opts_builder.with_step(
                (FLAGS.max_steps - 1) // 2)  # Profile <num>th step
            cifar_profiler.profile_graph(
                profile_graph_opts_builder.build())  # Show graph view result

            # 2. Profile each layer's parameters, modle size and parameters distribution
            profile_scope_opt_builder = option_builder.ProfileOptionBuilder(
                option_builder.ProfileOptionBuilder.
                trainable_variables_parameter())
            profile_scope_opt_builder.with_max_depth(
                4)  # Maximum level of nested depth
            profile_scope_opt_builder.select(['params'])  # Show params
            profile_scope_opt_builder.order_by('params')  # Sort by params
            cifar_profiler.profile_name_scope(
                profile_scope_opt_builder.build())

            # 3. Profile top K most time-consuming operations
            profile_op_opt_builder = option_builder.ProfileOptionBuilder()
            profile_op_opt_builder.select(
                ['micros',
                 'occurrence'])  # Show Op execution time, node's number
            profile_op_opt_builder.order_by('micros')  # Sort by micros
            profile_op_opt_builder.with_max_depth(4)  # Only show top 5
            cifar_profiler.profile_operations(profile_op_opt_builder.build())

            # 4. Profile top K most memory-consuming operations
            profile_op_opt_builder = option_builder.ProfileOptionBuilder()
            profile_op_opt_builder.select(
                ['bytes',
                 'occurrence'])  # Show Op consumed memory, node's number
            profile_op_opt_builder.order_by('bytes')  # Sort by bytes
            profile_op_opt_builder.with_max_depth(4)  # Only show top 5
            cifar_profiler.profile_operations(profile_op_opt_builder.build())

            # 5. Profile python code performance line by line
            profile_code_opt_builder = option_builder.ProfileOptionBuilder()
            profile_code_opt_builder.with_max_depth(1000)
            profile_code_opt_builder.with_node_names(
                show_name_regexes=[r'cifar10[\s\S]*'])
            profile_code_opt_builder.with_min_execution_time(
                min_micros=10)  # Only show Top 10
            profile_code_opt_builder.select(['micros'])
            profile_code_opt_builder.order_by('micros')
            cifar_profiler.profile_python(profile_code_opt_builder.build())

            # 6. Give optimization Advice
            cifar_profiler.advise(options=model_analyzer.ALL_ADVICE)
Пример #14
0
    ckpt = tf.train.latest_checkpoint(hp.logdir)
    if ckpt is None:
        logging.warning("No checkpoint is found")
        exit(1)
    else:
        saver.restore(sess, ckpt)

    logging.info("# test evaluation")
    sess.run(eval_init_op)
    # It means 跑op=y_hat,输入是num_eval_batches,只截取前num_eval_samples个结果
    logging.info("# get hypotheses")
    if hp.use_profile:
        logging.info("# init profile")
        run_metadata = tf.RunMetadata()
        run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE)
        mnist_profiler = model_analyzer.Profiler(graph=sess.graph)
        ts = time.time()
        hypotheses = get_hypotheses(num_eval_batches, num_eval_samples, sess, y_hat, m.idx2token,
                                    use_profile=True,
                                    options=run_options, run_metadata=run_metadata, profiler=mnist_profiler)
        logging.info("eval: takes %s" % (time.time() - ts))
        # 统计内容为每个graph node的运行时间和占用内存
        profile_graph_opts_builder = option_builder.ProfileOptionBuilder(
            option_builder.ProfileOptionBuilder.time_and_memory())

        # 输出方式为timeline
        profile_graph_opts_builder.with_timeline_output(timeline_file='/tmp/mnist_profiler.json')
        # 定义显示sess.Run() 第0步的统计数据
        profile_graph_opts_builder.with_step(0)
        profile_graph_opts_builder.with_step(1)
        # 显示视图为graph view
Пример #15
0
def main(_):
    # Import data
    mnist = input_data.read_data_sets(FLAGS.data_dir)

    # Create the model
    x = tf.placeholder(tf.float32, [None, 784])

    # Define loss and optimizer
    y_ = tf.placeholder(tf.int64, [None])

    # Build the graph for the deep net
    y_conv, keep_prob = deepnn(x)

    with tf.name_scope('loss'):
        cross_entropy = tf.losses.sparse_softmax_cross_entropy(labels=y_,
                                                               logits=y_conv)
    cross_entropy = tf.reduce_mean(cross_entropy)

    with tf.name_scope('adam_optimizer'):
        train_step = tf.train.AdamOptimizer(1e-4).minimize(cross_entropy)

    with tf.name_scope('accuracy'):
        correct_prediction = tf.equal(tf.argmax(y_conv, 1), y_)
        correct_prediction = tf.cast(correct_prediction, tf.float32)
    accuracy = tf.reduce_mean(correct_prediction)

    graph_location = tempfile.mkdtemp()
    print('Saving graph to: %s' % graph_location)
    train_writer = tf.summary.FileWriter(graph_location)
    train_writer.add_graph(tf.get_default_graph())

    from tensorflow.python.profiler import model_analyzer
    from tensorflow.python.profiler import option_builder
    with tf.Session(config=get_sess_config()) as sess:
        sess.run(tf.global_variables_initializer())
        profiler = model_analyzer.Profiler(sess.graph)
        #for i in range(20000):
        for i in range(2):
            batch = mnist.train.next_batch(21000)
            '''if i % 100 == 0:
        train_accuracy = accuracy.eval(feed_dict={
            x: batch[0], y_: batch[1], keep_prob: 1.0})
        print('step %d, training accuracy %g' % (i, train_accuracy))
      '''
            #train_step.run(feed_dict={x: batch[0], y_: batch[1], keep_prob: 0.5})
            run_metadata = tf.RunMetadata()
            sess.run(
                train_step,
                feed_dict={
                    x: batch[0],
                    y_: batch[1],
                    keep_prob: 0.5
                },
                options=tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE),
                run_metadata=run_metadata)
            profiler.add_step(i, run_metadata)

            # profile the timing of your model operations.
            opts = (tf.profiler.ProfileOptionBuilder(
                option_builder.ProfileOptionBuilder.time_and_memory()).select(
                    ['micros', 'bytes',
                     'occurrence']).order_by('micros').build())
            profiler.profile_operations(options=opts)
            '''
      opts = (option_builder.ProfileOptionBuilder(
        option_builder.ProfileOptionBuilder.time_and_memory())
        .with_step(i)
        .with_timeline_output("./timeline_output/code_step").build())
      profiler.profile_python(options=opts)
      '''

            # can generate a timeline:
            opts = (option_builder.ProfileOptionBuilder(
                option_builder.ProfileOptionBuilder.time_and_memory()).
                    with_step(i).with_timeline_output(
                        "./timeline_output/step").build())
            profiler.profile_graph(options=opts)

        #print('test accuracy %g' % accuracy.eval(feed_dict={
        #    x: mnist.test.images, y_: mnist.test.labels, keep_prob: 1.0}))
    # Print to stdout an analysis of the memory usage and the timing information
    # broken down by python codes.
    # ProfileOptionBuilder = tf.profiler.ProfileOptionBuilder
    # opts = ProfileOptionBuilder(ProfileOptionBuilder.time_and_memory()
    #    ).with_node_names(show_name_regexes=['*']).build()
    #).with_node_names(show_name_regexes=['.*my_code.py.*']).build()
    '''tf.profiler.profile(
      tf.get_default_graph(),
      run_meta=run_metadata,
      cmd='code',
     options=opts)
  '''
    '''
Пример #16
0
def test_lanenet_for_eval(image_path, weights_path):
    """

    :param image_path: 测试图片地址
    :param weights_path: 训练模型地址
    :return:
    """
    assert ops.exists(image_path), '{:s} not exist'.format(image_path)

    log.info('Start reading image and preprocessing')
    t_start = time.time()
    image = cv2.imread(image_path, cv2.IMREAD_COLOR)
    """
    :param: cv2.IMREAD_COLOR:  It specifies to load a color image. Any transparency of image will be neglected. 
    """
    image_vis = image

    image = cv2.resize(image, (512, 256), interpolation=cv2.INTER_LINEAR)
    """
    :param: INTER_LINEAR: 双线性插值。
    """

    image = image / 127.5 - 1.0  # 归一化 (只归一未改变维数)
    log.info('Image load complete, cost time: {:.5f}s'.format(time.time() - t_start))

    input_tensor = tf.placeholder(dtype=tf.float32, shape=[1, 256, 512, 3], name='input_tensor')
    """
    在神经网络构建graph的时候在模型中的占位,此时并没有把要输入的数据传入模型,它只会分配必要的内存。
    等建立session,在会话中,运行模型的时候通过feed_dict()函数向占位符喂入数据。

    :param: dtype:数据类型。常用的是tf.float32,tf.float64等数值类型
    :param: shape:数据形状。NHWC:[batch, in_height, in_width, in_channels] [参与训练的一批(batch)图像的数量,输入图片的高度,输入图片的宽度,输入图片的通道数]
    :param: name:名称。
    """

    net = lanenet.LaneNet(phase='test', net_flag='vgg')

    binary_seg_ret, instance_seg_ret = net.inference(input_tensor=input_tensor, name='lanenet_model')

    postprocessor = lanenet_postprocess.LaneNetPostProcessor()

    saver = tf.train.Saver()
    # 加载预训练模型参数

    # Set session configuration
    sess_config = tf.ConfigProto()
    sess_config.gpu_options.per_process_gpu_memory_fraction = CFG.TEST.GPU_MEMORY_FRACTION
    # 限制 GPU 使用率

    sess_config.gpu_options.allow_growth = CFG.TRAIN.TF_ALLOW_GROWTH
    # 动态申请显存

    sess_config.gpu_options.allocator_type = 'BFC'  # best fit with coalescing  内存管理算法
    # 内存分配类型
    sess = tf.Session(config=sess_config)

    with sess.as_default():
        saver.restore(sess=sess, save_path=weights_path)

        binary_seg_image_, instance_seg_image_ = sess.run(
            [binary_seg_ret, instance_seg_ret],
            feed_dict={input_tensor: [image]}
        )

        profiler = model_analyzer.Profiler(graph=sess.graph)
        run_metadata = tf.RunMetadata()

        t_start = time.time()

        binary_seg_image, instance_seg_image = sess.run(
            [binary_seg_ret, instance_seg_ret],
            feed_dict={input_tensor: [image]},
            options=tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE),
            run_metadata=run_metadata
        )

        t_cost = time.time() - t_start
        log.info('Single imgae inference cost time: {:.5f}s'.format(t_cost))

        profiler.add_step(step=1, run_meta=run_metadata)

        profile_op_builder = option_builder.ProfileOptionBuilder()
        profile_op_builder.select(['micros', 'occurrence'])
        profile_op_builder.order_by('micros')
        profile_op_builder.with_max_depth(5)
        profile_op_builder.with_file_output(outfile="./op_profiler.txt")
        # profiler.profile_graph(profile_op_builder.build())
        profiler.profile_operations(profile_op_builder.build())

        profile_code_builder = option_builder.ProfileOptionBuilder()
        profile_code_builder.with_max_depth(1000)
        profile_code_builder.with_node_names(show_name_regexes=['cnn_basenet.py.*'])
        profile_code_builder.with_min_execution_time(min_micros=10)
        profile_code_builder.select(['micros'])
        profile_code_builder.order_by('min_micros')
        profile_code_builder.with_file_output(outfile="./code_profiler.txt")
        profiler.profile_python(profile_code_builder.build())

        profiler.advise(options=model_analyzer.ALL_ADVICE)

        postprocess_result = postprocessor.postprocess(
            binary_seg_result=binary_seg_image[0],
            instance_seg_result=instance_seg_image[0],
            source_image=image_vis
        )

        """
        postprocess_result = postprocessor.postprocess_for_test(
            binary_seg_result=binary_seg_image[0],
            instance_seg_result=instance_seg_image[0],
            source_image=image_vis
        )"""

        mask_image = postprocess_result['mask_image']

        for i in range(CFG.TRAIN.EMBEDDING_FEATS_DIMS):
            # __C.TRAIN.EMBEDDING_FEATS_DIMS = 4
            instance_seg_image[0][:, :, i] = minmax_scale(instance_seg_image[0][:, :, i])
            # 与 instance_seg_image[0][:, :, i] =
            # cv2.normalize(instance_seg_image[0][:, :, i], None, 0, 255, norm_type=cv2.NORM_MINMAX, dtype=cv2.CV_8UC3)
            # 功能相同
            # 将bgr彩色矩阵归一化到0-255之间
        embedding_image = np.array(instance_seg_image[0], np.uint8)

        # for op in tf.get_default_graph().get_operations():
        #     print(str(op.name))

        # print([n.name for n in tf.get_default_graph().as_graph_def().node])

        plt.figure('mask_image')
        # plt.imshow(mask_image[:, :, (2, 1, 0)])
        plt.imshow(mask_image)
        plt.figure('src_image')
        plt.imshow(image_vis[:, :, (2, 1, 0)])
        plt.figure('instance_image')
        plt.imshow(embedding_image[:, :, (2, 1, 0)])
        plt.figure('binary_image')
        plt.imshow(binary_seg_image[0] * 255, cmap='gray')
        """"
        plt.figure("result")
        plt.imshow(postprocess_result['source_image'])
        """
        plt.show()

        cv2.imwrite('instance_mask_image.png', mask_image)
        cv2.imwrite('source_image.png', postprocess_result['source_image'])
        cv2.imwrite('binary_mask_image.png', binary_seg_image[0] * 255)

    sess.close()

    return
Пример #17
0
    def train(self, config):
        d_optim = tf.train.AdamOptimizer(config.learning_rate, beta1=config.beta1) \
                  .minimize(self.d_loss, var_list=self.d_vars)
        g_optim = tf.train.AdamOptimizer(config.learning_rate, beta1=config.beta1) \
                  .minimize(self.g_loss, var_list=self.g_vars)
        try:
            tf.global_variables_initializer().run()
        except:
            tf.initialize_all_variables().run()

        if config.G_img_sum:
            self.g_sum = merge_summary([
                self.z_sum, self.d__sum, self.G_sum, self.d_loss_fake_sum,
                self.g_loss_sum
            ])
        else:
            self.g_sum = merge_summary([
                self.z_sum, self.d__sum, self.d_loss_fake_sum, self.g_loss_sum
            ])
        self.d_sum = merge_summary(
            [self.z_sum, self.d_sum, self.d_loss_real_sum, self.d_loss_sum])
        self.writer = SummaryWriter(os.path.join(self.out_dir, "logs"),
                                    self.sess.graph)

        sample_z = gen_random(config.z_dist,
                              size=(self.sample_num, self.z_dim))

        if config.dataset == 'mnist':
            sample_inputs = self.data_X[0:self.sample_num]
            sample_labels = self.data_y[0:self.sample_num]
        else:
            sample_files = self.data[0:self.sample_num]
            sample = [
                get_image(sample_file,
                          input_height=self.input_height,
                          input_width=self.input_width,
                          resize_height=self.output_height,
                          resize_width=self.output_width,
                          crop=self.crop,
                          grayscale=self.grayscale)
                for sample_file in sample_files
            ]
            if (self.grayscale):
                sample_inputs = np.array(sample).astype(np.float32)[:, :, :,
                                                                    None]
            else:
                sample_inputs = np.array(sample).astype(np.float32)

        counter = 1
        start_time = time.time()
        could_load, checkpoint_counter = self.load(self.checkpoint_dir)
        if could_load:
            counter = checkpoint_counter
            print(" [*] Load SUCCESS")
        else:
            print(" [!] Load failed...")

        for epoch in xrange(config.epoch):

            if counter > 15:
                exit()

            if config.dataset == 'mnist':
                batch_idxs = min(len(self.data_X),
                                 config.train_size) // config.batch_size
            else:
                self.data = glob(
                    os.path.join(config.data_dir, config.dataset,
                                 self.input_fname_pattern))
                np.random.shuffle(self.data)
                batch_idxs = min(len(self.data),
                                 config.train_size) // config.batch_size

            for idx in xrange(0, int(batch_idxs)):
                if config.dataset == 'mnist':
                    batch_images = self.data_X[idx *
                                               config.batch_size:(idx + 1) *
                                               config.batch_size]
                    batch_labels = self.data_y[idx *
                                               config.batch_size:(idx + 1) *
                                               config.batch_size]
                else:
                    batch_files = self.data[idx * config.batch_size:(idx + 1) *
                                            config.batch_size]
                    batch = [
                        get_image(batch_file,
                                  input_height=self.input_height,
                                  input_width=self.input_width,
                                  resize_height=self.output_height,
                                  resize_width=self.output_width,
                                  crop=self.crop,
                                  grayscale=self.grayscale)
                        for batch_file in batch_files
                    ]
                    if self.grayscale:
                        batch_images = np.array(batch).astype(
                            np.float32)[:, :, :, None]
                    else:
                        batch_images = np.array(batch).astype(np.float32)

                batch_z = gen_random(config.z_dist, size=[config.batch_size, self.z_dim]) \
                      .astype(np.float32)

                if config.dataset == 'mnist':

                    if counter == 10:

                        # add by Jiaolin
                        # Create a profiler.
                        from tensorflow.python.profiler import model_analyzer
                        from tensorflow.python.profiler import option_builder

                        profiler = model_analyzer.Profiler(self.sess.graph)

                        run_options = tf.RunOptions(
                            trace_level=tf.RunOptions.FULL_TRACE)
                        run_metadata = tf.RunMetadata()

                        # Update D network
                        _, summary_str = self.sess.run(
                            [d_optim, self.d_sum],
                            feed_dict={
                                self.inputs: batch_images,
                                self.z: batch_z,
                                self.y: batch_labels,
                            }  #)
                            ,
                            options=run_options,
                            run_metadata=run_metadata)

                        self.writer.add_summary(summary_str, counter)

                        # add by jiaolin
                        # Print to stdout an analysis of the memory usage and the timing information
                        # broken down by operation types.
                        print("Update D network at step=%d" % counter)
                        tf.profiler.profile(
                            tf.get_default_graph(),
                            run_meta=run_metadata,
                            cmd='op',
                            options=tf.profiler.ProfileOptionBuilder.
                            time_and_memory())

                        profile_result = "timeline.mnist.gpu.D-network-update.step-%d.umem-%s.batchsize-%d.json" % (
                            counter, UNIFIED_MEMORY_SET, self.batch_size)

                        print("profile_result=", profile_result)

                        # Create the Timeline object, and write it to a json
                        tl = timeline.Timeline(run_metadata.step_stats)
                        ctf = tl.generate_chrome_trace_format()
                        with open(profile_result, 'w') as tlf:
                            tlf.write(ctf)

                        # Update G network
                        _, summary_str = self.sess.run(
                            [g_optim, self.g_sum],
                            feed_dict={
                                self.z: batch_z,
                                self.y: batch_labels,
                            }  #)
                            ,
                            options=run_options,
                            run_metadata=run_metadata)

                        self.writer.add_summary(summary_str, counter)

                        # add by jiaolin
                        # Print to stdout an analysis of the memory usage and the timing information
                        # broken down by operation types.
                        print("Update G network at step=%d" % counter)
                        tf.profiler.profile(
                            tf.get_default_graph(),
                            run_meta=run_metadata,
                            cmd='op',
                            options=tf.profiler.ProfileOptionBuilder.
                            time_and_memory())

                        profile_result = "timeline.mnist.gpu.G-network-update.step-%d.umem-%s.batchsize-%d.json" % (
                            counter, UNIFIED_MEMORY_SET, self.batch_size)

                        print("profile_result=", profile_result)

                        # Create the Timeline object, and write it to a json
                        tl = timeline.Timeline(run_metadata.step_stats)
                        ctf = tl.generate_chrome_trace_format()
                        with open(profile_result, 'w') as tlf:
                            tlf.write(ctf)

                    else:
                        # Update D network
                        _, summary_str = self.sess.run(
                            [d_optim, self.d_sum],
                            feed_dict={
                                self.inputs: batch_images,
                                self.z: batch_z,
                                self.y: batch_labels,
                            })
                        self.writer.add_summary(summary_str, counter)

                        # Update G network
                        _, summary_str = self.sess.run([g_optim, self.g_sum],
                                                       feed_dict={
                                                           self.z: batch_z,
                                                           self.y:
                                                           batch_labels,
                                                       })
                        self.writer.add_summary(summary_str, counter)

                    # Run g_optim twice to make sure that d_loss does not go to zero (different from paper)
                    _, summary_str = self.sess.run([g_optim, self.g_sum],
                                                   feed_dict={
                                                       self.z: batch_z,
                                                       self.y: batch_labels
                                                   })
                    self.writer.add_summary(summary_str, counter)

                    errD_fake = self.d_loss_fake.eval({
                        self.z: batch_z,
                        self.y: batch_labels
                    })
                    errD_real = self.d_loss_real.eval({
                        self.inputs: batch_images,
                        self.y: batch_labels
                    })
                    errG = self.g_loss.eval({
                        self.z: batch_z,
                        self.y: batch_labels
                    })
                else:

                    if counter == 10:

                        # add by Jiaolin
                        # Create a profiler.
                        from tensorflow.python.profiler import model_analyzer
                        from tensorflow.python.profiler import option_builder

                        profiler = model_analyzer.Profiler(self.sess.graph)

                        run_options = tf.RunOptions(
                            trace_level=tf.RunOptions.FULL_TRACE)
                        run_metadata = tf.RunMetadata()

                        # Update D network
                        _, summary_str = self.sess.run(
                            [d_optim, self.d_sum],
                            feed_dict={
                                self.inputs: batch_images,
                                self.z: batch_z
                            }  #)
                            ,
                            options=run_options,
                            run_metadata=run_metadata)

                        # add by jiaolin
                        # Print to stdout an analysis of the memory usage and the timing information
                        # broken down by python codes.
                        #ProfileOptionBuilder = tf.profiler.ProfileOptionBuilder
                        #opts = ProfileOptionBuilder(ProfileOptionBuilder.time_and_memory()) #.with_node_names(show_name_regexes=['.*my_code.py.*']).build()

                        # add by jiaolin
                        # Print to stdout an analysis of the memory usage and the timing information
                        # broken down by operation types.
                        print("Update D network at step=%d" % counter)
                        tf.profiler.profile(
                            tf.get_default_graph(),
                            run_meta=run_metadata,
                            cmd='op',
                            options=tf.profiler.ProfileOptionBuilder.
                            time_and_memory())

                        profile_result = "timeline.celeba.gpu.D-network-update.step-%d.umem-%s.batchsize-%d.json" % (
                            counter, UNIFIED_MEMORY_SET, self.batch_size)

                        print("profile_result=", profile_result)

                        # Create the Timeline object, and write it to a json
                        tl = timeline.Timeline(run_metadata.step_stats)
                        ctf = tl.generate_chrome_trace_format()
                        with open(profile_result, 'w') as tlf:
                            tlf.write(ctf)

                        #add by jiaolin
                        #profiler.add_step(1, run_metadata)

                        self.writer.add_summary(summary_str, counter)

                        # Update G network
                        _, summary_str = self.sess.run(
                            [g_optim, self.g_sum],
                            feed_dict={self.z: batch_z}  #)
                            ,
                            options=run_options,
                            run_metadata=run_metadata)

                        # add by jiaolin
                        # Print to stdout an analysis of the memory usage and the timing information
                        # broken down by operation types.
                        print("Update G network at step=%d" % counter)
                        tf.profiler.profile(
                            tf.get_default_graph(),
                            run_meta=run_metadata,
                            cmd='op',
                            options=tf.profiler.ProfileOptionBuilder.
                            time_and_memory())

                        profile_result = "timeline.gpu.G-network-update.step-%d.umem-%s.batchsize-%d.json" % (
                            counter, UNIFIED_MEMORY_SET, self.batch_size)

                        print("profile_result=", profile_result)

                        # Create the Timeline object, and write it to a json
                        tl = timeline.Timeline(run_metadata.step_stats)
                        ctf = tl.generate_chrome_trace_format()
                        with open(profile_result, 'w') as tlf:
                            tlf.write(ctf)

                        #add by jiaolin
                        #profiler.add_step(2, run_metadata)

                        self.writer.add_summary(summary_str, counter)

                        # Run g_optim twice to make sure that d_loss does not go to zero (different from paper)
                        _, summary_str = self.sess.run(
                            [g_optim, self.g_sum], feed_dict={self.z: batch_z})
                        #,options=run_options
                        #,run_metadata=run_metadata)

                        self.writer.add_summary(summary_str, counter)

                    else:
                        # Update D network
                        _, summary_str = self.sess.run([d_optim, self.d_sum],
                                                       feed_dict={
                                                           self.inputs:
                                                           batch_images,
                                                           self.z: batch_z
                                                       })

                        self.writer.add_summary(summary_str, counter)

                        # Update G network
                        _, summary_str = self.sess.run(
                            [g_optim, self.g_sum], feed_dict={self.z: batch_z})

                        self.writer.add_summary(summary_str, counter)

                        # Run g_optim twice to make sure that d_loss does not go to zero (different from paper)
                        _, summary_str = self.sess.run(
                            [g_optim, self.g_sum], feed_dict={self.z: batch_z})
                        self.writer.add_summary(summary_str, counter)

                    errD_fake = self.d_loss_fake.eval({self.z: batch_z})
                    errD_real = self.d_loss_real.eval(
                        {self.inputs: batch_images})
                    errG = self.g_loss.eval({self.z: batch_z})

                print("[%8d Epoch:[%2d/%2d] [%4d/%4d] time: %4.4f, d_loss: %.8f, g_loss: %.8f" \
                  % (counter, epoch, config.epoch, idx, batch_idxs,
                    time.time() - start_time, errD_fake+errD_real, errG))

                if np.mod(counter, config.sample_freq) == 0:
                    if config.dataset == 'mnist':
                        samples, d_loss, g_loss = self.sess.run(
                            [self.sampler, self.d_loss, self.g_loss],
                            feed_dict={
                                self.z: sample_z,
                                self.inputs: sample_inputs,
                                self.y: sample_labels,
                            })
                        save_images(
                            samples, image_manifold_size(samples.shape[0]),
                            './{}/train_{:08d}.png'.format(
                                config.sample_dir, counter))
                        print("[Sample] d_loss: %.8f, g_loss: %.8f" %
                              (d_loss, g_loss))
                    else:
                        try:
                            samples, d_loss, g_loss = self.sess.run(
                                [self.sampler, self.d_loss, self.g_loss],
                                feed_dict={
                                    self.z: sample_z,
                                    self.inputs: sample_inputs,
                                },
                            )
                            save_images(
                                samples, image_manifold_size(samples.shape[0]),
                                './{}/train_{:08d}.png'.format(
                                    config.sample_dir, counter))
                            print("[Sample] d_loss: %.8f, g_loss: %.8f" %
                                  (d_loss, g_loss))
                        except:
                            print("one pic error!...")

                if np.mod(counter, config.ckpt_freq) == 0:
                    self.save(config.checkpoint_dir, counter)

                counter += 1
Пример #18
0
def train(log_dir, args, hparams):
    voicefilter_audio = Audio(hparams)

    save_dir = os.path.join(log_dir, 'extract_pretrained')
    plot_dir = os.path.join(log_dir, 'plots')
    wav_dir = os.path.join(log_dir, 'wavs')
    spec_dir = os.path.join(log_dir, 'spec-spectrograms')
    eval_dir = os.path.join(log_dir, 'eval-dir')
    #eval_plot_dir = os.path.join(eval_dir, 'plots')
    eval_wav_dir = os.path.join(eval_dir, 'wavs')
    tensorboard_dir = os.path.join(log_dir, 'extractron_events')
    meta_folder = os.path.join(log_dir, 'metas')

    os.makedirs(save_dir, exist_ok=True)
    os.makedirs(plot_dir, exist_ok=True)
    os.makedirs(wav_dir, exist_ok=True)
    os.makedirs(spec_dir, exist_ok=True)
    os.makedirs(eval_dir, exist_ok=True)
    #os.makedirs(eval_plot_dir, exist_ok=True)
    os.makedirs(eval_wav_dir, exist_ok=True)
    os.makedirs(tensorboard_dir, exist_ok=True)
    os.makedirs(meta_folder, exist_ok=True)

    checkpoint_path = os.path.join(save_dir, 'extractron_model.ckpt')
    checkpoint_path2 = os.path.join(save_dir, 'super_extractron_model.ckpt')
    #input_paths = [os.path.join(args.base_dir, args.extractron_input)]
    #if args.extractron_inputs:
    #    input_paths = [os.path.join(args.base_dir, arg_input_path)
    #                   for arg_input_path in args.extractron_inputs]
    #if args.extractron_input_glob:
    #    input_paths = glob.glob(args.extractron_input_glob)

    log('Checkpoint path: {}'.format(checkpoint_path))
    log('Using model: {}'.format(args.model))
    log(hparams_debug_string())

    # Start by setting a seed for repeatability
    tf.set_random_seed(hparams.extractron_random_seed)

    # Set up data feeder
    with tf.variable_scope('datafeeder'):
        feeder = Feeder(hparams)
        feeder.setup_dataset(args.dataset, args.eval_dataset)

        class DotDict(dict):
            """
            a dictionary that supports dot notation
            as well as dictionary access notation
            usage: d = DotDict() or d = DotDict({'val1':'first'})
            set attributes: d.val2 = 'second' or d['val2'] = 'second'
            get attributes: d.val2 or d['val2']
            """
            __getattr__ = dict.__getitem__
            __setattr__ = dict.__setitem__
            __delattr__ = dict.__delitem__

            def __init__(self, dct):
                for key, value in dct.items():
                    if hasattr(value, 'keys'):
                        value = DotDict(value)
                    self[key] = value

        dictkeys = [
            'target_linear', 'mixed_linear', 'target_mel', 'mixed_mel',
            'spkid_embeddings'
        ]
        eval_dictkeys = [
            'eval_target_linear', 'eval_mixed_linear', 'eval_target_phase',
            'eval_mixed_phase', 'eval_target_mel', 'eval_mixed_mel',
            'eval_spkid_embeddings'
        ]
        feeder_dict = DotDict(dict(zip(dictkeys, feeder.next)))
        feeder_dict.update(DotDict(dict(zip(eval_dictkeys, feeder.eval_next))))

    # Set up model:
    global_step = tf.Variable(0, name='global_step', trainable=False)
    model, stats = model_train_mode(args, feeder_dict, hparams, global_step)
    eval_model = model_test_mode(args, feeder_dict, hparams, global_step)

    # Book keeping
    step = 0
    time_window = ValueWindow(100)
    loss_window = ValueWindow(100)
    saver = tf.train.Saver(max_to_keep=5)
    saver2 = tf.train.Saver(max_to_keep=15)

    log('Extractron training set to a maximum of {} steps'.format(
        args.extractron_train_steps))

    # Memory allocation on the GPU as needed
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    #config.log_device_placement = True
    config.allow_soft_placement = True

    # Train
    with tf.Session(config=config) as sess:
        try:
            #summary_writer = tf.summary.FileWriter(tensorboard_dir, sess.graph)
            xsummary_writer = SummaryWriter(tensorboard_dir)

            sess.run(tf.global_variables_initializer())

            # saved model restoring
            if args.restore:
                # Restore saved model if the user requested it, default = True
                try:
                    checkpoint_state = tf.train.get_checkpoint_state(save_dir)

                    if (checkpoint_state
                            and checkpoint_state.model_checkpoint_path):
                        log('Loading checkpoint {}'.format(
                            checkpoint_state.model_checkpoint_path),
                            slack=True)
                        saver.restore(sess,
                                      checkpoint_state.model_checkpoint_path)

                    else:
                        log('No model to load at {}'.format(save_dir),
                            slack=True)
                        saver.save(sess,
                                   checkpoint_path,
                                   global_step=global_step)

                except tf.errors.OutOfRangeError as e:
                    log('Cannot restore checkpoint: {}'.format(e), slack=True)
            else:
                log('Starting new training!', slack=True)
                saver.save(sess, checkpoint_path, global_step=global_step)

            if hparams.tfprof or hparams.timeline:
                run_options = tf.RunOptions(
                    trace_level=tf.RunOptions.FULL_TRACE)
                run_metadata = tf.RunMetadata()
                if hparams.timeline:
                    from tensorflow.python.client import timeline
                if hparams.tfprof:
                    from tensorflow.python.profiler import model_analyzer, option_builder
                    my_profiler = model_analyzer.Profiler(graph=sess.graph)
                    profile_op_builder = option_builder.ProfileOptionBuilder()
                    profile_op_builder.select(['micros', 'occurrence'])
                    profile_op_builder.order_by('micros')
                    #profile_op_builder.select(['device', 'bytes', 'peak_bytes'])
                    #profile_op_builder.order_by('bytes')
                    profile_op_builder.with_max_depth(
                        20)  # can be any large number
                    profile_op_builder.with_file_output('profile.log')
                    profile_op = profile_op_builder.build()

            # Training loop
            while step < args.extractron_train_steps:
                start_time = time.time()
                # from tensorflow.python import debug as tf_debug
                # sess=tf_debug.LocalCLIDebugWrapperSession(sess)
                if hparams.tfprof or hparams.timeline:
                    step, loss, opt = sess.run(
                        [global_step, model.loss, model.optimize],
                        options=run_options,
                        run_metadata=run_metadata)
                    if hparams.timeline:
                        fetched_timeline = timeline.Timeline(
                            run_metadata.step_stats)
                        chrome_trace = fetched_timeline.generate_chrome_trace_format(
                            show_dataflow=True, show_memory=True)
                        with open('timeline_01.json', 'w') as f:
                            f.write(chrome_trace)
                    if hparams.tfprof:
                        my_profiler.add_step(step=int(step),
                                             run_meta=run_metadata)
                        my_profiler.profile_name_scope(profile_op)
                else:
                    step, loss, opt = sess.run(
                        [global_step, model.loss, model.optimize])
                time_window.append(time.time() - start_time)
                loss_window.append(loss)
                message = \
                'Step {:7d} [{:.3f} sec/step, {:.3f} sec/step, loss={:.5f}, avg_loss={:.5f}]'.format(
                    step, time.time() - start_time, time_window.average, loss, loss_window.average)

                log(message,
                    end='\r',
                    slack=(step % args.checkpoint_interval == 0))

                # Originally assume 100 means loss exploded, now change to 1000 due to waveglow settings
                if loss > 100 or np.isnan(loss):
                    log('Loss exploded to {:.5f} at step {}'.format(
                        loss, step))
                    raise Exception('Loss exploded')

                if step % args.summary_interval == 0:
                    log('\nWriting summary at step {}'.format(step))
                    add_train_summary(xsummary_writer, step, loss)
                    #summary_writer.add_summary(sess.run(stats), step)
                    #summary_writer.flush()

                if step % args.gc_interval == 0:
                    log('\nGarbage collect: {}\n'.format(gc.collect()))

                if step % args.eval_interval == 0:
                    # Run eval and save eval stats
                    log('\nRunning evaluation at step {}'.format(step))

                    #1. avg loss, before, after, predicted mag, mixed phase, mixed_mag, target phase, target_mag
                    #2. 3 wavs
                    #3. 3 mag specs
                    #4. sdr

                    eval_losses = []
                    before_losses = []
                    after_losses = []
                    linear_losses = []

                    for i in tqdm(range(args.test_steps)):
                        try:
                            eloss, before_loss, after_loss, linear_loss, \
                            mixed_phase, mixed_mel, mixed_linear, \
                            target_phase, target_mel, target_linear, \
                            predicted_linear = sess.run([
                                eval_model.tower_loss[0], eval_model.tower_before_loss[0], eval_model.tower_after_loss[0], eval_model.tower_linear_loss[0],
                                eval_model.tower_mixed_phase[0][0], eval_model.tower_mixed_mel[0][0],
                                eval_model.tower_mixed_linear[0][0],
                                eval_model.tower_target_phase[0][0], eval_model.tower_target_mel[0][0],
                                eval_model.tower_target_linear[0][0],
                                eval_model.tower_linear_outputs[0][0]
                            ])
                            eval_losses.append(eloss)
                            before_losses.append(before_loss)
                            after_losses.append(after_loss)
                            linear_losses.append(linear_loss)
                            #if i==0:
                            #    tmp_phase=mixed_phase
                            #    tmp_spec=mixed_spec
                        except tf.errors.OutOfRangeError:
                            log('\n test dataset out of range')
                            pass

                    eval_loss = sum(eval_losses) / len(eval_losses)
                    before_loss = sum(before_losses) / len(before_losses)
                    after_loss = sum(after_losses) / len(after_losses)
                    linear_loss = sum(linear_losses) / len(linear_losses)

                    #mixed_wav = voicefilter_audio.spec2wav(tmp_spec, tmp_phase)
                    mixed_wav = voicefilter_audio.spec2wav(
                        mixed_linear, mixed_phase)
                    target_wav = voicefilter_audio.spec2wav(
                        target_linear, target_phase)
                    predicted_wav = voicefilter_audio.spec2wav(
                        predicted_linear, mixed_phase)
                    librosa.output.write_wav(
                        os.path.join(eval_wav_dir,
                                     'step-{}-eval-mixed.wav'.format(step)),
                        mixed_wav, hparams.sample_rate)
                    librosa.output.write_wav(
                        os.path.join(eval_wav_dir,
                                     'step-{}-eval-target.wav'.format(step)),
                        target_wav, hparams.sample_rate)
                    librosa.output.write_wav(
                        os.path.join(
                            eval_wav_dir,
                            'step-{}-eval-predicted.wav'.format(step)),
                        predicted_wav, hparams.sample_rate)
                    #audio.save_wav(mixed_wav, os.path.join(
                    #    eval_wav_dir, 'step-{}-eval-mixed.wav'.format(step)), sr=hparams.sample_rate)
                    #audio.save_wav(target_wav, os.path.join(
                    #    eval_wav_dir, 'step-{}-eval-target.wav'.format(step)), sr=hparams.sample_rate)
                    #audio.save_wav(predicted_wav, os.path.join(
                    #    eval_wav_dir, 'step-{}-eval-predicted.wav'.format(step)), sr=hparams.sample_rate)

                    mixed_linear_img = plot_spectrogram_to_numpy(
                        mixed_linear.T)
                    target_linear_img = plot_spectrogram_to_numpy(
                        target_linear.T)
                    predicted_linear_img = plot_spectrogram_to_numpy(
                        predicted_linear.T)

                    #plot.plot_spectrogram(predicted_spec,
                    #        os.path.join(eval_plot_dir, 'step-{}-eval-spectrogram.png'.format(step)),
                    #        title='{}, {}, step={}, loss={:.5f}'.format(args.model, time_string(), step, eval_loss),
                    #        target_spectrogram=target_spec)

                    log('Eval loss for global step {}: {:.3f}'.format(
                        step, eval_loss))
                    log('Writing eval summary!')

                    add_eval_summary(xsummary_writer, step, before_loss,
                                     after_loss, linear_loss, eval_loss,
                                     hparams.sample_rate, mixed_wav,
                                     target_wav, predicted_wav,
                                     mixed_linear_img, target_linear_img,
                                     predicted_linear_img)

                if step % args.super_checkpoint_interval == 0 or step == args.extractron_train_steps:
                    # Save model and current global step
                    saver2.save(sess,
                                checkpoint_path2,
                                global_step=global_step)

                if step % args.checkpoint_interval == 0 or step == args.extractron_train_steps:
                    # Save model and current global step
                    saver.save(sess, checkpoint_path, global_step=global_step)

                    #log('\nSaving alignment, Mel-Spectrograms and griffin-lim inverted waveform..')

                    #input_seq, mel_prediction, alignment, target, target_length = sess.run([
                    #    model.tower_inputs[0][0],
                    #    model.tower_mel_outputs[0][0],
                    #    model.tower_alignments[0][0],
                    #    model.tower_mel_targets[0][0],
                    #    model.tower_targets_lengths[0][0],
                    #])

                    ## save predicted mel spectrogram to disk (debug)
                    #mel_filename = 'mel-prediction-step-{}.npy'.format(step)
                    #np.save(os.path.join(mel_dir, mel_filename),
                    #        mel_prediction.T, allow_pickle=False)

                    ## save griffin lim inverted wav for debug (mel -> wav)
                    #wav = audio.inv_mel_spectrogram(mel_prediction.T, hparams)
                    #audio.save_wav(wav, os.path.join(
                    #    wav_dir, 'step-{}-wave-from-mel.wav'.format(step)), sr=hparams.sample_rate)

                    ## save alignment plot to disk (control purposes)
                    #plot.plot_alignment(alignment, os.path.join(plot_dir, 'step-{}-align.png'.format(step)),
                    #                    title='{}, {}, step={}, loss={:.5f}'.format(
                    #                        args.model, time_string(), step, loss),
                    #                    max_len=target_length // hparams.outputs_per_step)
                    ## save real and predicted mel-spectrogram plot to disk (control purposes)
                    #plot.plot_spectrogram(mel_prediction, os.path.join(plot_dir, 'step-{}-mel-spectrogram.png'.format(step)),
                    #                      title='{}, {}, step={}, loss={:.5f}'.format(args.model, time_string(), step, loss), target_spectrogram=target,
                    #                      max_len=target_length)
                    #log('Input at step {}: {}'.format(
                    #    step, sequence_to_text(input_seq)))

            log('Extractron training complete after {} global steps!'.format(
                args.extractron_train_steps),
                slack=True)
            return save_dir

        except Exception as e:
            log('Exiting due to exception: {}'.format(e), slack=True)
            traceback.print_exc()
Пример #19
0
def run_model(model,
              horovod=False,
              gpu_num=1,
              output=None,
              steptime=False,
              profile=False,
              timeline=False,
              loss=False,
              session=1,
              step=1,
              batchsize=None,
              graph=False):
    # TODO: description

    # cannot dump graph if timeline or profile is On
    if graph and (timeline or profile):
        raise ValueError("cannot dump graph togother with timeline or tfprof")

    with tf.Graph().as_default():

        times_list = []
        losses_list = []
        op, _loss = tf_model.get_model(model, batchsize, horovod=horovod)

        # set gpus available
        config = tf.ConfigProto()
        if horovod is True:
            config.gpu_options.allow_growth = False
            config.gpu_options.visible_device_list = str(hvd.local_rank())
            # print('DEBUG: ', str(hvd.local_rank()))
        else:
            # buildup gpus='0,1,2...'
            config.gpu_options.allow_growth = False
            gpus = ','.join(map(str, range(gpu_num)))
            print('DEBUG: gpus=%s' % gpus)
            config.gpu_options.visible_device_list = gpus

        for i in range(session):

            sess = tf.Session(config=config)
            sess.run(tf.global_variables_initializer())
            times = []
            losses = []

            opts = None
            run_metadata = None

            # the dump graph mode on
            if graph:
                opts = tf.RunOptions(output_partition_graphs=True)
                run_metadata = tf.RunMetadata()
            # the profile mode on
            elif profile or timeline:
                # create runOptions and run_metadata object
                opts = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE)
                run_metadata = tf.RunMetadata()

                if profile:
                    # Create a profiler.
                    profiler = model_analyzer.Profiler(sess.graph)
            for n in range(step):
                start_time = time.time()

                # run model
                if loss is True:
                    res = sess.run([op, _loss],
                                   options=opts,
                                   run_metadata=run_metadata)
                    losses.append(res[1])
                else:

                    res = sess.run(op, options=opts, run_metadata=run_metadata)

                train_time = time.time() - start_time
                times.append(train_time)

                # print steptime and loss at realtime
                if loss is True:
                    print('Sess%d/%d Step%d/%d: time=%.2fms loss=%.2f' %
                          (i + 1, session, n + 1, step, train_time * 1000,
                           res[1]))
                else:
                    print('Sess%d/%d Step%d/%d: time=%.2fms' %
                          (i + 1, session, n + 1, step, train_time * 1000))
                if (not graph) and profile:
                    profiler.add_step(step=step, run_meta=run_metadata)

            times_list.append(times)
            losses_list.append(losses)

        if output is not None:

            # make folder if it not exist
            try:
                if not os.path.exists(output):
                    os.makedirs(output)
            except (FileExistsError):
                print("")

            file_loss = '_lossOn' if loss else ''
            file_trace = '_traceOn' if profile or timeline else ''
            file_horovod = '_hvdRank%d' % hvd.rank() if horovod else ''
            file_batchsize = '_bs%d' % batchsize if batchsize is not None\
                else '_bsDefault'
            file_gpunum = '_gpunum%d' % gpu_num

            if steptime is True:
                filename = '%s%s%s%s%s%s_steptime.csv' %\
                    (model, file_batchsize, file_loss, file_trace,
                        file_horovod, file_gpunum)
                output_csv(filename, times_list, path=output, scale=1000)

            if loss is True:
                filename = '%s%s%s%s%s%s_loss.csv' % \
                    (model, file_batchsize, file_loss, file_trace,
                        file_horovod, file_gpunum)
                output_csv(filename, losses_list, path=output)

            if graph:
                # save each partition of graph with _output_shapes attr

                if horovod:
                    graph_dir = os.path.join(
                        output, '%s%s%s%s_partitionGraph' %
                        (model, file_batchsize, file_loss, file_gpunum),
                        str(hvd.rank()))
                    if not os.path.exists(graph_dir):
                        os.makedirs(graph_dir)
                    save_partition_graph_shapes(run_metadata, graph_dir,
                                                'graph')
                else:
                    save_partition_graph_shapes(
                        run_metadata, output, '%s%s%s%s%s_partitionGraph' %
                        (model, file_batchsize, file_loss, file_horovod,
                         file_gpunum))

            if profile is True:
                filename = '%s%s%s%s%s_gpunum%d.profile' % \
                    (model, file_batchsize, file_loss, file_trace,
                        file_horovod, gpu_num)
                filepath = output + '/' + filename
                generate_tfprof_profile(profiler, filepath)

            if timeline is True:
                filename = '%s%s%s%s%s_gpunum%d.timeline' % \
                    (model, file_batchsize, file_loss, file_trace,
                        file_horovod, gpu_num)
                filepath = output + '/' + filename
                tl = _timeline.Timeline(run_metadata.step_stats)
                ctf = tl.generate_chrome_trace_format()
                with open(filepath, 'w') as f:
                    f.write(ctf)
Пример #20
0
  def testProfileBasic(self):
    ops.reset_default_graph()
    outfile = os.path.join(test.get_temp_dir(), 'dump')
    opts = (builder(builder.trainable_variables_parameter())
            .with_file_output(outfile)
            .with_accounted_types(['.*'])
            .select(['params', 'float_ops', 'micros', 'bytes',
                     'device', 'op_types', 'occurrence']).build())

    # Test the output without run_meta.
    sess = session.Session()
    r = lib.BuildFullModel()
    sess.run(variables.global_variables_initializer())

    # Test the output with run_meta.
    run_meta = config_pb2.RunMetadata()
    _ = sess.run(r,
                 options=config_pb2.RunOptions(
                     trace_level=config_pb2.RunOptions.FULL_TRACE),
                 run_metadata=run_meta)

    profiler = model_analyzer.Profiler(sess.graph)
    profiler.add_step(1, run_meta)
    profiler.profile_graph(opts)
    with gfile.Open(outfile, 'r') as f:
      profiler_str = f.read()

    model_analyzer.profile(
        sess.graph, cmd='graph', run_meta=run_meta, options=opts)
    with gfile.Open(outfile, 'r') as f:
      pma_str = f.read()
    self.assertEqual(pma_str, profiler_str)

    profiler.profile_name_scope(opts)
    with gfile.Open(outfile, 'r') as f:
      profiler_str = f.read()

    model_analyzer.profile(
        sess.graph, cmd='scope', run_meta=run_meta, options=opts)
    with gfile.Open(outfile, 'r') as f:
      pma_str = f.read()
    self.assertEqual(pma_str, profiler_str)

    profiler.profile_python(opts)
    with gfile.Open(outfile, 'r') as f:
      profiler_str = f.read()

    model_analyzer.profile(
        sess.graph, cmd='code', run_meta=run_meta, options=opts)
    with gfile.Open(outfile, 'r') as f:
      pma_str = f.read()
    self.assertEqual(pma_str, profiler_str)

    profiler.profile_operations(opts)
    with gfile.Open(outfile, 'r') as f:
      profiler_str = f.read()

    model_analyzer.profile(
        sess.graph, cmd='op', run_meta=run_meta, options=opts)
    with gfile.Open(outfile, 'r') as f:
      pma_str = f.read()
    self.assertEqual(pma_str, profiler_str)

    model_analyzer.profile(
        sess.graph, cmd='scope', run_meta=run_meta, options=opts)
    with gfile.Open(outfile, 'r') as f:
      pma_str = f.read()
    self.assertNotEqual(pma_str, profiler_str)
Пример #21
0
def train():
    # global_step
    global_step = tf.Variable(0, name = 'global_step', trainable=False)
    # cifar10 数据文件夹
    data_dir = '../cifar-10-batches-bin/'
    # 训练时的日志logs文件,没有这个目录要先建一个
    train_dir = './logs/'
    # 加载 images,labels
    images, labels =inputs(data_dir, BATCH_SIZE)

    # 求 loss
    loss = losses(inference(images), labels)
    # 设置优化算法,这里用 SGD 随机梯度下降法,恒定学习率
    optimizer = tf.train.GradientDescentOptimizer(LEARNING_RATE)
    # global_step 用来设置初始化
    train_op = optimizer.minimize(loss, global_step = global_step)
    # 保存操作
    saver = tf.train.Saver(tf.all_variables())
    # 汇总操作
    summary_op = tf.summary.merge_all()
    # 初始化方式是初始化所有变量
    init = tf.initialize_all_variables()

    os.environ['CUDA_VISIBLE_DEVICES'] = str(0)

	#自动选择运行设备 : tf.ConfigProto(allow_soft_placement=True),
	#设置tf.ConfigProto()中参数log_device_placement = True ,可以获取到 operations 和 Tensor 被指派到哪个设备(几号CPU或几号GPU)上运行,会在终端打印出各项操作是在哪个设备上运行的。
    config = tf.ConfigProto()
	#动态申请显存
    config.gpu_options.allow_growth = True
    
    options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE)
    run_metadata = tf.RunMetadata()
	
    cluster = tf.train.ClusterSpec({
        'node1':[
            '192.168.136.101:2222'
        ],
        'node2':[
            '192.168.136.102:2222'
        ]
    })
    server = tf.train.Server(cluster, job_name='node1', task_index=0)
    session = tf.Session(target='grpc://192.168.136.102:2222', config=config)
    # 占用 GPU 的 20% 资源
    #config.gpu_options.per_process_gpu_memory_fraction = 0.2
    # 设置会话模式,用 InteractiveSession 可交互的会话,逼格高
    sess = tf.InteractiveSession(config=config)
    # 运行初始化
    sess.run(init)
    profiler = model_analyzer.Profiler(graph=sess.graph)

    # 设置多线程协调器
    coord = tf.train.Coordinator()
    # 开始 Queue Runners (队列运行器)
    threads = tf.train.start_queue_runners(sess = sess, coord = coord)
    # 把汇总写进 train_dir,注意此处还没有运行
    summary_writer = tf.summary.FileWriter(train_dir, sess.graph)

    # 开始训练过程
    for step in range(MAX_STEP):
        if coord.should_stop():
            break
        start_time = time.time()
        # 在会话中运行 loss
        _, loss_value = sess.run([train_op, loss],options=options, run_metadata=run_metadata)
        profiler.add_step(step=step, run_meta=run_metadata)
        duration = time.time() - start_time
        # 确认收敛
        assert not np.isnan(loss_value), 'Model diverged with loss = NaN'
        if step % 30 == 0:
            # 本小节代码设置一些花哨的打印格式,可以不用管
            num_examples_per_step = BATCH_SIZE
            examples_per_sec = num_examples_per_step / duration
            sec_per_batch = float(duration)
            format_str = ('%s: step %d, loss = %.2f (%.1f examples/sec; %.3f '
                          'sec/batch)')
            print (format_str % (datetime.datetime.now(), step, loss_value,
                                 examples_per_sec, sec_per_batch))

        if step % 100 == 0:
            # 运行汇总操作, 写入汇总
            summary_str = sess.run(summary_op)
            summary_writer.add_summary(summary_str, step)

        if step % 1000 == 0 or (step + 1) == MAX_STEP:
            # 保存当前的模型和权重到 train_dir,global_step 为当前的迭代次数
            checkpoint_path = os.path.join(train_dir, 'model.ckpt')
            saver.save(sess, checkpoint_path, global_step=step)

    coord.request_stop()
    coord.join(threads)

        #统计内容为每个graph node的运行时间和占用内存
    profile_graph_opts_builder = option_builder.ProfileOptionBuilder(
      option_builder.ProfileOptionBuilder.time_and_memory())

    #输出方式为timeline
    # 输出文件夹必须存在
    profile_graph_opts_builder.with_timeline_output(timeline_file='/tmp/mnist_profiler.json')
    #定义显示sess.Run() 第70步的统计数据
    profile_graph_opts_builder.with_step(70)

    #显示视图为graph view
    profiler.profile_graph(profile_graph_opts_builder.build())
    sess.close()
Пример #22
0
def main(argv):
    argparser = argparse.ArgumentParser(
        'NTP 2.0', formatter_class=argparse.ArgumentDefaultsHelpFormatter)

    # data
    # WARNING: for countries, it's not necessary to enter the dev/test set as the evaluation does so
    # TODO: fix this behavior - all datasets should have the same behavior
    argparser.add_argument('--train', action='store', type=str)
    argparser.add_argument('--dev', action='store', type=str, default=None)
    argparser.add_argument('--test', action='store', type=str, default=None)

    argparser.add_argument('--clauses',
                           '-c',
                           action='store',
                           type=str,
                           default=None)
    argparser.add_argument('--mentions',
                           action='store',
                           type=str,
                           default=None)
    argparser.add_argument('--mentions-min',
                           action='store',
                           type=int,
                           default=1)

    # model params
    argparser.add_argument('--embedding-size',
                           '-k',
                           action='store',
                           type=int,
                           default=100)
    argparser.add_argument('--batch-size',
                           '-b',
                           action='store',
                           type=int,
                           default=10)
    # k-max for the new variable
    argparser.add_argument('--k-max',
                           '-m',
                           action='store',
                           type=int,
                           default=None)
    argparser.add_argument('--max-depth',
                           '-M',
                           action='store',
                           type=int,
                           default=1)

    # training params
    argparser.add_argument('--epochs',
                           '-e',
                           action='store',
                           type=int,
                           default=100)
    argparser.add_argument('--learning-rate',
                           '-l',
                           action='store',
                           type=float,
                           default=0.001)
    argparser.add_argument('--clip', action='store', type=float, default=1.0)
    argparser.add_argument('--l2', action='store', type=float, default=0.01)

    argparser.add_argument('--kernel',
                           action='store',
                           type=str,
                           default='rbf',
                           choices=['linear', 'rbf'])

    argparser.add_argument('--auxiliary-loss-weight',
                           '--auxiliary-weight',
                           '--aux-weight',
                           action='store',
                           type=float,
                           default=None)
    argparser.add_argument('--auxiliary-loss-model',
                           '--auxiliary-model',
                           '--aux-model',
                           action='store',
                           type=str,
                           default='complex')
    argparser.add_argument('--auxiliary-epochs',
                           '--aux-epochs',
                           action='store',
                           type=int,
                           default=0)

    argparser.add_argument('--corrupted-pairs',
                           '--corruptions',
                           '-C',
                           action='store',
                           type=int,
                           default=1)
    argparser.add_argument('--all', '-a', action='store_true')

    argparser.add_argument('--retrieve-k-facts',
                           '-F',
                           action='store',
                           type=int,
                           default=None)
    argparser.add_argument('--retrieve-k-rules',
                           '-R',
                           action='store',
                           type=int,
                           default=None)

    argparser.add_argument(
        '--index-type',
        '-i',
        action='store',
        type=str,
        default='nmslib',
        choices=['nmslib', 'faiss', 'faiss-cpu', 'random', 'exact'])

    argparser.add_argument('--index-refresh-rate',
                           '-I',
                           action='store',
                           type=int,
                           default=100)

    argparser.add_argument('--nms-m', action='store', type=int, default=15)
    argparser.add_argument('--nms-efc', action='store', type=int, default=100)
    argparser.add_argument('--nms-efs', action='store', type=int, default=100)

    argparser.add_argument('--evaluation-mode',
                           '-E',
                           action='store',
                           type=str,
                           default='ranking',
                           choices=['ranking', 'countries', 'ntn', 'none'])
    argparser.add_argument('--exact-knn-evaluation',
                           action='store',
                           type=str,
                           default=None,
                           choices=[None, 'faiss', 'exact'])

    argparser.add_argument('--loss-aggregator',
                           action='store',
                           type=str,
                           default='sum',
                           choices=['sum', 'mean'])

    argparser.add_argument('--decode', '-D', action='store_true')
    argparser.add_argument('--seed', action='store', type=int, default=0)

    argparser.add_argument('--keep-prob',
                           action='store',
                           type=float,
                           default=1.0)
    argparser.add_argument('--initializer',
                           action='store',
                           type=str,
                           default='uniform',
                           choices=['uniform', 'xavier'])

    argparser.add_argument('--mixed-losses', action='store_true')
    argparser.add_argument('--mixed-losses-aggregator',
                           action='store',
                           type=str,
                           default='mean',
                           choices=['mean', 'sum'])

    argparser.add_argument(
        '--rule-embeddings-type',
        '--rule-type',
        '-X',
        action='store',
        type=str,
        default='standard',
        choices=['standard', 'attention', 'sparse-attention'])

    argparser.add_argument('--unification-type',
                           '-U',
                           action='store',
                           type=str,
                           default='classic',
                           choices=['classic', 'joint'])

    argparser.add_argument('--unification-aggregation-type',
                           action='store',
                           type=str,
                           default='min',
                           choices=['min', 'mul', 'minmul'])

    argparser.add_argument('--epoch-based-batches', action='store_true')

    argparser.add_argument('--evaluate-per-epoch', action='store_true')

    argparser.add_argument('--no-ntp0', action='store_true')

    # checkpointing and regular model saving / loading - if checkpoint-path is not None - do checkpointing
    argparser.add_argument('--dump-path', type=str, default=None)
    argparser.add_argument('--checkpoint', action='store_true')
    argparser.add_argument('--checkpoint-frequency', type=int, default=1000)
    argparser.add_argument('--save', action='store_true')
    argparser.add_argument('--load', action='store_true')

    argparser.add_argument('--explanation',
                           '--explain',
                           action='store',
                           type=str,
                           default=None,
                           choices=['train', 'dev', 'test'])

    argparser.add_argument('--profile', action='store_true')
    argparser.add_argument('--tf-profiler', action='store_true')
    argparser.add_argument('--tensorboard', action='store_true')
    argparser.add_argument('--multimax', action='store_true')

    argparser.add_argument('--dev-only', action='store_true')

    argparser.add_argument('--only-rules-epochs',
                           action='store',
                           type=int,
                           default=0)
    argparser.add_argument('--test-batch-size',
                           action='store',
                           type=int,
                           default=None)

    argparser.add_argument('--input-type',
                           action='store',
                           type=str,
                           default='standard',
                           choices=['standard', 'reciprocal'])

    argparser.add_argument('--use-concrete', action='store_true')

    args = argparser.parse_args(argv)

    checkpoint = args.checkpoint
    dump_path = args.dump_path
    save = args.save
    load = args.load

    is_explanation = args.explanation

    nb_epochs = args.epochs
    nb_aux_epochs = args.auxiliary_epochs

    arguments_filename = None
    checkpoint_path = None
    if load:
        logger.info("Loading arguments from the loaded model...")
        arguments_filename = os.path.join(dump_path, 'arguments.json')
        checkpoint_path = os.path.join(dump_path, 'final_model/')
        # load a model, if there's one to load
    elif checkpoint and not check_checkpoint_finished(
            os.path.join(dump_path, 'checkpoints/')):
        checkpoint_path = os.path.join(dump_path, 'checkpoints/')
        logger.info("Loading arguments from an unfinished checkpoint...")
        arguments_filename = os.path.join(dump_path, 'arguments.json')

    loading_type = None

    if arguments_filename is not None and os.path.exists(arguments_filename):
        with open(arguments_filename, 'r') as f:
            json_arguments = json.load(f)
        args = argparse.Namespace(**json_arguments)
        if load:
            loading_type = 'model'
        elif checkpoint and not check_checkpoint_finished(
                os.path.join(dump_path, 'checkpoints/')):
            loading_type = 'checkpoint'

        # Load arguments from json

        # args = argparse.Namespace(**json_arguments)

        # args = vars(args)
        # for k, v in json_arguments.items():
        #     if k in args and args[k] != v:
        #         logger.info("\t{}={} (overriding loaded model's value of {})".format(k, args[k], v))
        #     if k not in args:
        #         args[k] = v
        #         logger.info("\t{}={} (overriding loaded model's value of {})".format(k, args[k], v))

    import pprint
    pprint.pprint(vars(args))

    train_path = args.train
    dev_path = args.dev
    test_path = args.test

    clauses_path = args.clauses
    mentions_path = args.mentions
    mentions_min = args.mentions_min

    input_type = args.input_type

    entity_embedding_size = predicate_embedding_size = args.embedding_size
    symbol_embedding_size = args.embedding_size

    batch_size = args.batch_size
    seed = args.seed

    learning_rate = args.learning_rate
    clip_value = args.clip
    l2_weight = args.l2
    kernel_name = args.kernel

    aux_loss_weight = 1.0
    if 'auxiliary_loss_weight' in args:
        aux_loss_weight = args.auxiliary_loss_weight

    aux_loss_model = args.auxiliary_loss_model

    nb_corrupted_pairs = args.corrupted_pairs
    is_all = args.all

    index_type = args.index_type
    index_refresh_rate = args.index_refresh_rate

    retrieve_k_facts = args.retrieve_k_facts
    retrieve_k_rules = args.retrieve_k_rules

    nms_m = args.nms_m
    nms_efc = args.nms_efc
    nms_efs = args.nms_efs

    k_max = args.k_max
    max_depth = args.max_depth

    evaluation_mode = args.evaluation_mode
    exact_knn_evaluation = args.exact_knn_evaluation

    loss_aggregator = args.loss_aggregator

    has_decode = args.decode

    keep_prob = 1.0
    if 'keep_prob' in args:
        keep_prob = args.keep_prob
    initializer_name = args.initializer

    mixed_losses = args.mixed_losses
    mixed_losses_aggregator_type = args.mixed_losses_aggregator

    rule_embeddings_type = args.rule_embeddings_type

    unification_type = args.unification_type
    unification_aggregation_type = args.unification_aggregation_type

    is_no_ntp0 = args.no_ntp0
    checkpoint_frequency = args.checkpoint_frequency

    profile = args.profile
    tf_profiler = args.tf_profiler
    tensorboard = args.tensorboard

    multimax = args.multimax
    dev_only = args.dev_only

    n_only_rules_epochs = args.only_rules_epochs

    test_batch_size = args.test_batch_size

    if test_batch_size is None:
        test_batch_size = batch_size * (1 + nb_corrupted_pairs * 2 *
                                        (2 if is_all else 1))

    # fire up eager
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True

    tf.enable_eager_execution(config=config)

    # set the seeds
    tf.set_random_seed(seed)
    np.random.seed(seed)
    random_state = np.random.RandomState(seed)

    epoch_based_batches = args.epoch_based_batches
    evaluate_per_epoch = args.evaluate_per_epoch
    use_concrete = args.use_concrete

    import multiprocessing

    nms_index_params = {
        'method': 'hnsw',
        'space': 'l2',
        'num_threads': multiprocessing.cpu_count(),
        'm': nms_m,
        'efc': nms_efc,
        'efs': nms_efs
    }

    faiss_index_params = {}
    faiss_index_params_cpu = {}
    try:
        import faiss
        faiss_index_params = {
            'resource':
            faiss.StandardGpuResources() if index_type in {'faiss'} else None
        }
        if faiss_index_params['resource'] is not None:
            faiss_index_params['resource'].noTempMemory()
        faiss_index_params_cpu = {'cpu': True}
    except ImportError:
        pass

    random_index_params = {
        'random_state': random_state,
    }

    index_type_to_params = {
        'nmslib': nms_index_params,
        'faiss-cpu': faiss_index_params_cpu,
        'faiss': faiss_index_params,
        'random': random_index_params,
        'exact': {},
    }

    kernel = gntp.kernels.get_kernel_by_name(kernel_name)

    clauses = []
    if clauses_path:
        with open(clauses_path, 'r') as f:
            clauses += [
                gntp.parse_clause(line.strip()) for line in f.readlines()
            ]

    mention_counts = gntp.read_mentions(mentions_path) if mentions_path else []
    mentions = [(s, pattern, o) for s, pattern, o, c in mention_counts
                if c >= mentions_min]

    data = Data(train_path=train_path,
                dev_path=dev_path,
                test_path=test_path,
                clauses=clauses,
                evaluation_mode=evaluation_mode,
                mentions=mentions,
                input_type=input_type)

    index_store = gntp.lookup.LookupIndexStore(
        index_type=index_type, index_params=index_type_to_params[index_type])

    aux_model = gntp.models.get_model_by_name(aux_loss_model)

    model = gntp.models.NTP(kernel=kernel,
                            max_depth=max_depth,
                            k_max=k_max,
                            retrieve_k_facts=retrieve_k_facts,
                            retrieve_k_rules=retrieve_k_rules,
                            index_refresh_rate=index_refresh_rate,
                            index_store=index_store,
                            unification_type=unification_type)

    neural_kb = NeuralKB(data=data,
                         entity_embedding_size=entity_embedding_size,
                         predicate_embedding_size=predicate_embedding_size,
                         symbol_embedding_size=symbol_embedding_size,
                         model_type='ntp',
                         initializer_name=initializer_name,
                         rule_embeddings_type=rule_embeddings_type,
                         use_concrete=use_concrete)

    optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate)

    if loading_type == 'checkpoint':
        logger.info(
            "********** Resuming from an unfinished checkpoint **********")
        # dirty hack, but this initializes optimizer's slots, so the loader can populate them
        optimizer._create_slots(neural_kb.variables)
        checkpoint_load(checkpoint_path, neural_kb, optimizer)

    elif loading_type == 'model':
        load_path = os.path.join(dump_path, 'final_model/')
        checkpoint_load(load_path, neural_kb, optimizer)

    # bather will always be ran with the starting random_state...
    batcher = Batcher(data,
                      batch_size,
                      nb_epochs,
                      random_state,
                      nb_corrupted_pairs,
                      is_all,
                      nb_aux_epochs,
                      epoch_based_batches=epoch_based_batches)

    batches_per_epoch = batcher.nb_batches / nb_epochs if nb_epochs > 0 else 0

    # ...and after that, if there's a random state to load, load it :)
    if loading_type is not None:
        checkpoint_rs = load_random_state(checkpoint_path)
        random_state.set_state(checkpoint_rs.get_state())

    batch_times = []
    logger.info('Starting training (for {} batches)..'.format(
        len(batcher.batches)))

    if tf.train.get_or_create_global_step().numpy() > 0:
        logger.info(
            '...checkpoint restoration - resuming from batch no {}'.format(
                tf.train.get_or_create_global_step().numpy() + 1))

    if tensorboard:
        # TODO add changeable params too
        if not os.path.exists(dump_path):
            os.makedirs(dump_path)
        else:
            # this should never happen
            pass

        writer = tf.contrib.summary.create_file_writer(dump_path)
        writer.set_as_default()

    per_epoch_losses = []

    if tf_profiler:
        profiler = model_analyzer.Profiler()

    start_training_time = time.time()

    n_epochs_finished = 0

    if profile:
        manager = multiprocessing.Manager()
        gpu_memory_profiler_return = manager.list()

        def gpu_memory_profiler():
            import subprocess
            import os
            env = os.environ.copy()
            which_gpu = -1
            if 'CUDA_VISIBLE_DEVICES' in env:
                try:
                    which_gpu = int(env['CUDA_VISIBLE_DEVICES'])
                except:
                    pass
            del env['LD_LIBRARY_PATH']
            while True:
                time.sleep(0.1)
                cmd = ["nvidia-smi", "--query-gpu=memory.used", "--format=csv"]
                output = subprocess.check_output(cmd, env=env)
                output = output.decode('utf-8')
                output = output.split('\n')
                if len(output) == 3:  # there's only one gpu
                    which_gpu = 0
                output = output[1:-1]
                if which_gpu > -1:
                    gpu_memory_profiler_return.append(
                        int(output[which_gpu].split()[0]))
                else:
                    gpu_memory_profiler_return.append(output)
            return

        gpu_memory_job = multiprocessing.Process(target=gpu_memory_profiler)
        gpu_memory_job.start()

    is_epoch_end = False
    with context.eager_mode():

        for batch_no, (batch_start, batch_end) in enumerate(batcher.batches):

            if tf_profiler:
                opts = (option_builder.ProfileOptionBuilder(
                    option_builder.ProfileOptionBuilder.
                    trainable_variables_parameter()).with_max_depth(
                        100000).with_step(batch_no).with_timeline_output(
                            'eager_profile').with_accounted_types(['.*'
                                                                   ]).build())

                context.enable_run_metadata()

            # print(sum(random_state.get_state()[1]))

            # TODO fix this - this was here due to checkpointing but causes the first batch to be skipped
            # and will likely cause the test to fail?
            # if tf.train.get_or_create_global_step().numpy() + 1 > batch_no:
            #     continue
            if is_explanation is not None:  # or load_model:
                logger.info("EXPLANATION MODE ON - turning training off!")
                break

            start_time = time.time()

            is_epoch_start = is_epoch_end
            is_epoch_end = (batch_no + 1) - int(
                (batch_no + 1) / batches_per_epoch) * batches_per_epoch < 1

            Xi_batch, Xp_batch, Xs_batch, Xo_batch, target_inputs = batcher.get_batch(
                batch_no, batch_start, batch_end)

            Xi_batch = tf.convert_to_tensor(Xi_batch, dtype=tf.int32)

            # goals should be [GE, GE, GE]
            with tf.GradientTape() as tape:

                if n_only_rules_epochs > n_epochs_finished:
                    is_rules_only = True
                else:
                    is_rules_only = False

                neural_kb.create_neural_kb(is_epoch_start, training=True)

                p_emb = tf.nn.embedding_lookup(neural_kb.relation_embeddings,
                                               Xp_batch)
                s_emb = tf.nn.embedding_lookup(neural_kb.entity_embeddings,
                                               Xs_batch)
                o_emb = tf.nn.embedding_lookup(neural_kb.entity_embeddings,
                                               Xo_batch)

                if keep_prob != 1.0:
                    p_emb = tf.nn.dropout(p_emb, keep_prob)
                    s_emb = tf.nn.dropout(s_emb, keep_prob)
                    o_emb = tf.nn.dropout(o_emb, keep_prob)

                if batcher.is_pretraining:
                    # PRE-TRAINING
                    aux_scores = aux_model.predict(p_emb, s_emb, o_emb)
                    loss = aux_model.loss(target_inputs,
                                          aux_scores,
                                          aggregator=loss_aggregator)
                else:

                    goal_scores, other = model.predict(
                        p_emb,
                        s_emb,
                        o_emb,
                        neural_facts_kb=neural_kb.neural_facts_kb,
                        neural_rules_kb=neural_kb.neural_rules_kb,
                        mask_indices=Xi_batch,
                        is_training=True,
                        target_inputs=target_inputs,
                        mixed_losses=mixed_losses,
                        aggregator_type=mixed_losses_aggregator_type,
                        no_ntp0=is_no_ntp0,
                        support_explanations=is_explanation is not None,
                        unification_score_aggregation=
                        unification_aggregation_type,
                        multimax=multimax,
                        tensorboard=tensorboard)

                    proof_states, new_target_inputs = other

                    if multimax:
                        target_inputs = new_target_inputs

                    model_loss = model.loss(target_inputs,
                                            goal_scores,
                                            aggregator=loss_aggregator)
                    loss = model_loss

                    if aux_loss_weight is not None and aux_loss_weight > 0.0:
                        aux_scores = aux_model.predict(p_emb, s_emb, o_emb)
                        loss_aux = aux_loss_weight * aux_model.loss(
                            target_inputs,
                            aux_scores,
                            aggregator=loss_aggregator)
                        loss += loss_aux

                if l2_weight:
                    loss_l2_weight = l2_weight * tf.add_n(
                        [tf.nn.l2_loss(var) for var in neural_kb.variables])
                    if loss_aggregator == 'mean':
                        num_of_vars = tf.reduce_sum([
                            tf.reduce_prod(var.shape)
                            for var in neural_kb.variables
                        ])
                        loss_l2_weight /= tf.cast(num_of_vars, tf.float32)
                    loss += loss_l2_weight

            # if not is_epoch_end:
            per_epoch_losses.append(loss.numpy())

            logger.info('Loss @ batch {} on {}: {}'.format(
                batch_no, batcher.nb_batches, loss))

            model_variables = neural_kb.get_trainable_variables(
                is_rules_only=is_rules_only)
            gradients = tape.gradient(loss, model_variables)
            grads_and_vars = [(tf.clip_by_value(grad, -clip_value,
                                                clip_value), var)
                              for grad, var in zip(gradients, model_variables)]

            optimizer.apply_gradients(
                grads_and_vars=grads_and_vars,
                global_step=tf.train.get_or_create_global_step())

            if tensorboard:
                with tf.contrib.summary.always_record_summaries():
                    tf.contrib.summary.scalar('loss_total', loss)
                    tf.contrib.summary.scalar('loss_ntp_model', model_loss)
                    if aux_loss_weight is not None and aux_loss_weight > 0.0:
                        tf.contrib.summary.scalar('loss_aux_model', loss_aux)
                    if l2_weight != 0.0:
                        tf.contrib.summary.scalar('loss_l2_weight',
                                                  loss_l2_weight)
                    tf.contrib.summary.histogram('embeddings_relation',
                                                 neural_kb.relation_embeddings)
                    tf.contrib.summary.histogram('embeddings_entity',
                                                 neural_kb.entity_embeddings)

                with tf.contrib.summary.always_record_summaries():
                    for grad, var in grads_and_vars:
                        tf.contrib.summary.scalar(
                            'gradient_sparsity_{}'.format(
                                var.name.replace(':', '__')),
                            tf.nn.zero_fraction(grad))
                        # if batch_end % data.nb_examples == 0 or batch_end % data.nb_examples == 1:
                        #     pdb.set_trace()
                        gradient_norm = tf.sqrt(tf.reduce_sum(tf.pow(grad, 2)))
                        tf.contrib.summary.scalar(
                            'gradient_norm_{}'.format(
                                var.name.replace(':', '__')), gradient_norm)
                        tf.contrib.summary.histogram(
                            'gradient_{}'.format(var.name.replace(':', '__')),
                            grad)
                        tf.contrib.summary.histogram(
                            'variable_{}'.format(var.name.replace(':', '__')),
                            var)
                        # gradient_values = tf.reduce_sum(tf.abs(grad))
                        # tf.contrib.summary.scalar('gradient_values/{}'.format(var.name.replace(':', '__')),
                        #                           gradient_values)

                    # grads = [g for g, _ in grads_and_vars]
                    # flattened_grads = tf.concat([tf.reshape(t, [-1]) for t in grads], axis=0)
                    # flattened_vars = tf.concat([tf.reshape(t, [-1]) for t in neural_kb.variables], axis=0)
                    # tf.contrib.summary.histogram('values_grad', flattened_grads)
                    # tf.contrib.summary.histogram('values_var', flattened_vars)
            if tensorboard:
                with tf.contrib.summary.always_record_summaries():
                    tf.contrib.summary.scalar('time_per_batch',
                                              time.time() - start_time)
            if tensorboard and is_epoch_end:
                with tf.contrib.summary.always_record_summaries():
                    tb_pel = sum(per_epoch_losses)
                    if loss_aggregator == 'mean':
                        tb_pel /= len(per_epoch_losses)
                    tf.contrib.summary.scalar('per_epoch_loss', tb_pel)

            if is_epoch_end:
                n_epochs_finished += 1
                per_epoch_losses = []

            # post-epoch whatever...
            if evaluate_per_epoch and is_epoch_end:
                index_type = 'faiss' if exact_knn_evaluation is None else exact_knn_evaluation
                tmp_exact_knn_eval = exact_knn_evaluation
                if exact_knn_evaluation is None and index_type == 'faiss':
                    tmp_exact_knn_eval = 'faiss'
                do_eval(evaluation_mode,
                        model,
                        neural_kb,
                        data,
                        batcher,
                        batch_size,
                        index_type_to_params,
                        is_no_ntp0,
                        is_explanation,
                        dev_only=True,
                        tensorboard=tensorboard,
                        verbose=True,
                        exact_knn_evaluation=tmp_exact_knn_eval,
                        test_batch_size=test_batch_size)

            # # checkpoint saving
            if checkpoint_path is not None and (batch_no +
                                                1) % checkpoint_frequency == 0:
                checkpoint_store(checkpoint_path, neural_kb, optimizer,
                                 random_state, args)

            if profile:
                if batch_no != 0:  # skip the first one as it's significantly longer (warmup?)
                    batch_times.append(time.time() - start_time)
                if batch_no == 10:
                    break

            if tf_profiler:
                profiler.add_step(batch_no, context.export_run_metadata())
                context.disable_run_metadata()
                # profiler.profile_operations(opts)
                profiler.profile_graph(options=opts)

    end_time = time.time()

    if tf_profiler:
        profiler.advise(options=model_analyzer.ALL_ADVICE)

    if profile:
        gpu_memory_job.terminate()
        if len(gpu_memory_profiler_return) == 0:
            gpu_memory_profiler_return = [0]
        nb_negatives = nb_corrupted_pairs * 2 * (2 if is_all else 1)
        nb_triple_variants = 1 + nb_negatives
        examples_per_batch = nb_triple_variants * batch_size
        print('Examples per batch: {}'.format(examples_per_batch))
        print('Batch times: {}'.format(batch_times))
        time_per_batch = np.average(batch_times)
        print('Average time per batch: {}'.format(time_per_batch))
        print('examples per second: {}'.format(examples_per_batch /
                                               time_per_batch))
    else:
        if is_explanation is None:
            logger.info('Training took {} seconds'.format(end_time -
                                                          start_training_time))

    # last checkpoint save
    if checkpoint_path is not None:
        checkpoint_store(checkpoint_path, neural_kb, optimizer, random_state,
                         args)

    # and save the model, if you want to save it (it's better practice to have
    # the checkpoint_path different to save_path, as one can save checkpoints on scratch, and models permanently
    if save:
        save_path = os.path.join(dump_path, 'final_model/')
        checkpoint_store(save_path, neural_kb, optimizer, random_state, args)

    # TODO prettify profiling
    if profile:
        return max(gpu_memory_profiler_return)

    logger.info('Starting evaluation ..')

    neural_kb.create_neural_kb()

    idx_to_relation = {
        idx: relation
        for relation, idx in data.relation_to_idx.items()
    }

    if has_decode:
        for neural_rule in neural_kb.neural_rules_kb:
            gntp.decode(neural_rule,
                        neural_kb.relation_embeddings,
                        idx_to_relation,
                        kernel=kernel)

    # explanations for the train set, just temporarily
    if is_explanation is not None:

        from gntp.util import make_batches

        which_triples = []
        if is_explanation == 'train':
            which_triples = data.train_triples
        elif is_explanation == 'dev':
            which_triples = data.dev_triples
        elif is_explanation == 'test':
            which_triples = data.test_triples

        _triples = [(data.entity_to_idx[s], data.predicate_to_idx[p],
                     data.entity_to_idx[o]) for s, p, o in which_triples]

        batches = make_batches(len(_triples), batch_size)

        explanations_filename = 'explanations-{}-{}.txt'.format(
            checkpoint_path.replace('/', '_'), is_explanation)
        with open(explanations_filename, 'w') as fw:
            for neural_rule in neural_kb.neural_rules_kb:
                decoded_rules = gntp.decode(neural_rule,
                                            neural_kb.relation_embeddings,
                                            idx_to_relation,
                                            kernel=kernel)

                for decoded_rule in decoded_rules:
                    fw.write(decoded_rule + '\n')

            fw.write('--' * 50 + '\n')

            for start, end in batches:
                batch = np.array(_triples[start:end])
                Xs_batch, Xp_batch, Xo_batch = batch[:, 0], batch[:,
                                                                  1], batch[:,
                                                                            2]

                _p_emb = tf.nn.embedding_lookup(neural_kb.relation_embeddings,
                                                Xp_batch)
                _s_emb = tf.nn.embedding_lookup(neural_kb.entity_embeddings,
                                                Xs_batch)
                _o_emb = tf.nn.embedding_lookup(neural_kb.entity_embeddings,
                                                Xo_batch)

                _res, (proof_states, _) = model.predict(
                    _p_emb,
                    _s_emb,
                    _o_emb,
                    neural_facts_kb=neural_kb.neural_facts_kb,
                    neural_rules_kb=neural_kb.neural_rules_kb,
                    is_training=False,
                    no_ntp0=is_no_ntp0,
                    support_explanations=is_explanation is not None)

                # path_indices = decode_per_path_type_proof_states_indices(proof_states)
                path_indices = decode_proof_states_indices(proof_states,
                                                           top_k=3)
                decoded_paths = decode_paths(path_indices, neural_kb)

                _ps, _ss, _os = Xp_batch.tolist(), Xs_batch.tolist(
                ), Xo_batch.tolist()
                __triples = [(data.idx_to_entity[s], data.idx_to_predicate[p],
                              data.idx_to_entity[o])
                             for s, p, o in zip(_ss, _ps, _os)]

                _scores = _res.numpy().tolist()

                for i, (_triple, _score, decoded_path) in enumerate(
                        zip(__triples, _scores, decoded_paths)):
                    _s, _p, _o = _triple
                    _triple_str = '{}({}, {})'.format(_p, _s, _o)

                    # print(_triple_str, _score, decoded_path)
                    fw.write("{}\t{}\t{}\n".format(_triple_str, _score,
                                                   decoded_path))
        logging.info('DONE with explanation...quitting.')
        sys.exit(0)

    eval_start = time.time()
    do_eval(evaluation_mode,
            model,
            neural_kb,
            data,
            batcher,
            batch_size,
            index_type_to_params,
            is_no_ntp0,
            is_explanation,
            dev_only=dev_only,
            exact_knn_evaluation=exact_knn_evaluation,
            test_batch_size=test_batch_size)
    logging.info('Evaluation took {} seconds'.format(time.time() - eval_start))
Пример #23
0
  def testMultiStepProfile(self):
    ops.reset_default_graph()
    opts = builder.time_and_memory(min_bytes=0)

    with session.Session() as sess:
      r1, r2, r3 = lib.BuildSplittableModel()
      sess.run(variables.global_variables_initializer())

      profiler = model_analyzer.Profiler(sess.graph)
      pb0 = profiler.profile_name_scope(opts)

      run_meta = config_pb2.RunMetadata()
      _ = sess.run(r1,
                   options=config_pb2.RunOptions(
                       trace_level=config_pb2.RunOptions.FULL_TRACE),
                   run_metadata=run_meta)
      profiler.add_step(1, run_meta)
      pb1 = profiler.profile_name_scope(opts)

      self.assertNotEqual(lib.SearchTFProfNode(pb1, 'DW'), None)
      self.assertEqual(lib.SearchTFProfNode(pb1, 'DW2'), None)
      self.assertEqual(lib.SearchTFProfNode(pb1, 'add'), None)

      run_meta2 = config_pb2.RunMetadata()
      _ = sess.run(r2,
                   options=config_pb2.RunOptions(
                       trace_level=config_pb2.RunOptions.FULL_TRACE),
                   run_metadata=run_meta2)
      profiler.add_step(2, run_meta2)
      pb2 = profiler.profile_name_scope(opts)

      self.assertNotEqual(lib.SearchTFProfNode(pb2, 'DW'), None)
      self.assertNotEqual(lib.SearchTFProfNode(pb2, 'DW2'), None)
      self.assertEqual(lib.SearchTFProfNode(pb2, 'add'), None)

      run_meta3 = config_pb2.RunMetadata()
      _ = sess.run(r3,
                   options=config_pb2.RunOptions(
                       trace_level=config_pb2.RunOptions.FULL_TRACE),
                   run_metadata=run_meta3)
      profiler.add_step(3, run_meta3)
      pb3 = profiler.profile_name_scope(opts)

      self.assertNotEqual(lib.SearchTFProfNode(pb3, 'DW'), None)
      self.assertNotEqual(lib.SearchTFProfNode(pb3, 'DW2'), None)
      self.assertNotEqual(lib.SearchTFProfNode(pb3, 'add'), None)

      self.assertEqual(lib.SearchTFProfNode(pb0, 'Conv2D'), None)
      self.assertGreater(lib.SearchTFProfNode(pb1, 'Conv2D').exec_micros, 0)
      self.assertEqual(lib.SearchTFProfNode(pb1, 'Conv2D_1'), None)
      self.assertGreater(lib.SearchTFProfNode(pb2, 'Conv2D_1').exec_micros, 0)
      self.assertEqual(lib.SearchTFProfNode(pb2, 'add'), None)
      self.assertGreater(lib.SearchTFProfNode(pb3, 'add').exec_micros, 0)

      advice_pb = profiler.advise(model_analyzer.ALL_ADVICE)
      self.assertTrue('AcceleratorUtilizationChecker' in advice_pb.checkers)
      self.assertTrue('ExpensiveOperationChecker' in advice_pb.checkers)
      self.assertTrue('OperationChecker' in advice_pb.checkers)

      checker = advice_pb.checkers['AcceleratorUtilizationChecker']
      if test.is_gpu_available():
        self.assertGreater(len(checker.reports), 0)
      else:
        self.assertEqual(len(checker.reports), 0)
      checker = advice_pb.checkers['ExpensiveOperationChecker']
      self.assertGreater(len(checker.reports), 0)
Пример #24
0
    def train(sess):
        sess.run(init)

        if args.profile:
            profiler_step = 0
            profiler = model_analyzer.Profiler(graph=sess.graph)
            run_options = tf.RunOptions(trace_level = tf.RunOptions.FULL_TRACE)
            run_metadata = tf.RunMetadata()
        
        total_time = 0.0
        epoch_times = []
        
        for epoch_id in xrange(max_epoch):
            batch_times = []
            epoch_start_time = time.time()
            train_data_iter = reader.get_data_iter( train_data, batch_size, num_steps)

            # assign lr, update the learning rate
            new_lr_1 = base_learning_rate * ( lr_decay ** max(epoch_id + 1 - epoch_start_decay, 0.0) )
            sess.run( lr_update, {new_lr: new_lr_1})
        
            total_loss = 0.0
            iters = 0
            batch_len = len(train_data) // batch_size
            epoch_size = ( batch_len - 1 ) // num_steps

            if args.profile:
                log_fre = 1
            else:
                log_fre = epoch_size // 10
        
            init_h = np.zeros( (num_layers, batch_size, hidden_size), dtype='float32')
            init_c = np.zeros( (num_layers, batch_size, hidden_size), dtype='float32')
        
            count = 0.0
            for batch_id, batch in enumerate(train_data_iter):
                x,y = batch
                feed_dict = {}
                feed_dict[feeding_list[0]] = x
                feed_dict[feeding_list[1]] = y
                feed_dict[feeding_list[2]] = init_h
                feed_dict[feeding_list[3]] = init_c
        
                batch_start_time = time.time()
                if args.profile:
                    output = sess.run([cost, final_h, final_c, train_op], feed_dict, options=run_options, run_metadata=run_metadata)
                    profiler.add_step(step=profiler_step, run_meta=run_metadata)
                    profiler_step = profiler_step + 1
                    if batch_id >= 10:
                        break
                else:
                    output = sess.run([cost, final_h, final_c, train_op], feed_dict)
                batch_time = time.time() - batch_start_time
                batch_times.append(batch_time)
        
                train_cost = output[0]
                init_h = output[1]
                init_c = output[2]
        
                total_loss += train_cost
                iters += num_steps
                count = count + 1
                if batch_id > 0 and  batch_id % log_fre == 0:
                    ppl = np.exp( total_loss / iters )
                    print("-- Epoch:[%d]; Batch:[%d]; Time: %.5f s; ppl: %.5f, lr: %.5f" % (epoch_id, batch_id, batch_time, ppl, new_lr_1))
        
            ppl = np.exp(total_loss / iters)
            epoch_time = time.time() - epoch_start_time
            epoch_times.append(epoch_time)
            total_time += epoch_time
        
            print("\nTrain epoch:[%d]; epoch Time: %.5f s; ppl: %.5f; avg_time: %.5f steps/s\n"
                  % (epoch_id, epoch_time, ppl, (batch_id + 1) / sum(batch_times)))

            valid_ppl, _ = eval(sess, valid_data)
            print("Valid ppl: %.5f" % valid_ppl)
    
        test_ppl, test_time = eval(sess, test_data)
        print("Test Time (total): %.5f, ppl: %.5f" % (test_time, test_ppl))
              
        if args.profile:
            profile_op_opt_builder = option_builder.ProfileOptionBuilder()
            profile_op_opt_builder.select(['micros','occurrence'])
            profile_op_opt_builder.order_by('micros')
            profile_op_opt_builder.with_max_depth(50)
            profiler.profile_operations(profile_op_opt_builder.build())
Пример #25
0
    def run(self,
            use_gpu,
            feed=None,
            repeat=1,
            log_level=0,
            check_output=False,
            profile=False):
        sess = self._init_session(use_gpu)
        #tf.debugging.set_log_device_placement(True)

        if profile:
            profiler = model_analyzer.Profiler(graph=sess.graph)
            run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE)
            run_metadata = tf.RunMetadata()
        else:
            profiler = None
            run_options = None
            run_metadata = None
        self.timeline_dict = None

        if feed is None:
            feed = self._feed_random_data()

        runtimes = []
        fetches = []
        outputs = None
        for i in range(repeat):
            begin = time.time()
            outputs = sess.run(fetches=self.fetch_list,
                               feed_dict=feed,
                               options=run_options,
                               run_metadata=run_metadata)
            end = time.time()
            runtimes.append(end - begin)

            if profile:
                # Update profiler
                profiler.add_step(step=i, run_meta=run_metadata)
                # For timeline
                tl = timeline.Timeline(run_metadata.step_stats)
                chrome_trace = tl.generate_chrome_trace_format()
                trace_file = open(self.name + '_tf.timeline', 'w')
                trace_file.write(chrome_trace)
                #self._update_timeline(chrome_trace)

            if check_output:
                fetches.append(outputs)
        if profile:
            # Generate profiling result
            profile_op_builder = option_builder.ProfileOptionBuilder()
            profile_op_builder.select(['micros', 'occurrence'])
            profile_op_builder.order_by('micros')
            profile_op_builder.with_max_depth(10)
            profiler.profile_operations(profile_op_builder.build())
            # Generate timeline
        #            profile_graph_builder = option_builder.ProfileOptionBuilder(
        #                                    option_builder.ProfileOptionBuilder.time_and_memory())
        #            profile_graph_builder.with_timeline_output(timeline_file=self.name + '_tf.timeline')
        #            profile_graph_builder.with_step(10)
        #            profiler.profile_graph(profile_graph_builder.build())
        #tl_output_file = self.name + "_tf.timeline"
        #with open(tl_output_file, 'w') as f:
        #    json.dump(self.timeline_dict, f)

        stats = {
            "framework": "tensorflow",
            "version": tf.__version__,
            "name": self.name,
            "total": runtimes
        }
        stats["device"] = "GPU" if use_gpu else "CPU"
        utils.print_stat(stats, log_level=log_level)
        return outputs
Пример #26
0
def train(args, data, show_loss, show_topk):
    n_user, n_item, n_entity, n_relation = data[0], data[1], data[2], data[3]
    train_data, eval_data, test_data = data[4], data[5], data[6]
    adj_entity, adj_relation = data[7], data[8]

    model = KGCN(args, n_user, n_entity, n_relation, adj_entity, adj_relation)

    # top-K evaluation settings
    user_list, train_record, test_record, item_set, k_list = topk_settings(
        show_topk, train_data, test_data, n_item)

    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        # monitor the usage of memory while training the model
        profiler = model_analyzer.Profiler(graph=sess.graph)
        run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE)
        run_metadata = tf.RunMetadata()
        # tensor-board
        writer = tf.summary.FileWriter('../data/' + args.dataset + '/logs',
                                       tf.get_default_graph())

        for step in range(args.n_epochs):
            # training
            t = time.time()
            np.random.shuffle(train_data)
            start = 0
            i = 0
            # skip the last incomplete minibatch if its size < batch size
            while start + args.batch_size <= train_data.shape[0]:
                _, loss = model.train(
                    sess,
                    get_feed_dict(model, train_data, start,
                                  start + args.batch_size), run_options,
                    run_metadata)
                # add the data into tfprofiler
                profiler.add_step(step=step, run_meta=run_metadata)
                if i == 0:
                    writer.add_run_metadata(run_metadata, 'step %d' % step)
                i += 1
                start += args.batch_size
                # if show_loss:
                #     print(start, loss)

            # CTR evaluation
            train_auc, train_f1 = ctr_eval(sess, model, train_data,
                                           args.batch_size)
            eval_auc, eval_f1 = ctr_eval(sess, model, eval_data,
                                         args.batch_size)
            test_auc, test_f1 = ctr_eval(sess, model, test_data,
                                         args.batch_size)

            # values = ps.virtual_memory()
            # used_memory = values.used / (1024.0 ** 3)
            train_time = time.time() - t

            # print('epoch %d    train auc: %.4f  f1: %.4f    eval auc: %.4f  f1: %.4f    test auc: %.4f  f1: %.4f'
            #       % (step, train_auc, train_f1, eval_auc, eval_f1, test_auc, test_f1))
            print(
                'epoch %d   training time: %.5f    train auc: %.4f  f1: %.4f    eval auc: %.4f  f1: %.4f    test auc: %.4f  f1: %.4f'
                % (step, train_time, train_auc, train_f1, eval_auc, eval_f1,
                   test_auc, test_f1))

        # # 统计模型的memory使用大小
        profile_scope_opt_builder = option_builder.ProfileOptionBuilder(
            option_builder.ProfileOptionBuilder.trainable_variables_parameter(
            ))
        # 显示字段是params,即参数
        profile_scope_opt_builder.select(['params'])
        # 根据params数量进行显示结果排序
        profile_scope_opt_builder.order_by('params')
        # 显示视图为scope view
        profiler.profile_name_scope(profile_scope_opt_builder.build())

        # ------------------------------------
        # 最耗时top 5 ops
        profile_op_opt_builder = option_builder.ProfileOptionBuilder()

        # 显示字段:op执行时间,使用该op的node的数量。 注意:op的执行时间即所有使用该op的node的执行时间总和。
        profile_op_opt_builder.select(['micros', 'occurrence'])
        # 根据op执行时间进行显示结果排序
        profile_op_opt_builder.order_by('micros')
        # 过滤条件:只显示排名top 7
        profile_op_opt_builder.with_max_depth(6)

        # 显示视图为op view
        profiler.profile_operations(profile_op_opt_builder.build())

        # ------------------------------------
        writer.close()
Пример #27
0
def main(_):
    os.environ['CUDA_VISIBLE_DEVICES'] = '0'
    with tf.Graph().as_default():
        out_shape = [FLAGS.train_image_size] * 2

        image_input = tf.placeholder(tf.uint8, shape=(None, None, 3))
        shape_input = tf.placeholder(tf.int32, shape=(2,))

        features = cls_preprocessing.preprocess_for_eval(image_input, out_shape, data_format=FLAGS.data_format, output_rgb=False)
        features = tf.expand_dims(features, axis=0)

        with tf.variable_scope(FLAGS.model_scope, default_name=None, values=[features], reuse=tf.AUTO_REUSE):
            model = cls_reg_net.CLS_REG_Model(FLAGS.resnet_size, FLAGS.resnet_version,
                                    FLAGS.attention_block, FLAGS.location_feature_stage,
                                    FLAGS.data_format)

            results = model(features, training=False)
            if FLAGS.location_feature_stage:
                logits, loc, location = results
            else:
                logits = results
        # tf.summary.image('base',tf.reshape(tf.range(9, dtype=tf.float32), [1,3,3,1]))
        tf.summary.image('origin_pic',tf.transpose(features, [0, 2, 3, 1]))
        # tf.summary.image('att_map', tf.transpose(att_map, [0, 2, 3, 1]))
        # tf.summary.image('loc', tf.transpose(loc, [0, 2, 3, 1]))
        merged = tf.summary.merge_all()

        saver = tf.train.Saver()
        with tf.Session() as sess:
            # 创建 profiler 对象
            my_profiler = model_analyzer.Profiler(graph=sess.graph)
            # 创建 metadata 对象
            run_metadata = tf.RunMetadata()
            run_options = tf.RunOptions(trace_level = tf.RunOptions.FULL_TRACE)
            init = tf.global_variables_initializer()
            sess.run(init)
            saver.restore(sess, get_checkpoint())

            # init summary writer
            writer = tf.summary.FileWriter("./demo/test_out/" ,sess.graph)
            i = 0
            for picname in os.listdir('./demo'):
                if picname.split('.')[-1] != 'jpg':
                    print(picname)
                    continue
                np_image = imread(os.path.join('./demo',picname))

                print(type(np_image), np_image.shape)
                # exit()
                logits_, loc_, location_, summary= sess.run([logits, loc, location, merged], 
                                                        feed_dict = {image_input : np_image, shape_input : np_image.shape[:-1]},
                                                        options=run_options, run_metadata=run_metadata)
                my_profiler.add_step(step=i, run_meta=run_metadata)

                # att = att.reshape([-1])
                # ma = np.argmax(att)
                # mi = np.argmin(att)
                # print(res)
                # # print(ma/28, ma%28, mi/28, mi%28)
                # print(att[ma],att[mi])
                # print(lo)
                writer.add_summary(summary,i)
                i+=1
                # img_to_draw = draw_toolbox.bboxes_draw_on_img(np_image, labels_, scores_, bboxes_, thickness=2)
                # imsave('./demo/test_out.jpg', img_to_draw)

            #统计内容为每个graph node的运行时间和占用内存
            profile_graph_opts_builder = option_builder.ProfileOptionBuilder(
            option_builder.ProfileOptionBuilder.time_and_memory())

            #输出方式为timeline
            profile_graph_opts_builder.with_timeline_output(timeline_file='/tmp/profiler.json')
            #定义显示sess.Run() 第70步的统计数据
            profile_graph_opts_builder.with_step(3)

            #显示视图为graph view
            my_profiler.profile_graph(profile_graph_opts_builder.build())
def main(_):
    # Import data
    mnist = input_data.read_data_sets(FLAGS.data_dir)

    a = bias_variable([3, 3])
    b = tf.constant(0.2, shape=[3, 3])
    c = tf.constant(10.0, shape=[3, 3])
    d = a + b
    e = tf.multiply(d, c)
    relu1 = tf.nn.relu(e, name='relu1')
    train_relu1 = tf.train.AdamOptimizer(1e-4).minimize(relu1)

    # Create the model
    x = tf.placeholder(tf.float32, [None, 784])

    # Define loss and optimizer
    y_ = tf.placeholder(tf.int64, [None])

    # Build the graph for the deep net
    y_conv, keep_prob = deepnn(x)

    with tf.name_scope('loss'):
        cross_entropy = tf.losses.sparse_softmax_cross_entropy(labels=y_,
                                                               logits=y_conv)
    cross_entropy = tf.reduce_mean(cross_entropy)

    with tf.name_scope('adam_optimizer'):
        train_step = tf.train.AdamOptimizer(1e-4).minimize(cross_entropy)

    with tf.name_scope('accuracy'):
        correct_prediction = tf.equal(tf.argmax(y_conv, 1), y_)
        correct_prediction = tf.cast(correct_prediction, tf.float32)
    accuracy = tf.reduce_mean(correct_prediction)

    from tensorflow.python.profiler import model_analyzer
    from tensorflow.python.profiler import option_builder
    with tf.Session(config=get_sess_config()) as sess:

        many_runs_timeline = TimeLiner()

        sess.graph.get_operation_by_name(
            'adam_optimizer/gradients/pool1/MaxPool_grad/MaxPoolGrad'
        )._set_attr(
            '_swap_to_host',
            attr_value_pb2.AttrValue(list=attr_value_pb2.AttrValue.ListValue(
                i=[0, 1])))
        sess.graph.get_operation_by_name(
            'adam_optimizer/gradients/conv1/Relu_grad/ReluGrad')._set_attr(
                '_swap_to_host', attr_value_pb2.AttrValue(i=1))

        sess.graph.get_operation_by_name(
            'adam_optimizer/gradients/pool2/MaxPool_grad/MaxPoolGrad'
        )._set_attr(
            '_swap_to_host',
            attr_value_pb2.AttrValue(list=attr_value_pb2.AttrValue.ListValue(
                i=[0, 1])))
        sess.graph.get_operation_by_name(
            'adam_optimizer/gradients/conv2/Relu_grad/ReluGrad')._set_attr(
                '_swap_to_host', attr_value_pb2.AttrValue(i=1))
        sess.graph.get_operation_by_name(
            'adam_optimizer/gradients/conv2/Conv2D_grad/Conv2DBackpropInput'
        )._set_attr('_swap_to_host', attr_value_pb2.AttrValue(i=2))
        #sess.graph.get_operation_by_name('pool1/MaxPool')._set_attr('_swap_to_host', attr_value_pb2.AttrValue(i=0))
        #gradient_ops = sess.graph.get_operation_by_name('adam_optimizer/gradients/conv2/Conv2D_grad/ShapeN')
        #gradient_ops._set_attr('_swap_to_host', attr_value_pb2.AttrValue(i=0))
        #gradient_ops._set_attr('_swap_to_host', attr_value_pb2.AttrValue(i=1))
        sess.run(tf.global_variables_initializer())
        profiler = model_analyzer.Profiler(sess.graph)
        #for i in range(20000):
        for i in range(FLAGS.iteration_count):
            batch = mnist.train.next_batch(FLAGS.batch_size)
            run_metadata = tf.RunMetadata()
            sess.run(
                train_step,
                feed_dict={
                    x: batch[0],
                    y_: batch[1],
                    keep_prob: 0.5
                },
                options=tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE),
                run_metadata=run_metadata)
            #sess.run(train_relu1, feed_dict={x: batch[0], y_: batch[1], keep_prob: 0.5}, options=tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE), run_metadata=run_metadata)

            trace = timeline.Timeline(step_stats=run_metadata.step_stats)
            chrome_trace = trace.generate_chrome_trace_format(
                show_dataflow=True, show_memory=True)
            many_runs_timeline.update_timeline(chrome_trace)

            profiler.add_step(i, run_metadata)

            # profile the timing of your model operations.
            #opts = (tf.profiler.ProfileOptionBuilder(
            #  option_builder.ProfileOptionBuilder.time_and_memory())
            #  .select(['micros', 'bytes', 'occurrence', 'peak_bytes', 'residual_bytes', 'output_bytes'])
            #  .order_by('name').build())
            #profiler.profile_operations(options=opts)

            # can generate a timeline:
            opts = (option_builder.ProfileOptionBuilder(
                option_builder.ProfileOptionBuilder.time_and_memory()
            ).with_step(i).with_timeline_output(
                "./timeline_output/step_" + FLAGS.mem_opt +
                str(FLAGS.batch_size) + str(FLAGS.iteration_count)).build())
            profiler.profile_graph(options=opts)
    chrome_trace_filename = str(FLAGS.batch_size) + str(FLAGS.mem_opt) + "new"
    graph_location = str(FLAGS.batch_size) + str(
        FLAGS.mem_opt) + "_swap_test.pbtxt"
    print('Saving graph to: %s' % graph_location)
    tf.train.write_graph(sess.graph_def, '.', graph_location, as_text=True)
    many_runs_timeline.save(chrome_trace_filename + '.ctf.json')