Пример #1
0
def keypoint_model_fn(features, labels, mode, params):
    targets = labels['targets']
    shape = labels['shape']
    classid = labels['classid']
    key_v = labels['key_v']
    isvalid = labels['isvalid']
    norm_value = labels['norm_value']

    cur_batch_size = tf.shape(features)[0]

    with tf.variable_scope(params['model_scope'], default_name=None, values=[features], reuse=tf.AUTO_REUSE):
        pred_outputs = hg.create_model(features, params['num_stacks'], params['feats_channals'],
                            config.class_num_joints[(params['model_scope'] if 'all' not in params['model_scope'] else '*')], params['num_modules'],
                            (mode == tf.estimator.ModeKeys.TRAIN), params['data_format'])

    if params['data_format'] == 'channels_last':
        pred_outputs = [tf.transpose(pred_outputs[ind], [0, 3, 1, 2], name='outputs_trans_{}'.format(ind)) for ind in list(range(len(pred_outputs)))]

    score_map = pred_outputs[-1]

    pred_x, pred_y = get_keypoint(features, targets, score_map, params['heatmap_size'], params['train_image_size'], params['train_image_size'], (params['model_scope'] if 'all' not in params['model_scope'] else '*'), clip_at_zero=True, data_format=params['data_format'])

    # this is important!!!
    targets = 255. * targets
    #with tf.control_dependencies([pred_x, pred_y]):
    ne_mertric = mertric.normalized_error(targets, score_map, norm_value, key_v, isvalid,
                             cur_batch_size,
                             config.class_num_joints[(params['model_scope'] if 'all' not in params['model_scope'] else '*')],
                             params['heatmap_size'],
                             params['train_image_size'])

    # last_pred_mse = tf.metrics.mean_squared_error(score_map, targets,
    #                             weights=1.0 / tf.cast(cur_batch_size, tf.float32),
    #                             name='last_pred_mse')

    all_visible = tf.logical_and(key_v>0, isvalid>0)
    targets = tf.boolean_mask(targets, all_visible)
    pred_outputs = [tf.boolean_mask(pred_outputs[ind], all_visible, name='boolean_mask_{}'.format(ind)) for ind in list(range(len(pred_outputs)))]

    sq_diff = tf.reduce_sum(tf.squared_difference(targets, pred_outputs[-1]), axis=-1)
    last_pred_mse = tf.metrics.mean_absolute_error(sq_diff, tf.zeros_like(sq_diff), name='last_pred_mse')

    metrics = {'normalized_error': ne_mertric, 'last_pred_mse':last_pred_mse}
    predictions = {'normalized_error': ne_mertric[1]}
    ne_mertric = tf.identity(ne_mertric[1], name='ne_mertric')


    mse_loss_list = []
    for pred_ind in list(range(len(pred_outputs))):
        mse_loss_list.append(tf.losses.mean_squared_error(targets, pred_outputs[pred_ind],
                            weights=1.0 / tf.cast(cur_batch_size, tf.float32),
                            scope='loss_{}'.format(pred_ind),
                            loss_collection=None,#tf.GraphKeys.LOSSES,
                            reduction=tf.losses.Reduction.MEAN))# SUM, SUM_OVER_BATCH_SIZE, default mean by all elements

    mse_loss = tf.multiply(params['mse_weight'], tf.add_n(mse_loss_list), name='mse_loss')
    tf.summary.scalar('mse', mse_loss)
    tf.losses.add_loss(mse_loss)

    # bce_loss_list = []
    # for pred_ind in list(range(len(pred_outputs))):
    #     bce_loss_list.append(tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=pred_outputs[pred_ind], labels=targets, name='loss_{}'.format(pred_ind)), name='loss_mean_{}'.format(pred_ind)))

    # mse_loss = tf.multiply(params['mse_weight'] / params['num_stacks'], tf.add_n(bce_loss_list), name='mse_loss')
    # tf.summary.scalar('mse', mse_loss)
    # tf.losses.add_loss(mse_loss)

    # Add weight decay to the loss. We exclude the batch norm variables because
    # doing so leads to a small improvement in accuracy.
    loss = mse_loss + params['weight_decay'] * tf.add_n(
                              [tf.nn.l2_loss(v) for v in tf.trainable_variables()
                               if 'batch_normalization' not in v.name])
    total_loss = tf.identity(loss, name='total_loss')
    tf.summary.scalar('loss', total_loss)

    if mode == tf.estimator.ModeKeys.EVAL:
        return tf.estimator.EstimatorSpec(mode=mode, loss=loss, predictions=predictions, eval_metric_ops=metrics)

    if mode == tf.estimator.ModeKeys.TRAIN:
        global_step = tf.train.get_or_create_global_step()

        lr_values = [params['warmup_learning_rate']] + [params['learning_rate'] * decay for decay in params['lr_decay_factors']]
        learning_rate = tf.train.piecewise_constant(tf.cast(global_step, tf.int32),
                                                    [params['warmup_steps']] + [int(float(ep)*params['steps_per_epoch']) for ep in params['decay_boundaries']],
                                                    lr_values)
        truncated_learning_rate = tf.maximum(learning_rate, tf.constant(params['end_learning_rate'], dtype=learning_rate.dtype), name='learning_rate')
        tf.summary.scalar('lr', truncated_learning_rate)

        optimizer = tf.train.MomentumOptimizer(learning_rate=truncated_learning_rate,
                                                momentum=params['momentum'])

        # Batch norm requires update_ops to be added as a train_op dependency.
        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        with tf.control_dependencies(update_ops):
            train_op = optimizer.minimize(loss, global_step)
    else:
        train_op = None

    return tf.estimator.EstimatorSpec(
                          mode=mode,
                          predictions=predictions,
                          loss=loss,
                          train_op=train_op,
                          eval_metric_ops=metrics,
                          scaffold=tf.train.Scaffold(init_fn=train_helper.get_init_fn_for_scaffold(FLAGS)))
def keypoint_model_fn(features, labels, mode, params):
    targets = labels['targets']
    shape = labels['shape']
    classid = labels['classid']
    key_v = labels['key_v']
    isvalid = labels['isvalid']
    norm_value = labels['norm_value']

    cur_batch_size = tf.shape(features)[0]
    #features= tf.ones_like(features)

    with tf.variable_scope(params['model_scope'], default_name=None, values=[features], reuse=tf.AUTO_REUSE):
        pred_outputs = backbone_(features, config.class_num_joints[(params['model_scope'] if 'all' not in params['model_scope'] else '*')], params['heatmap_size'], (mode == tf.estimator.ModeKeys.TRAIN), params['data_format'], net_depth=params['net_depth'])

    if params['data_format'] == 'channels_last':
        pred_outputs = [tf.transpose(pred_outputs[ind], [0, 3, 1, 2], name='outputs_trans_{}'.format(ind)) for ind in list(range(len(pred_outputs)))]

    score_map = pred_outputs[-1]

    pred_x, pred_y = get_keypoint(features, targets, score_map, params['heatmap_size'], params['train_image_size'], params['train_image_size'], (params['model_scope'] if 'all' not in params['model_scope'] else '*'), clip_at_zero=True, data_format=params['data_format'])

    # this is important!!!
    targets = 255. * targets
    blur_list = [1., 1.37, 1.73, 2.4, None]#[1., 1.5, 2., 3., None]
    #blur_list = [None, None, None, None, None]

    targets_list = []
    for sigma in blur_list:
        if sigma is None:
            targets_list.append(targets)
        else:
            # always channels first foe targets
            targets_list.append(gaussian_blur(targets, config.class_num_joints[(params['model_scope'] if 'all' not in params['model_scope'] else '*')], sigma, params['data_format'], 'blur_{}'.format(sigma)))

    # print(key_v)
    #targets = tf.reshape(255.*tf.one_hot(tf.ones_like(key_v,tf.int64)*(params['heatmap_size']*params['heatmap_size']//2+params['heatmap_size']), params['heatmap_size']*params['heatmap_size']), [cur_batch_size,-1,params['heatmap_size'],params['heatmap_size']])
    #norm_value = tf.ones_like(norm_value)
    # score_map = tf.reshape(tf.one_hot(tf.ones_like(key_v,tf.int64)*(31*64+31), params['heatmap_size']*params['heatmap_size']), [cur_batch_size,-1,params['heatmap_size'],params['heatmap_size']])

    #with tf.control_dependencies([pred_x, pred_y]):
    ne_mertric = mertric.normalized_error(targets, score_map, norm_value, key_v, isvalid,
                             cur_batch_size,
                             config.class_num_joints[(params['model_scope'] if 'all' not in params['model_scope'] else '*')],
                             params['heatmap_size'],
                             params['train_image_size'])

    # last_pred_mse = tf.metrics.mean_squared_error(score_map, targets,
    #                             weights=1.0 / tf.cast(cur_batch_size, tf.float32),
    #                             name='last_pred_mse')
    # filter all invisible keypoint maybe better for this task
    # all_visible = tf.logical_and(key_v>0, isvalid>0)
    # targets_list = [tf.boolean_mask(targets_list[ind], all_visible) for ind in list(range(len(targets_list)))]
    # pred_outputs = [tf.boolean_mask(pred_outputs[ind], all_visible, name='boolean_mask_{}'.format(ind)) for ind in list(range(len(pred_outputs)))]
    all_visible = tf.expand_dims(tf.expand_dims(tf.cast(tf.logical_and(key_v>0, isvalid>0), tf.float32), axis=-1), axis=-1)
    targets_list = [targets_list[ind] * all_visible for ind in list(range(len(targets_list)))]
    pred_outputs = [pred_outputs[ind] * all_visible for ind in list(range(len(pred_outputs)))]

    sq_diff = tf.reduce_sum(tf.squared_difference(targets, pred_outputs[-1]), axis=-1)
    last_pred_mse = tf.metrics.mean_absolute_error(sq_diff, tf.zeros_like(sq_diff), name='last_pred_mse')

    metrics = {'normalized_error': ne_mertric, 'last_pred_mse':last_pred_mse}
    predictions = {'normalized_error': ne_mertric[1]}
    ne_mertric = tf.identity(ne_mertric[1], name='ne_mertric')

    base_learning_rate = params['learning_rate']
    mse_loss_list = []
    if params['use_ohkm']:
        base_learning_rate = 1. * base_learning_rate
        for pred_ind in list(range(len(pred_outputs) - 1)):
            mse_loss_list.append(0.5 * tf.losses.mean_squared_error(targets_list[pred_ind], pred_outputs[pred_ind],
                                weights=1.0 / tf.cast(cur_batch_size, tf.float32),
                                scope='loss_{}'.format(pred_ind),
                                loss_collection=None,#tf.GraphKeys.LOSSES,
                                # mean all elements of all pixels in all batch
                                reduction=tf.losses.Reduction.MEAN))# SUM, SUM_OVER_BATCH_SIZE, default mean by all elements

        temp_loss = tf.reduce_mean(tf.reshape(tf.losses.mean_squared_error(targets_list[-1], pred_outputs[-1], weights=1.0, loss_collection=None, reduction=tf.losses.Reduction.NONE), [cur_batch_size, config.class_num_joints[(params['model_scope'] if 'all' not in params['model_scope'] else '*')], -1]), axis=-1)

        num_topk = config.class_num_joints[(params['model_scope'] if 'all' not in params['model_scope'] else '*')] // 2
        gather_col = tf.nn.top_k(temp_loss, k=num_topk, sorted=True)[1]
        gather_row = tf.reshape(tf.tile(tf.reshape(tf.range(cur_batch_size), [-1, 1]), [1, num_topk]), [-1, 1])
        gather_indcies = tf.stop_gradient(tf.stack([gather_row, tf.reshape(gather_col, [-1, 1])], axis=-1))

        select_targets = tf.gather_nd(targets_list[-1], gather_indcies)
        select_heatmap = tf.gather_nd(pred_outputs[-1], gather_indcies)

        mse_loss_list.append(tf.losses.mean_squared_error(select_targets, select_heatmap,
                                weights=1.0 / tf.cast(cur_batch_size, tf.float32),
                                scope='loss_{}'.format(len(pred_outputs) - 1),
                                loss_collection=None,#tf.GraphKeys.LOSSES,
                                # mean all elements of all pixels in all batch
                                reduction=tf.losses.Reduction.MEAN))
    else:
        for pred_ind in list(range(len(pred_outputs))):
            mse_loss_list.append(tf.losses.mean_squared_error(targets_list[pred_ind], pred_outputs[pred_ind],
                                weights=1.0 / tf.cast(cur_batch_size, tf.float32),
                                scope='loss_{}'.format(pred_ind),
                                loss_collection=None,#tf.GraphKeys.LOSSES,
                                # mean all elements of all pixels in all batch
                                reduction=tf.losses.Reduction.MEAN))# SUM, SUM_OVER_BATCH_SIZE, default mean by all elements

    mse_loss = tf.multiply(params['mse_weight'], tf.add_n(mse_loss_list), name='mse_loss')
    tf.summary.scalar('mse', mse_loss)
    tf.losses.add_loss(mse_loss)

    # bce_loss_list = []
    # for pred_ind in list(range(len(pred_outputs))):
    #     bce_loss_list.append(tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=pred_outputs[pred_ind], labels=targets_list[pred_ind]/255., name='loss_{}'.format(pred_ind)), name='loss_mean_{}'.format(pred_ind)))

    # mse_loss = tf.multiply(params['mse_weight'] / params['num_stacks'], tf.add_n(bce_loss_list), name='mse_loss')
    # tf.summary.scalar('mse', mse_loss)
    # tf.losses.add_loss(mse_loss)

    # Add weight decay to the loss. We exclude the batch norm variables because
    # doing so leads to a small improvement in accuracy.
    loss = mse_loss + params['weight_decay'] * tf.add_n([tf.nn.l2_loss(v) for v in tf.trainable_variables() if 'batch_normalization' not in v.name])
    total_loss = tf.identity(loss, name='total_loss')
    tf.summary.scalar('loss', total_loss)

    if mode == tf.estimator.ModeKeys.EVAL:
        return tf.estimator.EstimatorSpec(mode=mode, loss=loss, predictions=predictions, eval_metric_ops=metrics)

    if mode == tf.estimator.ModeKeys.TRAIN:
        global_step = tf.train.get_or_create_global_step()

        lr_values = [params['warmup_learning_rate']] + [base_learning_rate * decay for decay in params['lr_decay_factors']]
        learning_rate = tf.train.piecewise_constant(tf.cast(global_step, tf.int32),
                                                    [params['warmup_steps']] + [int(float(ep)*params['steps_per_epoch']) for ep in params['decay_boundaries']],
                                                    lr_values)
        truncated_learning_rate = tf.maximum(learning_rate, tf.constant(params['end_learning_rate'], dtype=learning_rate.dtype), name='learning_rate')
        tf.summary.scalar('lr', truncated_learning_rate)

        optimizer = tf.train.MomentumOptimizer(learning_rate=truncated_learning_rate,
                                                momentum=params['momentum'])

        optimizer = tf_replicate_model_fn.TowerOptimizer(optimizer)

        # Batch norm requires update_ops to be added as a train_op dependency.
        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        with tf.control_dependencies(update_ops):
            train_op = optimizer.minimize(loss, global_step)
    else:
        train_op = None

    return tf.estimator.EstimatorSpec(
                          mode=mode,
                          predictions=predictions,
                          loss=loss,
                          train_op=train_op,
                          eval_metric_ops=metrics,
                          scaffold=tf.train.Scaffold(init_fn=train_helper.get_init_fn_for_scaffold_(params['checkpoint_path'], params['model_dir'], params['checkpoint_exclude_scopes'], params['model_scope'], params['checkpoint_model_scope'], params['ignore_missing_vars'])))
Пример #3
0
def keypoint_model_fn(features, labels, mode, params):
    targets = labels['targets']
    shape = labels['shape']
    classid = labels['classid']
    key_v = labels['key_v']
    isvalid = labels['isvalid']
    norm_value = labels['norm_value']

    cur_batch_size = tf.shape(features)[0]
    #features= tf.ones_like(features)

    with tf.variable_scope(params['model_scope'],
                           default_name=None,
                           values=[features],
                           reuse=tf.AUTO_REUSE):
        pred_outputs = backbone_(
            features,
            config.class_num_joints[(params['model_scope'] if 'all'
                                     not in params['model_scope'] else '*')],
            params['heatmap_size'], (mode == tf.estimator.ModeKeys.TRAIN),
            params['data_format'])

    #print(pred_outputs)

    if params['data_format'] == 'channels_last':
        pred_outputs = [
            tf.transpose(pred_outputs[ind], [0, 3, 1, 2],
                         name='outputs_trans_{}'.format(ind))
            for ind in list(range(len(pred_outputs)))
        ]

    score_map = pred_outputs[-1]

    pred_x, pred_y = get_keypoint(
        features,
        targets,
        score_map,
        params['heatmap_size'],
        params['train_image_size'],
        params['train_image_size'],
        (params['model_scope'] if 'all' not in params['model_scope'] else '*'),
        clip_at_zero=True,
        data_format=params['data_format'])

    # this is important!!!
    targets = 255. * targets
    blur_list = [1., 1.37, 1.73, 2.4, None]  #[1., 1.5, 2., 3., None]
    #blur_list = [None, None, None, None, None]

    targets_list = []
    for sigma in blur_list:
        if sigma is None:
            targets_list.append(targets)
        else:
            # always channels first foe targets
            targets_list.append(
                gaussian_blur(
                    targets,
                    config.class_num_joints[(params['model_scope'] if 'all'
                                             not in params['model_scope'] else
                                             '*')], sigma,
                    params['data_format'], 'blur_{}'.format(sigma)))

    #with tf.control_dependencies([pred_x, pred_y]):
    ne_mertric = mertric.normalized_error(
        targets, score_map, norm_value, key_v, isvalid, cur_batch_size,
        config.class_num_joints[(params['model_scope'] if 'all'
                                 not in params['model_scope'] else '*')],
        params['heatmap_size'], params['train_image_size'])

    # last_pred_mse = tf.metrics.mean_squared_error(score_map, targets,
    #                             weights=1.0 / tf.cast(cur_batch_size, tf.float32),
    #                             name='last_pred_mse')
    # filter all invisible keypoint maybe better for this task
    # all_visible = tf.logical_and(key_v>0, isvalid>0)
    # targets_list = [tf.boolean_mask(targets_list[ind], all_visible) for ind in list(range(len(targets_list)))]
    # pred_outputs = [tf.boolean_mask(pred_outputs[ind], all_visible, name='boolean_mask_{}'.format(ind)) for ind in list(range(len(pred_outputs)))]
    all_visible = tf.expand_dims(tf.expand_dims(tf.cast(
        tf.logical_and(key_v > 0, isvalid > 0), tf.float32),
                                                axis=-1),
                                 axis=-1)
    targets_list = [
        targets_list[ind] * all_visible
        for ind in list(range(len(targets_list)))
    ]
    pred_outputs = [
        pred_outputs[ind] * all_visible
        for ind in list(range(len(pred_outputs)))
    ]

    sq_diff = tf.reduce_sum(tf.squared_difference(targets, pred_outputs[-1]),
                            axis=-1)
    last_pred_mse = tf.metrics.mean_absolute_error(sq_diff,
                                                   tf.zeros_like(sq_diff),
                                                   name='last_pred_mse')

    metrics = {'normalized_error': ne_mertric, 'last_pred_mse': last_pred_mse}
    predictions = {'normalized_error': ne_mertric[1]}
    ne_mertric = tf.identity(ne_mertric[1], name='ne_mertric')

    mse_loss_list = []
    if params['use_ohkm']:
        for pred_ind in list(range(len(pred_outputs) - 1)):
            mse_loss_list.append(
                0.5 * tf.losses.mean_squared_error(
                    targets_list[pred_ind],
                    pred_outputs[pred_ind],
                    weights=1.0 / tf.cast(cur_batch_size, tf.float32),
                    scope='loss_{}'.format(pred_ind),
                    loss_collection=None,  #tf.GraphKeys.LOSSES,
                    # mean all elements of all pixels in all batch
                    reduction=tf.losses.Reduction.MEAN)
            )  # SUM, SUM_OVER_BATCH_SIZE, default mean by all elements

        temp_loss = tf.reduce_mean(tf.reshape(
            tf.losses.mean_squared_error(targets_list[-1],
                                         pred_outputs[-1],
                                         weights=1.0,
                                         loss_collection=None,
                                         reduction=tf.losses.Reduction.NONE),
            [
                cur_batch_size, config.class_num_joints[(
                    params['model_scope']
                    if 'all' not in params['model_scope'] else '*')], -1
            ]),
                                   axis=-1)

        num_topk = config.class_num_joints[
            (params['model_scope']
             if 'all' not in params['model_scope'] else '*')] // 2
        gather_col = tf.nn.top_k(temp_loss, k=num_topk, sorted=True)[1]
        gather_row = tf.reshape(
            tf.tile(tf.reshape(tf.range(cur_batch_size), [-1, 1]),
                    [1, num_topk]), [-1, 1])
        gather_indcies = tf.stop_gradient(
            tf.stack([gather_row, tf.reshape(gather_col, [-1, 1])], axis=-1))

        select_targets = tf.gather_nd(targets_list[-1], gather_indcies)
        select_heatmap = tf.gather_nd(pred_outputs[-1], gather_indcies)

        mse_loss_list.append(
            tf.losses.mean_squared_error(
                select_targets,
                select_heatmap,
                weights=1.0 / tf.cast(cur_batch_size, tf.float32),
                scope='loss_{}'.format(len(pred_outputs) - 1),
                loss_collection=None,  #tf.GraphKeys.LOSSES,
                # mean all elements of all pixels in all batch
                reduction=tf.losses.Reduction.MEAN))
    else:
        for pred_ind in list(range(len(pred_outputs))):
            mse_loss_list.append(
                tf.losses.mean_squared_error(
                    targets_list[pred_ind],
                    pred_outputs[pred_ind],
                    weights=1.0 / tf.cast(cur_batch_size, tf.float32),
                    scope='loss_{}'.format(pred_ind),
                    loss_collection=None,  #tf.GraphKeys.LOSSES,
                    # mean all elements of all pixels in all batch
                    reduction=tf.losses.Reduction.MEAN)
            )  # SUM, SUM_OVER_BATCH_SIZE, default mean by all elements

    mse_loss = tf.multiply(params['mse_weight'],
                           tf.add_n(mse_loss_list),
                           name='mse_loss')
    tf.summary.scalar('mse', mse_loss)
    tf.losses.add_loss(mse_loss)

    # Add weight decay to the loss. We exclude the batch norm variables because
    # doing so leads to a small improvement in accuracy.
    loss = mse_loss + params['weight_decay'] * tf.add_n([
        tf.nn.l2_loss(v) for v in tf.trainable_variables()
        if 'batch_normalization' not in v.name
    ])
    total_loss = tf.identity(loss, name='total_loss')
    tf.summary.scalar('loss', total_loss)

    if mode == tf.estimator.ModeKeys.EVAL:
        return tf.estimator.EstimatorSpec(mode=mode,
                                          loss=loss,
                                          predictions=predictions,
                                          eval_metric_ops=metrics)

    if mode == tf.estimator.ModeKeys.TRAIN:
        global_step = tf.train.get_or_create_global_step()

        if not params['dummy_train']:
            step_remainder = tf.floormod(global_step - 1,
                                         params['steps_per_epoch'])
            range_scale = tf.to_float(step_remainder + 1) / tf.to_float(
                params['steps_per_epoch'])
            learning_rate = tf.add(
                (1 - range_scale) * params['high_learning_rate'],
                range_scale * params['low_learning_rate'],
                name='learning_rate')
            tf.summary.scalar('lr', learning_rate)

            should_update = tf.equal(step_remainder,
                                     params['steps_per_epoch'] - 2)
            optimizer = tf.train.MomentumOptimizer(learning_rate=learning_rate,
                                                   momentum=params['momentum'])

            # Batch norm requires update_ops to be added as a train_op dependency.
            update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
            with tf.control_dependencies(update_ops):
                opt_op = optimizer.minimize(loss, global_step)

            variables_to_train = []
            for var in tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES):
                variables_to_train.append(var)

            # Create an ExponentialMovingAverage object
            ema = swa_moving_average.SWAMovingAverage(
                tf.floordiv(global_step, params['steps_per_epoch']))
            with tf.control_dependencies([opt_op]):
                train_op = tf.cond(should_update,
                                   lambda: ema.apply(variables_to_train),
                                   lambda: tf.no_op())

            _init_fn = train_helper.get_raw_init_fn_for_scaffold(
                params['checkpoint_path'], params['model_dir'])
        else:
            learning_rate = tf.constant(0., name='learning_rate')
            tf.summary.scalar('lr', learning_rate)
            optimizer = tf.train.MomentumOptimizer(learning_rate=learning_rate,
                                                   momentum=0.)

            variables_to_train = []
            for var in tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES):
                variables_to_train.append(var)
            ema = swa_moving_average.SWAMovingAverage(
                tf.floordiv(global_step, params['steps_per_epoch']))
            # Batch norm requires update_ops to be added as a train_op dependency.
            update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
            with tf.control_dependencies(update_ops):
                train_op = optimizer.minimize(loss, global_step)
            _init_fn = train_helper.swa_get_init_fn_for_scaffold(
                params['checkpoint_path'], params['model_dir'],
                variables_to_train, ema)
    else:
        train_op = None

    return tf.estimator.EstimatorSpec(mode=mode,
                                      predictions=predictions,
                                      loss=loss,
                                      train_op=train_op,
                                      eval_metric_ops=metrics,
                                      scaffold=tf.train.Scaffold(
                                          init_fn=_init_fn, saver=None))