예제 #1
0
def test_inception3():
    # test for vgg16
    with tf.Graph().as_default():
        assert tf_model.exist_model('inception3')
        op, loss = tf_model.get_model('inception3')

    with tf.Graph().as_default():
        assert tf_model.exist_model('inception3', horovod=True)
        op, loss = tf_model.get_model('inception3', horovod=True)
예제 #2
0
def test_alexnet():
    # test for vgg16
    with tf.Graph().as_default():
        assert tf_model.exist_model('alexnet')
        op, loss = tf_model.get_model('alexnet')

    with tf.Graph().as_default():
        assert tf_model.exist_model('alexnet', horovod=True)
        op, loss = tf_model.get_model('alexnet', horovod=True)
예제 #3
0
def test_seq2seq():
    # test for seq2seq
    with tf.Graph().as_default():
        assert tf_model.exist_model('seq2seq')
        op, loss = tf_model.get_model('seq2seq')

    with tf.Graph().as_default():
        assert tf_model.exist_model('seq2seq', horovod=True)
        op, loss = tf_model.get_model('seq2seq', horovod=True)
예제 #4
0
def test_resnet50():
    # test for resnet50
    with tf.Graph().as_default():
        assert tf_model.exist_model('resnet50')
        op, loss = tf_model.get_model('resnet50')

    with tf.Graph().as_default():
        assert tf_model.exist_model('resnet50', horovod=True)
        op, loss = tf_model.get_model('resnet50', horovod=True)
예제 #5
0
def test_lstm():
    # test for lstm
    with tf.Graph().as_default():
        assert tf_model.exist_model('lstm')
        op, loss = tf_model.get_model('lstm')

    with tf.Graph().as_default():
        assert tf_model.exist_model('lstm', horovod=True)
        op, loss = tf_model.get_model('lstm', horovod=True)
예제 #6
0
def test_nasnet():
    # test for nasnet
    with tf.Graph().as_default():
        assert tf_model.exist_model('nasnet')
        op, loss = tf_model.get_model('nasnet')

    with tf.Graph().as_default():
        assert tf_model.exist_model('nasnet', horovod=True)
        op, loss = tf_model.get_model('nasnet', horovod=True)
예제 #7
0
def test_tf_model():
    # test for Error raising
    assert not tf_model.exist_model('foobar')
    with pytest.raises(ValueError):
        tf_model.get_model('foobar')