Exemplo n.º 1
0
 def testWarmStartVarBothVarsPartitioned(self):
   _, weights = self._create_prev_run_var(
       "old_scope/fruit_weights",
       shape=[4, 1],
       initializer=[[0.5], [1.], [1.5], [2.]],
       partitioner=lambda shape, dtype: [2, 1])
   prev_val = np.concatenate([weights[0], weights[1]], axis=0)
   # New session and new graph.
   with ops.Graph().as_default() as g:
     with self.test_session(graph=g) as sess:
       fruit_weights = variable_scope.get_variable(
           "new_scope/fruit_weights",
           shape=[4, 1],
           initializer=[[0.], [0.], [0.], [0.]],
           partitioner=lambda shape, dtype: [2, 1])
       self.assertTrue(
           isinstance(fruit_weights, variables.PartitionedVariable))
       ws_util._warmstart_var(
           fruit_weights,
           self.get_temp_dir(),
           prev_tensor_name="old_scope/fruit_weights")
       sess.run(variables.global_variables_initializer())
       fruit_weights = fruit_weights._get_variable_list()
       new_val = np.concatenate(
           [fruit_weights[0].eval(sess), fruit_weights[1].eval(sess)], axis=0)
       self.assertAllEqual(prev_val, new_val)
  def testWarmStartVarMultipleVarsBothPartitioned(self):
    _, prev_vals = self._create_prev_run_multiple_vars(
        var_names=["fruit_weights", "other_weights"],
        shapes=[[4, 1], [4, 1]],
        initializers=[[[0.5], [1.], [1.5], [2.]], [[.05], [.1], [.15], [.2]]],
        partitioners=[lambda shape, dtype: [2, 1], lambda shape, dtype: [2, 1]])

    with ops.Graph().as_default() as g:
      with self.test_session(graph=g) as sess:
        fruit_weights = variable_scope.get_variable(
            "fruit_weights",
            shape=[4, 1],
            initializer=[[0.], [0.], [0.], [0.]],
            partitioner=lambda shape, dtype: [2, 1])
        other_weights = variable_scope.get_variable(
            "other_weights",
            shape=[4, 1],
            initializer=[[0.], [0.], [0.], [0.]],
            partitioner=lambda shape, dtype: [2, 1])
        ws_util._warmstart_var([fruit_weights, other_weights],
                               self.get_temp_dir())
        sess.run(variables.global_variables_initializer())
        fruit_weights = fruit_weights._get_variable_list()
        new_fruit_weights_val = np.concatenate(
            [fruit_weights[0].eval(sess), fruit_weights[1].eval(sess)], axis=0)
        other_weights = other_weights._get_variable_list()
        new_other_weights_val = np.concatenate(
            [other_weights[0].eval(sess), other_weights[1].eval(sess)], axis=0)
        self.assertAllEqual(
            np.concatenate(prev_vals[0], axis=0), new_fruit_weights_val)
        self.assertAllEqual(
            np.concatenate(prev_vals[1], axis=0), new_other_weights_val)
Exemplo n.º 3
0
  def testWarmStartVar(self):
    _, prev_val = self._create_prev_run_var(
        "fruit_weights", initializer=[[0.5], [1.], [1.5], [2.]])

    with ops.Graph().as_default() as g:
      with self.test_session(graph=g) as sess:
        fruit_weights = variable_scope.get_variable(
            "fruit_weights", initializer=[[0.], [0.], [0.], [0.]])
        ws_util._warmstart_var(fruit_weights, self.get_temp_dir())
        sess.run(variables.global_variables_initializer())
        self.assertAllEqual(prev_val, fruit_weights.eval(sess))
Exemplo n.º 4
0
  def testWarmStartVarPrevVarPartitioned(self):
    _, weights = self._create_prev_run_var(
        "fruit_weights",
        shape=[4, 1],
        initializer=[[0.5], [1.], [1.5], [2.]],
        partitioner=lambda shape, dtype: [2, 1])
    prev_val = np.concatenate([weights[0], weights[1]], axis=0)

    with ops.Graph().as_default() as g:
      with self.test_session(graph=g) as sess:
        fruit_weights = variable_scope.get_variable(
            "fruit_weights", initializer=[[0.], [0.], [0.], [0.]])
        ws_util._warmstart_var(fruit_weights, self.get_temp_dir())
        sess.run(variables.global_variables_initializer())
        self.assertAllEqual(prev_val, fruit_weights.eval(sess))
Exemplo n.º 5
0
    def testWarmStartVarMultipleVarsMixOfPartitions(self):
        # First is not partitioned, but the second two are.
        _, prev_vals = self._create_prev_run_multiple_vars(
            var_names=["fruit_weights", "other_weights", "veggie_weights"],
            shapes=[None, [4, 1], [4, 1]],
            initializers=[[[0.5], [1.], [1.5], [2.]], [[.05], [.1], [.15],
                                                       [.2]],
                          [[5.], [10.], [15.], [20.]]],
            partitioners=[
                None, lambda shape, dtype: [2, 1], lambda shape, dtype: [2, 1]
            ])

        with ops.Graph().as_default() as g:
            with self.test_session(graph=g) as sess:
                fruit_weights = variable_scope.get_variable("fruit_weights",
                                                            initializer=[[0.],
                                                                         [0.],
                                                                         [0.],
                                                                         [0.]])
                other_weights = variable_scope.get_variable(
                    "other_weights",
                    shape=[4, 1],
                    initializer=[[0.], [0.], [0.], [0.]],
                    partitioner=lambda shape, dtype: [2, 1])
                veggie_weights = variable_scope.get_variable(
                    "veggie_weights",
                    shape=[4, 1],
                    initializer=[[0.], [0.], [0.], [0.]],
                    partitioner=lambda shape, dtype: [2, 1])
                # Flatten one of the partitioned variables.
                ws_util._warmstart_var([fruit_weights, other_weights] +
                                       veggie_weights._get_variable_list(),
                                       self.get_temp_dir())
                sess.run(variables.global_variables_initializer())
                veggie_weights = veggie_weights._get_variable_list()
                new_veggie_weights_val = np.concatenate([
                    veggie_weights[0].eval(sess), veggie_weights[1].eval(sess)
                ],
                                                        axis=0)
                other_weights = other_weights._get_variable_list()
                new_other_weights_val = np.concatenate(
                    [other_weights[0].eval(sess), other_weights[1].eval(sess)],
                    axis=0)
                self.assertAllEqual(prev_vals[0], fruit_weights.eval(sess))
                self.assertAllEqual(np.concatenate(prev_vals[1], axis=0),
                                    new_other_weights_val)
                self.assertAllEqual(np.concatenate(prev_vals[2], axis=0),
                                    new_veggie_weights_val)
  def testWarmStartVarMultipleVars(self):
    _, prev_vals = self._create_prev_run_multiple_vars(
        var_names=["fruit_weights", "other_weights"],
        initializers=[[[0.5], [1.], [1.5], [2.]], [[.05], [.1], [.15], [.2]]])

    with ops.Graph().as_default() as g:
      with self.test_session(graph=g) as sess:
        fruit_weights = variable_scope.get_variable(
            "fruit_weights", initializer=[[0.], [0.], [0.], [0.]])
        other_weights = variable_scope.get_variable(
            "other_weights", initializer=[[0.], [0.], [0.], [0.]])
        ws_util._warmstart_var([fruit_weights, other_weights],
                               self.get_temp_dir())
        sess.run(variables.global_variables_initializer())
        self.assertAllEqual(prev_vals[0], fruit_weights.eval(sess))
        self.assertAllEqual(prev_vals[1], other_weights.eval(sess))
  def testWarmStartVarMultipleVarsMixOfPartitions(self):
    # First is not partitioned, but the second two are.
    _, prev_vals = self._create_prev_run_multiple_vars(
        var_names=["fruit_weights", "other_weights", "veggie_weights"],
        shapes=[None, [4, 1], [4, 1]],
        initializers=[[[0.5], [1.], [1.5], [2.]], [[.05], [.1], [.15], [.2]],
                      [[5.], [10.], [15.], [20.]]],
        partitioners=[
            None, lambda shape, dtype: [2, 1], lambda shape, dtype: [2, 1]
        ])

    with ops.Graph().as_default() as g:
      with self.test_session(graph=g) as sess:
        fruit_weights = variable_scope.get_variable(
            "fruit_weights", initializer=[[0.], [0.], [0.], [0.]])
        other_weights = variable_scope.get_variable(
            "other_weights",
            shape=[4, 1],
            initializer=[[0.], [0.], [0.], [0.]],
            partitioner=lambda shape, dtype: [2, 1])
        veggie_weights = variable_scope.get_variable(
            "veggie_weights",
            shape=[4, 1],
            initializer=[[0.], [0.], [0.], [0.]],
            partitioner=lambda shape, dtype: [2, 1])
        # Flatten one of the partitioned variables.
        ws_util._warmstart_var([fruit_weights, other_weights] +
                               veggie_weights._get_variable_list(),
                               self.get_temp_dir())
        sess.run(variables.global_variables_initializer())
        veggie_weights = veggie_weights._get_variable_list()
        new_veggie_weights_val = np.concatenate(
            [veggie_weights[0].eval(sess), veggie_weights[1].eval(sess)],
            axis=0)
        other_weights = other_weights._get_variable_list()
        new_other_weights_val = np.concatenate(
            [other_weights[0].eval(sess), other_weights[1].eval(sess)], axis=0)
        self.assertAllEqual(prev_vals[0], fruit_weights.eval(sess))
        self.assertAllEqual(
            np.concatenate(prev_vals[1], axis=0), new_other_weights_val)
        self.assertAllEqual(
            np.concatenate(prev_vals[2], axis=0), new_veggie_weights_val)
Exemplo n.º 8
0
    def testWarmStartVarMultipleVars(self):
        _, prev_vals = self._create_prev_run_multiple_vars(
            var_names=["fruit_weights", "other_weights"],
            initializers=[[[0.5], [1.], [1.5], [2.]], [[.05], [.1], [.15],
                                                       [.2]]])

        with ops.Graph().as_default() as g:
            with self.test_session(graph=g) as sess:
                fruit_weights = variable_scope.get_variable("fruit_weights",
                                                            initializer=[[0.],
                                                                         [0.],
                                                                         [0.],
                                                                         [0.]])
                other_weights = variable_scope.get_variable("other_weights",
                                                            initializer=[[0.],
                                                                         [0.],
                                                                         [0.],
                                                                         [0.]])
                ws_util._warmstart_var([fruit_weights, other_weights],
                                       self.get_temp_dir())
                sess.run(variables.global_variables_initializer())
                self.assertAllEqual(prev_vals[0], fruit_weights.eval(sess))
                self.assertAllEqual(prev_vals[1], other_weights.eval(sess))
Exemplo n.º 9
0
    def testWarmStartVarMultipleVarsBothPartitioned(self):
        _, prev_vals = self._create_prev_run_multiple_vars(
            var_names=["fruit_weights", "other_weights"],
            shapes=[[4, 1], [4, 1]],
            initializers=[[[0.5], [1.], [1.5], [2.]], [[.05], [.1], [.15],
                                                       [.2]]],
            partitioners=[
                lambda shape, dtype: [2, 1], lambda shape, dtype: [2, 1]
            ])

        with ops.Graph().as_default() as g:
            with self.test_session(graph=g) as sess:
                fruit_weights = variable_scope.get_variable(
                    "fruit_weights",
                    shape=[4, 1],
                    initializer=[[0.], [0.], [0.], [0.]],
                    partitioner=lambda shape, dtype: [2, 1])
                other_weights = variable_scope.get_variable(
                    "other_weights",
                    shape=[4, 1],
                    initializer=[[0.], [0.], [0.], [0.]],
                    partitioner=lambda shape, dtype: [2, 1])
                ws_util._warmstart_var([fruit_weights, other_weights],
                                       self.get_temp_dir())
                sess.run(variables.global_variables_initializer())
                fruit_weights = fruit_weights._get_variable_list()
                new_fruit_weights_val = np.concatenate(
                    [fruit_weights[0].eval(sess), fruit_weights[1].eval(sess)],
                    axis=0)
                other_weights = other_weights._get_variable_list()
                new_other_weights_val = np.concatenate(
                    [other_weights[0].eval(sess), other_weights[1].eval(sess)],
                    axis=0)
                self.assertAllEqual(np.concatenate(prev_vals[0], axis=0),
                                    new_fruit_weights_val)
                self.assertAllEqual(np.concatenate(prev_vals[1], axis=0),
                                    new_other_weights_val)