Example #1
0
 def __init__(self, lr, train_loader, model, optimizer, scheduler, logger, device):
     self.lr = lr
     self.train_loader = train_loader
     self.model = model
     self.optimizer = optimizer
     self.scheduler = scheduler
     self.device = device
     self.lossfn = LossFn(self.device)
     self.logger = logger
     self.run_count = 0
     self.scalar_info = {}
Example #2
0
 def __init__(self, lr, train_loader, valid_loader, model_1, model_2,
              optimizer_1, optimizer_2, scheduler_1, scheduler_2, logger,
              device):
     self.lr = lr
     self.train_loader = train_loader
     self.valid_loader = valid_loader
     self.model_1 = model_1
     self.optimizer_1 = optimizer_1
     self.scheduler_1 = scheduler_1
     self.model_2 = model_2
     self.optimizer_2 = optimizer_2
     self.scheduler_2 = scheduler_2
     self.device = device
     self.lossfn = LossFn(self.device, lam=2)
     self.logger = logger
     self.run_count = 0
     self.scalar_info = {}
     self.config = Config()
class PNetTrainer(object):
    def __init__(self, lr, train_loader, model, optimizer, scheduler, logger,
                 device):
        self.lr = lr
        self.train_loader = train_loader
        self.model = model
        self.optimizer = optimizer
        self.scheduler = scheduler
        self.device = device
        self.lossfn = LossFn(self.device)
        self.logger = logger
        self.run_count = 0
        self.scalar_info = {}

    def compute_accuracy(self, prob_cls, gt_cls):
        #we only need the detection which >= 0
        prob_cls = torch.squeeze(prob_cls)
        mask = torch.ge(gt_cls, 0)
        #get valid element
        valid_gt_cls = torch.masked_select(gt_cls, mask)
        valid_prob_cls = torch.masked_select(prob_cls, mask)
        size = min(valid_gt_cls.size()[0], valid_prob_cls.size()[0])
        prob_ones = torch.ge(valid_prob_cls, 0.6).float()
        right_ones = torch.eq(prob_ones, valid_gt_cls.float()).float()

        return torch.div(torch.mul(torch.sum(right_ones), float(1.0)),
                         float(size))

    def update_lr(self, epoch):
        """
        update learning rate of optimizers
        :param epoch: current training epoch
        """
        # update learning rate of model optimizer
        for param_group in self.optimizer.param_groups:
            param_group['lr'] = self.lr

    def train(self, epoch):
        cls_loss_ = AverageMeter()
        box_offset_loss_ = AverageMeter()
        total_loss_ = AverageMeter()
        accuracy_ = AverageMeter()

        self.scheduler.step()
        self.model.train()

        for batch_idx, (data, target) in enumerate(self.train_loader):
            gt_label = target['label']
            gt_bbox = target['bbox_target']
            data, gt_label, gt_bbox = data.to(self.device), gt_label.to(
                self.device), gt_bbox.to(self.device).float()

            cls_pred, box_offset_pred = self.model(data)
            # compute the loss
            cls_loss = self.lossfn.cls_loss(gt_label, cls_pred)
            box_offset_loss = self.lossfn.box_loss(gt_label, gt_bbox,
                                                   box_offset_pred)

            total_loss = cls_loss + box_offset_loss
            accuracy = self.compute_accuracy(cls_pred, gt_label)

            self.optimizer.zero_grad()
            total_loss.backward()
            self.optimizer.step()

            cls_loss_.update(cls_loss, data.size(0))
            box_offset_loss_.update(box_offset_loss, data.size(0))
            total_loss_.update(total_loss, data.size(0))
            accuracy_.update(accuracy, data.size(0))

            print(
                'Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}\tAccuracy: {:.6f}'
                .format(epoch, batch_idx * len(data),
                        len(self.train_loader.dataset),
                        100. * batch_idx / len(self.train_loader),
                        total_loss.item(), accuracy.item()))

        self.scalar_info['cls_loss'] = cls_loss_.avg
        self.scalar_info['box_offset_loss'] = box_offset_loss_.avg
        self.scalar_info['total_loss'] = total_loss_.avg
        self.scalar_info['accuracy'] = accuracy_.avg
        self.scalar_info['lr'] = self.scheduler.get_lr()[0]

        if self.logger is not None:
            for tag, value in list(self.scalar_info.items()):
                self.logger.scalar_summary(tag, value, self.run_count)
            self.scalar_info = {}
        self.run_count += 1

        print("|===>Loss: {:.4f}".format(total_loss_.avg))
        return cls_loss_.avg, box_offset_loss_.avg, total_loss_.avg, accuracy_.avg
Example #4
0
class AlexNetTrainer(object):
    def __init__(self, lr, train_loader, valid_loader, model_1, model_2,
                 optimizer_1, optimizer_2, scheduler_1, scheduler_2, logger,
                 device):
        self.lr = lr
        self.train_loader = train_loader
        self.valid_loader = valid_loader
        self.model_1 = model_1
        self.optimizer_1 = optimizer_1
        self.scheduler_1 = scheduler_1
        self.model_2 = model_2
        self.optimizer_2 = optimizer_2
        self.scheduler_2 = scheduler_2
        self.device = device
        self.lossfn = LossFn(self.device, lam=2)
        self.logger = logger
        self.run_count = 0
        self.scalar_info = {}
        self.config = Config()
        # hook
        # self.handle = self.model_2.channelgroup_2.group[0].register_backward_hook(self.for_hook)

    def compute_accuracy(self, prob_cls, gt_cls):
        pred_cls = torch.max(prob_cls, 1)[1].squeeze()
        accuracy = float((pred_cls == gt_cls).sum()) / float(gt_cls.size(0))
        return accuracy

    def update_lr(self, epoch):
        """
        update learning rate of optimizers
        :param epoch: current training epoch
        """
        # update learning rate of model optimizer
        for param_group in self.optimizer_1.param_groups:
            param_group['lr'] = self.lr
        for param_group in self.optimizer_2.param_groups:
            param_group['lr'] = self.lr

    def show_image_grid(self, img_origin, img_parts, parts, epoch):
        print(img_parts.size())
        torchvision.utils.save_image(img_origin,
                                     self.config.save_path +
                                     '/img_origin_grid.jpg',
                                     nrow=8,
                                     padding=2,
                                     normalize=True,
                                     range=(0, 1))

        torchvision.utils.save_image(img_parts,
                                     self.config.save_path + '/img_parts' +
                                     '_grid.jpg',
                                     nrow=8,
                                     padding=2,
                                     normalize=True,
                                     range=(0, 1))

        print(type(parts))
        print(parts.shape)

        grid = torchvision.utils.make_grid(img_origin,
                                           nrow=8,
                                           padding=2,
                                           normalize=True,
                                           range=(0, 1))

        image_size = 114
        parts = parts.transpose((1, 0))

        # save grid and parts as numpy
        grid = grid.cpu().numpy().transpose((1, 2, 0))
        np.save(self.config.save_path + '/grid.npy', grid)
        np.save(self.config.save_path + '/parts.npy', parts)

        img = cv2.cvtColor(grid * 255, cv2.COLOR_RGB2BGR)
        l = 24
        colors = [(229, 187, 129), (161, 23, 21), (34, 8, 7), (118, 77, 57)]
        for i in range(parts.shape[0]):
            box = np.array([
                np.maximum(0, (parts[i, 0] - l)),
                np.maximum(0, (parts[i, 1] - l)),
                np.minimum(image_size, (parts[i, 0] + l)),
                np.minimum(image_size, (parts[i, 1] + l))
            ])
            cv2.rectangle(
                img,
                (int(box[0] + i % 8 * (2 + image_size)), int(box[1]) + i // 8 *
                 (2 + image_size)),
                (int(box[2] + i % 8 * (2 + image_size)), int(box[3]) + i // 8 *
                 (2 + image_size)), colors[0], 2)

        cv2.imwrite(self.config.save_path + '/plt_image_boxs_epoch' +
                    str(epoch) + '.jpg', img)  # 保存图片

        return

    def show_mask(self, mask, epoch):
        img = (1 - (mask - torch.min(mask)) /
               (torch.max(mask) - torch.min(mask)))
        img = torchvision.utils.make_grid(img.view(64, 1, 6, 6),
                                          nrow=8,
                                          padding=2,
                                          normalize=True,
                                          range=(0,
                                                 1)).cpu().numpy().transpose(
                                                     (1, 2, 0))
        img = cv2.cvtColor(img * 255, cv2.COLOR_RGB2BGR)
        cv2.imwrite(self.config.save_path + '/mask_' + 'epoch' + str(epoch) +
                    '.jpg', img)  # 保存图片

        return

    def for_hook(self, module, input_grad, output_grad):
        print('\r\nhook:\r\n')
        print(len(input_grad))
        print(len(output_grad))
        print(input_grad[0].size())
        print(input_grad[1].size())
        print(output_grad[0].size())
        print(input_grad[0][5][5:10])
        print(output_grad[0][5][5:10])

    def train(self, epoch):
        cls_loss_ = AverageMeter()
        accuracy_ = AverageMeter()
        accuracy_valid_ = AverageMeter()

        # 训练集作为模型输入
        self.scheduler_1.step()
        self.scheduler_2.step()
        self.model_1.train()
        self.model_2.train()

        for batch_idx, (data, gt_label) in enumerate(self.train_loader):

            data, gt_label = data.to(self.device), gt_label.to(self.device)
            x, mask = self.model_1(data)

            # test
            # print(self.model_1.alexnet_1.conv1[0].weight.data)
            # print(self.model_2.channelgroup_2.group[0].weight.data[5][5:10])
            # print(self.model_3.Classify_1.conv1[0].weight.data)

            # test

            with torch.no_grad():
                parts = part_box(mask)
                img_parts, parts = get_part(data.cpu(),
                                            parts)  # (1, 64, 48, 48)
                img_parts = torch.from_numpy(img_parts).view(
                    img_parts.shape[0], 1, 48,
                    48).to(self.device)  # view(64, 1, 48, 48)

                if (epoch == 1 or epoch == 5 or epoch == 10
                        or epoch == 15) and batch_idx == 1:
                    self.show_image_grid(data, img_parts, parts, epoch)
                    self.show_mask(mask, epoch)
                    print('save image and parts in result: ' +
                          self.config.save_path)
                    print('epoch: ' + str(epoch))
                    print('batch_idx: ' + str(batch_idx))

            cls_pred = self.model_2(img_parts, x)

            # compute the loss
            cls_loss = self.lossfn.cls_loss(gt_label, cls_pred)
            accuracy = self.compute_accuracy(cls_pred, gt_label)

            if epoch >= 0:
                self.optimizer_1.zero_grad()
                self.optimizer_2.zero_grad()
                cls_loss.backward()
                self.optimizer_1.step()
                self.optimizer_2.step()

            cls_loss_.update(cls_loss.item(), data.size(0))
            accuracy_.update(accuracy, data.size(0))

            if batch_idx % 2000 == 1:
                print('batch_idx: ', batch_idx)
                print('Cls loss: ', cls_loss.item())

        # 验证集作为模型输入
        with torch.no_grad():
            self.model_1.eval()
            self.model_2.eval()

            for batch_idx, (data, gt_label) in enumerate(self.valid_loader):
                data, gt_label = data.to(self.device), gt_label.to(self.device)

                x, mask = self.model_1(data)

                parts = part_box(mask)
                img_parts, parts = get_part(data.cpu(),
                                            parts)  # (4, 64, 48, 48)

                img_parts = torch.from_numpy(img_parts).view(
                    img_parts.shape[0], 1, 48, 48).to(self.device)

                cls_pred = self.model_2(img_parts, x)

                accuracy_valid = self.compute_accuracy(cls_pred, gt_label)
                accuracy_valid_.update(accuracy_valid, data.size(0))

            # 记录数据
            self.scalar_info['cls_loss'] = cls_loss_.avg
            self.scalar_info['accuracy'] = accuracy_.avg
            self.scalar_info['lr'] = self.scheduler_1.get_lr()[0]

            # if self.logger is not None:
            #     for tag, value in list(self.scalar_info.items()):
            #         self.logger.scalar_summary(tag, value, self.run_count)
            #     self.scalar_info = {}
            # self.run_count += 1

        print(
            "\r\nEpoch: {}|===>Train Loss: {:.8f}   Train Accuracy: {:.6f}   valid Accuracy: {:.6f}\r\n"
            .format(epoch, cls_loss_.avg, accuracy_.avg, accuracy_valid_.avg))

        return cls_loss_.avg, accuracy_.avg, accuracy_valid_.avg
Example #5
0
class ONetTrainer(object):
    
    def __init__(self, lr, train_loader, model, optimizer, scheduler, logger, device):
        self.lr = lr
        self.train_loader = train_loader
        self.model = model
        self.optimizer = optimizer
        self.scheduler = scheduler
        self.device = device
        self.lossfn = LossFn(self.device)
        self.logger = logger
        self.run_count = 0
        self.scalar_info = {}

    def compute_accuracy(self, prob_cls, gt_cls):
        #we only need the detection which >= 0
        prob_cls = torch.squeeze(prob_cls)
        mask = torch.ge(gt_cls, 0)
        #get valid element
        valid_gt_cls = torch.masked_select(gt_cls, mask)
        valid_prob_cls = torch.masked_select(prob_cls, mask)
        size = min(valid_gt_cls.size()[0], valid_prob_cls.size()[0])
        prob_ones = torch.ge(valid_prob_cls, 0.6).float()
        right_ones = torch.eq(prob_ones, valid_gt_cls.float()).float()

        return torch.div(torch.mul(torch.sum(right_ones), float(1.0)), float(size))

    def update_lr(self, epoch):
        """
        update learning rate of optimizers
        :param epoch: current training epoch
        """
        # update learning rate of model optimizer
        for param_group in self.optimizer.param_groups:
            param_group['lr'] = self.lr

    def train(self, epoch):
        cls_loss_ = AverageMeter()
        box_offset_loss_ = AverageMeter()
        landmark_loss_ = AverageMeter()
        total_loss_ = AverageMeter()
        accuracy_ = AverageMeter()

        self.scheduler.step()
        self.model.train()

        for batch_idx, (data, target) in enumerate(self.train_loader):
            gt_label = target['label']
            gt_bbox = target['bbox_target']
            gt_landmark = target['landmark_target']
            data, gt_label, gt_bbox, gt_landmark = data.to(self.device), gt_label.to(
                self.device), gt_bbox.to(self.device).float(), gt_landmark.to(self.device).float()

            cls_pred, box_offset_pred, landmark_offset_pred = self.model(data)
            # compute the loss
            cls_loss = self.lossfn.cls_loss(gt_label, cls_pred)
            box_offset_loss = self.lossfn.box_loss(
                gt_label, gt_bbox, box_offset_pred)
            landmark_loss = self.lossfn.landmark_loss(gt_label, gt_landmark, landmark_offset_pred)

            total_loss = cls_loss + box_offset_loss * 0.5 + landmark_loss
            accuracy = self.compute_accuracy(cls_pred, gt_label)

            self.optimizer.zero_grad()
            total_loss.backward()
            self.optimizer.step()

            cls_loss_.update(cls_loss, data.size(0))
            box_offset_loss_.update(box_offset_loss, data.size(0))
            landmark_loss_.update(landmark_loss, data.size(0))
            total_loss_.update(total_loss, data.size(0))
            accuracy_.update(accuracy, data.size(0))

            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}\tAccuracy: {:.6f}'.format(
                epoch, batch_idx * len(data), len(self.train_loader.dataset),
                100. * batch_idx / len(self.train_loader), total_loss.item(), accuracy.item()))

        self.scalar_info['cls_loss'] = cls_loss_.avg
        self.scalar_info['box_offset_loss'] = box_offset_loss_.avg
        self.scalar_info['landmark_loss'] = landmark_loss_.avg
        self.scalar_info['total_loss'] = total_loss_.avg
        self.scalar_info['accuracy'] = accuracy_.avg
        self.scalar_info['lr'] = self.scheduler.get_lr()[0]

        if self.logger is not None:
            for tag, value in list(self.scalar_info.items()):
                self.logger.scalar_summary(tag, value, self.run_count)
            self.scalar_info = {}
        self.run_count += 1

        print("|===>Loss: {:.4f}".format(total_loss_.avg))
        return cls_loss_.avg, box_offset_loss_.avg, landmark_loss_.avg, total_loss_.avg, accuracy_.avg
Example #6
0
class FastCNNTrainer(object):

    def __init__(self, lr, train_loader, valid_loader, model, optimizer, scheduler, logger, device):
        self.lr = lr
        self.train_loader = train_loader
        self.valid_loader = valid_loader
        self.model = model
        self.optimizer = optimizer
        self.scheduler = scheduler
        self.device = device
        self.lossfn = LossFn(self.device)
        self.logger = logger
        self.run_count = 0
        self.scalar_info = {}

    def compute_accuracy(self, prob_cls, gt_cls):
        pred_cls = torch.max(prob_cls, 1)[1].squeeze()
        accuracy = float((pred_cls == gt_cls).sum()) / float(gt_cls.size(0))
        return accuracy

    def update_lr(self, epoch):
        """
        update learning rate of optimizers
        :param epoch: current training epoch
        """
        # update learning rate of model optimizer
        for param_group in self.optimizer.param_groups:
            param_group['lr'] = self.lr

    def train(self, epoch):
        cls_loss_ = AverageMeter()
        accuracy_ = AverageMeter()
        accuracy_valid_ = AverageMeter()


        #训练集作为模型输入
        self.scheduler.step()
        self.model.train()

        for batch_idx, (data, gt_label) in enumerate(self.train_loader):
            
            data, gt_label = data.to(self.device), gt_label.to(
                self.device)

            cls_pred, feature = self.model(data)
            # compute the loss
            cls_loss = self.lossfn.cls_loss(gt_label, cls_pred)
            accuracy = self.compute_accuracy(cls_pred, gt_label)

            self.optimizer.zero_grad()
            cls_loss.backward()
            self.optimizer.step()

            cls_loss_.update(cls_loss.item(), data.size(0))
            accuracy_.update(accuracy, data.size(0))

            if batch_idx%20 == 10:
                print(batch_idx)
                print(cls_loss.item())

        # 验证集作为模型输入
        with torch.no_grad():
            self.model.eval()

            for batch_idx, (data, gt_label) in enumerate(self.valid_loader):
                data, gt_label = data.to(self.device), gt_label.to(
                    self.device)

                cls_pred, feature = self.model(data)
                accuracy_valid = self.compute_accuracy(cls_pred, gt_label)
                accuracy_valid_.update(accuracy_valid, data.size(0))


            #记录数据
            self.scalar_info['cls_loss'] = cls_loss_.avg
            self.scalar_info['accuracy'] = accuracy_.avg
            self.scalar_info['lr'] = self.scheduler.get_lr()[0]

            # if self.logger is not None:
            #     for tag, value in list(self.scalar_info.items()):
            #         self.logger.scalar_summary(tag, value, self.run_count)
            #     self.scalar_info = {}
            # self.run_count += 1

        print("\r\nEpoch: {}|===>Train Loss: {:.8f}   Train Accuracy: {:.6f}   valid Accuracy: {:.6f}\r\n"
              .format(epoch, cls_loss_.avg, accuracy_.avg, accuracy_valid_.avg))

        return cls_loss_.avg, accuracy_.avg, accuracy_valid_.avg



    def calculate_mean_feature(self):
        with torch.no_grad():
            self.model.eval()

            featureMean = torch.zeros((8, 384*3*3))
            labelNumber = torch.zeros(8)

            for batch_idx, (data, gt_label) in enumerate(self.valid_loader):
                data, gt_label = data.to(self.device), gt_label.to(
                    self.device)

                cls_pred, feature = self.model(data)
                for s in range(data.size(0)):
                    label = gt_label[s]
                    featureMean[label] += feature[s]
                    labelNumber[label] += 1

            for l in range(featureMean.size(0)):
                featureMean[l] = featureMean[l]/labelNumber[l]

        return featureMean
Example #7
0
class ONetTrainer(object):
    
    def __init__(self, lr, train_loader, model, optimizer, scheduler, logger, device):
        self.lr = lr
        self.train_loader = train_loader
        self.model = model
        self.optimizer = optimizer
        self.scheduler = scheduler
        self.device = device
        self.lossfn = LossFn(self.device)
        self.logger = logger
        self.run_count = 0
        self.scalar_info = {}

    def compute_accuracy(self, prob_cls, gt_cls, prob_attr, gt_attr):
        #we only need the detection which >= 0
        prob_cls = torch.squeeze(prob_cls)
        mask = torch.ge(gt_cls, 0)
        #get valid element
        valid_gt_cls = torch.masked_select(gt_cls, mask)
        valid_prob_cls = torch.masked_select(prob_cls, mask)
        size = min(valid_gt_cls.size()[0], valid_prob_cls.size()[0])
        prob_ones = torch.ge(valid_prob_cls, 0.6).float()
        right_ones = torch.eq(prob_ones, valid_gt_cls.float()).float()
        accuracy_cls = torch.div(torch.mul(torch.sum(right_ones), float(1.0)), float(size))


        prob_attr = torch.squeeze(prob_attr)
        mask_attr = torch.eq(gt_cls, -2)
        chose_index = torch.nonzero(mask_attr.data)
        chose_index = torch.squeeze(chose_index)
        valid_gt_attr = gt_attr[chose_index, :]
        valid_prob_attr = prob_attr[chose_index, :]
        size_attr = min(valid_gt_attr.size()[0], valid_prob_attr.size()[0])
        valid_gt_color = valid_gt_attr[:,0]
        valid_gt_layer = valid_gt_attr[:,1]
        valid_gt_type  = valid_gt_attr[:,2]
        # print(valid_prob_attr)
        valid_prob_color = torch.max(valid_prob_attr[:,:5],1)
        valid_prob_layer = torch.max(valid_prob_attr[:,5:7],1)
        valid_prob_type  = torch.max(valid_prob_attr[:,7:],1)
        # print(valid_prob_color)
        # print(valid_gt_color)
        color_right_ones = torch.eq(valid_prob_color[1],valid_gt_color).float()
        layer_right_ones = torch.eq(valid_prob_layer[1],valid_gt_layer).float()
        type_right_ones  = torch.eq(valid_prob_type[1], valid_gt_type).float()
        accuracy_color = torch.div(torch.mul(torch.sum(color_right_ones), float(1.0)), float(size_attr))
        accuracy_layer = torch.div(torch.mul(torch.sum(layer_right_ones), float(1.0)), float(size_attr))
        accuracy_type  = torch.div(torch.mul(torch.sum(type_right_ones),  float(1.0)), float(size_attr))

        return accuracy_cls,accuracy_color,accuracy_layer,accuracy_type

    def update_lr(self, epoch):
        """
        update learning rate of optimizers
        :param epoch: current training epoch
        """
        # update learning rate of model optimizer
        for param_group in self.optimizer.param_groups:
            param_group['lr'] = self.lr

    def train(self, epoch):
        cls_loss_ = AverageMeter()
        box_offset_loss_ = AverageMeter()
        landmark_loss_ = AverageMeter()
        total_loss_ = AverageMeter()
        accuracy_cls_ = AverageMeter()
        accuracy_color_ = AverageMeter()
        accuracy_layer_ = AverageMeter()
        accuracy_type_ = AverageMeter()

        self.scheduler.step()
        self.model.train()

        for batch_idx, (data, target) in enumerate(self.train_loader):
            gt_label = target['label']
            gt_bbox = target['bbox_target']
            gt_landmark = target['landmark_target']
            gt_attr = target['attribute']
            data, gt_label, gt_bbox, gt_landmark, gt_attr = data.to(self.device), gt_label.to(
                self.device), gt_bbox.to(self.device).float(), gt_landmark.to(
                self.device).float(), gt_attr.to(self.device).long()

            cls_pred, box_offset_pred, landmark_offset_pred, attr_pred = self.model(data)
            # print(cls_pred[0:100])
            # print(box_offset_pred[0:100,:])
            # print(landmark_offset_pred[0:100,:])
            # print(attr_pred[0:100,:])
            # compute the loss
            cls_loss = self.lossfn.cls_loss(gt_label, cls_pred)
            box_offset_loss = self.lossfn.box_loss(
                gt_label, gt_bbox, box_offset_pred)
            landmark_loss = self.lossfn.landmark_loss(gt_label, gt_landmark, landmark_offset_pred)
            color_loss,layer_loss,type_loss = self.lossfn.attr_loss(gt_label,gt_attr,attr_pred)

            total_loss = cls_loss + box_offset_loss * 0.5 + landmark_loss + color_loss*0.5 + layer_loss*0.5 + type_loss*0.5
            accuracy_cls,accuracy_color,accuracy_layer,accuracy_type = self.compute_accuracy(cls_pred, gt_label, attr_pred, gt_attr)

            self.optimizer.zero_grad()
            total_loss.backward()
            self.optimizer.step()

            cls_loss_.update(cls_loss, data.size(0))
            box_offset_loss_.update(box_offset_loss, data.size(0))
            landmark_loss_.update(landmark_loss, data.size(0))
            total_loss_.update(total_loss, data.size(0))
            accuracy_cls_.update(accuracy_cls, data.size(0))
            accuracy_color_.update(accuracy_color, data.size(0))
            accuracy_layer_.update(accuracy_layer, data.size(0))
            accuracy_type_.update(accuracy_type, data.size(0))

            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}, cls_loss: {:.6f}, box_loss: {:.6f}, landmark_loss: {:.6f}, color_loss: {:.6f}, layer_loss: {:.6f}, type_loss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(self.train_loader.dataset),
                100. * batch_idx / len(self.train_loader), 
                total_loss.item(), cls_loss.item(), box_offset_loss.item(), landmark_loss.item(), 
                color_loss.item(),layer_loss.item(),type_loss.item()))
            print('Accuracy_cls: {:.6f}, Accuracy_color: {:.6f}, Accuracy_layer: {:.6f}, Accuracy_type: {:.6f}'.format(
                accuracy_cls.item(),accuracy_color.item(),accuracy_layer.item(),accuracy_type.item()))

        self.scalar_info['cls_loss'] = cls_loss_.avg
        self.scalar_info['box_offset_loss'] = box_offset_loss_.avg
        self.scalar_info['landmark_loss'] = landmark_loss_.avg
        self.scalar_info['total_loss'] = total_loss_.avg
        self.scalar_info['accuracy_cls'] = accuracy_cls_.avg
        self.scalar_info['accuracy_color'] = accuracy_color_.avg
        self.scalar_info['accuracy_layer'] = accuracy_layer_.avg
        self.scalar_info['accuracy_type'] = accuracy_type_.avg 
        self.scalar_info['lr'] = self.scheduler.get_lr()[0]

        if self.logger is not None:
            for tag, value in list(self.scalar_info.items()):
                self.logger.scalar_summary(tag, value, self.run_count)
            self.scalar_info = {}
        self.run_count += 1

        print("|===>Loss: {:.4f}".format(total_loss_.avg))
        return cls_loss_.avg, box_offset_loss_.avg, landmark_loss_.avg, total_loss_.avg, accuracy_cls_.avg, accuracy_color_.avg, accuracy_layer_.avg, accuracy_type_.avg
Example #8
0
class LeNet_5Trainer(object):
    def __init__(self, lr, train_loader, valid_x, valid_y, model, optimizer,
                 scheduler, logger, device):
        self.lr = lr
        self.train_loader = train_loader
        self.valid_x = valid_x
        self.valid_y = valid_y
        self.model = model
        self.optimizer = optimizer
        self.scheduler = scheduler
        self.device = device
        self.lossfn = LossFn(self.device)
        self.logger = logger
        self.run_count = 0
        self.scalar_info = {}

    def compute_accuracy(self, prob_cls, gt_cls):
        pred_cls = torch.max(prob_cls, 1)[1].squeeze()
        accuracy = float((pred_cls == gt_cls).sum()) / float(gt_cls.size(0))
        return accuracy

    def train(self, epoch):
        cls_loss_ = AverageMeter()
        accuracy_ = AverageMeter()

        self.model.train()

        for batch_idx, (data, gt_label) in enumerate(self.train_loader):
            data, gt_label = data.to(self.device), gt_label.to(self.device)

            cls_pred = self.model(data)
            # compute the loss
            cls_loss = self.lossfn.cls_loss(gt_label, cls_pred)
            accuracy = self.compute_accuracy(cls_pred, gt_label)

            self.optimizer.zero_grad()
            cls_loss.backward()
            self.optimizer.step()

            cls_loss_.update(cls_loss, data.size(0))
            accuracy_.update(accuracy, data.size(0))

            if batch_idx % 50 == 0:

                print(
                    'Train Epoch: {} [{}/{} ({:.0f}%)]\tTrain Loss: {:.6f}\tTrain Accuracy: {:.6f}'
                    .format(epoch, batch_idx * len(data),
                            len(self.train_loader.dataset),
                            100. * batch_idx / len(self.train_loader),
                            cls_loss.item(), accuracy))

        self.scalar_info['cls_loss'] = cls_loss_.avg
        self.scalar_info['accuracy'] = accuracy_.avg
        self.scalar_info['lr'] = self.lr

        # if self.logger is not None:
        #     for tag, value in list(self.scalar_info.items()):
        #         self.logger.scalar_summary(tag, value, self.run_count)
        #     self.scalar_info = {}
        # self.run_count += 1

        print("|===>Loss: {:.4f}   Train Accuracy: {:.6f} ".format(
            cls_loss_.avg, accuracy_.avg))

        return cls_loss_.avg, accuracy_.avg