def load_einet_state(model_path,
                     einet_file='einet.pth',
                     graph_file='einet.rg',
                     n_vars=None, n_classes=None, n_sums=None, n_input_dists=None,
                     exp_fam=None, exp_fam_args=None,
                     use_em=None, em_freq=None, em_stepsize=None,
                     graph=None):

    # reload model
    einet = None

    if graph is None:
        if graph_file:
            graph_file = os.path.join(model_path, graph_file)
            graph = Graph.read_gpickle(graph_file)
        else:
            raise ValueError(f"Cannot create graph")
    
    model_file = os.path.join(model_path, einet_file)
    
    einet = make_einet(graph,
               n_vars=n_vars,
               n_classes=n_classes,
               n_sums=n_sums,
               n_input_dists=n_input_dists,
               exp_fam=exp_fam, exp_fam_args=exp_fam_args,
               use_em=use_em,
               em_freq=em_freq, em_stepsize=em_stepsize)
    einet.load_state_dict(torch.load(model_file))
    
    print("Loaded model from {}".format(model_file))
        
    return einet, graph
def load_einet(model_path, einet_file='einet.pth', graph_file='einet.rg'):

    # reload model
    einet, graph = None, None
    model_file = os.path.join(model_path, einet_file)
    einet = torch.load(model_file)
    print("Loaded model from {}".format(model_file))
    
    if graph_file:
        graph_file = os.path.join(model_path, graph_file)
        graph = Graph.read_gpickle(graph_file)
        
    return einet, graph