def build_model(self):
     # feed 变量
     self.input_data = tf.placeholder(
         tf.float32,
         [None, self.un_ae_struct[0]])  # N等于_num_examples或batch_size
     # 构建rmbs
     self.pt_list = list()
     for i in range(len(self.un_ae_struct) - 1):
         print('Build AE-{}...'.format(i + 1))
         n_x = self.un_ae_struct[i]
         n_y = self.un_ae_struct[i + 1]
         if self.ae_type == 'sae' and n_x > n_y: ae_type = 'ae'
         else: ae_type = self.ae_type
         name = ae_type + '-' + str(i + 1)
         ae = AE(
             name=name,
             en_func=self.en_func,
             loss_func=self.
             loss_func,  # encoder:[sigmoid] || decoder:[sigmoid] with ‘cross_entropy’ | [relu] with ‘mse’
             ae_type=ae_type,  # ae | dae | sae
             noise_type=self.
             noise_type,  # Gaussian noise (gs) | Masking noise (mn)
             beta=self.beta,  # 惩罚因子权重(第二项损失的系数)
             p=self.p,  # DAE:样本该维作为噪声的概率 / SAE稀疏性参数:期望的隐层平均活跃度(在训练批次上取平均)
             ae_struct=[n_x, n_y],
             ae_epochs=self.ae_epochs,
             batch_size=self.batch_size,
             ae_lr=self.ae_lr)
         # print(ae.__dict__)
         self.pt_list.append(ae)  # 加入list
Beispiel #2
0
    def __init__(_, sizes): #sizes es una lista resp a la arquitectura profunda
        super().__init__()
        _.subnet = torch.nn.ModuleList() #lista para recuperar los parámetros de todas las RBM
#        _.L1 = torch.nn.Linear(vsize, hsize)
#        _.L2 = torch.nn.Linear(hsize, vsize)
        for i in range(len(sizes)-1):
            _.subnet.append( AE(sizes[i],sizes[i+1]) )
Beispiel #3
0
def train(config_path):
    modelArgs = NetworkConfigParser.constructModelArgs(config_path, ModelArgs)
    nn = NetworkConfigParser.constructNetwork(config_path)
    train_path, test_path, save_path = NetworkConfigParser.getDataInfo(
        config_path)
    print nn
    # TODO : Arguments
    num_hid = nn.layers[1].num_units
    shape = (None, nn.layers[0].num_units)
    train, test, cold = loadTrainTest(train_path, test_path, shape=shape)
    ae = AE(nn, modelArgs)
    evaluate = EvaluateNN(ae)
    theta = ae.nn.getFlattenParams()
    ae.setParameters(theta)
    iterCounter = Counter()
    optimizer = getOptimizer(modelArgs.optimizer, ae, evaluate, theta, train,
                             test, nn, modelArgs, iterCounter,
                             modelArgs.batch_size, modelArgs.max_iter[0])

    optimizer.step_grow = 5.0
    k = 0
    for info in optimizer:
        print "Iteration %d" % k
        if k == 5:
            optimizer.step_grow = 1.2
        if k % 5 == 0:
            ae.setParameters(theta)
            rmse, mae = evaluate.calculateRMSEandMAE(train, test)
            print "Fold :%d Test RMSE: %f Test MAE: %f" % (i, rmse, mae)
        if k > modelArgs.max_iter[0]:
            break
        k += 1
    if save_path:
        _theta = ae.getParameters()
        np.save(save_path, _theta)
Beispiel #4
0
def main(_):
    #print(FLAGS.__flags)
    file_name =  'm[' + FLAGS.model + ']_lr[' + str(FLAGS.learning_rate) + ']_b[' + str(FLAGS.batch_size) + \
                 ']_ae' + FLAGS.ae_h_dim_list + '_z[' + str(FLAGS.z_dim) +  ']_dis' + FLAGS.dis_h_dim_list
    logger.info(file_name)

    with tf.device('/gpu:%d' % FLAGS.gpu_id):
        ### ===== Build model ===== ###
        if FLAGS.model == "AE":
            logger.info("Build AE model")
            model = AE(logger, FLAGS.learning_rate, FLAGS.input_dim, FLAGS.z_dim, eval(FLAGS.ae_h_dim_list))

        elif FLAGS.model == "VAE":
            logger.info("Build VAE model")

        elif FLAGS.model == "VAE_GAN":
            logger.info("Build VAE_GAN model")


        ### ===== Train/Test =====###

        if FLAGS.is_train:
            #logger.info("Start training")
            train_data = load_data(os.path.join(FLAGS.data_dir, 'train_data.npy'))
            val_data = load_data(os.path.join(FLAGS.data_dir, 'val_data.npy'))
            #print(train_data.shape)
            model.train(train_data, FLAGS.batch_size)
        else:
            logger.info("Start testing")
            test_data = load_data(os.path.join(FLAGS.data_dir, 'test_data.npy'))
Beispiel #5
0
def train(configPath, name):
    useGpu = os.environ.get('GNUMPY_USE_GPU', 'auto')
    if useGpu == "no":
        mode = "cpu"
    else:
        mode = "gpu"

    print '========================================================'
    print 'train %s' % name
    print "the program is on %s" % mode
    print '======================================================='

    config = configparser.ConfigParser(
        interpolation=configparser.ExtendedInterpolation())
    config.read(configPath)
    model_name = config.get(name, 'model')
    if model_name == "ae":
        from ae import AE
        model = AE(config, name)
    elif model_name == "lae":
        from lae import LAE
        model = LAE(config, name)
    elif model_name == "pae":
        from pae import PAE
        model = PAE(config, name)
    elif model_name == "sae":
        from sae import SAE
        model = SAE(config, name)
    elif model_name == "msae":
        from msae import MSAE
        model = MSAE(config, name)

    model.train()
 def build_model(self):
     # 构建rmbs
     self.pt_list = list()
     self.parameter_list=list()
     for i in range(len(self.struct) -1):
         print('Build AE-{}...'.format(i+1))
         n_x = self.struct[i]
         n_y = self.struct[i+1]
         if self.ae_type=='sae' and n_x>n_y: ae_type='ae'
         else: ae_type=self.ae_type
         name=ae_type+'-'+ str(i + 1)
         ae = AE(name=name,
                 act_type=self.act_type,
                 loss_func=self.loss_func, # encoder:[sigmoid] || decoder:[sigmoid] with ‘cross_entropy’ | [relu] with ‘mse’
                 ae_type=ae_type, # ae | dae | sae
                 noise_type=self.noise_type, # Gaussian noise (gs) | Masking noise (mn)
                 beta=self.beta,  # 惩罚因子权重(第二项损失的系数)
                 p=self.p, # DAE:样本该维作为噪声的概率 / SAE稀疏性参数:期望的隐层平均活跃度(在训练批次上取平均)
                 struct=[n_x,n_y],
                 out_size = self.out_size,
                 ae_epochs=self.ae_epochs,
                 batch_size=self.batch_size,
                 lr=self.ae_lr)
         # print(ae.__dict__)
         self.pt_list.append(ae) # 加入list
         self.parameter_list.append([ae.W,ae.bh])
Beispiel #7
0
def train_cleitc(dataloader, seed, **kwargs):
    """

    :param s_dataloaders:
    :param t_dataloaders:
    :param kwargs:
    :return:
    """
    autoencoder = AE(input_dim=kwargs['input_dim'],
                     latent_dim=kwargs['latent_dim'],
                     hidden_dims=kwargs['encoder_hidden_dims'],
                     dop=kwargs['dop']).to(kwargs['device'])

    # get reference encoder
    aux_ae = deepcopy(autoencoder)

    aux_ae.encoder.load_state_dict(torch.load(os.path.join('./model_save/ae5000', f'ft_encoder_{seed}.pt')))
    print('reference encoder loaded')
    reference_encoder = aux_ae.encoder

    # construct transmitter
    transmitter = MLP(input_dim=kwargs['latent_dim'],
                      output_dim=kwargs['latent_dim'],
                      hidden_dims=[kwargs['latent_dim']]).to(kwargs['device'])

    ae_eval_train_history = defaultdict(list)
    ae_eval_test_history = defaultdict(list)

    if kwargs['retrain_flag']:
        cleit_params = [
            autoencoder.parameters(),
            transmitter.parameters()
        ]
        cleit_optimizer = torch.optim.AdamW(chain(*cleit_params), lr=kwargs['lr'])
        # start autoencoder pretraining
        for epoch in range(int(kwargs['train_num_epochs'])):
            if epoch % 1 == 0:
                print(f'----Autoencoder Training Epoch {epoch} ----')
            for step, batch in enumerate(dataloader):
                ae_eval_train_history = cleit_train_step(ae=autoencoder,
                                                         reference_encoder=reference_encoder,
                                                         transmitter=transmitter,
                                                         batch=batch,
                                                         device=kwargs['device'],
                                                         optimizer=cleit_optimizer,
                                                         history=ae_eval_train_history)
        torch.save(autoencoder.state_dict(), os.path.join(kwargs['model_save_folder'], 'cleit_ae.pt'))
        torch.save(transmitter.state_dict(), os.path.join(kwargs['model_save_folder'], 'transmitter.pt'))
    else:
        try:
            autoencoder.load_state_dict(torch.load(os.path.join(kwargs['model_save_folder'], 'cleit_ae.pt')))
            transmitter.load_state_dict(torch.load(os.path.join(kwargs['model_save_folder'], 'transmitter.pt')))
        except FileNotFoundError:
            raise Exception("No pre-trained encoder")

    encoder = EncoderDecoder(encoder=autoencoder.encoder,
                             decoder=transmitter).to(kwargs['device'])

    return encoder, (ae_eval_train_history, ae_eval_test_history)
Beispiel #8
0
def loadModel(config_path):
    modelArgs = NetworkConfigParser.constructModelArgs(config_path, ModelArgs)
    nn = NetworkConfigParser.constructNetwork(config_path)
    train_path, test_path, save_path = NetworkConfigParser.getDataInfo(
        config_path)
    ae = AE(nn, modelArgs)
    theta = np.load(save_path + ".npy")
    ae.setParameters(theta)
    return ae
Beispiel #9
0
def main(epoch_num):
    # 下载mnist数据集
    mnist_train = datasets.MNIST('mnist',
                                 train=True,
                                 transform=transforms.Compose(
                                     [transforms.ToTensor()]),
                                 download=True)
    mnist_test = datasets.MNIST('mnist',
                                train=False,
                                transform=transforms.Compose(
                                    [transforms.ToTensor()]),
                                download=True)

    # 载入mnist数据集
    # batch_size设置每一批数据的大小,shuffle设置是否打乱数据顺序,结果表明,该函数会先打乱数据再按batch_size取数据
    mnist_train = DataLoader(mnist_train, batch_size=32, shuffle=True)
    mnist_test = DataLoader(mnist_test, batch_size=32, shuffle=True)

    # 查看每一个batch图片的规模
    x, label = iter(mnist_train).__next__()  # 取出第一批(batch)训练所用的数据集
    print(' img : ', x.shape
          )  # img :  torch.Size([32, 1, 28, 28]), 每次迭代获取32张图片,每张图大小为(1,28,28)

    # 准备工作 : 搭建计算流程
    device = torch.device('cuda')
    model = AE().to(device)  # 生成AE模型,并转移到GPU上去
    print('The structure of our model is shown below: \n')
    print(model)
    loss_function = nn.MSELoss()  # 生成损失函数
    optimizer = optim.Adam(model.parameters(),
                           lr=1e-3)  # 生成优化器,需要优化的是model的参数,学习率为0.001

    # 开始迭代
    loss_epoch = []
    for epoch in range(epoch_num):
        # 每一代都要遍历所有的批次
        for batch_index, (x, _) in enumerate(mnist_train):
            # [b, 1, 28, 28]
            x = x.to(device)
            # 前向传播
            x_hat = model(x)  # 模型的输出,在这里会自动调用model中的forward函数
            loss = loss_function(x_hat, x)  # 计算损失值,即目标函数
            # 后向传播
            optimizer.zero_grad()  # 梯度清零,否则上一步的梯度仍会存在
            loss.backward()  # 后向传播计算梯度,这些梯度会保存在model.parameters里面
            optimizer.step()  # 更新梯度,这一步与上一步主要是根据model.parameters联系起来了

        loss_epoch.append(loss.item())
        if epoch % (epoch_num // 10) == 0:
            print('Epoch [{}/{}] : '.format(epoch, epoch_num), 'loss = ',
                  loss.item())  # loss是Tensor类型
            # x, _ = iter(mnist_test).__next__()   # 在测试集中取出一部分数据
            # with torch.no_grad():
            #     x_hat = model(x)

    return loss_epoch
Beispiel #10
0
def main():
    mnist_train = datasets.MNIST('mnist',
                                 True,
                                 transform=transforms.Compose(
                                     [transforms.ToTensor()]),
                                 download=True)
    mnist_train = DataLoader(mnist_train, batch_size=32, shuffle=True)

    mnist_test = datasets.MNIST('mnist',
                                False,
                                transform=transforms.Compose(
                                    [transforms.ToTensor()]),
                                download=True)
    mnist_test = DataLoader(mnist_test, batch_size=32, shuffle=True)

    x, _ = iter(mnist_train).next()
    print('x:', x.shape)

    # device = torch.device('cuda')
    # model = AE().to(device)
    model = AE()
    criteon = nn.MSELoss()
    optimizer = optim.Adam(model.parameters(), lr=1e-3)
    print(model)

    viz = visdom.Visdom()

    for epoch in range(1000):

        for batchidx, (x, _) in enumerate(mnist_train):
            # [b, 1, 28, 28]
            # x = x.to(device)

            x_hat, kld = model(x)
            loss = criteon(x_hat, x)

            if kld is not None:
                elbo = -loss - 1.0 * kld
                loss = -elbo

            # backprop
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        # print(epoch, 'loss:', loss.item(), 'kld:', kld.item())
        print(epoch, 'loss', loss.item())

        x, _ = iter(mnist_test).next()
        # x = x.to(device)
        with torch.no_grad():
            x_hat, kld = model(x)
        viz.images(x, nrow=8, win='x', opts=dict(title='x'))
        viz.images(x_hat, nrow=8, win='x_hat', opts=dict(title='x_ha'))
Beispiel #11
0
def train():
    mnist_train = datasets.MNIST('../data/mnist',
                                 train=True,
                                 transform=transforms.Compose(
                                     [transforms.ToTensor()]),
                                 download=True)
    mnist_train = DataLoader(mnist_train, batch_size=32, shuffle=True)
    mnist_test = datasets.MNIST('../data/mnist',
                                train=False,
                                transform=transforms.Compose(
                                    [transforms.ToTensor()]),
                                download=True)
    mnist_test = DataLoader(mnist_test, batch_size=32, shuffle=True)

    #不需要label,因为是无监督学习
    x, _ = iter(mnist_train).next()
    print('x:', x.shape)

    device = torch.device('cuda')

    model = AE().to(device)
    criteon = nn.MSELoss()  # loss function
    optimzer = optim.Adam(model.parameters(), lr=1e-3)
    print(model)

    vis = visdom.Visdom()

    for epoch in range(1000):

        # 训练过程
        for batchIdx, (x, _) in enumerate(mnist_train):
            #forwardp [b, 1, 28, 28]
            x = x.to(device)
            x_hat = model(x)
            loss = criteon(x_hat, x)

            #backward
            optimzer.zero_grad()
            loss.backward()
            optimzer.step()

        # 打印loss
        print('epoch:', epoch, '  loss:', loss.item())

        # 测试过程
        x, _ = iter(mnist_test).next()
        x = x.to(device)
        with torch.no_grad():  #测试不用梯度
            x_hat = model(x)

        vis.images(x, nrow=8, win='x', opts=dict(title='x'))  #画输入
        vis.images(x_hat, nrow=8, win='x_hat', opts=dict(title='x_hat'))  #画输出
Beispiel #12
0
def main():

    mnist_train = DataLoader(datasets.MNIST('../Lesson5/mnist_data',
                                            True,
                                            transform=transforms.Compose(
                                                [transforms.ToTensor()]),
                                            download=True),
                             batch_size=32,
                             shuffle=True)

    mnist_test = DataLoader(datasets.MNIST('../Lesson5/mnist_data',
                                           False,
                                           transforms.Compose(
                                               [transforms.ToTensor()]),
                                           download=True),
                            batch_size=32,
                            shuffle=True)

    x, _ = iter(mnist_train).next()
    print(f'x:{x.shape}')

    device = torch.device('cuda')
    model = AE().to(device)
    criteon = nn.MSELoss()
    optimizer = optim.Adam(model.parameters(), lr=1e-3)
    print(model)

    viz = visdom.Visdom()

    for epoch in range(1000):

        for batchidx, (x, _) in enumerate(mnist_train):
            # [b, 1, 28, 28]
            x = x.to(device)

            x_hat, _ = model(x)
            loss = criteon(x_hat, x)

            # backprop
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        print(epoch, 'loss:', loss.item())

        x, _ = iter(mnist_test).next()
        x = x.to(device)
        with torch.no_grad():
            x_hat, kld = model(x)
        viz.images(x, nrow=8, win='x', opts=dict(title='x'))
        viz.images(x_hat, nrow=8, win='x_hat', opts=dict(title='x_hat'))
Beispiel #13
0
 def loadModel(self, config, name):
     """
     name: path to model file or section name for the model
     """
     if os.path.exists(name):
         from ae import AE
         model = AE.load(name)
     else:
         modelname = self.readField(config, name, "model")
         if modelname == "lae":
             from lae import LAE
             model = LAE(config, name)
         elif modelname == "pae":
             from pae import PAE
             model = PAE(config, name)
         elif modelname == 'ae':
             from ae import AE
             model = AE(config, name)
     return model
Beispiel #14
0
def train_ae(dataloader, **kwargs):
    """
    :param s_dataloaders:
    :param t_dataloaders:
    :param kwargs:
    :return:
    """
    autoencoder = AE(input_dim=kwargs['input_dim'],
                     latent_dim=kwargs['latent_dim'],
                     hidden_dims=kwargs['encoder_hidden_dims'],
                     dop=kwargs['dop']).to(kwargs['device'])

    ae_eval_train_history = defaultdict(list)
    ae_eval_test_history = defaultdict(list)

    if kwargs['retrain_flag']:
        ae_optimizer = torch.optim.AdamW(autoencoder.parameters(),
                                         lr=kwargs['lr'])
        # start autoencoder pretraining
        for epoch in range(int(kwargs['train_num_epochs'])):
            if epoch % 50 == 0:
                print(f'----Autoencoder Training Epoch {epoch} ----')
            for step, batch in enumerate(dataloader):
                ae_eval_train_history = ae_train_step(
                    ae=autoencoder,
                    batch=batch,
                    device=kwargs['device'],
                    optimizer=ae_optimizer,
                    history=ae_eval_train_history)
        torch.save(autoencoder.state_dict(),
                   os.path.join(kwargs['model_save_folder'], 'ae.pt'))
    else:
        try:
            autoencoder.load_state_dict(
                torch.load(os.path.join(kwargs['model_save_folder'], 'ae.pt')))
        except FileNotFoundError:
            raise Exception("No pre-trained encoder")

    return autoencoder.encoder, (ae_eval_train_history, ae_eval_test_history)
Beispiel #15
0
K = hyperparams["K"]

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

test_transform = transforms.Compose([
    transforms.ToTensor(),
])

testset = torchvision.datasets.ImageFolder(val_data_folder,
                                           transform=test_transform)
test_loader = torch.utils.data.DataLoader(testset,
                                          batch_size=batch_size,
                                          shuffle=False,
                                          num_workers=20)

model = AE(K=K).to(device)
model = nn.DataParallel(model, device_ids=[0])
model.load_state_dict(
    torch.load(saved_model_name, map_location={'cuda:1': 'cuda:0'}))

if not os.path.exists(save_folder_name):
    os.makedirs(save_folder_name)

with tqdm(total=len(test_loader), desc="Batches") as pbar:
    for i, (data) in enumerate(test_loader):
        model.eval()
        img, labels = data
        encoded, out, hashed = model(img)
        torch.save(out, save_folder_name + "/out/out_{}.pt".format(i))
        torch.save(labels, save_folder_name + "/lab/lab_{}.pt".format(i))
        torch.save(hashed, save_folder_name + "/hash/hash_{}.pt".format(i))
Beispiel #16
0
from ae import AE

ae = AE(1)
if ae is not None:  # <1>
    print('is fine')
# => is fine
Beispiel #17
0
def train_cleitcs(s_dataloaders, t_dataloaders, val_dataloader, test_dataloader, metric_name, seed, **kwargs):
    """

    :param s_dataloaders:
    :param t_dataloaders:
    :param kwargs:
    :return:
    """
    s_train_dataloader = s_dataloaders
    t_train_dataloader = t_dataloaders

    autoencoder = AE(input_dim=kwargs['input_dim'],
                     latent_dim=kwargs['latent_dim'],
                     hidden_dims=kwargs['encoder_hidden_dims'],
                     dop=kwargs['dop']).to(kwargs['device'])
    # get reference encoder
    aux_ae = deepcopy(autoencoder)

    aux_ae.encoder.load_state_dict(torch.load(os.path.join('./model_save/ae', f'ft_encoder_{seed}.pt')))
    print('reference encoder loaded')
    reference_encoder = aux_ae.encoder

    # construct transmitter
    transmitter = MLP(input_dim=kwargs['latent_dim'],
                      output_dim=kwargs['latent_dim'],
                      hidden_dims=[kwargs['latent_dim']]).to(kwargs['device'])

    encoder = autoencoder.encoder
    target_decoder = MoMLP(input_dim=kwargs['latent_dim'],
                           output_dim=kwargs['output_dim'],
                           hidden_dims=kwargs['regressor_hidden_dims'],
                           out_fn=torch.nn.Sigmoid).to(kwargs['device'])

    target_regressor = EncoderDecoder(encoder=encoder,
                                      decoder=target_decoder).to(kwargs['device'])

    train_history = defaultdict(list)
    # ae_eval_train_history = defaultdict(list)
    val_history = defaultdict(list)
    s_target_regression_eval_train_history = defaultdict(list)
    t_target_regression_eval_train_history = defaultdict(list)
    target_regression_eval_val_history = defaultdict(list)
    target_regression_eval_test_history = defaultdict(list)
    cleit_params = [
        target_regressor.parameters(),
        transmitter.parameters()
    ]
    model_optimizer = torch.optim.AdamW(chain(*cleit_params), lr=kwargs['lr'])
    for epoch in range(int(kwargs['train_num_epochs'])):
        if epoch % 50 == 0:
            print(f'Coral training epoch {epoch}')
        for step, s_batch in enumerate(s_train_dataloader):
            t_batch = next(iter(t_train_dataloader))
            train_history = cleit_train_step(model=target_regressor,
                                             transmitter=transmitter,
                                             reference_encoder=reference_encoder,
                                             s_batch=s_batch,
                                             t_batch=t_batch,
                                             device=kwargs['device'],
                                             optimizer=model_optimizer,
                                             alpha=kwargs['alpha'],
                                             history=train_history)
        s_target_regression_eval_train_history = evaluate_target_regression_epoch(regressor=target_regressor,
                                                                                  dataloader=s_train_dataloader,
                                                                                  device=kwargs['device'],
                                                                                  history=s_target_regression_eval_train_history)

        t_target_regression_eval_train_history = evaluate_target_regression_epoch(regressor=target_regressor,
                                                                                  dataloader=t_train_dataloader,
                                                                                  device=kwargs['device'],
                                                                                  history=t_target_regression_eval_train_history)
        target_regression_eval_val_history = evaluate_target_regression_epoch(regressor=target_regressor,
                                                                              dataloader=val_dataloader,
                                                                              device=kwargs['device'],
                                                                              history=target_regression_eval_val_history)
        target_regression_eval_test_history = evaluate_target_regression_epoch(regressor=target_regressor,
                                                                               dataloader=test_dataloader,
                                                                               device=kwargs['device'],
                                                                               history=target_regression_eval_test_history)

        save_flag, stop_flag = model_save_check(history=target_regression_eval_val_history,
                                                metric_name=metric_name,
                                                tolerance_count=50)
        if save_flag:
            torch.save(target_regressor.state_dict(), os.path.join(kwargs['model_save_folder'], f'cleitcs_regressor_{seed}.pt'))
        if stop_flag:
            break
    target_regressor.load_state_dict(
        torch.load(os.path.join(kwargs['model_save_folder'], f'cleitcs_regressor_{seed}.pt')))

    # evaluate_target_regression_epoch(regressor=target_regressor,
    #                                  dataloader=val_dataloader,
    #                                  device=kwargs['device'],
    #                                  history=None,
    #                                  seed=seed,
    #                                  output_folder=kwargs['model_save_folder'])
    evaluate_target_regression_epoch(regressor=target_regressor,
                                     dataloader=test_dataloader,
                                     device=kwargs['device'],
                                     history=None,
                                     seed=seed,
                                     output_folder=kwargs['model_save_folder'])

    return target_regressor, (
        train_history, s_target_regression_eval_train_history, t_target_regression_eval_train_history,
        target_regression_eval_val_history, target_regression_eval_test_history)
Beispiel #18
0
from torchvision import transforms, datasets
from torch.utils.data import DataLoader
from ae import AE
from visdom import Visdom

device = torch.device('cuda:0')  if torch.cuda.is_available() else torch.device('cpu')
batchsz = 128
epochs = 50
lr = 1e-3

train_dataset = datasets.MNIST('../data',transform=transforms.ToTensor())
test_dataset = datasets.MNIST('../data', False, transform=transforms.ToTensor())
train_loader = DataLoader(train_dataset, batch_size=batchsz, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batchsz, shuffle=True)

net = AE()
criterion = nn.MSELoss()
optimizer = optim.Adam(net.parameters(), lr=lr)
scheduler = optim.lr_scheduler.StepLR(optimizer, 5)
net.to(device)
criterion.to(device)
train_loss = []
viz = Visdom()
for epoch in range(epochs):
    train_loss.clear()
    net.train()
    for step, (x, _) in enumerate(train_loader):
        x = x.to(device)
        x_hat = net(x)

        loss = criterion(x_hat, x)
Beispiel #19
0
def train_adda(s_dataloaders, t_dataloaders, val_dataloader, test_dataloader,
               metric_name, seed, **kwargs):
    """

    :param s_dataloaders:
    :param t_dataloaders:
    :param kwargs:
    :return:
    """
    s_train_dataloader = s_dataloaders
    t_train_dataloader = t_dataloaders

    autoencoder = AE(input_dim=kwargs['input_dim'],
                     latent_dim=kwargs['latent_dim'],
                     hidden_dims=kwargs['encoder_hidden_dims'],
                     dop=kwargs['dop']).to(kwargs['device'])
    encoder = autoencoder.encoder

    target_decoder = MoMLP(input_dim=kwargs['latent_dim'],
                           output_dim=kwargs['output_dim'],
                           hidden_dims=kwargs['regressor_hidden_dims'],
                           out_fn=torch.nn.Sigmoid).to(kwargs['device'])

    target_regressor = EncoderDecoder(
        encoder=encoder, decoder=target_decoder).to(kwargs['device'])

    confounding_classifier = MLP(input_dim=kwargs['latent_dim'],
                                 output_dim=1,
                                 hidden_dims=kwargs['classifier_hidden_dims'],
                                 dop=kwargs['dop']).to(kwargs['device'])

    critic_train_history = defaultdict(list)
    gen_train_history = defaultdict(list)
    s_target_regression_eval_train_history = defaultdict(list)
    t_target_regression_eval_train_history = defaultdict(list)
    target_regression_eval_val_history = defaultdict(list)
    target_regression_eval_test_history = defaultdict(list)

    model_optimizer = torch.optim.AdamW(target_regressor.parameters(),
                                        lr=kwargs['lr'])
    classifier_optimizer = torch.optim.RMSprop(
        confounding_classifier.parameters(), lr=kwargs['lr'])
    for epoch in range(int(kwargs['train_num_epochs'])):
        if epoch % 50 == 0:
            print(f'ADDA training epoch {epoch}')
        for step, s_batch in enumerate(s_train_dataloader):
            t_batch = next(iter(t_train_dataloader))
            critic_train_history = critic_train_step(
                critic=confounding_classifier,
                model=target_regressor,
                s_batch=s_batch,
                t_batch=t_batch,
                device=kwargs['device'],
                optimizer=classifier_optimizer,
                history=critic_train_history,
                # clip=0.1,
                gp=10.0)
            if (step + 1) % 5 == 0:
                gen_train_history = gan_gen_train_step(
                    critic=confounding_classifier,
                    model=target_regressor,
                    s_batch=s_batch,
                    t_batch=t_batch,
                    device=kwargs['device'],
                    optimizer=model_optimizer,
                    alpha=1.0,
                    history=gen_train_history)
        s_target_regression_eval_train_history = evaluate_target_regression_epoch(
            regressor=target_regressor,
            dataloader=s_train_dataloader,
            device=kwargs['device'],
            history=s_target_regression_eval_train_history)

        t_target_regression_eval_train_history = evaluate_target_regression_epoch(
            regressor=target_regressor,
            dataloader=t_train_dataloader,
            device=kwargs['device'],
            history=t_target_regression_eval_train_history)
        target_regression_eval_val_history = evaluate_target_regression_epoch(
            regressor=target_regressor,
            dataloader=val_dataloader,
            device=kwargs['device'],
            history=target_regression_eval_val_history)
        target_regression_eval_test_history = evaluate_target_regression_epoch(
            regressor=target_regressor,
            dataloader=test_dataloader,
            device=kwargs['device'],
            history=target_regression_eval_test_history)

        save_flag, stop_flag = model_save_check(
            history=target_regression_eval_val_history,
            metric_name=metric_name,
            tolerance_count=50)
        if save_flag:
            torch.save(
                target_regressor.state_dict(),
                os.path.join(kwargs['model_save_folder'],
                             f'adda_regressor_{seed}.pt'))
        if stop_flag:
            break

    target_regressor.load_state_dict(
        torch.load(
            os.path.join(kwargs['model_save_folder'],
                         f'adda_regressor_{seed}.pt')))

    # evaluate_target_regression_epoch(regressor=target_regressor,
    #                                  dataloader=val_dataloader,
    #                                  device=kwargs['device'],
    #                                  history=None,
    #                                  seed=seed,
    #                                  output_folder=kwargs['model_save_folder'])
    evaluate_target_regression_epoch(regressor=target_regressor,
                                     dataloader=test_dataloader,
                                     device=kwargs['device'],
                                     history=None,
                                     seed=seed,
                                     output_folder=kwargs['model_save_folder'])

    return target_regressor, (critic_train_history, gen_train_history,
                              s_target_regression_eval_train_history,
                              t_target_regression_eval_train_history,
                              target_regression_eval_val_history,
                              target_regression_eval_test_history)
Beispiel #20
0
def build_ae(encoder, decoder, args):
    NORM_REGULARIZE = args.getfloat('Network', 'NORM_REGULARIZE')
    VARIATIONAL = args.getfloat('Network', 'VARIATIONAL')
    ae = AE(encoder, decoder, NORM_REGULARIZE, VARIATIONAL)
    return ae
Beispiel #21
0
valid_ids = train_ids[val_data_idxs]
train_ids = train_ids[train_data_idxs]

train_gen = DataGen(train_ids, train_path, image_size=image_size, channels=channels, batch_size=batch_size)
valid_gen = DataGen(valid_ids, train_path, database_path=database_path, image_size=image_size, channels=channels, batch_size=batch_size)

train_steps = len(train_ids)//batch_size
valid_steps = len(valid_ids)//batch_size

if not os.path.exists(model_path):
    os.makedirs(model_path)

# model, model_middle = UNet(256, 3)

ae = AE(batch_size, units=200)
model_ae, model_encoder, latents, ios = ae.make_ae()

if os.path.exists(os.path.join(model_path, model_name)):
    try:
        model_ae.load_weights(os.path.join(model_path, model_name))
        print('Weights loaded from:')
        print(os.path.join(model_path,model_name))
    except ValueError as e:
        print('{0}'.format(e))
        print('Not loading old weights.')

# model_vae.add_loss(vae_loss)
optimizer = tf.keras.optimizers.Adam(lr=learning_rate) #, clipvalue=1000000.
model_ae.compile(optimizer=optimizer, metrics=["mae"], loss=ae.ae_loss())
model_ae.summary()
Beispiel #22
0
def main():

    model = AE()
    optimizer = optim.Adam(model.parameters(), lr=lr)
    criteon = nn.MSELoss()
    #criteon = nn.CrossEntropyLoss()
    print(model)

    #for epoch in range(epochs):
    for epoch in range(100):
        for step, (x, y) in enumerate(all_loader):
            # x: [b,1,100,100]
            x_hat = model(x, False)
            #print('x shape:',x.shape,'x_hat shape:',x_hat.shape)
            loss = criteon(x_hat, x)
            #backprop
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        print(epoch, 'loss', loss.item())

        # 使用encode部分
        # encoder = []
        # label = []
        encoder = torch.randn(301, 100)
        label = torch.randn(301, 2)
        #tmp = torch.randn(2, 1, 100, 100)
        for x, y in encoder_loader:
            #x,y = iter(all_loader).next()
            with torch.no_grad():
                x_encoder = model(x, True)
                # label.append(y)
                # encoder.append(x_encoder)
                label = y
                encoder = x_encoder
        encoder = encoder.numpy()
        label = label.numpy()
        #print(encoder.shape,label.shape)

        #  谱聚类
        if epoch == 0:
            pred = torch.zeros(301, 101)
            pred = pred.numpy()
        pred.T[0][:] = label.T[0]
        print(pred)

        from sklearn.cluster import SpectralClustering
        from sklearn.metrics import adjusted_rand_score
        from sklearn.metrics import normalized_mutual_info_score
        from sklearn.metrics.pairwise import cosine_similarity
        simatrix = 0.5 * cosine_similarity(encoder) + 0.5
        SC = SpectralClustering(affinity='precomputed',
                                assign_labels='discretize',
                                random_state=100)
        label1 = SC.fit_predict(simatrix)

        pred.T[epoch + 1][:] = label1[:]

        print('pred:', pred.shape)
        print(pred)
        if epoch == 99:
            pd.DataFrame(pred).to_csv("pred.csv", index=False, sep=',')
Beispiel #23
0
    ########## Build model ##########
    ### ckpt path should not contain '[' or ']'
    ae_h_dim_list_replaced = FLAGS.ae_h_dim_list.replace('[', '').replace(
        ']', '').replace(',', '-')
    dis_h_dim_list_replaced = FLAGS.dis_h_dim_list.replace('[', '').replace(
        ']', '').replace(',', '-')
    model_spec = 'm' + FLAGS.model + '_lr' + str(
        FLAGS.learning_rate
    ) + '_e' + str(FLAGS.epoch) + '_keep' + str(FLAGS.keep_prob) + '_b' + str(
        FLAGS.batch_size) + '_ae' + ae_h_dim_list_replaced + '_z' + str(
            FLAGS.z_dim)

    ##### AE #####
    if FLAGS.model == 'AE':
        model = AE(FLAGS.gpu_id, FLAGS.learning_rate, FLAGS.loss_type,
                   input_dim, FLAGS.z_dim, eval(FLAGS.ae_h_dim_list))
    elif FLAGS.model == 'VAE':
        model = VAE(FLAGS.gpu_id, FLAGS.learning_rate, FLAGS.loss_type,
                    input_dim, FLAGS.z_dim, eval(FLAGS.ae_h_dim_list))

    ##### GAN #####
    elif FLAGS.model == 'VANILLA_GAN':
        model = VANILLA_GAN(FLAGS.gpu_id, FLAGS.learning_rate, FLAGS.loss_type,
                            input_dim, FLAGS.z_dim, eval(FLAGS.ae_h_dim_list),
                            eval(FLAGS.dis_h_dim_list))
        model_spec += '_dis' + dis_h_dim_list_replaced
    elif FLAGS.model == 'INFO_GAN':
        model = INFO_GAN(FLAGS.gpu_id, FLAGS.learning_rate, FLAGS.loss_type,
                         input_dim, FLAGS.z_dim, eval(FLAGS.ae_h_dim_list),
                         eval(FLAGS.dis_h_dim_list))
        model_spec += '_dis' + dis_h_dim_list_replaced
Beispiel #24
0
        x = clusters.x.values

        check_origin_max_date_match(args.origin, clusters)

        for nc in nc_range:
            # Pick up where you left off. EmbedClust is very computationally expensive, so you don't want to
            # re-run unnecessarily. Also allows parallization (since other nodes will pick up unfinished origins)
            with engine.connect() as conn:
                sql = f"select 1 from embed_clust where origin='{args.origin}' and n_clusters={nc} limit 1"
                if args.pickup and conn.execute(sql).fetchone():
                    print(f'{bcolors.WARNING}skip origin={args.origin} nc={nc}{bcolors.ENDC}')
                    continue
            print(f"{bcolors.OKBLUE}origin={args.origin} nc={nc}{bcolors.ENDC}")

            K.clear_session()  # hyperopt creates many graphs, will max memory fast if not cleared
            ae = AE()
            hypers = AE.best_hypers(args.origin)
            if hypers is None:
                print("No embed_clust.use for this forecast origin; go into database and check-box some `use` column")
                break
            ae.compile(hypers)
            embed_clust = EmbedClust(ae, args, nc)

            print('...Pretraining...')
            embed_clust.ae.train(x)

            embed_clust.model.summary()
            embed_clust.compile(loss=['kld', 'mse'], loss_weights=[0.1, 1], optimizer='adam')
            y_pred = embed_clust.fit(x, tol=args.tol)

            # Save for use by RNN. See https://www.safaribooksonline.com/library/view/python-cookbook/0596001673/ch08s08.html
Beispiel #25
0
def train_dann(s_dataloaders, t_dataloaders, val_dataloader, test_dataloader,
               metric_name, seed, **kwargs):
    """

    :param s_dataloaders:
    :param t_dataloaders:
    :param kwargs:
    :return:
    """
    s_train_dataloader = s_dataloaders
    t_train_dataloader = t_dataloaders

    autoencoder = AE(input_dim=kwargs['input_dim'],
                     latent_dim=kwargs['latent_dim'],
                     hidden_dims=kwargs['encoder_hidden_dims'],
                     dop=kwargs['dop']).to(kwargs['device'])
    encoder = autoencoder.encoder

    target_decoder = MoMLP(input_dim=kwargs['latent_dim'],
                           output_dim=kwargs['output_dim'],
                           hidden_dims=kwargs['regressor_hidden_dims'],
                           out_fn=torch.nn.Sigmoid).to(kwargs['device'])

    target_regressor = EncoderDecoder(
        encoder=encoder, decoder=target_decoder).to(kwargs['device'])

    classifier = MLP(input_dim=kwargs['latent_dim'],
                     output_dim=1,
                     hidden_dims=kwargs['classifier_hidden_dims'],
                     dop=kwargs['dop'],
                     out_fn=torch.nn.Sigmoid).to(kwargs['device'])

    confounder_classifier = EncoderDecoder(encoder=autoencoder.encoder,
                                           decoder=classifier).to(
                                               kwargs['device'])

    train_history = defaultdict(list)
    s_target_regression_eval_train_history = defaultdict(list)
    t_target_regression_eval_train_history = defaultdict(list)
    target_regression_eval_val_history = defaultdict(list)
    target_regression_eval_test_history = defaultdict(list)

    confounded_loss = nn.BCEWithLogitsLoss()
    dann_params = [
        target_regressor.parameters(),
        confounder_classifier.decoder.parameters()
    ]
    dann_optimizer = torch.optim.AdamW(chain(*dann_params), lr=kwargs['lr'])

    # start alternative training
    for epoch in range(int(kwargs['train_num_epochs'])):
        if epoch % 50 == 0:
            print(f'DANN training epoch {epoch}')
        # start autoencoder training epoch
        for step, s_batch in enumerate(s_train_dataloader):
            t_batch = next(iter(t_train_dataloader))
            train_history = dann_train_step(classifier=confounder_classifier,
                                            model=target_regressor,
                                            s_batch=s_batch,
                                            t_batch=t_batch,
                                            loss_fn=confounded_loss,
                                            alpha=kwargs['alpha'],
                                            device=kwargs['device'],
                                            optimizer=dann_optimizer,
                                            history=train_history,
                                            scheduler=None)

        s_target_regression_eval_train_history = evaluate_target_regression_epoch(
            regressor=target_regressor,
            dataloader=s_train_dataloader,
            device=kwargs['device'],
            history=s_target_regression_eval_train_history)

        t_target_regression_eval_train_history = evaluate_target_regression_epoch(
            regressor=target_regressor,
            dataloader=t_train_dataloader,
            device=kwargs['device'],
            history=t_target_regression_eval_train_history)
        target_regression_eval_val_history = evaluate_target_regression_epoch(
            regressor=target_regressor,
            dataloader=val_dataloader,
            device=kwargs['device'],
            history=target_regression_eval_val_history)
        target_regression_eval_test_history = evaluate_target_regression_epoch(
            regressor=target_regressor,
            dataloader=test_dataloader,
            device=kwargs['device'],
            history=target_regression_eval_test_history)

        save_flag, stop_flag = model_save_check(
            history=target_regression_eval_val_history,
            metric_name=metric_name,
            tolerance_count=50)
        if save_flag:
            torch.save(
                target_regressor.state_dict(),
                os.path.join(kwargs['model_save_folder'],
                             f'dann_regressor_{seed}.pt'))
        if stop_flag:
            break
    target_regressor.load_state_dict(
        torch.load(
            os.path.join(kwargs['model_save_folder'],
                         f'dann_regressor_{seed}.pt')))

    # evaluate_target_regression_epoch(regressor=target_regressor,
    #                                  dataloader=val_dataloader,
    #                                  device=kwargs['device'],
    #                                  history=None,
    #                                  seed=seed,
    #                                  output_folder=kwargs['model_save_folder'])
    evaluate_target_regression_epoch(regressor=target_regressor,
                                     dataloader=test_dataloader,
                                     device=kwargs['device'],
                                     history=None,
                                     seed=seed,
                                     output_folder=kwargs['model_save_folder'])

    return target_regressor, (train_history,
                              s_target_regression_eval_train_history,
                              t_target_regression_eval_train_history,
                              target_regression_eval_val_history,
                              target_regression_eval_test_history)
Beispiel #26
0
def train_cleita(dataloader, seed, **kwargs):
    autoencoder = AE(input_dim=kwargs['input_dim'],
                     latent_dim=kwargs['latent_dim'],
                     hidden_dims=kwargs['encoder_hidden_dims'],
                     dop=kwargs['dop']).to(kwargs['device'])

    # get reference encoder
    aux_ae = deepcopy(autoencoder)

    aux_ae.encoder.load_state_dict(
        torch.load(os.path.join('./model_save', f'ft_encoder_{seed}.pt')))
    print('reference encoder loaded')
    reference_encoder = aux_ae.encoder

    # construct transmitter
    transmitter = MLP(input_dim=kwargs['latent_dim'],
                      output_dim=kwargs['latent_dim'],
                      hidden_dims=[kwargs['latent_dim']]).to(kwargs['device'])

    confounding_classifier = MLP(input_dim=kwargs['latent_dim'],
                                 output_dim=1,
                                 hidden_dims=kwargs['classifier_hidden_dims'],
                                 dop=kwargs['dop']).to(kwargs['device'])

    ae_train_history = defaultdict(list)
    ae_val_history = defaultdict(list)
    critic_train_history = defaultdict(list)
    gen_train_history = defaultdict(list)

    if kwargs['retrain_flag']:
        cleit_params = [autoencoder.parameters(), transmitter.parameters()]
        cleit_optimizer = torch.optim.AdamW(chain(*cleit_params),
                                            lr=kwargs['lr'])
        classifier_optimizer = torch.optim.RMSprop(
            confounding_classifier.parameters(), lr=kwargs['lr'])
        for epoch in range(int(kwargs['train_num_epochs'])):
            if epoch % 50 == 0:
                print(f'confounder wgan training epoch {epoch}')
            for step, batch in enumerate(dataloader):
                critic_train_history = critic_train_step(
                    critic=confounding_classifier,
                    ae=autoencoder,
                    reference_encoder=reference_encoder,
                    transmitter=transmitter,
                    batch=batch,
                    device=kwargs['device'],
                    optimizer=classifier_optimizer,
                    history=critic_train_history,
                    # clip=0.1,
                    gp=10.0)
                if (step + 1) % 5 == 0:
                    gen_train_history = gan_gen_train_step(
                        critic=confounding_classifier,
                        ae=autoencoder,
                        transmitter=transmitter,
                        batch=batch,
                        device=kwargs['device'],
                        optimizer=cleit_optimizer,
                        alpha=1.0,
                        history=gen_train_history)

        torch.save(autoencoder.state_dict(),
                   os.path.join(kwargs['model_save_folder'], 'cleit_ae.pt'))
        torch.save(transmitter.state_dict(),
                   os.path.join(kwargs['model_save_folder'], 'transmitter.pt'))
    else:
        try:
            autoencoder.load_state_dict(
                torch.load(
                    os.path.join(kwargs['model_save_folder'], 'cleit_ae.pt')))
            transmitter.load_state_dict(
                torch.load(
                    os.path.join(kwargs['model_save_folder'],
                                 'transmitter.pt')))
        except FileNotFoundError:
            raise Exception("No pre-trained encoder")

    encoder = EncoderDecoder(encoder=autoencoder.encoder,
                             decoder=transmitter).to(kwargs['device'])

    return encoder, (ae_train_history, ae_val_history, critic_train_history,
                     gen_train_history)
# ---------------- Write here the directory and file names you want to use for your model ------------- #

directory_name = '../resources/'
weights_filename = "ae_init.ckpt"
graph_filename = 'ae_graph.pb'
graph_text_filename = 'ae_graph_text.pb'  # This file will be readable by a human and contain all ops names/

# ----------------------------------------------------------------------------------------------------- #

IMAGE_WIDTH = 32
IMAGE_HEIGHT = 32
LATENT_DIM = 2

# ------------------------------- Instantiate the model and save the graph ---------------------------- #
tf.compat.v1.disable_eager_execution()
model = AE(IMAGE_WIDTH, IMAGE_HEIGHT, LATENT_DIM)
model.build()
model.summary()

# Open session and write model config and weights
gpu_options = tf.compat.v1.GPUOptions(allow_growth=True)

with tf.compat.v1.Session(config=tf.compat.v1.ConfigProto(
        gpu_options=gpu_options, log_device_placement=False)) as sess:

    init_all_vars_op = tf.compat.v1.variables_initializer(
        tf.compat.v1.global_variables(), name='init_all_vars_op')
    sess.run(init_all_vars_op)

    saver = tf.compat.v1.train.Saver(model.trainable_weights)
    saver_def = saver.as_saver_def()
Beispiel #28
0
def fine_tune_encoder(train_dataloader, val_dataloader, seed, test_dataloader=None,
                      metric_name='cpearsonr',
                      normalize_flag=False, **kwargs):
    autoencoder = AE(input_dim=kwargs['input_dim'],
                     latent_dim=kwargs['latent_dim'],
                     hidden_dims=kwargs['encoder_hidden_dims'],
                     dop=kwargs['dop']).to(kwargs['device'])
    encoder = autoencoder.encoder

    target_decoder = MoMLP(input_dim=kwargs['latent_dim'],
                           output_dim=kwargs['output_dim'],
                           hidden_dims=kwargs['regressor_hidden_dims'],
                           out_fn=torch.nn.Sigmoid).to(kwargs['device'])

    target_regressor = EncoderDecoder(encoder=encoder,
                                      decoder=target_decoder,
                                      normalize_flag=normalize_flag).to(kwargs['device'])

    target_regression_train_history = defaultdict(list)
    target_regression_eval_train_history = defaultdict(list)
    target_regression_eval_val_history = defaultdict(list)
    target_regression_eval_test_history = defaultdict(list)

    target_regression_optimizer = torch.optim.AdamW(target_regressor.parameters(), lr=kwargs['lr'])

    for epoch in range(kwargs['train_num_epochs']):
        if epoch % 10 == 0:
            print(f'MLP fine-tuning epoch {epoch}')
        for step, batch in enumerate(train_dataloader):
            target_regression_train_history = regression_train_step(model=target_regressor,
                                                                    batch=batch,
                                                                    device=kwargs['device'],
                                                                    optimizer=target_regression_optimizer,
                                                                    history=target_regression_train_history)
        target_regression_eval_train_history = evaluate_target_regression_epoch(regressor=target_regressor,
                                                                                dataloader=train_dataloader,
                                                                                device=kwargs['device'],
                                                                                history=target_regression_eval_train_history)
        target_regression_eval_val_history = evaluate_target_regression_epoch(regressor=target_regressor,
                                                                              dataloader=val_dataloader,
                                                                              device=kwargs['device'],
                                                                              history=target_regression_eval_val_history)

        if test_dataloader is not None:
            target_regression_eval_test_history = evaluate_target_regression_epoch(regressor=target_regressor,
                                                                                   dataloader=test_dataloader,
                                                                                   device=kwargs['device'],
                                                                                   history=target_regression_eval_test_history)
        save_flag, stop_flag = model_save_check(history=target_regression_eval_val_history,
                                                metric_name=metric_name,
                                                tolerance_count=50)
        if save_flag or epoch == 0:
            torch.save(target_regressor.state_dict(),
                       os.path.join(kwargs['model_save_folder'], f'target_regressor_{seed}.pt'))
            torch.save(target_regressor.encoder.state_dict(),
                       os.path.join(kwargs['model_save_folder'], f'ft_encoder_{seed}.pt'))
        if stop_flag:
            break

    target_regressor.load_state_dict(
        torch.load(os.path.join(kwargs['model_save_folder'], f'target_regressor_{seed}.pt')))

    evaluate_target_regression_epoch(regressor=target_regressor,
                                     dataloader=val_dataloader,
                                     device=kwargs['device'],
                                     history=None,
                                     seed=seed,
                                     cv_flag=True,
                                     output_folder=kwargs['model_save_folder'])
    if test_dataloader is not None:
        evaluate_target_regression_epoch(regressor=target_regressor,
                                         dataloader=test_dataloader,
                                         device=kwargs['device'],
                                         history=None,
                                         seed=seed,
                                         output_folder=kwargs['model_save_folder'])


    return target_regressor, (target_regression_train_history, target_regression_eval_train_history,
                              target_regression_eval_val_history, target_regression_eval_test_history)