예제 #1
0
    def __init__(self,
                 iterators,
                 optimizer,
                 converter=convert.concat_examples,
                 devices=None):
        if not MultiprocessParallelUpdater.available():
            raise Exception(
                'NCCL is not enabled. MultiprocessParallelUpdater '
                'requires NCCL.\n'
                'Please reinstall chainer after you install NCCL.\n'
                '(see https://github.com/chainer/chainer#installation).')

        assert len(iterators) == len(devices)
        for iterator in iterators[1:]:
            assert len(iterator.dataset) == len(iterators[0].dataset)

        # Correct optimizer parameters for new minibatch size
        optim = optimizer.__class__.__name__
        if optim in ('Adam', 'AdaGrad', 'RMSprop'):
            optimizer.eps *= len(devices)
            warnings.warn(
                'optimizer.eps is changed to {} '
                'by MultiprocessParallelUpdater for new batch size.'.format(
                    optimizer.eps))
        elif optim in ('VaswaniAdam'):
            pass
        elif optim in ('RMSpropGraves', 'AdaDelta'):
            optimizer.eps *= len(devices)**2  # not quite right for AdaDelta
            warnings.warn(
                'optimizer.eps is changed to {} '
                'by MultiprocessParallelUpdater for new batch size.'.format(
                    optimizer.eps))
        elif hasattr(optimizer, 'lr'):
            optimizer.lr /= len(devices)
            warnings.warn(
                'optimizer.lr is changed to {} '
                'by MultiprocessParallelUpdater for new batch size.'.format(
                    optimizer.lr))

        super(MultiprocessParallelUpdater,
              self).__init__(iterator=iterators[0],
                             optimizer=optimizer,
                             converter=converter)

        if isinstance(devices, dict):
            main = devices.pop('main')
            devices = list(six.itervalues(devices))
            devices = [main] + devices
        if devices is None or any(device is None for device in devices):
            raise ValueError('must specify GPU devices')

        self._master = optimizer.target
        self._devices = devices
        self._mpu_iterators = iterators
        self._initialized = False

        self._pipes = []
        self._workers = []
        self.comm = None
예제 #2
0
        localization_net.disable_update()

    # if we are using more than one GPU, we need to evenly split the datasets
    if len(args.gpus) > 1:
        gpu_datasets = split_dataset_n_random(train_dataset, len(args.gpus))
        if not len(gpu_datasets[0]) == len(gpu_datasets[-1]):
            adapted_second_split = split_dataset(gpu_datasets[-1], len(gpu_datasets[0]))[0]
            gpu_datasets[-1] = adapted_second_split
    else:
        gpu_datasets = [train_dataset]

    train_iterators = [chainer.iterators.MultiprocessIterator(dataset, args.batch_size) for dataset in gpu_datasets]
    validation_iterator = chainer.iterators.MultiprocessIterator(validation_dataset, args.batch_size, repeat=False)

    # use the MultiProcessParallelUpdater in order to harness the full power of data parallel computation
    updater = MultiprocessParallelUpdater(train_iterators, optimizer, devices=args.gpus)

    log_dir = os.path.join(args.log_dir, "{}_{}".format(datetime.datetime.now().isoformat(), args.log_name))
    args.log_dir = log_dir

    # backup current file
    if not os.path.exists(log_dir):
        os.makedirs(log_dir, exist_ok=True)
    shutil.copy(__file__, log_dir)

    # backup all necessary configuration params
    report = {
        'log_dir': log_dir,
        'image_size': image_size,
        'target_size': target_shape,
        'localization_net': [localization_net.__class__.__name__, get_definition_filename(localization_net)],
예제 #3
0
        train_iter = chainer.iterators.SerialIterator(train, args.batchsize)
    else:
        train_iters = [
            MultithreadIterator(i, int(args.batchsize / len(devices)),
                                n_threads=4)
            for i in split_dataset_n_random(train, len(devices))]

    test_iter = MultithreadIterator(test, args.batchsize, repeat=False,
                                    shuffle=False, n_threads=4)

    # Set up a trainer
    if len(args.gpus) < 2:
        updater = training.StandardUpdater(train_iter, optimizer,
                                           device=args.gpus[0])
    else:
        updater = MultiprocessParallelUpdater(train_iters, optimizer,
                                              devices=devices)
    trainer = training.Trainer(updater, (args.epoch, 'epoch'),
                               out=output_dir)

    if args.cosine:
        trainer.extend(
            CosineAnnealing('lr', int(args.epoch),
                            len(train) / args.batchsize,
                            eta_min=args.eta_min,
                            init=args.lr))
    else:
        trainer.extend(
            extensions.ExponentialShift('lr', 0.1, init=args.lr),
            trigger=triggers.ManualScheduleTrigger(
                [int(args.epoch * 0.50), int(args.epoch * 0.75)], 'epoch'))
예제 #4
0
        train_iterators = [
            MultiprocessIterator(dataset,
                                 args.batch_size,
                                 n_processes=args.num_processes)
            for dataset in gpu_datasets
        ]

        validation_iterator = MultiprocessIterator(
            validation_dataset,
            args.batch_size,
            n_processes=args.num_processes,
            repeat=False)

    updater = MultiprocessParallelUpdater(
        train_iterators,
        optimizer,
        devices=args.gpus,
        converter=get_concat_and_pad_examples(args.blank_label))
    updater.setup_workers()

    log_dir = os.path.join(
        args.log_dir, "{}_{}".format(datetime.datetime.now().isoformat(),
                                     args.log_name))
    args.log_dir = log_dir

    # backup current file
    if not os.path.exists(log_dir):
        os.makedirs(log_dir, exist_ok=True)
    shutil.copy(__file__, log_dir)

    # log all necessary configuration params