Пример #1
0
    model = models_factory.get_uncompiled_model(args.network, class_num=get_class_num(), trainable=True, weights=None)
    model.summary()
    
    LR_POLICY = tf_parameter_mgr.getLearningRate()
    OPTIMIZER = tf_parameter_mgr.getOptimizer(LR_POLICY)
    model.compile(optimizer=OPTIMIZER,
                  loss='sparse_categorical_crossentropy',
                  metrics=["accuracy"])
    return model

#strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy()

if __name__ == '__main__':
    args = parse_args()

    train_data = data_factory.get_dataset_from_tfrecords(args.network, tf_parameter_mgr.getTrainData(), tf_parameter_mgr.getTrainBatchSize())
    test_data = data_factory.get_dataset_from_tfrecords(args.network, tf_parameter_mgr.getTestData(), tf_parameter_mgr.getTrainBatchSize())

    # Output total iterations info for deep learning insights
    epochs = tf_parameter_mgr.getMaxSteps()
    print("Total iterations: %s" % (len(list(train_data.as_numpy_iterator())) * epochs))
    
    #with strategy.scope():
    model = get_compiled_model()

    weight_file = get_init_weight_file()
    if weight_file:
        print('loading weights')
        model.load_weights(weight_file)
        
    history=model.fit(train_data, epochs=tf_parameter_mgr.getMaxSteps(), callbacks=[TFKerasMonitorCallback(test_data)])
    'task_id', 0, 'Task ID of the worker/replica running the training.')

tf.app.flags.DEFINE_string(
    'train_dir', '/tmp/mnist_train', """Directory where to write event logs """
    """and checkpoint.""")

tf.app.flags.DEFINE_boolean('log_device_placement', False,
                            """Whether to log device placement.""")

import mnist_input
import tf_parameter_mgr
import monitor_cb

max_steps = tf_parameter_mgr.getMaxSteps()
test_interval = tf_parameter_mgr.getTestInterval()
batch_size = tf_parameter_mgr.getTrainBatchSize()


def setup_distribute():
    global FLAGS
    worker_hosts = []
    ps_hosts = []
    spec = {}
    if FLAGS.worker_hosts is not None and FLAGS.worker_hosts != '':
        worker_hosts = FLAGS.worker_hosts.split(',')
        spec.update({'worker': worker_hosts})

    if FLAGS.ps_hosts is not None and FLAGS.ps_hosts != '':
        ps_hosts = FLAGS.ps_hosts.split(',')
        spec.update({'ps': ps_hosts})
Пример #3
0
        loss = loss_fn(labels, logits)

    grads = tape.gradient(loss, model.trainable_variables)
    optimizer.apply_gradients(zip(grads, model.trainable_variables))

    return loss, logits, grads


def test_one_step(model, loss_fn, images, labels):
    logits = model(images)
    loss = loss_fn(labels, logits)

    return loss, logits


BATCH_SIZE = tf_parameter_mgr.getTrainBatchSize()

if __name__ == '__main__':
    args = parse_args()

    train_data = data_factory.get_dataset_from_tfrecords(
        args.network, tf_parameter_mgr.getTrainData(),
        tf_parameter_mgr.getTrainBatchSize())
    test_data = data_factory.get_dataset_from_tfrecords(
        args.network, tf_parameter_mgr.getTestData(),
        tf_parameter_mgr.getTrainBatchSize())

    # Output total iterations info for deep learning insights
    epochs = tf_parameter_mgr.getMaxSteps()
    print("Total iterations: %s" %
          (len(list(train_data.as_numpy_iterator())) * epochs))