Ejemplo n.º 1
0
 def embedding_lookup(t):
     if not reuse: log.warning(scope.name)
     _ = fc(t, int(embedding_dim/4), is_train,
            info=not reuse, name='fc1')
     _ = fc(_, embedding_dim, is_train,
            info=not reuse, name='fc2')
     return _
Ejemplo n.º 2
0
        def Demo_Encoder(s_h, seq_lengths, scope='Demo_Encoder', reuse=False):
            with tf.variable_scope(scope, reuse=reuse) as scope:
                if not reuse: log.warning(scope.name)
                state_features = tf.reshape(
                    State_Encoder(tf.reshape(s_h, [-1, self.h, self.w, depth]),
                                  self.batch_size * max_demo_len, reuse=reuse),
                    [self.batch_size, max_demo_len, -1])

                with tf.variable_scope('cell_{}'.format(i), reuse=reuse):
                    if self.encoder_rnn_type == 'lstm':
                        cell = rnn.BasicLSTMCell(
                            num_units=self.num_lstm_cell_units,
                            state_is_tuple=True)
                    elif self.encoder_rnn_type == 'rnn':
                        cell = rnn.BasicRNNCell(num_units=self.num_lstm_cell_units)
                    elif self.encoder_rnn_type == 'gru':
                        cell = rnn.GRUCell(num_units=self.num_lstm_cell_units)
                    else:
                        raise ValueError('Unknown encoder rnn type')

                new_h, cell_state = tf.nn.dynamic_rnn(
                    cell=cell, dtype=tf.float32, sequence_length=seq_lengths,
                    inputs=state_features)
                all_states = new_h
                return all_states, cell_state.h, cell_state.c
Ejemplo n.º 3
0
 def Token_Decoder(f, token_dim, scope='Token_Decoder', reuse=False):
     with tf.variable_scope(scope, reuse=reuse) as scope:
         if not reuse: log.warning(scope.name)
         _ = fc(f,
                token_dim,
                is_train,
                info=not reuse,
                batch_norm=False,
                activation_fn=None,
                name='fc1')
         return _
Ejemplo n.º 4
0
        def LSTM_Decoder(visual_h, visual_c, gt_tokens, lstm_cell,
                         unroll_type='teacher_forcing',
                         seq_lengths=None, max_sequence_len=10, token_dim=50,
                         embedding_dim=128, init_state=None,
                         scope='LSTM_Decoder', reuse=False):
            with tf.variable_scope(scope, reuse=reuse) as scope:
                if not reuse: log.warning(scope.name)
                # augmented embedding with token_dim + 1 (<s>) token
                s_token = tf.zeros([self.batch_size, 1],
                                   dtype=gt_tokens.dtype) + token_dim + 1
                gt_tokens = tf.concat([s_token, gt_tokens[:, :-1]], axis=1)

                embedding_lookup = Token_Embedding(token_dim, embedding_dim,
                                                   reuse=reuse)

                # dynamic_decode implementation
                helper = get_DecoderHelper(embedding_lookup, seq_lengths,
                                           token_dim, gt_tokens=gt_tokens,
                                           unroll_type=unroll_type)
                projection_layer = layers_core.Dense(
                    token_dim, use_bias=False, name="output_projection")
                if init_state is None:
                    init_state = rnn.LSTMStateTuple(visual_c, visual_h)
                decoder = seq2seq.BasicDecoder(
                    lstm_cell, helper, init_state,
                    output_layer=projection_layer)
                # pred_length [batch_size]: length of the predicted sequence
                outputs, final_context_state, pred_length = seq2seq.dynamic_decode(
                    decoder, maximum_iterations=max_sequence_len,
                    scope='dynamic_decoder')
                pred_length = tf.expand_dims(pred_length, axis=1)

                # as dynamic_decode generate variable length sequence output,
                # we pad it dynamically to match input embedding shape.
                rnn_output = outputs.rnn_output
                sz = tf.shape(rnn_output)
                dynamic_pad = tf.zeros(
                    [sz[0], max_sequence_len - sz[1], sz[2]],
                    dtype=rnn_output.dtype)
                pred_seq = tf.concat([rnn_output, dynamic_pad], axis=1)
                seq_shape = pred_seq.get_shape().as_list()
                pred_seq.set_shape(
                    [seq_shape[0], max_sequence_len, seq_shape[2]])

                pred_seq = tf.transpose(
                    tf.reshape(pred_seq,
                               [self.batch_size, max_sequence_len, -1]),
                    [0, 2, 1])  # make_dim: [bs, n, len]
                return pred_seq, pred_length, final_context_state
Ejemplo n.º 5
0
        def Demo_Encoder(s_h, per, seq_lengths, scope='Demo_Encoder', reuse=False):
            with tf.variable_scope(scope, reuse=reuse) as scope:
                if not reuse: log.warning(scope.name)
                state_features = tf.reshape(
                    State_Encoder(tf.reshape(s_h, [-1, self.h, self.w, depth]),
                                  tf.reshape(per, [-1, self.per_dim]),
                                  self.batch_size * max_demo_len, reuse=reuse),
                    [self.batch_size, max_demo_len, -1])
                if self.encoder_rnn_type == 'bilstm':
                    fcell = rnn.BasicLSTMCell(
                        num_units=math.ceil(self.num_lstm_cell_units),
                        state_is_tuple=True)
                    bcell = rnn.BasicLSTMCell(
                        num_units=math.floor(self.num_lstm_cell_units),
                        state_is_tuple=True)
                    new_h, cell_state = tf.nn.bidirectional_dynamic_rnn(
                        fcell, bcell, state_features,
                        sequence_length=seq_lengths, dtype=tf.float32)
                    new_h = tf.reduce_sum(tf.stack(new_h, axis=2), axis=2)
                    cell_state = rnn.LSTMStateTuple(
                        tf.reduce_sum(tf.stack(
                            [cs.c for cs in cell_state], axis=1), axis=1),
                        tf.reduce_sum(tf.stack(
                            [cs.h for cs in cell_state], axis=1), axis=1))
                elif self.encoder_rnn_type == 'lstm':
                    cell = rnn.BasicLSTMCell(
                        num_units=self.num_lstm_cell_units,
                        state_is_tuple=True)
                    new_h, cell_state = tf.nn.dynamic_rnn(
                        cell=cell, dtype=tf.float32, sequence_length=seq_lengths,
                        inputs=state_features)
                elif self.encoder_rnn_type == 'rnn':
                    cell = rnn.BasicRNNCell(num_units=self.num_lstm_cell_units)
                    new_h, cell_state = tf.nn.dynamic_rnn(
                        cell=cell, dtype=tf.float32, sequence_length=seq_lengths,
                        inputs=state_features)
                elif self.encoder_rnn_type == 'gru':
                    cell = rnn.GRUCell(num_units=self.num_lstm_cell_units)
                    new_h, cell_state = tf.nn.dynamic_rnn(
                        cell=cell, dtype=tf.float32, sequence_length=seq_lengths,
                        inputs=state_features)
                else:
                    raise ValueError('Unknown encoder rnn type')

                if self.concat_state_feature_direct_prediction:
                    all_states = tf.concat([new_h, state_features], axis=-1)
                else:
                    all_states = new_h
                return all_states, cell_state.h, cell_state.c
Ejemplo n.º 6
0
        def Token_Embedding(token_dim, embedding_dim,
                            scope='Token_Embedding', reuse=False):
            with tf.variable_scope(scope, reuse=reuse) as scope:
                if not reuse: log.warning(scope.name)
                # We add token_dim + 1, to use this tokens as a starting token
                # <s>
                embedding_map = tf.get_variable(
                    name="embedding_map", shape=[token_dim + 1, embedding_dim],
                    initializer=tf.random_uniform_initializer(
                        minval=-0.01, maxval=0.01))

                def embedding_lookup(t):
                    embedding = tf.nn.embedding_lookup(embedding_map, t)
                    return embedding
                return embedding_lookup
Ejemplo n.º 7
0
 def State_Encoder(s, batch_size, scope='State_Encoder', reuse=False):
     with tf.variable_scope(scope, reuse=reuse) as scope:
         if not reuse: log.warning(scope.name)
         _ = conv2d(s, 16, is_train, k_h=3, k_w=3,
                    info=not reuse, batch_norm=True, name='conv1')
         _ = conv2d(_, 32, is_train, k_h=3, k_w=3,
                    info=not reuse, batch_norm=True, name='conv2')
         _ = conv2d(_, 48, is_train, k_h=3, k_w=3,
                    info=not reuse, batch_norm=True, name='conv3')
         if self.dataset_type == 'vizdoom':
             _ = conv2d(_, 48, is_train, k_h=3, k_w=3,
                        info=not reuse, batch_norm=True, name='conv4')
             _ = conv2d(_, 48, is_train, k_h=3, k_w=3,
                        info=not reuse, batch_norm=True, name='conv5')
         state_feature = tf.reshape(_, [batch_size, -1])
         return state_feature
Ejemplo n.º 8
0
 def State_Encoder(s, per, batch_size, scope='State_Encoder', reuse=False):
     with tf.variable_scope(scope, reuse=reuse) as scope:
         if not reuse: log.warning(scope.name)
         _ = conv2d(s, 16, is_train, k_h=3, k_w=3,
                    info=not reuse, batch_norm=True, name='conv1')
         _ = conv2d(_, 32, is_train, k_h=3, k_w=3,
                    info=not reuse, batch_norm=True, name='conv2')
         _ = conv2d(_, 48, is_train, k_h=3, k_w=3,
                    info=not reuse, batch_norm=True, name='conv3')
         if self.pixel_input:
             _ = conv2d(_, 48, is_train, k_h=3, k_w=3,
                        info=not reuse, batch_norm=True, name='conv4')
             _ = conv2d(_, 48, is_train, k_h=3, k_w=3,
                        info=not reuse, batch_norm=True, name='conv5')
         state_feature = tf.reshape(_, [batch_size, -1])
         if self.state_encoder_fc:
             state_feature = fc(state_feature, 512, is_train,
                                info=not reuse, name='fc1')
             state_feature = fc(state_feature, 512, is_train,
                                info=not reuse, name='fc2')
         state_feature = tf.concat([state_feature, per], axis=-1)
         if not reuse: log.info(
             'concat feature {}'.format(state_feature))
         return state_feature
Ejemplo n.º 9
0
        def Sequence_Loss(pred_sequence,
                          gt_sequence,
                          pred_sequence_lengths=None,
                          gt_sequence_lengths=None,
                          max_sequence_len=None,
                          token_dim=None,
                          sequence_type='program',
                          name=None):
            with tf.name_scope(name, "SequenceOutput") as scope:
                log.warning(scope)
                max_sequence_lengths = tf.maximum(pred_sequence_lengths,
                                                  gt_sequence_lengths)
                min_sequence_lengths = tf.minimum(pred_sequence_lengths,
                                                  gt_sequence_lengths)
                gt_mask = tf.sequence_mask(gt_sequence_lengths[:, 0],
                                           max_sequence_len,
                                           dtype=tf.float32,
                                           name='mask')
                max_mask = tf.sequence_mask(max_sequence_lengths[:, 0],
                                            max_sequence_len,
                                            dtype=tf.float32,
                                            name='max_mask')
                min_mask = tf.sequence_mask(min_sequence_lengths[:, 0],
                                            max_sequence_len,
                                            dtype=tf.float32,
                                            name='min_mask')
                labels = tf.reshape(
                    tf.transpose(gt_sequence, [0, 2, 1]),
                    [self.batch_size * max_sequence_len, token_dim])
                logits = tf.reshape(
                    tf.transpose(pred_sequence, [0, 2, 1]),
                    [self.batch_size * max_sequence_len, token_dim])
                # [bs, max_program_len]
                cross_entropy = tf.nn.softmax_cross_entropy_with_logits(
                    labels=labels, logits=logits)
                # normalize loss
                loss = tf.reduce_sum(cross_entropy * tf.reshape(gt_mask, [-1])) / \
                    tf.reduce_sum(gt_mask)
                output = [gt_sequence, pred_sequence]

                label_argmax = tf.argmax(labels, axis=-1)
                logit_argmax = tf.argmax(logits, axis=-1)

                # accuracy
                # token level acc
                correct_token_pred = tf.reduce_sum(
                    tf.to_float(tf.equal(label_argmax, logit_argmax)) *
                    tf.reshape(min_mask, [-1]))
                token_accuracy = correct_token_pred / tf.reduce_sum(max_mask)
                # seq level acc
                seq_equal = tf.equal(
                    tf.reshape(
                        tf.to_float(label_argmax) * tf.reshape(gt_mask, [-1]),
                        [self.batch_size, -1]),
                    tf.reshape(
                        tf.to_float(logit_argmax) * tf.reshape(gt_mask, [-1]),
                        [self.batch_size, -1]))
                len_equal = tf.equal(gt_sequence_lengths[:, 0],
                                     pred_sequence_lengths[:, 0])
                is_same_seq = tf.logical_and(tf.reduce_all(seq_equal, axis=-1),
                                             len_equal)
                seq_accuracy = tf.reduce_sum(
                    tf.to_float(is_same_seq)) / self.batch_size

                pred_tokens = None
                syntax_accuracy = None
                is_correct_syntax = None

                output_stat = SequenceLossOutput(
                    mask=gt_mask,
                    loss=loss,
                    output=output,
                    token_acc=token_accuracy,
                    seq_acc=seq_accuracy,
                    syntax_acc=syntax_accuracy,
                    is_correct_syntax=is_correct_syntax,
                    pred_tokens=pred_tokens,
                    is_same_seq=is_same_seq,
                )

                return output_stat
Ejemplo n.º 10
0
    def build(self, is_train=True):
        # demo_len = self.demo_len
        if self.stack_subsequent_state:
            max_demo_len = self.max_demo_len - 1
            demo_len = self.demo_len - 1
            s_h = tf.stack([
                self.s_h[:, :, :max_demo_len, :, :, :], self.s_h[:, :,
                                                                 1:, :, :, :]
            ],
                           axis=-1)
            depth = self.depth * 2
        else:
            max_demo_len = self.max_demo_len
            demo_len = self.demo_len
            s_h = self.s_h
            depth = self.depth

        # s [bs, h, w, depth] -> feature [bs, v]
        # CNN
        def State_Encoder(s,
                          per,
                          batch_size,
                          scope='State_Encoder',
                          reuse=False):
            with tf.variable_scope(scope, reuse=reuse) as scope:
                if not reuse: log.warning(scope.name)
                _ = conv2d(s,
                           16,
                           is_train,
                           k_h=3,
                           k_w=3,
                           info=not reuse,
                           batch_norm=True,
                           name='conv1')
                _ = conv2d(_,
                           32,
                           is_train,
                           k_h=3,
                           k_w=3,
                           info=not reuse,
                           batch_norm=True,
                           name='conv2')
                _ = conv2d(_,
                           48,
                           is_train,
                           k_h=3,
                           k_w=3,
                           info=not reuse,
                           batch_norm=True,
                           name='conv3')
                if self.pixel_input:
                    _ = conv2d(_,
                               48,
                               is_train,
                               k_h=3,
                               k_w=3,
                               info=not reuse,
                               batch_norm=True,
                               name='conv4')
                    _ = conv2d(_,
                               48,
                               is_train,
                               k_h=3,
                               k_w=3,
                               info=not reuse,
                               batch_norm=True,
                               name='conv5')
                state_feature = tf.reshape(_, [batch_size, -1])
                if self.state_encoder_fc:
                    state_feature = fc(state_feature,
                                       512,
                                       is_train,
                                       info=not reuse,
                                       name='fc1')
                    state_feature = fc(state_feature,
                                       512,
                                       is_train,
                                       info=not reuse,
                                       name='fc2')
                state_feature = tf.concat([state_feature, per], axis=-1)
                if not reuse:
                    log.info('concat feature {}'.format(state_feature))
                return state_feature

        # s_h [bs, t, h, w, depth] -> feature [bs, v]
        # LSTM
        def Demo_Encoder(s_h,
                         per,
                         seq_lengths,
                         scope='Demo_Encoder',
                         reuse=False):
            with tf.variable_scope(scope, reuse=reuse) as scope:
                if not reuse: log.warning(scope.name)
                state_features = tf.reshape(
                    State_Encoder(tf.reshape(s_h, [-1, self.h, self.w, depth]),
                                  tf.reshape(per, [-1, self.per_dim]),
                                  self.batch_size * max_demo_len,
                                  reuse=reuse),
                    [self.batch_size, max_demo_len, -1])
                if self.encoder_rnn_type == 'bilstm':
                    fcell = rnn.BasicLSTMCell(num_units=math.ceil(
                        self.num_lstm_cell_units),
                                              state_is_tuple=True)
                    bcell = rnn.BasicLSTMCell(num_units=math.floor(
                        self.num_lstm_cell_units),
                                              state_is_tuple=True)
                    new_h, cell_state = tf.nn.bidirectional_dynamic_rnn(
                        fcell,
                        bcell,
                        state_features,
                        sequence_length=seq_lengths,
                        dtype=tf.float32)
                    new_h = tf.reduce_sum(tf.stack(new_h, axis=2), axis=2)
                    cell_state = rnn.LSTMStateTuple(
                        tf.reduce_sum(tf.stack([cs.c for cs in cell_state],
                                               axis=1),
                                      axis=1),
                        tf.reduce_sum(tf.stack([cs.h for cs in cell_state],
                                               axis=1),
                                      axis=1))
                elif self.encoder_rnn_type == 'lstm':
                    cell = rnn.BasicLSTMCell(
                        num_units=self.num_lstm_cell_units,
                        state_is_tuple=True)
                    new_h, cell_state = tf.nn.dynamic_rnn(
                        cell=cell,
                        dtype=tf.float32,
                        sequence_length=seq_lengths,
                        inputs=state_features)
                elif self.encoder_rnn_type == 'rnn':
                    cell = rnn.BasicRNNCell(num_units=self.num_lstm_cell_units)
                    new_h, cell_state = tf.nn.dynamic_rnn(
                        cell=cell,
                        dtype=tf.float32,
                        sequence_length=seq_lengths,
                        inputs=state_features)
                elif self.encoder_rnn_type == 'gru':
                    cell = rnn.GRUCell(num_units=self.num_lstm_cell_units)
                    new_h, cell_state = tf.nn.dynamic_rnn(
                        cell=cell,
                        dtype=tf.float32,
                        sequence_length=seq_lengths,
                        inputs=state_features)
                else:
                    raise ValueError('Unknown encoder rnn type')

                if self.concat_state_feature_direct_prediction:
                    all_states = tf.concat([new_h, state_features], axis=-1)
                else:
                    all_states = new_h
                return all_states, cell_state.h, cell_state.c

        # program token [bs, len] -> embedded tokens [len] list of [bs, dim]
        # tensors
        # Embedding
        def Token_Embedding(token_dim,
                            embedding_dim,
                            scope='Token_Embedding',
                            reuse=False):
            with tf.variable_scope(scope, reuse=reuse) as scope:
                if not reuse: log.warning(scope.name)
                # We add token_dim + 1, to use this tokens as a starting token
                # <s>
                embedding_map = tf.get_variable(
                    name="embedding_map",
                    shape=[token_dim + 1, embedding_dim],
                    initializer=tf.random_uniform_initializer(minval=-0.01,
                                                              maxval=0.01))

                def embedding_lookup(t):
                    embedding = tf.nn.embedding_lookup(embedding_map, t)
                    return embedding

                return embedding_lookup

        # program token feature [bs, u] -> program token [bs, dim_program_token]
        # MLP
        def Token_Decoder(f, token_dim, scope='Token_Decoder', reuse=False):
            with tf.variable_scope(scope, reuse=reuse) as scope:
                if not reuse: log.warning(scope.name)
                _ = fc(f,
                       token_dim,
                       is_train,
                       info=not reuse,
                       batch_norm=False,
                       activation_fn=None,
                       name='fc1')
                return _

        # Input {{{
        # =========
        # test_k list of [bs, ac, max_demo_len - 1] tensor
        self.gt_test_actions_onehot = [
            single_test_a_h for single_test_a_h in tf.unstack(
                tf.transpose(self.test_a_h, [0, 1, 3, 2]), axis=1)
        ]
        # test_k list of [bs, max_demo_len - 1] tensor
        self.gt_test_actions_tokens = [
            single_test_a_h_token
            for single_test_a_h_token in tf.unstack(self.test_a_h_tokens,
                                                    axis=1)
        ]

        # a_h = self.a_h
        # }}}

        # Graph {{{
        # =========
        # Demo -> Demo feature
        demo_h_list = []
        demo_c_list = []
        demo_feature_history_list = []
        for i in range(self.k):
            demo_feature_history, demo_h, demo_c = \
                Demo_Encoder(s_h[:, i], self.per[:, i],
                             demo_len[:, i], reuse=i > 0)
            demo_feature_history_list.append(demo_feature_history)
            demo_h_list.append(demo_h)
            demo_c_list.append(demo_c)
            if i == 0: log.warning(demo_feature_history)
        demo_h_stack = tf.stack(demo_h_list, axis=1)  # [bs, k, v]
        demo_c_stack = tf.stack(demo_c_list, axis=1)  # [bs, k, v]
        if self.demo_aggregation == 'concat':  # [bs, k*v]
            demo_h_summary = tf.reshape(demo_h_stack, [self.batch_size, -1])
            demo_c_summary = tf.reshape(demo_c_stack, [self.batch_size, -1])
        elif self.demo_aggregation == 'avgpool':  # [bs, v]
            demo_h_summary = tf.reduce_mean(demo_h_stack, axis=1)
            demo_c_summary = tf.reduce_mean(demo_c_stack, axis=1)
        elif self.demo_aggregation == 'maxpool':  # [bs, v]
            demo_h_summary = tf.squeeze(tf.layers.max_pooling1d(
                demo_h_stack,
                demo_h_stack.get_shape().as_list()[1],
                1,
                padding='valid',
                data_format='channels_last'),
                                        axis=1)
            demo_c_summary = tf.squeeze(tf.layers.max_pooling1d(
                demo_c_stack,
                demo_c_stack.get_shape().as_list()[1],
                1,
                padding='valid',
                data_format='channels_last'),
                                        axis=1)
        else:
            raise ValueError('Unknown demo aggregation type')

        def get_DecoderHelper(embedding_lookup,
                              seq_lengths,
                              token_dim,
                              gt_tokens=None,
                              unroll_type='teacher_forcing'):
            if unroll_type == 'teacher_forcing':
                if gt_tokens is None:
                    raise ValueError('teacher_forcing requires gt_tokens')
                embedding = embedding_lookup(gt_tokens)
                helper = seq2seq.TrainingHelper(embedding, seq_lengths)
            elif unroll_type == 'scheduled_sampling':
                if gt_tokens is None:
                    raise ValueError('scheduled_sampling requires gt_tokens')
                embedding = embedding_lookup(gt_tokens)
                # sample_prob 1.0: always sample from ground truth
                # sample_prob 0.0: always sample from prediction
                helper = seq2seq.ScheduledEmbeddingTrainingHelper(
                    embedding,
                    seq_lengths,
                    embedding_lookup,
                    1.0 - self.sample_prob,
                    seed=None,
                    scheduling_seed=None)
            elif unroll_type == 'greedy':
                # during evaluation, we perform greedy unrolling.
                start_token = tf.zeros([self.batch_size],
                                       dtype=tf.int32) + token_dim
                end_token = token_dim - 1
                helper = seq2seq.GreedyEmbeddingHelper(embedding_lookup,
                                                       start_token, end_token)
            else:
                raise ValueError('Unknown unroll type')
            return helper

        def LSTM_Decoder(visual_h,
                         visual_c,
                         gt_tokens,
                         lstm_cell,
                         unroll_type='teacher_forcing',
                         seq_lengths=None,
                         max_sequence_len=10,
                         token_dim=50,
                         embedding_dim=128,
                         init_state=None,
                         scope='LSTM_Decoder',
                         reuse=False):
            with tf.variable_scope(scope, reuse=reuse) as scope:
                if not reuse: log.warning(scope.name)
                # augmented embedding with token_dim + 1 (<s>) token
                s_token = tf.zeros([self.batch_size, 1],
                                   dtype=gt_tokens.dtype) + token_dim + 1
                gt_tokens = tf.concat([s_token, gt_tokens[:, :-1]], axis=1)

                embedding_lookup = Token_Embedding(token_dim,
                                                   embedding_dim,
                                                   reuse=reuse)

                # dynamic_decode implementation
                helper = get_DecoderHelper(embedding_lookup,
                                           seq_lengths,
                                           token_dim,
                                           gt_tokens=gt_tokens,
                                           unroll_type=unroll_type)
                projection_layer = layers_core.Dense(token_dim,
                                                     use_bias=False,
                                                     name="output_projection")
                if init_state is None:
                    init_state = rnn.LSTMStateTuple(visual_c, visual_h)
                decoder = seq2seq.BasicDecoder(lstm_cell,
                                               helper,
                                               init_state,
                                               output_layer=projection_layer)
                # pred_length [batch_size]: length of the predicted sequence
                outputs, final_context_state, pred_length = seq2seq.dynamic_decode(
                    decoder,
                    maximum_iterations=max_sequence_len,
                    scope='dynamic_decoder')
                pred_length = tf.expand_dims(pred_length, axis=1)

                # as dynamic_decode generate variable length sequence output,
                # we pad it dynamically to match input embedding shape.
                rnn_output = outputs.rnn_output
                sz = tf.shape(rnn_output)
                dynamic_pad = tf.zeros(
                    [sz[0], max_sequence_len - sz[1], sz[2]],
                    dtype=rnn_output.dtype)
                pred_seq = tf.concat([rnn_output, dynamic_pad], axis=1)
                seq_shape = pred_seq.get_shape().as_list()
                pred_seq.set_shape(
                    [seq_shape[0], max_sequence_len, seq_shape[2]])

                pred_seq = tf.transpose(
                    tf.reshape(pred_seq,
                               [self.batch_size, max_sequence_len, -1]),
                    [0, 2, 1])  # make_dim: [bs, n, len]
                return pred_seq, pred_length, final_context_state

        if self.scheduled_sampling:
            train_unroll_type = 'scheduled_sampling'
        else:
            train_unroll_type = 'teacher_forcing'

        # Attn
        lstm_cell = rnn.BasicLSTMCell(num_units=self.num_lstm_cell_units)
        attn_mechanisms = []
        for j in range(self.test_k):
            attn_mechanisms_k = []
            for i in range(self.k):
                with tf.variable_scope('AttnMechanism', reuse=i > 0 or j > 0):
                    if self.attn_type == 'luong':
                        attn_mechanism = seq2seq.LuongAttention(
                            self.num_lstm_cell_units,
                            demo_feature_history_list[i],
                            memory_sequence_length=self.demo_len[:, i])
                    elif self.attn_type == 'luong_monotonic':
                        attn_mechanism = seq2seq.LuongMonotonicAttention(
                            self.num_lstm_cell_units,
                            demo_feature_history_list[i],
                            memory_sequence_length=self.demo_len[:, i])
                    else:
                        raise ValueError('Unknown attention type')
                attn_mechanisms_k.append(attn_mechanism)
            attn_mechanisms.append(attn_mechanisms_k)

        self.attn_cells = []
        for i in range(self.test_k):
            attn_cell = PoolingAttentionWrapper(
                lstm_cell,
                attn_mechanisms[i],
                attention_layer_size=self.num_lstm_cell_units,
                alignment_history=True,
                output_attention=True,
                pooling='avgpool')
            self.attn_cells.append(attn_cell)

        # Demo + current state -> action
        self.pred_action_list = []
        self.greedy_pred_action_list = []
        self.greedy_pred_action_len_list = []
        for i in range(self.test_k):
            attn_init_state = self.attn_cells[i].zero_state(
                self.batch_size,
                dtype=tf.float32).clone(cell_state=rnn.LSTMStateTuple(
                    demo_h_summary, demo_c_summary))
            embedding_dim = demo_h_summary.get_shape().as_list()[-1]
            pred_action, pred_action_len, action_state = LSTM_Decoder(
                demo_h_summary,
                demo_c_summary,
                self.gt_test_actions_tokens[i],
                self.attn_cells[i],
                unroll_type=train_unroll_type,
                seq_lengths=self.test_action_len[:, i],
                max_sequence_len=self.max_action_len,
                token_dim=self.action_space,
                embedding_dim=embedding_dim,
                init_state=attn_init_state,
                scope='Manipulation',
                reuse=i > 0)
            assert pred_action.get_shape() == \
                self.gt_test_actions_onehot[i].get_shape()
            self.pred_action_list.append(pred_action)

            greedy_attn_init_state = self.attn_cells[i].zero_state(
                self.batch_size,
                dtype=tf.float32).clone(cell_state=rnn.LSTMStateTuple(
                    demo_h_summary, demo_c_summary))
            greedy_pred_action, greedy_pred_action_len, \
                greedy_action_state = LSTM_Decoder(
                    demo_h_summary, demo_c_summary, self.gt_test_actions_tokens[i],
                    self.attn_cells[i], unroll_type='greedy',
                    seq_lengths=self.test_action_len[:, i],
                    max_sequence_len=self.max_action_len,
                    token_dim=self.action_space,
                    embedding_dim=embedding_dim,
                    init_state=greedy_attn_init_state,
                    scope='Manipulation', reuse=True
                )
            assert greedy_pred_action.get_shape() == \
                self.gt_test_actions_onehot[i].get_shape()
            self.greedy_pred_action_list.append(greedy_pred_action)
            self.greedy_pred_action_len_list.append(greedy_pred_action_len)
        # }}}

        # Build losses {{{
        # ================
        def Sequence_Loss(pred_sequence,
                          gt_sequence,
                          pred_sequence_lengths=None,
                          gt_sequence_lengths=None,
                          max_sequence_len=None,
                          token_dim=None,
                          sequence_type='program',
                          name=None):
            with tf.name_scope(name, "SequenceOutput") as scope:
                log.warning(scope)
                max_sequence_lengths = tf.maximum(pred_sequence_lengths,
                                                  gt_sequence_lengths)
                min_sequence_lengths = tf.minimum(pred_sequence_lengths,
                                                  gt_sequence_lengths)
                gt_mask = tf.sequence_mask(gt_sequence_lengths[:, 0],
                                           max_sequence_len,
                                           dtype=tf.float32,
                                           name='mask')
                max_mask = tf.sequence_mask(max_sequence_lengths[:, 0],
                                            max_sequence_len,
                                            dtype=tf.float32,
                                            name='max_mask')
                min_mask = tf.sequence_mask(min_sequence_lengths[:, 0],
                                            max_sequence_len,
                                            dtype=tf.float32,
                                            name='min_mask')
                labels = tf.reshape(
                    tf.transpose(gt_sequence, [0, 2, 1]),
                    [self.batch_size * max_sequence_len, token_dim])
                logits = tf.reshape(
                    tf.transpose(pred_sequence, [0, 2, 1]),
                    [self.batch_size * max_sequence_len, token_dim])
                # [bs, max_program_len]
                cross_entropy = tf.nn.softmax_cross_entropy_with_logits(
                    labels=labels, logits=logits)
                # normalize loss
                loss = tf.reduce_sum(cross_entropy * tf.reshape(gt_mask, [-1])) / \
                    tf.reduce_sum(gt_mask)
                output = [gt_sequence, pred_sequence]

                label_argmax = tf.argmax(labels, axis=-1)
                logit_argmax = tf.argmax(logits, axis=-1)

                # accuracy
                # token level acc
                correct_token_pred = tf.reduce_sum(
                    tf.to_float(tf.equal(label_argmax, logit_argmax)) *
                    tf.reshape(min_mask, [-1]))
                token_accuracy = correct_token_pred / tf.reduce_sum(max_mask)
                # seq level acc
                seq_equal = tf.equal(
                    tf.reshape(
                        tf.to_float(label_argmax) * tf.reshape(gt_mask, [-1]),
                        [self.batch_size, -1]),
                    tf.reshape(
                        tf.to_float(logit_argmax) * tf.reshape(gt_mask, [-1]),
                        [self.batch_size, -1]))
                len_equal = tf.equal(gt_sequence_lengths[:, 0],
                                     pred_sequence_lengths[:, 0])
                is_same_seq = tf.logical_and(tf.reduce_all(seq_equal, axis=-1),
                                             len_equal)
                seq_accuracy = tf.reduce_sum(
                    tf.to_float(is_same_seq)) / self.batch_size

                pred_tokens = None
                syntax_accuracy = None
                is_correct_syntax = None

                output_stat = SequenceLossOutput(
                    mask=gt_mask,
                    loss=loss,
                    output=output,
                    token_acc=token_accuracy,
                    seq_acc=seq_accuracy,
                    syntax_acc=syntax_accuracy,
                    is_correct_syntax=is_correct_syntax,
                    pred_tokens=pred_tokens,
                    is_same_seq=is_same_seq,
                )

                return output_stat

        self.loss = 0
        self.output = []

        # Manipulation network loss
        avg_action_loss = 0
        avg_action_token_acc = 0
        avg_action_seq_acc = 0
        seq_match = []
        for i in range(self.test_k):
            action_stat = Sequence_Loss(
                self.pred_action_list[i],
                self.gt_test_actions_onehot[i],
                pred_sequence_lengths=tf.expand_dims(self.test_action_len[:,
                                                                          i],
                                                     axis=1),
                gt_sequence_lengths=tf.expand_dims(self.test_action_len[:, i],
                                                   axis=1),
                max_sequence_len=self.max_action_len,
                token_dim=self.action_space,
                sequence_type='action',
                name="Action_Sequence_Loss_{}".format(i))
            avg_action_loss += action_stat.loss
            avg_action_token_acc += action_stat.token_acc
            avg_action_seq_acc += action_stat.seq_acc
            seq_match.append(action_stat.is_same_seq)
            self.output.extend(action_stat.output)
        avg_action_loss /= self.test_k
        avg_action_token_acc /= self.test_k
        avg_action_seq_acc /= self.test_k
        avg_action_seq_all_acc = tf.reduce_sum(
            tf.to_float(tf.reduce_all(tf.stack(seq_match, axis=1),
                                      axis=-1))) / self.batch_size
        self.loss += avg_action_loss

        greedy_avg_action_loss = 0
        greedy_avg_action_token_acc = 0
        greedy_avg_action_seq_acc = 0
        greedy_seq_match = []
        for i in range(self.test_k):
            greedy_action_stat = Sequence_Loss(
                self.greedy_pred_action_list[i],
                self.gt_test_actions_onehot[i],
                pred_sequence_lengths=self.greedy_pred_action_len_list[i],
                gt_sequence_lengths=tf.expand_dims(self.test_action_len[:, i],
                                                   axis=1),
                max_sequence_len=self.max_action_len,
                token_dim=self.action_space,
                sequence_type='action',
                name="Greedy_Action_Sequence_Loss_{}".format(i))
            greedy_avg_action_loss += greedy_action_stat.loss
            greedy_avg_action_token_acc += greedy_action_stat.token_acc
            greedy_avg_action_seq_acc += greedy_action_stat.seq_acc
            greedy_seq_match.append(greedy_action_stat.is_same_seq)
        greedy_avg_action_loss /= self.test_k
        greedy_avg_action_token_acc /= self.test_k
        greedy_avg_action_seq_acc /= self.test_k
        greedy_avg_action_seq_all_acc = tf.reduce_sum(
            tf.to_float(
                tf.reduce_all(tf.stack(greedy_seq_match, axis=1),
                              axis=-1))) / self.batch_size
        # }}}

        # Evalutaion {{{
        # ==============
        self.report_loss = {}
        self.report_accuracy = {}
        self.report_hist = {}
        self.report_loss['avg_action_loss'] = avg_action_loss
        self.report_accuracy['avg_action_token_acc'] = avg_action_token_acc
        self.report_accuracy['avg_action_seq_acc'] = avg_action_seq_acc
        self.report_accuracy['avg_action_seq_all_acc'] = avg_action_seq_all_acc
        self.report_loss['greedy_avg_action_loss'] = greedy_avg_action_loss
        self.report_accuracy['greedy_avg_action_token_acc'] = \
            greedy_avg_action_token_acc
        self.report_accuracy['greedy_avg_action_seq_acc'] = \
            greedy_avg_action_seq_acc
        self.report_accuracy['greedy_avg_action_seq_all_acc'] = \
            greedy_avg_action_seq_all_acc
        self.report_output = []
        # dummy fetch values for evaler
        self.ground_truth_program = self.program
        self.pred_program = []
        self.greedy_pred_program = []
        self.greedy_pred_program_len = []
        self.greedy_program_is_correct_syntax = []
        self.program_is_correct_syntax = []
        self.program_num_execution_correct = []
        self.program_is_correct_execution = []
        self.greedy_num_execution_correct = []
        self.greedy_is_correct_execution = []

        #

        # Tensorboard Summary {{{
        # =======================
        # Loss
        def train_test_scalar_summary(name, value):
            tf.summary.scalar(name, value, collections=['train'])
            tf.summary.scalar("test_{}".format(name),
                              value,
                              collections=['test'])

        train_test_scalar_summary("loss/loss", self.loss)

        if self.scheduled_sampling:
            train_test_scalar_summary("loss/sample_prob", self.sample_prob)
        train_test_scalar_summary("loss/avg_action_loss", avg_action_loss)
        train_test_scalar_summary("loss/avg_action_token_acc",
                                  avg_action_token_acc)
        train_test_scalar_summary("loss/avg_action_seq_acc",
                                  avg_action_seq_acc)
        train_test_scalar_summary("loss/avg_action_seq_all_acc",
                                  avg_action_seq_all_acc)
        tf.summary.scalar("test_loss/greedy_avg_action_loss",
                          greedy_avg_action_loss,
                          collections=['test'])
        tf.summary.scalar("test_loss/greedy_avg_action_token_acc",
                          greedy_avg_action_token_acc,
                          collections=['test'])
        tf.summary.scalar("test_loss/greedy_avg_action_seq_acc",
                          greedy_avg_action_seq_acc,
                          collections=['test'])
        tf.summary.scalar("test_loss/greedy_avg_action_seq_all_acc",
                          greedy_avg_action_seq_all_acc,
                          collections=['test'])

        def program2str(p_token, p_len):
            program_str = []
            for i in range(self.batch_size):
                program_str.append(
                    self.vocab.intseq2str(
                        np.argmax(p_token[i], axis=0)[:p_len[i, 0]]))
            program_str = np.stack(program_str, axis=0)
            return program_str

        tf.summary.text('program_id/id',
                        self.program_id,
                        collections=['train'])
        tf.summary.text('program/ground_truth',
                        tf.py_func(program2str,
                                   [self.program, self.program_len],
                                   tf.string),
                        collections=['train'])
        tf.summary.text('test_program_id/id',
                        self.program_id,
                        collections=['test'])
        tf.summary.text('test_program/ground_truth',
                        tf.py_func(program2str,
                                   [self.program, self.program_len],
                                   tf.string),
                        collections=['test'])

        # Visualization
        def visualized_map(pred, gt):
            dummy = tf.expand_dims(tf.zeros_like(pred), axis=-1)
            pred = tf.expand_dims(tf.nn.softmax(pred, dim=1), axis=-1)
            gt = tf.expand_dims(gt, axis=-1)
            return tf.concat([pred, gt, dummy], axis=-1)

        # Attention visualization
        def build_alignments(alignment_history):
            alignments = []
            for i in alignment_history:
                align = tf.expand_dims(tf.transpose(i.stack(), [1, 2, 0]),
                                       -1) * 255
                align_shape = tf.shape(align)
                alignments.append(align)
                alignments.append(
                    tf.zeros([align_shape[0], 1, align_shape[2], 1],
                             dtype=tf.float32) + 255)
            alignments_image = tf.reshape(
                tf.tile(tf.concat(alignments, axis=1), [1, 1, 1, self.k]),
                [align_shape[0], -1, align_shape[2] * self.k, 1])
            return alignments_image

        alignments = build_alignments(action_state.alignment_history)
        tf.summary.image("attn", alignments, collections=['train'])
        tf.summary.image("test_attn", alignments, collections=['test'])

        greedy_alignments = build_alignments(
            greedy_action_state.alignment_history)
        tf.summary.image("test_greedy_attn",
                         greedy_alignments,
                         collections=['test'])

        if self.pixel_input:
            tf.summary.image("state/initial_state",
                             self.s_h[:, 0, 0, :, :, :],
                             collections=['train'])
            tf.summary.image("state/demo_program_1",
                             self.s_h[0, 0, :, :, :, :],
                             max_outputs=self.max_demo_len,
                             collections=['train'])

        i = 0  # show only the first demo (among k)
        tf.summary.image("visualized_action/k_{}".format(i),
                         visualized_map(self.pred_action_list[i],
                                        self.gt_test_actions_onehot[i]),
                         collections=['train'])
        tf.summary.image("test_visualized_action/k_{}".format(i),
                         visualized_map(self.pred_action_list[i],
                                        self.gt_test_actions_onehot[i]),
                         collections=['test'])
        tf.summary.image("test_visualized_greedy_action/k_{}".format(i),
                         visualized_map(self.greedy_pred_action_list[i],
                                        self.gt_test_actions_onehot[i]),
                         collections=['test'])

        # Visualize demo features
        if self.debug:
            i = 0  # show only the first images
            tf.summary.image("debug/demo_feature_history/k_{}".format(i),
                             tf.image.grayscale_to_rgb(
                                 tf.expand_dims(demo_feature_history_list[i],
                                                -1)),
                             collections=['train'])
        # }}}
        print('\033[93mSuccessfully loaded the model.\033[0m')
Ejemplo n.º 11
0
def main():
    import argparse
    parser = argparse.ArgumentParser(
        formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument('--debug',
                        action='store_true',
                        default=False,
                        help='set to True to see debugging visualization')
    parser.add_argument('--prefix',
                        type=str,
                        defiault='default',
                        help='a nickanme for the training')
    parser.add_argument('--model',
                        type=str,
                        default='synthesis_baseline',
                        choices=[
                            'synthesis_baseline', 'induction_baseline',
                            'summarizer', 'full'
                        ],
                        help='specify which type of models to train')
    parser.add_argument('--dataset_type',
                        type=str,
                        default='karel',
                        choices=['karel', 'vizdoom'])
    parser.add_argument('--dataset_path',
                        type=str,
                        default='datasets/karel_dataset',
                        help='the path to your dataset')
    parser.add_argument('--checkpoint',
                        type=str,
                        default=None,
                        help='specify the path to a pre-trained checkpoint')
    # log
    parser.add_argument('--log_step',
                        type=int,
                        default=10,
                        help='the frequency of outputing log info')
    parser.add_argument('--write_summary_step',
                        type=int,
                        default=100,
                        help=' the frequency of writing TensorBoard sumamries')
    parser.add_argument('--test_sample_step',
                        type=int,
                        default=100,
                        help='the frequency of performing '
                        'testing inference during training')
    # hyperparameters
    parser.add_argument('--num_k',
                        type=int,
                        default=10,
                        help='the number of seen demonstrations')
    parser.add_argument('--batch_size', type=int, default=32)
    parser.add_argument('--learning_rate', type=float, default=0.001)
    parser.add_argument('--lr_weight_decay',
                        action='store_true',
                        default=False,
                        help='set to `True` to perform expotential weight '
                        'decay on the learning rate')
    parser.add_argument(
        '--scheduled_sampling',
        action='store_true',
        default=False,
        help='set to True to train models with scheduled sampling')
    parser.add_argument('--scheduled_sampling_decay_steps',
                        type=int,
                        default=20000,
                        help='the number of training steps required to decay'
                        'scheduled sampling probability to minimum.')
    # model hyperparameters
    parser.add_argument('--encoder_rnn_type',
                        default='lstm',
                        choices=['lstm', 'rnn', 'gru'])
    parser.add_argument('--num_lstm_cell_units', type=int, default=512)
    parser.add_argument('--demo_aggregation',
                        type=str,
                        default='avgpool',
                        choices=['concat', 'avgpool', 'maxpool'],
                        help='how to aggregate the demo features')

    config = parser.parse_args()

    if config.dataset_type == 'karel':
        import karel_env.dataset_karel as dataset
        dataset_train, dataset_test, dataset_val \
            = dataset.create_default_splits(config.dataset_path, num_k=config.num_k)
    elif config.dataset_type == 'vizdoom':
        import vizdoom_env.dataset_vizdoom as dataset
        dataset_train, dataset_test, dataset_val \
            = dataset.create_default_splits(config.dataset_path, num_k=config.num_k)
    else:
        raise ValueError(config.dataset)

    # Set data dimension in configuration
    data_tuple = dataset_train.get_data(dataset_train.ids[0])
    # s_h: state history, demonstrations
    # a_h: action history, sequence of actions
    # per: sequence of perception primitives
    program, _, s_h, test_s_h, a_h, _, _, _, program_len, demo_len, test_demo_len, \
        per, test_per = data_tuple[:13]

    config.dim_program_token = np.asarray(program.shape)[0]
    config.max_program_len = np.asarray(program.shape)[1]
    config.k = np.asarray(s_h.shape)[0]
    config.test_k = np.asarray(test_s_h.shape)[0]
    config.max_demo_len = np.asarray(s_h.shape)[1]
    config.h = np.asarray(s_h.shape)[2]
    config.w = np.asarray(s_h.shape)[3]
    config.depth = np.asarray(s_h.shape)[4]
    config.action_space = np.asarray(a_h.shape)[2]
    config.per_dim = np.asarray(per.shape)[2]
    if config.dataset_type == 'karel':
        config.dsl_type = dataset_train.dsl_type
        config.env_type = dataset_train.env_type
        config.vizdoom_pos_keys = []
        config.vizdoom_max_init_pos_len = -1
        config.perception_type = ''
        config.level = None
    elif config.dataset_type == 'vizdoom':
        config.dsl_type = 'vizdoom_default'  # vizdoom has 1 dsl type for now
        config.env_type = 'vizdoom_default'  # vizdoom has 1 env type
        config.vizdoom_pos_keys = dataset_train.vizdoom_pos_keys
        config.vizdoom_max_init_pos_len = dataset_train.vizdoom_max_init_pos_len
        config.perception_type = dataset_train.perception_type
        config.level = dataset_train.level

    trainer = Trainer(config, dataset_train, dataset_test)

    log.warning("dataset: %s, learning_rate: %f", config.dataset_path,
                config.learning_rate)
    trainer.train()
Ejemplo n.º 12
0
    def eval_run(self):
        # load checkpoint
        if self.checkpoint:
            self.saver.restore(self.session, self.checkpoint)
            log.info("Loaded from checkpoint!")

        log.infov("Start Inference and Evaluation")

        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(self.session,
                                               coord=coord,
                                               start=True)
        try:
            if self.config.pred_program:
                if not os.path.exists(self.output_dir):
                    os.makedirs(self.output_dir)
                log.infov("Output Dir: %s", self.output_dir)
                base_name = os.path.join(
                    self.output_dir,
                    'out_{}_{}'.format(self.checkpoint_name,
                                       self.dataset_split))
                text_file = open('{}.txt'.format(base_name), 'w')
                from karel_env.dsl import get_KarelDSL
                dsl = get_KarelDSL(dsl_type=self.dataset.dsl_type, seed=123)

                hdf5_file = h5py.File('{}.hdf5'.format(base_name), 'w')
                log_file = open('{}.log'.format(base_name), 'w')
            else:
                log_file = None

            if self.config.result_data:
                result_file = h5py.File(self.config.result_data_path, 'w')
                data_file = h5py.File(
                    os.path.join(self.config.dataset_path, 'data.hdf5'), 'r')

            if not self.config.no_loss:
                loss_all = []
                acc_all = []
                hist_all = {}
                time_all = []
                for s in xrange(self.config.max_steps):
                    step, loss, acc, hist, \
                        pred_program, pred_program_len, pred_is_correct_syntax, \
                        greedy_pred_program, greedy_program_len, greedy_is_correct_syntax, \
                        gt_program, gt_program_len, output, program_id, \
                        program_num_execution_correct, program_is_correct_execution, \
                        greedy_num_execution_correct, greedy_is_correct_execution, \
                        step_time = self.run_single_step(self.batch)
                    if not self.config.quiet:
                        step_msg = self.log_step_message(
                            s, loss, acc, hist, step_time)
                    if self.config.result_data:
                        for i in range(len(program_id)):
                            try:
                                grp = result_file.create_group(program_id[i])
                                grp['program'] = gt_program[i]
                                grp['pred_program'] = greedy_pred_program[i]
                                grp['pred_program_len'] = greedy_program_len[
                                    i][0]
                                grp['s_h'] = data_file[
                                    program_id[i]]['s_h'].value
                                grp['test_s_h'] = data_file[
                                    program_id[i]]['test_s_h'].value
                            except:
                                print('Duplicates: {}'.format(program_id[i]))
                                pass

                    # write pred/gt program
                    if self.config.pred_program:
                        log_file.write('{}\n'.format(step_msg))
                        for i in range(self.batch_size):
                            pred_program_token = np.argmax(
                                pred_program[i, :, :pred_program_len[i, 0]],
                                axis=0)
                            pred_program_str = dsl.intseq2str(
                                pred_program_token)
                            greedy_program_token = np.argmax(
                                greedy_pred_program[i, :, :greedy_program_len[
                                    i, 0]],
                                axis=0)
                            greedy_program_str = dsl.intseq2str(
                                greedy_program_token)
                            try:
                                grp = hdf5_file.create_group(program_id[i])
                            except:
                                pass
                            else:
                                correctness = ['wrong', 'correct']
                                grp['program_prediction'] = pred_program_str
                                grp['program_syntax'] = \
                                    correctness[int(pred_is_correct_syntax[i])]
                                grp['program_num_execution_correct'] = \
                                    int(program_num_execution_correct[i])
                                grp['program_is_correct_execution'] = \
                                    program_is_correct_execution[i]
                                grp['greedy_prediction'] = \
                                    greedy_program_str
                                grp['greedy_syntax'] = \
                                    correctness[int(greedy_is_correct_syntax[i])]
                                grp['greedy_num_execution_correct'] = \
                                    int(greedy_num_execution_correct[i])
                                grp['greedy_is_correct_execution'] = \
                                    greedy_is_correct_execution[i]

                            text_file.write(
                                '[id: {}]\ngt: {}\npred{}: {}\ngreedy{}: {}\n'.
                                format(
                                    program_id[i],
                                    dsl.intseq2str(
                                        np.argmax(gt_program[
                                            i, :, :gt_program_len[i, 0]],
                                                  axis=0)),
                                    '(error)'
                                    if pred_is_correct_syntax[i] == 0 else '',
                                    pred_program_str,
                                    '(error)' if greedy_is_correct_syntax[i]
                                    == 0 else '',
                                    greedy_program_str,
                                ))
                    loss_all.append(np.array(loss.values()))
                    acc_all.append(np.array(acc.values()))
                    time_all.append(step_time)
                    for hist_key, hist_value in hist.items():
                        if hist_key not in hist_all:
                            hist_all[hist_key] = []
                        hist_all[hist_key].append(hist_value)

                loss_avg = np.average(np.stack(loss_all), axis=0)
                acc_avg = np.average(np.stack(acc_all), axis=0)
                hist_avg = {}
                for hist_key, hist_values in hist_all.items():
                    hist_avg[hist_key] = np.average(np.stack(hist_values),
                                                    axis=0)
                final_msg = self.log_final_message(
                    loss_avg,
                    loss.keys(),
                    acc_avg,
                    acc.keys(),
                    hist_avg,
                    hist_avg.keys(),
                    np.sum(time_all),
                    write_summary=self.config.write_summary,
                    summary_file=self.config.summary_file)

            if self.config.result_data:
                result_file.close()
                data_file.close()

            if self.config.pred_program:
                log_file.write('{}\n'.format(final_msg))
                log_file.write("Model class: {}\n".format(self.config.model))
                log_file.write("Checkpoint: {}\n".format(self.checkpoint))
                log_file.write("Dataset: {}\n".format(
                    self.config.dataset_path))
                log_file.close()
                text_file.close()
                hdf5_file.close()

        except Exception as e:
            coord.request_stop(e)

        log.warning('Completed Evaluation.')

        coord.request_stop()
        try:
            coord.join(threads, stop_grace_period_secs=3)
        except RuntimeError as e:
            log.warn(str(e))
Ejemplo n.º 13
0
def main():
    import argparse
    parser = argparse.ArgumentParser(
        formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument('--model',
                        type=str,
                        default='synthesis_baseline',
                        choices=[
                            'synthesis_baseline', 'induction_baseline',
                            'summarizer', 'full'
                        ],
                        help='specify which type of models to evaluate')
    parser.add_argument('--dataset_type',
                        type=str,
                        default='karel',
                        choices=['karel', 'vizdoom'])
    parser.add_argument('--dataset_path',
                        type=str,
                        default='datasets/karel_dataset',
                        help='the path to your dataset')
    parser.add_argument('--dataset_split',
                        type=str,
                        default='test',
                        choices=['train', 'test', 'val'],
                        help='specify the data split to evaluate')
    parser.add_argument('--checkpoint',
                        type=str,
                        default='',
                        help='the path to a trained checkpoint')
    parser.add_argument('--train_dir',
                        type=str,
                        default='',
                        help='the path to train_dir. '
                        'the newest checkpoint will be evaluated')
    parser.add_argument('--output_dir',
                        type=str,
                        default=None,
                        help='the directory to write out programs')
    parser.add_argument('--max_steps',
                        type=int,
                        default=0,
                        help='the number of batches to evaluate. '
                        'set to 0 to evaluate all testing data')
    # hyperparameters
    parser.add_argument('--num_k',
                        type=int,
                        default=10,
                        help='the number of seen demonstrations')
    parser.add_argument('--batch_size', type=int, default=20)
    # model hyperparameters
    parser.add_argument('--encoder_rnn_type',
                        default='lstm',
                        choices=['lstm', 'rnn', 'gru'])
    parser.add_argument('--num_lstm_cell_units', type=int, default=512)
    parser.add_argument('--demo_aggregation',
                        type=str,
                        default='avgpool',
                        choices=['concat', 'avgpool', 'maxpool'],
                        help='how to aggregate the demo features')
    # evaluation task
    parser.add_argument(
        '--no_loss',
        action='store_true',
        default=False,
        help='set to True to not print out the accuracies and losses')
    parser.add_argument('--pred_program',
                        action='store_true',
                        default=False,
                        help='set to True to write out '
                        'predicted and ground truth programs')
    parser.add_argument('--result_data',
                        action='store_true',
                        default=False,
                        help='set to True to save evaluation results')
    parser.add_argument('--result_data_path',
                        type=str,
                        default='result.hdf5',
                        help='the file path to save evaluation results')
    # specify the ids of the testing data that you want to test
    parser.add_argument('--id_list',
                        type=str,
                        help='specify the ids of the data points '
                        'that you want to evaluate. '
                        'By default a whole data split will be evaluated')
    # unseen test
    parser.add_argument('--unseen_test', action='store_true', default=False)
    # write summary file
    parser.add_argument(
        '--quiet',
        action='store_true',
        default=False,
        help='set to True to not log out accuracies and losses '
        'for every batch')
    parser.add_argument('--no_write_summary',
                        action='store_true',
                        default=False,
                        help='set to False to write out '
                        'the summary of accuracies and losses')
    parser.add_argument(
        '--summary_file',
        type=str,
        default='report.txt',
        help='the path to write the summary of accuracies and losses')
    config = parser.parse_args()

    config.write_summary = not config.no_write_summary

    if config.dataset_type == 'karel':
        import karel_env.dataset_karel as dataset
    elif config.datasete_type == 'vizdoom':
        import vizdoom_env.dataset_vizdoom as dataset
    else:
        raise ValueError(config.dataset)

    dataset_train, dataset_test, dataset_val = \
        dataset.create_default_splits(config.dataset_path,
                                        is_train=False, num_k=config.num_k)
    if config.dataset_split == 'train':
        target_dataset = dataset_train
    elif config.dataset_split == 'test':
        target_dataset = dataset_test
    elif config.dataset_split == 'val':
        target_dataset = dataset_val
    else:
        raise ValueError('Unknown dataset split')

    if not config.max_steps > 0:
        config.max_steps = int(len(target_dataset._ids) / config.batch_size)

    if config.dataset_type == 'karel':
        config.perception_type = ''
    elif config.dataset_type == 'vizdoom':
        config.perception_type = target_dataset.perception_type
    else:
        raise ValueError(config.dataset)
    # }}}

    # Data dim
    # [n, max_program_len], [max_program_len], [k, max_demo_len, h, w, depth]
    # [k, max_len_demo, ac], [1], [k]
    data_tuple = target_dataset.get_data(target_dataset.ids[0])
    program, _, s_h, test_s_h, a_h, _, _, _, program_len, demo_len, test_demo_len, \
        per, test_per = data_tuple[:13]

    config.dim_program_token = np.asarray(program.shape)[0]
    config.max_program_len = np.asarray(program.shape)[1]
    config.k = np.asarray(s_h.shape)[0]
    config.test_k = np.asarray(test_s_h.shape)[0]
    config.max_demo_len = np.asarray(s_h.shape)[1]
    config.h = np.asarray(s_h.shape)[2]
    config.w = np.asarray(s_h.shape)[3]
    config.depth = np.asarray(s_h.shape)[4]
    config.action_space = np.asarray(a_h.shape)[2]
    config.per_dim = np.asarray(per.shape)[2]
    if config.dataset_type == 'karel':
        config.dsl_type = target_dataset.dsl_type
        config.env_type = target_dataset.env_type
        config.vizdoom_pos_keys = []
        config.vizdoom_max_init_pos_len = -1
        config.level = None
    elif config.dataset_type == 'vizdoom':
        config.dsl_type = 'vizdoom_default'  # vizdoom has 1 dsl type for now
        config.env_type = 'vizdoom_default'  # vizdoom has 1 env type
        config.vizdoom_pos_keys = target_dataset.vizdoom_pos_keys
        config.vizdoom_max_init_pos_len = target_dataset.vizdoom_max_init_pos_len
        config.level = target_dataset.level

    evaler = Evaler(config, target_dataset)

    log.warning("dataset: %s", config.dataset_path)
    evaler.eval_run()
Ejemplo n.º 14
0
    def build(self, is_train=True):
        max_demo_len = self.max_demo_len
        demo_len = self.demo_len
        s_h = self.s_h
        depth = self.depth

        # s [bs, h, w, depth] -> feature [bs, v]
        # CNN
        def State_Encoder(s, batch_size, scope='State_Encoder', reuse=False):
            with tf.variable_scope(scope, reuse=reuse) as scope:
                if not reuse: log.warning(scope.name)
                _ = conv2d(s, 16, is_train, k_h=3, k_w=3,
                           info=not reuse, batch_norm=True, name='conv1')
                _ = conv2d(_, 32, is_train, k_h=3, k_w=3,
                           info=not reuse, batch_norm=True, name='conv2')
                _ = conv2d(_, 48, is_train, k_h=3, k_w=3,
                           info=not reuse, batch_norm=True, name='conv3')
                if self.dataset_type == 'vizdoom':
                    _ = conv2d(_, 48, is_train, k_h=3, k_w=3,
                               info=not reuse, batch_norm=True, name='conv4')
                    _ = conv2d(_, 48, is_train, k_h=3, k_w=3,
                               info=not reuse, batch_norm=True, name='conv5')
                state_feature = tf.reshape(_, [batch_size, -1])
                return state_feature

        # s_h [bs, t, h, w, depth] -> feature [bs, v]
        # LSTM
        def Demo_Encoder(s_h, seq_lengths, scope='Demo_Encoder', reuse=False):
            with tf.variable_scope(scope, reuse=reuse) as scope:
                if not reuse: log.warning(scope.name)
                state_features = tf.reshape(
                    State_Encoder(tf.reshape(s_h, [-1, self.h, self.w, depth]),
                                  self.batch_size * max_demo_len, reuse=reuse),
                    [self.batch_size, max_demo_len, -1])

                with tf.variable_scope('cell_{}'.format(i), reuse=reuse):
                    if self.encoder_rnn_type == 'lstm':
                        cell = rnn.BasicLSTMCell(
                            num_units=self.num_lstm_cell_units,
                            state_is_tuple=True)
                    elif self.encoder_rnn_type == 'rnn':
                        cell = rnn.BasicRNNCell(num_units=self.num_lstm_cell_units)
                    elif self.encoder_rnn_type == 'gru':
                        cell = rnn.GRUCell(num_units=self.num_lstm_cell_units)
                    else:
                        raise ValueError('Unknown encoder rnn type')

                new_h, cell_state = tf.nn.dynamic_rnn(
                    cell=cell, dtype=tf.float32, sequence_length=seq_lengths,
                    inputs=state_features)
                all_states = new_h
                return all_states, cell_state.h, cell_state.c

        # program token [bs, len] -> embedded tokens [len] list of [bs, dim]
        # tensors
        # Embedding
        def Token_Embedding(token_dim, embedding_dim,
                            scope='Token_Embedding', reuse=False):
            with tf.variable_scope(scope, reuse=reuse) as scope:
                if not reuse: log.warning(scope.name)
                # We add token_dim + 1, to use this tokens as a starting token
                # <s>
                embedding_map = tf.get_variable(
                    name="embedding_map", shape=[token_dim + 1, embedding_dim],
                    initializer=tf.random_uniform_initializer(
                        minval=-0.01, maxval=0.01))

                def embedding_lookup(t):
                    embedding = tf.nn.embedding_lookup(embedding_map, t)
                    return embedding
                return embedding_lookup

        # program token feature [bs, u] -> program token [bs, dim_program_token]
        # MLP
        def Token_Decoder(f, token_dim, scope='Token_Decoder', reuse=False):
            with tf.variable_scope(scope, reuse=reuse) as scope:
                if not reuse: log.warning(scope.name)
                _ = fc(f, token_dim, is_train, info=not reuse,
                       batch_norm=False, activation_fn=None, name='fc1')
                return _

        # Per vector encoder
        def Per_Encoder(token_dim, embedding_dim,
                        scope='Per_Encoder', reuse=False):
            with tf.variable_scope(scope, reuse=reuse) as scope:
                def embedding_lookup(t):
                    if not reuse: log.warning(scope.name)
                    _ = fc(t, int(embedding_dim/4), is_train,
                           info=not reuse, name='fc1')
                    _ = fc(_, embedding_dim, is_train,
                           info=not reuse, name='fc2')
                    return _
                return embedding_lookup

        # Input {{{
        # =========
        self.ground_truth_program = self.program
        self.gt_tokens = tf.argmax(self.ground_truth_program, axis=1)
        # k list of [bs, ac, max_demo_len - 1] tensor
        self.gt_actions_onehot = [single_a_h
                                  for single_a_h
                                  in tf.unstack(tf.transpose(
                                      self.a_h, [0, 1, 3, 2]), axis=1)]
        # k list of [bs, max_demo_len - 1] tensor
        self.gt_actions_tokens = [single_a_h_token
                                  for single_a_h_token in tf.unstack(
                                      self.a_h_tokens, axis=1)]
        self.gt_per = tf.transpose(self.per, [1, 0, 3, 2])
        # }}}

        # Graph {{{
        # =========
        # Demo -> Demo feature
        demo_h_list = []
        demo_c_list = []
        demo_feature_history_list = []
        for i in range(self.k):
            demo_feature_history, demo_h, demo_c = \
                Demo_Encoder(s_h[:, i], demo_len[:, i],
                             reuse=i > 0)
            demo_feature_history_list.append(demo_feature_history)
            demo_h_list.append(demo_h)
            demo_c_list.append(demo_c)
            if i == 0: log.warning(demo_feature_history)
        demo_h_stack = tf.stack(demo_h_list, axis=1)  # [bs, k, v]
        demo_c_stack = tf.stack(demo_c_list, axis=1)  # [bs, k, v]
        if self.demo_aggregation == 'concat':  # [bs, k*v]
            demo_h_summary = tf.reshape(demo_h_stack, [self.batch_size, -1])
            demo_c_summary = tf.reshape(demo_c_stack, [self.batch_size, -1])
        elif self.demo_aggregation == 'avgpool':  # [bs, v]
            demo_h_summary = tf.reduce_mean(demo_h_stack, axis=1)
            demo_c_summary = tf.reduce_mean(demo_c_stack, axis=1)
        elif self.demo_aggregation == 'maxpool':  # [bs, v]
            demo_h_summary = tf.squeeze(
                tf.layers.max_pooling1d(demo_h_stack,
                                        demo_h_stack.get_shape().as_list()[1],
                                        1, padding='valid',
                                        data_format='channels_last'),
                axis=1)
            demo_c_summary = tf.squeeze(
                tf.layers.max_pooling1d(demo_c_stack,
                                        demo_c_stack.get_shape().as_list()[1],
                                        1, padding='valid',
                                        data_format='channels_last'),
                axis=1)
        else:
            raise ValueError('Unknown demo aggregation type')

        def get_DecoderHelper(embedding_lookup, seq_lengths, token_dim,
                              gt_tokens=None, unroll_type='teacher_forcing'):
            if unroll_type == 'teacher_forcing':
                if gt_tokens is None:
                    raise ValueError('teacher_forcing requires gt_tokens')
                embedding = embedding_lookup(gt_tokens)
                helper = seq2seq.TrainingHelper(embedding, seq_lengths)
            elif unroll_type == 'scheduled_sampling':
                if gt_tokens is None:
                    raise ValueError('scheduled_sampling requires gt_tokens')
                embedding = embedding_lookup(gt_tokens)
                # sample_prob 1.0: always sample from ground truth
                # sample_prob 0.0: always sample from prediction
                helper = seq2seq.ScheduledEmbeddingTrainingHelper(
                    embedding, seq_lengths, embedding_lookup,
                    1.0 - self.sample_prob, seed=None,
                    scheduling_seed=None)
            elif unroll_type == 'greedy':
                # during evaluation, we perform greedy unrolling.
                start_token = tf.zeros([self.batch_size], dtype=tf.int32) + \
                    token_dim
                end_token = self.vocab.token2int['m)']
                helper = seq2seq.GreedyEmbeddingHelper(
                    embedding_lookup, start_token, end_token)
            elif unroll_type == 'syntax_greedy':
                start_token = tf.zeros([self.batch_size], dtype=tf.int32) + \
                    token_dim
                end_token = self.vocab.token2int['m)']
                helper = seq2seq_helper.SyntacticGreedyEmbeddingHelper(
                    self.dsl_syntax, self.max_program_len,
                    embedding_lookup, start_token, end_token)
            elif unroll_type == 'syntax_sample':
                start_token = tf.zeros([self.batch_size], dtype=tf.int32) + \
                    token_dim
                end_token = self.vocab.token2int['m)']
                helper = seq2seq_helper.SyntacticSampleEmbeddingHelper(
                    self.dsl_syntax, self.max_program_len,
                    embedding_lookup, start_token, end_token)
            else:
                raise ValueError('Unknown unroll type')
            return helper

        def LSTM_Decoder(init_state, gt_tokens, lstm_cell,
                         unroll_type='teacher_forcing',
                         seq_lengths=None, max_sequence_len=10, token_dim=50,
                         embedding_dim=128, scope='LSTM_Decoder', reuse=False):
            with tf.variable_scope(scope, reuse=reuse) as scope:
                # augmented embedding with token_dim + 1 (<s>) token
                s_token = tf.zeros([self.batch_size, 1],
                                   dtype=gt_tokens.dtype) + token_dim + 1
                gt_tokens = tf.concat([s_token, gt_tokens[:, :-1]], axis=1)
                embedding_lookup = Token_Embedding(
                    token_dim, embedding_dim, reuse=reuse)

                # dynamic_decode implementation
                helper = get_DecoderHelper(embedding_lookup, seq_lengths,
                                           token_dim, gt_tokens=gt_tokens,
                                           unroll_type=unroll_type)
                projection_layer = layer_core.Dense(
                    token_dim, use_bias=False, name="output_projection")
                decoder = seq2seq.BasicDecoder(
                    lstm_cell, helper, init_state,
                    output_layer=projection_layer)
                # pred_length [batch_size]: length of the predicted sequence
                outputs, _, pred_length = tf.contrib.seq2seq.dynamic_decode(
                    decoder, maximum_iterations=max_sequence_len,
                    scope='dynamic_decoder')
                pred_length = tf.expand_dims(pred_length, axis=1)

                # as dynamic_decode generate variable length sequence output,
                # we pad it dynamically to match input embedding shape.
                rnn_output = outputs.rnn_output
                sz = tf.shape(rnn_output)
                dynamic_pad = tf.zeros(
                    [sz[0], max_sequence_len - sz[1], sz[2]],
                    dtype=rnn_output.dtype)
                pred_seq = tf.concat([rnn_output, dynamic_pad], axis=1)
                seq_shape = pred_seq.get_shape().as_list()
                pred_seq.set_shape(
                    [seq_shape[0], max_sequence_len, seq_shape[2]])

                pred_token = outputs.sample_id
                sz = tf.shape(pred_token)
                dynamic_pad = tf.zeros([sz[0], max_sequence_len - sz[1]],
                                       dtype=pred_token.dtype)
                pred_token = tf.concat([pred_token, dynamic_pad], axis=1)
                pred_token.set_shape([pred_token.get_shape().as_list()[0],
                                      max_sequence_len])

                pred_seq = tf.transpose(
                    tf.reshape(pred_seq,
                               [self.batch_size, max_sequence_len, -1]),
                    [0, 2, 1])  # make_dim: [bs, n, len]
                return pred_seq, pred_token, pred_length

        if self.scheduled_sampling:
            train_unroll_type = 'scheduled_sampling'
        else:
            train_unroll_type = 'teacher_forcing'

        # Demo feature -> Program
        dec_lstm_cells = []
        with tf.variable_scope('dec_cell_{}'.format(i), reuse=False):
            cell = rnn.BasicLSTMCell(
                num_units=self.num_lstm_cell_units)
            dec_lstm_cells.append(cell)
        self.program_lstm_cell = dec_lstm_cells[-1]
        init_state = rnn.LSTMStateTuple(demo_c_summary, demo_h_summary)
        embedding_dim = demo_h_summary.get_shape().as_list()[-1]
        self.pred_program, self.pred_program_tokens, self.pred_program_len = LSTM_Decoder(
            init_state, self.program_tokens,
            self.program_lstm_cell, unroll_type=train_unroll_type,
            seq_lengths=self.program_len[:, 0],
            max_sequence_len=self.max_program_len,
            token_dim=self.dim_program_token,
            embedding_dim=embedding_dim, scope='Program_Decoder', reuse=False
        )
        assert self.pred_program.get_shape() == \
            self.ground_truth_program.get_shape()

        self.greedy_pred_program, self.greedy_pred_program_tokens, self.greedy_pred_program_len = LSTM_Decoder(
            init_state, self.program_tokens,
            self.program_lstm_cell, unroll_type='greedy',
            seq_lengths=self.program_len[:, 0],
            max_sequence_len=self.max_program_len,
            token_dim=self.dim_program_token,
            embedding_dim=embedding_dim, scope='Program_Decoder', reuse=True
        )
        assert self.greedy_pred_program.get_shape() == \
            self.ground_truth_program.get_shape()
        # }}}

        def check_correct_syntax(p_token, p_len, is_same_seq):
            if self.dataset_type == 'karel':
                from karel_env.dsl.dsl_parse import parse
            elif self.dataset_type == 'vizdoom':
                from vizdoom_env.dsl.dsl_parse import parse
            is_correct = []
            for i in range(self.batch_size):
                if is_same_seq[i] == 1:
                    is_correct.append(1)
                else:
                    p_str = self.vocab.intseq2str(p_token[i, :p_len[i, 0]])
                    parse_out = parse(p_str)
                    if parse_out[1]: is_correct.append(1)
                    else: is_correct.append(0)
            return np.array(is_correct).astype(np.float32)

        # Build losses {{{
        # ================
        def Sequence_Loss(pred_sequence, gt_sequence, pred_sequence_tokens=None,
                          pred_sequence_lengths=None, gt_sequence_lengths=None,
                          max_sequence_len=None, token_dim=None,
                          name=None):
            with tf.name_scope(name, "SequenceOutput"):
                max_sequence_lengths = tf.maximum(pred_sequence_lengths,
                                                  gt_sequence_lengths)
                min_sequence_lengths = tf.minimum(pred_sequence_lengths,
                                                  gt_sequence_lengths)
                gt_mask = tf.sequence_mask(gt_sequence_lengths[:, 0],
                                           max_sequence_len, dtype=tf.float32,
                                           name='mask')
                max_mask = tf.sequence_mask(max_sequence_lengths[:, 0],
                                            max_sequence_len, dtype=tf.float32,
                                            name='max_mask')
                min_mask = tf.sequence_mask(min_sequence_lengths[:, 0],
                                            max_sequence_len, dtype=tf.float32,
                                            name='min_mask')
                labels = tf.reshape(tf.transpose(gt_sequence, [0, 2, 1]),
                                    [self.batch_size*max_sequence_len,
                                     token_dim])
                logits = tf.reshape(tf.transpose(pred_sequence, [0, 2, 1]),
                                    [self.batch_size*max_sequence_len,
                                     token_dim])

                # [bs, max_program_len]
                cross_entropy = tf.nn.softmax_cross_entropy_with_logits(
                    labels=labels, logits=logits)

                # normalize loss
                loss = tf.reduce_sum(cross_entropy * tf.reshape(gt_mask, [-1])) / \
                    tf.reduce_sum(gt_mask)
                output = [gt_sequence, pred_sequence]

                label_argmax = tf.argmax(labels, axis=-1)
                if pred_sequence_tokens is None:
                    logit_argmax = tf.argmax(logits, axis=-1)
                else:
                    logit_argmax = tf.reshape(
                        tf.cast(pred_sequence_tokens, label_argmax.dtype),
                        [self.batch_size * max_sequence_len])

                # accuracy
                # token level acc
                correct_token_pred = tf.reduce_sum(
                    tf.to_float(tf.equal(label_argmax, logit_argmax)) *
                    tf.reshape(min_mask, [-1]))
                token_accuracy = correct_token_pred / tf.reduce_sum(max_mask)
                # seq level acc
                seq_equal = tf.equal(
                    tf.reshape(tf.to_float(label_argmax) * tf.reshape(gt_mask, [-1]),
                               [self.batch_size, -1]),
                    tf.reshape(tf.to_float(logit_argmax) * tf.reshape(gt_mask, [-1]),
                               [self.batch_size, -1])
                )
                len_equal = tf.equal(gt_sequence_lengths[:, 0],
                                     pred_sequence_lengths[:, 0])
                is_same_seq = tf.to_float(tf.logical_and(
                            tf.reduce_all(seq_equal, axis=-1), len_equal))
                seq_accuracy = tf.reduce_sum(is_same_seq) / self.batch_size

                pred_tokens = tf.reshape(
                    logit_argmax, [self.batch_size, max_sequence_len])
                is_correct_syntax = tf.py_func(
                    check_correct_syntax,
                    [pred_tokens, pred_sequence_lengths, is_same_seq],
                    tf.float32)
                syntax_accuracy = \
                    tf.reduce_sum(is_correct_syntax) / self.batch_size

                output_stat = SequenceLossOutput(
                    mask=gt_mask, loss=loss, output=output,
                    token_acc=token_accuracy,
                    seq_acc=seq_accuracy, syntax_acc=syntax_accuracy,
                    is_correct_syntax=is_correct_syntax,
                    pred_tokens=pred_tokens, is_same_seq=is_same_seq,
                )

                return output_stat

        def exact_program_compare_karel(p_token, p_len, is_correct_syntax,
                                        gt_token, gt_len):
            from karel_env.dsl import dsl_enum_program

            exact_program_correct = []
            for i in range(self.batch_size):
                if is_correct_syntax[i] == 1:
                    p_str = self.vocab.intseq2str(p_token[i, :p_len[i, 0]])
                    gt_str = self.vocab.intseq2str(gt_token[i, :gt_len[i, 0]])

                    p_prog, _ = dsl_enum_program.parse(p_str)
                    gt_prog, _ = dsl_enum_program.parse(gt_str)
                    exact_program_correct.append(float(p_prog == gt_prog))
                else:
                    exact_program_correct.append(0.0)
            return np.array(exact_program_correct, dtype=np.float32)

        def exact_program_compare_vizdoom(p_token, p_len, is_correct_syntax,
                                          gt_token, gt_len):
            from vizdoom_env.dsl import dsl_enum_program

            exact_program_correct = []
            for i in range(self.batch_size):
                if is_correct_syntax[i] == 1:
                    p_str = self.vocab.intseq2str(p_token[i, :p_len[i, 0]])
                    gt_str = self.vocab.intseq2str(gt_token[i, :gt_len[i, 0]])

                    p_prog, _ = dsl_enum_program.parse(p_str)
                    gt_prog, _ = dsl_enum_program.parse(gt_str)
                    exact_program_correct.append(float(p_prog == gt_prog))
                else:
                    exact_program_correct.append(0.0)
            return np.array(exact_program_correct, dtype=np.float32)

        def generate_program_output_karel(initial_states, max_demo_len,
                                          demo_k, h, w, depth,
                                          p_token, p_len, is_correct_syntax,
                                          is_same_seq):
            from karel_env import karel
            from karel_env.dsl.dsl_parse import parse

            batch_pred_demos = []
            batch_pred_demo_len = []
            for i in range(self.batch_size):
                pred_demos = []
                pred_demo_len = []
                for k in range(demo_k):
                    if is_same_seq[i] == 0 and is_correct_syntax[i] == 1:
                        p_str = self.vocab.intseq2str(p_token[i, :p_len[i, 0]])
                        exe, s_exe = parse(p_str)
                        if not s_exe:
                            raise RuntimeError("s_exe couldn't be False here")
                        karel_world, n, s_run = exe(
                            karel.Karel_world(initial_states[i, k],
                                              make_error=self.env_type != 'no_error'),
                            0)
                        if s_run:
                            exe_s_h = copy.deepcopy(karel_world.s_h)
                            pred_demo_len.append(len(exe_s_h))
                            pred_demo = np.stack(exe_s_h[:pred_demo_len[-1]], axis=0)
                            padded = np.zeros([max_demo_len, h, w, depth])
                            padded[:pred_demo.shape[0], :, :, :] = pred_demo[:max_demo_len]
                            pred_demos.append(padded)
                        else:
                            pred_demo_len.append(0)
                            pred_demos.append(
                                np.zeros([max_demo_len, h, w, depth]))
                    else:
                        pred_demo_len.append(0)
                        pred_demos.append(
                            np.zeros([max_demo_len, h, w, depth]))
                batch_pred_demos.append(np.stack(pred_demos, axis=0))
                batch_pred_demo_len.append(np.stack(pred_demo_len, axis=0))
            return np.stack(batch_pred_demos, axis=0).astype(np.float32), \
                np.stack(batch_pred_demo_len, axis=0).astype(np.int32)

        def generate_program_output_vizdoom(init_pos, init_pos_len,
                                            vizdoom_pos_keys,
                                            max_demo_len,
                                            demo_k, h, w, depth,
                                            p_token, p_len, is_correct_syntax,
                                            is_same_seq):
            from vizdoom_env.vizdoom_env import Vizdoom_env
            from vizdoom_env.dsl.dsl_parse import parse
            from cv2 import resize, INTER_AREA

            world = Vizdoom_env(config='vizdoom_env/asset/default.cfg',
                                perception_type=self.perception_type)
            world.init_game()
            batch_pred_demos = []
            batch_pred_demo_len = []
            for i in range(self.batch_size):
                pred_demos = []
                pred_demo_len = []
                for k in range(demo_k):
                    if is_same_seq[i] == 0 and is_correct_syntax[i] == 1:
                        init_dict = {}
                        for p, key in enumerate(vizdoom_pos_keys):
                            init_dict[key] = np.squeeze(
                                init_pos[i, k, p][:init_pos_len[i, k, p]])
                        world.new_episode(init_dict)
                        p_str = self.vocab.intseq2str(p_token[i, :p_len[i, 0]])
                        exe, compile_sucess = parse(p_str)
                        if not compile_sucess:
                            raise RuntimeError(
                                "Compile failure should not happen here")
                        new_w, num_call, success = exe(world, 0)
                        if success:
                            exe_s_h = []
                            for s in world.s_h:
                                if s.shape[0] != h or s.shape[1] != w:
                                    s = resize(s, (h, w),
                                               interpolation=INTER_AREA)
                                exe_s_h.append(s.copy())
                            pred_demo_len.append(len(exe_s_h))
                            pred_demo = np.stack(exe_s_h[:pred_demo_len[-1]],
                                                 axis=0)
                            padded = np.zeros([max_demo_len, h, w, depth])
                            padded[:pred_demo.shape[0], :, :, :] = \
                                pred_demo[:max_demo_len]
                            pred_demos.append(padded)
                        else:
                            pred_demo_len.append(0)
                            pred_demos.append(
                                np.zeros([max_demo_len, h, w, depth]))
                    else:
                        pred_demo_len.append(0)
                        pred_demos.append(
                            np.zeros([max_demo_len, h, w, depth]))
                batch_pred_demos.append(np.stack(pred_demos, axis=0))
                batch_pred_demo_len.append(np.stack(pred_demo_len, axis=0))
            world.end_game()
            return np.stack(batch_pred_demos, axis=0).astype(np.float32), \
                np.stack(batch_pred_demo_len, axis=0).astype(np.int32)

        def ExecuteProgram(s_h, max_demo_len, k, h, w, depth,
                           p_token, p_len,
                           is_correct_syntax, is_same_seq,
                           init_pos=None,
                           init_pos_len=None):
            if self.dataset_type == 'karel':
                initial_states = s_h[:, :, 0, :, :, :]  # [bs, k, h, w, depth]
                execution, execution_len = tf.py_func(
                    generate_program_output_karel,
                    [initial_states,
                     max_demo_len, k, h, w, depth,
                     p_token, p_len, is_correct_syntax, is_same_seq],
                    (tf.float32, tf.int32))
            elif self.dataset_type == 'vizdoom':
                execution, execution_len = tf.py_func(
                    generate_program_output_vizdoom,
                    [init_pos, init_pos_len, self.vizdoom_pos_keys,
                     max_demo_len, k, h, w, depth,
                     p_token, p_len, is_correct_syntax, is_same_seq],
                    (tf.float32, tf.int32))
            else:
                raise ValueError('Unknown dataset_type')
            execution.set_shape([self.batch_size, k,
                                 max_demo_len, h, w, depth])
            execution_len.set_shape([self.batch_size, k])
            return execution, execution_len

        def ExactProgramCompare(p_token, p_len, is_correct_syntax, gt_token, gt_len):
            if self.dataset_type == 'karel':
                exact_program_correct = tf.py_func(
                    exact_program_compare_karel,
                    [p_token, p_len, is_correct_syntax, gt_token, gt_len],
                    (tf.float32))
            elif self.dataset_type == 'vizdoom':
                exact_program_correct = tf.py_func(
                    exact_program_compare_vizdoom,
                    [p_token, p_len, is_correct_syntax, gt_token, gt_len],
                    (tf.float32))
            else:
                raise ValueError('Unknown dataset_type')
            exact_program_correct.set_shape([self.batch_size])
            exact_program_accuracy = tf.reduce_mean(exact_program_correct)
            return exact_program_correct, exact_program_accuracy

        def CompareDemoAndExecution(demo, demo_len, k,
                                    execution, execution_len,
                                    is_same_program):
            _ = tf.equal(demo, execution)
            _ = tf.reduce_all(_, axis=-1)  # reduce depth
            _ = tf.reduce_all(_, axis=-1)  # reduce w
            _ = tf.reduce_all(_, axis=-1)  # reduce h
            _ = tf.reduce_all(_, axis=-1)  # reduce sequence length
            is_same_execution = _  # [bs, k]
            is_same_len = tf.equal(demo_len, execution_len)  # [bs, k]

            is_correct_execution = tf.logical_or(
                tf.logical_and(is_same_execution, is_same_len),
                tf.tile(
                    tf.expand_dims(tf.cast(is_same_program, tf.bool), axis=1),
                    [1, k]))  # [bs, k]
            num_correct_execution = tf.reduce_sum(
                tf.to_float(is_correct_execution), axis=-1)

            hist_list = []
            for i in range(k + 1):
                eq_i = tf.to_float(tf.equal(num_correct_execution, i))
                hist_list.append(tf.reduce_sum(eq_i) / self.batch_size)
            execution_acc_hist = tf.stack(hist_list, axis=0)
            return num_correct_execution, is_correct_execution, execution_acc_hist

        self.loss = 0
        self.output = []

        program_stat = Sequence_Loss(
            self.pred_program,
            self.ground_truth_program,
            pred_sequence_lengths=self.program_len,
            gt_sequence_lengths=self.program_len,
            max_sequence_len=self.max_program_len,
            token_dim=self.dim_program_token,
            name="Program_Sequence_Loss")

        self.program_is_correct_syntax = program_stat.is_correct_syntax
        self.loss += program_stat.loss
        self.output.extend(program_stat.output)

        self.pred_exact_program_correct, self.pred_exact_program_accuracy = \
            ExactProgramCompare(program_stat.pred_tokens, self.program_len,
                                program_stat.is_correct_syntax,
                                self.gt_tokens, self.program_len)

        # Execute program with TRAINING demo initial states
        program_execution, program_execution_len = ExecuteProgram(
            self.s_h, self.max_demo_len, self.k, self.h, self.w, self.depth,
            program_stat.pred_tokens, self.program_len,
            program_stat.is_correct_syntax, program_stat.is_same_seq,
            init_pos=self.init_pos, init_pos_len=self.init_pos_len)
        self.program_num_execution_correct, self.program_is_correct_execution, \
            program_execution_acc_hist = \
            CompareDemoAndExecution(self.s_h, self.demo_len, self.k,
                                    program_execution, program_execution_len,
                                    program_stat.is_same_seq)
        # Execute program with TESTING demo initial states
        test_program_execution, test_program_execution_len = ExecuteProgram(
            self.test_s_h, self.max_demo_len,
            self.test_k, self.h, self.w, self.depth,
            program_stat.pred_tokens, self.program_len,
            program_stat.is_correct_syntax, program_stat.is_same_seq,
            init_pos=self.test_init_pos, init_pos_len=self.test_init_pos_len)
        self.test_program_num_execution_correct, \
            self.test_program_is_correct_execution, \
            test_program_execution_acc_hist = \
            CompareDemoAndExecution(self.test_s_h, self.test_demo_len,
                                    self.test_k,
                                    test_program_execution,
                                    test_program_execution_len,
                                    program_stat.is_same_seq)

        greedy_program_stat = Sequence_Loss(
            self.greedy_pred_program,
            self.ground_truth_program,
            pred_sequence_tokens=self.greedy_pred_program_tokens,
            pred_sequence_lengths=self.greedy_pred_program_len,
            gt_sequence_lengths=self.program_len,
            max_sequence_len=self.max_program_len,
            token_dim=self.dim_program_token,
            name="Greedy_Program_Sequence_Loss")

        self.greedy_program_is_correct_syntax = \
            greedy_program_stat.is_correct_syntax

        self.greedy_exact_program_correct, self.greedy_exact_program_accuracy = \
            ExactProgramCompare(greedy_program_stat.pred_tokens, self.greedy_pred_program_len,
                                greedy_program_stat.is_correct_syntax,
                                self.gt_tokens, self.program_len)

        # Execute program with TRAINING demo initial states
        greedy_execution, greedy_execution_len = ExecuteProgram(
            self.s_h, self.max_demo_len, self.k, self.h, self.w, self.depth,
            greedy_program_stat.pred_tokens, self.greedy_pred_program_len,
            greedy_program_stat.is_correct_syntax,
            greedy_program_stat.is_same_seq,
            init_pos=self.init_pos, init_pos_len=self.init_pos_len)
        self.greedy_num_execution_correct, self.greedy_is_correct_execution, \
            greedy_execution_acc_hist = \
            CompareDemoAndExecution(self.s_h, self.demo_len, self.k,
                                    greedy_execution, greedy_execution_len,
                                    greedy_program_stat.is_same_seq)
        # Execute program with TESTING demo initial states
        test_greedy_execution, test_greedy_execution_len = ExecuteProgram(
            self.test_s_h, self.max_demo_len, self.test_k,
            self.h, self.w, self.depth,
            greedy_program_stat.pred_tokens, self.greedy_pred_program_len,
            greedy_program_stat.is_correct_syntax,
            greedy_program_stat.is_same_seq,
            init_pos=self.test_init_pos, init_pos_len=self.test_init_pos_len)
        self.test_greedy_num_execution_correct, \
            self.test_greedy_is_correct_execution, \
            test_greedy_execution_acc_hist = \
            CompareDemoAndExecution(self.test_s_h, self.test_demo_len,
                                    self.test_k,
                                    test_greedy_execution,
                                    test_greedy_execution_len,
                                    greedy_program_stat.is_same_seq)
        # }}}

        # Evaluation {{{
        # ==============
        self.report_loss = {}
        self.report_accuracy = {}
        self.report_hist = {}
        self.report_loss['program_loss'] = program_stat.loss
        self.report_accuracy['program_token_acc'] = program_stat.token_acc
        self.report_accuracy['program_seq_acc'] = program_stat.seq_acc
        self.report_accuracy['program_syntax_acc'] = program_stat.syntax_acc
        self.report_accuracy['pred_exact_program_accuracy'] = \
            self.pred_exact_program_accuracy
        self.report_accuracy['greedy_exact_program_accuracy'] = \
            self.greedy_exact_program_accuracy
        self.report_loss['greedy_program_loss'] = greedy_program_stat.loss
        self.report_accuracy['greedy_program_token_acc'] = \
            greedy_program_stat.token_acc
        self.report_accuracy['greedy_program_seq_acc'] = \
            greedy_program_stat.seq_acc
        self.report_accuracy['greedy_program_syntax_acc'] = \
            greedy_program_stat.syntax_acc
        self.report_hist['program_execution_acc_hist'] = \
            program_execution_acc_hist
        self.report_hist['greedy_program_execution_acc_hist'] = \
            greedy_execution_acc_hist
        self.report_hist['test_program_execution_acc_hist'] = \
            test_program_execution_acc_hist
        self.report_hist['test_greedy_program_execution_acc_hist'] = \
            test_greedy_execution_acc_hist
        self.report_output = []

        # Tensorboard Summary {{{
        # =======================
        # Loss
        def train_test_scalar_summary(name, value):
            tf.summary.scalar(name, value, collections=['train'])
            tf.summary.scalar("test_{}".format(name), value,
                              collections=['test'])

        train_test_scalar_summary("loss/loss", self.loss)
        train_test_scalar_summary("loss/program_loss", program_stat.loss)
        train_test_scalar_summary("loss/program_token_acc",
                                  program_stat.token_acc)
        train_test_scalar_summary("loss/program_seq_acc",
                                  program_stat.seq_acc)
        train_test_scalar_summary("loss/program_syntax_acc",
                                  program_stat.syntax_acc)
        if self.scheduled_sampling:
            train_test_scalar_summary("loss/sample_prob", self.sample_prob)
        tf.summary.scalar("test_loss/greedy_program_loss",
                          greedy_program_stat.loss, collections=['test'])
        tf.summary.scalar("test_loss/greedy_program_token_acc",
                          greedy_program_stat.token_acc, collections=['test'])
        tf.summary.scalar("test_loss/greedy_program_seq_acc",
                          greedy_program_stat.seq_acc, collections=['test'])
        tf.summary.scalar("test_loss/greedy_program_syntax_acc",
                          greedy_program_stat.syntax_acc, collections=['test'])

        def program2str(p_token, p_len):
            program_str = []
            for i in range(self.batch_size):
                program_str.append(
                    self.vocab.intseq2str(
                        p_token[i][:p_len[i, 0]]))
            program_str = np.stack(program_str, axis=0)
            return program_str

        tf.summary.text('program_id/id', self.program_id, collections=['train'])
        tf.summary.text('program/pred',
                        tf.py_func(
                            program2str,
                            [tf.argmax(self.pred_program, axis=1), self.program_len],
                            tf.string),
                        collections=['train'])
        tf.summary.text('program/ground_truth',
                        tf.py_func(
                            program2str,
                            [tf.argmax(self.ground_truth_program, axis=1), self.program_len],
                            tf.string),
                        collections=['train'])
        tf.summary.text('test_program_id/id', self.program_id,
                        collections=['test'])
        tf.summary.text('test_program/pred',
                        tf.py_func(
                            program2str,
                            [tf.argmax(self.pred_program, axis=1), self.program_len],
                            tf.string),
                        collections=['test'])
        tf.summary.text('test_program/greedy_pred',
                        tf.py_func(
                            program2str,
                            [self.greedy_pred_program_tokens,
                             self.greedy_pred_program_len],
                            tf.string),
                        collections=['test'])
        tf.summary.text('test_program/ground_truth',
                        tf.py_func(
                            program2str,
                            [tf.argmax(self.ground_truth_program, axis=1), self.program_len],
                            tf.string),
                        collections=['test'])

        # Visualization
        def visualized_map(pred, gt):
            dummy = tf.expand_dims(tf.zeros_like(pred), axis=-1)
            pred = tf.expand_dims(pred, axis=-1)
            gt = tf.expand_dims(gt, axis=-1)
            return tf.clip_by_value(tf.concat([pred, gt, dummy], axis=-1), 0, 1)

        if self.debug:
            tiled_mask = tf.tile(tf.expand_dims(program_stat.mask, axis=1),
                                 [1, self.dim_program_token, 1])
            tf.summary.image("debug/mask",
                             tf.image.grayscale_to_rgb(
                                 tf.expand_dims(tiled_mask, -1)),
                             collections=['train'])
        tf.summary.image("visualized_program",
                         visualized_map(tf.nn.softmax(self.pred_program, dim=1),
                                        self.ground_truth_program),
                         collections=['train'])
        tf.summary.image("test_visualized_program",
                         visualized_map(tf.nn.softmax(self.pred_program, dim=1),
                                        self.ground_truth_program),
                         collections=['test'])
        tf.summary.image("test_visualized_greedy_program",
                         visualized_map(tf.one_hot(self.greedy_pred_program_tokens,
                                                   self.dim_program_token, axis=1),
                                        self.ground_truth_program),
                         collections=['test'])
        if self.dataset_type == 'vizdoom':
            tf.summary.image("state/initial_state",
                             self.s_h[:, 0, 0, :, :, :],
                             collections=['train'])
            tf.summary.image("state/demo_program_1",
                             self.s_h[0, 0, :, :, :, :],
                             max_outputs=self.max_demo_len,
                             collections=['train'])

        # Visualize demo features
        if self.debug:
            i = 0  # show only the first images
            tf.summary.image("debug/demo_feature_history/k_{}".format(i),
                             tf.image.grayscale_to_rgb(
                                 tf.expand_dims(demo_feature_history_list[i],
                                                -1)),
                             collections=['train'])
        # }}}
        print('\033[93mSuccessfully loaded the model.\033[0m')