示例#1
0
文件: tools.py 项目: koenboeckx/VKHO
def test_replay(model_file, agent_type='qmix', period=None):
    import yaml
    from utilities import get_args
    args = get_args(yaml.load(open('default_config.yaml', 'r')))
    path = '/home/koen' + args.path

    args.gamma = 0.8
    args.max_episode_length = 30
    args.step_penalty = 0.05
    args.a_terrain = True

    if agent_type == 'qmix':
        model = QMixModel(input_shape=args.n_inputs,
                          n_actions=args.n_actions,
                          args=args)
        target = QMixModel(input_shape=args.n_inputs,
                           n_actions=args.n_actions,
                           args=args)
        model.load_state_dict(torch.load(path + model_file))
        target.load_state_dict(torch.load(path + model_file))
        models = {"model": model, "target": target}
        team_blue = [
            QMIXAgent(idx, "blue", args) for idx in range(args.n_friends)
        ]
    elif agent_type == 'reinforce':
        models = RNNModel(input_shape=args.n_inputs,
                          n_actions=args.n_actions,
                          args=args)
        models.load_state_dict(torch.load(path + model_file))
        team_blue = [
            PGAgent(idx, "blue", args) for idx in range(args.n_friends)
        ]

    for agent in team_blue:
        agent.set_model(models)
    team_red = [
        Agent(args.n_friends + idx, "red") for idx in range(args.n_enemies)
    ]
    agents = team_blue + team_red
    env = RestrictedEnvironment(agents, args)
    while True:
        episode = generate_episode(env, args)
        print(len(episode))
        if len(episode) < 6:
            visualize(env, episode, period=period)
            break
示例#2
0
import gc
import os
import sys
from shutil import copyfile

from keras.models import load_model
from keras.utils import to_categorical

from alagent import ALAgent
from tagger import CRFTagger
import utilities
import tensorflow as tf
import numpy as np
import time

args = utilities.get_args()
logger = utilities.init_logger()

max_len = args.max_seq_length
VOCABULARY = args.vocab_size
EPISODES = args.episodes
BUDGET = args.annotation_budget
k = args.k

rootdir = args.root_dir
train_file = args.train_file
dev_file = args.dev_file
test_file = args.test_file
emb_file = args.word_vec_file
DATASET_NAME = args.dataset_name
policy_path = args.policy_path
示例#3
0
def run(_config):
    args = get_args(_config)
    train_iteratively(args, agent_type='qmix')
示例#4
0
def run(_config):
    global args
    args = get_args(_config)
    train(args)
        log_file.close()


def main(args):
    print('Reading config')
    config = utilities.read_json_config(args.config, utilities.Task.parse)
    print('Starting parsing...')
    output_dir = '{}/{}'.format(config['dir']['data'], config['name'])
    print('Creating data directory {}'.format(output_dir))
    utilities.create_dir(output_dir)
    print('Reading raw data...')
    threads = []
    pbar = tqdm(range(len(config['data'])))
    for index in pbar:
        infile = config['data'][index]['file']
        outfile = '{}/{}.csv'.format(output_dir, config['data'][index]['name'])
        pbar.set_description('Processing raw_data in={} out={}'.format(
            infile, outfile))
        parse(
            infile, outfile, config['data'][index]['old_format']
            if 'old_format' in config['data'][index] else False)


if __name__ == '__main__':
    import time
    start_time = time.time()
    main(utilities.get_args())
    print("--- %s seconds ---" % (time.time() - start_time))
    # Threads: --- 71.97676062583923 seconds ---
    # No Thread : --- 64.11785078048706 seconds ---
示例#6
0
    print('Preparing other output dirs')
    cdf_dir = '{}/cdf'.format(output_dir)
    gnuplot_dir = '{}/gnuplot'.format(output_dir)
    plot_dir = '{}/plot'.format(output_dir)
    model_dir = '{}/model'.format(output_dir)

    utilities.create_dir(cdf_dir)
    utilities.create_dir(gnuplot_dir)
    utilities.create_dir(plot_dir)
    utilities.create_dir(model_dir)

    print('Generate diff and plots...')

    pbar = tqdm(predictors)
    for predictor in pbar:
        pbar.set_description('Generate diffs for {}'.format(predictor))
        diff = generate_diff(config, predictors, predictor, dataset)
        pbar.set_description('Saving diffs for {}'.format(predictor))
        sorted_indexes = save_diff(config, cdf_dir, predictor, diff)
        pbar.set_description('Creating plot for {}'.format(predictor))
        save_plot(config, cdf_dir, gnuplot_dir, plot_dir, predictor, diff, sorted_indexes)
        pbar.set_description('Saving model for {}'.format(predictor))
        utilities.save('{}/{}.joblib'.format(model_dir, predictor), predictors[predictor])

    
if __name__ == '__main__':
    import time
    start_time = time.time()
    main(utilities.get_args(True))
    print("--- %s seconds ---" % (time.time() - start_time))
示例#7
0
    def package(self, data):
        size = len(data)
        return struct.pack('i', size) + data


class ClientTransfer(Protocol):
    def __init__(self):
        pass

    def set_protocol(self, p):
        self.server = p
        self.server.transport.resumeProducing()
        pass

    def dataReceived(self, data):
        self.server.transport.write(data)
        pass


if __name__ == '__main__':
    kwargs = get_args(sys.argv[1:])
    host = kwargs['host'] if kwargs['host'] is not None else '0.0.0.0'
    port = kwargs['port'] if kwargs['port'] is not None else 0
    factory = Factory()
    factory.protocol = Transfer
    reactor.listenTCP(port, factory)
    log.dev_info(
        "Dispatcher started, waiting for connection on port {}".format(port))
    reactor.run()
示例#8
0
from qmix import train

if __name__ == '__main__':
    import yaml
    from utilities import get_args
    args = get_args(yaml.load(open('default_config.yaml', 'r')))
    args.n_steps = 100
    train(args)
示例#9
0
from utilities import get_args, create_directories
from rrd import create, update
from traceroute import Traceroute
import pathlib
import argparse as ap

parser = ap.ArgumentParser()
parser.add_argument('target')
args = parser.parse_args()

directory = '/var/www/html/django'  # TODO: Remove hardcoded location.
target = args.target

args = get_args(directory, target)
rrd_dir, graph_dir = create_directories(args, target)
tr = Traceroute(target, rrd_dir, graph_dir)
exists = pathlib.Path(tr.rrd).exists()

exists or create(tr, args)
update(tr)