Exemplo n.º 1
0
    def createGraph(self):

        self.input = tf.placeholder(tf.int32, [self.batch_size, self.seq_len],
                                    name='inputs')
        self.targs = tf.placeholder(tf.int32, [self.batch_size, self.seq_len],
                                    name='targets')
        onehot = tf.one_hot(self.input, self.vocab_size, name='input_oh')

        inputs = tf.split(onehot, self.seq_len, 1)
        inputs = [tf.squeeze(i, [1]) for i in inputs]
        targets = tf.split(self.targs, self.seq_len, 1)

        with tf.variable_scope("posRNN"):

            cells = [GRUCell(self.num_hidden) for _ in range(self.num_layers)]

            stacked = MultiRNNCell(cells, state_is_tuple=True)
            self.zero_state = stacked.zero_state(self.batch_size, tf.float32)

            outputs, self.last_state = seq2seq.rnn_decoder(
                inputs, self.zero_state, stacked)

            w = tf.get_variable(
                "w", [self.num_hidden, self.vocab_size],
                tf.float32,
                initializer=tf.random_normal_initializer(stddev=0.02))
            b = tf.get_variable("b", [self.vocab_size],
                                initializer=tf.constant_initializer(0.0))
            logits = [tf.matmul(o, w) + b for o in outputs]

            const_weights = [
                tf.ones([self.batch_size]) for _ in xrange(self.seq_len)
            ]
            self.loss = seq2seq.sequence_loss(logits, targets, const_weights)

            self.opt = tf.train.AdamOptimizer(0.001,
                                              beta1=0.5).minimize(self.loss)

        with tf.variable_scope("posRNN", reuse=True):

            batch_size = 1
            self.s_inputs = tf.placeholder(tf.int32, [batch_size],
                                           name='s_inputs')
            s_onehot = tf.one_hot(self.s_inputs,
                                  self.vocab_size,
                                  name='s_input_oh')

            self.s_zero_state = stacked.zero_state(batch_size, tf.float32)
            s_outputs, self.s_last_state = seq2seq.rnn_decoder(
                [s_onehot], self.s_zero_state, stacked)
            s_outputs = tf.reshape(s_outputs, [1, self.num_hidden])
            self.s_probs = tf.nn.softmax(tf.matmul(s_outputs, w) + b)
Exemplo n.º 2
0
class RNNpropModel(nn_opt.BasicNNOptModel):
    def _build_pre(self):
        self.dimA = 20
        self.cellA = MultiRNNCell([LSTMCell(self.dimA)] * 2)
        self.b1 = 0.95
        self.b2 = 0.95
        self.lr = 0.1
        self.eps = 1e-8

    def _build_input(self):
        self.x = self.ph([None])
        self.m = self.ph([None])
        self.v = self.ph([None])
        self.b1t = self.ph([])
        self.b2t = self.ph([])
        self.sid = self.ph([])
        self.cellA_state = tuple(
            (self.ph([None, size.c]), self.ph([None, size.h]))
            for size in self.cellA.state_size)
        self.input_state = [
            self.sid, self.b1t, self.b2t, self.x, self.m, self.v,
            self.cellA_state
        ]

    def _build_initial(self):
        x = self.x
        m = tf.zeros(shape=tf.shape(x))
        v = tf.zeros(shape=tf.shape(x))
        b1t = tf.ones([])
        b2t = tf.ones([])
        cellA_state = self.cellA.zero_state(tf.size(x), tf.float32)
        self.initial_state = [tf.zeros([]), b1t, b2t, x, m, v, cellA_state]

    # return state, fx
    def _iter(self, f, i, state):
        sid, b1t, b2t, x, m, v, cellA_state = state

        fx, grad = self._get_fx(f, i, x)
        grad = tf.stop_gradient(grad)

        m = self.b1 * m + (1 - self.b1) * grad
        v = self.b2 * v + (1 - self.b2) * (grad**2)

        b1t *= self.b1
        b2t *= self.b2

        sv = tf.sqrt(v / (1 - b2t)) + self.eps

        last = tf.stack([grad / sv, (m / (1 - b1t)) / sv], 1)
        last = tf.nn.elu(self.fc(last, 20))

        with tf.variable_scope("cellA"):
            lastA, cellA_state = self.cellA(last, cellA_state)
        with tf.variable_scope("fc_A"):
            a = self.fc(lastA, 1)[:, 0]

        a = tf.tanh(a) * self.lr
        x -= a

        return [sid + 1, b1t, b2t, x, m, v, cellA_state], fx
Exemplo n.º 3
0
    def impress(self, state_code, pre_impress_states):
        # LSTM, 3 layers
        self.impress_lay_num = 3
        with tf.variable_scope('impress', reuse=tf.AUTO_REUSE):
            def loop_fn(time, cell_output, cell_state, loop_state):
                if cell_output is None:#time = 0
                    # initialization
                    input = state_code
                    state = state_
                    emit_output = None
                    loop_state = None
                else:
                    input = cell_output
                    emit_output = cell_output
                    state = cell_state
                    loop_state = None
                    
                elements_finished = (time >= 1)
                return (elements_finished, input, state, emit_output, loop_state)

            multirnn_cell = MultiRNNCell([LSTMCell(self.impress_dim) 
                    for _ in range(self.impress_lay_num)],  state_is_tuple=True) 
            
            if pre_impress_states == None:
                state_ = (multirnn_cell.zero_state(self.batch_size, tf.float32))
            else:
                state_ = pre_impress_states   
    
            emit_ta, states, final_loop_state = tf.nn.raw_rnn(multirnn_cell, loop_fn)
            state_impress_code = tf.transpose(emit_ta.stack(), [1, 0, 2])[0] # transpose for putting batch dimension to first dimension
            
            return state_impress_code, final_loop_state
Exemplo n.º 4
0
    def _create_decoder_cell(self):
        enc_outputs, enc_states, enc_seq_len = self.enc_outputs, self.enc_states, self.enc_seq_len
        batch_size = self.batch_size * self.cfg.beam_size if self.use_beam_search else self.batch_size
        with tf.variable_scope("attention"):
            if self.cfg.attention == "luong":  # Luong attention mechanism
                attention_mechanism = LuongAttention(
                    num_units=self.cfg.num_units,
                    memory=enc_outputs,
                    memory_sequence_length=enc_seq_len)
            else:  # default using Bahdanau attention mechanism
                attention_mechanism = BahdanauAttention(
                    num_units=self.cfg.num_units,
                    memory=enc_outputs,
                    memory_sequence_length=enc_seq_len)

        def cell_input_fn(
            inputs, attention
        ):  # define cell input function to keep input/output dimension same
            # reference: https://www.tensorflow.org/api_docs/python/tf/contrib/seq2seq/AttentionWrapper
            if not self.cfg.use_attention_input_feeding:
                return inputs
            input_project = tf.layers.Dense(self.cfg.num_units,
                                            dtype=tf.float32,
                                            name='attn_input_feeding')
            return input_project(tf.concat([inputs, attention], axis=-1))

        if self.cfg.top_attention:  # apply attention mechanism only on the top decoder layer
            cells = [
                self._create_rnn_cell() for _ in range(self.cfg.num_layers)
            ]
            cells[-1] = AttentionWrapper(
                cells[-1],
                attention_mechanism=attention_mechanism,
                name="Attention_Wrapper",
                attention_layer_size=self.cfg.num_units,
                initial_cell_state=enc_states[-1],
                cell_input_fn=cell_input_fn)
            initial_state = [state for state in enc_states]
            initial_state[-1] = cells[-1].zero_state(batch_size=batch_size,
                                                     dtype=tf.float32)
            dec_init_states = tuple(initial_state)
            cells = MultiRNNCell(cells)
        else:
            cells = MultiRNNCell(
                [self._create_rnn_cell() for _ in range(self.cfg.num_layers)])
            cells = AttentionWrapper(cells,
                                     attention_mechanism=attention_mechanism,
                                     name="Attention_Wrapper",
                                     attention_layer_size=self.cfg.num_units,
                                     initial_cell_state=enc_states,
                                     cell_input_fn=cell_input_fn)
            dec_init_states = cells.zero_state(
                batch_size=batch_size,
                dtype=tf.float32).clone(cell_state=enc_states)
        return cells, dec_init_states
Exemplo n.º 5
0
class LSTMOptModel(nn_opt.BasicNNOptModel):
    def lstm_cell(self):
        return LSTMCell(num_units=self.dimH)

    def _build_pre(self, size):
        self.dimH = size
        self.num_of_layers = 2
        self.cellH = MultiRNNCell(
            [self.lstm_cell() for _ in range(self.num_of_layers)])
        self.lr = 0.1

    def _build_input(self):
        self.x = self.ph([None])
        self.cellH_state = tuple(
            (self.ph([None, size.c]), self.ph([None, size.h]))
            for size in self.cellH.state_size)
        self.input_state = [self.x, self.cellH_state]

    def _build_initial(self):
        x = self.x  # weights of optimizee
        cellH_state = self.cellH.zero_state(tf.size(x), tf.float32)
        self.initial_state = [x, cellH_state]

    # return state, fx
    def iter(self, f, i, state):
        x, cellH_state = state

        fx, grad = self._get_fx(f, i, x)
        self.optimizee_grad.append(grad)
        grad = tf.stop_gradient(grad)

        last = self._deepmind_log_encode(grad)

        with tf.variable_scope("cellH"):
            last, cellH_state = self.cellH(last, cellH_state)

        with tf.variable_scope("fc"):
            last = self.fc(last, 1)

        delta_x = last[:, 0] * self.lr

        x += delta_x
        return [x, cellH_state], fx
Exemplo n.º 6
0
class LSTMOptModel(nn_opt.BasicNNOptModel):
    def _build_pre(self):
        self.dimH = 20
        self.cellH = MultiRNNCell([LSTMCell(self.dimH)] * 2)
        self.lr = 0.1

    def _build_input(self):
        self.x = self.ph([None])
        self.cellH_state = tuple(
            (self.ph([None, size.c]), self.ph([None, size.h]))
            for size in self.cellH.state_size)
        self.input_state = [self.x, self.cellH_state]

    def _build_initial(self):
        x = self.x
        cellH_state = self.cellH.zero_state(tf.size(x), tf.float32)
        self.initial_state = [x, cellH_state]

    # return state, fx
    def _iter(self, f, i, state):
        x, cellH_state = state

        fx, grad = self._get_fx(f, i, x)
        grad = tf.stop_gradient(grad)

        last = self._deepmind_log_encode(grad)

        with tf.variable_scope("cellH"):
            last, cellH_state = self.cellH(last, cellH_state)

        with tf.variable_scope("fc"):
            last = self.fc(last, 1)

        delta_x = last[:, 0] * self.lr

        x += delta_x
        return [x, cellH_state], fx
Exemplo n.º 7
0
# targets is a list of length sequence_length
# each element of targets is a 1D vector of length batch_size

# ------------------
# YOUR COMPUTATION GRAPH HERE

# create a BasicLSTMCell
gru0 = mygru(state_dim)
gru1 = mygru(state_dim)

# use it to create a MultiRNNCell
my_rnn = MultiRNNCell([gru0, gru1], state_is_tuple=True)

# use it to create an initial_state
# note that initial_state will be a *list* of tensors!
initial_state = my_rnn.zero_state(batch_size, tf.float32)

# call seq2seq.rnn_decoder
with tf.variable_scope("encoder") as scope:
    outputs, final_state = tf.contrib.legacy_seq2seq.rnn_decoder(
        inputs, initial_state, my_rnn)

W = tf.Variable(tf.random_normal([state_dim, vocab_size], stddev=0.02))
b = tf.Variable(tf.random_normal([vocab_size], stddev=0.01))

# transform the list of state outputs to a list of logits.
# use a linear transformation.
logits = [tf.matmul(output, W) + [b] * batch_size for output in outputs]

# call seq2seq.sequence_loss
loss_w = [1.0 for i in range(sequence_length)]
Exemplo n.º 8
0
            initializer=tf.contrib.layers.variance_scaling_initializer())

        r_t = tf.sign(tf.matmul(inputs, W_xr) + tf.matmul(state, W_hr) + b_r)
        z_t = tf.sign(tf.matmul(inputs, W_xz) + tf.matmul(state, W_hz) + b_z)
        h_hat_t = tf.tanh(
            tf.matmul(inputs, W_xh) + tf.matmul(r_t * state, W_hh) + b_h)

        h_t = z_t * state + (1 - z_t) * h_hat_t

        return h_t, h_t


multi_cell = MultiRNNCell(
    [BasicLSTMCell(state_dim) for i in range(num_layers)])
#multi_cell = MultiRNNCell([mygru(state_dim) for i in range(num_layers)])
initial_state = multi_cell.zero_state(batch_size, dtype=tf.float32)

# call seq2seq.rnn_decoder
outputs, final_state = rnn_decoder(inputs, initial_state, multi_cell)

# transform the list of state outputs to a list of logits.
# use a linear transformation.
weights = tf.get_variable(
    name="W",
    shape=[state_dim, vocab_size],
    initializer=tf.contrib.layers.variance_scaling_initializer())
bias = tf.get_variable(
    name="b",
    shape=[vocab_size],
    initializer=tf.contrib.layers.variance_scaling_initializer())
Exemplo n.º 9
0
inputs = [tf.squeeze(input_, [1]) for input_ in inputs]
targets = tf.split(targ_ph, sequence_length, axis=1)

# at this point, inputs is a list of length sequence_length
# each element of inputs is [batch_size,vocab_size]

# targets is a list of length sequence_length
# each element of targets is a 1D vector of length batch_size

with tf.name_scope("rnn_base") as scope:
    # lstm1 = BasicLSTMCell(state_dim)
    # lstm2 = BasicLSTMCell(state_dim)
    lstm1 = mygru(state_dim)
    lstm2 = mygru(state_dim)
    rnn = MultiRNNCell([lstm1, lstm2])
    initial_state = rnn.zero_state(batch_size, tf.float32)

with tf.name_scope("decoder") as scope:
    outputs, final_state = rnn_decoder(inputs, initial_state, rnn)
    W = tf.Variable(tf.random_normal([state_dim, vocab_size], stddev=0.02))
    b = tf.Variable(tf.random_normal([vocab_size], stddev=0.01))
    logits = [tf.matmul(output, W) + [b] for output in outputs]
    loss_w = [1.0 for i in range(sequence_length)]
    loss = sequence_loss(logits=logits, targets=targets, weights=loss_w)
    optim = tf.train.AdamOptimizer(learning_rate=0.001).minimize(loss)

with tf.name_scope("sampler") as scope:
    s_in_ph = tf.placeholder(tf.int32, [1], name='s_inputs')
    s_in_onehot = tf.one_hot(s_in_ph, vocab_size, name="s_input_onehot")
    s_inputs = tf.split(s_in_onehot, 1, axis=1)
    s_initial_state = rnn.zero_state(1, tf.float32)
Exemplo n.º 10
0
def construct_model(images,
                    actions=None,
                    states=None,
                    iter_num=-1.0,
                    k=-1,
                    use_state=True,
                    num_masks=10,
                    stp=False,
                    cdna=True,
                    dna=False,
                    context_frames=2,
                    pix_distributions=None,
                    conf=None):
    """Build convolutional lstm video predictor using STP, CDNA, or DNA.

    Args:
      images: tensor of ground truth image sequences
      actions: tensor of action sequences
      states: tensor of ground truth state sequences
      iter_num: tensor of the current training iteration (for sched. sampling)
      k: constant used for scheduled sampling. -1 to feed in own prediction.
      use_state: True to include state and action in prediction
      num_masks: the number of different pixel motion predictions (and
                 the number of masks for each of those predictions)
      stp: True to use Spatial Transformer Predictor (STP)
      cdna: True to use Convoluational Dynamic Neural Advection (CDNA)
      dna: True to use Dynamic Neural Advection (DNA)
      context_frames: number of ground truth frames to pass in before
                      feeding in own predictions
      pix_distrib: the initial one-hot distriubtion for designated pixels
    Returns:
      gen_images: predicted future image frames
      gen_states: predicted future states

    Raises:
      ValueError: if more than one network option specified or more than 1 mask
      specified for DNA model.
    """

    if 'dna_size' in conf.keys():
        DNA_KERN_SIZE = conf['dna_size']
    else:
        DNA_KERN_SIZE = 5

    print 'constructing network with less layers...'

    if stp + cdna + dna != 1:
        raise ValueError('More than one, or no network option specified.')
    batch_size, img_height, img_width, color_channels = images[0].get_shape(
    )[0:4]
    batch_size = int(batch_size)
    lstm_func = basic_conv_lstm_cell

    # Generated robot states and images.
    gen_states, gen_images, gen_masks, inf_low_state, pred_low_state = [], [], [], [], []
    current_state = states[0]
    gen_pix_distrib = []

    summaries = []

    if k == -1:
        feedself = True
    else:
        # Scheduled sampling:
        # Calculate number of ground-truth frames to pass in.
        num_ground_truth = tf.to_int32(
            tf.round(
                tf.to_float(batch_size) * (k / (k + tf.exp(iter_num / k)))))
        feedself = False

    # LSTM state sizes and states.
    lstm_size = np.int32(np.array([16, 32, 64, 100, 10]))
    lstm_state1, lstm_state2, lstm_state3 = None, None, None

    single_lstm1 = BasicLSTMCell(lstm_size[3], state_is_tuple=True)
    single_lstm2 = BasicLSTMCell(lstm_size[4], state_is_tuple=True)
    low_dim_lstm = MultiRNNCell([single_lstm1, single_lstm2],
                                state_is_tuple=True)

    low_dim_lstm_state = low_dim_lstm.zero_state(batch_size, tf.float32)

    dim_low_state = int(lstm_size[-1])

    t = -1
    for image, action in zip(images[:-1], actions[:-1]):
        t += 1
        print 'building timestep ', t
        # Reuse variables after the first timestep.
        reuse = bool(gen_images)

        done_warm_start = len(gen_images) > context_frames - 1
        with slim.arg_scope([
                lstm_func, slim.layers.conv2d, slim.layers.fully_connected,
                tf_layers.layer_norm, slim.layers.conv2d_transpose
        ],
                            reuse=reuse):

            if feedself and done_warm_start:
                # Feed in generated image.
                prev_image = gen_images[-1]
                if pix_distributions != None:
                    prev_pix_distrib = gen_pix_distrib[-1]
            elif done_warm_start:
                # Scheduled sampling
                prev_image = scheduled_sample(image, gen_images[-1],
                                              batch_size, num_ground_truth)
            else:
                # Always feed in ground_truth
                prev_image = image
                if pix_distributions != None:
                    prev_pix_distrib = pix_distributions[t]
                    prev_pix_distrib = tf.expand_dims(prev_pix_distrib, -1)

            # Predicted state is always fed back in
            state_action = tf.concat(1, [action, current_state])  # 6x

            import pdb
            pdb.set_trace()
            enc0 = slim.layers.conv2d(  #32x32x32
                prev_image,
                32,
                kernel_size=[5, 5],
                stride=2,
                scope='scale1_conv1',
                normalizer_fn=tf_layers.layer_norm,
                normalizer_params={'scope': 'layer_norm1'})

            hidden1, lstm_state1 = lstm_func(  #32x32
                enc0, lstm_state1, lstm_size[0], scope='state1')
            hidden1 = tf_layers.layer_norm(hidden1, scope='layer_norm2')

            enc1 = slim.layers.conv2d(  #16x16
                hidden1,
                hidden1.get_shape()[3], [3, 3],
                stride=2,
                scope='conv2')

            hidden2, lstm_state2 = lstm_func(  #16x16x32
                enc1, lstm_state2, lstm_size[1], scope='state3')
            hidden2 = tf_layers.layer_norm(hidden2, scope='layer_norm4')

            enc2 = slim.layers.conv2d(  #8x8x32
                hidden2,
                hidden2.get_shape()[3], [3, 3],
                stride=2,
                scope='conv3')

            # Pass in state and action.
            smear = tf.reshape(
                state_action,
                [batch_size, 1, 1,
                 int(state_action.get_shape()[1])])
            smear = tf.tile(  #8x8x6
                smear,
                [1, int(enc2.get_shape()[1]),
                 int(enc2.get_shape()[2]), 1])
            if use_state:
                enc2 = tf.concat(3, [enc2, smear])
            enc3 = slim.layers.conv2d(  #8x8x32
                enc2,
                hidden2.get_shape()[3], [1, 1],
                stride=1,
                scope='conv4')

            hidden3, lstm_state3 = lstm_func(  #8x8x64
                enc3, lstm_state3, lstm_size[2], scope='state5')  # last 8x8
            hidden3 = tf_layers.layer_norm(hidden3, scope='layer_norm6')

            enc3 = slim.layers.conv2d(  # 8x8x32
                hidden3, 16, [1, 1], stride=1, scope='conv5')

            enc3_flat = tf.reshape(enc3, [batch_size, -1])

            if 'use_low_dim_lstm' in conf:
                with tf.variable_scope('low_dim_lstm', reuse=reuse):
                    hidden4, low_dim_lstm_state = low_dim_lstm(
                        enc3_flat, low_dim_lstm_state)
                low_dim_state = hidden4
            else:
                enc_fully1 = slim.layers.fully_connected(enc3_flat,
                                                         400,
                                                         scope='enc_fully1')

                enc_fully2 = slim.layers.fully_connected(enc_fully1,
                                                         100,
                                                         scope='enc_fully2')

                low_dim_state = enc_fully2

            # inferred low dimensional state:
            inf_low_state.append(low_dim_state)

            pred_low_state.append(project_fwd_lowdim(low_dim_state))

            smear = tf.reshape(low_dim_state,
                               [batch_size, 1, 1, dim_low_state])
            smear = tf.tile(  # 8x8xdim_hidden_state
                smear,
                [1, int(enc2.get_shape()[1]),
                 int(enc2.get_shape()[2]), 1])

            enc4 = slim.layers.conv2d_transpose(  #16x16x32
                smear,
                hidden3.get_shape()[3],
                3,
                stride=2,
                scope='convt1')

            enc5 = slim.layers.conv2d_transpose(  #32x32x32
                enc4,
                enc0.get_shape()[3],
                3,
                stride=2,
                scope='convt2')

            enc6 = slim.layers.conv2d_transpose(  #64x64x16
                enc5,
                16,
                3,
                stride=2,
                scope='convt3',
                normalizer_fn=tf_layers.layer_norm,
                normalizer_params={'scope': 'layer_norm9'})

            # Using largest hidden state for predicting untied conv kernels.
            enc7 = slim.layers.conv2d_transpose(enc6,
                                                DNA_KERN_SIZE**2,
                                                1,
                                                stride=1,
                                                scope='convt4')

            # Only one mask is supported (more should be unnecessary).
            if num_masks != 1:
                raise ValueError('Only one mask is supported for DNA model.')
            transformed = [dna_transformation(prev_image, enc7, DNA_KERN_SIZE)]

            if 'use_masks' in conf:
                masks = slim.layers.conv2d_transpose(enc6,
                                                     num_masks + 1,
                                                     1,
                                                     stride=1,
                                                     scope='convt7')
                masks = tf.reshape(
                    tf.nn.softmax(tf.reshape(masks, [-1, num_masks + 1])), [
                        int(batch_size),
                        int(img_height),
                        int(img_width), num_masks + 1
                    ])
                mask_list = tf.split(3, num_masks + 1, masks)
                output = mask_list[0] * prev_image
                for layer, mask in zip(transformed, mask_list[1:]):
                    output += layer * mask
            else:
                mask_list = None
                output = transformed

            gen_images.append(output)
            gen_masks.append(mask_list)

            if dna and pix_distributions != None:
                transf_distrib = [
                    dna_transformation(prev_pix_distrib, enc7, DNA_KERN_SIZE)
                ]

            if pix_distributions != None:
                pix_distrib_output = mask_list[0] * prev_pix_distrib
                mult_list = []
                for i in range(num_masks):
                    mult_list.append(transf_distrib[i] * mask_list[i + 1])
                    pix_distrib_output += mult_list[i]

                gen_pix_distrib.append(pix_distrib_output)

            # pred_low_state_stopped = tf.stop_gradient(pred_low_state)

            state_enc1 = slim.layers.fully_connected(
                # pred_low_state[-1],
                low_dim_state,
                100,
                scope='state_enc1')

            state_enc2 = slim.layers.fully_connected(
                state_enc1,
                # int(current_state.get_shape()[1]),
                4,
                scope='state_enc2',
                activation_fn=None)
            current_state = tf.squeeze(state_enc2)
            gen_states.append(current_state)

    if pix_distributions != None:
        return gen_images, gen_states, gen_masks, gen_pix_distrib, inf_low_state, pred_low_state
    else:
        return gen_images, gen_states, gen_masks, None, inf_low_state, pred_low_state
Exemplo n.º 11
0
class RNNpropModel(nn_opt.BasicNNOptModel):
    def _build_pre(self, size):
        self.dimA = size
        self.num_of_layers = 2
        self.cellA = MultiRNNCell([LSTMCell(num_units=self.dimA) for _ in range(self.num_of_layers)])
        self.b1 = 0.95
        self.b2 = 0.95
        self.lr = 0.1
        self.eps = 1e-8

    def _build_input(self):
        self.x = self.ph([None])
        self.m = self.ph([None])
        self.v = self.ph([None])
        self.b1t = self.ph([])
        self.b2t = self.ph([])
        self.sid = self.ph([])
        self.cellA_state = tuple((self.ph([None, size.c]), self.ph([None, size.h])) for size in self.cellA.state_size)
        self.input_state = [self.x, self.sid, self.b1t, self.b2t, self.m, self.v, self.cellA_state]

    def _build_initial(self):
        x = self.x
        m = tf.zeros(shape=tf.shape(x))
        v = tf.zeros(shape=tf.shape(x))
        b1t = tf.ones([])
        b2t = tf.ones([])
        cellA_state = self.cellA.zero_state(tf.size(x), tf.float32)
        self.initial_state = [x, tf.zeros([]), b1t, b2t, m, v, cellA_state]

    # return state, fx
    def iter(self, f, i, state):
        x, sid, b1t, b2t, m, v, cellA_state = state

        fx, grad = self._get_fx(f, i, x)
        # self.optimizee_grad.append(grad)
        grad = tf.stop_gradient(grad)

        m = self.b1 * m + (1 - self.b1) * grad
        v = self.b2 * v + (1 - self.b2) * (grad ** 2)

        b1t *= self.b1
        b2t *= self.b2

        sv = tf.sqrt(v / (1 - b2t)) + self.eps
        # TODO
        last = tf.stack([grad / sv, (m / (1 - b1t)) / sv], 1)
        # last = tf.stack([grad], 1)
        # last = self._deepmind_log_encode(grad, p=10)
        # last = tf.stack([grad/tf.norm(grad, ord=2)], 1)
        last = tf.nn.elu(self.fc(last, 20))

        with tf.variable_scope("cellA"):
            lastA, cellA_state = self.cellA(last, cellA_state)

        with tf.variable_scope("fc_A"):
            a = self.fc(lastA, 1, use_bias=True)[:, 0]

        a = tf.tanh(a) * self.lr

        x -= a

        return [x, sid + 1, b1t, b2t, m, v, cellA_state], fx