Example #1
0
  def create_variables_and_ops(self, table, variable_name, num_hosts,
                               table_config, table_variables,
                               load_parameters_ops, retrieve_parameters_ops):
    optimizer_name = 'Adam'
    m_initializer = init_ops.zeros_initializer()
    m_variables = _create_partitioned_variables(
        name='%s/%s/m' % (variable_name, optimizer_name),
        num_hosts=num_hosts,
        vocabulary_size=table_config.vocabulary_size,
        embedding_dimension=table_config.dimension,
        collections=[ops.GraphKeys.GLOBAL_VARIABLES],
        initializer=m_initializer)
    v_initializer = init_ops.zeros_initializer()
    v_variables = _create_partitioned_variables(
        name='%s/%s/v' % (variable_name, optimizer_name),
        num_hosts=num_hosts,
        vocabulary_size=table_config.vocabulary_size,
        embedding_dimension=table_config.dimension,
        collections=[ops.GraphKeys.GLOBAL_VARIABLES],
        initializer=v_initializer)

    self._table_to_m_variables_dict[table] = m_variables
    self._table_to_v_variables_dict[table] = v_variables

    for host_id, table_variable, m_variable, v_variable in (zip(
        range(num_hosts), table_variables,
        m_variables, v_variables)):
      with ops.colocate_with(table_variable):
        load_parameters_op = (
            tpu_ops.load_tpu_embedding_adam_parameters(
                parameters=table_variable,
                momenta=m_variable,
                velocities=v_variable,
                table_name=table,
                num_shards=num_hosts,
                shard_id=host_id))
        retrieved_table, retrieved_m, retrieved_v = (
            tpu_ops.retrieve_tpu_embedding_adam_parameters(
                table_name=table,
                num_shards=num_hosts,
                shard_id=host_id))
        retrieve_parameters_op = control_flow_ops.group(
            state_ops.assign(table_variable, retrieved_table),
            state_ops.assign(m_variable, retrieved_m),
            state_ops.assign(v_variable, retrieved_v))

      load_parameters_ops.append(load_parameters_op)
      retrieve_parameters_ops.append(retrieve_parameters_op)
Example #2
0
  def create_variables_and_ops(self, table, variable_name, num_hosts,
                               table_config, table_variables,
                               load_parameters_ops, retrieve_parameters_ops):
    optimizer_name = 'Adam'
    m_initializer = init_ops.zeros_initializer()
    m_variables = _create_partitioned_variables(
        name='%s/%s/m' % (variable_name, optimizer_name),
        num_hosts=num_hosts,
        vocabulary_size=table_config.vocabulary_size,
        embedding_dimension=table_config.dimension,
        collections=[ops.GraphKeys.GLOBAL_VARIABLES],
        initializer=m_initializer)
    v_initializer = init_ops.zeros_initializer()
    v_variables = _create_partitioned_variables(
        name='%s/%s/v' % (variable_name, optimizer_name),
        num_hosts=num_hosts,
        vocabulary_size=table_config.vocabulary_size,
        embedding_dimension=table_config.dimension,
        collections=[ops.GraphKeys.GLOBAL_VARIABLES],
        initializer=v_initializer)

    self._table_to_m_variables_dict[table] = m_variables
    self._table_to_v_variables_dict[table] = v_variables

    for host_id, table_variable, m_variable, v_variable in (zip(
        range(num_hosts), table_variables,
        m_variables, v_variables)):
      with ops.colocate_with(table_variable):
        load_parameters_op = (
            tpu_ops.load_tpu_embedding_adam_parameters(
                parameters=table_variable,
                momenta=m_variable,
                velocities=v_variable,
                table_name=table,
                num_shards=num_hosts,
                shard_id=host_id))
        retrieved_table, retrieved_m, retrieved_v = (
            tpu_ops.retrieve_tpu_embedding_adam_parameters(
                table_name=table,
                num_shards=num_hosts,
                shard_id=host_id))
        retrieve_parameters_op = control_flow_ops.group(
            state_ops.assign(table_variable, retrieved_table),
            state_ops.assign(m_variable, retrieved_m),
            state_ops.assign(v_variable, retrieved_v))

      load_parameters_ops.append(load_parameters_op)
      retrieve_parameters_ops.append(retrieve_parameters_op)
        def retrieve_ops_fn():
            """Returns the retrieve ops for Adam embedding tables.

      Returns:
        A list of ops to retrieve embedding and slot variables from TPU to CPU.
      """

            retrieve_op_list = []
            for host_id, table_variable, m_variable, v_variable in (zip(
                    range(num_hosts), table_variables, m_variables,
                    v_variables)):
                with ops.colocate_with(table_variable):
                    retrieved_table, retrieved_m, retrieved_v = (
                        tpu_ops.retrieve_tpu_embedding_adam_parameters(
                            table_name=table,
                            num_shards=num_hosts,
                            shard_id=host_id))
                    retrieve_parameters_op = control_flow_ops.group(
                        state_ops.assign(table_variable, retrieved_table),
                        state_ops.assign(m_variable, retrieved_m),
                        state_ops.assign(v_variable, retrieved_v))

                retrieve_op_list.append(retrieve_parameters_op)
            return retrieve_op_list