Exemple #1
0
def build_model(
    inputs,
    num_classes,
    feature_dim,
    is_training,
    update_bn,
    hparams,
):
    """Constructs the vision model being trained/evaled.

  Args:
    inputs: input features/images being fed to the image model build built.
    num_classes: number of output classes being predicted.
    is_training: is the model training or not.
    hparams: additional hyperparameters associated with the image model.

  Returns:
    The logits of the image model.
  """
    scopes = setup_arg_scopes(is_training)
    with contextlib.nested(*scopes):
        if hparams.model_name == "pyramid_net":
            logits = build_shake_drop_model(inputs, num_classes, is_training)
        elif hparams.model_name == "wrn":
            logits = build_wrn_model(inputs, num_classes, feature_dim,
                                     hparams.wrn_size, update_bn)

        elif hparams.model_name == "shake_shake":
            logits = build_shake_shake_model(inputs, num_classes, hparams,
                                             is_training)
    return logits
def model_fn(features, labels, mode, params, config):
    # print("============calling model_fn================")
    sup_only = params['sup_only']
    # print(features)
    if mode == tf.estimator.ModeKeys.EVAL:
        all_data = features
    else:
        sup_x = features['image']
        sup_y = features['label']
        sup_batch_size = sup_x.shape[0]
        unsup = labels['unsup']
        aug = labels['aug']
        unsup_batch_size = unsup.shape[0]
        all_data = tf.concat([sup_x, unsup, aug], axis=0)

    logits = wrn.build_wrn_model(all_data, params['n_classes'], 32)
    # print(np.shape(logits))
    predicted_classes = tf.argmax(logits, axis=-1, output_type=tf.int32)
    probs = tf.nn.softmax(logits)

    if mode == tf.estimator.ModeKeys.EVAL:
        sup_loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
            labels=labels, logits=logits)
        sup_loss = tf.reduce_mean(sup_loss)
        accuracy = tf.metrics.accuracy(labels,
                                       predicted_classes,
                                       name='acc_op')
        metrics = {'accuracy': accuracy}
        tf.summary.scalar('accuracy', accuracy[1])
        return tf.estimator.EstimatorSpec(mode,
                                          loss=sup_loss,
                                          eval_metric_ops=metrics)

    if mode == tf.estimator.ModeKeys.PREDICT:
        predictions = {
            'class_ids': predicted_classes,
            'probs': probs,
            'logits': logits
        }
        return tf.estimator.EstimatorSpec(mode, predictions=predictions)

    assert mode == tf.estimator.ModeKeys.TRAIN
    # print(sup_loss.shape)
    sup_loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
        labels=sup_y, logits=logits[:sup_batch_size])
    sup_loss = tf.reduce_mean(sup_loss, name='sup_loss_tensor')
    # sup_loss, avg_sup_loss, tsa_threshold = anneal_sup_loss(
    #     logits[:sup_batch_size],
    #     labels[:sup_batch_size],
    #     sup_loss,
    #     tf.train.get_global_step()
    # )
    # sup_loss = avg_sup_loss
    if sup_only:
        optimizer = tf.train.AdamOptimizer()
        train_op = optimizer.minimize(sup_loss,
                                      global_step=tf.train.get_global_step())
        return tf.estimator.EstimatorSpec(mode,
                                          loss=sup_loss,
                                          train_op=train_op)
    unsup_loss = kl_divergence(
        tf.stop_gradient(logits[sup_batch_size:sup_batch_size +
                                unsup_batch_size]),
        logits[sup_batch_size + unsup_batch_size:])
    unsup_loss = tf.reduce_mean(unsup_loss, name='unsup_loss_tensor')
    total_loss = sup_loss + unsup_loss
    total_loss = decay_weights(total_loss, 5e-4)

    metric_dict = {
        'sup_loss': 'sup_loss_tensor',
        'unsup_loss': 'unsup_loss_tensor',
        # 'tsa_threshold': 'tsa_threshold_tensor'
    }
    logging_hook = tf.train.LoggingTensorHook(tensors=metric_dict,
                                              every_n_iter=100)
    training_hooks = [logging_hook]

    global_step = tf.train.get_global_step()
    if warmup_steps > 0:
        warmup_lr = tf.to_float(global_step) / tf.to_float(warmup_steps) * lr
    else:
        warmup_lr = 0.0

    # decay the learning rate using the cosine schedule
    decay_lr = tf.train.cosine_decay(lr,
                                     global_step=global_step - warmup_steps,
                                     decay_steps=steps - warmup_steps,
                                     alpha=min_lr_ratio)

    learning_rate = tf.where(global_step < warmup_steps, warmup_lr, decay_lr)
    optimizer = tf.train.MomentumOptimizer(learning_rate=learning_rate,
                                           momentum=0.9,
                                           use_nesterov=True)
    # grads_and_vars = optimizer.compute_gradients(total_loss)
    # gradients, variables = zip(*grads_and_vars)
    # update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
    # with tf.control_dependencies(update_ops):
    #     train_op = optimizer.apply_gradients(
    #         zip(gradients, variables), global_step=tf.train.get_global_step())
    train_op = optimizer.minimize(total_loss,
                                  global_step=tf.train.get_global_step())

    return tf.estimator.EstimatorSpec(mode,
                                      loss=total_loss,
                                      training_hooks=training_hooks,
                                      train_op=train_op)