Example #1
0
    def _dry_run(
        self,
        devices: List[torch.device],
        training_shape: Optional[Point],
        valid_shapes: Optional[List[Point]],
        shrinkage: Optional[Point],
        fut: Future,
    ) -> None:
        if self.config.get(DRY_RUN, {}).get(SKIP, False):
            training_shape_from_config = self.config[TRAINING].get(TRAINING_SHAPE, None)
            shrinkage_from_config = self.config.get(DRY_RUN, {}).get(SHRINKAGE, None)
            if training_shape_from_config is None:
                warnings.warn(f"Cannot skip dry run due to missing {TRAINING_SHAPE} in config:{TRAINING}")
            elif shrinkage_from_config is None:
                warnings.warn(f"Cannot skip dry run due to missing {SHRINKAGE} in config:{DRY_RUN}")
            elif len(training_shape_from_config) != len(shrinkage_from_config):
                warnings.warn(
                    f"Cannot skip dry run due to incompatible config values:\n"
                    f"\tconfig:{TRAINING}:{TRAINING_SHAPE}:{training_shape_from_config}\n"
                    f"\tconfig:{DRY_RUN}:{SHRINKAGE}:{shrinkage_from_config}"
                )
            else:
                training_shape_from_config = Point(**{f"d{i}": v for i, v in enumerate(training_shape_from_config)})
                shrinkage_from_config = Point(**{f"d{i}": v for i, v in enumerate(shrinkage_from_config)})
                self.logger.info("skip dry run (no_dry_run: true specified in config)")
                fut.set_result(
                    (devices, training_shape_from_config, [training_shape_from_config], shrinkage_from_config)
                )
                return

        self.logger.info("Starting dry run for %s", devices)
        try:
            if not devices:
                raise ValueError(f"Dry run on empty device list")

            if self.shrinkage is None:
                self.shrinkage = shrinkage
            elif shrinkage is not None and shrinkage != self.shrinkage:
                raise ValueError(f"given shrinkage {shrinkage} incompatible with self.shrinkage {self.shrinkage}")

            working_devices = self.minimal_device_test(devices)
            failed_devices = set(devices) - set(working_devices)
            if failed_devices:
                self.logger.error(f"Minimal device test failed for {failed_devices}")

            if self.training_shape is None:
                self.training_shape = self._determine_training_shape(training_shape=training_shape, devices=devices)
            elif training_shape is not None and self.training_shape != training_shape:
                raise ValueError(
                    f"given training_shape {training_shape} incompatible with self.training_shape {self.training_shape}"
                )

            self._determine_valid_shapes(devices=devices, valid_shapes=valid_shapes)

            fut.set_result((devices, self.training_shape, self.valid_shapes, self.shrinkage))
            self.logger.info("dry run done. shrinkage: %s", self.shrinkage)
        except Exception as e:
            self.logger.error(traceback.format_exc())
            fut.set_exception(e)
Example #2
0
    def validate_shape(self, devices: Sequence[torch.device], shape: Point,
                       train_mode: bool) -> bool:
        assert devices
        if train_mode:
            crit_class = self.criterion_class
            criterion_kwargs = {
                key: value
                for key, value in self.config[TRAINING]
                [LOSS_CRITERION_CONFIG].items() if key != "method"
            }
        else:
            crit_class = None
            criterion_kwargs = {}

        return_conns = [
            in_subproc(
                self._validate_shape,
                model=self.model,
                device=d,
                shape=shape,
                criterion_class=crit_class,
                criterion_kwargs=criterion_kwargs,
            ) for d in devices
        ]
        output_shapes = [conn.recv() for conn in return_conns]
        for e in output_shapes:
            if isinstance(e, str):
                self.logger.info("Shape %s invalid: %s", shape, e)
                return False

        out = output_shapes[0]
        if any([o != out for o in output_shapes[1:]]):
            self.logger.warning(
                "different devices returned different output shapes for same input shape!"
            )
            return False

        output_shape = shape.__class__(
            **{a: s
               for a, s in zip(shape.order, out)}).drop_batch()

        shrinkage = shape.drop_batch() - output_shape

        if self.shrinkage is None:
            self.shrinkage = shrinkage
            self.logger.info("Determined shrinkage to be %s", shrinkage)
            return True
        else:
            return self.shrinkage == shrinkage
Example #3
0
 def _determine_valid_shapes(self, devices: Sequence[torch.device], valid_shapes: Sequence[Point]):
     # todo: find valid shapes
     if valid_shapes is None:
         self.valid_shapes = [self.training_shape]
     else:
         self.valid_shapes = [
             s
             for s in valid_shapes
             if self.validate_shape(
                 devices=devices, shape=Point(**{a: s[a] for a in s.order}).add_batch(1), train_mode=False
             )
         ]
Example #4
0
    def find_one_shape(
        self,
        lower_limit: Point,
        upper_limit: Point,
        devices: Sequence[torch.device],
        train_mode: bool = False,
        discard: float = 0,
    ) -> Optional[Point]:
        assert lower_limit.order == upper_limit.order
        lower = numpy.array(lower_limit)
        upper = numpy.array(upper_limit)
        diff = upper - lower
        assert all(
            diff >= 0
        ), f"negative diff: {diff} = upper({upper}) - lower({lower}) "
        assert 0 <= discard < 1

        def update_nonzero(diff):
            nonzero_index = diff.nonzero()[0]
            nonzero = diff[nonzero_index]
            ndiff = len(nonzero)
            return nonzero_index, nonzero, ndiff

        nonzero_index, nonzero, ndiff = update_nonzero(diff)

        ncomb = numpy.prod(nonzero)
        if ncomb > 10000:
            self.logger.warning("Possibly testing too many combinations!!!")

        while ndiff:
            search_order = numpy.argsort(nonzero)[::-1]
            for diff_i in search_order:
                shape = Point(**dict(zip(lower_limit.order, lower + diff)))
                if self.validate_shape(devices=devices,
                                       shape=shape,
                                       train_mode=train_mode):
                    return shape

                reduced = int((1.0 - discard) * nonzero[diff_i] - 1)
                diff[nonzero_index[diff_i]] = reduced

            nonzero_index, nonzero, ndiff = update_nonzero(diff)

        return None
Example #5
0
    def _determine_training_shape(self, devices: Sequence[torch.device], training_shape: Optional[Point] = None):
        self.logger.debug("Determine training shape on %s (previous training shape: %s)", devices, training_shape)
        batch_size = self.config[TRAINING][BATCH_SIZE]

        if TRAINING_SHAPE in self.config[TRAINING]:
            config_training_shape = Point(**{f"d{i}": v for i, v in enumerate(self.config[TRAINING][TRAINING_SHAPE])})
            if training_shape is None:
                training_shape = config_training_shape
            else:
                assert training_shape == config_training_shape, "training shape unequal to config training shape"

            self.logger.debug("Validate given training shape: %s", training_shape)
            training_shape = training_shape.add_batch(batch_size)

            if TRAINING_SHAPE_UPPER_BOUND in self.config[TRAINING]:
                training_shape_upper_bound = Point(
                    b=batch_size,
                    **{f"d{i}": v for i, v in enumerate(self.config[TRAINING][TRAINING_SHAPE_UPPER_BOUND])},
                )
                if not (training_shape <= training_shape_upper_bound):
                    raise ValueError(
                        f"{TRAINING_SHAPE}: {training_shape} incompatible with {TRAINING_SHAPE_UPPER_BOUND}: "
                        f"{training_shape_upper_bound}"
                    )

            if TRAINING_SHAPE_LOWER_BOUND in self.config[TRAINING]:
                training_shape_lower_bound = Point(
                    b=batch_size,
                    **{f"d{i}": v for i, v in enumerate(self.config[TRAINING][TRAINING_SHAPE_LOWER_BOUND])},
                )
            else:
                training_shape_lower_bound = Point(**{a: 1 for a in training_shape.order})

            if not (training_shape_lower_bound <= training_shape):
                raise ValueError(
                    f"{TRAINING_SHAPE_LOWER_BOUND}{training_shape_lower_bound} incompatible with {TRAINING_SHAPE}"
                    f"{training_shape}"
                )

            if not self.validate_shape(devices=devices, shape=training_shape, train_mode=True):
                raise ValueError(f"{TRAINING_SHAPE}: {training_shape} could not be processed on devices: {devices}")
        else:
            self.logger.debug("Determine training shape from lower and upper bound...")
            if TRAINING_SHAPE_UPPER_BOUND not in self.config[TRAINING]:
                raise ValueError(f"config is missing {TRAINING_SHAPE} and/or {TRAINING_SHAPE_UPPER_BOUND}.")

            training_shape_upper_bound = Point(
                b=batch_size, **{f"d{i}": v for i, v in enumerate(self.config[TRAINING][TRAINING_SHAPE_UPPER_BOUND])}
            )

            if TRAINING_SHAPE_LOWER_BOUND in self.config[TRAINING]:
                training_shape_lower_bound = Point(
                    b=batch_size,
                    **{f"d{i}": v for i, v in enumerate(self.config[TRAINING][TRAINING_SHAPE_LOWER_BOUND])},
                )
            else:
                training_shape_lower_bound = Point(**{a: 1 for a in training_shape_upper_bound.order})

            if not (training_shape_lower_bound <= training_shape_upper_bound):
                raise ValueError(
                    f"{TRAINING_SHAPE_LOWER_BOUND}: {training_shape_lower_bound} incompatible with "
                    f"{TRAINING_SHAPE_UPPER_BOUND}: {training_shape_upper_bound}"
                )

            self.logger.debug(
                "Determine training shape from lower and upper bound (%s, %s)",
                training_shape_lower_bound,
                training_shape_upper_bound,
            )
            training_shape = self.find_one_shape(
                training_shape_lower_bound, training_shape_upper_bound, devices=devices
            )
            if training_shape is None:
                raise ValueError(
                    f"No valid training shape found between lower bound {training_shape_lower_bound} and upper bound "
                    f"{training_shape_upper_bound}"
                )

        training_shape.drop_batch()
        return training_shape