예제 #1
0
def trainer(ops):

    set_seed(ops.seed)

    use_cuda = torch.cuda.is_available()
    device = torch.device('cuda:0' if use_cuda else 'cpu')

    if ops.pattern == 'P-Net':
        m_XNet = PNet()
        mtcnn_detector = None
    elif ops.pattern == 'R-Net':
        m_XNet = RNet()
    elif ops.pattern == 'O-Net':
        m_XNet = ONet()
    # datasets
    dataset = LoadImagesAndLabels(pattern=ops.pattern,
                                  path_img=ops.path_img,
                                  path_anno=ops.path_anno,
                                  batch_size=ops.batch_size)
    print('dataset len : ', dataset.__len__())
    dataloader = DataLoader(dataset,
                            batch_size=1,
                            num_workers=ops.num_workers,
                            shuffle=True,
                            pin_memory=False,
                            drop_last=True)

    print('{} : \n'.format(ops.pattern), m_XNet)
    m_XNet = m_XNet.to(device)

    m_loss = LossFn()

    if ops.Optimizer_X == 'Adam':
        optimizer = torch.optim.Adam(m_XNet.parameters(),
                                     lr=ops.init_lr,
                                     betas=(0.9, 0.99),
                                     weight_decay=1e-6)
    elif ops.Optimizer_X == 'SGD':
        optimizer = torch.optim.SGD(m_XNet.parameters(),
                                    lr=ops.init_lr,
                                    momentum=0.9,
                                    weight_decay=1e-6)
    elif ops.Optimizer_X == 'RMSprop':
        optimizer = torch.optim.RMSprop(m_XNet.parameters(),
                                        lr=ops.init_lr,
                                        alpha=0.9,
                                        weight_decay=1e-6)
    else:
        print('------>>> Optimizer init error : ', ops.Optimizer_X)

    # load finetune model
    if os.access(ops.ft_model, os.F_OK):
        chkpt = torch.load(ops.ft_model, map_location=device)
        print('chkpt:\n', ops.ft_model)
        m_XNet.load_state_dict(chkpt)

    # train
    print('  epoch : ', ops.epochs)
    best_loss = np.inf
    loss_mean = 0.
    loss_cls_mean = 0.
    loss_idx = 0.
    init_lr = ops.init_lr

    loss_cnt = 0

    loss_cnt = 0

    for epoch in range(0, ops.epochs):

        if loss_idx != 0:
            if best_loss > (loss_mean / loss_idx):
                best_loss = loss_mean / loss_idx
                loss_cnt = 0
            else:
                if loss_cnt > 3:
                    init_lr = init_lr * 0.5
                    set_learning_rate(optimizer, init_lr)
                    loss_cnt = 0
                else:
                    loss_cnt += 1

        loss_mean = 0.
        loss_cls_mean = 0.
        loss_idx = 0.

        print('\nepoch %d ' % epoch)
        m_XNet = m_XNet.train()
        random.shuffle(dataset.annotations)  # shuffle 图片组合

        for i, (imgs, gt_labels, gt_offsets, pos_num, part_num,
                neg_num) in enumerate(dataloader):
            imgs = imgs.squeeze(0)
            gt_labels = gt_labels.squeeze(0)
            gt_offsets = gt_offsets.squeeze(0)
            # print('imgs size {}, labels size {}, offsets size {}'.format(imgs.size(),gt_labels.size(),gt_offsets.size()))

            if use_cuda:
                imgs = imgs.cuda()  # (bs, 3, h, w)
                gt_labels = gt_labels.cuda()
                gt_offsets = gt_offsets.cuda()

            cls_pred, box_offset_pred = m_XNet(imgs)

            cls_loss = m_loss.focal_Loss(gt_labels, cls_pred)
            box_offset_loss = m_loss.box_loss(gt_labels, gt_offsets,
                                              box_offset_pred)

            if ops.pattern == 'O-Net':
                all_loss = cls_loss * 1.0 + box_offset_loss * 0.4
            elif ops.pattern == 'R-Net':
                all_loss = cls_loss * 1.0 + box_offset_loss * 0.6
            else:
                all_loss = cls_loss * 1.0 + box_offset_loss * 1.0

            loss_mean += all_loss.item()
            loss_cls_mean += cls_loss.item()
            loss_idx += 1.

            optimizer.zero_grad()
            all_loss.backward()
            optimizer.step()

            if i % 5 == 0:
                loc_time = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())

                print("[%s]-<%s>Epoch: %d, [%d/%d],lr:%.6f,Loss:%.5f - Mean Loss:%.5f - Mean cls loss:%5f,cls_loss:%.5f ,bbox_loss:%.5f,imgs_batch: %4d,best_loss: %.5f" \
                % (loc_time,ops.pattern,epoch, i,dataset.__len__(), optimizer.param_groups[0]['lr'], \
                all_loss.item(),loss_mean/loss_idx,loss_cls_mean/loss_idx,cls_loss.item(),box_offset_loss.item(),imgs.size()[0],best_loss), ' ->pos:{},part:{},neg:{}'.format(pos_num.item(),part_num.item(),neg_num.item()))

            if i % 50 == 0 and i > 1:
                accuracy = compute_accuracy(cls_pred, gt_labels)
                print("\n  ------------- >>>  accuracy: %f\n" %
                      (accuracy.item()))
                accuracy = compute_accuracy(cls_pred, gt_labels)
                torch.save(m_XNet.state_dict(),
                           ops.ckpt + '{}_latest.pth'.format(ops.pattern))
            if i % 80 == 0 and i > 1:
                torch.save(
                    m_XNet.state_dict(),
                    ops.ckpt + '{}_epoch-{}.pth'.format(ops.pattern, epoch))
예제 #2
0
def train_onet(model_store_path,
               end_epoch,
               imdb,
               batch_size,
               frequent=50,
               base_lr=0.01,
               use_cuda=True):

    if not os.path.exists(model_store_path):
        os.makedirs(model_store_path)

    lossfn = LossFn()
    net = ONet(is_train=True)
    net.train()
    if use_cuda:
        net.cuda()

    optimizer = torch.optim.Adam(net.parameters(), lr=base_lr)

    train_data = TrainImageReader(imdb, 48, batch_size, shuffle=True)

    for cur_epoch in range(1, end_epoch + 1):
        train_data.reset()
        accuracy_list = []
        cls_loss_list = []
        bbox_loss_list = []
        landmark_loss_list = []

        for batch_idx, (image, (gt_label, gt_bbox,
                                gt_landmark)) in enumerate(train_data):

            im_tensor = [
                image_tools.convert_image_to_tensor(image[i, :, :, :])
                for i in range(image.shape[0])
            ]
            im_tensor = torch.stack(im_tensor)

            im_tensor = Variable(im_tensor)
            gt_label = Variable(torch.from_numpy(gt_label).float())

            gt_bbox = Variable(torch.from_numpy(gt_bbox).float())
            gt_landmark = Variable(torch.from_numpy(gt_landmark).float())

            if use_cuda:
                im_tensor = im_tensor.cuda()
                gt_label = gt_label.cuda()
                gt_bbox = gt_bbox.cuda()
                gt_landmark = gt_landmark.cuda()

            cls_pred, box_offset_pred, landmark_offset_pred = net(im_tensor)
            # all_loss, cls_loss, offset_loss = lossfn.loss(gt_label=label_y,gt_offset=bbox_y, pred_label=cls_pred, pred_offset=box_offset_pred)

            cls_loss = lossfn.cls_loss(gt_label, cls_pred)
            box_offset_loss = lossfn.box_loss(gt_label, gt_bbox,
                                              box_offset_pred)
            landmark_loss = lossfn.landmark_loss(gt_label, gt_landmark,
                                                 landmark_offset_pred)

            all_loss = cls_loss * 0.8 + box_offset_loss * 0.6 + landmark_loss * 1.5

            if batch_idx % frequent == 0:
                accuracy = compute_accuracy(cls_pred, gt_label)

                show1 = accuracy.data.tolist()[0]
                show2 = cls_loss.data.tolist()[0]
                show3 = box_offset_loss.data.tolist()[0]
                show4 = landmark_loss.data.tolist()[0]
                show5 = all_loss.data.tolist()[0]

                print "%s : Epoch: %d, Step: %d, accuracy: %s, det loss: %s, bbox loss: %s, landmark loss: %s, all_loss: %s, lr:%s " % (
                    datetime.datetime.now(), cur_epoch, batch_idx, show1,
                    show2, show3, show4, show5, base_lr)
                accuracy_list.append(accuracy)
                cls_loss_list.append(cls_loss)
                bbox_loss_list.append(box_offset_loss)
                landmark_loss_list.append(landmark_loss)

            optimizer.zero_grad()
            all_loss.backward()
            optimizer.step()

        accuracy_avg = torch.mean(torch.cat(accuracy_list))
        cls_loss_avg = torch.mean(torch.cat(cls_loss_list))
        bbox_loss_avg = torch.mean(torch.cat(bbox_loss_list))
        landmark_loss_avg = torch.mean(torch.cat(landmark_loss_list))

        show6 = accuracy_avg.data.tolist()[0]
        show7 = cls_loss_avg.data.tolist()[0]
        show8 = bbox_loss_avg.data.tolist()[0]
        show9 = landmark_loss_avg.data.tolist()[0]

        print "Epoch: %d, accuracy: %s, cls loss: %s, bbox loss: %s, landmark loss: %s " % (
            cur_epoch, show6, show7, show8, show9)
        torch.save(
            net.state_dict(),
            os.path.join(model_store_path, "onet_epoch_%d.pt" % cur_epoch))
        torch.save(
            net,
            os.path.join(model_store_path,
                         "onet_epoch_model_%d.pkl" % cur_epoch))
예제 #3
0
def train_pnet(model_store_path,
               end_epoch,
               imdb,
               batch_size,
               frequent=10,
               base_lr=0.01,
               use_cuda=True):

    if not os.path.exists(model_store_path):
        os.makedirs(model_store_path)

    lossfn = LossFn()
    net = PNet(is_train=True, use_cuda=use_cuda)
    net.train()

    if use_cuda:
        net.cuda()
    optimizer = torch.optim.Adam(net.parameters(), lr=base_lr)

    train_data = TrainImageReader(imdb, 12, batch_size, shuffle=True)

    frequent = 10
    for cur_epoch in range(1, end_epoch + 1):
        train_data.reset()  # shuffle

        for batch_idx, (image, (gt_label, gt_bbox,
                                gt_landmark)) in enumerate(train_data):

            im_tensor = [
                image_tools.convert_image_to_tensor(image[i, :, :, :])
                for i in range(image.shape[0])
            ]
            im_tensor = torch.stack(im_tensor)
            im_tensor = Variable(im_tensor)
            gt_label = Variable(torch.from_numpy(gt_label).float())

            gt_bbox = Variable(torch.from_numpy(gt_bbox).float())
            # gt_landmark = Variable(torch.from_numpy(gt_landmark).float())

            if use_cuda:
                im_tensor = im_tensor.cuda()
                gt_label = gt_label.cuda()
                gt_bbox = gt_bbox.cuda()
                # gt_landmark = gt_landmark.cuda()

            cls_pred, box_offset_pred = net(im_tensor)

            # all_loss, cls_loss, offset_loss = lossfn.loss(gt_label=label_y,gt_offset=bbox_y, pred_label=cls_pred, pred_offset=box_offset_pred)

            cls_loss = lossfn.cls_loss(gt_label, cls_pred)
            box_offset_loss = lossfn.box_loss(gt_label, gt_bbox,
                                              box_offset_pred)
            # landmark_loss = lossfn.landmark_loss(gt_label,gt_landmark,landmark_offset_pred)

            all_loss = cls_loss * 1.0 + box_offset_loss * 0.5

            if batch_idx % frequent == 0:
                accuracy = compute_accuracy(cls_pred, gt_label)

                show1 = accuracy.data.cpu().numpy()
                show2 = cls_loss.data.cpu().numpy()
                show3 = box_offset_loss.data.cpu().numpy()
                # show4 = landmark_loss.data.cpu().numpy()
                show5 = all_loss.data.cpu().numpy()

                print(
                    "%s : Epoch: %d, Step: %d, accuracy: %s, det loss: %s, bbox loss: %s, all_loss: %s, lr:%s "
                    % (datetime.datetime.now(), cur_epoch, batch_idx, show1,
                       show2, show3, show5, base_lr))

            optimizer.zero_grad()
            all_loss.backward()
            optimizer.step()

        torch.save(
            net.state_dict(),
            os.path.join(model_store_path, "pnet_epoch_%d.pt" % cur_epoch))
        torch.save(
            net,
            os.path.join(model_store_path,
                         "pnet_epoch_model_%d.pkl" % cur_epoch))
예제 #4
0
    def train_pnet(self, train_data_path):
        device = torch.device('cuda')
        lossfn = LossFn()
        net = PNet()
        # 返回 一样的 net = net.to(device)
        net.to(device)
        # 切换到train 状态  net.eval() 测试状态
        net.train()
        # print(net)
        optimizer = torch.optim.Adam(net.parameters(), lr=1e-3)

        self.viz.line(Y=torch.FloatTensor([0.]),
                      X=torch.FloatTensor([0.]),
                      win='pnet_train_loss',
                      opts=dict(title='train loss'))

        # 加载数据 ratios : pos:part:neg:landmark
        trian_datasets = DataReader(train_data_path,
                                    im_size=12,
                                    transform=self.trainTransform,
                                    batch_size=4096,
                                    ratios=(2, 1, 1, 2))

        for epoch in range(1):
            print("epoch:", epoch)
            for step, (imgs, cls_labels, rois,
                       landmarks) in enumerate(trian_datasets):

                # [b, 3, 12, 12],[b],[4],[10]
                im_tensor = imgs.to(device)
                cls_labels = cls_labels.to(device)
                rois = rois.to(device)
                landmarks = landmarks.to(device)

                cls_pred, box_offset_pred, landmarks_pred = net(im_tensor)

                # 貌似这里打印最后一个的loss,对于整体来说不怎么准确
                cls_loss = lossfn.cls_loss(cls_labels, cls_pred)
                box_offset_loss = lossfn.box_loss(cls_labels, rois,
                                                  box_offset_pred)
                landmark_loss = lossfn.landmark_loss(cls_labels, landmarks,
                                                     landmarks_pred)

                print("cls_loss:", cls_loss)
                print("box_offset_loss:", box_offset_loss)
                print("landmark_loss:", landmark_loss)

                all_loss = cls_loss * 1.0 + box_offset_loss * 0.5 + landmark_loss * 0.5

                self.viz.line(Y=torch.FloatTensor([all_loss]),
                              X=torch.FloatTensor([step]),
                              win='pnet_train_loss',
                              update='append')

                optimizer.zero_grad()
                all_loss.backward()
                optimizer.step()
                print("all_loss:", all_loss)
                print("-" * 40, "step:", step, "-" * 40)

                if step % 1000 == 0:
                    accuracy = compute_accuracy(cls_pred, cls_labels)
                    recoll = compute_recoll(cls_pred, cls_labels)
                    print(
                        "=" * 80, "\n\n=> acc:{}\n=> recoll:{}\n\n".format(
                            accuracy, recoll), "=" * 80)

                if step % 1000 == 0:
                    torch.save(
                        net.state_dict(),
                        os.path.join("../data/models/",
                                     "pnet_epoch_%d.pt" % epoch))
                    torch.save(
                        net,
                        os.path.join("../data/models/",
                                     "pnet_epoch_model_%d.pkl" % epoch))
                    epoch += 1