Example #1
0
def pipeline(prev_frame: np.array, frame: np.array):
    prev_frame = procees_frame(prev_frame)
    frame = procees_frame(frame)

    # get rigid flow
    rflow = get_rigid_flow(prev_frame, frame)

    # warped image
    wframe = warp(prev_frame, rflow)

    # compute error map
    ssim_loss = ssim(wframe, frame).mean(1, True)
    l1_loss = torch.abs(wframe - frame).mean(1, True)
    err_map = 0.85 * ssim_loss + 0.15 * l1_loss

    # get dynamic flow correction
    input = torch.cat([prev_frame, frame, wframe], dim=1)
    with torch.no_grad():
        enc_output = encoder(input, rflow, err_map)
        dec_output = decoder(enc_output)
        dflow = dec_output[('flow', 0)]

    flow = dflow + rflow
    flow = flow.squeeze(0).cpu().numpy().transpose(1, 2, 0)
    color_flow = flow_to_color(flow)
    return color_flow[..., ::-1]
Example #2
0
    def warp_patch(self, ng_path, z, bbox, res_mip_range, mip):
        influence_bbox = deepcopy(bbox)
        influence_bbox.uncrop(self.max_displacement, mip=0)

        agg_flow = influence_bbox.identity(mip=mip)
        agg_flow = np.expand_dims(agg_flow, axis=0)
        agg_res = data_handler.get_aggregate_rel_flow(
            z, influence_bbox, res_mip_range, mip, self.process_low_mip,
            self.process_high_mip, self.x_res_ng_paths, self.y_res_ng_paths)
        agg_flow += agg_res

        raw_data = data_handler.get_image_data(ng_path, z, influence_bbox, mip)
        #no need to warp if flow is identity
        #warp introduces noise
        if not influence_bbox.is_identity_flow(agg_flow, mip=mip):
            warped = warp(raw_data, agg_flow)
        else:
            #print ("not warping")
            warped = raw_data[0]

        mip_disp = int(self.max_displacement / 2**mip)
        cropped = crop(warped, mip_disp)
        result = data_handler.preprocess_data(cropped * 256)
        #preprocess divides by 256 and puts it into right dimensions
        #this data range is good already, so mult by 256
        data_handler.save_image_patch(self.dst_ng_path, result, z, bbox, mip)
Example #3
0
 def feature_dtw(self, query_index, inst_index):
     Q = self.feature_query(query_index, inst_index)
     similarity = []
     for ___ in range(len(self.query_mlf.wav_list)):
         similarity.append(float('Inf'))
     for w in range(len(self.query_mlf.wav_list)):
         f_file = self.feature_fold + self.query_mlf.wav_list[w] + '.mfc'
         D = util.read_feature(f_file)
         similarity[w] = util.warp(util.cos_dist(D, Q))
         # similarity[w] = util.warp(- np.dot(D,Q.T))
         print query_index, w, similarity[w]
     return similarity
Example #4
0
def optical_flow_HS(Im1, Im2, nlevel):
    ni, nj = Im1.shape
    u = np.zeros((ni, nj))
    v = np.zeros((ni, nj))
    for lev in range(nlevel, -1, -1):
        Im1warp = util.warp(Im1, -u, -v)
        Im1c = util.coarsen(Im1warp, lev)
        Im2c = util.coarsen(Im2, lev)

        niter = 100
        w1 = 100
        w2 = 0
        Ix = 0.5 * (util.deriv_x(Im1c) + util.deriv_x(Im2c))
        Iy = 0.5 * (util.deriv_y(Im1c) + util.deriv_y(Im2c))
        It = Im2c - Im1c
        du = np.zeros(Ix.shape)
        dv = np.zeros(Ix.shape)
        for k in range(niter):
            ubar2 = util.laplacian(du) + du
            vbar2 = util.laplacian(dv) + dv
            ubar1 = util.deriv_xx(du) + du
            vbar1 = util.deriv_yy(dv) + dv
            uxy = util.deriv_xy(du)
            vxy = util.deriv_xy(dv)
            du = (w1 * ubar2 + w2 * (ubar1 + vxy)) / (w1 + w2) - Ix * (
                (w1 * (Ix * ubar2 + Iy * vbar2) + w2 * ((ubar1 + vxy) * Ix +
                                                        (vbar1 + uxy) * Iy)) /
                (w1 + w2) + It) / (w1 + w2 + Ix**2 + Iy**2)
            dv = (w1 * vbar2 + w2 * (vbar1 + uxy)) / (w1 + w2) - Iy * (
                (w1 * (Ix * ubar2 + Iy * vbar2) + w2 * ((ubar1 + vxy) * Ix +
                                                        (vbar1 + uxy) * Iy)) /
                (w1 + w2) + It) / (w1 + w2 + Ix**2 + Iy**2)

        u += util.sharpen(du * 2**lev, lev)
        v += util.sharpen(dv * 2**lev, lev)
    return u, v
Example #5
0
def error(error_type, msg):
    return warp("error", {
        'type': error_type,
        'message': msg
    })
Example #6
0
def ok(msg_type, data):
    return warp("ok", {
        'type': msg_type,
        'code': check_code(data)
    })
Example #7
0
def welcome(msg):
    return warp("welcome", msg)
Example #8
0
##run filter
u = np.zeros((nens, ns, nx, ny))
v = np.zeros((nens, ns, nx, ny))
for s in range(ns):
  xs1 = xs.copy()
  # print('running EnSRF for scale {}'.format(s))
  xs = da.EnSRF(xs, obs_loc, obs, obs_err, localize_cutoff[s], s)
  if run_displacement == 1:
    if s < ns-1:
      for m in range(nens):
        # print('aligning member {:04d}'.format(m+1))
        u[m, s, :, :], v[m, s, :, :] = da.optical_flow_HS(xs1[m, s, :, :, 0], xs[m, s, :, :, 0], nlevel=5)
        for z in range(nz):
          for i in range(s+1, ns):
            xs[m, i, :, :, z] = util.warp(xs[m, i, :, :, z], -u[m, s, :, :], -v[m, s, :, :])
xas = xs.copy()

##sum scales back to full state
xa = np.sum(xas, axis=1)

###adaptive inflation
hxb = np.zeros((nens, nobs))
hxa = np.zeros((nens, nobs))
for m in range(nens):
  for n in range(nobs):
    hxb[m, n] = util.interp3d(xb[m, :, :, :], obs_loc[n, :])
    hxa[m, n] = util.interp3d(xa[m, :, :, :], obs_loc[n, :])
hxbm = np.mean(hxb, axis=0)
hxam = np.mean(hxa, axis=0)
amb = hxam - hxbm
Example #9
0
    if args.load_checkpoint:
        start_epoch, rloss = load_checkpoint()

    for epoch in range(start_epoch, args.num_epochs):
        for i, data in enumerate(train_dataloader):
            # zero grad
            optimizer.zero_grad()

            # extract data
            imgs1 = [data[('color_aug', -1, i)].to(device) for i in range(4)]
            imgs2 = [data[('color_aug', 0, i)].to(device) for i in range(4)]
            rflow = get_rigid_flow(imgs1[0], imgs2[0])

            # compute warped imaged using rigid flow
            wimg2_r = warp(imgs1[0], rflow)
            input = torch.cat((imgs1[0], imgs2[0], wimg2_r), dim=1)

            # compute reprojection loss
            # for the warped image with rigid flow
            ssim_loss = ssim(wimg2_r, imgs2[0]).mean(1, True)
            l1_loss = torch.abs(wimg2_r - imgs2[0]).mean(1, True)
            reprojection_loss = 0.85 * ssim_loss + 0.15 * l1_loss

            # compute dynamic flow
            enc_output = encoder(input, rflow, reprojection_loss)
            dec_output = decoder(enc_output)

            loss = 0
            for j in args.scales:
                img1 = imgs1[j]
Example #10
0
def receive_group_message(data):
    return warp('receive_group_message', data)
Example #11
0
def group_message(data):
    return warp("group_message", data)
Example #12
0
def enter_group(data):
    return warp("enter_group", data)
Example #13
0
def logout(data):
    return warp("logout", data)
Example #14
0
def login(data):
    return warp("login", data)
Example #15
0
def send_token(data):
    return warp("send_token", data)
Example #16
0
def register(data):
    return warp("register", data)
Example #17
0
def receive_private_message(data):
    return warp('receive_private_message', data)
Example #18
0
def connect():
    return warp("connect", "")
Example #19
0
def exit_group(data):
    return warp("exit_group", data)
Example #20
0
    xs = da.EnSRF(xs, obs_loc, obs, obs_err, localize_cutoff[s],
                  s)  #run DA step of iteration
    if run_displacement == 1:  #run displacement step of iteration
        if s < ns - 1:
            print('aligning members')
            for m in range(nens):
                # print('aligning member {:04d}'.format(m+1))
                u[m,
                  s, :, :], v[m,
                              s, :, :] = da.optical_flow_HS(xs1[m, s, :, :, 0],
                                                            xs[m, s, :, :, 0],
                                                            nlevel=5)
                for z in range(nz):
                    for i in range(s + 1, ns):
                        xs[m, i, :, :, z] = util.warp(xs[m, i, :, :,
                                                         z], -u[m, s, :, :],
                                                      -v[m, s, :, :])
                xs2[:, s +
                    1, :, :, :] = xs[:, s +
                                     1, :, :, :]  #save a copy of the aligned prior
print('filtering complete')
xas = xs.copy()  #final analysis

# util.output_ens('2.nc', xa[:, :, :, :])


def set_axis(ax, title):
    ax.set_aspect('equal', 'box')
    ax.set_xlim(0, nx)
    ax.set_ylim(0, ny)
    ax.set_xticks(np.arange(0, nx + 1, 50))
Example #21
0
def send_private(data):
    return warp("send_private", data)
Example #22
0
def test_sample():
    global test_iter

    encoder.eval()
    encoder.eval()

    try:
        test_batch = next(test_iter)
    except StopIteration:
        test_iter = iter(test_dataloader)
        test_batch = next(test_iter)

    imgs1 = [data[('color_aug', -1, i)].to(device) for i in range(4)]
    imgs2 = [data[('color_aug', 0, i)].to(device) for i in range(4)]
    rflow = get_rigid_flow(imgs1[0], imgs2[0])

    # compute warped image using rigid flow
    wimg2_r = warp(imgs1[0], rflow)
    input = torch.cat((imgs1[0], imgs2[0], wimg2_r), dim=1)

    # compute reprojection loss
    # for the warped image with rigid flow
    ssim_loss = ssim(wimg2_r, imgs2[0]).mean(1, True)
    l1_loss = torch.abs(wimg2_r - imgs2[0]).mean(1, True)
    reprojection_loss = 0.85 * ssim_loss + 0.15 * l1_loss

    with torch.no_grad():
        enc_output = encoder(input, rflow, reprojection_loss)
        dec_output = decoder(enc_output)

    # compute warped image using rigid and dynamic flow
    dflow = dec_output[('flow', 0)]
    flow = dflow + rflow
    wimg2_dr = warp(imgs1[0], flow)

    # compute reprojection loss
    # for the warped image using the rigid and dynamic flow
    ssim_loss = ssim(wimg2_dr, imgs2[0]).mean(1, True)
    l1_loss = torch.abs(wimg2_dr - imgs2[0]).mean(1, True)
    reprojection_loss = 0.85 * ssim_loss + 0.15 * l1_loss

    # compute mask
    with torch.no_grad():
        input = torch.cat([imgs1[0], imgs2[0], wimg2_dr], dim=1)
        enc_mask_output = encoder_mask(input, flow, reprojection_loss)
        dec_mask_output = decoder_mask(enc_mask_output)
        mask = torch.sigmoid(dec_mask_output[('flow', 0)])
        mask = mask.repeat(1, 3, 1, 1)

    # color flow
    flow = flow.cpu()
    colors = []
    for j in range(args.batch_size):
        color_flow = flow[j].numpy().transpose(1, 2, 0)
        color_flow = flow_to_color(color_flow).transpose(2, 0, 1)
        color_flow = torch.tensor(color_flow).unsqueeze(0).float() / 255
        colors.append(color_flow)
    colors = torch.cat(colors, dim=0)

    img1 = imgs1[0][:args.num_vis].cpu()
    img2 = imgs2[0][:args.num_vis].cpu()
    wimg2_dr = wimg2_dr[:args.num_vis].cpu()
    mask = mask[:args.num_vis].cpu()
    colors = colors[:args.num_vis]

    imgs = torch.cat([
        img1, img2, mask * wimg2_dr, 0.5 * (wimg2_dr + img2), 0.5 *
        (img1 + img2), colors, mask
    ],
                     dim=3)
    imgs = torchvision.utils.make_grid(imgs, nrow=1, normalize=False)
    imgs = (255 * imgs.numpy().transpose(1, 2, 0)).astype(np.uint8)
    cv2.imwrite("./snapshots/imgs/%d.%d.png" % (epoch, i), imgs[..., ::-1])

    encoder.train()
    decoder.train()
Example #23
0
    def forward(self, data_in, data_out):
        loss = 0.0
        image = data_in['image']
        camera = data_out['camera']
        ground = data_out['ground']
        depth_out = data_out['depth']
        normal, points = util.normal(depth_out, camera)
        data_out['normal'] = normal
        data_out['points'] = points

        if self.average_weight > 0.0:
            average_depth = depth_out.mean()
            data_out['eval_average_depth'] = average_depth
            data_out['loss_average_depth'] = (
                average_depth - self.average_depth).abs() * self.average_weight
            loss += data_out['loss_average_depth']

        if self.scale_weight > 0.0 and 'distance_previous' in data_in and 'distance_next' in data_in:
            motion = data_out['motion']
            scale = data_out['scale']
            data_out['eval_scale'] = scale.mean()

            # distance previous
            dp_in = data_in['distance_previous']
            dp_out = torch.norm(motion[:, 3:6].detach(), dim=1, keepdim=True)
            loss_scale_previous = (dp_out * scale - dp_in).abs().mean()
            data_out[
                'loss_scale_previous'] = loss_scale_previous * self.scale_weight
            loss += data_out['loss_scale_previous']
            data_out['eval_dp_in'] = dp_in.mean()
            data_out['eval_dp_out'] = dp_out.mean()

            # distance next
            dn_in = data_in['distance_next']
            dn_out = torch.norm(motion[:, 9:12].detach(), dim=1, keepdim=True)
            loss_scale_next = (dn_out * scale - dn_in).abs().mean()
            data_out['loss_scale_next'] = loss_scale_next * self.scale_weight
            loss += data_out['loss_scale_next']
            data_out['eval_dn_in'] = dn_in.mean()
            data_out['eval_dn_out'] = dn_out.mean()

        if self.ground_weight > 0.0:
            ground_grid = util.plane_grid(ground.detach(), camera.detach(),
                                          image.shape, image.device)
            data_out['ground_grid'] = ground_grid
            ground_normal = ground[:, 0:3].reshape(-1, 3, 1, 1)
            ground_d = -ground[:, 3:4].reshape(-1, 1, 1, 1)
            normal_weight = normal.detach() - ground_normal.detach()
            normal_weight = torch.pow(normal_weight, 2.0).sum(dim=1,
                                                              keepdim=True)
            normal_weight = torch.exp(-normal_weight * 5.0)
            data_out['ground_normal_weight'] = normal_weight

            ground_for_dist = torch.cat([ground_normal.detach(), ground_d],
                                        dim=1)
            ground_dist = util.plane_dist(ground_for_dist, camera.detach(),
                                          depth_out.detach()).abs()
            ground_dist = ground_dist[:, :, 1:-1, 1:-1]
            data_out['ground_dist'] = ground_dist
            dist_weight = torch.pow(ground_dist.detach(), 2.0)
            dist_weight = torch.exp(-25.0 * dist_weight)
            data_out['ground_dist_weight'] = dist_weight

            ground_weight = normal_weight * dist_weight
            data_out['ground_weight'] = ground_weight

            ground_normal_residual = ground_weight * (ground_normal -
                                                      normal.detach()).abs()
            data_out['ground_normal_residual'] = ground_normal_residual.mean(
                dim=1, keepdim=True)
            data_out['loss_ground_normal'] = data_out[
                'ground_normal_residual'].mean()
            loss += data_out['loss_ground_normal']

            ground_dist_residual = ground_weight * ground_dist.abs()
            data_out['ground_dist_residual'] = ground_dist_residual
            data_out['loss_ground_dist'] = data_out[
                'ground_dist_residual'].mean()
            loss += data_out['loss_ground_dist']

        if self.regular_weight > 0.0:
            if self.regular_weight > 0.0:
                loss_regular = 0.0
                depth_down = depth_out
                image_down = image
                for i in range(self.down_times):
                    scale_factor = 1.0 if i == 0 else 0.5
                    image_down = torch.nn.functional.interpolate(
                        image_down,
                        scale_factor=scale_factor,
                        mode='bilinear',
                        align_corners=True)
                    depth_down = torch.nn.functional.interpolate(
                        depth_down,
                        scale_factor=scale_factor,
                        mode='bilinear',
                        align_corners=True)
                    if self.regular_flag == 0:  # depth grad
                        regular_grad = torch.cat(util.sobel(depth_down), dim=1) \
                            .abs().mean(dim=1, keepdim=True)
                        image_grad = torch.cat(util.sobel(image_down), dim=1) \
                            .abs().mean(dim=1, keepdim=True)
                    elif self.regular_flag == 1:  # depth grad2
                        regular_grad = torch.cat(util.sobel(depth_down), dim=1)
                        regular_grad = torch.cat(util.sobel(regular_grad), dim=1) \
                            .abs().mean(dim=1, keepdim=True)
                        image_grad = torch.cat(util.sobel(image_down, padding=-1), dim=1) \
                            .abs().mean(dim=1, keepdim=True)
                    elif self.regular_flag == 2:  # normal grad
                        normal_down, _ = util.normal(depth_down, camera)
                        regular_grad = torch.cat(util.sobel(normal_down), dim=1) \
                            .abs().mean(dim=1, keepdim=True)
                        image_grad = torch.cat(util.sobel(image_down, padding=-1), dim=1) \
                            .abs().mean(dim=1, keepdim=True)
                    elif self.regular_flag == 3:  # normal grad 2
                        points = util.unproject(depth_down, camera)
                        points = points[:, 0:3, ...] * points[:, 3:4, ...]
                        grad_x, grad_y = util.sobel(points, padding=0)
                        normal_down = util.cross(grad_x, grad_y)
                        normal_down = torch.nn.functional.normalize(
                            normal_down)
                        regular_grad = torch.cat(util.sobel(normal_down), dim=1) \
                            .abs().mean(dim=1, keepdim=True)
                        image_grad = torch.cat(util.sobel(image_down, padding=-1), dim=1) \
                            .abs().mean(dim=1, keepdim=True)
                    else:
                        raise Exception("Invalid regular flag")
                    image_grad_inv = torch.exp(-100.0 * image_grad *
                                               image_grad)
                    regular_residual = regular_grad * regular_grad * image_grad_inv  # *   # * image_grad_inv
                    data_out['regular_grad_grad%d' % i] = regular_grad
                    data_out['regular_image_inv_%d' % i] = image_grad_inv
                    data_out['regular_residual_%d' % i] = regular_residual
                    loss_regular += torch.pow(regular_residual, 2.0).mean()

            data_out['loss_regular'] = loss_regular / min(
                self.down_times, 4) * self.regular_weight
            loss += data_out['loss_regular']

        if 'depth' in data_in:
            depth_in = data_in['depth']
            depth_out = data_out['depth']
            scale = data_out['scale'].reshape(-1, 1, 1, 1)
            mask = depth_in > (1.0 / 80.0)
            mask &= depth_out > (1.0 / 80.0)
            z_in = torch.zeros_like(depth_in)
            z_out = torch.zeros_like(depth_out)
            z_in[mask] = 1.0 / depth_in[mask]
            z_out[mask] = 1.0 / depth_out[mask]

            residual_abs_rel = torch.zeros_like(depth_in)
            residual_abs_rel[mask] = (1.0 - z_out[mask] / z_in[mask]).abs()
            data_out['residual_abs_rel'] = residual_abs_rel
            data_out['eval_abs_rel'] = residual_abs_rel.sum() / mask.sum()

            residual_abs_rel_scaled = torch.zeros_like(depth_in)
            residual_abs_rel_scaled[mask] = (
                1.0 - (z_out * scale)[mask] / z_in[mask]).abs()
            data_out['residual_abs_rel_scaled'] = residual_abs_rel_scaled
            data_out['eval_abs_rel_scaled'] = residual_abs_rel_scaled.sum(
            ) / mask.sum()
            data_out['eval_scale'] = scale.mean()

            if self.depth_weight > 0.0:
                data_out['loss_depth'] = data_out[
                    'eval_abs_rel'] * self.depth_weight
                loss += data_out['loss_depth']

        if self.ref_weight > 0:
            loss_previous = 0.0
            loss_next = 0.0
            data_out['base_previous'] = (image - data_in['previous']).abs()
            data_out['base_next'] = (image - data_in['next']).abs()
            data_out['image_previous'] = data_in['previous']
            data_out['image_next'] = data_in['next']
            image_down = data_in['image']
            image_previous_down = data_in['previous']
            image_next_down = data_in['next']
            depth_down = data_out['depth']
            motion_previous = data_out['motion'][:, 0:6]
            motion_next = data_out['motion'][:, 6:12]
            for i in range(self.down_times):
                scale_factor = 1.0 if i == 0 else 0.5
                image_down = torch.nn.functional.interpolate(
                    image_down,
                    scale_factor=scale_factor,
                    mode='bilinear',
                    align_corners=True)
                image_previous_down = torch.nn.functional.interpolate(
                    image_previous_down,
                    scale_factor=scale_factor,
                    mode='bilinear',
                    align_corners=True)
                image_next_down = torch.nn.functional.interpolate(
                    image_next_down,
                    scale_factor=scale_factor,
                    mode='bilinear',
                    align_corners=True)
                depth_down = torch.nn.functional.interpolate(
                    depth_down,
                    scale_factor=scale_factor,
                    mode='bilinear',
                    align_corners=True)
                if self.warp_flag == 0:
                    warp_previous = util.warp(image_previous_down, depth_down,
                                              camera, motion_previous,
                                              self.warp_flag)
                    mask_previous = warp_previous != 0
                    warp_next = util.warp(image_next_down, depth_down, camera,
                                          motion_next, self.warp_flag)
                    mask_next = warp_next != 0
                    residual_previous = (image_down * mask_previous -
                                         warp_previous).abs()
                    residual_next = (image_down * mask_next - warp_next).abs()
                elif self.warp_flag == 1:  # direct
                    warp_previous, record_previous = util.warp(
                        image_down, depth_down, camera, motion_previous,
                        self.warp_flag)
                    mask_previous = record_previous != 0
                    warp_next, record_next = util.warp(image_down, depth_down,
                                                       camera, motion_next,
                                                       self.warp_flag)
                    mask_next = record_next != 0
                    residual_previous = (image_previous_down * mask_previous -
                                         warp_previous).abs()
                    residual_next = (image_next_down * mask_next -
                                     warp_next).abs()
                elif self.warp_flag == 2:  # record
                    warp_previous, record_previous, _, weight_previous = util.warp(
                        image_previous_down, depth_down, camera,
                        motion_previous, self.warp_flag, self.previous_sigma)
                    mask_previous = weight_previous
                    warp_next, record_next, _, weight_next = util.warp(
                        image_next_down, depth_down, camera, motion_next,
                        self.warp_flag, self.next_sigma)
                    mask_next = weight_next
                    residual_previous = ((image_down - warp_previous) *
                                         weight_previous).abs()
                    residual_next = ((image_down - warp_next) *
                                     weight_next).abs()
                    data_out['record_depth_previous_%d' % i] = record_previous
                    data_out['record_depth_next_%d' % i] = record_next
                    data_out['record_weight_previous_%d' % i] = weight_previous
                    data_out['record_weight_next_%d' % i] = weight_next
                elif self.warp_flag == 3:  # wide
                    warp_previous = util.warp(image_previous_down, depth_down,
                                              camera, motion_previous,
                                              self.warp_flag)
                    mask_previous = warp_previous != 0
                    warp_next = util.warp(image_next_down, depth_down, camera,
                                          motion_next, self.warp_flag)
                    mask_next = warp_next != 0
                    residual_previous = (image_down * mask_previous -
                                         warp_previous).abs()
                    residual_next = (image_down * mask_next - warp_next).abs()
                else:
                    raise Exception('Invalid warp flag.')
                data_out['residual_previous_%d' % i] = residual_previous
                data_out['residual_next_%d' % i] = residual_next
                data_out['warp_previous_%d' % i] = warp_previous
                data_out['warp_next_%d' % i] = warp_next
                loss_previous += residual_previous.sum() / (
                    mask_previous.sum() + 1)
                loss_next += residual_next.sum() / (mask_next.sum() + 1)
            sigma_momentum = 0.99
            self.previous_sigma = sigma_momentum * self.previous_sigma + \
                (1.0 - sigma_momentum) * loss_previous.item() * \
                self.sigma_scale / self.down_times
            self.next_sigma = sigma_momentum * self.next_sigma + \
                (1.0 - sigma_momentum) * loss_next.item() * \
                self.sigma_scale / self.down_times
            data_out['eval_previous_sigma'] = self.previous_sigma
            data_out['eval_next_sigma'] = self.next_sigma
            data_out[
                'loss_previous'] = loss_previous / self.down_times * self.ref_weight
            data_out[
                'loss_next'] = loss_next / self.down_times * self.ref_weight
            loss += data_out['loss_previous'] + data_out['loss_next']

        return loss