Esempio n. 1
0
def inference(args):
    """
    It only restores LSTMPolicy architecture, and does inference using that.
    """
    # get address of checkpoints
    indir = os.path.join(args.log_dir, 'train')
    outdir = os.path.join(
        args.log_dir, 'inference') if args.out_dir is None else args.out_dir
    with open(indir + '/checkpoint', 'r') as f:
        first_line = f.readline().strip()
    ckpt = first_line.split(' ')[-1].split('/')[-1][:-1]
    ckpt = ckpt.split('-')[-1]
    ckpt = indir + '/model.ckpt-' + ckpt

    # define environment
    if args.record:
        env = create_env(args.env_id,
                         client_id='0',
                         remotes=None,
                         envWrap=args.envWrap,
                         designHead=args.designHead,
                         record=True,
                         noop=args.noop,
                         acRepeat=args.acRepeat,
                         outdir=outdir)
    else:
        env = create_env(args.env_id,
                         client_id='0',
                         remotes=None,
                         envWrap=args.envWrap,
                         designHead=args.designHead,
                         record=True,
                         noop=args.noop,
                         acRepeat=args.acRepeat)
    numaction = env.action_space.n

    with tf.device("/cpu:0"):
        # define policy network
        with tf.variable_scope("global"):
            policy = LSTMPolicy(env.observation_space.shape, numaction,
                                args.designHead)
            policy.global_step = tf.get_variable(
                "global_step", [],
                tf.int32,
                initializer=tf.constant_initializer(0, dtype=tf.int32),
                trainable=False)

        # Variable names that start with "local" are not saved in checkpoints.
        if use_tf12_api:
            variables_to_restore = [
                v for v in tf.global_variables()
                if not v.name.startswith("local")
            ]
            init_all_op = tf.global_variables_initializer()
        else:
            variables_to_restore = [
                v for v in tf.all_variables() if not v.name.startswith("local")
            ]
            init_all_op = tf.initialize_all_variables()
        saver = FastSaver(variables_to_restore)

        # print trainable variables
        var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
                                     tf.get_variable_scope().name)
        logger.info('Trainable vars:')
        for v in var_list:
            logger.info('  %s %s', v.name, v.get_shape())

        # summary of rewards
        action_writers = []
        if use_tf12_api:
            summary_writer = tf.summary.FileWriter(outdir)
            for ac_id in range(numaction):
                action_writers.append(
                    tf.summary.FileWriter(
                        os.path.join(outdir, 'action_{}'.format(ac_id))))
        else:
            summary_writer = tf.train.SummaryWriter(outdir)
            for ac_id in range(numaction):
                action_writers.append(
                    tf.train.SummaryWriter(
                        os.path.join(outdir, 'action_{}'.format(ac_id))))
        logger.info("Inference events directory: %s", outdir)

        config = tf.ConfigProto(allow_soft_placement=True,
                                log_device_placement=False)
        with tf.Session(config=config) as sess:
            logger.info("Initializing all parameters.")
            sess.run(init_all_op)
            logger.info("Restoring trainable global parameters.")
            saver.restore(sess, ckpt)
            logger.info("Restored model was trained for %.2fM global steps",
                        sess.run(policy.global_step) / 1000000.)
            #saving with meta graph:
            metaSaver = tf.train.Saver(variables_to_restore)
            metaSaver.save(
                sess, '/home/swagking0/noreward-rl/models/models_me/mario_me')

            last_state = env.reset()
            if args.render or args.record:
                env.render()
            last_features = policy.get_initial_features()  # reset lstm memory
            length = 0
            rewards = 0
            mario_distances = np.zeros((args.num_episodes, ))
            for i in range(args.num_episodes):
                print("Starting episode %d" % (i + 1))
                if args.recordSignal:
                    from PIL import Image
                    signalCount = 1
                    utils.mkdir_p(outdir + '/recordedSignal/ep_%02d/' % i)
                    Image.fromarray(
                        (255 * last_state[..., -1]).astype('uint8')).save(
                            outdir + '/recordedSignal/ep_%02d/%06d.jpg' %
                            (i, signalCount))

                if args.random:
                    print('I am random policy!')
                else:
                    if args.greedy:
                        print('I am greedy policy!')
                    else:
                        print('I am sampled policy!')
                while True:
                    # run policy
                    fetched = policy.act_inference(last_state, *last_features)
                    prob_action, action, value_, features = fetched[
                        0], fetched[1], fetched[2], fetched[3:]

                    # run environment: sampled one-hot 'action' (not greedy)
                    if args.random:
                        stepAct = np.random.randint(0,
                                                    numaction)  # random policy
                    else:
                        if args.greedy:
                            stepAct = prob_action.argmax()  # greedy policy
                        else:
                            stepAct = action.argmax()
                    # print(stepAct, prob_action.argmax(), prob_action)
                    state, reward, terminal, info = env.step(stepAct)

                    # update stats
                    length += 1
                    rewards += reward
                    last_state = state
                    last_features = features
                    if args.render or args.record:
                        env.render()
                    if args.recordSignal:
                        signalCount += 1
                        Image.fromarray(
                            (255 * last_state[..., -1]).astype('uint8')).save(
                                outdir + '/recordedSignal/ep_%02d/%06d.jpg' %
                                (i, signalCount))

                    # store summary
                    summary = tf.Summary()
                    summary.value.add(tag='ep_{}/reward'.format(i),
                                      simple_value=reward)
                    summary.value.add(tag='ep_{}/netreward'.format(i),
                                      simple_value=rewards)
                    summary.value.add(tag='ep_{}/value'.format(i),
                                      simple_value=float(value_[0]))
                    if 'NoFrameskip-v' in args.env_id:  # atari
                        summary.value.add(
                            tag='ep_{}/lives'.format(i),
                            simple_value=env.unwrapped.ale.lives())
                    summary_writer.add_summary(summary, length)
                    summary_writer.flush()
                    summary = tf.Summary()
                    for ac_id in range(numaction):
                        summary.value.add(tag='action_prob',
                                          simple_value=float(
                                              prob_action[ac_id]))
                        action_writers[ac_id].add_summary(summary, length)
                        action_writers[ac_id].flush()

                    timestep_limit = env.spec.tags.get(
                        'wrapper_config.TimeLimit.max_episode_steps')
                    if timestep_limit is None:
                        timestep_limit = env.spec.timestep_limit
                    if terminal or length >= timestep_limit:
                        if length >= timestep_limit or not env.metadata.get(
                                'semantics.autoreset'):
                            last_state = env.reset()
                        last_features = policy.get_initial_features(
                        )  # reset lstm memory
                        print(
                            "Episode finished. Sum of rewards: %.2f. Length: %d."
                            % (rewards, length))
                        if 'distance' in info:
                            print('Mario Distance Covered:', info['distance'])
                            mario_distances[i] = info['distance']
                        length = 0
                        rewards = 0
                        if args.render or args.record:
                            env.render()
                        if args.recordSignal:
                            signalCount += 1
                            Image.fromarray(
                                (255 *
                                 last_state[..., -1]).astype('uint8')).save(
                                     outdir +
                                     '/recordedSignal/ep_%02d/%06d.jpg' %
                                     (i, signalCount))
                        break

        logger.info('Finished %d true episodes.', args.num_episodes)
        if 'distance' in info:
            print('Mario Distances:', mario_distances)
            np.save(outdir + '/distances.npy', mario_distances)
        env.close()
Esempio n. 2
0
def inference(args):
    indir = os.path.join(args.log_dir, 'train')
    outdir = os.path.join(args.log_dir, 'player') if args.out_dir is None else args.out_dir

    with open(indir + "/checkpoint", "r") as f:
        first_line = f.readline().strip()
        print ("first_line is : {}".format(first_line))
    ckpt = first_line.split(' ')[-1].split('/')[-1][:-1]
    ckpt = ckpt.split('-')[-1]
    ckpt = indir + '/model.ckpt-' + ckpt

    print ("ckpt: {}".format(ckpt))

    # define environment
    env = create_icegame_env(outdir, args.env_id)
    num_actions = env.action_space.n

    with tf.device("/cpu:0"):
        # define policy network
        with tf.variable_scope("global"):
            policy = LSTMPolicy(env.observation_space.shape, num_actions)
            policy.global_step = tf.get_variable("global_step", [], 
                    tf.int32, initializer=tf.constant_initializer(0, dtype=tf.int32), trainable=False)
        # Variable names that start with "local" are not saved in checkpoints.
        variables_to_restore = [v for v in tf.global_variables() if not v.name.startswith("local")]
        init_all_op = tf.global_variables_initializer()

        saver = FastSaver(variables_to_restore)

        # print trainable variables
        var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, tf.get_variable_scope().name)
        logger.info('Trainable vars:')
        for v in var_list:
            logger.info('  {} {}'.format(v.name, v.get_shape()))
        logger.info("Restored the trained model.")

        # summary of rewards
        action_writers = []
        summary_writer = tf.summary.FileWriter(outdir)
        for act_idx in range(num_actions):
            action_writers.append(tf.summary.FileWriter(
                os.path.join(outdir, "action_{}".format(act_idx))
            ))

        logger.info("Inference events directory: %s", outdir)
        config = tf.ConfigProto(allow_soft_placement=True, log_device_placement=False)

        with tf.Session() as sess:
            logger.info("Initializing all parameters.")
            sess.run(init_all_op)
            logger.info("Restoring trainable global parameters.")
            saver.restore(sess, ckpt)
            logger.info("Restored model was trained for %.2fM global steps", sess.run(policy.global_step)/1000000.)

            last_features = policy.get_initial_features()  # reset lstm memory
            length = 0
            rewards = 0

            # For plotting
            plt.ion()
            fig = plt.figure(num=None, figsize=(8, 8), dpi=92, facecolor='w', edgecolor='k')

            gs1 = gridspec.GridSpec(3, 3)
            gs1.update(left=0.05, right=0.85, wspace=0.15)
            ax1 = plt.subplot(gs1[:-1, :])
            ax2 = plt.subplot(gs1[-1, :-1])
            ax3 = plt.subplot(gs1[-1, -1])

            ax1.set_title("IceGame (Agent Lives: {}, UpTimes: {})".format(env.lives, env.sim.get_updated_counter()))

            ind = np.arange(num_actions)
            width = 0.20
            #action_legends = ["Up", "Down", "Left", "Right", "NextUp", "NextDown", "Metropolis"]
            action_legends = [">", "v", "<", "^", "", "", "Metro"]

            for ep in range(args.num_episodes):
                """TODO: policy sampling strategy
                    random, greedy and sampled policy.
                """

                last_state = env.reset()
                steps_rewards=[]
                steps_values=[]

                # running policy
                while True:
                    fetched = policy.act_inference(last_state, *last_features)
                    prob_action, action, value_, features = fetched[0], fetched[1], fetched[2], fetched[3:]

                    #TODO: policy sampling strategy

                    # Greedy
                    #print ("Prob of actions: {}".format(prob_action))
                    stepAct = action.argmax()
                    state, reward, terminal, info = env.step(stepAct)

                    # update stats
                    length += 1
                    rewards += reward
                    last_state = state
                    last_features = features
                    steps_rewards.append(rewards)
                    steps_values.append(value_)

                    if info:
                        loopsize = info["Loop Size"]
                        looparea = info["Loop Area"]

                    """Animation for State and Actions
                    """
                    ax2.clear()
                    ax2.bar(ind, prob_action)
                    ax2.set_xticks(ind + width / 2)
                    ax2.set_xticklabels(action_legends)

                    ax1.imshow(state[:,:,2], 'Reds', interpolation="None",  vmin=-1, vmax=1)
                    # with hist
                    #ax1.imshow(state[:,:,7], 'Reds', interpolation="None",  vmin=-1, vmax=1)
                    ax1.set_title("IceGame: (Agent Lives: {}, UpTimes: {})".format(env.lives, env.sim.get_updated_counter()))

                    ax3.clear()
                    ax3.plot(steps_rewards, linewidth=2)
                    ax3.plot(steps_values, linewidth=2)
                    #plt.savefig("records/{}.png".format(length))

                    plt.pause(0.20)

                    # store summary
                    summary = tf.Summary()
                    summary.value.add(tag='ep_{}/reward'.format(ep), simple_value=reward)
                    summary.value.add(tag='ep_{}/netreward'.format(ep), simple_value=rewards)
                    summary.value.add(tag='ep_{}/value'.format(ep), simple_value=float(value_[0]))

                    if info:
                        summary.value.add(tag='ep_{}/loop_size'.format(ep), simple_value=loopsize)
                        summary.value.add(tag='ep_{}/loop_area'.format(ep), simple_value=looparea)

                    summary_writer.add_summary(summary, length)
                    summary_writer.flush()

                    summary = tf.Summary()
                    for ac_id in range(num_actions):
                        summary.value.add(tag='ep_{}/a_{}'.format(ep, ac_id), simple_value=float(prob_action[ac_id]))
                        action_writers[ac_id].add_summary(summary, length)
                        action_writers[ac_id].flush()

                    """TODO:
                        1. Need more concrete idea for playing the game when interfering.
                        2. Save these values for post processing.
                    """
                    if terminal:
                        #if length >= timestep_limit:
                        #    last_state, _, _, _ = env.reset()

                        last_features = policy.get_initial_features()  # reset lstm memory
                        print("Episode finished. Sum of rewards: %.2f. Length: %d." % (rewards, length))

                        length = 0
                        rewards = 0
                        break

        logger.info('Finished %d true episodes.', args.num_episodes)
        plt.savefig("GameScene.png")
        logger.info("Save the last scene to GameScene.png")
        env.close()
Esempio n. 3
0
def inference(args):
    indir = os.path.join(args.logdir, 'train')
    outdir = os.path.join(args.logdir,
                          'player') if args.outdir is None else args.outdir
    if not os.path.exists(outdir):
        os.makedirs(outdir)

    with open(indir + "/checkpoint", "r") as f:
        first_line = f.readline().strip()
        print("first_line is : {}".format(first_line))
    ckpt = first_line.split(' ')[-1].split('/')[-1][:-1]
    ckpt = ckpt.split('-')[-1]
    ckpt = indir + '/model.ckpt-' + ckpt

    print("ckpt: {}".format(ckpt))

    # define environment
    #env = create_icegame_env(outdir, args.env_id, args)
    env = create_icegame_env(outdir, args.env_id)
    # define environment
    local_space = env.local_observation_space.n
    global_space = env.global_observation_space.shape
    action_space = env.action_space.n

    # resize the system and enable subregion
    #if env.L != args.system_size:
    #    print ("Enlarge the system {} --> {}".format(env.L, args.system_size))
    #    env.resize_ice_config(args.system_size, args.mcsteps)
    #    env.dump_env_setting()
    #    env.save_ice()

    # our trained cnn always 32, 32
    env.enable_subregion()
    print("Enable sub-region mechanism.")

    # policy recoder
    ppath = os.path.join(outdir, "episodes")
    if not os.path.exists(ppath):
        os.makedirs(ppath)
    pirec = PolicyRecorder(ppath)

    with tf.device("/cpu:0"):
        # define policy network
        with tf.variable_scope("global"):
            if args.policy == "simple":
                policy = models.SimplePolicy(global_space, local_space,
                                             action_space, args)
            elif args.policy == "cnn":
                policy = models.CNNPolicy(global_space, local_space,
                                          action_space, args)
            policy.global_step = tf.get_variable(
                "global_step", [],
                tf.int32,
                initializer=tf.constant_initializer(0, dtype=tf.int32),
                trainable=False)
        # Variable names that start with "local" are not saved in checkpoints.
        variables_to_restore = [
            v for v in tf.global_variables() if not v.name.startswith("local")
        ]
        init_all_op = tf.global_variables_initializer()

        saver = FastSaver(variables_to_restore)

        # print trainable variables
        var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
                                     tf.get_variable_scope().name)
        logger.info('Trainable vars:')
        for v in var_list:
            logger.info('  {} {}'.format(v.name, v.get_shape()))
        logger.info("Restored the trained model.")

        # summary of rewards
        action_writers = []
        summary_writer = tf.summary.FileWriter(outdir)
        """NOT so useful.
        for act_idx in range(action_space):
            action_writers.append(tf.summary.FileWriter(
                os.path.join(outdir, "action_{}".format(act_idx))
            ))
        """

        logger.info("Inference events directory: %s", outdir)
        config = tf.ConfigProto(allow_soft_placement=True,
                                log_device_placement=False)

        with tf.Session() as sess:
            logger.info("Initializing all parameters.")
            sess.run(init_all_op)
            logger.info("Restoring trainable global parameters.")
            saver.restore(sess, ckpt)
            logger.info("Restored model was trained for %.2fM global steps",
                        sess.run(policy.global_step) / 1000000.)

            #last_features = policy.get_initial_features()  # reset lstm memory
            length = 0
            rewards = 0

            # For plotting
            if args.render:
                import matplotlib.pyplot as plt
                import matplotlib.gridspec as gridspec

                plt.ion()
                fig = plt.figure(num=None,
                                 figsize=(8, 8),
                                 dpi=92,
                                 facecolor='w',
                                 edgecolor='k')

                gs1 = gridspec.GridSpec(3, 3)
                gs1.update(left=0.05, right=0.85, wspace=0.15)
                ax1 = plt.subplot(gs1[:-1, :])
                ax2 = plt.subplot(gs1[-1, :-1])
                ax3 = plt.subplot(gs1[-1, -1])

                ax1.set_title("IceGame (UpTimes: {})".format(
                    env.sim.get_updated_counter()))

                ind = np.arange(action_space)
                width = 0.20
                action_legends = [
                    "head_0", "head_1", "head_2", "tail_0", "tail_1", "tail_2",
                    "Metro"
                ]

                steps_energies = []

            for ep in range(args.num_tests):
                """TODO: policy sampling strategy
                    random, greedy and sampled policy.
                """

                env.start(create_defect=True)
                last_state = env.reset()
                # these for plotting
                steps_rewards = []
                steps_values = []
                step = 0

                # policy recorder
                pirec.attach_episode(ep)
                # TODO: Call save_ice here?

                # running policy
                while True:
                    fetched = policy.act_inference(last_state)
                    prob_action, action, value_ = fetched[0], fetched[
                        1], fetched[2]
                    """TODO: Policy Recorder
                        * prob_action
                        * value_
                        * local config
                        * init_config (of course, but store in other way.)
                        * Store all cases
                        Q: Can we put these in env_hist.json?
                    """

                    stepAct = action.argmax()
                    state, reward, terminal, info = env.step(stepAct)
                    local = last_state.local_obs.tolist()
                    pi_ = prob_action.tolist()
                    value_ = value_.tolist()[0]
                    action_ = action.tolist()

                    # TODO: We need env 'weights', p(s, s', a) = ? (what the f**k is it?)
                    # And we also want some physical observables
                    pirec.push_step(step, stepAct, pi_, value_, local, reward)

                    # update stats
                    length += 1
                    step += 1
                    rewards += reward
                    last_state = state

                    if info:
                        loopsize = info["Loop Size"]
                        looparea = info["Loop Area"]
                    """Animation for State and Actions
                        Show Energy Bar On Screen.
                    """

                    if args.render:
                        # save list for plotting
                        steps_rewards.append(rewards)
                        steps_values.append(value_)

                        ax2.clear()
                        ax2.bar(ind, prob_action)
                        ax2.set_xticks(ind + width / 2)
                        ax2.set_xticklabels(action_legends)

                        canvas = state.global_obs[:, :, 0]
                        ax1.clear()
                        ax1.imshow(canvas,
                                   'Reds',
                                   interpolation="None",
                                   vmin=-1,
                                   vmax=1)
                        ax1.set_title("IceGame: (UpTimes: {})".format(
                            env.sim.get_updated_counter()))

                        ax3.clear()
                        ax3.plot(steps_energies, linewidth=2)

                        plt.pause(0.05)
                    """TODO:
                        1. Need more concrete idea for playing the game when interfering.
                        2. Save these values for post processing.
                        3. We need penalty for timeout. --> Move timeout into env.
                    """
                    if terminal:
                        print(
                            "Episode finished. Sum of rewards: %.2f. Length: %d."
                            % (rewards, length))
                        pirec.dump_episode()
                        length = 0
                        rewards = 0
                        step = 0
                        break

        logger.info('Finished %d true episodes.', args.num_tests)
        if args.render:
            plt.savefig("GameScene.png")
        logger.info("Save the last scene to GameScene.png")
        env.close()
Esempio n. 4
0
def inference(args):
    indir = os.path.join(args.log_dir, 'train')
    outdir = os.path.join(
        args.log_dir, 'inference') if args.out_dir is None else args.out_dir

    with open(indir + "/checkpoint", "r") as f:
        first_line = f.readline().strip()
        print("first_line is : {}".format(first_line))
    ckpt = first_line.split(' ')[-1].split('/')[-1][:-1]
    ckpt = ckpt.split('-')[-1]
    ckpt = indir + '/model.ckpt-' + ckpt

    print("ckpt: {}".format(ckpt))

    # define environment
    env = create_icegame_env(outdir, args.env_id)
    num_actions = env.action_space.n

    with tf.device("/cpu:0"):
        # define policy network
        with tf.variable_scope("global"):
            policy = LSTMPolicy(env.observation_space.shape, num_actions)
            policy.global_step = tf.get_variable(
                "global_step", [],
                tf.int32,
                initializer=tf.constant_initializer(0, dtype=tf.int32),
                trainable=False)
        # Variable names that start with "local" are not saved in checkpoints.
        variables_to_restore = [
            v for v in tf.global_variables() if not v.name.startswith("local")
        ]
        init_all_op = tf.global_variables_initializer()

        saver = FastSaver(variables_to_restore)

        # print trainable variables
        var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
                                     tf.get_variable_scope().name)
        logger.info('Trainable vars:')
        for v in var_list:
            logger.info('  {} {}'.format(v.name, v.get_shape()))
        logger.info("Restored the trained model.")

        # summary of rewards
        action_writers = []
        summary_writer = tf.summary.FileWriter(outdir)
        for act_idx in range(num_actions):
            action_writers.append(
                tf.summary.FileWriter(
                    os.path.join(outdir, "action_{}".format(act_idx))))

        logger.info("Inference events directory: %s", outdir)
        config = tf.ConfigProto(allow_soft_placement=True,
                                log_device_placement=False)

        with tf.Session() as sess:
            logger.info("Initializing all parameters.")
            sess.run(init_all_op)
            logger.info("Restoring trainable global parameters.")
            saver.restore(sess, ckpt)
            logger.info("Restored model was trained for %.2fM global steps",
                        sess.run(policy.global_step) / 1000000.)

            last_features = policy.get_initial_features()  # reset lstm memory
            length = 0
            rewards = 0
            loopsizes = []

            # All Episodes records
            for ep in range(args.num_episodes):
                """TODO: policy sampling strategy
                    random, greedy and sampled policy.
                """

                last_state = env.reset()

                # Episode records

                # running policy
                while True:
                    fetched = policy.act_inference(last_state, *last_features)
                    prob_action, action, value_, features = fetched[
                        0], fetched[1], fetched[2], fetched[3:]

                    #TODO: policy sampling strategy

                    # Greedy
                    stepAct = action.argmax()
                    state, reward, terminal, info = env.step(stepAct)

                    # update stats
                    length += 1
                    rewards += reward
                    last_state = state
                    last_features = features
                    """TODO: Resonable Statistics are necessary
                    """

                    if info:
                        loopsize = info["Loop Size"]
                        looparea = info["Loop Area"]

                    # store summary
                    summary = tf.Summary()
                    summary.value.add(tag='ep_{}/reward'.format(ep),
                                      simple_value=reward)
                    summary.value.add(tag='ep_{}/netreward'.format(ep),
                                      simple_value=rewards)
                    summary.value.add(tag='ep_{}/value'.format(ep),
                                      simple_value=float(value_[0]))

                    if info:
                        summary.value.add(tag='ep_{}/loop_size'.format(ep),
                                          simple_value=loopsize)
                        summary.value.add(tag='ep_{}/loop_area'.format(ep),
                                          simple_value=looparea)
                        loopsizes.append(loopsize)

                    summary_writer.add_summary(summary, length)
                    summary_writer.flush()

                    summary = tf.Summary()
                    for ac_id in range(num_actions):
                        summary.value.add(tag='ep_{}/a_{}'.format(ep, ac_id),
                                          simple_value=float(
                                              prob_action[ac_id]))
                        action_writers[ac_id].add_summary(summary, length)
                        action_writers[ac_id].flush()
                    """TODO:
                        1. Need more concrete idea for playing the game when interfering.
                        2. Save these values for post processing.
                    """
                    if terminal:
                        #if length >= timestep_limit:
                        #    last_state, _, _, _ = env.reset()

                        last_features = policy.get_initial_features(
                        )  # reset lstm memory
                        print(
                            "Episode finished. Sum of rewards: %.2f. Length: %d."
                            % (rewards, length))

                        length = 0
                        rewards = 0
                        break

        logger.info('Finished %d true episodes.', args.num_episodes)

        # Count loop topology
        unique, counts = np.unique(loopsizes, return_counts=True)
        loopstatistics = dict(zip(unique, counts))
        print(loopstatistics)
        env.close()