# copy dataset info to model dir
    copy_dataset_info(config)

    # Transfer data to local vars
    batch_num = config.batch_size
    validation_split = 0.1
    epochs = config.epochs
    sqrd_diff_loss = config.sqrd_diff_loss
    ls_split = config.ls_split
    test_data_types = config.data_type.copy()
    if "inflow" in test_data_types: test_data_types.remove("inflow")

    # Multitile model -> scale by 2
    config.tile_scale = 2

    keras_batch_manager = BatchManager(config, 1, 1)
    sup_param_count = keras_batch_manager.supervised_param_count

    in_out_dim = 3 if "density" in config.data_type or "levelset" in config.data_type else 2
    in_out_dim = in_out_dim + 1 if config.is_3d else in_out_dim

    try:
        tiles_use_global = config.tiles_use_global
    except AttributeError:
        tiles_use_global = False
    if keras_batch_manager.use_tiles and tiles_use_global:
        # extra channels for global information
        in_out_dim += 2

    input_shape = (config.res_z, ) if config.is_3d else ()
    input_shape += (config.res_y, config.res_x, in_out_dim)
    batch_num = config.batch_size
    validation_split = 0.1
    epochs = config.epochs
    input_frame_count = config.input_frame_count
    prediction_window = config.w_num
    decode_predictions = config.decode_predictions
    skip_pred_steps = config.skip_pred_steps
    init_state_network = config.init_state_network
    in_out_states = config.in_out_states
    pred_gradient_loss = config.pred_gradient_loss
    ls_prediction_loss = config.ls_prediction_loss
    ls_supervision = config.ls_supervision
    sqrd_diff_loss = config.sqrd_diff_loss
    ls_split = config.ls_split

    keras_batch_manager = BatchManager(config, input_frame_count, prediction_window)
    sup_param_count = keras_batch_manager.supervised_param_count

    in_out_dim = 3 if "density" in config.data_type else 2
    in_out_dim = in_out_dim + 1 if config.is_3d else in_out_dim
    input_shape = (input_frame_count,)
    input_shape += (keras_batch_manager.res_z,) if config.is_3d else ()
    input_shape += (keras_batch_manager.res_y, keras_batch_manager.res_x, in_out_dim)

    print("Input Shape: {}".format(input_shape))

    rec_pred = RecursivePredictionCleanSplit(config=config, input_shape=input_shape, decode_predictions=decode_predictions, skip_pred_steps=skip_pred_steps, init_state_network=init_state_network, in_out_states=in_out_states, pred_gradient_loss=pred_gradient_loss, ls_prediction_loss=ls_prediction_loss, ls_supervision=ls_supervision, sqrd_diff_loss=sqrd_diff_loss, ls_split=ls_split, supervised_parameters=sup_param_count) 

    # Train =====================================================================================================
    if config.is_train:
        if config.load_path:
示例#3
0
    epochs = config.epochs
    input_frame_count = config.input_frame_count
    prediction_window = config.w_num
    decode_predictions = config.decode_predictions
    skip_pred_steps = config.skip_pred_steps
    init_state_network = config.init_state_network
    in_out_states = config.in_out_states
    pred_gradient_loss = config.pred_gradient_loss
    ls_prediction_loss = config.ls_prediction_loss
    ls_supervision = config.ls_supervision
    sqrd_diff_loss = config.sqrd_diff_loss
    ls_split = config.ls_split
    test_data_types = config.data_type.copy()
    if "inflow" in test_data_types: test_data_types.remove("inflow")

    keras_batch_manager = BatchManager(config, input_frame_count,
                                       prediction_window)
    sup_param_count = keras_batch_manager.supervised_param_count

    in_out_dim = 3 if "density" in config.data_type or "levelset" in config.data_type else 2
    in_out_dim = in_out_dim + 1 if config.is_3d else in_out_dim
    input_shape = (input_frame_count, )
    input_shape += (config.res_z, ) if config.is_3d else ()
    input_shape += (config.res_y, config.res_x, in_out_dim)

    train_prediction_only = config.train_prediction_only and config.is_train and config.load_path is not ''
    if train_prediction_only:
        print("Training only the prediction network!")

    ## Write config to file
    config_d = vars(config) if config else {}
    unparsed_d = vars(unparsed) if unparsed else {}
示例#4
0
    def _build_model(self, **kwargs):
        print("Building Model")
        is_3d = self.is_3d
        velo_dim = 3 if is_3d else 2

        batch_manager = kwargs.get("batch_manager", None)
        if batch_manager is None and self.advection_loss > 0.0:
            print("WARNING: no batch manager found... creating dummy")
            batch_manager = BatchManager(self.config,
                                         self.config.input_frame_count,
                                         self.config.w_num,
                                         data_args_path=kwargs.get(
                                             "data_args_path", None))

        # Load predefined model layouts
        self._create_submodels()

        enc = self.ae._encoder
        dec = self.ae._decoder
        p_pred = self.ae._p_pred
        pred = self.pred.model

        # State init network
        if self.in_out_states:
            state_init_in = Input(shape=(self.w_num, self.z_num),
                                  dtype="float32",
                                  name="State_Init_Input")
            state_init_x = Reshape((self.w_num * self.z_num, ),
                                   name="Reshape_io_states")(state_init_in)
            state_init_x = Dense(128)(state_init_x)
            state_init_x = LeakyReLU()(state_init_x)
            state_init_x = Dense(2 * self.pred.encoder_lstm_neurons + 2 *
                                 self.pred.decoder_lstm_neurons)(state_init_x)
            state_init_x = LeakyReLU()(state_init_x)
            state_init_x = Reshape((4, self.pred.encoder_lstm_neurons),
                                   name="Reshape_io_states_2")(state_init_x)
            self.state_init_model = Model(name="State_Init",
                                          inputs=state_init_in,
                                          outputs=state_init_x)

        if self.stateful:
            inputs = Input(
                batch_shape=(self.b_num, ) + self.input_shape,
                dtype="float32",
                name="Combined_AE_Input_Fields")  # (b, input_depth, x, y, c)
        else:
            inputs = Input(
                shape=self.input_shape,
                dtype="float32",
                name="Combined_AE_Input_Fields")  # (b, input_depth, y, x, c)

        # Input for GT supervised parameters (e.g. rotation and position)
        # -> (b_num, 14, 2)
        sup_param_inputs = Input(shape=(self.input_shape[0],
                                        self.sup_param_count),
                                 dtype="float32",
                                 name="Combined_AE_Input_Sup_Param")

        if self.use_inflow or self.advection_loss > 0.0:
            input_inflow = Input(
                shape=self.input_inflow_shape,
                dtype="float32",
                name="Inflow_Input")  # (b, input_depth, y, x, 1)

        if self.ls_split > 0.0 and not self.train_prediction_only:
            inputs_full = Lambda(lambda x: x[:, 0:1],
                                 name="ls_split_slice")(inputs)
            inputs_full = Lambda(lambda x: K.squeeze(x, 1),
                                 name="ls_split_0")(inputs_full)
            inputs_vel = Lambda(lambda x: K.concatenate([
                x[..., 0:velo_dim],
                K.zeros_like(x)[..., velo_dim:velo_dim + 1]
            ],
                                                        axis=-1),
                                name="ls_split_1")(inputs_full)
            inputs_den = Lambda(lambda x: K.concatenate([
                K.zeros_like(x)[..., 0:velo_dim], x[..., velo_dim:velo_dim + 1]
            ],
                                                        axis=-1),
                                name="ls_split_2")(inputs_full)
            z_vel = enc(inputs_vel)
            z_vel = Lambda(lambda x: x, name="z_vel")(z_vel)
            z_den = enc(inputs_den)
            z_den = Lambda(lambda x: x, name="z_den")(z_den)

        enc_input = None
        enc_input_range = inputs.shape[
            1] if self.ls_prediction_loss else self.w_num
        for i in range(enc_input_range):  # input depth iteration
            if enc_input == None:
                enc_input = Lambda(lambda x: x[:, i],
                                   name="Slice_enc_input_{}".format(i))(inputs)
                enc_input = enc(enc_input)
                enc_input = Lambda(lambda x: K.expand_dims(x, axis=1))(
                    enc_input)
            else:
                temp_enc = Lambda(lambda x: x[:, i],
                                  name="Slice_enc_input_{}".format(i))(inputs)
                temp_enc = enc(temp_enc)
                encoded = Lambda(lambda x: K.expand_dims(x, axis=1))(temp_enc)
                enc_input = concatenate([enc_input, encoded],
                                        axis=1)  # (b, input_depth, z)

        # directly extract z to apply supervised latent space loss afterwards
        z = Lambda(lambda x: x[:, 0:1], name="Slice_z")(enc_input)

        # Overwrite supervised latent space entries in enc_input
        # e.g. sup_param_inputs -> (b,14,2)
        enc_input = Lambda(lambda x: x[:, :, 0:-self.sup_param_count],
                           name="sup_param_count_slice")(enc_input)

        # (b_num, 14, 2) -> (b_num, w_num, 2)
        first_input_sup_params = Lambda(
            lambda x: x[:, :self.w_num],
            name="first_input_sup_param_slice")(sup_param_inputs)
        enc_input = concatenate([enc_input, first_input_sup_params],
                                axis=2,
                                name="enc_input_sup_param_concat")

        rec_input = enc_input

        if self.in_out_states:
            if self.init_state_network:
                pred_states_init = self.state_init_model(rec_input)
            else:
                pred_states_init = Lambda(lambda x: K.zeros(
                    (self.b_num, 4, self.pred.encoder_lstm_neurons)
                ))(
                    inputs
                )  # lambda is quickhack to make initializing with zero possible (input tensor does not really matter)...

            def slice_states(x):
                return tf.unstack(x, axis=1)

            pred_states_0_0, pred_states_0_1, pred_states_1_0, pred_states_1_1 = Lambda(
                slice_states)(pred_states_init)

        if self.ls_prediction_loss:
            rec_output_ls = None
        rec_output = None
        adv_output = None
        rec_den = None
        for i in range(self.recursive_prediction):
            if self.in_out_states:
                x, pred_states_0_0, pred_states_0_1, pred_states_1_0, pred_states_1_1 = pred(
                    [
                        rec_input, pred_states_0_0, pred_states_0_1,
                        pred_states_1_0, pred_states_1_1
                    ])
            else:
                x = pred([rec_input])
            x = self.pred._fix_output_dimension(x)

            # predicted delta
            # add now to previous input
            pred_add_first_elem = Lambda(
                lambda x: x[:, -self.pred.out_w_num:None],
                name="rec_input_add_slice_{}".format(i))(rec_input)
            x = Add(name="Pred_Add_{}".format(i))(
                [pred_add_first_elem, x])  # previous z + predicted delta z

            if self.ls_supervision:
                pred_x = Lambda(lambda x: x[:, :, 0:-self.sup_param_count],
                                name="pred_x_slice_{}".format(i))(x)

                sup_param_real = Lambda(
                    lambda x: x[:, self.w_num + i:self.w_num + self.pred.
                                out_w_num + i],
                    name="sup_param_real_{}".format(i))(sup_param_inputs)
                x = concatenate(
                    [pred_x, sup_param_real],
                    axis=2,
                    name="Pred_Real_Supervised_Concat_{}".format(i))

            rec_input = Lambda(lambda x: x[:, self.pred.out_w_num:None],
                               name="rec_input_slice_{}".format(i))(rec_input)

            rec_input = concatenate([rec_input, x],
                                    axis=1,
                                    name="Pred_Input_Concat_{}".format(i))
            rec_input_last = x
            if self.decode_predictions:
                if self.ls_prediction_loss:
                    x_ls = x
                x = dec(
                    Reshape((self.z_num, ),
                            name="Reshape_xDecPred_{}".format(i))(x))

            # ########################################################################################################################
            # density/ls advection loss
            # 0) get first GT density field that is to be advected (0,1) -> 2 [take 1]
            # 0) denormalize current passive GT field (z,y,x,1)
            # 1) extract velocity array (z,y,x,3) [or (...,2)]
            # 2) denormalize velocity -> v = keras_data.denorm_vel(v)
            # 3) apply inflow region or obstacle subtract
            # 4) use current passive field (z,y,x,1) as advection src
            # 5) call advect(src, v, dt=keras_data.time_step, mac_adv=False, name="density")
            # 6) store as d+1 for usage in next frame -> rec_den
            # 7) normalize returned advected passive quantity
            # 8) hand over to loss -> (advect(d^t,v^t), d^t+1)
            # 9) use the advected density for reencoding
            # 10) start at 1)

            if self.advection_loss > 0.0 and i < self.recursive_prediction - 1:
                assert self.decode_predictions, (
                    "decode_predictions must be used")
                cur_decoded_pred = x
                # 0) get first GT density field that is to be advected (0,1) -> 2 [take 1]
                if rec_den == None:
                    rec_den = Lambda(lambda x: x[:, self.w_num - 1, ...,
                                                 velo_dim:velo_dim + 1],
                                     name="gt_passive_{}".format(i))(inputs)
                    rec_den = batch_manager.denorm(rec_den,
                                                   self.passive_data_type,
                                                   as_layer=True)

                # 1) extract velocity array (z,y,x,3) [or (...,2)]
                pred_vel = Lambda(
                    lambda x: x[..., 0:velo_dim],
                    name="vel_extract_{}".format(i))(cur_decoded_pred)
                # 2) denormalize velocity -> v = keras_data.denorm_vel(v)
                denorm_pred_vel = batch_manager.denorm(pred_vel,
                                                       "velocity",
                                                       as_layer=True)
                # 3) apply inflow region or obstacle subtract
                cur_inflow = Lambda(
                    lambda x: x[:, self.w_num + i],
                    name="inflow_extract_{}".format(self.w_num +
                                                    i))(input_inflow)
                rec_den = Lambda(
                    lambda x: K.tf.where(tf.greater(x[0], 0.0), x[0], x[1]))(
                        [cur_inflow, rec_den])
                # 4) use current passive field (z,y,x,1) as advection src
                # 5) call advect(src, v, dt=keras_data.time_step, mac_adv=False, name="density")
                # 6) store as d+1 for usage in next frame -> rec_den
                #print("4) + 5) + 6)")
                rec_den = Lambda(advect,
                                 arguments={
                                     'dt': batch_manager.time_step,
                                     'mac_adv': False,
                                     'name': self.passive_data_type
                                 })([rec_den, denorm_pred_vel])
                # 7) normalize returned advected passive quantity
                rec_den_norm = batch_manager.norm(rec_den,
                                                  self.passive_data_type,
                                                  as_layer=True)
                # 8) hand over to loss -> (advect(d^t,v^t), d^t+1)
                if adv_output == None or self.only_last_prediction:
                    rec_den_norm = Lambda(lambda x: K.expand_dims(x, axis=1))(
                        rec_den_norm)
                    adv_output = rec_den_norm
                else:
                    rec_den_norm = Lambda(lambda x: K.expand_dims(x, axis=1))(
                        rec_den_norm)
                    adv_output = concatenate(
                        [adv_output, rec_den_norm],
                        axis=1,
                        name="Adv_Passive_GT_Concat_{}".format(i))
                # 9) use the advected density for reencoding
                rec_den_norm_sq = Lambda(
                    lambda x: K.squeeze(x, 1),
                    name="rec_den_norm_squeeze_{}".format(i))(rec_den_norm)
                reencoded_input = Lambda(
                    lambda x: K.concatenate(x, axis=-1),
                    name="reencoding_vel_den_{}".format(self.w_num + i))(
                        [pred_vel, rec_den_norm_sq])
                z_reenc = enc(reencoded_input)
                z_reenc = Lambda(lambda x: K.expand_dims(x, axis=1))(z_reenc)
                # 10) take only density part of latent space and replace ls history
                # create mask with npa = np.zeros(shape); npa[:, x:y] = 1; m = K.constant( npa )
                m_np = np.zeros((self.pred.out_w_num, self.z_num),
                                dtype=np.float32)
                m_np[:, self.ls_split_idx:-self.sup_param_count] = 1.0
                # create lambda with a,b: a * m + b * (1-m)
                rec_input_last = Lambda(
                    lambda x: x[0] * K.constant(value=m_np, dtype='float32') +
                    x[1] * (1.0 - K.constant(value=m_np, dtype='float32')),
                    name="z_reenc_stitch_{}".format(self.w_num + i))(
                        [z_reenc, rec_input_last])
                # replace rec_input last elem
                rec_input = Lambda(
                    lambda x: x[:, :-1],
                    name="rec_input_cut_{}".format(self.w_num + i))(rec_input)
                rec_input = concatenate(
                    [rec_input, rec_input_last],
                    axis=1,
                    name="rec_input_concat_{}".format(self.w_num + i))

            if rec_output == None or self.only_last_prediction:
                rec_output = x
            else:
                rec_output = concatenate(
                    [rec_output, x],
                    axis=1,
                    name="Pred_Output_Concat_{}".format(i))

            if self.ls_prediction_loss:
                if rec_output_ls == None or self.only_last_prediction:
                    rec_output_ls = x_ls
                else:
                    rec_output_ls = concatenate(
                        [rec_output_ls, x_ls],
                        axis=1,
                        name="Pred_Output_LS_Concat_{}".format(i))

        if self.decode_predictions:
            if self.only_last_prediction:
                rec_out_shape = (1, ) + self.input_shape[1:]
            else:
                rec_out_shape = (
                    self.recursive_prediction, ) + self.input_shape[1:]
            rec_output = Reshape(rec_out_shape,
                                 name="Prediction_output")(rec_output)

        if self.decode_predictions:
            if self.ls_prediction_loss:
                if self.only_last_prediction:
                    GT_output_LS = Lambda(
                        lambda x: x[:, -1],
                        name="GT_output_LS_slice".format(i))(enc_input)
                    GT_output_LS_shape = (1, ) + int_shape(GT_output_LS)[1:]
                    GT_output_LS = Reshape(
                        GT_output_LS_shape,
                        name="Reshape_last_GT_ls")(GT_output_LS)
                else:
                    GT_output_LS = Lambda(
                        lambda x: x[:, -self.recursive_prediction:None],
                        name="GT_output_LS_slice".format(i))(enc_input)
        else:
            if self.only_last_prediction:
                GT_output = Lambda(
                    lambda x: x[:, -1],
                    name="GT_output_encoded_slice".format(i))(enc_input)
                GT_output_shape = (1, ) + int_shape(GT_output)[1:]
                GT_output = Reshape(GT_output_shape,
                                    name="Reshape_last_GT")(GT_output)
            else:
                GT_output = Lambda(
                    lambda x: x[:, -self.recursive_prediction:None],
                    name="GT_output_encoded_slice".format(i))(enc_input)

        # first half of pred_output is actual prediction, last half is GT to compare against in loss
        if not self.decode_predictions:
            pred_output = concatenate([rec_output, GT_output],
                                      axis=1,
                                      name="Prediction_Output")
        else:
            pred_output = rec_output

        if self.decode_predictions and self.ls_prediction_loss:
            pred_output_LS = concatenate([rec_output_ls, GT_output_LS],
                                         axis=1,
                                         name="Prediction_Output_LS")

        # supervised LS loss
        p_pred_output = p_pred(
            Reshape((self.z_num, ), name="Reshape_pPred")(z))

        # decoder loss
        if not self.train_prediction_only:
            ae_output = dec(
                Reshape((self.z_num, ), name="Reshape_zTrainPredOnly")(z))
            output_list = [ae_output, pred_output]
        else:
            output_list = [pred_output]

        if not self.train_prediction_only:
            output_list.append(p_pred_output)
        if self.ls_prediction_loss:
            output_list.append(pred_output_LS)
        if self.ls_split > 0.0 and not self.train_prediction_only:
            output_list.append(z_vel)
            output_list.append(z_den)
        if self.advection_loss > 0.0:
            output_list.append(adv_output)

        input_list = [inputs, sup_param_inputs]
        if self.use_inflow or self.advection_loss > 0.0:
            input_list.append(input_inflow)

        print("Setup Model")
        if len(self.gpus) > 1:
            with tf.device('/cpu:0'):
                self.model = Model(name="Combined_AE_LSTM",
                                   inputs=input_list,
                                   outputs=output_list)
        else:
            self.model = Model(name="Combined_AE_LSTM",
                               inputs=input_list,
                               outputs=output_list)
    def _build_model(self, **kwargs):
        print("Building Model")
        is_3d = self.is_3d
        velo_dim = 3 if is_3d else 2

        batch_manager = kwargs.get("batch_manager", None)
        if batch_manager is None and self.advection_loss > 0.0:
            print("WARNING: no batch manager found... creating dummy")
            batch_manager = BatchManager(self.config, self.config.input_frame_count, self.config.w_num, data_args_path=kwargs.get("data_args_path", None))

        self._create_submodels()

        enc = self.ae._encoder
        dec = self.ae._decoder

        pred = self.pred.model
        lc = self.latent_compression

        inputs = Input(shape=self.input_shape, dtype="float32", name="Combined_AE_Input_Fields") # (b, input_depth, tiles, y, x, c)

        print("Input shape: {}".format(inputs))

        def global_concat(x, t):
            # first concat x axis for individual rows
            c0 = K.concatenate([x[:, t, 0], x[:, t, 1], x[:, t, 2]], axis=2)
            c1 = K.concatenate([x[:, t, 3], x[:, t, 4], x[:, t, 5]], axis=2)
            c2 = K.concatenate([x[:, t, 6], x[:, t, 7], x[:, t, 8]], axis=2)
            # then concat y axis
            ct = K.concatenate([c0,c1,c2], axis=1)
            return ct

        # global concat
        global_t1 = Lambda(global_concat, arguments={'t': 1})(inputs)

        # loop over tiles -> generate z
        z_time = []
        for t in range(self.config.w_num):
            z_tiles = []
            #   0   1   2
            #   3   4   5
            #   6   7   8
            for i in range(9):
                # individual tile compression
                def slice_0(x, t, i):
                    return x[:, t, i]
                cur_input = Lambda(slice_0, name="Slice_ae_input_{}".format(i), arguments={"t": t, "i": i})(inputs)
                z_tiles.append(enc(cur_input))
            z_time.append(z_tiles)

        z_tconv_t0 = Lambda(lambda x: K.expand_dims(x, axis=1))(z_time[0][0])
        def concat_0(x):
            return K.concatenate([x[0], K.expand_dims(x[1], axis=1)], axis=1)
        for i in range(1,9):
            z_tconv_t0 = Lambda(concat_0, name="Concate_z_enc_{}".format(i))([z_tconv_t0, z_time[0][i]])

        # Use 1D Conv
        z_tconv_t0 = lc([z_tconv_t0])

        # predict t0 -> t1
        z_pred_t1 = pred([z_tconv_t0])
        z_pred_t1 = self.pred._fix_output_dimension(z_pred_t1)

        x_pred_t1 = dec(Reshape((self.z_num,), name="Reshape_xDecPred_{}".format(0))(z_pred_t1))
        # store prediction of t1 in var
        pred_output = x_pred_t1

        # Pad decoded fields to match total field; (b, 3*y, 3*x, c) -> needed for advection
        x_pred_t1 = Lambda(lambda x: tf.pad(x, [[0,0], [self.ae.input_shape[0], self.ae.input_shape[0]], [self.ae.input_shape[1], self.ae.input_shape[1]], [0,0]], "CONSTANT", constant_values=0))(x_pred_t1)

        # apply advection on t0 density with predicted velocity
        if self.advection_loss > 0.0:
            cur_decoded_pred = x_pred_t1

            # 0) get first GT density field that is to be advected (0,1) -> 2 [take 1]
            global_t1_den = Lambda(lambda x: x[..., velo_dim:velo_dim+1], name="gt_passive_{}".format(0))(global_t1)
            global_t1_den = batch_manager.denorm(global_t1_den, self.passive_data_type, as_layer=True)

            # 1) extract velocity array (z,y,x,3) [or (...,2)]
            pred_vel = Lambda(lambda x: x[...,0:velo_dim], name="vel_extract_{}".format(i))(cur_decoded_pred)
    
            # 2) denormalize velocity -> v = keras_data.denorm_vel(v)
            denorm_pred_vel = batch_manager.denorm(pred_vel, "velocity", as_layer=True)

            # 4) use current passive field (z,y,x,1) as advection src
            # 5) call advect(src, v, dt=keras_data.time_step, mac_adv=False, name="density")
            # 6) store as d+1 for usage in next frame -> rec_den
            global_t2_pred_den = Lambda(advect, arguments={'dt': batch_manager.time_step, 'mac_adv': False, 'name': self.passive_data_type}, name="Advect_{}".format(0))( [global_t1_den, denorm_pred_vel] )

            # 6.1) cutoff padding, that was added earlier
            global_t2_pred_den = Lambda(lambda x: x[:,self.ae.input_shape[0]:-self.ae.input_shape[0], self.ae.input_shape[1]:-self.ae.input_shape[1]])(global_t2_pred_den)

            # 7) normalize returned advected passive quantity
            rec_den_norm = batch_manager.norm(global_t2_pred_den, self.passive_data_type, as_layer=True)

            # 8) hand over to loss -> (advect(d^t,v^t), d^t+1)
            adv_output = rec_den_norm

        # decoder loss 
        output_list = [pred_output]
        
        if self.advection_loss > 0.0:
            # adv_output represents density of t2 produced by d2_adv = (v1_pred, d1_gt)
            output_list.append(adv_output)

        # inputs
        input_list = [inputs]

        print("Setup Model")
        if len(self.gpus) > 1:
            with tf.device('/cpu:0'):
                self.model = Model(name="Combined_AE_LSTM", inputs=input_list, outputs=output_list)
        else:
            self.model = Model(name="Combined_AE_LSTM", inputs=input_list, outputs=output_list)