Esempio n. 1
0
def gated_rms_norm(x, eps=None, scope=None):
    """RMS-based Layer normalization layer"""
    if eps is None:
        eps = dtype.epsilon()
    with tf.variable_scope(scope or "rms_norm",
                           dtype=tf.as_dtype(dtype.floatx())):
        layer_size = util.shape_list(x)[-1]

        scale = tf.get_variable("scale", [layer_size], initializer=tf.ones_initializer())
        gate = tf.get_variable("gate", [layer_size], initializer=None)

        ms = tf.reduce_mean(x ** 2, -1, keep_dims=True)

        # adding gating here which slightly improves quality
        return scale * x * tf.rsqrt(ms + eps) * tf.nn.sigmoid(gate * x)
Esempio n. 2
0
def layer_norm(x, eps=None, scope=None):
    """Layer normalization layer"""
    if eps is None:
        eps = dtype.epsilon()
    with tf.variable_scope(scope or "layer_norm",
                           dtype=tf.as_dtype(dtype.floatx())):
        layer_size = util.shape_list(x)[-1]

        scale = tf.get_variable("scale", [layer_size], initializer=tf.ones_initializer())
        offset = tf.get_variable("offset", [layer_size], initializer=tf.zeros_initializer())

        mean = tf.reduce_mean(x, -1, keep_dims=True)
        var = tf.reduce_mean((x - mean) ** 2, -1, keep_dims=True)

        return scale * (x - mean) * tf.rsqrt(var + eps) + offset
Esempio n. 3
0
    def decoding_fn(target, state, time):
        with tf.variable_scope(
                params.scope_name or "model",
                reuse=tf.AUTO_REUSE,
                dtype=tf.as_dtype(dtype.floatx()),
                custom_getter=dtype.float32_variable_storage_getter):
            if params.search_mode == "cache":
                step_loss, step_logits, step_state, _ = decoder(
                    target, state, params)
            else:
                estate = encoder(state, params)
                estate['dev_decode'] = True
                _, step_logits, _, _ = decoder(target, estate, params)
                step_state = state

            return step_logits, step_state
Esempio n. 4
0
def data_parallelism(device_type, num_devices, fn, *args, **kwargs):
    # Replicate args and kwargs
    if args:
        new_args = [_maybe_repeat(arg, num_devices) for arg in args]
        # Transpose
        new_args = [list(x) for x in zip(*new_args)]
    else:
        new_args = [[] for _ in range(num_devices)]

    new_kwargs = [{} for _ in range(num_devices)]

    for k, v in kwargs.items():
        vals = _maybe_repeat(v, num_devices)

        for i in range(num_devices):
            new_kwargs[i][k] = vals[i]

    fns = _maybe_repeat(fn, num_devices)

    # Now make the parallel call.
    outputs = []
    for i in range(num_devices):
        worker = "/{}:{}".format(device_type, i)
        if device_type == 'cpu':
            _device_setter = local_device_setter(worker_device=worker)
        else:
            _device_setter = local_device_setter(
                ps_device_type='gpu',
                worker_device=worker,
                ps_strategy=tc.training.GreedyLoadBalancingStrategy(
                    num_devices, tc.training.byte_size_load_fn)
            )

        with tf.variable_scope(tf.get_variable_scope(), reuse=bool(i != 0),
                               dtype=tf.as_dtype(dtype.floatx())):
            with tf.name_scope("tower_%d" % i):
                with tf.device(_device_setter):
                    outputs.append(fns[i](*new_args[i], **new_kwargs[i]))

    return _reshape_output(outputs)
Esempio n. 5
0
def dot_attention(query, memory, mem_mask, hidden_size,
                  ln=False, num_heads=1, cache=None, dropout=None,
                  use_relative_pos=False, max_relative_position=16,
                  out_map=True, scope=None, fuse_mask=None,
                  decode_step=None):
    """
    dotted attention model
    :param query: [batch_size, qey_len, dim]
    :param memory: [batch_size, seq_len, mem_dim] or None
    :param mem_mask: [batch_size, seq_len]
    :param hidden_size: attention space dimension
    :param ln: whether use layer normalization
    :param num_heads: attention head number
    :param dropout: attention dropout, default disable
    :param out_map: output additional mapping
    :param cache: cache-based decoding
    :param fuse_mask: aan mask during training, and timestep for testing
    :param max_relative_position: maximum position considered for relative embedding
    :param use_relative_pos: whether use relative position information
    :param decode_step: the time step of current decoding, 0-based
    :param scope:
    :return: a value matrix, [batch_size, qey_len, mem_dim]
    """
    with tf.variable_scope(scope or "dot_attention", reuse=tf.AUTO_REUSE,
                           dtype=tf.as_dtype(dtype.floatx())):
        if fuse_mask is not None:
            assert memory is not None, 'Fuse mechanism only applied with cross-attention'
        if cache and use_relative_pos:
            assert decode_step is not None, 'Decode Step must provide when use relative position encoding'

        if memory is None:
            # suppose self-attention from queries alone
            h = linear(query, hidden_size * 3, ln=ln, scope="qkv_map")
            q, k, v = tf.split(h, 3, -1)

            if cache is not None:
                k = tf.concat([cache['k'], k], axis=1)
                v = tf.concat([cache['v'], v], axis=1)
                cache = {
                    'k': k,
                    'v': v,
                }
        else:
            q = linear(query, hidden_size, ln=ln, scope="q_map")
            if cache is not None and ('mk' in cache and 'mv' in cache):
                k, v = cache['mk'], cache['mv']
            else:
                k = linear(memory, hidden_size, ln=ln, scope="k_map")
                v = linear(memory, hidden_size, ln=ln, scope="v_map")

            if cache is not None:
                cache['mk'] = k
                cache['mv'] = v

        q = split_heads(q, num_heads)
        k = split_heads(k, num_heads)
        v = split_heads(v, num_heads)

        q *= (hidden_size // num_heads) ** (-0.5)

        q_shp = util.shape_list(q)
        k_shp = util.shape_list(k)
        v_shp = util.shape_list(v)

        q_len = q_shp[2] if decode_step is None else decode_step + 1
        r_lst = None if decode_step is None else 1

        # q * k => attention weights
        if use_relative_pos:
            r = rpr.get_relative_positions_embeddings(
                q_len, k_shp[2], k_shp[3],
                max_relative_position, name="rpr_keys", last=r_lst)
            logits = rpr.relative_attention_inner(q, k, r, transpose=True)
        else:
            logits = tf.matmul(q, k, transpose_b=True)

        if mem_mask is not None:
            logits += mem_mask

        weights = tf.nn.softmax(logits)

        dweights = util.valid_apply_dropout(weights, dropout)

        # weights * v => attention vectors
        if use_relative_pos:
            r = rpr.get_relative_positions_embeddings(
                q_len, k_shp[2], v_shp[3],
                max_relative_position, name="rpr_values", last=r_lst)
            o = rpr.relative_attention_inner(dweights, v, r, transpose=False)
        else:
            o = tf.matmul(dweights, v)

        o = combine_heads(o)

        if fuse_mask is not None:
            # This is for AAN, the important part is sharing v_map
            v_q = linear(query, hidden_size, ln=ln, scope="v_map")

            if cache is not None and 'aan' in cache:
                aan_o = (v_q + cache['aan']) / dtype.tf_to_float(fuse_mask + 1)
            else:
                # Simplified Average Attention Network
                aan_o = tf.matmul(fuse_mask, v_q)

            if cache is not None:
                if 'aan' not in cache:
                    cache['aan'] = v_q
                else:
                    cache['aan'] = v_q + cache['aan']

            # Directly sum both self-attention and cross attention
            o = o + aan_o

        if out_map:
            o = linear(o, hidden_size, ln=ln, scope="o_map")

        results = {
            'weights': weights,
            'output': o,
            'cache': cache
        }

        return results
Esempio n. 6
0
def dot_attention(query, memory, mem_mask, hidden_size,
                  ln=False, num_heads=1, cache=None, dropout=None,
                  out_map=True, scope=None):
    """
    dotted attention model
    :param query: [batch_size, qey_len, dim]
    :param memory: [batch_size, seq_len, mem_dim] or None
    :param mem_mask: [batch_size, seq_len]
    :param hidden_size: attention space dimension
    :param ln: whether use layer normalization
    :param num_heads: attention head number
    :param dropout: attention dropout, default disable
    :param out_map: output additional mapping
    :param cache: cache-based decoding
    :param scope:
    :return: a value matrix, [batch_size, qey_len, mem_dim]
    """
    with tf.variable_scope(scope or "dot_attention", reuse=tf.AUTO_REUSE,
                           dtype=tf.as_dtype(dtype.floatx())):
        if memory is None:
            # suppose self-attention from queries alone
            h = func.linear(query, hidden_size * 3, ln=ln, scope="qkv_map")
            q, k, v = tf.split(h, 3, -1)

            if cache is not None:
                k = tf.concat([cache['k'], k], axis=1)
                v = tf.concat([cache['v'], v], axis=1)
                cache = {
                    'k': k,
                    'v': v,
                }
        else:
            q = func.linear(query, hidden_size, ln=ln, scope="q_map")
            if cache is not None and ('mk' in cache and 'mv' in cache):
                k, v = cache['mk'], cache['mv']
            else:
                k = func.linear(memory, hidden_size, ln=ln, scope="k_map")
                v = func.linear(memory, hidden_size, ln=ln, scope="v_map")

            if cache is not None:
                cache['mk'] = k
                cache['mv'] = v

        q = func.split_heads(q, num_heads)
        k = func.split_heads(k, num_heads)
        v = func.split_heads(v, num_heads)

        q *= (hidden_size // num_heads) ** (-0.5)

        # q * k => attention weights
        logits = tf.matmul(q, k, transpose_b=True)

        # convert the mask to 0-1 form and multiply to logits
        if mem_mask is not None:
            zero_one_mask = tf.to_float(tf.equal(mem_mask, 0.0))
            logits *= zero_one_mask

        # replace softmax with relu
        # weights = tf.nn.softmax(logits)
        weights = tf.nn.relu(logits)

        dweights = util.valid_apply_dropout(weights, dropout)

        # weights * v => attention vectors
        o = tf.matmul(dweights, v)
        o = func.combine_heads(o)

        # perform RMSNorm to stabilize running
        o = gated_rms_norm(o, scope="post")

        if out_map:
            o = func.linear(o, hidden_size, ln=ln, scope="o_map")

        results = {
            'weights': weights,
            'output': o,
            'cache': cache
        }

        return results
Esempio n. 7
0
def train(params):
    # status measure
    if params.recorder.estop or \
            params.recorder.epoch > params.epoches or \
            params.recorder.step > params.max_training_steps:
        tf.logging.info(
            "Stop condition reached, you have finished training your model.")
        return 0.

    # loading dataset
    tf.logging.info("Begin Loading Training and Dev Dataset")
    start_time = time.time()
    train_dataset = Dataset(params.src_train_file,
                            params.tgt_train_file,
                            params.src_vocab,
                            params.tgt_vocab,
                            params.max_len,
                            batch_or_token=params.batch_or_token,
                            data_leak_ratio=params.data_leak_ratio)
    dev_dataset = Dataset(params.src_dev_file,
                          params.src_dev_file,
                          params.src_vocab,
                          params.src_vocab,
                          params.eval_max_len,
                          batch_or_token='batch',
                          data_leak_ratio=params.data_leak_ratio)
    tf.logging.info(
        "End Loading dataset, within {} seconds".format(time.time() -
                                                        start_time))

    # Build Graph
    with tf.Graph().as_default():
        lr = tf.placeholder(tf.as_dtype(dtype.floatx()), [], "learn_rate")

        # shift automatically sliced multi-gpu process into `zero` manner :)
        features = []
        for fidx in range(max(len(params.gpus), 1)):
            feature = {
                "source": tf.placeholder(tf.int32, [None, None], "source"),
                "target": tf.placeholder(tf.int32, [None, None], "target"),
            }
            features.append(feature)

        # session info
        sess = util.get_session(params.gpus)

        tf.logging.info("Begining Building Training Graph")
        start_time = time.time()

        # create global step
        global_step = tf.train.get_or_create_global_step()

        # set up optimizer
        optimizer = tf.train.AdamOptimizer(lr,
                                           beta1=params.beta1,
                                           beta2=params.beta2,
                                           epsilon=params.epsilon)

        # get graph
        graph = model.get_model(params.model_name)

        # set up training graph
        loss, gradients = tower_train_graph(features, optimizer, graph, params)

        # apply pseudo cyclic parallel operation
        vle, ops = cycle.create_train_op({"loss": loss}, gradients, optimizer,
                                         global_step, params)

        tf.logging.info(
            "End Building Training Graph, within {} seconds".format(
                time.time() - start_time))

        tf.logging.info("Begin Building Inferring Graph")
        start_time = time.time()

        # set up infer graph
        eval_seqs, eval_scores = tower_infer_graph(features, graph, params)

        tf.logging.info(
            "End Building Inferring Graph, within {} seconds".format(
                time.time() - start_time))

        # initialize the model
        sess.run(tf.global_variables_initializer())

        # log parameters
        util.variable_printer()

        # create saver
        train_saver = saver.Saver(
            checkpoints=params.checkpoints,
            output_dir=params.output_dir,
            best_checkpoints=params.best_checkpoints,
        )

        tf.logging.info("Training")
        cycle_counter = 0
        data_on_gpu = []
        cum_tokens = []

        # restore parameters
        tf.logging.info("Trying restore pretrained parameters")
        train_saver.restore(sess, path=params.pretrained_model)

        tf.logging.info("Trying restore existing parameters")
        train_saver.restore(sess)

        # setup learning rate
        params.lrate = params.recorder.lrate
        adapt_lr = lrs.get_lr(params)

        start_time = time.time()
        start_epoch = params.recorder.epoch
        for epoch in range(start_epoch, params.epoches + 1):

            params.recorder.epoch = epoch

            tf.logging.info("Training the model for epoch {}".format(epoch))
            size = params.batch_size if params.batch_or_token == 'batch' \
                else params.token_size

            train_queue = queuer.EnQueuer(
                train_dataset.batcher(size,
                                      buffer_size=params.buffer_size,
                                      shuffle=params.shuffle_batch,
                                      train=True),
                lambda x: x,
                worker_processes_num=params.process_num,
                input_queue_size=params.input_queue_size,
                output_queue_size=params.output_queue_size,
            )

            adapt_lr.before_epoch(eidx=epoch)

            for lidx, data in enumerate(train_queue):

                if params.train_continue:
                    if lidx <= params.recorder.lidx:
                        segments = params.recorder.lidx // 5
                        if params.recorder.lidx < 5 or lidx % segments == 0:
                            tf.logging.info(
                                "{} Passing {}-th index according to record".
                                format(util.time_str(time.time()), lidx))

                        continue

                params.recorder.lidx = lidx

                data_on_gpu.append(data)
                # use multiple gpus, and data samples is not enough
                # make sure the data is fully added
                # The actual batch size: batch_size * num_gpus * update_cycle
                if len(params.gpus) > 0 and len(data_on_gpu) < len(
                        params.gpus):
                    continue

                # increase the counter by 1
                cycle_counter += 1

                if cycle_counter == 1:
                    # calculate adaptive learning rate
                    adapt_lr.step(params.recorder.step)

                    # clear internal states
                    sess.run(ops["zero_op"])

                # data feeding to gpu placeholders
                feed_dicts = {}
                for fidx, shard_data in enumerate(data_on_gpu):
                    # define feed_dict
                    feed_dict = {
                        features[fidx]["source"]: shard_data["src"],
                        features[fidx]["target"]: shard_data["tgt"],
                        lr: adapt_lr.get_lr(),
                    }
                    feed_dicts.update(feed_dict)

                    # collect target tokens
                    cum_tokens.append(np.sum(shard_data['tgt'] > 0))

                # reset data points on gpus
                data_on_gpu = []

                # internal accumulative gradient collection
                if cycle_counter < params.update_cycle:
                    sess.run(ops["collect_op"], feed_dict=feed_dicts)

                # at the final step, update model parameters
                if cycle_counter == params.update_cycle:
                    cycle_counter = 0

                    # directly update parameters, usually this works well
                    if not params.safe_nan:
                        _, loss, gnorm, pnorm, gstep = sess.run(
                            [
                                ops["train_op"], vle["loss"],
                                vle["gradient_norm"], vle["parameter_norm"],
                                global_step
                            ],
                            feed_dict=feed_dicts)

                        if np.isnan(loss) or np.isinf(loss) or np.isnan(
                                gnorm) or np.isinf(gnorm):
                            tf.logging.error(
                                "Nan or Inf raised! Loss {} GNorm {}.".format(
                                    loss, gnorm))
                            params.recorder.estop = True
                            break
                    else:
                        # Notice, applying safe nan can help train the big model, but sacrifice speed
                        loss, gnorm, pnorm, gstep = sess.run(
                            [
                                vle["loss"], vle["gradient_norm"],
                                vle["parameter_norm"], global_step
                            ],
                            feed_dict=feed_dicts)

                        if np.isnan(loss) or np.isinf(loss) or np.isnan(gnorm) or np.isinf(gnorm) \
                                or gnorm > params.gnorm_upper_bound:
                            tf.logging.error(
                                "Nan or Inf raised, GStep {} is passed! Loss {} GNorm {}."
                                .format(gstep, loss, gnorm))
                            continue

                        sess.run(ops["train_op"], feed_dict=feed_dicts)

                    if gstep % params.disp_freq == 0:
                        end_time = time.time()
                        tf.logging.info(
                            "{} Epoch {}, GStep {}~{}, LStep {}~{}, "
                            "Loss {:.3f}, GNorm {:.3f}, PNorm {:.3f}, Lr {:.5f}, "
                            "Src {}, Tgt {}, Tokens {}, UD {:.3f} s".format(
                                util.time_str(end_time), epoch,
                                gstep - params.disp_freq + 1, gstep,
                                lidx - params.disp_freq + 1, lidx, loss, gnorm,
                                pnorm, adapt_lr.get_lr(),
                                data['src'].shape, data['tgt'].shape,
                                np.sum(cum_tokens), end_time - start_time))
                        start_time = time.time()
                        cum_tokens = []

                    # trigger model saver
                    if gstep > 0 and gstep % params.save_freq == 0:
                        train_saver.save(sess, gstep)
                        params.recorder.save_to_json(
                            os.path.join(params.output_dir, "record.json"))

                    # trigger model evaluation
                    if gstep > 0 and gstep % params.eval_freq == 0:
                        if params.ema_decay > 0.:
                            sess.run(ops['ema_backup_op'])
                            sess.run(ops['ema_assign_op'])

                        tf.logging.info("Start Evaluating")
                        eval_start_time = time.time()
                        tranes, scores, indices = evalu.decoding(
                            sess, features, eval_seqs, eval_scores,
                            dev_dataset, params)
                        bleu = evalu.eval_metric(tranes,
                                                 params.tgt_dev_file,
                                                 indices=indices)
                        eval_end_time = time.time()
                        tf.logging.info("End Evaluating")

                        if params.ema_decay > 0.:
                            sess.run(ops['ema_restore_op'])

                        tf.logging.info(
                            "{} GStep {}, Scores {}, BLEU {}, Duration {:.3f} s"
                            .format(util.time_str(eval_end_time), gstep,
                                    np.mean(scores), bleu,
                                    eval_end_time - eval_start_time))

                        # save eval translation
                        evalu.dump_tanslation(
                            tranes,
                            os.path.join(params.output_dir,
                                         "eval-{}.trans.txt".format(gstep)),
                            indices=indices)

                        # save parameters
                        train_saver.save(sess, gstep, bleu)

                        # check for early stopping
                        valid_scores = [
                            v[1] for v in params.recorder.valid_script_scores
                        ]
                        if len(valid_scores
                               ) == 0 or bleu > np.max(valid_scores):
                            params.recorder.bad_counter = 0
                        else:
                            params.recorder.bad_counter += 1

                            if params.recorder.bad_counter > params.estop_patience:
                                params.recorder.estop = True
                                break

                        params.recorder.history_scores.append(
                            (int(gstep), float(np.mean(scores))))
                        params.recorder.valid_script_scores.append(
                            (int(gstep), float(bleu)))
                        params.recorder.save_to_json(
                            os.path.join(params.output_dir, "record.json"))

                        # handle the learning rate decay in a typical manner
                        adapt_lr.after_eval(float(bleu))

                    # trigger temporary sampling
                    if gstep > 0 and gstep % params.sample_freq == 0:
                        tf.logging.info("Start Sampling")
                        decode_seqs, decode_scores = sess.run(
                            [eval_seqs[:1], eval_scores[:1]],
                            feed_dict={features[0]["source"]: data["src"][:5]})
                        tranes, scores = evalu.decode_hypothesis(
                            decode_seqs, decode_scores, params)

                        for sidx in range(min(5, len(scores))):
                            sample_source = evalu.decode_target_token(
                                data['src'][sidx], params.src_vocab)
                            tf.logging.info("{}-th Source: {}".format(
                                sidx, ' '.join(sample_source)))
                            sample_target = evalu.decode_target_token(
                                data['tgt'][sidx], params.tgt_vocab)
                            tf.logging.info("{}-th Target: {}".format(
                                sidx, ' '.join(sample_target)))
                            sample_trans = tranes[sidx]
                            tf.logging.info("{}-th Translation: {}".format(
                                sidx, ' '.join(sample_trans)))

                        tf.logging.info("End Sampling")

                    # trigger stopping
                    if gstep >= params.max_training_steps:
                        # stop running by setting EStop signal
                        params.recorder.estop = True
                        break

                    # should be equal to global_step
                    params.recorder.step = int(gstep)

            if params.recorder.estop:
                tf.logging.info("Early Stopped!")
                break

            # reset to 0
            params.recorder.lidx = -1

            adapt_lr.after_epoch(eidx=epoch)

    # Final Evaluation
    tf.logging.info("Start Final Evaluating")
    if params.ema_decay > 0.:
        sess.run(ops['ema_backup_op'])
        sess.run(ops['ema_assign_op'])

    gstep = int(params.recorder.step + 1)
    eval_start_time = time.time()
    tranes, scores, indices = evalu.decoding(sess, features, eval_seqs,
                                             eval_scores, dev_dataset, params)
    bleu = evalu.eval_metric(tranes, params.tgt_dev_file, indices=indices)
    eval_end_time = time.time()
    tf.logging.info("End Evaluating")

    if params.ema_decay > 0.:
        sess.run(ops['ema_restore_op'])

    tf.logging.info(
        "{} GStep {}, Scores {}, BLEU {}, Duration {:.3f} s".format(
            util.time_str(eval_end_time), gstep, np.mean(scores), bleu,
            eval_end_time - eval_start_time))

    # save eval translation
    evalu.dump_tanslation(tranes,
                          os.path.join(params.output_dir,
                                       "eval-{}.trans.txt".format(gstep)),
                          indices=indices)

    tf.logging.info("Your training is finished :)")

    return train_saver.best_score
Esempio n. 8
0
def dot_attention(query,
                  memory,
                  mem_mask,
                  hidden_size,
                  ln=False,
                  num_heads=1,
                  cache=None,
                  dropout=None,
                  out_map=True,
                  scope=None,
                  count_mask=None):
    """
    dotted attention model with l0drop
    :param query: [batch_size, qey_len, dim]
    :param memory: [batch_size, seq_len, mem_dim] or None
    :param mem_mask: [batch_size, seq_len]
    :param hidden_size: attention space dimension
    :param ln: whether use layer normalization
    :param num_heads: attention head number
    :param dropout: attention dropout, default disable
    :param out_map: output additional mapping
    :param cache: cache-based decoding
    :param count_mask: counting vector for l0drop
    :param scope:
    :return: a value matrix, [batch_size, qey_len, mem_dim]
    """
    with tf.variable_scope(scope or "dot_attention",
                           reuse=tf.AUTO_REUSE,
                           dtype=tf.as_dtype(dtype.floatx())):
        if memory is None:
            # suppose self-attention from queries alone
            h = func.linear(query, hidden_size * 3, ln=ln, scope="qkv_map")
            q, k, v = tf.split(h, 3, -1)

            if cache is not None:
                k = tf.concat([cache['k'], k], axis=1)
                v = tf.concat([cache['v'], v], axis=1)
                cache = {
                    'k': k,
                    'v': v,
                }
        else:
            q = func.linear(query, hidden_size, ln=ln, scope="q_map")
            if cache is not None and ('mk' in cache and 'mv' in cache):
                k, v = cache['mk'], cache['mv']
            else:
                k = func.linear(memory, hidden_size, ln=ln, scope="k_map")
                v = func.linear(memory, hidden_size, ln=ln, scope="v_map")

            if cache is not None:
                cache['mk'] = k
                cache['mv'] = v

        q = func.split_heads(q, num_heads)
        k = func.split_heads(k, num_heads)
        v = func.split_heads(v, num_heads)

        q *= (hidden_size // num_heads)**(-0.5)

        # q * k => attention weights
        logits = tf.matmul(q, k, transpose_b=True)

        if mem_mask is not None:
            logits += mem_mask

        # modifying 'weights = tf.nn.softmax(logits)' to include the counting information.
        # --------
        logits = logits - tf.reduce_max(logits, -1, keepdims=True)
        exp_logits = tf.exp(logits)

        # basically, the count considers how many states are dropped (i.e. gate value 0s)
        if count_mask is not None:
            exp_logits *= count_mask

        exp_sum_logits = tf.reduce_sum(exp_logits, -1, keepdims=True)
        weights = exp_logits / exp_sum_logits
        # --------

        dweights = util.valid_apply_dropout(weights, dropout)

        # weights * v => attention vectors
        o = tf.matmul(dweights, v)
        o = func.combine_heads(o)

        if out_map:
            o = func.linear(o, hidden_size, ln=ln, scope="o_map")

        results = {'weights': weights, 'output': o, 'cache': cache}

        return results
Esempio n. 9
0
def create_train_op(named_scalars, grads_and_vars, optimizer, global_step, params):
    tf.get_variable_scope().set_dtype(tf.as_dtype(dtype.floatx()))

    gradients = [item[0] for item in grads_and_vars]
    variables = [item[1] for item in grads_and_vars]

    if params.update_cycle == 1:
        zero_variables_op = tf.no_op("zero_variables")
        collect_op = tf.no_op("collect_op")
    else:
        named_vars = {}
        for name in named_scalars:
            named_var = tf.Variable(tf.zeros([], dtype=tf.float32),
                                    name="{}/CTrainOpReplica".format(name),
                                    trainable=False)
            named_vars[name] = named_var
        count_var = tf.Variable(tf.zeros([], dtype=tf.as_dtype(dtype.floatx())),
                                name="count/CTrainOpReplica",
                                trainable=False)
        slot_variables = _replicate_variables(variables, suffix='CTrainOpReplica')
        zero_variables_op = _zero_variables(
            slot_variables + [count_var] + list(named_vars.values()))

        collect_ops = []
        # collect gradients
        collect_grads_op = _collect_gradients(gradients, slot_variables)
        collect_ops.append(collect_grads_op)

        # collect other scalars
        for name in named_scalars:
            scalar = named_scalars[name]
            named_var = named_vars[name]
            collect_op = tf.assign_add(named_var, scalar)
            collect_ops.append(collect_op)
        # collect counting variable
        collect_count_op = tf.assign_add(count_var, 1.0)
        collect_ops.append(collect_count_op)

        collect_op = tf.group(*collect_ops, name="collect_op")
        scale = 1.0 / (tf.cast(count_var, tf.float32) + 1.0)
        gradients = [scale * (g + s)
                     for (g, s) in zip(gradients, slot_variables)]

        for name in named_scalars:
            named_scalars[name] = scale * (
                    named_scalars[name] + named_vars[name])

    grand_norm = tf.global_norm(gradients)
    param_norm = tf.global_norm(variables)

    # Gradient clipping
    if isinstance(params.clip_grad_norm or None, float):
        gradients, _ = tf.clip_by_global_norm(gradients,
                                              params.clip_grad_norm,
                                              use_norm=grand_norm)

    # Update variables
    grads_and_vars = list(zip(gradients, variables))
    train_op = optimizer.apply_gradients(grads_and_vars, global_step)

    ops = {
        "zero_op": zero_variables_op,
        "collect_op": collect_op,
        "train_op": train_op
    }

    # apply ema
    if params.ema_decay > 0.:
        tf.logging.info('Using Exp Moving Average to train the model with decay {}.'.format(params.ema_decay))
        ema = tf.train.ExponentialMovingAverage(decay=params.ema_decay, num_updates=global_step)
        ema_op = ema.apply(variables)
        with tf.control_dependencies([ops['train_op']]):
            ops['train_op'] = tf.group(ema_op)
        bck_vars = _replicate_variables(variables, suffix="CTrainOpBackUpReplica")

        ops['ema_backup_op'] = tf.group(*(tf.assign(bck, var.read_value())
                                        for bck, var in zip(bck_vars, variables)))
        ops['ema_restore_op'] = tf.group(*(tf.assign(var, bck.read_value())
                                         for bck, var in zip(bck_vars, variables)))
        ops['ema_assign_op'] = tf.group(*(tf.assign(var, ema.average(var).read_value())
                                        for var in variables))

    ret = named_scalars
    ret.update({
        "gradient_norm": grand_norm,
        "parameter_norm": param_norm,
    })

    return ret, ops
Esempio n. 10
0
def additive_attention(query,
                       memory,
                       mem_mask,
                       hidden_size,
                       ln=False,
                       proj_memory=None,
                       num_heads=1,
                       dropout=None,
                       att_fun="add",
                       scope=None):
    """
    additive attention model
    :param query: [batch_size, dim]
    :param memory: [batch_size, seq_len, mem_dim]
    :param mem_mask: [batch_size, seq_len]
    :param hidden_size: attention space dimension
    :param ln: whether use layer normalization
    :param proj_memory: this is the mapped memory for saving memory
    :param num_heads: attention head number
    :param dropout: attention dropout, default disable
    :param scope:
    :return: a value matrix, [batch_size, mem_dim]
    """
    with tf.variable_scope(scope or "additive_attention",
                           dtype=tf.as_dtype(dtype.floatx())):
        if proj_memory is None:
            proj_memory = linear(memory,
                                 hidden_size,
                                 ln=ln,
                                 scope="feed_memory")

        query = linear(tf.expand_dims(query, 1),
                       hidden_size,
                       ln=ln,
                       scope="feed_query")

        query = split_heads(query, num_heads)
        proj_memory = split_heads(proj_memory, num_heads)

        if att_fun == "add":
            value = tf.tanh(query + proj_memory)

            logits = linear(value, 1, ln=False, scope="feed_logits")
            logits = tf.squeeze(logits, -1)
        else:
            logits = tf.matmul(query, proj_memory, transpose_b=True)
            logits = tf.squeeze(logits, 2)

        logits = util.mask_scale(logits, tf.expand_dims(mem_mask, 1))

        weights = tf.nn.softmax(logits, -1)  # [batch_size, seq_len]

        dweights = util.valid_apply_dropout(weights, dropout)

        memory = split_heads(memory, num_heads)
        value = tf.reduce_sum(tf.expand_dims(dweights, -1) * memory,
                              -2,
                              keepdims=True)

        value = combine_heads(value)
        value = tf.squeeze(value, 1)

        results = {
            'weights': weights,
            'output': value,
            'cache_state': proj_memory
        }

        return results