def load_model(self, epoch=None):
        """
        Load pre-trained model based on modelname.

        :return: None
        """
        if epoch is None:
            self.model = fm.load_model('{}'.format(self.modelname))
        else:
            self.model = fm.load_model('{}_{}'.format(self.modelname, epoch))
示例#2
0
def get_loss_bottleneck(loader, modelname, subdir, latent_dim_lst, device,
                        a_distr, a_rec, a_spar, a_disen):
    """


    :return:
    """

    total_loss = []
    rec_loss = []
    dist_loss = []
    spar_loss = []
    disen_loss = []

    for i in latent_dim_lst:
        print('bottleneck dimension: {}'.format(i))
        model = fm.load_model('{}_{}'.format(modelname, i).to(device), subdir)
        criterion = cl.MMD_autoencoder_loss(a_distr=a_distr,
                                            a_rec=a_rec,
                                            a_spar=a_spar,
                                            a_disen=a_disen,
                                            latent_dim=i,
                                            loss_type='l1',
                                            device=device)
        losses = get_loss(loader, model, criterion, i, device)
        total_loss.append(losses[0])
        rec_loss.append(losses[1])
        dist_loss.append(losses[2])
        spar_loss.append(losses[3])
        disen_loss.append(losses[4])

    return total_loss, rec_loss, dist_loss, spar_loss, disen_loss
示例#3
0
def plot_diff_all(get_model,
                  modelname,
                  num_epoch,
                  loader,
                  latent_dim_lst,
                  device='cpu',
                  num_img=1,
                  grid_row_size=10,
                  figsize=(30, 30),
                  filename=None,
                  conditional=False):
    """
    Plot original images, reconstructed images by the INN and the difference between those images for all latent dimensions given in latent_dim_lst.

    :param model: INN use for reconstruction
    :param loader: loader that wraps the train, test or evaluation set
    :param latent_dim_lst: list of dimensions of the latent space of which plots should be generated
    :param num_img: number of images to plot. Default: 1
    :param grid_row_size: number of images in one row in the grid
    :param figsize: the size of the generated plot
    :param filename: file name under which the plot will be saved. (optional)
    :return: None
    """
    for lat_dim in latent_dim_lst:
        print("Latent Dimension: ", lat_dim)
        try:
            model = fm.load_model(
                '{}_{}_{}'.format(modelname, lat_dim, num_epoch),
                "{}_bottleneck".format(modelname))
            plot_diff(model,
                      loader,
                      lat_dim,
                      device,
                      num_img,
                      grid_row_size,
                      filename='{}_{}'.format(modelname, lat_dim))
        except:
            model = get_model().to(device)
            model = fm.load_weight(
                model, '{}_{}_{}'.format(modelname, lat_dim, num_epoch),
                '{}_bottleneck'.format(modelname))
            plot_diff(model,
                      loader,
                      lat_dim,
                      device,
                      num_img,
                      grid_row_size,
                      filename='com_INN_mnist_{}'.format(lat_dim))