示例#1
0
    def __init__(self, *args, preprocessing_transform=None, preprocessing_params: dict = None,
                 transform=None, no_progress: bool = False, **kwargs):
        """
        Args:
            positional and keyword arguments for initialize(*args, **kwargs) (see class and initialize documentation)
            preprocessing_params (dict): parameters for the preprocessing:
                - 'cache_dir': path to the cached preprocessed data.
                - 'num_workers': number of process used in parallel for preprocessing (default: number of cores)
            preprocessing_transform (Callable): Called on the outputs of _get_data over the indices
                                                from 0 to len(self) during the construction of the dataset,
                                                the preprocessed outputs are then cached to 'cache_dir'.
            transform (Callable): Called on the preprocessed data at __getitem__.
            no_progress (bool): disable tqdm progress bar for preprocessing.
        """
        self.initialize(*args, **kwargs)
        if preprocessing_transform is not None:
            desc = 'Applying preprocessing'
            if preprocessing_params is None:
                preprocessing_params = {}

            cache_dir = preprocessing_params.get('cache_dir')
            assert cache_dir is not None, 'Cache directory is not given'

            self.cache_convert = helpers.Cache(
                preprocessing_transform,
                cache_dir=cache_dir,
                cache_key=helpers._get_hash(repr(preprocessing_transform))
            )

            use_cuda = preprocessing_params.get('use_cuda', False)

            num_workers = preprocessing_params.get('num_workers')
            uncached = [idx for idx in range(len(self)) if self._get_attributes(idx)[
                'name'] not in self.cache_convert.cached_ids]
            if len(uncached) > 0:
                if num_workers == 0:
                    with torch.no_grad():
                        for idx in tqdm(range(len(self)), desc=desc, disable=no_progress):
                            name = self._get_attributes(idx)['name']
                            data = self._get_data(idx)
                            self.cache_convert(name, *data)
                else:
                    p = Pool(num_workers)
                    iterator = p.imap_unordered(
                        _preprocess_task,
                        [(idx, self._get_data, self._get_attributes, self.cache_convert)
                            for idx in uncached])
                    for i in tqdm(range(len(uncached)), desc=desc, disable=no_progress):
                        next(iterator)
        else:
            self.cache_convert = None

        self.transform = transform
示例#2
0
    def __init__(self,
                 dataset,
                 preprocessing_transform=None,
                 cache_dir=None,
                 num_workers=None,
                 transform=None,
                 no_progress: bool = False):
        """
        A wrapper dataset that applies a preprocessing transform to a given base
        dataset. The result of the preprocessing transform will be cached to
        disk.

        The base dataset should support a `get_data(idx)` method that returns
        the data to be preprocessed. If such method is not found, then its
        `__getitem__(idx)` method will be used instead.

        The base dataset can optionally support a `get_attributes(idx)` method
        that returns the data that should not be preprocessed. If such method is
        not found, an empty dict will be used as attributes. The `__getitem__`
        of `ProcessedDataset` will contain both the data and the attributes.

        The base dataset can optionally support a `get_cache_key(idx)` method
        that returns the string key to use for caching. If such method is not
        found, the index (as a string) will be used as cache key.

        Note: if CUDA is used in preprocessing, `num_workers` must be set to 0.

        Args:
            dataset (torch.utils.data.Dataset):
                The base dataset to preprocess.
            cache_dir (str):
                Path to the cached preprocessed data. Must be given if
                `preprocessing_transform` is not None.
            num_workers (int):
                Number of process used in parallel for preprocessing (default:
                number of cores)
            preprocessing_transform (Callable):
                Called on the outputs of get_data over the indices from 0 to
                `len(self)` during the construction of the dataset, the
                preprocessed outputs are then cached to 'cache_dir'.
            transform (Callable):
                Called on the preprocessed data at `__getitem__`.
                The result of this function is not cached, unlike
                `preprocessing_transform`.
            no_progress (bool): Disable tqdm progress bar for preprocessing.
        """
        # TODO: Consider integrating combination into `ProcessedDataset`.

        self.dataset = dataset
        self.transform = transform

        if preprocessing_transform is not None:
            desc = 'Applying preprocessing'

            assert cache_dir is not None, 'Cache directory is not given'

            self.cache_convert = Cache(preprocessing_transform,
                                       cache_dir=cache_dir,
                                       cache_key=_get_hash(
                                           repr(preprocessing_transform)))

            uncached = [
                idx for idx in range(len(self))
                if self.get_cache_key(idx) not in self.cache_convert.cached_ids
            ]
            if len(uncached) > 0:
                if num_workers == 0:
                    with torch.no_grad():
                        for idx in tqdm(range(len(self)),
                                        desc=desc,
                                        disable=no_progress):
                            key = self.get_cache_key(idx)
                            data = self._get_base_data(idx)
                            self.cache_convert(key, data)
                else:
                    p = Pool(num_workers)
                    iterator = p.imap_unordered(
                        _preprocess_task,
                        [(idx, self._get_base_data, self.get_cache_key,
                          self.cache_convert) for idx in uncached])
                    for i in tqdm(range(len(uncached)),
                                  desc=desc,
                                  disable=no_progress):
                        next(iterator)
        else:
            self.cache_convert = None
示例#3
0
    def learn(self):
        """
        Performs numIters iterations with numEps episodes of self-play in each
        iteration. After every iteration, it retrains neural network with
        examples in trainExamples (which has a maximum length of maxlenofQueue).
        It then pits the new neural network against the old one and accepts it
        only if it wins >= updateThreshold fraction of games.
        """
        try:
            mp.set_start_method('spawn')
        except RuntimeError:
            pass

        manager = mp.Manager()
        sharedQ = manager.Queue()

        # Create the server-communicating process
        remoteDataQ = manager.Queue()
        remoteSDQ = manager.Queue()
        rrProc = mp.Process(target=Coach.remoteSendProcess if self.args.remote_send else Coach.remoteRecvProcess, args=((remoteDataQ, remoteSDQ),))
        rrProc.daemon = True
        rrProc.start()

        # Generate self-plays and train
        for i in range(1, self.args.numIters + 1):
            # If remote_send (i.e. Haedong server), update state_dict
            if self.args.remote_send:
                log.info("Checking for state_dict update")
                if not remoteSDQ.empty():
                    sd = remoteSDQ.get()
                    while not remoteSDQ.empty():
                        sd = remoteSDQ.get()
                    self.nnet.nnet.load_state_dict(sd)
                    log.info("Updated state_dict")
                else:
                    log.info("No new state_dict available")

            # Create num_gpu_procs nnProcess
            nnProcs = []
            for j in range(self.args.num_gpu_procs):
                # Run nnProc
                state_dict = {k: v.cpu() for k, v in self.nnet.nnet.state_dict().items()}
                nnProc = mp.Process(target=Coach.nnProcess, args=[(self.game, state_dict, sharedQ, j%torch.cuda.device_count())])
                nnProc.daemon = True
                nnProc.start()
                nnProcs.append(nnProc)

            # Create self-play process pool
            selfplayPool = Pool(None)

            # Create pool args
            pArgs = []
            for j in range(self.args.numEps):
                # pArgs.append((self.game, self.args, sharedQ))
                pArgs.append((Game(6), self.args, sharedQ))

            # bookkeeping
            log.info(f'Starting Iter #{i} ... ({self.selfPlaysPlayed} games played)')
            # examples of the iteration
            if not self.skipFirstSelfPlay or i > 1:
                iterationTrainExamples = deque([], maxlen=self.args.maxlenOfQueue)

                log.info('Start generating self-plays')

                with tqdm(total = self.args.numEps) as pbar:
                    for d in tqdm(selfplayPool.imap_unordered(Coach.executeEpisode, pArgs)):
                        if self.args.remote_send:
                            remoteDataQ.put(d)
                            # remoteconn.send(d)
                        else:
                            iterationTrainExamples += d
                        pbar.update()
                
                self.selfPlaysPlayed += self.args.numEps

                # save the iteration examples to the history 
                self.trainExamplesHistory.append(iterationTrainExamples)

            # Close the process pool
            selfplayPool.close()

            # Kill the NN processes
            for j in range(self.args.num_gpu_procs):
                sharedQ.put(None)

            for j in range(self.args.num_gpu_procs):
                nnProcs[j].join()

            # If the process is remote_send (i.e. the Haedong server), then skip the training part
            if self.args.remote_send:
                continue
            
            # Otherwise, add the server-generated examples to the iterationTrainExamples
            num_remote_selfplays = 0
            while not remoteDataQ.empty():
                d = remoteDataQ.get()
                iterationTrainExamples += d
                num_remote_selfplays += 1
            
            log.info(f'{num_remote_selfplays} self-play data loaded from remote server')
            self.selfPlaysPlayed += num_remote_selfplays

            # Update the trainExamplesHistory
            if len(self.trainExamplesHistory) > self.args.numItersForTrainExamplesHistory:
                log.warning(
                    f"Removing the oldest entry in trainExamples. len(trainExamplesHistory) = {len(self.trainExamplesHistory)}")
                self.trainExamplesHistory.pop(0)

            # backup history to a file
            # NB! the examples were collected using the model from the previous iteration, so (i-1)  
            self.saveTrainExamples(self.selfPlaysPlayed)

            # shuffle examples before training
            trainExamples = []
            for e in self.trainExamplesHistory:
                trainExamples.extend(e)
            shuffle(trainExamples)

            log.info('TRAINING AND SAVING NEW MODEL')
            self.nnet.train(trainExamples)
            self.nnet.save_checkpoint(folder=self.args.checkpoint, filename=self.getCheckpointFile(self.selfPlaysPlayed))

            # Send the new state_dict
            state_dict = {k: v.cpu() for k, v in self.nnet.nnet.state_dict().items()}
            remoteSDQ.put(state_dict)
            log.info('Sent the updated state_dict')