示例#1
0
def test_inception3():
    # test for vgg16
    with tf.Graph().as_default():
        assert tf_model.exist_model('inception3')
        op, loss = tf_model.get_model('inception3')

    with tf.Graph().as_default():
        assert tf_model.exist_model('inception3', horovod=True)
        op, loss = tf_model.get_model('inception3', horovod=True)
示例#2
0
def test_alexnet():
    # test for vgg16
    with tf.Graph().as_default():
        assert tf_model.exist_model('alexnet')
        op, loss = tf_model.get_model('alexnet')

    with tf.Graph().as_default():
        assert tf_model.exist_model('alexnet', horovod=True)
        op, loss = tf_model.get_model('alexnet', horovod=True)
示例#3
0
def test_seq2seq():
    # test for seq2seq
    with tf.Graph().as_default():
        assert tf_model.exist_model('seq2seq')
        op, loss = tf_model.get_model('seq2seq')

    with tf.Graph().as_default():
        assert tf_model.exist_model('seq2seq', horovod=True)
        op, loss = tf_model.get_model('seq2seq', horovod=True)
示例#4
0
def test_resnet50():
    # test for resnet50
    with tf.Graph().as_default():
        assert tf_model.exist_model('resnet50')
        op, loss = tf_model.get_model('resnet50')

    with tf.Graph().as_default():
        assert tf_model.exist_model('resnet50', horovod=True)
        op, loss = tf_model.get_model('resnet50', horovod=True)
示例#5
0
def test_lstm():
    # test for lstm
    with tf.Graph().as_default():
        assert tf_model.exist_model('lstm')
        op, loss = tf_model.get_model('lstm')

    with tf.Graph().as_default():
        assert tf_model.exist_model('lstm', horovod=True)
        op, loss = tf_model.get_model('lstm', horovod=True)
示例#6
0
def test_nasnet():
    # test for nasnet
    with tf.Graph().as_default():
        assert tf_model.exist_model('nasnet')
        op, loss = tf_model.get_model('nasnet')

    with tf.Graph().as_default():
        assert tf_model.exist_model('nasnet', horovod=True)
        op, loss = tf_model.get_model('nasnet', horovod=True)
示例#7
0
    args = parser.parse_args()

    # if --list, print out the list of models
    if args.list is True:
        models = tf_model.get_model_list(horovod=args.horovod)
        print('All of the models with horovod=%s:' % str(args.horovod))
        for model in models:
            print('  %s' % model)
        exit()

    # check if the models available
    models = []
    for i in args.models:
        if i == 'all':
            models = tf_model.get_model_list(horovod=args.horovod)
        elif tf_model.exist_model(i, horovod=args.horovod):
            models.append(i)
        else:
            raise ValueError('model: %s doesn\'t exist' % i)

    # if --graph, generate the pbtxt graph od models

    if args.graph is True:
        if args.profile:
            print('the --graph cannot run together with --tfprof')
            args.profile = False
        if args.timeline:
            print('the --graph cannot run together with --timeline')
            args.timeline = False
        if args.session_num != 1:
            print('because --graph setted, session_num auto set to 1')
示例#8
0
def test_tf_model():
    # test for Error raising
    assert not tf_model.exist_model('foobar')
    with pytest.raises(ValueError):
        tf_model.get_model('foobar')