コード例 #1
0
 def _transform_dataset(self, mode):
     all_output_keys = []
     signature_epoch, mode_ops = self._get_signature_epoch(mode)
     extracted_ds = self.extracted_dataset[mode]
     state = {"mode": mode}
     dataset_map = {}
     for epoch in signature_epoch:
         epoch_ops_all = []
         forward_ops_epoch = []
         filter_ops_epoch = []
         forward_ops_between_filter = []
         # get batch size for the epoch
         global_batch_size = self.get_global_batch_size(epoch)
         # generate ops for specific mode and epoch
         for op in mode_ops:
             if isinstance(op, Scheduler):
                 scheduled_op = op.get_current_value(epoch)
                 if scheduled_op:
                     epoch_ops_all.append(scheduled_op)
             else:
                 epoch_ops_all.append(op)
         # check the ops
         epoch_ops_without_filter = [
             op for op in epoch_ops_all if not isinstance(op, TensorFilter)
         ]
         verify_ops(epoch_ops_without_filter, "Pipeline")
         # arrange operation according to filter location
         for op in epoch_ops_all:
             all_output_keys.append(op.outputs)
             if not isinstance(op, TensorFilter):
                 forward_ops_between_filter.append(op)
             else:
                 forward_ops_epoch.append(forward_ops_between_filter)
                 filter_ops_epoch.append(op)
                 forward_ops_between_filter = []
         forward_ops_epoch.append(forward_ops_between_filter)
         # execute the operations
         dataset = self._execute_ops(extracted_ds, forward_ops_epoch,
                                     filter_ops_epoch, state)
         if self.expand_dims:
             dataset = dataset.flat_map(tf.data.Dataset.from_tensor_slices)
         if self.batch:
             if self.padded_batch:
                 _ = dataset.map(self._get_padded_shape)
                 dataset = dataset.padded_batch(
                     global_batch_size, padded_shapes=self.padded_shape)
             else:
                 dataset = dataset.batch(global_batch_size)
         dataset = dataset.prefetch(buffer_size=1)
         if fe.distribute_strategy:
             dataset = fe.distribute_strategy.experimental_distribute_dataset(
                 dataset)
         dataset_map[epoch] = iter(dataset)
     self.dataset_schedule[mode] = Scheduler(epoch_dict=dataset_map)
     self.all_output_keys = self.all_output_keys | set(
         flatten_list(all_output_keys))
コード例 #2
0
ファイル: network.py プロジェクト: AriChow/fastestimator
    def prepare(self, mode_list, distribute_strategy):
        """This function constructs the model specified in model definition and create replica of model
         for distributed training across multiple devices if there are multiple GPU available.

        Args:
            mode_list : can be either 'train' or 'eval'
            distribute_strategy : Tensorflow class that defines distribution strategy (e.g. tf.distribute.MirroredStrategy)
        """
        all_output_keys = []
        for mode in mode_list:
            signature_epoch, mode_ops = self._get_signature_epoch(mode)
            epoch_ops_map = {}
            epoch_model_map = {}
            for epoch in signature_epoch:
                epoch_ops = []
                epoch_model = []
                # generate ops for specific mode and epoch
                for op in mode_ops:
                    if isinstance(op, Scheduler):
                        scheduled_op = op.get_current_value(epoch)
                        if scheduled_op:
                            epoch_ops.append(scheduled_op)
                    else:
                        epoch_ops.append(op)
                # check the ops
                verify_ops(epoch_ops, "Network")
                # create model list
                for op in epoch_ops:
                    all_output_keys.append(op.outputs)
                    if isinstance(op, ModelOp):
                        if op.model.keras_model is None:
                            with distribute_strategy.scope(
                            ) if distribute_strategy else NonContext():
                                op.model.keras_model = op.model.model_def()
                                op.model.keras_model.optimizer = op.model.optimizer
                                op.model.keras_model.loss_name = op.model.loss_name
                                op.model.keras_model.model_name = op.model.model_name
                                assert op.model.model_name not in self.model, \
                                    "duplicated model name: {}".format(op.model.model_name)
                                self.model[
                                    op.model.model_name] = op.model.keras_model
                                if op.model.loss_name not in self.all_losses:
                                    self.all_losses.append(op.model.loss_name)
                        if op.model.keras_model not in epoch_model:
                            epoch_model.append(op.model.keras_model)
                assert epoch_model, "Network has no model for epoch {}".format(
                    epoch)
                epoch_ops_map[epoch] = epoch_ops
                epoch_model_map[epoch] = epoch_model
            self.op_schedule[mode] = Scheduler(epoch_dict=epoch_ops_map)
            self.model_schedule[mode] = Scheduler(epoch_dict=epoch_model_map)
        self.all_output_keys = set(flatten_list(all_output_keys)) - {None}
コード例 #3
0
 def prepare(self, mode_list):
     """This function constructs the operations necessary for each epoch
     """
     all_output_keys = []
     all_models = []
     for mode in mode_list:
         signature_epoch, mode_ops = self._get_signature_epoch(mode)
         epoch_ops_map = {}
         epoch_model_map = {}
         for epoch in signature_epoch:
             epoch_ops = []
             epoch_model = []
             epoch_model_update = defaultdict(lambda: False)
             # generate ops for specific mode and epoch
             for op in mode_ops:
                 if isinstance(op, Scheduler):
                     scheduled_op = op.get_current_value(epoch)
                     if scheduled_op:
                         epoch_ops.append(scheduled_op)
                 else:
                     epoch_ops.append(op)
             # check the ops
             verify_ops(epoch_ops, "Network")
             # create model list
             for op in epoch_ops:
                 all_output_keys.append(op.outputs)
                 if isinstance(op, ModelOp):
                     if op.model not in epoch_model:
                         epoch_model.append(op.model)
                         epoch_model_update[op.model] = epoch_model_update[
                             op.model]
                     if op.model not in all_models:
                         all_models.append(op.model)
                 if isinstance(op, UpdateOp):
                     epoch_model_update[op.model] = True
             if mode == "train":
                 for model, has_update in epoch_model_update.items():
                     if not has_update:
                         epoch_ops.append(UpdateOp(model=model))
             assert epoch_model, "Network has no model for epoch {}".format(
                 epoch)
             epoch_ops_map[epoch] = epoch_ops
             epoch_model_map[epoch] = epoch_model
         self.op_schedule[mode] = Scheduler(epoch_dict=epoch_ops_map)
         self.model_schedule[mode] = Scheduler(epoch_dict=epoch_model_map)
     self.all_output_keys = set(flatten_list(all_output_keys)) - {None}
     for model in all_models:
         assert model.model_name not in self.model, "duplicated model name: {}".format(
             model.model_name)
         self.model[model.model_name] = model
コード例 #4
0
 def _check_ops(self, mode):
     if self.ops_local:
         self.mode_ops[mode] = get_op_from_mode(self.ops_local, mode)
         if len(self.mode_ops[mode]) > 0:
             verify_ops(self.mode_ops[mode], "RecordWriter")