def get_optimizer(params): """Get optimizer.""" learning_rate = learning_rate_schedule(params) momentum = params['momentum'] if params['optimizer'].lower() == 'sgd': logging.info('Use SGD optimizer') optimizer = tf.keras.optimizers.SGD(learning_rate, momentum=momentum) elif params['optimizer'].lower() == 'adam': logging.info('Use Adam optimizer') optimizer = tf.keras.optimizers.Adam(learning_rate, beta_1=momentum) else: raise ValueError('optimizers should be adam or sgd') moving_average_decay = params['moving_average_decay'] if moving_average_decay: # TODO(tanmingxing): potentially add dynamic_decay for new tfa release. from tensorflow_addons import optimizers as tfa_optimizers # pylint: disable=g-import-not-at-top optimizer = tfa_optimizers.MovingAverage( optimizer, average_decay=moving_average_decay, dynamic_decay=True) if params['mixed_precision']: optimizer = tf.keras.mixed_precision.experimental.LossScaleOptimizer( optimizer, loss_scale=tf.mixed_precision.experimental.DynamicLossScale( params['loss_scale'])) return optimizer
def get_optimizer(params): """Get optimizer.""" learning_rate = learning_rate_schedule(params) momentum = params['momentum'] if params['optimizer'].lower() == 'sgd': logging.info('Use SGD optimizer') optimizer = tf.keras.optimizers.SGD(learning_rate, momentum=momentum) elif params['optimizer'].lower() == 'adam': logging.info('Use Adam optimizer') optimizer = tf.keras.optimizers.Adam(learning_rate, beta_1=momentum) else: raise ValueError('optimizers should be adam or sgd') moving_average_decay = params['moving_average_decay'] if moving_average_decay: from tensorflow_addons import optimizers as tfa_optimizers # pylint: disable=g-import-not-at-top optimizer = tfa_optimizers.MovingAverage( optimizer, average_decay=moving_average_decay, dynamic_decay=True) precision = utils.get_precision(params['strategy'], params['mixed_precision']) if precision == 'mixed_float16' and params['loss_scale']: optimizer = tf.keras.mixed_precision.LossScaleOptimizer( optimizer, initial_scale=params['loss_scale']) return optimizer
def run(input_dataset_class, common_module, keypoint_profiles_module, input_example_parser_creator): """Runs training pipeline. Args: input_dataset_class: An input dataset class that matches input table type. common_module: A Python module that defines common flags and constants. keypoint_profiles_module: A Python module that defines keypoint profiles. input_example_parser_creator: A function handle for creating data parser function. If None, uses the default parser creator. """ log_dir_path = FLAGS.log_dir_path pipeline_utils.create_dir_and_save_flags( flags, log_dir_path, 'all_flags.train_with_features.json') # Setup summary writer. summary_writer = tf.summary.create_file_writer( os.path.join(log_dir_path, 'train_logs'), flush_millis=10000) # Setup configuration. keypoint_profile_2d = keypoint_profiles_module.create_keypoint_profile_or_die( FLAGS.keypoint_profile_name_2d) # Setup model. input_length = math.ceil(FLAGS.num_frames / FLAGS.downsample_rate) if FLAGS.input_features_dim > 0: feature_dim = FLAGS.input_features_dim input_shape = (input_length, feature_dim) else: feature_dim = None input_shape = (input_length, 13 * 2) classifier = models.get_temporal_classifier( FLAGS.classifier_type, input_shape=input_shape, num_classes=FLAGS.num_classes) ema_classifier = tf.keras.models.clone_model(classifier) optimizer = tf.keras.optimizers.Adam(learning_rate=FLAGS.learning_rate) optimizer = tfa_optimizers.MovingAverage(optimizer) global_step = optimizer.iterations ckpt_manager, _, _ = utils.create_checkpoint( log_dir_path, optimizer=optimizer, model=classifier, ema_model=ema_classifier, global_step=global_step) # Setup the training dataset. dataset = pipelines.create_dataset_from_tables( FLAGS.input_tables, [int(x) for x in FLAGS.batch_sizes], num_instances_per_record=1, shuffle=True, drop_remainder=True, num_epochs=None, keypoint_names_2d=keypoint_profile_2d.keypoint_names, feature_dim=feature_dim, num_classes=FLAGS.num_classes, num_frames=FLAGS.num_frames, shuffle_buffer_size=FLAGS.shuffle_buffer_size, common_module=common_module, dataset_class=input_dataset_class, input_example_parser_creator=input_example_parser_creator) loss_object = tf.keras.losses.CategoricalCrossentropy(from_logits=True) def train_one_iteration(inputs): """Trains the model for one iteration. Args: inputs: A dictionary for training inputs. Returns: loss: The training loss for this iteration. """ if FLAGS.input_features_dim > 0: features = inputs[common_module.KEY_FEATURES] else: features, _ = pipelines.create_model_input( inputs, common_module.MODEL_INPUT_KEYPOINT_TYPE_2D_INPUT, keypoint_profile_2d) features = tf.squeeze(features, axis=1) features = features[:, ::FLAGS.downsample_rate, :] labels = inputs[common_module.KEY_CLASS_TARGETS] labels = tf.squeeze(labels, axis=1) with tf.GradientTape() as tape: outputs = classifier(features, training=True) regularization_loss = sum(classifier.losses) crossentropy_loss = loss_object(labels, outputs) total_loss = crossentropy_loss + regularization_loss trainable_variables = classifier.trainable_variables grads = tape.gradient(total_loss, trainable_variables) optimizer.apply_gradients(zip(grads, trainable_variables)) for grad, trainable_variable in zip(grads, trainable_variables): tf.summary.scalar( 'summarize_grads/' + trainable_variable.name, tf.linalg.norm(grad), step=global_step) return dict( total_loss=total_loss, crossentropy_loss=crossentropy_loss, regularization_loss=regularization_loss) if FLAGS.compile: train_one_iteration = tf.function(train_one_iteration) record_every_n_steps = min(5, FLAGS.num_iterations) save_ckpt_every_n_steps = min(500, FLAGS.num_iterations) with summary_writer.as_default(): with tf.summary.record_if(global_step % record_every_n_steps == 0): start = time.time() for inputs in dataset: if global_step >= FLAGS.num_iterations: break model_losses = train_one_iteration(inputs) duration = time.time() - start start = time.time() for name, loss in model_losses.items(): tf.summary.scalar('train/' + name, loss, step=global_step) tf.summary.scalar('train/learning_rate', optimizer.lr, step=global_step) tf.summary.scalar('train/batch_time', duration, step=global_step) tf.summary.scalar('global_step/sec', 1 / duration, step=global_step) if global_step % record_every_n_steps == 0: logging.info('Iter[{}/{}], {:.6f}s/iter, loss: {:.4f}'.format( global_step.numpy(), FLAGS.num_iterations, duration, model_losses['total_loss'].numpy())) # Save checkpoint. if global_step % save_ckpt_every_n_steps == 0: utils.assign_moving_average_vars(classifier, ema_classifier, optimizer) ckpt_manager.save(checkpoint_number=global_step) logging.info('Checkpoint saved at step %d.', global_step.numpy())