Example #1
0
class SingleGAN():
    def name(self):
        return 'SingleGAN'

    def initialize(self, opt):
        torch.cuda.set_device(opt.gpu)
        cudnn.benchmark = True
        self.opt = opt
        self.build_models()

    def build_models(self):
        ################### generator #########################################
        self.G = SingleGenerator(input_nc=self.opt.input_nc,
                                 output_nc=self.opt.input_nc,
                                 ngf=self.opt.ngf,
                                 nc=self.opt.c_num + self.opt.d_num,
                                 e_blocks=self.opt.e_blocks,
                                 norm_type=self.opt.norm)
        ################### encoder ###########################################
        self.E = None
        if self.opt.mode == 'multimodal':
            self.E = Encoder(input_nc=self.opt.input_nc,
                             output_nc=self.opt.c_num,
                             nef=self.opt.nef,
                             nd=self.opt.d_num,
                             n_blocks=4,
                             norm_type=self.opt.norm)
        if self.opt.isTrain:
            ################### discriminators #####################################
            self.Ds = []
            for i in range(self.opt.d_num):
                self.Ds.append(
                    D_NET_Multi(input_nc=self.opt.output_nc,
                                ndf=self.opt.ndf,
                                block_num=3,
                                norm_type=self.opt.norm))
            ################### init_weights ########################################
            if self.opt.continue_train:
                self.G.load_state_dict(
                    torch.load('{}/G_{}.pth'.format(self.opt.model_dir,
                                                    self.opt.which_epoch)))
                if self.E is not None:
                    self.E.load_state_dict(
                        torch.load('{}/E_{}.pth'.format(
                            self.opt.model_dir, self.opt.which_epoch)))
                for i in range(self.opt.d_num):
                    self.Ds[i].load_state_dict(
                        torch.load('{}/D_{}_{}.pth'.format(
                            self.opt.model_dir, i, self.opt.which_epoch)))
            else:
                self.G.apply(weights_init(self.opt.init_type))
                if self.E is not None:
                    self.E.apply(weights_init(self.opt.init_type))
                for i in range(self.opt.d_num):
                    self.Ds[i].apply(weights_init(self.opt.init_type))
            ################### use GPU #############################################
            self.G.cuda()
            if self.E is not None:
                self.E.cuda()
            for i in range(self.opt.d_num):
                self.Ds[i].cuda()
            ################### set criterion ########################################
            self.criterionGAN = GANLoss(
                mse_loss=(self.opt.c_gan_mode == 'lsgan'))
            ################## define optimizers #####################################
            self.define_optimizers()
        else:
            self.G.load_state_dict(
                torch.load('{}/G_{}.pth'.format(self.opt.model_dir,
                                                self.opt.which_epoch)))
            self.G.cuda()
            self.G.eval()
            if self.E is not None:
                self.E.load_state_dict(
                    torch.load('{}/E_{}.pth'.format(self.opt.model_dir,
                                                    self.opt.which_epoch)))
                self.E.cuda()
                self.E.eval()

    def sample_latent_code(self, size):
        c = torch.cuda.FloatTensor(size).normal_()
        return Variable(c)

    def get_domain_code(self, domainLable):
        domainCode = torch.zeros([len(domainLable), self.opt.d_num])
        domainIndex_cache = [[] for i in range(self.opt.d_num)]
        for index in range(len(domainLable)):
            domainCode[index, domainLable[index]] = 1
            domainIndex_cache[domainLable[index]].append(index)
        domainIndex = []
        for index in domainIndex_cache:
            domainIndex.append(Variable(torch.LongTensor(index)).cuda())
        return Variable(domainCode).cuda(), domainIndex

    def define_optimizer(self, Net):
        return optim.Adam(Net.parameters(), lr=self.opt.lr, betas=(0.5, 0.999))

    def define_optimizers(self):
        self.G_opt = self.define_optimizer(self.G)
        self.E_opt = None
        if self.E is not None:
            self.E_opt = self.define_optimizer(self.E)
        self.Ds_opt = []
        for i in range(self.opt.d_num):
            self.Ds_opt.append(self.define_optimizer(self.Ds[i]))

    def update_lr(self, lr):
        for param_group in self.G_opt.param_groups:
            param_group['lr'] = lr
        if self.E_opt is not None:
            for param_group in self.E_opt.param_groups:
                param_group['lr'] = lr
        for i in range(self.opt.d_num):
            for param_group in self.Ds_opt[i].param_groups:
                param_group['lr'] = lr

    def save(self, name):
        torch.save(self.G.state_dict(),
                   '{}/G_{}.pth'.format(self.opt.model_dir, name))
        if self.E_opt is not None:
            torch.save(self.E.state_dict(),
                       '{}/E_{}.pth'.format(self.opt.model_dir, name))
        for i in range(self.opt.d_num):
            torch.save(self.Ds[i].state_dict(),
                       '{}/D_{}_{}.pth'.format(self.opt.model_dir, i, name))

    def prepare_image(self, data):
        img, sourceD, targetD = data
        return Variable(torch.cat(img, 0)).cuda(), torch.cat(sourceD,
                                                             0), torch.cat(
                                                                 targetD, 0)

    def translation(self, data):
        input, sourceD, targetD = self.prepare_image(data)
        sourceDC, sourceIndex = self.get_domain_code(sourceD)
        targetDC, targetIndex = self.get_domain_code(targetD)

        images, names = [], []
        for i in range(self.opt.d_num):
            images.append(
                [tensor2im(input.index_select(0, sourceIndex[i])[0].data)])
            names.append(['D{}'.format(i)])

        if self.opt.mode == 'multimodal':
            for i in range(self.opt.n_samples):
                c_rand = self.sample_latent_code(
                    torch.Size([input.size(0), self.opt.c_num]))
                targetC = torch.cat([targetDC, c_rand], 1)
                output = self.G(input, targetC)
                for j in range(output.size(0)):
                    images[sourceD[j]].append(tensor2im(output[j].data))
                    names[sourceD[j]].append('{}to{}_{}'.format(
                        sourceD[j], targetD[j], i))
        else:
            output = self.G(input, targetDC)
            for i in range(output.size(0)):
                images[sourceD[i]].append(tensor2im(output[i].data))
                names[sourceD[i]].append('{}to{}'.format(
                    sourceD[i], targetD[i]))

        return images, names

    def get_current_errors(self):
        dict = []
        for i in range(self.opt.d_num):
            dict += [('D_{}'.format(i), self.errDs[i].data.item())]
            dict += [('G_{}'.format(i), self.errGs[i].data.item())]
        dict += [('errCyc', self.errCyc.data.item())]
        if self.opt.lambda_ide > 0:
            dict += [('errIde', self.errIde.data.item())]
        if self.E is not None:
            dict += [('errKl', self.errKL.data.item())]
            dict += [('errCode', self.errCode.data.item())]
        return OrderedDict(dict)

    def get_current_visuals(self):
        real = make_grid(self.real.data, nrow=self.real.size(0), padding=0)
        fake = make_grid(self.fake.data, nrow=self.real.size(0), padding=0)
        cyc = make_grid(self.cyc.data, nrow=self.real.size(0), padding=0)
        img = [real, fake, cyc]
        name = 'rsal,fake,cyc'
        if self.opt.lambda_ide > 0:
            ide = make_grid(self.ide.data, nrow=self.real.size(0), padding=0)
            img.append(ide)
            name += ',ide'
        img = torch.cat(img, 1)
        return OrderedDict([(name, tensor2im(img))])

    def update_D(self, D, D_opt, real, fake):
        D.zero_grad()
        pred_fake = D(fake.detach())
        pred_real = D(real)
        errD = self.criterionGAN(pred_fake, False) + self.criterionGAN(
            pred_real, True)
        errD.backward()
        D_opt.step()
        return errD

    def calculate_G(self, D, fake):
        pred_fake = D(fake)
        errG = self.criterionGAN(pred_fake, True)
        return errG

    def update_model(self, data):
        ### prepare data ###
        self.real, sourceD, targetD = self.prepare_image(data)
        sourceDC, self.sourceIndex = self.get_domain_code(sourceD)
        targetDC, self.targetIndex = self.get_domain_code(targetD)
        sourceC, targetC = sourceDC, targetDC
        ### generate image ###
        if self.E is not None:
            c_enc, mu, logvar = self.E(self.real, sourceDC)
            c_rand = self.sample_latent_code(c_enc.size())
            sourceC = torch.cat([sourceDC, c_enc], 1)
            targetC = torch.cat([targetDC, c_rand], 1)
        self.fake = self.G(self.real, targetC)
        self.cyc = self.G(self.fake, sourceC)
        if self.E is not None:
            _, mu_enc, _ = self.E(self.fake, targetDC)
        if self.opt.lambda_ide > 0:
            self.ide = self.G(self.real, sourceC)
        ### update D ###
        self.errDs = []
        for i in range(self.opt.d_num):
            errD = self.update_D(
                self.Ds[i], self.Ds_opt[i],
                self.real.index_select(0, self.sourceIndex[i]),
                self.fake.index_select(0, self.targetIndex[i]))
            self.errDs.append(errD)
        ### update G ###
        self.errGs, self.errKl, self.errCode, errG_total = [], 0, 0, 0
        self.G.zero_grad()
        for i in range(self.opt.d_num):
            errG = self.calculate_G(
                self.Ds[i], self.fake.index_select(0, self.targetIndex[i]))
            errG_total += errG
            self.errGs.append(errG)
        self.errCyc = torch.mean(
            torch.abs(self.cyc - self.real)) * self.opt.lambda_cyc
        errG_total += self.errCyc
        if self.opt.lambda_ide > 0:
            self.errIde = torch.mean(
                torch.abs(self.ide - self.real)) * self.opt.lambda_ide
            errG_total += self.errIde
        if self.E is not None:
            self.E.zero_grad()
            self.errKL = KL_loss(mu, logvar) * self.opt.lambda_kl
            errG_total += self.errKL
            errG_total.backward(retain_graph=True)
            self.G_opt.step()
            self.E_opt.step()
            self.G.zero_grad()
            self.E.zero_grad()
            self.errCode = torch.mean(
                torch.abs(mu_enc - c_rand)) * self.opt.lambda_c
            self.errCode.backward()
            self.G_opt.step()
        else:
            errG_total.backward()
            self.G_opt.step()
class SingleGAN():
    def name(self):
        return 'SingleGAN'

    def initialize(self, opt):
        torch.cuda.set_device(opt.gpu)
        cudnn.benchmark = True
        self.opt = opt
        self.build_models()
        
        
    def build_models(self):
        ################### generator #########################################
        self.G = SingleGenerator(input_nc=self.opt.input_nc, output_nc=self.opt.input_nc, ngf=self.opt.ngf, nc=self.opt.c_num+self.opt.d_num, e_blocks=self.opt.e_blocks, norm_type=self.opt.norm)
        ################### encoder ###########################################
        self.E =None
        if self.opt.mode == 'multimodal':
            self.E = Encoder(input_nc=self.opt.input_nc, output_nc=self.opt.c_num, nef=self.opt.nef, nd=self.opt.d_num, n_blocks=4, norm_type=self.opt.norm)
        if self.opt.isTrain:    
            ################### discriminators #####################################
            self.Ds = []
            for i in range(self.opt.d_num):
                self.Ds.append(D_NET_Multi(input_nc=self.opt.output_nc, ndf=self.opt.ndf, block_num=3,norm_type=self.opt.norm))
            ################### init_weights ########################################
            if self.opt.continue_train:
                self.G.load_state_dict(torch.load('{}/G_{}.pth'.format(self.opt.model_dir, self.opt.which_epoch)))
                if self.E is not None:
                    self.E.load_state_dict(torch.load('{}/E_{}.pth'.format(self.opt.model_dir, self.opt.which_epoch)))
                for i in range(self.opt.d_num):
                    self.Ds[i].load_state_dict(torch.load('{}/D_{}_{}.pth'.format(self.opt.model_dir, i, self.opt.which_epoch)))
            else:
                self.G.apply(weights_init(self.opt.init_type))
                if self.E is not None:
                    self.E.apply(weights_init(self.opt.init_type))
                for i in range(self.opt.d_num):
                    self.Ds[i].apply(weights_init(self.opt.init_type))
            ################### use GPU #############################################
            self.G.cuda()
            if self.E is not None:
                self.E.cuda()
            for i in range(self.opt.d_num):
                self.Ds[i].cuda()
            ################### set criterion ########################################
            self.criterionGAN = GANLoss(mse_loss=(self.opt.c_gan_mode == 'lsgan'))
            ################## define optimizers #####################################
            self.define_optimizers()
        else:
            self.G.load_state_dict(torch.load('{}/G_{}.pth'.format(self.opt.model_dir, self.opt.which_epoch)))
            self.G.cuda()
            self.G.eval()
            if self.E is not None:
                self.E.load_state_dict(torch.load('{}/E_{}.pth'.format(self.opt.model_dir, self.opt.which_epoch)))
                self.E.cuda()
                self.E.eval()
        
    def sample_latent_code(self, size):
        c = torch.cuda.FloatTensor(size).normal_()
        return Variable(c)
        
    def get_domain_code(self, domainLable):
        domainCode = torch.zeros([len(domainLable),self.opt.d_num])
        domainIndex_cache = [[] for i in range(self.opt.d_num)]
        for index in range(len(domainLable)):
            domainCode[index, domainLable[index]] = 1
            domainIndex_cache[domainLable[index]].append(index)
        domainIndex = []
        for index in domainIndex_cache:
            domainIndex.append(Variable(torch.LongTensor(index)).cuda())
        return Variable(domainCode).cuda(), domainIndex
        
    def define_optimizer(self, Net):
        return optim.Adam(Net.parameters(),
                                    lr=self.opt.lr,
                                    betas=(0.5, 0.999))
    def define_optimizers(self):
        self.G_opt = self.define_optimizer(self.G)
        self.E_opt = None
        if self.E is not None:
            self.E_opt = self.define_optimizer(self.E)
        self.Ds_opt = []
        for i in range(self.opt.d_num):
            self.Ds_opt.append(self.define_optimizer(self.Ds[i]))
    
    def update_lr(self, lr):
        for param_group in self.G_opt.param_groups:
            param_group['lr'] = lr
        if self.E_opt is not None:
            for param_group in self.E_opt.param_groups:
                param_group['lr'] = lr
        for i in range(self.opt.d_num):
            for param_group in self.Ds_opt[i].param_groups:
                param_group['lr'] = lr
                
    def save(self, name):
        torch.save(self.G.state_dict(), '{}/G_{}.pth'.format(self.opt.model_dir, name))
        if self.E_opt is not None:
            torch.save(self.E.state_dict(), '{}/E_{}.pth'.format(self.opt.model_dir, name))
        for i in range(self.opt.d_num):
            torch.save(self.Ds[i].state_dict(), '{}/D_{}_{}.pth'.format(self.opt.model_dir, i, name))
            
        
    def prepare_image(self, data):
        img, sourceD, targetD = data
        return Variable(torch.cat(img,0)).cuda(), torch.cat(sourceD,0), torch.cat(targetD,0)
    
    def translation(self, data):
        input, sourceD, targetD = self.prepare_image(data)
        sourceDC, sourceIndex = self.get_domain_code(sourceD)
        targetDC, targetIndex = self.get_domain_code(targetD)
        
        images, names =[], []
        for i in range(self.opt.d_num):
            images.append([tensor2im(input.index_select(0,sourceIndex[i])[0].data)])
            names.append(['D{}'.format(i)])
            
        if self.opt.mode == 'multimodal':
            for i in range(self.opt.n_samples):
                c_rand = self.sample_latent_code(torch.Size([input.size(0),self.opt.c_num]))
                targetC = torch.cat([targetDC, c_rand],1)
                output = self.G(input,targetC)
                for j in range(output.size(0)):
                    images[sourceD[j]].append(tensor2im(output[j].data))
                    names[sourceD[j]].append('{}to{}_{}'.format(sourceD[j],targetD[j],i)) 
        else:
            output = self.G(input,targetDC)
            for i in range(output.size(0)):
                images[sourceD[i]].append(tensor2im(output[i].data))
                names[sourceD[i]].append('{}to{}'.format(sourceD[i],targetD[i]))
            
        return  images, names
    
    def get_current_errors(self):
        dict = []
        for i in range(self.opt.d_num):
            dict += [('D_{}'.format(i), self.errDs[i].data.item())]
            dict += [('G_{}'.format(i), self.errGs[i].data.item())]
        dict += [('errCyc', self.errCyc.data.item())]
        if self.opt.lambda_ide > 0:
            dict += [('errIde', self.errIde.data.item())]
        if self.E is not None:
            dict += [('errKl', self.errKL.data.item())]
            dict += [('errCode', self.errCode.data.item())]
        return OrderedDict(dict)
        
    def get_current_visuals(self):
        real = make_grid(self.real.data,nrow=self.real.size(0),padding=0)
        fake = make_grid(self.fake.data,nrow=self.real.size(0),padding=0)
        cyc = make_grid(self.cyc.data,nrow=self.real.size(0),padding=0)
        img = [real,fake,cyc]
        name = 'rsal,fake,cyc'
        if self.opt.lambda_ide > 0:
            ide = make_grid(self.ide.data,nrow=self.real.size(0),padding=0)
            img.append(ide)
            name +=',ide'
        img = torch.cat(img,1)
        return OrderedDict([(name,tensor2im(img))])
        
    def update_D(self, D, D_opt, real, fake):
        D.zero_grad()
        pred_fake = D(fake.detach())
        pred_real = D(real)
        errD = self.criterionGAN(pred_fake,False) + self.criterionGAN(pred_real,True)
        errD.backward()
        D_opt.step()
        return errD
        
    def calculate_G(self, D, fake):
        pred_fake = D(fake)
        errG = self.criterionGAN(pred_fake,True)
        return errG

    def RGB2gray(rgb):
        r, g, b = rgb[:, :, 0], rgb[:, :, 1], rgb[:, :, 2]
        gray = 0.2989 * r + 0.5870 * g + 0.1140 * b
        return gray

        
    def update_model(self,data):
        ### prepare data ###
        self.real, sourceD, targetD = self.prepare_image(data)
        sourceDC, self.sourceIndex = self.get_domain_code(sourceD)
        targetDC, self.targetIndex = self.get_domain_code(targetD)
        sourceC, targetC = sourceDC, targetDC
        ### generate image ###
        if self.E is not None:
            c_enc, mu, logvar = self.E(self.real,sourceDC)
            c_rand = self.sample_latent_code(c_enc.size())
            sourceC = torch.cat([sourceDC, c_enc],1)
            targetC = torch.cat([targetDC, c_rand],1)
        self.fake = self.G(self.real,targetC)
        self.cyc = self.G(self.fake,sourceC)
        if self.E is not None:
            _, mu_enc, _ = self.E(self.fake,targetDC)
        if self.opt.lambda_ide > 0:
            self.ide = self.G(self.real,sourceC)
        ### update D ###
        self.errDs = []
        for i in range(self.opt.d_num):
            errD = self.update_D(self.Ds[i], self.Ds_opt[i], self.real.index_select(0,self.sourceIndex[i]), self.fake.index_select(0,self.targetIndex[i]))
            self.errDs.append(errD)

        ### update G ###
        self.errGs, self.errKl, self.errCode, errG_total = [], 0, 0, 0
        self.G.zero_grad()

        #####change:g_loss pure generate loss
        g_loss = 0
        #######
        for i in range(self.opt.d_num):# d_num = of domain number
            errG = self.calculate_G(self.Ds[i], self.fake.index_select(0,self.targetIndex[i]))
            errG_total += errG
            g_loss += errG
            self.errGs.append(errG)

        self.errCyc = torch.mean(torch.abs(self.cyc-self.real)) *  self.opt.lambda_cyc

        errG_total += self.errCyc

        if self.opt.lambda_ide > 0:
            self.errIde = torch.mean(torch.abs(self.ide-self.real)) *  self.opt.lambda_ide
            errG_total += self.errIde

    ######change:loss_freq #######
        # fake image 1d power spectrum

        N = 179
        epsilon = 1e-8
        psd1D_img = np.zeros([self.fake.shape[0], N])
        for t in range(self.fake.shape[0]):
            gen_imgs = self.fake.permute(0, 2, 3, 1)
            img_numpy = gen_imgs[t, :, :, :].cpu().detach().numpy()
            #img_gray = self.RGB2gray(img_numpy)
            img_gray = 0.2989 * img_numpy[:, :, 0] + 0.5870 * img_numpy[:, :, 1] + 0.1140 * img_numpy[:, :, 2]
            fft = np.fft.fft2(img_gray)
            fshift = np.fft.fftshift(fft)
            fshift += epsilon
            magnitude_spectrum = 20 * np.log(np.abs(fshift))
            psd1D = radialProfile.azimuthalAverage(magnitude_spectrum)
            psd1D = (psd1D - np.min(psd1D)) / (np.max(psd1D) - np.min(psd1D))
            psd1D_img[t, :] = psd1D

        psd1D_img = torch.from_numpy(psd1D_img).float()
        psd1D_img = Variable(psd1D_img, requires_grad=True).to("cuda")

        # real image 1d power spectrum
        psd1D_rec = np.zeros([self.real.shape[0], N])
        for t in range(self.real.shape[0]):
            gen_imgs = self.real.permute(0, 2, 3, 1)
            img_numpy = gen_imgs[t, :, :, :].cpu().detach().numpy()
            #img_gray = self.RGB2gray(img_numpy)
            img_gray = 0.2989 * img_numpy[:, :, 0] + 0.5870 * img_numpy[:, :, 1] + 0.1140 * img_numpy[:, :, 2]
            fft = np.fft.fft2(img_gray)
            fshift = np.fft.fftshift(fft)
            fshift += epsilon
            magnitude_spectrum = 20 * np.log(np.abs(fshift))
            psd1D = radialProfile.azimuthalAverage(magnitude_spectrum)
            psd1D = (psd1D - np.min(psd1D)) / (np.max(psd1D) - np.min(psd1D))
            psd1D_rec[t, :] = psd1D

        psd1D_rec = torch.from_numpy(psd1D_rec).float()
        psd1D_rec = Variable(psd1D_rec, requires_grad=True).to('cuda')
        criterion_freq = nn.BCELoss()
        loss_freq = criterion_freq(psd1D_rec, psd1D_img.detach())
        loss_freq *= g_loss  ####
        lambda_freq = 0.5   ###
        errG_total += loss_freq * lambda_freq



        if self.E is not None:
            self.E.zero_grad()
            self.errKL = KL_loss(mu,logvar) * self.opt.lambda_kl
            errG_total += self.errKL
            errG_total.backward(retain_graph=True)
            self.G_opt.step()
            self.E_opt.step()
            self.G.zero_grad()
            self.E.zero_grad()
            self.errCode = torch.mean(torch.abs(mu_enc - c_rand)) * self.opt.lambda_c
            self.errCode.backward()
            self.G_opt.step()
        else:
            errG_total.backward()
            self.G_opt.step()

        return errD,errG_total ,loss_freq
Example #3
0
def train(epoch=5, freeze=True):

    #Defining Model
    model = Encoder()
    print(model)

    if freeze:
        for param in model._resnet_extractor.parameters():
            param.require_grad = False

    criterion = nn.CrossEntropyLoss()

    optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

    transform = set_transform()

    test_loader = get_loader("./test_corpus.csv", 8, transform=transform)

    if torch.cuda.is_available():
        model = model.cuda()

    total_train_loss = []
    total_val_loss = []

    best_train = 100000000
    best_valid = 100000000
    not_improve = 0

    for e in range(1):

        loss_train = []
        loss_val = 0
        acc_train = 0
        acc_val = 0

        model.train()
        num_iter = 1

        model.eval()
        num_iter_val = 1
        for i, (images, classes) in enumerate(test_loader):

            optimizer.zero_grad()

            feature_image = model(images)

            if torch.cuda.is_available():
                feature_image = feature_image.cuda()
                classes = classes.cuda()

            _, preds = torch.max(feature_image.data, 1)

            loss = criterion(feature_image, classes)

            loss_val += loss.cpu().detach().numpy()
            acc_val += torch.sum(preds == classes)

            num_iter_val = i + 1
            del feature_image, classes, preds
            torch.cuda.empty_cache()

        avg_val = loss_val / num_iter_val
        print(f"\t\tValid Loss: {avg_val}")

        tb.add_scalar("Validation_Loss", avg_val, e)
        tb.add_scalar("Validation_Accuracy", 100 - avg_val, e)

        if avg_val < best_valid:
            total_val_loss.append(avg_val)
            model_save = save_path + "/best_model.th"
            torch.save(model.state_dict(), model_save)
            best_valid = avg_val
            print(f"Model saved to path save/")
            not_improve = 0

        else:
            not_improve += 1
            print(f"Not Improved {not_improve} times ")
        if not_improve == 6:
            break

    save_loss = {"train": total_train_loss, "valid": total_val_loss}
    with open(save_path + "/losses.pickle", "wb") as files:
        pickle.dump(save_loss, files)

    tb.close()
def train(epoch=5, freeze=True):

    tb = SummaryWriter()
    #Defining Model
    model = Encoder()
    print(model)

    #model.load_state_dict(torch.load(load_path))

    if freeze:
        for param in model._resnet_extractor.parameters():
            param.require_grad = False

    criterion = nn.CrossEntropyLoss()

    optimizer = optim.SGD(model.parameters(),
                          lr=learning_rate,
                          momentum=momentum)

    transform = set_transform()
    train_loader = get_loader(train_csv, batch_size, transform=transform)

    valid_loader = get_loader(valid_csv, batch_size, transform=transform)

    img, cls = next(iter(train_loader))

    #print(img.shape)
    grid = torchvision.utils.make_grid(img)
    tb.add_image('images', grid, 0)
    # tb.add_graph(model,img[0])

    if torch.cuda.is_available():
        model = model.cuda()
        img = img.cuda()

    total_train_loss = []
    total_val_loss = []

    best_train = 100000000
    best_valid = 100000000
    not_improve = 0

    #train_avg_list = []
    #valid_avg_list = []
    tb.add_graph(model, img)
    for e in range(1, epoch):

        loss_train = []
        loss_val = 0
        acc_train = 0
        acc_val = 0

        model.train()
        num_iter = 1
        for i, (images, classes) in enumerate(train_loader):

            optimizer.zero_grad()

            if torch.cuda.is_available():
                images = images.cuda()
                classes = classes.cuda()

            feature_image = model(images)
            _, preds = torch.max(feature_image.data, 1)

            loss = criterion(feature_image, classes)

            loss.backward()

            optimizer.step()

            loss_train.append(loss.cpu().detach().numpy())
            acc_train += torch.sum(preds == classes)

            del feature_image, classes, preds
            torch.cuda.empty_cache()

            #print(f"Loss i: {i}")
            num_iter = i + 1
            if i % 10 == 0:
                print(f"Epoch ({e}/{epoch}) Iter: {i+1} Loss: {loss}")

        avg_loss = sum(loss_train) / num_iter
        print(f"\t\tTotal iter: {num_iter} AVG loss: {avg_loss}")
        tb.add_scalar("Train_Loss", avg_loss, e)
        tb.add_scalar("Train_Accuracy", 100 - avg_loss, e)

        total_train_loss.append(avg_loss)

        model.eval()
        num_iter_val = 1
        for i, (images, classes) in enumerate(valid_loader):

            optimizer.zero_grad()

            feature_image = model(images)

            if torch.cuda.is_available():
                feature_image = feature_image.cuda()
                classes = classes.cuda()

            _, preds = torch.max(feature_image.data, 1)

            loss = criterion(feature_image, classes)

            loss_val += loss.cpu().detach().numpy()
            acc_val += torch.sum(preds == classes)

            num_iter_val = i + 1
            del feature_image, classes, preds
            torch.cuda.empty_cache()

        avg_val = loss_val / num_iter_val
        print(f"\t\tValid Loss: {avg_val}")

        tb.add_scalar("Validation_Loss", avg_val, e)
        tb.add_scalar("Validation_Accuracy", 100 - avg_val, e)

        if avg_val < best_valid:
            total_val_loss.append(avg_val)
            model_save = save_path + "/best_model.th"
            torch.save(model.state_dict(), model_save)
            best_valid = avg_val
            print(f"Model saved to path save/")
            not_improve = 0

        else:
            not_improve += 1
            print(f"Not Improved {not_improve} times ")
        if not_improve == 6:
            break

    save_loss = {"train": total_train_loss, "valid": total_val_loss}
    with open(save_path + "/losses.pickle", "wb") as files:
        pickle.dump(save_loss, files)

    tb.close()