コード例 #1
0
def main(args):
    if args.distribute:
        distribute.enable_distributed_training()

    tf.logging.set_verbosity(tf.logging.INFO)
    model_cls = models.get_model(args.model)
    params = default_parameters()

    # Import and override parameters
    # Priorities (low -> high):
    # default -> saved -> command
    params = merge_parameters(params, model_cls.get_parameters())
    params = import_params(args.output, args.model, params)
    override_parameters(params, args)

    # Export all parameters and model specific parameters
    if distribute.rank() == 0:
        export_params(params.output, "params.json", params)
        export_params(params.output, "%s.json" % args.model,
                      collect_params(params, model_cls.get_parameters()))

    # Build Graph
    with tf.Graph().as_default():
        if not params.record:
            # Build input queue
            features = dataset.get_training_input(params.input, params)
        else:
            features = record.get_input_features(
                os.path.join(params.record, "*train*"), "train", params)

        # Build model
        initializer = get_initializer(params)
        regularizer = tf.contrib.layers.l1_l2_regularizer(
            scale_l1=params.scale_l1, scale_l2=params.scale_l2)
        model = model_cls(params)
        # Create global step
        global_step = tf.train.get_or_create_global_step()
        dtype = tf.float16 if args.half else None

        # Multi-GPU setting
        sharded_losses = parallel.parallel_model(
            model.get_training_func(initializer, regularizer, dtype), features,
            params.device_list)
        loss = tf.add_n(sharded_losses) / len(sharded_losses)
        loss = loss + tf.losses.get_regularization_loss()

        if distribute.rank() == 0:
            print_variables()

        learning_rate = get_learning_rate_decay(params.learning_rate,
                                                global_step, params)
        learning_rate = tf.convert_to_tensor(learning_rate, dtype=tf.float32)

        tf.summary.scalar("loss", loss)
        tf.summary.scalar("learning_rate", learning_rate)

        # Create optimizer
        if params.optimizer == "Adam":
            opt = tf.train.AdamOptimizer(learning_rate,
                                         beta1=params.adam_beta1,
                                         beta2=params.adam_beta2,
                                         epsilon=params.adam_epsilon)
        elif params.optimizer == "LazyAdam":
            opt = tf.contrib.opt.LazyAdamOptimizer(learning_rate,
                                                   beta1=params.adam_beta1,
                                                   beta2=params.adam_beta2,
                                                   epsilon=params.adam_epsilon)
        else:
            raise RuntimeError("Optimizer %s not supported" % params.optimizer)

        opt = optimizers.MultiStepOptimizer(opt, params.update_cycle)

        if args.half:
            opt = optimizers.LossScalingOptimizer(opt, params.loss_scale)

        # Optimization
        grads_and_vars = opt.compute_gradients(
            loss, colocate_gradients_with_ops=True)

        if params.clip_grad_norm:
            grads, var_list = list(zip(*grads_and_vars))
            grads, _ = tf.clip_by_global_norm(grads, params.clip_grad_norm)
            grads_and_vars = zip(grads, var_list)

        train_op = opt.apply_gradients(grads_and_vars, global_step=global_step)

        # Validation
        if params.validation and params.references[0]:
            files = [params.validation] + list(params.references)
            eval_inputs = dataset.sort_and_zip_files(files)
            eval_input_fn = dataset.get_evaluation_input
        else:
            eval_input_fn = None

        # Hooks
        train_hooks = [
            tf.train.StopAtStepHook(last_step=params.train_steps),
            tf.train.NanTensorHook(loss),
            tf.train.LoggingTensorHook(
                {
                    "step": global_step,
                    "loss": loss,
                    "source": tf.shape(features["source"]),
                    "target": tf.shape(features["target"])
                },
                every_n_iter=1)
        ]

        broadcast_hook = distribute.get_broadcast_hook()

        if broadcast_hook:
            train_hooks.append(broadcast_hook)

        if distribute.rank() == 0:
            # Add hooks
            save_vars = tf.trainable_variables() + [global_step]
            saver = tf.train.Saver(
                var_list=save_vars if params.only_save_trainable else None,
                max_to_keep=params.keep_checkpoint_max,
                sharded=False)
            tf.add_to_collection(tf.GraphKeys.SAVERS, saver)
            train_hooks.append(
                hooks.MultiStepHook(tf.train.CheckpointSaverHook(
                    checkpoint_dir=params.output,
                    save_secs=params.save_checkpoint_secs or None,
                    save_steps=params.save_checkpoint_steps or None,
                    saver=saver),
                                    step=params.update_cycle))

            if eval_input_fn is not None:
                train_hooks.append(
                    hooks.MultiStepHook(hooks.EvaluationHook(
                        lambda f: inference.create_inference_graph([model], f,
                                                                   params),
                        lambda: eval_input_fn(eval_inputs, params),
                        lambda x: decode_target_ids(x, params),
                        params.output,
                        session_config(params),
                        device_list=params.device_list,
                        max_to_keep=params.keep_top_checkpoint_max,
                        eval_secs=params.eval_secs,
                        eval_steps=params.eval_steps),
                                        step=params.update_cycle))
            checkpoint_dir = params.output
        else:
            checkpoint_dir = None

        restore_op = restore_variables(args.checkpoint)

        def restore_fn(step_context):
            step_context.session.run(restore_op)

        # Create session, do not use default CheckpointSaverHook
        with tf.train.MonitoredTrainingSession(
                checkpoint_dir=checkpoint_dir,
                hooks=train_hooks,
                save_checkpoint_secs=None,
                config=session_config(params)) as sess:
            # Restore pre-trained variables
            sess.run_step_fn(restore_fn)

            while not sess.should_stop():
                sess.run(train_op)
コード例 #2
0
def main(args):
    tf.logging.set_verbosity(tf.logging.INFO)
    model_cls = models.get_model(args.model)
    params = default_parameters()

    # Import and override parameters
    # Priorities (low -> high):
    # default -> saved -> command
    params = merge_parameters(params, model_cls.get_parameters())
    params = import_params(args.output, args.model, params)
    override_parameters(params, args)

    # Export all parameters and model specific parameters
    export_params(params.output, "params.json", params)
    export_params(params.output, "%s.json" % args.model,
                  collect_params(params, model_cls.get_parameters()))

    #import ipdb; ipdb.set_trace()
    # Build Graph
    with tf.Graph().as_default():
        if not params.record:
            # Build input queue
            features = dataset_c2f_4layers.get_training_input_and_c2f_label(
                params.input, params.c2f_input, params)
        else:
            features = record.get_input_features(
                os.path.join(params.record, "*train*"), "train", params)

        update_cycle = params.update_cycle
        features, init_op = cache.cache_features(features, update_cycle)

        # Build model
        initializer = get_initializer(params)
        regularizer = tf.contrib.layers.l1_l2_regularizer(
            scale_l1=params.scale_l1, scale_l2=params.scale_l2)
        model = model_cls(params)
        # Create global step
        global_step = tf.train.get_or_create_global_step()

        # Multi-GPU setting
        sharded_losses = parallel.parallel_model(
            model.get_training_func(initializer, regularizer), features,
            params.device_list)
        if len(sharded_losses) > 1:
            losses_mle, losses_l1, losses_l2, losses_l3, losses_l4 = sharded_losses
            loss_mle = tf.add_n(losses_mle) / len(losses_mle)
            loss_l1 = tf.add_n(losses_l1) / len(losses_l1)
            loss_l2 = tf.add_n(losses_l2) / len(losses_l2)
            loss_l3 = tf.add_n(losses_l3) / len(losses_l3)
            loss_l4 = tf.add_n(losses_l4) / len(losses_l4)
        else:
            loss_mle, loss_l1, loss_l2, loss_l3, loss_l4 = sharded_losses[0]
        loss = loss_mle + (loss_l1 + loss_l2 + loss_l3 + loss_l4
                           ) / 4.0 + tf.losses.get_regularization_loss()

        # Print parameters
        all_weights = {v.name: v for v in tf.trainable_variables()}
        total_size = 0

        for v_name in sorted(list(all_weights)):
            v = all_weights[v_name]
            tf.logging.info("%s\tshape    %s", v.name[:-2].ljust(80),
                            str(v.shape).ljust(20))
            v_size = np.prod(np.array(v.shape.as_list())).tolist()
            total_size += v_size
        tf.logging.info("Total trainable variables size: %d", total_size)

        learning_rate = get_learning_rate_decay(params.learning_rate,
                                                global_step, params)
        learning_rate = tf.convert_to_tensor(learning_rate, dtype=tf.float32)
        tf.summary.scalar("learning_rate", learning_rate)

        # Create optimizer
        if params.optimizer == "Adam":
            opt = tf.train.AdamOptimizer(learning_rate,
                                         beta1=params.adam_beta1,
                                         beta2=params.adam_beta2,
                                         epsilon=params.adam_epsilon)
        elif params.optimizer == "LazyAdam":
            opt = tf.contrib.opt.LazyAdamOptimizer(learning_rate,
                                                   beta1=params.adam_beta1,
                                                   beta2=params.adam_beta2,
                                                   epsilon=params.adam_epsilon)
        else:
            raise RuntimeError("Optimizer %s not supported" % params.optimizer)

        loss, ops = optimize.create_train_op(loss, opt, global_step, params)
        restore_op = restore_variables(args.checkpoint)

        #import ipdb; ipdb.set_trace()

        # Validation
        if params.validation and params.references[0]:
            files = [params.validation] + list(params.references)
            eval_inputs = dataset_c2f_4layers.sort_and_zip_files(files)
            eval_input_fn = dataset_c2f_4layers.get_evaluation_input
        else:
            eval_input_fn = None

        # Add hooks
        save_vars = tf.trainable_variables() + [global_step]
        saver = tf.train.Saver(
            var_list=save_vars if params.only_save_trainable else None,
            max_to_keep=params.keep_checkpoint_max,
            sharded=False)
        tf.add_to_collection(tf.GraphKeys.SAVERS, saver)

        multiplier = tf.convert_to_tensor([update_cycle, 1])

        train_hooks = [
            tf.train.StopAtStepHook(last_step=params.train_steps),
            tf.train.NanTensorHook(loss),
            tf.train.LoggingTensorHook(
                {
                    "step": global_step,
                    "loss_mle": loss_mle,
                    "loss_l1": loss_l1,
                    "loss_l2": loss_l2,
                    "loss_l3": loss_l3,
                    "loss_l4": loss_l4,
                    "source": tf.shape(features["source"]) * multiplier,
                    "target": tf.shape(features["target"]) * multiplier
                },
                every_n_iter=1),
            tf.train.CheckpointSaverHook(
                checkpoint_dir=params.output,
                save_secs=params.save_checkpoint_secs or None,
                save_steps=params.save_checkpoint_steps or None,
                saver=saver)
        ]

        config = session_config(params)

        if eval_input_fn is not None:
            train_hooks.append(
                hooks.EvaluationHook(
                    lambda f: inference.create_inference_graph([model], f,
                                                               params),
                    lambda: eval_input_fn(eval_inputs, params),
                    lambda x: decode_target_ids(x, params),
                    params.output,
                    config,
                    params.keep_top_checkpoint_max,
                    eval_secs=params.eval_secs,
                    eval_steps=params.eval_steps))

        #def restore_fn(step_context):
        #    step_context.session.run(restore_op)

        #def step_fn(step_context):
        #    # Bypass hook calls
        #    step_context.session.run([init_op, ops["zero_op"]])
        #    for i in range(update_cycle - 1):
        #        step_context.session.run(ops["collect_op"])

        #    return step_context.run_with_hooks(ops["train_op"])

        # Create session, do not use default CheckpointSaverHook
        with tf.train.MonitoredTrainingSession(checkpoint_dir=params.output,
                                               hooks=train_hooks,
                                               save_checkpoint_secs=None,
                                               config=config) as sess:
            # Restore pre-trained variables
            sess._tf_sess().run(restore_op)

            while not sess.should_stop():
                sess._tf_sess().run([init_op, ops["zero_op"]])
                for i in range(update_cycle - 1):
                    sess._tf_sess().run(ops["collect_op"])
                sess.run(ops["train_op"])
コード例 #3
0
def main(args):
    if args.distribute:
        distribute.enable_distributed_training()

    tf.logging.set_verbosity(tf.logging.INFO)
    model_cls = models.get_model(args.model)
    params = default_parameters()

    # Import and override parameters
    # Priorities (low -> high):
    # default -> saved -> command
    params = merge_parameters(params, model_cls.get_parameters())
    params = import_params(args.output, args.model, params)
    override_parameters(params, args)

    # Export all parameters and model specific parameters
    if not args.distribute or distribute.rank() == 0:
        export_params(params.output, "params.json", params)
        export_params(params.output, "%s.json" % args.model,
                      collect_params(params, model_cls.get_parameters()))

    assert 'r2l' in params.input[2]
    # Build Graph
    use_all_devices(params)
    with tf.Graph().as_default():
        if not params.record:
            # Build input queue
            features = dataset.abd_get_training_input(params.input, params)
        else:
            features = record.get_input_features(
                os.path.join(params.record, "*train*"), "train", params)

        update_cycle = params.update_cycle
        features, init_op = cache.cache_features(features, update_cycle)

        # Build model
        initializer = get_initializer(params)
        regularizer = tf.contrib.layers.l1_l2_regularizer(
            scale_l1=params.scale_l1, scale_l2=params.scale_l2)
        model = model_cls(params)
        # Create global step
        global_step = tf.train.get_or_create_global_step()
        dtype = tf.float16 if args.fp16 else None

        if args.distribute:
            training_func = model.get_training_func(initializer, regularizer,
                                                    dtype)
            loss = training_func(features)
        else:
            # Multi-GPU setting
            sharded_losses = parallel.parallel_model(
                model.get_training_func(initializer, regularizer, dtype),
                features, params.device_list)
            loss = tf.add_n(sharded_losses) / len(sharded_losses)
            loss = loss + tf.losses.get_regularization_loss()

        # Print parameters
        if not args.distribute or distribute.rank() == 0:
            print_variables()

        learning_rate = get_learning_rate_decay(params.learning_rate,
                                                global_step, params)
        learning_rate = tf.convert_to_tensor(learning_rate, dtype=tf.float32)
        tf.summary.scalar("learning_rate", learning_rate)

        # Create optimizer
        if params.optimizer == "Adam":
            opt = tf.train.AdamOptimizer(learning_rate,
                                         beta1=params.adam_beta1,
                                         beta2=params.adam_beta2,
                                         epsilon=params.adam_epsilon)
        elif params.optimizer == "LazyAdam":
            opt = tf.contrib.opt.LazyAdamOptimizer(learning_rate,
                                                   beta1=params.adam_beta1,
                                                   beta2=params.adam_beta2,
                                                   epsilon=params.adam_epsilon)
        else:
            raise RuntimeError("Optimizer %s not supported" % params.optimizer)

        loss, ops = optimize.create_train_op(
            loss, opt, global_step,
            distribute.all_reduce if args.distribute else None, args.fp16,
            params)
        restore_op = restore_variables(args.checkpoint)

        # Validation
        if params.validation and params.references[0]:
            files = params.validation + list(params.references)
            eval_inputs = dataset.sort_and_zip_files(files)
            eval_input_fn = dataset.abd_get_evaluation_input
        else:
            eval_input_fn = None

        # Add hooks
        multiplier = tf.convert_to_tensor([update_cycle, 1])

        train_hooks = [
            tf.train.StopAtStepHook(last_step=params.train_steps),
            tf.train.NanTensorHook(loss),
            tf.train.LoggingTensorHook(
                {
                    "step": global_step,
                    "loss": loss,
                    "source": tf.shape(features["source"]) * multiplier,
                    "target": tf.shape(features["target"]) * multiplier
                },
                every_n_iter=1)
        ]

        if args.distribute:
            train_hooks.append(distribute.get_broadcast_hook())

        config = session_config(params)

        if not args.distribute or distribute.rank() == 0:
            # Add hooks
            save_vars = tf.trainable_variables() + [global_step]
            saver = tf.train.Saver(
                var_list=save_vars if params.only_save_trainable else None,
                max_to_keep=params.keep_checkpoint_max,
                sharded=False)
            tf.add_to_collection(tf.GraphKeys.SAVERS, saver)
            train_hooks.append(
                tf.train.CheckpointSaverHook(
                    checkpoint_dir=params.output,
                    save_secs=params.save_checkpoint_secs or None,
                    save_steps=params.save_checkpoint_steps or None,
                    saver=saver))

        if eval_input_fn is not None:
            if not args.distribute or distribute.rank() == 0:
                train_hooks.append(
                    hooks.EvaluationHook(
                        lambda f: inference.create_inference_graph([model], f,
                                                                   params),
                        lambda: eval_input_fn(eval_inputs, params),
                        lambda x: decode_target_ids(x, params),
                        params.output,
                        config,
                        params.keep_top_checkpoint_max,
                        eval_secs=params.eval_secs,
                        eval_steps=params.eval_steps))

        def restore_fn(step_context):
            step_context.session.run(restore_op)

        def step_fn(step_context):
            # Bypass hook calls
            step_context.session.run([init_op, ops["zero_op"]])
            for i in range(update_cycle - 1):
                step_context.session.run(ops["collect_op"])

            return step_context.run_with_hooks(ops["train_op"])

        # Create session, do not use default CheckpointSaverHook
        if not args.distribute or distribute.rank() == 0:
            checkpoint_dir = params.output
        else:
            checkpoint_dir = None

        with tf.train.MonitoredTrainingSession(checkpoint_dir=checkpoint_dir,
                                               hooks=train_hooks,
                                               save_checkpoint_secs=None,
                                               config=config) as sess:
            # Restore pre-trained variables
            sess.run_step_fn(restore_fn)

            while not sess.should_stop():
                sess.run_step_fn(step_fn)
コード例 #4
0
def main(args):
    tf.logging.set_verbosity(tf.logging.INFO)
    # model_cls = models.get_model(args.model)
    model_cls = transformer_cache_fixencoder.Transformer
    params = default_parameters()

    # Import and override parameters
    # Priorities (low -> high):
    # default -> saved -> command
    params = merge_parameters(params, model_cls.get_parameters())
    params = import_params(args.output, args.model, params)
    override_parameters(params, args)

    # Export all parameters and model specific parameters
    export_params(params.output, "params.json", params)
    export_params(params.output, "%s.json" % args.model,
                  collect_params(params, model_cls.get_parameters()))

    # Build Graph
    with tf.Graph().as_default():
        if not params.record:
            # Build input queue
            features = dataset.get_training_input_src_context(
                params.input, params)
        else:
            features = record.get_input_features(
                os.path.join(params.record, "*train*"), "train", params)

        features, init_op = cache.cache_features(features, params.update_cycle)

        # Build model
        initializer = get_initializer(params)
        model = model_cls(params)

        # Multi-GPU setting
        sharded_losses = parallel.parallel_model(
            model.get_training_func(initializer), features, params.device_list)
        loss = tf.add_n(sharded_losses) / len(sharded_losses)

        # Create global step
        global_step = tf.train.get_or_create_global_step()

        # Print parameters
        all_weights = {v.name: v for v in tf.trainable_variables()}
        total_size = 0

        for v_name in sorted(list(all_weights)):
            v = all_weights[v_name]
            tf.logging.info("%s\tshape    %s", v.name[:-2].ljust(80),
                            str(v.shape).ljust(20))
            v_size = np.prod(np.array(v.shape.as_list())).tolist()
            total_size += v_size
        tf.logging.info("Total trainable variables size: %d", total_size)

        learning_rate = get_learning_rate_decay(params.learning_rate,
                                                global_step, params)
        learning_rate = tf.convert_to_tensor(learning_rate, dtype=tf.float32)
        tf.summary.scalar("learning_rate", learning_rate)

        # Create optimizer
        if params.optimizer == "Adam":
            opt = tf.train.AdamOptimizer(learning_rate,
                                         beta1=params.adam_beta1,
                                         beta2=params.adam_beta2,
                                         epsilon=params.adam_epsilon)
        elif params.optimizer == "LazyAdam":
            opt = tf.contrib.opt.LazyAdamOptimizer(learning_rate,
                                                   beta1=params.adam_beta1,
                                                   beta2=params.adam_beta2,
                                                   epsilon=params.adam_epsilon)
        else:
            raise RuntimeError("Optimizer %s not supported" % params.optimizer)

        loss, ops = optimize.create_train_op(loss, opt, global_step, params)
        restore_op = restore_variables(args.checkpoint)
        restore_trained_encoder_op = restore_encoder_variables(
            args.thumt_checkpoint)
        # Validation
        if params.validation and params.references[0]:
            files = [params.validation] + list(params.references)
            eval_inputs = dataset.sort_and_zip_files_catch(files)
            eval_input_fn = dataset.get_evaluation_input_catch
        else:
            eval_input_fn = None

        # Add hooks
        save_vars = tf.trainable_variables() + [global_step]
        saver = tf.train.Saver(
            var_list=save_vars if params.only_save_trainable else None,
            max_to_keep=params.keep_checkpoint_max,
            sharded=False)
        tf.add_to_collection(tf.GraphKeys.SAVERS, saver)

        train_hooks = [
            tf.train.StopAtStepHook(last_step=params.train_steps),
            tf.train.NanTensorHook(loss),
            tf.train.LoggingTensorHook({
                "step": global_step,
                "loss": loss,
            },
                                       every_n_iter=1),
            tf.train.CheckpointSaverHook(
                checkpoint_dir=params.output,
                save_secs=params.save_checkpoint_secs or None,
                save_steps=params.save_checkpoint_steps or None,
                saver=saver)
        ]

        config = session_config(params)

        if eval_input_fn is not None:
            train_hooks.append(
                hooks.EvaluationHook(
                    lambda f: inference.create_inference_graph(
                        [model.get_inference_func()], f, params),
                    lambda: eval_input_fn(eval_inputs, params),
                    lambda x: decode_target_ids(x, params),
                    params.output,
                    config,
                    params.keep_top_checkpoint_max,
                    eval_secs=params.eval_secs,
                    eval_steps=params.eval_steps))

        def restore_fn(step_context):
            step_context.session.run(restore_op)
            step_context.session.run(restore_trained_encoder_op)

        def step_fn(step_context):
            # Bypass hook calls
            step_context.session.run([init_op, ops["zero_op"]])
            for i in range(params.update_cycle):
                step_context.session.run(ops["collect_op"])
            step_context.session.run(ops["scale_op"])

            # ####################################
            # # print some unchanged variable
            # scale = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
            #                           "transformer/encoder/layer_0/self_attention/layer_norm/scale")
            # # scale = tf.get_variable("transformer/encoder/layer_0/self_attention/layer_norm/scale")
            # scale = step_context.session.run(scale[0])
            #
            # print(scale)
            #
            # ####################################
            # # print some changed variable
            #
            # scale = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
            #                           "transformer/context/head_to_scalar")
            # # scale = tf.get_variable("transformer/encoder/layer_0/self_attention/layer_norm/scale")
            # scale = step_context.session.run(scale[0])
            #
            # print(scale)

            return step_context.run_with_hooks(ops["train_op"])

        # Create session, do not use default CheckpointSaverHook
        with tf.train.MonitoredTrainingSession(checkpoint_dir=params.output,
                                               hooks=train_hooks,
                                               save_checkpoint_secs=None,
                                               config=config) as sess:
            # Restore pre-trained variables
            sess.run_step_fn(restore_fn)

            while not sess.should_stop():
                sess.run_step_fn(step_fn)
コード例 #5
0
def main(args):
    tf.logging.set_verbosity(tf.logging.INFO)
    model_cls = models.get_model(args.model)  # a model class
    params = default_parameters()

    # Import and override parameters
    # Priorities (low -> high):
    # default -> saved -> command
    params = merge_parameters(params, model_cls.get_parameters())
    params = import_params(args.output, args.model, params)
    override_parameters(params, args)

    # Export all parameters and model specific parameters
    export_params(params.output, "params.json", params)
    export_params(params.output, "%s.json" % args.model,
                  collect_params(params, model_cls.get_parameters()))
    # print(params.vocabulary)

    # Build Graph
    with tf.Graph().as_default():
        if not params.record:
            # Build input queue
            parsing_features = dataset.get_training_input(params.parsing_input,
                                                          params,
                                                          problem='parsing')
            amr_features = dataset.get_training_input(params.amr_input,
                                                      params,
                                                      problem='amr')
        else:
            parsing_features = record.get_input_features(
                os.path.join(params.record, "*train*"), "train", params)
            amr_features = record.get_input_features(
                os.path.join(params.record, "*train*"), "train", params)

        update_cycle = params.update_cycle
        parsing_features, parsing_init_op = cache.cache_features(
            parsing_features, update_cycle)
        amr_features, amr_init_op = cache.cache_features(
            amr_features, update_cycle)

        # Build model
        initializer = get_initializer(params)
        regularizer = tf.contrib.layers.l1_l2_regularizer(
            scale_l1=params.scale_l1, scale_l2=params.scale_l2)
        model = model_cls(params)
        # Create global step
        global_step = tf.train.get_or_create_global_step()

        # Multi-GPU setting
        # parsing_sharded_losses, amr_sharded_losses = parallel.parallel_model(
        #         model.get_training_func(initializer, regularizer),
        #         [parsing_features, amr_features],
        #         params.device_list
        # )
        # parsing_loss, amr_loss = model.get_training_func(initializer, regularizer)([parsing_features, amr_features])
        # with tf.variable_scope("shared_decode_variable") as scope:
        #     parsing_loss = model.get_training_func(initializer, regularizer, problem='parsing')(parsing_features)
        #     scope.reuse_variables()
        #     amr_loss = model.get_training_func(initializer, regularizer, problem='amr')(amr_features)
        #with tf.variable_scope("encoder_shared", reuse=True):
        print(params.layer_postprocess)
        with tf.variable_scope("encoder_shared",
                               initializer=initializer,
                               regularizer=regularizer,
                               reuse=tf.AUTO_REUSE):
            parsing_encoder_output = model.get_encoder_out(
                parsing_features, "train", params)
            amr_encoder_output = model.get_encoder_out(amr_features, "train",
                                                       params)
        with tf.variable_scope("parsing_decoder",
                               initializer=initializer,
                               regularizer=regularizer):
            parsing_loss = model.get_decoder_out(parsing_features,
                                                 parsing_encoder_output,
                                                 "train",
                                                 params,
                                                 problem="parsing")
        with tf.variable_scope("amr_decoder",
                               initializer=initializer,
                               regularizer=regularizer):
            amr_loss = model.get_decoder_out(amr_features,
                                             amr_encoder_output,
                                             "train",
                                             params,
                                             problem="amr")

        parsing_loss = parsing_loss + tf.losses.get_regularization_loss()

        amr_loss = amr_loss + tf.losses.get_regularization_loss()

        # Print parameters
        all_weights = {v.name: v for v in tf.trainable_variables()}
        total_size = 0

        for v_name in sorted(list(all_weights)):
            v = all_weights[v_name]
            tf.logging.info("%s\tshape    %s", v.name[:-2].ljust(80),
                            str(v.shape).ljust(20))
            v_size = np.prod(np.array(v.shape.as_list())).tolist()
            total_size += v_size
        tf.logging.info("Total trainable variables size: %d", total_size)

        learning_rate = get_learning_rate_decay(params.learning_rate,
                                                global_step, params)
        learning_rate = tf.convert_to_tensor(learning_rate, dtype=tf.float32)
        tf.summary.scalar("learning_rate", learning_rate)

        # Create optimizer
        if params.optimizer == "Adam":
            opt = tf.train.AdamOptimizer(learning_rate,
                                         beta1=params.adam_beta1,
                                         beta2=params.adam_beta2,
                                         epsilon=params.adam_epsilon)
        elif params.optimizer == "LazyAdam":
            opt = tf.contrib.opt.LazyAdamOptimizer(learning_rate,
                                                   beta1=params.adam_beta1,
                                                   beta2=params.adam_beta2,
                                                   epsilon=params.adam_epsilon)
        else:
            raise RuntimeError("Optimizer %s not supported" % params.optimizer)

        parsing_loss, parsing_ops = optimize.create_train_op(parsing_loss,
                                                             opt,
                                                             global_step,
                                                             params,
                                                             problem="parsing")
        amr_loss, amr_ops = optimize.create_train_op(amr_loss,
                                                     opt,
                                                     global_step,
                                                     params,
                                                     problem="amr")

        restore_op = restore_variables(args.checkpoint)

        # Validation
        if params.validation and params.references[0]:
            files = [params.validation] + list(params.references)
            eval_inputs = dataset.sort_and_zip_files(files)
            eval_input_fn = dataset.get_evaluation_input
        else:
            eval_input_fn = None

        # Add hooks
        save_vars = tf.trainable_variables() + [global_step]
        saver = tf.train.Saver(
            var_list=save_vars if params.only_save_trainable else None,
            max_to_keep=params.keep_checkpoint_max,
            sharded=False)
        tf.add_to_collection(tf.GraphKeys.SAVERS, saver)

        multiplier = tf.convert_to_tensor([update_cycle, 1])

        train_hooks = [
            tf.train.StopAtStepHook(last_step=params.train_steps),
            tf.train.NanTensorHook(parsing_loss),
            tf.train.NanTensorHook(amr_loss),
            tf.train.LoggingTensorHook(
                {
                    "step": global_step,
                    "parsing_loss": parsing_loss,
                    "amr_loss": amr_loss,
                    "parsing_source":
                    tf.shape(parsing_features["source"]) * multiplier,
                    "parsing_target":
                    tf.shape(parsing_features["target"]) * multiplier,
                    "amr_source":
                    tf.shape(amr_features["source"]) * multiplier,
                    "amr_target": tf.shape(amr_features["target"]) * multiplier
                },
                every_n_iter=1),
            tf.train.CheckpointSaverHook(
                checkpoint_dir=params.output,
                save_secs=params.save_checkpoint_secs or None,
                save_steps=params.save_checkpoint_steps or None,
                saver=saver)
        ]

        config = session_config(params)

        if eval_input_fn is not None:
            train_hooks.append(
                hooks.EvaluationHook(
                    lambda f: inference.create_inference_graph([model], f,
                                                               params),
                    lambda: eval_input_fn(eval_inputs, params),
                    lambda x: decode_target_ids(x, params, flag='target'),
                    lambda x: decode_target_ids(x, params, flag='source'),
                    params.output,
                    config,
                    params.keep_top_checkpoint_max,
                    eval_secs=params.eval_secs,
                    eval_steps=params.eval_steps))

        def restore_fn(step_context):
            step_context.session.run(restore_op)

        def parsing_step_fn(step_context):
            # Bypass hook calls
            step_context.session.run([parsing_init_op, parsing_ops["zero_op"]
                                      ])  # if params.cycle==1 do nothing
            for i in range(update_cycle - 1):
                step_context.session.run(parsing_ops["collect_op"])

            return step_context.run_with_hooks(parsing_ops["train_op"])

        def amr_step_fn(step_context):
            # Bypass hook calls
            step_context.session.run([amr_init_op, amr_ops["zero_op"]])
            for i in range(update_cycle - 1):
                step_context.session.run(amr_ops["collect_op"])

            return step_context.run_with_hooks(amr_ops["train_op"])

        def step_fn(step_context):
            # Bypass hook calls
            return step_context.run_with_hooks(parsing_ops["train_op"])

        # Create session, do not use default CheckpointSaverHook
        with tf.train.MonitoredTrainingSession(checkpoint_dir=params.output,
                                               hooks=train_hooks,
                                               save_checkpoint_secs=None,
                                               config=config) as sess:
            # Restore pre-trained variables
            sess.run_step_fn(restore_fn)
            step = 0
            while not sess.should_stop():
                if step % 2 == 0:
                    sess.run_step_fn(parsing_step_fn)
                else:
                    sess.run_step_fn(amr_step_fn)
                step += 1
コード例 #6
0
ファイル: translator.py プロジェクト: Quincy1994/SSAN
def main(args):
    tf.logging.set_verbosity(tf.logging.INFO)
    # Load configs
    model_cls_list = [models.get_model(model) for model in args.models]
    params_list = [default_parameters() for _ in range(len(model_cls_list))]
    params_list = [
        merge_parameters(params, model_cls.get_parameters())
        for params, model_cls in zip(params_list, model_cls_list)
    ]
    params_list = [
        import_params(args.checkpoints[i], args.models[i], params_list[i])
        for i in range(len(args.checkpoints))
    ]
    params_list = [
        override_parameters(params_list[i], args)
        for i in range(len(model_cls_list))
    ]

    # Build Graph
    with tf.Graph().as_default():
        model_var_lists = []

        # Load checkpoints
        for i, checkpoint in enumerate(args.checkpoints):
            tf.logging.info("Loading %s" % checkpoint)
            var_list = tf.train.list_variables(checkpoint)
            values = {}
            reader = tf.train.load_checkpoint(checkpoint)

            for (name, shape) in var_list:
                if not name.startswith(model_cls_list[i].get_name()):
                    continue

                if name.find("losses_avg") >= 0:
                    continue

                tensor = reader.get_tensor(name)
                values[name] = tensor

            model_var_lists.append(values)

        # Build models
        model_fns = []

        for i in range(len(args.checkpoints)):
            name = model_cls_list[i].get_name()
            model = model_cls_list[i](params_list[i], name + "_%d" % i)
            model_fn = model.get_inference_func()
            model_fns.append(model_fn)

        params = params_list[0]
        # Read input file
        sorted_keys, sorted_inputs = dataset.sort_input_file(args.input)
        # Build input queue
        features = dataset.get_inference_input(sorted_inputs, params)
        # Create placeholders
        placeholders = []

        for i in range(len(params.device_list)):
            placeholders.append({
                "source":
                tf.placeholder(tf.int32, [None, None], "source_%d" % i),
                "source_length":
                tf.placeholder(tf.int32, [None], "source_length_%d" % i)
            })

        # A list of outputs
        predictions = parallel.data_parallelism(
            params.device_list,
            lambda f: inference.create_inference_graph(model_fns, f, params),
            placeholders)

        # Create assign ops
        assign_ops = []

        all_var_list = tf.trainable_variables()

        for i in range(len(args.checkpoints)):
            un_init_var_list = []
            name = model_cls_list[i].get_name()

            for v in all_var_list:
                if v.name.startswith(name + "_%d" % i):
                    un_init_var_list.append(v)

            ops = set_variables(un_init_var_list, model_var_lists[i],
                                name + "_%d" % i)
            assign_ops.extend(ops)

        assign_op = tf.group(*assign_ops)
        results = []

        # Create session
        with tf.Session(config=session_config(params)) as sess:
            # Restore variables
            sess.run(assign_op)
            sess.run(tf.tables_initializer())
            start = time.time()
            while True:
                try:
                    feats = sess.run(features)
                    op, feed_dict = shard_features(feats, placeholders,
                                                   predictions)
                    results.append(sess.run(predictions, feed_dict=feed_dict))
                    #message = "Finished batch %d" % len(results)
                    #tf.logging.log(tf.logging.INFO, message)
                except tf.errors.OutOfRangeError:
                    break
            elapsed = time.time() - start
            tf.logging.log(tf.logging.INFO, "total time: %d" % elapsed)

        # Convert to plain text
        vocab = params.vocabulary["target"]
        outputs = []
        scores = []

        for result in results:
            for item in result[0]:
                outputs.append(item.tolist())
            for item in result[1]:
                scores.append(item.tolist())

        outputs = list(itertools.chain(*outputs))
        scores = list(itertools.chain(*scores))

        restored_inputs = []
        restored_outputs = []
        restored_scores = []

        for index in range(len(sorted_inputs)):
            restored_inputs.append(sorted_inputs[sorted_keys[index]])
            restored_outputs.append(outputs[sorted_keys[index]])
            restored_scores.append(scores[sorted_keys[index]])

        # Write to file
        with open(args.output, "w") as outfile:
            count = 0
            for outputs, scores in zip(restored_outputs, restored_scores):
                for output, score in zip(outputs, scores):
                    decoded = []
                    for idx in output:
                        if idx == params.mapping["target"][params.eos]:
                            break
                        decoded.append(vocab[idx])

                    decoded = " ".join(decoded)

                    if not args.verbose:
                        outfile.write("%s\n" % decoded)
                        break
                    else:
                        pattern = "%d ||| %s ||| %s ||| %f\n"
                        source = restored_inputs[count]
                        values = (count, source, decoded, score)
                        outfile.write(pattern % values)

                count += 1
コード例 #7
0
ファイル: word_order_MT.py プロジェクト: hfxunlp/WRD
def main(args):
    eval_steps = args.eval_steps
    tf.logging.set_verbosity(tf.logging.DEBUG)
    # Load configs
    model_cls_list = [models.get_model(model) for model in args.models]
    params_list = [default_parameters() for _ in range(len(model_cls_list))]
    params_list = [
        merge_parameters(params, model_cls.get_parameters())
        for params, model_cls in zip(params_list, model_cls_list)
    ]
    params_list = [
        import_params(args.checkpoints[i], args.models[i], params_list[i])
        for i in range(len(args.checkpoints))
    ]
    params_list = [
        override_parameters(params_list[i], args)
        for i in range(len(model_cls_list))
    ]

    # Build Graph
    with tf.Graph().as_default():
        model_var_lists = []

        # Load checkpoints
        for i, checkpoint in enumerate(args.checkpoints):
            tf.logging.info("Loading %s" % checkpoint)
            var_list = tf.train.list_variables(checkpoint)
            values = {}
            reader = tf.train.load_checkpoint(checkpoint)

            for (name, shape) in var_list:
                if not name.startswith(model_cls_list[i].get_name()):
                    continue

                if name.find("losses_avg") >= 0:
                    continue

                tensor = reader.get_tensor(name)
                values[name] = tensor

            model_var_lists.append(values)

        # Build models
        model_fns = []

        for i in range(len(args.checkpoints)):
            name = model_cls_list[i].get_name()
            model = model_cls_list[i](params_list[i], name + "_%d" % i)
            model_fn = model.get_inference_func()
            model_fns.append(model_fn)

        params = params_list[0]
        # Read input file
        #features = dataset.get_inference_input(args.input, params)
        #features_eval = dataset.get_inference_input(args.eval, params)
        #features_test = dataset.get_inference_input(args.test, params)

        features_train = dataset.get_inference_input(args.input, params, False,
                                                     True)
        features_eval = dataset.get_inference_input(args.eval, params, True,
                                                    False)
        features_test = dataset.get_inference_input(args.test, params, True,
                                                    False)

        # Create placeholders
        placeholders = []

        for i in range(len(params.device_list)):
            placeholders.append({
                "source":
                tf.placeholder(tf.int32, [None, None], "source_%d" % i),
                "source_length":
                tf.placeholder(tf.int32, [None], "source_length_%d" % i),
                "target":
                tf.placeholder(tf.int32, [None, 2], "target_%d" % i)
            })

        # A list of outputs
        predictions = parallel.data_parallelism(
            params.device_list,
            lambda f: inference.create_inference_graph(model_fns, f, params),
            placeholders)

        # Create assign ops
        assign_ops = []

        all_var_list = tf.trainable_variables()

        for i in range(len(args.checkpoints)):
            un_init_var_list = []
            name = model_cls_list[i].get_name()

            for v in all_var_list:
                if v.name.startswith(name + "_%d" % i):
                    un_init_var_list.append(v)

            ops = set_variables(un_init_var_list, model_var_lists[i],
                                name + "_%d" % i)
            assign_ops.extend(ops)

        assign_op = tf.group(*assign_ops)
        results = []

        tf_x = tf.placeholder(tf.float32, [None, None, 512])
        tf_y = tf.placeholder(tf.int32, [None, 2])
        tf_x_len = tf.placeholder(tf.int32, [None])

        src_mask = -1e9 * (1.0 - tf.sequence_mask(
            tf_x_len, maxlen=tf.shape(predictions[0])[1], dtype=tf.float32))
        with tf.variable_scope("my_metric"):
            #q,k,v = tf.split(linear(tf_x, 3*512, True, True, scope="logit_transform"), [512, 512,512],axis=-1)
            q, k, v = tf.split(nn.linear(predictions[0],
                                         3 * 512,
                                         True,
                                         True,
                                         scope="logit_transform"),
                               [512, 512, 512],
                               axis=-1)
            q = nn.linear(
                tf.nn.tanh(q), 1, True, True,
                scope="logit_transform2")[:, :, 0] + src_mask
            # label smoothing
            ce1 = nn.smoothed_softmax_cross_entropy_with_logits(
                logits=q,
                labels=tf_y[:, :1],
                #smoothing=params.label_smoothing,
                smoothing=False,
                normalize=True)
            w1 = tf.nn.softmax(q)[:, None, :]
            #k = nn.linear(tf.nn.tanh(tf.matmul(w1,v)+k),1,True,True,scope="logit_transform3")[:,:,0]+src_mask
            k = tf.matmul(k,
                          tf.matmul(w1, v) *
                          (512**-0.5), False, True)[:, :, 0] + src_mask
            # label smoothing
            ce2 = nn.smoothed_softmax_cross_entropy_with_logits(
                logits=k,
                labels=tf_y[:, 1:],
                #smoothing=params.label_smoothing,
                smoothing=False,
                normalize=True)
            w2 = tf.nn.softmax(k)[:, None, :]
            weights = tf.concat([w1, w2], axis=1)
        loss = tf.reduce_mean(ce1 + ce2)

        #tf_x = tf.placeholder(tf.float32, [None, 512])
        #tf_y = tf.placeholder(tf.int32, [None])

        #l1 = tf.layers.dense(tf.squeeze(predictions[0], axis=-2), 64, tf.nn.sigmoid)
        #output = tf.layers.dense(l1, int(args.softmax_size))

        #loss = tf.losses.sparse_softmax_cross_entropy(labels=tf_y, logits=output)
        o1 = tf.argmax(w1, axis=-1)
        o2 = tf.argmax(w2, axis=-1)
        a1, a1_update = tf.metrics.accuracy(labels=tf.squeeze(tf_y[:, 0]),
                                            predictions=tf.argmax(w1, axis=-1),
                                            name='a1')
        a2, a2_update = tf.metrics.accuracy(labels=tf.squeeze(tf_y[:, 1]),
                                            predictions=tf.argmax(w2, axis=-1),
                                            name='a2')
        accuracy, accuracy_update = tf.metrics.accuracy(
            labels=tf.squeeze(tf_y),
            predictions=tf.argmax(weights, axis=-1),
            name='a_all')

        running_vars = tf.get_collection(tf.GraphKeys.LOCAL_VARIABLES,
                                         scope="my_metric")
        #running_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope="my_metric")
        running_vars_initializer = tf.variables_initializer(
            var_list=running_vars)

        #variables_to_train = tf.trainable_variables()
        #print (len(variables_to_train), (variables_to_train[0]), variables_to_train[1])
        #variables_to_train.remove(variables_to_train[0])
        #variables_to_train.remove(variables_to_train[0])
        #print (len(variables_to_train))
        variables_to_train = [
            v for v in tf.trainable_variables()
            if v.name.startswith("my_metric")
        ]

        optimizer = tf.train.AdamOptimizer(learning_rate=0.001)
        train_op = optimizer.minimize(loss, var_list=variables_to_train)
        #train_op = optimizer.minimize(loss, var_list=running_vars)

        # Create session
        with tf.Session(config=session_config(params)) as sess:
            init_op = tf.group(tf.global_variables_initializer(),
                               tf.local_variables_initializer())
            sess.run(init_op)
            # Restore variables
            sess.run(assign_op)
            sess.run(tf.tables_initializer())

            current_step = 0

            best_validate_acc = 0
            last_test_acc = 0

            train_x_set = []
            train_y_set = []
            valid_x_set = []
            valid_y_set = []
            test_x_set = []
            test_y_set = []
            train_x_len_set = []
            valid_x_len_set = []
            test_x_len_set = []

            while current_step < eval_steps:
                print('=======current step ' + str(current_step))
                batch_num = 0
                while True:
                    try:
                        feats = sess.run(features_train)
                        op, feed_dict = shard_features(feats, placeholders,
                                                       predictions)
                        #x = (np.squeeze(sess.run(predictions, feed_dict=feed_dict), axis = -2))
                        y = feed_dict.values()[2]
                        x_len = feed_dict.values()[1]

                        feed_dict.update({tf_y: y})
                        feed_dict.update({tf_x_len: x_len})

                        los, __, pred = sess.run([loss, train_op, weights],
                                                 feed_dict=feed_dict)
                        print("current_step", current_step, "batch_num",
                              batch_num, "loss", los)

                        batch_num += 1
                        if batch_num % 100 == 0:

                            # eval
                            b_total = 0
                            a_total = 0
                            a1_total = 0
                            a2_total = 0
                            validate_acc = 0
                            batch_num_eval = 0

                            while True:
                                try:
                                    feats_eval = sess.run(features_eval)
                                    op, feed_dict_eval = shard_features(
                                        feats_eval, placeholders, predictions)
                                    #x = (np.squeeze(sess.run(predictions, feed_dict=feed_dict), axis = -2))
                                    y = feed_dict_eval.values()[2]
                                    x_len = feed_dict_eval.values()[1]
                                    feed_dict_eval.update({tf_y: y})
                                    feed_dict_eval.update({tf_x_len: x_len})

                                    sess.run(running_vars_initializer)
                                    acc = 0
                                    #acc, pred = sess.run([accuracy, output], feed_dict = {tf_x : x, tf_y : y})
                                    sess.run([
                                        a1_update, a2_update, accuracy_update,
                                        weights
                                    ],
                                             feed_dict=feed_dict_eval)
                                    acc1, acc2, acc = sess.run(
                                        [a1, a2, accuracy])
                                    batch_size = len(y)
                                    #print(acc)
                                    a1_total += round(batch_size * acc1)
                                    a2_total += round(batch_size * acc2)
                                    a_total += round(batch_size * acc)
                                    b_total += batch_size
                                    batch_num_eval += 1

                                    if batch_num_eval == 20:
                                        break

                                except tf.errors.OutOfRangeError:
                                    print("eval out of range")
                                    break
                            if b_total:
                                validate_acc = a_total / b_total
                                print("eval acc : " + str(validate_acc) +
                                      "( " + str(a1_total / b_total) + ", " +
                                      str(a2_total / b_total) + " )")
                            print("last test acc : " + str(last_test_acc))

                            if validate_acc > best_validate_acc:
                                best_validate_acc = validate_acc

                            # test
                            b_total = 0
                            a1_total = 0
                            a2_total = 0
                            a_total = 0
                            batch_num_test = 0
                            with open(args.output, "w") as outfile:
                                while True:
                                    try:
                                        feats_test = sess.run(features_test)
                                        op, feed_dict_test = shard_features(
                                            feats_test, placeholders,
                                            predictions)

                                        #x = (np.squeeze(sess.run(predictions, feed_dict=feed_dict), axis = -2))
                                        y = feed_dict_test.values()[2]
                                        x_len = feed_dict_test.values()[1]
                                        feed_dict_test.update({tf_y: y})
                                        feed_dict_test.update(
                                            {tf_x_len: x_len})

                                        sess.run(running_vars_initializer)
                                        acc = 0
                                        #acc, pred = sess.run([accuracy, output], feed_dict = {tf_x : x, tf_y : y})
                                        __, __, __, out1, out2 = sess.run(
                                            [
                                                a1_update, a2_update,
                                                accuracy_update, o1, o2
                                            ],
                                            feed_dict=feed_dict_test)
                                        acc1, acc2, acc = sess.run(
                                            [a1, a2, accuracy])

                                        batch_size = len(y)
                                        a_total += round(batch_size * acc)
                                        a1_total += round(batch_size * acc1)
                                        a2_total += round(batch_size * acc2)
                                        b_total += batch_size
                                        batch_num_test += 1
                                        for pred1, pred2 in zip(out1, out2):
                                            outfile.write("%s " % pred1[0])
                                            outfile.write("%s\n" % pred2[0])
                                        if batch_num_test == 20:
                                            break
                                    except tf.errors.OutOfRangeError:
                                        print("test out of range")
                                        break
                                if b_total:
                                    last_test_acc = a_total / b_total
                                    print("new test acc : " +
                                          str(last_test_acc) + "( " +
                                          str(a1_total / b_total) + ", " +
                                          str(a2_total / b_total) + " )")

                        if batch_num == 25000:
                            break
                    except tf.errors.OutOfRangeError:
                        print("train out of range")
                        break

                # eval


#                b_total = 0
#                a_total = 0
#                a1_total = 0
#                a2_total = 0
#                validate_acc = 0
#                batch_num = 0

#                while True:
#                    try:
#                        feats_eval = sess.run(features_eval)
#                        op, feed_dict = shard_features(feats_eval, placeholders, predictions)
#                        #x = (np.squeeze(sess.run(predictions, feed_dict=feed_dict), axis = -2))
#                        y =  feed_dict.values()[2]
#                        x_len =  feed_dict.values()[1]
#                        feed_dict.update({tf_y:y})
#                        feed_dict.update({tf_x_len:x_len})

#                        sess.run(running_vars_initializer)
#                        acc = 0
#acc, pred = sess.run([accuracy, output], feed_dict = {tf_x : x, tf_y : y})
#                        sess.run([a1_update, a2_update, accuracy_update, weights], feed_dict = feed_dict)
#                        acc1,acc2,acc = sess.run([a1,a2,accuracy])
#                        batch_size = len(y)
#print(acc)
#                        a1_total += round(batch_size*acc1)
#                        a2_total += round(batch_size*acc2)
#                        a_total += round(batch_size*acc)
#                        b_total += batch_size
#                        batch_num += 1

#                        if batch_num == 10:
#                            break

#                    except tf.errors.OutOfRangeError:
#                        print ("eval out of range")
#                        break

#                validate_acc = a_total/b_total
#                print("eval acc : "  + str(validate_acc) + "( "+str(a1_total/b_total)+ ", "+ str(a2_total/b_total) + " )")
#                print("last test acc : " + str(last_test_acc))

#                if validate_acc > best_validate_acc:
#                    best_validate_acc = validate_acc

# test
#                    b_total = 0
#                    a1_total = 0
#                    a2_total = 0
#                    a_total = 0
#                    batch_num = 0

#                    while True:
#                        try:
#                            feats_test = sess.run(features_test)
#                            op, feed_dict = shard_features(feats_test, placeholders,
#                                                             predictions)

#x = (np.squeeze(sess.run(predictions, feed_dict=feed_dict), axis = -2))
#                            y =  feed_dict.values()[2]
#                            x_len =  feed_dict.values()[1]
#                            feed_dict.update({tf_y:y})
#                            feed_dict.update({tf_x_len:x_len})

#                            sess.run(running_vars_initializer)
#                            acc = 0
#acc, pred = sess.run([accuracy, output], feed_dict = {tf_x : x, tf_y : y})
#                            sess.run([a1_update,a2_update,accuracy_update, weights], feed_dict = feed_dict)
#                            acc1,acc2,acc = sess.run([a1,a2,accuracy])

#                            batch_size = len(y)
#                            a_total += round(batch_size*acc)
#                            a1_total += round(batch_size*acc1)
#                            a2_total += round(batch_size*acc2)
#                            b_total += batch_size
#                            batch_num += 1

#                            if batch_num==10:
#                                break
#                        except tf.errors.OutOfRangeError:
#                            print ("test out of range")
#                            break
#                    last_test_acc = a_total/b_total
#                    print("new test acc : " + str(last_test_acc)+ "( "+str(a1_total/b_total)+ ", "+ str(a2_total/b_total) + " )")

                current_step += 1
                print("")
        print("Final test acc " + str(last_test_acc))

        return