Exemplo n.º 1
0
def time_range(message, color_id=None, argb_color=None):
    """A context manager to describe the enclosed block as a nested range

    >>> from cupy import prof
    >>> with cupy.prof.time_range('some range in green', color_id=0):
    ...    # do something you want to measure
    ...    pass

    Args:
        message: Name of a range.
        color_id: range color ID
        argb_color: range color in ARGB (e.g. 0xFF00FF00 for green)

    .. seealso:: :func:`cupy.cuda.nvtx.RangePush`
        :func:`cupy.cuda.nvtx.RangePop`
    """
    if color_id is not None and argb_color is not None:
        raise ValueError('Only either color_id or argb_color can be specified')

    if argb_color is not None:
        nvtx.RangePushC(message, argb_color)
    else:
        if color_id is None:
            color_id = -1
        nvtx.RangePush(message, color_id)
    try:
        yield
    finally:
        nvtx.RangePop()
    def read(self, datafile, labelfile):
        nvtx.RangePush("Read Data", 2)

        #set shape to none in the beginning
        shape = None

        #data
        #begin=time.time()
        with h5.File(self.path + '/' + datafile,
                     "r",
                     driver="core",
                     backing_store=False) as f:
            #get shape info
            shape = f['climate']['data'].shape
            #get min and max values and update stored values
            if self.update_on_read:
                self.minvals = np.minimum(
                    self.minvals, f['climate']['data_stats'][0, self.channels])
                self.maxvals = np.maximum(
                    self.maxvals, f['climate']['data_stats'][1, self.channels])
            #get data
            data = f['climate']['data'][:, :, self.channels].astype(np.float32)
            #data = data[:,:,self.channels]
            #do min/max normalization
            for c in range(len(self.channels)):
                data[:, :, c] = (data[:, :, c] - self.minvals[c]) / (
                    self.maxvals[c] - self.minvals[c])

            #transposition necessary because we went to NCHW
            data = np.transpose(data, [2, 0, 1])

        #label
        with h5.File(self.path + '/' + labelfile,
                     "r",
                     driver="core",
                     backing_store=False) as f:
            label = f['climate']['labels'][...].astype(np.int32)
        #end=time.time()
        #print "Time to read image %.3f s" % (end-begin)

        nvtx.RangePop()  # Load Data

        return data, label
def main(blocks, weights, image_dir, checkpoint_dir, trn_sz):
    #init horovod
    nvtx.RangePush("init horovod", 1)
    comm_rank = 0
    comm_local_rank = 0
    comm_size = 1
    if horovod:
        hvd.init()
        comm_rank = hvd.rank()
        comm_local_rank = hvd.local_rank()
        comm_size = hvd.size()
        if comm_rank == 0:
            print("Using distributed computation with Horovod: {} total ranks".
                  format(comm_size, comm_rank))
    nvtx.RangePop()  # init horovod

    #parameters
    batch = 1
    channels = [0, 1, 2, 10]
    #blocks = [3,3,4,4,7,7,10]
    num_epochs = 3
    dtype = tf.float16

    #session config
    sess_config = tf.ConfigProto(
        inter_op_parallelism_threads=2,  #1
        intra_op_parallelism_threads=33,  #6
        log_device_placement=False,
        allow_soft_placement=True)
    sess_config.gpu_options.visible_device_list = str(comm_local_rank)

    #get data
    training_graph = tf.Graph()
    if comm_rank == 0:
        print("Loading data...")
    path, trn_data, trn_labels, val_data, val_labels, tst_data, tst_labels = load_data(
        trn_sz)
    if comm_rank == 0:
        print("Shape of trn_data is {}".format(trn_data.shape[0]))
        print("done.")
    with training_graph.as_default():
        nvtx.RangePush("TF Init", 3)
        #create datasets
        datafiles = tf.placeholder(tf.string, shape=[None])
        labelfiles = tf.placeholder(tf.string, shape=[None])
        trn_reader = h5_input_reader(path, channels, update_on_read=True)
        trn_dataset = create_dataset(trn_reader, datafiles, labelfiles, batch,
                                     num_epochs, comm_size, comm_rank, True)
        val_reader = h5_input_reader(path, channels, update_on_read=False)
        val_dataset = create_dataset(val_reader, datafiles, labelfiles, batch,
                                     1, comm_size, comm_rank)

        #create iterators
        handle = tf.placeholder(tf.string,
                                shape=[],
                                name="iterator-placeholder")
        iterator = tf.data.Iterator.from_string_handle(
            handle, (tf.float32, tf.int32),
            ((batch, len(channels), image_height, image_width),
             (batch, image_height, image_width)))
        next_elem = iterator.get_next()

        #create init handles
        #trn
        trn_iterator = trn_dataset.make_initializable_iterator()
        trn_handle_string = trn_iterator.string_handle()
        trn_init_op = iterator.make_initializer(trn_dataset)
        #val
        val_iterator = val_dataset.make_initializable_iterator()
        val_handle_string = val_iterator.string_handle()
        val_init_op = iterator.make_initializer(val_dataset)

        #set up model
        logit, prediction = create_tiramisu(3,
                                            next_elem[0],
                                            image_height,
                                            image_width,
                                            len(channels),
                                            loss_weights=weights,
                                            nb_layers_per_block=blocks,
                                            p=0.2,
                                            wd=1e-4,
                                            dtype=dtype)
        loss = tf.losses.sparse_softmax_cross_entropy(labels=next_elem[1],
                                                      logits=logit)
        #if horovod:
        #    loss_average = hvd.allreduce(loss)/comm_size
        #else:
        #    loss_average = loss
        global_step = tf.train.get_or_create_global_step()
        #set up optimizer
        opt = tf.train.RMSPropOptimizer(learning_rate=1e-3)
        if horovod:
            opt = hvd.DistributedOptimizer(opt)
        train_op = opt.minimize(loss, global_step=global_step)
        #set up streaming metrics
        labels_one_hot = tf.contrib.layers.one_hot_encoding(next_elem[1], 3)
        iou_op, iou_update_op = tf.metrics.mean_iou(prediction,
                                                    labels_one_hot,
                                                    3,
                                                    weights=None,
                                                    metrics_collections=None,
                                                    updates_collections=None,
                                                    name="iou_score")

        #compute epochs and stuff:
        num_samples = trn_data.shape[0] // comm_size
        num_steps_per_epoch = num_samples // batch
        num_steps = num_epochs * num_steps_per_epoch

        #hooks
        #these hooks are essential. regularize the step hook by adding one additional step at the end
        hooks = [tf.train.StopAtStepHook(last_step=num_steps + 1)]
        if horovod:
            hooks.append(hvd.BroadcastGlobalVariablesHook(0))
        #initializers:
        init_op = tf.global_variables_initializer()
        init_local_op = tf.local_variables_initializer()

        #checkpointing
        #if comm_rank == 0:
        #    checkpoint_save_freq = num_steps_per_epoch
        #    checkpoint_saver = tf.train.Saver(max_to_keep = 1000)
        #    hooks.append(tf.train.CheckpointSaverHook(checkpoint_dir=checkpoint_dir, save_steps=checkpoint_save_freq, saver=checkpoint_saver))
        #    #create image dir if not exists
        #    if not os.path.isdir(image_dir):
        #        os.makedirs(image_dir)

        ##DEBUG
        ##summary
        #if comm_rank == 0:
        #    print("write graph for debugging")
        #    tf.summary.scalar("loss",loss)
        #    summary_op = tf.summary.merge_all()
        #    #hooks.append(tf.train.SummarySaverHook(save_steps=num_steps_per_epoch, summary_writer=summary_writer, summary_op=summary_op))
        #    with tf.Session(config=sess_config) as sess:
        #        sess.run([init_op, init_local_op])
        #        #create iterator handles
        #        trn_handle = sess.run(trn_handle_string)
        #        #init iterators
        #        sess.run(trn_init_op, feed_dict={handle: trn_handle, datafiles: trn_data, labelfiles: trn_labels})
        #        #summary:
        #        sess.run(summary_op, feed_dict={handle: trn_handle})
        #        #summary file writer
        #        summary_writer = tf.summary.FileWriter('./logs', sess.graph)
        ##DEBUG

        #start session
        with tf.train.MonitoredTrainingSession(config=sess_config,
                                               hooks=hooks) as sess:
            #initialize
            sess.run(
                [init_op, init_local_op]
            )  #, options=tf.RunOptions(report_tensor_allocations_upon_oom=True))
            #create iterator handles
            trn_handle, val_handle = sess.run(
                [trn_handle_string, val_handle_string])
            #init iterators
            sess.run(trn_init_op,
                     feed_dict={
                         handle: trn_handle,
                         datafiles: trn_data,
                         labelfiles: trn_labels
                     })
            sess.run(val_init_op,
                     feed_dict={
                         handle: val_handle,
                         datafiles: val_data,
                         labelfiles: val_labels
                     })

            nvtx.RangePop()  # TF Init

            #do the training
            epoch = 1
            step = 1
            train_loss = 0.
            nvtx.RangePush("Training Loop", 4)
            nvtx.RangePush("Epoch", epoch)
            start_time = time.time()
            while not sess.should_stop():

                #training loop
                try:
                    nvtx.RangePush("Step", step)
                    #construct feed dict
                    _, _, train_steps, tmp_loss = sess.run(
                        [train_op, iou_update_op, global_step, loss],
                        feed_dict={handle: trn_handle})
                    train_steps_in_epoch = train_steps % num_steps_per_epoch
                    train_loss += tmp_loss
                    nvtx.RangePop()  # Step
                    step += 1

                    if train_steps_in_epoch > 0:
                        #print step report
                        print(
                            "REPORT: rank {}, training loss for step {} (of {}) is {}"
                            .format(comm_rank, train_steps, num_steps,
                                    train_loss / train_steps_in_epoch))
                    else:
                        end_time = time.time()
                        #print epoch report
                        train_loss /= num_steps_per_epoch
                        print(
                            "COMPLETED: rank {}, training loss for epoch {} (of {}) is {}, epoch duration {} s"
                            .format(comm_rank, epoch, num_epochs, train_loss,
                                    end_time - start_time))
                        nvtx.RangePush("IOU", 6)
                        iou_score = sess.run(iou_op)
                        nvtx.RangePop()
                        print(
                            "COMPLETED: rank {}, training IoU for epoch {} (of {}) is {}, epoch duration {} s"
                            .format(comm_rank, epoch, num_epochs, iou_score,
                                    end_time - start_time))
                        start_time = time.time()

                        #evaluation loop
                        eval_loss = 0.
                        eval_steps = 0
                        #update the input reader
                        val_reader.minvals = trn_reader.minvals
                        val_reader.maxvals = trn_reader.maxvals
                        nvtx.RangePush("Eval Loop", 7)
                        while True:
                            try:
                                #construct feed dict
                                _, tmp_loss, val_model_predictions, val_model_labels = sess.run(
                                    [
                                        iou_update_op, loss, prediction,
                                        next_elem[1]
                                    ],
                                    feed_dict={handle: val_handle})
                                if use_scipy:
                                    imsave(
                                        image_dir + '/test_pred_epoch' +
                                        str(epoch) + '_estep' +
                                        str(eval_steps) + '_rank' +
                                        str(comm_rank) + '.png',
                                        np.argmax(val_model_predictions[0,
                                                                        ...],
                                                  axis=2) * 100)
                                    imsave(
                                        image_dir + '/test_label_epoch' +
                                        str(epoch) + '_estep' +
                                        str(eval_steps) + '_rank' +
                                        str(comm_rank) + '.png',
                                        val_model_labels[0, ...] * 100)
                                else:
                                    np.save(
                                        image_dir + '/test_pred_epoch' +
                                        str(epoch) + '_estep' +
                                        str(eval_steps) + '_rank' +
                                        str(comm_rank) + '.npy',
                                        np.argmax(val_model_predictions[0,
                                                                        ...],
                                                  axis=2) * 100)
                                    np.save(
                                        image_dir + '/test_label_epoch' +
                                        str(epoch) + '_estep' +
                                        str(eval_steps) + '_rank' +
                                        str(comm_rank) + '.npy',
                                        val_model_labels[0, ...] * 100)
                                eval_loss += tmp_loss
                                eval_steps += 1
                            except tf.errors.OutOfRangeError:
                                eval_steps = np.max([eval_steps, 1])
                                eval_loss /= eval_steps
                                print(
                                    "COMPLETED: rank {}, evaluation loss for epoch {} (of {}) is {}"
                                    .format(comm_rank, epoch - 1, num_epochs,
                                            eval_loss))
                                iou_score = sess.run(iou_op)
                                print(
                                    "COMPLETED: rank {}, evaluation IoU for epoch {} (of {}) is {}"
                                    .format(comm_rank, epoch - 1, num_epochs,
                                            iou_score))
                                sess.run(val_init_op,
                                         feed_dict={
                                             handle: val_handle,
                                             datafiles: val_data,
                                             labelfiles: val_labels
                                         })
                                break
                        nvtx.RangePop()  # Eval Loop

                        #reset counters
                        epoch += 1
                        train_loss = 0.
                        step = 0

                        nvtx.RangePop()  # Epoch
                        nvtx.RangePush("Epoch", epoch)

                except tf.errors.OutOfRangeError:
                    break

            nvtx.RangePop()  # Epoch
            nvtx.RangePop()  # Training Loop
Exemplo n.º 4
0
 def __enter__(self):
     if self.argb_color is not None:
         nvtx.RangePushC(self.message, self.argb_color)
     else:
         nvtx.RangePush(self.message, self.color_id)
     return self
Exemplo n.º 5
0
 def test_RangePush(self):
     nvtx.RangePush("test:RangePush", 1)
     nvtx.RangePop()
Exemplo n.º 6
0
def main(input_path, blocks, weights, image_dir, checkpoint_dir, trn_sz,
         learning_rate, loss_type, fs_type, opt_type, batch, batchnorm,
         num_epochs, dtype, chkpt, filter_sz, growth, disable_training,
         enable_tf_timeline):
    options = None
    run_metadata = None
    many_runs_timeline = None

    timeline_trace_fp = open("timeline_trace.pickle", "wb")

    options, run_metadata, many_runs_timeline, min_timeline_step, max_timeline_step = \
        init_timeline_configs(enable_tf_timeline, tf.RunOptions.FULL_TRACE, -1, -1)

    global_time_logger = logger(-1, "Global Total Time", -1, True)
    global_time_logger.start_timer()

    #init horovod

    initialization_timer_logger = logger(-1, "Initialize Horovod", -1, True)
    initialization_timer_logger.start_timer()

    nvtx.RangePush("init horovod", 1)
    comm_rank = 0
    comm_local_rank = 0
    comm_size = 1
    comm_local_size = 1
    if horovod:
        hvd.init()
        comm_rank = hvd.rank()
        comm_local_rank = hvd.local_rank()
        comm_size = hvd.size()
        #not all horovod versions have that implemented
        try:
            comm_local_size = hvd.local_size()
        except:
            comm_local_size = 1
        if comm_rank == 0:
            print("Using distributed computation with Horovod: {} total ranks".
                  format(comm_size, comm_rank))
    nvtx.RangePop()  # init horovod

    initialization_timer_logger.set_rank(int(comm_rank))
    initialization_timer_logger.end_timer()

    global_time_logger.set_rank(int(comm_rank))

    #parameters
    channels = [0, 1, 2, 10]
    per_rank_output = False
    loss_print_interval = 1

    #session config

    initialization_timer_logger.start_timer(comm_rank, "Configure Session")

    sess_config = tf.ConfigProto(
        inter_op_parallelism_threads=6,  #1
        intra_op_parallelism_threads=1,  #6
        log_device_placement=False,
        allow_soft_placement=True)
    sess_config.gpu_options.visible_device_list = str(comm_local_rank)

    initialization_timer_logger.end_timer()

    #get data

    initialization_timer_logger.start_timer(comm_rank, "Get Data")

    training_graph = tf.Graph()
    if comm_rank == 0:
        print("Loading data...")
    trn_data, val_data, tst_data = load_data(input_path, trn_sz, comm_rank)
    if comm_rank == 0:
        print("Shape of trn_data is {}".format(trn_data.shape[0]))
        print("done.")

    initialization_timer_logger.end_timer()

    #print some stats
    if comm_rank == 0:
        print("Learning Rate: {}".format(learning_rate))
        print("Num workers: {}".format(comm_size))
        print("Local batch size: {}".format(batch))
        if dtype == tf.float32:
            print("Precision: {}".format("FP32"))
        else:
            print("Precision: {}".format("FP16"))
        print("Batch normalization: {}".format(batchnorm))
        print("Blocks: {}".format(blocks))
        print("Growth rate: {}".format(growth))
        print("Filter size: {}".format(filter_sz))
        print("Channels: {}".format(channels))
        print("Loss type: {}".format(loss_type))
        print("Loss weights: {}".format(weights))
        print("Optimizer type: {}".format(opt_type))
        print("Num training samples: {}".format(trn_data.shape[0]))
        print("Num validation samples: {}".format(val_data.shape[0]))

    io_training_time_logger = logger(comm_rank, "IO and Training", -1, True)
    io_training_time_logger.start_timer()

    with training_graph.as_default():
        nvtx.RangePush("TF Init", 3)
        #create readers
        trn_reader = h5_input_reader(input_path,
                                     channels,
                                     weights,
                                     dtype,
                                     normalization_file="stats.h5",
                                     update_on_read=False,
                                     comm_rank=comm_rank)
        val_reader = h5_input_reader(input_path,
                                     channels,
                                     weights,
                                     dtype,
                                     normalization_file="stats.h5",
                                     update_on_read=False,
                                     comm_rank=comm_rank)
        #create datasets
        if fs_type == "local":
            trn_dataset = create_dataset(trn_reader,
                                         trn_data,
                                         batch,
                                         num_epochs,
                                         comm_local_size,
                                         comm_local_rank,
                                         dtype,
                                         shuffle=True)
            val_dataset = create_dataset(val_reader,
                                         val_data,
                                         batch,
                                         1,
                                         comm_local_size,
                                         comm_local_rank,
                                         dtype,
                                         shuffle=False)
        else:
            trn_dataset = create_dataset(trn_reader,
                                         trn_data,
                                         batch,
                                         num_epochs,
                                         comm_size,
                                         comm_rank,
                                         dtype,
                                         shuffle=True)
            val_dataset = create_dataset(val_reader,
                                         val_data,
                                         batch,
                                         1,
                                         comm_size,
                                         comm_rank,
                                         dtype,
                                         shuffle=False)

        #create iterators
        handle = tf.placeholder(tf.string,
                                shape=[],
                                name="iterator-placeholder")
        iterator = tf.data.Iterator.from_string_handle(
            handle, (dtype, tf.int32, dtype),
            ((batch, len(channels), image_height, image_width),
             (batch, image_height, image_width),
             (batch, image_height, image_width)))
        next_elem = iterator.get_next()

        #create init handles
        #trn
        trn_iterator = trn_dataset.make_initializable_iterator()
        trn_handle_string = trn_iterator.string_handle()
        trn_init_op = iterator.make_initializer(trn_dataset)
        #val
        val_iterator = val_dataset.make_initializable_iterator()
        val_handle_string = val_iterator.string_handle()
        val_init_op = iterator.make_initializer(val_dataset)

        #set up model
        logit, prediction = create_tiramisu(3,
                                            next_elem[0],
                                            image_height,
                                            image_width,
                                            len(channels),
                                            loss_weights=weights,
                                            nb_layers_per_block=blocks,
                                            p=0.2,
                                            wd=1e-4,
                                            dtype=dtype,
                                            batchnorm=batchnorm,
                                            growth_rate=growth,
                                            filter_sz=filter_sz,
                                            comm_rank=comm_rank)

        #set up loss
        labels_one_hot = tf.cast(tf.contrib.layers.one_hot_encoding(
            next_elem[1], 3),
                                 dtype=dtype)
        loss = None
        if loss_type == "weighted":
            loss = tf.losses.softmax_cross_entropy(
                onehot_labels=labels_one_hot,
                logits=logit,
                weights=next_elem[2])
        elif loss_type == "focal":
            loss = focal_loss(onehot_labels=labels_one_hot,
                              logits=logit,
                              alpha=1.,
                              gamma=2.)
        else:
            raise ValueError("Error, loss type {} not supported.",
                             format(loss_type))
        if horovod:
            loss_avg = hvd.allreduce(tf.cast(loss, tf.float32))
        else:
            loss_avg = tf.identity(loss)

        #set up global step
        global_step = tf.train.get_or_create_global_step()

        #set up optimizer
        if opt_type.startswith("LARC"):
            if comm_rank == 0:
                print("Enabling LARC")
            train_op = get_larc_optimizer(opt_type.split("-")[1],
                                          loss,
                                          global_step,
                                          learning_rate,
                                          LARC_mode="clip",
                                          LARC_eta=0.002,
                                          LARC_epsilon=1. / 16000.)
        else:
            train_op = get_optimizer(opt_type, loss, global_step,
                                     learning_rate)
        #set up streaming metrics
        iou_op, iou_update_op = tf.metrics.mean_iou(labels=next_elem[1],
                                                    predictions=tf.argmax(
                                                        prediction, axis=3),
                                                    num_classes=3,
                                                    weights=None,
                                                    metrics_collections=None,
                                                    updates_collections=None,
                                                    name="iou_score")
        iou_reset_op = tf.variables_initializer([
            i for i in tf.local_variables() if i.name.startswith('iou_score/')
        ])

        if horovod:
            iou_avg = hvd.allreduce(iou_op)
        else:
            iou_avg = tf.identity(iou_op)

        #compute epochs and stuff:
        if fs_type == "local":
            num_samples = trn_data.shape[0] // comm_local_size
        else:
            num_samples = trn_data.shape[0] // comm_size
        #num_steps_per_epoch = num_samples // batch
        num_steps_per_epoch = 10
        num_steps = num_epochs * num_steps_per_epoch
        if per_rank_output:
            print("Rank {} does {} steps per epoch".format(
                comm_rank, num_steps_per_epoch))

        #hooks
        #these hooks are essential. regularize the step hook by adding one additional step at the end
        hooks = [tf.train.StopAtStepHook(last_step=num_steps + 1)]
        #bcast init for bcasting the model after start
        init_bcast = hvd.broadcast_global_variables(0)
        #initializers:
        init_op = tf.global_variables_initializer()
        init_local_op = tf.local_variables_initializer()

        #checkpointing
        if comm_rank == 0:
            checkpoint_save_freq = num_steps_per_epoch * 2
            checkpoint_saver = tf.train.Saver(max_to_keep=1000)
            listener = checkpoint_listener(comm_rank, True)
            hooks.append(
                tf.train.CheckpointSaverHook(checkpoint_dir=checkpoint_dir,
                                             save_steps=checkpoint_save_freq,
                                             saver=checkpoint_saver,
                                             listeners=[listener]))
            #create image dir if not exists
            if not os.path.isdir(image_dir):
                os.makedirs(image_dir)

        ##DEBUG
        ##summary
        #if comm_rank == 0:
        #    print("write graph for debugging")
        #    tf.summary.scalar("loss",loss)
        #    summary_op = tf.summary.merge_all()
        #    #hooks.append(tf.train.SummarySaverHook(save_steps=num_steps_per_epoch, summary_writer=summary_writer, summary_op=summary_op))
        #    with tf.Session(config=sess_config) as sess:
        #        sess.run([init_op, init_local_op])
        #        #create iterator handles
        #        trn_handle = sess.run(trn_handle_string)
        #        #init iterators
        #        sess.run(trn_init_op, feed_dict={handle: trn_handle, datafiles: trn_data, labelfiles: trn_labels})
        #        #summary:
        #        sess.run(summary_op, feed_dict={handle: trn_handle})
        #        #summary file writer
        #        summary_writer = tf.summary.FileWriter('./logs', sess.graph)
        ##DEBUG

        #start session
        with tf.train.MonitoredTrainingSession(config=sess_config,
                                               hooks=hooks) as sess:
            #initialize
            sess.run([init_op, init_local_op])

            #restore from checkpoint:
            if comm_rank == 0:
                load_model(sess, checkpoint_saver, checkpoint_dir, comm_rank)
            #broadcast loaded model variables
            sess.run(init_bcast)

            #create iterator handles
            trn_handle, val_handle = sess.run(
                [trn_handle_string, val_handle_string],
                options=options,
                run_metadata=run_metadata)

            update_timeline_in_range(enable_tf_timeline, run_metadata,
                                     many_runs_timeline,
                                     "create_iterator_handle.json")

            #init iterators
            sess.run(trn_init_op,
                     feed_dict={handle: trn_handle},
                     options=options,
                     run_metadata=run_metadata)

            update_timeline_in_range(enable_tf_timeline, run_metadata,
                                     many_runs_timeline,
                                     "init_train_iterator_handle.json")

            sess.run(val_init_op,
                     feed_dict={handle: val_handle},
                     options=options,
                     run_metadata=run_metadata)

            update_timeline_in_range(enable_tf_timeline, run_metadata,
                                     many_runs_timeline,
                                     "init_val_iterator_handle.json")

            nvtx.RangePop()  # TF Init

            # do the training
            epoch = 1
            step = 1
            train_loss = 0.
            nvtx.RangePush("Training Loop", 4)
            nvtx.RangePush("Epoch", epoch)
            start_time = time.time()

            training_loop_timer_logger = logger(comm_rank, "Training Loop", -1,
                                                True)
            training_loop_timer_logger.start_timer()

            train_steps = 0
            while not (sess.should_stop()):
                #training loop
                try:
                    training_iteration_time_logger = logger(
                        comm_rank, "Training Iteration", epoch, True)
                    training_iteration_time_logger.start_timer()

                    nvtx.RangePush("Step", step)

                    if disable_training:
                        train_steps = sess.run([global_step],
                                               feed_dict={handle: trn_handle},
                                               options=options,
                                               run_metadata=run_metadata)

                        update_timeline_in_range(
                            enable_tf_timeline, run_metadata,
                            many_runs_timeline, train_steps[0],
                            "train_" + str(global_step) + ".json",
                            min_timeline_step, max_timeline_step)

                        train_steps_in_epoch = train_steps[
                            0] % num_steps_per_epoch

                        # do the validation phase
                        if train_steps_in_epoch == 0:
                            eval_steps = 0
                            while True:
                                try:
                                    sess.run([next_elem[1]],
                                             feed_dict={handle: val_handle},
                                             options=options,
                                             run_metadata=run_metadata)

                                    update_timeline_in_range(
                                        enable_tf_timeline, run_metadata,
                                        many_runs_timeline,
                                        "val_dict" + str(eval_steps) + ".json")

                                    eval_steps += 1
                                except tf.errors.OutOfRangeError:
                                    sess.run(val_init_op,
                                             feed_dict={handle: val_handle},
                                             options=options,
                                             run_metadata=run_metadata)

                                    update_timeline_in_range(
                                        enable_tf_timeline, run_metadata,
                                        many_runs_timeline, "val_dict_out_" +
                                        str(eval_steps) + ".json")

                                    break

                    else:
                        # construct feed dict
                        _, train_steps, tmp_loss = sess.run(
                            [
                                train_op, global_step,
                                (loss if per_rank_output else loss_avg)
                            ],
                            feed_dict={handle: trn_handle},
                            options=options,
                            run_metadata=run_metadata)

                        update_timeline_in_range(
                            enable_tf_timeline, run_metadata,
                            many_runs_timeline, train_steps,
                            "val_" + str(global_step) + ".json",
                            min_timeline_step, max_timeline_step)

                        if comm_rank == 0:
                            step_trace_fp = open(
                                "train_step_trace_" + str(global_step) +
                                ".pickle", "wb")
                            pickle.dump(run_metadata, step_trace_fp)

                        train_steps_in_epoch = train_steps % num_steps_per_epoch
                        train_loss += tmp_loss
                        nvtx.RangePop()  # Step
                        step += 1

                        #print step report
                        eff_steps = train_steps_in_epoch if (
                            train_steps_in_epoch > 0) else num_steps_per_epoch
                        if (train_steps % loss_print_interval) == 0:
                            if per_rank_output:
                                print(
                                    "REPORT: rank {}, training loss for step {} (of {}) is {}, time {}"
                                    .format(comm_rank, train_steps, num_steps,
                                            train_loss / eff_steps,
                                            time.time() - start_time))
                            else:
                                if comm_rank == 0:
                                    print(
                                        "REPORT: training loss for step {} (of {}) is {}, time {}"
                                        .format(train_steps, num_steps,
                                                train_loss / eff_steps,
                                                time.time() - start_time))

                        #do the validation phase
                        if train_steps_in_epoch == 0:
                            end_time = time.time()
                            #print epoch report
                            train_loss /= num_steps_per_epoch
                            if per_rank_output:
                                print(
                                    "COMPLETED: rank {}, training loss for epoch {} (of {}) is {}, time {} s"
                                    .format(comm_rank, epoch, num_epochs,
                                            train_loss,
                                            time.time() - start_time))
                            else:
                                if comm_rank == 0:
                                    print(
                                        "COMPLETED: training loss for epoch {} (of {}) is {}, time {} s"
                                        .format(epoch, num_epochs, train_loss,
                                                time.time() - start_time))

                            #evaluation loop
                            eval_loss = 0.
                            eval_steps = 0
                            nvtx.RangePush("Eval Loop", 7)
                            timeline_help_count = 0
                            while True:
                                try:
                                    #construct feed dict
                                    _, tmp_loss, val_model_predictions, val_model_labels = sess.run(
                                        [
                                            iou_update_op,
                                            (loss
                                             if per_rank_output else loss_avg),
                                            prediction, next_elem[1]
                                        ],
                                        feed_dict={handle: val_handle},
                                        options=options,
                                        run_metadata=run_metadata)

                                    update_timeline_in_range(
                                        enable_tf_timeline, run_metadata,
                                        many_runs_timeline,
                                        timeline_help_count,
                                        "train_" + str(global_step) + ".json",
                                        min_timeline_step, max_timeline_step)

                                    if comm_rank == 0:
                                        step_trace_fp = open(
                                            "validation_step_trace_" +
                                            str(global_step) + ".pickle", "wb")
                                        pickle.dump(run_metadata,
                                                    step_trace_fp)

                                    timeline_help_count += 1

                                    #print some images
                                    if comm_rank == 0:
                                        if have_imsave:
                                            imsave(
                                                image_dir +
                                                '/test_pred_epoch' +
                                                str(epoch) + '_estep' +
                                                str(eval_steps) + '_rank' +
                                                str(comm_rank) + '.png',
                                                np.argmax(
                                                    val_model_predictions[0,
                                                                          ...],
                                                    axis=2) * 100)
                                            imsave(
                                                image_dir +
                                                '/test_label_epoch' +
                                                str(epoch) + '_estep' +
                                                str(eval_steps) + '_rank' +
                                                str(comm_rank) + '.png',
                                                val_model_labels[0, ...] * 100)
                                            imsave(
                                                image_dir +
                                                '/test_combined_epoch' +
                                                str(epoch) + '_estep' +
                                                str(eval_steps) + '_rank' +
                                                str(comm_rank) + '.png',
                                                colormap[
                                                    val_model_labels[0, ...],
                                                    np.argmax(
                                                        val_model_predictions[
                                                            0, ...],
                                                        axis=2)])
                                        else:
                                            np.save(
                                                image_dir +
                                                '/test_pred_epoch' +
                                                str(epoch) + '_estep' +
                                                str(eval_steps) + '_rank' +
                                                str(comm_rank) + '.npy',
                                                np.argmax(
                                                    val_model_predictions[0,
                                                                          ...],
                                                    axis=2) * 100)
                                            np.save(
                                                image_dir +
                                                '/test_label_epoch' +
                                                str(epoch) + '_estep' +
                                                str(eval_steps) + '_rank' +
                                                str(comm_rank) + '.npy',
                                                val_model_labels[0, ...] * 100)

                                    eval_loss += tmp_loss
                                    eval_steps += 1
                                except tf.errors.OutOfRangeError:
                                    eval_steps = np.max([eval_steps, 1])
                                    eval_loss /= eval_steps
                                    if per_rank_output:
                                        print(
                                            "COMPLETED: rank {}, evaluation loss for epoch {} (of {}) is {}"
                                            .format(comm_rank, epoch,
                                                    num_epochs, eval_loss))
                                    else:
                                        if comm_rank == 0:
                                            print(
                                                "COMPLETED: evaluation loss for epoch {} (of {}) is {}"
                                                .format(
                                                    epoch, num_epochs,
                                                    eval_loss))
                                    if per_rank_output:
                                        iou_score = sess.run(iou_op)

                                        print(
                                            "COMPLETED: rank {}, evaluation IoU for epoch {} (of {}) is {}"
                                            .format(comm_rank, epoch,
                                                    num_epochs, iou_score))
                                    else:
                                        iou_score = sess.run(iou_avg)

                                        if comm_rank == 0:
                                            print(
                                                "COMPLETED: evaluation IoU for epoch {} (of {}) is {}"
                                                .format(
                                                    epoch, num_epochs,
                                                    iou_score))
                                    sess.run(iou_reset_op)

                                    sess.run(val_init_op,
                                             feed_dict={handle: val_handle},
                                             options=options,
                                             run_metadata=run_metadata)

                                    update_timeline_in_range(
                                        enable_tf_timeline, run_metadata,
                                        many_runs_timeline,
                                        "train_" + str(global_step) + ".json")

                                    if comm_rank == 0:
                                        step_trace_fp = open(
                                            "validation_step_trace_out.pickle",
                                            "wb")
                                        pickle.dump(run_metadata,
                                                    step_trace_fp)

                                    break
                            nvtx.RangePop()  # Eval Loop

                    if enable_tf_timeline:
                        many_runs_timeline.save('Timeliner_output.json')

                    # reset counters
                    epoch += 1
                    train_loss = 0.
                    step = 0

                    nvtx.RangePop()  # Epoch
                    nvtx.RangePush("Epoch", epoch)

                    training_iteration_time_logger.end_timer()

                except tf.errors.OutOfRangeError:
                    break

            nvtx.RangePop()  # Epoch
            nvtx.RangePop()  # Training Loop

            training_loop_timer_logger.end_timer()

    if enable_tf_timeline:
        many_runs_timeline.save('Timeliner_output.json')

    io_training_time_logger.end_timer()
    global_time_logger.end_timer()