Beispiel #1
0
def evaluate_from_model(model_dir, multi_flag=False, eval_data_all=False):
    """
    Evaluating interface. 1. Retreive the flags 2. get data 3. initialize network 4. eval
    :param model_dir: The folder to retrieve the model
    :param eval_data_all: The switch to turn on if you want to put all data in evaluation data
    :return: None
    """
    # Retrieve the flag object
    print("Retrieving flag object for parameters")
    if (model_dir.startswith("models")):
        model_dir = model_dir[7:]
        print("after removing prefix models/, now model_dir is:", model_dir)
    flags = helper_functions.load_flags(os.path.join("models", model_dir))
    flags.eval_model = model_dir                    # Reset the eval mode

    # Set up the test_ratio
    if flags.data_set == 'ballistics':
        flags.test_ratio = 0.1
    elif flags.data_set == 'sine_wave':
        flags.test_ratio = 0.1
    elif flags.data_set == 'robotic_arm':
        flags.test_ratio = 0.1
    
    # Get the data
    train_loader, test_loader = data_reader.read_data(flags, eval_data_all=eval_data_all)
    print("Making network now")

    # Make Network
    ntwk = Network(INN, flags, train_loader, test_loader, inference_mode=True, saved_model=flags.eval_model)
    print(ntwk.ckpt_dir)
    print("number of trainable parameters is :")
    pytorch_total_params = sum(p.numel() for p in ntwk.model.parameters() if p.requires_grad)
    print(pytorch_total_params)

    # Evaluation process
    print("Start eval now:")
    if multi_flag:
        ntwk.evaluate_multiple_time()
    else:
        pred_file, truth_file = ntwk.evaluate()

    # Plot the MSE distribution
    if flags.data_set != 'meta_material' and not multi_flag: 
        plotMSELossDistrib(pred_file, truth_file, flags)
    print("Evaluation finished")
    
    # If gaussian, plot the scatter plot
    if flags.data_set == 'gaussian_mixture':
        Xpred = helper_functions.get_Xpred(path='data/', name=flags.eval_model) 
        Ypred = helper_functions.get_Ypred(path='data/', name=flags.eval_model) 

        # Plot the points scatter
        generate_Gaussian.plotData(Xpred, Ypred, save_dir='data/' + flags.eval_model.replace('/','_') + 'generation plot.png', eval_mode=True)
Beispiel #2
0
def PlotPairwiseGeometry(figname, Xpred_dir):
    """
    Function to plot the pair-wise scattering plot of the geometery file to show
    the correlation between the geometry that the network learns
    """

    Xpredfile = helper_functions.get_Xpred(Xpred_dir)
    Xpred = pd.read_csv(Xpredfile, header=None, delimiter=' ')
    f = plt.figure()
    axes = pd.plotting.scatter_matrix(Xpred, alpha=0.2)
    #plt.tight_layout()
    plt.title("Pair-wise scattering of Geometery predictions")
    plt.savefig(figname)
Beispiel #3
0
def PlotPossibleGeoSpace(figname,
                         Xpred_dir,
                         compare_original=False,
                         calculate_diversity=None):
    """
    Function to plot the possible geometry space for a model evaluation result.
    It reads from Xpred_dir folder and finds the Xpred result insdie and plot that result
    :params figname: The name of the figure to save
    :params Xpred_dir: The directory to look for Xpred file which is the source of plotting
    :output A plot containing 4 subplots showing the 8 geomoetry dimensions
    """
    Xpred = helper_functions.get_Xpred(Xpred_dir)
    #Xpredfile = helper_functions.get_Xpred(Xpred_dir)
    #Xpred = pd.read_csv(Xpredfile, header=None, delimiter=' ').values

    Xtruth = helper_functions.get_Xtruth(Xpred_dir)
    #Xtruthfile = helper_functions.get_Xtruth(Xpred_dir)
    #Xtruth = pd.read_csv(Xtruthfile, header=None, delimiter=' ').values

    f = plt.figure()
    ax0 = plt.gca()
    print(np.shape(Xpred))
    #print(Xpred)
    #plt.title(figname)
    if (calculate_diversity == 'MST'):
        diversity_Xpred, diversity_Xtruth = calculate_MST(Xpred, Xtruth)
    elif (calculate_diversity == 'AREA'):
        diversity_Xpred, diversity_Xtruth = calculate_AREA(Xpred, Xtruth)

    for i in range(4):
        ax = plt.subplot(2, 2, i + 1)
        ax.scatter(Xpred[:, i], Xpred[:, i + 4], s=3, label="Xpred")
        if (compare_original):
            ax.scatter(Xtruth[:, i], Xtruth[:, i + 4], s=3, label="Xtruth")
        plt.xlabel('h{}'.format(i))
        plt.ylabel('r{}'.format(i))
        plt.xlim(-1, 1)
        plt.ylim(-1, 1)
        plt.legend()
    if (calculate_diversity != None):
        plt.text(-4,
                 3.5,
                 'Div_Xpred = {}, Div_Xtruth = {}, under criteria {}'.format(
                     diversity_Xpred, diversity_Xtruth, calculate_diversity),
                 zorder=1)
    plt.suptitle(figname)
    f.savefig(figname + '.png')