Esempio n. 1
0
def main(args):
    tf.logging.set_verbosity(tf.logging.INFO)
    model_cls = models.get_model(args.model)
    params = default_parameters()

    # Import and override parameters
    # Priorities (low -> high):
    # default -> saved -> command
    params = merge_parameters(params, model_cls.get_parameters())
    params = import_params(args.checkpoint, args.model, params)
    override_parameters(params, args)

    # Build Graph
    with tf.Graph().as_default():
        model = model_cls(params)
        inputs = read_files(args.input)
        features = get_features(inputs, params)
        score_fn = model.get_evaluation_func()
        scores = score_fn(features, params)

        sess_creator = tf.train.ChiefSessionCreator(
            config=session_config(params)
        )

        # Load checkpoint
        tf.logging.info("Loading %s" % args.checkpoint)
        var_list = tf.train.list_variables(args.checkpoint)
        values = {}
        reader = tf.train.load_checkpoint(args.checkpoint)

        for (name, shape) in var_list:
            if not name.startswith(model_cls.get_name()):
                continue

            tensor = reader.get_tensor(name)
            values[name] = tensor

        ops = set_variables(tf.trainable_variables(), values,
                            model_cls.get_name())
        assign_op = tf.group(*ops)

        # Create session
        with tf.train.MonitoredSession(session_creator=sess_creator) as sess:
            # Restore variables
            sess.run(assign_op)
            fd = tf.gfile.Open(args.output, "w")

            while not sess.should_stop():
                results = sess.run(scores)
                for value in results:
                    fd.write("%f\n" % value)

            fd.close()
Esempio n. 2
0
def main(args):
    tf.logging.set_verbosity(tf.logging.INFO)
    # Load configs
    model_cls_list = [models.get_model(model) for model in args.models]
    params_list = [default_parameters() for _ in range(len(model_cls_list))]
    params_list = [
        merge_parameters(params, model_cls.get_parameters())
        for params, model_cls in zip(params_list, model_cls_list)
    ]
    params_list = [
        import_params(args.checkpoints[i], args.models[i], params_list[i])
        for i in range(len(args.checkpoints))
    ]
    params_list = [
        override_parameters(params_list[i], args)
        for i in range(len(model_cls_list))
    ]

    # Build Graph
    with tf.Graph().as_default():
        model_var_lists = []

        # Load checkpoints
        for i, checkpoint in enumerate(args.checkpoints):
            tf.logging.info("Loading %s" % checkpoint)
            var_list = tf.train.list_variables(checkpoint)
            values = {}
            reader = tf.train.load_checkpoint(checkpoint)

            for (name, shape) in var_list:
                if not name.startswith(model_cls_list[i].get_name()):
                    continue

                if name.find("losses_avg") >= 0:
                    continue

                tensor = reader.get_tensor(name)
                values[name] = tensor

            model_var_lists.append(values)

        # Build models
        model_fns = []

        for i in range(len(args.checkpoints)):
            name = model_cls_list[i].get_name()
            model = model_cls_list[i](params_list[i], name + "_%d" % i)
            model_fn = model.get_inference_func()
            model_fns.append(model_fn)

        params = params_list[0]
        # Read input file
        sorted_keys, sorted_inputs = dataset.sort_input_file(args.input)
        # Build input queue
        features = dataset.get_inference_input(sorted_inputs, params)
        predictions = search.create_inference_graph(model_fns, features,
                                                    params)

        assign_ops = []

        all_var_list = tf.trainable_variables()

        for i in range(len(args.checkpoints)):
            un_init_var_list = []
            name = model_cls_list[i].get_name()

            for v in all_var_list:
                if v.name.startswith(name + "_%d" % i):
                    un_init_var_list.append(v)

            ops = set_variables(un_init_var_list, model_var_lists[i],
                                name + "_%d" % i)
            assign_ops.extend(ops)

        assign_op = tf.group(*assign_ops)

        sess_creator = tf.train.ChiefSessionCreator(
            config=session_config(params)
        )

        results = []

        # Create session
        with tf.train.MonitoredSession(session_creator=sess_creator) as sess:
            # Restore variables
            sess.run(assign_op)

            while not sess.should_stop():
                results.append(sess.run(predictions))
                message = "Finished batch %d" % len(results)
                tf.logging.log(tf.logging.INFO, message)

        # Convert to plain text
        vocab = params.vocabulary["target"]
        outputs = []
        scores = []

        for result in results:
            outputs.append(result[0].tolist())
            scores.append(result[1].tolist())

        outputs = list(itertools.chain(*outputs))
        scores = list(itertools.chain(*scores))

        restored_inputs = []
        restored_outputs = []
        restored_scores = []

        for index in range(len(sorted_inputs)):
            restored_inputs.append(sorted_inputs[sorted_keys[index]])
            restored_outputs.append(outputs[sorted_keys[index]])
            restored_scores.append(scores[sorted_keys[index]])

        # Write to file
        with open(args.output, "w") as outfile:
            count = 0
            for outputs, scores in zip(restored_outputs, restored_scores):
                for output, score in zip(outputs, scores):
                    decoded = []
                    for idx in output:
                        if idx == params.mapping["target"][params.eos]:
                            break
                        decoded.append(vocab[idx])

                    decoded = " ".join(decoded)

                    if not args.verbose:
                        outfile.write("%s\n" % decoded)
                        break
                    else:
                        pattern = "%d ||| %s ||| %s ||| %f\n"
                        source = restored_inputs[count]
                        values = (count, source, decoded, score)
                        outfile.write(pattern % values)

                count += 1
Esempio n. 3
0
def main(args):
    tf.logging.set_verbosity(tf.logging.INFO)
    model_cls = models.get_model(args.model)
    params = default_parameters()
    # Import and override parameters
    # Priorities (low -> high):
    # default -> saved -> command
    params = merge_parameters(params, model_cls.get_parameters())
    params = import_params(args.output, args.model, params)
    override_parameters(params, args)

    # Export all parameters and model specific parameters
    export_params(params.output, "params.json", params)
    export_params(params.output, "%s.json" % args.model,
                  collect_params(params, model_cls.get_parameters()))

    # Build Graph
    with tf.Graph().as_default():
        if not params.record:
            # Build input queue
            if params.use_bert and params.bert_emb_path:
                features = dataset.get_training_input_with_bert(
                    params.input + [params.bert_emb_path], params)
            else:
                features = dataset.get_training_input(params.input, params)
        else:
            features = record.get_input_features(  # ??? 
                os.path.join(params.record, "*train*"), "train", params)

        # Build model
        initializer = get_initializer(params)
        model = model_cls(params)

        # Multi-GPU setting
        sharded_losses = parallel.parallel_model(
            model.get_training_func(initializer), features, params.device_list)
        loss = tf.add_n(sharded_losses) / len(sharded_losses)

        # Create global step
        global_step = tf.train.get_or_create_global_step()

        # Print parameters
        all_weights = {v.name: v for v in tf.trainable_variables()}
        total_size = 0

        for v_name in sorted(list(all_weights)):
            v = all_weights[v_name]
            tf.logging.info("%s\tshape    %s", v.name[:-2].ljust(80),
                            str(v.shape).ljust(20))
            v_size = np.prod(np.array(
                v.shape.as_list())).tolist()  # mutiple all dimension size
            total_size += v_size
        tf.logging.info("Total trainable variables size: %d", total_size)

        learning_rate = get_learning_rate_decay(params.learning_rate,
                                                global_step, params)
        learning_rate = tf.convert_to_tensor(learning_rate, dtype=tf.float32)
        tf.summary.scalar("learning_rate", learning_rate)

        # Create optimizer
        opt = tf.train.AdamOptimizer(learning_rate,
                                     beta1=params.adam_beta1,
                                     beta2=params.adam_beta2,
                                     epsilon=params.adam_epsilon)

        if params.update_cycle == 1:
            train_op = tf.contrib.layers.optimize_loss(
                name="training",
                loss=loss,
                global_step=global_step,
                learning_rate=learning_rate,
                clip_gradients=params.clip_grad_norm or None,
                optimizer=opt,
                colocate_gradients_with_ops=True)
            zero_op = tf.no_op("zero_op")
            collect_op = tf.no_op("collect_op")
        else:
            grads_and_vars = opt.compute_gradients(
                loss, colocate_gradients_with_ops=True)
            gradients = [item[0] for item in grads_and_vars]
            variables = [item[1] for item in grads_and_vars]
            variables = utils.replicate_variables(variables)
            zero_op = utils.zero_variables(variables)
            collect_op = utils.collect_gradients(gradients, variables)

            scale = 1.0 / params.update_cycle
            gradients, variables = utils.scale_gradients(grads_and_vars, scale)

            # Gradient clipping avoid greadient explosion!!
            if isinstance(params.clip_grad_norm or None, float):
                gradients, _ = tf.clip_by_global_norm(gradients,
                                                      params.clip_grad_norm)

            # Update variables
            grads_and_vars = list(zip(gradients, variables))
            with tf.control_dependencies([collect_op]):
                train_op = opt.apply_gradients(grads_and_vars, global_step)

        # Validation
        '''
        if params.validation and params.references[0]:
            files = [params.validation] + list(params.references)
            eval_inputs = files
            eval_input_fn = dataset.get_evaluation_input
        else:
            print("Don't evaluate")
            eval_input_fn = None
        '''
        # Add hooks
        train_hooks = [
            tf.train.StopAtStepHook(last_step=params.train_steps),
            tf.train.NanTensorHook(
                loss
            ),  # Monitors the loss tensor and stops training if loss is NaN
            tf.train.LoggingTensorHook(
                {
                    "step": global_step,
                    "loss": loss,
                    "chars": tf.shape(features["chars"]),
                    "source": tf.shape(features["source"]),
                    #"bert": tf.shape(features["bert"]),
                    "lr": learning_rate
                },
                every_n_iter=1),
            tf.train.CheckpointSaverHook(
                checkpoint_dir=params.output,
                save_secs=params.save_checkpoint_secs or None,
                save_steps=params.save_checkpoint_steps or None,
                saver=tf.train.Saver(max_to_keep=params.keep_checkpoint_max,
                                     sharded=False))
        ]

        config = session_config(params)
        '''
        if not eval_input_fn is  None:
            train_hooks.append(
                hooks.EvaluationHook(
                    lambda f: search.create_inference_graph(
                        model.get_evaluation_func(), f, params
                    ),
                    lambda: eval_input_fn(eval_inputs, params),
                    lambda x: decode_target_ids(x, params),
                    params.output,
                    config,
                    params.keep_top_checkpoint_max,
                    eval_secs=params.eval_secs,
                    eval_steps=params.eval_steps
                )
            )
        '''

        with tf.train.MonitoredTrainingSession(checkpoint_dir=params.output,
                                               hooks=train_hooks,
                                               save_checkpoint_secs=None,
                                               config=config) as sess:
            while not sess.should_stop():
                utils.session_run(sess, zero_op)
                for i in range(1, params.update_cycle):
                    utils.session_run(sess, collect_op)
                sess.run(train_op)
Esempio n. 4
0
def main(args):
    tf.logging.set_verbosity(tf.logging.INFO)
    # Load configs
    model_cls_list = [models.get_model(model) for model in args.models]
    params_list = [default_parameters() for _ in range(len(model_cls_list))]
    params_list = [
        merge_parameters(params, model_cls.get_parameters())
        for params, model_cls in zip(params_list, model_cls_list)
    ]
    params_list = [
        import_params(args.checkpoints[i], args.models[i], params_list[i])
        for i in range(len(args.checkpoints))
    ]
    params_list = [
        override_parameters(params_list[i], args)
        for i in range(len(model_cls_list))
    ]

    # Build Graph
    with tf.Graph().as_default():
        model_var_lists = []

        # Load checkpoints
        for i, checkpoint in enumerate(args.checkpoints):
            tf.logging.info("Loading %s" % checkpoint)
            var_list = tf.train.list_variables(checkpoint)
            values = {}
            reader = tf.train.load_checkpoint(checkpoint)

            for (name, shape) in var_list:
                if not name.startswith(model_cls_list[i].get_name()):
                    continue

                if name.find("losses_avg") >= 0:
                    continue

                tensor = reader.get_tensor(name)
                values[name] = tensor

            model_var_lists.append(values)

        # Build models
        model_list = []

        for i in range(len(args.checkpoints)):
            name = model_cls_list[i].get_name()
            model = model_cls_list[i](params_list[i], name + "_%d" % i)
            model_list.append(model)

        params = params_list[0]
        # Read input file
        sorted_keys, sorted_inputs = dataset.sort_input_file(args.input)
        # Build input queue
        features = dataset.get_inference_input(sorted_inputs, params)
        # Create placeholders
        placeholders = []

        for i in range(len(params.device_list)):
            placeholders.append({
                "source":
                tf.placeholder(tf.int32, [None, None], "source_%d" % i),
                "source_length":
                tf.placeholder(tf.int32, [None], "source_length_%d" % i)
            })

        # A list of outputs
        if params.generate_samples:
            inference_fn = sampling.create_sampling_graph
        else:
            inference_fn = inference.create_inference_graph

        predictions = parallel.data_parallelism(
            params.device_list, lambda f: inference_fn(model_list, f, params),
            placeholders)

        # Create assign ops
        assign_ops = []
        feed_dict = {}

        all_var_list = tf.trainable_variables()

        for i in range(len(args.checkpoints)):
            un_init_var_list = []
            name = model_cls_list[i].get_name()

            for v in all_var_list:
                if v.name.startswith(name + "_%d" % i):
                    un_init_var_list.append(v)

            ops = set_variables(un_init_var_list, model_var_lists[i],
                                name + "_%d" % i, feed_dict)
            assign_ops.extend(ops)

        assign_op = tf.group(*assign_ops)
        init_op = tf.tables_initializer()
        results = []

        tf.get_default_graph().finalize()

        # Create session
        with tf.Session(config=session_config(params)) as sess:
            # Restore variables
            sess.run(assign_op, feed_dict=feed_dict)
            sess.run(init_op)

            while True:
                try:
                    feats = sess.run(features)
                    op, feed_dict = shard_features(feats, placeholders,
                                                   predictions)
                    results.append(sess.run(op, feed_dict=feed_dict))
                    message = "Finished batch %d" % len(results)
                    tf.logging.log(tf.logging.INFO, message)
                except tf.errors.OutOfRangeError:
                    break

        # Convert to plain text
        vocab = params.vocabulary["target"]
        outputs = []
        scores = []

        for result in results:
            for item in result[0]:
                outputs.append(item.tolist())
            for item in result[1]:
                scores.append(item.tolist())

        outputs = list(itertools.chain(*outputs))
        scores = list(itertools.chain(*scores))

        restored_inputs = []
        restored_outputs = []
        restored_scores = []

        for index in range(len(sorted_inputs)):
            restored_inputs.append(sorted_inputs[sorted_keys[index]])
            restored_outputs.append(outputs[sorted_keys[index]])
            restored_scores.append(scores[sorted_keys[index]])

        # Write to file
        with open(args.output, "w") as outfile:
            count = 0
            for outputs, scores in zip(restored_outputs, restored_scores):
                for output, score in zip(outputs, scores):
                    decoded = []
                    for idx in output:
                        if idx == params.mapping["target"][params.eos]:
                            break
                        decoded.append(vocab[idx])

                    decoded = " ".join(decoded)

                    if not args.verbose:
                        outfile.write("%s\n" % decoded)
                        #break
                    else:
                        pattern = "%d ||| %s ||| %s ||| %f\n"
                        source = restored_inputs[count]
                        values = (count, source, decoded, score)
                        outfile.write(pattern % values)

                count += 1
Esempio n. 5
0
def main(args):
    # Load configs
    model_cls_list = [models.get_model(model) for model in args.models]
    params_list = [default_params() for _ in range(len(model_cls_list))]
    params_list = [
        merge_params(params, model_cls.default_params())
        for params, model_cls in zip(params_list, model_cls_list)]
    params_list = [
        import_params(args.checkpoints[i], args.models[i], params_list[i])
        for i in range(len(args.checkpoints))]
    params_list = [
        override_params(params_list[i], args)
        for i in range(len(model_cls_list))]

    params = params_list[0]
    dist.init_process_group("nccl", init_method=args.url,
                            rank=args.local_rank,
                            world_size=len(params.device_list))
    torch.cuda.set_device(params.device_list[args.local_rank])
    torch.set_default_tensor_type(torch.cuda.FloatTensor)

    if args.half:
        torch.set_default_dtype(torch.half)
        torch.set_default_tensor_type(torch.cuda.HalfTensor)

    # Create model
    with torch.no_grad():
        model_list = []

        if len(args.input) == 1:
            mode = "infer"
            if params.from_torchtext:
                dataset = data.get_dataset_torchtext(args.input[0], mode, params)
            else:
                dataset = data.get_dataset(args.input[0], mode, params)
        else:
            # Teacher-forcing
            mode = "eval"
            if params.from_torchtext:
                dataset = data.get_dataset_torchtext(args.input, mode, params)
            else:
                dataset = data.get_dataset(args.input, mode, params)

        iterator = iter(dataset)
        idx = 0
        counter = 0
        pad_max = 1024
        top_beams = params.top_beams
        decode_batch_size = params.decode_batch_size

        # count eval dataset
        total_len = 0
        max_length = 0
        for sample in iterator:
            total_len += 1
            length = sample['source'].shape[1]
            if length > max_length:
                max_length = length
        iterator = iter(dataset)

        for param in params_list:
            if hasattr(param, "max_length"):
                param.max_length = min(param.max_length, max_length)
            else:
                param.max_length = max_length

        for i in range(len(args.models)):
            model = model_cls_list[i](params_list[i]).cuda()

            if args.half:
                model = model.half()

            model.eval()
            model.load_state_dict(
                torch.load(utils.latest_checkpoint(args.checkpoints[i]),
                           map_location="cpu")["model"])

            model_list.append(model)

        # Buffers for synchronization
        size = torch.zeros([dist.get_world_size()]).long()
        t_list = [torch.empty([decode_batch_size, top_beams, pad_max]).long()
                  for _ in range(dist.get_world_size())]

        if dist.get_rank() == 0:
            fd = open(args.output, "wb")
            pbar = tqdm(total=total_len)
            pbar.set_description("Translating to {}".format(args.output))
        else:
            fd = None

        states = [None for _ in model_list]
        if "cachedtransformer" in [model.name for model in model_list]:
            last_features = [None for _ in model_list]
        for model in model_list:
            if model.name == "cachedtransformer":
                model.encoder.cache.set_batch_size(params.decode_batch_size)
                model.decoder.cache.set_batch_size(params.decode_batch_size)

        while True:
            try:
                features = next(iterator)
                features = data.lookup(features, mode, params, from_torchtext=params.from_torchtext)

                if mode == "eval":
                    features = features[0]

                batch_size = features["source"].shape[0]
            except:
                features = {
                    "source": torch.ones([1, 1]).long(),
                    "source_mask": torch.ones([1, 1]).float()
                }

                if mode == "eval":
                    features["target"] = torch.ones([1, 1]).long()
                    features["target_mask"] = torch.ones([1, 1]).float()

                batch_size = 0
            finally:
                for im, model in enumerate(model_list):
                    if model.name == "cachedtransformer":
                        features = update_cache(model, features, states[im], last_features[im], evaluate=True)
                        last_features[im] = features

            counter += 1

            # Decode
            if mode != "eval":
                seqs, _, states = utils.beam_search(model_list, features, params)
            else:
                seqs, _ = utils.argmax_decoding(model_list, features, params)

            # Padding
            pad_batch = decode_batch_size - seqs.shape[0]
            pad_beams = top_beams - seqs.shape[1]
            pad_length = pad_max - seqs.shape[2]
            seqs = torch.nn.functional.pad(
                seqs, (0, pad_length, 0, pad_beams, 0, pad_batch))

            # Synchronization
            size.zero_()
            size[dist.get_rank()].copy_(torch.tensor(batch_size))
            dist.all_reduce(size)
            dist.all_gather(t_list, seqs)

            if size.sum() == 0:
                break

            if dist.get_rank() != 0:
                continue

            for i in range(decode_batch_size):
                for j in range(dist.get_world_size()):
                    for k in range(top_beams):
                        n = size[j]
                        seq = convert_to_string(t_list[j][i][k], params)

                        if i >= n:
                            continue

                        if top_beams == 1:
                            fd.write(seq)
                            fd.write(b"\n")
                        else:
                            fd.write(str(idx).encode("utf-8"))
                            fd.write(b"\t")
                            fd.write(str(k).encode("utf-8"))
                            fd.write(b"\t")
                            fd.write(seq)
                            fd.write(b"\n")

                    idx = idx + 1

            if dist.get_rank() == 0:
                pbar.update(1)

        if dist.get_rank() == 0:
            pbar.close()
            fd.close()
Esempio n. 6
0
def main(args):
    tf.logging.set_verbosity(tf.logging.INFO)
    # Load configs
    model_cls_list = [models.get_model(model, lrp=True) 
                      for model in args.models]
    params_list = [default_parameters() for _ in range(len(model_cls_list))]
    params_list = [
        merge_parameters(params, model_cls.get_parameters())
        for params, model_cls in zip(params_list, model_cls_list)
    ]
    params_list = [
        import_params(args.checkpoints[i], args.models[i], params_list[i])
        for i in range(len(args.checkpoints))
    ]
    params_list = [
        override_parameters(params_list[i], args)
        for i in range(len(model_cls_list))
    ]

    # Build Graph
    with tf.Graph().as_default():
        model_var_lists = []

        # Load checkpoints
        for i, checkpoint in enumerate(args.checkpoints):
            print("Loading %s" % checkpoint)
            var_list = tf.train.list_variables(checkpoint)
            values = {}
            reader = tf.train.load_checkpoint(checkpoint)

            for (name, shape) in var_list:
                if not name.startswith(model_cls_list[i].get_name()):
                    continue

                if name.find("losses_avg") >= 0:
                    continue

                tensor = reader.get_tensor(name)
                values[name] = tensor

            model_var_lists.append(values)

        # Build models
        model_fns = []

        for i in range(len(args.checkpoints)):
            name = model_cls_list[i].get_name()
            model = model_cls_list[i](params_list[i], name + "_%d" % i)
            model_fn = model.get_relevance_func()
            model_fns.append(model_fn)

        params = params_list[0]
        # Read input file
        with tf.gfile.Open(args.input) as fd:
            inputs = [line.strip() for line in fd]
        with tf.gfile.Open(args.output) as fd:
            outputs = [line.strip() for line in fd]
        # Build input queue
        features = dataset.get_relevance_input(inputs, outputs, params)
        relevances = model_fns[0](features, params)

        assign_ops = []

        all_var_list = tf.trainable_variables()

        for i in range(len(args.checkpoints)):
            un_init_var_list = []
            name = model_cls_list[i].get_name()

            for v in all_var_list:
                if v.name.startswith(name + "_%d" % i):
                    un_init_var_list.append(v)

            ops = set_variables(un_init_var_list, model_var_lists[i],
                                name + "_%d" % i)
            assign_ops.extend(ops)

        assign_op = tf.group(*assign_ops)

        params.add_hparam('intra_op_parallelism_threads',1)
        params.add_hparam('inter_op_parallelism_threads',1)
        sess_creator = tf.train.ChiefSessionCreator(
            config=session_config(params)
        )

        results = []

        # Create session
        with tf.train.MonitoredSession(session_creator=sess_creator) as sess:
            # Restore variables
            sess.run(assign_op)
            if not os.path.exists(args.relevances):
                os.makedirs(args.relevances)
            count = 0
            while not sess.should_stop():
                src_seq, trg_seq, rlv_info, loss = sess.run(relevances)
                message = "Finished batch"" %d" % count
                for i in range(src_seq.shape[0]):
                    count += 1
                    src = to_text(params.vocabulary["source"], 
                                  params.mapping["source"], src_seq[i], params)
                    trg = to_text(params.vocabulary["target"], 
                                  params.mapping["target"], trg_seq[i], params)
                    output = open(args.relevances + '/' + str(count), 'w')
                    output.write('src: ' + src + '\n')
                    output.write('trg: ' + trg + '\n')
                    output.write('result: ' + str(rlv_info["result"][i]) + '\n')
                tf.logging.log(tf.logging.INFO, message)
Esempio n. 7
0
def main(args):
    # Load configs
    model_cls_list = [models.get_model(model) for model in args.models]
    params_list = [default_params() for _ in range(len(model_cls_list))]
    params_list = [
        merge_params(params, model_cls.default_params())
        for params, model_cls in zip(params_list, model_cls_list)
    ]
    params_list = [
        import_params(args.checkpoints[i], args.models[i], params_list[i])
        for i in range(len(args.checkpoints))
    ]
    params_list = [
        override_params(params_list[i], args)
        for i in range(len(model_cls_list))
    ]

    params = params_list[0]

    if args.cpu:
        dist.init_process_group("gloo",
                                init_method=args.url,
                                rank=args.local_rank,
                                world_size=1)
        torch.set_default_tensor_type(torch.FloatTensor)
    else:
        dist.init_process_group("nccl",
                                init_method=args.url,
                                rank=args.local_rank,
                                world_size=len(params.device_list))
        torch.cuda.set_device(params.device_list[args.local_rank])
        torch.set_default_tensor_type(torch.cuda.FloatTensor)

    if args.half:
        torch.set_default_dtype(torch.half)
        torch.set_default_tensor_type(torch.cuda.HalfTensor)

    # Create model
    with torch.no_grad():
        model_list = []

        for i in range(len(args.models)):
            if args.cpu:
                model = model_cls_list[i](params_list[i])
            else:
                model = model_cls_list[i](params_list[i]).cuda()

            if args.half:
                model = model.half()

            model.eval()
            model.load_state_dict(
                torch.load(utils.latest_checkpoint(args.checkpoints[i]),
                           map_location="cpu")["model"])

            model_list.append(model)

        if len(args.input) == 1:
            mode = "infer"
            sorted_key, dataset = data.get_dataset(args.input[0], mode, params)
        else:
            # Teacher-forcing
            mode = "eval"
            dataset = data.get_dataset(args.input, mode, params)
            sorted_key = None

        iterator = iter(dataset)
        counter = 0
        pad_max = 1024
        top_beams = params.top_beams
        decode_batch_size = params.decode_batch_size

        # Buffers for synchronization
        size = torch.zeros([dist.get_world_size()]).long()
        t_list = [
            torch.empty([decode_batch_size, top_beams, pad_max]).long()
            for _ in range(dist.get_world_size())
        ]

        all_outputs = []

        while True:
            try:
                features = next(iterator)
                features = data.lookup(features, mode, params, to_cpu=args.cpu)

                if mode == "eval":
                    features = features[0]

                batch_size = features["source"].shape[0]
            except:
                features = {
                    "source": torch.ones([1, 1]).long(),
                    "source_mask": torch.ones([1, 1]).float()
                }

                if mode == "eval":
                    features["target"] = torch.ones([1, 1]).long()
                    features["target_mask"] = torch.ones([1, 1]).float()

                batch_size = 0

            t = time.time()
            counter += 1

            # Decode
            if mode != "eval":
                seqs, _ = utils.beam_search(model_list, features, params)
            else:
                seqs, _ = utils.argmax_decoding(model_list, features, params)

            # Padding
            pad_batch = decode_batch_size - seqs.shape[0]
            pad_beams = top_beams - seqs.shape[1]
            pad_length = pad_max - seqs.shape[2]
            seqs = torch.nn.functional.pad(
                seqs, (0, pad_length, 0, pad_beams, 0, pad_batch))

            # Synchronization
            size.zero_()
            size[dist.get_rank()].copy_(torch.tensor(batch_size))

            if args.cpu:
                t_list[dist.get_rank()].copy_(seqs)
            else:
                dist.all_reduce(size)
                dist.all_gather(t_list, seqs)

            if size.sum() == 0:
                break

            if dist.get_rank() != 0:
                continue

            for i in range(decode_batch_size):
                for j in range(dist.get_world_size()):
                    beam_seqs = []
                    pad_flag = i >= size[j]
                    for k in range(top_beams):
                        seq = convert_to_string(t_list[j][i][k], params)

                        if pad_flag:
                            continue

                        beam_seqs.append(seq)

                    if pad_flag:
                        continue

                    all_outputs.append(beam_seqs)

            t = time.time() - t
            print("Finished batch: %d (%.3f sec)" % (counter, t))

        if dist.get_rank() == 0:
            restored_outputs = []
            if sorted_key is not None:
                for idx in range(len(all_outputs)):
                    restored_outputs.append(all_outputs[sorted_key[idx]])
            else:
                restored_outputs = all_outputs

            with open(args.output, "wb") as fd:
                if top_beams == 1:
                    for seqs in restored_outputs:
                        fd.write(seqs[0] + b"\n")
                else:
                    for idx, seqs in enumerate(restored_outputs):
                        for k, seq in enumerate(seqs):
                            fd.write(b"%d\t%d\t" % (idx, k))
                            fd.write(seq + b"\n")
def main(args):
    tf.logging.set_verbosity(tf.logging.INFO)
    # Load configs
    model_cls_list = [models.get_model(model) for model in args.models]
    params_list = [default_parameters() for _ in range(len(model_cls_list))]
    params_list = [
        merge_parameters(params, model_cls.get_parameters())
        for params, model_cls in zip(params_list, model_cls_list)
    ]
    params_list = [
        import_params(args.checkpoints[i], args.models[i], params_list[i])
        for i in range(len(args.checkpoints))
    ]
    params_list = [
        override_parameters(params_list[i], args)
        for i in range(len(model_cls_list))
    ]

    # Build Graph
    with tf.Graph().as_default():
        model_var_lists = []

        # Load checkpoints
        for i, checkpoint in enumerate(args.checkpoints):
            print("Loading %s" % checkpoint)
            var_list = tf.train.list_variables(checkpoint)
            values = {}
            reader = tf.train.load_checkpoint(checkpoint)

            for (name, shape) in var_list:
                if not name.startswith(model_cls_list[i].get_name()):
                    continue

                if name.find("losses_avg") >= 0:
                    continue

                tensor = reader.get_tensor(name)
                values[name] = tensor

            model_var_lists.append(values)

        # Build models
        model_fns = []

        for i in range(len(args.checkpoints)):
            name = model_cls_list[i].get_name()
            model = model_cls_list[i](params_list[i], name + "_%d" % i)
            model_fn = model.get_inference_func()
            model_fns.append(model_fn)

        params = params_list[0]

        #features = dataset.get_inference_input_with_bert(args.input, params)
        if params.use_bert and params.bert_emb_path:
            features = ds.get_inference_input_with_bert(
                params.input + [params.bert_emb_path], params)
        else:
            features = ds.get_inference_input(params.input, params)

        predictions = search.create_inference_graph(model_fns, features,
                                                    params)

        assign_ops = []

        all_var_list = tf.trainable_variables()

        for i in range(len(args.checkpoints)):
            un_init_var_list = []
            name = model_cls_list[i].get_name()

            for v in all_var_list:
                if v.name.startswith(name + "_%d" % i):
                    un_init_var_list.append(v)

            ops = set_variables(un_init_var_list, model_var_lists[i],
                                name + "_%d" % i)
            assign_ops.extend(ops)

        assign_op = tf.group(*assign_ops)

        sess_creator = tf.train.ChiefSessionCreator(
            config=session_config(params))

        results = []

        # Create session
        with tf.train.MonitoredSession(session_creator=sess_creator) as sess:
            # Restore variables
            sess.run(assign_op)

            while not sess.should_stop():
                results.append(sess.run(predictions))
                message = "Finished batch %d" % len(results)
                tf.logging.log(tf.logging.INFO, message)
                if len(results) > 2:
                    break
        # Convert to plain text
        vocab = params.vocabulary["target"]
        outputs = []

        for result in results:
            outputs.append(result.tolist())

        outputs = list(itertools.chain(*outputs))

        #restored_outputs = []

        # Write to file
        with open(args.output, "w") as outfile:
            for output in outputs:
                decoded = []
                for idx in output:
                    #if idx == params.mapping["target"][params.eos]:
                    #if idx != output[-1]:
                    #print("Warning: incomplete predictions as {}".format(" ".join(output)))
                    #break
                    decoded.append(vocab[idx])

                decoded = " ".join(decoded)
                outfile.write("%s\n" % decoded)
Esempio n. 9
0
def main(args):
    tf.logging.set_verbosity(tf.logging.INFO)
    # Load configs
    model_cls_list = [models.get_model(args.models)]
    params_list = [default_parameters() for _ in range(len(model_cls_list))]
    params_list = [
        merge_parameters(params, model_cls.get_parameters())
        for params, model_cls in zip(params_list, model_cls_list)
    ]
    params_list = [
        import_params(args.checkpoints, args.models,
                      params_list[0])  #导入训练产生的配置文件
        #for i in range(len([args.checkpoints]))
    ]
    params_list = [
        override_parameters(params_list[i], args)
        for i in range(len(model_cls_list))
    ]

    # Build Graph
    with tf.Graph().as_default():
        model_var_lists = []

        # Load checkpoints
        for i, checkpoint in enumerate([args.checkpoints]):
            print("Loading %s" % checkpoint)
            var_list = tf.train.list_variables(checkpoint)  #所有模型变量取成列表
            values = {}
            reader = tf.train.load_checkpoint(checkpoint)

            for (name, shape) in var_list:
                if not name.startswith(model_cls_list[i].get_name()
                                       ):  #获取所有rnnsearch里不带"losses_avg"的变量
                    continue

                if name.find("losses_avg") >= 0:
                    continue

                tensor = reader.get_tensor(name)  #获取成数
                values[name] = tensor

            model_var_lists.append(values)  #获取所有rnnsearch里不带"losses_avg"的变量,数值

        # Build models
        model_fns = []

        for i in range(len([args.checkpoints])):
            name = model_cls_list[i].get_name()
            model = model_cls_list[i](params_list[i], name + "_%d" % i)
            model_fn = model.get_inference_func()  #调用模型中的推理功能
            model_fns.append(model_fn)

        params = params_list[0]

        #features = dataset.get_inference_input_with_bert(args.input, params)
        if params.use_bert and params.bert_emb_path:
            features = dataset.get_inference_input_with_bert(
                params.input + [params.bert_emb_path], params)
        else:
            features = dataset.get_inference_input([params.input], params)

        predictions = search.create_inference_graph(model_fns, features,
                                                    params)

        assign_ops = []

        all_var_list = tf.trainable_variables()

        for i in range(len([args.checkpoints])):
            un_init_var_list = []
            name = model_cls_list[i].get_name()

            for v in all_var_list:
                if v.name.startswith(name + "_%d" % i):
                    un_init_var_list.append(v)

            ops = set_variables(un_init_var_list, model_var_lists[i],
                                name + "_%d" % i)
            assign_ops.extend(ops)

        assign_op = tf.group(*assign_ops)

        sess_creator = tf.train.ChiefSessionCreator(
            config=session_config(params))

        results = []

        # Create session
        with tf.train.MonitoredSession(session_creator=sess_creator) as sess:
            # Restore variables
            sess.run(assign_op)

            while not sess.should_stop():
                results.extend(sess.run(predictions))
                message = "Finished batch %d" % len(results)
                tf.logging.log(tf.logging.INFO, message)
            tar = []
            with open(params.input, "r") as inputs_f:
                for line in inputs_f:
                    if line.strip() == "O":
                        continue
                    else:
                        tar.extend(line.split(" ")[:-1])
def main(args):
    tf.logging.set_verbosity(tf.logging.INFO)
    # Load configs
    model_cls_list = [models.get_model(model) for model in args.models]
    params_list = [default_parameters() for _ in range(len(model_cls_list))]
    params_list = [
        merge_parameters(params, model_cls.get_parameters())
        for params, model_cls in zip(params_list, model_cls_list)
    ]
    params_list = [
        import_params(args.checkpoints[i], args.models[i], params_list[i])
        for i in range(len(args.checkpoints))
    ]
    params_list = [
        override_parameters(params_list[i], args)
        for i in range(len(model_cls_list))
    ]

    # Build Graph
    with tf.Graph().as_default():
        model_var_lists = []

        # Load checkpoints
        for i, checkpoint in enumerate(args.checkpoints):
            tf.logging.info("Loading %s" % checkpoint)
            var_list = tf.train.list_variables(checkpoint)
            values = {}
            reader = tf.train.load_checkpoint(checkpoint)

            for (name, shape) in var_list:
                if not name.startswith(model_cls_list[i].get_name()):
                    continue

                if name.find("losses_avg") >= 0:
                    continue

                tensor = reader.get_tensor(name)
                values[name] = tensor

            model_var_lists.append(values)

        # Build models
        model_list = []

        for i in range(len(args.checkpoints)):
            name = model_cls_list[i].get_name()
            model = model_cls_list[i](params_list[i], name + "_%d" % i)
            model_list.append(model)

        params = params_list[0]
        # Read input file
        sorted_keys, sorted_inputs = dataset.sort_input_file(args.input)
        # Build input queue
        features = dataset.get_inference_input(sorted_inputs, params)
        # Create placeholders
        placeholders = []

        for i in range(len(params.device_list)):
            placeholders.append({
                "source":
                tf.placeholder(tf.int32, [None, None], "source_%d" % i),
                "source_length":
                tf.placeholder(tf.int32, [None], "source_length_%d" % i)
            })

        # A list of outputs
        if params.generate_samples:
            inference_fn = sampling.create_sampling_graph
        else:
            inference_fn = inference.create_inference_graph

        predictions = parallel.data_parallelism(
            params.device_list, lambda f: inference_fn(model_list, f, params),
            placeholders)

        # Create assign ops
        assign_ops = []
        feed_dict = {}

        all_var_list = tf.trainable_variables()

        for i in range(len(args.checkpoints)):
            un_init_var_list = []
            name = model_cls_list[i].get_name()

            for v in all_var_list:
                if v.name.startswith(name + "_%d" % i):
                    un_init_var_list.append(v)

            ops = set_variables(un_init_var_list, model_var_lists[i],
                                name + "_%d" % i, feed_dict)
            assign_ops.extend(ops)

        assign_op = tf.group(*assign_ops)
        init_op = tf.tables_initializer()
        results = []

        tf.get_default_graph().finalize()

        tf.logging.info(args.models[0])
        if args.models[0] == 'transformer_raw_t5':
            t5_list = []
            for var in tf.trainable_variables():
                if 'en_t5_bias_mat' in var.name or 'de_self_relative_attention_bias' in var.name:
                    t5_list.append(var)
                    tf.logging.info(var)

            for op in tf.get_default_graph().get_operations():
                if 'encoder_t5_bias' in op.name or 'decoder_t5_bias' in op.name:
                    if 'random' in op.name or 'read' in op.name or 'Assign' in op.name or 'placeholder' in op.name:
                        continue
                    t5_list.append(op.values()[0])
                    tf.logging.info(op.values()[0].name)
        elif args.models[0] == 'transformer_raw_soft_t5':
            soft_t5_bias_list = []
            for op in tf.get_default_graph().get_operations():
                if 'soft_t5_bias' in op.name or 'soft_t5_encoder' in op.name or 'soft_t5_decoder' in op.name:
                    if 'random' in op.name or 'read' in op.name or 'Assign' in op.name or 'placeholder' in op.name or 'decoder' in op.name:
                        continue
                    soft_t5_bias_list.append(op.values()[0])
                    tf.logging.info(op.values()[0].name)

        # Create session
        with tf.Session(config=session_config(params)) as sess:
            # Restore variables
            sess.run(assign_op, feed_dict=feed_dict)
            sess.run(init_op)

            while True:
                try:
                    feats = sess.run(features)
                    op, feed_dict = shard_features(feats, placeholders,
                                                   predictions)
                    results.append(sess.run(op, feed_dict=feed_dict))
                    '''
                    if args.models[0] == 'transformer_raw_t5':
                        var_en_bucket=tf.get_default_graph().get_tensor_by_name(t5_list[0].name)
                        var_de_bucket=tf.get_default_graph().get_tensor_by_name(t5_list[1].name)
                        
                        var_en_bias=tf.get_default_graph().get_tensor_by_name(t5_list[2].name)
                        
                        en_bucket,de_bucket,en_t5_bias = sess.run([var_en_bucket,
                                                                   var_de_bucket,
                                                                   var_en_bias],
                                              feed_dict=feed_dict)
                        
                        ret_param = {'en_bucket':en_bucket,'de_bucket':en_bucket,
                                     'en_t5_bias':en_t5_bias}
                        pickle.dump(ret_param,open(args.checkpoints[0]+'/'+'t5_bias.pkl','wb'))
                        tf.logging.info('store the t5 bias')
                    elif args.models[0] == 'transformer_raw_soft_t5':
                        var_en_alpha=tf.get_default_graph().get_tensor_by_name(soft_t5_bias_list[0].name)
                        var_en_beta=tf.get_default_graph().get_tensor_by_name(soft_t5_bias_list[1].name)
                        var_en_t5_bias=tf.get_default_graph().get_tensor_by_name(soft_t5_bias_list[2].name)
                        en_alpha,en_beta,en_t5_bias = sess.run([var_en_alpha,var_en_beta,var_en_t5_bias], feed_dict=feed_dict)
                    
                        ret_param = {'en_t5_bias':en_t5_bias,'en_alpha':en_alpha,
                              'en_beta':en_beta}
                        pickle.dump(ret_param,open(args.checkpoints[0]+'/'+'soft_t5_bias.pkl','wb'))
                        tf.logging.info('store the soft-t5 bias')
                        '''
                    message = "Finished batch %d" % len(results)
                    tf.logging.log(tf.logging.INFO, message)
                except tf.errors.OutOfRangeError:
                    break

        # Convert to plain text
        vocab = params.vocabulary["target"]
        outputs = []
        scores = []

        for result in results:
            for shard in result:
                for item in shard[0]:
                    outputs.append(item.tolist())
                for item in shard[1]:
                    scores.append(item.tolist())

        restored_inputs = []
        restored_outputs = []
        restored_scores = []

        for index in range(len(sorted_inputs)):
            restored_inputs.append(sorted_inputs[sorted_keys[index]])
            restored_outputs.append(outputs[sorted_keys[index]])
            restored_scores.append(scores[sorted_keys[index]])

        # Write to file
        if sys.version_info.major == 2:
            outfile = open(args.output, "w")
        elif sys.version_info.major == 3:
            outfile = open(args.output, "w", encoding="utf-8")
        else:
            raise ValueError("Unkown python running environment!")

        count = 0
        for outputs, scores in zip(restored_outputs, restored_scores):
            for output, score in zip(outputs, scores):
                decoded = []
                for idx in output:
                    if idx == params.mapping["target"][params.eos]:
                        break
                    decoded.append(vocab[idx])

                decoded = " ".join(decoded)

                if not args.verbose:
                    outfile.write("%s\n" % decoded)
                else:
                    pattern = "%d ||| %s ||| %s ||| %f\n"
                    source = restored_inputs[count]
                    values = (count, source, decoded, score)
                    outfile.write(pattern % values)

            count += 1
        outfile.close()
Esempio n. 11
0
            else:
                break


    if isinstance(predictions, (list, tuple)):
        predictions = [item[:n] for item in predictions]

    return predictions, feed_dict


# main start here
if True:
    args = parse_args()
    tf.logging.set_verbosity(tf.logging.INFO)
    # Load configs
    model_cls_list = [models.get_model(model) for model in args.models]
    params_list = [default_parameters() for _ in range(len(model_cls_list))]
    params_list = [
        merge_parameters(params, model_cls.get_parameters())
        for params, model_cls in zip(params_list, model_cls_list)
    ]
    params_list = [
        import_params(args.checkpoints[i], args.models[i], params_list[i])
        for i in range(len(args.checkpoints))
    ]
    params_list = [
        override_parameters(params_list[i], args)
        for i in range(len(model_cls_list))
    ]

    # Build Graph
Esempio n. 12
0
def main(args):
    tf.logging.set_verbosity(tf.logging.INFO)
    model_cls = models.get_model(args.model)
    params = default_parameters()

    # Import and override parameters
    # Priorities (low -> high):
    # default -> saved -> command
    params = merge_parameters(params, model_cls.get_parameters())
    params = import_params(args.output, args.model, params)
    override_parameters(params, args)

    # Export all parameters and model specific parameters
    export_params(params.output, "params.json", params)
    export_params(params.output, "%s.json" % args.model,
                  collect_params(params, model_cls.get_parameters()))

    # Build Graph
    with tf.Graph().as_default():
        if not params.record:
            # Build input queue
            features = dataset.get_training_input(params.input, params)
        else:
            features = record.get_input_features(
                os.path.join(params.record, "*train*"), "train", params)

        features, init_op = cache.cache_features(features, params.update_cycle)

        # Build model
        initializer = get_initializer(params)
        model = model_cls(params)

        # Multi-GPU setting
        sharded_losses = parallel.parallel_model(
            model.get_training_func(initializer), features, params.device_list)
        loss = tf.add_n(sharded_losses) / len(sharded_losses)

        # Create global step
        global_step = tf.train.get_or_create_global_step()

        # Print parameters
        all_weights = {v.name: v for v in tf.trainable_variables()}
        total_size = 0

        for v_name in sorted(list(all_weights)):
            v = all_weights[v_name]
            tf.logging.info("%s\tshape    %s", v.name[:-2].ljust(80),
                            str(v.shape).ljust(20))
            v_size = np.prod(np.array(v.shape.as_list())).tolist()
            total_size += v_size
        tf.logging.info("Total trainable variables size: %d", total_size)

        learning_rate = get_learning_rate_decay(params.learning_rate,
                                                global_step, params)
        learning_rate = tf.convert_to_tensor(learning_rate, dtype=tf.float32)
        tf.summary.scalar("learning_rate", learning_rate)

        # Create optimizer
        if params.optimizer == "Adam":
            opt = tf.train.AdamOptimizer(learning_rate,
                                         beta1=params.adam_beta1,
                                         beta2=params.adam_beta2,
                                         epsilon=params.adam_epsilon)
        elif params.optimizer == "LazyAdam":
            opt = tf.contrib.opt.LazyAdamOptimizer(learning_rate,
                                                   beta1=params.adam_beta1,
                                                   beta2=params.adam_beta2,
                                                   epsilon=params.adam_epsilon)
        else:
            raise RuntimeError("Optimizer %s not supported" % params.optimizer)

        loss, ops = optimize.create_train_op(loss, opt, global_step, params)
        restore_op = restore_variables(args.checkpoint)

        # Validation
        if params.validation and params.references[0]:
            files = [params.validation] + list(params.references)
            eval_inputs = dataset.sort_and_zip_files(files)
            eval_input_fn = dataset.get_evaluation_input
        else:
            eval_input_fn = None

        # Add hooks
        save_vars = tf.trainable_variables() + [global_step]
        train_hooks = [
            tf.train.StopAtStepHook(last_step=params.train_steps),
            tf.train.NanTensorHook(loss),
            tf.train.LoggingTensorHook({
                "step": global_step,
                "loss": loss,
            },
                                       every_n_iter=1),
            tf.train.CheckpointSaverHook(
                checkpoint_dir=params.output,
                save_secs=params.save_checkpoint_secs or None,
                save_steps=params.save_checkpoint_steps or None,
                saver=tf.train.Saver(
                    var_list=save_vars if params.only_save_trainable else None,
                    max_to_keep=params.keep_checkpoint_max,
                    sharded=False))
        ]

        config = session_config(params)
        # gpu allow growth
        config.gpu_options.allow_growth = True

        if eval_input_fn is not None:
            train_hooks.append(
                hooks.EvaluationHook(
                    lambda f: inference.create_inference_graph(
                        [model.get_inference_func()], f, params),
                    lambda: eval_input_fn(eval_inputs, params),
                    lambda x: decode_target_ids(x, params),
                    params.output,
                    config,
                    params.keep_top_checkpoint_max,
                    eval_secs=params.eval_secs,
                    eval_steps=params.eval_steps))

        # Create session, do not use default CheckpointSaverHook
        with tf.train.MonitoredTrainingSession(checkpoint_dir=params.output,
                                               hooks=train_hooks,
                                               save_checkpoint_secs=None,
                                               config=config) as sess:
            sess.run(restore_op)
            while not sess.should_stop():
                # Bypass hook calls
                utils.session_run(sess, [init_op, ops["zero_op"]])
                for i in range(params.update_cycle):
                    utils.session_run(sess, ops["collect_op"])
                utils.session_run(sess, ops["scale_op"])
                sess.run(ops["train_op"])
Esempio n. 13
0
def main(args):
    tf.logging.set_verbosity(tf.logging.INFO)
    # Load configs
    model_cls_list = [models.get_model(model) for model in args.models]
    params_list = [default_parameters() for _ in range(len(model_cls_list))]
    params_list = [
        merge_parameters(params, model_cls.get_parameters())
        for params, model_cls in zip(params_list, model_cls_list)
    ]
    params_list = [
        import_params(args.checkpoints[i], args.models[i], params_list[i])
        for i in range(len(args.checkpoints))
    ]
    params_list = [
        override_parameters(params_list[i], args)
        for i in range(len(model_cls_list))
    ]

    # Build Graph
    with tf.Graph().as_default():
        model_var_lists = []

        # Load checkpoints
        for i, checkpoint in enumerate(args.checkpoints):
            print("Loading %s" % checkpoint)
            var_list = tf.train.list_variables(checkpoint)
            values = {}
            reader = tf.train.load_checkpoint(checkpoint)

            for (name, shape) in var_list:
                if not name.startswith(model_cls_list[i].get_name()):
                    continue

                if name.find("losses_avg") >= 0:
                    continue

                tensor = reader.get_tensor(name)
                values[name] = tensor

            model_var_lists.append(values)

        # Build models
        model_fns = []

        for i in range(len(args.checkpoints)):
            name = model_cls_list[i].get_name()
            model = model_cls_list[i](params_list[i], name + "_%d" % i)
            model_fn = model.get_relevance_func()
            model_fns.append(model_fn)

        params = params_list[0]
        # Build input queue
        features = dataset.get_training_input(args.input, params)
        relevances = model_fns[0](features, params)

        assign_ops = []

        all_var_list = tf.trainable_variables()

        for i in range(len(args.checkpoints)):
            un_init_var_list = []
            name = model_cls_list[i].get_name()

            for v in all_var_list:
                if v.name.startswith(name + "_%d" % i):
                    un_init_var_list.append(v)

            ops = set_variables(un_init_var_list, model_var_lists[i],
                                name + "_%d" % i)
            assign_ops.extend(ops)

        assign_op = tf.group(*assign_ops)

        sess_creator = tf.train.ChiefSessionCreator(
            config=session_config(params))

        results = []
        num = 10
        count = 0
        hooks = [tf.train.LoggingTensorHook({}, every_n_iter=1)]
        with tf.train.MonitoredSession(session_creator=sess_creator,
                                       hooks=hooks) as sess:
            # Restore variables
            sess.run(assign_op)
            src_seq, trg_seq, rlv_info, loss = sess.run(relevances)
            start = time.time()
            while count < num:  #not sess.should_stop():
                src_seq, trg_seq, rlv_info, loss = sess.run(relevances)
                print('--result--')
                print('loss:', loss)
                for i in range(src_seq.shape[0]):
                    src = to_text(params.vocabulary["source"],
                                  params.mapping["source"], src_seq[i], params)
                    trg = to_text(params.vocabulary["target"],
                                  params.mapping["target"], trg_seq[i], params)
                    print('sentence %d' % i)
                    print('src:', src)
                    print('src_idx:', src_seq[i])
                    print('trg:', trg)
                    print('trg_idx:', trg_seq[i])
                    print('result:', rlv_info["result"][i])
                count += 1
            end = time.time()
            print('total time:', end - start)
Esempio n. 14
0
def main(args):
    tf.logging.set_verbosity(tf.logging.INFO)
    model_cls = models.get_model(args.model)

    params = default_parameters()

    params = merge_parameters(params, model_cls.get_parameters())

    params = import_params(args.output, args.model, params)
    override_parameters(params, args)

    export_params(params.output, "params.json", params)
    export_params(params.output, "%s.json" % args.model,
                  collect_params(params, model_cls.get_parameters()))

    with tf.Graph().as_default():
        features = dataset.get_final_training_input(params.input, params)

        update_cycle = params.update_cycle
        features, init_op = cache.cache_features(features, update_cycle)

        initializer = get_initializer(params)
        regularizer = tf.contrib.layers.l1_l2_regularizer(
            scale_l1=params.scale_l1, scale_l2=params.scale_l2)
        model = model_cls(params)
        global_step = tf.train.get_or_create_global_step()

        model_fn = model.get_training_func(initializer, regularizer)
        class_loss, attention_loss = model_fn(features)
        attention_loss = attention_loss * params.lamda
        loss = class_loss + attention_loss + tf.losses.get_regularization_loss(
        )

        all_weights = {v.name: v for v in tf.trainable_variables()}
        total_size = 0

        for v_name in sorted(list(all_weights)):
            v = all_weights[v_name]
            tf.logging.info("%s\tshape    %s", v.name[:-2].ljust(80),
                            str(v.shape).ljust(20))
            v_size = np.prod(np.array(v.shape.as_list())).tolist()
            total_size += v_size
        tf.logging.info("Total trainable variables size: %d", total_size)

        learning_rate = get_learning_rate_decay(params.learning_rate,
                                                global_step, params)
        learning_rate = tf.convert_to_tensor(learning_rate, dtype=tf.float32)
        tf.summary.scalar("learning_rate", learning_rate)

        if params.optimizer == "Adam":
            opt = tf.train.AdamOptimizer(learning_rate,
                                         beta1=params.adam_beta1,
                                         beta2=params.adam_beta2,
                                         epsilon=params.adam_epsilon)
        elif params.optimizer == "LazyAdam":
            opt = tf.contrib.opt.LazyAdamOptimizer(learning_rate,
                                                   beta1=params.adam_beta1,
                                                   beta2=params.adam_beta2,
                                                   epsilon=params.adam_epsilon)
        elif params.optimizer == "SGD":
            opt = tf.train.GradientDescentOptimizer(learning_rate)
        else:
            raise RuntimeError("Optimizer %s not supported" % params.optimizer)

        loss, ops = optimize.create_train_op(loss, opt, global_step, params)
        restore_op = restore_variables(args.checkpoint)

        if params.validation:
            eval_sorted_keys, eval_inputs = dataset.read_eval_input_file(
                params.validation)
            eval_input_fn = dataset.get_predict_input
        else:
            eval_input_fn = None

        save_vars = tf.trainable_variables() + [global_step]
        saver = tf.train.Saver(
            var_list=save_vars if params.only_save_trainable else None,
            max_to_keep=params.keep_checkpoint_max,
            sharded=False)
        tf.add_to_collection(tf.GraphKeys.SAVERS, saver)

        multiplier = tf.convert_to_tensor([update_cycle, 1])

        train_hooks = [
            tf.train.StopAtStepHook(last_step=params.train_steps),
            tf.train.NanTensorHook(loss),
            tf.train.LoggingTensorHook(
                {
                    "step": global_step,
                    "loss": loss,
                    "class_loss": class_loss,
                    "attention_loss": attention_loss,
                    "text": tf.shape(features["text"]) * multiplier,
                    "aspect": tf.shape(features["aspect"]) * multiplier,
                    "polarity": tf.shape(features["polarity"]) * multiplier
                },
                every_n_iter=1),
            tf.train.CheckpointSaverHook(
                checkpoint_dir=params.output,
                save_secs=params.save_checkpoint_secs or None,
                save_steps=params.save_checkpoint_steps or None,
                saver=saver)
        ]

        config = session_config(params)

        if eval_input_fn is not None:
            train_hooks.append(
                hooks.EvaluationHook(
                    lambda f: inference.create_predict_graph([model], f, params
                                                             ),
                    lambda: eval_input_fn(eval_inputs, params),
                    params.output,
                    config,
                    params.keep_top_checkpoint_max,
                    eval_secs=params.eval_secs,
                    eval_steps=params.eval_steps))

        def restore_fn(step_context):
            step_context.session.run(restore_op)

        def step_fn(step_context):
            step_context.session.run([init_op, ops["zero_op"]])
            for i in range(update_cycle - 1):
                step_context.session.run(ops["collect_op"])

            return step_context.run_with_hooks(ops["train_op"])

        with tf.train.MonitoredTrainingSession(checkpoint_dir=params.output,
                                               hooks=train_hooks,
                                               save_checkpoint_secs=None,
                                               config=config) as sess:
            sess.run_step_fn(restore_fn)

            while not sess.should_stop():
                sess.run_step_fn(step_fn)
Esempio n. 15
0
def main(args):
    tf.logging.set_verbosity(tf.logging.INFO)
    model_cls = models.get_model(args.model)  # a model class
    params = default_parameters()

    # Import and override parameters
    # Priorities (low -> high):
    # default -> saved -> command
    params = merge_parameters(params, model_cls.get_parameters())
    params = import_params(args.output, args.model, params)
    override_parameters(params, args)

    # Export all parameters and model specific parameters
    export_params(params.output, "params.json", params)
    export_params(params.output, "%s.json" % args.model,
                  collect_params(params, model_cls.get_parameters()))
    # print(params.vocabulary)

    # Build Graph
    with tf.Graph().as_default():
        if not params.record:
            # Build input queue
            parsing_features = dataset.get_training_input(params.parsing_input,
                                                          params,
                                                          problem='parsing')
            amr_features = dataset.get_training_input(params.amr_input,
                                                      params,
                                                      problem='amr')
        else:
            parsing_features = record.get_input_features(
                os.path.join(params.record, "*train*"), "train", params)
            amr_features = record.get_input_features(
                os.path.join(params.record, "*train*"), "train", params)

        update_cycle = params.update_cycle
        parsing_features, parsing_init_op = cache.cache_features(
            parsing_features, update_cycle)
        amr_features, amr_init_op = cache.cache_features(
            amr_features, update_cycle)

        # Build model
        initializer = get_initializer(params)
        regularizer = tf.contrib.layers.l1_l2_regularizer(
            scale_l1=params.scale_l1, scale_l2=params.scale_l2)
        model = model_cls(params)
        # Create global step
        global_step = tf.train.get_or_create_global_step()

        # Multi-GPU setting
        # parsing_sharded_losses, amr_sharded_losses = parallel.parallel_model(
        #         model.get_training_func(initializer, regularizer),
        #         [parsing_features, amr_features],
        #         params.device_list
        # )
        # parsing_loss, amr_loss = model.get_training_func(initializer, regularizer)([parsing_features, amr_features])
        # with tf.variable_scope("shared_decode_variable") as scope:
        #     parsing_loss = model.get_training_func(initializer, regularizer, problem='parsing')(parsing_features)
        #     scope.reuse_variables()
        #     amr_loss = model.get_training_func(initializer, regularizer, problem='amr')(amr_features)
        #with tf.variable_scope("encoder_shared", reuse=True):
        print(params.layer_postprocess)
        with tf.variable_scope("encoder_shared",
                               initializer=initializer,
                               regularizer=regularizer,
                               reuse=tf.AUTO_REUSE):
            parsing_encoder_output = model.get_encoder_out(
                parsing_features, "train", params)
            amr_encoder_output = model.get_encoder_out(amr_features, "train",
                                                       params)
        with tf.variable_scope("parsing_decoder",
                               initializer=initializer,
                               regularizer=regularizer):
            parsing_loss = model.get_decoder_out(parsing_features,
                                                 parsing_encoder_output,
                                                 "train",
                                                 params,
                                                 problem="parsing")
        with tf.variable_scope("amr_decoder",
                               initializer=initializer,
                               regularizer=regularizer):
            amr_loss = model.get_decoder_out(amr_features,
                                             amr_encoder_output,
                                             "train",
                                             params,
                                             problem="amr")

        parsing_loss = parsing_loss + tf.losses.get_regularization_loss()

        amr_loss = amr_loss + tf.losses.get_regularization_loss()

        # Print parameters
        all_weights = {v.name: v for v in tf.trainable_variables()}
        total_size = 0

        for v_name in sorted(list(all_weights)):
            v = all_weights[v_name]
            tf.logging.info("%s\tshape    %s", v.name[:-2].ljust(80),
                            str(v.shape).ljust(20))
            v_size = np.prod(np.array(v.shape.as_list())).tolist()
            total_size += v_size
        tf.logging.info("Total trainable variables size: %d", total_size)

        learning_rate = get_learning_rate_decay(params.learning_rate,
                                                global_step, params)
        learning_rate = tf.convert_to_tensor(learning_rate, dtype=tf.float32)
        tf.summary.scalar("learning_rate", learning_rate)

        # Create optimizer
        if params.optimizer == "Adam":
            opt = tf.train.AdamOptimizer(learning_rate,
                                         beta1=params.adam_beta1,
                                         beta2=params.adam_beta2,
                                         epsilon=params.adam_epsilon)
        elif params.optimizer == "LazyAdam":
            opt = tf.contrib.opt.LazyAdamOptimizer(learning_rate,
                                                   beta1=params.adam_beta1,
                                                   beta2=params.adam_beta2,
                                                   epsilon=params.adam_epsilon)
        else:
            raise RuntimeError("Optimizer %s not supported" % params.optimizer)

        parsing_loss, parsing_ops = optimize.create_train_op(parsing_loss,
                                                             opt,
                                                             global_step,
                                                             params,
                                                             problem="parsing")
        amr_loss, amr_ops = optimize.create_train_op(amr_loss,
                                                     opt,
                                                     global_step,
                                                     params,
                                                     problem="amr")

        restore_op = restore_variables(args.checkpoint)

        # Validation
        if params.validation and params.references[0]:
            files = [params.validation] + list(params.references)
            eval_inputs = dataset.sort_and_zip_files(files)
            eval_input_fn = dataset.get_evaluation_input
        else:
            eval_input_fn = None

        # Add hooks
        save_vars = tf.trainable_variables() + [global_step]
        saver = tf.train.Saver(
            var_list=save_vars if params.only_save_trainable else None,
            max_to_keep=params.keep_checkpoint_max,
            sharded=False)
        tf.add_to_collection(tf.GraphKeys.SAVERS, saver)

        multiplier = tf.convert_to_tensor([update_cycle, 1])

        train_hooks = [
            tf.train.StopAtStepHook(last_step=params.train_steps),
            tf.train.NanTensorHook(parsing_loss),
            tf.train.NanTensorHook(amr_loss),
            tf.train.LoggingTensorHook(
                {
                    "step": global_step,
                    "parsing_loss": parsing_loss,
                    "amr_loss": amr_loss,
                    "parsing_source":
                    tf.shape(parsing_features["source"]) * multiplier,
                    "parsing_target":
                    tf.shape(parsing_features["target"]) * multiplier,
                    "amr_source":
                    tf.shape(amr_features["source"]) * multiplier,
                    "amr_target": tf.shape(amr_features["target"]) * multiplier
                },
                every_n_iter=1),
            tf.train.CheckpointSaverHook(
                checkpoint_dir=params.output,
                save_secs=params.save_checkpoint_secs or None,
                save_steps=params.save_checkpoint_steps or None,
                saver=saver)
        ]

        config = session_config(params)

        if eval_input_fn is not None:
            train_hooks.append(
                hooks.EvaluationHook(
                    lambda f: inference.create_inference_graph([model], f,
                                                               params),
                    lambda: eval_input_fn(eval_inputs, params),
                    lambda x: decode_target_ids(x, params, flag='target'),
                    lambda x: decode_target_ids(x, params, flag='source'),
                    params.output,
                    config,
                    params.keep_top_checkpoint_max,
                    eval_secs=params.eval_secs,
                    eval_steps=params.eval_steps))

        def restore_fn(step_context):
            step_context.session.run(restore_op)

        def parsing_step_fn(step_context):
            # Bypass hook calls
            step_context.session.run([parsing_init_op, parsing_ops["zero_op"]
                                      ])  # if params.cycle==1 do nothing
            for i in range(update_cycle - 1):
                step_context.session.run(parsing_ops["collect_op"])

            return step_context.run_with_hooks(parsing_ops["train_op"])

        def amr_step_fn(step_context):
            # Bypass hook calls
            step_context.session.run([amr_init_op, amr_ops["zero_op"]])
            for i in range(update_cycle - 1):
                step_context.session.run(amr_ops["collect_op"])

            return step_context.run_with_hooks(amr_ops["train_op"])

        def step_fn(step_context):
            # Bypass hook calls
            return step_context.run_with_hooks(parsing_ops["train_op"])

        # Create session, do not use default CheckpointSaverHook
        with tf.train.MonitoredTrainingSession(checkpoint_dir=params.output,
                                               hooks=train_hooks,
                                               save_checkpoint_secs=None,
                                               config=config) as sess:
            # Restore pre-trained variables
            sess.run_step_fn(restore_fn)
            step = 0
            while not sess.should_stop():
                if step % 2 == 0:
                    sess.run_step_fn(parsing_step_fn)
                else:
                    sess.run_step_fn(amr_step_fn)
                step += 1
Esempio n. 16
0
def main(args):
    tf.logging.set_verbosity(tf.logging.INFO)
    model_cls = models.get_model(args.model)
    params = default_parameters()

    # Import and override parameters
    # Priorities (low -> high):
    # default -> saved -> command
    params = merge_parameters(params, model_cls.get_parameters())
    params = import_params(args.checkpoint, args.model, params)
    override_parameters(params, args)

    # Build Graph
    with tf.Graph().as_default():
        model = model_cls(params)
        inputs = read_files(args.input)
        features = get_features(inputs, params)
        score_fn = model.get_evaluation_func()
        # scores = score_fn(features, params)
        # Multi-GPU setting
        scores = parallel.parallel_model(score_fn, features,
                                         params.device_list)

        sess_creator = tf.train.ChiefSessionCreator(
            config=session_config(params))

        # Load checkpoint
        tf.logging.info("Loading %s" % args.checkpoint)
        var_list = tf.train.list_variables(args.checkpoint)
        values = {}
        reader = tf.train.load_checkpoint(args.checkpoint)

        for (name, shape) in var_list:
            if not name.startswith(model_cls.get_name()):
                continue

            tensor = reader.get_tensor(name)
            values[name] = tensor

        ops = set_variables(tf.trainable_variables(), values,
                            model_cls.get_name())
        assign_op = tf.group(*ops)

        # Create session
        fb = 0
        with tf.train.MonitoredSession(session_creator=sess_creator) as sess:
            # Restore variables
            sess.run(assign_op)
            fd = tf.gfile.Open(args.rv_file, "w")
            if params.model_uncertainty:
                fm = tf.gfile.Open(args.mean_file, 'w')
                fv = tf.gfile.Open(args.var_file, 'w')
                fsm = tf.gfile.Open(args.sen_mean, 'w')
                fsv = tf.gfile.Open(args.sen_var, 'w')
                fsr = tf.gfile.Open(args.sen_rv, 'w')

            while not sess.should_stop():
                results = sess.run(scores)
                fb += 1
                message = "Finished batch %d" % fb
                tf.logging.log(tf.logging.INFO, message)
                if params.model_uncertainty:
                    rv_score = []
                    mean_score = []
                    var_score = []
                    len_score = []
                    sen_mean = []
                    sen_var = []
                    sen_rv = []
                    for result in results:
                        rv_score.append(result["rv"].tolist())
                        mean_score.append(result["mean"].tolist())
                        var_score.append(result["var"].tolist())
                        len_score.append(result["len"].tolist())
                        sen_mean.append(result["sen_mean"].tolist())
                        sen_var.append(result["sen_var"].tolist())
                        sen_rv.append(result["sen_rv"].tolist())
                    rv_score = list(itertools.chain(*rv_score))
                    mean_score = list(itertools.chain(*mean_score))
                    var_score = list(itertools.chain(*var_score))
                    len_score = list(itertools.chain(*len_score))
                    sen_mean = list(itertools.chain(*sen_mean))
                    sen_var = list(itertools.chain(*sen_var))
                    sen_rv = list(itertools.chain(*sen_rv))

                    # mean_score = results["mean"]
                    # var_score = results["var"]
                    # rv_score = results["rv"]
                    # len_score = results["len"].tolist()
                    len_score = map(int, len_score)

                    for i, l in enumerate(len_score):
                        m = mean_score[i][:l]
                        v = var_score[i][:l]
                        r = rv_score[i][:l]
                        for (i_m, i_v, i_r) in zip(m, v, r):
                            # fd.write('%f/%f/%f ' % (i_m, i_v, i_r))
                            fd.write('%f ' % i_r)
                            fm.write('%f ' % i_m)
                            fv.write('%f ' % i_v)
                        fd.write("\n")
                        fm.write("\n")
                        fv.write("\n")

                        fsm.write("{}\n".format(sen_mean[i]))
                        fsv.write("{}\n".format(sen_var[i]))
                        fsr.write("{}\n".format(sen_rv[i]))
                else:
                    results = itertools.chain(*results)
                    for value in results:
                        fd.write("%f\n" % value)

            fd.close()
def main(args):
    tf.logging.set_verbosity(tf.logging.INFO)
    # Load configs
    model_cls_list = [models.get_model(args.models)]
    params_list = [default_parameters() for _ in range(len(model_cls_list))]
    params_list = [
        merge_parameters(params, model_cls.get_parameters())
        for params, model_cls in zip(params_list, model_cls_list)
    ]
    params_list = [
        import_params(args.checkpoints, args.models,
                      params_list[0])  #导入训练产生的配置文件
        #for i in range(len([args.checkpoints]))
    ]
    params_list = [
        override_parameters(params_list[i], args)
        for i in range(len(model_cls_list))
    ]

    # Build Graph
    with tf.Graph().as_default():
        model_var_lists = []

        # Load checkpoints
        for i, checkpoint in enumerate([args.checkpoints]):
            print("Loading %s" % checkpoint)
            var_list = tf.train.list_variables(checkpoint)  #所有模型变量取成列表
            values = {}
            reader = tf.train.load_checkpoint(checkpoint)

            for (name, shape) in var_list:
                if not name.startswith(model_cls_list[i].get_name()
                                       ):  #获取所有rnnsearch里不带"losses_avg"的变量
                    continue

                if name.find("losses_avg") >= 0:
                    continue

                tensor = reader.get_tensor(name)  #获取成数
                values[name] = tensor

            model_var_lists.append(values)  #获取所有rnnsearch里不带"losses_avg"的变量,数值

        # Build models
        model_fns = []

        for i in range(len([args.checkpoints])):
            name = model_cls_list[i].get_name()
            model = model_cls_list[i](params_list[i], name + "_%d" % i)
            model_fn = model.get_inference_func()  #调用模型中的推理功能
            model_fns.append(model_fn)

        params = params_list[0]

        #features = dataset.get_inference_input_with_bert(args.input, params)
        if params.use_bert and params.bert_emb_path:
            features = dataset.get_inference_input_with_bert(
                params.input + [params.bert_emb_path], params)
        else:
            features = dataset.get_inference_input([params.input], params)

        predictions = search.create_inference_graph(model_fns, features,
                                                    params)

        assign_ops = []

        all_var_list = tf.trainable_variables()

        for i in range(len([args.checkpoints])):
            un_init_var_list = []
            name = model_cls_list[i].get_name()

            for v in all_var_list:
                if v.name.startswith(name + "_%d" % i):
                    un_init_var_list.append(v)

            ops = set_variables(un_init_var_list, model_var_lists[i],
                                name + "_%d" % i)
            assign_ops.extend(ops)

        assign_op = tf.group(*assign_ops)

        sess_creator = tf.train.ChiefSessionCreator(
            config=session_config(params))

        result_for_score = []
        result_for_write = []

        # Create session
        with tf.train.MonitoredSession(session_creator=sess_creator) as sess:
            # Restore variables
            sess.run(assign_op)
            lenth = []
            with open(args.input, "r", encoding="utf8") as f:
                for line in f:
                    if line.strip() == "-DOCSTART-\n":
                        continue
                    lines = line.strip().split(" ")
                    lenth.append(len(lines))
                f.close()  #获取每句话的长度,为去掉padding的字做参考
            current_num = 0
            batch = 0
            while not sess.should_stop():
                currrent_res_arr = sess.run(predictions)
                result_for_write.append(currrent_res_arr)
                for arr in currrent_res_arr:
                    result_for_score.extend(list(arr)[:lenth[current_num]])
                    current_num += 1
                batch += 1
                message = "Finished batch %d" % batch
                tf.logging.log(tf.logging.INFO, message)
        if params.is_validation:
            from sklearn.metrics import precision_score, recall_score, f1_score
            import numpy as np
            #将标签映射成序号
            voc_lis = params.vocabulary["target"]
            index = list(np.arange(len(voc_lis)))
            dic = dict(zip(voc_lis, index))

            def map_res(x):
                return dic[x]

            targets_list = []
            with open(args.eval_file, "r") as f:  #读取标签文件
                for line in f:
                    if line.strip() == "O":
                        continue
                    lines = line.strip().split(" ")
                    targets_list.extend(list(map(map_res, lines)))  #标签文件转化成序号

            result_arr = np.array(result_for_score)
            targets_arr = np.array(targets_list)
            precision_ = precision_score(targets_arr,
                                         result_arr,
                                         average="micro",
                                         labels=[0, 2, 3, 4])
            recall_ = recall_score(result_arr,
                                   targets_arr,
                                   average="micro",
                                   labels=[0, 2, 3, 4])
            print("precision_score:{}".format(precision_))
            print("recall_score:{}".format(recall_))
            print("F1_score:{}".format(2 * precision_ * recall_ /
                                       (recall_ + precision_)))
        else:
            # Convert to plain text
            vocab = params.vocabulary["target"]
            outputs = []

            for result in result_for_write:
                outputs.append(result.tolist())

            outputs = list(itertools.chain(*outputs))

            #restored_outputs = []

            # Write to file
            num = 0
            with open(args.output, "w") as outfile:
                for output in outputs:
                    decoded = []
                    for idx in output[:lenth[num] + 1]:
                        if idx == params.mapping["target"][params.eos]:
                            if idx != output[lenth[num]]:
                                print(
                                    "Warning: incomplete predictions as line{} in src sentence"
                                    .format(num + 1))
                        decoded.append(vocab[idx])
                    decoded = " ".join(decoded[:-1])
                    outfile.write("%s\n" % decoded)
                    num += 1
Esempio n. 18
0
def main(args):
    model_cls = models.get_model(args.model)

    # Import and override parameters
    # Priorities (low -> high):
    # default -> saved -> command
    params = default_params()
    params = merge_params(params, model_cls.default_params(args.hparam_set))
    params = import_params(args.output, args.model, params)
    params = override_params(params, args)

    # Initialize distributed utility
    if args.distributed:
        dist.init_process_group("nccl")
        torch.cuda.set_device(args.local_rank)
        torch.set_default_tensor_type(torch.cuda.FloatTensor)
    else:
        dist.init_process_group("nccl",
                                init_method=args.url,
                                rank=args.local_rank,
                                world_size=len(params.device_list))
        torch.cuda.set_device(params.device_list[args.local_rank])
        torch.set_default_tensor_type(torch.cuda.FloatTensor)

    # Export parameters
    if dist.get_rank() == 0:
        export_params(params.output, "params.json", params)
        export_params(params.output, "%s.json" % params.model,
                      collect_params(params, model_cls.default_params()))

    model = model_cls(params).cuda()

    if args.half:
        model = model.half()
        torch.set_default_dtype(torch.half)
        torch.set_default_tensor_type(torch.cuda.HalfTensor)

    model.train()

    # Init tensorboard
    summary.init(params.output, params.save_summary)

    schedule = get_learning_rate_schedule(params)
    clipper = get_clipper(params)
    optimizer = get_optimizer(params, schedule, clipper)

    if args.half:
        optimizer = optimizers.LossScalingOptimizer(optimizer)

    optimizer = optimizers.MultiStepOptimizer(optimizer, params.update_cycle)

    trainable_flags = print_variables(model, params.pattern,
                                      dist.get_rank() == 0)

    dataset = data.get_dataset(params.input, "train", params)

    if params.validation:
        sorted_key, eval_dataset = data.get_dataset(params.validation, "infer",
                                                    params)
        references = load_references(params.references)
    else:
        sorted_key = None
        eval_dataset = None
        references = None

    # Load checkpoint
    checkpoint = utils.latest_checkpoint(params.output)

    if args.checkpoint is not None:
        # Load pre-trained models
        state = torch.load(args.checkpoint, map_location="cpu")
        model.load_state_dict(state["model"])
        step = params.initial_step
        epoch = 0
        broadcast(model)
    elif checkpoint is not None:
        state = torch.load(checkpoint, map_location="cpu")
        step = state["step"]
        epoch = state["epoch"]
        model.load_state_dict(state["model"])

        if "optimizer" in state:
            optimizer.load_state_dict(state["optimizer"])
    else:
        step = 0
        epoch = 0
        broadcast(model)

    def train_fn(inputs):
        features, labels = inputs
        loss = model(features, labels)
        return loss

    counter = 0

    while True:
        for features in dataset:
            if counter % params.update_cycle == 0:
                step += 1
                utils.set_global_step(step)

            counter += 1
            t = time.time()
            features = data.lookup(features, "train", params)
            loss = train_fn(features)
            gradients = optimizer.compute_gradients(loss,
                                                    list(model.parameters()))
            grads_and_vars = exclude_variables(
                trainable_flags, zip(gradients,
                                     list(model.named_parameters())))
            optimizer.apply_gradients(grads_and_vars)

            t = time.time() - t

            summary.scalar("loss", loss, step, write_every_n_steps=1)
            summary.scalar("global_step/sec", t, step)

            print("epoch = %d, step = %d, loss = %.3f (%.3f sec)" %
                  (epoch + 1, step, float(loss), t))

            if counter % params.update_cycle == 0:
                if step >= params.train_steps:
                    utils.evaluate(model, sorted_key, eval_dataset,
                                   params.output, references, params)
                    save_checkpoint(step, epoch, model, optimizer, params)

                    if dist.get_rank() == 0:
                        summary.close()

                    return

                if step % params.eval_steps == 0:
                    utils.evaluate(model, sorted_key, eval_dataset,
                                   params.output, references, params)

                if step % params.save_checkpoint_steps == 0:
                    save_checkpoint(step, epoch, model, optimizer, params)

        epoch += 1
Esempio n. 19
0
def main(args):
    tf.logging.set_verbosity(tf.logging.INFO)
    model_cls_list = [models.get_model(model) for model in args.models]
    params_list = [default_parameters() for _ in range(len(model_cls_list))]
    params_list = [
        merge_parameters(params, model_cls.get_parameters())
        for params, model_cls in zip(params_list, model_cls_list)
    ]
    params_list = [
        import_params(args.checkpoints[i], args.models[i], params_list[i])
        for i in range(len(args.checkpoints))
    ]
    params_list = [
        override_parameters(params_list[i], args)
        for i in range(len(model_cls_list))
    ]

    with tf.Graph().as_default():
        model_var_lists = []

        for i, checkpoint in enumerate(args.checkpoints):
            tf.logging.info("Loading %s" % checkpoint)
            var_list = tf.train.list_variables(checkpoint)
            values = {}
            reader = tf.train.load_checkpoint(checkpoint)

            for (name, shape) in var_list:
                if not name.startswith(model_cls_list[i].get_name()):
                    continue

                if name.find("losses_avg") >= 0:
                    continue

                tensor = reader.get_tensor(name)
                values[name] = tensor

            model_var_lists.append(values)

        model_list = []

        for i in range(len(args.checkpoints)):
            name = model_cls_list[i].get_name()
            model = model_cls_list[i](params_list[i], name + "_%d" % i)
            model_list.append(model)

        params = params_list[0]
        params.initializer_gain = 1.0

        sorted_keys, sorted_inputs = dataset.read_eval_input_file(args.input)

        features = dataset.get_predict_input(sorted_inputs, params)

        placeholders = []

        for i in range(len(params.device_list)):
            placeholders.append({
                "text":
                tf.placeholder(tf.int32, [None, None], "text_%d" % i),
                "text_length":
                tf.placeholder(tf.int32, [None], "text_length_%d" % i),
                "aspect":
                tf.placeholder(tf.int32, [None, None], "aspect_%d" % i),
                "aspect_length":
                tf.placeholder(tf.int32, [None], "aspect_length_%d" % i),
                "polarity":
                tf.placeholder(tf.int32, [None, None], "polarity_%d" % i)
            })

        predict_fn = inference.create_predict_graph

        predictions = parallel.data_parallelism(
            params.device_list, lambda f: predict_fn(model_list, f, params),
            placeholders)

        assign_ops = []
        feed_dict = {}

        all_var_list = tf.trainable_variables()

        for i in range(len(args.checkpoints)):
            un_init_var_list = []
            name = model_cls_list[i].get_name()

            for v in all_var_list:
                if v.name.startswith(name + "_%d" % i):
                    un_init_var_list.append(v)

            ops = set_variables(un_init_var_list, model_var_lists[i],
                                name + "_%d" % i, feed_dict)
            assign_ops.extend(ops)

        assign_op = tf.group(*assign_ops)
        init_op = tf.tables_initializer()
        results = []

        with tf.Session(config=session_config(params)) as sess:
            sess.run(assign_op, feed_dict=feed_dict)
            sess.run(init_op)

            while True:
                try:
                    feats = sess.run(features)
                    op, feed_dict = shard_features(feats, placeholders,
                                                   predictions)
                    results.append(sess.run(op, feed_dict=feed_dict))
                    message = "Finished batch %d" % len(results)
                    tf.logging.log(tf.logging.INFO, message)
                except tf.errors.OutOfRangeError:
                    break

        input_features = []
        scores1 = []
        scores2 = []
        output_alphas = []
        for result in results:
            for item in result[0]:
                input_features.append(item.tolist())
            for item in result[1]:
                scores1.append(item.tolist())
            for item in result[2]:
                scores2.append(item.tolist())
            for item in result[3]:
                output_alphas.append(item.tolist())

        scores1 = list(itertools.chain(*scores1))
        scores2 = list(itertools.chain(*scores2))
        output_alphas = list(itertools.chain(*output_alphas))

        restored_scores1 = []
        restored_scores2 = []
        restored_output_alphas = []
        restored_inputs_text = []
        restored_inputs_aspect = []
        restored_inputs_score = []

        for index in range(len(sorted_inputs[0])):
            restored_scores1.append(scores1[sorted_keys[index]][0])
            restored_scores2.append(scores2[sorted_keys[index]])
            restored_output_alphas.append(output_alphas[sorted_keys[index]])

            restored_inputs_text.append(sorted_inputs[0][sorted_keys[index]])
            restored_inputs_aspect.append(sorted_inputs[1][sorted_keys[index]])
            restored_inputs_score.append(sorted_inputs[2][sorted_keys[index]])

        class3_bad_TP = 0.0
        class3_bad_FP = 0.0
        class3_bad_FN = 0.0

        class3_mid_TP = 0.0
        class3_mid_FP = 0.0
        class3_mid_FN = 0.0

        class3_good_TP = 0.0
        class3_good_FP = 0.0
        class3_good_FN = 0.0

        with open(args.output, "w") as outfile:

            for score1, score2, score3, alphas, text, aspect in zip(
                    restored_scores1, restored_scores2, restored_inputs_score,
                    restored_output_alphas, restored_inputs_text,
                    restored_inputs_aspect):
                score1 = str(score1)
                outfile.write("###########################\n")
                pattern = "%s|||%f,%f,%f|||%s\n"
                values = (score1, score2[0], score2[1], score2[2], score3)
                outfile.write(pattern % values)
                outfile.write(aspect + "\n")
                for (word, alpha) in zip(text.split(), alphas):
                    outfile.write(word + " " + str(alpha) + "\t")
                outfile.write("\n")

                if score1 == '0' and score3 == '0':
                    class3_bad_TP += 1.0
                if score1 == '1' and score3 == '1':
                    class3_mid_TP += 1.0
                if score1 == '2' and score3 == '2':
                    class3_good_TP += 1.0

                if score1 == '0' and score3 != '0':
                    class3_bad_FP += 1.0
                if score1 == '1' and score3 != '1':
                    class3_mid_FP += 1.0
                if score1 == '2' and score3 != '2':
                    class3_good_FP += 1.0

                if score1 != '0' and score3 == '0':
                    class3_bad_FN += 1.0
                if score1 != '1' and score3 == '1':
                    class3_mid_FN += 1.0
                if score1 != '2' and score3 == '2':
                    class3_good_FN += 1.0

            outfile.write("\n")
            outfile.write("Class 3:\n")
            outfile.write("Confusion Matrix:\n")
            outfile.write("\t" + "{name: >10s}".format(name="positive") +
                          "\t" + "{name: >10s}".format(name="neural") + "\t" +
                          "{name: >10s}".format(name="negative") + "\n")
            outfile.write("TP\t" + int2int(class3_bad_TP) + "\t" +
                          int2int(class3_mid_TP) + "\t" +
                          int2int(class3_good_TP) + "\n")
            outfile.write("FP\t" + int2int(class3_bad_FP) + "\t" +
                          int2int(class3_mid_FP) + "\t" +
                          int2int(class3_good_FP) + "\n")
            outfile.write("FN\t" + int2int(class3_bad_FN) + "\t" +
                          int2int(class3_mid_FN) + "\t" +
                          int2int(class3_good_FN) + "\n")
            outfile.write(
                "P\t" + float2int(class3_bad_TP /
                                  (class3_bad_TP + class3_bad_FP + 0.000001)) +
                "\t" + float2int(class3_mid_TP /
                                 (class3_mid_TP + class3_mid_FP + 0.000001)) +
                "\t" +
                float2int(class3_good_TP /
                          (class3_good_TP + class3_good_FP + 0.000001)) + "\n")
            outfile.write(
                "R\t" + float2int(class3_bad_TP /
                                  (class3_bad_TP + class3_bad_FN + 0.000001)) +
                "\t" + float2int(class3_mid_TP /
                                 (class3_mid_TP + class3_mid_FN + 0.000001)) +
                "\t" +
                float2int(class3_good_TP /
                          (class3_good_TP + class3_good_FN + 0.000001)) + "\n")
            outfile.write("F1\t" +
                          float2int(class3_bad_TP * 2 /
                                    (class3_bad_TP * 2 + class3_bad_FP +
                                     class3_bad_FN + 0.000001)) + "\t" +
                          float2int(class3_mid_TP * 2 /
                                    (class3_mid_TP * 2 + class3_mid_FP +
                                     class3_mid_FN + 0.000001)) + "\t" +
                          float2int(class3_good_TP * 2 /
                                    (class3_good_TP * 2 + class3_good_FP +
                                     class3_good_FN + 0.000001)) + "\n")
            outfile.write("F1-Micro:\t" + float2int(
                (class3_bad_TP + class3_mid_TP + class3_good_TP) * 2 /
                ((class3_bad_TP + class3_mid_TP + class3_good_TP) * 2 +
                 (class3_bad_FP + class3_mid_FP + class3_good_FP) +
                 (class3_bad_FN + class3_mid_FN + class3_good_FN) +
                 0.000001)) + "\n")
            outfile.write("F1-Macro:\t" + float2int(
                (class3_bad_TP * 2 /
                 (class3_bad_TP * 2 + class3_bad_FP + class3_bad_FN +
                  0.000001) + class3_mid_TP * 2 /
                 (class3_mid_TP * 2 + class3_mid_FP + class3_mid_FN +
                  0.000001) + class3_good_TP * 2 /
                 (class3_good_TP * 2 + class3_good_FP + class3_good_FN +
                  0.000001)) / 3.0) + "\n")
def main(args):
    if args.distribute:
        distribute.enable_distributed_training()

    tf.logging.set_verbosity(tf.logging.INFO)
    model_cls = models.get_model(args.model)
    params = default_parameters()

    # Import and override parameters
    # Priorities (low -> high):
    # default -> saved -> command
    params = merge_parameters(params, model_cls.get_parameters())
    params = import_params(args.output, args.model, params)
    override_parameters(params, args)

    # Export all parameters and model specific parameters
    if not args.distribute or distribute.rank() == 0:
        export_params(params.output, "params.json", params)
        export_params(params.output, "%s.json" % args.model,
                      collect_params(params, model_cls.get_parameters()))

    assert 'r2l' in params.input[2]
    # Build Graph
    use_all_devices(params)
    with tf.Graph().as_default():
        if not params.record:
            # Build input queue
            features = dataset.abd_get_training_input(params.input, params)
        else:
            features = record.get_input_features(
                os.path.join(params.record, "*train*"), "train", params)

        update_cycle = params.update_cycle
        features, init_op = cache.cache_features(features, update_cycle)

        # Build model
        initializer = get_initializer(params)
        regularizer = tf.contrib.layers.l1_l2_regularizer(
            scale_l1=params.scale_l1, scale_l2=params.scale_l2)
        model = model_cls(params)
        # Create global step
        global_step = tf.train.get_or_create_global_step()
        dtype = tf.float16 if args.fp16 else None

        if args.distribute:
            training_func = model.get_training_func(initializer, regularizer,
                                                    dtype)
            loss = training_func(features)
        else:
            # Multi-GPU setting
            sharded_losses = parallel.parallel_model(
                model.get_training_func(initializer, regularizer, dtype),
                features, params.device_list)
            loss = tf.add_n(sharded_losses) / len(sharded_losses)
            loss = loss + tf.losses.get_regularization_loss()

        # Print parameters
        if not args.distribute or distribute.rank() == 0:
            print_variables()

        learning_rate = get_learning_rate_decay(params.learning_rate,
                                                global_step, params)
        learning_rate = tf.convert_to_tensor(learning_rate, dtype=tf.float32)
        tf.summary.scalar("learning_rate", learning_rate)

        # Create optimizer
        if params.optimizer == "Adam":
            opt = tf.train.AdamOptimizer(learning_rate,
                                         beta1=params.adam_beta1,
                                         beta2=params.adam_beta2,
                                         epsilon=params.adam_epsilon)
        elif params.optimizer == "LazyAdam":
            opt = tf.contrib.opt.LazyAdamOptimizer(learning_rate,
                                                   beta1=params.adam_beta1,
                                                   beta2=params.adam_beta2,
                                                   epsilon=params.adam_epsilon)
        else:
            raise RuntimeError("Optimizer %s not supported" % params.optimizer)

        loss, ops = optimize.create_train_op(
            loss, opt, global_step,
            distribute.all_reduce if args.distribute else None, args.fp16,
            params)
        restore_op = restore_variables(args.checkpoint)

        # Validation
        if params.validation and params.references[0]:
            files = params.validation + list(params.references)
            eval_inputs = dataset.sort_and_zip_files(files)
            eval_input_fn = dataset.abd_get_evaluation_input
        else:
            eval_input_fn = None

        # Add hooks
        multiplier = tf.convert_to_tensor([update_cycle, 1])

        train_hooks = [
            tf.train.StopAtStepHook(last_step=params.train_steps),
            tf.train.NanTensorHook(loss),
            tf.train.LoggingTensorHook(
                {
                    "step": global_step,
                    "loss": loss,
                    "source": tf.shape(features["source"]) * multiplier,
                    "target": tf.shape(features["target"]) * multiplier
                },
                every_n_iter=1)
        ]

        if args.distribute:
            train_hooks.append(distribute.get_broadcast_hook())

        config = session_config(params)

        if not args.distribute or distribute.rank() == 0:
            # Add hooks
            save_vars = tf.trainable_variables() + [global_step]
            saver = tf.train.Saver(
                var_list=save_vars if params.only_save_trainable else None,
                max_to_keep=params.keep_checkpoint_max,
                sharded=False)
            tf.add_to_collection(tf.GraphKeys.SAVERS, saver)
            train_hooks.append(
                tf.train.CheckpointSaverHook(
                    checkpoint_dir=params.output,
                    save_secs=params.save_checkpoint_secs or None,
                    save_steps=params.save_checkpoint_steps or None,
                    saver=saver))

        if eval_input_fn is not None:
            if not args.distribute or distribute.rank() == 0:
                train_hooks.append(
                    hooks.EvaluationHook(
                        lambda f: inference.create_inference_graph([model], f,
                                                                   params),
                        lambda: eval_input_fn(eval_inputs, params),
                        lambda x: decode_target_ids(x, params),
                        params.output,
                        config,
                        params.keep_top_checkpoint_max,
                        eval_secs=params.eval_secs,
                        eval_steps=params.eval_steps))

        def restore_fn(step_context):
            step_context.session.run(restore_op)

        def step_fn(step_context):
            # Bypass hook calls
            step_context.session.run([init_op, ops["zero_op"]])
            for i in range(update_cycle - 1):
                step_context.session.run(ops["collect_op"])

            return step_context.run_with_hooks(ops["train_op"])

        # Create session, do not use default CheckpointSaverHook
        if not args.distribute or distribute.rank() == 0:
            checkpoint_dir = params.output
        else:
            checkpoint_dir = None

        with tf.train.MonitoredTrainingSession(checkpoint_dir=checkpoint_dir,
                                               hooks=train_hooks,
                                               save_checkpoint_secs=None,
                                               config=config) as sess:
            # Restore pre-trained variables
            sess.run_step_fn(restore_fn)

            while not sess.should_stop():
                sess.run_step_fn(step_fn)
Esempio n. 21
0
def main(args):
    tf.logging.set_verbosity(tf.logging.INFO)
    # Load configs
    model_cls_list = [models.get_model(model) for model in args.models]
    params_list = [default_parameters() for _ in range(len(model_cls_list))]
    params_list = [
        merge_parameters(params, model_cls.get_parameters())
        for params, model_cls in zip(params_list, model_cls_list)
    ]
    params_list = [
        import_params(args.checkpoints[i], args.models[i], params_list[i])
        for i in range(len(args.checkpoints))
    ]
    params_list = [
        override_parameters(params_list[i], args)
        for i in range(len(model_cls_list))
    ]

    # Build Graph
    with tf.Graph().as_default():
        model_var_lists = []

        # Load checkpoints
        for i, checkpoint in enumerate(args.checkpoints):
            tf.logging.info("Loading %s" % checkpoint)
            var_list = tf.train.list_variables(checkpoint)
            values = {}
            reader = tf.train.load_checkpoint(checkpoint)

            for (name, shape) in var_list:
                if not name.startswith(model_cls_list[i].get_name()):
                    continue

                if name.find("losses_avg") >= 0:
                    continue

                tensor = reader.get_tensor(name)
                values[name] = tensor

            model_var_lists.append(values)
        # Build models
        model_list = []

        for i in range(len(args.checkpoints)):
            name = model_cls_list[i].get_name()
            model = model_cls_list[i](params_list[i], name + "_%d" % i)
            model_list.append(model)

        params = params_list[0]
        print(params)

        build_graph(params,
                    args,
                    model_list,
                    model_cls_list,
                    model_var_lists,
                    problem='parsing')

        build_graph(params,
                    args,
                    model_list,
                    model_cls_list,
                    model_var_lists,
                    problem='amr')
Esempio n. 22
0
def main(args):
    tf.logging.set_verbosity(tf.logging.INFO)
    model_cls = models.get_model(args.model)
    params = default_parameters()

    # Import and override parameters
    # Priorities (low -> high):
    # default -> saved -> command
    params = merge_parameters(params, model_cls.get_parameters())
    params = import_params(args.output, args.model, params)
    override_parameters(params, args)

    # Export all parameters and model specific parameters
    export_params(params.output, "params.json", params)
    export_params(params.output, "%s.json" % args.model,
                  collect_params(params, model_cls.get_parameters()))

    # Build Graph
    with tf.Graph().as_default():
        if not params.record:
            # Build input queue
            features = dataset.get_training_input(params.input, params)
        else:
            features = record.get_input_features(
                os.path.join(params.record, "*train*"), "train", params)

        # Build model
        initializer = get_initializer(params)
        model = model_cls(params)
        if params.MRT:
            assert params.batch_size == 1
            features = mrt_utils.get_MRT(features, params, model)

        # Multi-GPU setting
        sharded_losses = parallel.parallel_model(
            model.get_training_func(initializer), features, params.device_list)
        loss = tf.add_n(sharded_losses) / len(sharded_losses)

        # Create global step
        global_step = tf.train.get_or_create_global_step()

        # Print parameters
        all_weights = {v.name: v for v in tf.trainable_variables()}
        total_size = 0

        for v_name in sorted(list(all_weights)):
            v = all_weights[v_name]
            tf.logging.info("%s\tshape    %s", v.name[:-2].ljust(80),
                            str(v.shape).ljust(20))
            v_size = np.prod(np.array(v.shape.as_list())).tolist()
            total_size += v_size
        tf.logging.info("Total trainable variables size: %d", total_size)

        learning_rate = get_learning_rate_decay(params.learning_rate,
                                                global_step, params)
        learning_rate = tf.convert_to_tensor(learning_rate, dtype=tf.float32)
        tf.summary.scalar("learning_rate", learning_rate)

        # Create optimizer
        opt = tf.train.AdamOptimizer(learning_rate,
                                     beta1=params.adam_beta1,
                                     beta2=params.adam_beta2,
                                     epsilon=params.adam_epsilon)

        train_op = tf.contrib.layers.optimize_loss(
            name="training",
            loss=loss,
            global_step=global_step,
            learning_rate=learning_rate,
            clip_gradients=params.clip_grad_norm or None,
            optimizer=opt,
            colocate_gradients_with_ops=True)

        # Validation
        if params.validation and params.references[0]:
            files = [params.validation] + list(params.references)
            eval_inputs = dataset.sort_and_zip_files(files)
            eval_input_fn = dataset.get_evaluation_input
        else:
            eval_input_fn = None

        # Add hooks
        train_hooks = [
            tf.train.StopAtStepHook(last_step=params.train_steps),
            tf.train.NanTensorHook(loss),
            tf.train.LoggingTensorHook(
                {
                    "step": global_step,
                    "loss": loss,
                    "source": tf.shape(features["source"]),
                    "target": tf.shape(features["target"])
                },
                every_n_iter=1),
            tf.train.CheckpointSaverHook(
                checkpoint_dir=params.output,
                save_secs=params.save_checkpoint_secs or None,
                save_steps=params.save_checkpoint_steps or None,
                saver=tf.train.Saver(max_to_keep=params.keep_checkpoint_max,
                                     sharded=False))
        ]

        config = session_config(params)

        if eval_input_fn is not None:
            train_hooks.append(
                hooks.EvaluationHook(
                    lambda f: search.create_inference_graph(
                        model.get_evaluation_func(), f, params),
                    lambda: eval_input_fn(eval_inputs, params),
                    lambda x: decode_target_ids(x, params),
                    params.output,
                    config,
                    params.keep_top_checkpoint_max,
                    eval_secs=params.eval_secs,
                    eval_steps=params.eval_steps))

        # Create session, do not use default CheckpointSaverHook
        with tf.train.MonitoredTrainingSession(checkpoint_dir=params.output,
                                               hooks=train_hooks,
                                               save_checkpoint_secs=None,
                                               config=config) as sess:
            while not sess.should_stop():
                sess.run(train_op)
Esempio n. 23
0
def main(args):
    model_cls = models.get_model(args.model)
    # Import and override parameters
    # Priorities (low -> high):
    # default -> saved -> command
    params = default_params()
    params = merge_params(params, model_cls.default_params())
    params = import_params(args.checkpoint, args.model, params)
    params = override_params(params, args)

    dist.init_process_group("nccl",
                            init_method=args.url,
                            rank=args.local_rank,
                            world_size=len(params.device_list))
    torch.cuda.set_device(params.device_list[args.local_rank])
    torch.set_default_tensor_type(torch.cuda.FloatTensor)

    if args.half:
        torch.set_default_dtype(torch.half)
        torch.set_default_tensor_type(torch.cuda.HalfTensor)

    def score_fn(inputs, _model, level="sentence"):
        _features, _labels = inputs
        _score = _model(_features, _labels, mode="eval", level=level)
        return _score

    # Create model
    with torch.no_grad():
        model = model_cls(params).cuda()

        if args.half:
            model = model.half()

        if not params.monte_carlo:
            model.eval()

        model.load_state_dict(
            torch.load(utils.latest_checkpoint(args.checkpoint),
                       map_location="cpu")["model"])
        dataset = data.get_dataset(args.input, "eval", params)
        data_iter = iter(dataset)
        counter = 0
        pad_max = 1024

        # Buffers for synchronization
        size = torch.zeros([dist.get_world_size()]).long()
        if params.level == "sentence":
            t_list = [
                torch.empty([params.decode_batch_size]).float()
                for _ in range(dist.get_world_size())
            ]
        else:
            t_list = [
                torch.empty([params.decode_batch_size, pad_max]).float()
                for _ in range(dist.get_world_size())
            ]

        if dist.get_rank() == 0:
            fd = open(args.output, "w")
        else:
            fd = None

        while True:
            try:
                features = next(data_iter)
                features = data.lookup(features, "eval", params)
                batch_size = features[0]["source"].shape[0]
            except:
                features = {
                    "source": torch.ones([1, 1]).long(),
                    "source_mask": torch.ones([1, 1]).float(),
                    "target": torch.ones([1, 1]).long(),
                    "target_mask": torch.ones([1, 1]).float()
                }, torch.ones([1, 1]).long()
                batch_size = 0

            t = time.time()
            counter += 1

            scores = score_fn(features, model, params.level)

            # Padding
            if params.level == "sentence":
                pad_batch = params.decode_batch_size - scores.shape[0]
                scores = torch.nn.functional.pad(scores, [0, pad_batch])
            else:
                pad_batch = params.decode_batch_size - scores.shape[0]
                pad_length = pad_max - scores.shape[1]
                scores = torch.nn.functional.pad(scores,
                                                 (0, pad_length, 0, pad_batch),
                                                 value=-1)

            # Synchronization
            size.zero_()
            size[dist.get_rank()].copy_(torch.tensor(batch_size))
            dist.all_reduce(size)
            dist.all_gather(t_list, scores.float())

            if size.sum() == 0:
                break

            if dist.get_rank() != 0:
                continue

            for i in range(params.decode_batch_size):
                for j in range(dist.get_world_size()):
                    n = size[j]
                    score = t_list[j][i]

                    if i >= n:
                        continue

                    if params.level == "sentence":
                        fd.write("{:.4f}\n".format(score))
                    else:
                        s_list = score.tolist()
                        for s in s_list:
                            if s >= 0:
                                fd.write("{:.8f} ".format(s))
                            else:
                                fd.write("\n")
                                break

            t = time.time() - t
            logging.info("Finished batch: %d (%.3f sec)" % (counter, t))

        if dist.get_rank() == 0:
            fd.close()
Esempio n. 24
0
def main(args):
    if args.distribute:
        distribute.enable_distributed_training()

    tf.logging.set_verbosity(tf.logging.INFO)
    model_cls = models.get_model(args.model)
    params = default_parameters()

    # Import and override parameters
    # Priorities (low -> high):
    # default -> saved -> command
    params = merge_parameters(params, model_cls.get_parameters())
    params = import_params(args.output, args.model, params)
    override_parameters(params, args)

    # Export all parameters and model specific parameters
    if distribute.rank() == 0:
        export_params(params.output, "params.json", params)
        export_params(params.output, "%s.json" % args.model,
                      collect_params(params, model_cls.get_parameters()))

    # Build Graph
    with tf.Graph().as_default():
        if not params.record:
            # Build input queue
            features = dataset.get_training_input(params.input, params)
        else:
            features = record.get_input_features(
                os.path.join(params.record, "*train*"), "train", params)

        # Build model
        initializer = get_initializer(params)
        regularizer = tf.contrib.layers.l1_l2_regularizer(
            scale_l1=params.scale_l1, scale_l2=params.scale_l2)
        model = model_cls(params)
        # Create global step
        global_step = tf.train.get_or_create_global_step()
        dtype = tf.float16 if args.half else None

        # Multi-GPU setting
        sharded_losses = parallel.parallel_model(
            model.get_training_func(initializer, regularizer, dtype), features,
            params.device_list)
        loss = tf.add_n(sharded_losses) / len(sharded_losses)
        loss = loss + tf.losses.get_regularization_loss()

        if distribute.rank() == 0:
            print_variables()

        learning_rate = get_learning_rate_decay(params.learning_rate,
                                                global_step, params)
        learning_rate = tf.convert_to_tensor(learning_rate, dtype=tf.float32)

        tf.summary.scalar("loss", loss)
        tf.summary.scalar("learning_rate", learning_rate)

        # Create optimizer
        if params.optimizer == "Adam":
            opt = tf.train.AdamOptimizer(learning_rate,
                                         beta1=params.adam_beta1,
                                         beta2=params.adam_beta2,
                                         epsilon=params.adam_epsilon)
        elif params.optimizer == "LazyAdam":
            opt = tf.contrib.opt.LazyAdamOptimizer(learning_rate,
                                                   beta1=params.adam_beta1,
                                                   beta2=params.adam_beta2,
                                                   epsilon=params.adam_epsilon)
        else:
            raise RuntimeError("Optimizer %s not supported" % params.optimizer)

        opt = optimizers.MultiStepOptimizer(opt, params.update_cycle)

        if args.half:
            opt = optimizers.LossScalingOptimizer(opt, params.loss_scale)

        # Optimization
        grads_and_vars = opt.compute_gradients(
            loss, colocate_gradients_with_ops=True)

        if params.clip_grad_norm:
            grads, var_list = list(zip(*grads_and_vars))
            grads, _ = tf.clip_by_global_norm(grads, params.clip_grad_norm)
            grads_and_vars = zip(grads, var_list)

        train_op = opt.apply_gradients(grads_and_vars, global_step=global_step)

        # Validation
        if params.validation and params.references[0]:
            files = [params.validation] + list(params.references)
            eval_inputs = dataset.sort_and_zip_files(files)
            eval_input_fn = dataset.get_evaluation_input
        else:
            eval_input_fn = None

        # Hooks
        train_hooks = [
            tf.train.StopAtStepHook(last_step=params.train_steps),
            tf.train.NanTensorHook(loss),
            tf.train.LoggingTensorHook(
                {
                    "step": global_step,
                    "loss": loss,
                    "source": tf.shape(features["source"]),
                    "target": tf.shape(features["target"])
                },
                every_n_iter=1)
        ]

        broadcast_hook = distribute.get_broadcast_hook()

        if broadcast_hook:
            train_hooks.append(broadcast_hook)

        if distribute.rank() == 0:
            # Add hooks
            save_vars = tf.trainable_variables() + [global_step]
            saver = tf.train.Saver(
                var_list=save_vars if params.only_save_trainable else None,
                max_to_keep=params.keep_checkpoint_max,
                sharded=False)
            tf.add_to_collection(tf.GraphKeys.SAVERS, saver)
            train_hooks.append(
                hooks.MultiStepHook(tf.train.CheckpointSaverHook(
                    checkpoint_dir=params.output,
                    save_secs=params.save_checkpoint_secs or None,
                    save_steps=params.save_checkpoint_steps or None,
                    saver=saver),
                                    step=params.update_cycle))

            if eval_input_fn is not None:
                train_hooks.append(
                    hooks.MultiStepHook(hooks.EvaluationHook(
                        lambda f: inference.create_inference_graph([model], f,
                                                                   params),
                        lambda: eval_input_fn(eval_inputs, params),
                        lambda x: decode_target_ids(x, params),
                        params.output,
                        session_config(params),
                        device_list=params.device_list,
                        max_to_keep=params.keep_top_checkpoint_max,
                        eval_secs=params.eval_secs,
                        eval_steps=params.eval_steps),
                                        step=params.update_cycle))
            checkpoint_dir = params.output
        else:
            checkpoint_dir = None

        restore_op = restore_variables(args.checkpoint)

        def restore_fn(step_context):
            step_context.session.run(restore_op)

        # Create session, do not use default CheckpointSaverHook
        with tf.train.MonitoredTrainingSession(
                checkpoint_dir=checkpoint_dir,
                hooks=train_hooks,
                save_checkpoint_secs=None,
                config=session_config(params)) as sess:
            # Restore pre-trained variables
            sess.run_step_fn(restore_fn)

            while not sess.should_stop():
                sess.run(train_op)
Esempio n. 25
0
def main(args):
    tf.logging.set_verbosity(tf.logging.INFO)
    model_cls = models.get_model(args.model)
    params = default_parameters()

    # Import and override parameters
    # Priorities (low -> high):
    # default -> saved -> command
    params = merge_parameters(params, model_cls.get_parameters())
    params = import_params(args.checkpoint, args.model, params)
    override_parameters(params, args)

    # Build Graph
    with tf.Graph().as_default():
        model = model_cls(params)
        #print('input file:', args.input)
        inputs = read_files([args.input])[0]
        #print('inputs', inputs)

        # Load phrase table
        #if args.tmpphrase and os.path.exists(args.tmpphrase):
        #    print('load tmpphrase')
        #    phrase_table = json.load(open(args.tmpphrase, 'r'))
        #else:
        phrase_table = json.load(open(args.phrase, 'r'))

        # Load ivocab
        ivocab_src = build_ivocab(params.vocabulary["source"])
        ivocab_trg = build_ivocab(params.vocabulary["target"])

        #print(features)
        score_fn = model.get_evaluation_func()
        score_cache_fn = model.get_evaluation_cache_func()
        placeholder = {}
        placeholder["source"] = tf.placeholder(tf.int32, [None, None],
                                               "source")
        placeholder["source_length"] = tf.placeholder(tf.int32, [None],
                                                      "source_length")
        placeholder["target"] = tf.placeholder(tf.int32, [None, None],
                                               "target")
        placeholder["target_length"] = tf.placeholder(tf.int32, [None],
                                                      "target_length")
        scores = score_fn(placeholder, params)
        state = {
            "encoder":
            tf.placeholder(tf.float32, [None, None, params.hidden_size],
                           "encoder"),
            "decoder": {
                "layer_%d" % i: {
                    "key":
                    tf.placeholder(tf.float32,
                                   [None, None, params.hidden_size],
                                   "decoder_key"),
                    "value":
                    tf.placeholder(tf.float32,
                                   [None, None, params.hidden_size],
                                   "decoder_value")
                }
                for i in range(params.num_decoder_layers)
            }
        }
        scores_cache = score_cache_fn(placeholder, state, params)

        # create cache
        enc_fn, dec_fn = model.get_inference_func()
        p_enc = {}
        p_enc["source"] = tf.placeholder(tf.int32, [None, None], "source")
        p_enc["source_length"] = tf.placeholder(tf.int32, [None],
                                                "source_length")
        enc = enc_fn(placeholder, params)
        dec = dec_fn(placeholder, state, params)

        sess_creator = tf.train.ChiefSessionCreator(
            config=session_config(params))

        # Load checkpoint
        tf.logging.info("Loading %s" % args.checkpoint)
        var_list = tf.train.list_variables(args.checkpoint)
        values = {}
        reader = tf.train.load_checkpoint(args.checkpoint)

        for (name, shape) in var_list:
            if not name.startswith(model_cls.get_name()):
                continue

            tensor = reader.get_tensor(name)
            values[name] = tensor

        ops = set_variables(tf.trainable_variables(), values,
                            model_cls.get_name())
        #scores = score_fn(features, params)

        sess_creator = tf.train.ChiefSessionCreator(
            config=session_config(params))

        # Load checkpoint
        tf.logging.info("Loading %s" % args.checkpoint)
        var_list = tf.train.list_variables(args.checkpoint)
        values = {}
        reader = tf.train.load_checkpoint(args.checkpoint)

        for (name, shape) in var_list:
            if not name.startswith(model_cls.get_name()):
                continue

            tensor = reader.get_tensor(name)
            values[name] = tensor

        ops = set_variables(tf.trainable_variables(), values,
                            model_cls.get_name())
        assign_op = tf.group(*ops)

        # Create session
        sess = tf.train.MonitoredSession(session_creator=sess_creator)
        sess.run(assign_op)

        fd = tf.gfile.Open(args.output, "w")

        fout = open(args.output, 'w')
        count = 0

        for input in inputs:
            count += 1
            print(count)
            start = time.time()
            src = copy.deepcopy(input)
            src = src.decode('utf-8')
            words = src.split(' ')
            f_src = {}
            f_src["source"] = [getid(ivocab_src, input)]
            f_src["source_length"] = [len(f_src["source"][0])]
            #print('input_enc', f_src)
            feed_src = {
                placeholder["source"]: f_src["source"],
                placeholder["source_length"]: f_src["source_length"]
            }
            encoder_state = sess.run(enc, feed_dict=feed_src)

            # generate a subset of phrase table for current translation
            phrases = subset(phrase_table, words, args.ngram, rbpe=args.rbpe)
            phrases_reverse = reverse_phrase(phrases)
            #print('reverse phrase:', phrases_reverse)
            print('source:', src.encode('utf-8'))
            if args.rbpe:
                words = reverse_bpe(src).split(' ')
                print('reverse_bpe:', reverse_bpe(src).encode('utf-8'))
            coverage = [0] * len(words)

            #if args.tmpphrase:
            #    json.dump(phrases, open(args.tmpphrase, 'w'))
            if args.verbose:
                print_phrases(phrases)
            #print('src:', repr(src))
            #for k in phrases.keys():
            #    print(k.encode('utf-8'), len(phrases[k]))

            state_init = {}
            state_init["encoder"] = encoder_state
            for i in range(params.num_decoder_layers):
                state_init["decoder"] = {}
                state_init["decoder"]["layer_%d" % i] = np.zeros(
                    (0, params.hidden_size))
            '''
            stacks:
            1. partial translation
            2. coverage status (set), [coverage, status, align_log_prob]
            status (set): ['normal',  ''] or ['limited', limited word] 
            3. hidden state ({"layer_0": [...], "layer_1": [...]})
            4. [id from last stack, last status id]
            5. score
            '''
            stacks = []
            stack_current = [[
                '', {
                    json.dumps([coverage, ['normal', ''], 0.]): 1
                },
                getstate(encoder_state, params.num_decoder_layers), [0, 0], 0
            ]]
            stacks.append(stack_current)
            finished = []
            length = 0

            time_neural = 0
            time_update = 0
            while True:
                if args.verbose:
                    print('===', length, '===')
                time_a = time.time()
                if len(stack_current) == 0:
                    break
                stack_current = sorted(stack_current,
                                       key=lambda x: x[-1],
                                       reverse=True)
                stack_current = stack_current[:params.beam_size]
                if args.verbose:
                    print_stack(stack_current)
                    print([len(ss[1]) for ss in stack_current])
                time_b = time.time()
                #print('sort stack:', time_b-time_a, 's')

                if len(stack_current) == 0:
                    continue

                features = get_feature(stack_current, ivocab_trg)
                #print('encoder state size:', encoder_state)
                features["encoder"] = np.tile(encoder_state["encoder"],
                                              (len(stack_current), 1, 1))
                features["source"] = [getid(ivocab_src, input)
                                      ] * len(stack_current)
                features["source_length"] = [len(features["source"][0])
                                             ] * len(stack_current)
                #print("features:", features)
                features["decoder"] = {}
                for i in range(params.num_decoder_layers):
                    features["decoder"]["layer_%d" % i] = merge_tensor(
                        stack_current, i)
                '''
                for k in features.keys():
                    if type(features[k]) == list:
                        print(k, np.asarray(features[k]).shape)
                print("encoder", features["encoder"].shape)
                print("decoder_key", features["decoder"]["layer_0"]["key"].shape)
                '''

                feed_dict = {
                    placeholder['source']: features['source'],
                    placeholder['source_length']: features['source_length'],
                    placeholder['target']: features['target'],
                    placeholder['target_length']: features['target_length']
                }
                #if length >= 1:
                #    scoring = sess.run(scores, feed_dict=feed_dict)
                #    print('scores:', scoring)
                dict_tmp = {state["encoder"]: features["encoder"]}
                feed_dict.update(dict_tmp)
                dict_tmp = {
                    state["decoder"]["layer_%d" % i]["key"]:
                    features["decoder"]["layer_%d" % i]["key"]
                    for i in range(params.num_decoder_layers)
                }
                feed_dict.update(dict_tmp)
                dict_tmp = {
                    state["decoder"]["layer_%d" % i]["value"]:
                    features["decoder"]["layer_%d" % i]["value"]
                    for i in range(params.num_decoder_layers)
                }
                feed_dict.update(dict_tmp)

                time_bc = time.time()
                #print('prepare neural:', time_bc-time_b, 's')
                log_probs, new_state = sess.run(dec, feed_dict=feed_dict)
                time_c = time.time()
                #print('neural:', time_c-time_bc, 's')
                time_neural += time_c - time_bc
                new_state = outdims(new_state, params.num_decoder_layers)

                new_stack, finished = update_stack(stack_current, finished,
                                                   log_probs, new_state, words,
                                                   phrases, phrases_reverse,
                                                   ivocab_trg, params)
                time_d = time.time()
                #print('update stack:', time_d-time_c, 's')
                time_update += time_d - time_c
                finished = sorted(finished, key=lambda x: x[-1], reverse=True)
                finished = finished[:params.beam_size]
                stack_current = new_stack
                stacks.append(stack_current)
                if args.verbose:
                    print_stack_finished(finished)
                length += 1
                #if length > 5:
                #    exit()
            print('neural:', time_neural, 's')
            print('total update stack:', time_update, 's')

            if params.cut_ending:
                len_max = len(finished[0][0].split(' ')) - 1
                len_now = len_max
                loss = [0] * (len_now + 1)
                pos_cut = len_now + 1
                now = stacks[len_now][finished[0][1]]
                while len_now > 0:
                    loss[len_now] = now[-1]
                    now = stacks[len_now - 1][now[3][0]]
                    len_now -= 1
                for p in range(15, len_max + 1):
                    cut = True
                    sum_ = 0.
                    for tp in range(p, len_max + 1):
                        sum_ = loss[p - 1] - loss[tp]
                        if sum_ / (tp - p + 1) < params.cut_threshold:
                            cut = False
                            break
                    if cut:
                        pos_cut = p
                        break
                result = ' '.join(finished[0][0].split(' ')[:pos_cut - 1])

                print('loss:', loss)
            else:
                result = finished[0][0]

            #fout.write((finished[0][0].replace(' <eos>', '').strip()+'\n').encode('utf-8'))
            fout.write(
                (result.replace(' <eos>', '').strip() + '\n').encode('utf-8'))
            print((result.replace(' <eos>', '').strip()).encode('utf-8'))
            print((finished[0][0].replace(' <eos>',
                                          '').strip()).encode('utf-8'))

            end = time.time()
            global time_totalsp
            if args.time:
                print('time total sp:', time_totalsp, 's')
                print('time:', end - start, 'seconds')

            if args.verbose:
                len_now = len(finished[0][0].split(' ')) - 1
                now = stacks[len_now][finished[0][1]]
                nowst = None
                while len_now > 0:
                    print("===", len_now, now[-1], "===")
                    print(now[0].encode('utf-8'), now[1], now[3])
                    find_translate(now, stacks[len_now - 1][now[3][0]], nowst,
                                   now[3][1], words)
                    nowst = now[3][1]
                    now = stacks[len_now - 1][now[3][0]]
                    len_now -= 1
Esempio n. 26
0
def main(args):
    tf.logging.set_verbosity(tf.logging.INFO)
    model_cls = models.get_model(args.model)
    params = default_parameters()

    # Import and override parameters
    # Priorities (low -> high):
    # default -> saved -> command
    params = merge_parameters(params, model_cls.get_parameters())
    params = import_params(args.output, args.model, params)
    override_parameters(params, args)

    # Export all parameters and model specific parameters
    export_params(params.output, "params.json", params)
    export_params(params.output, "%s.json" % args.model,
                  collect_params(params, model_cls.get_parameters()))

    #import ipdb; ipdb.set_trace()
    # Build Graph
    with tf.Graph().as_default():
        if not params.record:
            # Build input queue
            features = dataset_c2f_4layers.get_training_input_and_c2f_label(
                params.input, params.c2f_input, params)
        else:
            features = record.get_input_features(
                os.path.join(params.record, "*train*"), "train", params)

        update_cycle = params.update_cycle
        features, init_op = cache.cache_features(features, update_cycle)

        # Build model
        initializer = get_initializer(params)
        regularizer = tf.contrib.layers.l1_l2_regularizer(
            scale_l1=params.scale_l1, scale_l2=params.scale_l2)
        model = model_cls(params)
        # Create global step
        global_step = tf.train.get_or_create_global_step()

        # Multi-GPU setting
        sharded_losses = parallel.parallel_model(
            model.get_training_func(initializer, regularizer), features,
            params.device_list)
        if len(sharded_losses) > 1:
            losses_mle, losses_l1, losses_l2, losses_l3, losses_l4 = sharded_losses
            loss_mle = tf.add_n(losses_mle) / len(losses_mle)
            loss_l1 = tf.add_n(losses_l1) / len(losses_l1)
            loss_l2 = tf.add_n(losses_l2) / len(losses_l2)
            loss_l3 = tf.add_n(losses_l3) / len(losses_l3)
            loss_l4 = tf.add_n(losses_l4) / len(losses_l4)
        else:
            loss_mle, loss_l1, loss_l2, loss_l3, loss_l4 = sharded_losses[0]
        loss = loss_mle + (loss_l1 + loss_l2 + loss_l3 + loss_l4
                           ) / 4.0 + tf.losses.get_regularization_loss()

        # Print parameters
        all_weights = {v.name: v for v in tf.trainable_variables()}
        total_size = 0

        for v_name in sorted(list(all_weights)):
            v = all_weights[v_name]
            tf.logging.info("%s\tshape    %s", v.name[:-2].ljust(80),
                            str(v.shape).ljust(20))
            v_size = np.prod(np.array(v.shape.as_list())).tolist()
            total_size += v_size
        tf.logging.info("Total trainable variables size: %d", total_size)

        learning_rate = get_learning_rate_decay(params.learning_rate,
                                                global_step, params)
        learning_rate = tf.convert_to_tensor(learning_rate, dtype=tf.float32)
        tf.summary.scalar("learning_rate", learning_rate)

        # Create optimizer
        if params.optimizer == "Adam":
            opt = tf.train.AdamOptimizer(learning_rate,
                                         beta1=params.adam_beta1,
                                         beta2=params.adam_beta2,
                                         epsilon=params.adam_epsilon)
        elif params.optimizer == "LazyAdam":
            opt = tf.contrib.opt.LazyAdamOptimizer(learning_rate,
                                                   beta1=params.adam_beta1,
                                                   beta2=params.adam_beta2,
                                                   epsilon=params.adam_epsilon)
        else:
            raise RuntimeError("Optimizer %s not supported" % params.optimizer)

        loss, ops = optimize.create_train_op(loss, opt, global_step, params)
        restore_op = restore_variables(args.checkpoint)

        #import ipdb; ipdb.set_trace()

        # Validation
        if params.validation and params.references[0]:
            files = [params.validation] + list(params.references)
            eval_inputs = dataset_c2f_4layers.sort_and_zip_files(files)
            eval_input_fn = dataset_c2f_4layers.get_evaluation_input
        else:
            eval_input_fn = None

        # Add hooks
        save_vars = tf.trainable_variables() + [global_step]
        saver = tf.train.Saver(
            var_list=save_vars if params.only_save_trainable else None,
            max_to_keep=params.keep_checkpoint_max,
            sharded=False)
        tf.add_to_collection(tf.GraphKeys.SAVERS, saver)

        multiplier = tf.convert_to_tensor([update_cycle, 1])

        train_hooks = [
            tf.train.StopAtStepHook(last_step=params.train_steps),
            tf.train.NanTensorHook(loss),
            tf.train.LoggingTensorHook(
                {
                    "step": global_step,
                    "loss_mle": loss_mle,
                    "loss_l1": loss_l1,
                    "loss_l2": loss_l2,
                    "loss_l3": loss_l3,
                    "loss_l4": loss_l4,
                    "source": tf.shape(features["source"]) * multiplier,
                    "target": tf.shape(features["target"]) * multiplier
                },
                every_n_iter=1),
            tf.train.CheckpointSaverHook(
                checkpoint_dir=params.output,
                save_secs=params.save_checkpoint_secs or None,
                save_steps=params.save_checkpoint_steps or None,
                saver=saver)
        ]

        config = session_config(params)

        if eval_input_fn is not None:
            train_hooks.append(
                hooks.EvaluationHook(
                    lambda f: inference.create_inference_graph([model], f,
                                                               params),
                    lambda: eval_input_fn(eval_inputs, params),
                    lambda x: decode_target_ids(x, params),
                    params.output,
                    config,
                    params.keep_top_checkpoint_max,
                    eval_secs=params.eval_secs,
                    eval_steps=params.eval_steps))

        #def restore_fn(step_context):
        #    step_context.session.run(restore_op)

        #def step_fn(step_context):
        #    # Bypass hook calls
        #    step_context.session.run([init_op, ops["zero_op"]])
        #    for i in range(update_cycle - 1):
        #        step_context.session.run(ops["collect_op"])

        #    return step_context.run_with_hooks(ops["train_op"])

        # Create session, do not use default CheckpointSaverHook
        with tf.train.MonitoredTrainingSession(checkpoint_dir=params.output,
                                               hooks=train_hooks,
                                               save_checkpoint_secs=None,
                                               config=config) as sess:
            # Restore pre-trained variables
            sess._tf_sess().run(restore_op)

            while not sess.should_stop():
                sess._tf_sess().run([init_op, ops["zero_op"]])
                for i in range(update_cycle - 1):
                    sess._tf_sess().run(ops["collect_op"])
                sess.run(ops["train_op"])
Esempio n. 27
0
def main(args):
    model_cls = models.get_model(args.model)

    # Import and override parameters
    # Priorities (low -> high):
    # default -> saved -> command
    params = default_params()
    params = merge_params(params, model_cls.default_params(args.hparam_set))
    params = import_params(args.output, args.model, params)
    params = override_params(params, args)

    # Initialize distributed utility
    if args.distributed:
        dist.init_process_group("nccl")
        torch.cuda.set_device(args.local_rank)
        torch.set_default_tensor_type(torch.cuda.FloatTensor)
    else:
        dist.init_process_group("nccl",
                                init_method=args.url,
                                rank=args.local_rank,
                                world_size=len(params.device_list))
        torch.cuda.set_device(params.device_list[args.local_rank])
        torch.set_default_tensor_type(torch.cuda.FloatTensor)

    # Export parameters
    if dist.get_rank() == 0:
        export_params(params.output, "params.json", params)
        export_params(params.output, "%s.json" % params.model,
                      collect_params(params, model_cls.default_params()))

    model = model_cls(params).cuda()

    if args.half:
        model = model.half()
        torch.set_default_dtype(torch.half)
        torch.set_default_tensor_type(torch.cuda.HalfTensor)

    model.train()

    # Init tensorboard
    summary.init(params.output, params.save_summary)

    schedule = get_learning_rate_schedule(params)
    clipper = get_clipper(params)

    if params.optimizer.lower() == "adam":
        optimizer = optimizers.AdamOptimizer(learning_rate=schedule,
                                             beta_1=params.adam_beta1,
                                             beta_2=params.adam_beta2,
                                             epsilon=params.adam_epsilon,
                                             clipper=clipper,
                                             summaries=params.save_summary)
    elif params.optimizer.lower() == "adadelta":
        optimizer = optimizers.AdadeltaOptimizer(
            learning_rate=schedule,
            rho=params.adadelta_rho,
            epsilon=params.adadelta_epsilon,
            clipper=clipper,
            summaries=params.save_summary)
    elif params.optimizer.lower() == "sgd":
        optimizer = optimizers.SGDOptimizer(learning_rate=schedule,
                                            clipper=clipper,
                                            summaries=params.save_summary)
    else:
        raise ValueError("Unknown optimizer %s" % params.optimizer)

    if args.half:
        optimizer = optimizers.LossScalingOptimizer(optimizer)

    optimizer = optimizers.MultiStepOptimizer(optimizer, params.update_cycle)

    if dist.get_rank() == 0:
        print_variables(model)

    if params.from_torchtext:
        dataset = data.get_dataset_torchtext(params.input, "train", params)
    else:
        dataset = data.get_dataset(params.input, "train", params)

    if params.validation:
        if params.from_torchtext:
            eval_dataset = data.get_dataset_torchtext(params.validation,
                                                      "infer", params)
        else:
            eval_dataset = data.get_dataset(params.validation, "infer", params)
        references = load_references(params.references)
    else:
        eval_dataset = None
        references = None

    # Load checkpoint
    checkpoint = utils.latest_checkpoint(params.output)

    if args.checkpoint is not None:
        # Load pre-trained models
        state = torch.load(args.checkpoint, map_location="cpu")
        model.load_state_dict(state["model"], strict=False)
        step = params.initial_step
        epoch = 0
        broadcast(model)
    elif checkpoint is not None:
        state = torch.load(checkpoint, map_location="cpu")
        step = state["step"]
        epoch = state["epoch"]
        model.load_state_dict(state["model"])

        if "optimizer" in state:
            optimizer.load_state_dict(state["optimizer"])
    else:
        step = 0
        epoch = 0
        broadcast(model)

    def train_fn(inputs):
        features, labels = inputs
        loss, state = model(features, labels)
        return loss, state

    counter = 0
    state = None
    if params.model == "cachedtransformer":
        last_feature = None

    while True:
        start_time = time.time()

        for features in dataset:
            if counter % params.update_cycle == 0:
                step += 1
                utils.set_global_step(step)

            counter += 1
            t = time.time()
            features = data.lookup(features,
                                   "train",
                                   params,
                                   from_torchtext=params.from_torchtext)
            if model.name == "cachedtransformer":
                features = utils.update_cache(model, features, state,
                                              last_feature)
                last_feature = features[0]
            loss, state = train_fn(features)
            gradients = optimizer.compute_gradients(loss,
                                                    list(model.parameters()))
            grads_and_vars = optimizers.exclude_variables(
                params.pattern, zip(gradients, list(model.named_parameters())))
            optimizer.apply_gradients(grads_and_vars)

            t = time.time() - t

            summary.scalar("loss", loss, step, write_every_n_steps=1)
            summary.scalar("global_step/sec", t, step)

            if counter % params.update_cycle == 0:
                if step > 0 and step % args.log_interval == 0:
                    elapsed = time.time() - start_time
                    print('| epoch {:2d} | step {:8d} | lr {:02.2e} | '
                          'ms/step {:4.0f} | loss {:8.4f} '.format(
                              epoch + 1, step,
                              optimizer._optimizer._learning_rate(step),
                              elapsed * 1000 / args.log_interval, loss.item()))
                    start_time = time.time()

                if step >= params.train_steps:
                    utils.evaluate(model, eval_dataset, params.output,
                                   references, params)
                    save_checkpoint(step, epoch, model, optimizer, params)

                    if dist.get_rank() == 0:
                        summary.close()

                    return

                if step % params.eval_steps == 0:
                    utils.evaluate(model, eval_dataset, params.output,
                                   references, params)
                    start_time = time.time()

                if step % params.save_checkpoint_steps == 0:
                    save_checkpoint(step, epoch, model, optimizer, params)
                    start_time = time.time()

        epoch += 1
Esempio n. 28
0
def main(args):
    tf.logging.set_verbosity(tf.logging.INFO)
    # Load configs
    model_cls_list = [models.get_model(model) for model in args.models]
    params_list = [default_parameters() for _ in range(len(model_cls_list))]
    params_list = [
        merge_parameters(params, model_cls.get_parameters())
        for params, model_cls in zip(params_list, model_cls_list)
    ]
    params_list = [
        import_params(args.checkpoints[i], args.models[i], params_list[i])
        for i in range(len(args.checkpoints))
    ]
    params_list = [
        override_parameters(params_list[i], args)
        for i in range(len(model_cls_list))
    ]

    # Build Graph
    with tf.Graph().as_default():
        model_var_lists = []

        # Load checkpoints
        for i, checkpoint in enumerate(args.checkpoints):
            tf.logging.info("Loading %s" % checkpoint)
            var_list = tf.train.list_variables(checkpoint)
            values = {}
            reader = tf.train.load_checkpoint(checkpoint)

            for (name, shape) in var_list:
                if not name.startswith(model_cls_list[i].get_name()):
                    continue

                if name.find("losses_avg") >= 0:
                    continue

                tensor = reader.get_tensor(name)
                values[name] = tensor

            model_var_lists.append(values)

        # Build models
        model_list = []

        for i in range(len(args.checkpoints)):
            name = model_cls_list[i].get_name()
            model = model_cls_list[i](params_list[i], name + "_%d" % i)
            model_list.append(model)

        params = params_list[0]
        placeholder = {}
        placeholder["source"] = tf.placeholder(tf.int32, [None, None],
                                               "source")
        placeholder["source_length"] = tf.placeholder(tf.int32, [None],
                                                      "source_length")
        enc_fn, dec_fn = model_list[0].get_inference_func()
        enc = enc_fn(placeholder, params)
        state = {}
        state["encoder"] = tf.placeholder(tf.float32,
                                          [None, None, params.hidden_size],
                                          "encoder")
        dec = dec_fn(placeholder, state, params)
        # Read input file
        sorted_keys, sorted_inputs = dataset.sort_input_file(args.input)
        # Build input queue
        features = dataset.get_inference_input(sorted_inputs, params)
        # Create placeholders
        placeholders = []

        for i in range(len(params.device_list)):
            placeholders.append({
                "source":
                tf.placeholder(tf.int32, [None, None], "source_%d" % i),
                "source_length":
                tf.placeholder(tf.int32, [None], "source_length_%d" % i)
            })

        # A list of outputs
        if params.generate_samples:
            inference_fn = sampling.create_sampling_graph
        else:
            inference_fn = inference.create_inference_graph

        # Create assign ops
        assign_ops = []
        feed_dict = {}

        all_var_list = tf.trainable_variables()

        for i in range(len(args.checkpoints)):
            un_init_var_list = []
            name = model_cls_list[i].get_name()

            for v in all_var_list:
                if v.name.startswith(name + "_%d" % i):
                    un_init_var_list.append(v)

            ops = set_variables(un_init_var_list, model_var_lists[i],
                                name + "_%d" % i, feed_dict)
            assign_ops.extend(ops)

        assign_op = tf.group(*assign_ops)
        init_op = tf.tables_initializer()
        results = []

        tf.get_default_graph().finalize()

        # Create session
        with tf.Session(config=session_config(params)) as sess:
            # Restore variables
            sess.run(assign_op, feed_dict=feed_dict)
            sess.run(init_op)

            total_start = time.time()
            while True:
                start = time.time()
                try:
                    feats = sess.run(features)
                    feed_dict = {
                        placeholder["source"]: feats["source"],
                        placeholder["source_length"]: feats["source_length"]
                    }
                    encoder_output = sess.run(enc, feed_dict=feed_dict)
                    encoder_output = encoder_output['encoder']
                    feed_dict_dec = {
                        placeholder["source"]: feats["source"],
                        placeholder["source_length"]: feats["source_length"],
                        state["encoder"]: encoder_output
                    }
                    result = sess.run(dec, feed_dict=feed_dict_dec)
                    #print(result)
                    results.append(result)
                    message = "Finished batch %d" % len(results)
                    tf.logging.log(tf.logging.INFO, message)
                    end = time.time()
                    print('time:', end - start, 's')
                except tf.errors.OutOfRangeError:
                    break
            total_end = time.time()
            print('total time:', total_end - total_start, 's')

        # Convert to plain text
        vocab = params.vocabulary["target"]
        outputs = []
        scores = []

        for result in results:
            print('result', result)
            for item in result:
                outputs.append(item.tolist())
            #for item in result[1]:
            #    scores.append(item.tolist())

        #outputs = list(itertools.chain(*outputs))

        restored_inputs = []
        restored_outputs = []

        for index in range(len(sorted_inputs)):
            restored_inputs.append(sorted_inputs[sorted_keys[index]])
            restored_outputs.append(outputs[sorted_keys[index]])

        # Write to file
        with open(args.output, "w") as outfile:
            count = 0
            for outputs in restored_outputs:
                print('oup', outputs)
                for output in outputs:
                    outfile.write(str(round(output, 2)) + ' ')
                outfile.write('\n')
                for output, score in zip(outputs, scores):
                    decoded = []
                    decoded = " ".join(decoded)

                    outfile.write("%s\n" % decoded)

                count += 1
Esempio n. 29
0
def main(args):
    eval_steps = args.eval_steps
    tf.logging.set_verbosity(tf.logging.DEBUG)
    # Load configs
    model_cls_list = [models.get_model(model) for model in args.models]
    params_list = [default_parameters() for _ in range(len(model_cls_list))]
    params_list = [
        merge_parameters(params, model_cls.get_parameters())
        for params, model_cls in zip(params_list, model_cls_list)
    ]
    params_list = [
        import_params(args.checkpoints[i], args.models[i], params_list[i])
        for i in range(len(args.checkpoints))
    ]
    params_list = [
        override_parameters(params_list[i], args)
        for i in range(len(model_cls_list))
    ]

    # Build Graph
    with tf.Graph().as_default():
        model_var_lists = []

        # Load checkpoints
        for i, checkpoint in enumerate(args.checkpoints):
            tf.logging.info("Loading %s" % checkpoint)
            var_list = tf.train.list_variables(checkpoint)
            values = {}
            reader = tf.train.load_checkpoint(checkpoint)

            for (name, shape) in var_list:
                if not name.startswith(model_cls_list[i].get_name()):
                    continue

                if name.find("losses_avg") >= 0:
                    continue

                tensor = reader.get_tensor(name)
                values[name] = tensor

            model_var_lists.append(values)

        # Build models
        model_fns = []

        for i in range(len(args.checkpoints)):
            name = model_cls_list[i].get_name()
            model = model_cls_list[i](params_list[i], name + "_%d" % i)
            model_fn = model.get_inference_func()
            model_fns.append(model_fn)

        params = params_list[0]
        # Read input file
        #features = dataset.get_inference_input(args.input, params)
        #features_eval = dataset.get_inference_input(args.eval, params)
        #features_test = dataset.get_inference_input(args.test, params)

        features_train = dataset.get_inference_input(args.input, params, False,
                                                     True)
        features_eval = dataset.get_inference_input(args.eval, params, True,
                                                    False)
        features_test = dataset.get_inference_input(args.test, params, True,
                                                    False)

        # Create placeholders
        placeholders = []

        for i in range(len(params.device_list)):
            placeholders.append({
                "source":
                tf.placeholder(tf.int32, [None, None], "source_%d" % i),
                "source_length":
                tf.placeholder(tf.int32, [None], "source_length_%d" % i),
                "target":
                tf.placeholder(tf.int32, [None, 2], "target_%d" % i)
            })

        # A list of outputs
        predictions = parallel.data_parallelism(
            params.device_list,
            lambda f: inference.create_inference_graph(model_fns, f, params),
            placeholders)

        # Create assign ops
        assign_ops = []

        all_var_list = tf.trainable_variables()

        for i in range(len(args.checkpoints)):
            un_init_var_list = []
            name = model_cls_list[i].get_name()

            for v in all_var_list:
                if v.name.startswith(name + "_%d" % i):
                    un_init_var_list.append(v)

            ops = set_variables(un_init_var_list, model_var_lists[i],
                                name + "_%d" % i)
            assign_ops.extend(ops)

        assign_op = tf.group(*assign_ops)
        results = []

        tf_x = tf.placeholder(tf.float32, [None, None, 512])
        tf_y = tf.placeholder(tf.int32, [None, 2])
        tf_x_len = tf.placeholder(tf.int32, [None])

        src_mask = -1e9 * (1.0 - tf.sequence_mask(
            tf_x_len, maxlen=tf.shape(predictions[0])[1], dtype=tf.float32))
        with tf.variable_scope("my_metric"):
            #q,k,v = tf.split(linear(tf_x, 3*512, True, True, scope="logit_transform"), [512, 512,512],axis=-1)
            q, k, v = tf.split(nn.linear(predictions[0],
                                         3 * 512,
                                         True,
                                         True,
                                         scope="logit_transform"),
                               [512, 512, 512],
                               axis=-1)
            q = nn.linear(
                tf.nn.tanh(q), 1, True, True,
                scope="logit_transform2")[:, :, 0] + src_mask
            # label smoothing
            ce1 = nn.smoothed_softmax_cross_entropy_with_logits(
                logits=q,
                labels=tf_y[:, :1],
                #smoothing=params.label_smoothing,
                smoothing=False,
                normalize=True)
            w1 = tf.nn.softmax(q)[:, None, :]
            #k = nn.linear(tf.nn.tanh(tf.matmul(w1,v)+k),1,True,True,scope="logit_transform3")[:,:,0]+src_mask
            k = tf.matmul(k,
                          tf.matmul(w1, v) *
                          (512**-0.5), False, True)[:, :, 0] + src_mask
            # label smoothing
            ce2 = nn.smoothed_softmax_cross_entropy_with_logits(
                logits=k,
                labels=tf_y[:, 1:],
                #smoothing=params.label_smoothing,
                smoothing=False,
                normalize=True)
            w2 = tf.nn.softmax(k)[:, None, :]
            weights = tf.concat([w1, w2], axis=1)
        loss = tf.reduce_mean(ce1 + ce2)

        #tf_x = tf.placeholder(tf.float32, [None, 512])
        #tf_y = tf.placeholder(tf.int32, [None])

        #l1 = tf.layers.dense(tf.squeeze(predictions[0], axis=-2), 64, tf.nn.sigmoid)
        #output = tf.layers.dense(l1, int(args.softmax_size))

        #loss = tf.losses.sparse_softmax_cross_entropy(labels=tf_y, logits=output)
        o1 = tf.argmax(w1, axis=-1)
        o2 = tf.argmax(w2, axis=-1)
        a1, a1_update = tf.metrics.accuracy(labels=tf.squeeze(tf_y[:, 0]),
                                            predictions=tf.argmax(w1, axis=-1),
                                            name='a1')
        a2, a2_update = tf.metrics.accuracy(labels=tf.squeeze(tf_y[:, 1]),
                                            predictions=tf.argmax(w2, axis=-1),
                                            name='a2')
        accuracy, accuracy_update = tf.metrics.accuracy(
            labels=tf.squeeze(tf_y),
            predictions=tf.argmax(weights, axis=-1),
            name='a_all')

        running_vars = tf.get_collection(tf.GraphKeys.LOCAL_VARIABLES,
                                         scope="my_metric")
        #running_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope="my_metric")
        running_vars_initializer = tf.variables_initializer(
            var_list=running_vars)

        #variables_to_train = tf.trainable_variables()
        #print (len(variables_to_train), (variables_to_train[0]), variables_to_train[1])
        #variables_to_train.remove(variables_to_train[0])
        #variables_to_train.remove(variables_to_train[0])
        #print (len(variables_to_train))
        variables_to_train = [
            v for v in tf.trainable_variables()
            if v.name.startswith("my_metric")
        ]

        optimizer = tf.train.AdamOptimizer(learning_rate=0.001)
        train_op = optimizer.minimize(loss, var_list=variables_to_train)
        #train_op = optimizer.minimize(loss, var_list=running_vars)

        # Create session
        with tf.Session(config=session_config(params)) as sess:
            init_op = tf.group(tf.global_variables_initializer(),
                               tf.local_variables_initializer())
            sess.run(init_op)
            # Restore variables
            sess.run(assign_op)
            sess.run(tf.tables_initializer())

            current_step = 0

            best_validate_acc = 0
            last_test_acc = 0

            train_x_set = []
            train_y_set = []
            valid_x_set = []
            valid_y_set = []
            test_x_set = []
            test_y_set = []
            train_x_len_set = []
            valid_x_len_set = []
            test_x_len_set = []

            while current_step < eval_steps:
                print('=======current step ' + str(current_step))
                batch_num = 0
                while True:
                    try:
                        feats = sess.run(features_train)
                        op, feed_dict = shard_features(feats, placeholders,
                                                       predictions)
                        #x = (np.squeeze(sess.run(predictions, feed_dict=feed_dict), axis = -2))
                        y = feed_dict.values()[2]
                        x_len = feed_dict.values()[1]

                        feed_dict.update({tf_y: y})
                        feed_dict.update({tf_x_len: x_len})

                        los, __, pred = sess.run([loss, train_op, weights],
                                                 feed_dict=feed_dict)
                        print("current_step", current_step, "batch_num",
                              batch_num, "loss", los)

                        batch_num += 1
                        if batch_num % 100 == 0:

                            # eval
                            b_total = 0
                            a_total = 0
                            a1_total = 0
                            a2_total = 0
                            validate_acc = 0
                            batch_num_eval = 0

                            while True:
                                try:
                                    feats_eval = sess.run(features_eval)
                                    op, feed_dict_eval = shard_features(
                                        feats_eval, placeholders, predictions)
                                    #x = (np.squeeze(sess.run(predictions, feed_dict=feed_dict), axis = -2))
                                    y = feed_dict_eval.values()[2]
                                    x_len = feed_dict_eval.values()[1]
                                    feed_dict_eval.update({tf_y: y})
                                    feed_dict_eval.update({tf_x_len: x_len})

                                    sess.run(running_vars_initializer)
                                    acc = 0
                                    #acc, pred = sess.run([accuracy, output], feed_dict = {tf_x : x, tf_y : y})
                                    sess.run([
                                        a1_update, a2_update, accuracy_update,
                                        weights
                                    ],
                                             feed_dict=feed_dict_eval)
                                    acc1, acc2, acc = sess.run(
                                        [a1, a2, accuracy])
                                    batch_size = len(y)
                                    #print(acc)
                                    a1_total += round(batch_size * acc1)
                                    a2_total += round(batch_size * acc2)
                                    a_total += round(batch_size * acc)
                                    b_total += batch_size
                                    batch_num_eval += 1

                                    if batch_num_eval == 20:
                                        break

                                except tf.errors.OutOfRangeError:
                                    print("eval out of range")
                                    break
                            if b_total:
                                validate_acc = a_total / b_total
                                print("eval acc : " + str(validate_acc) +
                                      "( " + str(a1_total / b_total) + ", " +
                                      str(a2_total / b_total) + " )")
                            print("last test acc : " + str(last_test_acc))

                            if validate_acc > best_validate_acc:
                                best_validate_acc = validate_acc

                            # test
                            b_total = 0
                            a1_total = 0
                            a2_total = 0
                            a_total = 0
                            batch_num_test = 0
                            with open(args.output, "w") as outfile:
                                while True:
                                    try:
                                        feats_test = sess.run(features_test)
                                        op, feed_dict_test = shard_features(
                                            feats_test, placeholders,
                                            predictions)

                                        #x = (np.squeeze(sess.run(predictions, feed_dict=feed_dict), axis = -2))
                                        y = feed_dict_test.values()[2]
                                        x_len = feed_dict_test.values()[1]
                                        feed_dict_test.update({tf_y: y})
                                        feed_dict_test.update(
                                            {tf_x_len: x_len})

                                        sess.run(running_vars_initializer)
                                        acc = 0
                                        #acc, pred = sess.run([accuracy, output], feed_dict = {tf_x : x, tf_y : y})
                                        __, __, __, out1, out2 = sess.run(
                                            [
                                                a1_update, a2_update,
                                                accuracy_update, o1, o2
                                            ],
                                            feed_dict=feed_dict_test)
                                        acc1, acc2, acc = sess.run(
                                            [a1, a2, accuracy])

                                        batch_size = len(y)
                                        a_total += round(batch_size * acc)
                                        a1_total += round(batch_size * acc1)
                                        a2_total += round(batch_size * acc2)
                                        b_total += batch_size
                                        batch_num_test += 1
                                        for pred1, pred2 in zip(out1, out2):
                                            outfile.write("%s " % pred1[0])
                                            outfile.write("%s\n" % pred2[0])
                                        if batch_num_test == 20:
                                            break
                                    except tf.errors.OutOfRangeError:
                                        print("test out of range")
                                        break
                                if b_total:
                                    last_test_acc = a_total / b_total
                                    print("new test acc : " +
                                          str(last_test_acc) + "( " +
                                          str(a1_total / b_total) + ", " +
                                          str(a2_total / b_total) + " )")

                        if batch_num == 25000:
                            break
                    except tf.errors.OutOfRangeError:
                        print("train out of range")
                        break

                # eval


#                b_total = 0
#                a_total = 0
#                a1_total = 0
#                a2_total = 0
#                validate_acc = 0
#                batch_num = 0

#                while True:
#                    try:
#                        feats_eval = sess.run(features_eval)
#                        op, feed_dict = shard_features(feats_eval, placeholders, predictions)
#                        #x = (np.squeeze(sess.run(predictions, feed_dict=feed_dict), axis = -2))
#                        y =  feed_dict.values()[2]
#                        x_len =  feed_dict.values()[1]
#                        feed_dict.update({tf_y:y})
#                        feed_dict.update({tf_x_len:x_len})

#                        sess.run(running_vars_initializer)
#                        acc = 0
#acc, pred = sess.run([accuracy, output], feed_dict = {tf_x : x, tf_y : y})
#                        sess.run([a1_update, a2_update, accuracy_update, weights], feed_dict = feed_dict)
#                        acc1,acc2,acc = sess.run([a1,a2,accuracy])
#                        batch_size = len(y)
#print(acc)
#                        a1_total += round(batch_size*acc1)
#                        a2_total += round(batch_size*acc2)
#                        a_total += round(batch_size*acc)
#                        b_total += batch_size
#                        batch_num += 1

#                        if batch_num == 10:
#                            break

#                    except tf.errors.OutOfRangeError:
#                        print ("eval out of range")
#                        break

#                validate_acc = a_total/b_total
#                print("eval acc : "  + str(validate_acc) + "( "+str(a1_total/b_total)+ ", "+ str(a2_total/b_total) + " )")
#                print("last test acc : " + str(last_test_acc))

#                if validate_acc > best_validate_acc:
#                    best_validate_acc = validate_acc

# test
#                    b_total = 0
#                    a1_total = 0
#                    a2_total = 0
#                    a_total = 0
#                    batch_num = 0

#                    while True:
#                        try:
#                            feats_test = sess.run(features_test)
#                            op, feed_dict = shard_features(feats_test, placeholders,
#                                                             predictions)

#x = (np.squeeze(sess.run(predictions, feed_dict=feed_dict), axis = -2))
#                            y =  feed_dict.values()[2]
#                            x_len =  feed_dict.values()[1]
#                            feed_dict.update({tf_y:y})
#                            feed_dict.update({tf_x_len:x_len})

#                            sess.run(running_vars_initializer)
#                            acc = 0
#acc, pred = sess.run([accuracy, output], feed_dict = {tf_x : x, tf_y : y})
#                            sess.run([a1_update,a2_update,accuracy_update, weights], feed_dict = feed_dict)
#                            acc1,acc2,acc = sess.run([a1,a2,accuracy])

#                            batch_size = len(y)
#                            a_total += round(batch_size*acc)
#                            a1_total += round(batch_size*acc1)
#                            a2_total += round(batch_size*acc2)
#                            b_total += batch_size
#                            batch_num += 1

#                            if batch_num==10:
#                                break
#                        except tf.errors.OutOfRangeError:
#                            print ("test out of range")
#                            break
#                    last_test_acc = a_total/b_total
#                    print("new test acc : " + str(last_test_acc)+ "( "+str(a1_total/b_total)+ ", "+ str(a2_total/b_total) + " )")

                current_step += 1
                print("")
        print("Final test acc " + str(last_test_acc))

        return