def run_single_NN():

    task = Task.init(
        task_name='Gridworld, n=32, b=4, normal')  #init task on ClearML

    #load variables from file
    open_file = open("NNIRL_param_list.pkl", "rb")
    NNIRL_param_list = pickle.load(open_file)
    threshold = NNIRL_param_list[0]
    optim_type = NNIRL_param_list[1]
    net = NNIRL_param_list[2]
    X = NNIRL_param_list[3]
    initD = NNIRL_param_list[4]
    mu_sa = NNIRL_param_list[5]
    muE = NNIRL_param_list[6]
    F = NNIRL_param_list[7]
    #F = F.type(torch.DoubleTensor)
    mdp_data = NNIRL_param_list[8]
    configuration_dict = NNIRL_param_list[9]
    truep = NNIRL_param_list[10]
    NLL_EVD_plots = NNIRL_param_list[11]
    example_samples = NNIRL_param_list[12]
    noisey_features = NNIRL_param_list[13]

    NLL = NLLFunction()  # initialise NLL
    #assign constants
    NLL.F = F
    NLL.muE = muE
    NLL.mu_sa = mu_sa
    NLL.initD = initD
    NLL.mdp_data = mdp_data

    configuration_dict = task.connect(
        configuration_dict)  #enabling configuration override by clearml

    start_time = time.time()  #to time execution
    #tester = testers() #to use testing functions

    # lists for printing
    NLList = []
    iterations = []
    evdList = []

    i = 0  #track iterations
    finalOutput = None  #store final est R
    loss = 1000  #init loss
    diff = 1000  #init diff
    evd = 10  #init val

    if noisey_features:
        #add noise to features at states 12, 34 and 64 (when mdp_params.n=8)
        #set each states features to all 0
        print('\n... adding noise to features at states 12, 34 and 64 ...\n')
        X[11, :] = torch.zeros(X.size()[1])
        X[33, :] = torch.zeros(X.size()[1])
        X[63, :] = torch.zeros(X.size()[1])

    #if noisey_paths:
    # print('\n... adding noise to paths at states 12, 34 and 64 ...\n')

    if (optim_type == 'Adam'):
        print('\nOptimising with torch.Adam\n')
        optimizer = torch.optim.Adam(
            weight_decay=1e-2)  #weight decay for l2 regularisation
        #while(evd > threshold): #termination criteria: evd threshold
        #for p in range(configuration_dict.get('number_of_epochs')): #termination criteria: no of iters in config dict
        while diff >= threshold:  #termination criteria: loss diff
            #for p in range(1): #for testing
            prevLoss = loss

            output = torch.empty(len(X[0]), 1, dtype=torch.double)

            indexer = 0
            for j in range(len(X[0])):
                thisR = net(X[:, j].view(-1, len(X[:, j])))
                output[indexer] = thisR
                indexer += 1
            finalOutput = output

            loss = NLL.apply(output, initD, mu_sa, muE, F,
                             mdp_data)  #use this line for custom gradient
            #loss = likelihood(output, initD, mu_sa, muE, F, mdp_data) #use this line for auto gradient
            #tester.checkgradients_NN(output, NLL) # check gradients
            loss.backward()  # propagate grad through network
            #nn.utils.clip_grad_norm_(net.parameters(), max_norm=2.0, norm_type=2)
            evd = NLL.calculate_EVD(truep, torch.matmul(X, output))  # calc EVD

            #printline to show est R
            #print('{}: output:\n {} | EVD: {} | loss: {} '.format(i, torch.matmul(X, output).repeat(1, 5).detach() , evd, loss.detach() ))

            #printline to hide est R
            print('{}: | EVD: {} | loss: {} | diff {}'.format(
                i, evd, loss, diff))
            # store metrics for printing
            finaloutput = output
            tensorboard_writer.add_scalar('loss', loss, i)
            tensorboard_writer.add_scalar('evd', evd, i)
            tensorboard_writer.add_scalar('diff', diff, i)

            i += 1
            diff = abs(prevLoss - loss)

        print('\implement LBFGS\n')

    PATH = './NN_IRL.pth'
    torch.save(net.state_dict(), PATH)

    if NLL_EVD_plots:
        # plot
        f, (ax1, ax2) = plt.subplots(1, 2, sharex=True)
        ax1.plot(iterations, NLList)
        ax1.plot(iterations, NLList, 'r+')

        ax2.plot(iterations, evdList)
        ax2.plot(iterations, evdList, 'r+')
        ax2.set_title('Expected Value Diff')

    print("\nruntime: --- %s seconds ---\n" % (time.time() - start_time))
    return net, finalOutput, (time.time() - start_time)
    # Extract full reward function
    y_mc_relu_reward = torch.from_numpy(y_mc_relu)
    y_mc_relu_reward = y_mc_relu_reward.reshape(len(y_mc_relu_reward), 1)
    y_mc_relu_reward = y_mc_relu_reward.repeat((1, 5))

    #Solve with learned reward functions
    y_mc_relu_v, y_mc_relu_q, y_mc_relu_logp, y_mc_relu_P = linearvalueiteration(
        mdp_data, y_mc_relu_reward)
    # Print results
    print("\nTrue R has:\n - negated likelihood: {}\n - EVD: {}".format(trueNLL,  irl_model.NLL.calculate_EVD(truep, r)))
    print("\nPred R with ReLU activation has:\n - negated likelihood: {}\n - EVD: {}".format(irl_model.NLL.apply(y_mc_relu_reward, initD, mu_sa, muE, feature_data['splittable'], mdp_data), irl_model.NLL.calculate_EVD(truep, y_mc_relu_reward)))

    # Initalise loss function
    NLL = NLLFunction()
    # Assign loss function constants
    NLL.F = feature_data['splittable']
    NLL.muE = muE
    NLL.mu_sa = mu_sa
    NLL.initD = initD
    NLL.mdp_data = mdp_data

    #Save results
    print('\n... saving results ...\n')

    # Create path for trained models
    RESULTS_PATH = "./noisey_paths/results/dropout/"
    for path in [RESULTS_PATH]:
    #Print what benchmark

    # Remove chosen states from paths
    if states_to_remove[index_states_to_remove] is not None:
        N = len(example_samples)
        top_index = math.ceil(0.5 * N)
        twenty_percent_example_samples = example_samples[0:top_index]
        for path in twenty_percent_example_samples:
            T = len(path)
            pathindex = twenty_percent_example_samples.index(path)
            for move in path:
                moveindex = twenty_percent_example_samples[pathindex].index(
                #remove state
                if move[0] in states_to_remove[index_states_to_remove]:
                    newmove = move
    example_samples = NNIRL_param_list[9]
    mdp_params = NNIRL_param_list[10]
    r = NNIRL_param_list[11]
    mdp_solution = NNIRL_param_list[12]
    feature_data = NNIRL_param_list[13]
    trueNLL = NNIRL_param_list[14]
    normalise = NNIRL_param_list[15]
    user_input = NNIRL_param_list[16]
    worldtype = NNIRL_param_list[17]


    # Remove chosen states from paths
    if states_to_remove[index_states_to_remove] is not None:
        N = len(example_samples)
        top_index = math.ceil(0.5 * N)
        twenty_percent_example_samples = example_samples[0:top_index]
        for path in twenty_percent_example_samples:
            T = len(path)
            pathindex = twenty_percent_example_samples.index(path)
            for move in path:
                moveindex = twenty_percent_example_samples[pathindex].index(
                #remove state
                if move[0] in states_to_remove[index_states_to_remove]:
                    newmove = move
#Solve MDP
print("\n... performing value iteration for v, q, logp and truep ...")
v, q, logp, truep = linearvalueiteration(mdp_data, r)
mdp_solution = {'v': v, 'q': q, 'p': truep, 'logp': logp}
optimal_policy = torch.argmax(truep, axis=1)
print("\n... done ...")

#Sample paths
if new_paths:
    print("\n... sampling paths from true R ...")
    example_samples = sampleexamples(N, T, mdp_solution, mdp_data)
    print("\n... done sampling", N, "paths ...")

NLL = NLLFunction()  # initialise NLL
if new_paths:
    initD, mu_sa, muE, F, mdp_data = NLL.calc_var_values(mdp_data, N, T, example_samples, feature_data)  # calculate required variables
    print("\n... using pre-loaded sampled paths ...")
    # Load variables
    open_file = open("NNIRL_param_list.pkl", "rb")
    NNIRL_param_list = pickle.load(open_file)
    threshold = NNIRL_param_list[0]
    optim_type = NNIRL_param_list[1]
    net = NNIRL_param_list[2]
    initD = NNIRL_param_list[3]
    mu_sa = NNIRL_param_list[4]
    muE = NNIRL_param_list[5]
    mdp_data = NNIRL_param_list[6]
    example_samples = NNIRL_param_list[9]
    mdp_params = NNIRL_param_list[10]
    r = NNIRL_param_list[11]
    mdp_solution = NNIRL_param_list[12]
    feature_data = NNIRL_param_list[13]
    trueNLL = NNIRL_param_list[14]
    normalise = NNIRL_param_list[15]
    user_input = NNIRL_param_list[16]
    worldtype = NNIRL_param_list[17]


    # Add noise to features
    if states_to_remove[index_states_to_corrupt] is not None:
        for state in states_to_remove[index_states_to_corrupt]:
            if random.randint(0, 100) < 3:  #3% chance of NOT using this state
            for i in range(len(feature_data['splittable'][state, :])):
                if random.randint(
                        0, 100) < 22:  #22% chance of inverting the feature
                    #invert the feature, works since binary features
                        state, i] = 1 - feature_data['splittable'][state, i]

    # Connect configuration dict
    configuration_dict = {

    configuration_dict = {
        'number_of_epochs': 3,
        'base_lr': 0.1,
        'p': dropout_val,
        'no_hidden_layers': 3,
        'no_neurons_in_hidden_layers': len(feature_data['splittable'][0]) * 2
    }  #set config params for clearml
    #configuration_dict = task.connect(configuration_dict) #For when task being logged on clearML

    # Assign loss function constants
    NLL.F = feature_data['splittable']
    NLL.muE = muE
    NLL.mu_sa = mu_sa
    NLL.initD = initD