def __get_network__(self, encode_seq, decode_seq, is_train=True, reuse=False): w_init = tf.random_normal_initializer(stddev=0.02) g_init = tf.random_normal_initializer(1., 0.02) with tf.variable_scope("seq2seq_model", reuse=reuse) as vs: tl.layers.set_name_reuse(reuse) net_encode = InputLayer(encode_seq, name='in_root') net_decode = InputLayer(decode_seq, name="decode") net_rnn = Seq2Seq( net_encode, net_decode, cell_fn=tf.contrib.rnn.BasicLSTMCell, n_hidden=config.dim_hidden, initializer=tf.random_uniform_initializer(-0.1, 0.1), encode_sequence_length=tl.layers.retrieve_seq_length_op( net_encode.outputs), decode_sequence_length=tl.layers.retrieve_seq_length_op( net_decode.outputs), initial_state_encode=None, # dropout=(0.8 if is_train else None), dropout=None, n_layer=1, return_seq_2d=True, name='seq2seq') # self.net_rnn_seq2seq = net_rnn net_rnn_seq2seq = net_rnn net_out_seq2seq = DenseLayer(net_rnn, n_units=1, act=tf.identity, name='dense2') if is_train: net_out_seq2seq = ReshapeLayer( net_out_seq2seq, (config.batch_size, config.out_seq_length + 1, 1), name="reshape_out") else: net_out_seq2seq = ReshapeLayer(net_out_seq2seq, (config.batch_size, 1, 1), name="reshape_out") # net_out_seq2seq = net_out_seq2seq # net_out = DenseLayer(net_rnn, n_units=64, act=tf.identity, name='dense1') # net_out = DenseLayer(net_rnn, n_units=1, act=tf.identity, name='dense2') # net_out = ReshapeLayer(net_out, (config.batch_size, config.out_seq_length + 1, 1), name="reshape_out") with tf.variable_scope(self.model_name, reuse=reuse) as vs: tl.layers.set_name_reuse(reuse) net_encode_query = InputLayer(self.query_x, name='in_root_query') net_decode_query = InputLayer(self.query_decode_seq, name="decode_query") net_rnn_query = RNNLayer( net_decode_query, cell_fn=tf.contrib.rnn.BasicLSTMCell, cell_init_args={"forget_bias": 1.0}, n_hidden=config.query_dim_hidden, initializer=tf.random_uniform_initializer(-0.1, 0.1), n_steps=config.out_seq_length, return_last=True, # return_last=False, # return_seq_2d=True, name="rnn_query") net_rnn_query = ExpandDimsLayer(net_rnn_query, axis=1, name="rnn_query_expand") net_rnn_query = TileLayer(net_rnn_query, [1, config.out_seq_length, 1], name="rnn_query_tile") net_rnn_query = ReshapeLayer( net_rnn_query, (config.batch_size * config.out_seq_length, config.query_dim_hidden), name="rnn_query_reshape") net_traffic_state = InputLayer(self.traffic_state, name="in_traffic_state") if is_train: net_rnn_traffic = ReshapeLayer( net_rnn_seq2seq, (config.batch_size, config.out_seq_length + 1, config.dim_hidden), name="reshape_traffic_q1") net_rnn_traffic.outputs = tf.slice( net_rnn_traffic.outputs, [0, 0, 0], [ config.batch_size, config.out_seq_length, config.dim_hidden ], name="slice_traffic_q") net_rnn_traffic = ReshapeLayer( net_rnn_traffic, (config.batch_size * config.out_seq_length, config.dim_hidden), name="reshape_traffic_q2") net_out = ConcatLayer([net_rnn_traffic, net_rnn_query], concat_dim=-1, name="concat_traffic_query1") else: net_out = ConcatLayer([net_traffic_state, net_rnn_query], concat_dim=-1, name="concat_traffic_query2") # net_out = DenseLayer(net_out, n_units=128, act=tf.nn.relu, name="dense_query1") # net_out = DenseLayer(net_out, n_units=32, act=tf.nn.relu, name="dense_query2") net_out = DenseLayer(net_out, n_units=1, act=tf.identity, name="dense_query3") # net_out = ReshapeLayer(net_out, (config.batch_size, config.out_seq_length + 1, 1), name="reshape_out") # if is_train: net_out = ReshapeLayer( net_out, (config.batch_size, config.out_seq_length, 1), name="reshape_out") # else: # net_out = ReshapeLayer(net_out, (config.batch_size, 1, 1), name="reshape_out") return net_rnn_seq2seq, net_out_seq2seq, net_rnn_query, net_out
def __get_network__(self, encode_seq, neighbour_seq, decode_seq, features, features_full, is_train=True, reuse=False): w_init = tf.random_normal_initializer(stddev=0.02) g_init = tf.random_normal_initializer(1., 0.02) with tf.variable_scope(self.model_name + "_spatial", reuse=reuse) as vs: tl.layers.set_name_reuse(reuse) inputs_x_root = InputLayer(encode_seq, name='in_root') inputs_x_nbor = InputLayer(neighbour_seq, name="in_neighbour") # encoding neighbour graph information n = ReshapeLayer(inputs_x_nbor, (config.batch_size * config.in_seq_length, config.num_neighbour), "reshape1") n.outputs = tf.expand_dims(n.outputs, axis=-1) n = Conv1d(n, 4, 4, 1, act=tf.identity, padding='SAME', W_init=w_init, name='conv1') n = BatchNormLayer(n, act=tf.nn.relu, is_train=is_train, gamma_init=g_init, name='bn1') n = MaxPool1d(n, 2, 2, padding='valid', name='maxpool1') n = FlattenLayer(n, name="flatten1") n = ReshapeLayer(n, (config.batch_size, config.in_seq_length, -1), name="reshape1_back") net_encode = ConcatLayer([inputs_x_root, n], concat_dim=-1, name="encode") net_decode = InputLayer(decode_seq, name="decode") net_rnn = Seq2Seq( net_encode, net_decode, cell_fn=tf.contrib.rnn.BasicLSTMCell, n_hidden=config.dim_hidden, initializer=tf.random_uniform_initializer(-0.1, 0.1), encode_sequence_length=tl.layers.retrieve_seq_length_op( net_encode.outputs), decode_sequence_length=tl.layers.retrieve_seq_length_op( net_decode.outputs), initial_state_encode=None, # dropout=(0.8 if is_train else None), dropout=None, n_layer=1, return_seq_2d=True, name='seq2seq') net_rnn_seq2seq = net_rnn net_spatial_out = DenseLayer(net_rnn, n_units=1, act=tf.identity, name='dense2') if is_train: net_spatial_out = ReshapeLayer( net_spatial_out, (config.batch_size, config.out_seq_length + 1, 1), name="reshape_out") else: net_spatial_out = ReshapeLayer(net_spatial_out, (config.batch_size, 1, 1), name="reshape_out") with tf.variable_scope(self.model_name + "_wide", reuse=reuse) as vs: tl.layers.set_name_reuse(reuse) # Features net_features = InputLayer(features, name="in_features") net_features_full = InputLayer(features_full, name="in_features_full") net_features_full = ReshapeLayer( net_features_full, (config.batch_size * (config.out_seq_length + 1), config.dim_features), name="reshape_feature_full_1") if is_train: net_features = ReshapeLayer( net_features, (config.batch_size * (config.out_seq_length + 1), config.dim_features), name="reshape_feature_1") else: net_features = ReshapeLayer(net_features, (config.batch_size * (1), config.dim_features), name="reshape_feature_1") self.net_features_dim = 32 net_features = DenseLayer(net_features, n_units=self.net_features_dim, act=tf.nn.relu, name='dense_features') net_features_full = DenseLayer(net_features_full, n_units=self.net_features_dim, act=tf.nn.relu, name='dense_features_full') # self.net_features = net_features net_wide_out = ConcatLayer([net_rnn_seq2seq, net_features], concat_dim=-1, name="concat_features") net_wide_out = DenseLayer(net_wide_out, n_units=1, act=tf.identity, name='dense2') if is_train: net_wide_out = ReshapeLayer( net_wide_out, (config.batch_size, config.out_seq_length + 1, 1), name="reshape_out") else: net_wide_out = ReshapeLayer(net_wide_out, (config.batch_size, 1, 1), name="reshape_out") with tf.variable_scope(self.model_name + "_query", reuse=reuse) as vs: tl.layers.set_name_reuse(reuse) net_decode_query = InputLayer(self.query_decode_seq, name="decode_query") net_rnn_query = RNNLayer( net_decode_query, cell_fn=tf.contrib.rnn.BasicLSTMCell, cell_init_args={"forget_bias": 1.0}, n_hidden=config.query_dim_hidden, initializer=tf.random_uniform_initializer(-0.1, 0.1), n_steps=config.out_seq_length, return_last=True, # return_last=False, # return_seq_2d=True, name="rnn_query") ''' net_rnn_query = DynamicRNNLayer( net_decode_query, cell_fn=tf.contrib.rnn.BasicLSTMCell, cell_init_args={"forget_bias": 1.0}, # n_hidden=config.query_dim_hidden, n_hidden=32, initializer=tf.random_uniform_initializer(-0.1, 0.1), return_last=True, # dropout=0.8, sequence_length=tl.layers.retrieve_seq_length_op(net_decode_query.outputs), # return_last=False, # return_seq_2d=True, name="rnn_query_dynamic" ) ''' net_rnn_query = ExpandDimsLayer(net_rnn_query, axis=1, name="rnn_query_expand") net_rnn_query = TileLayer(net_rnn_query, [1, config.out_seq_length, 1], name="rnn_query_tile") net_rnn_query = ReshapeLayer( net_rnn_query, (config.batch_size * config.out_seq_length, config.query_dim_hidden), name="rnn_query_reshape") # net_rnn_query = ReshapeLayer(net_rnn_query, (config.batch_size * config.out_seq_length, 32), name="rnn_query_reshape") # self.net_rnn_query = net_rnn_query net_traffic_state = InputLayer(self.traffic_state, name="in_traffic_state") ''' if is_train: net_rnn_traffic = ReshapeLayer(net_rnn_seq2seq, (config.batch_size, config.out_seq_length + 1, config.dim_hidden), name="reshape_traffic_q1") net_rnn_traffic.outputs = tf.slice(net_rnn_traffic.outputs, [0, 0, 0], [config.batch_size, config.out_seq_length, config.dim_hidden], name="slice_traffic_q") net_rnn_traffic = ReshapeLayer(net_rnn_traffic, (config.batch_size * config.out_seq_length, config.dim_hidden), name="reshape_traffic_q2") net_features_traffic = ReshapeLayer(net_features, (config.batch_size, config.out_seq_length + 1, self.net_features_dim), name="reshape_features_q1") net_features_traffic.outputs = tf.slice(net_features_traffic.outputs, [0, 0, 0], [config.batch_size, config.out_seq_length, self.net_features_dim], name="slice_features_q") net_features_traffic = ReshapeLayer(net_features_traffic, (config.batch_size * config.out_seq_length, self.net_features_dim), name="reshape_features_q2") net_query_out = ConcatLayer([net_rnn_traffic, net_features_traffic, net_rnn_query], concat_dim=-1, name="concat_traffic_query1") # net_query_out = ConcatLayer([net_rnn_traffic, net_rnn_query], concat_dim=-1, name="concat_traffic_query1") else: ''' net_features_traffic = ReshapeLayer( net_features_full, (config.batch_size, config.out_seq_length + 1, self.net_features_dim), name="reshape_features_q1") net_features_traffic.outputs = tf.slice( net_features_traffic.outputs, [0, 0, 0], [ config.batch_size, config.out_seq_length, self.net_features_dim ], name="slice_features_q") net_features_traffic = ReshapeLayer( net_features_traffic, (config.batch_size * config.out_seq_length, self.net_features_dim), name="reshape_features_q2") net_query_out = ConcatLayer( [net_traffic_state, net_features_traffic, net_rnn_query], concat_dim=-1, name="concat_traffic_query1") # net_rnn_traffic = ReshapeLayer(net_rnn_seq2seq, (config.batch_size, config.out_seq_length + 1, config.dim_hidden), name="reshape_traffic_q1") # net_rnn_traffic.outputs = tf.slice(net_rnn_traffic.outputs, [0, 0, 0], [config.batch_size, config.out_seq_length, config.dim_hidden], name="slice_traffic_q") # net_rnn_traffic = ReshapeLayer(net_rnn_traffic, (config.batch_size * config.out_seq_length, config.dim_hidden), name="reshape_traffic_q2") # net_query_out = ConcatLayer([net_rnn_traffic, net_features_traffic, net_rnn_query], concat_dim=-1, name="concat_traffic_query1") # net_out = DenseLayer(net_out, n_units=128, act=tf.nn.relu, name="dense_query1") # net_out = DenseLayer(net_out, n_units=64, act=tf.nn.relu, name="dense_query2") # net_query_out = DropoutLayer(net_query_out, keep=0.8, is_fix=True, is_train=is_train, name='drop_query3') net_query_out = DenseLayer(net_query_out, n_units=1, act=tf.identity, name="dense_query3") # net_out = ReshapeLayer(net_out, (config.batch_size, config.out_seq_length + 1, 1), name="reshape_out") # if is_train: net_query_out = ReshapeLayer( net_query_out, (config.batch_size, config.out_seq_length, 1), name="reshape_out") # else: # net_out = ReshapeLayer(net_out, (config.batch_size, 1, 1), name="reshape_out") # TODO residual net ''' if is_train: net_query_out.outputs = tf.add( net_query_out.outputs, tf.slice(net_wide_out.outputs, [0, 0, 0], [config.batch_size, config.out_seq_length, 1]), name="res_add" ) else: ''' net_base_pred = InputLayer(self.base_pred, name="in_net_base_pred") net_query_out.outputs = tf.add(net_query_out.outputs, net_base_pred.outputs, name="res_add") return net_rnn_seq2seq, net_spatial_out, net_wide_out, net_rnn_query, net_query_out
def __get_network__(self, encode_seq, neighbour_seq, decode_seq, is_train=True, reuse=False): w_init = tf.random_normal_initializer(stddev=0.02) g_init = tf.random_normal_initializer(1., 0.02) with tf.variable_scope(self.model_name, reuse=reuse) as vs: tl.layers.set_name_reuse(reuse) inputs_x_root = InputLayer(encode_seq, name='in_root') inputs_x_nbor = InputLayer(neighbour_seq, name="in_neighbour") # encoding neighbour graph information n = ReshapeLayer(inputs_x_nbor, (config.batch_size * config.in_seq_length, config.num_neighbour), "reshape1") n.outputs = tf.expand_dims(n.outputs, axis=-1) n = Conv1d(n, 4, 4, 1, act=tf.identity, padding='SAME', W_init=w_init, name='conv1') n = BatchNormLayer(n, act=tf.nn.relu, is_train=is_train, gamma_init=g_init, name='bn1') n = MaxPool1d(n, 2, 2, padding='valid', name='maxpool1') n = FlattenLayer(n, name="flatten1") n = ReshapeLayer(n, (config.batch_size, config.in_seq_length, -1), name="reshape1_back") net_encode = ConcatLayer([inputs_x_root, n], concat_dim=-1, name="encode") net_decode = InputLayer(decode_seq, name="decode") net_rnn = Seq2Seq( net_encode, net_decode, cell_fn=tf.contrib.rnn.BasicLSTMCell, n_hidden=config.dim_hidden, initializer=tf.random_uniform_initializer(-0.1, 0.1), encode_sequence_length=tl.layers.retrieve_seq_length_op( net_encode.outputs), decode_sequence_length=tl.layers.retrieve_seq_length_op( net_decode.outputs), initial_state_encode=None, # dropout=(0.8 if is_train else None), dropout=None, n_layer=1, return_seq_2d=True, name='seq2seq') # net_out = DenseLayer(net_rnn, n_units=64, act=tf.identity, name='dense1') net_out = DenseLayer(net_rnn, n_units=1, act=tf.identity, name='dense2') if is_train: net_out = ReshapeLayer( net_out, (config.batch_size, config.out_seq_length + 1, 1), name="reshape_out") else: net_out = ReshapeLayer(net_out, (config.batch_size, 1, 1), name="reshape_out") self.net_rnn = net_rnn return net_out