Ejemplo n.º 1
0
def train_supervised(policy):
    step = 50000
    policy.train()
    for i in range(step):
        structure = random.choice(list(train_dict.keys()))
        objectiveSeq = train_dict[structure][
            0]  # TODO select randomly insr=tead of 1st
        N = len(objectiveSeq)
        structureTemp, energy = RNA.fold(
            rnalib.sequence_to_string(objectiveSeq))
        if structureTemp != structure:
            print('error 4534')
            exit()
        env = rnalib.RNAEnvironment(goal=structure, max_steps=1000)

        optimum_actions = getBestActions(env.sequence, objectiveSeq)
        # print(optimum_actions)
        if i % 1000 == 0:
            print('i=' + str(i) + '\t' + structure + '\tobj= ' +
                  str(objectiveSeq) + '\tinitSeq=' + str(env.sequence))
        for a_tup in optimum_actions:
            rnalib.update(actionTuple2Array(N, a_tup), policy, env)
            reward = env.step(a_tup)
            ###########
            if i % 10000 == 0:
                a = array2actionTuple(
                    policy.get_action(np.expand_dims(env.state.T, axis=0)))
                print('train_action = ' + str(a_tup) + '\t policy = ' + str(a))
Ejemplo n.º 2
0
def TestTabularPolicy(policy, max_steps=500, showStat=True):
    structure = random.choice(list(train_dict.keys()))
    objectiveSeq = train_dict[structure][0]
    N = len(objectiveSeq)
    env = rnalib.RNAEnvironment(goal=structure, max_steps=100)
    # optimum_actions = getBestActions(env.sequence, objectiveSeq)
    # print(env.terminated)
    if showStat:
        print('objective structure: ' + structure)
    while not env.terminated:
        a = epsilonGreedyAction(env.state, policy, epsilon=0.0)
        prevState = env.state
        r = env.step(a)
        if showStat:
            print('step: ' + str(env.count) + '\tseq = ' + str(env.sequence) +
                  '\t policy Actn = ' + str(a) + '\t objSeQ: ' +
                  str(objectiveSeq))
        if r == 5:
            if showStat:
                print('Success')
        flag = 0
        while (env.state == prevState).all() and not env.terminated:
            a = epsilonGreedyAction(env.state, policy, epsilon=1.0)
            r = env.step(a)
            flag = 1
        if flag == 1:
            if showStat:
                print('step: ' + str(env.count) + '\tseq = ' +
                      str(env.sequence) + '\t RANDOM Actn = ' + str(a) +
                      '\t objSeQ: ' + str(objectiveSeq))
        if r == 5:
            if showStat:
                print('***SUCCESS***')
            return env.count
Ejemplo n.º 3
0
def train_DynaQ(alpha=0.5,
                gamma=1.0,
                numEpisodes=500,
                maxPolicySize=10 * 1024 * 1024):
    policy = {}
    for e in range(numEpisodes):
        structure = random.choice(list(train_dict.keys()))
        N = len(structure)
        env = rnalib.RNAEnvironment(goal=structure, max_steps=10000)
        exp_replay = []
        S = env.state
        a = epsilonGreedyAction(S, policy, epsilon=0.6)
        while not env.terminated:
            r = env.step(a)
            Sprime = env.state
            aprime = epsilonGreedyAction(Sprime, policy, epsilon=0.6)
            exp_replay.append((S, a, r, Sprime, aprime))
            S = Sprime
            a = aprime
        for j in range(100):
            for (S, a, r, Sprime,
                 aprime) in random.sample(exp_replay, len(exp_replay)):
                policy[(list_to_tuple(S), a)] = policy[
                    (list_to_tuple(S), a)] + alpha * (r + gamma * policy[
                        (list_to_tuple(Sprime), aprime)] - policy[
                            (list_to_tuple(S), a)])

        print(
            str(e) + '\tpolicy Size: ' +
            str(sys.getsizeof(policy) / (1024 * 1024)) + ' Mb')
        if sys.getsizeof(policy) > maxPolicySize:
            break
    # print(policy)
    # s = []
    # testStrucks =  random.sample(list(train_dict.keys()),min(100, len(train_dict.keys())))

    # def findRNA(structure, policy, max_steps = 1000):
    #     env = rnalib.RNAEnvironment(goal = structure, max_steps = max_steps)
    #     while not env.terminated:
    #         a = epsilonGreedyAction(env.state, policy, epsilon = 0.0)
    #         prevState = env.state
    #         r = env.step(a)
    #         if r == 5:
    #             return 1
    #         while (env.state == prevState).all() and not env.terminated:
    #             a = epsilonGreedyAction(env.state, policy, epsilon = 1.0)
    #             r = env.step(a)
    #         if r == 5:
    #             return 1
    #     return 0
    # count = 0
    # for i in range(100):
    #     print(i,count)
    #     st = testStrucks[i%len(testStrucks)]
    #     count +=findRNA(st, policy)
    # print('************')
    # print(count)
    return policy
Ejemplo n.º 4
0
def TestRandomPolicy():
    structure = random.choice(list(train_dict.keys()))
    objectiveSeq = train_dict[structure][0]
    N = len(objectiveSeq)
    env = rnalib.RNAEnvironment(goal = structure, max_steps = 10000)
    max_step = 10e4
    steps = 0
    while not env.terminated and steps < max_step:
        a = (random.randint(0,N-1),random.randint(0,3))
        r = env.step(a)
        steps += 1
    return steps
Ejemplo n.º 5
0
def findRNA(structure, policy=None, useRandomPolicy=False):
    env = rnalib.RNAEnvironment(goal=structure, max_steps=1000)
    while not env.terminated:
        if useRandomPolicy:
            a = (random.randint(0, len(structure) - 1), random.randint(0, 3))
        else:
            if policy is None:
                print('Need RNAPolicy if not random')
            a = array2actionTuple(
                policy.get_action(np.expand_dims(env.state.T, axis=0)))
        prevState = env.state
        r = env.step(a)
        if r == 5:
            return 1
        if (env.state == prevState).all():
            a = (random.randint(0, len(structure) - 1), random.randint(0, 3))
            r = env.step(a)
        if r == 5:
            return 1
    return 0
Ejemplo n.º 6
0
def DQNPolicy(random_policy=False):
	counter,total_steps = 0,0
	for puzzle in output:
		if puzzle[0] == 'A' or puzzle[0] == 'U' or puzzle[0] == 'G' or puzzle[0] == 'C':
			structure, _ = RNA.fold(puzzle)
		else:
			structure = puzzle

		env = rnalib.RNAEnvironment(goal=structure,max_steps=1000)

		if not random_policy:
			policy = rnalib.RNA_BiLSTM_Policy(hidden_size= 15, num_layers= 4)
			policy = torch.load('DQN_policy')

		# print(len(structure))

		steps = 0
		start_time = time.time()
		end_time = start_time + 30*60
		f = 0
		while not env.terminated and time.time() < end_time:
			if random_policy:
				a = (random.randint(0,len(structure)-1),random.randint(0,3))
			else:
				a = array2actionTuple(policy.get_action(np.expand_dims(env.state.T, axis=0)))
			r = env.step(a)
			steps += 1
			if r == 5:
				print('SUCCESS',structure,len(structure),steps)
				counter += 1
				total_steps += steps
				f = 1
				break

		if f == 0:
			print('FAILED',structure,len(structure),steps)
			total_steps += steps

	print("TOTAL SOLVED : ",counter, "AVERAGE :", total_steps/100 )
Ejemplo n.º 7
0
def TestPolicy(policy, max_steps=500, showStat=True):
    structure = random.choice(list(train_dict.keys()))
    objectiveSeq = train_dict[structure][0]
    N = len(objectiveSeq)
    env = rnalib.RNAEnvironment(goal=structure, max_steps=max_steps)
    if showStat:
        print('objective structure: ' + structure)
    while not env.terminated:
        a = array2actionTuple(
            policy.get_action(np.expand_dims(env.state.T, axis=0)))
        prevState = env.state
        r = env.step(a)
        if showStat:
            print('step: ' + str(env.count) + '\tseq = ' + str(env.sequence) +
                  '\t policy Actn = ' + str(a) + '\t objSeQ: ' +
                  str(objectiveSeq))
        if r == 5:
            if showStat:
                print('***SUCCESS***')
            return env.count
            break
        flag = 0
        while (env.state == prevState).all() and not env.terminated:
            a = epsilonGreedyActionDQN(env.state, policy, epsilon=1.0)
            r = env.step(a)
            flag = 1
        if flag == 1:
            if showStat:
                print('step: ' + str(env.count) + '\tseq = ' +
                      str(env.sequence) + '\t RANDOM Actn = ' + str(a) +
                      '\t objSeQ: ' + str(objectiveSeq))
        if r == 5:
            if showStat:
                print('***SUCCESS***')
            return env.count
    return -1
Ejemplo n.º 8
0
def train_DQN(policy):
    numEpisodes = 50
    epsilon = 0.3
    miniBatchSz = 50
    alpha = 0.3
    gamma = 1.0
    frac = 0.3
    policy.train()

    for e in range(numEpisodes):
        structure = random.choice(list(train_dict.keys()))
        N = len(structure)
        env = rnalib.RNAEnvironment(goal=structure, max_steps=10000)

        exp_replay = []
        print('STRUCTURE: ' + structure)
        num_success = 0
        while len(exp_replay) <= 500 or num_success <= miniBatchSz * frac:
            S = env.state
            a = epsilonGreedyActionDQN(S, policy,
                                       epsilon=epsilon)  #actionTuple
            while not env.terminated:
                r = env.step(a)
                Sprime = env.state
                aprime = epsilonGreedyActionDQN(Sprime,
                                                policy,
                                                epsilon=epsilon)
                Y = r if env.terminated else r + gamma * np.max(
                    policy.get_action(np.expand_dims(Sprime.T, axis=0)))

                if env.terminated:
                    # print('reward = 1 on DQNtrain')
                    exp_replay.append((S, a, Y, None))
                    num_success += 1
                    break
                exp_replay.append((S, a, Y, Sprime))

                S = Sprime
                a = aprime
            env.reset()
            # print('##################### env reset exp_replay len = ' + str(len(exp_replay))+' num_success'+str(num_success))
        # if True:
        #     print('---------')
        #     for (_,a,y,Sprime) in exp_replay[:5]:
        #         print(a,y,(Sprime is None))
        #     tempS = None
        #     for (tempS,a,y,Sprime) in [e for e in exp_replay if e[-1] is None][:5]:
        #         print(a,y,(Sprime is None))
        #     print(policy.get_action(np.expand_dims(tempS.T, axis=0)))
        #     print(policy.get_action(np.expand_dims(random.choice(exp_replay)[0].T, axis=0)))

        #     # exit()
        print('training from exp_replay len = ' + str(len(exp_replay)) +
              ' \tEpisode: ' + str(e))
        losses = []
        for j in range(1000):
            ## force x% of minibatch to be terminal states
            miniBatch = random.sample(exp_replay,
                                      int(miniBatchSz * (1 - frac)))
            miniBatch += random.sample(
                [e for e in exp_replay if e[-1] is None],
                int(miniBatchSz * frac))
            # print(miniBatch)
            losses.append(rnalib.updateParam(miniBatch, model=policy,
                                             lr=alpha))
            if j % 200 == 0:
                print(losses[-1])
Ejemplo n.º 9
0
        for b in range(4):
            if a[i, b] == a[idx, base]:
                base_list.append((i, b))

    return random.choice(base_list)


# Attempt to solve a single puzzle.  The target structure in dot bracket notation
# should be supplied as a command line argument.
puzzle = sys.argv[1]
if puzzle[0] == 'A' or puzzle[0] == 'U' or puzzle[0] == 'G' or puzzle[0] == 'C':
    structure, _ = RNA.fold(puzzle)
else:
    structure = puzzle

env = rnalib.RNAEnvironment(goal=structure, max_steps=1000)

policy = rnalib.RNA_BiLSTM_Policy(hidden_size=15, num_layers=4)
policy = torch.load('DQN_policy')

print(len(structure))
print(puzzle)
print(structure)
steps = 0
start_time = time.time()
end_time = start_time + 30 * 60
f = 0
while not env.terminated and time.time() < end_time:
    a = array2actionTuple(
        policy.get_action(np.expand_dims(env.state.T, axis=0)))
    # a = (random.randint(0,len(structure)-1),random.randint(0,3))