Exemplo n.º 1
0
def main():
    args = parse.parse()

    # set random seeds
    random.seed(args.manual_seed)
    torch.manual_seed(args.manual_seed)
    np.random.seed(args.manual_seed)

    # prepare output directories
    base_dir = Path(args.out_dir)
    model_dir = base_dir.joinpath(args.model_name)
    if (args.resume or args.initialize) and not model_dir.exists():
        raise Exception("Model directory for resume does not exist")
    if not (args.resume or args.initialize) and model_dir.exists():
        c = ""
        while c != "y" and c != "n":
            c = input("Model directory already exists, overwrite?").strip()

        if c == "y":
            shutil.rmtree(model_dir)
        else:
            sys.exit(0)
    model_dir.mkdir(parents=True, exist_ok=True)

    summary_writer_dir = model_dir.joinpath("runs")
    summary_writer_dir.mkdir(exist_ok=True)
    save_path = model_dir.joinpath("checkpoints")
    save_path.mkdir(exist_ok=True)

    # prepare summary writer
    writer = SummaryWriter(summary_writer_dir, comment=args.writer_comment)

    # prepare data
    train_loader, val_loader, test_loader, args = load_dataset(
        args, flatten=args.flatten_image
    )

    # prepare flow model
    if hasattr(flows, args.flow):
        flow_model_template = getattr(flows, args.flow)

    flow_list = [flow_model_template(args.zdim) for _ in range(args.num_flows)]
    if args.permute_conv:
        convs = [flows.OneByOneConv(dim=args.zdim) for _ in range(args.num_flows)]
        flow_list = list(itertools.chain(*zip(convs, flow_list)))
    if args.actnorm:
        actnorms = [flows.ActNorm(dim=args.zdim) for _ in range(args.num_flows)]
        flow_list = list(itertools.chain(*zip(actnorms, flow_list)))
    prior = torch.distributions.MultivariateNormal(
        torch.zeros(args.zdim, device=args.device),
        torch.eye(args.zdim, device=args.device),
    )
    flow_model = NormalizingFlowModel(prior, flow_list).to(args.device)

    # prepare losses and autoencoder
    if args.dataset == "mnist":
        args.imshape = (1, 28, 28)
        if args.ae_model == "linear":
            ae_model = AutoEncoder(args.xdim, args.zdim, args.units, "binary").to(
                args.device
            )
            ae_loss = nn.BCEWithLogitsLoss(reduction="sum").to(args.device)

        elif args.ae_model == "conv":
            args.zshape = (8, 7, 7)
            ae_model = ConvAutoEncoder(
                in_channels=1,
                image_size=np.squeeze(args.imshape),
                activation=nn.Hardtanh(0, 1),
            ).to(args.device)
            ae_loss = nn.BCELoss(reduction="sum").to(args.device)

    elif args.dataset == "cifar10":
        args.imshape = (3, 32, 32)
        args.zshape = (8, 8, 8)
        ae_loss = nn.MSELoss(reduction="sum").to(args.device)
        ae_model = ConvAutoEncoder(in_channels=3, image_size=args.imshape).to(
            args.device
        )

    # setup optimizers
    ae_optimizer = optim.Adam(ae_model.parameters(), args.learning_rate)
    flow_optimizer = optim.Adam(flow_model.parameters(), args.learning_rate)

    total_epochs = np.max([args.vae_epochs, args.flow_epochs, args.epochs])

    if args.resume:
        checkpoint = torch.load(args.model_path, map_location=args.device)
        flow_model.load_state_dict(checkpoint["flow_model"])
        ae_model.load_state_dict(checkpoint["ae_model"])
        flow_optimizer.load_state_dict(checkpoint["flow_optimizer"])
        ae_optimizer.load_state_dict(checkpoint["ae_optimizer"])
        init_epoch = checkpoint["epoch"]
    elif args.initialize:
        checkpoint = torch.load(args.model_path, map_location=args.device)
        flow_model.load_state_dict(checkpoint["flow_model"])
        ae_model.load_state_dict(checkpoint["ae_model"])
    else:
        init_epoch = 1

    if args.initialize:
        raise NotImplementedError

    # training loop
    for epoch in trange(init_epoch, total_epochs + 1):
        if epoch <= args.vae_epochs:
            train_ae(
                epoch,
                train_loader,
                ae_model,
                ae_optimizer,
                writer,
                ae_loss,
                device=args.device,
            )
            log_ae_tensorboard_images(
                ae_model,
                val_loader,
                writer,
                epoch,
                "AE/val/Images",
                xshape=args.imshape,
            )
            # evaluate_ae(epoch, test_loader, ae_model, writer, ae_loss)

        if epoch <= args.flow_epochs:
            train_flow(
                epoch,
                train_loader,
                flow_model,
                ae_model,
                flow_optimizer,
                writer,
                device=args.device,
                flatten=not args.no_flatten_latent,
            )

            log_flow_tensorboard_images(
                flow_model,
                ae_model,
                writer,
                epoch,
                "Flow/sampled/Images",
                xshape=args.imshape,
                zshape=args.zshape,
            )

        if epoch % args.save_iter == 0:
            checkpoint_dict = {
                "epoch": epoch,
                "ae_optimizer": ae_optimizer.state_dict(),
                "flow_optimizer": flow_optimizer.state_dict(),
                "ae_model": ae_model.state_dict(),
                "flow_model": flow_model.state_dict(),
            }
            fname = f"model_{epoch}.pt"
            save_checkpoint(checkpoint_dict, save_path, fname)

    if args.save_images:
        p = Path(f"images/mnist/{args.model_name}")
        p.mkdir(parents=True, exist_ok=True)
        n_samples = 10000

        print("final epoch images")
        flow_model.eval()
        ae_model.eval()
        with torch.no_grad():
            z = flow_model.sample(n_samples)
            z = z.to(next(ae_model.parameters()).device)
            xcap = ae_model.decoder.predict(z).to("cpu").view(-1, *args.imshape).numpy()
        xcap = (np.rint(xcap) * int(255)).astype(np.uint8)
        for i, im in enumerate(xcap):
            imsave(f'{p.joinpath(f"im_{i}.png").as_posix()}', np.squeeze(im))

    writer.close()
Exemplo n.º 2
0
Arquivo: test.py Projeto: yongxinw/zsl
class Tester(object):
    def __init__(self, args):
        super(Tester, self).__init__()

        self.args = args

        self.model = AutoEncoder(args)
        self.model.load_state_dict(torch.load(args.checkpoint))
        self.model.cuda()
        self.model.eval()

        self.result = {}

        self.train_dataset = CUBDataset(split='train')
        self.test_dataset = CUBDataset(split='test')
        self.val_dataset = CUBDataset(split='val')

        self.train_loader = DataLoader(dataset=self.train_dataset,
                                       batch_size=args.batch_size)
        self.test_loader = DataLoader(dataset=self.test_dataset,
                                      batch_size=args.batch_size)
        self.val_loader = DataLoader(dataset=self.val_dataset,
                                     batch_size=100,
                                     shuffle=True)

        train_cls = self.train_dataset.get_classes('train')
        test_cls = self.test_dataset.get_classes('test')
        print("Load class")
        print(train_cls)
        print(test_cls)

        self.zsl = ZSLPrediction(train_cls, test_cls)

    # def tSNE(self):

    def conse_prediction(self, mode='test'):
        def pred(recon_x, z_tilde, output):
            cls_score = output.detach().cpu().numpy()
            print(cls_score)
            pred = self.zsl.conse_wordembedding_predict(
                cls_score, self.args.conse_top_k)
            return pred

        self.get_features(mode=mode, pred_func=pred)

        if (mode + '_pred') in self.result:

            target = self.result[mode + '_label']
            pred = self.result[mode + '_pred']
            print(target)
            print(pred)
            acc = np.sum(target == pred)
            print(acc)
            total = target.shape[0]
            print(total)
            return acc / float(total)
        else:
            raise NotImplementedError

    def knn_prediction(self, mode='test'):
        self.get_features(mode=mode, pred_func=None)

        if (mode + '_feature') in self.result:
            features = self.result[mode + '_feature']
            labels = self.result[mode + '_label']
            print(labels)
            self.zsl.construct_nn(features,
                                  labels,
                                  k=5,
                                  metric='cosine',
                                  sample_num=5)
            pred = self.zsl.nn_predict(features)

            acc = np.sum(labels == pred)
            total = labels.shape[0]

            return acc / float(total)
        else:
            raise NotImplementedError

    def tSNE(self, mode='train'):
        self.get_features(mode=mode, pred_func=None)

        total_num = self.result[mode + '_feature'].shape[0]

        random_index = np.random.permutation(total_num)

        random_index = random_index[:30]

        self.zsl.tSNE_visualization(self.result[mode+'_feature'][random_index,:], \
                                    self.result[mode+'_label'][random_index], \
                                    mode=mode,
                                    file_name= self.args.tsne_out)

    def get_features(self, mode='test', pred_func=None):
        self.model.eval()
        if pred_func is None and (mode + '_feature') in self.result:
            print("Use cached result")
            return
        if pred_func is not None and (mode + '_pred') in self.result:
            print("Use cached result")
            return

        if mode == 'train':
            loader = self.train_loader
        elif mode == 'test':
            loader = self.test_loader

        all_z = []
        all_label = []
        all_pred = []

        for data in tqdm(loader):
            # if idx == 3:
            #     break

            images = Variable(data['image64_crop'].cuda())
            target = Variable(data['class_id'].cuda())

            recon_x, z_tilde, output = self.model(images)
            target = target.detach().cpu().numpy()

            output = F.softmax(output, dim=1)

            all_label.append(target)
            all_z.append(z_tilde.detach().cpu().numpy())

            if pred_func is not None:
                pred = pred_func(recon_x, z_tilde, output)
                all_pred.append(pred)

        self.result[mode + '_feature'] = np.vstack(all_z)  # all features
        # print(all_label)
        self.result[mode + '_label'] = np.hstack(all_label)  # all test label

        if pred_func is not None:
            self.result[mode + '_pred'] = np.hstack(all_pred)
            print(self.result[mode + '_pred'].shape)
        print(self.result[mode + '_feature'].shape)
        print(self.result[mode + '_label'].shape)

    def validation_recon(self):
        self.model.eval()
        for idx, data in enumerate(self.val_loader):
            if idx == 1:
                break

            images = Variable(data['image64_crop'].cuda())
            recon_x, z_tilde, output = self.model(images)

            all_recon_images = recon_x.detach().cpu().numpy()  #N x 3 x 64 x 64
            all_origi_images = data['image64_crop'].numpy()  #N x 3 x 64 x 64

            for i in range(all_recon_images.shape[0]):
                imsave(
                    './recon/recon' + str(i) + '.png',
                    np.transpose(np.squeeze(all_origi_images[i, :, :, :]),
                                 [1, 2, 0]))
                imsave(
                    './recon/orig' + str(i) + '.png',
                    np.transpose(np.squeeze(all_recon_images[i, :, :, :]),
                                 [1, 2, 0]))

    def test_nn_image(self):
        self.get_features(mode='test', pred_func=None)
        self.get_features(mode='train', pred_func=None)

        N = 100
        random_index = np.random.permutation(
            self.result['test_feature'].shape[0])[:N]

        from sklearn.neighbors import NearestNeighbors

        neigh = NearestNeighbors()
        neigh.fit(self.result['train_feature'])

        test_feature = self.result['test_feature'][random_index, :]
        _, pred_index = neigh.kneighbors(test_feature, 1)

        for i in range(N):
            test_index = random_index[i]

            data = self.test_dataset[test_index]
            image = data['image64_crop'].numpy()  #1 x 3 x 64 x 64
            print(image.shape)
            imsave('./nn_image/test' + str(i) + '.png',
                   np.transpose(np.squeeze(image), [1, 2, 0]))

            train_index = pred_index[i][0]
            print(train_index)
            data = self.train_dataset[train_index]
            image = data['image64_crop'].numpy()  #1 x 3 x 64 x 64
            print(image.shape)
            imsave('./nn_image/train' + str(i) + '.png',
                   np.transpose(np.squeeze(image), [1, 2, 0]))
        batch = batch.to(device)

        embedding = model.encode(batch)
        embedding = embedding.detach()
        del batch
        torch.cuda.empty_cache()
        prediction = classifier(embedding)
        prediction = torch.argmax(prediction, dim=1)
        precision = (prediction == label).sum()
        precision = torch.true_divide(precision, prediction.shape[0])
        precision_list.append(precision.item())
    mean_precision = sum(precision_list) / len(precision_list)
    return mean_precision


model.eval()
loss_list = []
for epoch in range(45):
    loss_list = []
    classifier.train()
    if epoch > 25:
        for param in optimizer.param_groups:
            param['lr'] = max(0.0001, param['lr'] / 1.2)
            print('lr: ', param['lr'])
    for batch_i, batch in enumerate(data_loader_train):
        label = batch['latent'][:, 0]  #figure type
        label = label.type(torch.LongTensor) - 1

        label = label.to(device)
        batch = batch['image'].unsqueeze(1)
        batch = batch.type(torch.FloatTensor)
Exemplo n.º 4
0
class Trainer(object):
    def __init__(self, args):
        # load network
        self.G = AutoEncoder(args)
        self.D = Discriminator(args)
        self.G.weight_init()
        self.D.weight_init()
        self.G.cuda()
        self.D.cuda()
        self.criterion = nn.MSELoss()

        # load data
        self.train_dataset = CUBDataset(split='train')
        self.valid_dataset = CUBDataset(split='val')
        self.train_loader = DataLoader(dataset=self.train_dataset, batch_size=args.batch_size)
        self.valid_loader = DataLoader(dataset=self.valid_dataset, batch_size=args.batch_size)

        # Optimizers
        self.G_optim = optim.Adam(self.G.parameters(), lr = args.lr_G)
        self.D_optim = optim.Adam(self.D.parameters(), lr = 0.5 * args.lr_D)
        self.G_scheduler = StepLR(self.G_optim, step_size=30, gamma=0.5)
        self.D_scheduler = StepLR(self.D_optim, step_size=30, gamma=0.5)

        # Parameters
        self.epochs = args.epochs
        self.batch_size = args.batch_size
        self.z_var = args.z_var 
        self.sigma = args.sigma
        self.lambda_1 = args.lambda_1
        self.lambda_2 = args.lambda_2

        log_dir = os.path.join(args.log_dir, datetime.now().strftime("%m_%d_%H_%M_%S"))
        # if not os.path.isdir(log_dir):
            # os.makedirs(log_dir)
        self.writter = SummaryWriter(log_dir)

    def train(self):
        global_step = 0
        self.G.train()
        self.D.train()
        ones = Variable(torch.ones(self.batch_size, 1).cuda())
        zeros = Variable(torch.zeros(self.batch_size, 1).cuda())

        for epoch in range(self.epochs):
            self.G_scheduler.step()
            self.D_scheduler.step()
            print("training epoch {}".format(epoch))
            all_num = 0.0
            acc_num = 0.0
            images_index = 0
            for data in tqdm(self.train_loader):
                images = Variable(data['image64'].cuda())
                target_image = Variable(data['image64'].cuda()) 
                target = Variable(data['class_id'].cuda())
                recon_x, z_tilde, output = self.G(images)
                z = Variable((self.sigma*torch.randn(z_tilde.size())).cuda())
                log_p_z = log_density_igaussian(z, self.z_var).view(-1, 1)
                ones = Variable(torch.ones(images.size()[0], 1).cuda())
                zeros = Variable(torch.zeros(images.size()[0], 1).cuda())

                # ======== Train Discriminator ======== #
                D_z = self.D(z)
                D_z_tilde = self.D(z_tilde)
                D_loss = F.binary_cross_entropy_with_logits(D_z+log_p_z, ones) + \
                    F.binary_cross_entropy_with_logits(D_z_tilde+log_p_z, zeros)

                total_D_loss = self.lambda_1*D_loss
                self.D_optim.zero_grad()
                total_D_loss.backward(retain_graph=True)
                self.D_optim.step()

                # ======== Train Generator ======== #
                recon_loss = F.mse_loss(recon_x, target_image, reduction='sum').div(self.batch_size)
                G_loss = F.binary_cross_entropy_with_logits(D_z_tilde+log_p_z, ones)
                class_loss = F.cross_entropy(output, target)
                total_G_loss = recon_loss + self.lambda_1*G_loss + self.lambda_2*class_loss
                self.G_optim.zero_grad()
                total_G_loss.backward()
                self.G_optim.step()

                # ======== Compute Classification Accuracy ======== #
                values, indices = torch.max(output, 1)
                acc_num += torch.sum((indices == target)).cpu().item()
                all_num += len(target)
                
                # ======== Log by TensorBoardX
                global_step += 1
                if (global_step + 1) % 10 == 0:
                    self.writter.add_scalar('train/recon_loss', recon_loss.cpu().item(), global_step)
                    self.writter.add_scalar('train/G_loss', G_loss.cpu().item(), global_step)
                    self.writter.add_scalar('train/D_loss', D_loss.cpu().item(), global_step)
                    self.writter.add_scalar('train/classify_loss', class_loss.cpu().item(), global_step)
                    self.writter.add_scalar('train/total_G_loss', total_G_loss.cpu().item(), global_step)
                    self.writter.add_scalar('train/acc', acc_num/all_num, global_step)
                    if images_index < 5 and torch.rand(1) < 0.5:
                        self.writter.add_image('train_output_{}'.format(images_index), recon_x[0], global_step)
                        self.writter.add_image('train_target_{}'.format(images_index), target_image[0], global_step)
                        images_index += 1
            if epoch % 2 == 0:
                self.validate(global_step)

    def validate(self, global_step):
        self.G.eval()
        self.D.eval()
        acc_num = 0.0
        all_num = 0.0
        recon_loss = 0.0
        images_index = 0
        for data in tqdm(self.valid_loader):
            images = Variable(data['image64'].cuda())
            target_image = Variable(data['image64'].cuda()) 
            target = Variable(data['class_id'].cuda())
            recon_x, z_tilde, output = self.G(images)
            values, indices = torch.max(output, 1)
            acc_num += torch.sum((indices == target)).cpu().item()
            all_num += len(target)
            recon_loss += F.mse_loss(recon_x, target_image, reduction='sum').cpu().item()
            if images_index < 5:
                self.writter.add_image('valid_output_{}'.format(images_index), recon_x[0], global_step)
                self.writter.add_image('valid_target_{}'.format(images_index), target_image[0], global_step)
                images_index += 1

        self.writter.add_scalar('valid/acc', acc_num/all_num, global_step)
        self.writter.add_scalar('valid/recon_loss', recon_loss/all_num, global_step)
Exemplo n.º 5
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--filename', required=True, help='Name/path of file')
    parser.add_argument('--save_dir', default='./outputs', help='Path to dictionary where will be save results.')

    parser.add_argument('--pretrain_epochs', type=int, default=100, help="Number of epochs to pretrain model AE")
    parser.add_argument('--epochs', type=int, default=100, help="Number of epochs to train AE and classifier")
    parser.add_argument('--dims_layers_ae', type=int, nargs='+', default=[500, 100, 10],
                        help="Dimensional of layers in AE")
    parser.add_argument('--dims_layers_classifier', type=int, nargs='+', default=[10, 5],
                        help="Dimensional of layers in classifier")
    parser.add_argument('--batch_size', type=int, default=50)
    parser.add_argument('--lr', type=float, default=0.001, help="Learning rate")
    parser.add_argument('--use_dropout', action='store_true', help="Use dropout")

    parser.add_argument('--no-cuda', action='store_true', help='disables CUDA training')
    parser.add_argument('--seed', type=int, default=1234, help='random seed (default: 1)')

    parser.add_argument('--procedure', nargs='+', choices=['pre-training_ae', 'training_classifier', 'training_all'],
                        help='Procedure which you can use. Choice from: pre-training_ae, training_all, '
                             'training_classifier')
    parser.add_argument('--criterion_classifier', default='BCELoss', choices=['BCELoss', 'HingeLoss'],
                        help='Kind of loss function')
    parser.add_argument('--scale_loss', type=float, default=1., help='Weight for loss of classifier')
    parser.add_argument('--earlyStopping', type=int, default=None, help='Number of epochs to early stopping')
    parser.add_argument('--use_scheduler', action='store_true')
    args = parser.parse_args()
    print(args)

    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    use_cuda = not args.no_cuda and torch.cuda.is_available()
    device = torch.device("cuda" if use_cuda else "cpu")

    loaded = np.load(args.filename)
    x_train = loaded[f'data_train']
    x_test = loaded[f'data_test']
    y_train = loaded[f'lab_train']
    y_test = loaded[f'lab_test']
    del loaded

    name_target = PurePosixPath(args.filename).parent.stem
    n_split = PurePosixPath(args.filename).stem
    save_dir = f'{args.save_dir}/tensorboard/{name_target}_{n_split}'
    Path(save_dir).mkdir(parents=True, exist_ok=True)

    if args.dims_layers_classifier[0] == -1:
        args.dims_layers_classifier[0] = x_test.shape[1]

    model_classifier = Classifier(args.dims_layers_classifier, args.use_dropout).to(device)
    if args.criterion_classifier == 'HingeLoss':
        criterion_classifier = nn.HingeEmbeddingLoss()
        print('Use "Hinge" loss.')
    else:
        criterion_classifier = nn.BCEWithLogitsLoss()

    model_ae = None
    criterion_ae = None
    if 'training_classifier' != args.procedure[0]:
        args.dims_layers_ae = [x_train.shape[1]] + args.dims_layers_ae
        assert args.dims_layers_ae[-1] == args.dims_layers_classifier[0], \
            'Dimension of latent space must be equal with dimension of input classifier!'

        model_ae = AutoEncoder(args.dims_layers_ae, args.use_dropout).to(device)
        criterion_ae = nn.MSELoss()
        optimizer = torch.optim.Adam(list(model_ae.parameters()) + list(model_classifier.parameters()), lr=args.lr)
    else:
        optimizer = torch.optim.Adam(model_classifier.parameters(), lr=args.lr)

    scheduler = None
    if args.use_scheduler:
        scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda ep: 0.95)

    writer = SummaryWriter(save_dir)

    total_scores = {'roc_auc': 0, 'acc': 0, 'mcc': 0, 'bal': 0, 'recall': 0,
                    'max_roc_auc': 0, 'max_acc': 0, 'max_mcc': 0, 'max_bal': 0, 'max_recall': 0,
                    'pre-fit_time': 0, 'pre-score_time': 0, 'fit_time': 0, 'score_time': 0
                    }

    dir_model_ae = f'{args.save_dir}/models_AE'
    Path(dir_model_ae).mkdir(parents=True, exist_ok=True)
    # dir_model_classifier = f'{args.save_dir}/models_classifier'
    # Path(dir_model_classifier).mkdir(parents=True, exist_ok=True)

    path_ae = f'{dir_model_ae}/{name_target}_{n_split}.pth'
    # path_classifier = f'{dir_model_classifier}/{name_target}_{n_split}.pth'

    if 'pre-training_ae' in args.procedure:
        min_val_loss = np.Inf
        epochs_no_improve = 0

        epoch_tqdm = tqdm(range(args.pretrain_epochs), desc="Epoch pre-train loss")
        for epoch in epoch_tqdm:
            loss_train, time_trn = train_step(model_ae, None, criterion_ae, None, optimizer, scheduler, x_train,
                                              y_train, device, writer, epoch, args.batch_size, 'pre-training_ae')
            loss_test, _ = test_step(model_ae, None, criterion_ae, None, x_test, y_test, device, writer, epoch,
                                     args.batch_size, 'pre-training_ae')

            if not np.isfinite(loss_train):
                break

            total_scores['pre-fit_time'] += time_trn

            if loss_test < min_val_loss:
                torch.save(model_ae.state_dict(), path_ae)
                epochs_no_improve = 0
                min_val_loss = loss_test
            else:
                epochs_no_improve += 1
            epoch_tqdm.set_description(
                f"Epoch pre-train loss: {loss_train:.5f}, test loss: {loss_test:.5f} (minimal val-loss: {min_val_loss:.5f}, stop: {epochs_no_improve}|{args.earlyStopping})")
            if args.earlyStopping is not None and epoch >= args.earlyStopping and epochs_no_improve == args.earlyStopping:
                print('\033[1;31mEarly stopping in pre-training model\033[0m')
                break
        print(f"\033[1;5;33mLoad model AE form '{path_ae}'\033[0m")
        if device.type == "cpu":
            model_ae.load_state_dict(torch.load(path_ae, map_location=lambda storage, loc: storage))
        else:
            model_ae.load_state_dict(torch.load(path_ae))
        model_ae = model_ae.to(device)
        model_ae.eval()

    min_val_loss = np.Inf
    epochs_no_improve = 0

    epoch = None
    stage = 'training_classifier' if 'training_classifier' in args.procedure else 'training_all'
    epoch_tqdm = tqdm(range(args.epochs), desc="Epoch train loss")
    for epoch in epoch_tqdm:
        loss_train, time_trn = train_step(model_ae, model_classifier, criterion_ae, criterion_classifier, optimizer,
                                          scheduler, x_train, y_train, device, writer, epoch, args.batch_size,
                                          stage, args.scale_loss)
        loss_test, scores_val, time_tst = test_step(model_ae, model_classifier, criterion_ae, criterion_classifier,
                                                    x_test, y_test, device, writer, epoch, args.batch_size, stage,
                                                    args.scale_loss)

        if not np.isfinite(loss_train):
            break

        total_scores['fit_time'] += time_trn
        total_scores['score_time'] += time_tst
        if total_scores['max_roc_auc'] < scores_val['roc_auc']:
            for key, val in scores_val.items():
                total_scores[f'max_{key}'] = val

        if loss_test < min_val_loss:
            # torch.save(model_ae.state_dict(), path_ae)
            # torch.save(model_classifier.state_dict(), path_classifier)
            epochs_no_improve = 0
            min_val_loss = loss_test
            for key, val in scores_val.items():
                total_scores[key] = val
        else:
            epochs_no_improve += 1
        epoch_tqdm.set_description(
            f"Epoch train loss: {loss_train:.5f}, test loss: {loss_test:.5f} (minimal val-loss: {min_val_loss:.5f}, stop: {epochs_no_improve}|{args.earlyStopping})")
        if args.earlyStopping is not None and epoch >= args.earlyStopping and epochs_no_improve == args.earlyStopping:
            print('\033[1;31mEarly stopping!\033[0m')
            break
    total_scores['score_time'] /= epoch + 1
    writer.close()

    save_file = f'{args.save_dir}/{name_target}.txt'
    head = 'idx;params'
    temp = f'{n_split};pretrain_epochs:{args.pretrain_epochs},dims_layers_ae:{args.dims_layers_ae},' \
           f'dims_layers_classifier:{args.dims_layers_classifier},batch_size:{args.batch_size},lr:{args.lr}' \
           f'use_dropout:{args.use_dropout},procedure:{args.procedure},scale_loss:{args.scale_loss},' \
           f'earlyStopping:{args.earlyStopping}'
    for key, val in total_scores.items():
        head = head + f';{key}'
        temp = temp + f';{val}'

    not_exists = not Path(save_file).exists()
    with open(save_file, 'a') as f:
        if not_exists:
            f.write(f'{head}\n')
        f.write(f'{temp}\n')