def gated_rms_norm(x, eps=None, scope=None): """RMS-based Layer normalization layer""" if eps is None: eps = dtype.epsilon() with tf.variable_scope(scope or "rms_norm", dtype=tf.as_dtype(dtype.floatx())): layer_size = util.shape_list(x)[-1] scale = tf.get_variable("scale", [layer_size], initializer=tf.ones_initializer()) gate = tf.get_variable("gate", [layer_size], initializer=None) ms = tf.reduce_mean(x ** 2, -1, keep_dims=True) # adding gating here which slightly improves quality return scale * x * tf.rsqrt(ms + eps) * tf.nn.sigmoid(gate * x)
def layer_norm(x, eps=None, scope=None): """Layer normalization layer""" if eps is None: eps = dtype.epsilon() with tf.variable_scope(scope or "layer_norm", dtype=tf.as_dtype(dtype.floatx())): layer_size = util.shape_list(x)[-1] scale = tf.get_variable("scale", [layer_size], initializer=tf.ones_initializer()) offset = tf.get_variable("offset", [layer_size], initializer=tf.zeros_initializer()) mean = tf.reduce_mean(x, -1, keep_dims=True) var = tf.reduce_mean((x - mean) ** 2, -1, keep_dims=True) return scale * (x - mean) * tf.rsqrt(var + eps) + offset
def decoding_fn(target, state, time): with tf.variable_scope( params.scope_name or "model", reuse=tf.AUTO_REUSE, dtype=tf.as_dtype(dtype.floatx()), custom_getter=dtype.float32_variable_storage_getter): if params.search_mode == "cache": step_loss, step_logits, step_state, _ = decoder( target, state, params) else: estate = encoder(state, params) estate['dev_decode'] = True _, step_logits, _, _ = decoder(target, estate, params) step_state = state return step_logits, step_state
def data_parallelism(device_type, num_devices, fn, *args, **kwargs): # Replicate args and kwargs if args: new_args = [_maybe_repeat(arg, num_devices) for arg in args] # Transpose new_args = [list(x) for x in zip(*new_args)] else: new_args = [[] for _ in range(num_devices)] new_kwargs = [{} for _ in range(num_devices)] for k, v in kwargs.items(): vals = _maybe_repeat(v, num_devices) for i in range(num_devices): new_kwargs[i][k] = vals[i] fns = _maybe_repeat(fn, num_devices) # Now make the parallel call. outputs = [] for i in range(num_devices): worker = "/{}:{}".format(device_type, i) if device_type == 'cpu': _device_setter = local_device_setter(worker_device=worker) else: _device_setter = local_device_setter( ps_device_type='gpu', worker_device=worker, ps_strategy=tc.training.GreedyLoadBalancingStrategy( num_devices, tc.training.byte_size_load_fn) ) with tf.variable_scope(tf.get_variable_scope(), reuse=bool(i != 0), dtype=tf.as_dtype(dtype.floatx())): with tf.name_scope("tower_%d" % i): with tf.device(_device_setter): outputs.append(fns[i](*new_args[i], **new_kwargs[i])) return _reshape_output(outputs)
def dot_attention(query, memory, mem_mask, hidden_size, ln=False, num_heads=1, cache=None, dropout=None, use_relative_pos=False, max_relative_position=16, out_map=True, scope=None, fuse_mask=None, decode_step=None): """ dotted attention model :param query: [batch_size, qey_len, dim] :param memory: [batch_size, seq_len, mem_dim] or None :param mem_mask: [batch_size, seq_len] :param hidden_size: attention space dimension :param ln: whether use layer normalization :param num_heads: attention head number :param dropout: attention dropout, default disable :param out_map: output additional mapping :param cache: cache-based decoding :param fuse_mask: aan mask during training, and timestep for testing :param max_relative_position: maximum position considered for relative embedding :param use_relative_pos: whether use relative position information :param decode_step: the time step of current decoding, 0-based :param scope: :return: a value matrix, [batch_size, qey_len, mem_dim] """ with tf.variable_scope(scope or "dot_attention", reuse=tf.AUTO_REUSE, dtype=tf.as_dtype(dtype.floatx())): if fuse_mask is not None: assert memory is not None, 'Fuse mechanism only applied with cross-attention' if cache and use_relative_pos: assert decode_step is not None, 'Decode Step must provide when use relative position encoding' if memory is None: # suppose self-attention from queries alone h = linear(query, hidden_size * 3, ln=ln, scope="qkv_map") q, k, v = tf.split(h, 3, -1) if cache is not None: k = tf.concat([cache['k'], k], axis=1) v = tf.concat([cache['v'], v], axis=1) cache = { 'k': k, 'v': v, } else: q = linear(query, hidden_size, ln=ln, scope="q_map") if cache is not None and ('mk' in cache and 'mv' in cache): k, v = cache['mk'], cache['mv'] else: k = linear(memory, hidden_size, ln=ln, scope="k_map") v = linear(memory, hidden_size, ln=ln, scope="v_map") if cache is not None: cache['mk'] = k cache['mv'] = v q = split_heads(q, num_heads) k = split_heads(k, num_heads) v = split_heads(v, num_heads) q *= (hidden_size // num_heads) ** (-0.5) q_shp = util.shape_list(q) k_shp = util.shape_list(k) v_shp = util.shape_list(v) q_len = q_shp[2] if decode_step is None else decode_step + 1 r_lst = None if decode_step is None else 1 # q * k => attention weights if use_relative_pos: r = rpr.get_relative_positions_embeddings( q_len, k_shp[2], k_shp[3], max_relative_position, name="rpr_keys", last=r_lst) logits = rpr.relative_attention_inner(q, k, r, transpose=True) else: logits = tf.matmul(q, k, transpose_b=True) if mem_mask is not None: logits += mem_mask weights = tf.nn.softmax(logits) dweights = util.valid_apply_dropout(weights, dropout) # weights * v => attention vectors if use_relative_pos: r = rpr.get_relative_positions_embeddings( q_len, k_shp[2], v_shp[3], max_relative_position, name="rpr_values", last=r_lst) o = rpr.relative_attention_inner(dweights, v, r, transpose=False) else: o = tf.matmul(dweights, v) o = combine_heads(o) if fuse_mask is not None: # This is for AAN, the important part is sharing v_map v_q = linear(query, hidden_size, ln=ln, scope="v_map") if cache is not None and 'aan' in cache: aan_o = (v_q + cache['aan']) / dtype.tf_to_float(fuse_mask + 1) else: # Simplified Average Attention Network aan_o = tf.matmul(fuse_mask, v_q) if cache is not None: if 'aan' not in cache: cache['aan'] = v_q else: cache['aan'] = v_q + cache['aan'] # Directly sum both self-attention and cross attention o = o + aan_o if out_map: o = linear(o, hidden_size, ln=ln, scope="o_map") results = { 'weights': weights, 'output': o, 'cache': cache } return results
def dot_attention(query, memory, mem_mask, hidden_size, ln=False, num_heads=1, cache=None, dropout=None, out_map=True, scope=None): """ dotted attention model :param query: [batch_size, qey_len, dim] :param memory: [batch_size, seq_len, mem_dim] or None :param mem_mask: [batch_size, seq_len] :param hidden_size: attention space dimension :param ln: whether use layer normalization :param num_heads: attention head number :param dropout: attention dropout, default disable :param out_map: output additional mapping :param cache: cache-based decoding :param scope: :return: a value matrix, [batch_size, qey_len, mem_dim] """ with tf.variable_scope(scope or "dot_attention", reuse=tf.AUTO_REUSE, dtype=tf.as_dtype(dtype.floatx())): if memory is None: # suppose self-attention from queries alone h = func.linear(query, hidden_size * 3, ln=ln, scope="qkv_map") q, k, v = tf.split(h, 3, -1) if cache is not None: k = tf.concat([cache['k'], k], axis=1) v = tf.concat([cache['v'], v], axis=1) cache = { 'k': k, 'v': v, } else: q = func.linear(query, hidden_size, ln=ln, scope="q_map") if cache is not None and ('mk' in cache and 'mv' in cache): k, v = cache['mk'], cache['mv'] else: k = func.linear(memory, hidden_size, ln=ln, scope="k_map") v = func.linear(memory, hidden_size, ln=ln, scope="v_map") if cache is not None: cache['mk'] = k cache['mv'] = v q = func.split_heads(q, num_heads) k = func.split_heads(k, num_heads) v = func.split_heads(v, num_heads) q *= (hidden_size // num_heads) ** (-0.5) # q * k => attention weights logits = tf.matmul(q, k, transpose_b=True) # convert the mask to 0-1 form and multiply to logits if mem_mask is not None: zero_one_mask = tf.to_float(tf.equal(mem_mask, 0.0)) logits *= zero_one_mask # replace softmax with relu # weights = tf.nn.softmax(logits) weights = tf.nn.relu(logits) dweights = util.valid_apply_dropout(weights, dropout) # weights * v => attention vectors o = tf.matmul(dweights, v) o = func.combine_heads(o) # perform RMSNorm to stabilize running o = gated_rms_norm(o, scope="post") if out_map: o = func.linear(o, hidden_size, ln=ln, scope="o_map") results = { 'weights': weights, 'output': o, 'cache': cache } return results
def train(params): # status measure if params.recorder.estop or \ params.recorder.epoch > params.epoches or \ params.recorder.step > params.max_training_steps: tf.logging.info( "Stop condition reached, you have finished training your model.") return 0. # loading dataset tf.logging.info("Begin Loading Training and Dev Dataset") start_time = time.time() train_dataset = Dataset(params.src_train_file, params.tgt_train_file, params.src_vocab, params.tgt_vocab, params.max_len, batch_or_token=params.batch_or_token, data_leak_ratio=params.data_leak_ratio) dev_dataset = Dataset(params.src_dev_file, params.src_dev_file, params.src_vocab, params.src_vocab, params.eval_max_len, batch_or_token='batch', data_leak_ratio=params.data_leak_ratio) tf.logging.info( "End Loading dataset, within {} seconds".format(time.time() - start_time)) # Build Graph with tf.Graph().as_default(): lr = tf.placeholder(tf.as_dtype(dtype.floatx()), [], "learn_rate") # shift automatically sliced multi-gpu process into `zero` manner :) features = [] for fidx in range(max(len(params.gpus), 1)): feature = { "source": tf.placeholder(tf.int32, [None, None], "source"), "target": tf.placeholder(tf.int32, [None, None], "target"), } features.append(feature) # session info sess = util.get_session(params.gpus) tf.logging.info("Begining Building Training Graph") start_time = time.time() # create global step global_step = tf.train.get_or_create_global_step() # set up optimizer optimizer = tf.train.AdamOptimizer(lr, beta1=params.beta1, beta2=params.beta2, epsilon=params.epsilon) # get graph graph = model.get_model(params.model_name) # set up training graph loss, gradients = tower_train_graph(features, optimizer, graph, params) # apply pseudo cyclic parallel operation vle, ops = cycle.create_train_op({"loss": loss}, gradients, optimizer, global_step, params) tf.logging.info( "End Building Training Graph, within {} seconds".format( time.time() - start_time)) tf.logging.info("Begin Building Inferring Graph") start_time = time.time() # set up infer graph eval_seqs, eval_scores = tower_infer_graph(features, graph, params) tf.logging.info( "End Building Inferring Graph, within {} seconds".format( time.time() - start_time)) # initialize the model sess.run(tf.global_variables_initializer()) # log parameters util.variable_printer() # create saver train_saver = saver.Saver( checkpoints=params.checkpoints, output_dir=params.output_dir, best_checkpoints=params.best_checkpoints, ) tf.logging.info("Training") cycle_counter = 0 data_on_gpu = [] cum_tokens = [] # restore parameters tf.logging.info("Trying restore pretrained parameters") train_saver.restore(sess, path=params.pretrained_model) tf.logging.info("Trying restore existing parameters") train_saver.restore(sess) # setup learning rate params.lrate = params.recorder.lrate adapt_lr = lrs.get_lr(params) start_time = time.time() start_epoch = params.recorder.epoch for epoch in range(start_epoch, params.epoches + 1): params.recorder.epoch = epoch tf.logging.info("Training the model for epoch {}".format(epoch)) size = params.batch_size if params.batch_or_token == 'batch' \ else params.token_size train_queue = queuer.EnQueuer( train_dataset.batcher(size, buffer_size=params.buffer_size, shuffle=params.shuffle_batch, train=True), lambda x: x, worker_processes_num=params.process_num, input_queue_size=params.input_queue_size, output_queue_size=params.output_queue_size, ) adapt_lr.before_epoch(eidx=epoch) for lidx, data in enumerate(train_queue): if params.train_continue: if lidx <= params.recorder.lidx: segments = params.recorder.lidx // 5 if params.recorder.lidx < 5 or lidx % segments == 0: tf.logging.info( "{} Passing {}-th index according to record". format(util.time_str(time.time()), lidx)) continue params.recorder.lidx = lidx data_on_gpu.append(data) # use multiple gpus, and data samples is not enough # make sure the data is fully added # The actual batch size: batch_size * num_gpus * update_cycle if len(params.gpus) > 0 and len(data_on_gpu) < len( params.gpus): continue # increase the counter by 1 cycle_counter += 1 if cycle_counter == 1: # calculate adaptive learning rate adapt_lr.step(params.recorder.step) # clear internal states sess.run(ops["zero_op"]) # data feeding to gpu placeholders feed_dicts = {} for fidx, shard_data in enumerate(data_on_gpu): # define feed_dict feed_dict = { features[fidx]["source"]: shard_data["src"], features[fidx]["target"]: shard_data["tgt"], lr: adapt_lr.get_lr(), } feed_dicts.update(feed_dict) # collect target tokens cum_tokens.append(np.sum(shard_data['tgt'] > 0)) # reset data points on gpus data_on_gpu = [] # internal accumulative gradient collection if cycle_counter < params.update_cycle: sess.run(ops["collect_op"], feed_dict=feed_dicts) # at the final step, update model parameters if cycle_counter == params.update_cycle: cycle_counter = 0 # directly update parameters, usually this works well if not params.safe_nan: _, loss, gnorm, pnorm, gstep = sess.run( [ ops["train_op"], vle["loss"], vle["gradient_norm"], vle["parameter_norm"], global_step ], feed_dict=feed_dicts) if np.isnan(loss) or np.isinf(loss) or np.isnan( gnorm) or np.isinf(gnorm): tf.logging.error( "Nan or Inf raised! Loss {} GNorm {}.".format( loss, gnorm)) params.recorder.estop = True break else: # Notice, applying safe nan can help train the big model, but sacrifice speed loss, gnorm, pnorm, gstep = sess.run( [ vle["loss"], vle["gradient_norm"], vle["parameter_norm"], global_step ], feed_dict=feed_dicts) if np.isnan(loss) or np.isinf(loss) or np.isnan(gnorm) or np.isinf(gnorm) \ or gnorm > params.gnorm_upper_bound: tf.logging.error( "Nan or Inf raised, GStep {} is passed! Loss {} GNorm {}." .format(gstep, loss, gnorm)) continue sess.run(ops["train_op"], feed_dict=feed_dicts) if gstep % params.disp_freq == 0: end_time = time.time() tf.logging.info( "{} Epoch {}, GStep {}~{}, LStep {}~{}, " "Loss {:.3f}, GNorm {:.3f}, PNorm {:.3f}, Lr {:.5f}, " "Src {}, Tgt {}, Tokens {}, UD {:.3f} s".format( util.time_str(end_time), epoch, gstep - params.disp_freq + 1, gstep, lidx - params.disp_freq + 1, lidx, loss, gnorm, pnorm, adapt_lr.get_lr(), data['src'].shape, data['tgt'].shape, np.sum(cum_tokens), end_time - start_time)) start_time = time.time() cum_tokens = [] # trigger model saver if gstep > 0 and gstep % params.save_freq == 0: train_saver.save(sess, gstep) params.recorder.save_to_json( os.path.join(params.output_dir, "record.json")) # trigger model evaluation if gstep > 0 and gstep % params.eval_freq == 0: if params.ema_decay > 0.: sess.run(ops['ema_backup_op']) sess.run(ops['ema_assign_op']) tf.logging.info("Start Evaluating") eval_start_time = time.time() tranes, scores, indices = evalu.decoding( sess, features, eval_seqs, eval_scores, dev_dataset, params) bleu = evalu.eval_metric(tranes, params.tgt_dev_file, indices=indices) eval_end_time = time.time() tf.logging.info("End Evaluating") if params.ema_decay > 0.: sess.run(ops['ema_restore_op']) tf.logging.info( "{} GStep {}, Scores {}, BLEU {}, Duration {:.3f} s" .format(util.time_str(eval_end_time), gstep, np.mean(scores), bleu, eval_end_time - eval_start_time)) # save eval translation evalu.dump_tanslation( tranes, os.path.join(params.output_dir, "eval-{}.trans.txt".format(gstep)), indices=indices) # save parameters train_saver.save(sess, gstep, bleu) # check for early stopping valid_scores = [ v[1] for v in params.recorder.valid_script_scores ] if len(valid_scores ) == 0 or bleu > np.max(valid_scores): params.recorder.bad_counter = 0 else: params.recorder.bad_counter += 1 if params.recorder.bad_counter > params.estop_patience: params.recorder.estop = True break params.recorder.history_scores.append( (int(gstep), float(np.mean(scores)))) params.recorder.valid_script_scores.append( (int(gstep), float(bleu))) params.recorder.save_to_json( os.path.join(params.output_dir, "record.json")) # handle the learning rate decay in a typical manner adapt_lr.after_eval(float(bleu)) # trigger temporary sampling if gstep > 0 and gstep % params.sample_freq == 0: tf.logging.info("Start Sampling") decode_seqs, decode_scores = sess.run( [eval_seqs[:1], eval_scores[:1]], feed_dict={features[0]["source"]: data["src"][:5]}) tranes, scores = evalu.decode_hypothesis( decode_seqs, decode_scores, params) for sidx in range(min(5, len(scores))): sample_source = evalu.decode_target_token( data['src'][sidx], params.src_vocab) tf.logging.info("{}-th Source: {}".format( sidx, ' '.join(sample_source))) sample_target = evalu.decode_target_token( data['tgt'][sidx], params.tgt_vocab) tf.logging.info("{}-th Target: {}".format( sidx, ' '.join(sample_target))) sample_trans = tranes[sidx] tf.logging.info("{}-th Translation: {}".format( sidx, ' '.join(sample_trans))) tf.logging.info("End Sampling") # trigger stopping if gstep >= params.max_training_steps: # stop running by setting EStop signal params.recorder.estop = True break # should be equal to global_step params.recorder.step = int(gstep) if params.recorder.estop: tf.logging.info("Early Stopped!") break # reset to 0 params.recorder.lidx = -1 adapt_lr.after_epoch(eidx=epoch) # Final Evaluation tf.logging.info("Start Final Evaluating") if params.ema_decay > 0.: sess.run(ops['ema_backup_op']) sess.run(ops['ema_assign_op']) gstep = int(params.recorder.step + 1) eval_start_time = time.time() tranes, scores, indices = evalu.decoding(sess, features, eval_seqs, eval_scores, dev_dataset, params) bleu = evalu.eval_metric(tranes, params.tgt_dev_file, indices=indices) eval_end_time = time.time() tf.logging.info("End Evaluating") if params.ema_decay > 0.: sess.run(ops['ema_restore_op']) tf.logging.info( "{} GStep {}, Scores {}, BLEU {}, Duration {:.3f} s".format( util.time_str(eval_end_time), gstep, np.mean(scores), bleu, eval_end_time - eval_start_time)) # save eval translation evalu.dump_tanslation(tranes, os.path.join(params.output_dir, "eval-{}.trans.txt".format(gstep)), indices=indices) tf.logging.info("Your training is finished :)") return train_saver.best_score
def dot_attention(query, memory, mem_mask, hidden_size, ln=False, num_heads=1, cache=None, dropout=None, out_map=True, scope=None, count_mask=None): """ dotted attention model with l0drop :param query: [batch_size, qey_len, dim] :param memory: [batch_size, seq_len, mem_dim] or None :param mem_mask: [batch_size, seq_len] :param hidden_size: attention space dimension :param ln: whether use layer normalization :param num_heads: attention head number :param dropout: attention dropout, default disable :param out_map: output additional mapping :param cache: cache-based decoding :param count_mask: counting vector for l0drop :param scope: :return: a value matrix, [batch_size, qey_len, mem_dim] """ with tf.variable_scope(scope or "dot_attention", reuse=tf.AUTO_REUSE, dtype=tf.as_dtype(dtype.floatx())): if memory is None: # suppose self-attention from queries alone h = func.linear(query, hidden_size * 3, ln=ln, scope="qkv_map") q, k, v = tf.split(h, 3, -1) if cache is not None: k = tf.concat([cache['k'], k], axis=1) v = tf.concat([cache['v'], v], axis=1) cache = { 'k': k, 'v': v, } else: q = func.linear(query, hidden_size, ln=ln, scope="q_map") if cache is not None and ('mk' in cache and 'mv' in cache): k, v = cache['mk'], cache['mv'] else: k = func.linear(memory, hidden_size, ln=ln, scope="k_map") v = func.linear(memory, hidden_size, ln=ln, scope="v_map") if cache is not None: cache['mk'] = k cache['mv'] = v q = func.split_heads(q, num_heads) k = func.split_heads(k, num_heads) v = func.split_heads(v, num_heads) q *= (hidden_size // num_heads)**(-0.5) # q * k => attention weights logits = tf.matmul(q, k, transpose_b=True) if mem_mask is not None: logits += mem_mask # modifying 'weights = tf.nn.softmax(logits)' to include the counting information. # -------- logits = logits - tf.reduce_max(logits, -1, keepdims=True) exp_logits = tf.exp(logits) # basically, the count considers how many states are dropped (i.e. gate value 0s) if count_mask is not None: exp_logits *= count_mask exp_sum_logits = tf.reduce_sum(exp_logits, -1, keepdims=True) weights = exp_logits / exp_sum_logits # -------- dweights = util.valid_apply_dropout(weights, dropout) # weights * v => attention vectors o = tf.matmul(dweights, v) o = func.combine_heads(o) if out_map: o = func.linear(o, hidden_size, ln=ln, scope="o_map") results = {'weights': weights, 'output': o, 'cache': cache} return results
def create_train_op(named_scalars, grads_and_vars, optimizer, global_step, params): tf.get_variable_scope().set_dtype(tf.as_dtype(dtype.floatx())) gradients = [item[0] for item in grads_and_vars] variables = [item[1] for item in grads_and_vars] if params.update_cycle == 1: zero_variables_op = tf.no_op("zero_variables") collect_op = tf.no_op("collect_op") else: named_vars = {} for name in named_scalars: named_var = tf.Variable(tf.zeros([], dtype=tf.float32), name="{}/CTrainOpReplica".format(name), trainable=False) named_vars[name] = named_var count_var = tf.Variable(tf.zeros([], dtype=tf.as_dtype(dtype.floatx())), name="count/CTrainOpReplica", trainable=False) slot_variables = _replicate_variables(variables, suffix='CTrainOpReplica') zero_variables_op = _zero_variables( slot_variables + [count_var] + list(named_vars.values())) collect_ops = [] # collect gradients collect_grads_op = _collect_gradients(gradients, slot_variables) collect_ops.append(collect_grads_op) # collect other scalars for name in named_scalars: scalar = named_scalars[name] named_var = named_vars[name] collect_op = tf.assign_add(named_var, scalar) collect_ops.append(collect_op) # collect counting variable collect_count_op = tf.assign_add(count_var, 1.0) collect_ops.append(collect_count_op) collect_op = tf.group(*collect_ops, name="collect_op") scale = 1.0 / (tf.cast(count_var, tf.float32) + 1.0) gradients = [scale * (g + s) for (g, s) in zip(gradients, slot_variables)] for name in named_scalars: named_scalars[name] = scale * ( named_scalars[name] + named_vars[name]) grand_norm = tf.global_norm(gradients) param_norm = tf.global_norm(variables) # Gradient clipping if isinstance(params.clip_grad_norm or None, float): gradients, _ = tf.clip_by_global_norm(gradients, params.clip_grad_norm, use_norm=grand_norm) # Update variables grads_and_vars = list(zip(gradients, variables)) train_op = optimizer.apply_gradients(grads_and_vars, global_step) ops = { "zero_op": zero_variables_op, "collect_op": collect_op, "train_op": train_op } # apply ema if params.ema_decay > 0.: tf.logging.info('Using Exp Moving Average to train the model with decay {}.'.format(params.ema_decay)) ema = tf.train.ExponentialMovingAverage(decay=params.ema_decay, num_updates=global_step) ema_op = ema.apply(variables) with tf.control_dependencies([ops['train_op']]): ops['train_op'] = tf.group(ema_op) bck_vars = _replicate_variables(variables, suffix="CTrainOpBackUpReplica") ops['ema_backup_op'] = tf.group(*(tf.assign(bck, var.read_value()) for bck, var in zip(bck_vars, variables))) ops['ema_restore_op'] = tf.group(*(tf.assign(var, bck.read_value()) for bck, var in zip(bck_vars, variables))) ops['ema_assign_op'] = tf.group(*(tf.assign(var, ema.average(var).read_value()) for var in variables)) ret = named_scalars ret.update({ "gradient_norm": grand_norm, "parameter_norm": param_norm, }) return ret, ops
def additive_attention(query, memory, mem_mask, hidden_size, ln=False, proj_memory=None, num_heads=1, dropout=None, att_fun="add", scope=None): """ additive attention model :param query: [batch_size, dim] :param memory: [batch_size, seq_len, mem_dim] :param mem_mask: [batch_size, seq_len] :param hidden_size: attention space dimension :param ln: whether use layer normalization :param proj_memory: this is the mapped memory for saving memory :param num_heads: attention head number :param dropout: attention dropout, default disable :param scope: :return: a value matrix, [batch_size, mem_dim] """ with tf.variable_scope(scope or "additive_attention", dtype=tf.as_dtype(dtype.floatx())): if proj_memory is None: proj_memory = linear(memory, hidden_size, ln=ln, scope="feed_memory") query = linear(tf.expand_dims(query, 1), hidden_size, ln=ln, scope="feed_query") query = split_heads(query, num_heads) proj_memory = split_heads(proj_memory, num_heads) if att_fun == "add": value = tf.tanh(query + proj_memory) logits = linear(value, 1, ln=False, scope="feed_logits") logits = tf.squeeze(logits, -1) else: logits = tf.matmul(query, proj_memory, transpose_b=True) logits = tf.squeeze(logits, 2) logits = util.mask_scale(logits, tf.expand_dims(mem_mask, 1)) weights = tf.nn.softmax(logits, -1) # [batch_size, seq_len] dweights = util.valid_apply_dropout(weights, dropout) memory = split_heads(memory, num_heads) value = tf.reduce_sum(tf.expand_dims(dweights, -1) * memory, -2, keepdims=True) value = combine_heads(value) value = tf.squeeze(value, 1) results = { 'weights': weights, 'output': value, 'cache_state': proj_memory } return results