示例#1
0
def load_saved_model(model: Module,
                     path: str,
                     T: Value,
                     global_reward: Value,
                     model_critic: Module = None) -> None:
    """
    load saved model from file
    :param model: model to load params for
    :param path: path to load parameters from
    :param T: global steps counter, to continue training
    :param model_critic: possible separate critic model to load if non shared network is used
    :return: None
    """
    if os.path.isfile(path):
        print(f"=> loading model checkpoint '{path}'")
        checkpoint = torch.load(path)
        model.load_state_dict(checkpoint['model'])
        T.value = checkpoint['epoch']
        global_reward.value = checkpoint['global_reward']
        if model_critic:
            model_critic.load_state_dict(checkpoint['model_critic'])
        print(f"=> loaded model checkpoint '{path}' (T: {checkpoint['epoch']} "
              f"-- global reward: {checkpoint['global_reward']})")
    else:
        print(f"=> no model checkpoint found at '{path}'")
示例#2
0
    def _generate_parallel(self, iteration, network, device, num_workers):
        q, r = divmod(self.remaining_games, num_workers)
        num_active_workers = Value('i', num_workers)
        resign_threshold = Value('d', self.resign_mgr.threshold())
        evaluator_mgr = BulkEvaluatorManager([network], device, num_workers)
        output_queue = SimpleQueue()

        # start the workers
        workers = []
        for worker_id in range(num_workers):
            num_games = q + 1 if worker_id < r else q
            evaluator = evaluator_mgr.get_evaluator(worker_id, 0)
            worker = Process(
                target=self._worker_job,
                args=(worker_id, num_games, num_active_workers,
                      resign_threshold, evaluator, output_queue),
            )
            workers.append(worker)
            worker.start()

        # start evaluator server
        server = evaluator_mgr.get_server(num_active_workers)
        server.start()

        # collect the examples generated by workers
        while num_active_workers.value > 0 or not output_queue.empty():
            examples, resign_value_history, result = output_queue.get()
            self.example_pool += examples
            self.game_length.append(len(examples))

            # add the history into resignation manager to update the threshold
            if resign_value_history is not None:
                self.resign_mgr.add(resign_value_history, result)
                resign_threshold.value = self.resign_mgr.threshold()

            self.remaining_games -= 1

            # periodically save the progress
            if (self.conf.GAMES_PER_ITERATION - self.remaining_games) \
                    % self.conf.EXAMPLE_POOL_SAVE_FREQUENCY == 0:
                self.save(iteration)
                log.info(
                    f'[iter={iteration}] ExamplePool: checkpoint saved, '
                    f'{self.remaining_games} games remaining'
                )

        for worker in workers:
            worker.join()
        server.join()
示例#3
0
        server = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
        server.bind((address, port))
        server.listen()

        param_queue = Queue()
        param_queue.put(net.state_dict())

        shutdown_val = Value('b', 0)

        receiver_proc = Process(target=HandleWorkers,
                                args=(server, replay_memory, mem_lock,
                                      param_queue, shutdown_val))
        receiver_proc.start()

    while True:
        try:
            Train(net, replay_memory, mem_lock, args.output_file)
            if param_queue is not None:
                param_queue.put(net.state_dict)
            torch.save(net.state_dict(), args.output_file)
        except KeyboardInterrupt:
            if server is not None:
                assert (shutdown_val is not None and receiver_proc is not None)
                print("Shutting down...")

                with shutdown_val.get_lock():
                    shutdown_val.value = 1
                receiver_proc.join()
                server.close()
            break
示例#4
0
    # Opt
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    # Save every 2k
    if iterations % 2000 == 0 or exit:
        print("Saving")
        save_checkpoint(
            {
                'iterations': iterations,
                'state_dict': model.state_dict(),
            }, NAME)
    if exit:
        print("Exiting..")
        control.value = 0
        sys.exit()
    print((iterations, loss))

    # OP Test
    if int(args.debug):
        test_index = 0
        hm_final = hm[test_index, :, :, :]
        paf_final = pafC[test_index, :, :, :]
        poseHeatMaps = torch.cat([hm_final, paf_final],
                                 0).detach().cpu().numpy().copy()
        imageToProcess = imgs.detach().cpu().numpy().copy()[
            test_index, :, :, :]
        imageToProcess = (cv2.merge([
            imageToProcess[0, :, :] + 0.5, imageToProcess[1, :, :] + 0.5,
            imageToProcess[2, :, :] + 0.5
示例#5
0
            receiver_proc = Process(target=HandleWorkers,
                                    args=(server, out_queue, param_queue,
                                          shutdown_server))
            receiver_proc.start()

            state_dict = resnet.state_dict()
            print(type(state_dict))
            param_queue.put(state_dict)

            tensor = out_queue.get()
            net_output = resnet(tensor)
            print(net_output)

            print("Shutting down server...")
            with shutdown_server.get_lock():
                shutdown_server.value = 1
        else:
            param_queue = Queue()
            server = communication.WorkerSocket(address, port)
            print("Connected to server")

            receiver_proc = Process(target=ReceiveParams,
                                    args=(server, param_queue),
                                    daemon=True)
            receiver_proc.start()

            # for _ in range(3):
            #     SendPlayout(server)
            #     time.sleep(1)

            state_dict = param_queue.get()
def generate_kernel_parallel(kernel_cfg,
                             x,
                             y,
                             batch_size=32,
                             num_gpus=4,
                             symmetric=False,
                             model_uuid=None,
                             checkpoint_K=None,
                             checkpoint_rows_done=None,
                             cache_path="tc_cache",
                             float32=False,
                             extra_info={},
                             verbose=False,
                             use_tqdm=True):
    ''' Takes in two numpy arrays x and y that are N x H x W x C and M x H x W x C
        and spits out a kernel matrix K that is N x M
    '''

    #TODO fixme
    print("Batch Size ", batch_size)
    assert num_gpus <= torch.cuda.device_count()
    N = x.shape[0]
    M = y.shape[0]
    if float32:
        K = np.memmap("/dev/shm/kernel",
                      mode="w+",
                      dtype="float32",
                      shape=(N, M))
    else:
        K = np.memmap("/dev/shm/kernel",
                      mode="w+",
                      dtype="float64",
                      shape=(N, M))

    K.fill(np.inf)
    rows_done = np.memmap("/dev/shm/rowsdone",
                          mode="w+",
                          dtype="uint16",
                          shape=(1, ))
    if checkpoint_rows_done is not None:
        rows_done[:] = np.copy(utils.bytes_to_numpy(checkpoint_rows_done))
        K[:rows_done[0], :] = np.copy(utils.bytes_to_numpy(checkpoint_K))

    n = 0
    done_q = Queue()
    data_q = Queue()
    done = Value('i', 0)
    num_column_blocks = int(N / batch_size)

    x_idxs = torch.arange(x.shape[0])
    y_idxs = torch.arange(y.shape[0])

    x_data = TensorDataset(x_idxs, torch.from_numpy(x))
    x_loader = DataLoader(x_data, batch_size=batch_size)

    y_data = TensorDataset(y_idxs, torch.from_numpy(y))
    y_loader = DataLoader(y_data, batch_size=batch_size)

    processes = []

    x_data = [x for x in x_loader]
    y_data = [y for y in y_loader]
    count = 0
    start_time = time.time()
    for x_idxs, x_b in x_data:
        for y_idxs, y_b in y_data:
            count += 1
            start_x = int(min(x_idxs))
            end_x = int(max(x_idxs) + 1)
            start_y = int(min(y_idxs))
            end_y = int(max(y_idxs) + 1)
            if end_x > rows_done[0]:
                data_q.put(((x_idxs, x_b), (y_idxs, y_b)))
            #print(start_x, start_y)
            if count % 1000 == 0:
                print("Current Count Is: ", count)
    os.environ["OMP_NUM_THREADS"] = str(1)

    for gpu_idx in range(num_gpus):
        p = Process(target=_kernel_gen_help,
                    args=(done_q, data_q, kernel_cfg, batch_size, symmetric,
                          gpu_idx, K.shape, cache_path, float32, done,
                          verbose))
        processes.append(p)

    for i, p in enumerate(processes):
        p.start()
    if symmetric:
        done_work = rows_done[0] * M + (N - rows_done[0]) * (rows_done[0])
    else:
        done_work = rows_done[0] * M
    work_left = N * M - done_work
    last_checkpoint = work_left
    print("Data_q size start", data_q.qsize())
    if use_tqdm:
        pbar = tqdm(total=N * M)
    else:
        pbar = None
    total_progress = 0
    while work_left > 0:
        progress = done_q.get()
        total_progress += progress
        work_left -= progress
        elapsed = time.time() - start_time
        avg_speed = total_progress / elapsed
        time_left = utils.pretty_time_delta(work_left / avg_speed)
        if pbar is not None:
            pbar.update(progress)
        else:
            print(
                f"Work Left: {work_left}, Work done so far: {total_progress}, Time Left: {time_left}"
            )
    if pbar is not None:
        pbar.close()
    print("Data_q size end", data_q.qsize())
    done.value = 1
    for i, p in enumerate(processes):
        p.join()
    np.save("/tmp/K_train_full.npy", K)
    if symmetric:
        _symmetric_fill(K, x, y, batch_size)
    K_copy = np.zeros(K.shape)
    np.copyto(K_copy, K)
    assert np.all(np.isfinite(K_copy))
    return K_copy
示例#7
0
def train(args,
          worker_id: int,
          global_model: Union[ActorNetwork, ActorCriticNetwork],
          T: Value,
          global_reward: Value,
          optimizer: torch.optim.Optimizer = None,
          global_model_critic: CriticNetwork = None,
          optimizer_critic: torch.optim.Optimizer = None,
          lr_scheduler: torch.optim.lr_scheduler = None,
          lr_scheduler_critic: torch.optim.lr_scheduler = None):
    """
    Start worker in training mode, i.e. training the shared model with backprop
    loosely based on https://github.com/ikostrikov/pytorch-a3c/blob/master/train.py
    :param args: console arguments
    :param worker_id: id of worker to differentiatethem and init different seeds
    :param global_model: global model, which is optimized/ for split models: actor
    :param T: global counter of steps
    :param global_reward: global running reward value
    :param optimizer: optimizer for shared model/ for split models: actor model
    :param global_model_critic: optional global critic model for split networks
    :param optimizer_critic: optional critic optimizer for split networks
    :param lr_scheduler: optional learning rate scheduler instance for shared model
    / for fixed model: actor learning rate scheduler
    :param lr_scheduler_critic: optional learning rate scheduler instance for critic model
    :return: None
    """
    torch.manual_seed(args.seed + worker_id)

    if args.worker == 1:
        logging.info(f"Running A2C with {args.n_envs} environments.")
        if "RR" not in args.env_name:
            env = SubprocVecEnv([
                make_env(args.env_name, args.seed, i, args.log_dir)
                for i in range(args.n_envs)
            ])
        else:
            env = DummyVecEnv(
                [make_env(args.env_name, args.seed, worker_id, args.log_dir)])
    else:
        logging.info(f"Running A3C: training worker {worker_id} started.")
        env = DummyVecEnv(
            [make_env(args.env_name, args.seed, worker_id, args.log_dir)])
        # avoid any issues if this is not 1
        args.n_envs = 1

    normalizer = get_normalizer(args.normalizer, env)

    # init local NN instance for worker thread
    model = copy.deepcopy(global_model)
    model.train()

    model_critic = None

    if global_model_critic:
        model_critic = copy.deepcopy(global_model_critic)
        model_critic.train()

    # if no shared optimizer is provided use individual one
    if not optimizer:
        optimizer, optimizer_critic = get_optimizer(
            args.optimizer,
            global_model,
            args.lr,
            model_critic=global_model_critic,
            lr_critic=args.lr_critic)
        if args.lr_scheduler == "exponential":
            lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer,
                                                                  gamma=0.99)
            if optimizer_critic:
                lr_scheduler_critic = torch.optim.lr_scheduler.ExponentialLR(
                    optimizer_critic, gamma=0.99)

    state = torch.Tensor(env.reset())

    t = np.zeros(args.n_envs)
    global_iter = 0
    episode_reward = np.zeros(args.n_envs)

    if worker_id == 0:
        writer = SummaryWriter(log_dir='experiments/runs/')

    while True:
        # Get state of the global model
        model.load_state_dict(global_model.state_dict())
        if not args.shared_model:
            model_critic.load_state_dict(global_model_critic.state_dict())

        # containers for computing loss
        values = []
        log_probs = []
        rewards = []
        entropies = []
        # container to check whether a terminal state was reached from one of the envs
        terminals = []

        # reward_sum = 0
        for step in range(args.rollout_steps):
            t += 1

            if args.shared_model:
                value, mu, std = model(normalizer(state))
            else:
                mu, std = model(normalizer(state))
                value = model_critic(normalizer(state))

            dist = torch.distributions.Normal(mu, std)

            # ------------------------------------------
            # # select action
            action = dist.sample()

            # ------------------------------------------
            # Compute statistics for loss
            entropy = dist.entropy().sum(-1).unsqueeze(-1)
            log_prob = dist.log_prob(action).sum(-1).unsqueeze(-1)

            # make selected move
            action = np.clip(action.detach().numpy(), -args.max_action,
                             args.max_action)
            state, reward, dones, _ = env.step(
                action[0]
                if not args.worker == 1 or "RR" in args.env_name else action)

            reward = shape_reward(args, reward)

            episode_reward += reward

            # probably don't set terminal state if max_episode length
            dones = np.logical_or(dones, t >= args.max_episode_length)

            values.append(value)
            log_probs.append(log_prob)
            rewards.append(torch.Tensor(reward).unsqueeze(-1))
            entropies.append(entropy)
            terminals.append(torch.Tensor(1 - dones).unsqueeze(-1))

            for i, done in enumerate(dones):
                if done:
                    # keep track of the avg overall global reward
                    with global_reward.get_lock():
                        if global_reward.value == -np.inf:
                            global_reward.value = episode_reward[i]
                        else:
                            global_reward.value = .99 * global_reward.value + .01 * episode_reward[
                                i]
                    if worker_id == 0 and T.value % args.log_frequency == 0:
                        writer.add_scalar("reward/global", global_reward.value,
                                          T.value)

                    episode_reward[i] = 0
                    t[i] = 0
                    if args.worker != 1 or "RR" in args.env_name:
                        env.reset()

            with T.get_lock():
                # this is one for a3c and n for A2C (actually the lock is not needed for A2C)
                T.value += args.n_envs

            if lr_scheduler and worker_id == 0 and T.value % args.lr_scheduler_step and global_iter != 0:
                lr_scheduler.step(T.value / args.lr_scheduler_step)

                if lr_scheduler_critic:
                    lr_scheduler_critic.step(T.value / args.lr_scheduler_step)

            state = torch.Tensor(state)

        if args.shared_model:
            v, _, _ = model(normalizer(state))
            G = v.detach()
        else:
            G = model_critic(normalizer(state)).detach()

        values.append(G)

        # compute loss and backprop
        advantages = torch.zeros((args.n_envs, 1))

        ret = torch.zeros((args.rollout_steps, args.n_envs, 1))
        adv = torch.zeros((args.rollout_steps, args.n_envs, 1))

        # iterate over all time steps from most recent to the starting one
        for i in reversed(range(args.rollout_steps)):
            # G can be seen essentially as the return over the course of the rollout
            G = rewards[i] + args.discount * terminals[i] * G
            if not args.no_gae:
                # Generalized Advantage Estimation
                td_error = rewards[i] + args.discount * terminals[i] * values[
                    i + 1] - values[i]
                # terminals here to "reset" advantages to 0, because reset ist called internally in the env
                # and new trajectory started
                advantages = advantages * args.discount * args.tau * terminals[
                    i] + td_error
            else:
                advantages = G - values[i].detach()

            adv[i] = advantages.detach()
            ret[i] = G.detach()

        policy_loss = -(torch.stack(log_probs) * adv).mean()
        # minus 1 in order to remove the last element, which is only necessary for next timestep value
        value_loss = .5 * (ret - torch.stack(values[:-1])).pow(2).mean()
        entropy_loss = torch.stack(entropies).mean()

        # zero grads to reset the gradients
        optimizer.zero_grad()

        if args.shared_model:
            # combined loss for shared architecture
            total_loss = policy_loss + args.value_loss_weight * value_loss - args.entropy_loss_weight * entropy_loss
            total_loss.backward()
        else:
            optimizer_critic.zero_grad()

            value_loss.backward()
            (policy_loss - args.entropy_loss_weight * entropy_loss).backward()

            # this is just used for plotting in tensorboard
            total_loss = policy_loss + args.value_loss_weight * value_loss - args.entropy_loss_weight * entropy_loss

        torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
        sync_grads(model, global_model)
        optimizer.step()

        if not args.shared_model:
            torch.nn.utils.clip_grad_norm_(model_critic.parameters(),
                                           args.max_grad_norm)
            sync_grads(model_critic, global_model_critic)
            optimizer_critic.step()

        global_iter += 1

        if worker_id == 0 and T.value % args.log_frequency == 0:
            log_to_tensorboard(writer,
                               model,
                               optimizer,
                               rewards,
                               values,
                               total_loss,
                               policy_loss,
                               value_loss,
                               entropy_loss,
                               T.value,
                               model_critic=model_critic,
                               optimizer_critic=optimizer_critic)