Beispiel #1
0
def train_model(FLAGS):
    # data flow generator
    train_sequence, validation_sequence = data_flow(FLAGS.data_local, FLAGS.batch_size,
                                                    FLAGS.num_classes, FLAGS.input_size)

    optimizer = adam(lr=FLAGS.learning_rate)#, clipnorm=0.001)
    reduce_on_plateau = ReduceLROnPlateau(monitor="val_acc", mode="max", factor=0.1, patience=10, verbose=1)
    objective = 'categorical_crossentropy'
    metrics = ['accuracy']
    model = model_fn(FLAGS, objective, optimizer, metrics)
    # if FLAGS.restore_model_path != '' and file.exists(FLAGS.restore_model_path):
    #     if FLAGS.restore_model_path.startswith('s3://'):
    #         restore_model_name = FLAGS.restore_model_path.rsplit('/', 1)[1]
    #         file.copy(FLAGS.restore_model_path, '/cache/tmp/' + restore_model_name)
    #         model.load_weights('/cache/tmp/' + restore_model_name)
    #         os.remove('/cache/tmp/' + restore_model_name)
    #     else:
    # model.load_weights(FLAGS.restore_model_path)
    if not os.path.exists(FLAGS.train_local):
        os.makedirs(FLAGS.train_local)
    tensorBoard = TensorBoard(log_dir=FLAGS.train_local)
    history = LossHistory(FLAGS)
    model.fit_generator(
        train_sequence,
        steps_per_epoch=len(train_sequence),
        epochs=FLAGS.max_epochs,
        verbose=2,
        callbacks=[history, tensorBoard,reduce_on_plateau],
        validation_data=validation_sequence,
        max_queue_size=10,
        workers=int(multiprocessing.cpu_count() * 0.7),
        use_multiprocessing=True,
        shuffle=True
    )

    print('training done!')

    if FLAGS.deploy_script_path != '':
        from save_model import save_pb_model
        save_pb_model(FLAGS, model)

    if FLAGS.test_data_url != '':
        print('test dataset predicting...')
        from eval import load_test_data
        img_names, test_data, test_labels = load_test_data(FLAGS)
        predictions = model.predict(test_data, verbose=0)

        right_count = 0
        for index, pred in enumerate(predictions):
            predict_label = np.argmax(pred, axis=0)
            test_label = test_labels[index]
            if predict_label == test_label:
                right_count += 1
        accuracy = right_count / len(img_names)
        print('accuracy: %0.4f' % accuracy)
        metric_file_name = os.path.join(FLAGS.train_local, 'metric.json')
        metric_file_content = '{"total_metric": {"total_metric_values": {"accuracy": %0.4f}}}' % accuracy
        with open(metric_file_name, "w") as f:
            f.write(metric_file_content + '\n')
    print('end')
def train_model(FLAGS):
    # data flow generator
    train_sequence, validation_sequence = data_flow(FLAGS.data_local,
                                                    FLAGS.batch_size,
                                                    FLAGS.num_classes,
                                                    FLAGS.input_size)

    optimizer = adam(lr=FLAGS.learning_rate, clipnorm=0.001)
    objective = 'binary_crossentropy'
    metrics = ['accuracy']
    model = model_fn(FLAGS, objective, optimizer, metrics)
    if FLAGS.restore_model_path != '' and file.exists(
            FLAGS.restore_model_path):
        if FLAGS.restore_model_path.startswith('s3://'):
            restore_model_name = FLAGS.restore_model_path.rsplit('/', 1)[1]
            file.copy(FLAGS.restore_model_path,
                      '/cache/tmp/' + restore_model_name)
            model.load_weights('/cache/tmp/' + restore_model_name)
            os.remove('/cache/tmp/' + restore_model_name)
        else:
            model.load_weights(FLAGS.restore_model_path)
    if not os.path.exists(FLAGS.train_local):
        os.makedirs(FLAGS.train_local)
    tensorBoard = TensorBoard(log_dir=FLAGS.train_local)
    history = LossHistory(FLAGS)
    # STEP_SIZE_TRAIN = train_sequence.n

    # STEP_SIZE_VALID = validation_sequence.n
    model.fit_generator(
        train_sequence,
        steps_per_epoch=len(train_sequence),
        epochs=FLAGS.max_epochs,
        verbose=1,
        callbacks=[history, tensorBoard],
        # validation_steps=STEP_SIZE_VALID,
        validation_data=validation_sequence,
        max_queue_size=10,
        workers=int(multiprocessing.cpu_count() * 0.7),
        use_multiprocessing=True,
        shuffle=True)
    # count=train_sequence.get_count()
    # for n in range(FLAGS.max_epochs):
    #     for i in range(count):
    #         batch_train_x,batch_train_y = train_sequence.next_batch()
    #         batch_val_x, batch_val_y=validation_sequence.next_batch()
    #         model.fit(x=batch_train_x,
    #               y=batch_train_y,
    #             verbose=1,
    #             callbacks=[history, tensorBoard],
    #             # validation_steps=STEP_SIZE_VALID,
    #             validation_data=(batch_val_x, batch_val_y),
    #             shuffle=True)
    print('training done!')

    if FLAGS.deploy_script_path != '':
        from save_model import save_pb_model
        save_pb_model(FLAGS, model)

    if FLAGS.test_data_url != '':
        print('test dataset predicting...')
        from eval import load_test_data
        img_names, test_data, test_labels = load_test_data(FLAGS)
        predictions = model.predict(test_data, verbose=0)

        right_count = 0
        for index, pred in enumerate(predictions):
            predict_label = np.argmax(pred, axis=0)
            test_label = test_labels[index]
            if predict_label == test_label:
                right_count += 1
        accuracy = right_count / len(img_names)
        print('accuracy: %0.4f' % accuracy)
        metric_file_name = os.path.join(FLAGS.train_local, 'metric.json')
        metric_file_content = '{"total_metric": {"total_metric_values": {"accuracy": %0.4f}}}' % accuracy
        with open(metric_file_name, "w") as f:
            f.write(metric_file_content + '\n')
    print('end')
Beispiel #3
0
def train_model(FLAGS):
    preprocess_input = efn.preprocess_input

    train_sequence, validation_sequence = data_flow(FLAGS.data_local,
                                                    FLAGS.batch_size,
                                                    FLAGS.num_classes,
                                                    FLAGS.input_size,
                                                    preprocess_input)

    optimizer = Adam(lr=FLAGS.learning_rate)

    objective = 'categorical_crossentropy'
    metrics = ['accuracy']

    model = model_fn(FLAGS, objective, optimizer, metrics)

    if FLAGS.restore_model_path != '' and os.path.exists(
            FLAGS.restore_model_path):
        model.load_weights(FLAGS.restore_model_path)
        print("LOAD OK!!!")

    if not os.path.exists(FLAGS.save_model_local):
        os.makedirs(FLAGS.save_model_local)

    log_local = './log_file/'

    tensorBoard = TensorBoard(log_dir=log_local)
    # reduce_lr = ReduceLROnPlateau(monitor='val_loss', factor=0.5, patience=5, mode='auto')

    sample_count = len(train_sequence) * FLAGS.batch_size
    epochs = FLAGS.max_epochs
    warmup_epoch = 5
    batch_size = FLAGS.batch_size
    learning_rate_base = FLAGS.learning_rate
    total_steps = int(epochs * sample_count / batch_size)
    warmup_steps = int(warmup_epoch * sample_count / batch_size)

    warm_up_lr = WarmUpCosineDecayScheduler(
        learning_rate_base=learning_rate_base,
        total_steps=total_steps,
        warmup_learning_rate=0,
        warmup_steps=warmup_steps,
        hold_base_rate_steps=0,
    )

    cbk = Mycbk(FLAGS)
    model.fit_generator(train_sequence,
                        steps_per_epoch=len(train_sequence),
                        epochs=FLAGS.max_epochs,
                        verbose=1,
                        callbacks=[cbk, tensorBoard, warm_up_lr],
                        validation_data=validation_sequence,
                        max_queue_size=10,
                        workers=int(multiprocessing.cpu_count() * 0.7),
                        use_multiprocessing=True,
                        shuffle=True)

    print('training done!')

    from save_model import save_pb_model
    save_pb_model(FLAGS, model)

    if FLAGS.test_data_local != '':
        print('test dataset predicting...')
        from eval import load_test_data
        img_names, test_data, test_labels = load_test_data(FLAGS)
        test_data = preprocess_input(test_data)
        predictions = model.predict(test_data, verbose=0)

        right_count = 0
        for index, pred in enumerate(predictions):
            predict_label = np.argmax(pred, axis=0)
            test_label = test_labels[index]
            if predict_label == test_label:
                right_count += 1
        accuracy = right_count / len(img_names)
        print('accuracy: %0.4f' % accuracy)
        metric_file_name = os.path.join(FLAGS.save_model_local, 'metric.json')
        metric_file_content = '{"total_metric": {"total_metric_values": {"accuracy": %0.4f}}}' % accuracy
        with open(metric_file_name, "w") as f:
            f.write(metric_file_content + '\n')
    print('end')
Beispiel #4
0
def train_model(FLAGS):
    start_time = datetime.now()
    # data flow generator
    train_data_dir_list = list(FLAGS.data_local.split(','))
    train_sequence, validation_sequence = data_flow(train_data_dir_list,
                                                    FLAGS.batch_size,
                                                    FLAGS.num_classes,
                                                    FLAGS.input_size)

    # dir_list = list(FLAGS.data_local.split(','))
    # train_generator, validation_generator = get_tran_val(dir_list[0], dir_list[1], FLAGS.input_size, FLAGS.batch_size)

    reduce_lr = ReduceLROnPlateau(monitor='val_loss',
                                  factor=0.2,
                                  patience=1,
                                  mode='auto',
                                  min_lr=1e-16)

    optimizer = adam(lr=FLAGS.learning_rate, clipnorm=0.0005)
    objective = 'categorical_crossentropy'
    metrics = ['accuracy']
    model = multimodel(FLAGS, objective, optimizer, metrics)

    if not os.path.exists(FLAGS.train_local):
        os.makedirs(FLAGS.train_local)
    tensorBoard = TensorBoard(log_dir=FLAGS.train_local)
    history = LossHistory(FLAGS, train_data_dir_list, model)

    model.fit_generator(train_sequence,
                        steps_per_epoch=len(train_sequence),
                        epochs=FLAGS.max_epochs,
                        verbose=1,
                        callbacks=[
                            history, tensorBoard, reduce_lr,
                            EarlyStopping(monitor='val_acc',
                                          patience=3,
                                          restore_best_weights=True)
                        ],
                        validation_data=validation_sequence,
                        max_queue_size=10,
                        workers=int(multiprocessing.cpu_count() * 0.7),
                        use_multiprocessing=True,
                        shuffle=True)
    # model.fit_generator(
    #     train_sequence,
    #     steps_per_epoch=(12621 // FLAGS.batch_size) + 1,
    #     epochs=FLAGS.max_epochs,
    #     verbose=1,
    #     callbacks=[history, reduce_lr, tensorBoard],
    #     validation_data=validation_sequence,
    #     validation_steps=(3156 // FLAGS.batch_size) + 1,
    # )

    print('training done!')
    if FLAGS.deploy_script_path != '':
        save_pb_model(FLAGS, model)

    end_time = datetime.now()
    cost_seconds = (end_time - start_time).seconds
    print('=' * 70)
    print('Cost time: {}:{}:{}\n'.format(cost_seconds // 3600,
                                         (cost_seconds % 3600) // 60,
                                         cost_seconds % 60))

    with open(os.path.join(FLAGS.train_local, 'train_details.txt'), 'w') as f:
        rank = sorted(train_details.items(),
                      key=lambda x: int(x[0].split('_')[1]),
                      reverse=False)
        f.write('epoch order\n')
        for item in rank:
            f.write('{}: acc: {:.7f}  val_acc: {:.7f}\n'.format(
                item[0], item[1][0], item[1][1]))

        f.write('=' * 70 + '\n')
        f.write('val_acc order\n')
        rank = sorted(train_details.items(),
                      key=lambda x: x[1][1],
                      reverse=True)
        for item in rank:
            f.write('{}: acc: {:.7f}  val_acc: {:.7f}\n'.format(
                item[0], item[1][0], item[1][1]))

    test_model(FLAGS, model)
def train_model(FLAGS):
    # data flow generator
    train_sequence, validation_sequence = data_flow(FLAGS.data_local,
                                                    FLAGS.batch_size,
                                                    FLAGS.num_classes,
                                                    FLAGS.input_size)

    optimizer = adam(lr=FLAGS.learning_rate)  #, clipnorm=0.001)
    reduce_on_plateau = ReduceLROnPlateau(monitor="val_acc",
                                          mode="max",
                                          factor=0.1,
                                          patience=10,
                                          verbose=1)
    objective = 'categorical_crossentropy'
    metrics = ['accuracy']
    model = model_fn(FLAGS, objective, optimizer, metrics)
    if not os.path.exists(FLAGS.train_local):
        os.makedirs(FLAGS.train_local)
    tensorBoard = TensorBoard(log_dir=FLAGS.train_local)
    history = LossHistory(FLAGS)
    model.fit_generator(train_sequence,
                        steps_per_epoch=len(train_sequence),
                        epochs=FLAGS.max_epochs,
                        verbose=2,
                        callbacks=[history, tensorBoard, reduce_on_plateau],
                        validation_data=validation_sequence,
                        max_queue_size=10,
                        workers=int(multiprocessing.cpu_count() * 0.7),
                        use_multiprocessing=True,
                        shuffle=True)

    print('training done!')

    model.load_weights(history.best_file)

    if FLAGS.deploy_script_path != '':
        from save_model import save_pb_model
        save_pb_model(FLAGS, model)

    labels = []
    logits = []

    for i in range(len(validation_sequence)):
        test_data, test_label = validation_sequence[i]
        predictions = model.predict(test_data, verbose=0)
        labels.extend(np.argmax(test_label, axis=1))
        logits.extend(np.argmax(predictions, axis=1))

    labels = np.array(labels)
    logits = np.array(logits)

    accuracy = np.sum((labels - logits) == 0) / labels.size
    print('accuracy: %0.4f' % accuracy)

    result = []

    for i in range(FLAGS.num_classes):
        result.append(
            np.sum(((labels - logits) == 0) * (labels == i)) /
            np.sum(labels == i))

    with open('result.json', 'w') as fp:
        json.dump([result, labels.tolist(), logits.tolist()], fp)
Beispiel #6
0
def train_model(FLAGS):
    # data flow generator
    train_sequence, validation_sequence = data_flow(FLAGS.data_local,
                                                    FLAGS.batch_size,
                                                    FLAGS.num_classes,
                                                    FLAGS.input_size)

    # optimizer = adam(lr=FLAGS.learning_rate, clipnorm=0.001)
    optimizer = Nadam(lr=FLAGS.learning_rate,
                      beta_1=0.9,
                      beta_2=0.999,
                      epsilon=1e-08,
                      schedule_decay=0.004)
    # optimizer = SGD(lr=FLAGS.learning_rate, momentum=0.9)
    objective = 'categorical_crossentropy'
    metrics = ['accuracy']
    #model = model_fn(FLAGS, objective, optimizer, metrics)
    #model = model_fn_SE_ResNet50(FLAGS, objective, optimizer, metrics)
    model = model_fn_Xception(FLAGS, objective, optimizer, metrics)

    if FLAGS.restore_model_path != '' and os.path.exists(
            FLAGS.restore_model_path):
        if FLAGS.restore_model_path.startswith('s3://'):
            restore_model_name = FLAGS.restore_model_path.rsplit('/', 1)[1]
            shutil.copyfile(FLAGS.restore_model_path,
                            '/cache/tmp/' + restore_model_name)
            model.load_weights('/cache/tmp/' + restore_model_name)
            os.remove('/cache/tmp/' + restore_model_name)
        else:
            model.load_weights(FLAGS.restore_model_path)
        print("LOAD OK!!!")
    if not os.path.exists(FLAGS.train_local):
        os.makedirs(FLAGS.train_local)

    log_local = '../log_file/'
    tensorBoard = TensorBoard(log_dir=log_local)
    # reduce_lr = ks.callbacks.ReduceLROnPlateau(monitor='val_acc', factor=0.5, verbose=1, patience=1,
    #                                            min_lr=1e-7)
    # 余弦退火学习率
    sample_count = len(train_sequence) * FLAGS.batch_size
    epochs = FLAGS.max_epochs
    warmup_epoch = 5
    batch_size = FLAGS.batch_size
    learning_rate_base = FLAGS.learning_rate
    total_steps = int(epochs * sample_count / batch_size)
    warmup_steps = int(warmup_epoch * sample_count / batch_size)

    warm_up_lr = WarmUpCosineDecayScheduler(
        learning_rate_base=learning_rate_base,
        total_steps=total_steps,
        warmup_learning_rate=0,
        warmup_steps=warmup_steps,
        hold_base_rate_steps=0,
    )
    history = LossHistory(FLAGS)
    model.fit_generator(train_sequence,
                        steps_per_epoch=len(train_sequence),
                        epochs=FLAGS.max_epochs,
                        verbose=1,
                        callbacks=[history, tensorBoard, warm_up_lr],
                        validation_data=validation_sequence,
                        max_queue_size=10,
                        workers=int(multiprocessing.cpu_count() * 0.7),
                        use_multiprocessing=True,
                        shuffle=True)

    print('training done!')

    if FLAGS.deploy_script_path != '':
        from save_model import save_pb_model
        save_pb_model(FLAGS, model)

    if FLAGS.test_data_url != '':
        print('test dataset predicting...')
        from eval import load_test_data
        img_names, test_data, test_labels = load_test_data(FLAGS)
        test_data = preprocess_input(test_data)
        predictions = model.predict(test_data, verbose=0)

        right_count = 0
        for index, pred in enumerate(predictions):
            predict_label = np.argmax(pred, axis=0)
            test_label = test_labels[index]
            if predict_label == test_label:
                right_count += 1
        accuracy = right_count / len(img_names)
        print('accuracy: %0.4f' % accuracy)
        metric_file_name = os.path.join(FLAGS.train_local, 'metric.json')
        metric_file_content = '{"total_metric": {"total_metric_values": {"accuracy": %0.4f}}}' % accuracy
        with open(metric_file_name, "w") as f:
            f.write(metric_file_content + '\n')
    print('end')
def train_model(FLAGS):
    # data flow generator
    train_sequence, validation_sequence = data_flow(FLAGS.data_local,
                                                    FLAGS.batch_size,
                                                    FLAGS.num_classes,
                                                    FLAGS.input_size)

    optimizer = adam(lr=FLAGS.learning_rate, decay=1e-6, clipnorm=0.001)
    objective = 'categorical_crossentropy'
    metrics = ['accuracy']
    model = model_fn(FLAGS, objective, optimizer, metrics)
    if FLAGS.restore_model_path != '' and mox.file.exists(
            FLAGS.restore_model_path):
        if FLAGS.restore_model_path.startswith('s3://'):
            restore_model_name = FLAGS.restore_model_path.rsplit('/', 1)[1]
            mox.file.copy(FLAGS.restore_model_path,
                          '/cache/tmp/' + restore_model_name)
            model.load_weights('/cache/tmp/' + restore_model_name)
            os.remove('/cache/tmp/' + restore_model_name)
        else:
            model.load_weights(FLAGS.restore_model_path)
        print('restore parameters from %s success' % FLAGS.restore_model_path)

    if not os.path.exists(FLAGS.train_local):
        os.makedirs(FLAGS.train_local)
    tensorboard = TensorBoard(log_dir=FLAGS.train_local,
                              batch_size=FLAGS.batch_size)
    early_stopping = EarlyStopping(monitor='val_loss', patience=4, verbose=2)
    history = LossHistory(FLAGS)
    model.fit_generator(train_sequence,
                        steps_per_epoch=len(train_sequence),
                        epochs=FLAGS.max_epochs,
                        verbose=1,
                        callbacks=[history, tensorboard, early_stopping],
                        validation_data=validation_sequence,
                        max_queue_size=10,
                        workers=int(multiprocessing.cpu_count() * 0.7),
                        use_multiprocessing=True,
                        shuffle=True)

    print('training done!')

    # 将训练日志拷贝到OBS,然后可以用 ModelArts 训练作业自带的tensorboard查看训练情况
    if FLAGS.train_url.startswith('s3://'):
        files = mox.file.list_directory(FLAGS.train_local)
        for file_name in files:
            if file_name.startswith('enevts'):
                mox.file.copy(os.path.join(FLAGS.train_local, file_name),
                              os.path.join(FLAGS.train_url, file_name))
        print('save events log file to OBS path: ', FLAGS.train_url)

    pb_save_dir_local = ''
    if FLAGS.deploy_script_path != '':
        from save_model import save_pb_model
        # 默认将最新的模型保存为pb模型,您可以使用python run.py --mode=save_pb ... 将指定的h5模型转为pb模型
        pb_save_dir_local = save_pb_model(FLAGS, model)

    if FLAGS.deploy_script_path != '' and FLAGS.test_data_url != '':
        print('test dataset predicting...')
        from inference import infer_on_dataset
        accuracy, result_file_path = infer_on_dataset(
            FLAGS.test_data_local, FLAGS.test_data_local,
            os.path.join(pb_save_dir_local, 'model'))
        if accuracy is not None:
            metric_file_name = os.path.join(FLAGS.train_url, 'metric.json')
            metric_file_content = '{"total_metric": {"total_metric_values": {"accuracy": %0.4f}}}' % accuracy
            with mox.file.File(metric_file_name, "w") as f:
                f.write(metric_file_content + '\n')
            if FLAGS.train_url.startswith('s3://'):
                result_file_path_obs = os.path.join(
                    FLAGS.train_url, 'model',
                    os.path.basename(result_file_path))
                mox.file.copy(result_file_path, result_file_path_obs)
                print('accuracy result file has been copied to %s' %
                      result_file_path_obs)
        else:
            print('accuracy is None')
    print('end')
Beispiel #8
0
def train_model(FLAGS):
    # data flow generator
    train_sequence, validation_sequence = data_flow(
        FLAGS.data_local, FLAGS.batch_size, FLAGS.num_classes,
        FLAGS.input_size)  # 成为了一个正方形224*224

    optimizer = Adam(lr=FLAGS.learning_rate, epsilon=10e-8)
    objective = 'categorical_crossentropy'
    metrics = ['accuracy']
    model = model_fn(FLAGS, objective, optimizer, metrics)
    if FLAGS.restore_model_path != '' and mox.file.exists(
            FLAGS.restore_model_path):
        if FLAGS.restore_model_path.startswith('s3://'):
            restore_model_name = FLAGS.restore_model_path.rsplit('/', 1)[1]
            mox.file.copy(FLAGS.restore_model_path,
                          '/cache/tmp/' + restore_model_name)
            model.load_weights('/cache/tmp/' + restore_model_name)
            os.remove('/cache/tmp/' + restore_model_name)
        else:
            model.load_weights(FLAGS.restore_model_path)
        print('restore parameters from %s success' % FLAGS.restore_model_path)

    if not os.path.exists(FLAGS.train_local):
        os.makedirs(FLAGS.train_local)
    tensorboard = TensorBoard(log_dir=FLAGS.train_local,
                              batch_size=FLAGS.batch_size)
    history = LossHistory(FLAGS)
    model.fit_generator(
        train_sequence,  # 生成器函数
        steps_per_epoch=len(
            train_sequence
        ),  # 整数,当生成器返回steps_per_epoch次数据时计一个epoch结束,执行下一个epoch
        epochs=FLAGS.max_epochs,  # 整数,数据迭代的轮数
        verbose=1,  # 日志显示,0为不在标准输出流输出日志信息,1为输出进度条记录,2为每个epoch输出一行记录
        callbacks=[history, tensorboard],  # 回调实例。在训练期间应用的回调列表
        validation_data=validation_sequence,  # 本参数指定验证集的生成器返回次数
        max_queue_size=10,  # 生成器队列的最大容量
        workers=int(multiprocessing.cpu_count() * 0.7),  # 最大进程数
        use_multiprocessing=True,  # 使用基于进程的线程
        shuffle=True  # 是否在每个epoch开始时打乱批次的顺序
    )

    print('training done!')

    # 将训练日志拷贝到OBS,然后可以用 ModelArts 训练作业自带的tensorboard查看训练情况
    if FLAGS.train_url.startswith('s3://'):
        files = mox.file.list_directory(FLAGS.train_local)
        for file_name in files:
            if file_name.startswith('enevts'):
                mox.file.copy(os.path.join(FLAGS.train_local, file_name),
                              os.path.join(FLAGS.train_url, file_name))
        print('save events log file to OBS path: ', FLAGS.train_url)

    pb_save_dir_local = ''
    if FLAGS.deploy_script_path != '':
        from save_model import save_pb_model
        # 默认将最新的模型保存为pb模型,您可以使用python run.py --mode=save_pb ... 将指定的h5模型转为pb模型
        pb_save_dir_local = save_pb_model(FLAGS, model)

    if FLAGS.deploy_script_path != '' and FLAGS.test_data_url != '':
        print('test dataset predicting...')
        from inference import infer_on_dataset
        accuracy, result_file_path = infer_on_dataset(
            FLAGS.test_data_local, FLAGS.test_data_local,
            os.path.join(pb_save_dir_local, 'model'))
        if accuracy is not None:
            metric_file_name = os.path.join(FLAGS.train_url, 'metric.json')
            metric_file_content = '{"total_metric": {"total_metric_values": {"accuracy": %0.4f}}}' % accuracy
            with mox.file.File(metric_file_name, "w") as f:
                f.write(metric_file_content + '\n')
            if FLAGS.train_url.startswith('s3://'):
                result_file_path_obs = os.path.join(
                    FLAGS.train_url, 'model',
                    os.path.basename(result_file_path))
                mox.file.copy(result_file_path, result_file_path_obs)
                print('accuracy result file has been copied to %s' %
                      result_file_path_obs)
        else:
            print('accuracy is None')
    print('end')