예제 #1
0
    def __init__(self,
                 save_checkpoint_steps=1,
                 save_checkpoint_seconds=0,
                 keep_checkpoint_max=5,
                 keep_checkpoint_per_n_minutes=0,
                 integrated_save=True):

        if not save_checkpoint_steps and not save_checkpoint_seconds and \
                not keep_checkpoint_max and not keep_checkpoint_per_n_minutes:
            raise ValueError("The input_param can't be all None or 0")

        if save_checkpoint_steps:
            save_checkpoint_steps = check_int_non_negative(save_checkpoint_steps)
        if save_checkpoint_seconds:
            save_checkpoint_seconds = check_int_non_negative(save_checkpoint_seconds)
        if keep_checkpoint_max:
            keep_checkpoint_max = check_int_non_negative(keep_checkpoint_max)
        if keep_checkpoint_per_n_minutes:
            keep_checkpoint_per_n_minutes = check_int_non_negative(keep_checkpoint_per_n_minutes)

        self._save_checkpoint_steps = save_checkpoint_steps
        self._save_checkpoint_seconds = save_checkpoint_seconds
        if self._save_checkpoint_steps and self._save_checkpoint_steps > 0:
            self._save_checkpoint_seconds = None

        self._keep_checkpoint_max = keep_checkpoint_max
        self._keep_checkpoint_per_n_minutes = keep_checkpoint_per_n_minutes
        if self._keep_checkpoint_max and self._keep_checkpoint_max > 0:
            self._keep_checkpoint_per_n_minutes = None
        else:
            if not self._keep_checkpoint_per_n_minutes or self._keep_checkpoint_per_n_minutes == 0:
                self._keep_checkpoint_max = 1

        self._integrated_save = check_bool(integrated_save)
예제 #2
0
    def __init__(self,
                 save_checkpoint_steps=1,
                 save_checkpoint_seconds=0,
                 keep_checkpoint_max=5,
                 keep_checkpoint_per_n_minutes=0,
                 integrated_save=True,
                 model_type="normal"):

        if not save_checkpoint_steps and not save_checkpoint_seconds and \
                not keep_checkpoint_max and not keep_checkpoint_per_n_minutes:
            raise ValueError("The input_param can't be all None or 0")

        if save_checkpoint_steps:
            save_checkpoint_steps = check_int_non_negative(
                save_checkpoint_steps)
        if save_checkpoint_seconds:
            save_checkpoint_seconds = check_int_non_negative(
                save_checkpoint_seconds)
        if keep_checkpoint_max:
            keep_checkpoint_max = check_int_non_negative(keep_checkpoint_max)
        if keep_checkpoint_per_n_minutes:
            keep_checkpoint_per_n_minutes = check_int_non_negative(
                keep_checkpoint_per_n_minutes)
        if model_type:
            model_type = check_string(model_type,
                                      ["normal", "fusion", "quant"])

        self._save_checkpoint_steps = save_checkpoint_steps
        self._save_checkpoint_seconds = save_checkpoint_seconds
        if self._save_checkpoint_steps and self._save_checkpoint_steps > 0:
            self._save_checkpoint_seconds = None

        self._keep_checkpoint_max = keep_checkpoint_max
        self._keep_checkpoint_per_n_minutes = keep_checkpoint_per_n_minutes
        if self._keep_checkpoint_max and self._keep_checkpoint_max > 0:
            self._keep_checkpoint_per_n_minutes = None
        else:
            if not self._keep_checkpoint_per_n_minutes or self._keep_checkpoint_per_n_minutes == 0:
                self._keep_checkpoint_max = 1

        self._model_type = model_type
        self._integrated_save = check_bool(integrated_save)
예제 #3
0
    def __init__(self, in_channels, out_channels, kernel_size, stride,
                 pad_mode, padding, dilation, group, has_bias, weight_init,
                 bias_init):
        super(_Conv, self).__init__()
        self.in_channels = check_int_positive(in_channels)
        self.out_channels = check_int_positive(out_channels)
        self.kernel_size = kernel_size
        self.stride = stride
        self.pad_mode = pad_mode
        self.padding = check_int_non_negative(padding)
        self.dilation = dilation
        self.group = check_int_positive(group)
        self.has_bias = has_bias
        if (not isinstance(kernel_size[0], int)) or (not isinstance(kernel_size[1], int)) or \
            kernel_size[0] < 1 or kernel_size[1] < 1:
            raise ValueError(
                "Attr 'kernel_size' of 'Conv2D' Op passed " +
                str(self.kernel_size) +
                ", should be a int or tuple and equal to or greater than 1.")
        if (not isinstance(stride[0], int)) or (not isinstance(
                stride[1], int)) or stride[0] < 1 or stride[1] < 1:
            raise ValueError(
                "Attr 'stride' of 'Conv2D' Op passed " + str(self.stride) +
                ", should be a int or tuple and equal to or greater than 1.")
        if (not isinstance(dilation[0], int)) or (not isinstance(dilation[1], int)) or \
            dilation[0] < 1 or dilation[1] < 1:
            raise ValueError("Attr 'dilation' of 'Conv2D' Op passed " +
                             str(self.dilation) +
                             ", should equal to or greater than 1.")
        if in_channels % group != 0:
            raise ValueError(
                "Attr 'in_channels' of 'Conv2D' Op must be divisible by "
                "attr 'group' of 'Conv2D' Op.")
        if out_channels % group != 0:
            raise ValueError(
                "Attr 'out_channels' of 'Conv2D' Op must be divisible by "
                "attr 'group' of 'Conv2D' Op.")

        self.weight = Parameter(initializer(
            weight_init, [out_channels, in_channels // group, *kernel_size]),
                                name='weight')

        if check_bool(has_bias):
            self.bias = Parameter(initializer(bias_init, [out_channels]),
                                  name='bias')
        else:
            if bias_init != 'zeros':
                logger.warning(
                    "Value of 'has_bias' is False, value of 'bias_init' will be ignored."
                )
            self.bias = None