def main(args): if args.distribute: distribute.enable_distributed_training() tf.logging.set_verbosity(tf.logging.INFO) model_cls = models.get_model(args.model) params = default_parameters() # Import and override parameters # Priorities (low -> high): # default -> saved -> command params = merge_parameters(params, model_cls.get_parameters()) params = import_params(args.output, args.model, params) override_parameters(params, args) # Export all parameters and model specific parameters if distribute.rank() == 0: export_params(params.output, "params.json", params) export_params(params.output, "%s.json" % args.model, collect_params(params, model_cls.get_parameters())) # Build Graph with tf.Graph().as_default(): if not params.record: # Build input queue features = dataset.get_training_input(params.input, params) else: features = record.get_input_features( os.path.join(params.record, "*train*"), "train", params) # Build model initializer = get_initializer(params) regularizer = tf.contrib.layers.l1_l2_regularizer( scale_l1=params.scale_l1, scale_l2=params.scale_l2) model = model_cls(params) # Create global step global_step = tf.train.get_or_create_global_step() dtype = tf.float16 if args.half else None # Multi-GPU setting sharded_losses = parallel.parallel_model( model.get_training_func(initializer, regularizer, dtype), features, params.device_list) loss = tf.add_n(sharded_losses) / len(sharded_losses) loss = loss + tf.losses.get_regularization_loss() if distribute.rank() == 0: print_variables() learning_rate = get_learning_rate_decay(params.learning_rate, global_step, params) learning_rate = tf.convert_to_tensor(learning_rate, dtype=tf.float32) tf.summary.scalar("loss", loss) tf.summary.scalar("learning_rate", learning_rate) # Create optimizer if params.optimizer == "Adam": opt = tf.train.AdamOptimizer(learning_rate, beta1=params.adam_beta1, beta2=params.adam_beta2, epsilon=params.adam_epsilon) elif params.optimizer == "LazyAdam": opt = tf.contrib.opt.LazyAdamOptimizer(learning_rate, beta1=params.adam_beta1, beta2=params.adam_beta2, epsilon=params.adam_epsilon) else: raise RuntimeError("Optimizer %s not supported" % params.optimizer) opt = optimizers.MultiStepOptimizer(opt, params.update_cycle) if args.half: opt = optimizers.LossScalingOptimizer(opt, params.loss_scale) # Optimization grads_and_vars = opt.compute_gradients( loss, colocate_gradients_with_ops=True) if params.clip_grad_norm: grads, var_list = list(zip(*grads_and_vars)) grads, _ = tf.clip_by_global_norm(grads, params.clip_grad_norm) grads_and_vars = zip(grads, var_list) train_op = opt.apply_gradients(grads_and_vars, global_step=global_step) # Validation if params.validation and params.references[0]: files = [params.validation] + list(params.references) eval_inputs = dataset.sort_and_zip_files(files) eval_input_fn = dataset.get_evaluation_input else: eval_input_fn = None # Hooks train_hooks = [ tf.train.StopAtStepHook(last_step=params.train_steps), tf.train.NanTensorHook(loss), tf.train.LoggingTensorHook( { "step": global_step, "loss": loss, "source": tf.shape(features["source"]), "target": tf.shape(features["target"]) }, every_n_iter=1) ] broadcast_hook = distribute.get_broadcast_hook() if broadcast_hook: train_hooks.append(broadcast_hook) if distribute.rank() == 0: # Add hooks save_vars = tf.trainable_variables() + [global_step] saver = tf.train.Saver( var_list=save_vars if params.only_save_trainable else None, max_to_keep=params.keep_checkpoint_max, sharded=False) tf.add_to_collection(tf.GraphKeys.SAVERS, saver) train_hooks.append( hooks.MultiStepHook(tf.train.CheckpointSaverHook( checkpoint_dir=params.output, save_secs=params.save_checkpoint_secs or None, save_steps=params.save_checkpoint_steps or None, saver=saver), step=params.update_cycle)) if eval_input_fn is not None: train_hooks.append( hooks.MultiStepHook(hooks.EvaluationHook( lambda f: inference.create_inference_graph([model], f, params), lambda: eval_input_fn(eval_inputs, params), lambda x: decode_target_ids(x, params), params.output, session_config(params), device_list=params.device_list, max_to_keep=params.keep_top_checkpoint_max, eval_secs=params.eval_secs, eval_steps=params.eval_steps), step=params.update_cycle)) checkpoint_dir = params.output else: checkpoint_dir = None restore_op = restore_variables(args.checkpoint) def restore_fn(step_context): step_context.session.run(restore_op) # Create session, do not use default CheckpointSaverHook with tf.train.MonitoredTrainingSession( checkpoint_dir=checkpoint_dir, hooks=train_hooks, save_checkpoint_secs=None, config=session_config(params)) as sess: # Restore pre-trained variables sess.run_step_fn(restore_fn) while not sess.should_stop(): sess.run(train_op)
def get_training_input(filenames, params): """ Get input for training stage :param filenames: A list contains [source_filenames, target_filenames] :param params: Hyper-parameters :returns: A dictionary of pair <Key, Tensor> """ with tf.device("/cpu:0"): src_dataset = tf.data.TextLineDataset(filenames[0]) tgt_dataset = tf.data.TextLineDataset(filenames[1]) dataset = tf.data.Dataset.zip((src_dataset, tgt_dataset)) if distribute.is_distributed_training_mode(): dataset = dataset.shard(distribute.size(), distribute.rank()) dataset = dataset.shuffle(params.buffer_size) dataset = dataset.repeat() # Split string dataset = dataset.map( lambda src, tgt: (tf.string_split([src]).values, tf.string_split([tgt]).values), num_parallel_calls=params.num_threads) # Append <eos> symbol dataset = dataset.map( lambda src, tgt: (tf.concat([src, [tf.constant(params.eos)]], axis=0), tf.concat([tgt, [tf.constant(params.eos)]], axis=0)), num_parallel_calls=params.num_threads) # Convert to dictionary dataset = dataset.map(lambda src, tgt: { "source": src, "target": tgt, "source_length": tf.shape(src), "target_length": tf.shape(tgt) }, num_parallel_calls=params.num_threads) # Create iterator iterator = dataset.make_one_shot_iterator() features = iterator.get_next() # Create lookup table src_table = tf.contrib.lookup.index_table_from_tensor( tf.constant(params.vocabulary["source"]), default_value=params.mapping["source"][params.unk]) tgt_table = tf.contrib.lookup.index_table_from_tensor( tf.constant(params.vocabulary["target"]), default_value=params.mapping["target"][params.unk]) # String to index lookup features["source"] = src_table.lookup(features["source"]) features["target"] = tgt_table.lookup(features["target"]) # Batching features = batch_examples(features, params.batch_size, params.max_length, params.mantissa_bits, shard_multiplier=len(params.device_list), length_multiplier=params.length_multiplier, constant=params.constant_batch_size, num_threads=params.num_threads) # Convert to int32 features["source"] = tf.to_int32(features["source"]) features["target"] = tf.to_int32(features["target"]) features["source_length"] = tf.to_int32(features["source_length"]) features["target_length"] = tf.to_int32(features["target_length"]) features["source_length"] = tf.squeeze(features["source_length"], 1) features["target_length"] = tf.squeeze(features["target_length"], 1) return features
def main(args): if args.distribute: distribute.enable_distributed_training() tf.logging.set_verbosity(tf.logging.INFO) model_cls = models.get_model(args.model) params = default_parameters() # Import and override parameters # Priorities (low -> high): # default -> saved -> command params = merge_parameters(params, model_cls.get_parameters()) params = import_params(args.output, args.model, params) override_parameters(params, args) # Export all parameters and model specific parameters if not args.distribute or distribute.rank() == 0: export_params(params.output, "params.json", params) export_params(params.output, "%s.json" % args.model, collect_params(params, model_cls.get_parameters())) assert 'r2l' in params.input[2] # Build Graph use_all_devices(params) with tf.Graph().as_default(): if not params.record: # Build input queue features = dataset.abd_get_training_input(params.input, params) else: features = record.get_input_features( os.path.join(params.record, "*train*"), "train", params) update_cycle = params.update_cycle features, init_op = cache.cache_features(features, update_cycle) # Build model initializer = get_initializer(params) regularizer = tf.contrib.layers.l1_l2_regularizer( scale_l1=params.scale_l1, scale_l2=params.scale_l2) model = model_cls(params) # Create global step global_step = tf.train.get_or_create_global_step() dtype = tf.float16 if args.fp16 else None if args.distribute: training_func = model.get_training_func(initializer, regularizer, dtype) loss = training_func(features) else: # Multi-GPU setting sharded_losses = parallel.parallel_model( model.get_training_func(initializer, regularizer, dtype), features, params.device_list) loss = tf.add_n(sharded_losses) / len(sharded_losses) loss = loss + tf.losses.get_regularization_loss() # Print parameters if not args.distribute or distribute.rank() == 0: print_variables() learning_rate = get_learning_rate_decay(params.learning_rate, global_step, params) learning_rate = tf.convert_to_tensor(learning_rate, dtype=tf.float32) tf.summary.scalar("learning_rate", learning_rate) # Create optimizer if params.optimizer == "Adam": opt = tf.train.AdamOptimizer(learning_rate, beta1=params.adam_beta1, beta2=params.adam_beta2, epsilon=params.adam_epsilon) elif params.optimizer == "LazyAdam": opt = tf.contrib.opt.LazyAdamOptimizer(learning_rate, beta1=params.adam_beta1, beta2=params.adam_beta2, epsilon=params.adam_epsilon) else: raise RuntimeError("Optimizer %s not supported" % params.optimizer) loss, ops = optimize.create_train_op( loss, opt, global_step, distribute.all_reduce if args.distribute else None, args.fp16, params) restore_op = restore_variables(args.checkpoint) # Validation if params.validation and params.references[0]: files = params.validation + list(params.references) eval_inputs = dataset.sort_and_zip_files(files) eval_input_fn = dataset.abd_get_evaluation_input else: eval_input_fn = None # Add hooks multiplier = tf.convert_to_tensor([update_cycle, 1]) train_hooks = [ tf.train.StopAtStepHook(last_step=params.train_steps), tf.train.NanTensorHook(loss), tf.train.LoggingTensorHook( { "step": global_step, "loss": loss, "source": tf.shape(features["source"]) * multiplier, "target": tf.shape(features["target"]) * multiplier }, every_n_iter=1) ] if args.distribute: train_hooks.append(distribute.get_broadcast_hook()) config = session_config(params) if not args.distribute or distribute.rank() == 0: # Add hooks save_vars = tf.trainable_variables() + [global_step] saver = tf.train.Saver( var_list=save_vars if params.only_save_trainable else None, max_to_keep=params.keep_checkpoint_max, sharded=False) tf.add_to_collection(tf.GraphKeys.SAVERS, saver) train_hooks.append( tf.train.CheckpointSaverHook( checkpoint_dir=params.output, save_secs=params.save_checkpoint_secs or None, save_steps=params.save_checkpoint_steps or None, saver=saver)) if eval_input_fn is not None: if not args.distribute or distribute.rank() == 0: train_hooks.append( hooks.EvaluationHook( lambda f: inference.create_inference_graph([model], f, params), lambda: eval_input_fn(eval_inputs, params), lambda x: decode_target_ids(x, params), params.output, config, params.keep_top_checkpoint_max, eval_secs=params.eval_secs, eval_steps=params.eval_steps)) def restore_fn(step_context): step_context.session.run(restore_op) def step_fn(step_context): # Bypass hook calls step_context.session.run([init_op, ops["zero_op"]]) for i in range(update_cycle - 1): step_context.session.run(ops["collect_op"]) return step_context.run_with_hooks(ops["train_op"]) # Create session, do not use default CheckpointSaverHook if not args.distribute or distribute.rank() == 0: checkpoint_dir = params.output else: checkpoint_dir = None with tf.train.MonitoredTrainingSession(checkpoint_dir=checkpoint_dir, hooks=train_hooks, save_checkpoint_secs=None, config=config) as sess: # Restore pre-trained variables sess.run_step_fn(restore_fn) while not sess.should_stop(): sess.run_step_fn(step_fn)
def get_training_input(filenames, params): """ Get input for training stage :param filenames: A list contains [source_filenames, target_filenames] :param params: Hyper-parameters :returns: A dictionary of pair <Key, Tensor> """ with tf.device("/cpu:0"): datasets = [] for filename in filenames: datasets.append(tf.data.TextLineDataset(filename)) dataset = tf.data.Dataset.zip(tuple(datasets)) if distribute.is_distributed_training_mode(): dataset = dataset.shard(distribute.size(), distribute.rank()) dataset = dataset.shuffle(params.buffer_size) dataset = dataset.repeat() # Split string dataset = dataset.map( lambda *x: [tf.string_split([y]).values for y in x], num_parallel_calls=params.num_threads) # Append <eos> symbol dataset = dataset.map( lambda *x: [tf.concat([y, [tf.constant(params.eos)]], axis=0) for y in x], num_parallel_calls=params.num_threads) def convert_to_dict(src, tgt, *x): res = {} res["source"] = src res["source_length"] = tf.shape(src) res["target"] = tgt res["target_length"] = tf.shape(tgt) for i, v in enumerate(x): res["mt_%d" % i] = v res["mt_length_%d" % i] = tf.shape(v) return res # Convert to dictionary dataset = dataset.map(convert_to_dict, num_parallel_calls=params.num_threads) # Create iterator iterator = dataset.make_one_shot_iterator() features = iterator.get_next() # Create lookup table src_table = tf.contrib.lookup.index_table_from_tensor( tf.constant(params.vocabulary["source"]), default_value=params.mapping["source"][params.unk]) tgt_table = tf.contrib.lookup.index_table_from_tensor( tf.constant(params.vocabulary["target"]), default_value=params.mapping["target"][params.unk]) # String to index lookup features["source"] = src_table.lookup(features["source"]) features["target"] = tgt_table.lookup(features["target"]) for i in range(len(filenames) - 2): features["mt_%d" % i] = tgt_table.lookup(features["mt_%d" % i]) # Batching features = batch_examples(features, params.batch_size, params.max_length, params.mantissa_bits, shard_multiplier=len(params.device_list), length_multiplier=params.length_multiplier, constant=params.constant_batch_size, num_threads=params.num_threads) # Convert to int32 features["source"] = tf.to_int32(features["source"]) features["target"] = tf.to_int32(features["target"]) features["source_length"] = tf.to_int32(features["source_length"]) features["target_length"] = tf.to_int32(features["target_length"]) features["source_length"] = tf.squeeze(features["source_length"], 1) features["target_length"] = tf.squeeze(features["target_length"], 1) for i in range(len(filenames) - 2): features["mt_%d" % i] = tf.to_int32(features["mt_%d" % i]) features["mt_length_%d" % i] = tf.to_int32( features["mt_length_%d" % i]) features["mt_length_%d" % i] = tf.squeeze( features["mt_length_%d" % i], 1) return features