예제 #1
0
파일: fuel_lstm.py 프로젝트: gunkisu/asr
def build_graph(FLAGS):
    """Define training graph.
  """
    with tf.device(FLAGS.device):
        # Graph input
        x = tf.placeholder(
            tf.float32,
            shape=(None, None,
                   FLAGS.n_input))  # (seq_len, batch_size, n_input)
        x_mask = tf.placeholder(tf.float32,
                                shape=(None, None))  # (seq_len, batch_size)
        state = tf.placeholder(tf.float32, shape=(2, None, FLAGS.n_hidden))
        y = tf.placeholder(tf.int32,
                           shape=(None, None))  # (seq_len, batch_size)

    # Define LSTM module
    _rnn = LSTMModule(FLAGS.n_hidden)
    # Call LSTM module
    h_rnn_3d, last_state = _rnn(x, state)
    # Reshape into [seq_len*batch_size, num_units]
    h_rnn_2d = tf.reshape(h_rnn_3d, [-1, FLAGS.n_hidden])
    # Define output layer
    _output = LinearCell(FLAGS.n_class)
    # Call output layer [seq_len*batch_size, n_class]
    h_logits = _output(h_rnn_2d, 'output')
    # Transform labels into one-hot vectors [seq_len*batch_size, n_class]
    y_1hot = tf.one_hot(tf.reshape(y, [-1]), depth=FLAGS.n_class)
    # Define loss and optimizer [seq_len*batch_size]
    cross_entropy = tf.nn.softmax_cross_entropy_with_logits(labels=y_1hot,
                                                            logits=h_logits)
    # Reshape into [seq_len, batch_size]
    #  cross_entropy = tf.reshape(cross_entropy, [-1, FLAGS.batch_size])
    cost = tf.reduce_sum((cross_entropy * tf.reshape(x_mask, [-1])),
                         reduction_indices=0)
    return Graph(cost, x, x_mask, state, last_state, y)
예제 #2
0
def build_graph(args):
  with tf.device(args.device):
    # [batch_size, seq_len, ...]
    seq_x_data = tf.placeholder(dtype=tf.float32, shape=(None, None, args.n_input))
    seq_x_mask = tf.placeholder(dtype=tf.float32, shape=(None, None))
    seq_y_data = tf.placeholder(dtype=tf.int32, shape=(None, None))

    # [2, batch_size, ...]
    init_state = tf.placeholder(tf.float32, shape=(2, None, args.n_hidden))

  with tf.variable_scope('rnn'):
    _rnn = StackLSTMModule(num_units=args.n_hidden, num_layers=args.n_layer)

  with tf.variable_scope('label'):
    _label_logit = LinearCell(num_units=args.n_class)

  n_samples = tf.shape(seq_x_data)[0] 
  seq_len = tf.shape(seq_x_data)[1]

  seq_hid_3d, _ = _rnn(seq_x_data, init_state)
  seq_hid_2d = tf.reshape(seq_hid_3d, [-1, args.n_hidden])

  seq_label_logits = _label_logit(seq_hid_2d, 'label_logit')

  y_1hot = tf.one_hot(tf.reshape(seq_y_data, [-1]), depth=FLAGS.n_class)

  ml_cost = tf.nn.softmax_cross_entropy_with_logits(logits=seq_label_logits,  labels=y_1hot)
  ml_cost = tf.reduce_sum(ml_cost*tf.reshape(seq_x_mask, [-1]))

  pred_idx = tf.argmax(seq_label_logits, axis=1)

  train_graph = TrainGraph(ml_cost,
                           seq_x_data,
                           seq_x_mask,
                           seq_y_data,
                           init_state,
                           pred_idx,
                           tf.reshape(seq_label_logits, (n_samples, seq_len, -1)))

  return train_graph
예제 #3
0
def build_graph_am(FLAGS):
    """Define training graph for acousic modeling. 
  Parameters used: batch_size, n_hidden, n_input, use_impl_type, 
    n_output_embed, n_class, weight_decay, grad_clipping, 
    start_from_ckpt, opt_file_name, use_slope_annel
  
  """
    tparams = OrderedDict()
    trng = RandomStreams(
        np.random.RandomState(np.random.randint(1024)).randint(
            np.iinfo(np.int32).max))
    print("Building the computational graph")

    # Define bunch of shared variables
    st_slope = sharedX(1., name='binary_sigmoid_gate')
    init_state = np.zeros((3, 2, FLAGS.n_batch, FLAGS.n_hidden),
                          dtype=np.float32)
    init_bound = np.zeros((2, FLAGS.n_batch), dtype=np.float32)
    tstate = sharedX(init_state, name='rnn_state')
    tboundary = sharedX(init_bound, name='rnn_bound')

    # Graph input: n_seq, n_batch, n_feat
    x = tensor.ftensor3('inp')
    x_mask = tensor.fmatrix('inp_mask')
    y = tensor.imatrix('tar')
    y_mask = tensor.fmatrix('tar_mask')

    # Define HM-LSTM module
    _rnn = HMLSTMModule(FLAGS.n_input,
                        FLAGS.n_hidden,
                        prefix='hm_lstm',
                        use_impl_type=FLAGS.use_impl_type)
    tparams = merge_dict(tparams, _rnn._params)

    # Call HM-LSTM module
    (h_rnn_1_3d, c_rnn_1_3d, h_rnn_2_3d, c_rnn_2_3d, h_rnn_3_3d, c_rnn_3_3d,
     z_1_3d, z_2_3d), last_state, last_boundary = \
            _rnn(x, tstate, tboundary)

    # Define output gating layer
    _o_gate = LinearCell([FLAGS.n_hidden] * 3,
                         3,
                         prefix='o_gate',
                         activation=tensor.nnet.sigmoid)
    tparams = merge_dict(tparams, _o_gate._params)

    # Call output gating layer
    h_o_gate = _o_gate([h_rnn_1_3d, h_rnn_2_3d, h_rnn_3_3d])

    # Define output embedding layer
    _o_embed = LinearCell([FLAGS.n_hidden] * 3,
                          FLAGS.n_output_embed,
                          prefix='o_embed',
                          activation=tensor.nnet.relu)
    tparams = merge_dict(tparams, _o_embed._params)

    # Call output embedding layer
    h_o_embed = _o_embed([
        h_rnn_1_3d * h_o_gate[:, :, 0][:, :, None],
        h_rnn_2_3d * h_o_gate[:, :, 1][:, :, None],
        h_rnn_3_3d * h_o_gate[:, :, 2][:, :, None]
    ])
    # Define output layer
    _output = LinearCell(FLAGS.n_output_embed, FLAGS.n_class, prefix='output')
    tparams = merge_dict(tparams, _output._params)
    # Call output layer
    h_logit = _output([h_o_embed])

    logit_shape = h_logit.shape
    logit = h_logit.reshape([logit_shape[0] * logit_shape[1], logit_shape[2]])
    logit = logit - logit.max(axis=1).dimshuffle(0, 'x')
    probs = logit - tensor.log(
        tensor.exp(logit).sum(axis=1).dimshuffle(0, 'x'))

    # Compute the cost
    y_flat = y.flatten()
    y_flat_idx = tensor.arange(y_flat.shape[0]) * FLAGS.n_class + y_flat
    cost = -probs.flatten()[y_flat_idx]
    cost = cost.reshape([y.shape[0], y.shape[1]])
    cost = (cost * y_mask).sum(0)
    cost_len = y_mask.sum(0)

    f_prop_updates = OrderedDict()
    f_prop_updates[tstate] = last_state
    f_prop_updates[tboundary] = last_boundary
    states = [tstate, tboundary]

    # Later use for visualization
    inps = [x, y, y_mask]

    print("Building f_log_prob function")
    f_log_prob = theano.function(inps, [cost, cost_len],
                                 updates=f_prop_updates)
    cost = cost.mean()

    # If the flag is on, apply L2 regularization on weights
    if FLAGS.weight_decay > 0.:
        weights_norm = 0.
        for k, v in tparams.iteritems():
            if '_W' in k:
                weights_norm += (v**2).sum()
        cost += weights_norm * FLAGS.weight_decay
    #print("Computing the gradients")
    grads = tensor.grad(cost, wrt=itemlist(tparams))
    grads = gradient_clipping(grads, tparams, FLAGS.gclip)
    # Compile the optimizer, the actual computational graph
    learning_rate = tensor.scalar(name='learning_rate')
    gshared = [
        theano.shared(p.get_value() * 0., name='%s_grad' % k)
        for k, p in tparams.iteritems()
    ]
    gsup = OrderedDict(izip(gshared, grads))
    print("Building f_prop function")
    f_prop = theano.function(inps, [cost],
                             updates=merge_dict(gsup, f_prop_updates))
    opt_updates, opt_tparams = adam(learning_rate, tparams, gshared)
    if FLAGS.start_from_ckpt and os.path.exists(FLAGS.opt_file_name):
        opt_params = np.load(FLAGS.opt_file_name)
        zipp(opt_params, opt_tparams)
    if FLAGS.use_slope_anneal:
        for kk, pp in opt_updates.items():
            k = str(kk)[-7:]
            if '_W' in k and not ('_v' in k or '_m' in k):
                # _v or _m come from the gradients buffers of adam optimizer
                updated_param = opt_updates[kk][:, -1]
                col_norms = tensor.sqrt(tensor.sqr(updated_param).sum())
                desired_norms = tensor.clip(col_norms, 0, 1.9365)
                ratio = (desired_norms / (1e-7 + col_norms))
                updated_param = tensor.set_subtensor(opt_updates[kk][:, -1],
                                                     updated_param * ratio)
                opt_updates[kk] = updated_param
    print("Building f_update function")
    f_update = theano.function([learning_rate], [],
                               updates=opt_updates,
                               on_unused_input='ignore')
    #print("Building f_debug function")
    f_debug = theano.function(
        [x], [h_rnn_1_3d, h_rnn_2_3d, h_rnn_3_3d, z_1_3d, z_2_3d],
        updates=f_prop_updates,
        on_unused_input='ignore')
    return f_prop, f_update, f_log_prob, f_debug, tparams, opt_tparams, states, \
        st_slope
예제 #4
0
def build_graph(args):
    with tf.device(args.device):
        # n_batch, n_seq, n_feat
        seq_x_data = tf.placeholder(tf.float32,
                                    shape=(None, None, args.n_input))
        seq_x_mask = tf.placeholder(tf.float32, shape=(None, None))
        seq_y_data = tf.placeholder(tf.int32, shape=(None, None))
        seq_y_data_for_action = tf.placeholder(tf.int32, shape=(None, None))

        init_state = tuple(
            lstm_state(args.n_hidden, l) for l in range(args.n_layer))

        seq_action = tf.placeholder(tf.float32,
                                    shape=(None, None, args.n_action))
        seq_advantage = tf.placeholder(tf.float32, shape=(None, None))
        seq_action_mask = tf.placeholder(tf.float32, shape=(None, None))

        step_x_data = tf.placeholder(tf.float32,
                                     shape=(None, args.n_input),
                                     name='step_x_data')

        embedding = tf.get_variable("embedding",
                                    [args.n_class, args.n_embedding],
                                    dtype=tf.float32)
        step_y_data_for_action = tf.placeholder(tf.int32,
                                                shape=(None, ),
                                                name='step_y_data_for_action')
        seq_y_input = tf.nn.embedding_lookup(embedding, seq_y_data_for_action)

        sample_y = tf.placeholder(tf.bool, name='sample_y')

    def lstm_cell():
        return tf.contrib.rnn.LSTMCell(num_units=args.n_hidden,
                                       forget_bias=0.0)

    cell = tf.contrib.rnn.MultiRNNCell(
        [lstm_cell() for _ in range(args.n_layer)])

    with tf.variable_scope('label'):
        _label_logit = LinearCell(num_units=args.n_class)

    with tf.variable_scope('action'):
        _action_logit = LinearCell(num_units=args.n_action)

    # sampling graph
    step_h_state, step_last_state = cell(step_x_data,
                                         init_state,
                                         scope='rnn/multi_rnn_cell')

    # no need to do stop_gradient because training is not done for the sampling graph
    step_label_logits = _label_logit(step_h_state, 'label_logit')
    step_label_probs = tf.nn.softmax(logits=step_label_logits,
                                     name='step_label_probs')
    step_y_input_answer = tf.nn.embedding_lookup(embedding,
                                                 step_y_data_for_action)

    step_y_1hot_pred = tf.argmax(step_label_probs, axis=-1)
    step_y_input_pred = tf.nn.embedding_lookup(embedding, step_y_1hot_pred)
    step_y_input = tf.where(sample_y, step_y_input_pred, step_y_input_answer)

    step_action_logits = _action_logit([step_h_state, step_y_input],
                                       'action_logit')
    step_action_probs = tf.nn.softmax(logits=step_action_logits,
                                      name='step_action_probs')
    step_action_samples = tf.multinomial(logits=step_action_logits,
                                         num_samples=1,
                                         name='step_action_samples')
    step_action_entropy = categorical_ent(step_action_probs)

    # training graph
    seq_hid_3d, _ = tf.nn.dynamic_rnn(cell=cell,
                                      inputs=seq_x_data,
                                      initial_state=init_state,
                                      scope='rnn')
    seq_hid_2d = tf.reshape(seq_hid_3d, [-1, args.n_hidden])

    seq_label_logits = _label_logit(seq_hid_2d, 'label_logit')

    y_1hot = tf.one_hot(tf.reshape(seq_y_data, [-1]), depth=args.n_class)

    ml_cost = tf.nn.softmax_cross_entropy_with_logits(logits=seq_label_logits,
                                                      labels=y_1hot)
    ml_cost = tf.reduce_sum(ml_cost * tf.reshape(seq_x_mask, [-1]))

    pred_idx = tf.argmax(seq_label_logits, axis=1)

    seq_hid_3d_rl = seq_hid_3d[:, :-1, :]  #
    seq_hid_2d_rl = tf.reshape(seq_hid_3d_rl, [-1, args.n_hidden])
    seq_hid_2d_rl = tf.stop_gradient(seq_hid_2d_rl)

    seq_y_input_2d = tf.reshape(seq_y_input[:, :-1:], [-1, args.n_embedding])
    seq_action_logits = _action_logit([seq_hid_2d_rl, seq_y_input_2d],
                                      'action_logit')
    seq_action_probs = tf.nn.softmax(seq_action_logits)

    action_prob_entropy = categorical_ent(seq_action_probs)
    action_prob_entropy *= tf.reshape(seq_action_mask, [-1])
    action_prob_entropy = tf.reduce_sum(action_prob_entropy) / tf.reduce_sum(
        seq_action_mask)

    rl_cost = tf.reduce_sum(tf.log(seq_action_probs+1e-8) \
        * tf.reshape(seq_action, [-1,args.n_action]), axis=-1)
    rl_cost *= tf.reshape(seq_advantage, [-1])
    rl_cost = -tf.reduce_sum(rl_cost * tf.reshape(seq_action_mask, [-1]))

    train_graph = TrainGraph(ml_cost, rl_cost, seq_x_data, seq_x_mask,
                             seq_y_data, seq_y_data_for_action, init_state,
                             seq_action, seq_advantage, seq_action_mask,
                             pred_idx)

    sample_graph = SampleGraph(step_h_state, step_last_state, step_label_probs,
                               step_action_probs, step_action_samples,
                               step_x_data, step_y_data_for_action, init_state,
                               step_action_entropy, sample_y)

    return train_graph, sample_graph
예제 #5
0
파일: ptb_lstm.py 프로젝트: gunkisu/asr
def build_graph(FLAGS):
    """Define training graph.
  """
    tparams = OrderedDict()
    trng = RandomStreams(
        np.random.RandomState(np.random.randint(1024)).randint(
            np.iinfo(np.int32).max))
    print("Building the computational graph")
    # Define bunch of shared variables
    init_state = np.zeros((3, 2, FLAGS.batch_size, FLAGS.n_hidden),
                          dtype=np.float32)
    tstate = sharedX(init_state, name='rnn_state')
    # Graph input
    inp = tensor.matrix('inp', dtype='int64')
    inp_mask = tensor.matrix('inp_mask', dtype='float32')
    inp.tag.test_value = np.zeros((FLAGS.max_seq_len, FLAGS.batch_size),
                                  dtype='int64')
    inp_mask.tag.test_value = np.ones((FLAGS.max_seq_len, FLAGS.batch_size),
                                      dtype='float32')
    x, y = inp[:-1], inp[1:]
    y_mask = inp_mask[1:]
    # Define input embedding layer
    _i_embed = LinearCell(FLAGS.n_class,
                          FLAGS.n_input_embed,
                          prefix='i_embed',
                          bias=False,
                          input_is_int=True)
    tparams = merge_dict(tparams, _i_embed._params)
    # Call input embedding layer
    h_i_emb_3d = _i_embed(x)
    # Define the first LSTM module
    _rnn_1 = LSTMModule(FLAGS.n_input_embed, FLAGS.n_hidden, prefix='lstm_1')
    tparams = merge_dict(tparams, _rnn_1._params)
    # Call the first LSTM module
    (h_rnn_1_3d, c_rnn_1_3d), last_state_1 = _rnn_1(h_i_emb_3d, tstate[0])
    # Define the second LSTM module
    _rnn_2 = LSTMModule(FLAGS.n_hidden, FLAGS.n_hidden, prefix='lstm_2')
    tparams = merge_dict(tparams, _rnn_1._params)
    # Call the second LSTM module
    (h_rnn_2_3d, c_rnn_2_3d), last_state_2 = _rnn_2(h_rnn_1_3d, tstate[1])
    # Define the third LSTM module
    _rnn_3 = LSTMModule(FLAGS.n_hidden, FLAGS.n_hidden, prefix='lstm_3')
    tparams = merge_dict(tparams, _rnn_3._params)
    # Call the third LSTM module
    (h_rnn_3_3d, c_rnn_3_3d), last_state_3 = _rnn_3(h_rnn_2_3d, tstate[2])
    # Define output gating layer
    _o_gate = LinearCell([FLAGS.n_hidden] * 3,
                         3,
                         prefix='o_gate',
                         activation=tensor.nnet.sigmoid)
    tparams = merge_dict(tparams, _o_gate._params)
    # Call output gating layer
    h_o_gate = _o_gate([h_rnn_1_3d, h_rnn_2_3d, h_rnn_3_3d])
    # Define output embedding layer
    _o_embed = LinearCell([FLAGS.n_hidden] * 3,
                          FLAGS.n_output_embed,
                          prefix='o_embed',
                          activation=tensor.nnet.relu)
    tparams = merge_dict(tparams, _o_embed._params)
    # Call output embedding layer
    h_o_embed = _o_embed([
        h_rnn_1_3d * h_o_gate[:, :, 0][:, :, None],
        h_rnn_2_3d * h_o_gate[:, :, 1][:, :, None],
        h_rnn_3_3d * h_o_gate[:, :, 2][:, :, None]
    ])
    # Define output layer
    _output = LinearCell(FLAGS.n_output_embed, FLAGS.n_class, prefix='output')
    tparams = merge_dict(tparams, _output._params)
    # Call output layer
    h_logit = _output([h_o_embed])
    logit_shape = h_logit.shape
    logit = h_logit.reshape([logit_shape[0] * logit_shape[1], logit_shape[2]])
    logit = logit - logit.max(axis=1).dimshuffle(0, 'x')
    probs = logit - tensor.log(
        tensor.exp(logit).sum(axis=1).dimshuffle(0, 'x'))
    # Compute the cost
    y_flat = y.flatten()
    y_flat_idx = tensor.arange(y_flat.shape[0]) * FLAGS.n_class + y_flat
    cost = -probs.flatten()[y_flat_idx]
    cost = cost.reshape([y.shape[0], y.shape[1]])
    cost = (cost * y_mask).sum(0)
    cost_len = y_mask.sum(0)
    last_state = tensor.stack([last_state_1, last_state_2, last_state_3],
                              axis=0)
    f_prop_updates = OrderedDict()
    f_prop_updates[tstate] = last_state
    states = [tstate]
    # Later use for visualization
    inps = [inp, inp_mask]
    print("Building f_log_prob function")
    f_log_prob = theano.function(inps, [cost, cost_len],
                                 updates=f_prop_updates)
    cost = cost.mean()
    # If the flag is on, apply L2 regularization on weights
    if FLAGS.weight_decay > 0.:
        weights_norm = 0.
        for k, v in tparams.iteritems():
            weights_norm += (v**2).sum()
        cost += weights_norm * FLAGS.weight_decay
    #print("Computing the gradients")
    grads = tensor.grad(cost, wrt=itemlist(tparams))
    grads = gradient_clipping(grads, tparams, 1.)
    # Compile the optimizer, the actual computational graph
    learning_rate = tensor.scalar(name='learning_rate')
    gshared = [
        theano.shared(p.get_value() * 0., name='%s_grad' % k)
        for k, p in tparams.iteritems()
    ]
    gsup = OrderedDict(izip(gshared, grads))
    print("Building f_prop function")
    f_prop = theano.function(inps, [cost],
                             updates=merge_dict(gsup, f_prop_updates))
    opt_updates, opt_tparams = adam(learning_rate, tparams, gshared)
    if FLAGS.start_from_ckpt and os.path.exists(opt_file_name):
        opt_params = np.load(opt_file_name)
        zipp(opt_params, opt_tparams)
    print("Building f_update function")
    f_update = theano.function([learning_rate], [],
                               updates=opt_updates,
                               on_unused_input='ignore')
    #print("Building f_debug function")
    f_debug = theano.function(inps, [h_rnn_1_3d, h_rnn_2_3d, h_rnn_3_3d],
                              updates=f_prop_updates,
                              on_unused_input='ignore')
    return f_prop, f_update, f_log_prob, f_debug, tparams, opt_tparams, states, None
예제 #6
0
def build_graph(args):
    with tf.device(args.device):
        # n_batch, n_seq, n_feat
        seq_x_data = tf.placeholder(dtype=tf.float32,
                                    shape=(None, None, args.n_input),
                                    name='seq_x_data')
        seq_x_mask = tf.placeholder(dtype=tf.float32,
                                    shape=(None, None),
                                    name='seq_x_mask')
        seq_y_data = tf.placeholder(dtype=tf.int32, shape=(None, None))

        init_state = tuple(
            lstm_state(args.n_hidden, l) for l in range(args.n_layer))

        step_x_data = tf.placeholder(tf.float32,
                                     shape=(None, args.n_input),
                                     name='step_x_data')

    with tf.variable_scope('rnn'):

        def lstm_cell():
            return tf.contrib.rnn.LSTMCell(num_units=args.n_hidden,
                                           forget_bias=0.0)

        cell = tf.contrib.rnn.MultiRNNCell(
            [lstm_cell() for _ in range(args.n_layer)])

    with tf.variable_scope('label'):
        _label_logit = LinearCell(num_units=args.n_class)

    # training graph
    seq_hid_3d, _ = tf.nn.dynamic_rnn(cell=cell,
                                      inputs=seq_x_data,
                                      initial_state=init_state,
                                      scope='rnn')
    seq_hid_2d = tf.reshape(seq_hid_3d, [-1, args.n_hidden])

    seq_label_logits = _label_logit(seq_hid_2d, 'label_logit')

    y_1hot = tf.one_hot(tf.reshape(seq_y_data, [-1]), depth=FLAGS.n_class)

    ml_cost = tf.nn.softmax_cross_entropy_with_logits(logits=seq_label_logits,
                                                      labels=y_1hot)
    ml_cost = tf.reduce_sum(ml_cost * tf.reshape(seq_x_mask, [-1]))

    pred_idx = tf.argmax(seq_label_logits, axis=1)

    seq_label_probs = tf.nn.softmax(seq_label_logits, name='seq_label_probs')

    # testing graph
    step_h_state, step_last_state = cell(step_x_data,
                                         init_state,
                                         scope='rnn/multi_rnn_cell')
    step_label_logits = _label_logit(step_h_state, 'label_logit')
    step_label_probs = tf.nn.softmax(logits=step_label_logits,
                                     name='step_label_probs')

    train_graph = TrainGraph(ml_cost, seq_x_data, seq_x_mask, seq_y_data,
                             init_state, pred_idx, seq_label_probs)

    test_graph = TestGraph(step_x_data, init_state, step_last_state,
                           step_label_probs)

    return train_graph, test_graph
예제 #7
0
파일: train.py 프로젝트: gunkisu/asr
def build_graph(FLAGS):
    # Define input data
    with tf.device(FLAGS.device):
        # input sequence (seq_len, num_samples, num_input)
        x_data = tf.placeholder(dtype=tf.float32,
                                shape=(None, None, FLAGS.n_input),
                                name='x_data')

        # input mask (seq_len, num_samples)
        x_mask = tf.placeholder(dtype=tf.float32,
                                shape=(None, None),
                                name='x_mask')

        # gt label (seq_len, num_samples)
        y_data = tf.placeholder(dtype=tf.int32,
                                shape=(None, None),
                                name='y_data')

        # init state (but mostly init with 0s)
        init_state = tf.placeholder(dtype=tf.float32,
                                    shape=(None, FLAGS.n_hidden),
                                    name='init_state')

        # init counter (but mostly init with 0s)
        init_cntr = tf.placeholder(dtype=tf.float32,
                                   shape=(None, 1),
                                   name='init_cntr')

    # Get one-hot label
    y_1hot = tf.one_hot(y_data, depth=FLAGS.n_class)

    # Get sequence length and batch size
    seq_len = tf.shape(x_data)[0]
    num_samples = tf.shape(x_data)[1]

    # For each layer
    policy_data_list = []
    prev_hid_data = x_data
    for l in range(FLAGS.n_layer):
        # Set input data (concat input and mask)
        prev_input = tf.concat(
            values=[prev_hid_data,
                    tf.expand_dims(x_mask, axis=-1)], axis=-1)

        # Set skim lstm
        with tf.variable_scope('lstm_{}'.format(l)) as vs:
            skim_lstm = SkimLSTMModule(num_units=FLAGS.n_hidden,
                                       max_skims=FLAGS.n_action,
                                       min_reads=FLAGS.n_read,
                                       forget_bias=FLAGS.forget_bias,
                                       use_input=FLAGS.use_input,
                                       use_skim=FLAGS.use_skim)

            # Run bidir skim lstm
            outputs = skim_lstm(inputs=prev_input,
                                init_state=[init_state, init_cntr],
                                use_bidir=True)

        # Get output
        hid_data, read_mask, act_mask, act_lgp = outputs

        # Split data
        fwd_hid_data, bwd_hid_data = tf.split(value=hid_data,
                                              num_or_size_splits=2,
                                              axis=2)
        fwd_read_mask, bwd_read_mask = tf.split(value=read_mask,
                                                num_or_size_splits=2,
                                                axis=2)
        fwd_act_mask, bwd_act_mask = tf.split(value=act_mask,
                                              num_or_size_splits=2,
                                              axis=2)
        fwd_act_lgp, bwd_act_lgp = tf.split(value=act_lgp,
                                            num_or_size_splits=2,
                                            axis=2)

        # Set summary
        tf.summary.image(
            name='fwd_results_{}'.format(l),
            tensor=tf.concat(values=[
                tf.tile(
                    tf.expand_dims(tf.expand_dims(tf.transpose(x_mask, [1, 0]),
                                                  axis=-1),
                                   axis=1), [1, 20, 1, 1]),
                tf.tile(
                    tf.expand_dims(tf.expand_dims(tf.transpose(
                        tf.to_float(y_data) / tf.to_float(FLAGS.n_class),
                        [1, 0]),
                                                  axis=-1),
                                   axis=1), [1, 20, 1, 1]),
                tf.tile(
                    tf.expand_dims(tf.transpose(fwd_read_mask, [1, 0, 2]),
                                   axis=1), [1, 20, 1, 1]),
                tf.tile(
                    tf.expand_dims(tf.transpose(fwd_act_mask, [1, 0, 2]),
                                   axis=1), [1, 20, 1, 1]),
            ],
                             axis=1))
        tf.summary.image(
            name='bwd_results_{}'.format(l),
            tensor=tf.concat(values=[
                tf.tile(
                    tf.expand_dims(tf.expand_dims(tf.transpose(x_mask, [1, 0]),
                                                  axis=-1),
                                   axis=1), [1, 20, 1, 1]),
                tf.tile(
                    tf.expand_dims(tf.expand_dims(tf.transpose(
                        tf.to_float(y_data) / tf.to_float(FLAGS.n_class),
                        [1, 0]),
                                                  axis=-1),
                                   axis=1), [1, 20, 1, 1]),
                tf.tile(
                    tf.expand_dims(tf.transpose(bwd_read_mask, [1, 0, 2]),
                                   axis=1), [1, 20, 1, 1]),
                tf.tile(
                    tf.expand_dims(tf.transpose(bwd_act_mask, [1, 0, 2]),
                                   axis=1), [1, 20, 1, 1]),
            ],
                             axis=1))

        # Set baseline
        with tf.variable_scope("fwd_baseline_{}".format(l)) as vs:
            fwd_policy_input = tf.reshape(tf.stop_gradient(fwd_hid_data),
                                          [-1, FLAGS.n_hidden])
            fwd_baseline_cell = LinearCell(num_units=1)
            fwd_basline = fwd_baseline_cell(fwd_policy_input)
            fwd_basline = tf.reshape(fwd_basline, [seq_len, num_samples])

        with tf.variable_scope("bwd_baseline_{}".format(l)) as vs:
            bwd_policy_input = tf.reshape(tf.stop_gradient(bwd_hid_data),
                                          [-1, FLAGS.n_hidden])
            bwd_baseline_cell = LinearCell(num_units=1)
            bwd_basline = bwd_baseline_cell(bwd_policy_input)
            bwd_basline = tf.reshape(bwd_basline, [seq_len, num_samples])

        # Set next input
        prev_hid_data = hid_data

        # Save data
        policy_data_list.append([
            tf.squeeze(fwd_read_mask),
            tf.squeeze(fwd_act_mask), fwd_act_lgp, fwd_basline
        ])
        policy_data_list.append([
            tf.squeeze(bwd_read_mask),
            tf.squeeze(bwd_act_mask), bwd_act_lgp, bwd_basline
        ])

    # Set output layer
    with tf.variable_scope('output') as vs:
        output_cell = LinearCell(FLAGS.n_class)
        output_logit = output_cell(
            tf.reshape(prev_hid_data, [-1, 2 * FLAGS.n_hidden]))
        output_logit = tf.reshape(output_logit,
                                  (seq_len, num_samples, FLAGS.n_class))

    # Frame-wise cross entropy
    frame_cce = tf.nn.softmax_cross_entropy_with_logits(
        labels=tf.reshape(y_1hot, [-1, FLAGS.n_class]),
        logits=tf.reshape(output_logit, [-1, FLAGS.n_class]))
    frame_cce *= tf.reshape(x_mask, [
        -1,
    ])

    # Frame mean cce
    mean_frame_cce = tf.reduce_sum(frame_cce) / tf.reduce_sum(x_mask)
    tf.summary.scalar(name='frame_cce', tensor=mean_frame_cce)

    # Model cce
    model_cce = tf.reduce_sum(frame_cce) / tf.to_float(num_samples)

    # Frame-wise accuracy
    frame_accr = tf.to_float(
        tf.equal(tf.argmax(output_logit, axis=-1), tf.argmax(
            y_1hot, axis=-1))) * x_mask
    sample_frame_accr = tf.reduce_sum(frame_accr, axis=0) / tf.reduce_sum(
        x_mask, axis=0)
    mean_frame_accr = tf.reduce_sum(frame_accr) / tf.reduce_sum(x_mask)
    tf.summary.scalar(name='frame_accr', tensor=mean_frame_accr)

    # Sample-wise REWARD
    sample_reward = sample_frame_accr

    # Define policy cost for each network
    baseline_cost_list = []
    policy_cost_list = []
    read_ratio_list = []
    for i, policy_data in enumerate(policy_data_list):
        # Get data
        read_mask, act_mask, act_lgp, baseline = policy_data

        # Get read ratio
        read_ratio = tf.reduce_sum(read_mask, axis=0) / tf.reduce_sum(x_mask,
                                                                      axis=0)
        skim_ratio = 1.0 - read_ratio

        # combine reward (frame accuracy and skim ratio)
        original_reward = (sample_reward + skim_ratio * 0.0)

        # revised with baseline
        revised_reward = (tf.expand_dims(original_reward, axis=0) -
                          baseline) * act_mask

        # baseline cost
        baseline_cost = tf.reduce_sum(tf.square(revised_reward))

        # policy cost
        policy_cost = tf.stop_gradient(revised_reward) * tf.reduce_sum(
            act_lgp, axis=-1) * act_mask
        policy_cost = -tf.reduce_sum(policy_cost) / tf.to_float(num_samples)

        # Save values
        baseline_cost_list.append(baseline_cost)
        policy_cost_list.append(policy_cost)
        read_ratio_list.append(tf.reduce_mean(read_ratio, keep_dims=True))

        tf.summary.scalar(name='frame_read_ratio_{}'.format(i),
                          tensor=tf.reduce_mean(read_ratio))

    tf.summary.scalar(name='policy_cost', tensor=tf.add_n(policy_cost_list))
    tf.summary.scalar(name='baseline_cost',
                      tensor=tf.add_n(baseline_cost_list))

    return Graph(x_data=x_data,
                 x_mask=x_mask,
                 y_data=y_data,
                 init_state=init_state,
                 init_cntr=init_cntr,
                 mean_accr=mean_frame_accr,
                 mean_loss=mean_frame_cce,
                 ml_cost=model_cce,
                 rl_cost=tf.add_n(policy_cost_list),
                 bl_cost=tf.add_n(baseline_cost_list),
                 read_ratio_list=tf.concat(read_ratio_list, axis=0))
예제 #8
0
def build_graph(args):
    ##################
    # Input variable #
    ##################
    with tf.device(args.device):
        ##################
        # Sequence-level #
        ##################
        # Input sequence data [batch_size, seq_len, ...]
        seq_x_data = tf.placeholder(dtype=tf.float32,
                                    shape=(None, None, args.n_input),
                                    name='seq_x_data')
        seq_x_mask = tf.placeholder(dtype=tf.float32,
                                    shape=(None, None),
                                    name='seq_x_mask')
        seq_y_data = tf.placeholder(dtype=tf.int32,
                                    shape=(None, None),
                                    name='seq_y_data')

        # Action related data [batch_size, seq_len, ...]
        seq_a_data = tf.placeholder(dtype=tf.float32,
                                    shape=(None, None, args.n_action),
                                    name='seq_a_data')
        seq_a_mask = tf.placeholder(dtype=tf.float32,
                                    shape=(None, None),
                                    name='seq_a_mask')
        seq_advantage = tf.placeholder(dtype=tf.float32,
                                       shape=(None, None),
                                       name='seq_advantage')
        seq_reward = tf.placeholder(dtype=tf.float32,
                                    shape=(None, None),
                                    name='seq_reward')
        ##############
        # Step-level #
        ##############
        # Input step data [batch_size, n_input]
        step_x_data = tf.placeholder(dtype=tf.float32,
                                     shape=(None, args.n_input),
                                     name='step_x_data')

        # Prev action
        prev_a_data = tf.placeholder(dtype=tf.float32,
                                     shape=(None, args.n_action),
                                     name='prev_a_data')

        # Prev state [2, batch_size, n_hidden]
        prev_state = tf.placeholder(dtype=tf.float32,
                                    shape=(2, None, args.n_hidden),
                                    name='prev_states')

        # Flag for sampling
        use_sampling = tf.placeholder(dtype=tf.bool, name='use_sampling')
    ###########
    # Modules #
    ###########
    # Recurrent Module (LSTM)
    with tf.variable_scope('rnn'):
        _rnn = LSTMModule(num_units=args.n_hidden)

    # Labelling Module (FF)
    with tf.variable_scope('label'):
        _label_logit = LinearCell(num_units=args.n_class, activation=None)

    # Actioning Module (FF)
    with tf.variable_scope('action'):
        _action_logit = LinearCell(num_units=args.n_action, activation=None)

    ##################
    # Sampling graph #
    ##################
    # Recurrent update
    step_h_state, step_last_state = _rnn(inputs=tf.concat(
        values=[step_x_data, prev_a_data], axis=-1),
                                         init_state=prev_state,
                                         one_step=True)

    # Label logits/probs
    step_label_logits = _label_logit(inputs=step_h_state, scope='label_logit')
    step_label_probs = tf.nn.softmax(logits=step_label_logits)

    # Action logits
    if FLAGS.ref_input:
        step_action_input = [step_h_state, step_x_data]
    else:
        step_action_input = step_h_state
    step_action_logits = _action_logit(inputs=step_action_input,
                                       scope='action_logit')

    # Action probs
    step_action_probs = tf.nn.softmax(logits=step_action_logits)

    # Action sampling
    step_action_samples = tf.cond(
        pred=use_sampling,
        fn1=lambda: tf.multinomial(logits=step_action_logits, num_samples=1),
        fn2=lambda: tf.reshape(tf.argmax(input=step_action_logits, axis=-1),
                               [-1, 1]))

    # Set sampling graph
    sample_graph = SampleGraph(step_x_data, prev_a_data, prev_state,
                               step_h_state, step_last_state, step_label_probs,
                               step_action_probs, step_action_samples,
                               use_sampling)

    ##################
    # Training graph #
    ##################
    # Recurrent update
    init_state = tf.zeros(shape=[2, tf.shape(seq_x_data)[0], args.n_hidden])
    seq_h_state_3d, seq_last_state = _rnn(inputs=tf.concat(values=[
        seq_x_data,
        tf.concat([
            tf.zeros(shape=[tf.shape(seq_x_data)[0], 1, args.n_action]),
            seq_a_data[:, :-1, :]
        ],
                  axis=1)
    ],
                                                           axis=-1),
                                          init_state=init_state,
                                          one_step=False)

    # Label logits/probs
    seq_label_logits = _label_logit(inputs=tf.reshape(seq_h_state_3d,
                                                      [-1, args.n_hidden]),
                                    scope='label_logit')

    # Action logits
    if FLAGS.ref_input:
        seq_a_input = [
            tf.reshape(seq_h_state_3d, [-1, args.n_hidden]),
            tf.reshape(seq_x_data, [-1, args.n_input])
        ]
    else:
        seq_a_input = tf.reshape(seq_h_state_3d, [-1, args.n_hidden])
    seq_a_logits = _action_logit(inputs=seq_a_input, scope='action_logit')

    # Action probs
    seq_a_probs = tf.nn.softmax(logits=seq_a_logits)

    # Action entropy
    seq_a_ent = categorical_ent(dist=seq_a_probs) * tf.reshape(
        seq_a_mask, [-1])

    # ML cost (logP(label))
    seq_y_1hot = tf.one_hot(indices=tf.reshape(seq_y_data, [-1]),
                            depth=args.n_class)
    seq_ml_cost = tf.nn.softmax_cross_entropy_with_logits(
        logits=seq_label_logits, labels=seq_y_1hot)
    seq_ml_cost *= tf.reshape(seq_x_mask, [-1])

    # RL cost (logP(action)*reward)
    seq_rl_cost = -tf.log(seq_a_probs + 1e-8) * tf.reshape(
        seq_a_data, [-1, args.n_action])
    seq_rl_cost = tf.reduce_sum(seq_rl_cost, axis=-1)
    seq_rl_cost *= tf.reshape(seq_advantage, [-1]) * tf.reshape(
        seq_a_mask, [-1])

    # RL cost wo/ baseline
    seq_real_rl_cost = tf.reshape(seq_reward, [-1]) * tf.reshape(
        seq_a_mask, [-1])

    # Set training graph
    train_graph = TrainGraph(seq_x_data, seq_x_mask, seq_y_data, seq_a_data,
                             seq_a_mask, seq_advantage, seq_reward,
                             seq_label_logits, seq_ml_cost, seq_rl_cost,
                             seq_real_rl_cost, seq_a_ent)

    return train_graph, sample_graph