def test_inception3(): # test for vgg16 with tf.Graph().as_default(): assert tf_model.exist_model('inception3') op, loss = tf_model.get_model('inception3') with tf.Graph().as_default(): assert tf_model.exist_model('inception3', horovod=True) op, loss = tf_model.get_model('inception3', horovod=True)
def test_alexnet(): # test for vgg16 with tf.Graph().as_default(): assert tf_model.exist_model('alexnet') op, loss = tf_model.get_model('alexnet') with tf.Graph().as_default(): assert tf_model.exist_model('alexnet', horovod=True) op, loss = tf_model.get_model('alexnet', horovod=True)
def test_seq2seq(): # test for seq2seq with tf.Graph().as_default(): assert tf_model.exist_model('seq2seq') op, loss = tf_model.get_model('seq2seq') with tf.Graph().as_default(): assert tf_model.exist_model('seq2seq', horovod=True) op, loss = tf_model.get_model('seq2seq', horovod=True)
def test_resnet50(): # test for resnet50 with tf.Graph().as_default(): assert tf_model.exist_model('resnet50') op, loss = tf_model.get_model('resnet50') with tf.Graph().as_default(): assert tf_model.exist_model('resnet50', horovod=True) op, loss = tf_model.get_model('resnet50', horovod=True)
def test_lstm(): # test for lstm with tf.Graph().as_default(): assert tf_model.exist_model('lstm') op, loss = tf_model.get_model('lstm') with tf.Graph().as_default(): assert tf_model.exist_model('lstm', horovod=True) op, loss = tf_model.get_model('lstm', horovod=True)
def test_nasnet(): # test for nasnet with tf.Graph().as_default(): assert tf_model.exist_model('nasnet') op, loss = tf_model.get_model('nasnet') with tf.Graph().as_default(): assert tf_model.exist_model('nasnet', horovod=True) op, loss = tf_model.get_model('nasnet', horovod=True)
args = parser.parse_args() # if --list, print out the list of models if args.list is True: models = tf_model.get_model_list(horovod=args.horovod) print('All of the models with horovod=%s:' % str(args.horovod)) for model in models: print(' %s' % model) exit() # check if the models available models = [] for i in args.models: if i == 'all': models = tf_model.get_model_list(horovod=args.horovod) elif tf_model.exist_model(i, horovod=args.horovod): models.append(i) else: raise ValueError('model: %s doesn\'t exist' % i) # if --graph, generate the pbtxt graph od models if args.graph is True: if args.profile: print('the --graph cannot run together with --tfprof') args.profile = False if args.timeline: print('the --graph cannot run together with --timeline') args.timeline = False if args.session_num != 1: print('because --graph setted, session_num auto set to 1')
def test_tf_model(): # test for Error raising assert not tf_model.exist_model('foobar') with pytest.raises(ValueError): tf_model.get_model('foobar')