Exemplo n.º 1
0
 def computation_fn():
     graph = mtf.Graph()
     mesh = mtf.Mesh(graph, 'my_mesh')
     mesh_shape = mtf.convert_to_shape('all:2')
     layout = 'none:all'
     mesh_devices = [''] * mesh_shape.size
     mesh_impl = mtf.simd_mesh_impl.SimdMeshImpl(
         mesh_shape, mtf.convert_to_layout_rules(layout),
         mesh_devices, device_assignment)
     hidden_dim = mtf.Dimension('hidden', 3)
     w = mtf.get_variable(mesh,
                          'w',
                          shape=[hidden_dim],
                          initializer=tf.constant_initializer(
                              [0.1, -0.2, -0.1]))
     x = mtf.constant(mesh, [0.4, 0.2, -0.5], [hidden_dim],
                      dtype=tf.float32)
     loss = mtf.reduce_mean(mtf.square(x - w))
     var_grads = mtf.gradients(
         [loss], [v.outputs[0] for v in graph.trainable_variables])
     optimizer = mtf_optimize.AdamWeightDecayOptimizer(
         learning_rate=0.2)
     update_ops = optimizer.apply_grads(var_grads,
                                        graph.trainable_variables)
     self.lowering = mtf.Lowering(graph, {mesh: mesh_impl})
     tf_update_ops = [
         self.lowering.lowered_operation(op) for op in update_ops
     ]
     return tf.group(tf_update_ops)
Exemplo n.º 2
0
    def testRecomputeGrad(self):
        graph = mtf.Graph()
        mesh = mtf.Mesh(graph, "my_mesh")

        # let's differentiate x^2 + x
        # dy/dx = 2x+1
        def x_squared_plus_x(x):
            return x * x + x

        x = tf.constant([5, 10], dtype=tf.float32)
        dy = tf.constant([2, 3], dtype=tf.float32)
        two = mtf.Dimension("two", 2)
        expected_y = tf.constant([30, 110], dtype=tf.float32)
        expected_dx = tf.constant([22, 63], dtype=tf.float32)
        mtf_x = mtf.import_fully_replicated(mesh, x, shape=mtf.Shape([two]))
        mtf_dy = mtf.import_tf_tensor(mesh, dy, shape=mtf.Shape([two]))
        mtf_y = mtf.recompute_grad(x_squared_plus_x, [mtf_x])
        [mtf_dx] = mtf.gradients([mtf_y], [mtf_x], [mtf_dy])
        mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl(
            shape="processors:2", layout="two:processors", devices=["", ""])
        lowering = mtf.Lowering(graph, {mesh: mesh_impl})
        actual_y = lowering.export_to_tf_tensor(mtf_y)
        actual_dx = lowering.export_to_tf_tensor(mtf_dx)
        self.assertAllEqual(self.evaluate(actual_y), self.evaluate(expected_y))
        self.assertAllEqual(self.evaluate(actual_dx),
                            self.evaluate(expected_dx))
Exemplo n.º 3
0
def model_fn(features, labels, mode, params):
	"""The model_fn argument for creating an Estimator."""
	global_step = tf.train.get_global_step()
	graph = mtf.Graph()
	mesh = mtf.Mesh(graph, "my_mesh")
	logits, loss = model_backbone(features, labels, mesh)
	
	variables = graph._all_variables
	for v in variables:
		logger.debug("[parameter] (name,shape,dtype): ({},{},{})".format(v.name,v.shape,v.dtype.master_dtype))
	mesh_shape = mtf.convert_to_shape(args_opt.mesh_shape)
	# layout_rules = mtf.auto_mtf.layout(graph, mesh_shape, [logits, loss])
	mesh_shape = mtf.convert_to_shape(mesh_shape)
	estimator = memory_estimator.MemoryEstimator(graph, mesh_shape, [logits, loss])
	optimizer = layout_optimizer.LayoutOptimizer(estimator,scheduler_alg="NAIVE")
	layout_rules =  mtf.convert_to_layout_rules(optimizer.solve())

	# layout_rules=[('batch', 'b1')]

	logger.info("[auto mtf search] strategy: {}".format(layout_rules))
	mesh_devices = ["gpu:{}".format(i) for i in range(int(args_opt.num_gpus))]
	mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl(mesh_shape, layout_rules, mesh_devices)



	if mode == tf.estimator.ModeKeys.TRAIN:
		var_grads = mtf.gradients(
			[loss], [v.outputs[0] for v in graph.trainable_variables])
		optimizer = mtf.optimize.SgdOptimizer(0.01)
		# optimizer = tf.train.experimental.enable_mixed_precision_graph_rewrite(optimizer)
		update_ops = optimizer.apply_grads(var_grads, graph.trainable_variables)

	lowering = mtf.Lowering(graph, {mesh: mesh_impl})
	restore_hook = mtf.MtfRestoreHook(lowering)

	tf_logits = lowering.export_to_tf_tensor(logits)
	if mode != tf.estimator.ModeKeys.PREDICT:
		tf_loss = lowering.export_to_tf_tensor(loss)

	if mode == tf.estimator.ModeKeys.TRAIN:
		tf_update_ops = [lowering.lowered_operation(op) for op in update_ops]
		tf_update_ops.append(tf.assign_add(global_step, 1))
		train_op = tf.group(tf_update_ops)

		accuracy = tf.metrics.accuracy(
			labels=labels, predictions=tf.argmax(tf_logits, axis=1))

		# Name tensors to be logged with LoggingTensorHook.
		tf.identity(tf_loss, "cross_entropy")
		tf.identity(accuracy[1], name="train_accuracy")

		logging_hook = tf.train.LoggingTensorHook(every_n_iter=100,tensors={'loss': 'cross_entropy','acc':'train_accuracy'})
		# profiling_hook = tf.estimator.ProfilerHook(save_steps=20, output_dir='./profiling/')
		# restore_hook must come before saver_hook
		return tf.estimator.EstimatorSpec(
			tf.estimator.ModeKeys.TRAIN, loss=tf_loss, train_op=train_op,
			training_chief_hooks=[restore_hook,logging_hook])
Exemplo n.º 4
0
def test_entmax():
    graph = mtf.Graph()
    mesh = mtf.Mesh(graph, "my_mesh")
    length = mtf.Dimension("tensor_length", 8)
    tensor = mtf.range(mesh, length, tf.float32)
    output = entmax(tensor)
    grad = mtf.gradients([output], [tensor])[0]
    sample = sample_categorical(output, length)

    mesh_impl = placement_mesh_impl.PlacementMeshImpl(shape=[], layout={}, devices=[""])
    lowering = mtf.Lowering(graph, {mesh: mesh_impl})
    sample = lowering.export_to_tf_tensor(sample)
    grad = lowering.export_to_tf_tensor(grad)
Exemplo n.º 5
0
  def testTopK(self):
    graph = mtf.Graph()
    mesh = mtf.Mesh(graph, "my_mesh")
    a_dim = mtf.Dimension("a", 6)
    b_dim = mtf.Dimension("b", 2)
    inputs = tf.constant([[1, 10],
                          [2, 9],
                          [3, 8],
                          [4, 7],
                          [5, 6],
                          [6, 5]],
                         dtype=tf.float32)
    k_dim = mtf.Dimension("k", 2)
    d_values = tf.constant([[11, 12], [13, 14]], dtype=tf.float32)
    reduced_dim = a_dim
    expected_values = tf.constant([[6, 5], [10, 9]], dtype=tf.float32)
    expected_indices = tf.constant([[5, 4], [0, 1]])
    expected_d_inputs = tf.constant([[0, 13],
                                     [0, 14],
                                     [0, 0],
                                     [0, 0],
                                     [12, 0],
                                     [11, 0]],
                                    dtype=tf.float32)

    mtf_inputs = mtf.import_fully_replicated(
        mesh, inputs, shape=mtf.Shape([a_dim, b_dim]))
    mtf_d_values = mtf.import_tf_tensor(
        mesh, d_values, shape=mtf.Shape([b_dim, k_dim]))
    mtf_values, mtf_indices = mtf.top_k(mtf_inputs,
                                        reduced_dim=reduced_dim,
                                        k_dim=k_dim,
                                        name="test_nth_smallest")
    [mtf_d_inputs] = mtf.gradients([mtf_values], [mtf_inputs], [mtf_d_values])
    mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl(
        shape="rows:2,cols:2", layout="a:rows,b:cols", devices=["", "", "", ""])
    lowering = mtf.Lowering(graph, {mesh: mesh_impl})
    actual_values = lowering.export_to_tf_tensor(mtf_values)
    actual_indices = lowering.export_to_tf_tensor(mtf_indices)
    actual_d_inputs = lowering.export_to_tf_tensor(mtf_d_inputs)
    actual_inputs = lowering.export_to_tf_tensor(mtf_inputs)
    self.assertAllEqual(self.evaluate(actual_inputs),
                        self.evaluate(inputs))
    self.assertAllEqual(self.evaluate(actual_values),
                        self.evaluate(expected_values))
    self.assertAllEqual(self.evaluate(actual_indices),
                        self.evaluate(expected_indices))
    self.assertAllEqual(self.evaluate(actual_d_inputs),
                        self.evaluate(expected_d_inputs))
Exemplo n.º 6
0
    def my_model_fn(features, labels, mode, params=None, config=None):
        """Estimator model function.

    Args:
      features: input features dictionary
      labels: ignored
      mode: a tf.estimator.ModeKeys
      params: something
      config: something

    Returns:
      something
    """
        del labels, config
        global_step = tf.train.get_global_step()
        if use_tpu:
            ctx = params["context"]
            num_hosts = ctx.num_hosts
            host_placement_fn = ctx.tpu_host_placement_function
            device_list = [
                host_placement_fn(host_id=t) for t in range(num_hosts)
            ]
            # TODO(ylc): Better estimation of replica cache size?
            replica_cache_size = 300 * 1000000  # 300M per replica
            # Worker 0 caches all the TPU binaries.
            worker0_mem = replica_cache_size * ctx.num_replicas
            devices_memeory_usage = [worker0_mem] + [0] * (num_hosts - 1)
            var_placer = mtf.utils.BalancedVariablePlacer(
                device_list, devices_memeory_usage)
            mesh_devices = [""] * mesh_shape.size
            physical_shape = list(
                params["context"].device_assignment.topology.mesh_shape)
            logical_to_physical = _logical_to_physical(physical_shape,
                                                       mesh_shape)
            mesh_impl = mtf.simd_mesh_impl.SimdMeshImpl(
                mesh_shape,
                layout_rules,
                mesh_devices,
                ctx.device_assignment,
                logical_to_physical=logical_to_physical)
        else:
            var_placer = None
            mesh_devices = [""] * mesh_shape.size
            mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl(
                mesh_shape, layout_rules, mesh_devices)

        graph = mtf.Graph()
        mesh = mtf.Mesh(graph, "my_mesh", var_placer)

        def _import_feature(key, allow_missing=False):
            """Import a feature from the features dictionary into a mtf.Tensor.

      Args:
        key: a string
        allow_missing: a boolean

      Returns:
        a mtf.Tensor with dtype int32 and shape [batch_dim, length_dim]
      """
            outer_batch_dim = mtf.Dimension("outer_batch", outer_batch_size)
            batch_dim = mtf.Dimension("batch", batch_size // outer_batch_size)
            length_dim = mtf.Dimension("length", sequence_length)

            mtf_shape = mtf.Shape([outer_batch_dim, batch_dim, length_dim])
            if key not in features:
                if allow_missing:
                    return None
                else:
                    raise ValueError("feature not found %s - features %s = " %
                                     (key, features))
            tf.logging.info("Import feature %s: %s" % (key, features[key]))

            x = tf.to_int32(features[key])
            x = tf.reshape(
                x, [outer_batch_size, batch_size // outer_batch_size, -1])

            if not use_tpu:
                x = tf.Print(x, [x],
                             "import feature %s" % key,
                             summarize=1000,
                             first_n=1)
            return mtf.import_fully_replicated(mesh, x, mtf_shape, name=key)

        if mode == tf.estimator.ModeKeys.PREDICT:
            inputs = _import_feature("inputs")
            inputs = mtf.reshape(
                inputs,
                mtf.Shape([
                    mtf.Dimension("batch", batch_size),
                    mtf.Dimension("length", sequence_length)
                ]))
            if isinstance(transformer_model, transformer.Unitransformer):
                mtf_samples = transformer_model.sample_autoregressive(
                    inputs, variable_dtype=get_variable_dtype())
            elif isinstance(transformer_model, transformer.Bitransformer):
                mtf_samples = transformer_model.decode(
                    inputs, variable_dtype=get_variable_dtype())
            else:
                raise ValueError("unrecognized class")
            mtf_samples = mtf.anonymize(mtf_samples)
            lowering = mtf.Lowering(graph, {mesh: mesh_impl},
                                    autostack=autostack)
            outputs = lowering.export_to_tf_tensor(mtf_samples)
            predictions = {"outputs": outputs}
            return tpu_estimator.TPUEstimatorSpec(
                mode=tf.estimator.ModeKeys.PREDICT,
                predictions=predictions,
                prediction_hooks=[mtf.MtfRestoreHook(lowering)])

        targets = _import_feature("targets")
        anon_targets = mtf.anonymize(targets)
        if model_type == "lm":
            _, length_dim = targets.shape
            inputs = mtf.shift(targets, offset=1, dim=length_dim, wrap=False)
        else:
            inputs = _import_feature("inputs")

        if mode == tf.estimator.ModeKeys.EVAL:
            if isinstance(transformer_model, transformer.Unitransformer):
                mtf_samples = transformer_model.sample_autoregressive(
                    inputs, variable_dtype=get_variable_dtype())
            elif isinstance(transformer_model, transformer.Bitransformer):
                mtf_samples = transformer_model.decode(
                    inputs, variable_dtype=get_variable_dtype())
            else:
                raise ValueError("unrecognized class")
            mtf_samples = mtf.anonymize(mtf_samples)
            lowering = mtf.Lowering(graph, {mesh: mesh_impl},
                                    autostack=autostack)
            outputs = lowering.export_to_tf_tensor(mtf_samples)
            labels = lowering.export_to_tf_tensor(anon_targets)
            restore_hook = mtf.MtfRestoreHook(lowering)

            # metric_names becomes locally scoped if we simply assign
            # ["padded_neg_log_perplexity"] to it conditioned on if it's None.
            local_metric_names = metric_names or ["token_accuracy"]

            def metric_fn(labels, outputs):
                return get_metric_fns(local_metric_names, labels, outputs)

            eval_metrics = (metric_fn, [labels, outputs])
            return tpu_estimator.TPUEstimatorSpec(
                tf.estimator.ModeKeys.EVAL,
                # Unfortunately TPUEstimatorSpec requires us to provide a value for
                # loss when in EVAL mode. Since we are sampling or decoding from the
                # model, we don't have a loss to report.
                loss=tf.constant(0.),
                evaluation_hooks=[restore_hook],
                eval_metrics=eval_metrics)

        if isinstance(transformer_model, transformer.Unitransformer):
            position_kwargs = dict(
                sequence_id=_import_feature("targets_segmentation", True),
                position=_import_feature("targets_position", True),
            )
        elif isinstance(transformer_model, transformer.Bitransformer):
            position_kwargs = dict(
                encoder_sequence_id=_import_feature("inputs_segmentation",
                                                    True),
                decoder_sequence_id=_import_feature("targets_segmentation",
                                                    True),
                encoder_position=_import_feature("inputs_position", True),
                decoder_position=_import_feature("targets_position", True),
            )
        else:
            raise ValueError("unrecognized class")

        logits, loss = transformer_model.call_simple(
            inputs=inputs,
            targets=targets,
            compute_loss=True,
            mode=mode,
            variable_dtype=get_variable_dtype(),
            **position_kwargs)

        if use_tpu and logits is not None:
            logits = mtf.anonymize(logits)

        # TRAIN mode
        if mode == tf.estimator.ModeKeys.TRAIN:
            var_grads = mtf.gradients(
                [loss], [v.outputs[0] for v in graph.trainable_variables])
            optimizer = mtf.optimize.AdafactorOptimizer(
                learning_rate=learning_rate)
            update_ops = optimizer.apply_grads(var_grads,
                                               graph.trainable_variables)

        lowering = mtf.Lowering(graph, {mesh: mesh_impl}, autostack=autostack)

        tf_loss = lowering.export_to_tf_tensor(loss)
        tf_loss = tf.to_float(tf_loss)
        if not use_tpu:
            tf_loss = tf.Print(tf_loss,
                               [tf_loss, tf.train.get_global_step()],
                               "step, tf_loss")

        if mode == tf.estimator.ModeKeys.TRAIN:
            tf_update_ops = [
                lowering.lowered_operation(op) for op in update_ops
            ]
            tf_update_ops.append(tf.assign_add(global_step, 1))
            train_op = tf.group(tf_update_ops)

        with mtf.utils.outside_all_rewrites():
            # Copy master variables to slices. Must be called first.
            restore_hook = mtf.MtfRestoreHook(lowering)
            saver = tf.train.Saver(tf.global_variables(),
                                   sharded=True,
                                   max_to_keep=checkpoints_to_keep,
                                   keep_checkpoint_every_n_hours=2,
                                   defer_build=False,
                                   save_relative_paths=True)
            tf.add_to_collection(tf.GraphKeys.SAVERS, saver)
            saver_listener = mtf.MtfCheckpointSaverListener(lowering)
            saver_hook = tf.train.CheckpointSaverHook(
                model_dir,
                save_steps=save_steps,
                saver=saver,
                listeners=[saver_listener])
            gin_config_saver_hook = gin.tf.GinConfigSaverHook(
                model_dir, summarize_config=True)

            if mode == tf.estimator.ModeKeys.TRAIN:
                if use_tpu:
                    return tpu_estimator.TPUEstimatorSpec(
                        mode=tf.estimator.ModeKeys.TRAIN,
                        loss=tf_loss,
                        train_op=train_op,
                        training_hooks=[
                            restore_hook,
                            saver_hook,
                            gin_config_saver_hook,
                        ])
                else:
                    return tf.estimator.EstimatorSpec(
                        tf.estimator.ModeKeys.TRAIN,
                        loss=tf_loss,
                        train_op=train_op,
                        training_chief_hooks=[
                            restore_hook,
                            saver_hook,
                            gin_config_saver_hook,
                        ])
Exemplo n.º 7
0
def model_fn(features, labels, mode, params):
    """The model_fn argument for creating an Estimator."""

    #tf.logging.info("features = %s labels = %s mode = %s params=%s" %
    #              (features, labels, mode, params))

    global_step = tf.train.get_global_step()
    graph = mtf.Graph()
    mesh = mtf.Mesh(graph, "my_mesh")
    data = features['data']
    R0 = features['R0'] * 1.
    x0 = features['x0']
    print("\nR0 in the model function : %0.1f\n" % R0)
    fields, metrics, kv = recon_model(mesh, data, R0, x0)
    fieldvar, final_field = fields
    chisq, prior, loss = metrics

    ##
    layout_rules = [("nx_lr", "row"), ("ny_lr", "col"), ("nx", "row"),
                    ("ny", "col"), ("ty", "row"), ("tz", "col"),
                    ("ty_lr", "row"), ("tz_lr", "col"), ("nx_block", "row"),
                    ("ny_block", "col")]

    mesh_shape = [("row", FLAGS.nx), ("col", FLAGS.ny)]

    mesh_size = FLAGS.nx * FLAGS.ny  # mesh_shape.size
    mesh_devices = [""] * mesh_size
    mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl(
        mesh_shape, layout_rules, mesh_devices)

    #    mesh_shape = mtf.convert_to_shape(FLAGS.mesh_shape)
    #    mesh_size = mesh_shape.size
    #    layout_rules = [("nx_lr", "row"), ("ny_lr", "col"),
    #                    ("nx", "row"), ("ny", "col"),
    #                    ("ty", "row"), ("tz", "col"),
    #                    ("ty_lr", "row"), ("tz_lr", "col"),
    #                    ("nx_block","row"), ("ny_block","col")]
    #    devices = ["/job:localhost/replica:0/task:%d/device:GPU:%d"%(i,j) for i in range(0) for j in range(FLAGS.gpus_per_task)]
    #    devices = [""] * mesh_size
    #    mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl(mesh_shape, layout_rules, devices)
    print(mesh_impl)

    ##
    if mode == tf.estimator.ModeKeys.TRAIN:
        var_grads = mtf.gradients(
            [loss], [v.outputs[0] for v in graph.trainable_variables])

        #        nyq = np.pi*nc/bs
        #        def _cwise_highpass(kfield, kx, ky, kz):
        #            kx = tf.reshape(kx, [-1, 1, 1])
        #            ky = tf.reshape(ky, [1, -1, 1])
        #            kz = tf.reshape(kz, [1, 1, -1])
        #            kk = (kx / bs * nc)**2 + (ky/ bs * nc)**2 + (kz/ bs * nc)**2
        #            wts = tf.cast(tf.exp(- kk* (R0*bs/nc + 1/nyq)**2), kfield.dtype)
        #            return kfield * (1-wts)
        #
        #        k_dims_pr = [d.shape[0] for d in kv]
        #        k_dims_pr = [k_dims_pr[2], k_dims_pr[0], k_dims_pr[1]]
        #        cgrads = mesh_utils.r2c3d(var_grads[0], k_dims_pr, dtype=tf.complex64)
        #        cgrads = mtf.cwise(_cwise_highpass, [cgrads] + kv, output_dtype=tf.complex64)
        #        var_grads = [mesh_utils.c2r3d(cgrads, var_grads[0].shape[-3:], dtype=tf.float32)]
        #        update_ops = [mtf.assign(fieldvar, fieldvar - var_grads[0]*0.2)]

        #optimizer = mtf.optimize.AdafactorOptimizer(1)
        #optimizer = mtf.optimize.SgdOptimizer(0.01)
        #optimizer = mtf.optimize.MomentumOptimizer(0.01, 0.001)
        optimizer = mtf.optimize.AdamWeightDecayOptimizer(features['lr'])
        update_ops = optimizer.apply_grads(var_grads,
                                           graph.trainable_variables)

#

    start = time.time()
    lowering = mtf.Lowering(graph, {mesh: mesh_impl})
    print("\nTime taken for lowering is : %0.3f" % (time.time() - start))
    restore_hook = mtf.MtfRestoreHook(lowering)
    #
    tf_init = lowering.export_to_tf_tensor(fieldvar)
    tf_data = lowering.export_to_tf_tensor(final_field)
    tf_loss = lowering.export_to_tf_tensor(loss)
    tf_chisq = lowering.export_to_tf_tensor(chisq)
    tf_prior = lowering.export_to_tf_tensor(prior)

    ##Predict
    if mode == tf.estimator.ModeKeys.PREDICT:
        tf_loss = lowering.export_to_tf_tensor(loss)
        tf.summary.scalar("loss", tf_loss)
        tf.summary.scalar("chisq", tf_chisq)
        tf.summary.scalar("prior", tf_prior)
        predictions = {
            "ic": tf_init,
            "data": tf_data,
        }
        return tf.estimator.EstimatorSpec(
            mode=tf.estimator.ModeKeys.PREDICT,
            predictions=predictions,
            prediction_hooks=[restore_hook],
            export_outputs={
                "data": tf.estimator.export.PredictOutput(
                    predictions)  #TODO: is classify a keyword?
            })

    ##Train
    if mode == tf.estimator.ModeKeys.TRAIN:
        tf_update_ops = [lowering.lowered_operation(op) for op in update_ops]
        tf_update_ops.append(tf.assign_add(global_step, 1))
        train_op = tf.group(tf_update_ops)
        saver = tf.train.Saver(tf.global_variables(),
                               sharded=True,
                               max_to_keep=1,
                               keep_checkpoint_every_n_hours=2,
                               defer_build=False,
                               save_relative_paths=True)
        tf.add_to_collection(tf.GraphKeys.SAVERS, saver)
        saver_listener = mtf.MtfCheckpointSaverListener(lowering)
        saver_hook = tf.train.CheckpointSaverHook(fpath,
                                                  save_steps=1000,
                                                  saver=saver,
                                                  listeners=[saver_listener])

        logging_hook = tf.train.LoggingTensorHook(
            {
                "loss": tf_loss,
                "chisq": tf_chisq,
                "prior": tf_prior
            },
            every_n_iter=10)

        # Name tensors to be logged with LoggingTensorHook.
        tf.identity(tf_loss, "loss")
        tf.identity(tf_prior, "prior")
        tf.identity(tf_chisq, "chisq")

        # Save accuracy scalar to Tensorboard output.
        tf.summary.scalar("loss", tf_loss)
        tf.summary.scalar("chisq", tf_chisq)
        tf.summary.scalar("prior", tf_prior)

        # restore_hook must come before saver_hook
        return tf.estimator.EstimatorSpec(
            tf.estimator.ModeKeys.TRAIN,
            loss=tf_loss,
            train_op=train_op,
            training_chief_hooks=[restore_hook, saver_hook, logging_hook])

    ##Eval
    if mode == tf.estimator.ModeKeys.EVAL:
        return tf.estimator.EstimatorSpec(
            mode=tf.estimator.ModeKeys.EVAL,
            loss=tf_loss,
            evaluation_hooks=[restore_hook],
            eval_metric_ops={
                "loss": tf_loss,
                "chisq": tf_chisq,
                #tf.metrics.accuracy(
                #    labels=labels, predictions=tf.argmax(tf_logits, axis=1)),
            })
Exemplo n.º 8
0
    def __call__(self, features, labels, mode, params):  # this is the model_fn
        """A model is called by TpuEstimator."""
        del labels
        global_step = tf.train.get_global_step()

        # Graph setup
        graph = mtf.Graph()
        mesh_shape = mtf.convert_to_shape(self.mesh_shape)
        layout_rules = mtf.convert_to_layout_rules(self.layout)
        if params["use_tpu"]:
            ctx = params["context"]
            num_hosts = ctx.num_hosts
            host_placement_fn = ctx.tpu_host_placement_function
            device_list = [
                host_placement_fn(host_id=t) for t in range(num_hosts)
            ]
            # Worker 0 caches all the TPU binaries.
            replica_cache_size = 300 * 1024 * 1024  # 300M per replica.
            worker0_mem = replica_cache_size * 8 * num_hosts
            devices_memory_usage = [worker0_mem] + [0] * (num_hosts - 1)
            var_placer = mtf.utils.BalancedVariablePlacer(
                device_list, devices_memory_usage)
            mesh = mtf.Mesh(graph, "my_mesh", var_placer)
            mesh_devices = [""] * mesh_shape.size

            mesh_impl = mtf.simd_mesh_impl.SimdMeshImpl(
                mesh_shape, layout_rules, mesh_devices, devices_memory_usage)
        else:
            var_placer = None
            mesh_devices = [""] * mesh_shape.size
            mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl(
                mesh_shape, layout_rules, mesh_devices)

            mesh = mtf.Mesh(graph, "my_mesh", var_placer)

        # RUN Model
        with mtf.utils.outside_all_rewrites():
            logits, loss = self.model(mesh, features, params)

        # TRAIN mode
        if mode == tf.estimator.ModeKeys.TRAIN:
            var_grads = mtf.gradients(
                [loss], [v.outputs[0] for v in graph.trainable_variables])
            if self.optimizer == "Adafactor":
                optimizer = mtf.optimize.AdafactorOptimizer()
            else:
                assert self.optimizer == "SGD"
                optimizer = mtf.optimize.SgdOptimizer(
                    learning_rate=self.learning_rate)
                update_ops = optimizer.apply_grads(var_grads,
                                                   graph.trainable_variables)
        else:
            # for now, we can only export fully-replicated tensors.
            fully_replicated_logits = mtf.anonymize(logits)

        # covert back to tensorflow format
        lowering = mtf.Lowering(graph, {mesh: mesh_impl})
        tf_loss = tf.to_float(lowering.export_to_tf_tensor(loss))
        if mode == tf.estimator.ModeKeys.TRAIN:
            tf_update_ops = [
                lowering.lowered_operation(op) for op in update_ops
            ]
            tf_update_ops.append(tf.assign_add(global_step, 1))
            tf.logging.info("tf_update_ops: {}".format(tf_update_ops))
            train_op = tf.group(tf_update_ops)
        else:
            tf_logits = lowering.export_to_tf_tensor(fully_replicated_logits)

        # create estimator
        with mtf.utils.outside_all_rewrites():
            # Copy master variables to slices. Must be called first.
            restore_hook = mtf.MtfRestoreHook(lowering)
            if mode == tf.estimator.ModeKeys.TRAIN:
                saver = tf.train.Saver(
                    tf.global_variables(),
                    sharded=True,
                    max_to_keep=10,
                    keep_checkpoint_every_n_hours=2,
                    defer_build=False,
                    save_relative_paths=True,
                )
                tf.add_to_collection(tf.GraphKeys.SAVERS, saver)
                saver_listener = mtf.MtfCheckpointSaverListener(lowering)
                saver_hook = tf.train.CheckpointSaverHook(
                    self.model_dir,
                    save_steps=1000,
                    saver=saver,
                    listeners=[saver_listener],
                )

                return tpu_estimator.TPUEstimatorSpec(
                    tf.estimator.ModeKeys.TRAIN,
                    loss=tf_loss,
                    train_op=train_op,
                    training_hooks=[restore_hook, saver_hook],
                )
            elif mode == tf.estimator.ModeKeys.EVAL:

                def metric_fn(tf_logits):
                    mean_logits = tf.metrics.mean(tf_logits)
                    return {"mean_logits": mean_logits}

                eval_metrics = (metric_fn, [tf_logits])
                return tpu_estimator.TPUEstimatorSpec(
                    tf.estimator.ModeKeys.EVAL,
                    evaluation_hooks=[restore_hook],
                    loss=tf_loss,
                    eval_metrics=eval_metrics,
                )
            elif mode == tf.estimator.ModeKeys.PREDICT:
                return tpu_estimator.TPUEstimatorSpec(
                    tf.estimator.ModeKeys.PREDICT,
                    evaluation_hooks=[restore_hook],
                    loss=None,
                    eval_metrics=eval_metrics,
                )

        @property
        def dense_initializer(self):
            if self.config.initializer_range:
                return tf.truncated_normal_initializer(
                    stddev=self.config.initializer_range)
            else:
                return mtf.layers.VarianceScalingInitializer(scale=0.4)

        @property
        def embedding_initializer(self):
            initializer = self.dense_initializer
            if isinstance(initializer, mtf.layers.DenseInitializer):
                # embedding matrix is also used as classifier weight matrix.
                # scale it appropriately.
                return initializer(reduced_dims=[self.model_dim],
                                   new_dims=[self.vocab_dim])
            else:
                return initializer

        @property
        def num_hidden_layers(self):
            return self.config.num_hidden_layers

        def normalize(self, x, reduce_dim):
            return nn.layer_norm(
                x,
                reduce_dim,
                subtract_mean=self.config.use_bias,
                use_bias=self.config.use_bias,
            )

        def model(self, mesh, x, y, params):
            # x :: [batch, io, vocab]

            if params["precision"] == "bfloat16":
                dtype = tf.bfloat16
                # master has type float32, slice and activation have type bfloat16
                variable_dtype = mtf.VariableDType(tf.float32, tf.bfloat16,
                                                   tf.bfloat16)
            else:
                dtype = tf.float32
                # master, slice and activate have all float16
                variable_dtype = mtf.VariableDType(tf.float32, tf.float32,
                                                   tf.float32)

            # Build the actual model
            batch_dim = mtf.Dimension("batch", params["batch_size"])
            vocab_dim = mtf.Dimension("vocab", params["vocab_size"])
            io_dim = mtf.Dimension("sequence", params["io"])
            io_chan_dim = mtf.Dimension("io", params["io_channels"])

            # from input to mtf
            x = mtf.import_tf_tensor(mesh, x,
                                     mtf.Shape([batch_dim, io_dim, vocab_dim]))

            # Embeddings
            with tf.variable_scope(scope="toy", default_name="seq2seq"):
                with tf.variable_scope("embeddings"):
                    # Perform embedding lookup on the word ids.
                    embedding_table = mtf.get_variable(
                        mesh,
                        "word_embeddings",
                        mtf.Shape([vocab_dim, io_chan_dim]),
                        initializer=self.embedding_initializer,
                    )

                    word_embedding_output = mtf.gather(
                        embedding_table,
                        x,
                        dim=vocab_dim,
                        output_shape=io_chan_dim)

                    # Add positional embeddings and token type embeddings, then layer
                    # normalize and perform dropout.
                    embedding_output = word_embedding_output

                    pos_embedding = mtf.get_variable(
                        mesh,
                        "pos_embeddings",
                        mtf.Shape([io_dim, io_chan_dim]),
                        initializer=self.embedding_initializer,
                    )
                    embedding_output = self.normalize(embedding_output)
                    embedding_output = mtf.dropout(
                        embedding_output,
                        keep_prob=1.0 - self.config.layer_output_dropout_prob,
                    )

                # shift token by pos embeddings
                x = word_embedding_output + pos_embedding
                x = mtf.cast(x, variable_dtype.activation_dtype)

                h = x
                for lnum in range(1, self.num_hidden_layers + 2):
                    if lnum + 1 == self.num_hidden_layers + 2:
                        # output layer
                        dim = io_dim
                    elif lnum % 2 == 0:
                        dim = mtf.Dimension("hidden_even", io_chan_dim)
                    else:
                        dim = mtf.Dimension("hidden_odd", io_chan_dim)
                        h = mtf.layers.dense(
                            h,
                            dim,
                            use_bias=False,
                            master_dtype=variable_dtype.master_dtype,
                            slice_dtype=variable_dtype.slice_dtype,
                            name="layer_%d" % lnum,
                        )

                prediction = h
                # project back to token dimensions

                # compute the mean quare loss between the input and the output
                loss = mtf.reduce_mean(mtf.square(y - prediction))
                return prediction, loss
Exemplo n.º 9
0
  def _model_fn(input_fea, input_lab):
    """Creates a model, add summary, modes (train or eval), and hooks."""

    # input_fea and input_lab should be a list (laid_out_tensors).
    if not isinstance(input_fea, list):
      input_fea = [input_fea]
    if not isinstance(input_lab, list):
      input_lab = [input_lab]

    def _add_summary(lowering, train_or_eval, tf_loss, scalars, global_step):
      """Add all summaries."""
      for k in scalars.keys():
        if not isinstance(scalars[k], tf.Tensor):
          scalars[k] = tf.cast(
              lowering.export_to_tf_tensor(scalars[k]), tf.float32)

      def _host_loss_summary(global_step, tf_loss, **scalars):
        """Add summary.scalar in host side."""
        gs = tf.cast(global_step, tf.int64)
        sum_loss = contrib_summary.scalar(
            '{}_loss'.format(train_or_eval), tf_loss, step=gs)
        sum_ops = [sum_loss.op]
        for description, tf_metric in scalars.iteritems():
          sum_metric = contrib_summary.scalar(
              '{}_{}'.format(train_or_eval, description), tf_metric, step=gs)
          sum_ops.append(sum_metric)
        with tf.control_dependencies(sum_ops):
          return tf.identity(tf_loss)

      if FLAGS.use_tpu:
        # Cast the global step to tf.int32, since
        # outside_compilation does not support tf.int64.
        tf_loss = tpu.outside_compilation(
            _host_loss_summary,
            tf.cast(global_step, tf.int32),
            tf_loss,
            **scalars)
      else:
        tf_loss = _host_loss_summary(
            tf.cast(global_step, tf.int32),
            tf_loss,
            **scalars)

      return tf_loss

    global_step = tf.train.get_or_create_global_step()
    graph, mesh, mesh_impl = mesh_context.create_graph_mesh_and_mesh_impl()

    with mtf.utils.outside_all_rewrites():
      # Do not tpu_rewrite this part. Inside this unet, If you use Tensorflow,
      # instead of Mesh-Tensorflor, it will cause host to tpu send/rec.
      preds, loss, scalars, bn_update_ops = (
          unet.unet_with_spatial_partition(
              mesh, mesh_impl, train_or_eval, input_fea, input_lab))

    if train_or_eval == 'train':
      var_grads = mtf.gradients(
          [loss], [v.outputs[0] for v in graph.trainable_variables])

      lr = FLAGS.lr * tf.pow(
          FLAGS.lr_drop_rate,
          tf.floor(tf.cast(global_step, tf.float32) / FLAGS.lr_drop_steps))
      scalars['learning_rate'] = lr

      optimizer = mtf.optimize.AdafactorOptimizer(learning_rate=lr)
      update_ops = optimizer.apply_grads(var_grads, graph.trainable_variables)

      # This is where the actual tf graph got built.
      lowering = mtf.Lowering(graph, {mesh: mesh_impl})

      tf_update_ops = [lowering.lowered_operation(op) for op in update_ops]
      tf_update_ops.append(tf.assign_add(global_step, 1))
      tf_update_ops.extend(
          [lowering.lowered_operation(op) for op in bn_update_ops])

    else:  # train_or_eval == 'eval':
      preds = [mtf.anonymize(pred) for pred in preds]

      # This is where the actual tf graph got built.
      lowering = mtf.Lowering(graph, {mesh: mesh_impl})

      tf_preds = [tf.cast(
          lowering.export_to_tf_tensor(pred), tf.float32) for pred in preds]

    tf_loss = tf.cast(lowering.export_to_tf_tensor(loss), tf.float32)
    if FLAGS.write_summary:
      tf_loss = _add_summary(
          lowering, train_or_eval, tf_loss, scalars, global_step)
    master_to_slice_hook = mtf.MtfRestoreHook(lowering)

    if train_or_eval == 'train':
      with mtf.utils.outside_all_rewrites():
        saver = tf.train.Saver(tf.global_variables(),
                               save_relative_paths=True)
        tf.add_to_collection(tf.GraphKeys.SAVERS, saver)
        saver_listener = mtf.MtfCheckpointSaverListener(lowering)
        slice_to_master_hook = tf.train.CheckpointSaverHook(
            FLAGS.checkpoint_dir,
            save_steps=FLAGS.save_checkpoints_steps,
            saver=saver, listeners=[saver_listener])
        captured_hooks.capture([master_to_slice_hook, slice_to_master_hook])
        return tf.group([tf_loss] + tf_update_ops)

    else:  # train_or_eval == 'eval':
      if FLAGS.use_tpu:
        tf_preds.extend([tf_loss, global_step])
        tf_preds_dtypes = [tf_pred.dtype for tf_pred in tf_preds]
        tf_preds_shapes = [tf_pred.shape for tf_pred in tf_preds]
        captured_hooks.capture([master_to_slice_hook, None])
        captured_output_dtypes_shapes.capture(
            [tf_preds_dtypes, tf_preds_shapes])
        return tpu_ops.outfeed_enqueue_tuple(tf_preds)

      else:
        tf_preds.extend([tf_loss, global_step])
        captured_hooks.capture([master_to_slice_hook, None])
        return tf_preds
Exemplo n.º 10
0
  def estimator_model_fn(cls,
                         hparams,
                         features,
                         labels,
                         mode,
                         config=None,
                         params=None,
                         decode_hparams=None,
                         use_tpu=False):
    hparams = copy.deepcopy(hparams)
    hparams.use_tpu = use_tpu
    # merge decode_hparams into hparams if present
    if mode == tf.estimator.ModeKeys.PREDICT and decode_hparams is not None:
      for k, v in six.iteritems(decode_hparams.values()):
        if hasattr(hparams, k) and getattr(hparams, k) != v:
          tf.logging.warning("Overriding hparams.%s with %s from decode_hparams"
                             % (k, v))
        setattr(hparams, k, v)

    # Instantiate model
    data_parallelism = None
    if not use_tpu and config:
      data_parallelism = config.data_parallelism
    model = cls(
        hparams,
        mode,
        data_parallelism=data_parallelism,
        decode_hparams=decode_hparams)

    global_step = tf.train.get_global_step()

    mesh_shape = mtf.convert_to_shape(hparams.mesh_shape)
    layout_rules = mtf.convert_to_layout_rules(hparams.layout)
    if use_tpu:
      ctx = params["context"]
      num_hosts = ctx.num_hosts
      host_placement_fn = ctx.tpu_host_placement_function
      device_list = [host_placement_fn(host_id=t) for t in range(num_hosts)]
      # TODO(ylc): Better estimation of replica cache size?
      replica_cache_size = 300 * 1000000  # 300M per replica
      # Worker 0 caches all the TPU binaries.
      worker0_mem = replica_cache_size * ctx.num_replicas
      devices_memeory_usage = [worker0_mem] + [0] * (num_hosts - 1)
      var_placer = mtf.utils.BalancedVariablePlacer(device_list,
                                                    devices_memeory_usage)
      mesh_devices = [""] * mesh_shape.size
      mesh_impl = mtf.simd_mesh_impl.SimdMeshImpl(
          mesh_shape, layout_rules, mesh_devices, ctx.device_assignment)
    else:
      var_placer = None
      if data_parallelism is None or len(data_parallelism.ps_devices) == 1:
        mesh_devices = [""] * mesh_shape.size
      else:
        assert len(data_parallelism.ps_devices) == mesh_shape.size
        mesh_devices = data_parallelism.ps_devices
      mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl(
          mesh_shape, layout_rules, mesh_devices)

    graph = mtf.Graph()
    mesh = mtf.Mesh(graph, "my_mesh", var_placer)
    # PREDICT mode
    if mode == tf.estimator.ModeKeys.PREDICT:
      return model.estimator_spec_predict(features, mesh, mesh_impl, use_tpu)

    logits, loss = model.mtf_model_fn(features, mesh)
    if use_tpu and logits is not None:
      logits = mtf.anonymize(logits)

    # TRAIN mode
    if mode == tf.estimator.ModeKeys.TRAIN:
      var_grads = mtf.gradients(
          [loss], [v.outputs[0] for v in graph.trainable_variables])
      lr = learning_rate.learning_rate_schedule(hparams)
      tf.summary.scalar("learning_rate", lr)
      mtf_lr = mtf.import_tf_tensor(
          mesh, tf.convert_to_tensor(lr, dtype=tf.float32), mtf.Shape([]))
      optimizer = mtf.optimize.make_optimizer(hparams, mtf_lr)
      update_ops = []
      for grad, var in zip(var_grads, graph.trainable_variables):
        update_ops.extend(optimizer.apply_grad(grad, var))

    lowering = mtf.Lowering(graph, {mesh: mesh_impl})

    tf_loss = lowering.export_to_tf_tensor(loss)
    tf_loss = tf.to_float(tf_loss)
    if logits and mode != tf.estimator.ModeKeys.TRAIN:
      tf_logits = lowering.export_to_tf_tensor(logits)

    if mode == tf.estimator.ModeKeys.TRAIN:
      tf_update_ops = [lowering.lowered_operation(op) for op in update_ops]
      tf_update_ops.append(tf.assign_add(global_step, 1))
      # tf.logging.info("tf_update_ops: {}".format(tf_update_ops))
      train_op = tf.group(tf_update_ops)

    with mtf.utils.outside_all_rewrites():
      # Copy master variables to slices. Must be called first.
      restore_hook = mtf.MtfRestoreHook(lowering)
      saver = tf.train.Saver(
          tf.global_variables(),
          sharded=True,
          max_to_keep=10,
          keep_checkpoint_every_n_hours=2,
          defer_build=False,
          save_relative_paths=True)
      tf.add_to_collection(tf.GraphKeys.SAVERS, saver)
      saver_listener = mtf.MtfCheckpointSaverListener(lowering)
      saver_hook = tf.train.CheckpointSaverHook(
          hparams.model_dir,
          save_steps=1000,
          saver=saver,
          listeners=[saver_listener])

    # EVAL mode
    if mode == tf.estimator.ModeKeys.EVAL:
      tf_logits = lowering.export_to_tf_tensor(logits)
      return model.estimator_spec_eval(features, tf_logits, labels, tf_loss,
                                       restore_hook, use_tpu)

    if use_tpu:
      # TPU host call. Important: need to be called before remove_summaries()
      if hparams.tpu_enable_host_call:
        host_call = t2t_model.create_host_call(hparams.model_dir)
      else:
        host_call = None

      t2t_model.remove_summaries()
      return tpu_estimator.TPUEstimatorSpec(
          mode=tf.estimator.ModeKeys.TRAIN,
          loss=tf_loss,
          train_op=train_op,
          host_call=host_call,
          training_hooks=[restore_hook, saver_hook])
    else:
      return tf.estimator.EstimatorSpec(
          tf.estimator.ModeKeys.TRAIN, loss=tf_loss, train_op=train_op,
          training_chief_hooks=[restore_hook, saver_hook])
Exemplo n.º 11
0
def my_model_fn(features,
                labels,
                mode,
                params=None,
                config=None):
  """Estimator model function.

  Args:
    features: input features dictionary
    labels: ignored
    mode: a tf.estimator.ModeKeys
    params: something
    config: something
  Returns:
    something
  """
  del labels, config
  use_tpu = FLAGS.tpu
  global_step = tf.train.get_global_step()

  mesh_shape = mtf.convert_to_shape(FLAGS.mesh_shape)
  layout_rules = mtf.convert_to_layout_rules(FLAGS.layout)
  if use_tpu:
    ctx = params["context"]
    num_hosts = ctx.num_hosts
    host_placement_fn = ctx.tpu_host_placement_function
    device_list = [host_placement_fn(host_id=t) for t in range(num_hosts)]
    # TODO(ylc): Better estimation of replica cache size?
    replica_cache_size = 300 * 1000000  # 300M per replica
    # Worker 0 caches all the TPU binaries.
    worker0_mem = replica_cache_size * ctx.num_replicas
    devices_memeory_usage = [worker0_mem] + [0] * (num_hosts - 1)
    var_placer = mtf.utils.BalancedVariablePlacer(device_list,
                                                  devices_memeory_usage)
    mesh_devices = [""] * mesh_shape.size
    mesh_impl = mtf.simd_mesh_impl.SimdMeshImpl(
        mesh_shape, layout_rules, mesh_devices, ctx.device_assignment)
  else:
    var_placer = None
    mesh_devices = [""] * mesh_shape.size
    mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl(
        mesh_shape, layout_rules, mesh_devices)

  graph = mtf.Graph()
  mesh = mtf.Mesh(graph, "my_mesh", var_placer)

  model = transformer.Bitransformer(
      encoder_layer_stack=layer_stack(include_encdec_attention=False),
      decoder_layer_stack=layer_stack(include_encdec_attention=True),
      encoder_d_model=FLAGS.d_model,
      decoder_d_model=FLAGS.d_model,
      input_vocab_size=transformer_dataset.padded_vocab_size(
          transformer_dataset.inputs_vocab_size(FLAGS.dataset)),
      output_vocab_size=transformer_dataset.padded_vocab_size(
          transformer_dataset.targets_vocab_size(FLAGS.dataset)),
      max_length=FLAGS.max_length,
      shared_embedding=False,
      shared_embedding_and_softmax_weights=True,
      label_smoothing=FLAGS.label_smoothing,
      layout=FLAGS.layout,
      mesh_shape=FLAGS.mesh_shape)

  inputs = import_feature(features, mesh, "inputs")

  # Data-types used for variables and activations
  # See comments in the FLAGS
  master_dtype = tf.as_dtype(FLAGS.master_dtype)
  if FLAGS.slice_dtype:
    slice_dtype = tf.as_dtype(FLAGS.slice_dtype)
  elif not FLAGS.tpu or FLAGS.mode == "train":
    slice_dtype = tf.float32
  else:
    slice_dtype = tf.bfloat16
  if FLAGS.activation_dtype:
    activation_dtype = tf.as_dtype(FLAGS.activation_dtype)
  else:
    activation_dtype = tf.bfloat16 if FLAGS.tpu else tf.float32
  variable_dtype = mtf.VariableDType(master_dtype=master_dtype,
                                     slice_dtype=slice_dtype,
                                     activation_dtype=activation_dtype)

  # PREDICT mode
  if mode == tf.estimator.ModeKeys.PREDICT:
    mtf_samples = model.decode(
        inputs,
        variable_dtype=variable_dtype,
        beam_size=FLAGS.beam_size,
        alpha=FLAGS.alpha,
        temperature=FLAGS.temperature)
    mtf_samples = mtf.anonymize(mtf_samples)
    lowering = mtf.Lowering(graph, {mesh: mesh_impl}, autostack=FLAGS.autostack)
    outputs = lowering.export_to_tf_tensor(mtf_samples)
    predictions = {
        "outputs": outputs
    }
    return tpu_estimator.TPUEstimatorSpec(
        mode=tf.estimator.ModeKeys.PREDICT,
        predictions=predictions,
        prediction_hooks=[mtf.MtfRestoreHook(lowering)])

  targets = import_feature(features, mesh, "targets")
  anon_targets = mtf.anonymize(targets)
  logits, loss = model.call_simple(
      inputs=inputs,
      targets=targets,
      compute_loss=True,
      mode=mode,
      variable_dtype=variable_dtype,
      encoder_sequence_id=import_feature(features, mesh, "inputs_segmentation"),
      decoder_sequence_id=import_feature(
          features, mesh, "targets_segmentation"),
      encoder_position=import_feature(features, mesh, "inputs_position"),
      decoder_position=import_feature(features, mesh, "targets_position")
  )

  if use_tpu and logits is not None:
    logits = mtf.anonymize(logits)

  # TRAIN mode
  if mode == tf.estimator.ModeKeys.TRAIN:
    var_grads = mtf.gradients(
        [loss], [v.outputs[0] for v in graph.trainable_variables])
    optimizer = mtf.optimize.AdafactorOptimizer()
    update_ops = optimizer.apply_grads(var_grads, graph.trainable_variables)

  lowering = mtf.Lowering(graph, {mesh: mesh_impl}, autostack=FLAGS.autostack)

  tf_loss = lowering.export_to_tf_tensor(loss)
  tf_loss = tf.to_float(tf_loss)
  if not use_tpu:
    tf_loss = tf.Print(
        tf_loss, [tf_loss, tf.train.get_global_step()], "step, tf_loss")
  if logits and mode != tf.estimator.ModeKeys.TRAIN:
    tf_logits = lowering.export_to_tf_tensor(logits)

  if mode == tf.estimator.ModeKeys.TRAIN:
    tf_update_ops = [lowering.lowered_operation(op) for op in update_ops]
    tf_update_ops.append(tf.assign_add(global_step, 1))
    train_op = tf.group(tf_update_ops)

  with mtf.utils.outside_all_rewrites():
    # Copy master variables to slices. Must be called first.
    restore_hook = mtf.MtfRestoreHook(lowering)
    saver = tf.train.Saver(
        tf.global_variables(),
        sharded=True,
        max_to_keep=10,
        keep_checkpoint_every_n_hours=2,
        defer_build=False,
        save_relative_paths=True)
    tf.add_to_collection(tf.GraphKeys.SAVERS, saver)
    saver_listener = mtf.MtfCheckpointSaverListener(lowering)
    saver_hook = tf.train.CheckpointSaverHook(
        FLAGS.model_dir,
        save_steps=1000,
        saver=saver,
        listeners=[saver_listener])

    if mode == tf.estimator.ModeKeys.TRAIN:
      if use_tpu:
        return tpu_estimator.TPUEstimatorSpec(
            mode=tf.estimator.ModeKeys.TRAIN,
            loss=tf_loss,
            train_op=train_op,
            training_hooks=[restore_hook, saver_hook])
      else:
        return tf.estimator.EstimatorSpec(
            tf.estimator.ModeKeys.TRAIN, loss=tf_loss, train_op=train_op,
            training_chief_hooks=[restore_hook, saver_hook])
    elif mode == tf.estimator.ModeKeys.EVAL:
      def padded_neg_log_perplexity(logits, labels):
        weights = tf.to_float(tf.not_equal(labels, 0))
        xent = tf.nn.sparse_softmax_cross_entropy_with_logits(
            labels=labels, logits=logits)
        return {"neg_log_perplexity": tf.metrics.mean(-xent, weights)}
      labels = lowering.export_to_tf_tensor(anon_targets)
      eval_metrics = (padded_neg_log_perplexity, [tf_logits, labels])
      return tpu_estimator.TPUEstimatorSpec(
          tf.estimator.ModeKeys.EVAL,
          evaluation_hooks=[restore_hook],
          loss=tf_loss,
          eval_metrics=eval_metrics)
Exemplo n.º 12
0
    def my_model_fn(features, labels, mode, params=None, config=None):
        """Estimator model function.
        Args:
          features: dictionary where keys are strings like "inputs" and "targets"
            and the values are the actual values of "inputs". See TPUEstimator's
            docs for more information
          labels: ignored argument
          mode: a tf.estimator.ModeKeys
          params: dictionary containing the key "context"
          config: ignored argument
        Returns:
          a TPUEstimatorSpec
        """
        del labels, config
        global_step = tf.train.get_global_step()
        if use_tpu and "context" in params:
            ctx = params["context"]
            num_hosts = ctx.num_hosts
            host_placement_fn = ctx.tpu_host_placement_function
            device_list = [
                host_placement_fn(host_id=t) for t in range(num_hosts)
            ]
            # TODO(ylc): Better estimation of replica cache size?
            replica_cache_size = 300 * 1000000  # 300M per replica
            # Worker 0 caches all the TPU binaries.
            worker0_mem = replica_cache_size * ctx.num_replicas
            devices_memeory_usage = [worker0_mem] + [0] * (num_hosts - 1)
            var_placer = mtf.utils.BalancedVariablePlacer(
                device_list, devices_memeory_usage)
            # deprecated mesh_devices = [""] * mesh_shape.size
            physical_shape = list(
                params["context"].device_assignment.topology.mesh_shape)
            logical_to_physical = mtf.simd_mesh_impl.auto_logical_to_physical_tpu(
                mesh_shape.to_integer_list, physical_shape)
            mesh_impl = mtf.simd_mesh_impl.SimdMeshImpl(
                mesh_shape,
                layout_rules,
                mesh_devices,
                ctx.device_assignment,
                logical_to_physical=logical_to_physical)
        else:
            var_placer = None
            # deprecated mesh_devices = [""] * mesh_shape.size
            mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl(
                mesh_shape, layout_rules, mesh_devices)

        graph = mtf.Graph()
        mesh = mtf.Mesh(graph, "my_mesh", var_placer)

        mtf_features = {}
        for key, x in features.items():
            outer_batch_dim = mtf.Dimension("outer_batch", outer_batch_size)
            batch_dim = mtf.Dimension("batch", batch_size // outer_batch_size)
            # Some auxiliary features may have been generated in packing.
            # The names of these new features are of the form
            #   "<original_feature_name>_<suffix>", e.g. "inputs_segmentation".
            #   We look up the lengths based on the original feature name, without
            #   the "_<suffix>".
            feature_length = sequence_length[key.split("_")[0]]
            length_dim = mtf.Dimension("length", feature_length)
            ensemble_dims = ([mtf.Dimension("ensemble", ensemble_inputs)]
                             if ensemble_inputs else [])
            feature_shape = mtf.Shape(ensemble_dims +
                                      [outer_batch_dim, batch_dim, length_dim])
            x = tf.cast(features[key], tf.int32)
            x = tf.reshape(x, feature_shape.to_integer_list)
            if not use_tpu:
                tf.logging.info("feature %s : %s" % (key, x))
                x = tf.Print(x, [x],
                             "import feature %s" % key,
                             summarize=1000,
                             first_n=10)
            mtf_features[key] = mtf.import_fully_replicated(mesh,
                                                            x,
                                                            feature_shape,
                                                            name=key)
            if key == "targets" or key == "codeprefixedtargets" or key == "controlcode":
                anon_targets = mtf.anonymize(mtf_features[key])

        if mode == tf.estimator.ModeKeys.PREDICT:

            def _feature_shape(key):
                feature_length = sequence_length[key.split("_")[0]]
                return mtf.Shape([
                    mtf.Dimension("batch", batch_size),
                    mtf.Dimension("length", feature_length)
                ])

            mtf_features = {
                k: mtf.reshape(v, _feature_shape(k))
                for k, v in six.iteritems(mtf_features)
            }
            inputs = mtf_features["inputs"]

            if attribute_embedding:
                attributes = mtf_features["attribute"]
            else:
                attributes = None

            if has_partial_sequences:
                controlcodes = mtf_features["controlcode"]
            else:
                controlcodes = None

            if predict_fn:
                mtf_samples = predict_fn(model=transformer_model,
                                         features=mtf_features,
                                         variable_dtype=get_variable_dtype())
            elif isinstance(transformer_model, transformer.Unitransformer):
                # pad so that there is enough room for the targets
                inputs = mtf.pad(inputs, [0, sequence_length["targets"]],
                                 length_dim.name)
                mtf_samples = transformer_model.sample_autoregressive(
                    inputs,
                    variable_dtype=get_variable_dtype(),
                    remove_partial_sequences=True)
            elif isinstance(transformer_model, Bitransformer_ll):
                mtf_samples = transformer_model.decode(
                    inputs,
                    attributes=attributes,
                    controlcodes=controlcodes,
                    has_partial_sequences=has_partial_sequences,
                    remove_partial_sequences=remove_partial_sequences,
                    variable_dtype=get_variable_dtype())  #
            elif isinstance(
                    transformer_model,
                (transformer.Bitransformer, transformer.StudentTeacher)):
                mtf_samples = transformer_model.decode(
                    inputs, variable_dtype=get_variable_dtype())
            else:
                raise ValueError("unrecognized class")
            mtf_samples = mtf.anonymize(mtf_samples)
            inputs = mtf.anonymize(inputs)
            lowering = mtf.Lowering(graph, {mesh: mesh_impl},
                                    autostack=autostack)
            inputs = lowering.export_to_tf_tensor(inputs)
            outputs = lowering.export_to_tf_tensor(mtf_samples)
            predictions = {"inputs": inputs, "outputs": outputs}

            # When exporting a model, we need to communicate to TF-Serving that
            # master variables need to be copied to their slave slice variables.
            # Estimator uses a Scaffold's "local_init_op" for this purpose, so we
            # augment the default "local_init_op" here.
            #
            # The "ready_op" is also constructed here to ensure the variables
            # initialized by "local_init_op" are the same ones checked by "ready_op".
            #
            # WARNING: Any variables created outside of this model_fn()
            # (e.g. tpu_estimator/iterations_per_loop) will NOT be initialized nor
            # checked by these ops.
            def scaffold_fn():
                return tf.train.Scaffold(
                    local_init_op=tf.group(
                        tf.train.Scaffold.default_local_init_op(),
                        lowering.copy_masters_to_slices(),
                        name="mtf_local_init_op"),
                    ready_op=tf.concat([
                        tf.report_uninitialized_variables(),
                        resources.report_uninitialized_resources()
                    ],
                                       axis=0,
                                       name="mtf_ready_op"))

            return tpu_estimator.TPUEstimatorSpec(
                mode=tf.estimator.ModeKeys.PREDICT,
                predictions=predictions,
                scaffold_fn=scaffold_fn,
                prediction_hooks=[mtf.MtfRestoreHook(lowering)])

        assert (mode == tf.estimator.ModeKeys.TRAIN
                or mode == tf.estimator.ModeKeys.EVAL)

        def logits_and_loss(mtf_features):
            """Compute logits and loss.
            Args:
              mtf_features: a dictionary
            Returns:
              logits: a mtf.Tensor
              loss: a mtf.Tensor
            """
            if model_type == "lm":  # TOTRY Adapt that to our case
                if "inputs" in mtf_features:
                    mtf_features = _dynamic_text2self(mtf_features)
                _, _, length_dim = mtf_features["targets"].shape
                inputs = mtf.shift(mtf_features["targets"],
                                   offset=1,
                                   dim=length_dim,
                                   wrap=False)
            else:
                inputs = mtf_features["inputs"]

            if attribute_embedding:
                attributes = mtf_features["attribute"]
            else:
                attributes = None

            if control_codes:
                codeprefixedtargets = mtf_features["codeprefixedtargets"]
            else:
                codeprefixedtargets = None

            if isinstance(transformer_model, transformer.Unitransformer):
                position_kwargs = dict(
                    sequence_id=mtf_features.get("targets_segmentation", None),
                    position=mtf_features.get("targets_position", None),
                )
            elif isinstance(transformer_model, transformer.Bitransformer
                            ) or model_type == "bi_student_teacher":
                if control_codes:
                    position_kwargs = dict(
                        encoder_sequence_id=mtf_features.get(
                            "inputs_segmentation", None),
                        decoder_sequence_id=mtf_features.get(
                            "codeprefixedtargets_segmentation", None),
                        decoder_subsequence_id=mtf_features.get(
                            "codeprefixedtargets_subsegmentation", None),
                        encoder_position=mtf_features.get(
                            "inputs_position", None),
                        decoder_position=mtf_features.get(
                            "codeprefixedtargets_position", None),
                    )
                else:
                    position_kwargs = dict(
                        encoder_sequence_id=mtf_features.get(
                            "inputs_segmentation", None),
                        decoder_sequence_id=mtf_features.get(
                            "targets_segmentation", None),
                        decoder_subsequence_id=mtf_features.get(
                            "targets_subsegmentation", None),
                        encoder_position=mtf_features.get(
                            "inputs_position", None),
                        decoder_position=mtf_features.get(
                            "targets_position", None),
                    )
            else:
                raise ValueError("unrecognized class")

            if isinstance(transformer_model, Bitransformer_ll):
                if cycle_consistency_loss:
                    logits_ae, l_ae = transformer_model.call_simple(
                        inputs=inputs,
                        targets=mtf_features["targets"],
                        compute_loss=True,
                        attributes=attributes,
                        codeprefixedtargets=codeprefixedtargets,
                        mode=mode,
                        variable_dtype=get_variable_dtype(),
                        **position_kwargs)

                    if has_partial_sequences:
                        controlcodes = mtf_features["controlcode"]
                    else:
                        controlcodes = None

                    with gin.config_scope('training'):
                        mtf_samples = transformer_model.decode(
                            inputs,
                            attributes=attributes,
                            controlcodes=controlcodes,
                            has_partial_sequences=has_partial_sequences,
                            remove_partial_sequences=remove_partial_sequences,
                            variable_dtype=get_variable_dtype())
                        # mtf_samples = mtf.anonymize(mtf_samples)
                    outputs = mtf_samples

                    logits_cycle, l_cycle = transformer_model.call_simple(
                        inputs=outputs,
                        targets=mtf_features["targets"],
                        compute_loss=True,
                        attributes=attributes,
                        codeprefixedtargets=codeprefixedtargets,
                        mode=mode,
                        variable_dtype=get_variable_dtype(),
                        **position_kwargs)

                    loss_ae_cycle = lambda_ae * l_ae + lambda_cycle * l_cycle
                    return logits_cycle, loss_ae_cycle
                else:
                    return transformer_model.call_simple(
                        inputs=inputs,
                        targets=mtf_features["targets"],
                        compute_loss=True,
                        attributes=attributes,
                        codeprefixedtargets=codeprefixedtargets,
                        mode=mode,
                        variable_dtype=get_variable_dtype(),
                        **position_kwargs)
            else:
                return transformer_model.call_simple(
                    inputs=inputs,
                    targets=mtf_features["targets"],
                    compute_loss=True,
                    mode=mode,
                    variable_dtype=get_variable_dtype(),
                    num_microbatches=num_microbatches,
                    **position_kwargs)

        if mode == tf.estimator.ModeKeys.TRAIN:
            num_microbatches = serialize_num_microbatches(
                batch_dim, sequence_length, mesh_shape, layout_rules)
            if num_microbatches > 1:

                def serialized_fn(mtf_features):
                    return {
                        "loss":
                        (logits_and_loss(mtf_features)[1] / num_microbatches)
                    }

                var_grads, loss_dict = mtf.serialize_training_step(
                    mtf_features, serialized_fn, batch_dim, num_microbatches)
                loss = loss_dict["loss"]
            else:
                loss = logits_and_loss(mtf_features)[1]
                var_grads = mtf.gradients(
                    [loss], [v.outputs[0] for v in graph.trainable_variables])

            if tpu_summaries:
                mtf.scalar_summary("loss", loss)

            if callable(learning_rate_schedule):
                # the following happens on CPU since TPU can't handle summaries.
                with mtf.utils.outside_all_rewrites():
                    learning_rate = learning_rate_schedule(
                        step=tf.train.get_global_step())
                    tf.summary.scalar("learning_rate", learning_rate)
            else:
                learning_rate = learning_rate_schedule

            if isinstance(variable_filter, str):
                pattern = re.compile(variable_filter)
                variable_filter_fn = lambda v: pattern.search(v.name)
            elif variable_filter is None:
                variable_filter_fn = lambda v: True
            elif callable(variable_filter):
                variable_filter_fn = variable_filter
            else:
                raise ValueError(
                    "variable_filter must be None, a string, or a callable function"
                )
            trainable_vars = [
                v for v in graph.trainable_variables if variable_filter_fn(v)
            ]
            trainable_var_grads = [
                g for g, v in zip(var_grads, graph.trainable_variables)
                if variable_filter_fn(v)
            ]
            if len(trainable_vars) != len(graph.trainable_variables):
                tf.logging.info("Variables being trained:")
                tf.logging.info([v.name for v in trainable_vars])
                tf.logging.info("Variables not being trained:")
                tf.logging.info([
                    v.name for v in graph.trainable_variables
                    if not variable_filter_fn(v)
                ])

            update_ops = optimizer(learning_rate=learning_rate).apply_grads(
                trainable_var_grads, trainable_vars)

            lowering = mtf.Lowering(graph, {mesh: mesh_impl},
                                    autostack=autostack)

            tf_loss = lowering.export_to_tf_tensor(loss)
            tf_loss = tf.cast(tf_loss, tf.float32)
            if not use_tpu:
                tf_loss = tf.Print(
                    tf_loss, [tf_loss, tf.train.get_global_step()],
                    "step, tf_loss")

            tf_update_ops = [
                lowering.lowered_operation(op) for op in update_ops
            ]
            tf_update_ops.append(tf.assign_add(global_step, 1))
            train_op = tf.group(tf_update_ops)

            if hasattr(transformer_model, "initialize"):
                with mtf.utils.outside_all_rewrites():
                    transformer_model.initialize()

            if tpu_summaries:
                # has to be outside of
                # with mtf.utils.outside_all_rewrites()
                host_call = mtf.utils.create_host_call(model_dir)
                mtf.utils.remove_summaries()
            else:
                host_call = None

            with mtf.utils.outside_all_rewrites():

                if init_checkpoint:
                    ckpt_vars = {
                        v
                        for v, _ in tf.train.list_variables(init_checkpoint)
                    }
                    global_vars = {v.op.name for v in tf.global_variables()}
                    restore_vars = ckpt_vars.intersection(global_vars)
                    tf.logging.info("Initializing variables from %s:",
                                    init_checkpoint)
                    tf.logging.debug("\n".join(sorted(restore_vars)))
                    tf.logging.info("Variables in %s but not in graph:",
                                    init_checkpoint)
                    tf.logging.info("\n".join(sorted(ckpt_vars - global_vars)))
                    tf.logging.info("Variables in graph but not in %s:",
                                    init_checkpoint)
                    tf.logging.info("\n".join(sorted(global_vars - ckpt_vars)))
                    tf.train.init_from_checkpoint(init_checkpoint,
                                                  {v: v
                                                   for v in restore_vars})

                # Copy master variables to slices. Must be called first.
                restore_hook = mtf.MtfRestoreHook(lowering)
                saver = tf.train.Saver(tf.global_variables(),
                                       sharded=True,
                                       max_to_keep=keep_checkpoint_max,
                                       keep_checkpoint_every_n_hours=2,
                                       defer_build=False,
                                       save_relative_paths=True)
                tf.add_to_collection(tf.GraphKeys.SAVERS, saver)
                saver_listener = mtf.MtfCheckpointSaverListener(lowering)
                saver_hook = tf.train.CheckpointSaverHook(
                    model_dir,
                    save_steps=save_checkpoints_steps,
                    saver=saver,
                    listeners=[saver_listener])
                gin_config_saver_hook = gin.tf.GinConfigSaverHook(
                    model_dir,
                    summarize_config=True,
                    include_step_in_filename=False)

                if use_tpu:
                    return tpu_estimator.TPUEstimatorSpec(
                        mode=tf.estimator.ModeKeys.TRAIN,
                        loss=tf_loss,
                        train_op=train_op,
                        host_call=host_call,
                        training_hooks=[
                            restore_hook,
                            saver_hook,
                            gin_config_saver_hook,
                        ])
                else:
                    return tf.estimator.EstimatorSpec(
                        tf.estimator.ModeKeys.TRAIN,
                        loss=tf_loss,
                        train_op=train_op,
                        training_chief_hooks=[
                            restore_hook,
                            saver_hook,
                            gin_config_saver_hook,
                        ])
        elif mode == tf.estimator.ModeKeys.EVAL:
            logits, loss = logits_and_loss(mtf_features)
            anon_logits = mtf.anonymize(logits)
            lowering = mtf.Lowering(graph, {mesh: mesh_impl},
                                    autostack=autostack)
            tf_loss = tf.cast(lowering.export_to_tf_tensor(loss), tf.float32)
            tf_loss = tf.cast(tf_loss, tf.float32)
            tf_logits = tf.cast(lowering.export_to_tf_tensor(anon_logits),
                                tf.float32)

            def simple_metrics(logits, labels):
                """Simple metrics for teacher-forced eval."""
                weights = tf.cast(tf.not_equal(labels, 0), tf.float32)
                xent = tf.nn.sparse_softmax_cross_entropy_with_logits(
                    labels=labels, logits=logits)
                predictions = tf.cast(tf.argmax(logits, axis=-1), labels.dtype)
                token_correct = tf.cast(tf.equal(predictions, labels),
                                        tf.float32) * weights
                sequence_correct = tf.to_float(
                    tf.equal(tf.reduce_sum(token_correct, -1),
                             tf.reduce_sum(weights, -1)))
                sequence_weights = tf.to_float(
                    tf.not_equal(tf.reduce_sum(weights, -1), 0))
                return {
                    "neg_log_perplexity":
                    tf.metrics.mean(-xent, weights),
                    "token_accuracy":
                    tf.metrics.mean(token_correct, weights),
                    "sequence_accuracy":
                    tf.metrics.mean(sequence_correct, sequence_weights)
                }

            labels = lowering.export_to_tf_tensor(anon_targets)
            eval_metrics = (simple_metrics, [tf_logits, labels])
            with mtf.utils.outside_all_rewrites():
                restore_hook = mtf.MtfRestoreHook(lowering)
            return tpu_estimator.TPUEstimatorSpec(
                tf.estimator.ModeKeys.EVAL,
                evaluation_hooks=[restore_hook],
                loss=tf_loss,
                eval_metrics=eval_metrics)
Exemplo n.º 13
0
    def my_model_fn(features, labels, mode, params=None, config=None):
        """Estimator model function.

    Args:
      features: input features dictionary
      labels: ignored
      mode: a tf.estimator.ModeKeys
      params: something
      config: something

    Returns:
      something
    """
        del labels, config
        global_step = tf.train.get_global_step()
        if use_tpu:
            ctx = params["context"]
            num_hosts = ctx.num_hosts
            host_placement_fn = ctx.tpu_host_placement_function
            device_list = [
                host_placement_fn(host_id=t) for t in range(num_hosts)
            ]
            # TODO(ylc): Better estimation of replica cache size?
            replica_cache_size = 300 * 1000000  # 300M per replica
            # Worker 0 caches all the TPU binaries.
            worker0_mem = replica_cache_size * ctx.num_replicas
            devices_memeory_usage = [worker0_mem] + [0] * (num_hosts - 1)
            var_placer = mtf.utils.BalancedVariablePlacer(
                device_list, devices_memeory_usage)
            mesh_devices = [""] * mesh_shape.size
            physical_shape = list(
                params["context"].device_assignment.topology.mesh_shape)
            logical_to_physical = _logical_to_physical(physical_shape,
                                                       mesh_shape)
            mesh_impl = mtf.simd_mesh_impl.SimdMeshImpl(
                mesh_shape,
                layout_rules,
                mesh_devices,
                ctx.device_assignment,
                logical_to_physical=logical_to_physical)
        else:
            var_placer = None
            mesh_devices = [""] * mesh_shape.size
            mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl(
                mesh_shape, layout_rules, mesh_devices)

        graph = mtf.Graph()
        mesh = mtf.Mesh(graph, "my_mesh", var_placer)

        outer_batch_dim = mtf.Dimension("outer_batch", outer_batch_size)
        batch_dim = mtf.Dimension("batch", batch_size // outer_batch_size)
        length_dim = mtf.Dimension("length", sequence_length)
        feature_shape = mtf.Shape([outer_batch_dim, batch_dim, length_dim])

        mtf_features = {}
        for key, x in features.items():
            x = tf.to_int32(features[key])
            x = tf.reshape(x, [
                outer_batch_size, batch_size // outer_batch_size,
                sequence_length
            ])
            if not use_tpu:
                x = tf.Print(x, [x],
                             "import feature %s" % key,
                             summarize=1000,
                             first_n=1)
            mtf_features[key] = mtf.import_fully_replicated(mesh,
                                                            x,
                                                            feature_shape,
                                                            name=key)

        if mode == tf.estimator.ModeKeys.PREDICT:
            inputs = mtf_features["inputs"]
            inputs = mtf.reshape(
                inputs,
                mtf.Shape([
                    mtf.Dimension("batch", batch_size),
                    mtf.Dimension("length", sequence_length)
                ]))
            if isinstance(transformer_model, transformer.Unitransformer):
                mtf_samples = transformer_model.sample_autoregressive(
                    inputs, variable_dtype=get_variable_dtype())
            elif isinstance(
                    transformer_model,
                (transformer.Bitransformer, transformer.StudentTeacher)):
                mtf_samples = transformer_model.decode(
                    inputs, variable_dtype=get_variable_dtype())
            else:
                raise ValueError("unrecognized class")
            mtf_samples = mtf.anonymize(mtf_samples)
            lowering = mtf.Lowering(graph, {mesh: mesh_impl},
                                    autostack=autostack)
            outputs = lowering.export_to_tf_tensor(mtf_samples)
            predictions = {"outputs": outputs}
            return tpu_estimator.TPUEstimatorSpec(
                mode=tf.estimator.ModeKeys.PREDICT,
                predictions=predictions,
                prediction_hooks=[mtf.MtfRestoreHook(lowering)])

        elif mode == tf.estimator.ModeKeys.EVAL:
            raise NotImplementedError("We don't expect to use mode == eval.")

        else:
            assert mode == tf.estimator.ModeKeys.TRAIN
            num_microbatches = serialize_num_microbatches(
                batch_dim, length_dim, mesh_shape, layout_rules)

            def model_fn(mtf_features):
                """The kind of function we need for mtf.serialize_training_step.

        Args:
          mtf_features: a dictionary
        Returns:
          a dictionary
        """
                targets = mtf_features["targets"]
                if model_type == "lm":
                    _, _, length_dim = targets.shape
                    inputs = mtf.shift(targets,
                                       offset=1,
                                       dim=length_dim,
                                       wrap=False)
                else:
                    inputs = mtf_features["inputs"]

                if isinstance(transformer_model, transformer.Unitransformer):
                    position_kwargs = dict(
                        sequence_id=mtf_features.get("targets_segmentation",
                                                     None),
                        position=mtf_features.get("targets_position", None),
                    )
                elif isinstance(transformer_model, transformer.Bitransformer
                                ) or model_type == "bi_student_teacher":
                    position_kwargs = dict(
                        encoder_sequence_id=mtf_features.get(
                            "inputs_segmentation", None),
                        decoder_sequence_id=mtf_features.get(
                            "targets_segmentation", None),
                        encoder_position=mtf_features.get(
                            "inputs_position", None),
                        decoder_position=mtf_features.get(
                            "targets_position", None),
                    )
                else:
                    raise ValueError("unrecognized class")

                logits, loss = transformer_model.call_simple(
                    inputs=inputs,
                    targets=targets,
                    compute_loss=True,
                    mode=mode,
                    variable_dtype=get_variable_dtype(),
                    **position_kwargs)
                if num_microbatches > 1:
                    loss /= float(num_microbatches)
                del logits
                return {"loss": loss}

            if num_microbatches > 1:
                var_grads, loss_dict = mtf.serialize_training_step(
                    mtf_features, model_fn, batch_dim, num_microbatches)
            else:
                loss_dict = model_fn(mtf_features)
                var_grads = mtf.gradients(
                    [loss_dict["loss"]],
                    [v.outputs[0] for v in graph.trainable_variables])

            loss = loss_dict["loss"]

            if callable(learning_rate_schedule):
                # the following happens on CPU since TPU can't handle summaries.
                with mtf.utils.outside_all_rewrites():
                    learning_rate = learning_rate_schedule(
                        step=tf.train.get_global_step())
                    tf.summary.scalar("learning_rate", learning_rate)
            else:
                learning_rate = learning_rate_schedule

            update_ops = optimizer(learning_rate=learning_rate).apply_grads(
                var_grads, graph.trainable_variables)

            lowering = mtf.Lowering(graph, {mesh: mesh_impl},
                                    autostack=autostack)

            tf_loss = lowering.export_to_tf_tensor(loss)
            tf_loss = tf.to_float(tf_loss)
            if not use_tpu:
                tf_loss = tf.Print(
                    tf_loss, [tf_loss, tf.train.get_global_step()],
                    "step, tf_loss")

            tf_update_ops = [
                lowering.lowered_operation(op) for op in update_ops
            ]
            tf_update_ops.append(tf.assign_add(global_step, 1))
            train_op = tf.group(tf_update_ops)

            if hasattr(transformer_model, "initialize"):
                with mtf.utils.outside_all_rewrites():
                    transformer_model.initialize()

            with mtf.utils.outside_all_rewrites():
                # Copy master variables to slices. Must be called first.
                restore_hook = mtf.MtfRestoreHook(lowering)
                saver = tf.train.Saver(tf.global_variables(),
                                       sharded=True,
                                       max_to_keep=keep_checkpoint_max,
                                       keep_checkpoint_every_n_hours=2,
                                       defer_build=False,
                                       save_relative_paths=True)
                tf.add_to_collection(tf.GraphKeys.SAVERS, saver)
                saver_listener = mtf.MtfCheckpointSaverListener(lowering)
                saver_hook = tf.train.CheckpointSaverHook(
                    model_dir,
                    save_steps=save_checkpoints_steps,
                    saver=saver,
                    listeners=[saver_listener])
                gin_config_saver_hook = gin.tf.GinConfigSaverHook(
                    model_dir, summarize_config=True)

                if use_tpu:
                    if tpu_summaries:
                        tf.summary.scalar("loss", tf_loss)
                        host_call = mtf.utils.create_host_call(model_dir)
                        mtf.utils.remove_summaries()
                    else:
                        host_call = None
                    return tpu_estimator.TPUEstimatorSpec(
                        mode=tf.estimator.ModeKeys.TRAIN,
                        loss=tf_loss,
                        train_op=train_op,
                        host_call=host_call,
                        training_hooks=[
                            restore_hook,
                            saver_hook,
                            gin_config_saver_hook,
                        ])
                else:
                    return tf.estimator.EstimatorSpec(
                        tf.estimator.ModeKeys.TRAIN,
                        loss=tf_loss,
                        train_op=train_op,
                        training_chief_hooks=[
                            restore_hook,
                            saver_hook,
                            gin_config_saver_hook,
                        ])
Exemplo n.º 14
0
    # result = mtf.sigmoid(result)
    # result = -(label*mtf.log(result)+(1-label)*mtf.log(1-result))
    # result = mtf.reduce_sum(result)
    result = mtf.layers.sigmoid_cross_entropy_with_logits(result,label)
    wide_loss = mtf.reduce_sum(result)

    

    print("========",global_step)
    devices = ["gpu:0"]
    mesh_shape = [("all_processors", 1)]
    layout_rules = [("dim1", "all_processors")]
    mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl(
        mesh_shape, layout_rules, devices)

    var_grads = mtf.gradients(
			[wide_loss], [v.outputs[0] for v in graph.trainable_variables])
    optimizer = mtf.optimize.SgdOptimizer(0.01)
    update_ops = optimizer.apply_grads(var_grads, graph.trainable_variables)
    lowering = mtf.Lowering(graph, {mesh:mesh_impl})
    restore_hook = mtf.MtfRestoreHook(lowering)

    tf_update_ops = [lowering.lowered_operation(op) for op in update_ops]
    tf_update_ops.append(tf.assign_add(global_step, 1))
    train_op = tf.group(tf_update_ops)

    estimator=tf.estimator.EstimatorSpec(
			tf.estimator.ModeKeys.TRAIN, loss=wide_loss, train_op=train_op,
			training_chief_hooks=[restore_hook])

    WDlaunch = tf.estimator.Estimator(
		model_fn=estimator,
Exemplo n.º 15
0
def Optimize(graph, loss, lr=0.01):
    grads = mtf.gradients([loss],
                          [v.outputs[0] for v in graph.trainable_variables])
    assert all(g for g in grads)
    opt = mtf.optimize.SgdOptimizer(lr)
    return opt.apply_grads(grads, graph.trainable_variables)
Exemplo n.º 16
0
def recon_prototype(mesh,
                    data,
                    nc=FLAGS.nc,
                    bs=FLAGS.box_size,
                    batch_size=FLAGS.batch_size,
                    a0=FLAGS.a0,
                    a=FLAGS.af,
                    nsteps=FLAGS.nsteps,
                    dtype=tf.float32):
    """
    Prototype of function computing LPT deplacement.

    Returns output tensorflow and mesh tensorflow tensors
    """
    if dtype == tf.float32:
        npdtype = "float32"
        cdtype = tf.complex64
    elif dtype == tf.float64:
        npdtype = "float64"
        cdtype = tf.complex128
    print(dtype, npdtype)

    # Compute a few things first, using simple tensorflow
    stages = np.linspace(a0, a, nsteps, endpoint=True)

    #graph = mtf.Graph()
    #mesh = mtf.Mesh(graph, "my_mesh")

    # Define the named dimensions
    # Parameters of the small scales decomposition
    n_block_x = FLAGS.nx
    n_block_y = FLAGS.ny
    n_block_z = 1
    halo_size = FLAGS.hsize

    if halo_size >= 0.5 * min(nc // n_block_x, nc // n_block_y,
                              nc // n_block_z):
        new_size = int(0.5 *
                       min(nc // n_block_x, nc // n_block_y, nc // n_block_z))
        print('WARNING: REDUCING HALO SIZE from %d to %d' %
              (halo_size, new_size))
        halo_size = new_size

    # Parameters of the large scales decomposition
    downsampling_factor = 2
    lnc = nc // 2**downsampling_factor

    fx_dim = mtf.Dimension("nx", nc)
    fy_dim = mtf.Dimension("ny", nc)
    fz_dim = mtf.Dimension("nz", nc)

    tfx_dim = mtf.Dimension("tx", nc)
    tfy_dim = mtf.Dimension("ty", nc)
    tfz_dim = mtf.Dimension("tz", nc)

    # Dimensions of the low resolution grid
    x_dim = mtf.Dimension("nx_lr", lnc)
    y_dim = mtf.Dimension("ny_lr", lnc)
    z_dim = mtf.Dimension("nz_lr", lnc)

    tx_dim = mtf.Dimension("tx_lr", lnc)
    ty_dim = mtf.Dimension("ty_lr", lnc)
    tz_dim = mtf.Dimension("tz_lr", lnc)

    nx_dim = mtf.Dimension('nx_block', n_block_x)
    ny_dim = mtf.Dimension('ny_block', n_block_y)
    nz_dim = mtf.Dimension('nz_block', n_block_z)

    sx_dim = mtf.Dimension('sx_block', nc // n_block_x)
    sy_dim = mtf.Dimension('sy_block', nc // n_block_y)
    sz_dim = mtf.Dimension('sz_block', nc // n_block_z)

    k_dims = [tx_dim, ty_dim, tz_dim]

    batch_dim = mtf.Dimension("batch", batch_size)

    klin = np.loadtxt('../flowpm/data/Planck15_a1p00.txt').T[0]
    plin = np.loadtxt('../flowpm/data/Planck15_a1p00.txt').T[1]
    ipklin = iuspline(klin, plin)
    pk_dim = mtf.Dimension("npk", len(plin))
    pk = mtf.import_tf_tensor(mesh, plin.astype(npdtype), shape=[pk_dim])

    # Compute necessary Fourier kernels
    kvec = flowpm.kernels.fftk((nc, nc, nc), symmetric=False, dtype=npdtype)
    kx = mtf.import_tf_tensor(mesh,
                              kvec[0].squeeze().astype(npdtype),
                              shape=[tfx_dim])
    ky = mtf.import_tf_tensor(mesh,
                              kvec[1].squeeze().astype(npdtype),
                              shape=[tfy_dim])
    kz = mtf.import_tf_tensor(mesh,
                              kvec[2].squeeze().astype(npdtype),
                              shape=[tfz_dim])
    kv = [ky, kz, kx]

    # kvec for low resolution grid
    kvec_lr = flowpm.kernels.fftk([lnc, lnc, lnc],
                                  symmetric=False,
                                  dtype=npdtype)

    kx_lr = mtf.import_tf_tensor(mesh,
                                 kvec_lr[0].squeeze().astype(npdtype) /
                                 2**downsampling_factor,
                                 shape=[tx_dim])
    ky_lr = mtf.import_tf_tensor(mesh,
                                 kvec_lr[1].squeeze().astype(npdtype) /
                                 2**downsampling_factor,
                                 shape=[ty_dim])
    kz_lr = mtf.import_tf_tensor(mesh,
                                 kvec_lr[2].squeeze().astype(npdtype) /
                                 2**downsampling_factor,
                                 shape=[tz_dim])
    kv_lr = [ky_lr, kz_lr, kx_lr]

    # kvec for high resolution blocks
    padded_sx_dim = mtf.Dimension('padded_sx_block',
                                  nc // n_block_x + 2 * halo_size)
    padded_sy_dim = mtf.Dimension('padded_sy_block',
                                  nc // n_block_y + 2 * halo_size)
    padded_sz_dim = mtf.Dimension('padded_sz_block',
                                  nc // n_block_z + 2 * halo_size)

    kvec_hr = flowpm.kernels.fftk([
        nc // n_block_x + 2 * halo_size, nc // n_block_y + 2 * halo_size,
        nc // n_block_z + 2 * halo_size
    ],
                                  symmetric=False,
                                  dtype=npdtype)
    kx_hr = mtf.import_tf_tensor(mesh,
                                 kvec_hr[0].squeeze().astype(npdtype),
                                 shape=[padded_sx_dim])
    ky_hr = mtf.import_tf_tensor(mesh,
                                 kvec_hr[1].squeeze().astype(npdtype),
                                 shape=[padded_sy_dim])
    kz_hr = mtf.import_tf_tensor(mesh,
                                 kvec_hr[2].squeeze().astype(npdtype),
                                 shape=[padded_sz_dim])
    kv_hr = [ky_hr, kz_hr, kx_hr]

    # kvec for prior blocks
    prior_sx_dim = mtf.Dimension('prior_sx_block', nc // n_block_x)
    prior_sy_dim = mtf.Dimension('prior_sy_block', nc // n_block_y)
    prior_sz_dim = mtf.Dimension('prior_sz_block', nc // n_block_z)

    kvec_pr = flowpm.kernels.fftk(
        [nc // n_block_x, nc // n_block_y, nc // n_block_z],
        symmetric=False,
        dtype=npdtype)
    kx_pr = mtf.import_tf_tensor(mesh,
                                 kvec_pr[0].squeeze().astype(npdtype),
                                 shape=[prior_sx_dim])
    ky_pr = mtf.import_tf_tensor(mesh,
                                 kvec_pr[1].squeeze().astype(npdtype),
                                 shape=[prior_sy_dim])
    kz_pr = mtf.import_tf_tensor(mesh,
                                 kvec_pr[2].squeeze().astype(npdtype),
                                 shape=[prior_sz_dim])
    kv_pr = [ky_pr, kz_pr, kx_pr]

    shape = [batch_dim, fx_dim, fy_dim, fz_dim]
    lr_shape = [batch_dim, x_dim, y_dim, z_dim]
    hr_shape = [batch_dim, nx_dim, ny_dim, nz_dim, sx_dim, sy_dim, sz_dim]
    part_shape = [batch_dim, fx_dim, fy_dim, fz_dim]

    ## Compute initial initial conditions distributed

    fieldvar = mtf.get_variable(mesh, 'linear', hr_shape)
    input_field = tf.placeholder(data.dtype, [
        batch_size, n_block_x, n_block_y, n_block_z, nc // n_block_x,
        nc // n_block_y, nc // n_block_z
    ])
    mtfinp = mtf.import_tf_tensor(mesh, input_field, shape=hr_shape)
    linearop = mtf.assign(fieldvar, mtfinp)
    #
    field = fieldvar
    initc = mtf.slicewise(lambda x: x[:, 0, 0, 0], [field],
                          output_dtype=tf.float32,
                          output_shape=[batch_dim, fx_dim, fy_dim, fz_dim],
                          name='my_dumb_reshape',
                          splittable_dims=part_shape[:-1] + hr_shape[:4])

    #
    for block_size_dim in hr_shape[-3:]:
        field = mtf.pad(field, [halo_size, halo_size], block_size_dim.name)

    for blocks_dim, block_size_dim in zip(hr_shape[1:4], field.shape[-3:]):
        field = mpm.halo_reduce(field, blocks_dim, block_size_dim, halo_size)

    field = mtf.reshape(field, field.shape + [mtf.Dimension('h_dim', 1)])
    high = field
    low = mesh_utils.downsample(field, downsampling_factor, antialias=True)

    low = mtf.reshape(low, low.shape[:-1])
    high = mtf.reshape(high, high.shape[:-1])

    for block_size_dim in hr_shape[-3:]:
        low = mtf.slice(low, halo_size // 2**downsampling_factor,
                        block_size_dim.size // 2**downsampling_factor,
                        block_size_dim.name)
    # Hack usisng  custom reshape because mesh is pretty dumb
    low = mtf.slicewise(lambda x: x[:, 0, 0, 0], [low],
                        output_dtype=initc.dtype,
                        output_shape=lr_shape,
                        name='my_dumb_reshape',
                        splittable_dims=lr_shape[:-1] + hr_shape[:4])

    # Here we can run our nbody
    if FLAGS.nbody:
        state = mtfpm.lpt_init(low,
                               high,
                               0.1,
                               kv_lr,
                               kv_hr,
                               halo_size,
                               hr_shape,
                               lr_shape,
                               part_shape[1:],
                               downsampling_factor=downsampling_factor,
                               antialias=True)

        final_state = mtfpm.nbody(state,
                                  stages,
                                  lr_shape,
                                  hr_shape,
                                  kv_lr,
                                  kv_hr,
                                  halo_size,
                                  downsampling_factor=downsampling_factor)
    else:
        final_state = mtfpm.lpt_init(low,
                                     high,
                                     stages[-1],
                                     kv_lr,
                                     kv_hr,
                                     halo_size,
                                     hr_shape,
                                     lr_shape,
                                     part_shape[1:],
                                     downsampling_factor=downsampling_factor,
                                     antialias=True)

    # paint the field
    final_field = mtf.zeros(mesh, shape=hr_shape)
    for block_size_dim in hr_shape[-3:]:
        final_field = mtf.pad(final_field, [halo_size, halo_size],
                              block_size_dim.name)
    final_field = mesh_utils.cic_paint(final_field, final_state[0], halo_size)
    # Halo exchange
    for blocks_dim, block_size_dim in zip(hr_shape[1:4],
                                          final_field.shape[-3:]):
        final_field = mpm.halo_reduce(final_field, blocks_dim, block_size_dim,
                                      halo_size)
    # Remove borders
    for block_size_dim in hr_shape[-3:]:
        final_field = mtf.slice(final_field, halo_size, block_size_dim.size,
                                block_size_dim.name)

    final_field = mtf.slicewise(
        lambda x: x[:, 0, 0, 0], [final_field],
        output_dtype=dtype,
        output_shape=[batch_dim, fx_dim, fy_dim, fz_dim],
        name='my_dumb_reshape',
        splittable_dims=part_shape[:-1] + hr_shape[:4])

    mtfdata = mtf.import_tf_tensor(mesh,
                                   tf.convert_to_tensor(data),
                                   shape=shape)

    # Get prior
    k_dims_pr = [d.shape[0] for d in kv_pr]
    k_dims_pr = [k_dims_pr[2], k_dims_pr[0], k_dims_pr[1]]
    cfield = mesh_utils.r2c3d(fieldvar, k_dims_pr, dtype=cdtype)

    def _cwise_prior(kfield, pk, kx, ky, kz):
        kx = tf.reshape(kx, [-1, 1, 1])
        ky = tf.reshape(ky, [1, -1, 1])
        kz = tf.reshape(kz, [1, 1, -1])
        kk = tf.sqrt((kx / bs * nc)**2 + (ky / bs * nc)**2 + (kz / bs * nc)**2)
        kshape = kk.shape
        kk = tf.reshape(kk, [-1])
        pkmesh = tfp.math.interp_regular_1d_grid(
            x=kk,
            x_ref_min=1e-05,
            x_ref_max=1000.0,
            y_ref=pk,
            grid_regularizing_transform=tf.log)
        priormesh = tf.reshape(pkmesh, kshape)
        return tf.abs(kfield) / priormesh**0.5

    cpfield = mtf.cwise(_cwise_prior, [cfield, pk] + kv_pr,
                        output_dtype=tf.float32)
    prior = mtf.reduce_sum(mtf.square(cpfield)) * bs**3

    # Total loss
    diff = (final_field - mtfdata)
    R0 = tf.placeholder(tf.float32, shape=())

    def _cwise_smooth(kfield, kx, ky, kz):
        kx = tf.reshape(kx, [-1, 1, 1])
        ky = tf.reshape(ky, [1, -1, 1])
        kz = tf.reshape(kz, [1, 1, -1])
        kk = (kx / bs * nc)**2 + (ky / bs * nc)**2 + (kz / bs * nc)**2
        wts = tf.cast(tf.exp(-kk * (R0 * bs / nc)**2), kfield.dtype)
        return kfield * wts

    cdiff = mesh_utils.r2c3d(diff, k_dims_pr, dtype=cdtype)
    cdiff = mtf.cwise(_cwise_smooth, [cdiff] + kv_pr, output_dtype=cdtype)
    diff = mesh_utils.c2r3d(cdiff, diff.shape[-3:], dtype=dtype)
    chisq = mtf.reduce_sum(mtf.square(diff))
    loss = chisq + prior

    #return initc, final_field, loss, linearop, input_field
    nyq = np.pi * nc / bs

    def _cwise_highpass(kfield, kx, ky, kz):
        kx = tf.reshape(kx, [-1, 1, 1])
        ky = tf.reshape(ky, [1, -1, 1])
        kz = tf.reshape(kz, [1, 1, -1])
        kk = (kx / bs * nc)**2 + (ky / bs * nc)**2 + (kz / bs * nc)**2
        wts = tf.cast(tf.exp(-kk * (R0 * bs / nc + 1 / nyq)**2), kfield.dtype)
        return kfield * (1 - wts)

    var_grads = mtf.gradients([loss], [fieldvar])
    cgrads = mesh_utils.r2c3d(var_grads[0], k_dims_pr, dtype=cdtype)
    cgrads = mtf.cwise(_cwise_highpass, [cgrads] + kv_pr, output_dtype=cdtype)
    var_grads = [
        mesh_utils.c2r3d(cgrads, var_grads[0].shape[-3:], dtype=dtype)
    ]

    lr = tf.placeholder(tf.float32, shape=())
    update_op = mtf.assign(fieldvar, fieldvar - var_grads[0] * lr)

    return initc, final_field, loss, var_grads, update_op, linearop, input_field, lr, R0
Exemplo n.º 17
0
def recon_prototype(mesh,
                    data,
                    nc=FLAGS.nc,
                    bs=FLAGS.box_size,
                    batch_size=FLAGS.batch_size,
                    a0=FLAGS.a0,
                    a=FLAGS.af,
                    nsteps=FLAGS.nsteps,
                    dtype=tf.float32):
    """
    Prototype of function computing LPT deplacement.

    Returns output tensorflow and mesh tensorflow tensors
    """
    if dtype == tf.float32:
        npdtype = "float32"
        cdtype = tf.complex64
    elif dtype == tf.float64:
        npdtype = "float64"
        cdtype = tf.complex128
    print("Dtype : ", dtype, npdtype)

    # Compute a few things first, using simple tensorflow
    kny = 1 * np.pi * nc / bs
    R1, R2 = 3., 3 * 1.2
    stages = np.linspace(a0, a, nsteps, endpoint=True)

    #graph = mtf.Graph()
    #mesh = mtf.Mesh(graph, "my_mesh")

    # Define the named dimensions
    # Parameters of the small scales decomposition
    n_block_x = FLAGS.nx
    n_block_y = FLAGS.ny
    n_block_z = 1
    halo_size = FLAGS.hsize

    if halo_size >= 0.5 * min(nc // n_block_x, nc // n_block_y,
                              nc // n_block_z):
        new_size = int(0.5 *
                       min(nc // n_block_x, nc // n_block_y, nc // n_block_z))
        print('WARNING: REDUCING HALO SIZE from %d to %d' %
              (halo_size, new_size))
        halo_size = new_size

    # Parameters of the large scales decomposition

    scalar = mtf.Dimension("scalar", 1)

    fx_dim = mtf.Dimension("nx", nc)
    fy_dim = mtf.Dimension("ny", nc)
    fz_dim = mtf.Dimension("nz", nc)

    tfx_dim = mtf.Dimension("tx", nc)
    tfy_dim = mtf.Dimension("ty", nc)
    tfz_dim = mtf.Dimension("tz", nc)

    tx_dim = mtf.Dimension("tx_lr", nc)
    ty_dim = mtf.Dimension("ty_lr", nc)
    tz_dim = mtf.Dimension("tz_lr", nc)

    nx_dim = mtf.Dimension('nx_block', n_block_x)
    ny_dim = mtf.Dimension('ny_block', n_block_y)
    nz_dim = mtf.Dimension('nz_block', n_block_z)

    sx_dim = mtf.Dimension('sx_block', nc // n_block_x)
    sy_dim = mtf.Dimension('sy_block', nc // n_block_y)
    sz_dim = mtf.Dimension('sz_block', nc // n_block_z)

    #k_dims = [tx_dim, ty_dim, tz_dim]

    batch_dim = mtf.Dimension("batch", batch_size)

    klin = np.loadtxt('../flowpm/data/Planck15_a1p00.txt').T[0]
    plin = np.loadtxt('../flowpm/data/Planck15_a1p00.txt').T[1]
    ipklin = iuspline(klin, plin)
    pk_dim = mtf.Dimension("npk", len(plin))
    pk = mtf.import_tf_tensor(mesh, plin.astype(npdtype), shape=[pk_dim])

    # Compute necessary Fourier kernels
    kvec = flowpm.kernels.fftk((nc, nc, nc), symmetric=False)
    kx = mtf.import_tf_tensor(mesh,
                              kvec[0].squeeze().astype('float32'),
                              shape=[tfx_dim])
    ky = mtf.import_tf_tensor(mesh,
                              kvec[1].squeeze().astype('float32'),
                              shape=[tfy_dim])
    kz = mtf.import_tf_tensor(mesh,
                              kvec[2].squeeze().astype('float32'),
                              shape=[tfz_dim])
    kv = [ky, kz, kx]

    # kvec for low resolution grid
    kvec_lr = flowpm.kernels.fftk([nc, nc, nc], symmetric=False)
    kx_lr = mtf.import_tf_tensor(mesh,
                                 kvec_lr[0].squeeze().astype('float32'),
                                 shape=[tx_dim])
    ky_lr = mtf.import_tf_tensor(mesh,
                                 kvec_lr[1].squeeze().astype('float32'),
                                 shape=[ty_dim])
    kz_lr = mtf.import_tf_tensor(mesh,
                                 kvec_lr[2].squeeze().astype('float32'),
                                 shape=[tz_dim])
    kv_lr = [ky_lr, kz_lr, kx_lr]

    shape = [batch_dim, fx_dim, fy_dim, fz_dim]
    lr_shape = [batch_dim, fx_dim, fy_dim, fz_dim]
    hr_shape = [batch_dim, nx_dim, ny_dim, nz_dim, sx_dim, sy_dim, sz_dim]
    part_shape = [batch_dim, fx_dim, fy_dim, fz_dim]

    #
    # Begin simulation

    ## Compute initial initial conditions distributed
    #initc = mtfpm.linear_field(mesh, shape, bs, nc, pk, kv)

    fieldvar = mtf.get_variable(mesh, 'linear', part_shape)
    input_field = tf.placeholder(data.dtype, [batch_size, nc, nc, nc])
    mtfinp = mtf.import_tf_tensor(mesh, input_field, shape=part_shape)
    linearop = mtf.assign(fieldvar, mtfinp)

    #field = fieldvar
    initc = fieldvar

    print("initc : ", initc)

    # Here we can run our nbody
    if FLAGS.nbody:
        state = mtfpm.lpt_init_single(
            fieldvar,
            a0,
            kv_lr,
            halo_size,
            lr_shape,
            hr_shape,
            part_shape[1:],
            antialias=True,
        )
        # Here we can run our nbody
        final_state = mtfpm.nbody_single(state, stages, lr_shape, hr_shape,
                                         kv_lr, halo_size)
    else:
        final_state = mtfpm.lpt_init_single(
            initc,
            stages[-1],
            kv_lr,
            halo_size,
            lr_shape,
            hr_shape,
            part_shape[1:],
            antialias=True,
        )

    # paint the field
    final_field = mtf.zeros(mesh, shape=hr_shape)
    for block_size_dim in hr_shape[-3:]:
        final_field = mtf.pad(final_field, [halo_size, halo_size],
                              block_size_dim.name)
    final_field = mesh_utils.cic_paint(final_field, final_state[0], halo_size)
    # Halo exchange
    for blocks_dim, block_size_dim in zip(hr_shape[1:4],
                                          final_field.shape[-3:]):
        final_field = mpm.halo_reduce(final_field, blocks_dim, block_size_dim,
                                      halo_size)
    # Remove borders
    for block_size_dim in hr_shape[-3:]:
        final_field = mtf.slice(final_field, halo_size, block_size_dim.size,
                                block_size_dim.name)

    final_field = mtf.slicewise(
        lambda x: x[:, 0, 0, 0], [final_field],
        output_dtype=dtype,
        output_shape=[batch_dim, fx_dim, fy_dim, fz_dim],
        name='my_dumb_reshape',
        splittable_dims=part_shape[:-1] + hr_shape[:4])
    ##
    x = final_field

    ppars, mpars, kernel = setupfnn()
    pwts, pbias, pmx, psx = ppars
    mwts, mbias, mmx, msx, mmy, msy = mpars
    msy, mmy = msy[0], mmy[0]
    print("mmy : ", mmy)
    size = 3

    k_dims = [d.shape[0] for d in kv]
    k_dims = [k_dims[2], k_dims[0], k_dims[1]]
    tfnc, tfbs = float_to_mtf(nc * 1., mesh,
                              scalar), float_to_mtf(bs, mesh, scalar)

    x1f = mesh_utils.r2c3d(x, k_dims, dtype=cdtype)
    x1f = mtf.cwise(cwise_decic, [x1f] + kv + [tfnc, tfbs],
                    output_dtype=cdtype)
    x1d = mesh_utils.c2r3d(x1f, x.shape[-3:], dtype=dtype)
    x1d = mtf.add(x1d, -1.)

    x1f0 = mesh_utils.r2c3d(x1d, k_dims, dtype=cdtype)
    x1f = mtf.cwise(cwise_fingauss,
                    [x1f0, float_to_mtf(R1, mesh, scalar)] + kv + [tfnc, tfbs],
                    output_dtype=cdtype)
    x1 = mesh_utils.c2r3d(x1f, x1d.shape[-3:], dtype=dtype)
    x2f = mtf.cwise(cwise_fingauss,
                    [x1f0, float_to_mtf(R2, mesh, scalar)] + kv + [tfnc, tfbs],
                    output_dtype=cdtype)
    x2 = mesh_utils.c2r3d(x2f, x1d.shape[-3:], dtype=dtype)
    x12 = x1 - x2

    width = tf.placeholder(tf.float32, shape=())

    def apply_pwts(x, x1, x2):
        #y = tf.expand_dims(x, axis=-1)

        y = tf.nn.conv3d(tf.expand_dims(x, axis=-1), kernel, [1, 1, 1, 1, 1],
                         'SAME')
        y1 = tf.nn.conv3d(tf.expand_dims(x1, axis=-1), kernel, [1, 1, 1, 1, 1],
                          'SAME')
        y2 = tf.nn.conv3d(tf.expand_dims(x2, axis=-1), kernel, [1, 1, 1, 1, 1],
                          'SAME')
        #y = tf.nn.conv3d(tf.expand_dims(tfwrap3D(x), -1), kernel, [1, 1, 1, 1, 1], 'VALID')
        #y1 = tf.nn.conv3d(tf.expand_dims(tfwrap3D(x1), -1), kernel, [1, 1, 1, 1, 1], 'VALID')
        #y2 = tf.nn.conv3d(tf.expand_dims(tfwrap3D(x12), -1), kernel, [1, 1, 1, 1, 1], 'VALID')

        yy = tf.concat([y, y1, y2], axis=-1)
        yy = yy - pmx
        yy = yy / psx
        yy1 = tf.nn.relu(tf.matmul(yy, pwts[0]) + pbias[0])
        yy2 = tf.nn.relu(tf.matmul(yy1, pwts[1]) + pbias[1])
        yy3 = tf.matmul(yy2, pwts[2]) + pbias[2]
        pmodel = tf.nn.sigmoid(width * yy3)
        return pmodel[..., 0]

    pmodel = mtf.slicewise(
        apply_pwts,
        [x, x1, x12],
        output_dtype=tf.float32,
        output_shape=part_shape,  # + [mtf.Dimension('c_dim', 81)],
        name='apply_pwts',
        splittable_dims=lr_shape[:-1] + hr_shape[1:4] + part_shape[1:3])

    def apply_mwts(x, x1, x2):
        #y = tf.expand_dims(x, axis=-1)

        zz = tf.concat([
            tf.expand_dims(x, -1),
            tf.expand_dims(x1, -1),
            tf.expand_dims(x2, -1)
        ],
                       axis=-1)
        zz = zz - mmx
        zz = zz / msx
        zz1 = tf.nn.elu(tf.matmul(zz, mwts[0]) + mbias[0])
        zz2 = tf.nn.elu(tf.matmul(zz1, mwts[1]) + mbias[1])
        zz3 = tf.matmul(zz2, mwts[2]) + mbias[2]
        mmodel = zz3 * msy + mmy
        return mmodel[..., 0]

    mmodel = mtf.slicewise(
        apply_mwts,
        [x, x1, x12],
        output_dtype=tf.float32,
        output_shape=part_shape,  # + [mtf.Dimension('c_dim', 81)],
        name='apply_mwts',
        splittable_dims=lr_shape[:-1] + hr_shape[1:4] + part_shape[1:3])

    model = pmodel * mmodel

    mtfdata = mtf.import_tf_tensor(mesh,
                                   tf.convert_to_tensor(data),
                                   shape=shape)

    # Get prior
    #k_dims = [d.shape[0] for d in kv]
    #k_dims = [k_dims[2], k_dims[0], k_dims[1]]
    k_dims_pr = [d.shape[0] for d in kv]
    k_dims_pr = [k_dims_pr[2], k_dims_pr[0], k_dims_pr[1]]
    cfield = mesh_utils.r2c3d(fieldvar, k_dims_pr, dtype=cdtype)

    def _cwise_prior(kfield, pk, kx, ky, kz):
        kx = tf.reshape(kx, [-1, 1, 1])
        ky = tf.reshape(ky, [1, -1, 1])
        kz = tf.reshape(kz, [1, 1, -1])
        kk = tf.sqrt((kx / bs * nc)**2 + (ky / bs * nc)**2 + (kz / bs * nc)**2)
        kshape = kk.shape
        kk = tf.reshape(kk, [-1])
        pkmesh = tfp.math.interp_regular_1d_grid(
            x=kk,
            x_ref_min=1e-05,
            x_ref_max=1000.0,
            y_ref=pk,
            grid_regularizing_transform=tf.log)
        priormesh = tf.reshape(pkmesh, kshape)
        return tf.abs(kfield) / priormesh**0.5

    cpfield = mtf.cwise(_cwise_prior, [cfield, pk] + kv,
                        output_dtype=tf.float32)
    prior = mtf.reduce_sum(mtf.square(cpfield)) * bs**3 * nc**3

    # Total loss
    #diff = (model - mtfdata)
    modelf = mesh_utils.r2c3d(model, k_dims, dtype=cdtype)
    modelsmf = mtf.cwise(cwise_fingauss,
                         [modelf, float_to_mtf(R1, mesh, scalar)] + kv +
                         [tfnc, tfbs],
                         output_dtype=cdtype)
    modelsm = mesh_utils.c2r3d(modelsmf, x1d.shape[-3:], dtype=dtype)
    #dataf = mesh_utils.r2c3d(mtfdata, k_dims, dtype=cdtype)
    #datasmf = mtf.cwise(cwise_fingauss, [dataf, float_to_mtf(R1, mesh, scalar)] + kv + [tfnc, tfbs], output_dtype=cdtype)
    #datasm = mesh_utils.c2r3d(datasmf, x1d.shape[-3:], dtype=dtype)

    ##Anneal
    R0 = tf.placeholder(tf.float32, shape=())
    M0 = tf.placeholder(tf.float32, shape=())
    off, istd = tf.placeholder(tf.float32, shape=data.shape), tf.placeholder(
        tf.float32, shape=data.shape)
    mtfoff = mtf.import_tf_tensor(mesh, off, shape=shape)
    mtfistd = mtf.import_tf_tensor(mesh, istd, shape=shape)
    diff = mtf.log(modelsm + M0) - mtf.log(mtfdata + M0)
    #diff = diff / 0.25
    #diff = (diff + mtfoff)*mtfistd #For some reason, doing things wrong this one
    diff = (diff + mtfoff) / 0.25

    def _cwise_smooth(kfield, kx, ky, kz):
        kx = tf.reshape(kx, [-1, 1, 1])
        ky = tf.reshape(ky, [1, -1, 1])
        kz = tf.reshape(kz, [1, 1, -1])
        kk = (kx / bs * nc)**2 + (ky / bs * nc)**2 + (kz / bs * nc)**2
        wts = tf.cast(tf.exp(-kk * (R0 * bs / nc)**2), kfield.dtype)
        return kfield * wts

    cdiff = mesh_utils.r2c3d(diff, k_dims_pr, dtype=cdtype)
    cdiff = mtf.cwise(_cwise_smooth, [cdiff] + kv, output_dtype=cdtype)
    diff = mesh_utils.c2r3d(cdiff, diff.shape[-3:], dtype=dtype)
    chisq = mtf.reduce_sum(mtf.square(diff))
    loss = chisq + prior

    #return initc, final_field, loss, linearop, input_field
    nyq = np.pi * nc / bs

    def _cwise_highpass(kfield, kx, ky, kz):
        kx = tf.reshape(kx, [-1, 1, 1])
        ky = tf.reshape(ky, [1, -1, 1])
        kz = tf.reshape(kz, [1, 1, -1])
        kk = (kx / bs * nc)**2 + (ky / bs * nc)**2 + (kz / bs * nc)**2
        wts = tf.cast(tf.exp(-kk * (R0 * bs / nc + 1 / nyq)**2), kfield.dtype)
        return kfield * (1 - wts)

    var_grads = mtf.gradients([loss], [fieldvar])
    cgrads = mesh_utils.r2c3d(var_grads[0], k_dims_pr, dtype=cdtype)
    cgrads = mtf.cwise(_cwise_highpass, [cgrads] + kv, output_dtype=cdtype)
    var_grads = [mesh_utils.c2r3d(cgrads, diff.shape[-3:], dtype=dtype)]

    lr = tf.placeholder(tf.float32, shape=())
    update_op = mtf.assign(fieldvar, fieldvar - var_grads[0] * lr)

    return initc, model, loss, var_grads, update_op, linearop, input_field, lr, R0, M0, width, chisq, prior, off, istd
Exemplo n.º 18
0
def model_fn(features, labels, mode, params):
    """The model_fn argument for creating an Estimator."""

    #tf.logging.info("features = %s labels = %s mode = %s params=%s" %
    #              (features, labels, mode, params))

    global_step = tf.train.get_global_step()
    graph = mtf.Graph()
    mesh = mtf.Mesh(graph, "my_mesh")
    fields, metrics, kv = recon_model(mesh, features['datasm'],
                                      features['bparams'],
                                      features['ipkerror'], features['R0'],
                                      features['x0'])
    fieldvar, final, model = fields
    chisq, prior, loss = metrics

    ##
    mesh_shape = mtf.convert_to_shape(FLAGS.mesh_shape)
    mesh_size = mesh_shape.size
    layout_rules = [("nx_lr", "row"), ("ny_lr", "col"), ("nx", "row"),
                    ("ny", "col"), ("ty", "row"), ("tz", "col"),
                    ("ty_lr", "row"), ("tz_lr", "col"), ("nx_block", "row"),
                    ("ny_block", "col")]
    devices = [""] * mesh_size
    mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl(
        mesh_shape, layout_rules, devices)
    print(mesh_impl)

    ##
    if mode == tf.estimator.ModeKeys.TRAIN:
        var_grads = mtf.gradients(
            [loss], [v.outputs[0] for v in graph.trainable_variables])
        optimizer = mtf.optimize.AdamWeightDecayOptimizer(features['lr'])
        update_ops = optimizer.apply_grads(var_grads,
                                           graph.trainable_variables)

    lowering = mtf.Lowering(graph, {mesh: mesh_impl})
    restore_hook = mtf.MtfRestoreHook(lowering)
    #
    tf_init = lowering.export_to_tf_tensor(fieldvar)
    tf_final = lowering.export_to_tf_tensor(final)
    tf_model = lowering.export_to_tf_tensor(model)
    tf_loss = lowering.export_to_tf_tensor(loss)
    tf_chisq = lowering.export_to_tf_tensor(chisq)
    tf_prior = lowering.export_to_tf_tensor(prior)

    ##Predict
    if mode == tf.estimator.ModeKeys.PREDICT:
        tf_loss = lowering.export_to_tf_tensor(loss)
        tf.summary.scalar("loss", tf_loss)
        tf.summary.scalar("chisq", tf_chisq)
        tf.summary.scalar("prior", tf_prior)
        predictions = {
            "ic": tf_init,
            "final": tf_final,
            "model": tf_model,
        }
        return tf.estimator.EstimatorSpec(
            mode=tf.estimator.ModeKeys.PREDICT,
            predictions=predictions,
            prediction_hooks=[restore_hook],
            export_outputs={
                "model": tf.estimator.export.PredictOutput(
                    predictions)  #TODO: is classify a keyword?
            })

    ##Train
    if mode == tf.estimator.ModeKeys.TRAIN:
        tf_update_ops = [lowering.lowered_operation(op) for op in update_ops]
        tf_update_ops.append(tf.assign_add(global_step, 1))
        train_op = tf.group(tf_update_ops)
        saver = tf.train.Saver(tf.global_variables(),
                               sharded=True,
                               max_to_keep=1,
                               keep_checkpoint_every_n_hours=2,
                               defer_build=False,
                               save_relative_paths=True)
        tf.add_to_collection(tf.GraphKeys.SAVERS, saver)
        saver_listener = mtf.MtfCheckpointSaverListener(lowering)
        saver_hook = tf.train.CheckpointSaverHook(fpath,
                                                  save_steps=1000,
                                                  saver=saver,
                                                  listeners=[saver_listener])

        logging_hook = tf.train.LoggingTensorHook(
            {
                "loss": tf_loss,
                "chisq": tf_chisq,
                "prior": tf_prior
            },
            every_n_iter=10)

        # Name tensors to be logged with LoggingTensorHook.
        tf.identity(tf_loss, "loss")
        tf.identity(tf_prior, "prior")
        tf.identity(tf_chisq, "chisq")

        # Save accuracy scalar to Tensorboard output.
        tf.summary.scalar("loss", tf_loss)
        tf.summary.scalar("chisq", tf_chisq)
        tf.summary.scalar("prior", tf_prior)

        # restore_hook must come before saver_hook
        return tf.estimator.EstimatorSpec(
            tf.estimator.ModeKeys.TRAIN,
            loss=tf_loss,
            train_op=train_op,
            training_chief_hooks=[restore_hook, saver_hook, logging_hook])

    ##Eval
    if mode == tf.estimator.ModeKeys.EVAL:
        return tf.estimator.EstimatorSpec(
            mode=tf.estimator.ModeKeys.EVAL,
            loss=tf_loss,
            evaluation_hooks=[restore_hook],
            eval_metric_ops={
                "loss": tf_loss,
                "chisq": tf_chisq,
                #tf.metrics.accuracy(
                #    labels=labels, predictions=tf.argmax(tf_logits, axis=1)),
            })
Exemplo n.º 19
0
    def testPool(self, pooling_method):
        batch = 2
        depth = 3
        height = 4
        width = 6
        channels = 3
        tf.random.set_random_seed(1234)
        inputs = tf.random_normal([batch, depth, height, width, channels])

        stride_d = 3
        stride_h = 2
        stride_w = 3

        graph = mtf.Graph()
        mesh = mtf.Mesh(graph, "my_mesh")
        batch_dim = mtf.Dimension("batch", batch)
        depth_dim = mtf.Dimension("depth", depth)
        height_dim = mtf.Dimension("height", height)
        width_dim = mtf.Dimension("width", width)
        channels_dim = mtf.Dimension("channels", channels)

        mtf_inputs = mtf.import_tf_tensor(mesh,
                                          inputs,
                                          shape=mtf.Shape([
                                              batch_dim, depth_dim, height_dim,
                                              width_dim, channels_dim
                                          ]))

        if pooling_method == "MAX_2D":
            mtf_outputs = mtf.layers.max_pool2d(mtf_inputs,
                                                ksize=(stride_h, stride_w))
            inputs = tf.reshape(inputs,
                                [batch * depth, height, width, channels])
            expected_outputs = tf.keras.layers.MaxPooling2D(
                (stride_h, stride_w))(inputs)
            expected_outputs = tf.reshape(expected_outputs, [
                batch, depth,
                int(height / stride_h),
                int(width / stride_w), channels
            ])

        elif pooling_method == "AVG_2D":
            mtf_outputs = mtf.layers.avg_pool2d(mtf_inputs,
                                                ksize=(stride_h, stride_w))
            inputs = tf.reshape(inputs,
                                [batch * depth, height, width, channels])
            expected_outputs = tf.keras.layers.AveragePooling2D(
                (stride_h, stride_w))(inputs)
            expected_outputs = tf.reshape(expected_outputs, [
                batch, depth,
                int(height / stride_h),
                int(width / stride_w), channels
            ])

        elif pooling_method == "MAX_3D":
            mtf_outputs = mtf.layers.max_pool3d(
                mtf_inputs, ksize=[stride_d, stride_h, stride_w])
            expected_outputs = tf.keras.layers.MaxPooling3D(
                [stride_d, stride_h, stride_w])(inputs)

        elif pooling_method == "AVG_3D":
            mtf_outputs = mtf.layers.avg_pool3d(
                mtf_inputs, ksize=[stride_d, stride_h, stride_w])
            expected_outputs = tf.keras.layers.AveragePooling3D(
                [stride_d, stride_h, stride_w])(inputs)

        mtf_gradient = mtf.gradients([mtf_outputs], [mtf_inputs])[0]

        mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl(shape=[],
                                                              layout={},
                                                              devices=[""])
        lowering = mtf.Lowering(graph, {mesh: mesh_impl})
        actual_outputs = lowering.export_to_tf_tensor(mtf_outputs)
        actual_gradient = lowering.export_to_tf_tensor(mtf_gradient)

        tf_group = lowering.copy_masters_to_slices()
        init = tf.global_variables_initializer()
        self.evaluate(init)
        self.evaluate(tf_group)
        actual, expected = self.evaluate([actual_outputs, expected_outputs])
        self.assertAllClose(actual, expected)

        actual = self.evaluate(actual_gradient)
        if pooling_method == "MAX_2D":
            expected_non_zeros = batch * depth * height * width * channels / (
                stride_h * stride_w)
            self.assertEqual(np.count_nonzero(actual), expected_non_zeros)

        elif pooling_method == "AVG_2D":
            expected = np.ones((batch, depth, height, width, channels),
                               dtype=np.float32) / stride_h / stride_w
            self.assertAllClose(actual, expected)

        elif pooling_method == "MAX_3D":
            expected_non_zeros = batch * depth * height * width * channels / (
                stride_d * stride_h * stride_w)
            self.assertEqual(np.count_nonzero(actual), expected_non_zeros)

        elif pooling_method == "AVG_3D":
            expected = np.ones(
                (batch, depth, height, width, channels),
                dtype=np.float32) / stride_d / stride_h / stride_w
            self.assertAllClose(actual, expected)
Exemplo n.º 20
0
def create_optimizer(loss,
                     init_lr,
                     num_train_steps,
                     num_warmup_steps,
                     max_optimized_variable_size=None,
                     optimizer="adam",
                     clip_gradients=True):
    """Creates an optimizer training op."""
    global_step = tf.train.get_or_create_global_step()
    mesh = loss.mesh

    if init_lr:
        # Implements linear decay of the learning rate.
        learning_rate = tf.constant(value=init_lr, shape=[], dtype=tf.float32)
        learning_rate = tf.train.polynomial_decay(learning_rate,
                                                  global_step,
                                                  num_train_steps,
                                                  end_learning_rate=0.0,
                                                  power=1.0,
                                                  cycle=False)
        # Implements linear warmup. I.e., if global_step < num_warmup_steps, the
        # learning rate will be `global_step/num_warmup_steps * init_lr`.
        if num_warmup_steps:
            global_steps_int = tf.cast(global_step, tf.int32)
            warmup_steps_int = tf.constant(num_warmup_steps, dtype=tf.int32)

            global_steps_float = tf.cast(global_steps_int, tf.float32)
            warmup_steps_float = tf.cast(warmup_steps_int, tf.float32)

            warmup_percent_done = global_steps_float / warmup_steps_float
            warmup_learning_rate = init_lr * warmup_percent_done

            is_warmup = tf.cast(global_steps_int < warmup_steps_int,
                                tf.float32)
            learning_rate = ((1.0 - is_warmup) * learning_rate +
                             is_warmup * warmup_learning_rate)

        mtf_learning_rate = mtf.import_tf_tensor(mesh, learning_rate, [])
    else:
        if optimizer == "adam":
            raise ValueError("Adam does not have a default learning rate")
        learning_rate = None
        mtf_learning_rate = None

    # It is recommended that you use this optimizer for fine tuning, since this
    # is how the model was trained (note that the Adam m/v variables are NOT
    # loaded from init_checkpoint.)
    if optimizer == "adam":
        optimizer = mtf_optimize.AdamWeightDecayOptimizer(
            learning_rate=mtf_learning_rate,
            weight_decay_rate=0.01,
            beta_1=0.9,
            beta_2=0.999,
            epsilon=1e-6,
            exclude_from_weight_decay=["LayerNorm", "layer_norm", "bias"])
    elif optimizer == "adafactor":
        optimizer = mtf_optimize.AdafactorOptimizer(
            learning_rate=learning_rate, min_dim_size_to_factor=32)
    else:
        raise ValueError("unknown optimizer")

    trainable_variables = mesh.graph.trainable_variables
    if max_optimized_variable_size:
        trainable_variables = [
            t for t in trainable_variables
            if t.shape.size <= max_optimized_variable_size
        ]

    var_grads = mtf.gradients([loss],
                              [v.outputs[0] for v in trainable_variables])

    # This is how the model was pre-trained.
    if clip_gradients:
        (var_grads,
         _) = clip_by_global_norm(var_grads,
                                  clip_norm=mtf.constant(mesh,
                                                         1.0,
                                                         dtype=tf.float32))

    update_ops = optimizer.apply_grads(var_grads, trainable_variables)

    return learning_rate, update_ops
Exemplo n.º 21
0
def model_fn(features, labels, mode, params):
  """A model is called by TpuEstimator."""
  del labels
  global_step = tf.train.get_global_step()
  graph = mtf.Graph()
  mesh_shape = mtf.convert_to_shape(FLAGS.mesh_shape)
  layout_rules = mtf.convert_to_layout_rules(FLAGS.layout)
  if FLAGS.use_tpu:
    ctx = params['context']
    num_hosts = ctx.num_hosts
    host_placement_fn = ctx.tpu_host_placement_function
    device_list = [host_placement_fn(host_id=t) for t in range(num_hosts)]
    tf.logging.info('device_list = %s' % device_list,)
    # TODO(ylc): Better estimation of replica cache size?
    replica_cache_size = 300 * 1000000  # 300M per replica
    # Worker 0 caches all the TPU binaries.
    worker0_mem = replica_cache_size * ctx.num_replicas
    devices_memeory_usage = [worker0_mem] + [0] * (num_hosts - 1)
    var_placer = mtf.utils.BalancedVariablePlacer(device_list,
                                                  devices_memeory_usage)
    mesh_devices = [''] * mesh_shape.size
    mesh_impl = mtf.simd_mesh_impl.SimdMeshImpl(
        mesh_shape, layout_rules, mesh_devices, ctx.device_assignment)
  else:
    var_placer = None
    mesh_devices = [''] * mesh_shape.size
    mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl(
        mesh_shape, layout_rules, mesh_devices)
  mesh = mtf.Mesh(graph, 'my_mesh', var_placer)

  with mtf.utils.outside_all_rewrites():
    logits, loss = toy_model(features, mesh)

  # TRAIN mode
  if mode == tf.estimator.ModeKeys.TRAIN:
    var_grads = mtf.gradients([loss],
                              [v.outputs[0] for v in graph.trainable_variables])
    if FLAGS.optimizer == 'Adafactor':
      optimizer = mtf.optimize.AdafactorOptimizer()
    else:
      assert FLAGS.optimizer == 'SGD'
      optimizer = mtf.optimize.SgdOptimizer(lr=FLAGS.lr)
    update_ops = optimizer.apply_grads(var_grads, graph.trainable_variables)
  else:
    # for now, we can only export fully-replicated tensors.
    fully_replicated_logits = mtf.anonymize(logits)

  lowering = mtf.Lowering(graph, {mesh: mesh_impl})

  tf_loss = tf.to_float(lowering.export_to_tf_tensor(loss))

  if mode == tf.estimator.ModeKeys.TRAIN:
    tf_update_ops = [lowering.lowered_operation(op) for op in update_ops]
    tf_update_ops.append(tf.assign_add(global_step, 1))
    tf.logging.info('tf_update_ops: {}'.format(tf_update_ops))
    train_op = tf.group(tf_update_ops)
  else:
    tf_logits = lowering.export_to_tf_tensor(fully_replicated_logits)

  with mtf.utils.outside_all_rewrites():
    # Copy master variables to slices. Must be called first.
    restore_hook = mtf.MtfRestoreHook(lowering)
    if mode == tf.estimator.ModeKeys.TRAIN:
      saver = tf.train.Saver(
          tf.global_variables(),
          sharded=True,
          max_to_keep=10,
          keep_checkpoint_every_n_hours=2,
          defer_build=False,
          save_relative_paths=True)
      tf.add_to_collection(tf.GraphKeys.SAVERS, saver)
      saver_listener = mtf.MtfCheckpointSaverListener(lowering)
      saver_hook = tf.train.CheckpointSaverHook(
          FLAGS.model_dir,
          save_steps=1000,
          saver=saver,
          listeners=[saver_listener])

      return tpu_estimator.TPUEstimatorSpec(
          tf.estimator.ModeKeys.TRAIN,
          loss=tf_loss,
          train_op=train_op,
          training_hooks=[restore_hook, saver_hook])
    elif mode == tf.estimator.ModeKeys.EVAL:

      def metric_fn(tf_logits):
        mean_logits = tf.metrics.mean(tf_logits)
        return {'mean_logits': mean_logits}

      eval_metrics = (metric_fn, [tf_logits])

      return tpu_estimator.TPUEstimatorSpec(
          tf.estimator.ModeKeys.EVAL,
          evaluation_hooks=[restore_hook],
          loss=tf_loss,
          eval_metrics=eval_metrics)
Exemplo n.º 22
0
    def my_model_fn(features, labels, mode, params=None, config=None):
        """Estimator model function.

    Args:
      features: input features dictionary
      labels: ignored
      mode: a tf.estimator.ModeKeys
      params: something
      config: something

    Returns:
      something
    """
        del labels, config
        global_step = tf.train.get_global_step()
        if use_tpu:
            ctx = params["context"]
            num_hosts = ctx.num_hosts
            host_placement_fn = ctx.tpu_host_placement_function
            device_list = [
                host_placement_fn(host_id=t) for t in range(num_hosts)
            ]
            # TODO(ylc): Better estimation of replica cache size?
            replica_cache_size = 300 * 1000000  # 300M per replica
            # Worker 0 caches all the TPU binaries.
            worker0_mem = replica_cache_size * ctx.num_replicas
            devices_memeory_usage = [worker0_mem] + [0] * (num_hosts - 1)
            var_placer = mtf.utils.BalancedVariablePlacer(
                device_list, devices_memeory_usage)
            mesh_devices = [""] * mesh_shape.size
            mesh_impl = mtf.simd_mesh_impl.SimdMeshImpl(
                mesh_shape, layout_rules, mesh_devices, ctx.device_assignment)
        else:
            var_placer = None
            mesh_devices = [""] * mesh_shape.size
            mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl(
                mesh_shape, layout_rules, mesh_devices)

        graph = mtf.Graph()
        mesh = mtf.Mesh(graph, "my_mesh", var_placer)

        def _import_feature(key):
            """Import a feature from the features dictionary into a mtf.Tensor.

      Args:
        key: a string

      Returns:
        a mtf.Tensor with dtype int32 and shape [batch_dim, length_dim]
      """
            batch_dim = mtf.Dimension("batch", batch_size)
            length_dim = mtf.Dimension("length", length)
            mtf_shape = mtf.Shape([batch_dim, length_dim])
            if key not in features:
                return None
            x = tf.to_int32(features[key])
            if not use_tpu:
                x = tf.Print(x, [x],
                             "import feature %s" % key,
                             summarize=1000,
                             first_n=1)
            return mtf.import_fully_replicated(mesh, x, mtf_shape, name=key)

        # PREDICT mode
        if mode == tf.estimator.ModeKeys.PREDICT:
            inputs = _import_feature("inputs")
            if text2self:
                mtf_samples = transformer_model.sample_autoregressive(
                    inputs,
                    variable_dtype=variable_dtype,
                    temperature=temperature)
            else:
                mtf_samples = transformer_model.decode(
                    inputs,
                    variable_dtype=variable_dtype,
                    beam_size=beam_size,
                    alpha=alpha,
                    temperature=temperature)
            mtf_samples = mtf.anonymize(mtf_samples)
            lowering = mtf.Lowering(graph, {mesh: mesh_impl},
                                    autostack=autostack)
            outputs = lowering.export_to_tf_tensor(mtf_samples)
            predictions = {"outputs": outputs}
            return tpu_estimator.TPUEstimatorSpec(
                mode=tf.estimator.ModeKeys.PREDICT,
                predictions=predictions,
                prediction_hooks=[mtf.MtfRestoreHook(lowering)])

        targets = _import_feature("targets")
        anon_targets = mtf.anonymize(targets)
        if text2self:
            _, length_dim = targets.shape
            inputs = mtf.shift(targets, offset=1, dim=length_dim, wrap=False)
        else:
            inputs = _import_feature("inputs")

        if text2self:
            position_kwargs = dict(
                sequence_id=_import_feature("targets_segmentation"),
                position=_import_feature("targets_position"),
            )
        else:
            position_kwargs = dict(
                encoder_sequence_id=_import_feature("inputs_segmentation"),
                decoder_sequence_id=_import_feature("targets_segmentation"),
                encoder_position=_import_feature("inputs_position"),
                decoder_position=_import_feature("targets_position"),
            )

        logits, loss = transformer_model.call_simple(
            inputs=inputs,
            targets=targets,
            compute_loss=True,
            mode=mode,
            variable_dtype=variable_dtype,
            **position_kwargs)

        if use_tpu and logits is not None:
            logits = mtf.anonymize(logits)

        # TRAIN mode
        if mode == tf.estimator.ModeKeys.TRAIN:
            var_grads = mtf.gradients(
                [loss], [v.outputs[0] for v in graph.trainable_variables])
            optimizer = mtf.optimize.AdafactorOptimizer()
            update_ops = optimizer.apply_grads(var_grads,
                                               graph.trainable_variables)

        lowering = mtf.Lowering(graph, {mesh: mesh_impl}, autostack=autostack)

        tf_loss = lowering.export_to_tf_tensor(loss)
        tf_loss = tf.to_float(tf_loss)
        if not use_tpu:
            tf_loss = tf.Print(tf_loss,
                               [tf_loss, tf.train.get_global_step()],
                               "step, tf_loss")
        if logits and mode != tf.estimator.ModeKeys.TRAIN:
            tf_logits = lowering.export_to_tf_tensor(logits)

        if mode == tf.estimator.ModeKeys.TRAIN:
            tf_update_ops = [
                lowering.lowered_operation(op) for op in update_ops
            ]
            tf_update_ops.append(tf.assign_add(global_step, 1))
            train_op = tf.group(tf_update_ops)

        with mtf.utils.outside_all_rewrites():
            # Copy master variables to slices. Must be called first.
            restore_hook = mtf.MtfRestoreHook(lowering)
            saver = tf.train.Saver(tf.global_variables(),
                                   sharded=True,
                                   max_to_keep=10,
                                   keep_checkpoint_every_n_hours=2,
                                   defer_build=False,
                                   save_relative_paths=True)
            tf.add_to_collection(tf.GraphKeys.SAVERS, saver)
            saver_listener = mtf.MtfCheckpointSaverListener(lowering)
            saver_hook = tf.train.CheckpointSaverHook(
                model_dir,
                save_steps=1000,
                saver=saver,
                listeners=[saver_listener])

            if mode == tf.estimator.ModeKeys.TRAIN:
                if use_tpu:
                    return tpu_estimator.TPUEstimatorSpec(
                        mode=tf.estimator.ModeKeys.TRAIN,
                        loss=tf_loss,
                        train_op=train_op,
                        training_hooks=[restore_hook, saver_hook])
                else:
                    return tf.estimator.EstimatorSpec(
                        tf.estimator.ModeKeys.TRAIN,
                        loss=tf_loss,
                        train_op=train_op,
                        training_chief_hooks=[restore_hook, saver_hook])
            elif mode == tf.estimator.ModeKeys.EVAL:

                def padded_neg_log_perplexity(logits, labels):
                    weights = tf.to_float(tf.not_equal(labels, 0))
                    xent = tf.nn.sparse_softmax_cross_entropy_with_logits(
                        labels=labels, logits=logits)
                    return {
                        "neg_log_perplexity": tf.metrics.mean(-xent, weights)
                    }

                labels = lowering.export_to_tf_tensor(anon_targets)
                eval_metrics = (padded_neg_log_perplexity, [tf_logits, labels])
                return tpu_estimator.TPUEstimatorSpec(
                    tf.estimator.ModeKeys.EVAL,
                    evaluation_hooks=[restore_hook],
                    loss=tf_loss,
                    eval_metrics=eval_metrics)
Exemplo n.º 23
0
def get_optimizer(mesh, loss, params, variable_dtype, inp_var_grads=None):
    """Creates and returns an optimizer training op."""
    global_step = tf.train.get_or_create_global_step()

    # set defaults
    end_step = params.get("lr_decay_end", params["train_steps"])
    lr_decay = params.get("lr_decay", "cosine")
    warmup_steps = params.get("warmup_steps", 3000)
    gradient_clipping = params.get("gradient_clipping", 1.0)
    optimizer_name = params.get("optimizer", "adam")

    learning_rate = tf.constant(value=params["lr"],
                                shape=[],
                                dtype=variable_dtype.slice_dtype)
    clip_value = mtf.constant(mesh,
                              gradient_clipping,
                              dtype=variable_dtype.slice_dtype)

    if inp_var_grads is None:
        var_grads = mtf.gradients(
            [loss], [v.outputs[0] for v in mesh.graph.trainable_variables])
    else:
        var_grads = inp_var_grads

    valid_grads_vars = list(
        filter(lambda grad_var: grad_var[0] is not None,
               zip(var_grads, mesh.graph.trainable_variables)))
    valid_vars = [var for grad, var in valid_grads_vars]
    valid_grad = [grad for grad, var in valid_grads_vars]

    tf.logging.info([
        v for v in zip(var_grads,
                       [v.outputs[0] for v in mesh.graph.trainable_variables])
    ])
    # Cast to full precision
    var_grads_fp = [
        mtf.cast(v, variable_dtype.slice_dtype) for v in valid_grad
    ]

    if lr_decay == "linear":
        learning_rate = tf.train.polynomial_decay(
            learning_rate,
            global_step,
            end_step,
            end_learning_rate=params["lr"] *
            0.1,  # Decrease to 10% of initial LR according to GPT-3 paper
            power=1.0,
            cycle=False)
    elif lr_decay == "cosine":
        learning_rate = tf.train.cosine_decay(
            learning_rate,
            global_step,
            end_step,
            alpha=0.1  # Alpha is min lr value as a fraction of init lr.
        )

    if warmup_steps > 0:
        global_steps_int = tf.cast(global_step, tf.int32)
        warmup_steps_int = tf.constant(warmup_steps, dtype=tf.int32)

        dtype = variable_dtype.slice_dtype

        global_steps_float = tf.cast(global_steps_int, dtype)
        warmup_steps_float = tf.cast(warmup_steps_int, dtype)

        warmup_percent_done = global_steps_float / warmup_steps_float
        warmup_learning_rate = learning_rate * warmup_percent_done

        is_warmup = tf.cast(global_steps_int < warmup_steps_int, dtype)
        learning_rate = ((1.0 - is_warmup) * learning_rate +
                         is_warmup * warmup_learning_rate)

    learning_rate = mtf.import_fully_replicated(mesh,
                                                learning_rate,
                                                mtf.Shape([]),
                                                name="learning_rate")
    scalar_summary("lr", learning_rate)

    if optimizer_name.lower() == "adam":
        optimizer = mtf.optimize.AdamWeightDecayOptimizer(
            learning_rate=learning_rate,
            weight_decay_rate=params.get("weight_decay", 0.0),
            beta_1=params.get("beta_1", 0.9),
            beta_2=params.get("beta_2", 0.999),
            epsilon=params.get("epsilon", 1e-6),
            exclude_from_weight_decay=["norm", "bias"])
    elif optimizer_name.lower() == "adafactor":
        optimizer = mtf.optimize.AdafactorOptimizer(
            learning_rate=learning_rate,
            decay_rate=params.get("weight_decay", 0.0),
            beta1=params.get("beta_1", 0.9),
            epsilon1=params.get("epsilon_1", 1e-30),
            epsilon2=params.get("epsilon_2", 1e-3))
    else:
        raise ValueError(f"{optimizer_name} not recognized")

    if gradient_clipping is not None:
        (var_grads_fp, _) = clip_by_global_norm(var_grads_fp,
                                                clip_norm=clip_value)
    update_ops = optimizer.apply_grads(var_grads_fp, valid_vars)
    return learning_rate, update_ops, var_grads_fp
Exemplo n.º 24
0
def model_fn(features, labels, mode, params):
	"""The model_fn argument for creating an Estimator."""
	tf.logging.info("features = %s labels = %s mode = %s params=%s" %
					(features, labels, mode, params))
	global_step = tf.train.get_global_step()
	graph = mtf.Graph()
	mesh = mtf.Mesh(graph, "my_mesh")
	logits, loss = mnist_model(features, labels, mesh)
	
	variables = graph._all_variables
	tf.logging.info("[variable num]: {}".format(len(variables)))
	for v in variables:
		tf.logging.info("[variable] (name, shape): ({},{})".format(v.name,v.shape))

	mesh_shape = mtf.convert_to_shape(FLAGS.mesh_shape)
	layout_rules = mtf.auto_mtf.layout(graph, mesh_shape, [logits, loss])
	# layout_rules = {('ResidualBlock-8-filters1', 'b2'), ('ResidualBlock-3-filters3', 'b1'), ('ResidualBlockWithDown-0-filters3', 'b2'), ('ResidualBlockWithDown-1-filters2', 'b2'), ('ResidualBlock-2-filters2', 'b1'), ('ResidualBlockWithDown-2-filters2', 'b1'), ('ResidualBlock-10-filters2', 'b1'), ('ResidualBlockWithDown-3-filters3', 'b1'), ('ResidualBlock-7-filters2', 'b2'), ('ResidualBlock-0-filters2', 'b2'), ('ResidualBlock-1-filters1', 'b2'), ('ResidualBlockWithDown-2-filters1', 'b2'), ('ResidualBlockWithDown-1-filters3', 'b1'), ('ResidualBlockWithDown-0-filters1', 'b2'), ('ResidualBlock-2-filters1', 'b2'), ('ResidualBlock-6-filters3', 'b2'), ('ResidualBlock-4-filters2', 'b1'), ('ResidualBlock-1-filters3', 'b2'), ('ResidualBlock-9-filters2', 'b1'), ('ResidualBlock-11-filters2', 'b2'), ('ResidualBlock-1-filters2', 'b1'), ('ResidualBlock-5-filters1', 'b1'), ('ResidualBlock-11-filters3', 'b1'), ('classesnum', 'b2'), ('ResidualBlock-9-filters3', 'b2'), ('ResidualBlock-2-filters3', 'b2'), ('ResidualBlockWithDown-3-filters2', 'b2'), ('ResidualBlock-6-filters2', 'b1'), ('ResidualBlock-3-filters2', 'b2'), ('backbone-filters', 'b1'), ('ResidualBlock-10-filters1', 'b2'), ('ResidualBlock-8-filters2', 'b1'), ('ResidualBlock-5-filters3', 'b1'), ('ResidualBlock-4-filters1', 'b2'), ('ResidualBlock-7-filters1', 'b1'), ('ResidualBlock-7-filters3', 'b1'), ('ResidualBlock-11-filters1', 'b1'), ('ResidualBlock-10-filters3', 'b2'), ('ResidualBlockWithDown-2-filters3', 'b2'), ('ResidualBlock-5-filters2', 'b2'), ('ResidualBlock-3-filters1', 'b1'), ('ResidualBlock-0-filters1', 'b1'), ('ResidualBlock-0-filters3', 'b1'), ('ResidualBlock-6-filters1', 'b2'), ('ResidualBlock-9-filters1', 'b2'), ('ResidualBlockWithDown-0-filters2', 'b1'), ('ResidualBlockWithDown-3-filters1', 'b1'), ('ResidualBlockWithDown-1-filters1', 'b1')}
	print("[auto mtf search] strategy: {}".format(layout_rules))
	
	mesh_size = mesh_shape.size
	mesh_devices = ["gpu:0", "gpu:1","gpu:2", "gpu:3"]
	mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl(
		mesh_shape, layout_rules, mesh_devices)

	if mode == tf.estimator.ModeKeys.TRAIN:
		var_grads = mtf.gradients(
			[loss], [v.outputs[0] for v in graph.trainable_variables])
		optimizer = mtf.optimize.AdafactorOptimizer()
		# optimizer = tf.train.experimental.enable_mixed_precision_graph_rewrite(optimizer)
		update_ops = optimizer.apply_grads(var_grads, graph.trainable_variables)

	lowering = mtf.Lowering(graph, {mesh: mesh_impl})
	restore_hook = mtf.MtfRestoreHook(lowering)

	tf_logits = lowering.export_to_tf_tensor(logits)
	if mode != tf.estimator.ModeKeys.PREDICT:
		tf_loss = lowering.export_to_tf_tensor(loss)
		tf.summary.scalar("loss", tf_loss)

	if mode == tf.estimator.ModeKeys.TRAIN:
		tf_update_ops = [lowering.lowered_operation(op) for op in update_ops]
		tf_update_ops.append(tf.assign_add(global_step, 1))
		train_op = tf.group(tf_update_ops)
		# saver = tf.train.Saver(
		# 	tf.global_variables(),
		# 	sharded=True,
		# 	max_to_keep=10,
		# 	keep_checkpoint_every_n_hours=2,
		# 	defer_build=False, save_relative_paths=True)
		# tf.add_to_collection(tf.GraphKeys.SAVERS, saver)
		# saver_listener = mtf.MtfCheckpointSaverListener(lowering)
		# saver_hook = tf.train.CheckpointSaverHook(
		# 	FLAGS.model_dir,
		# 	save_steps=1000,
		# 	saver=saver,
		# 	listeners=[saver_listener])

		accuracy = tf.metrics.accuracy(
			labels=labels, predictions=tf.argmax(tf_logits, axis=1))

		# Name tensors to be logged with LoggingTensorHook.
		tf.identity(tf_loss, "cross_entropy")
		tf.identity(accuracy[1], name="train_accuracy")

		logging_hook = tf.train.LoggingTensorHook(every_n_iter=100,tensors={'loss': 'cross_entropy','acc':'train_accuracy'})

		# Save accuracy scalar to Tensorboard output.
		# tf.summary.scalar("train_accuracy", accuracy[1])

		# restore_hook must come before saver_hook
		return tf.estimator.EstimatorSpec(
			tf.estimator.ModeKeys.TRAIN, loss=tf_loss, train_op=train_op,
			training_chief_hooks=[restore_hook,logging_hook])

	if mode == tf.estimator.ModeKeys.PREDICT:
		predictions = {
			"classes": tf.argmax(tf_logits, axis=1),
			"probabilities": tf.nn.softmax(tf_logits),
		}
		return tf.estimator.EstimatorSpec(
			mode=tf.estimator.ModeKeys.PREDICT,
			predictions=predictions,
			prediction_hooks=[restore_hook],
			export_outputs={
				"classify": tf.estimator.export.PredictOutput(predictions)
			})
	if mode == tf.estimator.ModeKeys.EVAL:
		return tf.estimator.EstimatorSpec(
			mode=tf.estimator.ModeKeys.EVAL,
			loss=tf_loss,
			evaluation_hooks=[restore_hook],
			eval_metric_ops={
				"accuracy":
				tf.metrics.accuracy(
					labels=labels, predictions=tf.argmax(tf_logits, axis=1)),
			})
def model_fn(features, labels, mode, params):
    """The model_fn argument for creating an Estimator."""
    tf.logging.info("features = %s labels = %s mode = %s params=%s" %
                    (features, labels, mode, params))
    global_step = tf.train.get_global_step()
    graph = mtf.Graph()
    mesh = mtf.Mesh(graph, "my_mesh")
    logits, loss = mnist_model(features, labels, mesh)
    mesh_shape = [("gpu_rows", 1), ("gpu_cols", 6)]
    #dependency on ortools, which won't build on power
    #import mesh_tensorflow.auto_mtf
    #layout_rules = mtf.auto_mtf.layout(graph, mesh_shape,[logits,loss])
    layout_rules = [("filters1", "gpu_rows"), ("filters2", "gpu_cols"),
                    ("filters3", "gpu_rows"), ("filters4", "gpu_cols"),
                    ("filters5", "gpu_rows"), ("filters6", "gpu_cols")]
    mesh_devices = ["gpu:%d" % itm for itm in range(6)]
    mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl(
        mesh_shape, layout_rules, mesh_devices)

    if mode == tf.estimator.ModeKeys.TRAIN:
        var_grads = mtf.gradients(
            [loss], [v.outputs[0] for v in graph.trainable_variables])
        optimizer = mtf.optimize.AdafactorOptimizer()
        lowering = mtf.Lowering(graph, {mesh: mesh_impl})
        tf_grads = [lowering.export_to_tf_tensor(grad) for grad in var_grads]
        tf_avg_grads = [hvd.allreduce(grad) for grad in tf_grads]
        var_avg_grads = [
            mtf.import_tf_tensor(mesh, tf_avg_grad, shape=grad.shape)
            for tf_avg_grad, grad in zip(tf_avg_grads, var_grads)
        ]
        update_ops = optimizer.apply_grads(var_avg_grads,
                                           graph.trainable_variables)
    with tf.variable_scope('', reuse=tf.AUTO_REUSE) as _:
        lowering = mtf.Lowering(graph, {mesh: mesh_impl})
    restore_hook = mtf.MtfRestoreHook(lowering)
    tf_logits = lowering.export_to_tf_tensor(logits)
    if mode != tf.estimator.ModeKeys.PREDICT:
        tf_loss = lowering.export_to_tf_tensor(loss)
        tf.summary.scalar("loss", tf_loss)

    if mode == tf.estimator.ModeKeys.TRAIN:
        tf_update_ops = [lowering.lowered_operation(op) for op in update_ops]
        tf_update_ops.append(tf.assign_add(global_step, 1))
        train_op = tf.group(tf_update_ops)
        saver = tf.train.Saver(tf.global_variables(),
                               sharded=True,
                               max_to_keep=10,
                               keep_checkpoint_every_n_hours=2,
                               defer_build=False,
                               save_relative_paths=True)
        tf.add_to_collection(tf.GraphKeys.SAVERS, saver)
        saver_listener = mtf.MtfCheckpointSaverListener(lowering)
        saver_hook = tf.train.CheckpointSaverHook("/tmp/%s" % FLAGS.model_dir,
                                                  save_steps=1000,
                                                  saver=saver,
                                                  listeners=[saver_listener])

        accuracy = tf.metrics.accuracy(labels=labels,
                                       predictions=tf.argmax(tf_logits,
                                                             axis=1))

        # Name tensors to be logged with LoggingTensorHook.
        tf.identity(tf_loss, "cross_entropy")
        tf.identity(accuracy[1], name="train_accuracy")

        # Save accuracy scalar to Tensorboard output.
        tf.summary.scalar("train_accuracy", accuracy[1])

        # restore_hook must come before saver_hook
        return tf.estimator.EstimatorSpec(
            tf.estimator.ModeKeys.TRAIN,
            loss=tf_loss,
            train_op=train_op,
            training_chief_hooks=[restore_hook, saver_hook])

    if mode == tf.estimator.ModeKeys.PREDICT:
        predictions = {
            "classes": tf.argmax(tf_logits, axis=1),
            "probabilities": tf.nn.softmax(tf_logits),
        }
        return tf.estimator.EstimatorSpec(
            mode=tf.estimator.ModeKeys.PREDICT,
            predictions=predictions,
            prediction_hooks=[restore_hook],
            export_outputs={
                "classify": tf.estimator.export.PredictOutput(predictions)
            })
    if mode == tf.estimator.ModeKeys.EVAL:
        return tf.estimator.EstimatorSpec(
            mode=tf.estimator.ModeKeys.EVAL,
            loss=tf_loss,
            evaluation_hooks=[restore_hook],
            eval_metric_ops={
                "accuracy":
                tf.metrics.accuracy(labels=labels,
                                    predictions=tf.argmax(tf_logits, axis=1)),
            })
Exemplo n.º 26
0
def get_optimizer(mesh, loss, params, variable_dtype, inp_var_grads=None):
    """Creates and returns an optimizer training op."""
    global_step = tf.train.get_or_create_global_step()

    learning_rate = tf.constant(value=params["lr"],
                                shape=[],
                                dtype=variable_dtype.slice_dtype)
    clip_value = mtf.constant(mesh,
                              params["gradient_clipping"],
                              dtype=variable_dtype.slice_dtype)

    if inp_var_grads is None:
        var_grads = mtf.gradients(
            [loss], [v.outputs[0] for v in mesh.graph.trainable_variables])
    else:
        var_grads = inp_var_grads

    # Cast to full precision
    var_grads_fp = [mtf.cast(v, variable_dtype.slice_dtype) for v in var_grads]

    # decrease LR to final lr (lr*0.1) by this step - defaults to train_steps
    end_step = params.get("lr_decay_end", params["train_steps"])

    if params["lr_decay"] == "linear":
        learning_rate = tf.train.polynomial_decay(
            learning_rate,
            global_step,
            end_step,
            end_learning_rate=params["lr"] *
            0.1,  # Decrease to 10% of initial LR according to GPT-3 paper
            power=1.0,
            cycle=False)
    elif params["lr_decay"] == "cosine":
        learning_rate = tf.train.cosine_decay(
            learning_rate,
            global_step,
            end_step,
            alpha=0.1  # Alpha is min lr value as a fraction of init lr.
        )

    if params["warmup_steps"] > 0:
        global_steps_int = tf.cast(global_step, tf.int32)
        warmup_steps_int = tf.constant(params["warmup_steps"], dtype=tf.int32)

        dtype = variable_dtype.slice_dtype

        global_steps_float = tf.cast(global_steps_int, dtype)
        warmup_steps_float = tf.cast(warmup_steps_int, dtype)

        warmup_percent_done = global_steps_float / warmup_steps_float
        warmup_learning_rate = learning_rate * warmup_percent_done

        is_warmup = tf.cast(global_steps_int < warmup_steps_int, dtype)
        learning_rate = ((1.0 - is_warmup) * learning_rate +
                         is_warmup * warmup_learning_rate)

    learning_rate = mtf.import_fully_replicated(mesh,
                                                learning_rate,
                                                mtf.Shape([]),
                                                name="learning_rate")
    mtf.scalar_summary("lr", learning_rate)

    if params["opt_name"].lower() == "adam":
        optimizer = AdamWeightDecayOptimizer(
            learning_rate=learning_rate,
            weight_decay_rate=params["weight_decay"],
            beta_1=params["beta1"],
            beta_2=params["beta2"],
            epsilon=params["epsilon"],
            exclude_from_weight_decay=["norm", "bias"],
            variable_dtype=variable_dtype)
    else:
        optimizer = mtf.optimize.AdafactorOptimizer(
            learning_rate=params["lr"],
            decay_rate=params["weight_decay"],
            beta1=params["beta1"],
            epsilon1=params["ada_epsilon1"],
            epsilon2=params["ada_epsilon2"])

    if params["use_tpu"]:
        optimizer = tf.tpu.CrossShardOptimizer(optimizer)

    if params["gradient_clipping"] is not None:
        (var_grads_fp, _) = clip_by_global_norm(var_grads_fp,
                                                clip_norm=clip_value)

    update_ops = optimizer.apply_grads(var_grads_fp,
                                       mesh.graph.trainable_variables)
    return learning_rate, update_ops, var_grads_fp
Exemplo n.º 27
0
def model_fn(features, labels, mode, params):
    """The model_fn argument for creating an Estimator."""
    tf.logging.info("features = %s labels = %s mode = %s params=%s" %
                    (features, labels, mode, params))
    global_step = tf.train.get_global_step()
    graph = mtf.Graph()
    mesh = mtf.Mesh(graph, "my_mesh")
    logits, loss = cifar_model(features, labels, mesh)
    mesh_shape = mtf.convert_to_shape(FLAGS.mesh_shape)
    layout_rules = mtf.convert_to_layout_rules(FLAGS.layout)
    mesh_size = mesh_shape.size
    # To enable manual device placement (e.g.: GPU 1, 2) comment the line below and uncomment the next one
    mesh_devices = [""] * mesh_size
    # mesh_devices = ['GPU:' + str(i) for i in range(mesh_size)]
    mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl(
        mesh_shape, layout_rules, mesh_devices)

    labels = features['label']

    if mode == tf.estimator.ModeKeys.TRAIN:
        var_grads = mtf.gradients(
            [loss], [v.outputs[0] for v in graph.trainable_variables])
        # Variables that affect learning rate.
        num_batches_per_epoch = NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN / FLAGS.batch_size
        decay_steps = int(num_batches_per_epoch * NUM_EPOCHS_PER_DECAY)

        # Decay the learning rate exponentially based on the number of steps.
        lr = tf.train.exponential_decay(INITIAL_LEARNING_RATE,
                                        global_step,
                                        decay_steps,
                                        LEARNING_RATE_DECAY_FACTOR,
                                        staircase=True)
        tf.summary.scalar('learning_rate', lr)
        mtf_lr = mtf.import_tf_tensor(
            mesh, tf.convert_to_tensor(lr, dtype=tf.float32), mtf.Shape([]))
        optimizer = mtf.optimize.AdafactorOptimizer(learning_rate=mtf_lr)
        update_ops = optimizer.apply_grads(var_grads,
                                           graph.trainable_variables)

    lowering = mtf.Lowering(graph, {mesh: mesh_impl})
    restore_hook = mtf.MtfRestoreHook(lowering)

    tf_logits = lowering.export_to_tf_tensor(logits)
    if mode != tf.estimator.ModeKeys.PREDICT:
        tf_loss = lowering.export_to_tf_tensor(loss)
        tf_loss = tf.to_float(tf_loss)
        tf.summary.scalar("loss", tf_loss)

    if mode == tf.estimator.ModeKeys.TRAIN:
        tf_update_ops = [lowering.lowered_operation(op) for op in update_ops]
        tf_update_ops.append(tf.assign_add(global_step, 1))
        train_op = tf.group(tf_update_ops)
        saver = tf.train.Saver(tf.global_variables(),
                               sharded=True,
                               max_to_keep=10,
                               keep_checkpoint_every_n_hours=2,
                               defer_build=False,
                               save_relative_paths=True)
        tf.add_to_collection(tf.GraphKeys.SAVERS, saver)
        saver_listener = mtf.MtfCheckpointSaverListener(lowering)
        saver_hook = tf.train.CheckpointSaverHook(FLAGS.model_dir,
                                                  save_steps=1000,
                                                  saver=saver,
                                                  listeners=[saver_listener])

        accuracy = tf.metrics.accuracy(labels=labels,
                                       predictions=tf.argmax(tf_logits,
                                                             axis=1))

        # Name tensors to be logged with LoggingTensorHook.
        tf.identity(tf_loss, "cross_entropy")
        tf.identity(accuracy[1], name="train_accuracy")

        # Save accuracy scalar to Tensorboard output.
        tf.summary.scalar("train_accuracy", accuracy[1])

        # restore_hook must come before saver_hook
        return tf.estimator.EstimatorSpec(
            tf.estimator.ModeKeys.TRAIN,
            loss=tf_loss,
            train_op=train_op,
            training_chief_hooks=[restore_hook, saver_hook])

    if mode == tf.estimator.ModeKeys.PREDICT:
        predictions = {
            "classes": tf.argmax(tf_logits, axis=1),
            "probabilities": tf.nn.softmax(tf_logits),
        }
        return tf.estimator.EstimatorSpec(
            mode=tf.estimator.ModeKeys.PREDICT,
            predictions=predictions,
            prediction_hooks=[restore_hook],
            export_outputs={
                "classify": tf.estimator.export.PredictOutput(predictions)
            })

    if mode == tf.estimator.ModeKeys.EVAL:
        return tf.estimator.EstimatorSpec(
            mode=tf.estimator.ModeKeys.EVAL,
            loss=tf_loss,
            evaluation_hooks=[restore_hook],
            eval_metric_ops={
                "accuracy":
                tf.metrics.accuracy(labels=labels,
                                    predictions=tf.argmax(tf_logits, axis=1)),
            })
Exemplo n.º 28
0
def model_fn(features, labels, mode, params):
    """The model_fn argument for creating an Estimator."""
    tf.logging.info("features = %s labels = %s mode = %s params=%s" %
                    (features, labels, mode, params))
    global_step = tf.train.get_global_step()
    graph = mtf.Graph()
    # wrapped graph named "my_mesh"
    mesh = mtf.Mesh(graph, "my_mesh")
    logits, loss = mnist_model(features, labels, mesh)
    mesh_shape = mtf.convert_to_shape(FLAGS.mesh_shape)
    layout_rules = mtf.convert_to_layout_rules(FLAGS.layout)
    mesh_size = mesh_shape.size
    print("mesh_shape.size = ", mesh_shape.size)
    mesh_devices = [""] * mesh_size
    mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl(
        mesh_shape, layout_rules, mesh_devices)

    if mode == tf.estimator.ModeKeys.TRAIN:
        var_grads = mtf.gradients(
            [loss], [v.outputs[0] for v in graph.trainable_variables])
        optimizer = mtf.optimize.AdafactorOptimizer()
        update_ops = optimizer.apply_grads(var_grads,
                                           graph.trainable_variables)

    lowering = mtf.Lowering(graph, {mesh: mesh_impl})
    restore_hook = mtf.MtfRestoreHook(lowering)

    tf_logits = lowering.export_to_tf_tensor(logits)
    if mode != tf.estimator.ModeKeys.PREDICT:
        tf_loss = lowering.export_to_tf_tensor(loss)
        tf.summary.scalar("loss", tf_loss)

    if mode == tf.estimator.ModeKeys.TRAIN:
        tf_update_ops = [lowering.lowered_operation(op) for op in update_ops]
        tf_update_ops.append(tf.assign_add(global_step, 1))
        train_op = tf.group(tf_update_ops)
        saver = tf.train.Saver(tf.global_variables(),
                               sharded=True,
                               max_to_keep=10,
                               keep_checkpoint_every_n_hours=2,
                               defer_build=False,
                               save_relative_paths=True)
        tf.add_to_collection(tf.GraphKeys.SAVERS, saver)
        saver_listener = mtf.MtfCheckpointSaverListener(lowering)
        saver_hook = tf.train.CheckpointSaverHook(FLAGS.model_dir,
                                                  save_steps=1000,
                                                  saver=saver,
                                                  listeners=[saver_listener])

        accuracy = tf.metrics.accuracy(labels=labels,
                                       predictions=tf.argmax(tf_logits,
                                                             axis=1))

        # Name tensors to be logged with LoggingTensorHook.
        tf.identity(tf_loss, "cross_entropy")
        tf.identity(accuracy[1], name="train_accuracy")

        # Save accuracy scalar to Tensorboard output.
        tf.summary.scalar("train_accuracy", accuracy[1])

        # restore_hook must come before saver_hook
        return tf.estimator.EstimatorSpec(
            tf.estimator.ModeKeys.TRAIN,
            loss=tf_loss,
            train_op=train_op,
            training_chief_hooks=[restore_hook, saver_hook])

    if mode == tf.estimator.ModeKeys.PREDICT:
        predictions = {
            "classes": tf.argmax(tf_logits, axis=1),
            "probabilities": tf.nn.softmax(tf_logits),
        }
        return tf.estimator.EstimatorSpec(
            mode=tf.estimator.ModeKeys.PREDICT,
            predictions=predictions,
            prediction_hooks=[restore_hook],
            export_outputs={
                "classify": tf.estimator.export.PredictOutput(predictions)
            })
    if mode == tf.estimator.ModeKeys.EVAL:
        return tf.estimator.EstimatorSpec(
            mode=tf.estimator.ModeKeys.EVAL,
            loss=tf_loss,
            evaluation_hooks=[restore_hook],
            eval_metric_ops={
                "accuracy":
                tf.metrics.accuracy(labels=labels,
                                    predictions=tf.argmax(tf_logits, axis=1)),
            })
Exemplo n.º 29
0
    def _model_fn(input_fea, input_lab):
        """Creates a model, add summary, modes (train or eval), and hooks."""
        def _add_summary(lowering, train_or_eval, tf_loss, scalars,
                         global_step):
            """Add all summaries."""
            for k in scalars.keys():
                if not isinstance(scalars[k], tf.Tensor):
                    scalars[k] = tf.cast(
                        lowering.export_to_tf_tensor(scalars[k]), tf.float32)

            def _host_loss_summary(global_step, tf_loss, **scalars):
                """Add summary.scalar in host side."""
                gs = tf.cast(global_step, tf.int64)
                sum_loss = tf.contrib.summary.scalar(
                    '{}_loss'.format(train_or_eval), tf_loss, step=gs)
                sum_ops = [sum_loss.op]
                for description, tf_metric in scalars.iteritems():
                    sum_metric = tf.contrib.summary.scalar('{}_{}'.format(
                        train_or_eval, description),
                                                           tf_metric,
                                                           step=gs)
                    sum_ops.append(sum_metric)
                with tf.control_dependencies(sum_ops):
                    return tf.identity(tf_loss)

            # Cast the global step to tf.int32, since
            # outside_compilation does not support tf.int64.
            tf_loss = tpu.outside_compilation(_host_loss_summary,
                                              tf.cast(global_step, tf.int32),
                                              tf_loss, **scalars)

            return tf_loss

        global_step = tf.train.get_or_create_global_step()
        graph = mtf.Graph()

        # Worker 0 caches all the TPU binaries.
        replica_cache_size = 300 * 1024 * 1024  # 300M per replica.
        worker0_mem = replica_cache_size * 8 * num_hosts
        devices_memory_usage = [worker0_mem] + [0] * (num_hosts - 1)

        tf.logging.info('cpu_devices: {}, devices_mem: {}'.format(
            cpu_devices, devices_memory_usage))
        var_placer = mtf.utils.BalancedVariablePlacer(cpu_devices,
                                                      devices_memory_usage)

        mesh = mtf.Mesh(graph, 'my_mesh', var_placer)

        mesh_shape = mtf.convert_to_shape(FLAGS.mesh_shape)
        layout_rules = unet.get_layout()
        mesh_impl = mtf.simd_mesh_impl.SimdMeshImpl(mesh_shape, layout_rules,
                                                    None, d_assignment)

        with mtf.utils.outside_all_rewrites():  # Do not tpu_rewrite this part.
            preds, loss, scalars, bn_update_ops = (
                unet.unet_with_spatial_partition(mesh, train_or_eval,
                                                 input_fea, input_lab))

        if train_or_eval == 'train':
            var_grads = mtf.gradients(
                [loss], [v.outputs[0] for v in graph.trainable_variables])

            lr = FLAGS.lr * tf.pow(
                FLAGS.lr_drop_rate,
                tf.floor(
                    tf.cast(global_step, tf.float32) / FLAGS.lr_drop_steps))
            scalars['learning_rate'] = lr

            optimizer = mtf.optimize.AdafactorOptimizer(learning_rate=lr)
            update_ops = optimizer.apply_grads(var_grads,
                                               graph.trainable_variables)

            # This is where the actual tf graph got built.
            lowering = mtf.Lowering(graph, {mesh: mesh_impl})

            tf_update_ops = [
                lowering.lowered_operation(op) for op in update_ops
            ]
            tf_update_ops.append(tf.assign_add(global_step, 1))
            tf_update_ops.extend(
                [lowering.lowered_operation(op) for op in bn_update_ops])
            tf_update_ops_group = tf.group(tf_update_ops)

        else:  # train_or_eval == 'eval':
            preds = [mtf.anonymize(pred) for pred in preds]

            # This is where the actual tf graph got built.
            lowering = mtf.Lowering(graph, {mesh: mesh_impl})

            tf_preds = [
                tf.cast(lowering.export_to_tf_tensor(pred), tf.float32)
                for pred in preds
            ]

        tf_loss = tf.cast(lowering.export_to_tf_tensor(loss), tf.float32)
        if FLAGS.write_summary:
            tf_loss = _add_summary(lowering, train_or_eval, tf_loss, scalars,
                                   global_step)
        master_to_slice_hook = mtf.MtfRestoreHook(lowering)

        if train_or_eval == 'train':
            with mtf.utils.outside_all_rewrites():
                saver = tf.train.Saver(tf.global_variables(),
                                       save_relative_paths=True)
                tf.add_to_collection(tf.GraphKeys.SAVERS, saver)
                saver_listener = mtf.MtfCheckpointSaverListener(lowering)
                slice_to_master_hook = tf.train.CheckpointSaverHook(
                    FLAGS.checkpoint_dir,
                    save_steps=FLAGS.save_checkpoints_steps,
                    saver=saver,
                    listeners=[saver_listener])
                captured_hooks.capture(
                    [master_to_slice_hook, slice_to_master_hook])
                return tf_update_ops_group

        else:  # train_or_eval == 'eval':
            tf_preds.extend([tf_loss, global_step])
            tf_preds_dtypes = [tf_pred.dtype for tf_pred in tf_preds]
            tf_preds_shapes = [tf_pred.shape for tf_pred in tf_preds]
            captured_hooks.capture([master_to_slice_hook, None])
            captured_output_dtypes_shapes.capture(
                [tf_preds_dtypes, tf_preds_shapes])
            return tpu_ops.outfeed_enqueue_tuple(tf_preds)
Exemplo n.º 30
0
    def estimator_model_fn(cls,
                           hparams,
                           features,
                           labels,
                           mode,
                           config=None,
                           params=None,
                           decode_hparams=None):
        hparams = copy.deepcopy(hparams)
        use_tpu = params and params.get("use_tpu", False)
        hparams.use_tpu = use_tpu
        # merge decode_hparams into hparams if present
        if mode == tf.estimator.ModeKeys.PREDICT and decode_hparams is not None:
            for k, v in six.iteritems(decode_hparams.values()):
                if hasattr(hparams, k) and getattr(hparams, k) != v:
                    tf.logging.warning(
                        "Overriding hparams.%s with %s from decode_hparams" %
                        (k, v))
                setattr(hparams, k, v)

        # Instantiate model
        data_parallelism = None
        if not use_tpu and config:
            data_parallelism = config.data_parallelism
        model = cls(hparams,
                    mode,
                    data_parallelism=data_parallelism,
                    decode_hparams=decode_hparams)

        global_step = tf.train.get_global_step()

        mesh_shape = mtf.convert_to_shape(hparams.mesh_shape)
        layout_rules = mtf.convert_to_layout_rules(hparams.layout)
        if use_tpu:
            ctx = params["context"]
            num_hosts = ctx.num_hosts
            host_placement_fn = ctx.tpu_host_placement_function
            device_list = [
                host_placement_fn(host_id=t) for t in range(num_hosts)
            ]
            # TODO(ylc): Better estimation of replica cache size?
            replica_cache_size = 300 * 1000000  # 300M per replica
            # Worker 0 caches all the TPU binaries.
            worker0_mem = replica_cache_size * ctx.num_replicas
            devices_memeory_usage = [worker0_mem] + [0] * (num_hosts - 1)
            var_placer = mtf.utils.BalancedVariablePlacer(
                device_list, devices_memeory_usage)
            mesh_devices = [""] * mesh_shape.size
            mesh_impl = mtf.simd_mesh_impl.SimdMeshImpl(
                mesh_shape, layout_rules, mesh_devices, ctx.device_assignment)
        else:
            var_placer = None
            if len(data_parallelism.ps_devices) == 1:
                mesh_devices = [""] * mesh_shape.size
            else:
                assert len(data_parallelism.ps_devices) == mesh_shape.size
                mesh_devices = data_parallelism.ps_devices
            mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl(
                mesh_shape, layout_rules, mesh_devices)

        graph = mtf.Graph()
        mesh = mtf.Mesh(graph, "my_mesh", var_placer)
        # PREDICT mode
        if mode == tf.estimator.ModeKeys.PREDICT:
            return model.estimator_spec_predict(features, mesh, mesh_impl,
                                                use_tpu)

        logits, loss = model.mtf_model_fn(features, mesh)
        if use_tpu and logits is not None:
            logits = mtf.anonymize(logits)

        # TRAIN mode
        if mode == tf.estimator.ModeKeys.TRAIN:
            var_grads = mtf.gradients(
                [loss], [v.outputs[0] for v in graph.trainable_variables])
            lr = learning_rate.learning_rate_schedule(hparams)
            tf.summary.scalar("learning_rate", lr)
            mtf_lr = mtf.import_tf_tensor(
                mesh, tf.convert_to_tensor(lr, dtype=tf.float32),
                mtf.Shape([]))
            optimizer = mtf.optimize.make_optimizer(hparams, mtf_lr)
            update_ops = []
            for grad, var in zip(var_grads, graph.trainable_variables):
                update_ops.extend(optimizer.apply_grad(grad, var))

        lowering = mtf.Lowering(graph, {mesh: mesh_impl})

        tf_loss = lowering.export_to_tf_tensor(loss)
        tf_loss = tf.to_float(tf_loss)
        if logits and mode != tf.estimator.ModeKeys.TRAIN:
            tf_logits = lowering.export_to_tf_tensor(logits)

        if mode == tf.estimator.ModeKeys.TRAIN:
            tf_update_ops = [
                lowering.lowered_operation(op) for op in update_ops
            ]
            tf_update_ops.append(tf.assign_add(global_step, 1))
            # tf.logging.info("tf_update_ops: {}".format(tf_update_ops))
            train_op = tf.group(tf_update_ops)

        with mtf.utils.outside_all_rewrites():
            # Copy master variables to slices. Must be called first.
            restore_hook = mtf.MtfRestoreHook(lowering)
            saver = tf.train.Saver(tf.global_variables(),
                                   sharded=True,
                                   max_to_keep=10,
                                   keep_checkpoint_every_n_hours=2,
                                   defer_build=False,
                                   save_relative_paths=True)
            tf.add_to_collection(tf.GraphKeys.SAVERS, saver)
            saver_listener = mtf.MtfCheckpointSaverListener(lowering)
            saver_hook = tf.train.CheckpointSaverHook(
                hparams.model_dir,
                save_steps=1000,
                saver=saver,
                listeners=[saver_listener])

        # EVAL mode
        if mode == tf.estimator.ModeKeys.EVAL:
            tf_logits = lowering.export_to_tf_tensor(logits)
            return model.estimator_spec_eval(features, tf_logits, labels,
                                             tf_loss, restore_hook, use_tpu)

        if use_tpu:
            # TPU host call. Important: need to be called before remove_summaries()
            if hparams.tpu_enable_host_call:
                host_call = t2t_model.create_host_call(hparams.model_dir)
            else:
                host_call = None

            t2t_model.remove_summaries()
            return tpu_estimator.TPUEstimatorSpec(
                mode=tf.estimator.ModeKeys.TRAIN,
                loss=tf_loss,
                train_op=train_op,
                host_call=host_call,
                training_hooks=[restore_hook, saver_hook])
        else:
            return tf.estimator.EstimatorSpec(
                tf.estimator.ModeKeys.TRAIN,
                loss=tf_loss,
                train_op=train_op,
                training_chief_hooks=[restore_hook, saver_hook])
Exemplo n.º 31
0
def model_fn(features, labels, mode, params):
    """A model is called by TpuEstimator."""
    del labels
    global_step = tf.train.get_global_step()
    graph = mtf.Graph()
    mesh = mtf.Mesh(graph, 'my_mesh')
    mesh_shape = mtf.convert_to_shape(FLAGS.mesh_shape)
    mesh_devices = [''] * mesh_shape.size
    mesh_impl = mtf.simd_mesh_impl.SimdMeshImpl(
        mesh_shape, mtf.convert_to_layout_rules(FLAGS.layout), mesh_devices,
        params['context'].device_assignment)
    with mtf.utils.outside_all_rewrites():
        logits, loss = toy_model(features, mesh)

    # TRAIN mode
    if mode == tf.estimator.ModeKeys.TRAIN:
        var_grads = mtf.gradients(
            [loss], [v.outputs[0] for v in graph.trainable_variables])
        optimizer = mtf.optimize.AdafactorOptimizer()
        update_ops = []
        for grad, var in zip(var_grads, graph.trainable_variables):
            update_ops.extend(optimizer.apply_grad(grad, var))
    else:
        # for now, we can only export fully-replicated tensors.
        fully_replicated_logits = mtf.anonymize(logits)

    lowering = mtf.Lowering(graph, {mesh: mesh_impl})

    tf_loss = lowering.export_to_tf_tensor(loss)

    if mode == tf.estimator.ModeKeys.TRAIN:
        tf_update_ops = [lowering.lowered_operation(op) for op in update_ops]
        tf_update_ops.append(tf.assign_add(global_step, 1))
        tf.logging.info('tf_update_ops: {}'.format(tf_update_ops))
        train_op = tf.group(tf_update_ops)
    else:
        tf_logits = lowering.export_to_tf_tensor(fully_replicated_logits)

    with mtf.utils.outside_all_rewrites():
        # Copy master variables to slices. Must be called first.
        restore_hook = mtf.MtfRestoreHook(lowering)
        if mode == tf.estimator.ModeKeys.TRAIN:
            saver = tf.train.Saver(tf.global_variables(),
                                   sharded=True,
                                   max_to_keep=10,
                                   keep_checkpoint_every_n_hours=2,
                                   defer_build=False,
                                   save_relative_paths=True)
            tf.add_to_collection(tf.GraphKeys.SAVERS, saver)
            saver_listener = mtf.MtfCheckpointSaverListener(lowering)
            saver_hook = tf.train.CheckpointSaverHook(
                FLAGS.model_dir,
                save_steps=1000,
                saver=saver,
                listeners=[saver_listener])

            return tpu_estimator.TPUEstimatorSpec(
                tf.estimator.ModeKeys.TRAIN,
                loss=tf_loss,
                train_op=train_op,
                training_hooks=[restore_hook, saver_hook])
        elif mode == tf.estimator.ModeKeys.EVAL:

            def metric_fn(tf_logits):
                mean_logitss = tf.metrics.mean(tf_logits)
                return {'mean_logitss': mean_logitss}

            eval_metrics = (metric_fn, [tf_logits])

            return tpu_estimator.TPUEstimatorSpec(
                tf.estimator.ModeKeys.EVAL,
                evaluation_hooks=[restore_hook],
                loss=tf_loss,
                eval_metrics=eval_metrics)