def __init__(self, self_dict):
        # get arguments
        for arg in self_dict:
            setattr(self, arg, self_dict[arg])
        self.rng = np.random.RandomState(self.seed)
        tf.set_random_seed(self.seed)

        # network setting
        self.graph = tf.Graph()
        self.sess = tf.InteractiveSession(graph=self.graph)
        self.model_path = self.style_path + "/" + self.model_path
        with tf.gfile.GFile(self.model_path, 'rb') as f:
            graph_def = tf.GraphDef()
            graph_def.ParseFromString(f.read())
        # fix checkerboard artifacts: ksize should be divisible by the stride size
        # but it changes scale
        if self.pool1:
            for n in graph_def.node:
                if 'conv2d0_pre_relu/conv' in n.name:
                    n.attr['strides'].list.i[1:3] = [1, 1]

        # density input
        # shape: [D,H,W]
        d_shp = [None, None, None]
        self.d = tf.placeholder(dtype=tf.float32, shape=d_shp, name='density')

        # add batch dim / channel dim
        # shape: [1,D,H,W,1]
        d = tf.expand_dims(tf.expand_dims(self.d, axis=0), axis=-1)

        ######
        # sequence stylization
        self.d_opt = tf.placeholder(dtype=tf.float32, name='opt')

        if 'field' in self.field_type:
            if self.w_field == 1:
                self.c = 1
            elif self.w_field == 0:
                self.c = 3
            else:
                self.c = 4  # scalar (1) + vector field (3)
        elif 'density' in self.field_type:
            self.c = 1  # scalar field
        else:
            self.c = 3  # vector field

        if 'field' in self.field_type:
            d_opt = self.d_opt[:, :, ::-1] * tf.to_float(
                tf.shape(self.d_opt)[2])
            if self.w_field == 1:
                self.v_ = grad(d_opt)
            elif self.w_field == 0:
                self.v_ = curl(d_opt)
            else:
                pot = d_opt[..., 0, None]
                strf = d_opt[..., 1:]
                self.v_p = grad(pot)
                self.v_s = curl(strf)
                self.v_ = self.v_p * self.w_field + self.v_s * (1 -
                                                                self.w_field)

            v = self.v_[:, :, ::-1]
            vx = v[..., 0] / tf.to_float(tf.shape(v)[3])
            vy = -v[..., 1] / tf.to_float(tf.shape(v)[2])
            vz = v[..., 2] / tf.to_float(tf.shape(v)[1])
            v = tf.stack([vz, vy, vx], axis=-1)
            d = advect(d, v, order=self.adv_order, is_3d=True)
        elif 'velocity' in self.field_type:
            v = self.d_opt  # [1,D,H,W,3]
            d = advect(d, v, order=self.adv_order, is_3d=True)
        else:
            # stylize by addition
            d += self.d_opt  # [1,D,H,W,1]

        self.b_num = self.v_batch
        ######

        ######
        # velocity fields to advect gradients [B,D,H,W,3]
        if self.window_size > 1:
            self.v = tf.placeholder(dtype=tf.float32, name='velocity')
            self.g = tf.placeholder(dtype=tf.float32, name='gradient')
            self.adv = advect(self.g, self.v, order=self.adv_order, is_3d=True)
        ######

        # value clipping (d >= 0)
        d = tf.maximum(d, 0)

        # stylized 3d result
        self.d_out = d

        if self.rotate:
            d, self.rot_mat = rotate(d)  # [b,D,H,W,1]

            # compute rotation matrices
            self.rot_mat_, self.views = rot_mat(self.phi0,
                                                self.phi1,
                                                self.phi_unit,
                                                self.theta0,
                                                self.theta1,
                                                self.theta_unit,
                                                sample_type=self.sample_type,
                                                rng=self.rng,
                                                nv=self.n_views)

            if self.n_views is None:
                self.n_views = len(self.views)
            print('# vps:', self.n_views)
            assert (self.n_views % self.v_batch == 0)

        # render 3d volume
        transmit = tf.exp(-tf.cumsum(d[:, ::-1], axis=1) * self.transmit)
        d = tf.reduce_sum(d[:, ::-1] * transmit, axis=1)
        d /= tf.reduce_max(d)  # [0,1]

        # resize if needed
        if abs(self.resize_scale - 1) > 1e-7:
            h = tf.to_int32(
                tf.multiply(float(self.resize_scale),
                            tf.to_float(tf.shape(d)[1])))
            w = tf.to_int32(
                tf.multiply(float(self.resize_scale),
                            tf.to_float(tf.shape(d)[2])))
            d = tf.image.resize_images(d, size=[h, w])

        # change the range of image to [0-255]
        self.d_img = tf.concat([d * 255] * 3, axis=-1)  # [B,H,W,3]

        # plug-in to the pre-trained network
        imagenet_mean = 117.0
        d_preprocessed = self.d_img - imagenet_mean
        tf.import_graph_def(graph_def, {'input': d_preprocessed})
        self.layers = [
            op.name for op in self.graph.get_operations()
            if op.type == 'Conv2D' and 'import/' in op.name
        ]
Exemplo n.º 2
0
    def run(self, params):
        # loss
        self._loss(params)

        # optimizer
        self.opt_lr = tf.compat.v1.placeholder(tf.float32)

        # adaptive learning rate per octave
        if abs(self.lr_scale - 1) > 1e-7:
            self.lr = [
                self.lr / self.lr_scale**i for i in range(self.octave_n)
            ]

        # settings for octave process
        oct_size = []
        dhw = np.array(self.resolution)
        for _ in range(self.octave_n):
            oct_size.append(dhw)
            dhw = (dhw // self.octave_scale).astype(np.int)
        oct_size.reverse()
        print('input size for each octave', oct_size)

        p = params['p']

        g_opt = []
        if 'p' in self.target_field:
            for i in range(self.num_frames):
                n = p[i].shape[0]
                p_opt_shp = [n, 3]
                p_opt = np.zeros(shape=p_opt_shp, dtype=np.float32)
                g_opt.append(p_opt)

        if 'd' in self.target_field:
            r = params['r']
            for i in range(self.num_frames):
                n = p[i].shape[0]
                r_opt_shp = [n, self.num_kernels]
                r_opt_ = np.zeros(shape=r_opt_shp, dtype=np.float32)
                g_opt.append(r_opt_)

        # optimize
        loss_history = []
        d_intm = []
        opt_ = {}
        for octave in trange(self.octave_n, desc='octave'):
            loss_history_o = []
            d_intm_o = []

            feed = {}
            feed[self.res] = oct_size[octave]
            if self.content_img is not None:
                feed[self.content_feature] = self._content_feature(
                    self.content_img, oct_size[octave][1:])

            if self.style_img is not None:
                style_features = self._style_feature(self.style_img,
                                                     oct_size[octave][1:])

                for i in range(len(self.style_features)):
                    feed[self.style_features[i]] = style_features[i]

                if self.w_hist > 0:
                    hist_features = self._hist_feature(self.style_img,
                                                       oct_size[octave][1:])

                    for i in range(len(self.hist_features)):
                        feed[self.hist_features[i]] = hist_features[i]

            if type(self.lr) == list:
                lr = self.lr[octave]
            else:
                lr = self.lr

            # optimizer list for each batch
            for step in trange(self.iter, desc='iter'):
                g_tmp = [None] * self.num_frames

                for t in range(0, self.num_frames,
                               self.batch_size * self.interp):
                    for i in range(self.batch_size):
                        feed[self.p[i]] = p[t + i * self.interp]
                        feed[self.opt_ph[i]] = g_opt[t + i * self.interp]
                        if 'd' in self.target_field:
                            feed[self.r[i]] = r[t + i * self.interp]

                    # assign g_opt to self.opt through self.opt_ph
                    self.sess.run(self.opt_init, feed)

                    feed[self.opt_lr] = lr
                    opt_id = t // self.frames_per_opt
                    # opt_id = self.rng.randint(num_opt)
                    if opt_id in opt_:
                        train_op = opt_[opt_id]
                    else:
                        opt = tf.compat.v1.train.AdamOptimizer(
                            learning_rate=self.opt_lr)
                        train_op = opt.minimize(self.total_loss,
                                                var_list=self.opt)
                        self.sess.run(
                            tf.compat.v1.variables_initializer(
                                opt.variables()), feed)
                        opt_[opt_id] = train_op

                    # optimize
                    if self.rotate:
                        g_opt_ = None
                        l_ = []
                        for i in range(0, self.n_views, self.v_batch):
                            feed[self.rot_mat] = self.rot_mat_[i:i +
                                                               self.v_batch]
                            _, l_vp = self.sess.run(
                                [train_op, self.total_loss], feed)
                            l_.append(l_vp)

                            g_opt_i = self.sess.run(self.opt, feed)

                            if i == 0:
                                g_opt_ = np.nan_to_num(g_opt_i)
                            else:
                                for j in range(self.batch_size):
                                    g_opt_[j] += np.nan_to_num(g_opt_i[j])

                        loss_history_o.append(np.mean(l_))

                        if not 'uniform' in self.sample_type:
                            self.rot_mat_, self.views = rot_mat(
                                self.phi0,
                                self.phi1,
                                self.phi_unit,
                                self.theta0,
                                self.theta1,
                                self.theta_unit,
                                sample_type=self.sample_type,
                                rng=self.rng,
                                nv=self.n_views)

                        for i in range(self.batch_size):
                            g_opt_[i] /= (self.n_views / self.v_batch)
                    else:
                        _, l_ = self.sess.run([train_op, self.total_loss],
                                              feed)
                        loss_history_o.append(l_)

                        g_opt_ = self.sess.run(self.opt, feed)

                    for i in range(self.batch_size):
                        g_tmp[t + i * self.interp] = np.nan_to_num(
                            g_opt_[i]) - g_opt[t + i * self.interp]
                        if 'd' in self.target_field:
                            # masking by original density
                            g_tmp[t +
                                  i * self.interp] *= r[t +
                                                        i * self.interp][...,
                                                                         0,
                                                                         None]

                    if step == self.iter - 1 and octave < self.octave_n - 1:  # True or
                        if self.rotate:
                            feed[self.rot_mat] = [np.identity(3)
                                                  ] * self.batch_size

                        d_intm_ = self.sess.run(self.d_img, feed)
                        d_intm_o.append(d_intm_.astype(np.uint8))

                        # ## debug
                        # d_gray = self.sess.run(self.d_gray, feed)
                        # plt.subplot(121)
                        # plt.imshow(d_intm_[0,...])
                        # plt.subplot(122)
                        # plt.imshow(d_gray[0,...,0])
                        # plt.show()

                #########
                # gradient alignment
                if self.window_sigma > 0 and self.num_frames > 1:
                    g_tmp[:self.num_frames:self.interp] = denoise(
                        g_tmp[:self.num_frames:self.interp],
                        sigma=(self.window_sigma, 0, 0))

                for t in range(0, self.num_frames, self.interp):
                    g_opt[t] += g_tmp[t]

            loss_history.append(loss_history_o)
            if octave < self.octave_n - 1:
                d_intm.append(np.concatenate(d_intm_o, axis=0))

        if self.interp > 1:
            w = np.linspace(0, 1, self.interp + 1)
            for t in range(0, self.num_frames - 1, self.interp):
                for i in range(1, self.interp):
                    print(t + i, w[i])
                    g_opt[t + i] = g_opt[t] * (
                        1 - w[i]) + g_opt[t + self.interp] * w[i]

        # gather outputs
        result = {'l': loss_history, 'd_intm': d_intm, 'v': None, 'c': None}

        # final inference
        p_sty = [None] * self.num_frames
        v_sty = [None] * self.num_frames
        r_sty = [None] * self.num_frames
        d_sty = [None] * self.num_frames
        for t in range(0, self.num_frames, self.batch_size):
            for i in range(self.batch_size):
                feed[self.p[i]] = p[t + i]
                feed[self.opt_ph[i]] = g_opt[t + i]
                if 'd' in self.target_field:
                    feed[self.r[i]] = r[t + i]

            if self.rotate:
                feed[self.rot_mat] = [np.identity(3)] * self.batch_size

            self.sess.run(self.opt_init, feed)
            p_, d_, d_img = self.sess.run([self.p_out, self.d_out, self.d_img],
                                          feed)

            if 'p' in self.target_field:
                v_ = self.sess.run(self.v, feed)

            for i in range(self.batch_size):
                p_sty[t + i] = p_[i]
                if 'p' in self.target_field:
                    v_sty[t + i] = v_[i]

            d_sty[t:t + self.batch_size] = d_
            r_sty[t:t + self.batch_size] = d_img.astype(np.uint8)

        result['p'] = p_sty
        if 'p' in self.target_field:
            result['v'] = v_sty
        result['d'] = np.array(d_sty)
        result['r'] = np.array(r_sty)

        return result
    def run(self, params):
        # loss
        self._loss(params)

        # gradient
        g = tf.gradients(-self.total_loss, self.d_opt)[0]

        # laplacian gradient normalizer
        grad_norm = tffunc(np.float32)(partial(lap_normalize,
                                               scale_n=self.lap_n,
                                               c=self.c,
                                               is_3d=True))

        d = params['d']
        if 'mask' in params:
            mask = params['mask']
            mask = np.stack([mask] * self.c, axis=-1)

        if 'v' in params:
            v = params['v']

        # settings for octave process
        oct_size = []
        hw = np.int32(d.shape)[1:]
        for _ in range(self.octave_n):
            oct_size.append(hw.copy())
            hw = np.int32(np.float32(hw) / self.octave_scale)
        print('input size for each octave', oct_size)

        d_shp = [self.num_frames] + [s for s in oct_size[-1]] + [self.c]
        d_opt_ = np.zeros(shape=d_shp, dtype=np.float32)

        # optimize
        loss_history = []
        for octave in trange(self.octave_n):
            # octave process: scale-down for input
            if octave < self.octave_n - 1:
                d_ = []
                for i in range(self.num_frames):
                    d_.append(resize(d[i], oct_size[-octave - 1]))
                d_ = np.array(d_)

                if 'mask' in params:
                    mask_ = []
                    for i in range(self.num_frames):
                        m = resize(mask[i], oct_size[-octave - 1])
                        mask_.append(m)

                if 'v' in params:
                    v_ = []
                    for i in range(self.num_frames - 1):
                        v_.append(resize(v[i], oct_size[-octave - 1]))
                    v_ = np.array(v_)
            else:
                d_ = d
                if 'mask' in params: mask_ = mask
                if 'v' in params: v_ = v

            if octave > 0:
                d_opt__ = []
                for i in range(self.num_frames):
                    d_opt__.append(resize(d_opt_[i], oct_size[-octave - 1]))
                del d_opt_
                d_opt_ = np.array(d_opt__)

            feed = {}

            if 'content_target' in params:
                feed[self.content_feature] = self._content_feature(
                    params['content_target'], oct_size[-octave - 1][1:])

            if 'style_target' in params:
                style_features, style_denoms = self._style_feature(
                    params['style_target'], oct_size[-octave - 1][1:])

                for i in range(len(self.style_features)):
                    feed[self.style_features[i]] = style_features[i]
                    feed[self.style_denoms[i]] = style_denoms[i]
            if (octave == self.octave_n - 1):
                d_opt_iter = []
            for step in trange(self.iter):
                g__ = []
                for t in trange(self.num_frames):
                    feed[self.d] = d_[t]
                    feed[self.d_opt] = d_opt_[t, None]

                    if self.rotate:
                        g_ = None
                        l_ = 0
                        for i in range(0, self.n_views, self.v_batch):
                            feed[self.rot_mat] = self.rot_mat_[i:i +
                                                               self.v_batch]
                            g_vp, l_vp = self.sess.run([g, self.total_loss],
                                                       feed)
                            if g_ is None:
                                g_ = g_vp
                            else:
                                g_ += g_vp
                            l_ += l_vp
                        l_ /= np.ceil(self.n_views / self.v_batch)

                        if not 'uniform' in self.sample_type:
                            self.rot_mat_, self.views = rot_mat(
                                self.phi0,
                                self.phi1,
                                self.phi_unit,
                                self.theta0,
                                self.theta1,
                                self.theta_unit,
                                sample_type=self.sample_type,
                                rng=self.rng,
                                nv=self.n_views)
                    else:
                        g_, l_ = self.sess.run([g, self.total_loss], feed)
                        loss_history.append(l_)

                    g_ = denoise(g_, sigma=self.g_sigma)

                    if 'lr' in params:
                        lr = params['lr'][min(t, len(params['lr']) - 1)]
                        g_[0] = grad_norm(g_[0]) * lr
                    else:
                        g_[0] = grad_norm(g_[0]) * self.lr
                    if 'mask' in params: g_[0] *= mask_[t]

                    g__.append(g_)

                if self.window_size > 1:
                    n = (self.window_size - 1) // 2
                    for t in range(self.num_frames):
                        t0 = np.maximum(t - n, 0)
                        t1 = np.minimum(t + n, self.num_frames - 1)
                        # print(t, t0, t1)
                        w = [1 / (t1 - t0 + 1)] * self.num_frames

                        g_ = g__[t].copy() * w[t]
                        for s in range(t0, t1 + 1):
                            if s == t: continue
                            g_ += self._transport(g__[s].copy(), v_, s,
                                                  t) * w[s]  # move s to t

                        d_opt_[t] += g_[0]
                        g__[t] = g_
                else:
                    for t in range(self.num_frames):
                        d_opt_[t] += g__[t][0]

                # to avoid resizing numerical error
                if 'mask' in params:
                    for t in range(self.num_frames):
                        d_opt_[t] *= np.ceil(mask_[t])

                if self.iter_seg > 0 and octave == self.octave_n - 1:
                    if (((step / float(self.iter_seg)) -
                         int(step / self.iter_seg)) < 0.00001) and (
                             step != self.iter - 1) and (step != 0):
                        d_opt_iter.append(np.array(d_opt_, copy=True))
        # gather outputs
        result = {'l': loss_history}

        d_opt_iter = np.array(d_opt_iter)
        d_iter = []
        for i in range(d_opt_iter.shape[0]):
            d__ = []
            d_out_ = tf.identity(self.d_out)
            #feed_ = tf.identity(feed)
            for t in range(self.num_frames):
                feed[self.d_opt] = d_opt_iter[i, t, None]
                feed[self.d] = d[t]
                d__.append(self.sess.run(d_out_, feed)[0, ..., 0])
            d__ = np.array(d__)
            d_iter.append(d__)
        d_iter = np.array(d_iter)
        result['d_iter'] = d_iter

        d_ = []
        for t in range(self.num_frames):
            feed[self.d_opt] = d_opt_[t, None]
            feed[self.d] = d[t]
            d_.append(self.sess.run(self.d_out, feed)[0, ..., 0])
        d_ = np.array(d_)
        result['d'] = d_
        return result
Exemplo n.º 4
0
    def __init__(self, self_dict):
        StylerBase.__init__(self, self_dict)

        # particle position
        # shape: [N,3], scale: [0,1]
        p = []
        p_shp = [None, 3]
        self.p = []  # input
        self.v = []  # style

        # particle density, [N,nk]
        r_shp = [None, self.num_kernels]
        self.r = []  # input
        self.d = []  # style

        # output
        d = []
        d_gray = []

        pressure = []

        self.opt_init = []
        self.opt_ph = []
        self.opt = []

        self.res = tf.compat.v1.placeholder(tf.int32, [3], name='resolution')

        for i in range(self.batch_size):
            # particle position, [N,3]
            p_ = tf.compat.v1.placeholder(dtype=tf.float32,
                                          shape=p_shp,
                                          name='p%d' % i)
            self.p.append(p_)
            p_ = tf.expand_dims(p_, axis=0)  # [1,N,3]

            # particle velocity, [N,3]
            if 'p' in self.target_field:
                p_opt_ph = tf.compat.v1.placeholder(dtype=tf.float32,
                                                    shape=p_shp,
                                                    name='p_opt_ph%d' % i)
                self.opt_ph.append(p_opt_ph)
                p_opt = tf.Variable(p_opt_ph,
                                    validate_shape=False,
                                    name='p_opt%d' % i)
                self.opt.append(p_opt)
                p_opt_ = tf.reshape(p_opt, tf.shape(p_opt_ph))
                p_opt_ = tf.expand_dims(p_opt_, axis=0)
                v_ = p_opt_
                self.v.append(v_[0])
                p_ += v_

            p.append(p_[0])

            # particle density, [N,nk]
            if 'd' in self.target_field:
                r_ = tf.compat.v1.placeholder(dtype=tf.float32,
                                              shape=r_shp,
                                              name='r%d' % i)
                self.r.append(r_)
                r_ = tf.expand_dims(r_, axis=0)  # [1,N,nk]

                r_opt_ph = tf.compat.v1.placeholder(dtype=tf.float32,
                                                    shape=r_shp,
                                                    name='r_opt_ph')
                self.opt_ph.append(r_opt_ph)
                r_opt = tf.Variable(r_opt_ph,
                                    validate_shape=False,
                                    name='r_opt')
                self.opt.append(r_opt)
                r_opt_ = tf.reshape(r_opt, tf.shape(r_opt_ph))
                r_opt_ = tf.expand_dims(r_opt_, axis=0)  # [1,N,nk]
                r_opt_ = tf.clip_by_value(r_opt_, -1, 1)  #### necessary!
                self.d.append(r_opt_[0])
                r_ += r_opt_

                # weighted avg. density estimation
                for k in range(self.num_kernels):
                    factor = self.kernel_scale**k
                    support = self.support / factor
                    r_k = tf.expand_dims(r_[..., k], axis=-1)
                    d_hat = p2g_wavg(p_,
                                     r_k,
                                     self.domain,
                                     self.res,
                                     self.radius,
                                     self.nsize,
                                     kernel='cubic',
                                     support=support,
                                     clip=self.clip,
                                     is_2d=False)
                    if k == 0:
                        d_ = d_hat
                    else:
                        d_ += d_hat
            else:
                # position-based (SPH) density field estimation
                d_ = p2g(p_,
                         self.domain,
                         self.res,
                         self.radius,
                         self.rest_density,
                         self.nsize,
                         support=self.support,
                         clip=self.clip,
                         is_2d=False)  # [B,N,3] -> [B,D,H,W,1]
                d_ /= self.rest_density  # normalize density

            d.append(d_)

            # pressure estimation
            if self.w_pressure > 0 and 'p' in self.target_field:
                pressure_ = tf.where(d_ > 0, d_ - 1, tf.zeros_like(d_))
                pressure.append(pressure_)

        self.opt_init = tf.compat.v1.initializers.variables(self.opt)

        # stylized (advected) particles
        self.p_out = p  # [N,3]*B

        # estimated density fields
        d = tf.concat(d, axis=0)  # [B,D,H,W,1]

        if self.w_pressure > 0 and 'p' in self.target_field:
            pressure = tf.concat(pressure, axis=0)  # [B,D,H,W,1]
            self.pressure = pressure

        if self.k > 0:
            # smoothing density for density optimization
            k = []
            k1 = np.float32([1, self.k, 1])
            k2 = np.outer(k1, k1)
            for i in k1:
                k.append(k2 * i)
            k = np.array(k)
            k = k[:, :, :, None, None] / k.sum()
            d = tf.nn.conv3d(d, k, [1, 1, 1, 1, 1], 'SAME')

        # value clipping for rendering
        # d = tf.clip_by_value(d, 0, 1)
        d = tf.maximum(d, 0)

        # stylized result
        self.d_out = d  # [B,D,H,W,1]

        ####
        # rotate 3d smoke for rendering
        if self.rotate:
            d, self.rot_mat = rotate(d)  # [B,D,H,W,1] or [B,D,H,W,4]
            self.d_out_rot = d

            # compute rotation matrices
            self.rot_mat_, self.views = rot_mat(self.phi0,
                                                self.phi1,
                                                self.phi_unit,
                                                self.theta0,
                                                self.theta1,
                                                self.theta_unit,
                                                sample_type=self.sample_type,
                                                rng=self.rng,
                                                nv=self.n_views)

            if self.n_views is None:
                self.n_views = len(self.views)
            print('# vps:', self.n_views)
            assert (self.n_views % self.v_batch == 0)

        # render 3d volume
        if self.render_liquid:
            # d = tf.reduce_max(d, axis=1) # [B,H,W,1]
            transmit = tf.exp(-tf.cumsum(d[:, ::-1], axis=1) * self.transmit)
            self.d_trans = transmit
            d = 1 - transmit[:, -1]  # [B,H,W,1], [0,1]
            # d = (1 - transmit[:,-1])*np.array([0.26, 0.5, 0.75]) + transmit[:,-1]*np.array([1, 1, 1]) # [B,H,W,1], [0,1]
        else:
            transmit = tf.exp(-tf.cumsum(d[:, ::-1], axis=1) *
                              self.transmit)[:, ::-1]
            d *= transmit
            d = tf.reduce_sum(d, axis=1)  # [B,H,W,1] or [B,H,W,3]
            d /= tf.reduce_max(d)  # [B,H,W,1], [0,1]

        # mask for style features
        self.d_gray = d  # [B,H,W,1]
        ####

        self._plugin_to_loss_net(d)