コード例 #1
0
def main():
    print('Reading Data')
    s = 'navigation'  #'navigation'
    trainf, validf = s + "-data-train-small.pickle", s + "-data-test-small.pickle"
    train, test = SeqData(trainf), SeqData(validf)

    # classType = NavigationTask if s == 'navigation' else TransportTask
    print(train.env.stateSubVectors)
    print('Defining Model')
    # Parameters
    learning_rate = 0.0002
    training_steps = 15000  #2000 # 10000
    batch_size = 64  #256 #128
    display_step = 200
    # Network Parameters
    n_hidden = 200  #128 #5*train.lenOfInput # hidden layer num of features
    len_state = train.lenOfState  # linear sequence or not
    len_input = train.lenOfInput

    fake_input = np.reshape(test.data[5], [1, 10, -1])
    fake_state = fake_input[0][0][0:len_state]
    fake_action = fake_input[0][0][len_state:]
    print('Initializing FM')
    with tf.Graph().as_default(), tf.Session() as sess:
        fm = ForwardModel(len_state, len_input, n_hidden)
        print('FM initialized')
        fm.train(train, test, training_steps, batch_size, train.env,
                 learning_rate, display_step, "abcd")
コード例 #2
0
ファイル: LSTMFM2.py プロジェクト: ttaa9/Policy-Tree-Planner
def main():
    f_model_name = 'forward-lstm-stochastic.pt'
    s = 'navigation'  # 'transport'
    trainf, validf = s + "-data-train-small.pickle", s + "-data-test-small.pickle"
    print('Reading Data')
    train, valid = SeqData(trainf), SeqData(validf)
    fm = LSTMForwardModel(train.lenOfInput, train.lenOfState)
    fm.train(train, valid)
    torch.save(fm.state_dict(), f_model_name)
コード例 #3
0
def main():

    ###
    runTraining = True

    ###

    f_model_name = 'LSTM_FM_1_99'
    s = 'navigation'  # 'transport'
    # Read training/validation data
    print('Reading Data')
    trainf, validf = s + "-data-train-small.pickle", s + "-data-test-small.pickle"
    train, valid = SeqData(trainf), SeqData(validf)
    # Load forward model
    ForwardModel = LSTMForwardModel(train.lenOfInput, train.lenOfState)
    ForwardModel.load_state_dict(torch.load(f_model_name))
    # Initialize forward policy
    exampleEnv = generateTask(
        0, 0, 0, 3, 0)  # This takes about 10 sec to train & solve on my comp
    SimPolicy = SimulationPolicy(exampleEnv)
    # Run training
    if runTraining:
        maxDepth = 3
        SimPolicy.trainSad(
            exampleEnv,
            ForwardModel,
            printActions=True,
            maxDepth=maxDepth,
            # treeBreadth=2,
            eta_lr=0.001,  #0.000375,
            trainIters=500,
            alpha=0.5,
            lambda_h=-0.005,  #-0.0125, # negative = encourage entropy
            useHolder=True,
            holderp=-2.0,
            useOnlyLeaves=False,
            gamma=0.9  #1.5
        )

        # NOTE: the branching factor parameter here is merely the branching level AT THE PARENT
        # It has no effect anywhere else
        s_0 = torch.unsqueeze(avar(torch.FloatTensor(
            exampleEnv.getStateRep())),
                              dim=0)
        tree = Tree(s_0,
                    ForwardModel,
                    SimPolicy,
                    greedy_valueF,
                    exampleEnv,
                    maxDepth=maxDepth)  #, branchingFactor=2)
        tree.measureLossAtTestTime()
        states, actions = tree.getBestPlan()
        print('Final Actions')
        for i in range(len(actions)):
            jq = actions[i][0].data.numpy().argmax()
            print('A' + str(i) + ':', jq, NavigationTask.actions[jq])
コード例 #4
0
def main():
    f_model_name = 'LSTM_FM_1_99'
    s = 'navigation'  # 'transport'
    trainf, validf = s + "-data-train-small.pickle", s + "-data-test-small.pickle"
    print('Reading Data')
    train, valid = SeqData(trainf), SeqData(validf)
    exampleEnv = generateTask(0, 0, 0, 14, 14)
    ForwardModel = LSTMForwardModel(train.lenOfInput, train.lenOfState)
    ForwardModel.load_state_dict(torch.load(f_model_name))
    SimPolicy = SimulationPolicy(exampleEnv)
    SimPolicy.trainSad(ForwardModel)
    s_0 = torch.unsqueeze(avar(torch.FloatTensor(exampleEnv.getStateRep())),
                          dim=0)
    tree = Tree(s_0, ForwardModel, SimPolicy, greedy_cont_valueF, exampleEnv,
                5, 2)
    states, actions = tree.getBestPlan()
    for i in range(len(actions)):
        print(actions[i][0].data.numpy().argmax())
コード例 #5
0
def inference():
    print('Reading Data')
    s = 'navigation'  #'navigation'
    trainf, validf = s + "-data-train-small.pickle", s + "-data-test-small.pickle"
    train, test = SeqData(trainf), SeqData(validf)
    # classType = NavigationTask if s == 'navigation' else TransportTask
    print(train.env.stateSubVectors)
    print('Defining Model')
    # Parameters
    learning_rate = 0.01
    training_steps = 5000  #2000 # 10000
    batch_size = 64  #256 #128
    display_step = 200
    # Network Parameters
    seq_max_len = 10  # Sequence max length
    n_hidden = 100  #128 #5*train.lenOfInput # hidden layer num of features
    len_state = train.lenOfState  # linear sequence or not
    len_input = train.lenOfInput

    fake_input = np.reshape(test.data[5], [1, 10, -1])
    fake_state = fake_input[0][0][0:len_state]
    fake_action = fake_input[0][0][len_state:]
    print(fake_action)

    print('action:', np.argmax(fake_action))
    print('state:', [
        np.argmax(k)
        for k in train.env.deconcatenateOneHotStateVector(fake_state)
    ])
    print(fake_input)

    with tf.Graph().as_default(), tf.Session() as sess:

        fm = ForwardModel(len_state, len_input, n_hidden)

        fm.load_model('abcd.ckpt')
        fake_output, state_out = fm.predict(fake_input)
        fake_output = train.env.deconcatenateOneHotStateVector(fake_output[0])
        fake_output = [np.argmax(i) for i in fake_output]
        print(fake_output)
コード例 #6
0
def main():

    ###
    runTraining = True
    generateFigs = False
    ###

    f_model_name = 'LSTM_FM_1_99'
    s = 'navigation'  # 'transport'

    # Generate task
    exampleEnv = generateTask(0, 0, 0, 0, 6)  # <----------------------- Task

    # Greedy value predictor
    # gvp_model_name = "greedy_value_predictor"
    # GreedyVP = GreedyValuePredictor(exampleEnv)
    # GreedyVP.load_state_dict(torch.load(gvp_model_name))
    # greedyValueEstimator = lambda state: greedy_value_predictor(state,GreedyVP=GreedyVP)
    gve = greedy_valueFunc

    # Load forward model
    print('Loading Forward Model')
    lenOfState = 15 * 4 + 4  # exampleEnv.
    lenOfInput = lenOfState + 10  # exampleEnv.
    ForwardModel = LSTMForwardModel(lenOfInput, lenOfState)  #
    ForwardModel.load_state_dict(torch.load(f_model_name))

    # Initialize policy
    SimPolicy = SimulationPolicy(exampleEnv)

    # Train the simulation policy
    print('Starting training')
    SimPolicy.trainReinforce(exampleEnv,
                             ForwardModel,
                             maxDepth=5,
                             N_iters=2000,
                             valueF=gve)

    sys.exit(
        0
    )  #--------------------------------------------------------------------------------------------

    #########################################################################################################

    # Read training/validation data
    print('Reading Data')
    trainf, validf = s + "-data-train-small.pickle", s + "-data-test-small.pickle"
    train, valid = SeqData(trainf), SeqData(validf)

    useGreedyRewardPredictor = False
    if useGreedyRewardPredictor:
        print('Loading (greedy) value predictor')
        gvp_model_name = "greedy_value_predictor"
        GreedyVP = GreedyValuePredictor(exampleEnv)
        GreedyVP.load_state_dict(torch.load(gvp_model_name))
        greedyValueEstimator = lambda state: greedy_value_predictor(
            state, GreedyVP=GreedyVP)
    else:
        print('Using L2')
        greedyValueEstimator = greedy_valueFunc

    if generateFigs:

        nRepeats = 10
        nodeNumsSeqs = []
        accsSeqs = []
        namee = '3,4-10'  #'5,9-10' #'0,6-10' # '3,4-10'

        fname = "data-fig-" + namee

        if os.path.exists(fname):
            print('Loading ' + fname)
            with open(fname, 'rb') as outFile:
                p = pickle.load(outFile)
                nodeNumsSeqs, accsSeqs = p
                numSeqs = len(nodeNumsSeqs)
                numPoints = len(nodeNumsSeqs[0])
                t = np.arange(numPoints)

                nodeNumsSeqs = np.array(nodeNumsSeqs)
                meanNodes = np.mean(nodeNumsSeqs, axis=0)
                stderrNodes = np.std(nodeNumsSeqs, axis=0) / np.sqrt(numSeqs)

                accsSeqs = np.array(accsSeqs)
                meanAccs = np.mean(accsSeqs, axis=0)
                stderrAccs = np.std(accsSeqs, axis=0) / np.sqrt(numSeqs)

                smooth = False
                if smooth:
                    from scipy.signal import savgol_filter
                    # window length of wlen and a degree deg polynomial
                    wlen, deg = 5, 2
                    meanNodes = savgol_filter(meanNodes, wlen, deg)
                    meanAccs = savgol_filter(meanAccs, wlen, deg)

                import matplotlib.pyplot as plt
                plt.rc('font', size=20)
                fig, ax = plt.subplots(1)
                # https://stackoverflow.com/questions/3899980/how-to-change-the-font-size-on-a-matplotlib-plot
                ax.plot(t, meanNodes, lw=1.4, label='Total Nodes', color='red')
                ax.fill_between(t,
                                meanNodes + stderrNodes,
                                meanNodes - stderrNodes,
                                facecolor='red',
                                alpha=0.45)
                ax.set_xlabel('Iteration')
                ax.set_ylabel('Number of Nodes')  #, color='b')
                #
                ax2 = ax.twinx()
                ax2.plot(t, meanAccs, lw=1.4, label='Reward', color='blue')
                ax2.fill_between(t,
                                 meanAccs + stderrAccs,
                                 meanAccs - stderrAccs,
                                 facecolor='blue',
                                 alpha=0.45)
                ax2.set_ylabel('Reward')  #, color='r')
                #
                ax.set_xlim([-1, 750])
                ax2.set_ylim([-0.05, 1.05])
                # ax.legend(loc='lower right') #'upper left')
                # ax2.legend(loc='upper left')
                #
                plt.tight_layout()

                plt.show()

        else:
            print('Generating data')
            for ir in range(0, nRepeats):
                print('On iter', ir)
                maxDepth = 4
                exampleEnv = generateTask(0, 0, 0, 5,
                                          9)  # <----------------------- Task
                SimPolicy = SimulationPolicy(exampleEnv)
                overallNumNodes, accuracies = SimPolicy.trainSad(
                    exampleEnv,
                    ForwardModel,
                    printActions=True,
                    maxDepth=maxDepth,
                    # treeBreadth=2,
                    eta_lr=0.00135,  #0.000375,
                    trainIters=750,
                    alpha=12.0,
                    lambda_h=
                    -0.075,  #-0.0125, # negative = encourage entropy in actions
                    useHolder=True,
                    holderp=-2.0,
                    useOnlyLeaves=False,
                    gamma=0.0000005,  #0.00000025, #1.5
                    xi=-0.00000125,  # -0.000005, #  0.00000000125
                    valueF=greedyValueEstimator)
                nodeNumsSeqs.append(overallNumNodes)
                accsSeqs.append(accuracies)

            with open(fname, 'wb') as outFile:
                print('Saving data')
                pickle.dump([nodeNumsSeqs, accsSeqs], outFile)

    # Run training
    if runTraining:
        maxDepth = 4
        overallNumNodes, accuracies = SimPolicy.trainSad(
            exampleEnv,
            ForwardModel,
            printActions=True,
            maxDepth=maxDepth,
            # treeBreadth=2,
            eta_lr=0.00125,  #0.000375,
            trainIters=750,
            alpha=12.0,
            lambda_h=-0.075,  #-0.0125, # negative = encourage entropy in actions
            useHolder=True,
            holderp=-2.0,
            useOnlyLeaves=False,
            gamma=0.0000005,  #0.00000025, #1.5
            xi=-0.00000125,  # -0.000005, #  0.00000000125
            valueF=greedyValueEstimator)

        s_0 = torch.unsqueeze(avar(torch.FloatTensor(
            exampleEnv.getStateRep())),
                              dim=0)
        tree = Tree(s_0,
                    ForwardModel,
                    SimPolicy,
                    greedyValueEstimator,
                    exampleEnv,
                    maxDepth=maxDepth)  #, branchingFactor=2)
        tree.measureLossAtTestTime()
        states, actions = tree.getBestPlan()
        print('Final Actions')
        for i in range(len(actions)):
            jq = actions[i][0].data.numpy().argmax()
            print('A' + str(i) + ':', jq, NavigationTask.actions[jq])
コード例 #7
0
def main():
    ####################################################
    trainingLSTM = False
    overwrite = False
    runHenaff = False
    testFM = False
    ###
    useFFANN = True
    trainingFFANN = False
    ####################################################
    if useFFANN:
        f_model_name = 'forward-ffann-stochastic.pt'
        exampleEnv = NavigationTask()
        f = ForwardModelFFANN(exampleEnv)
        if trainingFFANN:
            ts = "navigation-data-train-single-small.pickle"
            vs = "navigation-data-test-single-small.pickle"
            print('Reading Data')
            with open(ts, 'rb') as inFile:
                print('\tReading', ts)
                trainSet = pickle.load(inFile)
            with open(vs, 'rb') as inFile:
                print('\tReading', vs)
                validSet = pickle.load(inFile)

            f.train(trainSet, validSet)
            print('Saving to', f_model_name)
            torch.save(f.state_dict(), f_model_name)
        else:
            f.load_state_dict(torch.load(f_model_name))
            start = np.zeros(74, dtype=np.float32)
            start[0 + 4] = 1
            start[15 + 6] = 1
            start[15 + 15 + 0] = 1
            start[15 + 15 + 4 + 8] = 1
            start[15 + 15 + 4 + 15 + 7] = 1
            start[15 + 15 + 4 + 15 + 15 + 4] = 1.0

            f.test(start)

            for i in range(10):
                width, height = 15, 15
                p_0 = np.array([npr.randint(0, width), npr.randint(0, height)])
                start_pos = [p_0, r.choice(NavigationTask.oriens)]
                goal_pos = np.array(
                    [npr.randint(0, width),
                     npr.randint(0, height)])
                checkEnv = NavigationTask(width=width,
                                          height=height,
                                          agent_start_pos=start_pos,
                                          goal_pos=goal_pos,
                                          track_history=True,
                                          stochasticity=0.0,
                                          maxSteps=10)
                s_0 = checkEnv.getStateRep()
                a = np.zeros(10)
                a[npr.randint(0, 10)] = 1
                inval = np.concatenate((s_0, a))
                checkEnv.performAction(np.argmax(a))
                s_1 = checkEnv.getStateRep()
                f.test(inval, s_1)
                print('----')

    else:
        f_model_name = 'forward-lstm-stochastic.pt'
        s = 'navigation'  # 'transport'
        trainf, validf = s + "-data-train-small.pickle", s + "-data-test-small.pickle"
        print('Reading Data')
        train, valid = SeqData(trainf), SeqData(validf)
        f = ForwardModelLSTM(train.lenOfInput, train.lenOfState)
        if trainingLSTM:
            if os.path.exists(f_model_name) and not overwrite:
                print('Loading from', f_model_name)
                f.load_state_dict(torch.load(f_model_name))
            else:
                f.train(train, valid)
                print('Saving to', f_model_name)
                torch.save(f.state_dict(), f_model_name)
            print('Q-test')
            bdata, blabels, _ = valid.next(2000, nopad=True)
            acc1, _ = f._accuracyBatch(bdata, blabels, valid.env)
            print(acc1)
        if runHenaff:
            print('Loading from', f_model_name)
            f.load_state_dict(torch.load(f_model_name))
            #        seq,label = train.randomTrainingPair()
            #        start = seq[0][0:64]
            #       start[63] = 0
            #       start[63-15] = 0
            #       start[15+15+4+5] = 1
            #       start[15+15+4+15+5] = 1
            #       start
            start = np.zeros(64)
            start[0] = 1
            start[15] = 1
            start[15 + 15] = 1
            start[15 + 15 + 4 + 0] = 1
            start[15 + 15 + 4 + 15 + 2] = 1
            print(train.env.deconcatenateOneHotStateVector(start))
            #sys.exit(0)
            print('Building planner')
            planner = HenaffPlanner(f)
            print('Starting generation')
            planner.generatePlan(start, train.env, niters=150)
        if testFM:
            f.load_state_dict(torch.load(f_model_name))
            start = np.zeros(64)
            start[0 + 2] = 1
            start[15 + 3] = 1
            start[15 + 15 + 0] = 1
            start[15 + 15 + 4 + 5] = 1
            start[15 + 15 + 4 + 15 + 5] = 1
            action = np.zeros(10)
            deconRes = train.env.deconcatenateOneHotStateVector(start)
            print('Start state')
            print('px', np.argmax(deconRes[0]))
            print('py', np.argmax(deconRes[1]))
            print('orien', np.argmax(deconRes[2]))
            print('gx', np.argmax(deconRes[3]))
            print('gy', np.argmax(deconRes[4]))
            action[5] = 1.0
            stateAction = [
                torch.cat([(torch.FloatTensor(start)),
                           (torch.FloatTensor(action))])
            ]
            #print('SA:',stateAction)
            #print('Start State')
            #printState( stateAction[0][0:-10], train.env )
            print('Action', NavigationTask.actions[np.argmax(action)])
            f.reInitialize()
            seq = avar(torch.cat(stateAction).view(
                len(stateAction), 1, -1))  # [seqlen x batchlen x hidden_size]
            result = f.forward(seq)
            print('PredState')
            printState(result, train.env)
コード例 #8
0
def main():
    ####################################################
    trainingLSTM = False
    overwrite = False
    runHenaff = False
    testFM = False
    ###
    useFFANN = True
    trainingFFANN = False
    manualTest = False
    autoTest = False
    henaffHyperSearch = False
    runHenaffFFANN = True #True
    ####################################################
    if useFFANN:
        f_model_name = 'forward-ffann-noisy-wan-1.pt' # 6 gets 99% on 0.1% noise
        exampleEnv = NavigationTask()
        f = ForwardModelFFANN(exampleEnv)

        if trainingFFANN:
            ############
            ts = "navigation-data-train-single-small.pickle"
            vs = "navigation-data-test-single-small.pickle"
            tsx_noisy = "noisier-actNoise-navigation-data-single.pickle"
            preload_name = f_model_name
            saveName = 'forward-ffann-noisy-wan-2.pt'
            ############
            print('Reading Data')
            with open(ts,'rb') as inFile:
                print('\tReading',ts); trainSet = pickle.load(inFile)
            with open(vs,'rb') as inFile:
                print('\tReading',vs); validSet = pickle.load(inFile)
            if not preload_name is None:
                print('Loading from',f_model_name)
                f.load_state_dict( torch.load(f_model_name) )
            f.train(trainSet,validSet,noisyDataSetTxLoc=tsx_noisy,f_model_name=saveName)
            print('Saving to',saveName)
            torch.save(f.state_dict(), saveName)

        elif manualTest:
            def softmax(x):
                e_x = np.exp(x - np.max(x))
                return e_x / e_x.sum()
            ###
            #f_model_name = 'forward-ffann-noisy6.pt'
            ###
            f.load_state_dict( torch.load(f_model_name) )
            start = np.zeros(74, dtype=np.float32)
            start[0+4] = 1
            start[15+6] = 1
            start[15+15+0] = 1
            start[15+15+4+8] = 1
            start[15+15+4+15+7] = 1
            start[15+15+4+15+15+4] = 1.0
            f.test(start)
            print('-----\n','Starting manualTest loop')
            for i in range(5):
                width, height = 15, 15
                p_0 = np.array([npr.randint(0,width),npr.randint(0,height)])
                start_pos = [p_0, r.choice(NavigationTask.oriens)]
                goal_pos = np.array([ npr.randint(0,width), npr.randint(0,height) ])
                checkEnv = NavigationTask(
                    width=width, height=height, agent_start_pos=start_pos, goal_pos=goal_pos,
                    track_history=True, stochasticity=0.0, maxSteps=10)
                s_0 = checkEnv.getStateRep()
                #a1, a2 = np.zeros(10), np.zeros(10)
                #a1[ npr.randint(0,10) ] = 1
                #a2[ npr.randint(0,10) ] = 1
                numActions = 3
                currState = avar( torch.FloatTensor(s_0).unsqueeze(0) )
                print('Start State')
                f.printState( currState[0] )
                actionSet = []
                for j in range(numActions):
                    action = np.zeros( 10 )
                    action[ npr.randint(0,10) ] = 1
                    action += npr.randn( 10 )*0.1
                    action = softmax( action )
                    print('\tSoft Noisy Action ',j,'=',action)
                    #### Apply Gumbel Softmax ####
                    temperature = 0.01
                    logProbAction = torch.log( avar(torch.FloatTensor(action)) ) 
                    actiong = gumbel_softmax(logProbAction, temperature)
                    ##############################
                    print('\tGumbel Action ',j,'=',actiong.data.numpy())
                    actionSet.append( actiong )
                    checkEnv.performAction( np.argmax(action) )
                    a = actiong  # avar( torch.FloatTensor(actiong) )
                    currState = f.forward( torch.cat([currState[0],a]).unsqueeze(0) )
                    print("Intermediate State",j)
                    f.printState( currState[0] )
                #checkEnv.performAction(np.argmax(a1))
                #checkEnv.performAction(np.argmax(a2))
                s_1 = checkEnv.getStateRep()
                #inval = np.concatenate( (s_0,a1) )
                #outval1 = f.forward( avar(torch.FloatTensor(inval).unsqueeze(0)) )
                #print(outval1.shape)
                #print(a2.shape)
                #inval2 = np.concatenate( (outval1[0].data.numpy(),a2) )
                #outval2 = f.forward( avar(torch.FloatTensor(inval2).unsqueeze(0)) )
                for action in actionSet:
                    f.printAction(action)
                print('Predicted')
                f.printState( currState[0] )
                print('Actual')
                s1 = avar( torch.FloatTensor( s_1 ).unsqueeze(0) )
                f.printState( s1[0] ) 
                print("Rough accuracy", torch.sum( (currState - s1).pow(2) ).data[0] )
                #print('Predicted',currState.data[0].numpy())
                #print('Actual',s_1)
                #outval1 = f.test(inval,s_1)
                print('----\n')
        if autoTest:
            print('Loading from',f_model_name)
            f.load_state_dict( torch.load(f_model_name) )


        if runHenaffFFANN:
            print('Loading from',f_model_name)
            f.load_state_dict( torch.load(f_model_name) )
            start = np.zeros(64)
            start[0] = 1
            start[15] = 1
            start[15+15] = 1
            start[15+15+4+0] = 1
            start[15+15+4+15+4] = 1
            print(f.env.deconcatenateOneHotStateVector(start))
            #sys.exit(0)
            print('Building planner')
            planner = HenaffPlanner(f,maxNumActions=2)
            print('Starting generation')
            actions = planner.generatePlan(start,niters=100,extraVerbose=False)

        if henaffHyperSearch:
            print('Loading from',f_model_name)
            f.load_state_dict( torch.load(f_model_name) )
            ### Hyper-params ###
            lambda_h = 0.01  # Entropy strength
            eta = 0.5        # Learning rate
            ###
    else:
        f_model_name = 'forward-lstm-stochastic.pt'    
        s = 'navigation' # 'transport'
        trainf, validf = s + "-data-train-small.pickle", s + "-data-test-small.pickle"
        print('Reading Data')
        train, valid = SeqData(trainf), SeqData(validf)
        f = ForwardModelLSTM(train.lenOfInput,train.lenOfState)
        if trainingLSTM:
            if os.path.exists(f_model_name) and not overwrite:
                print('Loading from',f_model_name)
                f.load_state_dict( torch.load(f_model_name) )
            else:
                f.train(train,valid)
                print('Saving to',f_model_name)
                torch.save(f.state_dict(), f_model_name)
            print('Q-test')
            bdata, blabels, _ = valid.next(2000, nopad=True)
            acc1, _ = f._accuracyBatch(bdata,blabels,valid.env)
            print(acc1)
        if runHenaff:
            print('Loading from',f_model_name)
            f.load_state_dict( torch.load(f_model_name) )
    #        seq,label = train.randomTrainingPair()
    #        start = seq[0][0:64]
     #       start[63] = 0
     #       start[63-15] = 0
     #       start[15+15+4+5] = 1
     #       start[15+15+4+15+5] = 1
     #       start
            start = np.zeros(64)
            start[0] = 1
            start[15] = 1
            start[15+15] = 1
            start[15+15+4+0] = 1
            start[15+15+4+15+2] = 1
            print(train.env.deconcatenateOneHotStateVector(start))
            #sys.exit(0)
            print('Building planner')
            planner = HenaffPlanner(f)
            print('Starting generation')
            planner.generatePlan(start,train.env,niters=150)
        if testFM:
            f.load_state_dict( torch.load(f_model_name) )
            start = np.zeros(64)
            start[0+2] = 1
            start[15+3] = 1
            start[15+15+0] = 1
            start[15+15+4+5] = 1
            start[15+15+4+15+5] = 1
            action = np.zeros(10)
            deconRes = train.env.deconcatenateOneHotStateVector(start)
            print('Start state')
            print('px',    np.argmax(deconRes[0]) )
            print('py',    np.argmax(deconRes[1]) )
            print('orien', np.argmax(deconRes[2]) )
            print('gx',    np.argmax(deconRes[3]) )
            print('gy',    np.argmax(deconRes[4]) )
            action[5] = 1.0
            stateAction = [torch.cat([(torch.FloatTensor(start)), (torch.FloatTensor(action))])]
            #print('SA:',stateAction)
            #print('Start State')
            #printState( stateAction[0][0:-10], train.env )
            print('Action',NavigationTask.actions[np.argmax( action )])
            f.reInitialize()
            seq = avar(torch.cat(stateAction).view(len(stateAction), 1, -1)) # [seqlen x batchlen x hidden_size]
            result = f.forward(seq)
            print('PredState')
            printState( result, train.env )