Exemplo n.º 1
0
    def on_batch_end(self, batch, logs=None):
        if self.broadcast_done:
            return

        if bps.size() <= 1:
            return

        with tf.device(self.device):
            if bps._executing_eagerly() and hasattr(self.model, 'variables'):
                # TensorFlow 2.0 or TensorFlow eager
                bps.broadcast_variables(self.model.variables,
                                        root_rank=self.root_rank)
                bps.broadcast_variables(self.model.optimizer.variables(),
                                        root_rank=self.root_rank)
            else:
                bcast_op = bps.broadcast_global_variables(self.root_rank)
                self.backend.get_session().run(bcast_op)

        self.broadcast_done = True
Exemplo n.º 2
0
 def broadcast_global_variables(backend, root_rank):
     return _eval(backend, bps.broadcast_global_variables(root_rank))
Exemplo n.º 3
0
    tf.enable_eager_execution(config)

# Set up standard model.
# Check https://github.com/keras-team/keras-applications for all supported models, e.g., ResNet50, VGG16
model = getattr(applications, args.model)(weights=None)

opt = tf.train.GradientDescentOptimizer(0.01)

# BytePS: (optional) compression algorithm.
compression = bps.Compression.fp16 if args.fp16_pushpull else bps.Compression.none

# BytePS: wrap optimizer with DistributedOptimizer.
opt = bps.DistributedOptimizer(opt, compression=compression)

init = tf.global_variables_initializer()
bcast_op = bps.broadcast_global_variables(0)

data = tf.random_uniform([args.batch_size, 224, 224, 3])
target = tf.random_uniform([args.batch_size, 1], minval=0, maxval=999, dtype=tf.int64)


def loss_function():
    logits = model(data, training=True)
    return tf.losses.sparse_softmax_cross_entropy(target, logits)


def log(s, nl=True):
    if bps.rank() != 0:
        return
    print(s, end='\n' if nl else '')
    sys.stdout.flush()
Exemplo n.º 4
0
def broadcast_global_variables(backend, root_rank):
    bcast_op = bps.broadcast_global_variables(root_rank)
    return backend.get_session().run(bcast_op)
Exemplo n.º 5
0
 def on_train_begin(self, logs=None):
     if bps.size() <= 1:
         return
     with tf.device(self.device):
         bcast_op = bps.broadcast_global_variables(self.root_rank)
         self.backend.get_session().run(bcast_op)