def __init__(
            self, data_module, output_list=None,
            net_func=None, batch_axis=0, num_samples=None, disp_time_interval=2,
            output_fn=None, is_large=False):

        self.data_module = data_module
        self.num_samples = self.data_module.num_samples()
        self.batch_axis = batch_axis
        self.disp_time_interval = disp_time_interval
        self.output_fn = output_fn
        self.is_large = is_large

        if num_samples is not None:
            if self.num_samples < num_samples:
                print("specified number_samples is larger than one epoch")
            else:
                self.num_samples = num_samples

        self.use_net_func = output_list is None  # otherwise use net_func
        if self.use_net_func:
            assert net_func is not None, \
                "output_list and net_func should not be both specified"
            self.net_func = net_func
            # remark: net_func(sess)
        else:
            assert net_func is None, \
                "one of output_list and net_func must be specified"
            self.output_list = output_list
            [self.flatten_output_list, self.output_wrap_func] = \
                recursive_flatten_with_wrap_func(
                    lambda x: tmf.is_tf_data(x), self.output_list)

        self.data_module.reset()
        self.cur_sample_end = 0
def spatial_transformer(input_layer, theta, out_size, name=PROVIDED):

    # init
    input_shape = tmf.get_shape(input_layer.tensor)
    assert len(input_shape) == 4, "input tensor must be rank 4"
    if theta is np.ndarray:
        theta = tf.constant(theta)
    elif not tmf.is_tf_data(theta):
        theta = theta.tensor

    # apply transformer
    output = transformer(input_layer.tensor,
                         theta,
                         out_size=out_size,
                         name=name)

    # make output shape explicit
    output = tf.reshape(output, [input_shape[0]] + out_size + [input_shape[3]])
    return output
def coordinate_inv_transformer(input_layer, theta, name=PROVIDED):

    # init
    input_tensor = input_layer.tensor
    input_shape = tmf.get_shape(input_tensor)
    assert len(input_shape) == 3, "input tensor must be rank 3"
    if theta is np.ndarray:
        theta = tf.constant(theta)
    elif not tmf.is_tf_data(theta):
        theta = theta.tensor

    keypoint_num = tmf.get_shape(input_tensor)[1]

    with tf.variable_scope(name):
        kp2_e = tf.concat(
            [input_tensor, tf.ones_like(input_tensor[:, :, :1])], axis=2)
        kp2_e = tf.expand_dims(kp2_e, axis=-1)
        transform_e = tf.tile(tf.expand_dims(theta, axis=1),
                              [1, keypoint_num, 1, 1])
        kp1from2_e = tf.matmul(transform_e, kp2_e)
        kp1from2 = tf.squeeze(kp1from2_e, axis=-1)

    return kp1from2
    def create_data_tensor(self, raw_data_module, output_index=0):
        # set up: self.data_module and self.data_tensor
        assert not self._called_create_data_tensor, \
            "data_tensor should not be create twice"
        self._called_create_data_tensor = True

        preprocessed_data_module = self.pipeline.data_module_preprocess(
            raw_data_module, mode="random")
        self._data_module = runner.resumable_data_module_wrapper.Net(
            preprocessed_data_module)
        unshuffled_data_tensor = tf_variable_from_data_module(
            self.data_module, self.batch_size, output_index)

        tsc = self.opt.train_shuffle_capacity
        maximum_shuffle_capacity = None
        if self.train_use_shuffle():
            if tsc == "full":
                maximum_shuffle_capacity = self.data_module.num_samples()
            elif tsc == "on":
                # use shuffle provided by the data loader
                pass
            elif tsc >= 1:
                maximum_shuffle_capacity = tsc * self.batch_size
                if maximum_shuffle_capacity > self.data_module.num_samples():
                    maximum_shuffle_capacity = self.data_module.num_samples()

        if maximum_shuffle_capacity is not None:
            minimum_shuffle_capacity = 3 * self.batch_size
            if tmf.is_tf_data(unshuffled_data_tensor):
                unshuffled_data_tensor = [unshuffled_data_tensor]
            data_tensor = tf.train.shuffle_batch(
                unshuffled_data_tensor,
                batch_size=self.batch_size,
                capacity=maximum_shuffle_capacity,
                min_after_dequeue=minimum_shuffle_capacity,
                enqueue_many=True)
        else:
            data_tensor = tgu.sequential_data_buffer(
                unshuffled_data_tensor,
                batch_size=self.batch_size,
                capacity=self.batch_size * 3,
                enqueue_many=True)

        if self.opt.train_color_jittering or self.opt.train_random_mirroring:
            if isinstance(data_tensor, list):
                image = data_tensor[0]
            else:
                image = data_tensor

            if self.opt.train_color_jittering:
                if self.opt.image_color_scaling is not None:
                    clip_value_min = (1 - self.opt.image_color_scaling) * 0.5
                    image = (image -
                             clip_value_min) / self.opt.image_color_scaling
                image = tf.image.random_brightness(image, max_delta=32. / 255.)
                # image = tf.image.random_saturation(image, lower=0.5, upper=1.5)
                # image = tf.image.random_hue(image, max_delta=0.2)
                image = tf.image.random_contrast(image, lower=0.5, upper=1.5)
                image = tf.clip_by_value(image,
                                         clip_value_min=0,
                                         clip_value_max=1)
                if self.opt.image_color_scaling is not None:
                    clip_value_min = (1 - self.opt.image_color_scaling) * 0.5
                    image = image * self.opt.image_color_scaling + clip_value_min

            if self.opt.train_random_mirroring:
                image_flip = tf.reverse(image, axis=[2])
                im_inp_shape = tmf.get_shape(image)
                flip_ind = tf.random_uniform([im_inp_shape[0]] + [1] *
                                             (len(im_inp_shape) - 1),
                                             minval=0.,
                                             maxval=1.,
                                             dtype=tf.float32) > 0.5
                image = tf.where(tf.tile(flip_ind, [1] + im_inp_shape[1:]),
                                 image_flip, image)

            if isinstance(data_tensor, list):
                data_tensor[0] = image
            else:
                data_tensor = image

        self._data_tensor = data_tensor

        return data_tensor
    def __init__(self,
                 data_module,
                 loss_tensor,
                 solver_type,
                 solver_kwargs=None,
                 disp_tensor_dict=None,
                 minimizer_kwargs=None,
                 update_ops=None,
                 max_epochs=None,
                 disp_time_interval=2,
                 disp_prefix=None,
                 learning_rate=None,
                 global_step=None,
                 snapshot_func=None,
                 snapshot_interval=7200,
                 snapshot_sharing=None,
                 permanent_snapshot_step_list=None,
                 snapshot_step_list=None,
                 test_func=None,
                 test_steps=10000,
                 logger=None,
                 scope=None,
                 extra_output_tensors=None):

        if solver_kwargs is None:
            solver_kwargs = dict()

        if minimizer_kwargs is None:
            minimizer_kwargs = dict()

        if disp_tensor_dict is None:
            disp_tensor_dict = dict()
        minimizer_kwargs = copy(minimizer_kwargs)

        if learning_rate is not None:
            solver_kwargs["learning_rate"] = learning_rate
        else:
            assert hasattr(solver_kwargs,
                           "learning_rate"), "learning rate is not set"
        self.learning_rate_tensor = solver_kwargs["learning_rate"]
        if not tmf.is_tf_data(self.learning_rate_tensor):
            self.learning_rate_tensor = tf.constant(self.learning_rate_tensor)
        self.learning_rate = None

        if scope is None:
            scope = "trainer"

        self.data_module = data_module
        optimizer_func = getattr(tf.train, solver_type + "Optimizer")

        self.optimizer = optimizer_func(**solver_kwargs)

        # figure out subiters
        if "var_list" in minimizer_kwargs:
            var_list = minimizer_kwargs["var_list"]
        else:
            var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)

        self.update_shared = tgu.update_shared_vars(var_list)

        var_list = list(
            set(var_list) -
            set(tgu.get_freeze_collection()))  # remove freeze variables

        self.var_list = var_list
        minimizer_kwargs["var_list"] = var_list

        # cache variables

        old_variable_list = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)

        # define training iters
        with tf.device("/cpu:0"), tf.variable_scope(scope):
            self.iter_variable = tf.Variable(0,
                                             trainable=False,
                                             dtype=tf.int64,
                                             name="trainer_step")
            self.pos_variable = tf.Variable(0,
                                            trainable=False,
                                            dtype=tf.int64,
                                            name="trainer_pos")

        # function for handling update ops
        def attach_updates_to_train_op(train_op_without_updates):
            # add update ops (mainly for batch normalization)
            if update_ops is None:
                train_op = pt.with_update_ops(train_op_without_updates)
            else:
                assert isinstance(update_ops,
                                  list), "update_ops must be a list"
                if update_ops:
                    train_op = tf.group(train_op_without_updates, *update_ops)
                else:
                    train_op = train_op_without_updates
            return train_op

        # define minimizer
        self.gradient_tensors = OrderedDict()

        is_single_device = tmf.is_tf_data(loss_tensor)
        assert is_single_device, \
            "ERROR: this code does not support multiple devices. Use CUDA_VISIBLE_DEVICES=... to specify the GPU."

        raw_gradient_tensor = self.optimizer.compute_gradients(
            loss_tensor, **minimizer_kwargs)

        # disp and extra variables
        self.loss_tensor = loss_tensor
        self.disp_tensor_dict = flatten_str_dict(disp_tensor_dict)
        self.extra_output_tensors = extra_output_tensors

        new_variable_list = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)

        # train for all subiters
        self.gradient_tensor = []
        for g, v in raw_gradient_tensor:
            if hasattr(v, "lr_mult"):
                g *= v.lr_mult
            self.gradient_tensor.append((g, v))
        self.train_op_without_updates = self.optimizer.apply_gradients(
            self.gradient_tensor)
        self.train_op = attach_updates_to_train_op(
            self.train_op_without_updates)
        with tf.control_dependencies([self.update_shared]):
            self.train_op = tf.group(self.train_op)

        # sanity check
        assert "extra" not in disp_tensor_dict, \
            "extra is reserved for extra outputs"

        # saveable variables
        self.saveable_variables = list(
            set(new_variable_list) - set(old_variable_list))

        # helper for extra outputs
        self.extra_output_tensors_flattened, self.extra_output_tensors_wrapfunc = \
            recursive_flatten_with_wrap_func(tmf.is_tf_data, self.extra_output_tensors)

        # setup loss summaries
        self.loss_summaries = []
        with tf.name_scope('trainer_summaries'):
            if disp_prefix is not None:
                summary_prefix = disp_prefix + "_"
            else:
                summary_prefix = ""
            self.loss_summaries.append(
                tf.summary.scalar(summary_prefix + "Loss", self.loss_tensor))
            for k, v in self.disp_tensor_dict.items():
                self.loss_summaries.append(
                    tf.summary.scalar(summary_prefix + k, v))
            self.loss_summaries.append(
                tf.summary.scalar(summary_prefix + "learning_rate",
                                  self.learning_rate_tensor))
        self.merged_loss_summaries = tf.summary.merge([self.loss_summaries])
        self.logger = logger  # do not do anything with it just set it up if possible

        # self.train_op = pt.apply_optimizer(
        #     self.optimizer, losses=[loss_tensor], **minimizer_kwargs)

        # step up variables for the training stage
        self.max_epochs = max_epochs
        self.disp_time_interval = disp_time_interval

        self.disp_prefix = disp_prefix
        self.total_iter = np.uint64(0)
        self.total_pos = np.uint64(0)

        if global_step is None:
            global_step = tf.train.get_or_create_global_step()
        self.global_step = global_step

        with tf.device("/cpu:0"):
            self.iter_variable_inc = tf.assign_add(self.iter_variable, 1)
            self.global_step_inc = tf.assign_add(self.global_step, 1)
            self.pos_assign_placeholder = tf.placeholder(
                tf.int64, shape=[], name="trainer_pos_assign")
            self.pos_variable_assign = tf.assign(self.pos_variable,
                                                 self.pos_assign_placeholder)

        # set up variables for run_init
        self.all_output_tensors = None
        self.all_output_names = None
        self.tmp_training_losses = None
        self.disp_countdown = None
        self.outside_timestamp = None
        self.tmp_iter_start = None

        # set up snapshot saver
        if permanent_snapshot_step_list is None:
            permanent_snapshot_step_list = []
        if snapshot_step_list is None:
            snapshot_step_list = []

        if snapshot_sharing is None:
            self.snapshot_runner_shared = False
            if snapshot_func is None:
                snapshot_interval = None
            else:
                if snapshot_interval is not None:
                    print("  - snapshot in every %d sec" % snapshot_interval)

            self.permanent_snapshot_condition = \
                ArgsSepFunc(lambda the_step: the_step in permanent_snapshot_step_list)
            self.snapshot_condition = \
                ArgsSepFunc(lambda the_step: the_step in snapshot_step_list)

            _snapshot_periodic_runner = PeriodicRun(snapshot_interval,
                                                    snapshot_func)
            _snapshot_periodic_runner.add_extra_true_condition(
                self.snapshot_condition)
            _snapshot_periodic_runner.add_extra_true_condition(self.need_stop)
            _snapshot_periodic_runner.add_extra_true_condition(
                self.permanent_snapshot_condition,
                extra_func=lambda sess, step: snapshot_func(
                    sess, step, preserve=True))
            self.snapshot_periodic_runner = _snapshot_periodic_runner
        else:
            self.snapshot_runner_shared = True
            self.snapshot_periodic_runner = snapshot_sharing.snapshot_periodic_runner
            self.permanent_snapshot_condition = snapshot_sharing.permanent_snapshot_condition
            self.snapshot_condition = snapshot_sharing.snapshot_condition

        # set up test func
        self.test_func = test_func
        self.test_steps = test_steps

        # step up variables for avg update
        self.avg_var_list = None
        self.avg_var_forward_steps = None
        self.avg_var_minimum_update_steps = None
        self.avg_var_update_num = None
        self.avg_var_exact_mode = None
        self.avg_var_running_mode = None
 def convert_to_tensor(a):
     if a is np.ndarray:
         a = tf.constant(a)
     elif not tmf.is_tf_data(a):
         a = a.tensor
     return a