예제 #1
0
def main():
    params = Params()

    mp.set_start_method('spawn')
    lock = mp.Lock()

    actions = mp.Array('i', [-1] * params.n_process, lock=lock)
    count = mp.Value('i', 0)
    best_acc = mp.Value('d', 0.0)

    state_Queue = mp.JoinableQueue()
    action_done = mp.SimpleQueue()
    reward_Queue = mp.JoinableQueue()

    # shared_model = A3C_LSTM_GA()
    # shared_model = shared_model.share_memory()
    #
    # shared_optimizer = SharedAdam(shared_model.parameters(), lr=params.lr, amsgrad=params.amsgrad, weight_decay=params.weight_decay)
    # shared_optimizer.share_memory()
    #run_sim(0, params, shared_model, None,  count, lock)
    #test(params, shared_model, count, lock, best_acc)

    processes = []

    train_process = 0
    test_process = 0

    p = mp.Process(target=learning,
                   args=(
                       params,
                       state_Queue,
                       action_done,
                       actions,
                       reward_Queue,
                   ))
    p.start()
    processes.append(p)
    # test_process += 1

    for rank in range(params.n_process):
        p = mp.Process(target=run_sim,
                       args=(
                           train_process,
                           params,
                           state_Queue,
                           action_done,
                           actions,
                           reward_Queue,
                           lock,
                       ))

        train_process += 1
        p.start()
        processes.append(p)

    for p in processes:
        p.join()
예제 #2
0
def eval_net(dataloader, model, opts):
    torch.manual_seed(23)
    queue = mp.JoinableQueue()
    _outputs, _indices = mp.Queue(), mp.Queue()

    consumer = Consumer(queue, _outputs, _indices, dataloader, opts)
    consumer.start()
    dataset = dataloader.dataset
    model.eval()

    dataset_len = 100

    with tqdm(total=dataset_len) as t:
        for i in range(dataset_len):
            img, heatmap_t, paf_t, ignore_mask_t = dataset.get_img(i,
                                                                   flip=False)
            img_batch = np.zeros(
                (1, 3, opts["train"]["imgSize"], opts["train"]["imgSize"]))
            img_batch[0, :, :, :] = img
            with torch.no_grad():
                imgs_torch = torch.from_numpy(img_batch).float().cuda()
                heatmaps, pafs = model(imgs_torch)
                heatmap = heatmaps[-1][0].data.cpu().numpy()
                paf = pafs[-1][0].data.cpu().numpy()
            # print(heatmap_t.shape)
            queue.put(
                (i, img, heatmap_t, heatmap, paf_t, paf, ignore_mask_t[0]))
            t.update()
    consumer.join()
    outputs, indices = [], []
    for _ in range(100):
        outputs.append(_outputs.get())
        indices.append(_indices.get())
    print(outputs)
    return outputs, indices
예제 #3
0
    def __init__(self, env_name, env_kwargs, batch_size, policy, baseline,
                 env=None, seed=None, num_workers=1):
        super(MultiTaskSampler, self).__init__(env_name, env_kwargs, batch_size,
                                               policy, seed=seed, env=env)

        self.num_workers = num_workers

        self.task_queue = mp.JoinableQueue()
        self.train_episodes_queue = mp.Queue()
        self.valid_episodes_queue = mp.Queue()
        policy_lock = mp.Lock()

        self.workers = [SamplerWorker(index, env_name, env_kwargs, batch_size, 
                                      self.env.observation_space, self.env.action_space,
                                      self.policy, deepcopy(baseline), self.seed,
                                      self.task_queue, self.train_episodes_queue,
                                      self.valid_episodes_queue, policy_lock)
                        for index in range(num_workers)]

        for worker in self.workers:
            worker.daemon = True  # this makes all the threads stop when main process ends
            worker.start()

        self._waiting_sample = False
        self._event_loop = asyncio.get_event_loop()
        self._train_consumer_thread = None
        self._valid_consumer_thread = None
예제 #4
0
    def generate_walks(self):

        g = self.graph
        assert g.is_homogeneous

        all_nodes = g.nodes().numpy().tolist() * self.num_walks
        random.shuffle(all_nodes)

        queue = mp.JoinableQueue()
        per_worker = len(all_nodes) // commons.workers + 1
        ps = []
        for i in range(commons.workers):
            chunk = all_nodes[i * per_worker:(i + 1) * per_worker]
            ps.append(
                mp.Process(target=self.sample,
                           args=(g, chunk, self.walk_length, queue))),

        for p in ps:
            p.start()

        all_walks = []
        for i in range(commons.workers):
            all_walks.extend(queue.get())

        for p in ps:
            p.terminate()

        with open(self.walks_file, 'w') as f:
            for walk in all_walks:
                walk = self.simple_filter(walk)
                f.write(' '.join(walk))
                f.write('\n')
    def __init__(self,
                 env_name,
                 env_kwargs,
                 batch_size,
                 policy,
                 baseline,
                 env=None,
                 seed=None,
                 num_workers=1):
        # 多重继承类,调用类MultiTaskSampler
        super(MultiTaskSampler, self).__init__(env_name,
                                               env_kwargs,
                                               batch_size,
                                               policy,
                                               seed=seed,
                                               env=env)

        self.num_workers = num_workers
        # 初始化队列 训练队列与测试队列 用于提取多进程数据
        self.task_queue = mp.JoinableQueue()
        self.train_episodes_queue = mp.Queue()
        self.valid_episodes_queue = mp.Queue()
        policy_lock = mp.Lock()
        # self.Original_policy = self.policy
        # temporary_policy = self.Original_policy

        # 构建 num_workers 个 workers;调用 num_workers 次
        self.workers = [SamplerWorker(index,
                                      env_name,
                                      env_kwargs,
                                      batch_size,
                                      self.env.observation_space,
                                      self.env.action_space,
                                      self.policy,
                                      deepcopy(baseline),
                                      self.seed,
                                      self.task_queue,
                                      self.train_episodes_queue,
                                      self.valid_episodes_queue,
                                      policy_lock)
            for index in range(num_workers)]

        for worker in self.workers:
            # 守护进程 主进程代码运行结束,守护进程随即终止
            worker.daemon = True
            """
            启动worker (SamplerWorker(index)) 跳转至类SamplerWorker 中的函数 run(self),
            触发采样训练轨迹,inner更新网络,采样验证轨迹数据
            """
            worker.start()

        self._waiting_sample = False
        # 创建事件循环以及训练、验证双线程
        self._event_loop = asyncio.get_event_loop()
        self._train_consumer_thread = None
        self._valid_consumer_thread = None
예제 #6
0
def eval_wer(loader, criterion, lm_weight, index2letter, n_processes=32):

    criterion.eval()

    bar = progressbar.ProgressBar(len(loader))
    bar.start()

    task_q, result_q = mp.JoinableQueue(), mp.Queue()
    processes = []
    for _ in range(n_processes):
        p = mp.Process(
            target=Worker(lm_weight, index2letter, task_q, result_q))
        p.start()
        processes.append(p)

    tasks_fed = 0
    mean_wer = 0.0
    results = 0.0

    for index, data in enumerate(loader):
        bar.update(index)
        batch_size = data[0].size(0)
        tasks_fed += batch_size

        with torch.no_grad():
            seq, seq_lengths, labels, label_lengths = prepare_data(
                data, put_on_cuda=False)
            seq = seq.cuda()

            predictions = criterion.letter_classifier(seq).log_softmax(
                dim=-1).cpu()

        for k in range(batch_size):
            p_ = predictions[k, :, :]
            labels_ = (labels[k, :label_lengths[k]])
            task_q.put((p_, labels_))

        task_q.join()
        while not result_q.empty():
            mean_wer += result_q.get()
            results += 1
        assert results == tasks_fed
    bar.finish()

    for _ in processes:
        task_q.put(None)

    for p in processes:
        p.join()

    mean_wer /= results
    return mean_wer
 def reset(self, items: TfmdSource, train_setup=False):
     pv('reset', self.verbose)
     self.step_idx = 0
     self.close(items)
     self.cancel.clear()
     self.queue = mp.JoinableQueue(maxsize=self.n_processes)
     items.items = [
         self.process_cls(start=True,
                          items=self.cached_items,
                          train_queue=self.queue,
                          cancel=self.cancel)
         for _ in range(self.n_processes)
     ]
     if not all([p.is_alive() for p in items.items]):
         raise CancelFitException()
 def __init__(self,
              n_processes: int = 1,
              process_cls=None,
              cancel=None,
              verbose: str = False,
              regular_get: bool = True,
              tracker=None):
     store_attr(but='process_cls')
     self.process_cls = ifnone(process_cls, DataFitProcess)
     self.queue = mp.JoinableQueue(maxsize=self.n_processes)
     self.cancel = ifnone(self.cancel, mp.Event())
     self.pipe_in, self.pipe_out = mp.Pipe(False) if self.verbose else (
         None, None)
     self.cached_items = []
     self._place_holder_out = None
     self.step_idx = 0
예제 #9
0
    def __init__(
        self,
        policy_container,
        opposing_policy_container,
        env_gen,
        evaluation_policy_container=None,
        network=None,
        swap_sides=True,
        save_dir="saves",
        epoch_length=500,
        initial_games=64,
        self_play=False,
        lr=0.001,
        stagger=False,
        evaluation_games=100,
    ):
        self.policy_container = policy_container
        self.opposing_policy_container = opposing_policy_container
        self.evaluation_policy_container = evaluation_policy_container
        self.env_gen = env_gen
        self.swap_sides = swap_sides
        self.save_dir = save_dir
        self.epoch_length = epoch_length
        self.self_play = self_play
        self.lr = lr
        self.stagger = stagger

        self.network = network
        self.evaluation_games = evaluation_games

        self.start_time = datetime.datetime.now().isoformat()

        self.task_queue = multiprocessing.JoinableQueue()
        self.memory_queue = multiprocessing.Queue()
        self.result_queue = multiprocessing.Queue()
        self.initial_games = initial_games
        self.writer = SummaryWriter()

        if save_dir:
            os.mkdir(os.path.join(save_dir, self.start_time))
            logging.basicConfig(filename=join(save_dir, self.start_time,
                                              "log"),
                                level=logging.INFO)
        multiprocessing_logging.install_mp_handler()
예제 #10
0
파일: batch.py 프로젝트: wzhystar/fastNLP
def run_batch_iter(batch):
    q = mp.JoinableQueue(maxsize=10)
    fetch_p = mp.Process(target=run_fetch, args=(batch, q))
    fetch_p.daemon = True
    fetch_p.start()
    # print('fork fetch process')
    while 1:
        try:
            res = q.get(timeout=1)
            q.task_done()
            # print('get fetched')
            if res is None:
                break
            yield res
        except Exception as e:
            if fetch_p.is_alive():
                continue
            else:
                break
    fetch_p.terminate()
    fetch_p.join()
예제 #11
0
import torch.multiprocessing as mp

from option import opt
import process as proc

MAX_FPS = 30
MAX_SEGMENT_LENGTH = 4
SHARED_QUEUE_LEN = MAX_FPS * MAX_SEGMENT_LENGTH  #Regulate GPU memory usage (> 3 would be fine)

if __name__ == "__main__":
    mp.set_start_method('spawn', force=True)
    torch.multiprocessing.set_sharing_strategy('file_descriptor')

    #create Queue, Pipe
    decode_queue = mp.Queue()
    dnn_queue = mp.JoinableQueue()
    data_queue = mp.JoinableQueue()
    encode_queue = mp.JoinableQueue()
    output_output, output_input = mp.Pipe(duplex=False)

    #create shared tensor
    shared_tensor_list = {}
    res_list = [(270, 480), (360, 640), (540, 960), (1080, 1920)]
    for res in res_list:
        shared_tensor_list[res[0]] = []
        for _ in range(SHARED_QUEUE_LEN):
            shared_tensor_list[res[0]].append(
                torch.ByteTensor(res[0], res[1], 3).cuda().share_memory_())

    #create processes
    decode_process = mp.Process(target=proc.decode,
예제 #12
0
import gym

from copy import deepcopy

import matplotlib.pyplot as plt
from tqdm import tqdm
import numpy as np

'''
Let's see how just non-MAML version adapts

'''



task_queue = mp.JoinableQueue()
train_episodes_queue = mp.Queue()
valid_episodes_queue = mp.Queue()
policy_lock = mp.Lock()
env_name = "2DNavigation-v0"
env_kwargs = {
                    "low": -0.5,
                    "high": 0.5,
                    "task": {"goal": np.array([1, 1])}
                }
env = gym.make(env_name, **env_kwargs)
print(env.)

policy = get_policy_for_env(env,
                            hidden_sizes=(64, 64),
                            nonlinearity='tanh')
예제 #13
0
 def __init__(self, processes) -> None:
     super().__init__()
     self.n_jobs = processes
     # self.init_fun = init_fun
     self.task_queue = multiprocessing.JoinableQueue()
     self.results_queue = multiprocessing.Queue()
    print("assembly complete: " + str(name))
    if seq:
        pickle.dump(
            result,
            open(
                "/tigress/noamm/schapiro/gpu/nets_seq_" + str(num_networks) +
                "_" + str(num_trials) + "_" + str(name) + ".pkl", "wb"))
    else:
        pickle.dump(
            result,
            open(
                "/tigress/noamm/schapiro/gpu/nets_sep_" + str(num_networks) +
                "_" + str(num_trials) + "_" + str(name) + ".pkl", "wb"))


q_sep = mp.JoinableQueue()
q_seq = mp.JoinableQueue()

process_sep = [
    mp.Process(target=assemble, args=(i, False, num_trials))
    for i in range(num_networks)
]
process_seq = [
    mp.Process(target=assemble, args=(i, True, num_trials))
    for i in range(num_networks)
]

print("sep")
for process in process_sep:
    process.start()
예제 #15
0
    def __init__(
            self,
            video_path,
            output_path,
            realtime,
            start,
            duration,
            show_time,
            confidence_threshold=0.5,
            exclude_class=None,
            common_cate=False,
    ):
        self.vid_info = cv2_video_info(video_path)
        fps = self.vid_info["fps"]
        if fps == 0 or fps > 100:
            print(
                "Warning: The detected frame rate {} could be wrong. The behavior of this demo code can be abnormal.".format(
                    fps))

        self.realtime = realtime
        self.start = start
        self.duration = duration
        self.show_time =  show_time
        self.confidence_threshold = confidence_threshold
        if common_cate:
            self.cate_to_show = self.COMMON_CATES
            self.category_split = (6, 11)
        else:
            self.cate_to_show = self.CATEGORIES
            self.category_split = (14, 63)
        self.cls2label = {class_name: i for i, class_name in enumerate(self.cate_to_show)}
        if exclude_class is None:
            exclude_class = self.EXCLUSION
        self.exclude_id = [self.cls2label[cls_name] for cls_name in exclude_class if cls_name in self.cls2label]

        self.width = self.vid_info["width"]
        self.height = self.vid_info["height"]
        long_side = min(self.width, self.height)
        self.font_size = max(int(round((long_side / 40))), 1)
        self.box_width = max(int(round(long_side / 180)), 1)
        self.font = ImageFont.truetype("./Roboto-Bold.ttf", self.font_size)

        self.box_color = (191, 40, 41)
        self.category_colors = ((176, 85, 234), (87, 118, 198), (52, 189, 199))
        self.category_trans = int(0.6 * 255)

        self.action_dictionary = dict()

        if realtime:
            # Output Video
            width = self.vid_info["width"]
            height = self.vid_info["height"]
            fps = self.vid_info["fps"]
            fourcc = cv2.VideoWriter_fourcc(*"mp4v")
            self.out_vid = cv2.VideoWriter(output_path, fourcc, fps, (width, height))
        else:
            self.frame_queue = mp.JoinableQueue(512)
            self.result_queue = mp.JoinableQueue()
            self.track_queue = mp.JoinableQueue()
            self.done_queue = mp.Queue()
            self.frame_loader = mp.Process(target=self._load_frame, args=(video_path,))
            self.frame_loader.start()
            self.video_writer = mp.Process(target=self._wirte_frame, args=(output_path,))
            self.video_writer.start()
예제 #16
0
def ft(ft_dir, app_args, cleanup_ft_dir=False):
    """Fine tune all the checkpoint files we find in the immediate-directory specified.
    For each checkpoint file we find, we create and queue a FinetuningTask.  
    A FinetuningProcess will pickup the FinetuningTask and process it.
    """
    print("Fine-tuning directory %s" % ft_dir)
    checkpoints = glob.glob(os.path.join(ft_dir, "*checkpoint.pth.tar"))
    assert checkpoints

    # We create a subdirectory, where we will write all of our output
    ft_output_dir = os.path.join(ft_dir, 'ft')

    os.makedirs(ft_output_dir, exist_ok=True)
    print("Writing results to directory %s" % ft_output_dir)
    app_args.output_dir = ft_output_dir

    # Multi-process queues
    tasks = multiprocessing.JoinableQueue()
    results = multiprocessing.Queue()

    # Create and launch the fine-tuning processes
    processes = []
    n_processes = min(app_args.processes, len(checkpoints))
    for i in range(n_processes):
        # Pre-load the data-loaders of each fine-tuning process once
        app = classifier.ClassifierCompressor(
            app_args, script_dir=os.path.dirname(__file__))
        data_loader = classifier.load_data(app.args)
        # Delete log directories
        shutil.rmtree(app.logdir)
        processes.append(FinetuningProcess(tasks, results, data_loader))
        # Start the process
        processes[-1].start()

    n_gpus = torch.cuda.device_count()

    # Enqueue all of the fine-tuning tasks
    for (instance, ckpt_file) in enumerate(checkpoints):
        tasks.put(FinetuningTask(args=(ckpt_file, instance % n_gpus,
                                       app_args)))

    # Push an end-of-tasks marker
    for i in range(len(processes)):
        tasks.put(None)

    # Wait until all tasks finish
    tasks.join()

    # Start printing results
    results_dict = OrderedDict()
    while not results.empty():
        result = results.get()
        results_dict[result[0]] = result[1]

    # Read the results of the AMC experiment (we'll want to use some of the data)
    # import pandas as pd
    # df = pd.read_csv(os.path.join(ft_dir, "amc.csv"))
    # assert len(results_dict) > 0

    if cleanup_ft_dir:
        # cleanup: remove the "ft" directory
        shutil.rmtree(ft_output_dir)
예제 #17
0
    def __init__(self,
                 env_name,
                 env_kwargs,
                 batch_size,
                 policy,
                 baseline,
                 dynamics=None,
                 inverse_dynamics=False,
                 env=None,
                 seed=None,
                 num_workers=1,
                 epochs_counter=None,
                 act_prev_mean=None,
                 obs_prev_mean=None,
                 eta=None,
                 benchmark=None,
                 pre_epochs=-1,
                 normalize_spaces=True,
                 add_noise=False):
        super(MultiTaskSampler, self).__init__(env_name,
                                               env_kwargs,
                                               batch_size,
                                               policy,
                                               seed=seed,
                                               env=env)
        # Metaworld
        self.benchmark = benchmark

        ### Dynamics
        self.env_name = env_name
        self.epochs_counter = epochs_counter
        self.pre_epochs = pre_epochs

        self.dynamics = dynamics
        if self.dynamics is not None:
            self.kl_previous = mp.Manager().list()
            kl_previous_lock = mp.Manager().RLock()
            dynamics_lock = mp.Lock()
        else:
            dynamics_lock = None
            self.kl_previous = None
            kl_previous_lock = None

        self.inverse_dynamics = inverse_dynamics

        self.act_prev_mean = act_prev_mean
        self.obs_prev_mean = obs_prev_mean

        act_prev_lock = mp.Manager().RLock()
        obs_prev_lock = mp.Manager().RLock()

        self.num_workers = num_workers

        self.task_queue = mp.JoinableQueue()
        self.train_episodes_queue = mp.Queue()
        self.valid_episodes_queue = mp.Queue()
        policy_lock = mp.Lock()

        self.workers = [
            SamplerWorker(
                index,
                env_name,
                env_kwargs,
                batch_size,
                self.env.observation_space,
                self.env.action_space,
                self.policy,
                deepcopy(baseline),
                self.seed,
                self.task_queue,
                self.train_episodes_queue,
                self.valid_episodes_queue,
                policy_lock,
                # Queues and Epochs
                epochs_counter=epochs_counter,
                pre_epochs=pre_epochs,
                act_prev_lock=act_prev_lock,
                obs_prev_lock=obs_prev_lock,
                act_prev_mean=self.act_prev_mean,
                obs_prev_mean=self.obs_prev_mean,
                # Dynamics
                dynamics=self.dynamics,
                dynamics_lock=dynamics_lock,
                kl_previous=self.kl_previous,
                kl_previous_lock=kl_previous_lock,
                inverse_dynamics=self.inverse_dynamics,
                eta=eta,
                # Metaworld
                benchmark=benchmark,
                normalize_spaces=normalize_spaces,
                add_noise=add_noise) for index in range(num_workers)
        ]

        for worker in self.workers:
            worker.daemon = True
            worker.start()

        self._waiting_sample = False
        self._event_loop = asyncio.get_event_loop()
        self._train_consumer_thread = None
        self._valid_consumer_thread = None
예제 #18
0
def finetune_directory(ft_dir,
                       stats_file,
                       app_args,
                       cleanup_ft_dir=False,
                       checkpoints=None):
    """Fine tune all the checkpoint files we find in the immediate-directory specified.

    For each checkpoint file we find, we create and queue a FinetuningTask.  
    A FinetuningProcess will pickup the FinetuningTask and process it.
    """
    print("Fine-tuning directory %s" % ft_dir)
    if not checkpoints:
        # Get a list of the checkpoint files
        checkpoints = glob.glob(os.path.join(ft_dir, "*checkpoint.pth.tar"))
    assert checkpoints

    # We create a subdirectory, where we will write all of our output
    ft_output_dir = os.path.join(ft_dir, "ft")
    os.makedirs(ft_output_dir, exist_ok=True)
    print("Writing results to directory %s" % ft_output_dir)
    app_args.output_dir = ft_output_dir

    # Multi-process queues
    tasks = multiprocessing.JoinableQueue()
    results = multiprocessing.Queue()

    # Create and launch the fine-tuning processes
    processes = []
    n_processes = min(app_args.processes, len(checkpoints))
    for i in range(n_processes):
        # Pre-load the data-loaders of each fine-tuning process once
        app = classifier.ClassifierCompressor(
            app_args, script_dir=os.path.dirname(__file__))
        data_loader = classifier.load_data(app.args)
        # Delete log directories
        shutil.rmtree(app.logdir)
        processes.append(FinetuningProcess(tasks, results, data_loader))
        # Start the process
        processes[-1].start()

    n_gpus = torch.cuda.device_count()

    # Enqueue all of the fine-tuning tasks
    for (instance, ckpt_file) in enumerate(checkpoints):
        tasks.put(FinetuningTask(args=(ckpt_file, instance % n_gpus,
                                       app_args)))

    # Push an end-of-tasks marker
    for i in range(len(processes)):
        tasks.put(None)

    # Wait until all tasks finish
    tasks.join()

    # Start printing results
    results_dict = OrderedDict()
    while not results.empty():
        result = results.get()
        results_dict[result[0]] = result[1]

    # Read the results of the AMC experiment (we'll want to use some of the data)
    import pandas as pd
    df = pd.read_csv(os.path.join(ft_dir, "amc.csv"))
    assert len(results_dict) > 0
    # Log some info for each checkpoint
    for ckpt_name in sorted(results_dict.keys()):
        net_search_results = df[df["ckpt_name"] ==
                                ckpt_name[:-len("_checkpoint.pth.tar")]]
        search_top1 = net_search_results["top1"].iloc[0]
        normalized_macs = net_search_results["normalized_macs"].iloc[0]
        log_entry = (ft_output_dir, ckpt_name, normalized_macs, search_top1,
                     *results_dict[ckpt_name])
        print("%s <>  %s: %.2f %.2f %.2f %.2f %.2f" % log_entry)
        stats_file.add_record(log_entry)
    if cleanup_ft_dir:
        # cleanup: remove the "ft" directory
        shutil.rmtree(ft_output_dir)
예제 #19
0
 def __init__(self):
     self.queue = mp.JoinableQueue(1)
     self._put_end = False
     self._got_end = False
예제 #20
0
    def train_model(self,
                    num_epochs=10,
                    resume_model=False,
                    resume_memory=False,
                    num_workers=None):
        try:
            evaluator = self.network
            optim = torch.optim.SGD(evaluator.parameters(),
                                    weight_decay=0.0001,
                                    momentum=0.9,
                                    lr=self.lr)
            evaluator.share_memory()

            num_workers = num_workers or multiprocessing.cpu_count()
            player_workers = [
                SelfPlayWorker(
                    self.task_queue,
                    self.memory_queue,
                    self.result_queue,
                    self.env_gen,
                    evaluator=evaluator,
                    start_time=self.start_time,
                    policy_container=self.policy_container,
                    opposing_policy_container=self.opposing_policy_container,
                    evaluation_policy_container=self.
                    evaluation_policy_container,
                    save_dir=self.save_dir,
                    resume=resume_model,
                    self_play=self.self_play,
                ) for _ in range(num_workers - 1)
            ]
            for w in player_workers:
                w.start()

            update_worker_queue = multiprocessing.JoinableQueue()

            update_flag = multiprocessing.Event()
            update_flag.clear()

            update_worker = UpdateWorker(
                memory_queue=self.memory_queue,
                policy_container=self.policy_container,
                evaluator=evaluator,
                optim=optim,
                update_flag=update_flag,
                update_worker_queue=update_worker_queue,
                save_dir=self.save_dir,
                resume=resume_memory,
                start_time=self.start_time,
                stagger=self.stagger,
            )

            update_worker.start()

            for i in range(self.initial_games):
                swap_sides = not i % 2 == 0
                self.task_queue.put(
                    {"play": {
                        "swap_sides": swap_sides,
                        "update": False
                    }})
            self.task_queue.join()
            while not self.result_queue.empty():
                self.result_queue.get()

            saved_model_name = None
            reward = self.evaluate_policy(-1)
            for epoch in range(num_epochs):
                update_flag.set()
                for i in range(self.epoch_length):
                    swap_sides = not i % 2 == 0
                    self.task_queue.put({
                        "play": {
                            "swap_sides": swap_sides,
                            "update": True
                        },
                        "saved_name": saved_model_name,
                    })
                self.task_queue.join()

                saved_model_name = os.path.join(
                    self.save_dir,
                    self.start_time,
                    "model-" + datetime.datetime.now().isoformat() + ":" +
                    str(self.epoch_length * epoch),
                )
                update_flag.clear()
                update_worker_queue.put({"saved_name": saved_model_name})
                reward = self.evaluate_policy(epoch)

                update_worker_queue.join()
                update_worker_queue.put({"reward": reward})

                # Do some evaluation?

            # Clean up
            update_worker.terminate()
            [w.terminate() for w in player_workers]
            del self.memory_queue
            del self.task_queue
            del self.result_queue
        except Exception as e:
            logging.exception("error in main loop" + str(e))