Пример #1
0
def train():
    sed = np.random.randint(1000)
    np.random.seed(sed)
    np.set_printoptions(precision=3, suppress=True)
    global Module, Target_module, lock, epoch, start_time
    epoch = 0
    if args.game_source == 'Gym':
        dataiter = rl_data.GymDataIter(
            args.game, args.resized_width, args.resized_height, args.agent_history_length)
    else:
        dataiter = rl_data.MultiThreadFlappyBirdIter(
            args.resized_width, args.resized_height, args.agent_history_length)

    act_dim = dataiter.act_dim
    Module = getNet(act_dim, is_train=True)
    Target_module = getNet(act_dim, is_train=False)
    lock = threading.Lock()

    start_time = time.time()
    actor_learner_threads = [threading.Thread(target=actor_learner_thread, args=(
        thread_id,)) for thread_id in range(args.num_threads)]
    for t in actor_learner_threads:
        t.start()

    for t in actor_learner_threads:
        t.join()
Пример #2
0
def test():
    if args.game_source == 'Gym':
        dataiter = rl_data.GymDataIter(
            args.game, args.resized_width, args.resized_height, args.agent_history_length)
    else:
        dataiter = rl_data.MultiThreadFlappyBirdIter(
            args.resized_width, args.resized_height, args.agent_history_length)
    act_dim = dataiter.act_dim
    module = getNet(act_dim, is_train=False)

    s_t = dataiter.get_initial_state()
    ep_reward = 0
    while True:
        batch = mx.io.DataBatch(data=[mx.nd.array([s_t])],
                                label=None)
        module.forward(batch, is_train=False)
        q_out = module.get_outputs()[0].asnumpy()
        action_index = np.argmax(q_out)
        a_t = np.zeros([act_dim])
        a_t[action_index] = 1
        s_t1, r_t, terminal, info = dataiter.act(action_index)
        ep_reward += r_t
        if terminal:
            print 'reward', ep_reward
            ep_reward = 0
            s_t1 = dataiter.get_initial_state()
        s_t = s_t1
Пример #3
0
def getGame():
    if (args.game_source == 'Gym'):
        dataiter = rl_data.GymDataIter(args.game, args.resized_width,
                                       args.resized_height, args.agent_history_length)
    else:
        dataiter = rl_data.MultiThreadFlappyBirdIter(
            args.resized_width, args.resized_height, args.agent_history_length, visual=True)
    return dataiter
Пример #4
0
def test():
    log_config()

    devs = mx.cpu() if args.gpus is None else [
        mx.gpu(int(i)) for i in args.gpus.split(',')
    ]

    # module
    dataiter = rl_data.GymDataIter('scenes',
                                   args.batch_size,
                                   args.input_length,
                                   web_viz=True)
    print(dataiter.provide_data)
    net = sym.get_symbol_thor(dataiter.act_dim)
    module = mx.mod.Module(net,
                           data_names=[d[0] for d in dataiter.provide_data],
                           label_names=('policy_label', 'value_label'),
                           context=devs)
    module.bind(data_shapes=dataiter.provide_data,
                label_shapes=[('policy_label', (args.batch_size, )),
                              ('value_label', (args.batch_size, 1))],
                for_training=False)

    # load model
    assert args.load_epoch is not None
    assert args.model_prefix is not None
    module.load_params('%s-%04d.params' % (args.model_prefix, args.load_epoch))

    N = args.num_epochs * args.num_examples / args.batch_size

    R = 0
    T = 1e-20
    score = np.zeros((args.batch_size, ))
    for t in range(N):
        dataiter.clear_history()
        data = dataiter.next()
        module.forward(data, is_train=False)
        act = module.get_outputs()[0].asnumpy()
        act = [
            np.random.choice(dataiter.act_dim, p=act[i])
            for i in range(act.shape[0])
        ]
        dataiter.act(act)
        time.sleep(0.05)
        _, reward, _, done = dataiter.history[0]
        T += done.sum()
        score += reward
        R += (done * score).sum()
        score *= (1 - done)

        if t % 100 == 0:
            logging.info('n %d score: %f T: %f' % (t, R / T, T))
Пример #5
0
def test():
    if args.game_source == 'Gym':
        dataiter = rl_data.GymDataIter(args.game, args.resized_width,
                                       args.resized_height,
                                       args.agent_history_length)
    else:
        dataiter = rl_data.MultiThreadFlappyBirdIter(args.resized_width,
                                                     args.resized_height,
                                                     args.agent_history_length)
    act_dim = dataiter.act_dim
    module = getNet(act_dim)

    module.bind(data_shapes=[
        ('data', (1, args.agent_history_length, args.resized_width,
                  args.resized_height)), ('rewardInput', (1, 1)),
        ('actionInput', (1, act_dim)), ('tdInput', (1, 1))
    ],
                label_shapes=None,
                grad_req='null',
                force_rebind=True)
    s_t = dataiter.get_initial_state()

    ep_reward = 0
    while True:
        null_r = np.zeros((args.batch_size, 1))
        null_a = np.zeros((args.batch_size, act_dim))
        null_td = np.zeros((args.batch_size, 1))
        batch = mx.io.DataBatch(data=[
            mx.nd.array([s_t]),
            mx.nd.array(null_r),
            mx.nd.array(null_a),
            mx.nd.array(null_td)
        ],
                                label=None)
        module.forward(batch, is_train=False)
        policy_out, value_out, total_loss, loss_out, policy_out2 = module.get_outputs(
        )
        probs = policy_out.asnumpy()[0]
        action_index = np.argmax(probs)
        a_t = np.zeros([act_dim])
        a_t[action_index] = 1
        s_t1, r_t, terminal, info = dataiter.act(action_index)
        ep_reward += r_t
        if terminal:
            print 'reward', ep_reward
            ep_reward = 0
            s_t1 = dataiter.get_initial_state()
        s_t = s_t1
Пример #6
0
def train():
    # kvstore
    kv = mx.kvstore.create(args.kv_store)

    model_prefix = args.model_prefix
    if model_prefix is not None:
        model_prefix += "-%d" % (kv.rank)
    save_model_prefix = args.save_model_prefix
    if save_model_prefix is None:
        save_model_prefix = model_prefix

    log_config(args.log_dir, args.log_file, save_model_prefix, kv.rank)

    devs = mx.cpu() if args.gpus is None else [
        mx.gpu(int(i)) for i in args.gpus.split(',')
    ]

    epoch_size = args.num_examples / args.batch_size

    if args.kv_store == 'dist_sync':
        epoch_size /= kv.num_workers

    # disable kvstore for single device
    if 'local' in kv.type and (args.gpus is None
                               or len(args.gpus.split(',')) is 1):
        kv = None

    # module
    dataiter = rl_data.GymDataIter('Breakout-v0',
                                   args.batch_size,
                                   args.input_length,
                                   web_viz=True)
    net = sym.get_symbol_atari(dataiter.act_dim)
    module = mx.mod.Module(net,
                           data_names=[d[0] for d in dataiter.provide_data],
                           label_names=('policy_label', 'value_label'),
                           context=devs)
    module.bind(data_shapes=dataiter.provide_data,
                label_shapes=[('policy_label', (args.batch_size, )),
                              ('value_label', (args.batch_size, 1))],
                grad_req='add')

    # load model

    if args.load_epoch is not None:
        assert model_prefix is not None
        _, arg_params, aux_params = mx.model.load_checkpoint(
            model_prefix, args.load_epoch)
    else:
        arg_params = aux_params = None

    # save model
    checkpoint = None if save_model_prefix is None else mx.callback.do_checkpoint(
        save_model_prefix)

    init = mx.init.Mixed(['fc_value_weight|fc_policy_weight', '.*'], [
        mx.init.Uniform(0.001),
        mx.init.Xavier(rnd_type='gaussian', factor_type="in", magnitude=2)
    ])
    module.init_params(initializer=init,
                       arg_params=arg_params,
                       aux_params=aux_params)

    # optimizer
    module.init_optimizer(kvstore=kv,
                          optimizer='adam',
                          optimizer_params={
                              'learning_rate': args.lr,
                              'wd': args.wd,
                              'epsilon': 1e-3
                          })

    # logging
    np.set_printoptions(precision=3, suppress=True)

    T = 0
    dataiter.reset()
    score = np.zeros((args.batch_size, 1))
    final_score = np.zeros((args.batch_size, 1))

    for epoch in range(args.num_epochs):
        epoch_time = time.time()
        if save_model_prefix:
            module.save_params('%s-%04d.params' % (save_model_prefix, epoch))

        print("IterNum in worker:" + str(epoch_size / args.t_max))
        print("Num Worker :" + str(kv.num_workers))
        tic_20 = time.time()
        for iter_w in range(int(epoch_size / args.t_max)):
            tic = time.time()
            # clear gradients
            for exe in module._exec_group.grad_arrays:
                for g in exe:
                    g[:] = 0

            S, A, V, r, D = [], [], [], [], []
            for t in range(args.t_max + 1):
                data = dataiter.data()
                module.forward(mx.io.DataBatch(data=data, label=None),
                               is_train=False)
                act, _, val = module.get_outputs()
                V.append(val.asnumpy())
                if t < args.t_max:
                    act = act.asnumpy()
                    act = [
                        np.random.choice(dataiter.act_dim, p=act[i])
                        for i in range(act.shape[0])
                    ]
                    reward, done = dataiter.act(act)
                    S.append(data)
                    A.append(act)
                    r.append(reward.reshape((-1, 1)))
                    D.append(done.reshape((-1, 1)))

            err = 0
            R = V[args.t_max]
            for i in reversed(range(args.t_max)):
                R = r[i] + args.gamma * (1 - D[i]) * R
                adv = np.tile(R - V[i], (1, dataiter.act_dim))

                batch = mx.io.DataBatch(
                    data=S[i], label=[mx.nd.array(A[i]),
                                      mx.nd.array(R)])
                module.forward(batch, is_train=True)

                pi = module.get_outputs()[1]
                h = -args.beta * (mx.nd.log(pi + 1e-7) * pi)
                out_acts = np.amax(pi.asnumpy(), 1)
                out_acts = np.reshape(out_acts, (-1, 1))
                out_acts_tile = np.tile(-np.log(out_acts + 1e-7),
                                        (1, dataiter.act_dim))
                module.backward([mx.nd.array(out_acts_tile * adv), h])

                #print('pi', pi[0].asnumpy())
                #print('h', h[0].asnumpy())
                err += (adv**2).mean()
                score += r[i]
                final_score *= (1 - D[i])
                final_score += score * D[i]
                score *= 1 - D[i]
                T += D[i].sum()

            module.update()

            if iter_w % 20 == 0:
                iter_time = time.time() - tic_20
                tic_20 = time.time()
                logging.info(
                    'fps: %f err: %f score: %f final: %f T: %f Epoch: %s Iter: %s 20Iter_Time: %s'
                    % (args.batch_size /
                       (time.time() - tic), err / args.t_max, score.mean(),
                       final_score.mean(), T, epoch, iter_w, iter_time))
            #print(score.squeeze())
            #print(final_score.squeeze())
        e_time = time.time() - epoch_time
        logging.info('Epoch_time : %s ' % (e_time))
Пример #7
0
def actor_learner_thread(thread_id):
    global TMAX, T, Module, Target_module, lock, epoch

    if args.game_source == 'Gym':
        dataiter = rl_data.GymDataIter(args.game, args.resized_width,
                                       args.resized_height,
                                       args.agent_history_length)
    else:
        dataiter = rl_data.MultiThreadFlappyBirdIter(args.resized_width,
                                                     args.resized_height,
                                                     args.agent_history_length,
                                                     visual=True)
    act_dim = dataiter.act_dim

    # Set up per-episode counters
    ep_reward = 0
    ep_t = 0

    score = np.zeros((args.batch_size, 1))

    final_epsilon = sample_final_epsilon()
    initial_epsilon = 0.1
    epsilon = 0.1

    t = 0

    s_batch = []
    s1_batch = []
    a_batch = []
    r_batch = []
    R_batch = []
    terminal_batch = []

    # here use replayMemory to fix batch size for training
    replayMemory = []

    while T < TMAX:
        tic = time.time()
        epoch += 1
        terminal = False
        s_t = dataiter.get_initial_state()
        ep_reward = 0
        episode_max_q = 0
        ep_t = 0
        ep_loss = 0
        # perform an episode
        terminal = False
        episode_max_q = 0
        while True:
            # perform n steps
            t_start = t
            s_batch = []
            s1_batch = []
            a_batch = []
            r_batch = []
            R_batch = []
            while not (terminal or ((t - t_start) == args.t_max)):
                # TODO here should be qnet forwarding, not target net. However,
                #       dealing with variable length input in mxnet is not
                #       about one simple api. Needs to change to qnet here.
                batch = mx.io.DataBatch(data=[mx.nd.array([s_t])], label=None)
                with lock:
                    Target_module.forward(batch, is_train=False)
                    q_out = Target_module.get_outputs()[0].asnumpy()

                # select action using e-greedy
                #print q_out
                action_index = action_select(act_dim, q_out, epsilon)
                #print q_out, action_index

                a_t = np.zeros([act_dim])
                a_t[action_index] = 1

                # scale down eplision
                if epsilon > final_epsilon:
                    epsilon -= (initial_epsilon - final_epsilon) / \
                        args.anneal_epsilon_timesteps

                # play one step game
                s_t1, r_t, terminal, info = dataiter.act(action_index)
                r_t = np.clip(r_t, -1, 1)
                t += 1
                T += 1
                ep_t += 1
                ep_reward += r_t
                episode_max_q = max(episode_max_q, np.max(q_out))

                s_batch.append(s_t)
                s1_batch.append(s_t1)
                a_batch.append(a_t)
                r_batch.append(r_t)
                R_batch.append(r_t)
                terminal_batch.append(terminal)
                s_t = s_t1

            if terminal:
                R_t = 0
            else:
                batch = mx.io.DataBatch(data=[mx.nd.array([s_t1])], label=None)
                with lock:
                    Target_module.forward(batch, is_train=False)
                    R_t = np.max(Target_module.get_outputs()[0].asnumpy())

            for i in reversed(range(0, t - t_start)):
                R_t = r_batch[i] + args.gamma * R_t
                R_batch[i] = R_t

            if len(replayMemory) + len(s_batch) > args.replay_memory_length:
                replayMemory[0:(len(s_batch) + len(replayMemory)) -
                             args.replay_memory_length] = []
            for i in range(0, t - t_start):
                replayMemory.append(
                    (s_batch[i], a_batch[i], r_batch[i], s1_batch[i],
                     R_batch[i], terminal_batch[i]))

            if len(replayMemory) < args.batch_size:
                continue
            minibatch = random.sample(replayMemory, args.batch_size)
            state_batch = ([data[0] for data in minibatch])
            action_batch = ([data[1] for data in minibatch])
            R_batch = ([data[4] for data in minibatch])

            # estimated reward according to target network
            # print mx.nd.array(state_batch), mx.nd.array([R_batch]),
            # mx.nd.array(action_batch)
            batch = mx.io.DataBatch(data=[
                mx.nd.array(state_batch),
                mx.nd.array(np.reshape(R_batch, (-1, 1))),
                mx.nd.array(action_batch)
            ],
                                    label=None)

            with lock:
                Module.forward(batch, is_train=True)
                loss = np.mean(Module.get_outputs()[0].asnumpy())
                summary_writer.add_summary(s, T)
                summary_writer.flush()
                Module.backward()
                Module.update()

            if t % args.network_update_frequency == 0 or terminal:
                with lock:
                    copyTargetQNetwork(Module, Target_module)

            if terminal:
                print "THREAD:", thread_id, "/ TIME", T, "/ TIMESTEP", t, "/ EPSILON", epsilon, "/ REWARD", ep_reward, "/ Q_MAX %.4f" % (
                    episode_max_q), "/ EPSILON PROGRESS", t / float(
                        args.anneal_epsilon_timesteps)
                s = summary.scalar('score', ep_reward)
                summary_writer.add_summary(s, T)
                summary_writer.flush()
                elapsed_time = time.time() - start_time
                steps_per_sec = T / elapsed_time
                print(
                    "### Performance : {} STEPS in {:.0f} sec. {:.0f} STEPS/sec. {:.2f}M STEPS/hour"
                    .format(T, elapsed_time, steps_per_sec,
                            steps_per_sec * 3600 / 1000000.))
                ep_reward = 0
                episode_max_q = 0
                break

        #print epoch
        if args.save_every != 0 and epoch % args.save_every == 0:
            save_params(args.save_model_prefix, Module, epoch)
Пример #8
0
def setup(isGlobal=False):
    '''
    devs = mx.cpu() if args.gpus is None else [
        mx.gpu(int(i)) for i in args.gpus.split(',')]
    '''

    #devs = mx.gpu(1)
    devs = mx.cpu()

    arg_params, aux_params = load_args()

    if (args.game_source == 'Gym'):
        dataiter = rl_data.GymDataIter(args.game, args.resized_width,
                                       args.resized_height,
                                       args.agent_history_length)
    else:
        dataiter = rl_data.MultiThreadFlappyBirdIter(args.resized_width,
                                                     args.resized_height,
                                                     args.agent_history_length,
                                                     visual=True)
    act_dim = dataiter.act_dim

    mod = mx.mod.Module(sym.get_dqn_symbol(act_dim, ispredict=False),
                        data_names=('data', 'rewardInput', 'actionInput'),
                        label_names=None,
                        context=devs)
    mod.bind(data_shapes=[('data', (args.batch_size, args.agent_history_length,
                                    args.resized_width, args.resized_height)),
                          ('rewardInput', (args.batch_size, 1)),
                          ('actionInput', (args.batch_size, act_dim))],
             label_shapes=None,
             grad_req='write')

    initializer = mx.init.Xavier(factor_type='in', magnitude=2.34)

    if args.load_epoch is not None:
        mod.init_params(arg_params=arg_params, aux_params=aux_params)
    else:
        mod.init_params(initializer)
    # optimizer
    mod.init_optimizer(optimizer='adam',
                       optimizer_params={
                           'learning_rate': args.lr,
                           'wd': args.wd,
                           'epsilon': 1e-3,
                           'clip_gradient': 10.0
                       })

    target_mod = mx.mod.Module(sym.get_dqn_symbol(act_dim, ispredict=True),
                               data_names=('data', ),
                               label_names=None,
                               context=devs)

    target_mod.bind(data_shapes=[
        ('data', (1, args.agent_history_length, args.resized_width,
                  args.resized_height)),
    ],
                    label_shapes=None,
                    grad_req='null')
    if args.load_epoch is not None:
        target_mod.init_params(arg_params=arg_params, aux_params=aux_params)
    else:
        target_mod.init_params(initializer)
    # optimizer
    target_mod.init_optimizer(optimizer='adam',
                              optimizer_params={
                                  'learning_rate': args.lr,
                                  'wd': args.wd,
                                  'epsilon': 1e-3,
                                  'clip_gradient': 10.0
                              })
    if (isGlobal == False):
        return mod, target_mod, dataiter
    else:
        return mod, target_mod
Пример #9
0
def actor_learner_thread(thread_id):
    global TMAX, T, Module, Target_module, lock, epoch, start_time

    if args.game_source == 'Gym':
        dataiter = rl_data.GymDataIter(args.game, args.resized_width,
                                       args.resized_height, args.agent_history_length)
    else:
        dataiter = rl_data.MultiThreadFlappyBirdIter(args.resized_width,
                                                     args.resized_height, args.agent_history_length, visual=True)

    act_dim = dataiter.act_dim
    thread_net = getNet(act_dim, is_train=True)
    thread_net.bind(data_shapes=[('data', (1, args.agent_history_length,
                                           args.resized_width, args.resized_height)),
                                 ('rewardInput', (1, 1)),
                                 ('actionInput', (1, act_dim))],
                    label_shapes=None, grad_req='null', force_rebind=True)
    # Set up per-episode counters
    ep_reward = 0
    episode_max_q = 0
    final_epsilon = sample_final_epsilon()
    initial_epsilon = 0.1
    epsilon = 0.1
    t = 0

    # here use replayMemory to fix batch size for training
    replayMemory = []

    while T < TMAX:
        epoch += 1
        terminal = False
        s_t = dataiter.get_initial_state()
        ep_reward = 0

        while True:
            t_start = t
            s_batch = []
            s1_batch = []
            a_batch = []
            r_batch = []
            R_batch = []
            terminal_batch = []
            thread_net.bind(data_shapes=[('data', (1, args.agent_history_length,
                                                args.resized_width, args.resized_height)),
                                        ('rewardInput', (1, 1)),
                                        ('actionInput', (1, act_dim))],
                            label_shapes=None, grad_req='null', force_rebind=True)
            with lock:
                thread_net.copy_from_module(Module)
            #thread_net.clear_gradients()
            while not (terminal or ((t - t_start) == args.t_max)):
                batch = mx.io.DataBatch(data=[mx.nd.array([s_t]), mx.nd.array(np.zeros((1, 1))),
                                            mx.nd.array(np.zeros((1, act_dim)))],
                                        label=None)
                thread_net.forward(batch, is_train=False)
                q_out = thread_net.get_outputs()[1].asnumpy()

                # select action using e-greedy
                action_index = action_select(act_dim, q_out, epsilon)
                #print q_out, action_index

                a_t = np.zeros([act_dim])
                a_t[action_index] = 1

                # scale down eplision
                if epsilon > final_epsilon:
                    epsilon -= (initial_epsilon - final_epsilon) / \
                        args.anneal_epsilon_timesteps

                # play one step game
                s_t1, r_t, terminal, info = dataiter.act(action_index)
                r_t = np.clip(r_t, -1, 1)
                t += 1
                with lock:
                    T += 1
                ep_reward += r_t
                episode_max_q = max(episode_max_q, np.max(q_out))
                s_batch.append(s_t)
                s1_batch.append(s_t1)
                a_batch.append(a_t)
                r_batch.append(r_t)
                R_batch.append(r_t)
                terminal_batch.append(terminal)
                s_t = s_t1

            if terminal:
                R_t = 0
            else:
                batch = mx.io.DataBatch(data=[mx.nd.array([s_t1])], label=None)
                with lock:
                    Target_module.forward(batch, is_train=False)
                    R_t = np.max(Target_module.get_outputs()[0].asnumpy())

            for i in reversed(range(0, t - t_start)):
                R_t = r_batch[i] + args.gamma * R_t
                R_batch[i] = R_t

            if len(replayMemory) + len(s_batch) > args.replay_memory_length:
                replayMemory[0:(len(s_batch) + len(replayMemory)) -
                            args.replay_memory_length] = []
            for i in range(0, t - t_start):
                replayMemory.append(
                    (s_batch[i], a_batch[i], r_batch[i], s1_batch[i],
                    R_batch[i],
                    terminal_batch[i]))

            if len(replayMemory) < args.batch_size:
                continue
            minibatch = random.sample(replayMemory, args.batch_size)
            state_batch = ([data[0] for data in minibatch])
            action_batch = ([data[1] for data in minibatch])
            R_batch = ([data[4] for data in minibatch])

            # TODO here can only forward one at each time because mxnet need rebind
            # for variable input length
            batch_size = len(minibatch)
            thread_net.bind(data_shapes=[('data', (batch_size, args.agent_history_length,
                                                args.resized_width, args.resized_height)),
                                        ('rewardInput', (batch_size, 1)),
                                        ('actionInput', (batch_size, act_dim))],
                            label_shapes=None, grad_req='write', force_rebind=True)

            batch = mx.io.DataBatch(data=[mx.nd.array(state_batch),
                                          mx.nd.array(np.reshape(
                                              R_batch, (-1, 1))),
                                          mx.nd.array(action_batch)], label=None)

            thread_net.clear_gradients()
            thread_net.forward(batch, is_train=True)
            loss = np.mean(thread_net.get_outputs()[0].asnumpy())
            thread_net.backward()

            s = summary.scalar('loss', loss)
            summary_writer.add_summary(s, T)
            summary_writer.flush()

            with lock:
                Module.clear_gradients()
                Module.add_gradients_from_module(thread_net)
                Module.update()
                Module.clear_gradients()

            #thread_net.update()
            thread_net.clear_gradients()

            if t % args.network_update_frequency == 0 or terminal:
                with lock:
                    Target_module.copy_from_module(Module)

            if terminal:
                print "THREAD:", thread_id, "/ TIME", T, "/ TIMESTEP", t, "/ EPSILON", epsilon, "/ REWARD", ep_reward, "/ Q_MAX %.4f" % episode_max_q, "/ EPSILON PROGRESS", t / float(args.anneal_epsilon_timesteps)
                s = summary.scalar('score', ep_reward)
                summary_writer.add_summary(s, T)
                summary_writer.flush()
                elapsed_time = time.time() - start_time
                steps_per_sec = T / elapsed_time
                print("### Performance : {} STEPS in {:.0f} sec. {:.0f} STEPS/sec. {:.2f}M STEPS/hour".format(
                    T,  elapsed_time, steps_per_sec, steps_per_sec * 3600 / 1000000.))
                ep_reward = 0
                episode_max_q = 0
                ep_reward = 0
                break

        if args.save_every != 0 and epoch % args.save_every == 0:
            save_params(args.save_model_prefix, Module, epoch)