コード例 #1
0
        def load_ops_fn():
            """Returns the retrieve ops for AdaGrad embedding tables.

      Returns:
        A list of ops to load embedding and slot variables from CPU to TPU.
      """
            load_op_list = []
            for host_id, table_variable, accumulator_variable in (zip(
                    range(num_hosts), table_variables, accumulator_variables)):
                with ops.colocate_with(table_variable):
                    load_parameters_op = (
                        tpu_ops.load_tpu_embedding_adagrad_parameters(
                            parameters=table_variable,
                            accumulators=accumulator_variable,
                            table_name=table,
                            num_shards=num_hosts,
                            shard_id=host_id))
                load_op_list.append(load_parameters_op)
            return load_op_list
コード例 #2
0
    def load_ops_fn():
      """Returns the retrieve ops for AdaGrad embedding tables.

      Returns:
        A list of ops to load embedding and slot variables from CPU to TPU.
      """
      load_op_list = []
      for host_id, table_variable, accumulator_variable in (zip(
          range(num_hosts), table_variables, accumulator_variables)):
        with ops.colocate_with(table_variable):
          load_parameters_op = (
              tpu_ops.load_tpu_embedding_adagrad_parameters(
                  parameters=table_variable,
                  accumulators=accumulator_variable,
                  table_name=table,
                  num_shards=num_hosts,
                  shard_id=host_id))
        load_op_list.append(load_parameters_op)
      return load_op_list