Example #1
0
def main(args):
    print('Loading data')
    idxs = np.load(args.boards_file, allow_pickle=True)['idxs']
    print(f'Number of Boards: {len(idxs)}')

    if torch.cuda.is_available() and args.num_gpus > 0:
        device = torch.device('cuda:0')
    else:
        device = torch.device('cpu')

    if args.shuffle:
        np.random.shuffle(idxs)

    train_idxs = idxs[:-args.num_test]
    test_idxs = idxs[-args.num_test:]

    train_loader = DataLoader(Boards(train_idxs),
                              batch_size=args.batch_size,
                              shuffle=False)
    test_loader = DataLoader(Boards(test_idxs), batch_size=args.batch_size)

    model = AutoEncoder().to(device)
    if args.model_loadname:
        model.load_state_dict(torch.load(args.model_loadname))

    optimizer = optim.Adam(model.parameters(), lr=args.lr)

    model.train()
    losses = []
    total_iters = 0

    for epoch in range(args.init_epoch, args.epochs):
        print(f'Running epoch {epoch} / {args.epochs}\n')
        for batch_idx, board in tqdm(enumerate(train_loader),
                                     total=len(train_loader)):
            board = board.to(device)
            optimizer.zero_grad()
            loss = model.loss(board)
            loss.backward()

            losses.append(loss.item())
            optimizer.step()

            if total_iters % args.log_interval == 0:
                tqdm.write(f'Loss: {loss.item()}')

            if total_iters % args.save_interval == 0:
                torch.save(
                    model.state_dict(),
                    append_to_modelname(args.model_savename, total_iters))
                plot_losses(losses, 'vis/ae_losses.png')
            total_iters += 1
    print('train dataset length: ', len(train_val['train']))
    print('val dataset length: ', len(val_test['val']))
    print('test dataset length: ', len(val_test['train']))

    print('latent space size:', conf['model']['latent_size'])
    print('batch size:', conf['train']['batch_size'])
    print('margin:', conf['train']['margin'])

    # for epoch in range(60):
    #     if epoch > 15:
    #         for param in optimizer.param_groups:
    #             param['lr'] = max(0.00005, param['lr'] / conf['train']['lr_decay'])
    #             print('lr: ', param['lr'])
    #     validation_loss, recall, recall10, mean_epoch_loss_encoder, penalty, mean_epoch_loss_metric \
    #         = contrastive_autoencoder(model, optimizer, data_loader_train, data_loader_val, device)
    #     print('loss: {0:2.5f}, autoencoder val: {1:2.5f}, loss metric learn: {2:1.5f}, penalty: {3:1.5f}'
    #           .format(mean_epoch_loss_encoder, validation_loss, mean_epoch_loss_metric, penalty) )

    iterative_training(model, optimizer, data_loader_train, data_loader_val,
                       device)

    torch.save(model.state_dict(), save_path)

    print('calculating recall')
    #recall, recall10 = recall_validation(model, data_loader_val, transform1, transform2, device)
    #print('train mode, recall: {0:2.4f}, recall10: {1:2.4f}'.format(recall, recall10))
    model.eval()
    recall, recall10 = recall_validation(model, data_loader_val, transform1,
                                         transform2, device)
    print('eval mode, recall: {0:2.4f}, recall10: {1:2.4f}'.format(
        recall, recall10))
Example #3
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--filename', required=True, help='Name/path of file')
    parser.add_argument('--savefile',
                        type=str,
                        default='./output.txt',
                        help='Path to file where will be save results')
    parser.add_argument('--class_weight',
                        action='store_true',
                        default=None,
                        help='Use balance weight')
    parser.add_argument('--seed', default=1234, help='Number of seed')

    parser.add_argument('--pretrain_epochs',
                        type=int,
                        default=100,
                        help="Number of epochs to pretrain model AE")
    parser.add_argument('--dims_layers_ae',
                        type=int,
                        nargs='+',
                        default=[500, 100, 10],
                        help="Dimensional of layers in AE")
    parser.add_argument('--batch_size', type=int, default=50)
    parser.add_argument('--lr',
                        type=float,
                        default=0.001,
                        help="Learning rate")
    parser.add_argument('--use_dropout',
                        action='store_true',
                        help="Use dropout")
    parser.add_argument('--no-cuda',
                        action='store_true',
                        help='disables CUDA training')
    parser.add_argument('--earlyStopping',
                        type=int,
                        default=None,
                        help='Number of epochs to early stopping')
    parser.add_argument('--use_scheduler', action='store_true')
    args = parser.parse_args()
    print(args)

    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    use_cuda = not args.no_cuda and torch.cuda.is_available()
    device = torch.device("cuda" if use_cuda else "cpu")
    print(f'Device: {device.type}')

    loaded = np.load(args.filename)
    data = loaded['data']
    labels = loaded['label']
    del loaded

    name_target = PurePosixPath(args.savefile).stem
    save_dir = f'{PurePosixPath(args.savefile).parent}/tensorboard/{name_target}'
    Path(save_dir).mkdir(parents=True, exist_ok=True)

    args.dims_layers_ae = [data.shape[1]] + args.dims_layers_ae
    model_ae = AutoEncoder(args.dims_layers_ae, args.use_dropout).to(device)

    criterion_ae = nn.MSELoss()
    optimizer = torch.optim.Adam(model_ae.parameters(),
                                 lr=args.lr,
                                 weight_decay=1e-5)

    scheduler = None
    if args.use_scheduler:
        scheduler = torch.optim.lr_scheduler.LambdaLR(
            optimizer, lr_lambda=lambda ep: 0.95)

    min_val_loss = np.Inf
    epochs_no_improve = 0
    fit_time_ae = 0
    writer = SummaryWriter(save_dir)
    model_path = f'{PurePosixPath(args.savefile).parent}/models_AE/{name_target}.pth'
    Path(PurePosixPath(model_path).parent).mkdir(parents=True, exist_ok=True)
    epoch_tqdm = tqdm(range(args.pretrain_epochs), desc="Epoch loss")
    for epoch in epoch_tqdm:
        loss_train, fit_t = train_step(model_ae, criterion_ae, optimizer,
                                       scheduler, data, labels, device, writer,
                                       epoch, args.batch_size)
        fit_time_ae += fit_t
        if loss_train < min_val_loss:
            torch.save(model_ae.state_dict(), model_path)
            epochs_no_improve = 0
            min_val_loss = loss_train
        else:
            epochs_no_improve += 1
        epoch_tqdm.set_description(
            f"Epoch loss: {loss_train:.5f} (minimal loss: {min_val_loss:.5f}, stop: {epochs_no_improve}|{args.earlyStopping})"
        )
        if args.earlyStopping is not None and epoch > args.earlyStopping and epochs_no_improve == args.earlyStopping:
            print('\033[1;31mEarly stopping in AE model\033[0m')
            break

    print('===================================================')
    print(f'Transforming data to lower dimensional')
    if device.type == "cpu":
        model_ae.load_state_dict(
            torch.load(model_path, map_location=lambda storage, loc: storage))
    else:
        model_ae.load_state_dict(torch.load(model_path))
    model_ae.eval()

    low_data = np.empty((data.shape[0], args.dims_layers_ae[-1]))
    n_batch, rest = divmod(data.shape[0], args.batch_size)
    n_batch = n_batch + 1 if rest else n_batch
    score_time_ae = 0
    with torch.no_grad():
        test_tqdm = tqdm(range(n_batch), desc="Transform data", leave=False)
        for i in test_tqdm:
            start_time = time.time()
            batch = torch.from_numpy(
                data[i * args.batch_size:(i + 1) *
                     args.batch_size, :]).float().to(device)
            # ===================forward=====================
            z, _ = model_ae(batch)
            low_data[i * args.batch_size:(i + 1) *
                     args.batch_size, :] = z.detach().cpu().numpy()
            end_time = time.time()
            score_time_ae += end_time - start_time
    print('Data shape after transformation: {}'.format(low_data.shape))
    print('===================================================')

    if args.class_weight:
        args.class_weight = 'balanced'
    else:
        args.class_weight = None

    # Split data
    sss = StratifiedShuffleSplit(n_splits=3,
                                 test_size=0.1,
                                 random_state=args.seed)
    scoring = {
        'acc': make_scorer(accuracy_score),
        'roc_auc': make_scorer(roc_auc_score, needs_proba=True),
        'mcc': make_scorer(matthews_corrcoef),
        'bal': make_scorer(balanced_accuracy_score),
        'recall': make_scorer(recall_score)
    }

    max_iters = 10000
    save_results(args.savefile,
                 'w',
                 'model',
                 None,
                 True,
                 fit_time_ae=fit_time_ae,
                 score_time_ae=score_time_ae)

    with warnings.catch_warnings():
        warnings.simplefilter('ignore', ConvergenceWarning)
        warnings.simplefilter('ignore', RuntimeWarning)
        environ["PYTHONWARNINGS"] = "ignore"

        # Linear SVM
        print("\rLinear SVM         ", end='')
        parameters = {'C': [0.01, 0.1, 1, 10, 100]}
        # svc = svm.LinearSVC(class_weight=args.class_weight, random_state=seed)
        svc = svm.SVC(kernel='linear',
                      class_weight=args.class_weight,
                      random_state=args.seed,
                      probability=True,
                      max_iter=max_iters)
        clf = GridSearchCV(svc,
                           parameters,
                           cv=sss,
                           n_jobs=-1,
                           scoring=scoring,
                           refit='roc_auc',
                           return_train_score=True)
        try:
            clf.fit(low_data, labels)
        except Exception as e:
            if hasattr(e, 'message'):
                print(e.message)
            else:
                print(e)

        save_results(args.savefile,
                     'a',
                     'Linear SVM',
                     clf,
                     False,
                     fit_time_ae=fit_time_ae,
                     score_time_ae=score_time_ae)

        # RBF SVM
        print("\rRBF SVM             ", end='')
        parameters = {
            'kernel': ['rbf'],
            'C': [0.01, 0.1, 1, 10, 100],
            'gamma': ['scale', 'auto', 1e-2, 1e-3, 1e-4]
        }
        svc = svm.SVC(gamma="scale",
                      class_weight=args.class_weight,
                      random_state=args.seed,
                      probability=True,
                      max_iter=max_iters)
        clf = GridSearchCV(svc,
                           parameters,
                           cv=sss,
                           n_jobs=-1,
                           scoring=scoring,
                           refit='roc_auc',
                           return_train_score=True)
        try:
            clf.fit(low_data, labels)
        except Exception as e:
            if hasattr(e, 'message'):
                print(e.message)
            else:
                print(e)
        save_results(args.savefile,
                     'a',
                     'RBF SVM',
                     clf,
                     False,
                     fit_time_ae=fit_time_ae,
                     score_time_ae=score_time_ae)

        # LogisticRegression
        print("\rLogisticRegression  ", end='')
        lreg = LogisticRegression(random_state=args.seed,
                                  solver='lbfgs',
                                  multi_class='ovr',
                                  class_weight=args.class_weight,
                                  n_jobs=-1,
                                  max_iter=max_iters)
        parameters = {'C': [0.01, 0.1, 1, 10, 100]}
        clf = GridSearchCV(lreg,
                           parameters,
                           cv=sss,
                           n_jobs=-1,
                           scoring=scoring,
                           refit='roc_auc',
                           return_train_score=True)
        try:
            clf.fit(low_data, labels)
        except Exception as e:
            if hasattr(e, 'message'):
                print(e.message)
            else:
                print(e)
        save_results(args.savefile,
                     'a',
                     'LogisticRegression',
                     clf,
                     False,
                     fit_time_ae=fit_time_ae,
                     score_time_ae=score_time_ae)
        print()
Example #4
0
def main():
    args = parse()

    # set random seeds
    random.seed(args.manual_seed)
    torch.manual_seed(args.manual_seed)
    np.random.seed(args.manual_seed)

    # prepare output directories
    base_dir = Path(args.out_dir)
    model_dir = base_dir.joinpath(args.model_name)
    if (args.resume or args.initialize) and not model_dir.exists():
        raise Exception("Model directory for resume does not exist")
    if not (args.resume or args.initialize) and model_dir.exists():
        c = ""
        while c != "y" and c != "n":
            c = input("Model directory already exists, overwrite?").strip()

        if c == "y":
            shutil.rmtree(model_dir)
        else:
            sys.exit(0)
    model_dir.mkdir(parents=True, exist_ok=True)

    summary_writer_dir = model_dir.joinpath("runs")
    summary_writer_dir.mkdir(exist_ok=True)
    save_path = model_dir.joinpath("checkpoints")
    save_path.mkdir(exist_ok=True)

    # prepare summary writer
    writer = SummaryWriter(summary_writer_dir, comment=args.writer_comment)

    # prepare data
    train_loader, val_loader, test_loader, args = load_dataset(
        args, flatten=args.flatten)

    # prepare flow model
    if hasattr(flows, args.flow):
        flow_model_template = getattr(flows, args.flow)

    flow_list = [flow_model_template(args.zdim) for _ in range(args.num_flows)]
    prior = torch.distributions.MultivariateNormal(torch.zeros(args.zdim),
                                                   torch.eye(args.zdim))
    flow_model = NormalizingFlowModel(prior, flow_list).to(args.device)

    # prepare autoencoder
    if args.dataset == "mnist":
        ae_model = AutoEncoder(args.xdim, args.zdim, args.units,
                               "binary").to(args.device)
    elif args.dataset == "cifar10":
        ae_model = ConvAutoEncoder().to(args.device)

    # setup optimizers
    ae_optimizer = optim.Adam(ae_model.parameters(), args.learning_rate)
    flow_optimizer = optim.Adam(flow_model.parameters(), args.learning_rate)

    # setup loss
    if args.dataset == "mnist":
        args.imshape = (1, 28, 28)
        args.zshape = (args.zdim, )
        ae_loss = nn.BCEWithLogitsLoss(reduction="sum").to(args.device)
    elif args.dataset == "cifar10":
        args.imshape = (3, 32, 32)
        args.zshape = (8, 8, 8)
        ae_loss = nn.MSELoss(reduction="sum").to(args.device)

    total_epochs = np.max([args.vae_epochs, args.flow_epochs, args.epochs])

    if args.resume:
        raise NotImplementedError
    if args.initialize:
        raise NotImplementedError

    # training loop
    for epoch in trange(1, total_epochs + 1):
        if epoch <= args.vae_epochs:
            train_ae(
                epoch,
                train_loader,
                ae_model,
                ae_optimizer,
                writer,
                ae_loss,
                device=args.device,
            )
            log_ae_tensorboard_images(
                ae_model,
                val_loader,
                writer,
                epoch,
                "AE/val/Images",
                xshape=args.imshape,
            )
            evaluate_ae(epoch, test_loader, ae_model, writer, ae_loss)

        if epoch <= args.flow_epochs:
            train_flow(
                epoch,
                train_loader,
                flow_model,
                ae_model,
                flow_optimizer,
                writer,
                device=args.device,
            )

            log_flow_tensorboard_images(
                flow_model,
                ae_model,
                writer,
                epoch,
                "Flow/sampled/Images",
                xshape=args.imshape,
                zshape=args.zshape,
            )

        if epoch % args.save_iter == 0:
            checkpoint_dict = {
                "epoch": epoch,
                "ae_optimizer": ae_optimizer.state_dict(),
                "flow_optimizer": flow_optimizer.state_dict(),
                "ae_model": ae_model.state_dict(),
                "flow_model": flow_model.state_dict(),
            }
            fname = f"model_{epoch}.pt"
            save_checkpoint(checkpoint_dict, save_path, fname)

    writer.close()
def main():
    with open("config.json") as json_file:
        conf = json.load(json_file)
    dataset_path = os.path.join(conf['data']['dataset_path'],
                                conf['data']['dataset_file'])
    device = conf['train']['device']

    model = AutoEncoder(in_channels=1,
                        dec_channels=1,
                        latent_size=conf['model']['latent_size'])
    model = model.to(device)
    model.load_state_dict(torch.load(load_path))

    dspites_dataset = Dspites(dataset_path)
    train_val = train_val_split(dspites_dataset)
    val_test = train_val_split(train_val['val'], val_split=0.2)

    data_loader_train = DataLoader(train_val['train'],
                                   batch_size=conf['train']['batch_size'],
                                   shuffle=True,
                                   num_workers=2)
    data_loader_val = DataLoader(val_test['val'],
                                 batch_size=200,
                                 shuffle=False,
                                 num_workers=1)
    data_loader_test = DataLoader(val_test['train'],
                                  batch_size=200,
                                  shuffle=False,
                                  num_workers=1)

    print('autoencoder training')
    print('frozen encoder: ', freeze_encoder)
    print('train dataset length: ', len(train_val['train']))
    print('val dataset length: ', len(val_test['val']))
    print('test dataset length: ', len(val_test['train']))

    print('latent space size:', conf['model']['latent_size'])
    print('batch size:', conf['train']['batch_size'])

    loss_function = nn.BCELoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

    model.train()
    if freeze_encoder:
        model.freeze_encoder()

    for epoch in range(25):
        if epoch > 15:
            for param in optimizer.param_groups:
                param['lr'] = max(0.00001,
                                  param['lr'] / conf['train']['lr_decay'])
                print('lr: ', param['lr'])

        loss_list = []
        model.train()

        for batch_i, batch in enumerate(data_loader_train):
            augment_transform = np.random.choice(augment_transform_list1)
            batch1 = image_batch_transformation(batch, augment_transform)
            loss = autoencoder_step(model, batch, device, loss_function)
            loss_list.append(loss.item())
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        mean_epoch_loss = sum(loss_list) / len(loss_list)
        model.eval()
        validation_loss = autoencoder_validation(data_loader_val, model,
                                                 device, loss_function)
        if epoch == 0:
            min_validation_loss = validation_loss
        else:
            min_validation_loss = min(min_validation_loss, validation_loss)
        print('epoch {0}, loss: {1:2.5f}, validation: {2:2.5f}'.format(
            epoch, mean_epoch_loss, validation_loss))
        if min_validation_loss == validation_loss:
            #pass
            torch.save(model.state_dict(), save_path)

    model.load_state_dict(torch.load(save_path))
    test_results = autoencoder_validation(data_loader_test, model, device,
                                          loss_function)
    print('test result: ', test_results)
        epoch_loss_siamese += siamese_loss.item()

        loss = siamese_loss + ae_loss
        # loss = loss_function(img_batch, decoded_img)

        # print(loss.item())
        loss.backward()

        optimizer.step()

    epoch_loss = epoch_loss_autoencoder + epoch_loss_siamese
    print(f'siamese: {epoch_loss_siamese}, autoencoder: {epoch_loss_autoencoder}, all: {epoch_loss}')
    
    intloss = int(epoch_loss * 10000) / 10000
    if epoch % config.save_frequency == 0:
        torch.save(autoencoder.state_dict(), f'{config.saved_models_folder}/autoencoder_epoch{epoch}_loss{intloss}.pth')
        torch.save(siamese_network.state_dict(), f'{config.saved_models_folder}/siamese_network_epoch{epoch}_loss{intloss}.pth')
        print('Saved models, epoch: ' + str(epoch))

for batch in train_data_loader:
    batch = batch[0]
    break
batch = batch.to(device)
batch = transform2(batch)

print('test')
print('the same images')
img1 = batch[0]
features, decoded_img = autoencoder(img1.unsqueeze(0))
result = siamese_network(features, features)
print(result)
Example #7
0
def main(args):
    # Set random seed for reproducibility
    manualSeed = 999
    #manualSeed = random.randint(1, 10000) # use if you want new results
    print("Random Seed: ", manualSeed)
    random.seed(manualSeed)
    torch.manual_seed(manualSeed)

    dataroot = args.dataroot
    workers = args.workers
    batch_size = args.batch_size
    nc = args.nc
    ngf = args.ngf
    ndf = args.ndf
    nhd = args.nhd
    num_epochs = args.num_epochs
    lr = args.lr
    beta1 = args.beta1
    ngpu = args.ngpu
    resume = args.resume
    record_pnt = args.record_pnt
    log_pnt = args.log_pnt
    mse = args.mse
    '''
    # We can use an image folder dataset the way we have it setup.
    # Create the dataset
    dataset = dset.ImageFolder(root=dataroot,
                               transform=transforms.Compose([
                                   transforms.Resize(image_size),
                                   transforms.CenterCrop(image_size),
                                   transforms.ToTensor(),
                                   transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                               ]))
    # Create the dataloader
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size,
                                             shuffle=True, num_workers=workers)
    '''
    dataset = dset.MNIST(
        root=dataroot,
        transform=transforms.Compose([
            transforms.ToTensor(),
            # transforms.Normalize(0.5, 0.5),
        ]))

    dataloader = torch.utils.data.DataLoader(dataset,
                                             batch_size=batch_size,
                                             shuffle=True,
                                             num_workers=workers)

    # Decide which device we want to run on
    device = torch.device("cuda:0" if (
        torch.cuda.is_available() and ngpu > 0) else "cpu")

    # Create the generator
    netG = AutoEncoder(nc, ngf, nhd=nhd).to(device)

    # Handle multi-gpu if desired
    if (device.type == 'cuda') and (ngpu > 1):
        netG = nn.DataParallel(netG, list(range(ngpu)))

    # Apply the weights_init function to randomly initialize all weights
    #  to mean=0, stdev=0.2.
    netG.apply(weights_init)

    # Create the Discriminator
    netD = Discriminator(nc, ndf, ngpu).to(device)

    # Handle multi-gpu if desired
    if (device.type == 'cuda') and (ngpu > 1):
        netD = nn.DataParallel(netD, list(range(ngpu)))

    # Apply the weights_init function to randomly initialize all weights
    #  to mean=0, stdev=0.2.
    netD.apply(weights_init)

    #resume training if args.resume is True
    if resume:
        ckpt = torch.load('ckpts/recent.pth')
        netG.load_state_dict(ckpt["netG"])
        netD.load_state_dict(ckpt["netD"])

    # Initialize BCELoss function
    criterion = nn.BCELoss()
    MSE = nn.MSELoss()
    mse_coeff = 1.
    center_coeff = 0.001

    # Establish convention for real and fake flags during training
    real_flag = 1
    fake_flag = 0

    # Setup Adam optimizers for both G and D
    optimizerD = optim.Adam(netD.parameters(), lr=lr, betas=(beta1, 0.999))
    optimizerG = optim.Adam(netG.dec.parameters(), lr=lr, betas=(beta1, 0.999))
    optimizerAE = optim.Adam(netG.parameters(), lr=lr, betas=(beta1, 0.999))

    # Training Loop

    # Lists to keep track of progress
    iters = 0

    R_errG = 0
    R_errD = 0
    R_errAE = 0
    R_std = 0
    R_mean = 0

    print("Starting Training Loop...")
    # For each epoch
    for epoch in range(num_epochs):
        # For each batch in the dataloader
        for i, data in enumerate(dataloader, 0):

            ############################
            # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
            ###########################
            ## Train with all-real batch
            netD.zero_grad()
            # Format batch
            real_img, label = data
            real_img, label = real_img.to(device), to_one_hot_vector(
                10, label).to(device)

            b_size = real_img.size(0)
            flag = torch.full((b_size, ), real_flag, device=device)
            # Forward pass real batch through D
            output = netD(real_img, label).view(-1)
            # Calculate loss on all-real batch
            errD_real = criterion(output, flag)
            # Calculate gradients for D in backward pass
            errD_real.backward()
            D_x = output.mean().item()

            ## Train with all-fake batch
            # Generate fake image batch with G
            noise = torch.randn(b_size, nhd, 1, 1).to(device)
            fake = netG.dec(noise, label)
            flag.fill_(fake_flag)
            # Classify all fake batch with D
            output = netD(fake.detach(), label).view(-1)
            # Calculate D's loss on the all-fake batch
            errD_fake = criterion(output, flag)
            # Calculate the gradients for this batch
            errD_fake.backward()
            D_G_z1 = output.mean().item()
            # Add the gradients from the all-real and all-fake batches
            errD = errD_real + errD_fake
            # Update D
            optimizerD.step()

            ############################
            # (2) Update G network: maximize log(D(G(z)))
            ###########################
            netG.dec.zero_grad()
            flag.fill_(real_flag)  # fake flags are real for generator cost
            # Since we just updated D, perform another forward pass of all-fake batch through D
            output = netD(fake, label).view(-1)
            # Calculate G's loss based on this output
            errG = criterion(output, flag)
            # Calculate gradients for G
            errG.backward()
            # Update G
            optimizerG.step()

            ############################
            # (3) Update AE network: minimize reconstruction loss
            ###########################
            netG.zero_grad()
            new_img = netG(real_img, label, label)
            hidden = netG.enc(real_img, label)
            central_loss = MSE(hidden, torch.zeros(hidden.shape).to(device))
            errAE = mse_coeff* MSE(real_img, new_img) \
                    + center_coeff* central_loss
            errAE.backward()
            optimizerAE.step()

            R_errG += errG.item()
            R_errD += errD.item()
            R_errAE += errAE.item()
            R_std += (hidden**2).mean().item()
            R_mean += hidden.mean().item()
            # Output training stats
            if i % log_pnt == 0:
                print(
                    '[%d/%d][%d/%d]\tLoss_D: %.4f\tLoss_G: %.4f\tLoss_AE: %.4f\t'
                    % (epoch, num_epochs, i, len(dataloader), R_errD / log_pnt,
                       R_errG / log_pnt, R_errAE / log_pnt))
                print('mean: %.4f\tstd: %.4f\tcentral/msecoeff: %4f' %
                      (R_mean / log_pnt, R_std / log_pnt,
                       center_coeff / mse_coeff))
                R_errG = 0.
                R_errD = 0.
                R_errAE = 0.
                R_std = 0.
                R_mean = 0.

            # Check how the generator is doing by saving G's output on fixed_noise
            if (iters % record_pnt == 0) or ((epoch == num_epochs - 1) and
                                             (i == len(dataloader) - 1)):
                vutils.save_image(
                    fake.to("cpu"),
                    './samples/image_{}.png'.format(iters // record_pnt))
                torch.save(
                    {
                        "netG": netG.state_dict(),
                        "netD": netD.state_dict(),
                        "nc": nc,
                        "ngf": ngf,
                        "ndf": ndf
                    }, 'ckpts/recent.pth')

            iters += 1
Example #8
0
        loss, emb_loss, delta_coeff, reconstr_loss = autoencoder_step(
            model, batch, device, loss_function)
        loss_list.append(loss.item())
        emb_loss_list.append(emb_loss.item())
        delta_loss_list.append(delta_coeff.item())
        reconstr_loss_list.append(reconstr_loss.item())
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    mean_epoch_loss = sum(loss_list) / len(loss_list)
    mean_ebedding_loss = sum(emb_loss_list) / len(emb_loss_list)
    mean_delta_loss = sum(delta_loss_list) / len(delta_loss_list)
    mean_reconstr_loss = sum(reconstr_loss_list) / len(reconstr_loss_list)
    model.eval()
    validation_loss = autoencoder_validation(data_loader_val, model, device,
                                             loss_function)
    if epoch == 0:
        min_validation_loss = validation_loss
    else:
        min_validation_loss = min(min_validation_loss, validation_loss)
    print('epoch {0}, loss: {1:2.5f}, emb loss: {2:2.5f}, reconstr original: {3:2.5f}, validation: {4:2.5f}, reconstract: {5:2.5f}'\
          .format(epoch, mean_epoch_loss, mean_ebedding_loss, mean_delta_loss, validation_loss, mean_reconstr_loss))
    if min_validation_loss == validation_loss:
        torch.save(model.state_dict(),
                   'weights/autoencoder_enhanced_latent12.pt')

model.load_state_dict(torch.load('weights/autoencoder_enhanced_latent12.pt'))
test_results = autoencoder_validation(data_loader_val, model, device,
                                      loss_function)
print('test result: ', test_results)
Example #9
0
def main():
    opts = get_argparser().parse_args()

    # dataset
    train_trainsform = transforms.Compose([
        transforms.RandomCrop(size=512, pad_if_needed=True),
        transforms.RandomHorizontalFlip(),
        transforms.RandomVerticalFlip(),
        transforms.ToTensor(),
    ])

    val_transform = transforms.Compose([transforms.ToTensor()])

    train_loader = data.DataLoader(data.ConcatDataset([
        ImageDataset(root='datasets/data/CLIC/train',
                     transform=train_trainsform),
        ImageDataset(root='datasets/data/CLIC/valid',
                     transform=train_trainsform),
    ]),
                                   batch_size=opts.batch_size,
                                   shuffle=True,
                                   num_workers=2,
                                   drop_last=True)

    val_loader = data.DataLoader(ImageDataset(root='datasets/data/kodak',
                                              transform=val_transform),
                                 batch_size=1,
                                 shuffle=False,
                                 num_workers=1)

    os.environ['CUDA_VISIBLE_DEVICES'] = opts.gpu_id
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    print("Train set: %d, Val set: %d" %
          (len(train_loader.dataset), len(val_loader.dataset)))
    model = AutoEncoder(C=128, M=128, in_chan=3, out_chan=3).to(device)

    # optimizer
    optimizer = torch.optim.Adam(model.parameters(),
                                 lr=1e-4,
                                 weight_decay=1e-5)

    # checkpoint
    best_score = 0.0
    cur_epoch = 0
    if opts.ckpt is not None and os.path.isfile(opts.ckpt):
        model.load_state_dict(torch.load(opts.ckpt))
    else:
        print("[!] Retrain")

    if opts.loss_type == 'ssim':
        criterion = SSIM_Loss(data_range=1.0, size_average=True, channel=3)
    else:
        criterion = MS_SSIM_Loss(data_range=1.0,
                                 size_average=True,
                                 channel=3,
                                 nonnegative_ssim=True)

    #==========   Train Loop   ==========#
    for cur_epoch in range(opts.total_epochs):
        # =====  Train  =====
        model.train()
        for cur_step, images in enumerate(train_loader):
            images = images.to(device, dtype=torch.float32)
            optimizer.zero_grad()
            outputs = model(images)

            loss = criterion(outputs, images)
            loss.backward()

            optimizer.step()

            if (cur_step) % opts.log_interval == 0:
                print("Epoch %d, Batch %d/%d, loss=%.6f" %
                      (cur_epoch, cur_step, len(train_loader), loss.item()))

        # =====  Save Latest Model  =====
        torch.save(model.state_dict(), 'latest_model.pt')

        # =====  Validation  =====
        print("Val on Kodak dataset...")
        best_score = 0.0
        cur_score = test(opts, model, val_loader, criterion, device)
        print("%s = %.6f" % (opts.loss_type, cur_score))
        # =====  Save Best Model  =====
        if cur_score > best_score:  # save best model
            best_score = cur_score
            torch.save(model.state_dict(), 'best_model.pt')
            print("Best model saved as best_model.pt")
Example #10
0
    elif opt.model == 'unet':
        net = UNet(3, n_classes=1, filters=opt.filters)
    elif opt.model == 'unet3plus':
        net = UNet3Plus(3, n_classes=1, filters=opt.filters)

    if device == torch.device('cuda'):
        net = nn.DataParallel(net, device_ids=[0, 1, 2, 3])
        logger.info(f'use gpu: {net.device_ids}')
    net.to(device=device)

    # optimizer = optim.RMSprop(net.parameters(), lr=opt.lr, weight_decay=1e-8, momentum=0.9)
    optimizer = optim.SGD(net.parameters(), lr=opt.lr, momentum=0.9)
    criterion = nn.BCEWithLogitsLoss()
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                     'max',
                                                     patience=opt.patience)

    train_net(net=net,
              device=device,
              train_loader=train_loader,
              val_loader=val_loader,
              epochs=opt.epochs,
              optimizer=optimizer,
              criterion=criterion,
              scheduler=scheduler,
              clip_grad=opt.clip_grad)

    model_filename = f'{opt.output_dir}/{opt.model}_{opt.filters[0]}_{time.strftime("%Y%m%d_%H%M%S", time.localtime())}.pth'
    torch.save(net.state_dict(), model_filename)
    logger.info(f'Model saved: {model_filename}!!')
Example #11
0
def main():
    loss_function = nn.BCELoss()

    with open("config.json") as json_file:
        conf = json.load(json_file)
    device = conf['train']['device']

    dataset_path = os.path.join(conf['data']['dataset_path'],
                                conf['data']['dataset_file'])
    dspites_dataset = Dspites(dataset_path)
    train_val = train_val_split(dspites_dataset)
    val_test = train_val_split(train_val['val'], val_split=0.2)

    data_loader_train = DataLoader(train_val['train'],
                                   batch_size=conf['train']['batch_size'],
                                   shuffle=True,
                                   num_workers=2)
    data_loader_val = DataLoader(val_test['val'],
                                 batch_size=200,
                                 shuffle=False,
                                 num_workers=1)
    data_loader_test = DataLoader(val_test['train'],
                                  batch_size=200,
                                  shuffle=False,
                                  num_workers=1)

    print('metric learning')
    print('train dataset length: ', len(train_val['train']))
    print('val dataset length: ', len(val_test['val']))
    print('test dataset length: ', len(val_test['train']))

    print('latent space size:', conf['model']['latent_size'])
    print('batch size:', conf['train']['batch_size'])
    print('margin:', conf['train']['margin'])

    loss_list = []
    model = AutoEncoder(in_channels=1,
                        dec_channels=1,
                        latent_size=conf['model']['latent_size'])
    model = model.to(device)

    optimizer = torch.optim.Adam(model.parameters(), lr=conf['train']['lr'])

    model.train()
    if load_path:
        model.load_state_dict(torch.load(load_path))

    for epoch in range(10):
        for param in optimizer.param_groups:
            param['lr'] = max(0.00001, param['lr'] / conf['train']['lr_decay'])
            print('lr: ', param['lr'])
        loss_list = []

        for batch_i, batch in enumerate(data_loader_train):
            # if batch_i == 1000:
            #     break
            batch = batch['image']
            batch = batch.type(torch.FloatTensor)
            batch = batch.to(device)
            loss = triplet_step(model, batch, transform1, transform2)
            loss_list.append(loss.item())
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        recall, recall10 = recall_validation(model, data_loader_val,
                                             transform1, transform2, device)
        if epoch == 0:
            min_validation_recall = recall
        else:
            min_validation_recall = min(min_validation_recall, recall)
        if min_validation_recall == recall and save_path:
            torch.save(model.state_dict(), save_path)
        print('epoch {0}, loss {1:2.4f}'.format(
            epoch,
            sum(loss_list) / len(loss_list)))
        print('recall@3: {0:2.4f}, recall 10%: {1:2.4f}'.format(
            recall, recall10))

    model.load_state_dict(torch.load(save_path))
    recall, recall10 = recall_validation(model, data_loader_test, transform1,
                                         transform2)
    print('test recall@3: {0:2.4f}, recall@3 10%: {1:2.4f}'.format(
        recall, recall10))
def main():
    args = parse.parse()

    # set random seeds
    random.seed(args.manual_seed)
    torch.manual_seed(args.manual_seed)
    np.random.seed(args.manual_seed)

    # prepare output directories
    base_dir = Path(args.out_dir)
    model_dir = base_dir.joinpath(args.model_name)
    if (args.resume or args.initialize) and not model_dir.exists():
        raise Exception("Model directory for resume does not exist")
    if not (args.resume or args.initialize) and model_dir.exists():
        c = ""
        while c != "y" and c != "n":
            c = input("Model directory already exists, overwrite?").strip()

        if c == "y":
            shutil.rmtree(model_dir)
        else:
            sys.exit(0)
    model_dir.mkdir(parents=True, exist_ok=True)

    summary_writer_dir = model_dir.joinpath("runs")
    summary_writer_dir.mkdir(exist_ok=True)
    save_path = model_dir.joinpath("checkpoints")
    save_path.mkdir(exist_ok=True)

    # prepare summary writer
    writer = SummaryWriter(summary_writer_dir, comment=args.writer_comment)

    # prepare data
    train_loader, val_loader, test_loader, args = load_dataset(
        args, flatten=args.flatten_image
    )

    # prepare flow model
    if hasattr(flows, args.flow):
        flow_model_template = getattr(flows, args.flow)

    flow_list = [flow_model_template(args.zdim) for _ in range(args.num_flows)]
    if args.permute_conv:
        convs = [flows.OneByOneConv(dim=args.zdim) for _ in range(args.num_flows)]
        flow_list = list(itertools.chain(*zip(convs, flow_list)))
    if args.actnorm:
        actnorms = [flows.ActNorm(dim=args.zdim) for _ in range(args.num_flows)]
        flow_list = list(itertools.chain(*zip(actnorms, flow_list)))
    prior = torch.distributions.MultivariateNormal(
        torch.zeros(args.zdim, device=args.device),
        torch.eye(args.zdim, device=args.device),
    )
    flow_model = NormalizingFlowModel(prior, flow_list).to(args.device)

    # prepare losses and autoencoder
    if args.dataset == "mnist":
        args.imshape = (1, 28, 28)
        if args.ae_model == "linear":
            ae_model = AutoEncoder(args.xdim, args.zdim, args.units, "binary").to(
                args.device
            )
            ae_loss = nn.BCEWithLogitsLoss(reduction="sum").to(args.device)

        elif args.ae_model == "conv":
            args.zshape = (8, 7, 7)
            ae_model = ConvAutoEncoder(
                in_channels=1,
                image_size=np.squeeze(args.imshape),
                activation=nn.Hardtanh(0, 1),
            ).to(args.device)
            ae_loss = nn.BCELoss(reduction="sum").to(args.device)

    elif args.dataset == "cifar10":
        args.imshape = (3, 32, 32)
        args.zshape = (8, 8, 8)
        ae_loss = nn.MSELoss(reduction="sum").to(args.device)
        ae_model = ConvAutoEncoder(in_channels=3, image_size=args.imshape).to(
            args.device
        )

    # setup optimizers
    ae_optimizer = optim.Adam(ae_model.parameters(), args.learning_rate)
    flow_optimizer = optim.Adam(flow_model.parameters(), args.learning_rate)

    total_epochs = np.max([args.vae_epochs, args.flow_epochs, args.epochs])

    if args.resume:
        checkpoint = torch.load(args.model_path, map_location=args.device)
        flow_model.load_state_dict(checkpoint["flow_model"])
        ae_model.load_state_dict(checkpoint["ae_model"])
        flow_optimizer.load_state_dict(checkpoint["flow_optimizer"])
        ae_optimizer.load_state_dict(checkpoint["ae_optimizer"])
        init_epoch = checkpoint["epoch"]
    elif args.initialize:
        checkpoint = torch.load(args.model_path, map_location=args.device)
        flow_model.load_state_dict(checkpoint["flow_model"])
        ae_model.load_state_dict(checkpoint["ae_model"])
    else:
        init_epoch = 1

    if args.initialize:
        raise NotImplementedError

    # training loop
    for epoch in trange(init_epoch, total_epochs + 1):
        if epoch <= args.vae_epochs:
            train_ae(
                epoch,
                train_loader,
                ae_model,
                ae_optimizer,
                writer,
                ae_loss,
                device=args.device,
            )
            log_ae_tensorboard_images(
                ae_model,
                val_loader,
                writer,
                epoch,
                "AE/val/Images",
                xshape=args.imshape,
            )
            # evaluate_ae(epoch, test_loader, ae_model, writer, ae_loss)

        if epoch <= args.flow_epochs:
            train_flow(
                epoch,
                train_loader,
                flow_model,
                ae_model,
                flow_optimizer,
                writer,
                device=args.device,
                flatten=not args.no_flatten_latent,
            )

            log_flow_tensorboard_images(
                flow_model,
                ae_model,
                writer,
                epoch,
                "Flow/sampled/Images",
                xshape=args.imshape,
                zshape=args.zshape,
            )

        if epoch % args.save_iter == 0:
            checkpoint_dict = {
                "epoch": epoch,
                "ae_optimizer": ae_optimizer.state_dict(),
                "flow_optimizer": flow_optimizer.state_dict(),
                "ae_model": ae_model.state_dict(),
                "flow_model": flow_model.state_dict(),
            }
            fname = f"model_{epoch}.pt"
            save_checkpoint(checkpoint_dict, save_path, fname)

    if args.save_images:
        p = Path(f"images/mnist/{args.model_name}")
        p.mkdir(parents=True, exist_ok=True)
        n_samples = 10000

        print("final epoch images")
        flow_model.eval()
        ae_model.eval()
        with torch.no_grad():
            z = flow_model.sample(n_samples)
            z = z.to(next(ae_model.parameters()).device)
            xcap = ae_model.decoder.predict(z).to("cpu").view(-1, *args.imshape).numpy()
        xcap = (np.rint(xcap) * int(255)).astype(np.uint8)
        for i, im in enumerate(xcap):
            imsave(f'{p.joinpath(f"im_{i}.png").as_posix()}', np.squeeze(im))

    writer.close()
Example #13
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--filename', required=True, help='Name/path of file')
    parser.add_argument('--save_dir', default='./outputs', help='Path to dictionary where will be save results.')

    parser.add_argument('--pretrain_epochs', type=int, default=100, help="Number of epochs to pretrain model AE")
    parser.add_argument('--epochs', type=int, default=100, help="Number of epochs to train AE and classifier")
    parser.add_argument('--dims_layers_ae', type=int, nargs='+', default=[500, 100, 10],
                        help="Dimensional of layers in AE")
    parser.add_argument('--dims_layers_classifier', type=int, nargs='+', default=[10, 5],
                        help="Dimensional of layers in classifier")
    parser.add_argument('--batch_size', type=int, default=50)
    parser.add_argument('--lr', type=float, default=0.001, help="Learning rate")
    parser.add_argument('--use_dropout', action='store_true', help="Use dropout")

    parser.add_argument('--no-cuda', action='store_true', help='disables CUDA training')
    parser.add_argument('--seed', type=int, default=1234, help='random seed (default: 1)')

    parser.add_argument('--procedure', nargs='+', choices=['pre-training_ae', 'training_classifier', 'training_all'],
                        help='Procedure which you can use. Choice from: pre-training_ae, training_all, '
                             'training_classifier')
    parser.add_argument('--criterion_classifier', default='BCELoss', choices=['BCELoss', 'HingeLoss'],
                        help='Kind of loss function')
    parser.add_argument('--scale_loss', type=float, default=1., help='Weight for loss of classifier')
    parser.add_argument('--earlyStopping', type=int, default=None, help='Number of epochs to early stopping')
    parser.add_argument('--use_scheduler', action='store_true')
    args = parser.parse_args()
    print(args)

    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    use_cuda = not args.no_cuda and torch.cuda.is_available()
    device = torch.device("cuda" if use_cuda else "cpu")

    loaded = np.load(args.filename)
    x_train = loaded[f'data_train']
    x_test = loaded[f'data_test']
    y_train = loaded[f'lab_train']
    y_test = loaded[f'lab_test']
    del loaded

    name_target = PurePosixPath(args.filename).parent.stem
    n_split = PurePosixPath(args.filename).stem
    save_dir = f'{args.save_dir}/tensorboard/{name_target}_{n_split}'
    Path(save_dir).mkdir(parents=True, exist_ok=True)

    if args.dims_layers_classifier[0] == -1:
        args.dims_layers_classifier[0] = x_test.shape[1]

    model_classifier = Classifier(args.dims_layers_classifier, args.use_dropout).to(device)
    if args.criterion_classifier == 'HingeLoss':
        criterion_classifier = nn.HingeEmbeddingLoss()
        print('Use "Hinge" loss.')
    else:
        criterion_classifier = nn.BCEWithLogitsLoss()

    model_ae = None
    criterion_ae = None
    if 'training_classifier' != args.procedure[0]:
        args.dims_layers_ae = [x_train.shape[1]] + args.dims_layers_ae
        assert args.dims_layers_ae[-1] == args.dims_layers_classifier[0], \
            'Dimension of latent space must be equal with dimension of input classifier!'

        model_ae = AutoEncoder(args.dims_layers_ae, args.use_dropout).to(device)
        criterion_ae = nn.MSELoss()
        optimizer = torch.optim.Adam(list(model_ae.parameters()) + list(model_classifier.parameters()), lr=args.lr)
    else:
        optimizer = torch.optim.Adam(model_classifier.parameters(), lr=args.lr)

    scheduler = None
    if args.use_scheduler:
        scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda ep: 0.95)

    writer = SummaryWriter(save_dir)

    total_scores = {'roc_auc': 0, 'acc': 0, 'mcc': 0, 'bal': 0, 'recall': 0,
                    'max_roc_auc': 0, 'max_acc': 0, 'max_mcc': 0, 'max_bal': 0, 'max_recall': 0,
                    'pre-fit_time': 0, 'pre-score_time': 0, 'fit_time': 0, 'score_time': 0
                    }

    dir_model_ae = f'{args.save_dir}/models_AE'
    Path(dir_model_ae).mkdir(parents=True, exist_ok=True)
    # dir_model_classifier = f'{args.save_dir}/models_classifier'
    # Path(dir_model_classifier).mkdir(parents=True, exist_ok=True)

    path_ae = f'{dir_model_ae}/{name_target}_{n_split}.pth'
    # path_classifier = f'{dir_model_classifier}/{name_target}_{n_split}.pth'

    if 'pre-training_ae' in args.procedure:
        min_val_loss = np.Inf
        epochs_no_improve = 0

        epoch_tqdm = tqdm(range(args.pretrain_epochs), desc="Epoch pre-train loss")
        for epoch in epoch_tqdm:
            loss_train, time_trn = train_step(model_ae, None, criterion_ae, None, optimizer, scheduler, x_train,
                                              y_train, device, writer, epoch, args.batch_size, 'pre-training_ae')
            loss_test, _ = test_step(model_ae, None, criterion_ae, None, x_test, y_test, device, writer, epoch,
                                     args.batch_size, 'pre-training_ae')

            if not np.isfinite(loss_train):
                break

            total_scores['pre-fit_time'] += time_trn

            if loss_test < min_val_loss:
                torch.save(model_ae.state_dict(), path_ae)
                epochs_no_improve = 0
                min_val_loss = loss_test
            else:
                epochs_no_improve += 1
            epoch_tqdm.set_description(
                f"Epoch pre-train loss: {loss_train:.5f}, test loss: {loss_test:.5f} (minimal val-loss: {min_val_loss:.5f}, stop: {epochs_no_improve}|{args.earlyStopping})")
            if args.earlyStopping is not None and epoch >= args.earlyStopping and epochs_no_improve == args.earlyStopping:
                print('\033[1;31mEarly stopping in pre-training model\033[0m')
                break
        print(f"\033[1;5;33mLoad model AE form '{path_ae}'\033[0m")
        if device.type == "cpu":
            model_ae.load_state_dict(torch.load(path_ae, map_location=lambda storage, loc: storage))
        else:
            model_ae.load_state_dict(torch.load(path_ae))
        model_ae = model_ae.to(device)
        model_ae.eval()

    min_val_loss = np.Inf
    epochs_no_improve = 0

    epoch = None
    stage = 'training_classifier' if 'training_classifier' in args.procedure else 'training_all'
    epoch_tqdm = tqdm(range(args.epochs), desc="Epoch train loss")
    for epoch in epoch_tqdm:
        loss_train, time_trn = train_step(model_ae, model_classifier, criterion_ae, criterion_classifier, optimizer,
                                          scheduler, x_train, y_train, device, writer, epoch, args.batch_size,
                                          stage, args.scale_loss)
        loss_test, scores_val, time_tst = test_step(model_ae, model_classifier, criterion_ae, criterion_classifier,
                                                    x_test, y_test, device, writer, epoch, args.batch_size, stage,
                                                    args.scale_loss)

        if not np.isfinite(loss_train):
            break

        total_scores['fit_time'] += time_trn
        total_scores['score_time'] += time_tst
        if total_scores['max_roc_auc'] < scores_val['roc_auc']:
            for key, val in scores_val.items():
                total_scores[f'max_{key}'] = val

        if loss_test < min_val_loss:
            # torch.save(model_ae.state_dict(), path_ae)
            # torch.save(model_classifier.state_dict(), path_classifier)
            epochs_no_improve = 0
            min_val_loss = loss_test
            for key, val in scores_val.items():
                total_scores[key] = val
        else:
            epochs_no_improve += 1
        epoch_tqdm.set_description(
            f"Epoch train loss: {loss_train:.5f}, test loss: {loss_test:.5f} (minimal val-loss: {min_val_loss:.5f}, stop: {epochs_no_improve}|{args.earlyStopping})")
        if args.earlyStopping is not None and epoch >= args.earlyStopping and epochs_no_improve == args.earlyStopping:
            print('\033[1;31mEarly stopping!\033[0m')
            break
    total_scores['score_time'] /= epoch + 1
    writer.close()

    save_file = f'{args.save_dir}/{name_target}.txt'
    head = 'idx;params'
    temp = f'{n_split};pretrain_epochs:{args.pretrain_epochs},dims_layers_ae:{args.dims_layers_ae},' \
           f'dims_layers_classifier:{args.dims_layers_classifier},batch_size:{args.batch_size},lr:{args.lr}' \
           f'use_dropout:{args.use_dropout},procedure:{args.procedure},scale_loss:{args.scale_loss},' \
           f'earlyStopping:{args.earlyStopping}'
    for key, val in total_scores.items():
        head = head + f';{key}'
        temp = temp + f';{val}'

    not_exists = not Path(save_file).exists()
    with open(save_file, 'a') as f:
        if not_exists:
            f.write(f'{head}\n')
        f.write(f'{temp}\n')