Пример #1
0
    def encode_coordinates_fn(self, net):
        """
        Adds one-hot encoding of coordinates to different views in the networks.
        For each "pixel" of a feature map it adds a one hot encoded x and y
        coordinates.
        :param net: a tensor of shape=[batch_size, height, width, num_features]
        :return: a tensor with the same height and width, but altered feature_size.
        """

        mparams = self._mparams['encode_coordinates_fn']
        if mparams.enabled:
            batch_size, h, w, _ = net.shape.as_list()

            # create two matrix has shape (w, h) or (w, h)
            x, y = tf.meshgrid(tf.range(w), tf.range(h))
            w_loc = tf_slim.one_hot_encoding(
                x, num_classes=w)  # shape of (w, h, w)
            h_loc = tf_slim.one_hot_encoding(
                y, num_classes=h)  # shape of (w, h, h)
            loc = tf.concat([h_loc, w_loc], axis=2)  # shape of (w, h, w + h)
            loc = tf.tile(
                tf.expand_dims(loc, 0),
                [batch_size, 1, 1, 1])  # shape of (batch_size, w, h, w + h)

            return tf.concat(
                [net, loc],
                3)  # shape of (batch_size, w, h, w + h + num_features)
        else:
            return net
Пример #2
0
    def char_prediction(self, chars_logit):
        """
        return confidence scores (softmax values) for predicted characters

        :param chars_logit: chars logits, a tensor with shape [batch_size x seq_length x num_char_classes]
        :return:
            A tuple (ids, log_prob, scores), where:
            ids - predicted characters, a int32 tensor with shape
            [batch_size x seq_length];
            log_prob - a log probability of all characters, a float tensor with
            shape [batch_size, seq_length, num_char_classes];
            scores - corresponding confidence scores for characters, a float
                    tensor with shape [batch_size x seq_length].
        """

        log_prob = logits_to_log_prob(chars_logit)
        ids = tf.cast(tf.argmax(log_prob, axis=2),
                      name='predicted_chars',
                      dtype=tf.int32)

        mask = tf.cast(
            tf_slim.one_hot_encoding(ids, self._params.num_char_classes),
            tf.bool)
        all_scores = tf.nn.softmax(chars_logit)
        selected_scores = tf.boolean_mask(all_scores, mask, name='char_scores')
        scores = tf.reshape(selected_scores,
                            shape=(-1, self._params.seq_length))

        return ids, log_prob, scores
Пример #3
0
def imagenet_input(is_training):
    """Data reader for imagenet.

  Reads in imagenet data and performs pre-processing on the images.

  Args:
     is_training: bool specifying if train or validation dataset is needed.
  Returns:
     A batch of images and labels.
  """
    if is_training:
        dataset = dataset_factory.get_dataset('imagenet', 'train',
                                              FLAGS.dataset_dir)
    else:
        dataset = dataset_factory.get_dataset('imagenet', 'validation',
                                              FLAGS.dataset_dir)

    provider = slim.dataset_data_provider.DatasetDataProvider(
        dataset,
        shuffle=is_training,
        common_queue_capacity=2 * FLAGS.batch_size,
        common_queue_min=FLAGS.batch_size)
    [image, label] = provider.get(['image', 'label'])

    image_preprocessing_fn = preprocessing_factory.get_preprocessing(
        'mobilenet_v1', is_training=is_training)

    image = image_preprocessing_fn(image, FLAGS.image_size, FLAGS.image_size)

    images, labels = tf.train.batch([image, label],
                                    batch_size=FLAGS.batch_size,
                                    num_threads=4,
                                    capacity=5 * FLAGS.batch_size)
    labels = slim.one_hot_encoding(labels, FLAGS.num_classes)
    return images, labels
Пример #4
0
def get_data(dataset,
             batch_size,
             augment=False,
             central_crop_size=None,
             shuffle_config=None,
             shuffle=True):
    """Wraps calls to DatasetDataProviders and shuffle_batch.
  For more details about supported Dataset objects refer to datasets/fsns.py.
  Args:
    dataset: a slim.data.dataset.Dataset object.
    batch_size: number of samples per batch.
    augment: optional, if True does random image distortion.
    central_crop_size: A CharLogit tuple (crop_width, crop_height).
    shuffle_config: A namedtuple ShuffleBatchConfig.
    shuffle: if True use data shuffling.
  Returns:
  """
    if not shuffle_config:
        shuffle_config = DEFAULT_SHUFFLE_CONFIG

    provider = tf_slim.dataset_data_provider.DatasetDataProvider(
        dataset,
        shuffle=shuffle,
        common_queue_capacity=2 * batch_size,
        common_queue_min=batch_size)
    image_orig, label = provider.get(['image', 'label'])

    image = preprocess_image(image_orig,
                             augment,
                             central_crop_size,
                             num_towers=dataset.num_of_views)
    label_one_hot = tf_slim.one_hot_encoding(label, dataset.num_char_classes)
    # print(image.get_shape())
    # print(image_orig.get_shape())
    # print(label[0].get_shape())
    # print(label_one_hot.get_shape())
    """
    dataset = tf.data.Dataset.from_tensor_slices((image, image_orig, label, label_one_hot))
    dataset = dataset.shuffle(buffer_size=shuffle_config.min_after_dequeue, reshuffle_each_iteration=True).batch(batch_size=batch_size)

    images = tf.constant(list(dataset.map(lambda x_img, x_img_orig, y_label, y_label_one_hot: x_img)))
    images_orig = tf.constant(list(dataset.map(lambda x_img, x_img_orig, y_label, y_label_one_hot: x_img_orig)))
    labels = tf.constant(list(dataset.map(lambda x_img, x_img_orig, y_label, y_label_one_hot: y_label)))
    labels_one_hot = tf.constant(list(dataset.map(lambda x_img, x_img_orig, y_label, y_label_one_hot: y_label_one_hot)))
    """

    images, images_orig, labels, labels_one_hot = (
        tf.compat.v1.train.shuffle_batch(
            [image, image_orig, label, label_one_hot],
            batch_size=batch_size,
            num_threads=shuffle_config.num_batching_threads,
            capacity=shuffle_config.queue_capacity,
            min_after_dequeue=shuffle_config.min_after_dequeue))

    return InputEndpoints(images=images,
                          images_orig=images_orig,
                          labels=labels,
                          labels_one_hot=labels_one_hot)
Пример #5
0
 def char_one_hot(self, logit):
     """Creates one hot encoding for a logit of a character.
     Args:
       logit: A tensor with shape [batch_size, num_char_classes].
     Returns:
       A tensor with shape [batch_size, num_char_classes]
     """
     prediction = tf.argmax(logit, axis=1)
     return tf_slim.one_hot_encoding(prediction, self._params.num_char_classes)
Пример #6
0
    def build_model(self):
        tf.reset_default_graph()
        self.losses = []
        self.vars = []
        self.avg_gradient = []
        self.apply_grad = []
        self.instances = []
        self.gradients = []

        class setter():
            def __init__(self, assignment, devices):
                self.assignment = assignment
                self.last_device = devices[0]

            def choose(self, op):
                scope = tf.get_variable_scope().name
                for key in self.assignment:
                    if key in scope:
                        self.last_device = self.assignment[key]
                        return self.assignment[key]
                #print(self.assignment)
                print(scope, op.name, self.last_device)
                return self.last_device

        def device_setter(assignment, devices):
            _setter = setter(assignment, devices)
            return _setter.choose

        losses = []
        outputs = []

        tf.get_variable_scope()._reuse = tf.AUTO_REUSE
        for i in range(1):
            loss, output, scopes = self.model_fn(None, self.model_name)
            losses.append(loss)
            outputs.append(output[-1])
        self.scopes = scopes
        new_loss = tf.add_n(losses)
        new_loss = tf.reduce_mean(new_loss, name="final_loss")
        #self.train_op = tf.train.AdamOptimizer(learning_rate=0.2, beta1=0.9, beta2=0.98, epsilon=1e-9).minimize(new_loss)
        self.train_op = tf.train.GradientDescentOptimizer(
            learning_rate=0.01).minimize(new_loss,
                                         colocate_gradients_with_ops=True)
        init = tf.global_variables_initializer()

        g = tf.get_default_graph().as_graph_def(add_shapes=True)
        import tge
        strategy = {node.name: [1, 1, 1, 1, 1] for node in g.node}

        g = (
            tge.TGE(g, devices).custom(strategy)
            # .replace_placeholder(BATCHSIZE)
            .use_collective()
            # .verbose()
            .compile().get_result())

        with open("vgg_tge_modified.pbtxt", "w") as fo:
            fo.write(pbtf.MessageToString(g))

        tf.reset_default_graph()
        gdef = graph_pb2.GraphDef()
        with open("vgg_tge_modified.pbtxt", "r") as f:
            txt = f.read()
        pbtf.Parse(txt, gdef)

        tf.import_graph_def(gdef)
        graph = tf.get_default_graph()

        dataset = dataset_factory.get_dataset("imagenet", "train",
                                              "/data/slim_imagenet")

        preprocessing_name = "vgg_19"
        image_preprocessing_fn = preprocessing_factory.get_preprocessing(
            preprocessing_name, is_training=True)

        provider = slim.dataset_data_provider.DatasetDataProvider(
            dataset,
            num_readers=4,
            common_queue_capacity=20 * batch_size,
            common_queue_min=10 * batch_size)
        [image, label] = provider.get(['image', 'label'])

        train_image_size = 224

        image = image_preprocessing_fn(image, train_image_size,
                                       train_image_size)
        print("image shape:", image.shape)
        print("label shape:", label.shape)
        images, labels = tf.train.batch([image, label],
                                        batch_size=batch_size,
                                        num_threads=4,
                                        capacity=5 * batch_size)
        labels = slim.one_hot_encoding(labels, dataset.num_classes)
        batch_queue = slim.prefetch_queue.prefetch_queue([images, labels],
                                                         capacity=2 *
                                                         micro_batch_num)

        x_tensor = graph.get_tensor_by_name("import/Placeholder/replica_0:0")
        y_tensor = graph.get_tensor_by_name("import/Placeholder_1/replica_0:0")
        x, y = batch_queue.dequeue()
        replace_input(graph, x, x_tensor.name)
        replace_input(graph, y, y_tensor.name)

        opt = graph.get_operation_by_name("import/GradientDescent/replica_0")
        loss = tf.reduce_mean(tf.add_n(get_tensors(graph, "final_loss")))
        init = graph.get_operation_by_name("import/init/replica_0")

        config = tf.ConfigProto()
        config.allow_soft_placement = True
        sess = tf.Session(config=config)
        sess.run(init)
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess=sess, coord=coord)
        for i in range(10000000):
            _, cal_loss = sess.run([opt, loss])
            if i % 10 == 0:
                print("Step:{},Loss:{}".format(i, cal_loss))
Пример #7
0
    def build_model(self):
        tf.reset_default_graph()
        self.losses = []
        self.vars = []
        self.avg_gradient = []
        self.apply_grad = []
        self.instances = []
        self.gradients = []

        gpu_num = 4

        recorded_accuracy5 = []
        global_start_time = time.time()
        with open("vgg_dp3_time_record.txt", "w") as f:
            f.write("global start time: {}\n".format(global_start_time))
        times = []

        class setter():
            def __init__(self, assignment, devices):
                self.assignment = assignment
                self.last_device = devices[0]

            def choose(self, op):
                scope = tf.get_variable_scope().name
                for key in self.assignment:
                    if key in scope:
                        self.last_device = self.assignment[key]
                        return self.assignment[key]
                #print(self.assignment)
                print(scope, op.name, self.last_device)
                return self.last_device

        def device_setter(assignment, devices):
            _setter = setter(assignment, devices)
            return _setter.choose

        losses = []
        outputs = []

        with tf.variable_scope("input", reuse=tf.AUTO_REUSE):

            dataset = dataset_factory.get_dataset("imagenet", "train",
                                                  "/data/slim_imagenet")

            preprocessing_name = "vgg_19"
            image_preprocessing_fn = preprocessing_factory.get_preprocessing(
                preprocessing_name, is_training=True)

            provider = slim.dataset_data_provider.DatasetDataProvider(
                dataset,
                num_readers=4,
                common_queue_capacity=20 * batch_size,
                common_queue_min=10 * batch_size)
            [image, label] = provider.get(['image', 'label'])

            train_image_size = 224

            image = image_preprocessing_fn(image, train_image_size,
                                           train_image_size)
            print("image shape:", image.shape)
            print("label shape:", label.shape)
            images, labels = tf.train.batch([image, label],
                                            batch_size=batch_size,
                                            num_threads=4,
                                            capacity=5 * batch_size)
            labels = slim.one_hot_encoding(labels, dataset.num_classes)
            batch_queue = slim.prefetch_queue.prefetch_queue([images, labels],
                                                             capacity=2 *
                                                             gpu_num)

        tf.get_variable_scope()._reuse = tf.AUTO_REUSE
        for i in range(gpu_num):
            with tf.device("gpu:{}".format(i)):
                loss, output, scopes = self.model_fn(batch_queue,
                                                     self.model_name)
                losses.append(loss)
                outputs.append(output[-1])
        self.scopes = scopes
        with tf.device("gpu:2"):
            new_loss = tf.add_n(losses, name="final_loss") / gpu_num
            new_loss = tf.reduce_mean(new_loss)
            new_outputs = tf.add_n(outputs)
        #self.train_op = tf.train.AdamOptimizer(learning_rate=0.2, beta1=0.9, beta2=0.98, epsilon=1e-9).minimize(new_loss)
        #self.train_op = tf.train.GradientDescentOptimizer(learning_rate=0.01).minimize(new_loss,colocate_gradients_with_ops=True)
        self.train_op = tf.train.MomentumOptimizer(
            learning_rate=0.01,
            momentum=0.9).minimize(new_loss, colocate_gradients_with_ops=True)

        graph = tf.get_default_graph()
        accurate_num = get_tensors(graph, "top_accuracy")
        print("accurate_num:", accurate_num)
        #accurate_num = tf.reduce_sum(tf.add_n(accurate_num))
        accurate_num = tf.reduce_sum(accurate_num[0])

        init = tf.global_variables_initializer()
        config = tf.ConfigProto()
        config.allow_soft_placement = True
        sess = tf.Session(config=config)
        sess.run(init)
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess=sess, coord=coord)
        start_time = time.time()

        for i in range(10000000):
            _, loss, accuracy_num = sess.run(
                [self.train_op, new_loss, accurate_num])
            #top5accuracy = accuracy_num / (gpu_num * batch_size)
            top5accuracy = accuracy_num / (batch_size)

            if i % 10 == 0:
                end_time = time.time()
                print(
                    "Step:{},Loss:{},top5 accuracy:{},per_step_time:{}".format(
                        i, loss, top5accuracy, (end_time - start_time) / 10))
                start_time = time.time()

            gap = top5accuracy * 100 // 5 * 5
            if gap not in recorded_accuracy5:
                global_end_time = time.time()
                recorded_accuracy5.append(gap)
                print(
                    "achieveing {}% at the first time, concreate top5 accuracy: {}%. time slot: {}, duration: {}s\n"
                    .format(gap, top5accuracy * 100, global_end_time,
                            global_end_time - global_start_time),
                    flush=True)
                with open("vgg_dp3_time_record.txt", "a+") as f:
                    f.write(
                        "achieveing {}% at the first time, concreate top5 accuracy: {}%. time slot: {}, duration: {}s\n"
                        .format(gap, top5accuracy * 100, global_end_time,
                                global_end_time - global_start_time))
Пример #8
0
def main(_):
    if not FLAGS.dataset_dir:
        raise ValueError(
            'You must supply the dataset directory with --dataset_dir')

    tf.logging.set_verbosity(tf.logging.INFO)
    with tf.Graph().as_default():
        #######################
        # Config model_deploy #
        #######################
        deploy_config = model_deploy.DeploymentConfig(
            num_clones=FLAGS.num_clones,
            clone_on_cpu=FLAGS.clone_on_cpu,
            replica_id=FLAGS.task,
            num_replicas=FLAGS.worker_replicas,
            num_ps_tasks=FLAGS.num_ps_tasks)

        # Create global_step
        with tf.device(deploy_config.variables_device()):
            global_step = slim.create_global_step()

        ######################
        # Select the dataset #
        ######################
        dataset = dataset_factory.get_dataset(FLAGS.dataset_name,
                                              FLAGS.dataset_split_name,
                                              FLAGS.dataset_dir)

        ######################
        # Select the network #
        ######################
        network_fn = nets_factory.get_network_fn(
            FLAGS.model_name,
            num_classes=(dataset.num_classes - FLAGS.labels_offset),
            weight_decay=FLAGS.weight_decay,
            is_training=True)

        #####################################
        # Select the preprocessing function #
        #####################################
        preprocessing_name = FLAGS.preprocessing_name or FLAGS.model_name
        image_preprocessing_fn = preprocessing_factory.get_preprocessing(
            preprocessing_name, is_training=True)

        ##############################################################
        # Create a dataset provider that loads data from the dataset #
        ##############################################################
        with tf.device(deploy_config.inputs_device()):
            provider = slim.dataset_data_provider.DatasetDataProvider(
                dataset,
                num_readers=FLAGS.num_readers,
                common_queue_capacity=20 * FLAGS.batch_size,
                common_queue_min=10 * FLAGS.batch_size)
            [image, label] = provider.get(['image', 'label'])
            label -= FLAGS.labels_offset

            train_image_size = FLAGS.train_image_size or network_fn.default_image_size

            image = image_preprocessing_fn(image, train_image_size,
                                           train_image_size)

            images, labels = tf.train.batch(
                [image, label],
                batch_size=FLAGS.batch_size,
                num_threads=FLAGS.num_preprocessing_threads,
                capacity=5 * FLAGS.batch_size)
            labels = slim.one_hot_encoding(
                labels, dataset.num_classes - FLAGS.labels_offset)
            batch_queue = slim.prefetch_queue.prefetch_queue(
                [images, labels], capacity=2 * deploy_config.num_clones)

        ####################
        # Define the model #
        ####################
        def clone_fn(batch_queue):
            """Allows data parallelism by creating multiple clones of network_fn."""
            images, labels = batch_queue.dequeue()
            logits, end_points = network_fn(images)

            #############################
            # Specify the loss function #
            #############################
            if 'AuxLogits' in end_points:
                slim.losses.softmax_cross_entropy(
                    end_points['AuxLogits'],
                    labels,
                    label_smoothing=FLAGS.label_smoothing,
                    weights=0.4,
                    scope='aux_loss')
            slim.losses.softmax_cross_entropy(
                logits,
                labels,
                label_smoothing=FLAGS.label_smoothing,
                weights=1.0)
            return end_points

        # Gather initial summaries.
        summaries = set(tf.get_collection(tf.compat.v1.GraphKeys.SUMMARIES))

        clones = model_deploy.create_clones(deploy_config, clone_fn,
                                            [batch_queue])
        first_clone_scope = deploy_config.clone_scope(0)
        # Gather update_ops from the first clone. These contain, for example,
        # the updates for the batch_norm variables created by network_fn.
        update_ops = tf.get_collection(tf.compat.v1.GraphKeys.UPDATE_OPS,
                                       first_clone_scope)

        # Add summaries for end_points.
        end_points = clones[0].outputs
        for end_point in end_points:
            x = end_points[end_point]
            summaries.add(tf.summary.histogram('activations/' + end_point, x))
            summaries.add(
                tf.summary.scalar('sparsity/' + end_point,
                                  tf.nn.zero_fraction(x)))

        # Add summaries for losses.
        for loss in tf.get_collection(tf.compat.v1.GraphKeys.LOSSES,
                                      first_clone_scope):
            summaries.add(tf.summary.scalar('losses/%s' % loss.op.name, loss))

        # Add summaries for variables.
        for variable in slim.get_model_variables():
            summaries.add(tf.summary.histogram(variable.op.name, variable))

        #################################
        # Configure the moving averages #
        #################################
        if FLAGS.moving_average_decay:
            moving_average_variables = slim.get_model_variables()
            variable_averages = tf.train.ExponentialMovingAverage(
                FLAGS.moving_average_decay, global_step)
        else:
            moving_average_variables, variable_averages = None, None

        #if FLAGS.quantize_delay >= 0:
        #  tf.contrib.quantize.create_training_graph(
        #      quant_delay=FLAGS.quantize_delay)

        #########################################
        # Configure the optimization procedure. #
        #########################################
        with tf.device(deploy_config.optimizer_device()):
            learning_rate = _configure_learning_rate(dataset.num_samples,
                                                     global_step)
            optimizer = _configure_optimizer(learning_rate)
            summaries.add(tf.summary.scalar('learning_rate', learning_rate))

        if FLAGS.sync_replicas:
            # If sync_replicas is enabled, the averaging will be done in the chief
            # queue runner.
            optimizer = tf.train.SyncReplicasOptimizer(
                opt=optimizer,
                replicas_to_aggregate=FLAGS.replicas_to_aggregate,
                total_num_replicas=FLAGS.worker_replicas,
                variable_averages=variable_averages,
                variables_to_average=moving_average_variables)
        elif FLAGS.moving_average_decay:
            # Update ops executed locally by trainer.
            update_ops.append(
                variable_averages.apply(moving_average_variables))

        # Variables to train.
        variables_to_train = _get_variables_to_train()

        #  and returns a train_tensor and summary_op
        total_loss, clones_gradients = model_deploy.optimize_clones(
            clones, optimizer, var_list=variables_to_train)
        # Add total_loss to summary.
        summaries.add(tf.summary.scalar('total_loss', total_loss))

        # Create gradient updates.
        grad_updates = optimizer.apply_gradients(clones_gradients,
                                                 global_step=global_step)
        update_ops.append(grad_updates)

        update_op = tf.group(*update_ops)
        with tf.control_dependencies([update_op]):
            train_tensor = tf.identity(total_loss, name='train_op')

        # Add the summaries from the first clone. These contain the summaries
        # created by model_fn and either optimize_clones() or _gather_clone_loss().
        summaries |= set(
            tf.get_collection(tf.compat.v1.GraphKeys.SUMMARIES,
                              first_clone_scope))

        # Merge all summaries together.
        summary_op = tf.summary.merge(list(summaries), name='summary_op')

        ###########################
        # Kicks off the training. #
        ###########################
        slim.learning.train(
            train_tensor,
            logdir=FLAGS.train_dir,
            master=FLAGS.master,
            is_chief=(FLAGS.task == 0),
            init_fn=_get_init_fn(),
            summary_op=summary_op,
            number_of_steps=FLAGS.max_number_of_steps,
            log_every_n_steps=FLAGS.log_every_n_steps,
            save_summaries_secs=FLAGS.save_summaries_secs,
            save_interval_secs=FLAGS.save_interval_secs,
            sync_optimizer=optimizer if FLAGS.sync_replicas else None)
def main(_):
    #tf.disable_v2_behavior() ###
    tf.compat.v1.disable_eager_execution()
    tf.compat.v1.enable_resource_variables()

    # Enable habana bf16 conversion pass
    if FLAGS.dtype == 'bf16':
        os.environ['TF_BF16_CONVERSION'] = flags.FLAGS.bf16_config_path
        FLAGS.precision = 'bf16'
    else:
        os.environ['TF_BF16_CONVERSION'] = "0"

    if FLAGS.use_horovod:
        hvd_init()

    if not FLAGS.dataset_dir:
        raise ValueError(
            'You must supply the dataset directory with --dataset_dir')

    tf.logging.set_verbosity(tf.logging.INFO)
    with tf.Graph().as_default():
        #######################
        # Config model_deploy #
        #######################
        deploy_config = model_deploy.DeploymentConfig(
            num_clones=FLAGS.num_clones,
            clone_on_cpu=FLAGS.clone_on_cpu,
            replica_id=FLAGS.task,
            num_replicas=FLAGS.worker_replicas,
            num_ps_tasks=FLAGS.num_ps_tasks)

        # Create global_step
        with tf.device(deploy_config.variables_device()):
            global_step = slim.create_global_step()

        ######################
        # Select the dataset #
        ######################
        dataset = dataset_factory.get_dataset(FLAGS.dataset_name,
                                              FLAGS.dataset_split_name,
                                              FLAGS.dataset_dir)

        ######################
        # Select the network #
        ######################
        network_fn = nets_factory.get_network_fn(
            FLAGS.model_name,
            num_classes=(dataset.num_classes - FLAGS.labels_offset),
            weight_decay=FLAGS.weight_decay,
            is_training=True)

        #####################################
        # Select the preprocessing function #
        #####################################
        preprocessing_name = FLAGS.preprocessing_name or FLAGS.model_name
        image_preprocessing_fn = preprocessing_factory.get_preprocessing(
            preprocessing_name,
            is_training=True,
            use_grayscale=FLAGS.use_grayscale)

        ##############################################################
        # Create a dataset provider that loads data from the dataset #
        ##############################################################
        with tf.device(deploy_config.inputs_device()):
            provider = slim.dataset_data_provider.DatasetDataProvider(
                dataset,
                num_readers=FLAGS.num_readers,
                common_queue_capacity=20 * FLAGS.batch_size,
                common_queue_min=10 * FLAGS.batch_size)
            [image, label] = provider.get(['image', 'label'])
            label -= FLAGS.labels_offset

            train_image_size = FLAGS.train_image_size or network_fn.default_image_size

            image = image_preprocessing_fn(image, train_image_size,
                                           train_image_size)

            images, labels = tf.train.batch(
                [image, label],
                batch_size=FLAGS.batch_size,
                num_threads=FLAGS.num_preprocessing_threads,
                capacity=5 * FLAGS.batch_size)
            labels = slim.one_hot_encoding(
                labels, dataset.num_classes - FLAGS.labels_offset)
            batch_queue = slim.prefetch_queue.prefetch_queue(
                [images, labels], capacity=2 * deploy_config.num_clones)

        ####################
        # Define the model #
        ####################
        def clone_fn(batch_queue):
            """Allows data parallelism by creating multiple clones of network_fn."""
            images, labels = batch_queue.dequeue()
            logits, end_points = network_fn(images)

            #############################
            # Specify the loss function #
            #############################
            if 'AuxLogits' in end_points:
                slim.losses.softmax_cross_entropy(
                    end_points['AuxLogits'],
                    labels,
                    label_smoothing=FLAGS.label_smoothing,
                    weights=0.4,
                    scope='aux_loss')
            slim.losses.softmax_cross_entropy(
                logits,
                labels,
                label_smoothing=FLAGS.label_smoothing,
                weights=1.0)
            return end_points

        # Gather initial summaries.

        summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES))

        clones = model_deploy.create_clones(deploy_config, clone_fn,
                                            [batch_queue])
        first_clone_scope = deploy_config.clone_scope(0)
        # Gather update_ops from the first clone. These contain, for example,
        # the updates for the batch_norm variables created by network_fn.
        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS,
                                       first_clone_scope)

        # Add summaries for end_points.
        end_points = clones[0].outputs

        for end_point in end_points:
            x = end_points[end_point]
            summaries.add(tf.summary.histogram('activations/' + end_point, x))
            summaries.add(
                tf.summary.scalar('sparsity/' + end_point,
                                  tf.nn.zero_fraction(x)))

        # Add summaries for variables.
        for variable in slim.get_model_variables():
            summaries.add(tf.summary.histogram(variable.op.name, variable))

        #################################
        # Configure the moving averages #
        #################################
        if FLAGS.moving_average_decay:
            moving_average_variables = slim.get_model_variables()
            variable_averages = tf.train.ExponentialMovingAverage(
                FLAGS.moving_average_decay, global_step)
        else:
            moving_average_variables, variable_averages = None, None

        #if FLAGS.quantize_delay >= 0:
        #  quantize.create_training_graph(quant_delay=FLAGS.quantize_delay) #for debugging!!

        #########################################
        # Configure the optimization procedure. #
        #########################################
        with tf.device(deploy_config.optimizer_device()):
            learning_rate = _configure_learning_rate(dataset.num_samples,
                                                     global_step)
            optimizer = _configure_optimizer(learning_rate)
            summaries.add(tf.summary.scalar('learning_rate', learning_rate))

        if FLAGS.sync_replicas:
            # If sync_replicas is enabled, the averaging will be done in the chief
            # queue runner.
            optimizer = tf.train.SyncReplicasOptimizer(
                opt=optimizer,
                replicas_to_aggregate=FLAGS.replicas_to_aggregate,
                total_num_replicas=FLAGS.worker_replicas,
                variable_averages=variable_averages,
                variables_to_average=moving_average_variables)
        elif FLAGS.moving_average_decay:
            # Update ops executed locally by trainer.
            update_ops.append(
                variable_averages.apply(moving_average_variables))

        # Variables to train.
        variables_to_train = _get_variables_to_train()

        #  and returns a train_tensor and summary_op
        total_loss, clones_gradients = model_deploy.optimize_clones(
            clones, optimizer, var_list=variables_to_train)

        # Create gradient updates.
        grad_updates = optimizer.apply_gradients(clones_gradients,
                                                 global_step=global_step)
        update_ops.append(grad_updates)

        update_op = tf.group(*update_ops)
        with tf.control_dependencies([update_op]):
            train_tensor = tf.identity(total_loss, name='train_op')

        # Add the summaries from the first clone. These contain the summaries
        # created by model_fn and either optimize_clones() or _gather_clone_loss().
        summaries |= set(
            tf.get_collection(tf.GraphKeys.SUMMARIES, first_clone_scope))

        # Merge all summaries together.
        summary_op = tf.summary.merge(list(summaries), name='summary_op')

        if horovod_enabled():
            hvd.broadcast_global_variables(0)
        ###########################
        # Kicks off the training. #
        ###########################
        with dump_callback():
            with logger.benchmark_context(FLAGS):
                eps1 = ExamplesPerSecondKerasHook(FLAGS.log_every_n_steps,
                                                  output_dir=FLAGS.train_dir,
                                                  batch_size=FLAGS.batch_size)

                write_hparams_v1(
                    eps1.writer, {
                        'batch_size': FLAGS.batch_size,
                        **{x: getattr(FLAGS, x)
                           for x in FLAGS}
                    })

                train_step_kwargs = {}
                if FLAGS.max_number_of_steps:
                    should_stop_op = math_ops.greater_equal(
                        global_step, FLAGS.max_number_of_steps)
                else:
                    should_stop_op = constant_op.constant(False)
                train_step_kwargs['should_stop'] = should_stop_op
                if FLAGS.log_every_n_steps > 0:
                    train_step_kwargs['should_log'] = math_ops.equal(
                        math_ops.mod(global_step, FLAGS.log_every_n_steps), 0)

                eps1.on_train_begin()
                train_step_kwargs['EPS'] = eps1

                slim.learning.train(
                    train_tensor,
                    logdir=FLAGS.train_dir,
                    train_step_fn=train_step1,
                    train_step_kwargs=train_step_kwargs,
                    master=FLAGS.master,
                    is_chief=(FLAGS.task == 0),
                    init_fn=_get_init_fn(),
                    summary_op=summary_op,
                    summary_writer=None,
                    number_of_steps=FLAGS.max_number_of_steps,
                    log_every_n_steps=FLAGS.log_every_n_steps,
                    save_summaries_secs=FLAGS.save_summaries_secs,
                    save_interval_secs=FLAGS.save_interval_secs,
                    sync_optimizer=optimizer if FLAGS.sync_replicas else None)
def main(_):
  tf.disable_eager_execution()

  if not FLAGS.dataset_dir:
    raise ValueError('You must supply the dataset directory with --dataset_dir')

  tf.logging.set_verbosity(tf.logging.INFO)
  with tf.Graph().as_default():
    ######################
    # Select the dataset #
    ######################
    dataset = dataset_factory.get_dataset(
        FLAGS.dataset_name, FLAGS.dataset_split_name, FLAGS.dataset_dir)

    #########################
    # Configure the network #
    #########################
    inception_params = network_params.InceptionV3FCNParams(
        receptive_field_size=FLAGS.receptive_field_size,
        prelogit_dropout_keep_prob=0.8,
        depth_multiplier=0.1,
        min_depth=16,
        inception_fcn_stride=0,
    )
    conv_params = network_params.ConvScopeParams(
        dropout=False,
        dropout_keep_prob=0.8,
        batch_norm=True,
        batch_norm_decay=0.99,
        l2_weight_decay=4e-05,
    )
    network_fn = inception_v3_fcn.get_inception_v3_fcn_network_fn(
        inception_params,
        conv_params,
        num_classes=dataset.num_classes,
        is_training=True,
    )

    #####################################
    # Select the preprocessing function #
    #####################################
    image_preprocessing_fn = preprocessing_factory.get_preprocessing(
        'inception_v3', is_training=True)

    ##############################################################
    # Create a dataset provider that loads data from the dataset #
    ##############################################################
    provider = slim.dataset_data_provider.DatasetDataProvider(
        dataset,
        num_readers=DATASET_READERS,
        common_queue_capacity=20 * FLAGS.batch_size,
        common_queue_min=10 * FLAGS.batch_size)
    [image, label] = provider.get(['image', 'label'])
    train_image_size = FLAGS.receptive_field_size
    image = image_preprocessing_fn(image, train_image_size, train_image_size)
    images, labels = tf.train.batch([image, label],
                                    batch_size=FLAGS.batch_size,
                                    num_threads=PREPROCESSING_THREADS,
                                    capacity=5 * FLAGS.batch_size)
    labels = slim.one_hot_encoding(labels, dataset.num_classes)

    ####################
    # Define the model #
    ####################
    logits, _ = network_fn(images)

    slim.losses.softmax_cross_entropy(logits, labels)
    total_loss = slim.losses.get_total_loss()
    tf.summary.scalar('losses/Total_Loss', total_loss)

    optimizer = tf.train.RMSPropOptimizer(0.01)

    train_op = slim.learning.create_train_op(
        total_loss,
        optimizer,
        variables_to_train=_get_variables_to_train())

    ###########################
    # Kicks off the training. #
    ###########################
    slim.learning.train(
        train_op,
        logdir=FLAGS.train_dir,
        init_fn=_get_init_fn(),
        number_of_steps=FLAGS.max_number_of_steps,
        log_every_n_steps=FLAGS.log_every_n_steps,
        save_summaries_secs=FLAGS.save_summaries_secs,
        save_interval_secs=FLAGS.save_interval_secs,
        session_config=tf.ConfigProto(allow_soft_placement=True))
Пример #11
0
    def activate_unit(self,path,graph_def):
        #setup_workers(workers, "grpc+verbs")
        tf.reset_default_graph()

        #server = tf.distribute.Server(cluster, job_name='worker', task_index=0, protocol="grpc+verbs",
         #                                  config=config)
        target = None

        tf.import_graph_def(graph_def)
        print("import success")
        graph = tf.get_default_graph()
        init0 = graph.get_operation_by_name("import/init/replica_0")
        print("11111111111111111111111")

        dataset = dataset_factory.get_dataset(
            "imagenet", "train", "/data/slim_imagenet")

        preprocessing_name = "vgg_19"
        image_preprocessing_fn = preprocessing_factory.get_preprocessing(
            preprocessing_name,
            is_training=True)

        provider = slim.dataset_data_provider.DatasetDataProvider(
            dataset,
            num_readers=4,
            common_queue_capacity=20 * batch_size*micro_batch_num,
            common_queue_min=10 * batch_size*micro_batch_num,)
        [image, label] = provider.get(['image', 'label'])

        train_image_size = 224


        image = image_preprocessing_fn(image, train_image_size, train_image_size)
        print("image shape:", image.shape)
        print("label shape:", label.shape)
        images, labels = tf.train.batch(
            [image, label],
            batch_size=batch_size*micro_batch_num,
            num_threads=4,
            capacity=5 * batch_size*micro_batch_num)
        labels = slim.one_hot_encoding(
            labels, dataset.num_classes)
        batch_queue = slim.prefetch_queue.prefetch_queue(
            [images, labels], capacity=2 * micro_batch_num)

        input_dict = None
        '''
        placeholders = [node.outputs[0] for node in graph.get_operations() if node.node_def.op == 'Placeholder']
        shapes = [(p.shape.as_list()) for p in placeholders ]
        for shape in shapes:
            shape[0]=batch_size
        input_dict = { p: np.random.rand(*shapes[i]) for i,p in enumerate(placeholders) }
        '''
        #prepare input

        xs = ["import/input/Placeholder/replica_0:0"]
        ys = ["import/input/Placeholder_1/replica_0:0"]
        for i in range(1,micro_batch_num):
            xs.append("import/input_{}/Placeholder/replica_0:0".format(i))
            ys.append("import/input_{}/Placeholder_1/replica_0:0".format(i))
        x, y = batch_queue.dequeue()
        for i in range(len(xs)):
            replace_input(graph,x[i*batch_size:(i+1)*batch_size],xs[i])
            replace_input(graph,y[i*batch_size:(i+1)*batch_size],ys[i])
        losses = get_tensors(graph, "final_loss")
        losses = tf.reduce_mean(tf.add_n(losses)/len(losses))
        accurate_num = get_tensors(graph,"top_accuracy")
        print("accurate_num:",accurate_num)
        total_batch_size = batch_size*micro_batch_num
        size_for_each = total_batch_size/len(accurate_num)
        num_to_calculate = int(64/size_for_each)
        accurate_num = tf.reduce_sum(tf.add_n(accurate_num[:num_to_calculate]))

        config = tf.ConfigProto()
        config.allow_soft_placement = True
        sess = tf.Session(target, config=config)  # , config=tf.ConfigProto(allow_soft_placement=False))
        print("222222222222222222222222")
        print("333333333333333333333")
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess=sess, coord=coord)
        opt = []
        for sink in self.sinks:
            op = graph.get_operation_by_name('import/' + sink + "/replica_0")
            opt.append(op)
        # opt = [graph.get_operation_by_name('import/' + x) for x in self.sinks]
        print("444444444444444444444")
        recorded_accuracy5 = []
        global_start_time = time.time()
        with open("time_record.txt", "w") as f:
            f.write("global start time: {}\n".format(global_start_time))
        times= []

        sess.run(init0)
        #sess.run(init1)

        start_time = time.time()
        for j in range(100000000000000):
            ret = sess.run(opt + [losses,accurate_num], feed_dict=input_dict)
            loss = ret[-2]
            top5accuracy_num = ret[-1]
            top5accuracy = top5accuracy_num/64
            if j % 10 == 0:
                end_time = time.time()
                print("Step:{},Loss:{},top5 accuracy:{},per_step_time:{}".format(j,loss,top5accuracy,(end_time-start_time)/10))
                start_time = time.time()
            gap = top5accuracy*100 // 5 * 5
            if gap not in recorded_accuracy5:
                global_end_time = time.time()
                recorded_accuracy5.append(gap)
                print("achieveing {}% at the first time, concreate top5 accuracy: {}%. time slot: {}, duration: {}s\n".format(gap,top5accuracy*100,global_end_time,global_end_time-global_start_time),flush=True)
                with open("time_record.txt","a+") as f:
                    f.write("achieveing {}% at the first time, concreate top5 accuracy: {}%. time slot: {}, duration: {}s\n".format(gap,top5accuracy*100,global_end_time,global_end_time-global_start_time))




        avg_time = sum(times)/len(times)
        print(path,times,"average time:", avg_time)
        print(" ")
        '''
Пример #12
0
    def build_model(self):
        tf.reset_default_graph()
        self.losses = []
        self.vars = []
        self.avg_gradient = []
        self.apply_grad = []
        self.instances = []
        self.gradients = []

        class setter():
            def __init__(self, assignment, devices):
                self.assignment = assignment
                self.last_device = devices[0]

            def choose(self, op):
                scope = tf.get_variable_scope().name
                for key in self.assignment:
                    if key in scope:
                        self.last_device = self.assignment[key]
                        return self.assignment[key]
                #print(self.assignment)
                print(scope, op.name, self.last_device)
                return self.last_device

        def device_setter(assignment, devices):
            _setter = setter(assignment, devices)
            return _setter.choose

        losses = []
        outputs = []

        with tf.variable_scope("input", reuse=tf.AUTO_REUSE):

            dataset = dataset_factory.get_dataset("imagenet", "train",
                                                  "/data/slim_imagenet")

            preprocessing_name = "vgg_19"
            image_preprocessing_fn = preprocessing_factory.get_preprocessing(
                preprocessing_name, is_training=True)

            provider = slim.dataset_data_provider.DatasetDataProvider(
                dataset,
                num_readers=4,
                common_queue_capacity=20 * batch_size,
                common_queue_min=10 * batch_size)
            [image, label] = provider.get(['image', 'label'])

            train_image_size = 224

            image = image_preprocessing_fn(image, train_image_size,
                                           train_image_size)
            print("image shape:", image.shape)
            print("label shape:", label.shape)
            images, labels = tf.train.batch([image, label],
                                            batch_size=batch_size,
                                            num_threads=4,
                                            capacity=5 * batch_size)
            labels = slim.one_hot_encoding(labels, dataset.num_classes)
            batch_queue = slim.prefetch_queue.prefetch_queue([images, labels],
                                                             capacity=2 *
                                                             micro_batch_num)

        tf.get_variable_scope()._reuse = tf.AUTO_REUSE
        for i in range(1):
            with tf.device("gpu:{}".format(i)):
                loss, output, scopes = self.model_fn(batch_queue,
                                                     self.model_name)
                losses.append(loss)
                outputs.append(output[-1])
        self.scopes = scopes
        with tf.device("gpu:0"):
            new_loss = tf.add_n(losses, name="final_loss")
            new_loss = tf.reduce_mean(new_loss)
            new_outputs = tf.add_n(outputs)
            #self.train_op = tf.train.AdamOptimizer(learning_rate=0.2, beta1=0.9, beta2=0.98, epsilon=1e-9).minimize(new_loss)
            self.train_op = tf.train.GradientDescentOptimizer(
                learning_rate=0.01).minimize(new_loss,
                                             colocate_gradients_with_ops=True)

        init = tf.global_variables_initializer()
        config = tf.ConfigProto()
        config.allow_soft_placement = True
        sess = tf.Session(config=config)
        sess.run(init)
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess=sess, coord=coord)
        for i in range(10000000):
            _, loss = sess.run([self.train_op, new_loss])
            if i % 10 == 0:
                print("Step:{},Loss:{}".format(i, loss))
Пример #13
0
def main(model_root, datasets_dir, model_name):
    # tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.INFO)
    # 训练相关参数设置
    with tf.Graph().as_default():
        deploy_config = model_deploy.DeploymentConfig(
            num_clones=num_clones,
            clone_on_cpu=False,
            replica_id=task,
            num_replicas=worker_replicas,
            num_ps_tasks=num_ps_tasks)

        global_step = slim.create_global_step()

        train_dir = os.path.join(model_root, model_name)
        dataset = convert_data.get_datasets('train', dataset_dir=datasets_dir)

        network_fn = net_select.get_network_fn(model_name,
                                               num_classes=dataset.num_classes,
                                               weight_decay=weight_decay,
                                               is_training=True)

        image_preprocessing_fn = preprocessing_select.get_preprocessing(
            model_name, is_training=True)

        print("the data_sources:", dataset.data_sources)

        with tf.device(deploy_config.inputs_device()):
            provider = slim.dataset_data_provider.DatasetDataProvider(
                dataset,
                num_readers=num_readers,
                common_queue_capacity=20 * batch_size,
                common_queue_min=10 * batch_size)
            [image, label] = provider.get(['image', 'label'])

            train_image_size = network_fn.default_image_size

            image = image_preprocessing_fn(image, train_image_size,
                                           train_image_size)

            images, labels = tf.compat.v1.train.batch(
                [image, label],
                batch_size=batch_size,
                num_threads=num_preprocessing_threads,
                capacity=5 * batch_size)
            labels = slim.one_hot_encoding(labels, dataset.num_classes)
            batch_queue = slim.prefetch_queue.prefetch_queue(
                [images, labels], capacity=2 * deploy_config.num_clones)

        def calculate_pooling_center_loss(features, label, alfa, nrof_classes,
                                          weights, name):
            features = tf.reshape(features, [features.shape[0], -1])
            label = tf.argmax(label, 1)

            nrof_features = features.get_shape()[1]
            centers = tf.compat.v1.get_variable(
                name, [nrof_classes, nrof_features],
                dtype=tf.float32,
                initializer=tf.constant_initializer(0),
                trainable=False)
            label = tf.reshape(label, [-1])
            centers_batch = tf.gather(centers, label)
            centers_batch = tf.nn.l2_normalize(centers_batch, axis=-1)

            diff = (1 - alfa) * (centers_batch - features)
            centers = tf.compat.v1.scatter_sub(centers, label, diff)

            with tf.control_dependencies([centers]):
                distance = tf.square(features - centers_batch)
                distance = tf.reduce_sum(distance, axis=-1)
                center_loss = tf.reduce_mean(distance)

            center_loss = tf.identity(center_loss * weights,
                                      name=name + '_loss')
            return center_loss

        def attention_crop(attention_maps):
            '''
            利用attention map 做数据增强,这里是论文中的Crop Mask
            :param attention_maps: Feature maps降维得到的
            :return:
            '''
            batch_size, height, width, num_parts = attention_maps.shape
            bboxes = []
            for i in range(batch_size):
                attention_map = attention_maps[i]
                part_weights = attention_map.mean(axis=0).mean(axis=0)
                part_weights = np.sqrt(part_weights)
                part_weights = part_weights / np.sum(part_weights)
                selected_index = np.random.choice(np.arange(0, num_parts),
                                                  1,
                                                  p=part_weights)[0]

                mask = attention_map[:, :, selected_index]

                threshold = random.uniform(0.4, 0.6)
                itemindex = np.where(mask >= mask.max() * threshold)

                ymin = itemindex[0].min() / height - 0.1
                ymax = itemindex[0].max() / height + 0.1
                xmin = itemindex[1].min() / width - 0.1
                xmax = itemindex[1].max() / width + 0.1

                bbox = np.asarray([ymin, xmin, ymax, xmax], dtype=np.float32)
                bboxes.append(bbox)
            bboxes = np.asarray(bboxes, np.float32)
            return bboxes

        def attention_drop(attention_maps):
            '''
            这里是attention drop部分,目的是为了让模型可以注意到物体的其他部位(因不同attention map可能聚焦了同一部位)
            :param attention_maps:
            :return:
            '''
            batch_size, height, width, num_parts = attention_maps.shape
            masks = []
            for i in range(batch_size):
                attention_map = attention_maps[i]
                part_weights = attention_map.mean(axis=0).mean(axis=0)
                part_weights = np.sqrt(part_weights)
                if (np.sum(part_weights) != 0):
                    part_weights = part_weights / np.sum(part_weights)
                selected_index = np.random.choice(np.arange(0, num_parts),
                                                  1,
                                                  p=part_weights)[0]
                mask = attention_map[:, :, selected_index:selected_index + 1]

                # soft mask
                threshold = random.uniform(0.2, 0.5)
                mask = (mask < threshold * mask.max()).astype(np.float32)
                masks.append(mask)
            masks = np.asarray(masks, dtype=np.float32)
            return masks

        def clone_fn(batch_queue):
            """Allows data parallelism by creating multiple clones of network_fn."""
            images, labels = batch_queue.dequeue()
            logits_1, end_points_1 = network_fn(images)

            attention_maps = end_points_1['attention_maps']
            attention_maps = tf.image.resize(
                attention_maps, [train_image_size, train_image_size],
                method=tf.image.ResizeMethod.BILINEAR)

            # attention crop
            bboxes = tf.compat.v1.py_func(attention_crop, [attention_maps],
                                          [tf.float32])
            bboxes = tf.reshape(bboxes, [batch_size, 4])
            box_ind = tf.range(batch_size, dtype=tf.int32)
            images_crop = tf.image.crop_and_resize(
                images,
                bboxes,
                box_ind,
                crop_size=[train_image_size, train_image_size])

            # attention drop
            masks = tf.compat.v1.py_func(attention_drop, [attention_maps],
                                         [tf.float32])
            masks = tf.reshape(
                masks, [batch_size, train_image_size, train_image_size, 1])
            images_drop = images * masks

            logits_2, end_points_2 = network_fn(images_crop, reuse=True)
            logits_3, end_points_3 = network_fn(images_drop, reuse=True)

            slim.losses.softmax_cross_entropy(logits_1,
                                              labels,
                                              weights=1 / 3.0,
                                              scope='cross_entropy_1')
            slim.losses.softmax_cross_entropy(logits_2,
                                              labels,
                                              weights=1 / 3.0,
                                              scope='cross_entropy_2')
            slim.losses.softmax_cross_entropy(logits_3,
                                              labels,
                                              weights=1 / 3.0,
                                              scope='cross_entropy_3')

            embeddings = end_points_1['embeddings']
            center_loss = calculate_pooling_center_loss(
                features=embeddings,
                label=labels,
                alfa=0.95,
                nrof_classes=dataset.num_classes,
                weights=1.0,
                name='center_loss')
            slim.losses.add_loss(center_loss)

            return end_points_1

        # Gather initial summaries.
        summaries = set(
            tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.SUMMARIES))

        clones = model_deploy.create_clones(deploy_config, clone_fn,
                                            [batch_queue])
        first_clone_scope = deploy_config.clone_scope(0)
        # Gather update_ops from the first clone. These contain, for example,
        # the updates for the batch_norm variables created by network_fn.
        update_ops = tf.compat.v1.get_collection(
            tf.compat.v1.GraphKeys.UPDATE_OPS, first_clone_scope)

        # Add summaries for end_points.
        end_points = clones[0].outputs
        for end_point in end_points:
            x = end_points[end_point]
            summaries.add(tf.summary.histogram('activations/' + end_point, x))
            summaries.add(
                tf.summary.scalar('sparsity/' + end_point,
                                  tf.nn.zero_fraction(x)))

        # Add summaries for losses.
        for loss in tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.LOSSES,
                                                first_clone_scope):
            summaries.add(tf.summary.scalar('losses/%s' % loss.op.name, loss))

        # Add summaries for variables.
        for variable in slim.get_model_variables():
            summaries.add(tf.summary.histogram(variable.op.name, variable))

        #################################
        # Configure the moving averages #
        #################################
        if moving_average_decay:
            moving_average_variables = slim.get_model_variables()
            variable_averages = tf.train.ExponentialMovingAverage(
                moving_average_decay, global_step)
        else:
            moving_average_variables, variable_averages = None, None

        #########################################
        # Configure the optimization procedure. #
        #########################################
        with tf.device(deploy_config.optimizer_device()):
            learning_rate = configure_learning_rate(dataset.num_samples,
                                                    global_step)
            optimizer = configure_optimizer(learning_rate)
            summaries.add(tf.summary.scalar('learning_rate', learning_rate))

        if moving_average_decay:
            # Update ops executed locally by trainer.
            update_ops.append(
                variable_averages.apply(moving_average_variables))

        # Variables to train.
        variables_to_train = get_variables_to_train(trainable_scopes)

        #  and returns a train_tensor and summary_op
        total_loss, clones_gradients = model_deploy.optimize_clones(
            clones, optimizer, var_list=variables_to_train)
        # Add total_loss to summary.
        summaries.add(tf.summary.scalar('total_loss', total_loss))

        # Create gradient updates.
        grad_updates = optimizer.apply_gradients(clones_gradients,
                                                 global_step=global_step)
        update_ops.append(grad_updates)

        update_op = tf.group(*update_ops)
        with tf.control_dependencies([update_op]):
            train_tensor = tf.identity(total_loss, name='train_op')

        # Add the summaries from the first clone. These contain the summaries
        # created by model_fn and either optimize_clones() or _gather_clone_loss().
        summaries |= set(
            tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.SUMMARIES,
                                        first_clone_scope))

        # Merge all summaries together.
        summary_op = tf.compat.v1.summary.merge_all()

        config = tf.compat.v1.ConfigProto(allow_soft_placement=True,
                                          log_device_placement=False)
        config.gpu_options.allow_growth = True
        config.gpu_options.visible_device_list = "0"

        save_model_path = os.path.join(checkpoint_path, model_name,
                                       "%s.ckpt" % model_name)
        print(save_model_path)

        # saver = tf.compat.v1.train.import_meta_graph('%s.meta'%save_model_path, clear_devices=True)
        tf.compat.v1.disable_eager_execution()
        # train the model
        slim.learning.train(
            train_op=train_tensor,
            logdir=train_dir,
            is_chief=(task == 0),
            init_fn=_get_init_fn(save_model_path, train_dir=train_dir),
            summary_op=summary_op,
            number_of_steps=max_number_of_steps,
            log_every_n_steps=log_every_n_steps,
            save_summaries_secs=save_summaries_secs,
            save_interval_secs=save_interval_secs,
            # sync_optimizer=None,
            session_config=config)
def main(_):
    if not FLAGS.dataset_dir:
        raise ValueError(
            'You must supply the dataset directory with --dataset_dir')

    tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.INFO)
    with tf.Graph().as_default():
        #######################
        # Config model_deploy #
        #######################
        deploy_config = model_deploy.DeploymentConfig(
            num_clones=FLAGS.num_clones,
            clone_on_cpu=FLAGS.clone_on_cpu,
            replica_id=FLAGS.task,
            num_replicas=FLAGS.worker_replicas,
            num_ps_tasks=FLAGS.num_ps_tasks)

        # Create global_step
        with tf.device(deploy_config.variables_device()):
            global_step = slim.create_global_step()

        ######################
        # Select the dataset #
        ######################
        dataset = dataset_factory.get_dataset(FLAGS.dataset_name,
                                              FLAGS.dataset_split_name,
                                              FLAGS.dataset_dir)

        ######################
        # Select the network #
        ######################
        network_fn = nets_factory.get_network_fn(
            FLAGS.model_name,
            num_classes=(dataset.num_classes - FLAGS.labels_offset),
            weight_decay=FLAGS.weight_decay,
            is_training=True)

        #####################################
        # Select the preprocessing function #
        #####################################
        preprocessing_name = FLAGS.preprocessing_name or FLAGS.model_name
        image_preprocessing_fn = preprocessing_factory.get_preprocessing(
            preprocessing_name, is_training=True)

        ##############################################################
        # Create a dataset provider that loads data from the dataset #
        ##############################################################
        with tf.device(deploy_config.inputs_device()):
            provider = slim.dataset_data_provider.DatasetDataProvider(
                dataset,
                num_readers=FLAGS.num_readers,
                common_queue_capacity=20 * FLAGS.batch_size,
                common_queue_min=10 * FLAGS.batch_size)
            [image, label] = provider.get(['image', 'label'])
            label -= FLAGS.labels_offset

            train_image_size = FLAGS.train_image_size or network_fn.default_image_size

            image = image_preprocessing_fn(image, train_image_size,
                                           train_image_size)

            images, labels = tf.compat.v1.train.batch(
                [image, label],
                batch_size=FLAGS.batch_size,
                num_threads=FLAGS.num_preprocessing_threads,
                capacity=5 * FLAGS.batch_size)
            labels = slim.one_hot_encoding(
                labels, dataset.num_classes - FLAGS.labels_offset)
            batch_queue = slim.prefetch_queue.prefetch_queue(
                [images, labels], capacity=2 * deploy_config.num_clones)

        ####################
        # Define the model #
        ####################
        def clone_fn(batch_queue):
            """Allows data parallelism by creating multiple clones of network_fn."""
            images, labels = batch_queue.dequeue()
            logits, end_points = network_fn(images)

            #############################
            # Specify the loss function #
            #############################
            if 'AuxLogits' in end_points:
                slim.losses.softmax_cross_entropy(
                    end_points['AuxLogits'],
                    labels,
                    label_smoothing=FLAGS.label_smoothing,
                    weights=0.4,
                    scope='aux_loss')
            slim.losses.softmax_cross_entropy(
                logits,
                labels,
                label_smoothing=FLAGS.label_smoothing,
                weights=1.0)

            accuracy = slim.metrics.accuracy(
                tf.cast(tf.argmax(input=logits, axis=1), dtype=tf.int32),
                tf.cast(tf.argmax(input=labels, axis=1), dtype=tf.int32))
            tf.compat.v1.add_to_collection('accuracy', accuracy)
            end_points['train_accuracy'] = accuracy
            return end_points

        # Get accuracies for the batch

        # Gather initial summaries.
        summaries = set(
            tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.SUMMARIES))

        clones = model_deploy.create_clones(deploy_config, clone_fn,
                                            [batch_queue])
        first_clone_scope = deploy_config.clone_scope(0)
        # Gather update_ops from the first clone. These contain, for example,
        # the updates for the batch_norm variables created by network_fn.
        update_ops = tf.compat.v1.get_collection(
            tf.compat.v1.GraphKeys.UPDATE_OPS, first_clone_scope)

        # Add summaries for end_points.
        end_points = clones[0].outputs

        for end_point in end_points:
            if 'accuracy' in end_point:
                continue
            x = end_points[end_point]
            summaries.add(
                tf.compat.v1.summary.histogram('activations/' + end_point, x))
            summaries.add(
                tf.compat.v1.summary.scalar('sparsity/' + end_point,
                                            tf.nn.zero_fraction(x)))
        train_acc = end_points['train_accuracy']
        summaries.add(
            tf.compat.v1.summary.scalar('train_accuracy',
                                        end_points['train_accuracy']))

        # Add summaries for losses.
        for loss in tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.LOSSES,
                                                first_clone_scope):
            summaries.add(
                tf.compat.v1.summary.scalar('losses/%s' % loss.op.name, loss))

        # @philkuz
        # Add accuracy summaries
        # TODO add if statemetn for n iterations
        # images_val, labels_val= tf.train.batch(
        #     [image, label],
        #     batch_size=FLAGS.batch_size,
        #     num_threads=FLAGS.num_preprocessing_threads,
        #     capacity=5 * FLAGS.batch_size)

        # # labels_val = slim.one_hot_encoding(
        # #     labels_val, dataset.num_classes - FLAGS.labels_offset)
        # batch_queue_val = slim.prefetch_queue.prefetch_queue(
        #     [images_val, labels_val], capacity=2 * deploy_config.num_clones)
        # logits, end_points = network_fn(images, reuse=True)
        # # predictions = tf.nn.softmax(logits)
        # predictions = tf.to_in32(tf.argmax(logits,1))

        # logits_val, end_points_val = network_fn(images_val, reuse=True)
        # predictions_val = tf.to_in32(tf.argmax(logits_val,1))

        # labels_val = tf.squeeze(labels_val)
        # labels = tf.squeeze(labels)

        # names_to_values, names_to_updates = slim.metrics.aggregate_metric_map({
        #       'train/accuracy': slim.metrics.streaming_accuracy(predictions, labels),
        #       'val/accuracy': slim.metrics.streaming_accuracy(predictions_val, labels_val),
        # })
        # for metric_name, metric_value in names_to_values.items():
        #   op = tf.summary.scalar(metric_name, metric_value)
        #   # op = tf.Print(op, [metric_value], metric_name)
        #   summaries.add(op)
        # Add summaries for variables.
        # TODO something to remove some of these from tensorboard scalars
        for variable in slim.get_model_variables():
            summaries.add(
                tf.compat.v1.summary.histogram(variable.op.name, variable))

        #################################
        # Configure the moving averages #
        #################################
        if FLAGS.moving_average_decay:
            moving_average_variables = slim.get_model_variables()
            variable_averages = tf.train.ExponentialMovingAverage(
                FLAGS.moving_average_decay, global_step)
        else:
            moving_average_variables, variable_averages = None, None

        #########################################
        # Configure the optimization procedure. #
        #########################################
        with tf.device(deploy_config.optimizer_device()):
            learning_rate = _configure_learning_rate(dataset.num_samples,
                                                     global_step)
            optimizer = _configure_optimizer(learning_rate)
            summaries.add(
                tf.compat.v1.summary.scalar('learning_rate', learning_rate))

        if FLAGS.sync_replicas:
            # If sync_replicas is enabled, the averaging will be done in the chief
            # queue runner.
            optimizer = tf.compat.v1.train.SyncReplicasOptimizer(
                opt=optimizer,
                replicas_to_aggregate=FLAGS.replicas_to_aggregate,
                total_num_replicas=FLAGS.worker_replicas,
                variable_averages=variable_averages,
                variables_to_average=moving_average_variables)
        elif FLAGS.moving_average_decay:
            # Update ops executed locally by trainer.
            update_ops.append(
                variable_averages.apply(moving_average_variables))

        # Variables to train.
        variables_to_train = _get_variables_to_train()

        #  and returns a train_tensor and summary_op
        total_loss, clones_gradients = model_deploy.optimize_clones(
            clones, optimizer, var_list=variables_to_train)
        # Add total_loss to summary.
        summaries.add(tf.compat.v1.summary.scalar('total_loss', total_loss))

        # Create gradient updates.
        grad_updates = optimizer.apply_gradients(clones_gradients,
                                                 global_step=global_step)
        update_ops.append(grad_updates)

        update_op = tf.group(*update_ops)
        with tf.control_dependencies([update_op]):
            train_tensor = tf.identity(total_loss, name='train_op')

        # Add the summaries from the first clone. These contain the summaries
        # created by model_fn and either optimize_clones() or _gather_clone_loss().
        summaries |= set(
            tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.SUMMARIES,
                                        first_clone_scope))

        # Merge all summaries together.
        summary_op = tf.compat.v1.summary.merge(list(summaries),
                                                name='summary_op')

        # @philkuz
        # set the  max_number_of_steps parameter if num_epochs is available
        print('FLAGS.num_epochs', FLAGS.num_epochs)
        if FLAGS.num_epochs is not None and FLAGS.max_number_of_steps is None:
            FLAGS.max_number_of_steps = int(
                FLAGS.num_epochs * dataset.num_samples / FLAGS.batch_size)
            # FLAGS.max_number_of_steps = int(math.round(FLAGS.num_epochs / dataset.num_samples))

        # setup the logdir
        # @philkuz  the train_dir setup
        if FLAGS.experiment_name is not None:
            experiment_dir = 'bs={},lr={},epochs={}/{}'.format(
                FLAGS.batch_size, FLAGS.learning_rate, FLAGS.num_epochs,
                FLAGS.experiment_name)
            print(experiment_dir)
            FLAGS.train_dir = os.path.join(FLAGS.train_dir, experiment_dir)
            print(FLAGS.train_dir)

        # @philkuz overriding train_step
        def train_step(sess, train_op, global_step, train_step_kwargs):
            """Function that takes a gradient step and specifies whether to stop.
      Args:
        sess: The current session.
        train_op: An `Operation` that evaluates the gradients and returns the
          total loss.
        global_step: A `Tensor` representing the global training step.
        train_step_kwargs: A dictionary of keyword arguments.
      Returns:
        The total loss and a boolean indicating whether or not to stop training.
      Raises:
        ValueError: if 'should_trace' is in `train_step_kwargs` but `logdir` is not.
      """
            start_time = time.time()

            trace_run_options = None
            run_metadata = None
            should_acc = True  # TODO make this not hardcoded @philkuz
            if 'should_trace' in train_step_kwargs:
                if 'logdir' not in train_step_kwargs:
                    raise ValueError(
                        'logdir must be present in train_step_kwargs when '
                        'should_trace is present')
                if sess.run(train_step_kwargs['should_trace']):
                    trace_run_options = config_pb2.RunOptions(
                        trace_level=config_pb2.RunOptions.FULL_TRACE)
                    run_metadata = config_pb2.RunMetadata()
            if not should_acc:
                total_loss, np_global_step = sess.run(
                    [train_op, global_step],
                    options=trace_run_options,
                    run_metadata=run_metadata)
            else:
                total_loss, acc, np_global_step = sess.run(
                    [train_op, train_acc, global_step],
                    options=trace_run_options,
                    run_metadata=run_metadata)
            time_elapsed = time.time() - start_time

            if run_metadata is not None:
                tl = timeline.Timeline(run_metadata.step_stats)
                trace = tl.generate_chrome_trace_format()
                trace_filename = os.path.join(
                    train_step_kwargs['logdir'],
                    'tf_trace-%d.json' % np_global_step)
                tf.compat.v1.logging.info('Writing trace to %s',
                                          trace_filename)
                file_io.write_string_to_file(trace_filename, trace)
                if 'summary_writer' in train_step_kwargs:
                    train_step_kwargs['summary_writer'].add_run_metadata(
                        run_metadata, 'run_metadata-%d' % np_global_step)

            if 'should_log' in train_step_kwargs:
                if sess.run(train_step_kwargs['should_log']):
                    if not should_acc:
                        tf.compat.v1.logging.info(
                            'global step %d: loss = %.4f (%.3f sec/step)',
                            np_global_step, total_loss, time_elapsed)
                    else:
                        tf.compat.v1.logging.info(
                            'global step %d: loss = %.4f train_acc = %.4f (%.3f sec/step)',
                            np_global_step, total_loss, acc, time_elapsed)

            if 'should_stop' in train_step_kwargs:
                should_stop = sess.run(train_step_kwargs['should_stop'])
            else:
                should_stop = False

            return total_loss, should_stop

        ###########################
        # Kicks off the training. #
        ###########################
        slim.learning.train(
            train_tensor,
            logdir=FLAGS.train_dir,
            train_step_fn=train_step,
            master=FLAGS.master,
            is_chief=(FLAGS.task == 0),
            init_fn=_get_init_fn(),
            summary_op=summary_op,
            number_of_steps=FLAGS.max_number_of_steps,
            log_every_n_steps=FLAGS.log_every_n_steps,
            save_summaries_secs=FLAGS.save_summaries_secs,
            save_interval_secs=FLAGS.save_interval_secs,
            sync_optimizer=optimizer if FLAGS.sync_replicas else None)