Esempio n. 1
0
    def __get_network__(self,
                        encode_seq,
                        decode_seq,
                        is_train=True,
                        reuse=False):
        w_init = tf.random_normal_initializer(stddev=0.02)
        g_init = tf.random_normal_initializer(1., 0.02)

        with tf.variable_scope("seq2seq_model", reuse=reuse) as vs:
            tl.layers.set_name_reuse(reuse)
            net_encode = InputLayer(encode_seq, name='in_root')

            net_decode = InputLayer(decode_seq, name="decode")

            net_rnn = Seq2Seq(
                net_encode,
                net_decode,
                cell_fn=tf.contrib.rnn.BasicLSTMCell,
                n_hidden=config.dim_hidden,
                initializer=tf.random_uniform_initializer(-0.1, 0.1),
                encode_sequence_length=tl.layers.retrieve_seq_length_op(
                    net_encode.outputs),
                decode_sequence_length=tl.layers.retrieve_seq_length_op(
                    net_decode.outputs),
                initial_state_encode=None,
                # dropout=(0.8 if is_train else None),
                dropout=None,
                n_layer=1,
                return_seq_2d=True,
                name='seq2seq')
            # self.net_rnn_seq2seq = net_rnn
            net_rnn_seq2seq = net_rnn

            net_out_seq2seq = DenseLayer(net_rnn,
                                         n_units=1,
                                         act=tf.identity,
                                         name='dense2')
            if is_train:
                net_out_seq2seq = ReshapeLayer(
                    net_out_seq2seq,
                    (config.batch_size, config.out_seq_length + 1, 1),
                    name="reshape_out")
            else:
                net_out_seq2seq = ReshapeLayer(net_out_seq2seq,
                                               (config.batch_size, 1, 1),
                                               name="reshape_out")

            # net_out_seq2seq = net_out_seq2seq
            # net_out = DenseLayer(net_rnn, n_units=64, act=tf.identity, name='dense1')
            # net_out = DenseLayer(net_rnn, n_units=1, act=tf.identity, name='dense2')
            # net_out = ReshapeLayer(net_out, (config.batch_size, config.out_seq_length + 1, 1), name="reshape_out")

        with tf.variable_scope(self.model_name, reuse=reuse) as vs:
            tl.layers.set_name_reuse(reuse)
            net_encode_query = InputLayer(self.query_x, name='in_root_query')

            net_decode_query = InputLayer(self.query_decode_seq,
                                          name="decode_query")

            net_rnn_query = RNNLayer(
                net_decode_query,
                cell_fn=tf.contrib.rnn.BasicLSTMCell,
                cell_init_args={"forget_bias": 1.0},
                n_hidden=config.query_dim_hidden,
                initializer=tf.random_uniform_initializer(-0.1, 0.1),
                n_steps=config.out_seq_length,
                return_last=True,
                # return_last=False,
                # return_seq_2d=True,
                name="rnn_query")
            net_rnn_query = ExpandDimsLayer(net_rnn_query,
                                            axis=1,
                                            name="rnn_query_expand")
            net_rnn_query = TileLayer(net_rnn_query,
                                      [1, config.out_seq_length, 1],
                                      name="rnn_query_tile")
            net_rnn_query = ReshapeLayer(
                net_rnn_query, (config.batch_size * config.out_seq_length,
                                config.query_dim_hidden),
                name="rnn_query_reshape")

            net_traffic_state = InputLayer(self.traffic_state,
                                           name="in_traffic_state")

            if is_train:
                net_rnn_traffic = ReshapeLayer(
                    net_rnn_seq2seq,
                    (config.batch_size, config.out_seq_length + 1,
                     config.dim_hidden),
                    name="reshape_traffic_q1")
                net_rnn_traffic.outputs = tf.slice(
                    net_rnn_traffic.outputs, [0, 0, 0], [
                        config.batch_size, config.out_seq_length,
                        config.dim_hidden
                    ],
                    name="slice_traffic_q")
                net_rnn_traffic = ReshapeLayer(
                    net_rnn_traffic,
                    (config.batch_size * config.out_seq_length,
                     config.dim_hidden),
                    name="reshape_traffic_q2")
                net_out = ConcatLayer([net_rnn_traffic, net_rnn_query],
                                      concat_dim=-1,
                                      name="concat_traffic_query1")
            else:
                net_out = ConcatLayer([net_traffic_state, net_rnn_query],
                                      concat_dim=-1,
                                      name="concat_traffic_query2")

            # net_out = DenseLayer(net_out, n_units=128, act=tf.nn.relu, name="dense_query1")
            # net_out = DenseLayer(net_out, n_units=32, act=tf.nn.relu, name="dense_query2")
            net_out = DenseLayer(net_out,
                                 n_units=1,
                                 act=tf.identity,
                                 name="dense_query3")
            # net_out = ReshapeLayer(net_out, (config.batch_size, config.out_seq_length + 1, 1), name="reshape_out")
            # if is_train:
            net_out = ReshapeLayer(
                net_out, (config.batch_size, config.out_seq_length, 1),
                name="reshape_out")
            # else:
            #    net_out = ReshapeLayer(net_out, (config.batch_size, 1, 1), name="reshape_out")
        return net_rnn_seq2seq, net_out_seq2seq, net_rnn_query, net_out
Esempio n. 2
0
    def __get_network__(self,
                        encode_seq,
                        neighbour_seq,
                        decode_seq,
                        features,
                        features_full,
                        is_train=True,
                        reuse=False):
        w_init = tf.random_normal_initializer(stddev=0.02)
        g_init = tf.random_normal_initializer(1., 0.02)

        with tf.variable_scope(self.model_name + "_spatial",
                               reuse=reuse) as vs:
            tl.layers.set_name_reuse(reuse)
            inputs_x_root = InputLayer(encode_seq, name='in_root')
            inputs_x_nbor = InputLayer(neighbour_seq, name="in_neighbour")

            # encoding neighbour graph information
            n = ReshapeLayer(inputs_x_nbor,
                             (config.batch_size * config.in_seq_length,
                              config.num_neighbour), "reshape1")
            n.outputs = tf.expand_dims(n.outputs, axis=-1)
            n = Conv1d(n,
                       4,
                       4,
                       1,
                       act=tf.identity,
                       padding='SAME',
                       W_init=w_init,
                       name='conv1')
            n = BatchNormLayer(n,
                               act=tf.nn.relu,
                               is_train=is_train,
                               gamma_init=g_init,
                               name='bn1')
            n = MaxPool1d(n, 2, 2, padding='valid', name='maxpool1')
            n = FlattenLayer(n, name="flatten1")
            n = ReshapeLayer(n, (config.batch_size, config.in_seq_length, -1),
                             name="reshape1_back")

            net_encode = ConcatLayer([inputs_x_root, n],
                                     concat_dim=-1,
                                     name="encode")
            net_decode = InputLayer(decode_seq, name="decode")

            net_rnn = Seq2Seq(
                net_encode,
                net_decode,
                cell_fn=tf.contrib.rnn.BasicLSTMCell,
                n_hidden=config.dim_hidden,
                initializer=tf.random_uniform_initializer(-0.1, 0.1),
                encode_sequence_length=tl.layers.retrieve_seq_length_op(
                    net_encode.outputs),
                decode_sequence_length=tl.layers.retrieve_seq_length_op(
                    net_decode.outputs),
                initial_state_encode=None,
                # dropout=(0.8 if is_train else None),
                dropout=None,
                n_layer=1,
                return_seq_2d=True,
                name='seq2seq')
            net_rnn_seq2seq = net_rnn

            net_spatial_out = DenseLayer(net_rnn,
                                         n_units=1,
                                         act=tf.identity,
                                         name='dense2')
            if is_train:
                net_spatial_out = ReshapeLayer(
                    net_spatial_out,
                    (config.batch_size, config.out_seq_length + 1, 1),
                    name="reshape_out")
            else:
                net_spatial_out = ReshapeLayer(net_spatial_out,
                                               (config.batch_size, 1, 1),
                                               name="reshape_out")

        with tf.variable_scope(self.model_name + "_wide", reuse=reuse) as vs:
            tl.layers.set_name_reuse(reuse)
            # Features
            net_features = InputLayer(features, name="in_features")
            net_features_full = InputLayer(features_full,
                                           name="in_features_full")
            net_features_full = ReshapeLayer(
                net_features_full,
                (config.batch_size *
                 (config.out_seq_length + 1), config.dim_features),
                name="reshape_feature_full_1")
            if is_train:
                net_features = ReshapeLayer(
                    net_features,
                    (config.batch_size *
                     (config.out_seq_length + 1), config.dim_features),
                    name="reshape_feature_1")
            else:
                net_features = ReshapeLayer(net_features,
                                            (config.batch_size *
                                             (1), config.dim_features),
                                            name="reshape_feature_1")

            self.net_features_dim = 32
            net_features = DenseLayer(net_features,
                                      n_units=self.net_features_dim,
                                      act=tf.nn.relu,
                                      name='dense_features')
            net_features_full = DenseLayer(net_features_full,
                                           n_units=self.net_features_dim,
                                           act=tf.nn.relu,
                                           name='dense_features_full')
            # self.net_features = net_features

            net_wide_out = ConcatLayer([net_rnn_seq2seq, net_features],
                                       concat_dim=-1,
                                       name="concat_features")
            net_wide_out = DenseLayer(net_wide_out,
                                      n_units=1,
                                      act=tf.identity,
                                      name='dense2')

            if is_train:
                net_wide_out = ReshapeLayer(
                    net_wide_out,
                    (config.batch_size, config.out_seq_length + 1, 1),
                    name="reshape_out")
            else:
                net_wide_out = ReshapeLayer(net_wide_out,
                                            (config.batch_size, 1, 1),
                                            name="reshape_out")

        with tf.variable_scope(self.model_name + "_query", reuse=reuse) as vs:
            tl.layers.set_name_reuse(reuse)

            net_decode_query = InputLayer(self.query_decode_seq,
                                          name="decode_query")

            net_rnn_query = RNNLayer(
                net_decode_query,
                cell_fn=tf.contrib.rnn.BasicLSTMCell,
                cell_init_args={"forget_bias": 1.0},
                n_hidden=config.query_dim_hidden,
                initializer=tf.random_uniform_initializer(-0.1, 0.1),
                n_steps=config.out_seq_length,
                return_last=True,

                # return_last=False,
                # return_seq_2d=True,
                name="rnn_query")
            '''
            net_rnn_query = DynamicRNNLayer(
                net_decode_query,
                cell_fn=tf.contrib.rnn.BasicLSTMCell,
                cell_init_args={"forget_bias": 1.0},
                # n_hidden=config.query_dim_hidden,
                n_hidden=32,
                initializer=tf.random_uniform_initializer(-0.1, 0.1),
                return_last=True,
                # dropout=0.8,
                sequence_length=tl.layers.retrieve_seq_length_op(net_decode_query.outputs),
                # return_last=False,
                # return_seq_2d=True,
                name="rnn_query_dynamic"
            )
            '''

            net_rnn_query = ExpandDimsLayer(net_rnn_query,
                                            axis=1,
                                            name="rnn_query_expand")
            net_rnn_query = TileLayer(net_rnn_query,
                                      [1, config.out_seq_length, 1],
                                      name="rnn_query_tile")
            net_rnn_query = ReshapeLayer(
                net_rnn_query, (config.batch_size * config.out_seq_length,
                                config.query_dim_hidden),
                name="rnn_query_reshape")
            # net_rnn_query = ReshapeLayer(net_rnn_query, (config.batch_size * config.out_seq_length, 32), name="rnn_query_reshape")

            # self.net_rnn_query = net_rnn_query

            net_traffic_state = InputLayer(self.traffic_state,
                                           name="in_traffic_state")
            '''
            if is_train:
                net_rnn_traffic = ReshapeLayer(net_rnn_seq2seq, (config.batch_size, config.out_seq_length + 1, config.dim_hidden), name="reshape_traffic_q1")
                net_rnn_traffic.outputs = tf.slice(net_rnn_traffic.outputs, [0, 0, 0], [config.batch_size, config.out_seq_length, config.dim_hidden], name="slice_traffic_q")
                net_rnn_traffic = ReshapeLayer(net_rnn_traffic, (config.batch_size * config.out_seq_length, config.dim_hidden), name="reshape_traffic_q2")

                net_features_traffic = ReshapeLayer(net_features, (config.batch_size, config.out_seq_length + 1, self.net_features_dim), name="reshape_features_q1")
                net_features_traffic.outputs = tf.slice(net_features_traffic.outputs, [0, 0, 0], [config.batch_size, config.out_seq_length, self.net_features_dim], name="slice_features_q")
                net_features_traffic = ReshapeLayer(net_features_traffic, (config.batch_size * config.out_seq_length, self.net_features_dim), name="reshape_features_q2")

                net_query_out = ConcatLayer([net_rnn_traffic, net_features_traffic, net_rnn_query], concat_dim=-1, name="concat_traffic_query1")
                # net_query_out = ConcatLayer([net_rnn_traffic, net_rnn_query], concat_dim=-1, name="concat_traffic_query1")
            else:
            '''
            net_features_traffic = ReshapeLayer(
                net_features_full,
                (config.batch_size, config.out_seq_length + 1,
                 self.net_features_dim),
                name="reshape_features_q1")
            net_features_traffic.outputs = tf.slice(
                net_features_traffic.outputs, [0, 0, 0], [
                    config.batch_size, config.out_seq_length,
                    self.net_features_dim
                ],
                name="slice_features_q")
            net_features_traffic = ReshapeLayer(
                net_features_traffic,
                (config.batch_size * config.out_seq_length,
                 self.net_features_dim),
                name="reshape_features_q2")

            net_query_out = ConcatLayer(
                [net_traffic_state, net_features_traffic, net_rnn_query],
                concat_dim=-1,
                name="concat_traffic_query1")
            # net_rnn_traffic = ReshapeLayer(net_rnn_seq2seq, (config.batch_size, config.out_seq_length + 1, config.dim_hidden), name="reshape_traffic_q1")
            # net_rnn_traffic.outputs = tf.slice(net_rnn_traffic.outputs, [0, 0, 0], [config.batch_size, config.out_seq_length, config.dim_hidden], name="slice_traffic_q")
            # net_rnn_traffic = ReshapeLayer(net_rnn_traffic, (config.batch_size * config.out_seq_length, config.dim_hidden), name="reshape_traffic_q2")
            # net_query_out = ConcatLayer([net_rnn_traffic, net_features_traffic, net_rnn_query], concat_dim=-1, name="concat_traffic_query1")

            # net_out = DenseLayer(net_out, n_units=128, act=tf.nn.relu, name="dense_query1")
            # net_out = DenseLayer(net_out, n_units=64, act=tf.nn.relu, name="dense_query2")
            # net_query_out = DropoutLayer(net_query_out, keep=0.8, is_fix=True, is_train=is_train, name='drop_query3')
            net_query_out = DenseLayer(net_query_out,
                                       n_units=1,
                                       act=tf.identity,
                                       name="dense_query3")
            # net_out = ReshapeLayer(net_out, (config.batch_size, config.out_seq_length + 1, 1), name="reshape_out")
            # if is_train:
            net_query_out = ReshapeLayer(
                net_query_out, (config.batch_size, config.out_seq_length, 1),
                name="reshape_out")
            # else:
            #    net_out = ReshapeLayer(net_out, (config.batch_size, 1, 1), name="reshape_out")

            # TODO residual net
            '''
            if is_train:
                net_query_out.outputs = tf.add(
                    net_query_out.outputs,
                    tf.slice(net_wide_out.outputs, [0, 0, 0], [config.batch_size, config.out_seq_length, 1]),
                    name="res_add"
                )
            else:
            '''
            net_base_pred = InputLayer(self.base_pred, name="in_net_base_pred")
            net_query_out.outputs = tf.add(net_query_out.outputs,
                                           net_base_pred.outputs,
                                           name="res_add")

        return net_rnn_seq2seq, net_spatial_out, net_wide_out, net_rnn_query, net_query_out
Esempio n. 3
0
    def __get_network__(self,
                        encode_seq,
                        neighbour_seq,
                        decode_seq,
                        is_train=True,
                        reuse=False):
        w_init = tf.random_normal_initializer(stddev=0.02)
        g_init = tf.random_normal_initializer(1., 0.02)

        with tf.variable_scope(self.model_name, reuse=reuse) as vs:
            tl.layers.set_name_reuse(reuse)
            inputs_x_root = InputLayer(encode_seq, name='in_root')
            inputs_x_nbor = InputLayer(neighbour_seq, name="in_neighbour")

            # encoding neighbour graph information
            n = ReshapeLayer(inputs_x_nbor,
                             (config.batch_size * config.in_seq_length,
                              config.num_neighbour), "reshape1")
            n.outputs = tf.expand_dims(n.outputs, axis=-1)
            n = Conv1d(n,
                       4,
                       4,
                       1,
                       act=tf.identity,
                       padding='SAME',
                       W_init=w_init,
                       name='conv1')
            n = BatchNormLayer(n,
                               act=tf.nn.relu,
                               is_train=is_train,
                               gamma_init=g_init,
                               name='bn1')
            n = MaxPool1d(n, 2, 2, padding='valid', name='maxpool1')
            n = FlattenLayer(n, name="flatten1")
            n = ReshapeLayer(n, (config.batch_size, config.in_seq_length, -1),
                             name="reshape1_back")

            net_encode = ConcatLayer([inputs_x_root, n],
                                     concat_dim=-1,
                                     name="encode")
            net_decode = InputLayer(decode_seq, name="decode")

            net_rnn = Seq2Seq(
                net_encode,
                net_decode,
                cell_fn=tf.contrib.rnn.BasicLSTMCell,
                n_hidden=config.dim_hidden,
                initializer=tf.random_uniform_initializer(-0.1, 0.1),
                encode_sequence_length=tl.layers.retrieve_seq_length_op(
                    net_encode.outputs),
                decode_sequence_length=tl.layers.retrieve_seq_length_op(
                    net_decode.outputs),
                initial_state_encode=None,
                # dropout=(0.8 if is_train else None),
                dropout=None,
                n_layer=1,
                return_seq_2d=True,
                name='seq2seq')
            # net_out = DenseLayer(net_rnn, n_units=64, act=tf.identity, name='dense1')
            net_out = DenseLayer(net_rnn,
                                 n_units=1,
                                 act=tf.identity,
                                 name='dense2')
            if is_train:
                net_out = ReshapeLayer(
                    net_out, (config.batch_size, config.out_seq_length + 1, 1),
                    name="reshape_out")
            else:
                net_out = ReshapeLayer(net_out, (config.batch_size, 1, 1),
                                       name="reshape_out")

            self.net_rnn = net_rnn

            return net_out