示例#1
0
    def define_graph(self):
        with tf.variable_scope("Graph", reuse=tf.AUTO_REUSE):
            self.in_data = tf.placeholder(self._dtype,
                                          shape=(self._batch, self._in_seq,
                                                 self._h, self._w, self._in_c),
                                          name="input")
            self.gt_data = tf.placeholder(self._dtype,
                                          shape=(self._batch, self._out_seq,
                                                 self._h, self._w, 1),
                                          name="gt")
            self.global_step = tf.Variable(0, trainable=False)
            with tf.device('/device:GPU:0'):
                encoder_net = Encoder(self._batch,
                                      self._in_seq,
                                      gru_filter=c.ENCODER_GRU_FILTER,
                                      gru_in_chanel=c.ENCODER_GRU_INCHANEL,
                                      conv_kernel=c.CONV_KERNEL,
                                      conv_stride=c.CONV_STRIDE,
                                      h2h_kernel=c.H2H_KERNEL,
                                      i2h_kernel=c.I2H_KERNEL,
                                      height=self._h,
                                      width=self._w)
                for i in range(self._in_seq):
                    encoder_net.rnn_encoder(self.in_data[:, i, ...])
            states = encoder_net.rnn_states
            with tf.device('/device:GPU:1'):
                forecaster_net = Forecaster(
                    self._batch,
                    self._out_seq,
                    gru_filter=c.DECODER_GRU_FILTER,
                    gru_in_chanel=c.DECODER_GRU_INCHANEL,
                    deconv_kernel=c.DECONV_KERNEL,
                    deconv_stride=c.DECONV_STRIDE,
                    h2h_kernel=c.H2H_KERNEL,
                    i2h_kernel=c.I2H_KERNEL,
                    rnn_states=states,
                    height=self._h,
                    width=self._w)

                for i in range(self._out_seq):
                    forecaster_net.rnn_forecaster()
            pred = tf.concat(forecaster_net.pred, axis=1)

            with tf.variable_scope("loss"):
                gt = self.gt_data
                weights = get_loss_weight_symbol(pred)
                self.result = pred
                self.mse = weighted_mse(pred, gt, weights)
                self.mae = weighted_mae(pred, gt, weights)
                self.gdl = gdl_loss(pred, gt)
                loss = c.L1_LAMBDA * self.mse + c.L2_LAMBDA * self.mae + c.GDL_LAMBDA * self.gdl
                self.optimizer = tf.train.AdamOptimizer(c.LR).minimize(
                    loss, global_step=self.global_step)

                tf.summary.scalar('mse', self.mse)
                tf.summary.scalar('mae', self.mae)
                tf.summary.scalar('gdl', self.gdl)
                tf.summary.scalar('combine_loss', loss)
                self.summary = tf.summary.merge_all()
    def define_graph(self):
        with tf.variable_scope("Graph", reuse=tf.AUTO_REUSE):
            self.in_data = tf.placeholder(self._dtype,
                                          shape=(self._batch, self._in_seq,
                                                 self._h, self._w, self._in_c),
                                          name="input")
            self.gt_data = tf.placeholder(self._dtype,
                                          shape=(self._batch, self._out_seq,
                                                 self._h, self._w, 1),
                                          name="gt")
            self.global_step = tf.Variable(0, trainable=False)
            with tf.device('/device:GPU:0'):
                encoder_net = Encoder(self._batch,
                                      self._in_seq,
                                      gru_filter=c.ENCODER_GRU_FILTER,
                                      gru_in_chanel=c.ENCODER_GRU_INCHANEL,
                                      conv_kernel=c.CONV_KERNEL,
                                      conv_stride=c.CONV_STRIDE,
                                      h2h_kernel=c.H2H_KERNEL,
                                      i2h_kernel=c.I2H_KERNEL,
                                      height=self._h,
                                      width=self._w)
                if c.SEQUENCE_MODE:
                    for i in range(self._in_seq):
                        encoder_net.rnn_encoder_step(self.in_data[:, i, ...])
                else:
                    encoder_net.rnn_encoder(self.in_data)
            states = encoder_net.rnn_states
            with tf.device('/device:GPU:1'):
                forecaster_net = Forecaster(
                    self._batch,
                    self._out_seq,
                    gru_filter=c.DECODER_GRU_FILTER,
                    gru_in_chanel=c.DECODER_GRU_INCHANEL,
                    deconv_kernel=c.DECONV_KERNEL,
                    deconv_stride=c.DECONV_STRIDE,
                    h2h_kernel=c.H2H_KERNEL,
                    i2h_kernel=c.I2H_KERNEL,
                    rnn_states=states,
                    height=self._h,
                    width=self._w)
                if c.SEQUENCE_MODE:
                    for i in range(self._out_seq):
                        forecaster_net.rnn_forecaster_step()
                    pred = tf.concat(forecaster_net.pred, axis=1)
                else:
                    forecaster_net.rnn_forecaster()
                    pred = forecaster_net.pred

            with tf.variable_scope("loss"):
                gt = self.gt_data
                weights = get_loss_weight_symbol(pred)

                self.result = pred
                self.mse = tf.reduce_mean(tf.square(pred - gt))
                self.mae = weighted_mae(pred, gt, weights)
                self.gdl = gdl_loss(pred, gt)
                self.d_loss = self.result

                if c.ADVERSARIAL:
                    self.d_pred = tf.placeholder(
                        self._dtype,
                        (self._batch, self._out_seq, self._h, self._w, 1))
                    self.d_loss = tf.reduce_mean(tf.square(self.d_pred - gt))
                    self.loss = tf.cond(self.global_step > self.adv_involve,
                                        lambda: c.L1_LAMBDA * self.mae \
                                                + c.L2_LAMBDA * self.mse \
                                                + c.GDL_LAMBDA * self.gdl \
                                                + c.ADV_LAMBDA * self.d_loss,
                                        lambda: c.L1_LAMBDA * self.mae \
                                                + c.L2_LAMBDA * self.mse \
                                                + c.GDL_LAMBDA * self.gdl
                                        )
                    # self.loss = c.L1_LAMBDA * self.mae \
                    #        + c.L2_LAMBDA * self.mse \
                    #        + c.GDL_LAMBDA * self.gdl \
                    #        + c.ADV_LAMBDA * self.d_loss
                else:
                    self.loss = c.L1_LAMBDA * self.mae + c.L2_LAMBDA * self.mse + c.GDL_LAMBDA * self.gdl
                    self.d_loss = self.loss

                self.optimizer = tf.train.AdamOptimizer(c.LR).minimize(
                    self.loss, global_step=self.global_step)

                self.summary = tf.summary.merge([
                    tf.summary.scalar('mse', self.mse),
                    tf.summary.scalar('mae', self.mae),
                    tf.summary.scalar('gdl', self.gdl),
                    tf.summary.scalar('combine_loss', self.loss)
                ])