Ejemplo n.º 1
0
def model_fn(features, labels, mode, params):
    """The model_fn to be used with TPUEstimator.

    Args:
      features: `Tensor` of batched images.
      labels: `Tensor` of one hot labels for the data samples
      mode: one of `tf.estimator.ModeKeys.{TRAIN,EVAL,PREDICT}`
      params: `dict` of parameters passed to the model from the TPUEstimator,
          `params['batch_size']` is always provided and should be used as the
          effective batch size.

    Returns:
      A `TPUEstimatorSpec` for the model
    """
    if isinstance(features, dict):
        features = features['feature']

    # In most cases, the default data format NCHW instead of NHWC should be
    # used for a significant performance boost on GPU. NHWC should be used
    # only if the network needs to be run on CPU since the pooling operations
    # are only supported on NHWC. TPU uses XLA compiler to figure out best layout.
    if FLAGS.data_format == 'channels_first':
        assert not FLAGS.transpose_input  # channels_first only for GPU
        features = tf.transpose(features, [0, 3, 1, 2])
        stats_shape = [3, 1, 1]
    else:
        stats_shape = [1, 1, 3]

    if FLAGS.transpose_input and mode != tf.estimator.ModeKeys.PREDICT:
        features = tf.transpose(features, [3, 0, 1, 2])  # HWCN to NHWC

    is_training = (mode == tf.estimator.ModeKeys.TRAIN)
    has_moving_average_decay = (FLAGS.moving_average_decay > 0)
    # This is essential, if using a keras-derived model.
    tf.keras.backend.set_learning_phase(is_training)
    logging.info('Using open-source implementation.')
    override_params = {}
    if FLAGS.batch_norm_momentum is not None:
        override_params['batch_norm_momentum'] = FLAGS.batch_norm_momentum
    if FLAGS.batch_norm_epsilon is not None:
        override_params['batch_norm_epsilon'] = FLAGS.batch_norm_epsilon
    if FLAGS.dropout_rate is not None:
        override_params['dropout_rate'] = FLAGS.dropout_rate
    if FLAGS.survival_prob is not None:
        override_params['survival_prob'] = FLAGS.survival_prob
    if FLAGS.data_format:
        override_params['data_format'] = FLAGS.data_format
    if FLAGS.num_label_classes:
        override_params['num_classes'] = FLAGS.num_label_classes
    if FLAGS.depth_coefficient:
        override_params['depth_coefficient'] = FLAGS.depth_coefficient
    if FLAGS.width_coefficient:
        override_params['width_coefficient'] = FLAGS.width_coefficient

    def normalize_features(features, mean_rgb, stddev_rgb):
        """Normalize the image given the means and stddevs."""
        features -= tf.constant(mean_rgb,
                                shape=stats_shape,
                                dtype=features.dtype)
        features /= tf.constant(stddev_rgb,
                                shape=stats_shape,
                                dtype=features.dtype)
        return features

    def build_model():
        """Build model using the model_name given through the command line."""
        model_builder = model_builder_factory.get_model_builder(
            FLAGS.model_name)
        normalized_features = normalize_features(features,
                                                 model_builder.MEAN_RGB,
                                                 model_builder.STDDEV_RGB)
        logits, _ = model_builder.build_model(normalized_features,
                                              model_name=FLAGS.model_name,
                                              training=is_training,
                                              override_params=override_params,
                                              model_dir=FLAGS.model_dir)
        return logits

    if params['use_bfloat16']:
        with tf.tpu.bfloat16_scope():
            logits = tf.cast(build_model(), tf.float32)
    else:
        logits = build_model()

    if mode == tf.estimator.ModeKeys.PREDICT:
        predictions = {
            'classes': tf.argmax(logits, axis=1),
            'probabilities': tf.nn.softmax(logits, name='softmax_tensor')
        }
        return tf.estimator.EstimatorSpec(
            mode=mode,
            predictions=predictions,
            export_outputs={
                'classify': tf.estimator.export.PredictOutput(predictions)
            })

    # If necessary, in the model_fn, use params['batch_size'] instead the batch
    # size flags (--train_batch_size or --eval_batch_size).
    batch_size = params['batch_size']  # pylint: disable=unused-variable

    # Calculate loss, which includes softmax cross entropy and L2 regularization.
    cross_entropy = tf.losses.softmax_cross_entropy(  # not losses.softmax_cross_entropy nn.softmax_cross_entropy_with_logits
        onehot_labels=labels,
        logits=logits,
        label_smoothing=FLAGS.label_smoothing)

    # Add weight decay to the loss for non-batch-normalization variables.
    loss = cross_entropy + FLAGS.weight_decay * tf.add_n([
        tf.nn.l2_loss(v) for v in tf.trainable_variables()
        if 'batch_normalization' not in v.name
    ])

    global_step = tf.train.get_global_step()
    if has_moving_average_decay:
        ema = tf.train.ExponentialMovingAverage(
            decay=FLAGS.moving_average_decay, num_updates=global_step)
        ema_vars = utils.get_ema_vars()

    host_call = None
    restore_vars_dict = None
    if is_training:
        # Compute the current epoch and associated learning rate from global_step.
        current_epoch = (tf.cast(global_step, tf.float32) /
                         params['steps_per_epoch'])

        scaled_lr = FLAGS.base_learning_rate * (FLAGS.train_batch_size / 256.0)
        logging.info('base_learning_rate = %f', FLAGS.base_learning_rate)
        learning_rate = utils.build_learning_rate(scaled_lr, global_step,
                                                  params['steps_per_epoch'])
        optimizer = utils.build_optimizer(learning_rate, optimizer_name="sgd")
        if FLAGS.use_tpu:
            # When using TPU, wrap the optimizer with CrossShardOptimizer which
            # handles synchronization details between different TPU cores. To the
            # user, this should look like regular synchronous training.
            optimizer = tf.tpu.CrossShardOptimizer(optimizer)

        # Batch normalization requires UPDATE_OPS to be added as a dependency to
        # the train operation.
        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        with tf.control_dependencies(update_ops):
            train_op = optimizer.minimize(loss, global_step)

        if has_moving_average_decay:
            with tf.control_dependencies([train_op]):
                train_op = ema.apply(ema_vars)

        if not FLAGS.skip_host_call:

            def host_call_fn(gs, lr, ce):
                """Training host call. Creates scalar summaries for training metrics.

                This function is executed on the CPU and should not directly reference
                any Tensors in the rest of the `model_fn`. To pass Tensors from the
                model to the `metric_fn`, provide as part of the `host_call`. See
                https://www.tensorflow.org/api_docs/python/tf/estimator/tpu/TPUEstimatorSpec
                for more information.

                Arguments should match the list of `Tensor` objects passed as the second
                element in the tuple passed to `host_call`.

                Args:
                  gs: `Tensor with shape `[batch]` for the global_step
                  lr: `Tensor` with shape `[batch]` for the learning_rate.
                  ce: `Tensor` with shape `[batch]` for the current_epoch.

                Returns:
                  List of summary ops to run on the CPU host.
                """
                gs = gs[0]
                # Host call fns are executed FLAGS.iterations_per_loop times after one
                # TPU loop is finished, setting max_queue value to the same as number of
                # iterations will make the summary writer only flush the data to storage
                # once per loop.
                with tf2.summary.create_file_writer(
                        FLAGS.model_dir,
                        max_queue=FLAGS.iterations_per_loop).as_default():
                    with tf2.summary.record_if(True):
                        tf2.summary.scalar('learning_rate', lr[0], step=gs)
                        tf2.summary.scalar('current_epoch', ce[0], step=gs)

                        return tf.summary.all_v2_summary_ops()

            # To log the loss, current learning rate, and epoch for Tensorboard, the
            # summary op needs to be run on the host CPU via host_call. host_call
            # expects [batch_size, ...] Tensors, thus reshape to introduce a batch
            # dimension. These Tensors are implicitly concatenated to
            # [params['batch_size']].
            gs_t = tf.reshape(global_step, [1])
            lr_t = tf.reshape(learning_rate, [1])
            ce_t = tf.reshape(current_epoch, [1])

            host_call = (host_call_fn, [gs_t, lr_t, ce_t])

    else:
        train_op = None
        if has_moving_average_decay:
            # Load moving average variables for eval.
            restore_vars_dict = ema.variables_to_restore(ema_vars)

    eval_metrics = None
    if mode == tf.estimator.ModeKeys.EVAL:

        def metric_fn(labels, logits):
            """Evaluation metric function. Evaluates accuracy.

            This function is executed on the CPU and should not directly reference
            any Tensors in the rest of the `model_fn`. To pass Tensors from the model
            to the `metric_fn`, provide as part of the `eval_metrics`. See
            https://www.tensorflow.org/api_docs/python/tf/estimator/tpu/TPUEstimatorSpec
            for more information.

            Arguments should match the list of `Tensor` objects passed as the second
            element in the tuple passed to `eval_metrics`.

            Args:
              labels: `Tensor` with shape `[batch, num_classes]`.
              logits: `Tensor` with shape `[batch, num_classes]`.

            Returns:
              A dict of the metrics to return from evaluation.
            """

            # TODO 这里改 metric
            labels = tf.argmax(labels, axis=1)
            predictions = tf.argmax(logits, axis=1)
            accuracy = tf.metrics.accuracy(labels, predictions)
            auc = tf2.keras.metrics.AUC(name='auc')  # AUC
            auc.update_state(labels, predictions)
            return {'accuracy': accuracy, 'auc': auc}

        eval_metrics = (metric_fn, [labels, logits])

    num_params = np.sum([np.prod(v.shape) for v in tf.trainable_variables()])
    logging.info('number of trainable parameters: %d', num_params)

    def _scaffold_fn():
        saver = tf.train.Saver(restore_vars_dict)
        return tf.train.Scaffold(saver=saver)

    if has_moving_average_decay and not is_training:
        # Only apply scaffold for eval jobs.
        scaffold_fn = _scaffold_fn
    else:
        scaffold_fn = None

    return tf.estimator.tpu.TPUEstimatorSpec(mode=mode,
                                             loss=loss,
                                             train_op=train_op,
                                             host_call=host_call,
                                             eval_metrics=eval_metrics,
                                             scaffold_fn=scaffold_fn)
Ejemplo n.º 2
0
    def build(self, sampling):
        if sampling == True:
            batch_size, num_steps = 1, 1
        else:
            batch_size = self.batch_size
            num_steps = self.num_steps

        tf_x = tf.placeholder(tf.int32,
                              shape=[batch_size, num_steps],
                              name='tf_x')
        tf_y = tf.placeholder(tf.int32,
                              shape=[batch_size, num_steps],
                              name='tf_y')

        tf_keepprob = tf.placeholder(tf.float32, name='tf_keepprob')

        # one-hotエンコーディングを適用
        x_onehot = tf.one_hot(tf_x, depth=self.num_classes)
        y_onehot = tf.one_hot(tf_y, depth=self.num_classes)

        # 多層RNNのセルを構築
        cells = tf.nn.rnn_cell.MultiRNNCell([
            tf.nn.rnn_cell.DropoutWrapper(
                tf.nn.rnn_cell.BasicLSTMCell(self.lstm_size),
                output_keep_prob=tf_keepprob) for _ in range(self.num_layers)
        ])

        # 初期状態を定義
        self.initial_state = cells.zero_state(batch_size, tf.float32)

        # RNNで各シーケンスステップを実行
        lstm_outputs, self.final_state = tf.nn.dynamic_rnn(
            cells, x_onehot, initial_state=self.initial_state)

        print(' << lstm_outputs >> ', lstm_outputs)

        # 2次元テンソルに変形
        seq_output_reshaped = tf.reshape(lstm_outputs,
                                         shape=[-1, self.lstm_size],
                                         name='seq_output_reshaped')

        # 総入力を取得
        logits = tf.layers.dense(inputs=seq_output_reshaped,
                                 units=self.num_classes,
                                 activation=None,
                                 name='logits')

        # 次の文字バッチの確率
        proba = tf.nn.softmax(logits, name='probabilities')

        # コスト関数を定義
        y_reshaped = tf.reshape(y_onehot,
                                shape=[-1, self.num_classes],
                                name='y_reshaped')

        cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(
            logits=logits, labels=y_reshaped),
                              name='cost')

        # 勾配発散問題を回避するための勾配刈り込み
        tvars = tf.trainable_variables()
        grads, _ = tf.clip_by_global_norm(tf.gradients(cost, tvars),
                                          self.grad_clip)

        # オプティマイザを定義
        optimizer = tf.train.AdamOptimizer(self.learning_rate)
        train_op = optimizer.apply_gradients(zip(grads, tvars),
                                             name='train_op')
Ejemplo n.º 3
0
    def _model(self, images, is_training, reuse=False):
        """Compute the logits given the images."""

        if self.fixed_arc is None:
            is_training = True

        with tf.variable_scope(self.name, reuse=reuse):
            # the first two inputs
            with tf.variable_scope("stem_conv"):
                w = create_weight("w", [3, 3, 3, self.out_filters * 3])
                x = tf.nn.conv2d(images,
                                 w, [1, 1, 1, 1],
                                 "SAME",
                                 data_format=self.data_format)
                x = batch_norm(x, is_training, data_format=self.data_format)
            if self.data_format == "NHCW":
                split_axis = 3
            elif self.data_format == "NCHW":
                split_axis = 1
            else:
                raise ValueError("Unknown data_format '{0}'".format(
                    self.data_format))
            layers = [x, x]

            # building layers in the micro space
            out_filters = self.out_filters
            for layer_id in range(self.num_layers + 2):
                with tf.variable_scope("layer_{0}".format(layer_id)):
                    if layer_id not in self.pool_layers:
                        if self.fixed_arc is None:
                            x = self._enas_layer(layer_id, layers,
                                                 self.normal_arc2, out_filters)
                        else:
                            x = self._fixed_layer(
                                layer_id,
                                layers,
                                self.normal_arc2,
                                out_filters,
                                1,
                                is_training,
                                normal_or_reduction_cell="normal")
                    else:
                        out_filters *= 2
                        if self.fixed_arc is None:
                            x = self._factorized_reduction(
                                x, out_filters, 2, is_training)
                            layers = [layers[-1], x]
                            x = self._enas_layer(layer_id, layers,
                                                 self.reduce_arc2, out_filters)
                        else:
                            x = self._fixed_layer(
                                layer_id,
                                layers,
                                self.reduce_arc2,
                                out_filters,
                                2,
                                is_training,
                                normal_or_reduction_cell="reduction")
                    print("Layer {0:>2d}: {1}".format(layer_id, x))
                    layers = [layers[-1], x]

                # auxiliary heads
                self.num_aux_vars = 0
                if (self.use_aux_heads and layer_id in self.aux_head_indices
                        and is_training):
                    print("Using aux_head at layer {0}".format(layer_id))
                    with tf.variable_scope("aux_head"):
                        aux_logits = tf.nn.relu(x)
                        aux_logits = tf.layers.average_pooling2d(
                            aux_logits, [5, 5], [3, 3],
                            "VALID",
                            data_format=self.actual_data_format)
                        with tf.variable_scope("proj"):
                            inp_c = self._get_C(aux_logits)
                            w = create_weight("w", [1, 1, inp_c, 128])
                            aux_logits = tf.nn.conv2d(
                                aux_logits,
                                w, [1, 1, 1, 1],
                                "SAME",
                                data_format=self.data_format)
                            aux_logits = batch_norm(
                                aux_logits,
                                is_training=True,
                                data_format=self.data_format)
                            aux_logits = tf.nn.relu(aux_logits)

                        with tf.variable_scope("avg_pool"):
                            inp_c = self._get_C(aux_logits)
                            hw = self._get_HW(aux_logits)
                            w = create_weight("w", [hw, hw, inp_c, 768])
                            aux_logits = tf.nn.conv2d(
                                aux_logits,
                                w, [1, 1, 1, 1],
                                "SAME",
                                data_format=self.data_format)
                            aux_logits = batch_norm(
                                aux_logits,
                                is_training=True,
                                data_format=self.data_format)
                            aux_logits = tf.nn.relu(aux_logits)

                        with tf.variable_scope("fc"):
                            aux_logits = global_avg_pool(
                                aux_logits, data_format=self.data_format)
                            inp_c = aux_logits.get_shape()[1].value
                            w = create_weight("w", [inp_c, 10])
                            aux_logits = tf.matmul(aux_logits, w)
                            self.aux_logits = aux_logits

                    aux_head_variables = [
                        var for var in tf.trainable_variables()
                        if (var.name.startswith(self.name)
                            and "aux_head" in var.name)
                    ]
                    self.num_aux_vars = count_model_params(aux_head_variables)
                    print("Aux head uses {0} params".format(self.num_aux_vars))

            x = tf.nn.relu(x)
            x = global_avg_pool(x, data_format=self.data_format)
            if is_training and self.keep_prob is not None and self.keep_prob < 1.0:
                x = tf.nn.dropout(x, self.keep_prob)
            with tf.variable_scope("fc"):
                inp_c = self._get_C(x)
                w = create_weight("w", [inp_c, 10])
                x = tf.matmul(x, w)
        return x
Ejemplo n.º 4
0
    def model_fn(features, labels, mode, params):  # pylint: disable=unused-argument
        """The `model_fn` for TPUEstimator."""

        tf.logging.info("*** Features ***")
        for name in sorted(features.keys()):
            tf.logging.info("  name = %s, shape = %s" %
                            (name, features[name].shape))

        input_ids = features["input_ids"]
        input_mask = features["input_mask"]
        segment_ids = features["segment_ids"]
        label_ids = features["label_ids"]
        is_real_example = None
        if "is_real_example" in features:
            is_real_example = tf.cast(features["is_real_example"],
                                      dtype=tf.float32)
        else:
            is_real_example = tf.ones(tf.shape(label_ids), dtype=tf.float32)

        is_training = (mode == tf.estimator.ModeKeys.TRAIN)

        (total_loss, per_example_loss, probabilities, logits, predictions) = \
            create_model(albert_config, is_training, input_ids, input_mask,
                         segment_ids, label_ids, num_labels, use_one_hot_embeddings,
                         task_name, hub_module)

        tvars = tf.trainable_variables()
        initialized_variable_names = {}
        scaffold_fn = None
        if init_checkpoint:
            (assignment_map, initialized_variable_names
             ) = modeling.get_assignment_map_from_checkpoint(
                 tvars, init_checkpoint)
            if use_tpu:

                def tpu_scaffold():
                    tf.train.init_from_checkpoint(init_checkpoint,
                                                  assignment_map)
                    return tf.train.Scaffold()

                scaffold_fn = tpu_scaffold
            else:
                tf.train.init_from_checkpoint(init_checkpoint, assignment_map)

        tf.logging.info("**** Trainable Variables ****")
        for var in tvars:
            init_string = ""
            if var.name in initialized_variable_names:
                init_string = ", *INIT_FROM_CKPT*"
            tf.logging.info("  name = %s, shape = %s%s", var.name, var.shape,
                            init_string)

        output_spec = None
        if mode == tf.estimator.ModeKeys.TRAIN:

            train_op = optimization.create_optimizer(total_loss, learning_rate,
                                                     num_train_steps,
                                                     num_warmup_steps, use_tpu,
                                                     optimizer)

            output_spec = contrib_tpu.TPUEstimatorSpec(mode=mode,
                                                       loss=total_loss,
                                                       train_op=train_op,
                                                       scaffold_fn=scaffold_fn)
        elif mode == tf.estimator.ModeKeys.EVAL:
            if task_name not in ["sts-b", "cola"]:

                def metric_fn(per_example_loss, label_ids, logits,
                              is_real_example):
                    predictions = tf.argmax(logits,
                                            axis=-1,
                                            output_type=tf.int32)
                    accuracy = tf.metrics.accuracy(labels=label_ids,
                                                   predictions=predictions,
                                                   weights=is_real_example)
                    loss = tf.metrics.mean(values=per_example_loss,
                                           weights=is_real_example)
                    return {
                        "eval_accuracy": accuracy,
                        "eval_loss": loss,
                    }
            elif task_name == "sts-b":

                def metric_fn(per_example_loss, label_ids, logits,
                              is_real_example):
                    """Compute Pearson correlations for STS-B."""
                    # Display labels and predictions
                    concat1 = contrib_metrics.streaming_concat(logits)
                    concat2 = contrib_metrics.streaming_concat(label_ids)

                    # Compute Pearson correlation
                    pearson = contrib_metrics.streaming_pearson_correlation(
                        logits, label_ids, weights=is_real_example)

                    # Compute MSE
                    # mse = tf.metrics.mean(per_example_loss)
                    mse = tf.metrics.mean_squared_error(
                        label_ids, logits, weights=is_real_example)

                    loss = tf.metrics.mean(values=per_example_loss,
                                           weights=is_real_example)

                    return {
                        "pred": concat1,
                        "label_ids": concat2,
                        "pearson": pearson,
                        "MSE": mse,
                        "eval_loss": loss,
                    }
            elif task_name == "cola":

                def metric_fn(per_example_loss, label_ids, logits,
                              is_real_example):
                    """Compute Matthew's correlations for COLA."""
                    predictions = tf.argmax(logits,
                                            axis=-1,
                                            output_type=tf.int32)
                    # https://en.wikipedia.org/wiki/Matthews_correlation_coefficient
                    tp, tp_op = tf.metrics.true_positives(
                        labels=label_ids,
                        predictions=predictions,
                        weights=is_real_example)
                    tn, tn_op = tf.metrics.true_negatives(
                        labels=label_ids,
                        predictions=predictions,
                        weights=is_real_example)
                    fp, fp_op = tf.metrics.false_positives(
                        labels=label_ids,
                        predictions=predictions,
                        weights=is_real_example)
                    fn, fn_op = tf.metrics.false_negatives(
                        labels=label_ids,
                        predictions=predictions,
                        weights=is_real_example)

                    # Compute Matthew's correlation
                    mcc = tf.div_no_nan(
                        tp * tn - fp * fn,
                        tf.pow((tp + fp) * (tp + fn) * (tn + fp) * (tn + fn),
                               0.5))

                    # Compute accuracy
                    accuracy = tf.metrics.accuracy(labels=label_ids,
                                                   predictions=predictions,
                                                   weights=is_real_example)

                    loss = tf.metrics.mean(values=per_example_loss,
                                           weights=is_real_example)

                    return {
                        "matthew_corr":
                        (mcc, tf.group(tp_op, tn_op, fp_op, fn_op)),
                        "eval_accuracy": accuracy,
                        "eval_loss": loss,
                    }

            eval_metrics = (metric_fn, [
                per_example_loss, label_ids, logits, is_real_example
            ])
            output_spec = contrib_tpu.TPUEstimatorSpec(
                mode=mode,
                loss=total_loss,
                eval_metrics=eval_metrics,
                scaffold_fn=scaffold_fn)
        else:
            output_spec = contrib_tpu.TPUEstimatorSpec(mode=mode,
                                                       predictions={
                                                           "probabilities":
                                                           probabilities,
                                                           "predictions":
                                                           predictions
                                                       },
                                                       scaffold_fn=scaffold_fn)
        return output_spec
Ejemplo n.º 5
0
def main():
    os.environ["CUDA_VISIBLE_DEVICES"] = "0"
    if a.seed is None:
        a.seed = random.randint(0, 2**31 - 1)

    tf.set_random_seed(a.seed)
    np.random.seed(a.seed)
    random.seed(a.seed)

    if not os.path.exists(a.output_dir):
        os.makedirs(a.output_dir)

    if a.mode == "test":
        if a.checkpoint is None:
            raise Exception("checkpoint required for test mode")

        # load some options from the checkpoint
        options = {"which_direction", "ngf", "ndf", "lab_colorization"}
        with open(os.path.join(a.checkpoint, "options.json")) as f:
            for key, val in json.loads(f.read()).items():
                if key in options:
                    print("loaded", key, "=", val)
                    setattr(a, key, val)
        # disable these features in test mode
        a.flip = False

    for k, v in a._get_kwargs():
        print(k, "=", v)

    with open(os.path.join(a.output_dir, "options.json"), "w") as f:
        f.write(json.dumps(vars(a), sort_keys=True, indent=4))

    examples = load_examples()
    print("examples count = %d" % examples.count)

    # inputs and targets are [batch_size, height, width, channels]
    if a.mode == "test":
        patch_h_cnt, padding_h = find_patch_and_padding(
            IMAGE_HEIGHT, CROP_SIZE)
        patch_w_cnt, padding_w = find_patch_and_padding(IMAGE_WIDTH, CROP_SIZE)

        paddings = [[0, 0], [padding_h, padding_h], [padding_w, padding_w],
                    [0, 0]]
        inputs_pad = tf.pad(examples.inputs, paddings, "REFLECT")
        targets_pad = tf.pad(examples.targets, paddings, "REFLECT")

        IMAGE_PADDING_HEIGHT = IMAGE_HEIGHT + 2 * padding_h
        IMAGE_PADDING_WIDTH = IMAGE_WIDTH + 2 * padding_w
        outputs = tf.zeros([1, IMAGE_PADDING_HEIGHT, IMAGE_PADDING_WIDTH, 1],
                           dtype=tf.float32)

        first = True
        # combine patchs into images
        for row in range(patch_h_cnt):
            for col in range(patch_w_cnt):
                row_index = int(row * CROP_SIZE / 2)
                col_index = int(col * CROP_SIZE / 2)
                if first == True:
                    with tf.variable_scope("create_model"):
                        model = create_model(
                            tf.slice(inputs_pad, [0, row_index, col_index, 0],
                                     [1, CROP_SIZE, CROP_SIZE, 1]),
                            tf.slice(targets_pad, [0, row_index, col_index, 0],
                                     [1, CROP_SIZE, CROP_SIZE, 1]))
                    first = False
                else:
                    with tf.variable_scope("create_model", reuse=True):
                        model = create_model(
                            tf.slice(inputs_pad, [0, row_index, col_index, 0],
                                     [1, CROP_SIZE, CROP_SIZE, 1]),
                            tf.slice(targets_pad, [0, row_index, col_index, 0],
                                     [1, CROP_SIZE, CROP_SIZE, 1]))
                paddings = [
                    [0, 0],
                    [row_index, IMAGE_PADDING_HEIGHT - CROP_SIZE - row_index],
                    [col_index, IMAGE_PADDING_WIDTH - CROP_SIZE - col_index],
                    [0, 0]
                ]
                outputs = outputs + tf.pad(model.outputs, paddings, "CONSTANT")

        CROP_HALF = int(CROP_SIZE / 2)
        o_11 = tf.pad(
            tf.slice(outputs, [0, 0, 0, 0], [1, CROP_HALF, CROP_HALF, 1]),
            [[0, 0], [0, IMAGE_PADDING_HEIGHT - CROP_HALF],
             [0, IMAGE_PADDING_WIDTH - CROP_HALF], [0, 0]], "CONSTANT")
        o_12 = tf.pad(
            tf.slice(outputs, [0, 0, IMAGE_PADDING_WIDTH - CROP_HALF, 0],
                     [1, CROP_HALF, CROP_HALF, 1]),
            [[0, 0], [0, IMAGE_PADDING_HEIGHT - CROP_HALF],
             [IMAGE_PADDING_WIDTH - CROP_HALF, 0], [0, 0]], "CONSTANT")
        o_13 = tf.pad(
            tf.slice(outputs, [0, IMAGE_PADDING_HEIGHT - CROP_HALF, 0, 0],
                     [1, CROP_HALF, CROP_HALF, 1]),
            [[0, 0], [IMAGE_PADDING_HEIGHT - CROP_HALF, 0],
             [0, IMAGE_PADDING_WIDTH - CROP_HALF], [0, 0]], "CONSTANT")
        o_14 = tf.pad(
            tf.slice(outputs, [
                0, IMAGE_PADDING_HEIGHT - CROP_HALF,
                IMAGE_PADDING_WIDTH - CROP_HALF, 0
            ], [1, CROP_HALF, CROP_HALF, 1]),
            [[0, 0], [IMAGE_PADDING_HEIGHT - CROP_HALF, 0],
             [IMAGE_PADDING_WIDTH - CROP_HALF, 0], [0, 0]], "CONSTANT")

        o_21 = tf.pad(
            tf.slice(outputs, [0, 0, CROP_HALF, 0],
                     [1, CROP_HALF, IMAGE_PADDING_WIDTH - 2 * CROP_HALF, 1]),
            [[0, 0], [0, IMAGE_PADDING_HEIGHT - CROP_HALF],
             [CROP_HALF, CROP_HALF], [0, 0]], "CONSTANT")
        o_22 = tf.pad(
            tf.slice(outputs, [0, CROP_HALF, 0, 0],
                     [1, IMAGE_PADDING_HEIGHT - 2 * CROP_HALF, CROP_HALF, 1]),
            [[0, 0], [CROP_HALF, CROP_HALF],
             [0, IMAGE_PADDING_WIDTH - CROP_HALF], [0, 0]], "CONSTANT")
        o_23 = tf.pad(
            tf.slice(outputs,
                     [0, IMAGE_PADDING_HEIGHT - CROP_HALF, CROP_HALF, 0],
                     [1, CROP_HALF, IMAGE_PADDING_WIDTH - 2 * CROP_HALF, 1]),
            [[0, 0], [IMAGE_PADDING_HEIGHT - CROP_HALF, 0],
             [CROP_HALF, CROP_HALF], [0, 0]], "CONSTANT")
        o_24 = tf.pad(
            tf.slice(outputs,
                     [0, CROP_HALF, IMAGE_PADDING_WIDTH - CROP_HALF, 0],
                     [1, IMAGE_PADDING_HEIGHT - 2 * CROP_HALF, CROP_HALF, 1]),
            [[0, 0], [CROP_HALF, CROP_HALF],
             [IMAGE_PADDING_WIDTH - CROP_HALF, 0], [0, 0]], "CONSTANT")
        o_4 = tf.pad(
            tf.slice(outputs, [0, CROP_HALF, CROP_HALF, 0], [
                1, IMAGE_PADDING_HEIGHT - 2 * CROP_HALF,
                IMAGE_PADDING_WIDTH - 2 * CROP_HALF, 1
            ]),
            [[0, 0], [CROP_HALF, CROP_HALF], [CROP_HALF, CROP_HALF], [0, 0]],
            "CONSTANT")

        outputs = o_11 + o_12 + o_13 + o_14 + (o_21 + o_22 + o_23 +
                                               o_24) / 2 + o_4 / 4
        outputs = tf.slice(outputs, [0, padding_h, padding_w, 0],
                           [1, IMAGE_HEIGHT, IMAGE_WIDTH, 1])
        outputs = deprocess(outputs)
    else:
        with tf.variable_scope("create_model"):
            model = create_model(examples.inputs, examples.targets)
        outputs = deprocess(model.outputs)

    inputs = deprocess(examples.inputs)
    targets = deprocess(examples.targets)

    def convert(image):
        return tf.image.convert_image_dtype(image,
                                            dtype=tf.uint8,
                                            saturate=True)

    # reverse any processing on images so they can be written to disk or displayed to user
    with tf.name_scope("convert_inputs"):
        converted_inputs = convert(inputs)

    with tf.name_scope("convert_targets"):
        converted_targets = convert(targets)

    with tf.name_scope("convert_outputs"):
        converted_outputs = convert(outputs)

    with tf.name_scope("encode_images"):
        display_fetches = {
            "paths":
            examples.paths,
            "inputs":
            tf.map_fn(tf.image.encode_png,
                      converted_inputs,
                      dtype=tf.string,
                      name="input_pngs"),
            "targets":
            tf.map_fn(tf.image.encode_png,
                      converted_targets,
                      dtype=tf.string,
                      name="target_pngs"),
            "outputs":
            tf.map_fn(tf.image.encode_png,
                      converted_outputs,
                      dtype=tf.string,
                      name="output_pngs"),
        }

    # summaries
    with tf.name_scope("inputs_summary"):
        tf.summary.image("inputs", converted_inputs)

    with tf.name_scope("targets_summary"):
        tf.summary.image("targets", converted_targets)

    with tf.name_scope("outputs_summary"):
        tf.summary.image("outputs", converted_outputs)

    tf.summary.scalar("generator_loss_L1", model.gen_loss_L1)

    for var in tf.trainable_variables():
        tf.summary.histogram(var.op.name + "/values", var)
    if a.mode == "train":
        for grad, var in model.gen_grads_and_vars:
            tf.summary.histogram(var.op.name + "/gradients", grad)

    with tf.name_scope("parameter_count"):
        parameter_count = tf.reduce_sum(
            [tf.reduce_prod(tf.shape(v)) for v in tf.trainable_variables()])

    saver = tf.train.Saver(max_to_keep=1)

    logdir = a.output_dir if (a.trace_freq > 0 or a.summary_freq > 0) else None
    sv = tf.train.Supervisor(logdir=logdir, save_summaries_secs=0, saver=None)
    with sv.managed_session() as sess:
        print("parameter_count =", sess.run(parameter_count))

        if a.checkpoint is not None:
            print("loading model from checkpoint")
            checkpoint = tf.train.latest_checkpoint(a.checkpoint)
            saver.restore(sess, checkpoint)

        max_steps = 2**32
        if a.max_epochs is not None:
            max_steps = examples.steps_per_epoch * a.max_epochs
        if a.max_steps is not None:
            max_steps = a.max_steps

        if a.mode == "test":
            # testing
            # at most, process the test data once
            start = time.time()
            max_steps = min(examples.steps_per_epoch, max_steps)
            for step in range(max_steps):
                results = sess.run(display_fetches)
                filesets = save_images_test(results, step)
                for i, f in enumerate(filesets):
                    print("evaluated image", f["name"])
                #index_path = append_index(filesets)
            #print("wrote index at", index_path)
            print("rate", (time.time() - start) / max_steps)
        else:
            # training
            start = time.time()

            for step in range(max_steps):

                def should(freq):
                    return freq > 0 and ((step + 1) % freq == 0
                                         or step == max_steps - 1)

                options = None
                run_metadata = None
                if should(a.trace_freq):
                    options = tf.RunOptions(
                        trace_level=tf.RunOptions.FULL_TRACE)
                    run_metadata = tf.RunMetadata()

                fetches = {
                    "train": model.train,
                    "global_step": sv.global_step,
                }

                if should(a.progress_freq):
                    fetches["gen_loss_L1"] = model.gen_loss_L1

                if should(a.summary_freq):
                    fetches["summary"] = sv.summary_op

                if should(a.display_freq):
                    fetches["display"] = display_fetches

                results = sess.run(fetches,
                                   options=options,
                                   run_metadata=run_metadata)

                if should(a.summary_freq):
                    print("recording summary")
                    sv.summary_writer.add_summary(results["summary"],
                                                  results["global_step"])

                if should(a.display_freq):
                    print("saving display images")
                    filesets = save_images(results["display"],
                                           step=results["global_step"])
                    append_index(filesets, step=True)

                if should(a.trace_freq):
                    print("recording trace")
                    sv.summary_writer.add_run_metadata(
                        run_metadata, "step_%d" % results["global_step"])

                if should(a.progress_freq):
                    # global_step will have the correct step count if we resume from a checkpoint
                    train_epoch = math.ceil(results["global_step"] /
                                            examples.steps_per_epoch)
                    train_step = (results["global_step"] -
                                  1) % examples.steps_per_epoch + 1
                    rate = (step + 1) * a.batch_size / (time.time() - start)
                    remaining = (max_steps - step) * a.batch_size / rate
                    print(
                        "progress  epoch %d  step %d  image/sec %0.1f  remaining %dm"
                        % (train_epoch, train_step, rate, remaining / 60))
                    print("gen_loss_L1", results["gen_loss_L1"])

                if should(a.save_freq):
                    print("saving model")
                    saver.save(sess,
                               os.path.join(a.output_dir, "model"),
                               global_step=sv.global_step)

                if sv.should_stop():
                    break
Ejemplo n.º 6
0
    def rllim_train(self):
        """Training instance-wise weight estimator."""

        # Generates selected probabilities
        est_data_value = self.inst_weight_evaluator(self.x_input,
                                                    self.xt_input)

        # Generates a set of selected probabilities
        est_data_value_set = \
        [self.inst_weight_evaluator(self.x_input, self.xt_input[i, :]) \
         for i in range(self.batch_size_inner)]

        # Loss for the REINFORCE algorithm
        epsilon = 1e-8  # add to log to avoid overflow
        # 1. Without lambda penalty
        for ktt in range(self.batch_size_inner):
            prob = tf.reduce_mean(self.s_input[:, ktt] * \
                                  tf.log(est_data_value_set[ktt] + epsilon) + \
                                (1-self.s_input[:, ktt]) * \
                                  tf.log(1 - est_data_value_set[ktt] + epsilon))
            if ktt == 0:
                dve_loss_curr = (-self.reward_input[ktt] * prob)
            else:
                dve_loss_curr = dve_loss_curr + (-self.reward_input[ktt] *
                                                 prob)

        dve_loss = dve_loss_curr / self.batch_size_inner

        # 2. With lambda penalty
        eta = 1e3  # multiplier to the regularizer
        thresh = 0.01  # threshold for the minimum selection

        for ktt in range(self.batch_size_inner):
            prob_hat = tf.reduce_mean(self.s_input[:, ktt] * \
                                      tf.log(est_data_value_set[ktt] + epsilon) + \
                                (1-self.s_input[:, ktt]) * \
                                      tf.log(1 - est_data_value_set[ktt] + epsilon))
            if ktt == 0:
                dve_loss_curr_hat = (-self.reward_input[ktt] * prob_hat) - \
                self.hyper_lambda * tf.reduce_mean(est_data_value_set[ktt]) * \
                prob_hat + \
                eta * tf.maximum(thresh - tf.reduce_mean(est_data_value_set[ktt]), 0)
            else:
                dve_loss_curr_hat = dve_loss_curr_hat + \
                (-self.reward_input[ktt] * prob_hat) - \
                self.hyper_lambda * tf.reduce_mean(est_data_value_set[ktt]) \
                * prob_hat + \
                eta * tf.maximum(thresh - tf.reduce_mean(est_data_value_set[ktt]), 0)

        dve_loss_hat = dve_loss_curr_hat / self.batch_size_inner

        # Variables
        dve_vars = [v for v in tf.trainable_variables() \
                    if v.name.startswith('data_value_estimator')]

        # Optimization step
        dve_solver = tf.train.AdamOptimizer(0.0001).minimize(dve_loss,
                                                             var_list=dve_vars)

        dve_solver_hat = tf.train.AdamOptimizer(0.0001).minimize(
            dve_loss_hat, var_list=dve_vars)

        # Main session
        sess = tf.Session()
        sess.run(tf.global_variables_initializer())

        # Saves model at the end
        saver = tf.train.Saver()

        # Outer iterations
        for itt in tqdm.tqdm(range(self.outer_iterations)):

            # Batch selection
            batch_idx = \
            np.random.permutation(len(self.x_train[:, 0]))[:self.batch_size]

            x_batch = self.x_train[batch_idx, :]
            y_batch = self.y_train[batch_idx]

            val_batch_idx = \
            np.random.permutation(len(self.x_probe[:, 0]))[:self.batch_size_inner]

            xt_batch = self.x_probe[val_batch_idx, :]
            yt_batch = self.y_probe[val_batch_idx]

            # Initialization
            reward_curr = np.zeros([self.batch_size_inner, 1])
            sel_prob_curr = np.zeros([self.batch_size, self.batch_size_inner])

            # Inner iterations
            for ktt in range(self.batch_size_inner):

                # Generates selection probability
                est_dv_curr = sess.run(est_data_value,
                                       feed_dict={
                                           self.x_input:
                                           x_batch,
                                           self.xt_input:
                                           np.reshape(xt_batch[ktt, :],
                                                      [1, self.data_dim])
                                       })

                # Samples based on the selection probability
                sel_prob_curr[:, ktt] = np.random.binomial(
                    1, est_dv_curr, est_dv_curr.shape)[:, 0]

                # Exception (When selection probability is 0)
                if np.sum(sel_prob_curr[:, ktt]) == 0:
                    est_dv_curr = 0.5 * np.ones(np.shape(est_dv_curr))
                    sel_prob_curr[:, ktt] = np.random.binomial(
                        1, est_dv_curr, est_dv_curr.shape)[:, 0]

                # Trains instance-wise locally interpretable model
                self.interp_model.fit(x_batch, y_batch, sel_prob_curr[:, ktt])

                # Interpretable predictions
                yt_batch_hat_new = \
                self.interp_model.predict(np.reshape(xt_batch[ktt, :],
                                                     [1, self.data_dim]))

                # Fidelity of interpretable model
                new_mse = np.abs(yt_batch_hat_new - yt_batch[ktt])

                # Interpretable baseline prediction
                yt_batch_hat_ori = \
                self.baseline_model.predict(np.reshape(xt_batch[ktt, :],
                                                       [1, self.data_dim]))

                # Fidelity of interpretable baseline model
                ori_mse = np.abs(yt_batch_hat_ori - yt_batch[ktt])

                # Computes reward
                reward_curr[ktt] = new_mse - ori_mse

            # Trains the generator
            # Without lambda penalty
            if itt < 500:
                _ = sess.run(dve_solver,
                             feed_dict={
                                 self.x_input: x_batch,
                                 self.xt_input: xt_batch,
                                 self.s_input: sel_prob_curr,
                                 self.reward_input: reward_curr
                             })

            # With lambda penalty
            else:
                _ = sess.run(dve_solver_hat,
                             feed_dict={
                                 self.x_input: x_batch,
                                 self.xt_input: xt_batch,
                                 self.s_input: sel_prob_curr,
                                 self.reward_input: reward_curr
                             })

        # Saves model
        saver.save(sess, self.checkpoint_file_name)
    def __init__(self,
                 session,
                 state_spec,
                 action_spec,
                 hidden_layers,
                 learning_rate,
                 learning_rate_action,
                 learning_rate_ga,
                 batch_size,
                 action_maximization_iterations,
                 name,
                 l2_loss_flag=False,
                 simple_lambda_flag=True,
                 solver=None,
                 sufficient_ascent_flag=False,
                 initial_lambda=10.0,
                 lambda_max=5e3):
        """Creates CAQL networks.

    Args:
      session: TF session.
      state_spec: tf_agents.specs.array_spec.ArraySpec. Specification for state.
      action_spec: tf_agents.specs.array_spec.ArraySpec. Specification for
        action.
      hidden_layers: list of integers. Number of hidden units for each hidden
        layer.
      learning_rate: float on Q function learning rate.
      learning_rate_action: float on action function learning rate.
      learning_rate_ga: float. Learning rate for gradient ascent optimizer.
      batch_size: int on batch size for training.
      action_maximization_iterations: int on CEM/gradient ascent iterations.
      name: string on name of network.
      l2_loss_flag: bool on using l2 loss.
      simple_lambda_flag: bool on using lambda hinge loss.
      solver: string on inner max optimizer. Supported optimizers are
        "gradient_ascent", "cross_entropy", "ails", "mip".
      sufficient_ascent_flag: bool on using sufficient ascent.
      initial_lambda: float on initial lambda (only for simple_lambda_flag).
      lambda_max: float on lambda upper-bound.
    """
        self._session = session
        self.state_spec = state_spec
        self.action_spec = action_spec
        self.state_dim = state_spec.shape[0]
        self.action_dim = action_spec.shape[0]
        self.action_max = action_spec.maximum
        self.action_min = action_spec.minimum
        self.hidden_layers = hidden_layers
        self.learning_rate = learning_rate
        self.learning_rate_action = learning_rate_action
        self.learning_rate_ga = learning_rate_ga
        self.batch_size = batch_size
        self.action_maximization_iterations = action_maximization_iterations

        self.name = name
        self.lambda_max = lambda_max
        if solver == "ails" or solver == "mip":
            raise ValueError("AILS and MIP solvers are not supported yet.")

        # define placeholders
        self._state_tensor = tf.placeholder(dtype=tf.float32,
                                            name="state_tensor",
                                            shape=(None, self.state_dim))
        self._state_deviation_tensor = tf.placeholder(
            dtype=tf.float32,
            name="state_deviation_tensor",
            shape=(None, self.state_dim))
        self._action_tensor = tf.placeholder(dtype=tf.float32,
                                             name="action_tensor",
                                             shape=(None, self.action_dim))
        self._next_state_tensor = tf.placeholder(dtype=tf.float32,
                                                 name="next_state_tensor",
                                                 shape=(None, self.state_dim))
        self._reward_tensor = tf.placeholder(dtype=tf.float32,
                                             name="reward_tensor",
                                             shape=(None, 1))
        self._done_tensor = tf.placeholder(dtype=tf.bool,
                                           name="done_tensor",
                                           shape=(None, 1))
        self._discount_factor = tf.placeholder(dtype=tf.float32,
                                               name="discounting_factor",
                                               shape=())
        self._maxq_label = tf.placeholder(dtype=tf.float32,
                                          shape=(None, 1),
                                          name="maxq_label")

        self._backup_tensor = self._reward_tensor + (1.0 - tf.to_float(
            self._done_tensor)) * self._discount_factor * self._maxq_label

        self._true_label = tf.placeholder(dtype=tf.float32,
                                          shape=(None, 1),
                                          name="true_label")

        self.q_function_network = self._build_q_function_net(
            self._state_tensor, self._action_tensor)
        self.state_perturbed_q_function_network = self.q_function_network \
            + tf.expand_dims(tf.einsum("ij,ij->i",
                                       tf.gradients(self.q_function_network,
                                                    self._state_tensor)[0],
                                       self._state_deviation_tensor),
                             axis=-1)

        self._td_rmse = tf.sqrt(
            tf.losses.mean_squared_error(
                self._reward_tensor + (1.0 - tf.to_float(self._done_tensor)) *
                self._discount_factor * self._maxq_label,
                self.q_function_network))

        if simple_lambda_flag:
            with tf.variable_scope("{}_{}".format(self.name,
                                                  "lambda_function")):
                lambda_var = tf.Variable(initial_value=initial_lambda,
                                         trainable=True,
                                         name="lambda_var")
                self.lambda_function_network = tf.tile(
                    tf.reshape(
                        tf.minimum(lambda_max,
                                   tf.maximum(0.0, lambda_var),
                                   name="lambda_proj"), (-1, 1)),
                    (self.batch_size, 1))
        else:
            self.lambda_function_network = self._build_lambda_function_net(
                self._state_tensor, self._action_tensor)

        # define loss
        if l2_loss_flag:
            self._q_function_loss = tf.losses.mean_squared_error(
                self._true_label, self.q_function_network)
        else:
            self._q_function_loss = tf.reduce_mean(
                self.q_function_network + self.lambda_function_network *
                tf.maximum(0.0, self._true_label - self.q_function_network))

        self._lambda_function_loss = tf.reduce_mean(
            -self.lambda_function_network *
            (self._true_label - self.q_function_network))

        # Action network to learn argmax of Q
        self._best_q_label = tf.placeholder(dtype=tf.float32,
                                            shape=(None, 1),
                                            name="best_q_label")

        # create network placeholders
        self._create_network_var_ph()

        self.action_function_network = self._build_action_function_net(
            self._state_tensor)
        self.dummy_q_function_network = self._build_q_function_net(
            self._state_tensor, self.action_function_network)

        self._action_function_loss = tf.losses.mean_squared_error(
            self._best_q_label, self.dummy_q_function_network)

        # optimizer
        # NOTE: Increment this by one by inlcuding it only in main_q trainer.
        global_step = tf.Variable(0,
                                  name="{}_global_step".format(self.name),
                                  trainable=False)
        with tf.variable_scope("{}_{}".format(self.name, "optimizer")):
            self._action_function_optimizer = tf.train.AdamOptimizer(
                learning_rate=self.learning_rate).minimize(
                    self._action_function_loss,
                    var_list=tf.trainable_variables("{}_{}".format(
                        self.name, "action_function")))
            self._q_function_optimizer = tf.train.AdamOptimizer(
                learning_rate=self.learning_rate).minimize(
                    self._q_function_loss,
                    global_step=global_step,
                    var_list=tf.trainable_variables("{}_{}".format(
                        self.name, "q_function")))
            self._lambda_function_optimizer = tf.train.AdamOptimizer(
                learning_rate=self.learning_rate).minimize(
                    self._lambda_function_loss,
                    var_list=tf.trainable_variables("{}_{}".format(
                        self.name, "lambda_function")))

        # Tensors for dual solvers
        self._create_dual_maxq_label_tensor()
        self._create_dual_active_constraint_condition_tensor()

        self.solver = solver
        self.sufficient_ascent_flag = sufficient_ascent_flag
Ejemplo n.º 8
0
    def _init_graph(self):
        self.graph = tf.Graph()
        with self.graph.as_default():

            tf.set_random_seed(self.random_seed)

            # placeholder for single-value field.
            self.feat_index = tf.placeholder(tf.int32,
                                             shape=[None, None],
                                             name="feat_index")  # None * M-1
            self.feat_value = tf.placeholder(tf.float32,
                                             shape=[None, None],
                                             name="feat_value")  # None * M-1

            # placeholder for multi-value field. (movielens dataset genre field)
            self.genre_index = tf.placeholder(tf.int32,
                                              shape=[None, None],
                                              name="genre_index")  # None * 6
            self.genre_value = tf.placeholder(tf.float32,
                                              shape=[None, None],
                                              name="genre_value")  # None * 6

            self.label = tf.placeholder(tf.float32,
                                        shape=[None, 1],
                                        name="label")  # None * 1

            # In our implementation, the shape of dropout_keep_prob is [3], used in 3 different places.
            self.dropout_keep_prob = tf.placeholder(tf.float32,
                                                    shape=[None],
                                                    name="dropout_keep_prob")
            self.train_phase = tf.placeholder(tf.bool, name="train_phase")

            self.weights = self._initialize_weights()

            # model
            self.embeddings = tf.nn.embedding_lookup(
                self.weights["feature_embeddings"],
                self.feat_index)  # None * M-1 * d
            feat_value = tf.reshape(self.feat_value,
                                    shape=[-1, self.field_size - 1, 1])
            self.embeddings = tf.multiply(self.embeddings,
                                          feat_value)  # None * M-1 * d

            # for multi-value field
            self.embeddings_m = tf.nn.embedding_lookup(
                self.weights["feature_embeddings"],
                self.genre_index)  # None * 6 * d
            genre_value = tf.reshape(self.genre_value, shape=[-1, 6, 1])
            self.embeddings_m = tf.multiply(self.embeddings_m, genre_value)
            self.embeddings_m = tf.reduce_sum(self.embeddings_m,
                                              axis=1)  # None * d
            self.embeddings_m = tf.div(self.embeddings_m,
                                       tf.reduce_sum(
                                           self.genre_value,
                                           axis=1,
                                           keep_dims=True))  # None * d

            #concatenate single-value field with multi-value field
            self.embeddings = tf.concat(
                [self.embeddings,
                 tf.expand_dims(self.embeddings_m, 1)], 1)  # None * M * d
            self.embeddings = tf.nn.dropout(
                self.embeddings, self.dropout_keep_prob[1])  # None * M * d

            # joint training with feedforward nn
            if self.deep_layers != None:
                self.y_dense = tf.reshape(
                    self.embeddings,
                    shape=[-1, self.field_size * self.embedding_size])
                for i in range(0, len(self.deep_layers)):
                    self.y_dense = tf.add(
                        tf.matmul(self.y_dense, self.weights["layer_%d" % i]),
                        self.weights["bias_%d" % i])  # None * layer[i]
                    if self.batch_norm:
                        self.y_dense = self.batch_norm_layer(
                            self.y_dense,
                            train_phase=self.train_phase,
                            scope_bn="bn_%d" % i)
                    self.y_dense = tf.nn.relu(self.y_dense)
                    self.y_dense = tf.nn.dropout(self.y_dense,
                                                 self.dropout_keep_prob[2])
                self.y_dense = tf.add(tf.matmul(
                    self.y_dense, self.weights["prediction_dense"]),
                                      self.weights["prediction_bias_dense"],
                                      name='logits_dense')  # None * 1

            # ---------- main part of AutoInt-------------------
            self.y_deep = self.embeddings  # None * M * d
            for i in range(self.blocks):
                self.y_deep = multihead_attention(
                    queries=self.y_deep,
                    keys=self.y_deep,
                    values=self.y_deep,
                    num_units=self.block_shape[i],
                    num_heads=self.heads,
                    dropout_keep_prob=self.dropout_keep_prob[0],
                    is_training=self.train_phase,
                    has_residual=self.has_residual)

            self.flat = tf.reshape(
                self.y_deep, shape=[-1, self.output_size * self.field_size])

            self.out = tf.add(tf.matmul(self.flat, self.weights["prediction"]),
                              self.weights["prediction_bias"],
                              name='logits')  # None * 1

            if self.deep_layers != None:
                self.out += self.y_dense

            # ---------- Compute the loss ----------
            # loss
            if self.loss_type == "logloss":
                self.out = tf.nn.sigmoid(self.out, name='pred')
                self.loss = tf.losses.log_loss(self.label, self.out)
            elif self.loss_type == "mse":
                self.loss = tf.nn.l2_loss(tf.subtract(self.label, self.out))

            # l2 regularization on weights
            if self.l2_reg > 0:
                if self.deep_layers != None:
                    for i in range(len(self.deep_layers)):
                        self.loss += tf.contrib.layers.l2_regularizer(
                            self.l2_reg)(self.weights["layer_%d" % i])

            self.global_step = tf.Variable(0,
                                           name="global_step",
                                           trainable=False)
            self.var1 = [
                v for v in tf.trainable_variables()
                if v.name != 'feature_bias:0'
            ]
            self.var2 = [tf.trainable_variables()[1]
                         ]  # self.var2 = [feature_bias]

            if self.optimizer_type == "adam":
                self.optimizer = tf.train.AdamOptimizer(learning_rate=self.learning_rate,
                                                    beta1=0.9, beta2=0.999, epsilon=1e-8).\
                                                    minimize(self.loss, global_step=self.global_step)
            elif self.optimizer_type == "adagrad":
                self.optimizer = tf.train.AdagradOptimizer(learning_rate=self.learning_rate,
                                                           initial_accumulator_value=1e-8).\
                                                           minimize(self.loss, global_step=self.global_step)
            elif self.optimizer_type == "gd":
                self.optimizer = tf.train.GradientDescentOptimizer(learning_rate=self.learning_rate).\
                                                                   minimize(self.loss, global_step=self.global_step)
            elif self.optimizer_type == "momentum":
                self.optimizer = tf.train.MomentumOptimizer(learning_rate=self.learning_rate, momentum=0.95).\
                                                            minimize(self.loss, global_step=self.global_step)

            # init
            self.saver = tf.train.Saver(max_to_keep=5)
            init = tf.global_variables_initializer()
            self.sess = self._init_session()
            self.sess.run(init)
            self.count_param()
Ejemplo n.º 9
0
import os
import tensorflow.compat.v1 as tf
tf.disable_v2_behavior()
from tensorflow.core.protobuf import saver_pb2
import driving_data
import model

LOGDIR = './save'

sess = tf.InteractiveSession()

L2NormConst = 0.001

train_vars = tf.trainable_variables()

loss = tf.reduce_mean(tf.square(tf.subtract(model.y_, model.y))) + tf.add_n([tf.nn.l2_loss(v) for v in train_vars]) * L2NormConst
train_step = tf.train.AdamOptimizer(1e-4).minimize(loss)
sess.run(tf.global_variables_initializer())

# create a summary to monitor cost tensor
tf.summary.scalar("loss", loss)
# merge all summaries into a single op
merged_summary_op = tf.summary.merge_all()

saver = tf.train.Saver(write_version = saver_pb2.SaverDef.V2)

# op to write logs to Tensorboard
logs_path = './logs'
summary_writer = tf.summary.FileWriter(logs_path, graph=tf.get_default_graph())

epochs = 30
Ejemplo n.º 10
0
        def compute_gradients(self,
                              loss,
                              var_list,
                              gate_gradients=GATE_OP,
                              aggregation_method=None,
                              colocate_gradients_with_ops=False,
                              grad_loss=None,
                              gradient_tape=None):
            """DP-SGD version of base class method."""
            if callable(loss):
                # TF is running in Eager mode
                raise NotImplementedError(
                    'Vectorized optimizer unavailable for TF2.')
            else:
                # TF is running in graph mode, check we did not receive a gradient tape.
                if gradient_tape:
                    raise ValueError(
                        'When in graph mode, a tape should not be passed.')

                batch_size = tf.shape(input=loss)[0]
                if self._num_microbatches is None:
                    self._num_microbatches = batch_size

                # Note: it would be closer to the correct i.i.d. sampling of records if
                # we sampled each microbatch from the appropriate binomial distribution,
                # although that still wouldn't be quite correct because it would be
                # sampling from the dataset without replacement.
                microbatch_losses = tf.reshape(loss,
                                               [self._num_microbatches, -1])

                if var_list is None:
                    var_list = (tf.trainable_variables() + tf.get_collection(
                        tf.GraphKeys.TRAINABLE_RESOURCE_VARIABLES))

                def process_microbatch(microbatch_loss):
                    """Compute clipped grads for one microbatch."""
                    microbatch_loss = tf.reduce_mean(
                        input_tensor=microbatch_loss)
                    grads, _ = zip(
                        *super(DPOptimizerClass, self).compute_gradients(
                            microbatch_loss, var_list, gate_gradients,
                            aggregation_method, colocate_gradients_with_ops,
                            grad_loss))
                    grads_list = [
                        g if g is not None else tf.zeros_like(v)
                        for (g, v) in zip(list(grads), var_list)
                    ]
                    # Clip gradients to have L2 norm of l2_norm_clip.
                    # Here, we use TF primitives rather than the built-in
                    # tf.clip_by_global_norm() so that operations can be vectorized
                    # across microbatches.
                    grads_flat = tf.nest.flatten(grads_list)
                    squared_l2_norms = [
                        tf.reduce_sum(input_tensor=tf.square(g))
                        for g in grads_flat
                    ]
                    global_norm = tf.sqrt(tf.add_n(squared_l2_norms))
                    div = tf.maximum(global_norm / self._l2_norm_clip, 1.)
                    clipped_flat = [g / div for g in grads_flat]
                    clipped_grads = tf.nest.pack_sequence_as(
                        grads_list, clipped_flat)
                    return clipped_grads

                clipped_grads = tf.vectorized_map(process_microbatch,
                                                  microbatch_losses)

                def reduce_noise_normalize_batch(stacked_grads):
                    summed_grads = tf.reduce_sum(input_tensor=stacked_grads,
                                                 axis=0)
                    noise_stddev = self._l2_norm_clip * self._noise_multiplier
                    noise = tf.random.normal(tf.shape(input=summed_grads),
                                             stddev=noise_stddev)
                    noised_grads = summed_grads + noise
                    return noised_grads / tf.cast(self._num_microbatches,
                                                  tf.float32)

                final_grads = tf.nest.map_structure(
                    reduce_noise_normalize_batch, clipped_grads)

                return list(zip(final_grads, var_list))
    def test_calculate_branching_model_parameters_transformer(
            self, get_config, expected_hidden_depths):
        tf.reset_default_graph()

        (num_cells, left_inputs, left_layers, left_output_dims, right_inputs,
         right_layers, right_output_dims, combiner_functions,
         final_combiner_function, dummy_activations, dummy_norms,
         layer_registry, is_decoder) = get_config()

        # Get predicted number of parameters.
        (predicted_num_params, output_size, hidden_depths,
         _) = translation_nas_net.calculate_branching_model_parameters(
             encoding_depth=_EMBEDDING_DEPTH,
             left_inputs=left_inputs,
             left_layers=left_layers,
             left_output_dims=left_output_dims,
             right_inputs=right_inputs,
             right_layers=right_layers,
             right_output_dims=right_output_dims,
             combiner_functions=combiner_functions,
             final_combiner_function=final_combiner_function,
             layer_registry=layer_registry,
             num_cells=num_cells,
             encoder_depth=_EMBEDDING_DEPTH)

        # Create model graph.
        input_tensor = tf.zeros([32, _INPUT_LENGTH, _EMBEDDING_DEPTH])
        hparams = transformer.transformer_small()

        if is_decoder:
            nonpadding = None
            mask_future = True
            decoder_self_attention_bias = (
                common_attention.attention_bias_lower_triangle(_INPUT_LENGTH))
            encoder_cell_outputs = [input_tensor] * 6
        else:
            nonpadding = tf.ones([32, _INPUT_LENGTH])
            mask_future = False
            decoder_self_attention_bias = None
            encoder_cell_outputs = None

        translation_nas_net.apply_nas_layers(
            input_tensor=input_tensor,
            left_inputs=left_inputs,
            left_layers=left_layers,
            left_activations=dummy_activations,
            left_output_dims=left_output_dims,
            left_norms=dummy_norms,
            right_inputs=right_inputs,
            right_layers=right_layers,
            right_activations=dummy_activations,
            right_output_dims=right_output_dims,
            right_norms=dummy_norms,
            combiner_functions=combiner_functions,
            final_combiner_function=final_combiner_function,
            num_cells=num_cells,
            nonpadding=nonpadding,
            layer_registry=layer_registry,
            mask_future=mask_future,
            hparams=hparams,
            var_scope="test",
            encoder_decoder_attention_bias=None,
            encoder_cell_outputs=encoder_cell_outputs,
            decoder_self_attention_bias=decoder_self_attention_bias,
            final_layer_norm=False)

        # Count graph variables.
        trainable_variables_list = tf.trainable_variables()
        empirical_num_params = 0
        for variable_tensor in trainable_variables_list:
            empirical_num_params += _list_product(
                variable_tensor.shape.as_list())

        # Compare.
        self.assertEqual(empirical_num_params, predicted_num_params)
        self.assertEqual(output_size, _EMBEDDING_DEPTH)
        self.assertEqual(hidden_depths, expected_hidden_depths)
Ejemplo n.º 12
0
    def model_fn(features, labels, mode, params):  # pylint: disable=unused-argument
        """The `model_fn` for TPUEstimator."""

        tf.logging.info("*** Features ***")
        for name in sorted(features.keys()):
            tf.logging.info("  name = %s, shape = %s" %
                            (name, features[name].shape))

        input_ids = features["input_ids"]

        is_training = (mode == tf.estimator.ModeKeys.TRAIN)

        model = GroverModel(
            config=config,
            is_training=is_training,
            input_ids=input_ids,
            pad_token_id=config.pad_token_id,
            chop_off_last_token=True,
        )

        total_loss = model.lm_loss()

        if is_training:
            train_op, train_metrics = optimization_adafactor.create_optimizer(
                total_loss, learning_rate, num_train_steps, num_warmup_steps,
                use_tpu)
            tvars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)
        else:
            train_op = None
            train_metrics = {}
            tvars = tf.trainable_variables()

        initialized_variable_names = {}
        scaffold_fn = None
        if init_checkpoint:
            (assignment_map,
             initialized_variable_names) = get_assignment_map_from_checkpoint(
                 tvars, init_checkpoint)
            if use_tpu:

                def tpu_scaffold():
                    tf.train.init_from_checkpoint(init_checkpoint,
                                                  assignment_map)
                    return tf.train.Scaffold()

                scaffold_fn = tpu_scaffold
            else:
                tf.train.init_from_checkpoint(init_checkpoint, assignment_map)

        tf.logging.info("**** Trainable Variables ****")
        for var in tvars:
            init_string = ""
            if var.name in initialized_variable_names:
                init_string = ", *INIT_FROM_CKPT*"
            tf.logging.info("  name = %s, shape = %s%s", var.name, var.shape,
                            init_string)

        output_spec = None
        if mode == tf.estimator.ModeKeys.TRAIN:
            if use_tpu:
                output_spec = tf.contrib.tpu.TPUEstimatorSpec(
                    mode=mode,
                    loss=total_loss,
                    train_op=train_op,
                    host_call=construct_scalar_host_call(
                        metric_dict=train_metrics,
                        model_dir=params['model_dir'],
                        prefix='training/'),
                    scaffold_fn=scaffold_fn)
            else:
                output_spec = tf.contrib.tpu.TPUEstimatorSpec(
                    mode=mode,
                    loss=total_loss,
                    train_op=train_op,
                    training_hooks=[
                        tf.train.LoggingTensorHook(
                            {'loss': tf.metrics.mean(total_loss)[1]},
                            every_n_iter=100)
                    ],
                    scaffold_fn=scaffold_fn)

        elif mode == tf.estimator.ModeKeys.EVAL:

            def metric_fn(total_loss):
                loss = tf.metrics.mean(values=total_loss)
                return {
                    "eval_loss": loss,
                }

            eval_metrics = (metric_fn, [total_loss])
            output_spec = tf.contrib.tpu.TPUEstimatorSpec(
                mode=mode,
                loss=total_loss,
                eval_metrics=eval_metrics,
                scaffold_fn=scaffold_fn)
        else:
            gt_logprobs = tf.squeeze(tf.batch_gather(
                model.log_probs, model.target_ids[:, :, None]),
                                     axis=2)

            # Need top-p required under topp sampling!
            better_than_gt = model.log_probs > gt_logprobs[:, :, None]
            top_p_required = tf.reduce_sum(
                tf.cast(better_than_gt, tf.float32) * tf.exp(model.log_probs),
                axis=2)

            # No top-p sampling for now, since this seems to be too slow on TPUs
            if use_tpu:
                predictions = tf.reshape(
                    tf.random.categorical(logits=model.logits_flat,
                                          num_samples=1),
                    get_shape_list(model.target_ids),
                )
            else:
                # Argmax
                # predictions = tf.math.argmax(model.log_probs, axis=-1, output_type=tf.int32)
                predictions = tf.reshape(
                    _top_p_sample(model.logits_flat, num_samples=1,
                                  p=0.99)['sample'],
                    get_shape_list(model.target_ids),
                )
            pred_logprobs = tf.squeeze(tf.batch_gather(model.log_probs,
                                                       predictions[:, :,
                                                                   None]),
                                       axis=2)

            output_spec = tf.contrib.tpu.TPUEstimatorSpec(
                mode=mode,
                predictions={
                    'gt_logprobs': gt_logprobs,
                    'top_p_required': top_p_required,
                    'predictions': predictions,
                    'pred_logprobs': pred_logprobs,
                    'labels': input_ids
                },
                scaffold_fn=scaffold_fn)
        return output_spec
Ejemplo n.º 13
0
    def step_fn(self, params, model):
        """A single step for supervised learning."""
        (train_images, train_labels, valid_images,
         valid_labels) = tf.raw_ops.InfeedDequeueTuple(
             dtypes=params.train_dtypes, shapes=params.train_shapes)

        if train_labels.dtype == tf.int32:
            train_labels = tf.one_hot(train_labels,
                                      depth=params.num_classes,
                                      dtype=tf.float32)
        if valid_labels.dtype == tf.int32:
            valid_labels = tf.one_hot(valid_labels,
                                      depth=params.num_classes,
                                      dtype=tf.float32)
        global_step = tf.train.get_or_create_global_step()

        num_replicas = tf.cast(params.num_replicas, tf.float32)

        with tf.variable_scope(MODEL_SCOPE):
            train_logits = model(train_images, training=True)

        with tf.variable_scope(SCORE_SCOPE):
            score_logits = model(train_images,
                                 training=False,
                                 return_scores=True)
            score_m = tf.tpu.cross_replica_sum(tf.reduce_sum(score_logits))
            score_m = tf.stop_gradient(score_m) / float(params.num_replicas)
            score_e = tf.exp(score_logits - score_m)
            score_z = tf.tpu.cross_replica_sum(tf.reduce_sum(score_e))
            score_probs = score_e / score_z

        # train the main model
        cross_entropy = tf.losses.softmax_cross_entropy(
            onehot_labels=train_labels,
            logits=train_logits,
            label_smoothing=params.label_smoothing,
            reduction=tf.losses.Reduction.NONE)
        cross_entropy = tf.reduce_sum(cross_entropy *
                                      tf.stop_gradient(score_probs))

        l2_reg_rate = tf.cast(params.weight_decay / params.num_replicas,
                              tf.float32)
        weight_dec = common_utils.get_l2_loss(excluded_keywords=[SCORE_SCOPE])
        total_loss = cross_entropy + weight_dec * l2_reg_rate

        model_variables = [
            v for v in tf.trainable_variables() if MODEL_SCOPE in v.name
        ]
        train_gradients = tf.gradients(total_loss, model_variables)
        train_gradients = [
            tf.tpu.cross_replica_sum(g) for g in train_gradients
        ]
        train_gradients, grad_norm = tf.clip_by_global_norm(
            train_gradients, params.grad_bound)

        learning_rate, optimizer = common_utils.get_optimizer(params)
        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        train_op = tf.cond(
            tf.math.is_finite(grad_norm), lambda: optimizer.
            apply_gradients(zip(train_gradients, model_variables),
                            global_step=global_step), tf.no_op)
        with tf.control_dependencies(update_ops + [train_op]):
            ema_train_op = common_utils.setup_ema(
                params, f'{MODEL_SCOPE}/{model.name}')

        with tf.control_dependencies([ema_train_op]):
            with tf.variable_scope(MODEL_SCOPE, reuse=True):
                valid_logits = model(valid_images, training=False)
                valid_cross_entropy = tf.losses.softmax_cross_entropy(
                    onehot_labels=valid_labels,
                    logits=valid_logits,
                    reduction=tf.losses.Reduction.MEAN) / float(
                        params.num_replicas)
                valid_gradients = tf.gradients(valid_cross_entropy,
                                               model_variables)
                valid_gradients = [
                    tf.tpu.cross_replica_sum(g) for g in valid_gradients
                ]

            dot_product = tf.add_n([
                tf.reduce_sum(g_t * g_v)
                for g_t, g_v in zip(train_gradients, valid_gradients)
            ])
            dot_product = tf.stop_gradient(dot_product)
            dot_product_avg = tf.get_variable(name='dot_product_avg',
                                              shape=[],
                                              trainable=False)
            dot_product_update = tf.assign_sub(
                dot_product_avg, 0.01 * (dot_product_avg - dot_product))
            with tf.control_dependencies([dot_product_update]):
                dot_product = tf.identity(dot_product - dot_product_avg)

        # trains the scorer.
        score_entropy = tf.reduce_sum(-score_probs * tf.math.log(score_probs))
        score_entropy = tf.tpu.cross_replica_sum(score_entropy) / float(
            valid_images.shape[0].value)
        score_variables = [
            v for v in tf.trainable_variables() if SCORE_SCOPE in v.name
        ]
        score_gradients = tf.gradients(dot_product * score_entropy,
                                       score_variables)
        score_gradients = [
            tf.tpu.cross_replica_sum(g) for g in score_gradients
        ]
        score_optimizer = tf.train.GradientDescentOptimizer(
            learning_rate=params.scorer_lr, use_locking=True)
        score_train_op = tf.cond(
            global_step < params.scorer_wait_steps, tf.no_op,
            lambda: score_optimizer.apply_gradients(
                zip(score_gradients, score_variables)))

        with tf.control_dependencies([score_train_op]):
            logs = collections.OrderedDict()
            logs['global_step'] = tf.cast(global_step, tf.float32)

            logs['model/total'] = total_loss
            logs['model/weight_decay'] = weight_dec / num_replicas
            logs['model/cross_entropy'] = cross_entropy
            logs['model/lr'] = tf.identity(learning_rate) / num_replicas
            logs['model/grad_norm'] = grad_norm / num_replicas

            logs['score/dot_product'] = dot_product / num_replicas
            logs['score/dot_product_avg'] = dot_product_avg / num_replicas
            logs['score/entropy'] = score_entropy
            logs['score/p_min'] = tf.reduce_min(score_probs) / num_replicas
            logs['score/p_max'] = tf.reduce_max(score_probs) / num_replicas

            tensors = [tf.expand_dims(t, axis=0) for t in logs.values()]
            self.step_info = {k: [tf.float32, [1]] for k in logs.keys()}
            outfeed_enqueue_op = tf.cond(
                common_utils.should_log(params),
                lambda: tf.raw_ops.OutfeedEnqueueTuple(inputs=tensors),
                tf.no_op)
        return outfeed_enqueue_op
Ejemplo n.º 14
0
  def build_train_graph(self,
                        inputs,
                        min_depth,
                        max_depth,
                        num_mpi_planes,
                        learning_rate=0.0002,
                        beta1=0.9,
                        vgg_model_file=None,
                        global_step=0):
    """Construct the training computation graph.

    Args:
      inputs: dictionary of tensors (see 'input_data' below) needed for training
      min_depth: minimum depth for the PSV and MPI planes
      max_depth: maximum depth for the PSV and MPI planes
      num_mpi_planes: number of MPI planes to infer
      learning_rate: learning rate
      beta1: hyperparameter for Adam
      vgg_model_file: path to vgg weights (needed when vgg loss is used)
      global_step: current optimization step
    Returns:
      A train_op to be used for training.
    """
    print("starting to build graph")
    with tf.name_scope("input_size_randomization"):
      dim_choices = tf.constant([[1, 16], [2, 32], [4, 32], [4, 64], [4, 128],
                                 [8, 32], [8, 64], [8, 128]],
                                dtype=tf.int32)
      rand_dim = tf.random_shuffle(dim_choices)[0, :]
      height_div = rand_dim[0]
      width_div = rand_dim[0]
      num_mpi_planes = rand_dim[1]
      tf.summary.scalar("num_mpi_planes", num_mpi_planes)

    with tf.name_scope("setup"):
      mpi_planes = self.inv_depths(min_depth, max_depth, num_mpi_planes)

    with tf.name_scope("input_data"):
      raw_tgt_image = inputs["tgt_image"]
      raw_ref_image = inputs["ref_image"]
      raw_src_images = inputs["src_images"]

      _, img_height, img_width, _ = raw_src_images.get_shape().as_list(
      )
      img_height = img_height // height_div
      img_width = img_width // width_div

      raw_tgt_image = tf.image.convert_image_dtype(
          raw_tgt_image, dtype=tf.float32)
      raw_ref_image = tf.image.convert_image_dtype(
          raw_ref_image, dtype=tf.float32)
      raw_src_images = tf.image.convert_image_dtype(
          raw_src_images, dtype=tf.float32)
      raw_tgt_image = tf.image.resize_area(raw_tgt_image,
                                           [img_height, img_width])
      raw_ref_image = tf.image.resize_area(raw_ref_image,
                                           [img_height, img_width])
      raw_src_images = tf.image.resize_area(raw_src_images,
                                            [img_height, img_width])

      tgt_pose = inputs["tgt_pose"]
      ref_pose = inputs["ref_pose"]
      src_poses = inputs["src_poses"]
      intrinsics = inputs["intrinsics"]

      # Scale intrinsics based on size randomization
      intrinsics = tf.concat([
          intrinsics[:, 0:1, :] / tf.to_float(width_div),
          intrinsics[:, 1:2, :] / tf.to_float(height_div), intrinsics[:, 2:3, :]
      ],
                             axis=1)
      inputs["intrinsics"] = intrinsics

      _, num_source, _, _ = src_poses.get_shape().as_list()

    with tf.name_scope("inference"):
      print("setting up MPI inference")
      num_mpi_planes = tf.shape(mpi_planes)[0]
      pred = self.infer_mpi(raw_src_images, raw_ref_image, ref_pose, src_poses,
                            intrinsics, num_mpi_planes,
                            mpi_planes)
      rgba_layers = pred["rgba_layers"]
      rgba_layers_refine = pred["rgba_layers_refine"]
      stuff_behind = pred["stuff_behind"]
      refine_input_mpi = pred["refine_input_mpi"]
      psv = pred["psv"]

    with tf.name_scope("synthesis"):
      print("setting up rendering")
      rel_pose = tf.matmul(tgt_pose, tf.matrix_inverse(ref_pose))
      output_image, output_layers = self.mpi_render_view(
          rgba_layers, rel_pose, mpi_planes, intrinsics)
      output_alpha = output_layers[Ellipsis, -1]
      output_image_refine, _ = self.mpi_render_view(
          rgba_layers_refine, rel_pose, mpi_planes, intrinsics)

    with tf.name_scope("loss"):
      print("computing losses")
      # Mask loss for pixels outside reference frustum
      loss_mask = tf.where(
          tf.equal(
              tf.reduce_min(
                  tf.abs(tf.reduce_sum(output_layers, axis=-1)),
                  axis=3,
                  keep_dims=True), 0.0),
          tf.zeros_like(output_alpha[:, :, :, 0:1]),
          tf.ones_like(output_alpha[:, :, :, 0:1]))
      loss_mask = tf.stop_gradient(loss_mask)
      tf.summary.image("loss_mask", loss_mask)

      # Helper functions for loss
      def compute_error(real, fake, mask):
        return tf.reduce_mean(mask * tf.abs(fake - real))

      # Normalized VGG loss (from
      # https://github.com/CQFIO/PhotographicImageSynthesis)

      downsample = lambda tensor, ds: tf.nn.avg_pool(tensor, [1, ds, ds, 1],
                                                     [1, ds, ds, 1], "SAME")

      def vgg_loss(raw_tgt_image, output_image, loss_mask):
        """Compute VGG loss."""

        vgg_real = build_vgg19(raw_tgt_image * 255.0, vgg_model_file)
        rescaled_output_image = (output_image + 1.)/2. * 255.0
        vgg_fake = build_vgg19(
            rescaled_output_image, vgg_model_file, reuse=True)
        p0 = compute_error(vgg_real["input"], vgg_fake["input"], loss_mask)
        p1 = compute_error(vgg_real["conv1_2"],
                           vgg_fake["conv1_2"],
                           loss_mask)/2.6
        p2 = compute_error(vgg_real["conv2_2"],
                           vgg_fake["conv2_2"],
                           downsample(loss_mask, 2))/4.8
        p3 = compute_error(vgg_real["conv3_2"],
                           vgg_fake["conv3_2"],
                           downsample(loss_mask, 4))/3.7
        p4 = compute_error(vgg_real["conv4_2"],
                           vgg_fake["conv4_2"],
                           downsample(loss_mask, 8))/5.6
        p5 = compute_error(vgg_real["conv5_2"],
                           vgg_fake["conv5_2"],
                           downsample(loss_mask, 16))*10/1.5
        total_loss = p0+p1+p2+p3+p4+p5
        return total_loss, vgg_real, vgg_fake

      vgg_loss_initial, _, _ = vgg_loss(raw_tgt_image, output_image, loss_mask)
      tf.summary.scalar("vgg_loss_initial", vgg_loss_initial)
      total_loss = vgg_loss_initial

      vgg_loss_refine, _, _ = vgg_loss(raw_tgt_image, output_image_refine,
                                       loss_mask)
      tf.summary.scalar("vgg_loss_refine", vgg_loss_refine)
      total_loss += vgg_loss_refine

    with tf.name_scope("train_op"):
      print("setting up train op")
      train_vars = [var for var in tf.trainable_variables()]
      optim = tf.train.AdamOptimizer(learning_rate, beta1)
      grads_and_vars = optim.compute_gradients(total_loss, var_list=train_vars)
      train_op = [optim.apply_gradients(grads_and_vars)]

    # Summaries
    tf.summary.scalar("total_loss", total_loss)
    # Source images
    for i in range(num_source):
      src_image = raw_src_images[:, :, :, i*3:(i+1)*3]
      tf.summary.image("src_image_%d" % i, src_image)
    # Output image
    tf.summary.image("output_image", self.deprocess_image(output_image))
    # Refined output image
    tf.summary.image("output_image_refine",
                     self.deprocess_image(output_image_refine))
    # Target image
    tf.summary.image("tgt_image", raw_tgt_image)
    # Ref image
    tf.summary.image("ref_image", raw_ref_image)
    # Predicted color and alpha layers, and PSV
    num_summ = 16  # Number of plane summaries to show in tensorboard
    for i in range(num_summ):
      ind = tf.to_int32(i * num_mpi_planes/num_summ)
      rgb = rgba_layers[:, :, :, ind, :3]
      alpha = rgba_layers[:, :, :, ind, -1:]
      ref_plane = psv[:, :, :, ind, 3:6]
      source_plane = psv[:, :, :, ind, :3]
      output_rgb = output_layers[:, :, :, ind, :3]
      tf.summary.image("rgb_layer_%d" % i, self.deprocess_image(rgb))
      tf.summary.image("alpha_layer_%d" % i, alpha)
      tf.summary.image("rgba_layer_%d" % i, self.deprocess_image(rgb * alpha))
      tf.summary.image("psv_avg_%d" % i,
                       (self.deprocess_image(0.5*ref_plane + 0.5*source_plane)))
      tf.summary.image("output_rgb_%d" % i,
                       self.deprocess_image(output_rgb))
      tf.summary.image("psv_ref_%d" % i, self.deprocess_image(ref_plane))
      tf.summary.image("psv_source_%d" % i, self.deprocess_image(source_plane))

    # Cumulative rendered images and refined MPI
    for i in range(num_summ):
      ind = tf.to_int32(i * num_mpi_planes/num_summ)
      rgb = rgba_layers_refine[:, :, :, ind, :3]
      alpha = rgba_layers_refine[:, :, :, ind, 3:]
      render = stuff_behind[:, :, :, ind, :3]
      input_colors = refine_input_mpi[:, :, :, ind, :3]
      tf.summary.image("rgb_layer_refine_%d" % i, self.deprocess_image(rgb))
      tf.summary.image("alpha_layer_refine_%d" % i, alpha)
      tf.summary.image("rgba_layer_refine_%d" % i,
                       self.deprocess_image(rgb * alpha))
      tf.summary.image("cumulative_render_%d" % i, self.deprocess_image(render))
      tf.summary.image("input_colors_refine_%d" % i,
                       self.deprocess_image(input_colors))

    return train_op
Ejemplo n.º 15
0
  def model_fn(features, labels, mode, params):  # pylint: disable=unused-argument
    """The `model_fn` for TPUEstimator."""

    tf.logging.info("*** Features ***")
    for name in sorted(features.keys()):
      tf.logging.info("  name = %s, shape = %s" % (name, features[name].shape))

    input_ids = features["input_ids"]
    input_mask = features["input_mask"]
    segment_ids = features["segment_ids"]
    masked_lm_positions = features["masked_lm_positions"]
    masked_lm_ids = features["masked_lm_ids"]
    masked_lm_weights = features["masked_lm_weights"]

    is_training = (mode == tf_estimator.ModeKeys.TRAIN)

    model = modeling.BertModel(
        config=bert_config,
        is_training=is_training,
        input_ids=input_ids,
        input_mask=input_mask,
        token_type_ids=segment_ids,
        use_one_hot_embeddings=use_one_hot_embeddings)

    (masked_lm_loss, masked_lm_example_loss,
     masked_lm_log_probs) = get_masked_lm_output(bert_config,
                                                 model.get_sequence_output(),
                                                 model.get_embedding_table(),
                                                 masked_lm_positions,
                                                 masked_lm_ids,
                                                 masked_lm_weights)

    total_loss = masked_lm_loss

    tvars = tf.trainable_variables()

    initialized_variable_names = {}
    scaffold_fn = None
    if init_checkpoint:
      (assignment_map, initialized_variable_names
      ) = modeling.get_assignment_map_from_checkpoint(tvars, init_checkpoint)
      if use_tpu:

        def tpu_scaffold():
          tf.train.init_from_checkpoint(init_checkpoint, assignment_map)
          return tf.train.Scaffold()

        scaffold_fn = tpu_scaffold
      else:
        tf.train.init_from_checkpoint(init_checkpoint, assignment_map)

    tf.logging.info("**** Trainable Variables ****")
    for var in tvars:
      init_string = ""
      if var.name in initialized_variable_names:
        init_string = ", *INIT_FROM_CKPT*"
      tf.logging.info("  name = %s, shape = %s%s", var.name, var.shape,
                      init_string)

    output_spec = None
    if mode == tf_estimator.ModeKeys.TRAIN:
      train_op = optimization.create_optimizer(total_loss, learning_rate,
                                               num_train_steps,
                                               num_warmup_steps, use_tpu)

      output_spec = tf.contrib.tpu.TPUEstimatorSpec(
          mode=mode,
          loss=total_loss,
          train_op=train_op,
          scaffold_fn=scaffold_fn)
    elif mode == tf_estimator.ModeKeys.EVAL:

      def metric_fn(masked_lm_example_loss, masked_lm_log_probs, masked_lm_ids,
                    masked_lm_weights):
        """Computes the loss and accuracy of the model."""
        masked_lm_log_probs = tf.reshape(masked_lm_log_probs,
                                         [-1, masked_lm_log_probs.shape[-1]])
        masked_lm_predictions = tf.argmax(
            masked_lm_log_probs, axis=-1, output_type=tf.int32)
        masked_lm_example_loss = tf.reshape(masked_lm_example_loss, [-1])
        masked_lm_ids = tf.reshape(masked_lm_ids, [-1])
        masked_lm_weights = tf.reshape(masked_lm_weights, [-1])
        masked_lm_accuracy = tf.metrics.accuracy(
            labels=masked_lm_ids,
            predictions=masked_lm_predictions,
            weights=masked_lm_weights)
        masked_lm_mean_loss = tf.metrics.mean(
            values=masked_lm_example_loss, weights=masked_lm_weights)

        return {
            "masked_lm_accuracy": masked_lm_accuracy,
            "masked_lm_loss": masked_lm_mean_loss,
        }

      eval_metrics = (metric_fn, [
          masked_lm_example_loss, masked_lm_log_probs, masked_lm_ids,
          masked_lm_weights
      ])
      output_spec = tf.contrib.tpu.TPUEstimatorSpec(
          mode=mode,
          loss=total_loss,
          eval_metrics=eval_metrics,
          scaffold_fn=scaffold_fn)
    else:
      raise ValueError("Only TRAIN and EVAL modes are supported: %s" % (mode))

    return output_spec
    def model_fn(features, labels, mode, params):  # pylint: disable=unused-argument
        """The `model_fn` for TPUEstimator."""

        tf.logging.info("*** Features ***")
        for name in sorted(features.keys()):
            tf.logging.info("  name = %s, shape = %s" %
                            (name, features[name].shape))

        input_ids = features["input_ids"]
        input_mask = features["input_mask"]
        segment_ids = features["segment_ids"]
        is_real_example, label_ids = None, None
        if FLAGS.export_dir is None:
            label_ids = features["label_ids"]
            if "is_real_example" in features:
                is_real_example = tf.cast(features["is_real_example"],
                                          dtype=tf.float32)
            else:
                is_real_example = tf.ones(tf.shape(label_ids),
                                          dtype=tf.float32)

        is_training = (mode == tf.estimator.ModeKeys.TRAIN)

        (total_loss, per_example_loss, logits,
         probabilities) = create_model(bert_config, is_training, input_ids,
                                       input_mask, segment_ids, label_ids,
                                       num_labels, use_one_hot_embeddings)

        tvars = tf.trainable_variables()
        initialized_variable_names = {}
        scaffold_fn = None
        if init_checkpoint:
            (assignment_map, initialized_variable_names
             ) = modeling.get_assignment_map_from_checkpoint(
                 tvars, init_checkpoint)
            if use_tpu:

                def tpu_scaffold():
                    tf.train.init_from_checkpoint(init_checkpoint,
                                                  assignment_map)
                    return tf.train.Scaffold()

                scaffold_fn = tpu_scaffold
            else:
                tf.train.init_from_checkpoint(init_checkpoint, assignment_map)

        tf.logging.info("**** Trainable Variables ****")
        for var in tvars:
            init_string = ""
            if var.name in initialized_variable_names:
                init_string = ", *INIT_FROM_CKPT*"
            tf.logging.info("  name = %s, shape = %s%s", var.name, var.shape,
                            init_string)

        output_spec = None
        if mode == tf.estimator.ModeKeys.TRAIN:

            train_op = optimization.create_optimizer(total_loss, learning_rate,
                                                     num_train_steps,
                                                     num_warmup_steps, use_tpu)

            output_spec = tf.estimator.tpu.TPUEstimatorSpec(
                mode=mode,
                loss=total_loss,
                train_op=train_op,
                scaffold_fn=scaffold_fn)
        elif mode == tf.estimator.ModeKeys.EVAL:

            def metric_fn(per_example_loss, label_ids, logits,
                          is_real_example):
                predictions = tf.argmax(logits, axis=-1, output_type=tf.int32)
                accuracy = tf.metrics.accuracy(labels=label_ids,
                                               predictions=predictions,
                                               weights=is_real_example)
                loss = tf.metrics.mean(values=per_example_loss,
                                       weights=is_real_example)
                return {
                    "eval_accuracy": accuracy,
                    "eval_loss": loss,
                }

            eval_metrics = (metric_fn, [
                per_example_loss, label_ids, logits, is_real_example
            ])
            output_spec = tf.estimator.tpu.TPUEstimatorSpec(
                mode=mode,
                loss=total_loss,
                eval_metrics=eval_metrics,
                scaffold_fn=scaffold_fn)
        else:
            probabilities = tf.identity(probabilities, name="probabilities")
            output_spec = tf.estimator.tpu.TPUEstimatorSpec(
                mode=mode,
                predictions={"probabilities": probabilities},
                scaffold_fn=scaffold_fn)
        return output_spec
Ejemplo n.º 17
0
    def model_fn(features, labels, mode, params):
        """Model computational graph."""
        del labels
        del params

        #### Build model
        if FLAGS.model_config:
            net_config = modeling.ModelConfig.init_from_json(
                FLAGS.model_config)
        else:
            net_config = modeling.ModelConfig.init_from_flags()
        net_config.to_json(os.path.join(FLAGS.model_dir, "net_config.json"))
        model = modeling.FunnelTFM(net_config)

        #### Training or Evaluation
        is_training = (mode == tf.estimator.ModeKeys.TRAIN)

        @model_utils.bf16_decorator
        def race_loss_func(features, model):
            """Get race loss."""
            #### Get loss from inputs
            inputs = features["input_ids"]
            seg_id = features["segment_ids"]
            input_mask = features["input_mask"]
            labels = features["label_ids"]

            with tf.variable_scope("model", reuse=tf.AUTO_REUSE):
                per_example_loss, logits = model.get_race_loss(
                    labels,
                    inputs,
                    is_training,
                    seg_id=seg_id,
                    input_mask=input_mask,
                    use_tpu=FLAGS.use_tpu,
                    use_bfloat16=FLAGS.use_bfloat16)

            return per_example_loss, logits

        per_example_loss, logits = race_loss_func(features, model)
        total_loss = tf.reduce_mean(per_example_loss)

        #### Check model parameters
        num_params = sum([np.prod(v.shape) for v in tf.trainable_variables()])
        tf.logging.info("#params: %d", num_params)
        if FLAGS.verbose:
            format_str = "{{:<{0}s}}\t{{}}".format(
                max([len(v.name) for v in tf.trainable_variables()]))
            for v in tf.trainable_variables():
                tf.logging.info(format_str.format(v.name, v.get_shape()))

        #### Load pretrained models
        scaffold_fn = model_utils.custom_initialization(FLAGS.init_global_vars)

        #### Evaluation mode
        if mode == tf.estimator.ModeKeys.EVAL:
            assert FLAGS.num_hosts == 1

            def metric_fn(per_example_loss, label_ids, logits, is_real_example,
                          is_high_example):
                """Metric function used for evaluation."""
                predictions = tf.argmax(logits, axis=-1, output_type=tf.int32)
                eval_input_dict = {
                    "labels": label_ids,
                    "predictions": predictions,
                    "weights": is_real_example
                }
                accuracy = tf.metrics.accuracy(**eval_input_dict)

                high_eval_input_dict = {
                    "labels": label_ids,
                    "predictions": predictions,
                    "weights": is_real_example * is_high_example
                }
                accuracy_high = tf.metrics.accuracy(**high_eval_input_dict)

                mid_eval_input_dict = {
                    "labels": label_ids,
                    "predictions": predictions,
                    "weights": is_real_example * (1.0 - is_high_example)
                }
                accuracy_mid = tf.metrics.accuracy(**mid_eval_input_dict)

                loss = tf.metrics.mean(values=per_example_loss,
                                       weights=is_real_example)
                return {
                    "eval_accuracy": accuracy,
                    "eval_accuracy_high": accuracy_high,
                    "eval_accuracy_mid": accuracy_mid,
                    "eval_loss": loss
                }

            is_real_example = tf.cast(features["is_real_example"],
                                      dtype=tf.float32)
            is_high_example = tf.cast(features["is_high_example"],
                                      dtype=tf.float32)

            #### Constructing evaluation TPUEstimatorSpec with new cache.
            label_ids = tf.reshape(features["label_ids"], [-1])
            metric_args = [
                per_example_loss, label_ids, logits, is_real_example,
                is_high_example
            ]

            if FLAGS.use_tpu:
                eval_spec = tf.estimator.tpu.TPUEstimatorSpec(
                    mode=mode,
                    loss=total_loss,
                    eval_metrics=(metric_fn, metric_args),
                    scaffold_fn=scaffold_fn)
            else:
                eval_spec = tf.estimator.EstimatorSpec(
                    mode=mode,
                    loss=total_loss,
                    eval_metric_ops=metric_fn(*metric_args))

            return eval_spec

        #### Get train op
        train_op, _ = optimization.get_train_op(total_loss)

        #### Constructing training TPUEstimatorSpec
        if FLAGS.use_tpu:
            #### Creating host calls
            host_call = None

            train_spec = tf.estimator.tpu.TPUEstimatorSpec(
                mode=mode,
                loss=total_loss,
                train_op=train_op,
                host_call=host_call,
                scaffold_fn=scaffold_fn)
        else:
            train_spec = tf.estimator.EstimatorSpec(mode=mode,
                                                    loss=total_loss,
                                                    train_op=train_op)

        return train_spec
Ejemplo n.º 18
0
  def build_train_graph(self,
                        inputs,
                        min_depth,
                        max_depth,
                        cube_res,
                        theta_res,
                        phi_res,
                        r_res,
                        scale_factors,
                        num_mpi_planes,
                        learning_rate=0.0001,
                        vgg_model_weights=None,
                        global_step=0,
                        depth_clip=20.0):
    """Construct the training computation graph.

    Args:
      inputs: dictionary of tensors (see 'input_data' below) needed for training
      min_depth: minimum depth for the PSV and MPI planes
      max_depth: maximum depth for the PSV and MPI planes
      cube_res: per-side cube resolution
      theta_res: environment map width
      phi_res: environment map height
      r_res: number of radii to use when sampling spheres for rendering
      scale_factors: downsampling factors of cubes relative to the coarsest
      num_mpi_planes: number of MPI planes to infer
      learning_rate: learning rate
      vgg_model_weights: vgg weights (needed when vgg loss is used)
      global_step: training iteration
      depth_clip: maximum depth for coarsest resampled volumes

    Returns:
      A train_op to be used for training.
    """
    with tf.name_scope('setup'):
      psv_planes = pj.inv_depths(min_depth, max_depth, num_mpi_planes)
      mpi_planes = pj.inv_depths(min_depth, max_depth, num_mpi_planes)

    with tf.name_scope('input_data'):

      tgt_image = inputs['tgt_image']
      ref_image = inputs['ref_image']
      src_images = inputs['src_images']
      env_image = inputs['env_image']

      ref_depth = inputs['ref_depth']

      tgt_pose = inputs['tgt_pose']
      ref_pose = inputs['ref_pose']
      src_poses = inputs['src_poses']
      env_pose = inputs['env_pose']

      intrinsics = inputs['intrinsics']

      _, _, _, num_source = src_poses.get_shape().as_list()

    with tf.name_scope('inference'):
      num_mpi_planes = tf.shape(mpi_planes)[0]
      pred = self.infer_mpi(src_images, ref_image, ref_pose, src_poses,
                            intrinsics, psv_planes)
      rgba_layers = pred['rgba_layers']
      psv = pred['psv']

    with tf.name_scope('synthesis'):
      output_image, output_alpha_acc, _ = self.mpi_render_view(
          rgba_layers, ref_pose, tgt_pose, mpi_planes, intrinsics)
    with tf.name_scope('environment_rendering'):
      mpi_gt = self.img2mpi(ref_image, ref_depth, mpi_planes)
      output_image_gt, _, _ = self.mpi_render_view(mpi_gt, ref_pose, tgt_pose,
                                                   mpi_planes, intrinsics)

      lightvols_gt, _, _, _, _ = self.predict_lighting_vol(
          mpi_gt,
          mpi_planes,
          intrinsics,
          cube_res,
          scale_factors,
          depth_clip=depth_clip)

      lightvols, lightvol_centers, \
      lightvol_side_lengths, \
      cube_rel_shapes, \
      cube_nest_inds = self.predict_lighting_vol(rgba_layers, mpi_planes,
                                                 intrinsics, cube_res,
                                                 scale_factors,
                                                 depth_clip=depth_clip)

      lightvols_out = nets.cube_net_multires(lightvols, cube_rel_shapes,
                                             cube_nest_inds)

      gt_envmap, gt_shells = self.render_envmap(lightvols_gt, lightvol_centers,
                                                lightvol_side_lengths,
                                                cube_rel_shapes, cube_nest_inds,
                                                ref_pose, env_pose, theta_res,
                                                phi_res, r_res)

      prenet_envmap, prenet_shells = self.render_envmap(
          lightvols, lightvol_centers, lightvol_side_lengths, cube_rel_shapes,
          cube_nest_inds, ref_pose, env_pose, theta_res, phi_res, r_res)

      output_envmap, output_shells = self.render_envmap(
          lightvols_out, lightvol_centers, lightvol_side_lengths,
          cube_rel_shapes, cube_nest_inds, ref_pose, env_pose, theta_res,
          phi_res, r_res)

    with tf.name_scope('loss'):
      # mask loss for pixels outside reference frustum
      loss_mask = tf.where(
          tf.equal(output_alpha_acc[Ellipsis, tf.newaxis], 0.0),
          tf.zeros_like(output_image[:, :, :, 0:1]),
          tf.ones_like(output_image[:, :, :, 0:1]))
      loss_mask = tf.stop_gradient(loss_mask)
      tf.summary.image('loss_mask', loss_mask)

      # helper functions for loss
      def compute_error(real, fake, mask):
        mask = tf.ones_like(real) * mask
        return tf.reduce_sum(mask * tf.abs(fake - real)) / (
            tf.reduce_sum(mask) + 1.0e-8)

      # Normalized VGG loss
      def downsample(tensor, ds):
        return tf.nn.avg_pool(tensor, [1, ds, ds, 1], [1, ds, ds, 1], 'SAME')

      def vgg_loss(tgt_image, output_image, loss_mask, vgg_weights):
        """VGG activation loss definition."""

        vgg_real = nets.build_vgg19(tgt_image * 255.0, vgg_weights)
        rescaled_output_image = output_image * 255.0
        vgg_fake = nets.build_vgg19(rescaled_output_image, vgg_weights)
        p0 = compute_error(vgg_real['input'], vgg_fake['input'], loss_mask)
        p1 = compute_error(vgg_real['conv1_2'], vgg_fake['conv1_2'],
                           loss_mask) / 2.6
        p2 = compute_error(vgg_real['conv2_2'], vgg_fake['conv2_2'],
                           downsample(loss_mask, 2)) / 4.8
        p3 = compute_error(vgg_real['conv3_2'], vgg_fake['conv3_2'],
                           downsample(loss_mask, 4)) / 3.7
        p4 = compute_error(vgg_real['conv4_2'], vgg_fake['conv4_2'],
                           downsample(loss_mask, 8)) / 5.6
        p5 = compute_error(vgg_real['conv5_2'], vgg_fake['conv5_2'],
                           downsample(loss_mask, 16)) * 10 / 1.5
        total_loss = p0 + p1 + p2 + p3 + p4 + p5
        return total_loss

      # rendered image loss
      render_loss = vgg_loss(tgt_image, output_image, loss_mask,
                             vgg_model_weights) / 100.0
      total_loss = render_loss

      # rendered envmap loss
      envmap_loss = vgg_loss(env_image, output_envmap[Ellipsis, :3],
                             tf.ones_like(env_image[Ellipsis, 0:1]),
                             vgg_model_weights) / 100.0

      # set envmap loss to 0 when only training mpi network (see paper)
      envmap_loss = tf.where(tf.greater(global_step, 240000), envmap_loss, 0.0)

      total_loss += envmap_loss

      # adversarial loss for envmap
      real_logit = nets.discriminator(env_image, scope='discriminator')
      fake_logit = nets.discriminator(
          output_envmap[Ellipsis, :3], scope='discriminator')
      adv_loss_list = []
      for i in range(len(fake_logit)):
        adv_loss_list.append(0.1 * -1.0 * tf.reduce_mean(fake_logit[i][-1]))
      adv_loss = tf.reduce_mean(adv_loss_list)
      real_loss_list = []
      fake_loss_list = []
      for i in range(len(fake_logit)):
        real_loss_list.append(
            -1.0 * tf.reduce_mean(tf.minimum(real_logit[i][-1] - 1, 0.0)))
        fake_loss_list.append(
            -1.0 *
            tf.reduce_mean(tf.minimum(-1.0 * fake_logit[i][-1] - 1, 0.0)))
      real_loss = tf.reduce_mean(real_loss_list)
      fake_loss = tf.reduce_mean(fake_loss_list)
      disc_loss = real_loss + fake_loss

      # set adv/disc losses to 0 until end of training
      adv_loss = tf.where(tf.greater(global_step, 690000), adv_loss, 0.0)
      disc_loss = tf.where(tf.greater(global_step, 690000), disc_loss, 0.0)

      tf.summary.scalar('loss_disc', disc_loss)
      tf.summary.scalar('loss_disc_real', real_loss)
      tf.summary.scalar('loss_disc_fake', fake_loss)
      tf.summary.scalar('loss_adv', adv_loss)

      total_loss += adv_loss

    with tf.name_scope('train_op'):
      train_variables = [
          var for var in tf.trainable_variables()
          if 'discriminator' not in var.name
      ]
      optim = tf.train.AdamOptimizer(learning_rate, epsilon=1e-4)
      grads_and_variables = optim.compute_gradients(
          total_loss, var_list=train_variables)
      grads = [gv[0] for gv in grads_and_variables]
      variables = [gv[1] for gv in grads_and_variables]

      def denan(x):
        return tf.where(tf.is_nan(x), tf.zeros_like(x), x)

      grads_clipped = [denan(g) for g in grads]
      grads_clipped, _ = tf.clip_by_global_norm(grads_clipped, 100.0)
      train_op = [optim.apply_gradients(zip(grads_clipped, variables))]
      tf.summary.scalar('gradient global norm', tf.linalg.global_norm(grads))
      tf.summary.scalar('clipped gradient global norm',
                        tf.linalg.global_norm(grads_clipped))

      d_variables = [
          var for var in tf.trainable_variables() if 'discriminator' in var.name
      ]
      optim_d = tf.train.AdamOptimizer(learning_rate, beta1=0.0)
      train_op.append(optim_d.minimize(disc_loss, var_list=d_variables))

    with tf.name_scope('envmap_gt'):
      tf.summary.image('envmap', gt_envmap)
      tf.summary.image('envmap_alpha', gt_envmap[Ellipsis, -1:])
      for i in range(len(gt_shells)):
        i_envmap = pj.over_composite(gt_shells[i])
        tf.summary.image('envmap_level_' + str(i), i_envmap)
    with tf.name_scope('envmap_prenet'):
      tf.summary.image('envmap', prenet_envmap)
      tf.summary.image('envmap_alpha', prenet_envmap[Ellipsis, -1:])
      for i in range(len(prenet_shells)):
        i_envmap = pj.over_composite(prenet_shells[i])
        tf.summary.image('envmap_level_' + str(i), i_envmap)
    with tf.name_scope('envmap_output'):
      tf.summary.image('envmap', output_envmap)
      tf.summary.image('envmap_alpha', output_envmap[Ellipsis, -1:])
      for i in range(len(output_shells)):
        i_envmap = pj.over_composite(output_shells[i])
        tf.summary.image('envmap_level_' + str(i), i_envmap)

    tf.summary.scalar('loss_total', total_loss)
    tf.summary.scalar('loss_render', render_loss)
    tf.summary.scalar('loss_envmap', envmap_loss)
    tf.summary.scalar('min_depth', min_depth)
    tf.summary.scalar('max_depth', max_depth)

    with tf.name_scope('level_stats'):
      for i in range(len(lightvols)):
        tf.summary.scalar('cube_side_length_' + str(i),
                          lightvol_side_lengths[i])
        tf.summary.scalar('cube_center_' + str(i), lightvol_centers[i][0, -1])

    # Source images
    for i in range(num_source):
      src_image = src_images[:, :, :, i * 3:(i + 1) * 3]
      tf.summary.image('image_src_%d' % i, src_image)
    # Output image
    tf.summary.image('image_output', output_image)
    tf.summary.image('image_output_Gt', output_image_gt)
    # Target image
    tf.summary.image('image_tgt', tgt_image)
    tf.summary.image('envmap_tgt', env_image)
    # Ref image
    tf.summary.image('image_ref', ref_image)
    # Predicted color and alpha layers, and PSV
    num_summ = 8  # number of plane summaries to show in tensorboard
    for i in range(num_summ):
      ind = tf.to_int32(i * num_mpi_planes / num_summ)
      rgb = rgba_layers[:, :, :, ind, :3]
      alpha = rgba_layers[:, :, :, ind, -1:]
      ref_plane = psv[:, :, :, ind, :3]
      source_plane = psv[:, :, :, ind, 3:6]
      tf.summary.image('layer_rgb_%d' % i, rgb)
      tf.summary.image('layer_alpha_%d' % i, alpha)
      tf.summary.image('layer_rgba_%d' % i, rgba_layers[:, :, :, ind, :])
      tf.summary.image('psv_avg_%d' % i, 0.5 * ref_plane + 0.5 * source_plane)
      tf.summary.image('psv_ref_%d' % i, ref_plane)
      tf.summary.image('psv_source_%d' % i, source_plane)

    return train_op
Ejemplo n.º 19
0
  def model_fn(features, labels, mode, params=None):
    """Build model and optimizer."""
    is_training = mode == tf.estimator.ModeKeys.TRAIN

    # Check training mode.
    if FLAGS.train_mode == 'pretrain':
      num_transforms = 2
      if FLAGS.fine_tune_after_block > -1:
        raise ValueError('Does not support layer freezing during pretraining,'
                         'should set fine_tune_after_block<=-1 for safety.')
    elif FLAGS.train_mode == 'finetune':
      num_transforms = 1
    else:
      raise ValueError('Unknown train_mode {}'.format(FLAGS.train_mode))

    # Split channels, and optionally apply extra batched augmentation.
    features_list = tf.split(
        features, num_or_size_splits=num_transforms, axis=-1)
    if FLAGS.use_blur and is_training and FLAGS.train_mode == 'pretrain':
      features_list = data_util.batch_random_blur(
          features_list, FLAGS.image_size, FLAGS.image_size)
    features = tf.concat(features_list, 0)  # (num_transforms * bsz, h, w, c)

    # Base network forward pass.
    with tf.variable_scope('base_model'):
      if FLAGS.train_mode == 'finetune' and FLAGS.fine_tune_after_block >= 4:
        # Finetune just supervised (linear) head will not update BN stats.
        model_train_mode = False
      else:
        # Pretrain or finetuen anything else will update BN stats.
        model_train_mode = is_training
      hiddens = model(features, is_training=model_train_mode)

    # Add head and loss.
    if FLAGS.train_mode == 'pretrain':
      tpu_context = params['context'] if 'context' in params else None
      hiddens_proj = model_util.projection_head(hiddens, is_training)
      contrast_loss, logits_con, labels_con = obj_lib.add_contrastive_loss(
          hiddens_proj,
          hidden_norm=FLAGS.hidden_norm,
          temperature=FLAGS.temperature,
          tpu_context=tpu_context if is_training else None,
          loss_type=FLAGS.loss_type,
          flags=FLAGS)
      logits_sup = tf.zeros([params['batch_size'], num_classes])
      gradients_penalty = FLAGS.gradient_penalty_weight * obj_lib.add_gradients_penalty(features, model, model_train_mode)
    else:
      contrast_loss = tf.zeros([])
      logits_con = tf.zeros([params['batch_size'], 10])
      labels_con = tf.zeros([params['batch_size'], 10])
      hiddens = model_util.projection_head(hiddens, is_training)
      logits_sup = model_util.supervised_head(
          hiddens, num_classes, is_training)
      obj_lib.add_supervised_loss(
          labels=labels['labels'],
          logits=logits_sup,
          weights=labels['mask'])

    # Add weight decay to loss, for non-LARS optimizers.
    model_util.add_weight_decay(adjust_per_optimizer=True)
    loss = tf.losses.get_total_loss()

    if FLAGS.train_mode == 'pretrain':
      variables_to_train = tf.trainable_variables()
    else:
      collection_prefix = 'trainable_variables_inblock_'
      variables_to_train = []
      for j in range(FLAGS.fine_tune_after_block + 1, 6):
        variables_to_train += tf.get_collection(collection_prefix + str(j))
      assert variables_to_train, 'variables_to_train shouldn\'t be empty!'

    tf.logging.info('===============Variables to train (begin)===============')
    tf.logging.info(variables_to_train)
    tf.logging.info('================Variables to train (end)================')

    learning_rate = model_util.learning_rate_schedule(
        FLAGS.learning_rate, num_train_examples)

    if is_training:
      if FLAGS.train_summary_steps > 0:
        # Compute stats for the summary.
        prob_con = tf.nn.softmax(logits_con)
        entropy_con = - tf.reduce_mean(
            tf.reduce_sum(prob_con * tf.math.log(prob_con + 1e-8), -1))

        summary_writer = tf2.summary.create_file_writer(FLAGS.model_dir)
        with tf.control_dependencies([summary_writer.init()]):
          with summary_writer.as_default():
            should_record = tf.math.equal(
                tf.math.floormod(tf.train.get_global_step(),
                                 FLAGS.train_summary_steps), 0)
            with tf2.summary.record_if(should_record):
              contrast_acc = tf.equal(
                  tf.argmax(labels_con, 1), tf.argmax(logits_con, axis=1))
              contrast_acc = tf.reduce_mean(tf.cast(contrast_acc, tf.float32))
              label_acc = tf.equal(
                  tf.argmax(labels['labels'], 1), tf.argmax(logits_sup, axis=1))
              label_acc = tf.reduce_mean(tf.cast(label_acc, tf.float32))
              tf2.summary.scalar(
                  'train_contrast_loss',
                  contrast_loss,
                  step=tf.train.get_global_step())
              tf2.summary.scalar(
                  'train_contrast_acc',
                  contrast_acc,
                  step=tf.train.get_global_step())
              tf2.summary.scalar(
                  'train_label_accuracy',
                  label_acc,
                  step=tf.train.get_global_step())
              tf2.summary.scalar(
                  'contrast_entropy',
                  entropy_con,
                  step=tf.train.get_global_step())
              tf2.summary.scalar(
                  'learning_rate', learning_rate,
                  step=tf.train.get_global_step())

      optimizer = model_util.get_optimizer(learning_rate)
      control_deps = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
      if FLAGS.train_summary_steps > 0:
        control_deps.extend(tf.summary.all_v2_summary_ops())
      with tf.control_dependencies(control_deps):
        train_op = optimizer.minimize(
            loss, global_step=tf.train.get_or_create_global_step(),
            var_list=variables_to_train)

      if FLAGS.checkpoint:
        def scaffold_fn():
          """Scaffold function to restore non-logits vars from checkpoint."""
          tf.train.init_from_checkpoint(
              FLAGS.checkpoint,
              {v.op.name: v.op.name
               for v in tf.global_variables(FLAGS.variable_schema)})

          if FLAGS.zero_init_logits_layer:
            # Init op that initializes output layer parameters to zeros.
            output_layer_parameters = [
                var for var in tf.trainable_variables() if var.name.startswith(
                    'head_supervised')]
            tf.logging.info('Initializing output layer parameters %s to zero',
                            [x.op.name for x in output_layer_parameters])
            with tf.control_dependencies([tf.global_variables_initializer()]):
              init_op = tf.group([
                  tf.assign(x, tf.zeros_like(x))
                  for x in output_layer_parameters])
            return tf.train.Scaffold(init_op=init_op)
          else:
            return tf.train.Scaffold()
      else:
        scaffold_fn = None

      return tf.estimator.tpu.TPUEstimatorSpec(
          mode=mode, train_op=train_op, loss=loss, scaffold_fn=scaffold_fn)
    else:

      def metric_fn(logits_sup, labels_sup, logits_con, labels_con, mask,
                    **kws):
        """Inner metric function."""
        metrics = {k: tf.metrics.mean(v, weights=mask)
                   for k, v in kws.items()}
        metrics['label_top_1_accuracy'] = tf.metrics.accuracy(
            tf.argmax(labels_sup, 1), tf.argmax(logits_sup, axis=1),
            weights=mask)
        metrics['label_top_5_accuracy'] = tf.metrics.recall_at_k(
            tf.argmax(labels_sup, 1), logits_sup, k=5, weights=mask)
        metrics['contrastive_top_1_accuracy'] = tf.metrics.accuracy(
            tf.argmax(labels_con, 1), tf.argmax(logits_con, axis=1),
            weights=mask)
        metrics['contrastive_top_5_accuracy'] = tf.metrics.recall_at_k(
            tf.argmax(labels_con, 1), logits_con, k=5, weights=mask)
        return metrics

      metrics = {
          'logits_sup': logits_sup,
          'labels_sup': labels['labels'],
          'logits_con': logits_con,
          'labels_con': labels_con,
          'mask': labels['mask'],
          'contrast_loss': tf.fill((params['batch_size'],), contrast_loss),
          'regularization_loss': tf.fill((params['batch_size'],),
                                         tf.losses.get_regularization_loss()),
      }

      return tf.estimator.tpu.TPUEstimatorSpec(
          mode=mode,
          loss=loss,
          eval_metrics=(metric_fn, metrics),
          scaffold_fn=None)
Ejemplo n.º 20
0
    def auxiliary_loss_fn(state):
        """Computes an auxiliary loss.

    Args:
      state: (dict) PathNet state as a dict containing a 'task_id' entry with
        scalar task id.

    Raises:
      Exception: if some parameter names are not recognized.

    Returns:
      The total auxiliary loss for the given task and at the given timestep.
    """
        task_id = state['task_id']
        loss = tf.constant(0.0)

        for key, value in kwargs.items():
            if key == 'disconnect_penalty':
                disconnect_penalty = value
                if disconnect_penalty > 0.0:
                    loss += disconnect_penalty * tf.reduce_sum([
                        router.get_probability_of_having_no_connections(
                            task_id) for router in routers
                    ])
            elif key == 'connection_penalty':
                connection_penalty = value
                if connection_penalty > 0.0:
                    loss += connection_penalty * tf.reduce_sum([
                        router.get_expected_number_of_connections_for_task(
                            task_id) for router in routers
                    ])
            elif key == 'budget_penalty':
                budget_penalty = value
                if budget_penalty > 0.0:
                    budget = kwargs['budget']
                    num_total_components = kwargs['num_total_components']

                    expected_number_of_connections = tf.reduce_sum([
                        router.get_expected_number_of_connections_for_task(
                            task_id) for router in routers
                    ])

                    expected_fraction_of_connections = (
                        expected_number_of_connections / num_total_components)

                    loss += budget_penalty * tf.math.maximum(
                        tf.constant(0.0),
                        expected_fraction_of_connections - budget)
            elif key in ['budget', 'num_total_components']:
                pass
            elif key == 'entropy_penalty':
                entropy_penalty = value
                if entropy_penalty > 0.0:
                    entropy_penalty_alpha = kwargs['entropy_penalty_alpha']
                    num_total_steps = kwargs['num_total_steps']

                    global_step = tf.train.get_or_create_global_step()

                    current_entropy_penalty = entropy_penalty * tf.math.pow(
                        global_step / num_total_steps, entropy_penalty_alpha)
                    current_entropy_penalty = tf.dtypes.cast(
                        current_entropy_penalty, dtype=tf.float32)

                    loss += current_entropy_penalty * tf.reduce_sum(
                        [router.get_entropy(task_id) for router in routers])
            elif key in ['entropy_penalty_alpha', 'num_total_steps']:
                pass
            elif key == 'l2_penalty':
                l2_penalty = value
                if l2_penalty > 0.0:
                    # Penalize all trainable variables apart from biases and
                    # allocation logits.
                    l2_penalty_vars = [
                        var for var in tf.trainable_variables()
                        if 'bias' not in var.name
                        and 'router_dist' not in var.name
                    ]

                    loss += tf.add_n(
                        [tf.nn.l2_loss(var)
                         for var in l2_penalty_vars]) * l2_penalty
            else:
                raise Exception(
                    'Unrecognized parameter for auxiliary losses: %s' % key)

        return loss
Ejemplo n.º 21
0
def create_optimizer(loss,
                     init_lr,
                     num_train_steps,
                     num_warmup_steps,
                     use_tpu,
                     optimizer="adamw",
                     poly_power=1.0,
                     start_warmup_step=0,
                     gradient_accumulation_steps=1,
                     grad_clipping=None):
  """Creates an optimizer training op."""
  global_step = tf.train.get_or_create_global_step()

  learning_rate = tf.constant(value=init_lr, shape=[], dtype=tf.float32)

  # Implements linear decay of the learning rate.
  learning_rate = tf.train.polynomial_decay(
      learning_rate,
      global_step,
      num_train_steps,
      end_learning_rate=0.0,
      power=poly_power,
      cycle=False)

  # Implements linear warmup. I.e., if global_step - start_warmup_step <
  # num_warmup_steps, the learning rate will be
  # `(global_step - start_warmup_step)/num_warmup_steps * init_lr`.
  if num_warmup_steps:
    tf.logging.info("++++++ warmup starts at step " + str(start_warmup_step)
                    + ", for " + str(num_warmup_steps) + " steps ++++++")
    global_steps_int = tf.cast(global_step, tf.int32)
    start_warm_int = tf.constant(start_warmup_step, dtype=tf.int32)
    global_steps_int = global_steps_int - start_warm_int
    warmup_steps_int = tf.constant(num_warmup_steps, dtype=tf.int32)

    global_steps_float = tf.cast(global_steps_int, tf.float32)
    warmup_steps_float = tf.cast(warmup_steps_int, tf.float32)

    warmup_percent_done = global_steps_float / warmup_steps_float
    warmup_learning_rate = init_lr * warmup_percent_done

    is_warmup = tf.cast(global_steps_int < warmup_steps_int, tf.float32)
    learning_rate = (
        (1.0 - is_warmup) * learning_rate + is_warmup * warmup_learning_rate)

  # It is OK that you use this optimizer for finetuning, since this
  # is how the model was trained (note that the Adam m/v variables are NOT
  # loaded from init_checkpoint.)
  # It is OK to use AdamW in the finetuning even the model is trained by LAMB.
  # As report in the Bert pulic github, the learning rate for SQuAD 1.1 finetune
  # is 3e-5, 4e-5 or 5e-5. For LAMB, the users can use 3e-4, 4e-4,or 5e-4 for a
  # batch size of 64 in the finetune.
  if optimizer == "adamw":
    tf.logging.info("using adamw")
    optimizer = AdamWeightDecayOptimizer(
        learning_rate=learning_rate,
        weight_decay_rate=0.01,
        beta_1=0.9,
        beta_2=0.999,
        epsilon=1e-6,
        exclude_from_weight_decay=["LayerNorm", "layer_norm", "bias"],
        grad_clipping=grad_clipping)
  else:
    raise ValueError("Not supported optimizer: ", optimizer)

  # This is empirically better than adding the optimizer after the `use_tpu` if.
  if gradient_accumulation_steps > 1:
    optimizer = GradientAccumulationOptimizer(
        optimizer,
        steps=gradient_accumulation_steps,
        grad_clipping=grad_clipping)

  if use_tpu:
    optimizer = tf.tpu.CrossShardOptimizer(optimizer)

  tvars = tf.trainable_variables()
  grads = tf.gradients(loss, tvars)

  # This is how the model was pre-trained.
  (grads, _) = tf.clip_by_global_norm(grads, clip_norm=1.0)

  train_op = optimizer.apply_gradients(
      zip(grads, tvars), global_step=global_step)

  return train_op
Ejemplo n.º 22
0
  def model_fn(features, labels, mode, params=None):
    """Constructs the object detection model.

    Args:
      features: Dictionary of feature tensors, returned from `input_fn`.
      labels: Dictionary of groundtruth tensors if mode is TRAIN or EVAL,
        otherwise None.
      mode: Mode key from tf.estimator.ModeKeys.
      params: Parameter dictionary passed from the estimator.

    Returns:
      An `EstimatorSpec` that encapsulates the model and its serving
        configurations.
    """
    params = params or {}
    total_loss, train_op, detections, export_outputs = None, None, None, None
    is_training = mode == tf.estimator.ModeKeys.TRAIN

    # Make sure to set the Keras learning phase. True during training,
    # False for inference.
    tf.keras.backend.set_learning_phase(is_training)
    # Set policy for mixed-precision training with Keras-based models.
    if use_tpu and train_config.use_bfloat16:
      from tensorflow.python.keras.engine import base_layer_utils  # pylint: disable=g-import-not-at-top
      # Enable v2 behavior, as `mixed_bfloat16` is only supported in TF 2.0.
      base_layer_utils.enable_v2_dtype_behavior()
      tf2.keras.mixed_precision.set_global_policy('mixed_bfloat16')
    detection_model = detection_model_fn(
        is_training=is_training, add_summaries=(not use_tpu))
    scaffold_fn = None

    if mode == tf.estimator.ModeKeys.TRAIN:
      labels = unstack_batch(
          labels,
          unpad_groundtruth_tensors=train_config.unpad_groundtruth_tensors)
    elif mode == tf.estimator.ModeKeys.EVAL:
      # For evaling on train data, it is necessary to check whether groundtruth
      # must be unpadded.
      boxes_shape = (
          labels[fields.InputDataFields.groundtruth_boxes].get_shape()
          .as_list())
      unpad_groundtruth_tensors = boxes_shape[1] is not None and not use_tpu
      labels = unstack_batch(
          labels, unpad_groundtruth_tensors=unpad_groundtruth_tensors)

    if mode in (tf.estimator.ModeKeys.TRAIN, tf.estimator.ModeKeys.EVAL):
      provide_groundtruth(detection_model, labels)

    preprocessed_images = features[fields.InputDataFields.image]

    side_inputs = detection_model.get_side_inputs(features)

    if use_tpu and train_config.use_bfloat16:
      with tf.tpu.bfloat16_scope():
        prediction_dict = detection_model.predict(
            preprocessed_images,
            features[fields.InputDataFields.true_image_shape], **side_inputs)
        prediction_dict = ops.bfloat16_to_float32_nested(prediction_dict)
    else:
      prediction_dict = detection_model.predict(
          preprocessed_images,
          features[fields.InputDataFields.true_image_shape], **side_inputs)

    def postprocess_wrapper(args):
      return detection_model.postprocess(args[0], args[1])

    if mode in (tf.estimator.ModeKeys.EVAL, tf.estimator.ModeKeys.PREDICT):
      if use_tpu and postprocess_on_cpu:
        detections = tf.tpu.outside_compilation(
            postprocess_wrapper,
            (prediction_dict,
             features[fields.InputDataFields.true_image_shape]))
      else:
        detections = postprocess_wrapper((
            prediction_dict,
            features[fields.InputDataFields.true_image_shape]))

    if mode == tf.estimator.ModeKeys.TRAIN:
      load_pretrained = hparams.load_pretrained if hparams else False
      if train_config.fine_tune_checkpoint and load_pretrained:
        if not train_config.fine_tune_checkpoint_type:
          # train_config.from_detection_checkpoint field is deprecated. For
          # backward compatibility, set train_config.fine_tune_checkpoint_type
          # based on train_config.from_detection_checkpoint.
          if train_config.from_detection_checkpoint:
            train_config.fine_tune_checkpoint_type = 'detection'
          else:
            train_config.fine_tune_checkpoint_type = 'classification'
        asg_map = detection_model.restore_map(
            fine_tune_checkpoint_type=train_config.fine_tune_checkpoint_type,
            load_all_detection_checkpoint_vars=(
                train_config.load_all_detection_checkpoint_vars))
        available_var_map = (
            variables_helper.get_variables_available_in_checkpoint(
                asg_map,
                train_config.fine_tune_checkpoint,
                include_global_step=False))
        if use_tpu:

          def tpu_scaffold():
            tf.train.init_from_checkpoint(train_config.fine_tune_checkpoint,
                                          available_var_map)
            return tf.train.Scaffold()

          scaffold_fn = tpu_scaffold
        else:
          tf.train.init_from_checkpoint(train_config.fine_tune_checkpoint,
                                        available_var_map)

    if mode in (tf.estimator.ModeKeys.TRAIN, tf.estimator.ModeKeys.EVAL):
      if (mode == tf.estimator.ModeKeys.EVAL and
          eval_config.use_dummy_loss_in_eval):
        total_loss = tf.constant(1.0)
        losses_dict = {'Loss/total_loss': total_loss}
      else:
        losses_dict = detection_model.loss(
            prediction_dict, features[fields.InputDataFields.true_image_shape])
        losses = [loss_tensor for loss_tensor in losses_dict.values()]
        if train_config.add_regularization_loss:
          regularization_losses = detection_model.regularization_losses()
          if use_tpu and train_config.use_bfloat16:
            regularization_losses = ops.bfloat16_to_float32_nested(
                regularization_losses)
          if regularization_losses:
            regularization_loss = tf.add_n(
                regularization_losses, name='regularization_loss')
            losses.append(regularization_loss)
            losses_dict['Loss/regularization_loss'] = regularization_loss
        total_loss = tf.add_n(losses, name='total_loss')
        losses_dict['Loss/total_loss'] = total_loss

      if 'graph_rewriter_config' in configs:
        graph_rewriter_fn = graph_rewriter_builder.build(
            configs['graph_rewriter_config'], is_training=is_training)
        graph_rewriter_fn()

      # TODO(rathodv): Stop creating optimizer summary vars in EVAL mode once we
      # can write learning rate summaries on TPU without host calls.
      global_step = tf.train.get_or_create_global_step()
      training_optimizer, optimizer_summary_vars = optimizer_builder.build(
          train_config.optimizer)

    if mode == tf.estimator.ModeKeys.TRAIN:
      if use_tpu:
        training_optimizer = tf.tpu.CrossShardOptimizer(training_optimizer)

      # Optionally freeze some layers by setting their gradients to be zero.
      trainable_variables = None
      include_variables = (
          train_config.update_trainable_variables
          if train_config.update_trainable_variables else None)
      exclude_variables = (
          train_config.freeze_variables
          if train_config.freeze_variables else None)
      trainable_variables = slim.filter_variables(
          tf.trainable_variables(),
          include_patterns=include_variables,
          exclude_patterns=exclude_variables)

      clip_gradients_value = None
      if train_config.gradient_clipping_by_norm > 0:
        clip_gradients_value = train_config.gradient_clipping_by_norm

      if not use_tpu:
        for var in optimizer_summary_vars:
          tf.summary.scalar(var.op.name, var)
      summaries = [] if use_tpu else None
      if train_config.summarize_gradients:
        summaries = ['gradients', 'gradient_norm', 'global_gradient_norm']
      train_op = slim.optimizers.optimize_loss(
          loss=total_loss,
          global_step=global_step,
          learning_rate=None,
          clip_gradients=clip_gradients_value,
          optimizer=training_optimizer,
          update_ops=detection_model.updates(),
          variables=trainable_variables,
          summaries=summaries,
          name='')  # Preventing scope prefix on all variables.

    if mode == tf.estimator.ModeKeys.PREDICT:
      exported_output = exporter_lib.add_output_tensor_nodes(detections)
      export_outputs = {
          tf.saved_model.signature_constants.PREDICT_METHOD_NAME:
              tf.estimator.export.PredictOutput(exported_output)
      }

    eval_metric_ops = None
    scaffold = None
    if mode == tf.estimator.ModeKeys.EVAL:
      class_agnostic = (
          fields.DetectionResultFields.detection_classes not in detections)
      groundtruth = _prepare_groundtruth_for_eval(
          detection_model, class_agnostic,
          eval_input_config.max_number_of_boxes)
      use_original_images = fields.InputDataFields.original_image in features
      if use_original_images:
        eval_images = features[fields.InputDataFields.original_image]
        true_image_shapes = tf.slice(
            features[fields.InputDataFields.true_image_shape], [0, 0], [-1, 3])
        original_image_spatial_shapes = features[fields.InputDataFields
                                                 .original_image_spatial_shape]
      else:
        eval_images = features[fields.InputDataFields.image]
        true_image_shapes = None
        original_image_spatial_shapes = None

      eval_dict = eval_util.result_dict_for_batched_example(
          eval_images,
          features[inputs.HASH_KEY],
          detections,
          groundtruth,
          class_agnostic=class_agnostic,
          scale_to_absolute=True,
          original_image_spatial_shapes=original_image_spatial_shapes,
          true_image_shapes=true_image_shapes)

      if fields.InputDataFields.image_additional_channels in features:
        eval_dict[fields.InputDataFields.image_additional_channels] = features[
            fields.InputDataFields.image_additional_channels]

      if class_agnostic:
        category_index = label_map_util.create_class_agnostic_category_index()
      else:
        category_index = label_map_util.create_category_index_from_labelmap(
            eval_input_config.label_map_path)
      vis_metric_ops = None
      if not use_tpu and use_original_images:
        keypoint_edges = [
            (kp.start, kp.end) for kp in eval_config.keypoint_edge]

        eval_metric_op_vis = vis_utils.VisualizeSingleFrameDetections(
            category_index,
            max_examples_to_draw=eval_config.num_visualizations,
            max_boxes_to_draw=eval_config.max_num_boxes_to_visualize,
            min_score_thresh=eval_config.min_score_threshold,
            use_normalized_coordinates=False,
            keypoint_edges=keypoint_edges or None)
        vis_metric_ops = eval_metric_op_vis.get_estimator_eval_metric_ops(
            eval_dict)

      # Eval metrics on a single example.
      eval_metric_ops = eval_util.get_eval_metric_ops_for_evaluators(
          eval_config, list(category_index.values()), eval_dict)
      for loss_key, loss_tensor in iter(losses_dict.items()):
        eval_metric_ops[loss_key] = tf.metrics.mean(loss_tensor)
      for var in optimizer_summary_vars:
        eval_metric_ops[var.op.name] = (var, tf.no_op())
      if vis_metric_ops is not None:
        eval_metric_ops.update(vis_metric_ops)
      eval_metric_ops = {str(k): v for k, v in eval_metric_ops.items()}

      if eval_config.use_moving_averages:
        variable_averages = tf.train.ExponentialMovingAverage(0.0)
        variables_to_restore = variable_averages.variables_to_restore()
        keep_checkpoint_every_n_hours = (
            train_config.keep_checkpoint_every_n_hours)
        saver = tf.train.Saver(
            variables_to_restore,
            keep_checkpoint_every_n_hours=keep_checkpoint_every_n_hours)
        scaffold = tf.train.Scaffold(saver=saver)

    # EVAL executes on CPU, so use regular non-TPU EstimatorSpec.
    if use_tpu and mode != tf.estimator.ModeKeys.EVAL:
      return tf.estimator.tpu.TPUEstimatorSpec(
          mode=mode,
          scaffold_fn=scaffold_fn,
          predictions=detections,
          loss=total_loss,
          train_op=train_op,
          eval_metrics=eval_metric_ops,
          export_outputs=export_outputs)
    else:
      if scaffold is None:
        keep_checkpoint_every_n_hours = (
            train_config.keep_checkpoint_every_n_hours)
        saver = tf.train.Saver(
            sharded=True,
            keep_checkpoint_every_n_hours=keep_checkpoint_every_n_hours,
            save_relative_paths=True)
        tf.add_to_collection(tf.GraphKeys.SAVERS, saver)
        scaffold = tf.train.Scaffold(saver=saver)
      return tf.estimator.EstimatorSpec(
          mode=mode,
          predictions=detections,
          loss=total_loss,
          train_op=train_op,
          eval_metric_ops=eval_metric_ops,
          export_outputs=export_outputs,
          scaffold=scaffold)
Ejemplo n.º 23
0
def create_optimizer(loss,
                     init_lr,
                     num_train_steps,
                     num_warmup_steps,
                     use_tpu,
                     trainable_variable_scope=""):
    """Creates an optimizer training op."""
    global_step = tf.train.get_or_create_global_step()

    learning_rate = tf.constant(value=init_lr, shape=[], dtype=tf.float32)

    # Implements linear decay of the learning rate.
    learning_rate = tf.train.polynomial_decay(learning_rate,
                                              global_step,
                                              num_train_steps,
                                              end_learning_rate=0.0,
                                              power=1.0,
                                              cycle=False)

    # Implements linear warmup. I.e., if global_step < num_warmup_steps, the
    # learning rate will be `global_step/num_warmup_steps * init_lr`.
    if num_warmup_steps:
        global_steps_int = tf.cast(global_step, tf.int32)
        warmup_steps_int = tf.constant(num_warmup_steps, dtype=tf.int32)

        global_steps_float = tf.cast(global_steps_int, tf.float32)
        warmup_steps_float = tf.cast(warmup_steps_int, tf.float32)

        warmup_percent_done = global_steps_float / warmup_steps_float
        warmup_learning_rate = init_lr * warmup_percent_done

        is_warmup = tf.cast(global_steps_int < warmup_steps_int, tf.float32)
        learning_rate = ((1.0 - is_warmup) * learning_rate +
                         is_warmup * warmup_learning_rate)

    # It is recommended that you use this optimizer for fine tuning, since this
    # is how the model was trained (note that the Adam m/v variables are NOT
    # loaded from init_checkpoint.)
    optimizer = AdamWeightDecayOptimizer(
        learning_rate=learning_rate,
        weight_decay_rate=0.01,
        beta_1=0.9,
        beta_2=0.999,
        epsilon=1e-6,
        exclude_from_weight_decay=["LayerNorm", "layer_norm", "bias"])

    if use_tpu:
        optimizer = tf.tpu.CrossShardOptimizer(optimizer)

    tvars = tf.trainable_variables(trainable_variable_scope)
    grads = tf.gradients(loss, tvars)

    # This is how the model was pre-trained.
    (grads, _) = tf.clip_by_global_norm(grads, clip_norm=1.0)

    train_op = optimizer.apply_gradients(zip(grads, tvars),
                                         global_step=global_step)

    new_global_step = global_step + 1
    train_op = tf.group(train_op, [global_step.assign(new_global_step)])
    return train_op
Ejemplo n.º 24
0
    def model_fn(features, labels, mode, params):  # pylint: disable=unused-argument
        """The `model_fn` for TPUEstimator."""

        tf.logging.info("*** Features ***")
        for name in sorted(features.keys()):
            tf.logging.info("  name = %s, shape = %s" %
                            (name, features[name].shape))

        input_ids = tf.reshape(features["input_ids"],
                               [-1, FLAGS.max_seq_length])
        input_mask = tf.reshape(features["input_mask"],
                                [-1, FLAGS.max_seq_length])
        segment_ids = tf.reshape(features["segment_ids"],
                                 [-1, FLAGS.max_seq_length])

        label_ids = features["label"]

        is_training = (mode == tf.estimator.ModeKeys.TRAIN)

        model = modeling.BertModel(
            config=bert_config,
            is_training=is_training,
            input_ids=input_ids,
            input_mask=input_mask,
            token_type_ids=segment_ids,
            use_one_hot_embeddings=use_one_hot_embeddings)

        (cpc_loss, _, logits,
         probabilities) = model_builder.create_model_bilin(
             model, label_ids, num_choices)

        total_loss = cpc_loss

        tvars = tf.trainable_variables()
        initialized_variable_names = {}
        scaffold_fn = None
        if init_checkpoint:
            (assignment_map, initialized_variable_names
             ) = modeling.get_assignment_map_from_checkpoint(
                 tvars, init_checkpoint)

            if use_tpu:

                def tpu_scaffold():
                    tf.train.init_from_checkpoint(init_checkpoint,
                                                  assignment_map)
                    return tf.train.Scaffold()

                scaffold_fn = tpu_scaffold
            else:
                tf.train.init_from_checkpoint(init_checkpoint, assignment_map)

        tf.logging.info("**** Trainable Variables ****")
        for var in tvars:
            init_string = ""
            if var.name in initialized_variable_names:
                init_string = ", *INIT_FROM_CKPT*"
            tf.logging.info("  name = %s, shape = %s%s", var.name, var.shape,
                            init_string)

        output_spec = None
        if mode == tf.estimator.ModeKeys.TRAIN:

            train_op = optimization.create_optimizer(total_loss, learning_rate,
                                                     num_train_steps,
                                                     num_warmup_steps, use_tpu)

            output_spec = contrib_tpu.TPUEstimatorSpec(mode=mode,
                                                       loss=total_loss,
                                                       train_op=train_op,
                                                       scaffold_fn=scaffold_fn)

        elif mode == tf.estimator.ModeKeys.EVAL:

            def metric_fn(cpc_loss, label_ids, logits):
                """Collect metrics for function."""

                predictions = tf.argmax(logits, axis=-1, output_type=tf.int32)
                accuracy = tf.metrics.accuracy(labels=label_ids,
                                               predictions=predictions)
                cpc_loss_metric = tf.metrics.mean(values=cpc_loss)
                metric_dict = {
                    "eval_accuracy": accuracy,
                    "eval_cpc_loss": cpc_loss_metric,
                }
                return metric_dict

            eval_metrics = (metric_fn, [cpc_loss, label_ids, logits])
            output_spec = contrib_tpu.TPUEstimatorSpec(
                mode=mode,
                loss=total_loss,
                eval_metrics=eval_metrics,
                scaffold_fn=scaffold_fn)
        else:
            output_spec = contrib_tpu.TPUEstimatorSpec(
                mode=mode,
                predictions={"probabilities": probabilities},
                scaffold_fn=scaffold_fn)
        return output_spec
Ejemplo n.º 25
0
vgg19_path = 'imagenet-vgg-verydeep-19.mat'
pretrain_model_path = 'srdplus-pretrained/'
sample_path = 'ghost-free-shadow-removal/Samples'

import tensorflow.compat.v1 as tf
import numpy as np
import matplotlib.pyplot as plt
from networks import build_aggasatt_joint

tf.disable_eager_execution()

with tf.variable_scope(tf.get_variable_scope(), reuse=tf.AUTO_REUSE):
    input = tf.placeholder(tf.float32, shape=[1, 256, 256, 3])
    shadow_free_image = build_aggasatt_joint(input, 64, vgg19_path)

sess = tf.Session()
sess.run(tf.global_variables_initializer())
idtd_ckpt = tf.train.get_checkpoint_state(pretrain_model_path)
saver_restore = tf.train.Saver([var for var in tf.trainable_variables()])
print('loaded ' + idtd_ckpt.model_checkpoint_path)
saver_restore.restore(sess, idtd_ckpt.model_checkpoint_path)

frozen_graph_def = tf.graph_util.convert_variables_to_constants(
    sess, sess.graph_def, ['g_conv_img/BiasAdd', 'g_conv_mask/BiasAdd'])

# Save the frozen graph
with open('shadow_removal.pb', 'wb') as f:
    f.write(frozen_graph_def.SerializeToString())
Ejemplo n.º 26
0
def _model_fn(features, labels, mode, params, model, variable_filter_fn=None):
  """Model definition entry.

  Args:
    features: the input image tensor with shape [batch_size, height, width, 3].
      The height and width are fixed and equal.
    labels: the input labels in a dictionary. The labels include class targets
      and box targets which are dense label maps. The labels are generated from
      get_input_fn function in data/dataloader.py
    mode: the mode of TPUEstimator including TRAIN and EVAL.
    params: the dictionary defines hyperparameters of model. The default
      settings are in default_hparams function in this file.
    model: the model outputs class logits and box regression outputs.
    variable_filter_fn: the filter function that takes trainable_variables and
      returns the variable list after applying the filter rule.

  Returns:
    tpu_spec: the TPUEstimatorSpec to run training, evaluation, or prediction.

  Raises:
    RuntimeError: if both ckpt and backbone_ckpt are set.
  """
  is_tpu = params['strategy'] == 'tpu'
  if params['img_summary_steps']:
    utils.image('input_image', features, is_tpu)
  training_hooks = []
  params['is_training_bn'] = (mode == tf.estimator.ModeKeys.TRAIN)

  if params['use_keras_model']:

    def model_fn(inputs):
      model = efficientdet_keras.EfficientDetNet(
          config=hparams_config.Config(params))
      cls_out_list, box_out_list = model(inputs, params['is_training_bn'])
      cls_outputs, box_outputs = {}, {}
      for i in range(params['min_level'], params['max_level'] + 1):
        cls_outputs[i] = cls_out_list[i - params['min_level']]
        box_outputs[i] = box_out_list[i - params['min_level']]
      return cls_outputs, box_outputs
  else:
    model_fn = functools.partial(model, config=hparams_config.Config(params))

  precision = utils.get_precision(params['strategy'], params['mixed_precision'])
  cls_outputs, box_outputs = utils.build_model_with_precision(
      precision, model_fn, features, params['is_training_bn'])

  levels = cls_outputs.keys()
  for level in levels:
    cls_outputs[level] = tf.cast(cls_outputs[level], tf.float32)
    box_outputs[level] = tf.cast(box_outputs[level], tf.float32)

  # Set up training loss and learning rate.
  update_learning_rate_schedule_parameters(params)
  global_step = tf.train.get_or_create_global_step()
  learning_rate = learning_rate_schedule(params, global_step)

  # cls_loss and box_loss are for logging. only total_loss is optimized.
  det_loss, cls_loss, box_loss = detection_loss(
      cls_outputs, box_outputs, labels, params)
  reg_l2loss = reg_l2_loss(params['weight_decay'])
  total_loss = det_loss + reg_l2loss

  if mode == tf.estimator.ModeKeys.TRAIN:
    utils.scalar('lrn_rate', learning_rate, is_tpu)
    utils.scalar('trainloss/cls_loss', cls_loss, is_tpu)
    utils.scalar('trainloss/box_loss', box_loss, is_tpu)
    utils.scalar('trainloss/det_loss', det_loss, is_tpu)
    utils.scalar('trainloss/reg_l2_loss', reg_l2loss, is_tpu)
    utils.scalar('trainloss/loss', total_loss, is_tpu)
    train_epochs = tf.cast(global_step, tf.float32) / params['steps_per_epoch']
    utils.scalar('train_epochs', train_epochs, is_tpu)

  moving_average_decay = params['moving_average_decay']
  if moving_average_decay:
    ema = tf.train.ExponentialMovingAverage(
        decay=moving_average_decay, num_updates=global_step)
    ema_vars = utils.get_ema_vars()

  if mode == tf.estimator.ModeKeys.TRAIN:
    if params['optimizer'].lower() == 'sgd':
      optimizer = tf.train.MomentumOptimizer(
          learning_rate, momentum=params['momentum'])
    elif params['optimizer'].lower() == 'adam':
      optimizer = tf.train.AdamOptimizer(learning_rate)
    else:
      raise ValueError('optimizers should be adam or sgd')

    if is_tpu:
      optimizer = tf.tpu.CrossShardOptimizer(optimizer)
    elif params['mixed_precision']:
      optimizer = tf.train.experimental.enable_mixed_precision_graph_rewrite(optimizer)

    # Batch norm requires update_ops to be added as a train_op dependency.
    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
    var_list = tf.trainable_variables()
    if variable_filter_fn:
      var_list = variable_filter_fn(var_list)

    if params.get('clip_gradients_norm', None):
      logging.info('clip gradients norm by %f', params['clip_gradients_norm'])
      grads_and_vars = optimizer.compute_gradients(total_loss, var_list)
      with tf.name_scope('clip'):
        grads = [gv[0] for gv in grads_and_vars]
        tvars = [gv[1] for gv in grads_and_vars]
        # First clip each variable's norm, then clip global norm.
        clip_norm = abs(params['clip_gradients_norm'])
        clipped_grads = [
            tf.clip_by_norm(g, clip_norm) if g is not None else None
            for g in grads
        ]
        clipped_grads, _ = tf.clip_by_global_norm(clipped_grads, clip_norm)
        utils.scalar('gradient_norm', tf.linalg.global_norm(clipped_grads),
                     is_tpu)
        grads_and_vars = list(zip(clipped_grads, tvars))

      with tf.control_dependencies(update_ops):
        train_op = optimizer.apply_gradients(grads_and_vars, global_step)
    else:
      with tf.control_dependencies(update_ops):
        train_op = optimizer.minimize(
            total_loss, global_step, var_list=var_list)

    if moving_average_decay:
      with tf.control_dependencies([train_op]):
        train_op = ema.apply(ema_vars)

  else:
    train_op = None

  eval_metrics = None
  if mode == tf.estimator.ModeKeys.EVAL:

    def metric_fn(**kwargs):
      """Returns a dictionary that has the evaluation metrics."""
      if params['nms_configs'].get('pyfunc', True):
        detections_bs = []
        for index in range(kwargs['boxes'].shape[0]):
          nms_configs = params['nms_configs']
          detections = tf.numpy_function(
              functools.partial(nms_np.per_class_nms, nms_configs=nms_configs),
              [
                  kwargs['boxes'][index],
                  kwargs['scores'][index],
                  kwargs['classes'][index],
                  tf.slice(kwargs['image_ids'], [index], [1]),
                  tf.slice(kwargs['image_scales'], [index], [1]),
                  params['num_classes'],
                  nms_configs['max_output_size'],
              ], tf.float32)
          detections_bs.append(detections)
        detections_bs = postprocess.transform_detections(
            tf.stack(detections_bs))
      else:
        # These two branches should be equivalent, but currently they are not.
        # TODO(tanmingxing): enable the non_pyfun path after bug fix.
        nms_boxes, nms_scores, nms_classes, _ = postprocess.per_class_nms(
            params, kwargs['boxes'], kwargs['scores'], kwargs['classes'],
            kwargs['image_scales'])
        img_ids = tf.cast(
            tf.expand_dims(kwargs['image_ids'], -1), nms_scores.dtype)
        detections_bs = [
            img_ids * tf.ones_like(nms_scores),
            nms_boxes[:, :, 1],
            nms_boxes[:, :, 0],
            nms_boxes[:, :, 3] - nms_boxes[:, :, 1],
            nms_boxes[:, :, 2] - nms_boxes[:, :, 0],
            nms_scores,
            nms_classes,
        ]
        detections_bs = tf.stack(detections_bs, axis=-1, name='detnections')

      if params.get('testdev_dir', None):
        logging.info('Eval testdev_dir %s', params['testdev_dir'])
        eval_metric = coco_metric.EvaluationMetric(
            testdev_dir=params['testdev_dir'])
        coco_metrics = eval_metric.estimator_metric_fn(detections_bs,
                                                       tf.zeros([1]))
      else:
        logging.info('Eval val with groudtruths %s.', params['val_json_file'])
        eval_metric = coco_metric.EvaluationMetric(
            filename=params['val_json_file'], label_map=params['label_map'])
        coco_metrics = eval_metric.estimator_metric_fn(
            detections_bs, kwargs['groundtruth_data'])

      # Add metrics to output.
      cls_loss = tf.metrics.mean(kwargs['cls_loss_repeat'])
      box_loss = tf.metrics.mean(kwargs['box_loss_repeat'])
      output_metrics = {
          'cls_loss': cls_loss,
          'box_loss': box_loss,
      }
      output_metrics.update(coco_metrics)
      return output_metrics

    cls_loss_repeat = tf.reshape(
        tf.tile(tf.expand_dims(cls_loss, 0), [
            params['batch_size'],
        ]), [params['batch_size'], 1])
    box_loss_repeat = tf.reshape(
        tf.tile(tf.expand_dims(box_loss, 0), [
            params['batch_size'],
        ]), [params['batch_size'], 1])

    cls_outputs = postprocess.to_list(cls_outputs)
    box_outputs = postprocess.to_list(box_outputs)
    params['nms_configs']['max_nms_inputs'] = anchors.MAX_DETECTION_POINTS
    boxes, scores, classes = postprocess.pre_nms(params, cls_outputs,
                                                 box_outputs)
    metric_fn_inputs = {
        'cls_loss_repeat': cls_loss_repeat,
        'box_loss_repeat': box_loss_repeat,
        'image_ids': labels['source_ids'],
        'groundtruth_data': labels['groundtruth_data'],
        'image_scales': labels['image_scales'],
        'boxes': boxes,
        'scores': scores,
        'classes': classes,
    }
    eval_metrics = (metric_fn, metric_fn_inputs)

  checkpoint = params.get('ckpt') or params.get('backbone_ckpt')

  if checkpoint and mode == tf.estimator.ModeKeys.TRAIN:
    # Initialize the model from an EfficientDet or backbone checkpoint.
    if params.get('ckpt') and params.get('backbone_ckpt'):
      raise RuntimeError(
          '--backbone_ckpt and --checkpoint are mutually exclusive')

    if params.get('backbone_ckpt'):
      var_scope = params['backbone_name'] + '/'
      if params['ckpt_var_scope'] is None:
        # Use backbone name as default checkpoint scope.
        ckpt_scope = params['backbone_name'] + '/'
      else:
        ckpt_scope = params['ckpt_var_scope'] + '/'
    else:
      # Load every var in the given checkpoint
      var_scope = ckpt_scope = '/'

    def scaffold_fn():
      """Loads pretrained model through scaffold function."""
      logging.info('restore variables from %s', checkpoint)

      var_map = utils.get_ckpt_var_map(
          ckpt_path=checkpoint,
          ckpt_scope=ckpt_scope,
          var_scope=var_scope,
          skip_mismatch=params['skip_mismatch'])

      tf.train.init_from_checkpoint(checkpoint, var_map)
      return tf.train.Scaffold()
  elif mode == tf.estimator.ModeKeys.EVAL and moving_average_decay:

    def scaffold_fn():
      """Load moving average variables for eval."""
      logging.info('Load EMA vars with ema_decay=%f', moving_average_decay)
      restore_vars_dict = ema.variables_to_restore(ema_vars)
      saver = tf.train.Saver(restore_vars_dict)
      return tf.train.Scaffold(saver=saver)
  else:
    scaffold_fn = None

  if is_tpu:
    return tf.estimator.tpu.TPUEstimatorSpec(
        mode=mode,
        loss=total_loss,
        train_op=train_op,
        eval_metrics=eval_metrics,
        host_call=utils.get_tpu_host_call(global_step, params),
        scaffold_fn=scaffold_fn,
        training_hooks=training_hooks)
  else:
    # Profile every 1K steps.
    if params.get('profile', False):
      profile_hook = tf.estimator.ProfilerHook(
          save_steps=1000, output_dir=params['model_dir'], show_memory=True)
      training_hooks.append(profile_hook)

      # Report memory allocation if OOM; it will slow down the running.
      class OomReportingHook(tf.estimator.SessionRunHook):

        def before_run(self, run_context):
          return tf.estimator.SessionRunArgs(
              fetches=[],
              options=tf.RunOptions(report_tensor_allocations_upon_oom=True))

      training_hooks.append(OomReportingHook())

    logging_hook = tf.estimator.LoggingTensorHook(
        {
            'step': global_step,
            'det_loss': det_loss,
            'cls_loss': cls_loss,
            'box_loss': box_loss,
        },
        every_n_iter=params.get('iterations_per_loop', 100),
    )
    training_hooks.append(logging_hook)

    eval_metric_ops = (
        eval_metrics[0](**eval_metrics[1]) if eval_metrics else None)
    return tf.estimator.EstimatorSpec(
        mode=mode,
        loss=total_loss,
        train_op=train_op,
        eval_metric_ops=eval_metric_ops,
        scaffold=scaffold_fn() if scaffold_fn else None,
        training_hooks=training_hooks)
Ejemplo n.º 27
0
def resnet_model_fn(features, labels, mode, params):
    """The model_fn for ResNet to be used with TPUEstimator.

  Args:
    features: `Tensor` of batched images. If transpose_input is enabled, it is
      transposed to device layout and reshaped to 1D tensor.
    labels: `Tensor` of labels for the data samples
    mode: one of `tf.estimator.ModeKeys.{TRAIN,EVAL,PREDICT}`
    params: `dict` of parameters passed to the model from the TPUEstimator,
      `params['batch_size']` is always provided and should be used as the
      effective batch size.

  Returns:
    A `TPUEstimatorSpec` for the model
  """
    if isinstance(features, dict):
        features = features['feature']

    # In most cases, the default data format NCHW instead of NHWC should be
    # used for a significant performance boost on GPU/TPU. NHWC should be used
    # only if the network needs to be run on CPU since the pooling operations
    # are only supported on NHWC.
    if params['data_format'] == 'channels_first':
        assert not params['transpose_input']  # channels_first only for GPU
        features = tf.transpose(features, [0, 3, 1, 2])

    if params['transpose_input'] and mode != tf.estimator.ModeKeys.PREDICT:
        image_size = tf.sqrt(tf.shape(features)[0] / (3 * tf.shape(labels)[0]))
        features = tf.reshape(features, [image_size, image_size, 3, -1])
        features = tf.transpose(features, [3, 0, 1, 2])  # HWCN to NHWC

    # Normalize the image to zero mean and unit variance.
    features -= tf.constant(MEAN_RGB, shape=[1, 1, 3], dtype=features.dtype)
    features /= tf.constant(STDDEV_RGB, shape=[1, 1, 3], dtype=features.dtype)

    # DropBlock keep_prob for the 4 block groups of ResNet architecture.
    # None means applying no DropBlock at the corresponding block group.
    dropblock_keep_probs = [None] * 4
    if params['dropblock_groups']:
        # Scheduled keep_prob for DropBlock.
        train_steps = tf.cast(params['train_steps'], tf.float32)
        current_step = tf.cast(tf.train.get_global_step(), tf.float32)
        current_ratio = current_step / train_steps
        dropblock_keep_prob = (1 - current_ratio *
                               (1 - params['dropblock_keep_prob']))

        # Computes DropBlock keep_prob for different block groups of ResNet.
        dropblock_groups = [
            int(x) for x in params['dropblock_groups'].split(',')
        ]
        for block_group in dropblock_groups:
            if block_group < 1 or block_group > 4:
                raise ValueError(
                    'dropblock_groups should be a comma separated list of integers '
                    'between 1 and 4 (dropblcok_groups: {}).'.format(
                        params['dropblock_groups']))
            dropblock_keep_probs[block_group - 1] = 1 - (
                (1 - dropblock_keep_prob) / 4.0**(4 - block_group))

    # This nested function allows us to avoid duplicating the logic which
    # builds the network, for different values of --precision.
    def build_network():
        network = resnet_model.resnet_v1(
            resnet_depth=params['resnet_depth'],
            num_classes=params['num_label_classes'],
            dropblock_size=params['dropblock_size'],
            dropblock_keep_probs=dropblock_keep_probs,
            data_format=params['data_format'])
        return network(inputs=features,
                       is_training=(mode == tf.estimator.ModeKeys.TRAIN))

    if params['precision'] == 'bfloat16':
        with contrib_tpu.bfloat16_scope():
            logits = build_network()
        logits = tf.cast(logits, tf.float32)
    elif params['precision'] == 'float32':
        logits = build_network()

    if mode == tf.estimator.ModeKeys.PREDICT:
        predictions = {
            'classes': tf.argmax(logits, axis=1),
            'probabilities': tf.nn.softmax(logits, name='softmax_tensor')
        }
        return tf.estimator.EstimatorSpec(
            mode=mode,
            predictions=predictions,
            export_outputs={
                'classify': tf.estimator.export.PredictOutput(predictions)
            })

    # If necessary, in the model_fn, use params['batch_size'] instead the batch
    # size flags (--train_batch_size or --eval_batch_size).
    batch_size = params['batch_size']  # pylint: disable=unused-variable

    # Calculate loss, which includes softmax cross entropy and L2 regularization.
    one_hot_labels = tf.one_hot(labels, params['num_label_classes'])
    cross_entropy = tf.losses.softmax_cross_entropy(
        logits=logits,
        onehot_labels=one_hot_labels,
        label_smoothing=params['label_smoothing'])

    # Add weight decay to the loss for non-batch-normalization variables.
    loss = cross_entropy + params['weight_decay'] * tf.add_n([
        tf.nn.l2_loss(v) for v in tf.trainable_variables()
        if 'batch_normalization' not in v.name
    ])

    host_call = None
    if mode == tf.estimator.ModeKeys.TRAIN:
        # Compute the current epoch and associated learning rate from global_step.
        global_step = tf.train.get_global_step()
        steps_per_epoch = params['num_train_images'] / params[
            'train_batch_size']
        current_epoch = (tf.cast(global_step, tf.float32) / steps_per_epoch)
        # LARS is a large batch optimizer. LARS enables higher accuracy at batch 16K
        # and larger batch sizes.
        if params['enable_lars']:
            learning_rate = 0.0
            optimizer = lars_util.init_lars_optimizer(current_epoch, params)
            raise ValueError(
                'LARS unexpected in the context of IGT experiments.')
        else:
            learning_rate = linear_learning_rate_schedule(params, global_step)

            if FLAGS.optimizer == 'momentum':
                tf.logging.info('Using MomentumOptimizer ({}).'.format(
                    params['momentum']))
                optimizer = tf.train.MomentumOptimizer(
                    learning_rate=learning_rate,
                    momentum=params['momentum'],
                    use_nesterov=False)

            elif FLAGS.optimizer == 'adam':
                tf.logging.info('Using AdamOptimizer')
                optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate)

            elif FLAGS.optimizer == 'eigt':
                tf.logging.info('Using ExpIgtOptimizer {} tail: {}'.format(
                    FLAGS.igt_optimizer, FLAGS.tail_fraction))
                optimizer = exp_igt_optimizer.ExpIgtOptimizer(
                    learning_rate,
                    tail_fraction=FLAGS.tail_fraction,
                    optimizer=FLAGS.igt_optimizer)

            else:
                raise ValueError('{} is not a supported optimizer'.format(
                    FLAGS.optimizer))

        if params['use_tpu']:
            # When using TPU, wrap the optimizer with CrossShardOptimizer which
            # handles synchronization details between different TPU cores. To the
            # user, this should look like regular synchronous training.
            optimizer = contrib_tpu.CrossShardOptimizer(optimizer)

        # Batch normalization requires UPDATE_OPS to be added as a dependency to
        # the train operation.
        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        with tf.control_dependencies(update_ops):
            train_op = optimizer.minimize(loss, global_step)

        if not params['skip_host_call']:

            def host_call_fn(gs, loss, lr, ce):
                """Training host call.

        Creates scalar summaries for training metrics.

        This function is executed on the CPU and should not directly reference
        any Tensors in the rest of the `model_fn`. To pass Tensors from the
        model to the `metric_fn`, provide as part of the `host_call`. See
        https://www.tensorflow.org/api_docs/python/tf/contrib/tpu/TPUEstimatorSpec
        for more information.

        Arguments should match the list of `Tensor` objects passed as the second
        element in the tuple passed to `host_call`.

        Args:
          gs: `Tensor with shape `[batch]` for the global_step
          loss: `Tensor` with shape `[batch]` for the training loss.
          lr: `Tensor` with shape `[batch]` for the learning_rate.
          ce: `Tensor` with shape `[batch]` for the current_epoch.

        Returns:
          List of summary ops to run on the CPU host.
        """
                gs = gs[0]
                # Host call fns are executed params['iterations_per_loop'] times after
                # one TPU loop is finished, setting max_queue value to the same as
                # number of iterations will make the summary writer only flush the data
                # to storage once per loop.
                with summary.create_file_writer(
                        get_model_dir(params),
                        max_queue=params['iterations_per_loop']).as_default():
                    with summary.always_record_summaries():
                        summary.scalar('loss', loss[0], step=gs)
                        summary.scalar('learning_rate', lr[0], step=gs)
                        summary.scalar('current_epoch', ce[0], step=gs)

                        return summary.all_summary_ops()

            # To log the loss, current learning rate, and epoch for Tensorboard, the
            # summary op needs to be run on the host CPU via host_call. host_call
            # expects [batch_size, ...] Tensors, thus reshape to introduce a batch
            # dimension. These Tensors are implicitly concatenated to
            # [params['batch_size']].
            gs_t = tf.reshape(global_step, [1])
            loss_t = tf.reshape(loss, [1])
            lr_t = tf.reshape(learning_rate, [1])
            ce_t = tf.reshape(current_epoch, [1])

            host_call = (host_call_fn, [gs_t, loss_t, lr_t, ce_t])

    else:
        train_op = None

    eval_metrics = None
    scaffold_fn = None
    if mode == tf.estimator.ModeKeys.EVAL:

        def metric_fn(labels, logits):
            """Evaluation metric function.

      Evaluates accuracy.

      This function is executed on the CPU and should not directly reference
      any Tensors in the rest of the `model_fn`. To pass Tensors from the model
      to the `metric_fn`, provide as part of the `eval_metrics`. See
      https://www.tensorflow.org/api_docs/python/tf/contrib/tpu/TPUEstimatorSpec
      for more information.

      Arguments should match the list of `Tensor` objects passed as the second
      element in the tuple passed to `eval_metrics`.

      Args:
        labels: `Tensor` with shape `[batch]`.
        logits: `Tensor` with shape `[batch, num_classes]`.

      Returns:
        A dict of the metrics to return from evaluation.
      """
            predictions = tf.argmax(logits, axis=1)
            top_1_accuracy = tf.metrics.accuracy(labels, predictions)
            in_top_5 = tf.cast(tf.nn.in_top_k(logits, labels, 5), tf.float32)
            top_5_accuracy = tf.metrics.mean(in_top_5)

            return {
                'top_1_accuracy': top_1_accuracy,
                'top_5_accuracy': top_5_accuracy,
            }

        eval_metrics = (metric_fn, [labels, logits])

        if FLAGS.mode == 'eval_igt' and FLAGS.igt_eval_mode == 'true':
            tf.logging.info('Using true param loading saver.')

            def scaffold_fn_true_params():
                """Returns a scaffold that loads the true values into vars."""
                var_mapping = {}
                trainable_vars = set(tf.trainable_variables())
                for var in tf.global_variables():
                    if var in trainable_vars:
                        var_mapping[var.op.name + '/true_param'] = var
                    else:
                        var_mapping[var.op.name] = var

                tf.logging.info('Mapping: {}'.format(var_mapping))
                saver = tf.train.Saver(var_list=var_mapping, sharded=True)
                return tf.train.Scaffold(saver=saver)

            scaffold_fn = scaffold_fn_true_params

    return contrib_tpu.TPUEstimatorSpec(mode=mode,
                                        loss=loss,
                                        train_op=train_op,
                                        host_call=host_call,
                                        eval_metrics=eval_metrics,
                                        scaffold_fn=scaffold_fn)
Ejemplo n.º 28
0
def compute_gradients(total_loss):
    """Separate the function of gradient computation."""
    monitor_dict = {}
    print(FLAGS.weight_decay, "==weight_decay==")
    print(FLAGS.lr_layer_decay_rate, "==lr_layer_decay_rate==")
    print(FLAGS.use_wd_exclusion, "==use_wd_exclusion==")
    print(FLAGS.adam_correction, "==adam_correction==")

    ##### Configure optimizer
    global_step = tf.train.get_or_create_global_step()

    # Warmup the learning rate linearly
    if FLAGS.warmup_steps > 0:
        progress = (tf.cast(global_step, tf.float32) /
                    tf.cast(FLAGS.warmup_steps, tf.float32))
    else:
        progress = 1.0
    curr_ratio = progress + (1.0 - progress) * FLAGS.min_lr_ratio
    warmup_lr = curr_ratio * FLAGS.learning_rate

    # Decay the learning rate
    if FLAGS.decay_method == "poly":
        decay_lr = tf.train.polynomial_decay(
            FLAGS.learning_rate,
            global_step=global_step - FLAGS.warmup_steps,
            decay_steps=FLAGS.train_steps - FLAGS.warmup_steps,
            end_learning_rate=FLAGS.learning_rate * FLAGS.min_lr_ratio)
    elif FLAGS.decay_method == "cos":
        decay_lr = tf.train.cosine_decay(
            FLAGS.learning_rate,
            global_step=global_step - FLAGS.warmup_steps,
            decay_steps=FLAGS.train_steps - FLAGS.warmup_steps,
            alpha=FLAGS.min_lr_ratio)
    else:
        raise ValueError(FLAGS.decay_method)

    learning_rate = tf.where(global_step < FLAGS.warmup_steps, warmup_lr,
                             decay_lr)

    if (FLAGS.weight_decay > 0 and not FLAGS.use_tpu
            and FLAGS.num_core_per_host > 1):
        raise ValueError("Do not support `weight_decay > 0` with multi-gpu "
                         "training so far.")

    if FLAGS.use_wd_exclusion:
        exclude_from_weight_decay = ["LayerNorm", "layer_norm", "bias"]
    else:
        exclude_from_weight_decay = []

    print(exclude_from_weight_decay, "==exclude_from_weight_decay==")

    optimizer = AdamWeightDecayOptimizer(
        learning_rate=learning_rate,
        beta_1=FLAGS.adam_beta1,
        beta_2=FLAGS.adam_beta2,
        epsilon=FLAGS.adam_epsilon,
        bias_correction=FLAGS.adam_correction,
        exclude_from_weight_decay=exclude_from_weight_decay,
        weight_decay_rate=FLAGS.weight_decay)

    if FLAGS.use_tpu:
        if FLAGS.per_core_clip:
            optimizer = tpu_optimizer.CrossShardOptimizer(
                optimizer, skip_nan_grad=FLAGS.skip_nan_grad)
        else:
            optimizer = tpu_optimizer.CrossShardOptimizer(
                optimizer, skip_nan_grad=FLAGS.skip_nan_grad, clip=FLAGS.clip)

    ##### Compute gradient
    variables = tf.trainable_variables()
    gradients = tf.gradients(total_loss, variables)

    if FLAGS.clip > 0 and FLAGS.per_core_clip:
        tf.logging.info("Clip local gradient with norm %.3f.", FLAGS.clip)
        clipped, local_gnorm = tf.clip_by_global_norm(gradients, FLAGS.clip)
    else:
        tf.logging.info("Do not clip local gradient.")
        clipped = list(gradients)
        local_gnorm = tf.linalg.global_norm(gradients)

    # layer-wise learning rate decay
    if FLAGS.lr_layer_decay_rate != 1.0:

        def _get_layer_id(name):
            if "model/input" in name:
                return 0
            m = re.search(r"model/(encoder|decoder)/layer_(\d+?)/", name)
            if not m: return None
            return int(m.group(2)) + 1

        n_layer = 0
        for i in range(len(clipped)):
            layer_id = _get_layer_id(variables[i].name)
            if layer_id is None: continue
            n_layer = max(n_layer, layer_id + 1)

        for i in range(len(clipped)):
            layer_id = _get_layer_id(variables[i].name)
            if layer_id is not None:
                abs_rate = FLAGS.lr_layer_decay_rate**(n_layer - 1 - layer_id)
                tf.logging.info("Apply mult %.4f to the grad of %s", abs_rate,
                                variables[i].name)
                if isinstance(clipped[i], tf.IndexedSlices):
                    clipped[i] = tf.IndexedSlices(clipped[i].values * abs_rate,
                                                  clipped[i].indices,
                                                  clipped[i].dense_shape)
                else:
                    clipped[i] *= abs_rate
            else:
                tf.logging.info("Grad of %s is not decayed.",
                                variables[i].name)

    grad_and_vars = list(zip(clipped, variables))

    monitor_dict["local_gnorm"] = local_gnorm
    monitor_dict["learning_rate"] = learning_rate

    return optimizer, grad_and_vars, global_step, monitor_dict
Ejemplo n.º 29
0
def manually_compute_losses(numpy_inputs, inputs_placeholder, loss, num_workers,
                            params):
  """Manually compute the losses each worker should report in tf_cnn_benchmarks.

  This function essentially simulates tf_cnn_benchmarks, computing what the loss
  of each worker should be. The caller should create a model, that takes in
  images from `inputs_placeholder`, a tf.placeholder, and computes `loss`.

  This function, and all ops passed to this function, must be run under a
  tf.device('cpu:0') context manager.

  Non-SGD optimizers are not supported with multiple workers.

  Args:
    numpy_inputs: A Numpy array to use as the input images.
    inputs_placeholder: A tf.placeholder tensor, where input images can be fed
      into.
    loss: A scalar tensor representing the loss of the model, which is obtained
      from the input images in inputs_placeholder.
    num_workers: How many workers should be simulated.
    params: Params tuple. This doesn't have to have information about the
      distributed cluster, such as --num_workers, as num_workers is passed in
      separately.

  Returns:
    A list of list of losses. return_value[i][j] is the loss of the ith worker
    after the jth step.
  """
  batch_size = params.batch_size * params.num_gpus
  assert numpy_inputs.shape[0] % (num_workers * batch_size) == 0
  l2_loss = tf.add_n([tf.nn.l2_loss(x) for x in tf.trainable_variables()])
  total_loss = loss + params.weight_decay * l2_loss
  reported_loss = (loss if params.loss_type_to_report == 'base_loss'
                   else total_loss)
  gradient_multiplier = 1
  if params.variable_update in ('replicated', 'distributed_all_reduce'):
    # In certain variable updates, tf_cnn_benchmarks add the gradients of the
    # GPUs instead of taking their mean, making the gradients effectively
    # params.num_gpu times higher.
    # TODO(b/62722498): Make all variable updates consistent.
    gradient_multiplier = params.num_gpus

  opt = benchmark_cnn.get_optimizer(params, params.init_learning_rate)
  grad_vars = opt.compute_gradients(
      total_loss, grad_loss=tf.constant(gradient_multiplier, dtype=tf.float32))
  grads = [g for g, _ in grad_vars]
  # We apply gradients from a placeholder. That way, we can first compute the
  # gradients from each worker, then afterwards apply them one by one by feeding
  # them into the placeholder.
  placeholder_grad_vars = [(tf.placeholder(g.dtype, g.shape), v)
                           for g, v in grad_vars]
  placeholder_grads = [g for g, _ in placeholder_grad_vars]
  apply_grads_op = opt.apply_gradients(placeholder_grad_vars)

  batch_iterators = [_worker_batches_in_numpy_array(numpy_inputs, batch_size,
                                                    shift_ratio=i / num_workers)
                     for i in range(num_workers)]
  # Set the GPU count to 0, to avoid taking all the GPU memory. Unfortunately,
  # doing so still takes up about ~1GB for some reason.
  with tf.Session(config=tf.ConfigProto(device_count={'GPU': 0})) as sess:
    sess.run(tf.global_variables_initializer())
    losses = [[] for _ in range(num_workers)]
    for i in range(params.num_batches):
      computed_grads = []
      for j in range(num_workers):
        batch_feed = next(batch_iterators[j])
        batch_feed = batch_feed / 127.5 - 1
        worker_loss, worker_grads = sess.run((reported_loss, grads),
                                             {inputs_placeholder: batch_feed})
        losses[j].append(worker_loss)
        computed_grads.append(worker_grads)
      for worker_grads in computed_grads:
        # TODO(reedwm): With multiple workers, applying the gradients
        # sequentially per worker is not equivalent to what tf_cnn_benchmarks
        # does when the optmizer is not SGD. Therefore, this currently does not
        # work currently when num_workers > 1 and params.optimizer != 'sgd'.
        feed_dict = dict(zip(placeholder_grads, worker_grads))
        sess.run(apply_grads_op, feed_dict)
  return losses
    def model_fn(features, labels, mode, params):  # pylint: disable=unused-argument
        """The `model_fn` for TPUEstimator."""

        tf.logging.info("*** Features ***")
        for name in sorted(features.keys()):
            tf.logging.info("  name = %s, shape = %s" %
                            (name, features[name].shape))

        unique_ids = features["unique_ids"]
        input_ids = features["input_ids"]
        input_mask = features["input_mask"]
        segment_ids = features["segment_ids"]
        label_ids = features["label_ids"]

        # obtaining the membership variables is important since only those weights
        # are modified during the optimization process.
        membership_logits, membership_vars = create_model(
            bert_config=bert_config,
            input_ids=input_ids,
            input_mask=input_mask,
            segment_ids=segment_ids,
            use_one_hot_embeddings=use_one_hot_embeddings,
            membership_features_str=membership_features_str)

        membership_probs = tf.nn.softmax(membership_logits, axis=-1)
        membership_log_probs = tf.nn.log_softmax(membership_logits, axis=-1)

        tvars = tf.trainable_variables()

        initialized_variable_names = {}
        scaffold_fn = None
        if init_checkpoint:
            (assignment_map, initialized_variable_names
             ) = modeling.get_assignment_map_from_checkpoint(
                 tvars, init_checkpoint)
            if use_tpu:

                def tpu_scaffold():
                    tf.train.init_from_checkpoint(init_checkpoint,
                                                  assignment_map)
                    return tf.train.Scaffold()

                scaffold_fn = tpu_scaffold
            else:
                tf.train.init_from_checkpoint(init_checkpoint, assignment_map)

        tf.logging.info("**** Trainable Variables ****")
        for var in tvars:
            init_string = ""
            if var.name in initialized_variable_names:
                init_string = ", *INIT_FROM_CKPT*"
            tf.logging.info("  name = %s, shape = %s%s", var.name, var.shape,
                            init_string)

        output_spec = None
        if mode == tf_estimator.ModeKeys.TRAIN:

            one_hot_positions = tf.one_hot(label_ids,
                                           depth=2,
                                           dtype=tf.float32)
            loss = -tf.reduce_mean(
                tf.reduce_sum(one_hot_positions * membership_log_probs,
                              axis=-1))

            global_step = tf.train.get_or_create_global_step()

            optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate)
            if use_tpu:
                optimizer = contrib_tpu.CrossShardOptimizer(optimizer)

            train_op = optimizer.minimize(loss=loss,
                                          global_step=global_step,
                                          var_list=membership_vars)

            output_spec = contrib_tpu.TPUEstimatorSpec(mode=mode,
                                                       loss=loss,
                                                       train_op=train_op,
                                                       scaffold_fn=scaffold_fn)

        elif mode == tf_estimator.ModeKeys.EVAL:

            one_hot_positions = tf.one_hot(label_ids,
                                           depth=2,
                                           dtype=tf.float32)
            per_example_loss = -1 * tf.reduce_sum(
                one_hot_positions * membership_log_probs, axis=-1)
            total_loss = tf.reduce_mean(per_example_loss)

            def metric_fn(per_example_loss, label_ids, membership_logits):
                predictions = tf.argmax(membership_logits,
                                        axis=-1,
                                        output_type=tf.int32)
                loss = tf.metrics.mean(values=per_example_loss)
                accuracy = tf.metrics.accuracy(labels=label_ids,
                                               predictions=predictions)
                return {"eval_accuracy": accuracy, "eval_loss": loss}

            eval_metrics = (metric_fn,
                            [per_example_loss, label_ids, membership_logits])

            output_spec = contrib_tpu.TPUEstimatorSpec(
                mode=mode,
                loss=total_loss,
                eval_metrics=eval_metrics,
                scaffold_fn=scaffold_fn)

        elif mode == tf_estimator.ModeKeys.PREDICT:
            predictions = {
                "unique_ids": unique_ids,
                "membership_probs": membership_probs
            }
            output_spec = contrib_tpu.TPUEstimatorSpec(mode=mode,
                                                       predictions=predictions,
                                                       scaffold_fn=scaffold_fn)

        else:
            raise ValueError("Only TRAIN and PREDICT modes are supported: %s" %
                             (mode))

        return output_spec