def decode(args):
    option, values = load_model(args.model)
    #option, values = load_average_model(args.model)
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True

    svocabs, tvocabs = option["vocabulary"]
    svocab, isvocab = svocabs
    tvocab, itvocab = tvocabs

    unk_sym = option["unk"]
    eos_sym = option["eos"]

    source_word2vec, target_word2vec = option["word2vecs"]

    count = 0

    doption = {
        "maxlen": args.maxlen,
        "minlen": args.minlen,
        "beamsize": args.beamsize,
        "normalize": args.normalize
    }

    # create graph
    model = NMT(option["num_layers"], option["num_heads"],
                option["attention_dropout"], option["residual_dropout"],
                option["relu_dropout"],
                option["embedding"], option["hidden"], option["filter"],
                len(isvocab), len(itvocab), source_word2vec, target_word2vec)

    model.option = option

    input_file = open(args.corpus, 'r')
    output_file = open(args.translation, 'w')

    with tf.Session(config=config):
        tf.global_variables_initializer().run()
        set_variables(tf.trainable_variables(), values)

        line = input_file.readline()
        while line:
            line_list = line.split()
            data = [line]
            seq, _, seq_len = convert_data(data, svocab, unk_sym, eos_sym)
            t1 = time.time()
            tlist = beamsearch(model, seq, seq_len, **doption)
            t2 = time.time()

            if len(tlist) == 0:
                sys.stdout.write("\n")
                score = -10000.0
            else:
                best, score = tlist[0]
                output_file.write(" ".join(best[:-1]))
                output_file.write("\n")

            count = count + 1
            sys.stderr.write(str(count) + " ")
            sys.stderr.write(str(score) + " " + str(t2 - t1) + "\n")
            line = input_file.readline()
    output_file.close()
    input_file.close()
def train(args):
    option = default_option()

    # predefined model names
    pathname, basename = os.path.split(args.model)
    modelname = get_filename(basename)
    autoname = os.path.join(pathname, modelname + ".autosave.pkl")
    bestname = os.path.join(pathname, modelname + ".best.pkl")

    # load models
    if os.path.exists(args.model):
        opt, params = load_model(args.model)
        override(option, opt)
        init = False
    else:
        init = True
        params = None

    override(option, args_to_dict(args))
    print_option(option)

    # load references
    if option["references"]:
        references = load_references(option["references"])
    else:
        references = None

    # input corpus
    batch = option["batch"]
    sortk = option["sort"] or 1
    shuffle = option["seed"] if option["shuffle"] else None
    reader = TextReader(option["corpus"], shuffle)
    processor = [data_length, data_length]
    stream = TextIterator(reader, [batch, batch * sortk], processor,
                          option["limit"], option["sort"])

    if shuffle and option["indices"] is not None:
        reader.set_indices(option["indices"])

    if args.reset:
        option["count"] = [0, 0]
        option["epoch"] = 0
        option["cost"] = 0.0

    skip_stream(reader, option["count"][1])

    # beamsearch option
    search_opt = {
        "beamsize": option["beamsize"],
        "normalize": option["normalize"],
        "maxlen": option["maxlen"],
        "minlen": option["minlen"]
    }

    # misc
    svocabs, tvocabs = option["vocabulary"]
    svocab, isvocab = svocabs
    tvocab, itvocab = tvocabs
    unk = option["unk"]
    eos = option["eos"]

    source_word2vec, target_word2vec = option["word2vecs"]

    scale = option["scale"]

    # set seed
    np.random.seed(option["seed"])
    tf.set_random_seed(option["seed"])

    initializer = tf.random_uniform_initializer(-scale, scale)
    model = NMT(option["num_layers"],
                option["num_heads"],
                option["attention_dropout"],
                option["residual_dropout"],
                option["relu_dropout"],
                option["embedding"],
                option["hidden"],
                option["filter"],
                len(isvocab),
                len(itvocab),
                source_word2vec,
                target_word2vec,
                initializer=initializer)

    model.option = option

    # create optimizer
    optim = Optimizer(model,
                      algorithm=option["optimizer"],
                      norm=True,
                      constraint=("norm", option["norm"]))

    # create session
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True

    with tf.Session(config=config):
        tf.global_variables_initializer().run()

        print "parameters:", count_parameters(tf.trainable_variables())

        if not init:
            set_variables(tf.trainable_variables(), params)

        def lr_decay_fn(*args, **kwargs):
            global_step = kwargs["global_step"]
            step = kwargs["step"]
            epoch = kwargs["epoch"]
            option["alpha"] = option["alpha"] * option["decay"]
            msg = "G/E/S: %d/%d/%d  alpha: %f"
            print(msg % (global_step, epoch, step, option["alpha"]))

        def train_step_fn(data, **variables):
            alpha = option["alpha"]
            global_step = variables["global_step"]
            step = variables["step"]
            epoch = variables["epoch"]

            xdata, _, xlen = convert_data(data[0], svocab, unk, eos)
            ydata, _, ylen = convert_data(data[1], tvocab, unk, eos)

            t1 = time.time()
            cost, norm = optim.optimize(xdata, xlen, ydata, ylen)
            alpha = (1 / float(option["embedding"])**0.5) * min(
                1 / float(global_step)**0.5,
                global_step / float(option["warmup"])**1.5)
            optim.update(alpha=alpha)
            t2 = time.time()
            #cost = cost * len(ylen) / sum(ylen)
            msg = "G/E/S: %d/%d/%d cost: %f norm: %f time: %f"
            print(msg % (global_step, epoch, step, cost, norm, t2 - t1))

            return cost / math.log(2)

        def sample_fn(*args, **kwargs):
            data = args[0]
            batch = len(data[0])
            ind = np.random.randint(0, batch)
            sdata = data[0][ind]
            tdata = data[1][ind]
            xdata, _, xlen = convert_data(data[0], svocab, unk, eos)
            xdata = xdata[ind:ind + 1, :]
            xlen = xlen[ind:ind + 1]
            hls = beamsearch(model, xdata, xlen, **search_opt)
            best, score = hls[0]
            print("> " + sdata)
            print("> " + tdata)
            print("> " + " ".join(best[:-1]))

        def cost_summary(*args, **kwargs):
            cost = kwargs["local_cost"]
            global_cost = kwargs["global_cost"]
            step = kwargs["local_step"]
            global_step = kwargs["global_step"]

            ac, gac = cost / step, global_cost / global_step

            print("averaged cost: %f/%f" % (ac, gac))

        def stop_fn(*args, **kwargs):
            if option["maxepoch"] < kwargs["epoch"]:
                raise StopIteration

        def save_fn(*args, **kwargs):
            save_model(model, autoname, reader, option, **kwargs)

        def validate_fn(*args, **kwargs):
            if option["validation"] and references:
                validate_model(model, option["validation"], references,
                               search_opt, bestname, reader, option, **kwargs)

        # global/epoch
        lr_decay_hook = ops.train_loop.hook(option["stop"], 1, lr_decay_fn)
        # local
        save_hook = ops.train_loop.hook(0, option["freq"], save_fn)
        e_save_hook = ops.train_loop.hook(0, 2, save_fn)
        # local
        sample_hook = ops.train_loop.hook(0, option["sfreq"], sample_fn)
        # global/local/epoch
        validate_hook = ops.train_loop.hook(0, option["vfreq"], validate_fn)
        e_validate_hook = ops.train_loop.hook(0, 1, validate_fn)
        # epoch
        cost_summary_hook = ops.train_loop.hook(0, 1, cost_summary)
        # global/epoch
        stop_hook = ops.train_loop.hook(0, 1, stop_fn)

        global_level_hooks = []
        local_level_hooks = [save_hook, sample_hook, validate_hook]
        epoch_level_hooks = [
            lr_decay_hook, cost_summary_hook, e_save_hook, e_validate_hook,
            stop_hook
        ]

        ops.train_loop.train_loop(stream, train_step_fn, option,
                                  global_level_hooks, local_level_hooks,
                                  epoch_level_hooks)

    stream.close()