def loadData(args):

    print("loading the input file(s)")
    # pdb.set_trace()
    A = pd.read_csv(osp.join(args.root, args.adj), index_col=0)
    index = A.index
    A = A.astype(np.float32)
    A = np.array(A)

    #A = torch.from_numpy(np.array(A))

    df = pd.read_csv(osp.join(args.root, args.data), index_col=0)
    df_con = pd.read_csv(osp.join(args.root, args.con_data), index_col=0)
    df_dis = pd.read_csv(osp.join(args.root, args.dis_data), index_col=0)
    num_input = df.shape[0]
    train_input = round(0.7 * num_input)
    test_input = round(0.2 * num_input)
    val_input = round(0.1 * num_input)
    time = set_time(df.index.values, df)

    dis = set_dis(df_dis)
    con = Construction_Layer(torch.from_numpy(np.array(df_con)),
                             torch.from_numpy(np.array(dis)))
    dataframe = np.stack([np.array(df), con.numpy(), time.numpy()])

    #all = dataframe[0,:,:]
    mean, std = np.mean(df), np.std(df)
    X = normalize(dataframe, mean[0], std[0])
    train = X[:, :train_input, :]
    val = X[:, train_input:train_input + val_input, :]
    test = X[:, train_input + val_input:, :]

    train_ = data_generate(train, args, df[:train_input], 'train.pt')
    val_ = data_generate(val, args, df[train_input:train_input + val_input],
                         'val.pt')
    test_ = data_generate(test, args, df[train_input + val_input:], 'test.pt')

    train = mydataset(train_[0], train_[1])
    val = mydataset(val_[0], val_[1])
    test = mydataset(test_[0], test_[1])
    test_time = mydataset(test_[2], test_[3])
    return train, val, test, test_time, A, mean, std, index
예제 #2
0
def train():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    train_path = ['data/train/train_extension']

    train_f_path = ['data/train/train_extension_Process_data/FEAT']
    test_path = ['data/test/EvaluationFramework_ISMIR2014/DATASET']
    test_f_path = ['data/test/Process_data/FEAT']
    testsample_path = "data/test_sample/wav_label"
    testsample_f_path = "data/test_sample/FEAT"

    if torch.cuda.is_available():
        print("cuda")
    else:
        print("cpu")

    augmentation = [Random_volume(rate=0.5, min_range=0.5, max_range=1.5)]

    train_dataloader = DataLoader(mydataset(train_path,
                                            train_f_path,
                                            amount=10000,
                                            augmentation=augmentation,
                                            channel=hparam.FEAT_channel),
                                  batch_size=hparam.batch_size,
                                  shuffle=True,
                                  num_workers=hparam.num_workers)
    test_dataloader = DataLoader(mydataset(test_path,
                                           test_f_path,
                                           channel=hparam.FEAT_channel),
                                 batch_size=hparam.batch_size,
                                 shuffle=True,
                                 num_workers=hparam.num_workers)

    model = get_Resnet(channel=hparam.FEAT_channel,
                       is_simplified=True).to(device)
    # model = resnest50(channel=9, num_classes=6).to(device)
    # model.load_state_dict(torch.load("checkpoint/3690.pth"))
    # print("load OK")

    optimizer = optim.RMSprop(model.parameters(),
                              lr=hparam.lr,
                              weight_decay=0,
                              momentum=0.9)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=10)

    model.train()

    logger = Logger(runs_dir=hparam.runs_path)

    step_count = 0
    ori_runlist = logger.get_runsdir()
    with SummaryWriter(comment=tensor_comment) as writer:
        new_rundir = logger.get_new_runsdir(ori_runlist)
        for idx, code_path in enumerate(hparam.modelcode_path):
            logger.save_codebackup(code_path, new_rundir, index=idx)

        for epoch in range(hparam.epoch):
            bar = tqdm(train_dataloader)
            for features_full, label_note in bar:
                model.train()
                for clip_id in range(features_full.shape[-1] // 19):
                    features_full_clip = features_full[:, :, :, clip_id *
                                                       19:(clip_id + 1) *
                                                       19].to(device)
                    label_note_clip = label_note[:, clip_id:clip_id +
                                                 1, :].to(device)

                    optimizer.zero_grad()

                    out_label = model(features_full_clip)
                    loss = get_BCE_loss(out_label, label_note_clip)

                    loss.backward()
                    optimizer.step()
                step_count += 1

                if step_count % hparam.step_to_test == 0:
                    test_sumloss = 0
                    acc_sumloss = 0
                    batch_count = 0
                    with torch.no_grad():
                        for features_full, label_note in test_dataloader:
                            for clip_id in range(features_full.shape[-1] //
                                                 19):
                                features_full_clip = features_full[:, :, :,
                                                                   clip_id *
                                                                   19:
                                                                   (clip_id +
                                                                    1) *
                                                                   19].to(
                                                                       device)
                                label_note_clip = label_note[:,
                                                             clip_id:clip_id +
                                                             1, :].to(device)
                                out_label = model(features_full_clip)

                                test_loss = get_BCE_loss(
                                    out_label, label_note_clip)
                                test_sumloss += test_loss
                                test_acc = get_accuracy(
                                    out_label, label_note_clip)

                                acc_sumloss += test_acc
                                batch_count += 1

                    avg_loss = test_sumloss / batch_count
                    avg_acc = acc_sumloss / batch_count
                    writer.add_scalar(f'scalar/acc', avg_acc, step_count)
                    writer.add_scalar(f'scalar/loss/test', avg_loss,
                                      step_count)
                    writer.add_scalar(f'scalar/loss/train', loss, step_count)
                    #
                    # writer.add_scalars(f'scalar/loss', {'train_loss':loss, 'test_loss':avg_loss}, step_count)

                    bar.set_postfix({
                        'loss': f' {loss} ',
                        'test_loss': f' {avg_loss} ',
                        'test_acc': f' {avg_acc} '
                    })
                    bar.update(1)

                if step_count % hparam.step_to_save == 0:

                    whole_song_sampletest(testsample_path,
                                          testsample_f_path,
                                          model=model,
                                          writer_in=writer,
                                          timestep=step_count,
                                          channel=hparam.FEAT_channel)
                    torch.save(model.state_dict(),
                               f"checkpoint/{step_count}.pth")
                    logger.save_modelbackup(model, new_rundir)
                    print(f'saved in {step_count}\n')

                    test_eval_path = test_path[0]
                    test_eval_f_path = test_f_path[0]

                    testset_evaluation(test_eval_path,
                                       test_eval_f_path,
                                       model=model,
                                       writer_in=writer,
                                       timestep=step_count,
                                       channel=hparam.FEAT_channel)
예제 #3
0
    def testing(self):
        self.parallel = False
        
        self.load_model()
        self.load_save_model()
        self.logger.changedir()

        t_imageDir = '../../mri_dataset/real_final_test/'
        t_labelDir = '../../mri_dataset/real_final_test/'
        
        Dataset  = { 'valid': DataLoader(mydataset(self.imageDir,self.sampling,kfold=False),
                                batch_size = 1,
                                shuffle = False, 
                                num_workers = 8),
                    'test': DataLoader(mydataset(t_imageDir,self.sampling,kfold=False),
                                batch_size = 1,
                                shuffle = False, 
                                num_workers = 8)}

                # Dataset  = {'train': DataLoader(mydataset(self.imageDir,self.sampling,self.knum,True),
                #                 batch_size = self.batch_size,
                #                 shuffle = True,
                #                 num_workers=8),
                #     'valid': DataLoader(mydataset(self.imageDir,self.sampling,self.knum),
                #                 batch_size = self.batch_size,
                #                 shuffle = True, 
                #                 num_workers = 4)}


        phase = 'test'
        test_final_psnr = 0
        test_recon_psnr = 0
        test_final_ssim = 0
        test_recon_ssim = 0
        test_evalution =  recon_matrix()
        new_evalutaion = []

        for i, batch in enumerate(tqdm(Dataset[phase])):
            with torch.no_grad():
                self.gen.eval()
                self.regen.eval()
                self.dis.eval()

                inputa, mask = batch[0].to(self.device), batch[1].to(self.device)
                inputa=inputa.float()
                mask = mask.float()
                
                under_im,zeroimg = apply_mask(inputa,mask)
                
                prediction = self.gen(zeroimg.float())
                
                prediction = torch.add(zeroimg.float() , prediction)                            
                pred = update(prediction,inputa,mask)

                re_prediction = self.regen(pred.float())
                re_prediction = torch.add(zeroimg.float() , re_prediction)                            
                re_pred = update(re_prediction,inputa,mask)


                #calcuate matrix
                # prediction = cv2.normalize(prediction.cpu().numpy(),  normalizedImg, 0, 1, cv2.NORM_MINMAX)
                # pred = cv2.normalize(pred.cpu().numpy(),  normalizedImg, 0, 1, cv2.NORM_MINMAX)
                # inputa = cv2.normalize(inputa.cpu().numpy(),  normalizedImg, 0, 1, cv2.NORM_MINMAX)
                
                # prediction = (prediction * 255.0).cpu().detach().numpy()
                # pred = (pred * 255.0).cpu().detach().numpy()
                # inputa = (inputa * 255.0).cpu().numpy()
                
                # re_prediction = (re_prediction * 255.0).cpu().detach().numpy()
                # re_pred = (re_pred * 255.0).cpu().detach().numpy()
                
                # prediction = cvt2imag(prediction).cpu().detach().numpy()
                # pred = cvt2imag(pred).cpu().detach().numpy()
                # inputa= cvt2imag(inputa).cpu().detach().numpy()
                # re_prediction = cvt2imag(re_prediction).cpu().detach().numpy()
                # re_pred = cvt2imag(re_pred).cpu().detach().numpy()
                # # print(zeroimg.max())
                # zeroimg_ = cvt2imag(zeroimg).cpu().detach().numpy()
                prediction = (prediction[:,0:1])
                pred = (pred[:,0:1])
                inputa= (inputa[:,0:1])
                re_prediction = (re_prediction[:,0:1])
                re_pred = (re_pred[:,0:1])
                # print(zeroimg.max())
                zeroimg_ = (zeroimg[:,0:1])
                # print(zeroimg_.max(),inputa.max())
                test_final_psnr = test_evalution.psnr(prediction,inputa,1.0)
                test_recon_psnr = test_evalution.psnr(pred,inputa,1.0)
                test_final_ssim = test_evalution.psnr(re_prediction,inputa,1.0)
                test_recon_ssim = test_evalution.psnr(re_pred,inputa,1.0)
                zero_recon_psnr = test_evalution.psnr(zeroimg_,inputa,1.0)
        

                midterm_ssim = test_evalution.ssim(pred,inputa,1.0)
                finalterm_ssim = test_evalution.ssim(re_pred,inputa,1.0)
                zeroterm_ssim = test_evalution.ssim(zeroimg_,inputa,1.0)
        

                # i=float(i)
                evalutiondict = {'val_recon_psnr':test_recon_psnr,
                                    'val_recon_ssim':test_recon_ssim,
                                'zero_recon_psnr':zero_recon_psnr}

                                
                evalutiondict.update({'midterm_ssim':midterm_ssim,
                                    'finalterm_ssim':finalterm_ssim,
                                'zeroterm_ssim':zeroterm_ssim})
                                
                # class_list=['recon_psnr','final_psnr','recon_ssim','final_ssim']
                class_list,evalutiondict = self.logger.convert_to_list(evalutiondict)
                new_evalutaion.append(list(evalutiondict))

                prediction = prediction.cpu().detach().numpy()
                pred = pred.cpu().detach().numpy()
                inputa= inputa.cpu().detach().numpy()
                re_prediction = re_prediction.cpu().detach().numpy()
                re_pred = re_pred.cpu().detach().numpy()
                # print(zeroimg.max())
                zeroimg_ = zeroimg.cpu().detach().numpy()

                total_image = {'recon_image':prediction,
                                'final_image':pred,
                                'input_image':inputa}
                total_image.update({'re_prediction':re_prediction,
                                    're_pred':re_pred})

                total_image.update({'zero_image':zeroimg.cpu().numpy()})
                self.re_normalize(total_image,255)
                
                # recon_error_img = np.abs(total_image['input_image'] - total_image['recon_image'])
                # final_error_img = np.abs(total_image['input_image'] - total_image['final_image'])
                # zero_error_img = np.abs(total_image['input_image'] - total_image['zero_image'])

                recon_error_img = total_image['input_image'] - total_image['final_image']
                final_error_img = total_image['input_image'] - total_image['re_pred']
                zero_error_img = total_image['input_image'] - total_image['zero_image']

                total_image.update({'recon_error_img':recon_error_img,
                                    'final_error_img':final_error_img,
                                    'zero_error_img':zero_error_img,
                                    'mask':mask.cpu().numpy()})
                
                self.logger.save_images(total_image,i)

        self.logger.save_csv_file(np.array(new_evalutaion),'valid',list(class_list))
예제 #4
0
    def trainning(self):


        self.load_model()
        self.load_loss()

        evalution =  recon_matrix()
        print('----- Dataset-------------')
        Dataset  = {'train': DataLoader(mydataset(self.imageDir,self.sampling,self.knum,True,self_supervised=True),
                                batch_size = self.batch_size,
                                shuffle = False,
                                num_workers=8),
                    'valid': DataLoader(mydataset(self.imageDir,self.sampling,self.knum),
                                batch_size = self.batch_size,
                                shuffle = True, 
                                num_workers = 4)}

        best_psnr = 0
        best_epoch = 0
        best_ssim = 0

        for epoch in range(self.epoches):
      
            #set loss weight
            ALPHA = 1e+1
            GAMMA = 1e+4
            DELTA = 1e-4
            normalizedImg = np.zeros((192,192))
            

            if epoch % self.validnum == 0:
                self.gen.eval()
                self.regen.eval()
                self.dis.eval()
                phase = 'valid'
                self.save_model('last',epoch)
                
                val_recon_psnr = 0
                val_final_psnr = 0
                val_recon_ssim = 0
                val_final_ssim = 0

                
            else :
                self.gen.train()
                self.regen.train()
                self.dis.train()
                phase = 'train'
                self.schedulerG.step(epoch)
                self.schedulerD.step(epoch)
                recon_psnr = 0
                final_psnr = 0
                recon_ssim = 0
                final_ssim = 0

            print(f"{epoch}/{self.epoches}epochs,IR=>{get_lr(self.optimizerG)},best_epoch=>{best_epoch},phase=>{phase}")
            print(f"==>{self.path}<==")      
            for i, batch in enumerate(tqdm(Dataset[phase])):
                if phase == 'train':
                    # set model inputa

                    _image = Variable(batch[0]).to(self.device)
                    mask = Variable(batch[1]).to(self.device)

                    mask2 = Variable(batch[2]).to(self.device)
                    
                    self.optimizerG.zero_grad()
                    # under_image,zero_image = apply_mask(_image,mask2)
                    
                    # _image=zero_image.float()
                    
                    ##########preprocessing FFT, iFFT ###########
                    # apply_mask function do multiple cartasian mask randomly
                    under_image,zero_image = apply_mask(_image,mask) 
                    # when apply Furier transform it change data dype double so we change floatdatatype
                    zero_image=zero_image.float()
                    
                    ###############train gen#################
                    #generator Add ZFimage + feature image 
                    trained_img = self.gen(zero_image)
                    recon_img = torch.add(trained_img , zero_image).float().to(self.device)

                    #update mask again
                    mask=mask.to(torch.float32)
                    _image=_image.to(torch.float32)
                    
                    final_img = update(recon_img,_image,mask).to(torch.float32).to(self.device)

                    ###refine gan ####
                    retrained_img = self.regen(final_img)
                    re_recon_img = torch.add(retrained_img, zero_image).float().to(self.device)
                    re_final_img = update(re_recon_img, _image,mask).to(torch.float32).to(self.device)


                    ###########total generate loss ##############
                    #image losses
                    recon_loss = self.Lloss(_image,recon_img)
                    error_loss = self.Closs(_image,final_img)       
                    
                    re_recon_loss = self.Lloss(_image,re_recon_img)
                    re_error_loss = self.Closs(_image,re_final_img)
                    #compare frequecy image
                    #frequency loss
                    freq_img,_ = apply_mask(final_img,mask)
                    re_freq_img,_ = apply_mask(re_final_img,mask)

                    freq_loss = self.Floss(under_image.to(torch.float32),freq_img).to(torch.float32)
                    re_freq_loss = self.Floss(under_image.to(torch.float32),re_freq_img).to(torch.float32)
                    
                    #WGan gen loss
                    boost_dis_fake = self.dis(final_img[:,0:1])
                    re_boost_dis_fake = self.dis(re_final_img[:,0:1])

                    boost_fake_A=self.WGAN_loss.loss_gen(boost_dis_fake)
                    re_boost_fake_A=self.WGAN_loss.loss_gen(re_boost_dis_fake)

                    #add regulization (total variation)
                    
                    TV_a=self.TVloss(final_img).to(torch.float32)
                    re_TV_a = self.TVloss(re_final_img).to(torch.float32)

                    #reconstruction loss
                    boost_R_loss = (recon_loss + error_loss* 10.0  + re_recon_loss + re_error_loss* 10.0) * 10.0 + (freq_loss + re_freq_loss)
                    boost_R_loss += (boost_fake_A + re_boost_fake_A)* 10.0
                    boost_R_loss += (TV_a + re_TV_a)
                    
                    loss_g = (boost_R_loss)
                    if self.perceptual_loss == True:
                        vgg_recon_loss = self.vgg_loss(_image,recon_img)
                        vgg_error_loss = self.vgg_loss(_image,final_img)       
                        
                        vgg_re_recon_loss = self.vgg_loss(_image,re_recon_img)
                        vgg_re_error_loss = self.vgg_loss(_image,re_final_img)
                        boost_R_loss += (vgg_error_loss+vgg_re_error_loss)

                    
                    loss_g.backward(retain_graph=True)
                    
                    self.optimizerG.step()

                    summary_val = {'recon_loss':recon_loss,
                                'error_loss':error_loss,
                                'freq_loss':freq_loss,
                                'TV_a':TV_a,
                                'boost_fake_A':boost_fake_A}


                    summary_val.update({'re_recon_loss':re_recon_loss,
                                're_error_loss':re_error_loss,
                                're_freq_loss':re_freq_loss,
                                're_TV_a':re_TV_a,
                                're_boost_fake_A':re_boost_fake_A})

                    if self.perceptual_loss == True:
                        summary_val.update({'vgg_error_loss':vgg_error_loss,
                                'vgg_re_error_loss':vgg_re_error_loss})
                        ###############train discrim#################
                        # self.gen_num
                    # if epoch > self.gen_num:

                    dis_real_img = self.dis(_image[:,0:1])
                    dis_fake_img = self.dis(final_img[:,0:1])
                    re_dis_fake_img = self.dis(re_final_img[:,0:1])
                    

                    self.optimizerD.zero_grad()

                    #calcuate loss function
                    dis_loss = self.WGAN_loss.loss_disc(dis_fake_img,dis_real_img)
                    re_dis_loss = self.WGAN_loss.loss_disc(re_dis_fake_img,dis_real_img)

                    # loss_RMSE   = Lloss(_image,final_image)            
                    GP_loss=compute_gradient_penalty(self.dis, _image[:,0:1], final_img[:,0:1],self.device) * 0.0001
                    re_GP_loss=compute_gradient_penalty(self.dis, _image[:,0:1], re_final_img[:,0:1],self.device) * 0.0001
                    
                    discrim_loss = (dis_loss + re_dis_loss) + (GP_loss + re_GP_loss)
                    discrim_loss.backward(retain_graph=True)

                    self.optimizerD.step()
                    summary_val.update({'dis_loss':dis_loss,'GP_loss':GP_loss})
                    
                    summary_val.update({'re_dis_loss':re_dis_loss,'re_GP_loss':re_GP_loss})
                
                    # final_img = cv2.normalize(final_img.cpu().detach().numpy(),  normalizedImg, 0, 1, cv2.NORM_MINMAX)
                    # recon_img = cv2.normalize(recon_img.cpu().detach().numpy(),  normalizedImg, 0, 1, cv2.NORM_MINMAX)
                    # _image = cv2.normalize(_image.cpu().numpy(),  normalizedImg, 0, 1, cv2.NORM_MINMAX)
                    
                    # re_final_img = cv2.normalize(re_final_img.cpu().detach().numpy(),  normalizedImg, 0, 1, cv2.NORM_MINMAX)
                    # re_recon_img = cv2.normalize(re_recon_img.cpu().detach().numpy(),  normalizedImg, 0, 1, cv2.NORM_MINMAX)
                    
                    # final_img = (final_img * 255.0).cpu().detach().numpy()
                    # recon_img = (recon_img * 255.0).cpu().detach().numpy()
                    # _image = (_image * 255.0).cpu().numpy()
                    
                    # re_final_img = (re_final_img * 255.0).cpu().detach().numpy()
                    # re_recon_img = (re_recon_img * 255.0).cpu().detach().numpy()
                    
                    final_img = cvt2imag(final_img).cpu().detach().numpy()
                    recon_img = cvt2imag(recon_img).cpu().detach().numpy()
                    _image= cvt2imag(_image).cpu().detach().numpy()
                    re_final_img = cvt2imag(re_final_img).cpu().detach().numpy()
                    re_recon_img = cvt2imag(re_recon_img).cpu().detach().numpy()


                    final_psnr = evalution.PSNR(final_img,_image,255.0)
                    recon_psnr = evalution.PSNR(recon_img,_image,255.0)
                    final_ssim = evalution.PSNR(re_final_img,_image,255.0)
                    recon_ssim = evalution.PSNR(re_recon_img,_image,255.0)
                    
                    
                else:
                    with torch.no_grad():

                        inputa, mask = batch[0].to(self.device), batch[1].to(self.device)
                        
                        inputa=inputa.float()
                        mask = mask.float()
                        
                        under_im,inputs = apply_mask(inputa,mask)
                        prediction = self.gen(inputs.float())
                        prediction = torch.add(inputs.float() , prediction)                         
                        pred = update(prediction,inputa,mask)

                        

                        re_prediction = self.regen(prediction.float())
                        
                        re_prediction = torch.add(inputs.float() , re_prediction)                         
                        
                        re_pred = update(re_prediction,inputa,mask)

                        recon_loss = self.Lloss(inputa,prediction) 
                        error_loss = self.Closs(inputa,pred)        
                        re_recon_loss = self.Lloss(inputa,re_prediction) 
                        re_error_loss = self.Closs(inputa,re_pred)        
                        
                        #compare frequecy image
                        #frequency loss

                        freq_img,_ = apply_mask(pred,mask)
                        freq_loss = self.Floss(under_im.to(torch.float32),freq_img).to(torch.float32)

                        re_freq_img,_ = apply_mask(re_pred,mask)
                        re_freq_loss = self.Floss(under_im.to(torch.float32),re_freq_img).to(torch.float32)


                        summary_val = {'recon_loss':recon_loss,
                            'error_loss':error_loss,
                            'freq_loss':freq_loss}
            

                        summary_val.update({'re_recon_loss':re_recon_loss,
                            're_error_loss':re_error_loss,
                            're_freq_loss':re_freq_loss})
            

                        #calcuate matrix
                        # prediction = cv2.normalize(prediction.cpu().numpy(),  normalizedImg, 0, 1, cv2.NORM_MINMAX)
                        # pred = cv2.normalize(pred.cpu().numpy(),  normalizedImg, 0, 1, cv2.NORM_MINMAX)
                        # inputa = cv2.normalize(inputa.cpu().numpy(),  normalizedImg, 0, 1, cv2.NORM_MINMAX)

                        # re_prediction = cv2.normalize(re_prediction.cpu().numpy(),  normalizedImg, 0, 1, cv2.NORM_MINMAX)
                        # re_pred = cv2.normalize(re_pred.cpu().numpy(),  normalizedImg, 0, 1, cv2.NORM_MINMAX)

                        # prediction = (prediction  * 255.0).cpu().detach().numpy()
                        # pred = (pred * 255.0).cpu().detach().numpy()
                        # inputa = (inputa * 255.0).cpu().numpy()

                        # re_prediction = (re_prediction * 255.0).cpu().detach().numpy()
                        # re_pred = (re_pred * 255.0).cpu().detach().numpy()
                        
                        prediction = cvt2imag(prediction[0,0]).cpu().detach().numpy()
                        pred = cvt2imag(pred[0,0]).cpu().detach().numpy()
                        inputa= cvt2imag(inputa[0,0]).cpu().detach().numpy()
                        re_prediction = cvt2imag(re_prediction[0,0]).cpu().detach().numpy()
                        re_pred = cvt2imag(re_pred[0,0]).cpu().detach().numpy()
                        
                        val_final_psnr = evalution.psnr(prediction,inputa,255.0)
                        val_recon_psnr = evalution.psnr(pred,inputa,255.0)
                        val_final_ssim = evalution.psnr(re_prediction,inputa,255.0)
                        val_recon_ssim = evalution.psnr(re_pred,inputa,255.0)

            if phase == 'train':
                i=float(i)     
                evalutiondict = {'final_psnr':final_psnr,'recon_psnr':recon_psnr,
                                    'final_ssim':final_ssim,'recon_ssim':recon_ssim}
                                    
                summary_val.update(evalutiondict)
                self.printall(summary_val,epoch,'train')
                
                if (evalutiondict['final_psnr'] > best_psnr) or (evalutiondict['final_ssim'] > best_ssim):
                    self.save_model('bestsave_model',epoch)
                    best_psnr = evalutiondict['final_psnr']
                    best_epoch = epoch
                    best_ssim = evalutiondict['final_ssim']

            else:      
                i=float(i)
            
                evalutiondict = {'val_final_psnr':val_final_psnr,'val_recon_psnr':val_recon_psnr,
                                    'val_final_ssim':val_final_ssim,'val_recon_ssim':val_recon_ssim}
                summary_val.update(evalutiondict)

                self.printall(summary_val,epoch,'valid')

                # class_list=['recon_psnr','final_psnr','recon_ssim','final_ssim']
                # class_list,evalutiondict = self.logger.convert_to_list(evalutiondict)
                # print(class_list,evalutiondict)
                # df=pd.DataFrame(evalutiondict)
                self.logger.save_csv_file(evalutiondict,'test')

                total_image = {'recon_image':prediction,
                                'final_image':pred,
                                'input_image':inputa}
                
                total_image.update({'re_recon_image':re_prediction,
                                're_final_image':re_pred})
                

                self.re_normalize(total_image,255)
                total_image.update({'zero_image':inputs.cpu().numpy()})

                recon_error_img = total_image['input_image'] - total_image['recon_image']
                final_error_img = total_image['input_image'] - total_image['final_image']
                zero_error_img = total_image['input_image'] - total_image['zero_image']

                re_recon_error_img = total_image['input_image'] - total_image['re_recon_image']
                re_final_error_img = total_image['input_image'] - total_image['re_final_image']
                

                total_image.update({'recon_error_img':recon_error_img,
                                    'final_error_img':final_error_img,
                                    'zero_error_img':zero_error_img,
                                    'mask':mask.cpu().numpy()})
                
                total_image.update({'re_recon_error_img':re_recon_error_img,
                                    're_final_error_img':re_final_error_img})
                                    
                self.logger.save_images(total_image,epoch)
예제 #5
0
    def testing(self):
        t_imageDir = '../mri_dataset/real_final_test/'
        t_labelDir = '../mri_dataset/real_final_test/'

        Dataset = {
            'valid':
            DataLoader(mydataset(self.imageDir, self.sampling, self.knum),
                       batch_size=self.BATCH_SIZE,
                       shuffle=True,
                       num_workers=4),
            'test':
            DataLoader(mydataset(t_imageDir, self.sampling, self.knum),
                       batch_size=self.BATCH_SIZE,
                       shuffle=True,
                       num_workers=4)
        }

        self.load_model()

        phase = 'test'
        test_final_psnr = 0
        test_recon_psnr = 0
        test_final_ssim = 0
        test_recon_ssim = 0
        test_evalution = recon_matrix()

        for i, batch in enumerate(tqdm(Dataset[phase])):
            with torch.no_grad():
                self.gen.eval()
                self.dis.eval()

                inputa, mask = batch[0].to(self.device), batch[2].to(
                    self.device)
                inputa = inputa.float()
                mask = mask.float()

                under_im, inputs = apply_mask(inputa, mask)

                prediction = self.gen(inputs.float())

                prediction = torch.add(inputs.float(), prediction)
                pred = update(prediction, inputa, mask)

                #calcuate matrix
                prediction = cv2.normalize(prediction.cpu().numpy(),
                                           normalizedImg, 0, 1,
                                           cv2.NORM_MINMAX)
                pred = cv2.normalize(pred.cpu().numpy(), normalizedImg, 0, 1,
                                     cv2.NORM_MINMAX)
                inputa = cv2.normalize(inputa.cpu().numpy(), normalizedImg, 0,
                                       1, cv2.NORM_MINMAX)

                test_final_psnr += test_evalution.PSNR(final_img, _image, 1)
                test_recon_psnr += test_evalution.PSNR(recon_img, _image, 1)
                test_final_ssim += test_evalution.PSNR(final_img, _image, 1)
                test_recon_ssim += test_evalution.PSNR(recon_img, _image, 1)

            i = float(i)

            self.logger.changedir()

            evalutiondict = {
                'val_final_psnr': test_final_psnr / i,
                'val_recon_psnr': test_recon_psnr / i,
                'val_final_ssim': test_final_ssim / i,
                'val_recon_ssim': test_recon_ssim / i
            }

            # class_list=['recon_psnr','final_psnr','recon_ssim','final_ssim']
            class_list, evalutiondict = self.logger.convert_to_list(
                evalutiondict)
            self.logger.save_csv(evalutiondict, 'valid', class_list)

            total_image = {
                'recon_image': prediction.cpu().numpy(),
                'final_image': pred.cpu().numpy(),
                'input_image': inputa.cpu().numpy()
            }

            self.re_normalize(total_image, 255)
            total_image.update({'zero_image': inputs.cpu().numpy()})

            recon_error_img = total_image['input_image'] - total_image[
                'recon_image']
            final_error_img = total_image['input_image'] - total_image[
                'final_image']
            zero_error_img = total_image['input_image'] - total_image[
                'zero_image']

            total_image.update({
                'recon_error_img': recon_error_img,
                'final_error_img': final_error_img,
                'zero_error_img': zero_error_img,
                'mask': mask.cpu().numpy()
            })

            self.save_image(total_image, epoch)
예제 #6
0
파일: train.py 프로젝트: cow5566bad/YOLOv1
def main(args):
    logging.info('preprocess labels...')
    data = mydataset(args.image_dir, args.label_dir)
    _valid = mydataset(args.valid_image_dir, args.valid_label_dir)
    predictor = Predictor(valid=_valid)
    predictor.fit_dataset(data)