Beispiel #1
0
def main(args):
    if args.distribute:
        distribute.enable_distributed_training()

    tf.logging.set_verbosity(tf.logging.INFO)
    model_cls = models.get_model(args.model)
    params = default_parameters()

    # Import and override parameters
    # Priorities (low -> high):
    # default -> saved -> command
    params = merge_parameters(params, model_cls.get_parameters())
    params = import_params(args.output, args.model, params)
    override_parameters(params, args)

    # Export all parameters and model specific parameters
    if distribute.rank() == 0:
        export_params(params.output, "params.json", params)
        export_params(params.output, "%s.json" % args.model,
                      collect_params(params, model_cls.get_parameters()))

    # Build Graph
    with tf.Graph().as_default():
        if not params.record:
            # Build input queue
            features = dataset.get_training_input(params.input, params)
        else:
            features = record.get_input_features(
                os.path.join(params.record, "*train*"), "train", params)

        # Build model
        initializer = get_initializer(params)
        regularizer = tf.contrib.layers.l1_l2_regularizer(
            scale_l1=params.scale_l1, scale_l2=params.scale_l2)
        model = model_cls(params)
        # Create global step
        global_step = tf.train.get_or_create_global_step()
        dtype = tf.float16 if args.half else None

        # Multi-GPU setting
        sharded_losses = parallel.parallel_model(
            model.get_training_func(initializer, regularizer, dtype), features,
            params.device_list)
        loss = tf.add_n(sharded_losses) / len(sharded_losses)
        loss = loss + tf.losses.get_regularization_loss()

        if distribute.rank() == 0:
            print_variables()

        learning_rate = get_learning_rate_decay(params.learning_rate,
                                                global_step, params)
        learning_rate = tf.convert_to_tensor(learning_rate, dtype=tf.float32)

        tf.summary.scalar("loss", loss)
        tf.summary.scalar("learning_rate", learning_rate)

        # Create optimizer
        if params.optimizer == "Adam":
            opt = tf.train.AdamOptimizer(learning_rate,
                                         beta1=params.adam_beta1,
                                         beta2=params.adam_beta2,
                                         epsilon=params.adam_epsilon)
        elif params.optimizer == "LazyAdam":
            opt = tf.contrib.opt.LazyAdamOptimizer(learning_rate,
                                                   beta1=params.adam_beta1,
                                                   beta2=params.adam_beta2,
                                                   epsilon=params.adam_epsilon)
        else:
            raise RuntimeError("Optimizer %s not supported" % params.optimizer)

        opt = optimizers.MultiStepOptimizer(opt, params.update_cycle)

        if args.half:
            opt = optimizers.LossScalingOptimizer(opt, params.loss_scale)

        # Optimization
        grads_and_vars = opt.compute_gradients(
            loss, colocate_gradients_with_ops=True)

        if params.clip_grad_norm:
            grads, var_list = list(zip(*grads_and_vars))
            grads, _ = tf.clip_by_global_norm(grads, params.clip_grad_norm)
            grads_and_vars = zip(grads, var_list)

        train_op = opt.apply_gradients(grads_and_vars, global_step=global_step)

        # Validation
        if params.validation and params.references[0]:
            files = [params.validation] + list(params.references)
            eval_inputs = dataset.sort_and_zip_files(files)
            eval_input_fn = dataset.get_evaluation_input
        else:
            eval_input_fn = None

        # Hooks
        train_hooks = [
            tf.train.StopAtStepHook(last_step=params.train_steps),
            tf.train.NanTensorHook(loss),
            tf.train.LoggingTensorHook(
                {
                    "step": global_step,
                    "loss": loss,
                    "source": tf.shape(features["source"]),
                    "target": tf.shape(features["target"])
                },
                every_n_iter=1)
        ]

        broadcast_hook = distribute.get_broadcast_hook()

        if broadcast_hook:
            train_hooks.append(broadcast_hook)

        if distribute.rank() == 0:
            # Add hooks
            save_vars = tf.trainable_variables() + [global_step]
            saver = tf.train.Saver(
                var_list=save_vars if params.only_save_trainable else None,
                max_to_keep=params.keep_checkpoint_max,
                sharded=False)
            tf.add_to_collection(tf.GraphKeys.SAVERS, saver)
            train_hooks.append(
                hooks.MultiStepHook(tf.train.CheckpointSaverHook(
                    checkpoint_dir=params.output,
                    save_secs=params.save_checkpoint_secs or None,
                    save_steps=params.save_checkpoint_steps or None,
                    saver=saver),
                                    step=params.update_cycle))

            if eval_input_fn is not None:
                train_hooks.append(
                    hooks.MultiStepHook(hooks.EvaluationHook(
                        lambda f: inference.create_inference_graph([model], f,
                                                                   params),
                        lambda: eval_input_fn(eval_inputs, params),
                        lambda x: decode_target_ids(x, params),
                        params.output,
                        session_config(params),
                        device_list=params.device_list,
                        max_to_keep=params.keep_top_checkpoint_max,
                        eval_secs=params.eval_secs,
                        eval_steps=params.eval_steps),
                                        step=params.update_cycle))
            checkpoint_dir = params.output
        else:
            checkpoint_dir = None

        restore_op = restore_variables(args.checkpoint)

        def restore_fn(step_context):
            step_context.session.run(restore_op)

        # Create session, do not use default CheckpointSaverHook
        with tf.train.MonitoredTrainingSession(
                checkpoint_dir=checkpoint_dir,
                hooks=train_hooks,
                save_checkpoint_secs=None,
                config=session_config(params)) as sess:
            # Restore pre-trained variables
            sess.run_step_fn(restore_fn)

            while not sess.should_stop():
                sess.run(train_op)
Beispiel #2
0
def get_training_input(filenames, params):
    """ Get input for training stage

    :param filenames: A list contains [source_filenames, target_filenames]
    :param params: Hyper-parameters

    :returns: A dictionary of pair <Key, Tensor>
    """

    with tf.device("/cpu:0"):
        src_dataset = tf.data.TextLineDataset(filenames[0])
        tgt_dataset = tf.data.TextLineDataset(filenames[1])

        dataset = tf.data.Dataset.zip((src_dataset, tgt_dataset))

        if distribute.is_distributed_training_mode():
            dataset = dataset.shard(distribute.size(), distribute.rank())

        dataset = dataset.shuffle(params.buffer_size)
        dataset = dataset.repeat()

        # Split string
        dataset = dataset.map(
            lambda src, tgt:
            (tf.string_split([src]).values, tf.string_split([tgt]).values),
            num_parallel_calls=params.num_threads)

        # Append <eos> symbol
        dataset = dataset.map(
            lambda src, tgt:
            (tf.concat([src, [tf.constant(params.eos)]], axis=0),
             tf.concat([tgt, [tf.constant(params.eos)]], axis=0)),
            num_parallel_calls=params.num_threads)

        # Convert to dictionary
        dataset = dataset.map(lambda src, tgt: {
            "source": src,
            "target": tgt,
            "source_length": tf.shape(src),
            "target_length": tf.shape(tgt)
        },
                              num_parallel_calls=params.num_threads)

        # Create iterator
        iterator = dataset.make_one_shot_iterator()
        features = iterator.get_next()

        # Create lookup table
        src_table = tf.contrib.lookup.index_table_from_tensor(
            tf.constant(params.vocabulary["source"]),
            default_value=params.mapping["source"][params.unk])
        tgt_table = tf.contrib.lookup.index_table_from_tensor(
            tf.constant(params.vocabulary["target"]),
            default_value=params.mapping["target"][params.unk])

        # String to index lookup
        features["source"] = src_table.lookup(features["source"])
        features["target"] = tgt_table.lookup(features["target"])

        # Batching
        features = batch_examples(features,
                                  params.batch_size,
                                  params.max_length,
                                  params.mantissa_bits,
                                  shard_multiplier=len(params.device_list),
                                  length_multiplier=params.length_multiplier,
                                  constant=params.constant_batch_size,
                                  num_threads=params.num_threads)

        # Convert to int32
        features["source"] = tf.to_int32(features["source"])
        features["target"] = tf.to_int32(features["target"])
        features["source_length"] = tf.to_int32(features["source_length"])
        features["target_length"] = tf.to_int32(features["target_length"])
        features["source_length"] = tf.squeeze(features["source_length"], 1)
        features["target_length"] = tf.squeeze(features["target_length"], 1)

        return features
def main(args):
    if args.distribute:
        distribute.enable_distributed_training()

    tf.logging.set_verbosity(tf.logging.INFO)
    model_cls = models.get_model(args.model)
    params = default_parameters()

    # Import and override parameters
    # Priorities (low -> high):
    # default -> saved -> command
    params = merge_parameters(params, model_cls.get_parameters())
    params = import_params(args.output, args.model, params)
    override_parameters(params, args)

    # Export all parameters and model specific parameters
    if not args.distribute or distribute.rank() == 0:
        export_params(params.output, "params.json", params)
        export_params(params.output, "%s.json" % args.model,
                      collect_params(params, model_cls.get_parameters()))

    assert 'r2l' in params.input[2]
    # Build Graph
    use_all_devices(params)
    with tf.Graph().as_default():
        if not params.record:
            # Build input queue
            features = dataset.abd_get_training_input(params.input, params)
        else:
            features = record.get_input_features(
                os.path.join(params.record, "*train*"), "train", params)

        update_cycle = params.update_cycle
        features, init_op = cache.cache_features(features, update_cycle)

        # Build model
        initializer = get_initializer(params)
        regularizer = tf.contrib.layers.l1_l2_regularizer(
            scale_l1=params.scale_l1, scale_l2=params.scale_l2)
        model = model_cls(params)
        # Create global step
        global_step = tf.train.get_or_create_global_step()
        dtype = tf.float16 if args.fp16 else None

        if args.distribute:
            training_func = model.get_training_func(initializer, regularizer,
                                                    dtype)
            loss = training_func(features)
        else:
            # Multi-GPU setting
            sharded_losses = parallel.parallel_model(
                model.get_training_func(initializer, regularizer, dtype),
                features, params.device_list)
            loss = tf.add_n(sharded_losses) / len(sharded_losses)
            loss = loss + tf.losses.get_regularization_loss()

        # Print parameters
        if not args.distribute or distribute.rank() == 0:
            print_variables()

        learning_rate = get_learning_rate_decay(params.learning_rate,
                                                global_step, params)
        learning_rate = tf.convert_to_tensor(learning_rate, dtype=tf.float32)
        tf.summary.scalar("learning_rate", learning_rate)

        # Create optimizer
        if params.optimizer == "Adam":
            opt = tf.train.AdamOptimizer(learning_rate,
                                         beta1=params.adam_beta1,
                                         beta2=params.adam_beta2,
                                         epsilon=params.adam_epsilon)
        elif params.optimizer == "LazyAdam":
            opt = tf.contrib.opt.LazyAdamOptimizer(learning_rate,
                                                   beta1=params.adam_beta1,
                                                   beta2=params.adam_beta2,
                                                   epsilon=params.adam_epsilon)
        else:
            raise RuntimeError("Optimizer %s not supported" % params.optimizer)

        loss, ops = optimize.create_train_op(
            loss, opt, global_step,
            distribute.all_reduce if args.distribute else None, args.fp16,
            params)
        restore_op = restore_variables(args.checkpoint)

        # Validation
        if params.validation and params.references[0]:
            files = params.validation + list(params.references)
            eval_inputs = dataset.sort_and_zip_files(files)
            eval_input_fn = dataset.abd_get_evaluation_input
        else:
            eval_input_fn = None

        # Add hooks
        multiplier = tf.convert_to_tensor([update_cycle, 1])

        train_hooks = [
            tf.train.StopAtStepHook(last_step=params.train_steps),
            tf.train.NanTensorHook(loss),
            tf.train.LoggingTensorHook(
                {
                    "step": global_step,
                    "loss": loss,
                    "source": tf.shape(features["source"]) * multiplier,
                    "target": tf.shape(features["target"]) * multiplier
                },
                every_n_iter=1)
        ]

        if args.distribute:
            train_hooks.append(distribute.get_broadcast_hook())

        config = session_config(params)

        if not args.distribute or distribute.rank() == 0:
            # Add hooks
            save_vars = tf.trainable_variables() + [global_step]
            saver = tf.train.Saver(
                var_list=save_vars if params.only_save_trainable else None,
                max_to_keep=params.keep_checkpoint_max,
                sharded=False)
            tf.add_to_collection(tf.GraphKeys.SAVERS, saver)
            train_hooks.append(
                tf.train.CheckpointSaverHook(
                    checkpoint_dir=params.output,
                    save_secs=params.save_checkpoint_secs or None,
                    save_steps=params.save_checkpoint_steps or None,
                    saver=saver))

        if eval_input_fn is not None:
            if not args.distribute or distribute.rank() == 0:
                train_hooks.append(
                    hooks.EvaluationHook(
                        lambda f: inference.create_inference_graph([model], f,
                                                                   params),
                        lambda: eval_input_fn(eval_inputs, params),
                        lambda x: decode_target_ids(x, params),
                        params.output,
                        config,
                        params.keep_top_checkpoint_max,
                        eval_secs=params.eval_secs,
                        eval_steps=params.eval_steps))

        def restore_fn(step_context):
            step_context.session.run(restore_op)

        def step_fn(step_context):
            # Bypass hook calls
            step_context.session.run([init_op, ops["zero_op"]])
            for i in range(update_cycle - 1):
                step_context.session.run(ops["collect_op"])

            return step_context.run_with_hooks(ops["train_op"])

        # Create session, do not use default CheckpointSaverHook
        if not args.distribute or distribute.rank() == 0:
            checkpoint_dir = params.output
        else:
            checkpoint_dir = None

        with tf.train.MonitoredTrainingSession(checkpoint_dir=checkpoint_dir,
                                               hooks=train_hooks,
                                               save_checkpoint_secs=None,
                                               config=config) as sess:
            # Restore pre-trained variables
            sess.run_step_fn(restore_fn)

            while not sess.should_stop():
                sess.run_step_fn(step_fn)
Beispiel #4
0
def get_training_input(filenames, params):
    """ Get input for training stage

    :param filenames: A list contains [source_filenames, target_filenames]
    :param params: Hyper-parameters

    :returns: A dictionary of pair <Key, Tensor>
    """

    with tf.device("/cpu:0"):
        datasets = []
        for filename in filenames:
            datasets.append(tf.data.TextLineDataset(filename))

        dataset = tf.data.Dataset.zip(tuple(datasets))

        if distribute.is_distributed_training_mode():
            dataset = dataset.shard(distribute.size(), distribute.rank())

        dataset = dataset.shuffle(params.buffer_size)
        dataset = dataset.repeat()

        # Split string
        dataset = dataset.map(
            lambda *x: [tf.string_split([y]).values for y in x],
            num_parallel_calls=params.num_threads)

        # Append <eos> symbol
        dataset = dataset.map(
            lambda *x:
            [tf.concat([y, [tf.constant(params.eos)]], axis=0) for y in x],
            num_parallel_calls=params.num_threads)

        def convert_to_dict(src, tgt, *x):
            res = {}
            res["source"] = src
            res["source_length"] = tf.shape(src)
            res["target"] = tgt
            res["target_length"] = tf.shape(tgt)
            for i, v in enumerate(x):
                res["mt_%d" % i] = v
                res["mt_length_%d" % i] = tf.shape(v)
            return res

        # Convert to dictionary
        dataset = dataset.map(convert_to_dict,
                              num_parallel_calls=params.num_threads)

        # Create iterator
        iterator = dataset.make_one_shot_iterator()
        features = iterator.get_next()

        # Create lookup table
        src_table = tf.contrib.lookup.index_table_from_tensor(
            tf.constant(params.vocabulary["source"]),
            default_value=params.mapping["source"][params.unk])
        tgt_table = tf.contrib.lookup.index_table_from_tensor(
            tf.constant(params.vocabulary["target"]),
            default_value=params.mapping["target"][params.unk])

        # String to index lookup
        features["source"] = src_table.lookup(features["source"])
        features["target"] = tgt_table.lookup(features["target"])
        for i in range(len(filenames) - 2):
            features["mt_%d" % i] = tgt_table.lookup(features["mt_%d" % i])

        # Batching
        features = batch_examples(features,
                                  params.batch_size,
                                  params.max_length,
                                  params.mantissa_bits,
                                  shard_multiplier=len(params.device_list),
                                  length_multiplier=params.length_multiplier,
                                  constant=params.constant_batch_size,
                                  num_threads=params.num_threads)

        # Convert to int32
        features["source"] = tf.to_int32(features["source"])
        features["target"] = tf.to_int32(features["target"])
        features["source_length"] = tf.to_int32(features["source_length"])
        features["target_length"] = tf.to_int32(features["target_length"])
        features["source_length"] = tf.squeeze(features["source_length"], 1)
        features["target_length"] = tf.squeeze(features["target_length"], 1)
        for i in range(len(filenames) - 2):
            features["mt_%d" % i] = tf.to_int32(features["mt_%d" % i])
            features["mt_length_%d" % i] = tf.to_int32(
                features["mt_length_%d" % i])
            features["mt_length_%d" % i] = tf.squeeze(
                features["mt_length_%d" % i], 1)

        return features