def _dry_run( self, devices: List[torch.device], training_shape: Optional[Point], valid_shapes: Optional[List[Point]], shrinkage: Optional[Point], fut: Future, ) -> None: if self.config.get(DRY_RUN, {}).get(SKIP, False): training_shape_from_config = self.config[TRAINING].get(TRAINING_SHAPE, None) shrinkage_from_config = self.config.get(DRY_RUN, {}).get(SHRINKAGE, None) if training_shape_from_config is None: warnings.warn(f"Cannot skip dry run due to missing {TRAINING_SHAPE} in config:{TRAINING}") elif shrinkage_from_config is None: warnings.warn(f"Cannot skip dry run due to missing {SHRINKAGE} in config:{DRY_RUN}") elif len(training_shape_from_config) != len(shrinkage_from_config): warnings.warn( f"Cannot skip dry run due to incompatible config values:\n" f"\tconfig:{TRAINING}:{TRAINING_SHAPE}:{training_shape_from_config}\n" f"\tconfig:{DRY_RUN}:{SHRINKAGE}:{shrinkage_from_config}" ) else: training_shape_from_config = Point(**{f"d{i}": v for i, v in enumerate(training_shape_from_config)}) shrinkage_from_config = Point(**{f"d{i}": v for i, v in enumerate(shrinkage_from_config)}) self.logger.info("skip dry run (no_dry_run: true specified in config)") fut.set_result( (devices, training_shape_from_config, [training_shape_from_config], shrinkage_from_config) ) return self.logger.info("Starting dry run for %s", devices) try: if not devices: raise ValueError(f"Dry run on empty device list") if self.shrinkage is None: self.shrinkage = shrinkage elif shrinkage is not None and shrinkage != self.shrinkage: raise ValueError(f"given shrinkage {shrinkage} incompatible with self.shrinkage {self.shrinkage}") working_devices = self.minimal_device_test(devices) failed_devices = set(devices) - set(working_devices) if failed_devices: self.logger.error(f"Minimal device test failed for {failed_devices}") if self.training_shape is None: self.training_shape = self._determine_training_shape(training_shape=training_shape, devices=devices) elif training_shape is not None and self.training_shape != training_shape: raise ValueError( f"given training_shape {training_shape} incompatible with self.training_shape {self.training_shape}" ) self._determine_valid_shapes(devices=devices, valid_shapes=valid_shapes) fut.set_result((devices, self.training_shape, self.valid_shapes, self.shrinkage)) self.logger.info("dry run done. shrinkage: %s", self.shrinkage) except Exception as e: self.logger.error(traceback.format_exc()) fut.set_exception(e)
def validate_shape(self, devices: Sequence[torch.device], shape: Point, train_mode: bool) -> bool: assert devices if train_mode: crit_class = self.criterion_class criterion_kwargs = { key: value for key, value in self.config[TRAINING] [LOSS_CRITERION_CONFIG].items() if key != "method" } else: crit_class = None criterion_kwargs = {} return_conns = [ in_subproc( self._validate_shape, model=self.model, device=d, shape=shape, criterion_class=crit_class, criterion_kwargs=criterion_kwargs, ) for d in devices ] output_shapes = [conn.recv() for conn in return_conns] for e in output_shapes: if isinstance(e, str): self.logger.info("Shape %s invalid: %s", shape, e) return False out = output_shapes[0] if any([o != out for o in output_shapes[1:]]): self.logger.warning( "different devices returned different output shapes for same input shape!" ) return False output_shape = shape.__class__( **{a: s for a, s in zip(shape.order, out)}).drop_batch() shrinkage = shape.drop_batch() - output_shape if self.shrinkage is None: self.shrinkage = shrinkage self.logger.info("Determined shrinkage to be %s", shrinkage) return True else: return self.shrinkage == shrinkage
def _determine_valid_shapes(self, devices: Sequence[torch.device], valid_shapes: Sequence[Point]): # todo: find valid shapes if valid_shapes is None: self.valid_shapes = [self.training_shape] else: self.valid_shapes = [ s for s in valid_shapes if self.validate_shape( devices=devices, shape=Point(**{a: s[a] for a in s.order}).add_batch(1), train_mode=False ) ]
def find_one_shape( self, lower_limit: Point, upper_limit: Point, devices: Sequence[torch.device], train_mode: bool = False, discard: float = 0, ) -> Optional[Point]: assert lower_limit.order == upper_limit.order lower = numpy.array(lower_limit) upper = numpy.array(upper_limit) diff = upper - lower assert all( diff >= 0 ), f"negative diff: {diff} = upper({upper}) - lower({lower}) " assert 0 <= discard < 1 def update_nonzero(diff): nonzero_index = diff.nonzero()[0] nonzero = diff[nonzero_index] ndiff = len(nonzero) return nonzero_index, nonzero, ndiff nonzero_index, nonzero, ndiff = update_nonzero(diff) ncomb = numpy.prod(nonzero) if ncomb > 10000: self.logger.warning("Possibly testing too many combinations!!!") while ndiff: search_order = numpy.argsort(nonzero)[::-1] for diff_i in search_order: shape = Point(**dict(zip(lower_limit.order, lower + diff))) if self.validate_shape(devices=devices, shape=shape, train_mode=train_mode): return shape reduced = int((1.0 - discard) * nonzero[diff_i] - 1) diff[nonzero_index[diff_i]] = reduced nonzero_index, nonzero, ndiff = update_nonzero(diff) return None
def _determine_training_shape(self, devices: Sequence[torch.device], training_shape: Optional[Point] = None): self.logger.debug("Determine training shape on %s (previous training shape: %s)", devices, training_shape) batch_size = self.config[TRAINING][BATCH_SIZE] if TRAINING_SHAPE in self.config[TRAINING]: config_training_shape = Point(**{f"d{i}": v for i, v in enumerate(self.config[TRAINING][TRAINING_SHAPE])}) if training_shape is None: training_shape = config_training_shape else: assert training_shape == config_training_shape, "training shape unequal to config training shape" self.logger.debug("Validate given training shape: %s", training_shape) training_shape = training_shape.add_batch(batch_size) if TRAINING_SHAPE_UPPER_BOUND in self.config[TRAINING]: training_shape_upper_bound = Point( b=batch_size, **{f"d{i}": v for i, v in enumerate(self.config[TRAINING][TRAINING_SHAPE_UPPER_BOUND])}, ) if not (training_shape <= training_shape_upper_bound): raise ValueError( f"{TRAINING_SHAPE}: {training_shape} incompatible with {TRAINING_SHAPE_UPPER_BOUND}: " f"{training_shape_upper_bound}" ) if TRAINING_SHAPE_LOWER_BOUND in self.config[TRAINING]: training_shape_lower_bound = Point( b=batch_size, **{f"d{i}": v for i, v in enumerate(self.config[TRAINING][TRAINING_SHAPE_LOWER_BOUND])}, ) else: training_shape_lower_bound = Point(**{a: 1 for a in training_shape.order}) if not (training_shape_lower_bound <= training_shape): raise ValueError( f"{TRAINING_SHAPE_LOWER_BOUND}{training_shape_lower_bound} incompatible with {TRAINING_SHAPE}" f"{training_shape}" ) if not self.validate_shape(devices=devices, shape=training_shape, train_mode=True): raise ValueError(f"{TRAINING_SHAPE}: {training_shape} could not be processed on devices: {devices}") else: self.logger.debug("Determine training shape from lower and upper bound...") if TRAINING_SHAPE_UPPER_BOUND not in self.config[TRAINING]: raise ValueError(f"config is missing {TRAINING_SHAPE} and/or {TRAINING_SHAPE_UPPER_BOUND}.") training_shape_upper_bound = Point( b=batch_size, **{f"d{i}": v for i, v in enumerate(self.config[TRAINING][TRAINING_SHAPE_UPPER_BOUND])} ) if TRAINING_SHAPE_LOWER_BOUND in self.config[TRAINING]: training_shape_lower_bound = Point( b=batch_size, **{f"d{i}": v for i, v in enumerate(self.config[TRAINING][TRAINING_SHAPE_LOWER_BOUND])}, ) else: training_shape_lower_bound = Point(**{a: 1 for a in training_shape_upper_bound.order}) if not (training_shape_lower_bound <= training_shape_upper_bound): raise ValueError( f"{TRAINING_SHAPE_LOWER_BOUND}: {training_shape_lower_bound} incompatible with " f"{TRAINING_SHAPE_UPPER_BOUND}: {training_shape_upper_bound}" ) self.logger.debug( "Determine training shape from lower and upper bound (%s, %s)", training_shape_lower_bound, training_shape_upper_bound, ) training_shape = self.find_one_shape( training_shape_lower_bound, training_shape_upper_bound, devices=devices ) if training_shape is None: raise ValueError( f"No valid training shape found between lower bound {training_shape_lower_bound} and upper bound " f"{training_shape_upper_bound}" ) training_shape.drop_batch() return training_shape