示例#1
0
        def session_run_job():
            with tf.Session() as sess:
                a = tf.Variable([10.0] * variable_size, name='a')
                b = tf.Variable([20.0] * variable_size, name='b')
                c = tf.Variable([30.0] * variable_size, name='c')
                x = tf.multiply(a, b, name="x")
                y = tf.add(x, c, name="y")

                sess.run(tf.global_variables_initializer())

                run_options = tf.RunOptions()
                tf_debug.watch_graph(
                    run_options,
                    sess.graph,
                    debug_ops="DebugIdentity(gated_grpc=True)",
                    debug_urls=self._debugger_url)
                session_run_results.append(sess.run(y, options=run_options))
    def _createTestGraphAndRunOptions(self, sess, gated_grpc=True):
        a = tf.Variable([1.0], name='a')
        b = tf.Variable([2.0], name='b')
        c = tf.Variable([3.0], name='c')
        d = tf.Variable([4.0], name='d')
        x = tf.add(a, b, name='x')
        y = tf.add(c, d, name='y')
        z = tf.add(x, y, name='z')

        run_options = tf.RunOptions(output_partition_graphs=True)
        debug_op = 'DebugIdentity'
        if gated_grpc:
            debug_op += '(gated_grpc=True)'
        tf_debug.watch_graph(run_options,
                             sess.graph,
                             debug_ops=debug_op,
                             debug_urls=self.debug_server_url)
        return z, run_options
  def _createTestGraphAndRunOptions(self, sess, gated_grpc=True):
    a = tf.Variable([1.0], name='a')
    b = tf.Variable([2.0], name='b')
    c = tf.Variable([3.0], name='c')
    d = tf.Variable([4.0], name='d')
    x = tf.add(a, b, name='x')
    y = tf.add(c, d, name='y')
    z = tf.add(x, y, name='z')

    run_options = tf.RunOptions(output_partition_graphs=True)
    debug_op = 'DebugIdentity'
    if gated_grpc:
      debug_op += '(gated_grpc=True)'
    tf_debug.watch_graph(run_options,
                         sess.graph,
                         debug_ops=debug_op,
                         debug_urls=self.debug_server_url)
    return z, run_options
示例#4
0
        def session_run_job():
            with tf.Session() as sess:
                a = tf.Variable(10, dtype=tf.int32, name='a')
                b = tf.Variable(20, dtype=tf.int32, name='b')
                d = tf.constant(1, dtype=tf.int32, name='d')
                inc_a = tf.assign_add(a, d, name='inc_a')
                inc_b = tf.assign_add(b, d, name='inc_b')
                inc_ab = tf.group([inc_a, inc_b], name="inc_ab")

                sess.run(tf.global_variables_initializer())
                run_options = tf.RunOptions()
                tf_debug.watch_graph(
                    run_options,
                    sess.graph,
                    debug_ops="DebugIdentity(gated_grpc=True)",
                    debug_urls=self._debugger_url)
                session_run_results.append(
                    sess.run(inc_ab, options=run_options))
示例#5
0
        def session_run_job():
            with tf.Session() as sess:
                a = tf.Variable(10, dtype=tf.int32, name='a')
                b = tf.Variable(1, dtype=tf.int32, name='b')
                inc_a = tf.assign_add(a, b, name='inc_a')

                sess.run(tf.global_variables_initializer())

                run_options = tf.RunOptions()
                tf_debug.watch_graph(
                    run_options,
                    sess.graph,
                    debug_ops="DebugIdentity(gated_grpc=True)",
                    debug_urls=self._debugger_url)

                for _ in range(steps):
                    session_run_results.append(
                        sess.run(inc_a, options=run_options))
示例#6
0
    def testMultipleInt32ValuesOverMultipleRunsAreRecorded(self):
        with tf.Session() as sess:
            x_init_val = np.array([10], dtype=np.int32)
            x_init = tf.constant(x_init_val, shape=[1], name="x_init")
            x = tf.Variable(x_init, name="x")

            x_inc_val = np.array([2], dtype=np.int32)
            x_inc = tf.constant(x_inc_val, name="x_inc")
            inc_x = tf.assign_add(x, x_inc, name="inc_x")

            sess.run(x.initializer)

            run_options = tf.RunOptions(output_partition_graphs=True)
            tf_debug.watch_graph(
                run_options,
                sess.graph,
                debug_ops=["DebugNumericSummary"],
                debug_urls=[self._debug_url],
            )

            # Increase three times.
            for _ in range(3):
                sess.run(inc_x, options=run_options)

        # Debugger data is stored within a special directory within logdir.
        event_files = glob.glob(
            os.path.join(
                self._logdir,
                constants.DEBUGGER_DATA_DIRECTORY_NAME,
                "events.debugger*",
            ))
        self.assertEqual(1, len(event_files))

        self._check_health_pills_in_events_file(
            event_files[0],
            {
                "x_inc:0:DebugNumericSummary": [x_inc_val] * 3,
                "x:0:DebugNumericSummary": [
                    x_init_val,
                    x_init_val + x_inc_val,
                    x_init_val + 2 * x_inc_val,
                ],
            },
        )
示例#7
0
  def _poll_server_till_success(self, max_tries, poll_interval_seconds):
    for _ in range(max_tries):
      try:
        with tf.Session() as sess:
          a_init_val = np.array([42.0])
          a_init = tf.constant(a_init_val, shape=[1], name="a_init")
          a = tf.Variable(a_init, name="a")

          run_options = tf.RunOptions(output_partition_graphs=True)
          tf_debug.watch_graph(run_options,
                               sess.graph,
                               debug_ops=["DebugNumericSummary"],
                               debug_urls=[self._debug_url])

          sess.run(a.initializer, options=run_options)
          return True
      except tf.errors.FailedPreconditionError:
        time.sleep(poll_interval_seconds)

    return False
    def _poll_server_till_success(self, max_tries, poll_interval_seconds):
        for _ in range(max_tries):
            try:
                with tf.Session() as sess:
                    a_init_val = np.array([42.0])
                    a_init = tf.constant(a_init_val, shape=[1], name="a_init")
                    a = tf.Variable(a_init, name="a")

                    run_options = tf.RunOptions(output_partition_graphs=True)
                    tf_debug.watch_graph(run_options,
                                         sess.graph,
                                         debug_ops=["DebugNumericSummary"],
                                         debug_urls=[self._debug_url])

                    sess.run(a.initializer, options=run_options)
                    return True
            except tf.errors.FailedPreconditionError:
                time.sleep(poll_interval_seconds)

        return False
    def _createTestGraphAndRunOptions(self, sess, gated_grpc=True):
        a = tf.Variable([1.0], name="a")
        b = tf.Variable([2.0], name="b")
        c = tf.Variable([3.0], name="c")
        d = tf.Variable([4.0], name="d")
        x = tf.add(a, b, name="x")
        y = tf.add(c, d, name="y")
        z = tf.add(x, y, name="z")

        run_options = tf.compat.v1.RunOptions(output_partition_graphs=True)
        debug_op = "DebugIdentity"
        if gated_grpc:
            debug_op += "(gated_grpc=True)"
        tf_debug.watch_graph(
            run_options,
            sess.graph,
            debug_ops=debug_op,
            debug_urls=self.debug_server_url,
        )
        return z, run_options
示例#10
0
  def testMultipleInt32ValuesOverMultipleRunsAreRecorded(self):
    with tf.Session() as sess:
      x_init_val = np.array([10], dtype=np.int32)
      x_init = tf.constant(x_init_val, shape=[1], name="x_init")
      x = tf.Variable(x_init, name="x")

      x_inc_val = np.array([2], dtype=np.int32)
      x_inc = tf.constant(x_inc_val, name="x_inc")
      inc_x = tf.assign_add(x, x_inc, name="inc_x")

      sess.run(x.initializer)

      run_options = tf.RunOptions(output_partition_graphs=True)
      tf_debug.watch_graph(run_options,
                           sess.graph,
                           debug_ops=["DebugNumericSummary"],
                           debug_urls=[self._debug_url])

      # Increase three times.
      for _ in range(3):
        sess.run(inc_x, options=run_options)

    # Debugger data is stored within a special directory within logdir.
    event_files = glob.glob(
        os.path.join(self._logdir, constants.DEBUGGER_DATA_DIRECTORY_NAME,
                     "events.debugger*"))
    self.assertEqual(1, len(event_files))

    self._check_health_pills_in_events_file(
        event_files[0],
        {
            "x_inc:0:DebugNumericSummary": [x_inc_val] * 3,
            "x:0:DebugNumericSummary": [
                x_init_val,
                x_init_val + x_inc_val,
                x_init_val + 2 * x_inc_val],
        })
示例#11
0
    def testConcurrentNumericsAlertsAreRegisteredCorrectly(self):
        num_threads = 3
        num_runs_per_thread = 2
        total_num_runs = num_threads * num_runs_per_thread

        # Before any Session runs, the report ought to be empty.
        self.assertEqual([], self._debug_data_server.numerics_alert_report())

        with tf.compat.v1.Session() as sess:
            x_init_val = np.array([[2.0], [-1.0]])
            y_init_val = np.array([[0.0], [-0.25]])
            z_init_val = np.array([[0.0, 3.0], [-1.0, 0.0]])

            x_init = tf.constant(x_init_val, shape=[2, 1], name="x_init")
            x = tf.Variable(x_init, name="x")
            y_init = tf.constant(y_init_val, shape=[2, 1])
            y = tf.Variable(y_init, name="y")
            z_init = tf.constant(z_init_val, shape=[2, 2])
            z = tf.Variable(z_init, name="z")

            u = tf.compat.v1.div(x, y, name="u")  # Produces an Inf.
            v = tf.matmul(z, u, name="v")  # Produces NaN and Inf.

            sess.run(x.initializer)
            sess.run(y.initializer)
            sess.run(z.initializer)

            run_options_list = []
            for i in range(num_threads):
                run_options = tf.compat.v1.RunOptions(
                    output_partition_graphs=True)
                # Use different grpc:// URL paths so that each thread opens a separate
                # gRPC stream to the debug data server, simulating multi-worker setting.
                tf_debug.watch_graph(
                    run_options,
                    sess.graph,
                    debug_ops=["DebugNumericSummary"],
                    debug_urls=[self._debug_url + "/thread%d" % i])
                run_options_list.append(run_options)

            def run_v(thread_id):
                for _ in range(num_runs_per_thread):
                    sess.run(v, options=run_options_list[thread_id])

            run_threads = []
            for thread_id in range(num_threads):
                thread = threading.Thread(
                    target=functools.partial(run_v, thread_id))
                thread.start()
                run_threads.append(thread)

            for thread in run_threads:
                thread.join()

        report = self._debug_data_server.numerics_alert_report()
        self.assertEqual(2, len(report))
        self.assertTrue(report[0].device_name.lower().endswith("cpu:0"))
        self.assertEqual("u:0", report[0].tensor_name)
        self.assertGreater(report[0].first_timestamp, 0)
        self.assertEqual(0, report[0].nan_event_count)
        self.assertEqual(0, report[0].neg_inf_event_count)
        self.assertEqual(total_num_runs, report[0].pos_inf_event_count)
        self.assertTrue(report[1].device_name.lower().endswith("cpu:0"))
        self.assertEqual("u:0", report[0].tensor_name)
        self.assertGreaterEqual(report[1].first_timestamp,
                                report[0].first_timestamp)
        self.assertEqual(total_num_runs, report[1].nan_event_count)
        self.assertEqual(total_num_runs, report[1].neg_inf_event_count)
        self.assertEqual(0, report[1].pos_inf_event_count)
示例#12
0
    def testRunSimpleNetworkoWithInfAndNaNWorks(self):
        with tf.compat.v1.Session() as sess:
            x_init_val = np.array([[2.0], [-1.0]])
            y_init_val = np.array([[0.0], [-0.25]])
            z_init_val = np.array([[0.0, 3.0], [-1.0, 0.0]])

            x_init = tf.constant(x_init_val, shape=[2, 1], name="x_init")
            x = tf.Variable(x_init, name="x")
            y_init = tf.constant(y_init_val, shape=[2, 1])
            y = tf.Variable(y_init, name="y")
            z_init = tf.constant(z_init_val, shape=[2, 2])
            z = tf.Variable(z_init, name="z")

            u = tf.compat.v1.div(x, y, name="u")  # Produces an Inf.
            v = tf.matmul(z, u, name="v")  # Produces NaN and Inf.

            sess.run(x.initializer)
            sess.run(y.initializer)
            sess.run(z.initializer)

            run_options = tf.compat.v1.RunOptions(output_partition_graphs=True)
            tf_debug.watch_graph(run_options,
                                 sess.graph,
                                 debug_ops=["DebugNumericSummary"],
                                 debug_urls=[self._debug_url])

            result = sess.run(v, options=run_options)
            self.assertTrue(np.isnan(result[0, 0]))
            self.assertEqual(-np.inf, result[1, 0])

        # Debugger data is stored within a special directory within logdir.
        event_files = glob.glob(
            os.path.join(self._logdir, constants.DEBUGGER_DATA_DIRECTORY_NAME,
                         "events.debugger*"))
        self.assertEqual(1, len(event_files))

        self._check_health_pills_in_events_file(
            event_files[0], {
                "x:0:DebugNumericSummary": [x_init_val],
                "y:0:DebugNumericSummary": [y_init_val],
                "z:0:DebugNumericSummary": [z_init_val],
                "u:0:DebugNumericSummary": [x_init_val / y_init_val],
                "v:0:DebugNumericSummary":
                [np.matmul(z_init_val, x_init_val / y_init_val)],
            })

        report = self._debug_data_server.numerics_alert_report()
        self.assertEqual(2, len(report))
        self.assertTrue(report[0].device_name.lower().endswith("cpu:0"))
        self.assertEqual("u:0", report[0].tensor_name)
        self.assertGreater(report[0].first_timestamp, 0)
        self.assertEqual(0, report[0].nan_event_count)
        self.assertEqual(0, report[0].neg_inf_event_count)
        self.assertEqual(1, report[0].pos_inf_event_count)
        self.assertTrue(report[1].device_name.lower().endswith("cpu:0"))
        self.assertEqual("u:0", report[0].tensor_name)
        self.assertGreaterEqual(report[1].first_timestamp,
                                report[0].first_timestamp)
        self.assertEqual(1, report[1].nan_event_count)
        self.assertEqual(1, report[1].neg_inf_event_count)
        self.assertEqual(0, report[1].pos_inf_event_count)
示例#13
0
  def testConcurrentNumericsAlertsAreRegisteredCorrectly(self):
    num_threads = 3
    num_runs_per_thread = 2
    total_num_runs = num_threads * num_runs_per_thread

    # Before any Session runs, the report ought to be empty.
    self.assertEqual([], self._debug_data_server.numerics_alert_report())

    with tf.Session() as sess:
      x_init_val = np.array([[2.0], [-1.0]])
      y_init_val = np.array([[0.0], [-0.25]])
      z_init_val = np.array([[0.0, 3.0], [-1.0, 0.0]])

      x_init = tf.constant(x_init_val, shape=[2, 1], name="x_init")
      x = tf.Variable(x_init, name="x")
      y_init = tf.constant(y_init_val, shape=[2, 1])
      y = tf.Variable(y_init, name="y")
      z_init = tf.constant(z_init_val, shape=[2, 2])
      z = tf.Variable(z_init, name="z")

      u = tf.div(x, y, name="u")  # Produces an Inf.
      v = tf.matmul(z, u, name="v")  # Produces NaN and Inf.

      sess.run(x.initializer)
      sess.run(y.initializer)
      sess.run(z.initializer)

      run_options_list = []
      for i in range(num_threads):
        run_options = tf.RunOptions(output_partition_graphs=True)
        # Use different grpc:// URL paths so that each thread opens a separate
        # gRPC stream to the debug data server, simulating multi-worker setting.
        tf_debug.watch_graph(run_options,
                             sess.graph,
                             debug_ops=["DebugNumericSummary"],
                             debug_urls=[self._debug_url + "/thread%d" % i])
        run_options_list.append(run_options)

      def run_v(thread_id):
        for _ in range(num_runs_per_thread):
          sess.run(v, options=run_options_list[thread_id])  # DEBUG

      run_threads = []
      for thread_id in range(num_threads):
        thread = threading.Thread(target=functools.partial(run_v, thread_id))
        thread.start()
        run_threads.append(thread)

      for thread in run_threads:
        thread.join()

    report = self._debug_data_server.numerics_alert_report()
    self.assertEqual(2, len(report))
    self.assertEqual("/job:localhost/replica:0/task:0/cpu:0",
                     report[0].device_name)
    self.assertEqual("u:0", report[0].tensor_name)
    self.assertGreater(report[0].first_timestamp, 0)
    self.assertEqual(0, report[0].nan_event_count)
    self.assertEqual(0, report[0].neg_inf_event_count)
    self.assertEqual(total_num_runs, report[0].pos_inf_event_count)
    self.assertEqual("/job:localhost/replica:0/task:0/cpu:0",
                     report[1].device_name)
    self.assertEqual("u:0", report[0].tensor_name)
    self.assertGreaterEqual(report[1].first_timestamp,
                            report[0].first_timestamp)
    self.assertEqual(total_num_runs, report[1].nan_event_count)
    self.assertEqual(total_num_runs, report[1].neg_inf_event_count)
    self.assertEqual(0, report[1].pos_inf_event_count)
示例#14
0
  def testRunSimpleNetworkoWithInfAndNaNWorks(self):
    with tf.Session() as sess:
      x_init_val = np.array([[2.0], [-1.0]])
      y_init_val = np.array([[0.0], [-0.25]])
      z_init_val = np.array([[0.0, 3.0], [-1.0, 0.0]])

      x_init = tf.constant(x_init_val, shape=[2, 1], name="x_init")
      x = tf.Variable(x_init, name="x")
      y_init = tf.constant(y_init_val, shape=[2, 1])
      y = tf.Variable(y_init, name="y")
      z_init = tf.constant(z_init_val, shape=[2, 2])
      z = tf.Variable(z_init, name="z")

      u = tf.div(x, y, name="u")  # Produces an Inf.
      v = tf.matmul(z, u, name="v")  # Produces NaN and Inf.

      sess.run(x.initializer)
      sess.run(y.initializer)
      sess.run(z.initializer)

      run_options = tf.RunOptions(output_partition_graphs=True)
      tf_debug.watch_graph(run_options,
                           sess.graph,
                           debug_ops=["DebugNumericSummary"],
                           debug_urls=[self._debug_url])

      result = sess.run(v, options=run_options)
      self.assertTrue(np.isnan(result[0, 0]))
      self.assertEqual(-np.inf, result[1, 0])

    # Debugger data is stored within a special directory within logdir.
    event_files = glob.glob(
        os.path.join(self._logdir, constants.DEBUGGER_DATA_DIRECTORY_NAME,
                     "events.debugger*"))
    self.assertEqual(1, len(event_files))

    self._check_health_pills_in_events_file(event_files[0], {
        "x:0:DebugNumericSummary": [x_init_val],
        "y:0:DebugNumericSummary": [y_init_val],
        "z:0:DebugNumericSummary": [z_init_val],
        "u:0:DebugNumericSummary": [x_init_val / y_init_val],
        "v:0:DebugNumericSummary": [
            np.matmul(z_init_val, x_init_val / y_init_val)
        ],
    })

    report = self._debug_data_server.numerics_alert_report()
    self.assertEqual(2, len(report))
    self.assertEqual("/job:localhost/replica:0/task:0/cpu:0",
                     report[0].device_name)
    self.assertEqual("u:0", report[0].tensor_name)
    self.assertGreater(report[0].first_timestamp, 0)
    self.assertEqual(0, report[0].nan_event_count)
    self.assertEqual(0, report[0].neg_inf_event_count)
    self.assertEqual(1, report[0].pos_inf_event_count)
    self.assertEqual("/job:localhost/replica:0/task:0/cpu:0",
                     report[1].device_name)
    self.assertEqual("u:0", report[0].tensor_name)
    self.assertGreaterEqual(report[1].first_timestamp,
                            report[0].first_timestamp)
    self.assertEqual(1, report[1].nan_event_count)
    self.assertEqual(1, report[1].neg_inf_event_count)
    self.assertEqual(0, report[1].pos_inf_event_count)
示例#15
0
def train_original(model, config, session=None):
    # define a session if needed.
    session = session or tf.Session()

    # define summaries.
    summary_writer = tf.summary.FileWriter(config.log_dir, session.graph)
    image_summary = tf.summary.image(
        'generated images', model.G, max_outputs=8
    )
    loss_summaries = tf.summary.merge([
        tf.summary.scalar('wasserstein distance', -model.c_loss),
        tf.summary.scalar('generator loss', model.g_loss),
    ])

    # define optimizers.
    C_trainer = tf.train.RMSPropOptimizer(
        learning_rate=config.learning_rate
    )
    G_trainer = tf.train.RMSPropOptimizer(
        learning_rate=config.learning_rate
    )

    # define parameter update tasks
    c_grads = C_trainer.compute_gradients(model.c_loss, var_list=model.c_vars)
    g_grads = G_trainer.compute_gradients(model.g_loss, var_list=model.g_vars)
    update_C = C_trainer.apply_gradients(c_grads)
    update_G = G_trainer.apply_gradients(g_grads)
    clip_C = [
        v.assign(tf.clip_by_value(v, -config.clip_size, config.clip_size))
        for v in model.c_vars
    ]

    if config.execution_graph_dump_to:
        import os
        tf.summary.FileWriter(os.path.join(os.getcwd(), config.execution_graph_dump_to), tf.get_default_graph())
        exit(0)

    # main training session context
    with session:
        if config.resume:
            epoch_start = utils.load_checkpoint(session, model, config) + 1
        else:
            epoch_start = 1
            session.run(tf.global_variables_initializer())

        step_counter = 0
        for epoch in range(epoch_start, config.epochs+1):
            dataset = DATASETS[config.dataset](config.batch_size)
            dataset_length = DATASET_LENGTH_GETTERS[config.dataset]()
            dataset_stream = tqdm(dataset_length//config.batch_size)

            try:
                # while theta has not converged do
                while True:
                    # for t=0,...,n_critic do
                    for _ in range(config.critic_update_ratio):
                        # Sample {x^(i)}[i=1,m]~Pr a batch from the real data.
                        dataset_stream.update()
                        xs = next(dataset)
                        # Sample {z^(i)}[i=1,m]~p(z) a batch of prior samples.
                        zs = _sample_z(config.batch_size, model.z_size)

                        if config.critic_dump_to:
                            import os
                            from tensorflow.python import debug as tf_debug
                            run_options = tf.RunOptions()
                            tf_debug.watch_graph(
                                run_options,
                                session.graph,
                                debug_urls=['file://' + os.path.join(os.getcwd(), config.critic_dump_to)]
                            )
                            _, c_loss = session.run(
                                [update_C, model.c_loss],
                                feed_dict={
                                    model.z_in: zs,
                                    model.image_in: xs
                                },
                                options=run_options
                            )
                            exit(0)

                        # g_w <- grad_w[mean(f_w(x^(i))) - mean(f_w(g_theta(z^(i))))]
                        # w <- w + alpha * RMSProp(w, g_w)
                        _, c_loss = session.run(
                            [update_C, model.c_loss],
                            feed_dict={
                                model.z_in: zs,
                                model.image_in: xs
                            }
                        )
                        # w <- clip(w, -c, c)
                        session.run(clip_C)

                    # end for
                    # Sample {z^(i)}[i=1,m]~p(z) a batch of prior samples.
                    zs = _sample_z(config.batch_size, model.z_size)

                    if config.generator_dump_to:
                        import os
                        from tensorflow.python import debug as tf_debug
                        run_options = tf.RunOptions()
                        tf_debug.watch_graph(
                            run_options,
                            session.graph,
                            debug_urls=['file://' + os.path.join(os.getcwd(), config.generator_dump_to)]
                        )
                        _, g_loss = session.run(
                            [update_G, model.g_loss],
                            feed_dict={model.z_in: zs},
                            options=run_options
                        )
                        exit(0)

                    # g_theta <- -grad_theta[mean(f_w(g_theta(z^(i))))]
                    # theta <- theta - alpha * RMSProp(theta, g_theta)
                    _, g_loss = session.run(
                        [update_G, model.g_loss],
                        feed_dict={model.z_in: zs}
                    )



                    # display current training process status
                    step_counter += 1
                    dataset_stream.set_description((
                        'epoch: {epoch}/{epochs} | '
                        'progress: [{trained}/{total}] ({progress:.0f}%) | '
                        'g loss: {g_loss:.3f} | '
                        'w distance: {w_dist:.3f}'
                    ).format(
                        epoch=epoch,
                        epochs=config.epochs,
                        trained=(config.batch_size * step_counter) % dataset_length,
                        total=dataset_length,
                        progress=(
                            100. * (config.batch_size * step_counter) % dataset_length / dataset_length
                        ),
                        g_loss=g_loss,
                        w_dist=-c_loss,
                    ))
                    # log the generated samples
                    if step_counter % config.image_log_interval == 0:
                        zs = _sample_z(config.sample_size, model.z_size)
                        summary_writer.add_summary(session.run(
                            image_summary, feed_dict={
                                model.z_in: zs
                            }
                        ), step_counter)

                    # log the losses
                    if step_counter % config.loss_log_interval == 0:
                        zs = _sample_z(config.batch_size, model.z_size)
                        summary_writer.add_summary(session.run(
                            loss_summaries, feed_dict={
                                model.z_in: zs,
                                model.image_in: xs
                            }
                        ), step_counter)
                # end while
            except:
                pass

            # save the model at the every end of the epochs.
            utils.save_checkpoint(session, model, epoch, config)
示例#16
0
def train(model, config, session=None):
    # define a session if needed.
    session = session or tf.Session()

    # define summaries.
    summary_writer = tf.summary.FileWriter(config.log_dir, session.graph)
    image_summary = tf.summary.image(
        'generated images', model.G, max_outputs=8
    )
    loss_summaries = tf.summary.merge([
        tf.summary.scalar('wasserstein distance', -model.c_loss),
        tf.summary.scalar('generator loss', model.g_loss),
    ])

    # define optimizers.
    C_traner = tf.train.AdamOptimizer(
        learning_rate=config.learning_rate,
        beta1=config.beta1
    )
    G_trainer = tf.train.AdamOptimizer(
        learning_rate=config.learning_rate,
        beta1=config.beta1,
    )

    # define parameter update tasks
    c_grads = C_traner.compute_gradients(model.c_loss, var_list=model.c_vars)
    g_grads = G_trainer.compute_gradients(model.g_loss, var_list=model.g_vars)
    update_C = C_traner.apply_gradients(c_grads)
    update_G = G_trainer.apply_gradients(g_grads)
    clip_C = [
        v.assign(tf.clip_by_value(v, -config.clip_size, config.clip_size))
        for v in model.c_vars
    ]

    if config.execution_graph_dump_to:
        import os
        tf.summary.FileWriter(os.path.join(os.getcwd(), config.execution_graph_dump_to), tf.get_default_graph())
        exit(0)

    # main training session context
    with session:
        if config.resume:
            epoch_start = utils.load_checkpoint(session, model, config) + 1
        else:
            epoch_start = 1
            session.run(tf.global_variables_initializer())

        for epoch in range(epoch_start, config.epochs+1):
            dataset = DATASETS[config.dataset](config.batch_size)
            dataset_length = DATASET_LENGTH_GETTERS[config.dataset]()
            dataset_stream = tqdm(enumerate(dataset, 1))

            for batch_index, xs in dataset_stream:
                # where are we?
                iteration = (epoch-1)*(dataset_length // config.batch_size) + batch_index

                # place more weight on ciritic in the begining of the training.
                critic_update_ratio = (
                    30 if (batch_index < 25 or batch_index % 500 == 0) else
                    config.critic_update_ratio
                )

                # train the critic against the current generator and the data.
                for _ in range(critic_update_ratio):
                    zs = _sample_z(config.batch_size, model.z_size)

                    if config.critic_dump_to:
                        import os
                        from tensorflow.python import debug as tf_debug
                        run_options = tf.RunOptions()
                        tf_debug.watch_graph(
                            run_options,
                            session.graph,
                            debug_urls=['file://' + os.path.join(os.getcwd(), config.critic_dump_to)]
                        )
                        _, c_loss = session.run(
                            [update_C, model.c_loss],
                            feed_dict={
                                model.z_in: zs,
                                model.image_in: xs
                            },
                            options=run_options
                        )
                        # session.run(clip_C, options=run_options)
                        exit(0)

                    _, c_loss = session.run(
                        [update_C, model.c_loss],
                        feed_dict={
                            model.z_in: zs,
                            model.image_in: xs
                        }
                    )
                    session.run(clip_C)

                # train the generator against the current critic.
                zs = _sample_z(config.batch_size, model.z_size)

                if config.generator_dump_to:
                    import os
                    from tensorflow.python import debug as tf_debug
                    run_options = tf.RunOptions()
                    tf_debug.watch_graph(
                        run_options,
                        session.graph,
                        debug_urls=['file://' + os.path.join(os.getcwd(), config.generator_dump_to)]
                    )
                    _, g_loss = session.run(
                        [update_G, model.g_loss],
                        feed_dict={model.z_in: zs},
                        options=run_options
                    )
                    exit(0)

                _, g_loss = session.run(
                    [update_G, model.g_loss],
                    feed_dict={model.z_in: zs}
                )

                # display current training process status
                dataset_stream.set_description((
                    'epoch: {epoch}/{epochs} | '
                    'progress: [{trained}/{total}] ({progress:.0f}%) | '
                    'g loss: {g_loss:.3f} | '
                    'w distance: {w_dist:.3f}'
                ).format(
                    epoch=epoch,
                    epochs=config.epochs,
                    trained=batch_index*config.batch_size,
                    total=dataset_length,
                    progress=(
                        100.
                        * batch_index
                        * config.batch_size
                        / dataset_length
                    ),
                    g_loss=g_loss,
                    w_dist=-c_loss,
                ))

                # log the generated samples
                if iteration % config.image_log_interval == 0:
                    zs = _sample_z(config.sample_size, model.z_size)
                    summary_writer.add_summary(session.run(
                        image_summary, feed_dict={
                            model.z_in: zs
                        }
                    ), iteration)

                # log the losses
                if iteration % config.loss_log_interval == 0:
                    zs = _sample_z(config.batch_size, model.z_size)
                    summary_writer.add_summary(session.run(
                        loss_summaries, feed_dict={
                            model.z_in: zs,
                            model.image_in: xs
                        }
                    ), iteration)

            # save the model at the every end of the epochs.
            utils.save_checkpoint(session, model, epoch, config)