Ejemplo n.º 1
0
 def testGetVariablesDontReturnsTransients(self):
   with self.test_session():
     with variable_scope.variable_scope('A'):
       variables_lib2.local_variable(0)
     with variable_scope.variable_scope('B'):
       variables_lib2.local_variable(0)
     self.assertEquals([], variables_lib2.get_variables('A'))
     self.assertEquals([], variables_lib2.get_variables('B'))
Ejemplo n.º 2
0
 def testGetVariablesSuffix(self):
   with self.test_session():
     with variable_scope.variable_scope('A'):
       a = variables_lib2.variable('a', [5])
     with variable_scope.variable_scope('A'):
       b = variables_lib2.variable('b', [5])
     self.assertEquals([a], variables_lib2.get_variables(suffix='a'))
     self.assertEquals([b], variables_lib2.get_variables(suffix='b'))
Ejemplo n.º 3
0
 def testGetVariablesReturns(self):
   with self.test_session():
     with variable_scope.variable_scope('A'):
       a = variables_lib2.model_variable('a', [5])
     with variable_scope.variable_scope('B'):
       b = variables_lib2.model_variable('a', [5])
     self.assertEquals([a], variables_lib2.get_variables('A'))
     self.assertEquals([b], variables_lib2.get_variables('B'))
Ejemplo n.º 4
0
 def testGetVariablesWithScope(self):
   with self.test_session():
     with variable_scope.variable_scope('A') as var_scope:
       a = variables_lib2.variable('a', [5])
       b = variables_lib2.variable('b', [5])
     self.assertSetEqual(
         set([a, b]), set(variables_lib2.get_variables(var_scope)))
Ejemplo n.º 5
0
    def testTrainingSubsetsOfVariablesOnlyUpdatesThoseVariables(self):
        # First, train only the weights of the model.
        with ops.Graph().as_default():
            random_seed.set_random_seed(0)
            total_loss = self.ModelLoss()
            optimizer = gradient_descent.GradientDescentOptimizer(
                learning_rate=1.0)
            weights, biases = variables_lib.get_variables()

            train_op = training.create_train_op(total_loss, optimizer)
            train_weights = training.create_train_op(
                total_loss, optimizer, variables_to_train=[weights])
            train_biases = training.create_train_op(
                total_loss, optimizer, variables_to_train=[biases])

            with self.test_session() as session:
                # Initialize the variables.
                session.run(variables_lib2.global_variables_initializer())

                # Get the initial weights and biases values.
                weights_values, biases_values = session.run([weights, biases])
                self.assertGreater(np.linalg.norm(weights_values), 0)
                self.assertAlmostEqual(np.linalg.norm(biases_values), 0)

                # Update weights and biases.
                loss = session.run(train_op)
                self.assertGreater(loss, .5)
                new_weights, new_biases = session.run([weights, biases])

                # Check that the weights and biases have been updated.
                self.assertGreater(
                    np.linalg.norm(weights_values - new_weights), 0)
                self.assertGreater(np.linalg.norm(biases_values - new_biases),
                                   0)

                weights_values, biases_values = new_weights, new_biases

                # Update only weights.
                loss = session.run(train_weights)
                self.assertGreater(loss, .5)
                new_weights, new_biases = session.run([weights, biases])

                # Check that the weights have been updated, but biases have not.
                self.assertGreater(
                    np.linalg.norm(weights_values - new_weights), 0)
                self.assertAlmostEqual(
                    np.linalg.norm(biases_values - new_biases), 0)
                weights_values = new_weights

                # Update only biases.
                loss = session.run(train_biases)
                self.assertGreater(loss, .5)
                new_weights, new_biases = session.run([weights, biases])

                # Check that the biases have been updated, but weights have not.
                self.assertAlmostEqual(
                    np.linalg.norm(weights_values - new_weights), 0)
                self.assertGreater(np.linalg.norm(biases_values - new_biases),
                                   0)
Ejemplo n.º 6
0
 def testReuseVariable(self):
   with self.test_session():
     with variable_scope.variable_scope('A'):
       a = variables_lib2.variable('a', [])
     with variable_scope.variable_scope('A', reuse=True):
       b = variables_lib2.variable('a', [])
     self.assertEquals(a, b)
     self.assertListEqual([a], variables_lib2.get_variables())
Ejemplo n.º 7
0
 def testWrongIncludeGetVariablesToRestore(self):
   with self.test_session():
     with variable_scope.variable_scope('A'):
       a = variables_lib2.variable('a', [5])
     with variable_scope.variable_scope('B'):
       b = variables_lib2.variable('a', [5])
     self.assertEquals([a, b], variables_lib2.get_variables())
     self.assertEquals([], variables_lib2.get_variables_to_restore(['a']))
Ejemplo n.º 8
0
 def testExcludeGetMixedVariablesToRestore(self):
   with self.test_session():
     with variable_scope.variable_scope('A'):
       a = variables_lib2.variable('a', [5])
       b = variables_lib2.variable('b', [5])
     with variable_scope.variable_scope('B'):
       c = variables_lib2.variable('c', [5])
       d = variables_lib2.variable('d', [5])
     self.assertEquals([a, b, c, d], variables_lib2.get_variables())
     self.assertEquals(
         [b, d],
         variables_lib2.get_variables_to_restore(exclude=['A/a', 'B/c']))
Ejemplo n.º 9
0
  def testTrainingSubsetsOfVariablesOnlyUpdatesThoseVariables(self):
    # First, train only the weights of the model.
    with ops.Graph().as_default():
      random_seed.set_random_seed(0)
      total_loss = self.ModelLoss()
      optimizer = gradient_descent.GradientDescentOptimizer(learning_rate=1.0)
      weights, biases = variables_lib.get_variables()

      train_op = training.create_train_op(total_loss, optimizer)
      train_weights = training.create_train_op(
          total_loss, optimizer, variables_to_train=[weights])
      train_biases = training.create_train_op(
          total_loss, optimizer, variables_to_train=[biases])

      with session_lib.Session() as sess:
        # Initialize the variables.
        sess.run(variables_lib2.global_variables_initializer())

        # Get the intial weights and biases values.
        weights_values, biases_values = sess.run([weights, biases])
        self.assertGreater(np.linalg.norm(weights_values), 0)
        self.assertAlmostEqual(np.linalg.norm(biases_values), 0)

        # Update weights and biases.
        loss = sess.run(train_op)
        self.assertGreater(loss, .5)
        new_weights, new_biases = sess.run([weights, biases])

        # Check that the weights and biases have been updated.
        self.assertGreater(np.linalg.norm(weights_values - new_weights), 0)
        self.assertGreater(np.linalg.norm(biases_values - new_biases), 0)

        weights_values, biases_values = new_weights, new_biases

        # Update only weights.
        loss = sess.run(train_weights)
        self.assertGreater(loss, .5)
        new_weights, new_biases = sess.run([weights, biases])

        # Check that the weights have been updated, but biases have not.
        self.assertGreater(np.linalg.norm(weights_values - new_weights), 0)
        self.assertAlmostEqual(np.linalg.norm(biases_values - new_biases), 0)
        weights_values = new_weights

        # Update only biases.
        loss = sess.run(train_biases)
        self.assertGreater(loss, .5)
        new_weights, new_biases = sess.run([weights, biases])

        # Check that the biases have been updated, but weights have not.
        self.assertAlmostEqual(np.linalg.norm(weights_values - new_weights), 0)
        self.assertGreater(np.linalg.norm(biases_values - new_biases), 0)
  def test_variable_reuse(self):
    """Test that variable scopes work and inference on a real-ish case."""
    tensor1_ref = array_ops.zeros([6, 5, 7, 3, 3])
    tensor1_examples = array_ops.zeros([4, 5, 7, 3, 3])
    tensor2_ref = array_ops.zeros([4, 2, 3])
    tensor2_examples = array_ops.zeros([2, 2, 3])

    with variable_scope.variable_scope('dummy_scope', reuse=True):
      with self.assertRaisesRegexp(
          ValueError, 'does not exist, or was not created with '
          'tf.get_variable()'):
        virtual_batchnorm.VBN(tensor1_ref)

    vbn1 = virtual_batchnorm.VBN(tensor1_ref, name='vbn1')
    vbn2 = virtual_batchnorm.VBN(tensor2_ref, name='vbn2')

    # Fetch reference and examples after virtual batch normalization. Also
    # fetch in variable reuse case.
    to_fetch = []

    to_fetch.append(vbn1.reference_batch_normalization())
    to_fetch.append(vbn2.reference_batch_normalization())
    to_fetch.append(vbn1(tensor1_examples))
    to_fetch.append(vbn2(tensor2_examples))

    variable_scope.get_variable_scope().reuse_variables()

    to_fetch.append(vbn1.reference_batch_normalization())
    to_fetch.append(vbn2.reference_batch_normalization())
    to_fetch.append(vbn1(tensor1_examples))
    to_fetch.append(vbn2(tensor2_examples))

    self.assertEqual(4, len(contrib_variables_lib.get_variables()))

    with self.session(use_gpu=True) as sess:
      variables_lib.global_variables_initializer().run()
      sess.run(to_fetch)
    def test_variable_reuse(self):
        """Test that variable scopes work and inference on a real-ish case."""
        tensor1_ref = array_ops.zeros([6, 5, 7, 3, 3])
        tensor1_examples = array_ops.zeros([4, 5, 7, 3, 3])
        tensor2_ref = array_ops.zeros([4, 2, 3])
        tensor2_examples = array_ops.zeros([2, 2, 3])

        with variable_scope.variable_scope('dummy_scope', reuse=True):
            with self.assertRaisesRegexp(
                    ValueError, 'does not exist, or was not created with '
                    'tf.get_variable()'):
                virtual_batchnorm.VBN(tensor1_ref)

        vbn1 = virtual_batchnorm.VBN(tensor1_ref, name='vbn1')
        vbn2 = virtual_batchnorm.VBN(tensor2_ref, name='vbn2')

        # Fetch reference and examples after virtual batch normalization. Also
        # fetch in variable reuse case.
        to_fetch = []

        to_fetch.append(vbn1.reference_batch_normalization())
        to_fetch.append(vbn2.reference_batch_normalization())
        to_fetch.append(vbn1(tensor1_examples))
        to_fetch.append(vbn2(tensor2_examples))

        variable_scope.get_variable_scope().reuse_variables()

        to_fetch.append(vbn1.reference_batch_normalization())
        to_fetch.append(vbn2.reference_batch_normalization())
        to_fetch.append(vbn1(tensor1_examples))
        to_fetch.append(vbn2(tensor2_examples))

        self.assertEqual(4, len(contrib_variables_lib.get_variables()))

        with self.test_session(use_gpu=True) as sess:
            variables_lib.global_variables_initializer().run()
            sess.run(to_fetch)
Ejemplo n.º 12
0
    def build_dqn(self):
        self.w = {}
        self.t_w = {}

        #initializer = tf.contrib.layers.xavier_initializer()
        initializer = tf.truncated_normal_initializer(0, 0.02)
        activation_fn = tf.nn.relu

        # training network
        with tf.variable_scope('prediction'):
            if self.cnn_format == 'NHWC':
                self.s_t = tf.placeholder('float32', [
                    None, self.screen_width, self.screen_height,
                    self.history_length
                ],
                                          name='s_t')
            elif data_format == 'NCHW':
                self.s_t = tf.placeholder('float32', [
                    None, self.history_length, self.screen_width,
                    self.screen_height
                ],
                                          name='s_t')

            self.l1, self.w['l1_w'], self.w['l1_b'] = conv2d(self.s_t / 255.,
                                                             16, [8, 8],
                                                             [4, 4],
                                                             initializer,
                                                             activation_fn,
                                                             self.cnn_format,
                                                             name='l1')
            self.l2, self.w['l2_w'], self.w['l2_b'] = conv2d(self.l1,
                                                             32, [4, 4],
                                                             [2, 2],
                                                             initializer,
                                                             activation_fn,
                                                             self.cnn_format,
                                                             name='l2')

            shape = self.l2.get_shape().as_list()
            self.l2_flat = tf.reshape(
                self.l2,
                [-1, functools.reduce(lambda x, y: x * y, shape[1:])])

            if self.dueling:
                self.value_hid, self.w['l3_val_w'], self.w['l3_val_b'] = \
                    linear(self.l2_flat, 256, activation_fn=activation_fn, name='value_hid')

                self.adv_hid, self.w['l3_adv_w'], self.w['l3_adv_b'] = \
                    linear(self.l2_flat, 256, activation_fn=activation_fn, name='adv_hid')

                self.value, self.w['val_w_out'], self.w['val_w_b'] = \
                  linear(self.value_hid, 1, name='value_out')

                self.advantage, self.w['adv_w_out'], self.w['adv_w_b'] = \
                  linear(self.adv_hid, self.env.action_size, name='adv_out')

                # Average Dueling
                self.q = self.value + (self.advantage - tf.reduce_mean(
                    self.advantage, reduction_indices=1, keep_dims=True))
            else:
                self.l3, self.w['l3_w'], self.w['l3_b'] = linear(
                    self.l2_flat, 256, activation_fn=activation_fn, name='l3')
                self.q, self.w['q_w'], self.w['q_b'] = linear(
                    self.l3, self.env.action_size, name='q')

            self.q_action = tf.argmax(self.q, dimension=1)

        # target network
        with tf.variable_scope('target'):

            if self.cnn_format == 'NHWC':
                self.target_s_t = tf.placeholder('float32', [
                    None, self.screen_width, self.screen_height,
                    self.history_length
                ],
                                                 name='target_s_t')
            else:
                self.target_s_t = tf.placeholder('float32', [
                    None, self.history_length, self.screen_width,
                    self.screen_height
                ],
                                                 name='target_s_t')

            self.target_l1, self.t_w['l1_w'], self.t_w['l1_b'] = conv2d(
                self.target_s_t / 255.,
                16, [8, 8], [4, 4],
                initializer,
                activation_fn,
                self.cnn_format,
                name='target_l1')
            self.target_l2, self.t_w['l2_w'], self.t_w['l2_b'] = conv2d(
                self.target_l1,
                32, [4, 4], [2, 2],
                initializer,
                activation_fn,
                self.cnn_format,
                name='target_l2')

            shape = self.target_l2.get_shape().as_list()
            self.target_l2_flat = tf.reshape(
                self.target_l2,
                [-1, functools.reduce(lambda x, y: x * y, shape[1:])])

            if self.dueling:
                self.t_value_hid, self.t_w['l3_val_w'], self.t_w['l3_val_b'] = \
                    linear(self.target_l2_flat, 256, activation_fn=activation_fn, name='target_value_hid')

                self.t_adv_hid, self.t_w['l3_adv_w'], self.t_w['l3_adv_b'] = \
                    linear(self.target_l2_flat, 256, activation_fn=activation_fn, name='target_adv_hid')

                self.t_value, self.t_w['val_w_out'], self.t_w['val_w_b'] = \
                  linear(self.t_value_hid, 1, name='target_value_out')

                self.t_advantage, self.t_w['adv_w_out'], self.t_w['adv_w_b'] = \
                  linear(self.t_adv_hid, self.env.action_size, name='target_adv_out')

                # Average Dueling
                self.target_q = self.t_value + (
                    self.t_advantage - tf.reduce_mean(
                        self.t_advantage, reduction_indices=1, keep_dims=True))
            else:
                self.target_l3, self.t_w['l3_w'], self.t_w['l3_b'] = \
                    linear(self.target_l2_flat, 256, activation_fn=activation_fn, name='target_l3')
                self.target_q, self.t_w['q_w'], self.t_w['q_b'] = \
                    linear(self.target_l3, self.env.action_size, name='target_q')

            self.target_q_idx = tf.placeholder('int32', [None, None],
                                               'outputs_idx')
            self.target_q_with_idx = tf.gather_nd(self.target_q,
                                                  self.target_q_idx)

            global_collection = tf.get_collection_ref(
                tf.GraphKeys.GLOBAL_VARIABLES)
            for var in variables.get_variables(scope="target"):
                tf.add_to_collection(tf.GraphKeys.LOCAL_VARIABLES, var)
                global_collection.remove(var)

        with tf.variable_scope('pred_to_target'):
            self.t_w_input = {}
            self.t_w_assign_op = {}

            for name in self.w.keys():
                self.t_w_assign_op[name] = self.t_w[name].assign(self.w[name])

        # optimizer
        with tf.variable_scope('optimizer'):
            self.target_q_t = tf.placeholder('float32', [None],
                                             name='target_q_t')
            self.action = tf.placeholder('int64', [None], name='action')

            action_one_hot = tf.one_hot(self.action,
                                        self.env.action_size,
                                        1.0,
                                        0.0,
                                        name='action_one_hot')
            q_acted = tf.reduce_sum(self.q * action_one_hot,
                                    reduction_indices=1,
                                    name='q_acted')

            self.delta = self.target_q_t - q_acted
            self.loss = tf.reduce_mean(tf.square(self.delta), name='loss')

            new_grads_and_vars = []
            grads_and_vars = self.optimizer.compute_gradients(
                self.loss, list(self.w.values()))
            for grad, var in tuple(grads_and_vars):
                new_grads_and_vars.append((tf.clip_by_norm(grad, 40), var))

            self.optim = self.optimizer.apply_gradients(new_grads_and_vars)

            global_collection = tf.get_collection_ref(
                tf.GraphKeys.GLOBAL_VARIABLES)
            for var in variables.get_variables(scope="optimizer"):
                tf.add_to_collection(tf.GraphKeys.LOCAL_VARIABLES, var)
                global_collection.remove(var)

        with tf.variable_scope('summary'):
            scalar_summary_tags = ['average.reward', 'average.loss', 'average.q', \
                'episode.max reward', 'episode.min reward', 'episode.avg reward', 'episode.num of game', 'training.learning_rate']

            self.summary_placeholders = {}
            self.summary_ops = {}

            for tag in scalar_summary_tags:
                self.summary_placeholders[tag] = tf.placeholder(
                    'float32', None, name=tag.replace(' ', '_'))
                self.summary_ops[tag] = tf.summary.scalar(
                    "%s-%s/%s" % (self.env_name, self.env_type, tag),
                    self.summary_placeholders[tag])

            self.summary_op = tf.summary.merge(list(self.summary_ops.values()),
                                               name='total_summary')

            histogram_summary_tags = ['episode.rewards', 'episode.actions']

            for tag in histogram_summary_tags:
                self.summary_placeholders[tag] = tf.placeholder(
                    'float32', None, name=tag.replace(' ', '_'))
                self.summary_ops[tag] = tf.summary.histogram(
                    tag, self.summary_placeholders[tag])
Ejemplo n.º 13
0
    def build_a3c(self):
        self.w = {}
        self.t_w = {}

        initializer = tf.truncated_normal_initializer(0, 0.02)
        activation_fn = tf.nn.relu
        DQN_type = 'nature'
        data_format = self.cnn_format
        beta = 0.1

        if data_format == 'NHWC':
            self.s_t = tf.placeholder('float32', [
                None, self.screen_width, self.screen_height,
                self.history_length
            ],
                                      name='s_t')
        elif data_format == 'NCHW':
            self.s_t = tf.placeholder('float32', [
                None, self.history_length, self.screen_width,
                self.screen_height
            ],
                                      name='s_t')

        if data_format == 'NCHW':
            device = '/gpu:0'
        elif data_format == 'NHWC':
            device = '/cpu:0'
        else:
            raise ValueError('Unknown data_format: %s' % data_format)

        def flat(layer):
            shape = layer.get_shape().as_list()
            return tf.reshape(
                layer,
                [-1, functools.reduce(lambda x, y: x * y, shape[1:])])

        if DQN_type.lower() == 'nature':
            with tf.variable_scope('Nature_DQN'), tf.device(device):
                self.l0 = tf.div(self.s_t, 255.)
                self.l1, self.w['l1_w'], self.w['l1_b'] = conv2d(
                    self.l0,
                    32, [8, 8], [4, 4],
                    initializer,
                    activation_fn,
                    data_format,
                    name='l1_conv')
                self.l2, self.w['l2_w'], self.w['l2_b'] = conv2d(
                    self.l1,
                    64, [4, 4], [2, 2],
                    initializer,
                    activation_fn,
                    data_format,
                    name='l2_conv')
                self.l3, self.w['l3_w'], self.w['l3_b'] = conv2d(
                    self.l2,
                    64, [3, 3], [1, 1],
                    initializer,
                    activation_fn,
                    data_format,
                    name='l3_conv')

                self.l3_flat = flat(self.l3)

                self.l4, self.w['l4_w'], self.w['l4_b'] = \
                    linear(self.l3_flat, 512, activation_fn=activation_fn, name='l4_linear')
        elif DQN_type.lower() == 'nips':
            with tf.variable_scope('Nips_DQN'), tf.device(device):
                self.l0 = tf.div(self.s_t, 255.)
                self.l1, self.w['l1_w'], self.w['l1_b'] = conv2d(
                    self.l0,
                    16, [8, 8], [4, 4],
                    initializer,
                    activation_fn,
                    data_format,
                    name='l1_conv')
                self.l2, self.w['l2_w'], self.w['l2_b'] = conv2d(
                    self.l1,
                    32, [4, 4], [2, 2],
                    initializer,
                    activation_fn,
                    data_format,
                    name='l2_conv')

                self.l2_flat = flat(self.l2)

                self.l4, self.w['l4_w'], self.w['l4_b'] = \
                    linear(self.l2_flat, 256, activation_fn=activation_fn, name='l4_linear')
        else:
            raise ValueError('Wrong DQN type: %s' % DQN_type)

        def reshape_w(w):
            shape = w.get_shape().as_list()
            return tf.transpose(tf.reshape(w, shape[:2] + [1, -1]),
                                [3, 0, 1, 2])

        # Policy head.
        with tf.variable_scope('policy'):
            # 512 -> action_size
            self.policy_logits, self.w['p_w'], self.w['p_b'] = linear(
                self.l4, self.env.action_size, name='linear')

            with tf.variable_scope('policy'):
                self.policy = tf.nn.softmax(self.policy_logits, name='pi')
            with tf.variable_scope('log_policy'):
                self.log_policy = tf.log(self.policy)
            with tf.variable_scope('policy_entropy'):
                self.policy_entropy = -tf.reduce_sum(
                    self.policy * self.log_policy, 1)

            # with tf.variable_scope('pred_action'):
            # self.sampled_action = tf.multinomial(self.policy_logits, 1)
            # self.sampled_action = batch_sample(self.policy)
            # sampled_action_one_hot = tf.one_hot(self.sampled_action, self.env.action_size, 1., 0.)
            # with tf.variable_scope('log_policy_of_action'):
            # self.log_policy_of_sampled_action = tf.reduce_sum(self.log_policy * sampled_action_one_hot, 1)

        # Value head.
        with tf.variable_scope('value'):
            # 512 -> 1
            self.value, self.w['q_w'], self.w['q_b'] = linear(self.l4,
                                                              1,
                                                              name='linear')

        with tf.variable_scope('optimizer'):
            self.R = tf.placeholder('float32', [None], name='target_reward')
            self.action = tf.placeholder('int64', [None], name='action')

            # self.true_log_policy = tf.placeholder('float32', [None], name='true_action')
            self.true_log_policy = tf.nn.sparse_softmax_cross_entropy_with_logits(
                labels=self.action,
                logits=self.policy_logits,
                name='true_action')

            # TODO: equation on paper and codes of other implementations are different
            with tf.variable_scope('policy_loss'):
                self.policy_loss = -(self.true_log_policy \
                    * (self.R - self.value)) - beta * self.policy_entropy

            with tf.variable_scope('value_loss'):
                self.value_loss = tf.pow(self.R - self.value, 2) / 2

            with tf.variable_scope('total_loss'):
                self.loss = tf.reduce_mean(self.policy_loss + self.value_loss)

            new_grads_and_vars = []
            grads_and_vars = self.optimizer.compute_gradients(
                self.loss, list(self.w.values()))
            for grad, var in tuple(grads_and_vars):
                new_grads_and_vars.append((tf.clip_by_norm(grad, 40), var))

            self.optim = self.optimizer.apply_gradients(new_grads_and_vars)

            global_collection = tf.get_collection_ref(
                tf.GraphKeys.GLOBAL_VARIABLES)
            for var in variables.get_variables(scope="optimizer"):
                tf.add_to_collection(tf.GraphKeys.LOCAL_VARIABLES, var)
                global_collection.remove(var)

        # if global_network != None:
        if False:
            with tf.variable_scope('copy_from_target'):
                copy_ops = []

                for name in self.w.keys():
                    copy_op = self.w[name].assign(global_network.w[name])
                    copy_ops.append(copy_op)

                self.global_copy_op = tf.group(*copy_ops,
                                               name='global_copy_op')
Ejemplo n.º 14
0
 def add_fc_weights_summary(name, path):
     biases = variables.get_variables(path + '/biases')
     weights = variables.get_variables(path + '/weights')
     biases = tf.expand_dims(biases, 1)
     tf.summary.image(name, [tf.transpose(tf.concat([weights, biases], 1))])
Ejemplo n.º 15
0
def main(unused_argv=None):
    with tf.Graph().as_default():
        # Force all input processing onto CPU in order to reserve the GPU for the
        # forward inference and back-propagation.
        device = '/cpu:0' if not FLAGS.ps_tasks else '/job:worker/cpu:0'
        with tf.device(
                tf.train.replica_device_setter(FLAGS.ps_tasks,
                                               worker_device=device)):
            inputs, _ = image_utils.imagenet_inputs(FLAGS.batch_size,
                                                    FLAGS.image_size)
            # Load style images and select one at random (for each graph execution, a
            # new random selection occurs)

            _, style_labels, \
                style_gram_matrices = image_utils.style_image_inputs(
                    os.path.expanduser(FLAGS.style_dataset_file),
                    batch_size=FLAGS.batch_size,
                    image_size=FLAGS.image_size,
                    square_crop=True,
                    shuffle=True)

        with tf.device(tf.train.replica_device_setter(FLAGS.ps_tasks)):
            # Process style and weight flags
            num_styles = FLAGS.num_styles
            if FLAGS.style_coefficients is None:
                style_coefficients = [1.0 for _ in range(num_styles)]
            else:
                style_coefficients = ast.literal_eval(FLAGS.style_coefficients)
            if len(style_coefficients) != num_styles:
                raise ValueError(
                    'number of style coefficients differs from number of styles'
                )
            content_weights = ast.literal_eval(FLAGS.content_weights)
            style_weights = ast.literal_eval(FLAGS.style_weights)

            # Rescale style weights dynamically based on the current style image
            style_coefficient = tf.gather(tf.constant(style_coefficients),
                                          style_labels)
            style_weights = dict((key, style_coefficient * value)
                                 for key, value in style_weights.items())

            # Define the model
            stylized_inputs = model.transform(inputs,
                                              alpha=FLAGS.alpha,
                                              normalizer_params={
                                                  'labels': style_labels,
                                                  'num_categories': num_styles,
                                                  'center': True,
                                                  'scale': True
                                              })

            # Compute losses.
            total_loss, loss_dict = learning.total_loss(
                inputs, stylized_inputs, style_gram_matrices, content_weights,
                style_weights)
            for key, value in loss_dict.items():
                tf.summary.scalar(key, value)

            instance_norm_vars = [
                var for var in slim.get_variables('transformer')
                if 'InstanceNorm' in var.name
            ]
            other_vars = [
                var for var in slim.get_variables('transformer')
                if 'InstanceNorm' not in var.name
            ]

            # Function to restore VGG16 parameters.
            init_fn_vgg = slim.assign_from_checkpoint_fn(
                vgg.checkpoint_file(), slim.get_variables('vgg_16'))

            checkpoint = os.path.expanduser(FLAGS.checkpoint)
            if tf.gfile.IsDirectory(checkpoint):
                checkpoint = tf.train.latest_checkpoint(checkpoint)
                tf.logging.info(
                    'loading latest checkpoint file: {}'.format(checkpoint))

            # Function to restore N-styles parameters.
            vars = slim.get_variables(
                'transformer') if FLAGS.restore_all_weights else other_vars
            init_fn_n_styles = slim.assign_from_checkpoint_fn(checkpoint, vars)

            def init_fn(session):
                init_fn_vgg(session)
                init_fn_n_styles(session)

            # Set up training.
            optimizer = tf.train.AdamOptimizer(FLAGS.learning_rate)
            train_op = slim.learning.create_train_op(
                total_loss,
                optimizer,
                clip_gradient_norm=FLAGS.clip_gradient_norm,
                variables_to_train=instance_norm_vars,
                summarize_gradients=False)

            savertransformer = tf.train.Saver(
                variables.get_variables("transformer"),
                save_relative_paths=True)

            # Run training.
            slim.learning.train(train_op=train_op,
                                logdir=os.path.expanduser(FLAGS.train_dir),
                                log_every_n_steps=FLAGS.log_every_n_steps,
                                master=FLAGS.master,
                                is_chief=FLAGS.task == 0,
                                number_of_steps=FLAGS.train_steps,
                                init_fn=init_fn,
                                saver=savertransformer,
                                save_summaries_secs=FLAGS.save_summaries_secs,
                                save_interval_secs=FLAGS.save_interval_secs)