Exemplo n.º 1
0
def conv_block(input, conv_dim, input_dim, res_params, scope):
    with tf.variable_scope(scope):
        input_conv = tf.reshape(input, [-1, 1, input_dim, 1])
        single_conv = slim.conv2d(activation_fn=None,
                                  inputs=input_conv,
                                  num_outputs=conv_dim,
                                  kernel_size=[1, 1],
                                  stride=[1, 4],
                                  padding='SAME')

        pair_conv = slim.conv2d(activation_fn=None,
                                inputs=input_conv,
                                num_outputs=conv_dim,
                                kernel_size=[1, 2],
                                stride=[1, 4],
                                padding='SAME')

        triple_conv = slim.conv2d(activation_fn=None,
                                  inputs=input_conv,
                                  num_outputs=conv_dim,
                                  kernel_size=[1, 3],
                                  stride=[1, 4],
                                  padding='SAME')

        quadric_conv = slim.conv2d(activation_fn=None,
                                   inputs=input_conv,
                                   num_outputs=conv_dim,
                                   kernel_size=[1, 4],
                                   stride=[1, 4],
                                   padding='SAME')

        conv_list = [single_conv, pair_conv, triple_conv, quadric_conv]
        conv = tf.concat(conv_list, -1)

        # conv_idens = []
        # for c in conv_list:
        #     for i in range(5):
        #         c = identity_block(c, 32, 3)
        #     conv_idens.append(c)
        # conv = tf.concat(conv_idens, -1)

        for param in res_params:
            if param[-1] == 'identity':
                conv = identity_block(conv, param[0], param[1])
            elif param[-1] == 'downsampling':
                conv = downsample_block(conv, param[0], param[1])
            elif param[-1] == 'upsampling':
                conv = upsample_block(conv, param[0], param[1])
            else:
                raise Exception('unsupported layer type')
        # assert conv.shape[1] * conv.shape[2] * conv.shape[3] == 1024
        conv = tf.reshape(conv,
                          [-1, conv.shape[1] * conv.shape[2] * conv.shape[3]])
        # conv = tf.squeeze(tf.reduce_mean(conv, axis=[2]), axis=[1])
    return conv
Exemplo n.º 2
0
    def build_graph(self, onehot_cards):
        scope = 'AutoEncoder'
        with tf.variable_scope(scope):
            with slim.arg_scope([slim.conv2d, slim.conv2d_transpose],
                                weights_regularizer=slim.l2_regularizer(1e-3)):
                input_conv = tf.reshape(onehot_cards, [-1, 1, INPUT_DIM, 1])
                single_conv = slim.conv2d(activation_fn=None, inputs=input_conv, num_outputs=32,
                                          kernel_size=[1, 1], stride=[1, 4], padding='SAME')

                pair_conv = slim.conv2d(activation_fn=None, inputs=input_conv, num_outputs=32,
                                        kernel_size=[1, 2], stride=[1, 4], padding='SAME')

                triple_conv = slim.conv2d(activation_fn=None, inputs=input_conv, num_outputs=32,
                                          kernel_size=[1, 3], stride=[1, 4], padding='SAME')

                quadric_conv = slim.conv2d(activation_fn=None, inputs=input_conv, num_outputs=32,
                                           kernel_size=[1, 4], stride=[1, 4], padding='SAME')

                conv = tf.concat([single_conv, pair_conv, triple_conv, quadric_conv], -1)

                encoding_params = [[128, 3, 'identity'],
                                   [128, 3, 'identity'],
                                   [128, 3, 'downsampling'],
                                   [128, 3, 'identity'],
                                   [128, 3, 'identity'],
                                   [256, 3, 'downsampling'],
                                   [256, 3, 'identity'],
                                   [256, 3, 'identity']
                                   ]
                for param in encoding_params:
                    if param[-1] == 'identity':
                        conv = identity_block(conv, param[0], param[1])
                    elif param[-1] == 'upsampling':
                        conv = upsample_block(conv, param[0], param[1])
                    elif param[-1] == 'downsampling':
                        conv = downsample_block(conv, param[0], param[1])
                    else:
                        raise Exception('unsupported layer type')
                conv = tf.reduce_mean(conv, [1, 2], True)
                encoding = tf.identity(conv, name='encoding')

                # is_training = get_current_tower_context().is_training
                # if not is_training:
                #     return

                decoding_params = [[256, 4, 'upsampling'],
                                   [256, 3, 'identity'],
                                   [256, 3, 'identity'],
                                   [256, 4, 'upsampling'],
                                   [128, 3, 'identity'],
                                   [128, 3, 'identity'],
                                   [128, 4, 'upsampling'],
                                   [128, 3, 'identity'],
                                   [1, 3, 'identity']
                                   ]
                for param in decoding_params:
                    if param[-1] == 'identity':
                        conv = identity_block(conv, param[0], param[1])
                    elif param[-1] == 'upsampling':
                        conv = upsample_block(conv, param[0], param[1])
                    elif param[-1] == 'downsampling':
                        conv = downsample_block(conv, param[0], param[1])
                    else:
                        raise Exception('unsupported layer type')
                print(conv.shape)
                decoded = tf.reshape(conv, [-1, conv.shape[1] * conv.shape[2] * conv.shape[3]])

        reconstuct_loss = tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.pad(onehot_cards, [[0, 0], [0, 4]]), logits=decoded)
        reconstuct_loss = tf.reduce_mean(tf.reduce_sum(reconstuct_loss, -1), name='reconstruct_loss')
        l2_loss = tf.truediv(regularize_cost_from_collection(), tf.cast(tf.shape(onehot_cards)[0], tf.float32), name='l2_loss')
        add_moving_summary(reconstuct_loss, decay=0)
        add_moving_summary(l2_loss, decay=0)
        loss = reconstuct_loss + l2_loss
        return loss