class RegTrainer(Trainer):
    def setup(self):
        """initial the datasets, model, loss and optimizer"""
        args = self.args
        if torch.cuda.is_available():
            self.device = torch.device("cuda")
            self.device_count = torch.cuda.device_count()
            # for code conciseness, we release the single gpu version
            assert self.device_count == 1
            logging.info('using {} gpus'.format(self.device_count))
        else:
            raise Exception("gpu is not available")

        self.downsample_ratio = args.downsample_ratio
        self.datasets = {
            x: Crowd(os.path.join(args.data_dir, x), args.crop_size,
                     args.downsample_ratio, args.is_gray, x)
            for x in ['train', 'val']
        }
        self.dataloaders = {
            x: DataLoader(self.datasets[x],
                          collate_fn=(train_collate
                                      if x == 'train' else default_collate),
                          batch_size=(args.batch_size if x == 'train' else 1),
                          shuffle=(True if x == 'train' else False),
                          num_workers=args.num_workers * self.device_count,
                          pin_memory=(True if x == 'train' else False))
            for x in ['train', 'val']
        }
        #self.model =vgg19()
        self.model = CSRNet()
        self.model.to(self.device)
        self.optimizer = optim.Adam(self.model.parameters(),
                                    lr=args.lr,
                                    weight_decay=args.weight_decay)

        self.start_epoch = 0
        if args.resume:
            suf = args.resume.rsplit('.', 1)[-1]
            if suf == 'tar':
                checkpoint = torch.load(args.resume, self.device)
                self.model.load_state_dict(checkpoint['model_state_dict'])
                self.optimizer.load_state_dict(
                    checkpoint['optimizer_state_dict'])
                self.start_epoch = checkpoint['epoch'] + 1
            elif suf == 'pth':
                self.model.load_state_dict(torch.load(args.resume,
                                                      self.device))

        self.post_prob = Post_Prob(args.sigma, args.crop_size,
                                   args.downsample_ratio,
                                   args.background_ratio, args.use_background,
                                   self.device)
        self.criterion = Bay_Loss(args.use_background, self.device)
        self.save_list = Save_Handle(max_num=args.max_model_num)
        self.best_mae = np.inf
        self.best_mse = np.inf
        self.best_count = 0

    def train(self):
        """training process"""
        args = self.args
        for epoch in range(self.start_epoch, args.max_epoch):
            logging.info('-' * 5 +
                         'Epoch {}/{}'.format(epoch, args.max_epoch - 1) +
                         '-' * 5)
            self.epoch = epoch
            self.train_eopch()
            if epoch % args.val_epoch == 0 and epoch >= args.val_start:
                self.val_epoch()

    def train_eopch(self):
        epoch_loss = AverageMeter()
        epoch_mae = AverageMeter()
        epoch_mse = AverageMeter()
        epoch_start = time.time()
        self.model.train()  # Set model to training mode

        # Iterate over data.
        for step, (inputs, points, targets,
                   st_sizes) in enumerate(self.dataloaders['train']):
            inputs = inputs.to(self.device)
            st_sizes = st_sizes.to(self.device)
            gd_count = np.array([len(p) for p in points], dtype=np.float32)
            points = [p.to(self.device) for p in points]
            targets = [t.to(self.device) for t in targets]

            with torch.set_grad_enabled(True):
                outputs = self.model(inputs)
                prob_list = self.post_prob(points, st_sizes)
                loss = self.criterion(prob_list, targets, outputs)

                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()

                N = inputs.size(0)
                pre_count = torch.sum(outputs.view(N, -1),
                                      dim=1).detach().cpu().numpy()
                res = pre_count - gd_count
                epoch_loss.update(loss.item(), N)
                epoch_mse.update(np.mean(res * res), N)
                epoch_mae.update(np.mean(abs(res)), N)

        logging.info(
            'Epoch {} Train, Loss: {:.2f}, MSE: {:.2f} MAE: {:.2f}, Cost {:.1f} sec'
            .format(self.epoch, epoch_loss.get_avg(),
                    np.sqrt(epoch_mse.get_avg()), epoch_mae.get_avg(),
                    time.time() - epoch_start))
        model_state_dic = self.model.state_dict()
        save_path = os.path.join(self.save_dir,
                                 '{}_ckpt.tar'.format(self.epoch))
        torch.save(
            {
                'epoch': self.epoch,
                'optimizer_state_dict': self.optimizer.state_dict(),
                'model_state_dict': model_state_dic
            }, save_path)
        self.save_list.append(save_path)  # control the number of saved models

    def val_epoch(self):
        epoch_start = time.time()
        self.model.eval()  # Set model to evaluate mode
        epoch_res = []
        # Iterate over data.
        for inputs, count, name in self.dataloaders['val']:
            inputs = inputs.to(self.device)
            # inputs are images with different sizes
            assert inputs.size(
                0) == 1, 'the batch size should equal to 1 in validation mode'
            with torch.set_grad_enabled(False):
                outputs = self.model(inputs)
                res = count[0].item() - torch.sum(outputs).item()
                epoch_res.append(res)

        epoch_res = np.array(epoch_res)
        mse = np.sqrt(np.mean(np.square(epoch_res)))
        mae = np.mean(np.abs(epoch_res))
        logging.info(
            'Epoch {} Val, MSE: {:.2f} MAE: {:.2f}, Cost {:.1f} sec'.format(
                self.epoch, mse, mae,
                time.time() - epoch_start))

        model_state_dic = self.model.state_dict()
        if (2.0 * mse + mae) < (2.0 * self.best_mse + self.best_mae):
            self.best_mse = mse
            self.best_mae = mae
            logging.info(
                "save best mse {:.2f} mae {:.2f} model epoch {}".format(
                    self.best_mse, self.best_mae, self.epoch))
            torch.save(model_state_dic,
                       os.path.join(self.save_dir, 'best_model.pth'))
Exemplo n.º 2
0
class MyTrainer(Trainer):
    def setup(self):
        """initial the datasets, model, loss and optimizer"""
        args = self.args
        self.skip_test = args.skip_test
        if torch.cuda.is_available():
            self.device = torch.device("cuda")
            self.device_count = torch.cuda.device_count()
            # for code conciseness, we release the single gpu version
            assert self.device_count == 1
            logging.info('using {} gpus'.format(self.device_count))
        else:
            raise Exception("gpu is not available")

        self.downsample_ratio = args.downsample_ratio
        lists = {}
        train_list = None
        val_list = None
        test_list = None
        lists['train'] = train_list
        lists['val'] = val_list
        lists['test'] = test_list
        self.datasets = {x: Crowd(os.path.join(args.data_dir, x),
                                  args.crop_size,
                                  args.downsample_ratio,
                                  args.is_gray, x, args.resize,
                                  im_list=lists[x]) for x in ['train', 'val']}
        self.dataloaders = {x: DataLoader(self.datasets[x],
                                          collate_fn=(train_collate
                                                      if x == 'train' else default_collate),
                                          batch_size=(args.batch_size
                                          if x == 'train' else 1),
                                          shuffle=(True if x == 'train' else False),
                                          num_workers=args.num_workers*self.device_count,
                                          pin_memory=(True if x == 'train' else False))
                            for x in ['train', 'val']}
        self.datasets['test'] = Crowd(os.path.join(args.data_dir, 'test'),
                                    args.crop_size,
                                    args.downsample_ratio,
                                    args.is_gray, 'val', args.resize, 
                                    im_list=lists['test'])
        self.dataloaders['test'] = DataLoader(self.datasets['test'],
                                    collate_fn=default_collate,
                                    batch_size=1,
                                    shuffle=False,
                                    num_workers=args.num_workers*self.device_count,
                                    pin_memory=False)
        print(len(self.dataloaders['train']))
        print(len(self.dataloaders['val']))

        if self.args.net == 'csrnet':
            self.model = CSRNet()
        else:
            self.model = vgg19()

        self.refiner = IndivBlur8(s=args.s, downsample=self.downsample_ratio, softmax=args.soft)
        refine_params = list(self.refiner.adapt.parameters())

        self.model.to(self.device)
        self.refiner.to(self.device)
        params = list(self.model.parameters()) 
        self.optimizer = optim.Adam(params, lr=args.lr, weight_decay=args.weight_decay)
        # self.optimizer = optim.SGD(params, lr=args.lr, momentum=0.95, weight_decay=args.weight_decay)
        self.dml_optimizer = torch.optim.Adam(refine_params, lr=1e-7, weight_decay=args.weight_decay)

        self.start_epoch = 0
        if args.resume:
            suf = args.resume.rsplit('.', 1)[-1]
            if suf == 'tar':
                checkpoint = torch.load(args.resume, self.device)
                self.model.load_state_dict(checkpoint['model_state_dict'])
                self.refiner.load_state_dict(checkpoint['refine_state_dict'])
                self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
                self.start_epoch = checkpoint['epoch'] + 1
            elif suf == 'pth':
                self.model.load_state_dict(torch.load(args.resume, self.device))

        self.crit = torch.nn.MSELoss(reduction='sum')

        self.save_list = Save_Handle(max_num=args.max_model_num)
        self.test_flag = False
        self.best_mae = {}
        self.best_mse = {}
        self.best_epoch = {}
        for stage in ['val', 'test']:
            self.best_mae[stage] = np.inf
            self.best_mse[stage] = np.inf
            self.best_epoch[stage] = 0

    def train(self):
        """training process"""
        args = self.args
        for epoch in range(self.start_epoch, args.max_epoch):
            logging.info('-'*5 + 'Epoch {}/{}'.format(epoch, args.max_epoch - 1) + '-'*5)
            self.epoch = epoch
            self.train_eopch(epoch)
            if epoch % args.val_epoch == 0 and epoch >= args.val_start:
                self.val_epoch()
                if self.test_flag and not self.skip_test:
                    self.val_epoch(stage='test')
                    self.test_flag = False

    def train_eopch(self, epoch=0):
        epoch_loss = AverageMeter()
        epoch_fore = AverageMeter()
        epoch_back = AverageMeter()
        epoch_cls_loss = AverageMeter()
        epoch_cls_acc = AverageMeter()
        epoch_fea_loss = AverageMeter()
        epoch_mae = AverageMeter()
        epoch_mse = AverageMeter()
        epoch_start = time.time()
        self.model.train()  # Set model to training mode
        self.refiner.train()  # Set model to training mode
        s_loss = None

        # Iterate over data.
        for step, (inputs, points, targets, st_sizes) in enumerate(self.dataloaders['train']):
            inputs = inputs.to(self.device)
            st_sizes = st_sizes.to(self.device)
            gd_count = np.array([len(p) for p in points], dtype=np.float32)
            points = [p.to(self.device) for p in points]
            targets = [t.to(self.device) for t in targets]

            with torch.set_grad_enabled(True):
                outputs = self.model(inputs)

                gt = self.refiner(points, inputs, outputs.shape)

                loss = self.crit(gt, outputs)
                loss += 10*cos_loss(gt, outputs)
                loss /= self.args.batch_size

                self.optimizer.zero_grad()
                self.dml_optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()
                self.dml_optimizer.step()

                pre_count = outputs[0].sum().detach().cpu().numpy()
                res = (pre_count - gd_count[0]) #gd_count
                if step % 100 == 0:
                    print('Error: {}, Pred: {}, GT: {}, Loss: {}'.format(res, pre_count, gd_count[0], loss.item()))

                N = inputs.shape[0]
                epoch_loss.update(loss.item(), N)
                epoch_mse.update(np.mean(res * res), N)
                epoch_mae.update(np.mean(abs(res)), N)

        logging.info('Epoch {} Train, Loss: {:.2f}, MSE: {:.2f} MAE: {:.2f}, Cost {:.1f} sec'
                .format(self.epoch, epoch_loss.get_avg(), np.sqrt(epoch_mse.get_avg()), epoch_mae.get_avg(),
                    time.time()-epoch_start))
        model_state_dic = self.model.state_dict()
        save_path = os.path.join(self.save_dir, '{}_ckpt.tar'.format(self.epoch))
        torch.save({
            'epoch': self.epoch,
            'optimizer_state_dict': self.optimizer.state_dict(),
            'model_state_dict': model_state_dic,
            'refine_state_dict': self.refiner.state_dict(),
        }, save_path)
        self.save_list.append(save_path)  # control the number of saved models

    def val_epoch(self, stage='val'):
        epoch_start = time.time()
        self.model.eval()  # Set model to evaluate mode
        self.refiner.eval()
        epoch_res = []
        epoch_fore = []
        epoch_back = []
        # Iterate over data.
        for inputs, points, name in self.dataloaders[stage]:
            inputs = inputs.to(self.device)
            # inputs are images with different sizes
            assert inputs.size(0) == 1, 'the batch size should equal to 1 in validation mode'
            with torch.set_grad_enabled(False):
                outputs = self.model(inputs)
                points = points[0].type(torch.LongTensor)
                res = len(points) - torch.sum(outputs).item()
                epoch_res.append(res)

        epoch_res = np.array(epoch_res)
        mse = np.sqrt(np.mean(np.square(epoch_res)))
        mae = np.mean(np.abs(epoch_res))
        logging.info('{} Epoch {}, MSE: {:.2f} MAE: {:.2f}, Cost {:.1f} sec'
                     .format(stage, self.epoch, mse, mae, time.time()-epoch_start))

        model_state_dic = self.model.state_dict()
        if (mse + mae) < (self.best_mse[stage] + self.best_mae[stage]):
            self.test_flag = True
            self.best_mse[stage] = mse
            self.best_mae[stage] = mae
            self.best_epoch[stage] = self.epoch 
            logging.info("{} save best mse {:.2f} mae {:.2f} model epoch {}".format(stage,
                                                                            self.best_mse[stage],
                                                                            self.best_mae[stage],
                                                                                 self.epoch))
            torch.save(model_state_dic, os.path.join(self.save_dir, 'best_{}.pth').format(stage))
        # print log info
        logging.info('Val: Best Epoch {} Val, MSE: {:.2f} MAE: {:.2f}, Cost {:.1f} sec'
                     .format(self.best_epoch['val'], self.best_mse['val'], self.best_mae['val'], time.time()-epoch_start))
        logging.info('Test: Best Epoch {} Val, MSE: {:.2f} MAE: {:.2f}, Cost {:.1f} sec'
                     .format(self.best_epoch['test'], self.best_mse['test'], self.best_mae['test'], time.time()-epoch_start))
Exemplo n.º 3
0
    def train(self):
        """training process"""
        args = self.args

        step = 0
        best_acc = 0.0
        step_loss = 0.0
        step_acc = 0
        step_count = 0
        step_start = time.time()

        save_list = Save_Handle(max_num=args.max_model_num)

        for epoch in range(self.start_epoch, args.max_epoch):
            if self.lr_scheduler is not None:
                self.lr_scheduler.step(epoch)
            logging.info('-' * 5 +
                         'Epoch {}/{}'.format(epoch, args.max_epoch - 1) +
                         '-' * 5)

            # Each epoch has a training and validation phase
            for phase in ['train', 'val']:
                epoch_start = time.time()
                if phase == 'train':
                    self.model.train()  # Set model to training mode
                else:
                    self.model.eval()  # Set model to evaluate mode

                epoch_loss = 0.0
                epoch_acc = 0

                # Iterate over data.
                for inputs, labels in self.dataloaders[phase]:
                    inputs = inputs.to(self.device)
                    labels = labels.to(self.device)

                    # forward
                    # track history if only in train
                    with torch.set_grad_enabled(phase == 'train'):
                        outputs = self.model(inputs)
                        loss = self.criterion(outputs, labels)

                        _, preds = torch.max(outputs, 1)

                        temp_correct = torch.sum(preds == labels.data)
                        temp_loss_sum = loss.item() * inputs.size(0)

                        if phase == 'train':
                            self.optimizer.zero_grad()
                            loss.backward()
                            self.optimizer.step()

                            step_loss += temp_loss_sum
                            step_acc += temp_correct
                            step_count += inputs.size(0)

                            if step % args.display_step == 0:
                                step_loss = step_loss / step_count
                                step_acc = step_acc.double() / step_count
                                temp_time = time.time()
                                train_elap = temp_time - step_start
                                step_start = temp_time
                                batch_elap = train_elap / args.display_step if step != 0 else train_elap
                                samples_per_s = 1.0 * step_count / train_elap
                                logging.info(
                                    'Step {} Epoch {}, Train Loss: {:.4f} Train Acc: {:.4f}, '
                                    '{:.1f} examples/sec {:.2f} sec/batch'.
                                    format(step, epoch, step_loss, step_acc,
                                           samples_per_s, batch_elap))
                                step_loss = 0.0
                                step_acc = 0
                                step_count = 0
                            step += 1

                    # statistics
                    epoch_loss += temp_loss_sum
                    epoch_acc += temp_correct

                epoch_loss = epoch_loss / len(self.dataloaders[phase].dataset)
                epoch_acc = epoch_acc.double() / len(
                    self.dataloaders[phase].dataset)

                logging.info(
                    'Epoch {} {}, Loss: {:.4f} Acc: {:.4f}, Cost {:.1f} sec'.
                    format(epoch, phase, epoch_loss, epoch_acc,
                           time.time() - epoch_start))

                model_state_dic = self.model.module.state_dict(
                ) if self.device_count > 1 else self.model.state_dict()
                save_path = os.path.join(self.save_dir,
                                         '{}_ckpt.tar'.format(epoch))
                torch.save(
                    {
                        'epoch': epoch,
                        'optimizer_state_dict': self.optimizer.state_dict(),
                        'model_state_dict': model_state_dic
                    }, save_path)
                save_list.append(save_path)

                if phase == 'val' and epoch_acc > best_acc:
                    best_acc = epoch_acc
                    logging.info("save best model epoch {}, acc {:.4f}".format(
                        epoch, epoch_acc))
                    torch.save(model_state_dic,
                               os.path.join(self.save_dir, 'best_model.pth'))
class RegTrainer(Trainer):
    def setup(self):
        """initial the datasets, model, loss and optimizer"""
        args = self.args
        if torch.cuda.is_available():
            self.device = torch.device("cuda")
            self.device_count = torch.cuda.device_count()
            # for code conciseness, we release the single gpu version
            assert self.device_count == 1
            logging.info('using {} gpus'.format(self.device_count))
        else:
            raise Exception("gpu is not available")

        self.downsample_ratio = args.downsample_ratio
        self.datasets = {x: Crowd(os.path.join(args.data_dir, x),
                                  args.crop_size,
                                  args.downsample_ratio,
                                  x) for x in ['train', 'val', 'test']}
        self.dataloaders = {x: DataLoader(self.datasets[x],
                                          collate_fn=(train_collate
                                                      if x == 'train' else default_collate),
                                          batch_size=(args.batch_size
                                          if x == 'train' else 1),
                                          shuffle=(True if x == 'train' else False),
                                          num_workers=args.num_workers*self.device_count,
                                          pin_memory=(True if x == 'train' else False))
                            for x in ['train', 'val', 'test']}

        self.model = fusion_model()
        self.model.to(self.device)
        self.optimizer = optim.Adam(self.model.parameters(), lr=args.lr, weight_decay=args.weight_decay)

        self.start_epoch = 0
        if args.resume:
            suf = args.resume.rsplit('.', 1)[-1]
            if suf == 'tar':
                checkpoint = torch.load(args.resume, self.device)
                self.model.load_state_dict(checkpoint['model_state_dict'])
                self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
                self.start_epoch = checkpoint['epoch'] + 1
            elif suf == 'pth':
                self.model.load_state_dict(torch.load(args.resume, self.device))

        self.post_prob = Post_Prob(args.sigma,
                                   args.crop_size,
                                   args.downsample_ratio,
                                   args.background_ratio,
                                   args.use_background,
                                   self.device)
        self.criterion = Bay_Loss(args.use_background, self.device)
        self.save_list = Save_Handle(max_num=args.max_model_num)

        self.best_game0 = np.inf
        self.best_game3 = np.inf
        self.best_count = 0
        self.best_count_1 = 0

    def train(self):
        """training process"""
        args = self.args
        for epoch in range(self.start_epoch, args.max_epoch):
            # logging.info('save dir: '+args.save_dir)
            logging.info('-'*5 + 'Epoch {}/{}'.format(epoch, args.max_epoch - 1) + '-'*5)
            self.epoch = epoch
            self.train_eopch()
            if epoch % args.val_epoch == 0 and epoch >= args.val_start:
                game0_is_best, game3_is_best = self.val_epoch()

            if epoch >= args.val_start and (game0_is_best or game3_is_best):
                self.test_epoch()

    def train_eopch(self):
        epoch_loss = AverageMeter()
        epoch_game = AverageMeter()
        epoch_mse = AverageMeter()
        epoch_start = time.time()
        self.model.train()  # Set model to training mode

        # Iterate over data.
        for step, (inputs, points, st_sizes) in enumerate(self.dataloaders['train']):

            if type(inputs) == list:
                inputs[0] = inputs[0].to(self.device)
                inputs[1] = inputs[1].to(self.device)
            else:
                inputs = inputs.to(self.device)
            st_sizes = st_sizes.to(self.device)
            gd_count = np.array([len(p) for p in points], dtype=np.float32)
            points = [p.to(self.device) for p in points]

            with torch.set_grad_enabled(True):
                outputs = self.model(inputs)
                prob_list = self.post_prob(points, st_sizes)
                loss = self.criterion(prob_list, outputs)

                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()

                if type(inputs) == list:
                    N = inputs[0].size(0)
                else:
                    N = inputs.size(0)
                pre_count = torch.sum(outputs.view(N, -1), dim=1).detach().cpu().numpy()
                res = pre_count - gd_count
                epoch_loss.update(loss.item(), N)
                epoch_mse.update(np.mean(res * res), N)
                epoch_game.update(np.mean(abs(res)), N)

        logging.info('Epoch {} Train, Loss: {:.2f}, GAME0: {:.2f} MSE: {:.2f}, Cost {:.1f} sec'
                     .format(self.epoch, epoch_loss.get_avg(), epoch_game.get_avg(), np.sqrt(epoch_mse.get_avg()),
                             time.time()-epoch_start))
        model_state_dic = self.model.state_dict()
        save_path = os.path.join(self.save_dir, '{}_ckpt.tar'.format(self.epoch))
        torch.save({
            'epoch': self.epoch,
            'optimizer_state_dict': self.optimizer.state_dict(),
            'model_state_dict': model_state_dic
        }, save_path)
        self.save_list.append(save_path)  # control the number of saved models

    def val_epoch(self):
        args = self.args
        self.model.eval()  # Set model to evaluate mode

        # Iterate over data.
        game = [0, 0, 0, 0]
        mse = [0, 0, 0, 0]
        total_relative_error = 0

        for inputs, target, name in self.dataloaders['val']:
            if type(inputs) == list:
                inputs[0] = inputs[0].to(self.device)
                inputs[1] = inputs[1].to(self.device)
            else:
                inputs = inputs.to(self.device)

            # inputs are images with different sizes
            if type(inputs) == list:
                assert inputs[0].size(0) == 1
            else:
                assert inputs.size(0) == 1, 'the batch size should equal to 1 in validation mode'
            with torch.set_grad_enabled(False):
                outputs = self.model(inputs)

                for L in range(4):
                    abs_error, square_error = eval_game(outputs, target, L)
                    game[L] += abs_error
                    mse[L] += square_error
                relative_error = eval_relative(outputs, target)
                total_relative_error += relative_error

        N = len(self.dataloaders['val'])
        game = [m / N for m in game]
        mse = [torch.sqrt(m / N) for m in mse]
        total_relative_error = total_relative_error / N

        logging.info('Epoch {} Val{}, '
                     'GAME0 {game0:.2f} GAME1 {game1:.2f} GAME2 {game2:.2f} GAME3 {game3:.2f} MSE {mse:.2f} Re {relative:.4f}, '
                     .format(self.epoch, N, game0=game[0], game1=game[1], game2=game[2], game3=game[3], mse=mse[0], relative=total_relative_error
                             )
                     )

        model_state_dic = self.model.state_dict()

        game0_is_best = game[0] < self.best_game0
        game3_is_best = game[3] < self.best_game3

        if game[0] < self.best_game0 or game[3] < self.best_game3:
            self.best_game3 = min(game[3], self.best_game3)
            self.best_game0 = min(game[0], self.best_game0)
            logging.info("*** Best Val GAME0 {:.3f} GAME3 {:.3f} model epoch {}".format(self.best_game0,
                                                                                    self.best_game3,
                                                                                    self.epoch))
            if args.save_all_best:
                torch.save(model_state_dic, os.path.join(self.save_dir, 'best_model_{}.pth'.format(self.best_count)))
                self.best_count += 1
            else:
                torch.save(model_state_dic, os.path.join(self.save_dir, 'best_model.pth'))

        return game0_is_best, game3_is_best

    def test_epoch(self):
        self.model.eval()  # Set model to evaluate mode

        # Iterate over data.
        game = [0, 0, 0, 0]
        mse = [0, 0, 0, 0]
        total_relative_error = 0

        for inputs, target, name in self.dataloaders['test']:
            if type(inputs) == list:
                inputs[0] = inputs[0].to(self.device)
                inputs[1] = inputs[1].to(self.device)
            else:
                inputs = inputs.to(self.device)

            # inputs are images with different sizes
            if type(inputs) == list:
                assert inputs[0].size(0) == 1
            else:
                assert inputs.size(0) == 1, 'the batch size should equal to 1 in validation mode'
            with torch.set_grad_enabled(False):
                outputs = self.model(inputs)

                for L in range(4):
                    abs_error, square_error = eval_game(outputs, target, L)
                    game[L] += abs_error
                    mse[L] += square_error
                relative_error = eval_relative(outputs, target)
                total_relative_error += relative_error

        N = len(self.dataloaders['test'])
        game = [m / N for m in game]
        mse = [torch.sqrt(m / N) for m in mse]
        total_relative_error = total_relative_error / N

        logging.info('Epoch {} Test{}, '
                     'GAME0 {game0:.2f} GAME1 {game1:.2f} GAME2 {game2:.2f} GAME3 {game3:.2f} MSE {mse:.2f} Re {relative:.4f}, '
                     .format(self.epoch, N, game0=game[0], game1=game[1], game2=game[2], game3=game[3], mse=mse[0],
                             relative=total_relative_error
                             )
                     )
Exemplo n.º 5
0
class RegTrainer(Trainer):
    def setup(self):
        print("regression Trainer ---> setup")
        """initial the datasets, model, loss and optimizer"""
        args = self.args
        if torch.cuda.is_available():
            self.device = torch.device("cuda")
            self.device_count = torch.cuda.device_count()
            # for code conciseness, we release the single gpu version
            assert self.device_count == 1
            logging.info('using {} gpus'.format(self.device_count))
        else:
            raise Exception("gpu is not available")

        self.downsample_ratio = args.downsample_ratio
        self.datasets = {
            x: Crowd(os.path.join(args.data_dir, x), args.crop_size,
                     args.downsample_ratio, args.is_gray, x)
            for x in ['train', 'val']
        }
        # pytorch的dataloaders
        self.dataloaders = {
            x: DataLoader(
                self.datasets[x],  # 传入的数据集
                collate_fn=(train_collate
                            if x == 'train' else default_collate),
                batch_size=(args.batch_size
                            if x == 'train' else 1),  # 每个batch有多少个样本
                shuffle=(True if x == 'train' else False),  # 洗牌
                num_workers=args.num_workers * self.device_count,  #几个进程来处理
                pin_memory=(True
                            if x == 'train' else False))  # 是否拷贝到cuda的固定内存中
            for x in ['train', 'val']
        }
        print()
        self.model = vgg19()  # 用vgg19模型
        self.model.to(self.device)
        self.optimizer = optim.Adam(self.model.parameters(),
                                    lr=args.lr,
                                    weight_decay=args.weight_decay)

        self.start_epoch = 0
        if args.resume:
            suf = args.resume.rsplit('.', 1)[-1]
            if suf == 'tar':
                checkpoint = torch.load(args.resume, self.device)
                self.model.load_state_dict(checkpoint['model_state_dict'])
                self.optimizer.load_state_dict(
                    checkpoint['optimizer_state_dict'])
                self.start_epoch = checkpoint['epoch'] + 1
            elif suf == 'pth':
                self.model.load_state_dict(torch.load(args.resume,
                                                      self.device))

        self.post_prob = Post_Prob(args.sigma, args.crop_size,
                                   args.downsample_ratio,
                                   args.background_ratio, args.use_background,
                                   self.device)
        self.criterion = Bay_Loss(args.use_background, self.device)
        self.save_list = Save_Handle(max_num=args.max_model_num)
        self.best_mae = np.inf
        self.best_mse = np.inf
        self.best_count = 0

    def train(self):
        print("regression Trainer ---> train")
        """training process"""
        args = self.args
        for epoch in range(self.start_epoch, args.max_epoch):
            logging.info('-' * 5 +
                         'Epoch {}/{}'.format(epoch, args.max_epoch - 1) +
                         '-' * 5)
            self.epoch = epoch
            self.train_eopch()
            if epoch % args.val_epoch == 0 and epoch >= args.val_start:
                self.val_epoch()

    '''
    train_epoch训练次数,这里训练1000次
    初始化loss,mae,mse,时间和模型,开始新一轮训练
    '''

    def train_eopch(self):
        print("regression Trainer ---> train_eopch")
        epoch_loss = AverageMeter()
        epoch_mae = AverageMeter()
        epoch_mse = AverageMeter()
        epoch_start = time.time()
        self.model.train()  # Set model to training mode
        '''
        遍历洗牌后的数据进行训练,乱序
        inputs, points, targets, st_sizes已经在前面的train_collate()定义
        '''
        # Iterate over data.
        for step, (inputs, points, targets,
                   st_sizes) in enumerate(self.dataloaders['train']):

            inputs = inputs.to(self.device)
            st_sizes = st_sizes.to(self.device)
            gd_count = np.array([len(p) for p in points], dtype=np.float32)
            points = [p.to(self.device) for p in points]
            # print("points:")
            # print(points)
            targets = [t.to(self.device) for t in targets]
            # print("targets:")
            # print(targets)

            with torch.set_grad_enabled(True):
                outputs = self.model(inputs)
                # outputs可以转换为64*64的矩阵,可以表示密度图,但是和论文里的不符
                # 【model是什么----vgg19模型】
                # 【输出查看Inputs是什么----tensor矩阵】
                # print(inputs)
                # 【输出查看Outputs是什么----tensor矩阵】
                # print(outputs)
                '''
                针对每一次训练,输出图像,发现都是64*64大小
                这里用的是层层卷积处理好的数据
                无法获取图像名称,而且已经被洗牌,所以顺序对不上
                '''
                # dm = outputs.squeeze().detach().cpu().numpy()
                # dm_nor = (dm-np.min(dm))/(np.max(dm)-np.min(dm)) # 归一化
                # plt.imshow(dm_nor, cmap=cm.jet)
                # 这里img都被数据代替,无法获取名字,所以用Num计数
                # plt.savefig("D:\研究生\BayesCrowdCounting\\" + str(num))
                # print("ok!")
                '''
                先验概率和损失
                在前面已经定义
                self.post_prob = Post_Prob(args.sigma,args.crop_size,……)
                self.criterion = Bay_Loss(args.use_background, self.device)
                '''
                prob_list = self.post_prob(points, st_sizes)
                loss = self.criterion(prob_list, targets, outputs)

                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()

                N = inputs.size(0)
                pre_count = torch.sum(outputs.view(N, -1),
                                      dim=1).detach().cpu().numpy()
                res = pre_count - gd_count

                epoch_loss.update(loss.item(), N)
                epoch_mse.update(np.mean(res * res), N)
                epoch_mae.update(np.mean(abs(res)), N)
        '''
        训练完一轮后在这里输出loss,mse,mae……的平均值
        '''
        logging.info(
            'Epoch {} Train, Loss: {:.2f}, MSE: {:.2f} MAE: {:.2f}, Cost {:.1f} sec'
            .format(self.epoch, epoch_loss.get_avg(),
                    np.sqrt(epoch_mse.get_avg()), epoch_mae.get_avg(),
                    time.time() - epoch_start))
        model_state_dic = self.model.state_dict()
        save_path = os.path.join(self.save_dir,
                                 '{}_ckpt.tar'.format(self.epoch))
        torch.save(
            {
                'epoch': self.epoch,
                'optimizer_state_dict': self.optimizer.state_dict(),
                'model_state_dict': model_state_dic
            }, save_path)
        self.save_list.append(save_path)  # control the number of saved models

    def val_epoch(self):
        print("regression Trainer ---> val_epoch")
        epoch_start = time.time()
        self.model.eval()  # Set model to evaluate mode
        epoch_res = []
        # Iterate over data.
        for inputs, count, name in self.dataloaders['val']:
            inputs = inputs.to(self.device)
            # inputs are images with different sizes
            assert inputs.size(
                0) == 1, 'the batch size should equal to 1 in validation mode'
            with torch.set_grad_enabled(False):
                outputs = self.model(inputs)

                res = count[0].item() - torch.sum(outputs).item()
                epoch_res.append(res)

        epoch_res = np.array(epoch_res)
        mse = np.sqrt(np.mean(np.square(epoch_res)))
        mae = np.mean(np.abs(epoch_res))
        logging.info(
            'Epoch {} Val, MSE: {:.2f} MAE: {:.2f}, Cost {:.1f} sec'.format(
                self.epoch, mse, mae,
                time.time() - epoch_start))

        model_state_dic = self.model.state_dict()
        if (2.0 * mse + mae) < (2.0 * self.best_mse + self.best_mae):
            self.best_mse = mse
            self.best_mae = mae
            logging.info(
                "save best mse {:.2f} mae {:.2f} model epoch {}".format(
                    self.best_mse, self.best_mae, self.epoch))
            torch.save(model_state_dic,
                       os.path.join(self.save_dir, 'best_model.pth'))