示例#1
0
文件: iqa.py 项目: Gavinylk/CNN-FRIQA
def validate(val_loader, model, criterion, show_step=False):
    losses = AverageMeter()
    srocc = SROCC()
    len_val = len(val_loader)
    pb = ProgressBar(len_val-1, show_step=show_step)

    print("Validation")

    # Switch to evaluate mode
    model.eval()

    with torch.no_grad():
        for i, ((img,ref), score) in enumerate(val_loader):
            img, ref, score = img.cuda(), ref.cuda(), score.squeeze().cuda()

            # Compute output
            output = model(img, ref)
            
            loss = criterion(output, score)
            losses.update(loss.data, img.shape[0])

            output = output.cpu().data
            score = score.cpu().data
            srocc.update(score.numpy(), output.numpy())

            pb.show(i, '[{0:5d}/{1:5d}]\t'
                    'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                    'Output {out:.4f}\t'
                    'Target {tar:.4f}\t'
                    .format(i, len_val, loss=losses, 
                    out=output, tar=score))


    return float(1.0-srocc.compute())  # losses.avg
示例#2
0
文件: iqa.py 项目: Gavinylk/CNN-FRIQA
def train(train_loader, model, criterion, optimizer, epoch):
    losses = AverageMeter()
    len_train = len(train_loader)
    pb = ProgressBar(len_train-1)

    print("Training")

    # Switch to train mode
    model.train()
    criterion.cuda()
    for i, ((img,ref), score) in enumerate(train_loader):
        img, ref, score = img.cuda(), ref.cuda(), score.squeeze().cuda()

        # Compute output
        output = model(img, ref)
        loss = criterion(output, score)

        # Measure accuracy and record loss
        losses.update(loss.data, img.shape[0])

        # Compute gradient and do SGD step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        pb.show(i, '[{0:5d}/{1:5d}]\t'
                'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                .format(i, len_train, loss=losses))
示例#3
0
 def _to_pool(self):
     len_data = self.__len__()
     pb = SimpleProgressBar(len_data)
     print("\ninitializing data pool...")
     for index in range(len_data):
         self._pool(index).store(self.__getitem__(index)[0])
         pb.show(index, "[{:d}]/[{:d}] ".format(index + 1, len_data))
示例#4
0
def test(test_data_loader, model):
    srocc = SROCC()
    plcc = PLCC()
    rmse = RMSE()
    len_test = len(test_data_loader)
    pb = ProgressBar(len_test, show_step=True)

    print("Testing")

    model.eval()
    with torch.no_grad():
        for i, ((img, ref), score) in enumerate(test_data_loader):
            img, ref = img.cuda(), ref.cuda()
            output = model(img, ref).cpu().data.numpy()
            score = score.data.numpy()

            srocc.update(score, output)
            plcc.update(score, output)
            rmse.update(score, output)

            pb.show(
                i, "Test: [{0:5d}/{1:5d}]\t"
                "Score: {2:.4f}\t"
                "Label: {3:.4f}".format(i + 1, len_test, float(output),
                                        float(score)))

    print("\n\nSROCC: {0:.4f}\n"
          "PLCC: {1:.4f}\n"
          "RMSE: {2:.4f}".format(srocc.compute(), plcc.compute(),
                                 rmse.compute()))
def sample_cgan_concat_given_labels(netG,
                                    given_labels,
                                    batch_size=100,
                                    denorm=True,
                                    to_numpy=True,
                                    verbose=True):
    '''
    netG: pretrained generator network
    given_labels: float. unnormalized labels. we need to convert them to values in [-1,1]. 
    '''

    ## num of fake images will be generated
    nfake = len(given_labels)

    ## normalize regression
    labels = given_labels / max_label

    ## generate images
    if batch_size > nfake:
        batch_size = nfake

    fake_images = []
    ## concat to avoid out of index errors
    labels = np.concatenate((labels, labels[0:batch_size]), axis=0)

    netG = netG.cuda()
    netG.eval()
    with torch.no_grad():
        if verbose:
            pb = SimpleProgressBar()
        tmp = 0
        while tmp < nfake:
            z = torch.randn(batch_size, dim_gan, dtype=torch.float).cuda()
            c = torch.from_numpy(labels[tmp:(tmp + batch_size)]).type(
                torch.float).cuda()
            batch_fake_images = netG(z, c)
            if denorm:  #denorm imgs to save memory
                assert batch_fake_images.max().item(
                ) <= 1.0 and batch_fake_images.min().item() >= -1.0
                batch_fake_images = batch_fake_images * 0.5 + 0.5
                batch_fake_images = batch_fake_images * 255.0
                batch_fake_images = batch_fake_images.type(torch.uint8)
            fake_images.append(batch_fake_images.detach().cpu())
            tmp += batch_size
            if verbose:
                pb.update(min(float(tmp) / nfake, 1) * 100)

    fake_images = torch.cat(fake_images, dim=0)
    #remove extra entries
    fake_images = fake_images[0:nfake]

    if to_numpy:
        fake_images = fake_images.numpy()

    return fake_images, given_labels
示例#6
0
def cal_labelscore(PreNet, images, labels_assi, min_label_before_shift, max_label_after_shift, batch_size = 200, resize = None, norm_img = False, num_workers=0):
    '''
    PreNet: pre-trained CNN
    images: fake images
    labels_assi: assigned labels
    resize: if None, do not resize; if resize = (H,W), resize images to 3 x H x W
    '''

    PreNet.eval()

    # assume images are nxncximg_sizeximg_size
    n = images.shape[0]
    nc = images.shape[1] #number of channels
    img_size = images.shape[2]
    labels_assi = labels_assi.reshape(-1)

    eval_trainset = IMGs_dataset(images, labels_assi, normalize=False)
    eval_dataloader = torch.utils.data.DataLoader(eval_trainset, batch_size=batch_size, shuffle=False, num_workers=num_workers)

    labels_pred = np.zeros(n+batch_size)

    nimgs_got = 0
    pb = SimpleProgressBar()
    for batch_idx, (batch_images, batch_labels) in enumerate(eval_dataloader):
        batch_images = batch_images.type(torch.float).cuda()
        batch_labels = batch_labels.type(torch.float).cuda()
        batch_size_curr = len(batch_labels)

        if norm_img:
            batch_images = normalize_images(batch_images)

        batch_labels_pred, _ = PreNet(batch_images)
        labels_pred[nimgs_got:(nimgs_got+batch_size_curr)] = batch_labels_pred.detach().cpu().numpy().reshape(-1)

        nimgs_got += batch_size_curr
        pb.update((float(nimgs_got)/n)*100)

        del batch_images; gc.collect()
        torch.cuda.empty_cache()
    #end for batch_idx

    labels_pred = labels_pred[0:n]


    labels_pred = (labels_pred*max_label_after_shift)-np.abs(min_label_before_shift)
    labels_assi = (labels_assi*max_label_after_shift)-np.abs(min_label_before_shift)

    ls_mean = np.mean(np.abs(labels_pred-labels_assi))
    ls_std = np.std(np.abs(labels_pred-labels_assi))

    return ls_mean, ls_std
示例#7
0
 def __train_epoch__(self, epoch, loader):
     loss_avg = 0
     train_accuracy_avg = 0
     train_example_count = 0
     global_progress_bar = progress_bar.SimpleProgressBar()
     for b, batch in enumerate(loader):
         example_ids = batch['example_id']
         self.optimizer.zero_grad()
         # end_time = time.time()
         # print('time:{}'.format(end_time - start_time))
         # start_time = time.time()
         input_batch = self.framework.deal_with_example_batch(
             example_ids, loader.example_dict)
         loss, _ = self.framework(**input_batch)
         # end_time = time.time()
         # print('time:{}'.format(end_time-start_time))
         # start_time = time.time()
         loss.backward()
         # end_time = time.time()
         # print('time:{}'.format(end_time - start_time))
         # loss = float(loss)
         self.optimizer.step()
         if hasattr(self, 'scheduler'):
             self.scheduler.step()
         loss_avg += float(loss.item())
         # train_example_count += len(labels)
         # print('epoch:{:5}  batch:{:5}  arg_loss:{:20}'.format(epoch + 1, b + 1, loss), end='\r')
         # end_time = time.time()
         # print('time:{}'.format(end_time-start_time))
         global_progress_bar.update((b + 1) * 100 / len(loader))
     # print()
     loss_avg = loss_avg / len(loader)
     # train_accuracy_avg = train_accuracy_avg / train_example_count
     return loss_avg
示例#8
0
def extra_parsed_sentence_dict_from_org_file(org_file):
    rows = file_tool.load_data(org_file, 'r')
    parsed_sentence_dict = {}
    pb = progress_bar.SimpleProgressBar()
    print('begin extra parsed sentence dict from original file')
    count = len(rows)
    for row_index, row in enumerate(rows):
        items = row.strip().split('[Sq]')
        if len(items) != 4:
            raise ValueError("file format error")
        sent_id = str(items[0].strip())
        org_sent = str(items[1].strip())
        sent_tokens = str(items[2].strip()).split(' ')

        dependencies = []
        dep_strs = str(items[3].strip()).split('[De]')
        for dep_str in dep_strs:
            def extra_word_index(wi_str):
                wi_str_temp = wi_str.split('-')
                word = ''.join(wi_str_temp[:-1])
                index = str(wi_str_temp[-1])
                return word, index

            dep_itmes = dep_str.strip().split('[|]')
            if len(dep_itmes) != 3:
                raise ValueError("file format error")
            dep_name = str(dep_itmes[0]).strip()
            first_word, first_index = extra_word_index(str(dep_itmes[1]).strip())
            second_word, second_index = extra_word_index(str(dep_itmes[2]).strip())

            word_pair = {
                "first": {"word": first_word, "index": first_index},
                "second": {"word": second_word, "index": second_index}
            }

            dependency_dict = {
                'name': dep_name,
                'word_pair': word_pair
            }
            dependencies.append(dependency_dict)

        for w in sent_tokens:
            if w == '' or len(w) == 0:
                raise ValueError("file format error")

        parsed_info_dict = {
            'original': org_sent,
            'words': sent_tokens,
            'dependencies': dependencies,
            'id': sent_id,
            'has_root': True
        }
        if sent_id in parsed_sentence_dict:
            raise ValueError("file format error")

        parsed_sentence_dict[sent_id] = parsed_info_dict

        pb.update(row_index/count * 100)

    return parsed_sentence_dict
示例#9
0
def sample_cgan_given_labels(netG, given_labels, batch_size=500):
    '''
    netG: pretrained generator network
    given_labels: float. unnormalized labels. we need to convert them to values in [-1,1]. 
    '''

    ## num of fake images will be generated
    nfake = len(given_labels)

    ## normalize regression
    labels = given_labels / max_label

    ## generate images
    if batch_size > nfake:
        batch_size = nfake

    netG = netG.cuda()
    netG.eval()

    ## concat to avoid out of index errors
    labels = np.concatenate((labels, labels[0:batch_size]), axis=0)

    fake_images = []

    with torch.no_grad():
        pb = SimpleProgressBar()
        tmp = 0
        while tmp < nfake:
            z = torch.randn(batch_size, dim_z, dtype=torch.float).cuda()
            c = torch.from_numpy(labels[tmp:(tmp + batch_size)]).type(
                torch.float).cuda()
            batch_fake_images = netG(z, c)
            fake_images.append(batch_fake_images.detach().cpu().numpy())
            tmp += batch_size
            pb.update(min(float(tmp) / nfake, 1) * 100)

    fake_images = np.concatenate(fake_images, axis=0)
    #remove extra images
    fake_images = fake_images[0:nfake]

    #denomarlized fake images
    if fake_images.max() <= 1.0:
        fake_images = fake_images * 0.5 + 0.5
        fake_images = (fake_images * 255.0).astype(np.uint8)

    return fake_images, given_labels
示例#10
0
def sample_ccgan_given_labels(netG, net_y2h, labels, batch_size = 500, to_numpy=True, denorm=True, verbose=True):
    '''
    netG: pretrained generator network
    labels: float. normalized labels.
    '''

    nfake = len(labels)
    if batch_size>nfake:
        batch_size=nfake

    fake_images = []
    fake_labels = np.concatenate((labels, labels[0:batch_size]))
    netG=netG.cuda()
    netG.eval()
    net_y2h = net_y2h.cuda()
    net_y2h.eval()
    with torch.no_grad():
        if verbose:
            pb = SimpleProgressBar()
        n_img_got = 0
        while n_img_got < nfake:
            z = torch.randn(batch_size, dim_gan, dtype=torch.float).cuda()
            y = torch.from_numpy(fake_labels[n_img_got:(n_img_got+batch_size)]).type(torch.float).view(-1,1).cuda()
            batch_fake_images = netG(z, net_y2h(y))
            if denorm: #denorm imgs to save memory
                assert batch_fake_images.max().item()<=1.0 and batch_fake_images.min().item()>=-1.0
                batch_fake_images = batch_fake_images*0.5+0.5
                batch_fake_images = batch_fake_images*255.0
                batch_fake_images = batch_fake_images.type(torch.uint8)
                # assert batch_fake_images.max().item()>1
            fake_images.append(batch_fake_images.cpu())
            n_img_got += batch_size
            if verbose:
                pb.update(min(float(n_img_got)/nfake, 1)*100)
        ##end while

    fake_images = torch.cat(fake_images, dim=0)
    #remove extra entries
    fake_images = fake_images[0:nfake]
    fake_labels = fake_labels[0:nfake]

    if to_numpy:
        fake_images = fake_images.numpy()

    return fake_images, fake_labels
示例#11
0
def cal_labelscore(PreNet, images, labels_assi, min_label_before_shift, max_label_after_shift, batch_size = 500, resize = None):
    '''
    PreNet: pre-trained CNN
    images: fake images
    labels_assi: assigned labels
    resize: if None, do not resize; if resize = (H,W), resize images to 3 x H x W
    '''
    PreNet.eval()

    # assume images are nxncximg_sizeximg_size
    n = images.shape[0]
    nc = images.shape[1] #number of channels
    img_size = images.shape[2]
    labels_assi = labels_assi.reshape(-1)

    # predict labels
    labels_pred = np.zeros(n)
    with torch.no_grad():
        tmp = 0
        pb = SimpleProgressBar()
        for i in range(n//batch_size):
            pb.update(float(i)*100/(n//batch_size))
            image_tensor = torch.from_numpy(images[tmp:(tmp+batch_size)]).type(torch.float).cuda()
            if resize is not None:
                image_tensor = nn.functional.interpolate(image_tensor, size = resize, scale_factor=None, mode='bilinear', align_corners=False)
            labels_batch, _ = PreNet(image_tensor)
            labels_pred[tmp:(tmp+batch_size)] = labels_batch.detach().cpu().numpy().reshape(-1)
            tmp+=batch_size
        del image_tensor; gc.collect()
        torch.cuda.empty_cache()

    labels_pred = (labels_pred*max_label_after_shift)-np.abs(min_label_before_shift)
    labels_assi = (labels_assi*max_label_after_shift)-np.abs(min_label_before_shift)

    ls_mean = np.mean(np.abs(labels_pred-labels_assi))
    ls_std = np.std(np.abs(labels_pred-labels_assi))

    return ls_mean, ls_std
示例#12
0
文件: iqa.py 项目: Gavinylk/CNN-FRIQA
def test(test_data_loader, model):
    scores = []
    srocc = SROCC()
    plcc = PLCC()
    rmse = RMSE()
    len_test = len(test_data_loader)
    pb = ProgressBar(len_test-1, show_step=True)

    print("Testing")

    model.eval()
    with torch.no_grad():
        for i, ((img, ref), score) in enumerate(test_data_loader):
            img, ref = img.cuda(), ref.cuda()
            output = model(img, ref).cpu().data.numpy()
            score = score.data.numpy()

            srocc.update(score, output)
            plcc.update(score, output)
            rmse.update(score, output)

            pb.show(i, 'Test: [{0:5d}/{1:5d}]\t'
                    'Score: {2:.4f}\t'
                    'Label: {3:.4f}'
                    .format(i, len_test, float(output), float(score)))

            scores.append(output)
    
    # Write scores to file
    with open('../test/scores.txt', 'w') as f:
        stat = list(map(lambda s: f.write(str(s)+'\n'), scores))

    print('\n\nSROCC: {0:.4f}\n'
            'PLCC: {1:.4f}\n'
            'RMSE: {2:.4f}'
            .format(srocc.compute(), plcc.compute(), rmse.compute())
    )
示例#13
0
def sample_cgan_given_labels(netG,
                             given_labels,
                             class_cutoff_points,
                             batch_size=200,
                             denorm=True,
                             to_numpy=True,
                             verbose=True):
    '''
    given_labels: a numpy array; raw label without any normalization; not class label
    class_cutoff_points: the cutoff points to determine the membership of a give label
    '''

    class_cutoff_points = np.array(class_cutoff_points)
    num_classes = len(class_cutoff_points) - 1

    nfake = len(given_labels)
    given_class_labels = np.zeros(nfake)
    for i in range(nfake):
        curr_given_label = given_labels[i]
        diff_tmp = class_cutoff_points - curr_given_label
        indx_nonneg = np.where(diff_tmp >= 0)[0]
        if len(indx_nonneg
               ) == 1:  #the last element of diff_tmp is non-negative
            curr_given_class_label = num_classes - 1
            assert indx_nonneg[0] == num_classes
        elif len(indx_nonneg) > 1:
            if diff_tmp[indx_nonneg[0]] > 0:
                curr_given_class_label = indx_nonneg[0] - 1
            else:
                curr_given_class_label = indx_nonneg[0]
        given_class_labels[i] = curr_given_class_label
    given_class_labels = np.concatenate(
        (given_class_labels, given_class_labels[0:batch_size]))

    if batch_size > nfake:
        batch_size = nfake
    fake_images = []
    netG = netG.cuda()
    netG.eval()
    with torch.no_grad():
        if verbose:
            pb = SimpleProgressBar()
        tmp = 0
        while tmp < nfake:
            z = torch.randn(batch_size, dim_gan, dtype=torch.float).cuda()
            labels = torch.from_numpy(
                given_class_labels[tmp:(tmp + batch_size)]).type(
                    torch.long).cuda()
            if labels.max().item() > num_classes:
                print("Error: max label {}".format(labels.max().item()))
            batch_fake_images = netG(z, labels)
            if denorm:  #denorm imgs to save memory
                assert batch_fake_images.max().item(
                ) <= 1.0 and batch_fake_images.min().item() >= -1.0
                batch_fake_images = batch_fake_images * 0.5 + 0.5
                batch_fake_images = batch_fake_images * 255.0
                batch_fake_images = batch_fake_images.type(torch.uint8)
                # assert batch_fake_images.max().item()>1
            fake_images.append(batch_fake_images.detach().cpu())
            tmp += batch_size
            if verbose:
                pb.update(min(float(tmp) / nfake, 1) * 100)

    fake_images = torch.cat(fake_images, dim=0)
    #remove extra entries
    fake_images = fake_images[0:nfake]

    if to_numpy:
        fake_images = fake_images.numpy()

    return fake_images, given_labels
示例#14
0
def cal_FID(PreNetFID,
            IMGSr,
            IMGSg,
            batch_size=500,
            resize=None,
            norm_img=False):
    #resize: if None, do not resize; if resize = (H,W), resize images to 3 x H x W

    PreNetFID.eval()

    nr = IMGSr.shape[0]
    ng = IMGSg.shape[0]

    nc = IMGSr.shape[1]  #IMGSr is nrxNCxIMG_SIExIMG_SIZE
    img_size = IMGSr.shape[2]

    if batch_size > min(nr, ng):
        batch_size = min(nr, ng)
        # print("FID: recude batch size to {}".format(batch_size))

    #compute the length of extracted features
    with torch.no_grad():
        test_img = torch.from_numpy(IMGSr[0].reshape(
            (1, nc, img_size, img_size))).type(torch.float).cuda()
        if resize is not None:
            test_img = nn.functional.interpolate(test_img,
                                                 size=resize,
                                                 scale_factor=None,
                                                 mode='bilinear',
                                                 align_corners=False)
        if norm_img:
            test_img = normalize_images(test_img)
        # _, test_features = PreNetFID(test_img)
        test_features = PreNetFID(test_img)
        d = test_features.shape[1]  #length of extracted features

    Xr = np.zeros((nr, d))
    Xg = np.zeros((ng, d))

    #batch_size = 500
    with torch.no_grad():
        tmp = 0
        pb1 = SimpleProgressBar()
        for i in range(nr // batch_size):
            imgr_tensor = torch.from_numpy(IMGSr[tmp:(tmp + batch_size)]).type(
                torch.float).cuda()
            if resize is not None:
                imgr_tensor = nn.functional.interpolate(imgr_tensor,
                                                        size=resize,
                                                        scale_factor=None,
                                                        mode='bilinear',
                                                        align_corners=False)
            if norm_img:
                imgr_tensor = normalize_images(imgr_tensor)
            # _, Xr_tmp = PreNetFID(imgr_tensor)
            Xr_tmp = PreNetFID(imgr_tensor)
            Xr[tmp:(tmp + batch_size)] = Xr_tmp.detach().cpu().numpy()
            tmp += batch_size
            # pb1.update(min(float(i)*100/(nr//batch_size), 100))
            pb1.update(min(max(tmp / nr * 100, 100), 100))
        del Xr_tmp, imgr_tensor
        gc.collect()
        torch.cuda.empty_cache()

        tmp = 0
        pb2 = SimpleProgressBar()
        for j in range(ng // batch_size):
            imgg_tensor = torch.from_numpy(IMGSg[tmp:(tmp + batch_size)]).type(
                torch.float).cuda()
            if resize is not None:
                imgg_tensor = nn.functional.interpolate(imgg_tensor,
                                                        size=resize,
                                                        scale_factor=None,
                                                        mode='bilinear',
                                                        align_corners=False)
            if norm_img:
                imgg_tensor = normalize_images(imgg_tensor)
            # _, Xg_tmp = PreNetFID(imgg_tensor)
            Xg_tmp = PreNetFID(imgg_tensor)
            Xg[tmp:(tmp + batch_size)] = Xg_tmp.detach().cpu().numpy()
            tmp += batch_size
            # pb2.update(min(float(j)*100/(ng//batch_size), 100))
            pb2.update(min(max(tmp / ng * 100, 100), 100))
        del Xg_tmp, imgg_tensor
        gc.collect()
        torch.cuda.empty_cache()

    fid_score = FID(Xr, Xg, eps=1e-6)

    return fid_score
示例#15
0
    def train_final_model(self):
        # self.framework.print_arg_dict()
        self.create_framework()
        # return
        self.logger.info('begin to train final model')
        train_loader = self.data_manager.train_loader(
            self.arg_dict['batch_size'])
        # train_loader, _ = self.data_loader_dict['train_loader_tuple_list'][1]

        max_steps = self.arg_dict["max_steps"]
        train_epochs = self.arg_dict['epoch']
        if max_steps > 0:
            t_total = max_steps
            train_epochs = max_steps // len(train_loader) + 1
        else:
            t_total = len(train_loader) * train_epochs

        self.scheduler = get_linear_schedule_with_warmup(
            self.optimizer,
            num_warmup_steps=self.arg_dict["warmup_steps"],
            num_training_steps=t_total)

        self.logger.info("train_loader:{}".format(len(train_loader)))
        loss_list = []

        max_steps_break = False
        step_count = 0
        general_tool.setup_seed(self.arg_dict['seed'])
        for epoch in range(train_epochs):
            loss_avg = 0
            global_progress_bar = progress_bar.SimpleProgressBar()
            for b, batch in enumerate(train_loader):
                example_ids = batch['example_id']
                self.optimizer.zero_grad()

                input_batch = self.framework.deal_with_example_batch(
                    example_ids, train_loader.example_dict)
                loss, _ = self.framework(**input_batch)

                loss.backward()

                self.optimizer.step()
                if hasattr(self, 'scheduler'):
                    self.scheduler.step()
                loss_avg += float(loss.item())
                step_count += 1
                if (max_steps > 0) and (step_count >= max_steps):
                    if max_steps_break:
                        raise RuntimeError
                    max_steps_break = True
                    break

                global_progress_bar.update((b + 1) * 100 / len(train_loader))

            loss_avg = loss_avg / len(train_loader)

            self.logger.info('epoch:{}  arg_loss:{}'.format(
                epoch + 1, loss_avg))
            loss_list.append(loss_avg)
            self.logger.info("current learning rate:{}".format(
                self.scheduler.get_last_lr()[0]))

        record_dict = {
            'loss': loss_list,
        }

        self.save_model()
        self.save_model(cpu=True)
        return record_dict
示例#16
0
    def __train_fold__(self, train_loader, valid_loader):
        return_state = ""
        epoch = 0
        max_accuracy = 0
        max_accuracy_e = 0
        loss_list = []
        valid_accuracy_list = []
        valid_f1_list = []
        trial_count_report = 0
        trial_report_list = []

        max_steps = self.arg_dict["max_steps"]
        train_epochs = self.arg_dict['epoch']
        if max_steps > 0:
            t_total = max_steps
            train_epochs = max_steps // len(train_loader) + 1
        else:
            t_total = len(train_loader) * train_epochs

        self.scheduler = get_linear_schedule_with_warmup(
            self.optimizer,
            num_warmup_steps=self.arg_dict["warmup_steps"],
            num_training_steps=t_total)
        step_count = 0
        max_steps_break = False
        self.logger.info(
            'total step:{} max step:{} warmup_steps:{} epoch:{} '.format(
                t_total, max_steps, self.arg_dict["warmup_steps"],
                train_epochs))
        try:
            best_result = None
            general_tool.setup_seed(self.arg_dict['seed'])
            for epoch in range(train_epochs):
                loss_avg = 0
                global_progress_bar = progress_bar.SimpleProgressBar()
                for b, batch in enumerate(train_loader):
                    example_ids = batch['example_id']
                    self.optimizer.zero_grad()

                    input_batch = self.framework.deal_with_example_batch(
                        example_ids, train_loader.example_dict)
                    loss, _ = self.framework(**input_batch)

                    loss.backward()

                    self.optimizer.step()
                    if hasattr(self, 'scheduler'):
                        self.scheduler.step()
                        # self.logger.info("current learning rate:{}".format(self.scheduler.get_last_lr()[0]))
                    loss_avg += float(loss.item())
                    step_count += 1
                    if (max_steps > 0) and (step_count >= max_steps):
                        if max_steps_break:
                            raise RuntimeError
                        max_steps_break = True
                        break

                    global_progress_bar.update(
                        (b + 1) * 100 / len(train_loader))

                loss_avg = loss_avg / len(train_loader)

                self.logger.info('epoch:{}  arg_loss:{}'.format(
                    epoch + 1, loss_avg))
                self.logger.info("current learning rate:{}".format(
                    self.scheduler.get_last_lr()[0]))
                with torch.no_grad():
                    evaluation_result = self.evaluation_calculation(
                        valid_loader)
                    valid_accuracy = evaluation_result['metric']['accuracy']
                    self.logger.info(evaluation_result['metric'])

                    if valid_accuracy > max_accuracy:
                        best_result = evaluation_result['metric']
                        max_accuracy = valid_accuracy
                        max_accuracy_e = epoch + 1
                        self.save_model()

                    loss_list.append(loss_avg)
                    valid_accuracy_list.append(valid_accuracy)
                    valid_f1_list.append(evaluation_result['metric']['F1'])

                return_state = "finished_max_epoch"
                # print(1.0 - train_accuracy_avg)
                if hasattr(self, 'trial'):
                    self.trial.report(1.0 - valid_accuracy, self.trial_step)
                    self.logger.info('trial_report:{} at step:{}'.format(
                        1.0 - valid_accuracy, self.trial_step))
                    trial_report_list.append(
                        (1.0 - valid_accuracy, self.trial_step))
                    self.trial_step += 1
                    trial_count_report += 1
                    if self.trial.should_prune():
                        raise optuna.exceptions.TrialPruned()

            if hasattr(self, 'trial'):
                self.logger.info(
                    'trial_report_count:{}'.format(trial_count_report))

            if best_result is not None:
                self.logger.info('max acc:{}  F1:{}  best epoch:{}'.format(
                    max_accuracy, best_result['F1'], max_accuracy_e))

        except KeyboardInterrupt:
            return_state = 'KeyboardInterrupt'
            if best_result is not None:
                self.logger.info('max acc:{}  F1:{}  best epoch:{}'.format(
                    max_accuracy, best_result['F1'], max_accuracy_e))
            else:
                print('have not finished one epoch')

        record_dict = {
            'loss': loss_list,
            'valid_acc': valid_accuracy_list,
            'valid_F1': valid_f1_list,
            'trial_report_list': trial_report_list
        }
        return round(1 - max_accuracy, 4), epoch + 1, return_state, record_dict