def __init__(self, save_checkpoint_steps=1, save_checkpoint_seconds=0, keep_checkpoint_max=5, keep_checkpoint_per_n_minutes=0, integrated_save=True): if not save_checkpoint_steps and not save_checkpoint_seconds and \ not keep_checkpoint_max and not keep_checkpoint_per_n_minutes: raise ValueError("The input_param can't be all None or 0") if save_checkpoint_steps: save_checkpoint_steps = check_int_non_negative(save_checkpoint_steps) if save_checkpoint_seconds: save_checkpoint_seconds = check_int_non_negative(save_checkpoint_seconds) if keep_checkpoint_max: keep_checkpoint_max = check_int_non_negative(keep_checkpoint_max) if keep_checkpoint_per_n_minutes: keep_checkpoint_per_n_minutes = check_int_non_negative(keep_checkpoint_per_n_minutes) self._save_checkpoint_steps = save_checkpoint_steps self._save_checkpoint_seconds = save_checkpoint_seconds if self._save_checkpoint_steps and self._save_checkpoint_steps > 0: self._save_checkpoint_seconds = None self._keep_checkpoint_max = keep_checkpoint_max self._keep_checkpoint_per_n_minutes = keep_checkpoint_per_n_minutes if self._keep_checkpoint_max and self._keep_checkpoint_max > 0: self._keep_checkpoint_per_n_minutes = None else: if not self._keep_checkpoint_per_n_minutes or self._keep_checkpoint_per_n_minutes == 0: self._keep_checkpoint_max = 1 self._integrated_save = check_bool(integrated_save)
def __init__(self, save_checkpoint_steps=1, save_checkpoint_seconds=0, keep_checkpoint_max=5, keep_checkpoint_per_n_minutes=0, integrated_save=True, model_type="normal"): if not save_checkpoint_steps and not save_checkpoint_seconds and \ not keep_checkpoint_max and not keep_checkpoint_per_n_minutes: raise ValueError("The input_param can't be all None or 0") if save_checkpoint_steps: save_checkpoint_steps = check_int_non_negative( save_checkpoint_steps) if save_checkpoint_seconds: save_checkpoint_seconds = check_int_non_negative( save_checkpoint_seconds) if keep_checkpoint_max: keep_checkpoint_max = check_int_non_negative(keep_checkpoint_max) if keep_checkpoint_per_n_minutes: keep_checkpoint_per_n_minutes = check_int_non_negative( keep_checkpoint_per_n_minutes) if model_type: model_type = check_string(model_type, ["normal", "fusion", "quant"]) self._save_checkpoint_steps = save_checkpoint_steps self._save_checkpoint_seconds = save_checkpoint_seconds if self._save_checkpoint_steps and self._save_checkpoint_steps > 0: self._save_checkpoint_seconds = None self._keep_checkpoint_max = keep_checkpoint_max self._keep_checkpoint_per_n_minutes = keep_checkpoint_per_n_minutes if self._keep_checkpoint_max and self._keep_checkpoint_max > 0: self._keep_checkpoint_per_n_minutes = None else: if not self._keep_checkpoint_per_n_minutes or self._keep_checkpoint_per_n_minutes == 0: self._keep_checkpoint_max = 1 self._model_type = model_type self._integrated_save = check_bool(integrated_save)
def __init__(self, in_channels, out_channels, kernel_size, stride, pad_mode, padding, dilation, group, has_bias, weight_init, bias_init): super(_Conv, self).__init__() self.in_channels = check_int_positive(in_channels) self.out_channels = check_int_positive(out_channels) self.kernel_size = kernel_size self.stride = stride self.pad_mode = pad_mode self.padding = check_int_non_negative(padding) self.dilation = dilation self.group = check_int_positive(group) self.has_bias = has_bias if (not isinstance(kernel_size[0], int)) or (not isinstance(kernel_size[1], int)) or \ kernel_size[0] < 1 or kernel_size[1] < 1: raise ValueError( "Attr 'kernel_size' of 'Conv2D' Op passed " + str(self.kernel_size) + ", should be a int or tuple and equal to or greater than 1.") if (not isinstance(stride[0], int)) or (not isinstance( stride[1], int)) or stride[0] < 1 or stride[1] < 1: raise ValueError( "Attr 'stride' of 'Conv2D' Op passed " + str(self.stride) + ", should be a int or tuple and equal to or greater than 1.") if (not isinstance(dilation[0], int)) or (not isinstance(dilation[1], int)) or \ dilation[0] < 1 or dilation[1] < 1: raise ValueError("Attr 'dilation' of 'Conv2D' Op passed " + str(self.dilation) + ", should equal to or greater than 1.") if in_channels % group != 0: raise ValueError( "Attr 'in_channels' of 'Conv2D' Op must be divisible by " "attr 'group' of 'Conv2D' Op.") if out_channels % group != 0: raise ValueError( "Attr 'out_channels' of 'Conv2D' Op must be divisible by " "attr 'group' of 'Conv2D' Op.") self.weight = Parameter(initializer( weight_init, [out_channels, in_channels // group, *kernel_size]), name='weight') if check_bool(has_bias): self.bias = Parameter(initializer(bias_init, [out_channels]), name='bias') else: if bias_init != 'zeros': logger.warning( "Value of 'has_bias' is False, value of 'bias_init' will be ignored." ) self.bias = None