start_last_save = 0
        # make new directory for this run in the case that there is already a
        # project with this name
        run_num = 0
        model_base_filedir = os.path.join(config.model_savedir,
                                          info['NAME'] + '%02d' % run_num)
        while os.path.exists(model_base_filedir):
            run_num += 1
            model_base_filedir = os.path.join(config.model_savedir,
                                              info['NAME'] + '%02d' % run_num)
        os.makedirs(model_base_filedir)
        print("----------------------------------------------")
        print("starting NEW project: %s" % model_base_filedir)

    model_base_filepath = os.path.join(model_base_filedir, info['NAME'])
    write_info_file(info, model_base_filepath, start_step_number)
    heads = list(range(info['N_ENSEMBLE']))
    seed_everything(info["SEED"])

    info['model_base_filepath'] = model_base_filepath
    info['num_actions'] = env.num_actions
    info['action_space'] = range(info['num_actions'])
    vqenv = VQEnv(info,
                  vq_model_loadpath=info['VQ_MODEL_LOADPATH'],
                  device='cpu')

    policy_net = EnsembleNet(n_ensemble=info['N_ENSEMBLE'],
                             n_actions=env.num_actions,
                             reshape_size=info['RESHAPE_SIZE'],
                             num_channels=info['HISTORY_SIZE'],
                             dueling=info['DUELING'],
Esempio n. 2
0
    if args.model_loadpath is not '':
        # what about random states - they will be wrong now???
        # TODO - what about target net update cnt
        target_net.load_state_dict(model_dict['target_net_state_dict'])
        policy_net.load_state_dict(model_dict['policy_net_state_dict'])
        opt.load_state_dict(model_dict['optimizer'])
        print("loaded model state_dicts")
        # TODO cant load buffer yet
        if args.buffer_loadpath == '':
            args.buffer_loadpath = args.model_loadpath.replace(
                '.pkl', '_train_buffer.pkl')
            print("auto loading buffer from:%s" % args.buffer_loadpath)
            rbuffer.load(args.buffer_loadpath)
    info['args'] = args
    write_info_file(info, model_base_filepath, total_steps)
    random_state = np.random.RandomState(info["SEED"])

    board_logger = TensorBoardLogger(model_base_filedir)
    last_target_update = 0
    print("Starting training")
    all_rewards = []

    epsilon_by_frame = lambda frame_idx: info['EPSILON_MIN'] + (info[
        'EPSILON_MAX'] - info['EPSILON_MIN']) * math.exp(-1. * frame_idx /
                                                         info['EPSILON_DECAY'])
    for epoch_num in range(epoch_start, info['N_EPOCHS']):
        ep_reward, total_steps, etime = run_training_episode(
            epoch_num, total_steps)
        all_rewards.append(ep_reward)
        overall_time += etime