Пример #1
0
def run(args, kwargs):
    args.model_signature = str(datetime.datetime.now())[0:19]

    model_name = args.dataset_name + '_' + args.model_name + '_' + args.prior + '(K_' + str(args.number_components) + ')' + '_wu(' + str(args.warmup) + ')' + '_z1_' + str(args.z1_size) + '_z2_' + str(args.z2_size)

    # DIRECTORY FOR SAVING
    snapshots_path = 'snapshots/'
    dir = snapshots_path + args.model_signature + '_' + model_name +  '/'

    if not os.path.exists(dir):
        os.makedirs(dir)

    # LOAD DATA=========================================================================================================
    print('load data')

    # loading data
    train_loader, val_loader, test_loader, args = load_dataset(args, **kwargs)

    # CREATE MODEL======================================================================================================
    print('create model')
    # importing model
    if args.model_name == 'vae':
        from models.VAE import VAE
    elif args.model_name == 'hvae_2level':
        from models.HVAE_2level import VAE
    elif args.model_name == 'convhvae_2level':
        from models.convHVAE_2level import VAE
    elif args.model_name == 'pixelhvae_2level':
        from models.PixelHVAE_2level import VAE
    else:
        raise Exception('Wrong name of the model!')

    model = VAE(args)
    if args.cuda:
        model.cuda()

    optimizer = AdamNormGrad(model.parameters(), lr=args.lr)
#    optimizer = optim.Adam(model.parameters(), lr=args.lr)

    # ======================================================================================================================
    print(args)
    with open('vae_experiment_log.txt', 'a') as f:
        print(args, file=f)

    # ======================================================================================================================
    print('perform experiment')
    from utils.perform_experiment import experiment_vae
    experiment_vae(args, train_loader, val_loader, test_loader, model, optimizer, dir, model_name = args.model_name)
    # ======================================================================================================================
    print('-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-')
    with open('vae_experiment_log.txt', 'a') as f:
        print('-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-\n', file=f)
Пример #2
0
def run(args, kwargs):
    args.model_signature = str(datetime.datetime.now())[0:19]

    model_name = args.dataset_name + '_' + args.model_name + '_' + args.prior + '(K_' + str(args.number_components) + ')' + '_wu(' + str(args.warmup) + ')' + '_z1_' + str(args.z1_size) + '_z2_' + str(args.z2_size)

    # DIRECTORY FOR SAVING
    snapshots_path = 'snapshots/'
    dir = snapshots_path + args.model_signature + '_' + model_name +  '/'

    if not os.path.exists(dir):
        os.makedirs(dir)

    # LOAD DATA=========================================================================================================
    print('load data')

    # loading data
    train_loader, val_loader, test_loader, args = load_dataset(args, **kwargs)

    # CREATE MODEL======================================================================================================
    print('create model')
    # importing model
    if args.model_name == 'vae':
        from models.VAE import VAE
    elif args.model_name == 'hvae_2level':
        from models.HVAE_2level import VAE
    elif args.model_name == 'convhvae_2level':
        from models.convHVAE_2level import VAE
    elif args.model_name == 'pixelhvae_2level':
        from models.PixelHVAE_2level import VAE
    else:
        raise Exception('Wrong name of the model!')

    model = VAE(args)
    if args.cuda:
        model.cuda()

    optimizer = AdamNormGrad(model.parameters(), lr=args.lr)

    # ======================================================================================================================
    print(args)
    with open('vae_experiment_log.txt', 'a') as f:
        print(args, file=f)

    # ======================================================================================================================
    print('perform experiment')
    from utils.perform_experiment import experiment_vae
    experiment_vae(args, train_loader, val_loader, test_loader, model, optimizer, dir, model_name = args.model_name)
    # ======================================================================================================================
    print('-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-')
    with open('vae_experiment_log.txt', 'a') as f:
        print('-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-\n', file=f)
Пример #3
0
def main():
    # load trained model
    vae = VAE(n_in=1, n_latent=n_latent, n_h=64)
    serializer.load(vae, n_epoch)

    # encode
    encoded_train = encode(x_train, y_train, vae)
    encoded_test = encode(x_test, y_test, vae)
    train_features = encoded_train[:, :-1]
    train_targets = encoded_train[:, -1].astype(np.int32)
    test_features = encoded_test[:, :-1]
    test_targets = encoded_test[:, -1].astype(np.int32)

    # train SVC
    train_svm(train_features, train_targets)

    # test
    results = predict(test_features)

    # evaluation
    evaluation(test_targets, results)

    return 0
Пример #4
0
        # S = SMV_kernel(patchSize, voxel_size, radius=5)
        # D = np.real(S * D)
    else:
        B0_dir = (0, 0, 1)
        patchSize = (64, 64, 64)
        # patchSize_padding = (64, 64, 128)
        patchSize_padding = patchSize
        extraction_step = (21, 21, 21)
        voxel_size = (1, 1, 1)
        D = dipole_kernel(patchSize_padding, voxel_size, B0_dir)

    # network
    vae3d = VAE(
        input_channels=1, 
        output_channels=2,
        latent_dim=latent_dim,
        use_deconv=use_deconv,
        renorm=renorm,
        flag_r_train=0
    )

    vae3d.to(device)
    print(vae3d)

    # optimizer
    optimizer = optim.Adam(vae3d.parameters(), lr = lr, betas=(0.5, 0.999))
    ms = [0.3, 0.5, 0.7, 0.9]
    ms = [np.floor(m * niter).astype(int) for m in ms]
    scheduler = MultiStepLR(optimizer, milestones = ms, gamma = 0.2)

    # logger
    logger = Logger('logs', rootDir, opt['flag_rsa'], opt['case_validation'], opt['case_test'])
Пример #5
0
    def __init__(self, params):

        self.params = params
        self.loss_function = {
            'ms-ssim': ms_ssim_loss,
            'mse': mse_loss,
            'mix': mix_loss
        }[params["loss"]]

        # Choose device
        self.cuda = params["cuda"] and torch.cuda.is_available()
        torch.manual_seed(params["seed"])
        # Fix numeric divergence due to bug in Cudnn
        torch.backends.cudnn.benchmark = True
        self.device = torch.device("cuda" if self.cuda else "cpu")

        # Prepare data transformations
        red_size = params["img_size"]
        transform_train = transforms.Compose([
            transforms.ToPILImage(),
            transforms.Resize((red_size, red_size)),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
        ])

        transform_val = transforms.Compose([
            transforms.ToPILImage(),
            transforms.Resize((red_size, red_size)),
            transforms.ToTensor(),
        ])

        # Initialize Data loaders
        op_dataset = RolloutObservationDataset(params["path_data"],
                                               transform_train,
                                               train=True)
        val_dataset = RolloutObservationDataset(params["path_data"],
                                                transform_val,
                                                train=False)

        self.train_loader = torch.utils.data.DataLoader(
            op_dataset,
            batch_size=params["batch_size"],
            shuffle=True,
            num_workers=0)
        self.eval_loader = torch.utils.data.DataLoader(
            val_dataset,
            batch_size=params["batch_size"],
            shuffle=False,
            num_workers=0)

        # Initialize model and hyperparams
        self.model = VAE(nc=3,
                         ngf=64,
                         ndf=64,
                         latent_variable_size=params["latent_size"],
                         cuda=self.cuda).to(self.device)
        self.optimizer = optim.Adam(self.model.parameters())
        self.init_vae_model()
        self.visualize = params["visualize"]
        if self.visualize:
            self.plotter = VisdomLinePlotter(env_name=params['env'])
            self.img_plotter = VisdomImagePlotter(env_name=params['env'])
        self.alpha = params["alpha"] if params["alpha"] else 1.0
Пример #6
0
class VAE_TRAINER():
    def __init__(self, params):

        self.params = params
        self.loss_function = {
            'ms-ssim': ms_ssim_loss,
            'mse': mse_loss,
            'mix': mix_loss
        }[params["loss"]]

        # Choose device
        self.cuda = params["cuda"] and torch.cuda.is_available()
        torch.manual_seed(params["seed"])
        # Fix numeric divergence due to bug in Cudnn
        torch.backends.cudnn.benchmark = True
        self.device = torch.device("cuda" if self.cuda else "cpu")

        # Prepare data transformations
        red_size = params["img_size"]
        transform_train = transforms.Compose([
            transforms.ToPILImage(),
            transforms.Resize((red_size, red_size)),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
        ])

        transform_val = transforms.Compose([
            transforms.ToPILImage(),
            transforms.Resize((red_size, red_size)),
            transforms.ToTensor(),
        ])

        # Initialize Data loaders
        op_dataset = RolloutObservationDataset(params["path_data"],
                                               transform_train,
                                               train=True)
        val_dataset = RolloutObservationDataset(params["path_data"],
                                                transform_val,
                                                train=False)

        self.train_loader = torch.utils.data.DataLoader(
            op_dataset,
            batch_size=params["batch_size"],
            shuffle=True,
            num_workers=0)
        self.eval_loader = torch.utils.data.DataLoader(
            val_dataset,
            batch_size=params["batch_size"],
            shuffle=False,
            num_workers=0)

        # Initialize model and hyperparams
        self.model = VAE(nc=3,
                         ngf=64,
                         ndf=64,
                         latent_variable_size=params["latent_size"],
                         cuda=self.cuda).to(self.device)
        self.optimizer = optim.Adam(self.model.parameters())
        self.init_vae_model()
        self.visualize = params["visualize"]
        if self.visualize:
            self.plotter = VisdomLinePlotter(env_name=params['env'])
            self.img_plotter = VisdomImagePlotter(env_name=params['env'])
        self.alpha = params["alpha"] if params["alpha"] else 1.0

    def train(self, epoch):
        self.model.train()
        # dataset_train.load_next_buffer()
        mse_loss = 0
        ssim_loss = 0
        train_loss = 0
        # Train step
        for batch_idx, data in enumerate(self.train_loader):
            data = data.to(self.device)
            self.optimizer.zero_grad()
            recon_batch, mu, logvar = self.model(data)
            loss, mse, ssim = self.loss_function(recon_batch, data, mu, logvar,
                                                 self.alpha)
            loss.backward()

            train_loss += loss.item()
            ssim_loss += ssim
            mse_loss += mse
            self.optimizer.step()

            if batch_idx % params["log_interval"] == 0:
                print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                    epoch, batch_idx * len(data),
                    len(self.train_loader.dataset),
                    100. * batch_idx / len(self.train_loader), loss.item()))
                print('MSE: {} , SSIM: {:.4f}'.format(mse, ssim))

        step = len(self.train_loader.dataset) / float(
            self.params["batch_size"])
        mean_train_loss = train_loss / step
        mean_ssim_loss = ssim_loss / step
        mean_mse_loss = mse_loss / step
        print('-- Epoch: {} Average loss: {:.4f}'.format(
            epoch, mean_train_loss))
        print('-- Average MSE: {:.5f} Average SSIM: {:.4f}'.format(
            mean_mse_loss, mean_ssim_loss))
        if self.visualize:
            self.plotter.plot('loss', 'train', 'VAE Train Loss', epoch,
                              mean_train_loss)
        return

    def eval(self):
        self.model.eval()
        # dataset_test.load_next_buffer()
        eval_loss = 0
        mse_loss = 0
        ssim_loss = 0
        vis = True
        with torch.no_grad():
            # Eval step
            for data in self.eval_loader:
                data = data.to(self.device)
                recon_batch, mu, logvar = self.model(data)

                loss, mse, ssim = self.loss_function(recon_batch, data, mu,
                                                     logvar, self.alpha)
                eval_loss += loss.item()
                ssim_loss += ssim
                mse_loss += mse
                if vis:
                    org_title = "Epoch: " + str(epoch)
                    comparison1 = torch.cat([
                        data[:4],
                        recon_batch.view(params["batch_size"], 3,
                                         params["img_size"],
                                         params["img_size"])[:4]
                    ])
                    if self.visualize:
                        self.img_plotter.plot(comparison1, org_title)
                    vis = False

        step = len(self.eval_loader.dataset) / float(params["batch_size"])
        mean_eval_loss = eval_loss / step
        mean_ssim_loss = ssim_loss / step
        mean_mse_loss = mse_loss / step
        print('-- Eval set loss: {:.4f}'.format(mean_eval_loss))
        print('-- Eval MSE: {:.5f} Eval SSIM: {:.4f}'.format(
            mean_mse_loss, mean_ssim_loss))
        if self.visualize:
            self.plotter.plot('loss', 'eval', 'VAE Eval Loss', epoch,
                              mean_eval_loss)
            self.plotter.plot('loss', 'mse train', 'VAE MSE Loss', epoch,
                              mean_mse_loss)
            self.plotter.plot('loss', 'ssim train', 'VAE MSE Loss', epoch,
                              mean_ssim_loss)

        return mean_eval_loss

    def init_vae_model(self):
        self.vae_dir = os.path.join(self.params["logdir"], 'vae')
        check_dir(self.vae_dir, 'samples')
        if not self.params["noreload"]:  # and os.path.exists(reload_file):
            reload_file = os.path.join(self.params["vae_location"], 'best.tar')
            state = torch.load(reload_file)
            print("Reloading model at epoch {}"
                  ", with eval error {}".format(state['epoch'],
                                                state['precision']))
            self.model.load_state_dict(state['state_dict'])
            self.optimizer.load_state_dict(state['optimizer'])

    def checkpoint(self, cur_best, eval_loss):
        # Save the best and last checkpoint
        best_filename = os.path.join(self.vae_dir, 'best.tar')
        filename = os.path.join(self.vae_dir, 'checkpoint.tar')
        is_best = not cur_best or eval_loss < cur_best
        if is_best:
            cur_best = eval_loss

        save_checkpoint(
            {
                'epoch': epoch,
                'state_dict': self.model.state_dict(),
                'precision': eval_loss,
                'optimizer': self.optimizer.state_dict()
            }, is_best, filename, best_filename)
        return cur_best

    def plot(self, train, eval, epochs):
        plt.plot(epochs, train, label="train loss")
        plt.plot(epochs, eval, label="eval loss")
        plt.legend()
        plt.grid()
        plt.savefig(self.params["logdir"] + "/vae_training_curve.png")
        plt.close()
        'param_wdropout_k': 1
    },
]

# %%
for param_setting in param_grid:

    # Copy config, set new variables
    run_config = deepcopy(config)
    run_config.freebits_param = param_setting['free_bits_param']
    run_config.mu_force_beta_param = param_setting['mu_force_beta_param']
    run_config.param_wdropout_k = param_setting['param_wdropout_k']

    vae = VAE(encoder_hidden_size=run_config.vae_encoder_hidden_size,
              decoder_hidden_size=run_config.vae_decoder_hidden_size,
              latent_size=run_config.vae_latent_size,
              vocab_size=run_config.vocab_size,
              param_wdropout_k=run_config.param_wdropout_k,
              embedding_size=run_config.embedding_size).to(run_config.device)

    optimizer = torch.optim.Adam(params=vae.parameters())

    # Initalize results writer
    path_to_results = f'{run_config.results_path}/vae'
    params2string = '-'.join(
        [f"{i}:{param_setting[i]}" for i in param_setting.keys()])

    results_writer = ResultsWriter(
        label=f'{run_config.run_label}--vae-{params2string}', )

    sentence_decoder = utils.make_sentence_decoder(cd.tokenizer, 1)
Пример #8
0
acc, all_preds = ev.evaluate_classifier(arguments, classifier, test_loader)

print('accuracy {}'.format(acc.item()))

if not os.path.exists('joint_models/'):
    os.mkdir('joint_models/')
torch.save(
    classifier.state_dict(), 'joint_models/joint_classifier_' +
    arguments.dataset_name + 'accuracy_{}'.format(acc) + '.t')

pdb.set_trace()

#
### generator
model = VAE(arguments)
if arguments.cuda:
    model = model.cuda()

if 0 & os.path.exists(model_path):
    print('loading model...')

    model.load_state_dict(torch.load(model_path))
    model = model.cuda()
else:
    print('training model...')
    optimizer = AdamNormGrad(model.parameters(), lr=arguments.lr)
    tr.experiment_vae(arguments, train_loader, val_loader, test_loader, model,
                      optimizer, dr, arguments.model_name)

results = ev.evaluate_vae(arguments, model, train_loader, test_loader, 0,
Пример #9
0
def train_vae(
    model: VAE,
    optimizer,
    train_loader: DataLoader,
    valid_loader: DataLoader,
    nr_epochs: int,
    device: str,
    results_writer: ResultsWriter,
    config: Config,
    decoder,
):
    """
    Trains VAE, bases on a config file
    """

    # Define highest values
    best_valid_loss = np.inf
    previous_valid_loss = np.inf

    # Create elbo loss function
    loss_fn = make_elbo_criterion(
        vocab_size=model.vocab_size,
        latent_size=model.latent_size,
        freebits_param=config.freebits_param,
        mu_force_beta_param=config.mu_force_beta_param)

    for epoch in range(nr_epochs):
        for idx, (train_batch, batch_sent_lengths) in enumerate(train_loader):
            it = epoch * len(train_loader) + idx

            batch_loss, perp, preds = train_batch_vae(
                model, optimizer, loss_fn, train_batch, device,
                config.mu_force_beta_param, batch_sent_lengths, results_writer,
                it)
            elbo_loss, kl_loss, nlll, mu_loss = batch_loss

            if it % config.print_every == 0:
                print(
                    f'Iteration: {it} || NLLL: {nlll} || Perp: {perp} || KL Loss: {kl_loss} || MuLoss: {mu_loss} || Total: {elbo_loss}'
                )

                # Store in the table
                train_vae_results = make_vae_results_dict(
                    batch_loss, perp, model, config, epoch, it)
                results_writer.add_train_batch_results(train_vae_results)

            if it % config.train_text_gen_every == 0:
                with torch.no_grad():
                    decoded_first_pred = decoder(preds.detach())
                    decoded_first_true = decoder(train_batch[:, 1:])
                    results_writer.add_sentence_predictions(
                        decoded_first_pred, decoded_first_true, it)

                    print(f'VAE is generating sentences on {it}: \n')
                    print(
                        f'\t The true sentence is: "{decoded_first_true}" \n')
                    print(
                        f'\t The predicted sentence is: "{decoded_first_pred}" \n'
                    )

            if idx % config.validate_every == 0 and it != 0:
                print('Validating model')
                valid_losses, valid_perp = evaluate_vae(
                    model,
                    valid_loader,
                    epoch,
                    device,
                    loss_fn,
                    config.mu_force_beta_param,
                    iteration=it)

                # Store validation results
                valid_vae_results = make_vae_results_dict(
                    valid_losses, valid_perp, model, config, epoch, it)
                results_writer.add_valid_results(valid_vae_results)

                valid_elbo_loss, valid_kl_loss, valid_nll_loss, valid_mu_loss = valid_losses
                print(
                    f'Validation Results || Elbo loss: {valid_elbo_loss} || KL loss: {valid_kl_loss} || NLLL {valid_nll_loss} || Perp: {valid_perp} ||MU loss {valid_mu_loss}'
                )

                # Check if the model is better and save
                previous_valid_loss = valid_elbo_loss
                if previous_valid_loss < best_valid_loss:
                    print(
                        f'New Best Validation score of {previous_valid_loss}!')
                    best_valid_loss = previous_valid_loss
                    save_model(
                        f'vae_best_mu{config.mu_force_beta_param}_wd{model.param_wdropout_k}_fb{config.freebits_param}',
                        model, optimizer, it)

                model.train()

    results_writer.save_train_results()
    results_writer.save_valid_results()

    print('Done training the VAE')
Пример #10
0
    RolloutSequenceDataset(params["path_data"], params["seq_len"], transform, buffer_size=params["train_buffer_size"]),
    batch_size=params['batch_size'], num_workers=1, shuffle=True)
test_loader = DataLoader(
    RolloutSequenceDataset(params["path_data"],  params["seq_len"], transform, train=False, buffer_size=params["test_buffer_size"]),
    batch_size=params['batch_size'], num_workers=1)



vae_file = os.path.join(params['logdir'], 'vae', 'best.tar')
assert os.path.exists(vae_file), "VAE Checkpoint does not exist."
state = torch.load(vae_file)
print("Loading VAE at epoch {} "
      "with test error {}".format(
          state['epoch'], state['precision']))

vae_model = VAE(nc=3, ngf=params["img_size"], ndf=params["img_size"], latent_variable_size=params["latent_size"], cuda=cuda).to(device)
vae_model.load_state_dict(state['state_dict'])

rnn_dir = os.path.join(params['logdir'], 'mdrnn')
rnn_file = os.path.join(rnn_dir, 'best.tar')
if os.path.exists(rnn_file):
    state_rnn = torch.load(rnn_file)
    print("Loading MD-RNN at epoch {} "
          "with test error {}".format(
              state_rnn['epoch'], state_rnn['precision']))

    mdrnn = MDRNN(params['latent_size'], params['action_size'], params['hidden_size'], params['num_gmm']).to(device)
    rnn_state_dict = {k: v for k, v in state_rnn['state_dict'].items()}
    mdrnn.load_state_dict(rnn_state_dict)
else:
    mdrnn = MDRNN(params['latent_size'], params['action_size'], params['hidden_size'], params['num_gmm'])
def main():
    parser = argparse.ArgumentParser(description='Testing')
    parser.add_argument('--obj', type=str, default='.')
    parser.add_argument('--data_type', type=str, default='mvtec')
    parser.add_argument('--data_path', type=str, default='.')
    parser.add_argument('--checkpoint_dir', type=str, default='.')
    parser.add_argument("--grayscale",
                        action='store_true',
                        help='color or grayscale input image')
    parser.add_argument('--batch_size', type=int, default=16)
    parser.add_argument('--img_resize', type=int, default=128)
    parser.add_argument('--crop_size', type=int, default=128)
    parser.add_argument('--seed', type=int, default=None)
    args = parser.parse_args()
    args.save_dir = './' + args.data_type + '/' + args.obj + '/vgg_feature' + '/seed_{}/'.format(
        args.seed)
    if not os.path.exists(args.save_dir):
        os.makedirs(args.save_dir)

    # load model and dataset
    args.input_channel = 1 if args.grayscale else 3
    model = VAE(input_channel=args.input_channel, z_dim=100).to(device)
    checkpoint = torch.load(args.checkpoint_dir)
    model.load_state_dict(checkpoint['model'])
    teacher = models.vgg16(pretrained=True).to(device)
    for param in teacher.parameters():
        param.requires_grad = False

    img_size = args.crop_size if args.img_resize != args.crop_size else args.img_resize
    kwargs = {'num_workers': 4, 'pin_memory': True} if use_cuda else {}

    test_dataset = MVTecDataset(args.data_path,
                                class_name=args.obj,
                                is_train=False,
                                resize=img_size)
    test_loader = torch.utils.data.DataLoader(test_dataset,
                                              batch_size=args.batch_size,
                                              shuffle=True,
                                              **kwargs)

    scores, test_imgs, recon_imgs, gt_list, gt_mask_list = test(
        model, teacher, test_loader)
    scores = np.asarray(scores)
    max_anomaly_score = scores.max()
    min_anomaly_score = scores.min()
    scores = (scores - min_anomaly_score) / (max_anomaly_score -
                                             min_anomaly_score)
    gt_mask = np.asarray(gt_mask_list)
    precision, recall, thresholds = precision_recall_curve(
        gt_mask.flatten(), scores.flatten())
    a = 2 * precision * recall
    b = precision + recall
    f1 = np.divide(a, b, out=np.zeros_like(a), where=b != 0)
    threshold = thresholds[np.argmax(f1)]

    fpr, tpr, _ = roc_curve(gt_mask.flatten(), scores.flatten())
    per_pixel_rocauc = roc_auc_score(gt_mask.flatten(), scores.flatten())
    print('pixel ROCAUC: %.3f' % (per_pixel_rocauc))

    plt.plot(fpr, tpr, label='%s ROCAUC: %.3f' % (args.obj, per_pixel_rocauc))
    plt.legend(loc="lower right")
    save_dir = args.save_dir + '/' + f'seed_{args.seed}' + '/' + 'pictures_{:.4f}'.format(
        threshold)
    os.makedirs(save_dir, exist_ok=True)
    plt.savefig(os.path.join(save_dir, args.obj + '_roc_curve.png'), dpi=100)

    plot_fig(args, test_imgs, recon_imgs, scores, gt_mask_list, threshold,
             save_dir)
Пример #12
0
                writer.add_scalar(tag='loss/test',
                                  scalar_value=loss,
                                  global_step=i)

            likelihood_x[i * batch_size:(i + 1) * batch_size] = logsumexp(
                losses, axis=1) - np.log(number_samples)

        return np.mean(likelihood_x)


if __name__ == '__main__':
    from loaders.load_funtions import load_MNIST
    from models.VAE import VAE

    import pathlib

    _, loader, _, dataset_type = load_MNIST('../datasets/')

    output_dit = pathlib.Path('../outputs/')
    input_shape = (1, 28, 28)

    model = VAE(dimension_latent_space=50,
                input_shape=input_shape,
                dataset_type=dataset_type)
    model.load_state_dict(
        torch.load('../outputs/trained/mnist_bin_standard_50/model',
                   map_location='cpu'))
    model.eval()

    print(model.calculate_likelihood(loader, 100))