Beispiel #1
0
def train(opt):

    if opt.use_tb:
        tb_dir = f'/home_hongdo/{getpass.getuser()}/tb/{opt.experiment_name}'
        print('tensorboard : ', tb_dir)
        if not os.path.exists(tb_dir):
            os.makedirs(tb_dir)
        writer = SummaryWriter(log_dir=tb_dir)
    """ dataset preparation """
    if not opt.data_filtering_off:
        print(
            'Filtering the images containing characters which are not in opt.character'
        )
        print(
            'Filtering the images whose label is longer than opt.batch_max_length'
        )
        # see https://github.com/clovaai/deep-text-recognition-benchmark/blob/6593928855fb7abb999a99f428b3e4477d4ae356/dataset.py#L130

    opt.select_data = opt.select_data.split('-')
    opt.batch_ratio = opt.batch_ratio.split('-')
    train_dataset = Batch_Balanced_Dataset(opt)

    # log = open(f'./saved_models/{opt.experiment_name}/log_dataset.txt', 'a')
    log = open(f'{save_dir}/{opt.experiment_name}/log_dataset.txt', 'a')
    AlignCollate_valid = AlignCollate(imgH=opt.imgH,
                                      imgW=opt.imgW,
                                      keep_ratio_with_pad=opt.PAD)
    valid_dataset, valid_dataset_log = hierarchical_dataset(
        root=opt.valid_data, opt=opt)
    valid_loader = torch.utils.data.DataLoader(
        valid_dataset,
        batch_size=opt.batch_size,
        shuffle=
        True,  # 'True' to check training progress with validation function.
        num_workers=int(opt.workers),
        collate_fn=AlignCollate_valid,
        pin_memory=True)
    log.write(valid_dataset_log)
    print('-' * 80)
    log.write('-' * 80 + '\n')
    log.close()
    """ model configuration """
    converter = AttnLabelConverter(opt.character)
    opt.num_class = len(converter.character)

    if opt.rgb:
        opt.input_channel = 3

    # sekim for transfer learning
    model = Model(opt, 38)

    print('model input parameters', opt.imgH, opt.imgW, opt.num_fiducial,
          opt.input_channel, opt.output_channel, opt.hidden_size,
          opt.num_class, opt.batch_max_length)

    # weight initialization
    for name, param in model.named_parameters():
        if 'localization_fc2' in name:
            print(f'Skip {name} as it is already initialized')
            continue
        try:
            if 'bias' in name:
                init.constant_(param, 0.0)
            elif 'weight' in name:
                init.kaiming_normal_(param)
        except Exception as e:  # for batchnorm.
            if 'weight' in name:
                param.data.fill_(1)
            continue

    # data parallel for multi-GPU
    model = torch.nn.DataParallel(model).to(device)

    if opt.saved_model != '':
        print(f'loading pretrained model from {opt.saved_model}')
        if opt.FT:
            model.load_state_dict(torch.load(opt.saved_model), strict=False)
        else:
            model.load_state_dict(torch.load(opt.saved_model))

    # sekim change last layer
    in_feature = model.module.Prediction.generator.in_features
    model.module.Prediction.attention_cell.rnn = nn.LSTMCell(
        256 + opt.num_class, 256).to(device)
    model.module.Prediction.generator = nn.Linear(in_feature,
                                                  opt.num_class).to(device)

    print(model.module.Prediction.generator)
    print("Model:")
    print(model)

    model.train()
    """ setup loss """
    criterion = torch.nn.CrossEntropyLoss(ignore_index=0).to(
        device)  # ignore [GO] token = ignore index 0
    # loss averager
    loss_avg = Averager()

    # filter that only require gradient decent
    filtered_parameters = []
    params_num = []
    for p in filter(lambda p: p.requires_grad, model.parameters()):
        filtered_parameters.append(p)
        params_num.append(np.prod(p.size()))
    print('Trainable params num : ', sum(params_num))
    # [print(name, p.numel()) for name, p in filter(lambda p: p[1].requires_grad, model.named_parameters())]

    # setup optimizer
    if opt.adam:
        optimizer = optim.Adam(filtered_parameters,
                               lr=opt.lr,
                               betas=(opt.beta1, 0.999))
    else:
        optimizer = optim.Adadelta(filtered_parameters,
                                   lr=opt.lr,
                                   rho=opt.rho,
                                   eps=opt.eps)
    print("Optimizer:")
    print(optimizer)
    """ final options """

    # with open(f'./saved_models/{opt.experiment_name}/opt.txt', 'a') as opt_file:
    with open(f'{save_dir}/{opt.experiment_name}/opt.txt', 'a') as opt_file:
        opt_log = '------------ Options -------------\n'
        args = vars(opt)
        for k, v in args.items():
            opt_log += f'{str(k)}: {str(v)}\n'
        opt_log += '---------------------------------------\n'
        print(opt_log)
        opt_file.write(opt_log)
    """ start training """
    start_iter = 0

    if opt.saved_model != '':
        try:
            start_iter = int(opt.saved_model.split('_')[-1].split('.')[0])
            print("-------------------------------------------------")
            print(f'continue to train, start_iter: {start_iter}')
        except:
            pass

    start_time = time.time()
    best_accuracy = -1
    best_norm_ED = -1
    i = start_iter

    while (True):
        # train part
        image_tensors, labels = train_dataset.get_batch()
        image = image_tensors.to(device)
        text, length = converter.encode(labels,
                                        batch_max_length=opt.batch_max_length)
        batch_size = image.size(0)

        preds = model(image, text[:, :-1])  # align with Attention.forward
        target = text[:, 1:]  # without [GO] Symbol
        cost = criterion(preds.view(-1, preds.shape[-1]),
                         target.contiguous().view(-1))

        model.zero_grad()
        cost.backward()
        torch.nn.utils.clip_grad_norm_(
            model.parameters(),
            opt.grad_clip)  # gradient clipping with 5 (Default)
        optimizer.step()

        loss_avg.add(cost)

        # validation part
        if i % opt.valInterval == 0:
            elapsed_time = time.time() - start_time
            # for log
            with open(f'{save_dir}/{opt.experiment_name}/log_train.txt',
                      'a') as log:
                # with open(f'./saved_models/{opt.experiment_name}/log_train.txt', 'a') as log:
                model.eval()
                with torch.no_grad():
                    valid_loss, current_accuracy, current_norm_ED, preds, confidence_score, labels, infer_time, length_of_data = validation(
                        model, criterion, valid_loader, converter, opt)

                model.train()

                # training loss and validation loss

                loss_log = f'[{i}/{opt.num_iter}] Train loss: {loss_avg.val():0.5f}, Valid loss: {valid_loss:0.5f}, Elapsed_time: {elapsed_time:0.5f}'
                loss_avg.reset()
                if opt.use_tb:
                    writer.add_scalar('OCR_loss/train_loss', loss_avg.val(), i)
                    writer.add_scalar('OCR_loss/validation_loss', valid_loss,
                                      i)

                current_model_log = f'{"Current_accuracy":17s}: {current_accuracy:0.3f}, {"Current_norm_ED":17s}: {current_norm_ED:0.2f}'

                # keep best accuracy model (on valid dataset)
                if current_accuracy > best_accuracy:
                    best_accuracy = current_accuracy
                    # torch.save(model.state_dict(), f'./saved_models/{opt.experiment_name}/best_accuracy.pth')
                    torch.save(
                        model.state_dict(),
                        f'{save_dir}/{opt.experiment_name}/best_accuracy.pth')
                if current_norm_ED > best_norm_ED:
                    best_norm_ED = current_norm_ED
                    # torch.save(model.state_dict(), f'./saved_models/{opt.experiment_name}/best_norm_ED.pth')
                    torch.save(
                        model.state_dict(),
                        f'{save_dir}/{opt.experiment_name}/best_norm_ED.pth')
                best_model_log = f'{"Best_accuracy":17s}: {best_accuracy:0.3f}, {"Best_norm_ED":17s}: {best_norm_ED:0.2f}'

                loss_model_log = f'{loss_log}\n{current_model_log}\n{best_model_log}'
                print(loss_model_log)
                log.write(loss_model_log + '\n')

                # show some predicted results
                dashed_line = '-' * 80
                head = f'{"Ground Truth":25s} | {"Prediction":25s} | Confidence Score & T/F'
                predicted_result_log = f'{dashed_line}\n{head}\n{dashed_line}\n'
                for gt, pred, confidence in zip(labels[:5], preds[:5],
                                                confidence_score[:5]):

                    gt = gt[:gt.find('[s]')]
                    pred = pred[:pred.find('[s]')]

                    predicted_result_log += f'{gt:25s} | {pred:25s} | {confidence:0.4f}\t{str(pred == gt)}\n'
                predicted_result_log += f'{dashed_line}'
                print(predicted_result_log)
                log.write(predicted_result_log + '\n')

        # save model per 1e+5 iter.
        if (i + 1) % 1e+5 == 0:
            # torch.save(model.state_dict(), f'./saved_models/{opt.experiment_name}/iter_{i+1}.pth')
            torch.save(model.state_dict(),
                       f'{save_dir}/{opt.experiment_name}/iter_{i + 1}.pth')

        if i == opt.num_iter:
            print('end the training')
            sys.exit()
        i += 1
Beispiel #2
0
    def train(self, opt):
        # src, tar dataloaders
        src_dataset, tar_dataset, valid_loader = self.dataloader(opt)
        src_dataset_size = src_dataset.total_data_size
        tar_dataset_size = tar_dataset.total_data_size
        train_size = max([src_dataset_size, tar_dataset_size])
        iters_per_epoch = int(train_size / opt.batch_size)

        # Modify train size. Make sure both are of same size.
        # Modify training loop to continue giving src loss after tar is done.

        self.model.train()
        self.global_discriminator.train()
        self.local_discriminator.train()
        start_iter = 0

        if opt.continue_model != '':
            self.load(opt.continue_model)
            print(" [*] Load SUCCESS")

        # loss averager
        cls_loss_avg = Averager()
        sim_loss_avg = Averager()
        loss_avg = Averager()

        # training loop
        print('training start !')
        start_time = time.time()
        best_accuracy = -1
        best_norm_ED = 1e+6
        # i = start_iter
        gamma = 0
        omega = 1
        epoch = 0
        for step in range(start_iter, opt.num_iter + 1):
            epoch = step // iters_per_epoch
            if opt.decay_flag and step > (opt.num_iter // 2):
                self.d_image_opt.param_groups[0]['lr'] -= (opt.lr /
                                                           (opt.num_iter // 2))
                self.d_inst_opt.param_groups[0]['lr'] -= (opt.lr /
                                                          (opt.num_iter // 2))

            src_image, src_labels = src_dataset.get_batch()
            src_image = src_image.to(device)
            src_text, src_length = self.converter.encode(
                src_labels, batch_max_length=opt.batch_max_length)

            tar_image, tar_labels = tar_dataset.get_batch()
            tar_image = tar_image.to(device)
            tar_text, tar_length = self.converter.encode(
                tar_labels, batch_max_length=opt.batch_max_length)

            # Set gradient to zero...
            self.model.zero_grad()
            # Domain classifiers
            self.global_discriminator.zero_grad()
            self.local_discriminator.zero_grad()

            # Attention # align with Attention.forward
            src_preds, src_global_feature, src_local_feature = self.model(
                src_image, src_text[:, :-1])
            # src_global_feature = self.model.visual_feature
            # src_local_feature = self.model.Prediction.context_history
            target = src_text[:, 1:]  # without [GO] Symbol
            src_cls_loss = self.criterion(
                src_preds.view(-1, src_preds.shape[-1]),
                target.contiguous().view(-1))
            src_global_feature = src_global_feature.view(
                src_global_feature.shape[0], -1)
            src_local_feature = src_local_feature.view(
                -1, src_local_feature.shape[-1])

            tar_preds, tar_global_feature, tar_local_feature = self.model(
                tar_image, tar_text[:, :-1], is_train=False)
            # tar_global_feature = self.model.visual_feature
            # tar_local_feature = self.model.Prediction.context_history
            tar_global_feature = tar_global_feature.view(
                tar_global_feature.shape[0], -1)
            tar_local_feature = tar_local_feature.view(
                -1, tar_local_feature.shape[-1])

            src_local_feature, tar_local_feature = filter_local_features(
                opt, src_local_feature, src_preds, tar_local_feature,
                tar_preds)

            # Add domain adaption elements
            # setup hyperparameter
            if step % 2000 == 0:
                p = float(step + start_iter) / opt.num_iter
                gamma = 2. / (1. + np.exp(-10 * p)) - 1
                omega = 1 - 1. / (1. + np.exp(-10 * p))
            self.global_discriminator.module.set_beta(gamma)
            self.local_discriminator.module.set_beta(gamma)

            src_d_img_score = self.global_discriminator(src_global_feature)
            src_d_inst_score = self.local_discriminator(src_local_feature)
            tar_d_img_score = self.global_discriminator(tar_global_feature)
            tar_d_inst_score = self.local_discriminator(tar_local_feature)

            src_d_img_loss = self.D_criterion(
                src_d_img_score,
                torch.zeros_like(src_d_img_score).to(device))
            src_d_inst_loss = self.D_criterion(
                src_d_inst_score,
                torch.zeros_like(src_d_inst_score).to(device))
            tar_d_img_loss = self.D_criterion(
                tar_d_img_score,
                torch.ones_like(tar_d_img_score).to(device))
            tar_d_inst_loss = self.D_criterion(
                tar_d_inst_score,
                torch.ones_like(tar_d_inst_score).to(device))
            d_img_loss = src_d_img_loss + tar_d_img_loss
            d_inst_loss = src_d_inst_loss + tar_d_inst_loss

            # Add domain loss
            loss = src_cls_loss.mean() + omega * (d_img_loss.mean() +
                                                  d_inst_loss.mean())
            loss_avg.add(loss)
            cls_loss_avg.add(src_cls_loss)
            sim_loss_avg.add(d_img_loss + d_inst_loss)

            # frcnn backward
            loss.backward()
            # clip_gradient(self.model, 10.)
            torch.nn.utils.clip_grad_norm_(
                self.model.parameters(),
                opt.grad_clip)  # gradient clipping with 5 (Default)
            # frcnn optimizer update
            self.optimizer.step()
            # domain optimizer update
            self.d_inst_opt.step()
            self.d_image_opt.step()

            # validation part
            if step % opt.valInterval == 0:

                elapsed_time = time.time() - start_time
                print(
                    f'[{step}/{opt.num_iter}] Loss: {loss_avg.val():0.5f} CLS_Loss: {cls_loss_avg.val():0.5f} SIMI_Loss: {sim_loss_avg.val():0.5f} elapsed_time: {elapsed_time:0.5f}'
                )
                # for log
                with open(
                        f'./saved_models/{opt.experiment_name}/log_train.txt',
                        'a') as log:
                    log.write(
                        f'[{step}/{opt.num_iter}] Loss: {loss_avg.val():0.5f} elapsed_time: {elapsed_time:0.5f}\n'
                    )
                    loss_avg.reset()
                    cls_loss_avg.reset()
                    sim_loss_avg.reset()

                    self.model.eval()
                    with torch.no_grad():
                        valid_loss, current_accuracy, current_norm_ED, preds, labels, infer_time, length_of_data = validation(
                            self.model, self.criterion, valid_loader,
                            self.converter, opt)

                    self.print_prediction_result(preds, labels, log)

                    valid_log = f'[{step}/{opt.num_iter}] valid loss: {valid_loss:0.5f}'
                    valid_log += f' accuracy: {current_accuracy:0.3f}, norm_ED: {current_norm_ED:0.2f}'
                    print(valid_log)
                    log.write(valid_log + '\n')

                    self.model.train()
                    self.global_discriminator.train()
                    self.local_discriminator.train()

                    # keep best accuracy model

                    if current_accuracy > best_accuracy:
                        best_accuracy = current_accuracy
                        save_name = f'./saved_models/{opt.experiment_name}/best_accuracy.pth'
                        self.save(opt, save_name)
                    if current_norm_ED < best_norm_ED:
                        best_norm_ED = current_norm_ED
                        save_name = f'./saved_models/{opt.experiment_name}/best_norm_ED.pth'
                        self.save(opt, save_name)

                    best_model_log = f'best_accuracy: {best_accuracy:0.3f}, best_norm_ED: {best_norm_ED:0.2f}'
                    print(best_model_log)
                    log.write(best_model_log + '\n')

            # save model per 1e+5 iter.
            if (step + 1) % 1e+5 == 0:
                save_name = f'./saved_models/{opt.experiment_name}/iter_{step+1}.pth'
                self.save(opt, save_name)
def train(opt):
    print(opt.local_rank)
    opt.device = torch.device('cuda:{}'.format(opt.local_rank))
    device = opt.device
    """ dataset preparation """
    train_dataset = Batch_Balanced_Dataset(opt)

    valid_loader = train_dataset.getValDataloader()
    print('-' * 80)
    """ model configuration """
    if 'CTC' == opt.Prediction:
        converter = CTCLabelConverter(opt.character, opt)
    elif 'Attn' == opt.Prediction:
        converter = AttnLabelConverter(opt.character, opt)
    elif 'CTC_Attn' == opt.Prediction:
        converter = CTCLabelConverter(opt.character, opt), AttnLabelConverter(
            opt.character, opt)

    opt.num_class = len(opt.character)

    if opt.rgb:
        opt.input_channel = 3
    model = Model(opt)

    model.to(opt.device)

    print(model)
    print('model input parameters', opt.rgb, opt.imgH, opt.imgW,
          opt.num_fiducial, opt.input_channel, opt.output_channel,
          opt.hidden_size, opt.num_class, opt.batch_max_length,
          opt.Transformation, opt.FeatureExtraction, opt.SequenceModeling,
          opt.Prediction)

    # weight initialization
    for name, param in model.named_parameters():
        if 'localization_fc2' in name:
            print(f'Skip {name} as it is already initialized')
            continue
        try:
            if 'bias' in name:
                init.constant_(param, 0.0)
            elif 'weight' in name:
                init.kaiming_normal_(param)
        except Exception as e:  # for batchnorm.
            if 'weight' in name:
                param.data.fill_(1)
            continue
    """ setup loss """
    if 'CTC' == opt.Prediction:
        criterion = torch.nn.CTCLoss(zero_infinity=True).to(device)
    elif 'Attn' == opt.Prediction:
        criterion = torch.nn.CrossEntropyLoss(
            ignore_index=0).to(device), torch.nn.MSELoss(
                reduction="sum").to(device)
        # ignore [GO] token = ignore index 0
    elif 'CTC_Attn' == opt.Prediction:
        criterion = torch.nn.CTCLoss(
            zero_infinity=True).to(device), torch.nn.CrossEntropyLoss(
                ignore_index=0).to(device), torch.nn.MSELoss(
                    reduction='sum').to(device)

    # loss averager
    loss_avg = Averager()

    # filter that only require gradient decent
    filtered_parameters = []
    params_num = []
    for p in filter(lambda p: p.requires_grad, model.parameters()):
        filtered_parameters.append(p)
        params_num.append(np.prod(p.size()))

    if opt.local_rank == 0:
        print('Trainable params num : ', sum(params_num))
    # [print(name, p.numel()) for name, p in filter(lambda p: p[1].requires_grad, model.named_parameters())]

    # setup optimizer
    if opt.sgd:
        optimizer = optim.SGD(filtered_parameters,
                              lr=opt.lr,
                              momentum=0.9,
                              weight_decay=opt.weight_decay)
    elif opt.adam:
        optimizer = optim.Adam(filtered_parameters,
                               lr=opt.lr,
                               betas=(opt.beta1, 0.999))
    else:
        optimizer = optim.Adadelta(filtered_parameters,
                                   lr=opt.lr,
                                   rho=opt.rho,
                                   eps=opt.eps)
    if opt.local_rank == 0:
        print("Optimizer:")
        print(optimizer)

    if opt.sync_bn:
        model = apex.parallel.convert_syncbn_model(model)

    if opt.amp > 1:
        model, optimizer = amp.initialize(model,
                                          optimizer,
                                          opt_level="O" + str(opt.amp),
                                          keep_batchnorm_fp32=True,
                                          loss_scale="dynamic")
    else:
        model, optimizer = amp.initialize(model,
                                          optimizer,
                                          opt_level="O" + str(opt.amp))

    # data parallel for multi-GPU
    model = DDP(model)

    if opt.continue_model != '':
        print(f'loading pretrained model from {opt.continue_model}')
        try:
            model.load_state_dict(
                torch.load(opt.continue_model,
                           map_location=torch.device(
                               'cuda', torch.cuda.current_device())))
        except:
            traceback.print_exc()
            print(f'COPYING pretrained model from {opt.continue_model}')
            pretrained_dict = torch.load(opt.continue_model,
                                         map_location=torch.device(
                                             'cuda',
                                             torch.cuda.current_device()))

            model_dict = model.state_dict()
            pretrained_dict2 = dict()
            for k, v in pretrained_dict.items():
                if opt.Prediction == 'Attn':
                    if 'module.Prediction_attn.' in k:
                        k = k.replace('module.Prediction_attn.',
                                      'module.Prediction.')
                if k in model_dict and model_dict[k].shape == v.shape:
                    pretrained_dict2[k] = v

            model_dict.update(pretrained_dict2)
            model.load_state_dict(model_dict)

    model.train()
    """ final options """
    with open(f'./saved_models/{opt.experiment_name}/opt.txt',
              'a') as opt_file:
        opt_log = '------------ Options -------------\n'
        args = vars(opt)
        for k, v in args.items():
            opt_log += f'{str(k)}: {str(v)}\n'
        opt_log += '---------------------------------------\n'
        opt_log += str(model)
        print(opt_log)
        opt_file.write(opt_log)
    """ start training """
    start_iter = 0

    start_time = time.time()
    best_accuracy = -1
    best_norm_ED = 1e+6
    i = start_iter

    ct = opt.batch_mul
    model.zero_grad()

    dist.barrier()
    while (True):
        # train part
        start = time.time()
        image, labels, pos = train_dataset.sync_get_batch()
        end = time.time()
        data_t = end - start

        start = time.time()
        batch_size = image.size(0)

        if 'CTC' == opt.Prediction:
            text, length = converter.encode(
                labels, batch_max_length=opt.batch_max_length)
            preds = model(image, text).log_softmax(2)
            preds_size = torch.IntTensor([preds.size(1)] *
                                         batch_size).to(device)
            preds = preds.permute(1, 0, 2)  # to use CTCLoss format

            # To avoid ctc_loss issue, disabled cudnn for the computation of the ctc_loss
            # https://github.com/jpuigcerver/PyLaia/issues/16
            torch.backends.cudnn.enabled = False
            cost = criterion(preds, text, preds_size, length)
            torch.backends.cudnn.enabled = True
        elif 'Attn' == opt.Prediction:
            text, length = converter.encode(
                labels, batch_max_length=opt.batch_max_length)
            preds = model(image, text[:, :-1])  # align with Attention.forward
            preds_attn = preds[0]
            preds_alpha = preds[1]
            target = text[:, 1:]  # without [GO] Symbol
            cost = criterion[0](preds_attn.view(-1, preds_attn.shape[-1]),
                                target.contiguous().view(-1))

            if opt.posreg_w > 0.001:
                cost_pos = alpha_loss(preds_alpha, pos, opt, criterion[1])
                print('attn_cost = ', cost, 'pos_cost = ',
                      cost_pos * opt.posreg_w)
                cost += opt.posreg_w * cost_pos
            else:
                print('attn_cost = ', cost_attn)

        elif 'CTC_Attn' == opt.Prediction:
            text_ctc, length_ctc = converter[0].encode(
                labels, batch_max_length=opt.batch_max_length)
            text_attn, length_attn = converter[1].encode(
                labels, batch_max_length=opt.batch_max_length)
            """ ctc prediction and loss """

            #should input text_attn here
            preds = model(image, text_attn[:, :-1])
            preds_ctc = preds[0].log_softmax(2)

            preds_ctc_size = torch.IntTensor([preds_ctc.size(1)] *
                                             batch_size).to(device)
            preds_ctc = preds_ctc.permute(1, 0, 2)  # to use CTCLoss format

            # To avoid ctc_loss issue, disabled cudnn for the computation of the ctc_loss
            # https://github.com/jpuigcerver/PyLaia/issues/16

            torch.backends.cudnn.enabled = False
            cost_ctc = criterion[0](preds_ctc, text_ctc, preds_ctc_size,
                                    length_ctc)
            torch.backends.cudnn.enabled = True
            """ attention prediction and loss """
            preds_attn = preds[1][0]  # align with Attention.forward
            preds_alpha = preds[1][1]

            target = text_attn[:, 1:]  # without [GO] Symbol

            cost_attn = criterion[1](preds_attn.view(-1, preds_attn.shape[-1]),
                                     target.contiguous().view(-1))

            cost = opt.ctc_attn_loss_ratio * cost_ctc + (
                1 - opt.ctc_attn_loss_ratio) * cost_attn

            if opt.posreg_w > 0.001:
                cost_pos = alpha_loss(preds_alpha, pos, opt, criterion[2])
                cost += opt.posreg_w * cost_pos
                cost_ctc = reduce_tensor(cost_ctc)
                cost_attn = reduce_tensor(cost_attn)
                cost_pos = reduce_tensor(cost_pos)
                if opt.local_rank == 0:
                    print('ctc_cost = ', cost_ctc, 'attn_cost = ', cost_attn,
                          'pos_cost = ', cost_pos * opt.posreg_w)
            else:
                cost_ctc = reduce_tensor(cost_ctc)
                cost_attn = reduce_tensor(cost_attn)
                if opt.local_rank == 0:
                    print('ctc_cost = ', cost_ctc, 'attn_cost = ', cost_attn)

        cost /= opt.batch_mul
        if opt.amp > 0:
            with amp.scale_loss(cost, optimizer) as scaled_loss:
                scaled_loss.backward()
        else:
            cost.backward()
        """ https://github.com/davidlmorton/learning-rate-schedules/blob/master/increasing_batch_size_without_increasing_memory.ipynb """
        ct -= 1
        if ct == 0:
            if opt.amp > 0:
                torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer),
                                               opt.grad_clip)
            else:
                torch.nn.utils.clip_grad_norm_(
                    model.parameters(),
                    opt.grad_clip)  # gradient clipping with 5 (Default)
            optimizer.step()
            model.zero_grad()
            ct = opt.batch_mul
        else:
            continue

        train_t = time.time() - start
        cost = reduce_tensor(cost)
        loss_avg.add(cost)
        if opt.local_rank == 0:
            print('iter', i, 'loss =', cost, ', data_t=', data_t, ',train_t=',
                  train_t, ', batchsz=', opt.batch_mul * opt.batch_size)
        sys.stdout.flush()
        # validation part
        if (i > 0 and i % opt.valInterval == 0) or (i == 0 and
                                                    opt.continue_model != ''):
            elapsed_time = time.time() - start_time
            print(
                f'[{i}/{opt.num_iter}] Loss: {loss_avg.val():0.5f} elapsed_time: {elapsed_time:0.5f}'
            )
            # for log
            with open(f'./saved_models/{opt.experiment_name}/log_train.txt',
                      'a') as log:
                log.write(
                    f'[{i}/{opt.num_iter}] Loss: {loss_avg.val():0.5f} elapsed_time: {elapsed_time:0.5f}\n'
                )
                loss_avg.reset()

                model.eval()
                with torch.no_grad():
                    if 'CTC_Attn' in opt.Prediction:
                        # we only count for attention accuracy, because ctc is used to help attention
                        valid_loss, current_accuracy_ctc, current_accuracy, current_norm_ED_ctc, current_norm_ED, preds, labels, infer_time, length_of_data = validation(
                            model, criterion[1], valid_loader, converter[1],
                            opt, converter[0])
                    elif 'Attn' in opt.Prediction:
                        valid_loss, current_accuracy, current_norm_ED, preds, labels, infer_time, length_of_data = validation(
                            model, criterion[0], valid_loader, converter, opt)
                    else:
                        valid_loss, current_accuracy, current_norm_ED, preds, labels, infer_time, length_of_data = validation(
                            model, criterion, valid_loader, converter, opt)
                model.train()

                for pred, gt in zip(preds[:10], labels[:10]):
                    if 'Attn' in opt.Prediction:
                        pred = pred[:pred.find('[s]')]
                        gt = gt[:gt.find('[s]')]
                    print(f'{pred:20s}, gt: {gt:20s},   {str(pred == gt)}')
                    log.write(
                        f'{pred:20s}, gt: {gt:20s},   {str(pred == gt)}\n')

                valid_log = f'[{i}/{opt.num_iter}] valid loss: {valid_loss:0.5f}'
                valid_log += f' accuracy: {current_accuracy:0.3f}, norm_ED: {current_norm_ED:0.2f}'

                if 'CTC_Attn' in opt.Prediction:
                    valid_log += f' ctc_accuracy: {current_accuracy_ctc:0.3f}, ctc_norm_ED: {current_norm_ED_ctc:0.2f}'
                    current_accuracy = max(current_accuracy,
                                           current_accuracy_ctc)
                    current_norm_ED = min(current_norm_ED, current_norm_ED_ctc)

                if opt.local_rank == 0:
                    print(valid_log)
                    log.write(valid_log + '\n')

                    # keep best accuracy model
                    if current_accuracy > best_accuracy:
                        best_accuracy = current_accuracy
                        torch.save(
                            model.state_dict(),
                            f'./saved_models/{opt.experiment_name}/best_accuracy.pth'
                        )
                        torch.save(
                            model,
                            f'./saved_models/{opt.experiment_name}/best_accuracy.model'
                        )
                    if current_norm_ED < best_norm_ED:
                        best_norm_ED = current_norm_ED
                        torch.save(
                            model.state_dict(),
                            f'./saved_models/{opt.experiment_name}/best_norm_ED.pth'
                        )
                        torch.save(
                            model,
                            f'./saved_models/{opt.experiment_name}/best_norm_ED.model'
                        )
                    best_model_log = f'best_accuracy: {best_accuracy:0.3f}, best_norm_ED: {best_norm_ED:0.2f}'
                    print(best_model_log)
                    log.write(best_model_log + '\n')

        # save model per iter.
        if (i + 1) % opt.save_interval == 0 and opt.local_rank == 0:
            torch.save(model.state_dict(),
                       f'./saved_models/{opt.experiment_name}/iter_{i+1}.pth')

        if i == opt.num_iter:
            print('end the training')
            sys.exit()
        if opt.prof_iter > 0 and i > opt.prof_iter:
            sys.exit()

        i += 1
Beispiel #4
0
def train(opt, log):
    """dataset preparation"""
    # train dataset. for convenience
    if opt.select_data == "label":
        select_data = [
            "1.SVT",
            "2.IIIT",
            "3.IC13",
            "4.IC15",
            "5.COCO",
            "6.RCTW17",
            "7.Uber",
            "8.ArT",
            "9.LSVT",
            "10.MLT19",
            "11.ReCTS",
        ]

    elif opt.select_data == "synth":
        select_data = ["MJ", "ST"]

    elif opt.select_data == "synth_SA":
        select_data = ["MJ", "ST", "SA"]
        opt.batch_ratio = "0.4-0.4-0.2"  # same ratio with SCATTER paper.

    elif opt.select_data == "mix":
        select_data = [
            "1.SVT",
            "2.IIIT",
            "3.IC13",
            "4.IC15",
            "5.COCO",
            "6.RCTW17",
            "7.Uber",
            "8.ArT",
            "9.LSVT",
            "10.MLT19",
            "11.ReCTS",
            "MJ",
            "ST",
        ]

    elif opt.select_data == "mix_SA":
        select_data = [
            "1.SVT",
            "2.IIIT",
            "3.IC13",
            "4.IC15",
            "5.COCO",
            "6.RCTW17",
            "7.Uber",
            "8.ArT",
            "9.LSVT",
            "10.MLT19",
            "11.ReCTS",
            "MJ",
            "ST",
            "SA",
        ]

    else:
        select_data = opt.select_data.split("-")

    # set batch_ratio for each data.
    if opt.batch_ratio:
        batch_ratio = opt.batch_ratio.split("-")
    else:
        batch_ratio = [round(1 / len(select_data), 3)] * len(select_data)

    train_loader = Batch_Balanced_Dataset(opt, opt.train_data, select_data,
                                          batch_ratio, log)

    if opt.semi != "None":
        select_data_unlabel = ["U1.Book32", "U2.TextVQA", "U3.STVQA"]
        batch_ratio_unlabel = [round(1 / len(select_data_unlabel), 3)
                               ] * len(select_data_unlabel)
        dataset_root_unlabel = "data_CVPR2021/training/unlabel/"
        train_loader_unlabel_semi = Batch_Balanced_Dataset(
            opt,
            dataset_root_unlabel,
            select_data_unlabel,
            batch_ratio_unlabel,
            log,
            learn_type="semi",
        )

    AlignCollate_valid = AlignCollate(opt, mode="test")
    valid_dataset, valid_dataset_log = hierarchical_dataset(
        root=opt.valid_data, opt=opt, mode="test")
    valid_loader = torch.utils.data.DataLoader(
        valid_dataset,
        batch_size=opt.batch_size,
        shuffle=
        True,  # 'True' to check training progress with validation function.
        num_workers=int(opt.workers),
        collate_fn=AlignCollate_valid,
        pin_memory=False,
    )
    log.write(valid_dataset_log)
    print("-" * 80)
    log.write("-" * 80 + "\n")
    """ model configuration """
    if "CTC" in opt.Prediction:
        converter = CTCLabelConverter(opt.character)
    else:
        converter = AttnLabelConverter(opt.character)
        opt.sos_token_index = converter.dict["[SOS]"]
        opt.eos_token_index = converter.dict["[EOS]"]
    opt.num_class = len(converter.character)

    model = Model(opt)

    # weight initialization
    for name, param in model.named_parameters():
        if "localization_fc2" in name:
            print(f"Skip {name} as it is already initialized")
            continue
        try:
            if "bias" in name:
                init.constant_(param, 0.0)
            elif "weight" in name:
                init.kaiming_normal_(param)
        except Exception as e:  # for batchnorm.
            if "weight" in name:
                param.data.fill_(1)
            continue

    # data parallel for multi-GPU
    model = torch.nn.DataParallel(model).to(device)
    model.train()
    if opt.saved_model != "":
        fine_tuning_log = f"### loading pretrained model from {opt.saved_model}\n"

        if "MoCo" in opt.saved_model or "MoCo" in opt.self_pre:
            pretrained_state_dict_qk = torch.load(opt.saved_model)
            pretrained_state_dict = {}
            for name in pretrained_state_dict_qk:
                if "encoder_q" in name:
                    rename = name.replace("encoder_q.", "")
                    pretrained_state_dict[rename] = pretrained_state_dict_qk[
                        name]
        else:
            pretrained_state_dict = torch.load(opt.saved_model)

        for name, param in model.named_parameters():
            try:
                param.data.copy_(pretrained_state_dict[name].data
                                 )  # load from pretrained model
                if opt.FT == "freeze":
                    param.requires_grad = False  # Freeze
                    fine_tuning_log += f"pretrained layer (freezed): {name}\n"
                else:
                    fine_tuning_log += f"pretrained layer: {name}\n"
            except:
                fine_tuning_log += f"non-pretrained layer: {name}\n"

        print(fine_tuning_log)
        log.write(fine_tuning_log + "\n")

    # print("Model:")
    # print(model)
    log.write(repr(model) + "\n")
    """ setup loss """
    if "CTC" in opt.Prediction:
        criterion = torch.nn.CTCLoss(zero_infinity=True).to(device)
    else:
        # ignore [PAD] token
        criterion = torch.nn.CrossEntropyLoss(
            ignore_index=converter.dict["[PAD]"]).to(device)

    if "Pseudo" in opt.semi:
        criterion_SemiSL = PseudoLabelLoss(opt, converter, criterion)
    elif "MeanT" in opt.semi:
        criterion_SemiSL = MeanTeacherLoss(opt, student_for_init_teacher=model)

    # loss averager
    train_loss_avg = Averager()
    semi_loss_avg = Averager()  # semi supervised loss avg

    # filter that only require gradient descent
    filtered_parameters = []
    params_num = []
    for p in filter(lambda p: p.requires_grad, model.parameters()):
        filtered_parameters.append(p)
        params_num.append(np.prod(p.size()))
    print(f"Trainable params num: {sum(params_num)}")
    log.write(f"Trainable params num: {sum(params_num)}\n")
    # [print(name, p.numel()) for name, p in filter(lambda p: p[1].requires_grad, model.named_parameters())]

    # setup optimizer
    if opt.optimizer == "sgd":
        optimizer = torch.optim.SGD(
            filtered_parameters,
            lr=opt.lr,
            momentum=opt.sgd_momentum,
            weight_decay=opt.sgd_weight_decay,
        )
    elif opt.optimizer == "adadelta":
        optimizer = torch.optim.Adadelta(filtered_parameters,
                                         lr=opt.lr,
                                         rho=opt.rho,
                                         eps=opt.eps)
    elif opt.optimizer == "adam":
        optimizer = torch.optim.Adam(filtered_parameters, lr=opt.lr)
    print("Optimizer:")
    print(optimizer)
    log.write(repr(optimizer) + "\n")

    if "super" in opt.schedule:
        if opt.optimizer == "sgd":
            cycle_momentum = True
        else:
            cycle_momentum = False

        scheduler = torch.optim.lr_scheduler.OneCycleLR(
            optimizer,
            max_lr=opt.lr,
            cycle_momentum=cycle_momentum,
            div_factor=20,
            final_div_factor=1000,
            total_steps=opt.num_iter,
        )
        print("Scheduler:")
        print(scheduler)
        log.write(repr(scheduler) + "\n")
    """ final options """
    # print(opt)
    opt_log = "------------ Options -------------\n"
    args = vars(opt)
    for k, v in args.items():
        if str(k) == "character" and len(str(v)) > 500:
            opt_log += f"{str(k)}: So many characters to show all: number of characters: {len(str(v))}\n"
        else:
            opt_log += f"{str(k)}: {str(v)}\n"
    opt_log += "---------------------------------------\n"
    print(opt_log)
    log.write(opt_log)
    log.close()
    """ start training """
    start_iter = 0
    if opt.saved_model != "":
        try:
            start_iter = int(opt.saved_model.split("_")[-1].split(".")[0])
            print(f"continue to train, start_iter: {start_iter}")
        except:
            pass

    start_time = time.time()
    best_score = -1

    # training loop
    for iteration in tqdm(
            range(start_iter + 1, opt.num_iter + 1),
            total=opt.num_iter,
            position=0,
            leave=True,
    ):
        if "MeanT" in opt.semi:
            image_tensors, image_tensors_ema, labels = train_loader.get_batch_ema(
            )
        else:
            image_tensors, labels = train_loader.get_batch()

        image = image_tensors.to(device)
        labels_index, labels_length = converter.encode(
            labels, batch_max_length=opt.batch_max_length)
        batch_size = image.size(0)

        # default recognition loss part
        if "CTC" in opt.Prediction:
            preds = model(image)
            preds_size = torch.IntTensor([preds.size(1)] * batch_size)
            preds_log_softmax = preds.log_softmax(2).permute(1, 0, 2)
            loss = criterion(preds_log_softmax, labels_index, preds_size,
                             labels_length)
        else:
            preds = model(image,
                          labels_index[:, :-1])  # align with Attention.forward
            target = labels_index[:, 1:]  # without [SOS] Symbol
            loss = criterion(preds.view(-1, preds.shape[-1]),
                             target.contiguous().view(-1))

        # semi supervised part (SemiSL)
        if "Pseudo" in opt.semi:
            image_unlabel, _ = train_loader_unlabel_semi.get_batch_two_images()
            image_unlabel = image_unlabel.to(device)
            loss_SemiSL = criterion_SemiSL(image_unlabel, model)

            loss = loss + loss_SemiSL
            semi_loss_avg.add(loss_SemiSL)

        elif "MeanT" in opt.semi:
            (
                image_tensors_unlabel,
                image_tensors_unlabel_ema,
            ) = train_loader_unlabel_semi.get_batch_two_images()
            image_unlabel = image_tensors_unlabel.to(device)
            student_input = torch.cat([image, image_unlabel], dim=0)

            image_ema = image_tensors_ema.to(device)
            image_unlabel_ema = image_tensors_unlabel_ema.to(device)
            teacher_input = torch.cat([image_ema, image_unlabel_ema], dim=0)
            loss_SemiSL = criterion_SemiSL(
                student_input=student_input,
                student_logit=preds,
                student=model,
                teacher_input=teacher_input,
                iteration=iteration,
            )

            loss = loss + loss_SemiSL
            semi_loss_avg.add(loss_SemiSL)

        model.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(
            model.parameters(),
            opt.grad_clip)  # gradient clipping with 5 (Default)
        optimizer.step()
        train_loss_avg.add(loss)

        if "super" in opt.schedule:
            scheduler.step()
        else:
            adjust_learning_rate(optimizer, iteration, opt)

        # validation part.
        # To see training progress, we also conduct validation when 'iteration == 1'
        if iteration % opt.val_interval == 0 or iteration == 1:
            # for validation log
            with open(f"./saved_models/{opt.exp_name}/log_train.txt",
                      "a") as log:
                model.eval()
                with torch.no_grad():
                    (
                        valid_loss,
                        current_score,
                        preds,
                        confidence_score,
                        labels,
                        infer_time,
                        length_of_data,
                    ) = validation(model, criterion, valid_loader, converter,
                                   opt)
                model.train()

                # keep best score (accuracy or norm ED) model on valid dataset
                # Do not use this on test datasets. It would be an unfair comparison
                # (training should be done without referring test set).
                if current_score > best_score:
                    best_score = current_score
                    torch.save(
                        model.state_dict(),
                        f"./saved_models/{opt.exp_name}/best_score.pth",
                    )

                # validation log: loss, lr, score (accuracy or norm ED), time.
                lr = optimizer.param_groups[0]["lr"]
                elapsed_time = time.time() - start_time
                valid_log = f"\n[{iteration}/{opt.num_iter}] Train_loss: {train_loss_avg.val():0.5f}, Valid_loss: {valid_loss:0.5f}"
                valid_log += f", Semi_loss: {semi_loss_avg.val():0.5f}\n"
                valid_log += f'{"Current_score":17s}: {current_score:0.2f}, Current_lr: {lr:0.7f}\n'
                valid_log += f'{"Best_score":17s}: {best_score:0.2f}, Infer_time: {infer_time:0.1f}, Elapsed_time: {elapsed_time:0.1f}'

                # show some predicted results
                dashed_line = "-" * 80
                head = f'{"Ground Truth":25s} | {"Prediction":25s} | Confidence Score & T/F'
                predicted_result_log = f"{dashed_line}\n{head}\n{dashed_line}\n"
                for gt, pred, confidence in zip(labels[:5], preds[:5],
                                                confidence_score[:5]):
                    if "Attn" in opt.Prediction:
                        gt = gt[:gt.find("[EOS]")]
                        pred = pred[:pred.find("[EOS]")]

                    predicted_result_log += f"{gt:25s} | {pred:25s} | {confidence:0.4f}\t{str(pred == gt)}\n"
                predicted_result_log += f"{dashed_line}"
                valid_log = f"{valid_log}\n{predicted_result_log}"
                print(valid_log)
                log.write(valid_log + "\n")

                opt.writer.add_scalar("train/train_loss",
                                      float(f"{train_loss_avg.val():0.5f}"),
                                      iteration)
                opt.writer.add_scalar("train/semi_loss",
                                      float(f"{semi_loss_avg.val():0.5f}"),
                                      iteration)
                opt.writer.add_scalar("train/lr", float(f"{lr:0.7f}"),
                                      iteration)
                opt.writer.add_scalar("train/elapsed_time",
                                      float(f"{elapsed_time:0.1f}"), iteration)
                opt.writer.add_scalar("valid/valid_loss",
                                      float(f"{valid_loss:0.5f}"), iteration)
                opt.writer.add_scalar("valid/current_score",
                                      float(f"{current_score:0.2f}"),
                                      iteration)
                opt.writer.add_scalar("valid/best_score",
                                      float(f"{best_score:0.2f}"), iteration)

                train_loss_avg.reset()
                semi_loss_avg.reset()
    """ Evaluation at the end of training """
    print("Start evaluation on benchmark testset")
    """ keep evaluation model and result logs """
    os.makedirs(f"./result/{opt.exp_name}", exist_ok=True)
    os.makedirs(f"./evaluation_log", exist_ok=True)
    saved_best_model = f"./saved_models/{opt.exp_name}/best_score.pth"
    # os.system(f'cp {saved_best_model} ./result/{opt.exp_name}/')
    model.load_state_dict(torch.load(f"{saved_best_model}"))

    opt.eval_type = "benchmark"
    model.eval()
    with torch.no_grad():
        total_accuracy, eval_data_list, accuracy_list = benchmark_all_eval(
            model, criterion, converter, opt)

    opt.writer.add_scalar("test/total_accuracy",
                          float(f"{total_accuracy:0.2f}"), iteration)
    for eval_data, accuracy in zip(eval_data_list, accuracy_list):
        accuracy = float(accuracy)
        opt.writer.add_scalar(f"test/{eval_data}", float(f"{accuracy:0.2f}"),
                              iteration)

    print(
        f'finished the experiment: {opt.exp_name}, "CUDA_VISIBLE_DEVICES" was {opt.CUDA_VISIBLE_DEVICES}'
    )
def train(opt):
    """ dataset preparation """
    if not opt.data_filtering_off:
        print(
            'Filtering the images containing characters which are not in opt.character'
        )
        print(
            'Filtering the images whose label is longer than opt.batch_max_length'
        )
        # see https://github.com/clovaai/deep-text-recognition-benchmark/blob/6593928855fb7abb999a99f428b3e4477d4ae356/dataset.py#L130

    opt.select_data = opt.select_data.split('-')
    opt.batch_ratio = opt.batch_ratio.split('-')
    # train_dataset (image, label)
    train_dataset = Batch_Balanced_Dataset(opt)
    log = open(f'./saved_models/{opt.exp_name}/log_dataset.txt', 'a')
    AlignCollate_valid = AlignCollate(imgH=opt.imgH,
                                      imgW=opt.imgW,
                                      keep_ratio_with_pad=opt.PAD)
    valid_dataset, valid_dataset_log = hierarchical_dataset(
        root=opt.valid_data, opt=opt)
    valid_loader = torch.utils.data.DataLoader(
        valid_dataset,
        batch_size=opt.batch_size,
        shuffle=
        True,  # 'True' to check training progress with validation function.
        num_workers=int(opt.workers),
        collate_fn=AlignCollate_valid,
        pin_memory=True)
    log.write(valid_dataset_log)
    print('-' * 80)
    log.write('-' * 80 + '\n')
    log.close()
    """ model configuration """
    if 'CTC' in opt.Prediction:
        if opt.baiduCTC:
            converter = CTCLabelConverterForBaiduWarpctc(opt.character)
        else:
            converter = CTCLabelConverter(opt.character)
    else:
        converter = AttnLabelConverter(opt.character)

    opt.num_class = len(converter.character)

    if opt.rgb:
        opt.input_channel = 3
    model = Model(opt)

    print('model input parameters', opt.imgH, opt.imgW, opt.num_fiducial,
          opt.input_channel, opt.output_channel, opt.hidden_size,
          opt.num_class, opt.batch_max_length, opt.Transformation,
          opt.FeatureExtraction, opt.SequenceModeling, opt.Prediction)

    print("Model:")
    print(model)

    total_num, true_grad_num, false_grad_num = calculate_model_params(model)
    print("Total parameters: ", total_num)
    print("Number of parameters requires grad: ", true_grad_num)
    print("Number of parameters do not require grad: ", false_grad_num)

    # weight initialization
    for name, param in model.named_parameters():
        if 'localization_fc2' in name:
            print(f'Skip {name} as it is already initialized')
            continue
        try:
            if 'bias' in name:
                init.constant_(param, 0.0)
            elif 'weight' in name:
                init.kaiming_normal_(param)
        except Exception as e:  # for batchnorm.
            if 'weight' in name:
                param.data.fill_(1)
            continue

    # data parallel for multi-GPU
    model = torch.nn.DataParallel(model).to(device)
    model.train()
    if isinstance(model, torch.nn.DataParallel):
        model = model.module
    # load pretrained model
    if opt.saved_model != '':
        print(f'loading pretrained model from {opt.saved_model}')
        if opt.FT:
            model.load_pretrained_networks()
        elif opt.continue_train:
            model.load_checkpoint(opt.model_name)
        else:
            raise Exception('Something went wrong!')
    """ setup loss """
    if 'CTC' in opt.Prediction:
        if opt.baiduCTC:
            # need to install warpctc. see our guideline.
            from warpctc_pytorch import CTCLoss
            criterion = CTCLoss()
        else:
            criterion = torch.nn.CTCLoss(zero_infinity=True).to(device)
    else:
        criterion = torch.nn.CrossEntropyLoss(ignore_index=0).to(
            device)  # ignore [GO] token = ignore index 0
    # loss averager
    loss_avg = Averager()
    log_dir = f'./saved_models/{opt.exp_name}'
    writer = SummaryWriter(log_dir)

    # """ final options """
    # print(opt)
    with open(f'./saved_models/{opt.exp_name}/opt.txt', 'a') as opt_file:
        opt_log = '------------ Options -------------\n'
        args = vars(opt)
        for k, v in args.items():
            opt_log += f'{str(k)}: {str(v)}\n'
        opt_log += '---------------------------------------\n'
        print(opt_log)
        opt_file.write(opt_log)
    """ start training """
    start_iter = 0
    start_time = time.time()
    best_accuracy = -1
    best_norm_ED = -1
    iteration = start_iter

    while (True):
        # train part
        image_tensors, labels = train_dataset.get_batch()
        image = image_tensors.to(device)
        text, length = converter.encode(labels,
                                        batch_max_length=opt.batch_max_length)
        batch_size = image.size(0)

        if 'CTC' in opt.Prediction:
            preds = model(image, text)
            preds_size = torch.IntTensor([preds.size(1)] * batch_size)
            if opt.baiduCTC:
                preds = preds.permute(1, 0, 2)  # to use CTCLoss format
                cost = criterion(preds, text, preds_size, length) / batch_size
            else:
                preds = preds.log_softmax(2).permute(1, 0, 2)
                cost = criterion(preds, text, preds_size, length)

        else:
            preds = model(image, text[:, :-1])  # align with Attention.forward
            target = text[:, 1:]  # without [GO] Symbol
            cost = criterion(preds.view(-1, preds.shape[-1]),
                             target.contiguous().view(-1))

        model.zero_grad()
        cost.backward()
        torch.nn.utils.clip_grad_norm_(
            model.parameters(),
            opt.grad_clip)  # gradient clipping with 5 (Default)
        model.optimize_parameters()
        writer.add_scalar('train_loss', cost, iteration + 1)
        loss_avg.add(cost)

        # validation part
        if (
                iteration + 1
        ) % opt.valInterval == 0 or iteration == 0:  # To see training progress, we also conduct validation when 'iteration == 0'
            elapsed_time = time.time() - start_time
            # for log
            with open(f'./saved_models/{opt.exp_name}/log_train.txt',
                      'a') as log:
                model.eval()
                with torch.no_grad():
                    valid_loss, current_accuracy, current_norm_ED, preds, confidence_score, labels, infer_time, length_of_data = validation(
                        model, criterion, valid_loader, converter, opt,
                        iteration)
                model.train()

                # training loss and validation loss
                loss_log = f'[{iteration+1}/{opt.num_iter}] Train loss: {loss_avg.val():0.5f}, Valid loss: {valid_loss:0.5f}, Elapsed_time: {elapsed_time:0.5f}'
                writer.add_scalar('val_loss', valid_loss, iteration + 1)
                writer.add_scalar('accuracy', current_accuracy, iteration + 1)
                loss_avg.reset()

                current_model_log = f'{"Current_accuracy":17s}: {current_accuracy:0.3f}, {"Current_norm_ED":17s}: {current_norm_ED:0.2f}'

                # keep best accuracy model (on valid dataset)
                if current_accuracy > best_accuracy:
                    best_accuracy = current_accuracy
                    model.save_checkpoints(iteration, 'best_accuracy.pth')
                if current_norm_ED > best_norm_ED:
                    best_norm_ED = current_norm_ED
                    model.save_checkpoints(iteration, 'best_norm_ED.pth')
                best_model_log = f'{"Best_accuracy":17s}: {best_accuracy:0.3f}, {"Best_norm_ED":17s}: {best_norm_ED:0.2f}'

                loss_model_log = f'{loss_log}\n{current_model_log}\n{best_model_log}'
                print(loss_model_log)
                log.write(loss_model_log + '\n')

                # show some predicted results
                dashed_line = '-' * 80
                head = f'{"Ground Truth":25s} | {"Prediction":25s} | Confidence Score & T/F'
                predicted_result_log = f'{dashed_line}\n{head}\n{dashed_line}\n'
                for gt, pred, confidence in zip(labels[:5], preds[:5],
                                                confidence_score[:5]):
                    if 'Attn' in opt.Prediction:
                        gt = gt[:gt.find('[s]')]
                        pred = pred[:pred.find('[s]')]
                    predicted_result_log += f'{gt:25s} | {pred:25s} | {confidence:0.4f}\t{str(pred == gt)}\n'
                predicted_result_log += f'{dashed_line}'
                print(predicted_result_log)
                log.write(predicted_result_log + '\n')

        # save model per 1e+5 iter.
        if (iteration + 1) % 1e+5 == 0:
            model.save_checkpoints(iteration + 1, opt.model_name)

        if (iteration + 1) == opt.num_iter:
            print('end the training')
            sys.exit()
        iteration += 1
Beispiel #6
0
def train(opt):
    """ dataset preparation """
    opt.select_data = opt.select_data.split('-')
    opt.batch_ratio = opt.batch_ratio.split('-')
    train_dataset = Batch_Balanced_Dataset(opt)

    AlignCollate_valid = AlignCollate(imgH=opt.imgH,
                                      imgW=opt.imgW,
                                      keep_ratio_with_pad=opt.PAD)
    valid_dataset = hierarchical_dataset(root=opt.valid_data, opt=opt)
    valid_loader = torch.utils.data.DataLoader(
        valid_dataset,
        batch_size=opt.batch_size,
        # 'True' to check training progress with validation function.
        shuffle=True,
        num_workers=int(opt.workers),
        collate_fn=AlignCollate_valid,
        pin_memory=True)
    print('-' * 80)
    """ model configuration """
    if 'Transformer' in opt.SequenceModeling:
        converter = TransformerLabelConverter(opt.character)
    elif 'CTC' in opt.Prediction:
        converter = CTCLabelConverter(opt.character)
    else:
        converter = AttnLabelConverter(opt.character)
    opt.num_class = len(converter.character)

    if opt.rgb:
        opt.input_channel = 3
    model = Model(opt)
    print('model input parameters', opt.imgH, opt.imgW, opt.num_fiducial,
          opt.input_channel, opt.output_channel, opt.hidden_size,
          opt.num_class, opt.batch_max_length, opt.Transformation,
          opt.FeatureExtraction, opt.SequenceModeling, opt.Prediction)

    # weight initialization
    for name, param in model.named_parameters():
        if 'localization_fc2' in name:
            print(f'Skip {name} as it is already initialized')
            continue
        try:
            if 'bias' in name:
                init.constant_(param, 0.0)
            elif 'weight' in name:
                init.kaiming_normal_(param)
        except Exception as e:  # for batchnorm.
            if 'weight' in name:
                param.data.fill_(1)
            continue
    """ setup loss """
    if 'Transformer' in opt.SequenceModeling:
        criterion = transformer_loss
    elif 'CTC' in opt.Prediction:
        criterion = torch.nn.CTCLoss(zero_infinity=True).cuda()
    else:
        # ignore [GO] token = ignore index 0
        criterion = torch.nn.CrossEntropyLoss(ignore_index=0).cuda()
    # loss averager
    loss_avg = Averager()

    # filter that only require gradient decent
    filtered_parameters = []
    params_num = []
    for p in filter(lambda p: p.requires_grad, model.parameters()):
        filtered_parameters.append(p)
        params_num.append(np.prod(p.size()))
    print('Trainable params num : ', sum(params_num))
    # [print(name, p.numel()) for name, p in filter(lambda p: p[1].requires_grad, model.named_parameters())]

    # setup optimizer
    if opt.adam:
        optimizer = optim.Adam(filtered_parameters,
                               lr=opt.lr,
                               betas=(opt.beta1, 0.999))
    elif 'Transformer' in opt.SequenceModeling and opt.use_scheduled_optim:
        optimizer = optim.Adam(filtered_parameters,
                               betas=(0.9, 0.98),
                               eps=1e-09)
        optimizer_schedule = ScheduledOptim(optimizer, opt.d_model,
                                            opt.n_warmup_steps)
    else:
        optimizer = optim.Adadelta(filtered_parameters,
                                   lr=opt.lr,
                                   rho=opt.rho,
                                   eps=opt.eps)
    print("Optimizer:")
    print(optimizer)
    """ final options """
    # print(opt)
    with open(f'./saved_models/{opt.experiment_name}/opt.txt',
              'a') as opt_file:
        opt_log = '------------ Options -------------\n'
        args = vars(opt)
        for k, v in args.items():
            opt_log += f'{str(k)}: {str(v)}\n'
        opt_log += '---------------------------------------\n'
        print(opt_log)
        opt_file.write(opt_log)
    """ start training """
    start_iter = 0

    start_time = time.time()
    best_accuracy = -1
    best_norm_ED = 1e+6
    pickle.load = partial(pickle.load, encoding="latin1")
    pickle.Unpickler = partial(pickle.Unpickler, encoding="latin1")
    if opt.load_weights != '' and check_isfile(opt.load_weights):
        # load pretrained weights but ignore layers that don't match in size
        checkpoint = torch.load(opt.load_weights, pickle_module=pickle)
        if type(checkpoint) == dict:
            pretrain_dict = checkpoint['state_dict']
        else:
            pretrain_dict = checkpoint
        model_dict = model.state_dict()
        pretrain_dict = {
            k: v
            for k, v in pretrain_dict.items()
            if k in model_dict and model_dict[k].size() == v.size()
        }
        model_dict.update(pretrain_dict)
        model.load_state_dict(model_dict)
        print("Loaded pretrained weights from '{}'".format(opt.load_weights))
        del checkpoint
        torch.cuda.empty_cache()
    if opt.continue_model != '':
        print(f'loading pretrained model from {opt.continue_model}')
        checkpoint = torch.load(opt.continue_model)
        print(checkpoint.keys())
        model.load_state_dict(checkpoint['state_dict'])
        start_iter = checkpoint['step'] + 1
        print('continue to train start_iter: ', start_iter)
        if 'optimizer' in checkpoint.keys():
            optimizer.load_state_dict(checkpoint['optimizer'])
            for state in optimizer.state.values():
                for k, v in state.items():
                    if isinstance(v, torch.Tensor):
                        state[k] = v.cuda()
        if 'best_accuracy' in checkpoint.keys():
            best_accuracy = checkpoint['best_accuracy']
        if 'best_norm_ED' in checkpoint.keys():
            best_norm_ED = checkpoint['best_norm_ED']
        del checkpoint
        torch.cuda.empty_cache()
    # data parallel for multi-GPU
    model = torch.nn.DataParallel(model).cuda()
    model.train()
    print("Model size:", count_num_param(model), 'M')
    if 'Transformer' in opt.SequenceModeling and opt.use_scheduled_optim:
        optimizer_schedule.n_current_steps = start_iter

    for i in tqdm(range(start_iter, opt.num_iter)):
        for p in model.parameters():
            p.requires_grad = True

        cpu_images, cpu_texts = train_dataset.get_batch()
        image = cpu_images.cuda()
        if 'Transformer' in opt.SequenceModeling:
            text, length, text_pos = converter.encode(cpu_texts,
                                                      opt.batch_max_length)
        elif 'CTC' in opt.Prediction:
            text, length = converter.encode(cpu_texts)
        else:
            text, length = converter.encode(cpu_texts, opt.batch_max_length)
        batch_size = image.size(0)

        if 'Transformer' in opt.SequenceModeling:
            preds = model(image, text, tgt_pos=text_pos)
            target = text[:, 1:]  # without <s> Symbol
            cost = criterion(preds.view(-1, preds.shape[-1]),
                             target.contiguous().view(-1))
        elif 'CTC' in opt.Prediction:
            preds = model(image, text).log_softmax(2)
            preds_size = torch.IntTensor([preds.size(1)] * batch_size)
            preds = preds.permute(1, 0, 2)  # to use CTCLoss format
            cost = criterion(preds, text, preds_size, length)
        else:
            preds = model(image, text)
            target = text[:, 1:]  # without [GO] Symbol
            cost = criterion(preds.view(-1, preds.shape[-1]),
                             target.contiguous().view(-1))

        model.zero_grad()
        cost.backward()

        if 'Transformer' in opt.SequenceModeling and opt.use_scheduled_optim:
            optimizer_schedule.step_and_update_lr()
        elif 'Transformer' in opt.SequenceModeling:
            optimizer.step()
        else:
            # gradient clipping with 5 (Default)
            torch.nn.utils.clip_grad_norm_(model.parameters(), opt.grad_clip)
            optimizer.step()

        loss_avg.add(cost)

        # validation part
        if i > 0 and (i + 1) % opt.valInterval == 0:
            elapsed_time = time.time() - start_time
            print(
                f'[{i+1}/{opt.num_iter}] Loss: {loss_avg.val():0.5f} elapsed_time: {elapsed_time:0.5f}'
            )
            # for log
            with open(f'./saved_models/{opt.experiment_name}/log_train.txt',
                      'a') as log:
                log.write(
                    f'[{i+1}/{opt.num_iter}] Loss: {loss_avg.val():0.5f} elapsed_time: {elapsed_time:0.5f}\n'
                )
                loss_avg.reset()

                model.eval()
                with torch.no_grad():
                    valid_loss, current_accuracy, current_norm_ED, preds, gts, infer_time, length_of_data = validation(
                        model, criterion, valid_loader, converter, opt)
                model.train()

                for pred, gt in zip(preds[:5], gts[:5]):
                    if 'Transformer' in opt.SequenceModeling:
                        pred = pred[:pred.find('</s>')]
                        gt = gt[:gt.find('</s>')]
                    elif 'Attn' in opt.Prediction:
                        pred = pred[:pred.find('[s]')]
                        gt = gt[:gt.find('[s]')]

                    print(f'{pred:20s}, gt: {gt:20s},   {str(pred == gt)}')
                    log.write(
                        f'{pred:20s}, gt: {gt:20s},   {str(pred == gt)}\n')

                valid_log = f'[{i+1}/{opt.num_iter}] valid loss: {valid_loss:0.5f}'
                valid_log += f' accuracy: {current_accuracy:0.3f}, norm_ED: {current_norm_ED:0.2f}'
                print(valid_log)
                log.write(valid_log + '\n')

                # keep best accuracy model
                if current_accuracy > best_accuracy:
                    best_accuracy = current_accuracy
                    state_dict = model.module.state_dict()
                    save_checkpoint(
                        {
                            'best_accuracy': best_accuracy,
                            'state_dict': state_dict,
                        }, False,
                        f'./saved_models/{opt.experiment_name}/best_accuracy.pth'
                    )
                if current_norm_ED < best_norm_ED:
                    best_norm_ED = current_norm_ED
                    state_dict = model.module.state_dict()
                    save_checkpoint(
                        {
                            'best_norm_ED': best_norm_ED,
                            'state_dict': state_dict,
                        }, False,
                        f'./saved_models/{opt.experiment_name}/best_norm_ED.pth'
                    )
                    # torch.save(
                    #     model.state_dict(), f'./saved_models/{opt.experiment_name}/best_norm_ED.pth')
                best_model_log = f'best_accuracy: {best_accuracy:0.3f}, best_norm_ED: {best_norm_ED:0.2f}'
                print(best_model_log)
                log.write(best_model_log + '\n')

        # save model per 1000 iter.
        if (i + 1) % 1000 == 0:
            state_dict = model.module.state_dict()
            optimizer_state_dict = optimizer.state_dict()
            save_checkpoint(
                {
                    'state_dict': state_dict,
                    'optimizer': optimizer_state_dict,
                    'step': i,
                    'best_accuracy': best_accuracy,
                    'best_norm_ED': best_norm_ED,
                }, False,
                f'./saved_models/{opt.experiment_name}/iter_{i+1}.pth')
                valid = matches > -1
                mkpts0 = kpts0[valid]
                mkpts1 = kpts1[matches[valid]]
                mconf = conf[valid]
                viz_path = eval_output_dir / '{}_matches.{}'.format(
                    str(i), opt.viz_extension)
                color = cm.jet(mconf)
                stem = pred['file_name']
                text = []

                out = make_matching_plot(image0, image1, kpts0, kpts1, mkpts0,
                                         mkpts1, color, text, viz_path, stem,
                                         opt.show_keypoints, opt.fast_viz,
                                         opt.opencv_display, 'Matches')

        ### eval ###
        superglue.eval()
        homogrpahy_auc = validation(superglue, 'datasets/val.txt')

        epoch_loss /= len(train_loader)

        if not os.path.isdir("exp"):
            os.makedirs("exp")

        model_out_path = "exp/model_epoch_{}_{}.pth".format(
            epoch, homogrpahy_auc)
        torch.save(superglue, model_out_path)
        print(
            "Epoch [{}/{}] done. Epoch Loss {}. Checkpoint saved to {}".format(
                epoch, opt.epoch, epoch_loss, model_out_path))
Beispiel #8
0
def train(opt):
    """ dataset preparation """
    if not opt.data_filtering_off:
        print(
            'Filtering the images containing characters which are not in opt.character'
        )
        print(
            'Filtering the images whose label is longer than opt.batch_max_length'
        )

    opt.select_data = opt.select_data.split('-')
    opt.batch_ratio = opt.batch_ratio.split('-')
    train_dataset = Batch_Balanced_Dataset(opt)

    log = open(f'./saved_models/{opt.exp_name}/log_dataset.txt', 'a')
    AlignCollate_valid = AlignCollate(imgH=opt.imgH,
                                      imgW=opt.imgW,
                                      keep_ratio_with_pad=opt.PAD)
    valid_dataset, valid_dataset_log = hierarchical_dataset(
        root=opt.valid_data, opt=opt)
    valid_loader = torch.utils.data.DataLoader(
        valid_dataset,
        batch_size=opt.batch_size,
        shuffle=
        True,  # 'True' to check training progress with validation function.
        num_workers=int(opt.workers),
        collate_fn=AlignCollate_valid,
        pin_memory=True)
    log.write(valid_dataset_log)
    print('-' * 80)
    log.write('-' * 80 + '\n')
    log.close()
    """ model configuration """
    if 'CTC' in opt.Prediction:
        if opt.baiduCTC:
            converter = CTCLabelConverterForBaiduWarpctc(opt.character)
        else:
            converter = CTCLabelConverter(opt.character)
    else:
        converter = AttnLabelConverter(opt.character)
    opt.num_class = len(converter.character)

    if opt.rgb:
        opt.input_channel = 3
    model = Model(opt)
    print('model input parameters', opt.imgH, opt.imgW, opt.num_fiducial,
          opt.input_channel, opt.output_channel, opt.hidden_size,
          opt.num_class, opt.batch_max_length, opt.Transformation,
          opt.FeatureExtraction, opt.SequenceModeling, opt.Prediction)

    # weight initialization
    for name, param in model.named_parameters():
        if 'localization_fc2' in name:
            print(f'Skip {name} as it is already initialized')
            continue
        try:
            if 'bias' in name:
                init.constant_(param, 0.0)
            elif 'weight' in name:
                init.kaiming_normal_(param)
        except Exception as e:  # for batchnorm.
            if 'weight' in name:
                param.data.fill_(1)
            continue

    # data parallel for multi-GPU
    model = torch.nn.DataParallel(model).to(device)
    model.train()
    if opt.saved_model != '':
        print(f'loading pretrained model from {opt.saved_model}')
        if opt.FT:
            model.load_state_dict(torch.load(opt.saved_model), strict=False)
        else:
            model.load_state_dict(torch.load(opt.saved_model))
    print("Model:")
    print(model)
    """ setup loss """
    if 'CTC' in opt.Prediction:
        if opt.baiduCTC:
            # need to install warpctc. see our guideline.
            from warpctc_pytorch import CTCLoss
            criterion = CTCLoss()
        else:
            criterion = torch.nn.CTCLoss(zero_infinity=True).to(device)
    else:
        criterion = torch.nn.CrossEntropyLoss(ignore_index=0).to(
            device)  # ignore [GO] token = ignore index 0
    # loss averager
    loss_avg = Averager()

    # filter that only require gradient decent
    filtered_parameters = []
    params_num = []
    for p in filter(lambda p: p.requires_grad, model.parameters()):
        filtered_parameters.append(p)
        params_num.append(np.prod(p.size()))
    print('Trainable params num : ', sum(params_num))
    # [print(name, p.numel()) for name, p in filter(lambda p: p[1].requires_grad, model.named_parameters())]

    # setup optimizer
    if opt.adam:
        optimizer = optim.Adam(filtered_parameters,
                               lr=opt.lr,
                               betas=(opt.beta1, 0.999))
    else:
        optimizer = optim.Adadelta(filtered_parameters,
                                   lr=opt.lr,
                                   rho=opt.rho,
                                   eps=opt.eps)
    print("Optimizer:")
    print(optimizer)
    """ final options """
    # print(opt)
    with open(f'./saved_models/{opt.exp_name}/opt.txt', 'a') as opt_file:
        opt_log = '------------ Options -------------\n'
        args = vars(opt)
        for k, v in args.items():
            opt_log += f'{str(k)}: {str(v)}\n'
        opt_log += '---------------------------------------\n'
        print(opt_log)
        opt_file.write(opt_log)
    """ start training """
    start_iter = 0
    if opt.saved_model != '':
        try:
            start_iter = int(opt.saved_model.split('_')[-1].split('.')[0])
            print(f'continue to train, start_iter: {start_iter}')
        except:
            pass

    start_time = time.time()
    best_accuracy = -1
    best_norm_ED = -1
    iteration = start_iter

    while (True):
        # train part
        image_tensors, labels = train_dataset.get_batch()
        image = image_tensors.to(device)
        text, length = converter.encode(labels,
                                        batch_max_length=opt.batch_max_length)
        batch_size = image.size(0)

        if 'CTC' in opt.Prediction:
            preds = model(image, text)
            preds_size = torch.IntTensor([preds.size(1)] * batch_size)
            if opt.baiduCTC:
                preds = preds.permute(1, 0, 2)  # to use CTCLoss format
                cost = criterion(preds, text, preds_size, length) / batch_size
            else:
                preds = preds.log_softmax(2).permute(1, 0, 2)
                cost = criterion(preds, text, preds_size, length)

        else:
            preds = model(image, text[:, :-1])  # align with Attention.forward
            target = text[:, 1:]  # without [GO] Symbol
            cost = criterion(preds.view(-1, preds.shape[-1]),
                             target.contiguous().view(-1))

        model.zero_grad()
        cost.backward()
        torch.nn.utils.clip_grad_norm_(
            model.parameters(),
            opt.grad_clip)  # gradient clipping with 5 (Default)
        optimizer.step()

        loss_avg.add(cost)

        # validation part
        if (
                iteration + 1
        ) % opt.valInterval == 0 or iteration == 0:  # To see training progress, we also conduct validation when 'iteration == 0'
            elapsed_time = time.time() - start_time
            # for log
            with open(f'./saved_models/{opt.exp_name}/log_train.txt',
                      'a') as log:
                model.eval()
                with torch.no_grad():
                    valid_loss, current_accuracy, current_norm_ED, preds, confidence_score, labels, infer_time, length_of_data = validation(
                        model, criterion, valid_loader, converter, opt)
                model.train()

                # training loss and validation loss
                loss_log = f'[{iteration+1}/{opt.num_iter}] Train loss: {loss_avg.val():0.5f}, Valid loss: {valid_loss:0.5f}, Elapsed_time: {elapsed_time:0.5f}'
                loss_avg.reset()

                current_model_log = f'{"Current_accuracy":17s}: {current_accuracy:0.3f}, {"Current_norm_ED":17s}: {current_norm_ED:0.2f}'

                # keep best accuracy model (on valid dataset)
                if current_accuracy > best_accuracy:
                    best_accuracy = current_accuracy
                    torch.save(
                        model.state_dict(),
                        f'./saved_models/{opt.exp_name}/best_accuracy.pth')
                if current_norm_ED > best_norm_ED:
                    best_norm_ED = current_norm_ED
                    torch.save(
                        model.state_dict(),
                        f'./saved_models/{opt.exp_name}/best_norm_ED.pth')
                best_model_log = f'{"Best_accuracy":17s}: {best_accuracy:0.3f}, {"Best_norm_ED":17s}: {best_norm_ED:0.2f}'

                loss_model_log = f'{loss_log}\n{current_model_log}\n{best_model_log}'
                print(loss_model_log)
                log.write(loss_model_log + '\n')

                # show some predicted results
                dashed_line = '-' * 80
                head = f'{"Ground Truth":25s} | {"Prediction":25s} | Confidence Score & T/F'
                predicted_result_log = f'{dashed_line}\n{head}\n{dashed_line}\n'
                for gt, pred, confidence in zip(labels[:5], preds[:5],
                                                confidence_score[:5]):
                    if 'Attn' in opt.Prediction:
                        gt = gt[:gt.find('[s]')]
                        pred = pred[:pred.find('[s]')]

                    predicted_result_log += f'{gt:25s} | {pred:25s} | {confidence:0.4f}\t{str(pred == gt)}\n'
                predicted_result_log += f'{dashed_line}'
                print(predicted_result_log)
                log.write(predicted_result_log + '\n')

        # save model per 1e+5 iter.
        if (iteration + 1) % 1e+5 == 0:
            torch.save(
                model.state_dict(),
                f'./saved_models/{opt.exp_name}/iter_{iteration+1}.pth')

        if (iteration + 1) == opt.num_iter:
            print('end the training')
            sys.exit()
        iteration += 1
    def train(self, opt):
        # Add custom dataset add cfgs from da-faster-rcnn
        # Make sure you change the imdb_name in factory.py
        """
        Dummy format:

        args.src_dataset == '$YOUR_DATASET_NAME'
        args.src_imdb_name = '$YOUR_DATASET_NAME_2007_trainval'
        args.src_imdbval_name = '$YOUR_DATASET_NAME_2007_test'
        args.set_cfgs = [...]
        """

        # src, tar dataloaders
        src_dataset, tar_dataset, valid_loader = self.dataloader(opt)
        src_dataset_size = src_dataset.total_data_size
        tar_dataset_size = tar_dataset.total_data_size
        train_size = max([src_dataset_size, tar_dataset_size])

        self.model.train()

        start_iter = 0

        if opt.continue_model != '':
            self.load(opt.continue_model)
            print(" [*] Load SUCCESS")
            # if opt.decay_flag and start_iter > (opt.num_iter // 2):
            #     self.d_image_opt.param_groups[0]['lr'] -= (opt.lr / (opt.num_iter // 2)) * (
            #             start_iter - opt.num_iter // 2)
            #     self.d_inst_opt.param_groups[0]['lr'] -= (opt.lr / (opt.num_iter // 2)) * (
            #             start_iter - opt.num_iter // 2)

        # loss averager
        cls_loss_avg = Averager()
        sim_loss_avg = Averager()
        loss_avg = Averager()

        # training loop
        print('training start !')
        start_time = time.time()
        best_accuracy = -1
        best_norm_ED = 1e+6

        for step in range(start_iter, opt.num_iter + 1):

            src_image, src_labels = src_dataset.get_batch()
            src_image = src_image.to(device)
            src_text, src_length = self.converter.encode(
                src_labels, batch_max_length=opt.batch_max_length)

            tar_image, tar_labels = tar_dataset.get_batch()
            tar_image = tar_image.to(device)
            tar_text, tar_length = self.converter.encode(
                tar_labels, batch_max_length=opt.batch_max_length)

            # Set gradient to zero...
            self.model.zero_grad()

            # Attention # align with Attention.forward
            src_preds, src_global_feature, src_local_feature = self.model(
                src_image, src_text[:, :-1])
            target = src_text[:, 1:]  # without [GO] Symbol
            src_cls_loss = self.criterion(
                src_preds.view(-1, src_preds.shape[-1]),
                target.contiguous().view(-1))

            src_local_feature = src_local_feature.view(
                -1, src_local_feature.shape[-1])
            # TODO
            tar_preds, tar_global_feature, tar_local_feature = self.model(
                tar_image, tar_text[:, :-1], is_train=False)

            tar_local_feature = tar_local_feature.view(
                -1, tar_local_feature.shape[-1])

            d_inst_loss = coral_loss(src_local_feature, src_preds,
                                     tar_local_feature, tar_preds)
            # Add domain loss
            loss = src_cls_loss.mean() + 0.1 * d_inst_loss.mean()
            loss_avg.add(loss)
            cls_loss_avg.add(src_cls_loss)
            sim_loss_avg.add(d_inst_loss)

            # frcnn backward
            loss.backward()
            torch.nn.utils.clip_grad_norm_(
                self.model.parameters(),
                opt.grad_clip)  # gradient clipping with 5 (Default)
            # frcnn optimizer update
            self.optimizer.step()

            # validation part
            if step % opt.valInterval == 0:

                elapsed_time = time.time() - start_time
                print(
                    f'[{step}/{opt.num_iter}] Loss: {loss_avg.val():0.5f} CLS_Loss: {cls_loss_avg.val():0.5f} SIMI_Loss: {sim_loss_avg.val():0.5f} elapsed_time: {elapsed_time:0.5f}'
                )
                # for log
                with open(
                        f'./saved_models/{opt.experiment_name}/log_train.txt',
                        'a') as log:
                    log.write(
                        f'[{step}/{opt.num_iter}] Loss: {loss_avg.val():0.5f} elapsed_time: {elapsed_time:0.5f}\n'
                    )
                    loss_avg.reset()
                    cls_loss_avg.reset()
                    sim_loss_avg.reset()

                    self.model.eval()
                    with torch.no_grad():
                        valid_loss, current_accuracy, current_norm_ED, preds, labels, infer_time, length_of_data = validation(
                            self.model, self.criterion, valid_loader,
                            self.converter, opt)

                    self.print_prediction_result(preds, labels, log)

                    valid_log = f'[{step}/{opt.num_iter}] valid loss: {valid_loss:0.5f}'
                    valid_log += f' accuracy: {current_accuracy:0.3f}, norm_ED: {current_norm_ED:0.2f}'
                    print(valid_log)
                    log.write(valid_log + '\n')

                    self.model.train()

                    # keep best accuracy model
                    if current_accuracy > best_accuracy:
                        best_accuracy = current_accuracy
                        save_name = f'./saved_models/{opt.experiment_name}/best_accuracy.pth'
                        self.save(opt, save_name)
                    if current_norm_ED < best_norm_ED:
                        best_norm_ED = current_norm_ED
                        save_name = f'./saved_models/{opt.experiment_name}/best_norm_ED.pth'
                        self.save(opt, save_name)

                    best_model_log = f'best_accuracy: {best_accuracy:0.3f}, best_norm_ED: {best_norm_ED:0.2f}'
                    print(best_model_log)
                    log.write(best_model_log + '\n')

            # save model per 1e+5 iter.
            if (step + 1) % 1e+5 == 0:
                save_name = f'./saved_models/{opt.experiment_name}/iter_{step+1}.pth'
                self.save(opt, save_name)
Beispiel #10
0
def train(opt,
          AMP,
          WdB,
          ralph_path,
          train_data_path,
          train_data_list,
          test_data_path,
          test_data_list,
          experiment_name,
          train_batch_size,
          val_batch_size,
          workers,
          lr,
          valInterval,
          num_iter,
          wdbprj,
          continue_model='',
          finetune=''):

    HVD3P = pO.HVD or pO.DDP

    os.makedirs(f'./saved_models/{experiment_name}', exist_ok=True)

    # if OnceExecWorker and WdB:
    #     wandb.init(project=wdbprj, name=experiment_name)
    #     wandb.config.update(opt)

    # load supplied ralph
    with open(ralph_path, 'r') as f:
        ralph_train = json.load(f)

    print('[4] IN TRAIN; BEFORE MAKING DATASET')
    train_dataset = ds_load.myLoadDS(train_data_list,
                                     train_data_path,
                                     ralph=ralph_train)
    valid_dataset = ds_load.myLoadDS(test_data_list,
                                     test_data_path,
                                     ralph=ralph_train)

    # SAVE RALPH FOR LATER USE
    # with open(f'./saved_models/{experiment_name}/ralph.json', 'w+') as f:
    #     json.dump(train_dataset.ralph, f)

    print('[5] DATASET DONE LOADING')
    if OnceExecWorker:
        print(pO)
        print('Alphabet :', len(train_dataset.alph), train_dataset.alph)
        for d in [train_dataset, valid_dataset]:
            print('Dataset Size :', len(d.fns))
            # print('Max LbW : ',max(list(map(len,d.tlbls))) )
            # print('#Chars : ',sum([len(x) for x in d.tlbls]))
            # print('Sample label :',d.tlbls[-1])
            # print("Dataset :", sorted(list(map(len,d.tlbls))) )
            print('-' * 80)

    if opt.num_gpu > 1:
        workers = workers * (1 if HVD3P else opt.num_gpu)

    if HVD3P:
        train_sampler = torch.utils.data.distributed.DistributedSampler(
            train_dataset, num_replicas=opt.world_size, rank=opt.rank)
        valid_sampler = torch.utils.data.distributed.DistributedSampler(
            valid_dataset, num_replicas=opt.world_size, rank=opt.rank)

    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=train_batch_size,
        shuffle=True if not HVD3P else False,
        pin_memory=True,
        num_workers=int(workers),
        sampler=train_sampler if HVD3P else None,
        worker_init_fn=WrkSeeder,
        collate_fn=ds_load.SameTrCollate)

    valid_loader = torch.utils.data.DataLoader(
        valid_dataset,
        batch_size=val_batch_size,
        pin_memory=True,
        num_workers=int(workers),
        sampler=valid_sampler if HVD3P else None)

    model = OrigamiNet()
    model.apply(init_bn)

    # load finetune ckpt
    if finetune != '':
        model = load_finetune(model, finetune)

    model.train()

    if OnceExecWorker:
        import pprint
        [print(k, model.lreszs[k]) for k in sorted(model.lreszs.keys())]

    biparams = list(
        dict(filter(lambda kv: 'bias' in kv[0],
                    model.named_parameters())).values())
    nonbiparams = list(
        dict(filter(lambda kv: 'bias' not in kv[0],
                    model.named_parameters())).values())

    if not pO.DDP:
        model = model.to(device)
    else:
        model.cuda(opt.rank)

    optimizer = optim.Adam(model.parameters(), lr=lr)
    lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer,
                                                          gamma=10**(-1 /
                                                                     90000))

    # if OnceExecWorker and WdB:
    #     wandb.watch(model, log="all")

    if pO.HVD:
        hvd.broadcast_parameters(model.state_dict(), root_rank=0)
        hvd.broadcast_optimizer_state(optimizer, root_rank=0)
        optimizer = hvd.DistributedOptimizer(
            optimizer, named_parameters=model.named_parameters())
        # optimizer = hvd.DistributedOptimizer(optimizer, named_parameters=model.named_parameters(), compression=hvd.Compression.fp16)

    if pO.DDP and opt.rank != 0:
        random.seed()
        np.random.seed()

    # if AMP:
    #     model, optimizer = amp.initialize(model, optimizer, opt_level = "O1")
    if pO.DP:
        model = torch.nn.DataParallel(model)
    elif pO.DDP:
        model = pDDP(model,
                     device_ids=[opt.rank],
                     output_device=opt.rank,
                     find_unused_parameters=False)

    model_ema = ModelEma(model)

    if continue_model != '':
        if OnceExecWorker:
            print(f'loading pretrained model from {continue_model}')
        checkpoint = torch.load(
            continue_model, map_location=f'cuda:{opt.rank}' if HVD3P else None)
        model.load_state_dict(checkpoint['model'], strict=True)
        optimizer.load_state_dict(checkpoint['optimizer'])
        model_ema._load_checkpoint(continue_model,
                                   f'cuda:{opt.rank}' if HVD3P else None)

    criterion = torch.nn.CTCLoss(reduction='none',
                                 zero_infinity=True).to(device)
    converter = CTCLabelConverter(train_dataset.ralph.values())

    if OnceExecWorker:
        with open(f'./saved_models/{experiment_name}/opt.txt',
                  'a') as opt_file:
            opt_log = '------------ Options -------------\n'
            args = vars(opt)
            for k, v in args.items():
                opt_log += f'{str(k)}: {str(v)}\n'
            opt_log += '---------------------------------------\n'
            opt_log += gin.operative_config_str()
            opt_file.write(opt_log)
            # if WdB:
            #     wandb.config.gin_str = gin.operative_config_str().splitlines()

        print(optimizer)
        print(opt_log)

    start_time = time.time()
    best_accuracy = -1
    best_norm_ED = 1e+6
    best_CER = 1e+6
    i = 0
    gAcc = 1
    epoch = 1
    btReplay = False and AMP
    max_batch_replays = 3

    if HVD3P: train_sampler.set_epoch(epoch)
    titer = iter(train_loader)

    while (True):
        start_time = time.time()

        model.zero_grad()
        train_loss = Metric(pO, 'train_loss')

        for j in trange(valInterval, leave=False, desc='Training'):

            # Load a batch
            try:
                image_tensors, labels, fnames = next(titer)
            except StopIteration:
                epoch += 1
                if HVD3P: train_sampler.set_epoch(epoch)
                titer = iter(train_loader)
                image_tensors, labels, fnames = next(titer)

            # log filenames
            # fnames = [f'{i}___{fname}' for fname in fnames]
            # with open(f'./saved_models/{experiment_name}/filelog.txt', 'a+') as f:
            #     f.write('\n'.join(fnames) + '\n')

            # Move to device
            image = image_tensors.to(device)
            text, length = converter.encode(labels)
            batch_size = image.size(0)

            replay_batch = True
            maxR = 3
            while replay_batch and maxR > 0:
                maxR -= 1

                # Forward pass
                preds = model(image, text).float()
                preds_size = torch.IntTensor([preds.size(1)] *
                                             batch_size).to(device)
                preds = preds.permute(1, 0, 2).log_softmax(2)

                if i == 0 and OnceExecWorker:
                    print('Model inp : ', image.dtype, image.size())
                    print('CTC inp : ', preds.dtype, preds.size(),
                          preds_size[0])

                # To avoid ctc_loss issue, disabled cudnn for the computation of the ctc_loss
                torch.backends.cudnn.enabled = False
                cost = criterion(preds, text.to(device), preds_size,
                                 length.to(device)).mean() / gAcc
                torch.backends.cudnn.enabled = True

                train_loss.update(cost)

                # cost tracking?
                # with open(f'./saved_models/{experiment_name}/steplog.txt', 'a+') as f:
                #     f.write(f'Step {i} cost: {cost}\n')

                optimizer.zero_grad()
                default_optimizer_step = optimizer.step  # added for batch replay

                # Backward and step
                if not AMP:
                    cost.backward()
                    replay_batch = False
                else:
                    # with amp.scale_loss(cost, optimizer) as scaled_loss:
                    #     scaled_loss.backward()
                    #     if pO.HVD: optimizer.synchronize()

                    # if optimizer.step is default_optimizer_step or not btReplay:
                    #     replay_batch = False
                    # elif maxR>0:
                    #     optimizer.step()
                    pass

            if btReplay:
                pass  #amp._amp_state.loss_scalers[0]._loss_scale = mx_sc

            if (i + 1) % gAcc == 0:

                if pO.HVD and AMP:
                    with optimizer.skip_synchronize():
                        optimizer.step()
                else:
                    optimizer.step()

                model.zero_grad()
                model_ema.update(model, num_updates=i / 2)

                if (i + 1) % (gAcc * 2) == 0:
                    lr_scheduler.step()

            i += 1

        # validation part
        if True:

            elapsed_time = time.time() - start_time
            start_time = time.time()

            model.eval()
            with torch.no_grad():

                valid_loss, current_accuracy, current_norm_ED, ted, bleu, preds, labels, infer_time = validation(
                    model_ema.ema, criterion, valid_loader, converter, opt, pO)

            model.train()
            v_time = time.time() - start_time

            if OnceExecWorker:
                if current_norm_ED < best_norm_ED:
                    best_norm_ED = current_norm_ED
                    checkpoint = {
                        'model': model.state_dict(),
                        'state_dict_ema': model_ema.ema.state_dict(),
                        'optimizer': optimizer.state_dict(),
                    }
                    torch.save(
                        checkpoint,
                        f'./saved_models/{experiment_name}/best_norm_ED.pth')

                if ted < best_CER:
                    best_CER = ted

                if current_accuracy > best_accuracy:
                    best_accuracy = current_accuracy

                out = f'[{i}] Loss: {train_loss.avg:0.5f} time: ({elapsed_time:0.1f},{v_time:0.1f})'
                out += f' vloss: {valid_loss:0.3f}'
                out += f' CER: {ted:0.4f} NER: {current_norm_ED:0.4f} lr: {lr_scheduler.get_lr()[0]:0.5f}'
                out += f' bAcc: {best_accuracy:0.1f}, bNER: {best_norm_ED:0.4f}, bCER: {best_CER:0.4f}, B: {bleu*100:0.2f}'
                print(out)

                with open(f'./saved_models/{experiment_name}/log_train.txt',
                          'a') as log:
                    log.write(out + '\n')

                # if WdB:
                #     wandb.log({'lr': lr_scheduler.get_lr()[0], 'It':i, 'nED': current_norm_ED,  'B':bleu*100,
                #     'tloss':train_loss.avg, 'AnED': best_norm_ED, 'CER':ted, 'bestCER':best_CER, 'vloss':valid_loss})

        if DEBUG:
            print(
                f'[!!!] Iteration check. Value of i: {i} | Value of num_iter: {num_iter}'
            )

        # Change i == num_iter to i >= num_iter
        # Add num_iter > 0 condition
        if num_iter > 0 and i >= num_iter:
            print('end the training')
            #sys.exit()
            break
Beispiel #11
0
def train(args):

    torch.cuda.manual_seed(1)
    torch.manual_seed(1)

    # user defined
    model_name = args.model_name
    model_loss_fn = args.loss_fn

    config_file = 'config.yaml'

    config = load_config(config_file)
    data_root = config['PATH']['data_root']
    labels = config['PARAMETERS']['labels']
    root_path = config['PATH']['root']
    model_dir = config['PATH']['model_path']
    best_dir = config['PATH']['best_model_path']

    data_class = config['PATH']['data_class']
    input_modalites = int(config['PARAMETERS']['input_modalites'])
    output_channels = int(config['PARAMETERS']['output_channels'])
    base_channel = int(config['PARAMETERS']['base_channels'])
    crop_size = int(config['PARAMETERS']['crop_size'])
    batch_size = int(config['PARAMETERS']['batch_size'])
    epochs = int(config['PARAMETERS']['epoch'])
    is_best = bool(config['PARAMETERS']['is_best'])
    is_resume = bool(config['PARAMETERS']['resume'])
    patience = int(config['PARAMETERS']['patience'])
    ignore_idx = int(config['PARAMETERS']['ignore_index'])
    early_stop_patience = int(config['PARAMETERS']['early_stop_patience'])

    # build up dirs
    model_path = os.path.join(root_path, model_dir)
    best_path = os.path.join(root_path, best_dir)
    intermidiate_data_save = os.path.join(root_path, 'train_data', model_name)
    train_info_file = os.path.join(intermidiate_data_save,
                                   '{}_train_info.json'.format(model_name))
    log_path = os.path.join(root_path, 'logfiles')

    if not os.path.exists(model_path):
        os.mkdir(model_path)
    if not os.path.exists(best_path):
        os.mkdir(best_path)
    if not os.path.exists(intermidiate_data_save):
        os.makedirs(intermidiate_data_save)
    if not os.path.exists(log_path):
        os.mkdir(log_path)

    log_name = model_name + '_' + config['PATH']['log_file']
    logger = logfile(os.path.join(log_path, log_name))
    logger.info('Dataset is loading ...')
    # split dataset
    dir_ = os.path.join(data_root, data_class)
    data_content = train_split(dir_)

    # load training set and validation set
    train_set = data_loader(data_content=data_content,
                            key='train',
                            form='LGG',
                            crop_size=crop_size,
                            batch_size=batch_size,
                            num_works=8)
    n_train = len(train_set)
    train_loader = train_set.load()

    val_set = data_loader(data_content=data_content,
                          key='val',
                          form='LGG',
                          crop_size=crop_size,
                          batch_size=batch_size,
                          num_works=8)

    logger.info('Dataset loading finished!')

    n_val = len(val_set)
    nb_batches = np.ceil(n_train / batch_size)
    n_total = n_train + n_val
    logger.info(
        '{} images will be used in total, {} for trainning and {} for validation'
        .format(n_total, n_train, n_val))

    net = init_U_Net(input_modalites, output_channels, base_channel)

    device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
    if torch.cuda.device_count() > 1:
        logger.info('{} GPUs available.'.format(torch.cuda.device_count()))
        net = nn.DataParallel(net)

    net.to(device)

    if model_loss_fn == 'Dice':
        criterion = DiceLoss(labels=labels, ignore_index=ignore_idx)
    elif model_loss_fn == 'CrossEntropy':
        criterion = CrossEntropyLoss(labels=labels, ignore_index=ignore_idx)
    elif model_loss_fn == 'FocalLoss':
        criterion = FocalLoss(labels=labels, ignore_index=ignore_idx)
    elif model_loss_fn == 'Dice_CE':
        criterion = Dice_CE(labels=labels, ignore_index=ignore_idx)
    elif model_loss_fn == 'Dice_FL':
        criterion = Dice_FL(labels=labels, ignore_index=ignore_idx)
    else:
        raise NotImplementedError()

    optimizer = optim.Adam(net.parameters(), lr=1e-4, weight_decay=1e-5)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                     verbose=True,
                                                     patience=patience)

    net, optimizer = amp.initialize(net, optimizer, opt_level='O1')

    min_loss = float('Inf')
    early_stop_count = 0
    global_step = 0
    start_epoch = 0
    start_loss = 0
    train_info = {
        'train_loss': [],
        'val_loss': [],
        'BG_acc': [],
        'NET_acc': [],
        'ED_acc': [],
        'ET_acc': []
    }

    if is_resume:
        try:
            ckp_path = os.path.join(model_dir,
                                    '{}_model_ckp.pth.tar'.format(model_name))
            net, optimizer, scheduler, start_epoch, min_loss, start_loss = load_ckp(
                ckp_path, net, optimizer, scheduler)

            # open previous training records
            with open(train_info_file) as f:
                train_info = json.load(f)

            logger.info(
                'Training loss from last time is {}'.format(start_loss) +
                '\n' +
                'Mininum training loss from last time is {}'.format(min_loss))

        except:
            logger.warning(
                'No checkpoint available, strat training from scratch')

    # start training
    for epoch in range(start_epoch, epochs):

        # setup to train mode
        net.train()
        running_loss = 0
        dice_coeff_bg = 0
        dice_coeff_net = 0
        dice_coeff_ed = 0
        dice_coeff_et = 0

        logger.info('Training epoch {} will begin'.format(epoch + 1))

        with tqdm(total=n_train,
                  desc=f'Epoch {epoch+1}/{epochs}',
                  unit='patch') as pbar:

            for i, data in enumerate(train_loader, 0):
                images, segs = data['image'].to(device), data['seg'].to(device)

                # zero the parameter gradients
                optimizer.zero_grad()

                outputs = net(images)

                loss = criterion(outputs, segs)
                # loss.backward()
                with amp.scale_loss(loss, optimizer) as scaled_loss:
                    scaled_loss.backward()

                optimizer.step()

                # save the output at the begining of each epoch to visulize it
                if i == 0:
                    in_images = images.detach().cpu().numpy()[:, 0, ...]
                    in_segs = segs.detach().cpu().numpy()
                    in_pred = outputs.detach().cpu().numpy()
                    heatmap_plot(image=in_images,
                                 mask=in_segs,
                                 pred=in_pred,
                                 name=model_name,
                                 epoch=epoch + 1)

                running_loss += loss.detach().item()
                dice_score = dice_coe(outputs.detach().cpu(),
                                      segs.detach().cpu())
                dice_coeff_bg += dice_score['BG']
                dice_coeff_ed += dice_score['ED']
                dice_coeff_et += dice_score['ET']
                dice_coeff_net += dice_score['NET']

                # show progress bar
                pbar.set_postfix(
                    **{
                        'Training loss': loss.detach().item(),
                        'Training (avg) accuracy': dice_score['avg']
                    })
                pbar.update(images.shape[0])

                global_step += 1
                if global_step % nb_batches == 0:
                    # validate
                    net.eval()
                    val_loss, val_acc = validation(net, val_set, criterion,
                                                   device, batch_size)

        train_info['train_loss'].append(running_loss / nb_batches)
        train_info['val_loss'].append(val_loss)
        train_info['BG_acc'].append(dice_coeff_bg / nb_batches)
        train_info['NET_acc'].append(dice_coeff_net / nb_batches)
        train_info['ED_acc'].append(dice_coeff_ed / nb_batches)
        train_info['ET_acc'].append(dice_coeff_et / nb_batches)

        # save bast trained model
        scheduler.step(running_loss / nb_batches)

        if min_loss > val_loss:
            min_loss = val_loss
            is_best = True
            early_stop_count = 0
        else:
            is_best = False
            early_stop_count += 1

        state = {
            'epoch': epoch + 1,
            'model_state_dict': net.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict(),
            'loss': running_loss / nb_batches,
            'min_loss': min_loss
        }
        verbose = save_ckp(state,
                           is_best,
                           early_stop_count=early_stop_count,
                           early_stop_patience=early_stop_patience,
                           save_model_dir=model_path,
                           best_dir=best_path,
                           name=model_name)

        logger.info('The average training loss for this epoch is {}'.format(
            running_loss / (np.ceil(n_train / batch_size))))
        logger.info(
            'Validation dice loss: {}; Validation (avg) accuracy: {}'.format(
                val_loss, val_acc))
        logger.info('The best validation loss till now is {}'.format(min_loss))

        # save the training info every epoch
        logger.info('Writing the training info into file ...')
        with open(train_info_file, 'w') as fp:
            json.dump(train_info, fp)

        loss_plot(train_info_file, name=model_name)

        if verbose:
            logger.info(
                'The validation loss has not improved for {} epochs, training will stop here.'
                .format(early_stop_patience))
            break

    logger.info('finish training!')
Beispiel #12
0
def validation_part(best_accuracy, best_norm_ED, converter, criterion, i,
                    loss_avg, model, opt, start_time, valid_loader):
    if i % opt.valInterval == 0:
        elapsed_time = time.time() - start_time
        # for log
        with open(f'./saved_models/{opt.experiment_name}/log_train.txt',
                  'a') as log:
            model.eval()
            with torch.no_grad():
                valid_loss, current_accuracy, current_norm_ED, preds, confidence_score, labels, infer_time, length_of_data = validation(
                    model, criterion, valid_loader, converter, opt)
            model.train()

            # training loss and validation loss
            loss_log = f'[{i}/{opt.num_iter}] Train loss: {loss_avg.val():0.5f}, Valid loss: {valid_loss:0.5f}, Elapsed_time: {elapsed_time:0.5f}'
            loss_avg.reset()

            current_model_log = f'{"Current_accuracy":17s}: {current_accuracy:0.3f}, {"Current_norm_ED":17s}: {current_norm_ED:0.2f}'

            # keep best accuracy model (on valid dataset)
            if current_accuracy > best_accuracy:
                best_accuracy = current_accuracy
                torch.save(
                    model.state_dict(),
                    f'./saved_models/{opt.experiment_name}/best_accuracy.pth')
            if current_norm_ED > best_norm_ED:
                best_norm_ED = current_norm_ED
                torch.save(
                    model.state_dict(),
                    f'./saved_models/{opt.experiment_name}/best_norm_ED.pth')
            best_model_log = f'{"Best_accuracy":17s}: {best_accuracy:0.3f}, {"Best_norm_ED":17s}: {best_norm_ED:0.2f}'

            loss_model_log = f'{loss_log}\n{current_model_log}\n{best_model_log}'
            print(loss_model_log)
            log.write(loss_model_log + '\n')

            # show some predicted results
            dashed_line = '-' * 80
            head = f'{"Ground Truth":25s} | {"Prediction":25s} | Confidence Score & T/F'
            predicted_result_log = f'{dashed_line}\n{head}\n{dashed_line}\n'
            for gt, pred, confidence in zip(labels[:5], preds[:5],
                                            confidence_score[:5]):
                if 'Attn' in opt.Prediction:
                    gt = gt[:gt.find('[s]')]
                    pred = pred[:pred.find('[s]')]

                predicted_result_log += f'{gt:25s} | {pred:25s} | {confidence:0.4f}\t{str(pred == gt)}\n'
            predicted_result_log += f'{dashed_line}'
            print(predicted_result_log)
            log.write(predicted_result_log + '\n')
def train(opt):
    """ dataset preparation """
    if not opt.data_filtering_off:
        print('Filtering the images containing characters which are not in opt.character')
        print('Filtering the images whose label is longer than opt.batch_max_length')
        # see https://github.com/clovaai/deep-text-recognition-benchmark/blob/6593928855fb7abb999a99f428b3e4477d4ae356/dataset.py#L130

    # opt.select_data = opt.select_data.split('-')#[MJ,ST]
    # opt.batch_ratio = opt.batch_ratio.split('-')#[0.5,0.5]
    # train_dataset = Batch_Balanced_Dataset(opt)

    # log = open(f'./saved_models/{opt.exp_name}/log_dataset.txt', 'a')
    AlignCollate_valid = AlignCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio_with_pad=opt.PAD)
    # valid_dataset, valid_dataset_log = hierarchical_dataset(root=opt.valid_data, opt=opt)
    # train_dataset = iiit5k_dataset_builder("/media/ps/hd1/lll/textRecognition/SAR/IIIT5K/train",
    #     "/media/ps/hd1/lll/textRecognition/SAR/IIIT5K/traindata.mat",opt)
    
    # train_dataset = PpocrDataset("/home/ldl/桌面/论文/文本识别/data/paddleocr",
    # "/home/ldl/桌面/论文/文本识别/data/paddleocr/label/train.txt",6625*100)
    


    # train_dataset_chinese = TextRecognition(4068*75,opt.charalength,opt.chinesefile)
    # train_dataset_english = TextRecognition(4068*25,opt.charalength,opt.englishfile)
    # train_dataset = ConcatDataset([train_dataset_chinese,train_dataset_english])
    train_dataset_xunfeieng = mytrdg_cutimg_dataset(total_img_path='/home/ldl/桌面/论文/文本识别/data/finish_data/eng_image/train/img',
        annotation_path='/home/ldl/桌面/论文/文本识别/data/finish_data/eng_image/train/gt')
    train_dataset_xunfeichn = mytrdg_cutimg_dataset(total_img_path='/home/ldl/桌面/论文/文本识别/data/finish_data/lan_image/train/img',
        annotation_path='/home/ldl/桌面/论文/文本识别/data/finish_data/lan_image/train/gt')
    train_dataset = ConcatDataset([train_dataset_xunfeichn,train_dataset_xunfeieng])
    train_loader = torch.utils.data.DataLoader(
        train_dataset, batch_size=opt.batch_size,
        shuffle=True,  # 'True' to check training progress with validation function.
        num_workers=int(opt.workers),
        collate_fn=AlignCollate_valid)

    # valid_dataset = iiit5k_dataset_builder("/media/ps/hd1/lll/textRecognition/SAR/IIIT5K/test",
    #     "/media/ps/hd1/lll/textRecognition/SAR/IIIT5K/testdata.mat",opt)
    # valid_dataset_chinese = TextRecognition(1001,opt.charalength,opt.chinesefile)
    # valid_dataset_english = TextRecognition(1001,opt.charalength,opt.englishfile)
    # valid_dataset = ConcatDataset([valid_dataset_chinese,valid_dataset_english])

    valid_dataset_xunfeieng = mytrdg_cutimg_dataset(total_img_path='/home/ldl/桌面/论文/文本识别/data/finish_data/eng_image/test/img',
        annotation_path='/home/ldl/桌面/论文/文本识别/data/finish_data/eng_image/test/gt')
    valid_dataset_xunfeichn = mytrdg_cutimg_dataset(total_img_path='/home/ldl/桌面/论文/文本识别/data/finish_data/lan_image/test/img',
        annotation_path='/home/ldl/桌面/论文/文本识别/data/finish_data/lan_image/test/gt')
    valid_dataset = ConcatDataset([valid_dataset_xunfeichn,valid_dataset_xunfeieng])

    # valid_dataset = PpocrDataset("/home/ldl/桌面/论文/文本识别/data/paddleocr/Synthetic_Chinese_String_Dataset/images",
    # "/home/ldl/桌面/论文/文本识别/data/paddleocr/Synthetic_Chinese_String_Dataset/test.txt",6625,
    #     split='jpg')
    valid_loader = torch.utils.data.DataLoader(
        valid_dataset, batch_size=opt.batch_size,
        shuffle=True,  # 'True' to check training progress with validation function.
        num_workers=int(opt.workers),
        collate_fn=AlignCollate_valid)
    # log.write(valid_dataset_log)
    print('-' * 80)
    # log.write('-' * 80 + '\n')
    # log.close()
    
    """ model configuration """
    if 'CTC' in opt.Prediction:
        if opt.baiduCTC:
            converter = CTCLabelConverterForBaiduWarpctc(opt.character)
        else:
            converter = CTCLabelConverter(opt.character)
    else:
        converter = AttnLabelConverter(opt.character)
    opt.num_class = len(converter.character)

    if opt.rgb:
        opt.input_channel = 3
    model = Model(opt)
    print('model input parameters', opt.imgH, opt.imgW, opt.num_fiducial, opt.input_channel, opt.output_channel,
          opt.hidden_size, opt.num_class, opt.batch_max_length, opt.Transformation, opt.FeatureExtraction,
          opt.SequenceModeling, opt.Prediction)

    # weight initialization
    for name, param in model.named_parameters():
        if 'localization_fc2' in name:
            print(f'Skip {name} as it is already initialized')
            continue
        try:
            if 'bias' in name:
                init.constant_(param, 0.0)
            elif 'weight' in name:
                init.kaiming_normal_(param)
        except Exception as e:  # for batchnorm.
            if 'weight' in name:
                param.data.fill_(1)
            continue

    # data parallel for multi-GPU
    if opt.num_gpu > 1:
    
        model = torch.nn.DataParallel(model).to(device)
    else:
        model.to(device)
    model.train()
    if opt.saved_model != '':
        print(f'loading pretrained model from {opt.saved_model}')
        if opt.FT:
            model.load_state_dict(torch.load(opt.saved_model), strict=False)
        else:
            model.load_state_dict(torch.load(opt.saved_model))
    print("Model:")
    # print(model)

    """ setup loss """
    if 'CTC' in opt.Prediction:
        if opt.baiduCTC:
            # need to install warpctc. see our guideline.
            from warpctc_pytorch import CTCLoss 
            criterion = CTCLoss()
        else:
            criterion = torch.nn.CTCLoss(zero_infinity=True).to(device)
    else:
        criterion = torch.nn.CrossEntropyLoss(ignore_index=0).to(device)  # ignore [GO] token = ignore index 0
    # loss averager
    loss_avg = Averager()

    # filter that only require gradient decent
    filtered_parameters = []
    params_num = []
    for p in filter(lambda p: p.requires_grad, model.parameters()):
        filtered_parameters.append(p)
        params_num.append(np.prod(p.size()))
    print('Trainable params num : ', sum(params_num))
    print(f"Train dataset length {len(train_dataset)}")
    print(f"Train dataset length {len(valid_dataset)}")
    # [print(name, p.numel()) for name, p in filter(lambda p: p[1].requires_grad, model.named_parameters())]

    # setup optimizer
    # if opt.adam:
    #     optimizer = optim.Adam(filtered_parameters, lr=opt.lr, betas=(opt.beta1, 0.999))
    # else:
    #     optimizer = optim.Adadelta(filtered_parameters, lr=opt.lr, rho=opt.rho, eps=opt.eps)
    optimizer = optim.Adam(filtered_parameters,lr=0.0001)
    print("Optimizer:")
    print(optimizer)

    """ final options """
    # print(opt)
    with open(f'./saved_models/{opt.exp_name}/opt.txt', 'a') as opt_file:
        opt_log = '------------ Options -------------\n'
        args = vars(opt)
        for k, v in args.items():
            opt_log += f'{str(k)}: {str(v)}\n'
        opt_log += '---------------------------------------\n'
        print(opt_log)
        opt_file.write(opt_log)

    """ start training """
    start_iter = 0
    # if opt.saved_model != '':
    #     try:
    #         start_iter = int(opt.saved_model.split('_')[-1].split('.')[0])
    #         print(f'continue to train, start_iter: {start_iter}')
    #     except:
    #         pass

    start_time = time.time()
    best_accuracy = -1
    best_norm_ED = -1
    iteration = start_iter
    epoch = 0
    train_iter_loader = iter(train_loader)

    while(True):
        # train part
        try:

            image_tensors, labels = train_iter_loader.next()
            # if len(labels)>80:
            #     print(labels)
            print("{:4}".format(iteration),end='\r')
        except StopIteration:
            epoch += 1
            print(f"epoch:{epoch}")
            # if epoch >= 1:
            #     break
            # train_loader = torch.utils.data.DataLoader(
            #     train_dataset, batch_size=opt.batch_size,
            #     shuffle=True,  # 'True' to check training progress with validation function.
            #     num_workers=int(opt.workers),
            #     collate_fn=AlignCollate_valid)
            train_iter_loader = iter(train_loader)
            continue
        image = image_tensors.to(device)
        text, length = converter.encode(labels, batch_max_length=opt.batch_max_length)
        batch_size = image.size(0)

        if 'CTC' in opt.Prediction:
            preds = model(image, text)
            preds_size = torch.IntTensor([preds.size(1)] * preds.size(0))
            if opt.baiduCTC:
                preds = preds.permute(1, 0, 2)  # to use CTCLoss format
                cost = criterion(preds, text, preds_size, length) / batch_size
            else:
                preds = preds.log_softmax(2).permute(1, 0, 2)
                try:
                    cost = criterion(preds, text, preds_size, length)
                except Exception:
                    print(preds.shape,preds_size.shape)
                    raise ''

        else:
            preds = model(image, text[:, :-1])  # align with Attention.forward
            target = text[:, 1:]  # without [GO] Symbol
            cost = criterion(preds.view(-1, preds.shape[-1]), target.contiguous().view(-1))

        model.zero_grad()
        cost.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), opt.grad_clip)  # gradient clipping with 5 (Default)
        optimizer.step()

        loss_avg.add(cost)

        # validation part
        if (iteration + 1) % opt.valInterval == 0 or iteration == 0: # To see training progress, we also conduct validation when 'iteration == 0' 
            elapsed_time = time.time() - start_time
            # for log
            with open(f'./saved_models/{opt.exp_name}/log_train.txt', 'a') as log:
                model.eval()
                print("validation")
                with torch.no_grad():
                    valid_loss, current_accuracy, current_norm_ED, preds, confidence_score, labels, infer_time, length_of_data = validation(
                        model, criterion, valid_loader, converter, opt)
                model.train()

                # training loss and validation loss
                loss_log = f'[{iteration+1}/{opt.num_iter}] Train loss: {loss_avg.val():0.5f}, Valid loss: {valid_loss:0.5f}, Elapsed_time: {elapsed_time:0.5f}'
                loss_avg.reset()

                current_model_log = f'{"Current_accuracy":17s}: {current_accuracy:0.3f}, {"Current_norm_ED":17s}: {current_norm_ED:0.2f}'

                # keep best accuracy model (on valid dataset)
                if current_accuracy >= best_accuracy:
                    best_accuracy = current_accuracy
                    torch.save(model.state_dict(), f'./saved_models/{opt.exp_name}/best_accuracy.pth')
                if current_norm_ED >= best_norm_ED:
                    best_norm_ED = current_norm_ED
                    torch.save(model.state_dict(), f'./saved_models/{opt.exp_name}/best_norm_ED.pth')
                best_model_log = f'{"Best_accuracy":17s}: {best_accuracy:0.3f}, {"Best_norm_ED":17s}: {best_norm_ED:0.2f}'

                loss_model_log = f'{loss_log}\n{current_model_log}\n{best_model_log}'
                print(loss_model_log)
                log.write(loss_model_log + '\n')

                # show some predicted results
                dashed_line = '-' * 80
                head = f'{"Ground Truth":25s} | {"Prediction":25s} | Confidence Score & T/F'
                predicted_result_log = f'{dashed_line}\n{head}\n{dashed_line}\n'
                for gt, pred, confidence in zip(labels[:5], preds[:5], confidence_score[:5]):
                    if 'Attn' in opt.Prediction:
                        gt = gt[:gt.find('[s]')]
                        pred = pred[:pred.find('[s]')]

                    predicted_result_log += f'{gt:25s} | {pred:25s} | {confidence:0.4f}\t{str(pred == gt)}\n'
                predicted_result_log += f'{dashed_line}'
                print(predicted_result_log)
                log.write(predicted_result_log + '\n')

        # save model per 1e+5 iter.
        if (iteration + 1) % 1e+4 == 0:
            torch.save(
                model.state_dict(), f'./saved_models/{opt.exp_name}/iter_{iteration+1}.pth')
        if (iteration + 1) == opt.num_iter:
            print('end the training')
            sys.exit()
        iteration += 1
def train(opt):
    os.makedirs(opt.log, exist_ok=True)
    writer = SummaryWriter(opt.log)
    """ dataset preparation """
    if not opt.data_filtering_off:
        print(
            'Filtering the images containing characters which are not in opt.character'
        )
        print(
            'Filtering the images whose label is longer than opt.batch_max_length'
        )

    opt.select_data = opt.select_data.split('-')
    opt.batch_ratio = opt.batch_ratio.split('-')
    train_dataset = Batch_Balanced_Dataset(opt)

    log = open(f'./saved_models/{opt.exp_name}/log_dataset.txt', 'a')
    AlignCollate_valid = AlignCollate(imgH=opt.imgH,
                                      imgW=opt.imgW,
                                      keep_ratio_with_pad=opt.PAD)
    valid_dataset, valid_dataset_log = hierarchical_dataset(
        root=opt.valid_data, opt=opt)
    valid_loader = torch.utils.data.DataLoader(
        valid_dataset,
        batch_size=opt.batch_size,
        shuffle=
        True,  # 'True' to check training progress with validation function.
        num_workers=int(opt.workers),
        collate_fn=AlignCollate_valid,
        pin_memory=True)

    log.write(valid_dataset_log)
    print('-' * 80)
    log.write('-' * 80 + '\n')
    log.close()
    """ model configuration """

    ctc_converter = CTCLabelConverter(opt.character)
    attn_converter = AttnLabelConverter(opt.character)
    opt.num_class = len(attn_converter.character)

    if opt.rgb:
        opt.input_channel = 3
    model = Model(opt)

    print('model input parameters', opt.imgH, opt.imgW, opt.num_fiducial,
          opt.input_channel, opt.output_channel, opt.hidden_size,
          opt.num_class, opt.batch_max_length)

    # weight initialization
    for name, param in model.named_parameters():
        if 'localization_fc2' in name:
            print(f'Skip {name} as it is already initialized')
            continue
        try:
            if 'bias' in name:
                init.constant_(param, 0.0)
            elif 'weight' in name:
                init.kaiming_normal_(param)
        except Exception as e:  # for batchnorm.
            if 'weight' in name:
                param.data.fill_(1)
            continue

    # data parallel for multi-GPU
    model = torch.nn.DataParallel(model).to(device)
    model.train()
    if opt.saved_model != '':
        print(f'loading pretrained model from {opt.saved_model}')
        if opt.FT:
            model.load_state_dict(torch.load(opt.saved_model), strict=False)
        else:
            model.load_state_dict(torch.load(opt.saved_model))
    """ setup loss """
    loss_avg = Averager()
    ctc_loss = torch.nn.CTCLoss(zero_infinity=True).to(device)
    attn_loss = torch.nn.CrossEntropyLoss(ignore_index=0).to(device)

    # filter that only require gradient decent
    filtered_parameters = []
    params_num = []
    for p in filter(lambda p: p.requires_grad, model.parameters()):
        filtered_parameters.append(p)
        params_num.append(np.prod(p.size()))
    print('Trainable params num : ', sum(params_num))

    # setup optimizer
    if opt.adam:
        optimizer = optim.Adam(filtered_parameters,
                               lr=opt.lr,
                               betas=(opt.beta1, 0.999))
    else:
        optimizer = optim.Adadelta(filtered_parameters,
                                   lr=opt.lr,
                                   rho=opt.rho,
                                   eps=opt.eps)
    print("Optimizer:")
    """ final options """
    # print(opt)
    with open(f'./saved_models/{opt.exp_name}/opt.txt', 'a') as opt_file:
        opt_log = '------------ Options -------------\n'
        args = vars(opt)
        for k, v in args.items():
            opt_log += f'{str(k)}: {str(v)}\n'
        opt_log += '---------------------------------------\n'
        print(opt_log)
        opt_file.write(opt_log)
    """ start training """
    start_iter = 0
    if opt.saved_model != '':
        try:
            start_iter = int(opt.saved_model.split('_')[-1].split('.')[0])
            print(f'continue to train, start_iter: {start_iter}')
        except:
            pass

    start_time = time.time()
    best_accuracy = -1
    best_norm_ED = -1
    iteration = start_iter
    pbar = tqdm(range(opt.num_iter))

    for iteration in pbar:

        # train part
        image_tensors, labels = train_dataset.get_batch()
        image = image_tensors.to(device)
        ctc_text, ctc_length = ctc_converter.encode(
            labels, batch_max_length=opt.batch_max_length)
        attn_text, attn_length = attn_converter.encode(
            labels, batch_max_length=opt.batch_max_length)

        batch_size = image.size(0)

        preds, refiner = model(image, attn_text[:, :-1])

        refiner_size = torch.IntTensor([refiner.size(1)] * batch_size)
        refiner = refiner.log_softmax(2).permute(1, 0, 2)
        refiner_loss = ctc_loss(refiner, ctc_text, refiner_size, ctc_length)

        total_loss = opt.lambda_ctc * refiner_loss
        target = attn_text[:, 1:]  # without [GO] Symbol
        for pred in preds:
            total_loss += opt.lambda_attn * attn_loss(
                pred.view(-1, pred.shape[-1]),
                target.contiguous().view(-1))

        model.zero_grad()
        total_loss.backward()
        torch.nn.utils.clip_grad_norm_(
            model.parameters(),
            opt.grad_clip)  # gradient clipping with 5 (Default)
        optimizer.step()
        loss_avg.add(total_loss)
        if loss_avg.val() <= 0.6:
            opt.grad_clip = 2
        if loss_avg.val() <= 0.3:
            opt.grad_clip = 1

        preds = (p.cpu() for p in preds)
        refiner = refiner.cpu()
        image = image.cpu()
        torch.cuda.empty_cache()

        writer.add_scalar('train_loss', loss_avg.val(), iteration)
        pbar.set_description('Iteration {0}/{1}, AvgLoss {2}'.format(
            iteration, opt.num_iter, loss_avg.val()))

        # validation part
        if (iteration + 1) % opt.valInterval == 0 or iteration == 0:
            elapsed_time = time.time() - start_time
            # for log
            with open(f'./saved_models/{opt.exp_name}/log_train.txt',
                      'a') as log:
                model.eval()
                with torch.no_grad():
                    valid_loss, current_accuracy, current_norm_ED, preds, confidence_score, labels, infer_time, length_of_data = validation(
                        model, attn_loss, valid_loader, attn_converter, opt)
                model.train()

                # training loss and validation loss
                loss_log = f'[{iteration+1}/{opt.num_iter}] Train loss: {loss_avg.val():0.5f}, Valid loss: {valid_loss:0.5f}, Elapsed_time: {elapsed_time:0.5f}'
                writer.add_scalar('Val_loss', valid_loss)
                pbar.set_description(loss_log)
                loss_avg.reset()

                current_model_log = f'{"Current_accuracy":17s}: {current_accuracy:0.3f}, {"Current_norm_ED":17s}: {current_norm_ED:0.2f}'

                # keep best accuracy model (on valid dataset)
                if current_accuracy > best_accuracy:
                    best_accuracy = current_accuracy
                    torch.save(
                        model.state_dict(),
                        f'./saved_models/{opt.exp_name}/best_accuracy_{str(best_accuracy)}.pth'
                    )
                if current_norm_ED > best_norm_ED:
                    best_norm_ED = current_norm_ED
                    torch.save(
                        model.state_dict(),
                        f'./saved_models/{opt.exp_name}/best_norm_ED.pth')
                best_model_log = f'{"Best_accuracy":17s}: {best_accuracy:0.3f}, {"Best_norm_ED":17s}: {best_norm_ED:0.2f}'

                loss_model_log = f'{loss_log}\n{current_model_log}\n{best_model_log}'
                # print(loss_model_log)

                log.write(loss_model_log + '\n')

                # show some predicted results
                dashed_line = '-' * 80
                head = f'{"Ground Truth":25s} | {"Prediction":25s} | Confidence Score & T/F'
                predicted_result_log = f'{dashed_line}\n{head}\n{dashed_line}\n'
                for gt, pred, confidence in zip(labels[:5], preds[:5],
                                                confidence_score[:5]):
                    if 'Attn' or 'Transformer' in opt.Prediction:
                        gt = gt[:gt.find('[s]')]
                        pred = pred[:pred.find('[s]')]

                    predicted_result_log += f'{gt:25s} | {pred:25s} | {confidence:0.4f}\t{str(pred == gt)}\n'
                predicted_result_log += f'{dashed_line}'
                log.write(predicted_result_log + '\n')

        # save model per 1e+3 iter.
        if (iteration + 1) % 1e+3 == 0:
            torch.save(model.state_dict(),
                       f'./saved_models/{opt.exp_name}/SCATTER_STR.pth')

        if (iteration + 1) == opt.num_iter:
            print('end the training')
            sys.exit()
Beispiel #15
0
def train(opt):
    """ 准备训练和验证的数据集 """
    transform = transforms.Compose([
        ToTensor(),
    ])
    train_dataset = LmdbDataset(opt.train_data, opt=opt, transform=transform)
    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=opt.batch_size,
        shuffle=True,
        num_workers=int(opt.workers),
    )

    valid_dataset = LmdbDataset(root=opt.valid_data,
                                opt=opt,
                                transform=transform)
    valid_loader = torch.utils.data.DataLoader(
        valid_dataset,
        batch_size=opt.batch_size,
        shuffle=True,
        num_workers=int(opt.workers),
    )
    print('-' * 80)
    """ 模型的配置 """
    if 'CTC' in opt.Prediction:
        converter = CTCLabelConverter(opt.character)
    else:
        converter = AttnLabelConverter(opt.character)
    opt.num_class = len(converter.character)

    if opt.rgb:
        opt.input_channel = 3
    model = Model(opt)
    print('model input parameters', opt.imgH, opt.imgW, opt.num_fiducial,
          opt.input_channel, opt.output_channel, opt.hidden_size,
          opt.num_class, opt.batch_max_length, opt.Transformation,
          opt.FeatureExtraction, opt.SequenceModeling, opt.Prediction)

    # 权重初始化
    for name, param in model.named_parameters():
        if 'localization_fc2' in name:
            print(f'Skip {name} as it is already initialized')
            continue
        try:
            if 'bias' in name:
                init.constant_(param, 0.0)
            elif 'weight' in name:
                init.kaiming_normal_(param)
        except Exception as e:  # for batchnorm.
            if 'weight' in name:
                param.data.fill_(1)
            continue

    model = model.to(device)
    model.train()
    if opt.continue_model != '':
        print(f'loading pretrained model from {opt.continue_model}')
        model.load_state_dict(torch.load(opt.continue_model))
    print("Model:")
    print(model)
    """ setup loss """
    if 'CTC' in opt.Prediction:
        criterion = torch.nn.CTCLoss(zero_infinity=True).to(device)
    else:
        criterion = torch.nn.CrossEntropyLoss(ignore_index=0).to(
            device)  # ignore [GO] token = ignore index 0
    # loss averager
    loss_avg = Averager()

    # filter that only require gradient decent
    filtered_parameters = []
    params_num = []
    for p in filter(lambda p: p.requires_grad, model.parameters()):
        filtered_parameters.append(p)
        params_num.append(np.prod(p.size()))
    print('Trainable params num : ', sum(params_num))
    # [print(name, p.numel()) for name, p in filter(lambda p: p[1].requires_grad, model.named_parameters())]

    # setup optimizer
    if opt.adam:
        optimizer = optim.Adam(filtered_parameters,
                               lr=opt.lr,
                               betas=(opt.beta1, 0.999))
    else:
        optimizer = optim.Adadelta(filtered_parameters,
                                   lr=opt.lr,
                                   rho=opt.rho,
                                   eps=opt.eps)
    print("Optimizer:")
    print(optimizer)
    """ final options """
    # print(opt)
    with open(f'./saved_models/{opt.experiment_name}/opt.txt',
              'a') as opt_file:
        opt_log = '------------ Options -------------\n'
        args = vars(opt)
        for k, v in args.items():
            opt_log += f'{str(k)}: {str(v)}\n'
        opt_log += '---------------------------------------\n'
        print(opt_log)
        opt_file.write(opt_log)
    """ start training """
    start_iter = 0
    if opt.continue_model != '':
        start_iter = int(opt.continue_model.split('_')[-1].split('.')[0])
        print(f'continue to train, start_iter: {start_iter}')

    start_time = time.time()
    best_accuracy = -1
    best_norm_ED = 1e+6
    i = start_iter

    while (True):
        # train part
        for image_tensors, labels in train_loader:
            image = image_tensors.to(device)
            text, length = converter.encode(
                labels, batch_max_length=opt.batch_max_length
            )  # text: [index, index, ..., index], length: [10, 8]
            batch_size = image.size(0)

            if 'CTC' in opt.Prediction:
                # set xx = model(image, text) torch.Size([100, 63, 7]), xx.log_softmax(2)[0][0] = xx[0][0].log_softmax(-1)
                preds = model(image,
                              text).log_softmax(2)  # torch.Size([100, 63, 12])
                preds_size = torch.IntTensor([preds.size(1)] *
                                             batch_size).to(device)
                preds = preds.permute(
                    1, 0, 2
                )  # to use CTCLoss format  # 100 * 63 * 7 ->  63 * 100 * 7

                # To avoid ctc_loss issue, disabled cudnn for the computation of the ctc_loss
                # https://github.com/jpuigcerver/PyLaia/issues/16
                torch.backends.cudnn.enabled = False
                cost = criterion(
                    preds, text, preds_size, length
                )  # preds.shape: torch.Size([63, 100, 7]), 其中63是序列特征,100是batch_size, 7是输出类别数量; text.shape: torch.Size([1000]), 表示1000个字符
                # preds_size:[63, 63, ..., 63] 100,数组中的63表示序列的长度 length: [10, 10, ..., 10] 100,数组中的每个10表示每个标签的长度,意思就是每一张图片有10个字符
                torch.backends.cudnn.enabled = True

            else:
                preds = model(image,
                              text[:, :-1])  # align with Attention.forward
                target = text[:, 1:]  # without [GO] Symbol
                cost = criterion(preds.view(-1, preds.shape[-1]),
                                 target.contiguous().view(-1))

            model.zero_grad()
            cost.backward()
            torch.nn.utils.clip_grad_norm_(
                model.parameters(),
                opt.grad_clip)  # gradient clipping with 5 (Default)
            optimizer.step()

            loss_avg.add(cost)

            # validation part
            if i % opt.valInterval == 0:
                elapsed_time = time.time() - start_time
                print(
                    f'[{i}/{opt.num_iter}] Loss: {loss_avg.val():0.5f} elapsed_time: {elapsed_time:0.5f}'
                )
                # for log
                with open(
                        f'./saved_models/{opt.experiment_name}/log_train.txt',
                        'a') as log:
                    log.write(
                        f'[{i}/{opt.num_iter}] Loss: {loss_avg.val():0.5f} elapsed_time: {elapsed_time:0.5f}\n'
                    )
                    loss_avg.reset()

                    model.eval()
                    with torch.no_grad():
                        valid_loss, current_accuracy, current_norm_ED, preds, labels, infer_time, length_of_data = validation(
                            model, criterion, valid_loader, converter, opt)
                    model.train()

                    for pred, gt in zip(preds[:5], labels[:5]):
                        if 'Attn' in opt.Prediction:
                            pred = pred[:pred.find('[s]')]
                            gt = gt[:gt.find('[s]')]
                        print(f'{pred:20s}, gt: {gt:20s},   {str(pred == gt)}')
                        log.write(
                            f'{pred:20s}, gt: {gt:20s},   {str(pred == gt)}\n')

                    valid_log = f'[{i}/{opt.num_iter}] valid loss: {valid_loss:0.5f}'
                    valid_log += f' accuracy: {current_accuracy:0.3f}, norm_ED: {current_norm_ED:0.2f}'
                    print(valid_log)
                    log.write(valid_log + '\n')

                    # keep best accuracy model
                    if current_accuracy > best_accuracy:
                        best_accuracy = current_accuracy
                        torch.save(
                            model.state_dict(),
                            f'./saved_models/{opt.experiment_name}/best_accuracy.pth'
                        )
                    if current_norm_ED < best_norm_ED:
                        best_norm_ED = current_norm_ED
                        torch.save(
                            model.state_dict(),
                            f'./saved_models/{opt.experiment_name}/best_norm_ED.pth'
                        )
                    best_model_log = f'best_accuracy: {best_accuracy:0.3f}, best_norm_ED: {best_norm_ED:0.2f}'
                    print(best_model_log)
                    log.write(best_model_log + '\n')

            # save model per 1e+5 iter.
            if (i + 1) % 1e+5 == 0:
                torch.save(
                    model.state_dict(),
                    f'./saved_models/{opt.experiment_name}/iter_{i+1}.pth')

            if i == opt.num_iter:
                print('end the training')
                sys.exit()
            i += 1
Beispiel #16
0
def train(opt):
    """ dataset preparation """
    opt.select_data = opt.select_data.split('-')
    opt.batch_ratio = opt.batch_ratio.split('-')

    #import ipdb;ipdb.set_trace()

    train_dataset = Batch_Balanced_Dataset(opt)

    AlignCollate_valid = AlignCollate(imgH=opt.imgH,
                                      imgW=opt.imgW,
                                      keep_ratio_with_pad=opt.PAD)
    valid_dataset = hierarchical_dataset(root=opt.valid_data, opt=opt)
    valid_loader = torch.utils.data.DataLoader(
        valid_dataset,
        batch_size=opt.batch_size,
        shuffle=
        True,  # 'True' to check training progress with validation function.
        num_workers=int(opt.workers),
        collate_fn=AlignCollate_valid,
        pin_memory=True)
    print('-' * 80)
    """ model configuration """
    if 'CTC' in opt.Prediction:
        converter = CTCLabelConverter(opt.character)
    else:
        converter = AttnLabelConverter(opt.character)
    opt.num_class = len(converter.character)

    if opt.rgb:
        opt.input_channel = 3
    model = Model(opt)
    print('model input parameters', opt.imgH, opt.imgW, opt.num_fiducial,
          opt.input_channel, opt.output_channel, opt.hidden_size,
          opt.num_class, opt.batch_max_length, opt.Transformation,
          opt.FeatureExtraction, opt.SequenceModeling, opt.Prediction)

    # weight initialization
    for name, param in model.named_parameters():
        if 'localization_fc2' in name:
            print(f'Skip {name} as it is already initialized')
            continue
        try:
            if 'bias' in name:
                init.constant_(param, 0.0)
            elif 'weight' in name:
                init.kaiming_normal_(param)
        except Exception as e:  # for batchnorm.
            if 'weight' in name:
                param.data.fill_(1)
            continue

    # data parallel for multi-GPU

    model = torch.nn.DataParallel(model).to(device)
    model.train()

    if opt.continue_model != '':
        print(f'loading pretrained model from {opt.continue_model}')
        model.load_state_dict(torch.load(opt.continue_model))
    print("Model:")
    #print(model)
    """ setup loss """
    if 'CTC' in opt.Prediction:
        criterion = torch.nn.CTCLoss(zero_infinity=True).to(device)
    else:
        criterion = torch.nn.CrossEntropyLoss(ignore_index=0).to(
            device)  # ignore [GO] token = ignore index 0
    # loss averager
    loss_avg = Averager()

    # filter that only require gradient decent
    filtered_parameters = []
    params_num = []
    for p in filter(lambda p: p.requires_grad, model.parameters()):
        filtered_parameters.append(p)
        params_num.append(np.prod(p.size()))
    print('Trainable params num : ', sum(params_num))
    # [print(name, p.numel()) for name, p in filter(lambda p: p[1].requires_grad, model.named_parameters())]

    # setup optimizer
    if opt.adam:
        optimizer = optim.Adam(filtered_parameters,
                               lr=opt.lr,
                               betas=(opt.beta1, 0.999))
    else:
        optimizer = optim.Adadelta(filtered_parameters,
                                   lr=opt.lr,
                                   rho=opt.rho,
                                   eps=opt.eps)
    print("Optimizer:")
    print(optimizer)
    """ final options """
    # print(opt)
    with open(f'./saved_models/{opt.experiment_name}/opt.txt',
              'a',
              encoding="utf-8") as opt_file:
        opt_log = '------------ Options -------------\n'
        args = vars(opt)
        for k, v in args.items():
            opt_log += f'{str(k)}: {str(v)}\n'
        opt_log += '---------------------------------------\n'
        print(opt_log)
        opt_file.write(opt_log)
    """ start training """
    start_iter = 0
    if opt.continue_model != '':
        start_iter = int(opt.continue_model.split('_')[-1].split('.')[0])
        print(f'continue to train, start_iter: {start_iter}')

    start_time = time.time()
    best_accuracy = -1
    best_norm_ED = 1e+6
    i = start_iter

    while (True):
        # train part
        image_tensors, labels = train_dataset.get_batch()

        image = image_tensors.to(device)
        text, length = converter.encode(labels,
                                        batch_max_length=opt.batch_max_length)
        batch_size = image.size(0)
        #import ipdb;ipdb.set_trace()

        if 'CTC' in opt.Prediction:
            preds = model(image, text).log_softmax(2)

            preds_size = torch.IntTensor([preds.size(1)] *
                                         batch_size).to(device)
            preds = preds.permute(1, 0, 2)  # to use CTCLoss format

            # To avoid ctc_loss issue, disabled cudnn for the computation of the ctc_loss
            # https://github.com/jpuigcerver/PyLaia/issues/16
            torch.backends.cudnn.enabled = False
            cost = criterion(preds, text, preds_size, length)
            torch.backends.cudnn.enabled = True

        else:

            preds = model(image, text[:, :-1])  # align with Attention.forward

            target = text[:, 1:]  # without [GO] Symbol
            cost = criterion(preds.view(-1, preds.shape[-1]),
                             target.contiguous().view(-1))

        model.zero_grad()
        cost.backward()
        torch.nn.utils.clip_grad_norm_(
            model.parameters(),
            opt.grad_clip)  # gradient clipping with 5 (Default)
        optimizer.step()

        loss_avg.add(cost)

        # validation part
        if i % opt.valInterval == 0:
            elapsed_time = time.time() - start_time
            print(
                f'[{i}/{opt.num_iter}] Loss: {loss_avg.val():0.5f} elapsed_time: {elapsed_time:0.5f}'
            )
            # for log
            with open(f'./saved_models/{opt.experiment_name}/log_train.txt',
                      'a',
                      encoding="utf-8") as log:
                log.write(
                    f'[{i}/{opt.num_iter}] Loss: {loss_avg.val():0.5f} elapsed_time: {elapsed_time:0.5f}\n'
                )
                loss_avg.reset()

                model.eval()
                with torch.no_grad():
                    valid_loss, current_accuracy, current_norm_ED, preds, labels, infer_time, length_of_data = validation(
                        model, criterion, valid_loader, converter, opt)
                model.train()

                for pred, gt in zip(preds[:5], labels[:5]):
                    if 'Attn' in opt.Prediction:
                        pred = pred[:pred.find('[s]')]
                        gt = gt[:gt.find('[s]')]
                    print(f'{pred:20s}, gt: {gt:20s},   {str(pred == gt)}')
                    #pred = pred.encode('utf-8')
                    #gt = gt.encode('utf-8')
                    log.write(
                        f'{pred:20s}, gt: {gt:20s},   {str(pred == gt)}\n')

                valid_log = f'[{i}/{opt.num_iter}] valid loss: {valid_loss:0.5f}'
                valid_log += f' accuracy: {current_accuracy:0.3f}, norm_ED: {current_norm_ED:0.2f}'
                print(valid_log)
                log.write(valid_log + '\n')

                # keep best accuracy model
                if current_accuracy > best_accuracy:
                    best_accuracy = current_accuracy
                    torch.save(
                        model.state_dict(),
                        f'./saved_models/{opt.experiment_name}/best_accuracy.pth'
                    )
                if current_norm_ED < best_norm_ED:
                    best_norm_ED = current_norm_ED
                    torch.save(
                        model.state_dict(),
                        f'./saved_models/{opt.experiment_name}/best_norm_ED.pth'
                    )
                best_model_log = f'best_accuracy: {best_accuracy:0.3f}, best_norm_ED: {best_norm_ED:0.2f}'
                print(best_model_log)
                log.write(best_model_log + '\n')

        # save model per 1e+5 iter.
        if (i + 1) % 1e+5 == 0:
            torch.save(model.state_dict(),
                       f'./saved_models/{opt.experiment_name}/iter_{i+1}.pth')

        if i == opt.num_iter:
            print('end the training')
            sys.exit()
        i += 1
Beispiel #17
0
def train(opt, show_number=2, amp=False):
    """ dataset preparation """
    if not opt.data_filtering_off:
        print(
            'Filtering the images containing characters which are not in opt.character'
        )
        print(
            'Filtering the images whose label is longer than opt.batch_max_length'
        )

    opt.select_data = opt.select_data.split('-')
    opt.batch_ratio = opt.batch_ratio.split('-')
    train_dataset = Batch_Balanced_Dataset(opt)

    log = open(f'./saved_models/{opt.experiment_name}/log_dataset.txt',
               'a',
               encoding="utf8")
    AlignCollate_valid = AlignCollate(imgH=opt.imgH,
                                      imgW=opt.imgW,
                                      keep_ratio_with_pad=opt.PAD,
                                      contrast_adjust=opt.contrast_adjust)
    valid_dataset, valid_dataset_log = hierarchical_dataset(
        root=opt.valid_data, opt=opt)
    valid_loader = torch.utils.data.DataLoader(
        valid_dataset,
        batch_size=min(32, opt.batch_size),
        shuffle=
        True,  # 'True' to check training progress with validation function.
        num_workers=int(opt.workers),
        prefetch_factor=512,
        collate_fn=AlignCollate_valid,
        pin_memory=True)
    log.write(valid_dataset_log)
    print('-' * 80)
    log.write('-' * 80 + '\n')
    log.close()
    """ model configuration """
    if 'CTC' in opt.Prediction:
        converter = CTCLabelConverter(opt.character)
    else:
        converter = AttnLabelConverter(opt.character)
    opt.num_class = len(converter.character)

    if opt.rgb:
        opt.input_channel = 3
    model = Model(opt)
    print('model input parameters', opt.imgH, opt.imgW, opt.num_fiducial,
          opt.input_channel, opt.output_channel, opt.hidden_size,
          opt.num_class, opt.batch_max_length, opt.Transformation,
          opt.FeatureExtraction, opt.SequenceModeling, opt.Prediction)

    if opt.saved_model != '':
        pretrained_dict = torch.load(opt.saved_model)
        if opt.new_prediction:
            model.Prediction = nn.Linear(
                model.SequenceModeling_output,
                len(pretrained_dict['module.Prediction.weight']))

        model = torch.nn.DataParallel(model).to(device)
        print(f'loading pretrained model from {opt.saved_model}')
        if opt.FT:
            model.load_state_dict(pretrained_dict, strict=False)
        else:
            model.load_state_dict(pretrained_dict)
        if opt.new_prediction:
            model.module.Prediction = nn.Linear(
                model.module.SequenceModeling_output, opt.num_class)
            for name, param in model.module.Prediction.named_parameters():
                if 'bias' in name:
                    init.constant_(param, 0.0)
                elif 'weight' in name:
                    init.kaiming_normal_(param)
            model = model.to(device)
    else:
        # weight initialization
        for name, param in model.named_parameters():
            if 'localization_fc2' in name:
                print(f'Skip {name} as it is already initialized')
                continue
            try:
                if 'bias' in name:
                    init.constant_(param, 0.0)
                elif 'weight' in name:
                    init.kaiming_normal_(param)
            except Exception as e:  # for batchnorm.
                if 'weight' in name:
                    param.data.fill_(1)
                continue
        model = torch.nn.DataParallel(model).to(device)

    model.train()
    print("Model:")
    print(model)
    count_parameters(model)
    """ setup loss """
    if 'CTC' in opt.Prediction:
        criterion = torch.nn.CTCLoss(zero_infinity=True).to(device)
    else:
        criterion = torch.nn.CrossEntropyLoss(ignore_index=0).to(
            device)  # ignore [GO] token = ignore index 0
    # loss averager
    loss_avg = Averager()

    # freeze some layers
    try:
        if opt.freeze_FeatureFxtraction:
            for param in model.module.FeatureExtraction.parameters():
                param.requires_grad = False
        if opt.freeze_SequenceModeling:
            for param in model.module.SequenceModeling.parameters():
                param.requires_grad = False
    except:
        pass

    # filter that only require gradient decent
    filtered_parameters = []
    params_num = []
    for p in filter(lambda p: p.requires_grad, model.parameters()):
        filtered_parameters.append(p)
        params_num.append(np.prod(p.size()))
    print('Trainable params num : ', sum(params_num))
    # [print(name, p.numel()) for name, p in filter(lambda p: p[1].requires_grad, model.named_parameters())]

    # setup optimizer
    if opt.optim == 'adam':
        #optimizer = optim.Adam(filtered_parameters, lr=opt.lr, betas=(opt.beta1, 0.999))
        optimizer = optim.Adam(filtered_parameters)
    else:
        optimizer = optim.Adadelta(filtered_parameters,
                                   lr=opt.lr,
                                   rho=opt.rho,
                                   eps=opt.eps)
    print("Optimizer:")
    print(optimizer)
    """ final options """
    # print(opt)
    with open(f'./saved_models/{opt.experiment_name}/opt.txt',
              'a',
              encoding="utf8") as opt_file:
        opt_log = '------------ Options -------------\n'
        args = vars(opt)
        for k, v in args.items():
            opt_log += f'{str(k)}: {str(v)}\n'
        opt_log += '---------------------------------------\n'
        print(opt_log)
        opt_file.write(opt_log)
    """ start training """
    start_iter = 0
    if opt.saved_model != '':
        try:
            start_iter = int(opt.saved_model.split('_')[-1].split('.')[0])
            print(f'continue to train, start_iter: {start_iter}')
        except:
            pass

    start_time = time.time()
    best_accuracy = -1
    best_norm_ED = -1
    i = start_iter

    scaler = GradScaler()
    t1 = time.time()

    while (True):
        # train part
        optimizer.zero_grad(set_to_none=True)

        if amp:
            with autocast():
                image_tensors, labels = train_dataset.get_batch()
                image = image_tensors.to(device)
                text, length = converter.encode(
                    labels, batch_max_length=opt.batch_max_length)
                batch_size = image.size(0)

                if 'CTC' in opt.Prediction:
                    preds = model(image, text).log_softmax(2)
                    preds_size = torch.IntTensor([preds.size(1)] * batch_size)
                    preds = preds.permute(1, 0, 2)
                    torch.backends.cudnn.enabled = False
                    cost = criterion(preds, text.to(device),
                                     preds_size.to(device), length.to(device))
                    torch.backends.cudnn.enabled = True
                else:
                    preds = model(image,
                                  text[:, :-1])  # align with Attention.forward
                    target = text[:, 1:]  # without [GO] Symbol
                    cost = criterion(preds.view(-1, preds.shape[-1]),
                                     target.contiguous().view(-1))
            scaler.scale(cost).backward()
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), opt.grad_clip)
            scaler.step(optimizer)
            scaler.update()
        else:
            image_tensors, labels = train_dataset.get_batch()
            image = image_tensors.to(device)
            text, length = converter.encode(
                labels, batch_max_length=opt.batch_max_length)
            batch_size = image.size(0)
            if 'CTC' in opt.Prediction:
                preds = model(image, text).log_softmax(2)
                preds_size = torch.IntTensor([preds.size(1)] * batch_size)
                preds = preds.permute(1, 0, 2)
                torch.backends.cudnn.enabled = False
                cost = criterion(preds, text.to(device), preds_size.to(device),
                                 length.to(device))
                torch.backends.cudnn.enabled = True
            else:
                preds = model(image,
                              text[:, :-1])  # align with Attention.forward
                target = text[:, 1:]  # without [GO] Symbol
                cost = criterion(preds.view(-1, preds.shape[-1]),
                                 target.contiguous().view(-1))
            cost.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), opt.grad_clip)
            optimizer.step()
        loss_avg.add(cost)

        # validation part
        if (i % opt.valInterval == 0) and (i != 0):
            print('training time: ', time.time() - t1)
            t1 = time.time()
            elapsed_time = time.time() - start_time
            # for log
            with open(f'./saved_models/{opt.experiment_name}/log_train.txt',
                      'a',
                      encoding="utf8") as log:
                model.eval()
                with torch.no_grad():
                    valid_loss, current_accuracy, current_norm_ED, preds, confidence_score, labels,\
                    infer_time, length_of_data = validation(model, criterion, valid_loader, converter, opt, device)
                model.train()

                # training loss and validation loss
                loss_log = f'[{i}/{opt.num_iter}] Train loss: {loss_avg.val():0.5f}, Valid loss: {valid_loss:0.5f}, Elapsed_time: {elapsed_time:0.5f}'
                loss_avg.reset()

                current_model_log = f'{"Current_accuracy":17s}: {current_accuracy:0.3f}, {"Current_norm_ED":17s}: {current_norm_ED:0.4f}'

                # keep best accuracy model (on valid dataset)
                if current_accuracy > best_accuracy:
                    best_accuracy = current_accuracy
                    torch.save(
                        model.state_dict(),
                        f'./saved_models/{opt.experiment_name}/best_accuracy.pth'
                    )
                if current_norm_ED > best_norm_ED:
                    best_norm_ED = current_norm_ED
                    torch.save(
                        model.state_dict(),
                        f'./saved_models/{opt.experiment_name}/best_norm_ED.pth'
                    )
                best_model_log = f'{"Best_accuracy":17s}: {best_accuracy:0.3f}, {"Best_norm_ED":17s}: {best_norm_ED:0.4f}'

                loss_model_log = f'{loss_log}\n{current_model_log}\n{best_model_log}'
                print(loss_model_log)
                log.write(loss_model_log + '\n')

                # show some predicted results
                dashed_line = '-' * 80
                head = f'{"Ground Truth":25s} | {"Prediction":25s} | Confidence Score & T/F'
                predicted_result_log = f'{dashed_line}\n{head}\n{dashed_line}\n'

                #show_number = min(show_number, len(labels))

                start = random.randint(0, len(labels) - show_number)
                for gt, pred, confidence in zip(
                        labels[start:start + show_number],
                        preds[start:start + show_number],
                        confidence_score[start:start + show_number]):
                    if 'Attn' in opt.Prediction:
                        gt = gt[:gt.find('[s]')]
                        pred = pred[:pred.find('[s]')]

                    predicted_result_log += f'{gt:25s} | {pred:25s} | {confidence:0.4f}\t{str(pred == gt)}\n'
                predicted_result_log += f'{dashed_line}'
                print(predicted_result_log)
                log.write(predicted_result_log + '\n')
                print('validation time: ', time.time() - t1)
                t1 = time.time()
        # save model per 1e+4 iter.
        if (i + 1) % 1e+4 == 0:
            torch.save(model.state_dict(),
                       f'./saved_models/{opt.experiment_name}/iter_{i+1}.pth')

        if i == opt.num_iter:
            print('end the training')
            sys.exit()
        i += 1
def train(opt, tb):
    """ dataset preparation """
    if not opt.data_filtering_off:
        print('Filtering the images containing characters which are not in opt.character')
        print('Filtering the images whose label is longer than opt.batch_max_length')
        # see https://github.com/clovaai/deep-text-recognition-benchmark/blob/6593928855fb7abb999a99f428b3e4477d4ae356/dataset.py#L130

    opt.select_data = opt.select_data.split('-')
    opt.batch_ratio = opt.batch_ratio.split('-')
    train_dataset = Batch_Balanced_Dataset(opt)

    log = open(f'./saved_models/{opt.experiment_name}/log_dataset.txt', 'a')
    AlignCollate_valid = AlignCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio_with_pad=opt.PAD)
    valid_dataset, valid_dataset_log = hierarchical_dataset(root=opt.valid_data, opt=opt)
    valid_loader = torch.utils.data.DataLoader(
        valid_dataset, batch_size=opt.batch_size,
        shuffle=True,  # 'True' to check training progress with validation function.
        num_workers=int(opt.workers),
        collate_fn=AlignCollate_valid, pin_memory=True)
    log.write(valid_dataset_log)
    print('-' * 80)
    log.write('-' * 80 + '\n')
    log.close()
    
    """ model configuration """
    if 'CTC' in opt.Prediction:
        converter = CTCLabelConverter(opt.character)
    else:
        converter = AttnLabelConverter(opt.character)
    opt.num_class = len(converter.character)

    if opt.rgb:
        opt.input_channel = 3
    model = Model(opt)
    print('model input parameters', opt.imgH, opt.imgW, opt.num_fiducial, opt.input_channel, opt.output_channel,
          opt.hidden_size, opt.num_class, opt.batch_max_length, opt.Transformation, opt.FeatureExtraction,
          opt.SequenceModeling, opt.Prediction)

    # weight initialization
    for name, param in model.named_parameters():
        if 'localization_fc2' in name:
            print(f'Skip {name} as it is already initialized')
            continue
        try:
            if 'bias' in name:
                init.constant_(param, 0.0)
            elif 'weight' in name:
                init.kaiming_normal_(param)
        except Exception as e:  # for batchnorm.
            if 'weight' in name:
                param.data.fill_(1)
            continue

    # data parallel for multi-GPU
    model = torch.nn.DataParallel(model).to(device)
    model.train()
    if opt.saved_model != '':
        print(f'loading pretrained model from {opt.saved_model}')
        if opt.FT:
            model.load_state_dict(torch.load(opt.saved_model), strict=False)
        else:
            model.load_state_dict(torch.load(opt.saved_model))
    print("Model:")
    print(model)

    """ setup loss """
    if 'CTC' in opt.Prediction:
        criterion = torch.nn.CTCLoss(zero_infinity=True).to(device)
    else:
        criterion = torch.nn.CrossEntropyLoss(ignore_index=0).to(device)  # ignore [GO] token = ignore index 0
    # loss averager
    loss_avg = Averager()

    # filter that only require gradient decent
    filtered_parameters = []
    params_num = []
    print("~~~~~~~~~~~~Gradient Descent~~~~~~~~~~~~~")
    #print(model.parameters())
    #print(model.)
    for p in filter(lambda p: p.requires_grad, model.parameters()):
        filtered_parameters.append(p)
        params_num.append(np.prod(p.size()))
    print('Filtered parameters for gradient descent: \n', len(filtered_parameters))
    print('Trainable params num : ', sum(params_num))
    # [print(name, p.numel()) for name, p in filter(lambda p: p[1].requires_grad, model.named_parameters())]

    # setup optimizer
    if opt.adam:
        optimizer = optim.Adam(filtered_parameters, lr=opt.lr, betas=(opt.beta1, 0.999))
    else:
        optimizer = optim.Adadelta(filtered_parameters, lr=opt.lr, rho=opt.rho, eps=opt.eps)
    print("Optimizer:")
    print(optimizer)

    """ final options """
    # print(opt)
    with open(f'./saved_models/{opt.experiment_name}/opt.txt', 'a') as opt_file:
        opt_log = '------------ Options -------------\n'
        args = vars(opt)
        for k, v in args.items():
            opt_log += f'{str(k)}: {str(v)}\n'
        opt_log += '---------------------------------------\n'
        print(opt_log)
        opt_file.write(opt_log)

    """ start training """
    start_iter = 0
    if opt.saved_model != '':
        try:
            start_iter = int(opt.saved_model.split('_')[-1].split('.')[0])
            print(f'continue to train, start_iter: {start_iter}')
        except:
            pass

    start_time = time.time()
    best_accuracy = -1
    best_norm_ED = -1
    i = start_iter

    while(True):
        # train part
        image_tensors, labels = train_dataset.get_batch()
        image = image_tensors.to(device)
        text, length = converter.encode(labels, batch_max_length=opt.batch_max_length)
        batch_size = image.size(0)

        if 'CTC' in opt.Prediction:
            preds = model(image, text).log_softmax(2)
            preds_size = torch.IntTensor([preds.size(1)] * batch_size)
            preds = preds.permute(1, 0, 2)

            # (ctc_a) For PyTorch 1.2.0 and 1.3.0. To avoid ctc_loss issue, disabled cudnn for the computation of the ctc_loss
            # https://github.com/jpuigcerver/PyLaia/issues/16
            torch.backends.cudnn.enabled = False
            cost = criterion(preds, text.to(device), preds_size.to(device), length.to(device))
            torch.backends.cudnn.enabled = True

            # # (ctc_b) To reproduce our pretrained model / paper, use our previous code (below code) instead of (ctc_a).
            # # With PyTorch 1.2.0, the below code occurs NAN, so you may use PyTorch 1.1.0.
            # # Thus, the result of CTCLoss is different in PyTorch 1.1.0 and PyTorch 1.2.0.
            # # See https://github.com/clovaai/deep-text-recognition-benchmark/issues/56#issuecomment-526490707
            # cost = criterion(preds, text, preds_size, length)

        else:
            preds = model(image, text[:, :-1])  # align with Attention.forward
            print(preds[0][0])
            target = text[:, 1:]  # without [GO] Symbol
            print(target[0])
            cost = criterion(preds.view(-1, preds.shape[-1]), target.contiguous().view(-1))

        model.zero_grad()
        cost.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), opt.grad_clip)  # gradient clipping with 5 (Default)
        optimizer.step()

        loss_avg.add(cost)

        # validation part
        if i % opt.valInterval == 0:
            elapsed_time = time.time() - start_time
            # for log
            with open(f'./saved_models/{opt.experiment_name}/log_train.txt', 'a') as log:
                model.eval()
                with torch.no_grad():
                    valid_loss, current_accuracy, current_norm_ED, preds, confidence_score, labels, infer_time, length_of_data = validation(
                        model, criterion, valid_loader, converter, opt)
                model.train()

                # training loss and validation loss
                loss_log = f'[{i}/{opt.num_iter}] Train loss: {loss_avg.val():0.5f}, Valid loss: {valid_loss:0.5f}, Elapsed_time: {elapsed_time:0.5f}'
                tb.add_scalar('Training Loss vs Iteration', loss_avg.val(), i)              # Record to Tensorboard
                tb.add_scalar('Validation Loss vs Iteration', valid_loss, i)                # Record to Tensorboard
                loss_avg.reset()

                current_model_log = f'{"Current_accuracy":17s}: {current_accuracy:0.3f}, {"Current_norm_ED":17s}: {current_norm_ED:0.2f}'
                tb.add_scalar('Current Accuracy vs Iteration', current_accuracy, i)         # Record to Tensorboard
                tb.add_scalar('Current Norm ED vs Iteration', current_norm_ED, i)           # Record to Tensorboard            

                # keep best accuracy model (on valid dataset)
                if current_accuracy > best_accuracy:
                    best_accuracy = current_accuracy
                    torch.save(model.state_dict(), f'./saved_models/{opt.experiment_name}/best_accuracy.pth')
                if current_norm_ED > best_norm_ED:
                    best_norm_ED = current_norm_ED
                    torch.save(model.state_dict(), f'./saved_models/{opt.experiment_name}/best_norm_ED.pth')
                best_model_log = f'{"Best_accuracy":17s}: {best_accuracy:0.3f}, {"Best_norm_ED":17s}: {best_norm_ED:0.2f}'

                loss_model_log = f'{loss_log}\n{current_model_log}\n{best_model_log}'
                print(loss_model_log)
                log.write(loss_model_log + '\n')

                # show some predicted results
                dashed_line = '-' * 80
                head = f'{"Ground Truth":25s} | {"Prediction":25s} | Confidence Score & T/F'
                predicted_result_log = f'{dashed_line}\n{head}\n{dashed_line}\n'
                for gt, pred, confidence in zip(labels[:5], preds[:5], confidence_score[:5]):
                    if 'Attn' in opt.Prediction:
                        gt = gt[:gt.find('[s]')]
                        pred = pred[:pred.find('[s]')]

                    predicted_result_log += f'{gt:25s} | {pred:25s} | {confidence:0.4f}\t{str(pred == gt)}\n'
                predicted_result_log += f'{dashed_line}'
                print(predicted_result_log)
                log.write(predicted_result_log + '\n')

        # save model per 1e+5 iter.
        if (i + 1) % 1e+5 == 0:
            torch.save(
                model.state_dict(), f'./saved_models/{opt.experiment_name}/iter_{i+1}.pth')

        if i == opt.num_iter:
            print('end the training')
            sys.exit()
        i += 1
Beispiel #19
0
def train(opt):
    """ training pipeline for our character recognition model """
    if not opt.data_filtering_off:
        print(
            "Filtering the images containing characters which are not in opt.character"
        )
        print(
            "Filtering the images whose label is longer than opt.batch_max_length"
        )

    opt.select_data = opt.select_data.split("-")
    opt.batch_ratio = opt.batch_ratio.split("-")
    train_dataset = Batch_Balanced_Dataset(opt)

    # Logging the experiment, so that we can refer to the performance of previous runs
    log = open(f"./saved_models/{opt.exp_name}/log_dataset.txt", "a")
    # Using params from user input to collation function for dataloader
    AlignCollate_valid = AlignCollate(imgH=opt.imgH,
                                      imgW=opt.imgW,
                                      keep_ratio_with_pad=opt.PAD)
    # Defining our validation dataloader
    valid_dataset, valid_dataset_log = hierarchical_dataset(
        root=opt.valid_data, opt=opt)
    valid_loader = torch.utils.data.DataLoader(
        valid_dataset,
        batch_size=opt.batch_size,
        shuffle=
        True,  # 'True' to check training progress with validation function.
        num_workers=int(opt.workers),
        collate_fn=AlignCollate_valid,
        pin_memory=True,
    )
    log.write(valid_dataset_log)
    print("-" * 80)
    log.write("-" * 80 + "\n")
    log.close()

    # Using either CTC or Attention for char predictions
    if "CTC" in opt.Prediction:
        converter = CTCLabelConverter(opt.character)
    else:
        converter = AttnLabelConverter(opt.character)
    opt.num_class = len(converter.character)
    # Runnning our OCR model in grayscale or RGB
    if opt.rgb:
        opt.input_channel = 3
    # Defining our model using user inputs
    model = Model(opt)
    print(
        "model input parameters",
        opt.imgH,
        opt.imgW,
        opt.num_fiducial,
        opt.input_channel,
        opt.output_channel,
        opt.hidden_size,
        opt.num_class,
        opt.batch_max_length,
        opt.Transformation,
        opt.FeatureExtraction,
        opt.SequenceModeling,
        opt.Prediction,
    )

    # weight initialization
    for name, param in model.named_parameters():
        if "localization_fc2" in name:
            print(f"Skip {name} as it is already initialized")
            continue
        try:
            if "bias" in name:
                init.constant_(param, 0.0)
            elif "weight" in name:
                init.kaiming_normal_(param)
        except Exception as e:  # for batchnorm.
            if "weight" in name:
                param.data.fill_(1)
            continue
    # Putting model in training mode
    model.train()
    # Using finetuning saved model from previous runs
    if opt.saved_model != "":
        print(f"loading pretrained model from {opt.saved_model}")
        if opt.FT:
            model.load_state_dict(torch.load(opt.saved_model), strict=False)
        else:
            model.load_state_dict(torch.load(opt.saved_model))
    print("Model:")
    # print(model)
    # Sending model to cpu or gpu, depending upon the avialbility
    model.to(device)

    # Setting up loss functions in the case of either CTC or Attention
    if "CTC" in opt.Prediction:
        criterion = torch.nn.CTCLoss(zero_infinity=True).to(device)
    else:
        criterion = torch.nn.CrossEntropyLoss(ignore_index=0).to(
            device)  # ignore [GO] token = ignore index 0
    # loss averager
    loss_avg = Averager()

    # filter that only require gradient decent
    filtered_parameters = []
    params_num = []
    for p in filter(lambda p: p.requires_grad, model.parameters()):
        filtered_parameters.append(p)
        params_num.append(np.prod(p.size()))
    print("Trainable params num : ", sum(params_num))
    # [print(name, p.numel()) for name, p in filter(lambda p: p[1].requires_grad, model.named_parameters())]

    # Setup of optimizer to be used
    if opt.adam:
        optimizer = optim.Adam(filtered_parameters,
                               lr=opt.lr,
                               betas=(opt.beta1, 0.999))
    else:
        optimizer = optim.Adadelta(filtered_parameters,
                                   lr=opt.lr,
                                   rho=opt.rho,
                                   eps=opt.eps)
    print("Optimizer:")
    print(optimizer)

    # print(opt)
    with open(f"./saved_models/{opt.exp_name}/opt.txt", "a") as opt_file:
        opt_log = "------------ Options -------------\n"
        args = vars(opt)
        for k, v in args.items():
            opt_log += f"{str(k)}: {str(v)}\n"
        opt_log += "---------------------------------------\n"
        print(opt_log)
        opt_file.write(opt_log)

    # Training iteration starts here
    start_iter = 0
    if opt.saved_model != "":
        try:
            start_iter = int(opt.saved_model.split("_")[-1].split(".")[0])
            print(f"continue to train, start_iter: {start_iter}")
        except:
            pass

    # Setting up initial metrics results and initializing the timer
    start_time = time.time()
    best_accuracy = -1
    best_norm_ED = -1
    iteration = start_iter

    while True:
        # train part
        image_tensors, labels = train_dataset.get_batch()
        image = image_tensors.to(device)
        text, length = converter.encode(labels,
                                        batch_max_length=opt.batch_max_length)
        batch_size = image.size(0)

        if "CTC" in opt.Prediction:
            preds = model(image, text)
            preds_size = torch.IntTensor([preds.size(1)] * batch_size)
            preds = preds.log_softmax(2).permute(1, 0, 2)
            cost = criterion(preds, text, preds_size, length)

        else:
            preds = model(image, text[:, :-1])  # align with Attention.forward
            target = text[:, 1:]  # without [GO] Symbol
            cost = criterion(preds.view(-1, preds.shape[-1]),
                             target.contiguous().view(-1))

        model.zero_grad()
        cost.backward()
        torch.nn.utils.clip_grad_norm_(
            model.parameters(),
            opt.grad_clip)  # gradient clipping with 5 (Default)
        optimizer.step()

        loss_avg.add(cost)

        # validation part
        if (
                iteration + 1
        ) % opt.valInterval == 0 or iteration == 0:  # To see training progress, we also conduct validation when 'iteration == 0'
            elapsed_time = time.time() - start_time
            # for log
            with open(f"./saved_models/{opt.exp_name}/log_train.txt",
                      "a") as log:
                model.eval()
                with torch.no_grad():
                    (
                        valid_loss,
                        current_accuracy,
                        current_norm_ED,
                        preds,
                        confidence_score,
                        labels,
                        infer_time,
                        length_of_data,
                    ) = validation(model, criterion, valid_loader, converter,
                                   opt)
                model.train()

                # training loss and validation loss
                loss_log = f"[{iteration+1}/{opt.num_iter}] Train loss: {loss_avg.val():0.5f}, Valid loss: {valid_loss:0.5f}, Elapsed_time: {elapsed_time:0.5f}"
                loss_avg.reset()

                current_model_log = f'{"Current_accuracy":17s}: {current_accuracy:0.3f}, {"Current_norm_ED":17s}: {current_norm_ED:0.2f}'

                # keep best accuracy model (on valid dataset)
                if current_accuracy > best_accuracy:
                    best_accuracy = current_accuracy
                    torch.save(
                        model.state_dict(),
                        f"./saved_models/{opt.exp_name}/best_accuracy.pth",
                    )
                if current_norm_ED > best_norm_ED:
                    best_norm_ED = current_norm_ED
                    torch.save(
                        model.state_dict(),
                        f"./saved_models/{opt.exp_name}/best_norm_ED.pth",
                    )
                best_model_log = f'{"Best_accuracy":17s}: {best_accuracy:0.3f}, {"Best_norm_ED":17s}: {best_norm_ED:0.2f}'

                loss_model_log = f"{loss_log}\n{current_model_log}\n{best_model_log}"
                print(loss_model_log)
                log.write(loss_model_log + "\n")

                # show some predicted results
                dashed_line = "-" * 80
                head = f'{"Ground Truth":25s} | {"Prediction":25s} | Confidence Score & T/F'
                predicted_result_log = f"{dashed_line}\n{head}\n{dashed_line}\n"
                for gt, pred, confidence in zip(labels[:5], preds[:5],
                                                confidence_score[:5]):
                    if "Attn" in opt.Prediction:
                        gt = gt[:gt.find("[s]")]
                        pred = pred[:pred.find("[s]")]

                    predicted_result_log += f"{gt:25s} | {pred:25s} | {confidence:0.4f}\t{str(pred == gt)}\n"
                predicted_result_log += f"{dashed_line}"
                print(predicted_result_log)
                log.write(predicted_result_log + "\n")

        # save model per 1e+5 iter.
        if (iteration + 1) % 1e5 == 0:
            torch.save(
                model.state_dict(),
                f"./saved_models/{opt.exp_name}/iter_{iteration+1}.pth",
            )

        if (iteration + 1) == opt.num_iter:
            print("end the training")
            sys.exit()
        iteration += 1