예제 #1
0
    def zero_state(self, batch_size, dtype):
        with tf.variable_scope('init', reuse=self.reuse):
            # 读取权重的初始化
            read_vector_list = [
                expand(tf.tanh(learned_init(self.memory_vector_dim)),
                       dim=0,
                       N=batch_size) for i in range(self.read_head_num)
            ]
            # 写入权重的初始化
            w_list = [
                expand(tf.nn.softmax(learned_init(self.memory_size)),
                       dim=0,
                       N=batch_size)
                for i in range(self.read_head_num + self.write_head_num)
            ]
            # RNN初始化
            controller_init_state = self.controller.zero_state(
                batch_size, dtype)
            # 存储单元初始
            M = expand(tf.get_variable(
                'init_M', [self.memory_size, self.memory_vector_dim],
                initializer=tf.constant_initializer(1e-6)),
                       dim=0,
                       N=batch_size)

            # 前面定义的:NTMControllerState = collections.namedtuple(
            #   'NTMControllerState', ('controller_state', 'read_vector_list',
            #   'w_list', 'M'))
            return NTMControllerState(controller_state=controller_init_state,
                                      read_vector_list=read_vector_list,
                                      w_list=w_list,
                                      M=M)
예제 #2
0
    def zero_state(self, batch_size, dtype):
        with tf.compat.v1.variable_scope('init', reuse=self.reuse):
            read_vector_list = [expand(tf.tanh(learned_init(self.memory_vector_dim)), dim=0, N=batch_size)
                for i in range(self.read_head_num)]

            w_list = [expand(tf.nn.softmax(learned_init(self.memory_size)), dim=0, N=batch_size)
                for i in range(self.read_head_num + self.write_head_num)]

            controller_init_state = self.controller.zero_state(batch_size, dtype)

            if self.init_mode == 'learned':
                M = expand(tf.tanh(
                    tf.reshape(
                        learned_init(self.memory_size * self.memory_vector_dim),
                        [self.memory_size, self.memory_vector_dim])
                    ), dim=0, N=batch_size)
            elif self.init_mode == 'random':
                M = expand(
                    tf.tanh(tf.get_variable('init_M', [self.memory_size, self.memory_vector_dim],
                        initializer=tf.random_normal_initializer(mean=0.0, stddev=0.5))),
                    dim=0, N=batch_size)
            elif self.init_mode == 'constant':
                M = expand(
                    tf.get_variable('init_M', [self.memory_size, self.memory_vector_dim],
                        initializer=tf.constant_initializer(1e-6)),
                    dim=0, N=batch_size)

            return NTMControllerState(
                controller_state=controller_init_state,
                read_vector_list=read_vector_list,
                w_list=w_list,
                M=M)
예제 #3
0
    def _build_model(self):
        if args.mann == 'none':

            def single_cell(num_units):
                return tf.contrib.rnn.BasicLSTMCell(num_units, forget_bias=1.0)

            cell = tf.contrib.rnn.OutputProjectionWrapper(
                tf.contrib.rnn.MultiRNNCell([
                    single_cell(args.num_units) for _ in range(args.num_layers)
                ]),
                args.num_bits_per_vector,
                activation=None)

            initial_state = tuple(
                tf.contrib.rnn.LSTMStateTuple(
                    c=expand(tf.tanh(learned_init(args.num_units)),
                             dim=0,
                             N=args.batch_size),
                    h=expand(tf.tanh(learned_init(args.num_units)),
                             dim=0,
                             N=args.batch_size))
                for _ in range(args.num_layers))

        elif args.mann == 'ntm':
            cell = NTMCell(args.num_layers,
                           args.num_units,
                           args.num_memory_locations,
                           args.memory_size,
                           args.num_read_heads,
                           args.num_write_heads,
                           addressing_mode='content_and_location',
                           shift_range=args.conv_shift_range,
                           reuse=False,
                           output_dim=args.num_bits_per_vector,
                           clip_value=args.clip_value,
                           init_mode=args.init_mode)

            initial_state = cell.zero_state(args.batch_size, tf.float32)

        output_sequence, _ = tf.nn.dynamic_rnn(cell=cell,
                                               inputs=self.inputs,
                                               time_major=False,
                                               initial_state=initial_state)

        if args.task == 'copy':
            self.output_logits = output_sequence[:, self.max_seq_len + 1:, :]
        elif args.task == 'associative_recall':
            self.output_logits = output_sequence[:,
                                                 3 * (self.max_seq_len + 1) +
                                                 2:, :]

        if args.task in ('copy', 'associative_recall'):
            self.outputs = tf.sigmoid(self.output_logits)
예제 #4
0
    def _build_model(self):
        if args.mann == 'none':

            def single_cell(num_units):
                return tf.contrib.rnn.BasicLSTMCell(num_units, forget_bias=1.0)

            cell = tf.contrib.rnn.OutputProjectionWrapper(
                tf.contrib.rnn.MultiRNNCell([
                    single_cell(args.num_units) for _ in range(args.num_layers)
                ]),
                args.num_bits_per_vector,
                activation=None)

            initial_state = tuple(
                tf.contrib.rnn.LSTMStateTuple(
                    c=expand(tf.tanh(learned_init(args.num_units)),
                             dim=0,
                             N=args.batch_size),
                    h=expand(tf.tanh(learned_init(args.num_units)),
                             dim=0,
                             N=args.batch_size))
                for _ in range(args.num_layers))

        elif args.mann == 'ntm':
            cell = NTMCell(args.num_layers,
                           args.num_units,
                           args.num_memory_locations,
                           args.memory_size,
                           args.num_read_heads,
                           args.num_write_heads,
                           addressing_mode='content_and_location',
                           shift_range=args.conv_shift_range,
                           reuse=False,
                           output_dim=args.num_bits_per_vector,
                           clip_value=args.clip_value,
                           init_mode=args.init_mode)

            initial_state = cell.zero_state(args.batch_size, tf.float32)
        elif args.mann == 'dnc':
            access_config = {
                'memory_size': args.num_memory_locations,
                'word_size': args.memory_size,
                'num_reads': args.num_read_heads,
                'num_writes': args.num_write_heads,
            }
            controller_config = {
                'hidden_size': args.num_units,
            }

            cell = DNC(access_config, controller_config,
                       args.num_bits_per_vector, args.clip_value)
            initial_state = cell.initial_state(args.batch_size)

        output_sequence, _ = tf.nn.dynamic_rnn(cell=cell,
                                               inputs=self.inputs,
                                               time_major=False,
                                               initial_state=initial_state)

        if args.task == 'copy' or args.task == 'repeat_copy':
            self.output_logits = output_sequence[:, self.max_seq_len + 1:, :]
        elif args.task == 'associative_recall':
            self.output_logits = output_sequence[:,
                                                 3 * (self.max_seq_len + 1) +
                                                 2:, :]
        elif args.task in ('traversal', 'shortest_path'):
            self.output_logits = output_sequence[:, -self.max_seq_len:, :]

        if args.task in ('copy', 'repeat_copy', 'associative_recall'):
            self.outputs = tf.sigmoid(self.output_logits)

        if args.task in ('traversal', 'shortest_path'):
            output_logits_split = tf.split(self.output_logits, 9, axis=2)
            self.outputs = tf.concat(
                [tf.nn.softmax(logits) for logits in output_logits_split],
                axis=2)
예제 #5
0
    def _build_model(self):
        if args.mann == 'none':

            def single_cell(num_units):
                return tf.contrib.rnn.BasicLSTMCell(num_units, forget_bias=1.0)

            cell = tf.contrib.rnn.OutputProjectionWrapper(
                tf.contrib.rnn.MultiRNNCell([
                    single_cell(args.num_units) for _ in range(args.num_layers)
                ]),
                args.num_bits_per_vector,
                activation=None)

            initial_state = tuple(
                tf.contrib.rnn.LSTMStateTuple(
                    c=expand(tf.tanh(learned_init(args.num_units)),
                             dim=0,
                             N=args.batch_size),
                    h=expand(tf.tanh(learned_init(args.num_units)),
                             dim=0,
                             N=args.batch_size))
                for _ in range(args.num_layers))

        elif args.mann == 'ntm':
            if args.use_local_impl:
                cell = NTMCell(controller_layers=args.num_layers,
                               controller_units=args.num_units,
                               memory_size=args.num_memory_locations,
                               memory_vector_dim=args.memory_size,
                               read_head_num=args.num_read_heads,
                               write_head_num=args.num_write_heads,
                               addressing_mode='content_and_location',
                               shift_range=args.conv_shift_range,
                               reuse=False,
                               output_dim=args.num_bits_per_vector,
                               clip_value=args.clip_value,
                               init_mode=args.init_mode)
            else:

                def single_cell(num_units):
                    return tf.compat.v1.nn.rnn_cell.BasicLSTMCell(
                        num_units, forget_bias=1.0)

                controller = tf.compat.v1.nn.rnn_cell.MultiRNNCell([
                    single_cell(args.num_units) for _ in range(args.num_layers)
                ])

                cell = NTMCell(controller,
                               args.num_memory_locations,
                               args.memory_size,
                               args.num_read_heads,
                               args.num_write_heads,
                               shift_range=args.conv_shift_range,
                               output_dim=args.num_bits_per_vector,
                               clip_value=args.clip_value)

        output_sequence, _ = tf.compat.v1.nn.dynamic_rnn(
            cell=cell,
            inputs=self.inputs,
            time_major=False,
            dtype=tf.float32,
            initial_state=initial_state if args.mann == 'none' else None)

        task_to_offset = {
            CopyTask.name:
            lambda: CopyTask.offset(self.max_seq_len),
            AssociativeRecallTask.name:
            lambda: AssociativeRecallTask.offset(self.max_seq_len),
            SumTask.name:
            lambda: SumTask.offset(self.max_seq_len),
            AverageSumTask.name:
            lambda: AverageSumTask.offset(self.max_seq_len, args.num_experts),
            MTATask.name:
            lambda: MTATask.
            offset(self.max_seq_len, args.num_experts, args.
                   two_tuple_weight_precision, args.two_tuple_alpha_precision)
        }
        try:
            where_output_begins = task_to_offset[args.task]()
            self.output_logits = output_sequence[:, where_output_begins:, :]
        except KeyError:
            raise UnknownTaskError(
                f'No information on output slicing of model for "{args.task}" task'
            )

        # Intentionally put in a map, so that each new task that is added to the library explicitly fails with
        # the message. Otherwise, code fails during the training process with a strange error
        task_to_activation = {
            CopyTask.name: tf.sigmoid,
            AssociativeRecallTask.name: tf.sigmoid,
            SumTask.name: tf.sigmoid,
            AverageSumTask.name: tf.sigmoid,
            MTATask.name: tf.sigmoid,
        }
        try:
            self.outputs = task_to_activation[args.task](self.output_logits)
        except KeyError:
            raise UnknownTaskError(
                f'No information on activation on model outputs for "{args.task}" task'
            )