Beispiel #1
0
class Agent:
    def __init__(self, model, memory=None, memory_size=500, nb_frames=None):
        assert len(
            model.get_output_shape_at(0)
        ) == 2, "Model's output shape should be (nb_samples, nb_actions)."
        if memory:
            self.memory = memory
        else:
            self.memory = ExperienceReplay(memory_size)
        if not nb_frames and not model.get_input_shape_at(0)[1]:
            raise Exception("Missing argument : nb_frames not provided")
        elif not nb_frames:
            nb_frames = model.get_input_shape_at(0)[1]
        elif model.get_input_shape_at(
                0
        )[1] and nb_frames and model.get_input_shape_at(0)[1] != nb_frames:
            raise Exception(
                "Dimension mismatch : time dimension of model should be equal to nb_frames."
            )
        self.model = model
        self.nb_frames = nb_frames
        self.frames = None

    @property
    def memory_size(self):
        return self.memory.memory_size

    @memory_size.setter
    def memory_size(self, value):
        self.memory.memory_size = value

    def reset_memory(self):
        self.exp_replay.reset_memory()

    def check_game_compatibility(self, game):
        #if len(self.model.input_layers_node_indices) != 1:
        #raise Exception('Multi node input is not supported.')
        game_output_shape = (1, None) + game.get_frame().shape
        if len(game_output_shape) != len(self.model.get_input_shape_at(0)):
            raise Exception(
                'Dimension mismatch. Input shape of the model should be compatible with the game.'
            )
        else:
            for i in range(len(self.model.get_input_shape_at(0))):
                if self.model.get_input_shape_at(0)[i] and game_output_shape[
                        i] and self.model.get_input_shape_at(
                            0)[i] != game_output_shape[i]:
                    raise Exception(
                        'Dimension mismatch. Input shape of the model should be compatible with the game.'
                    )
        if len(
                self.model.get_output_shape_at(0)
        ) != 2 or self.model.get_output_shape_at(0)[1] != game.nb_actions:
            raise Exception(
                'Output shape of model should be (nb_samples, nb_actions).')

    def get_game_data(self, game):
        frame = game.get_frame()
        if self.frames is None:
            self.frames = [frame] * self.nb_frames
        else:
            self.frames.append(frame)
            self.frames.pop(0)
        return np.expand_dims(self.frames, 0)

    def clear_frames(self):
        self.frames = None

    def train(self,
              game,
              nb_epoch=1000,
              batch_size=50,
              gamma=0.9,
              epsilon=[1., .1],
              epsilon_rate=0.5,
              reset_memory=False,
              observe=0,
              checkpoint=None,
              total_sessions=0,
              session_id=1):
        self.check_game_compatibility(game)

        ts = int(time.time())
        #fn = "gold-{}.csv".format(ts)

        #fn = "9nyc-250-1000-epr8-heat-adam.csv"
        #fn = "400-rl-nopool.csv"
        fn = "3-normal.csv"
        fn2 = "heat.csv"
        #advice_type = "OA"
        advice_type = "OA"
        meta_advice_type = "HFHA"
        #meta_feedback_frequency = 0.1
        #meta_feedback_frequency = 0.5 #HF!!!
        meta_feedback_frequency = 0.1  #LF!!!

        heatmap = [[0] * 20 for i in range(20)]

        if session_id == 1:
            advice_type = "OA"
        if session_id == 2:
            advice_type = "NA"
        if session_id == 3:
            advice_type = "RL"
        # print(heatmap)
        # with open("dummyheat.csv",'a') as f2:
        # 	csvWriter = csv.writer(f2,delimiter=',')
        # 	csvWriter.writerows(heatmap)
        # if ( session_id >= 3 and session_id < 5 ):
        # 	print("Switching to HFLA")
        # 	meta_advice_type = "HFLA"
        # 	#meta_feedback_frequency = 0.1
        # elif ( session_id >= 5 and session_id < 7 ):
        # 	print("Switching to LFHA")
        # 	meta_feedback_frequency = 0.1
        # 	meta_advice_type = "LFHA"
        # elif ( session_id >= 7 and session_id < 9 ):
        # 	print("Switching to LFLA")
        # 	meta_advice_type = "LFLA"
        # elif ( session_id >= 9 and session_id < 11 ):
        # 	advice_type = "OA"
        # 	print("Switching to NA HFLA")
        # 	meta_advice_type = "HFLA"
        # 	meta_feedback_frequency = 0.5
        # elif ( session_id >= 11 and session_id < 13 ):
        # 	print("Switching to NA HFLA")
        # 	meta_advice_type = "HFLA"
        # 	#meta_feedback_frequency = 0.1
        # elif ( session_id >= 13 and session_id < 15 ):
        # 	print("Switching to NA LFHA")
        # 	meta_feedback_frequency = 0.1
        # 	meta_advice_type = "LFHA"
        # elif ( session_id >= 15 and session_id < 17 ):
        # 	print("Switching to NA LFLA")
        # 	meta_advice_type = "LFLA"

        # if ( session_id >= 2 and session_id < 3 ):
        # 	meta_feedback_frequency = 0.1
        # 	print("Switching to LFHA")
        # 	advice_type = "OA"
        # 	meta_advice_type = "LFHA"
        # 	meta_feedback_frequency = 0.1
        # elif ( session_id >= 3 and session_id < 4 ):
        # 	advice_type = "NA"
        # 	print("Switching to NA LFHA")
        # 	meta_feedback_frequency = 0.1
        # 	meta_advice_type = "LFHA"
        # elif ( session_id >= 4 and session_id < 5 ):
        # 	print("Switching to NA LFLA")
        # 	meta_feedback_frequency = 0.1
        # 	advice_type = "NA"
        # 	meta_advice_type = "LFLA"
        # elif ( session_id >= 5 and session_id < 6 ):
        # 	advice_type = "OA"
        # 	print("Switching to OA HFHA")
        # 	meta_advice_type = "HFHA"
        # 	meta_feedback_frequency = 0.5
        # elif ( session_id >= 6 and session_id < 7 ):
        # 	advice_type = "NA"
        # 	meta_feedback_frequency = 0.5
        # 	print("Switching to NA HFHA")
        # 	meta_advice_type = "HFHA"
        # 	meta_feedback_frequency = 0.5
        # elif ( session_id >= 7 and session_id < 8 ):
        # 	advice_type = "NA"
        # 	print("Switching to NA HFLA")
        # 	meta_feedback_frequency = 0.5
        # 	meta_advice_type = "HFLA"
        # elif ( session_id >= 8 and session_id < 9 ):
        # 	advice_type = "OA"
        # 	meta_feedback_frequency = 0.5
        # 	print("Switching to OA HFLA")
        # 	meta_advice_type = "HFLA"

        # if ( session_id >= 4 and session_id < 7 ):
        # 	#print("Switching to LFLA")
        # 	advice_type = "RL"
        # 	#meta_advice_type = "LFLA"
        # elif ( session_id >= 7 and session_id < 10 ):
        # 	# with open("1RLheat.csv",'a') as f2:
        # 	# 	csvWriter = csv.writer(f2,delimiter=',')
        # 	# 	csvWriter.writerows(heatmap)
        # 	# 	heatmap = [ [0]*20 for i in range(20)]
        # 	advice_type = "NA"
        # 	#print("Switching to LFHA")
        # 	#meta_feedback_frequency = 0.1
        # 	#meta_advice_type = "LFHA"
        # elif ( session_id >= 10 ):
        # 	# with open("1NAheat.csv",'a') as f2:
        # 	# 	csvWriter = csv.writer(f2,delimiter=',')
        # 	# 	csvWriter.writerows(heatmap)
        # 	# 	heatmap = [ [0]*20 for i in range(20)]
        # 	#print("Switching to LFLA")

        # 	#meta_advice_type = "LFLA"
        # 	advice_type = "NA"

        # with open(fn,'w') as f:
        # 	f.write('session_id,advice_type,time,epoch,frames,score,win_perc,loss'+'\n')
        # 	f.flush()
        # 	f.close()            with open(fn,'a') as f:
        with open(fn, 'a') as f:
            total_frames = 0
            #f.write('session_id,advice_type,time,epoch,frames,score,win_perc,loss'+'\n')
            #f.flush()
            if type(epsilon) in {tuple, list}:
                delta = ((epsilon[0] - epsilon[1]) / (nb_epoch * epsilon_rate))
                final_epsilon = epsilon[1]
                epsilon = epsilon[0]
            else:
                final_epsilon = epsilon
            model = self.model
            nb_actions = model.get_output_shape_at(0)[-1]
            win_count = 0
            rolling_win_window = []
            max_obs_loss = -99999999999999999
            m_loss = -99999999
            for epoch in range(nb_epoch):
                lastAdviceStep = 0
                adviceGiven = 0
                adviceAttempts = 0
                modelActions = 0
                print(heatmap)
                loss = 0.
                game.reset()
                self.clear_frames()
                if reset_memory:
                    self.reset_memory()
                game_over = False
                S = self.get_game_data(game)
                savedModel = False
                while not game_over:
                    a = 0
                    if advice_type == "RL":
                        if np.random.random() < epsilon or epoch < observe:
                            a = int(np.random.randint(game.nb_actions))
                            #print("Random Action")
                        else:
                            q = model.predict(
                                S
                            )  #use the prediction confidence to determine whether to ask the player for help
                            qs = model.predict_classes(S)
                            #a = int(np.argmax(qs[0]))
                            #highest_conf = np.amax(q)
                            #print("Game Grid: {}".format(game.get_grid()))
                            #print("Highest MSE Confidence = {}".format(highest_conf))
                            #a = int(np.argmax(q[0]))
                            a = int(np.argmax(qs[0]))
                    if advice_type == "OA":
                        if np.random.random() < epsilon or epoch < observe:
                            a = int(np.random.randint(game.nb_actions))
                            #print("Random Action")
                        else:
                            q = model.predict(
                                S
                            )  #use the prediction confidence to determine whether to ask the player for help
                            qs = model.predict_classes(S)
                            #print(qs)
                            #print(q)
                            highest_loss = abs(np.amax(q))  #added ABS
                            lowest_loss = abs(np.amin(q))
                            #print(highest_loss)
                            #print("HighestLoss:{}".format(highest_loss))
                            if highest_loss > max_obs_loss and highest_loss != 0:
                                max_obs_loss = highest_loss
                                #print("MaxLoss:{}".format(highest_loss))
                            #inn = highest_loss / max_obs_loss
                            relative_cost = np.power(
                                lowest_loss / max_obs_loss, 0.5)
                            #print("RelCostA:{}".format(relative_cost))
                            if relative_cost < 1e-20:
                                relative_cost = 1e-20
                            relative_cost = -1 / (np.log(relative_cost) - 1)
                            #print("RelCostB:{}".format(relative_cost))
                            confidence_score_max = 1
                            confidence_score_min = 0.01
                            feedback_chance = confidence_score_min + (
                                confidence_score_max -
                                confidence_score_min) * relative_cost

                            if feedback_chance < 0.01:
                                feedback_chance = 0.01
                            #if feedback_chance < 0.1:
                            giveAdvice = False
                            if (random.random() < meta_feedback_frequency):
                                giveAdvice = True
                                adviceAttempts = adviceAttempts + 1
                            if (relative_cost <= 0.25 and game.stepsTaken >=
                                (lastAdviceStep + 10)) or giveAdvice == False:
                                #print("HC: {}".format(max_obs_loss))
                                modelActions = modelActions + 1
                                #print("Highest Loss: {} RC: {} POS: Q0:{}".format(highest_loss, relative_cost, q[0]))
                                a = int(np.argmax(qs[0]))
                            else:
                                if random.random() < .5 and (
                                        meta_advice_type == "HFLA"
                                        or meta_advice_type == "LFLA"):
                                    lastAdviceStep = game.stepsTaken
                                    a = int(np.random.randint(game.nb_actions))
                                    adviceGiven = adviceGiven + 1
                                    #print("Taking BAD Player Action")
                                else:
                                    lastAdviceStep = game.stepsTaken
                                    adviceGiven = adviceGiven + 1
                                    x = game.location[0]
                                    z = game.location[1]
                                    yaw = game.location[2]
                                    a = -1
                                    #print(yaw)
                                    if z <= 6:
                                        if x < 12:
                                            #print("Segment1")
                                            if yaw == 270:
                                                a = 0
                                            if yaw == 180:
                                                a = 1
                                            if yaw == 90:
                                                a = 3
                                            if yaw == 0:
                                                a = 2
                                        elif x > 15:
                                            #print("Segment2")
                                            if yaw == 90:
                                                a = 0
                                            if yaw == 180:
                                                a = 2
                                            if yaw == 0:
                                                a = 1
                                            if yaw == 270:
                                                a = 3
                                        else:
                                            #print("Segment3")
                                            if yaw == 0:
                                                a = 0
                                            if yaw == 270:
                                                a = 1
                                            if yaw == 90:
                                                a = 2
                                            if yaw == 180:
                                                a = 3
                                    elif (x >= 7) and ((z == 7) or (z == 8) or
                                                       (z == 9) or (z == 10) or
                                                       (z == 11) or (z == 12)):
                                        #print("Segment4")
                                        if yaw == 90:
                                            a = 0
                                        if yaw == 180:
                                            a = 2
                                        if yaw == 0:
                                            a = 1
                                        if yaw == 270:
                                            a = 3
                                    elif ((x < 7) and (x > 3)) and (
                                        (z == 7) or (z == 8) or (z == 9) or
                                        (z == 10) or (z == 11) or (z == 12)):
                                        if yaw == 0:
                                            a = 0
                                        if yaw == 270:
                                            a = 1
                                        if yaw == 90:
                                            a = 2
                                        if yaw == 180:
                                            a = 3
                                    elif ((x < 3)) and ((z == 7) or (z == 8) or
                                                        (z == 9) or
                                                        (z == 10) or
                                                        (z == 11) or
                                                        (z == 12)):
                                        if yaw == 0:
                                            a = 2
                                        if yaw == 270:
                                            a = 0
                                        if yaw == 180:
                                            a = 1
                                        if yaw == 90:
                                            a = 3
                                    elif (z == 14) or (z == 15):
                                        if yaw == 0:
                                            a = 0
                                        if yaw == 270:
                                            a = 1
                                        if yaw == 90:
                                            a = 2
                                        if yaw == 180:
                                            a = 3
                                    elif (z == 17) or (z == 16):
                                        #print("Segment6")
                                        if yaw == 270:
                                            a = 0
                                        if yaw == 180:
                                            a = 1
                                        if yaw == 0:
                                            a = 2
                                        if yaw == 90:
                                            a = 3
                                    elif (z > 17):
                                        #print("Segment6")
                                        if yaw == 270:
                                            a = 2
                                        if yaw == 180:
                                            a = 0
                                        if yaw == 0:
                                            a = 3
                                        if yaw == 90:
                                            a = 1
                                    else:
                                        a = int(
                                            np.random.randint(game.nb_actions))

                                    if a == -1:
                                        a = int(
                                            np.random.randint(game.nb_actions))
                                    # if z < 6 and x < 13:
                                    # 	print("Segment1")
                                    # 	if yaw == 270:
                                    # 		a = 0
                                    # 	else:
                                    # 		a = 1
                                    # elif z < 8 and x >= 13:
                                    # 	print("Segment2")
                                    # 	if yaw == 0:
                                    # 		a = 0
                                    # 	else:
                                    # 		a = 1
                                    # elif z >= 8 and x == 13:
                                    # 	print("Segment3")
                                    # 	if yaw == 90:
                                    # 		a = 0
                                    # 	else:
                                    # 		a = 1
                                    # elif z >= 8 and z<= 17 and x < 6:
                                    # 	print("Segment4")
                                    # 	if yaw == 0:
                                    # 		a = 0
                                    # 	else:
                                    # 		a = 1
                                    # elif z > 18 and x < 18:
                                    # 	print("Segment5")
                                    # 	if yaw == 270:
                                    # 		a = 0
                                    # 	else:
                                    # 		a = 1
                                    # else:
                                    # 	a = int(np.argmax(q[0]))

                                #print("Game Grid: {}".format(game.get_grid()))
                                #print("Highest MSE Confidence = {}".format(highest_conf))

                    if advice_type == "NA":
                        if np.random.random() < epsilon or epoch < observe:
                            a = int(np.random.randint(game.nb_actions))
                            game.play(a)
                            heatmap[game.location[0]][
                                game.location[1]] = heatmap[game.location[0]][
                                    game.location[1]] + 1
                            #f2.write('{},{},{},{}\n'.format(advice_type,game.location[0],game.location[1],1 ))
                            #f2.flush()
                            r = game.get_score()
                            S_prime = self.get_game_data(game)
                            game_over = game.is_over()
                            transition = [S, a, r, S_prime, game_over]
                            self.memory.remember(*transition)
                            S = S_prime
                            #print("Random Action")
                        else:
                            q = model.predict(
                                S
                            )  #use the prediction confidence to determine whether to ask the player for help
                            qs = model.predict_classes(S)
                            highest_loss = abs(np.amax(q))  #added ABS
                            lowest_loss = abs(np.amin(q))
                            #print("HighestLoss:{}".format(highest_loss))
                            if highest_loss > max_obs_loss and highest_loss != 0:
                                max_obs_loss = highest_loss
                                #print("MaxLoss:{}".format(highest_loss))
                            #inn = highest_loss / max_obs_loss
                            relative_cost = np.power(
                                lowest_loss / max_obs_loss, 0.5)
                            #print("RelCostA:{}".format(relative_cost))
                            if relative_cost < 1e-20:
                                relative_cost = 1e-20
                            relative_cost = -1 / (np.log(relative_cost) - 1)
                            #print("RelCostB:{}".format(relative_cost))
                            confidence_score_max = 1
                            confidence_score_min = 0.01
                            feedback_chance = confidence_score_min + (
                                confidence_score_max -
                                confidence_score_min) * relative_cost
                            #feedback_chance = random.random()
                            #print("Feedback Chance: {}".format(feedback_chance))
                            if feedback_chance < 0.01:
                                feedback_chance = 0.01
                            #if feedback_chance > meta_feedback_frequency:
                            #if feedback_chance < 0.1:
                            #print(relative_cost)
                            giveAdvice = False
                            if (random.random() < meta_feedback_frequency):
                                giveAdvice = True
                                adviceAttempts = adviceAttempts + 1
                            if (relative_cost <= 0.25 and game.stepsTaken >=
                                (lastAdviceStep + 10)) or giveAdvice == False:
                                #print("Taking Model Action")
                                #print("HC: {}".format(max_obs_loss))
                                #print("Confidence: {} RC: {}".format(feedback_chance, relative_cost))
                                modelActions = modelActions + 1
                                #a = int(np.argmin(q[0]))
                                a = int(np.argmax(qs[0]))
                                game.play(a)
                                heatmap[game.location[0]][
                                    game.location[1]] = heatmap[
                                        game.location[0]][game.location[1]] + 1
                                #f2.write('{},{},{},{}\n'.format(advice_type,game.location[0],game.location[1],1 ))
                                #f2.flush()
                                r = game.get_score()
                                S_prime = self.get_game_data(game)
                                game_over = game.is_over()
                                transition = [S, a, r, S_prime, game_over]
                                self.memory.remember(*transition)
                                S = S_prime
                            else:
                                #print("Taking Player Action")
                                if random.random() < .5 and (
                                        meta_advice_type == "HFLA"
                                        or meta_advice_type == "LFLA"):
                                    a = int(np.random.randint(game.nb_actions))
                                    adviceGiven = adviceGiven + 1
                                    game.play(a)
                                    heatmap[game.location[0]][game.location[
                                        1]] = heatmap[game.location[0]][
                                            game.location[1]] + 1
                                    lastAdviceStep = game.stepsTaken
                                    #f2.write('{},{},{},{}\n'.format(advice_type,game.location[0],game.location[1],1 ))
                                    #f2.flush()
                                    r = game.get_score()
                                    S_prime = self.get_game_data(game)
                                    game_over = game.is_over()
                                    transition = [S, a, r, S_prime, game_over]
                                    self.memory.remember(*transition)
                                    S = S_prime
                                    if game_over == False:
                                        #game.play(checkForBestMove(game.location[0],game.location[1],game.location[2]))
                                        a = int(
                                            np.random.randint(game.nb_actions))
                                        game.play(a)
                                        heatmap[game.location[0]][
                                            game.location[1]] = heatmap[
                                                game.location[0]][
                                                    game.location[1]] + 1
                                        #f2.write('{},{},{},{}\n'.format(advice_type,game.location[0],game.location[1],1 ))
                                        #f2.flush()
                                        r = game.get_score()
                                        S_prime = self.get_game_data(game)
                                        game_over = game.is_over()
                                        transition = [
                                            S, a, r, S_prime, game_over
                                        ]
                                        self.memory.remember(*transition)
                                        S = S_prime
                                        # if game_over == False:
                                        # 	game.play(checkForBestMove(game.location[0],game.location[1],game.location[2]))
                                        # 	heatmap[game.location[0]][game.location[1]] = heatmap[game.location[0]][game.location[1]] + 1
                                        # 	#f2.write('{},{},{},{}\n'.format(advice_type,game.location[0],game.location[1],1 ))
                                        # 	#f2.flush()
                                        # 	r = game.get_score()
                                        # 	S_prime = self.get_game_data(game)
                                        # 	game_over = game.is_over()
                                        # 	transition = [S, a, r, S_prime, game_over]
                                        # 	self.memory.remember(*transition)
                                        # 	S = S_prime
                                    #print("Taking BAD Player Action")
                                else:
                                    adviceGiven = adviceGiven + 1
                                    lastAdviceStep = game.stepsTaken
                                    x = game.location[0]
                                    z = game.location[1]
                                    yaw = game.location[2]
                                    #print(x)
                                    #print(z)
                                    a = -1
                                    #print(yaw)
                                    if z <= 6:
                                        if x < 12:
                                            #print("Segment1")
                                            if yaw == 270:
                                                a = 0
                                            if yaw == 180:
                                                a = 1
                                            if yaw == 90:
                                                a = 3
                                            if yaw == 0:
                                                a = 2
                                        elif x > 15:
                                            #print("Segment2")
                                            if yaw == 90:
                                                a = 0
                                            if yaw == 180:
                                                a = 2
                                            if yaw == 0:
                                                a = 1
                                            if yaw == 270:
                                                a = 3
                                        else:
                                            #print("Segment3")
                                            if yaw == 0:
                                                a = 0
                                            if yaw == 270:
                                                a = 1
                                            if yaw == 90:
                                                a = 2
                                            if yaw == 180:
                                                a = 3
                                    elif (x >= 7) and ((z == 7) or (z == 8) or
                                                       (z == 9) or (z == 10) or
                                                       (z == 11) or (z == 12)):
                                        #print("Segment4")
                                        if yaw == 90:
                                            a = 0
                                        if yaw == 180:
                                            a = 2
                                        if yaw == 0:
                                            a = 1
                                        if yaw == 270:
                                            a = 3
                                    elif ((x < 7) and (x > 3)) and (
                                        (z == 7) or (z == 8) or (z == 9) or
                                        (z == 10) or (z == 11) or (z == 12)):
                                        if yaw == 0:
                                            a = 0
                                        if yaw == 270:
                                            a = 1
                                        if yaw == 90:
                                            a = 2
                                        if yaw == 180:
                                            a = 3
                                    elif ((x < 3)) and ((z == 7) or (z == 8) or
                                                        (z == 9) or
                                                        (z == 10) or
                                                        (z == 11) or
                                                        (z == 12)):
                                        if yaw == 0:
                                            a = 2
                                        if yaw == 270:
                                            a = 0
                                        if yaw == 180:
                                            a = 1
                                        if yaw == 90:
                                            a = 3
                                    elif (z == 14) or (z == 15):
                                        if yaw == 0:
                                            a = 0
                                        if yaw == 270:
                                            a = 1
                                        if yaw == 90:
                                            a = 2
                                        if yaw == 180:
                                            a = 3
                                    elif (z == 17) or (z == 16):
                                        #print("Segment6")
                                        if yaw == 270:
                                            a = 0
                                        if yaw == 180:
                                            a = 1
                                        if yaw == 0:
                                            a = 2
                                        if yaw == 90:
                                            a = 3
                                    elif (z > 17):
                                        #print("Segment6")
                                        if yaw == 270:
                                            a = 2
                                        if yaw == 180:
                                            a = 0
                                        if yaw == 0:
                                            a = 3
                                        if yaw == 90:
                                            a = 1
                                    else:
                                        a = int(
                                            np.random.randint(game.nb_actions))

                                    if a == -1:
                                        a = int(
                                            np.random.randint(game.nb_actions))
                                    # #print(yaw)
                                    # if z < 6 and x < 13:
                                    # 	#print("Segment1")
                                    # 	if yaw == 270:
                                    # 		a = 0
                                    # 	else:
                                    # 		a = 1
                                    # elif z < 8 and x >= 13:
                                    # 	#print("Segment2")
                                    # 	if yaw == 0:
                                    # 		a = 0
                                    # 	else:
                                    # 		a = 1
                                    # elif z >= 8 and x == 13:
                                    # 	#print("Segment3")
                                    # 	if yaw == 90:
                                    # 		a = 0
                                    # 	else:
                                    # 		a = 1
                                    # elif z >= 8 and z<= 17 and x < 6:
                                    # 	#print("Segment4")
                                    # 	if yaw == 0:
                                    # 		a = 0
                                    # 	else:
                                    # 		a = 1
                                    # elif z > 18 and x < 18:
                                    # 	#print("Segment5")
                                    # 	if yaw == 270:
                                    # 		a = 0
                                    # 	else:
                                    # 		a = 1
                                    # else:
                                    # 	a = int(np.argmax(q[0]))

                                #Play an extra 2 times (for NA friction)
                                game.play(a)
                                heatmap[game.location[0]][
                                    game.location[1]] = heatmap[
                                        game.location[0]][game.location[1]] + 1
                                #f2.write('{},{},{},{}\n'.format(advice_type,game.location[0],game.location[1],1 ))
                                #f2.flush()
                                r = game.get_score()
                                S_prime = self.get_game_data(game)
                                game_over = game.is_over()
                                transition = [S, a, r, S_prime, game_over]
                                self.memory.remember(*transition)
                                S = S_prime
                                if game_over == False:
                                    game.play(
                                        checkForBestMove(
                                            game.location[0], game.location[1],
                                            game.location[2]))
                                    heatmap[game.location[0]][game.location[
                                        1]] = heatmap[game.location[0]][
                                            game.location[1]] + 1
                                    #f2.write('{},{},{},{}\n'.format(advice_type,game.location[0],game.location[1],1 ))
                                    #f2.flush()
                                    r = game.get_score()
                                    S_prime = self.get_game_data(game)
                                    game_over = game.is_over()
                                    transition = [S, a, r, S_prime, game_over]
                                    self.memory.remember(*transition)
                                    S = S_prime
                                    # if game_over == False:
                                    # 	game.play(checkForBestMove(game.location[0],game.location[1],game.location[2]))
                                    # 	heatmap[game.location[0]][game.location[1]] = heatmap[game.location[0]][game.location[1]] + 1
                                    # 	#f2.write('{},{},{},{}\n'.format(advice_type,game.location[0],game.location[1],1 ))
                                    # 	#f2.flush()
                                    # 	r = game.get_score()
                                    # 	S_prime = self.get_game_data(game)
                                    # 	game_over = game.is_over()
                                    # 	transition = [S, a, r, S_prime, game_over]
                                    # 	self.memory.remember(*transition)
                                    # 	S = S_prime
                    if game_over == False:
                        if advice_type != "NA":
                            game.play(a)
                            heatmap[game.location[0]][
                                game.location[1]] = heatmap[game.location[0]][
                                    game.location[1]] + 1
                            #f2.write('{},{},{},{}\n'.format(advice_type,game.location[0],game.location[1],1 ))
                            #f2.flush()
                            r = game.get_score()
                            S_prime = self.get_game_data(game)
                            game_over = game.is_over()
                            transition = [S, a, r, S_prime, game_over]
                            self.memory.remember(*transition)
                            S = S_prime
                    if epoch >= observe:
                        batch = self.memory.get_batch(model=model,
                                                      batch_size=batch_size,
                                                      gamma=gamma)
                        if batch:
                            inputs, targets = batch
                            mtob = model.train_on_batch(inputs, targets)
                            if mtob > m_loss:
                                m_loss = mtob
                            loss += float(mtob)
                            #print( "LOSS: {} CULM_LOSS: {}".format(mtob,loss))
                    if checkpoint and (savedModel == False) and (
                        (epoch + 1 - observe) % checkpoint == 0
                            or epoch + 1 == nb_epoch):
                        #model.save_weights('weights.dat')
                        print("Checkpoint... saving model..")
                        if advice_type == "OA":
                            model.save('oa_model.h5')
                        if advice_type == "NA":
                            model.save('na_model.h5')
                        if advice_type == "RL":
                            model.save('rl_model.h5')
                        # model_json = model.to_json()
                        # with open("model.json", "w") as json_file:
                        #    json_file.write(model_json)
                        # #serialize weights to HDF5
                        # model.save_weights("model.h5")
                        savedModel = True
                if game.is_won():
                    win_count += 1
                    rolling_win_window.insert(0, 1)
                else:
                    rolling_win_window.insert(0, 0)
                if epsilon > final_epsilon and epoch >= observe:
                    epsilon -= delta
                    percent_win = 0
                    cdt = datetime.datetime.now()
                    if sum(rolling_win_window) != 0:
                        percent_win = sum(rolling_win_window) / 4
                    total_frames = total_frames + game.stepsTaken
                    f.write(
                        '{},{},{},{},{},{},{},{},{},{},{},{},{},{}\n'.format(
                            session_id, advice_type, meta_advice_type,
                            str(cdt), (epoch + 1), total_frames, game.score,
                            percent_win, epsilon, loss, game.stepsTaken,
                            adviceGiven, adviceAttempts, modelActions))
                    f.flush()
                    print(
                        "Session: {} | Time: {} | Epoch {:03d}/{:03d} | Steps {:.4f} | Epsilon {:.2f} | Score {} | Loss {}"
                        .format(session_id, str(cdt), epoch + 1, nb_epoch,
                                game.stepsTaken, epsilon, game.score, loss))
                    if len(rolling_win_window) > 4:
                        rolling_win_window.pop()
                    time.sleep(1.0)

            if advice_type == "OA":
                with open("{}OAheatxtues.csv".format(session_id), 'w+') as f2:
                    csvWriter = csv.writer(f2, delimiter=',')
                    csvWriter.writerows(heatmap)
                #heatmap = [ [0]*20 for i in range(20)]
            if advice_type == "RL":
                with open("{}RLheatxtues.csv".format(session_id), 'w+') as f2:
                    csvWriter = csv.writer(f2, delimiter=',')
                    csvWriter.writerows(heatmap)
                #heatmap = [ [0]*20 for i in range(20)]
            if advice_type == "NA":
                with open("{}NAheatxtues.csv".format(session_id), 'w+') as f2:
                    csvWriter = csv.writer(f2, delimiter=',')
                    csvWriter.writerows(heatmap)
                #heatmap = [ [0]*20 for i in range(20)]

    def play(self, game, nb_epoch=10, epsilon=0., visualize=False):
        self.check_game_compatibility(game)
        model = self.model
        win_count = 0
        frames = []
        for epoch in range(nb_epoch):
            print("Playing")
            game.reset()
            self.clear_frames()
            S = self.get_game_data(game)
            if visualize:
                frames.append(game.draw())
            game_over = False
            while not game_over:
                if np.random.rand() < epsilon:
                    print("random")
                    action = int(np.random.randint(0, game.nb_actions))
                else:
                    q = model.predict(S)[0]
                    possible_actions = game.get_possible_actions()
                    q = [q[i] for i in possible_actions]
                    action = possible_actions[np.argmax(q)]
                print(action)
                game.play(action)
                S = self.get_game_data(game)
                if visualize:
                    frames.append(game.draw())
                game_over = game.is_over()
            if game.is_won():
                win_count += 1
        print("Accuracy {} %".format(100. * win_count / nb_epoch))
        #Visualizing/printing images is currently super slow
        if visualize:
            if 'images' not in os.listdir('.'):
                os.mkdir('images')
            for i in range(len(frames)):
                plt.imshow(frames[i], interpolation='none')
                plt.savefig("images/" + game.name + str(i) + ".png")
Beispiel #2
0
class Agent:
    def __init__(self, model, memory=None, memory_size=100, nb_frames=None):
        assert len(
            model.output_shape
        ) == 2, "Model's output shape should be (nb_samples, nb_actions)."
        if memory:
            self.memory = memory
        else:
            self.memory = ExperienceReplay(memory_size)
        if not nb_frames and not model.input_shape[1]:
            raise Exception("Missing argument : nb_frames not provided")
        elif not nb_frames:
            nb_frames = model.input_shape[1]
        elif model.input_shape[
                1] and nb_frames and model.input_shape[1] != nb_frames:
            raise Exception(
                "Dimension mismatch : time dimension of model should be equal to nb_frames."
            )
        self.model = model
        self.nb_frames = nb_frames  # model input shape, 24
        self.frames = None

    @property
    def memory_size(self):
        return self.memory.memory_size

    @memory_size.setter
    def memory_size(self, value):
        self.memory.memory_size = value

    def reset_memory(self):
        self.exp_replay.reset_memory()

    def check_game_compatibility(self, game):
        game_output_shape = (1, None) + game.get_frame().shape
        #game_output_shape = (None, game.get_frame().shape)
        if len(game_output_shape) != len(self.model.input_shape):
            raise Exception(
                'Dimension mismatch. Input shape of the model should be compatible with the game.'
            )
        else:
            for i in range(len(self.model.input_shape)):
                if self.model.input_shape[i] and game_output_shape[
                        i] and self.model.input_shape[i] != game_output_shape[
                            i]:
                    raise Exception(
                        'Dimension mismatch. Input shape of the model should be compatible with the game.'
                    )
        if len(self.model.output_shape
               ) != 2 or self.model.output_shape[1] != game.nb_actions:
            raise Exception(
                'Output shape of model should be (nb_samples, nb_actions).')

    def get_game_data(self, game):  # returns scaled
        frame = game.get_frame()  # candidate to return scaled
        if self.frames is None:
            self.frames = [frame] * self.nb_frames
        else:
            self.frames.append(frame)
            self.frames.pop(0)
        return np.expand_dims(self.frames, 0)

    def clear_frames(self):
        self.frames = None

    def train(self,
              game,
              nb_epoch=1000,
              batch_size=50,
              gamma=0.9,
              epsilon=[1., .1],
              epsilon_rate=0.5,
              reset_memory=False,
              observe=0,
              checkpoint=None):
        self.check_game_compatibility(game)
        if type(epsilon) in {tuple, list}:
            delta = ((epsilon[0] - epsilon[1]) / (nb_epoch * epsilon_rate))
            final_epsilon = epsilon[1]
            epsilon = epsilon[0]
        else:
            final_epsilon = epsilon
        save = Save()
        model = self.model
        nb_actions = model.output_shape[-1]
        win_count = 0
        for epoch in range(nb_epoch):
            loss = 0.
            q = np.zeros(3)
            game.reset()
            self.clear_frames()
            if reset_memory:
                self.reset_memory()
            game_over = False
            S = self.get_game_data(game)  # S must be scaled
            i = 0
            while not game_over:
                i = i + 1
                if np.random.random() < epsilon or epoch < observe:
                    a = int(np.random.randint(game.nb_actions))
                    print('>'),
                else:
                    # S must be scaled
                    q = model.predict(S)  # !
                    a = int(np.argmax(q[0]))
                game.play(a)
                r = game.get_score(a)
                S_prime = self.get_game_data(game)  # S_prime must be scaled
                game_over = game.is_over()
                # S, a, S_prime, must be scaled
                # reward, game over is not scaled in catch/snake
                transition = [S, a, r, S_prime, game_over]  # !
                self.memory.remember(*transition)
                S = S_prime
                if epoch >= observe:
                    batch = self.memory.get_batch(model=model,
                                                  batch_size=batch_size,
                                                  gamma=gamma)
                    if batch:
                        inputs, targets = batch  # scaled
                        loss += float(model.train_on_batch(inputs, targets))
                #if checkpoint and ((epoch + 1 - observe) % checkpoint == 0 or epoch + 1 == nb_epoch):
                #model.save_weights('4kweights.dat')

                save.log(game, epoch)

            if game.is_won():
                win_count += 1
            if epsilon > final_epsilon and epoch >= observe:
                epsilon -= delta
            print(' ')
            print(
                "Epoch {:03d}/{:03d} | Loss {:.4f} | Epsilon {:.2f} | Win count {} | loss Avg {:.4f}"
                .format(epoch + 1, nb_epoch, loss, epsilon, win_count,
                        loss / i))

            if ((epoch % 10) == 0):
                save.save_model(model, Config.f_model)
            save.log_epoch(loss, win_count, loss / i)

    def play(self, game, nb_epoch=10, epsilon=0., visualize=True):
        self.check_game_compatibility(game)
        model = self.model
        win_count = 0
        frames = []
        save = Save()
        for epoch in range(nb_epoch):
            game.reset()
            self.clear_frames()
            S = self.get_game_data(game)  # S must be scaled
            if visualize:
                frames.append(game.draw())
            game_over = False
            while not game_over:
                if np.random.rand() < epsilon:
                    print("random")
                    action = int(np.random.randint(0, game.nb_actions))
                else:
                    # S must be scaled
                    q = model.predict(S)[0]  # !
                    possible_actions = game.get_possible_actions()
                    q = [q[i] for i in possible_actions]
                    action = possible_actions[np.argmax(q)]

                game.play(action)

                S = self.get_game_data(game)
                '''
				if visualize:
					frames.append(game.draw())
				game_over = game.is_over()
				'''
                save.log(game, nb_epoch)

            if game.is_won():
                win_count += 1
        print("Accuracy {} %".format(100. * win_count / nb_epoch))
        '''
Beispiel #3
0
class Agent:

	def __init__(self, model, memory=None, memory_size=1000, nb_frames=None):
		assert len(model.output_shape) == 2, "Model's output shape should be (nb_samples, nb_actions)."
		if memory:
			self.memory = memory
		else:
			self.memory = ExperienceReplay(memory_size)
		if not nb_frames and not model.input_shape:
			raise Exception("Missing argument : nb_frames not provided")
		elif not nb_frames:
			nb_frames = model.input_shape[1]
		elif model.input_shape[1] and nb_frames and model.input_shape[1] != nb_frames:
			raise Exception("Dimension mismatch : time dimension of model should be equal to nb_frames.")
		self.model = model
		self.nb_frames = nb_frames
		self.frames = None

	@property
	def memory_size(self):
		return self.memory.memory_size

	@memory_size.setter
	def memory_size(self, value):
		self.memory.memory_size = value

	def reset_memory(self):
		self.exp_replay.reset_memory()

	def check_game_compatibility(self, game):
		game_output_shape = (1, None) + game.get_frame().shape
		if len(game_output_shape) != len(self.model.input_shape):
			raise Exception('Dimension mismatch. Input shape of the model should be compatible with the game.')
		else:
			for i in range(len(self.model.input_shape)):
				if self.model.input_shape[i] and game_output_shape[i] and self.model.input_shape[i] != game_output_shape[i]:
					raise Exception('Dimension mismatch. Input shape of the model should be compatible with the game.')
		if len(self.model.output_shape) != 2 or self.model.output_shape[1] != game.nb_actions:
			raise Exception('Output shape of model should be (nb_samples, nb_actions).')

	def get_game_data(self, game):
		frame = game.get_frame()
		if self.frames is None:
			self.frames = [frame] * self.nb_frames
		else:
			self.frames.append(frame)
			self.frames.pop(0)
		return np.expand_dims(self.frames, 0)

	def clear_frames(self):
		self.frames = None

	def action_count(self, game):
		#print "game.get_action_count: ", game.get_action_count
		return game.get_action_count

	# SET WHICH RUNS TO PRINT OUT HERE *****************************************************************
	def report_action(self, game):
		return ((self.action_count(game) % self.report_freq) == 0) # and ((self.action_count(game) % self.report_freq) < 20) #% 10000) == 0 #

	def train(self, game, nb_epoch=1000, batch_size=50, gamma=0.9, epsilon=[1., .1], epsilon_rate=0.5, reset_memory=False, id=""):

		txt_store_path = "./txtstore/run_1000e_b50_15r_reg_lr1/junk/"
		printing = False
		record_weights = False
		self.max_moves = game.get_max_moves()
		self.report_freq  = self.max_moves #50

		'''fo_A = open(txt_store_path + "A.txt", "rw+")
		fo_G = open(txt_store_path + "G.txt", "rw+")
		fo_Gb = open(txt_store_path + "Gb.txt", "rw+")
		fo_I = open(txt_store_path + "I.txt", "rw+")
		fo_Q = open(txt_store_path + "Q.txt", "rw+")
		fo_R = open(txt_store_path + "R.txt", "rw+")
		fo_S = open(txt_store_path + "S.txt", "rw+")
		fo_T = open(txt_store_path + "T.txt", "rw+")
		fo_W = open(txt_store_path + "W.txt", "rw+")
		fo_Wb = open(txt_store_path + "Wb.txt", "rw+")'''

		self.check_game_compatibility(game)
		if type(epsilon)  in {tuple, list}:
			delta =  ((epsilon[0] - epsilon[1]) / (nb_epoch * epsilon_rate))
			final_epsilon = epsilon[1]
			epsilon = epsilon[0]
		else:
			final_epsilon = epsilon
		model = self.model
		nb_actions = model.output_shape[-1]
		win_count = 0

		scores = np.zeros((nb_epoch,self.max_moves/self.report_freq))
		losses = np.zeros((nb_epoch,self.max_moves/self.report_freq))


		for epoch in range(nb_epoch):
			#ipdb.set_trace(context=9)	# TRACING HERE *********************************************
			loss = 0.
			game.reset()
			self.clear_frames()
			if reset_memory:
				self.reset_memory()
			game_over = False
			S = self.get_game_data(game)
			no_last_S = True

			plot_showing = False

			while not game_over:
				if np.random.random() < epsilon:
					a = int(np.random.randint(game.nb_actions))
					#if (self.action_count(game) % 100000) == 0:
					'''if self.report_action(game):
						if printing:
							print "random",
						q = model.predict(S)'''
					q = model.predict(S)
					expected_action = (a == int(np.argmax(q[0])))
				else:
					expected_action = True
					q = model.predict(S)
					#print q.shape
					#print q[0]
					# ************************************** CATCHING NANS
					'''if (q[0,0] != q[0,0]):
						ipdb.set_trace(context=9)	# TRACING HERE *********************************************
					'''
					a = int(np.argmax(q[0]))
					#if (self.action_count(game) % 100000) == 0:
				prob = epsilon/game.nb_actions
				if expected_action:
					prob = 1 - epsilon + prob
				game.play(a, self.report_action(game))
				r = game.get_score()
				#ipdb.set_trace(context=9)	# TRACING HERE *********************************************


				# PRINTING S HERE ******************************************************************

				''' if plot_showing:
					plt.clf()
				plt.imshow(np.reshape(S,(6,6)))
				plt.draw()
				plt.show(block=False)
				plot_showing = True
				print "hi" '''

				# PRINTING S HERE ******************************************************************

				S_prime = self.get_game_data(game)



				'''if self.report_action(game):
					if printing:
						print "S: ", S
						#if no_last_S:
						#	last_S = S
						#	no_last_S = False
						#else:
						#	print "dS:", S - last_S
						#	print "    ==>  Q(lS):", model.predict(last_S)
						#print
						print "    ==>  Q(S): ", q, "    ==>  A: ", a, "    ==> R: %f" % r
						#print "    ==>  Q(S'):", model.predict(S_prime)
						#print
					fo_S.seek(0,2)
					np.savetxt(fo_S, S[0], fmt='%4.4f') #
					fo_Q.seek(0,2)
					np.savetxt(fo_Q, q, fmt='%4.4f') #
					fo_A.seek(0,2)
					fo_A.write(str(a)+"\n") #savetxt(fo, S[0], fmt='%4.4f') #
					fo_R.seek(0,2)
					fo_R.write(str(r)+"\n")
				'''

				#ipdb.set_trace(context=9)	# TRACING HERE *********************************************


				#last_S = S

				game_over = game.is_over()
				transition = [S, a, r, S_prime, game_over, prob]
				self.memory.remember(*transition)
				S = S_prime
				batch = self.memory.get_batch(model=model, batch_size=batch_size, gamma=gamma, ruql=True) #, print_it=False) #self.report_action(game))
				if batch:
					inputs, targets, probs = batch

					#print("model.total_loss: ", model.total_loss)
					'''if record_weights:

						weights_pre = model.get_weights() # GOT WEIGHTS *************************
						#print "weights_pre"
						#print weights_pre

						if self.report_action(game):
							fo_W.seek(0,2)
							np.savetxt(fo_W, weights_pre[0], fmt='%4.4f') #
							fo_W.write("\n")
							fo_Wb.seek(0,2)
							np.savetxt(fo_Wb, weights_pre[1], fmt='%4.4f') #
							fo_Wb.write("\n")'''

					#output = model.train_on_batch(inputs, targets)
					#loss += float(output[0]) #model.train_on_batch(inputs, targets))
					'''print "myAgent"
					print 'inputs: ', type(inputs), "; ", inputs.shape 
					print 'targets: ', type(targets), "; ", targets.shape
					print 'probs: ', type(probs), "; ", probs.shape'''
					loss += float(model.train_on_batch(inputs, targets, probs=probs))

					#if self.report_action(game):
					#	#print output
					#	#fo_G.seek(0,2)
					#	#np.savetxt(fo_G, output[1], fmt='%4.4f') #
					#	#fo_G.write("\n")
					#	#fo_Gb.seek(0,2)
					#	#np.savetxt(fo_Gb, output[2], fmt='%4.4f') #
					#	#fo_Gb.write("\n")

					#weights_post = model.get_weights() # GOT WEIGHTS ********************************
					#print "weights_post"
					#print weights_post
					#ipdb.set_trace()	# TRACING HERE *********************************************

					#print("action_count PRE: ", action_count)
					if self.report_action(game):
						action_count = self.action_count(game)
						#print("action_count/self.report_freq: ", action_count/self.report_freq)
						#print("action_count: ", action_count)
						#print("self.report_freq: ", self.report_freq)
						#print("scores so far: ", scores)
						#print("scores.shape: ", scores.shape)'''
						while (action_count/self.report_freq > scores.shape[1]):
							scores = np.append(scores, np.zeros((nb_epoch,1)), 1)
							losses = np.append(losses, np.zeros((nb_epoch,1)), 1)
						scores[epoch, action_count/self.report_freq-1] = game.get_total_score()
						losses[epoch, action_count/self.report_freq-1] = loss

						#print ("running a batch (of %d): 1: %d; 2: %d" % (len(batch), batch[0].size, \
						#	batch[1].size))
						#print "memory size: ", self.memory_size
						#print "using memory\n", inputs, "; tgt: ", targets
						#fo_I.seek(0,2)
						#np.savetxt(fo_I, inputs[0], fmt='%4.4f') #
						#fo_T.seek(0,2)
						#np.savetxt(fo_T, targets, fmt='%4.4f') #
					#fo_T.write("\n")
			if game.is_won():
				win_count += 1
			if epsilon > final_epsilon:
				epsilon -= delta
			if (epoch % 50) == 0:
				print("Epoch {:03d}/{:03d} | Loss {:.4f} | Epsilon {:.2f} | Win count {}".format(epoch + 1, nb_epoch, loss, epsilon, win_count))
		pickle.dump(scores, open(txt_store_path + "score" + id + ".p", "wb" ) )
		pickle.dump(losses, open(txt_store_path + "loss" + id + ".p", "wb" ) )
		'''
		fo_A.close()
		fo_G.close()
		fo_Gb.close()
		fo_I.close()
		fo_Q.close()
		fo_R.close()
		fo_S.close()
		fo_T.close()
		fo_W.close()
		fo_Wb.close()'''

		average_taken_over = 10
		last_col = self.max_moves/self.report_freq -1

		fo_log = open("log.txt", "rw+")
		fo_log.seek(0,2)

		average_score = np.mean(scores[-average_taken_over:nb_epoch, last_col])
		average_error = np.mean(losses[-average_taken_over:nb_epoch, last_col])

		fo_log.write("\n{:20}|{:^12}|{:^10}|{:^10}|{:^6}|{:^12}|{:^12}|{:^12}{:^6}{:^6}|{:^10}|{:^20}|{:^10}|{:^6}".format(" ", "game moves", "avg score", "error", "WC", "epochs", "batch size", "epsiln frm", ".. to", ".. by", "lr", "desciption", "timer", "reg"))
		fo_log.write("\n{:<20}|{:^12d}|{:^10.2f}|{:^10.2f}|{:^6d}|".format(time.strftime("%d/%m/%Y %H:%M"), self.max_moves, \
			average_score, average_error, win_count)) #average_taken_over,
		fo_log.close()


	def play(self, game, nb_epoch=1, epsilon=0., visualize=False):
		self.check_game_compatibility(game)
		model = self.model
		win_count = 0
		frames = []
		for epoch in range(nb_epoch):
			game.reset()
			self.clear_frames()
			S = self.get_game_data(game)
			if visualize:
				frames.append(game.draw())
			game_over = False
			while not game_over:
				if np.random.rand() < epsilon:
					print("random")
					action = int(np.random.randint(0, game.nb_actions))
				else:
					q = model.predict(S)
					action = int(np.argmax(q[0]))
				game.play(action)
				S = self.get_game_data(game)
				if visualize:
					frames.append(game.draw())
				game_over = game.is_over()
			if game.is_won():
				win_count += 1
		print("Accuracy {} %".format(100. * win_count / nb_epoch))
		if visualize:
			if 'images' not in os.listdir('.'):
				os.mkdir('images')
			for i in range(len(frames)):
				plt.imshow(frames[i], interpolation='none')
				plt.savefig("images/" + game.name + str(i) + ".png")
Beispiel #4
0
class Agent:
    def __init__(self, model, memory=None, memory_size=1000, nb_frames=None):
        assert len(
            model.output_shape
        ) == 2, "Model's output shape should be (nb_samples, nb_actions)."
        if memory:
            self.memory = memory
        else:
            self.memory = ExperienceReplay(memory_size)
        if not nb_frames and not model.input_shape:
            raise Exception("Missing argument : nb_frames not provided")
        elif not nb_frames:
            nb_frames = model.input_shape[1]
        elif model.input_shape[
                1] and nb_frames and model.input_shape[1] != nb_frames:
            raise Exception(
                "Dimension mismatch : time dimension of model should be equal to nb_frames."
            )
        self.model = model
        self.nb_frames = nb_frames
        self.frames = None

    @property
    def memory_size(self):
        return self.memory.memory_size

    @memory_size.setter
    def memory_size(self, value):
        self.memory.memory_size = value

    def reset_memory(self):
        self.exp_replay.reset_memory()

    def check_game_compatibility(self, game):
        game_output_shape = (1, None) + game.get_frame().shape
        if len(game_output_shape) != len(self.model.input_shape):
            raise Exception(
                'Dimension mismatch. Input shape of the model should be compatible with the game.'
            )
        else:
            for i in range(len(self.model.input_shape)):
                if self.model.input_shape[i] and game_output_shape[
                        i] and self.model.input_shape[i] != game_output_shape[
                            i]:
                    raise Exception(
                        'Dimension mismatch. Input shape of the model should be compatible with the game.'
                    )
        if len(self.model.output_shape
               ) != 2 or self.model.output_shape[1] != game.nb_actions:
            raise Exception(
                'Output shape of model should be (nb_samples, nb_actions).')

    def get_game_data(self, game):
        frame = game.get_frame()
        if self.frames is None:
            self.frames = [frame] * self.nb_frames
        else:
            self.frames.append(frame)
            self.frames.pop(0)
        return np.expand_dims(self.frames, 0)

    def clear_frames(self):
        self.frames = None

    def train(self,
              game,
              nb_epoch=1000,
              batch_size=50,
              gamma=0.9,
              epsilon=[1., .1],
              epsilon_rate=0.5,
              reset_memory=False):
        self.check_game_compatibility(game)
        if type(epsilon) in {tuple, list}:
            delta = ((epsilon[0] - epsilon[1]) / (nb_epoch * epsilon_rate))
            final_epsilon = epsilon[1]
            epsilon = epsilon[0]
        else:
            final_epsilon = epsilon
        model = self.model
        nb_actions = model.output_shape[-1]
        win_count = 0
        for epoch in range(nb_epoch):
            loss = 0.
            game.reset()
            self.clear_frames()
            if reset_memory:
                self.reset_memory()
            game_over = False
            S = self.get_game_data(game)
            while not game_over:
                if np.random.random() < epsilon:
                    a = int(np.random.randint(game.nb_actions))
                else:
                    q = model.predict(S)
                    a = int(np.argmax(q[0]))
                game.play(a)
                r = game.get_score()
                S_prime = self.get_game_data(game)
                game_over = game.is_over()
                transition = [S, a, r, S_prime, game_over]
                self.memory.remember(*transition)
                S = S_prime
                inputs, targets = self.memory.get_batch(model=model,
                                                        batch_size=batch_size,
                                                        gamma=gamma)
                loss += model.train_on_batch(inputs, targets)[0]
            if game.is_won():
                win_count += 1
            if epsilon > final_epsilon:
                epsilon -= delta
            print(
                "Epoch {:03d}/{:03d} | Loss {:.4f} | Epsilon {:.2f} | Win count {}"
                .format(epoch + 1, nb_epoch, loss, epsilon, win_count))

    def play(self, game, nb_epoch=10, epsilon=0., visualize=True):
        self.check_game_compatibility(game)
        model = self.model
        win_count = 0
        frames = []
        for epoch in range(nb_epoch):
            game.reset()
            self.clear_frames()
            S = self.get_game_data(game)
            if visualize:
                frames.append(game.draw())
            game_over = False
            while not game_over:
                if np.random.rand() < epsilon:
                    print("random")
                    action = int(np.random.randint(0, game.nb_actions))
                else:
                    q = model.predict(S)
                    action = int(np.argmax(q[0]))
                game.play(action)
                S = self.get_game_data(game)
                if visualize:
                    frames.append(game.draw())
                game_over = game.is_over()
            if game.is_won():
                win_count += 1
        print("Accuracy {} %".format(100. * win_count / nb_epoch))
        if visualize:
            if 'images' not in os.listdir('.'):
                os.mkdir('images')
            for i in range(len(frames)):
                plt.imshow(frames[i], interpolation='none')
                plt.savefig("images/" + game.name + str(i) + ".png")
Beispiel #5
0
class Agent:
  def __init__(self, game, mode=SIMPLE, nb_epoch=10000, memory_size=1000, batch_size=50, nb_frames=4, epsilon=1., discount=.9, learning_rate=.1, model=None):

    self.game = game
    self.mode = mode
    self.target_model = None
    self.rows, self.columns = game.field_shape()
    self.nb_epoch = nb_epoch
    self.nb_frames = nb_frames
    self.nb_actions = game.nb_actions()

    if mode == TEST:
      print('Training Mode: Loading model...')
      self.model = load_model(model)
    elif mode == SIMPLE:
      print('Using Plain DQN: Building model...')
      self.model = self.build_model()
    elif mode == DOUBLE:
      print('Using Double DQN: Building primary and target model...')
      self.model = self.build_model()
      self.target_model = self.build_model()
      self.update_target_model()

    # Trades off the importance of sooner versus later rewards.
    # A factor of 0 means it rather prefers immediate rewards
    # and it will mostly consider current rewards. A factor of 1
    # will make it strive for a long-term high reward.
    self.discount = discount

    # The learning rate or step size determines to what extent the newly
    # acquired information will override the old information. A factor
    # of 0 will make the agent not learn anything, while a factor of 1
    # would make the agent consider only the most recent information
    self.learning_rate = learning_rate

    # Use epsilon-greedy exploration as our policy.
    # Epsilon determines the probability for choosing random actions.
    # This factor will decrease linear by the number of epoches. So we choose
    # a random action by the probability 'eps'. Without this policy the network
    # is greedy and it will it settles with the first effective strategy it finds.
    # Hence, we introduce certain randomness.
    # Epislon reaches its minimum at 1/2 of the games
    epsilon_end = self.nb_epoch - (self.nb_epoch / 2)
    self.policy = EpsGreedyPolicy(self.model, epsilon_end, self.nb_actions, epsilon, .1)

    # Create new experience replay memory. Without this optimization
    # the training takes extremely long even on a GPU and most
    # importantly the approximation of Q-values using non-linear
    # functions, that is used for our NN, is not very stable.
    self.memory = ExperienceReplay(self.model, self.target_model, self.nb_actions, memory_size, batch_size, self.discount, self.learning_rate)

    self.frames = None

  def build_model(self):
    model = Sequential()
    model.add(Conv2D(32, (2, 2), activation='relu', input_shape=(self.nb_frames, self.rows, self.columns), data_format="channels_first"))
    model.add(Conv2D(64, (2, 2), activation='relu'))
    model.add(Conv2D(64, (3, 3), activation='relu'))
    model.add(Flatten())
    model.add(Dropout(0.1))
    model.add(Dense(512, activation='relu'))
    model.add(Dense(self.nb_actions))
    model.compile(Adam(), 'MSE')

    return model

  def update_target_model(self):
    self.target_model.set_weights(self.model.get_weights())

  def get_frames(self):
    frame = self.game.get_state()
    if self.frames is None:
      self.frames = [frame] * self.nb_frames
    else:
      self.frames.append(frame)
      self.frames.pop(0)

    # Expand frames to match the input shape for the CNN (4D)
    # 1D      = # batches
    # 2D      = # frames per batch
    # 3D / 4D = game board
    return np.expand_dims(self.frames, 0)

  def clear_frames(self):
    self.frames = None

  def print_stats(self, data, y_label, x_label='Epoch', marker='-'):
    data = np.array(data)
    x, y = data.T
    p = np.polyfit(x, y, 3)

    fig = plt.figure()

    plt.plot(x, y, marker)
    plt.plot(x, np.polyval(p, x), 'r:')
    plt.xlabel(x_label)
    plt.ylabel(y_label)

    words = y_label.split()
    file_name = '_'.join(map(lambda x: x.lower(), words))
    path = './plots/{name}_{size}x{size}_{timestamp}'
    fig.savefig(path.format(size=self.game.grid_size, name=file_name, timestamp=int(time())))

  def train(self, update_freq=10):
    total_steps = 0
    max_steps = self.game.grid_size**2 * 3
    loops = 0
    nb_wins = 0
    cumulative_reward = 0
    duration_buffer = []
    reward_buffer = []
    steps_buffer = []
    wins_buffer = []

    for epoch in range(self.nb_epoch):
      loss = 0.
      duration = 0
      steps = 0

      self.game.reset()
      self.clear_frames()
      done = False

      # Observe the initial state
      state_t = self.get_frames()

      start_time = time()

      while(not done):
        # Explore or Exploit
        action = self.policy.select_action(state_t, epoch)

        # Act on the environment
        _, reward, done, is_victory = self.game.act(action)
        state_tn = self.get_frames()

        cumulative_reward += reward
        steps += 1
        total_steps += 1

        if steps == max_steps and not done:
          loops += 1
          done = True

        # Build transition and remember it (Experience Replay)
        transition = [state_t, action, reward, state_tn, done]
        self.memory.remember(*transition)
        state_t = state_tn

        # Get batch of batch_size samples
        # A batch generally approximates the distribution of the input data
        # better than a single input. The larger the batch, the better the
        # approximation. However, larger batches take longer to process.
        batch = self.memory.get_batch()

        if batch:
          inputs, targets = batch
          loss += float(self.model.train_on_batch(inputs, targets))

        if self.game.is_victory():
          nb_wins += 1

        if done:
          duration = utils.get_time_difference(start_time, time())

        if self.mode == DOUBLE and self.target_model is not None and total_steps % (update_freq) == 0:
          self.update_target_model()

      current_epoch = epoch + 1
      reward_buffer.append([current_epoch, cumulative_reward])
      duration_buffer.append([current_epoch, duration])
      steps_buffer.append([current_epoch, steps])
      wins_buffer.append([current_epoch, nb_wins])

      summary = 'Epoch {:03d}/{:03d} | Loss {:.4f} | Epsilon {:.2f} | Time(ms) {:3.3f} | Steps {:.2f} | Wins {} | Loops {}'
      print(summary.format(current_epoch, self.nb_epoch, loss, self.policy.get_eps(), duration, steps, nb_wins, loops))

    # Generate plots
    self.print_stats(reward_buffer, 'Cumulative Reward')
    self.print_stats(duration_buffer, 'Duration per Game')
    self.print_stats(steps_buffer, 'Steps per Game')
    self.print_stats(wins_buffer, 'Wins')

    path = './models/model_{mode}_{size}x{size}_{epochs}_{timestamp}.h5'
    mode = 'dqn' if self.mode == SIMPLE else 'ddqn'
    self.model.save(path.format(mode=mode, size=self.game.grid_size, epochs=self.nb_epoch, timestamp=int(time())))

  def play(self, nb_games=5, interval=.7):
    nb_wins = 0
    accuracy = 0
    summary = '{}\n\nAccuracy {:.2f}% | Game {}/{} | Wins {}'

    for epoch in range(nb_games):
      self.game.reset()
      self.clear_frames()
      done = False

      state_t = self.get_frames()

      self.print_state(summary, state_t[:,-1], accuracy, epoch, nb_games, nb_wins, 0)

      while(not done):
        q = self.model.predict(state_t)
        action = np.argmax(q[0])

        _, _, done, is_victory = self.game.act(action)
        state_tn = self.get_frames()

        state_t = state_tn

        if is_victory:
          nb_wins += 1

        accuracy = 100. * nb_wins / nb_games

        self.print_state(summary, state_t[:,-1], accuracy, epoch, nb_games, nb_wins, interval)

  def print_state(self, summary, state, accuracy, epoch, nb_games, nb_wins, interval):
    utils.clear_screen()
    print(summary.format(state, accuracy, epoch + 1, nb_games, nb_wins))
    sleep(interval)