model = Sequential() model.add( Conv2D(32, kernel_size=(3, 3), activation='relu', input_shape=input_shape)) model.add(Conv2D(64, (3, 3), activation='relu')) model.add(MaxPooling2D(pool_size=(2, 2))) model.add(Dropout(0.25)) model.add(Flatten()) model.add(Dense(128, activation='relu')) model.add(Dropout(0.5)) model.add(Dense(num_classes, activation='softmax')) # BytePS: adjust learning rate based on number of GPUs. opt = keras.optimizers.Adadelta(1.0 * bps.size()) # BytePS: add BytePS Distributed Optimizer. opt = bps.DistributedOptimizer(opt) model.compile(loss=keras.losses.categorical_crossentropy, optimizer=opt, metrics=['accuracy']) callbacks = [ # BytePS: broadcast initial variable states from rank 0 to all other processes. # This is necessary to ensure consistent initialization of all workers when # training is started with random weights or restored from a checkpoint. bps.callbacks.BroadcastGlobalVariablesCallback(0), ] # BytePS: save checkpoints only on worker 0 to prevent other workers from corrupting them. if bps.rank() == 0: callbacks.append(
regularizer = keras.regularizers.l2(args.wd) layer_config['config']['kernel_regularizer'] = \ {'class_name': regularizer.__class__.__name__, 'config': regularizer.get_config()} if type(layer) == keras.layers.BatchNormalization: layer_config['config']['momentum'] = 0.9 layer_config['config']['epsilon'] = 1e-5 model = keras.models.Model.from_config(model_config) # BytePS: adjust learning rate based on number of GPUs. opt = keras.optimizers.SGD(lr=args.base_lr * bps.size(), momentum=args.momentum) # BytePS: add BytePS Distributed Optimizer. opt = bps.DistributedOptimizer(opt, compression=compression) model.compile(loss=keras.losses.categorical_crossentropy, optimizer=opt, metrics=['accuracy', 'top_k_categorical_accuracy']) callbacks = [ # BytePS: broadcast initial variable states from rank 0 to all other processes. # This is necessary to ensure consistent initialization of all workers when # training is started with random weights or restored from a checkpoint. bps.callbacks.BroadcastGlobalVariablesCallback(0), # BytePS: average metrics among workers at the end of every epoch. # # Note: This callback must be in the list before the ReduceLROnPlateau, # TensorBoard, or other metrics-based callbacks.