コード例 #1
0
 def _build_feed_inputs(self, inputs):
     self._input_names = []
     self._feed_inputs = []
     self._feed_input_names = []
     self._feed_input_shapes = []
     self.inputs = []
     for i, x in enumerate(inputs):
         name = 'input_%d' % (i + 1)
         self._input_names.append(name)
         if isinstance(x, list):
             x = np.asarray(x)
             if x.ndim == 1:
                 x = np.expand_dims(x, 1)
         if isinstance(x, np.ndarray):
             shape = (None,) + x.shape[1:]
             placeholder = F.placeholder(
                 shape=shape, name=name)
             self.inputs.append(placeholder)
             self._feed_inputs.append(placeholder)
             self._feed_input_names.append(name)
             self._feed_input_shapes.append(shape)
         else:
             self.inputs.append(x)
             if F.is_placeholder(x):
                 self._feed_inputs.append(x)
                 self._feed_input_names.append(name)
                 self._feed_input_shapes.append(F.int_shape(x))
コード例 #2
0
 def _build_feed_targets(self, targets):
     # We don't check targets' length to compatible with self.outputs'
     # cause loss and metric have already calculated from model_fn
     self.targets = []
     self._target_names = []
     self._feed_targets = []
     self._feed_target_names = []
     self._feed_target_shapes = []
     for i, x in enumerate(targets):
         name = 'target_%d' % (i + 1)
         self._target_names.append(name)
         if isinstance(x, list):
             x = np.asarray(x)
             if x.ndim == 1:
                 x = np.expand_dims(x, 1)
         if isinstance(x, np.ndarray):
             shape = (None,) + x.shape[1:]
             placeholder = F.placeholder(
                 shape=shape, name=name)
             self.targets.append(placeholder)
             self._feed_targets.append(placeholder)
             self._feed_target_names.append(name)
             self._feed_target_shapes.append(shape)
         else:
             self.targets.append(x)
             if F.is_placeholder(x):
                 self._feed_targets.append(x)
                 self._feed_target_names.append(name)
                 self._feed_target_shapes.append(F.int_shape(x))
コード例 #3
0
 def _set_inputs(self, inputs, outputs=None, training=None):
     """
     Subclassed model
     :param inputs: Only support nested list, non-nested dict;
     :param outputs:
     :param training:
     :return:
     """
     self._nested_inputs = inputs
     self.inputs = []
     for i, x in enumerate(utils.valid_data(inputs)):
         name = 'input_%d' % (i + 1)
         self._input_names.append(name)
         if isinstance(x, list):
             x = np.asarray(x)
             if x.ndim == 1:
                 x = np.expand_dims(x, 1)
         if isinstance(x, np.ndarray):
             shape = (None,) + x.shape[1:]
             placeholder = F.placeholder(
                 shape=shape, name=name)
             self.inputs.append(placeholder)
             self._feed_inputs.append(placeholder)
             self._feed_input_names.append(name)
             self._feed_input_shapes.append(shape)
         else:
             self.inputs.append(x)
             if F.is_placeholder(x):
                 self._feed_inputs.append(x)
                 self._feed_input_names.append(name)
                 self._feed_input_shapes.append(F.int_shape(x))
     if self.model_fn is None:
         kwargs = {'training': training} if has_arg(self.forward, 'training') else {}
         self._nested_outputs = self(inputs, **kwargs)
         self.outputs = nest.flatten(self._nested_outputs)
     elif outputs is not None:
         logging.info('=>Calling model_fn...')
         result = self.model_fn(
             self, utils.nest_data(
                 self.inputs, x_keys, x),
             utils.nest_data(
                 self.targets, y_keys, y))
         logging.info('=>Finish calling model_fn...')
         if not isinstance(result, EstimatorSpec):
             raise ValueError("Result returned from `model_fn` must be"
                              "an instance of `EstimatorSpec`")
         self.train_hooks.extend(result.train_hooks)
         self.val_hooks.extend(result.val_hooks)
         self.loss = result.loss
         self.metrics = result.metrics
         self.outputs = result.outputs
     self._output_names = [
         'output_%d' % i for i in range(1, len(self.outputs) + 1)]
     self._uses_learning_phase = any(getattr(x, '_uses_learning_phase', False)
                                     for x in self.outputs)
     self.built = True
コード例 #4
0
 def _compile_targets(self, targets):
     logging.info("=>Compiling targets...")
     self.targets = []
     self._feed_targets = []
     self._feed_target_names = []
     self._feed_target_shapes = []
     self._feed_loss_fns = []
     targets = self._compile_args(targets, 'targets')
     for i in range(len(self.outputs)):
         if i in self._skip_target_indices:
             self.targets.append(None)
         else:
             name = self.output_names[i]
             output = self.outputs[i]
             target = targets[i]
             loss_fn = self.loss_functions[i]
             if target is None:
                 target = F.placeholder(
                     ndim=len(F.int_shape(output)),
                     name=name + '_target',
                     sparse=F.is_sparse(output),
                     dtype=F.dtype(output))
             elif isinstance(target, list):
                 target = np.asarray(target)
                 if target.ndim == 1:
                     target = np.expand_dims(target, 1)
             if isinstance(target, np.ndarray):
                 shape = (None,) + target.shape[1:]
                 placeholder = F.placeholder(
                     shape=shape, name=name)
                 self.targets.append(placeholder)
                 self._feed_targets.append(placeholder)
                 self._feed_target_names.append(name)
                 self._feed_target_shapes.append(shape)
                 self._feed_loss_fns.append(loss_fn)
             else:
                 self.targets.append(target)
                 if F.is_placeholder(target):
                     self._feed_targets.append(target)
                     self._feed_target_names.append(name)
                     self._feed_target_shapes.append(F.int_shape(target))
                     self._feed_loss_fns.append(loss_fn)