def __init__(self, args): self.epochLength = args.epochLength self.epochNumber = args.epochNumber self.testLength = args.testLength self.initialState = args.initialState self.game = args.game self.graphics = args.graphics self.visible = args.visible self.filename = args.filename self.repeatAction = args.repeatAction self.mergedFrames = args.mergedFrames self.testOnly = args.testOnly self.agent = LearningAgent(args) self.emulator = NesEmulator() self.emulator.loadRom('nes/roms/arkanoid.nes') self.emulator.getNextFrame() self.actions = args.actions if self.game == 'space_invaders': from space_invaders import GameAdapter elif self.game == 'arkanoid': from arkanoid import GameAdapter else: print 'UNKNOWN GAME', self.game self.gameAdapter = GameAdapter(self.emulator) self.agent.load(self.filename + '.pkl') if self.graphics: self.initGraphics()
def test_cos(self): hp = hyperparameters.copy() hp["embedding_fn"] = "cos" la = LearningAgent( hyperparameters=hp, configuration=configuration, env_name="LunarLander-v2", ) ( quantile_thresholds_ph, inner_product, shaped_embedding, final_embedding, ) = la.learning_network.build_quantile_embedding( la.placeholders["quantile_thresholds"]) la.sess.run(tf.global_variables_initializer()) memories = build_memories(hyperparameters["batch_size"]) ( quantile_thresholds_ph_output, inner_product_output, shaped_embedding_output, final_embedding_output, ) = la.sess.run([ quantile_thresholds_ph, inner_product, shaped_embedding, final_embedding, ], feed_dict=la.feed_dict_from_training_batch(memories)) for batch_idx in range(hyperparameters["batch_size"]): for quantile_idx in range(hyperparameters["num_quantiles"]): for i_idx in range(hyperparameters["embedding_repeat"]): np.testing.assert_almost_equal( i_idx * quantile_thresholds_ph_output[batch_idx] [quantile_idx][0] * math.pi, inner_product_output[batch_idx][quantile_idx][i_idx], 5) inner_product_output_i = inner_product_output[batch_idx][ quantile_idx][i_idx] np.testing.assert_almost_equal( shaped_embedding_output[batch_idx][quantile_idx] [i_idx], math.cos(inner_product_output_i)) np.testing.assert_equal( shaped_embedding_output[batch_idx][quantile_idx][i_idx] >= -1, True) np.testing.assert_equal( shaped_embedding_output[batch_idx][quantile_idx][i_idx] <= 1, True)
def test_accuracy_encodings(): """Test that the encoding is correct for dice targetted by an accuracy.""" agent = LearningAgent() attacker = ship.Ship(name="Attacker", template=ship_templates["Attacker"], upgrades=[], player_number=1) three_brace = ship.Ship(name="Double Brace", template=ship_templates["Triple Brace"], upgrades=[], player_number=2) # Make a brace token red three_brace.spend_token('brace', ArmadaTypes.green) enc_three_brace, world_state = make_encoding(attacker, three_brace, "short", agent) # Define the offsets for convenience token_begin = Encodings.getAttackTokenOffset() token_end = token_begin + ArmadaTypes.max_defense_tokens # Verify that no tokens are targeted at first assert 0.0 == enc_three_brace[token_begin:token_end].sum() # Now make a token red and target it three_brace.spend_token('brace', ArmadaTypes.green) green_acc_begin = Encodings.getAttackTokenOffset() green_acc_end = green_acc_begin + len(ArmadaTypes.defense_tokens) red_acc_begin = Encodings.getAttackTokenOffset() + len( ArmadaTypes.defense_tokens) red_acc_end = red_acc_begin + len(ArmadaTypes.defense_tokens) world_state.attack.accuracy_defender_token( ArmadaTypes.defense_tokens.index('brace'), ArmadaTypes.red) encoding = Encodings.encodeAttackState(world_state) # Verify that only the red token has the accuracy flag set assert encoding[red_acc_begin + ArmadaTypes.defense_tokens.index('brace')].item() == 1. assert encoding[red_acc_begin:red_acc_end].sum().item() == 1. assert encoding[green_acc_begin:green_acc_end].sum().item() == 0. # Target both remaining green tokens world_state.attack.accuracy_defender_token( ArmadaTypes.defense_tokens.index('brace'), ArmadaTypes.green) world_state.attack.accuracy_defender_token( ArmadaTypes.defense_tokens.index('brace'), ArmadaTypes.green) encoding = Encodings.encodeAttackState(world_state) # Verify that two green and one red brace have the accuracy flag assert encoding[red_acc_begin + ArmadaTypes.defense_tokens.index('brace')].item() == 1. assert encoding[red_acc_begin:red_acc_end].sum().item() == 1. assert encoding[green_acc_begin + ArmadaTypes.defense_tokens.index('brace')].item() == 2. assert encoding[green_acc_begin:green_acc_end].sum().item() == 2.
def test_spent_encodings(): """Test that the encoding is correct for different defense tokens.""" agent = LearningAgent() attacker = ship.Ship(name="Attacker", template=ship_templates["Attacker"], upgrades=[], player_number=1) defender = ship.Ship(name="Defender", template=ship_templates["All Defense Tokens"], upgrades=[], player_number=2) encoding, world_state = make_encoding(attacker, defender, "short", agent) # The defender and attacker come first, then the accuracied tokens, then the spent tokens spent_begin = 2 * ship.Ship.encodeSize() + 2 * len( ArmadaTypes.defense_tokens) spent_end = spent_begin + len(ArmadaTypes.defense_tokens) # Verify that no tokens are marked spent by default assert torch.sum(encoding[spent_begin:spent_end]) == 0. # Spend all of the tokens for tidx, ttype in enumerate(ArmadaTypes.defense_tokens): world_state.attack.defender_spend_token(ttype, 'green') encoding = Encodings.encodeAttackState(world_state) assert torch.sum(encoding[spent_begin:spent_end]).item() == len( ArmadaTypes.defense_tokens) # Try spending the tokens at different indices for tidx, ttype in enumerate(ArmadaTypes.defense_tokens): # Re-encode and then set the token to spent. attacker = ship.Ship(name="Attacker", template=ship_templates["Attacker"], upgrades=[], player_number=1) defender = ship.Ship(name="Defender", template=ship_templates["All Defense Tokens"], upgrades=[], player_number=2) encoding, world_state = make_encoding(attacker, defender, "short", agent) world_state.attack.defender_spend_token(ttype, 'green') encoding = Encodings.encodeAttackState(world_state) assert torch.sum(encoding[spent_begin:spent_end]).item() == 1.0 assert encoding[spent_begin:spent_end][tidx].item() == 1.0
def test_range_encodings(): """Test that the encoding is correct for ranges.""" agent = LearningAgent() attacker = ship.Ship(name="Attacker", template=ship_templates["Attacker"], upgrades=[], player_number=1) no_token = ship.Ship(name="No Defense Tokens", template=ship_templates["No Defense Tokens"], upgrades=[], player_number=2) range_begin = Encodings.getAttackRangeOffset() for offset, attack_range in enumerate(ArmadaTypes.ranges): enc_attack = make_encoding(attacker, no_token, attack_range, agent)[0] assert torch.sum(enc_attack[range_begin:range_begin + len(ArmadaTypes.ranges)]) == 1 assert 1.0 == enc_attack[range_begin + offset].item()
def test_roll_encodings(): """Test that the encoding is correct for dice pools and faces.""" agent = LearningAgent() attacker = ship.Ship(name="Attacker", template=ship_templates["Attacker"], upgrades=[], player_number=1) no_token = ship.Ship(name="No Defense Tokens", template=ship_templates["No Defense Tokens"], upgrades=[], player_number=2) dice_begin = Encodings.getAttackDiceOffset() # Do 100 trials to ensure everything is working as expected _, world_state = make_encoding(attacker, no_token, "short", agent) for _ in range(100): pool_colors, pool_faces = attacker.roll("front", "short") attack = world_state.attack attack.pool_faces = pool_faces attack.pool_colors = pool_colors # Count which items are matched to check if they are all encoded matched_dice = [0] * len(pool_faces) world_state.updateAttack(attack) # Make a random roll and encode the attack state # [ color - 3, face - 6] enc_attack = Encodings.encodeAttackState(world_state) # Try to find a match for each color,face pair in the encoding enc_dice = enc_attack[Encodings.getAttackDiceOffset():] for face, color in zip(attack.pool_faces, attack.pool_colors): assert 0. < enc_dice[Encodings.dieOffset(color=color, face=face)].item() enc_dice[Encodings.dieOffset(color=color, face=face)] -= 1 # All dice from the pool should have been matched and there should be no more encoded assert sum(enc_dice) <= 0.
def test_policy_learning(spend_defense_tokens_model, resolve_attack_effects_model): """Train a model to produce a better than random choice policy for defense tokens. The spend_defense_tokens_model will be used to determine the quality of this network's output. There will not be an update step as in reinforcement learning, this is just testing the mechanism. Returns: (nn.module, nn.module,): The 'resolve attack effects' and 'spend defense tokens' models. """ def_tokens_model, errors, eval_errors = spend_defense_tokens_model def_tokens_model.eval() res_attack_model, errors, eval_errors = resolve_attack_effects_model res_attack_model.eval() prediction_models = { "attack - spend defense tokens": def_tokens_model.eval(), "attack - resolve attack effects": res_attack_model.eval() } # Do the training. Use the prediction model lifetime to create the loss target. The loss # will be the difference between the max possible round and the predicted round. # For the defense token model the higher round is better, for the attack effect model a # lower round is better. loss_fn = { "attack - spend defense tokens": lambda predictions: 7.0 - predictions, "attack - resolve attack effects": lambda predictions: predictions - 1. } # Generate a new learning model learning_agent = LearningAgent(SeparatePhaseModel()) # TODO FIXME The learning agent doesn't really have a training mode learning_agent.model.train() random_agent = RandomAgent() training_ships = ["All Defense Tokens", "All Defense Tokens", "Imperial II-class Star Destroyer", "MC80 Command Cruiser", "Assault Frigate Mark II A", "No Shield Ship", "One Shield Ship", "Mega Die Ship"] defenders = [] attackers = [] for name in training_ships: attackers.append(Ship(name=name, template=ship_templates[name], upgrades=[], player_number=1, device='cpu')) for name in training_ships: defenders.append(Ship(name=name, template=ship_templates[name], upgrades=[], player_number=2, device='cpu')) batch_size = 32 # Remember the training loss values to test for improvement losses = {} for subphase in ["attack - spend defense tokens", "attack - resolve attack effects"]: losses[subphase] = [] # This gets samples to use for training. We will use a random agent to generate the states. # In reinforcement learning the agent would alternate between random actions and actions # from the learning agent to balance exploration of the state space with exploitation of # learned behavior. samples = get_n_examples( n_examples=1000, ship_a=attackers, ship_b=defenders, agent=random_agent) # Get a batch for each subphase for subphase in ["attack - spend defense tokens", "attack - resolve attack effects"]: # The learning_agent will take in the world state and attack state and will produce a new # action encoding. Then the prediction network will take in a new tuple with this action and # predict a final round. The difference between the random action and the final round with # the network's action is the reward. world_size = Encodings.calculateWorldStateSize() action_size = Encodings.calculateActionSize(subphase) attack_size = Encodings.calculateAttackSize() for batch, lifetimes in collect_attack_batches(batch_size=batch_size, attacks=samples, subphase=subphase): batch = batch.cuda() # The learning agent takes in the world state along with the action as the input tensor. new_batch = torch.cat((batch[:,:world_size], batch[:,world_size + action_size:]), dim=1).cuda() # TODO FIXME Just make a forward function that takes in a phase name new_actions = learning_agent.model.models[subphase](new_batch) new_action_state = torch.cat((batch[:,:world_size], new_actions, batch[:,world_size + action_size:]), dim=1) action_result = prediction_models[subphase](new_action_state) loss = loss_fn[subphase](action_result) learning_agent.model.get_optimizer(subphase).zero_grad() loss.sum().backward() learning_agent.model.get_optimizer(subphase).step() with torch.no_grad(): losses[subphase].append(loss.mean().item()) # In reinforcement learning there would also be a phase where the prediction models are # updated. for subphase in ["attack - spend defense tokens", "attack - resolve attack effects"]: print(f"losses for {subphase} start with {losses[subphase][0:5]} and end with {losses[subphase][-5:]}") assert losses[subphase][-1] < losses[subphase][0] # TODO FIXME HERE See if the policy networks produce better results than random actions learning_agent.model.models.eval()
import MySQLdb import os import numpy as np from learning_agent import LearningAgent from default_hyperparameters import hyperparameters configuration = {"render": False, "graph": False} experiment_name = "lunar_lander_linear_exploration" la = LearningAgent( hyperparameters=hyperparameters, configuration=configuration, env_name="LunarLander-v2", ) evaluations = la.execute() db = MySQLdb.connect( host="dqn-db-instance.coib1qtynvtw.us-west-2.rds.amazonaws.com", user="******", passwd=os.environ['MYSQL_PASS'], db="dqn_results") for evaluation, evaluation_idx in zip(evaluations, range(len(evaluations))): cur = db.cursor() cur.execute( "insert into experiments (label, x1, x2, x3, x4, x5, x6, x7, x8, x9, x10, y, checkpoint, trainingSteps, agent_name) values ('{0}', '{1}', '{2}', '{3}', '{4}', '{5}', '{6}', '{7}', '{8}', '{9}', '{10}', '{11}', '{12}', '{13}', '{14}')" .format(experiment_name, evaluation, evaluation_idx, 0, 0, 0, 0, 0, 0, 0, 0, 0, "checkpoint_" + str(evaluation_idx), 0,
class Experiment: def __init__(self, args): self.epochLength = args.epochLength self.epochNumber = args.epochNumber self.testLength = args.testLength self.initialState = args.initialState self.game = args.game self.graphics = args.graphics self.visible = args.visible self.filename = args.filename self.repeatAction = args.repeatAction self.mergedFrames = args.mergedFrames self.testOnly = args.testOnly self.agent = LearningAgent(args) self.emulator = NesEmulator() self.emulator.loadRom('nes/roms/arkanoid.nes') self.emulator.getNextFrame() self.actions = args.actions if self.game == 'space_invaders': from space_invaders import GameAdapter elif self.game == 'arkanoid': from arkanoid import GameAdapter else: print 'UNKNOWN GAME', self.game self.gameAdapter = GameAdapter(self.emulator) self.agent.load(self.filename + '.pkl') if self.graphics: self.initGraphics() def runDemonstration(self): self.agent.beginTest() while True: self.emulator.loadState(self.initialState) frame = self.preprocess(self.emulator.getNextFrame().reshape( (240, 256, 4))) action = self.agent.begin(frame) self.gameAdapter.reset() start = time.time() while True: self.handleEvents() terminal = False reward = 0 for i in range(4): self.emulator.setController1State(self.actions[action]) if i >= 3: frame = np.maximum( frame, self.emulator.getNextFrame().reshape( (240, 256, 4))) else: frame = self.emulator.getNextFrame().reshape( (240, 256, 4)) reward += self.gameAdapter.update() terminal = terminal or self.gameAdapter.lost() self.displayFrame(frame) end = time.time() toSleep = 0.01666 - (end - start) if toSleep > 0: time.sleep(toSleep) start = time.time() if terminal: self.agent.end(reward) break action = self.agent.tick(self.preprocess(frame), reward) def run(self): if self.testOnly: self.runDemonstration() return stepsLeft = 0 for epoch in range(self.epochNumber): stepsLeft = max(0, self.epochLength + stepsLeft) #For when stepsLeft goes below 0 print 'Running epoch ', epoch + 1 while stepsLeft > 0: steps = self.runEpisode() stepsLeft -= steps print 'Starting testing' self.agent.beginTest() testStepsLeft = self.testLength while testStepsLeft > 0: steps = self.runEpisode() testStepsLeft -= steps self.agent.endTest() self.agent.save(self.filename + '.pkl') def runEpisode(self): self.emulator.loadState(self.initialState) frame = self.preprocess(self.emulator.getNextFrame().reshape( (240, 256, 4))) action = self.agent.begin(frame) self.gameAdapter.reset() steps = 0 while True: self.handleEvents() steps += 1 reward, terminal, frame = self.tick(action) if terminal: self.agent.end(reward) break action = self.agent.tick(self.preprocess(frame), reward) return steps def tick(self, action): reward = 0 terminal = False frame = None for i in range(self.repeatAction): self.emulator.setController1State(self.actions[action]) if i >= (self.repeatAction - self.mergedFrames): frame = np.maximum( frame, self.emulator.getNextFrame().reshape((240, 256, 4))) else: frame = self.emulator.getNextFrame().reshape((240, 256, 4)) reward += self.gameAdapter.update() terminal = terminal or self.gameAdapter.lost() if self.graphics and self.visible: self.displayFrame(frame) return reward, terminal, frame def initGraphics(self): self.screen = pygame.display.set_mode((256, 240)) def closeWindow(self): pygame.display.quit() def displayFrame(self, frame): img = pygame.image.frombuffer(frame.tostring(), (256, 240), 'RGBX') self.screen.blit(img, (0, 0)) pygame.display.flip() def handleEvents(self): if not self.graphics: return for event in pygame.event.get(): if event.type == pygame.QUIT: sys.exit() elif event.type == pygame.KEYDOWN: if event.key == pygame.K_q: sys.exit() elif event.key == pygame.K_v: self.visible = not self.visible elif event.key == pygame.K_p: self.agent.printStuff() def preprocess(self, frame): workingFrame = frame[35:230, 5:200, :] workingFrame = scipy.misc.imresize(workingFrame, (100, 100)) workingFrame = np.dot(workingFrame, [0.299, 0.114, 0.587, 0.0]).astype('uint8') return workingFrame
def test_red_token_encodings(): """Test that the encoding is correct for red defense tokens.""" agent = LearningAgent() attacker = ship.Ship(name="Attacker", template=ship_templates["Attacker"], upgrades=[], player_number=1) two_brace = ship.Ship(name="Double Brace", template=ship_templates["Double Brace"], upgrades=[], player_number=2) two_redirect = ship.Ship(name="Double Redirect", template=ship_templates["Double Redirect"], upgrades=[], player_number=2) two_evade = ship.Ship(name="Double Evade", template=ship_templates["Double Evade"], upgrades=[], player_number=2) two_contain = ship.Ship(name="Double Contain", template=ship_templates["Double Contain"], upgrades=[], player_number=2) two_scatter = ship.Ship(name="Double Scatter", template=ship_templates["Double Scatter"], upgrades=[], player_number=2) two_brace.spend_token('brace', ArmadaTypes.green) two_brace.spend_token('brace', ArmadaTypes.green) two_redirect.spend_token('redirect', ArmadaTypes.green) two_redirect.spend_token('redirect', ArmadaTypes.green) two_evade.spend_token('evade', ArmadaTypes.green) two_evade.spend_token('evade', ArmadaTypes.green) two_contain.spend_token('contain', ArmadaTypes.green) two_contain.spend_token('contain', ArmadaTypes.green) two_scatter.spend_token('scatter', ArmadaTypes.green) two_scatter.spend_token('scatter', ArmadaTypes.green) enc_two_brace = make_encoding(attacker, two_brace, "short", agent)[0] enc_two_redirect = make_encoding(attacker, two_redirect, "short", agent)[0] enc_two_evade = make_encoding(attacker, two_evade, "short", agent)[0] enc_two_contain = make_encoding(attacker, two_contain, "short", agent)[0] enc_two_scatter = make_encoding(attacker, two_scatter, "short", agent)[0] # Check the red token section ttensor = torch.zeros(len(ArmadaTypes.defense_tokens)) ttensor[ArmadaTypes.defense_tokens.index("brace")] = 2.0 assert torch.allclose(get_defender_tokens(enc_two_brace, "red"), ttensor) ttensor = torch.zeros(len(ArmadaTypes.defense_tokens)) ttensor[ArmadaTypes.defense_tokens.index("redirect")] = 2.0 assert torch.allclose(get_defender_tokens(enc_two_redirect, "red"), ttensor) ttensor = torch.zeros(len(ArmadaTypes.defense_tokens)) ttensor[ArmadaTypes.defense_tokens.index("evade")] = 2.0 assert torch.allclose(get_defender_tokens(enc_two_evade, "red"), ttensor) ttensor = torch.zeros(len(ArmadaTypes.defense_tokens)) ttensor[ArmadaTypes.defense_tokens.index("contain")] = 2.0 assert torch.allclose(get_defender_tokens(enc_two_contain, "red"), ttensor) ttensor = torch.zeros(len(ArmadaTypes.defense_tokens)) ttensor[ArmadaTypes.defense_tokens.index("scatter")] = 2.0 assert torch.allclose(get_defender_tokens(enc_two_scatter, "red"), ttensor)
def test_token_encodings(): """Test that the encoding is correct for different defense tokens.""" agent = LearningAgent() no_token = ship.Ship(name="No Defense Tokens", template=ship_templates["No Defense Tokens"], upgrades=[], player_number=1) one_brace = ship.Ship(name="Single Brace", template=ship_templates["Single Brace"], upgrades=[], player_number=2) two_brace = ship.Ship(name="Double Brace", template=ship_templates["Double Brace"], upgrades=[], player_number=3) two_redirect = ship.Ship(name="Double Redirect", template=ship_templates["Double Redirect"], upgrades=[], player_number=4) two_evade = ship.Ship(name="Double Evade", template=ship_templates["Double Evade"], upgrades=[], player_number=5) two_contain = ship.Ship(name="Double Contain", template=ship_templates["Double Contain"], upgrades=[], player_number=6) two_scatter = ship.Ship(name="Double Scatter", template=ship_templates["Double Scatter"], upgrades=[], player_number=7) # Encode some attack states enc_one_brace = make_encoding(no_token, one_brace, "short", agent)[0] enc_two_brace = make_encoding(one_brace, two_brace, "short", agent)[0] enc_two_redirect = make_encoding(two_brace, two_redirect, "short", agent)[0] enc_two_evade = make_encoding(two_redirect, two_evade, "short", agent)[0] enc_two_contain = make_encoding(two_evade, two_contain, "short", agent)[0] enc_two_scatter = make_encoding(two_contain, two_scatter, "short", agent)[0] # Order of tokens in the encoding # token_types = ["evade", "brace", "scatter", "contain", "redirect"] # Check the green token section ttensor = torch.zeros(len(ArmadaTypes.defense_tokens)) ttensor[ArmadaTypes.defense_tokens.index("brace")] = 1.0 assert torch.allclose(get_defender_tokens(enc_one_brace, "green"), ttensor) ttensor = torch.zeros(len(ArmadaTypes.defense_tokens)) ttensor[ArmadaTypes.defense_tokens.index("brace")] = 2.0 assert torch.allclose(get_defender_tokens(enc_two_brace, "green"), ttensor) ttensor = torch.zeros(len(ArmadaTypes.defense_tokens)) ttensor[ArmadaTypes.defense_tokens.index("redirect")] = 2.0 assert torch.allclose(get_defender_tokens(enc_two_redirect, "green"), ttensor) ttensor = torch.zeros(len(ArmadaTypes.defense_tokens)) ttensor[ArmadaTypes.defense_tokens.index("evade")] = 2.0 assert torch.allclose(get_defender_tokens(enc_two_evade, "green"), ttensor) ttensor = torch.zeros(len(ArmadaTypes.defense_tokens)) ttensor[ArmadaTypes.defense_tokens.index("contain")] = 2.0 assert torch.allclose(get_defender_tokens(enc_two_contain, "green"), ttensor) ttensor = torch.zeros(len(ArmadaTypes.defense_tokens)) ttensor[ArmadaTypes.defense_tokens.index("scatter")] = 2.0 assert torch.allclose(get_defender_tokens(enc_two_scatter, "green"), ttensor)
# There is a lot of randomness in dice rolls, so our goal will be to train a density model. This # means that we will train the network to estimate the parameters of the distribution that describes # the lifetime. In plainer words, we will train the network to predict the average lifetime of the # vessel and the undertainty of that prediction. We are rolling three kinds of dice, so in reality # this is the combination of three probability distributions, but to simplify things we will # estimate it as a single normal. # The loss will be the negative log of the probability of an outcome given a predicted mean and # standard deviation. This works because as the probability approaches zero the negative log of it # approaches infinity. So if the outputs the mean and sigma the loss will be: # > d = torch.distributions(mean, sigma) # > loss = -d.log_prob(result) # Create a learning agent. This will initialize the model. # TODO The learning agent should be the one to take random actions if it is not # using novelty for exploration. prediction_agent = LearningAgent(ArmadaModel(with_novelty=args.novelty)) # Load a previously trained model for additional training # TODO FIXME HERE Reloading with the new novelty training stuff is not working if os.path.isfile(args.filename): prediction_agent.model.load(args.filename) optimizer = prediction_agent.model.get_optimizer("def_tokens") if args.novelty: novelty_optimizer = prediction_agent.model.get_optimizer( "def_tokens_novelty") examples = [] prediction_params = [] batch_target = [] novelties = [] loss_fn = torch.nn.MSELoss()
class Experiment: def __init__(self, args): self.epochLength = args.epochLength self.epochNumber = args.epochNumber self.testLength = args.testLength self.initialState = args.initialState self.game = args.game self.graphics = args.graphics self.visible = args.visible self.filename = args.filename self.repeatAction = args.repeatAction self.mergedFrames = args.mergedFrames self.testOnly = args.testOnly self.agent = LearningAgent(args) self.emulator = NesEmulator() self.emulator.loadRom('nes/roms/arkanoid.nes') self.emulator.getNextFrame() self.actions = args.actions if self.game == 'space_invaders': from space_invaders import GameAdapter elif self.game == 'arkanoid': from arkanoid import GameAdapter else: print 'UNKNOWN GAME', self.game self.gameAdapter = GameAdapter(self.emulator) self.agent.load(self.filename + '.pkl') if self.graphics: self.initGraphics() def runDemonstration(self): self.agent.beginTest() while True: self.emulator.loadState(self.initialState) frame = self.preprocess(self.emulator.getNextFrame().reshape((240,256,4))) action = self.agent.begin(frame) self.gameAdapter.reset() start = time.time() while True: self.handleEvents() terminal = False reward = 0 for i in range(4): self.emulator.setController1State(self.actions[action]) if i >= 3: frame = np.maximum(frame, self.emulator.getNextFrame().reshape((240,256,4))) else: frame = self.emulator.getNextFrame().reshape((240,256,4)) reward += self.gameAdapter.update() terminal = terminal or self.gameAdapter.lost() self.displayFrame(frame) end = time.time() toSleep = 0.01666 - (end - start) if toSleep > 0: time.sleep(toSleep) start = time.time() if terminal: self.agent.end(reward) break action = self.agent.tick(self.preprocess(frame), reward) def run(self): if self.testOnly: self.runDemonstration() return stepsLeft = 0 for epoch in range(self.epochNumber): stepsLeft = max(0, self.epochLength + stepsLeft) #For when stepsLeft goes below 0 print 'Running epoch ', epoch+1 while stepsLeft > 0: steps = self.runEpisode() stepsLeft -= steps print 'Starting testing' self.agent.beginTest() testStepsLeft = self.testLength while testStepsLeft > 0: steps = self.runEpisode() testStepsLeft -= steps self.agent.endTest() self.agent.save(self.filename + '.pkl') def runEpisode(self): self.emulator.loadState(self.initialState) frame = self.preprocess(self.emulator.getNextFrame().reshape((240,256,4))) action = self.agent.begin(frame) self.gameAdapter.reset() steps = 0 while True: self.handleEvents() steps += 1 reward, terminal, frame = self.tick(action) if terminal: self.agent.end(reward) break action = self.agent.tick(self.preprocess(frame), reward) return steps def tick(self, action): reward = 0 terminal = False frame = None for i in range(self.repeatAction): self.emulator.setController1State(self.actions[action]) if i >= (self.repeatAction - self.mergedFrames): frame = np.maximum(frame, self.emulator.getNextFrame().reshape((240,256,4))) else: frame = self.emulator.getNextFrame().reshape((240,256,4)) reward += self.gameAdapter.update() terminal = terminal or self.gameAdapter.lost() if self.graphics and self.visible: self.displayFrame(frame) return reward, terminal, frame def initGraphics(self): self.screen = pygame.display.set_mode((256,240)) def closeWindow(self): pygame.display.quit() def displayFrame(self, frame): img = pygame.image.frombuffer(frame.tostring(), (256,240), 'RGBX') self.screen.blit(img, (0,0)) pygame.display.flip() def handleEvents(self): if not self.graphics: return for event in pygame.event.get(): if event.type == pygame.QUIT: sys.exit() elif event.type == pygame.KEYDOWN: if event.key == pygame.K_q: sys.exit() elif event.key == pygame.K_v: self.visible = not self.visible elif event.key == pygame.K_p: self.agent.printStuff() def preprocess(self, frame): workingFrame = frame[35:230, 5:200, :] workingFrame = scipy.misc.imresize(workingFrame, (100,100)) workingFrame = np.dot(workingFrame, [0.299, 0.114, 0.587, 0.0]).astype('uint8') return workingFrame