Exemple #1
0
def train(output_directory, epochs, learning_rate, iters_per_checkpoint,
          batch_size, seed, checkpoint_path):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)

    criterion = torch.nn.L1Loss()
    model = CRNN(**CRNN_config).cuda()
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

    # load checkpoint if one exists
    iteration = 0
    if checkpoint_path != "":
        model, optimizer, iteration = load_checkpoint(checkpoint_path, model,
                                                      optimizer)

        iteration += 1

    trainset = LJspeechDataset(**data_config)
    # my_collate = collate_fn(trainset)
    train_loader = DataLoader(trainset, num_workers=1, shuffle=True,\
                                batch_size=batch_size,
                                collate_fn=collate_fn,
                                pin_memory=False,
                                drop_last=True)

    if not os.path.isdir(output_directory):
        os.makedirs(output_directory)
        os.chmod(output_directory, 0o775)
    print("output directory", output_directory)

    model.train()
    epoch_offset = max(0, int(iteration / len(train_loader)))

    for epoch in range(epoch_offset, epochs):
        epoch_ave_loss = 0
        print("Epoch: {}".format(epoch))
        for i, batch in enumerate(train_loader):
            model.zero_grad()

            # zeroPadded_batch = pad_sequence(batch, batch_first=True)

            netFeed = batch[:, :-1, :]
            netTarget = batch[:, 1:, :]
            netTarget = torch.autograd.Variable(netTarget.cuda())
            netFeed = torch.autograd.Variable(netFeed.cuda())

            netOutput = model(netFeed)

            loss = criterion(netOutput, netTarget)

            reduced_loss = loss.item()

            loss.backward()

            optimizer.step()

            if (iteration % iters_per_checkpoint == 0):
                print("{}:\t{:.9f}".format(iteration, reduced_loss))
            iteration += 1
            epoch_ave_loss += reduced_loss

        checkpoint_path = "{}/CRNN_net_{}".format(output_directory, epoch)
        save_checkpoint(model, optimizer, learning_rate, iteration,
                        checkpoint_path)
        epoch_ave_loss = epoch_ave_loss / i
        print("Epoch: {}, the average epoch loss: {}".format(
            epoch, epoch_ave_loss))
Exemple #2
0
def main():
    config = Config()

    if not os.path.exists(config.expr_dir):
        os.makedirs(config.expr_dir)

    if torch.cuda.is_available() and not config.use_cuda:
        print("WARNING: You have a CUDA device, so you should probably set cuda in params.py to True")

    # 加载训练数据集
    train_dataset = HubDataset(config, "train", transform=None)

    train_kwargs = {'num_workers': 2, 'pin_memory': True,
                    'collate_fn': alignCollate(config.img_height, config.img_width, config.keep_ratio)} if torch.cuda.is_available() else {}

    training_data_batch = DataLoader(train_dataset, batch_size=config.train_batch_size, shuffle=True, drop_last=False, **train_kwargs)

    # 加载定长校验数据集
    eval_dataset = HubDataset(config, "eval", transform=transforms.Compose([ResizeNormalize(config.img_height, config.img_width)]))
    eval_kwargs = {'num_workers': 2, 'pin_memory': False} if torch.cuda.is_available() else {}
    eval_data_batch = DataLoader(eval_dataset, batch_size=config.eval_batch_size, shuffle=False, drop_last=False, **eval_kwargs)

    # 加载不定长校验数据集
    # eval_dataset = HubDataset(config, "eval")
    # eval_kwargs = {'num_workers': 2, 'pin_memory': False,
    #                'collate_fn': alignCollate(config.img_height, config.img_width, config.keep_ratio)} if torch.cuda.is_available() else {}
    # eval_data_batch = DataLoader(eval_dataset, batch_size=config.eval_batch_size, shuffle=False, drop_last=False, **eval_kwargs)

    # 定义网络模型
    nclass = len(config.label_classes) + 1
    crnn = CRNN(config.img_height, config.nc, nclass, config.hidden_size, n_rnn=config.n_layers)
    # 加载预训练模型
    if config.pretrained != '':
        print('loading pretrained model from %s' % config.pretrained)
        crnn.load_state_dict(torch.load(config.pretrained))
    print(crnn)

    # Compute average for `torch.Variable` and `torch.Tensor`.
    loss_avg = utils.averager()

    # Convert between str and label.
    converter = utils.strLabelConverter(config.label_classes)

    criterion = CTCLoss()           # 定义损失函数

    # 设置占位符
    image = torch.FloatTensor(config.train_batch_size, 3, config.img_height, config.img_height)
    text = torch.LongTensor(config.train_batch_size * 5)
    length = torch.LongTensor(config.train_batch_size)

    if config.use_cuda and torch.cuda.is_available():
        criterion = criterion.cuda()
        image = image.cuda()
        crnn = crnn.to(config.device)

    image = Variable(image)
    text = Variable(text)
    length = Variable(length)

    # 设定优化器
    if config.adam:
        optimizer = optim.Adam(crnn.parameters(), lr=config.lr, betas=(config.beta1, 0.999))
    elif config.adadelta:
        optimizer = optim.Adadelta(crnn.parameters())
    else:
        optimizer = optim.RMSprop(crnn.parameters(), lr=config.lr)

    def val(net, criterion, eval_data_batch):
        print('Start val')
        for p in crnn.parameters():
            p.requires_grad = False
        net.eval()

        n_correct = 0
        loss_avg_eval = utils.averager()
        for data in eval_data_batch:
            cpu_images, cpu_texts = data
            batch_size = cpu_images.size(0)
            utils.loadData(image, cpu_images)
            t, l = converter.encode(cpu_texts)
            utils.loadData(text, t)
            utils.loadData(length, l)
            preds = crnn(image)
            preds_size = Variable(torch.LongTensor([preds.size(0)] * batch_size))
            cost = criterion(preds, text, preds_size, length) / batch_size
            loss_avg_eval.add(cost)         # 计算loss

            _, preds = preds.max(2)
            preds = preds.transpose(1, 0).contiguous().view(-1)
            sim_preds = converter.decode(preds.data, preds_size.data, raw=False)
            cpu_texts_decode = []
            for i in cpu_texts:
                cpu_texts_decode.append(i)
            for pred, target in zip(sim_preds, cpu_texts_decode):       # 计算准确率
                if pred == target:
                    n_correct += 1

            raw_preds = converter.decode(preds.data, preds_size.data, raw=True)[:config.n_val_disp]
            for raw_pred, pred, gt in zip(raw_preds, sim_preds, cpu_texts_decode):
                print('%-20s => %-20s, gt: %-20s' % (raw_pred, pred, gt))

        accuracy = n_correct / float(len(eval_dataset))
        print('Val loss: %f, accuray: %f' % (loss_avg.val(), accuracy))

    # 训练每个batch数据
    def train(net, criterion, optimizer, data):
        cpu_images, cpu_texts = data
        batch_size = cpu_images.size(0)             # 计算当前batch_size大小
        utils.loadData(image, cpu_images)
        t, l = converter.encode(cpu_texts)          # 转换为类别
        utils.loadData(text, t)
        utils.loadData(length, l)
        optimizer.zero_grad()                       # 清零梯度
        preds = net(image)
        preds_size = Variable(torch.LongTensor([preds.size(0)] * batch_size))
        cost = criterion(preds, text, preds_size, length) / batch_size
        cost.backward()
        optimizer.step()
        return cost

    for epoch in range(config.nepoch):
        i = 0
        for batch_data in training_data_batch:
            for p in crnn.parameters():
                p.requires_grad = True
            crnn.train()
            cost = train(crnn, criterion, optimizer, batch_data)
            loss_avg.add(cost)
            i += 1

            if i % config.displayInterval == 0:
                print('[%d/%d][%d/%d] Loss: %f' %
                      (epoch, config.nepoch, i, len(training_data_batch), loss_avg.val()))
                loss_avg.reset()

            # if i % config.valInterval == 0:
            #     val(crnn, criterion, eval_data_batch)
            #
            # # do checkpointing
            # if i % config.saveInterval == 0:
            #     torch.save(crnn.state_dict(), '{0}/netCRNN_{1}_{2}.pth'.format(config.expr_dir, epoch, i))

        val(crnn, criterion, eval_data_batch)
        torch.save(crnn.state_dict(), '{0}/netCRNN_{1}_end.pth'.format(config.expr_dir, epoch))
def main():

    print(torch.__version__)

    with open('config.yaml') as f:
        config = yaml.load(f, Loader=yaml.FullLoader)
    print(torch.cuda.is_available())
    torch.backends.cudnn.benchmark = True

    char_set = config['char_set']
    # if config['method'] == 'ctc':
    char2idx_ctc, idx2char_ctc = get_char_dict_ctc(char_set)
    char2idx_att, idx2char_att = get_char_dict_attention(char_set)
    config['char2idx_ctc'] = char2idx_ctc
    config['idx2char_ctc'] = idx2char_ctc
    config['char2idx_att'] = char2idx_att
    config['idx2char_att'] = idx2char_att

    batch_size = config['batch_size']

    if not os.path.exists(config['save_path']):
        os.mkdir(config['save_path'])
    print(config)

    train_dataset = TextRecDataset(config, phase='train')
    val_dataset = TextRecDataset(config, phase='val')
    test_dataset = TextRecDataset(config, phase='test')
    trainloader = data.DataLoader(train_dataset,
                                  batch_size=batch_size,
                                  shuffle=True,
                                  num_workers=cpu_count(),
                                  pin_memory=False)

    valloader = data.DataLoader(val_dataset,
                                batch_size=batch_size,
                                shuffle=False,
                                num_workers=cpu_count(),
                                pin_memory=False)

    testloader = data.DataLoader(test_dataset,
                                 batch_size=batch_size,
                                 shuffle=False,
                                 num_workers=cpu_count(),
                                 pin_memory=False)

    class_num = len(config['char_set']) + 1
    print('class_num', class_num)
    model = CRNN(class_num)
    # decoder = Decoder(class_num, config['max_string_len'], char2idx_att)
    attention_head = AttentionHead(class_num, config['max_string_len'], char2idx_att)

    # criterion = nn.CTCLoss(blank=char2idx['-'], reduction='mean')
    criterion_ctc = CTCFocalLoss(blank=char2idx_ctc['-'], gamma=0.5)
    criterion_att = nn.CrossEntropyLoss(reduction='none')

    if config['use_gpu']:
        model = model.cuda()
        # decoder = decoder.cuda()
        attention_head = attention_head.cuda()
    summary(model, (1, 32, 400))

    # model = torch.nn.DataParallel(model)

    # optimizer = torch.optim.Adam([{'params': model.parameters()}], lr=1e-2, weight_decay=5e-4)
    optimizer = torch.optim.SGD([{'params': model.parameters()},
                                 {'params': attention_head.parameters()}], lr=0.001, momentum=0.9, weight_decay=5e-4)

    scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[500, 800], gamma=0.1)

    print('train start, total batches %d' % len(trainloader))
    iter_cnt = 0
    for i in range(1, config['epochs']+1):
        start = time.time()
        model.train()
        attention_head.train()
        for j, batch in enumerate(trainloader):

            iter_cnt += 1
            imgs = batch[0].cuda()
            labels_length = batch[1].cuda()
            labels_str = batch[2]
            labels_ctc = batch[3].cuda().long()
            labels_ctc_mask = batch[4].cuda().float()
            labels_att = batch[5].cuda().long()
            labels_att_mask = batch[6].cuda().float()

            if config['method'] == 'ctc':
                # CTC loss
                outputs, cnn_features = model(imgs)
                log_prob = outputs.log_softmax(dim=2)
                t,n,c = log_prob.size(0),log_prob.size(1),log_prob.size(2)
                input_length = (torch.ones((n,)) * t).cuda().int()
                loss_ctc = criterion_ctc(log_prob, labels_ctc, input_length, labels_length)

                # attention loss   
                outputs = attention_head(cnn_features, labels_att)
                probs = outputs.permute(1, 2, 0)
                losses_att = criterion_att(probs, labels_att)
                losses_att = losses_att * labels_att_mask
                losses_att = losses_att.sum() / labels_att_mask.sum()

                loss = loss_ctc + losses_att

            else:
                # cross_entropy loss
                outputs_ctc, sqs = model(imgs)
                outputs_att = decoder(sqs, label_att)

                outputs = outputs_att.permute(1, 2, 0)
                losses = criterion(outputs, label_att)
                losses = losses * labels_att_mask
                loss = losses.sum() / labels_att_mask.sum()
 
                # attention loss   

            optimizer.zero_grad()            
            loss.backward()
            # nn.utils.clip_grad_norm_(model.parameters(), max_norm=5)
            optimizer.step()

            if iter_cnt % config['print_freq'] == 0:
                print('epoch %d, iter %d, train loss %f' % (i, iter_cnt, loss.item()))

        print('epoch %d, time %f' % (i, (time.time() - start)))
        scheduler.step()

        print("validating...")
        
        if config['method'] == 'ctc':
            eval_ctc(model, valloader, idx2char_ctc)
        else:
            eval_attention(model, decoder, valloader, idx2char_att)

        if i % config['test_freq'] == 0:
            print("testing...")
            if config['method'] == 'ctc':
                line_acc, rec_score = eval_ctc(model, testloader, idx2char_ctc)
            else:
                line_acc, rec_score = eval_attention(model, decoder, testloader, idx2char_att)

        if i % config['save_freq'] == 0:
            save_file_name = f"epoch_{i}_acc_{line_acc:.3f}_rec_score_{rec_score:.3f}.pth"
            save_file = os.path.join(config['save_path'], save_file_name)
            torch.save(model.state_dict(), save_file)
Exemple #4
0
class Trainer(object):
    def __init__(self):
        os.environ['CUDA_VISIBLE_DEVICES'] = args.gpus
        if args.chars_file == '':
            self.alphabet = alphabetChinese
        else:
            self.alphabet = utils.load_chars(args.chars_file)
        nclass = len(self.alphabet) + 1
        nc = 1
        self.net = CRNN(args.imgH, nc, args.nh, nclass)
        self.train_dataloader, self.val_dataloader = self.dataloader(
            self.alphabet)
        self.criterion = CTCLoss()
        self.optimizer = self.get_optimizer()
        self.converter = utils.strLabelConverter(self.alphabet,
                                                 ignore_case=False)
        self.best_acc = 0.00001

        model_name = '%s' % (args.dataset_name)
        if not os.path.exists(args.save_prefix):
            os.mkdir(args.save_prefix)
        args.save_prefix += model_name

        if args.pretrained != '':
            print('loading pretrained model from %s' % args.pretrained)
            checkpoint = torch.load(args.pretrained)

            if 'model_state_dict' in checkpoint.keys():
                # self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
                args.start_epoch = checkpoint['epoch']
                self.best_acc = checkpoint['best_acc']
                checkpoint = checkpoint['model_state_dict']

            from collections import OrderedDict
            model_dict = OrderedDict()
            for k, v in checkpoint.items():
                if 'module' in k:
                    model_dict[k[7:]] = v
                else:
                    model_dict[k] = v
            self.net.load_state_dict(model_dict)

        if not args.cuda and torch.cuda.is_available():
            print(
                "WARNING: You have a CUDA device, so you should probably run with --cuda"
            )

        elif args.cuda and torch.cuda.is_available():
            print('available gpus is ', torch.cuda.device_count())
            self.net = torch.nn.DataParallel(self.net, output_dim=1).cuda()
            self.criterion = self.criterion.cuda()

    def dataloader(self, alphabet):
        # train_transform = transforms.Compose(
        #     [transforms.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.5),
        #     resizeNormalize(args.imgH)])
        # train_dataset = BaseDataset(args.train_dir, alphabet, transform=train_transform)
        train_dataset = NumDataset(args.train_dir,
                                   alphabet,
                                   transform=resizeNormalize(args.imgH))
        train_dataloader = DataLoader(dataset=train_dataset,
                                      batch_size=args.batch_size,
                                      shuffle=True,
                                      num_workers=args.num_workers,
                                      pin_memory=True)

        if os.path.exists(args.val_dir):
            # val_dataset = BaseDataset(args.val_dir, alphabet, transform=resizeNormalize(args.imgH))
            val_dataset = NumDataset(args.val_dir,
                                     alphabet,
                                     mode='test',
                                     transform=resizeNormalize(args.imgH))
            val_dataloader = DataLoader(dataset=val_dataset,
                                        batch_size=args.batch_size,
                                        shuffle=False,
                                        num_workers=args.num_workers,
                                        pin_memory=True)
        else:
            val_dataloader = None

        return train_dataloader, val_dataloader

    def get_optimizer(self):
        if args.optimizer == 'sgd':
            optimizer = optim.SGD(
                self.net.parameters(),
                lr=args.lr,
                momentum=args.momentum,
                weight_decay=args.wd,
            )
        elif args.optimizer == 'adam':
            optimizer = optim.Adam(
                self.net.parameters(),
                lr=args.lr,
                betas=(args.beta1, 0.999),
            )
        else:
            optimizer = optim.RMSprop(
                self.net.parameters(),
                lr=args.lr,
                momentum=args.momentum,
                weight_decay=args.wd,
            )
        return optimizer

    def train(self):
        logging.basicConfig()
        logger = logging.getLogger()
        logger.setLevel(logging.INFO)
        log_file_path = args.save_prefix + '_train.log'
        log_dir = os.path.dirname(log_file_path)
        if log_dir and not os.path.exists(log_dir):
            os.mkdir(log_dir)
        fh = logging.FileHandler(log_file_path)
        logger.addHandler(fh)
        logger.info(args)
        logger.info('Start training from [Epoch {}]'.format(args.start_epoch +
                                                            1))

        losses = utils.Averager()
        train_accuracy = utils.Averager()

        for epoch in range(args.start_epoch, args.nepoch):
            self.net.train()
            btic = time.time()
            for i, (imgs, labels) in enumerate(self.train_dataloader):
                batch_size = imgs.size()[0]
                imgs = imgs.cuda()
                preds = self.net(imgs).cpu()
                text, length = self.converter.encode(
                    labels
                )  # length  一个batch各个样本的字符长度, text 一个batch中所有中文字符所对应的下标
                preds_size = torch.IntTensor([preds.size(0)] * batch_size)
                loss_avg = self.criterion(preds, text, preds_size,
                                          length) / batch_size

                self.optimizer.zero_grad()
                loss_avg.backward()
                self.optimizer.step()

                losses.update(loss_avg.item(), batch_size)

                _, preds_m = preds.max(2)
                preds_m = preds_m.transpose(1, 0).contiguous().view(-1)
                sim_preds = self.converter.decode(preds_m.data,
                                                  preds_size.data,
                                                  raw=False)
                n_correct = 0
                for pred, target in zip(sim_preds, labels):
                    if pred == target:
                        n_correct += 1
                train_accuracy.update(n_correct, batch_size, MUL_n=False)

                if args.log_interval and not (i + 1) % args.log_interval:
                    logger.info(
                        '[Epoch {}/{}][Batch {}/{}], Speed: {:.3f} samples/sec, Loss:{:.3f}'
                        .format(epoch + 1, args.nepoch, i + 1,
                                len(self.train_dataloader),
                                batch_size / (time.time() - btic),
                                losses.val()))
                    losses.reset()

            logger.info(
                'Training accuracy: {:.3f}, [#correct:{} / #total:{}]'.format(
                    train_accuracy.val(), train_accuracy.sum,
                    train_accuracy.count))
            train_accuracy.reset()

            if args.val_interval and not (epoch + 1) % args.val_interval:
                acc = self.validate(logger)
                if acc > self.best_acc:
                    self.best_acc = acc
                    save_path = '{:s}_best.pth'.format(args.save_prefix)
                    torch.save(
                        {
                            'epoch': epoch,
                            'model_state_dict': self.net.state_dict(),
                            # 'optimizer_state_dict': self.optimizer.state_dict(),
                            'best_acc': self.best_acc,
                        },
                        save_path)
                logging.info("best acc is:{:.3f}".format(self.best_acc))
                if args.save_interval and not (epoch + 1) % args.save_interval:
                    save_path = '{:s}_{:04d}_{:.3f}.pth'.format(
                        args.save_prefix, epoch + 1, acc)
                    torch.save(
                        {
                            'epoch': epoch,
                            'model_state_dict': self.net.state_dict(),
                            # 'optimizer_state_dict': self.optimizer.state_dict(),
                            'best_acc': self.best_acc,
                        },
                        save_path)

    def validate(self, logger):
        if self.val_dataloader is None:
            return 0
        logger.info('Start validate.')
        losses = utils.Averager()
        self.net.eval()
        n_correct = 0
        with torch.no_grad():
            for i, (imgs, labels) in enumerate(self.val_dataloader):
                batch_size = imgs.size()[0]
                imgs = imgs.cuda()
                preds = self.net(imgs).cpu()
                text, length = self.converter.encode(
                    labels
                )  # length  一个batch各个样本的字符长度, text 一个batch中所有中文字符所对应的下标
                preds_size = torch.IntTensor(
                    [preds.size(0)] * batch_size)  # timestep * batchsize
                loss_avg = self.criterion(preds, text, preds_size,
                                          length) / batch_size

                losses.update(loss_avg.item(), batch_size)

                _, preds = preds.max(2)
                preds = preds.transpose(1, 0).contiguous().view(-1)
                sim_preds = self.converter.decode(preds.data,
                                                  preds_size.data,
                                                  raw=False)
                for pred, target in zip(sim_preds, labels):
                    if pred == target:
                        n_correct += 1

        accuracy = n_correct / float(losses.count)

        logger.info(
            'Evaling loss: {:.3f}, accuracy: {:.3f}, [#correct:{} / #total:{}]'
            .format(losses.val(), accuracy, n_correct, losses.count))

        return accuracy