Example #1
0
def main(args):
    print('Loading data')
    idxs = np.load(args.boards_file, allow_pickle=True)['idxs']
    print(f'Number of Boards: {len(idxs)}')

    if torch.cuda.is_available() and args.num_gpus > 0:
        device = torch.device('cuda:0')
    else:
        device = torch.device('cpu')

    if args.shuffle:
        np.random.shuffle(idxs)

    train_idxs = idxs[:-args.num_test]
    test_idxs = idxs[-args.num_test:]

    train_loader = DataLoader(Boards(train_idxs),
                              batch_size=args.batch_size,
                              shuffle=False)
    test_loader = DataLoader(Boards(test_idxs), batch_size=args.batch_size)

    model = AutoEncoder().to(device)
    if args.model_loadname:
        model.load_state_dict(torch.load(args.model_loadname))

    optimizer = optim.Adam(model.parameters(), lr=args.lr)

    model.train()
    losses = []
    total_iters = 0

    for epoch in range(args.init_epoch, args.epochs):
        print(f'Running epoch {epoch} / {args.epochs}\n')
        for batch_idx, board in tqdm(enumerate(train_loader),
                                     total=len(train_loader)):
            board = board.to(device)
            optimizer.zero_grad()
            loss = model.loss(board)
            loss.backward()

            losses.append(loss.item())
            optimizer.step()

            if total_iters % args.log_interval == 0:
                tqdm.write(f'Loss: {loss.item()}')

            if total_iters % args.save_interval == 0:
                torch.save(
                    model.state_dict(),
                    append_to_modelname(args.model_savename, total_iters))
                plot_losses(losses, 'vis/ae_losses.png')
            total_iters += 1
Example #2
0
def main():
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    sample_size = 64

    ckpt = torch.load("ckpts/recent.pth")
    model = AutoEncoder(ckpt["nc"], ckpt["ngf"]).to(device)
    model.load_state_dict(ckpt["netG"])

    img_transform = transforms.Compose([
        transforms.ToTensor(),
    ])

    dataset = MNIST('./data/MNIST', transform=img_transform)
    dataloader = DataLoader(dataset, batch_size=sample_size)

    imgs, label = next(iter(dataloader))

    new_imgs = reconstruct(model, imgs, label, device)
    vutils.save_image(imgs, "./inference_img/original.png")
    vutils.save_image(new_imgs, "./inference_img/new_img.png")
Example #3
0
def init_model(path_to_checkpoint=PATH_TO_EMBEDDER):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    net = AutoEncoder(
        latent_dim=512,
        in_channels=1,
        num_hiddens=256,
        num_res_hiddens=64,
        num_res_layers=4,
        out_channels=1,
    ).to(device)

    net.load_state_dict(
        torch.load(open(path_to_checkpoint, 'rb'), map_location=device)
    )
    
    print()
    print('='*30, end='\n\n')
    print(net.eval())
    print(end='\n\n')
    print('='*30, end='\n\n')

    return device, net
def main():
    with open("config.json") as json_file:
        conf = json.load(json_file)
    dataset_path = os.path.join(conf['data']['dataset_path'],
                                conf['data']['dataset_file'])
    device = conf['train']['device']

    model = AutoEncoder(in_channels=1,
                        dec_channels=1,
                        latent_size=conf['model']['latent_size'])
    model = model.to(device)
    model.load_state_dict(torch.load(load_path))

    dspites_dataset = Dspites(dataset_path)
    train_val = train_val_split(dspites_dataset)
    val_test = train_val_split(train_val['val'], val_split=0.2)

    data_loader_train = DataLoader(train_val['train'],
                                   batch_size=conf['train']['batch_size'],
                                   shuffle=True,
                                   num_workers=2)
    data_loader_val = DataLoader(val_test['val'],
                                 batch_size=200,
                                 shuffle=False,
                                 num_workers=1)
    data_loader_test = DataLoader(val_test['train'],
                                  batch_size=200,
                                  shuffle=False,
                                  num_workers=1)

    print('autoencoder training')
    print('frozen encoder: ', freeze_encoder)
    print('train dataset length: ', len(train_val['train']))
    print('val dataset length: ', len(val_test['val']))
    print('test dataset length: ', len(val_test['train']))

    print('latent space size:', conf['model']['latent_size'])
    print('batch size:', conf['train']['batch_size'])

    loss_function = nn.BCELoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

    model.train()
    if freeze_encoder:
        model.freeze_encoder()

    for epoch in range(25):
        if epoch > 15:
            for param in optimizer.param_groups:
                param['lr'] = max(0.00001,
                                  param['lr'] / conf['train']['lr_decay'])
                print('lr: ', param['lr'])

        loss_list = []
        model.train()

        for batch_i, batch in enumerate(data_loader_train):
            augment_transform = np.random.choice(augment_transform_list1)
            batch1 = image_batch_transformation(batch, augment_transform)
            loss = autoencoder_step(model, batch, device, loss_function)
            loss_list.append(loss.item())
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        mean_epoch_loss = sum(loss_list) / len(loss_list)
        model.eval()
        validation_loss = autoencoder_validation(data_loader_val, model,
                                                 device, loss_function)
        if epoch == 0:
            min_validation_loss = validation_loss
        else:
            min_validation_loss = min(min_validation_loss, validation_loss)
        print('epoch {0}, loss: {1:2.5f}, validation: {2:2.5f}'.format(
            epoch, mean_epoch_loss, validation_loss))
        if min_validation_loss == validation_loss:
            #pass
            torch.save(model.state_dict(), save_path)

    model.load_state_dict(torch.load(save_path))
    test_results = autoencoder_validation(data_loader_test, model, device,
                                          loss_function)
    print('test result: ', test_results)
from config import Config
from models import AutoEncoder, SiameseNetwork

config = Config()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 

autoencoder = AutoEncoder(config)
siamese_network = SiameseNetwork(config)

autoencoder_file = '/autoencoder_epoch175_loss1.1991.pth'
siamese_file = '/siamese_network_epoch175_loss1.1991.pth'

if config.load_model:
    autoencoder.load_state_dict(torch.load(config.saved_models_folder + autoencoder_file))
    siamese_network.load_state_dict(torch.load(config.saved_models_folder + siamese_file))

autoencoder.to(device)
autoencoder.train()

siamese_network.to(device)
siamese_network.train()

params = list(autoencoder.parameters()) + list(siamese_network.parameters())

optimizer = torch.optim.Adam(params, lr=config.lr, betas=(0.9, 0.999))

transform = transforms.Compose([
    transforms.Grayscale(num_output_channels=1),
    # transforms.RandomCrop(size=128),
Example #6
0
def main(args):
    torch.manual_seed(args.seed)
    np.random.seed(args.seed)

    print('Loading data')
    data = np.load(args.boards_file, allow_pickle=True)
    idxs = data['idxs']
    labels = data['values'] 
    mask = labels != None
    idxs = idxs[mask]
    labels = labels[mask]
    n = len(idxs)

    if args.shuffle:
        perm = np.random.permutation(n)
        idxs = idxs[perm]
        labels = labels[perm]

    if args.experiment is None:
        experiment = Experiment(project_name="chess-axia")
        experiment.log_parameters(vars(args))
    else:
        experiment = ExistingExperiment(previous_experiment=args.experiment)
    key = experiment.get_key()

    print(f'Number of Boards: {n}')

    if torch.cuda.is_available() and args.num_gpus > 0:
        device = torch.device('cuda:0')
    else:
        device = torch.device('cpu')

    if args.num_train is None:
        args.num_train = n - args.num_test
    if args.num_train + args.num_test > n:
        raise ValueError('num-train and num-test sum to more than dataset size')
    train_idxs = idxs[:args.num_train]
    test_idxs = idxs[-args.num_test:]

    train_labels = labels[:-args.num_test]
    test_labels = labels[-args.num_test:]
    #print(f'Win percentage: {sum(train_labels)/ len(train_labels):.1%}')
    print('Train size: ' + str(len(train_labels)))

    train_loader = DataLoader(BoardAndPieces(train_idxs, train_labels),
                              batch_size=args.batch_size, collate_fn=collate_fn,
                              shuffle=True)
    test_loader = DataLoader(BoardAndPieces(test_idxs, test_labels),
                             batch_size=args.batch_size, collate_fn=collate_fn)

    ae = AutoEncoder().to(device)
    ae_file = append_to_modelname(args.ae_model, args.ae_iter)
    ae.load_state_dict(torch.load(ae_file))

    model = BoardValuator(ae).to(device)
    loss_fn = model.loss_fn
    model = DataParallel(model)
    if args.model_loadname:
        model.load_state_dict(torch.load(args.model_loadname))

    if args.ae_freeze:
        print('Freezing AE model')
        for param in ae.parameters():
            param.requires_grad = False

    if torch.cuda.device_count() > 1 and args.num_gpus > 1:
        model = torch.nn.DataParallel(model)

    optimizer = optim.Adam(model.parameters(), lr=args.lr)

    #cum_acc = cum_loss = count = 0
    total_iters = args.init_iter

    for epoch in range(args.init_epoch, args.epochs):
        print(f'Running epoch {epoch} / {args.epochs}\n')
        #for batch_idx, (input, mask, label) in tqdm(enumerate(train_loader),
        #                             total=len(train_loader)):
        for batch_idx, (input, mask, label) in enumerate(train_loader):

            model.train()

            input = to(input, device)
            mask = to(mask, device)
            label = to(label, device)

            optimizer.zero_grad()
            output = model(input, mask)
            loss = loss_fn(output, label)
            loss.backward()
            optimizer.step()

            cum_loss += loss.item()
            # cum_acc += acc.item()
            count += 1

            if total_iters % args.log_interval == 0:
                tqdm.write(f'Epoch: {epoch}\t Iter: {total_iters:>6}\t Loss: {loss.item():.5f}')
                # experiment.log_metric('accuracy', cum_acc / count,
                #                       step=total_iters)
                experiment.log_metric('loss', cum_loss / count,
                                      step=total_iters)
                experiment.log_metric('loss_', cum_loss / count,
                                      step=total_iters)
                #cum_acc = cum_loss = count = 0

            if total_iters % args.save_interval == 0:
                path = get_modelpath(args.model_dirname, key,
                                     args.model_savename, iter=total_iters,
                                     epoch=epoch)
                dirname = os.path.dirname(path)
                if not os.path.exists(dirname):
                    os.makedirs(dirname)
                torch.save(model.state_dict(), path)

            if total_iters % args.eval_interval == 0 and total_iters != 0:
                loss = eval_loss(model, test_loader, device, loss_fn)
                tqdm.write(f'\tTEST: Loss: {loss:.5f}')
                #experiment.log_metric('test accuracy', acc, step=total_iters,
                #                      epoch=epoch)
                experiment.log_metric('test loss', loss, step=total_iters,
                                      epoch=epoch)
            total_iters += 1
Example #7
0
data_loader_train = DataLoader(train_val['train'], batch_size=conf['train']['batch_size'], shuffle=True, num_workers=2)
data_loader_val = DataLoader(val_test['val'], batch_size=200, shuffle=False, num_workers=1)
data_loader_test = DataLoader(val_test['train'], batch_size=200, shuffle=False, num_workers=1)

print('latent space size:', conf['model']['latent_size'])
print('batch size:', conf['train']['batch_size'])

conf['train']['batch_size'] = 128
data_loader_train = DataLoader(train_val['train'], batch_size=conf['train']['batch_size'], shuffle=True, num_workers=2)
data_loader_val = DataLoader(train_val['val'], batch_size=500, shuffle=False, num_workers=1)

model = AutoEncoder(in_channels=1, dec_channels=1, latent_size=conf['model']['latent_size'])
model = model.to(device)
#
#autoencoder_bce_loss_latent12.pt
model.load_state_dict(torch.load('weights/archi_mega_super_long_metric_learn_6.pt'))

#1 - scale (from 0.5 to 1.0), 2,3 - orientation (cos, sin), 4,5 - position (from 0 to 1)
latent_range = [4,5]
min_value = 0
max_value = 1

regressor = SimpleNet(latent_size=conf['model']['latent_size'], number_of_classes=len(latent_range))
regressor.to(device)

loss_function = nn.MSELoss()
optimizer = torch.optim.Adam(regressor.parameters(), lr=0.001)



def regression_validation(regressor, model, data_loader):
Example #8
0
    parser.add_argument("--pretrain_epochs", default=20, type=int)
    parser.add_argument("--train_epochs", default=200, type=int)
    parser.add_argument("--save_dir", default="saves")
    args = parser.parse_args()
    print(args)
    epochs_pre = args.pretrain_epochs
    batch_size = args.batch_size

    x, y = load_mnist()
    autoencoder = AutoEncoder().to(device)
    ae_save_path = "saves/sim_autoencoder.pth"

    if os.path.isfile(ae_save_path):
        print("Loading {}".format(ae_save_path))
        checkpoint = torch.load(ae_save_path)
        autoencoder.load_state_dict(checkpoint["state_dict"])
    else:
        print("=> no checkpoint found at '{}'".format(ae_save_path))
        checkpoint = {"epoch": 0, "best": float("inf")}
    pretrain(
        data=x,
        model=autoencoder,
        num_epochs=epochs_pre,
        savepath=ae_save_path,
        checkpoint=checkpoint,
    )

    dec_save_path = "saves/dec.pth"
    dec = DEC(
        n_clusters=10,
        autoencoder=autoencoder,
Example #9
0
def main(args):
    # Set random seed for reproducibility
    manualSeed = 999
    #manualSeed = random.randint(1, 10000) # use if you want new results
    print("Random Seed: ", manualSeed)
    random.seed(manualSeed)
    torch.manual_seed(manualSeed)

    dataroot = args.dataroot
    workers = args.workers
    batch_size = args.batch_size
    nc = args.nc
    ngf = args.ngf
    ndf = args.ndf
    nhd = args.nhd
    num_epochs = args.num_epochs
    lr = args.lr
    beta1 = args.beta1
    ngpu = args.ngpu
    resume = args.resume
    record_pnt = args.record_pnt
    log_pnt = args.log_pnt
    mse = args.mse
    '''
    # We can use an image folder dataset the way we have it setup.
    # Create the dataset
    dataset = dset.ImageFolder(root=dataroot,
                               transform=transforms.Compose([
                                   transforms.Resize(image_size),
                                   transforms.CenterCrop(image_size),
                                   transforms.ToTensor(),
                                   transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                               ]))
    # Create the dataloader
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size,
                                             shuffle=True, num_workers=workers)
    '''
    dataset = dset.MNIST(
        root=dataroot,
        transform=transforms.Compose([
            transforms.ToTensor(),
            # transforms.Normalize(0.5, 0.5),
        ]))

    dataloader = torch.utils.data.DataLoader(dataset,
                                             batch_size=batch_size,
                                             shuffle=True,
                                             num_workers=workers)

    # Decide which device we want to run on
    device = torch.device("cuda:0" if (
        torch.cuda.is_available() and ngpu > 0) else "cpu")

    # Create the generator
    netG = AutoEncoder(nc, ngf, nhd=nhd).to(device)

    # Handle multi-gpu if desired
    if (device.type == 'cuda') and (ngpu > 1):
        netG = nn.DataParallel(netG, list(range(ngpu)))

    # Apply the weights_init function to randomly initialize all weights
    #  to mean=0, stdev=0.2.
    netG.apply(weights_init)

    # Create the Discriminator
    netD = Discriminator(nc, ndf, ngpu).to(device)

    # Handle multi-gpu if desired
    if (device.type == 'cuda') and (ngpu > 1):
        netD = nn.DataParallel(netD, list(range(ngpu)))

    # Apply the weights_init function to randomly initialize all weights
    #  to mean=0, stdev=0.2.
    netD.apply(weights_init)

    #resume training if args.resume is True
    if resume:
        ckpt = torch.load('ckpts/recent.pth')
        netG.load_state_dict(ckpt["netG"])
        netD.load_state_dict(ckpt["netD"])

    # Initialize BCELoss function
    criterion = nn.BCELoss()
    MSE = nn.MSELoss()
    mse_coeff = 1.
    center_coeff = 0.001

    # Establish convention for real and fake flags during training
    real_flag = 1
    fake_flag = 0

    # Setup Adam optimizers for both G and D
    optimizerD = optim.Adam(netD.parameters(), lr=lr, betas=(beta1, 0.999))
    optimizerG = optim.Adam(netG.dec.parameters(), lr=lr, betas=(beta1, 0.999))
    optimizerAE = optim.Adam(netG.parameters(), lr=lr, betas=(beta1, 0.999))

    # Training Loop

    # Lists to keep track of progress
    iters = 0

    R_errG = 0
    R_errD = 0
    R_errAE = 0
    R_std = 0
    R_mean = 0

    print("Starting Training Loop...")
    # For each epoch
    for epoch in range(num_epochs):
        # For each batch in the dataloader
        for i, data in enumerate(dataloader, 0):

            ############################
            # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
            ###########################
            ## Train with all-real batch
            netD.zero_grad()
            # Format batch
            real_img, label = data
            real_img, label = real_img.to(device), to_one_hot_vector(
                10, label).to(device)

            b_size = real_img.size(0)
            flag = torch.full((b_size, ), real_flag, device=device)
            # Forward pass real batch through D
            output = netD(real_img, label).view(-1)
            # Calculate loss on all-real batch
            errD_real = criterion(output, flag)
            # Calculate gradients for D in backward pass
            errD_real.backward()
            D_x = output.mean().item()

            ## Train with all-fake batch
            # Generate fake image batch with G
            noise = torch.randn(b_size, nhd, 1, 1).to(device)
            fake = netG.dec(noise, label)
            flag.fill_(fake_flag)
            # Classify all fake batch with D
            output = netD(fake.detach(), label).view(-1)
            # Calculate D's loss on the all-fake batch
            errD_fake = criterion(output, flag)
            # Calculate the gradients for this batch
            errD_fake.backward()
            D_G_z1 = output.mean().item()
            # Add the gradients from the all-real and all-fake batches
            errD = errD_real + errD_fake
            # Update D
            optimizerD.step()

            ############################
            # (2) Update G network: maximize log(D(G(z)))
            ###########################
            netG.dec.zero_grad()
            flag.fill_(real_flag)  # fake flags are real for generator cost
            # Since we just updated D, perform another forward pass of all-fake batch through D
            output = netD(fake, label).view(-1)
            # Calculate G's loss based on this output
            errG = criterion(output, flag)
            # Calculate gradients for G
            errG.backward()
            # Update G
            optimizerG.step()

            ############################
            # (3) Update AE network: minimize reconstruction loss
            ###########################
            netG.zero_grad()
            new_img = netG(real_img, label, label)
            hidden = netG.enc(real_img, label)
            central_loss = MSE(hidden, torch.zeros(hidden.shape).to(device))
            errAE = mse_coeff* MSE(real_img, new_img) \
                    + center_coeff* central_loss
            errAE.backward()
            optimizerAE.step()

            R_errG += errG.item()
            R_errD += errD.item()
            R_errAE += errAE.item()
            R_std += (hidden**2).mean().item()
            R_mean += hidden.mean().item()
            # Output training stats
            if i % log_pnt == 0:
                print(
                    '[%d/%d][%d/%d]\tLoss_D: %.4f\tLoss_G: %.4f\tLoss_AE: %.4f\t'
                    % (epoch, num_epochs, i, len(dataloader), R_errD / log_pnt,
                       R_errG / log_pnt, R_errAE / log_pnt))
                print('mean: %.4f\tstd: %.4f\tcentral/msecoeff: %4f' %
                      (R_mean / log_pnt, R_std / log_pnt,
                       center_coeff / mse_coeff))
                R_errG = 0.
                R_errD = 0.
                R_errAE = 0.
                R_std = 0.
                R_mean = 0.

            # Check how the generator is doing by saving G's output on fixed_noise
            if (iters % record_pnt == 0) or ((epoch == num_epochs - 1) and
                                             (i == len(dataloader) - 1)):
                vutils.save_image(
                    fake.to("cpu"),
                    './samples/image_{}.png'.format(iters // record_pnt))
                torch.save(
                    {
                        "netG": netG.state_dict(),
                        "netD": netD.state_dict(),
                        "nc": nc,
                        "ngf": ngf,
                        "ndf": ndf
                    }, 'ckpts/recent.pth')

            iters += 1
Example #10
0
def main():
    opts = get_argparser().parse_args()

    # dataset
    train_trainsform = transforms.Compose([
        transforms.RandomCrop(size=512, pad_if_needed=True),
        transforms.RandomHorizontalFlip(),
        transforms.RandomVerticalFlip(),
        transforms.ToTensor(),
    ])

    val_transform = transforms.Compose([transforms.ToTensor()])

    train_loader = data.DataLoader(data.ConcatDataset([
        ImageDataset(root='datasets/data/CLIC/train',
                     transform=train_trainsform),
        ImageDataset(root='datasets/data/CLIC/valid',
                     transform=train_trainsform),
    ]),
                                   batch_size=opts.batch_size,
                                   shuffle=True,
                                   num_workers=2,
                                   drop_last=True)

    val_loader = data.DataLoader(ImageDataset(root='datasets/data/kodak',
                                              transform=val_transform),
                                 batch_size=1,
                                 shuffle=False,
                                 num_workers=1)

    os.environ['CUDA_VISIBLE_DEVICES'] = opts.gpu_id
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    print("Train set: %d, Val set: %d" %
          (len(train_loader.dataset), len(val_loader.dataset)))
    model = AutoEncoder(C=128, M=128, in_chan=3, out_chan=3).to(device)

    # optimizer
    optimizer = torch.optim.Adam(model.parameters(),
                                 lr=1e-4,
                                 weight_decay=1e-5)

    # checkpoint
    best_score = 0.0
    cur_epoch = 0
    if opts.ckpt is not None and os.path.isfile(opts.ckpt):
        model.load_state_dict(torch.load(opts.ckpt))
    else:
        print("[!] Retrain")

    if opts.loss_type == 'ssim':
        criterion = SSIM_Loss(data_range=1.0, size_average=True, channel=3)
    else:
        criterion = MS_SSIM_Loss(data_range=1.0,
                                 size_average=True,
                                 channel=3,
                                 nonnegative_ssim=True)

    #==========   Train Loop   ==========#
    for cur_epoch in range(opts.total_epochs):
        # =====  Train  =====
        model.train()
        for cur_step, images in enumerate(train_loader):
            images = images.to(device, dtype=torch.float32)
            optimizer.zero_grad()
            outputs = model(images)

            loss = criterion(outputs, images)
            loss.backward()

            optimizer.step()

            if (cur_step) % opts.log_interval == 0:
                print("Epoch %d, Batch %d/%d, loss=%.6f" %
                      (cur_epoch, cur_step, len(train_loader), loss.item()))

        # =====  Save Latest Model  =====
        torch.save(model.state_dict(), 'latest_model.pt')

        # =====  Validation  =====
        print("Val on Kodak dataset...")
        best_score = 0.0
        cur_score = test(opts, model, val_loader, criterion, device)
        print("%s = %.6f" % (opts.loss_type, cur_score))
        # =====  Save Best Model  =====
        if cur_score > best_score:  # save best model
            best_score = cur_score
            torch.save(model.state_dict(), 'best_model.pt')
            print("Best model saved as best_model.pt")
Example #11
0
def main():
    loss_function = nn.BCELoss()

    with open("config.json") as json_file:
        conf = json.load(json_file)
    device = conf['train']['device']

    dataset_path = os.path.join(conf['data']['dataset_path'],
                                conf['data']['dataset_file'])
    dspites_dataset = Dspites(dataset_path)
    train_val = train_val_split(dspites_dataset)
    val_test = train_val_split(train_val['val'], val_split=0.2)

    data_loader_train = DataLoader(train_val['train'],
                                   batch_size=conf['train']['batch_size'],
                                   shuffle=True,
                                   num_workers=2)
    data_loader_val = DataLoader(val_test['val'],
                                 batch_size=200,
                                 shuffle=False,
                                 num_workers=1)
    data_loader_test = DataLoader(val_test['train'],
                                  batch_size=200,
                                  shuffle=False,
                                  num_workers=1)

    print('metric learning')
    print('train dataset length: ', len(train_val['train']))
    print('val dataset length: ', len(val_test['val']))
    print('test dataset length: ', len(val_test['train']))

    print('latent space size:', conf['model']['latent_size'])
    print('batch size:', conf['train']['batch_size'])
    print('margin:', conf['train']['margin'])

    loss_list = []
    model = AutoEncoder(in_channels=1,
                        dec_channels=1,
                        latent_size=conf['model']['latent_size'])
    model = model.to(device)

    optimizer = torch.optim.Adam(model.parameters(), lr=conf['train']['lr'])

    model.train()
    if load_path:
        model.load_state_dict(torch.load(load_path))

    for epoch in range(10):
        for param in optimizer.param_groups:
            param['lr'] = max(0.00001, param['lr'] / conf['train']['lr_decay'])
            print('lr: ', param['lr'])
        loss_list = []

        for batch_i, batch in enumerate(data_loader_train):
            # if batch_i == 1000:
            #     break
            batch = batch['image']
            batch = batch.type(torch.FloatTensor)
            batch = batch.to(device)
            loss = triplet_step(model, batch, transform1, transform2)
            loss_list.append(loss.item())
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        recall, recall10 = recall_validation(model, data_loader_val,
                                             transform1, transform2, device)
        if epoch == 0:
            min_validation_recall = recall
        else:
            min_validation_recall = min(min_validation_recall, recall)
        if min_validation_recall == recall and save_path:
            torch.save(model.state_dict(), save_path)
        print('epoch {0}, loss {1:2.4f}'.format(
            epoch,
            sum(loss_list) / len(loss_list)))
        print('recall@3: {0:2.4f}, recall 10%: {1:2.4f}'.format(
            recall, recall10))

    model.load_state_dict(torch.load(save_path))
    recall, recall10 = recall_validation(model, data_loader_test, transform1,
                                         transform2)
    print('test recall@3: {0:2.4f}, recall@3 10%: {1:2.4f}'.format(
        recall, recall10))
def main():
    args = parse.parse()

    # set random seeds
    random.seed(args.manual_seed)
    torch.manual_seed(args.manual_seed)
    np.random.seed(args.manual_seed)

    # prepare output directories
    base_dir = Path(args.out_dir)
    model_dir = base_dir.joinpath(args.model_name)
    if (args.resume or args.initialize) and not model_dir.exists():
        raise Exception("Model directory for resume does not exist")
    if not (args.resume or args.initialize) and model_dir.exists():
        c = ""
        while c != "y" and c != "n":
            c = input("Model directory already exists, overwrite?").strip()

        if c == "y":
            shutil.rmtree(model_dir)
        else:
            sys.exit(0)
    model_dir.mkdir(parents=True, exist_ok=True)

    summary_writer_dir = model_dir.joinpath("runs")
    summary_writer_dir.mkdir(exist_ok=True)
    save_path = model_dir.joinpath("checkpoints")
    save_path.mkdir(exist_ok=True)

    # prepare summary writer
    writer = SummaryWriter(summary_writer_dir, comment=args.writer_comment)

    # prepare data
    train_loader, val_loader, test_loader, args = load_dataset(
        args, flatten=args.flatten_image
    )

    # prepare flow model
    if hasattr(flows, args.flow):
        flow_model_template = getattr(flows, args.flow)

    flow_list = [flow_model_template(args.zdim) for _ in range(args.num_flows)]
    if args.permute_conv:
        convs = [flows.OneByOneConv(dim=args.zdim) for _ in range(args.num_flows)]
        flow_list = list(itertools.chain(*zip(convs, flow_list)))
    if args.actnorm:
        actnorms = [flows.ActNorm(dim=args.zdim) for _ in range(args.num_flows)]
        flow_list = list(itertools.chain(*zip(actnorms, flow_list)))
    prior = torch.distributions.MultivariateNormal(
        torch.zeros(args.zdim, device=args.device),
        torch.eye(args.zdim, device=args.device),
    )
    flow_model = NormalizingFlowModel(prior, flow_list).to(args.device)

    # prepare losses and autoencoder
    if args.dataset == "mnist":
        args.imshape = (1, 28, 28)
        if args.ae_model == "linear":
            ae_model = AutoEncoder(args.xdim, args.zdim, args.units, "binary").to(
                args.device
            )
            ae_loss = nn.BCEWithLogitsLoss(reduction="sum").to(args.device)

        elif args.ae_model == "conv":
            args.zshape = (8, 7, 7)
            ae_model = ConvAutoEncoder(
                in_channels=1,
                image_size=np.squeeze(args.imshape),
                activation=nn.Hardtanh(0, 1),
            ).to(args.device)
            ae_loss = nn.BCELoss(reduction="sum").to(args.device)

    elif args.dataset == "cifar10":
        args.imshape = (3, 32, 32)
        args.zshape = (8, 8, 8)
        ae_loss = nn.MSELoss(reduction="sum").to(args.device)
        ae_model = ConvAutoEncoder(in_channels=3, image_size=args.imshape).to(
            args.device
        )

    # setup optimizers
    ae_optimizer = optim.Adam(ae_model.parameters(), args.learning_rate)
    flow_optimizer = optim.Adam(flow_model.parameters(), args.learning_rate)

    total_epochs = np.max([args.vae_epochs, args.flow_epochs, args.epochs])

    if args.resume:
        checkpoint = torch.load(args.model_path, map_location=args.device)
        flow_model.load_state_dict(checkpoint["flow_model"])
        ae_model.load_state_dict(checkpoint["ae_model"])
        flow_optimizer.load_state_dict(checkpoint["flow_optimizer"])
        ae_optimizer.load_state_dict(checkpoint["ae_optimizer"])
        init_epoch = checkpoint["epoch"]
    elif args.initialize:
        checkpoint = torch.load(args.model_path, map_location=args.device)
        flow_model.load_state_dict(checkpoint["flow_model"])
        ae_model.load_state_dict(checkpoint["ae_model"])
    else:
        init_epoch = 1

    if args.initialize:
        raise NotImplementedError

    # training loop
    for epoch in trange(init_epoch, total_epochs + 1):
        if epoch <= args.vae_epochs:
            train_ae(
                epoch,
                train_loader,
                ae_model,
                ae_optimizer,
                writer,
                ae_loss,
                device=args.device,
            )
            log_ae_tensorboard_images(
                ae_model,
                val_loader,
                writer,
                epoch,
                "AE/val/Images",
                xshape=args.imshape,
            )
            # evaluate_ae(epoch, test_loader, ae_model, writer, ae_loss)

        if epoch <= args.flow_epochs:
            train_flow(
                epoch,
                train_loader,
                flow_model,
                ae_model,
                flow_optimizer,
                writer,
                device=args.device,
                flatten=not args.no_flatten_latent,
            )

            log_flow_tensorboard_images(
                flow_model,
                ae_model,
                writer,
                epoch,
                "Flow/sampled/Images",
                xshape=args.imshape,
                zshape=args.zshape,
            )

        if epoch % args.save_iter == 0:
            checkpoint_dict = {
                "epoch": epoch,
                "ae_optimizer": ae_optimizer.state_dict(),
                "flow_optimizer": flow_optimizer.state_dict(),
                "ae_model": ae_model.state_dict(),
                "flow_model": flow_model.state_dict(),
            }
            fname = f"model_{epoch}.pt"
            save_checkpoint(checkpoint_dict, save_path, fname)

    if args.save_images:
        p = Path(f"images/mnist/{args.model_name}")
        p.mkdir(parents=True, exist_ok=True)
        n_samples = 10000

        print("final epoch images")
        flow_model.eval()
        ae_model.eval()
        with torch.no_grad():
            z = flow_model.sample(n_samples)
            z = z.to(next(ae_model.parameters()).device)
            xcap = ae_model.decoder.predict(z).to("cpu").view(-1, *args.imshape).numpy()
        xcap = (np.rint(xcap) * int(255)).astype(np.uint8)
        for i, im in enumerate(xcap):
            imsave(f'{p.joinpath(f"im_{i}.png").as_posix()}', np.squeeze(im))

    writer.close()
Example #13
0
File: test.py Project: yongxinw/zsl
class Tester(object):
    def __init__(self, args):
        super(Tester, self).__init__()

        self.args = args

        self.model = AutoEncoder(args)
        self.model.load_state_dict(torch.load(args.checkpoint))
        self.model.cuda()
        self.model.eval()

        self.result = {}

        self.train_dataset = CUBDataset(split='train')
        self.test_dataset = CUBDataset(split='test')
        self.val_dataset = CUBDataset(split='val')

        self.train_loader = DataLoader(dataset=self.train_dataset,
                                       batch_size=args.batch_size)
        self.test_loader = DataLoader(dataset=self.test_dataset,
                                      batch_size=args.batch_size)
        self.val_loader = DataLoader(dataset=self.val_dataset,
                                     batch_size=100,
                                     shuffle=True)

        train_cls = self.train_dataset.get_classes('train')
        test_cls = self.test_dataset.get_classes('test')
        print("Load class")
        print(train_cls)
        print(test_cls)

        self.zsl = ZSLPrediction(train_cls, test_cls)

    # def tSNE(self):

    def conse_prediction(self, mode='test'):
        def pred(recon_x, z_tilde, output):
            cls_score = output.detach().cpu().numpy()
            print(cls_score)
            pred = self.zsl.conse_wordembedding_predict(
                cls_score, self.args.conse_top_k)
            return pred

        self.get_features(mode=mode, pred_func=pred)

        if (mode + '_pred') in self.result:

            target = self.result[mode + '_label']
            pred = self.result[mode + '_pred']
            print(target)
            print(pred)
            acc = np.sum(target == pred)
            print(acc)
            total = target.shape[0]
            print(total)
            return acc / float(total)
        else:
            raise NotImplementedError

    def knn_prediction(self, mode='test'):
        self.get_features(mode=mode, pred_func=None)

        if (mode + '_feature') in self.result:
            features = self.result[mode + '_feature']
            labels = self.result[mode + '_label']
            print(labels)
            self.zsl.construct_nn(features,
                                  labels,
                                  k=5,
                                  metric='cosine',
                                  sample_num=5)
            pred = self.zsl.nn_predict(features)

            acc = np.sum(labels == pred)
            total = labels.shape[0]

            return acc / float(total)
        else:
            raise NotImplementedError

    def tSNE(self, mode='train'):
        self.get_features(mode=mode, pred_func=None)

        total_num = self.result[mode + '_feature'].shape[0]

        random_index = np.random.permutation(total_num)

        random_index = random_index[:30]

        self.zsl.tSNE_visualization(self.result[mode+'_feature'][random_index,:], \
                                    self.result[mode+'_label'][random_index], \
                                    mode=mode,
                                    file_name= self.args.tsne_out)

    def get_features(self, mode='test', pred_func=None):
        self.model.eval()
        if pred_func is None and (mode + '_feature') in self.result:
            print("Use cached result")
            return
        if pred_func is not None and (mode + '_pred') in self.result:
            print("Use cached result")
            return

        if mode == 'train':
            loader = self.train_loader
        elif mode == 'test':
            loader = self.test_loader

        all_z = []
        all_label = []
        all_pred = []

        for data in tqdm(loader):
            # if idx == 3:
            #     break

            images = Variable(data['image64_crop'].cuda())
            target = Variable(data['class_id'].cuda())

            recon_x, z_tilde, output = self.model(images)
            target = target.detach().cpu().numpy()

            output = F.softmax(output, dim=1)

            all_label.append(target)
            all_z.append(z_tilde.detach().cpu().numpy())

            if pred_func is not None:
                pred = pred_func(recon_x, z_tilde, output)
                all_pred.append(pred)

        self.result[mode + '_feature'] = np.vstack(all_z)  # all features
        # print(all_label)
        self.result[mode + '_label'] = np.hstack(all_label)  # all test label

        if pred_func is not None:
            self.result[mode + '_pred'] = np.hstack(all_pred)
            print(self.result[mode + '_pred'].shape)
        print(self.result[mode + '_feature'].shape)
        print(self.result[mode + '_label'].shape)

    def validation_recon(self):
        self.model.eval()
        for idx, data in enumerate(self.val_loader):
            if idx == 1:
                break

            images = Variable(data['image64_crop'].cuda())
            recon_x, z_tilde, output = self.model(images)

            all_recon_images = recon_x.detach().cpu().numpy()  #N x 3 x 64 x 64
            all_origi_images = data['image64_crop'].numpy()  #N x 3 x 64 x 64

            for i in range(all_recon_images.shape[0]):
                imsave(
                    './recon/recon' + str(i) + '.png',
                    np.transpose(np.squeeze(all_origi_images[i, :, :, :]),
                                 [1, 2, 0]))
                imsave(
                    './recon/orig' + str(i) + '.png',
                    np.transpose(np.squeeze(all_recon_images[i, :, :, :]),
                                 [1, 2, 0]))

    def test_nn_image(self):
        self.get_features(mode='test', pred_func=None)
        self.get_features(mode='train', pred_func=None)

        N = 100
        random_index = np.random.permutation(
            self.result['test_feature'].shape[0])[:N]

        from sklearn.neighbors import NearestNeighbors

        neigh = NearestNeighbors()
        neigh.fit(self.result['train_feature'])

        test_feature = self.result['test_feature'][random_index, :]
        _, pred_index = neigh.kneighbors(test_feature, 1)

        for i in range(N):
            test_index = random_index[i]

            data = self.test_dataset[test_index]
            image = data['image64_crop'].numpy()  #1 x 3 x 64 x 64
            print(image.shape)
            imsave('./nn_image/test' + str(i) + '.png',
                   np.transpose(np.squeeze(image), [1, 2, 0]))

            train_index = pred_index[i][0]
            print(train_index)
            data = self.train_dataset[train_index]
            image = data['image64_crop'].numpy()  #1 x 3 x 64 x 64
            print(image.shape)
            imsave('./nn_image/train' + str(i) + '.png',
                   np.transpose(np.squeeze(image), [1, 2, 0]))
Example #14
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--filename', required=True, help='Name/path of file')
    parser.add_argument('--savefile',
                        type=str,
                        default='./output.txt',
                        help='Path to file where will be save results')
    parser.add_argument('--class_weight',
                        action='store_true',
                        default=None,
                        help='Use balance weight')
    parser.add_argument('--seed', default=1234, help='Number of seed')

    parser.add_argument('--pretrain_epochs',
                        type=int,
                        default=100,
                        help="Number of epochs to pretrain model AE")
    parser.add_argument('--dims_layers_ae',
                        type=int,
                        nargs='+',
                        default=[500, 100, 10],
                        help="Dimensional of layers in AE")
    parser.add_argument('--batch_size', type=int, default=50)
    parser.add_argument('--lr',
                        type=float,
                        default=0.001,
                        help="Learning rate")
    parser.add_argument('--use_dropout',
                        action='store_true',
                        help="Use dropout")
    parser.add_argument('--no-cuda',
                        action='store_true',
                        help='disables CUDA training')
    parser.add_argument('--earlyStopping',
                        type=int,
                        default=None,
                        help='Number of epochs to early stopping')
    parser.add_argument('--use_scheduler', action='store_true')
    args = parser.parse_args()
    print(args)

    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    use_cuda = not args.no_cuda and torch.cuda.is_available()
    device = torch.device("cuda" if use_cuda else "cpu")
    print(f'Device: {device.type}')

    loaded = np.load(args.filename)
    data = loaded['data']
    labels = loaded['label']
    del loaded

    name_target = PurePosixPath(args.savefile).stem
    save_dir = f'{PurePosixPath(args.savefile).parent}/tensorboard/{name_target}'
    Path(save_dir).mkdir(parents=True, exist_ok=True)

    args.dims_layers_ae = [data.shape[1]] + args.dims_layers_ae
    model_ae = AutoEncoder(args.dims_layers_ae, args.use_dropout).to(device)

    criterion_ae = nn.MSELoss()
    optimizer = torch.optim.Adam(model_ae.parameters(),
                                 lr=args.lr,
                                 weight_decay=1e-5)

    scheduler = None
    if args.use_scheduler:
        scheduler = torch.optim.lr_scheduler.LambdaLR(
            optimizer, lr_lambda=lambda ep: 0.95)

    min_val_loss = np.Inf
    epochs_no_improve = 0
    fit_time_ae = 0
    writer = SummaryWriter(save_dir)
    model_path = f'{PurePosixPath(args.savefile).parent}/models_AE/{name_target}.pth'
    Path(PurePosixPath(model_path).parent).mkdir(parents=True, exist_ok=True)
    epoch_tqdm = tqdm(range(args.pretrain_epochs), desc="Epoch loss")
    for epoch in epoch_tqdm:
        loss_train, fit_t = train_step(model_ae, criterion_ae, optimizer,
                                       scheduler, data, labels, device, writer,
                                       epoch, args.batch_size)
        fit_time_ae += fit_t
        if loss_train < min_val_loss:
            torch.save(model_ae.state_dict(), model_path)
            epochs_no_improve = 0
            min_val_loss = loss_train
        else:
            epochs_no_improve += 1
        epoch_tqdm.set_description(
            f"Epoch loss: {loss_train:.5f} (minimal loss: {min_val_loss:.5f}, stop: {epochs_no_improve}|{args.earlyStopping})"
        )
        if args.earlyStopping is not None and epoch > args.earlyStopping and epochs_no_improve == args.earlyStopping:
            print('\033[1;31mEarly stopping in AE model\033[0m')
            break

    print('===================================================')
    print(f'Transforming data to lower dimensional')
    if device.type == "cpu":
        model_ae.load_state_dict(
            torch.load(model_path, map_location=lambda storage, loc: storage))
    else:
        model_ae.load_state_dict(torch.load(model_path))
    model_ae.eval()

    low_data = np.empty((data.shape[0], args.dims_layers_ae[-1]))
    n_batch, rest = divmod(data.shape[0], args.batch_size)
    n_batch = n_batch + 1 if rest else n_batch
    score_time_ae = 0
    with torch.no_grad():
        test_tqdm = tqdm(range(n_batch), desc="Transform data", leave=False)
        for i in test_tqdm:
            start_time = time.time()
            batch = torch.from_numpy(
                data[i * args.batch_size:(i + 1) *
                     args.batch_size, :]).float().to(device)
            # ===================forward=====================
            z, _ = model_ae(batch)
            low_data[i * args.batch_size:(i + 1) *
                     args.batch_size, :] = z.detach().cpu().numpy()
            end_time = time.time()
            score_time_ae += end_time - start_time
    print('Data shape after transformation: {}'.format(low_data.shape))
    print('===================================================')

    if args.class_weight:
        args.class_weight = 'balanced'
    else:
        args.class_weight = None

    # Split data
    sss = StratifiedShuffleSplit(n_splits=3,
                                 test_size=0.1,
                                 random_state=args.seed)
    scoring = {
        'acc': make_scorer(accuracy_score),
        'roc_auc': make_scorer(roc_auc_score, needs_proba=True),
        'mcc': make_scorer(matthews_corrcoef),
        'bal': make_scorer(balanced_accuracy_score),
        'recall': make_scorer(recall_score)
    }

    max_iters = 10000
    save_results(args.savefile,
                 'w',
                 'model',
                 None,
                 True,
                 fit_time_ae=fit_time_ae,
                 score_time_ae=score_time_ae)

    with warnings.catch_warnings():
        warnings.simplefilter('ignore', ConvergenceWarning)
        warnings.simplefilter('ignore', RuntimeWarning)
        environ["PYTHONWARNINGS"] = "ignore"

        # Linear SVM
        print("\rLinear SVM         ", end='')
        parameters = {'C': [0.01, 0.1, 1, 10, 100]}
        # svc = svm.LinearSVC(class_weight=args.class_weight, random_state=seed)
        svc = svm.SVC(kernel='linear',
                      class_weight=args.class_weight,
                      random_state=args.seed,
                      probability=True,
                      max_iter=max_iters)
        clf = GridSearchCV(svc,
                           parameters,
                           cv=sss,
                           n_jobs=-1,
                           scoring=scoring,
                           refit='roc_auc',
                           return_train_score=True)
        try:
            clf.fit(low_data, labels)
        except Exception as e:
            if hasattr(e, 'message'):
                print(e.message)
            else:
                print(e)

        save_results(args.savefile,
                     'a',
                     'Linear SVM',
                     clf,
                     False,
                     fit_time_ae=fit_time_ae,
                     score_time_ae=score_time_ae)

        # RBF SVM
        print("\rRBF SVM             ", end='')
        parameters = {
            'kernel': ['rbf'],
            'C': [0.01, 0.1, 1, 10, 100],
            'gamma': ['scale', 'auto', 1e-2, 1e-3, 1e-4]
        }
        svc = svm.SVC(gamma="scale",
                      class_weight=args.class_weight,
                      random_state=args.seed,
                      probability=True,
                      max_iter=max_iters)
        clf = GridSearchCV(svc,
                           parameters,
                           cv=sss,
                           n_jobs=-1,
                           scoring=scoring,
                           refit='roc_auc',
                           return_train_score=True)
        try:
            clf.fit(low_data, labels)
        except Exception as e:
            if hasattr(e, 'message'):
                print(e.message)
            else:
                print(e)
        save_results(args.savefile,
                     'a',
                     'RBF SVM',
                     clf,
                     False,
                     fit_time_ae=fit_time_ae,
                     score_time_ae=score_time_ae)

        # LogisticRegression
        print("\rLogisticRegression  ", end='')
        lreg = LogisticRegression(random_state=args.seed,
                                  solver='lbfgs',
                                  multi_class='ovr',
                                  class_weight=args.class_weight,
                                  n_jobs=-1,
                                  max_iter=max_iters)
        parameters = {'C': [0.01, 0.1, 1, 10, 100]}
        clf = GridSearchCV(lreg,
                           parameters,
                           cv=sss,
                           n_jobs=-1,
                           scoring=scoring,
                           refit='roc_auc',
                           return_train_score=True)
        try:
            clf.fit(low_data, labels)
        except Exception as e:
            if hasattr(e, 'message'):
                print(e.message)
            else:
                print(e)
        save_results(args.savefile,
                     'a',
                     'LogisticRegression',
                     clf,
                     False,
                     fit_time_ae=fit_time_ae,
                     score_time_ae=score_time_ae)
        print()
Example #15
0
                               batch_size=conf['train']['batch_size'],
                               shuffle=True,
                               num_workers=2)
data_loader_val = DataLoader(val_test['val'],
                             batch_size=200,
                             shuffle=False,
                             num_workers=1)
data_loader_test = DataLoader(val_test['train'],
                              batch_size=200,
                              shuffle=False,
                              num_workers=1)

loss_function = nn.BCELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.0005)

model.load_state_dict(torch.load('weights/autoencoder_bce_loss_latent12.pt'))


def autoencoder_step(model, original_batch, device, loss_function):
    original = original_batch['image']
    original = original.unsqueeze(1)
    original = original.type(torch.FloatTensor)
    original = original.to(device)
    with torch.no_grad():
        _, decoded = model(original)
    z_decoded, decoded_decoded = model(decoded.detach())
    z_original, decoded_original = model(original)
    reconstruction_loss_original = loss_function(decoded_original, original)
    reconstruction_loss_decoded = loss_function(decoded_decoded, original)
    reconstruction_loss = reconstruction_loss_original + reconstruction_loss_decoded
    cos_margin = 0.7 * torch.ones(z_decoded.shape[0]).to(device)
        model.freeze_encoder()
        loss_function = return_loss_function(model_frozen)
        mean_epoch_loss, validation_loss = \
            decoder_step(model, loss_function, optimizer, data_loader_train, data_loader_val, device)
        print('         autoencoder loss: {0:2.5f}, BCE val: {1:2.5f}'.format(
            mean_epoch_loss, validation_loss))


if __name__ == "__main__":
    device = conf['train']['device']

    model = AutoEncoder(in_channels=1,
                        dec_channels=1,
                        latent_size=conf['model']['latent_size'])
    model = model.to(device)
    model.load_state_dict(torch.load(load_path))

    dataset_path = os.path.join(conf['data']['dataset_path'],
                                conf['data']['dataset_file'])
    dspites_dataset = Dspites(dataset_path)
    train_val = train_val_split(dspites_dataset)
    val_test = train_val_split(train_val['val'], val_split=0.2)

    data_loader_train = DataLoader(train_val['train'],
                                   batch_size=conf['train']['batch_size'],
                                   shuffle=True,
                                   num_workers=2)
    data_loader_val = DataLoader(val_test['val'],
                                 batch_size=200,
                                 shuffle=False,
                                 num_workers=1)
Example #17
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--filename', required=True, help='Name/path of file')
    parser.add_argument('--save_dir', default='./outputs', help='Path to dictionary where will be save results.')

    parser.add_argument('--pretrain_epochs', type=int, default=100, help="Number of epochs to pretrain model AE")
    parser.add_argument('--epochs', type=int, default=100, help="Number of epochs to train AE and classifier")
    parser.add_argument('--dims_layers_ae', type=int, nargs='+', default=[500, 100, 10],
                        help="Dimensional of layers in AE")
    parser.add_argument('--dims_layers_classifier', type=int, nargs='+', default=[10, 5],
                        help="Dimensional of layers in classifier")
    parser.add_argument('--batch_size', type=int, default=50)
    parser.add_argument('--lr', type=float, default=0.001, help="Learning rate")
    parser.add_argument('--use_dropout', action='store_true', help="Use dropout")

    parser.add_argument('--no-cuda', action='store_true', help='disables CUDA training')
    parser.add_argument('--seed', type=int, default=1234, help='random seed (default: 1)')

    parser.add_argument('--procedure', nargs='+', choices=['pre-training_ae', 'training_classifier', 'training_all'],
                        help='Procedure which you can use. Choice from: pre-training_ae, training_all, '
                             'training_classifier')
    parser.add_argument('--criterion_classifier', default='BCELoss', choices=['BCELoss', 'HingeLoss'],
                        help='Kind of loss function')
    parser.add_argument('--scale_loss', type=float, default=1., help='Weight for loss of classifier')
    parser.add_argument('--earlyStopping', type=int, default=None, help='Number of epochs to early stopping')
    parser.add_argument('--use_scheduler', action='store_true')
    args = parser.parse_args()
    print(args)

    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    use_cuda = not args.no_cuda and torch.cuda.is_available()
    device = torch.device("cuda" if use_cuda else "cpu")

    loaded = np.load(args.filename)
    x_train = loaded[f'data_train']
    x_test = loaded[f'data_test']
    y_train = loaded[f'lab_train']
    y_test = loaded[f'lab_test']
    del loaded

    name_target = PurePosixPath(args.filename).parent.stem
    n_split = PurePosixPath(args.filename).stem
    save_dir = f'{args.save_dir}/tensorboard/{name_target}_{n_split}'
    Path(save_dir).mkdir(parents=True, exist_ok=True)

    if args.dims_layers_classifier[0] == -1:
        args.dims_layers_classifier[0] = x_test.shape[1]

    model_classifier = Classifier(args.dims_layers_classifier, args.use_dropout).to(device)
    if args.criterion_classifier == 'HingeLoss':
        criterion_classifier = nn.HingeEmbeddingLoss()
        print('Use "Hinge" loss.')
    else:
        criterion_classifier = nn.BCEWithLogitsLoss()

    model_ae = None
    criterion_ae = None
    if 'training_classifier' != args.procedure[0]:
        args.dims_layers_ae = [x_train.shape[1]] + args.dims_layers_ae
        assert args.dims_layers_ae[-1] == args.dims_layers_classifier[0], \
            'Dimension of latent space must be equal with dimension of input classifier!'

        model_ae = AutoEncoder(args.dims_layers_ae, args.use_dropout).to(device)
        criterion_ae = nn.MSELoss()
        optimizer = torch.optim.Adam(list(model_ae.parameters()) + list(model_classifier.parameters()), lr=args.lr)
    else:
        optimizer = torch.optim.Adam(model_classifier.parameters(), lr=args.lr)

    scheduler = None
    if args.use_scheduler:
        scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda ep: 0.95)

    writer = SummaryWriter(save_dir)

    total_scores = {'roc_auc': 0, 'acc': 0, 'mcc': 0, 'bal': 0, 'recall': 0,
                    'max_roc_auc': 0, 'max_acc': 0, 'max_mcc': 0, 'max_bal': 0, 'max_recall': 0,
                    'pre-fit_time': 0, 'pre-score_time': 0, 'fit_time': 0, 'score_time': 0
                    }

    dir_model_ae = f'{args.save_dir}/models_AE'
    Path(dir_model_ae).mkdir(parents=True, exist_ok=True)
    # dir_model_classifier = f'{args.save_dir}/models_classifier'
    # Path(dir_model_classifier).mkdir(parents=True, exist_ok=True)

    path_ae = f'{dir_model_ae}/{name_target}_{n_split}.pth'
    # path_classifier = f'{dir_model_classifier}/{name_target}_{n_split}.pth'

    if 'pre-training_ae' in args.procedure:
        min_val_loss = np.Inf
        epochs_no_improve = 0

        epoch_tqdm = tqdm(range(args.pretrain_epochs), desc="Epoch pre-train loss")
        for epoch in epoch_tqdm:
            loss_train, time_trn = train_step(model_ae, None, criterion_ae, None, optimizer, scheduler, x_train,
                                              y_train, device, writer, epoch, args.batch_size, 'pre-training_ae')
            loss_test, _ = test_step(model_ae, None, criterion_ae, None, x_test, y_test, device, writer, epoch,
                                     args.batch_size, 'pre-training_ae')

            if not np.isfinite(loss_train):
                break

            total_scores['pre-fit_time'] += time_trn

            if loss_test < min_val_loss:
                torch.save(model_ae.state_dict(), path_ae)
                epochs_no_improve = 0
                min_val_loss = loss_test
            else:
                epochs_no_improve += 1
            epoch_tqdm.set_description(
                f"Epoch pre-train loss: {loss_train:.5f}, test loss: {loss_test:.5f} (minimal val-loss: {min_val_loss:.5f}, stop: {epochs_no_improve}|{args.earlyStopping})")
            if args.earlyStopping is not None and epoch >= args.earlyStopping and epochs_no_improve == args.earlyStopping:
                print('\033[1;31mEarly stopping in pre-training model\033[0m')
                break
        print(f"\033[1;5;33mLoad model AE form '{path_ae}'\033[0m")
        if device.type == "cpu":
            model_ae.load_state_dict(torch.load(path_ae, map_location=lambda storage, loc: storage))
        else:
            model_ae.load_state_dict(torch.load(path_ae))
        model_ae = model_ae.to(device)
        model_ae.eval()

    min_val_loss = np.Inf
    epochs_no_improve = 0

    epoch = None
    stage = 'training_classifier' if 'training_classifier' in args.procedure else 'training_all'
    epoch_tqdm = tqdm(range(args.epochs), desc="Epoch train loss")
    for epoch in epoch_tqdm:
        loss_train, time_trn = train_step(model_ae, model_classifier, criterion_ae, criterion_classifier, optimizer,
                                          scheduler, x_train, y_train, device, writer, epoch, args.batch_size,
                                          stage, args.scale_loss)
        loss_test, scores_val, time_tst = test_step(model_ae, model_classifier, criterion_ae, criterion_classifier,
                                                    x_test, y_test, device, writer, epoch, args.batch_size, stage,
                                                    args.scale_loss)

        if not np.isfinite(loss_train):
            break

        total_scores['fit_time'] += time_trn
        total_scores['score_time'] += time_tst
        if total_scores['max_roc_auc'] < scores_val['roc_auc']:
            for key, val in scores_val.items():
                total_scores[f'max_{key}'] = val

        if loss_test < min_val_loss:
            # torch.save(model_ae.state_dict(), path_ae)
            # torch.save(model_classifier.state_dict(), path_classifier)
            epochs_no_improve = 0
            min_val_loss = loss_test
            for key, val in scores_val.items():
                total_scores[key] = val
        else:
            epochs_no_improve += 1
        epoch_tqdm.set_description(
            f"Epoch train loss: {loss_train:.5f}, test loss: {loss_test:.5f} (minimal val-loss: {min_val_loss:.5f}, stop: {epochs_no_improve}|{args.earlyStopping})")
        if args.earlyStopping is not None and epoch >= args.earlyStopping and epochs_no_improve == args.earlyStopping:
            print('\033[1;31mEarly stopping!\033[0m')
            break
    total_scores['score_time'] /= epoch + 1
    writer.close()

    save_file = f'{args.save_dir}/{name_target}.txt'
    head = 'idx;params'
    temp = f'{n_split};pretrain_epochs:{args.pretrain_epochs},dims_layers_ae:{args.dims_layers_ae},' \
           f'dims_layers_classifier:{args.dims_layers_classifier},batch_size:{args.batch_size},lr:{args.lr}' \
           f'use_dropout:{args.use_dropout},procedure:{args.procedure},scale_loss:{args.scale_loss},' \
           f'earlyStopping:{args.earlyStopping}'
    for key, val in total_scores.items():
        head = head + f';{key}'
        temp = temp + f';{val}'

    not_exists = not Path(save_file).exists()
    with open(save_file, 'a') as f:
        if not_exists:
            f.write(f'{head}\n')
        f.write(f'{temp}\n')