def build_inference(self, handle, use_moving_average=False):
        """Builds an inference pipeline.

    This always uses the whole pipeline.

    Args:
      handle: Handle tensor for the ComputeSession.
      use_moving_average: Whether or not to read from the moving
        average variables instead of the true parameters. Note: it is not
        possible to make gradient updates when this is True.

    Returns:
      handle: Handle after annotation.
    """
        self.read_from_avg = use_moving_average
        network_states = {}

        for comp in self.components:
            network_states[comp.name] = component.NetworkState()
            handle = dragnn_ops.init_component_data(
                handle,
                beam_size=comp.inference_beam_size,
                component=comp.name)
            master_state = component.MasterState(
                handle, dragnn_ops.batch_size(handle, component=comp.name))
            with tf.control_dependencies([handle]):
                handle = comp.build_greedy_inference(master_state,
                                                     network_states)
            handle = dragnn_ops.write_annotations(handle, component=comp.name)

        self.read_from_avg = False
        return handle
Exemple #2
0
  def build_inference(self, handle, use_moving_average=False,
                      clear_existing_annotations=False):
    """Builds an inference pipeline.

    This always uses the whole pipeline.

    Args:
      handle: Handle tensor for the ComputeSession.
      use_moving_average: Whether or not to read from the moving
        average variables instead of the true parameters. Note: it is not
        possible to make gradient updates when this is True.
      clear_existing_annotations: Whether or not existing annotations
        should be cleared when processing a new batch.

    Returns:
      handle: Handle after annotation.
    """
    self.read_from_avg = use_moving_average
    network_states = {}

    for comp in self.components:
      network_states[comp.name] = component.NetworkState()
      handle = dragnn_ops.init_component_data(
          handle, component=comp.name,
          clear_existing_annotations=clear_existing_annotations)
      master_state = component.MasterState(handle,
                                           dragnn_ops.batch_size(
                                               handle, component=comp.name))
      with tf.control_dependencies([handle]):
        handle = comp.build_inference(master_state, network_states)
      handle = dragnn_ops.write_annotations(handle, component=comp.name)

    self.read_from_avg = False
    return handle
Exemple #3
0
  def build_inference(self,
                      handle,
                      use_moving_average=False,
                      build_runtime_graph=False):
    """Builds an inference pipeline.

    This always uses the whole pipeline.

    Args:
      handle: Handle tensor for the ComputeSession.
      use_moving_average: Whether or not to read from the moving
        average variables instead of the true parameters. Note: it is not
        possible to make gradient updates when this is True.
      build_runtime_graph: Whether to build a graph for use by the runtime.

    Returns:
      handle: Handle after annotation.
    """
    self.read_from_avg = use_moving_average
    self.build_runtime_graph = build_runtime_graph
    network_states = {}

    for comp in self.components:
      network_states[comp.name] = component.NetworkState()
      handle = dragnn_ops.init_component_data(
          handle, beam_size=comp.inference_beam_size, component=comp.name)
      if build_runtime_graph:
        batch_size = 1  # runtime uses singleton batches
      else:
        batch_size = dragnn_ops.batch_size(handle, component=comp.name)
      master_state = component.MasterState(handle, batch_size)
      with tf.control_dependencies([handle]):
        handle = comp.build_greedy_inference(master_state, network_states)
      handle = dragnn_ops.write_annotations(handle, component=comp.name)

    self.read_from_avg = False
    self.build_runtime_graph = False
    return handle
Exemple #4
0
  def build_training(self,
                     handle,
                     component_weights=None,
                     unroll_using_oracle=None,
                     max_index=-1):
    """Builds a training pipeline.

    Args:
      handle: Handle tensor for the ComputeSession.
      component_weights: If set, this is a list of relative weights
        each component's cost should get in the pipeline. Defaults to 1.0 for
        each component.
      unroll_using_oracle: If set, this is a list of booleans indicating
        whether or not to use the gold decodings for each component. Defaults
        to True for each component.
      max_index: Training will use only the first max_index components,
        or -1 for all components.

    Returns:
      handle: to the ComputeSession, conditioned on completing training step.
      outputs: a dictionary of useful training tensors.

    Raises:
      IndexError: if max_index is positive but out of bounds.
    """
    self.read_from_avg = False
    if max_index < 0:
      max_index = len(self.components)
    else:
      if not 0 < max_index <= len(self.components):
        raise IndexError('Invalid max_index {} for components {}; handle {}'.
                         format(max_index, self.component_names, handle.name))

    # By default, we train every component supervised.
    if not component_weights:
      component_weights = [1] * max_index
    if not unroll_using_oracle:
      unroll_using_oracle = [True] * max_index

    component_weights = component_weights[:max_index]
    total_weight = (float)(sum(component_weights))
    component_weights = [w / total_weight for w in component_weights]

    unroll_using_oracle = unroll_using_oracle[:max_index]

    logging.info('Creating training target:')
    logging.info('\tWeights: %s', component_weights)
    logging.info('\tOracle: %s', unroll_using_oracle)

    metrics_list = []
    cost = tf.constant(0.)
    effective_batch = tf.constant(0)

    avg_ops = []
    params_to_train = []

    network_states = {}
    for component_index in range(0, max_index):
      comp = self.components[component_index]
      network_states[comp.name] = component.NetworkState()

      logging.info('Initializing data for component "%s"', comp.name)
      handle = dragnn_ops.init_component_data(handle, component=comp.name,
                                              clear_existing_annotations=False)
      # TODO(googleuser): Phase out component.MasterState.
      master_state = component.MasterState(handle,
                                           dragnn_ops.batch_size(
                                               handle, component=comp.name))
      with tf.control_dependencies([handle, cost]):
        args = (master_state, network_states)
        if unroll_using_oracle[component_index]:
          handle, component_cost, correct, total = comp.build_training(
              *args)
        else:
          handle = comp.build_inference(*args, during_training=True)
          component_cost = tf.constant(0.)
          correct, total = tf.constant(0), tf.constant(0)

        weighted_component_cost = tf.multiply(
            component_cost,
            tf.constant((float)(component_weights[component_index])),
            name='weighted_component_cost')

        cost += weighted_component_cost
        effective_batch += total
        metrics_list += [[total], [correct]]

        with tf.control_dependencies([comp.advance_counters(total)]):
          cost = tf.identity(cost)

        # Keep track of which parameters will be trained, and any moving
        # average updates to apply for these parameters.
        params_to_train += comp.network.params
        if self.hyperparams.use_moving_average:
          avg_ops += comp.avg_ops

    # Concatenate evaluation results
    metrics = tf.concat(metrics_list, 0)

    # Now that the cost is computed:
    # 1. compute the gradients,
    # 2. add an optimizer to update the parameters using the gradients,
    # 3. make the ComputeSession handle depend on the optimizer.
    grads_and_vars = self.optimizer.compute_gradients(
        cost, var_list=params_to_train)
    clipped_gradients = [(self._clip_gradients(g), v)
                         for g, v in grads_and_vars]
    minimize_op = self.optimizer.apply_gradients(
        clipped_gradients, global_step=self.master_vars['step'])

    if self.hyperparams.use_moving_average:
      with tf.control_dependencies([minimize_op]):
        minimize_op = tf.group(*avg_ops)

    # Make sure all the side-effectful minimizations ops finish before
    # proceeding.
    with tf.control_dependencies([minimize_op]):
      handle = tf.identity(handle)

    # Restore that subsequent builds don't use average by default.
    self.read_from_avg = False

    # Returns named access to common outputs.
    outputs = {
        'cost': cost,
        'batch': effective_batch,
        'metrics': metrics,
    }
    return handle, outputs
Exemple #5
0
  def build_training(self,
                     handle,
                     compute_gradients=True,
                     use_moving_average=False,
                     advance_counters=True,
                     component_weights=None,
                     unroll_using_oracle=None,
                     max_index=-1):
    """Builds a training pipeline.

    Args:
      handle: Handle tensor for the ComputeSession.
      compute_gradients: Whether to generate gradients and an optimizer op.
        When False, build_training will return a 'dry run' training op,
        used normally only for oracle tracing.
      use_moving_average: Whether or not to read from the moving
        average variables instead of the true parameters. Note: it is not
        possible to make gradient updates when this is True.
      advance_counters: Whether or not this loop should increment the
        per-component step counters.
      component_weights: If set, this is a list of relative weights
        each component's cost should get in the pipeline. Defaults to 1.0 for
        each component.
      unroll_using_oracle: If set, this is a list of booleans indicating
        whether or not to use the gold decodings for each component. Defaults
        to True for each component.
      max_index: Training will use only the first max_index components,
        or -1 for all components.

    Returns:
      handle: to the ComputeSession, conditioned on completing training step.
      outputs: a dictionary of useful training tensors.

    Raises:
      IndexError: if max_index is positive but out of bounds.
    """
    check.IsFalse(compute_gradients and use_moving_average,
                  'It is not possible to make gradient updates when reading '
                  'from the moving average variables.')

    self.read_from_avg = use_moving_average
    if max_index < 0:
      max_index = len(self.components)
    else:
      if not 0 < max_index <= len(self.components):
        raise IndexError(
            'Invalid max_index {} for components {}; handle {}'.format(
                max_index, self.component_names, handle.name))

    # By default, we train every component supervised.
    if not component_weights:
      component_weights = [1] * max_index
    if not unroll_using_oracle:
      unroll_using_oracle = [True] * max_index

    if not max_index <= len(unroll_using_oracle):
      raise IndexError(('Invalid max_index {} for unroll_using_oracle {}; '
                        'handle {}').format(max_index, unroll_using_oracle,
                                            handle.name))

    component_weights = component_weights[:max_index]
    total_weight = (float)(sum(component_weights))
    component_weights = [w / total_weight for w in component_weights]

    unroll_using_oracle = unroll_using_oracle[:max_index]

    logging.info('Creating training target:')
    logging.info('\tWeights: %s', component_weights)
    logging.info('\tOracle: %s', unroll_using_oracle)

    metrics_list = []
    cost = tf.constant(0.)
    effective_batch = tf.constant(0)

    avg_ops = []
    params_to_train = []

    network_states = {}
    for component_index in range(0, max_index):
      comp = self.components[component_index]
      network_states[comp.name] = component.NetworkState()

      logging.info('Initializing data for component "%s"', comp.name)
      handle = dragnn_ops.init_component_data(
          handle, beam_size=comp.training_beam_size, component=comp.name)
      # TODO(googleuser): Phase out component.MasterState.
      master_state = component.MasterState(handle,
                                           dragnn_ops.batch_size(
                                               handle, component=comp.name))
      with tf.control_dependencies([handle, cost]):
        args = (master_state, network_states)
        if unroll_using_oracle[component_index]:

          handle, component_cost, component_correct, component_total = (
              tf.cond(comp.training_beam_size > 1,
                      lambda: comp.build_structured_training(*args),
                      lambda: comp.build_greedy_training(*args)))

        else:
          handle = comp.build_greedy_inference(*args, during_training=True)
          component_cost = tf.constant(0.)
          component_correct, component_total = tf.constant(0), tf.constant(0)

        weighted_component_cost = tf.multiply(
            component_cost,
            tf.constant((float)(component_weights[component_index])),
            name='weighted_component_cost')

        cost += weighted_component_cost
        effective_batch += component_total
        metrics_list += [[component_total], [component_correct]]

        if advance_counters:
          with tf.control_dependencies(
              [comp.advance_counters(component_total)]):
            cost = tf.identity(cost)

        # Keep track of which parameters will be trained, and any moving
        # average updates to apply for these parameters.
        params_to_train += comp.network.params
        if self.hyperparams.use_moving_average:
          avg_ops += comp.avg_ops

    # Concatenate evaluation results
    metrics = tf.concat(metrics_list, 0)

    # If gradient computation is requested, then:
    # 1. compute the gradients,
    # 2. add an optimizer to update the parameters using the gradients,
    # 3. make the ComputeSession handle depend on the optimizer.
    gradient_norm = tf.constant(0.)
    if compute_gradients:
      logging.info('Creating train op with %d variables:\n\t%s',
                   len(params_to_train),
                   '\n\t'.join([x.name for x in params_to_train]))

      grads_and_vars = self.optimizer.compute_gradients(
          cost, var_list=params_to_train)
      clipped_gradients = [
          (self._clip_gradients(g), v) for g, v in grads_and_vars
      ]
      gradient_norm = tf.global_norm(list(zip(*clipped_gradients))[0])

      minimize_op = self.optimizer.apply_gradients(
          clipped_gradients, global_step=self.master_vars['step'])

      if self.hyperparams.use_moving_average:
        with tf.control_dependencies([minimize_op]):
          minimize_op = tf.group(*avg_ops)

      # Make sure all the side-effectful minimizations ops finish before
      # proceeding.
      with tf.control_dependencies([minimize_op]):
        handle = tf.identity(handle)

    # Restore that subsequent builds don't use average by default.
    self.read_from_avg = False

    cost = tf.check_numerics(cost, message='Cost is not finite.')

    # Returns named access to common outputs.
    outputs = {
        'cost': cost,
        'gradient_norm': gradient_norm,
        'batch': effective_batch,
        'metrics': metrics,
    }
    return handle, outputs
Exemple #6
0
    def build_training(self,
                       handle,
                       compute_gradients=True,
                       use_moving_average=False,
                       advance_counters=True,
                       component_weights=None,
                       unroll_using_oracle=None,
                       max_index=-1):
        """Builds a training pipeline.

    Args:
      handle: Handle tensor for the ComputeSession.
      compute_gradients: Whether to generate gradients and an optimizer op.
        When False, build_training will return a 'dry run' training op,
        used normally only for oracle tracing.
      use_moving_average: Whether or not to read from the moving
        average variables instead of the true parameters. Note: it is not
        possible to make gradient updates when this is True.
      advance_counters: Whether or not this loop should increment the
        per-component step counters.
      component_weights: If set, this is a list of relative weights
        each component's cost should get in the pipeline. Defaults to 1.0 for
        each component.
      unroll_using_oracle: If set, this is a list of booleans indicating
        whether or not to use the gold decodings for each component. Defaults
        to True for each component.
      max_index: Training will use only the first max_index components,
        or -1 for all components.

    Returns:
      handle: to the ComputeSession, conditioned on completing training step.
      outputs: a dictionary of useful training tensors.

    Raises:
      IndexError: if max_index is positive but out of bounds.
    """
        check.IsFalse(
            compute_gradients and use_moving_average,
            'It is not possible to make gradient updates when reading '
            'from the moving average variables.')

        self.read_from_avg = use_moving_average
        if max_index < 0:
            max_index = len(self.components)
        else:
            if not 0 < max_index <= len(self.components):
                raise IndexError(
                    'Invalid max_index {} for components {}; handle {}'.format(
                        max_index, self.component_names, handle.name))

        # By default, we train every component supervised.
        if not component_weights:
            component_weights = [1] * max_index
        if not unroll_using_oracle:
            unroll_using_oracle = [True] * max_index

        if not max_index <= len(unroll_using_oracle):
            raise IndexError(
                ('Invalid max_index {} for unroll_using_oracle {}; '
                 'handle {}').format(max_index, unroll_using_oracle,
                                     handle.name))

        component_weights = component_weights[:max_index]
        total_weight = (float)(sum(component_weights))
        component_weights = [w / total_weight for w in component_weights]

        unroll_using_oracle = unroll_using_oracle[:max_index]

        logging.info('Creating training target:')
        logging.info('\tWeights: %s', component_weights)
        logging.info('\tOracle: %s', unroll_using_oracle)

        metrics_list = []
        cost = tf.constant(0.)
        effective_batch = tf.constant(0)

        avg_ops = []
        params_to_train = []

        network_states = {}
        for component_index in range(0, max_index):
            comp = self.components[component_index]
            network_states[comp.name] = component.NetworkState()

            logging.info('Initializing data for component "%s"', comp.name)
            handle = dragnn_ops.init_component_data(
                handle, beam_size=comp.training_beam_size, component=comp.name)
            # TODO(googleuser): Phase out component.MasterState.
            master_state = component.MasterState(
                handle, dragnn_ops.batch_size(handle, component=comp.name))
            with tf.control_dependencies([handle, cost]):
                args = (master_state, network_states)
                if unroll_using_oracle[component_index]:

                    handle, component_cost, component_correct, component_total = (
                        tf.cond(comp.training_beam_size > 1,
                                lambda: comp.build_structured_training(*args),
                                lambda: comp.build_greedy_training(*args)))

                else:
                    handle = comp.build_greedy_inference(*args,
                                                         during_training=True)
                    component_cost = tf.constant(0.)
                    component_correct, component_total = tf.constant(
                        0), tf.constant(0)

                weighted_component_cost = tf.multiply(
                    component_cost,
                    tf.constant((float)(component_weights[component_index])),
                    name='weighted_component_cost')

                cost += weighted_component_cost
                effective_batch += component_total
                metrics_list += [[component_total], [component_correct]]

                if advance_counters:
                    with tf.control_dependencies(
                        [comp.advance_counters(component_total)]):
                        cost = tf.identity(cost)

                # Keep track of which parameters will be trained, and any moving
                # average updates to apply for these parameters.
                params_to_train += comp.network.params
                if self.hyperparams.use_moving_average:
                    avg_ops += comp.avg_ops

        # Concatenate evaluation results
        metrics = tf.concat(metrics_list, 0)

        # If gradient computation is requested, then:
        # 1. compute the gradients,
        # 2. add an optimizer to update the parameters using the gradients,
        # 3. make the ComputeSession handle depend on the optimizer.
        gradient_norm = tf.constant(0.)
        if compute_gradients:
            logging.info('Creating train op with %d variables:\n\t%s',
                         len(params_to_train),
                         '\n\t'.join([x.name for x in params_to_train]))

            grads_and_vars = self.optimizer.compute_gradients(
                cost, var_list=params_to_train)
            clipped_gradients = [(self._clip_gradients(g), v)
                                 for g, v in grads_and_vars]
            gradient_norm = tf.global_norm(list(zip(*clipped_gradients))[0])

            minimize_op = self.optimizer.apply_gradients(
                clipped_gradients, global_step=self.master_vars['step'])

            if self.hyperparams.use_moving_average:
                with tf.control_dependencies([minimize_op]):
                    minimize_op = tf.group(*avg_ops)

            # Make sure all the side-effectful minimizations ops finish before
            # proceeding.
            with tf.control_dependencies([minimize_op]):
                handle = tf.identity(handle)

        # Restore that subsequent builds don't use average by default.
        self.read_from_avg = False

        cost = tf.check_numerics(cost, message='Cost is not finite.')

        # Returns named access to common outputs.
        outputs = {
            'cost': cost,
            'gradient_norm': gradient_norm,
            'batch': effective_batch,
            'metrics': metrics,
        }
        return handle, outputs
  def testBulkFeatureIdExtractorExtractFocusWithOffset(self):
    path = os.path.join(tf.test.get_temp_dir(), 'label-map')
    with open(path, 'w') as label_map_file:
      label_map_file.write('0\n')

    master_spec = spec_pb2.MasterSpec()
    text_format.Parse("""
        component {
          name: "test"
          transition_system {
            registered_name: "shift-only"
          }
          resource {
            name: "label-map"
            part {
              file_pattern: "%s"
              file_format: "text"
            }
          }
          network_unit {
            registered_name: "ExportFixedFeaturesNetwork"
          }
          backend {
            registered_name: "SyntaxNetComponent"
          }
          fixed_feature {
            name: "focus1" embedding_dim: -1 size: 1 fml: "input.focus"
            predicate_map: "none"
          }
          fixed_feature {
            name: "focus2" embedding_dim: -1 size: 1 fml: "input(1).focus"
            predicate_map: "none"
          }
          fixed_feature {
            name: "focus3" embedding_dim: -1 size: 1 fml: "input(2).focus"
            predicate_map: "none"
          }
        }
        """ % path, master_spec)

    with tf.Graph().as_default():
      corpus = _create_fake_corpus()
      corpus = tf.constant(corpus, shape=[len(corpus)])
      handle = dragnn_ops.get_session(
          container='test',
          master_spec=master_spec.SerializeToString(),
          grid_point='')
      handle = dragnn_ops.attach_data_reader(handle, corpus)
      handle = dragnn_ops.init_component_data(
          handle, beam_size=1, component='test')
      batch_size = dragnn_ops.batch_size(handle, component='test')
      master_state = component.MasterState(handle, batch_size)

      extractor = bulk_component.BulkFeatureIdExtractorComponentBuilder(
          self.master, master_spec.component[0])
      network_state = component.NetworkState()
      self.network_states['test'] = network_state
      handle = extractor.build_greedy_inference(master_state,
                                                self.network_states)
      focus1 = network_state.activations['focus1'].bulk_tensor
      focus2 = network_state.activations['focus2'].bulk_tensor
      focus3 = network_state.activations['focus3'].bulk_tensor

      with self.test_session() as sess:
        focus1, focus2, focus3 = sess.run([focus1, focus2, focus3])
        tf.logging.info('focus1=\n%s', focus1)
        tf.logging.info('focus2=\n%s', focus2)
        tf.logging.info('focus3=\n%s', focus3)

        self.assertAllEqual(
            focus1,
            [[0], [-1], [-1], [-1],
             [0], [1], [-1], [-1],
             [0], [1], [2], [-1],
             [0], [1], [2], [3]])

        self.assertAllEqual(
            focus2,
            [[-1], [-1], [-1], [-1],
             [1], [-1], [-1], [-1],
             [1], [2], [-1], [-1],
             [1], [2], [3], [-1]])

        self.assertAllEqual(
            focus3,
            [[-1], [-1], [-1], [-1],
             [-1], [-1], [-1], [-1],
             [2], [-1], [-1], [-1],
             [2], [3], [-1], [-1]])
  def testBulkFeatureIdExtractorExtractFocusWithOffset(self):
    path = os.path.join(tf.test.get_temp_dir(), 'label-map')
    with open(path, 'w') as label_map_file:
      label_map_file.write('0\n')

    master_spec = spec_pb2.MasterSpec()
    text_format.Parse("""
        component {
          name: "test"
          transition_system {
            registered_name: "shift-only"
          }
          resource {
            name: "label-map"
            part {
              file_pattern: "%s"
              file_format: "text"
            }
          }
          network_unit {
            registered_name: "ExportFixedFeaturesNetwork"
          }
          backend {
            registered_name: "SyntaxNetComponent"
          }
          fixed_feature {
            name: "focus1" embedding_dim: -1 size: 1 fml: "input.focus"
            predicate_map: "none"
          }
          fixed_feature {
            name: "focus2" embedding_dim: -1 size: 1 fml: "input(1).focus"
            predicate_map: "none"
          }
          fixed_feature {
            name: "focus3" embedding_dim: -1 size: 1 fml: "input(2).focus"
            predicate_map: "none"
          }
        }
        """ % path, master_spec)

    with tf.Graph().as_default():
      corpus = _create_fake_corpus()
      corpus = tf.constant(corpus, shape=[len(corpus)])
      handle = dragnn_ops.get_session(
          container='test',
          master_spec=master_spec.SerializeToString(),
          grid_point='')
      handle = dragnn_ops.attach_data_reader(handle, corpus)
      handle = dragnn_ops.init_component_data(
          handle, beam_size=1, component='test')
      batch_size = dragnn_ops.batch_size(handle, component='test')
      master_state = component.MasterState(handle, batch_size)

      extractor = bulk_component.BulkFeatureIdExtractorComponentBuilder(
          self.master, master_spec.component[0])
      network_state = component.NetworkState()
      self.network_states['test'] = network_state
      handle = extractor.build_greedy_inference(master_state,
                                                self.network_states)
      focus1 = network_state.activations['focus1'].bulk_tensor
      focus2 = network_state.activations['focus2'].bulk_tensor
      focus3 = network_state.activations['focus3'].bulk_tensor

      with self.test_session() as sess:
        focus1, focus2, focus3 = sess.run([focus1, focus2, focus3])
        tf.logging.info('focus1=\n%s', focus1)
        tf.logging.info('focus2=\n%s', focus2)
        tf.logging.info('focus3=\n%s', focus3)

        self.assertAllEqual(
            focus1,
            [[0], [-1], [-1], [-1],
             [0], [1], [-1], [-1],
             [0], [1], [2], [-1],
             [0], [1], [2], [3]])

        self.assertAllEqual(
            focus2,
            [[-1], [-1], [-1], [-1],
             [1], [-1], [-1], [-1],
             [1], [2], [-1], [-1],
             [1], [2], [3], [-1]])

        self.assertAllEqual(
            focus3,
            [[-1], [-1], [-1], [-1],
             [-1], [-1], [-1], [-1],
             [2], [-1], [-1], [-1],
             [2], [3], [-1], [-1]])
Exemple #9
0
    def build_training(self,
                       handle,
                       component_weights=None,
                       unroll_using_oracle=None,
                       max_index=-1):
        """Builds a training pipeline.

    Args:
      handle: Handle tensor for the ComputeSession.
      component_weights: If set, this is a list of relative weights
        each component's cost should get in the pipeline. Defaults to 1.0 for
        each component.
      unroll_using_oracle: If set, this is a list of booleans indicating
        whether or not to use the gold decodings for each component. Defaults
        to True for each component.
      max_index: Training will use only the first max_index components,
        or -1 for all components.

    Returns:
      handle: to the ComputeSession, conditioned on completing training step.
      outputs: a dictionary of useful training tensors.

    Raises:
      IndexError: if max_index is positive but out of bounds.
    """
        self.read_from_avg = False
        if max_index < 0:
            max_index = len(self.components)
        else:
            if not 0 < max_index <= len(self.components):
                raise IndexError(
                    'Invalid max_index {} for components {}; handle {}'.format(
                        max_index, self.component_names, handle.name))

        # By default, we train every component supervised.
        if not component_weights:
            component_weights = [1] * max_index
        if not unroll_using_oracle:
            unroll_using_oracle = [True] * max_index

        component_weights = component_weights[:max_index]
        total_weight = (float)(sum(component_weights))
        component_weights = [w / total_weight for w in component_weights]

        unroll_using_oracle = unroll_using_oracle[:max_index]

        logging.info('Creating training target:')
        logging.info('\tWeights: %s', component_weights)
        logging.info('\tOracle: %s', unroll_using_oracle)

        metrics_list = []
        cost = tf.constant(0.)
        effective_batch = tf.constant(0)

        avg_ops = []
        params_to_train = []

        network_states = {}
        for component_index in range(0, max_index):
            comp = self.components[component_index]
            network_states[comp.name] = component.NetworkState()

            logging.info('Initializing data for component "%s"', comp.name)
            handle = dragnn_ops.init_component_data(
                handle, component=comp.name, clear_existing_annotations=False)
            # TODO(googleuser): Phase out component.MasterState.
            master_state = component.MasterState(
                handle, dragnn_ops.batch_size(handle, component=comp.name))
            with tf.control_dependencies([handle, cost]):
                args = (master_state, network_states)
                if unroll_using_oracle[component_index]:
                    handle, component_cost, correct, total = comp.build_training(
                        *args)
                else:
                    handle = comp.build_inference(*args, during_training=True)
                    component_cost = tf.constant(0.)
                    correct, total = tf.constant(0), tf.constant(0)

                weighted_component_cost = tf.multiply(
                    component_cost,
                    tf.constant((float)(component_weights[component_index])),
                    name='weighted_component_cost')

                cost += weighted_component_cost
                effective_batch += total
                metrics_list += [[total], [correct]]

                with tf.control_dependencies([comp.advance_counters(total)]):
                    cost = tf.identity(cost)

                # Keep track of which parameters will be trained, and any moving
                # average updates to apply for these parameters.
                params_to_train += comp.network.params
                if self.hyperparams.use_moving_average:
                    avg_ops += comp.avg_ops

        # Concatenate evaluation results
        metrics = tf.concat(metrics_list, 0)

        # Now that the cost is computed:
        # 1. compute the gradients,
        # 2. add an optimizer to update the parameters using the gradients,
        # 3. make the ComputeSession handle depend on the optimizer.
        grads_and_vars = self.optimizer.compute_gradients(
            cost, var_list=params_to_train)
        clipped_gradients = [(self._clip_gradients(g), v)
                             for g, v in grads_and_vars]
        minimize_op = self.optimizer.apply_gradients(
            clipped_gradients, global_step=self.master_vars['step'])

        if self.hyperparams.use_moving_average:
            with tf.control_dependencies([minimize_op]):
                minimize_op = tf.group(*avg_ops)

        # Make sure all the side-effectful minimizations ops finish before
        # proceeding.
        with tf.control_dependencies([minimize_op]):
            handle = tf.identity(handle)

        # Restore that subsequent builds don't use average by default.
        self.read_from_avg = False

        # Returns named access to common outputs.
        outputs = {
            'cost': cost,
            'batch': effective_batch,
            'metrics': metrics,
        }
        return handle, outputs