Ejemplo n.º 1
0
def build_graph(reader,
                model,
                train_data_pattern,
                label_loss_fn=losses.CrossEntropyLoss(),
                batch_size=1000,
                base_learning_rate=0.01,
                learning_rate_decay_examples=1000000,
                learning_rate_decay=0.95,
                optimizer_class=tf.compat.v1.train.AdamOptimizer,
                clip_gradient_norm=1.0,
                regularization_penalty=1,
                num_readers=1,
                num_epochs=None):
    """Creates the Tensorflow graph.

  This will only be called once in the life of
  a training model, because after the graph is created the model will be
  restored from a meta graph file rather than being recreated.

  Args:
    reader: The data file reader. It should inherit from BaseReader.
    model: The core model (e.g. logistic or neural net). It should inherit from
      BaseModel.
    train_data_pattern: glob path to the training data files.
    label_loss_fn: What kind of loss to apply to the model. It should inherit
      from BaseLoss.
    batch_size: How many examples to process at a time.
    base_learning_rate: What learning rate to initialize the optimizer with.
    optimizer_class: Which optimization algorithm to use.
    clip_gradient_norm: Magnitude of the gradient to clip to.
    regularization_penalty: How much weight to give the regularization loss
      compared to the label loss.
    num_readers: How many threads to use for I/O operations.
    num_epochs: How many passes to make over the data. 'None' means an unlimited
      number of passes.
  """

    global_step = tf.Variable(0, trainable=False, name="global_step")

    local_device_protos = device_lib.list_local_devices()
    gpus = [x.name for x in local_device_protos if x.device_type == "GPU"]
    gpus = gpus[:FLAGS.num_gpu]
    num_gpus = len(gpus)

    if num_gpus > 0:
        logging.info("Using the following GPUs to train: " + str(gpus))
        num_towers = num_gpus
        device_string = "/gpu:%d"
    else:
        logging.info("No GPUs found. Training on CPU.")
        num_towers = 1
        device_string = "/cpu:%d"

    learning_rate = tf.train.exponential_decay(base_learning_rate,
                                               global_step * batch_size *
                                               num_towers,
                                               learning_rate_decay_examples,
                                               learning_rate_decay,
                                               staircase=True)
    tf.summary.scalar("learning_rate", learning_rate)

    optimizer = optimizer_class(learning_rate)
    input_data_dict = (get_input_data_tensors(reader,
                                              train_data_pattern,
                                              batch_size=batch_size *
                                              num_towers,
                                              num_readers=num_readers,
                                              num_epochs=num_epochs))
    print('input_data_dict', input_data_dict)
    model_input_raw = input_data_dict["video_matrix"]
    labels_batch = input_data_dict["labels"]
    num_frames = input_data_dict["num_frames"]
    print("model_input_shape, ", model_input_raw.shape)
    print("labels_batch, ", labels_batch)

    import csv
    import urllib3
    import numpy as np
    import pandas as pd
    whitelisted_cls_mask = np.zeros((3862, ), dtype=np.float32)
    url = pd.read_csv('segment_label_ids.csv')
    # response = urllib2.urlopen(url)
    for line in url:
        try:
            cls_id = int(line[0])
            whitelisted_cls_mask[cls_id] = 1.
        except ValueError:
            # Simply skip the non-integer line.
            continue
    #response.close()

    # url2 = 'http://storage.googleapis.com/youtube8m-lijun-mlengine/classCount.csv'
    # response2 = urllib2.urlopen(url2)
    # fobj2 = csv.reader(response2)
    # for line in fobj2:
    #   try:
    #     cls_id = int(line[0])
    #     whitelisted_cls_mask[cls_id] = (15-np.log(int(line[1])))**2
    #   except ValueError:
    #       # Simply skip the non-integer line.
    #     continue
    # response2.close()
    # select=tf.matmul(tf.cast(labels_batch, tf.float32),tf.reshape(whitelisted_cls_mask,[3862,1]))>0
    # select=tf.squeeze(select)
    # model_input_raw = model_input_raw[select,:,:]
    # labels_batch = labels_batch[select,:]
    # num_frames = num_frames[select]

    tf.summary.histogram("model/input_raw", model_input_raw)
    feature_dim = len(model_input_raw.get_shape()) - 1
    model_input = tf.nn.l2_normalize(model_input_raw, feature_dim)

    tower_inputs = tf.split(model_input, num_towers)
    tower_labels = tf.split(labels_batch, num_towers)
    tower_num_frames = tf.split(num_frames, num_towers)
    tower_gradients = []
    tower_predictions = []
    tower_label_losses = []
    tower_reg_losses = []

    # import csv
    # import urllib2
    # import numpy as np
    # whitelisted_cls_mask = np.zeros((3862,),
    #                               dtype=np.float32)
    # url = 'http://storage.googleapis.com/youtube8m-lijun-mlengine/segment_label_ids.csv'
    # response = urllib2.urlopen(url)
    # fobj = csv.reader(response)
    # for line in fobj:
    #   try:
    #     cls_id = int(line[0])
    #     whitelisted_cls_mask[cls_id] = 1.
    #   except ValueError:
    #       # Simply skip the non-integer line.
    #     continue
    # response.close()
    # whitelisted_cls_mask=whitelisted_cls_mask+np.ones((3862,),dtype=np.float32)
    whitelisted_cls_mask = whitelisted_cls_mask * 4 + np.ones(
        (3862, ), dtype=np.float32)
    # whitelisted_cls_mask=0.05*(whitelisted_cls_mask*99+np.ones((3862,),dtype=np.float32))
    # print('whitelisted_cls_mask',np.amin(whitelisted_cls_mask))
    for i in range(num_towers):
        # For some reason these 'with' statements can't be combined onto the same
        # line. They have to be nested.f
        with tf.device(device_string % i):
            with (tf.variable_scope(("tower"), reuse=True if i > 0 else
                                    None)):  # reuse=True if i > 0 else None
                with (slim.arg_scope(
                    [slim.model_variable, slim.variable],
                        device="/cpu:0" if num_gpus != 1 else "/gpu:0")):
                    result = model.create_model(tower_inputs[i],
                                                num_frames=tower_num_frames[i],
                                                vocab_size=reader.num_classes,
                                                labels=tower_labels[i])
                    for variable in slim.get_model_variables():
                        tf.summary.histogram(variable.op.name, variable)
                    # print('result predictions',result["predictions"])
                    predictions = result["predictions"]
                    tower_predictions.append(predictions)

                    if "loss" in result.keys():
                        label_loss = result["loss"]
                    else:
                        label_loss = label_loss_fn.calculate_loss(
                            predictions,
                            tower_labels[i],
                            label_weights=whitelisted_cls_mask)
                        if "aux_predictions" in result.keys():
                            for pred in result["aux_predictions"]:
                                label_loss += label_loss_fn.calculate_loss(
                                    pred,
                                    tower_labels[i],
                                    label_weights=whitelisted_cls_mask)
                    # print('label_loss',label_loss)
                    if "regularization_loss" in result.keys():
                        reg_loss = result["regularization_loss"]
                    else:
                        reg_loss = tf.constant(0.0)

                    reg_losses = tf.losses.get_regularization_losses()
                    if reg_losses:
                        reg_loss += tf.add_n(reg_losses)

                    tower_reg_losses.append(reg_loss)

                    # Adds update_ops (e.g., moving average updates in batch normalization) as
                    # a dependency to the train_op.
                    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
                    if "update_ops" in result.keys():
                        update_ops += result["update_ops"]
                    if update_ops:
                        with tf.control_dependencies(update_ops):
                            barrier = tf.no_op(name="gradient_barrier")
                            with tf.control_dependencies([barrier]):
                                label_loss = tf.identity(label_loss)

                    tower_label_losses.append(label_loss)

                    # Incorporate the L2 weight penalties etc.
                    final_loss = regularization_penalty * reg_loss + label_loss
                    gradients = optimizer.compute_gradients(
                        final_loss, colocate_gradients_with_ops=False)
                    tower_gradients.append(gradients)
    label_loss = tf.reduce_mean(tf.stack(tower_label_losses))
    tf.summary.scalar("label_loss", label_loss)
    if regularization_penalty != 0:
        reg_loss = tf.reduce_mean(tf.stack(tower_reg_losses))
        tf.summary.scalar("reg_loss", reg_loss)
    merged_gradients = utils.combine_gradients(tower_gradients)

    if clip_gradient_norm > 0:
        with tf.name_scope("clip_grads"):
            merged_gradients = utils.clip_gradient_norms(
                merged_gradients, clip_gradient_norm)

    train_op = optimizer.apply_gradients(merged_gradients,
                                         global_step=global_step)

    tf.add_to_collection("global_step", global_step)
    tf.add_to_collection("loss", label_loss)
    tf.add_to_collection("predictions", tf.concat(tower_predictions, 0))
    tf.add_to_collection("input_batch_raw", model_input_raw)
    tf.add_to_collection("input_batch", model_input)
    tf.add_to_collection("num_frames", num_frames)
    tf.add_to_collection("labels", tf.cast(labels_batch, tf.float32))
    tf.add_to_collection("train_op", train_op)
Ejemplo n.º 2
0
def build_graph(reader,
                model,
                train_data_pattern,
                label_loss_fn=losses.CrossEntropyLoss(),
                batch_size=1000,
                base_learning_rate=0.01,
                learning_rate_decay_examples=1000000,
                learning_rate_decay=0.95,
                optimizer_class=tf.train.AdamOptimizer,
                clip_gradient_norm=1.0,
                regularization_penalty=1,
                num_readers=1,
                num_epochs=None):
    """Creates the Tensorflow graph.
  This will only be called once in the life of
  a training model, because after the graph is created the model will be
  restored from a meta graph file rather than being recreated.
  Args:
    reader: The data file reader. It should inherit from BaseReader.
    model: The core model (e.g. logistic or neural net). It should inherit
           from BaseModel.
    train_data_pattern: glob path to the training data files.
    label_loss_fn: What kind of loss to apply to the model. It should inherit
                from BaseLoss.
    batch_size: How many examples to process at a time.
    base_learning_rate: What learning rate to initialize the optimizer with.
    optimizer_class: Which optimization algorithm to use.
    clip_gradient_norm: Magnitude of the gradient to clip to.
    regularization_penalty: How much weight to give the regularization loss
                            compared to the label loss.
    num_readers: How many threads to use for I/O operations.
    num_epochs: How many passes to make over the data. 'None' means an
                unlimited number of passes.
  """

    global_step = tf.Variable(0, trainable=False, name="global_step")

    local_device_protos = device_lib.list_local_devices()
    gpus = [x.name for x in local_device_protos if x.device_type == 'GPU']
    gpus = gpus[:FLAGS.num_gpu]
    num_gpus = len(gpus)

    if num_gpus > 0:
        logging.info("Using the following GPUs to train: " + str(gpus))
        num_towers = num_gpus
        device_string = '/gpu:%d'
    else:
        logging.info("No GPUs found. Training on CPU.")
        num_towers = 1
        device_string = '/cpu:%d'

    learning_rate = tf.train.exponential_decay(base_learning_rate,
                                               global_step * batch_size *
                                               num_towers,
                                               learning_rate_decay_examples,
                                               learning_rate_decay,
                                               staircase=True)
    tf.summary.scalar('learning_rate', learning_rate)

    optimizer = optimizer_class(learning_rate)
    unused_video_id, model_input_raw, labels_batch, num_frames = (
        get_input_data_tensors(reader,
                               train_data_pattern,
                               batch_size=batch_size * num_towers,
                               num_readers=num_readers,
                               num_epochs=num_epochs))
    tf.summary.histogram("model/input_raw", model_input_raw)

    feature_dim = len(model_input_raw.get_shape()) - 1

    model_input = tf.nn.l2_normalize(model_input_raw, feature_dim)

    tower_inputs = tf.split(model_input, num_towers)
    tower_labels = tf.split(labels_batch, num_towers)
    tower_num_frames = tf.split(num_frames, num_towers)
    tower_gradients = []
    tower_predictions = []
    tower_label_losses = []
    tower_reg_losses = []
    for i in range(num_towers):
        # For some reason these 'with' statements can't be combined onto the same
        # line. They have to be nested.
        with tf.device(device_string % i):
            with (tf.variable_scope(("tower"), reuse=True if i > 0 else None)):
                with (slim.arg_scope(
                    [slim.model_variable, slim.variable],
                        device="/cpu:0" if num_gpus != 1 else "/gpu:0")):
                    result = model.create_model(tower_inputs[i],
                                                num_frames=tower_num_frames[i],
                                                vocab_size=reader.num_classes,
                                                labels=tower_labels[i])
                    for variable in slim.get_model_variables():
                        tf.summary.histogram(variable.op.name, variable)

                    predictions = result["predictions"]
                    tower_predictions.append(predictions)

                    if "loss" in result.keys():
                        label_loss = result["loss"]
                    else:
                        label_loss = label_loss_fn.calculate_loss(
                            predictions, tower_labels[i])

                    if "regularization_loss" in result.keys():
                        reg_loss = result["regularization_loss"]
                    else:
                        reg_loss = tf.constant(0.0)

                    reg_losses = tf.losses.get_regularization_losses()
                    if reg_losses:
                        reg_loss += tf.add_n(reg_losses)

                    tower_reg_losses.append(reg_loss)

                    # Adds update_ops (e.g., moving average updates in batch normalization) as
                    # a dependency to the train_op.
                    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
                    if "update_ops" in result.keys():
                        update_ops += result["update_ops"]
                    if update_ops:
                        with tf.control_dependencies(update_ops):
                            barrier = tf.no_op(name="gradient_barrier")
                            with tf.control_dependencies([barrier]):
                                label_loss = tf.identity(label_loss)

                    tower_label_losses.append(label_loss)

                    # Incorporate the L2 weight penalties etc.
                    final_loss = regularization_penalty * reg_loss + label_loss
                    gradients = optimizer.compute_gradients(
                        final_loss, colocate_gradients_with_ops=False)
                    tower_gradients.append(gradients)
    label_loss = tf.reduce_mean(tf.stack(tower_label_losses))
    tf.summary.scalar("label_loss", label_loss)
    if regularization_penalty != 0:
        reg_loss = tf.reduce_mean(tf.stack(tower_reg_losses))
        tf.summary.scalar("reg_loss", reg_loss)
    merged_gradients = utils.combine_gradients(tower_gradients)

    if clip_gradient_norm > 0:
        with tf.name_scope('clip_grads'):
            merged_gradients = utils.clip_gradient_norms(
                merged_gradients, clip_gradient_norm)

    train_op = optimizer.apply_gradients(merged_gradients,
                                         global_step=global_step)

    tf.add_to_collection("global_step", global_step)
    tf.add_to_collection("loss", label_loss)
    tf.add_to_collection("predictions", tf.concat(tower_predictions, 0))
    tf.add_to_collection("input_batch_raw", model_input_raw)
    tf.add_to_collection("input_batch", model_input)
    tf.add_to_collection("num_frames", num_frames)
    tf.add_to_collection("labels", tf.cast(labels_batch, tf.float32))
    tf.add_to_collection("train_op", train_op)
Ejemplo n.º 3
0
def build_graph(reader,
                model,
                train_data_pattern,
                train_data_pattern2,
                train_data_pattern3,
                eval_data_pattern,
                label_loss_fn=losses.CrossEntropyLoss(),
                batch_size=1000,
                base_learning_rate=0.01,
                learning_rate_decay_examples=1000000,
                learning_rate_decay=0.95,
                optimizer_class=tf.train.AdamOptimizer,
                clip_gradient_norm=1.0,
                regularization_penalty=1,
                num_readers=1,
                num_epochs=None,
                l2_penalty=1e-8,
                gpu_only=1):
    """Creates the Tensorflow graph.

  This will only be called once in the life of
  a training model, because after the graph is created the model will be
  restored from a meta graph file rather than being recreated.

  Args:
    reader: The data file reader. It should inherit from BaseReader.
    model: The core model (e.g. logistic or neural net). It should inherit
           from BaseModel.
    train_data_pattern: glob path to the training data files.
    label_loss_fn: What kind of loss to apply to the model. It should inherit
                from BaseLoss.
    batch_size: How many examples to process at a time.
    base_learning_rate: What learning rate to initialize the optimizer with.
    optimizer_class: Which optimization algorithm to use.
    clip_gradient_norm: Magnitude of the gradient to clip to.
    regularization_penalty: How much weight to give the regularization loss
                            compared to the label loss.
    num_readers: How many threads to use for I/O operations.
    num_epochs: How many passes to make over the data. 'None' means an
                unlimited number of passes.
  """
    # data files
    files1 = gfile.Glob(train_data_pattern)
    files2 = gfile.Glob(train_data_pattern2)
    files3 = gfile.Glob(train_data_pattern3)
    files = files1 + files2 + files3
    if not files:
        raise IOError("Unable to find training files. data_pattern='" +
                      data_pattern + "'.")
    logging.info("Total number of training files: %s + %s + %s =  %s.",
                 str(len(files1)), str(len(files2)), str(len(files3)),
                 str(len(files)))

    files4 = gfile.Glob(eval_data_pattern)
    logging.info("Total number of eval files: %s.", str(len(files4)))

    if FLAGS.fold == -1:
        validate_files = files4
        train_files = files
    else:
        validate_files = files[FLAGS.fold::5]
        train_files = [x for x in files if x not in validate_files]

    logging.info("train files: {}, first is: {}.".format(
        len(train_files), train_files[0].split('/')[-1]))
    logging.info("eval files: {}, first is: {}.".format(
        len(validate_files), validate_files[0].split('/')[-1]))

    # label weights for loss function. ugly hard coded for now.
    wgts_np = np.ones(FLAGS.truncated_num_classes)
    over_weight_labels = False
    if over_weight_labels:
        labels_to_overwgt = [
            38, 47, 49, 55, 72, 76, 86, 89, 93, 94, 95, 98, 99, 101, 102, 110,
            111, 113, 114, 115, 120, 121
        ]
        wgts_np[labels_to_overwgt] = 2.0
    wgts_4_lossfn = tf.constant(wgts_np, dtype=tf.float32)

    global_step = tf.Variable(0, trainable=False, name="global_step")
    restart_learning_rate = tf.Variable(base_learning_rate,
                                        trainable=False,
                                        name="restart_learning_rate")

    local_device_protos = device_lib.list_local_devices()
    gpus = [x.name for x in local_device_protos if x.device_type == 'GPU']
    num_gpus = len(gpus)

    if num_gpus > 0:
        logging.info("Using the following GPUs to train: " + str(gpus))
        num_towers = num_gpus
        device_string = '/gpu:%d'
    else:
        logging.info("No GPUs found. Training on CPU.")
        num_towers = 1
        device_string = '/cpu:%d'

    learning_rate = tf.train.exponential_decay(restart_learning_rate,
                                               global_step * batch_size *
                                               num_towers,
                                               learning_rate_decay_examples,
                                               learning_rate_decay,
                                               staircase=True)
    tf.summary.scalar('learning_rate', learning_rate)

    optimizer = optimizer_class(learning_rate)
    unused_video_id, model_input_raw, labels_batch, num_frames = (
        get_input_data_tensors(reader,
                               train_files,
                               batch_size=batch_size * num_towers,
                               num_readers=num_readers,
                               num_epochs=num_epochs))
    tf.summary.histogram("model/input_raw", model_input_raw)

    # model params
    # probabilities for keeping a neuron in a layer, assuming max 10 layers, below default value
    with tf.variable_scope("tower", reuse=True) as scope:
        layers_keep_probs = tf.Variable(
            [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0],
            trainable=False,
            name="layers_keep_probs")
    model_input = model_input_raw
    if FLAGS.apply_global_normalization:
        g_mean, g_std = model_utils.load_global_moments()
        g_inv_std = 1.0 / g_std
        global_mean = tf.constant(g_mean, dtype=tf.float32)
        # expand global mean to match new dimension and fill rest with zeros
        new_dim = tf.cast(model_input_raw.shape[1], tf.int32)
        zero_padding = tf.zeros(new_dim - tf.shape(global_mean), tf.float32)
        global_mean_padded = tf.concat([global_mean, zero_padding], 0)
        # expand global inv std to match new dimension and fill rest with ones
        global_inv_std = tf.constant(g_inv_std, dtype=tf.float32)
        one_padding = tf.ones(new_dim - tf.shape(global_inv_std), tf.float32)
        global_inv_std_padded = tf.concat([global_inv_std, one_padding], 0)
        # apply normalizations (can do both) if requested
        # global L2 normalization
        model_input = tf.multiply(tf.subtract(model_input, global_mean_padded),
                                  global_inv_std_padded)
    # regular L2 normalization
    if FLAGS.apply_batch_l2_normalization:
        feature_dim = len(model_input.get_shape()) - 1
        model_input = tf.nn.l2_normalize(model_input, feature_dim)

    tower_inputs = tf.split(model_input, num_towers)
    tower_labels = tf.split(labels_batch, num_towers)
    tower_num_frames = tf.split(num_frames, num_towers)
    tower_gradients = []
    tower_predictions = []
    tower_label_losses = []
    tower_reg_losses = []

    # eval graph - to monitor performance out of sample during training
    e_video_id, e_input_raw, e_labels_batch, e_num_frames = (
        get_input_data_tensors(reader,
                               validate_files,
                               batch_size=batch_size * num_towers,
                               num_readers=num_readers,
                               num_epochs=2 * num_epochs))
    e_input = e_input_raw
    if FLAGS.apply_global_normalization:
        e_input = tf.multiply(tf.subtract(e_input, global_mean_padded),
                              global_inv_std_padded)
    if FLAGS.apply_batch_l2_normalization:
        feature_dim = len(model_input.get_shape()) - 1
        e_input = tf.nn.l2_normalize(e_input, feature_dim)

    e_tower_inputs = tf.split(e_input, num_towers)
    e_tower_labels = tf.split(e_labels_batch, num_towers)
    e_tower_num_frames = tf.split(e_num_frames, num_towers)
    e_tower_predictions = []
    e_tower_layers_keep_probs = tf.Variable(
        [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0],
        trainable=False,
        name="layers_keep_probs")
    logging.info(e_tower_inputs)
    # end eval
    for i in range(num_towers):
        # For some reason these 'with' statements can't be combined onto the same
        # line. They have to be nested.
        logging.info('For tower: ' + str(i))
        with tf.device(device_string % i):
            with (tf.variable_scope(("tower"), reuse=True if i > 0 else None)):
                with (slim.arg_scope(
                    [slim.model_variable, slim.variable],
                        device="/cpu:0" if num_gpus != 1 else "/gpu:0")):
                    logging.info(layers_keep_probs)
                    result = model.create_model(
                        tower_inputs[i],
                        num_frames=tower_num_frames[i],
                        vocab_size=reader.num_classes,
                        labels=tower_labels[i],
                        layers_keep_probs=layers_keep_probs,
                        l2_penalty=l2_penalty,
                        is_training=True)
                    for variable in slim.get_model_variables():
                        logging.info(variable)
                        tf.summary.histogram(variable.op.name, variable)

                    # create shadow moving average model variables
                    if FLAGS.use_ema == True:
                        model_vars = [x for x in slim.get_model_variables()]
                        ema = tf.train.ExponentialMovingAverage(
                            decay=1.0 - 1.0 / FLAGS.ema_halflife)
                        ema_op = ema.apply(model_vars)
                        logging.info("model_vars:")
                        logging.info(" || ".join([str(x) for x in model_vars]))
                        ema_vars = [ema.average(x) for x in model_vars]
                        ema_vars_pair_dict = {
                            ema.average_name(x): x.op.name
                            for x in model_vars
                        }
                        logging.info("ema_vars_pair_dict:")
                        for x, y in ema_vars_pair_dict.items():
                            logging.info(x + ': ' + y)
                        for v in ema_vars:
                            tf.summary.histogram(v.op.name, v)
                        tf.add_to_collection(tf.GraphKeys.UPDATE_OPS, ema_op)
                        tf.add_to_collection("ema_op", ema_op)

                    predictions = result["predictions"]
                    tower_predictions.append(predictions)

                    if "loss" in result.keys():
                        label_loss = result["loss"]
                    else:
                        label_loss = label_loss_fn.calculate_loss(
                            predictions, tower_labels[i], FLAGS.loss_epsilon)

                    if "regularization_loss" in result.keys():
                        reg_loss = result["regularization_loss"]
                    else:
                        reg_loss = tf.constant(0.0)

                    reg_losses = tf.losses.get_regularization_losses()
                    if reg_losses:
                        reg_loss += tf.add_n(reg_losses)

                    tower_reg_losses.append(reg_loss)

                    # Adds update_ops (e.g., moving average updates in batch normalization) as
                    # a dependency to the train_op.
                    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
                    if "update_ops" in result.keys():
                        update_ops += result["update_ops"]
                    if update_ops:
                        with tf.control_dependencies(update_ops):
                            barrier = tf.no_op(name="gradient_barrier")
                            with tf.control_dependencies([barrier]):
                                label_loss = tf.identity(label_loss)

                    tower_label_losses.append(label_loss)

                    # Incorporate the L2 weight penalties etc.
                    final_loss = regularization_penalty * reg_loss + label_loss
                    gradients = optimizer.compute_gradients(
                        final_loss, colocate_gradients_with_ops=False)
                    tower_gradients.append(gradients)

                    # eval ops
                    logging.info("eval ops")
                    e_result = model.create_model(
                        e_tower_inputs[i],
                        num_frames=e_tower_num_frames[i],
                        vocab_size=reader.num_classes,
                        labels=e_tower_labels[i],
                        layers_keep_probs=
                        e_tower_layers_keep_probs,  #tf.Variable([1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], tf.float32, name="layers_keep_probs")
                        l2_penalty=l2_penalty,
                        is_training=False)

                    e_predictions = e_result["predictions"]
                    e_tower_predictions.append(e_predictions)
                    # end eval ops

    label_loss = tf.reduce_mean(tf.stack(tower_label_losses))
    tf.summary.scalar("label_loss", label_loss)
    if regularization_penalty != 0:
        reg_loss = tf.reduce_mean(tf.stack(tower_reg_losses))
        tf.summary.scalar("reg_loss", reg_loss)
    merged_gradients = utils.combine_gradients(tower_gradients)

    if clip_gradient_norm > 0:
        with tf.name_scope('clip_grads'):
            merged_gradients = utils.clip_gradient_norms(
                merged_gradients, clip_gradient_norm)

    train_op = optimizer.apply_gradients(merged_gradients,
                                         global_step=global_step)

    tf.add_to_collection("global_step", global_step)
    tf.add_to_collection("restart_learning_rate", restart_learning_rate)
    tf.add_to_collection("layers_keep_probs", layers_keep_probs)
    tf.add_to_collection("loss", label_loss)
    tf.add_to_collection("predictions", tf.concat(tower_predictions, 0))
    tf.add_to_collection("input_batch_raw", model_input_raw)
    tf.add_to_collection("input_batch", model_input)
    tf.add_to_collection("num_frames", num_frames)
    tf.add_to_collection("labels", tf.cast(labels_batch, tf.float32))
    tf.add_to_collection("train_op", train_op)
    #tf.add_to_collection("ema_op", ema_op)

    # add eval graph
    e_label_loss = label_loss_fn.calculate_loss(
        tf.concat(e_tower_predictions, 0), e_labels_batch, FLAGS.loss_epsilon)
    tf.summary.scalar("e_label_loss", e_label_loss)

    tf.add_to_collection("e_predictions", tf.concat(e_tower_predictions, 0))
    tf.add_to_collection("e_labels", tf.cast(e_labels_batch, tf.float32))
    tf.add_to_collection("e_loss", e_label_loss)
Ejemplo n.º 4
0
def model_fn(features, labels, mode, params):

    is_training = mode == learn.ModeKeys.TRAIN
    optimizer_class = find_class_by_name(params.optimizer, [tf.train])
    label_loss_fn = find_class_by_name(params.label_loss, [losses])()
    model = find_class_by_name(params.model,
                               [frame_level_models, video_level_models])()

    global_step = tf.train.get_or_create_global_step()
    learning_rate = tf.train.exponential_decay(
        params.base_learning_rate,
        global_step * params.batch_size * params.num_towers,
        params.learning_rate_decay_examples,
        params.learning_rate_decay,
        staircase=True,
    )

    tf.summary.scalar('learning_rate', learning_rate)

    optimizer = optimizer_class(learning_rate)

    tf.summary.histogram("model/input_raw", features['model_input'])

    feature_dim = len(features['model_input'].get_shape()) - 1

    model_input = tf.nn.l2_normalize(features['model_input'], feature_dim)

    tower_inputs = tf.split(model_input, params.num_towers)

    if mode == learn.ModeKeys.INFER:
        # ***
        #  this is a quick hack so that the existing model_fn code,
        #  taken from train.py, doesn't break in inference (or serving) mode.
        #  Normally, we would write model_fn such that the 'labels' input arg
        #  can be None in inference mode, but this existing model code was not written this
        #  way.  See the serving_input_fn() defined below, to see where 'labels_batch'
        # is added to the features dict, just to make this code work properly
        labels = features['labels_batch']

    tower_labels = tf.split(labels, params.num_towers)

    tower_num_frames = tf.split(features['num_frames'], params.num_towers)
    tower_gradients = []
    tower_predictions = []
    tower_label_losses = []
    tower_reg_losses = []

    for i in range(params.num_towers):
        # For some reason these 'with' statements can't be combined onto the same
        # line. They have to be nested.
        with tf.device(params.device_string % i):
            with (tf.variable_scope(("tower"), reuse=True if i > 0 else None)):
                with (slim.arg_scope([slim.model_variable, slim.variable],
                                     device="/cpu:0"
                                     if params.num_gpus != 1 else "/gpu:0")):
                    result = model.create_model(
                        tower_inputs[i],
                        num_frames=tower_num_frames[i],
                        vocab_size=params.reader.num_classes,
                        labels=tower_labels[i],
                        is_training=is_training)
                    for variable in slim.get_model_variables():
                        tf.summary.histogram(variable.op.name, variable)

                    predictions = result["predictions"]

                    tower_predictions.append(predictions)

                    if "loss" in result.keys():
                        label_loss = result["loss"]
                    else:
                        label_loss = label_loss_fn.calculate_loss(
                            predictions, tower_labels[i])

                    if "regularization_loss" in result.keys():
                        reg_loss = result["regularization_loss"]
                    else:
                        reg_loss = tf.constant(0.0)

                    reg_losses = tf.losses.get_regularization_losses()
                    if reg_losses:
                        reg_loss += tf.add_n(reg_losses)

                    tower_reg_losses.append(reg_loss)

                    # Adds update_ops (e.g., moving average updates in batch normalization) as
                    # a dependency to the train_op.
                    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
                    if "update_ops" in result.keys():
                        update_ops += result["update_ops"]
                    if update_ops:
                        with tf.control_dependencies(update_ops):
                            barrier = tf.no_op(name="gradient_barrier")
                            with tf.control_dependencies([barrier]):
                                label_loss = tf.identity(label_loss)

                    tower_label_losses.append(label_loss)

                    final_loss = params.regularization_penalty * reg_loss + label_loss
                    gradients = optimizer.compute_gradients(
                        final_loss, colocate_gradients_with_ops=False)
                    tower_gradients.append(gradients)

    pred_dict = {}
    label_loss = tf.reduce_mean(tf.stack(tower_label_losses))
    predictions = tf.concat(tower_predictions, 0)
    pred_dict['predictions'] = predictions
    tf.summary.scalar("label_loss", label_loss)
    if params.regularization_penalty != 0:
        reg_loss = tf.reduce_mean(tf.stack(tower_reg_losses))
        tf.summary.scalar("reg_loss", reg_loss)

    if is_training:
        # Incorporate the L2 weight penalties, etc.

        merged_gradients = utils.combine_gradients(tower_gradients)
        if params.clip_gradient_norm > 0:
            with tf.name_scope('clip_grads'):
                merged_gradients = utils.clip_gradient_norms(
                    merged_gradients, params.clip_gradient_norm)
        train_op = optimizer.apply_gradients(merged_gradients,
                                             global_step=global_step)
    else:
        train_op = None

    eval_metric_ops = {}
    if mode == learn.ModeKeys.EVAL or is_training:

        eval_metric_ops['hit_at_one'] = metrics.streaming_mean(
            tf.py_func(
                lambda x, y: np.float32(eval_util.calculate_hit_at_one(x, y)),
                [predictions, labels],
                tf.float32,
                stateful=False,
            ))
        eval_metric_ops['perr'] = metrics.streaming_mean(
            tf.py_func(
                lambda x, y: np.float32(
                    eval_util.calculate_precision_at_equal_recall_rate(x, y)),
                [predictions, labels],
                tf.float32,
                stateful=False,
            ))
        eval_metric_ops['gap'] = metrics.streaming_mean(
            tf.py_func(
                lambda x, y: np.float32(eval_util.calculate_gap(x, y)),
                [predictions, labels],
                tf.float32,
                stateful=False,
            ))

    else:
        pass
    top_predictions, top_indices = tf.nn.top_k(predictions,
                                               _TOP_PREDICTIONS_IN_OUTPUT)

    pred_dict['top_predictions'] = top_predictions
    pred_dict['top_indices'] = top_indices

    #add eval summaries and update ops for training
    for key, val in eval_metric_ops.items():
        tf.summary.scalar(key, val[0])  #create summary for each eval op
        tf.add_to_collection(
            tf.GraphKeys.UPDATE_OPS, val[1]
        )  # add the update op for each eval up to update ops collection, so that it will be run every train_op call

    #  tf.add_to_collection("global_step", global_step)
    #  tf.add_to_collection("loss", label_loss)
    tf.add_to_collection("predictions", tf.concat(tower_predictions, 0))
    #  tf.add_to_collection("input_batch_raw", model_input_raw)
    #  tf.add_to_collection("input_batch", model_input)
    #  tf.add_to_collection("num_frames", num_frames)
    tf.add_to_collection("labels", tf.cast(labels, tf.float32))
    #  tf.add_to_collection("train_op", train_op)
    tf.summary.scalar("loss", label_loss)

    export_outputs = {
        tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY:
        tf.estimator.export.PredictOutput(pred_dict)
    }

    return tf.estimator.EstimatorSpec(mode=mode,
                                      predictions=pred_dict,
                                      loss=label_loss,
                                      train_op=train_op,
                                      export_outputs=export_outputs,
                                      eval_metric_ops=eval_metric_ops)
Ejemplo n.º 5
0
def build_graph(reader,
                generator_model,
                discriminator_model,
                train_data_pattern,
                label_loss_fn=losses.CrossEntropyLoss(),
                batch_size=1000,
                base_learning_rate=0.01,
                learning_rate_decay_examples=1000000,
                learning_rate_decay=0.95,
                optimizer_class=tf.train.AdamOptimizer,
                clip_gradient_norm=1.0,
                regularization_penalty=1,
                num_readers=1,
                num_epochs=None):
    """Creates the Tensorflow graph.

  This will only be called once in the life of
  a training model, because after the graph is created the model will be
  restored from a meta graph file rather than being recreated.

  Args:
    reader: The data file reader. It should inherit from BaseReader.
    generator_model: The core model for generator. It should inherit from
                     BaseModel.
    discriminator_model: The core model for discriminator. It should inherit from
                         BaseModel.
    train_data_pattern: glob path to the training data files.
    label_loss_fn: What kind of loss to apply to the model. It should inherit
                from BaseLoss.
    batch_size: How many examples to process at a time.
    base_learning_rate: What learning rate to initialize the optimizer with.
    optimizer_class: Which optimization algorithm to use.
    clip_gradient_norm: Magnitude of the gradient to clip to.
    regularization_penalty: How much weight to give the regularization loss
                            compared to the label loss.
    num_readers: How many threads to use for I/O operations.
    num_epochs: How many passes to make over the data. 'None' means an
                unlimited number of passes.
  """

    global_step = tf.Variable(0, trainable=False, name="global_step")

    gpus = get_gpus()
    num_gpus = len(gpus)

    if num_gpus > 0:
        logging.info("Using the following GPUs to train: " + str(gpus))
        num_towers = num_gpus
        device_string = '/gpu:%d'
    else:
        logging.info("No GPUs found. Training on CPU.")
        num_towers = 1
        device_string = '/cpu:%d'

    learning_rate = tf.train.exponential_decay(base_learning_rate,
                                               global_step * batch_size *
                                               num_towers,
                                               learning_rate_decay_examples,
                                               learning_rate_decay,
                                               staircase=True)
    tf.summary.scalar('learning_rate', learning_rate)

    optimizer = optimizer_class(learning_rate)

    model_input_raw, _ = (get_input_data_tensors(reader,
                                                 train_data_pattern,
                                                 batch_size=batch_size *
                                                 num_towers,
                                                 num_readers=num_readers,
                                                 num_epochs=num_epochs))
    tf.summary.histogram("model/input_raw", model_input_raw)
    model_input = model_input_raw

    noise_input = tf.placeholder(
        tf.float32, shape=[None, random_noise_generator.get_dim()])

    image_width, image_height = reader.get_image_size()

    tower_inputs = tf.split(model_input, num_towers)
    tower_noise_input = tf.split(noise_input, num_towers)
    tower_D_gradients = []
    tower_G_gradients = []
    tower_generated_images = []
    tower_predictions_for_fake = []
    tower_predictions_for_real = []
    tower_D_losses = []
    tower_G_losses = []

    for i in range(num_towers):
        # For some reason these 'with' statements can't be combined onto the same
        # line. They have to be nested.
        with tf.device(device_string % i):
            with (tf.variable_scope(("tower"), reuse=True if i > 0 else None)):
                with (slim.arg_scope(
                    [slim.model_variable, slim.variable],
                        device="/cpu:0" if num_gpus != 1 else "/gpu:0")):
                    generator_model.create_model(image_width * image_height)
                    discriminator_model.create_model(image_width *
                                                     image_height)

                    generated_result = generator_model.run_model(
                        tower_noise_input[i])
                    generated_images = generated_result["output"]

                    generated_images_shaped = tf.reshape(
                        generated_images, [-1, image_height, image_width, 1])
                    tf.summary.image('generated_images',
                                     generated_images_shaped, 10)
                    tower_generated_images.append(generated_images)

                    result_from_fake = discriminator_model.run_model(
                        generated_images)
                    result_from_real = discriminator_model.run_model(
                        tower_inputs[i])
                    for variable in slim.get_model_variables():
                        tf.summary.histogram(variable.op.name, variable)

                    predictions_for_fake = result_from_fake["predictions"]
                    predictions_for_real = result_from_real["predictions"]
                    tower_predictions_for_fake.append(predictions_for_fake)
                    tower_predictions_for_real.append(predictions_for_real)

                    logits_for_fake = result_from_fake["logits"]
                    logits_for_real = result_from_real["logits"]
                    D_loss_fake = label_loss_fn.calculate_loss(
                        logits_for_fake, tf.zeros_like(logits_for_fake))
                    D_loss_real = label_loss_fn.calculate_loss(
                        logits_for_real, tf.ones_like(logits_for_real))
                    D_loss = D_loss_fake + D_loss_real
                    tower_D_losses.append(D_loss)

                    G_loss = label_loss_fn.calculate_loss(
                        logits_for_fake, tf.ones_like(logits_for_fake))
                    tower_G_losses.append(G_loss)

                    D_var = discriminator_model.get_variables()
                    D_gradients = optimizer.compute_gradients(D_loss,
                                                              var_list=D_var)
                    tower_D_gradients.append(D_gradients)

                    G_var = generator_model.get_variables()
                    G_gradients = optimizer.compute_gradients(G_loss,
                                                              var_list=G_var)
                    tower_G_gradients.append(G_gradients)

    D_loss = tf.reduce_mean(tf.stack(tower_D_losses))
    G_loss = tf.reduce_mean(tf.stack(tower_G_losses))
    tf.summary.scalar("D_loss", D_loss)
    tf.summary.scalar("G_loss", G_loss)
    merged_D_gradients = utils.combine_gradients(tower_D_gradients)
    merged_G_gradients = utils.combine_gradients(tower_G_gradients)

    if clip_gradient_norm > 0:
        with tf.name_scope('clip_grads'):
            merged_D_gradients = utils.clip_gradient_norms(
                merged_D_gradients, clip_gradient_norm)
            merged_G_gradients = utils.clip_gradient_norms(
                merged_G_gradients, clip_gradient_norm)

    # Attach global_step only once so that it will be increased by 1.
    D_train_op = optimizer.apply_gradients(merged_D_gradients)
    G_train_op = optimizer.apply_gradients(merged_G_gradients,
                                           global_step=global_step)

    tf.add_to_collection("global_step", global_step)
    tf.add_to_collection("D_loss", D_loss)
    tf.add_to_collection("G_loss", G_loss)
    tf.add_to_collection("p_for_fake", tf.concat(tower_predictions_for_fake,
                                                 0))
    tf.add_to_collection("p_for_data", tf.concat(tower_predictions_for_real,
                                                 0))
    tf.add_to_collection("input_batch_raw", model_input_raw)
    tf.add_to_collection("input_batch", model_input)
    tf.add_to_collection("generated_images",
                         tf.concat(tower_generated_images, 0))
    tf.add_to_collection("D_train_op", D_train_op)
    tf.add_to_collection("G_train_op", G_train_op)
    tf.add_to_collection("noise_input_placeholder", noise_input)
Ejemplo n.º 6
0
def build_graph(reader,
                model,
                train_data_pattern,
                label_loss_fn=losses.CrossEntropyLoss(),
                batch_size=1000,
                base_learning_rate=0.01,
                learning_rate_decay_examples=1000000,
                learning_rate_decay=0.95,
                optimizer_class=tf.train.AdamOptimizer,
                clip_gradient_norm=1.0,
                regularization_penalty=1,
                num_readers=1,
                num_epochs=None):

    global_step = tf.Variable(0, trainable=False, name="global_step")

    local_device_protos = device_lib.list_local_devices()
    gpus = [x.name for x in local_device_protos if x.device_type == 'GPU']
    num_gpus = len(gpus)

    if num_gpus > 0:
        logging.info("Using the following GPUs to train: " + str(gpus))
        num_towers = num_gpus
        device_string = '/gpu:%d'
    else:
        logging.info("No GPUs found. Training on CPU.")
        num_towers = 1
        device_string = '/cpu:%d'

    learning_rate = tf.train.exponential_decay(base_learning_rate,
                                               global_step * batch_size *
                                               num_towers,
                                               learning_rate_decay_examples,
                                               learning_rate_decay,
                                               staircase=True)
    tf.summary.scalar('learning_rate', learning_rate)

    optimizer = optimizer_class(learning_rate)
    unused_video_id, model_input_raw, labels_batch, num_frames = (
        get_input_data_tensors(reader,
                               train_data_pattern,
                               batch_size=batch_size * num_towers,
                               num_readers=num_readers,
                               num_epochs=num_epochs))
    tf.summary.histogram("model/input_raw", model_input_raw)

    feature_dim = len(model_input_raw.get_shape()) - 1

    model_input = tf.nn.l2_normalize(model_input_raw, feature_dim)

    tower_inputs = tf.split(model_input, num_towers)
    tower_labels = tf.split(labels_batch, num_towers)
    tower_num_frames = tf.split(num_frames, num_towers)
    tower_gradients = []
    tower_predictions = []
    tower_label_losses = []
    tower_reg_losses = []
    for i in range(num_towers):
        # For some reason these 'with' statements can't be combined onto the same
        # line. They have to be nested.
        with tf.device(device_string % i):
            with (tf.variable_scope(("tower"), reuse=True if i > 0 else None)):
                with (slim.arg_scope(
                    [slim.model_variable, slim.variable],
                        device="/cpu:0" if num_gpus != 1 else "/gpu:0")):
                    result = model.create_model(tower_inputs[i],
                                                num_frames=tower_num_frames[i],
                                                vocab_size=reader.num_classes,
                                                labels=tower_labels[i])
                    for variable in slim.get_model_variables():
                        tf.summary.histogram(variable.op.name, variable)

                    predictions = result["predictions"]
                    tower_predictions.append(predictions)

                    if "loss" in result.keys():
                        label_loss = result["loss"]
                    else:
                        label_loss = label_loss_fn.calculate_loss(
                            predictions, tower_labels[i])

                    if "regularization_loss" in result.keys():
                        reg_loss = result["regularization_loss"]
                    else:
                        reg_loss = tf.constant(0.0)

                    reg_losses = tf.losses.get_regularization_losses()
                    if reg_losses:
                        reg_loss += tf.add_n(reg_losses)

                    tower_reg_losses.append(reg_loss)

                    # Adds update_ops (e.g., moving average updates in batch normalization) as
                    # a dependency to the train_op.
                    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
                    if "update_ops" in result.keys():
                        update_ops += result["update_ops"]
                    if update_ops:
                        with tf.control_dependencies(update_ops):
                            barrier = tf.no_op(name="gradient_barrier")
                            with tf.control_dependencies([barrier]):
                                label_loss = tf.identity(label_loss)

                    tower_label_losses.append(label_loss)

                    # Incorporate the L2 weight penalties etc.
                    final_loss = regularization_penalty * reg_loss + label_loss
                    gradients = optimizer.compute_gradients(
                        final_loss, colocate_gradients_with_ops=False)
                    tower_gradients.append(gradients)
    label_loss = tf.reduce_mean(tf.stack(tower_label_losses))
    tf.summary.scalar("label_loss", label_loss)
    if regularization_penalty != 0:
        reg_loss = tf.reduce_mean(tf.stack(tower_reg_losses))
        tf.summary.scalar("reg_loss", reg_loss)
    merged_gradients = utils.combine_gradients(tower_gradients)

    if clip_gradient_norm > 0:
        with tf.name_scope('clip_grads'):
            merged_gradients = utils.clip_gradient_norms(
                merged_gradients, clip_gradient_norm)

    train_op = optimizer.apply_gradients(merged_gradients,
                                         global_step=global_step)

    tf.add_to_collection("global_step", global_step)
    tf.add_to_collection("loss", label_loss)
    tf.add_to_collection("predictions", tf.concat(tower_predictions, 0))
    tf.add_to_collection("input_batch_raw", model_input_raw)
    tf.add_to_collection("input_batch", model_input)
    tf.add_to_collection("num_frames", num_frames)
    tf.add_to_collection("labels", tf.cast(labels_batch, tf.float32))
    tf.add_to_collection("train_op", train_op)
Ejemplo n.º 7
0
def build_graph(reader,
                model,
                train_data_pattern,
                label_loss_fn=losses.CrossEntropyLoss(),
                batch_size=1000,
                base_learning_rate=0.01,
                learning_rate_decay_examples=1000000,
                learning_rate_decay=0.95,
                optimizer_class=tf.train.AdamOptimizer,
                clip_gradient_norm=1.0,
                regularization_penalty=1,
                num_readers=1,
                num_epochs=None):
  """Creates the Tensorflow graph.

  This will only be called once in the life of
  a training model, because after the graph is created the model will be
  restored from a meta graph file rather than being recreated.

  Args:
    reader: The data file reader. It should inherit from BaseReader.
    model: The core model (e.g. logistic or neural net). It should inherit
           from BaseModel.
    train_data_pattern: glob path to the training data files.
    label_loss_fn: What kind of loss to apply to the model. It should inherit
                from BaseLoss.
    batch_size: How many examples to process at a time.
    base_learning_rate: What learning rate to initialize the optimizer with.
    optimizer_class: Which optimization algorithm to use.
    clip_gradient_norm: Magnitude of the gradient to clip to.
    regularization_penalty: How much weight to give the regularization loss
                            compared to the label loss.
    num_readers: How many threads to use for I/O operations.
    num_epochs: How many passes to make over the data. 'None' means an
                unlimited number of passes.
  """

  global_step = tf.Variable(0, trainable=False, name="global_step")

  local_device_protos = device_lib.list_local_devices()
  gpus = [x.name for x in local_device_protos if x.device_type == 'GPU']
  gpus = gpus[:FLAGS.num_gpu]
  num_gpus = len(gpus)

  if num_gpus > 0:
    logging.info("Using the following GPUs to train: " + str(gpus))
    num_towers = num_gpus
    device_string = '/gpu:%d'
  else:
    logging.info("No GPUs found. Training on CPU.")
    num_towers = 1
    device_string = '/cpu:%d'

  learning_rate = tf.train.exponential_decay(
      base_learning_rate,
      global_step * batch_size * num_towers,
      learning_rate_decay_examples,
      learning_rate_decay,
      staircase=True)
  tf.summary.scalar('learning_rate', learning_rate)

  optimizer = optimizer_class(learning_rate)
  unused_video_id, model_input_raw, labels_batch, num_frames = (
      get_input_data_tensors(
          reader,
          train_data_pattern,
          batch_size=batch_size * num_towers,
          num_readers=num_readers,
          num_epochs=num_epochs))
  tf.summary.histogram("model/input_raw", model_input_raw)

  feature_dim = len(model_input_raw.get_shape()) - 1

  model_input = tf.nn.l2_normalize(model_input_raw, feature_dim)

  tower_inputs = tf.split(model_input, num_towers)
  tower_labels = tf.split(labels_batch, num_towers)
  tower_num_frames = tf.split(num_frames, num_towers)
  tower_gradients = []
  tower_predictions = []
  tower_label_losses = []
  tower_reg_losses = []
  for i in range(num_towers):
    # For some reason these 'with' statements can't be combined onto the same
    # line. They have to be nested.
    with tf.device(device_string % i):
      with (tf.variable_scope(("tower"), reuse=True if i > 0 else None)):
        with (slim.arg_scope([slim.model_variable, slim.variable], device="/cpu:0" if num_gpus!=1 else "/gpu:0")):
          result = model.create_model(
            tower_inputs[i],
            num_frames=tower_num_frames[i],
            vocab_size=reader.num_classes,
            labels=tower_labels[i])
          for variable in slim.get_model_variables():
            tf.summary.histogram(variable.op.name, variable)

          predictions = result["predictions"]
          tower_predictions.append(predictions)

          if "loss" in result.keys():
            label_loss = result["loss"]
          else:
            label_loss = label_loss_fn.calculate_loss(predictions, tower_labels[i])

          if "regularization_loss" in result.keys():
            reg_loss = result["regularization_loss"]
          else:
            reg_loss = tf.constant(0.0)

          reg_losses = tf.losses.get_regularization_losses()
          if reg_losses:
            reg_loss += tf.add_n(reg_losses)

          tower_reg_losses.append(reg_loss)

          # Adds update_ops (e.g., moving average updates in batch normalization) as
          # a dependency to the train_op.
          update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
          if "update_ops" in result.keys():
            update_ops += result["update_ops"]
          if update_ops:
            with tf.control_dependencies(update_ops):
              barrier = tf.no_op(name="gradient_barrier")
              with tf.control_dependencies([barrier]):
                label_loss = tf.identity(label_loss)

          tower_label_losses.append(label_loss)

          # Incorporate the L2 weight penalties etc.
          final_loss = regularization_penalty * reg_loss + label_loss
          gradients = optimizer.compute_gradients(final_loss,
              colocate_gradients_with_ops=False)
          tower_gradients.append(gradients)
  label_loss = tf.reduce_mean(tf.stack(tower_label_losses))
  tf.summary.scalar("label_loss", label_loss)
  if regularization_penalty != 0:
    reg_loss = tf.reduce_mean(tf.stack(tower_reg_losses))
    tf.summary.scalar("reg_loss", reg_loss)
  merged_gradients = utils.combine_gradients(tower_gradients)

  if clip_gradient_norm > 0:
    with tf.name_scope('clip_grads'):
      merged_gradients = utils.clip_gradient_norms(merged_gradients, clip_gradient_norm)

  train_op = optimizer.apply_gradients(merged_gradients, global_step=global_step)

  tf.add_to_collection("global_step", global_step)
  tf.add_to_collection("loss", label_loss)
  tf.add_to_collection("predictions", tf.concat(tower_predictions, 0))
  tf.add_to_collection("input_batch_raw", model_input_raw)
  tf.add_to_collection("input_batch", model_input)
  tf.add_to_collection("num_frames", num_frames)
  tf.add_to_collection("labels", tf.cast(labels_batch, tf.float32))
  tf.add_to_collection("train_op", train_op)