Пример #1
0
    def __init__(self, *defaults, **kwargs):
        '''
        Args:
            defaults(tuple): the default will have:
                0->model: the model of the experiment 
                1->train_dataloader: the dataloader   
                2->val_dataloader: the dataloader     
                3->optimizer: the optimizer of the network
                4->loss_function: the loss function of the model
                5->logger: the logger of the whole training process
                6->config: the config object of the whole process

            kwargs(dict): the default will have:
                verbose(str):
                parallel(bool): True-> data parallel
                pertrain(bool): True-> use the pretarin model

        '''
        # logger & config
        self.logger = defaults[5]
        self.config = defaults[6]

        # basic things
        if kwargs['parallele']:
            self.model = self.data_parallel(defaults[0])
        else:
            self.model = defaults[0].cuda()

        if kwargs['pretrain']:
            self.load_pretrain()

        self.train_dataloader = defaults[1]
        self.val_dataloader = defaults[2]
        self.optimizer = defaults[3]
        self.loss_function = defaults[4]

        # basic meter
        self.batch_time = AverageMeter()
        self.data_time = AverageMeter()
        self.loss_basic = AverageMeter()

        # others
        self.verbose = kwargs['verbose']
        self.accuarcy = 0.0  # to store the accuracy varies from epoch to epoch
        self.kwargs = kwargs
        self.result_path = '.'
        self.log_step = 5  # how many steps to print the information
        self.eval_step = 5  # how many steps to use the val dataset to test the model
        self.save_step = 5

        if self.config.RESUME.flag:
            self.resume()

        if self.config.FINETUNE.flag:
            self.fine_tune()
Пример #2
0
    def mini_eval(self, current_step):
        if current_step % self.config.TRAIN.mini_eval_step != 0:
            return
        temp_meter_rec = AverageMeter()
        temp_meter_pred = AverageMeter()
        self.STAE.eval()
        for data in self.val_dataloader:
            # get the reconstruction and prediction video clip
            time_len = data.shape[2]
            rec_time = time_len // 2
            inupt_rec_mini = data[:, :,
                                  0:rec_time, :, :].cuda()  # 0 ~ t//2 frame
            input_pred_mini = data[:, :, rec_time:time_len, :, :].cuda(
            )  # t//2 ~ t frame

            # Use the model, get the output
            output_rec_mini, output_pred_mini = self.STAE(inupt_rec_mini)
            rec_psnr_mini = psnr_error(output_rec_mini.detach(),
                                       inupt_rec_mini)
            pred_psnr_mini = psnr_error(output_pred_mini.detach(),
                                        input_pred_mini)
            temp_meter_rec.update(rec_psnr_mini.detach())
            temp_meter_pred.update(pred_psnr_mini.detach())
        self.logger.info(
            f'&^*_*^& ==> Step:{current_step}/{self.max_steps} the rec PSNR is {temp_meter_rec.avg:.3f}, the pred PSNR is {temp_meter_pred.avg:.3f}'
        )
Пример #3
0
    def mini_eval(self, current_step):
        if current_step % self.config.TRAIN.mini_eval_step != 0:
            return
        temp_meter_frame = AverageMeter()
        temp_meter_flow = AverageMeter()
        self.G.eval()
        self.D.eval()
        for data in self.val_dataloader:
            # get the data
            target_mini = data[:, :, 1, :, :]
            input_data_mini = data[:, :, 0, :, :]
            # squeeze the dimension
            target_mini = target_mini.view(target_mini.shape[0], -1,
                                           target_mini.shape[-2],
                                           target_mini.shape[-1]).cuda()
            input_data_mini = input_data_mini.view(
                input_data_mini.shape[0], -1, input_data_mini.shape[-2],
                input_data_mini.shape[-1]).cuda()

            # Use the model, get the output
            output_flow_G_mini, output_frame_G_mini = self.G(input_data_mini)
            input_gtFlowEstimTensor = torch.cat([input_data_mini, target_mini],
                                                1)
            gtFlow, _ = flow_batch_estimate(self.F, input_gtFlowEstimTensor)

            frame_psnr_mini = psnr_error(output_frame_G_mini.detach(),
                                         target_mini)
            flow_psnr_mini = psnr_error(output_flow_G_mini.detach(), gtFlow)
            temp_meter_frame.update(frame_psnr_mini.detach())
            temp_meter_flow.update(flow_psnr_mini.detach())
        self.logger.info(
            f'&^*_*^& ==> Step:{current_step}/{self.max_steps} the frame PSNR is {temp_meter_frame.avg:.3f}, the flow PSNR is {temp_meter_flow.avg:.3f}'
        )
Пример #4
0
    def mini_eval(self, current_step):
        if current_step % self.config.TRAIN.mini_eval_step != 0:
            return
        temp_meter_frame = AverageMeter()
        self.MemAE.eval()
        for data in self.val_dataloader:
            # get the data
            input_data_mini = data.cuda()
            output_rec, _ = self.MemAE(input_data_mini)

            frame_psnr_mini = psnr_error(output_rec.detach(), input_data_mini)
            temp_meter_frame.update(frame_psnr_mini.detach())
        self.logger.info(
            f'&^*_*^& ==> Step:{current_step}/{self.max_steps} the frame PSNR is {temp_meter_frame.avg:.3f}'
        )
Пример #5
0
    def mini_eval(self, current_step):
        if current_step % 10 != 0 or current_step == 0:
            return
        temp_meter = AverageMeter()
        self.G.eval()
        self.D.eval()
        for data in self.val_dataloader:
            # get the data
            target_mini = data[:, :, -1, :, :].cuda()
            input_data_mini = data[:, :, :-1, :, :].cuda()

            output_frame_G_mini = self.G(input_data_mini, target_mini)
            vaild_psnr = psnr_error(output_frame_G_mini.detach(), target_mini)
            temp_meter.update(vaild_psnr.detach())
        self.logger.info(
            f'&^*_*^& ==> Step:{current_step}/{self.max_steps} the PSNR is {temp_meter.avg:.3f}'
        )
Пример #6
0
    def mini_eval(self, current_step):
        if current_step % 10 != 0 or current_step == 0:
            return
        temp_meter_frame = AverageMeter()
        temp_meter_flow = AverageMeter()
        self.G.eval()
        self.D.eval()
        for data in self.val_dataloader:
            # base on the D to get each frame
            target_mini = data[:, :, -1, :, :].cuda() # t+1 frame 
            input_data = data[:, :, :-1, :, :] # 0 ~ t frame
            input_last_mini = input_data[:, :, -1, :, :].cuda() # t frame

            # squeeze the D dimension to C dimension, shape comes to [N, C, H, W]
            input_data_mini = input_data.view(input_data.shape[0], -1, input_data.shape[-2], input_data.shape[-1]).cuda()
            output_pred_G = self.G(input_data_mini)
            gtFlow, _ = flow_batch_estimate(self.F, torch.cat([input_last_mini, target_mini], 1))
            predFlow, _ = flow_batch_estimate(self.F, torch.cat([input_last_mini, output_pred_G], 1))
            frame_psnr_mini = psnr_error(output_pred_G.detach(), target_mini)
            flow_psnr_mini = psnr_error(predFlow, gtFlow)
            temp_meter_frame.update(frame_psnr_mini.detach())
            temp_meter_flow.update(flow_psnr_mini.detach())
        self.logger.info(f'&^*_*^& ==> Step:{current_step}/{self.max_steps} the PSNR is {temp_meter_frame.avg:.2f}, the flow PSNR is {temp_meter_flow.avg:.2f}')
Пример #7
0
    def __init__(self,
                 model,
                 train_dataloader,
                 val_dataloader,
                 optimizer,
                 loss_function,
                 logger,
                 config,
                 verbose='None',
                 parallel=True,
                 pretrain=False,
                 **kwargs):
        # logger & config
        self.logger = logger
        self.config = config

        # basic things
        if parallel:
            self.model = self.data_parallel(model)
        else:
            self.model = model

        if pretrain:
            self.load_pretrain()

        self.train_dataloader = train_dataloader
        self.val_dataloader = val_dataloader
        self.optimizer = optimizer
        self.loss_function = loss_function

        # basic meter
        self.batch_time = AverageMeter()
        self.data_time = AverageMeter()
        self.loss_basic = AverageMeter()

        # others
        self.verbose = verbose
        self.accuarcy = 0.0  # to store the accuracy varies from epoch to epoch
        self.kwargs = kwargs
        self.total_steps = len(self.train_dataloader)
        self.result_path = ''

        if self.config.RESUME.flag:
            self.resume()

        if self.config.FINETUNE.flag:
            self.fine_tune()
Пример #8
0
class BaseTrainer(AbstractTrainer):
    '''
    Define the basic things about training
    '''
    def __init__(self,
                 model,
                 train_dataloader,
                 val_dataloader,
                 optimizer,
                 loss_function,
                 logger,
                 config,
                 verbose='None',
                 parallel=True,
                 pretrain=False,
                 **kwargs):
        # logger & config
        self.logger = logger
        self.config = config

        # basic things
        if parallel:
            self.model = self.data_parallel(model)
        else:
            self.model = model

        if pretrain:
            self.load_pretrain()

        self.train_dataloader = train_dataloader
        self.val_dataloader = val_dataloader
        self.optimizer = optimizer
        self.loss_function = loss_function

        # basic meter
        self.batch_time = AverageMeter()
        self.data_time = AverageMeter()
        self.loss_basic = AverageMeter()

        # others
        self.verbose = verbose
        self.accuarcy = 0.0  # to store the accuracy varies from epoch to epoch
        self.kwargs = kwargs
        self.total_steps = len(self.train_dataloader)
        self.result_path = ''

        if self.config.RESUME.flag:
            self.resume()

        if self.config.FINETUNE.flag:
            self.fine_tune()

    def _get_time(self):
        '''
        Get the current time
        '''
        return time.strftime('%Y-%m-%d-%H-%M')  # 2019-08-07-10-34

    def load_pretrain(self):
        model_path = self.config.MODEL.pretrain_model
        if model_path is '':
            self.logger.info(
                '=>Not have the pre-train model! Training from the scratch')
        else:
            self.logger.info('=>Loading the model in {}'.format(model_path))
            pretrain_model = torch.load(model_path)
            if 'epoch' in pretrain_model.keys():
                self.logger.info('(|_|) ==> Use the check point file')
                self.model.load_state_dict(pretrain_model['model_state_dict'])
            else:
                self.logger.info('(+_+) ==> Use the model file')
                # model_all.load_state_dict(pretrain_model['state_dict'], strict=False)
                self.model.load_state_dict(pretrain_model['state_dict'])

    def resume(self):
        self.logger.info('=> Resume the previous training')
        checkpoint_path = self.config.RESUME.checkpoint_path
        self.logger.info(
            '=> Load the checkpoint from {}'.format(checkpoint_path))
        checkpoint = torch.load(checkpoint_path)
        self.model.load_state_dict(checkpoint['model_state_dict'])
        self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

    def fine_tune(self):
        layer_list = self.config.FINETUNE.layer_list
        self.logger.info(
            '=> Freeze layers except start with:{}'.format(layer_list))
        for n, p in self.model.named_parameters():
            parts = n.split('.')
            # consider the data parallel situation
            if parts[0] == 'module':
                if parts[1] not in layer_list:
                    p.requires_grad = False
                if p.requires_grad:
                    print(n)
            else:
                if parts[0] not in layer_list:
                    p.requires_grad = False
                if p.requires_grad:
                    print(n)
        self.logger.info('Finish Setting freeze layers')

    def data_parallel(self, model):
        '''
        Data parallel the model
        '''
        self.logger.info('<!_!> ==> Data Parallel')
        gpus = [int(i) for i in self.config.SYSTEM.gpus]
        model_parallel = torch.nn.DataParallel(model, device_ids=gpus).cuda()
        return model_parallel

    def run(self, current_epoch):
        '''
        Run the whole process:
        1. print the log information
        2. execute training process
        3. evaluate(including the validation and test)
        4. save model
        '''
        self.logger.info('-0_0- ==>|{}| Start Traing the {}/{} epoch'.format(
            self._get_time(), current_epoch, self.config.TRAIN.epochs))

        # train the model
        self.train(current_epoch)

        # evaluate
        acc = self.evaluate(current_epoch)
        if acc > self.accuarcy:
            self.accuarcy = acc
            # save the model & checkpoint
            self.save(current_epoch, best=True)
        else:
            # save the checkpoint
            self.save(current_epoch)
            self.logger.info(
                'LOL==>the accuracy is not imporved in epcoh{}'.format(
                    current_epoch))

    def save(self, current_epoch, best=False):
        if best:
            save_checkpoint(self.config,
                            self.kwargs['config_name'],
                            self.model,
                            current_epoch,
                            self.loss_basic.val,
                            self.optimizer,
                            self.logger,
                            self.kwargs['time_stamp'],
                            self.accuarcy,
                            flag='best',
                            verbose=(self.kwargs['cae_type'] + '#' +
                                     self.verbose))
            self.result_path = save_model(self.config,
                                          self.kwargs['config_name'],
                                          self.model,
                                          self.logger,
                                          self.kwargs['time_stamp'],
                                          self.accuarcy,
                                          verbose=(self.kwargs['cae_type'] +
                                                   '#' + self.verbose))
        else:
            save_checkpoint(self.config,
                            self.kwargs['config_name'],
                            self.model,
                            current_epoch,
                            self.loss_basic.val,
                            self.optimizer,
                            self.logger,
                            self.kwargs['time_stamp'],
                            self.accuarcy,
                            verbose=(self.kwargs['cae_type'] + '#' +
                                     self.verbose))

    def train(self, current_epoch):
        start = time.time()
        self.model.train()
        writer = self.kwargs['writer_dict']['writer']
        global_steps = self.kwargs['writer_dict']['global_steps_{}'.format(
            self.kwargs['cae_type'])]

        for step, (data, target) in enumerate(self.train_dataloader):
            self.data_time.update(time.time() - start)
            # True Process
            output = self.model(data)
            loss = self.loss_function(output, target.cuda())
            self.optimizer.zero_grad()
            loss.sum().backward()
            self.optimizer.step()
            # record
            self.loss_basic.update(loss.sum())
            self.batch_time.update(time.time() - start)

            if (step % 10 == 0) or (step == self.total_steps - 1):
                msg = 'Epoch: [{0}][{1}/{2}]\t' \
                    'Type: {cae_type}\t' \
                    'Time: {batch_time.val:.3f}s ({batch_time.avg:.3f}s)\t' \
                    'Speed: {speed:.1f} samples/s\t' \
                    'Data: {data_time.val:.3f}s ({data_time.avg:.3f}s)\t' \
                    'Loss: {losses.val:.5f} ({losses.avg:.5f})\t'.format(current_epoch, step, self.total_steps, cae_type=self.kwargs['cae_type'], batch_time=self.batch_time, speed=self.config.TRAIN.batch_size/self.batch_time.val, data_time=self.data_time,losses=self.loss_basic)
                self.logger.info(msg)
            writer.add_scalar('Train_loss', self.loss_basic.val, global_steps)
            global_steps += 1
            # reset start
            start = time.time()

        self.kwargs['writer_dict']['global_steps_{}'.format(
            self.kwargs['cae_type'])] = global_steps

    def evaluate(self, current_epoch):
        '''
        Evaluate the results of the model
        !!! Will change, e.g. accuracy, mAP.....
        !!! Or can call other methods written by the official
        '''
        self.model.eval()
        correct_1 = 0.0
        correct_5 = 0.0
        for _, (data, target) in enumerate(self.val_dataloader):
            data = Variable(data).cuda()
            target = Variable(target).cuda()
            score = self.model(data)
            _, pred = score.topk(5, 1, largest=True, sorted=True)

            target = target.view(target.size(0), -1).expand_as(pred)
            correct = pred.eq(target).float()

            # compute top5
            correct_5 += (correct[:, :5].sum()) / len(
                self.val_dataloader.dataset)

            # compute top1
            correct_1 += (correct[:, :1].sum()) / len(
                self.val_dataloader.dataset)

        self.logger.info(
            '&^*_*^& ==> Epoch:{}/{} the top1 is {}, the top5 is {}'.format(
                current_epoch, self.config.TRAIN.epochs, correct_1, correct_5))

        return correct_1
Пример #9
0
    def __init__(self, *defaults, **kwargs):
        '''
        Args:
            defaults(tuple): the default will have:
                0->model:{'Generator':net_g, 'Driscriminator':net_d, 'FlowNet':net_flow}
                1->train_dataloader: the dataloader   
                2->val_dataloader: the dataloader     
                3->optimizer:{'optimizer_g':op_g, 'optimizer_d'}
                4->loss_function: {'g_adverserial_loss':.., 'd_adverserial_loss':..., 'gradient_loss':.., 'opticalflow_loss':.., 'intentsity_loss':.. }
                5->logger: the logger of the whole training process
                6->config: the config object of the whole process

            kwargs(dict): the default will have:
                verbose(str):
                parallel(bool): True-> data parallel
                pertrain(bool): True-> use the pretarin model
                extra param:
                    test_dataset_keys: the dataset keys of each video
                    test_dataset_dict: the dataset dict of whole test videos
        '''
        self._hooks = []
        self._eval_hooks = []
        self._register_hooks(kwargs['hooks'])
        # logger & config
        self.logger = defaults[5]
        self.config = defaults[6]

        model = defaults[0]
        # basic things
        if kwargs['parallel']:
            self.STAE = self.data_parallel(model['STAE'])
        else:
            self.STAE = model['STAE'].cuda()

        if kwargs['pretrain']:
            self.load_pretrain()

        self.train_dataloader = defaults[1]
        self._train_loader_iter = iter(self.train_dataloader)

        self.val_dataloader = defaults[2]
        self._val_loader_iter = iter(self.val_dataloader)

        # get the optimizer
        optimizer = defaults[3]
        self.optim_STAE = optimizer['optimizer_stae']

        # get the loss_fucntion
        loss_function = defaults[4]
        # self.rec_loss = loss_function['rec_loss']
        self.rec_loss = loss_function['rec_loss']
        # self.pred_loss = loss_function['pred_loss']
        self.pred_loss = loss_function['pred_loss']

        # basic meter
        self.batch_time = AverageMeter()
        self.data_time = AverageMeter()
        self.loss_meter_STAE = AverageMeter()

        # others
        self.verbose = kwargs['verbose']
        self.accuarcy = 0.0  # to store the accuracy varies from epoch to epoch
        self.config_name = kwargs['config_name']
        self.kwargs = kwargs
        # self.total_steps = len(self.train_dataloader)
        self.result_path = ''
        self.log_step = self.config.TRAIN.log_step  # how many the steps, we will show the information
        self.vis_step = self.config.TRAIN.vis_step  # how many the steps, we will show the information
        self.eval_step = self.config.TRAIN.eval_step
        self.save_step = self.config.TRAIN.save_step  # save the model whatever the acc of the model
        self.max_steps = self.config.TRAIN.max_steps

        self.test_dataset_keys = kwargs['test_dataset_keys']
        self.test_dataset_dict = kwargs['test_dataset_dict']

        self.evaluate_function = kwargs['evaluate_function']
        # hypyer-parameters of loss
        self.loss_lamada = kwargs['loss_lamada']

        # the lr scheduler
        scheduler_dict = kwargs['lr_scheduler_dict']
        self.lr_stae = scheduler_dict['optimizer_stae_scheduler']

        if self.config.RESUME.flag:
            self.resume()

        if self.config.FINETUNE.flag:
            self.fine_tune()
Пример #10
0
class Trainer(DefaultTrainer):
    NAME = ["STAE.TRAIN"]

    def __init__(self, *defaults, **kwargs):
        '''
        Args:
            defaults(tuple): the default will have:
                0->model:{'Generator':net_g, 'Driscriminator':net_d, 'FlowNet':net_flow}
                1->train_dataloader: the dataloader   
                2->val_dataloader: the dataloader     
                3->optimizer:{'optimizer_g':op_g, 'optimizer_d'}
                4->loss_function: {'g_adverserial_loss':.., 'd_adverserial_loss':..., 'gradient_loss':.., 'opticalflow_loss':.., 'intentsity_loss':.. }
                5->logger: the logger of the whole training process
                6->config: the config object of the whole process

            kwargs(dict): the default will have:
                verbose(str):
                parallel(bool): True-> data parallel
                pertrain(bool): True-> use the pretarin model
                extra param:
                    test_dataset_keys: the dataset keys of each video
                    test_dataset_dict: the dataset dict of whole test videos
        '''
        self._hooks = []
        self._eval_hooks = []
        self._register_hooks(kwargs['hooks'])
        # logger & config
        self.logger = defaults[5]
        self.config = defaults[6]

        model = defaults[0]
        # basic things
        if kwargs['parallel']:
            self.STAE = self.data_parallel(model['STAE'])
        else:
            self.STAE = model['STAE'].cuda()

        if kwargs['pretrain']:
            self.load_pretrain()

        self.train_dataloader = defaults[1]
        self._train_loader_iter = iter(self.train_dataloader)

        self.val_dataloader = defaults[2]
        self._val_loader_iter = iter(self.val_dataloader)

        # get the optimizer
        optimizer = defaults[3]
        self.optim_STAE = optimizer['optimizer_stae']

        # get the loss_fucntion
        loss_function = defaults[4]
        # self.rec_loss = loss_function['rec_loss']
        self.rec_loss = loss_function['rec_loss']
        # self.pred_loss = loss_function['pred_loss']
        self.pred_loss = loss_function['pred_loss']

        # basic meter
        self.batch_time = AverageMeter()
        self.data_time = AverageMeter()
        self.loss_meter_STAE = AverageMeter()

        # others
        self.verbose = kwargs['verbose']
        self.accuarcy = 0.0  # to store the accuracy varies from epoch to epoch
        self.config_name = kwargs['config_name']
        self.kwargs = kwargs
        # self.total_steps = len(self.train_dataloader)
        self.result_path = ''
        self.log_step = self.config.TRAIN.log_step  # how many the steps, we will show the information
        self.vis_step = self.config.TRAIN.vis_step  # how many the steps, we will show the information
        self.eval_step = self.config.TRAIN.eval_step
        self.save_step = self.config.TRAIN.save_step  # save the model whatever the acc of the model
        self.max_steps = self.config.TRAIN.max_steps

        self.test_dataset_keys = kwargs['test_dataset_keys']
        self.test_dataset_dict = kwargs['test_dataset_dict']

        self.evaluate_function = kwargs['evaluate_function']
        # hypyer-parameters of loss
        self.loss_lamada = kwargs['loss_lamada']

        # the lr scheduler
        scheduler_dict = kwargs['lr_scheduler_dict']
        self.lr_stae = scheduler_dict['optimizer_stae_scheduler']

        if self.config.RESUME.flag:
            self.resume()

        if self.config.FINETUNE.flag:
            self.fine_tune()

    def train(self, current_step):
        # Pytorch [N, C, D, H, W]

        # initialize
        start = time.time()
        self.STAE.train()
        writer = self.kwargs['writer_dict']['writer']
        global_steps = self.kwargs['writer_dict']['global_steps_{}'.format(
            self.kwargs['model_type'])]

        # get the data
        data = next(self._train_loader_iter)  # the core for dataloader
        self.data_time.update(time.time() - start)

        # get the reconstruction and prediction video clip
        time_len = data.shape[2]
        rec_time = time_len // 2
        input_rec = data[:, :, 0:rec_time, :, :].cuda()  # 0 ~ t//2 frame
        input_pred = data[:, :,
                          rec_time:time_len, :, :].cuda()  # t//2 ~ t frame

        # True Process =================Start===================
        output_rec, output_pred = self.STAE(input_rec)
        loss_rec = self.rec_loss(output_rec, input_rec)
        loss_pred = self.pred_loss(output_pred, input_pred)

        loss_stae_all = self.loss_lamada[
            'rec_loss'] * loss_rec + self.loss_lamada['pred_loss'] * loss_pred
        self.optim_STAE.zero_grad()
        loss_stae_all.backward()
        self.optim_STAE.step()
        self.loss_meter_STAE.update(loss_stae_all.detach())

        if self.config.TRAIN.adversarial.scheduler.use:
            self.lr_stae.step()
        # ======================End==================

        self.batch_time.update(time.time() - start)

        if (current_step % self.log_step == 0):
            msg = 'Step: [{0}/{1}]\t' \
                'Type: {cae_type}\t' \
                'Time: {batch_time.val:.2f}s ({batch_time.avg:.2f}s)\t' \
                'Speed: {speed:.1f} samples/s\t' \
                'Data: {data_time.val:.2f}s ({data_time.avg:.2f}s)\t' \
                'Loss_STAE: {loss.val:.5f} ({loss.avg:.5f})'.format(current_step, self.max_steps, cae_type=self.kwargs['model_type'], batch_time=self.batch_time, speed=self.config.TRAIN.batch_size/self.batch_time.val, data_time=self.data_time,loss=self.loss_meter_STAE)
            self.logger.info(msg)
        writer.add_scalar('Train_loss_STAE', self.loss_meter_STAE.val,
                          global_steps)
        if (current_step % self.vis_step == 0):
            vis_objects = OrderedDict()
            vis_objects['train_output_rec'] = output_rec.detach()
            vis_objects['train_output_pred'] = output_pred.detach()
            vis_objects['train_input_rec'] = input_rec.detach()
            vis_objects['train_input_pred'] = input_pred.detach()
            training_vis_images(vis_objects, writer, global_steps)
        global_steps += 1

        # reset start
        start = time.time()

        self.saved_model = {'STAE': self.STAE}
        self.saved_optimizer = {'optim_STAE': self.optim_STAE}
        self.saved_loss = {'loss_STAE': self.loss_meter_STAE}
        self.kwargs['writer_dict']['global_steps_{}'.format(
            self.kwargs['model_type'])] = global_steps

    def mini_eval(self, current_step):
        if current_step % self.config.TRAIN.mini_eval_step != 0:
            return
        temp_meter_rec = AverageMeter()
        temp_meter_pred = AverageMeter()
        self.STAE.eval()
        for data in self.val_dataloader:
            # get the reconstruction and prediction video clip
            time_len = data.shape[2]
            rec_time = time_len // 2
            inupt_rec_mini = data[:, :,
                                  0:rec_time, :, :].cuda()  # 0 ~ t//2 frame
            input_pred_mini = data[:, :, rec_time:time_len, :, :].cuda(
            )  # t//2 ~ t frame

            # Use the model, get the output
            output_rec_mini, output_pred_mini = self.STAE(inupt_rec_mini)
            rec_psnr_mini = psnr_error(output_rec_mini.detach(),
                                       inupt_rec_mini)
            pred_psnr_mini = psnr_error(output_pred_mini.detach(),
                                        input_pred_mini)
            temp_meter_rec.update(rec_psnr_mini.detach())
            temp_meter_pred.update(pred_psnr_mini.detach())
        self.logger.info(
            f'&^*_*^& ==> Step:{current_step}/{self.max_steps} the rec PSNR is {temp_meter_rec.avg:.3f}, the pred PSNR is {temp_meter_pred.avg:.3f}'
        )
Пример #11
0
    def mini_eval(self, current_step):
        if current_step % self.config.TRAIN.mini_eval_step != 0:
            return
        temp_meter_A = AverageMeter()
        temp_meter_B = AverageMeter()
        temp_meter_C = AverageMeter()
        self.A.eval()
        self.B.eval()
        self.C.eval()
        self.Detector.eval()
        for data in self.val_dataloader:
            # base on the D to get each frame
            # in this method, D = 3 and not change
            future_mini = data[:, -1, :, :, :].cuda()  # t+1 frame
            current_mini = data[:, 1, :, :, :].cuda()  # t frame
            past_mini = data[:, 0, :, :, :].cuda()  # t-1 frame

            bboxs_mini = get_batch_dets(self.Detector, current_mini)

            for index, bbox in enumerate(bboxs_mini):
                if bbox.numel() == 0:
                    bbox = bbox.new_zeros([1, 4])
                # get the crop objects
                input_currentObject_B, _ = multi_obj_grid_crop(
                    current_mini[index], bbox)
                future_object, _ = multi_obj_grid_crop(future_mini[index],
                                                       bbox)
                future2current = torch.stack(
                    [future_object, input_currentObject_B], dim=1)
                past_object, _ = multi_obj_grid_crop(past_mini[index], bbox)
                current2past = torch.stack(
                    [input_currentObject_B, past_object], dim=1)

                _, _, input_objectGradient_A = frame_gradient(future2current)
                input_objectGradient_A = input_objectGradient_A.sum(1)
                _, _, input_objectGradient_C = frame_gradient(current2past)
                input_objectGradient_C = input_objectGradient_A.sum(1)

                _, output_recGradient_A = self.A(input_objectGradient_A)
                _, output_recObject_B = self.B(input_currentObject_B)
                _, output_recGradient_C = self.C(input_objectGradient_C)

                psnr_A = psnr_error(output_recGradient_A.detach(),
                                    input_objectGradient_A)
                psnr_B = psnr_error(output_recObject_B.detach(),
                                    input_currentObject_B)
                psnr_C = psnr_error(output_recGradient_C.detach(),
                                    input_objectGradient_C)
                temp_meter_A.update(psnr_A.detach())
                temp_meter_B.update(psnr_B.detach())
                temp_meter_C.update(psnr_C.detach())

        self.logger.info(
            f'&^*_*^& ==> Step:{current_step}/{self.max_steps} the  A PSNR is {temp_meter_A.avg:.2f}, the B PSNR is {temp_meter_B.avg:.2f}, the C PSNR is {temp_meter_C.avg:.2f}'
        )
Пример #12
0
class Trainer(DefaultTrainer):
    NAME = ["OCAE.TRAIN"]

    def __init__(self, *defaults, **kwargs):
        '''
        Args:
            defaults(tuple): the default will have:
                0->model:{'Generator':net_g, 'Driscriminator':net_d, 'FlowNet':net_flow}
                1->train_dataloader: the dataloader   
                2->val_dataloader: the dataloader     
                3->optimizer:{'optimizer_g':op_g, 'optimizer_d'}
                4->loss_function: {'g_adverserial_loss':.., 'd_adverserial_loss':..., 'gradient_loss':.., 'opticalflow_loss':.., 'intentsity_loss':.. }
                5->logger: the logger of the whole training process
                6->config: the config object of the whole process

            kwargs(dict): the default will have:
                verbose(str):
                parallel(bool): True-> data parallel
                pertrain(bool): True-> use the pretarin model
                extra param:
                    test_dataset_keys: the dataset keys of each video
                    test_dataset_dict: the dataset dict of whole test videos
        '''
        self._hooks = []
        self._register_hooks(kwargs['hooks'])
        # logger & config
        self.logger = defaults[5]
        self.config = defaults[6]

        model = defaults[0]
        # basic things
        if kwargs['parallel']:
            self.A = self.data_parallel(model['A'])
            self.B = self.data_parallel(model['B'])
            self.C = self.data_parallel(model['C'])
            self.Detector = self.data_parallel(model['Detector'])
        else:
            self.A = model['A'].cuda()
            self.B = model['B'].cuda()
            self.C = model['C'].cuda()
            self.Detector = model['Detector'].cuda()

        if kwargs['pretrain']:
            self.load_pretrain()

        self.train_dataloader = defaults[1]
        self._train_loader_iter = iter(self.train_dataloader)

        self.val_dataloader = defaults[2]
        self._val_loader_iter = iter(self.val_dataloader)

        # get the optimizer
        optimizer = defaults[3]
        self.optim_ABC = optimizer['optimizer_abc']

        # get the loss_fucntion
        loss_function = defaults[4]
        self.a_loss = loss_function['A_loss']
        self.b_loss = loss_function['B_loss']
        self.c_loss = loss_function['C_loss']

        # basic meter
        self.batch_time = AverageMeter()
        self.data_time = AverageMeter()
        self.loss_meter_A = AverageMeter()
        self.loss_meter_B = AverageMeter()
        self.loss_meter_C = AverageMeter()
        self.loss_meter_ABC = AverageMeter()
        self.psnr = AverageMeter()

        # others
        self.verbose = kwargs['verbose']
        self.accuarcy = 0.0  # to store the accuracy varies from epoch to epoch
        self.config_name = kwargs['config_name']
        self.kwargs = kwargs
        # self.total_steps = len(self.train_dataloader)
        self.result_path = ''
        self.log_step = self.config.TRAIN.log_step  # how many the steps, we will show the information
        self.eval_step = self.config.TRAIN.eval_step
        self.save_step = self.config.TRAIN.save_step  # save the model whatever the acc of the model
        self.max_steps = self.config.TRAIN.max_steps
        # self.testing_data_folder = self.config.DATASET.test_path
        self.test_dataset_keys = kwargs['test_dataset_keys']
        self.test_dataset_dict = kwargs['test_dataset_dict']

        self.cluster_dataset_keys = kwargs['cluster_dataset_keys']
        self.cluster_dataset_dict = kwargs['cluster_dataset_dict']

        self.evaluate_function = kwargs['evaluate_function']

        # hypyer-parameters of loss
        self.loss_lamada = kwargs['loss_lamada']

        # the lr scheduler
        lr_scheduler_dict = kwargs['lr_scheduler_dict']
        self.lr_abc = lr_scheduler_dict['optimizer_abc_scheduler']

        if self.config.RESUME.flag:
            self.resume()

        if self.config.FINETUNE.flag:
            self.fine_tune()

    def train(self, current_step):
        # Pytorch [N, C, D, H, W]
        # initialize
        start = time.time()
        self.A.train()
        self.B.train()
        self.C.train()
        self.Detector.eval()
        writer = self.kwargs['writer_dict']['writer']
        global_steps = self.kwargs['writer_dict']['global_steps_{}'.format(
            self.kwargs['model_type'])]

        # get the data
        data = next(self._train_loader_iter)  # the core for dataloader
        self.data_time.update(time.time() - start)

        # base on the D to get each frame
        # in this method, D = 3 and not change
        future = data[:, -1, :, :, :].cuda()  # t+1 frame
        current = data[:, 1, :, :, :].cuda()  # t frame
        past = data[:, 0, :, :, :].cuda()  # t-1 frame

        bboxs = get_batch_dets(self.Detector, current)
        # this method is based on the objects to train the model insted of frames
        for index, bbox in enumerate(bboxs):
            if bbox.numel() == 0:
                bbox = bbox.new_zeros([1, 4])
            # get the crop objects
            input_currentObject_B, _ = multi_obj_grid_crop(
                current[index], bbox)
            future_object, _ = multi_obj_grid_crop(future[index], bbox)
            future2current = torch.stack(
                [future_object, input_currentObject_B], dim=1)
            past_object, _ = multi_obj_grid_crop(past[index], bbox)
            current2past = torch.stack([input_currentObject_B, past_object],
                                       dim=1)

            _, _, input_objectGradient_A = frame_gradient(future2current)
            input_objectGradient_A = input_objectGradient_A.sum(1)
            _, _, input_objectGradient_C = frame_gradient(current2past)
            input_objectGradient_C = input_objectGradient_A.sum(1)

            # True Process =================Start===================
            _, output_recGradient_A = self.A(input_objectGradient_A)
            _, output_recObject_B = self.B(input_currentObject_B)
            _, output_recGradient_C = self.C(input_objectGradient_C)
            # import ipdb; ipdb.set_trace()
            loss_A = self.a_loss(output_recGradient_A, input_objectGradient_A)
            loss_B = self.b_loss(output_recObject_B, input_currentObject_B)
            loss_C = self.c_loss(output_recGradient_C, input_objectGradient_C)

            loss_all = self.loss_lamada['A_loss'] * loss_A + self.loss_lamada[
                'B_loss'] * loss_B + self.loss_lamada['C_loss'] * loss_C
            self.optim_ABC.zero_grad()
            loss_all.backward()
            self.optim_ABC.step()
            # record
            self.loss_meter_ABC.update(loss_all.detach())
            if self.config.TRAIN.general.scheduler.use:
                self.lr_abc.step()

            # ======================End==================

        self.batch_time.update(time.time() - start)

        if (current_step % self.log_step == 0):
            msg = 'Step: [{0}/{1}]\t' \
                'Type: {cae_type}\t' \
                'Time: {batch_time.val:.2f}s ({batch_time.avg:.2f}s)\t' \
                'Speed: {speed:.1f} samples/s\t' \
                'Data: {data_time.val:.2f}s ({data_time.avg:.2f}s)\t' \
                'Loss_ABC: {losses_ABC.val:.5f} ({losses_ABC.avg:.5f})\t'.format(current_step, self.max_steps, cae_type=self.kwargs['model_type'], batch_time=self.batch_time, speed=self.config.TRAIN.batch_size/self.batch_time.val, data_time=self.data_time,losses_ABC=self.loss_meter_ABC)
            self.logger.info(msg)
        writer.add_scalar('Train_loss_ABC', self.loss_meter_ABC.val,
                          global_steps)

        if (current_step % self.vis_step == 0):
            vis_objects = OrderedDict()
            vis_objects[
                'train_input_objectGradient_A'] = input_objectGradient_A.detach(
                )
            vis_objects[
                'train_input_currentObject_B'] = input_currentObject_B.detach(
                )
            vis_objects[
                'train_input_objectGradient_C'] = input_objectGradient_C.detach(
                )
            vis_objects[
                'train_output_recGradient_A'] = output_recGradient_A.detach()
            vis_objects[
                'train_output_recObject_B'] = output_recObject_B.detach()
            vis_objects[
                'train_output_recGradient_C'] = output_recGradient_C.detach()
            training_vis_images(vis_objects, writer, global_steps)
        global_steps += 1
        # reset start
        start = time.time()

        self.saved_model = {'A': self.A, 'B': self.B, 'C': self.C}
        self.saved_optimizer = {'optim_ABC': self.optim_ABC}
        self.saved_loss = {'loss_ABC': self.loss_meter_ABC.val}
        self.kwargs['writer_dict']['global_steps_{}'.format(
            self.kwargs['model_type'])] = global_steps

    def mini_eval(self, current_step):
        if current_step % self.config.TRAIN.mini_eval_step != 0:
            return
        temp_meter_A = AverageMeter()
        temp_meter_B = AverageMeter()
        temp_meter_C = AverageMeter()
        self.A.eval()
        self.B.eval()
        self.C.eval()
        self.Detector.eval()
        for data in self.val_dataloader:
            # base on the D to get each frame
            # in this method, D = 3 and not change
            future_mini = data[:, -1, :, :, :].cuda()  # t+1 frame
            current_mini = data[:, 1, :, :, :].cuda()  # t frame
            past_mini = data[:, 0, :, :, :].cuda()  # t-1 frame

            bboxs_mini = get_batch_dets(self.Detector, current_mini)

            for index, bbox in enumerate(bboxs_mini):
                if bbox.numel() == 0:
                    bbox = bbox.new_zeros([1, 4])
                # get the crop objects
                input_currentObject_B, _ = multi_obj_grid_crop(
                    current_mini[index], bbox)
                future_object, _ = multi_obj_grid_crop(future_mini[index],
                                                       bbox)
                future2current = torch.stack(
                    [future_object, input_currentObject_B], dim=1)
                past_object, _ = multi_obj_grid_crop(past_mini[index], bbox)
                current2past = torch.stack(
                    [input_currentObject_B, past_object], dim=1)

                _, _, input_objectGradient_A = frame_gradient(future2current)
                input_objectGradient_A = input_objectGradient_A.sum(1)
                _, _, input_objectGradient_C = frame_gradient(current2past)
                input_objectGradient_C = input_objectGradient_A.sum(1)

                _, output_recGradient_A = self.A(input_objectGradient_A)
                _, output_recObject_B = self.B(input_currentObject_B)
                _, output_recGradient_C = self.C(input_objectGradient_C)

                psnr_A = psnr_error(output_recGradient_A.detach(),
                                    input_objectGradient_A)
                psnr_B = psnr_error(output_recObject_B.detach(),
                                    input_currentObject_B)
                psnr_C = psnr_error(output_recGradient_C.detach(),
                                    input_objectGradient_C)
                temp_meter_A.update(psnr_A.detach())
                temp_meter_B.update(psnr_B.detach())
                temp_meter_C.update(psnr_C.detach())

        self.logger.info(
            f'&^*_*^& ==> Step:{current_step}/{self.max_steps} the  A PSNR is {temp_meter_A.avg:.2f}, the B PSNR is {temp_meter_B.avg:.2f}, the C PSNR is {temp_meter_C.avg:.2f}'
        )
Пример #13
0
class Trainer(DefaultTrainer):
    NAME = ["ANOPRED.TRAIN"]
    def __init__(self, *defaults, **kwargs):
        '''
        Args:
            defaults(tuple): the default will have:
                0->model:{'Generator':net_g, 'Driscriminator':net_d, 'FlowNet':net_flow}
                1->train_dataloader: the dataloader   
                2->val_dataloader: the dataloader     
                3->optimizer:{'optimizer_g':op_g, 'optimizer_d'}
                4->loss_function: {'g_adverserial_loss':.., 'd_adverserial_loss':..., 'gradient_loss':.., 'opticalflow_loss':.., 'intentsity_loss':.. }
                5->logger: the logger of the whole training process
                6->config: the config object of the whole process

            kwargs(dict): the default will have:
                verbose(str):
                parallel(bool): True-> data parallel
                pertrain(bool): True-> use the pretarin model
                extra param:
                    test_dataset_keys: the dataset keys of each video
                    test_dataset_dict: the dataset dict of whole test videos
        '''
        # print('in AnoPredTrainer')
        # logger & config
        self._hooks = []
        self._register_hooks(kwargs['hooks'])
        self.logger = defaults[5]
        self.config = defaults[6]

        model = defaults[0]
        # basic things
        if kwargs['parallel']:
            self.G = self.data_parallel(model['Generator'])
            self.D = self.data_parallel(model['Discriminator'])
            self.F = self.data_parallel(model['FlowNet'])
            # self.G = model['Generator'].to(torch.device('cuda:0'))
            # self.D = model['Discriminator'].to(torch.device('cuda:1'))
            # self.F = model['FlowNet'].cuda()
        else:
            self.G = model['Generator'].cuda()
            self.D = model['Discriminator'].cuda()
            self.F = model['FlowNet'].cuda() # lite flownet
        
        self.F.eval()
        
        if kwargs['pretrain']:
            self.load_pretrain()

        self.train_dataloader = defaults[1]
        self._train_loader_iter = iter(self.train_dataloader)

        self.val_dataloader = defaults[2]
        self._val_loader_iter = iter(self.val_dataloader)

        # get the optimizer
        optimizer = defaults[3]
        self.optim_G = optimizer['optimizer_g']
        self.optim_D = optimizer['optimizer_d']

        # get the loss_fucntion
        loss_function = defaults[4]
        self.gan_loss = loss_function['gan_loss']
        self.gd_loss = loss_function['gradient_loss']
        self.int_loss = loss_function['intentsity_loss']
        self.op_loss = loss_function['opticalflow_loss']

        # basic meter
        self.batch_time =  AverageMeter()
        self.data_time = AverageMeter()
        self.loss_meter_G = AverageMeter()
        self.loss_meter_D = AverageMeter()
        # self.psnr = AverageMeter()

        # others
        self.verbose = kwargs['verbose']
        self.accuarcy = 0.0  # to store the accuracy varies from epoch to epoch
        self.config_name = kwargs['config_name']
        self.kwargs = kwargs
        # self.total_steps = len(self.train_dataloader)
        self.result_path = ''
        self.log_step = self.config.TRAIN.log_step # how many the steps, we will show the information
        self.eval_step = self.config.TRAIN.eval_step 
        self.save_step = self.config.TRAIN.save_step # save the model whatever the acc of the model
        self.max_steps = self.config.TRAIN.max_steps
        # self.testing_data_folder = self.config.DATASET.test_path
        self.test_dataset_keys = kwargs['test_dataset_keys']
        self.test_dataset_dict = kwargs['test_dataset_dict']

        self.evaluate_function = kwargs['evaluate_function']
        
        # hypyer-parameters of loss 
        self.loss_lamada = kwargs['loss_lamada']

        # the lr scheduler
        scheduler_dict = kwargs['lr_scheduler_dict']
        self.lr_g = scheduler_dict['optimizer_g_scheduler']
        self.lr_d = scheduler_dict['optimizer_d_scheduler']

        if self.config.RESUME.flag:
            self.resume()
        
        if self.config.FINETUNE.flag:
            self.fine_tune()
    

    def train(self,current_step):
        # Pytorch [N, C, D, H, W]
        # initialize
        start = time.time()
        self.G.train()
        self.D.train()
        writer = self.kwargs['writer_dict']['writer']
        global_steps = self.kwargs['writer_dict']['global_steps_{}'.format(self.kwargs['model_type'])]

        # get the data
        data = next(self._train_loader_iter) 
        self.data_time.update(time.time() - start)

        # base on the D to get each frame
        target = data[:, :, -1, :, :].cuda() # t+1 frame 
        input_data = data[:, :, :-1, :, :] # 0 ~ t frame
        input_last = input_data[:, :, -1, :, :].cuda() # t frame

        # squeeze the D dimension to C dimension, shape comes to [N, C, H, W]
        input_data = input_data.view(input_data.shape[0], -1, input_data.shape[-2], input_data.shape[-1]).cuda()

        # True Process =================Start===================
        #---------update optim_G ---------
        self.set_requires_grad(self.D, False)
        output_pred_G = self.G(input_data)
        predFlowEstim = torch.cat([input_last, output_pred_G],1)
        gtFlowEstim = torch.cat([input_last, target], 1)
        _, gtFlow = flow_batch_estimate(self.F, gtFlowEstim)
        _, predFlow = flow_batch_estimate(self.F, predFlowEstim)

        # loss_g_adv = self.g_adv_loss(self.D(G_output))
        loss_g_adv = self.gan_loss(self.D(output_pred_G), True)
        loss_op = self.op_loss(predFlow, gtFlow)
        loss_int = self.int_loss(output_pred_G, target)
        loss_gd = self.gd_loss(output_pred_G, target)
        loss_g_all = self.loss_lamada['intentsity_loss'] * loss_int + self.loss_lamada['gradient_loss'] * loss_gd + self.loss_lamada['opticalflow_loss'] * loss_op + self.loss_lamada['gan_loss'] * loss_g_adv
        self.optim_G.zero_grad()
        loss_g_all.backward()
        self.optim_G.step()
        # record
        self.loss_meter_G.update(loss_g_all.detach())
        
        if self.config.TRAIN.adversarial.scheduler.use:
            self.lr_g.step()
        #---------update optim_D ---------------
        self.set_requires_grad(self.D, True)
        self.optim_D.zero_grad()
        # G_output = self.G(input)
        temp_t = self.D(target)
        temp_g = self.D(output_pred_G.detach())
        loss_d_1 = self.gan_loss(temp_t, True)
        loss_d_2 = self.gan_loss(temp_g, False)
        loss_d = (loss_d_1 + loss_d_2) * 0.5
        # loss_d.sum().backward()
        loss_d.backward()

        self.optim_D.step()
        if self.config.TRAIN.adversarial.scheduler.use:
            self.lr_d.step()
        self.loss_meter_D.update(loss_d.detach())
        # ======================End==================

        self.batch_time.update(time.time() - start)

        if (current_step % self.log_step == 0):
            msg = 'Step: [{0}/{1}]\t' \
                'Type: {cae_type}\t' \
                'Time: {batch_time.val:.2f}s ({batch_time.avg:.2f}s)\t' \
                'Speed: {speed:.1f} samples/s\t' \
                'Data: {data_time.val:.2f}s ({data_time.avg:.2f}s)\t' \
                'Loss_G: {losses_G.val:.5f} ({losses_G.avg:.5f})\t'   \
                'Loss_D:{losses_D.val:.5f}({losses_D.avg:.5f})'.format(current_step, self.max_steps, cae_type=self.kwargs['model_type'], batch_time=self.batch_time, speed=self.config.TRAIN.batch_size/self.batch_time.val, data_time=self.data_time,losses_G=self.loss_meter_G, losses_D=self.loss_meter_D)
            self.logger.info(msg)
        writer.add_scalar('Train_loss_G', self.loss_meter_G.val, global_steps)
        writer.add_scalar('Train_loss_D', self.loss_meter_D.val, global_steps)
        
        if (current_step % self.vis_step == 0):
            vis_objects = OrderedDict()
            vis_objects['train_target'] =  target.detach()
            vis_objects['train_output_pred_G'] = output_pred_G.detach()
            vis_objects['train_gtFlow'] = gtFlow.detach()
            vis_objects['train_predFlow'] = predFlow.detach()
            training_vis_images(vis_objects, writer, global_steps)
        global_steps += 1 
        # reset start
        start = time.time()
        
        self.saved_model = {'G':self.G, 'D':self.D}
        self.saved_optimizer = {'optim_G': self.optim_G, 'optim_D': self.optim_D}
        self.saved_loss = {'loss_G':self.loss_meter_G.val, 'loss_D':self.loss_meter_D.val}
        self.kwargs['writer_dict']['global_steps_{}'.format(self.kwargs['model_type'])] = global_steps
    
    def mini_eval(self, current_step):
        if current_step % 10 != 0 or current_step == 0:
            return
        temp_meter_frame = AverageMeter()
        temp_meter_flow = AverageMeter()
        self.G.eval()
        self.D.eval()
        for data in self.val_dataloader:
            # base on the D to get each frame
            target_mini = data[:, :, -1, :, :].cuda() # t+1 frame 
            input_data = data[:, :, :-1, :, :] # 0 ~ t frame
            input_last_mini = input_data[:, :, -1, :, :].cuda() # t frame

            # squeeze the D dimension to C dimension, shape comes to [N, C, H, W]
            input_data_mini = input_data.view(input_data.shape[0], -1, input_data.shape[-2], input_data.shape[-1]).cuda()
            output_pred_G = self.G(input_data_mini)
            gtFlow, _ = flow_batch_estimate(self.F, torch.cat([input_last_mini, target_mini], 1))
            predFlow, _ = flow_batch_estimate(self.F, torch.cat([input_last_mini, output_pred_G], 1))
            frame_psnr_mini = psnr_error(output_pred_G.detach(), target_mini)
            flow_psnr_mini = psnr_error(predFlow, gtFlow)
            temp_meter_frame.update(frame_psnr_mini.detach())
            temp_meter_flow.update(flow_psnr_mini.detach())
        self.logger.info(f'&^*_*^& ==> Step:{current_step}/{self.max_steps} the PSNR is {temp_meter_frame.avg:.2f}, the flow PSNR is {temp_meter_flow.avg:.2f}')