def build_model( inputs, num_classes, feature_dim, is_training, update_bn, hparams, ): """Constructs the vision model being trained/evaled. Args: inputs: input features/images being fed to the image model build built. num_classes: number of output classes being predicted. is_training: is the model training or not. hparams: additional hyperparameters associated with the image model. Returns: The logits of the image model. """ scopes = setup_arg_scopes(is_training) with contextlib.nested(*scopes): if hparams.model_name == "pyramid_net": logits = build_shake_drop_model(inputs, num_classes, is_training) elif hparams.model_name == "wrn": logits = build_wrn_model(inputs, num_classes, feature_dim, hparams.wrn_size, update_bn) elif hparams.model_name == "shake_shake": logits = build_shake_shake_model(inputs, num_classes, hparams, is_training) return logits
def model_fn(features, labels, mode, params, config): # print("============calling model_fn================") sup_only = params['sup_only'] # print(features) if mode == tf.estimator.ModeKeys.EVAL: all_data = features else: sup_x = features['image'] sup_y = features['label'] sup_batch_size = sup_x.shape[0] unsup = labels['unsup'] aug = labels['aug'] unsup_batch_size = unsup.shape[0] all_data = tf.concat([sup_x, unsup, aug], axis=0) logits = wrn.build_wrn_model(all_data, params['n_classes'], 32) # print(np.shape(logits)) predicted_classes = tf.argmax(logits, axis=-1, output_type=tf.int32) probs = tf.nn.softmax(logits) if mode == tf.estimator.ModeKeys.EVAL: sup_loss = tf.nn.sparse_softmax_cross_entropy_with_logits( labels=labels, logits=logits) sup_loss = tf.reduce_mean(sup_loss) accuracy = tf.metrics.accuracy(labels, predicted_classes, name='acc_op') metrics = {'accuracy': accuracy} tf.summary.scalar('accuracy', accuracy[1]) return tf.estimator.EstimatorSpec(mode, loss=sup_loss, eval_metric_ops=metrics) if mode == tf.estimator.ModeKeys.PREDICT: predictions = { 'class_ids': predicted_classes, 'probs': probs, 'logits': logits } return tf.estimator.EstimatorSpec(mode, predictions=predictions) assert mode == tf.estimator.ModeKeys.TRAIN # print(sup_loss.shape) sup_loss = tf.nn.sparse_softmax_cross_entropy_with_logits( labels=sup_y, logits=logits[:sup_batch_size]) sup_loss = tf.reduce_mean(sup_loss, name='sup_loss_tensor') # sup_loss, avg_sup_loss, tsa_threshold = anneal_sup_loss( # logits[:sup_batch_size], # labels[:sup_batch_size], # sup_loss, # tf.train.get_global_step() # ) # sup_loss = avg_sup_loss if sup_only: optimizer = tf.train.AdamOptimizer() train_op = optimizer.minimize(sup_loss, global_step=tf.train.get_global_step()) return tf.estimator.EstimatorSpec(mode, loss=sup_loss, train_op=train_op) unsup_loss = kl_divergence( tf.stop_gradient(logits[sup_batch_size:sup_batch_size + unsup_batch_size]), logits[sup_batch_size + unsup_batch_size:]) unsup_loss = tf.reduce_mean(unsup_loss, name='unsup_loss_tensor') total_loss = sup_loss + unsup_loss total_loss = decay_weights(total_loss, 5e-4) metric_dict = { 'sup_loss': 'sup_loss_tensor', 'unsup_loss': 'unsup_loss_tensor', # 'tsa_threshold': 'tsa_threshold_tensor' } logging_hook = tf.train.LoggingTensorHook(tensors=metric_dict, every_n_iter=100) training_hooks = [logging_hook] global_step = tf.train.get_global_step() if warmup_steps > 0: warmup_lr = tf.to_float(global_step) / tf.to_float(warmup_steps) * lr else: warmup_lr = 0.0 # decay the learning rate using the cosine schedule decay_lr = tf.train.cosine_decay(lr, global_step=global_step - warmup_steps, decay_steps=steps - warmup_steps, alpha=min_lr_ratio) learning_rate = tf.where(global_step < warmup_steps, warmup_lr, decay_lr) optimizer = tf.train.MomentumOptimizer(learning_rate=learning_rate, momentum=0.9, use_nesterov=True) # grads_and_vars = optimizer.compute_gradients(total_loss) # gradients, variables = zip(*grads_and_vars) # update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) # with tf.control_dependencies(update_ops): # train_op = optimizer.apply_gradients( # zip(gradients, variables), global_step=tf.train.get_global_step()) train_op = optimizer.minimize(total_loss, global_step=tf.train.get_global_step()) return tf.estimator.EstimatorSpec(mode, loss=total_loss, training_hooks=training_hooks, train_op=train_op)