示例#1
0
def evaluate_llhx(frame, input_loader, n_marg_llh, use_q_llh, mode, device):
    avgr = Averager()
    for x, y in input_loader:
        x = x.to(device)
        avgr.update(frame.llh(x, None, n_marg_llh, use_q_llh, mode),
                    nrep=len(x))
    return avgr.avg
示例#2
0
def train(model, optimizer, source_loader, args):
    model.train()

    acc_avg = Averager()
    loss_avg = Averager()
    
    iter_source = iter(source_loader)
    num_iter = len(iter_source)
    for i in range(num_iter):
        sx, sy = iter_source.next()

        if args.cuda:
            sx, sy = sx.cuda(), sy.cuda()

        _, logits = model(sx)
        
        criterion = nn.CrossEntropyLoss()
        loss = criterion(logits, sy)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        acc = count_acc(logits, sy)
        acc_avg.add(acc)
        loss_avg.add(loss.item())

    acc = acc_avg.item()
    loss = loss_avg.item()
    return loss, acc
def validation(model, val_loader):
    '''
        evaluation function for validation data
    '''

    print("Start validation.\n")
    val_loss_hist = Averager()

    with torch.no_grad():
        for images, targets, _ in val_loader:
            images = torch.stack(images).to(CFG.device).float()
            bboxes = [
                target['boxes'].to(CFG.device).float() for target in targets
            ]
            labels = [
                target['labels'].to(CFG.device).float() for target in targets
            ]
            batch_size = images.shape[0]

            target_res = dict()
            target_res['bbox'] = bboxes
            target_res['cls'] = labels
            # target_res["img_scale"] = torch.tensor([1.0] * batch_size, dtype=torch.float).to(CFG.device)
            # target_res["img_size"] = torch.tensor([images[0].shape[-2:]] * batch_size, dtype=torch.float).to(CFG.device)

            # forward pass & calculate loss
            output = model(images, target_res)
            loss_value = output['loss'].detach().item()
            val_loss_hist.update(loss_value, batch_size)

            del images, targets, bboxes

    return val_loss_hist.value
示例#4
0
    def train(self, epoch, dataloader):
        use_cuda = torch.cuda.is_available()
        self.net.train()

        for m in self.net.modules():  ### freeze是处理batchnorm的
            if isinstance(m, _BatchNorm):
                if self.args["train"]["freeze"]:
                    m.eval()

        loss_avg = Averager()
        lls_avg = Averager()

        for batch_idx, sample in enumerate(dataloader):

            sequence_idx = sample[0]
            data, target, name = sample[0], sample[1], sample[2]
            if use_cuda:
                data = data.cuda()
                target = target.cuda()

            if len(sample) > 3:
                args = sample[3:]
            else:
                args = None

            total_loss, loss_list = self.warp(
                data, target, False, args)  #### warp一般是并行,返回一组batch的损失
            loss_avg.update(total_loss.mean().detach().cpu().numpy())
            loss_list = tuple([l.cpu().numpy() for l in loss_list])
            info = "Finish training %d out of %d, " % (batch_idx + 1,
                                                       len(dataloader))
            for lid, l in enumerate(loss_list):
                info += "loss %d: %.4f, " % (lid, float(np.mean(l)))
            print(info)
            lls_avg.update(loss_list)
            self.optimizer.zero_grad()
            loss_scalar = torch.mean(total_loss)
            if self.half:  # Automated Mix Precision
                with amp.scale_loss(loss_scalar,
                                    self.optimizer) as scaled_loss:
                    scaled_loss.backward()
            else:
                loss_scalar.backward()
            torch.nn.utils.clip_grad_value_(self.net.parameters(), 2)
            self.optimizer.step()
        self.scheduler.step(epoch=epoch)
        self.__writeLossLog(
            "Train",
            epoch,
            meanloss=loss_avg.val(),
            loss_list=lls_avg.val(),
            lr=self.scheduler.get_lr()[0],
        )
示例#5
0
    def __init__(self,
                 algo,
                 dataset,
                 kernel_fn,
                 base_model_fn,
                 num_particles=10,
                 resume=False,
                 resume_epoch=None,
                 resume_lr=1e-4):

        self.algo = algo
        self.dataset = dataset
        self.kernel_fn = kernel_fn
        self.num_particles = num_particles
        print("running {} on {}".format(algo, dataset))

        if self.dataset == 'regression':
            self.data = toy.generate_regression_data(80, 200)
            (self.train_data,
             self.train_targets), (self.test_data,
                                   self.test_targets) = self.data
        elif self.dataset == 'classification':
            self.train_data, self.train_targets = toy.generate_classification_data(
                100)
            self.test_data, self.test_targets = toy.generate_classification_data(
                200)
        else:
            raise NotImplementedError

        if kernel_fn == 'rbf':
            self.kernel = rbf_fn
        else:
            raise NotImplementedError

        models = [base_model_fn().cuda() for _ in range(num_particles)]

        self.models = models
        param_set, state_dict = extract_parameters(self.models)

        self.state_dict = state_dict
        self.param_set = torch.nn.Parameter(param_set.clone(),
                                            requires_grad=True)

        self.optimizer = torch.optim.Adam([{
            'params': self.param_set,
            'lr': 1e-3
        }])

        if self.dataset == 'regression':
            self.loss_fn = torch.nn.MSELoss()
        elif self.dataset == 'classification':
            self.loss_fn = torch.nn.CrossEntropyLoss()
        self.kernel_width_averager = Averager(shape=())
def perform_lambda_test(n_episodes, n_avg):
    averager = Averager(
        GymEpisodeTaskFactory(env, n_episodes, SarsaLambdaFactory(env)))
    alphas = np.arange(
        1, 15
    ) / N_TILINGS  # Those are again divided by N_TILINGS in sarsa to give final alpha value
    results = defaultdict(lambda: np.zeros(len(alphas)))
    for lam in [0, .68, .84, .92, .96, .98, .99]:
        for i, alpha in np.ndenumerate(alphas):
            results[lam][i] = averager.average((lam, alpha),
                                               n_avg,
                                               merge=average_steps_per_episode)
    plot_scatters_from_dict(results, 'lambda={}', alphas)
示例#7
0
def validation(model, criterion, evaluation_loader, converter, opt):
    """ validation or evaluation """
    n_correct = 0
    norm_ED = 0
    length_of_data = 0
    infer_time = 0
    valid_loss_avg = Averager()

    for i, (image_tensors, labels) in enumerate(evaluation_loader):
        batch_size = image_tensors.size(0)
        length_of_data = length_of_data + batch_size
        image = image_tensors.to(device)
        # For max length prediction
        length_for_pred = torch.IntTensor([opt.batch_max_length] *
                                          batch_size).to(device)
        text_for_pred = torch.LongTensor(batch_size, opt.batch_max_length +
                                         1).fill_(0).to(device)

        text_for_loss, length_for_loss = converter.encode(
            labels, batch_max_length=opt.batch_max_length)

        start_time = time.time()

        preds, global_feature, local_feature, attention_weights, transformed_imgs, control_points = model(
            image, text_for_pred, is_train=False)

        forward_time = time.time() - start_time

        preds = preds[:, :text_for_loss.shape[1] - 1, :]
        target = text_for_loss[:, 1:]  # without [GO] Symbol
        cost = criterion(preds.contiguous().view(-1, preds.shape[-1]),
                         target.contiguous().view(-1))

        # select max probabilty (greedy decoding) then decode index to character
        preds_score, preds_index = preds.max(2)
        preds_str = converter.decode(preds_index, length_for_pred)
        labels = converter.decode(text_for_loss[:, 1:], length_for_loss)

        infer_time += forward_time
        valid_loss_avg.add(cost)

        # calculate accuracy.
        batch_n_correct, batch_char_acc = compute_loss(preds_str, labels, opt)
        n_correct += batch_n_correct
        norm_ED += batch_char_acc

    accuracy = n_correct / float(length_of_data) * 100
    norm_ED = norm_ED / float(length_of_data) * 100

    return valid_loss_avg.val(
    ), accuracy, norm_ED, preds_str, labels, infer_time, length_of_data
示例#8
0
def train_base(args, base_loader, model, criterion, optimizer, epoch):

    losses, accs = [Averager() for i in range(2)]

    model.train()
    for i, (imgs, lbls) in enumerate(base_loader):
        imgs, lbls = imgs.to(args.device), lbls.to(args.device)

        features = model(imgs)
        predicts = model.classify(features)

        loss = criterion(predicts, lbls)
        acc = accuracy_(predicts, lbls)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        model.weight_norm()

        losses.add(loss.item()), accs.add(acc)
        print(" [%d/%d] base loss:%.3f, acc:%.2f" %
              (i + 1, len(base_loader), losses.item(), accs.item()),
              end='  \r')

    wandb.log({
        'base_train_loss': losses.item(),
        'base_train_acc': accs.item()
    })
    print("Epoch %d, base loss:%.4f, acc:%.3f" %
          (epoch, losses.item(), accs.item()))
示例#9
0
def train(opt):
    """ dataset preparation """
    train_dataset, valid_loader = dataset_preparation(opt)
    """ model configuration """
    converter, model = model_configuration(opt)

    # weight initialization
    weight_initialization(model)

    # set data parallel for multi-GPU
    model = set_data_parallel(model, opt)
    """ setup loss """
    criterion = setup_loss(opt)
    # loss averager
    loss_avg = Averager()

    # filter that only require gradient decent
    filtered_parameters = filter(model)

    # setup optimizer
    optimizer = setup_optimizer(filtered_parameters, opt)
    """ final options """
    # print(opt)
    final_options(opt)
    """ start training """
    best_accuracy, best_norm_ED, i, start_time = start_training(opt)

    # Train all epochs
    train_all_epochs(best_accuracy, best_norm_ED, converter, criterion, i,
                     loss_avg, model, opt, optimizer, start_time,
                     train_dataset, valid_loader)
示例#10
0
文件: run.py 项目: easezyc/MetaHeac
    def train_stage(self, data_train_stage1, data_train_stage2, model, epoch):
        print('Training Epoch {}:'.format(epoch + 1))
        model.train()
        aid_set = list(set(data_train_stage1.aid))
        avg_loss = Averager()
        data_train = data_train_stage1
        n_samples = data_train.shape[0]
        n_batch = int(np.ceil(n_samples / self.batchsize))
        if self.sample_method == 'normal':
            list_prob = []
            for aid in aid_set:
                list_prob.append(data_train_stage1[data_train_stage1.aid == aid].shape[0])
            list_prob_sum = sum(list_prob)
            for i in range(len(list_prob)):
                list_prob[i] = list_prob[i] / list_prob_sum
        elif self.sample_method == 'sqrt':
            list_prob = []
            for aid in aid_set:
                list_prob.append(math.sqrt(data_train_stage1[data_train_stage1.aid == aid].shape[0]))
            list_prob_sum = sum(list_prob)
            for i in range(len(list_prob)):
                list_prob[i] = list_prob[i] / list_prob_sum
        for i_batch in tqdm.tqdm(range(n_batch)):
            if (self.sample_method == 'normal') or (self.sample_method == 'sqrt'):
                batch_aid_set = np.random.choice(aid_set, size=self.task_count, replace=False, p=list_prob)#random.sample(aid_set, 5)
            elif self.sample_method == 'unit':
                batch_aid_set = random.sample(aid_set, self.task_count)

            list_sup_x, list_sup_y, list_qry_x, list_qry_y = list(), list(), list(), list()
            for aid in batch_aid_set:

                batch_sup = data_train[data_train.aid == aid].sample(self.batchsize)
                batch_qry = data_train[data_train.aid == aid].sample(self.batchsize)

                batch_sup_x = batch_sup[self.train_col]
                batch_sup_y = batch_sup[self.label_col].values
                batch_qry_x = batch_qry[self.train_col]
                batch_qry_y = batch_qry[self.label_col].values

                list_sup_x.append(batch_sup_x)
                list_sup_y.append(batch_sup_y)
                list_qry_x.append(batch_qry_x)
                list_qry_y.append(batch_qry_y)

            loss = model.global_update(list_sup_x, list_sup_y, list_qry_x, list_qry_y)
            avg_loss.add(loss.item())
        print('Training Epoch {}; Loss {}; '.format(epoch + 1, avg_loss.item()))
示例#11
0
def val(model, args, val_loader, label):
    model.eval()
    vl = Averager()
    va = Averager()

    with torch.no_grad():
        for i, batch in tqdm(enumerate(val_loader, 1), total=len(val_loader)):
            data, index_label = batch[0].cuda(), batch[1]
            logits = model(data, mode='val')
            loss = F.cross_entropy(logits, label)
            acc = count_acc(logits, label)

            vl.add(loss.item())
            va.add(acc)

    vl = vl.item()
    va = va.item()
    return vl, va
示例#12
0
def test(model, target_loader, args):
    model.eval()

    acc_avg = Averager()
    with torch.no_grad():
        iter_target = iter(target_loader)
        num_iter = len(iter_target)
        for i in range(num_iter):
            tx, ty = iter_target.next()

            if torch.cuda.is_available():
                tx, ty = tx.cuda(), ty.cuda()

            _, logits = model(tx)
            acc = count_acc(logits, ty)
            acc_avg.add(acc)

    acc = acc_avg.item()
    return acc
示例#13
0
def test(model, label, args, few_shot_params):
    if args.debug:
        n_test = 10
        print_freq = 2
    else:
        n_test = 1000
        print_freq = 100
    test_file = args.dataset_dir + 'test.json'
    test_datamgr = SetDataManager(test_file, args.dataset_dir, args.image_size,
                                  mode = 'val',n_episode = n_test ,**few_shot_params)
    loader = test_datamgr.get_data_loader(aug=False)

    test_acc_record = np.zeros((n_test,))

    warmup_state = torch.load(osp.join(args.checkpoint_dir, 'max_acc' + '.pth'))['params']
    model.load_state_dict(warmup_state, strict=False)
    model.eval()

    ave_acc = Averager()
    with torch.no_grad():
        for i, batch in enumerate(loader, 1):
            data, index_label = batch[0].cuda(), batch[1].cuda()
            logits = model(data, 'test')
            acc = count_acc(logits, label)
            ave_acc.add(acc)
            test_acc_record[i - 1] = acc
            if i % print_freq == 0:
                print('batch {}: {:.2f}({:.2f})'.format(i, ave_acc.item() * 100, acc * 100))

    m, pm = compute_confidence_interval(test_acc_record)
    # print('Val Best Epoch {}, Acc {:.4f}, Test Acc {:.4f}'.format(trlog['max_acc_epoch'], trlog['max_acc'],
    #                                                               ave_acc.item()))
    print('Test Acc {:.4f} + {:.4f}'.format(m, pm))
    acc_str = '%4.2f' % (m * 100)
    with open(args.save_dir + '/result.txt', 'a') as f:
        f.write('%s %s\n' % (acc_str, args.name))
示例#14
0
    def main(self):
        print('Starting Training...')
        loss_hist = Averager()
        loss = []
        validation_losses = []
        precisions = []
        itr = 1

        for epoch in range(self.num_epochs):
            self.model.train()
            loss_hist.reset()

            for images, targets in self.train_data_loader:

                images = list(image.to(self.device) for image in images)
                #         targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
                targets = [{k: v.long().to(self.device)
                            for k, v in t.items()} for t in targets]

                loss_dict = self.model(images, targets)

                losses = sum(loss for loss in loss_dict.values())
                self.validate()
                loss_value = losses.item()

                loss_hist.send(loss_value)
                loss.append(loss_value)

                self.optimizer.zero_grad()
                losses.backward()
                self.optimizer.step()

                if math.isnan(loss_value):
                    plot_grad_flow(self.model.named_parameters())
                    raise ValueError('Loss is nan')
                if itr % 50 == 0:
                    print(f"Iteration #{itr} loss: {loss_value}")

                itr += 1

            # update the learning rate
            if self.lr_scheduler is not None:
                self.lr_scheduler.step()

            if self.val_dataset:
                precision, validation_loss = self.validate()
                precisions.append(precision)
                validation_losses.append(validation_loss)
                print(f'Mean Precision for Validation Data: {precision}')
                print(f'Validation Loss: {validation_loss}')
            print(f"Epoch #{epoch} loss: {loss_hist.value}")
        print('Finished!')
        return loss, precisions, validation_losses
def validate(val_loader, model, device):
    model.eval()
    itr = 1
    loss_hist = Averager()
    loss_hist.reset()
    for images, targets, image_ids in val_loader:
        images = list(image.to(device) for image in images)
        targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
        loss_dict = model(images, targets)
        losses = sum(loss for loss in loss_dict)
        loss_value = losses.item()
        loss_hist.send(loss_value)
        if itr % 20 == 0:
            print(f"Iteration: {itr} loss: {loss_hist.value}")
        itr += 1
    return loss_hist.value
def perform_sarsa_lambda_comparison(n_episodes, n_avg):
    alphas = np.arange(
        0.2, 2.2, 0.2
    )  # Those are divided by N_TILINGS in sarsa to give final alpha value
    lam = 0.84
    results = defaultdict(lambda: np.zeros(len(alphas)))
    averager = Averager(
        GymEpisodeTaskFactory(env, n_episodes, SarsaLambdaFactory(env)))
    for i, alpha in np.ndenumerate(alphas):
        results['Sarsa(Lam) with replacing'][i] = -averager.average(
            (lam, alpha), n_avg, merge=average_steps_per_episode)

    averager = Averager(
        GymEpisodeTaskFactory(env, n_episodes,
                              TrueOnlineSarsaLambdaFactory(env)))
    for i, alpha in np.ndenumerate(alphas):
        results['True Online Sarsa(Lam)'][i] = -averager.average(
            (lam, alpha), n_avg, merge=average_steps_per_episode)

    plot_scatters_from_dict(results, '{}', alphas)
示例#17
0
def train(opt):
    """ dataset preparation """
    if not opt.data_filtering_off:
        print(
            'Filtering the images containing characters which are not in opt.character'
        )
        print(
            'Filtering the images whose label is longer than opt.batch_max_length'
        )

    opt.select_data = opt.select_data.split('-')
    opt.batch_ratio = opt.batch_ratio.split('-')
    train_dataset = Batch_Balanced_Dataset(opt)

    log = open(f'./saved_models/{opt.exp_name}/log_dataset.txt', 'a')
    AlignCollate_valid = AlignCollate(imgH=opt.imgH,
                                      imgW=opt.imgW,
                                      keep_ratio_with_pad=opt.PAD)
    valid_dataset, valid_dataset_log = hierarchical_dataset(
        root=opt.valid_data, opt=opt)
    valid_loader = torch.utils.data.DataLoader(
        valid_dataset,
        batch_size=opt.batch_size,
        shuffle=
        True,  # 'True' to check training progress with validation function.
        num_workers=int(opt.workers),
        collate_fn=AlignCollate_valid,
        pin_memory=True)
    log.write(valid_dataset_log)
    print('-' * 80)
    log.write('-' * 80 + '\n')
    log.close()
    """ model configuration """
    if 'CTC' in opt.Prediction:
        if opt.baiduCTC:
            converter = CTCLabelConverterForBaiduWarpctc(opt.character)
        else:
            converter = CTCLabelConverter(opt.character)
    else:
        converter = AttnLabelConverter(opt.character)
    opt.num_class = len(converter.character)

    if opt.rgb:
        opt.input_channel = 3
    model = Model(opt)
    print('model input parameters', opt.imgH, opt.imgW, opt.num_fiducial,
          opt.input_channel, opt.output_channel, opt.hidden_size,
          opt.num_class, opt.batch_max_length, opt.Transformation,
          opt.FeatureExtraction, opt.SequenceModeling, opt.Prediction)

    # weight initialization
    for name, param in model.named_parameters():
        if 'localization_fc2' in name:
            print(f'Skip {name} as it is already initialized')
            continue
        try:
            if 'bias' in name:
                init.constant_(param, 0.0)
            elif 'weight' in name:
                init.kaiming_normal_(param)
        except Exception as e:  # for batchnorm.
            if 'weight' in name:
                param.data.fill_(1)
            continue

    # data parallel for multi-GPU
    model = torch.nn.DataParallel(model).to(device)
    model.train()
    if opt.saved_model != '':
        print(f'loading pretrained model from {opt.saved_model}')
        if opt.FT:
            model.load_state_dict(torch.load(opt.saved_model), strict=False)
        else:
            model.load_state_dict(torch.load(opt.saved_model))
    print("Model:")
    print(model)
    """ setup loss """
    if 'CTC' in opt.Prediction:
        if opt.baiduCTC:
            # need to install warpctc. see our guideline.
            from warpctc_pytorch import CTCLoss
            criterion = CTCLoss()
        else:
            criterion = torch.nn.CTCLoss(zero_infinity=True).to(device)
    else:
        criterion = torch.nn.CrossEntropyLoss(ignore_index=0).to(
            device)  # ignore [GO] token = ignore index 0
    # loss averager
    loss_avg = Averager()

    # filter that only require gradient decent
    filtered_parameters = []
    params_num = []
    for p in filter(lambda p: p.requires_grad, model.parameters()):
        filtered_parameters.append(p)
        params_num.append(np.prod(p.size()))
    print('Trainable params num : ', sum(params_num))
    # [print(name, p.numel()) for name, p in filter(lambda p: p[1].requires_grad, model.named_parameters())]

    # setup optimizer
    if opt.adam:
        optimizer = optim.Adam(filtered_parameters,
                               lr=opt.lr,
                               betas=(opt.beta1, 0.999))
    else:
        optimizer = optim.Adadelta(filtered_parameters,
                                   lr=opt.lr,
                                   rho=opt.rho,
                                   eps=opt.eps)
    print("Optimizer:")
    print(optimizer)
    """ final options """
    # print(opt)
    with open(f'./saved_models/{opt.exp_name}/opt.txt', 'a') as opt_file:
        opt_log = '------------ Options -------------\n'
        args = vars(opt)
        for k, v in args.items():
            opt_log += f'{str(k)}: {str(v)}\n'
        opt_log += '---------------------------------------\n'
        print(opt_log)
        opt_file.write(opt_log)
    """ start training """
    start_iter = 0
    if opt.saved_model != '':
        try:
            start_iter = int(opt.saved_model.split('_')[-1].split('.')[0])
            print(f'continue to train, start_iter: {start_iter}')
        except:
            pass

    start_time = time.time()
    best_accuracy = -1
    best_norm_ED = -1
    iteration = start_iter

    while (True):
        # train part
        image_tensors, labels = train_dataset.get_batch()
        image = image_tensors.to(device)
        text, length = converter.encode(labels,
                                        batch_max_length=opt.batch_max_length)
        batch_size = image.size(0)

        if 'CTC' in opt.Prediction:
            preds = model(image, text)
            preds_size = torch.IntTensor([preds.size(1)] * batch_size)
            if opt.baiduCTC:
                preds = preds.permute(1, 0, 2)  # to use CTCLoss format
                cost = criterion(preds, text, preds_size, length) / batch_size
            else:
                preds = preds.log_softmax(2).permute(1, 0, 2)
                cost = criterion(preds, text, preds_size, length)

        else:
            preds = model(image, text[:, :-1])  # align with Attention.forward
            target = text[:, 1:]  # without [GO] Symbol
            cost = criterion(preds.view(-1, preds.shape[-1]),
                             target.contiguous().view(-1))

        model.zero_grad()
        cost.backward()
        torch.nn.utils.clip_grad_norm_(
            model.parameters(),
            opt.grad_clip)  # gradient clipping with 5 (Default)
        optimizer.step()

        loss_avg.add(cost)

        # validation part
        if (
                iteration + 1
        ) % opt.valInterval == 0 or iteration == 0:  # To see training progress, we also conduct validation when 'iteration == 0'
            elapsed_time = time.time() - start_time
            # for log
            with open(f'./saved_models/{opt.exp_name}/log_train.txt',
                      'a') as log:
                model.eval()
                with torch.no_grad():
                    valid_loss, current_accuracy, current_norm_ED, preds, confidence_score, labels, infer_time, length_of_data = validation(
                        model, criterion, valid_loader, converter, opt)
                model.train()

                # training loss and validation loss
                loss_log = f'[{iteration+1}/{opt.num_iter}] Train loss: {loss_avg.val():0.5f}, Valid loss: {valid_loss:0.5f}, Elapsed_time: {elapsed_time:0.5f}'
                loss_avg.reset()

                current_model_log = f'{"Current_accuracy":17s}: {current_accuracy:0.3f}, {"Current_norm_ED":17s}: {current_norm_ED:0.2f}'

                # keep best accuracy model (on valid dataset)
                if current_accuracy > best_accuracy:
                    best_accuracy = current_accuracy
                    torch.save(
                        model.state_dict(),
                        f'./saved_models/{opt.exp_name}/best_accuracy.pth')
                if current_norm_ED > best_norm_ED:
                    best_norm_ED = current_norm_ED
                    torch.save(
                        model.state_dict(),
                        f'./saved_models/{opt.exp_name}/best_norm_ED.pth')
                best_model_log = f'{"Best_accuracy":17s}: {best_accuracy:0.3f}, {"Best_norm_ED":17s}: {best_norm_ED:0.2f}'

                loss_model_log = f'{loss_log}\n{current_model_log}\n{best_model_log}'
                print(loss_model_log)
                log.write(loss_model_log + '\n')

                # show some predicted results
                dashed_line = '-' * 80
                head = f'{"Ground Truth":25s} | {"Prediction":25s} | Confidence Score & T/F'
                predicted_result_log = f'{dashed_line}\n{head}\n{dashed_line}\n'
                for gt, pred, confidence in zip(labels[:5], preds[:5],
                                                confidence_score[:5]):
                    if 'Attn' in opt.Prediction:
                        gt = gt[:gt.find('[s]')]
                        pred = pred[:pred.find('[s]')]

                    predicted_result_log += f'{gt:25s} | {pred:25s} | {confidence:0.4f}\t{str(pred == gt)}\n'
                predicted_result_log += f'{dashed_line}'
                print(predicted_result_log)
                log.write(predicted_result_log + '\n')

        # save model per 1e+5 iter.
        if (iteration + 1) % 1e+5 == 0:
            torch.save(
                model.state_dict(),
                f'./saved_models/{opt.exp_name}/iter_{iteration+1}.pth')

        if (iteration + 1) == opt.num_iter:
            print('end the training')
            sys.exit()
        iteration += 1
示例#18
0
def validation(model, criterion, evaluation_loader, converter, opt):
    """ validation or evaluation """
    n_correct = 0
    norm_ED = 0
    length_of_data = 0
    infer_time = 0
    valid_loss_avg = Averager()

    for i, (image_tensors, labels) in enumerate(evaluation_loader):
        batch_size = image_tensors.size(0)
        length_of_data = length_of_data + batch_size
        image = image_tensors.to(device)
        # For max length prediction
        length_for_pred = torch.IntTensor([opt.batch_max_length] *
                                          batch_size).to(device)
        text_for_pred = torch.LongTensor(batch_size, opt.batch_max_length +
                                         1).fill_(0).to(device)

        text_for_loss, length_for_loss = converter.encode(
            labels, batch_max_length=opt.batch_max_length)

        start_time = time.time()
        if 'CTC' in opt.Prediction:
            preds = model(image, text_for_pred).log_softmax(2)
            forward_time = time.time() - start_time

            # Calculate evaluation loss for CTC deocder.
            preds_size = torch.IntTensor([preds.size(1)] * batch_size)
            preds = preds.permute(1, 0, 2)  # to use CTCloss format

            # To avoid ctc_loss issue, disabled cudnn for the computation of the ctc_loss
            # https://github.com/jpuigcerver/PyLaia/issues/16
            torch.backends.cudnn.enabled = False
            cost = criterion(preds, text_for_loss, preds_size, length_for_loss)
            torch.backends.cudnn.enabled = True

            # Select max probabilty (greedy decoding) then decode index to character
            _, preds_index = preds.max(2)
            preds_index = preds_index.transpose(1, 0).contiguous().view(-1)
            preds_str = converter.decode(preds_index.data, preds_size.data)

        else:
            preds = model(image, text_for_pred, is_train=False)
            forward_time = time.time() - start_time

            preds = preds[:, :text_for_loss.shape[1] - 1, :]
            target = text_for_loss[:, 1:]  # without [GO] Symbol
            cost = criterion(preds.contiguous().view(-1, preds.shape[-1]),
                             target.contiguous().view(-1))

            # select max probabilty (greedy decoding) then decode index to character
            _, preds_index = preds.max(2)
            preds_str = converter.decode(preds_index, length_for_pred)
            labels = converter.decode(text_for_loss[:, 1:], length_for_loss)

        infer_time += forward_time
        valid_loss_avg.add(cost)

        # calculate accuracy.
        for pred, gt in zip(preds_str, labels):
            if 'Attn' in opt.Prediction:
                pred = pred[:pred.find(
                    '[s]')]  # prune after "end of sentence" token ([s])
                gt = gt[:gt.find('[s]')]

            if pred == gt:
                n_correct += 1
            if len(gt) == 0:
                norm_ED += 1
            else:
                norm_ED += edit_distance(pred, gt) / len(gt)

    accuracy = n_correct / float(length_of_data) * 100

    return valid_loss_avg.val(
    ), accuracy, norm_ED, preds_str, labels, infer_time, length_of_data
示例#19
0
def train(opt):
    """ dataset preparation """
    opt.select_data = opt.select_data.split('-')
    opt.batch_ratio = opt.batch_ratio.split('-')

    #import ipdb;ipdb.set_trace()

    train_dataset = Batch_Balanced_Dataset(opt)

    AlignCollate_valid = AlignCollate(imgH=opt.imgH,
                                      imgW=opt.imgW,
                                      keep_ratio_with_pad=opt.PAD)
    valid_dataset = hierarchical_dataset(root=opt.valid_data, opt=opt)
    valid_loader = torch.utils.data.DataLoader(
        valid_dataset,
        batch_size=opt.batch_size,
        shuffle=
        True,  # 'True' to check training progress with validation function.
        num_workers=int(opt.workers),
        collate_fn=AlignCollate_valid,
        pin_memory=True)
    print('-' * 80)
    """ model configuration """
    if 'CTC' in opt.Prediction:
        converter = CTCLabelConverter(opt.character)
    else:
        converter = AttnLabelConverter(opt.character)
    opt.num_class = len(converter.character)

    if opt.rgb:
        opt.input_channel = 3
    model = Model(opt)
    print('model input parameters', opt.imgH, opt.imgW, opt.num_fiducial,
          opt.input_channel, opt.output_channel, opt.hidden_size,
          opt.num_class, opt.batch_max_length, opt.Transformation,
          opt.FeatureExtraction, opt.SequenceModeling, opt.Prediction)

    # weight initialization
    for name, param in model.named_parameters():
        if 'localization_fc2' in name:
            print(f'Skip {name} as it is already initialized')
            continue
        try:
            if 'bias' in name:
                init.constant_(param, 0.0)
            elif 'weight' in name:
                init.kaiming_normal_(param)
        except Exception as e:  # for batchnorm.
            if 'weight' in name:
                param.data.fill_(1)
            continue

    # data parallel for multi-GPU

    model = torch.nn.DataParallel(model).to(device)
    model.train()

    if opt.continue_model != '':
        print(f'loading pretrained model from {opt.continue_model}')
        model.load_state_dict(torch.load(opt.continue_model))
    print("Model:")
    #print(model)
    """ setup loss """
    if 'CTC' in opt.Prediction:
        criterion = torch.nn.CTCLoss(zero_infinity=True).to(device)
    else:
        criterion = torch.nn.CrossEntropyLoss(ignore_index=0).to(
            device)  # ignore [GO] token = ignore index 0
    # loss averager
    loss_avg = Averager()

    # filter that only require gradient decent
    filtered_parameters = []
    params_num = []
    for p in filter(lambda p: p.requires_grad, model.parameters()):
        filtered_parameters.append(p)
        params_num.append(np.prod(p.size()))
    print('Trainable params num : ', sum(params_num))
    # [print(name, p.numel()) for name, p in filter(lambda p: p[1].requires_grad, model.named_parameters())]

    # setup optimizer
    if opt.adam:
        optimizer = optim.Adam(filtered_parameters,
                               lr=opt.lr,
                               betas=(opt.beta1, 0.999))
    else:
        optimizer = optim.Adadelta(filtered_parameters,
                                   lr=opt.lr,
                                   rho=opt.rho,
                                   eps=opt.eps)
    print("Optimizer:")
    print(optimizer)
    """ final options """
    # print(opt)
    with open(f'./saved_models/{opt.experiment_name}/opt.txt',
              'a',
              encoding="utf-8") as opt_file:
        opt_log = '------------ Options -------------\n'
        args = vars(opt)
        for k, v in args.items():
            opt_log += f'{str(k)}: {str(v)}\n'
        opt_log += '---------------------------------------\n'
        print(opt_log)
        opt_file.write(opt_log)
    """ start training """
    start_iter = 0
    if opt.continue_model != '':
        start_iter = int(opt.continue_model.split('_')[-1].split('.')[0])
        print(f'continue to train, start_iter: {start_iter}')

    start_time = time.time()
    best_accuracy = -1
    best_norm_ED = 1e+6
    i = start_iter

    while (True):
        # train part
        image_tensors, labels = train_dataset.get_batch()

        image = image_tensors.to(device)
        text, length = converter.encode(labels,
                                        batch_max_length=opt.batch_max_length)
        batch_size = image.size(0)
        #import ipdb;ipdb.set_trace()

        if 'CTC' in opt.Prediction:
            preds = model(image, text).log_softmax(2)

            preds_size = torch.IntTensor([preds.size(1)] *
                                         batch_size).to(device)
            preds = preds.permute(1, 0, 2)  # to use CTCLoss format

            # To avoid ctc_loss issue, disabled cudnn for the computation of the ctc_loss
            # https://github.com/jpuigcerver/PyLaia/issues/16
            torch.backends.cudnn.enabled = False
            cost = criterion(preds, text, preds_size, length)
            torch.backends.cudnn.enabled = True

        else:

            preds = model(image, text[:, :-1])  # align with Attention.forward

            target = text[:, 1:]  # without [GO] Symbol
            cost = criterion(preds.view(-1, preds.shape[-1]),
                             target.contiguous().view(-1))

        model.zero_grad()
        cost.backward()
        torch.nn.utils.clip_grad_norm_(
            model.parameters(),
            opt.grad_clip)  # gradient clipping with 5 (Default)
        optimizer.step()

        loss_avg.add(cost)

        # validation part
        if i % opt.valInterval == 0:
            elapsed_time = time.time() - start_time
            print(
                f'[{i}/{opt.num_iter}] Loss: {loss_avg.val():0.5f} elapsed_time: {elapsed_time:0.5f}'
            )
            # for log
            with open(f'./saved_models/{opt.experiment_name}/log_train.txt',
                      'a',
                      encoding="utf-8") as log:
                log.write(
                    f'[{i}/{opt.num_iter}] Loss: {loss_avg.val():0.5f} elapsed_time: {elapsed_time:0.5f}\n'
                )
                loss_avg.reset()

                model.eval()
                with torch.no_grad():
                    valid_loss, current_accuracy, current_norm_ED, preds, labels, infer_time, length_of_data = validation(
                        model, criterion, valid_loader, converter, opt)
                model.train()

                for pred, gt in zip(preds[:5], labels[:5]):
                    if 'Attn' in opt.Prediction:
                        pred = pred[:pred.find('[s]')]
                        gt = gt[:gt.find('[s]')]
                    print(f'{pred:20s}, gt: {gt:20s},   {str(pred == gt)}')
                    #pred = pred.encode('utf-8')
                    #gt = gt.encode('utf-8')
                    log.write(
                        f'{pred:20s}, gt: {gt:20s},   {str(pred == gt)}\n')

                valid_log = f'[{i}/{opt.num_iter}] valid loss: {valid_loss:0.5f}'
                valid_log += f' accuracy: {current_accuracy:0.3f}, norm_ED: {current_norm_ED:0.2f}'
                print(valid_log)
                log.write(valid_log + '\n')

                # keep best accuracy model
                if current_accuracy > best_accuracy:
                    best_accuracy = current_accuracy
                    torch.save(
                        model.state_dict(),
                        f'./saved_models/{opt.experiment_name}/best_accuracy.pth'
                    )
                if current_norm_ED < best_norm_ED:
                    best_norm_ED = current_norm_ED
                    torch.save(
                        model.state_dict(),
                        f'./saved_models/{opt.experiment_name}/best_norm_ED.pth'
                    )
                best_model_log = f'best_accuracy: {best_accuracy:0.3f}, best_norm_ED: {best_norm_ED:0.2f}'
                print(best_model_log)
                log.write(best_model_log + '\n')

        # save model per 1e+5 iter.
        if (i + 1) % 1e+5 == 0:
            torch.save(model.state_dict(),
                       f'./saved_models/{opt.experiment_name}/iter_{i+1}.pth')

        if i == opt.num_iter:
            print('end the training')
            sys.exit()
        i += 1
def train(opt):
    plotDir = os.path.join(opt.exp_dir,opt.exp_name,'plots')
    if not os.path.exists(plotDir):
        os.makedirs(plotDir)
    
    lib.print_model_settings(locals().copy())

    """ dataset preparation """
    if not opt.data_filtering_off:
        print('Filtering the images containing characters which are not in opt.character')
        print('Filtering the images whose label is longer than opt.batch_max_length')
        # see https://github.com/clovaai/deep-text-recognition-benchmark/blob/6593928855fb7abb999a99f428b3e4477d4ae356/dataset.py#L130

    opt.select_data = opt.select_data.split('-')
    opt.batch_ratio = opt.batch_ratio.split('-')

    #considering the real images for discriminator
    opt.batch_size = opt.batch_size*2

    train_dataset = Batch_Balanced_Dataset(opt)

    log = open(os.path.join(opt.exp_dir,opt.exp_name,'log_dataset.txt'), 'a')
    AlignCollate_valid = AlignCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio_with_pad=opt.PAD)
    valid_dataset, valid_dataset_log = hierarchical_dataset(root=opt.valid_data, opt=opt)
    valid_loader = torch.utils.data.DataLoader(
        valid_dataset, batch_size=opt.batch_size,
        shuffle=False,  # 'True' to check training progress with validation function.
        num_workers=int(opt.workers),
        collate_fn=AlignCollate_valid, pin_memory=True)
    log.write(valid_dataset_log)
    print('-' * 80)
    log.write('-' * 80 + '\n')
    log.close()
    
    """ model configuration """
    if 'CTC' in opt.Prediction:
        converter = CTCLabelConverter(opt.character)
    else:
        converter = AttnLabelConverter(opt.character)
    opt.num_class = len(converter.character)

    if opt.rgb:
        opt.input_channel = 3
    
    model = AdaINGen(opt)
    ocrModel = Model(opt)
    disModel = MsImageDisV1(opt)
    
    print('model input parameters', opt.imgH, opt.imgW, opt.num_fiducial, opt.input_channel, opt.output_channel,
          opt.hidden_size, opt.num_class, opt.batch_max_length, opt.Transformation, opt.FeatureExtraction,
          opt.SequenceModeling, opt.Prediction)

    
    #  weight initialization
    for currModel in [model, ocrModel, disModel]:
        for name, param in currModel.named_parameters():
            if 'localization_fc2' in name:
                print(f'Skip {name} as it is already initialized')
                continue
            try:
                if 'bias' in name:
                    init.constant_(param, 0.0)
                elif 'weight' in name:
                    init.kaiming_normal_(param)
            except Exception as e:  # for batchnorm.
                if 'weight' in name:
                    param.data.fill_(1)
                continue
    
    
    # data parallel for multi-GPU
    ocrModel = torch.nn.DataParallel(ocrModel).to(device)
    if not opt.ocrFixed:
        ocrModel.train()
    else:
        ocrModel.module.Transformation.eval()
        ocrModel.module.FeatureExtraction.eval()
        ocrModel.module.AdaptiveAvgPool.eval()
        # ocrModel.module.SequenceModeling.eval()
        ocrModel.module.Prediction.eval()

    model = torch.nn.DataParallel(model).to(device)
    model.train()
    
    disModel = torch.nn.DataParallel(disModel).to(device)
    disModel.train()

    if opt.modelFolderFlag:
        
        if len(glob.glob(os.path.join(opt.exp_dir,opt.exp_name,"iter_*_synth.pth")))>0:
            opt.saved_synth_model = glob.glob(os.path.join(opt.exp_dir,opt.exp_name,"iter_*_synth.pth"))[-1]
        
        if len(glob.glob(os.path.join(opt.exp_dir,opt.exp_name,"iter_*_dis.pth")))>0:
            opt.saved_dis_model = glob.glob(os.path.join(opt.exp_dir,opt.exp_name,"iter_*_dis.pth"))[-1]

    #loading pre-trained model
    if opt.saved_ocr_model != '' and opt.saved_ocr_model != 'None':
        print(f'loading pretrained ocr model from {opt.saved_ocr_model}')
        if opt.FT:
            ocrModel.load_state_dict(torch.load(opt.saved_ocr_model), strict=False)
        else:
            ocrModel.load_state_dict(torch.load(opt.saved_ocr_model))
    print("OCRModel:")
    print(ocrModel)

    if opt.saved_synth_model != '' and opt.saved_synth_model != 'None':
        print(f'loading pretrained synth model from {opt.saved_synth_model}')
        if opt.FT:
            model.load_state_dict(torch.load(opt.saved_synth_model), strict=False)
        else:
            model.load_state_dict(torch.load(opt.saved_synth_model))
    print("SynthModel:")
    print(model)

    if opt.saved_dis_model != '' and opt.saved_dis_model != 'None':
        print(f'loading pretrained discriminator model from {opt.saved_dis_model}')
        if opt.FT:
            disModel.load_state_dict(torch.load(opt.saved_dis_model), strict=False)
        else:
            disModel.load_state_dict(torch.load(opt.saved_dis_model))
    print("DisModel:")
    print(disModel)

    """ setup loss """
    if 'CTC' in opt.Prediction:
        ocrCriterion = torch.nn.CTCLoss(zero_infinity=True).to(device)
    else:
        ocrCriterion = torch.nn.CrossEntropyLoss(ignore_index=0).to(device)  # ignore [GO] token = ignore index 0
    
    recCriterion = torch.nn.L1Loss()
    styleRecCriterion = torch.nn.L1Loss()

    # loss averager
    loss_avg_ocr = Averager()
    loss_avg = Averager()
    loss_avg_dis = Averager()

    loss_avg_ocrRecon_1 = Averager()
    loss_avg_ocrRecon_2 = Averager()
    loss_avg_gen = Averager()
    loss_avg_imgRecon = Averager()
    loss_avg_styRecon = Averager()

    ##---------------------------------------##
    # filter that only require gradient decent
    filtered_parameters = []
    params_num = []
    for p in filter(lambda p: p.requires_grad, model.parameters()):
        filtered_parameters.append(p)
        params_num.append(np.prod(p.size()))
    print('Trainable params num : ', sum(params_num))
    # [print(name, p.numel()) for name, p in filter(lambda p: p[1].requires_grad, model.named_parameters())]

    # setup optimizer
    if opt.optim=='adam':
        optimizer = optim.Adam(filtered_parameters, lr=opt.lr, betas=(opt.beta1, opt.beta2), weight_decay=opt.weight_decay)
    else:
        optimizer = optim.Adadelta(filtered_parameters, lr=opt.lr, rho=opt.rho, eps=opt.eps, weight_decay=opt.weight_decay)
    print("SynthOptimizer:")
    print(optimizer)
    

    #filter parameters for OCR training
    ocr_filtered_parameters = []
    ocr_params_num = []
    for p in filter(lambda p: p.requires_grad, ocrModel.parameters()):
        ocr_filtered_parameters.append(p)
        ocr_params_num.append(np.prod(p.size()))
    print('OCR Trainable params num : ', sum(ocr_params_num))


    # setup optimizer
    if opt.optim=='adam':
        ocr_optimizer = optim.Adam(ocr_filtered_parameters, lr=opt.lr, betas=(opt.beta1, opt.beta2), weight_decay=opt.weight_decay)
    else:
        ocr_optimizer = optim.Adadelta(ocr_filtered_parameters, lr=opt.lr, rho=opt.rho, eps=opt.eps, weight_decay=opt.weight_decay)
    print("OCROptimizer:")
    print(ocr_optimizer)

    #filter parameters for OCR training
    dis_filtered_parameters = []
    dis_params_num = []
    for p in filter(lambda p: p.requires_grad, disModel.parameters()):
        dis_filtered_parameters.append(p)
        dis_params_num.append(np.prod(p.size()))
    print('Dis Trainable params num : ', sum(dis_params_num))

    # setup optimizer
    if opt.optim=='adam':
        dis_optimizer = optim.Adam(dis_filtered_parameters, lr=opt.lr, betas=(opt.beta1, opt.beta2), weight_decay=opt.weight_decay)
    else:
        dis_optimizer = optim.Adadelta(dis_filtered_parameters, lr=opt.lr, rho=opt.rho, eps=opt.eps, weight_decay=opt.weight_decay)
    print("DisOptimizer:")
    print(dis_optimizer)
    ##---------------------------------------##

    """ final options """
    with open(os.path.join(opt.exp_dir,opt.exp_name,'opt.txt'), 'a') as opt_file:
        opt_log = '------------ Options -------------\n'
        args = vars(opt)
        for k, v in args.items():
            opt_log += f'{str(k)}: {str(v)}\n'
        opt_log += '---------------------------------------\n'
        print(opt_log)
        opt_file.write(opt_log)

    """ start training """
    start_iter = 0
    
    if opt.saved_synth_model != '' and opt.saved_synth_model != 'None':
        try:
            start_iter = int(opt.saved_synth_model.split('_')[-2].split('.')[0])
            print(f'continue to train, start_iter: {start_iter}')
        except:
            pass

    
    #get schedulers
    scheduler = get_scheduler(optimizer,opt)
    ocr_scheduler = get_scheduler(ocr_optimizer,opt)
    dis_scheduler = get_scheduler(dis_optimizer,opt)

    start_time = time.time()
    best_accuracy = -1
    best_norm_ED = -1
    best_accuracy_ocr = -1
    best_norm_ED_ocr = -1
    iteration = start_iter
    cntr=0


    while(True):
        # train part
        
        if opt.lr_policy !="None":
            scheduler.step()
            ocr_scheduler.step()
            dis_scheduler.step()
            
        image_tensors_all, labels_1_all, labels_2_all = train_dataset.get_batch()
        
        # ## comment
        # pdb.set_trace()
        # for imgCntr in range(image_tensors.shape[0]):
        #     save_image(tensor2im(image_tensors[imgCntr]),'temp/'+str(imgCntr)+'.png')
        # pdb.set_trace()
        # ###
        # print(cntr)
        cntr+=1
        disCnt = int(image_tensors_all.size(0)/2)
        image_tensors, image_tensors_real, labels_gt, labels_2 = image_tensors_all[:disCnt], image_tensors_all[disCnt:disCnt+disCnt], labels_1_all[:disCnt], labels_2_all[:disCnt]

        image = image_tensors.to(device)
        image_real = image_tensors_real.to(device)
        batch_size = image.size(0)

        
        ##-----------------------------------##
        #generate text(labels) from ocr.forward
        if opt.ocrFixed:
            # ocrModel.eval()
            length_for_pred = torch.IntTensor([opt.batch_max_length] * batch_size).to(device)
            text_for_pred = torch.LongTensor(batch_size, opt.batch_max_length + 1).fill_(0).to(device)
            
            if 'CTC' in opt.Prediction:
                preds = ocrModel(image, text_for_pred)
                preds = preds[:, :text_for_loss.shape[1] - 1, :]
                preds_size = torch.IntTensor([preds.size(1)] * batch_size)
                _, preds_index = preds.max(2)
                labels_1 = converter.decode(preds_index.data, preds_size.data)
            else:
                preds = ocrModel(image, text_for_pred, is_train=False)
                _, preds_index = preds.max(2)
                labels_1 = converter.decode(preds_index, length_for_pred)
                for idx, pred in enumerate(labels_1):
                    pred_EOS = pred.find('[s]')
                    labels_1[idx] = pred[:pred_EOS]  # prune after "end of sentence" token ([s])
            # ocrModel.train()
        else:
            labels_1 = labels_gt
        
        ##-----------------------------------##

        text_1, length_1 = converter.encode(labels_1, batch_max_length=opt.batch_max_length)
        text_2, length_2 = converter.encode(labels_2, batch_max_length=opt.batch_max_length)
        
        #forward pass from style and word generator
        images_recon_1, images_recon_2, style = model(image, text_1, text_2)

        if 'CTC' in opt.Prediction:
            
            if not opt.ocrFixed:
                #ocr training with orig image
                preds_ocr = ocrModel(image, text_1)
                preds_size_ocr = torch.IntTensor([preds_ocr.size(1)] * batch_size)
                preds_ocr = preds_ocr.log_softmax(2).permute(1, 0, 2)

                ocrCost_train = ocrCriterion(preds_ocr, text_1, preds_size_ocr, length_1)

            
            #content loss for reconstructed images
            preds_1 = ocrModel(images_recon_1, text_1)
            preds_size_1 = torch.IntTensor([preds_1.size(1)] * batch_size)
            preds_1 = preds_1.log_softmax(2).permute(1, 0, 2)

            preds_2 = ocrModel(images_recon_2, text_2)
            preds_size_2 = torch.IntTensor([preds_2.size(1)] * batch_size)
            preds_2 = preds_2.log_softmax(2).permute(1, 0, 2)
            ocrCost_1 = ocrCriterion(preds_1, text_1, preds_size_1, length_1)
            ocrCost_2 = ocrCriterion(preds_2, text_2, preds_size_2, length_2)
            # ocrCost = 0.5*( ocrCost_1 + ocrCost_2 )

        else:
            if not opt.ocrFixed:
                #ocr training with orig image
                preds_ocr = ocrModel(image, text_1[:, :-1])  # align with Attention.forward
                target_ocr = text_1[:, 1:]  # without [GO] Symbol

                ocrCost_train = ocrCriterion(preds_ocr.view(-1, preds_ocr.shape[-1]), target_ocr.contiguous().view(-1))

            #content loss for reconstructed images
            preds_1 = ocrModel(images_recon_1, text_1[:, :-1], is_train=False)  # align with Attention.forward
            target_1 = text_1[:, 1:]  # without [GO] Symbol

            preds_2 = ocrModel(images_recon_2, text_2[:, :-1], is_train=False)  # align with Attention.forward
            target_2 = text_2[:, 1:]  # without [GO] Symbol

            ocrCost_1 = ocrCriterion(preds_1.view(-1, preds_1.shape[-1]), target_1.contiguous().view(-1))
            ocrCost_2 = ocrCriterion(preds_2.view(-1, preds_2.shape[-1]), target_2.contiguous().view(-1))
            # ocrCost = 0.5*(ocrCost_1+ocrCost_2)
        
        if not opt.ocrFixed:
            #training OCR
            ocrModel.zero_grad()
            ocrCost_train.backward()
            # torch.nn.utils.clip_grad_norm_(ocrModel.parameters(), opt.grad_clip)  # gradient clipping with 5 (Default)
            ocr_optimizer.step()
            #if ocr is fixed; ignore this loss
            loss_avg_ocr.add(ocrCost_train)
        else:
            loss_avg_ocr.add(torch.tensor(0.0))

        
        #Domain discriminator: Dis update
        disModel.zero_grad()
        disCost = opt.disWeight*0.5*(disModel.module.calc_dis_loss(images_recon_1.detach(), image_real) + disModel.module.calc_dis_loss(images_recon_2.detach(), image))
        disCost.backward()
        # torch.nn.utils.clip_grad_norm_(disModel.parameters(), opt.grad_clip)  # gradient clipping with 5 (Default)
        dis_optimizer.step()
        loss_avg_dis.add(disCost)
        
        # #[Style Encoder] + [Word Generator] update
        #Adversarial loss
        disGenCost = 0.5*(disModel.module.calc_gen_loss(images_recon_1)+disModel.module.calc_gen_loss(images_recon_2))

        #Input reconstruction loss
        recCost = recCriterion(images_recon_1,image)

        #Pair style reconstruction loss
        if opt.styleReconWeight == 0.0:
            styleRecCost = torch.tensor(0.0)
        else:
            if opt.styleDetach:
                styleRecCost = styleRecCriterion(model(images_recon_2, None, None, styleFlag=True), style.detach())
            else:
                styleRecCost = styleRecCriterion(model(images_recon_2, None, None, styleFlag=True), style)

        #OCR Content cost
        ocrCost = 0.5*(ocrCost_1+ocrCost_2)

        cost = opt.ocrWeight*ocrCost + opt.reconWeight*recCost + opt.disWeight*disGenCost + opt.styleReconWeight*styleRecCost

        model.zero_grad()
        ocrModel.zero_grad()
        disModel.zero_grad()
        cost.backward()
        # torch.nn.utils.clip_grad_norm_(model.parameters(), opt.grad_clip)  # gradient clipping with 5 (Default)
        optimizer.step()
        loss_avg.add(cost)

        #Individual losses
        loss_avg_ocrRecon_1.add(opt.ocrWeight*0.5*ocrCost_1)
        loss_avg_ocrRecon_2.add(opt.ocrWeight*0.5*ocrCost_2)
        loss_avg_gen.add(opt.disWeight*disGenCost)
        loss_avg_imgRecon.add(opt.reconWeight*recCost)
        loss_avg_styRecon.add(opt.styleReconWeight*styleRecCost)

        # validation part
        if (iteration + 1) % opt.valInterval == 0 or iteration == 0: # To see training progress, we also conduct validation when 'iteration == 0' 
            
            #Save training images
            os.makedirs(os.path.join(opt.exp_dir,opt.exp_name,'trainImages',str(iteration)), exist_ok=True)
            for trImgCntr in range(batch_size):
                try:
                    save_image(tensor2im(image[trImgCntr].detach()),os.path.join(opt.exp_dir,opt.exp_name,'trainImages',str(iteration),str(trImgCntr)+'_input_'+labels_gt[trImgCntr]+'.png'))
                    save_image(tensor2im(images_recon_1[trImgCntr].detach()),os.path.join(opt.exp_dir,opt.exp_name,'trainImages',str(iteration),str(trImgCntr)+'_recon_'+labels_1[trImgCntr]+'.png'))
                    save_image(tensor2im(images_recon_2[trImgCntr].detach()),os.path.join(opt.exp_dir,opt.exp_name,'trainImages',str(iteration),str(trImgCntr)+'_pair_'+labels_2[trImgCntr]+'.png'))
                except:
                    print('Warning while saving training image')
            
            elapsed_time = time.time() - start_time
            # for log
            
            with open(os.path.join(opt.exp_dir,opt.exp_name,'log_train.txt'), 'a') as log:
                model.eval()
                ocrModel.module.Transformation.eval()
                ocrModel.module.FeatureExtraction.eval()
                ocrModel.module.AdaptiveAvgPool.eval()
                ocrModel.module.SequenceModeling.eval()
                ocrModel.module.Prediction.eval()
                disModel.eval()
                
                with torch.no_grad():                    
                    valid_loss, current_accuracy, current_norm_ED, preds, confidence_score, labels, infer_time, length_of_data = validation_synth_lrw_res(
                        iteration, model, ocrModel, disModel, recCriterion, styleRecCriterion, ocrCriterion, valid_loader, converter, opt)
                model.train()
                if not opt.ocrFixed:
                    ocrModel.train()
                else:
                #     ocrModel.module.Transformation.eval()
                #     ocrModel.module.FeatureExtraction.eval()
                #     ocrModel.module.AdaptiveAvgPool.eval()
                    ocrModel.module.SequenceModeling.train()
                #     ocrModel.module.Prediction.eval()

                disModel.train()

                # training loss and validation loss
                loss_log = f'[{iteration+1}/{opt.num_iter}] Train OCR loss: {loss_avg_ocr.val():0.5f}, Train Synth loss: {loss_avg.val():0.5f}, Train Dis loss: {loss_avg_dis.val():0.5f}, Valid OCR loss: {valid_loss[0]:0.5f}, Valid Synth loss: {valid_loss[1]:0.5f}, Valid Dis loss: {valid_loss[2]:0.5f}, Elapsed_time: {elapsed_time:0.5f}'
                

                current_model_log_ocr = f'{"Current_accuracy_OCR":17s}: {current_accuracy[0]:0.3f}, {"Current_norm_ED_OCR":17s}: {current_norm_ED[0]:0.2f}'
                current_model_log_1 = f'{"Current_accuracy_recon":17s}: {current_accuracy[1]:0.3f}, {"Current_norm_ED_recon":17s}: {current_norm_ED[1]:0.2f}'
                current_model_log_2 = f'{"Current_accuracy_pair":17s}: {current_accuracy[2]:0.3f}, {"Current_norm_ED_pair":17s}: {current_norm_ED[2]:0.2f}'
                
                #plotting
                lib.plot.plot(os.path.join(plotDir,'Train-OCR-Loss'), loss_avg_ocr.val().item())
                lib.plot.plot(os.path.join(plotDir,'Train-Synth-Loss'), loss_avg.val().item())
                lib.plot.plot(os.path.join(plotDir,'Train-Dis-Loss'), loss_avg_dis.val().item())
                
                lib.plot.plot(os.path.join(plotDir,'Train-OCR-Recon1-Loss'), loss_avg_ocrRecon_1.val().item())
                lib.plot.plot(os.path.join(plotDir,'Train-OCR-Recon2-Loss'), loss_avg_ocrRecon_2.val().item())
                lib.plot.plot(os.path.join(plotDir,'Train-Gen-Loss'), loss_avg_gen.val().item())
                lib.plot.plot(os.path.join(plotDir,'Train-ImgRecon1-Loss'), loss_avg_imgRecon.val().item())
                lib.plot.plot(os.path.join(plotDir,'Train-StyRecon2-Loss'), loss_avg_styRecon.val().item())

                lib.plot.plot(os.path.join(plotDir,'Valid-OCR-Loss'), valid_loss[0].item())
                lib.plot.plot(os.path.join(plotDir,'Valid-Synth-Loss'), valid_loss[1].item())
                lib.plot.plot(os.path.join(plotDir,'Valid-Dis-Loss'), valid_loss[2].item())

                lib.plot.plot(os.path.join(plotDir,'Valid-OCR-Recon1-Loss'), valid_loss[3].item())
                lib.plot.plot(os.path.join(plotDir,'Valid-OCR-Recon2-Loss'), valid_loss[4].item())
                lib.plot.plot(os.path.join(plotDir,'Valid-Gen-Loss'), valid_loss[5].item())
                lib.plot.plot(os.path.join(plotDir,'Valid-ImgRecon1-Loss'), valid_loss[6].item())
                lib.plot.plot(os.path.join(plotDir,'Valid-StyRecon2-Loss'), valid_loss[7].item())

                lib.plot.plot(os.path.join(plotDir,'Orig-OCR-WordAccuracy'), current_accuracy[0])
                lib.plot.plot(os.path.join(plotDir,'Recon-OCR-WordAccuracy'), current_accuracy[1])
                lib.plot.plot(os.path.join(plotDir,'Pair-OCR-WordAccuracy'), current_accuracy[2])

                lib.plot.plot(os.path.join(plotDir,'Orig-OCR-CharAccuracy'), current_norm_ED[0])
                lib.plot.plot(os.path.join(plotDir,'Recon-OCR-CharAccuracy'), current_norm_ED[1])
                lib.plot.plot(os.path.join(plotDir,'Pair-OCR-CharAccuracy'), current_norm_ED[2])
                

                # keep best accuracy model (on valid dataset)
                if current_accuracy[1] > best_accuracy:
                    best_accuracy = current_accuracy[1]
                    torch.save(model.state_dict(), os.path.join(opt.exp_dir,opt.exp_name,'best_accuracy.pth'))
                    torch.save(disModel.state_dict(), os.path.join(opt.exp_dir,opt.exp_name,'best_accuracy_dis.pth'))
                if current_norm_ED[1] > best_norm_ED:
                    best_norm_ED = current_norm_ED[1]
                    torch.save(model.state_dict(), os.path.join(opt.exp_dir,opt.exp_name,'best_norm_ED.pth'))
                    torch.save(disModel.state_dict(), os.path.join(opt.exp_dir,opt.exp_name,'best_norm_ED_dis.pth'))
                best_model_log = f'{"Best_accuracy_Recon":17s}: {best_accuracy:0.3f}, {"Best_norm_ED_Recon":17s}: {best_norm_ED:0.2f}'

                # keep best accuracy model (on valid dataset)
                if current_accuracy[0] > best_accuracy_ocr:
                    best_accuracy_ocr = current_accuracy[0]
                    if not opt.ocrFixed:
                        torch.save(ocrModel.state_dict(), os.path.join(opt.exp_dir,opt.exp_name,'best_accuracy_ocr.pth'))
                if current_norm_ED[0] > best_norm_ED_ocr:
                    best_norm_ED_ocr = current_norm_ED[0]
                    if not opt.ocrFixed:
                        torch.save(ocrModel.state_dict(), os.path.join(opt.exp_dir,opt.exp_name,'best_norm_ED_ocr.pth'))
                best_model_log_ocr = f'{"Best_accuracy_ocr":17s}: {best_accuracy_ocr:0.3f}, {"Best_norm_ED_ocr":17s}: {best_norm_ED_ocr:0.2f}'

                loss_model_log = f'{loss_log}\n{current_model_log_ocr}\n{current_model_log_1}\n{current_model_log_2}\n{best_model_log_ocr}\n{best_model_log}'
                print(loss_model_log)
                log.write(loss_model_log + '\n')

                # show some predicted results
                dashed_line = '-' * 80
                head = f'{"Ground Truth":32s} | {"Prediction":25s} | Confidence Score & T/F'
                predicted_result_log = f'{dashed_line}\n{head}\n{dashed_line}\n'
                
                for gt_ocr, pred_ocr, confidence_ocr, gt_1, pred_1, confidence_1, gt_2, pred_2, confidence_2 in zip(labels[0][:5], preds[0][:5], confidence_score[0][:5], labels[1][:5], preds[1][:5], confidence_score[1][:5], labels[2][:5], preds[2][:5], confidence_score[2][:5]):
                    if 'Attn' in opt.Prediction:
                        # gt_ocr = gt_ocr[:gt_ocr.find('[s]')]
                        pred_ocr = pred_ocr[:pred_ocr.find('[s]')]

                        # gt_1 = gt_1[:gt_1.find('[s]')]
                        pred_1 = pred_1[:pred_1.find('[s]')]

                        # gt_2 = gt_2[:gt_2.find('[s]')]
                        pred_2 = pred_2[:pred_2.find('[s]')]

                    predicted_result_log += f'{"ocr"}: {gt_ocr:27s} | {pred_ocr:25s} | {confidence_ocr:0.4f}\t{str(pred_ocr == gt_ocr)}\n'
                    predicted_result_log += f'{"recon"}: {gt_1:25s} | {pred_1:25s} | {confidence_1:0.4f}\t{str(pred_1 == gt_1)}\n'
                    predicted_result_log += f'{"pair"}: {gt_2:26s} | {pred_2:25s} | {confidence_2:0.4f}\t{str(pred_2 == gt_2)}\n'
                predicted_result_log += f'{dashed_line}'
                print(predicted_result_log)
                log.write(predicted_result_log + '\n')

                loss_avg_ocr.reset()
                loss_avg.reset()
                loss_avg_dis.reset()

                loss_avg_ocrRecon_1.reset()
                loss_avg_ocrRecon_2.reset()
                loss_avg_gen.reset()
                loss_avg_imgRecon.reset()
                loss_avg_styRecon.reset()

            lib.plot.flush()

        lib.plot.tick()

        # save model per 1e+5 iter.
        if (iteration) % 1e+5 == 0:
            torch.save(
                model.state_dict(), os.path.join(opt.exp_dir,opt.exp_name,'iter_'+str(iteration+1)+'_synth.pth'))
            if not opt.ocrFixed:
                torch.save(
                    ocrModel.state_dict(), os.path.join(opt.exp_dir,opt.exp_name,'iter_'+str(iteration+1)+'_ocr.pth'))
            torch.save(
                disModel.state_dict(), os.path.join(opt.exp_dir,opt.exp_name,'iter_'+str(iteration+1)+'_dis.pth'))

        if (iteration + 1) == opt.num_iter:
            print('end the training')
            sys.exit()
        iteration += 1
def train(opt):
    os.makedirs(opt.log, exist_ok=True)
    writer = SummaryWriter(opt.log)
    """ dataset preparation """
    if not opt.data_filtering_off:
        print(
            'Filtering the images containing characters which are not in opt.character'
        )
        print(
            'Filtering the images whose label is longer than opt.batch_max_length'
        )

    opt.select_data = opt.select_data.split('-')
    opt.batch_ratio = opt.batch_ratio.split('-')
    train_dataset = Batch_Balanced_Dataset(opt)

    log = open(f'./saved_models/{opt.exp_name}/log_dataset.txt', 'a')
    AlignCollate_valid = AlignCollate(imgH=opt.imgH,
                                      imgW=opt.imgW,
                                      keep_ratio_with_pad=opt.PAD)
    valid_dataset, valid_dataset_log = hierarchical_dataset(
        root=opt.valid_data, opt=opt)
    valid_loader = torch.utils.data.DataLoader(
        valid_dataset,
        batch_size=opt.batch_size,
        shuffle=
        True,  # 'True' to check training progress with validation function.
        num_workers=int(opt.workers),
        collate_fn=AlignCollate_valid,
        pin_memory=True)

    log.write(valid_dataset_log)
    print('-' * 80)
    log.write('-' * 80 + '\n')
    log.close()
    """ model configuration """

    ctc_converter = CTCLabelConverter(opt.character)
    attn_converter = AttnLabelConverter(opt.character)
    opt.num_class = len(attn_converter.character)

    if opt.rgb:
        opt.input_channel = 3
    model = Model(opt)

    print('model input parameters', opt.imgH, opt.imgW, opt.num_fiducial,
          opt.input_channel, opt.output_channel, opt.hidden_size,
          opt.num_class, opt.batch_max_length)

    # weight initialization
    for name, param in model.named_parameters():
        if 'localization_fc2' in name:
            print(f'Skip {name} as it is already initialized')
            continue
        try:
            if 'bias' in name:
                init.constant_(param, 0.0)
            elif 'weight' in name:
                init.kaiming_normal_(param)
        except Exception as e:  # for batchnorm.
            if 'weight' in name:
                param.data.fill_(1)
            continue

    # data parallel for multi-GPU
    model = torch.nn.DataParallel(model).to(device)
    model.train()
    if opt.saved_model != '':
        print(f'loading pretrained model from {opt.saved_model}')
        if opt.FT:
            model.load_state_dict(torch.load(opt.saved_model), strict=False)
        else:
            model.load_state_dict(torch.load(opt.saved_model))
    """ setup loss """
    loss_avg = Averager()
    ctc_loss = torch.nn.CTCLoss(zero_infinity=True).to(device)
    attn_loss = torch.nn.CrossEntropyLoss(ignore_index=0).to(device)

    # filter that only require gradient decent
    filtered_parameters = []
    params_num = []
    for p in filter(lambda p: p.requires_grad, model.parameters()):
        filtered_parameters.append(p)
        params_num.append(np.prod(p.size()))
    print('Trainable params num : ', sum(params_num))

    # setup optimizer
    if opt.adam:
        optimizer = optim.Adam(filtered_parameters,
                               lr=opt.lr,
                               betas=(opt.beta1, 0.999))
    else:
        optimizer = optim.Adadelta(filtered_parameters,
                                   lr=opt.lr,
                                   rho=opt.rho,
                                   eps=opt.eps)
    print("Optimizer:")
    """ final options """
    # print(opt)
    with open(f'./saved_models/{opt.exp_name}/opt.txt', 'a') as opt_file:
        opt_log = '------------ Options -------------\n'
        args = vars(opt)
        for k, v in args.items():
            opt_log += f'{str(k)}: {str(v)}\n'
        opt_log += '---------------------------------------\n'
        print(opt_log)
        opt_file.write(opt_log)
    """ start training """
    start_iter = 0
    if opt.saved_model != '':
        try:
            start_iter = int(opt.saved_model.split('_')[-1].split('.')[0])
            print(f'continue to train, start_iter: {start_iter}')
        except:
            pass

    start_time = time.time()
    best_accuracy = -1
    best_norm_ED = -1
    iteration = start_iter
    pbar = tqdm(range(opt.num_iter))

    for iteration in pbar:

        # train part
        image_tensors, labels = train_dataset.get_batch()
        image = image_tensors.to(device)
        ctc_text, ctc_length = ctc_converter.encode(
            labels, batch_max_length=opt.batch_max_length)
        attn_text, attn_length = attn_converter.encode(
            labels, batch_max_length=opt.batch_max_length)

        batch_size = image.size(0)

        preds, refiner = model(image, attn_text[:, :-1])

        refiner_size = torch.IntTensor([refiner.size(1)] * batch_size)
        refiner = refiner.log_softmax(2).permute(1, 0, 2)
        refiner_loss = ctc_loss(refiner, ctc_text, refiner_size, ctc_length)

        total_loss = opt.lambda_ctc * refiner_loss
        target = attn_text[:, 1:]  # without [GO] Symbol
        for pred in preds:
            total_loss += opt.lambda_attn * attn_loss(
                pred.view(-1, pred.shape[-1]),
                target.contiguous().view(-1))

        model.zero_grad()
        total_loss.backward()
        torch.nn.utils.clip_grad_norm_(
            model.parameters(),
            opt.grad_clip)  # gradient clipping with 5 (Default)
        optimizer.step()
        loss_avg.add(total_loss)
        if loss_avg.val() <= 0.6:
            opt.grad_clip = 2
        if loss_avg.val() <= 0.3:
            opt.grad_clip = 1

        preds = (p.cpu() for p in preds)
        refiner = refiner.cpu()
        image = image.cpu()
        torch.cuda.empty_cache()

        writer.add_scalar('train_loss', loss_avg.val(), iteration)
        pbar.set_description('Iteration {0}/{1}, AvgLoss {2}'.format(
            iteration, opt.num_iter, loss_avg.val()))

        # validation part
        if (iteration + 1) % opt.valInterval == 0 or iteration == 0:
            elapsed_time = time.time() - start_time
            # for log
            with open(f'./saved_models/{opt.exp_name}/log_train.txt',
                      'a') as log:
                model.eval()
                with torch.no_grad():
                    valid_loss, current_accuracy, current_norm_ED, preds, confidence_score, labels, infer_time, length_of_data = validation(
                        model, attn_loss, valid_loader, attn_converter, opt)
                model.train()

                # training loss and validation loss
                loss_log = f'[{iteration+1}/{opt.num_iter}] Train loss: {loss_avg.val():0.5f}, Valid loss: {valid_loss:0.5f}, Elapsed_time: {elapsed_time:0.5f}'
                writer.add_scalar('Val_loss', valid_loss)
                pbar.set_description(loss_log)
                loss_avg.reset()

                current_model_log = f'{"Current_accuracy":17s}: {current_accuracy:0.3f}, {"Current_norm_ED":17s}: {current_norm_ED:0.2f}'

                # keep best accuracy model (on valid dataset)
                if current_accuracy > best_accuracy:
                    best_accuracy = current_accuracy
                    torch.save(
                        model.state_dict(),
                        f'./saved_models/{opt.exp_name}/best_accuracy_{str(best_accuracy)}.pth'
                    )
                if current_norm_ED > best_norm_ED:
                    best_norm_ED = current_norm_ED
                    torch.save(
                        model.state_dict(),
                        f'./saved_models/{opt.exp_name}/best_norm_ED.pth')
                best_model_log = f'{"Best_accuracy":17s}: {best_accuracy:0.3f}, {"Best_norm_ED":17s}: {best_norm_ED:0.2f}'

                loss_model_log = f'{loss_log}\n{current_model_log}\n{best_model_log}'
                # print(loss_model_log)

                log.write(loss_model_log + '\n')

                # show some predicted results
                dashed_line = '-' * 80
                head = f'{"Ground Truth":25s} | {"Prediction":25s} | Confidence Score & T/F'
                predicted_result_log = f'{dashed_line}\n{head}\n{dashed_line}\n'
                for gt, pred, confidence in zip(labels[:5], preds[:5],
                                                confidence_score[:5]):
                    if 'Attn' or 'Transformer' in opt.Prediction:
                        gt = gt[:gt.find('[s]')]
                        pred = pred[:pred.find('[s]')]

                    predicted_result_log += f'{gt:25s} | {pred:25s} | {confidence:0.4f}\t{str(pred == gt)}\n'
                predicted_result_log += f'{dashed_line}'
                log.write(predicted_result_log + '\n')

        # save model per 1e+3 iter.
        if (iteration + 1) % 1e+3 == 0:
            torch.save(model.state_dict(),
                       f'./saved_models/{opt.exp_name}/SCATTER_STR.pth')

        if (iteration + 1) == opt.num_iter:
            print('end the training')
            sys.exit()
示例#22
0
    set_gpu(args.gpu)

    dataset = MiniImageNet('test')
    sampler = CategoriesSampler(dataset.label, args.batch, args.way,
                                args.shot + args.query)
    loader = DataLoader(dataset,
                        batch_sampler=sampler,
                        num_workers=8,
                        pin_memory=True)

    model = Convnet().cuda()
    model.load_state_dict(torch.load(args.load))
    model.eval()

    ave_acc = Averager()
    for i, batch in enumerate(loader, 1):
        # data, _ = [_.cuda() for _ in batch]
        data, mask, bg, _ = [_.cuda() for _ in batch]
        p = args.shot * args.test_way
        data_spt, data_query = data[:p], data[p:]
        mask_spt = mask[:p]
        mask_spt[mask_spt != 0] = 1

        N, c, w, h = data.size()
        data_spt_repeat = data_spt.unsqueeze(0).repeat(N, 1, 1, 1, 1)
        mask_spt_repeat = mask_spt.unsqueeze(0).repeat(N, 1, 1, 1, 1)
        bg_repeat = bg.unsqueeze(0).repeat(p, 1, 1, 1,
                                           1).permute(1, 0, 2, 3, 4)
        merge_spt = data_spt_repeat * mask_spt_repeat.float() + bg_repeat * (
            1 - mask_spt_repeat.float())
示例#23
0
    trlog = {}
    trlog['args'] = vars(args)
    trlog['train_loss'] = []
    trlog['val_loss'] = []
    trlog['train_acc'] = []
    trlog['val_acc'] = []
    trlog['max_acc'] = 0.0

    timer = Timer()

    for epoch in range(1, max_epoch + 1):
        lr_scheduler.step()

        model.train()

        tl = Averager()
        ta = Averager()

        for i, batch in enumerate(train_loader, 1):
            data, _ = [_ for _ in batch]
            p = shot * train_way
            data_shot, data_query = data[:p], data[p:]

            proto = model(data_shot)
            proto = proto.reshape(shot, train_way, -1).mean(dim=0)

            label = torch.arange(train_way).repeat(2)
            label = label.type(torch.LongTensor)

            logits = euclidean_metric(model(data_query), proto)
            loss = F.cross_entropy(logits, label)
示例#24
0
    distance = distance_metric(args.distance, model)

    # optimizer
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
    #lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.5)
    steplr_after = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer, T_max=len(train_loader) * args.epochs, eta_min=1e-5)
    lr_scheduler = WarmupScheduler(optimizer,
                                   multiplier=1,
                                   total_epoch=args.warmup_epochs *
                                   len(train_loader),
                                   after_scheduler=steplr_after)

    best_acc = 0
    train_loss, train_acc, valid_loss, valid_acc = [
        Averager() for i in range(4)
    ]

    for epoch in range(1, args.epochs + 1):

        ############## Train ##############

        model.train()

        for i, batch in enumerate(train_loader):
            data, _ = [d.to(args.device) for d in batch]
            p = args.shot * args.train_way
            data_support, data_query = data[:p], data[
                p:]  #(shot*way) #(query*way)

            proto = model(data_support)
示例#25
0
    trlog['train_loss'] = []
    trlog['val_loss'] = []
    trlog['train_acc'] = []
    trlog['val_acc'] = []
    trlog['max_acc'] = 0.0
    trlog['max_acc_epoch'] = 0

    timer = Timer()
    global_count = 0
    writer = SummaryWriter(comment=args.save_path)

    for epoch in range(1, args.max_epoch + 1):
        if args.lr_decay:
            lr_scheduler.step()
        model.train()
        tl = Averager()
        ta = Averager()

        label = torch.arange(args.way).repeat(args.query)
        if torch.cuda.is_available():
            label = label.type(torch.cuda.LongTensor)
        else:
            label = label.type(torch.LongTensor)

        for i, batch in enumerate(train_loader, 1):
            global_count = global_count + 1
            if torch.cuda.is_available():
                data, _ = [_.cuda() for _ in batch]
            else:
                data = batch[0]
            p = args.shot * args.way
示例#26
0
def train():
    """ dataset preparation """
    train_dataset_lmdb = LmdbDataset(cfg.lmdb_trainset_dir_name)
    val_dataset_lmdb = LmdbDataset(cfg.lmdb_valset_dir_name)

    train_loader = torch.utils.data.DataLoader(
        train_dataset_lmdb, batch_size=cfg.batch_size,
        collate_fn=data_collate,
        shuffle=True,
        num_workers=int(cfg.workers),
        pin_memory=True)
    valid_loader = torch.utils.data.DataLoader(
        val_dataset_lmdb, batch_size=cfg.batch_size,
        collate_fn=data_collate,
        shuffle=True,  # 'True' to check training progress with validation function.
        num_workers=int(cfg.workers),
        pin_memory=True)

    # --------------------训练过程---------------------------------
    model = advancedEAST()
    if int(cfg.train_task_id[-3:]) != 256:
        id_num = cfg.train_task_id[-3:]
        idx_dic = {'384': 256, '512': 384, '640': 512, '736': 640}
        model.load_state_dict(torch.load('./saved_model/3T{}_best_loss.pth'.format(idx_dic[id_num])))
    elif os.path.exists('./saved_model/3T{}_best_loss.pth'.format(cfg.train_task_id)):
        model.load_state_dict(torch.load('./saved_model/3T{}_best_loss.pth'.format(cfg.train_task_id)))

    model = model.to(device)

    optimizer = optim.Adam(model.parameters(), lr=cfg.lr, weight_decay=cfg.decay)
    loss_func = quad_loss

    train_Loss_list = []
    val_Loss_list = []

    '''start training'''
    start_iter = 0
    if cfg.saved_model != '':
        try:
            start_iter = int(cfg.saved_model.split('_')[-1].split('.')[0])
            print('continue to train, start_iter: {}'.format(start_iter))
        except Exception as e:
            print(e)
            pass

    start_time = time.time()
    best_mF1_score = 0
    i = start_iter
    step_num = 0
    start_time = time.time()
    loss_avg = Averager()
    val_loss_avg = Averager()
    eval_p_r_f = eval_pre_rec_f1()

    while(True):
        model.train()
        # train part
        # training-----------------------------
        for image_tensors, labels, gt_xy_list in train_loader:
            step_num += 1
            batch_x = image_tensors.to(device).float()
            batch_y = labels.to(device).float()  # float64转float32

            out = model(batch_x)
            loss = loss_func(batch_y, out)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            loss_avg.add(loss)
            train_Loss_list.append(loss_avg.val())
            if i == 5 or (i + 1) % 10 == 0:
                eval_p_r_f.add(out, gt_xy_list)  # 非常耗时!!!

        # save model per 100 epochs.
        if (i + 1) % 1e+2 == 0:
            torch.save(model.state_dict(), './saved_models/{}/{}_iter_{}.pth'.format(cfg.train_task_id, cfg.train_task_id, step_num+1))

        print('Epoch:[{}/{}] Training Loss: {:.3f}'.format(i + 1, cfg.epoch_num, train_Loss_list[-1].item()))
        loss_avg.reset()

        if i == 5 or (i + 1) % 10 == 0:
            mPre, mRec, mF1_score = eval_p_r_f.val()
            print('Training meanPrecision:{:.2f}% meanRecall:{:.2f}% meanF1-score:{:.2f}%'.format(mPre, mRec, mF1_score))
            eval_p_r_f.reset()

        # evaluation--------------------------------
        if (i + 1) % cfg.valInterval == 0:
            elapsed_time = time.time() - start_time
            print('Elapsed time:{}s'.format(round(elapsed_time)))
            model.eval()
            for image_tensors, labels, gt_xy_list in valid_loader:
                batch_x = image_tensors.to(device)
                batch_y = labels.to(device).float()  # float64转float32

                out = model(batch_x)
                loss = loss_func(batch_y, out)

                val_loss_avg.add(loss)
                val_Loss_list.append(val_loss_avg.val())
                eval_p_r_f.add(out, gt_xy_list)

            mPre, mRec, mF1_score = eval_p_r_f.val()
            print('validation meanPrecision:{:.2f}% meanRecall:{:.2f}% meanF1-score:{:.2f}%'.format(mPre, mRec, mF1_score))
            eval_p_r_f.reset()

            if mF1_score > best_mF1_score:  # 记录最佳模型
                best_mF1_score = mF1_score
                torch.save(model.state_dict(), './saved_models/{}/{}_best_mF1_score_{:.3f}.pth'.format(cfg.train_task_id, cfg.train_task_id, mF1_score))
                torch.save(model.state_dict(), './saved_model/{}_best_mF1_score.pth'.format(cfg.train_task_id))

            print('Validation loss:{:.3f}'.format(val_loss_avg.val().item()))
            val_loss_avg.reset()

        if i == cfg.epoch_num:
            torch.save(model.state_dict(), './saved_models/{}/{}_iter_{}.pth'.format(cfg.train_task_id, cfg.train_task_id, i+1))
            print('End the training')
            break
        i += 1

    sys.exit()
def train(opt, tb):
    """ dataset preparation """
    if not opt.data_filtering_off:
        print('Filtering the images containing characters which are not in opt.character')
        print('Filtering the images whose label is longer than opt.batch_max_length')
        # see https://github.com/clovaai/deep-text-recognition-benchmark/blob/6593928855fb7abb999a99f428b3e4477d4ae356/dataset.py#L130

    opt.select_data = opt.select_data.split('-')
    opt.batch_ratio = opt.batch_ratio.split('-')
    train_dataset = Batch_Balanced_Dataset(opt)

    log = open(f'./saved_models/{opt.experiment_name}/log_dataset.txt', 'a')
    AlignCollate_valid = AlignCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio_with_pad=opt.PAD)
    valid_dataset, valid_dataset_log = hierarchical_dataset(root=opt.valid_data, opt=opt)
    valid_loader = torch.utils.data.DataLoader(
        valid_dataset, batch_size=opt.batch_size,
        shuffle=True,  # 'True' to check training progress with validation function.
        num_workers=int(opt.workers),
        collate_fn=AlignCollate_valid, pin_memory=True)
    log.write(valid_dataset_log)
    print('-' * 80)
    log.write('-' * 80 + '\n')
    log.close()
    
    """ model configuration """
    if 'CTC' in opt.Prediction:
        converter = CTCLabelConverter(opt.character)
    else:
        converter = AttnLabelConverter(opt.character)
    opt.num_class = len(converter.character)

    if opt.rgb:
        opt.input_channel = 3
    model = Model(opt)
    print('model input parameters', opt.imgH, opt.imgW, opt.num_fiducial, opt.input_channel, opt.output_channel,
          opt.hidden_size, opt.num_class, opt.batch_max_length, opt.Transformation, opt.FeatureExtraction,
          opt.SequenceModeling, opt.Prediction)

    # weight initialization
    for name, param in model.named_parameters():
        if 'localization_fc2' in name:
            print(f'Skip {name} as it is already initialized')
            continue
        try:
            if 'bias' in name:
                init.constant_(param, 0.0)
            elif 'weight' in name:
                init.kaiming_normal_(param)
        except Exception as e:  # for batchnorm.
            if 'weight' in name:
                param.data.fill_(1)
            continue

    # data parallel for multi-GPU
    model = torch.nn.DataParallel(model).to(device)
    model.train()
    if opt.saved_model != '':
        print(f'loading pretrained model from {opt.saved_model}')
        if opt.FT:
            model.load_state_dict(torch.load(opt.saved_model), strict=False)
        else:
            model.load_state_dict(torch.load(opt.saved_model))
    print("Model:")
    print(model)

    """ setup loss """
    if 'CTC' in opt.Prediction:
        criterion = torch.nn.CTCLoss(zero_infinity=True).to(device)
    else:
        criterion = torch.nn.CrossEntropyLoss(ignore_index=0).to(device)  # ignore [GO] token = ignore index 0
    # loss averager
    loss_avg = Averager()

    # filter that only require gradient decent
    filtered_parameters = []
    params_num = []
    print("~~~~~~~~~~~~Gradient Descent~~~~~~~~~~~~~")
    #print(model.parameters())
    #print(model.)
    for p in filter(lambda p: p.requires_grad, model.parameters()):
        filtered_parameters.append(p)
        params_num.append(np.prod(p.size()))
    print('Filtered parameters for gradient descent: \n', len(filtered_parameters))
    print('Trainable params num : ', sum(params_num))
    # [print(name, p.numel()) for name, p in filter(lambda p: p[1].requires_grad, model.named_parameters())]

    # setup optimizer
    if opt.adam:
        optimizer = optim.Adam(filtered_parameters, lr=opt.lr, betas=(opt.beta1, 0.999))
    else:
        optimizer = optim.Adadelta(filtered_parameters, lr=opt.lr, rho=opt.rho, eps=opt.eps)
    print("Optimizer:")
    print(optimizer)

    """ final options """
    # print(opt)
    with open(f'./saved_models/{opt.experiment_name}/opt.txt', 'a') as opt_file:
        opt_log = '------------ Options -------------\n'
        args = vars(opt)
        for k, v in args.items():
            opt_log += f'{str(k)}: {str(v)}\n'
        opt_log += '---------------------------------------\n'
        print(opt_log)
        opt_file.write(opt_log)

    """ start training """
    start_iter = 0
    if opt.saved_model != '':
        try:
            start_iter = int(opt.saved_model.split('_')[-1].split('.')[0])
            print(f'continue to train, start_iter: {start_iter}')
        except:
            pass

    start_time = time.time()
    best_accuracy = -1
    best_norm_ED = -1
    i = start_iter

    while(True):
        # train part
        image_tensors, labels = train_dataset.get_batch()
        image = image_tensors.to(device)
        text, length = converter.encode(labels, batch_max_length=opt.batch_max_length)
        batch_size = image.size(0)

        if 'CTC' in opt.Prediction:
            preds = model(image, text).log_softmax(2)
            preds_size = torch.IntTensor([preds.size(1)] * batch_size)
            preds = preds.permute(1, 0, 2)

            # (ctc_a) For PyTorch 1.2.0 and 1.3.0. To avoid ctc_loss issue, disabled cudnn for the computation of the ctc_loss
            # https://github.com/jpuigcerver/PyLaia/issues/16
            torch.backends.cudnn.enabled = False
            cost = criterion(preds, text.to(device), preds_size.to(device), length.to(device))
            torch.backends.cudnn.enabled = True

            # # (ctc_b) To reproduce our pretrained model / paper, use our previous code (below code) instead of (ctc_a).
            # # With PyTorch 1.2.0, the below code occurs NAN, so you may use PyTorch 1.1.0.
            # # Thus, the result of CTCLoss is different in PyTorch 1.1.0 and PyTorch 1.2.0.
            # # See https://github.com/clovaai/deep-text-recognition-benchmark/issues/56#issuecomment-526490707
            # cost = criterion(preds, text, preds_size, length)

        else:
            preds = model(image, text[:, :-1])  # align with Attention.forward
            print(preds[0][0])
            target = text[:, 1:]  # without [GO] Symbol
            print(target[0])
            cost = criterion(preds.view(-1, preds.shape[-1]), target.contiguous().view(-1))

        model.zero_grad()
        cost.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), opt.grad_clip)  # gradient clipping with 5 (Default)
        optimizer.step()

        loss_avg.add(cost)

        # validation part
        if i % opt.valInterval == 0:
            elapsed_time = time.time() - start_time
            # for log
            with open(f'./saved_models/{opt.experiment_name}/log_train.txt', 'a') as log:
                model.eval()
                with torch.no_grad():
                    valid_loss, current_accuracy, current_norm_ED, preds, confidence_score, labels, infer_time, length_of_data = validation(
                        model, criterion, valid_loader, converter, opt)
                model.train()

                # training loss and validation loss
                loss_log = f'[{i}/{opt.num_iter}] Train loss: {loss_avg.val():0.5f}, Valid loss: {valid_loss:0.5f}, Elapsed_time: {elapsed_time:0.5f}'
                tb.add_scalar('Training Loss vs Iteration', loss_avg.val(), i)              # Record to Tensorboard
                tb.add_scalar('Validation Loss vs Iteration', valid_loss, i)                # Record to Tensorboard
                loss_avg.reset()

                current_model_log = f'{"Current_accuracy":17s}: {current_accuracy:0.3f}, {"Current_norm_ED":17s}: {current_norm_ED:0.2f}'
                tb.add_scalar('Current Accuracy vs Iteration', current_accuracy, i)         # Record to Tensorboard
                tb.add_scalar('Current Norm ED vs Iteration', current_norm_ED, i)           # Record to Tensorboard            

                # keep best accuracy model (on valid dataset)
                if current_accuracy > best_accuracy:
                    best_accuracy = current_accuracy
                    torch.save(model.state_dict(), f'./saved_models/{opt.experiment_name}/best_accuracy.pth')
                if current_norm_ED > best_norm_ED:
                    best_norm_ED = current_norm_ED
                    torch.save(model.state_dict(), f'./saved_models/{opt.experiment_name}/best_norm_ED.pth')
                best_model_log = f'{"Best_accuracy":17s}: {best_accuracy:0.3f}, {"Best_norm_ED":17s}: {best_norm_ED:0.2f}'

                loss_model_log = f'{loss_log}\n{current_model_log}\n{best_model_log}'
                print(loss_model_log)
                log.write(loss_model_log + '\n')

                # show some predicted results
                dashed_line = '-' * 80
                head = f'{"Ground Truth":25s} | {"Prediction":25s} | Confidence Score & T/F'
                predicted_result_log = f'{dashed_line}\n{head}\n{dashed_line}\n'
                for gt, pred, confidence in zip(labels[:5], preds[:5], confidence_score[:5]):
                    if 'Attn' in opt.Prediction:
                        gt = gt[:gt.find('[s]')]
                        pred = pred[:pred.find('[s]')]

                    predicted_result_log += f'{gt:25s} | {pred:25s} | {confidence:0.4f}\t{str(pred == gt)}\n'
                predicted_result_log += f'{dashed_line}'
                print(predicted_result_log)
                log.write(predicted_result_log + '\n')

        # save model per 1e+5 iter.
        if (i + 1) % 1e+5 == 0:
            torch.save(
                model.state_dict(), f'./saved_models/{opt.experiment_name}/iter_{i+1}.pth')

        if i == opt.num_iter:
            print('end the training')
            sys.exit()
        i += 1
def train(opt):
    lib.print_model_settings(locals().copy())

    if 'Attn' in opt.Prediction:
        converter = AttnLabelConverter(opt.character)
        text_len = opt.batch_max_length+2
    else:
        converter = CTCLabelConverter(opt.character)
        text_len = opt.batch_max_length

    opt.classes = converter.character
    
    """ dataset preparation """
    if not opt.data_filtering_off:
        print('Filtering the images containing characters which are not in opt.character')
        print('Filtering the images whose label is longer than opt.batch_max_length')
        # see https://github.com/clovaai/deep-text-recognition-benchmark/blob/6593928855fb7abb999a99f428b3e4477d4ae356/dataset.py#L130

    log = open(os.path.join(opt.exp_dir,opt.exp_name,'log_dataset.txt'), 'a')
    AlignCollate_valid = AlignPairCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio_with_pad=opt.PAD)

    train_dataset = LmdbStyleDataset(root=opt.train_data, opt=opt)
    train_loader = torch.utils.data.DataLoader(
        train_dataset, batch_size=opt.batch_size*2, #*2 to sample different images from training encoder and discriminator real images
        shuffle=True,  # 'True' to check training progress with validation function.
        num_workers=int(opt.workers),
        collate_fn=AlignCollate_valid, pin_memory=True, drop_last=True)
    
    print('-' * 80)
    
    valid_dataset = LmdbStyleDataset(root=opt.valid_data, opt=opt)
    valid_loader = torch.utils.data.DataLoader(
        valid_dataset, batch_size=opt.batch_size*2, #*2 to sample different images from training encoder and discriminator real images
        shuffle=False,  # 'True' to check training progress with validation function.
        num_workers=int(opt.workers),
        collate_fn=AlignCollate_valid, pin_memory=True, drop_last=True)
    
    print('-' * 80)
    log.write('-' * 80 + '\n')
    log.close()

    text_dataset = text_gen(opt)
    text_loader = torch.utils.data.DataLoader(
        text_dataset, batch_size=opt.batch_size,
        shuffle=True,
        num_workers=int(opt.workers),
        pin_memory=True, drop_last=True)
    opt.num_class = len(converter.character)
    

    c_code_size = opt.latent
    cEncoder = GlobalContentEncoder(opt.num_class, text_len, opt.char_embed_size, c_code_size)
    ocrModel = ModelV1(opt)

    
    genModel = styleGANGen(opt.size, opt.latent, opt.latent, opt.n_mlp, channel_multiplier=opt.channel_multiplier)
    g_ema = styleGANGen(opt.size, opt.latent, opt.latent, opt.n_mlp, channel_multiplier=opt.channel_multiplier)
   
    disEncModel = styleGANDis(opt.size, channel_multiplier=opt.channel_multiplier, input_dim=opt.input_channel, code_s_dim=c_code_size)
    
    accumulate(g_ema, genModel, 0)
    
    # uCriterion = torch.nn.MSELoss()
    # sCriterion = torch.nn.MSELoss()
    # if opt.contentLoss == 'vis' or opt.contentLoss == 'seq':
    #     ocrCriterion = torch.nn.L1Loss()
    # else:
    if 'CTC' in opt.Prediction:
        ocrCriterion = torch.nn.CTCLoss(zero_infinity=True).to(device)
    else:
        print('Not implemented error')
        sys.exit()
        # ocrCriterion = torch.nn.CrossEntropyLoss(ignore_index=0).to(device)  # ignore [GO] token = ignore index 0

    cEncoder= torch.nn.DataParallel(cEncoder).to(device)
    cEncoder.train()
    genModel = torch.nn.DataParallel(genModel).to(device)
    g_ema = torch.nn.DataParallel(g_ema).to(device)
    genModel.train()
    g_ema.eval()

    disEncModel = torch.nn.DataParallel(disEncModel).to(device)
    disEncModel.train()

    ocrModel = torch.nn.DataParallel(ocrModel).to(device)
    if opt.ocrFixed:
        if opt.Transformation == 'TPS':
            ocrModel.module.Transformation.eval()
        ocrModel.module.FeatureExtraction.eval()
        ocrModel.module.AdaptiveAvgPool.eval()
        # ocrModel.module.SequenceModeling.eval()
        ocrModel.module.Prediction.eval()
    else:
        ocrModel.train()

    g_reg_ratio = opt.g_reg_every / (opt.g_reg_every + 1)
    d_reg_ratio = opt.d_reg_every / (opt.d_reg_every + 1)

    
    optimizer = optim.Adam(
        list(genModel.parameters())+list(cEncoder.parameters()),
        lr=opt.lr * g_reg_ratio,
        betas=(0 ** g_reg_ratio, 0.99 ** g_reg_ratio),
    )
    dis_optimizer = optim.Adam(
        disEncModel.parameters(),
        lr=opt.lr * d_reg_ratio,
        betas=(0 ** d_reg_ratio, 0.99 ** d_reg_ratio),
    )
    
    ocr_optimizer = optim.Adam(
        ocrModel.parameters(),
        lr=opt.lr,
        betas=(0.9, 0.99),
    )


    ## Loading pre-trained files
    if opt.modelFolderFlag:
        if len(glob.glob(os.path.join(opt.exp_dir,opt.exp_name,"iter_*_synth.pth")))>0:
            opt.saved_synth_model = glob.glob(os.path.join(opt.exp_dir,opt.exp_name,"iter_*_synth.pth"))[-1]

    if opt.saved_ocr_model !='' and opt.saved_ocr_model !='None':
        print(f'loading pretrained ocr model from {opt.saved_ocr_model}')
        checkpoint = torch.load(opt.saved_ocr_model)
        ocrModel.load_state_dict(checkpoint)
    
    # if opt.saved_gen_model !='' and opt.saved_gen_model !='None':
    #     print(f'loading pretrained gen model from {opt.saved_gen_model}')
    #     checkpoint = torch.load(opt.saved_gen_model, map_location=lambda storage, loc: storage)
    #     genModel.module.load_state_dict(checkpoint['g'])
    #     g_ema.module.load_state_dict(checkpoint['g_ema'])

    if opt.saved_synth_model != '' and opt.saved_synth_model != 'None':
        print(f'loading pretrained synth model from {opt.saved_synth_model}')
        checkpoint = torch.load(opt.saved_synth_model)
        
        # styleModel.load_state_dict(checkpoint['styleModel'])
        # mixModel.load_state_dict(checkpoint['mixModel'])
        genModel.load_state_dict(checkpoint['genModel'])
        g_ema.load_state_dict(checkpoint['g_ema'])
        disEncModel.load_state_dict(checkpoint['disEncModel'])
        ocrModel.load_state_dict(checkpoint['ocrModel'])
        
        optimizer.load_state_dict(checkpoint["optimizer"])
        dis_optimizer.load_state_dict(checkpoint["dis_optimizer"])
        ocr_optimizer.load_state_dict(checkpoint["ocr_optimizer"])

    # if opt.imgReconLoss == 'l1':
    #     recCriterion = torch.nn.L1Loss()
    # elif opt.imgReconLoss == 'ssim':
    #     recCriterion = ssim
    # elif opt.imgReconLoss == 'ms-ssim':
    #     recCriterion = msssim
    

    # loss averager
    loss_avg_dis = Averager()
    loss_avg_gen = Averager()
    loss_avg_unsup = Averager()
    loss_avg_sup = Averager()
    log_r1_val = Averager()
    log_avg_path_loss_val = Averager()
    log_avg_mean_path_length_avg = Averager()
    log_ada_aug_p = Averager()
    loss_avg_ocr_sup = Averager()
    loss_avg_ocr_unsup = Averager()

    """ final options """
    with open(os.path.join(opt.exp_dir,opt.exp_name,'opt.txt'), 'a') as opt_file:
        opt_log = '------------ Options -------------\n'
        args = vars(opt)
        for k, v in args.items():
            opt_log += f'{str(k)}: {str(v)}\n'
        opt_log += '---------------------------------------\n'
        print(opt_log)
        opt_file.write(opt_log)

    """ start training """
    start_iter = 0
    
    if opt.saved_synth_model != '' and opt.saved_synth_model != 'None':
        try:
            start_iter = int(opt.saved_synth_model.split('_')[-2].split('.')[0])
            print(f'continue to train, start_iter: {start_iter}')
        except:
            pass

    
    #get schedulers
    scheduler = get_scheduler(optimizer,opt)
    dis_scheduler = get_scheduler(dis_optimizer,opt)
    ocr_scheduler = get_scheduler(ocr_optimizer,opt)

    start_time = time.time()
    iteration = start_iter
    cntr=0
    
    mean_path_length = 0
    d_loss_val = 0
    r1_loss = torch.tensor(0.0, device=device)
    g_loss_val = 0
    path_loss = torch.tensor(0.0, device=device)
    path_lengths = torch.tensor(0.0, device=device)
    mean_path_length_avg = 0
    # loss_dict = {}

    accum = 0.5 ** (32 / (10 * 1000))
    ada_augment = torch.tensor([0.0, 0.0], device=device)
    ada_aug_p = opt.augment_p if opt.augment_p > 0 else 0.0
    ada_aug_step = opt.ada_target / opt.ada_length
    r_t_stat = 0
    epsilon = 10e-50
    # sample_z = torch.randn(opt.n_sample, opt.latent, device=device)

    while(True):
        # print(cntr)
        # train part
        if opt.lr_policy !="None":
            scheduler.step()
            dis_scheduler.step()
            ocr_scheduler.step()
        
        image_input_tensors, _, labels, _ = iter(train_loader).next()
        labels_z_c = iter(text_loader).next()

        image_input_tensors = image_input_tensors.to(device)
        gt_image_tensors = image_input_tensors[:opt.batch_size].detach()
        real_image_tensors = image_input_tensors[opt.batch_size:].detach()
        
        labels_gt = labels[:opt.batch_size]
        
        requires_grad(cEncoder, False)
        requires_grad(genModel, False)
        requires_grad(disEncModel, True)
        requires_grad(ocrModel, False)

        text_z_c, length_z_c = converter.encode(labels_z_c, batch_max_length=opt.batch_max_length)
        text_gt, length_gt = converter.encode(labels_gt, batch_max_length=opt.batch_max_length)

        z_c_code = cEncoder(text_z_c)
        noise_style = mixing_noise_style(opt.batch_size, opt.latent, opt.mixing, device)
        style=[]
        style.append(noise_style[0]*z_c_code)
        if len(noise_style)>1:
            style.append(noise_style[1]*z_c_code)
        
        if opt.zAlone:
            #to validate orig style gan results
            newstyle = []
            newstyle.append(style[0][:,:opt.latent])
            if len(style)>1:
                newstyle.append(style[1][:,:opt.latent])
            style = newstyle
        
        fake_img,_ = genModel(style, input_is_latent=opt.input_latent)
        
        # #unsupervised code prediction on generated image
        # u_pred_code = disEncModel(fake_img, mode='enc')
        # uCost = uCriterion(u_pred_code, z_code)

        # #supervised code prediction on gt image
        # s_pred_code = disEncModel(gt_image_tensors, mode='enc')
        # sCost = uCriterion(s_pred_code, gt_phoc_tensors)

        #Domain discriminator
        fake_pred = disEncModel(fake_img)
        real_pred = disEncModel(real_image_tensors)
        disCost = d_logistic_loss(real_pred, fake_pred)

        # dis_cost = disCost + opt.gamma_e*uCost + opt.beta*sCost
        loss_avg_dis.add(disCost)
        # loss_avg_sup.add(opt.beta*sCost)
        # loss_avg_unsup.add(opt.gamma_e * uCost)

        disEncModel.zero_grad()
        disCost.backward()
        dis_optimizer.step()

        d_regularize = cntr % opt.d_reg_every == 0

        if d_regularize:
            real_image_tensors.requires_grad = True
            real_pred = disEncModel(real_image_tensors)
            
            r1_loss = d_r1_loss(real_pred, real_image_tensors)

            disEncModel.zero_grad()
            (opt.r1 / 2 * r1_loss * opt.d_reg_every + 0 * real_pred[0]).backward()

            dis_optimizer.step()
        log_r1_val.add(r1_loss)
        
        # Recognizer update
        if not opt.ocrFixed and not opt.zAlone:
            requires_grad(disEncModel, False)
            requires_grad(ocrModel, True)

            if 'CTC' in opt.Prediction:
                preds_recon = ocrModel(gt_image_tensors, text_gt, is_train=True)
                preds_size = torch.IntTensor([preds_recon.size(1)] * opt.batch_size)
                preds_recon_softmax = preds_recon.log_softmax(2).permute(1, 0, 2)
                ocrCost = ocrCriterion(preds_recon_softmax, text_gt, preds_size, length_gt)
            else:
                print("Not implemented error")
                sys.exit()
            
            ocrModel.zero_grad()
            ocrCost.backward()
            # torch.nn.utils.clip_grad_norm_(ocrModel.parameters(), opt.grad_clip)  # gradient clipping with 5 (Default)
            ocr_optimizer.step()
            loss_avg_ocr_sup.add(ocrCost)
        else:
            loss_avg_ocr_sup.add(torch.tensor(0.0))


        # [Word Generator] update
        # image_input_tensors, _, labels, _ = iter(train_loader).next()
        labels_z_c = iter(text_loader).next()

        # image_input_tensors = image_input_tensors.to(device)
        # gt_image_tensors = image_input_tensors[:opt.batch_size]
        # real_image_tensors = image_input_tensors[opt.batch_size:]
        
        # labels_gt = labels[:opt.batch_size]

        requires_grad(cEncoder, True)
        requires_grad(genModel, True)
        requires_grad(disEncModel, False)
        requires_grad(ocrModel, False)

        text_z_c, length_z_c = converter.encode(labels_z_c, batch_max_length=opt.batch_max_length)
        
        z_c_code = cEncoder(text_z_c)
        noise_style = mixing_noise_style(opt.batch_size, opt.latent, opt.mixing, device)
        style=[]
        style.append(noise_style[0]*z_c_code)
        if len(noise_style)>1:
            style.append(noise_style[1]*z_c_code)

        if opt.zAlone:
            #to validate orig style gan results
            newstyle = []
            newstyle.append(style[0][:,:opt.latent])
            if len(style)>1:
                newstyle.append(style[1][:,:opt.latent])
            style = newstyle
        
        fake_img,_ = genModel(style, input_is_latent=opt.input_latent)

        fake_pred = disEncModel(fake_img)
        disGenCost = g_nonsaturating_loss(fake_pred)

        if opt.zAlone:
            ocrCost = torch.tensor(0.0)
        else:
            #Compute OCR prediction (Reconstruction of content)
            # text_for_pred = torch.LongTensor(opt.batch_size, opt.batch_max_length + 1).fill_(0).to(device)
            # length_for_pred = torch.IntTensor([opt.batch_max_length] * opt.batch_size).to(device)
            
            if 'CTC' in opt.Prediction:
                preds_recon = ocrModel(fake_img, text_z_c, is_train=False)
                preds_size = torch.IntTensor([preds_recon.size(1)] * opt.batch_size)
                preds_recon_softmax = preds_recon.log_softmax(2).permute(1, 0, 2)
                ocrCost = ocrCriterion(preds_recon_softmax, text_z_c, preds_size, length_z_c)
            else:
                print("Not implemented error")
                sys.exit()
        
        genModel.zero_grad()
        cEncoder.zero_grad()

        gen_enc_cost = disGenCost + opt.ocrWeight * ocrCost
        grad_fake_OCR = torch.autograd.grad(ocrCost, fake_img, retain_graph=True)[0]
        loss_grad_fake_OCR = 10**6*torch.mean(grad_fake_OCR**2)
        grad_fake_adv = torch.autograd.grad(disGenCost, fake_img, retain_graph=True)[0]
        loss_grad_fake_adv = 10**6*torch.mean(grad_fake_adv**2)
        
        if opt.grad_balance:
            gen_enc_cost.backward(retain_graph=True)
            grad_fake_OCR = torch.autograd.grad(ocrCost, fake_img, create_graph=True, retain_graph=True)[0]
            grad_fake_adv = torch.autograd.grad(disGenCost, fake_img, create_graph=True, retain_graph=True)[0]
            a = opt.ocrWeight * torch.div(torch.std(grad_fake_adv), epsilon+torch.std(grad_fake_OCR))
            if a is None:
                print(ocrCost, disGenCost, torch.std(grad_fake_adv), torch.std(grad_fake_OCR))
            if a>1000 or a<0.0001:
                print(a)
            
            ocrCost = a.detach() * ocrCost
            gen_enc_cost = disGenCost + ocrCost
            gen_enc_cost.backward(retain_graph=True)
            grad_fake_OCR = torch.autograd.grad(ocrCost, fake_img, create_graph=False, retain_graph=True)[0]
            grad_fake_adv = torch.autograd.grad(disGenCost, fake_img, create_graph=False, retain_graph=True)[0]
            loss_grad_fake_OCR = 10 ** 6 * torch.mean(grad_fake_OCR ** 2)
            loss_grad_fake_adv = 10 ** 6 * torch.mean(grad_fake_adv ** 2)
            with torch.no_grad():
                gen_enc_cost.backward()
        else:
            gen_enc_cost.backward()

        loss_avg_gen.add(disGenCost)
        loss_avg_ocr_unsup.add(opt.ocrWeight * ocrCost)

        optimizer.step()
        
        g_regularize = cntr % opt.g_reg_every == 0

        if g_regularize:
            path_batch_size = max(1, opt.batch_size // opt.path_batch_shrink)
            # image_input_tensors, _, labels, _ = iter(train_loader).next()
            labels_z_c = iter(text_loader).next()

            # image_input_tensors = image_input_tensors.to(device)
            # gt_image_tensors = image_input_tensors[:path_batch_size]

            # labels_gt = labels[:path_batch_size]

            text_z_c, length_z_c = converter.encode(labels_z_c[:path_batch_size], batch_max_length=opt.batch_max_length)
            # text_gt, length_gt = converter.encode(labels_gt, batch_max_length=opt.batch_max_length)
        
            z_c_code = cEncoder(text_z_c)
            noise_style = mixing_noise_style(path_batch_size, opt.latent, opt.mixing, device)
            style=[]
            style.append(noise_style[0]*z_c_code)
            if len(noise_style)>1:
                style.append(noise_style[1]*z_c_code)

            if opt.zAlone:
                #to validate orig style gan results
                newstyle = []
                newstyle.append(style[0][:,:opt.latent])
                if len(style)>1:
                    newstyle.append(style[1][:,:opt.latent])
                style = newstyle

            fake_img, grad = genModel(style, return_latents=True, g_path_regularize=True, mean_path_length=mean_path_length)
            
            decay = 0.01
            path_lengths = torch.sqrt(grad.pow(2).sum(2).mean(1))

            mean_path_length_orig = mean_path_length + decay * (path_lengths.mean() - mean_path_length)
            path_loss = (path_lengths - mean_path_length_orig).pow(2).mean()
            mean_path_length = mean_path_length_orig.detach().item()

            genModel.zero_grad()
            cEncoder.zero_grad()
            weighted_path_loss = opt.path_regularize * opt.g_reg_every * path_loss

            if opt.path_batch_shrink:
                weighted_path_loss += 0 * fake_img[0, 0, 0, 0]

            weighted_path_loss.backward()

            optimizer.step()

            # mean_path_length_avg = (
            #     reduce_sum(mean_path_length).item() / get_world_size()
            # )
            #commented above for multi-gpu , non-distributed setting
            mean_path_length_avg = mean_path_length

        accumulate(g_ema, genModel, accum)

        log_avg_path_loss_val.add(path_loss)
        log_avg_mean_path_length_avg.add(torch.tensor(mean_path_length_avg))
        log_ada_aug_p.add(torch.tensor(ada_aug_p))
        

        if get_rank() == 0:
            if wandb and opt.wandb:
                wandb.log(
                    {
                        "Generator": g_loss_val,
                        "Discriminator": d_loss_val,
                        "Augment": ada_aug_p,
                        "Rt": r_t_stat,
                        "R1": r1_val,
                        "Path Length Regularization": path_loss_val,
                        "Mean Path Length": mean_path_length,
                        "Real Score": real_score_val,
                        "Fake Score": fake_score_val,   
                        "Path Length": path_length_val,
                    }
                )
        
        # validation part
        if (iteration + 1) % opt.valInterval == 0 or iteration == 0: # To see training progress, we also conduct validation when 'iteration == 0' 
            
            #generate paired content with similar style
            labels_z_c_1 = iter(text_loader).next()
            labels_z_c_2 = iter(text_loader).next()
            
            text_z_c_1, length_z_c_1 = converter.encode(labels_z_c_1, batch_max_length=opt.batch_max_length)
            text_z_c_2, length_z_c_2 = converter.encode(labels_z_c_2, batch_max_length=opt.batch_max_length)

            z_c_code_1 = cEncoder(text_z_c_1)
            z_c_code_2 = cEncoder(text_z_c_2)

            
            style_c1_s1 = []
            style_c2_s1 = []
            style_s1 = mixing_noise_style(opt.batch_size, opt.latent, opt.mixing, device)
            style_c1_s1.append(style_s1[0]*z_c_code_1)
            style_c2_s1.append(style_s1[0]*z_c_code_2)
            if len(style_s1)>1:
                style_c1_s1.append(style_s1[1]*z_c_code_1)
                style_c2_s1.append(style_s1[1]*z_c_code_2)
            
            noise_style = mixing_noise_style(opt.batch_size, opt.latent, opt.mixing, device)
            style_c1_s2 = []
            style_c1_s2.append(noise_style[0]*z_c_code_1)
            if len(noise_style)>1:
                style_c1_s2.append(noise_style[1]*z_c_code_1)
            
            if opt.zAlone:
                #to validate orig style gan results
                newstyle = []
                newstyle.append(style_c1_s1[0][:,:opt.latent])
                if len(style_c1_s1)>1:
                    newstyle.append(style_c1_s1[1][:,:opt.latent])
                style_c1_s1 = newstyle
                style_c2_s1 = newstyle
                style_c1_s2 = newstyle
            
            fake_img_c1_s1, _ = g_ema(style_c1_s1, input_is_latent=opt.input_latent)
            fake_img_c2_s1, _ = g_ema(style_c2_s1, input_is_latent=opt.input_latent)
            fake_img_c1_s2, _ = g_ema(style_c1_s2, input_is_latent=opt.input_latent)

            if not opt.zAlone:
                #Run OCR prediction
                if 'CTC' in opt.Prediction:
                    preds = ocrModel(fake_img_c1_s1, text_z_c_1, is_train=False)
                    preds_size = torch.IntTensor([preds.size(1)] * opt.batch_size)
                    _, preds_index = preds.max(2)
                    preds_str_fake_img_c1_s1 = converter.decode(preds_index.data, preds_size.data)

                    preds = ocrModel(fake_img_c2_s1, text_z_c_2, is_train=False)
                    preds_size = torch.IntTensor([preds.size(1)] * opt.batch_size)
                    _, preds_index = preds.max(2)
                    preds_str_fake_img_c2_s1 = converter.decode(preds_index.data, preds_size.data)

                    preds = ocrModel(fake_img_c1_s2, text_z_c_1, is_train=False)
                    preds_size = torch.IntTensor([preds.size(1)] * opt.batch_size)
                    _, preds_index = preds.max(2)
                    preds_str_fake_img_c1_s2 = converter.decode(preds_index.data, preds_size.data)

                    preds = ocrModel(gt_image_tensors, text_gt, is_train=False)
                    preds_size = torch.IntTensor([preds.size(1)] * gt_image_tensors.shape[0])
                    _, preds_index = preds.max(2)
                    preds_str_gt = converter.decode(preds_index.data, preds_size.data)

                else:
                    print("Not implemented error")
                    sys.exit()
            else:
                preds_str_fake_img_c1_s1 = [':None:'] * fake_img_c1_s1.shape[0]
                preds_str_gt = [':None:'] * fake_img_c1_s1.shape[0] 

            os.makedirs(os.path.join(opt.trainDir,str(iteration)), exist_ok=True)
            for trImgCntr in range(opt.batch_size):
                try:
                    save_image(tensor2im(fake_img_c1_s1[trImgCntr].detach()),os.path.join(opt.trainDir,str(iteration),str(trImgCntr)+'_c1_s1_'+labels_z_c_1[trImgCntr]+'_ocr:'+preds_str_fake_img_c1_s1[trImgCntr]+'.png'))
                    if not opt.zAlone:
                        save_image(tensor2im(fake_img_c2_s1[trImgCntr].detach()),os.path.join(opt.trainDir,str(iteration),str(trImgCntr)+'_c2_s1_'+labels_z_c_2[trImgCntr]+'_ocr:'+preds_str_fake_img_c2_s1[trImgCntr]+'.png'))
                        save_image(tensor2im(fake_img_c1_s2[trImgCntr].detach()),os.path.join(opt.trainDir,str(iteration),str(trImgCntr)+'_c1_s2_'+labels_z_c_1[trImgCntr]+'_ocr:'+preds_str_fake_img_c1_s2[trImgCntr]+'.png'))
                        if trImgCntr<gt_image_tensors.shape[0]:
                            save_image(tensor2im(gt_image_tensors[trImgCntr].detach()),os.path.join(opt.trainDir,str(iteration),str(trImgCntr)+'_gt_act:'+labels_gt[trImgCntr]+'_ocr:'+preds_str_gt[trImgCntr]+'.png'))
                except:
                    print('Warning while saving training image')
            
            elapsed_time = time.time() - start_time
            # for log
            
            with open(os.path.join(opt.exp_dir,opt.exp_name,'log_train.txt'), 'a') as log:

                # training loss and validation loss
                loss_log = f'[{iteration+1}/{opt.num_iter}]  \
                    Train Dis loss: {loss_avg_dis.val():0.5f}, Train Gen loss: {loss_avg_gen.val():0.5f},\
                    Train UnSup OCR loss: {loss_avg_ocr_unsup.val():0.5f}, Train Sup OCR loss: {loss_avg_ocr_sup.val():0.5f}, \
                    Train R1-val loss: {log_r1_val.val():0.5f}, Train avg-path-loss: {log_avg_path_loss_val.val():0.5f}, \
                    Train mean-path-length loss: {log_avg_mean_path_length_avg.val():0.5f}, Train ada-aug-p: {log_ada_aug_p.val():0.5f}, \
                    Elapsed_time: {elapsed_time:0.5f}'
                
                
                #plotting
                lib.plot.plot(os.path.join(opt.plotDir,'Train-Dis-Loss'), loss_avg_dis.val().item())
                lib.plot.plot(os.path.join(opt.plotDir,'Train-Gen-Loss'), loss_avg_gen.val().item())
                lib.plot.plot(os.path.join(opt.plotDir,'Train-UnSup-OCR-Loss'), loss_avg_ocr_unsup.val().item())
                lib.plot.plot(os.path.join(opt.plotDir,'Train-Sup-OCR-Loss'), loss_avg_ocr_sup.val().item())
                lib.plot.plot(os.path.join(opt.plotDir,'Train-r1_val'), log_r1_val.val().item())
                lib.plot.plot(os.path.join(opt.plotDir,'Train-path_loss_val'), log_avg_path_loss_val.val().item())
                lib.plot.plot(os.path.join(opt.plotDir,'Train-mean_path_length_avg'), log_avg_mean_path_length_avg.val().item())
                lib.plot.plot(os.path.join(opt.plotDir,'Train-ada_aug_p'), log_ada_aug_p.val().item())

                
                print(loss_log)

                loss_avg_dis.reset()
                loss_avg_gen.reset()
                loss_avg_ocr_unsup.reset()
                loss_avg_ocr_sup.reset()
                log_r1_val.reset()
                log_avg_path_loss_val.reset()
                log_avg_mean_path_length_avg.reset()
                log_ada_aug_p.reset()
                

            lib.plot.flush()

        lib.plot.tick()

        # save model per 1e+5 iter.
        if (iteration) % 1e+4 == 0:
            torch.save({
                'cEncoder':cEncoder.state_dict(),
                'genModel':genModel.state_dict(),
                'g_ema':g_ema.state_dict(),
                'ocrModel':ocrModel.state_dict(),
                'disEncModel':disEncModel.state_dict(),
                'optimizer':optimizer.state_dict(),
                'ocr_optimizer':ocr_optimizer.state_dict(),
                'dis_optimizer':dis_optimizer.state_dict()}, 
                os.path.join(opt.exp_dir,opt.exp_name,'iter_'+str(iteration+1)+'_synth.pth'))
            

        if (iteration + 1) == opt.num_iter:
            print('end the training')
            sys.exit()
        iteration += 1
        cntr+=1
    args = parser.parse_args()
    pprint(vars(args))

    set_gpu(args.gpu)

    dataset = MiniImageNet('test')
    sampler = CategoriesSampler(dataset.label,
                                args.batch, args.way, args.folds * args.shot + args.query)
    loader = DataLoader(dataset, batch_sampler=sampler,
                        num_workers=8, pin_memory=True)

    model = Convnet().cuda()
    model.load_state_dict(torch.load(args.load))
    model.eval()

    ave_acc = Averager()
    s_label = torch.arange(args.train_way).repeat(args.shot).view(args.shot * args.train_way)
    s_onehot = torch.zeros(s_label.size(0), 20)
    s_onehot = s_onehot.scatter_(1, s_label.unsqueeze(dim=1), 1).cuda()

    for i, batch in enumerate(loader, 1):
        data, _ = [_.cuda() for _ in batch]
        k = args.way * args.shot
        data_shot, meta_support, data_query = data[:k], data[k:2*k], data[2*k:]

        #p = inter_fold(model, args, data_shot)

        x = model(data_shot)
        x = x.reshape(args.shot, args.way, -1).mean(dim=0)
        p = x
def train(opt):
    lib.print_model_settings(locals().copy())

    # train_transform =  transforms.Compose([
    #     # transforms.RandomResizedCrop(input_size),
    #     transforms.Resize((opt.imgH, opt.imgW)),
    #     # transforms.RandomHorizontalFlip(),
    #     transforms.ToTensor(),
    #     transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    # ])

    # val_transform = transforms.Compose([
    #     transforms.Resize((opt.imgH, opt.imgW)),
    #     # transforms.CenterCrop(input_size),
    #     transforms.ToTensor(),
    #     transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    # ])

    AlignFontCollateObj = AlignFontCollate(imgH=opt.imgH,
                                           imgW=opt.imgW,
                                           keep_ratio_with_pad=opt.PAD)
    train_dataset = fontDataset(imgDir=opt.train_img_dir,
                                annFile=opt.train_ann_file,
                                transform=None,
                                numClasses=opt.numClasses)
    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=opt.batch_size,
        shuffle=
        False,  # 'True' to check training progress with validation function.
        sampler=data_sampler(train_dataset,
                             shuffle=True,
                             distributed=opt.distributed),
        num_workers=int(opt.workers),
        collate_fn=AlignFontCollateObj,
        pin_memory=True,
        drop_last=False)
    # numClasses = len(train_dataset.Idx2F)
    numClasses = np.unique(train_dataset.fontIdx).size

    train_loader = sample_data(train_loader)
    print('-' * 80)
    numTrainSamples = len(train_dataset)

    # valid_dataset = LmdbStyleDataset(root=opt.valid_data, opt=opt)
    valid_dataset = fontDataset(imgDir=opt.train_img_dir,
                                annFile=opt.val_ann_file,
                                transform=None,
                                F2Idx=train_dataset.F2Idx,
                                Idx2F=train_dataset.Idx2F,
                                numClasses=opt.numClasses)
    valid_loader = torch.utils.data.DataLoader(
        valid_dataset,
        batch_size=opt.batch_size,
        shuffle=
        False,  # 'True' to check training progress with validation function.
        sampler=data_sampler(valid_dataset,
                             shuffle=False,
                             distributed=opt.distributed),
        num_workers=int(opt.workers),
        collate_fn=AlignFontCollateObj,
        pin_memory=True,
        drop_last=False)
    numTestSamples = len(valid_dataset)

    print('numClasses', numClasses)
    print('numTrainSamples', numTrainSamples)
    print('numTestSamples', numTestSamples)

    vggFontModel = VGGFontModel(models.vgg19(pretrained=opt.preTrained),
                                numClasses).to(device)
    for name, param in vggFontModel.classifier.named_parameters():
        try:
            if 'bias' in name:
                init.constant_(param, 0.0)
            elif 'weight' in name:
                init.kaiming_normal_(param)
        except Exception as e:  # for batchnorm.
            print('Exception in weight init' + name)
            if 'weight' in name:
                param.data.fill_(1)
            continue

    if opt.optim == "sgd":
        print('SGD optimizer')
        optimizer = optim.SGD(vggFontModel.parameters(),
                              lr=opt.lr,
                              momentum=0.9)
    elif opt.optim == "adam":
        print('Adam optimizer')
        optimizer = optim.Adam(vggFontModel.parameters(), lr=opt.lr)
    #get schedulers
    scheduler = get_scheduler(optimizer, opt)

    criterion = torch.nn.CrossEntropyLoss()

    if opt.modelFolderFlag:
        if len(
                glob.glob(
                    os.path.join(opt.exp_dir, opt.exp_name,
                                 "iter_*_vggFont.pth"))) > 0:
            opt.saved_font_model = glob.glob(
                os.path.join(opt.exp_dir, opt.exp_name,
                             "iter_*_vggFont.pth"))[-1]

    ## Loading pre-trained files
    if opt.saved_font_model != '' and opt.saved_font_model != 'None':
        print(f'loading pretrained synth model from {opt.saved_font_model}')
        checkpoint = torch.load(opt.saved_font_model,
                                map_location=lambda storage, loc: storage)

        vggFontModel.load_state_dict(checkpoint['vggFontModel'])
        optimizer.load_state_dict(checkpoint["optimizer"])
        scheduler.load_state_dict(checkpoint["scheduler"])

    # print('Model Initialization')
    #
    # print('Loaded checkpoint')

    if opt.distributed:
        vggFontModel = torch.nn.parallel.DistributedDataParallel(
            vggFontModel,
            device_ids=[opt.local_rank],
            output_device=opt.local_rank,
            broadcast_buffers=False,
            find_unused_parameters=True)
        vggFontModel.train()

    # print('Loaded distributed')

    if opt.distributed:
        vggFontModel_module = vggFontModel.module
    else:
        vggFontModel_module = vggFontModel

    # print('Loading module')

    # loss averager
    loss_train = Averager()
    loss_val = Averager()
    train_acc = Averager()
    val_acc = Averager()
    train_acc_5 = Averager()
    val_acc_5 = Averager()
    """ final options """
    with open(os.path.join(opt.exp_dir, opt.exp_name, 'opt.txt'),
              'a') as opt_file:
        opt_log = '------------ Options -------------\n'
        args = vars(opt)
        for k, v in args.items():
            opt_log += f'{str(k)}: {str(v)}\n'
        opt_log += '---------------------------------------\n'
        print(opt_log)
        opt_file.write(opt_log)
    """ start training """
    start_iter = 0

    if opt.saved_font_model != '' and opt.saved_font_model != 'None':
        try:
            start_iter = int(opt.saved_font_model.split('_')[-2].split('.')[0])
            print(f'continue to train, start_iter: {start_iter}')
        except:
            pass

    iteration = start_iter

    cntr = 0
    # trainCorrect=0
    # tCntr=0
    while (True):
        # print(cntr)
        # train part

        start_time = time.time()
        if not opt.testFlag:

            image_input_tensors, labels_gt = next(train_loader)
            image_input_tensors = image_input_tensors.to(device)
            labels_gt = labels_gt.view(-1).to(device)
            preds = vggFontModel(image_input_tensors)

            loss = criterion(preds, labels_gt)

            vggFontModel.zero_grad()
            loss.backward()
            optimizer.step()

            # _, preds_max = preds.max(dim=1)
            # trainCorrect += (preds_max == labels_gt).sum()
            # tCntr+=preds_max.shape[0]

            acc1, acc5 = getNumCorrect(preds,
                                       labels_gt,
                                       topk=(1, min(numClasses, 5)))
            train_acc.addScalar(acc1, preds.shape[0])
            train_acc_5.addScalar(acc5, preds.shape[0])

            loss_train.add(loss)

            if opt.lr_policy != "None":
                scheduler.step()

        # print
        if get_rank() == 0:
            if (
                    iteration + 1
            ) % opt.valInterval == 0 or iteration == 0 or opt.testFlag:  # To see training progress, we also conduct validation when 'iteration == 0'
                #validation
                # iCntr=torch.tensor(0.0).to(device)
                # valCorrect=torch.tensor(0.0).to(device)
                vggFontModel.eval()
                print('Inside val', iteration)

                for vCntr, (image_input_tensors,
                            labels_gt) in enumerate(valid_loader):
                    # print('vCntr--',vCntr)
                    if opt.debugFlag and vCntr > 2:
                        break

                    with torch.no_grad():
                        image_input_tensors = image_input_tensors.to(device)
                        labels_gt = labels_gt.view(-1).to(device)

                        preds = vggFontModel(image_input_tensors)
                        loss = criterion(preds, labels_gt)
                        loss_val.add(loss)

                        # _, preds_max = preds.max(dim=1)
                        # valCorrect += (preds_max == labels_gt).sum()
                        # iCntr+=preds_max.shape[0]

                        acc1, acc5 = getNumCorrect(preds,
                                                   labels_gt,
                                                   topk=(1, min(numClasses,
                                                                5)))
                        val_acc.addScalar(acc1, preds.shape[0])
                        val_acc_5.addScalar(acc5, preds.shape[0])

                vggFontModel.train()
                elapsed_time = time.time() - start_time

                #DO HERE
                with open(
                        os.path.join(opt.exp_dir, opt.exp_name,
                                     'log_train.txt'), 'a') as log:
                    # print('COUNT-------',val_acc_5.n_count)
                    # training loss and validation loss
                    loss_log = f'[{iteration+1}/{opt.num_iter}]  \
                        Train loss: {loss_train.val():0.5f}, Val loss: {loss_val.val():0.5f}, \
                        Train Top-1 Acc: {train_acc.val()*100:0.5f}, Train Top-5 Acc: {train_acc_5.val()*100:0.5f}, \
                        Val Top-1 Acc: {val_acc.val()*100:0.5f}, Val Top-5 Acc: {val_acc_5.val()*100:0.5f}, \
                        Elapsed_time: {elapsed_time:0.5f}'

                    #plotting
                    lib.plot.plot(os.path.join(opt.plotDir, 'Train-Loss'),
                                  loss_train.val().item())
                    lib.plot.plot(os.path.join(opt.plotDir, 'Val-Loss'),
                                  loss_val.val().item())
                    lib.plot.plot(os.path.join(opt.plotDir, 'Train-Top-1-Acc'),
                                  train_acc.val() * 100)
                    lib.plot.plot(os.path.join(opt.plotDir, 'Train-Top-5-Acc'),
                                  train_acc_5.val() * 100)
                    lib.plot.plot(os.path.join(opt.plotDir, 'Val-Top-1-Acc'),
                                  val_acc.val() * 100)
                    lib.plot.plot(os.path.join(opt.plotDir, 'Val-Top-5-Acc'),
                                  val_acc_5.val() * 100)

                    print(loss_log)
                    log.write(loss_log + "\n")

                    loss_train.reset()
                    loss_val.reset()
                    train_acc.reset()
                    val_acc.reset()
                    train_acc_5.reset()
                    val_acc_5.reset()
                    # trainCorrect=0
                    # tCntr=0

                lib.plot.flush()

            # save model per 30000 iter.
            if (iteration) % 15000 == 0:
                torch.save(
                    {
                        'vggFontModel': vggFontModel_module.state_dict(),
                        'optimizer': optimizer.state_dict(),
                        'scheduler': scheduler.state_dict()
                    },
                    os.path.join(opt.exp_dir, opt.exp_name, 'iter_' +
                                 str(iteration + 1) + '_vggFont.pth'))

            lib.plot.tick()

        if (iteration + 1) == opt.num_iter:
            print('end the training')
            sys.exit()
        iteration += 1
        cntr += 1