Exemple #1
0
def setup():
    if not horovod_installed:
        return False

    global horovod_initialized
    if horovod_initialized:
        return hvd

    hvd.init()
    horovod_initialized = True

    horovod_num_worker = hvd.size()
    horovod_rank = hvd.rank()
    # verify that MPI multi-threading is supported.
    assert hvd.mpi_threads_supported()
    # make sure MPI is not re-initialized.
    import mpi4py.rc
    mpi4py.rc.initialize = False
    # import mpi4py
    from mpi4py import MPI
    comm = MPI.COMM_WORLD
    # check size and rank are syncronized
    assert horovod_num_worker == comm.Get_size()
    assert horovod_rank == comm.Get_rank()
    return hvd
Exemple #2
0
def setup_horovod():
    import horovod.tensorflow as hvd

    # Initialize Horovod
    hvd.init()
    # Verify that MPI multi-threading is supported.
    assert hvd.mpi_threads_supported()

    from mpi4py import MPI

    assert hvd.size() == MPI.COMM_WORLD.Get_size()

    is_root = hvd.rank() == 0

    def mpi_average(local_list):
        # _local_list_orig = local_list
        local_list = list(map(float, local_list))
        # print('RANK {} AVERAGING {} -> {}'.format(hvd.rank(), _local_list_orig, local_list))
        sums = MPI.COMM_WORLD.gather(sum(local_list), root=0)
        counts = MPI.COMM_WORLD.gather(len(local_list), root=0)
        sum_counts = sum(counts) if is_root else None
        avg = (sum(sums) / sum_counts) if is_root else None
        return avg, sum_counts

    return hvd, MPI, is_root, mpi_average
Exemple #3
0
def init_workers(distributed=False):
    if distributed and not no_horovod:
        hvd.init()
        assert hvd.mpi_threads_supported()
        from mpi4py import MPI
        assert hvd.size() == MPI.COMM_WORLD.Get_size()
        comm = MPI.COMM_WORLD
        print("Rank: {}, Size: {}".format(hvd.rank(), hvd.size()))
        return SimpleNamespace(rank=hvd.rank(), size=hvd.size(),
                                local_rank=hvd.local_rank(),
                                local_size=hvd.local_size(), comm=comm)
    else:
        print("not doing distributed")
        return SimpleNamespace(rank=0, size=1, local_rank=0, local_size=1, comm=None)
Exemple #4
0
def start_training(config):
    if config.IS_DISTRIBUTION:
        import horovod.tensorflow as hvd
        # initialize Horovod.
        hvd.init()
        num_worker = hvd.size()
        rank = hvd.rank()
        # verify that MPI multi-threading is supported.
        assert hvd.mpi_threads_supported()
        # make sure MPI is not re-initialized.
        import mpi4py.rc
        mpi4py.rc.initialize = False
        # import mpi4py
        from mpi4py import MPI
        comm = MPI.COMM_WORLD
        # check size and rank are syncronized
        assert num_worker == comm.Get_size()
        assert rank == comm.Get_rank()
    else:
        num_worker = 1
        rank = 0

    ModelClass = config.NETWORK_CLASS
    network_kwargs = dict(
        (key.lower(), val) for key, val in config.NETWORK.items())
    if "train_validation_saving_size".upper() in config.DATASET.keys():
        use_train_validation_saving = config.DATASET.TRAIN_VALIDATION_SAVING_SIZE > 0
    else:
        use_train_validation_saving = False

    if use_train_validation_saving:
        top_train_validation_saving_set_accuracy = 0

    train_dataset = setup_dataset(config, "train", rank)
    print("train dataset num:", train_dataset.num_per_epoch)

    if use_train_validation_saving:
        train_validation_saving_dataset = setup_dataset(
            config, "train_validation_saving", rank)
        print("train_validation_saving dataset num:",
              train_validation_saving_dataset.num_per_epoch)

    validation_dataset = setup_dataset(config, "validation", rank)
    print("validation dataset num:", validation_dataset.num_per_epoch)

    graph = tf.Graph()
    with graph.as_default():
        if ModelClass.__module__.startswith("lmnet.networks.object_detection"):
            model = ModelClass(
                classes=train_dataset.classes,
                num_max_boxes=train_dataset.num_max_boxes,
                is_debug=config.IS_DEBUG,
                **network_kwargs,
            )
        elif ModelClass.__module__.startswith("lmnet.networks.segmentation"):
            model = ModelClass(
                classes=train_dataset.classes,
                label_colors=train_dataset.label_colors,
                is_debug=config.IS_DEBUG,
                **network_kwargs,
            )
        else:
            model = ModelClass(
                classes=train_dataset.classes,
                is_debug=config.IS_DEBUG,
                **network_kwargs,
            )

        global_step = tf.Variable(0, name="global_step", trainable=False)
        is_training_placeholder = tf.placeholder(
            tf.bool, name="is_training_placeholder")

        images_placeholder, labels_placeholder = model.placeholderes()

        output = model.inference(images_placeholder, is_training_placeholder)
        if ModelClass.__module__.startswith("lmnet.networks.object_detection"):
            loss = model.loss(output, labels_placeholder,
                              is_training_placeholder)
        else:
            loss = model.loss(output, labels_placeholder)
        opt = model.optimizer(global_step)
        if config.IS_DISTRIBUTION:
            # add Horovod Distributed Optimizer
            opt = hvd.DistributedOptimizer(opt)
        train_op = model.train(loss, opt, global_step)
        metrics_ops_dict, metrics_update_op = model.metrics(
            output, labels_placeholder)
        # TODO(wakisaka): Deal with many networks.
        model.summary(output, labels_placeholder)

        summary_op = tf.summary.merge_all()

        metrics_summary_op, metrics_placeholders = executor.prepare_metrics(
            metrics_ops_dict)

        init_op = tf.global_variables_initializer()
        reset_metrics_op = tf.local_variables_initializer()
        if config.IS_DISTRIBUTION:
            # add Horovod broadcasting variables from rank 0 to all
            bcast_global_variables_op = hvd.broadcast_global_variables(0)

        if use_train_validation_saving:
            saver = tf.train.Saver(max_to_keep=1)
        else:
            saver = tf.train.Saver(max_to_keep=None)

        if config.IS_PRETRAIN:
            all_vars = tf.global_variables()
            pretrain_var_list = [
                var for var in all_vars
                if var.name.startswith(tuple(config.PRETRAIN_VARS))
            ]
            print("pretrain_vars", [var.name for var in pretrain_var_list])
            pretrain_saver = tf.train.Saver(pretrain_var_list,
                                            name="pretrain_saver")

    if config.IS_DISTRIBUTION:
        # For distributed training
        session_config = tf.ConfigProto(gpu_options=tf.GPUOptions(
            allow_growth=True, visible_device_list=str(hvd.local_rank())))
    else:
        # TODO(wakisaka): For debug.
        # session_config = tf.ConfigProto(
        #     gpu_options=tf.GPUOptions(
        #         allow_growth=True,
        #         per_process_gpu_memory_fraction=0.1
        #     )
        # )
        session_config = tf.ConfigProto(
        )  # tf.ConfigProto(log_device_placement=True)
    # TODO(wakisaka): XLA JIT
    # session_config.graph_options.optimizer_options.global_jit_level = tf.OptimizerOptions.ON_1

    sess = tf.Session(graph=graph, config=session_config)
    sess.run([init_op, reset_metrics_op])

    if rank == 0:
        train_writer = tf.summary.FileWriter(
            environment.TENSORBOARD_DIR + "/train", sess.graph)
        if use_train_validation_saving:
            train_val_saving_writer = tf.summary.FileWriter(
                environment.TENSORBOARD_DIR + "/train_validation_saving")
        val_writer = tf.summary.FileWriter(environment.TENSORBOARD_DIR +
                                           "/validation")

        if config.IS_PRETRAIN:
            print("------- Load pretrain data ----------")
            pretrain_saver.restore(
                sess, os.path.join(config.PRETRAIN_DIR, config.PRETRAIN_FILE))
            sess.run(tf.assign(global_step, 0))

        last_step = 0

        # for recovery
        ckpt = tf.train.get_checkpoint_state(environment.CHECKPOINTS_DIR)
        if ckpt and ckpt.model_checkpoint_path:
            print("--------- Restore last checkpoint -------------")
            saver.restore(sess, ckpt.model_checkpoint_path)
            # saver.recover_last_checkpoints(ckpt.model_checkpoint_path)
            last_step = sess.run(global_step)
            # TODO(wakisaka): tensorflow v1.3 remain previous event log in tensorboard.
            # https://github.com/tensorflow/tensorflow/blob/r1.3/tensorflow/python/training/supervisor.py#L1072
            train_writer.add_session_log(SessionLog(status=SessionLog.START),
                                         global_step=last_step + 1)
            val_writer.add_session_log(SessionLog(status=SessionLog.START),
                                       global_step=last_step + 1)
            print("recovered. last step", last_step)

    if config.IS_DISTRIBUTION:
        # broadcast variables from rank 0 to all other processes
        sess.run(bcast_global_variables_op)
        # calculate step per epoch for each nodes
        train_num_per_epoch = train_dataset.num_per_epoch
        num_per_nodes = (train_num_per_epoch + num_worker - 1) // num_worker
        step_per_epoch = num_per_nodes // config.BATCH_SIZE
        begin_index = (train_num_per_epoch * rank) // num_worker
        end_index = begin_index + num_per_nodes

    last_step = sess.run(global_step)

    # Calculate max steps. The priority of config.MAX_EPOCHS is higher than config.MAX_STEPS.
    if "MAX_EPOCHS" in config:
        max_steps = int(train_dataset.num_per_epoch / config.BATCH_SIZE *
                        config.MAX_EPOCHS)
    else:
        max_steps = config.MAX_STEPS
    print("max_steps: {}".format(max_steps))

    for step in range(last_step, max_steps):
        print("step", step)

        if config.IS_DISTRIBUTION:
            # scatter dataset
            if step % step_per_epoch == 0:
                indices = train_dataset.get_shuffle_index(
                ) if rank == 0 else None
                # broadcast shuffled indices
                indices = comm.bcast(indices, 0)
                feed_indices = indices[begin_index:end_index]
                # update each dataset by splited indices
                train_dataset.update_dataset(feed_indices)

        images, labels = train_dataset.feed()

        feed_dict = {
            is_training_placeholder: True,
            images_placeholder: images,
            labels_placeholder: labels,
        }

        if step * ((step + 1) % config.SUMMARISE_STEPS) == 0 and rank == 0:
            # Runtime statistics for develop.
            # run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE)
            # run_metadata = tf.RunMetadata()

            sess.run(reset_metrics_op)
            _, summary, _ = sess.run(
                [train_op, summary_op, metrics_update_op],
                feed_dict=feed_dict,
                # options=run_options,
                # run_metadata=run_metadata,
            )
            # train_writer.add_run_metadata(run_metadata, "step: {}".format(step + 1))
            train_writer.add_summary(summary, step + 1)

            metrics_values = sess.run(list(metrics_ops_dict.values()))
            metrics_feed_dict = {
                placeholder: value
                for placeholder, value in zip(metrics_placeholders,
                                              metrics_values)
            }

            metrics_summary, = sess.run(
                [metrics_summary_op],
                feed_dict=metrics_feed_dict,
            )
            train_writer.add_summary(metrics_summary, step + 1)
        else:
            sess.run([train_op], feed_dict=feed_dict)

        to_be_saved = step == 0 or (
            step + 1) == max_steps or (step + 1) % config.SAVE_STEPS == 0

        if to_be_saved and rank == 0:
            if use_train_validation_saving:

                sess.run(reset_metrics_op)
                train_validation_saving_step_size = int(
                    math.ceil(train_validation_saving_dataset.num_per_epoch /
                              config.BATCH_SIZE))
                print("train_validation_saving_step_size",
                      train_validation_saving_step_size)

                current_train_validation_saving_set_accuracy = 0

                for train_validation_saving_step in range(
                        train_validation_saving_step_size):
                    print("train_validation_saving_step",
                          train_validation_saving_step)

                    images, labels = train_validation_saving_dataset.feed()
                    feed_dict = {
                        is_training_placeholder: False,
                        images_placeholder: images,
                        labels_placeholder: labels,
                    }

                    if train_validation_saving_step % config.SUMMARISE_STEPS == 0:
                        summary, _ = sess.run([summary_op, metrics_update_op],
                                              feed_dict=feed_dict)
                        train_val_saving_writer.add_summary(summary, step + 1)
                    else:
                        sess.run([metrics_update_op], feed_dict=feed_dict)

                metrics_values = sess.run(list(metrics_ops_dict.values()))
                metrics_feed_dict = {
                    placeholder: value
                    for placeholder, value in zip(metrics_placeholders,
                                                  metrics_values)
                }
                metrics_summary, = sess.run(
                    [metrics_summary_op],
                    feed_dict=metrics_feed_dict,
                )
                train_val_saving_writer.add_summary(metrics_summary, step + 1)

                current_train_validation_saving_set_accuracy = sess.run(
                    metrics_ops_dict["accuracy"])

                if current_train_validation_saving_set_accuracy > top_train_validation_saving_set_accuracy:
                    top_train_validation_saving_set_accuracy = current_train_validation_saving_set_accuracy
                    print("New top train_validation_saving accuracy is: ",
                          top_train_validation_saving_set_accuracy)

                    _save_checkpoint(saver, sess, global_step, step)

            else:
                _save_checkpoint(saver, sess, global_step, step)

            if step == 0:
                # check create pb on only first step.
                minimal_graph = tf.graph_util.convert_variables_to_constants(
                    sess,
                    sess.graph.as_graph_def(add_shapes=True),
                    ["output"],
                )
                pb_name = "minimal_graph_with_shape_{}.pb".format(step + 1)
                pbtxt_name = "minimal_graph_with_shape_{}.pbtxt".format(step +
                                                                        1)
                tf.train.write_graph(minimal_graph,
                                     environment.CHECKPOINTS_DIR,
                                     pb_name,
                                     as_text=False)
                tf.train.write_graph(minimal_graph,
                                     environment.CHECKPOINTS_DIR,
                                     pbtxt_name,
                                     as_text=True)

        if step == 0 or (step + 1) % config.TEST_STEPS == 0:
            # init metrics values
            sess.run(reset_metrics_op)
            test_step_size = int(
                math.ceil(validation_dataset.num_per_epoch /
                          config.BATCH_SIZE))
            print("test_step_size", test_step_size)

            for test_step in range(test_step_size):
                print("test_step", test_step)

                images, labels = validation_dataset.feed()
                feed_dict = {
                    is_training_placeholder: False,
                    images_placeholder: images,
                    labels_placeholder: labels,
                }

                if test_step % config.SUMMARISE_STEPS == 0:
                    summary, _ = sess.run([summary_op, metrics_update_op],
                                          feed_dict=feed_dict)
                    if rank == 0:
                        val_writer.add_summary(summary, step + 1)
                else:
                    sess.run([metrics_update_op], feed_dict=feed_dict)

            metrics_values = sess.run(list(metrics_ops_dict.values()))
            metrics_feed_dict = {
                placeholder: value
                for placeholder, value in zip(metrics_placeholders,
                                              metrics_values)
            }
            metrics_summary, = sess.run(
                [metrics_summary_op],
                feed_dict=metrics_feed_dict,
            )
            if rank == 0:
                val_writer.add_summary(metrics_summary, step + 1)

    # training loop end.
    print("reach max step")
Exemple #5
0
# ==============================================================================

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import sys
import os
import horovod.tensorflow as hvd
import mpi4py
from mpi4py import MPI
comm = MPI.COMM_WORLD

mpi4py.rc.initialize = False
hvd.init()
assert hvd.mpi_threads_supported()
assert hvd.size() == comm.Get_size()

rank = comm.Get_rank()

if rank == 0:
    s = 'abcdef'
    data = {'key1': [7, 2.72, 2 + 3j], 'key2': ('abc', 'xyz')}
    print('before broadcasting: process %d has %s' % (rank, data), s)
else:
    s = None
    data = None
    print('before broadcasting: process %d has %s' % (rank, data), s)

data = comm.bcast(data, root=0)
s = comm.bcast(s, root=0)
Exemple #6
0
import horovod.tensorflow as hvd

# Initialize Horovod
hvd.init()

# Verify that MPI multi-threading is supported
print(hvd.mpi_threads_supported())
assert hvd.mpi_threads_supported()

from mpi4py import MPI

print(hvd.size())
assert hvd.size() == MPI.COMM_WORLD.Get_size()
def evaluate(
    *,
    flow_constructor,
    seed,
    restore_checkpoint,
    total_bs,
    iw_samples=1024,  # 4096 is too slow for ImageNet
    dtype=tf.float32,
    dataset='imagenet32',
    samples_filename='samples.png',
    extra_dims=3,
):
    import horovod.tensorflow as hvd

    # Initialize Horovod
    hvd.init()
    # Verify that MPI multi-threading is supported.
    assert hvd.mpi_threads_supported()

    from mpi4py import MPI

    assert hvd.size() == MPI.COMM_WORLD.Get_size()

    is_root = hvd.rank() == 0

    def mpi_average(local_list):
        local_list = list(map(float, local_list))
        sums = MPI.COMM_WORLD.gather(sum(local_list), root=0)
        counts = MPI.COMM_WORLD.gather(len(local_list), root=0)
        sum_counts = sum(counts) if is_root else None
        avg = (sum(sums) / sum_counts) if is_root else None
        return avg, sum_counts

    restore_checkpoint = os.path.expanduser(restore_checkpoint)

    # Seeding and logging setup
    seed_all(hvd.rank() + hvd.size() * seed)
    assert total_bs % hvd.size() == 0
    local_bs = total_bs // hvd.size()
    assert iw_samples % total_bs == 0

    if is_root:
        print('===== EVALUATING {} ({} IW samples) ====='.format(
            restore_checkpoint, iw_samples))

    # Load data
    assert dataset in ['imagenet32', 'imagenet64', 'imagenet64_5bit']

    if is_root:
        print('Loading data')
    MPI.COMM_WORLD.Barrier()
    if dataset == 'imagenet32':
        """The dataset as a npy file on RAM. There are as many copies as number of MPI threads. 
           This isn't effficient and tf.Records would be better to read from disk. 
           This is just done to ensure bits/dim reported are perfect and no data loading bugs creep in.
           However, the dataset is quite small resolution and even 8 MPI threads can work on 40GB RAM."""
        # data_train = np.load('../train_32x32.npy')
        data_val = np.load('../valid_32x32.npy')
        # assert data_train.dtype == 'uint8'
        # assert np.max(data_train) <= 255
        # assert np.min(data_train) >= 0
        assert np.max(data_val) <= 255
        assert np.min(data_val) >= 0
        assert data_val.dtype == 'uint8'
    elif dataset == 'imagenet64':
        """The dataset as a npy file on RAM. There are as many copies as number of MPI threads. 
           This isn't effficient and tf.Records would be better to read from disk. 
           This is just done to ensure bits/dim reported are perfect and no data loading bugs creep in.
           If you don't have enough CPU RAM to run 8 threads, run it with fewer threads and adjust batch-size / model-size tradeoff accordingly."""
        data_train = np.load('../train_64x64.npy')
        data_val = np.load('../valid_64x64.npy')
        assert data_train.dtype == 'uint8'
        assert np.max(data_train) <= 255
        assert np.min(data_train) >= 0
        assert np.max(data_val) <= 255
        assert np.min(data_val) >= 0
    elif dataset == 'imagenet64_5bit':
        """Similar loading as above. Quantized to 5-bit while loading."""
        if is_root:
            data_train = np.load('../train_64x64.npy')
            data_train = np.floor(data_train / 8.)
            data_train = data_train.astype('uint8')
            assert np.max(data_train) <= 31
            assert np.min(data_train) >= 0
            np.save('../train_64x64_5bit.npy', data_train)
            del data_train
            data_val = np.load('../valid_64x64.npy')
            data_val = np.floor(data_val / 8.)
            data_val = data_val.astype('uint8')
            assert np.max(data_val) <= 31
            assert np.min(data_val) >= 0
            np.save('../valid_64x64_5bit.npy', data_val)
            del data_val
        MPI.COMM_WORLD.Barrier()
        data_train = np.load('../train_64x64_5bit.npy')
        data_val = np.load('../valid_64x64_5bit.npy')
    # data_train = data_train.astype(dtype.as_numpy_dtype)
    data_val = data_val.astype(dtype.as_numpy_dtype)
    img_shp = list(data_val.shape[1:])
    if dataset == 'imagenet32':
        assert img_shp == [32, 32, 3]
    else:
        assert img_shp == [64, 64, 3]
    if is_root:
        # print('Training data: {}, Validation data: {}'.format(data_train.shape[0], data_val.shape[0]))
        print('Image shape:', img_shp)
    bpd_scale_factor = 1. / (np.log(2) * np.prod(img_shp))

    # Build graph
    if is_root: print('Building graph')
    dequant_flow, flow, posterior_flow = flow_constructor()
    x_sym = tf.placeholder(dtype, [local_bs] + img_shp)
    # This is a fake training graph. Just used to mimic flow_training, so we can load from the saver
    build_forward(x=x_sym,
                  dequant_flow=dequant_flow,
                  flow=flow,
                  posterior_flow=posterior_flow,
                  flow_kwargs=dict(init=False,
                                   ema=None,
                                   dropout_p=0,
                                   verbose=is_root)
                  # note dropout is 0: it doesn't matter
                  )

    # EMA
    params = tf.trainable_variables()
    if is_root:
        print('Parameters',
              sum(np.prod(p.get_shape().as_list()) for p in params))
    ema = tf.train.ExponentialMovingAverage(
        decay=0.9999999999999)  # ema turned off
    maintain_averages_op = tf.group(ema.apply(params))

    # Validation and sampling (with EMA)
    if is_root: print('===== Validation graph =====')
    val_flow_kwargs = dict(init=False, dropout_p=0, ema=ema, verbose=is_root)
    val_loss_sym, val_logratio_sym, val_dequant_x_sym = build_forward(
        x=x_sym,
        dequant_flow=dequant_flow,
        flow=flow,
        posterior_flow=posterior_flow,
        flow_kwargs=val_flow_kwargs)

    allgathered_val_logratios_sym = hvd.allgather(val_logratio_sym)
    # for debugging invertibility
    # val_inverr_sym = tf.reduce_max(tf.abs(
    #     val_dequant_x_sym - flow.inverse(val_y_sym, dropout_p=0, ema=ema, verbose=is_root)[0]
    # ))

    if is_root: print('===== Sampling graph =====')
    samples_sym, _ = flow.sample(local_bs, flow_kwargs=val_flow_kwargs)
    allgathered_samples_sym = hvd.allgather(tf.to_float(samples_sym))

    assert len(tf.trainable_variables()) == len(params)

    def run_iw_eval(sess):
        if is_root:
            print('Running IW eval with {} samples...'.format(iw_samples))
        # Go through one example at a time
        all_val_losses = []
        for i_example in (trange if is_root else range)(len(data_val)):
            # take this single example and tile it
            batch_x = np.tile(data_val[i_example, None, ...],
                              (local_bs, 1, 1, 1))
            # repeatedly evaluate logd for the IWAE bound
            batch_logratios = np.concatenate([
                sess.run(allgathered_val_logratios_sym, {x_sym: batch_x})
                for _ in range(iw_samples // total_bs)
            ]).astype(np.float64)
            assert batch_logratios.shape == (iw_samples, )
            # log [1/n \sum_i exp(r_i)] = log [exp(-b) 1/n \sum_i exp(r_i + b)] = -b + log [1/n \sum_i exp(r_i + b)]
            shift = batch_logratios.max()
            all_val_losses.append(
                -bpd_scale_factor *
                (shift + np.log(np.mean(np.exp(batch_logratios - shift)))))
            if i_example % 100 == 0 and is_root:
                print(i_example, np.mean(all_val_losses))
        if is_root:
            print(f'Final ({len(data_val)}):', np.mean(all_val_losses))

    def run_sampling_only(sess,
                          *,
                          prefix=dataset,
                          dump_to_tensorboard=True,
                          save_jpg=False):
        samples = sess.run(allgathered_samples_sym)
        if is_root:
            print('samples gathered from the session')
            if dataset == 'imagenet64_5bit':
                """Quantized values. So different kind of sampling needed here."""
                samples = np.floor(np.clip(samples, 0, 31))
                samples = samples * 8
                samples = samples.astype('uint8')
            # np.save('samples_' + prefix + '.npy', samples)
            import cv2
            samples = tile_imgs(
                np.floor(np.clip(samples, 0, 255)).astype('uint8'))
            cv2.imwrite(samples_filename, samples)

    def run_validation(sess):
        data_val_shard = np.array_split(data_val, hvd.size(),
                                        axis=0)[hvd.rank()]
        shard_losses, shard_corr = zip(*[
            sess.run([val_loss_sym, val_corr_sym], {x_sym: val_batch})
            for val_batch, in iterbatches([data_val_shard],
                                          batch_size=local_bs,
                                          include_final_partial_batch=False)
        ])
        val_loss, total_count = mpi_average(shard_losses)
        val_corr, _ = mpi_average(shard_corr)
        if is_root:
            for k, v in [
                ('val_bpd', bpd_scale_factor * val_loss),
                ('val_corr', val_corr),
                ('num_val_examples', total_count * local_bs),
            ]:
                print(k, v)

    # Run
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    config.gpu_options.visible_device_list = str(
        hvd.local_rank())  # Pin GPU to local rank (one GPU per process)
    with tf.Session(config=config) as sess:
        if is_root: print('Initializing')
        sess.run(tf.global_variables_initializer())
        # Restore from checkpoint
        if is_root:
            print('Restoring checkpoint:', restore_checkpoint)
            saver = tf.train.Saver()
            saver.restore(sess, restore_checkpoint)
            print('Broadcasting initial parameters')
        sess.run(hvd.broadcast_global_variables(0))
        sess.graph.finalize()

        # if samples_filename:
        # run_sampling_only(sess)

        # Make sure data is the same on all MPI processes
        tmp_inds = [0, 183, 3, 6, 20, 88]
        check_batch = np.ascontiguousarray(data_val[tmp_inds])
        gathered_batches = np.zeros(
            (hvd.size(),
             *check_batch.shape), check_batch.dtype) if is_root else None
        MPI.COMM_WORLD.Gather(check_batch, gathered_batches, root=0)
        if is_root:
            assert all(
                np.allclose(check_batch, b)
                for b in gathered_batches), 'data must be in the same order!'
            print('data ordering ok')

        # Run validation
        run_validation(sess)
        run_iw_eval(sess)
def train(
    *,
    flow_constructor,
    logdir,
    lr_schedule,
    dropout_p,
    seed,
    init_bs,
    total_bs,
    ema_decay,
    steps_per_log,
    max_grad_norm,
    dtype=tf.float32,
    scale_loss=None,
    dataset='imagenet32',
    steps_per_samples=20000,
    steps_per_dump=5000,
    n_epochs=2,
    restore_checkpoint=None,
    dump_samples_to_tensorboard=True,
    save_jpg=True,
):

    import horovod.tensorflow as hvd

    # Initialize Horovod
    hvd.init()
    # Verify that MPI multi-threading is supported.
    assert hvd.mpi_threads_supported()

    from mpi4py import MPI

    assert hvd.size() == MPI.COMM_WORLD.Get_size()

    is_root = hvd.rank() == 0

    def mpi_average(local_list):
        local_list = list(map(float, local_list))
        sums = MPI.COMM_WORLD.gather(sum(local_list), root=0)
        counts = MPI.COMM_WORLD.gather(len(local_list), root=0)
        sum_counts = sum(counts) if is_root else None
        avg = (sum(sums) / sum_counts) if is_root else None
        return avg, sum_counts

    # Seeding and logging setup
    seed_all(hvd.rank() + hvd.size() * seed)
    assert total_bs % hvd.size() == 0
    local_bs = total_bs // hvd.size()

    logger = None
    logdir = '{}_mpi{}_{}'.format(os.path.expanduser(logdir), hvd.size(),
                                  time.time())
    checkpointdir = os.path.join(logdir, 'checkpoints')
    profiledir = os.path.join(logdir, 'profiling')
    if is_root:
        print('Floating point format:', dtype)
        pprint(locals())
        os.makedirs(logdir)
        os.makedirs(checkpointdir)
        os.makedirs(profiledir)
        logger = TensorBoardOutput(logdir)

    # Load data
    assert dataset in ['imagenet32', 'imagenet64', 'imagenet64_5bit']

    if is_root:
        print('Loading data')
    MPI.COMM_WORLD.Barrier()
    if dataset == 'imagenet32':
        """The dataset as a npy file on RAM. There are as many copies as number of MPI threads. 
           This isn't effficient and tf.Records would be better to read from disk. 
           This is just done to ensure bits/dim reported are perfect and no data loading bugs creep in.
           However, the dataset is quite small resolution and even 8 MPI threads can work on 40GB RAM."""
        data_train = np.load('../train_32x32.npy')
        data_val = np.load('../valid_32x32.npy')
        assert data_train.dtype == 'uint8'
        assert np.max(data_train) <= 255
        assert np.min(data_train) >= 0
        assert np.max(data_val) <= 255
        assert np.min(data_val) >= 0
        assert data_val.dtype == 'uint8'
    elif dataset == 'imagenet64':
        """The dataset as a npy file on RAM. There are as many copies as number of MPI threads. 
           This isn't effficient and tf.Records would be better to read from disk. 
           This is just done to ensure bits/dim reported are perfect and no data loading bugs creep in.
           If you don't have enough CPU RAM to run 8 threads, run it with fewer threads and adjust batch-size / model-size tradeoff accordingly."""
        data_train = np.load('../train_64x64.npy')
        data_val = np.load('../valid_64x64.npy')
        assert data_train.dtype == 'uint8'
        assert np.max(data_train) <= 255
        assert np.min(data_train) >= 0
        assert np.max(data_val) <= 255
        assert np.min(data_val) >= 0
    elif dataset == 'imagenet64_5bit':
        """Similar loading as above. Quantized to 5-bit while loading."""
        if is_root:
            data_train = np.load('../train_64x64.npy')
            data_train = np.floor(data_train / 8.)
            data_train = data_train.astype('uint8')
            assert np.max(data_train) <= 31
            assert np.min(data_train) >= 0
            np.save('../train_64x64_5bit.npy', data_train)
            del data_train
            data_val = np.load('../valid_64x64.npy')
            data_val = np.floor(data_val / 8.)
            data_val = data_val.astype('uint8')
            assert np.max(data_val) <= 31
            assert np.min(data_val) >= 0
            np.save('../valid_64x64_5bit.npy', data_val)
            del data_val
        MPI.COMM_WORLD.Barrier()
        data_train = np.load('../train_64x64_5bit.npy')
        data_val = np.load('../valid_64x64_5bit.npy')
    data_train = data_train.astype(dtype.as_numpy_dtype)
    data_val = data_val.astype(dtype.as_numpy_dtype)
    img_shp = list(data_train.shape[1:])
    if dataset == 'imagenet32':
        assert img_shp == [32, 32, 3]
    else:
        assert img_shp == [64, 64, 3]
    if is_root:
        print('Training data: {}, Validation data: {}'.format(
            data_train.shape[0], data_val.shape[0]))
        print('Image shape:', img_shp)
    bpd_scale_factor = 1. / (np.log(2) * np.prod(img_shp))

    # Build graph
    if is_root: print('Building graph')
    dequant_flow, flow, posterior_flow = flow_constructor()
    # Data-dependent init
    if restore_checkpoint is None:
        if is_root: print('===== Init graph =====')
        x_init_sym = tf.placeholder(dtype, [init_bs] + img_shp)
        init_syms, _ = build_forward(x=x_init_sym,
                                     dequant_flow=dequant_flow,
                                     flow=flow,
                                     posterior_flow=posterior_flow,
                                     flow_kwargs=dict(init=True,
                                                      dropout_p=dropout_p,
                                                      verbose=is_root))
    # Training
    if is_root: print('===== Training graph =====')
    x_sym = tf.placeholder(dtype, [local_bs] + img_shp)
    loss_sym, _ = build_forward(x=x_sym,
                                dequant_flow=dequant_flow,
                                flow=flow,
                                posterior_flow=posterior_flow,
                                flow_kwargs=dict(dropout_p=dropout_p,
                                                 verbose=is_root))

    # EMA
    params = tf.trainable_variables()
    if is_root:
        print('Parameters',
              sum(np.prod(p.get_shape().as_list()) for p in params))
    ema = tf.train.ExponentialMovingAverage(decay=ema_decay)
    maintain_averages_op = tf.group(ema.apply(params))
    # Op for setting the ema params to the current non-ema params (for use after data-dependent init)
    name2var = {v.name: v for v in tf.global_variables()}
    copy_params_to_ema = tf.group([
        name2var[p.name.replace(':0', '') +
                 '/ExponentialMovingAverage:0'].assign(p) for p in params
    ])

    # Validation and sampling (with EMA)
    if is_root: print('===== Validation graph =====')
    val_loss_sym, _ = build_forward(x=x_sym,
                                    dequant_flow=dequant_flow,
                                    flow=flow,
                                    posterior_flow=posterior_flow,
                                    flow_kwargs=dict(dropout_p=0,
                                                     ema=ema,
                                                     verbose=is_root))
    # for debugging invertibility
    # val_inverr_sym = tf.reduce_max(tf.abs(
    #     val_dequant_x_sym - flow.inverse(val_y_sym, dropout_p=0, ema=ema, verbose=is_root)[0]
    # ))

    if is_root: print('===== Sampling graph =====')
    samples_sym, _ = flow.sample(local_bs,
                                 flow_kwargs=dict(dropout_p=0.,
                                                  ema=ema,
                                                  verbose=is_root))
    allgathered_samples_sym = hvd.allgather(tf.to_float(samples_sym))

    assert len(tf.trainable_variables()) == len(params)

    def run_sampling(sess,
                     i_step,
                     *,
                     prefix=dataset,
                     dump_to_tensorboard=True,
                     save_jpg=False):
        samples = sess.run(allgathered_samples_sym)
        if is_root:
            print('samples gathered from the session')
            if dataset == 'imagenet64_5bit':
                """Quantized values. So different kind of sampling needed here."""
                samples = np.floor(np.clip(samples, 0, 31))
                samples = samples * 8
                samples = samples.astype('uint8')
            # np.save('samples_' + prefix + '.npy', samples)
            # if save_jpg:
            # samples = tile_imgs(np.floor(np.clip(samples, 0, 255)).astype('uint8'))
            # cv2.imwrite('samples_' + prefix + '_' + str(i_step) + '.jpg', samples)
            if dump_to_tensorboard:
                """You can turn this off if tensorboard crashes for sample dumps. You can view the samples from the npy file anyway"""
                logger.writekvs(
                    [('samples',
                      tile_imgs(np.clip(samples, 0, 255).astype(np.uint8)))],
                    i_step)

    def run_validation(sess, i_step):
        data_val_shard = np.array_split(data_val, hvd.size(),
                                        axis=0)[hvd.rank()]
        shard_losses = np.concatenate([
            sess.run([val_loss_sym], {x_sym: val_batch})
            for val_batch, in iterbatches([data_val_shard],
                                          batch_size=local_bs,
                                          include_final_partial_batch=False)
        ])
        val_loss, total_count = mpi_average(shard_losses)
        if is_root:
            logger.writekvs([('val_bpd', bpd_scale_factor * val_loss),
                             ('num_val_examples', total_count * local_bs)],
                            i_step)

    # Optimization
    lr_sym = tf.placeholder(dtype, [], 'lr')
    optimizer = hvd.DistributedOptimizer(tf.train.AdamOptimizer(lr_sym))
    if scale_loss is None:
        grads_and_vars = optimizer.compute_gradients(loss_sym, var_list=params)
    else:
        grads_and_vars = [(g / scale_loss, v)
                          for (g, v) in optimizer.compute_gradients(
                              loss_sym * scale_loss, var_list=params)]
    if max_grad_norm is not None:
        clipped_grads, grad_norm_sym = tf.clip_by_global_norm(
            [g for (g, _) in grads_and_vars], max_grad_norm)
        grads_and_vars = [
            (cg, v) for (cg, (_, v)) in zip(clipped_grads, grads_and_vars)
        ]
    else:
        grad_norm_sym = tf.constant(0.)
    opt_sym = tf.group(optimizer.apply_gradients(grads_and_vars),
                       maintain_averages_op)

    def loop(sess: tf.Session):
        i_step = 0
        i_step_lr = 0
        if is_root: print('Initializing')
        sess.run(tf.global_variables_initializer())
        # if is_root:
        #     logger.write_graph(sess.graph)

        if restore_checkpoint is not None:
            """If restoring from an existing checkpoint whose path is specified in the launcher"""
            restore_step = int(restore_checkpoint.split('-')[-1])
            if is_root:
                saver = tf.train.Saver()
                print('Restoring checkpoint:', restore_checkpoint)
                print('Restoring from step:', restore_step)
                saver.restore(sess, restore_checkpoint)
                print('Loaded checkpoint')
            else:
                saver = None
            i_step = restore_step
            """You could re-start with the warm-up or start from wherever the checkpoint stopped depending on what is needed.
               If the session had to be stopped due to NaN/Inf, warm-up from a most recent working checkpoint is recommended.
               If it was because of Horovod Crash / Machine Shut down, re-starting from the same LR can be done in which case
               you need to uncomment the blow line. By default, it warms up."""
            i_step_lr = restore_step
        else:
            if is_root: print('Data dependent init')
            sess.run(
                init_syms, {
                    x_init_sym:
                    data_train[np.random.randint(0, data_train.shape[0],
                                                 init_bs)]
                })
            sess.run(copy_params_to_ema)
            saver = tf.train.Saver() if is_root else None
        if is_root: print('Broadcasting initial parameters')
        sess.run(hvd.broadcast_global_variables(0))
        sess.graph.finalize()

        if is_root:
            print('Training')
            print(
                'Parameters(M)',
                sum(np.prod(p.get_shape().as_list())
                    for p in params) / 1024. / 1024.)

        loss_hist = deque(maxlen=steps_per_log)
        """ 2 epochs are sufficient to see good results on Imagenet.
            After 2 epochs, gains are marginal, but important for good bits/dim."""
        for i_epoch in range(n_epochs):
            epoch_start_t = time.time()
            for i_epoch_step, (batch, ) in enumerate(
                    iterbatches(  # non-sharded: each gpu goes through the whole dataset
                        [data_train],
                        batch_size=local_bs,
                        include_final_partial_batch=False,
                    )):
                lr = lr_schedule(i_step_lr)
                loss, _ = sess.run(
                    [loss_sym, opt_sym],
                    {
                        x_sym: batch,
                        lr_sym: lr
                    },
                )
                loss_hist.append(loss)

                if i_epoch == i_epoch_step == 0:
                    epoch_start_t = time.time()

                if i_step % steps_per_log == 0:
                    loss_hist_means = MPI.COMM_WORLD.gather(float(
                        np.mean(loss_hist)),
                                                            root=0)
                    steps_per_sec = (i_epoch_step + 1) / (time.time() -
                                                          epoch_start_t)
                    if is_root:
                        kvs = [
                            ('iter', i_step),
                            ('epoch', i_epoch + i_epoch_step * local_bs /
                             data_train.shape[0]),  # epoch for this gpu
                            ('bpd',
                             float(
                                 np.mean(loss_hist_means) * bpd_scale_factor)),
                            ('lr', float(lr)),
                            ('fps', steps_per_sec * total_bs
                             ),  # fps calculated over all gpus (this epoch)
                            ('sps', steps_per_sec),
                        ]
                        logger.writekvs(kvs, i_step)
                """You could pass the validation for Imagenet because the val set is reasonably big.
                    It is extremely hard to overfit on Imagenet (if you manage to, let us know). 
                    So, skipping the validation throughout the training and validating at the end with the
                    most recent checkpoint would be okay and good for wall clock time.
                    You could also have steps_per_val specified in the launcher pretty high to find a balance."""

                if i_step > 0 and i_step % steps_per_samples == 0 and i_step_lr > 0:
                    run_sampling(
                        sess,
                        i_step=i_step,
                        dump_to_tensorboard=dump_samples_to_tensorboard,
                        save_jpg=save_jpg)
                    print('Run Validation...')
                    run_validation(sess, i_step)

                if i_step % steps_per_dump == 0 and i_step > 0 and i_step_lr > 0:
                    if saver is not None:
                        saver.save(sess,
                                   os.path.join(checkpointdir, 'model'),
                                   global_step=i_step)

                i_step += 1
                i_step_lr += 1
            # End of epoch

    # Train
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    config.gpu_options.visible_device_list = str(
        hvd.local_rank())  # Pin GPU to local rank (one GPU per process)
    with tf.Session(config=config) as sess:
        loop(sess)
Exemple #9
0
def train(model,
          loss_fn,
          Dataset=None,
          dataset=None,
          valid_dataset=None,
          valid_dataset2=None,
          test_dataset=None,
          evaluate_fn=None,
          inference_fn=None,
          eval_fn=None,
          write_valid=True,
          valid_names=None,
          infer_names=None,
          infer_debug_names=None,
          valid_write_fn=None,
          infer_write_fn=None,
          valid_suffix='.valid',
          infer_suffix='.infer',
          write_streaming=False,
          optimizer=None,
          param_groups=None,
          init_fn=None,
          sep=','):
    use_horovod = 'OMPI_COMM_WORLD_RANK' in os.environ

    if Dataset is None:
        assert dataset
    logging.info('Dataset', Dataset, 'dataset', dataset, 'valid_dataset',
                 valid_dataset, 'test_dataset', test_dataset, loss_fn)

    if FLAGS.torch:
        torch.manual_seed(FLAGS.seed or 0)
        if torch.cuda.device_count():
            torch.cuda.manual_seed(FLAGS.seed or 0)
        if use_horovod:
            import horovod.torch as hvd
            hvd.init()
            #print('-----------------', hvd, hvd.size())
            assert hvd.mpi_threads_supported()
            assert hvd.size() == comm.Get_size()
            # hvd.init already done on apps.train.py init
            torch.cuda.set_device(hvd.local_rank())
        # https://pytorch.org/tutorials/beginner/blitz/data_parallel_tutorial.html
        else:
            if torch.cuda.device_count() > 1:
                model = torch.nn.DataParallel(model)
        model.to(device)

    input_ = FLAGS.train_input
    inputs = gezi.list_files(input_)
    inputs.sort()

    all_inputs = inputs

    #batch_size = FLAGS.batch_size
    batch_size = melt.batch_size()

    num_gpus = melt.num_gpus()

    #batch_size = max(batch_size, 1)
    #batch_size_ = batch_size if not FLAGS.batch_sizes else int(FLAGS.batch_sizes.split(',')[-1])
    batch_size_ = FLAGS.eval_batch_size or batch_size

    if dataset is None:
        if FLAGS.fold is not None:
            inputs = [
                x for x in inputs if not x.endswith('%d.record' % FLAGS.fold)
                and not x.endswith('%d.tfrecord' % FLAGS.fold)
            ]
            # if FLAGS.valid_input:
            #   inputs += [x for x in gezi.list_files(FLAGS.valid_input) if not x.endswith('%d.record' % FLAGS.fold)]
        logging.info('inputs', len(inputs), inputs[:100])
    num_folds = FLAGS.num_folds or len(inputs) + 1

    if dataset is None:
        dataset = Dataset('train')
        assert len(inputs) > 0
        train_dataset = dataset.make_batch(batch_size,
                                           inputs,
                                           simple_parse=FLAGS.simple_parse)
        num_examples = dataset.num_examples_per_epoch('train')
    else:
        assert FLAGS.torch_only, 'only torch only currently support input dataset not Dataset class type, because we do not have len function there'
        train_dataset = dataset
        num_examples = len(train_dataset)

    num_all_examples = num_examples

    if valid_dataset is None:
        valid_inputs = None
        if FLAGS.valid_input:
            valid_inputs = gezi.list_files(FLAGS.valid_input)
        else:
            if FLAGS.fold is not None:
                #valid_inputs = [x for x in all_inputs if x not in inputs]
                if not FLAGS.test_aug:
                    valid_inputs = [
                        x for x in all_inputs
                        if not 'aug' in x and x not in inputs
                    ]
                else:
                    valid_inputs = [
                        x for x in all_inputs if 'aug' in x and x not in inputs
                    ]

        logging.info('valid_inputs', valid_inputs)

    num_valid_examples = None
    if valid_dataset is not None:
        num_valid_examples = len(valid_dataset)
    else:
        if valid_inputs:
            valid_dataset = dataset.make_batch(batch_size_,
                                               valid_inputs,
                                               subset='valid',
                                               hvd_shard=FLAGS.horovod_eval)
            valid_dataset2 = dataset.make_batch(batch_size,
                                                valid_inputs,
                                                subset='valid',
                                                repeat=True,
                                                initializable=False,
                                                hvd_shard=False)
            valid_dataset2_iter = iter(valid_dataset2)
        else:
            valid_datsset = None
            valid_dataset2 = None

    if num_examples:
        if FLAGS.fold is not None:
            num_examples = int(num_examples * (num_folds - 1) / num_folds)
        num_steps_per_epoch = -(-num_examples // batch_size)
    else:
        num_steps_per_epoch = None
    logging.info('num_train_examples:', num_examples)
    if use_horovod and num_examples:
        num_steps_per_epoch = -(-num_examples // (batch_size * hvd.size()))

    if num_valid_examples is None:
        if FLAGS.valid_input:
            num_valid_examples = dataset.num_examples_per_epoch('valid')
            num_valid_steps_per_epoch = -(-num_valid_examples // batch_size_
                                          ) if num_valid_examples else None
        else:
            if FLAGS.fold is not None:
                if num_examples:
                    num_valid_examples = int(num_all_examples *
                                             (1 / num_folds))
                    num_valid_steps_per_epoch = -(-num_valid_examples //
                                                  batch_size_)
                else:
                    num_valid_steps_per_epoch = None
    if use_horovod and FLAGS.horovod_eval and num_valid_examples:
        num_valid_steps_per_epoch = -(-num_valid_examples //
                                      (batch_size_ * hvd.size()))
    logging.info('num_valid_examples:', num_valid_examples)

    if test_dataset is None:
        if FLAGS.test_input:
            test_inputs = gezi.list_files(FLAGS.test_input)
            #test_inputs = [x for x in test_inputs if not 'aug' in x]
            logging.info('test_inputs', test_inputs)
        else:
            test_inputs = None

    num_test_examples = None
    if test_dataset is not None:
        num_test_examples = len(test_dataset)
    else:
        if test_inputs:
            test_dataset = dataset.make_batch(batch_size_,
                                              test_inputs,
                                              subset='test')
            num_test_examples = dataset.num_examples_per_epoch('test')
        else:
            test_dataset = None
    num_test_steps_per_epoch = -(-num_test_examples //
                                 batch_size_) if num_test_examples else None
    if use_horovod and FLAGS.horovod_eval and num_test_examples:
        num_test_steps_per_epoch = -(-num_test_examples //
                                     (batch_size_ * hvd.size()))
    logging.info('num_test_examples:', num_test_examples)

    summary = tf.contrib.summary
    # writer = summary.create_file_writer(FLAGS.log_dir + '/epoch')
    # writer_train = summary.create_file_writer(FLAGS.log_dir + '/train')
    # writer_valid = summary.create_file_writer(FLAGS.log_dir + '/valid')
    writer = summary.create_file_writer(FLAGS.log_dir)
    writer_train = summary.create_file_writer(FLAGS.log_dir)
    writer_valid = summary.create_file_writer(FLAGS.log_dir)
    global_step = tf.train.get_or_create_global_step()
    ## RuntimeError: tf.summary.FileWriter is not compatible with eager execution. Use tf.contrib.summary instead.
    #logger = gezi.SummaryWriter(FLAGS.log_dir)

    learning_rate = tfe.Variable(FLAGS.learning_rate, name="learning_rate")

    tf.add_to_collection('learning_rate', learning_rate)

    learning_rate_weight = tf.get_collection('learning_rate_weight')[-1]
    try:
        learning_rate_weights = tf.get_collection('learning_rate_weights')[-1]
    except Exception:
        learning_rate_weights = None

    # ckpt dir save models one per epoch
    ckpt_dir = os.path.join(FLAGS.model_dir, 'ckpt')
    os.system('mkdir -p %s' % ckpt_dir)
    # HACK ckpt dir is actually save mini epoch like when you set save_interval_epochs=0.1, this is usefull when you training large dataset
    ckpt_dir2 = os.path.join(FLAGS.model_dir, 'ckpt2')
    os.system('mkdir -p %s' % ckpt_dir2)

    #TODO FIXME now I just changed tf code so to not by default save only latest 5
    # refer to https://github.com/tensorflow/tensorflow/issues/22036
    # manager = tf.contrib.checkpoint.CheckpointManager(
    #     checkpoint, directory=ckpt_dir, max_to_keep=5)
    # latest_checkpoint = manager.latest_checkpoint

    latest_checkpoint = tf.train.latest_checkpoint(ckpt_dir)
    if latest_checkpoint:
        logging.info('Latest checkpoint:', latest_checkpoint)
    else:
        latest_checkpoint = tf.train.latest_checkpoint(ckpt_dir2)
        logging.info('Latest checkpoint:', latest_checkpoint)

    if os.path.exists(FLAGS.model_dir + '.index'):
        latest_checkpoint = FLAGS.model_dir

    if 'test' in FLAGS.work_mode or 'valid' in FLAGS.work_mode:
        #assert not os.path.isdir(FLAGS.model_dir), FLAGS.model_dir
        latest_checkpoint = FLAGS.model_dir
        #assert os.path.exists(latest_checkpoint) and os.path.isfile(latest_checkpoint)

    checkpoint_prefix = os.path.join(ckpt_dir, 'ckpt')
    checkpoint_prefix2 = os.path.join(ckpt_dir2, 'ckpt')

    if not FLAGS.torch:
        try:
            optimizer = optimizer or melt.get_optimizer(
                FLAGS.optimizer)(learning_rate)
        except Exception:
            logging.warning(
                f'Fail to using {FLAGS.optimizer} use adam instead')
            optimizer = melt.get_optimizer('adam')(learning_rate)

        # TODO...
        if learning_rate_weights is None:
            checkpoint = tf.train.Checkpoint(
                learning_rate=learning_rate,
                learning_rate_weight=learning_rate_weight,
                model=model,
                optimizer=optimizer,
                global_step=global_step)
        else:
            checkpoint = tf.train.Checkpoint(
                learning_rate=learning_rate,
                learning_rate_weight=learning_rate_weight,
                learning_rate_weights=learning_rate_weights,
                model=model,
                optimizer=optimizer,
                global_step=global_step)

        checkpoint.restore(latest_checkpoint)
        checkpoint2 = copy.deepcopy(checkpoint)

        start_epoch = int(
            latest_checkpoint.split('-')
            [-1]) if latest_checkpoint and 'ckpt' in latest_checkpoint else 0
        start_step = 0  # TODO
    else:
        # TODO torch with learning rate adjust
        # https://github.com/horovod/horovod/blob/master/examples/pytorch_mnist.py
        # TODO full support for pytorch now not work

        if optimizer is None:
            import lele
            is_dynamic_opt = True
            if FLAGS.optimizer == 'noam':
                optimizer_ = torch.optim.Adamax(model.parameters(), lr=0)
                if use_horovod:
                    optimizer_ = hvd.DistributedOptimizer(optimizer_)
                optimizer = lele.training.optimizers.NoamOpt(
                    128, 2, 4000, optimzier_)
            elif FLAGS.optimizer == 'bert':
                num_train_steps = int(
                    num_steps_per_epoch *
                    (FLAGS.num_decay_epochs or FLAGS.num_epochs))
                if FLAGS.warmup_steps and use_horovod:
                    FLAGS.warmup_steps = max(
                        int(FLAGS.warmup_steps / hvd.size()), 1)
                num_warmup_steps = FLAGS.warmup_steps or int(
                    num_steps_per_epoch * FLAGS.warmup_epochs) or int(
                        num_train_steps * FLAGS.warmup_proportion)
                logging.info('num_train_steps', num_train_steps,
                             'num_warmup_steps', num_warmup_steps,
                             'warmup_proportion', FLAGS.warmup_proportion)
                optimizer_ = torch.optim.Adamax(model.parameters(), lr=0)
                if use_horovod:
                    optimizer_ = hvd.DistributedOptimizer(optimizer_)
                optimizer = lele.training.optimizers.BertOpt(
                    FLAGS.learning_rate, FLAGS.min_learning_rate,
                    num_train_steps, num_warmup_steps, optimizer_)
            else:
                is_dynamic_opt = False
                optimizer = torch.optim.Adamax(
                    param_groups if param_groups else model.parameters(),
                    lr=FLAGS.learning_rate)
                if use_horovod:
                    optimizer = hvd.DistributedOptimizer(optimizer)

        start_epoch = 0
        latest_path = latest_checkpoint + '.pyt' if latest_checkpoint else os.path.join(
            FLAGS.model_dir, 'latest.pyt')
        if not os.path.exists(latest_path):
            latest_path = os.path.join(FLAGS.model_dir, 'latest.pyt')
        if os.path.exists(latest_path):
            logging.info('loading torch model from', latest_path)
            checkpoint = torch.load(latest_path)
            if not FLAGS.torch_finetune:
                start_epoch = checkpoint['epoch']
                step = checkpoint['step']
                global_step.assign(step + 1)
            load_torch_model(model, latest_path)
            if FLAGS.torch_load_optimizer:
                optimizer.load_state_dict(checkpoint['optimizer'])

        # TODO by this way restart can not change learning rate..
        if learning_rate_weights is None:
            checkpoint = tf.train.Checkpoint(
                learning_rate=learning_rate,
                learning_rate_weight=learning_rate_weight,
                global_step=global_step)
        else:
            checkpoint = tf.train.Checkpoint(
                learning_rate=learning_rate,
                learning_rate_weight=learning_rate_weight,
                learning_rate_weights=learning_rate_weights,
                global_step=global_step)

        try:
            checkpoint.restore(latest_checkpoint)
            checkpoint2 = copy.deepcopy(checkpoint)
        except Exception:
            pass

    if FLAGS.torch and is_dynamic_opt:
        optimizer._step = global_step.numpy()

    #model.load_weights(os.path.join(ckpt_dir, 'ckpt-1'))
    #model.save('./weight3.hd5')
    logging.info('optimizer:', optimizer)

    if FLAGS.torch_lr:
        learning_rate.assign(optimizer.rate(1))
    if FLAGS.torch:
        learning_rate.assign(optimizer.param_groups[0]['lr'])
        logging.info('learning rate got from pytorch latest.py as',
                     learning_rate.numpy())

    learning_rate.assign(learning_rate * FLAGS.learning_rate_start_factor)
    if learning_rate_weights is not None:
        learning_rate_weights.assign(learning_rate_weights *
                                     FLAGS.learning_rate_start_factor)

    # TODO currently not support 0.1 epoch.. like this
    num_epochs = FLAGS.num_epochs if FLAGS.num_epochs != 0 else 1024

    will_valid = valid_dataset and not FLAGS.work_mode == 'test' and not 'SHOW' in os.environ and not 'QUICK' in os.environ
    if global_step.numpy() == 0:
        will_valid = False

    if gezi.get_env('EVFIRST') == '1':
        will_valid = True

    if gezi.get_env('EVFIRST') == '0':
        will_valid = False

    if will_valid:
        logging.info('----------valid')
        if hasattr(model, 'eval'):
            model.eval()
        names = None
        if evaluate_fn is not None:
            vals, names = evaluate_fn(model, valid_dataset,
                                      tf.train.latest_checkpoint(ckpt_dir),
                                      num_valid_steps_per_epoch)
        elif eval_fn:
            model_path = None if not write_valid else latest_checkpoint
            names = valid_names if valid_names is not None else [
                infer_names[0]
            ] + [x + '_y' for x in infer_names[1:]
                 ] + infer_names[1:] if infer_names else None

            logging.info('model_path:', model_path, 'model_dir:',
                         FLAGS.model_dir)
            vals, names = evaluate(model,
                                   valid_dataset,
                                   eval_fn,
                                   model_path,
                                   names,
                                   valid_write_fn,
                                   write_streaming,
                                   num_valid_steps_per_epoch,
                                   num_valid_examples,
                                   suffix=valid_suffix,
                                   sep=sep)
        if names:
            logging.info2(
                'epoch:%.2f/%d step:%d' %
                (global_step.numpy() / num_steps_per_epoch, num_epochs,
                 global_step.numpy()),
                ['%s:%.4f' % (name, val) for name, val in zip(names, vals)])

        if FLAGS.work_mode == 'valid' or gezi.get_env('METRIC') == '1':
            exit(0)

    if 'test' in FLAGS.work_mode or gezi.get_env(
            'TEST') == '1' or gezi.get_env('INFER') == '1':
        logging.info('--------test/inference')
        if test_dataset:
            if hasattr(model, eval):
                model.eval()
            if inference_fn is None:
                # model_path = FLAGS.model_dir + '.pyt' if not latest_checkpoint else latest_checkpoint
                # logging.info('model_path', model_path)
                assert latest_checkpoint
                inference(model,
                          test_dataset,
                          latest_checkpoint,
                          infer_names,
                          infer_debug_names,
                          infer_write_fn,
                          write_streaming,
                          num_test_steps_per_epoch,
                          num_test_examples,
                          suffix=infer_suffix)
            else:
                inference_fn(model, test_dataset,
                             tf.train.latest_checkpoint(ckpt_dir),
                             num_test_steps_per_epoch)
        exit(0)

    if 'SHOW' in os.environ:
        num_epochs = start_epoch + 1

    class PytObj(object):
        def __init__(self, x):
            self.x = x

        def numpy(self):
            return self.x

    class PytMean(object):
        def __init__(self):
            self._val = 0.
            self.count = 0

            self.is_call = True

        def clear(self):
            self._val = 0
            self.count = 0

        def __call__(self, val):
            if not self.is_call:
                self.clear()
                self.is_call = True
            self._val += val.item()
            self.count += 1

        def result(self):
            if self.is_call:
                self.is_call = False
            if not self.count:
                val = 0
            else:
                val = self._val / self.count
            # TODO just for compact with tf ..
            return PytObj(val)

    Mean = tfe.metrics.Mean if not FLAGS.torch else PytMean

    num_insts = 0

    if FLAGS.learning_rate_decay_factor > 0:
        #assert FLAGS.learning_rate_values is None, 'use exponential_decay or piecewise_constant?'
        #NOTICE if you do finetune or other things which might change batch_size then you'd better direclty set num_steps_per_decay
        #since global step / decay_steps will not be correct epoch as num_steps per epoch changed
        #so if if you change batch set you have to reset global step as fixed step
        assert FLAGS.num_steps_per_decay or (
            FLAGS.num_epochs_per_decay and num_steps_per_epoch
        ), 'must set num_steps_per_epoch or num_epochs_per_decay and num_steps_per_epoch'
        decay_steps = FLAGS.num_steps_per_decay or int(
            num_steps_per_epoch * FLAGS.num_epochs_per_decay)
        decay_start_step = FLAGS.decay_start_step or int(
            num_steps_per_epoch * FLAGS.decay_start_epoch)
        # decayed_learning_rate = learning_rate * decay_rate ^ (global_step / decay_steps)
        logging.info(
            'learning_rate_decay_factor:{} decay_epochs:{} decay_steps:{} decay_start_epoch:{} decay_start_step:{}'
            .format(FLAGS.learning_rate_decay_factor,
                    FLAGS.num_epochs_per_decay, decay_steps,
                    FLAGS.decay_start_epoch, decay_start_step))

    #-------------------------start training
    if hasattr(model, 'train'):
        model.train()

    timer = gezi.Timer()
    loss_avg = Mean()
    valid_loss_avg = Mean()

    num_epochs = num_epochs if num_epochs else 0
    loops = min(num_epochs, 1) if FLAGS.torch_only else 1
    for _ in range(loops):
        for i, (x, y) in enumerate(train_dataset):
            #print('-------------------', i)
            print(len(x['index']), len(x['value']), len(x['id']))
            print(x['index'][0].size(), x['index'][1].size(), y.size())
            print(x['value'][0].size(), x['value'][1].size(), y.size())
            print(x['id'][0], x['id'][1], y.size())
            if i == 3:
                exit(0)
            continue

            if FLAGS.torch:
                x, y = to_torch(x, y)
                if is_dynamic_opt:
                    learning_rate.assign(optimizer.rate())

            def loss_fn_(x, y):
                if not FLAGS.torch and 'training' in inspect.getargspec(
                        model.call).args:
                    y_ = model(x, training=True)
                else:
                    y_ = model(x)
                if not FLAGS.torch:
                    return loss_fn(y, y_)
                else:
                    return loss_fn(y_, y)

            if not FLAGS.torch:
                loss, grads = melt.eager.grad(model, x, y, loss_fn)
                grads, _ = tf.clip_by_global_norm(grads, FLAGS.clip_gradients)
                #optimizer.apply_gradients(zip(grads, model.variables))
                optimizer.apply_gradients(zip(grads,
                                              model.trainable_variables))
                # https://github.com/horovod/horovod/blob/master/examples/tensorflow_mnist_eager.py
                # Horovod: broadcast initial variable states from rank 0 to all other processes.
                # This is necessary to ensure consistent initialization of all workers when
                # training is started with random weights or restored from a checkpoint.
                # Note: broadcast should be done after the first gradient step to ensure optimizer
                # initialization.
                # TODO check eager mode
                if use_horovod and epoch == start_epoch and i == 0:
                    hvd.broadcast_variables(model.variables, root_rank=0)
                    hvd.broadcast_variables(optimizier.variables(),
                                            root_rank=0)
            else:
                optimizer.zero_grad()
                loss = loss_fn_(x, y)
                loss.backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(),
                                               FLAGS.clip_gradients)
                optimizer.step()

            global_step.assign_add(1)
            loss_avg(loss)

            ## https://discuss.pytorch.org/t/calling-loss-backward-reduce-memory-usage/2735
            # if FLAGS.torch:
            #   del loss

            batch_size_ = list(
                x.values())[0].shape[FLAGS.batch_size_dim] if type(x) == type(
                    {}) else x.shape[FLAGS.batch_size_dim]
            num_insts += int(batch_size_)
            if global_step.numpy() % FLAGS.interval_steps == 0:
                #checkpoint.save(checkpoint_prefix)
                elapsed = timer.elapsed()
                steps_per_second = FLAGS.interval_steps / elapsed
                instances_per_second = num_insts / elapsed
                num_insts = 0

                if num_steps_per_epoch is None:
                    epoch_time_info = ''
                else:
                    hours_per_epoch = num_steps_per_epoch / FLAGS.interval_steps * elapsed / 3600
                    epoch_time_info = '1epoch:[{:.2f}h]'.format(
                        hours_per_epoch)

                if valid_dataset2:
                    # try:
                    #   x, y = next(iter(valid_dataset2))
                    # except Exception:
                    #   # TODO FIXME how.. iterate stop restart.., here hack for my iterator see projects/lm/dataset
                    #   x, y = next(iter(valid_dataset2))
                    ## valid dataset2 is repeated
                    ## NOTICE will always the first batch ... as below
                    #x, y = next(iter(valid_dataset2))
                    x, y = next(valid_dataset2_iter)
                    #print(x['id'][0])
                    if FLAGS.torch:
                        x, y = to_torch(x, y)
                    if hasattr(model, 'eval'):
                        model.eval()
                    valid_loss = loss_fn_(x, y)
                    valid_loss = valid_loss.numpy(
                    ) if not FLAGS.torch else valid_loss.item()
                    if hasattr(model, 'train'):
                        model.train()

                    if not use_horovod or hvd.rank() == 0:
                        # 'train_loss:[%.4f]' % loss_avg.result().numpy(),
                        # 'valid_loss:[%.4f]' % valid_loss_avg.result().numpy()
                        logging.info2(
                            'epoch:%.2f/%d' %
                            ((global_step.numpy() / num_steps_per_epoch),
                             num_epochs), 'step:%d' % global_step.numpy(),
                            'elapsed:[%.2f]' % elapsed, 'batch_size:[%d]' %
                            batch_size_, 'gpus:[%d]' % num_gpus,
                            'batches/s:[%.2f]' % steps_per_second,
                            'insts/s:[%d]' % instances_per_second,
                            '%s' % epoch_time_info,
                            'lr:[%.6f]' % learning_rate.numpy(),
                            'train_loss:[%.4f]' % loss_avg.result().numpy(),
                            'valid_loss:[%.4f]' % valid_loss)
                        if global_step.numpy(
                        ) % FLAGS.valid_interval_steps == 0:
                            with writer_valid.as_default(
                            ), summary.always_record_summaries():
                                summary.scalar('loss/valid', valid_loss)
                                writer_valid.flush()
                else:
                    if not use_horovod or hvd.rank() == 0:
                        #'train_loss:[%.4f]' % loss_avg.result().numpy()
                        logging.info2(
                            'epoch:%.2f/%d' %
                            ((epoch + i / num_steps_per_epoch), num_epochs),
                            'step:%d' % global_step.numpy(), 'elapsed:[%.2f]' %
                            elapsed, 'batch_size:[%d]' % batch_size_,
                            'gpus:[%d]' % num_gpus,
                            'batches/s:[%.2f]' % steps_per_second,
                            'insts/s:[%d]' % instances_per_second,
                            '%s' % epoch_time_info,
                            'lr:[%.6f]' % learning_rate.numpy(),
                            'train_loss:[%.4f]' % loss_avg.result().numpy())

                if not use_horovod or hvd.rank() == 0:
                    if global_step.numpy() % FLAGS.valid_interval_steps == 0:
                        with writer_train.as_default(
                        ), summary.always_record_summaries():
                            summary.scalar('loss/train_avg',
                                           loss_avg.result().numpy())
                            summary.scalar('learning_rate',
                                           learning_rate.numpy())
                            summary.scalar('other/batch_size', batch_size_)
                            summary.scalar('other/epoch', melt.epoch())
                            summary.scalar('perf/steps_per_second',
                                           steps_per_second)
                            summary.scalar('perf/instances_per_second',
                                           instances_per_second)
                            writer_train.flush()

            if valid_dataset and FLAGS.metric_eval_interval_steps and global_step.numpy(
            ) and global_step.numpy() % FLAGS.metric_eval_interval_steps == 0:
                if hasattr(model, eval):
                    model.eval()
                vals, names = None, None
                if evaluate_fn is not None:
                    vals, names = evaluate_fn(model, valid_dataset, None,
                                              num_valid_steps_per_epoch)
                elif eval_fn:
                    names = valid_names if valid_names is not None else [
                        infer_names[0]
                    ] + [x + '_y' for x in infer_names[1:]
                         ] + infer_names[1:] if infer_names else None
                    vals, names = evaluate(model,
                                           valid_dataset,
                                           eval_fn,
                                           None,
                                           names,
                                           valid_write_fn,
                                           write_streaming,
                                           num_valid_steps_per_epoch,
                                           num_valid_examples,
                                           sep=sep)
                if not use_horovod or hvd.rank() == 0:
                    if vals and names:
                        with writer_valid.as_default(
                        ), summary.always_record_summaries():
                            for name, val in zip(names, vals):
                                summary.scalar(f'step_eval/{name}', val)
                            writer_valid.flush()

                if FLAGS.torch:
                    if not FLAGS.torch_lr:
                        # control learning rate by tensorflow learning rate
                        for param_group in optimizer.param_groups:
                            # important learning rate decay
                            param_group['lr'] = learning_rate.numpy()
                if hasattr(model, 'train'):
                    model.train()
                if not use_horovod or hvd.rank() == 0:
                    if names and vals:
                        logging.info2(
                            'epoch:%.2f/%d' %
                            ((global_step.numpy() / num_steps_per_epoch),
                             num_epochs),
                            'valid_step:%d' % global_step.numpy(),
                            'valid_metrics', [
                                '%s:%.5f' % (name, val)
                                for name, val in zip(names, vals)
                            ])

            if not use_horovod or hvd.rank() == 0:
                # TODO save ok ?
                if global_step.numpy() % FLAGS.save_interval_steps == 0:
                    if FLAGS.torch:
                        state = {
                            'epoch':
                            int(global_step.numpy() / num_steps_per_epoch),
                            'step':
                            global_step.numpy(),
                            'state_dict':
                            model.state_dict() if not hasattr(model, 'module')
                            else model.module.state_dict(),
                            'optimizer':
                            optimizer.state_dict(),
                        }
                        torch.save(state,
                                   os.path.join(FLAGS.model_dir, 'latest.pyt'))

                # TODO fixme why if both checpoint2 and chekpoint used... not ok..
                if FLAGS.save_interval_epochs and global_step.numpy() % int(
                        num_steps_per_epoch * FLAGS.save_interval_epochs) == 0:
                    checkpoint2.save(checkpoint_prefix2)
                    if FLAGS.torch:
                        state = {
                            'epoch':
                            int(global_step.numpy() / num_steps_per_epoch),
                            'step':
                            global_step.numpy(),
                            'state_dict':
                            model.state_dict() if not hasattr(model, 'module')
                            else model.module.state_dict(),
                            'optimizer':
                            optimizer.state_dict(),
                        }
                        torch.save(
                            state,
                            tf.train.latest_checkpoint(ckpt_dir2) + '.pyt')

            if FLAGS.learning_rate_decay_factor > 0:
                if global_step.numpy(
                ) >= decay_start_step and global_step.numpy(
                ) % decay_steps == 0:
                    lr = max(
                        learning_rate.numpy() *
                        FLAGS.learning_rate_decay_factor,
                        FLAGS.min_learning_rate)
                    if lr < learning_rate.numpy():
                        learning_rate.assign(lr)
                        if FLAGS.torch:
                            for param_group in optimizer.param_groups:
                                param_group['lr'] = learning_rate.numpy()

            if i == 0:
                try:
                    if not FLAGS.torch:
                        logging.info(model.summary())
                        # #tf.keras.utils.plot_model(model, to_file='/home/gezi/model.png', show_shapes=False, show_layer_names=True, rankdir='TB')
                        # import keras
                        # keras.utils.plot_model(model, to_file='/home/gezi/model.png', show_shapes=False, show_layer_names=True, rankdir='LR', expand_nested=True, dpi=96)
                    else:
                        logging.info(model)
                except Exception:
                    traceback.print_exc()
                    logging.info(
                        'Fail to do model.summary() may be you have layer define in init but not used in call'
                    )
                if 'SHOW' in os.environ:
                    exit(0)

            if valid_dataset and global_step.numpy() % int(
                    num_steps_per_epoch * FLAGS.valid_interval_epochs) == 0:
                if hasattr(model, 'eval'):
                    model.eval()

                vals, names = None, None
                if evaluate_fn is not None:
                    vals, names = evaluate_fn(
                        model, valid_dataset,
                        tf.train.latest_checkpoint(ckpt_dir),
                        num_valid_steps_per_epoch)
                elif eval_fn:
                    model_path = None if not write_valid else tf.train.latest_checkpoint(
                        ckpt_dir)
                    print('---------metric evaluate step', global_step.numpy(),
                          'model_path:', model_path)
                    names = valid_names if valid_names is not None else [
                        infer_names[0]
                    ] + [x + '_y' for x in infer_names[1:]
                         ] + infer_names[1:] if infer_names else None

                    vals, names = evaluate(model,
                                           valid_dataset,
                                           eval_fn,
                                           model_path,
                                           names,
                                           valid_write_fn,
                                           write_streaming,
                                           num_valid_steps_per_epoch,
                                           num_valid_examples,
                                           suffix=valid_suffix,
                                           sep=sep)

                if not use_horovod or hvd.rank() == 0:
                    if vals and names:
                        logging.info2(
                            'epoch:%.2f/%d' %
                            (global_step.numpy() / num_steps_per_epoch,
                             num_epochs), 'step:%d' % global_step.numpy(),
                            'valid_metrics', [
                                '%s:%.5f' % (name, val)
                                for name, val in zip(names, vals)
                            ])

                if not use_horovod or hvd.rank() == 0:
                    with writer.as_default(), summary.always_record_summaries(
                    ):
                        temp = global_step.value()
                        global_step.assign(
                            int(global_step.numpy() /
                                int(num_steps_per_epoch *
                                    FLAGS.valid_interval_epochs)))
                        if valid_dataset:
                            if hasattr(model, 'eval'):
                                model.eval()
                            if vals and names:
                                for name, val in zip(names, vals):
                                    summary.scalar(f'eval/{name}', val)
                        writer.flush()
                        global_step.assign(temp)

            if test_dataset and global_step.numpy() % int(
                    num_steps_per_epoch *
                    FLAGS.inference_interval_epochs) == 0:
                if hasattr(model, 'eval'):
                    model.eval()
                if inference_fn is None:
                    inference(model,
                              test_dataset,
                              tf.train.latest_checkpoint(ckpt_dir),
                              infer_names,
                              infer_debug_names,
                              infer_write_fn,
                              write_streaming,
                              num_test_steps_per_epoch,
                              num_test_examples,
                              suffix=infer_suffix,
                              sep=sep)
                else:
                    inference_fn(model, test_dataset,
                                 tf.train.latest_checkpoint(ckpt_dir),
                                 num_test_steps_per_epoch)

            if num_epochs and (global_step.numpy() %
                               num_steps_per_epoch) == 0 and int(
                                   global_step.numpy() /
                                   num_steps_per_epoch) == num_epochs:
                logging.info(f'Finshed training of {num_epochs} epochs')
                exit(0)