Exemplo n.º 1
0
    def train(
        self,
        iter_unit,
        num_iter,
        batch_size,
        warmup_steps=50,
        weight_decay=1e-4,
        lr_init=0.1,
        lr_warmup_epochs=5,
        momentum=0.9,
        log_every_n_steps=1,
        loss_scale=256,
        label_smoothing=0.0,
        use_cosine_lr=False,
        use_static_loss_scaling=False,
        is_benchmark=False
    ):

        if iter_unit not in ["epoch", "batch"]:
            raise ValueError('`iter_unit` value is unknown: %s (allowed: ["epoch", "batch"])' % iter_unit)

        if self.run_hparams.data_dir is None and not is_benchmark:
            raise ValueError('`data_dir` must be specified for training!')

        if self.run_hparams.use_tf_amp or self.run_hparams.dtype == tf.float16:
            if use_static_loss_scaling:
                os.environ["TF_ENABLE_AUTO_MIXED_PRECISION_LOSS_SCALING"] = "0"
            else:
                LOGGER.log("TF Loss Auto Scaling is activated")
                os.environ["TF_ENABLE_AUTO_MIXED_PRECISION_LOSS_SCALING"] = "1"
        else:
            use_static_loss_scaling = False  # Make sure it hasn't been set to True on FP32 training

        num_gpus = 1 if not hvd_utils.is_using_hvd() else hvd.size()
        global_batch_size = batch_size * num_gpus

        if self.run_hparams.data_dir is not None:
            filenames,num_samples, num_steps, num_epochs, num_decay_steps = runner_utils.parse_tfrecords_dataset(
                data_dir=self.run_hparams.data_dir,
                mode="train",
                iter_unit=iter_unit,
                num_iter=num_iter,
                global_batch_size=global_batch_size,
            )

            steps_per_epoch = num_steps / num_epochs

        else:
            num_epochs = 1
            num_steps = num_iter
            steps_per_epoch = num_steps
            num_decay_steps = num_steps
            num_samples = num_steps * batch_size

            
        if self.run_hparams.data_idx_dir is not None:
            idx_filenames = runner_utils.parse_dali_idx_dataset(
                data_idx_dir=self.run_hparams.data_idx_dir,
                mode="train"
            )
            
        training_hooks = []
      
        if hvd.rank() == 0:
            LOGGER.log('Starting Model Training...')
            LOGGER.log("Training Epochs", num_epochs)
            LOGGER.log("Total Steps", num_steps)
            LOGGER.log("Steps per Epoch", steps_per_epoch)
            LOGGER.log("Decay Steps", num_decay_steps)
            LOGGER.log("Weight Decay Factor", weight_decay)
            LOGGER.log("Init Learning Rate", lr_init)
            LOGGER.log("Momentum", momentum)
            LOGGER.log("Num GPUs", num_gpus)
            LOGGER.log("Per-GPU Batch Size", batch_size)

            
            if is_benchmark:

                benchmark_logging_hook = hooks.BenchmarkLoggingHook(
                    log_file_path=os.path.join(self.run_hparams.log_dir, "training_benchmark.json"),
                    global_batch_size=global_batch_size,
                    log_every=log_every_n_steps,
                    warmup_steps=warmup_steps
                )

                training_hooks.append(benchmark_logging_hook)

            else:

                training_logging_hook = hooks.TrainingLoggingHook(
                    log_file_path=os.path.join(self.run_hparams.log_dir, "training.json"),
                    global_batch_size=global_batch_size,
                    num_steps=num_steps,
                    num_samples=num_samples,
                    num_epochs=num_epochs,
                    log_every=log_every_n_steps
                )

                training_hooks.append(training_logging_hook)

        if hvd_utils.is_using_hvd():
            bcast_hook = hvd.BroadcastGlobalVariablesHook(0)
            training_hooks.append(bcast_hook)

        training_hooks.append(hooks.PrefillStagingAreasHook())

        # NVTX
        nvtx_callback = NVTXHook(skip_n_steps=1, name='Train')
        training_hooks.append(nvtx_callback)
      
        estimator_params = {
            'batch_size': batch_size,
            'steps_per_epoch': steps_per_epoch,
            'num_gpus': num_gpus,
            'momentum': momentum,
            'lr_init': lr_init,
            'lr_warmup_epochs': lr_warmup_epochs,
            'weight_decay': weight_decay,
            'loss_scale': loss_scale,
            'apply_loss_scaling': use_static_loss_scaling,
            'label_smoothing': label_smoothing,
            'num_decay_steps': num_decay_steps,
            'use_cosine_lr': use_cosine_lr
        }

        image_classifier = self._get_estimator(
            mode='train',
            run_params=estimator_params,
            use_xla=self.run_hparams.use_xla,
            use_dali=self.run_hparams.use_dali,
            gpu_memory_fraction=self.run_hparams.gpu_memory_fraction
        )

        def training_data_fn():
            
            if self.run_hparams.use_dali and self.run_hparams.data_idx_dir is not None:
                if hvd.rank() == 0:
                    LOGGER.log("Using DALI input... ")
                    
                return data_utils.get_dali_input_fn(
                    filenames=filenames,
                    idx_filenames=idx_filenames,
                    batch_size=batch_size,
                    height=self.run_hparams.height,
                    width=self.run_hparams.width,
                    training=True,
                    distort_color=self.run_hparams.distort_colors,
                    num_threads=self.run_hparams.num_preprocessing_threads,
                    deterministic=False if self.run_hparams.seed is None else True
                )
            
            elif self.run_hparams.data_dir is not None:

                return data_utils.get_tfrecords_input_fn(
                    filenames=filenames,
                    batch_size=batch_size,
                    height=self.run_hparams.height,
                    width=self.run_hparams.width,
                    training=True,
                    distort_color=self.run_hparams.distort_colors,
                    num_threads=self.run_hparams.num_preprocessing_threads,
                    deterministic=False if self.run_hparams.seed is None else True
                )

            else:
                if hvd.rank() == 0:
                    LOGGER.log("Using Synthetic Data ...")
                return data_utils.get_synth_input_fn(
                    batch_size=batch_size,
                    height=self.run_hparams.height,
                    width=self.run_hparams.width,
                    num_channels=self.run_hparams.n_channels,
                    data_format=self.run_hparams.input_format,
                    num_classes=self.run_hparams.n_classes,
                    dtype=self.run_hparams.dtype,
                )


        try:
            image_classifier.train(
                input_fn=training_data_fn,
                steps=num_steps,
                hooks=training_hooks,
            )
        except KeyboardInterrupt:
            print("Keyboard interrupt")
            
        if hvd.rank() == 0:
            LOGGER.log('Ending Model Training ...')
Exemplo n.º 2
0
features_plh_2 = tf.compat.v1.placeholder('float', [None, 8])
labels_plh = tf.compat.v1.placeholder('float', [None, 1])


logits, probs = DenseBinaryClassificationNet(inputs=(features_plh_1, features_plh_2))

loss = tf.math.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=logits, labels=labels_plh))
acc = tf.math.reduce_mean(tf.compat.v1.metrics.accuracy(labels=labels_plh, predictions=tf.round(tf.nn.sigmoid(logits))))
optimizer = tf.compat.v1.train.MomentumOptimizer(learning_rate=0.01, momentum=0.9, use_nesterov=True).minimize(loss)


# Initialize variables. local variables are needed to be initialized for tf.metrics.*
init_g = tf.compat.v1.global_variables_initializer()
init_l = tf.compat.v1.local_variables_initializer()

nvtx_callback = NVTXHook(skip_n_steps=1, name='Train')

# Start training
with tf.compat.v1.train.MonitoredSession(hooks=[nvtx_callback]) as sess:
    sess.run([init_g, init_l])

    # Run graph
    for epoch in range(NUM_EPOCHS):
        for (x1, x2), y in batch_generator(features, labels, batch_size=128):
            optimizer_, loss_, acc_ = sess.run(
                [optimizer, loss, acc],
                feed_dict={
                    features_plh_1: x1,
                    features_plh_2: x2,
                    labels_plh: y
                }
def main(device, input_path_train, input_path_validation, dummy_data,
         downsampling_fact, downsampling_mode, channels, data_format, label_id,
         weights, image_dir, checkpoint_dir, trn_sz, val_sz, loss_type, model,
         decoder, fs_type, optimizer, batch, batchnorm, num_epochs, dtype,
         disable_checkpoints, disable_imsave, tracing, trace_dir,
         output_sampling, scale_factor, intra_threads, inter_threads):
    #init horovod
    comm_rank = 0
    comm_local_rank = 0
    comm_size = 1
    comm_local_size = 1
    if horovod:
        hvd.init()
        comm_rank = hvd.rank()
        comm_local_rank = hvd.local_rank()
        comm_size = hvd.size()
        #not all horovod versions have that implemented
        try:
            comm_local_size = hvd.local_size()
        except:
            comm_local_size = 1
        if comm_rank == 0:
            print("Using distributed computation with Horovod: {} total ranks".
                  format(comm_size, comm_rank))

    #downsampling? recompute image dims
    image_height = image_height_orig // downsampling_fact
    image_width = image_width_orig // downsampling_fact

    #parameters
    per_rank_output = False
    loss_print_interval = 1

    #session config
    sess_config = tf.ConfigProto(
        inter_op_parallelism_threads=inter_threads,  #6
        intra_op_parallelism_threads=intra_threads,  #1
        log_device_placement=False,
        allow_soft_placement=True)
    sess_config.gpu_options.visible_device_list = str(comm_local_rank)
    sess_config.gpu_options.force_gpu_compatible = True

    #get data
    training_graph = tf.Graph()
    if comm_rank == 0:
        print("Loading data...")
    train_files = load_data(input_path_train, True, trn_sz, horovod)
    valid_files = load_data(input_path_validation, False, val_sz, horovod)

    #print some stats
    if comm_rank == 0:
        print("Num workers: {}".format(comm_size))
        print("Local batch size: {}".format(batch))
        if dtype == tf.float32:
            print("Precision: {}".format("FP32"))
        else:
            print("Precision: {}".format("FP16"))
        print("Decoder: {}".format(decoder))
        print("Batch normalization: {}".format(batchnorm))
        print("Channels: {}".format(channels))
        print("Loss type: {}".format(loss_type))
        print("Loss weights: {}".format(weights))
        print("Loss scale factor: {}".format(scale_factor))
        print("Output sampling target: {}".format(output_sampling))
        #print optimizer parameters
        for k, v in optimizer.items():
            print("Solver Parameters: {k}: {v}".format(k=k, v=v))
        print("Num training samples: {}".format(train_files.shape[0]))
        print("Num validation samples: {}".format(valid_files.shape[0]))
        if dummy_data:
            print("Using synthetic dummy data")
        print("Disable checkpoints: {}".format(disable_checkpoints))
        print("Disable image save: {}".format(disable_imsave))

    #compute epochs and stuff:
    if fs_type == "local":
        num_samples = train_files.shape[0] // comm_local_size
    else:
        num_samples = train_files.shape[0] // comm_size
    print("num_samples: {} batch: {}".format(num_samples, batch))
    num_steps_per_epoch = num_samples // batch
    num_steps = num_epochs * num_steps_per_epoch
    if comm_rank == 0:
        print("Number of steps per epoch: {}".format(num_steps_per_epoch))
        print("Number of steps in total: {}".format(num_steps))
    if per_rank_output:
        print("Rank {} does {} steps per epoch".format(comm_rank,
                                                       num_steps_per_epoch))

    with training_graph.as_default():

        if dummy_data:
            dummy_data_args = dict(batchsize=batch,
                                   data_format=data_format,
                                   dtype=dtype)
            trn_dataset = create_dummy_dataset(n_samples=trn_sz,
                                               num_epochs=num_epochs,
                                               **dummy_data_args)
            val_dataset = create_dummy_dataset(n_samples=val_sz,
                                               num_epochs=1,
                                               **dummy_data_args)
        else:
            #create readers
            trn_reader = h5_input_reader(input_path_train,
                                         channels,
                                         weights,
                                         dtype,
                                         normalization_file="stats.h5",
                                         update_on_read=False,
                                         data_format=data_format,
                                         label_id=label_id,
                                         sample_target=output_sampling)
            val_reader = h5_input_reader(input_path_validation,
                                         channels,
                                         weights,
                                         dtype,
                                         normalization_file="stats.h5",
                                         update_on_read=False,
                                         data_format=data_format,
                                         label_id=label_id)
            #create datasets
            if fs_type == "local":
                trn_dataset = create_dataset(trn_reader,
                                             train_files,
                                             batch,
                                             num_epochs,
                                             comm_local_size,
                                             comm_local_rank,
                                             dtype,
                                             shuffle=True)
                val_dataset = create_dataset(val_reader,
                                             valid_files,
                                             batch,
                                             1,
                                             comm_local_size,
                                             comm_local_rank,
                                             dtype,
                                             shuffle=False)
            else:
                trn_dataset = create_dataset(trn_reader,
                                             train_files,
                                             batch,
                                             num_epochs,
                                             comm_size,
                                             comm_rank,
                                             dtype,
                                             shuffle=True)
                val_dataset = create_dataset(val_reader,
                                             valid_files,
                                             batch,
                                             1,
                                             comm_size,
                                             comm_rank,
                                             dtype,
                                             shuffle=False)

        #create iterators
        handle = tf.placeholder(tf.string,
                                shape=[],
                                name="iterator-placeholder")
        iterator = tf.data.Iterator.from_string_handle(
            handle, (dtype, tf.int32, dtype, tf.string),
            ((batch, len(channels), image_height_orig,
              image_width_orig) if data_format == "channels_first" else
             (batch, image_height_orig, image_width_orig, len(channels)),
             (batch, image_height_orig, image_width_orig),
             (batch, image_height_orig, image_width_orig), (batch)))
        next_elem = iterator.get_next()

        #if downsampling, do some preprocessing
        if downsampling_fact != 1:
            if downsampling_mode == "scale":
                #do downsampling
                rand_select = tf.cast(tf.one_hot(tf.random_uniform(
                    (batch, image_height, image_width),
                    minval=0,
                    maxval=downsampling_fact * downsampling_fact,
                    dtype=tf.int32),
                                                 depth=downsampling_fact *
                                                 downsampling_fact,
                                                 axis=-1),
                                      dtype=tf.int32)
                next_elem = (tf.layers.average_pooling2d(next_elem[0], downsampling_fact, downsampling_fact, 'valid', data_format), \
                             tf.reduce_max(tf.multiply(tf.image.extract_image_patches(tf.expand_dims(next_elem[1], axis=-1), \
                                                                                 [1, downsampling_fact, downsampling_fact, 1], \
                                                                                 [1, downsampling_fact, downsampling_fact, 1], \
                                                                                 [1,1,1,1], 'VALID'), rand_select), axis=-1), \
                             tf.squeeze(tf.layers.average_pooling2d(tf.expand_dims(next_elem[2], axis=-1), downsampling_fact, downsampling_fact, 'valid', "channels_last"), axis=-1), \
                             next_elem[3])
            elif downsampling_mode == "center-crop":
                #some parameters
                length = 1. / float(downsampling_fact)
                offset = length / 2.
                boxes = [[offset, offset, offset + length, offset + length]
                         ] * batch
                box_ind = list(range(0, batch))
                crop_size = [image_height, image_width]

                #be careful with data order
                if data_format == "channels_first":
                    next_elem[0] = tf.transpose(next_elem[0],
                                                perm=[0, 2, 3, 1])

                #crop
                next_elem = (tf.image.crop_and_resize(next_elem[0], boxes, box_ind, crop_size, method='bilinear', extrapolation_value=0, name="data_cropping"), \
                             ensure_type(tf.squeeze(tf.image.crop_and_resize(tf.expand_dims(next_elem[1],axis=-1), boxes, box_ind, crop_size, method='nearest', extrapolation_value=0, name="label_cropping"), axis=-1), tf.int32), \
                             tf.squeeze(tf.image.crop_and_resize(tf.expand_dims(next_elem[2],axis=-1), boxes, box_ind, crop_size, method='bilinear', extrapolation_value=0, name="weight_cropping"), axis=-1), \
                             next_elem[3])

                #be careful with data order
                if data_format == "channels_first":
                    next_elem[0] = tf.transpose(next_elem[0],
                                                perm=[0, 3, 1, 2])

            else:
                raise ValueError(
                    "Error, downsampling mode {} not supported. Supported are [center-crop, scale]"
                    .format(downsampling_mode))

        #create init handles
        #trn
        trn_iterator = trn_dataset.make_initializable_iterator()
        trn_handle_string = trn_iterator.string_handle()
        trn_init_op = iterator.make_initializer(trn_dataset)
        #val
        val_iterator = val_dataset.make_initializable_iterator()
        val_handle_string = val_iterator.string_handle()
        val_init_op = iterator.make_initializer(val_dataset)

        #compute the input filter number based on number of channels used
        num_channels = len(channels)
        #set up model
        model = deeplab_v3_plus_generator(num_classes=3,
                                          output_stride=8,
                                          base_architecture=model,
                                          decoder=decoder,
                                          batchnorm=batchnorm,
                                          pre_trained_model=None,
                                          batch_norm_decay=None,
                                          data_format=data_format)

        logit, prediction = model(next_elem[0], True, dtype)

        #set up loss
        loss = None

        #cast the logits to fp32
        logit = ensure_type(logit, tf.float32)
        if loss_type == "weighted":
            #cast weights to FP32
            w_cast = ensure_type(next_elem[2], tf.float32)
            loss = tf.losses.sparse_softmax_cross_entropy(
                labels=next_elem[1],
                logits=logit,
                weights=w_cast,
                reduction=tf.losses.Reduction.SUM)
            if scale_factor != 1.0:
                loss *= scale_factor

        elif loss_type == "weighted_mean":
            #cast weights to FP32
            w_cast = ensure_type(next_elem[2], tf.float32)
            loss = tf.losses.sparse_softmax_cross_entropy(
                labels=next_elem[1],
                logits=logit,
                weights=w_cast,
                reduction=tf.losses.Reduction.SUM_BY_NONZERO_WEIGHTS)
            if scale_factor != 1.0:
                loss *= scale_factor

        elif loss_type == "focal":
            #one-hot-encode
            labels_one_hot = tf.contrib.layers.one_hot_encoding(
                next_elem[1], 3)
            #cast to FP32
            labels_one_hot = ensure_type(labels_one_hot, tf.float32)
            loss = focal_loss(onehot_labels=labels_one_hot,
                              logits=logit,
                              alpha=1.,
                              gamma=2.)

        else:
            raise ValueError("Error, loss type {} not supported.",
                             format(loss_type))

        #determine flops
        flops = graph_flops.graph_flops(
            format="NHWC" if data_format == "channels_last" else "NCHW",
            verbose=False,
            batch=batch,
            sess_config=sess_config)
        flops *= comm_size
        if comm_rank == 0:
            print('training flops: {:.3f} TF/step'.format(flops * 1e-12))

        #number of trainable parameters
        if comm_rank == 0:
            num_params = get_number_of_trainable_parameters()
            print('number of trainable parameters: {} ({} MB)'.format(
                num_params,
                num_params * (4 if dtype == tf.float32 else 2) * (2**-20)))

        if horovod:
            loss_avg = hvd.allreduce(ensure_type(loss, tf.float32))
        else:
            loss_avg = tf.identity(loss)
        tmpl = (loss if per_rank_output else loss_avg)

        #set up global step - keep on CPU
        with tf.device('/device:CPU:0'):
            global_step = tf.train.get_or_create_global_step()

        #set up optimizer
        if optimizer['opt_type'].startswith("LARC"):
            if comm_rank == 0:
                print("Enabling LARC")
            train_op, lr = get_larc_optimizer(optimizer, loss, global_step,
                                              num_steps_per_epoch, horovod)
        else:
            train_op, lr = get_optimizer(optimizer, loss, global_step,
                                         num_steps_per_epoch, horovod)

        #set up streaming metrics
        iou_op, iou_update_op = tf.metrics.mean_iou(labels=next_elem[1],
                                                    predictions=tf.argmax(
                                                        prediction, axis=3),
                                                    num_classes=3,
                                                    weights=None,
                                                    metrics_collections=None,
                                                    updates_collections=None,
                                                    name="iou_score")
        iou_reset_op = tf.variables_initializer([
            i for i in tf.local_variables() if i.name.startswith('iou_score/')
        ])

        if horovod:
            iou_avg = hvd.allreduce(iou_op)
        else:
            iou_avg = tf.identity(iou_op)

        if "gpu" in device.lower():
            with tf.device(device):
                mem_usage_ops = [
                    tf.contrib.memory_stats.MaxBytesInUse(),
                    tf.contrib.memory_stats.BytesLimit()
                ]
        #hooks
        #these hooks are essential. regularize the step hook by adding one additional step at the end
        #hooks = [tf.train.StopAtStepHook(last_step=3)]
        #hooks = [tf.train.StopAtStepHook(num_steps=3)]
        hooks = [tf.train.StopAtStepHook(last_step=num_steps + 1)]
        nvtx_callback = NVTXHook(skip_n_steps=0, name='TTTTTrain')
        hooks.append(nvtx_callback)
        #bcast init for bcasting the model after start
        if horovod:
            init_bcast = hvd.broadcast_global_variables(0)
        #initializers:
        init_op = tf.global_variables_initializer()
        init_local_op = tf.local_variables_initializer()

        #checkpointing
        if comm_rank == 0:
            checkpoint_save_freq = 5 * num_steps_per_epoch
            checkpoint_saver = tf.train.Saver(max_to_keep=1000)
            if (not disable_checkpoints):
                hooks.append(
                    tf.train.CheckpointSaverHook(
                        checkpoint_dir=checkpoint_dir,
                        save_steps=checkpoint_save_freq,
                        saver=checkpoint_saver))
            #create image dir if not exists
            if not os.path.isdir(image_dir):
                os.makedirs(image_dir)

        #tracing
        if tracing is not None:
            import tracehook
            tracing_hook = tracehook.TraceHook(steps_to_trace=tracing,
                                               cache_traces=True,
                                               trace_dir=trace_dir)
            hooks.append(tracing_hook)
            print("############ tracing enabled")

        # instead of averaging losses over an entire epoch, use a moving
        #  window average
        recent_losses = []
        loss_window_size = 10

        #start session
        with tf.train.MonitoredTrainingSession(config=sess_config,
                                               hooks=hooks) as sess:
            #initialize
            sess.run([init_op, init_local_op])
            #restore from checkpoint:
            if comm_rank == 0 and not disable_checkpoints:
                load_model(sess, checkpoint_saver, checkpoint_dir)
            #broadcast loaded model variables
            if horovod:
                sess.run(init_bcast)
            #create iterator handles
            trn_handle, val_handle = sess.run(
                [trn_handle_string, val_handle_string])
            #init iterators
            sess.run(trn_init_op, feed_dict={handle: trn_handle})
            sess.run(val_init_op, feed_dict={handle: val_handle})

            # figure out what step we're on (it won't be 0 if we are
            #  restoring from a checkpoint) so we can count from there
            train_steps = sess.run([global_step])[0]

            #do the training
            epoch = 1
            step = 1

            prev_mem_usage = 0
            t_sustained_start = time.time()
            r_peak = 0

            #warmup loops
            print("### Warmup for 5 steps")
            start_time = time.time()
            #while not sess.should_stop():
            for _ in range(5):
                #try:
                print('warmup train_steps is {}'.format(train_steps))
                if train_steps == 5:
                    #                    if have_pycuda:
                    #                        pyc.driver.start_profiler()
                    print(train_steps)
                _ = sess.run([train_op], feed_dict={handle: trn_handle})
                #tmp_loss = sess.run([(loss if per_rank_output else loss_avg)],feed_dict={handle: trn_handle})

                if train_steps == 5:
                    #                    if have_pycuda:
                    #                        pyc.driver.stop_profiler()
                    print(train_steps)

                train_steps += 1

            end_time = time.time()
            print("### Warmup time: {:0.2f}".format(end_time - start_time))

            ### Start profiling
            print('Begin training loop')
            #if have_cupy:
            #cupy.cuda.profiler.start()
            #            if have_pycuda:
            #                pyc.driver.start_profiler()
            #while not sess.should_stop():
            for _ in range(1):
                try:
                    print('train_steps is {}'.format(train_steps))
                    if train_steps == 5:
                        if have_pycuda:
                            pyc.driver.start_profiler()
                        print(train_steps)
                    _ = sess.run([tmpl], feed_dict={handle: trn_handle})
                    #                    _ = sess.run([train_op],feed_dict={handle: trn_handle})
                    if train_steps == 5:
                        if have_pycuda:
                            pyc.driver.stop_profiler()
                        print(train_steps)
                    train_steps += 1
                except tf.errors.OutOfRangeError:
                    break


#            if have_pycuda:
#                pyc.driver.stop_profiler()

### End of profiling
#if have_cupy:
#    cupy.cuda.profiler.stop()

# write any cached traces to disk
        if tracing is not None:
            tracing_hook.write_traces()

    print('All done')
Exemplo n.º 4
0
    def train(self,
              outputs_path_rooth,
              display_step=500,
              restore=False,
              model_to_load_path=None,
              **kwargs):

        if not os.path.isdir(outputs_path_rooth):
            os.makedirs(outputs_path_rooth)

        prediction_path = "predictions/"
        self.prediction_path_train = os.path.join(outputs_path_rooth,
                                                  prediction_path, "train")
        os.system("rm -r {}".format(self.prediction_path_train))
        os.makedirs(self.prediction_path_train)
        self.prediction_path_valid = os.path.join(outputs_path_rooth,
                                                  prediction_path, "valid")
        os.system("rm -r {}".format(self.prediction_path_valid))
        os.makedirs(self.prediction_path_valid)

        logs_path = "logs"
        logs_path = os.path.join(outputs_path_rooth, logs_path)
        os.system("rm -r {}".format(logs_path))
        os.makedirs(logs_path)

        model_path = os.path.join(outputs_path_rooth, "model_ckpts_temp/")
        os.system("rm -r {}".format(model_path))
        os.makedirs(model_path)
        self.model_path = model_path

        self._initialize()
        self.display_step = display_step

        total_epochs = self.data_provider.epochs
        print("Starting optimization:{}-total epochs".format(total_epochs))

        config = tf.ConfigProto()
        config.gpu_options.allow_growth = True
        #log_device_placement
        self.loop = 0

        initialize_variables = tf.global_variables_initializer()
        self._get_summaries()
        self.saver = tf.train.Saver(max_to_keep=10)

        nvtx_callback = NVTXHook(skip_n_steps=1, name='Train')
        "with tf.Session(config=config,  graph=tf.get_default_graph()) as sess:"
        with tf.train.MonitoredSession(hooks=[nvtx_callback]) as sess:
            self.summary_writer = tf.summary.FileWriter(logs_path,
                                                        graph=sess.graph)

            sess.run(initialize_variables)

            sess.run(self.init_iter)

            if restore:
                ckpt = tf.train.get_checkpoint_state(model_to_load_path)
                if ckpt and ckpt.model_checkpoint_path:
                    self.net.restore(ckpt.model_checkpoint_path, sess)

            print("Start optimization")
            global_epoch = 0
            global_iter = 0

            for epoch in range(total_epochs):
                self.loss_train_left = []
                self.loss_train_right = []
                self.f1_train = []
                s = time.time()
                i_c = 0
                times = []
                while True:
                    try:
                        s = time.time()
                        global_iter = self.train_funct(sess, global_iter,
                                                       global_epoch + 1)
                        times.append(time.time() - s)
                        if i_c % 100:
                            print(np.mean(times))
                        i_c += 1
                    except tf.errors.OutOfRangeError:
                        t_epoch = time.time() - s
                        print("epoch run in {:.2f} ({:.5f}s/iter)".format(
                            t_epoch, t_epoch / i_c))
                        exit()
                        global_epoch += 1
                        self.net.save(self.saver, sess, model_path,
                                      global_epoch)
                        self.get_train_summaries_and_reinitialize(global_epoch)
                        self.run_validation(sess, global_epoch)
                        sess.run(self.init_iter)
                        break

        print("\n\nOptimization Finished!\n\n")