예제 #1
0
def train_batch(net, criterion, optimizer, flage=False):
    data = train_iter.next()
    cpu_images, cpu_texts = data  # decode utf-8 to unicode
    if ifUnicode:
        cpu_texts = [clean_txt(tx.decode('utf-8')) for tx in cpu_texts]

    batch_size = cpu_images.size(0)
    utils.loadData(image, cpu_images)
    t, l = converter.encode(cpu_texts)
    utils.loadData(text, t)
    utils.loadData(length, l)

    preds = crnn(image)
    preds_size = Variable(torch.IntTensor([preds.size(0)] * batch_size))
    cost = criterion(preds, text, preds_size, length) / batch_size
    crnn.zero_grad()
    cost.backward()
    if flage:
        lr = 0.0001
        optimizer = optim.Adadelta(crnn.parameters(), lr=lr)
    optimizer.step()
    return cost
예제 #2
0
    def test(self):
        image = torch.FloatTensor(self.batch_size, 3, self.img_h, self.img_w)
        text = torch.IntTensor(self.batch_size * 5)
        length = torch.IntTensor(self.batch_size)
        image = Variable(image)
        text = Variable(text)
        length = Variable(length)

        for p in self.model.parameters():
            p.requires_grad = False

        test_loss = 0.0
        correct = 0
        # loss_avg = utils.averager()

        time_start = time.time()
        self.model.eval()
        for data, target in self.test_loader:
            cpu_images = data
            cpu_texts = target
            batch_size = cpu_images.size(0)
            utils.loadData(image, cpu_images)
            if self.use_unicode:
                # cpu_texts = [tx.decode('utf-8') for tx in cpu_texts]
                cpu_texts = [tx for tx in cpu_texts]

            t, l = self.converter.encode(cpu_texts)
            utils.loadData(text, t)
            utils.loadData(length, l)

            if self.use_gpu:
                image = image.cuda()

            preds = self.model(image)
            preds_size = Variable(torch.IntTensor([preds.size(0)] *
                                                  batch_size))
            loss = self.criterion(preds, text, preds_size, length)
            test_loss += loss.item()

            _, preds = preds.max(2)
            # preds = preds.squeeze(2)
            preds = preds.transpose(1, 0).contiguous().view(-1)
            sim_preds = self.converter.decode(preds.data,
                                              preds_size.data,
                                              raw=False)

            # print("==============================================")
            # print(loss.item())
            # print(target)
            # # print(t)
            # # print(l)
            # total_preds = self.converter.decode(preds.data, preds_size.data, raw=True)
            # print(total_preds)
            # print(sim_preds)
            # if np.isnan(loss.item()):
            #     assert 1 == 0

            for pred, target in zip(sim_preds, cpu_texts):
                if pred.strip() == target.strip():
                    correct += 1
                # else:
                #     print(pred.strip())
                #     print(target.strip())

        time_end = time.time()
        time_avg = float(time_end - time_start) / float(
            len(self.test_loader.dataset))
        accuracy = correct / float(len(self.test_loader.dataset))
        test_loss /= len(self.test_loader)
        print('[Test] loss: %f, accuray: %f, time: %f' %
              (test_loss, accuracy, time_avg))
        return test_loss, accuracy
예제 #3
0
    def train(self, epoch, decay_epoch=80):
        image = torch.FloatTensor(self.batch_size, 3, self.img_h, self.img_w)
        text = torch.IntTensor(self.batch_size * 5)
        length = torch.IntTensor(self.batch_size)
        image = Variable(image)
        text = Variable(text)
        length = Variable(length)

        print('[train] epoch: %d' % epoch)
        for epoch_i in range(epoch):
            start_time = time.time()
            train_loss = 0.0
            correct = 0

            if epoch_i >= decay_epoch and epoch_i % decay_epoch == 0:  # 减小学习速率
                self.lr = self.lr * 0.1
                for param_group in self.optimizer.param_groups:
                    param_group['lr'] = self.lr
                # self.optimizer = optim.Adam(self.model.parameters(), lr=self.lr, weight_decay=1e-5)

            print('================================================')
            self.model.train()
            for batch_idx, (data,
                            target) in enumerate(self.train_loader):  # 训练
                # data, target = Variable(data), Variable(target)

                if self.use_unicode:
                    # target = [tx.decode('utf-8') for tx in target]
                    target = [tx for tx in target]
                    # print(target)

                # print('data size', data.size())         # [64, 3, 32, 270]
                batch_size = data.size(0)
                utils.loadData(image, data)
                # print('image size', image.size())       # [64, 3, 32, 270]
                t, l = self.converter.encode(target)
                # print(t)
                # print(l)
                utils.loadData(text, t)
                utils.loadData(length, l)

                if self.use_gpu:
                    image = image.cuda()

                # 梯度清0
                self.optimizer.zero_grad()
                for p in self.model.parameters():
                    p.requires_grad = True

                # 计算损失
                preds = self.model(
                    image)  # image size: [64, 3, 32, 270]  (char num 10)
                preds_size = Variable(
                    torch.IntTensor([preds.size(0)] * batch_size))
                # print('preds_size', preds_size)
                loss = self.criterion(preds, text, preds_size, length)
                # self.model.zero_grad()
                # 反向传播计算梯度
                loss.backward()
                # 更新参数
                self.optimizer.step()
                train_loss += loss.item()
                # print(preds.size())
                # total = 0.0
                # print('len', len(preds.data[0][0]))
                # for i in range(len(preds.data[0][0])):
                #     total += preds.data[0][0][i]
                #     print('total', total)

                _, preds = preds.max(2)
                # print(preds.size())
                # preds = preds.squeeze(2)
                preds = preds.transpose(1, 0).contiguous().view(-1)
                # print(preds.size())
                sim_preds = self.converter.decode(preds.data,
                                                  preds_size.data,
                                                  raw=False)
                # print(sim_preds)
                # print(target)
                # total_preds = self.converter.decode(preds.data, preds_size.data, raw=True)
                # print(total_preds)

                # print("==============================================", batch_idx)
                # print(loss.item())
                # print(target)
                # # print(t)
                # # print(l)
                # total_preds = self.converter.decode(preds.data, preds_size.data, raw=True)
                # print(total_preds)
                # print(sim_preds)
                # if np.isnan(loss.item()):
                #     assert 1 == 0

                for pred, target in zip(sim_preds, target):
                    # print('pred', pred, type(pred))
                    # print('target', target, type(target))
                    if pred.strip() == target.strip():
                        correct += 1

            train_loss /= len(self.train_loader)
            acc = float(correct) / float(len(self.train_loader.dataset))
            use_time = time.time() - start_time
            print(
                '[Train] Epoch: {} \tLoss: {:.6f}\tAcc: {:.6f}\tlr: {}\ttime: {}'
                .format(epoch_i, train_loss, acc, self.lr, use_time))

            # Test
            test_loss, test_acc = self.test()
            if test_loss < self.best_loss:
                self.best_loss = test_loss
                str_list = self.model_file.split('.')
                best_model_file = ""
                for str_index in range(len(str_list)):
                    best_model_file = best_model_file + str_list[str_index]
                    if str_index == (len(str_list) - 2):
                        best_model_file += '_best'
                    if str_index != (len(str_list) - 1):
                        best_model_file += '.'
                self.save(best_model_file)  # 保存最好的模型

            if test_acc > self.best_acc:
                self.best_acc = test_acc
                str_list = self.model_file.split('.')
                best_model_file = ""
                for str_index in range(len(str_list)):
                    best_model_file = best_model_file + str_list[str_index]
                    if str_index == (len(str_list) - 2):
                        best_model_file += '_best_acc'
                    if str_index != (len(str_list) - 1):
                        best_model_file += '.'
                self.save(best_model_file)  # 保存最好的模型

        self.save(self.model_file)
예제 #4
0
def val(net, dataset, criterion, max_iter=2):
    print('Start val')

    for p in crnn.parameters():
        p.requires_grad = False

    net.eval()
    data_loader = torch.utils.data.DataLoader(dataset,
                                              shuffle=False,
                                              batch_size=opt.batchSize,
                                              num_workers=int(opt.workers))
    val_iter = iter(data_loader)

    i = 0
    n_correct = 0
    loss_avg = utils.averager()

    max_iter = min(max_iter, len(data_loader))
    for i in range(max_iter):
        data = val_iter.next()
        i += 1
        cpu_images, cpu_texts = data
        batch_size = cpu_images.size(0)
        utils.loadData(image, cpu_images)
        if ifUnicode:
            cpu_texts = [clean_txt(tx.decode('utf-8')) for tx in cpu_texts]
        # print(cpu_texts)
        t, l = converter.encode(cpu_texts)
        # print(t)
        # print(l)
        utils.loadData(text, t)
        utils.loadData(length, l)

        preds = crnn(image)
        preds_size = Variable(torch.IntTensor([preds.size(0)] * batch_size))
        cost = criterion(preds, text, preds_size, length) / batch_size
        loss_avg.add(cost)

        # print(preds)
        # print(preds.shape)
        _, preds = preds.max(2)
        # print(preds)
        # print(preds.shape)
        # preds = preds.squeeze(2)
        # print(preds)
        # print(preds.shape)
        preds = preds.transpose(1, 0).contiguous().view(-1)
        # print(preds)
        # print(preds.shape)
        sim_preds = converter.decode(preds.data, preds_size.data, raw=False)
        print(sim_preds)
        print(cpu_texts)
        for pred, target in zip(sim_preds, cpu_texts):
            if pred.strip() == target.strip():
                n_correct += 1

    # raw_preds = converter.decode(preds.data, preds_size.data, raw=True)[:opt.n_test_disp]
    # for raw_pred, pred, gt in zip(raw_preds, sim_preds, cpu_texts):
    # print((pred, gt))
    # print
    accuracy = n_correct / float(max_iter * opt.batchSize)
    testLoss = loss_avg.val()
    print('Test loss: %f, accuray: %f' % (testLoss, accuracy))
    return testLoss, accuracy