Beispiel #1
0
def main(_):
    if not FLAGS.dataset_dir:
        raise ValueError(
            'You must supply the dataset directory with --dataset_dir')

    tf.logging.set_verbosity(tf.logging.INFO)
    with tf.Graph().as_default() as g:
        K.set_learning_phase(True)
        #######################
        # Config model_deploy #
        #######################
        deploy_config = model_deploy.DeploymentConfig(
            num_clones=FLAGS.num_clones,
            clone_on_cpu=FLAGS.clone_on_cpu,
            replica_id=FLAGS.task,
            num_replicas=FLAGS.worker_replicas,
            num_ps_tasks=FLAGS.num_ps_tasks)
        # Create global_step
        with tf.device(deploy_config.variables_device()):
            global_step = slim.create_global_step()

        ######################
        # Select the dataset #
        ######################
        dataset = dataset_factory.get_dataset(FLAGS.dataset_name,
                                              FLAGS.dataset_split_name,
                                              FLAGS.dataset_dir)

        ######################
        # Select the network #
        ######################
        network_fn = nets_factory.get_network_fn(
            FLAGS.model_name,
            num_classes=(dataset.num_classes - FLAGS.labels_offset),
            weight_decay=FLAGS.weight_decay,
            is_training=True)

        #####################################
        # Select the preprocessing function #
        #####################################
        preprocessing_name = FLAGS.preprocessing_name or FLAGS.model_name
        image_preprocessing_fn = preprocessing_factory.get_preprocessing(
            preprocessing_name, is_training=True)

        ##############################################################
        # Create a dataset provider that loads data from the dataset #
        ##############################################################
        with tf.device(deploy_config.inputs_device()):
            provider = slim.dataset_data_provider.DatasetDataProvider(
                dataset,
                num_readers=FLAGS.num_readers,
                common_queue_capacity=20 * FLAGS.batch_size,
                common_queue_min=10 * FLAGS.batch_size)
            [image, label] = provider.get(['image', 'label'])
            label -= FLAGS.labels_offset

            train_image_size = FLAGS.train_image_size or network_fn.default_image_size

            image = image_preprocessing_fn(image, train_image_size,
                                           train_image_size)

            images, labels = tf.train.batch(
                [image, label],
                batch_size=FLAGS.batch_size,
                num_threads=FLAGS.num_preprocessing_threads,
                capacity=5 * FLAGS.batch_size)
            labels = slim.one_hot_encoding(
                labels, dataset.num_classes - FLAGS.labels_offset)
            batch_queue = slim.prefetch_queue.prefetch_queue(
                [images, labels], capacity=2 * deploy_config.num_clones)

        ####################
        # Define the model #
        ####################
        def clone_fn(batch_queue):
            """Allows data parallelism by creating multiple clones of network_fn."""
            images, labels = batch_queue.dequeue()
            logits = network_fn(images)

            #############################
            # Specify the loss function #
            #############################
            slim.losses.softmax_cross_entropy(
                logits,
                labels,
                label_smoothing=FLAGS.label_smoothing,
                weights=1.0)
            return None

        # Gather initial summaries.
        summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES))

        clones = model_deploy.create_clones(deploy_config, clone_fn,
                                            [batch_queue])
        first_clone_scope = deploy_config.clone_scope(0)
        # Gather update_ops from the first clone. These contain, for example,
        # the updates for the batch_norm variables created by network_fn.
        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS,
                                       first_clone_scope)

        # Add summaries for losses.
        for loss in tf.get_collection(tf.GraphKeys.LOSSES, first_clone_scope):
            summaries.add(tf.summary.scalar('losses/%s' % loss.op.name, loss))

        # Add summaries for variables.
        for variable in slim.get_model_variables():
            summaries.add(tf.summary.histogram(variable.op.name, variable))

        #################################
        # Configure the moving averages #
        #################################
        if FLAGS.moving_average_decay:
            moving_average_variables = slim.get_model_variables()
            variable_averages = tf.train.ExponentialMovingAverage(
                FLAGS.moving_average_decay, global_step)
        else:
            moving_average_variables, variable_averages = None, None

        #########################################
        # Configure the optimization procedure. #
        #########################################

        with tf.device(deploy_config.optimizer_device()):
            learning_rate = _configure_learning_rate(dataset.num_samples,
                                                     global_step)
            optimizer = _configure_optimizer(learning_rate)
            summaries.add(tf.summary.scalar('learning_rate', learning_rate))

        if FLAGS.sync_replicas:
            # If sync_replicas is enabled, the averaging will be done in the chief
            # queue runner.
            optimizer = tf.train.SyncReplicasOptimizer(
                opt=optimizer,
                replicas_to_aggregate=FLAGS.replicas_to_aggregate,
                total_num_replicas=FLAGS.worker_replicas,
                variable_averages=variable_averages,
                variables_to_average=moving_average_variables)
        elif FLAGS.moving_average_decay:
            # Update ops executed locally by trainer.
            update_ops.append(
                variable_averages.apply(moving_average_variables))

        # Variables to train.
        variables_to_train = _get_variables_to_train()

        #  and returns a train_tensor and summary_op
        total_loss, clones_gradients = model_deploy.optimize_clones(
            clones, optimizer, var_list=variables_to_train)
        # Add total_loss to summary.
        summaries.add(tf.summary.scalar('total_loss', total_loss))

        # Create gradient updates.
        grad_updates = optimizer.apply_gradients(clones_gradients,
                                                 global_step=global_step)
        update_ops.append(grad_updates)

        update_op = tf.group(*update_ops)
        with tf.control_dependencies([update_op]):
            train_tensor = tf.identity(total_loss, name='train_op')

        # Add the summaries from the first clone. These contain the summaries
        # created by model_fn and either optimize_clones() or _gather_clone_loss().
        summaries |= set(
            tf.get_collection(tf.GraphKeys.SUMMARIES, first_clone_scope))

        # Merge all summaries together.
        summary_op = tf.summary.merge(list(summaries), name='summary_op')

        if FLAGS.save_max_keep != 5:
            saver = tf.train.Saver(var_list=tf.all_variables(),
                                   max_to_keep=FLAGS.save_max_keep)
        else:
            saver = None

        # Set train step kwargs
        with tf.name_scope('train_step'):
            train_step_kwargs = {}
            if FLAGS.max_number_of_steps:
                should_stop_op = tf.greater_equal(global_step,
                                                  FLAGS.max_number_of_steps)
            else:
                should_stop_op = tf.constant(False)
            train_step_kwargs['should_stop'] = should_stop_op
            if FLAGS.log_every_n_steps > 0:
                train_step_kwargs['should_log'] = tf.equal(
                    tf.mod(global_step, FLAGS.log_every_n_steps), 0)

            # GDP Taylor Value
            gdp_mask_taylor = []
            var_ops = tf.get_collection('GDP_VAR')
            for var in var_ops:
                weight = var[0]
                bias = None
                if len(var) > 2:
                    bias = var[1]

                # get corresponding gradient ops
                weight_gradient = None
                for go in clones_gradients:
                    if go[1].name == weight.name:
                        weight_gradient = go[0]
                        break

                if bias is not None:
                    bias_gradient = None
                    for go in clones_gradients:
                        if go[1].name == bias.name:
                            bias_gradient = go[0]
                            break

                # calculate the corresponding taylor value
                filter_num = int(weight.shape[-1])
                if bias is not None:
                    taylor_value = tf.abs(
                        tf.reduce_sum(weight *
                                      weight_gradient, axis=[0, 1, 2]) +
                        bias * bias_gradient)
                else:
                    taylor_value = tf.abs(
                        tf.reduce_sum(weight * weight_gradient, axis=[0, 1,
                                                                      2]))

                # set mask update op
                mask = var[-1]

                gdp_mask_taylor.append((mask, taylor_value))

            train_step_kwargs['gdp_mask_taylor'] = gdp_mask_taylor

        ###########################
        # Kicks off the training. #
        ###########################
        slim.learning.train(
            train_tensor,
            logdir=FLAGS.train_dir,
            train_step_fn=train_step,
            train_step_kwargs=train_step_kwargs,
            master=FLAGS.master,
            is_chief=(FLAGS.task == 0),
            init_fn=_get_init_fn(),
            summary_op=summary_op,
            number_of_steps=FLAGS.max_number_of_steps,
            log_every_n_steps=FLAGS.log_every_n_steps,
            save_summaries_secs=FLAGS.save_summaries_secs,
            save_interval_secs=FLAGS.save_interval_secs,
            saver=saver,
            sync_optimizer=optimizer if FLAGS.sync_replicas else None)
def train(create_tensor_dict_fn,
          create_model_fn,
          train_config,
          master,
          task,
          num_clones,
          worker_replicas,
          clone_on_cpu,
          ps_tasks,
          worker_job_name,
          is_chief,
          train_dir,
          graph_hook_fn=None):
    """Training function for detection models.

  Args:
    create_tensor_dict_fn: a function to create a tensor input dictionary.
    create_model_fn: a function that creates a DetectionModel and generates
                     losses.
     : a train_pb2.TrainConfig protobuf.
    master: BNS name of the TensorFlow master to use.
    task: The task id of this training instance.
    num_clones: The number of clones to run per machine.
    worker_replicas: The number of work replicas to train with.
    clone_on_cpu: True if clones should be forced to run on CPU.
    ps_tasks: Number of parameter server tasks.
    worker_job_name: Name of the worker job.
    is_chief: Whether this replica is the chief replica.
    train_dir: Directory to write checkpoints and training summaries to.
    graph_hook_fn: Optional function that is called after the inference graph is
      built (before optimization). This is helpful to perform additional changes
      to the training graph such as adding FakeQuant ops. The function should
      modify the default graph.

  Raises:
    ValueError: If both num_clones > 1 and train_config.sync_replicas is true.
  """

    detection_model = create_model_fn()
    data_augmentation_options = [
        preprocessor_builder.build(step)
        for step in train_config.data_augmentation_options
    ]

    with tf.Graph().as_default():
        # Build a configuration specifying multi-GPU and multi-replicas.
        deploy_config = model_deploy.DeploymentConfig(
            num_clones=num_clones,
            clone_on_cpu=clone_on_cpu,
            replica_id=task,
            num_replicas=worker_replicas,
            num_ps_tasks=ps_tasks,
            worker_job_name=worker_job_name)

        # Place the global step on the device storing the variables.
        with tf.device(deploy_config.variables_device()):
            global_step = slim.create_global_step()

        if num_clones != 1 and train_config.sync_replicas:
            raise ValueError('In Synchronous SGD mode num_clones must ',
                             'be 1. Found num_clones: {}'.format(num_clones))
        batch_size = train_config.batch_size // num_clones
        if train_config.sync_replicas:
            batch_size //= train_config.replicas_to_aggregate

        with tf.device(deploy_config.inputs_device()):
            input_queue = create_input_queue(
                batch_size, create_tensor_dict_fn,
                train_config.batch_queue_capacity,
                train_config.num_batch_queue_threads,
                train_config.prefetch_queue_capacity,
                data_augmentation_options)

        # Gather initial summaries.
        # TODO(rathodv): See if summaries can be added/extracted from global tf
        # collections so that they don't have to be passed around.
        summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES))
        global_summaries = set([])

        model_fn = functools.partial(_create_losses,
                                     create_model_fn=create_model_fn,
                                     train_config=train_config)
        clones = model_deploy.create_clones(deploy_config, model_fn,
                                            [input_queue])
        first_clone_scope = clones[0].scope

        if graph_hook_fn:
            with tf.device(deploy_config.variables_device()):
                graph_hook_fn()

        # Gather update_ops from the first clone. These contain, for example,
        # the updates for the batch_norm variables created by model_fn.
        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS,
                                       first_clone_scope)

        with tf.device(deploy_config.optimizer_device()):
            training_optimizer, optimizer_summary_vars = optimizer_builder.build(
                train_config.optimizer)
            for var in optimizer_summary_vars:
                tf.summary.scalar(var.op.name, var, family='LearningRate')

        sync_optimizer = None
        if train_config.sync_replicas:
            training_optimizer = tf.train.SyncReplicasOptimizer(
                training_optimizer,
                replicas_to_aggregate=train_config.replicas_to_aggregate,
                total_num_replicas=worker_replicas)
            sync_optimizer = training_optimizer

        with tf.device(deploy_config.optimizer_device()):
            regularization_losses = (
                None if train_config.add_regularization_loss else [])
            total_loss, grads_and_vars = model_deploy.optimize_clones(
                clones,
                training_optimizer,
                regularization_losses=regularization_losses)
            total_loss = tf.check_numerics(total_loss,
                                           'LossTensor is inf or nan.')

            # Optionally multiply bias gradients by train_config.bias_grad_multiplier.
            if train_config.bias_grad_multiplier:
                biases_regex_list = ['.*/biases']
                grads_and_vars = variables_helper.multiply_gradients_matching_regex(
                    grads_and_vars,
                    biases_regex_list,
                    multiplier=train_config.bias_grad_multiplier)

            # Optionally freeze some layers by setting their gradients to be zero.
            if train_config.freeze_variables:
                grads_and_vars = variables_helper.freeze_gradients_matching_regex(
                    grads_and_vars, train_config.freeze_variables)

            # Optionally clip gradients
            if train_config.gradient_clipping_by_norm > 0:
                with tf.name_scope('clip_grads'):
                    grads_and_vars = slim.learning.clip_gradient_norms(
                        grads_and_vars, train_config.gradient_clipping_by_norm)

            # Create gradient updates.
            grad_updates = training_optimizer.apply_gradients(
                grads_and_vars, global_step=global_step)
            update_ops.append(grad_updates)
            update_op = tf.group(*update_ops, name='update_barrier')
            with tf.control_dependencies([update_op]):
                train_tensor = tf.identity(total_loss, name='train_op')

        # Add summaries.
        for model_var in slim.get_model_variables():
            global_summaries.add(
                tf.summary.histogram('ModelVars/' + model_var.op.name,
                                     model_var))
        for loss_tensor in tf.losses.get_losses():
            global_summaries.add(
                tf.summary.scalar('Losses/' + loss_tensor.op.name,
                                  loss_tensor))
        global_summaries.add(
            tf.summary.scalar('Losses/TotalLoss', tf.losses.get_total_loss()))

        # Add the summaries from the first clone. These contain the summaries
        # created by model_fn and either optimize_clones() or _gather_clone_loss().
        summaries |= set(
            tf.get_collection(tf.GraphKeys.SUMMARIES, first_clone_scope))
        summaries |= global_summaries

        # Merge all summaries together.
        summary_op = tf.summary.merge(list(summaries), name='summary_op')

        # Soft placement allows placing on CPU ops without GPU implementation.
        session_config = tf.ConfigProto(allow_soft_placement=True,
                                        log_device_placement=False)
        # remove
        # session_config.gpu_options.per_process_gpu_memory_fraction = 0.4

        # Save checkpoints regularly.
        keep_checkpoint_every_n_hours = train_config.keep_checkpoint_every_n_hours
        saver = tf.train.Saver(
            keep_checkpoint_every_n_hours=keep_checkpoint_every_n_hours)

        # Create ops required to initialize the model from a given checkpoint.
        init_fn = None
        if train_config.fine_tune_checkpoint:
            if not train_config.fine_tune_checkpoint_type:
                # train_config.from_detection_checkpoint field is deprecated. For
                # backward compatibility, fine_tune_checkpoint_type is set based on
                # from_detection_checkpoint.
                if train_config.from_detection_checkpoint:
                    train_config.fine_tune_checkpoint_type = 'detection'
                else:
                    train_config.fine_tune_checkpoint_type = 'classification'
            var_map = detection_model.restore_map(
                fine_tune_checkpoint_type=train_config.
                fine_tune_checkpoint_type,
                load_all_detection_checkpoint_vars=(
                    train_config.load_all_detection_checkpoint_vars))
            available_var_map = (
                variables_helper.get_variables_available_in_checkpoint(
                    var_map,
                    train_config.fine_tune_checkpoint,
                    include_global_step=False))
            init_saver = tf.train.Saver(available_var_map)

            def initializer_fn(sess):
                init_saver.restore(sess, train_config.fine_tune_checkpoint)

            init_fn = initializer_fn

        total_loss = slim.learning.train(
            train_tensor,
            logdir=train_dir,
            master=master,
            is_chief=is_chief,
            session_config=session_config,
            startup_delay_steps=train_config.startup_delay_steps,
            init_fn=init_fn,
            summary_op=summary_op,
            number_of_steps=(train_config.num_steps
                             if train_config.num_steps else None),
            save_summaries_secs=120,
            sync_optimizer=sync_optimizer,
            saver=saver)
        return total_loss
Beispiel #3
0
def train(create_tensor_dict_fn, create_model_fn, train_config, master, task,
          num_clones, worker_replicas, clone_on_cpu, ps_tasks, worker_job_name,
          is_chief, train_dir):
    """Training function for detection models.

    Args:
      create_tensor_dict_fn: a function to create a tensor input dictionary.
      create_model_fn: a function that creates a DetectionModel and generates
                       losses.
      train_config: a train_pb2.TrainConfig protobuf.
      master: BNS name of the TensorFlow master to use.
      task: The task id of this training instance.
      num_clones: The number of clones to run per machine.
      worker_replicas: The number of work replicas to train with.
      clone_on_cpu: True if clones should be forced to run on CPU.
      ps_tasks: Number of parameter server tasks.
      worker_job_name: Name of the worker job.
      is_chief: Whether this replica is the chief replica.
      train_dir: Directory to write checkpoints and training summaries to.
    """

    detection_model = create_model_fn()
    data_augmentation_options = [
        preprocessor_builder.build(step)
        for step in train_config.data_augmentation_options
    ]

    with tf.Graph().as_default():
        # Build a configuration specifying multi-GPU and multi-replicas.
        deploy_config = model_deploy.DeploymentConfig(
            num_clones=num_clones,
            clone_on_cpu=clone_on_cpu,
            replica_id=task,
            num_replicas=worker_replicas,
            num_ps_tasks=ps_tasks,
            worker_job_name=worker_job_name)

        # Place the global step on the device storing the variables.
        with tf.device(deploy_config.variables_device()):
            global_step = slim.create_global_step()

        with tf.device(deploy_config.inputs_device()):
            input_queue = create_input_queue(
                train_config.batch_size // num_clones, create_tensor_dict_fn,
                train_config.batch_queue_capacity,
                train_config.num_batch_queue_threads,
                train_config.prefetch_queue_capacity,
                data_augmentation_options)

        # Gather initial summaries.
        # TODO(rathodv): See if summaries can be added/extracted from global tf
        # collections so that they don't have to be passed around.
        summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES))
        global_summaries = set([])

        model_fn = functools.partial(_create_losses,
                                     create_model_fn=create_model_fn,
                                     train_config=train_config)
        clones = model_deploy.create_clones(deploy_config, model_fn,
                                            [input_queue])
        first_clone_scope = clones[0].scope

        # Gather update_ops from the first clone. These contain, for example,
        # the updates for the batch_norm variables created by model_fn.
        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS,
                                       first_clone_scope)

        with tf.device(deploy_config.optimizer_device()):
            training_optimizer = optimizer_builder.build(
                train_config.optimizer, global_summaries)

        sync_optimizer = None
        if train_config.sync_replicas:
            training_optimizer = tf.SyncReplicasOptimizer(
                training_optimizer,
                replicas_to_aggregate=train_config.replicas_to_aggregate,
                total_num_replicas=train_config.worker_replicas)
            sync_optimizer = training_optimizer

        # Create ops required to initialize the model from a given checkpoint.
        init_fn = None
        if train_config.fine_tune_checkpoint:
            var_map = detection_model.restore_map(
                from_detection_checkpoint=train_config.
                from_detection_checkpoint)
            available_var_map = (
                variables_helper.get_variables_available_in_checkpoint(
                    var_map, train_config.fine_tune_checkpoint))
            init_saver = tf.train.Saver(available_var_map)

            def initializer_fn(sess):
                init_saver.restore(sess, train_config.fine_tune_checkpoint)

            init_fn = initializer_fn

        with tf.device(deploy_config.optimizer_device()):
            total_loss, grads_and_vars = model_deploy.optimize_clones(
                clones, training_optimizer, regularization_losses=None)
            total_loss = tf.check_numerics(total_loss,
                                           'LossTensor is inf or nan.')

            # Optionally multiply bias gradients by train_config.bias_grad_multiplier.
            if train_config.bias_grad_multiplier:
                biases_regex_list = ['.*/biases']
                grads_and_vars = variables_helper.multiply_gradients_matching_regex(
                    grads_and_vars,
                    biases_regex_list,
                    multiplier=train_config.bias_grad_multiplier)

            # Optionally freeze some layers by setting their gradients to be zero.
            if train_config.freeze_variables:
                grads_and_vars = variables_helper.freeze_gradients_matching_regex(
                    grads_and_vars, train_config.freeze_variables)

            # Optionally clip gradients
            if train_config.gradient_clipping_by_norm > 0:
                with tf.name_scope('clip_grads'):
                    grads_and_vars = slim.learning.clip_gradient_norms(
                        grads_and_vars, train_config.gradient_clipping_by_norm)

            # Create gradient updates.
            grad_updates = training_optimizer.apply_gradients(
                grads_and_vars, global_step=global_step)
            update_ops.append(grad_updates)

            update_op = tf.group(*update_ops)
            with tf.control_dependencies([update_op]):
                train_tensor = tf.identity(total_loss, name='train_op')

        # Add summaries.
        for model_var in slim.get_model_variables():
            global_summaries.add(
                tf.summary.histogram(model_var.op.name, model_var))
        for loss_tensor in tf.losses.get_losses():
            global_summaries.add(
                tf.summary.scalar(loss_tensor.op.name, loss_tensor))
        global_summaries.add(
            tf.summary.scalar('TotalLoss', tf.losses.get_total_loss()))

        # Add the summaries from the first clone. These contain the summaries
        # created by model_fn and either optimize_clones() or _gather_clone_loss().
        summaries |= set(
            tf.get_collection(tf.GraphKeys.SUMMARIES, first_clone_scope))
        summaries |= global_summaries

        # Merge all summaries together.
        summary_op = tf.summary.merge(list(summaries), name='summary_op')

        # Soft placement allows placing on CPU ops without GPU implementation.
        session_config = tf.ConfigProto(allow_soft_placement=True,
                                        log_device_placement=False)

        # Save checkpoints regularly.
        keep_checkpoint_every_n_hours = train_config.keep_checkpoint_every_n_hours
        saver = tf.train.Saver(
            max_to_keep=50,
            keep_checkpoint_every_n_hours=keep_checkpoint_every_n_hours)

        slim.learning.train(
            train_tensor,
            logdir=train_dir,
            master=master,
            is_chief=is_chief,
            session_config=session_config,
            startup_delay_steps=train_config.startup_delay_steps,
            init_fn=init_fn,
            summary_op=summary_op,
            number_of_steps=(train_config.num_steps
                             if train_config.num_steps else None),
            save_summaries_secs=120,
            sync_optimizer=sync_optimizer,
            saver=saver)
def main(_):
    ###add for pruning
    if FLAGS.model_name == "vgg":
        gpu_options = tf.GPUOptions(
            per_process_gpu_memory_fraction=0.9)  #add by lzlu
        sessGPU = tf.Session(config=tf.ConfigProto(gpu_options=gpu_options))
    else:
        gpu_options = tf.GPUOptions(
            per_process_gpu_memory_fraction=0.3)  #add by lzlu
        sessGPU = tf.Session(config=tf.ConfigProto(gpu_options=gpu_options))
    print("FLAGS.model_name:", FLAGS.model_name)
    #config = tf.ConfigProto()
    #config.gpu_options.allow_growth=True
    #sessGPU = tf.Session(config=config)
    #sessGPU = tf.Session(config=tf.ConfigProto(log_device_placement=True))
    #sess = tf.Session(config=tf.ConfigProto(gpu_options=tf.GPUOptions(allow_growth=True)))
    print("FLAGS.max_number_of_steps:", FLAGS.max_number_of_steps)
    print("FLAGS.learning_rate:", FLAGS.learning_rate)
    print("FLAGS.weight_decay:", FLAGS.weight_decay)
    print("FLAGS.batch_size:", FLAGS.batch_size)
    print("FLAGS.trainable_scopes:", FLAGS.trainable_scopes)
    print("FLAGS.pruning_rates:", FLAGS.pruning_rates)
    print("FLAGS.train_dir:", FLAGS.train_dir)
    print("FLAGS.checkpoint_path:", FLAGS.checkpoint_path)
    print("FLAGS.pruning_gradient_update_ratio:",
          FLAGS.pruning_gradient_update_ratio)
    ###
    if not FLAGS.dataset_dir:
        raise ValueError(
            'You must supply the dataset directory with --dataset_dir')

    tf.logging.set_verbosity(tf.logging.INFO)
    with tf.Graph().as_default():
        #######################
        # Config model_deploy #
        #######################
        deploy_config = model_deploy.DeploymentConfig(
            num_clones=FLAGS.num_clones,
            clone_on_cpu=FLAGS.clone_on_cpu,
            replica_id=FLAGS.task,
            num_replicas=FLAGS.worker_replicas,
            num_ps_tasks=FLAGS.num_ps_tasks)
        print("deploy_config.variables_device():")
        print(deploy_config.variables_device())
        # Create global_step
        with tf.device(deploy_config.variables_device()):
            global_step = slim.create_global_step()

        ######################
        # Select the dataset #
        ######################
        dataset = dataset_factory.get_dataset(FLAGS.dataset_name,
                                              FLAGS.dataset_split_name,
                                              FLAGS.dataset_dir)

        ######################
        # Select the network #
        ######################
        network_fn = nets_factory.get_network_fn(
            FLAGS.model_name,
            num_classes=(dataset.num_classes - FLAGS.labels_offset),
            weight_decay=FLAGS.weight_decay,
            is_training=True)

        #####################################
        # Select the preprocessing function #
        #####################################
        preprocessing_name = FLAGS.preprocessing_name or FLAGS.model_name
        image_preprocessing_fn = preprocessing_factory.get_preprocessing(
            preprocessing_name, is_training=True)

        ##############################################################
        # Create a dataset provider that loads data from the dataset #
        ##############################################################
        with tf.device(deploy_config.inputs_device()):
            provider = slim.dataset_data_provider.DatasetDataProvider(
                dataset,
                num_readers=FLAGS.num_readers,
                common_queue_capacity=20 * FLAGS.batch_size,
                common_queue_min=10 * FLAGS.batch_size)
            [image, label] = provider.get(['image', 'label'])
            label -= FLAGS.labels_offset

            train_image_size = FLAGS.train_image_size or network_fn.default_image_size

            image = image_preprocessing_fn(image, train_image_size,
                                           train_image_size)

            images, labels = tf.train.batch(
                [image, label],
                batch_size=FLAGS.batch_size,
                num_threads=FLAGS.num_preprocessing_threads,
                capacity=5 * FLAGS.batch_size)
            labels = slim.one_hot_encoding(
                labels, dataset.num_classes - FLAGS.labels_offset)
            batch_queue = slim.prefetch_queue.prefetch_queue(
                [images, labels], capacity=2 * deploy_config.num_clones)

        ####################
        # Define the model #
        ####################
        def clone_fn(batch_queue):
            """Allows data parallelism by creating multiple clones of network_fn."""
            with tf.device(deploy_config.inputs_device()):
                images, labels = batch_queue.dequeue()
            logits, end_points = network_fn(images)

            #############################
            # Specify the loss function #
            #############################
            if 'AuxLogits' in end_points:
                tf.losses.softmax_cross_entropy(
                    logits=end_points['AuxLogits'],
                    onehot_labels=labels,
                    label_smoothing=FLAGS.label_smoothing,
                    weights=0.4,
                    scope='aux_loss')
            tf.losses.softmax_cross_entropy(
                logits=logits,
                onehot_labels=labels,
                label_smoothing=FLAGS.label_smoothing,
                weights=1.0)
            return end_points

        # Gather initial summaries.
        summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES))

        clones = model_deploy.create_clones(deploy_config, clone_fn,
                                            [batch_queue])
        first_clone_scope = deploy_config.clone_scope(0)
        # Gather update_ops from the first clone. These contain, for example,
        # the updates for the batch_norm variables created by network_fn.
        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS,
                                       first_clone_scope)

        # Add summaries for end_points.
        end_points = clones[0].outputs
        for end_point in end_points:
            x = end_points[end_point]
            summaries.add(tf.summary.histogram('activations/' + end_point, x))
            summaries.add(
                tf.summary.scalar('sparsity/' + end_point,
                                  tf.nn.zero_fraction(x)))

        # Add summaries for losses.
        for loss in tf.get_collection(tf.GraphKeys.LOSSES, first_clone_scope):
            summaries.add(tf.summary.scalar('losses/%s' % loss.op.name, loss))

        # Add summaries for variables.
        for variable in slim.get_model_variables():
            summaries.add(tf.summary.histogram(variable.op.name, variable))
            ##add for pruning
            summaries.add(
                tf.summary.scalar('pruning_rate/' + variable.op.name,
                                  1 - tf.nn.zero_fraction(variable)))

        #################################
        # Configure the moving averages #
        #################################
        if FLAGS.moving_average_decay:
            moving_average_variables = slim.get_model_variables()
            variable_averages = tf.train.ExponentialMovingAverage(
                FLAGS.moving_average_decay, global_step)
        else:
            moving_average_variables, variable_averages = None, None

        #########################################
        # Configure the optimization procedure. #
        #########################################
        print("deploy_config.optimizer_device():")
        print(deploy_config.optimizer_device())
        with tf.device(deploy_config.optimizer_device()):
            learning_rate = _configure_learning_rate(dataset.num_samples,
                                                     global_step)
            optimizer = _configure_optimizer(learning_rate)
            summaries.add(tf.summary.scalar('learning_rate', learning_rate))

        if FLAGS.sync_replicas:
            # If sync_replicas is enabled, the averaging will be done in the chief
            # queue runner.
            optimizer = tf.train.SyncReplicasOptimizer(
                opt=optimizer,
                replicas_to_aggregate=FLAGS.replicas_to_aggregate,
                variable_averages=variable_averages,
                variables_to_average=moving_average_variables,
                replica_id=tf.constant(FLAGS.task, tf.int32, shape=()),
                total_num_replicas=FLAGS.worker_replicas)
        elif FLAGS.moving_average_decay:
            # Update ops executed locally by trainer.
            update_ops.append(
                variable_averages.apply(moving_average_variables))

        # Variables to train.
        variables_to_train = _get_variables_to_train()

        #  and returns a train_tensor and summary_op
        total_loss, clones_gradients = model_deploy.optimize_clones(
            clones, optimizer, var_list=variables_to_train)

        ###add by lzlu
        variables = tf.model_variables()
        slim.model_analyzer.analyze_vars(variables, print_info=True)
        ##print("variables_to_train:",variables_to_train)
        ##print("clones_gradients_before_pruning:",clones_gradients)
        variables_to_pruning = get_variables_to_pruning()
        pruningMask = get_pruning_mask(variables_to_pruning)
        ##print("pruningMask__grad:",pruningMask)
        ##print("My_variables_to_pruning__grad:",variables_to_pruning)
        clones_gradients = apply_pruning_to_grad(clones_gradients, pruningMask)
        ##print("clones_gradients_after_pruning:",clones_gradients)
        ##print("slim.get_model_variables():",slim.get_model_variables())
        ###

        # Add total_loss to summary.
        summaries.add(tf.summary.scalar('total_loss', total_loss))

        # Create gradient updates.
        grad_updates = optimizer.apply_gradients(clones_gradients,
                                                 global_step=global_step)

        update_ops.append(grad_updates)

        update_op = tf.group(*update_ops)
        with tf.control_dependencies([update_op]):
            train_tensor = tf.identity(total_loss, name='train_op')

        # Add the summaries from the first clone. These contain the summaries
        # created by model_fn and either optimize_clones() or _gather_clone_loss().
        summaries |= set(
            tf.get_collection(tf.GraphKeys.SUMMARIES, first_clone_scope))

        # Merge all summaries together.
        summary_op = tf.summary.merge(list(summaries), name='summary_op')

        ### add for pruning
        #######################
        # Config mySaver      #
        #######################
        class mySaver(tf.train.Saver):
            def restore(self, sess, save_path):
                ##print("mySaver--restore...!")
                tf.train.Saver.restore(self, sess, save_path)
                variables_to_pruning = get_variables_to_pruning()
                ##print("My_variables_to_pruning__restore:",variables_to_pruning)
                pruningMask = apply_pruning_to_var(variables_to_pruning, sess)
                ##print("mySaver--restore done!")
            def save(self,
                     sess,
                     save_path,
                     global_step=None,
                     latest_filename=None,
                     meta_graph_suffix="meta",
                     write_meta_graph=True,
                     write_state=True):
                ##print("My Saver--save...!")
                tf.train.Saver.save(self, sess, save_path, global_step,
                                    latest_filename, meta_graph_suffix,
                                    write_meta_graph, write_state)
                ##print("My Saver--save done!")

        saver = mySaver(max_to_keep=2)
        ###

        ###########################
        # Kicks off the training. #
        ###########################
        slim.learning.train(
            train_tensor,
            logdir=FLAGS.train_dir,
            master=FLAGS.master,
            is_chief=(FLAGS.task == 0),
            init_fn=_get_init_fn(),
            summary_op=summary_op,
            number_of_steps=FLAGS.max_number_of_steps,
            log_every_n_steps=FLAGS.log_every_n_steps,
            save_summaries_secs=FLAGS.save_summaries_secs,
            saver=saver,  #add for pruning
            save_interval_secs=FLAGS.save_interval_secs,
            sync_optimizer=optimizer if FLAGS.sync_replicas else None)
Beispiel #5
0
def main(unused_argv):
    tf.logging.set_verbosity(tf.logging.INFO)
    # Set up deployment (i.e., multi-GPUs and/or multi-replicas).
    config = model_deploy.DeploymentConfig(num_clones=FLAGS.num_clones,
                                           clone_on_cpu=FLAGS.clone_on_cpu,
                                           replica_id=FLAGS.task,
                                           num_replicas=FLAGS.num_replicas,
                                           num_ps_tasks=FLAGS.num_ps_tasks)

    # Split the batch across GPUs.
    assert FLAGS.train_batch_size % config.num_clones == 0, (
        'Training batch size not divisble by number of clones (GPUs).')

    clone_batch_size = FLAGS.train_batch_size // config.num_clones

    tf.gfile.MakeDirs(FLAGS.train_logdir)
    tf.logging.info('Training on %s set', FLAGS.train_split)

    with tf.Graph().as_default() as graph:
        with tf.device(config.inputs_device()):
            dataset = data_generator.Dataset(
                dataset_name=FLAGS.dataset,
                split_name=FLAGS.train_split,
                dataset_dir=FLAGS.dataset_dir,
                batch_size=clone_batch_size,
                crop_size=[int(sz) for sz in FLAGS.train_crop_size],
                min_resize_value=FLAGS.min_resize_value,
                max_resize_value=FLAGS.max_resize_value,
                resize_factor=FLAGS.resize_factor,
                min_scale_factor=FLAGS.min_scale_factor,
                max_scale_factor=FLAGS.max_scale_factor,
                scale_factor_step_size=FLAGS.scale_factor_step_size,
                model_variant=FLAGS.model_variant,
                num_readers=4,
                is_training=True,
                should_shuffle=True,
                should_repeat=True)

        # Create the global step on the device storing the variables.
        with tf.device(config.variables_device()):
            global_step = tf.train.get_or_create_global_step()

            # Define the model and create clones.
            model_fn = _build_deeplab
            model_args = (dataset.get_one_shot_iterator(), {
                common.OUTPUT_TYPE: dataset.num_of_classes
            }, dataset.ignore_label)
            clones = model_deploy.create_clones(config,
                                                model_fn,
                                                args=model_args)

            # Gather update_ops from the first clone. These contain, for example,
            # the updates for the batch_norm variables created by model_fn.
            first_clone_scope = config.clone_scope(0)
            update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS,
                                           first_clone_scope)

        # Gather initial summaries.
        summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES))

        # Add summaries for model variables.
        for model_var in tf.model_variables():
            summaries.add(tf.summary.histogram(model_var.op.name, model_var))

        # Add summaries for images, labels, semantic predictions
        if FLAGS.save_summaries_images:
            summary_image = graph.get_tensor_by_name(
                ('%s/%s:0' % (first_clone_scope, common.IMAGE)).strip('/'))
            summaries.add(
                tf.summary.image('samples/%s' % common.IMAGE, summary_image))

            first_clone_label = graph.get_tensor_by_name(
                ('%s/%s:0' % (first_clone_scope, common.LABEL)).strip('/'))
            # Scale up summary image pixel values for better visualization.
            pixel_scaling = max(1, 255 // dataset.num_of_classes)
            summary_label = tf.cast(first_clone_label * pixel_scaling,
                                    tf.uint8)
            summaries.add(
                tf.summary.image('samples/%s' % common.LABEL, summary_label))

            first_clone_output = graph.get_tensor_by_name(
                ('%s/%s:0' %
                 (first_clone_scope, common.OUTPUT_TYPE)).strip('/'))
            predictions = tf.expand_dims(tf.argmax(first_clone_output, 3), -1)

            summary_predictions = tf.cast(predictions * pixel_scaling,
                                          tf.uint8)
            summaries.add(
                tf.summary.image('samples/%s' % common.OUTPUT_TYPE,
                                 summary_predictions))

        # Add summaries for losses.
        for loss in tf.get_collection(tf.GraphKeys.LOSSES, first_clone_scope):
            summaries.add(tf.summary.scalar('losses/%s' % loss.op.name, loss))

        # Build the optimizer based on the device specification.
        with tf.device(config.optimizer_device()):
            learning_rate = train_utils.get_model_learning_rate(
                FLAGS.learning_policy,
                FLAGS.base_learning_rate,
                FLAGS.learning_rate_decay_step,
                FLAGS.learning_rate_decay_factor,
                FLAGS.training_number_of_steps,
                FLAGS.learning_power,
                FLAGS.slow_start_step,
                FLAGS.slow_start_learning_rate,
                decay_steps=FLAGS.decay_steps,
                end_learning_rate=FLAGS.end_learning_rate)

            summaries.add(tf.summary.scalar('learning_rate', learning_rate))

            if FLAGS.optimizer == 'momentum':
                optimizer = tf.train.MomentumOptimizer(learning_rate,
                                                       FLAGS.momentum)
            elif FLAGS.optimizer == 'adam':
                optimizer = tf.train.AdamOptimizer(
                    learning_rate=FLAGS.adam_learning_rate,
                    epsilon=FLAGS.adam_epsilon)
            else:
                raise ValueError('Unknown optimizer')

        if FLAGS.quantize_delay_step >= 0:
            if FLAGS.num_clones > 1:
                raise ValueError(
                    'Quantization doesn\'t support multi-clone yet.')
            contrib_quantize.create_training_graph(
                quant_delay=FLAGS.quantize_delay_step)

        startup_delay_steps = FLAGS.task * FLAGS.startup_delay_steps

        with tf.device(config.variables_device()):
            total_loss, grads_and_vars = model_deploy.optimize_clones(
                clones, optimizer)
            total_loss = tf.check_numerics(total_loss, 'Loss is inf or nan.')
            summaries.add(tf.summary.scalar('total_loss', total_loss))

            # Modify the gradients for biases and last layer variables.
            last_layers = model.get_extra_layer_scopes(
                FLAGS.last_layers_contain_logits_only)
            grad_mult = train_utils.get_model_gradient_multipliers(
                last_layers, FLAGS.last_layer_gradient_multiplier)
            if grad_mult:
                grads_and_vars = slim.learning.multiply_gradients(
                    grads_and_vars, grad_mult)

            # Create gradient update op.
            grad_updates = optimizer.apply_gradients(grads_and_vars,
                                                     global_step=global_step)
            update_ops.append(grad_updates)
            update_op = tf.group(*update_ops)
            with tf.control_dependencies([update_op]):
                train_tensor = tf.identity(total_loss, name='train_op')

        # Add the summaries from the first clone. These contain the summaries
        # created by model_fn and either optimize_clones() or _gather_clone_loss().
        summaries |= set(
            tf.get_collection(tf.GraphKeys.SUMMARIES, first_clone_scope))

        # Merge all summaries together.
        summary_op = tf.summary.merge(list(summaries))

        # Soft placement allows placing on CPU ops without GPU implementation.
        session_config = tf.ConfigProto(allow_soft_placement=True,
                                        log_device_placement=False)

        # Start the training.
        profile_dir = FLAGS.profile_logdir
        if profile_dir is not None:
            tf.gfile.MakeDirs(profile_dir)

        with contrib_tfprof.ProfileContext(enabled=profile_dir is not None,
                                           profile_dir=profile_dir):
            init_fn = None
            if FLAGS.tf_initial_checkpoint:
                init_fn = train_utils.get_model_init_fn(
                    FLAGS.train_logdir,
                    FLAGS.tf_initial_checkpoint,
                    FLAGS.initialize_last_layer,
                    last_layers,
                    ignore_missing_vars=True)

            slim.learning.train(train_tensor,
                                logdir=FLAGS.train_logdir,
                                log_every_n_steps=FLAGS.log_steps,
                                master=FLAGS.master,
                                number_of_steps=FLAGS.training_number_of_steps,
                                is_chief=(FLAGS.task == 0),
                                session_config=session_config,
                                startup_delay_steps=startup_delay_steps,
                                init_fn=init_fn,
                                summary_op=summary_op,
                                save_summaries_secs=FLAGS.save_summaries_secs,
                                save_interval_secs=FLAGS.save_interval_secs)
def main(_):
    if not FLAGS.dataset_dir:
        raise ValueError(
            'You must supply the dataset directory with --dataset_dir')

    tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.INFO)
    with tf.Graph().as_default():
        #######################
        # Config model_deploy #
        #######################
        deploy_config = model_deploy.DeploymentConfig(
            num_clones=FLAGS.num_clones,
            clone_on_cpu=FLAGS.clone_on_cpu,
            replica_id=FLAGS.task,
            num_replicas=FLAGS.worker_replicas,
            num_ps_tasks=FLAGS.num_ps_tasks)

        # Create global_step
        with tf.device(deploy_config.variables_device()):
            global_step = slim.create_global_step()

        ######################
        # Select the dataset #
        ######################
        dataset = dataset_factory.get_dataset(FLAGS.dataset_name,
                                              FLAGS.dataset_split_name,
                                              FLAGS.dataset_dir)

        ######################
        # Select the network #
        ######################
        network_fn = nets_factory.get_network_fn(
            FLAGS.model_name,
            num_classes=(dataset.num_classes - FLAGS.labels_offset),
            weight_decay=FLAGS.weight_decay,
            is_training=True)

        #####################################
        # Select the preprocessing function #
        #####################################
        preprocessing_name = FLAGS.preprocessing_name or FLAGS.model_name
        image_preprocessing_fn = preprocessing_factory.get_preprocessing(
            preprocessing_name, is_training=True)

        ##############################################################
        # Create a dataset provider that loads data from the dataset #
        ##############################################################
        with tf.device(deploy_config.inputs_device()):
            provider = slim.dataset_data_provider.DatasetDataProvider(
                dataset,
                num_readers=FLAGS.num_readers,
                common_queue_capacity=20 * FLAGS.batch_size,
                common_queue_min=10 * FLAGS.batch_size)
            [image, label] = provider.get(['image', 'label'])
            label -= FLAGS.labels_offset

            train_image_size = FLAGS.train_image_size or network_fn.default_image_size

            image = image_preprocessing_fn(image, train_image_size,
                                           train_image_size)

            images, labels = tf.compat.v1.train.batch(
                [image, label],
                batch_size=FLAGS.batch_size,
                num_threads=FLAGS.num_preprocessing_threads,
                capacity=5 * FLAGS.batch_size)
            labels = slim.one_hot_encoding(
                labels, dataset.num_classes - FLAGS.labels_offset)
            batch_queue = slim.prefetch_queue.prefetch_queue(
                [images, labels], capacity=2 * deploy_config.num_clones)

        ####################
        # Define the model #
        ####################
        def clone_fn(batch_queue):
            """Allows data parallelism by creating multiple clones of network_fn."""
            images, labels = batch_queue.dequeue()
            logits, end_points = network_fn(images)

            #############################
            # Specify the loss function #
            #############################
            if 'AuxLogits' in end_points:
                slim.losses.softmax_cross_entropy(
                    end_points['AuxLogits'],
                    labels,
                    label_smoothing=FLAGS.label_smoothing,
                    weights=0.4,
                    scope='aux_loss')
            slim.losses.softmax_cross_entropy(
                logits,
                labels,
                label_smoothing=FLAGS.label_smoothing,
                weights=1.0)

            accuracy = slim.metrics.accuracy(
                tf.cast(tf.argmax(input=logits, axis=1), dtype=tf.int32),
                tf.cast(tf.argmax(input=labels, axis=1), dtype=tf.int32))
            tf.compat.v1.add_to_collection('accuracy', accuracy)
            end_points['train_accuracy'] = accuracy
            return end_points

        # Get accuracies for the batch

        # Gather initial summaries.
        summaries = set(
            tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.SUMMARIES))

        clones = model_deploy.create_clones(deploy_config, clone_fn,
                                            [batch_queue])
        first_clone_scope = deploy_config.clone_scope(0)
        # Gather update_ops from the first clone. These contain, for example,
        # the updates for the batch_norm variables created by network_fn.
        update_ops = tf.compat.v1.get_collection(
            tf.compat.v1.GraphKeys.UPDATE_OPS, first_clone_scope)

        # Add summaries for end_points.
        end_points = clones[0].outputs

        for end_point in end_points:
            if 'accuracy' in end_point:
                continue
            x = end_points[end_point]
            summaries.add(
                tf.compat.v1.summary.histogram('activations/' + end_point, x))
            summaries.add(
                tf.compat.v1.summary.scalar('sparsity/' + end_point,
                                            tf.nn.zero_fraction(x)))
        train_acc = end_points['train_accuracy']
        summaries.add(
            tf.compat.v1.summary.scalar('train_accuracy',
                                        end_points['train_accuracy']))

        # Add summaries for losses.
        for loss in tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.LOSSES,
                                                first_clone_scope):
            summaries.add(
                tf.compat.v1.summary.scalar('losses/%s' % loss.op.name, loss))

        # @philkuz
        # Add accuracy summaries
        # TODO add if statemetn for n iterations
        # images_val, labels_val= tf.train.batch(
        #     [image, label],
        #     batch_size=FLAGS.batch_size,
        #     num_threads=FLAGS.num_preprocessing_threads,
        #     capacity=5 * FLAGS.batch_size)

        # # labels_val = slim.one_hot_encoding(
        # #     labels_val, dataset.num_classes - FLAGS.labels_offset)
        # batch_queue_val = slim.prefetch_queue.prefetch_queue(
        #     [images_val, labels_val], capacity=2 * deploy_config.num_clones)
        # logits, end_points = network_fn(images, reuse=True)
        # # predictions = tf.nn.softmax(logits)
        # predictions = tf.to_in32(tf.argmax(logits,1))

        # logits_val, end_points_val = network_fn(images_val, reuse=True)
        # predictions_val = tf.to_in32(tf.argmax(logits_val,1))

        # labels_val = tf.squeeze(labels_val)
        # labels = tf.squeeze(labels)

        # names_to_values, names_to_updates = slim.metrics.aggregate_metric_map({
        #       'train/accuracy': slim.metrics.streaming_accuracy(predictions, labels),
        #       'val/accuracy': slim.metrics.streaming_accuracy(predictions_val, labels_val),
        # })
        # for metric_name, metric_value in names_to_values.items():
        #   op = tf.summary.scalar(metric_name, metric_value)
        #   # op = tf.Print(op, [metric_value], metric_name)
        #   summaries.add(op)
        # Add summaries for variables.
        # TODO something to remove some of these from tensorboard scalars
        for variable in slim.get_model_variables():
            summaries.add(
                tf.compat.v1.summary.histogram(variable.op.name, variable))

        #################################
        # Configure the moving averages #
        #################################
        if FLAGS.moving_average_decay:
            moving_average_variables = slim.get_model_variables()
            variable_averages = tf.train.ExponentialMovingAverage(
                FLAGS.moving_average_decay, global_step)
        else:
            moving_average_variables, variable_averages = None, None

        #########################################
        # Configure the optimization procedure. #
        #########################################
        with tf.device(deploy_config.optimizer_device()):
            learning_rate = _configure_learning_rate(dataset.num_samples,
                                                     global_step)
            optimizer = _configure_optimizer(learning_rate)
            summaries.add(
                tf.compat.v1.summary.scalar('learning_rate', learning_rate))

        if FLAGS.sync_replicas:
            # If sync_replicas is enabled, the averaging will be done in the chief
            # queue runner.
            optimizer = tf.compat.v1.train.SyncReplicasOptimizer(
                opt=optimizer,
                replicas_to_aggregate=FLAGS.replicas_to_aggregate,
                total_num_replicas=FLAGS.worker_replicas,
                variable_averages=variable_averages,
                variables_to_average=moving_average_variables)
        elif FLAGS.moving_average_decay:
            # Update ops executed locally by trainer.
            update_ops.append(
                variable_averages.apply(moving_average_variables))

        # Variables to train.
        variables_to_train = _get_variables_to_train()

        #  and returns a train_tensor and summary_op
        total_loss, clones_gradients = model_deploy.optimize_clones(
            clones, optimizer, var_list=variables_to_train)
        # Add total_loss to summary.
        summaries.add(tf.compat.v1.summary.scalar('total_loss', total_loss))

        # Create gradient updates.
        grad_updates = optimizer.apply_gradients(clones_gradients,
                                                 global_step=global_step)
        update_ops.append(grad_updates)

        update_op = tf.group(*update_ops)
        with tf.control_dependencies([update_op]):
            train_tensor = tf.identity(total_loss, name='train_op')

        # Add the summaries from the first clone. These contain the summaries
        # created by model_fn and either optimize_clones() or _gather_clone_loss().
        summaries |= set(
            tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.SUMMARIES,
                                        first_clone_scope))

        # Merge all summaries together.
        summary_op = tf.compat.v1.summary.merge(list(summaries),
                                                name='summary_op')

        # @philkuz
        # set the  max_number_of_steps parameter if num_epochs is available
        print('FLAGS.num_epochs', FLAGS.num_epochs)
        if FLAGS.num_epochs is not None and FLAGS.max_number_of_steps is None:
            FLAGS.max_number_of_steps = int(
                FLAGS.num_epochs * dataset.num_samples / FLAGS.batch_size)
            # FLAGS.max_number_of_steps = int(math.round(FLAGS.num_epochs / dataset.num_samples))

        # setup the logdir
        # @philkuz  the train_dir setup
        if FLAGS.experiment_name is not None:
            experiment_dir = 'bs={},lr={},epochs={}/{}'.format(
                FLAGS.batch_size, FLAGS.learning_rate, FLAGS.num_epochs,
                FLAGS.experiment_name)
            print(experiment_dir)
            FLAGS.train_dir = os.path.join(FLAGS.train_dir, experiment_dir)
            print(FLAGS.train_dir)

        # @philkuz overriding train_step
        def train_step(sess, train_op, global_step, train_step_kwargs):
            """Function that takes a gradient step and specifies whether to stop.
      Args:
        sess: The current session.
        train_op: An `Operation` that evaluates the gradients and returns the
          total loss.
        global_step: A `Tensor` representing the global training step.
        train_step_kwargs: A dictionary of keyword arguments.
      Returns:
        The total loss and a boolean indicating whether or not to stop training.
      Raises:
        ValueError: if 'should_trace' is in `train_step_kwargs` but `logdir` is not.
      """
            start_time = time.time()

            trace_run_options = None
            run_metadata = None
            should_acc = True  # TODO make this not hardcoded @philkuz
            if 'should_trace' in train_step_kwargs:
                if 'logdir' not in train_step_kwargs:
                    raise ValueError(
                        'logdir must be present in train_step_kwargs when '
                        'should_trace is present')
                if sess.run(train_step_kwargs['should_trace']):
                    trace_run_options = config_pb2.RunOptions(
                        trace_level=config_pb2.RunOptions.FULL_TRACE)
                    run_metadata = config_pb2.RunMetadata()
            if not should_acc:
                total_loss, np_global_step = sess.run(
                    [train_op, global_step],
                    options=trace_run_options,
                    run_metadata=run_metadata)
            else:
                total_loss, acc, np_global_step = sess.run(
                    [train_op, train_acc, global_step],
                    options=trace_run_options,
                    run_metadata=run_metadata)
            time_elapsed = time.time() - start_time

            if run_metadata is not None:
                tl = timeline.Timeline(run_metadata.step_stats)
                trace = tl.generate_chrome_trace_format()
                trace_filename = os.path.join(
                    train_step_kwargs['logdir'],
                    'tf_trace-%d.json' % np_global_step)
                tf.compat.v1.logging.info('Writing trace to %s',
                                          trace_filename)
                file_io.write_string_to_file(trace_filename, trace)
                if 'summary_writer' in train_step_kwargs:
                    train_step_kwargs['summary_writer'].add_run_metadata(
                        run_metadata, 'run_metadata-%d' % np_global_step)

            if 'should_log' in train_step_kwargs:
                if sess.run(train_step_kwargs['should_log']):
                    if not should_acc:
                        tf.compat.v1.logging.info(
                            'global step %d: loss = %.4f (%.3f sec/step)',
                            np_global_step, total_loss, time_elapsed)
                    else:
                        tf.compat.v1.logging.info(
                            'global step %d: loss = %.4f train_acc = %.4f (%.3f sec/step)',
                            np_global_step, total_loss, acc, time_elapsed)

            if 'should_stop' in train_step_kwargs:
                should_stop = sess.run(train_step_kwargs['should_stop'])
            else:
                should_stop = False

            return total_loss, should_stop

        ###########################
        # Kicks off the training. #
        ###########################
        slim.learning.train(
            train_tensor,
            logdir=FLAGS.train_dir,
            train_step_fn=train_step,
            master=FLAGS.master,
            is_chief=(FLAGS.task == 0),
            init_fn=_get_init_fn(),
            summary_op=summary_op,
            number_of_steps=FLAGS.max_number_of_steps,
            log_every_n_steps=FLAGS.log_every_n_steps,
            save_summaries_secs=FLAGS.save_summaries_secs,
            save_interval_secs=FLAGS.save_interval_secs,
            sync_optimizer=optimizer if FLAGS.sync_replicas else None)
Beispiel #7
0
    def init_batch(self):
        deploy_config = model_deploy.DeploymentConfig()

        with tf.device(deploy_config.variables_device()):
            global_step = slim.create_global_step()
            self.global_step = global_step

        tfrecord_list = os.listdir(FLAGS.dataset_dir)
        tfrecord_list = [
            os.path.join(FLAGS.dataset_dir, name) for name in tfrecord_list
            if name.endswith('tfrecords')
        ]
        file_queue = tf.train.string_input_producer(tfrecord_list,
                                                    num_epochs=1)

        reader = tf.TFRecordReader()
        _, serialized_example = reader.read(file_queue)

        features = tf.parse_single_example(serialized_example,
                                           features={
                                               'label':
                                               tf.FixedLenFeature([],
                                                                  tf.int64),
                                               'img':
                                               tf.FixedLenFeature([],
                                                                  tf.string),
                                               'img_height':
                                               tf.FixedLenFeature([],
                                                                  tf.int64),
                                               'img_width':
                                               tf.FixedLenFeature([],
                                                                  tf.int64),
                                               'cam':
                                               tf.FixedLenFeature([], tf.int64)
                                           })

        img = tf.decode_raw(features['img'], tf.uint8)
        img_height = tf.cast(features['img_height'], tf.int32)
        img_width = tf.cast(features['img_width'], tf.int32)
        img = tf.reshape(
            img,
            tf.stack([
                FLAGS.origin_height, FLAGS.origin_width, FLAGS.origin_channel
            ]))
        img = tf.image.convert_image_dtype(img, dtype=tf.float32)

        label = features['label']
        cam = features['cam']
        images, labels, cams = tf.train.batch([img, label, cam],
                                              batch_size=FLAGS.batch_size,
                                              capacity=3000,
                                              num_threads=4,
                                              allow_smaller_final_batch=True)
        # labels = tf.one_hot(labels, FLAGS.num_classes-FLAGS.labels_offset)

        #self.dataset = dataset
        self.deploy_config = deploy_config
        self.global_step = global_step

        self.images = images
        self.labels = labels
        self.cams = cams
def main(_):
    if not FLAGS.dataset_dir:
        raise ValueError(
            'You must supply the dataset directory with --dataset_dir')

    tf.logging.set_verbosity(tf.logging.INFO)
    with tf.Graph().as_default():
        #######################
        # Config model_deploy #
        #######################
        deploy_config = model_deploy.DeploymentConfig(
            num_clones=FLAGS.num_clones,
            clone_on_cpu=FLAGS.clone_on_cpu,
            replica_id=FLAGS.task,
            num_replicas=FLAGS.worker_replicas,
            num_ps_tasks=FLAGS.num_ps_tasks)

        # Create global_step
        with tf.device(deploy_config.variables_device()):
            global_step = slim.create_global_step()

        ######################
        # Select the dataset #
        ######################
        dataset = dataset_factory.get_dataset(FLAGS.dataset_name,
                                              FLAGS.dataset_split_name,
                                              FLAGS.dataset_dir)

        ######################
        # Select the network #
        ######################
        network_fn = nets_factory.get_network_fn(
            FLAGS.model_name,
            num_classes=(dataset.num_classes - FLAGS.labels_offset),
            weight_decay=FLAGS.weight_decay,
            is_training=False)

        #####################################
        # Select the preprocessing function #
        #####################################
        preprocessing_name = FLAGS.preprocessing_name or FLAGS.model_name
        image_preprocessing_fn = preprocessing_factory.get_preprocessing(
            preprocessing_name, is_training=False)

        ##############################################################
        # Create a dataset provider that loads data from the dataset #
        ##############################################################
        with tf.device(deploy_config.inputs_device()):
            provider = slim.dataset_data_provider.DatasetDataProvider(
                dataset,
                num_readers=FLAGS.num_readers,
                common_queue_capacity=20 * FLAGS.batch_size,
                common_queue_min=10 * FLAGS.batch_size)
            [image, label] = provider.get(['image', 'label'])
            label -= FLAGS.labels_offset

            train_image_size = FLAGS.train_image_size or network_fn.default_image_size
            print(train_image_size)

            image = image_preprocessing_fn(image, train_image_size,
                                           train_image_size)

            images, labels = tf.train.batch(
                [image, label],
                batch_size=FLAGS.batch_size,
                num_threads=FLAGS.num_preprocessing_threads,
                capacity=5 * FLAGS.batch_size)
            labels = slim.one_hot_encoding(
                labels, dataset.num_classes - FLAGS.labels_offset)
            batch_queue = slim.prefetch_queue.prefetch_queue(
                [images, labels], capacity=2 * deploy_config.num_clones)

        images, labels = batch_queue.dequeue()
        print(images, labels)
        logits, end_points = network_fn(images)

        labels_to_class_names = dataset_utils.read_label_file(
            FLAGS.dataset_dir, filename='labels.txt')
        print(labels_to_class_names)

        with tf.Session() as sess:
            sess.run(tf.global_variables_initializer())

            coord = tf.train.Coordinator()
            threads = tf.train.start_queue_runners(coord=coord)

            images_np, labels_np = sess.run([images, labels])
            print(images_np.shape, labels_np.shape)

            for i in range(10):
                image_np, label_np = sess.run([images, labels])

                #             plt.imshow(image_np[0,:,:,:])
                #             plt.title('label name:'+str(label_np[0]))
                #             plt.show()

                cv2.imshow(
                    'label name:',
                    cv2.cvtColor(image_np[0, :, :, :], cv2.COLOR_RGB2BGR))
                print(labels_to_class_names[np.argmax(label_np[0])])
                cv2.waitKey(0)

            coord.request_stop()
            coord.join(threads)
Beispiel #9
0
def main(_):
    if not FLAGS.dataset_dir:
        raise ValueError(
            'You must supply the dataset directory with --dataset_dir')

    tf.logging.set_verbosity(tf.logging.INFO)
    with tf.Graph().as_default():
        #######################
        # Config model_deploy #
        #######################
        deploy_config = model_deploy.DeploymentConfig(
            num_clones=FLAGS.num_clones,
            clone_on_cpu=FLAGS.clone_on_cpu,
            replica_id=FLAGS.task,
            num_replicas=FLAGS.worker_replicas,
            num_ps_tasks=FLAGS.num_ps_tasks)

        # Create global_step
        with tf.device(deploy_config.variables_device()):
            global_step = slim.create_global_step()

        ######################
        # Select the dataset #
        ######################
        dataset = dataset_factory.get_dataset(FLAGS.dataset_name,
                                              FLAGS.dataset_split_name,
                                              FLAGS.dataset_dir)

        ######################
        # Select the network #
        ######################
        network_fn = nets_factory.get_network_fn(
            FLAGS.model_name,
            num_classes=(dataset.num_classes - FLAGS.labels_offset),
            weight_decay=FLAGS.weight_decay,
            is_training=True)

        #####################################
        # Select the preprocessing function #
        #####################################
        preprocessing_name = FLAGS.preprocessing_name or FLAGS.model_name
        image_preprocessing_fn = preprocessing_factory.get_preprocessing(
            preprocessing_name, is_training=True)

        ##############################################################
        # Create a dataset provider that loads data from the dataset #
        ##############################################################
        with tf.device(deploy_config.inputs_device()):
            provider = slim.dataset_data_provider.DatasetDataProvider(
                dataset,
                num_readers=FLAGS.num_readers,
                common_queue_capacity=20 * FLAGS.batch_size,
                common_queue_min=10 * FLAGS.batch_size)
            [image, label] = provider.get(['image', 'label'])
            label -= FLAGS.labels_offset

            train_image_size = FLAGS.train_image_size or network_fn.default_image_size

            image = image_preprocessing_fn(image, train_image_size,
                                           train_image_size)

            images, labels = tf.train.batch(
                [image, label],
                batch_size=FLAGS.batch_size,
                num_threads=FLAGS.num_preprocessing_threads,
                capacity=5 * FLAGS.batch_size)
            labels = slim.one_hot_encoding(
                labels, dataset.num_classes - FLAGS.labels_offset)
            batch_queue = slim.prefetch_queue.prefetch_queue(
                [images, labels], capacity=2 * deploy_config.num_clones)

        ####################
        # Define the model #
        ####################
        def clone_fn(batch_queue):
            """Allows data parallelism by creating multiple clones of network_fn."""
            images, labels = batch_queue.dequeue()
            logits, end_points = network_fn(images)

            #############################
            # Specify the loss function #
            #############################
            if 'AuxLogits' in end_points:
                slim.losses.softmax_cross_entropy(
                    end_points['AuxLogits'],
                    labels,
                    label_smoothing=FLAGS.label_smoothing,
                    weights=0.4,
                    scope='aux_loss')
            slim.losses.softmax_cross_entropy(
                logits,
                labels,
                label_smoothing=FLAGS.label_smoothing,
                weights=1.0)
            return end_points

        # Gather initial summaries.
        summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES))

        clones = model_deploy.create_clones(deploy_config, clone_fn,
                                            [batch_queue])
        first_clone_scope = deploy_config.clone_scope(0)
        # Gather update_ops from the first clone. These contain, for example,
        # the updates for the batch_norm variables created by network_fn.
        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS,
                                       first_clone_scope)

        # Add summaries for end_points.
        end_points = clones[0].outputs
        for end_point in end_points:
            x = end_points[end_point]
            summaries.add(tf.summary.histogram('activations/' + end_point, x))
            summaries.add(
                tf.summary.scalar('sparsity/' + end_point,
                                  tf.nn.zero_fraction(x)))

        # Add summaries for losses.
        for loss in tf.get_collection(tf.GraphKeys.LOSSES, first_clone_scope):
            summaries.add(tf.summary.scalar('losses/%s' % loss.op.name, loss))

        # Add summaries for variables.
        for variable in slim.get_model_variables():
            summaries.add(tf.summary.histogram(variable.op.name, variable))

        #################################
        # Configure the moving averages #
        #################################
        if FLAGS.moving_average_decay:
            moving_average_variables = slim.get_model_variables()
            variable_averages = tf.train.ExponentialMovingAverage(
                FLAGS.moving_average_decay, global_step)
        else:
            moving_average_variables, variable_averages = None, None

        if FLAGS.quantize_delay >= 0:
            tf.contrib.quantize.create_training_graph(
                quant_delay=FLAGS.quantize_delay)

        #########################################
        # Configure the optimization procedure. #
        #########################################
        with tf.device(deploy_config.optimizer_device()):
            learning_rate = _configure_learning_rate(dataset.num_samples,
                                                     global_step)
            optimizer = _configure_optimizer(learning_rate)
            summaries.add(tf.summary.scalar('learning_rate', learning_rate))

        if FLAGS.sync_replicas:
            # If sync_replicas is enabled, the averaging will be done in the chief
            # queue runner.
            optimizer = tf.train.SyncReplicasOptimizer(
                opt=optimizer,
                replicas_to_aggregate=FLAGS.replicas_to_aggregate,
                total_num_replicas=FLAGS.worker_replicas,
                variable_averages=variable_averages,
                variables_to_average=moving_average_variables)
        elif FLAGS.moving_average_decay:
            # Update ops executed locally by trainer.
            update_ops.append(
                variable_averages.apply(moving_average_variables))

        # Variables to train.
        variables_to_train = _get_variables_to_train()

        #  and returns a train_tensor and summary_op
        total_loss, clones_gradients = model_deploy.optimize_clones(
            clones, optimizer, var_list=variables_to_train)
        # Add total_loss to summary.
        summaries.add(tf.summary.scalar('total_loss', total_loss))

        # Create gradient updates.
        grad_updates = optimizer.apply_gradients(clones_gradients,
                                                 global_step=global_step)
        update_ops.append(grad_updates)

        update_op = tf.group(*update_ops)
        with tf.control_dependencies([update_op]):
            train_tensor = tf.identity(total_loss, name='train_op')

        # Add the summaries from the first clone. These contain the summaries
        # created by model_fn and either optimize_clones() or _gather_clone_loss().
        summaries |= set(
            tf.get_collection(tf.GraphKeys.SUMMARIES, first_clone_scope))

        # Merge all summaries together.
        summary_op = tf.summary.merge(list(summaries), name='summary_op')

        ###########################
        # Kicks off the training. #
        ###########################

        config = tf.ConfigProto(inter_op_parallelism_threads=0,
                                intra_op_parallelism_threads=0)
        config.gpu_options.allow_growth = FLAGS.allow_growth

        if FLAGS.swa_weight_dir:
            print('Implement SWA from existed ckpt files.')
            saver = tf.train.Saver(max_to_keep=30)

            ckpt_list = set()
            for fname in os.listdir(FLAGS.swa_weight_dir):
                if 'ckpt' in fname:
                    ckpt_list.add(
                        os.path.join(
                            FLAGS.swa_weight_dir,
                            fname.split('.')[0] + '.' + fname.split('.')[1]))
            ckpt_list = list(ckpt_list)

            weight_all_epoch = []
            weight_names = {}
            for var in tf.trainable_variables():
                weight_names[var.name] = var.name

            for ckpt_file in ckpt_list:
                print('Dealing with:', ckpt_file)
                with tf.Session(config=config) as sess:
                    variables_to_restore = slim.get_model_variables()
                    saver.restore(sess, ckpt_file)

                    weight_current = sess.run(weight_names)
                    weight_all_epoch.append(weight_current)

            with tf.Session(config=config) as sess:
                variables_to_restore = slim.get_variables_to_restore()
                saver.restore(sess, ckpt_list[-1])

                tf.logging.info('Computing SWA...')
                weight_avg = {}
                for weight_name in weight_names:
                    value_sum = 0
                    for weight_dict in weight_all_epoch:
                        value_sum += weight_dict[weight_name]
                    weight_avg[weight_name] = value_sum / len(weight_all_epoch)

                for var in slim.get_variables_to_restore():
                    if var.name in weight_avg:
                        sess.run(var.assign(weight_avg[var.name]))
                    elif var.name.split('/ExponentialMovingAverage'
                                        )[0] + ':0' in weight_avg:
                        sess.run(
                            var.assign(weight_avg[var.name.split(
                                '/ExponentialMovingAverage')[0] + ':0']))

                print('Update BN mean/std...')
                bn_update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
                bn_update_op = tf.group(*bn_update_ops)
                '''
        for var in slim.get_model_variables():
          if 'moving_mean' in var.name or 'moving_variance' in var.name:
            sess.run(var.assign(tf.zeros_like(var)))
        '''

                coord = tf.train.Coordinator()
                threads = tf.train.start_queue_runners(sess=sess, coord=coord)

                num_step_per_epoch = int(1281167 / FLAGS.batch_size /
                                         FLAGS.num_clones)
                for batch_id in range(num_step_per_epoch * 10):
                    sess.run(bn_update_op)
                    if not batch_id % 100:
                        tf.logging.info('Update BN mean/std: global step: %d' %
                                        batch_id)
                    if not batch_id % int(0.3 * num_step_per_epoch):
                        print('Save swa model at BN step %d.' % batch_id)
                        saver.save(
                            sess, FLAGS.train_dir + '/model_swa.ckpt-' +
                            str(batch_id / num_step_per_epoch))

                coord.request_stop()
                coord.join(threads)

                print('Save final swa model.')
                saver.save(sess, FLAGS.train_dir + '/model_swa.ckpt')

        else:
            init_fn = _get_init_fn(global_step)

            with tf.Session(config=config) as sess:
                sess.run(tf.global_variables_initializer())
                sess.run(tf.local_variables_initializer())

                saver = tf.train.Saver(max_to_keep=30)
                summary_writer = tf.summary.FileWriter(FLAGS.train_dir,
                                                       sess.graph)

                if FLAGS.swa_delay is not None:
                    weight_all_epoch = []

                    weight_names = {}
                    for var in tf.trainable_variables():
                        weight_names[var.name] = var.name

                if (init_fn):
                    init_fn(sess)

                coord = tf.train.Coordinator()
                threads = tf.train.start_queue_runners(sess=sess, coord=coord)

                num_step_per_epoch = int(1281167 / FLAGS.batch_size /
                                         FLAGS.num_clones)

                for epoch_id in range(FLAGS.max_epoch):
                    for batch_id in range(num_step_per_epoch):
                        loss, global_step_out = sess.run(
                            [train_tensor, global_step])
                        if not batch_id % 10:
                            tf.logging.info(
                                'Global step: %d, Training loss: %f,' %
                                (global_step_out, loss))
                        if not batch_id % 300:
                            summary_out = sess.run(summary_op)
                            summary_writer.add_summary(summary_out,
                                                       global_step_out)
                        if batch_id and not batch_id % 4000:
                            saver.save(sess,
                                       FLAGS.train_dir + '/model.ckpt',
                                       global_step=global_step_out)

                        if FLAGS.swa_delay is not None:
                            if FLAGS.swa_delay <= epoch_id and not batch_id % int(
                                    FLAGS.num_epoch_per_swa *
                                    num_step_per_epoch - 1) and batch_id:
                                weight_current = sess.run(weight_names)
                                weight_all_epoch.append(weight_current)

                                checkdir(FLAGS.train_dir + '/swa_weight')
                                tf.logging.info(
                                    'Save checkpoint at global step %d' %
                                    global_step_out)
                                saver.save(sess,
                                           FLAGS.train_dir +
                                           '/swa_weight/model.ckpt',
                                           global_step=global_step_out)

                if FLAGS.swa_delay is not None:
                    tf.logging.info('Computing SWA...')
                    weight_avg = {}
                    for weight_name in weight_names:
                        value_sum = 0
                        for weight_dict in weight_all_epoch:
                            value_sum += weight_dict[weight_name]
                        weight_avg[weight_name] = value_sum / len(
                            weight_all_epoch)

                    for var in slim.get_variables_to_restore():
                        if var.name in weight_avg:
                            sess.run(var.assign(weight_avg[var.name]))
                        elif var.name.split('/ExponentialMovingAverage'
                                            )[0] + ':0' in weight_avg:
                            sess.run(
                                var.assign(weight_avg[var.name.split(
                                    '/ExponentialMovingAverage')[0] + ':0']))

                    print('Update BN mean/std...')
                    bn_update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
                    bn_update_op = tf.group(*bn_update_ops)

                    for batch_id in range(
                            int(1281167 / FLAGS.batch_size / FLAGS.num_clones)
                            * 10):
                        sess.run(bn_update_op)
                        if not batch_id % 100:
                            tf.logging.info(
                                'Update BN mean/std: global step: %d' %
                                batch_id)
                        if not batch_id % int(0.3 * num_step_per_epoch):
                            print('Save swa model at BN step %d.' % batch_id)
                            saver.save(
                                sess, FLAGS.train_dir + '/model_swa.ckpt-' +
                                str(batch_id / num_step_per_epoch))

                    print('Save final swa model.')
                    saver.save(sess, FLAGS.train_dir + '/model_swa.ckpt')

                coord.request_stop()
                coord.join(threads)
def main(_):
    if not FLAGS.dataset_dir:
        raise ValueError(
            'You must supply the dataset directory with --dataset_dir')

    tf.logging.set_verbosity(tf.logging.INFO)
    with tf.Graph().as_default():
        #######################
        # Config model_deploy #
        #######################
        deploy_config = model_deploy.DeploymentConfig(
            num_clones=FLAGS.num_clones,  # Clone 对象的个数
            clone_on_cpu=FLAGS.clone_on_cpu,  # 布尔类型变量,表示是否将 Clone 对象部署在 CPU 上
            replica_id=FLAGS.task,  # worker 或 PS 进程的 ID
            num_replicas=FLAGS.worker_replicas,  # worker 任务数(详见 5.2 节)
            num_ps_tasks=FLAGS.num_ps_tasks)  # PS 任务数

        # Create global_step
        with tf.device(deploy_config.variables_device()):
            global_step = slim.create_global_step()

        ######################
        # Select the dataset #
        ######################
        # 根据 FLAGS 指定的数据集名字 dataset_name(如 imagenet)、
        # 数据集被分割后的子数据集名称 dataset_split_name(如 train)
        # 和数据集所在的绝对路径 dataset_dir,从 dataset_factory 中获得数据集对象 dataset
        dataset = dataset_factory.get_dataset(FLAGS.dataset_name,
                                              FLAGS.dataset_split_name,
                                              FLAGS.dataset_dir)

        ######################
        # Select the network #
        ######################
        # 根据 FLAGS 指定的模型名称 model_name(如 alexnet_v2)、
        # 分类类别数 num_classes 和权值衰减 weight_decay(即 L2 正则项前面的系数),
        # 从 nets_factory 中获得模型函数对象 network_fn
        network_fn = nets_factory.get_network_fn(
            FLAGS.model_name,
            num_classes=(dataset.num_classes - FLAGS.labels_offset),
            weight_decay=FLAGS.weight_decay,
            is_training=True)

        #####################################
        # Select the preprocessing function #
        #####################################
        # 指定预处理函数名
        preprocessing_name = FLAGS.preprocessing_name or FLAGS.model_name
        # 根据预处理函数名,从 preprocessing_factory 中获得图像预处理函数对象
        # image_preprocessing_fn。
        image_preprocessing_fn = preprocessing_factory.get_preprocessing(
            preprocessing_name, is_training=True)

        ##############################################################
        # Create a dataset provider that loads data from the dataset #
        ##############################################################
        with tf.device(deploy_config.inputs_device()):
            # FLAGS.num_readers 指定了同时读取数据集的线程数(默认为 4),
            # 不同线程读取的数据入队到 common_queue 中。此处默认设定 common_queue
            # 的最大容量为训练批大小(batch_size)的 20 倍。common_queue_min
            # 表示 common_queue 队列中最少保留的数据量,默认设定为训练批大小的 10 倍
            provider = slim.dataset_data_provider.DatasetDataProvider(
                dataset,
                num_readers=FLAGS.num_readers,
                common_queue_capacity=20 * FLAGS.batch_size,
                common_queue_min=10 * FLAGS.batch_size)
            # 如 9.2.1 节所述,可以根据 key 值 image 和 label 从 provider 对象中获得训练数据及其标签张量
            [image, label] = provider.get(['image', 'label'])
            # 因为在 VGG 或 ResNet 模型中,背景没有被当作分类数据集中的一个类别,
            # 所以当训练这两类模型时,labels_offset 要被设置为 1
            label -= FLAGS.labels_offset
            # 设定训练时输入图像的分辨率
            train_image_size = FLAGS.train_image_size or network_fn.default_image_size
            # 训练数据经过图像预处理函数处理
            image = image_preprocessing_fn(image, train_image_size,
                                           train_image_size)
            # 通过 FLAGS.num_preprocessing_threads 指定的线程数并行读取,
            # 得到当前迭代用到的训练数据 images 和 labels 张量
            images, labels = tf.train.batch(
                [image, label],
                batch_size=FLAGS.batch_size,
                num_threads=FLAGS.num_preprocessing_threads,
                capacity=5 * FLAGS.batch_size)
            labels = slim.one_hot_encoding(
                labels, dataset.num_classes - FLAGS.labels_offset)
            # 调用 prefetch_queue 方法,启动一个 QueueRunner 对象用于保存预先准备好、
            # 即将被训练的数据。准备好的数据放在缓冲区队列 batch_queue 中
            batch_queue = slim.prefetch_queue.prefetch_queue(
                [images, labels], capacity=2 * deploy_config.num_clones)

        ####################
        # Define the model #
        ####################
        def clone_fn(batch_queue):
            """Allows data parallelism by creating multiple clones of network_fn."""
            # 从 batch_queue 中得到本次迭代所需要的训练数据————images 和 labels
            images, labels = batch_queue.dequeue()
            # 调用 network_fn,得到 CNN 模型最后一层的输出张量 logits,
            # 以及由 CNN 模型中每层的输出张量所组成的集合 end_points
            logits, end_points = network_fn(images)

            #############################
            # Specify the loss function #
            #############################
            # 在某些 CNN 模型(如 Inception V3)中,为了减少梯度消失现象,
            # 模型中间某一个或多个层的输出被用于辅助分类。这些层的输出张量为 AuxLogits
            if 'AuxLogits' in end_points:
                # 将辅助分类层的损失函数值也计算在模型整体的损失值中。
                # weight 参数表示辅助分类层对应的损失值在计入总损失值时被乘的折扣系数
                slim.losses.softmax_cross_entropy(
                    end_points['AuxLogits'],
                    labels,
                    label_smoothing=FLAGS.label_smoothing,
                    weights=0.4,
                    scope='aux_loss')
            # 计算最后分类层所对应的损失值
            slim.losses.softmax_cross_entropy(
                logits,
                labels,
                label_smoothing=FLAGS.label_smoothing,
                weights=1.0)
            # 返回模型每层的输出张量所组成的集合
            return end_points

        # Gather initial summaries.
        summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES))

        clones = model_deploy.create_clones(deploy_config, clone_fn,
                                            [batch_queue])
        first_clone_scope = deploy_config.clone_scope(0)
        # Gather update_ops from the first clone. These contain, for example,
        # the updates for the batch_norm variables created by network_fn.
        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS,
                                       first_clone_scope)

        # Add summaries for end_points.
        end_points = clones[0].outputs
        for end_point in end_points:
            x = end_points[end_point]
            summaries.add(tf.summary.histogram('activations/' + end_point, x))
            summaries.add(
                tf.summary.scalar('sparsity/' + end_point,
                                  tf.nn.zero_fraction(x)))

        # Add summaries for losses.
        for loss in tf.get_collection(tf.GraphKeys.LOSSES, first_clone_scope):
            summaries.add(tf.summary.scalar('losses/%s' % loss.op.name, loss))

        # Add summaries for variables.
        for variable in slim.get_model_variables():
            summaries.add(tf.summary.histogram(variable.op.name, variable))

        #################################
        # Configure the moving averages #
        #################################
        if FLAGS.moving_average_decay:
            # 如果衰减率 FLAGS.moving_average_decay 的值被指定,
            # 则 moving_average_variables 表示具有滑动平均特性的模型参数变量,
            # variable_averages 表示相应的滑动平均变量
            moving_average_variables = slim.get_model_variables()
            variable_averages = tf.train.ExponentialMovingAverage(
                FLAGS.moving_average_decay, global_step)
        else:
            moving_average_variables, variable_averages = None, None

        #########################################
        # Configure the optimization procedure. #
        #########################################
        with tf.device(deploy_config.optimizer_device()):
            # 当前,在学习速率的调整方面,支持 exponential、fixed、polynomial 这三种策略
            learning_rate = _configure_learning_rate(dataset.num_samples,
                                                     global_step)
            # 根据 FLAGS 所指定的优化器类型创建相应的优化器 optimizer。
            # 当前支持 adadelta、adagrad、adam、ftrl、momentum、rmsprop 和 sgd 这七种优化器
            optimizer = _configure_optimizer(learning_rate)
            summaries.add(tf.summary.scalar('learning_rate', learning_rate))

        if FLAGS.sync_replicas:
            # If sync_replicas is enabled, the averaging will be done in the chief
            # queue runner.
            # 如 5.2 节所述,进行分布式计算时,需要定义同步优化器。
            # 当前开源的 train_image_classifier.py 对分布式支持还不完善,
            # 此处代码需要配合 tf.train.ClusterSpec、tf.train.Server 等接口一起使用,才能实现分布式训练
            optimizer = tf.train.SyncReplicasOptimizer(
                opt=optimizer,
                replicas_to_aggregate=FLAGS.replicas_to_aggregate,
                total_num_replicas=FLAGS.worker_replicas,
                variable_averages=variable_averages,
                variables_to_average=moving_average_variables)
        elif FLAGS.moving_average_decay:
            # 如果衰减率 FLAGS.moving_average_decay 的值被指定,则对模型参数更新采取滑动平均操作
            update_ops.append(
                variable_averages.apply(moving_average_variables))

        # Variables to train.
        # 此段代码类似于 9.2.3 节介绍的 deploy 方法的部分代码(当 optimizer 非 None 时)
        variables_to_train = _get_variables_to_train()

        #  and returns a train_tensor and summary_op
        total_loss, clones_gradients = model_deploy.optimize_clones(
            clones, optimizer, var_list=variables_to_train)
        # Add total_loss to summary.
        summaries.add(tf.summary.scalar('total_loss', total_loss))

        # Create gradient updates.
        grad_updates = optimizer.apply_gradients(clones_gradients,
                                                 global_step=global_step)
        update_ops.append(grad_updates)

        update_op = tf.group(*update_ops)
        with tf.control_dependencies([update_op]):
            train_tensor = tf.identity(total_loss, name='train_op')

        # Add the summaries from the first clone. These contain the summaries
        # created by model_fn and either optimize_clones() or _gather_clone_loss().
        summaries |= set(
            tf.get_collection(tf.GraphKeys.SUMMARIES, first_clone_scope))

        # Merge all summaries together.
        summary_op = tf.summary.merge(list(summaries), name='summary_op')

        ###########################
        # Kicks off the training. #
        ###########################
        slim.learning.train(
            train_tensor,  # 单步迭代的训练操作
            logdir=FLAGS.train_dir,  # 训练过程中日志和模型检查点文件等存放的目录
            master=FLAGS.master,  # master 的地址,在单机训练时没有用到
            is_chief=(
                FLAGS.task == 0),  # 当前 worker 是否为 chief worker(在分布式训练场景中用到)
            init_fn=_get_init_fn(),  # 模型初始化函数
            summary_op=summary_op,  # summary 操作
            number_of_steps=FLAGS.max_number_of_steps,  # 最大训练步数
            log_every_n_steps=FLAGS.log_every_n_steps,  # 输出日志的间隔(以步数为单位)
            save_summaries_secs=FLAGS.
            save_summaries_secs,  # 输出 summary 日志的间隔(以秒为单位)
            save_interval_secs=FLAGS.save_interval_secs,  # 保存模型检查点文件的间隔(以秒为单位)
            sync_optimizer=optimizer
            if FLAGS.sync_replicas else None)  # 同步优化器(在单机训练时为 None)
def main(_):
    if not FLAGS.dataset_dir:
        raise ValueError(
            'You must supply the dataset directory with --dataset_dir')

    tf.logging.set_verbosity(tf.logging.INFO)
    with tf.Graph().as_default():
        #######################
        # Config model_deploy #
        #######################
        deploy_config = model_deploy.DeploymentConfig(
            num_clones=FLAGS.num_clones,  # 1
            clone_on_cpu=FLAGS.clone_on_cpu,  # False
            replica_id=FLAGS.task,  # 0
            num_replicas=FLAGS.worker_replicas,  # 1
            num_ps_tasks=FLAGS.num_ps_tasks)  # 0

        # Create global_step
        with tf.device(deploy_config.variables_device()):
            global_step = slim.create_global_step()

        ######################
        # Select the dataset #
        ######################
        dataset = dataset_factory.get_dataset(FLAGS.dataset_name,
                                              FLAGS.dataset_split_name,
                                              FLAGS.dataset_dir)

        num_classes = dataset.num_classes
        weights = [
            1.0
        ] * num_classes  # for weighted softmaxloss. set weights according to your needs
        weights = tf.constant(
            np.asarray(weights, dtype=np.float32).reshape(1, num_classes))

        ######################
        # Select the network #
        ######################
        network_fn = nets_factory.get_network_fn(
            FLAGS.model_name,
            num_classes=num_classes,
            weight_decay=FLAGS.weight_decay,
            is_training=True,
            freeze_bn=FLAGS.freeze_bn)
        #####################################
        # Select the preprocessing function #
        #####################################
        preprocessing_name = FLAGS.preprocessing_name or FLAGS.model_name
        image_preprocessing_fn = preprocessing_factory.get_preprocessing(
            preprocessing_name, is_training=True)

        ##############################################################
        # Create a dataset provider that loads data from the dataset #
        ##############################################################
        with tf.device(deploy_config.inputs_device()):
            provider = slim.dataset_data_provider.DatasetDataProvider(
                dataset,
                num_readers=FLAGS.num_readers,
                common_queue_capacity=20 * FLAGS.batch_size,
                common_queue_min=10 * FLAGS.batch_size)
            [image, label] = provider.get(['image', 'label'])

            train_image_size = FLAGS.train_image_size or network_fn.default_image_size

            image, label = image_preprocessing_fn(image,
                                                  train_image_size,
                                                  train_image_size,
                                                  label=label)

            images, labels = tf.train.batch(
                [image, label],
                batch_size=FLAGS.batch_size,
                num_threads=FLAGS.num_preprocessing_threads,
                capacity=5 * FLAGS.batch_size)

            labels_one_hot = slim.one_hot_encoding(labels, num_classes)

            batch_queue = slim.prefetch_queue.prefetch_queue(
                [images, labels_one_hot], capacity=2)

        ####################
        # Define the model #
        ####################
        def clone_fn(batch_queue):
            images, labels_one_hot = batch_queue.dequeue()
            tf.logging.info('images get_shape {}'.format(images.get_shape()))
            logits, end_points, aux_logits = network_fn(images)
            logits = tf.image.resize_bilinear(logits,
                                              size=(train_image_size,
                                                    train_image_size),
                                              align_corners=True)

            labels_one_hot = tf.reshape(labels_one_hot, logits.get_shape())

            # logits = tf.multiply(logits, weights)
            tf.losses.softmax_cross_entropy(
                labels_one_hot, logits, label_smoothing=FLAGS.label_smoothing)

            aux_logits = tf.image.resize_bilinear(aux_logits,
                                                  size=(train_image_size,
                                                        train_image_size),
                                                  align_corners=True)
            # aux_logits = tf.multiply(aux_logits, weights)
            tf.losses.softmax_cross_entropy(
                labels_one_hot,
                aux_logits,
                weights=0.4,
                label_smoothing=FLAGS.label_smoothing)

            return end_points

        # Gather initial summaries.
        summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES))

        clones = model_deploy.create_clones(deploy_config, clone_fn,
                                            [batch_queue])
        first_clone_scope = deploy_config.clone_scope(0)
        # Gather update_ops from the first clone. These contain, for example,
        # the updates for the batch_norm variables created by network_fn.
        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS,
                                       first_clone_scope)

        # Add summaries for end_points.
        end_points = clones[0].outputs
        for end_point in end_points:
            x = end_points[end_point]
            summaries.add(tf.summary.histogram('activations/' + end_point, x))
            summaries.add(
                tf.summary.scalar('sparsity/' + end_point,
                                  tf.nn.zero_fraction(x)))

        # Add summaries for losses.
        for loss in tf.get_collection(tf.GraphKeys.LOSSES, first_clone_scope):
            summaries.add(tf.summary.scalar('losses/%s' % loss.op.name, loss))

        # Add summaries for variables.
        for variable in slim.get_model_variables():
            summaries.add(tf.summary.histogram(variable.op.name, variable))

        #################################
        # Configure the moving averages #
        #################################
        if FLAGS.moving_average_decay:
            moving_average_variables = slim.get_model_variables()
            variable_averages = tf.train.ExponentialMovingAverage(
                FLAGS.moving_average_decay, global_step)
        else:
            moving_average_variables, variable_averages = None, None

        #########################################
        # Configure the optimization procedure. #
        #########################################
        with tf.device(deploy_config.optimizer_device()):
            learning_rate = _configure_learning_rate(dataset.num_samples,
                                                     global_step)
            optimizer = _configure_optimizer(learning_rate)
            summaries.add(tf.summary.scalar('learning_rate', learning_rate))

        if FLAGS.sync_replicas:  # False
            # If sync_replicas is enabled, the averaging will be done in the chief
            # queue runner.
            optimizer = tf.train.SyncReplicasOptimizer(
                opt=optimizer,
                replicas_to_aggregate=FLAGS.replicas_to_aggregate,
                variable_averages=variable_averages,
                variables_to_average=moving_average_variables,
                replica_id=tf.constant(FLAGS.task, tf.int32, shape=()),
                total_num_replicas=FLAGS.worker_replicas)
        elif FLAGS.moving_average_decay:
            # Update ops executed locally by trainer.
            update_ops.append(
                variable_averages.apply(moving_average_variables))

        # Variables to train.
        variables_to_train = _get_variables_to_train()
        #  and returns a train_tensor and summary_op
        total_loss, clones_gradients = model_deploy.optimize_clones(
            clones, optimizer, var_list=variables_to_train)
        # Add total_loss to summary.
        summaries.add(tf.summary.scalar('total_loss', total_loss))

        # Create gradient updates.
        grad_updates = optimizer.apply_gradients(clones_gradients,
                                                 global_step=global_step)
        update_ops.append(grad_updates)

        update_op = tf.group(*update_ops)
        with tf.control_dependencies([update_op]):
            train_tensor = tf.identity(total_loss, name='train_op')

        # Add the summaries from the first clone. These contain the summaries
        # created by model_fn and either optimize_clones() or _gather_clone_loss().
        summaries |= set(
            tf.get_collection(tf.GraphKeys.SUMMARIES, first_clone_scope))

        # Merge all summaries together.
        summary_op = tf.summary.merge(list(summaries), name='summary_op')

        session_config = tf.ConfigProto()
        session_config.gpu_options.allow_growth = True

        ###########################
        # Kicks off the training. #
        ###########################
        slim.learning.train(
            train_tensor,
            session_config=session_config,
            logdir=FLAGS.train_dir,
            master=FLAGS.master,  # master = ''
            is_chief=(FLAGS.task == 0),
            init_fn=_get_init_fn(),
            summary_op=summary_op,
            number_of_steps=FLAGS.max_number_of_steps,
            log_every_n_steps=FLAGS.log_every_n_steps,
            save_summaries_secs=FLAGS.save_summaries_secs,
            save_interval_secs=FLAGS.save_interval_secs,
            sync_optimizer=optimizer if FLAGS.sync_replicas else None)
Beispiel #12
0
def train():
    if not FLAGS.dataset_dir:
        raise ValueError(
            'You must supply the dataset directory with --dataset_dir')

    tf.logging.set_verbosity(tf.logging.INFO)
    with tf.Graph().as_default():
        #######################
        # Config model_deploy #
        #######################
        deploy_config = model_deploy.DeploymentConfig(
            num_clones=FLAGS.num_clones,
            clone_on_cpu=FLAGS.clone_on_cpu,
            replica_id=FLAGS.task,
            num_replicas=FLAGS.worker_replicas,
            num_ps_tasks=FLAGS.num_ps_tasks)

        # Create global_step
        with tf.device(deploy_config.variables_device()):
            global_step = tf.train.create_global_step()

        ######################
        # Select the network and #
        ######################
        network_fn = {}
        model_names = [net.strip() for net in FLAGS.model_name.split(',')]
        for i in range(FLAGS.num_networks):
            network_fn["{0}".format(i)] = nets_factory.get_network_fn(
                model_names[i],
                num_classes=FLAGS.num_classes,
                weight_decay=FLAGS.weight_decay)

        #########################################
        # Configure the optimization procedure. #
        #########################################
        with tf.device(deploy_config.optimizer_device()):
            net_opt = {}
            for i in range(FLAGS.num_networks):
                net_opt["{0}".format(i)] = tf.train.AdamOptimizer(
                    FLAGS.learning_rate,
                    beta1=FLAGS.adam_beta1,
                    beta2=FLAGS.adam_beta2,
                    epsilon=FLAGS.opt_epsilon)

        #####################################
        # Select the preprocessing function #
        #####################################
        preprocessing_name = FLAGS.preprocessing_name  # or FLAGS.model_name
        image_preprocessing_fn = preprocessing_factory.get_preprocessing(
            preprocessing_name, is_training=True)

        ##############################################################
        # Create a dataset provider that loads data from the dataset #
        ##############################################################
        train_image_batch, train_label_batch = utils.get_image_label_batch(
            FLAGS, shuffle=True, name='test2')
        semi_image_batch, semi_label_batch = utils.get_image_label_batch(
            FLAGS, shuffle=True, name='train1_train2')
        test_image_batch, test_label_batch = utils.get_image_label_batch(
            FLAGS, shuffle=False, name='test1')
        train_x = train_image_batch[:, :, :, 0:3]
        train_y = tf.cast(
            (tf.squeeze(train_image_batch[:, :, :, 3]) + 1.0) * 0.5, tf.int32)
        train_y = tf.one_hot(train_y, depth=2, axis=-1)

        semi_x = semi_image_batch[:, :, :, 0:3]
        semi_y = tf.cast(
            (tf.squeeze(semi_image_batch[:, :, :, 3]) + 1.0) * 0.5, tf.int32)
        semi_y = tf.one_hot(semi_y, depth=2, axis=-1)

        test_x = test_image_batch[:, :, :, 0:3]
        test_y = tf.cast(
            (tf.squeeze(test_image_batch[:, :, :, 3]) + 1.0) * 0.5, tf.int32)
        test_y = tf.one_hot(test_y, depth=2, axis=-1)


        precision, test_precision, val_precision, net_var_list, net_grads, net_update_ops, predictions, test_predictions,  val_predictions = {}, {}, {}, {}, {}, {}, {}, {}, {}
        semi_net_grads = {}

        with tf.name_scope('tower') as scope:
            with tf.variable_scope(tf.get_variable_scope()):
                net_loss, dice_loss, cross_loss, kl, exclusion, net_pred, offset = _tower_loss(
                    network_fn,
                    train_x,
                    train_y,
                    cross=True,
                    reuse=False,
                    is_training=True)
                semi_net_loss, _, _, _, _, _, _ = _tower_loss(network_fn,
                                                              train_x,
                                                              train_y,
                                                              cross=False,
                                                              reuse=True,
                                                              is_training=True)
                test_net_loss, _, _, _, _, test_net_pred, test_offset = _tower_loss(
                    network_fn,
                    test_x,
                    test_y,
                    cross=True,
                    reuse=True,
                    is_training=False)

                truth = tf.argmax(train_y, axis=-1)
                test_truth = tf.argmax(test_y, axis=-1)

                # Reuse variables for the next tower.
                #tf.get_variable_scope().reuse_variables()

                # Retain the summaries from the final tower.
                #summaries = tf.get_collection(tf.GraphKeys.SUMMARIES, scope)
                var_list = tf.trainable_variables()

                for i in range(FLAGS.num_networks):
                    predictions["{0}".format(i)] = tf.argmax(
                        net_pred["{0}".format(i)], axis=-1)
                    test_predictions["{0}".format(i)] = tf.argmax(
                        test_net_pred["{0}".format(i)], axis=-1)
                #precision["{0}".format(0)] = tf.reduce_mean(tf.to_float(tf.equal(predictions["{0}".format(0)], truth)))
                #test_precision["{0}".format(0)] = tf.reduce_mean(tf.to_float(tf.equal(test_predictions["{0}".format(0)], test_truth)))

                precision["{0}".format(0)] = 2 * tf.reduce_sum(
                    tf.to_float(predictions["{0}".format(0)]) *
                    tf.to_float(truth)) / tf.reduce_sum(
                        tf.to_float(predictions["{0}".format(0)] + truth))
                test_precision["{0}".format(0)] = 2 * tf.reduce_sum(
                    tf.to_float(test_predictions["{0}".format(0)]) *
                    tf.to_float(test_truth)) / tf.reduce_sum(
                        tf.to_float(test_predictions["{0}".format(0)] +
                                    test_truth))

                #precision["{0}".format(1)] = tf.reduce_mean(tf.to_float(tf.equal(predictions["{0}".format(1)], 1-truth)))
                # test_precision["{0}".format(1)] = tf.reduce_mean(tf.to_float(tf.equal(test_predictions["{0}".format(1)], 1-test_truth)))

                precision["{0}".format(1)] = 2 * tf.reduce_sum(
                    tf.to_float(predictions["{0}".format(1)]) *
                    tf.to_float(1 - truth)) / tf.reduce_sum(
                        tf.to_float(predictions["{0}".format(1)] + 1 - truth))
                test_precision["{0}".format(1)] = 2 * tf.reduce_sum(
                    tf.to_float(test_predictions["{0}".format(1)]) *
                    tf.to_float(1 - test_truth)) / tf.reduce_sum(
                        tf.to_float(test_predictions["{0}".format(1)] + 1 -
                                    test_truth))

                Gamma = {}
                for i in range(FLAGS.num_networks):
                    # Add a summary to track the training precision.
                    #summaries.append(tf.summary.scalar('precision_%d' % i, precision["{0}".format(i)]))
                    #summaries.append(tf.summary.scalar('test_precision_%d' % i, test_precision["{0}".format(i)]))
                    #summaries.append(tf.summary.scalar('val_precision_%d' % i, test_precision["{0}".format(i)]))

                    net_update_ops["{0}".format(i)] = \
                                tf.get_collection(tf.GraphKeys.UPDATE_OPS, scope=('%sdmlnet_%d' % (scope, i)))

                    net_var_list["{0}".format(i)] = \
                                    [var for var in var_list if 'dmlnet_%d' % i in var.name]

                    net_grads["{0}".format(i)] = net_opt["{0}".format(
                        i)].compute_gradients(
                            net_loss["{0}".format(i)],
                            var_list=net_var_list["{0}".format(i)])

                    semi_net_grads["{0}".format(i)] = net_opt["{0}".format(
                        i)].compute_gradients(
                            semi_net_loss["{0}".format(i)],
                            var_list=net_var_list["{0}".format(i)])
                Gamma["{0}".format(0)], Gamma["{0}".format(1)] = {}, {}
                for var in tf.trainable_variables():
                    if 'dmlnet_0' in var.name and 'GGamma' in var.name:
                        Gamma["{0}".format(0)][var.name] = var
                    if 'dmlnet_1' in var.name and 'GGamma' in var.name:
                        Gamma["{0}".format(1)][var.name] = var

        #################################
        # Configure the moving averages #
        #################################

        if FLAGS.moving_average_decay:
            moving_average_variables = {}
            all_moving_average_variables = slim.get_model_variables()
            variable_averages = tf.train.ExponentialMovingAverage(
                FLAGS.moving_average_decay, global_step)
            for i in range(FLAGS.num_networks):
                moving_average_variables["{0}".format(i)] = \
                    [var for var in all_moving_average_variables if 'dmlnet_%d' % i in var.name]
                net_update_ops["{0}".format(i)].append(
                    variable_averages.apply(
                        moving_average_variables["{0}".format(i)]))

        # Apply the gradients to adjust the shared variables.
        net_grad_updates, net_train_op, semi_net_grad_updates, semi_net_train_op = {}, {}, {}, {}
        for i in range(FLAGS.num_networks):
            net_grad_updates["{0}".format(i)] = net_opt["{0}".format(
                i)].apply_gradients(net_grads["{0}".format(i)],
                                    global_step=global_step)
            semi_net_grad_updates["{0}".format(i)] = net_opt["{0}".format(
                i)].apply_gradients(semi_net_grads["{0}".format(i)],
                                    global_step=global_step)
            net_update_ops["{0}".format(i)].append(
                net_grad_updates["{0}".format(i)])
            net_update_ops["{0}".format(i)].append(
                semi_net_grad_updates["{0}".format(i)])
            # Group all updates to into a single train op.
            net_train_op["{0}".format(i)] = tf.group(
                *net_update_ops["{0}".format(i)])
        '''# Apply the gradients to adjust the shared variables.
        net_train_op, semi_net_train_op = {}, {}
        for i in range(FLAGS.num_networks):
            net_train_op["{0}".format(i)] = net_opt["{0}".format(i)].minimize(net_loss["{0}".format(i)], global_step=global_step, var_list=net_var_list["{0}".format(i)])
            #semi_net_train_op["{0}".format(i)] = semi_net_opt["{0}".format(i)].minimize(semi_net_loss["{0}".format(i)],global_step=global_step, var_list=net_var_list["{0}".format(i)])'''

        # Create a saver.
        saver = tf.train.Saver(tf.global_variables())

        # Build the summary operation from the last tower summaries.
        #summary_op = tf.summary.merge(summaries)

        # Build an initialization operation to run below.
        init = tf.global_variables_initializer()

        # Start running operations on the Graph. allow_soft_placement must be set to
        # True to build towers on GPU, as some of the ops do not have GPU
        # implementations.
        sess = tf.Session(config=tf.ConfigProto(
            gpu_options=tf.GPUOptions(per_process_gpu_memory_fraction=0.85),
            allow_soft_placement=True,
            log_device_placement=FLAGS.log_device_placement))
        sess.run(init)

        #load_fn = slim.assign_from_checkpoint_fn(os.path.join(FLAGS.checkpoint_dir, 'model.ckpt-10'),tf.global_variables(),ignore_missing_vars=True)
        #load_fn = slim.assign_from_checkpoint_fn('./WCE_densenet4/checkpoint/model.ckpt-70',tf.global_variables(),ignore_missing_vars=False)
        #load_fn(sess)

        # Start the queue runners.
        tf.train.start_queue_runners(sess=sess)

        #summary_writer = tf.summary.FileWriter(
        #    os.path.join(FLAGS.log_dir),
        #    graph=sess.graph)

        net_loss_value, test_net_loss_value, precision_value, test_precision_value = {}, {}, {}, {}
        dice_loss_value, cross_loss_value, kl_value, exclusion_value = {}, {},{}, {}
        parameters = utils.count_trainable_params()
        print("Total training params: %.1fM \r\n" % (parameters / 1e6))

        start_time = time.time()
        counter = 0
        infile = open(os.path.join(FLAGS.log_dir, 'log.txt'), 'w')
        batch_count = np.int32(15552 / FLAGS.batch_size)

        GGamma, GG = {}, {}
        GGamma["{0}".format(0)], GGamma["{0}".format(1)] = {}, {}
        for i in range(FLAGS.num_networks):
            GG["{0}".format(i)] = sess.run([Gamma["{0}".format(i)]])
            for k in GG["{0}".format(i)][0].keys():
                if 'dmlnet_%d' % i in k and 'GGamma' in k:
                    GGamma_key = k.split(':')[0]
                    GGamma_key = '_'.join(GGamma_key.split('/'))
                    GGamma["{0}".format(i)][GGamma_key] = [
                        float(GG["{0}".format(i)][0][k])
                    ]

        for epoch in range(1, 1 + FLAGS.max_number_of_epochs):
            if (epoch) % 40 == 0:
                FLAGS.learning_rate = FLAGS.learning_rate * 0.1
            #if (epoch) % 75 == 0:
            #    FLAGS.learning_rate = FLAGS.learning_rate * 0.1

            for batch_idx in range(batch_count):
                counter += 1
                for i in range(FLAGS.num_networks):
                    _, net_loss_value["{0}".format(i)], dice_loss_value["{0}".format(i)], cross_loss_value["{0}".format(i)],kl_value["{0}".format(i)], exclusion_value["{0}".format(i)], precision_value["{0}".format(i)], offset_map = \
                    sess.run([net_train_op["{0}".format(i)], net_loss["{0}".format(i)],dice_loss["{0}".format(i)],cross_loss["{0}".format(i)], kl["{0}".format(i)], exclusion["{0}".format(i)], precision["{0}".format(i)], offset["{0}".format(i)]])
                    assert not np.isnan(net_loss_value["{0}".format(
                        i)]), 'Model diverged with loss = NaN'
                    #if epoch >= 20:
                    #    _ = sess.run([sc_update_op["{0}".format(i)]])

                if batch_idx % 500 == 0:
                    for i in range(FLAGS.num_networks):
                        test_net_loss_value["{0}".format(i)], GG["{0}".format(
                            i)], test_precision_value["{0}".format(
                                i)] = sess.run([
                                    test_net_loss["{0}".format(i)],
                                    Gamma["{0}".format(i)],
                                    test_precision["{0}".format(i)]
                                ])

                        for k in GG["{0}".format(i)].keys():
                            if 'dmlnet_%d' % i in k and 'GGamma' in k:
                                GGamma_key = k.split(':')[0]
                                GGamma_key = '_'.join(GGamma_key.split('/'))
                                GGamma["{0}".format(i)][GGamma_key].extend(
                                    [float(GG["{0}".format(i)][k])])

                    #format_str = 'Epoch: [%3d] [%3d/%3d] net0loss = %.4f, net0acc = %.4f, net0testloss = %.4f, net0testacc = %.4f,   net1loss = %.4f, net1acc = %.4f, net1testloss = %.4f, net1testacc = %.4f'
                    #print(format_str % (epoch, batch_idx,batch_count, net_loss_value["{0}".format(0)],
                    #      precision_value["{0}".format(0)],test_net_loss_value["{0}".format(0)],np.float32(test_precision_value["{0}".format(0)]),net_loss_value["{0}".format(1)],precision_value["{0}".format(1)],test_net_loss_value["{0}".format(1)],np.float32(test_precision_value["{0}".format(1)])))
                    #format_str1 = 'Epoch: [%3d] [%3d/%3d] time: %4.3f, dice0 = %.5f, cross0 = %.4f, kl0 = %.4f, exclusion0 = %.4f,     dice1 = %.5f, cross1 = %.4f, kl1 = %.4f, exclusion1 = %.4f'
                    #print(format_str1 % (epoch, batch_idx,batch_count, time.time()-start_time, dice_loss_value["{0}".format(0)],cross_loss_value["{0}".format(0)],np.float32(kl_value["{0}".format(0)]),np.float32(exclusion_value["{0}".format(0)]),dice_loss_value["{0}".format(1)],cross_loss_value["{0}".format(1)],np.float32(kl_value["{0}".format(1)]),np.float32(exclusion_value["{0}".format(1)])))

                    print(offset_map.max())
                    '''infile.write(format_str % (epoch, batch_idx,batch_count, net_loss_value["{0}".format(0)], precision_value["{0}".format(0)],test_net_loss_value["{0}".format(0)],np.float32(test_precision_value["{0}".format(0)]),net_loss_value["{0}".format(1)],precision_value["{0}".format(1)],test_net_loss_value["{0}".format(1)],np.float32(test_precision_value["{0}".format(1)])))
                    infile.write('\n')
                    infile.write(format_str1 % (epoch, batch_idx,batch_count, time.time()-start_time, dice_loss_value["{0}".format(0)],cross_loss_value["{0}".format(0)],np.float32(kl_value["{0}".format(0)]),np.float32(exclusion_value["{0}".format(0)]),dice_loss_value["{0}".format(1)],cross_loss_value["{0}".format(1)],np.float32(kl_value["{0}".format(1)]),np.float32(exclusion_value["{0}".format(1)])))
                    infile.write('\n')'''
                    format_str = 'Epoch: [%3d] [%3d/%3d] time: %4.4f, net0_loss = %.5f, net0_acc = %.4f, net0_test_acc = %.4f   net1_loss = %.5f, net1_acc = %.4f, net1_test_acc = %.4f'
                    print(format_str %
                          (epoch, batch_idx, batch_count, time.time() -
                           start_time, net_loss_value["{0}".format(0)],
                           precision_value["{0}".format(0)],
                           np.float32(test_precision_value["{0}".format(0)]),
                           net_loss_value["{0}".format(1)],
                           precision_value["{0}".format(1)],
                           np.float32(test_precision_value["{0}".format(1)])))

                if batch_idx == 0:
                    testpred0, test_gt, test_X = sess.run([
                        test_predictions["{0}".format(0)], test_truth, test_x
                    ])
                    tot_num_samples = FLAGS.batch_size
                    manifold_h = int(np.floor(np.sqrt(tot_num_samples)))
                    manifold_w = int(np.floor(np.sqrt(tot_num_samples)))
                    save_images(
                        make_png(test_gt[:manifold_h * manifold_w, :, :]),
                        [manifold_h, manifold_w],
                        os.path.join(FLAGS.saliency_map,
                                     str(epoch) + 'test_gt.jpg'))
                    save_images(
                        test_X[:manifold_h * manifold_w, :, :, :],
                        [manifold_h, manifold_w],
                        os.path.join(FLAGS.saliency_map,
                                     str(epoch) + 'test.jpg'))
                    save_images(
                        make_png(testpred0[:manifold_h * manifold_w, :, :]),
                        [manifold_h, manifold_w],
                        os.path.join(FLAGS.saliency_map,
                                     str(epoch) + 'test_pred0.jpg'))
                    #save_images(make_png(testpred1[:manifold_h * manifold_w, :,:]),[manifold_h, manifold_w],os.path.join(FLAGS.saliency_map, str(epoch)+'test_pred1.jpg'))

                    #for index in range(test_gt.shape[0]):
                    #    scipy.misc.imsave(os.path.join(FLAGS.saliency_map, str(epoch)+'_'+str(index)+'test.jpg'), make_png(test_gt[index,:,:]))
                    #    scipy.misc.imsave(os.path.join(FLAGS.saliency_map, str(epoch)+'_'+str(index)+'test_pred0.jpg'), make_png(testpred["{0}".format(0)][index,:,:]))
                    #    scipy.misc.imsave(os.path.join(FLAGS.saliency_map, str(epoch)+'_'+str(index)+'test_pred1.jpg'), make_png(testpred["{0}".format(1)][index,:,:]))

                #summary_str = sess.run(summary_op)
                #summary_writer.add_summary(summary_str, counter)

            # Save the model checkpoint periodically.
            if epoch % FLAGS.ckpt_steps == 0 or epoch == FLAGS.max_number_of_epochs:
                checkpoint_path = os.path.join(FLAGS.checkpoint_dir,
                                               'model.ckpt')
                saver.save(sess, checkpoint_path, global_step=epoch)

        for i in range(FLAGS.num_networks):
            for k in GG["{0}".format(i)].keys():
                if 'dmlnet_%d' % i in k and 'GGamma' in k:
                    GGamma_key = k.split(':')[0]
                    GGamma_key = '_'.join(GGamma_key.split('/'))
                    gamma_file = open(
                        os.path.join(FLAGS.log_dir, GGamma_key + '.txt'), 'w')
                    for g in GGamma["{0}".format(i)][GGamma_key]:
                        gamma_file.write(str(g) + ' \n')
                    gamma_file.close()
def main(_):
    tic = time.time()
    print('tensorflow version:', tf.__version__)
    tf.logging.set_verbosity(tf.logging.INFO)
    if not FLAGS.dataset_dir:
        raise ValueError(
            'You must supply the dataset directory with --dataset_dir')

    if FLAGS.include_cross_entropy:
        if FLAGS.logits_loss_lambda > 0:
            print('HG: use entropy loss and logits loss')
        else:
            print('HG: use entropy loss only')
    elif FLAGS.logits_loss_lambda > 0:
        print('HG: use logits loss only')
    else:
        print('HG: use activation loss only')

    # init
    net_name_scope_pruned = FLAGS.net_name_scope_pruned
    net_name_scope_checkpoint = FLAGS.net_name_scope_checkpoint
    indexed_prune_scopes_for_units = valid_indexed_prune_scopes_for_units
    kept_percentages = sorted(map(float, FLAGS.kept_percentages.split(',')))

    num_options = len(kept_percentages)
    num_units = len(indexed_prune_scopes_for_units)
    print('HG: num_options=%d, num_blocks=%d' % (num_options, num_units))
    print('HG: total number of configurations=%d' % (num_options**num_units))

    # find the  configurations to evaluate

    configs = get_sampled_configurations(num_units, num_options,
                                         FLAGS.total_num_configurations)
    num_configurations = len(configs)

    #Getting MPI rank integer
    comm = MPI.COMM_WORLD
    rank = comm.Get_rank()
    if rank >= num_configurations:
        print("ERROR: rank(%d) > num_configurations(%d)" %
              (rank, num_configurations))
        return
    # rank = 0
    FLAGS.configuration_index = FLAGS.start_configuration_index + rank
    config = configs[FLAGS.configuration_index]
    print('HG: kept_percentages=%s, start_config_index=%d, rank=%d, num_configs=%d,  config_index=%d' \
           %(str(kept_percentages), FLAGS.start_configuration_index, rank, num_configurations,  FLAGS.configuration_index))

    # prepare for training with the specific config
    indexed_prune_scopes, kept_percentage = config_to_indexed_prune_scopes(
        config, indexed_prune_scopes_for_units, kept_percentages)
    prune_info = indexed_prune_scopes_to_prune_info(indexed_prune_scopes,
                                                    kept_percentage)

    # prepare file system
    results_dir = os.path.join(FLAGS.train_dir,
                               'id' + str(FLAGS.configuration_index))
    train_dir = os.path.join(results_dir, 'train')

    if (not FLAGS.continue_training) or (
            not tf.train.latest_checkpoint(train_dir)):
        prune_scopes = indexed_prune_scopes_to_prune_scopes(
            indexed_prune_scopes, net_name_scope_checkpoint)
        shorten_scopes = indexed_prune_scopes_to_shorten_scopes(
            indexed_prune_scopes, net_name_scope_checkpoint)
        variables_init_value = get_init_values_for_pruned_layers(
            prune_scopes, shorten_scopes, kept_percentage)
        reinit_scopes = [
            re.sub(net_name_scope_checkpoint, net_name_scope_pruned, v)
            for v in prune_scopes + shorten_scopes
        ]

        prepare_file_system(train_dir)

    def write_detailed_info(info):
        with open(os.path.join(train_dir, 'train_details.txt'), 'a') as f:
            f.write(info + '\n')

    info = 'train_dir:' + train_dir + '\n'
    info += 'options:' + str(kept_percentages) + '\n'
    info += 'configuration: ' + str(config) + '\n'
    info += 'indexed_prune_scopes: ' + str(indexed_prune_scopes) + '\n'
    info += 'kept_percentage: ' + str(kept_percentage)
    print(info)
    write_detailed_info(info)

    with tf.Graph().as_default():

        #######################
        # Config model_deploy #
        #######################
        deploy_config = model_deploy.DeploymentConfig(
            num_clones=FLAGS.num_clones,
            clone_on_cpu=FLAGS.clone_on_cpu,
            replica_id=FLAGS.task,
            num_replicas=FLAGS.worker_replicas,
            num_ps_tasks=FLAGS.num_ps_tasks)

        ######################
        # Select the dataset #
        ######################
        dataset = dataset_factory.get_dataset(FLAGS.dataset_name,
                                              FLAGS.train_dataset_name,
                                              FLAGS.dataset_dir)
        test_dataset = dataset_factory.get_dataset(FLAGS.dataset_name,
                                                   FLAGS.test_dataset_name,
                                                   FLAGS.dataset_dir)

        batch_queue = train_inputs(dataset, deploy_config, FLAGS)
        test_images, test_labels = test_inputs(test_dataset, deploy_config,
                                               FLAGS)
        images, labels = batch_queue.dequeue()

        ######################
        # Select the network#
        ######################
        # use the original network as teacher
        network_fn = nets_factory.get_network_fn(
            FLAGS.model_name,
            num_classes=(dataset.num_classes - FLAGS.labels_offset),
            weight_decay=FLAGS.weight_decay)

        # use the pruned network as student
        network_fn_pruned = nets_factory.get_network_fn_pruned(
            FLAGS.model_name,
            prune_info=prune_info,
            num_classes=(dataset.num_classes - FLAGS.labels_offset),
            weight_decay=FLAGS.weight_decay)
        print('HG: prune_info:')
        pprint(prune_info)

        ####################
        # Define the model #
        ####################
        logits_teacher, end_points_teacher = network_fn(images,
                                                        is_training=False)
        logits_train, end_points = network_fn_pruned(
            images,
            is_training=True,
            is_local_train=False,
            reuse_variables=False,
            scope=net_name_scope_pruned)
        logits_eval, _ = network_fn_pruned(test_images,
                                           is_training=False,
                                           is_local_train=False,
                                           reuse_variables=True,
                                           scope=net_name_scope_pruned)

        correct_prediction = add_correct_prediction(logits_eval, test_labels)

        #############################
        # Specify the loss function #
        #############################
        if FLAGS.include_cross_entropy:
            cross_entropy = add_cross_entropy(logits_train, labels)
            tf.add_to_collection('subgraph_losses', cross_entropy)
        # get regularization loss
        regularization_losses = get_regularization_losses_within_scopes(
            [net_name_scope_pruned + '/'])
        print_list('regularization_losses', regularization_losses)

        # get knowledge distillation loss: l2_loss on logits
        if FLAGS.logits_loss_lambda > 0:
            logits_loss = add_l2_loss(logits_teacher,
                                      logits_train,
                                      weights=FLAGS.logits_loss_lambda)
            print('logits_loss', logits_loss)

        if not FLAGS.include_cross_entropy and FLAGS.logits_loss_lambda == 0:
            # use activation map loss if no logits loss and no cross entropy loss available
            last_indexed_prune_scope = valid_indexed_prune_scopes[-1]
            outputs_scope = get_prune_units_outputs_scope(
                last_indexed_prune_scope, net_name_scope_checkpoint)
            outputs = end_points_teacher[outputs_scope]

            outputs_scope = get_prune_units_outputs_scope(
                last_indexed_prune_scope, net_name_scope_pruned)
            outputs_pruned = end_points[outputs_scope]
            activation_loss = add_l2_loss(outputs, outputs_pruned)

        # total loss and its summary
        total_loss = tf.add_n(tf.get_collection('subgraph_losses'),
                              name='total_loss')
        for l in tf.get_collection('subgraph_losses') + [total_loss]:
            tf.summary.scalar(l.op.name + '/summary', l)

        #########################################
        # Configure the optimization procedure. #
        #########################################
        with tf.device(deploy_config.variables_device()):
            global_step = tf.Variable(0, trainable=False, name='global_step')

        with tf.device(deploy_config.optimizer_device()):
            learning_rate = configure_learning_rate(dataset.num_samples,
                                                    global_step, FLAGS)
            optimizer = configure_optimizer(learning_rate, FLAGS)
            tf.summary.scalar('learning_rate', learning_rate)

        #############################
        # Add train operation       #
        #############################
        variables_to_train = get_trainable_variables_within_scopes(
            [net_name_scope_pruned + '/'])
        train_op = add_train_op(optimizer,
                                total_loss,
                                global_step,
                                var_list=variables_to_train)
        print_list("variables_to_train", variables_to_train)

        # Gather update_ops: the updates for the batch_norm variables created by network_fn_pruned.
        update_ops = get_update_ops_within_scopes(
            [net_name_scope_pruned + '/'])
        print_list("update_ops", update_ops)

        # add train_tensor
        update_ops.append(train_op)
        update_op = tf.group(*update_ops)
        with tf.control_dependencies([update_op]):
            train_tensor = tf.identity(total_loss, name='train_op')

        # add summary op
        summary_op = tf.summary.merge_all()

        print("HG: trainable_variables=", len(tf.trainable_variables()))
        print("HG: model_variables=", len(tf.model_variables()))
        print("HG: global_variables=", len(tf.global_variables()))

        sess_config = tf.ConfigProto(intra_op_parallelism_threads=16,
                                     inter_op_parallelism_threads=16)
        with tf.Session(config=sess_config) as sess:
            ###########################
            # Prepare for filewriter. #
            ###########################
            train_writer = tf.summary.FileWriter(train_dir, sess.graph)

            # if restart the training or there is no checkpoint in the train_dir
            if (not FLAGS.continue_training) or (
                    not tf.train.latest_checkpoint(train_dir)):
                ###########################################
                # Restore original model variable values. #
                ###########################################
                variables_to_restore = get_model_variables_within_scopes(
                    [net_name_scope_checkpoint + '/'])
                print_list("restore model variables for original",
                           variables_to_restore)
                load_checkpoint(sess,
                                FLAGS.checkpoint_path,
                                var_list=variables_to_restore)

                #########################################
                # Reinit  pruned model variable  #
                #########################################
                variables_to_reinit = get_model_variables_within_scopes(
                    reinit_scopes)
                print_list("Initialize pruned variables", variables_to_reinit)
                assign_ops = []
                for v in variables_to_reinit:
                    key = re.sub(net_name_scope_pruned,
                                 net_name_scope_checkpoint, v.op.name)
                    if key in variables_init_value:
                        value = variables_init_value.get(key)
                        # print(key, value)
                        assign_ops.append(
                            tf.assign(v,
                                      tf.convert_to_tensor(value),
                                      validate_shape=True))
                        # v.set_shape(value.shape)
                    else:
                        raise ValueError(
                            "Key not in variables_init_value, key=", key)
                assign_op = tf.group(*assign_ops)
                sess.run(assign_op)

                #################################################
                # Restore unchanged model variable. #
                #################################################
                variables_to_restore = {
                    re.sub(net_name_scope_pruned, net_name_scope_checkpoint,
                           v.op.name): v
                    for v in get_model_variables_within_scopes(
                        [net_name_scope_pruned + '/'])
                    if v not in variables_to_reinit
                }
                print_list("restore pruned model variables",
                           variables_to_restore.values())
                load_checkpoint(sess,
                                FLAGS.checkpoint_path,
                                var_list=variables_to_restore)

            else:
                ###########################################
                ## Restore all variables from checkpoint ##
                ###########################################
                variables_to_restore = get_global_variables_within_scopes()
                load_checkpoint(sess, train_dir, var_list=variables_to_restore)

            #################################################
            # init unitialized global variable. #
            #################################################
            variables_to_init = get_global_variables_within_scopes(
                sess.run(tf.report_uninitialized_variables()))
            print_list("init unitialized variables", variables_to_init)
            sess.run(tf.variables_initializer(variables_to_init))

            init_global_step_value = sess.run(global_step)
            print('initial global step: ', init_global_step_value)
            if init_global_step_value >= FLAGS.max_number_of_steps:
                print('Exit: init_global_step_value (%d) >= FLAG.max_number_of_steps (%d)' \
                    %(init_global_step_value, FLAGS.max_number_of_steps))
                return

            ###########################
            # Record CPU usage  #
            ###########################
            # mpstat_output_filename = os.path.join(train_dir, "cpu-usage.log")
            # os.system("mpstat -P ALL 1 > " + mpstat_output_filename + " 2>&1 &")

            ###########################
            # Kicks off the training. #
            ###########################
            coord = tf.train.Coordinator()
            threads = tf.train.start_queue_runners(sess=sess, coord=coord)
            saver = tf.train.Saver(max_to_keep=FLAGS.max_to_keep)
            print('HG: # of threads=', len(threads))

            duration = 0
            duration_cnt = 0
            train_time = 0
            train_only_cnt = 0

            print("start to train at:", datetime.now())
            for i in range(init_global_step_value,
                           FLAGS.max_number_of_steps + 1):
                # run optional meta data, or summary, while run train tensor
                #if i < FLAGS.max_number_of_steps:
                if i > init_global_step_value:
                    # train while run metadata
                    if i % FLAGS.runmeta_every_n_steps == FLAGS.runmeta_every_n_steps - 1:
                        run_options = tf.RunOptions(
                            trace_level=tf.RunOptions.FULL_TRACE)
                        run_metadata = tf.RunMetadata()

                        loss_value = sess.run(train_tensor,
                                              options=run_options,
                                              run_metadata=run_metadata)
                        train_writer.add_run_metadata(run_metadata,
                                                      'step%d-train' % i)

                        # Create the Timeline object, and write it to a json file
                        fetched_timeline = timeline.Timeline(
                            run_metadata.step_stats)
                        chrome_trace = fetched_timeline.generate_chrome_trace_format(
                        )
                        with open(
                                os.path.join(train_dir,
                                             'timeline_' + str(i) + '.json'),
                                'w') as f:
                            f.write(chrome_trace)

                    # train while record summary
                    elif i % FLAGS.summary_every_n_steps == 0:
                        train_summary, loss_value = sess.run(
                            [summary_op, train_tensor])
                        train_writer.add_summary(train_summary, i)

                    # train only
                    else:
                        start_time = time.time()
                        loss_value = sess.run(train_tensor)
                        train_only_cnt += 1
                        train_time += time.time() - start_time
                        duration_cnt += 1
                        duration += time.time() - start_time

                    # log loss information
                    if i % FLAGS.log_every_n_steps == 0 and duration_cnt > 0:
                        log_frequency = duration_cnt
                        examples_per_sec = log_frequency * FLAGS.batch_size / duration
                        sec_per_batch = float(duration / log_frequency)
                        summary = tf.Summary()
                        summary.value.add(tag='examples_per_sec',
                                          simple_value=examples_per_sec)
                        summary.value.add(tag='sec_per_batch',
                                          simple_value=sec_per_batch)
                        train_writer.add_summary(summary, i)
                        format_str = (
                            '%s: step %d, loss = %.3f (%.1f examples/sec; %.3f sec/batch)'
                        )
                        print(format_str % (datetime.now(), i, loss_value,
                                            examples_per_sec, sec_per_batch))
                        duration = 0
                        duration_cnt = 0

                        info = format_str % (datetime.now(), i, loss_value,
                                             examples_per_sec, sec_per_batch)
                        write_detailed_info(info)
                else:
                    # run only total loss when i=0
                    train_summary, loss_value = sess.run(
                        [summary_op,
                         total_loss])  #loss_value = sess.run(total_loss)
                    train_writer.add_summary(train_summary, i)
                    format_str = ('%s: step %d, loss = %.3f')
                    print(format_str % (datetime.now(), i, loss_value))
                    info = format_str % (datetime.now(), i, loss_value)
                    write_detailed_info(info)

                # record the evaluation accuracy
                is_last_step = (i == FLAGS.max_number_of_steps)
                if i % FLAGS.evaluate_every_n_steps == 0 or is_last_step:
                    #run_meta = (i==FLAGS.evaluate_every_n_steps)
                    test_accuracy, run_metadata = evaluate_accuracy(
                        sess,
                        coord,
                        test_dataset.num_samples,
                        test_images,
                        test_labels,
                        test_images,
                        test_labels,
                        correct_prediction,
                        FLAGS.test_batch_size,
                        run_meta=False)
                    summary = tf.Summary()
                    summary.value.add(tag='accuracy',
                                      simple_value=test_accuracy)
                    train_writer.add_summary(summary, i)
                    #if run_meta:
                    #    eval_writer.add_run_metadata(run_metadata, 'step%d-eval' % i)

                    info = ('%s: step %d, test_accuracy = %.6f') % (
                        datetime.now(), i, test_accuracy)
                    print(info)
                    write_detailed_info(info)

                    ###########################
                    # Save model parameters . #
                    ###########################
                    #saver = tf.train.Saver(var_list=get_model_variables_within_scopes([net_name_scope_pruned+'/']))
                    save_path = saver.save(
                        sess, os.path.join(train_dir, 'model.ckpt-' + str(i)))
                    print("HG: Model saved in file: %s" % save_path)

            coord.request_stop()
            coord.join(threads)
            total_time = time.time() - tic
            train_speed = train_time * 1.0 / train_only_cnt
            train_time = train_speed * (
                FLAGS.max_number_of_steps
            )  # - init_global_step_value) #/train_only_cnt
            info = "HG: training speed(sec/batch): %.6f\n" % (train_speed)
            info += "HG: training time(min): %.1f, total time(min): %.1f" % (
                train_time / 60.0, total_time / 60.0)
            print(info)
            write_detailed_info(info)
def main(_):
    tic = time.time()
    print('tensorflow version:', tf.__version__)
    tf.logging.set_verbosity(tf.logging.INFO)
    if not FLAGS.dataset_dir:
        raise ValueError(
            'You must supply the dataset directory with --dataset_dir')
    # init
    net_name_scope_pruned = FLAGS.net_name_scope_pruned
    net_name_scope_checkpoint = FLAGS.net_name_scope_checkpoint
    indexed_prune_scopes_for_units = valid_indexed_prune_scopes_for_units
    kept_percentage = FLAGS.kept_percentage

    # set the configuration: should be a 16-length vector
    config = [1.0] * len(indexed_prune_scopes_for_units)
    config[FLAGS.block_id] = kept_percentage
    print("config:", config)

    # prepare for training with the specific config
    indexed_prune_scopes = indexed_prune_scopes_for_units[FLAGS.block_id]
    prune_info = indexed_prune_scopes_to_prune_info(indexed_prune_scopes,
                                                    kept_percentage)
    print("prune_info:", prune_info)

    # prepare file system
    results_dir = os.path.join(
        FLAGS.train_dir,
        'id' + str(FLAGS.block_id))  #+'_'+str(FLAGS.max_number_of_steps))
    train_dir = os.path.join(results_dir, 'kp' + str(kept_percentage))

    prune_scopes = indexed_prune_scopes_to_prune_scopes(
        indexed_prune_scopes, net_name_scope_checkpoint)
    shorten_scopes = indexed_prune_scopes_to_shorten_scopes(
        indexed_prune_scopes, net_name_scope_checkpoint)
    variables_init_value = get_init_values_for_pruned_layers(
        prune_scopes, shorten_scopes, kept_percentage)
    reinit_scopes = [
        re.sub(net_name_scope_checkpoint, net_name_scope_pruned, v)
        for v in prune_scopes + shorten_scopes
    ]

    prepare_file_system(train_dir)

    def write_detailed_info(info):
        with open(os.path.join(train_dir, 'eval_details.txt'), 'a') as f:
            f.write(info + '\n')

    info = 'train_dir:' + train_dir + '\n'
    info += 'block_id:' + str(FLAGS.block_id) + '\n'
    info += 'configuration: ' + str(config) + '\n'
    info += 'indexed_prune_scopes: ' + str(indexed_prune_scopes) + '\n'
    info += 'kept_percentage: ' + str(kept_percentage)
    print(info)
    write_detailed_info(info)

    with tf.Graph().as_default():

        #######################
        # Config model_deploy #
        #######################
        deploy_config = model_deploy.DeploymentConfig(
            num_clones=FLAGS.num_clones,
            clone_on_cpu=FLAGS.clone_on_cpu,
            replica_id=FLAGS.task,
            num_replicas=FLAGS.worker_replicas,
            num_ps_tasks=FLAGS.num_ps_tasks)

        ######################
        # Select the dataset #
        ######################
        # dataset = dataset_factory.get_dataset(
        #     FLAGS.dataset_name, FLAGS.train_dataset_name, FLAGS.dataset_dir)
        test_dataset = dataset_factory.get_dataset(FLAGS.dataset_name,
                                                   FLAGS.test_dataset_name,
                                                   FLAGS.dataset_dir)

        # batch_queue = train_inputs(dataset, deploy_config, FLAGS)
        test_images, test_labels = test_inputs(test_dataset, deploy_config,
                                               FLAGS)
        # images, labels = batch_queue.dequeue()

        ######################
        # Select the network#
        ######################

        network_fn_pruned = nets_factory.get_network_fn_pruned(
            FLAGS.model_name,
            prune_info=prune_info,
            num_classes=(test_dataset.num_classes - FLAGS.labels_offset),
            weight_decay=FLAGS.weight_decay)
        print('HG: prune_info:')
        pprint(prune_info)

        ####################
        # Define the model #
        ####################
        # logits_train, _ = network_fn_pruned(images, is_training=True, is_local_train=False, reuse_variables=False, scope = net_name_scope_pruned)
        logits_eval, _ = network_fn_pruned(test_images,
                                           is_training=False,
                                           is_local_train=False,
                                           reuse_variables=False,
                                           scope=net_name_scope_pruned)
        correct_prediction = add_correct_prediction(logits_eval, test_labels)

        print("HG: trainable_variables=", len(tf.trainable_variables()))
        print("HG: model_variables=", len(tf.model_variables()))
        print("HG: global_variables=", len(tf.global_variables()))

        sess_config = tf.ConfigProto(intra_op_parallelism_threads=16,
                                     inter_op_parallelism_threads=16)
        with tf.Session(config=sess_config) as sess:
            ###########################
            # Prepare for filewriter. #
            ###########################
            # train_writer = tf.summary.FileWriter(train_dir, sess.graph)

            #########################################
            # Reinit  pruned model variable  #
            #########################################
            variables_to_reinit = get_model_variables_within_scopes(
                reinit_scopes)
            print_list("Initialize pruned variables", variables_to_reinit)
            assign_ops = []
            for v in variables_to_reinit:
                key = re.sub(net_name_scope_pruned, net_name_scope_checkpoint,
                             v.op.name)
                if key in variables_init_value:
                    value = variables_init_value.get(key)
                    # print(key, value)
                    assign_ops.append(
                        tf.assign(v,
                                  tf.convert_to_tensor(value),
                                  validate_shape=True))
                    # v.set_shape(value.shape)
                else:
                    raise ValueError("Key not in variables_init_value, key=",
                                     key)
            assign_op = tf.group(*assign_ops)
            sess.run(assign_op)

            #################################################
            # Restore unchanged model variable. #
            #################################################
            variables_to_restore = {
                re.sub(net_name_scope_pruned, net_name_scope_checkpoint,
                       v.op.name): v
                for v in get_model_variables_within_scopes()
                if v not in variables_to_reinit
            }
            print_list("restore model variables",
                       variables_to_restore.values())
            load_checkpoint(sess,
                            FLAGS.checkpoint_path,
                            var_list=variables_to_restore)

            ###########################
            # Kicks off the training. #
            ###########################
            coord = tf.train.Coordinator()
            threads = tf.train.start_queue_runners(sess=sess, coord=coord)
            # saver = tf.train.Saver(max_to_keep=FLAGS.max_to_keep)
            print('HG: # of threads=', len(threads))

            eval_time = -1 * time.time()
            test_accuracy, run_metadata = evaluate_accuracy(
                sess,
                coord,
                test_dataset.num_samples,
                test_images,
                test_labels,
                test_images,
                test_labels,
                correct_prediction,
                FLAGS.test_batch_size,
                run_meta=False)
            eval_time += time.time()

            info = ('%s: test_accuracy = %.6f') % (datetime.now(),
                                                   test_accuracy)
            print(info)
            write_detailed_info(info)

            coord.request_stop()
            coord.join(threads)
            total_time = time.time() - tic

            info = "HG: training time(min): %.1f, total time(min): %.1f" % (
                eval_time / 60.0, total_time / 60.0)
            print(info)
            write_detailed_info(info)
def main(unused_argv):
    FLAGS.train_logdir = FLAGS.base_logdir + '/' + FLAGS.task_name
    if FLAGS.restore_name == None:
        FLAGS.restore_logdir = FLAGS.train_logdir
    else:
        FLAGS.restore_logdir = FLAGS.base_logdir + '/' + FLAGS.restore_name

    # Get logging dir ready.
    if not (os.path.isdir(FLAGS.train_logdir)):
        tf.gfile.MakeDirs(FLAGS.train_logdir)
    if len(os.listdir(FLAGS.train_logdir)) != 0:
        if not (FLAGS.if_restore) or (FLAGS.if_restore
                                      and FLAGS.task_name != FLAGS.restore_name
                                      and FLAGS.restore_name != None):
            if FLAGS.if_debug:
                shutil.rmtree(FLAGS.train_logdir)
                print '==== Log folder %s emptied: ' % FLAGS.train_logdir + 'rm -rf %s/*' % FLAGS.train_logdir
            else:
                if_delete_all = raw_input(
                    '#### The log folder %s exists and non-empty; delete all logs? [y/n] '
                    % FLAGS.train_logdir)
                if if_delete_all == 'y':
                    shutil.rmtree(FLAGS.train_logdir)
                    print '==== Log folder %s emptied: ' % FLAGS.train_logdir + 'rm -rf %s/*' % FLAGS.train_logdir
        else:
            print '==== Log folder exists; not emptying it because we need to restore from it.'
    tf.logging.info('==== Logging in dir:%s; Training on %s set',
                    FLAGS.train_logdir, FLAGS.train_split)

    # Set up deployment (i.e., multi-GPUs and/or multi-replicas).
    config = model_deploy.DeploymentConfig(
        num_clones=FLAGS.num_clones,
        clone_on_cpu=FLAGS.clone_on_cpu,
        replica_id=FLAGS.task,
        num_replicas=FLAGS.num_replicas,
        num_ps_tasks=FLAGS.num_ps_tasks)  # /device:CPU:0

    # Split the batch across GPUs.
    assert FLAGS.train_batch_size % config.num_clones == 0, (
        'Training batch size not divisble by number of clones (GPUs).')
    clone_batch_size = FLAGS.train_batch_size // config.num_clones

    # Get dataset-dependent information.
    dataset = regression_dataset.get_dataset(FLAGS,
                                             FLAGS.dataset,
                                             FLAGS.train_split,
                                             dataset_dir=FLAGS.dataset_dir)
    dataset_val = regression_dataset.get_dataset(FLAGS,
                                                 FLAGS.dataset,
                                                 FLAGS.val_split,
                                                 dataset_dir=FLAGS.dataset_dir)
    print '#### The data has size:', dataset.num_samples, dataset_val.num_samples

    dataset.height = FLAGS.train_crop_size[0]
    dataset.width = FLAGS.train_crop_size[1]
    dataset_val.height = FLAGS.train_crop_size[0]
    dataset_val.width = FLAGS.train_crop_size[1]

    codes = np.load('./deeplab/codes.npy')

    with tf.Graph().as_default() as graph:
        np.set_printoptions(precision=4)
        with tf.device(config.inputs_device()):
            codes_max = np.amax(codes, axis=1).reshape((-1, 1))
            codes_min = np.amin(codes, axis=1).reshape((-1, 1))
            # shape_range = np.hstack((codes_max + (codes_max - codes_min)/(dataset.SHAPE_BINS-1.), codes_min - (codes_max - codes_min)/(dataset.SHAPE_BINS-1.)))
            shape_range = np.hstack((codes_min, codes_max))
            pose_range = dataset.pose_range
            if FLAGS.if_log_depth:
                pose_range[6] = np.log(pose_range[6]).tolist()
            bin_centers_list = [
                np.linspace(r[0], r[1], num=b)
                for r, b in zip(np.vstack((pose_range,
                                           shape_range)), dataset.bin_nums)
            ]
            bin_size_list = [
                (r[1] - r[0]) / (b - 1 if b != 1 else 1)
                for r, b in zip(np.vstack((pose_range,
                                           shape_range)), dataset.bin_nums)
            ]
            bin_bounds_list = [[c_elem - s / 2.
                                for c_elem in c] + [c[-1] + s / 2.]
                               for c, s in zip(bin_centers_list, bin_size_list)
                               ]
            assert bin_bounds_list[6][
                0] > 0, 'Need more bins to make the first bound of log depth positive! (Or do we?)'
            for output, pose_range, bin_size, bin_centers, bin_bounds in zip(
                    dataset.output_names, pose_range, bin_size_list,
                    bin_centers_list, bin_bounds_list)[:8]:
                print output + '_poserange_binsize', pose_range, bin_size
                print output + '_bin_centers', bin_centers, len(bin_centers)
                print output + '_bin_bounds', bin_bounds, len(bin_bounds)
            bin_centers_tensors = [
                tf.constant(value=[bin_centers_list[i].tolist()],
                            dtype=tf.float32,
                            shape=[1, dataset.bin_nums[i]],
                            name=name)
                for i, name in enumerate(dataset.output_names)
            ]

            outputs_to_num_classes = {}
            outputs_to_indices = {}
            for output, bin_num, idx in zip(dataset.output_names,
                                            dataset.bin_nums,
                                            range(len(dataset.output_names))):
                if FLAGS.if_discrete_loss:
                    outputs_to_num_classes[output] = bin_num
                else:
                    outputs_to_num_classes[output] = 1
                outputs_to_indices[output] = idx

            model_options = common.ModelOptions(
                outputs_to_num_classes=outputs_to_num_classes,
                crop_size=[dataset.height, dataset.width],
                atrous_rates=FLAGS.atrous_rates,
                output_stride=FLAGS.output_stride)

            samples = input_generator.get(dataset,
                                          model_options,
                                          codes,
                                          clone_batch_size,
                                          dataset_split=FLAGS.train_split,
                                          is_training=True,
                                          model_variant=FLAGS.model_variant)
            inputs_queue = prefetch_queue.prefetch_queue(samples,
                                                         capacity=64 *
                                                         config.num_clones,
                                                         dynamic_pad=True)

            samples_val = input_generator.get(
                dataset_val,
                model_options,
                codes,
                clone_batch_size * 7,
                dataset_split=FLAGS.val_split,
                is_training=False,
                model_variant=FLAGS.model_variant)
            inputs_queue_val = prefetch_queue.prefetch_queue(samples_val,
                                                             capacity=64,
                                                             dynamic_pad=True)

        # Create the global step on the device storing the variables.
        with tf.device(config.variables_device()):
            global_step = tf.train.get_or_create_global_step()

            # Define the model and create clones.
            model_fn = _build_deeplab
            model_args = (FLAGS, inputs_queue.dequeue(),
                          outputs_to_num_classes, outputs_to_indices,
                          bin_centers_tensors, bin_centers_list,
                          bin_bounds_list, bin_size_list, dataset, codes,
                          config.inputs_device(), True)
            clones = model_deploy.create_clones(config,
                                                model_fn,
                                                args=model_args)

            # Gather update_ops from the first clone. These contain, for example,
            # the updates for the batch_norm variables created by model_fn.
            first_clone_scope = config.clone_scope(0)  # clone_0
            update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS,
                                           first_clone_scope)
            print '+++++++++++', len(
                [v.name for v in tf.trainable_variables()])

        # Gather initial summaries.
        summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES))

        # Build the optimizer based on the device specification.
        with tf.device(config.optimizer_device()):
            learning_rate = train_utils.get_model_learning_rate(
                FLAGS.learning_policy, FLAGS.base_learning_rate,
                FLAGS.learning_rate_decay_step,
                FLAGS.learning_rate_decay_factor,
                FLAGS.training_number_of_steps, FLAGS.learning_power,
                FLAGS.slow_start_step, FLAGS.slow_start_learning_rate)
            # optimizer = tf.train.MomentumOptimizer(learning_rate, FLAGS.momentum)
            optimizer = tf.train.AdamOptimizer(learning_rate)
            summaries.add(tf.summary.scalar('learning_rate', learning_rate))

        startup_delay_steps = FLAGS.task * FLAGS.startup_delay_steps

        with tf.device(config.variables_device()):
            total_loss, grads_and_vars = model_deploy.optimize_clones(
                clones, optimizer)
            print '-- [TOTAL LOSS]: ', total_loss
            for loss_item in tf.get_collection(tf.GraphKeys.LOSSES,
                                               first_clone_scope):
                print loss_item  # total_loss = tf.check_numerics(total_loss, 'Loss is inf or nan.')
            summaries.add(tf.summary.scalar('total_loss/train', total_loss))

            # Modify the gradients for biases and last layer variables.
            last_layers = model.get_extra_layer_scopes(
                FLAGS.last_layers_contain_logits_only)
            last_layers = last_layers + [
                'decoder',
                'decoder_weights',
            ]
            print '////last layers', last_layers

            # Keep trainable variables for last layers ONLY.
            # weight_scopes = [output_name+'_weights' for output_name in dataset.output_names] + ['decoder_weights']
            # grads_and_vars = train_utils.filter_gradients(weight_scopes, grads_and_vars)
            # print '==== variables_to_train: ', [grad_and_var[1].op.name for grad_and_var in grads_and_vars]

            grad_mult = train_utils.get_model_gradient_multipliers(
                last_layers, FLAGS.last_layer_gradient_multiplier)
            if grad_mult:
                grads_and_vars = slim.learning.multiply_gradients(
                    grads_and_vars, grad_mult)

            # Create gradient update op.
            grad_updates = optimizer.apply_gradients(grads_and_vars,
                                                     global_step=global_step)
            update_ops.append(grad_updates)
            update_op = tf.group(*update_ops)
            with tf.control_dependencies([update_op]):
                train_tensor = tf.identity(total_loss, name='train_op')

        with tf.device('/device:GPU:%d' % (FLAGS.num_clones + 1)):
            if FLAGS.if_val:
                ## Construct the validation graph; takes one GPU.
                image_names, z_logits, outputs_to_weights, seg_one_hots_list, weights_normalized, areas_masked, car_nums, car_nums_list, idx_xys, reg_logits_pose_xy_from_uv, pose_dict_N, prob_logits_pose, rotuvd_dict_N, masks_float, label_uv_flow_map, logits_uv_flow_map = _build_deeplab(
                    FLAGS,
                    inputs_queue_val.dequeue(),
                    outputs_to_num_classes,
                    outputs_to_indices,
                    bin_centers_tensors,
                    bin_centers_list,
                    bin_bounds_list,
                    bin_size_list,
                    dataset_val,
                    codes,
                    config.inputs_device(),
                    is_training=False)
                # pose_dict_N, xyz = _build_deeplab(FLAGS, inputs_queue_val.dequeue(), outputs_to_num_classes, outputs_to_indices, bin_vals, bin_range, dataset_val, codes, is_training=False)
        if FLAGS.num_clones > 1:
            pattern_train = first_clone_scope + '/%s:0'
        else:
            pattern_train = '%s:0'
        pattern_val = 'val-%s:0'
        pattern = pattern_val if FLAGS.if_val else pattern_train

        # Add summaries for images, labels, semantic predictions
        summaries = get_summaries(FLAGS, graph, summaries, dataset, config,
                                  first_clone_scope)

        # Merge all summaries together.
        summary_op = tf.summary.merge(list(summaries))

        # Soft placement allows placing on CPU ops without GPU implementation.
        session_config = tf.ConfigProto(allow_soft_placement=True,
                                        log_device_placement=False)
        session_config.gpu_options.allow_growth = True

        def train_step_fn(sess, train_op, global_step, train_step_kwargs):
            train_step_fn.step += 1  # or use global_step.eval(session=sess)
            loss = 0

            # calc training losses
            if not (FLAGS.if_pause):
                loss, should_stop = slim.learning.train_step(
                    sess, train_op, global_step, train_step_kwargs)
                print train_step_fn.step, loss
            else:
                sess.run(global_step)

            # # print 'loss: ', loss
            if train_step_fn.step % 20 == 0:
                first_clone_test = graph.get_tensor_by_name(
                    ('%s/%s:0' % (first_clone_scope,
                                  'depth_rescaled_logit_map')).strip('/'))
                first_clone_test2 = graph.get_tensor_by_name(
                    ('%s/%s:0' % (first_clone_scope,
                                  'depth_rescaled_label_map')).strip('/'))
                first_clone_test3 = graph.get_tensor_by_name(
                    ('%s/%s:0' % (first_clone_scope,
                                  'depth_rescaled_cls_error_map')).strip('/'))
                test1, test2, test3 = sess.run(
                    [first_clone_test, first_clone_test2, first_clone_test3])
                # print test
                for test0 in [test1, test2, test3]:
                    test = test0[test0 != 0.]
                    print 'test: ', test.shape, np.max(test), np.min(
                        test), np.mean(test), test.dtype

            # mask_rescaled_float = graph.get_tensor_by_name('%s:0'%'mask_rescaled_float')
            # test_out, test_out2 = sess.run([pose_dict_N, xyz])
            # print test_out
            # print test_out2
            # test_out3 = test_out3[test_out4!=0.]
            # print test_out3
            # print 'outputs_to_weights[z] masked: ', test_out3.shape, np.max(test_out3), np.min(test_out3), np.mean(test_out3), test_out3.dtype
            # print 'mask_rescaled_float: ', test_out4.shape, np.max(test_out4), np.min(test_out4), np.mean(test_out4), test_out4.dtype

            # test_1 = graph.get_tensor_by_name(
            #         ('%s/%s:0' % (first_clone_scope, 'prob_logits_pose')).strip('/'))
            # test_2 = graph.get_tensor_by_name(
            #         ('%s/%s:0' % (first_clone_scope, 'pose_dict_N')).strip('/'))
            # test_out, test_out2 = sess.run([test_1, test_2])
            # print '-- prob_logits_pose: ', test_out.shape, np.max(test_out), np.min(test_out), np.mean(test_out), test_out.dtype
            # print test_out, test_out.shape
            # print '-- pose_dict_N: ', test_out2.shape, np.max(test_out2), np.min(test_out2), np.mean(test_out2), test_out2.dtype
            # print test_out2, test_out2.shape

            should_stop = 0

            if FLAGS.if_val and train_step_fn.step % FLAGS.val_interval_steps == 0:
                # first_clone_test = graph.get_tensor_by_name('val-loss_all:0')
                # test = sess.run(first_clone_test)
                print '-- Validating...' + FLAGS.task_name
                # first_clone_test = graph.get_tensor_by_name(
                #         ('%s/%s:0' % (first_clone_scope, 'z')).strip('/'))
                # # first_clone_test2 = graph.get_tensor_by_name(
                # #         ('%s/%s:0' % (first_clone_scope, 'shape_id_sim_map')).strip('/'))
                # # first_clone_test3 = graph.get_tensor_by_name(
                # #         ('%s/%s:0' % (first_clone_scope, 'not_ignore_mask_in_loss')).strip('/'))

                mask_rescaled_float = graph.get_tensor_by_name(
                    'val-%s:0' % 'mask_rescaled_float')
                trans_sqrt_error = graph.get_tensor_by_name(pattern_val %
                                                            'trans_sqrt_error')
                trans_loss_error = graph.get_tensor_by_name(pattern_val %
                                                            'trans_loss_error')
                trans_diff_metric_abs = graph.get_tensor_by_name(
                    pattern_val % 'trans_diff_metric_abs')
                # label_id_slice = graph.get_tensor_by_name(pattern_val%'label_id_slice')
                # label_log_slice = graph.get_tensor_by_name(pattern_val%'label_log_slice')
                # summary_loss_slice_reg_vector_z = graph.get_tensor_by_name((pattern%'loss_slice_reg_vector_').replace(':0', '')+'z'+':0')
                if FLAGS.if_depth_only:
                    # summary_loss_slice_reg_vector_x = summary_loss_slice_reg_vector_z
                    # summary_loss_slice_reg_vector_y = summary_loss_slice_reg_vector_z
                    label_uv_map = trans_diff_metric_abs
                    logits_uv_map = trans_diff_metric_abs
                else:
                    # summary_loss_slice_reg_vector_x = graph.get_tensor_by_name((pattern%'loss_slice_reg_vector_').replace(':0', '')+'x'+':0')
                    # summary_loss_slice_reg_vector_y = graph.get_tensor_by_name((pattern%'loss_slice_reg_vector_').replace(':0', '')+'y'+':0')
                    label_uv_map = graph.get_tensor_by_name(
                        pattern % 'label_uv_flow_map')
                    logits_uv_map = graph.get_tensor_by_name(
                        pattern % 'logits_uv_flow_map')
                _, test_out, test_out2, test_out3, test_out3_1, test_out3_2, test_out4, test_out5_areas, test_out5, test_out6, test_out7, test_out8, test_out9, test_out10, test_out11, test_out12, test_out13, test_out14, test_out15, trans_sqrt_error, trans_diff_metric_abs, trans_loss_error = sess.run(
                    [
                        summary_op, image_names, z_logits,
                        outputs_to_weights['z_object'],
                        outputs_to_weights['z_log_dense'],
                        outputs_to_weights['z_log_offset'],
                        mask_rescaled_float, areas_masked, weights_normalized,
                        prob_logits_pose, pose_dict_N, car_nums, car_nums_list,
                        idx_xys, rotuvd_dict_N, masks_float,
                        reg_logits_pose_xy_from_uv, label_uv_map,
                        logits_uv_map, trans_sqrt_error, trans_diff_metric_abs,
                        trans_loss_error
                    ])
                # test_out_regx, test_out_regy, test_out_regz, trans_sqrt_error, trans_diff_metric_abs = sess.run([)
                print test_out
                print test_out2.shape
                test_out3 = test_out3[test_out4 != 0.]
                print 'outputs_to_weights[z] masked: ', test_out3.shape, np.max(
                    test_out3), np.min(test_out3), np.mean(
                        test_out3), test_out3.dtype
                test_out3_1 = test_out3_1[test_out4 != 0.]
                print 'outputs_to_weights[z dense] masked: ', test_out3_1.shape, np.max(
                    test_out3_1), np.min(test_out3_1), np.mean(
                        test_out3_1), test_out3_1.dtype
                test_out3_2 = test_out3_2[test_out4 != 0.]
                print 'outputs_to_weights[z offset] masked: ', test_out3_2.shape, np.max(
                    test_out3_2), np.min(test_out3_2), np.mean(
                        test_out3_2), test_out3_2.dtype
                print 'areas: ', test_out5_areas.T, test_out5_areas.shape, np.sum(
                    test_out5_areas)
                print 'masks: ', test_out12.T

                print '-- reg_logits_pose_xy(optionally from_uv): ', test_out13.shape, np.max(
                    test_out13), np.min(test_out13), np.mean(
                        test_out13), test_out13.dtype
                print test_out13, test_out13.shape
                print '-- pose_dict_N: ', test_out7.shape, np.max(
                    test_out7), np.min(test_out7), np.mean(
                        test_out7), test_out7.dtype
                print test_out7, test_out7.shape
                if FLAGS.if_uvflow:
                    print '-- prob_logits_pose: ', test_out6.shape, np.max(
                        test_out6), np.min(test_out6), np.mean(
                            test_out6), test_out6.dtype
                    print test_out6, test_out6.shape
                    print '-- rotuvd_dict_N: ', test_out11.shape, np.max(
                        test_out11), np.min(test_out11), np.mean(
                            test_out11), test_out11.dtype
                    print test_out11, test_out11.shape
                    if not (FLAGS.if_depth_only) and FLAGS.if_uvflow:
                        print '-- label_uv_map: ', test_out14.shape, np.max(
                            test_out14[:, :, :, 0]), np.min(
                                test_out14[:, :, :, 0]), np.max(
                                    test_out14[:, :, :,
                                               1]), np.min(test_out14[:, :, :,
                                                                      1])
                        print '-- logits_uv_map: ', test_out15.shape, np.max(
                            test_out15[:, :, :, 0]), np.min(
                                test_out15[:, :, :, 0]), np.max(
                                    test_out15[:, :, :,
                                               1]), np.min(test_out15[:, :, :,
                                                                      1])
                print '-- car_nums: ', test_out8, test_out9, test_out10.T
                # if not(FLAGS.if_depth_only):
                #     print '-- slice reg x: ', test_out_regx.T, np.max(test_out_regx), test_out_regx.shape
                #     print '-- slice reg y: ', test_out_regy.T, np.max(test_out_regy), test_out_regy.shape
                # print '-- slice reg z: ', test_out_regz.T, np.max(test_out_regz), test_out_regz.shape
                print '-- trans_sqrt_error: ', trans_sqrt_error.T, np.max(
                    trans_sqrt_error), trans_sqrt_error.shape
                print '-- trans_diff_metric_abs: ', np.hstack(
                    (trans_diff_metric_abs, test_out5_areas,
                     trans_sqrt_error)), np.max(
                         trans_diff_metric_abs,
                         axis=0), trans_diff_metric_abs.shape

                # print '++++ label_id_slice', label_id_slice.T, np.min(label_id_slice), np.max(label_id_slice)
                # print '++++ label_log_slice', label_log_slice.T, np.min(label_log_slice), np.max(label_log_slice)
                # print '++++ label_explog_slice', np.exp(label_log_slice).T, np.min(np.exp(label_log_slice)), np.max(np.exp(label_log_slice))
                # if np.min(label_id_slice)<0 or np.max(label_id_slice)>=64:
                # programPause = raw_input("!!!!!!!!!!!!!!!!!np.min(label_id_slice)<0 or np.max(label_id_slice)>=64")

                if FLAGS.if_pause:
                    programPause = raw_input(
                        "Press the <ENTER> key to continue...")

                # # Vlen(test_out), test_out[0].shape
                # # print test_out2.shape, test_out2
                # # print test_out3
                # # print test_out, test_out.shape
                # # # test_out = test[:, :, :, 3]
                # # test_out = test_out[test_out3]
                # # # test_out2 = test2[:, :, :, 3]
                # # test_out2 = test_out2[test_out3]
                # # # print test_out
                # # print 'shape_id_map: ', test_out.shape, np.max(test_out), np.min(test_out), np.mean(test_out), np.median(test_out), test_out.dtype
                # # print 'shape_id_sim_map: ', test_out2.shape, np.max(test_out2), np.min(test_out2), np.mean(test_out2), np.median(test_out2), test_out2.dtype
                # # print 'masks sum: ', test_out3.dtype, np.sum(test_out3.astype(float))
                # # assert np.max(test_out) == np.max(test_out2), 'MAtch1!!!'
                # # assert np.min(test_out) == np.min(test_out2), 'MAtch2!!!'

            return [loss, should_stop]

        train_step_fn.step = 0

        # trainables = [v.name for v in tf.trainable_variables()]
        # alls =[v.name for v in tf.all_variables()]
        # print '----- Trainables %d: '%len(trainables), trainables
        # print '----- All %d: '%len(alls), alls
        # print '===== ', len(list(set(trainables) - set(alls)))
        # print '===== ', len(list(set(alls) - set(trainables))), list(set(alls) - set(trainables))
        # print summaries
        print '+++++++++++', len([v.name for v in tf.trainable_variables()])

        if FLAGS.if_print_tensors:
            for op in tf.get_default_graph().get_operations():
                print str(op.name)

        init_assign_op, init_feed_dict = train_utils.model_init(
            FLAGS.restore_logdir,
            FLAGS.tf_initial_checkpoint,
            FLAGS.if_restore,
            FLAGS.initialize_last_layer,
            last_layers,
            # ignore_including=['_weights/BatchNorm', 'decoder_weights'],
            ignore_including=None,
            ignore_missing_vars=True)

        def InitAssignFn(sess):
            sess.run(init_assign_op, init_feed_dict)

        # Start the training.
        slim.learning.train(
            train_tensor,
            train_step_fn=train_step_fn,
            logdir=FLAGS.train_logdir,
            log_every_n_steps=FLAGS.log_steps,
            master=FLAGS.master,
            number_of_steps=FLAGS.training_number_of_steps,
            is_chief=(FLAGS.task == 0),
            session_config=session_config,
            startup_delay_steps=startup_delay_steps,
            init_fn=InitAssignFn if init_assign_op is not None else None,
            # init_fn=train_utils.get_model_init_fn(
            #     FLAGS.restore_logdir,
            #     FLAGS.tf_initial_checkpoint,
            #     FLAGS.if_restore,
            #     FLAGS.initialize_last_layer,
            #     last_layers,
            #     ignore_missing_vars=True),
            summary_op=summary_op,
            save_summaries_secs=FLAGS.save_summaries_secs,
            save_interval_secs=FLAGS.save_interval_secs
            if not (FLAGS.if_debug) else 300)
Beispiel #16
0
def train(args):
    os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_id
    tf.logging.set_verbosity(args.log)
    clone_on_cpu = args.gpu_id == ''
    num_clones = len(args.gpu_id.split(','))

    if args.log_root:
        if args.config is None:
            raise RuntimeError('No config json specified.')
        tf.logging.info('using config form {}'.format(args.config))
        with open(args.config, 'rt') as F:
            configs = json.load(F)
        hparams = Namespace(**configs)
        logdir_name = config_str.get_config_time_str(hparams, 'wavenet',
                                                     EXP_TAG)
        logdir = os.path.join(args.log_root, logdir_name)
        os.makedirs(logdir, exist_ok=True)
        shutil.copy(args.config, logdir)
    else:
        logdir = args.logdir
        config_json = glob.glob(os.path.join(logdir, '*.json'))[0]
        tf.logging.info('using config form {}'.format(config_json))
        with open(config_json, 'rt') as F:
            configs = json.load(F)
        hparams = Namespace(**configs)
    tf.logging.info('Saving to {}'.format(logdir))

    wn = wavenet.Wavenet(hparams,
                         os.path.abspath(os.path.expanduser(args.train_path)))

    def _data_dep_init():
        # slim.learning.train runs init_fn earlier than start_queue_runner
        # so the the function got dead locker if use the `input_dict` in L76 as input
        inputs_val = reader.get_init_batch(wn.train_path,
                                           batch_size=args.total_batch_size,
                                           seq_len=wn.wave_length)
        wave_data = inputs_val['wav']
        mel_data = inputs_val['mel']

        _inputs_dict = {
            'wav': tf.placeholder(dtype=tf.float32, shape=wave_data.shape),
            'mel': tf.placeholder(dtype=tf.float32, shape=mel_data.shape)
        }

        encode_dict = wn.encode_signal(_inputs_dict)
        _inputs_dict.update(encode_dict)
        init_ff_dict = wn.feed_forward(_inputs_dict, init=True)

        def callback(session):
            tf.logging.info('Calculate initial statistics.')
            init_out = session.run(init_ff_dict,
                                   feed_dict={
                                       _inputs_dict['wav']: wave_data,
                                       _inputs_dict['mel']: mel_data
                                   })
            init_out_params = init_out['out_params']
            if wn.loss_type == 'mol':
                _, mean, log_scale = np.split(init_out_params, 3, axis=2)
                scale = np.exp(np.maximum(log_scale, -7.0))
                _init_logging(mean, 'mean')
                _init_logging(scale, 'scale')
            elif wn.loss_type == 'gauss':
                mean, log_std = np.split(init_out_params, 2, axis=2)
                std = np.exp(np.maximum(log_std, -7.0))
                _init_logging(mean, 'mean')
                _init_logging(std, 'std')
            tf.logging.info('Done Calculate initial statistics.')

        return callback

    def _model_fn(_inputs_dict):
        encode_dict = wn.encode_signal(_inputs_dict)
        _inputs_dict.update(encode_dict)
        ff_dict = wn.feed_forward(_inputs_dict)
        ff_dict.update(encode_dict)
        loss_dict = wn.calculate_loss(ff_dict)
        loss = loss_dict['loss']
        tf.add_to_collection(tf.GraphKeys.LOSSES, loss)

    with tf.Graph().as_default():
        total_batch_size = args.total_batch_size
        assert total_batch_size % num_clones == 0
        clone_batch_size = int(total_batch_size / num_clones)

        deploy_config = model_deploy.DeploymentConfig(
            num_clones=num_clones,
            clone_on_cpu=clone_on_cpu,
            num_ps_tasks=0,
            worker_job_name='localhost',
            ps_job_name='localhost')

        with tf.device(deploy_config.inputs_device()):
            inputs_dict = wn.get_batch(clone_batch_size)

        summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES))

        clones = model_deploy.create_clones(deploy_config, _model_fn,
                                            [inputs_dict])
        first_clone_scope = deploy_config.clone_scope(0)
        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS,
                                       first_clone_scope)

        summaries.update(
            tf.get_collection(tf.GraphKeys.SUMMARIES, first_clone_scope))

        with tf.device(deploy_config.variables_device()):
            global_step = tf.get_variable(
                "global_step", [],
                tf.int32,
                initializer=tf.constant_initializer(0),
                trainable=False)

        with tf.device(deploy_config.optimizer_device()):
            lr = tf.constant(wn.learning_rate_schedule[0])
            for key, value in wn.learning_rate_schedule.items():
                lr = tf.cond(tf.less(global_step, key), lambda: lr,
                             lambda: tf.constant(value))
            summaries.add(tf.summary.scalar("learning_rate", lr))

            optimizer = tf.train.AdamOptimizer(lr, epsilon=1e-8)
            ema = tf.train.ExponentialMovingAverage(decay=0.9999,
                                                    num_updates=global_step)

            loss, clone_grads_vars = model_deploy.optimize_clones(
                clones, optimizer, var_list=tf.trainable_variables())
            update_ops.append(
                optimizer.apply_gradients(clone_grads_vars,
                                          global_step=global_step))
            update_ops.append(ema.apply(tf.trainable_variables()))

            summaries.add(tf.summary.scalar("train_loss", loss))

            update_op = tf.group(*update_ops)
            with tf.control_dependencies([update_op]):
                train_tensor = tf.identity(loss, name='train_op')

        session_config = tf.ConfigProto(allow_soft_placement=True)
        session_config.gpu_options.allow_growth = True
        summary_op = tf.summary.merge(list(summaries), name='summary_op')
        data_dep_init_fn = _data_dep_init()

        slim.learning.train(train_tensor,
                            logdir=logdir,
                            number_of_steps=wn.num_iters,
                            summary_op=summary_op,
                            global_step=global_step,
                            log_every_n_steps=100,
                            save_summaries_secs=600,
                            save_interval_secs=3600,
                            session_config=session_config,
                            init_fn=data_dep_init_fn)
Beispiel #17
0
def main(unused_argv):
    # syaru: Sets the threshold(入口) for what messages will be logged. 加上这句才能输出训练过程的log.
    tf.logging.set_verbosity(tf.logging.INFO)
    # Set up deployment (i.e., multi-GPUs and/or multi-replicas).
    # syaru: models/research/slim/deployment/model_deploy.DeploymentConfig(object)
    config = model_deploy.DeploymentConfig(num_clones=FLAGS.num_clones,
                                           clone_on_cpu=FLAGS.clone_on_cpu,
                                           replica_id=FLAGS.task,
                                           num_replicas=FLAGS.num_replicas,
                                           num_ps_tasks=FLAGS.num_ps_tasks)

    # Split the batch across GPUs.
    assert FLAGS.train_batch_size % config.num_clones == 0, (
        'Training batch size not divisble by number of clones (GPUs).')

    clone_batch_size = int(FLAGS.train_batch_size / config.num_clones)

    # Get dataset-dependent information.
    """
  syaru: deeplab/datasets/segmentation_dataset.get_dataset()
  Gets an instance of slim Dataset.
  Args:
    dataset_name: Dataset name.
    split_name: A train/val Split name.
    dataset_dir: The directory of the dataset sources.
  """
    dataset = segmentation_dataset.get_dataset(FLAGS.dataset,
                                               FLAGS.train_split,
                                               dataset_dir=FLAGS.dataset_dir)

    tf.gfile.MakeDirs(
        FLAGS.train_logdir
    )  # sayru: FLAGS.train_logdir = "pascal_voc_seg/exp/train_on_trainval_set/train"
    tf.logging.info('Training on %s set',
                    FLAGS.train_split)  #        FLAGS.train_split = "trainval"

    with tf.Graph().as_default() as graph:
        with tf.device(
                config.inputs_device()
        ):  # syaru: deeplab/utils/input_generator.get(): This functions gets the dataset split for semantic segmentation.
            samples = input_generator.get(  # Returns: A dictionary of batched Tensors for semantic segmentation.
                dataset,  # Args: dataset: An instance of slim Dataset.
                FLAGS.
                train_crop_size,  #       train_crop_size: 如果定义了crop_size,那么在train时会对大于crop_size的图片进行随机裁剪
                clone_batch_size,
                min_resize_value=FLAGS.min_resize_value,
                max_resize_value=FLAGS.max_resize_value,
                resize_factor=FLAGS.resize_factor,
                min_scale_factor=FLAGS.
                min_scale_factor,  # syaru: min_scale_factor: 'Minmum scale factor for data augmentation.'
                max_scale_factor=FLAGS.
                max_scale_factor,  # min_scale_factor: 'Maximum scale factor for data augmentation.'
                scale_factor_step_size=FLAGS.
                scale_factor_step_size,  # scale_factor_step_size: 'Scale factor step size for data augmentation.'(from minmum to maximum)
                dataset_split=FLAGS.train_split,
                is_training=True,
                model_variant=FLAGS.model_variant)
            # syaru: /tensorflow/contrib/slim/python/slim/data/prefetch_queue.py
            inputs_queue = prefetch_queue.prefetch_queue(  # tensors: A list or dictionary of `Tensors` to enqueue in the buffer.
                samples,
                capacity=128 * config.num_clones
            )  # capacity: An integer. The maximum number of elements in the queue.

        # Create the global step on the device storing the variables.
        with tf.device(config.variables_device()):
            global_step = tf.train.get_or_create_global_step()

            # Define the model and create clones.
            """
      syaru: 
      models/research/slim/deployment/model_deploy.create_clones():
      The `model_fn(*args, **kwargs)` function is called `config.num_clones` times to create the model clones.
      (and one or several clones are deployed on different GPUs and one or several replicas of such clones.)
      Then it return the scope and device in a namedtuple `Clone(outputs, scope, device)`.

      Args:
      config: A DeploymentConfig object.
      model_fn: A callable. Called as `model_fn(*args, **kwargs)`
      args: Optional list of arguments to pass to `model_fn`.
      kwargs: Optional list of keyword arguments to pass to `model_fn`..
      Returns:
      A list of namedtuples `Clone`.

      Note: it is assumed that any loss created by `model_fn` is collected at
      the tf.GraphKeys.LOSSES collection.

      To recover the losses, summaries or update_ops created by the clone use:
      ```python
      losses = tf.get_collection(tf.GraphKeys.LOSSES, clone.scope)
      summaries = tf.get_collection(tf.GraphKeys.SUMMARIES, clone.scope)
      update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, clone.scope)
      ```
      """
            model_fn = _build_deeplab
            model_args = (inputs_queue, {
                common.OUTPUT_TYPE: dataset.num_classes
            }, dataset.ignore_label)
            clones = model_deploy.create_clones(config,
                                                model_fn,
                                                args=model_args)

            # Gather update_ops from the first clone. These contain, for example,
            # the updates for the batch_norm variables created by model_fn.
            first_clone_scope = config.clone_scope(0)
            update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS,
                                           first_clone_scope)

        # Gather initial summaries.
        summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES))

        # Add summaries for model variables.
        for model_var in slim.get_model_variables():
            summaries.add(tf.summary.histogram(model_var.op.name, model_var))

        # Add summaries for images, labels, semantic predictions
        if FLAGS.save_summaries_images:
            summary_image = graph.get_tensor_by_name(  # syaru: get_tensor_by_name(name): return tensor by specifily 'name'.
                ('%s/%s:0' % (first_clone_scope, common.IMAGE)).strip('/')
            )  # str.strip (): is used to remove the specified characters at the front/end of the string (the default is space).
            summaries.add(
                tf.summary.image('samples/%s' % common.IMAGE, summary_image))

            summary_label = tf.cast(
                graph.get_tensor_by_name(
                    ('%s/%s:0' %
                     (first_clone_scope, common.LABEL)).strip('/')), tf.uint8)
            summaries.add(
                tf.summary.image('samples/%s' % common.LABEL, summary_label))

            predictions = tf.cast(
                tf.expand_dims(
                    tf.argmax(
                        graph.get_tensor_by_name(  # syaru: tf.argmax(axis=3)
                            ('%s/%s:0' % (first_clone_scope,
                                          common.OUTPUT_TYPE)).strip('/')),
                        3),
                    -1),
                tf.uint8)
            summaries.add(
                tf.summary.image('samples/%s' % common.OUTPUT_TYPE,
                                 predictions))

        # Add summaries for losses.
        for loss in tf.get_collection(tf.GraphKeys.LOSSES, first_clone_scope):
            summaries.add(tf.summary.scalar('losses/%s' % loss.op.name, loss))

        # Build the optimizer based on the device specification.
        with tf.device(config.optimizer_device()):
            # syaru: train_utils.get_model_learning_rate():
            #        Computes the model's learning rate for different learning policy("step" and "poly").
            learning_rate = train_utils.get_model_learning_rate(
                FLAGS.learning_policy, FLAGS.base_learning_rate,
                FLAGS.learning_rate_decay_step,
                FLAGS.learning_rate_decay_factor,
                FLAGS.training_number_of_steps, FLAGS.learning_power,
                FLAGS.slow_start_step, FLAGS.slow_start_learning_rate)
            optimizer = tf.train.MomentumOptimizer(learning_rate,
                                                   FLAGS.momentum)
            summaries.add(tf.summary.scalar('learning_rate', learning_rate))

        startup_delay_steps = FLAGS.task * FLAGS.startup_delay_steps
        for variable in slim.get_model_variables():
            summaries.add(tf.summary.histogram(variable.op.name, variable))

        with tf.device(config.variables_device()):
            # syaru: Compute clone losses and gradients for the given list of `Clones`.
            total_loss, grads_and_vars = model_deploy.optimize_clones(
                clones, optimizer)
            total_loss = tf.check_numerics(total_loss, 'Loss is inf or nan.')
            summaries.add(tf.summary.scalar('total_loss', total_loss))

            # Modify the gradients for biases and last layer variables.
            """
      syaru: 
      For the task of semantic segmentation, the models are
      usually fine-tuned from the models trained on the task of image
      classification. To fine-tune the models, we usually set larger (e.g.,
      10 times larger) learning rate for the parameters of last layer.
      
      deeplab/model/model.get_extra_layer_scopes():
      Returns: A list of scopes for extra layers. 

      deeplab/utils/train_utils.get_model_gradient_multipliers():
      Returns: The gradient multiplier map with variables as key, and multipliers as value.
      """
            last_layers = model.get_extra_layer_scopes(
                FLAGS.last_layers_contain_logits_only)
            grad_mult = train_utils.get_model_gradient_multipliers(
                last_layers, FLAGS.last_layer_gradient_multiplier)
            if grad_mult:
                grads_and_vars = slim.learning.multiply_gradients(
                    grads_and_vars, grad_mult)

            # Create gradient update op.
            grad_updates = optimizer.apply_gradients(grads_and_vars,
                                                     global_step=global_step)
            update_ops.append(grad_updates)
            # syaru: tf.identity()和tf.group()均可将语句变为操作(ops).
            #        (我们需要`optimizer.apply_gradients`后才计算`total_loss`(as 'train_op'),而tf.control_dependencies()适用于tf.ops)
            #        And `update_ops = tf.get_collection(..)` only return a list of variables.
            update_op = tf.group(*update_ops)
            with tf.control_dependencies([update_op]):
                train_tensor = tf.identity(total_loss, name='train_op')

        # Add the summaries from the first clone. These contain the summaries
        # created by model_fn and either optimize_clones() or _gather_clone_loss().
        summaries |= set(
            tf.get_collection(tf.GraphKeys.SUMMARIES, first_clone_scope))

        # Merge all summaries together.
        summary_op = tf.summary.merge(list(summaries))

        # Soft placement allows placing on CPU ops without GPU implementation.
        # syaru: set gpu_options
        gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.9)
        session_config = tf.ConfigProto(allow_soft_placement=True,
                                        log_device_placement=False,
                                        gpu_options=gpu_options)

        # Start the training.
        # syaru: /tensorflow/contrib/slim/python/slim/learning.py
        # train_utils.get_model_init_fn(): Gets the function initializing model variables from a checkpoint.
        slim.learning.train(
            train_tensor,
            logdir=FLAGS.train_logdir,
            log_every_n_steps=FLAGS.log_steps,
            master=FLAGS.master,
            number_of_steps=FLAGS.training_number_of_steps,
            is_chief=(FLAGS.task == 0),
            session_config=session_config,
            startup_delay_steps=startup_delay_steps,  # syaru:
            init_fn=train_utils.
            get_model_init_fn(  # `init_fn`: An optional callable to be executed after `init_op` is called. The
                FLAGS.
                train_logdir,  # callable must accept one argument, the session being initialized.
                FLAGS.tf_initial_checkpoint,
                FLAGS.initialize_last_layer,
                last_layers,
                ignore_missing_vars=True),
            summary_op=summary_op,
            save_summaries_secs=FLAGS.save_summaries_secs,
            save_interval_secs=FLAGS.save_interval_secs)
def train(datasets_dicts,
          epochs,
          val_every,
          iters_cnt,
          validate_with_eval_model,
          pipeline_config,
          num_clones=1,
          save_cback=None):
    logger.info('Start train')
    configs = configs_from_pipeline(pipeline_config)

    model_config = configs['model']
    train_config = configs['train_config']

    create_model_fn = functools.partial(
        model_builder.build,
        model_config=model_config,
        is_training=True)
    detection_model = create_model_fn()

    def get_next(dataset):
        return dataset_util.make_initializable_iterator(
            build_dataset(dataset)).get_next()

    create_tensor_dict_fn = functools.partial(get_next, datasets_dicts['train'])
    create_tensor_dict_fn_val = functools.partial(get_next, datasets_dicts['val'])

    data_augmentation_options = [
        preprocessor_builder.build(step)
        for step in train_config.data_augmentation_options]

    with tf.Graph().as_default():
        # Build a configuration specifying multi-GPU and multi-replicas.
        deploy_config = model_deploy.DeploymentConfig(
            num_clones=4,
            clone_on_cpu=False,
            replica_id=0,
            num_replicas=1,
            num_ps_tasks=0,
            worker_job_name='lonely_worker')

        # Place the global step on the device storing the variables.
        with tf.device(deploy_config.variables_device()):
            global_step = slim.create_global_step()

        with tf.device(deploy_config.inputs_device()):
            coord = coordinator.Coordinator()
            input_queue = create_input_queue(
                train_config.batch_size, create_tensor_dict_fn,
                train_config.batch_queue_capacity,
                train_config.num_batch_queue_threads,
                train_config.prefetch_queue_capacity, data_augmentation_options)

            input_queue_val = create_input_queue(
                train_config.batch_size, create_tensor_dict_fn_val,
                train_config.batch_queue_capacity,
                train_config.num_batch_queue_threads,
                train_config.prefetch_queue_capacity, data_augmentation_options)

        # create validation graph
        create_model_fn_val = functools.partial(
            model_builder.build,
            model_config=model_config,
            is_training=not validate_with_eval_model)

        with tf.device(deploy_config.optimizer_device()):
            training_optimizer, optimizer_summary_vars = optimizer_builder.build(
                train_config.optimizer)
            for var in optimizer_summary_vars:
                tf.summary.scalar(var.op.name, var, family='LearningRate')

        train_losses = []
        grads_and_vars = []
        with slim.arg_scope([slim.model_variable, slim.variable], device='/device:CPU:0'):
            for curr_dev_id in range(num_clones):
                with tf.device('/gpu:{}'.format(curr_dev_id)):
                    with tf.name_scope('clone_{}'.format(curr_dev_id)) as scope:
                        with tf.variable_scope(tf.get_variable_scope(),
                                               reuse=True if curr_dev_id > 0 else None):
                            losses = _create_losses_val(input_queue, create_model_fn, train_config)
                            clones_loss = tf.add_n(losses)
                            clones_loss = tf.divide(clones_loss, 1.0 * num_clones)
                            grads = training_optimizer.compute_gradients(clones_loss)
                            train_losses.append(clones_loss)
                            grads_and_vars.append(grads)
                            if curr_dev_id == 0:
                                update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)

        val_total_loss = get_val_loss(num_clones, input_queue_val, create_model_fn_val, train_config)

        with tf.device(deploy_config.optimizer_device()):
            total_loss = tf.add_n(train_losses)
            grads_and_vars = model_deploy._sum_clones_gradients(grads_and_vars)
            total_loss = tf.check_numerics(total_loss, 'LossTensor is inf or nan.')

            # Optionally multiply bias gradients by train_config.bias_grad_multiplier.
            if train_config.bias_grad_multiplier:
                biases_regex_list = ['.*/biases']
                grads_and_vars = variables_helper.multiply_gradients_matching_regex(
                    grads_and_vars,
                    biases_regex_list,
                    multiplier=train_config.bias_grad_multiplier)

            # Optionally freeze some layers by setting their gradients to be zero.
            if train_config.freeze_variables:
                grads_and_vars = variables_helper.freeze_gradients_matching_regex(
                    grads_and_vars, train_config.freeze_variables)

            # Optionally clip gradients
            if train_config.gradient_clipping_by_norm > 0:
                with tf.name_scope('clip_grads'):
                    grads_and_vars = slim.learning.clip_gradient_norms(
                        grads_and_vars, train_config.gradient_clipping_by_norm)

            # Create gradient updates.
            grad_updates = training_optimizer.apply_gradients(grads_and_vars,
                                                              global_step=global_step)
            update_ops.append(grad_updates)
            update_op = tf.group(*update_ops, name='update_barrier')
            with tf.control_dependencies([update_op]):
                train_tensor = tf.identity(total_loss, name='train_op')

        config = tf.ConfigProto(allow_soft_placement=True, log_device_placement=False)
        coord.clear_stop()
        sess = tf.Session(config=config)
        saver = tf.train.Saver()

        graph = ops.get_default_graph()
        with graph.as_default():
            with ops.name_scope('init_ops'):
                init_op = variables.global_variables_initializer()
                ready_op = variables.report_uninitialized_variables()
                local_init_op = control_flow_ops.group(
                        variables.local_variables_initializer(),
                        lookup_ops.tables_initializer())

        # graph.finalize()
        sess.run([init_op, ready_op, local_init_op])

        queue_runners = graph.get_collection(ops.GraphKeys.QUEUE_RUNNERS)
        threads = []
        for qr in queue_runners:
            threads.extend(qr.create_threads(sess, coord=coord, daemon=True, start=True))

        logger.info('Start restore')
        if train_config.fine_tune_checkpoint:
            var_map = detection_model.restore_map(
                            fine_tune_checkpoint_type=train_config.fine_tune_checkpoint_type,
                            load_all_detection_checkpoint_vars=(
                                train_config.load_all_detection_checkpoint_vars))
            available_var_map = (variables_helper.
                                    get_variables_available_in_checkpoint(
                                    var_map, train_config.fine_tune_checkpoint))
            if 'global_step' in available_var_map:
                del available_var_map['global_step']
            init_saver = tf.train.Saver(available_var_map)
            logger.info('Restoring model weights from previous checkpoint.')
            init_saver.restore(sess, train_config.fine_tune_checkpoint)
            logger.info('Model restored.')

        eval_planner = EvalPlanner(epochs, val_every)
        progress = sly.progress_counter_train(epochs, iters_cnt['train'])
        best_val_loss = float('inf')
        epoch_flt = 0

        for epoch in range(epochs):
            logger.info("Before new epoch", extra={'epoch': epoch_flt})
            for train_it in range(iters_cnt['train']):
                total_loss, np_global_step = sess.run([train_tensor, global_step])

                metrics_values_train = {
                    'loss': total_loss,
                }

                progress.iter_done_report()
                epoch_flt = epoch_float(epoch, train_it + 1, iters_cnt['train'])
                sly.report_metrics_training(epoch_flt, metrics_values_train)

                if eval_planner.need_validation(epoch_flt):
                    logger.info("Before validation", extra={'epoch': epoch_flt})

                    overall_val_loss = 0
                    for val_it in range(iters_cnt['val']):
                        overall_val_loss += sess.run(val_total_loss)

                        logger.info("Validation in progress", extra={'epoch': epoch_flt,
                                                                     'val_iter': val_it,
                                                                     'val_iters': iters_cnt['val']})

                    metrics_values_val = {
                        'loss': overall_val_loss / iters_cnt['val'],
                    }
                    sly.report_metrics_validation(epoch_flt, metrics_values_val)
                    logger.info("Validation has been finished", extra={'epoch': epoch_flt})

                    eval_planner.validation_performed()

                    val_loss = metrics_values_val['loss']
                    model_is_best = val_loss < best_val_loss
                    if model_is_best:
                        best_val_loss = val_loss
                        logger.info('It\'s been determined that current model is the best one for a while.')

                    save_cback(saver,
                               sess,
                               model_is_best,
                               opt_data={
                                         'epoch': epoch_flt,
                                         'val_metrics': metrics_values_val,
                               })

            logger.info("Epoch was finished", extra={'epoch': epoch_flt})
        coord.request_stop()
        coord.join(threads)
def main(_):
    tf.logging.set_verbosity(tf.logging.DEBUG)  #设置显示的log的阈值

    with tf.Graph().as_default():
        # Config model_deploy. Keep TF Slim Models structure.
        # Useful if want to need multiple GPUs and/or servers in the future.
        deploy_config = model_deploy.DeploymentConfig()
        # Create global_step.
        with tf.device(deploy_config.variables_device()):
            global_step = slim.create_global_step()
        file_name = os.path.join("../tfrecord", "train.tfrecords")

        def read_and_decode(filename,
                            num_epochs):  # read iris_contact.tfrecords
            filename_queue = tf.train.string_input_producer(
                [filename], num_epochs=num_epochs)
            reader = tf.TFRecordReader()
            print(filename_queue)
            _, serialized_example = reader.read(
                filename_queue)  # return file_name and file
            features = tf.parse_single_example(
                serialized_example,
                features={
                    'image/encoded':
                    tf.FixedLenFeature((), tf.string, default_value=''),
                    # 三个参数:shape,type,default_value
                    'image/format':
                    tf.FixedLenFeature((), tf.string, default_value='jpeg'),
                    'image/shape':
                    tf.FixedLenFeature([2], tf.int64),
                    'label':
                    tf.FixedLenFeature((), tf.string, default_value='unknow'),
                    'index':
                    tf.FixedLenFeature([1], tf.int64)
                })  # return image and label

            # img = tf.decode_raw(features['image/encoded'], tf.uint8)
            img = tf.image.decode_jpeg(features['image/encoded'])
            shape = features["image/shape"]
            img = tf.reshape(img, [32, 100, 3])  #  reshape image to 512*80*3
            img = tf.cast(img,
                          tf.float32) * (1. / 255) - 0.5  # throw img tensor
            label = features['label']  # throw label tensor
            index = features["index"]
            return img, label, shape, index

        def preprocess(image_raw):
            image = tf.image.decode_jpeg(tf.image.encode_jpeg(image_raw))
            return resize_image(image, (100, 32))

        def inputs(batch_size, num_epochs, filename):
            """Reads input data num_epochs times.
            Args:
              train: Selects between the training (True) and validation (False) data.
              batch_size: Number of examples per returned batch.
              num_epochs: Number of times to read the input data, or 0/None to
                 train forever.
            Returns:
              A tuple (images, labels), where:
              * images is a float tensor with shape [batch_size, mnist.IMAGE_PIXELS]
                in the range [-0.5, 0.5].
              * labels is an int32 tensor with shape [batch_size] with the true label,
                a number in the range [0, mnist.NUM_CLASSES).
              Note that an tf.train.QueueRunner is added to the graph, which
              must be run using e.g. tf.train.start_queue_runners().
            """
            if not num_epochs: num_epochs = None
            #filename = os.path.join(file_dir)

            with tf.name_scope('input'):
                # Even when reading in multiple threads, share the filename
                # queue.
                image, label, shape, index = read_and_decode(
                    filename, num_epochs)

                #image = preprocess(image)
                # Shuffle the examples and collect them into batch_size batches.
                # (Internally uses a RandomShuffleQueue.)
                # We run this in two threads to avoid being a bottleneck.
                images, shuffle_labels, sshape, sindex = tf.train.shuffle_batch(
                    [image, label, shape, index],
                    batch_size=batch_size,
                    num_threads=2,
                    capacity=1000 + 3 * batch_size,
                    # Ensures a minimum amount of shuffling of examples.
                    min_after_dequeue=100)

                return images, shuffle_labels, sshape, sindex

        with tf.Graph().as_default():
            # Input images and labels.
            starter_learning_rate = 0.1
            learning_rate = tf.train.exponential_decay(starter_learning_rate,
                                                       global_step,
                                                       100000,
                                                       0.96,
                                                       staircase=True)
            images, shuffle_labels, sshape, sindex = inputs(
                filename=file_name,
                batch_size=batch_size,
                num_epochs=num_epochs)

            crnn = model.CRNNNet()
            logits, inputs, seq_len, W, b = crnn.net(images)

            shuffle_labels = ['123456', '123', '12342']
            labels = shuffle_labels

            def sparse_tuple_from(sequences, dtype=np.int32):
                """Create a sparse representention of x.
                Args:
                    sequences: a list of lists of type dtype where each element is a sequence
                Returns:
                    A tuple with (indices, values, shape)
                """
                indices = []
                values = []

                for n, seq in enumerate(sequences):
                    indices.extend(zip([n] * len(seq), range(len(seq))))
                    values.extend(seq)

                indices = np.asarray(indices, dtype=np.int64)
                values = np.asarray(values, dtype=dtype)
                shape = np.asarray(
                    [len(sequences),
                     np.asarray(indices).max(0)[1] + 1],
                    dtype=np.int64)

                return indices, values, shape

            sparse_labels = sparse_tuple_from(labels)

            cost = crnn.losses(sparse_labels, logits, seq_len)
            optimizer = tf.train.AdadeltaOptimizer(
                learning_rate=learning_rate).minimize(loss=cost,
                                                      global_step=global_step)

            # Option 2: tf.contrib.ctc.ctc_beam_search_decoder
            # (it's slower but you'll get better results)
            decoded, log_prob = tf.nn.ctc_beam_search_decoder(
                logits, seq_len, merge_repeated=False)

            # Accuracy: label error rate
            acc = tf.reduce_mean(
                tf.edit_distance(tf.cast(decoded[0], tf.int32), sparse_labels))

            sess = tf.Session()
            init_op = tf.group(tf.global_variables_initializer(),
                               tf.local_variables_initializer())
            sess.run(init_op)  # Start input enqueue threads.
            coord = tf.train.Coordinator()
            threads = tf.train.start_queue_runners(sess=sess, coord=coord)
            #sess = tf_debug.LocalCLIDebugWrapperSession(sess)
            try:
                step = 0
                while not coord.should_stop():
                    start_time = time.time()

                    # Run one step of the model.  The return values are
                    # the activations from the `train_op` (which is
                    # discarded) and the `loss` op.  To inspect the values
                    # of your ops or variables, you may include them in
                    # the list passed to sess.run() and the value tensors
                    # will be returned in the tuple from the call.
                    #timages, tsparse_labels, tsshape, tsindex = sess.run([images, sparse_labels, sshape, sindex])

                    val_cost, val_ler, lr, step = sess.run(
                        [cost, acc, learning_rate, global_step])

                    duration = time.time() - start_time

                    print(val_cost)

                    # Print an overview fairly often.
                    if step % 10 == 0:
                        print('Step %d:  (%.3f sec)' % (step, duration))
                    step += 1
            except tf.errors.OutOfRangeError:
                print('Done training for %d epochs, %d steps.' %
                      (num_epochs, step))
            finally:
                # When done, ask the threads to stop.
                coord.request_stop()

                # Wait for threads to finish.
            coord.join(threads)
            sess.close()
Beispiel #20
0
def train():
    if not FLAGS.dataset_dir:
        raise ValueError(
            'You must supply the dataset directory with --dataset_dir')

    tf.logging.set_verbosity(tf.logging.INFO)
    with tf.Graph().as_default():
        #######################
        # Config model_deploy #
        #######################
        deploy_config = model_deploy.DeploymentConfig(
            num_clones=FLAGS.num_clones,
            clone_on_cpu=FLAGS.clone_on_cpu,
            replica_id=FLAGS.task,
            num_replicas=FLAGS.worker_replicas,
            num_ps_tasks=FLAGS.num_ps_tasks)

        # Create global_step
        with tf.device(deploy_config.variables_device()):
            global_step = tf.train.create_global_step()

        ######################
        # Select the network and #
        ######################
        network_fn = {}
        model_names = [net.strip() for net in FLAGS.model_name.split(',')]
        for i in range(FLAGS.num_networks):
            network_fn["{0}".format(i)] = nets_factory.get_network_fn(
                model_names[i],
                num_classes=FLAGS.num_classes,
                weight_decay=FLAGS.weight_decay)

        #####################################
        # Select the preprocessing function #
        #####################################
        preprocessing_name = FLAGS.preprocessing_name  # or FLAGS.model_name
        image_preprocessing_fn = preprocessing_factory.get_preprocessing(
            preprocessing_name, is_training=True)

        ##############################################################
        # Create a dataset provider that loads data from the dataset #
        ##############################################################
        test_image_batch, test_label_batch = utils.get_image_label_batch(
            FLAGS, shuffle=False, name='test')

        test_label_batch = slim.one_hot_encoding(test_label_batch,
                                                 FLAGS.num_classes)

        precision, test_precision, test_predictions, net_var_list, net_grads, net_update_ops = {}, {}, {}, {}, {}, {}
        semi_net_grads = {}

        with tf.name_scope('tower') as scope:
            with tf.variable_scope(tf.get_variable_scope()):
                test_net_loss, _, test_net_pred, test_attention0, test_attention1, test_second_input = _tower_loss(
                    network_fn,
                    test_image_batch,
                    test_label_batch,
                    is_cross=True,
                    reuse=False,
                    is_training=False)

                test_truth = tf.argmax(test_label_batch, axis=1)

                # Reuse variables for the next tower.
                #tf.get_variable_scope().reuse_variables()

                # Retain the summaries from the final tower.
                summaries = tf.get_collection(tf.GraphKeys.SUMMARIES, scope)
                var_list = tf.trainable_variables()

                for i in range(FLAGS.num_networks):
                    test_predictions["{0}".format(i)] = tf.argmax(
                        test_net_pred["{0}".format(i)], axis=1)
                    test_precision["{0}".format(i)] = tf.reduce_mean(
                        tf.to_float(
                            tf.equal(test_predictions["{0}".format(i)],
                                     test_truth)))
                    #test_predictions["{0}".format(i)] = test_net_pred["{0}".format(i)]
                net_pred = (test_net_pred["{0}".format(0)] +
                            test_net_pred["{0}".format(1)]) / 2.0
                net_pred = tf.argmax(net_pred, axis=1)
                precision_mean = tf.reduce_mean(
                    tf.to_float(tf.equal(net_pred, test_truth)))

                # Add a summary to track the training precision.
                #summaries.append(tf.summary.scalar('precision_%d' % i, precision["{0}".format(i)]))
                #summaries.append(tf.summary.scalar('test_precision_%d' % i, test_precision["{0}".format(i)]))

        # Create a saver.
        saver = tf.train.Saver(tf.global_variables())

        # Build the summary operation from the last tower summaries.
        #summary_op = tf.summary.merge(summaries)

        # Build an initialization operation to run below.
        init = tf.global_variables_initializer()

        # Start running operations on the Graph. allow_soft_placement must be set to
        # True to build towers on GPU, as some of the ops do not have GPU
        # implementations.
        sess = tf.Session(config=tf.ConfigProto(
            gpu_options=tf.GPUOptions(per_process_gpu_memory_fraction=0.5),
            allow_soft_placement=True,
            log_device_placement=FLAGS.log_device_placement))
        sess.run(init)

        load_fn = slim.assign_from_checkpoint_fn(os.path.join(
            FLAGS.checkpoint_dir, 'model.ckpt-0'),
                                                 tf.global_variables(),
                                                 ignore_missing_vars=True)
        #load_fn = slim.assign_from_checkpoint_fn('./WCE_densenet4/checkpoint/model.ckpt-20',tf.global_variables(),ignore_missing_vars=True)
        load_fn(sess)

        # Start the queue runners.
        tf.train.start_queue_runners(sess=sess)

        #summary_writer = tf.summary.FileWriter(
        #    os.path.join(FLAGS.log_dir),
        #    graph=sess.graph)

        net_loss_value, test_precision_value, test_predictions_value, precision_value = {}, {}, {}, {}

        parameters = utils.count_trainable_params()
        print("Total training params: %.1fM \r\n" % (parameters / 1e6))

        start_time = time.time()
        counter = 0
        infile = open(os.path.join(FLAGS.attention_map, 'log.txt'), 'w')
        batch_count = np.int32(FLAGS.dataset_size / FLAGS.batch_size)
        precision_value["{0}".format(0)] = []
        precision_value["{0}".format(1)] = []
        precision_value["{0}".format(2)] = []
        feature0 = []
        feature1 = []
        for batch_idx in range(batch_count):
            #for i in range(FLAGS.num_networks):
            test_predictions_value["{0}".format(0)], test_predictions_value[
                "{0}".format(1)], truth, predictions, test_precision_value[
                    "{0}".format(0)], test_precision_value["{0}".format(
                        1
                    )], prec, test, test_att0, test_att1, test_sec = sess.run([
                        test_predictions["{0}".format(0)],
                        test_predictions["{0}".format(1)], test_truth,
                        net_pred, test_precision["{0}".format(0)],
                        test_precision["{0}".format(1)], precision_mean,
                        test_image_batch, test_attention0, test_attention1,
                        test_second_input
                    ])

            precision_value["{0}".format(0)].append(
                test_precision_value["{0}".format(0)])
            precision_value["{0}".format(1)].append(
                test_precision_value["{0}".format(1)])
            precision_value["{0}".format(2)].append(prec)

            #predictions = test_predictions_value["{0}".format(1)]
            #infile.write(str(np.around(predictions[:,0], decimals=3))+' '+str(np.around(predictions[:,1], decimals=3))+'\n')
            #infile.write(str(np.around(predictions[:,2], decimals=3))+' '+str(truth)+'\n')
            infile.write(
                str(test_predictions_value["{0}".format(0)]) + ' ' +
                str(test_predictions_value["{0}".format(1)]) + '\n')
            infile.write(str(predictions) + ' ' + str(truth) + '\n')
            infile.write(
                str(np.float32(test_precision_value["{0}".format(0)])) + ' ' +
                str(np.float32(test_precision_value["{0}".format(1)])) + ' ' +
                str(np.float32(prec)) + '\n')
            format_str = 'batch_idx: [%3d] [%3d/%3d] time: %4.4f, net0_test_acc = %.4f,      net1_test_acc = %.4f,      net_test_acc = %.4f'
            print(format_str %
                  (batch_idx, batch_idx, batch_count, time.time() - start_time,
                   np.float32(test_precision_value["{0}".format(0)]),
                   np.float32(test_precision_value["{0}".format(1)]),
                   np.float32(prec)))

            #train, att0, att1, sec, semi, semi_att0, semi_att1, semi_sec, test, test_att0, test_att1, test_sec  = sess.run([train_image_batch, attention0, attention1, second_input, semi_image_batch, semi_attention0, semi_attention1, semi_second_input, test_image_batch, test_attention0, test_attention1, test_second_input])
            #train, att0, att1, sec, test, test_att0, test_att1, test_sec  = sess.run([train_image_batch, attention0, attention1, second_input, test_image_batch, test_attention0, test_attention1, test_second_input])

            #test, test_att0, test_att1, test_sec, test_att0_0, test_att0_1, test_att1_0, test_att1_1  = sess.run([test_image_batch, test_attention0, test_attention1, test_second_input, att0_0, att0_1, att1_0, att1_1])
            #feature0.append(netendpoints["{0}".format(0)]['feature'])
            #feature1.append(netendpoints["{0}".format(1)]['feature'])

            for index in range(test.shape[0]):
                test1 = test[index, :, :, :]
                test_att01 = test_att0[index, :, :, :]
                test_att11 = test_att1[index, :, :, :]
                test_sec1 = test_sec[index, :, :, :]
                #test_att0_01 = test_att0_0[index,:,:,:]
                #test_att0_11 = test_att0_1[index,:,:,:]
                #test_att1_01 = test_att1_0[index,:,:,:]
                #test_att1_11 = test_att1_1[index,:,:,:]

                #                 test_att01 = to_heat(test_att01)
                #                 test_att11 = to_heat(test_att11)
                scipy.misc.imsave(
                    os.path.join(
                        FLAGS.attention_map,
                        str(batch_idx) + '_' + str(index) + 'test.jpg'),
                    test1[:, :, :])
                scipy.misc.imsave(
                    os.path.join(
                        FLAGS.attention_map,
                        str(batch_idx) + '_' + str(index) + 'test_att0.jpg'),
                    test_att01[:, :, :])
                scipy.misc.imsave(
                    os.path.join(
                        FLAGS.attention_map,
                        str(batch_idx) + '_' + str(index) + 'test_att1.jpg'),
                    test_att11[:, :, :])
                #scipy.misc.imsave(os.path.join(FLAGS.attention_map, str(batch_idx)+'_'+str(index)+'test_att0_0.jpg'), test_att0_01[:,:,:])
                #scipy.misc.imsave(os.path.join(FLAGS.attention_map, str(batch_idx)+'_'+str(index)+'test_att0_1.jpg'), test_att0_11[:,:,:])
                #scipy.misc.imsave(os.path.join(FLAGS.attention_map, str(batch_idx)+'_'+str(index)+'test_att1_0.jpg'), test_att1_01[:,:,:])
                #scipy.misc.imsave(os.path.join(FLAGS.attention_map, str(batch_idx)+'_'+str(index)+'test_att1_1.jpg'), test_att1_11[:,:,:])
                scipy.misc.imsave(
                    os.path.join(
                        FLAGS.attention_map,
                        str(batch_idx) + '_' + str(index) + 'test_sec.jpg'),
                    test_sec1[:, :, :])

        #scipy.io.savemat(os.path.join(FLAGS.attention_map, 'feature0.mat'), {'feature_map0': feature0})
        #scipy.io.savemat(os.path.join(FLAGS.attention_map, 'feature1.mat'), {'feature_map1': feature1})

        for i in range(FLAGS.num_networks):
            print(np.mean(np.array(precision_value["{0}".format(i)])))
        print(np.mean(np.array(precision_value["{0}".format(2)])))
        infile.write(
            str(np.mean(np.array(precision_value["{0}".format(0)]))) + ' ' +
            str(np.mean(np.array(precision_value["{0}".format(1)]))) + ' ' +
            str(np.mean(np.array(precision_value["{0}".format(2)]))) + '\n')
        infile.close()
# Helper methods      #
#######################

#######################
# Main methods      #
#######################
if not tf.gfile.Exists(FLAGS.out_dir):
    tf.gfile.MakeDirs(FLAGS.out_dir)

with tf.Graph().as_default():
    #######################
    # Config model_deploy #
    #######################
    deploy_config = model_deploy.DeploymentConfig(
        num_clones=FLAGS.num_clones,
        clone_on_cpu=FLAGS.clone_on_cpu,
        replica_id=FLAGS.task,
        num_replicas=FLAGS.worker_replicas,
        num_ps_tasks=FLAGS.num_ps_tasks)

    ##############################################################
    # Create a dataset provider that loads data from the dataset #
    ##############################################################
    with tf.device(deploy_config.inputs_device()):
        ds = input_dataset.get_split(split_name=FLAGS.dataset_name,
                                     dataset_dir=FLAGS.dataset_dir,
                                     file_pattern='spitz_train')
        provider = slim.dataset_data_provider.DatasetDataProvider(
            ds,
            num_readers=FLAGS.num_readers,
            common_queue_capacity=20 * FLAGS.batch_size,
            common_queue_min=10 * FLAGS.batch_size)
def train(args):
    os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_id
    tf.logging.set_verbosity(args.log)
    clone_on_cpu = args.gpu_id == ''
    num_clones = len(args.gpu_id.split(','))

    ###
    # get teacher info.
    ###
    teacher_dir = utils.shell_path(args.teacher_dir)
    assert tf.gfile.IsDirectory(teacher_dir)
    json_in_dir = glob.glob(os.path.join(teacher_dir, '*.json'))
    assert len(json_in_dir) == 1
    te_json = json_in_dir[0]
    te_ckpt = tf.train.latest_checkpoint(teacher_dir)
    assert tf.train.checkpoint_exists(te_ckpt)

    with open(te_json, 'rt') as F:
        configs = json.load(F)
    te_hparams = Namespace(**configs)
    teacher = wavenet.Wavenet(te_hparams)

    ###
    # get student info.
    ###
    if args.config is None:
        raise RuntimeError('No config json specified.')
    with open(args.config, 'rt') as F:
        configs = json.load(F)
    st_hparams = Namespace(**configs)
    pwn = parallel_wavenet.ParallelWavenet(st_hparams, teacher,
                                           args.train_path)

    def _data_dep_init():
        inputs_val = reader.get_init_batch(pwn.train_path,
                                           batch_size=args.total_batch_size,
                                           seq_len=pwn.wave_length)
        mel_data = inputs_val['mel']

        _inputs_dict = {
            'mel': tf.placeholder(dtype=tf.float32, shape=mel_data.shape)
        }

        init_ff_dict = pwn.feed_forward(_inputs_dict, init=True)

        def callback(session):
            tf.logging.info('Running data dependent initialization '
                            'for weight normalization')
            init_out = session.run(init_ff_dict,
                                   feed_dict={_inputs_dict['mel']: mel_data})
            new_x = init_out['x']
            mean = init_out['mean_tot']
            scale = init_out['scale_tot']
            _init_logging(new_x, 'new_x')
            _init_logging(mean, 'mean')
            _init_logging(scale, 'scale')
            tf.logging.info('Done data dependent initialization '
                            'for weight normalization')

        return callback

    def _model_fn(_inputs_dict):
        ff_dict = pwn.feed_forward(_inputs_dict)
        ff_dict.update(_inputs_dict)
        loss_dict = pwn.calculate_loss(ff_dict)
        loss = loss_dict['loss']
        tf.add_to_collection(tf.GraphKeys.LOSSES, loss)

        tf.summary.scalar("kl_loss", loss_dict['kl_loss'])
        tf.summary.scalar("H_Ps", loss_dict['H_Ps'])
        tf.summary.scalar("H_Ps_Pt", loss_dict['H_Ps_Pt'])
        if 'power_loss' in loss_dict:
            tf.summary.scalar('power_loss', loss_dict['power_loss'])

    if args.log_root:
        logdir_name = config_str.get_config_time_str(st_hparams,
                                                     'parallel_wavenet',
                                                     EXP_TAG)
        logdir = os.path.join(args.log_root, logdir_name)
    else:
        logdir = args.logdir
    tf.logging.info('Saving to {}'.format(logdir))

    os.makedirs(logdir, exist_ok=True)
    shutil.copy(args.config, logdir)

    with tf.Graph().as_default():
        total_batch_size = args.total_batch_size
        assert total_batch_size % num_clones == 0
        clone_batch_size = int(total_batch_size / num_clones)

        deploy_config = model_deploy.DeploymentConfig(
            num_clones=num_clones,
            clone_on_cpu=clone_on_cpu,
            num_ps_tasks=0,
            worker_job_name='localhost',
            ps_job_name='localhost')

        with tf.device(deploy_config.inputs_device()):
            inputs_dict = pwn.get_batch(clone_batch_size)

        summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES))

        clones = model_deploy.create_clones(deploy_config, _model_fn,
                                            [inputs_dict])
        first_clone_scope = deploy_config.clone_scope(0)
        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS,
                                       first_clone_scope)

        summaries.update(
            tf.get_collection(tf.GraphKeys.SUMMARIES, first_clone_scope))

        with tf.device(deploy_config.variables_device()):
            global_step = tf.get_variable(
                "global_step", [],
                tf.int32,
                initializer=tf.constant_initializer(0),
                trainable=False)

        ###
        # variables to train
        ###
        st_vars = [
            var for var in tf.trainable_variables() if 'iaf' in var.name
        ]

        with tf.device(deploy_config.optimizer_device()):
            lr = tf.constant(pwn.learning_rate_schedule[0])
            for key, value in pwn.learning_rate_schedule.items():
                lr = tf.cond(tf.less(global_step, key), lambda: lr,
                             lambda: tf.constant(value))
            summaries.add(tf.summary.scalar("learning_rate", lr))

            optimizer = tf.train.AdamOptimizer(lr, epsilon=1e-8)
            ema = tf.train.ExponentialMovingAverage(decay=0.9999,
                                                    num_updates=global_step)
            loss, clone_grads_vars = model_deploy.optimize_clones(
                clones, optimizer, var_list=st_vars)
            update_ops.append(
                optimizer.apply_gradients(clone_grads_vars,
                                          global_step=global_step))
            update_ops.append(ema.apply(st_vars))

            summaries.add(tf.summary.scalar("train_loss", loss))

            update_op = tf.group(*update_ops)
            with tf.control_dependencies([update_op]):
                train_tensor = tf.identity(loss, name='train_op')

        ###
        # restore teacher
        ###
        te_vars = [
            var for var in tf.trainable_variables() if 'iaf' not in var.name
        ]
        # teacher use EMA
        te_vars = {
            '{}/ExponentialMovingAverage'.format(tv.name[:-2]): tv
            for tv in te_vars
        }
        restore_init_fn = tf.contrib.framework.assign_from_checkpoint_fn(
            te_ckpt, te_vars)
        data_dep_init_fn = _data_dep_init()

        def group_init_fn(session):
            restore_init_fn(session)
            data_dep_init_fn(session)

        session_config = tf.ConfigProto(allow_soft_placement=True)
        session_config.gpu_options.allow_growth = True
        summary_op = tf.summary.merge(list(summaries), name='summary_op')

        slim.learning.train(train_tensor,
                            logdir=logdir,
                            number_of_steps=pwn.num_iters,
                            summary_op=summary_op,
                            global_step=global_step,
                            log_every_n_steps=100,
                            save_summaries_secs=600,
                            save_interval_secs=3600,
                            session_config=session_config,
                            init_fn=group_init_fn)
Beispiel #23
0
def main(_):
    assert six.PY3
    assert 1 == FLAGS.num_clones

    if not FLAGS.dataset_dir:
        raise ValueError(
            'You must supply the dataset directory with --dataset_dir')

    tf.logging.set_verbosity(tf.logging.INFO)
    with tf.Graph().as_default():
        #######################
        # Config model_deploy #
        #######################
        deploy_config = model_deploy.DeploymentConfig(
            num_clones=FLAGS.num_clones,
            clone_on_cpu=FLAGS.clone_on_cpu,
            replica_id=FLAGS.task,
            num_replicas=FLAGS.worker_replicas,
            num_ps_tasks=FLAGS.num_ps_tasks)

        # Create global_step
        with tf.device(deploy_config.variables_device()):
            global_step = slim.create_global_step()

        ######################
        # Select the dataset #
        ######################
        dataset = trainset

        ######################
        # Select the network #
        ######################
        network_fn = nets_factory.get_network_fn(
            FLAGS.model_name,
            num_classes=(dataset.num_classes - FLAGS.labels_offset),
            weight_decay=FLAGS.weight_decay,
            is_training=True)

        #####################################
        # Select the preprocessing function #
        #####################################
        preprocessing_name = FLAGS.preprocessing_name or FLAGS.model_name
        image_preprocessing_fn = trainset.get_tf_preprocess_image(
            is_training=True)

        ##############################################################
        # Create a dataset provider that loads data from the dataset #
        ##############################################################
        with tf.device(deploy_config.inputs_device()):
            assert FLAGS.train_image_size is not None
            # assert FLAGS.train_image_size == network_fn.default_image_size
            train_image_size = FLAGS.train_image_size or network_fn.default_image_size

        ####################
        # Define the model #
        ####################
        def clone_fn(batch_queue):
            """Allows data parallelism by creating multiple clones of network_fn."""
            images = tf.placeholder(tf.float32,
                                    shape=(FLAGS.batch_size, train_image_size,
                                           train_image_size, 3))
            labels = tf.placeholder(tf.float32,
                                    shape=(FLAGS.batch_size,
                                           dataset.num_classes))
            trainset.set_holders(images, labels)
            logits, end_points = network_fn(
                tf.concat([
                    tf.expand_dims(
                        image_preprocessing_fn(images[i], train_image_size,
                                               train_image_size), 0)
                    for i in range(FLAGS.batch_size)
                ], 0))
            logits = tf.squeeze(logits)

            #############################
            # Specify the loss function #
            #############################
            if 'AuxLogits' in end_points:
                slim.losses.softmax_cross_entropy(
                    end_points['AuxLogits'],
                    labels,
                    label_smoothing=FLAGS.label_smoothing,
                    weights=0.4,
                    scope='aux_loss')
            slim.losses.softmax_cross_entropy(
                logits,
                labels,
                label_smoothing=FLAGS.label_smoothing,
                weights=1.0)
            return end_points

        # Gather initial summaries.
        summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES))

        clones = model_deploy.create_clones(deploy_config, clone_fn, [None])
        first_clone_scope = deploy_config.clone_scope(0)
        # Gather update_ops from the first clone. These contain, for example,
        # the updates for the batch_norm variables created by network_fn.
        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS,
                                       first_clone_scope)

        # Add summaries for end_points.
        end_points = clones[0].outputs
        for end_point in end_points:
            x = end_points[end_point]
            # summaries.add(tf.summary.histogram('activations/' + end_point, x))
            summaries.add(
                tf.summary.scalar('sparsity/' + end_point,
                                  tf.nn.zero_fraction(x)))

        # Add summaries for losses.
        for loss in tf.get_collection(tf.GraphKeys.LOSSES, first_clone_scope):
            summaries.add(tf.summary.scalar('losses/%s' % loss.op.name, loss))

        # Add summaries for variables.
        # for variable in slim.get_model_variables():
        #   summaries.add(tf.summary.histogram(variable.op.name, variable))

        #################################
        # Configure the moving averages #
        #################################
        if FLAGS.moving_average_decay:
            moving_average_variables = slim.get_model_variables()
            variable_averages = tf.train.ExponentialMovingAverage(
                FLAGS.moving_average_decay, global_step)
        else:
            moving_average_variables, variable_averages = None, None

        #########################################
        # Configure the optimization procedure. #
        #########################################
        with tf.device(deploy_config.optimizer_device()):
            learning_rate = _configure_learning_rate(dataset.num_samples,
                                                     global_step)
            optimizer = _configure_optimizer(learning_rate)
            summaries.add(tf.summary.scalar('learning_rate', learning_rate))

        if FLAGS.sync_replicas:
            # If sync_replicas is enabled, the averaging will be done in the chief
            # queue runner.
            optimizer = tf.train.SyncReplicasOptimizer(
                opt=optimizer,
                replicas_to_aggregate=FLAGS.replicas_to_aggregate,
                total_num_replicas=FLAGS.worker_replicas,
                variable_averages=variable_averages,
                variables_to_average=moving_average_variables)
        elif FLAGS.moving_average_decay:
            # Update ops executed locally by trainer.
            update_ops.append(
                variable_averages.apply(moving_average_variables))

        # Variables to train.
        variables_to_train = _get_variables_to_train()

        #  and returns a train_tensor and summary_op
        total_loss, clones_gradients = model_deploy.optimize_clones(
            clones, optimizer, var_list=variables_to_train)
        # Add total_loss to summary.
        summaries.add(tf.summary.scalar('total_loss', total_loss))

        # Create gradient updates.
        grad_updates = optimizer.apply_gradients(clones_gradients,
                                                 global_step=global_step)
        update_ops.append(grad_updates)

        update_op = tf.group(*update_ops)
        with tf.control_dependencies([update_op]):
            train_tensor = tf.identity(total_loss, name='train_op')

        # Add the summaries from the first clone. These contain the summaries
        # created by model_fn and either optimize_clones() or _gather_clone_loss().
        summaries |= set(
            tf.get_collection(tf.GraphKeys.SUMMARIES, first_clone_scope))

        # Merge all summaries together.
        summary_op = tf.summary.merge(list(summaries), name='summary_op')

        ###########################
        # Kicks off the training. #
        ###########################
        session_config = tf.ConfigProto()
        session_config.gpu_options.per_process_gpu_memory_fraction = FLAGS.per_process_gpu_memory_fraction

        learning.train(
            train_tensor,
            session_config=session_config,
            train_step_fn=train_step,
            logdir=FLAGS.train_dir,
            master=FLAGS.master,
            is_chief=(FLAGS.task == 0),
            init_fn=_get_init_fn(),
            summary_op=summary_op,
            number_of_steps=FLAGS.max_number_of_steps,
            log_every_n_steps=FLAGS.log_every_n_steps,
            save_summaries_secs=FLAGS.save_summaries_secs,
            save_interval_secs=FLAGS.save_interval_secs,
            sync_optimizer=optimizer if FLAGS.sync_replicas else None)
Beispiel #24
0
def main(_):
    if not FLAGS.dataset_dir:
        raise ValueError(
            'You must supply the dataset directory with --dataset_dir')

    tf.logging.set_verbosity(tf.logging.DEBUG)
    with tf.Graph().as_default():
        # Config model_deploy. Keep TF Slim Models structure.
        # Useful if want to need multiple GPUs and/or servers in the future.
        deploy_config = model_deploy.DeploymentConfig(
            num_clones=FLAGS.num_clones,
            clone_on_cpu=FLAGS.clone_on_cpu,
            replica_id=0,
            num_replicas=1,
            num_ps_tasks=0)
        # Create global_step.
        with tf.device(deploy_config.variables_device()):
            global_step = slim.create_global_step()

        # Select the dataset.
        dataset = dataset_factory.get_dataset(FLAGS.dataset_name,
                                              FLAGS.dataset_split_name,
                                              FLAGS.dataset_dir)

        # Get the SSD network and its anchors.
        ssd_class = nets_factory.get_network(FLAGS.model_name)
        ssd_params = ssd_class.default_params._replace(
            num_classes=FLAGS.num_classes)
        ssd_net = ssd_class(ssd_params)
        ssd_shape = ssd_net.params.img_shape
        ssd_anchors = ssd_net.anchors(ssd_shape)

        # Select the preprocessing function.
        preprocessing_name = FLAGS.preprocessing_name or FLAGS.model_name
        image_preprocessing_fn = preprocessing_factory.get_preprocessing(
            preprocessing_name, is_training=True)

        tf_utils.print_configuration(FLAGS.__flags, ssd_params,
                                     dataset.data_sources, FLAGS.train_dir)
        # =================================================================== #
        # Create a dataset provider and batches.
        # =================================================================== #
        with tf.device(deploy_config.inputs_device()):
            with tf.name_scope(FLAGS.dataset_name + '_data_provider'):
                provider = slim.dataset_data_provider.DatasetDataProvider(
                    dataset,
                    num_readers=FLAGS.num_readers,
                    common_queue_capacity=20 * FLAGS.batch_size,
                    common_queue_min=10 * FLAGS.batch_size,
                    shuffle=True)
            # Get for SSD network: image, labels, bboxes.
            [image, shape, glabels, gbboxes] = provider.get(
                ['image', 'shape', 'object/label', 'object/bbox'])
            # Pre-processing image, labels and bboxes.
            image, glabels, gbboxes = \
                image_preprocessing_fn(image, glabels, gbboxes, ssd_shape)
            # Encode groundtruth labels and bboxes.
            gclasses, glocalisations, gscores = \
                ssd_net.bboxes_encode(glabels, gbboxes, ssd_anchors)
            batch_shape = [1] + [len(ssd_anchors)] * 3

            # Training batches and queue.
            r = tf.train.batch(tf_utils.reshape_list(
                [image, gclasses, glocalisations, gscores]),
                               batch_size=FLAGS.batch_size,
                               num_threads=FLAGS.num_preprocessing_threads,
                               capacity=5 * FLAGS.batch_size)
            b_image, b_gclasses, b_glocalisations, b_gscores = \
                tf_utils.reshape_list(r, batch_shape)

            # Intermediate queueing: unique batch computation pipeline for all
            # GPUs running the training.
            batch_queue = slim.prefetch_queue.prefetch_queue(
                tf_utils.reshape_list(
                    [b_image, b_gclasses, b_glocalisations, b_gscores]),
                capacity=2 * deploy_config.num_clones)

        # =================================================================== #
        # Define the model running on every GPU.
        # =================================================================== #
        def clone_fn(batch_queue):
            """Allows data parallelism by creating multiple
            clones of network_fn."""
            # Dequeue batch.
            b_image, b_gclasses, b_glocalisations, b_gscores = \
                tf_utils.reshape_list(batch_queue.dequeue(), batch_shape)

            # Construct SSD network.
            arg_scope = ssd_net.arg_scope(weight_decay=FLAGS.weight_decay)
            with slim.arg_scope(arg_scope):
                predictions, localisations, logits, end_points = \
                    ssd_net.net(b_image, is_training=True)
            # Add loss function.
            ssd_net.losses(logits,
                           localisations,
                           b_gclasses,
                           b_glocalisations,
                           b_gscores,
                           match_threshold=FLAGS.match_threshold,
                           negative_ratio=FLAGS.negative_ratio,
                           alpha=FLAGS.loss_alpha,
                           label_smoothing=FLAGS.label_smoothing)
            return end_points

        # Gather initial summaries.
        summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES))

        # =================================================================== #
        # Add summaries from first clone.
        # =================================================================== #
        clones = model_deploy.create_clones(deploy_config, clone_fn,
                                            [batch_queue])
        first_clone_scope = deploy_config.clone_scope(0)
        # Gather update_ops from the first clone. These contain, for example,
        # the updates for the batch_norm variables created by network_fn.
        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS,
                                       first_clone_scope)

        # Add summaries for end_points.
        end_points = clones[0].outputs
        for end_point in end_points:
            x = end_points[end_point]
            summaries.add(tf.summary.histogram('activations/' + end_point, x))
            summaries.add(
                tf.summary.scalar('sparsity/' + end_point,
                                  tf.nn.zero_fraction(x)))
        # Add summaries for losses and extra losses.
        for loss in tf.get_collection(tf.GraphKeys.LOSSES, first_clone_scope):
            summaries.add(tf.summary.scalar(loss.op.name, loss))
        for loss in tf.get_collection('EXTRA_LOSSES', first_clone_scope):
            summaries.add(tf.summary.scalar(loss.op.name, loss))

        # Add summaries for variables.
        for variable in slim.get_model_variables():
            summaries.add(tf.summary.histogram(variable.op.name, variable))

        # =================================================================== #
        # Configure the moving averages.
        # =================================================================== #
        if FLAGS.moving_average_decay:
            moving_average_variables = slim.get_model_variables()
            variable_averages = tf.train.ExponentialMovingAverage(
                FLAGS.moving_average_decay, global_step)
        else:
            moving_average_variables, variable_averages = None, None

        # =================================================================== #
        # Configure the optimization procedure.
        # =================================================================== #
        with tf.device(deploy_config.optimizer_device()):
            learning_rate = tf_utils.configure_learning_rate(
                FLAGS, dataset.num_samples, global_step)
            optimizer = tf_utils.configure_optimizer(FLAGS, learning_rate)
            summaries.add(tf.summary.scalar('learning_rate', learning_rate))

        if FLAGS.moving_average_decay:
            # Update ops executed locally by trainer.
            update_ops.append(
                variable_averages.apply(moving_average_variables))

        # Variables to train.
        variables_to_train = tf_utils.get_variables_to_train(FLAGS)

        # and returns a train_tensor and summary_op
        total_loss, clones_gradients = model_deploy.optimize_clones(
            clones, optimizer, var_list=variables_to_train)
        # Add total_loss to summary.
        summaries.add(tf.summary.scalar('total_loss', total_loss))

        # Create gradient updates.
        grad_updates = optimizer.apply_gradients(clones_gradients,
                                                 global_step=global_step)
        update_ops.append(grad_updates)
        update_op = tf.group(*update_ops)
        train_tensor = control_flow_ops.with_dependencies([update_op],
                                                          total_loss,
                                                          name='train_op')

        # Add the summaries from the first clone. These contain the summaries
        summaries |= set(
            tf.get_collection(tf.GraphKeys.SUMMARIES, first_clone_scope))
        # Merge all summaries together.
        summary_op = tf.summary.merge(list(summaries), name='summary_op')

        # =================================================================== #
        # Kicks off the training.
        # =================================================================== #
        gpu_options = tf.GPUOptions(
            per_process_gpu_memory_fraction=FLAGS.gpu_memory_fraction)
        config = tf.ConfigProto(log_device_placement=False,
                                gpu_options=gpu_options)
        saver = tf.train.Saver(max_to_keep=5,
                               keep_checkpoint_every_n_hours=1.0,
                               write_version=2,
                               pad_step_number=False)
        slim.learning.train(train_tensor,
                            logdir=FLAGS.train_dir,
                            master='',
                            is_chief=True,
                            init_fn=tf_utils.get_init_fn(FLAGS),
                            summary_op=summary_op,
                            number_of_steps=FLAGS.max_number_of_steps,
                            log_every_n_steps=FLAGS.log_every_n_steps,
                            save_summaries_secs=FLAGS.save_summaries_secs,
                            saver=saver,
                            save_interval_secs=FLAGS.save_interval_secs,
                            session_config=config,
                            sync_optimizer=None)
Beispiel #25
0
def main(unused_argv):
    tf.logging.set_verbosity(tf.logging.INFO)
    # Set up deployment (i.e., multi-GPUs and/or multi-replicas).
    config = model_deploy.DeploymentConfig(num_clones=FLAGS.num_clones,
                                           clone_on_cpu=FLAGS.clone_on_cpu,
                                           replica_id=FLAGS.task,
                                           num_replicas=FLAGS.num_replicas,
                                           num_ps_tasks=FLAGS.num_ps_tasks)

    # Split the batch across GPUs.
    assert FLAGS.train_batch_size % config.num_clones == 0, (
        'Training batch size not divisble by number of clones (GPUs).')

    clone_batch_size = FLAGS.train_batch_size / config.num_clones

    # Get dataset-dependent information.
    dataset = segmentation_dataset.get_dataset(FLAGS.dataset,
                                               FLAGS.train_split,
                                               dataset_dir=FLAGS.dataset_dir)

    tf.gfile.MakeDirs(FLAGS.train_logdir)
    tf.logging.info('Training on %s set', FLAGS.train_split)

    with tf.Graph().as_default():
        with tf.device(config.inputs_device()):
            samples = input_generator.get(
                dataset,
                FLAGS.train_crop_size,
                clone_batch_size,
                min_resize_value=FLAGS.min_resize_value,
                max_resize_value=FLAGS.max_resize_value,
                resize_factor=FLAGS.resize_factor,
                min_scale_factor=FLAGS.min_scale_factor,
                max_scale_factor=FLAGS.max_scale_factor,
                scale_factor_step_size=FLAGS.scale_factor_step_size,
                dataset_split=FLAGS.train_split,
                is_training=True,
                model_variant=FLAGS.model_variant)
            inputs_queue = prefetch_queue.prefetch_queue(samples,
                                                         capacity=128 *
                                                         config.num_clones)

        # Create the global step on the device storing the variables.
        with tf.device(config.variables_device()):
            global_step = tf.train.get_or_create_global_step()

            # Define the model and create clones.
            model_fn = _build_deeplab
            model_args = (inputs_queue, {
                common.OUTPUT_TYPE: dataset.num_classes
            }, dataset.ignore_label)
            clones = model_deploy.create_clones(config,
                                                model_fn,
                                                args=model_args)

            # Gather update_ops from the first clone. These contain, for example,
            # the updates for the batch_norm variables created by model_fn.
            first_clone_scope = config.clone_scope(0)
            update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS,
                                           first_clone_scope)

        # Gather initial summaries.
        summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES))

        # Add summaries for model variables.
        for model_var in slim.get_model_variables():
            summaries.add(tf.summary.histogram(model_var.op.name, model_var))

        # Add summaries for losses.
        for loss in tf.get_collection(tf.GraphKeys.LOSSES, first_clone_scope):
            summaries.add(tf.summary.scalar('losses/%s' % loss.op.name, loss))

        # Build the optimizer based on the device specification.
        with tf.device(config.optimizer_device()):
            learning_rate = train_utils.get_model_learning_rate(
                FLAGS.learning_policy, FLAGS.base_learning_rate,
                FLAGS.learning_rate_decay_step,
                FLAGS.learning_rate_decay_factor,
                FLAGS.training_number_of_steps, FLAGS.learning_power,
                FLAGS.slow_start_step, FLAGS.slow_start_learning_rate)
            optimizer = tf.train.MomentumOptimizer(learning_rate,
                                                   FLAGS.momentum)
            summaries.add(tf.summary.scalar('learning_rate', learning_rate))

        startup_delay_steps = FLAGS.task * FLAGS.startup_delay_steps
        for variable in slim.get_model_variables():
            summaries.add(tf.summary.histogram(variable.op.name, variable))

        with tf.device(config.variables_device()):
            total_loss, grads_and_vars = model_deploy.optimize_clones(
                clones, optimizer)
            total_loss = tf.check_numerics(total_loss, 'Loss is inf or nan.')
            summaries.add(tf.summary.scalar('total_loss', total_loss))

            # Modify the gradients for biases and last layer variables.
            last_layers = model.get_extra_layer_scopes()
            grad_mult = train_utils.get_model_gradient_multipliers(
                last_layers, FLAGS.last_layer_gradient_multiplier)
            if grad_mult:
                grads_and_vars = slim.learning.multiply_gradients(
                    grads_and_vars, grad_mult)

            # Create gradient update op.
            grad_updates = optimizer.apply_gradients(grads_and_vars,
                                                     global_step=global_step)
            update_ops.append(grad_updates)
            update_op = tf.group(*update_ops)
            with tf.control_dependencies([update_op]):
                train_tensor = tf.identity(total_loss, name='train_op')

        # Add the summaries from the first clone. These contain the summaries
        # created by model_fn and either optimize_clones() or _gather_clone_loss().
        summaries |= set(
            tf.get_collection(tf.GraphKeys.SUMMARIES, first_clone_scope))

        # Merge all summaries together.
        summary_op = tf.summary.merge(list(summaries))

        # Soft placement allows placing on CPU ops without GPU implementation.
        session_config = tf.ConfigProto(allow_soft_placement=True,
                                        log_device_placement=False)

        # Start the training.
        slim.learning.train(train_tensor,
                            logdir=FLAGS.train_logdir,
                            log_every_n_steps=FLAGS.log_steps,
                            master=FLAGS.master,
                            number_of_steps=FLAGS.training_number_of_steps,
                            is_chief=(FLAGS.task == 0),
                            session_config=session_config,
                            startup_delay_steps=startup_delay_steps,
                            init_fn=train_utils.get_model_init_fn(
                                FLAGS.train_logdir,
                                FLAGS.tf_initial_checkpoint,
                                FLAGS.initialize_last_layer,
                                last_layers,
                                ignore_missing_vars=True),
                            summary_op=summary_op,
                            save_summaries_secs=FLAGS.save_summaries_secs,
                            save_interval_secs=FLAGS.save_interval_secs)
def main(_):

    # Log
    tf.logging.set_verbosity(tf.logging.INFO)

    graph = tf.Graph()
    with graph.as_default(), tf.device('/cpu:0'):

        ######################
        # Config model_deploy#
        ######################
        deploy_config = model_deploy.DeploymentConfig(
            num_clones=FLAGS.num_clones,
            clone_on_cpu=FLAGS.clone_on_cpu,
            replica_id=FLAGS.task,
            num_replicas=FLAGS.worker_replicas,
            num_ps_tasks=FLAGS.num_ps_tasks)

        #########################################
        ########## required from data ###########
        #########################################
        num_samples_per_epoch = fileh.root.label_train.shape[0]
        num_batches_per_epoch = int(num_samples_per_epoch / FLAGS.batch_size)

        num_samples_per_epoch_test = fileh.root.label_test.shape[0]
        num_batches_per_epoch_test = int(num_samples_per_epoch_test /
                                         FLAGS.batch_size)

        # Create global_step
        global_step = tf.Variable(0, name='global_step', trainable=False)

        #####################################
        #### Configure the larning rate. ####
        #####################################
        learning_rate = _configure_learning_rate(num_samples_per_epoch,
                                                 global_step)
        opt = _configure_optimizer(learning_rate)

        ######################
        # Select the network #
        ######################

        # Training flag.
        is_training = tf.placeholder(tf.bool)

        # Get the network. The number of subjects is num_subjects.
        model_speech_fn = nets_factory.get_network_fn(
            FLAGS.model_speech,
            num_classes=num_subjects,
            weight_decay=FLAGS.weight_decay,
            is_training=is_training)

        #####################################
        # Select the preprocessing function #
        #####################################

        # TODO: Do some preprocessing if necessary.

        ##############################################################
        # Create a dataset provider that loads data from the dataset #
        ##############################################################
        # with tf.device(deploy_config.inputs_device()):
        """
        Define the place holders and creating the batch tensor.
        """
        speech = tf.placeholder(tf.float32, (20, 80, 40, 1))
        label = tf.placeholder(tf.int32, (1))
        batch_dynamic = tf.placeholder(tf.int32, ())
        margin_imp_tensor = tf.placeholder(tf.float32, ())

        # Create the batch tensors
        batch_speech, batch_labels = tf.train.batch(
            [speech, label],
            batch_size=batch_dynamic,
            num_threads=FLAGS.num_preprocessing_threads,
            capacity=5 * FLAGS.batch_size)

        #############################
        # Specify the loss function #
        #############################
        tower_grads = []
        with tf.variable_scope(tf.get_variable_scope()):
            for i in xrange(FLAGS.num_clones):
                with tf.device('/gpu:%d' % i):
                    with tf.name_scope('%s_%d' % ('tower', i)) as scope:
                        """
                        Two distance metric are defined:
                           1 - distance_weighted: which is a weighted average of the distance between two structures.
                           2 - distance_l2: which is the regular l2-norm of the two networks outputs.
                        Place holders
                        """

                        ########################################
                        ######## Outputs of two networks #######
                        ########################################

                        # Distribute data among all clones equally.
                        step = int(FLAGS.batch_size / float(FLAGS.num_clones))

                        # Network outputs.
                        logits, end_points_speech = model_speech_fn(
                            batch_speech[i * step:(i + 1) * step])

                        ###################################
                        ########## Loss function ##########
                        ###################################
                        # one_hot labeling
                        label_onehot = tf.one_hot(tf.squeeze(
                            batch_labels[i * step:(i + 1) * step], [1]),
                                                  depth=num_subjects,
                                                  axis=-1)

                        SOFTMAX = tf.nn.softmax_cross_entropy_with_logits(
                            logits=logits, labels=label_onehot)

                        # Define loss
                        with tf.name_scope('loss'):
                            loss = tf.reduce_mean(SOFTMAX)

                        # Accuracy
                        with tf.name_scope('accuracy'):
                            # Evaluate the model
                            correct_pred = tf.equal(tf.argmax(logits, 1),
                                                    tf.argmax(label_onehot, 1))

                            # Accuracy calculation
                            accuracy = tf.reduce_mean(
                                tf.cast(correct_pred, tf.float32))

                        # ##### call the optimizer ######
                        # # TODO: call optimizer object outside of this gpu environment
                        #
                        # Reuse variables for the next tower.
                        tf.get_variable_scope().reuse_variables()

                        # Calculate the gradients for the batch of data on this CIFAR tower.
                        grads = opt.compute_gradients(loss)

                        # Keep track of the gradients across all towers.
                        tower_grads.append(grads)

        # We must calculate the mean of each gradient. Note that this is the
        # synchronization point across all towers.
        grads = average_gradients(tower_grads)

        # Apply the gradients to adjust the shared variables.
        apply_gradient_op = opt.apply_gradients(grads, global_step=global_step)

        # Track the moving averages of all trainable variables.
        MOVING_AVERAGE_DECAY = 0.9999
        variable_averages = tf.train.ExponentialMovingAverage(
            MOVING_AVERAGE_DECAY, global_step)
        variables_averages_op = variable_averages.apply(
            tf.trainable_variables())

        # Group all updates to into a single train op.
        train_op = tf.group(apply_gradient_op, variables_averages_op)

        #################################################
        ########### Summary Section #####################
        #################################################

        # Gather initial summaries.
        summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES))

        # Add summaries for all end_points.
        for end_point in end_points_speech:
            x = end_points_speech[end_point]
            summaries.add(
                tf.summary.scalar('sparsity/' + end_point,
                                  tf.nn.zero_fraction(x)))

        # Add summaries for variables.
        for variable in slim.get_model_variables():
            summaries.add(tf.summary.histogram(variable.op.name, variable))

        # Add to parameters to summaries
        summaries.add(tf.summary.scalar('learning_rate', learning_rate))
        summaries.add(tf.summary.scalar('global_step', global_step))
        summaries.add(tf.summary.scalar('eval/Loss', loss))
        summaries |= set(tf.get_collection(tf.GraphKeys.SUMMARIES))

        # Merge all summaries together.
        summary_op = tf.summary.merge(list(summaries), name='summary_op')

    ###########################
    ######## Training #########
    ###########################

    with tf.Session(graph=graph,
                    config=tf.ConfigProto(allow_soft_placement=True)) as sess:

        # Initialization of the network.
        variables_to_restore = slim.get_variables_to_restore()
        saver = tf.train.Saver(variables_to_restore, max_to_keep=20)
        coord = tf.train.Coordinator()
        sess.run(tf.global_variables_initializer())
        sess.run(tf.local_variables_initializer())

        # op to write logs to Tensorboard
        summary_writer = tf.summary.FileWriter(FLAGS.train_dir, graph=graph)

        #####################################
        ############## TRAIN ################
        #####################################

        step = 1
        for epoch in range(FLAGS.num_epochs):

            # Loop over all batches
            for batch_num in range(num_batches_per_epoch):
                step += 1
                start_idx = batch_num * FLAGS.batch_size
                end_idx = (batch_num + 1) * FLAGS.batch_size
                speech_train, label_train = fileh.root.utterance_train[
                    start_idx:end_idx, :, :, :], fileh.root.label_train[
                        start_idx:end_idx]

                # This transpose is necessary for 3D convolutional operation which will be performed by TensorFlow.
                speech_train = np.transpose(speech_train[None, :, :, :, :],
                                            axes=(1, 4, 2, 3, 0))

                # shuffling
                index = random.sample(range(speech_train.shape[0]),
                                      speech_train.shape[0])
                speech_train = speech_train[index]
                label_train = label_train[index]

                _, loss_value, train_accuracy, summary, training_step, _ = sess.run(
                    [
                        train_op, loss, accuracy, summary_op, global_step,
                        is_training
                    ],
                    feed_dict={
                        is_training:
                        True,
                        batch_dynamic:
                        label_train.shape[0],
                        margin_imp_tensor:
                        100,
                        batch_speech:
                        speech_train,
                        batch_labels:
                        label_train.reshape([label_train.shape[0], 1])
                    })
                summary_writer.add_summary(summary,
                                           epoch * num_batches_per_epoch + i)

                # # Calculate ROC data
                # # print("label_train, score_dissimilarity", type(label_train), type(score_dissimilarity),label_train[0],score_dissimilarity[0])
                # # sys.exit(1)
                # EER, AUC, AP, tpr, fpr = calculate_roc.calculate_eer_auc_ap(label_train, score_dissimilarity)
                #
                if (batch_num + 1) % 1 == 0:
                    print("Epoch " + str(epoch + 1) + ", Minibatch " + str(
                        batch_num + 1) + " of %d " % num_batches_per_epoch + ", Minibatch Loss= " + \
                          "{:.4f}".format(loss_value) + ", TRAIN ACCURACY= " + "{:.3f}".format(
                        100 * train_accuracy))

            # Save the model
            saver.save(sess, FLAGS.train_dir, global_step=training_step)

            # ###################################################
            # ############## TEST PER EACH EPOCH ################
            # ###################################################

            label_vector = np.zeros(
                (FLAGS.batch_size * num_batches_per_epoch_test, 1))
            test_accuracy_vector = np.zeros((num_batches_per_epoch_test, 1))

            # Loop over all batches
            for i in range(num_batches_per_epoch_test):
                start_idx = i * FLAGS.batch_size
                end_idx = (i + 1) * FLAGS.batch_size
                speech_test, label_test = fileh.root.utterance_test[
                    start_idx:end_idx, :, :, :], fileh.root.label_test[
                        start_idx:end_idx]

                # Get the test batch.
                speech_test = np.transpose(speech_test[None, :, :, :, :],
                                           axes=(1, 4, 2, 3, 0))

                # Evaluation
                loss_value, test_accuracy, _ = sess.run(
                    [loss, accuracy, is_training],
                    feed_dict={
                        is_training: False,
                        batch_dynamic: FLAGS.batch_size,
                        margin_imp_tensor: 50,
                        batch_speech: speech_test,
                        batch_labels: label_test.reshape([FLAGS.batch_size, 1])
                    })
                label_test = label_test.reshape([FLAGS.batch_size, 1])
                label_vector[start_idx:end_idx] = label_test
                test_accuracy_vector[i, :] = test_accuracy

                # ROC

                ##############################
                ##### K-split validation #####
                ##############################
            print("TESTING after finishing the training on: epoch " +
                  str(epoch + 1))
            # print("TESTING accuracy = ", 100 * np.mean(test_accuracy_vector, axis=0))

            K = 4
            Accuracy = np.zeros((K, 1))
            batch_k_validation = int(test_accuracy_vector.shape[0] / float(K))

            for i in range(K):
                Accuracy[i, :] = 100 * np.mean(
                    test_accuracy_vector[i * batch_k_validation:(i + 1) *
                                         batch_k_validation],
                    axis=0)

            # Reporting the K-fold validation
            print("Test Accuracy " + str(epoch + 1) + ", Mean= " + \
                          "{:.4f}".format(np.mean(Accuracy, axis=0)[0]) + ", std= " + "{:.3f}".format(
                        np.std(Accuracy, axis=0)[0]))
def main(_):
    #if not FLAGS.dataset_dir:
    #  raise ValueError('You must supply the dataset directory with --dataset_dir')

    tf.logging.set_verbosity(tf.logging.INFO)
    with tf.Graph().as_default():
        #######################
        # Config model_deploy #
        #######################
        deploy_config = model_deploy.DeploymentConfig(num_clones=NUM_GPUS,
                                                      clone_on_cpu=False,
                                                      replica_id=0,
                                                      num_replicas=1,
                                                      num_ps_tasks=0)

        # Create global_step
        with tf.device(deploy_config.variables_device()):
            global_step = slim.create_global_step()

        ######################
        # Select the dataset #
        ######################
        #dataset = dataset_factory.get_dataset(
        #    FLAGS.dataset_name, FLAGS.dataset_split_name, FLAGS.dataset_dir)

        ######################
        # Select the network #
        ######################
        #network_fn = nets_factory.get_network_fn(
        #    FLAGS.model_name,
        #    num_classes=NUM_CLASSES,
        #    weight_decay=FLAGS.weight_decay,
        #    is_training=True)

        #####################################
        # Select the preprocessing function #
        #####################################
        preprocessing_name = FLAGS.preprocessing_name or FLAGS.model_name
        image_preprocessing_fn = preprocessing_factory.get_preprocessing(
            preprocessing_name, is_training=True)

        ##############################################################
        # Create a dataset provider that loads data from the dataset #
        ##############################################################
        with tf.device(deploy_config.inputs_device()):
            train_image_size = FLAGS.train_image_size
            #provider = slim.dataset_data_provider.DatasetDataProvider(
            #    dataset,
            #    num_readers=FLAGS.num_readers,
            #    common_queue_capacity=20 * FLAGS.batch_size,
            #    common_queue_min=10 * FLAGS.batch_size)
            #[image, label] = provider.get(['image', 'label'])
            _images = tf.convert_to_tensor(train_image_paths, dtype=tf.string)
            _labels = tf.convert_to_tensor(train_labels, dtype=tf.int64)
            _labels_a = tf.convert_to_tensor(train_labels_a, dtype=tf.int64)
            input_queue = tf.train.slice_input_producer(
                [_images, _labels, _labels_a], shuffle=True)
            file_path = input_queue[0]
            tf.Print(file_path, [file_path], "image path:")
            file_content = tf.read_file(file_path)
            image = tf.image.decode_jpeg(file_content, channels=3)
            image = preprocessing(image)
            #image = image_preprocessing_fn(image, train_image_size, train_image_size)
            label = input_queue[1]
            label -= FLAGS.labels_offset
            label_a = input_queue[2]

            images, labels, labels_a = tf.train.batch(
                [image, label, label_a],
                batch_size=FLAGS.batch_size_in_clone,
                num_threads=FLAGS.num_preprocessing_threads,
                capacity=(NUM_GPUS + 2) * FLAGS.batch_size_in_clone)

            #[images], labels = tf.contrib.training.stratified_sample(
            #    [input_queue[0]], input_queue[1], target_probs,
            #    batch_size=FLAGS.batch_size_in_clone,
            #    init_probs = init_probs,
            #    threads_per_queue=FLAGS.num_preprocessing_threads,
            #    queue_capacity=(NUM_GPUS+2)*FLAGS.batch_size_in_clone, prob_dtype=dtypes.float64)
            #labels -= FLAGS.labels_offset
            labels = slim.one_hot_encoding(
                labels, FLAGS.num_classes - FLAGS.labels_offset)
            #images_ = []
            #im_dtype = dtypes.int32
            #for i in xrange(FLAGS.batch_size_in_clone):
            #  image = images[i]
            #  file_content = tf.read_file(image)
            #  image = tf.image.decode_jpeg(file_content, channels=3)
            #  image = image_preprocessing_fn(image, train_image_size, train_image_size)
            #  im_dtype = image.dtype
            #  images_.append(image)
            #images = tf.convert_to_tensor(images_, dtype=im_dtype)

            batch_queue = slim.prefetch_queue.prefetch_queue(
                [images, labels, labels_a],
                capacity=8 * deploy_config.num_clones)

            #_images_val = tf.convert_to_tensor(val_image_paths,dtype=tf.string)
            #_labels_val = tf.convert_to_tensor(val_labels,dtype=tf.int64)
            #input_queue_val = tf.train.slice_input_producer([_images_val, _labels_val], shuffle=False)
            #file_content_val = tf.read_file(input_queue_val[0])
            #image_val = tf.image.decode_jpeg(file_content_val, channels=3)
            #label_val = input_queue_val[1]
            #label_val -= FLAGS.labels_offset

            #image_size = FLAGS.train_image_size or network_fn.default_image_size

            #image_val = image_preprocessing_fn(image_val, image_size, image_size)

            #images_val, labels_val = tf.train.batch(
            #    [image_val, label_val],
            #    batch_size=FLAGS.batch_size_in_clone,
            #    num_threads=FLAGS.num_preprocessing_threads,
            #    capacity=5 * FLAGS.batch_size_in_clone)
            #labels_val = slim.one_hot_encoding(
            #    labels_val, FLAGS.num_classes - FLAGS.labels_offset)
            #batch_queue_val = slim.prefetch_queue.prefetch_queue(
            #    [images_val, labels_val], capacity=2 * deploy_config.num_clones)

        ####################
        # Define the model #
        ####################
        def clone_fn(batch_queue):
            """Allows data parallelism by creating multiple clones of network_fn."""
            images, labels, labels_a = batch_queue.dequeue()
            logits, logits_a, end_points = network_fn(images)

            #############################
            # Specify the loss function #
            #############################
            if not EXCLUDE_AUX and 'AuxLogits' in end_points:
                tf.losses.softmax_cross_entropy(
                    logits=end_points['AuxLogits'],
                    onehot_labels=labels,
                    label_smoothing=FLAGS.label_smoothing,
                    weights=0.4,
                    scope='aux_loss')
            tf.losses.softmax_cross_entropy(
                logits=logits,
                onehot_labels=labels,
                label_smoothing=FLAGS.label_smoothing,
                weights=1.0)
            tf.losses.sigmoid_cross_entropy(
                logits=logits_a,
                multi_class_labels=labels_a,
                label_smoothing=FLAGS.label_smoothing,
                weights=1.0)
            return end_points

        # Gather initial summaries.
        summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES))

        clones = model_deploy.create_clones(deploy_config, clone_fn,
                                            [batch_queue])
        first_clone_scope = deploy_config.clone_scope(0)
        # Gather update_ops from the first clone. These contain, for example,
        # the updates for the batch_norm variables created by network_fn.
        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS,
                                       first_clone_scope)

        # Add summaries for end_points.
        end_points = clones[0].outputs
        for end_point in end_points:
            x = end_points[end_point]
            summaries.add(tf.summary.histogram('activations/' + end_point, x))
            summaries.add(
                tf.summary.scalar('sparsity/' + end_point,
                                  tf.nn.zero_fraction(x)))

        # Add summaries for losses.
        for loss in tf.get_collection(tf.GraphKeys.LOSSES, first_clone_scope):
            summaries.add(tf.summary.scalar('losses/%s' % loss.op.name, loss))

        # Add summaries for variables.
        for variable in slim.get_model_variables():
            summaries.add(tf.summary.histogram(variable.op.name, variable))

        #################################
        # Configure the moving averages #
        #################################
        if FLAGS.moving_average_decay:
            moving_average_variables = slim.get_model_variables()
            variable_averages = tf.train.ExponentialMovingAverage(
                FLAGS.moving_average_decay, global_step)
        else:
            moving_average_variables, variable_averages = None, None

        #########################################
        # Configure the optimization procedure. #
        #########################################
        with tf.device(deploy_config.optimizer_device()):
            learning_rate = _configure_learning_rate(FLAGS.num_samples,
                                                     global_step, INIT_LR)
            optimizer = _configure_optimizer(learning_rate)
            learning_rate_lower = _configure_learning_rate(
                FLAGS.num_samples, global_step, LOC_LR)
            optimizer_lower = _configure_optimizer(learning_rate_lower)
            #summaries.add(tf.summary.scalar('learning_rate', learning_rate))

        variables_to_train = _get_variables_to_train()
        variables_to_train_lower = _get_variables_to_train_lower()

        total_loss_lower, clones_gradients_lower = model_deploy.optimize_clones(
            clones, optimizer_lower, var_list=variables_to_train_lower)

        total_loss, clones_gradients = model_deploy.optimize_clones(
            clones, optimizer, var_list=variables_to_train)

        #print(len(clones_gradients))
        #print(len(variables_to_train))
        #print(len(variables_to_train_lower))
        #assert(len(clones_gradients)==len(variables_to_train)+len(variables_to_train_lower))

        grad_updates = optimizer.apply_gradients(clones_gradients)
        update_ops.append(grad_updates)
        grad_updates_lower = optimizer_lower.apply_gradients(
            clones_gradients_lower, global_step=global_step)
        update_ops.append(grad_updates_lower)

        update_op = tf.group(*update_ops)
        train_tensor = control_flow_ops.with_dependencies([update_op],
                                                          total_loss,
                                                          name='train_op')

        # Add the summaries from the first clone. These contain the summaries
        # created by model_fn and either optimize_clones() or _gather_clone_loss().
        summaries |= set(
            tf.get_collection(tf.GraphKeys.SUMMARIES, first_clone_scope))

        # Merge all summaries together.
        summary_op = tf.summary.merge(list(summaries), name='summary_op')

        ###########################
        # Kicks off the training. #
        ###########################
        #slim.learning.train(
        #    train_tensor,
        #    logdir=FLAGS.train_dir,
        #    master=FLAGS.master,
        #    is_chief=(FLAGS.task == 0),
        #    init_fn=_get_init_fn(),
        #    summary_op=summary_op,
        #    number_of_steps=FLAGS.max_number_of_steps,
        #    log_every_n_steps=FLAGS.log_every_n_steps,
        #    save_summaries_secs=FLAGS.save_summaries_secs,
        #    save_interval_secs=FLAGS.save_interval_secs,
        #    sync_optimizer=optimizer if FLAGS.sync_replicas else None)

        do_training(train_tensor,
                    init_fn=_get_init_fn(),
                    summary_op=summary_op,
                    lr=learning_rate)
Beispiel #28
0
def main(_):
    # 打印级别设置为info
    tf.logging.set_verbosity(tf.logging.INFO)
    with tf.Graph().as_default():
        # 信息配置
        deploy_config = model_deploy.DeploymentConfig(num_clones=num_clones,
                                                      clone_on_cpu=False,
                                                      replica_id=0,
                                                      num_replicas=1,
                                                      num_ps_tasks=0)
        # 全局步数
        with tf.device(deploy_config.variables_device()):
            global_step = slim.create_global_step()
        # 读取数据
        with tf.device(deploy_config.inputs_device()):
            # start = time.time()
            images, labels = reader()
            batch_queue = slim.prefetch_queue.prefetch_queue(
                [images, labels], capacity=2 * deploy_config.num_clones)
            # readtime = time.time()-start

        # 前向传播
        def clone_fn(batch_queue):
            images, labels = batch_queue.dequeue()
            logits = models(images)
            slim.losses.softmax_cross_entropy(logits, labels)
            return logits

        summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES))
        # 前向传播,每次取出一个batch进行计算
        clones = model_deploy.create_clones(deploy_config, clone_fn,
                                            [batch_queue])
        # 取出目前计算到的梯度列表,保存在update_ops中
        first_clone_scope = deploy_config.clone_scope(0)
        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS,
                                       first_clone_scope)

        for loss in tf.get_collection(tf.GraphKeys.LOSSES, first_clone_scope):
            summaries.add(tf.summary.scalar('losses/%s' % loss.op.name, loss))
        for variable in slim.get_model_variables():
            summaries.add(tf.summary.histogram(variable.op.name, variable))

        with tf.device(deploy_config.optimizer_device()):
            # 改变学习率,设置优化器
            # learning_rate = tf.train.piecewise_constant(global_step, [200, 500], [0.003, 0.0003, 0.0001])
            # staircase=True:每个poch更新学习率  False:每一步更新学习率

            learning_rate = tf.train.exponential_decay(
                init_learning_rate,
                global_step,
                decay_steps,
                learning_rate_decay_factor,
                staircase=False,
                name='exponential_decay_learning_rate')
            optimizer = tf.train.GradientDescentOptimizer(learning_rate)
            # optimizer = tf.train.MomentumOptimizer(learning_rate, momentum=0.9)

            summaries.add(tf.summary.scalar('learning_rate', learning_rate))

        # 得到总损失,并计算梯度梯度,minimize()的第一部分
        total_loss, clones_gradients = model_deploy.optimize_clones(
            clones, optimizer, var_list=tf.trainable_variables())
        summaries.add(tf.summary.scalar('total_loss', total_loss))
        # 将计算出的梯度应用到变量上,minimize()的第二部分
        grad_updates = optimizer.apply_gradients(clones_gradients,
                                                 global_step=global_step)
        # 保存update_ops
        update_ops.append(grad_updates)
        update_op = tf.group(*update_ops)
        # 在执行with包含的内容前,先执行control_dependencies参数中的内容。
        # 即先执行update_op,再打印total_loss
        # 通过tf.identity创建节点,而不是直接运算
        with tf.control_dependencies([update_op]):
            train_tensor = tf.identity(total_loss, name='train_op')
            # tf.logging.info('readtime: %d' % readtime)

        summaries |= set(
            tf.get_collection(tf.GraphKeys.SUMMARIES, first_clone_scope))
        summary_op = tf.summary.merge(list(summaries), name='summary_op')

        # 指定需要加载的变量,加载预训练模型
        # fine_tune_path = 'checkpoints/vgg_16.ckpt'
        # variables_to_restore = slim.get_variables_to_restore(exclude=['global_step','vgg_16/fc8'])
        # init_fn = slim.assign_from_checkpoint_fn(fine_tune_path, variables_to_restore, ignore_missing_vars=True)
        # 配置GPU
        session_config = tf.ConfigProto(allow_soft_placement=True)
        session_config.gpu_options.per_process_gpu_memory_fraction = 0.9

        saver = tf.train.Saver(max_to_keep=5,
                               keep_checkpoint_every_n_hours=1.0,
                               write_version=2,
                               pad_step_number=False)

        slim.learning.train(
            train_tensor,
            logdir=output_path,
            master='',
            is_chief=True,
            # init_fn=init_fn,
            init_fn=None,
            summary_op=summary_op,
            number_of_steps=max_steps,
            log_every_n_steps=1,
            save_summaries_secs=10,
            saver=saver,
            save_interval_secs=150,
            sync_optimizer=None,
            session_config=session_config)
def main(_):
    if not FLAGS.dataset_dir:
        raise ValueError('没有指定tfrecord数据集的路径,--dataset_dir')

    tf.logging.set_verbosity(tf.logging.INFO)
    with tf.Graph().as_default():
        #######################
        # Config model_deploy #
        #######################
        deploy_config = model_deploy.DeploymentConfig(
            num_clones=FLAGS.num_clones,
            clone_on_cpu=FLAGS.clone_on_cpu,
            replica_id=FLAGS.task,
            num_replicas=FLAGS.worker_replicas,
            num_ps_tasks=FLAGS.num_ps_tasks)

        # Create global_step
        with tf.device(deploy_config.variables_device()):
            global_step = slim.create_global_step()

        ######################
        # Select the dataset #
        ######################
        dataset_model = factory.get_dataset(FLAGS.dataset_name)
        dataset = utils.get_dataset(FLAGS.dataset_dir,
                                    FLAGS.dataset_split_name, dataset_model)
        ######################
        # Select the network #
        ######################
        network_fn = nets_factory.get_network_fn(
            FLAGS.model_name,
            num_classes=(dataset.num_classes - FLAGS.labels_offset),
            weight_decay=FLAGS.weight_decay,
            is_training=True)

        #####################################
        # Select the preprocessing function #
        #####################################
        preprocessing_name = FLAGS.preprocessing_name or FLAGS.model_name
        image_preprocessing_fn = preprocessing_factory.get_preprocessing(
            preprocessing_name, is_training=True)
        ##############################################################
        # Create a dataset provider that loads data from the dataset #
        ##############################################################
        with tf.device(deploy_config.inputs_device()):
            provider = slim.dataset_data_provider.DatasetDataProvider(
                dataset,
                num_readers=FLAGS.num_readers,
                common_queue_capacity=20 * FLAGS.batch_size,
                common_queue_min=10 * FLAGS.batch_size)
            [image, label] = provider.get(['image', 'label'])
            label -= FLAGS.labels_offset

            train_image_size = FLAGS.train_image_size or network_fn.default_image_size
            image = image_preprocessing_fn(image, train_image_size,
                                           train_image_size)

            images, labels = tf.train.batch(
                [image, label],
                batch_size=FLAGS.batch_size,
                num_threads=FLAGS.num_preprocessing_threads,
                capacity=5 * FLAGS.batch_size)
            labels = slim.one_hot_encoding(
                labels, dataset.num_classes - FLAGS.labels_offset)
            batch_queue = slim.prefetch_queue.prefetch_queue(
                [images, labels], capacity=2 * deploy_config.num_clones)

        ####################
        # Define the model #
        ####################
        def clone_fn(batch_queue):
            """Allows data parallelism by creating multiple clones of network_fn."""
            images, labels = batch_queue.dequeue()
            logits, end_points = network_fn(images)

            #############################
            # Specify the loss function #
            #############################
            if 'AuxLogits' in end_points:
                slim.losses.softmax_cross_entropy(
                    end_points['AuxLogits'],
                    labels,
                    label_smoothing=FLAGS.label_smoothing,
                    weights=0.4,
                    scope='aux_loss')
            slim.losses.softmax_cross_entropy(
                logits,
                labels,
                label_smoothing=FLAGS.label_smoothing,
                weights=1.0)
            return end_points

        # Gather initial summaries.
        summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES))

        clones = model_deploy.create_clones(deploy_config, clone_fn,
                                            [batch_queue])
        first_clone_scope = deploy_config.clone_scope(0)
        # Gather update_ops from the first clone. These contain, for example,
        # the updates for the batch_norm variables created by network_fn.
        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS,
                                       first_clone_scope)

        # Add summaries for end_points.
        end_points = clones[0].outputs
        for end_point in end_points:
            x = end_points[end_point]
            summaries.add(tf.summary.histogram('activations/' + end_point, x))
            summaries.add(
                tf.summary.scalar('sparsity/' + end_point,
                                  tf.nn.zero_fraction(x)))

        # Add summaries for losses.
        for loss in tf.get_collection(tf.GraphKeys.LOSSES, first_clone_scope):
            summaries.add(tf.summary.scalar('losses/%s' % loss.op.name, loss))

        # Add summaries for variables.
        for variable in slim.get_model_variables():
            summaries.add(tf.summary.histogram(variable.op.name, variable))

        #################################
        # Configure the moving averages #
        #################################
        if FLAGS.moving_average_decay:
            moving_average_variables = slim.get_model_variables()
            variable_averages = tf.train.ExponentialMovingAverage(
                FLAGS.moving_average_decay, global_step)
        else:
            moving_average_variables, variable_averages = None, None

        #########################################
        # Configure the optimization procedure. #
        #########################################
        with tf.device(deploy_config.optimizer_device()):
            learning_rate = _configure_learning_rate(dataset.num_samples,
                                                     global_step)
            optimizer = _configure_optimizer(learning_rate)
            summaries.add(tf.summary.scalar('learning_rate', learning_rate))

        if FLAGS.sync_replicas:
            # If sync_replicas is enabled, the averaging will be done in the chief
            # queue runner.
            optimizer = tf.train.SyncReplicasOptimizer(
                opt=optimizer,
                replicas_to_aggregate=FLAGS.replicas_to_aggregate,
                total_num_replicas=FLAGS.worker_replicas,
                variable_averages=variable_averages,
                variables_to_average=moving_average_variables)
        elif FLAGS.moving_average_decay:
            # Update ops executed locally by trainer.
            update_ops.append(
                variable_averages.apply(moving_average_variables))

        # Variables to train.
        variables_to_train = _get_variables_to_train()

        #  and returns a train_tensor and summary_op
        total_loss, clones_gradients = model_deploy.optimize_clones(
            clones, optimizer, var_list=variables_to_train)
        # Add total_loss to summary.
        summaries.add(tf.summary.scalar('total_loss', total_loss))

        # Create gradient updates.
        grad_updates = optimizer.apply_gradients(clones_gradients,
                                                 global_step=global_step)
        update_ops.append(grad_updates)

        update_op = tf.group(*update_ops)
        with tf.control_dependencies([update_op]):
            train_tensor = tf.identity(total_loss, name='train_op')

        # Add the summaries from the first clone. These contain the summaries
        # created by model_fn and either optimize_clones() or _gather_clone_loss().
        summaries |= set(
            tf.get_collection(tf.GraphKeys.SUMMARIES, first_clone_scope))

        # Merge all summaries together.
        summary_op = tf.summary.merge(list(summaries), name='summary_op')

        ###########################
        # Kicks off the training. #
        ###########################
        slim.learning.train(
            train_tensor,
            logdir=FLAGS.train_dir,
            master=FLAGS.master,
            is_chief=(FLAGS.task == 0),
            init_fn=_get_init_fn(),
            summary_op=summary_op,
            number_of_steps=FLAGS.max_number_of_steps,
            log_every_n_steps=FLAGS.log_every_n_steps,
            save_summaries_secs=FLAGS.save_summaries_secs,
            save_interval_secs=FLAGS.save_interval_secs,
            sync_optimizer=optimizer if FLAGS.sync_replicas else None)
def train():
    if not FLAGS.dataset_dir:
        raise ValueError(
            'You must supply the dataset directory with --dataset_dir')

    tf.logging.set_verbosity(tf.logging.INFO)
    with tf.Graph().as_default():
        #######################
        # Config model_deploy #
        #######################
        deploy_config = model_deploy.DeploymentConfig(
            num_clones=FLAGS.num_clones,
            clone_on_cpu=FLAGS.clone_on_cpu,
            replica_id=FLAGS.task,
            num_replicas=FLAGS.worker_replicas,
            num_ps_tasks=FLAGS.num_ps_tasks)

        # Create global_step
        with tf.device(deploy_config.variables_device()):
            global_step = slim.create_global_step()

        ######################
        # Select the dataset #
        ######################
        dataset = dataset_factory.get_dataset(FLAGS.dataset_name,
                                              FLAGS.split_name,
                                              FLAGS.dataset_dir)

        ###########################
        # Select the CNN network  #
        ###########################
        network_fn = nets_factory.get_network_fn(
            FLAGS.model_name,
            num_classes=None,
            weight_decay=FLAGS.weight_decay,
            is_training=True)

        #########################################
        # Configure the optimization procedure. #
        #########################################
        with tf.device(deploy_config.optimizer_device()):
            learning_rate = configure_learning_rate(dataset.num_samples,
                                                    global_step)
            optimizer = configure_optimizer(learning_rate)

        #####################################
        # Select the preprocessing function #
        #####################################
        preprocessing_name = FLAGS.preprocessing_name
        image_preprocessing_fn = preprocessing_factory.get_preprocessing(
            preprocessing_name, is_training=True)

        ##############################################################
        # Create a dataset provider that loads data from the dataset #
        ##############################################################
        with tf.device(deploy_config.inputs_device()):
            examples_per_shard = 1024
            min_queue_examples = examples_per_shard * FLAGS.input_queue_memory_factor
            provider = slim.dataset_data_provider.DatasetDataProvider(
                dataset,
                num_readers=FLAGS.num_readers,
                common_queue_capacity=min_queue_examples +
                3 * FLAGS.batch_size,
                common_queue_min=min_queue_examples)
            [image, label, text_id,
             text] = provider.get(['image', 'label', 'caption_ids', 'caption'])

            train_image_size = network_fn.default_image_size
            image = image_preprocessing_fn(image, train_image_size,
                                           train_image_size)

            # This function splits the text into an input sequence and a target sequence,
            # where the target sequence is the input sequence right-shifted by 1. Input and
            # target sequences are batched and padded up to the maximum length of sequences
            # in the batch. A mask is created to distinguish real words from padding words.
            # Note that the target sequence is used if performing caption generation
            seq_length = tf.shape(text_id)[0]
            input_length = tf.expand_dims(tf.subtract(seq_length, 1), 0)
            input_seq = tf.slice(text_id, [0], input_length)
            target_seq = tf.slice(text_id, [1], input_length)
            input_mask = tf.ones(input_length, dtype=tf.int32)

            images, labels, input_seqs, target_seqs, input_masks, texts, text_ids = tf.train.batch(
                [
                    image, label, input_seq, target_seq, input_mask, text,
                    text_id
                ],
                batch_size=FLAGS.batch_size,
                capacity=2 * FLAGS.num_preprocessing_threads *
                FLAGS.batch_size,
                dynamic_pad=True,
                name="batch_and_pad")

            batch_queue = slim.prefetch_queue.prefetch_queue(
                [
                    images, labels, input_seqs, target_seqs, input_masks,
                    texts, text_ids
                ],
                capacity=16 * deploy_config.num_clones,
                num_threads=FLAGS.num_preprocessing_threads,
                dynamic_pad=True,
                name="perfetch_and_pad")

            images, labels, input_seqs, target_seqs, input_masks, texts, text_ids = batch_queue.dequeue(
            )

        images_splits = tf.split(axis=0,
                                 num_or_size_splits=FLAGS.num_gpus,
                                 value=images)
        labels_splits = tf.split(axis=0,
                                 num_or_size_splits=FLAGS.num_gpus,
                                 value=labels)
        input_seqs_splits = tf.split(axis=0,
                                     num_or_size_splits=FLAGS.num_gpus,
                                     value=input_seqs)
        target_seqs_splits = tf.split(axis=0,
                                      num_or_size_splits=FLAGS.num_gpus,
                                      value=target_seqs)
        input_masks_splits = tf.split(axis=0,
                                      num_or_size_splits=FLAGS.num_gpus,
                                      value=input_masks)
        texts_splits = tf.split(axis=0,
                                num_or_size_splits=FLAGS.num_gpus,
                                value=texts)
        text_ids_splits = tf.split(axis=0,
                                   num_or_size_splits=FLAGS.num_gpus,
                                   value=text_ids)

        tower_grads = []
        for k in xrange(FLAGS.num_gpus):
            with tf.device('/gpu:%d' % k):
                with tf.name_scope('tower_%d' % k) as scope:
                    with tf.variable_scope(tf.get_variable_scope()):

                        loss, cmpm_loss, cmpc_loss, i2t_loss, t2i_loss ,loss_ms = \
                            _tower_loss(network_fn, images_splits[k], labels_splits[k],
                                        input_seqs_splits[k], input_masks_splits[k])

                        # Reuse variables for the next tower.
                        tf.get_variable_scope().reuse_variables()

                        # Retain the summaries from the final tower.
                        summaries = tf.get_collection(tf.GraphKeys.SUMMARIES,
                                                      scope)

                        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS,
                                                       scope=scope)

                        # Variables to train.
                        variables_to_train = get_variables_to_train()
                        grads = optimizer.compute_gradients(
                            loss, var_list=variables_to_train)

                        tower_grads.append(grads)

        # We must calculate the mean of each gradient. Note that this is the
        # synchronization point across all towers.
        grads = _average_gradients(tower_grads)

        # Add a summary to track the learning rate and precision.
        summaries.append(tf.summary.scalar('learning_rate', learning_rate))

        # Add histograms for histogram and trainable variables.
        for grad, var in grads:
            if grad is not None:
                summaries.append(
                    tf.summary.histogram(var.op.name + '/gradients', grad))

        for var in tf.trainable_variables():
            summaries.append(tf.summary.histogram(var.op.name, var))

        #################################
        # Configure the moving averages #
        #################################
        if FLAGS.moving_average_decay:
            moving_average_variables = slim.get_model_variables()
            variable_averages = tf.train.ExponentialMovingAverage(
                FLAGS.moving_average_decay, global_step)
            update_ops.append(
                variable_averages.apply(moving_average_variables))

        # Apply the gradients to adjust the shared variables.
        grad_updates = optimizer.apply_gradients(grads,
                                                 global_step=global_step)
        update_ops.append(grad_updates)

        # Group all updates to into a single train op.
        train_op = tf.group(*update_ops)

        # Create a saver.
        saver = tf.train.Saver(tf.global_variables(), max_to_keep=100)

        # Build the summary operation from the last tower summaries.
        summary_op = tf.summary.merge(summaries)

        # Build an initialization operation to run below.
        init = tf.global_variables_initializer()

        # Start running operations on the Graph. allow_soft_placement must be set to
        # True to build towers on GPU, as some of the ops do not have GPU implementations.
        gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.5)
        config = tf.ConfigProto(
            allow_soft_placement=True,
            log_device_placement=FLAGS.log_device_placement,
            gpu_options=gpu_options)

        sess = tf.Session(config=config)
        sess.run(init)

        ck_global_step = get_init_fn(sess)
        print_train_info()

        # Start the queue runners.
        tf.train.start_queue_runners(sess=sess)

        summary_writer = tf.summary.FileWriter(os.path.join(FLAGS.log_dir),
                                               graph=sess.graph)

        num_steps_per_epoch = int(dataset.num_samples / FLAGS.batch_size)
        max_number_of_steps = FLAGS.num_epochs * num_steps_per_epoch + 100000

        for step in xrange(max_number_of_steps):
            step += int(ck_global_step)
            # check the training data
            # simages, slabels, sinput_seqs, starget_seqs, sinput_masks, stexts, stext_ids = \
            # sess.run([images_splits[0], labels_splits[0], input_seqs_splits[0], target_seqs_splits[0],
            #           input_masks_splits[0], texts_splits[0], text_ids_splits[0]])
            # save_images(simages[:8], [1, 8], './{}/{:05d}.png'.format(FLAGS.train_samples_dir, step))
            # import pdb
            # pdb.set_trace()

            _, total_loss_value, cmpm_loss_value, cmpc_loss_value, i2t_loss_value, t2i_loss_value ,loss_ms_value = \
                sess.run([train_op, loss, cmpm_loss, cmpc_loss, i2t_loss, t2i_loss,loss_ms ])
            # sess.run(tf.Print(sim_mat_i2t,[sim_mat_i2t],summarize=16))
            # _, total_loss_value, cmpm_loss_value, cmpc_loss_value, i2t_loss_value, t2i_loss_value = \
            #     sess.run([train_op, loss, cmpm_loss, cmpc_loss, i2t_loss, t2i_loss])

            assert not np.isnan(
                cmpm_loss_value), 'Model diverged with cmpm_loss = NaN'
            assert not np.isnan(
                cmpc_loss_value), 'Model diverged with cmpc_loss = NaN'
            assert not np.isnan(
                total_loss_value), 'Model diverged with total_loss = NaN'

            if step % 10 == 0:
                format_str = (
                    '%s: step %d, cmpm_loss = %.2f, cmpc_loss = %.2f, '
                    'i2t_loss = %.2f, t2i_loss = %.2f, loss_ms= %.2f')
                print(format_str % (FLAGS.dataset_name, step, cmpm_loss_value,
                                    cmpc_loss_value, i2t_loss_value,
                                    t2i_loss_value, loss_ms_value))
                # print(sim_t2i_,sub_val_t2i_,neg_val_t2i_)

            if step % 100 == 0:
                summary_str = sess.run(summary_op)
                summary_writer.add_summary(summary_str, step)

            # Save the model checkpoint periodically.
            if step % FLAGS.ckpt_steps == 0 or (step +
                                                1) == max_number_of_steps:
                checkpoint_path = os.path.join(FLAGS.checkpoint_dir,
                                               'model.ckpt')
                saver.save(sess, checkpoint_path, global_step=step)