Пример #1
0
def test_profile_train_step():
    """Asserts that the saved profile matches expectation for a simple example."""
    environment.init("test_executor")
    prepare_dirs(recreate=True)

    with tf.compat.v1.Session() as sess:
        w = tf.Variable([[5.0, 3.0, 2.9, -4.0, 0.0]])
        v = tf.Variable([[0.21, -2.70, 0.94, 3.82, -3.65],
                         [5.0, 3.0, 2.9, -4.0, 0.0],
                         [1.96, -2.2, 0.42, -1.26, -1.06],
                         [-1.55, 4.56, -4.71, -2.43, 4.55],
                         [-3.11, 3.78, -3.45, 2.18, -4.45]])
        z = tf.matmul(
            w, v)  # z is [[27.933998, -29.119999, 33.458, 13.166, -39.524002]]
        sess.run(tf.compat.v1.global_variables_initializer())
        step = 0
        options = tf.compat.v1.RunOptions(
            trace_level=tf.compat.v1.RunOptions.FULL_TRACE)
        run_meta = tf.compat.v1.RunMetadata()
        sess.run(z, options=options, run_metadata=run_meta)
        profile_train_step(step, sess, run_meta)

    expected_memory = textwrap.dedent("""\

        Doc:
        scope: The nodes in the model graph are organized by their names, which is hierarchical like filesystem.
        requested bytes: The memory requested by the operation, accumulatively.

        Profile:
        node name | requested bytes
        _TFProfRoot
          MatMul
          Variable
          Variable_1
        """).splitlines()
    expected_memory.sort()
    expected_timeline = [{
        "args": {
            "name": "MatMul",
            "op": "MatMul"
        },
        "cat": "Op",
        "dur": 267,
        "name": "MatMul",
        "ph": "X",
        "pid": 1,
        "tid": 0,
        "ts": 0
    }, {
        "args": {
            "name": "Variable",
            "op": "Variable"
        },
        "cat": "Op",
        "dur": 4,
        "name": "Variable",
        "ph": "X",
        "pid": 1,
        "tid": 0,
        "ts": 267
    }, {
        "args": {
            "name": "Variable_1",
            "op": "Variable_1"
        },
        "cat": "Op",
        "dur": 20,
        "name": "Variable_1",
        "ph": "X",
        "pid": 1,
        "tid": 0,
        "ts": 271
    }, {
        "args": {
            "name": "_TFProfRoot",
            "op": "_TFProfRoot"
        },
        "cat": "Op",
        "dur": 291,
        "name": "_TFProfRoot",
        "ph": "X",
        "pid": 0,
        "tid": 0,
        "ts": 0
    }, {
        "args": {
            "name": "Scope:0"
        },
        "name": "process_name",
        "ph": "M",
        "pid": 0
    }, {
        "args": {
            "name": "Scope:1"
        },
        "name": "process_name",
        "ph": "M",
        "pid": 1
    }]

    train_memory_path = os.path.join(environment.EXPERIMENT_DIR,
                                     "training_profile_memory")
    with open(train_memory_path) as train_memory_file:
        saved_data = train_memory_file.read().splitlines()
        saved_data.sort()
        for idx, line in enumerate(saved_data):
            assert line.startswith(expected_memory[idx])
    train_timeline_path = os.path.join(environment.EXPERIMENT_DIR,
                                       "training_profile_timeline_step")
    with open("{}_{}".format(train_timeline_path,
                             step)) as train_timeline_file:
        saved_data = json.load(train_timeline_file)["traceEvents"]
        saved_data.sort(key=lambda op: op["name"])
        for op1, op2 in zip(expected_timeline, saved_data):
            assert op1["args"] == op2["args"]
            # Generally, timeline values are different each run, so just check the keys match.
            assert op1.keys() == op2.keys()
Пример #2
0
def start_training(config, profile_step):
    use_horovod = horovod_util.is_enabled()
    print("use_horovod:", use_horovod)
    if use_horovod:
        hvd = horovod_util.setup()
        rank = hvd.rank()
        local_rank = hvd.local_rank()
    else:
        rank = 0
        local_rank = -1

    ModelClass = config.NETWORK_CLASS
    network_kwargs = {key.lower(): val for key, val in config.NETWORK.items()}

    train_dataset = setup_dataset(config, "train", rank, local_rank)
    print("train dataset num:", train_dataset.num_per_epoch)

    validation_dataset = setup_dataset(config, "validation", rank, local_rank)
    print("validation dataset num:", validation_dataset.num_per_epoch)

    graph = tf.Graph()
    with graph.as_default():
        if config.TASK == Tasks.OBJECT_DETECTION:
            model = ModelClass(
                classes=train_dataset.classes,
                num_max_boxes=train_dataset.num_max_boxes,
                is_debug=config.IS_DEBUG,
                **network_kwargs,
            )
        else:
            model = ModelClass(
                classes=train_dataset.classes,
                is_debug=config.IS_DEBUG,
                **network_kwargs,
            )

        is_training_placeholder = tf.compat.v1.placeholder(tf.bool, name="is_training_placeholder")

        images_placeholder, labels_placeholder = model.placeholders()

        output = model.inference(images_placeholder, is_training_placeholder)
        loss = model.loss(output, labels_placeholder)
        opt = model.optimizer()
        if use_horovod:
            # add Horovod Distributed Optimizer
            opt = hvd.DistributedOptimizer(opt)
        train_op = model.train(loss, opt)
        metrics_ops_dict, metrics_update_op = model.metrics(output, labels_placeholder)
        # TODO(wakisaka): Deal with many networks.
        model.summary(output, labels_placeholder)

        summary_op = tf.compat.v1.summary.merge_all()
        metrics_summary_op = executor.metrics_summary_op(metrics_ops_dict)

        init_op = tf.compat.v1.global_variables_initializer()
        reset_metrics_op = tf.compat.v1.local_variables_initializer()
        if use_horovod:
            # add Horovod broadcasting variables from rank 0 to all
            bcast_global_variables_op = hvd.broadcast_global_variables(0)

        saver = tf.compat.v1.train.Saver(max_to_keep=config.KEEP_CHECKPOINT_MAX)

        with open(os.path.join(environment.EXPERIMENT_DIR, "pretrain_vars.txt"), 'w') as pretrain_vars_file:
            train_vars = tf.compat.v1.trainable_variables()
            pretrain_vars_file.writelines("[\n")
            pretrain_vars_file.writelines("    '%s',\n" % var.name for var in train_vars)
            pretrain_vars_file.writelines("]\n")

        if config.IS_PRETRAIN:
            all_vars = tf.compat.v1.global_variables()
            pretrain_var_list = [
                var for var in all_vars if var.name.startswith(tuple(config.PRETRAIN_VARS))
            ]
            print("pretrain_vars", [
                var.name for var in pretrain_var_list
            ])
            pretrain_saver = tf.compat.v1.train.Saver(pretrain_var_list, name="pretrain_saver")

    if use_horovod:
        # For distributed training
        session_config = tf.compat.v1.ConfigProto(
            gpu_options=tf.compat.v1.GPUOptions(
                allow_growth=True,
                visible_device_list=str(hvd.local_rank())
            )
        )
    else:
        # TODO(wakisaka): For debug.
        # session_config = tf.ConfigProto(
        #     gpu_options=tf.GPUOptions(
        #         allow_growth=True,
        #         per_process_gpu_memory_fraction=0.1
        #     )
        # )
        session_config = tf.compat.v1.ConfigProto()  # tf.ConfigProto(log_device_placement=True)
    # TODO(wakisaka): XLA JIT
    # session_config.graph_options.optimizer_options.global_jit_level = tf.OptimizerOptions.ON_1

    sess = tf.compat.v1.Session(graph=graph, config=session_config)
    sess.run([init_op, reset_metrics_op])
    executor.save_pb_file(sess, environment.CHECKPOINTS_DIR)

    if rank == 0:
        train_writer = tf.compat.v1.summary.FileWriter(environment.TENSORBOARD_DIR + "/train", sess.graph)
        val_writer = tf.compat.v1.summary.FileWriter(environment.TENSORBOARD_DIR + "/validation")

        if config.IS_PRETRAIN:
            print("------- Load pretrain data ----------")
            pretrain_saver.restore(sess, os.path.join(config.PRETRAIN_DIR, config.PRETRAIN_FILE))

        # for recovery
        ckpt = tf.train.get_checkpoint_state(environment.CHECKPOINTS_DIR)
        if ckpt and ckpt.model_checkpoint_path:
            print("--------- Restore last checkpoint -------------")
            saver.restore(sess, ckpt.model_checkpoint_path)
            # saver.recover_last_checkpoints(ckpt.model_checkpoint_path)
            last_step = sess.run(model.global_step)
            # TODO(wakisaka): tensorflow v1.3 remain previous event log in tensorboard.
            # https://github.com/tensorflow/tensorflow/blob/r1.3/tensorflow/python/training/supervisor.py#L1072
            train_writer.add_session_log(SessionLog(status=SessionLog.START), global_step=last_step + 1)
            val_writer.add_session_log(SessionLog(status=SessionLog.START), global_step=last_step + 1)
            print("recovered. last step", last_step)

    if use_horovod:
        # broadcast variables from rank 0 to all other processes
        sess.run(bcast_global_variables_op)

    last_step = sess.run(model.global_step)

    # Calculate max steps. The priority of config.MAX_EPOCHS is higher than config.MAX_STEPS.
    if "MAX_EPOCHS" in config:
        max_steps = int(train_dataset.num_per_epoch / config.BATCH_SIZE * config.MAX_EPOCHS)
        if max_steps < 1:
            print("The max_steps is less than 1, consider reduce BATCH_SIZE. exit.", file=sys.stderr)
            sys.exit(1)
    else:
        max_steps = config.MAX_STEPS
        if max_steps < 1:
            print("The max_steps is less than 1, consider set MAX_STEPS greater than 0. exit.", file=sys.stderr)
            sys.exit(1)

    progbar = Progbar(max_steps)
    if rank == 0:
        progbar.update(last_step)
    for step in range(last_step, max_steps):

        images, labels = train_dataset.feed()

        feed_dict = {
            is_training_placeholder: True,
            images_placeholder: images,
            labels_placeholder: labels,
        }

        # Runtime statistics for develop.
        if step == profile_step:
            options = tf.compat.v1.RunOptions(trace_level=tf.compat.v1.RunOptions.FULL_TRACE)
            run_meta = tf.compat.v1.RunMetadata()
        else:
            options = None
            run_meta = None

        if step * ((step + 1) % config.SUMMARISE_STEPS) == 0 and rank == 0:
            sess.run(reset_metrics_op)
            _, summary, _ = sess.run(
                [train_op, summary_op, metrics_update_op], feed_dict=feed_dict,
                options=options,
                run_metadata=run_meta,
            )
            # train_writer.add_run_metadata(run_metadata, "step: {}".format(step + 1))
            train_writer.add_summary(summary, step + 1)

            metrics_summary = sess.run(metrics_summary_op)
            train_writer.add_summary(metrics_summary, step + 1)
            train_writer.flush()
        else:
            sess.run([train_op], feed_dict=feed_dict, options=options, run_metadata=run_meta)

        if step == profile_step:
            executor.profile_train_step(step, sess, run_meta)

        to_be_saved = step == 0 or (step + 1) == max_steps or (step + 1) % config.SAVE_CHECKPOINT_STEPS == 0

        if to_be_saved and rank == 0:
            _save_checkpoint(saver, sess, model.global_step)

        if step == 0 or (step + 1) % config.TEST_STEPS == 0:
            # init metrics values
            sess.run(reset_metrics_op)
            test_step_size = int(math.ceil(validation_dataset.num_per_epoch / config.BATCH_SIZE))

            for test_step in range(test_step_size):

                images, labels = validation_dataset.feed()
                feed_dict = {
                    is_training_placeholder: False,
                    images_placeholder: images,
                    labels_placeholder: labels,
                }

                if test_step % config.SUMMARISE_STEPS == 0:
                    summary, _ = sess.run([summary_op, metrics_update_op], feed_dict=feed_dict)
                    if rank == 0:
                        val_writer.add_summary(summary, step + 1)
                        val_writer.flush()
                else:
                    sess.run([metrics_update_op], feed_dict=feed_dict)

            metrics_summary = sess.run(metrics_summary_op)
            if rank == 0:
                val_writer.add_summary(metrics_summary, step + 1)
                val_writer.flush()

        if rank == 0:
            progbar.update(step + 1)
    # training loop end.
    train_dataset.close()
    validation_dataset.close()
    print("Done")