コード例 #1
0
    def build_test_model(self, reuse=None):
        """Build model for inference."""
        logging.info('Build test model.')
        with self.graph.as_default(), tf.device(self._sync_device), \
             tf.variable_scope(tf.get_variable_scope(), reuse=reuse):
            decoder_input = shift_right(self.dst_pl)
            Xs = split_tensor(self.src_pl, len(self._devices))
            Ys = split_tensor(self.dst_pl, len(self._devices))
            dec_inputs = split_tensor(decoder_input, len(self._devices))
            prediction_list = []
            loss_sum = 0
            for i, (X, Y, dec_input,
                    device) in enumerate(zip(Xs, Ys, dec_inputs,
                                             self._devices)):
                with tf.device(device):
                    logging.info('Build model on %s.' % device)

                    # Avoid errors caused by empty input by a condition phrase.
                    def true_fn():
                        enc_output = self.encoder(X,
                                                  is_training=False,
                                                  reuse=i > 0 or None)
                        prediction = self.beam_search(enc_output,
                                                      reuse=i > 0 or None)
                        dec_output = self.decoder(dec_input,
                                                  enc_output,
                                                  is_training=False,
                                                  reuse=True)
                        loss = self.test_loss(dec_output, Y, reuse=True)
                        return prediction, loss

                    def false_fn():
                        return tf.zeros([0, 0], dtype=INT_TYPE), 0.0

                    prediction, loss = tf.cond(tf.greater(tf.shape(X)[0], 0),
                                               true_fn, false_fn)

                    loss_sum += loss
                    prediction_list.append(prediction)

            max_length = tf.reduce_max(
                [tf.shape(pred)[1] for pred in prediction_list])

            def pad_to_max_length(input, length):
                """Pad the input (with rank 2) with 3(</S>) to the given length in the second axis."""
                shape = tf.shape(input)
                padding = tf.ones([shape[0], length - shape[1]],
                                  dtype=INT_TYPE) * 3
                return tf.concat([input, padding], axis=1)

            prediction_list = [
                pad_to_max_length(pred, max_length) for pred in prediction_list
            ]
            self.prediction = tf.concat(prediction_list, axis=0)
            self.loss_sum = loss_sum

            self.saver = tf.train.Saver(var_list=tf.global_variables())
コード例 #2
0
    def build_train_model(self, reuse=None):
        """Build model for training. """
        logging.info('Build train model.')
        self.prepare_training()

        def choose_device(op, device):
            if op.type.startswith('Variable'):
                return self._sync_device
            return device

        with self.graph.as_default(), tf.device(self._sync_device), \
            tf.variable_scope(tf.get_variable_scope(), initializer=self._initializer, reuse=reuse):
            Xs = split_tensor(self.src_pl, len(self._devices))
            Ys = split_tensor(self.dst_pl, len(self._devices))
            acc_list, loss_list, gv_list = [], [], []
            for i, (X, Y, device) in enumerate(zip(Xs, Ys, self._devices)):
                with tf.device(lambda op: choose_device(op, device)):
                    logging.info('Build model on %s.' % device)
                    encoder_output = self.encoder(X,
                                                  is_training=True,
                                                  reuse=i > 0 or None)
                    decoder_output = self.decoder(shift_right(Y),
                                                  encoder_output,
                                                  is_training=True,
                                                  reuse=i > 0 or None)
                    acc, loss = self.train_output(decoder_output,
                                                  Y,
                                                  reuse=i > 0 or None)
                    acc_list.append(acc)
                    loss_list.append(loss)
                    gv_list.append(self._optimizer.compute_gradients(loss))

            self.accuracy = tf.reduce_mean(acc_list)
            self.loss = tf.reduce_mean(loss_list)

            # Clip gradients and then apply.
            grads_and_vars = average_gradients(gv_list)
            for g, v in grads_and_vars:
                tf.summary.histogram('variables/' + v.name.split(':')[0], v)
                tf.summary.histogram('gradients/' + v.name.split(':')[0], g)
            grads, self.grads_norm = tf.clip_by_global_norm(
                [gv[0] for gv in grads_and_vars],
                clip_norm=self._config.train.grads_clip)
            grads_and_vars = zip(grads, [gv[1] for gv in grads_and_vars])
            self.train_op = self._optimizer.apply_gradients(
                grads_and_vars, global_step=self.global_step)

            # Summaries
            tf.summary.scalar('acc', self.accuracy)
            tf.summary.scalar('loss', self.loss)
            tf.summary.scalar('learning_rate', self.learning_rate)
            tf.summary.scalar('grads_norm', self.grads_norm)
            self.summary_op = tf.summary.merge_all()

        # We may want to test the model during training.
        self.build_test_model(reuse=True)
コード例 #3
0
ファイル: model_pretrain.py プロジェクト: zw76859420/ASR-1
    def build_train_model(self, test=True, reuse=None):
        """Build model for training. """
        logging.info('Build train model.')
        self.prepare_training()

        with self.graph.as_default():
            acc_list, loss_list, gv_list = [], [], []
            cache = {}
            load = dict([(d, 0) for d in self._devices])
            for i, (X, Y, device) in enumerate(
                    zip(self.src_pls, self.dst_pls, self._devices)):

                def daisy_chain_getter(getter, name, *args, **kwargs):
                    """Get a variable and cache in a daisy chain."""
                    device_var_key = (device, name)
                    if device_var_key in cache:
                        # if we have the variable on the correct device, return it.
                        return cache[device_var_key]
                    if name in cache:
                        # if we have it on a different device, copy it from the last device
                        v = tf.identity(cache[name])
                    else:
                        var = getter(name, *args, **kwargs)
                        v = tf.identity(var._ref())  # pylint: disable=protected-access
                    # update the cache
                    cache[name] = v
                    cache[device_var_key] = v
                    return v

                def balanced_device_setter(op):
                    """Balance variables to all devices."""
                    if op.type in {'Variable', 'VariableV2', 'VarHandleOp'}:
                        # return self._sync_device
                        min_load = min(load.values())
                        min_load_devices = [
                            d for d in load if load[d] == min_load
                        ]
                        chosen_device = random.choice(min_load_devices)
                        load[chosen_device] += op.outputs[0].get_shape(
                        ).num_elements()
                        return chosen_device
                    return device

                def identity_device_setter(op):
                    return device

                device_setter = balanced_device_setter

                with tf.variable_scope(tf.get_variable_scope(),
                                       initializer=self._initializer,
                                       custom_getter=daisy_chain_getter,
                                       reuse=reuse):
                    with tf.device(device_setter):
                        logging.info('Build model on %s.' % device)
                        encoder_output = self.encoder(
                            X,
                            is_training=True,
                            reuse=i > 0 or None,
                            encoder_scope=self.encoder_scope)
                        decoder_output = self.decoder(
                            shift_right(Y),
                            encoder_output,
                            is_training=True,
                            reuse=i > 0 or None,
                            decoder_scope=self.decoder_scope)
                        acc, loss = self.train_output(
                            decoder_output,
                            Y,
                            reuse=i > 0 or None,
                            decoder_scope=self.decoder_scope)
                        acc_list.append(acc)
                        loss_list.append(loss)

                        var_list = tf.trainable_variables()
                        if self._config.train.var_filter:
                            var_list = [
                                v for v in var_list if re.match(
                                    self._config.train.var_filter, v.name)
                            ]
                        gv_list.append(
                            self._optimizer.compute_gradients(
                                loss, var_list=var_list))

            self.accuracy = tf.reduce_mean(acc_list)
            self.loss = tf.reduce_mean(loss_list)

            # Clip gradients and then apply.
            grads_and_vars = average_gradients(gv_list)
            avg_abs_grads = tf.reduce_mean(tf.abs(grads_and_vars[0]))

            if self._config.train.grads_clip > 0:
                grads, self.grads_norm = tf.clip_by_global_norm(
                    [gv[0] for gv in grads_and_vars],
                    clip_norm=self._config.train.grads_clip)
                grads_and_vars = zip(grads, [gv[1] for gv in grads_and_vars])
            else:
                self.grads_norm = tf.global_norm(
                    [gv[0] for gv in grads_and_vars])

            self.train_op = self._optimizer.apply_gradients(
                grads_and_vars, global_step=self.global_step)

            # Summaries
            tf.summary.scalar('acc', self.accuracy)
            tf.summary.scalar('loss', self.loss)
            tf.summary.scalar('learning_rate', self.learning_rate)
            tf.summary.scalar('grads_norm', self.grads_norm)
            tf.summary.scalar('avg_abs_grads', avg_abs_grads)
            self.summary_op = tf.summary.merge_all()

            self.saver = tf.train.Saver(var_list=tf.global_variables(),
                                        max_to_keep=20)

        # We may want to test the model during training.
        if test:
            self.build_test_model(reuse=True)
コード例 #4
0
    def build_test_model(self, reuse=None):
        """Build model for inference."""
        logging.info('Build test model.')

        with self.graph.as_default(), tf.variable_scope(
                tf.get_variable_scope(), reuse=reuse):

            cache = {}
            load = dict([(d, 0) for d in self._devices])

            prediction_list = []
            prediction_label_list = []
            loss_sum = 0
            loss_label_sum = 0
            for i, (X, Y, Z, X_lens, device) in enumerate(
                    zip(self.src_pls, self.dst_pls, self.label_pls,
                        self.src_len_pls, self._devices)):

                def daisy_chain_getter(getter, name, *args, **kwargs):
                    """Get a variable and cache in a daisy chain."""
                    device_var_key = (device, name)
                    if device_var_key in cache:
                        # if we have the variable on the correct device, return
                        # it.
                        return cache[device_var_key]
                    if name in cache:
                        # if we have it on a different device, copy it from the
                        # last device
                        v = tf.identity(cache[name])
                    else:
                        var = getter(name, *args, **kwargs)
                        v = tf.identity(var._ref())  # pylint: disable=protected-access
                    # update the cache
                    cache[name] = v
                    cache[device_var_key] = v
                    return v

                def balanced_device_setter(op):
                    """Balance variables to all devices."""
                    if op.type in {'Variable', 'VariableV2', 'VarHandleOp'}:
                        # return self._sync_device
                        min_load = min(load.values())
                        min_load_devices = [
                            d for d in load if load[d] == min_load
                        ]
                        chosen_device = random.choice(min_load_devices)
                        load[chosen_device] += op.outputs[0].get_shape(
                        ).num_elements()
                        return chosen_device
                    return device

                device_setter = balanced_device_setter

                with tf.device(device):

                    logging.info('Build model on %s.' % device)
                    dec_input = shift_right(Y)
                    dec_label_input = shift_right(Z)

                    def true_fn():
                        enc_output = self.encoder(X,
                                                  is_training=False,
                                                  reuse=i > 0 or None)

                        prediction, prediction_label = self.beam_search_label(
                            enc_output, X, X_lens, reuse=i > 0 or None)
                        dec_output = self.decoder(dec_input,
                                                  enc_output,
                                                  is_training=False,
                                                  reuse=True)

                        loss = self.test_loss_label(dec_output,
                                                    Y,
                                                    Z,
                                                    reuse=True)

                        return prediction, prediction_label, loss

                    def false_fn():

                        return tf.zeros([0, 0], dtype=tf.int32), tf.zeros(
                            [0, 0], dtype=tf.int32), 0.0

                    prediction, prediction_label, loss = tf.cond(
                        tf.greater(tf.shape(X)[0], 0), true_fn, false_fn)

                    loss_sum += loss
                    prediction_list.append(prediction)
                    prediction_label_list.append(prediction_label)

            max_length = tf.reduce_max(
                [tf.shape(pred)[1] for pred in prediction_list])
            max_length_label = tf.reduce_max(
                [tf.shape(pred)[1] for pred in prediction_label_list])

            def pad_to_max_length(input, length):
                """Pad the input (with rank 2) with 3(</S>) to the given length in the second axis."""
                shape = tf.shape(input)
                padding = tf.ones([shape[0], length - shape[1]],
                                  dtype=tf.int32) * 3
                return tf.concat([input, padding], axis=1)

            # calculate the prediction of word sequences
            prediction_list = [
                pad_to_max_length(pred, max_length) for pred in prediction_list
            ]
            self.prediction = tf.concat(prediction_list, axis=0)

            # calculate the prediction of label sequences
            prediction_label_list = [
                pad_to_max_length(pred, max_length_label)
                for pred in prediction_label_list
            ]
            self.prediction_label = tf.concat(prediction_label_list, axis=0)

            self.loss_sum = loss_sum

            # Summaries
            tf.summary.scalar('loss_test', self.loss_sum)

            self.saver = tf.train.Saver(var_list=tf.global_variables())
コード例 #5
0
 def inv_bh_feistel(self, left, right):
     return utils.shift_right(left ^ right, self.beta, self.word_size, self.modulo), left
コード例 #6
0
 def th_feistel(self, left, right):
     return right, (utils.shift_right(left, self.alpha, self.word_size, self.modulo) + right) % self.modulo