Пример #1
0
def train(args):
    # Init gym env
    env = Environment(args)
    agent = Q_Agent(env, args)
    stats = Statistics(None, agent, env, args)
    if not args.load_model.isspace(): agent.loadModel(args.load_model)

    # Train agent
    try:
        for epoch in range(args.epochs):
            print 'Epoch #%d' % (epoch + 1)

            if args.train_steps > 0:
                print 'Training for %d steps' % args.train_steps
                agent.train(args.train_steps)
                stats.write(epoch + 1, 'train', tensorboard=False)

            if args.test_eps > 0:
                print 'Testing for %d steps' % args.test_eps
                agent.test(args.test_eps, render=args.render)
                stats.write(epoch + 1, 'test', tensorboard=False)

        agent.saveModel(args.job_name)
        stats.plot()
        stats.close()
        agent.test(10, render=True)
        print 'Done'
    except KeyboardInterrupt:
        agent.saveModel(args.job_name)
        stats.plot()
        stats.close()
        agent.test(10, render=True)
        print 'Caught keyboard interrupt, stopping run...'
Пример #2
0
def test(args):
    with tf.Session() as sess:
        # Init environment and agent
        env = Environment(args)
        agent = DQN_Agent(sess, env, args)
        stats = Statistics(sess, agent, env, args)
        if args.load_model: stats.loadModel()
        sess.graph.finalize()

        try:
            print 'Taking %d random actions before training' % args.steps_pre_train
            agent.randomExplore(args.steps_pre_train)

            agent.startEnqueueThreads()
            for epoch in range(args.epochs):
                print 'Epoch #%d' % (epoch + 1)

                if args.test_eps > 0:
                    print 'Testing for %d episodes' % args.test_eps
                    agent.test(args.test_eps, render=True)
                    stats.write(epoch + 1, 'test')

            agent.stopEnqueueThreads()
            stats.plot()
            stats.close()
            print 'Done'
        except KeyboardInterrupt:
            stats.plot()
            stats.close()
            agent.stopEnqueueThreads()
            print 'Caught keyboard interrupt, stopping run...'
Пример #3
0
def main():
    # Process arguments
    args = utils.parse_args()

    # Use random seed from argument
    if args.random_seed:
        random.seed(args.random_seed)

    # Instantiate environment class
    if args.environment == "ale":
        env = ALEEnvironment(args.game, args)
    elif args.environment == "gym":
        env = GymEnvironment(args.game, args)
    elif args.environment == "robot":
        env = RobotEnvironment(args.game, args)
    else:
        assert False, "Unknown environment" + args.environment

    # Instantiate DQN
    action_dim = env.action_dim()
    state_dim = env.state_dim()
    net = DQN(state_dim, action_dim, args)

    # Load weights before starting training
    if args.load_weights:
        filepath = args.load_weights
        net.load(filepath)

    # Instantiate agent
    agent = Agent(env, net, args)

    # Start statistics
    stats = Statistics(agent, agent.net, agent.net.memory, env, args)

    # Play game with two players (user and agent)
    if args.two_player:
        player_b = PlayerTwo(args)
        env.set_mode('test')
        stats.reset()
        agent.play_two_players(player_b)
        stats.write(0, "2player")
        sys.exit()

    # Play agent
    if args.play_games > 0:
        env.set_mode('test')
        stats.reset()
        for _ in range(args.play_games):
            agent.play()
        stats.write(0, "play")
        sys.exit()

    # Populate replay memory with random steps
    if args.random_steps:
        env.set_mode('test')
        stats.reset()
        agent.play_random(args.random_steps)
        stats.write(0, "random")

    for epoch in range(args.start_epoch, args.epochs):
        # Train agent
        if args.train_steps:
            env.set_mode('train')
            stats.reset()
            agent.train(args.train_steps)
            stats.write(epoch + 1, "train")

            # Save weights after every epoch
            if args.save_weights_prefix:
                filepath = args.save_weights_prefix + "_%d.h5" % (epoch + 1)
                net.save(filepath)

        # Test agent
        if args.test_steps:
            env.set_mode('test')
            stats.reset()
            agent.test(args.test_steps)
            stats.write(epoch + 1, "test")

    # Stop statistics
    stats.close()
Пример #4
0
    agent.play_random(args.random_steps, args)
    stats.write(0, "random")

# loop over epochs
for epoch in range(args.start_epoch, args.epochs):
    logger.info("Epoch #%d" % (epoch + 1))

    if args.train_steps:
        logger.info(" Training for %d steps" % args.train_steps)
        # Set env mode test so that loss of life is considered as terminal
        env.setMode('train')
        stats.reset()
        agent.train(args.train_steps, epoch)
        stats.write(epoch + 1, "train")

        if args.save_weights_prefix:
            filename = args.save_weights_prefix + "_%d.prm" % (epoch + 1)
            logger.info("Saving weights to %s" % filename)
            net.save_weights(filename)

    if args.test_steps:
        logger.info(" Testing for %d steps" % args.test_steps)
        # Set env mode test so that loss of life is not considered as terminal
        env.setMode('test')
        stats.reset()
        agent.test(args.test_steps, epoch)
        stats.write(epoch + 1, "test")

stats.close()
logger.info("All done")
Пример #5
0
if args.random_steps:
  # populate replay memory with random steps
  logger.info("Populating replay memory with %d random moves" % args.random_steps)
  stats.reset()
  agent.play_random(args.random_steps)
  stats.write(0, "random")

# loop over epochs
for epoch in xrange(args.epochs):
  logger.info("Epoch #%d" % (epoch + 1))

  if args.train_steps:
    logger.info(" Training for %d steps" % args.train_steps)
    stats.reset()
    agent.train(args.train_steps, epoch)
    stats.write(epoch + 1, "train")

    if args.save_weights_prefix:
      filename = args.save_weights_prefix + "_%d.prm" % (epoch + 1)
      logger.info("Saving weights to %s" % filename)
      net.save_weights(filename)

  if args.test_steps:
    logger.info(" Testing for %d steps" % args.test_steps)
    stats.reset()
    agent.test(args.test_steps, epoch)
    stats.write(epoch + 1, "test")

stats.close()
logger.info("All done")