def configure(self,
                session_config=None,
                cluster_spec=None,
                task_type=None,
                task_id=None):
    del task_type, task_id

    if session_config:
      session_config.isolate_session_state = True

    if cluster_spec:
      self._initialize_multi_worker(self._num_gpus, cluster_spec)

    if self._cross_tower_ops is None:
      if self._cluster_spec:
        # It currently cannot detect the toplogy of remote workers. So we
        # hard-code the multi-worker all-reduce algorithm for now.
        if len(self._workers) == 1:
          # The default is "nccl".
          self._cross_tower_ops = cross_tower_ops_lib.AllReduceCrossTowerOps()
        else:
          # The default is hierarchical reduce and broadcast.
          self._cross_tower_ops = cross_tower_ops_lib.MultiWorkerAllReduce(
              self._workers, self._num_gpus)
      else:
        self._cross_tower_ops = cross_tower_ops_lib.choose_the_best(
            self._devices, session_config=session_config)
Beispiel #2
0
def main():
    cross_tower_ops = cross_tower_ops_lib.AllReduceCrossTowerOps('nccl')
    distribution = tf.contrib.distribute.MirroredStrategy(
        num_gpus=2, cross_tower_ops=cross_tower_ops)
    config = tf.estimator.RunConfig(train_distribute=distribution)
    multi_task_config = Bunch(json.load(tf.gfile.Open(
        FLAGS.multi_task_config)))

    def input_fn():
        features = tf.data.Dataset.from_tensors([[1.]]).repeat(10)
        labels = tf.data.Dataset.from_tensors([1.]).repeat(10)
        return tf.data.Dataset.zip((features, labels))

    def input_fn_generator():
        features = tf.data.Dataset.from_generator(data_generator, tf.float32,
                                                  tf.TensorShape([None, 1]))
        labels = tf.data.Dataset.from_generator(data_generator, tf.float32,
                                                tf.TensorShape([None, 1]))
        dataset = tf.data.Dataset.zip((features, labels)).repeat(10)
        return dataset

    def input_fn_generator():
        dataset = train_eval_input_fn(FLAGS, multi_task_config, "train", 0)
        return dataset

    model_fn = build_model_fn_optimizer()

    est = tf.estimator.Estimator(model_fn=model_fn, config=config)
    print("==begin to train==")
    est.train(input_fn=input_fn_generator, max_steps=1000)
def get_run_config():
    session_config = tf.ConfigProto(allow_soft_placement=True,
                                    log_device_placement=False,
                                    intra_op_parallelism_threads=64,
                                    inter_op_parallelism_threads=64,
                                    gpu_options=tf.GPUOptions(
                                        allow_growth=True,
                                        force_gpu_compatible=True,
                                        per_process_gpu_memory_fraction=1.0))
    #session_config.graph_options.optimizer_options.opt_level = -1
    from tensorflow.core.protobuf import rewriter_config_pb2
    session_config.graph_options.rewrite_options.constant_folding = (
        rewriter_config_pb2.RewriterConfig.OFF)

    if FLAGS.distribution_strategy == "ExascaleStrategy":
        tf.logging.info(
            "*****************Using ExascaleStrategy*********************")
        import pai
        worker_hosts = FLAGS.worker_hosts.split(',')
        if len(worker_hosts) > 1:
            pai.distribute.set_tf_config(FLAGS.job_name, FLAGS.task_index,
                                         worker_hosts)
        strategy = pai.distribute.ExascaleStrategy(
            optimize_clip_by_global_norm=True)
    elif FLAGS.distribution_strategy == "MirroredStrategy":
        tf.logging.info(
            "*****************Using MirroredStrategy*********************")
        from tensorflow.contrib.distribute.python import cross_tower_ops as cross_tower_ops_lib
        cross_tower_ops = cross_tower_ops_lib.AllReduceCrossTowerOps('nccl')
        strategy = tf.contrib.distribute.MirroredStrategy(
            num_gpus=FLAGS.num_core_per_host)
    elif FLAGS.distribution_strategy == "None":
        strategy = None
    else:
        raise ValueError(
            "Set correct distribution strategy, ExascaleStrategy | MirroredStrategy | None"
        )

    #model_dir set in tf.estimator.Estimator
    run_config = tf.estimator.RunConfig(
        model_dir=FLAGS.model_dir,
        session_config=session_config,
        keep_checkpoint_max=FLAGS.max_save,
        save_checkpoints_secs=None,
        save_checkpoints_steps=FLAGS.save_steps,
        #log_step_count_steps=50,
        train_distribute=strategy,
    )
    return run_config
class SingleWorkerCrossTowerOpsTest(CrossTowerOpsTestBase):
  # TODO(yuefengz): decouple the num_gpus check from distribution in
  # combinations module so that we can pass in devices instead of a distribution
  # strategy.
  reduction_to_one_combinations = combinations.combine(
      cross_tower_ops=[
          combinations.NamedObject(
              "DefaultReductionToOneDeviceCrossTowerOps",
              cross_tower_ops_lib.ReductionToOneDeviceCrossTowerOps()),
          combinations.NamedObject(
              "ReductionToCPUDeviceCrossTowerOps",
              cross_tower_ops_lib.ReductionToOneDeviceCrossTowerOps(
                  reduce_to_device=_cpu_device)),
          combinations.NamedObject(
              "AccumulateNCrossTowerOp",
              cross_tower_ops_lib.ReductionToOneDeviceCrossTowerOps(
                  accumulation_fn=math_ops.accumulate_n)),
      ],
      distribution=[
          combinations.one_device_strategy,
          combinations.mirrored_strategy_with_gpu_and_cpu,
          combinations.mirrored_strategy_with_two_gpus
      ],
      mode=["graph", "eager"])
  allreduce_combinations = combinations.combine(
      cross_tower_ops=[
          combinations.NamedObject(
              "AllReduce",
              cross_tower_ops_lib.AllReduceCrossTowerOps("nccl", 1, 0, 0)),
          combinations.NamedObject(
              "HierarchicalCopy",
              cross_tower_ops_lib.AllReduceCrossTowerOps(
                  "hierarchical_copy", 8, 0, 0)),
          combinations.NamedObject(
              "AllReduceNoGradientRepacking",
              cross_tower_ops_lib.AllReduceCrossTowerOps("nccl", 0, 0, 0)),
          combinations.NamedObject(
              "HierarchicalCopyAggregateSmallTensors",
              cross_tower_ops_lib.AllReduceCrossTowerOps(
                  "hierarchical_copy", 0, 100, 10))
      ],
      distribution=[combinations.mirrored_strategy_with_two_gpus],
      mode=["graph", "eager"])

  @combinations.generate(reduction_to_one_combinations + allreduce_combinations)
  def testReductionAndBroadcast(self, cross_tower_ops, distribution):
    with distribution.scope():
      self._testReductionAndBroadcast(cross_tower_ops, distribution)

  def testChooseAlgorithm(self):
    device_links = [[1, 2, 3, 4], [0, 2, 3, 5], [0, 1, 3, 6], [0, 1, 2, 7],
                    [0, 5, 6, 7], [1, 4, 6, 7], [2, 4, 5, 7], [3, 4, 5, 6]]
    result = cross_tower_ops_lib._choose_all_reduce_algorithm(device_links)
    self.assertIsInstance(result, cross_tower_ops_lib.AllReduceCrossTowerOps)
    self.assertEqual(result._all_reduce_alg, "hierarchical_copy")
    self.assertEqual(result._num_packs, 8)

    # if there are only 4 devices
    device_links = [[1, 2, 3, 4], [0, 2, 3, 5], [0, 1, 3, 6], [0, 1, 2, 7]]
    result = cross_tower_ops_lib._choose_all_reduce_algorithm(device_links)
    self.assertIsInstance(result, cross_tower_ops_lib.AllReduceCrossTowerOps)
    self.assertEqual(result._all_reduce_alg, "nccl")
    self.assertEqual(result._num_packs, 1)

    # if devices links contain each device itself
    device_links = [[0, 1, 2, 3, 4], [0, 1, 2, 3, 5], [0, 1, 2, 3, 6],
                    [0, 1, 2, 3, 7], [0, 4, 5, 6, 7], [1, 4, 5, 6, 7],
                    [2, 4, 5, 6, 7], [3, 4, 5, 6, 7]]
    result = cross_tower_ops_lib._choose_all_reduce_algorithm(device_links)
    self.assertIsInstance(result, cross_tower_ops_lib.AllReduceCrossTowerOps)
    self.assertEqual(result._all_reduce_alg, "hierarchical_copy")
    self.assertEqual(result._num_packs, 8)

    # if not dgx1-like links
    device_links = [[0, 2, 3, 5], [0, 1, 3, 6], [0, 1, 2, 7], [0, 5, 6, 7],
                    [1, 4, 6, 7], [2, 4, 5, 7], [3, 4, 5, 6], [1, 2, 3, 4]]
    result = cross_tower_ops_lib._choose_all_reduce_algorithm(device_links)
    self.assertIsInstance(result, cross_tower_ops_lib.AllReduceCrossTowerOps)
    self.assertEqual(result._all_reduce_alg, "nccl")
    self.assertEqual(result._num_packs, 1)

  @combinations.generate(combinations.combine(
      mode=["graph", "eager"],
      required_gpus=1))
  def testSimpleReduceWithIndexedSlices(self):
    devices = ["/cpu:0", "/gpu:0"]
    t0 = _make_indexed_slices([[1., 2.]], [1], [5, 2], devices[0])
    t1 = _make_indexed_slices([[3., 4.], [5., 6.]], [1, 3], [5, 2], devices[1])
    per_device = value_lib.PerDevice({devices[0]: t0, devices[1]: t1})
    result = cross_tower_ops_lib._simple_reduce(per_device, devices[0],
                                                math_ops.add_n, "sum")

    # Test that the result is semantically equal to both the concatenated
    # IndexedSlices with and without duplicate indices.
    total_with_dups = _make_indexed_slices(
        [[1., 2.], [3., 4.], [5., 6.]], [1, 1, 3], [5, 2], devices[0])
    total_without_dups = _make_indexed_slices(
        [[4., 6.], [5., 6.]], [1, 3], [5, 2], devices[0])
    self._assert_indexed_slices_equal(total_with_dups, result)
    self._assert_indexed_slices_equal(total_without_dups, result)

  @combinations.generate(combinations.combine(
      cross_tower_ops_instance=[
          combinations.NamedObject(
              "ReductionToOneDeviceCrossTowerOps",
              cross_tower_ops_lib.ReductionToOneDeviceCrossTowerOps()),
          combinations.NamedObject(
              "AllReduceCrossTowerOps",
              cross_tower_ops_lib.AllReduceCrossTowerOps())
      ],
      method_string=["sum", "mean"],
      batch_reduce=[True, False],
      mode=["graph", "eager"],
      required_gpus=1))
  def testIndexedSlicesAllReduce(self, cross_tower_ops_instance,
                                 method_string, batch_reduce):
    devices = ["/cpu:0", "/gpu:0"]
    dense_shape = [5, 2]
    t0 = _make_indexed_slices([[1., 2.]], [1], dense_shape, devices[0])
    t1 = _make_indexed_slices(
        [[3., 4.], [5., 6.]], [1, 3], dense_shape, devices[1])
    per_device = value_lib.PerDevice({devices[0]: t0, devices[1]: t1})

    if batch_reduce:
      result = cross_tower_ops_instance.batch_reduce(method_string,
                                                     [(per_device, devices)])
    else:
      result = cross_tower_ops_instance.reduce(method_string, per_device,
                                               devices)

    total_indices_with_dups = [1, 1, 3]
    total_indices_without_dups = [1, 3]

    if method_string == "sum":
      total_values_with_dups = [[1., 2.], [3., 4.], [5., 6.]]
      total_values_without_dups = [[4., 6.], [5., 6.]]
    else:
      assert method_string == "mean"
      total_values_with_dups = [[0.5, 1.], [1.5, 2.], [2.5, 3.]]
      total_values_without_dups = [[2., 3.], [2.5, 3.]]

    total_mirrored_with_dups = _make_mirrored_indexed_slices(
        devices, total_values_with_dups, total_indices_with_dups, dense_shape)
    total_mirrored_without_dups = _make_mirrored_indexed_slices(
        devices, total_values_without_dups, total_indices_without_dups,
        dense_shape)

    # Test that the result is semantically equal to both the concatenated
    # IndexedSlices, as well as when the duplicate indices are summed up.
    if batch_reduce:
      total_mirrored_with_dups = [total_mirrored_with_dups]
      total_mirrored_without_dups = [total_mirrored_without_dups]

    self._assert_values_equal(total_mirrored_with_dups, result)
    self._assert_values_equal(total_mirrored_without_dups, result)
Beispiel #5
0
def main(_):

    print(FLAGS)
    print(tf.__version__, "==tensorflow version==")

    init_checkpoint = os.path.join(FLAGS.buckets, FLAGS.init_checkpoint)

    train_file = []
    for file in FLAGS.train_file.split(","):
        train_file_path = os.path.join(FLAGS.buckets, file)
        train_file.append(train_file_path)
    # train_file = os.path.join(FLAGS.buckets, FLAGS.train_file)
    # dev_file = os.path.join(FLAGS.buckets, FLAGS.dev_file)

    dev_file = []
    for file in FLAGS.dev_file.split(","):
        dev_file_path = os.path.join(FLAGS.buckets, file)
        dev_file.append(dev_file_path)
    checkpoint_dir = os.path.join(FLAGS.buckets, FLAGS.model_output)

    print(init_checkpoint, train_file, dev_file, checkpoint_dir,
          FLAGS.distribution_strategy)

    if FLAGS.distribution_strategy == "MirroredStrategy":
        cross_tower_ops = cross_tower_ops_lib.AllReduceCrossTowerOps(
            "nccl", 10, 0, 0)
        distribution = tf.contrib.distribute.MirroredStrategy(
            num_gpus=FLAGS.num_gpus, cross_tower_ops=cross_tower_ops)
        worker_count = FLAGS.num_gpus
    elif FLAGS.distribution_strategy == "CollectiveAllReduceStrategy":
        print("==disbale evaluator==")

        cluster, task_type, task_index = make_distributed_info_without_evaluator(
        )
        print("==cluster==", cluster, "==task_type==", task_type,
              "==task_index==", task_index)
        dump_into_tf_config(cluster, task_type, task_index)

        print(os.environ['TF_CONFIG'], "===tf config===")

        print("==apply collective all reduce strategy==", FLAGS.autoStrategy)
        if FLAGS.autoStrategy == 'true':
            distribution = None
        else:
            distribution = tf.contrib.distribute.CollectiveAllReduceStrategy(
                num_gpus_per_worker=FLAGS.num_gpus,
                cross_tower_ops_type='horovod',
                all_dense=True)

        worker_count = (len(cluster.get('worker', [])) + len(cluster['chief']))
        if task_type == 'chief':
            is_chief = 1
            task_index = 0
        else:
            is_chief = 0
            task_index = FLAGS.task_index
        print(worker_count, task_type, task_index, FLAGS.task_index)
    else:
        cross_tower_ops = cross_tower_ops_lib.AllReduceCrossTowerOps(
            "nccl", 10, 0, 0)
        distribution = tf.contrib.distribute.MirroredStrategy(
            num_gpus=FLAGS.num_gpus, cross_tower_ops=cross_tower_ops)
        worker_count = FLAGS.num_gpus

    sess_config = tf.ConfigProto(allow_soft_placement=True,
                                 log_device_placement=True)

    run_config = tf.estimator.RunConfig(
        keep_checkpoint_max=10,
        # model_dir=checkpoint_dir,
        train_distribute=distribution,  # tf 1.8
        # distribute=distribution,     # tf 1.4
        session_config=sess_config,
        save_checkpoints_secs=None,
        save_checkpoints_steps=None,
        log_step_count_steps=100)
    # disable_evaluation=True)  # 1.12

    if FLAGS.distribution_strategy == "MirroredStrategy":
        task_index = run_config.task_id
        is_chief = run_config.is_chief

    print("==worker_count==", worker_count, "==local_rank==", task_index,
          "==is is_chief==", is_chief, "==numbers of gpus==", FLAGS.num_gpus)
    cluster = ""
    target = ""

    print(FLAGS)

    if FLAGS.mode == "single_task":
        train_eval_api = train_eval
    elif FLAGS.mode == "multi_task":
        train_eval_api = multitask_train_eval
    elif FLAGS.mode == 'distillation':
        train_eval_api = distillation_train_eval
    elif FLAGS.mode == "electra":
        train_eval_api = pretrain_train_eval

    if FLAGS.mode == "electra":
        train_eval_api.monitored_estimator(
            FLAGS=FLAGS,
            worker_count=worker_count,
            task_index=task_index,
            cluster=cluster,
            is_chief=is_chief,
            init_checkpoint=init_checkpoint,
            train_file=train_file,
            dev_file=dev_file,
            checkpoint_dir=checkpoint_dir,
            run_config=run_config,
            distribution_strategy=FLAGS.distribution_strategy,
            profiler=FLAGS.profiler,
            parse_type=FLAGS.parse_type,
            rule_model=FLAGS.rule_model,
            train_op=FLAGS.train_op,
            running_type=FLAGS.running_type,
            decay=FLAGS.decay,
            warmup=FLAGS.warmup,
            input_target=FLAGS.input_target,
            distillation=FLAGS.distillation,
            temperature=FLAGS.temperature,
            distillation_ratio=FLAGS.distillation_ratio,
            electra_mode=FLAGS.electra_mode,
            sharing_mode=FLAGS.sharing_mode,
            attention_type=FLAGS.attention_type,
            ues_token_type=FLAGS.ues_token_type,
            gumbel_anneal=FLAGS.gumbel_anneal,
            annealed_mask_prob=FLAGS.annealed_mask_prob,
            joint_train=FLAGS.joint_train,
            optimization_type=FLAGS.optimization_type,
            seq_type=FLAGS.seq_type,
            mask_type=FLAGS.mask_type)
    else:
        train_eval_api.monitored_estimator(
            FLAGS=FLAGS,
            worker_count=worker_count,
            task_index=task_index,
            cluster=cluster,
            is_chief=is_chief,
            target=target,
            init_checkpoint=init_checkpoint,
            train_file=train_file,
            dev_file=dev_file,
            checkpoint_dir=checkpoint_dir,
            run_config=run_config,
            distribution_strategy=FLAGS.distribution_strategy,
            profiler=FLAGS.profiler,
            parse_type=FLAGS.parse_type,
            rule_model=FLAGS.rule_model,
            train_op=FLAGS.train_op,
            running_type=FLAGS.running_type,
            decay=FLAGS.decay,
            warmup=FLAGS.warmup,
            input_target=FLAGS.input_target,
            distillation=FLAGS.distillation,
            temperature=FLAGS.temperature,
            distillation_ratio=FLAGS.distillation_ratio,
            attention_type=FLAGS.attention_type,
            ues_token_type=FLAGS.ues_token_type,
            seq_type=FLAGS.seq_type,
            mask_type=FLAGS.mask_type)
Beispiel #6
0
def main(_):

	print(FLAGS)
	print(tf.__version__, "==tensorflow version==")

	os.environ['NCCL_LL_THRESHOLD'] = "0"

	init_checkpoint = os.path.join(FLAGS.buckets, FLAGS.init_checkpoint)
	train_file = []
	for file in FLAGS.train_file.split(","):
		train_file_path = os.path.join(FLAGS.buckets, file)
		train_file.append(train_file_path)
	# train_file = os.path.join(FLAGS.buckets, FLAGS.train_file)
	# dev_file = os.path.join(FLAGS.buckets, FLAGS.dev_file)

	dev_file = []
	for file in FLAGS.dev_file.split(","):
		dev_file_path = os.path.join(FLAGS.buckets, file)
		dev_file.append(dev_file_path)
	checkpoint_dir = os.path.join(FLAGS.buckets, FLAGS.model_output)

	print(init_checkpoint, train_file, dev_file, checkpoint_dir)

	if FLAGS.distribution_strategy == "MirroredStrategy":
		cross_tower_ops = cross_tower_ops_lib.AllReduceCrossTowerOps("nccl", 10, 0, 0)
		distribution = tf.contrib.distribute.MirroredStrategy(num_gpus=FLAGS.num_gpus, 
												cross_tower_ops=cross_tower_ops)
		worker_count = FLAGS.num_gpus
	else:
		cross_tower_ops = cross_tower_ops_lib.AllReduceCrossTowerOps("nccl", 10, 0, 0)
		distribution = tf.contrib.distribute.MirroredStrategy(num_gpus=FLAGS.num_gpus, 
												cross_tower_ops=cross_tower_ops)
		worker_count = FLAGS.num_gpus

	sess_config = tf.ConfigProto(allow_soft_placement=True,
									log_device_placement=True)

	run_config = tf.estimator.RunConfig(
					  keep_checkpoint_max=10,
					  # model_dir=checkpoint_dir,
					  train_distribute=distribution, # tf 1.8
					  # distribute=distribution,     # tf 1.4
					  session_config=sess_config,
					  save_checkpoints_secs=None,
					  save_checkpoints_steps=None,
					  log_step_count_steps=100)

	task_index = run_config.task_id
	is_chief = run_config.is_chief

	print("==worker_count==", worker_count, "==local_rank==", task_index, "==is is_chief==", is_chief)
	cluster = ""
	target = ""

	print(FLAGS)

	if FLAGS.mode == "single_task":
		train_eval_api = train_eval
	elif FLAGS.mode == "multi_task":
		train_eval_api = multitask_train_eval
	elif FLAGS.mode == 'distillation':
		train_eval_api = distillation_train_eval
	elif FLAGS.mode == "electra":
		train_eval_api = pretrain_train_eval

	if FLAGS.mode == "electra":
		train_eval_api.monitored_estimator(
			FLAGS=FLAGS,
			worker_count=worker_count, 
			task_index=task_index, 
			cluster=cluster, 
			is_chief=is_chief, 
			init_checkpoint=init_checkpoint,
			train_file=train_file,
			dev_file=dev_file,
			checkpoint_dir=checkpoint_dir,
			run_config=run_config,
			distribution_strategy=FLAGS.distribution_strategy,
			profiler=FLAGS.profiler,
			parse_type=FLAGS.parse_type,
			rule_model=FLAGS.rule_model,
			train_op=FLAGS.train_op,
			running_type=FLAGS.running_type,
			decay=FLAGS.decay,
			warmup=FLAGS.warmup,
			input_target=FLAGS.input_target,
			distillation=FLAGS.distillation,
			temperature=FLAGS.temperature,
			distillation_ratio=FLAGS.distillation_ratio)
	else:
		train_eval_api.monitored_estimator(
			FLAGS=FLAGS,
			worker_count=worker_count, 
			task_index=task_index, 
			cluster=cluster, 
			is_chief=is_chief, 
			target=target,
			init_checkpoint=init_checkpoint,
			train_file=train_file,
			dev_file=dev_file,
			checkpoint_dir=checkpoint_dir,
			run_config=run_config,
			distribution_strategy=FLAGS.distribution_strategy,
			profiler=FLAGS.profiler,
			parse_type=FLAGS.parse_type,
			rule_model=FLAGS.rule_model,
			train_op=FLAGS.train_op,
			running_type=FLAGS.running_type,
			decay=FLAGS.decay,
			warmup=FLAGS.warmup,
			input_target=FLAGS.input_target,
			distillation=FLAGS.distillation,
			temperature=FLAGS.temperature,
			distillation_ratio=FLAGS.distillation_ratio)
Beispiel #7
0
def main(unused_argv):
    for key in FLAGS.__dict__['__flags'].keys():
        if key in ('h', 'help'):
            continue
        print("%s=%s" % (key, str(FLAGS.__dict__['__flags'][key])))
    print("tensorflow version:", tf.__version__)
    if FLAGS.volumes:
        volumes = FLAGS.volumes.split(",")
        train_files = volumes[0] + "/*.tfr"
        if FLAGS.tf_random_seed != 0:  # for experiment
            train_files = volumes[0] + "/*_0000*.tfr"
        eval_files = volumes[1] + "/*.tfr" if len(volumes) > 1 else train_files
    else:
        train_files = FLAGS.train_data
        eval_files = FLAGS.eval_data
    print("train_data:", train_files)
    print("eval_data:", eval_files)
    model_params = {
        'learning_rate': FLAGS.learning_rate,
        'vocab_size': FLAGS.word_vocab_size,
        'cate_vocab_size': FLAGS.cate_vocab_size,
        'embedding_size': FLAGS.word_embedding_size,
        'doc_embedding_size': FLAGS.cate_embedding_size,
        'num_negative_samples': FLAGS.num_negative_samples,
        'embedding_merge': 'avg'
    }

    session_config = tf.ConfigProto(
        allow_soft_placement=True,
        log_device_placement=FLAGS.log_device_placement,
        intra_op_parallelism_threads=FLAGS.intra_op_parallelism_threads,
        inter_op_parallelism_threads=FLAGS.inter_op_parallelism_threads,
        gpu_options=tf.GPUOptions(allow_growth=True,
                                  force_gpu_compatible=True))
    if FLAGS.evaluate:
        cluster = {'chief': ['localhost:2221'], 'worker': ['localhost:2222']}
        os.environ['TF_CONFIG'] = json.dumps({
            'cluster': cluster,
            'task': {
                'type': 'evaluator',
                'index': 0
            }
        })
        config = tf.estimator.RunConfig(session_config=session_config)
        model = Doc2Vec(params=model_params,
                        optimizer=FLAGS.optimizer,
                        model_dir=FLAGS.checkpointDir,
                        config=config)
        evaluation_listener(model, train_files, eval_files)
        return
    else:
        cross_tower_ops = cross_tower_ops_lib.AllReduceCrossTowerOps(
            'nccl', 16, 0, 0)
        distribution = tf.contrib.distribute.MirroredStrategy(
            num_gpus=FLAGS.num_gpus, cross_tower_ops=cross_tower_ops)
        config = tf.estimator.RunConfig(
            distribute=distribution,
            save_checkpoints_secs=FLAGS.save_checkpoints_secs,
            keep_checkpoint_max=FLAGS.keep_checkpoint_max,
            session_config=session_config,
            log_step_count_steps=FLAGS.log_step_count_steps,
            save_summary_steps=FLAGS.log_step_count_steps)
    if FLAGS.tf_random_seed != 0:
        config.replace(tf_random_seed=FLAGS.tf_random_seed)

    model = Doc2Vec(params=model_params,
                    optimizer=FLAGS.optimizer,
                    model_dir=FLAGS.checkpointDir,
                    config=config)

    train_hooks = [
        tf.train.ProfilerHook(save_secs=FLAGS.save_checkpoints_secs,
                              output_dir=FLAGS.buckets)
    ] if FLAGS.profile else []

    logging.info("before train")
    model.train(lambda: input_fn(train_files, True),
                max_steps=FLAGS.train_steps,
                hooks=train_hooks)
    logging.info("finish main")
def main(_):

    print(FLAGS)
    print(tf.__version__, "==tensorflow version==")

    os.environ[
        'NCCL_LL_THRESHOLD'] = '0'  # to avoid collective reduce hangs on
    # os.environ['TF_ENABLE_WHILE_V2'] = '1'
    # os.environ['TF_ENABLE_COND_V2'] = '1'

    tf.enable_resource_variables()

    init_checkpoint = os.path.join(FLAGS.buckets, FLAGS.init_checkpoint)
    train_file = []
    for file in FLAGS.train_file.split(","):
        train_file_path = os.path.join(FLAGS.buckets, file)
        train_file.append(train_file_path)
    # train_file = os.path.join(FLAGS.buckets, FLAGS.train_file)
    # dev_file = os.path.join(FLAGS.buckets, FLAGS.dev_file)

    dev_file = []
    for file in FLAGS.dev_file.split(","):
        dev_file_path = os.path.join(FLAGS.buckets, file)
        dev_file.append(dev_file_path)
    checkpoint_dir = os.path.join(FLAGS.buckets, FLAGS.model_output)

    print(init_checkpoint, train_file, dev_file, checkpoint_dir)

    if FLAGS.distribution_strategy == "MirroredStrategy":
        cross_tower_ops = cross_tower_ops_lib.AllReduceCrossTowerOps(
            "nccl", 10, 0, 0)
        distribution = tf.contrib.distribute.MirroredStrategy(
            num_gpus=FLAGS.num_gpus, cross_tower_ops=cross_tower_ops)
        worker_count = FLAGS.num_gpus
    else:
        cross_tower_ops = cross_tower_ops_lib.AllReduceCrossTowerOps(
            "nccl", 10, 0, 0)
        distribution = tf.contrib.distribute.MirroredStrategy(
            num_gpus=FLAGS.num_gpus, cross_tower_ops=cross_tower_ops)
        worker_count = FLAGS.num_gpus

    sess_config = tf.ConfigProto(allow_soft_placement=True,
                                 log_device_placement=True)

    run_config = tf.estimator.RunConfig(
        keep_checkpoint_max=10,
        # model_dir=checkpoint_dir,
        train_distribute=distribution,  # tf 1.8
        # distribute=distribution,     # tf 1.4
        session_config=sess_config,
        save_checkpoints_secs=None,
        save_checkpoints_steps=None,
        log_step_count_steps=100)

    task_index = run_config.task_id
    is_chief = run_config.is_chief

    print("==worker_count==", worker_count, "==local_rank==", task_index,
          "==is is_chief==", is_chief)
    cluster = ""
    target = ""

    print(FLAGS)

    if FLAGS.mode == "single_task":
        train_eval_api = train_eval
    elif FLAGS.mode == "multi_task":
        train_eval_api = multitask_train_eval
    elif FLAGS.mode == 'distillation':
        train_eval_api = distillation_train_eval
    elif FLAGS.mode == "electra":
        train_eval_api = pretrain_train_eval

    if FLAGS.mode == "electra":
        train_eval_api.monitored_estimator(
            FLAGS=FLAGS,
            worker_count=worker_count,
            task_index=task_index,
            cluster=cluster,
            is_chief=is_chief,
            init_checkpoint=init_checkpoint,
            train_file=train_file,
            dev_file=dev_file,
            checkpoint_dir=checkpoint_dir,
            run_config=run_config,
            distribution_strategy=FLAGS.distribution_strategy,
            profiler=FLAGS.profiler,
            parse_type=FLAGS.parse_type,
            rule_model=FLAGS.rule_model,
            train_op=FLAGS.train_op,
            running_type=FLAGS.running_type,
            decay=FLAGS.decay,
            warmup=FLAGS.warmup,
            input_target=FLAGS.input_target,
            distillation=FLAGS.distillation,
            temperature=FLAGS.temperature,
            distillation_ratio=FLAGS.distillation_ratio,
            electra_mode=FLAGS.electra_mode,
            sharing_mode=FLAGS.sharing_mode,
            attention_type=FLAGS.attention_type,
            ues_token_type=FLAGS.ues_token_type,
            gumbel_anneal=FLAGS.gumbel_anneal,
            annealed_mask_prob=FLAGS.annealed_mask_prob,
            joint_train=FLAGS.joint_train,
            optimization_type=FLAGS.optimization_type,
            gen_disc_type=FLAGS.gen_disc_type,
            train_op_type=FLAGS.train_op_type,
            mask_method=FLAGS.mask_method,
            minmax_mode=FLAGS.minmax_mode,
            seq_type=FLAGS.seq_type,
            mask_type=FLAGS.mask_type)
        # use_tpu=FLAGS.use_tpu)
    else:
        train_eval_api.monitored_estimator(
            FLAGS=FLAGS,
            worker_count=worker_count,
            task_index=task_index,
            cluster=cluster,
            is_chief=is_chief,
            target=target,
            init_checkpoint=init_checkpoint,
            train_file=train_file,
            dev_file=dev_file,
            checkpoint_dir=checkpoint_dir,
            run_config=run_config,
            distribution_strategy=FLAGS.distribution_strategy,
            profiler=FLAGS.profiler,
            parse_type=FLAGS.parse_type,
            rule_model=FLAGS.rule_model,
            train_op=FLAGS.train_op,
            running_type=FLAGS.running_type,
            decay=FLAGS.decay,
            warmup=FLAGS.warmup,
            input_target=FLAGS.input_target,
            distillation=FLAGS.distillation,
            temperature=FLAGS.temperature,
            distillation_ratio=FLAGS.distillation_ratio,
            attention_type=FLAGS.attention_type,
            ues_token_type=FLAGS.ues_token_type,
            seq_type=FLAGS.seq_type,
            mask_type=FLAGS.mask_type)
Beispiel #9
0
def main(_):

    print(FLAGS)
    print(tf.__version__, "==tensorflow version==")

    init_checkpoint = os.path.join(FLAGS.buckets, FLAGS.init_checkpoint)
    train_file = os.path.join(FLAGS.buckets, FLAGS.train_file)
    dev_file = os.path.join(FLAGS.buckets, FLAGS.dev_file)
    checkpoint_dir = os.path.join(FLAGS.buckets, FLAGS.model_output)

    print(init_checkpoint, train_file, dev_file, checkpoint_dir)

    if FLAGS.distribution_strategy == "MirroredStrategy":
        cross_tower_ops = cross_tower_ops_lib.AllReduceCrossTowerOps(
            "nccl", 10, 0, 0)
        distribution = tf.contrib.distribute.MirroredStrategy(
            num_gpus=FLAGS.num_gpus, cross_tower_ops=cross_tower_ops)
        worker_count = FLAGS.num_gpus
    elif FLAGS.distribution_strategy == "CollectiveAllReduceStrategy":
        print("==apply collective all reduce strategy==")
        cluster, task_type, task_index = make_distributed_info_without_evaluator(
        )
        print("==cluster==", cluster, "==task_type==", task_type,
              "==task_index==", task_index)
        dump_into_tf_config(cluster, task_type, task_index)
        distribution = tf.contrib.distribute.CollectiveAllReduceStrategy(
            num_gpus_per_worker=1, cross_tower_ops_type='horovod')
        worker_count = len(cluster['chief']) + len(cluster['worker'])
    else:
        cross_tower_ops = cross_tower_ops_lib.AllReduceCrossTowerOps(
            "nccl", 10, 0, 0)
        distribution = tf.contrib.distribute.MirroredStrategy(
            num_gpus=FLAGS.num_gpus, cross_tower_ops=cross_tower_ops)
        worker_count = FLAGS.num_gpus

    sess_config = tf.ConfigProto(allow_soft_placement=True,
                                 log_device_placement=True)

    run_config = tf.estimator.RunConfig(
        keep_checkpoint_max=10,
        # model_dir=checkpoint_dir,
        train_distribute=distribution,  # tf 1.8
        # distribute=distribution,     # tf 1.4
        session_config=sess_config,
        save_checkpoints_secs=None,
        save_checkpoints_steps=None,
        log_step_count_steps=100)
    # disable_evaluation=True) # tf180

    task_index = run_config.task_id
    is_chief = run_config.is_chief

    print("==worker_count==", worker_count, "==local_rank==", task_index,
          "==is is_chief==", is_chief)
    cluster = ""
    target = ""

    train_eval_fn(FLAGS=FLAGS,
                  worker_count=worker_count,
                  task_index=task_index,
                  cluster=cluster,
                  is_chief=is_chief,
                  target=target,
                  init_checkpoint=init_checkpoint,
                  train_file=train_file,
                  dev_file=dev_file,
                  is_debug=FLAGS.is_debug,
                  checkpoint_dir=checkpoint_dir,
                  run_config=run_config,
                  distribution_strategy=FLAGS.distribution_strategy,
                  profiler=FLAGS.profiler,
                  parse_type=FLAGS.parse_type,
                  rule_model=FLAGS.rule_model,
                  train_op=FLAGS.train_op,
                  running_type=FLAGS.running_type,
                  decay=FLAGS.decay,
                  warmup=FLAGS.warmup,
                  input_target=FLAGS.input_target,
                  distillation=FLAGS.distillation,
                  temperature=FLAGS.temperature,
                  distillation_ratio=FLAGS.distillation_ratio)
Beispiel #10
0
    def __init__(self, **kwargs):

        if self.config.mode == 'train' or self.config.mode == "train_and_evaluate" or \
                self.config.mode == "train_and_evaluate_on_the_fly" or self.config.mode == "train_on_the_fly":

            tf.logging.info("***********Running in {} mode***********".format(
                self.config.mode))

            if self.config.enable_xla is True:
                tf.logging.info("***********Enable Tao***********")
                os.environ['BRIDGE_ENABLE_TAO'] = 'True'
                os.environ["TAO_ENABLE_CHECK"] = "false"
                os.environ["TAO_COMPILATION_MODE_ASYNC"] = "false"
                os.environ["DISABLE_DEADNESS_ANALYSIS"] = "true"
            else:
                tf.logging.info("***********Disable Tao***********")

            if self.config.enable_auto_mixed_precision is True:
                tf.logging.info(
                    "***********Enable AUTO_MIXED_PRECISION***********")
                os.environ['TF_AUTO_MIXED_PRECISION'] = 'True'
                os.environ['lossScaling'] = 'auto'
            else:
                tf.logging.info(
                    "***********Disable AUTO_MIXED_PRECISION***********")

            NCCL_MAX_NRINGS = "4"
            NCCL_MIN_NRINGS = "4"
            TF_JIT_PROFILING = 'False'
            PAI_ENABLE_HLO_DUMPER = 'False'
            os.environ['PAI_ENABLE_HLO_DUMPER'] = PAI_ENABLE_HLO_DUMPER
            os.environ['TF_JIT_PROFILING'] = TF_JIT_PROFILING
            os.environ["NCCL_MAX_NRINGS"] = NCCL_MAX_NRINGS
            os.environ["NCCL_MIN_NRINGS"] = NCCL_MIN_NRINGS
            os.environ["NCCL_LAUNCH_MODE"] = "PARALLEL"
            tf.logging.info("***********NCCL_MAX_NRINGS {}***********".format(
                NCCL_MAX_NRINGS))
            tf.logging.info("***********NCCL_MIN_NRINGS {}***********".format(
                NCCL_MIN_NRINGS))
            tf.logging.info("***********TF_JIT_PROFILING {}***********".format(
                TF_JIT_PROFILING))
            tf.logging.info(
                "***********PAI_ENABLE_HLO_DUMPER {}***********".format(
                    PAI_ENABLE_HLO_DUMPER))

            self.strategy = None
            if self.config.num_gpus >= 1 and self.config.num_workers >= 1 and \
                    (self.config.distribution_strategy == "ExascaleStrategy" or
                     self.config.distribution_strategy == "CollectiveAllReduceStrategy"):

                if FLAGS.usePAI:
                    import pai
                    worker_hosts = self.config.worker_hosts.split(',')
                    tf.logging.info(
                        "***********Job Name is {}***********".format(
                            self.config.job_name))
                    tf.logging.info(
                        "***********Task Index is {}***********".format(
                            self.config.task_index))
                    tf.logging.info(
                        "***********Worker Hosts is {}***********".format(
                            self.config.worker_hosts))
                    pai.distribute.set_tf_config(
                        self.config.job_name,
                        self.config.task_index,
                        worker_hosts,
                        has_evaluator=self.config.
                        pull_evaluation_in_multiworkers_training)

                if self.config.distribution_strategy == "ExascaleStrategy":
                    tf.logging.info(
                        "*****************Using ExascaleStrategy*********************"
                    )
                    if FLAGS.usePAI:
                        self.strategy = pai.distribute.ExascaleStrategy(
                            num_gpus=self.config.num_gpus,
                            num_micro_batches=self.config.
                            num_accumulated_batches,
                            max_splits=1,
                            enable_sparse_allreduce=False)
                    else:
                        raise ValueError("Please set usePAI is True")

                elif self.config.distribution_strategy == "CollectiveAllReduceStrategy":
                    tf.logging.info(
                        "*****************Using CollectiveAllReduceStrategy*********************"
                    )
                    if FLAGS.usePAI:
                        self.strategy = tf.contrib.distribute.CollectiveAllReduceStrategy(
                            num_gpus_per_worker=self.config.num_gpus,
                            cross_tower_ops_type='default',
                            all_dense=True,
                            iter_size=self.config.num_accumulated_batches)
                    else:
                        self.strategy = tf.contrib.distribute.CollectiveAllReduceStrategy(
                            num_gpus_per_worker=self.config.num_gpus)

                if self.config.pull_evaluation_in_multiworkers_training is True:
                    real_num_workers = self.config.num_workers - 1
                else:
                    real_num_workers = self.config.num_workers

                global_batch_size = self.config.train_batch_size * self.config.num_gpus * real_num_workers


            elif self.config.num_gpus > 1 and self.config.num_workers == 1 and \
                    self.config.distribution_strategy == "MirroredStrategy":
                tf.logging.info(
                    "*****************Using MirroredStrategy*********************"
                )
                if FLAGS.usePAI:
                    from tensorflow.contrib.distribute.python import cross_tower_ops as cross_tower_ops_lib
                    cross_tower_ops = cross_tower_ops_lib.AllReduceCrossTowerOps(
                        'nccl')
                    self.strategy = tf.contrib.distribute.MirroredStrategy(
                        num_gpus=self.config.num_gpus,
                        cross_tower_ops=cross_tower_ops,
                        all_dense=True,
                        iter_size=self.config.num_accumulated_batches)
                else:
                    self.strategy = tf.contrib.distribute.MirroredStrategy(
                        num_gpus=self.config.num_gpus)

                global_batch_size = self.config.train_batch_size * self.config.num_gpus * self.config.num_accumulated_batches

            elif self.config.num_gpus >= 1 and self.config.num_workers >= 1 and \
                    self.config.distribution_strategy == "WhaleStrategy":

                if FLAGS.usePAI:
                    import pai
                    worker_hosts = self.config.worker_hosts.split(',')
                    tf.logging.info(
                        "***********Job Name is {}***********".format(
                            self.config.job_name))
                    tf.logging.info(
                        "***********Task Index is {}***********".format(
                            self.config.task_index))
                    tf.logging.info(
                        "***********Worker Hosts is {}***********".format(
                            self.config.worker_hosts))
                    pai.distribute.set_tf_config(
                        self.config.job_name,
                        self.config.task_index,
                        worker_hosts,
                        has_evaluator=self.config.
                        pull_evaluation_in_multiworkers_training)

                tf.logging.info(
                    "*****************Using WhaleStrategy*********************"
                )
                os.environ["WHALE_COMMUNICATION_SPARSE_AS_DENSE"] = "True"
                os.environ["WHALE_COMMUNICATION_NUM_COMMUNICATORS"] = "2"
                os.environ["WHALE_COMMUNICATION_NUM_SPLITS"] = "8"
                global_batch_size = self.config.train_batch_size * self.config.num_accumulated_batches * self.config.num_model_replica

            elif self.config.num_gpus == 1 and self.config.num_workers == 1:
                global_batch_size = self.config.train_batch_size * self.config.num_accumulated_batches
                tf.logging.info(
                    "***********Single worker, Single gpu, Don't use distribution strategy***********"
                )

            elif self.config.num_gpus == 0 and self.config.num_workers == 1:
                global_batch_size = self.config.train_batch_size * self.config.num_accumulated_batches
                tf.logging.info(
                    "***********Single worker, Running on CPU***********")

            else:
                raise ValueError(
                    "In train model, Please set correct num_workers, num_gpus and distribution_strategy, \n"
                    "num_workers>=1, num_gpus>=1, distribution_strategy=WhaleStrategy|ExascaleStrategy|CollectiveAllReduceStrategy \n"
                    "num_workers>1, num_gpus==1, distribution_strategy=MirroredStrategy \n"
                    "num_workers=1, num_gpus=1, distribution_strategy=None")

            # Validate optional keyword arguments.
            if "num_train_examples" not in kwargs:
                raise ValueError('Please pass num_train_examples')

            self.num_train_examples = kwargs['num_train_examples']

            # if save steps is None, save per epoch
            if self.config.save_steps is None:
                self.save_steps = int(self.num_train_examples /
                                      global_batch_size)
            else:
                self.save_steps = self.config.save_steps

            self.train_steps = int(
                self.num_train_examples * self.config.num_epochs /
                global_batch_size) + 1

            self.throttle_secs = self.config.throttle_secs
            self.model_dir = self.config.model_dir
            tf.logging.info("model_dir: {}".format(self.config.model_dir))
            tf.logging.info("num workers: {}".format(self.config.num_workers))
            tf.logging.info("num gpus: {}".format(self.config.num_gpus))
            tf.logging.info("learning rate: {}".format(
                self.config.learning_rate))
            tf.logging.info("train batch size: {}".format(
                self.config.train_batch_size))
            tf.logging.info("global batch size: {}".format(global_batch_size))
            tf.logging.info("num accumulated batches: {}".format(
                self.config.num_accumulated_batches))
            tf.logging.info("num model replica: {}".format(
                self.config.num_model_replica))
            tf.logging.info("num train examples per epoch: {}".format(
                self.num_train_examples))
            tf.logging.info("num epochs: {}".format(self.config.num_epochs))
            tf.logging.info("train steps: {}".format(self.train_steps))
            tf.logging.info("save steps: {}".format(self.save_steps))
            tf.logging.info("throttle secs: {}".format(self.throttle_secs))
            tf.logging.info("keep checkpoint max: {}".format(
                self.config.keep_checkpoint_max))
            tf.logging.info("warmup ratio: {}".format(
                self.config.warmup_ratio))
            tf.logging.info("gradient clip: {}".format(
                self.config.gradient_clip))
            tf.logging.info("log step count steps: {}".format(
                self.config.log_step_count_steps))

            if self.config.distribution_strategy != "WhaleStrategy":
                self.estimator = tf.estimator.Estimator(
                    model_fn=self._build_model_fn(),
                    model_dir=self.config.model_dir,
                    config=self._get_run_train_config(config=self.config))
            else:
                tf.logging.info("***********Using Whale Estimator***********")
                try:
                    from easytransfer.engines.whale_estimator import WhaleEstimator
                    import whale as wh
                    wh.init()
                    self.estimator = WhaleEstimator(
                        model_fn=self._build_model_fn(),
                        model_dir=self.config.model_dir,
                        num_model_replica=self.config.num_model_replica,
                        num_accumulated_batches=self.config.
                        num_accumulated_batches)
                except:
                    raise NotImplementedError(
                        "WhaleStrategy doesn't work well")

            if self.config.mode == 'train_and_evaluate' or self.config.mode == 'train_and_evaluate_on_the_fly':
                self.num_eval_steps = self.config.num_eval_steps
                tf.logging.info("num eval steps: {}".format(
                    self.num_eval_steps))

        elif self.config.mode == 'evaluate' or self.config.mode == 'evaluate_on_the_fly':
            self.num_eval_steps = self.config.num_eval_steps
            tf.logging.info("num eval steps: {}".format(self.num_eval_steps))
            tf.logging.info("***********Running in {} mode***********".format(
                self.config.mode))
            self.estimator = tf.estimator.Estimator(
                model_fn=self._build_model_fn(),
                config=self._get_run_predict_config())

        elif self.config.mode == 'predict' or self.config.mode == 'predict_on_the_fly':
            tf.logging.info("***********Running in {} mode***********".format(
                self.config.mode))
            self.estimator = tf.estimator.Estimator(
                model_fn=self._build_model_fn(),
                config=self._get_run_predict_config())

        elif self.config.mode == 'export':
            tf.logging.info("***********Running in {} mode***********".format(
                self.config.mode))
            self.estimator = tf.estimator.Estimator(
                model_fn=self._build_model_fn(),
                config=self._get_run_predict_config())

        elif self.config.mode == 'preprocess':
            tf.logging.info("***********Running in {} mode***********".format(
                self.config.mode))
            self.estimator = tf.estimator.Estimator(
                model_fn=self._build_model_fn(),
                config=tf.estimator.RunConfig())

            self.first_sequence = self.config.first_sequence
            self.second_sequence = self.config.second_sequence
            self.label_enumerate_values = self.config.label_enumerate_values
            self.label_name = self.config.label_name
Beispiel #11
0
def main():
    # Parse arguments and print them
    args = parse_args()
    print("\nMain arguments:")
    for k, v in args.__dict__.items():
        print("{}={}".format(k, v))

    # Config
    config = parse_config('MiniBERT')
    config[
        "init_checkpoint"] = args.buckets + args.init_ckt_dir + "/model.ckpt-{}".format(
            args.init_ckt_step)

    # Check if the model has already exisited
    model_save_dir = args.buckets + args.checkpoint_dir
    if tf.gfile.Exists(model_save_dir + "/checkpoint"):
        raise ValueError(
            "Model %s has already existed, please delete them and retry" %
            model_save_dir)

    helper.dump_args(model_save_dir, args)

    transformer_model = model.TextTransformerNet(
        bert_config=config,
        model_configs=model.TextTransformerNet.ModelConfigs(
            dropout_rate=args.dropout_rate,
            num_vocabulary=args.num_vocabulary,
            feed_forward_in_dim=args.feed_forward_in_dim,
            model_dim=args.model_dim,
            num_blocks=args.num_blocks,
            num_heads=args.num_heads,
            enable_date_time_emb=args.enable_date_time_emb,
            word_emb_dim=args.word_emb_dim,
            date_span=args.date_span),
        train_configs=model.TrainConfigs(learning_rate=args.learning_rate,
                                         batch_size=args.batch_size,
                                         dropout_rate=args.dropout_rate),
        predict_configs=None,
        run_configs=model.RunConfigs(log_every=50))
    # checkpoint_path = None
    # if args.step > 0:
    #     checkpoint_path = model_save_dir + "/model.ckpt-{}".format(args.step)
    # warm_start_settings = tf.estimator.WarmStartSettings(checkpoint_path,
    #                                                      vars_to_warm_start='(.*Embedding|Conv-[1-4]|MlpLayer-1)')
    cross_tower_ops = cross_tower_ops_lib.AllReduceCrossTowerOps('nccl')
    distribution = tf.contrib.distribute.MirroredStrategy(
        num_gpus=4, cross_tower_ops=cross_tower_ops, all_dense=False)

    estimator = tf.estimator.Estimator(
        model_fn=transformer_model.model_fn,
        model_dir=model_save_dir,
        config=tf.estimator.RunConfig(session_config=tf.ConfigProto(
            gpu_options=tf.GPUOptions(allow_growth=False),
            allow_soft_placement=True),
                                      save_checkpoints_steps=args.snapshot,
                                      keep_checkpoint_max=20,
                                      train_distribute=distribution))
    print("Start training......")
    tf.estimator.train(estimator,
                       train_spec=tf.estimator.TrainSpec(
                           input_fn=loader.OdpsDataLoader(
                               table_name=args.tables,
                               mode=tf.estimator.ModeKeys.TRAIN,
                               hist_length=args.max_length,
                               target_length=args.target_length,
                               batch_size=args.batch_size).input_fn,
                           max_steps=args.max_steps))
def main(_):

    print(FLAGS)
    print(tf.__version__, "==tensorflow version==")

    # make all to train and not evaluate while training
    worker_hosts = FLAGS.worker_hosts.split(",")
    worker_count = len(worker_hosts)
    print("==numbers of workers==", worker_count)

    if len(worker_hosts) > 1:
        cluster = {"chief": [worker_hosts[0]], "worker": worker_hosts[1:]}
    else:
        cluster = {"chief": [worker_hosts[0]]}

    if FLAGS.task_index == 0:
        task_type = 'chief'
        task_index = 0
        os.environ['TF_CONFIG'] = json.dumps({
            'cluster': cluster,
            'task': {
                'type': "chief",
                'index': 0
            }
        })
    else:
        task_type = 'worker'
        task_index = FLAGS.task_index - 1
        os.environ['TF_CONFIG'] = json.dumps({
            'cluster': cluster,
            'task': {
                'type': task_type,
                'index': FLAGS.task_index - 1
            }
        })

    init_checkpoint = os.path.join(FLAGS.buckets, FLAGS.init_checkpoint)
    train_file = os.path.join(FLAGS.buckets, FLAGS.train_file)
    dev_file = os.path.join(FLAGS.buckets, FLAGS.dev_file)
    checkpoint_dir = os.path.join(FLAGS.buckets, FLAGS.model_output)

    print(init_checkpoint, train_file, dev_file, checkpoint_dir)

    if FLAGS.distribution_strategy == "MirroredStrategy":
        cross_tower_ops = cross_tower_ops_lib.AllReduceCrossTowerOps(
            "nccl", 10, 0, 0)
        distribution = tf.contrib.distribute.MirroredStrategy(
            num_gpus=FLAGS.num_gpus, cross_tower_ops=cross_tower_ops)
        worker_count = FLAGS.num_gpus
    elif FLAGS.distribution_strategy == "CollectiveAllReduceStrategy":
        distribution = tf.contrib.distribute.CollectiveAllReduceStrategy(
            num_gpus_per_worker=1,
            cross_tower_ops_type=FLAGS.get("cross_tower_ops_type", "paisoar"))
        worker_count = len(worker_hosts)
    else:
        cross_tower_ops = cross_tower_ops_lib.AllReduceCrossTowerOps(
            "nccl", 10, 0, 0)
        distribution = tf.contrib.distribute.MirroredStrategy(
            num_gpus=FLAGS.num_gpus, cross_tower_ops=cross_tower_ops)

    sess_config = tf.ConfigProto(allow_soft_placement=True,
                                 log_device_placement=True)

    run_config = tf.estimator.RunConfig(keep_checkpoint_max=5,
                                        model_dir=checkpoint_dir,
                                        distribute=distribution,
                                        session_config=sess_config,
                                        save_checkpoints_secs=None,
                                        save_checkpoints_steps=None,
                                        log_step_count_steps=100)

    task_index = run_config.task_id
    is_chief = run_config.is_chief

    print("==worker_count==", worker_count, "==local_rank==", task_index,
          "==is is_chief==", is_chief)
    target = ""

    train_eval.monitored_estimator(
        FLAGS=FLAGS,
        worker_count=worker_count,
        task_index=task_index,
        cluster=cluster,
        is_chief=is_chief,
        target=target,
        init_checkpoint=init_checkpoint,
        train_file=train_file,
        dev_file=dev_file,
        checkpoint_dir=checkpoint_dir,
        run_config=run_config,
        distribution_strategy=FLAGS.distribution_strategy,
        profiler=FLAGS.profiler,
        parse_type=FLAGS.parse_type,
        rule_model=FLAGS.rule_model,
        train_op=FLAGS.train_op,
        running_type=FLAGS.running_type,
        decay=FLAGS.decay,
        warmup=FLAGS.warmup,
        input_target=FLAGS.input_target,
        distillation=FLAGS.distillation,
        temperature=FLAGS.temperature,
        distillation_ratio=FLAGS.distillation_ratio)
Beispiel #13
0
class CrossTowerOpsTest(test.TestCase, parameterized.TestCase):
    def _assert_indexed_slices_equal(self, left, right):
        self.assertIsInstance(left, ops.IndexedSlices)
        self.assertIsInstance(right, ops.IndexedSlices)
        self.assertEqual(device_util.resolve(left.device),
                         device_util.resolve(right.device))
        self.assertAllEqual(self.evaluate(ops.convert_to_tensor(left)),
                            self.evaluate(ops.convert_to_tensor(right)))

    def _assert_values_equal(self, left, right):
        if isinstance(left, list):
            for l, r in zip(left, right):
                self._assert_values_equal(l, r)
        else:
            self.assertEqual(type(left), type(right))
            self.assertEqual(left.devices, right.devices)
            if isinstance(list(left._index.values())[0], ops.IndexedSlices):
                for (d, v) in left._index.iteritems():
                    self._assert_indexed_slices_equal(v, right._index[d])
            elif context.executing_eagerly():
                self.assertEqual([v.numpy() for v in left._index.values()],
                                 list(right._index.values()))
            else:
                with self.test_session() as sess:
                    self.assertEqual(sess.run(list(left._index.values())),
                                     list(right._index.values()))

    # TODO(yuefengz): decouple the num_gpus check from distribution in
    # combinations module so that we can pass in devices instead of a distribution
    # strategy.
    reduction_to_one_combinations = combinations.combine(
        cross_tower_ops=[
            combinations.NamedObject(
                "DefaultReductionToOneDeviceCrossTowerOps",
                cross_tower_ops_lib.ReductionToOneDeviceCrossTowerOps()),
            combinations.NamedObject(
                "ReductionToCPUDeviceCrossTowerOps",
                cross_tower_ops_lib.ReductionToOneDeviceCrossTowerOps(
                    reduce_to_device=_cpu_device)),
            combinations.NamedObject(
                "AccumulateNCrossTowerOp",
                cross_tower_ops_lib.ReductionToOneDeviceCrossTowerOps(
                    accumulation_fn=math_ops.accumulate_n)),
        ],
        distribution=[
            combinations.one_device_strategy,
            combinations.mirrored_strategy_with_gpu_and_cpu,
            combinations.mirrored_strategy_with_two_gpus
        ],
        mode=["graph", "eager"])
    allreduce_combinations = combinations.combine(
        cross_tower_ops=[
            combinations.NamedObject(
                "AllReduce",
                cross_tower_ops_lib.AllReduceCrossTowerOps("nccl", 1, 0, 0)),
            combinations.NamedObject(
                "HierarchicalCopy",
                cross_tower_ops_lib.AllReduceCrossTowerOps(
                    "hierarchical_copy", 8, 0, 0)),
            combinations.NamedObject(
                "AllReduceNoGradientRepacking",
                cross_tower_ops_lib.AllReduceCrossTowerOps("nccl", 0, 0, 0)),
            combinations.NamedObject(
                "HierarchicalCopyAggregateSmallTensors",
                cross_tower_ops_lib.AllReduceCrossTowerOps(
                    "hierarchical_copy", 0, 100, 10))
        ],
        distribution=[combinations.mirrored_strategy_with_two_gpus],
        mode=["graph", "eager"])

    @combinations.generate(reduction_to_one_combinations +
                           allreduce_combinations)
    def testReductionAndBroadcast(self, cross_tower_ops, distribution):
        devices = distribution.worker_devices

        values = [constant_op.constant(float(d)) for d in range(len(devices))]
        per_device = _make_per_device(values, devices)
        mean = (len(devices) - 1.) / 2.

        values_2 = [constant_op.constant(d + 1.0) for d in range(len(devices))]
        per_device_2 = _make_per_device(values_2, devices)
        mean_2 = mean + 1.

        destination_mirrored = _fake_mirrored(1., devices)
        destination_different = _fake_mirrored(1., _cpu_device)
        destination_str = _cpu_device
        destination_list = devices

        all_destinations = [
            None, destination_mirrored, destination_different, destination_str,
            destination_list
        ]

        # test reduce()
        for destinations in all_destinations:
            self._assert_values_equal(
                cross_tower_ops.reduce("mean",
                                       per_device,
                                       destinations=destinations),
                _fake_mirrored(mean, destinations or per_device))
            self._assert_values_equal(
                cross_tower_ops.reduce("mean",
                                       per_device_2,
                                       destinations=destinations),
                _fake_mirrored(mean_2, destinations or per_device))
            self._assert_values_equal(
                cross_tower_ops.reduce("sum",
                                       per_device,
                                       destinations=destinations),
                _fake_mirrored(mean * len(devices), destinations
                               or per_device))
            self._assert_values_equal(
                cross_tower_ops.reduce("sum",
                                       per_device_2,
                                       destinations=destinations),
                _fake_mirrored(mean_2 * len(devices), destinations
                               or per_device))

        # test batch_reduce()
        for d1, d2 in itertools.product(all_destinations, all_destinations):
            self._assert_values_equal(
                cross_tower_ops.batch_reduce("mean", [(per_device, d1),
                                                      (per_device_2, d2)]),
                [
                    _fake_mirrored(mean, d1 or per_device),
                    _fake_mirrored(mean_2, d2 or per_device_2)
                ])
            self._assert_values_equal(
                cross_tower_ops.batch_reduce("sum", [(per_device, d1),
                                                     (per_device_2, d2)]),
                [
                    _fake_mirrored(mean * len(devices), d1 or per_device),
                    _fake_mirrored(mean_2 * len(devices), d2 or per_device_2)
                ])

        # test broadcast()
        for destinations in all_destinations:
            if destinations is None:
                continue
            else:
                self._assert_values_equal(
                    cross_tower_ops.broadcast(constant_op.constant(1.),
                                              destinations),
                    _fake_mirrored(1., destinations))

    def testChooseAlgorithm(self):
        device_links = [[1, 2, 3, 4], [0, 2, 3, 5], [0, 1, 3, 6], [0, 1, 2, 7],
                        [0, 5, 6, 7], [1, 4, 6, 7], [2, 4, 5, 7], [3, 4, 5, 6]]
        result = cross_tower_ops_lib._choose_all_reduce_algorithm(device_links)
        self.assertIsInstance(result,
                              cross_tower_ops_lib.AllReduceCrossTowerOps)
        self.assertEqual(result.all_reduce_alg, "hierarchical_copy")
        self.assertEqual(result.num_packs, 8)

        # if there are only 4 devices
        device_links = [[1, 2, 3, 4], [0, 2, 3, 5], [0, 1, 3, 6], [0, 1, 2, 7]]
        result = cross_tower_ops_lib._choose_all_reduce_algorithm(device_links)
        self.assertIsInstance(result,
                              cross_tower_ops_lib.AllReduceCrossTowerOps)
        self.assertEqual(result.all_reduce_alg, "nccl")
        self.assertEqual(result.num_packs, 1)

        # if devices links contain each device itself
        device_links = [[0, 1, 2, 3, 4], [0, 1, 2, 3, 5], [0, 1, 2, 3, 6],
                        [0, 1, 2, 3, 7], [0, 4, 5, 6, 7], [1, 4, 5, 6, 7],
                        [2, 4, 5, 6, 7], [3, 4, 5, 6, 7]]
        result = cross_tower_ops_lib._choose_all_reduce_algorithm(device_links)
        self.assertIsInstance(result,
                              cross_tower_ops_lib.AllReduceCrossTowerOps)
        self.assertEqual(result.all_reduce_alg, "hierarchical_copy")
        self.assertEqual(result.num_packs, 8)

        # if not dgx1-like links
        device_links = [[0, 2, 3, 5], [0, 1, 3, 6], [0, 1, 2, 7], [0, 5, 6, 7],
                        [1, 4, 6, 7], [2, 4, 5, 7], [3, 4, 5, 6], [1, 2, 3, 4]]
        result = cross_tower_ops_lib._choose_all_reduce_algorithm(device_links)
        self.assertIsInstance(result,
                              cross_tower_ops_lib.AllReduceCrossTowerOps)
        self.assertEqual(result.all_reduce_alg, "nccl")
        self.assertEqual(result.num_packs, 1)

    @combinations.generate(
        combinations.combine(mode=["graph", "eager"], required_gpus=1))
    def testSimpleReduceWithIndexedSlices(self):
        devices = ["/cpu:0", "/gpu:0"]
        t0 = _make_indexed_slices([[1., 2.]], [1], [5, 2], devices[0])
        t1 = _make_indexed_slices([[3., 4.], [5., 6.]], [1, 3], [5, 2],
                                  devices[1])
        per_device = value_lib.PerDevice({devices[0]: t0, devices[1]: t1})
        result = cross_tower_ops_lib._simple_reduce(per_device, devices[0],
                                                    math_ops.add_n, "sum")

        # Test that the result is semantically equal to both the concatenated
        # IndexedSlices with and without duplicate indices.
        total_with_dups = _make_indexed_slices([[1., 2.], [3., 4.], [5., 6.]],
                                               [1, 1, 3], [5, 2], devices[0])
        total_without_dups = _make_indexed_slices([[4., 6.], [5., 6.]], [1, 3],
                                                  [5, 2], devices[0])
        self._assert_indexed_slices_equal(total_with_dups, result)
        self._assert_indexed_slices_equal(total_without_dups, result)

    @combinations.generate(
        combinations.combine(cross_tower_ops_instance=[
            combinations.NamedObject(
                "ReductionToOneDeviceCrossTowerOps",
                cross_tower_ops_lib.ReductionToOneDeviceCrossTowerOps()),
            combinations.NamedObject(
                "AllReduceCrossTowerOps",
                cross_tower_ops_lib.AllReduceCrossTowerOps())
        ],
                             method_string=["sum", "mean"],
                             batch_reduce=[True, False],
                             mode=["graph", "eager"],
                             required_gpus=1))
    def testIndexedSlicesAllReduce(self, cross_tower_ops_instance,
                                   method_string, batch_reduce):
        devices = ["/cpu:0", "/gpu:0"]
        dense_shape = [5, 2]
        t0 = _make_indexed_slices([[1., 2.]], [1], dense_shape, devices[0])
        t1 = _make_indexed_slices([[3., 4.], [5., 6.]], [1, 3], dense_shape,
                                  devices[1])
        per_device = value_lib.PerDevice({devices[0]: t0, devices[1]: t1})

        if batch_reduce:
            result = cross_tower_ops_instance.batch_reduce(
                method_string, [(per_device, devices)])
        else:
            result = cross_tower_ops_instance.reduce(method_string, per_device,
                                                     devices)

        total_indices_with_dups = [1, 1, 3]
        total_indices_without_dups = [1, 3]

        if method_string == "sum":
            total_values_with_dups = [[1., 2.], [3., 4.], [5., 6.]]
            total_values_without_dups = [[4., 6.], [5., 6.]]
        else:
            assert method_string == "mean"
            total_values_with_dups = [[0.5, 1.], [1.5, 2.], [2.5, 3.]]
            total_values_without_dups = [[2., 3.], [2.5, 3.]]

        total_mirrored_with_dups = _make_mirrored_indexed_slices(
            devices, total_values_with_dups, total_indices_with_dups,
            dense_shape)
        total_mirrored_without_dups = _make_mirrored_indexed_slices(
            devices, total_values_without_dups, total_indices_without_dups,
            dense_shape)

        # Test that the result is semantically equal to both the concatenated
        # IndexedSlices, as well as when the duplicate indices are summed up.
        if batch_reduce:
            total_mirrored_with_dups = [total_mirrored_with_dups]
            total_mirrored_without_dups = [total_mirrored_without_dups]

        self._assert_values_equal(total_mirrored_with_dups, result)
        self._assert_values_equal(total_mirrored_without_dups, result)
Beispiel #14
0
def main(_):
    for key in FLAGS.__dict__['__flags'].keys():
        if key in ('h', 'help'):
            continue
        print("%s=%s" % (key, str(FLAGS.__dict__['__flags'][key])))
    print("tensorflow version:", tf.__version__)
    if FLAGS.volumes:
        volumes = FLAGS.volumes.split(",")
        train_files = volumes[0] + "/*.tfr"
        eval_files = volumes[1] + "/*.tfr" if len(volumes) > 1 else train_files
    else:
        train_files = FLAGS.train_data
        eval_files = FLAGS.eval_data
    print("train_data:", train_files)
    print("eval_data:", eval_files)
    if "TF_CONFIG" in os.environ:
        print("TF_CONFIG", json.loads(os.environ["TF_CONFIG"]))
    model_params = {
        'batch_size': FLAGS.batch_size,
        'learning_rate': FLAGS.learning_rate,
        'dropout_rate': FLAGS.dropout_rate,
        'word_cnn_filter_sizes': map(int, FLAGS.word_cnn_filter_sizes.split(',')),
        'word_cnn_num_filters': FLAGS.word_cnn_num_filters,
        'char_cnn_filter_sizes': map(int, FLAGS.char_cnn_filter_sizes.split(',')),
        'char_cnn_num_filters': FLAGS.char_cnn_num_filters,
        'word_vocab_size': FLAGS.word_vocab_size,
        'char_vocab_size': FLAGS.char_vocab_size,
        'tag_vocab_size': FLAGS.tag_vocab_size,
        'word_embedding_size': FLAGS.word_embedding_size,
        'char_embedding_size': FLAGS.char_embedding_size,
        'tag_embedding_size': FLAGS.tag_embedding_size,
        'margin': FLAGS.margin,
        'negative_margin': FLAGS.negative_margin,
        'smooth': FLAGS.smooth,
        'l2_scale': FLAGS.l2_regularizer_scale,
        't': FLAGS.t, 'negative_t': FLAGS.negative_t,
        'use_lower_loss': FLAGS.use_lower_loss,
        'hidden_units': map(int, FLAGS.hidden_units.split(',')),
        'activations': FLAGS.activations.split(','),
        'use_batch_norm': FLAGS.use_batch_norm,
        'use_feature': FLAGS.use_feature,
        'num_negative_samples': FLAGS.num_negative_samples,
        'warm_start': FLAGS.warm_start,
        'init_checkpoint': FLAGS.init_checkpoint
    }
    if FLAGS.use_feature:
        feature_columns = create_feature_columns()
        num_cols = len(feature_columns)
        anchor_feature_columns = [feature_columns[i] for i in range(0, num_cols, 3)]
        positive_feature_columns = [feature_columns[i] for i in range(1, num_cols, 3)]
        negative_feature_columns = [feature_columns[i] for i in range(2, num_cols, 3)]
        print("anchor_feature_columns =", anchor_feature_columns)
        print("higher_feature_columns =", positive_feature_columns)
        print("lower_feature_columns =", negative_feature_columns)
        model_params.update({
            'anchor_feature_columns': anchor_feature_columns,
            'higher_feature_columns': positive_feature_columns,
            'lower_feature_columns': negative_feature_columns
        })
    session_config = tf.ConfigProto(
        allow_soft_placement=True,
        log_device_placement=FLAGS.log_device_placement,
        intra_op_parallelism_threads=FLAGS.intra_op_parallelism_threads,
        inter_op_parallelism_threads=FLAGS.inter_op_parallelism_threads,
        gpu_options=tf.GPUOptions(allow_growth=True, force_gpu_compatible=True)
    )
    if FLAGS.evaluate:
        cluster = {'chief': ['localhost:2221'], 'worker': ['localhost:2222']}
        os.environ['TF_CONFIG'] = json.dumps({'cluster': cluster, 'task': {'type': 'evaluator', 'index': 0}})
        config = tf.estimator.RunConfig(session_config=session_config)
        model = SemanticModel(params=model_params, optimizer=FLAGS.optimizer, model_dir=FLAGS.checkpointDir, config=config)
        evaluation_listener(model, train_files, eval_files)
        return
    else:
        cross_tower_ops = cross_tower_ops_lib.AllReduceCrossTowerOps('nccl', 16, 0, 0)
        distribution = tf.contrib.distribute.MirroredStrategy(num_gpus=FLAGS.num_gpus, cross_tower_ops=cross_tower_ops)
        config = tf.estimator.RunConfig(
            distribute=distribution,
            save_checkpoints_secs=FLAGS.save_checkpoints_secs,
            keep_checkpoint_max=FLAGS.keep_checkpoint_max, session_config=session_config,
            log_step_count_steps=FLAGS.log_step_count_steps,
            save_summary_steps=FLAGS.save_summary_steps)
    if FLAGS.tf_random_seed != 0:
        config.replace(tf_random_seed=FLAGS.tf_random_seed)

    # if FLAGS.init_embedding and FLAGS.tables:
    #     tables = FLAGS.tables.split(',')
    #     word_dict = tables[0]
    #     word_embed = tables[1]
    #     word_initializer = build_embedding_initializer(word_dict, word_embed, FLAGS.word_vocab_size, FLAGS.word_embedding_size)
    #     model_params['word_initializer'] = word_initializer
    #     if len(tables) > 2:
    #         char_dict = tables[2]
    #         char_embed = tables[3]
    #         char_initializer = build_embedding_initializer(char_dict, char_embed, FLAGS.char_vocab_size, FLAGS.char_embedding_size)
    #         model_params['char_initializer'] = char_initializer

    model = SemanticModel(params=model_params, optimizer=FLAGS.optimizer, model_dir=FLAGS.checkpointDir, config=config)

    train_hooks = [tf.train.ProfilerHook(save_secs=FLAGS.save_checkpoints_secs, output_dir=FLAGS.buckets)] if FLAGS.profile else []

    logging.info("before train")
    model.train(lambda: input_fn(train_files, True), max_steps=FLAGS.train_steps, hooks=train_hooks)
    logging.info("after train")

    if config.is_chief:
        logging.info("exporting model ...")
        serving_input_receiver_fn = get_serving_input_fn()
        model.export_savedmodel(FLAGS.outputs, serving_input_receiver_fn)
        logging.info("print model variables ...")
        model_statistics(model)
    logging.info("finish main")
class CrossTowerOpsTest(test.TestCase, parameterized.TestCase):
    def _assert_value_equal(self, left, right):
        if isinstance(left, list):
            for l, r in zip(left, right):
                self._assert_value_equal(l, r)
        else:
            self.assertEqual(type(left), type(right))
            self.assertEqual(left.devices, right.devices)
            if context.executing_eagerly():
                self.assertEqual([v.numpy() for v in left._index.values()],
                                 list(right._index.values()))
            else:
                with self.test_session() as sess:
                    self.assertEqual(sess.run(list(left._index.values())),
                                     list(right._index.values()))

    # TODO (yuefengz): decouple the num_gpus check from distribution in id:1093
    # https://github.com/imdone/tensorflow/issues/1094
    # combinations module so that we can pass in devices instead of a distribution
    # strategy.
    reduction_to_one_combinations = combinations.combine(
        cross_tower_ops=[
            combinations.NamedObject(
                "DefaultReductionToOneDeviceCrossTowerOps",
                cross_tower_ops_lib.ReductionToOneDeviceCrossTowerOps()),
            combinations.NamedObject(
                "ReductionToCPUDeviceCrossTowerOps",
                cross_tower_ops_lib.ReductionToOneDeviceCrossTowerOps(
                    reduce_to_device=_cpu_device)),
            combinations.NamedObject(
                "AccumulateNCrossTowerOp",
                cross_tower_ops_lib.ReductionToOneDeviceCrossTowerOps(
                    accumulation_fn=math_ops.accumulate_n)),
        ],
        distribution=[
            combinations.one_device_strategy,
            combinations.mirrored_strategy_with_gpu_and_cpu,
            combinations.mirrored_strategy_with_two_gpus
        ],
        mode=["graph", "eager"])
    allreduce_combinations = combinations.combine(
        cross_tower_ops=[
            combinations.NamedObject(
                "AllReduce",
                cross_tower_ops_lib.AllReduceCrossTowerOps("nccl", 1, 0, 0)),
            combinations.NamedObject(
                "HierarchicalCopy",
                cross_tower_ops_lib.AllReduceCrossTowerOps(
                    "hierarchical_copy", 8, 0, 0)),
            combinations.NamedObject(
                "AllReduceNoGradientRepacking",
                cross_tower_ops_lib.AllReduceCrossTowerOps("nccl", 0, 0, 0)),
            combinations.NamedObject(
                "HierarchicalCopyAggregateSmallTensors",
                cross_tower_ops_lib.AllReduceCrossTowerOps(
                    "hierarchical_copy", 0, 100, 10))
        ],
        distribution=[combinations.mirrored_strategy_with_two_gpus],
        mode=["graph", "eager"])

    @combinations.generate(reduction_to_one_combinations +
                           allreduce_combinations)
    def testReductionAndBroadcast(self, cross_tower_ops, distribution):
        devices = distribution.worker_devices

        values = [constant_op.constant(float(d)) for d in range(len(devices))]
        per_device = _make_per_device(values, devices)
        mean = (len(devices) - 1.) / 2.

        values_2 = [constant_op.constant(d + 1.0) for d in range(len(devices))]
        per_device_2 = _make_per_device(values_2, devices)
        mean_2 = mean + 1.

        destination_mirrored = _fake_mirrored(1., devices)
        destination_different = _fake_mirrored(1., _cpu_device)
        destination_str = _cpu_device
        destination_list = devices

        all_destinations = [
            None, destination_mirrored, destination_different, destination_str,
            destination_list
        ]

        # test reduce()
        for destinations in all_destinations:
            self._assert_value_equal(
                cross_tower_ops.reduce("mean",
                                       per_device,
                                       destinations=destinations),
                _fake_mirrored(mean, destinations or per_device))
            self._assert_value_equal(
                cross_tower_ops.reduce("mean",
                                       per_device_2,
                                       destinations=destinations),
                _fake_mirrored(mean_2, destinations or per_device))
            self._assert_value_equal(
                cross_tower_ops.reduce("sum",
                                       per_device,
                                       destinations=destinations),
                _fake_mirrored(mean * len(devices), destinations
                               or per_device))
            self._assert_value_equal(
                cross_tower_ops.reduce("sum",
                                       per_device_2,
                                       destinations=destinations),
                _fake_mirrored(mean_2 * len(devices), destinations
                               or per_device))

        # test batch_reduce()
        for d1, d2 in itertools.product(all_destinations, all_destinations):
            self._assert_value_equal(
                cross_tower_ops.batch_reduce("mean", [(per_device, d1),
                                                      (per_device_2, d2)]),
                [
                    _fake_mirrored(mean, d1 or per_device),
                    _fake_mirrored(mean_2, d2 or per_device_2)
                ])
            self._assert_value_equal(
                cross_tower_ops.batch_reduce("sum", [(per_device, d1),
                                                     (per_device_2, d2)]),
                [
                    _fake_mirrored(mean * len(devices), d1 or per_device),
                    _fake_mirrored(mean_2 * len(devices), d2 or per_device_2)
                ])

        # test broadcast()
        for destinations in all_destinations:
            if destinations is None:
                continue
            else:
                self._assert_value_equal(
                    cross_tower_ops.broadcast(constant_op.constant(1.),
                                              destinations),
                    _fake_mirrored(1., destinations))

    def testChooseAlgorithm(self):
        device_links = [[1, 2, 3, 4], [0, 2, 3, 5], [0, 1, 3, 6], [0, 1, 2, 7],
                        [0, 5, 6, 7], [1, 4, 6, 7], [2, 4, 5, 7], [3, 4, 5, 6]]
        result = cross_tower_ops_lib._choose_all_reduce_algorithm(device_links)
        self.assertTrue(
            isinstance(result, cross_tower_ops_lib.AllReduceCrossTowerOps))
        self.assertEqual(result.all_reduce_alg, "hierarchical_copy")
        self.assertEqual(result.num_packs, 8)

        # if there are only 4 devices
        device_links = [[1, 2, 3, 4], [0, 2, 3, 5], [0, 1, 3, 6], [0, 1, 2, 7]]
        result = cross_tower_ops_lib._choose_all_reduce_algorithm(device_links)
        self.assertTrue(
            isinstance(result, cross_tower_ops_lib.AllReduceCrossTowerOps))
        self.assertEqual(result.all_reduce_alg, "nccl")
        self.assertEqual(result.num_packs, 1)

        # if devices links contain each device itself
        device_links = [[0, 1, 2, 3, 4], [0, 1, 2, 3, 5], [0, 1, 2, 3, 6],
                        [0, 1, 2, 3, 7], [0, 4, 5, 6, 7], [1, 4, 5, 6, 7],
                        [2, 4, 5, 6, 7], [3, 4, 5, 6, 7]]
        result = cross_tower_ops_lib._choose_all_reduce_algorithm(device_links)
        self.assertTrue(
            isinstance(result, cross_tower_ops_lib.AllReduceCrossTowerOps))
        self.assertEqual(result.all_reduce_alg, "hierarchical_copy")
        self.assertEqual(result.num_packs, 8)

        # if not dgx1-like links
        device_links = [[0, 2, 3, 5], [0, 1, 3, 6], [0, 1, 2, 7], [0, 5, 6, 7],
                        [1, 4, 6, 7], [2, 4, 5, 7], [3, 4, 5, 6], [1, 2, 3, 4]]
        result = cross_tower_ops_lib._choose_all_reduce_algorithm(device_links)
        self.assertTrue(
            isinstance(result, cross_tower_ops_lib.AllReduceCrossTowerOps))
        self.assertEqual(result.all_reduce_alg, "nccl")
        self.assertEqual(result.num_packs, 1)